feat: implement Sparrow backend (v0.1.0)

Full FastAPI backend for the AI music continuation editor:

Services
- chain.py: chain + node CRUD, commit/discard, recursive CTE spine query
- musicgen.py: MusicGenClient with cf-orch allocation + mock mode (CF_MUSICGEN_MOCK=1)
- stems.py: Demucs 4-stem separation subprocess wrapper + mock mode
- export.py: ffmpeg concat demuxer to stitch committed spine into WAV/MP3

API endpoints
- chains: CRUD, multipart audio upload (WAV/MP3/FLAC/OGG/M4A/AIFF)
- nodes: branch creation (202 + BackgroundTasks), commit, discard, audio stream
- gpu: cf-orch capacity status; session allocation stubbed pending cf-orch#43
- stems: Paid-tier stem separation (Demucs, gated via tiers.py)
- export: POST /{chain_id}/export → FileResponse download
- events: SSE stream (node-status events) per chain via asyncio Queue pub/sub

Infrastructure
- lifespan: reads SPARROW_DB_PATH/DATA_DIR at startup (not import time)
- events_store: subscribe/unsubscribe/broadcast pattern for SSE
- CORS: open in dev, SPARROW_CORS_ORIGINS in production
- Background generation opens its own DB connection (WAL-safe)

Tests: 30/30 passing across service units and API integration
This commit is contained in:
pyr0ball 2026-04-17 15:22:37 -07:00
parent c3e18e83e5
commit a6f60c9e07
23 changed files with 1760 additions and 0 deletions

17
app/api/deps.py Normal file
View file

@ -0,0 +1,17 @@
# app/api/deps.py — FastAPI dependency providers
from __future__ import annotations
import os
import sqlite3
from fastapi import Request
def get_conn(request: Request) -> sqlite3.Connection:
"""Yield the shared SQLite connection from app state."""
return request.app.state.conn
def get_data_dir(request: Request) -> str:
"""Return the configured data directory."""
return request.app.state.data_dir

View file

@ -0,0 +1,90 @@
# app/api/endpoints/chains.py — chain CRUD + root audio upload
from __future__ import annotations
import os
import uuid
from pathlib import Path
from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile
from app.api.deps import get_conn, get_data_dir
from app.models.schemas.chain import ChainCreate, ChainRow, ChainTree
from app.models.schemas.node import NodeRow
from app.services import chain as chain_svc
router = APIRouter(prefix="/api/chains", tags=["chains"])
_ALLOWED_AUDIO = {".wav", ".mp3", ".flac", ".ogg", ".m4a", ".aiff"}
@router.post("/", response_model=ChainRow, status_code=201)
def create_chain(body: ChainCreate, conn=Depends(get_conn)) -> dict:
return chain_svc.create_chain(conn, body.name)
@router.get("/", response_model=list[ChainRow])
def list_chains(conn=Depends(get_conn)) -> list:
return chain_svc.list_chains(conn)
@router.get("/{chain_id}", response_model=ChainTree)
def get_chain(chain_id: str, conn=Depends(get_conn)) -> dict:
chain = chain_svc.get_chain(conn, chain_id)
if chain is None:
raise HTTPException(status_code=404, detail="Chain not found.")
return chain
@router.delete("/{chain_id}", status_code=204)
def delete_chain(chain_id: str, conn=Depends(get_conn)) -> None:
if not chain_svc.delete_chain(conn, chain_id):
raise HTTPException(status_code=404, detail="Chain not found.")
@router.post("/{chain_id}/upload", response_model=NodeRow, status_code=201)
async def upload_root(
chain_id: str,
file: UploadFile = File(...),
conn=Depends(get_conn),
data_dir: str = Depends(get_data_dir),
) -> dict:
"""
Upload source audio and create the root node for a chain.
Accepts WAV, MP3, FLAC, OGG, M4A, AIFF. The file is stored in
data/chains/{chain_id}/nodes/root_{node_id}{ext} and a root node
(always committed) is created in the DB.
"""
chain = chain_svc.get_chain(conn, chain_id)
if chain is None:
raise HTTPException(status_code=404, detail="Chain not found.")
suffix = Path(file.filename or "audio.wav").suffix.lower()
if suffix not in _ALLOWED_AUDIO:
raise HTTPException(
status_code=422,
detail=f"Unsupported audio format '{suffix}'. "
f"Allowed: {', '.join(sorted(_ALLOWED_AUDIO))}",
)
node_id = str(uuid.uuid4())
dest_dir = Path(data_dir) / "chains" / chain_id / "nodes"
dest_dir.mkdir(parents=True, exist_ok=True)
audio_path = str(dest_dir / f"root_{node_id}{suffix}")
contents = await file.read()
Path(audio_path).write_bytes(contents)
# Get duration via torchaudio if available, else 0.0
duration_s = _probe_duration(audio_path)
return chain_svc.create_root_node(conn, chain_id, audio_path, duration_s)
def _probe_duration(path: str) -> float:
try:
import torchaudio
info = torchaudio.info(path)
return info.num_frames / info.sample_rate
except Exception:
return 0.0

View file

@ -0,0 +1,54 @@
# app/api/endpoints/events.py — SSE stream for node status transitions
#
# Clients connect to GET /api/chains/{chain_id}/events and receive
# server-sent events whenever a node in that chain changes status.
#
# Event format:
# event: node-status
# data: {"node_id": "...", "status": "generating"|"ready"|"error", ...}
from __future__ import annotations
import asyncio
import json
import logging
from fastapi import APIRouter, Depends
from sse_starlette.sse import EventSourceResponse
from app.api.deps import get_conn
from app.api.events_store import subscribe, unsubscribe
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/chains", tags=["events"])
_KEEPALIVE_S = 15 # send a comment ping every N seconds to keep the connection alive
@router.get("/{chain_id}/events")
async def chain_events(chain_id: str, conn=Depends(get_conn)) -> EventSourceResponse:
"""
SSE stream for node status transitions in a chain.
Emits 'node-status' events. Closes when the client disconnects.
"""
q = subscribe(chain_id)
async def generator():
try:
while True:
try:
event = await asyncio.wait_for(q.get(), timeout=_KEEPALIVE_S)
yield {
"event": "node-status",
"data": json.dumps(event),
}
except asyncio.TimeoutError:
# Send keepalive comment to prevent proxy/browser timeout
yield {"comment": "keepalive"}
except asyncio.CancelledError:
pass
finally:
unsubscribe(chain_id, q)
logger.debug("SSE client disconnected from chain %s", chain_id)
return EventSourceResponse(generator())

View file

