Full FastAPI backend for the AI music continuation editor:
Services
- chain.py: chain + node CRUD, commit/discard, recursive CTE spine query
- musicgen.py: MusicGenClient with cf-orch allocation + mock mode (CF_MUSICGEN_MOCK=1)
- stems.py: Demucs 4-stem separation subprocess wrapper + mock mode
- export.py: ffmpeg concat demuxer to stitch committed spine into WAV/MP3
API endpoints
- chains: CRUD, multipart audio upload (WAV/MP3/FLAC/OGG/M4A/AIFF)
- nodes: branch creation (202 + BackgroundTasks), commit, discard, audio stream
- gpu: cf-orch capacity status; session allocation stubbed pending cf-orch#43
- stems: Paid-tier stem separation (Demucs, gated via tiers.py)
- export: POST /{chain_id}/export → FileResponse download
- events: SSE stream (node-status events) per chain via asyncio Queue pub/sub
Infrastructure
- lifespan: reads SPARROW_DB_PATH/DATA_DIR at startup (not import time)
- events_store: subscribe/unsubscribe/broadcast pattern for SSE
- CORS: open in dev, SPARROW_CORS_ORIGINS in production
- Background generation opens its own DB connection (WAL-safe)
Tests: 30/30 passing across service units and API integration
203 lines
7 KiB
Python
203 lines
7 KiB
Python
# app/services/chain.py — chain + node CRUD and branching tree logic
|
|
from __future__ import annotations
|
|
|
|
import sqlite3
|
|
import time
|
|
import uuid
|
|
from typing import Any
|
|
|
|
|
|
# ── Chain CRUD ────────────────────────────────────────────────────────────────
|
|
|
|
def create_chain(conn: sqlite3.Connection, name: str) -> dict[str, Any]:
|
|
chain_id = str(uuid.uuid4())
|
|
now = time.time()
|
|
conn.execute(
|
|
"INSERT INTO chains (id, name, created_at) VALUES (?, ?, ?)",
|
|
(chain_id, name, now),
|
|
)
|
|
conn.commit()
|
|
return {"id": chain_id, "name": name, "created_at": now, "node_count": 0}
|
|
|
|
|
|
def list_chains(conn: sqlite3.Connection) -> list[dict[str, Any]]:
|
|
rows = conn.execute("""
|
|
SELECT c.id, c.name, c.created_at,
|
|
COUNT(n.id) AS node_count
|
|
FROM chains c
|
|
LEFT JOIN nodes n ON n.chain_id = c.id
|
|
GROUP BY c.id
|
|
ORDER BY c.created_at DESC
|
|
""").fetchall()
|
|
return [dict(r) for r in rows]
|
|
|
|
|
|
def get_chain(conn: sqlite3.Connection, chain_id: str) -> dict[str, Any] | None:
|
|
row = conn.execute(
|
|
"SELECT id, name, created_at FROM chains WHERE id = ?", (chain_id,)
|
|
).fetchone()
|
|
if row is None:
|
|
return None
|
|
nodes = _get_nodes_for_chain(conn, chain_id)
|
|
return {"id": row["id"], "name": row["name"], "created_at": row["created_at"], "nodes": nodes}
|
|
|
|
|
|
def delete_chain(conn: sqlite3.Connection, chain_id: str) -> bool:
|
|
cur = conn.execute("DELETE FROM chains WHERE id = ?", (chain_id,))
|
|
conn.commit()
|
|
return cur.rowcount > 0
|
|
|
|
|
|
# ── Node CRUD ─────────────────────────────────────────────────────────────────
|
|
|
|
def create_root_node(
|
|
conn: sqlite3.Connection,
|
|
chain_id: str,
|
|
audio_path: str,
|
|
duration_s: float,
|
|
) -> dict[str, Any]:
|
|
"""Create the root node (uploaded source audio). Always committed."""
|
|
node_id = str(uuid.uuid4())
|
|
now = time.time()
|
|
conn.execute("""
|
|
INSERT INTO nodes
|
|
(id, chain_id, parent_id, audio_path, duration_s, status,
|
|
is_committed, prompt, cfg_coef, prompt_duration_s, created_at)
|
|
VALUES (?, ?, NULL, ?, ?, 'ready', 1, '', 3.0, 10.0, ?)
|
|
""", (node_id, chain_id, audio_path, duration_s, now))
|
|
conn.commit()
|
|
return get_node(conn, node_id)
|
|
|
|
|
|
def get_node(conn: sqlite3.Connection, node_id: str) -> dict[str, Any] | None:
|
|
row = conn.execute("SELECT * FROM nodes WHERE id = ?", (node_id,)).fetchone()
|
|
return _row_to_dict(row) if row else None
|
|
|
|
|
|
def create_branch_node(
|
|
conn: sqlite3.Connection,
|
|
parent_id: str,
|
|
chain_id: str,
|
|
prompt: str,
|
|
energy: float | None,
|
|
tempo_feel: float | None,
|
|
density: float | None,
|
|
cfg_coef: float,
|
|
prompt_duration_s: float,
|
|
) -> dict[str, Any]:
|
|
"""Create a new branch node in 'pending' state. Caller updates to 'generating'."""
|
|
node_id = str(uuid.uuid4())
|
|
now = time.time()
|
|
conn.execute("""
|
|
INSERT INTO nodes
|
|
(id, chain_id, parent_id, audio_path, duration_s, status,
|
|
is_committed, prompt, energy, tempo_feel, density,
|
|
cfg_coef, prompt_duration_s, created_at)
|
|
VALUES (?, ?, ?, NULL, NULL, 'pending', 0, ?, ?, ?, ?, ?, ?, ?)
|
|
""", (node_id, chain_id, parent_id, prompt, energy, tempo_feel, density,
|
|
cfg_coef, prompt_duration_s, now))
|
|
conn.commit()
|
|
return get_node(conn, node_id)
|
|
|
|
|
|
def update_node_status(
|
|
conn: sqlite3.Connection,
|
|
node_id: str,
|
|
status: str,
|
|
audio_path: str | None = None,
|
|
duration_s: float | None = None,
|
|
error_msg: str | None = None,
|
|
) -> dict[str, Any] | None:
|
|
fields: list[str] = ["status = ?"]
|
|
values: list[Any] = [status]
|
|
if audio_path is not None:
|
|
fields.append("audio_path = ?")
|
|
values.append(audio_path)
|
|
if duration_s is not None:
|
|
fields.append("duration_s = ?")
|
|
values.append(duration_s)
|
|
if error_msg is not None:
|
|
fields.append("error_msg = ?")
|
|
values.append(error_msg)
|
|
values.append(node_id)
|
|
conn.execute(f"UPDATE nodes SET {', '.join(fields)} WHERE id = ?", values)
|
|
conn.commit()
|
|
return get_node(conn, node_id)
|
|
|
|
|
|
def commit_node(conn: sqlite3.Connection, node_id: str) -> dict[str, Any] | None:
|
|
"""
|
|
Promote a branch node to the committed spine.
|
|
|
|
Discard (delete) all sibling branch nodes (same parent, not committed).
|
|
The committed node and all its ancestors remain. Root is always committed.
|
|
"""
|
|
node = get_node(conn, node_id)
|
|
if node is None:
|
|
return None
|
|
|
|
parent_id = node.get("parent_id")
|
|
if parent_id:
|
|
# Discard non-committed siblings (same parent)
|
|
conn.execute("""
|
|
DELETE FROM nodes
|
|
WHERE parent_id = ? AND id != ? AND is_committed = 0
|
|
""", (parent_id, node_id))
|
|
|
|
conn.execute("UPDATE nodes SET is_committed = 1 WHERE id = ?", (node_id,))
|
|
conn.commit()
|
|
return get_node(conn, node_id)
|
|
|
|
|
|
def delete_node(conn: sqlite3.Connection, node_id: str) -> bool:
|
|
"""Delete a branch node. Root nodes (parent_id IS NULL) cannot be deleted."""
|
|
row = conn.execute(
|
|
"SELECT parent_id FROM nodes WHERE id = ?", (node_id,)
|
|
).fetchone()
|
|
if row is None or row["parent_id"] is None:
|
|
return False # root node; refuse
|
|
cur = conn.execute(
|
|
"DELETE FROM nodes WHERE id = ? AND is_committed = 0", (node_id,)
|
|
)
|
|
conn.commit()
|
|
return cur.rowcount > 0
|
|
|
|
|
|
def get_committed_spine(conn: sqlite3.Connection, chain_id: str) -> list[dict[str, Any]]:
|
|
"""
|
|
Return committed nodes ordered root → leaf along the committed spine.
|
|
|
|
Uses a recursive CTE to walk the tree from the root through committed nodes.
|
|
"""
|
|
rows = conn.execute("""
|
|
WITH RECURSIVE spine(id, parent_id, depth) AS (
|
|
SELECT id, parent_id, 0
|
|
FROM nodes
|
|
WHERE chain_id = ? AND parent_id IS NULL AND is_committed = 1
|
|
UNION ALL
|
|
SELECT n.id, n.parent_id, spine.depth + 1
|
|
FROM nodes n
|
|
JOIN spine ON n.parent_id = spine.id
|
|
WHERE n.is_committed = 1 AND n.chain_id = ?
|
|
)
|
|
SELECT n.* FROM nodes n
|
|
JOIN spine ON spine.id = n.id
|
|
ORDER BY spine.depth
|
|
""", (chain_id, chain_id)).fetchall()
|
|
return [_row_to_dict(r) for r in rows]
|
|
|
|
|
|
# ── Helpers ───────────────────────────────────────────────────────────────────
|
|
|
|
def _row_to_dict(row: sqlite3.Row) -> dict[str, Any]:
|
|
d = dict(row)
|
|
d["is_committed"] = bool(d.get("is_committed", 0))
|
|
d["children"] = []
|
|
return d
|
|
|
|
|
|
def _get_nodes_for_chain(conn: sqlite3.Connection, chain_id: str) -> list[dict[str, Any]]:
|
|
rows = conn.execute(
|
|
"SELECT * FROM nodes WHERE chain_id = ? ORDER BY created_at", (chain_id,)
|
|
).fetchall()
|
|
return [_row_to_dict(r) for r in rows]
|