Full FastAPI backend for the AI music continuation editor:
Services
- chain.py: chain + node CRUD, commit/discard, recursive CTE spine query
- musicgen.py: MusicGenClient with cf-orch allocation + mock mode (CF_MUSICGEN_MOCK=1)
- stems.py: Demucs 4-stem separation subprocess wrapper + mock mode
- export.py: ffmpeg concat demuxer to stitch committed spine into WAV/MP3
API endpoints
- chains: CRUD, multipart audio upload (WAV/MP3/FLAC/OGG/M4A/AIFF)
- nodes: branch creation (202 + BackgroundTasks), commit, discard, audio stream
- gpu: cf-orch capacity status; session allocation stubbed pending cf-orch#43
- stems: Paid-tier stem separation (Demucs, gated via tiers.py)
- export: POST /{chain_id}/export → FileResponse download
- events: SSE stream (node-status events) per chain via asyncio Queue pub/sub
Infrastructure
- lifespan: reads SPARROW_DB_PATH/DATA_DIR at startup (not import time)
- events_store: subscribe/unsubscribe/broadcast pattern for SSE
- CORS: open in dev, SPARROW_CORS_ORIGINS in production
- Background generation opens its own DB connection (WAL-safe)
Tests: 30/30 passing across service units and API integration
182 lines
5.8 KiB
Python
182 lines
5.8 KiB
Python
# app/api/endpoints/nodes.py — branch generation, commit, discard, audio stream
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import logging
|
|
from pathlib import Path
|
|
|
|
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException
|
|
from fastapi.responses import FileResponse
|
|
|
|
from app.api.deps import get_conn, get_data_dir
|
|
from app.api.events_store import broadcast
|
|
from app.models.schemas.node import BranchRequest, NodeRow
|
|
from app.services import chain as chain_svc
|
|
from app.services.musicgen import MusicGenClient, make_output_path
|
|
|
|
logger = logging.getLogger(__name__)
|
|
router = APIRouter(prefix="/api/nodes", tags=["nodes"])
|
|
|
|
_musicgen = MusicGenClient()
|
|
|
|
|
|
@router.post("/{parent_id}/branch", response_model=NodeRow, status_code=202)
|
|
async def create_branch(
|
|
parent_id: str,
|
|
body: BranchRequest,
|
|
background_tasks: BackgroundTasks,
|
|
conn=Depends(get_conn),
|
|
data_dir: str = Depends(get_data_dir),
|
|
) -> dict:
|
|
"""
|
|
Create a branch node from parent_id and kick off async generation.
|
|
|
|
Returns 202 immediately with the node in 'pending' state.
|
|
Status transitions (pending → generating → ready/error) are pushed
|
|
via the chain's SSE stream.
|
|
"""
|
|
parent = chain_svc.get_node(conn, parent_id)
|
|
if parent is None:
|
|
raise HTTPException(status_code=404, detail="Parent node not found.")
|
|
if parent["audio_path"] is None:
|
|
raise HTTPException(
|
|
status_code=409, detail="Parent node has no audio yet."
|
|
)
|
|
|
|
node = chain_svc.create_branch_node(
|
|
conn,
|
|
parent_id=parent_id,
|
|
chain_id=parent["chain_id"],
|
|
prompt=body.prompt,
|
|
energy=body.energy,
|
|
tempo_feel=body.tempo_feel,
|
|
density=body.density,
|
|
cfg_coef=body.cfg_coef,
|
|
prompt_duration_s=body.prompt_duration_s,
|
|
)
|
|
|
|
background_tasks.add_task(
|
|
_run_generation,
|
|
node_id=node["id"],
|
|
chain_id=parent["chain_id"],
|
|
source_audio_path=parent["audio_path"],
|
|
prompt=body.prompt,
|
|
duration_s=body.duration_s,
|
|
cfg_coef=body.cfg_coef,
|
|
prompt_duration_s=body.prompt_duration_s,
|
|
data_dir=data_dir,
|
|
db_path=conn.execute("PRAGMA database_list").fetchone()[2],
|
|
)
|
|
|
|
return node
|
|
|
|
|
|
@router.post("/{node_id}/commit", response_model=NodeRow)
|
|
def commit_node(node_id: str, conn=Depends(get_conn)) -> dict:
|
|
"""
|
|
Promote a branch node to the committed spine.
|
|
|
|
Discards all non-committed siblings (same parent). The root node is
|
|
always committed and cannot be targeted here.
|
|
"""
|
|
node = chain_svc.get_node(conn, node_id)
|
|
if node is None:
|
|
raise HTTPException(status_code=404, detail="Node not found.")
|
|
if node["status"] != "ready":
|
|
raise HTTPException(
|
|
status_code=409, detail="Only 'ready' nodes can be committed."
|
|
)
|
|
result = chain_svc.commit_node(conn, node_id)
|
|
if result is None:
|
|
raise HTTPException(status_code=404, detail="Node not found.")
|
|
return result
|
|
|
|
|
|
@router.delete("/{node_id}", status_code=204)
|
|
def delete_node(node_id: str, conn=Depends(get_conn)) -> None:
|
|
"""
|
|
Discard a non-committed branch node.
|
|
|
|
Root nodes (parent_id IS NULL) and committed nodes cannot be deleted.
|
|
"""
|
|
if not chain_svc.delete_node(conn, node_id):
|
|
raise HTTPException(
|
|
status_code=409,
|
|
detail="Node not found or cannot be deleted (root or committed).",
|
|
)
|
|
|
|
|
|
@router.get("/{node_id}/audio")
|
|
def stream_audio(node_id: str, conn=Depends(get_conn)) -> FileResponse:
|
|
"""Stream the audio file for a ready node."""
|
|
node = chain_svc.get_node(conn, node_id)
|
|
if node is None:
|
|
raise HTTPException(status_code=404, detail="Node not found.")
|
|
if node["audio_path"] is None or not Path(node["audio_path"]).exists():
|
|
raise HTTPException(status_code=404, detail="Audio not available.")
|
|
return FileResponse(
|
|
node["audio_path"],
|
|
media_type="audio/wav",
|
|
filename=f"{node_id}.wav",
|
|
)
|
|
|
|
|
|
# ── Background generation task ────────────────────────────────────────────────
|
|
|
|
async def _run_generation(
|
|
*,
|
|
node_id: str,
|
|
chain_id: str,
|
|
source_audio_path: str,
|
|
prompt: str,
|
|
duration_s: float,
|
|
cfg_coef: float,
|
|
prompt_duration_s: float,
|
|
data_dir: str,
|
|
db_path: str,
|
|
) -> None:
|
|
"""
|
|
Background task: run MusicGen and update node status.
|
|
|
|
Opens its own DB connection so it doesn't block the request thread.
|
|
"""
|
|
from app.db.store import get_connection
|
|
|
|
conn = get_connection(db_path)
|
|
try:
|
|
chain_svc.update_node_status(conn, node_id, "generating")
|
|
await broadcast(chain_id, {"node_id": node_id, "status": "generating"})
|
|
|
|
output_path = make_output_path(data_dir, chain_id, node_id)
|
|
actual_duration = await _musicgen.generate(
|
|
source_audio_path=source_audio_path,
|
|
output_path=output_path,
|
|
prompt=prompt,
|
|
duration_s=duration_s,
|
|
cfg_coef=cfg_coef,
|
|
prompt_duration_s=prompt_duration_s,
|
|
)
|
|
|
|
chain_svc.update_node_status(
|
|
conn, node_id, "ready",
|
|
audio_path=output_path,
|
|
duration_s=actual_duration,
|
|
)
|
|
await broadcast(chain_id, {
|
|
"node_id": node_id,
|
|
"status": "ready",
|
|
"audio_path": output_path,
|
|
"duration_s": actual_duration,
|
|
})
|
|
logger.info("Node %s ready (%.1fs)", node_id, actual_duration)
|
|
|
|
except Exception as exc:
|
|
logger.exception("Generation failed for node %s: %s", node_id, exc)
|
|
chain_svc.update_node_status(conn, node_id, "error", error_msg=str(exc))
|
|
await broadcast(chain_id, {
|
|
"node_id": node_id,
|
|
"status": "error",
|
|
"error_msg": str(exc),
|
|
})
|
|
finally:
|
|
conn.close()
|