Full FastAPI backend for the AI music continuation editor:
Services
- chain.py: chain + node CRUD, commit/discard, recursive CTE spine query
- musicgen.py: MusicGenClient with cf-orch allocation + mock mode (CF_MUSICGEN_MOCK=1)
- stems.py: Demucs 4-stem separation subprocess wrapper + mock mode
- export.py: ffmpeg concat demuxer to stitch committed spine into WAV/MP3
API endpoints
- chains: CRUD, multipart audio upload (WAV/MP3/FLAC/OGG/M4A/AIFF)
- nodes: branch creation (202 + BackgroundTasks), commit, discard, audio stream
- gpu: cf-orch capacity status; session allocation stubbed pending cf-orch#43
- stems: Paid-tier stem separation (Demucs, gated via tiers.py)
- export: POST /{chain_id}/export → FileResponse download
- events: SSE stream (node-status events) per chain via asyncio Queue pub/sub
Infrastructure
- lifespan: reads SPARROW_DB_PATH/DATA_DIR at startup (not import time)
- events_store: subscribe/unsubscribe/broadcast pattern for SSE
- CORS: open in dev, SPARROW_CORS_ORIGINS in production
- Background generation opens its own DB connection (WAL-safe)
Tests: 30/30 passing across service units and API integration
90 lines
2.9 KiB
Python
90 lines
2.9 KiB
Python
# app/api/endpoints/chains.py — chain CRUD + root audio upload
|
|
from __future__ import annotations
|
|
|
|
import os
|
|
import uuid
|
|
from pathlib import Path
|
|
|
|
from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile
|
|
|
|
from app.api.deps import get_conn, get_data_dir
|
|
from app.models.schemas.chain import ChainCreate, ChainRow, ChainTree
|
|
from app.models.schemas.node import NodeRow
|
|
from app.services import chain as chain_svc
|
|
|
|
router = APIRouter(prefix="/api/chains", tags=["chains"])
|
|
|
|
_ALLOWED_AUDIO = {".wav", ".mp3", ".flac", ".ogg", ".m4a", ".aiff"}
|
|
|
|
|
|
@router.post("/", response_model=ChainRow, status_code=201)
|
|
def create_chain(body: ChainCreate, conn=Depends(get_conn)) -> dict:
|
|
return chain_svc.create_chain(conn, body.name)
|
|
|
|
|
|
@router.get("/", response_model=list[ChainRow])
|
|
def list_chains(conn=Depends(get_conn)) -> list:
|
|
return chain_svc.list_chains(conn)
|
|
|
|
|
|
@router.get("/{chain_id}", response_model=ChainTree)
|
|
def get_chain(chain_id: str, conn=Depends(get_conn)) -> dict:
|
|
chain = chain_svc.get_chain(conn, chain_id)
|
|
if chain is None:
|
|
raise HTTPException(status_code=404, detail="Chain not found.")
|
|
return chain
|
|
|
|
|
|
@router.delete("/{chain_id}", status_code=204)
|
|
def delete_chain(chain_id: str, conn=Depends(get_conn)) -> None:
|
|
if not chain_svc.delete_chain(conn, chain_id):
|
|
raise HTTPException(status_code=404, detail="Chain not found.")
|
|
|
|
|
|
@router.post("/{chain_id}/upload", response_model=NodeRow, status_code=201)
|
|
async def upload_root(
|
|
chain_id: str,
|
|
file: UploadFile = File(...),
|
|
conn=Depends(get_conn),
|
|
data_dir: str = Depends(get_data_dir),
|
|
) -> dict:
|
|
"""
|
|
Upload source audio and create the root node for a chain.
|
|
|
|
Accepts WAV, MP3, FLAC, OGG, M4A, AIFF. The file is stored in
|
|
data/chains/{chain_id}/nodes/root_{node_id}{ext} and a root node
|
|
(always committed) is created in the DB.
|
|
"""
|
|
chain = chain_svc.get_chain(conn, chain_id)
|
|
if chain is None:
|
|
raise HTTPException(status_code=404, detail="Chain not found.")
|
|
|
|
suffix = Path(file.filename or "audio.wav").suffix.lower()
|
|
if suffix not in _ALLOWED_AUDIO:
|
|
raise HTTPException(
|
|
status_code=422,
|
|
detail=f"Unsupported audio format '{suffix}'. "
|
|
f"Allowed: {', '.join(sorted(_ALLOWED_AUDIO))}",
|
|
)
|
|
|
|
node_id = str(uuid.uuid4())
|
|
dest_dir = Path(data_dir) / "chains" / chain_id / "nodes"
|
|
dest_dir.mkdir(parents=True, exist_ok=True)
|
|
audio_path = str(dest_dir / f"root_{node_id}{suffix}")
|
|
|
|
contents = await file.read()
|
|
Path(audio_path).write_bytes(contents)
|
|
|
|
# Get duration via torchaudio if available, else 0.0
|
|
duration_s = _probe_duration(audio_path)
|
|
|
|
return chain_svc.create_root_node(conn, chain_id, audio_path, duration_s)
|
|
|
|
|
|
def _probe_duration(path: str) -> float:
|
|
try:
|
|
import torchaudio
|
|
info = torchaudio.info(path)
|
|
return info.num_frames / info.sample_rate
|
|
except Exception:
|
|
return 0.0
|