@ -0,0 +1,64 @@
# app/api/endpoints/export.py — stitch committed spine and serve as download
from __future__ import annotations
import logging
from pathlib import Path
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import FileResponse
from app.api.deps import get_conn, get_data_dir
from app.models.schemas.node import ExportRequest
from app.services import chain as chain_svc
from app.services.export import ExportFormat, make_export_path, stitch
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/chains", tags=["export"])
@router.post("/{chain_id}/export")
async def export_chain(
chain_id: str,
body: ExportRequest,
conn=Depends(get_conn),
data_dir: str = Depends(get_data_dir),
) -> FileResponse:
"""
Stitch node audio files into a single download.
If body.node_ids is provided, those nodes are used in order.
Otherwise the committed spine (root leaf) is used.
Raises 409 if any targeted node has no audio.
"""
if body.node_ids is not None:
nodes = [chain_svc.get_node(conn, nid) for nid in body.node_ids]
if any(n is None for n in nodes):
raise HTTPException(status_code=404, detail="One or more nodes not found.")
else:
nodes = chain_svc.get_committed_spine(conn, chain_id)
if not nodes:
raise HTTPException(
status_code=404,
detail="Chain not found or has no committed nodes.",
)
audio_paths = [n["audio_path"] for n in nodes if n] # type: ignore[index]
missing = [n["id"] for n in nodes if n and n["audio_path"] is None] # type: ignore[index]
if missing:
raise HTTPException(
status_code=409,
detail=f"Nodes {missing} have no audio yet.",
)
fmt: ExportFormat = body.format # type: ignore[assignment]
output_path = make_export_path(data_dir, chain_id, fmt)
try:
await stitch(audio_paths, output_path, fmt)
except Exception as exc:
logger.exception("Export failed for chain %s: %s", chain_id, exc)
raise HTTPException(status_code=500, detail=f"Export failed: {exc}") from exc
media_type = "audio/wav" if fmt == "wav" else "audio/mpeg"
filename = f"sparrow_export_{chain_id[:8]}.{fmt}"
return FileResponse(output_path, media_type=media_type, filename=filename)

73
app/api/endpoints/gpu.py Normal file
View file

@ -0,0 +1,73 @@
# app/api/endpoints/gpu.py — cf-orch GPU status and session allocation
#
# GET /api/gpu/status — available capacity from cf-orch
# POST /api/gpu/connect — session-held allocation (Premium tier, stub)
# DELETE /api/gpu/disconnect — release session allocation (Premium tier, stub)
from __future__ import annotations
import logging
import os
import httpx
from fastapi import APIRouter, HTTPException
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/gpu", tags=["gpu"])
_ORCH_URL = os.environ.get("CF_ORCH_URL", "").rstrip("/")
_SERVICE = "cf-musicgen"
@router.get("/status")
async def gpu_status() -> dict:
"""
Return current cf-orch capacity for cf-musicgen.
Returns {"available": False, "reason": "..."} if cf-orch is unreachable
or unconfigured (e.g. mock mode).
"""
if not _ORCH_URL:
return {
"available": False,
"reason": "CF_ORCH_URL not configured — running in mock mode.",
"mock": True,
}
try:
async with httpx.AsyncClient(timeout=5.0) as client:
resp = await client.get(
f"{_ORCH_URL}/api/services/{_SERVICE}/status"
)
resp.raise_for_status()
return resp.json()
except httpx.HTTPStatusError as exc:
raise HTTPException(
status_code=502,
detail=f"cf-orch returned {exc.response.status_code}.",
) from exc
except Exception as exc:
raise HTTPException(
status_code=502,
detail=f"cf-orch unreachable: {exc}",
) from exc
@router.post("/connect", status_code=501)
async def gpu_connect() -> dict:
"""
Session-held GPU allocation (Premium tier).
Not yet implemented tracked in cf-orch #43.
"""
raise HTTPException(
status_code=501,
detail="Session-held GPU allocation is not yet implemented (cf-orch #43).",
)
@router.delete("/disconnect", status_code=501)
async def gpu_disconnect() -> dict:
"""Release a session-held GPU allocation (Premium tier, stub)."""
raise HTTPException(
status_code=501,
detail="Session-held GPU allocation is not yet implemented (cf-orch #43).",
)

182
app/api/endpoints/nodes.py Normal file
View file

@ -0,0 +1,182 @@
# app/api/endpoints/nodes.py — branch generation, commit, discard, audio stream
from __future__ import annotations
import asyncio
import logging
from pathlib import Path
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException
from fastapi.responses import FileResponse
from app.api.deps import get_conn, get_data_dir
from app.api.events_store import broadcast
from app.models.schemas.node import BranchRequest, NodeRow
from app.services import chain as chain_svc
from app.services.musicgen import MusicGenClient, make_output_path
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/nodes", tags=["nodes"])
_musicgen = MusicGenClient()
@router.post("/{parent_id}/branch", response_model=NodeRow, status_code=202)
async def create_branch(
parent_id: str,
body: BranchRequest,
background_tasks: BackgroundTasks,
conn=Depends(get_conn),
data_dir: str = Depends(get_data_dir),
) -> dict:
"""
Create a branch node from parent_id and kick off async generation.
Returns 202 immediately with the node in 'pending' state.
Status transitions (pending generating ready/error) are pushed
via the chain's SSE stream.
"""
parent = chain_svc.get_node(conn, parent_id)
if parent is None:
raise HTTPException(status_code=404, detail="Parent node not found.")
if parent["audio_path"] is None:
raise HTTPException(
status_code=409, detail="Parent node has no audio yet."
)
node = chain_svc.create_branch_node(
conn,
parent_id=parent_id,
chain_id=parent["chain_id"],
prompt=body.prompt,
energy=body.energy,
tempo_feel=body.tempo_feel,
density=body.density,
cfg_coef=body.cfg_coef,
prompt_duration_s=body.prompt_duration_s,
)
background_tasks.add_task(
_run_generation,
node_id=node["id"],
chain_id=parent["chain_id"],
source_audio_path=parent["audio_path"],
prompt=body.prompt,
duration_s=body.duration_s,
cfg_coef=body.cfg_coef,
prompt_duration_s=body.prompt_duration_s,
data_dir=data_dir,
db_path=conn.execute("PRAGMA database_list").fetchone()[2],
)
return node
@router.post("/{node_id}/commit", response_model=NodeRow)
def commit_node(node_id: str, conn=Depends(get_conn)) -> dict:
"""
Promote a branch node to the committed spine.
Discards all non-committed siblings (same parent). The root node is
always committed and cannot be targeted here.
"""
node = chain_svc.get_node(conn, node_id)
if node is None:
raise HTTPException(status_code=404, detail="Node not found.")
if node["status"] != "ready":
raise HTTPException(
status_code=409, detail="Only 'ready' nodes can be committed."
)
result = chain_svc.commit_node(conn, node_id)
if result is None:
raise HTTPException(status_code=404, detail="Node not found.")
return result
@router.delete("/{node_id}", status_code=204)
def delete_node(node_id: str, conn=Depends(get_conn)) -> None:
"""
Discard a non-committed branch node.
Root nodes (parent_id IS NULL) and committed nodes cannot be deleted.
"""
if not chain_svc.delete_node(conn, node_id):
raise HTTPException(
status_code=409,
detail="Node not found or cannot be deleted (root or committed).",
)
@router.get("/{node_id}/audio")
def stream_audio(node_id: str, conn=Depends(get_conn)) -> FileResponse:
"""Stream the audio file for a ready node."""
node = chain_svc.get_node(conn, node_id)
if node is None:
raise HTTPException(status_code=404, detail="Node not found.")
if node["audio_path"] is None or not Path(node["audio_path"]).exists():
raise HTTPException(status_code=404, detail="Audio not available.")
return FileResponse(
node["audio_path"],
media_type="audio/wav",
filename=f"{node_id}.wav",
)
# ── Background generation task ────────────────────────────────────────────────
async def _run_generation(
*,
node_id: str,
chain_id: str,
source_audio_path: str,
prompt: str,
duration_s: float,
cfg_coef: float,
prompt_duration_s: float,
data_dir: str,
db_path: str,
) -> None:
"""
Background task: run MusicGen and update node status.
Opens its own DB connection so it doesn't block the request thread.
"""
from app.db.store import get_connection
conn = get_connection(db_path)
try:
chain_svc.update_node_status(conn, node_id, "generating")
await broadcast(chain_id, {"node_id": node_id, "status": "generating"})
output_path = make_output_path(data_dir, chain_id, node_id)
actual_duration = await _musicgen.generate(
source_audio_path=source_audio_path,
output_path=output_path,
prompt=prompt,
duration_s=duration_s,
cfg_coef=cfg_coef,
prompt_duration_s=prompt_duration_s,
)
chain_svc.update_node_status(
conn, node_id, "ready",
audio_path=output_path,
duration_s=actual_duration,
)
await broadcast(chain_id, {
"node_id": node_id,
"status": "ready",
"audio_path": output_path,
"duration_s": actual_duration,
})
logger.info("Node %s ready (%.1fs)", node_id, actual_duration)
except Exception as exc:
logger.exception("Generation failed for node %s: %s", node_id, exc)
chain_svc.update_node_status(conn, node_id, "error", error_msg=str(exc))
await broadcast(chain_id, {
"node_id": node_id,
"status": "error",
"error_msg": str(exc),
})
finally:
conn.close()

