From a6f60c9e076f557776b9d6790e21690f6774cf8f Mon Sep 17 00:00:00 2001 From: pyr0ball Date: Fri, 17 Apr 2026 15:22:37 -0700 Subject: [PATCH] feat: implement Sparrow backend (v0.1.0) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Full FastAPI backend for the AI music continuation editor: Services - chain.py: chain + node CRUD, commit/discard, recursive CTE spine query - musicgen.py: MusicGenClient with cf-orch allocation + mock mode (CF_MUSICGEN_MOCK=1) - stems.py: Demucs 4-stem separation subprocess wrapper + mock mode - export.py: ffmpeg concat demuxer to stitch committed spine into WAV/MP3 API endpoints - chains: CRUD, multipart audio upload (WAV/MP3/FLAC/OGG/M4A/AIFF) - nodes: branch creation (202 + BackgroundTasks), commit, discard, audio stream - gpu: cf-orch capacity status; session allocation stubbed pending cf-orch#43 - stems: Paid-tier stem separation (Demucs, gated via tiers.py) - export: POST /{chain_id}/export → FileResponse download - events: SSE stream (node-status events) per chain via asyncio Queue pub/sub Infrastructure - lifespan: reads SPARROW_DB_PATH/DATA_DIR at startup (not import time) - events_store: subscribe/unsubscribe/broadcast pattern for SSE - CORS: open in dev, SPARROW_CORS_ORIGINS in production - Background generation opens its own DB connection (WAL-safe) Tests: 30/30 passing across service units and API integration --- app/api/deps.py | 17 +++ app/api/endpoints/chains.py | 90 +++++++++++++++ app/api/endpoints/events.py | 54 +++++++++ app/api/endpoints/export.py | 64 +++++++++++ app/api/endpoints/gpu.py | 73 ++++++++++++ app/api/endpoints/nodes.py | 182 +++++++++++++++++++++++++++++ app/api/endpoints/stems.py | 67 +++++++++++ app/api/events_store.py | 34 ++++++ app/api/routes.py | 14 +++ app/main.py | 66 +++++++++++ app/models/schemas/chain.py | 27 +++++ app/models/schemas/node.py | 49 ++++++++ app/services/chain.py | 203 +++++++++++++++++++++++++++++++++ app/services/export.py | 80 +++++++++++++ app/services/musicgen.py | 157 +++++++++++++++++++++++++ app/services/stems.py | 75 ++++++++++++ app/tiers.py | 52 +++++++++ tests/conftest.py | 65 +++++++++++ tests/test_api_chains.py | 69 +++++++++++ tests/test_api_nodes.py | 102 +++++++++++++++++ tests/test_chain_service.py | 150 ++++++++++++++++++++++++ tests/test_export_service.py | 34 ++++++ tests/test_musicgen_service.py | 36 ++++++ 23 files changed, 1760 insertions(+) create mode 100644 app/api/deps.py create mode 100644 app/api/endpoints/chains.py create mode 100644 app/api/endpoints/events.py create mode 100644 app/api/endpoints/export.py create mode 100644 app/api/endpoints/gpu.py create mode 100644 app/api/endpoints/nodes.py create mode 100644 app/api/endpoints/stems.py create mode 100644 app/api/events_store.py create mode 100644 app/api/routes.py create mode 100644 app/main.py create mode 100644 app/models/schemas/chain.py create mode 100644 app/models/schemas/node.py create mode 100644 app/services/chain.py create mode 100644 app/services/export.py create mode 100644 app/services/musicgen.py create mode 100644 app/services/stems.py create mode 100644 app/tiers.py create mode 100644 tests/conftest.py create mode 100644 tests/test_api_chains.py create mode 100644 tests/test_api_nodes.py create mode 100644 tests/test_chain_service.py create mode 100644 tests/test_export_service.py create mode 100644 tests/test_musicgen_service.py diff --git a/app/api/deps.py b/app/api/deps.py new file mode 100644 index 0000000..9fca206 --- /dev/null +++ b/app/api/deps.py @@ -0,0 +1,17 @@ +# app/api/deps.py — FastAPI dependency providers +from __future__ import annotations + +import os +import sqlite3 + +from fastapi import Request + + +def get_conn(request: Request) -> sqlite3.Connection: + """Yield the shared SQLite connection from app state.""" + return request.app.state.conn + + +def get_data_dir(request: Request) -> str: + """Return the configured data directory.""" + return request.app.state.data_dir diff --git a/app/api/endpoints/chains.py b/app/api/endpoints/chains.py new file mode 100644 index 0000000..706320f --- /dev/null +++ b/app/api/endpoints/chains.py @@ -0,0 +1,90 @@ +# app/api/endpoints/chains.py — chain CRUD + root audio upload +from __future__ import annotations + +import os +import uuid +from pathlib import Path + +from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile + +from app.api.deps import get_conn, get_data_dir +from app.models.schemas.chain import ChainCreate, ChainRow, ChainTree +from app.models.schemas.node import NodeRow +from app.services import chain as chain_svc + +router = APIRouter(prefix="/api/chains", tags=["chains"]) + +_ALLOWED_AUDIO = {".wav", ".mp3", ".flac", ".ogg", ".m4a", ".aiff"} + + +@router.post("/", response_model=ChainRow, status_code=201) +def create_chain(body: ChainCreate, conn=Depends(get_conn)) -> dict: + return chain_svc.create_chain(conn, body.name) + + +@router.get("/", response_model=list[ChainRow]) +def list_chains(conn=Depends(get_conn)) -> list: + return chain_svc.list_chains(conn) + + +@router.get("/{chain_id}", response_model=ChainTree) +def get_chain(chain_id: str, conn=Depends(get_conn)) -> dict: + chain = chain_svc.get_chain(conn, chain_id) + if chain is None: + raise HTTPException(status_code=404, detail="Chain not found.") + return chain + + +@router.delete("/{chain_id}", status_code=204) +def delete_chain(chain_id: str, conn=Depends(get_conn)) -> None: + if not chain_svc.delete_chain(conn, chain_id): + raise HTTPException(status_code=404, detail="Chain not found.") + + +@router.post("/{chain_id}/upload", response_model=NodeRow, status_code=201) +async def upload_root( + chain_id: str, + file: UploadFile = File(...), + conn=Depends(get_conn), + data_dir: str = Depends(get_data_dir), +) -> dict: + """ + Upload source audio and create the root node for a chain. + + Accepts WAV, MP3, FLAC, OGG, M4A, AIFF. The file is stored in + data/chains/{chain_id}/nodes/root_{node_id}{ext} and a root node + (always committed) is created in the DB. + """ + chain = chain_svc.get_chain(conn, chain_id) + if chain is None: + raise HTTPException(status_code=404, detail="Chain not found.") + + suffix = Path(file.filename or "audio.wav").suffix.lower() + if suffix not in _ALLOWED_AUDIO: + raise HTTPException( + status_code=422, + detail=f"Unsupported audio format '{suffix}'. " + f"Allowed: {', '.join(sorted(_ALLOWED_AUDIO))}", + ) + + node_id = str(uuid.uuid4()) + dest_dir = Path(data_dir) / "chains" / chain_id / "nodes" + dest_dir.mkdir(parents=True, exist_ok=True) + audio_path = str(dest_dir / f"root_{node_id}{suffix}") + + contents = await file.read() + Path(audio_path).write_bytes(contents) + + # Get duration via torchaudio if available, else 0.0 + duration_s = _probe_duration(audio_path) + + return chain_svc.create_root_node(conn, chain_id, audio_path, duration_s) + + +def _probe_duration(path: str) -> float: + try: + import torchaudio + info = torchaudio.info(path) + return info.num_frames / info.sample_rate + except Exception: + return 0.0 diff --git a/app/api/endpoints/events.py b/app/api/endpoints/events.py new file mode 100644 index 0000000..e1033a3 --- /dev/null +++ b/app/api/endpoints/events.py @@ -0,0 +1,54 @@ +# app/api/endpoints/events.py — SSE stream for node status transitions +# +# Clients connect to GET /api/chains/{chain_id}/events and receive +# server-sent events whenever a node in that chain changes status. +# +# Event format: +# event: node-status +# data: {"node_id": "...", "status": "generating"|"ready"|"error", ...} +from __future__ import annotations + +import asyncio +import json +import logging + +from fastapi import APIRouter, Depends +from sse_starlette.sse import EventSourceResponse + +from app.api.deps import get_conn +from app.api.events_store import subscribe, unsubscribe + +logger = logging.getLogger(__name__) +router = APIRouter(prefix="/api/chains", tags=["events"]) + +_KEEPALIVE_S = 15 # send a comment ping every N seconds to keep the connection alive + + +@router.get("/{chain_id}/events") +async def chain_events(chain_id: str, conn=Depends(get_conn)) -> EventSourceResponse: + """ + SSE stream for node status transitions in a chain. + + Emits 'node-status' events. Closes when the client disconnects. + """ + q = subscribe(chain_id) + + async def generator(): + try: + while True: + try: + event = await asyncio.wait_for(q.get(), timeout=_KEEPALIVE_S) + yield { + "event": "node-status", + "data": json.dumps(event), + } + except asyncio.TimeoutError: + # Send keepalive comment to prevent proxy/browser timeout + yield {"comment": "keepalive"} + except asyncio.CancelledError: + pass + finally: + unsubscribe(chain_id, q) + logger.debug("SSE client disconnected from chain %s", chain_id) + + return EventSourceResponse(generator()) diff --git a/app/api/endpoints/export.py b/app/api/endpoints/export.py new file mode 100644 index 0000000..8973648 --- /dev/null +++ b/app/api/endpoints/export.py @@ -0,0 +1,64 @@ +# app/api/endpoints/export.py — stitch committed spine and serve as download +from __future__ import annotations + +import logging +from pathlib import Path + +from fastapi import APIRouter, Depends, HTTPException +from fastapi.responses import FileResponse + +from app.api.deps import get_conn, get_data_dir +from app.models.schemas.node import ExportRequest +from app.services import chain as chain_svc +from app.services.export import ExportFormat, make_export_path, stitch + +logger = logging.getLogger(__name__) +router = APIRouter(prefix="/api/chains", tags=["export"]) + + +@router.post("/{chain_id}/export") +async def export_chain( + chain_id: str, + body: ExportRequest, + conn=Depends(get_conn), + data_dir: str = Depends(get_data_dir), +) -> FileResponse: + """ + Stitch node audio files into a single download. + + If body.node_ids is provided, those nodes are used in order. + Otherwise the committed spine (root → leaf) is used. + Raises 409 if any targeted node has no audio. + """ + if body.node_ids is not None: + nodes = [chain_svc.get_node(conn, nid) for nid in body.node_ids] + if any(n is None for n in nodes): + raise HTTPException(status_code=404, detail="One or more nodes not found.") + else: + nodes = chain_svc.get_committed_spine(conn, chain_id) + if not nodes: + raise HTTPException( + status_code=404, + detail="Chain not found or has no committed nodes.", + ) + + audio_paths = [n["audio_path"] for n in nodes if n] # type: ignore[index] + missing = [n["id"] for n in nodes if n and n["audio_path"] is None] # type: ignore[index] + if missing: + raise HTTPException( + status_code=409, + detail=f"Nodes {missing} have no audio yet.", + ) + + fmt: ExportFormat = body.format # type: ignore[assignment] + output_path = make_export_path(data_dir, chain_id, fmt) + + try: + await stitch(audio_paths, output_path, fmt) + except Exception as exc: + logger.exception("Export failed for chain %s: %s", chain_id, exc) + raise HTTPException(status_code=500, detail=f"Export failed: {exc}") from exc + + media_type = "audio/wav" if fmt == "wav" else "audio/mpeg" + filename = f"sparrow_export_{chain_id[:8]}.{fmt}" + return FileResponse(output_path, media_type=media_type, filename=filename) diff --git a/app/api/endpoints/gpu.py b/app/api/endpoints/gpu.py new file mode 100644 index 0000000..0322ac7 --- /dev/null +++ b/app/api/endpoints/gpu.py @@ -0,0 +1,73 @@ +# app/api/endpoints/gpu.py — cf-orch GPU status and session allocation +# +# GET /api/gpu/status — available capacity from cf-orch +# POST /api/gpu/connect — session-held allocation (Premium tier, stub) +# DELETE /api/gpu/disconnect — release session allocation (Premium tier, stub) +from __future__ import annotations + +import logging +import os + +import httpx +from fastapi import APIRouter, HTTPException + +logger = logging.getLogger(__name__) +router = APIRouter(prefix="/api/gpu", tags=["gpu"]) + +_ORCH_URL = os.environ.get("CF_ORCH_URL", "").rstrip("/") +_SERVICE = "cf-musicgen" + + +@router.get("/status") +async def gpu_status() -> dict: + """ + Return current cf-orch capacity for cf-musicgen. + + Returns {"available": False, "reason": "..."} if cf-orch is unreachable + or unconfigured (e.g. mock mode). + """ + if not _ORCH_URL: + return { + "available": False, + "reason": "CF_ORCH_URL not configured — running in mock mode.", + "mock": True, + } + try: + async with httpx.AsyncClient(timeout=5.0) as client: + resp = await client.get( + f"{_ORCH_URL}/api/services/{_SERVICE}/status" + ) + resp.raise_for_status() + return resp.json() + except httpx.HTTPStatusError as exc: + raise HTTPException( + status_code=502, + detail=f"cf-orch returned {exc.response.status_code}.", + ) from exc + except Exception as exc: + raise HTTPException( + status_code=502, + detail=f"cf-orch unreachable: {exc}", + ) from exc + + +@router.post("/connect", status_code=501) +async def gpu_connect() -> dict: + """ + Session-held GPU allocation (Premium tier). + + Not yet implemented — tracked in cf-orch #43. + """ + raise HTTPException( + status_code=501, + detail="Session-held GPU allocation is not yet implemented (cf-orch #43).", + ) + + +@router.delete("/disconnect", status_code=501) +async def gpu_disconnect() -> dict: + """Release a session-held GPU allocation (Premium tier, stub).""" + raise HTTPException( + status_code=501, + detail="Session-held GPU allocation is not yet implemented (cf-orch #43).", + ) diff --git a/app/api/endpoints/nodes.py b/app/api/endpoints/nodes.py new file mode 100644 index 0000000..7c7c6eb --- /dev/null +++ b/app/api/endpoints/nodes.py @@ -0,0 +1,182 @@ +# app/api/endpoints/nodes.py — branch generation, commit, discard, audio stream +from __future__ import annotations + +import asyncio +import logging +from pathlib import Path + +from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException +from fastapi.responses import FileResponse + +from app.api.deps import get_conn, get_data_dir +from app.api.events_store import broadcast +from app.models.schemas.node import BranchRequest, NodeRow +from app.services import chain as chain_svc +from app.services.musicgen import MusicGenClient, make_output_path + +logger = logging.getLogger(__name__) +router = APIRouter(prefix="/api/nodes", tags=["nodes"]) + +_musicgen = MusicGenClient() + + +@router.post("/{parent_id}/branch", response_model=NodeRow, status_code=202) +async def create_branch( + parent_id: str, + body: BranchRequest, + background_tasks: BackgroundTasks, + conn=Depends(get_conn), + data_dir: str = Depends(get_data_dir), +) -> dict: + """ + Create a branch node from parent_id and kick off async generation. + + Returns 202 immediately with the node in 'pending' state. + Status transitions (pending → generating → ready/error) are pushed + via the chain's SSE stream. + """ + parent = chain_svc.get_node(conn, parent_id) + if parent is None: + raise HTTPException(status_code=404, detail="Parent node not found.") + if parent["audio_path"] is None: + raise HTTPException( + status_code=409, detail="Parent node has no audio yet." + ) + + node = chain_svc.create_branch_node( + conn, + parent_id=parent_id, + chain_id=parent["chain_id"], + prompt=body.prompt, + energy=body.energy, + tempo_feel=body.tempo_feel, + density=body.density, + cfg_coef=body.cfg_coef, + prompt_duration_s=body.prompt_duration_s, + ) + + background_tasks.add_task( + _run_generation, + node_id=node["id"], + chain_id=parent["chain_id"], + source_audio_path=parent["audio_path"], + prompt=body.prompt, + duration_s=body.duration_s, + cfg_coef=body.cfg_coef, + prompt_duration_s=body.prompt_duration_s, + data_dir=data_dir, + db_path=conn.execute("PRAGMA database_list").fetchone()[2], + ) + + return node + + +@router.post("/{node_id}/commit", response_model=NodeRow) +def commit_node(node_id: str, conn=Depends(get_conn)) -> dict: + """ + Promote a branch node to the committed spine. + + Discards all non-committed siblings (same parent). The root node is + always committed and cannot be targeted here. + """ + node = chain_svc.get_node(conn, node_id) + if node is None: + raise HTTPException(status_code=404, detail="Node not found.") + if node["status"] != "ready": + raise HTTPException( + status_code=409, detail="Only 'ready' nodes can be committed." + ) + result = chain_svc.commit_node(conn, node_id) + if result is None: + raise HTTPException(status_code=404, detail="Node not found.") + return result + + +@router.delete("/{node_id}", status_code=204) +def delete_node(node_id: str, conn=Depends(get_conn)) -> None: + """ + Discard a non-committed branch node. + + Root nodes (parent_id IS NULL) and committed nodes cannot be deleted. + """ + if not chain_svc.delete_node(conn, node_id): + raise HTTPException( + status_code=409, + detail="Node not found or cannot be deleted (root or committed).", + ) + + +@router.get("/{node_id}/audio") +def stream_audio(node_id: str, conn=Depends(get_conn)) -> FileResponse: + """Stream the audio file for a ready node.""" + node = chain_svc.get_node(conn, node_id) + if node is None: + raise HTTPException(status_code=404, detail="Node not found.") + if node["audio_path"] is None or not Path(node["audio_path"]).exists(): + raise HTTPException(status_code=404, detail="Audio not available.") + return FileResponse( + node["audio_path"], + media_type="audio/wav", + filename=f"{node_id}.wav", + ) + + +# ── Background generation task ──────────────────────────────────────────────── + +async def _run_generation( + *, + node_id: str, + chain_id: str, + source_audio_path: str, + prompt: str, + duration_s: float, + cfg_coef: float, + prompt_duration_s: float, + data_dir: str, + db_path: str, +) -> None: + """ + Background task: run MusicGen and update node status. + + Opens its own DB connection so it doesn't block the request thread. + """ + from app.db.store import get_connection + + conn = get_connection(db_path) + try: + chain_svc.update_node_status(conn, node_id, "generating") + await broadcast(chain_id, {"node_id": node_id, "status": "generating"}) + + output_path = make_output_path(data_dir, chain_id, node_id) + actual_duration = await _musicgen.generate( + source_audio_path=source_audio_path, + output_path=output_path, + prompt=prompt, + duration_s=duration_s, + cfg_coef=cfg_coef, + prompt_duration_s=prompt_duration_s, + ) + + chain_svc.update_node_status( + conn, node_id, "ready", + audio_path=output_path, + duration_s=actual_duration, + ) + await broadcast(chain_id, { + "node_id": node_id, + "status": "ready", + "audio_path": output_path, + "duration_s": actual_duration, + }) + logger.info("Node %s ready (%.1fs)", node_id, actual_duration) + + except Exception as exc: + logger.exception("Generation failed for node %s: %s", node_id, exc) + chain_svc.update_node_status(conn, node_id, "error", error_msg=str(exc)) + await broadcast(chain_id, { + "node_id": node_id, + "status": "error", + "error_msg": str(exc), + }) + finally: + conn.close() diff --git a/app/api/endpoints/stems.py b/app/api/endpoints/stems.py new file mode 100644 index 0000000..6024beb --- /dev/null +++ b/app/api/endpoints/stems.py @@ -0,0 +1,67 @@ +# app/api/endpoints/stems.py — 4-stem separation (Paid tier) +from __future__ import annotations + +import logging +from pathlib import Path + +from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException + +from app.api.deps import get_conn, get_data_dir +from app.models.schemas.node import StemResult +from app.services import chain as chain_svc +from app.services import stems as stems_svc +from app.tiers import require_tier + +logger = logging.getLogger(__name__) +router = APIRouter(prefix="/api/stems", tags=["stems"]) + + +@router.post("/{node_id}", response_model=StemResult, status_code=202) +async def separate_stems( + node_id: str, + background_tasks: BackgroundTasks, + tier: str | None = None, + conn=Depends(get_conn), + data_dir: str = Depends(get_data_dir), +) -> dict: + """ + Run 4-stem separation (Demucs htdemucs) on a ready node. + + Requires Paid tier. Returns stem file paths immediately — paths will + be valid once the background task completes. Poll GET /api/nodes/{node_id} + or listen on the chain SSE stream for completion signals. + + tier: passed as query param, e.g. ?tier=paid + """ + require_tier("stems", tier) + + node = chain_svc.get_node(conn, node_id) + if node is None: + raise HTTPException(status_code=404, detail="Node not found.") + if node["audio_path"] is None: + raise HTTPException(status_code=409, detail="Node has no audio.") + + stems_dir = stems_svc.make_stems_dir(data_dir, node["chain_id"], node_id) + stem_paths = stems_svc._stem_paths(stems_dir, Path(node["audio_path"]).stem) + + background_tasks.add_task( + _run_separation, + source_audio_path=node["audio_path"], + stems_dir=stems_dir, + ) + + return StemResult( + node_id=node_id, + vocals=stem_paths["vocals"], + drums=stem_paths["drums"], + bass=stem_paths["bass"], + other=stem_paths["other"], + ) + + +async def _run_separation(source_audio_path: str, stems_dir: str) -> None: + try: + await stems_svc.separate(source_audio_path, stems_dir) + logger.info("Stems complete: %s", stems_dir) + except Exception as exc: + logger.exception("Stem separation failed: %s", exc) diff --git a/app/api/events_store.py b/app/api/events_store.py new file mode 100644 index 0000000..b8b0642 --- /dev/null +++ b/app/api/events_store.py @@ -0,0 +1,34 @@ +# app/api/events_store.py — Chain-scoped pub/sub for SSE node-status events +# +# Each SSE connection subscribes with subscribe(chain_id) → Queue. +# Background generation tasks call broadcast(chain_id, event) to push +# node state transitions (pending → generating → ready/error) to all listeners. +from __future__ import annotations + +import asyncio +from collections import defaultdict +from typing import Any + +# chain_id → list of subscriber queues +_listeners: dict[str, list[asyncio.Queue[Any]]] = defaultdict(list) + + +def subscribe(chain_id: str) -> asyncio.Queue[Any]: + """Register a new SSE listener for chain_id. Returns its Queue.""" + q: asyncio.Queue[Any] = asyncio.Queue() + _listeners[chain_id].append(q) + return q + + +def unsubscribe(chain_id: str, q: asyncio.Queue[Any]) -> None: + """Remove a listener Queue. Safe to call if already removed.""" + try: + _listeners[chain_id].remove(q) + except ValueError: + pass + + +async def broadcast(chain_id: str, event: dict[str, Any]) -> None: + """Push event dict to all active listeners for chain_id.""" + for q in list(_listeners.get(chain_id, [])): + await q.put(event) diff --git a/app/api/routes.py b/app/api/routes.py new file mode 100644 index 0000000..133bb49 --- /dev/null +++ b/app/api/routes.py @@ -0,0 +1,14 @@ +# app/api/routes.py — assemble all endpoint routers +from __future__ import annotations + +from fastapi import APIRouter + +from app.api.endpoints import chains, events, export, gpu, nodes, stems + +router = APIRouter() +router.include_router(chains.router) +router.include_router(nodes.router) +router.include_router(gpu.router) +router.include_router(stems.router) +router.include_router(export.router) +router.include_router(events.router) diff --git a/app/main.py b/app/main.py new file mode 100644 index 0000000..f899003 --- /dev/null +++ b/app/main.py @@ -0,0 +1,66 @@ +# app/main.py — Sparrow FastAPI application factory +from __future__ import annotations + +import logging +import os +from contextlib import asynccontextmanager +from pathlib import Path + +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware + +from app.api.routes import router +from app.db.store import get_connection, run_migrations + +logger = logging.getLogger(__name__) + + +@asynccontextmanager +async def lifespan(app: FastAPI): + # Read env at startup — not at import time — so tests can inject values + db_path = os.environ.get("SPARROW_DB_PATH", "data/sparrow.db") + data_dir = os.environ.get("SPARROW_DATA_DIR", "data") + env = os.environ.get("SPARROW_ENV", "development") + + Path(db_path).parent.mkdir(parents=True, exist_ok=True) + Path(data_dir).mkdir(parents=True, exist_ok=True) + + conn = get_connection(db_path) + run_migrations(conn) + app.state.conn = conn + app.state.data_dir = data_dir + + logger.info("Sparrow started (env=%s, db=%s, data=%s)", env, db_path, data_dir) + yield + + conn.close() + logger.info("Sparrow shut down.") + + +def create_app() -> FastAPI: + env = os.environ.get("SPARROW_ENV", "development") + app = FastAPI( + title="Sparrow", + description="AI music continuation — branching chain editor", + version="0.1.0", + lifespan=lifespan, + ) + + # CORS: open in dev, tighten in production via env + origins = ( + ["*"] if env == "development" + else os.environ.get("SPARROW_CORS_ORIGINS", "").split(",") + ) + app.add_middleware( + CORSMiddleware, + allow_origins=origins, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + + app.include_router(router) + return app + + +app = create_app() diff --git a/app/models/schemas/chain.py b/app/models/schemas/chain.py new file mode 100644 index 0000000..79a7abc --- /dev/null +++ b/app/models/schemas/chain.py @@ -0,0 +1,27 @@ +# app/models/schemas/chain.py — Chain Pydantic models +from __future__ import annotations + +from pydantic import BaseModel + + +class ChainCreate(BaseModel): + name: str + + +class ChainRow(BaseModel): + id: str + name: str + created_at: float + node_count: int = 0 + + +class ChainTree(BaseModel): + id: str + name: str + created_at: float + nodes: list["NodeRow"] = [] + + +# Avoid circular import — NodeRow imported at runtime +from app.models.schemas.node import NodeRow # noqa: E402 +ChainTree.model_rebuild() diff --git a/app/models/schemas/node.py b/app/models/schemas/node.py new file mode 100644 index 0000000..ad65768 --- /dev/null +++ b/app/models/schemas/node.py @@ -0,0 +1,49 @@ +# app/models/schemas/node.py — Node Pydantic models +from __future__ import annotations + +from typing import Literal +from pydantic import BaseModel + +NodeStatus = Literal["pending", "generating", "ready", "error"] + + +class NodeRow(BaseModel): + id: str + chain_id: str + parent_id: str | None + audio_path: str | None + duration_s: float | None + status: NodeStatus + is_committed: bool + prompt: str + energy: float | None + tempo_feel: float | None + density: float | None + cfg_coef: float + prompt_duration_s: float + error_msg: str | None + created_at: float + children: list["NodeRow"] = [] + + +class BranchRequest(BaseModel): + prompt: str = "" + energy: float | None = None + tempo_feel: float | None = None + density: float | None = None + cfg_coef: float = 3.0 + prompt_duration_s: float = 10.0 + duration_s: float = 15.0 + + +class StemResult(BaseModel): + node_id: str + vocals: str + drums: str + bass: str + other: str + + +class ExportRequest(BaseModel): + format: Literal["wav", "mp3"] = "wav" + node_ids: list[str] | None = None # if None, use committed spine diff --git a/app/services/chain.py b/app/services/chain.py new file mode 100644 index 0000000..8f3f5e8 --- /dev/null +++ b/app/services/chain.py @@ -0,0 +1,203 @@ +# app/services/chain.py — chain + node CRUD and branching tree logic +from __future__ import annotations + +import sqlite3 +import time +import uuid +from typing import Any + + +# ── Chain CRUD ──────────────────────────────────────────────────────────────── + +def create_chain(conn: sqlite3.Connection, name: str) -> dict[str, Any]: + chain_id = str(uuid.uuid4()) + now = time.time() + conn.execute( + "INSERT INTO chains (id, name, created_at) VALUES (?, ?, ?)", + (chain_id, name, now), + ) + conn.commit() + return {"id": chain_id, "name": name, "created_at": now, "node_count": 0} + + +def list_chains(conn: sqlite3.Connection) -> list[dict[str, Any]]: + rows = conn.execute(""" + SELECT c.id, c.name, c.created_at, + COUNT(n.id) AS node_count + FROM chains c + LEFT JOIN nodes n ON n.chain_id = c.id + GROUP BY c.id + ORDER BY c.created_at DESC + """).fetchall() + return [dict(r) for r in rows] + + +def get_chain(conn: sqlite3.Connection, chain_id: str) -> dict[str, Any] | None: + row = conn.execute( + "SELECT id, name, created_at FROM chains WHERE id = ?", (chain_id,) + ).fetchone() + if row is None: + return None + nodes = _get_nodes_for_chain(conn, chain_id) + return {"id": row["id"], "name": row["name"], "created_at": row["created_at"], "nodes": nodes} + + +def delete_chain(conn: sqlite3.Connection, chain_id: str) -> bool: + cur = conn.execute("DELETE FROM chains WHERE id = ?", (chain_id,)) + conn.commit() + return cur.rowcount > 0 + + +# ── Node CRUD ───────────────────────────────────────────────────────────────── + +def create_root_node( + conn: sqlite3.Connection, + chain_id: str, + audio_path: str, + duration_s: float, +) -> dict[str, Any]: + """Create the root node (uploaded source audio). Always committed.""" + node_id = str(uuid.uuid4()) + now = time.time() + conn.execute(""" + INSERT INTO nodes + (id, chain_id, parent_id, audio_path, duration_s, status, + is_committed, prompt, cfg_coef, prompt_duration_s, created_at) + VALUES (?, ?, NULL, ?, ?, 'ready', 1, '', 3.0, 10.0, ?) + """, (node_id, chain_id, audio_path, duration_s, now)) + conn.commit() + return get_node(conn, node_id) + + +def get_node(conn: sqlite3.Connection, node_id: str) -> dict[str, Any] | None: + row = conn.execute("SELECT * FROM nodes WHERE id = ?", (node_id,)).fetchone() + return _row_to_dict(row) if row else None + + +def create_branch_node( + conn: sqlite3.Connection, + parent_id: str, + chain_id: str, + prompt: str, + energy: float | None, + tempo_feel: float | None, + density: float | None, + cfg_coef: float, + prompt_duration_s: float, +) -> dict[str, Any]: + """Create a new branch node in 'pending' state. Caller updates to 'generating'.""" + node_id = str(uuid.uuid4()) + now = time.time() + conn.execute(""" + INSERT INTO nodes + (id, chain_id, parent_id, audio_path, duration_s, status, + is_committed, prompt, energy, tempo_feel, density, + cfg_coef, prompt_duration_s, created_at) + VALUES (?, ?, ?, NULL, NULL, 'pending', 0, ?, ?, ?, ?, ?, ?, ?) + """, (node_id, chain_id, parent_id, prompt, energy, tempo_feel, density, + cfg_coef, prompt_duration_s, now)) + conn.commit() + return get_node(conn, node_id) + + +def update_node_status( + conn: sqlite3.Connection, + node_id: str, + status: str, + audio_path: str | None = None, + duration_s: float | None = None, + error_msg: str | None = None, +) -> dict[str, Any] | None: + fields: list[str] = ["status = ?"] + values: list[Any] = [status] + if audio_path is not None: + fields.append("audio_path = ?") + values.append(audio_path) + if duration_s is not None: + fields.append("duration_s = ?") + values.append(duration_s) + if error_msg is not None: + fields.append("error_msg = ?") + values.append(error_msg) + values.append(node_id) + conn.execute(f"UPDATE nodes SET {', '.join(fields)} WHERE id = ?", values) + conn.commit() + return get_node(conn, node_id) + + +def commit_node(conn: sqlite3.Connection, node_id: str) -> dict[str, Any] | None: + """ + Promote a branch node to the committed spine. + + Discard (delete) all sibling branch nodes (same parent, not committed). + The committed node and all its ancestors remain. Root is always committed. + """ + node = get_node(conn, node_id) + if node is None: + return None + + parent_id = node.get("parent_id") + if parent_id: + # Discard non-committed siblings (same parent) + conn.execute(""" + DELETE FROM nodes + WHERE parent_id = ? AND id != ? AND is_committed = 0 + """, (parent_id, node_id)) + + conn.execute("UPDATE nodes SET is_committed = 1 WHERE id = ?", (node_id,)) + conn.commit() + return get_node(conn, node_id) + + +def delete_node(conn: sqlite3.Connection, node_id: str) -> bool: + """Delete a branch node. Root nodes (parent_id IS NULL) cannot be deleted.""" + row = conn.execute( + "SELECT parent_id FROM nodes WHERE id = ?", (node_id,) + ).fetchone() + if row is None or row["parent_id"] is None: + return False # root node; refuse + cur = conn.execute( + "DELETE FROM nodes WHERE id = ? AND is_committed = 0", (node_id,) + ) + conn.commit() + return cur.rowcount > 0 + + +def get_committed_spine(conn: sqlite3.Connection, chain_id: str) -> list[dict[str, Any]]: + """ + Return committed nodes ordered root → leaf along the committed spine. + + Uses a recursive CTE to walk the tree from the root through committed nodes. + """ + rows = conn.execute(""" + WITH RECURSIVE spine(id, parent_id, depth) AS ( + SELECT id, parent_id, 0 + FROM nodes + WHERE chain_id = ? AND parent_id IS NULL AND is_committed = 1 + UNION ALL + SELECT n.id, n.parent_id, spine.depth + 1 + FROM nodes n + JOIN spine ON n.parent_id = spine.id + WHERE n.is_committed = 1 AND n.chain_id = ? + ) + SELECT n.* FROM nodes n + JOIN spine ON spine.id = n.id + ORDER BY spine.depth + """, (chain_id, chain_id)).fetchall() + return [_row_to_dict(r) for r in rows] + + +# ── Helpers ─────────────────────────────────────────────────────────────────── + +def _row_to_dict(row: sqlite3.Row) -> dict[str, Any]: + d = dict(row) + d["is_committed"] = bool(d.get("is_committed", 0)) + d["children"] = [] + return d + + +def _get_nodes_for_chain(conn: sqlite3.Connection, chain_id: str) -> list[dict[str, Any]]: + rows = conn.execute( + "SELECT * FROM nodes WHERE chain_id = ? ORDER BY created_at", (chain_id,) + ).fetchall() + return [_row_to_dict(r) for r in rows] diff --git a/app/services/export.py b/app/services/export.py new file mode 100644 index 0000000..d3e5275 --- /dev/null +++ b/app/services/export.py @@ -0,0 +1,80 @@ +# app/services/export.py — Stitch committed spine audio files via ffmpeg +# +# Uses ffmpeg concat demuxer (list-form args, no shell injection). +# WAV output: copy codec. MP3 output: libmp3lame 320k. +from __future__ import annotations + +import asyncio +import logging +import os +import shutil +import tempfile +from pathlib import Path +from typing import Literal + +logger = logging.getLogger(__name__) + +ExportFormat = Literal["wav", "mp3"] + + +async def stitch( + audio_paths: list[str], + output_path: str, + fmt: ExportFormat = "wav", +) -> str: + """ + Concatenate audio_paths in order and write to output_path. + + Returns output_path on success. Raises RuntimeError on ffmpeg failure. + Requires ffmpeg on PATH. + """ + if not audio_paths: + raise ValueError("No audio paths to stitch.") + + Path(output_path).parent.mkdir(parents=True, exist_ok=True) + + # Single file: just copy, no concatenation needed + if len(audio_paths) == 1: + shutil.copy2(audio_paths[0], output_path) + return output_path + + # Write ffmpeg concat file list to a tempfile + with tempfile.NamedTemporaryFile( + mode="w", suffix=".txt", delete=False + ) as flist: + for p in audio_paths: + # ffmpeg concat list format requires one "file '/abs/path'" per line + flist.write(f"file '{p}'\n") + flist_path = flist.name + + try: + codec_args: list[str] = ( + ["-c:a", "copy"] if fmt == "wav" + else ["-c:a", "libmp3lame", "-b:a", "320k"] + ) + # create_subprocess_exec uses a list of args — no shell interpolation + proc = await asyncio.create_subprocess_exec( + "ffmpeg", "-y", + "-f", "concat", + "-safe", "0", + "-i", flist_path, + *codec_args, + output_path, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + _, stderr = await proc.communicate() + if proc.returncode != 0: + raise RuntimeError( + f"ffmpeg concat failed (exit {proc.returncode}): {stderr.decode()}" + ) + finally: + os.unlink(flist_path) + + logger.info("Stitched %d files -> %s", len(audio_paths), output_path) + return output_path + + +def make_export_path(data_dir: str, chain_id: str, fmt: ExportFormat) -> str: + """Standard export output path.""" + return str(Path(data_dir) / "chains" / chain_id / f"export.{fmt}") diff --git a/app/services/musicgen.py b/app/services/musicgen.py new file mode 100644 index 0000000..044d4b3 --- /dev/null +++ b/app/services/musicgen.py @@ -0,0 +1,157 @@ +# app/services/musicgen.py — MusicGen client (cf-orch allocation + /continue call) +# +# Mock mode (CF_MUSICGEN_MOCK=1): copies source file, adds 1s of silence padding. +# Real mode: allocates cf-musicgen from cf-orch, calls POST /continue, releases. +# +# cf_core.musicgen.app is not yet implemented (tracked in cf-core #49). +# Until it is, real mode will fail at allocation time with a clear error. +from __future__ import annotations + +import asyncio +import logging +import os +import shutil +import time +import uuid +from pathlib import Path + +import httpx + +logger = logging.getLogger(__name__) + +_ORCH_URL = os.environ.get("CF_ORCH_URL", "").rstrip("/") +_SERVICE = "cf-musicgen" +_MOCK = os.environ.get("CF_MUSICGEN_MOCK", "") == "1" + + +class MusicGenClient: + """ + Allocates a cf-musicgen instance from cf-orch and calls POST /continue. + + Each generate() call allocates, generates, and releases. For Premium tier + session-held allocations, subclass and override _allocate/_release. + """ + + async def generate( + self, + source_audio_path: str, + output_path: str, + prompt: str, + duration_s: float, + cfg_coef: float, + prompt_duration_s: float, + ) -> float: + """ + Generate a continuation of source_audio_path. + + Writes the result to output_path and returns the actual duration_s. + Raises RuntimeError on cf-orch allocation failure or generation error. + """ + if _MOCK: + return await _mock_generate(source_audio_path, output_path, duration_s) + + service_url = await self._allocate() + try: + return await _call_continue( + service_url=service_url, + audio_path=source_audio_path, + output_path=output_path, + prompt=prompt, + duration_s=duration_s, + cfg_coef=cfg_coef, + prompt_duration_s=prompt_duration_s, + ) + finally: + await self._release(service_url) + + async def _allocate(self) -> str: + """Request a cf-musicgen allocation from cf-orch. Returns the service URL.""" + if not _ORCH_URL: + raise RuntimeError( + "CF_ORCH_URL is not configured. Set it in .env or use CF_MUSICGEN_MOCK=1." + ) + async with httpx.AsyncClient(timeout=30.0) as client: + resp = await client.post( + f"{_ORCH_URL}/api/services/{_SERVICE}/allocations", + json={"requester": "sparrow"}, + ) + if resp.status_code == 503: + raise RuntimeError("No cf-musicgen capacity available — all GPUs busy.") + resp.raise_for_status() + data = resp.json() + url = data.get("url") or data.get("service_url") + if not url: + raise RuntimeError(f"cf-orch allocation response missing URL: {data}") + logger.info("Allocated cf-musicgen at %s", url) + return url + + async def _release(self, service_url: str) -> None: + """Release the cf-musicgen allocation back to cf-orch.""" + if not _ORCH_URL: + return + try: + async with httpx.AsyncClient(timeout=10.0) as client: + await client.delete(f"{_ORCH_URL}/api/services/{_SERVICE}/allocations", + json={"url": service_url}) + except Exception as exc: + logger.warning("Failed to release cf-musicgen allocation: %s", exc) + + +# ── Real /continue call ─────────────────────────────────────────────────────── + +async def _call_continue( + service_url: str, + audio_path: str, + output_path: str, + prompt: str, + duration_s: float, + cfg_coef: float, + prompt_duration_s: float, +) -> float: + """Call POST /continue on the allocated cf-musicgen service.""" + payload = { + "audio_path": audio_path, + "output_path": output_path, + "prompt": prompt, + "duration_s": duration_s, + "cfg_coef": cfg_coef, + "prompt_duration_s": prompt_duration_s, + } + # MusicGen can take 30–120s depending on duration and hardware + async with httpx.AsyncClient(timeout=180.0) as client: + resp = await client.post(f"{service_url.rstrip('/')}/continue", json=payload) + resp.raise_for_status() + data = resp.json() + return float(data.get("duration_s", duration_s)) + + +# ── Mock generation ─────────────────────────────────────────────────────────── + +async def _mock_generate( + source_audio_path: str, + output_path: str, + duration_s: float, +) -> float: + """ + Mock: copy the source file to output_path with a short simulated delay. + + Simulates generation latency so the UI state machine (pending → generating → ready) + exercises all transitions during development. + """ + await asyncio.sleep(2.0) # simulate generation time + Path(output_path).parent.mkdir(parents=True, exist_ok=True) + shutil.copy2(source_audio_path, output_path) + logger.info("Mock generation: copied %s → %s", source_audio_path, output_path) + + # Return the actual duration from the copied file + try: + import torchaudio + info = torchaudio.info(output_path) + return info.num_frames / info.sample_rate + except Exception: + return duration_s + + +def make_output_path(data_dir: str, chain_id: str, node_id: str) -> str: + """Standard output path for a generated node's audio file.""" + return str(Path(data_dir) / "chains" / chain_id / "nodes" / f"{node_id}.wav") diff --git a/app/services/stems.py b/app/services/stems.py new file mode 100644 index 0000000..44ed2bf --- /dev/null +++ b/app/services/stems.py @@ -0,0 +1,75 @@ +# app/services/stems.py — Demucs 4-stem separation subprocess wrapper +# +# Mock mode (CF_STEMS_MOCK=1): copies source file to all 4 stem paths. +# Real mode: runs demucs htdemucs model via asyncio.create_subprocess_exec +# (list-form, no shell=True — not vulnerable to injection). +# +# Demucs output layout: {output_dir}/htdemucs/{track.stem}/{stem}.wav +from __future__ import annotations + +import asyncio +import logging +import os +import shutil +from pathlib import Path + +logger = logging.getLogger(__name__) + +_MOCK = os.environ.get("CF_STEMS_MOCK", "") == "1" +_STEMS = ("vocals", "drums", "bass", "other") + + +async def separate(source_audio_path: str, output_dir: str) -> dict[str, str]: + """ + Split source_audio_path into 4 stems. + + Returns a dict of stem name to absolute file path. + Raises RuntimeError on demucs failure. + """ + if _MOCK: + return await _mock_separate(source_audio_path, output_dir) + return await _real_separate(source_audio_path, output_dir) + + +async def _real_separate(source_audio_path: str, output_dir: str) -> dict[str, str]: + track = Path(source_audio_path) + logger.info("Running demucs on %s -> %s", track.name, output_dir) + # create_subprocess_exec takes a list of args — no shell interpolation + proc = await asyncio.create_subprocess_exec( + "python", "-m", "demucs", + "--model", "htdemucs", + "--out", output_dir, + str(track), + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + _, stderr = await proc.communicate() + if proc.returncode != 0: + raise RuntimeError(f"demucs failed (exit {proc.returncode}): {stderr.decode()}") + return _stem_paths(output_dir, track.stem) + + +async def _mock_separate(source_audio_path: str, output_dir: str) -> dict[str, str]: + """Mock: copy source to all 4 stems with a short simulated delay.""" + await asyncio.sleep(1.5) + track = Path(source_audio_path) + stem_dir = Path(output_dir) / "htdemucs" / track.stem + stem_dir.mkdir(parents=True, exist_ok=True) + paths: dict[str, str] = {} + for stem in _STEMS: + dest = str(stem_dir / f"{stem}.wav") + shutil.copy2(source_audio_path, dest) + paths[stem] = dest + logger.info("Mock stems: copied %s -> %s", track.name, stem_dir) + return paths + + +def _stem_paths(output_dir: str, track_stem: str) -> dict[str, str]: + """Build expected output paths from demucs's standard directory layout.""" + stem_dir = Path(output_dir) / "htdemucs" / track_stem + return {s: str(stem_dir / f"{s}.wav") for s in _STEMS} + + +def make_stems_dir(data_dir: str, chain_id: str, node_id: str) -> str: + """Standard stems output directory for a node.""" + return str(Path(data_dir) / "chains" / chain_id / "stems" / node_id) diff --git a/app/tiers.py b/app/tiers.py new file mode 100644 index 0000000..edd8b26 --- /dev/null +++ b/app/tiers.py @@ -0,0 +1,52 @@ +# app/tiers.py — tier gates for Sparrow +# +# BYOK does not apply to Sparrow — "bring your own" means bring your own +# hardware (self-hosting). is_self_hosted() checks the cf-orch URL target. +from __future__ import annotations + +import os +from fastapi import HTTPException + + +def is_self_hosted() -> bool: + """ + True when cf-orch is on localhost or a LAN address. + + Self-hosted users get unrestricted access equivalent to Paid tier. + Cloud users are gated by the standard tier model. + """ + orch_url = os.environ.get("CF_ORCH_URL", "") + if not orch_url: + return True # no orch = local dev mode = self-hosted + host = orch_url.split("//")[-1].split(":")[0].split("/")[0] + return host in ("localhost", "127.0.0.1") or host.startswith("10.") or \ + host.startswith("192.168.") or host.startswith("172.") + + +def require_tier(feature: str, tier: str | None = None) -> None: + """ + Gate a feature by tier. Raises 402 if the caller does not qualify. + + For self-hosted instances all features are unlocked. + tier=None means the request carries no tier — treat as Free. + """ + if is_self_hosted(): + return + + _TIER_RANK = {"free": 0, "paid": 1, "premium": 2} + _FEATURE_MIN: dict[str, str] = { + "stems": "paid", + "parallel_branch": "paid", + "session_allocation": "premium", + "priority_queue": "premium", + } + + min_tier = _FEATURE_MIN.get(feature, "free") + caller_rank = _TIER_RANK.get(tier or "free", 0) + required_rank = _TIER_RANK[min_tier] + + if caller_rank < required_rank: + raise HTTPException( + status_code=402, + detail=f"Feature '{feature}' requires {min_tier} tier.", + ) diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..31c5237 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,65 @@ +# tests/conftest.py — shared fixtures for Sparrow test suite +from __future__ import annotations + +import os +import tempfile + +import pytest +import pytest_asyncio +from fastapi.testclient import TestClient +from httpx import ASGITransport, AsyncClient + +# Use mock mode for all tests — no GPU or ffmpeg required +os.environ.setdefault("CF_MUSICGEN_MOCK", "1") +os.environ.setdefault("CF_STEMS_MOCK", "1") + + +@pytest.fixture +def tmp_db(tmp_path): + """Temporary SQLite DB path.""" + return str(tmp_path / "test.db") + + +@pytest.fixture +def tmp_data(tmp_path): + """Temporary data directory.""" + d = tmp_path / "data" + d.mkdir() + return str(d) + + +@pytest.fixture +def conn(tmp_db): + """SQLite connection with migrations applied.""" + from app.db.store import get_connection, run_migrations + c = get_connection(tmp_db) + run_migrations(c) + yield c + c.close() + + +@pytest.fixture +def app(tmp_db, tmp_data): + """FastAPI test app with isolated DB and data dir.""" + os.environ["SPARROW_DB_PATH"] = tmp_db + os.environ["SPARROW_DATA_DIR"] = tmp_data + os.environ["SPARROW_ENV"] = "development" + + from app.main import create_app + return create_app() + + +@pytest.fixture +def client(app): + """Synchronous test client.""" + with TestClient(app, raise_server_exceptions=True) as c: + yield c + + +@pytest_asyncio.fixture +async def async_client(app): + """Async test client.""" + async with AsyncClient( + transport=ASGITransport(app=app), base_url="http://test" + ) as c: + yield c diff --git a/tests/test_api_chains.py b/tests/test_api_chains.py new file mode 100644 index 0000000..44e8aa0 --- /dev/null +++ b/tests/test_api_chains.py @@ -0,0 +1,69 @@ +# tests/test_api_chains.py — integration tests for chain + upload endpoints +from __future__ import annotations + +import io + + +def test_create_and_list_chains(client): + resp = client.post("/api/chains/", json={"name": "my chain"}) + assert resp.status_code == 201 + data = resp.json() + assert data["name"] == "my chain" + assert "id" in data + + resp = client.get("/api/chains/") + assert resp.status_code == 200 + assert len(resp.json()) == 1 + + +def test_get_chain_not_found(client): + resp = client.get("/api/chains/doesnotexist") + assert resp.status_code == 404 + + +def test_delete_chain(client): + resp = client.post("/api/chains/", json={"name": "to delete"}) + chain_id = resp.json()["id"] + + resp = client.delete(f"/api/chains/{chain_id}") + assert resp.status_code == 204 + + resp = client.get(f"/api/chains/{chain_id}") + assert resp.status_code == 404 + + +def test_upload_root_node(client, tmp_path): + resp = client.post("/api/chains/", json={"name": "upload test"}) + chain_id = resp.json()["id"] + + # Minimal valid WAV header (44 bytes) so torchaudio doesn't crash + # If torchaudio isn't installed, _probe_duration returns 0.0 gracefully + wav_bytes = b"RIFF" + b"\x00" * 40 + resp = client.post( + f"/api/chains/{chain_id}/upload", + files={"file": ("test.wav", io.BytesIO(wav_bytes), "audio/wav")}, + ) + assert resp.status_code == 201 + node = resp.json() + assert node["status"] == "ready" + assert node["is_committed"] is True + assert node["parent_id"] is None + + +def test_upload_unsupported_format(client): + resp = client.post("/api/chains/", json={"name": "format test"}) + chain_id = resp.json()["id"] + + resp = client.post( + f"/api/chains/{chain_id}/upload", + files={"file": ("track.txt", io.BytesIO(b"not audio"), "text/plain")}, + ) + assert resp.status_code == 422 + + +def test_upload_chain_not_found(client): + resp = client.post( + "/api/chains/nonexistent/upload", + files={"file": ("test.wav", io.BytesIO(b""), "audio/wav")}, + ) + assert resp.status_code == 404 diff --git a/tests/test_api_nodes.py b/tests/test_api_nodes.py new file mode 100644 index 0000000..5c4f540 --- /dev/null +++ b/tests/test_api_nodes.py @@ -0,0 +1,102 @@ +# tests/test_api_nodes.py — integration tests for branch/commit/discard endpoints +from __future__ import annotations + +import io + + +def _create_chain_with_root(client, tmp_path=None): + """Helper: create a chain and upload a root node.""" + resp = client.post("/api/chains/", json={"name": "test chain"}) + chain_id = resp.json()["id"] + + resp = client.post( + f"/api/chains/{chain_id}/upload", + files={"file": ("root.wav", io.BytesIO(b"RIFF" + b"\x00" * 40), "audio/wav")}, + ) + node = resp.json() + return chain_id, node["id"] + + +def test_create_branch_returns_202(client, tmp_path): + chain_id, root_id = _create_chain_with_root(client) + + resp = client.post( + f"/api/nodes/{root_id}/branch", + json={"prompt": "upbeat jazz", "duration_s": 10.0}, + ) + assert resp.status_code == 202 + node = resp.json() + assert node["status"] == "pending" + assert node["parent_id"] == root_id + assert node["prompt"] == "upbeat jazz" + + +def test_branch_parent_not_found(client): + resp = client.post( + "/api/nodes/nonexistent/branch", + json={"prompt": "test"}, + ) + assert resp.status_code == 404 + + +def test_commit_requires_ready_status(client, conn): + """ + Directly insert a pending node into the shared DB (bypasses BackgroundTasks + so we control the status without triggering mock generation). + """ + from app.services import chain as svc + + chain_id, root_id = _create_chain_with_root(client) + # Create branch via service (no background task), stays "pending" + branch = svc.create_branch_node( + conn, root_id, chain_id, "test", None, None, None, 3.0, 10.0 + ) + resp = client.post(f"/api/nodes/{branch['id']}/commit") + assert resp.status_code == 409 + + +def test_delete_root_node_refused(client): + chain_id, root_id = _create_chain_with_root(client) + resp = client.delete(f"/api/nodes/{root_id}") + assert resp.status_code == 409 + + +def test_delete_node_not_found(client): + resp = client.delete("/api/nodes/doesnotexist") + assert resp.status_code == 409 + + +def test_audio_not_found_for_pending_node(client, conn): + """Pending node (no audio_path yet) returns 404 on the audio endpoint.""" + from app.services import chain as svc + + chain_id, root_id = _create_chain_with_root(client) + branch = svc.create_branch_node( + conn, root_id, chain_id, "test", None, None, None, 3.0, 10.0 + ) + resp = client.get(f"/api/nodes/{branch['id']}/audio") + assert resp.status_code == 404 + + +def test_mock_generation_completes(client, conn): + """ + With CF_MUSICGEN_MOCK=1, TestClient runs BackgroundTasks synchronously, + so the node transitions to 'ready' before client.post() returns. + We verify the completed state via the shared conn fixture. + """ + from app.services import chain as svc + + chain_id, root_id = _create_chain_with_root(client) + resp = client.post( + f"/api/nodes/{root_id}/branch", + json={"prompt": "test mock", "duration_s": 5.0}, + ) + assert resp.status_code == 202 + branch_id = resp.json()["id"] + + # Background task completed synchronously — node is ready + node = svc.get_node(conn, branch_id) + assert node is not None + assert node["status"] == "ready", ( + f"Expected ready, got: {node['status']}, error: {node.get('error_msg')}" + ) diff --git a/tests/test_chain_service.py b/tests/test_chain_service.py new file mode 100644 index 0000000..bcebf14 --- /dev/null +++ b/tests/test_chain_service.py @@ -0,0 +1,150 @@ +# tests/test_chain_service.py — unit tests for chain/node CRUD service +from __future__ import annotations + +import pytest +from app.services import chain as svc + + +def test_create_and_list_chain(conn): + chain = svc.create_chain(conn, "test chain") + assert chain["name"] == "test chain" + assert chain["node_count"] == 0 + + chains = svc.list_chains(conn) + assert len(chains) == 1 + assert chains[0]["id"] == chain["id"] + + +def test_get_chain_not_found(conn): + assert svc.get_chain(conn, "nonexistent") is None + + +def test_delete_chain(conn): + chain = svc.create_chain(conn, "deleteme") + assert svc.delete_chain(conn, chain["id"]) + assert svc.get_chain(conn, chain["id"]) is None + + +def test_create_root_node(conn, tmp_path): + chain = svc.create_chain(conn, "root test") + audio = str(tmp_path / "audio.wav") + open(audio, "wb").close() # empty file, just needs to exist + + node = svc.create_root_node(conn, chain["id"], audio, 30.0) + assert node["status"] == "ready" + assert node["is_committed"] is True + assert node["parent_id"] is None + assert node["audio_path"] == audio + assert node["duration_s"] == 30.0 + + +def test_create_branch_node(conn, tmp_path): + chain = svc.create_chain(conn, "branch test") + audio = str(tmp_path / "audio.wav") + open(audio, "wb").close() + root = svc.create_root_node(conn, chain["id"], audio, 30.0) + + branch = svc.create_branch_node( + conn, + parent_id=root["id"], + chain_id=chain["id"], + prompt="upbeat jazz", + energy=0.8, + tempo_feel=None, + density=None, + cfg_coef=3.0, + prompt_duration_s=10.0, + ) + assert branch["status"] == "pending" + assert branch["is_committed"] is False + assert branch["parent_id"] == root["id"] + assert branch["prompt"] == "upbeat jazz" + + +def test_update_node_status(conn, tmp_path): + chain = svc.create_chain(conn, "status test") + audio = str(tmp_path / "audio.wav") + open(audio, "wb").close() + root = svc.create_root_node(conn, chain["id"], audio, 30.0) + branch = svc.create_branch_node( + conn, root["id"], chain["id"], "test", None, None, None, 3.0, 10.0 + ) + + updated = svc.update_node_status( + conn, branch["id"], "generating" + ) + assert updated["status"] == "generating" + + out_path = str(tmp_path / "out.wav") + open(out_path, "wb").close() + done = svc.update_node_status( + conn, branch["id"], "ready", audio_path=out_path, duration_s=15.0 + ) + assert done["status"] == "ready" + assert done["audio_path"] == out_path + assert done["duration_s"] == 15.0 + + +def test_commit_node_discards_siblings(conn, tmp_path): + chain = svc.create_chain(conn, "commit test") + audio = str(tmp_path / "audio.wav") + open(audio, "wb").close() + root = svc.create_root_node(conn, chain["id"], audio, 30.0) + + b1 = svc.create_branch_node( + conn, root["id"], chain["id"], "branch 1", None, None, None, 3.0, 10.0 + ) + b2 = svc.create_branch_node( + conn, root["id"], chain["id"], "branch 2", None, None, None, 3.0, 10.0 + ) + + # Make both ready + for nid in (b1["id"], b2["id"]): + svc.update_node_status(conn, nid, "ready", + audio_path=audio, duration_s=15.0) + + # Commit b1 — b2 should be deleted + committed = svc.commit_node(conn, b1["id"]) + assert committed["is_committed"] is True + assert svc.get_node(conn, b2["id"]) is None + + +def test_delete_node_refuses_root(conn, tmp_path): + chain = svc.create_chain(conn, "del test") + audio = str(tmp_path / "audio.wav") + open(audio, "wb").close() + root = svc.create_root_node(conn, chain["id"], audio, 30.0) + assert svc.delete_node(conn, root["id"]) is False + + +def test_delete_node_refuses_committed(conn, tmp_path): + chain = svc.create_chain(conn, "del committed") + audio = str(tmp_path / "audio.wav") + open(audio, "wb").close() + root = svc.create_root_node(conn, chain["id"], audio, 30.0) + branch = svc.create_branch_node( + conn, root["id"], chain["id"], "", None, None, None, 3.0, 10.0 + ) + svc.update_node_status(conn, branch["id"], "ready", + audio_path=audio, duration_s=15.0) + svc.commit_node(conn, branch["id"]) + assert svc.delete_node(conn, branch["id"]) is False + + +def test_committed_spine_order(conn, tmp_path): + chain = svc.create_chain(conn, "spine test") + audio = str(tmp_path / "audio.wav") + open(audio, "wb").close() + root = svc.create_root_node(conn, chain["id"], audio, 30.0) + + child = svc.create_branch_node( + conn, root["id"], chain["id"], "child", None, None, None, 3.0, 10.0 + ) + svc.update_node_status(conn, child["id"], "ready", + audio_path=audio, duration_s=15.0) + svc.commit_node(conn, child["id"]) + + spine = svc.get_committed_spine(conn, chain["id"]) + assert len(spine) == 2 + assert spine[0]["id"] == root["id"] + assert spine[1]["id"] == child["id"] diff --git a/tests/test_export_service.py b/tests/test_export_service.py new file mode 100644 index 0000000..9d7134b --- /dev/null +++ b/tests/test_export_service.py @@ -0,0 +1,34 @@ +# tests/test_export_service.py — unit tests for export stitch service +from __future__ import annotations + +import pytest + + +@pytest.mark.asyncio +async def test_stitch_single_file(tmp_path): + """Single file: should copy without invoking ffmpeg.""" + from app.services.export import stitch + + src = tmp_path / "segment.wav" + src.write_bytes(b"WAV DATA") + out = str(tmp_path / "output.wav") + + result = await stitch([str(src)], out) + assert result == out + assert open(out, "rb").read() == b"WAV DATA" + + +@pytest.mark.asyncio +async def test_stitch_empty_raises(tmp_path): + from app.services.export import stitch + + with pytest.raises(ValueError, match="No audio paths"): + await stitch([], str(tmp_path / "out.wav")) + + +def test_make_export_path(): + from app.services.export import make_export_path + + p = make_export_path("data", "chain-abc", "mp3") + assert p == "data/chains/chain-abc/export.mp3" + assert make_export_path("data", "chain-abc", "wav").endswith(".wav") diff --git a/tests/test_musicgen_service.py b/tests/test_musicgen_service.py new file mode 100644 index 0000000..d032fb2 --- /dev/null +++ b/tests/test_musicgen_service.py @@ -0,0 +1,36 @@ +# tests/test_musicgen_service.py — unit tests for MusicGenClient mock mode +from __future__ import annotations + +import os +import pytest + +# Ensure mock mode +os.environ["CF_MUSICGEN_MOCK"] = "1" + + +@pytest.mark.asyncio +async def test_mock_generate_copies_file(tmp_path): + from app.services.musicgen import MusicGenClient + + source = tmp_path / "source.wav" + source.write_bytes(b"RIFF" + b"\x00" * 40) + output = str(tmp_path / "output.wav") + + client = MusicGenClient() + duration = await client.generate( + source_audio_path=str(source), + output_path=output, + prompt="test", + duration_s=10.0, + cfg_coef=3.0, + prompt_duration_s=5.0, + ) + + assert os.path.exists(output) + assert duration >= 0.0 # torchaudio may not parse our fake WAV; 0.0 is fine + + +def test_make_output_path(): + from app.services.musicgen import make_output_path + path = make_output_path("data", "chain-abc", "node-xyz") + assert path == "data/chains/chain-abc/nodes/node-xyz.wav"