sparrow/app/services/chain.py
pyr0ball a6f60c9e07 feat: implement Sparrow backend (v0.1.0)
Full FastAPI backend for the AI music continuation editor:

Services
- chain.py: chain + node CRUD, commit/discard, recursive CTE spine query
- musicgen.py: MusicGenClient with cf-orch allocation + mock mode (CF_MUSICGEN_MOCK=1)
- stems.py: Demucs 4-stem separation subprocess wrapper + mock mode
- export.py: ffmpeg concat demuxer to stitch committed spine into WAV/MP3

API endpoints
- chains: CRUD, multipart audio upload (WAV/MP3/FLAC/OGG/M4A/AIFF)
- nodes: branch creation (202 + BackgroundTasks), commit, discard, audio stream
- gpu: cf-orch capacity status; session allocation stubbed pending cf-orch#43
- stems: Paid-tier stem separation (Demucs, gated via tiers.py)
- export: POST /{chain_id}/export → FileResponse download
- events: SSE stream (node-status events) per chain via asyncio Queue pub/sub

Infrastructure
- lifespan: reads SPARROW_DB_PATH/DATA_DIR at startup (not import time)
- events_store: subscribe/unsubscribe/broadcast pattern for SSE
- CORS: open in dev, SPARROW_CORS_ORIGINS in production
- Background generation opens its own DB connection (WAL-safe)

Tests: 30/30 passing across service units and API integration
2026-04-17 15:22:37 -07:00

203 lines
7 KiB
Python

