sparrow/app/api/endpoints/nodes.py
pyr0ball a6f60c9e07 feat: implement Sparrow backend (v0.1.0)
Full FastAPI backend for the AI music continuation editor:

Services
- chain.py: chain + node CRUD, commit/discard, recursive CTE spine query
- musicgen.py: MusicGenClient with cf-orch allocation + mock mode (CF_MUSICGEN_MOCK=1)
- stems.py: Demucs 4-stem separation subprocess wrapper + mock mode
- export.py: ffmpeg concat demuxer to stitch committed spine into WAV/MP3

API endpoints
- chains: CRUD, multipart audio upload (WAV/MP3/FLAC/OGG/M4A/AIFF)
- nodes: branch creation (202 + BackgroundTasks), commit, discard, audio stream
- gpu: cf-orch capacity status; session allocation stubbed pending cf-orch#43
- stems: Paid-tier stem separation (Demucs, gated via tiers.py)
- export: POST /{chain_id}/export → FileResponse download
- events: SSE stream (node-status events) per chain via asyncio Queue pub/sub

Infrastructure
- lifespan: reads SPARROW_DB_PATH/DATA_DIR at startup (not import time)
- events_store: subscribe/unsubscribe/broadcast pattern for SSE
- CORS: open in dev, SPARROW_CORS_ORIGINS in production
- Background generation opens its own DB connection (WAL-safe)

Tests: 30/30 passing across service units and API integration
2026-04-17 15:22:37 -07:00

182 lines
5.8 KiB
Python

# app/api/endpoints/nodes.py — branch generation, commit, discard, audio stream
from __future__ import annotations
import asyncio
import logging
from pathlib import Path
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException
from fastapi.responses import FileResponse
from app.api.deps import get_conn, get_data_dir
from app.api.events_store import broadcast
from app.models.schemas.node import BranchRequest, NodeRow
from app.services import chain as chain_svc
from app.services.musicgen import MusicGenClient, make_output_path
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/nodes", tags=["nodes"])
_musicgen = MusicGenClient()
@router.post("/{parent_id}/branch", response_model=NodeRow, status_code=202)
async def create_branch(
parent_id: str,
body: BranchRequest,
background_tasks: BackgroundTasks,
conn=Depends(get_conn),
data_dir: str = Depends(get_data_dir),
) -> dict:
"""
Create a branch node from parent_id and kick off async generation.
Returns 202 immediately with the node in 'pending' state.
Status transitions (pending → generating → ready/error) are pushed
via the chain's SSE stream.
"""
parent = chain_svc.get_node(conn, parent_id)
if parent is None:
raise HTTPException(status_code=404, detail="Parent node not found.")
if parent["audio_path"] is None:
raise HTTPException(
status_code=409, detail="Parent node has no audio yet."
)
node = chain_svc.create_branch_node(
conn,
parent_id=parent_id,
chain_id=parent["chain_id"],
prompt=body.prompt,
energy=body.energy,
tempo_feel=body.tempo_feel,
density=body.density,
cfg_coef=body.cfg_coef,
prompt_duration_s=body.prompt_duration_s,
)
background_tasks.add_task(
_run_generation,
node_id=node["id"],
chain_id=parent["chain_id"],
source_audio_path=parent["audio_path"],
prompt=body.prompt,
duration_s=body.duration_s,
cfg_coef=body.cfg_coef,
prompt_duration_s=body.prompt_duration_s,
data_dir=data_dir,
db_path=conn.execute("PRAGMA database_list").fetchone()[2],
)
return node
@router.post("/{node_id}/commit", response_model=NodeRow)
def commit_node(node_id: str, conn=Depends(get_conn)) -> dict:
"""
Promote a branch node to the committed spine.
Discards all non-committed siblings (same parent). The root node is
always committed and cannot be targeted here.
"""
node = chain_svc.get_node(conn, node_id)
if node is None:
raise HTTPException(status_code=404, detail="Node not found.")
if node["status"] != "ready":
raise HTTPException(
status_code=409, detail="Only 'ready' nodes can be committed."
)
result = chain_svc.commit_node(conn, node_id)
if result is None:
raise HTTPException(status_code=404, detail="Node not found.")
return result
@router.delete("/{node_id}", status_code=204)
def delete_node(node_id: str, conn=Depends(get_conn)) -> None:
"""
Discard a non-committed branch node.
Root nodes (parent_id IS NULL) and committed nodes cannot be deleted.
"""
if not chain_svc.delete_node(conn, node_id):
raise HTTPException(
status_code=409,
detail="Node not found or cannot be deleted (root or committed).",
)
@router.get("/{node_id}/audio")
def stream_audio(node_id: str, conn=Depends(get_conn)) -> FileResponse:
"""Stream the audio file for a ready node."""
node = chain_svc.get_node(conn, node_id)
if node is None:
raise HTTPException(status_code=404, detail="Node not found.")
if node["audio_path"] is None or not Path(node["audio_path"]).exists():
raise HTTPException(status_code=404, detail="Audio not available.")
return FileResponse(
node["audio_path"],
media_type="audio/wav",
filename=f"{node_id}.wav",
)
# ── Background generation task ────────────────────────────────────────────────
async def _run_generation(
*,
node_id: str,
chain_id: str,
source_audio_path: str,
prompt: str,
duration_s: float,
cfg_coef: float,
prompt_duration_s: float,
data_dir: str,
db_path: str,
) -> None:
"""
Background task: run MusicGen and update node status.
Opens its own DB connection so it doesn't block the request thread.
"""
from app.db.store import get_connection
conn = get_connection(db_path)
try:
chain_svc.update_node_status(conn, node_id, "generating")
await broadcast(chain_id, {"node_id": node_id, "status": "generating"})
output_path = make_output_path(data_dir, chain_id, node_id)
actual_duration = await _musicgen.generate(
source_audio_path=source_audio_path,
output_path=output_path,
prompt=prompt,
duration_s=duration_s,
cfg_coef=cfg_coef,
prompt_duration_s=prompt_duration_s,
)
chain_svc.update_node_status(
conn, node_id, "ready",
audio_path=output_path,
duration_s=actual_duration,
)
await broadcast(chain_id, {
"node_id": node_id,
"status": "ready",
"audio_path": output_path,
"duration_s": actual_duration,
})
logger.info("Node %s ready (%.1fs)", node_id, actual_duration)
except Exception as exc:
logger.exception("Generation failed for node %s: %s", node_id, exc)
chain_svc.update_node_status(conn, node_id, "error", error_msg=str(exc))
await broadcast(chain_id, {
"node_id": node_id,
"status": "error",
"error_msg": str(exc),
})
finally:
conn.close()