View file

@ -0,0 +1,67 @@
# app/api/endpoints/stems.py — 4-stem separation (Paid tier)
from __future__ import annotations
import logging
from pathlib import Path
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException
from app.api.deps import get_conn, get_data_dir
from app.models.schemas.node import StemResult
from app.services import chain as chain_svc
from app.services import stems as stems_svc
from app.tiers import require_tier
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/stems", tags=["stems"])
@router.post("/{node_id}", response_model=StemResult, status_code=202)
async def separate_stems(
node_id: str,
background_tasks: BackgroundTasks,
tier: str | None = None,
conn=Depends(get_conn),
data_dir: str = Depends(get_data_dir),
) -> dict:
"""
Run 4-stem separation (Demucs htdemucs) on a ready node.
Requires Paid tier. Returns stem file paths immediately paths will
be valid once the background task completes. Poll GET /api/nodes/{node_id}
or listen on the chain SSE stream for completion signals.
tier: passed as query param, e.g. ?tier=paid
"""
require_tier("stems", tier)
node = chain_svc.get_node(conn, node_id)
if node is None:
raise HTTPException(status_code=404, detail="Node not found.")
if node["audio_path"] is None:
raise HTTPException(status_code=409, detail="Node has no audio.")
stems_dir = stems_svc.make_stems_dir(data_dir, node["chain_id"], node_id)
stem_paths = stems_svc._stem_paths(stems_dir, Path(node["audio_path"]).stem)
background_tasks.add_task(
_run_separation,
source_audio_path=node["audio_path"],
stems_dir=stems_dir,
)
return StemResult(
node_id=node_id,
vocals=stem_paths["vocals"],
drums=stem_paths["drums"],
bass=stem_paths["bass"],
other=stem_paths["other"],
)
async def _run_separation(source_audio_path: str, stems_dir: str) -> None:
try:
await stems_svc.separate(source_audio_path, stems_dir)
logger.info("Stems complete: %s", stems_dir)
except Exception as exc:
logger.exception("Stem separation failed: %s", exc)

34
app/api/events_store.py Normal file
View file

@ -0,0 +1,34 @@
# app/api/events_store.py — Chain-scoped pub/sub for SSE node-status events
#
# Each SSE connection subscribes with subscribe(chain_id) → Queue.
# Background generation tasks call broadcast(chain_id, event) to push
# node state transitions (pending → generating → ready/error) to all listeners.
from __future__ import annotations
import asyncio
from collections import defaultdict
from typing import Any
# chain_id → list of subscriber queues
_listeners: dict[str, list[asyncio.Queue[Any]]] = defaultdict(list)
def subscribe(chain_id: str) -> asyncio.Queue[Any]:
"""Register a new SSE listener for chain_id. Returns its Queue."""
q: asyncio.Queue[Any] = asyncio.Queue()
_listeners[chain_id].append(q)
return q
def unsubscribe(chain_id: str, q: asyncio.Queue[Any]) -> None:
"""Remove a listener Queue. Safe to call if already removed."""
try:
_listeners[chain_id].remove(q)
except ValueError:
pass
async def broadcast(chain_id: str, event: dict[str, Any]) -> None:
"""Push event dict to all active listeners for chain_id."""
for q in list(_listeners.get(chain_id, [])):
await q.put(event)

14
app/api/routes.py Normal file
View file

@ -0,0 +1,14 @@
# app/api/routes.py — assemble all endpoint routers
from __future__ import annotations
from fastapi import APIRouter
from app.api.endpoints import chains, events, export, gpu, nodes, stems
router = APIRouter()
router.include_router(chains.router)
router.include_router(nodes.router)
router.include_router(gpu.router)
router.include_router(stems.router)
router.include_router(export.router)
router.include_router(events.router)

66
app/main.py Normal file
View file

@ -0,0 +1,66 @@
# app/main.py — Sparrow FastAPI application factory
from __future__ import annotations
import logging
import os
from contextlib import asynccontextmanager
from pathlib import Path
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from app.api.routes import router
from app.db.store import get_connection, run_migrations
logger = logging.getLogger(__name__)
@asynccontextmanager
async def lifespan(app: FastAPI):
# Read env at startup — not at import time — so tests can inject values
db_path = os.environ.get("SPARROW_DB_PATH", "data/sparrow.db")
data_dir = os.environ.get("SPARROW_DATA_DIR", "data")
env = os.environ.get("SPARROW_ENV", "development")
Path(db_path).parent.mkdir(parents=True, exist_ok=True)
Path(data_dir).mkdir(parents=True, exist_ok=True)
conn = get_connection(db_path)
run_migrations(conn)
app.state.conn = conn
app.state.data_dir = data_dir
logger.info("Sparrow started (env=%s, db=%s, data=%s)", env, db_path, data_dir)
yield
conn.close()
logger.info("Sparrow shut down.")
def create_app() -> FastAPI:
env = os.environ.get("SPARROW_ENV", "development")
app = FastAPI(
title="Sparrow",
description="AI music continuation — branching chain editor",
version="0.1.0",
lifespan=lifespan,
)
# CORS: open in dev, tighten in production via env
origins = (
["*"] if env == "development"
else os.environ.get("SPARROW_CORS_ORIGINS", "").split(",")
)
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.include_router(router)
return app
app = create_app()

