# app/api/endpoints/nodes.py — branch generation, commit, discard, audio stream from __future__ import annotations import asyncio import logging from pathlib import Path from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException from fastapi.responses import FileResponse from app.api.deps import get_conn, get_data_dir from app.api.events_store import broadcast from app.models.schemas.node import BranchRequest, NodeRow from app.services import chain as chain_svc from app.services.musicgen import MusicGenClient, make_output_path logger = logging.getLogger(__name__) router = APIRouter(prefix="/api/nodes", tags=["nodes"]) _musicgen = MusicGenClient() @router.post("/{parent_id}/branch", response_model=NodeRow, status_code=202) async def create_branch( parent_id: str, body: BranchRequest, background_tasks: BackgroundTasks, conn=Depends(get_conn), data_dir: str = Depends(get_data_dir), ) -> dict: """ Create a branch node from parent_id and kick off async generation. Returns 202 immediately with the node in 'pending' state. Status transitions (pending → generating → ready/error) are pushed via the chain's SSE stream. """ parent = chain_svc.get_node(conn, parent_id) if parent is None: raise HTTPException(status_code=404, detail="Parent node not found.") if parent["audio_path"] is None: raise HTTPException( status_code=409, detail="Parent node has no audio yet." ) node = chain_svc.create_branch_node( conn, parent_id=parent_id, chain_id=parent["chain_id"], prompt=body.prompt, energy=body.energy, tempo_feel=body.tempo_feel, density=body.density, cfg_coef=body.cfg_coef, prompt_duration_s=body.prompt_duration_s, ) background_tasks.add_task( _run_generation, node_id=node["id"], chain_id=parent["chain_id"], source_audio_path=parent["audio_path"], prompt=body.prompt, duration_s=body.duration_s, cfg_coef=body.cfg_coef, prompt_duration_s=body.prompt_duration_s, data_dir=data_dir, db_path=conn.execute("PRAGMA database_list").fetchone()[2], ) return node @router.post("/{node_id}/commit", response_model=NodeRow) def commit_node(node_id: str, conn=Depends(get_conn)) -> dict: """ Promote a branch node to the committed spine. Discards all non-committed siblings (same parent). The root node is always committed and cannot be targeted here. """ node = chain_svc.get_node(conn, node_id) if node is None: raise HTTPException(status_code=404, detail="Node not found.") if node["status"] != "ready": raise HTTPException( status_code=409, detail="Only 'ready' nodes can be committed." ) result = chain_svc.commit_node(conn, node_id) if result is None: raise HTTPException(status_code=404, detail="Node not found.") return result @router.delete("/{node_id}", status_code=204) def delete_node(node_id: str, conn=Depends(get_conn)) -> None: """ Discard a non-committed branch node. Root nodes (parent_id IS NULL) and committed nodes cannot be deleted. """ if not chain_svc.delete_node(conn, node_id): raise HTTPException( status_code=409, detail="Node not found or cannot be deleted (root or committed).", ) @router.get("/{node_id}/audio") def stream_audio(node_id: str, conn=Depends(get_conn)) -> FileResponse: """Stream the audio file for a ready node.""" node = chain_svc.get_node(conn, node_id) if node is None: raise HTTPException(status_code=404, detail="Node not found.") if node["audio_path"] is None or not Path(node["audio_path"]).exists(): raise HTTPException(status_code=404, detail="Audio not available.") return FileResponse( node["audio_path"], media_type="audio/wav", filename=f"{node_id}.wav", ) # ── Background generation task ──────────────────────────────────────────────── async def _run_generation( *, node_id: str, chain_id: str, source_audio_path: str, prompt: str, duration_s: float, cfg_coef: float, prompt_duration_s: float, data_dir: str, db_path: str, ) -> None: """ Background task: run MusicGen and update node status. Opens its own DB connection so it doesn't block the request thread. """ from app.db.store import get_connection conn = get_connection(db_path) try: chain_svc.update_node_status(conn, node_id, "generating") await broadcast(chain_id, {"node_id": node_id, "status": "generating"}) output_path = make_output_path(data_dir, chain_id, node_id) actual_duration = await _musicgen.generate( source_audio_path=source_audio_path, output_path=output_path, prompt=prompt, duration_s=duration_s, cfg_coef=cfg_coef, prompt_duration_s=prompt_duration_s, ) chain_svc.update_node_status( conn, node_id, "ready", audio_path=output_path, duration_s=actual_duration, ) await broadcast(chain_id, { "node_id": node_id, "status": "ready", "audio_path": output_path, "duration_s": actual_duration, }) logger.info("Node %s ready (%.1fs)", node_id, actual_duration) except Exception as exc: logger.exception("Generation failed for node %s: %s", node_id, exc) chain_svc.update_node_status(conn, node_id, "error", error_msg=str(exc)) await broadcast(chain_id, { "node_id": node_id, "status": "error", "error_msg": str(exc), }) finally: conn.close()