# app/services/chain.py — chain + node CRUD and branching tree logic from __future__ import annotations import sqlite3 import time import uuid from typing import Any # ── Chain CRUD ──────────────────────────────────────────────────────────────── def create_chain(conn: sqlite3.Connection, name: str) -> dict[str, Any]: chain_id = str(uuid.uuid4()) now = time.time() conn.execute( "INSERT INTO chains (id, name, created_at) VALUES (?, ?, ?)", (chain_id, name, now), ) conn.commit() return {"id": chain_id, "name": name, "created_at": now, "node_count": 0} def list_chains(conn: sqlite3.Connection) -> list[dict[str, Any]]: rows = conn.execute(""" SELECT c.id, c.name, c.created_at, COUNT(n.id) AS node_count FROM chains c LEFT JOIN nodes n ON n.chain_id = c.id GROUP BY c.id ORDER BY c.created_at DESC """).fetchall() return [dict(r) for r in rows] def get_chain(conn: sqlite3.Connection, chain_id: str) -> dict[str, Any] | None: row = conn.execute( "SELECT id, name, created_at FROM chains WHERE id = ?", (chain_id,) ).fetchone() if row is None: return None nodes = _get_nodes_for_chain(conn, chain_id) return {"id": row["id"], "name": row["name"], "created_at": row["created_at"], "nodes": nodes} def delete_chain(conn: sqlite3.Connection, chain_id: str) -> bool: cur = conn.execute("DELETE FROM chains WHERE id = ?", (chain_id,)) conn.commit() return cur.rowcount > 0 # ── Node CRUD ───────────────────────────────────────────────────────────────── def create_root_node( conn: sqlite3.Connection, chain_id: str, audio_path: str, duration_s: float, ) -> dict[str, Any]: """Create the root node (uploaded source audio). Always committed.""" node_id = str(uuid.uuid4()) now = time.time() conn.execute(""" INSERT INTO nodes (id, chain_id, parent_id, audio_path, duration_s, status, is_committed, prompt, cfg_coef, prompt_duration_s, created_at) VALUES (?, ?, NULL, ?, ?, 'ready', 1, '', 3.0, 10.0, ?) """, (node_id, chain_id, audio_path, duration_s, now)) conn.commit() return get_node(conn, node_id) def get_node(conn: sqlite3.Connection, node_id: str) -> dict[str, Any] | None: row = conn.execute("SELECT * FROM nodes WHERE id = ?", (node_id,)).fetchone() return _row_to_dict(row) if row else None def create_branch_node( conn: sqlite3.Connection, parent_id: str, chain_id: str, prompt: str, energy: float | None, tempo_feel: float | None, density: float | None, cfg_coef: float, prompt_duration_s: float, ) -> dict[str, Any]: """Create a new branch node in 'pending' state. Caller updates to 'generating'.""" node_id = str(uuid.uuid4()) now = time.time() conn.execute(""" INSERT INTO nodes (id, chain_id, parent_id, audio_path, duration_s, status, is_committed, prompt, energy, tempo_feel, density, cfg_coef, prompt_duration_s, created_at) VALUES (?, ?, ?, NULL, NULL, 'pending', 0, ?, ?, ?, ?, ?, ?, ?) """, (node_id, chain_id, parent_id, prompt, energy, tempo_feel, density, cfg_coef, prompt_duration_s, now)) conn.commit() return get_node(conn, node_id) def update_node_status( conn: sqlite3.Connection, node_id: str, status: str, audio_path: str | None = None, duration_s: float | None = None, error_msg: str | None = None, ) -> dict[str, Any] | None: fields: list[str] = ["status = ?"] values: list[Any] = [status] if audio_path is not None: fields.append("audio_path = ?") values.append(audio_path) if duration_s is not None: fields.append("duration_s = ?") values.append(duration_s) if error_msg is not None: fields.append("error_msg = ?") values.append(error_msg) values.append(node_id) conn.execute(f"UPDATE nodes SET {', '.join(fields)} WHERE id = ?", values) conn.commit() return get_node(conn, node_id) def commit_node(conn: sqlite3.Connection, node_id: str) -> dict[str, Any] | None: """ Promote a branch node to the committed spine. Discard (delete) all sibling branch nodes (same parent, not committed). The committed node and all its ancestors remain. Root is always committed. """ node = get_node(conn, node_id) if node is None: return None parent_id = node.get("parent_id") if parent_id: # Discard non-committed siblings (same parent) conn.execute(""" DELETE FROM nodes WHERE parent_id = ? AND id != ? AND is_committed = 0 """, (parent_id, node_id)) conn.execute("UPDATE nodes SET is_committed = 1 WHERE id = ?", (node_id,)) conn.commit() return get_node(conn, node_id) def delete_node(conn: sqlite3.Connection, node_id: str) -> bool: """Delete a branch node. Root nodes (parent_id IS NULL) cannot be deleted.""" row = conn.execute( "SELECT parent_id FROM nodes WHERE id = ?", (node_id,) ).fetchone() if row is None or row["parent_id"] is None: return False # root node; refuse cur = conn.execute( "DELETE FROM nodes WHERE id = ? AND is_committed = 0", (node_id,) ) conn.commit() return cur.rowcount > 0 def get_committed_spine(conn: sqlite3.Connection, chain_id: str) -> list[dict[str, Any]]: """ Return committed nodes ordered root → leaf along the committed spine. Uses a recursive CTE to walk the tree from the root through committed nodes. """ rows = conn.execute(""" WITH RECURSIVE spine(id, parent_id, depth) AS ( SELECT id, parent_id, 0 FROM nodes WHERE chain_id = ? AND parent_id IS NULL AND is_committed = 1 UNION ALL SELECT n.id, n.parent_id, spine.depth + 1 FROM nodes n JOIN spine ON n.parent_id = spine.id WHERE n.is_committed = 1 AND n.chain_id = ? ) SELECT n.* FROM nodes n JOIN spine ON spine.id = n.id ORDER BY spine.depth """, (chain_id, chain_id)).fetchall() return [_row_to_dict(r) for r in rows] # ── Helpers ─────────────────────────────────────────────────────────────────── def _row_to_dict(row: sqlite3.Row) -> dict[str, Any]: d = dict(row) d["is_committed"] = bool(d.get("is_committed", 0)) d["children"] = [] return d def _get_nodes_for_chain(conn: sqlite3.Connection, chain_id: str) -> list[dict[str, Any]]: rows = conn.execute( "SELECT * FROM nodes WHERE chain_id = ? ORDER BY created_at", (chain_id,) ).fetchall() return [_row_to_dict(r) for r in rows]