View file

@ -0,0 +1,27 @@
# app/models/schemas/chain.py — Chain Pydantic models
from __future__ import annotations
from pydantic import BaseModel
class ChainCreate(BaseModel):
name: str
class ChainRow(BaseModel):
id: str
name: str
created_at: float
node_count: int = 0
class ChainTree(BaseModel):
id: str
name: str
created_at: float
nodes: list["NodeRow"] = []
# Avoid circular import — NodeRow imported at runtime
from app.models.schemas.node import NodeRow # noqa: E402
ChainTree.model_rebuild()

View file

@ -0,0 +1,49 @@
# app/models/schemas/node.py — Node Pydantic models
from __future__ import annotations
from typing import Literal
from pydantic import BaseModel
NodeStatus = Literal["pending", "generating", "ready", "error"]
class NodeRow(BaseModel):
id: str
chain_id: str
parent_id: str | None
audio_path: str | None
duration_s: float | None
status: NodeStatus
is_committed: bool
prompt: str
energy: float | None
tempo_feel: float | None
density: float | None
cfg_coef: float
prompt_duration_s: float
error_msg: str | None
created_at: float
children: list["NodeRow"] = []
class BranchRequest(BaseModel):
prompt: str = ""
energy: float | None = None
tempo_feel: float | None = None
density: float | None = None
cfg_coef: float = 3.0
prompt_duration_s: float = 10.0
duration_s: float = 15.0
class StemResult(BaseModel):
node_id: str
vocals: str
drums: str
bass: str
other: str
class ExportRequest(BaseModel):
format: Literal["wav", "mp3"] = "wav"
node_ids: list[str] | None = None # if None, use committed spine

203
app/services/chain.py Normal file
View file

@ -0,0 +1,203 @@
# app/services/chain.py — chain + node CRUD and branching tree logic
from __future__ import annotations
import sqlite3
import time
import uuid
from typing import Any
# ── Chain CRUD ────────────────────────────────────────────────────────────────
def create_chain(conn: sqlite3.Connection, name: str) -> dict[str, Any]:
chain_id = str(uuid.uuid4())
now = time.time()
conn.execute(
"INSERT INTO chains (id, name, created_at) VALUES (?, ?, ?)",
(chain_id, name, now),
)
conn.commit()
return {"id": chain_id, "name": name, "created_at": now, "node_count": 0}
def list_chains(conn: sqlite3.Connection) -> list[dict[str, Any]]:
rows = conn.execute("""
SELECT c.id, c.name, c.created_at,
COUNT(n.id) AS node_count
FROM chains c
LEFT JOIN nodes n ON n.chain_id = c.id
GROUP BY c.id
ORDER BY c.created_at DESC
""").fetchall()
return [dict(r) for r in rows]
def get_chain(conn: sqlite3.Connection, chain_id: str) -> dict[str, Any] | None:
row = conn.execute(
"SELECT id, name, created_at FROM chains WHERE id = ?", (chain_id,)
).fetchone()
if row is None:
return None
nodes = _get_nodes_for_chain(conn, chain_id)
return {"id": row["id"], "name": row["name"], "created_at": row["created_at"], "nodes": nodes}
def delete_chain(conn: sqlite3.Connection, chain_id: str) -> bool:
cur = conn.execute("DELETE FROM chains WHERE id = ?", (chain_id,))
conn.commit()
return cur.rowcount > 0
# ── Node CRUD ─────────────────────────────────────────────────────────────────
def create_root_node(
conn: sqlite3.Connection,
chain_id: str,
audio_path: str,
duration_s: float,
) -> dict[str, Any]:
"""Create the root node (uploaded source audio). Always committed."""
node_id = str(uuid.uuid4())
now = time.time()
conn.execute("""
INSERT INTO nodes
(id, chain_id, parent_id, audio_path, duration_s, status,
is_committed, prompt, cfg_coef, prompt_duration_s, created_at)
VALUES (?, ?, NULL, ?, ?, 'ready', 1, '', 3.0, 10.0, ?)
""", (node_id, chain_id, audio_path, duration_s, now))
conn.commit()
return get_node(conn, node_id)
def get_node(conn: sqlite3.Connection, node_id: str) -> dict[str, Any] | None:
row = conn.execute("SELECT * FROM nodes WHERE id = ?", (node_id,)).fetchone()
return _row_to_dict(row) if row else None
def create_branch_node(
conn: sqlite3.Connection,
parent_id: str,
chain_id: str,
prompt: str,
energy: float | None,
tempo_feel: float | None,
density: float | None,
cfg_coef: float,
prompt_duration_s: float,
) -> dict[str, Any]:
"""Create a new branch node in 'pending' state. Caller updates to 'generating'."""
node_id = str(uuid.uuid4())
now = time.time()
conn.execute("""
INSERT INTO nodes
(id, chain_id, parent_id, audio_path, duration_s, status,
is_committed, prompt, energy, tempo_feel, density,
cfg_coef, prompt_duration_s, created_at)
VALUES (?, ?, ?, NULL, NULL, 'pending', 0, ?, ?, ?, ?, ?, ?, ?)
""", (node_id, chain_id, parent_id, prompt, energy, tempo_feel, density,
cfg_coef, prompt_duration_s, now))
conn.commit()
return get_node(conn, node_id)
def update_node_status(
conn: sqlite3.Connection,
node_id: str,
status: str,
audio_path: str | None = None,
duration_s: float | None = None,
error_msg: str | None = None,
) -> dict[str, Any] | None:
fields: list[str] = ["status = ?"]
values: list[Any] = [status]
if audio_path is not None:
fields.append("audio_path = ?")
values.append(audio_path)
if duration_s is not None:
fields.append("duration_s = ?")
values.append(duration_s)
if error_msg is not None:
fields.append("error_msg = ?")
values.append(error_msg)
values.append(node_id)
conn.execute(f"UPDATE nodes SET {', '.join(fields)} WHERE id = ?", values)
conn.commit()
return get_node(conn, node_id)
def commit_node(conn: sqlite3.Connection, node_id: str) -> dict[str, Any] | None:
"""
Promote a branch node to the committed spine.
Discard (delete) all sibling branch nodes (same parent, not committed).
The committed node and all its ancestors remain. Root is always committed.
"""
node = get_node(conn, node_id)
if node is None:
return None
parent_id = node.get("parent_id")
if parent_id:
# Discard non-committed siblings (same parent)
conn.execute("""
DELETE FROM nodes
WHERE parent_id = ? AND id != ? AND is_committed = 0
""", (parent_id, node_id))
conn.execute("UPDATE nodes SET is_committed = 1 WHERE id = ?", (node_id,))
conn.commit()
return get_node(conn, node_id)
def delete_node(conn: sqlite3.Connection, node_id: str) -> bool:
"""Delete a branch node. Root nodes (parent_id IS NULL) cannot be deleted."""
row = conn.execute(
"SELECT parent_id FROM nodes WHERE id = ?", (node_id,)
).fetchone()
if row is None or row["parent_id"] is None:
return False # root node; refuse
cur = conn.execute(
"DELETE FROM nodes WHERE id = ? AND is_committed = 0", (node_id,)
)
conn.commit()
return cur.rowcount > 0
def get_committed_spine(conn: sqlite3.Connection, chain_id: str) -> list[dict[str, Any]]:
"""
Return committed nodes ordered root leaf along the committed spine.
Uses a recursive CTE to walk the tree from the root through committed nodes.
"""
rows = conn.execute("""
WITH RECURSIVE spine(id, parent_id, depth) AS (
SELECT id, parent_id, 0
FROM nodes
WHERE chain_id = ? AND parent_id IS NULL AND is_committed = 1
UNION ALL
SELECT n.id, n.parent_id, spine.depth + 1
FROM nodes n
JOIN spine ON n.parent_id = spine.id
WHERE n.is_committed = 1 AND n.chain_id = ?
)
SELECT n.* FROM nodes n
JOIN spine ON spine.id = n.id
ORDER BY spine.depth
""", (chain_id, chain_id)).fetchall()
return [_row_to_dict(r) for r in rows]
# ── Helpers ───────────────────────────────────────────────────────────────────
def _row_to_dict(row: sqlite3.Row) -> dict[str, Any]:
d = dict(row)
d["is_committed"] = bool(d.get("is_committed", 0))
d["children"] = []
return d
def _get_nodes_for_chain(conn: sqlite3.Connection, chain_id: str) -> list[dict[str, Any]]:
rows = conn.execute(
"SELECT * FROM nodes WHERE chain_id = ? ORDER BY created_at", (chain_id,)
).fetchall()
return [_row_to_dict(r) for r in rows]