# app/services/chain.py — chain + node CRUD and branching tree logic
from __future__ import annotations
import sqlite3
import time
import uuid
from typing import Any
# ── Chain CRUD ────────────────────────────────────────────────────────────────
def create_chain(conn: sqlite3.Connection, name: str) -> dict[str, Any]:
chain_id = str(uuid.uuid4())
now = time.time()
conn.execute(
"INSERT INTO chains (id, name, created_at) VALUES (?, ?, ?)",
(chain_id, name, now),
)
conn.commit()
return {"id": chain_id, "name": name, "created_at": now, "node_count": 0}
def list_chains(conn: sqlite3.Connection) -> list[dict[str, Any]]:
rows = conn.execute("""
SELECT c.id, c.name, c.created_at,
COUNT(n.id) AS node_count
FROM chains c
LEFT JOIN nodes n ON n.chain_id = c.id
GROUP BY c.id
ORDER BY c.created_at DESC
""").fetchall()
return [dict(r) for r in rows]
def get_chain(conn: sqlite3.Connection, chain_id: str) -> dict[str, Any] | None:
row = conn.execute(
"SELECT id, name, created_at FROM chains WHERE id = ?", (chain_id,)
).fetchone()
if row is None:
return None
nodes = _get_nodes_for_chain(conn, chain_id)
return {"id": row["id"], "name": row["name"], "created_at": row["created_at"], "nodes": nodes}
def delete_chain(conn: sqlite3.Connection, chain_id: str) -> bool:
cur = conn.execute("DELETE FROM chains WHERE id = ?", (chain_id,))
conn.commit()
return cur.rowcount > 0
# ── Node CRUD ─────────────────────────────────────────────────────────────────
def create_root_node(
conn: sqlite3.Connection,
chain_id: str,
audio_path: str,
duration_s: float,
) -> dict[str, Any]:
"""Create the root node (uploaded source audio). Always committed."""
node_id = str(uuid.uuid4())
now = time.time()
conn.execute("""
INSERT INTO nodes
(id, chain_id, parent_id, audio_path, duration_s, status,
is_committed, prompt, cfg_coef, prompt_duration_s, created_at)
VALUES (?, ?, NULL, ?, ?, 'ready', 1, '', 3.0, 10.0, ?)
""", (node_id, chain_id, audio_path, duration_s, now))
conn.commit()
return get_node(conn, node_id)
def get_node(conn: sqlite3.Connection, node_id: str) -> dict[str, Any] | None:
row = conn.execute("SELECT * FROM nodes WHERE id = ?", (node_id,)).fetchone()
return _row_to_dict(row) if row else None
def create_branch_node(
conn: sqlite3.Connection,
parent_id: str,
chain_id: str,
prompt: str,
energy: float | None,
tempo_feel: float | None,
density: float | None,
cfg_coef: float,
prompt_duration_s: float,
) -> dict[str, Any]:
"""Create a new branch node in 'pending' state. Caller updates to 'generating'."""
node_id = str(uuid.uuid4())
now = time.time()
conn.execute("""
INSERT INTO nodes
(id, chain_id, parent_id, audio_path, duration_s, status,
is_committed, prompt, energy, tempo_feel, density,
cfg_coef, prompt_duration_s, created_at)
VALUES (?, ?, ?, NULL, NULL, 'pending', 0, ?, ?, ?, ?, ?, ?, ?)
""", (node_id, chain_id, parent_id, prompt, energy, tempo_feel, density,
cfg_coef, prompt_duration_s, now))
conn.commit()
return get_node(conn, node_id)
def update_node_status(
conn: sqlite3.Connection,
node_id: str,
status: str,
audio_path: str | None = None,
duration_s: float | None = None,
error_msg: str | None = None,
) -> dict[str, Any] | None:
fields: list[str] = ["status = ?"]
values: list[Any] = [status]
if audio_path is not None:
fields.append("audio_path = ?")
values.append(audio_path)
if duration_s is not None:
fields.append("duration_s = ?")
values.append(duration_s)
if error_msg is not None:
fields.append("error_msg = ?")
values.append(error_msg)
values.append(node_id)
conn.execute(f"UPDATE nodes SET {', '.join(fields)} WHERE id = ?", values)
conn.commit()
return get_node(conn, node_id)
def commit_node(conn: sqlite3.Connection, node_id: str) -> dict[str, Any] | None:
"""
Promote a branch node to the committed spine.
Discard (delete) all sibling branch nodes (same parent, not committed).
The committed node and all its ancestors remain. Root is always committed.
"""
node = get_node(conn, node_id)
if node is None:
return None
parent_id = node.get("parent_id")
if parent_id:
# Discard non-committed siblings (same parent)
conn.execute("""
DELETE FROM nodes
WHERE parent_id = ? AND id != ? AND is_committed = 0
""", (parent_id, node_id))
conn.execute("UPDATE nodes SET is_committed = 1 WHERE id = ?", (node_id,))
conn.commit()
return get_node(conn, node_id)
def delete_node(conn: sqlite3.Connection, node_id: str) -> bool:
"""Delete a branch node. Root nodes (parent_id IS NULL) cannot be deleted."""
row = conn.execute(
"SELECT parent_id FROM nodes WHERE id = ?", (node_id,)
).fetchone()
if row is None or row["parent_id"] is None:
return False # root node; refuse
cur = conn.execute(
"DELETE FROM nodes WHERE id = ? AND is_committed = 0", (node_id,)
)
conn.commit()
return cur.rowcount > 0
def get_committed_spine(conn: sqlite3.Connection, chain_id: str) -> list[dict[str, Any]]:
"""
Return committed nodes ordered root → leaf along the committed spine.
Uses a recursive CTE to walk the tree from the root through committed nodes.
"""
rows = conn.execute("""
WITH RECURSIVE spine(id, parent_id, depth) AS (
SELECT id, parent_id, 0
FROM nodes
WHERE chain_id = ? AND parent_id IS NULL AND is_committed = 1
UNION ALL
SELECT n.id, n.parent_id, spine.depth + 1
FROM nodes n
JOIN spine ON n.parent_id = spine.id
WHERE n.is_committed = 1 AND n.chain_id = ?
)
SELECT n.* FROM nodes n
JOIN spine ON spine.id = n.id
ORDER BY spine.depth
""", (chain_id, chain_id)).fetchall()
return [_row_to_dict(r) for r in rows]
# ── Helpers ───────────────────────────────────────────────────────────────────
def _row_to_dict(row: sqlite3.Row) -> dict[str, Any]:
d = dict(row)
d["is_committed"] = bool(d.get("is_committed", 0))
d["children"] = []
return d
def _get_nodes_for_chain(conn: sqlite3.Connection, chain_id: str) -> list[dict[str, Any]]:
rows = conn.execute(
"SELECT * FROM nodes WHERE chain_id = ? ORDER BY created_at", (chain_id,)
).fetchall()
return [_row_to_dict(r) for r in rows]