80
app/services/export.py Normal file
View file

@ -0,0 +1,80 @@
# app/services/export.py — Stitch committed spine audio files via ffmpeg
#
# Uses ffmpeg concat demuxer (list-form args, no shell injection).
# WAV output: copy codec. MP3 output: libmp3lame 320k.
from __future__ import annotations
import asyncio
import logging
import os
import shutil
import tempfile
from pathlib import Path
from typing import Literal
logger = logging.getLogger(__name__)
ExportFormat = Literal["wav", "mp3"]
async def stitch(
audio_paths: list[str],
output_path: str,
fmt: ExportFormat = "wav",
) -> str:
"""
Concatenate audio_paths in order and write to output_path.
Returns output_path on success. Raises RuntimeError on ffmpeg failure.
Requires ffmpeg on PATH.
"""
if not audio_paths:
raise ValueError("No audio paths to stitch.")
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
# Single file: just copy, no concatenation needed
if len(audio_paths) == 1:
shutil.copy2(audio_paths[0], output_path)
return output_path
# Write ffmpeg concat file list to a tempfile
with tempfile.NamedTemporaryFile(
mode="w", suffix=".txt", delete=False
) as flist:
for p in audio_paths:
# ffmpeg concat list format requires one "file '/abs/path'" per line
flist.write(f"file '{p}'\n")
flist_path = flist.name
try:
codec_args: list[str] = (
["-c:a", "copy"] if fmt == "wav"
else ["-c:a", "libmp3lame", "-b:a", "320k"]
)
# create_subprocess_exec uses a list of args — no shell interpolation
proc = await asyncio.create_subprocess_exec(
"ffmpeg", "-y",
"-f", "concat",
"-safe", "0",
"-i", flist_path,
*codec_args,
output_path,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
_, stderr = await proc.communicate()
if proc.returncode != 0:
raise RuntimeError(
f"ffmpeg concat failed (exit {proc.returncode}): {stderr.decode()}"
)
finally:
os.unlink(flist_path)
logger.info("Stitched %d files -> %s", len(audio_paths), output_path)
return output_path
def make_export_path(data_dir: str, chain_id: str, fmt: ExportFormat) -> str:
"""Standard export output path."""
return str(Path(data_dir) / "chains" / chain_id / f"export.{fmt}")

157
app/services/musicgen.py Normal file
View file

@ -0,0 +1,157 @@
# app/services/musicgen.py — MusicGen client (cf-orch allocation + /continue call)
#
# Mock mode (CF_MUSICGEN_MOCK=1): copies source file, adds 1s of silence padding.
# Real mode: allocates cf-musicgen from cf-orch, calls POST /continue, releases.
#
# cf_core.musicgen.app is not yet implemented (tracked in cf-core #49).
# Until it is, real mode will fail at allocation time with a clear error.
from __future__ import annotations
import asyncio
import logging
import os
import shutil
import time
import uuid
from pathlib import Path
import httpx
logger = logging.getLogger(__name__)
_ORCH_URL = os.environ.get("CF_ORCH_URL", "").rstrip("/")
_SERVICE = "cf-musicgen"
_MOCK = os.environ.get("CF_MUSICGEN_MOCK", "") == "1"
class MusicGenClient:
"""
Allocates a cf-musicgen instance from cf-orch and calls POST /continue.
Each generate() call allocates, generates, and releases. For Premium tier
session-held allocations, subclass and override _allocate/_release.
"""
async def generate(
self,
source_audio_path: str,
output_path: str,
prompt: str,
duration_s: float,
cfg_coef: float,
prompt_duration_s: float,
) -> float:
"""
Generate a continuation of source_audio_path.
Writes the result to output_path and returns the actual duration_s.
Raises RuntimeError on cf-orch allocation failure or generation error.
"""
if _MOCK:
return await _mock_generate(source_audio_path, output_path, duration_s)
service_url = await self._allocate()
try:
return await _call_continue(
service_url=service_url,
audio_path=source_audio_path,
output_path=output_path,
prompt=prompt,
duration_s=duration_s,
cfg_coef=cfg_coef,
prompt_duration_s=prompt_duration_s,
)
finally:
await self._release(service_url)
async def _allocate(self) -> str:
"""Request a cf-musicgen allocation from cf-orch. Returns the service URL."""
if not _ORCH_URL:
raise RuntimeError(
"CF_ORCH_URL is not configured. Set it in .env or use CF_MUSICGEN_MOCK=1."
)
async with httpx.AsyncClient(timeout=30.0) as client:
resp = await client.post(
f"{_ORCH_URL}/api/services/{_SERVICE}/allocations",
json={"requester": "sparrow"},
)
if resp.status_code == 503:
raise RuntimeError("No cf-musicgen capacity available — all GPUs busy.")
resp.raise_for_status()
data = resp.json()
url = data.get("url") or data.get("service_url")
if not url:
raise RuntimeError(f"cf-orch allocation response missing URL: {data}")
logger.info("Allocated cf-musicgen at %s", url)
return url
async def _release(self, service_url: str) -> None:
"""Release the cf-musicgen allocation back to cf-orch."""
if not _ORCH_URL:
return
try:
async with httpx.AsyncClient(timeout=10.0) as client:
await client.delete(f"{_ORCH_URL}/api/services/{_SERVICE}/allocations",
json={"url": service_url})
except Exception as exc:
logger.warning("Failed to release cf-musicgen allocation: %s", exc)
# ── Real /continue call ───────────────────────────────────────────────────────
async def _call_continue(
service_url: str,
audio_path: str,
output_path: str,
prompt: str,
duration_s: float,
cfg_coef: float,
prompt_duration_s: float,
) -> float:
"""Call POST /continue on the allocated cf-musicgen service."""
payload = {
"audio_path": audio_path,
"output_path": output_path,
"prompt": prompt,
"duration_s": duration_s,
"cfg_coef": cfg_coef,
"prompt_duration_s": prompt_duration_s,
}
# MusicGen can take 30120s depending on duration and hardware
async with httpx.AsyncClient(timeout=180.0) as client:
resp = await client.post(f"{service_url.rstrip('/')}/continue", json=payload)
resp.raise_for_status()
data = resp.json()
return float(data.get("duration_s", duration_s))
# ── Mock generation ───────────────────────────────────────────────────────────
async def _mock_generate(
source_audio_path: str,
output_path: str,
duration_s: float,
) -> float:
"""
Mock: copy the source file to output_path with a short simulated delay.
Simulates generation latency so the UI state machine (pending generating ready)
exercises all transitions during development.
"""
await asyncio.sleep(2.0) # simulate generation time
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
shutil.copy2(source_audio_path, output_path)
logger.info("Mock generation: copied %s%s", source_audio_path, output_path)
# Return the actual duration from the copied file
try:
import torchaudio
info = torchaudio.info(output_path)
return info.num_frames / info.sample_rate
except Exception:
return duration_s
def make_output_path(data_dir: str, chain_id: str, node_id: str) -> str:
"""Standard output path for a generated node's audio file."""
return str(Path(data_dir) / "chains" / chain_id / "nodes" / f"{node_id}.wav")

75
app/services/stems.py Normal file
View file

@ -0,0 +1,75 @@
# app/services/stems.py — Demucs 4-stem separation subprocess wrapper
#
# Mock mode (CF_STEMS_MOCK=1): copies source file to all 4 stem paths.
# Real mode: runs demucs htdemucs model via asyncio.create_subprocess_exec
# (list-form, no shell=True — not vulnerable to injection).
#
# Demucs output layout: {output_dir}/htdemucs/{track.stem}/{stem}.wav
from __future__ import annotations
import asyncio
import logging
import os
import shutil
from pathlib import Path
logger = logging.getLogger(__name__)
_MOCK = os.environ.get("CF_STEMS_MOCK", "") == "1"
_STEMS = ("vocals", "drums", "bass", "other")
async def separate(source_audio_path: str, output_dir: str) -> dict[str, str]:
"""
Split source_audio_path into 4 stems.
Returns a dict of stem name to absolute file path.
Raises RuntimeError on demucs failure.
"""
if _MOCK:
return await _mock_separate(source_audio_path, output_dir)
return await _real_separate(source_audio_path, output_dir)
async def _real_separate(source_audio_path: str, output_dir: str) -> dict[str, str]:
track = Path(source_audio_path)
logger.info("Running demucs on %s -> %s", track.name, output_dir)
# create_subprocess_exec takes a list of args — no shell interpolation
proc = await asyncio.create_subprocess_exec(
"python", "-m", "demucs",
"--model", "htdemucs",
"--out", output_dir,
str(track),
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
_, stderr = await proc.communicate()
if proc.returncode != 0:
raise RuntimeError(f"demucs failed (exit {proc.returncode}): {stderr.decode()}")
return _stem_paths(output_dir, track.stem)
async def _mock_separate(source_audio_path: str, output_dir: str) -> dict[str, str]:
"""Mock: copy source to all 4 stems with a short simulated delay."""
await asyncio.sleep(1.5)
track = Path(source_audio_path)
stem_dir = Path(output_dir) / "htdemucs" / track.stem
stem_dir.mkdir(parents=True, exist_ok=True)
paths: dict[str, str] = {}
for stem in _STEMS:
dest = str(stem_dir / f"{stem}.wav")
shutil.copy2(source_audio_path, dest)
paths[stem] = dest
logger.info("Mock stems: copied %s -> %s", track.name, stem_dir)
return paths
def _stem_paths(output_dir: str, track_stem: str) -> dict[str, str]:
"""Build expected output paths from demucs's standard directory layout."""
stem_dir = Path(output_dir) / "htdemucs" / track_stem
return {s: str(stem_dir / f"{s}.wav") for s in _STEMS}
def make_stems_dir(data_dir: str, chain_id: str, node_id: str) -> str:
"""Standard stems output directory for a node."""
return str(Path(data_dir) / "chains" / chain_id / "stems" / node_id)

52
app/tiers.py Normal file
View file

@ -0,0 +1,52 @@
# app/tiers.py — tier gates for Sparrow
#
# BYOK does not apply to Sparrow — "bring your own" means bring your own
# hardware (self-hosting). is_self_hosted() checks the cf-orch URL target.
from __future__ import annotations
import os
from fastapi import HTTPException
def is_self_hosted() -> bool:
"""
True when cf-orch is on localhost or a LAN address.
Self-hosted users get unrestricted access equivalent to Paid tier.
Cloud users are gated by the standard tier model.
"""
orch_url = os.environ.get("CF_ORCH_URL", "")
if not orch_url:
return True # no orch = local dev mode = self-hosted
host = orch_url.split("//")[-1].split(":")[0].split("/")[0]
return host in ("localhost", "127.0.0.1") or host.startswith("10.") or \
host.startswith("192.168.") or host.startswith("172.")
def require_tier(feature: str, tier: str | None = None) -> None:
"""
Gate a feature by tier. Raises 402 if the caller does not qualify.
For self-hosted instances all features are unlocked.
tier=None means the request carries no tier treat as Free.
"""
if is_self_hosted():
return
_TIER_RANK = {"free": 0, "paid": 1, "premium": 2}
_FEATURE_MIN: dict[str, str] = {
"stems": "paid",
"parallel_branch": "paid",
"session_allocation": "premium",
"priority_queue": "premium",
}
min_tier = _FEATURE_MIN.get(feature, "free")
caller_rank = _TIER_RANK.get(tier or "free", 0)
required_rank = _TIER_RANK[min_tier]
if caller_rank < required_rank:
raise HTTPException(
status_code=402,
detail=f"Feature '{feature}' requires {min_tier} tier.",
)

65
tests/conftest.py Normal file
View file

@ -0,0 +1,65 @@
# tests/conftest.py — shared fixtures for Sparrow test suite
from __future__ import annotations
import os
import tempfile
import pytest
import pytest_asyncio
from fastapi.testclient import TestClient
from httpx import ASGITransport, AsyncClient
# Use mock mode for all tests — no GPU or ffmpeg required
os.environ.setdefault("CF_MUSICGEN_MOCK", "1")
os.environ.setdefault("CF_STEMS_MOCK", "1")
@pytest.fixture
def tmp_db(tmp_path):
"""Temporary SQLite DB path."""
return str(tmp_path / "test.db")
@pytest.fixture
def tmp_data(tmp_path):
"""Temporary data directory."""
d = tmp_path / "data"
d.mkdir()
return str(d)
@pytest.fixture
def conn(tmp_db):
"""SQLite connection with migrations applied."""
from app.db.store import get_connection, run_migrations
c = get_connection(tmp_db)
run_migrations(c)
yield c
c.close()
@pytest.fixture
def app(tmp_db, tmp_data):
"""FastAPI test app with isolated DB and data dir."""
os.environ["SPARROW_DB_PATH"] = tmp_db
os.environ["SPARROW_DATA_DIR"] = tmp_data
os.environ["SPARROW_ENV"] = "development"
from app.main import create_app
return create_app()
@pytest.fixture
def client(app):
"""Synchronous test client."""
with TestClient(app, raise_server_exceptions=True) as c:
yield c
@pytest_asyncio.fixture
async def async_client(app):
"""Async test client."""
async with AsyncClient(
transport=ASGITransport(app=app), base_url="http://test"
) as c:
yield c

69
tests/test_api_chains.py Normal file
View file

@ -0,0 +1,69 @@
# tests/test_api_chains.py — integration tests for chain + upload endpoints
from __future__ import annotations
import io
def test_create_and_list_chains(client):
resp = client.post("/api/chains/", json={"name": "my chain"})
assert resp.status_code == 201
data = resp.json()
assert data["name"] == "my chain"
assert "id" in data
resp = client.get("/api/chains/")
assert resp.status_code == 200
assert len(resp.json()) == 1
def test_get_chain_not_found(client):
resp = client.get("/api/chains/doesnotexist")
assert resp.status_code == 404
def test_delete_chain(client):
resp = client.post("/api/chains/", json={"name": "to delete"})
chain_id = resp.json()["id"]
resp = client.delete(f"/api/chains/{chain_id}")
assert resp.status_code == 204
resp = client.get(f"/api/chains/{chain_id}")
assert resp.status_code == 404
def test_upload_root_node(client, tmp_path):
resp = client.post("/api/chains/", json={"name": "upload test"})
chain_id = resp.json()["id"]
# Minimal valid WAV header (44 bytes) so torchaudio doesn't crash
# If torchaudio isn't installed, _probe_duration returns 0.0 gracefully
wav_bytes = b"RIFF" + b"\x00" * 40
resp = client.post(
f"/api/chains/{chain_id}/upload",
files={"file": ("test.wav", io.BytesIO(wav_bytes), "audio/wav")},
)
assert resp.status_code == 201
node = resp.json()
assert node["status"] == "ready"
assert node["is_committed"] is True
assert node["parent_id"] is None
def test_upload_unsupported_format(client):
resp = client.post("/api/chains/", json={"name": "format test"})
chain_id = resp.json()["id"]
resp = client.post(
f"/api/chains/{chain_id}/upload",
files={"file": ("track.txt", io.BytesIO(b"not audio"), "text/plain")},
)
assert resp.status_code == 422
def test_upload_chain_not_found(client):
resp = client.post(
"/api/chains/nonexistent/upload",
files={"file": ("test.wav", io.BytesIO(b""), "audio/wav")},
)
assert resp.status_code == 404

102
tests/test_api_nodes.py Normal file
View file

@ -0,0 +1,102 @@
# tests/test_api_nodes.py — integration tests for branch/commit/discard endpoints
from __future__ import annotations
import io
def _create_chain_with_root(client, tmp_path=None):
"""Helper: create a chain and upload a root node."""
resp = client.post("/api/chains/", json={"name": "test chain"})
chain_id = resp.json()["id"]
resp = client.post(
f"/api/chains/{chain_id}/upload",
files={"file": ("root.wav", io.BytesIO(b"RIFF" + b"\x00" * 40), "audio/wav")},
)
node = resp.json()
return chain_id, node["id"]
def test_create_branch_returns_202(client, tmp_path):
chain_id, root_id = _create_chain_with_root(client)
resp = client.post(
f"/api/nodes/{root_id}/branch",
json={"prompt": "upbeat jazz", "duration_s": 10.0},
)
assert resp.status_code == 202
node = resp.json()
assert node["status"] == "pending"
assert node["parent_id"] == root_id
assert node["prompt"] == "upbeat jazz"
def test_branch_parent_not_found(client):
resp = client.post(
"/api/nodes/nonexistent/branch",
json={"prompt": "test"},
)
assert resp.status_code == 404
def test_commit_requires_ready_status(client, conn):
"""
Directly insert a pending node into the shared DB (bypasses BackgroundTasks
so we control the status without triggering mock generation).
"""
from app.services import chain as svc
chain_id, root_id = _create_chain_with_root(client)
# Create branch via service (no background task), stays "pending"
branch = svc.create_branch_node(
conn, root_id, chain_id, "test", None, None, None, 3.0, 10.0
)
resp = client.post(f"/api/nodes/{branch['id']}/commit")
assert resp.status_code == 409
def test_delete_root_node_refused(client):
chain_id, root_id = _create_chain_with_root(client)
resp = client.delete(f"/api/nodes/{root_id}")
assert resp.status_code == 409
def test_delete_node_not_found(client):
resp = client.delete("/api/nodes/doesnotexist")
assert resp.status_code == 409
def test_audio_not_found_for_pending_node(client, conn):
"""Pending node (no audio_path yet) returns 404 on the audio endpoint."""
from app.services import chain as svc
chain_id, root_id = _create_chain_with_root(client)
branch = svc.create_branch_node(
conn, root_id, chain_id, "test", None, None, None, 3.0, 10.0
)
resp = client.get(f"/api/nodes/{branch['id']}/audio")
assert resp.status_code == 404
def test_mock_generation_completes(client, conn):
"""
With CF_MUSICGEN_MOCK=1, TestClient runs BackgroundTasks synchronously,
so the node transitions to 'ready' before client.post() returns.
We verify the completed state via the shared conn fixture.
"""
from app.services import chain as svc
chain_id, root_id = _create_chain_with_root(client)
resp = client.post(
f"/api/nodes/{root_id}/branch",
json={"prompt": "test mock", "duration_s": 5.0},
)
assert resp.status_code == 202
branch_id = resp.json()["id"]
# Background task completed synchronously — node is ready
node = svc.get_node(conn, branch_id)
assert node is not None
assert node["status"] == "ready", (
f"Expected ready, got: {node['status']}, error: {node.get('error_msg')}"
)

150
tests/test_chain_service.py Normal file
View file

@ -0,0 +1,150 @@
# tests/test_chain_service.py — unit tests for chain/node CRUD service
from __future__ import annotations
import pytest
from app.services import chain as svc
def test_create_and_list_chain(conn):
chain = svc.create_chain(conn, "test chain")
assert chain["name"] == "test chain"
assert chain["node_count"] == 0
chains = svc.list_chains(conn)
assert len(chains) == 1
assert chains[0]["id"] == chain["id"]
def test_get_chain_not_found(conn):
assert svc.get_chain(conn, "nonexistent") is None
def test_delete_chain(conn):
chain = svc.create_chain(conn, "deleteme")
assert svc.delete_chain(conn, chain["id"])
assert svc.get_chain(conn, chain["id"]) is None
def test_create_root_node(conn, tmp_path):
chain = svc.create_chain(conn, "root test")
audio = str(tmp_path / "audio.wav")
open(audio, "wb").close() # empty file, just needs to exist
node = svc.create_root_node(conn, chain["id"], audio, 30.0)
assert node["status"] == "ready"
assert node["is_committed"] is True
assert node["parent_id"] is None
assert node["audio_path"] == audio
assert node["duration_s"] == 30.0
def test_create_branch_node(conn, tmp_path):
chain = svc.create_chain(conn, "branch test")
audio = str(tmp_path / "audio.wav")
open(audio, "wb").close()
root = svc.create_root_node(conn, chain["id"], audio, 30.0)
branch = svc.create_branch_node(
conn,
parent_id=root["id"],
chain_id=chain["id"],
prompt="upbeat jazz",
energy=0.8,
tempo_feel=None,
density=None,
cfg_coef=3.0,
prompt_duration_s=10.0,
)
assert branch["status"] == "pending"
assert branch["is_committed"] is False
assert branch["parent_id"] == root["id"]
assert branch["prompt"] == "upbeat jazz"
def test_update_node_status(conn, tmp_path):
chain = svc.create_chain(conn, "status test")
audio = str(tmp_path / "audio.wav")
open(audio, "wb").close()
root = svc.create_root_node(conn, chain["id"], audio, 30.0)
branch = svc.create_branch_node(
conn, root["id"], chain["id"], "test", None, None, None, 3.0, 10.0
)
updated = svc.update_node_status(
conn, branch["id"], "generating"
)
assert updated["status"] == "generating"
out_path = str(tmp_path / "out.wav")
open(out_path, "wb").close()
done = svc.update_node_status(
conn, branch["id"], "ready", audio_path=out_path, duration_s=15.0
)
assert done["status"] == "ready"
assert done["audio_path"] == out_path
assert done["duration_s"] == 15.0
def test_commit_node_discards_siblings(conn, tmp_path):
chain = svc.create_chain(conn, "commit test")
audio = str(tmp_path / "audio.wav")
open(audio, "wb").close()
root = svc.create_root_node(conn, chain["id"], audio, 30.0)
b1 = svc.create_branch_node(
conn, root["id"], chain["id"], "branch 1", None, None, None, 3.0, 10.0
)
b2 = svc.create_branch_node(
conn, root["id"], chain["id"], "branch 2", None, None, None, 3.0, 10.0
)
# Make both ready
for nid in (b1["id"], b2["id"]):
svc.update_node_status(conn, nid, "ready",
audio_path=audio, duration_s=15.0)
# Commit b1 — b2 should be deleted
committed = svc.commit_node(conn, b1["id"])
assert committed["is_committed"] is True
assert svc.get_node(conn, b2["id"]) is None
def test_delete_node_refuses_root(conn, tmp_path):
chain = svc.create_chain(conn, "del test")
audio = str(tmp_path / "audio.wav")
open(audio, "wb").close()
root = svc.create_root_node(conn, chain["id"], audio, 30.0)
assert svc.delete_node(conn, root["id"]) is False
def test_delete_node_refuses_committed(conn, tmp_path):
chain = svc.create_chain(conn, "del committed")
audio = str(tmp_path / "audio.wav")
open(audio, "wb").close()
root = svc.create_root_node(conn, chain["id"], audio, 30.0)
branch = svc.create_branch_node(
conn, root["id"], chain["id"], "", None, None, None, 3.0, 10.0
)
svc.update_node_status(conn, branch["id"], "ready",
audio_path=audio, duration_s=15.0)
svc.commit_node(conn, branch["id"])
assert svc.delete_node(conn, branch["id"]) is False
def test_committed_spine_order(conn, tmp_path):
chain = svc.create_chain(conn, "spine test")
audio = str(tmp_path / "audio.wav")
open(audio, "wb").close()
root = svc.create_root_node(conn, chain["id"], audio, 30.0)
child = svc.create_branch_node(
conn, root["id"], chain["id"], "child", None, None, None, 3.0, 10.0
)
svc.update_node_status(conn, child["id"], "ready",
audio_path=audio, duration_s=15.0)
svc.commit_node(conn, child["id"])
spine = svc.get_committed_spine(conn, chain["id"])
assert len(spine) == 2
assert spine[0]["id"] == root["id"]
assert spine[1]["id"] == child["id"]

View file

@ -0,0 +1,34 @@
# tests/test_export_service.py — unit tests for export stitch service
from __future__ import annotations
import pytest
@pytest.mark.asyncio
async def test_stitch_single_file(tmp_path):
"""Single file: should copy without invoking ffmpeg."""
from app.services.export import stitch
src = tmp_path / "segment.wav"
src.write_bytes(b"WAV DATA")
out = str(tmp_path / "output.wav")
result = await stitch([str(src)], out)
assert result == out
assert open(out, "rb").read() == b"WAV DATA"
@pytest.mark.asyncio
async def test_stitch_empty_raises(tmp_path):
from app.services.export import stitch
with pytest.raises(ValueError, match="No audio paths"):
await stitch([], str(tmp_path / "out.wav"))
def test_make_export_path():
from app.services.export import make_export_path
p = make_export_path("data", "chain-abc", "mp3")
assert p == "data/chains/chain-abc/export.mp3"
assert make_export_path("data", "chain-abc", "wav").endswith(".wav")

View file

@ -0,0 +1,36 @@
# tests/test_musicgen_service.py — unit tests for MusicGenClient mock mode
from __future__ import annotations
import os
import pytest
# Ensure mock mode
os.environ["CF_MUSICGEN_MOCK"] = "1"
@pytest.mark.asyncio
async def test_mock_generate_copies_file(tmp_path):
from app.services.musicgen import MusicGenClient
source = tmp_path / "source.wav"
source.write_bytes(b"RIFF" + b"\x00" * 40)
output = str(tmp_path / "output.wav")
client = MusicGenClient()
duration = await client.generate(
source_audio_path=str(source),
output_path=output,
prompt="test",
duration_s=10.0,
cfg_coef=3.0,
prompt_duration_s=5.0,
)
assert os.path.exists(output)
assert duration >= 0.0 # torchaudio may not parse our fake WAV; 0.0 is fine
def test_make_output_path():
from app.services.musicgen import make_output_path
path = make_output_path("data", "chain-abc", "node-xyz")
assert path == "data/chains/chain-abc/nodes/node-xyz.wav"