feat: implement Sparrow backend (v0.1.0)
Full FastAPI backend for the AI music continuation editor:
Services
- chain.py: chain + node CRUD, commit/discard, recursive CTE spine query
- musicgen.py: MusicGenClient with cf-orch allocation + mock mode (CF_MUSICGEN_MOCK=1)
- stems.py: Demucs 4-stem separation subprocess wrapper + mock mode
- export.py: ffmpeg concat demuxer to stitch committed spine into WAV/MP3
API endpoints
- chains: CRUD, multipart audio upload (WAV/MP3/FLAC/OGG/M4A/AIFF)
- nodes: branch creation (202 + BackgroundTasks), commit, discard, audio stream
- gpu: cf-orch capacity status; session allocation stubbed pending cf-orch#43
- stems: Paid-tier stem separation (Demucs, gated via tiers.py)
- export: POST /{chain_id}/export → FileResponse download
- events: SSE stream (node-status events) per chain via asyncio Queue pub/sub
Infrastructure
- lifespan: reads SPARROW_DB_PATH/DATA_DIR at startup (not import time)
- events_store: subscribe/unsubscribe/broadcast pattern for SSE
- CORS: open in dev, SPARROW_CORS_ORIGINS in production
- Background generation opens its own DB connection (WAL-safe)
Tests: 30/30 passing across service units and API integration
This commit is contained in:
parent
c3e18e83e5
commit
a6f60c9e07
23 changed files with 1760 additions and 0 deletions
17
app/api/deps.py
Normal file
17
app/api/deps.py
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
# app/api/deps.py — FastAPI dependency providers
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import sqlite3
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
|
||||
def get_conn(request: Request) -> sqlite3.Connection:
|
||||
"""Yield the shared SQLite connection from app state."""
|
||||
return request.app.state.conn
|
||||
|
||||
|
||||
def get_data_dir(request: Request) -> str:
|
||||
"""Return the configured data directory."""
|
||||
return request.app.state.data_dir
|
||||
90
app/api/endpoints/chains.py
Normal file
90
app/api/endpoints/chains.py
Normal file
|
|
@ -0,0 +1,90 @@
|
|||
# app/api/endpoints/chains.py — chain CRUD + root audio upload
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile
|
||||
|
||||
from app.api.deps import get_conn, get_data_dir
|
||||
from app.models.schemas.chain import ChainCreate, ChainRow, ChainTree
|
||||
from app.models.schemas.node import NodeRow
|
||||
from app.services import chain as chain_svc
|
||||
|
||||
router = APIRouter(prefix="/api/chains", tags=["chains"])
|
||||
|
||||
_ALLOWED_AUDIO = {".wav", ".mp3", ".flac", ".ogg", ".m4a", ".aiff"}
|
||||
|
||||
|
||||
@router.post("/", response_model=ChainRow, status_code=201)
|
||||
def create_chain(body: ChainCreate, conn=Depends(get_conn)) -> dict:
|
||||
return chain_svc.create_chain(conn, body.name)
|
||||
|
||||
|
||||
@router.get("/", response_model=list[ChainRow])
|
||||
def list_chains(conn=Depends(get_conn)) -> list:
|
||||
return chain_svc.list_chains(conn)
|
||||
|
||||
|
||||
@router.get("/{chain_id}", response_model=ChainTree)
|
||||
def get_chain(chain_id: str, conn=Depends(get_conn)) -> dict:
|
||||
chain = chain_svc.get_chain(conn, chain_id)
|
||||
if chain is None:
|
||||
raise HTTPException(status_code=404, detail="Chain not found.")
|
||||
return chain
|
||||
|
||||
|
||||
@router.delete("/{chain_id}", status_code=204)
|
||||
def delete_chain(chain_id: str, conn=Depends(get_conn)) -> None:
|
||||
if not chain_svc.delete_chain(conn, chain_id):
|
||||
raise HTTPException(status_code=404, detail="Chain not found.")
|
||||
|
||||
|
||||
@router.post("/{chain_id}/upload", response_model=NodeRow, status_code=201)
|
||||
async def upload_root(
|
||||
chain_id: str,
|
||||
file: UploadFile = File(...),
|
||||
conn=Depends(get_conn),
|
||||
data_dir: str = Depends(get_data_dir),
|
||||
) -> dict:
|
||||
"""
|
||||
Upload source audio and create the root node for a chain.
|
||||
|
||||
Accepts WAV, MP3, FLAC, OGG, M4A, AIFF. The file is stored in
|
||||
data/chains/{chain_id}/nodes/root_{node_id}{ext} and a root node
|
||||
(always committed) is created in the DB.
|
||||
"""
|
||||
chain = chain_svc.get_chain(conn, chain_id)
|
||||
if chain is None:
|
||||
raise HTTPException(status_code=404, detail="Chain not found.")
|
||||
|
||||
suffix = Path(file.filename or "audio.wav").suffix.lower()
|
||||
if suffix not in _ALLOWED_AUDIO:
|
||||
raise HTTPException(
|
||||
status_code=422,
|
||||
detail=f"Unsupported audio format '{suffix}'. "
|
||||
f"Allowed: {', '.join(sorted(_ALLOWED_AUDIO))}",
|
||||
)
|
||||
|
||||
node_id = str(uuid.uuid4())
|
||||
dest_dir = Path(data_dir) / "chains" / chain_id / "nodes"
|
||||
dest_dir.mkdir(parents=True, exist_ok=True)
|
||||
audio_path = str(dest_dir / f"root_{node_id}{suffix}")
|
||||
|
||||
contents = await file.read()
|
||||
Path(audio_path).write_bytes(contents)
|
||||
|
||||
# Get duration via torchaudio if available, else 0.0
|
||||
duration_s = _probe_duration(audio_path)
|
||||
|
||||
return chain_svc.create_root_node(conn, chain_id, audio_path, duration_s)
|
||||
|
||||
|
||||
def _probe_duration(path: str) -> float:
|
||||
try:
|
||||
import torchaudio
|
||||
info = torchaudio.info(path)
|
||||
return info.num_frames / info.sample_rate
|
||||
except Exception:
|
||||
return 0.0
|
||||
54
app/api/endpoints/events.py
Normal file
54
app/api/endpoints/events.py
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
# app/api/endpoints/events.py — SSE stream for node status transitions
|
||||
#
|
||||
# Clients connect to GET /api/chains/{chain_id}/events and receive
|
||||
# server-sent events whenever a node in that chain changes status.
|
||||
#
|
||||
# Event format:
|
||||
# event: node-status
|
||||
# data: {"node_id": "...", "status": "generating"|"ready"|"error", ...}
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
|
||||
from app.api.deps import get_conn
|
||||
from app.api.events_store import subscribe, unsubscribe
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api/chains", tags=["events"])
|
||||
|
||||
_KEEPALIVE_S = 15 # send a comment ping every N seconds to keep the connection alive
|
||||
|
||||
|
||||
@router.get("/{chain_id}/events")
|
||||
async def chain_events(chain_id: str, conn=Depends(get_conn)) -> EventSourceResponse:
|
||||
"""
|
||||
SSE stream for node status transitions in a chain.
|
||||
|
||||
Emits 'node-status' events. Closes when the client disconnects.
|
||||
"""
|
||||
q = subscribe(chain_id)
|
||||
|
||||
async def generator():
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
event = await asyncio.wait_for(q.get(), timeout=_KEEPALIVE_S)
|
||||
yield {
|
||||
"event": "node-status",
|
||||
"data": json.dumps(event),
|
||||
}
|
||||
except asyncio.TimeoutError:
|
||||
# Send keepalive comment to prevent proxy/browser timeout
|
||||
yield {"comment": "keepalive"}
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
finally:
|
||||
unsubscribe(chain_id, q)
|
||||
logger.debug("SSE client disconnected from chain %s", chain_id)
|
||||
|
||||
return EventSourceResponse(generator())
|
||||
64
app/api/endpoints/export.py
Normal file
64
app/api/endpoints/export.py
Normal file
|
|
@ -0,0 +1,64 @@
|
|||
# app/api/endpoints/export.py — stitch committed spine and serve as download
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import FileResponse
|
||||
|
||||
from app.api.deps import get_conn, get_data_dir
|
||||
from app.models.schemas.node import ExportRequest
|
||||
from app.services import chain as chain_svc
|
||||
from app.services.export import ExportFormat, make_export_path, stitch
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api/chains", tags=["export"])
|
||||
|
||||
|
||||
@router.post("/{chain_id}/export")
|
||||
async def export_chain(
|
||||
chain_id: str,
|
||||
body: ExportRequest,
|
||||
conn=Depends(get_conn),
|
||||
data_dir: str = Depends(get_data_dir),
|
||||
) -> FileResponse:
|
||||
"""
|
||||
Stitch node audio files into a single download.
|
||||
|
||||
If body.node_ids is provided, those nodes are used in order.
|
||||
Otherwise the committed spine (root → leaf) is used.
|
||||
Raises 409 if any targeted node has no audio.
|
||||
"""
|
||||
if body.node_ids is not None:
|
||||
nodes = [chain_svc.get_node(conn, nid) for nid in body.node_ids]
|
||||
if any(n is None for n in nodes):
|
||||
raise HTTPException(status_code=404, detail="One or more nodes not found.")
|
||||
else:
|
||||
nodes = chain_svc.get_committed_spine(conn, chain_id)
|
||||
if not nodes:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Chain not found or has no committed nodes.",
|
||||
)
|
||||
|
||||
audio_paths = [n["audio_path"] for n in nodes if n] # type: ignore[index]
|
||||
missing = [n["id"] for n in nodes if n and n["audio_path"] is None] # type: ignore[index]
|
||||
if missing:
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail=f"Nodes {missing} have no audio yet.",
|
||||
)
|
||||
|
||||
fmt: ExportFormat = body.format # type: ignore[assignment]
|
||||
output_path = make_export_path(data_dir, chain_id, fmt)
|
||||
|
||||
try:
|
||||
await stitch(audio_paths, output_path, fmt)
|
||||
except Exception as exc:
|
||||
logger.exception("Export failed for chain %s: %s", chain_id, exc)
|
||||
raise HTTPException(status_code=500, detail=f"Export failed: {exc}") from exc
|
||||
|
||||
media_type = "audio/wav" if fmt == "wav" else "audio/mpeg"
|
||||
filename = f"sparrow_export_{chain_id[:8]}.{fmt}"
|
||||
return FileResponse(output_path, media_type=media_type, filename=filename)
|
||||
73
app/api/endpoints/gpu.py
Normal file
73
app/api/endpoints/gpu.py
Normal file
|
|
@ -0,0 +1,73 @@
|
|||
# app/api/endpoints/gpu.py — cf-orch GPU status and session allocation
|
||||
#
|
||||
# GET /api/gpu/status — available capacity from cf-orch
|
||||
# POST /api/gpu/connect — session-held allocation (Premium tier, stub)
|
||||
# DELETE /api/gpu/disconnect — release session allocation (Premium tier, stub)
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter, HTTPException
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api/gpu", tags=["gpu"])
|
||||
|
||||
_ORCH_URL = os.environ.get("CF_ORCH_URL", "").rstrip("/")
|
||||
_SERVICE = "cf-musicgen"
|
||||
|
||||
|
||||
@router.get("/status")
|
||||
async def gpu_status() -> dict:
|
||||
"""
|
||||
Return current cf-orch capacity for cf-musicgen.
|
||||
|
||||
Returns {"available": False, "reason": "..."} if cf-orch is unreachable
|
||||
or unconfigured (e.g. mock mode).
|
||||
"""
|
||||
if not _ORCH_URL:
|
||||
return {
|
||||
"available": False,
|
||||
"reason": "CF_ORCH_URL not configured — running in mock mode.",
|
||||
"mock": True,
|
||||
}
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=5.0) as client:
|
||||
resp = await client.get(
|
||||
f"{_ORCH_URL}/api/services/{_SERVICE}/status"
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
except httpx.HTTPStatusError as exc:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail=f"cf-orch returned {exc.response.status_code}.",
|
||||
) from exc
|
||||
except Exception as exc:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail=f"cf-orch unreachable: {exc}",
|
||||
) from exc
|
||||
|
||||
|
||||
@router.post("/connect", status_code=501)
|
||||
async def gpu_connect() -> dict:
|
||||
"""
|
||||
Session-held GPU allocation (Premium tier).
|
||||
|
||||
Not yet implemented — tracked in cf-orch #43.
|
||||
"""
|
||||
raise HTTPException(
|
||||
status_code=501,
|
||||
detail="Session-held GPU allocation is not yet implemented (cf-orch #43).",
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/disconnect", status_code=501)
|
||||
async def gpu_disconnect() -> dict:
|
||||
"""Release a session-held GPU allocation (Premium tier, stub)."""
|
||||
raise HTTPException(
|
||||
status_code=501,
|
||||
detail="Session-held GPU allocation is not yet implemented (cf-orch #43).",
|
||||
)
|
||||
182
app/api/endpoints/nodes.py
Normal file
182
app/api/endpoints/nodes.py
Normal file
|
|
@ -0,0 +1,182 @@
|
|||
# app/api/endpoints/nodes.py — branch generation, commit, discard, audio stream
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException
|
||||
from fastapi.responses import FileResponse
|
||||
|
||||
from app.api.deps import get_conn, get_data_dir
|
||||
from app.api.events_store import broadcast
|
||||
from app.models.schemas.node import BranchRequest, NodeRow
|
||||
from app.services import chain as chain_svc
|
||||
from app.services.musicgen import MusicGenClient, make_output_path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api/nodes", tags=["nodes"])
|
||||
|
||||
_musicgen = MusicGenClient()
|
||||
|
||||
|
||||
@router.post("/{parent_id}/branch", response_model=NodeRow, status_code=202)
|
||||
async def create_branch(
|
||||
parent_id: str,
|
||||
body: BranchRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
conn=Depends(get_conn),
|
||||
data_dir: str = Depends(get_data_dir),
|
||||
) -> dict:
|
||||
"""
|
||||
Create a branch node from parent_id and kick off async generation.
|
||||
|
||||
Returns 202 immediately with the node in 'pending' state.
|
||||
Status transitions (pending → generating → ready/error) are pushed
|
||||
via the chain's SSE stream.
|
||||
"""
|
||||
parent = chain_svc.get_node(conn, parent_id)
|
||||
if parent is None:
|
||||
raise HTTPException(status_code=404, detail="Parent node not found.")
|
||||
if parent["audio_path"] is None:
|
||||
raise HTTPException(
|
||||
status_code=409, detail="Parent node has no audio yet."
|
||||
)
|
||||
|
||||
node = chain_svc.create_branch_node(
|
||||
conn,
|
||||
parent_id=parent_id,
|
||||
chain_id=parent["chain_id"],
|
||||
prompt=body.prompt,
|
||||
energy=body.energy,
|
||||
tempo_feel=body.tempo_feel,
|
||||
density=body.density,
|
||||
cfg_coef=body.cfg_coef,
|
||||
prompt_duration_s=body.prompt_duration_s,
|
||||
)
|
||||
|
||||
background_tasks.add_task(
|
||||
_run_generation,
|
||||
node_id=node["id"],
|
||||
chain_id=parent["chain_id"],
|
||||
source_audio_path=parent["audio_path"],
|
||||
prompt=body.prompt,
|
||||
duration_s=body.duration_s,
|
||||
cfg_coef=body.cfg_coef,
|
||||
prompt_duration_s=body.prompt_duration_s,
|
||||
data_dir=data_dir,
|
||||
db_path=conn.execute("PRAGMA database_list").fetchone()[2],
|
||||
)
|
||||
|
||||
return node
|
||||
|
||||
|
||||
@router.post("/{node_id}/commit", response_model=NodeRow)
|
||||
def commit_node(node_id: str, conn=Depends(get_conn)) -> dict:
|
||||
"""
|
||||
Promote a branch node to the committed spine.
|
||||
|
||||
Discards all non-committed siblings (same parent). The root node is
|
||||
always committed and cannot be targeted here.
|
||||
"""
|
||||
node = chain_svc.get_node(conn, node_id)
|
||||
if node is None:
|
||||
raise HTTPException(status_code=404, detail="Node not found.")
|
||||
if node["status"] != "ready":
|
||||
raise HTTPException(
|
||||
status_code=409, detail="Only 'ready' nodes can be committed."
|
||||
)
|
||||
result = chain_svc.commit_node(conn, node_id)
|
||||
if result is None:
|
||||
raise HTTPException(status_code=404, detail="Node not found.")
|
||||
return result
|
||||
|
||||
|
||||
@router.delete("/{node_id}", status_code=204)
|
||||
def delete_node(node_id: str, conn=Depends(get_conn)) -> None:
|
||||
"""
|
||||
Discard a non-committed branch node.
|
||||
|
||||
Root nodes (parent_id IS NULL) and committed nodes cannot be deleted.
|
||||
"""
|
||||
if not chain_svc.delete_node(conn, node_id):
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail="Node not found or cannot be deleted (root or committed).",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{node_id}/audio")
|
||||
def stream_audio(node_id: str, conn=Depends(get_conn)) -> FileResponse:
|
||||
"""Stream the audio file for a ready node."""
|
||||
node = chain_svc.get_node(conn, node_id)
|
||||
if node is None:
|
||||
raise HTTPException(status_code=404, detail="Node not found.")
|
||||
if node["audio_path"] is None or not Path(node["audio_path"]).exists():
|
||||
raise HTTPException(status_code=404, detail="Audio not available.")
|
||||
return FileResponse(
|
||||
node["audio_path"],
|
||||
media_type="audio/wav",
|
||||
filename=f"{node_id}.wav",
|
||||
)
|
||||
|
||||
|
||||
# ── Background generation task ────────────────────────────────────────────────
|
||||
|
||||
async def _run_generation(
|
||||
*,
|
||||
node_id: str,
|
||||
chain_id: str,
|
||||
source_audio_path: str,
|
||||
prompt: str,
|
||||
duration_s: float,
|
||||
cfg_coef: float,
|
||||
prompt_duration_s: float,
|
||||
data_dir: str,
|
||||
db_path: str,
|
||||
) -> None:
|
||||
"""
|
||||
Background task: run MusicGen and update node status.
|
||||
|
||||
Opens its own DB connection so it doesn't block the request thread.
|
||||
"""
|
||||
from app.db.store import get_connection
|
||||
|
||||
conn = get_connection(db_path)
|
||||
try:
|
||||
chain_svc.update_node_status(conn, node_id, "generating")
|
||||
await broadcast(chain_id, {"node_id": node_id, "status": "generating"})
|
||||
|
||||
output_path = make_output_path(data_dir, chain_id, node_id)
|
||||
actual_duration = await _musicgen.generate(
|
||||
source_audio_path=source_audio_path,
|
||||
output_path=output_path,
|
||||
prompt=prompt,
|
||||
duration_s=duration_s,
|
||||
cfg_coef=cfg_coef,
|
||||
prompt_duration_s=prompt_duration_s,
|
||||
)
|
||||
|
||||
chain_svc.update_node_status(
|
||||
conn, node_id, "ready",
|
||||
audio_path=output_path,
|
||||
duration_s=actual_duration,
|
||||
)
|
||||
await broadcast(chain_id, {
|
||||
"node_id": node_id,
|
||||
"status": "ready",
|
||||
"audio_path": output_path,
|
||||
"duration_s": actual_duration,
|
||||
})
|
||||
logger.info("Node %s ready (%.1fs)", node_id, actual_duration)
|
||||
|
||||
except Exception as exc:
|
||||
logger.exception("Generation failed for node %s: %s", node_id, exc)
|
||||
chain_svc.update_node_status(conn, node_id, "error", error_msg=str(exc))
|
||||
await broadcast(chain_id, {
|
||||
"node_id": node_id,
|
||||
"status": "error",
|
||||
"error_msg": str(exc),
|
||||
})
|
||||
finally:
|
||||
conn.close()
|
||||
67
app/api/endpoints/stems.py
Normal file
67
app/api/endpoints/stems.py
Normal file
|
|
@ -0,0 +1,67 @@
|
|||
# app/api/endpoints/stems.py — 4-stem separation (Paid tier)
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException
|
||||
|
||||
from app.api.deps import get_conn, get_data_dir
|
||||
from app.models.schemas.node import StemResult
|
||||
from app.services import chain as chain_svc
|
||||
from app.services import stems as stems_svc
|
||||
from app.tiers import require_tier
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api/stems", tags=["stems"])
|
||||
|
||||
|
||||
@router.post("/{node_id}", response_model=StemResult, status_code=202)
|
||||
async def separate_stems(
|
||||
node_id: str,
|
||||
background_tasks: BackgroundTasks,
|
||||
tier: str | None = None,
|
||||
conn=Depends(get_conn),
|
||||
data_dir: str = Depends(get_data_dir),
|
||||
) -> dict:
|
||||
"""
|
||||
Run 4-stem separation (Demucs htdemucs) on a ready node.
|
||||
|
||||
Requires Paid tier. Returns stem file paths immediately — paths will
|
||||
be valid once the background task completes. Poll GET /api/nodes/{node_id}
|
||||
or listen on the chain SSE stream for completion signals.
|
||||
|
||||
tier: passed as query param, e.g. ?tier=paid
|
||||
"""
|
||||
require_tier("stems", tier)
|
||||
|
||||
node = chain_svc.get_node(conn, node_id)
|
||||
if node is None:
|
||||
raise HTTPException(status_code=404, detail="Node not found.")
|
||||
if node["audio_path"] is None:
|
||||
raise HTTPException(status_code=409, detail="Node has no audio.")
|
||||
|
||||
stems_dir = stems_svc.make_stems_dir(data_dir, node["chain_id"], node_id)
|
||||
stem_paths = stems_svc._stem_paths(stems_dir, Path(node["audio_path"]).stem)
|
||||
|
||||
background_tasks.add_task(
|
||||
_run_separation,
|
||||
source_audio_path=node["audio_path"],
|
||||
stems_dir=stems_dir,
|
||||
)
|
||||
|
||||
return StemResult(
|
||||
node_id=node_id,
|
||||
vocals=stem_paths["vocals"],
|
||||
drums=stem_paths["drums"],
|
||||
bass=stem_paths["bass"],
|
||||
other=stem_paths["other"],
|
||||
)
|
||||
|
||||
|
||||
async def _run_separation(source_audio_path: str, stems_dir: str) -> None:
|
||||
try:
|
||||
await stems_svc.separate(source_audio_path, stems_dir)
|
||||
logger.info("Stems complete: %s", stems_dir)
|
||||
except Exception as exc:
|
||||
logger.exception("Stem separation failed: %s", exc)
|
||||
34
app/api/events_store.py
Normal file
34
app/api/events_store.py
Normal file
|
|
@ -0,0 +1,34 @@
|
|||
# app/api/events_store.py — Chain-scoped pub/sub for SSE node-status events
|
||||
#
|
||||
# Each SSE connection subscribes with subscribe(chain_id) → Queue.
|
||||
# Background generation tasks call broadcast(chain_id, event) to push
|
||||
# node state transitions (pending → generating → ready/error) to all listeners.
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections import defaultdict
|
||||
from typing import Any
|
||||
|
||||
# chain_id → list of subscriber queues
|
||||
_listeners: dict[str, list[asyncio.Queue[Any]]] = defaultdict(list)
|
||||
|
||||
|
||||
def subscribe(chain_id: str) -> asyncio.Queue[Any]:
|
||||
"""Register a new SSE listener for chain_id. Returns its Queue."""
|
||||
q: asyncio.Queue[Any] = asyncio.Queue()
|
||||
_listeners[chain_id].append(q)
|
||||
return q
|
||||
|
||||
|
||||
def unsubscribe(chain_id: str, q: asyncio.Queue[Any]) -> None:
|
||||
"""Remove a listener Queue. Safe to call if already removed."""
|
||||
try:
|
||||
_listeners[chain_id].remove(q)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
|
||||
async def broadcast(chain_id: str, event: dict[str, Any]) -> None:
|
||||
"""Push event dict to all active listeners for chain_id."""
|
||||
for q in list(_listeners.get(chain_id, [])):
|
||||
await q.put(event)
|
||||
14
app/api/routes.py
Normal file
14
app/api/routes.py
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
# app/api/routes.py — assemble all endpoint routers
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from app.api.endpoints import chains, events, export, gpu, nodes, stems
|
||||
|
||||
router = APIRouter()
|
||||
router.include_router(chains.router)
|
||||
router.include_router(nodes.router)
|
||||
router.include_router(gpu.router)
|
||||
router.include_router(stems.router)
|
||||
router.include_router(export.router)
|
||||
router.include_router(events.router)
|
||||
66
app/main.py
Normal file
66
app/main.py
Normal file
|
|
@ -0,0 +1,66 @@
|
|||
# app/main.py — Sparrow FastAPI application factory
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from app.api.routes import router
|
||||
from app.db.store import get_connection, run_migrations
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# Read env at startup — not at import time — so tests can inject values
|
||||
db_path = os.environ.get("SPARROW_DB_PATH", "data/sparrow.db")
|
||||
data_dir = os.environ.get("SPARROW_DATA_DIR", "data")
|
||||
env = os.environ.get("SPARROW_ENV", "development")
|
||||
|
||||
Path(db_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
Path(data_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
conn = get_connection(db_path)
|
||||
run_migrations(conn)
|
||||
app.state.conn = conn
|
||||
app.state.data_dir = data_dir
|
||||
|
||||
logger.info("Sparrow started (env=%s, db=%s, data=%s)", env, db_path, data_dir)
|
||||
yield
|
||||
|
||||
conn.close()
|
||||
logger.info("Sparrow shut down.")
|
||||
|
||||
|
||||
def create_app() -> FastAPI:
|
||||
env = os.environ.get("SPARROW_ENV", "development")
|
||||
app = FastAPI(
|
||||
title="Sparrow",
|
||||
description="AI music continuation — branching chain editor",
|
||||
version="0.1.0",
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
# CORS: open in dev, tighten in production via env
|
||||
origins = (
|
||||
["*"] if env == "development"
|
||||
else os.environ.get("SPARROW_CORS_ORIGINS", "").split(",")
|
||||
)
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
app.include_router(router)
|
||||
return app
|
||||
|
||||
|
||||
app = create_app()
|
||||
27
app/models/schemas/chain.py
Normal file
27
app/models/schemas/chain.py
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
# app/models/schemas/chain.py — Chain Pydantic models
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ChainCreate(BaseModel):
|
||||
name: str
|
||||
|
||||
|
||||
class ChainRow(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
created_at: float
|
||||
node_count: int = 0
|
||||
|
||||
|
||||
class ChainTree(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
created_at: float
|
||||
nodes: list["NodeRow"] = []
|
||||
|
||||
|
||||
# Avoid circular import — NodeRow imported at runtime
|
||||
from app.models.schemas.node import NodeRow # noqa: E402
|
||||
ChainTree.model_rebuild()
|
||||
49
app/models/schemas/node.py
Normal file
49
app/models/schemas/node.py
Normal file
|
|
@ -0,0 +1,49 @@
|
|||
# app/models/schemas/node.py — Node Pydantic models
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Literal
|
||||
from pydantic import BaseModel
|
||||
|
||||
NodeStatus = Literal["pending", "generating", "ready", "error"]
|
||||
|
||||
|
||||
class NodeRow(BaseModel):
|
||||
id: str
|
||||
chain_id: str
|
||||
parent_id: str | None
|
||||
audio_path: str | None
|
||||
duration_s: float | None
|
||||
status: NodeStatus
|
||||
is_committed: bool
|
||||
prompt: str
|
||||
energy: float | None
|
||||
tempo_feel: float | None
|
||||
density: float | None
|
||||
cfg_coef: float
|
||||
prompt_duration_s: float
|
||||
error_msg: str | None
|
||||
created_at: float
|
||||
children: list["NodeRow"] = []
|
||||
|
||||
|
||||
class BranchRequest(BaseModel):
|
||||
prompt: str = ""
|
||||
energy: float | None = None
|
||||
tempo_feel: float | None = None
|
||||
density: float | None = None
|
||||
cfg_coef: float = 3.0
|
||||
prompt_duration_s: float = 10.0
|
||||
duration_s: float = 15.0
|
||||
|
||||
|
||||
class StemResult(BaseModel):
|
||||
node_id: str
|
||||
vocals: str
|
||||
drums: str
|
||||
bass: str
|
||||
other: str
|
||||
|
||||
|
||||
class ExportRequest(BaseModel):
|
||||
format: Literal["wav", "mp3"] = "wav"
|
||||
node_ids: list[str] | None = None # if None, use committed spine
|
||||
203
app/services/chain.py
Normal file
203
app/services/chain.py
Normal file
|
|
@ -0,0 +1,203 @@
|
|||
# app/services/chain.py — chain + node CRUD and branching tree logic
|
||||
from __future__ import annotations
|
||||
|
||||
import sqlite3
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
|
||||
# ── Chain CRUD ────────────────────────────────────────────────────────────────
|
||||
|
||||
def create_chain(conn: sqlite3.Connection, name: str) -> dict[str, Any]:
|
||||
chain_id = str(uuid.uuid4())
|
||||
now = time.time()
|
||||
conn.execute(
|
||||
"INSERT INTO chains (id, name, created_at) VALUES (?, ?, ?)",
|
||||
(chain_id, name, now),
|
||||
)
|
||||
conn.commit()
|
||||
return {"id": chain_id, "name": name, "created_at": now, "node_count": 0}
|
||||
|
||||
|
||||
def list_chains(conn: sqlite3.Connection) -> list[dict[str, Any]]:
|
||||
rows = conn.execute("""
|
||||
SELECT c.id, c.name, c.created_at,
|
||||
COUNT(n.id) AS node_count
|
||||
FROM chains c
|
||||
LEFT JOIN nodes n ON n.chain_id = c.id
|
||||
GROUP BY c.id
|
||||
ORDER BY c.created_at DESC
|
||||
""").fetchall()
|
||||
return [dict(r) for r in rows]
|
||||
|
||||
|
||||
def get_chain(conn: sqlite3.Connection, chain_id: str) -> dict[str, Any] | None:
|
||||
row = conn.execute(
|
||||
"SELECT id, name, created_at FROM chains WHERE id = ?", (chain_id,)
|
||||
).fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
nodes = _get_nodes_for_chain(conn, chain_id)
|
||||
return {"id": row["id"], "name": row["name"], "created_at": row["created_at"], "nodes": nodes}
|
||||
|
||||
|
||||
def delete_chain(conn: sqlite3.Connection, chain_id: str) -> bool:
|
||||
cur = conn.execute("DELETE FROM chains WHERE id = ?", (chain_id,))
|
||||
conn.commit()
|
||||
return cur.rowcount > 0
|
||||
|
||||
|
||||
# ── Node CRUD ─────────────────────────────────────────────────────────────────
|
||||
|
||||
def create_root_node(
|
||||
conn: sqlite3.Connection,
|
||||
chain_id: str,
|
||||
audio_path: str,
|
||||
duration_s: float,
|
||||
) -> dict[str, Any]:
|
||||
"""Create the root node (uploaded source audio). Always committed."""
|
||||
node_id = str(uuid.uuid4())
|
||||
now = time.time()
|
||||
conn.execute("""
|
||||
INSERT INTO nodes
|
||||
(id, chain_id, parent_id, audio_path, duration_s, status,
|
||||
is_committed, prompt, cfg_coef, prompt_duration_s, created_at)
|
||||
VALUES (?, ?, NULL, ?, ?, 'ready', 1, '', 3.0, 10.0, ?)
|
||||
""", (node_id, chain_id, audio_path, duration_s, now))
|
||||
conn.commit()
|
||||
return get_node(conn, node_id)
|
||||
|
||||
|
||||
def get_node(conn: sqlite3.Connection, node_id: str) -> dict[str, Any] | None:
|
||||
row = conn.execute("SELECT * FROM nodes WHERE id = ?", (node_id,)).fetchone()
|
||||
return _row_to_dict(row) if row else None
|
||||
|
||||
|
||||
def create_branch_node(
|
||||
conn: sqlite3.Connection,
|
||||
parent_id: str,
|
||||
chain_id: str,
|
||||
prompt: str,
|
||||
energy: float | None,
|
||||
tempo_feel: float | None,
|
||||
density: float | None,
|
||||
cfg_coef: float,
|
||||
prompt_duration_s: float,
|
||||
) -> dict[str, Any]:
|
||||
"""Create a new branch node in 'pending' state. Caller updates to 'generating'."""
|
||||
node_id = str(uuid.uuid4())
|
||||
now = time.time()
|
||||
conn.execute("""
|
||||
INSERT INTO nodes
|
||||
(id, chain_id, parent_id, audio_path, duration_s, status,
|
||||
is_committed, prompt, energy, tempo_feel, density,
|
||||
cfg_coef, prompt_duration_s, created_at)
|
||||
VALUES (?, ?, ?, NULL, NULL, 'pending', 0, ?, ?, ?, ?, ?, ?, ?)
|
||||
""", (node_id, chain_id, parent_id, prompt, energy, tempo_feel, density,
|
||||
cfg_coef, prompt_duration_s, now))
|
||||
conn.commit()
|
||||
return get_node(conn, node_id)
|
||||
|
||||
|
||||
def update_node_status(
|
||||
conn: sqlite3.Connection,
|
||||
node_id: str,
|
||||
status: str,
|
||||
audio_path: str | None = None,
|
||||
duration_s: float | None = None,
|
||||
error_msg: str | None = None,
|
||||
) -> dict[str, Any] | None:
|
||||
fields: list[str] = ["status = ?"]
|
||||
values: list[Any] = [status]
|
||||
if audio_path is not None:
|
||||
fields.append("audio_path = ?")
|
||||
values.append(audio_path)
|
||||
if duration_s is not None:
|
||||
fields.append("duration_s = ?")
|
||||
values.append(duration_s)
|
||||
if error_msg is not None:
|
||||
fields.append("error_msg = ?")
|
||||
values.append(error_msg)
|
||||
values.append(node_id)
|
||||
conn.execute(f"UPDATE nodes SET {', '.join(fields)} WHERE id = ?", values)
|
||||
conn.commit()
|
||||
return get_node(conn, node_id)
|
||||
|
||||
|
||||
def commit_node(conn: sqlite3.Connection, node_id: str) -> dict[str, Any] | None:
|
||||
"""
|
||||
Promote a branch node to the committed spine.
|
||||
|
||||
Discard (delete) all sibling branch nodes (same parent, not committed).
|
||||
The committed node and all its ancestors remain. Root is always committed.
|
||||
"""
|
||||
node = get_node(conn, node_id)
|
||||
if node is None:
|
||||
return None
|
||||
|
||||
parent_id = node.get("parent_id")
|
||||
if parent_id:
|
||||
# Discard non-committed siblings (same parent)
|
||||
conn.execute("""
|
||||
DELETE FROM nodes
|
||||
WHERE parent_id = ? AND id != ? AND is_committed = 0
|
||||
""", (parent_id, node_id))
|
||||
|
||||
conn.execute("UPDATE nodes SET is_committed = 1 WHERE id = ?", (node_id,))
|
||||
conn.commit()
|
||||
return get_node(conn, node_id)
|
||||
|
||||
|
||||
def delete_node(conn: sqlite3.Connection, node_id: str) -> bool:
|
||||
"""Delete a branch node. Root nodes (parent_id IS NULL) cannot be deleted."""
|
||||
row = conn.execute(
|
||||
"SELECT parent_id FROM nodes WHERE id = ?", (node_id,)
|
||||
).fetchone()
|
||||
if row is None or row["parent_id"] is None:
|
||||
return False # root node; refuse
|
||||
cur = conn.execute(
|
||||
"DELETE FROM nodes WHERE id = ? AND is_committed = 0", (node_id,)
|
||||
)
|
||||
conn.commit()
|
||||
return cur.rowcount > 0
|
||||
|
||||
|
||||
def get_committed_spine(conn: sqlite3.Connection, chain_id: str) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Return committed nodes ordered root → leaf along the committed spine.
|
||||
|
||||
Uses a recursive CTE to walk the tree from the root through committed nodes.
|
||||
"""
|
||||
rows = conn.execute("""
|
||||
WITH RECURSIVE spine(id, parent_id, depth) AS (
|
||||
SELECT id, parent_id, 0
|
||||
FROM nodes
|
||||
WHERE chain_id = ? AND parent_id IS NULL AND is_committed = 1
|
||||
UNION ALL
|
||||
SELECT n.id, n.parent_id, spine.depth + 1
|
||||
FROM nodes n
|
||||
JOIN spine ON n.parent_id = spine.id
|
||||
WHERE n.is_committed = 1 AND n.chain_id = ?
|
||||
)
|
||||
SELECT n.* FROM nodes n
|
||||
JOIN spine ON spine.id = n.id
|
||||
ORDER BY spine.depth
|
||||
""", (chain_id, chain_id)).fetchall()
|
||||
return [_row_to_dict(r) for r in rows]
|
||||
|
||||
|
||||
# ── Helpers ───────────────────────────────────────────────────────────────────
|
||||
|
||||
def _row_to_dict(row: sqlite3.Row) -> dict[str, Any]:
|
||||
d = dict(row)
|
||||
d["is_committed"] = bool(d.get("is_committed", 0))
|
||||
d["children"] = []
|
||||
return d
|
||||
|
||||
|
||||
def _get_nodes_for_chain(conn: sqlite3.Connection, chain_id: str) -> list[dict[str, Any]]:
|
||||
rows = conn.execute(
|
||||
"SELECT * FROM nodes WHERE chain_id = ? ORDER BY created_at", (chain_id,)
|
||||
).fetchall()
|
||||
return [_row_to_dict(r) for r in rows]
|
||||
80
app/services/export.py
Normal file
80
app/services/export.py
Normal file
|
|
@ -0,0 +1,80 @@
|
|||
# app/services/export.py — Stitch committed spine audio files via ffmpeg
|
||||
#
|
||||
# Uses ffmpeg concat demuxer (list-form args, no shell injection).
|
||||
# WAV output: copy codec. MP3 output: libmp3lame 320k.
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ExportFormat = Literal["wav", "mp3"]
|
||||
|
||||
|
||||
async def stitch(
|
||||
audio_paths: list[str],
|
||||
output_path: str,
|
||||
fmt: ExportFormat = "wav",
|
||||
) -> str:
|
||||
"""
|
||||
Concatenate audio_paths in order and write to output_path.
|
||||
|
||||
Returns output_path on success. Raises RuntimeError on ffmpeg failure.
|
||||
Requires ffmpeg on PATH.
|
||||
"""
|
||||
if not audio_paths:
|
||||
raise ValueError("No audio paths to stitch.")
|
||||
|
||||
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Single file: just copy, no concatenation needed
|
||||
if len(audio_paths) == 1:
|
||||
shutil.copy2(audio_paths[0], output_path)
|
||||
return output_path
|
||||
|
||||
# Write ffmpeg concat file list to a tempfile
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w", suffix=".txt", delete=False
|
||||
) as flist:
|
||||
for p in audio_paths:
|
||||
# ffmpeg concat list format requires one "file '/abs/path'" per line
|
||||
flist.write(f"file '{p}'\n")
|
||||
flist_path = flist.name
|
||||
|
||||
try:
|
||||
codec_args: list[str] = (
|
||||
["-c:a", "copy"] if fmt == "wav"
|
||||
else ["-c:a", "libmp3lame", "-b:a", "320k"]
|
||||
)
|
||||
# create_subprocess_exec uses a list of args — no shell interpolation
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
"ffmpeg", "-y",
|
||||
"-f", "concat",
|
||||
"-safe", "0",
|
||||
"-i", flist_path,
|
||||
*codec_args,
|
||||
output_path,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
_, stderr = await proc.communicate()
|
||||
if proc.returncode != 0:
|
||||
raise RuntimeError(
|
||||
f"ffmpeg concat failed (exit {proc.returncode}): {stderr.decode()}"
|
||||
)
|
||||
finally:
|
||||
os.unlink(flist_path)
|
||||
|
||||
logger.info("Stitched %d files -> %s", len(audio_paths), output_path)
|
||||
return output_path
|
||||
|
||||
|
||||
def make_export_path(data_dir: str, chain_id: str, fmt: ExportFormat) -> str:
|
||||
"""Standard export output path."""
|
||||
return str(Path(data_dir) / "chains" / chain_id / f"export.{fmt}")
|
||||
157
app/services/musicgen.py
Normal file
157
app/services/musicgen.py
Normal file
|
|
@ -0,0 +1,157 @@
|
|||
# app/services/musicgen.py — MusicGen client (cf-orch allocation + /continue call)
|
||||
#
|
||||
# Mock mode (CF_MUSICGEN_MOCK=1): copies source file, adds 1s of silence padding.
|
||||
# Real mode: allocates cf-musicgen from cf-orch, calls POST /continue, releases.
|
||||
#
|
||||
# cf_core.musicgen.app is not yet implemented (tracked in cf-core #49).
|
||||
# Until it is, real mode will fail at allocation time with a clear error.
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import time
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_ORCH_URL = os.environ.get("CF_ORCH_URL", "").rstrip("/")
|
||||
_SERVICE = "cf-musicgen"
|
||||
_MOCK = os.environ.get("CF_MUSICGEN_MOCK", "") == "1"
|
||||
|
||||
|
||||
class MusicGenClient:
|
||||
"""
|
||||
Allocates a cf-musicgen instance from cf-orch and calls POST /continue.
|
||||
|
||||
Each generate() call allocates, generates, and releases. For Premium tier
|
||||
session-held allocations, subclass and override _allocate/_release.
|
||||
"""
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
source_audio_path: str,
|
||||
output_path: str,
|
||||
prompt: str,
|
||||
duration_s: float,
|
||||
cfg_coef: float,
|
||||
prompt_duration_s: float,
|
||||
) -> float:
|
||||
"""
|
||||
Generate a continuation of source_audio_path.
|
||||
|
||||
Writes the result to output_path and returns the actual duration_s.
|
||||
Raises RuntimeError on cf-orch allocation failure or generation error.
|
||||
"""
|
||||
if _MOCK:
|
||||
return await _mock_generate(source_audio_path, output_path, duration_s)
|
||||
|
||||
service_url = await self._allocate()
|
||||
try:
|
||||
return await _call_continue(
|
||||
service_url=service_url,
|
||||
audio_path=source_audio_path,
|
||||
output_path=output_path,
|
||||
prompt=prompt,
|
||||
duration_s=duration_s,
|
||||
cfg_coef=cfg_coef,
|
||||
prompt_duration_s=prompt_duration_s,
|
||||
)
|
||||
finally:
|
||||
await self._release(service_url)
|
||||
|
||||
async def _allocate(self) -> str:
|
||||
"""Request a cf-musicgen allocation from cf-orch. Returns the service URL."""
|
||||
if not _ORCH_URL:
|
||||
raise RuntimeError(
|
||||
"CF_ORCH_URL is not configured. Set it in .env or use CF_MUSICGEN_MOCK=1."
|
||||
)
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
resp = await client.post(
|
||||
f"{_ORCH_URL}/api/services/{_SERVICE}/allocations",
|
||||
json={"requester": "sparrow"},
|
||||
)
|
||||
if resp.status_code == 503:
|
||||
raise RuntimeError("No cf-musicgen capacity available — all GPUs busy.")
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
url = data.get("url") or data.get("service_url")
|
||||
if not url:
|
||||
raise RuntimeError(f"cf-orch allocation response missing URL: {data}")
|
||||
logger.info("Allocated cf-musicgen at %s", url)
|
||||
return url
|
||||
|
||||
async def _release(self, service_url: str) -> None:
|
||||
"""Release the cf-musicgen allocation back to cf-orch."""
|
||||
if not _ORCH_URL:
|
||||
return
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
await client.delete(f"{_ORCH_URL}/api/services/{_SERVICE}/allocations",
|
||||
json={"url": service_url})
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to release cf-musicgen allocation: %s", exc)
|
||||
|
||||
|
||||
# ── Real /continue call ───────────────────────────────────────────────────────
|
||||
|
||||
async def _call_continue(
|
||||
service_url: str,
|
||||
audio_path: str,
|
||||
output_path: str,
|
||||
prompt: str,
|
||||
duration_s: float,
|
||||
cfg_coef: float,
|
||||
prompt_duration_s: float,
|
||||
) -> float:
|
||||
"""Call POST /continue on the allocated cf-musicgen service."""
|
||||
payload = {
|
||||
"audio_path": audio_path,
|
||||
"output_path": output_path,
|
||||
"prompt": prompt,
|
||||
"duration_s": duration_s,
|
||||
"cfg_coef": cfg_coef,
|
||||
"prompt_duration_s": prompt_duration_s,
|
||||
}
|
||||
# MusicGen can take 30–120s depending on duration and hardware
|
||||
async with httpx.AsyncClient(timeout=180.0) as client:
|
||||
resp = await client.post(f"{service_url.rstrip('/')}/continue", json=payload)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
return float(data.get("duration_s", duration_s))
|
||||
|
||||
|
||||
# ── Mock generation ───────────────────────────────────────────────────────────
|
||||
|
||||
async def _mock_generate(
|
||||
source_audio_path: str,
|
||||
output_path: str,
|
||||
duration_s: float,
|
||||
) -> float:
|
||||
"""
|
||||
Mock: copy the source file to output_path with a short simulated delay.
|
||||
|
||||
Simulates generation latency so the UI state machine (pending → generating → ready)
|
||||
exercises all transitions during development.
|
||||
"""
|
||||
await asyncio.sleep(2.0) # simulate generation time
|
||||
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy2(source_audio_path, output_path)
|
||||
logger.info("Mock generation: copied %s → %s", source_audio_path, output_path)
|
||||
|
||||
# Return the actual duration from the copied file
|
||||
try:
|
||||
import torchaudio
|
||||
info = torchaudio.info(output_path)
|
||||
return info.num_frames / info.sample_rate
|
||||
except Exception:
|
||||
return duration_s
|
||||
|
||||
|
||||
def make_output_path(data_dir: str, chain_id: str, node_id: str) -> str:
|
||||
"""Standard output path for a generated node's audio file."""
|
||||
return str(Path(data_dir) / "chains" / chain_id / "nodes" / f"{node_id}.wav")
|
||||
75
app/services/stems.py
Normal file
75
app/services/stems.py
Normal file
|
|
@ -0,0 +1,75 @@
|
|||
# app/services/stems.py — Demucs 4-stem separation subprocess wrapper
|
||||
#
|
||||
# Mock mode (CF_STEMS_MOCK=1): copies source file to all 4 stem paths.
|
||||
# Real mode: runs demucs htdemucs model via asyncio.create_subprocess_exec
|
||||
# (list-form, no shell=True — not vulnerable to injection).
|
||||
#
|
||||
# Demucs output layout: {output_dir}/htdemucs/{track.stem}/{stem}.wav
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_MOCK = os.environ.get("CF_STEMS_MOCK", "") == "1"
|
||||
_STEMS = ("vocals", "drums", "bass", "other")
|
||||
|
||||
|
||||
async def separate(source_audio_path: str, output_dir: str) -> dict[str, str]:
|
||||
"""
|
||||
Split source_audio_path into 4 stems.
|
||||
|
||||
Returns a dict of stem name to absolute file path.
|
||||
Raises RuntimeError on demucs failure.
|
||||
"""
|
||||
if _MOCK:
|
||||
return await _mock_separate(source_audio_path, output_dir)
|
||||
return await _real_separate(source_audio_path, output_dir)
|
||||
|
||||
|
||||
async def _real_separate(source_audio_path: str, output_dir: str) -> dict[str, str]:
|
||||
track = Path(source_audio_path)
|
||||
logger.info("Running demucs on %s -> %s", track.name, output_dir)
|
||||
# create_subprocess_exec takes a list of args — no shell interpolation
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
"python", "-m", "demucs",
|
||||
"--model", "htdemucs",
|
||||
"--out", output_dir,
|
||||
str(track),
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
_, stderr = await proc.communicate()
|
||||
if proc.returncode != 0:
|
||||
raise RuntimeError(f"demucs failed (exit {proc.returncode}): {stderr.decode()}")
|
||||
return _stem_paths(output_dir, track.stem)
|
||||
|
||||
|
||||
async def _mock_separate(source_audio_path: str, output_dir: str) -> dict[str, str]:
|
||||
"""Mock: copy source to all 4 stems with a short simulated delay."""
|
||||
await asyncio.sleep(1.5)
|
||||
track = Path(source_audio_path)
|
||||
stem_dir = Path(output_dir) / "htdemucs" / track.stem
|
||||
stem_dir.mkdir(parents=True, exist_ok=True)
|
||||
paths: dict[str, str] = {}
|
||||
for stem in _STEMS:
|
||||
dest = str(stem_dir / f"{stem}.wav")
|
||||
shutil.copy2(source_audio_path, dest)
|
||||
paths[stem] = dest
|
||||
logger.info("Mock stems: copied %s -> %s", track.name, stem_dir)
|
||||
return paths
|
||||
|
||||
|
||||
def _stem_paths(output_dir: str, track_stem: str) -> dict[str, str]:
|
||||
"""Build expected output paths from demucs's standard directory layout."""
|
||||
stem_dir = Path(output_dir) / "htdemucs" / track_stem
|
||||
return {s: str(stem_dir / f"{s}.wav") for s in _STEMS}
|
||||
|
||||
|
||||
def make_stems_dir(data_dir: str, chain_id: str, node_id: str) -> str:
|
||||
"""Standard stems output directory for a node."""
|
||||
return str(Path(data_dir) / "chains" / chain_id / "stems" / node_id)
|
||||
52
app/tiers.py
Normal file
52
app/tiers.py
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
# app/tiers.py — tier gates for Sparrow
|
||||
#
|
||||
# BYOK does not apply to Sparrow — "bring your own" means bring your own
|
||||
# hardware (self-hosting). is_self_hosted() checks the cf-orch URL target.
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from fastapi import HTTPException
|
||||
|
||||
|
||||
def is_self_hosted() -> bool:
|
||||
"""
|
||||
True when cf-orch is on localhost or a LAN address.
|
||||
|
||||
Self-hosted users get unrestricted access equivalent to Paid tier.
|
||||
Cloud users are gated by the standard tier model.
|
||||
"""
|
||||
orch_url = os.environ.get("CF_ORCH_URL", "")
|
||||
if not orch_url:
|
||||
return True # no orch = local dev mode = self-hosted
|
||||
host = orch_url.split("//")[-1].split(":")[0].split("/")[0]
|
||||
return host in ("localhost", "127.0.0.1") or host.startswith("10.") or \
|
||||
host.startswith("192.168.") or host.startswith("172.")
|
||||
|
||||
|
||||
def require_tier(feature: str, tier: str | None = None) -> None:
|
||||
"""
|
||||
Gate a feature by tier. Raises 402 if the caller does not qualify.
|
||||
|
||||
For self-hosted instances all features are unlocked.
|
||||
tier=None means the request carries no tier — treat as Free.
|
||||
"""
|
||||
if is_self_hosted():
|
||||
return
|
||||
|
||||
_TIER_RANK = {"free": 0, "paid": 1, "premium": 2}
|
||||
_FEATURE_MIN: dict[str, str] = {
|
||||
"stems": "paid",
|
||||
"parallel_branch": "paid",
|
||||
"session_allocation": "premium",
|
||||
"priority_queue": "premium",
|
||||
}
|
||||
|
||||
min_tier = _FEATURE_MIN.get(feature, "free")
|
||||
caller_rank = _TIER_RANK.get(tier or "free", 0)
|
||||
required_rank = _TIER_RANK[min_tier]
|
||||
|
||||
if caller_rank < required_rank:
|
||||
raise HTTPException(
|
||||
status_code=402,
|
||||
detail=f"Feature '{feature}' requires {min_tier} tier.",
|
||||
)
|
||||
65
tests/conftest.py
Normal file
65
tests/conftest.py
Normal file
|
|
@ -0,0 +1,65 @@
|
|||
# tests/conftest.py — shared fixtures for Sparrow test suite
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from fastapi.testclient import TestClient
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
# Use mock mode for all tests — no GPU or ffmpeg required
|
||||
os.environ.setdefault("CF_MUSICGEN_MOCK", "1")
|
||||
os.environ.setdefault("CF_STEMS_MOCK", "1")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tmp_db(tmp_path):
|
||||
"""Temporary SQLite DB path."""
|
||||
return str(tmp_path / "test.db")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tmp_data(tmp_path):
|
||||
"""Temporary data directory."""
|
||||
d = tmp_path / "data"
|
||||
d.mkdir()
|
||||
return str(d)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def conn(tmp_db):
|
||||
"""SQLite connection with migrations applied."""
|
||||
from app.db.store import get_connection, run_migrations
|
||||
c = get_connection(tmp_db)
|
||||
run_migrations(c)
|
||||
yield c
|
||||
c.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app(tmp_db, tmp_data):
|
||||
"""FastAPI test app with isolated DB and data dir."""
|
||||
os.environ["SPARROW_DB_PATH"] = tmp_db
|
||||
os.environ["SPARROW_DATA_DIR"] = tmp_data
|
||||
os.environ["SPARROW_ENV"] = "development"
|
||||
|
||||
from app.main import create_app
|
||||
return create_app()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(app):
|
||||
"""Synchronous test client."""
|
||||
with TestClient(app, raise_server_exceptions=True) as c:
|
||||
yield c
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def async_client(app):
|
||||
"""Async test client."""
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=app), base_url="http://test"
|
||||
) as c:
|
||||
yield c
|
||||
69
tests/test_api_chains.py
Normal file
69
tests/test_api_chains.py
Normal file
|
|
@ -0,0 +1,69 @@
|
|||
# tests/test_api_chains.py — integration tests for chain + upload endpoints
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
|
||||
|
||||
def test_create_and_list_chains(client):
|
||||
resp = client.post("/api/chains/", json={"name": "my chain"})
|
||||
assert resp.status_code == 201
|
||||
data = resp.json()
|
||||
assert data["name"] == "my chain"
|
||||
assert "id" in data
|
||||
|
||||
resp = client.get("/api/chains/")
|
||||
assert resp.status_code == 200
|
||||
assert len(resp.json()) == 1
|
||||
|
||||
|
||||
def test_get_chain_not_found(client):
|
||||
resp = client.get("/api/chains/doesnotexist")
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
def test_delete_chain(client):
|
||||
resp = client.post("/api/chains/", json={"name": "to delete"})
|
||||
chain_id = resp.json()["id"]
|
||||
|
||||
resp = client.delete(f"/api/chains/{chain_id}")
|
||||
assert resp.status_code == 204
|
||||
|
||||
resp = client.get(f"/api/chains/{chain_id}")
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
def test_upload_root_node(client, tmp_path):
|
||||
resp = client.post("/api/chains/", json={"name": "upload test"})
|
||||
chain_id = resp.json()["id"]
|
||||
|
||||
# Minimal valid WAV header (44 bytes) so torchaudio doesn't crash
|
||||
# If torchaudio isn't installed, _probe_duration returns 0.0 gracefully
|
||||
wav_bytes = b"RIFF" + b"\x00" * 40
|
||||
resp = client.post(
|
||||
f"/api/chains/{chain_id}/upload",
|
||||
files={"file": ("test.wav", io.BytesIO(wav_bytes), "audio/wav")},
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
node = resp.json()
|
||||
assert node["status"] == "ready"
|
||||
assert node["is_committed"] is True
|
||||
assert node["parent_id"] is None
|
||||
|
||||
|
||||
def test_upload_unsupported_format(client):
|
||||
resp = client.post("/api/chains/", json={"name": "format test"})
|
||||
chain_id = resp.json()["id"]
|
||||
|
||||
resp = client.post(
|
||||
f"/api/chains/{chain_id}/upload",
|
||||
files={"file": ("track.txt", io.BytesIO(b"not audio"), "text/plain")},
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
def test_upload_chain_not_found(client):
|
||||
resp = client.post(
|
||||
"/api/chains/nonexistent/upload",
|
||||
files={"file": ("test.wav", io.BytesIO(b""), "audio/wav")},
|
||||
)
|
||||
assert resp.status_code == 404
|
||||
102
tests/test_api_nodes.py
Normal file
102
tests/test_api_nodes.py
Normal file
|
|
@ -0,0 +1,102 @@
|
|||
# tests/test_api_nodes.py — integration tests for branch/commit/discard endpoints
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
|
||||
|
||||
def _create_chain_with_root(client, tmp_path=None):
|
||||
"""Helper: create a chain and upload a root node."""
|
||||
resp = client.post("/api/chains/", json={"name": "test chain"})
|
||||
chain_id = resp.json()["id"]
|
||||
|
||||
resp = client.post(
|
||||
f"/api/chains/{chain_id}/upload",
|
||||
files={"file": ("root.wav", io.BytesIO(b"RIFF" + b"\x00" * 40), "audio/wav")},
|
||||
)
|
||||
node = resp.json()
|
||||
return chain_id, node["id"]
|
||||
|
||||
|
||||
def test_create_branch_returns_202(client, tmp_path):
|
||||
chain_id, root_id = _create_chain_with_root(client)
|
||||
|
||||
resp = client.post(
|
||||
f"/api/nodes/{root_id}/branch",
|
||||
json={"prompt": "upbeat jazz", "duration_s": 10.0},
|
||||
)
|
||||
assert resp.status_code == 202
|
||||
node = resp.json()
|
||||
assert node["status"] == "pending"
|
||||
assert node["parent_id"] == root_id
|
||||
assert node["prompt"] == "upbeat jazz"
|
||||
|
||||
|
||||
def test_branch_parent_not_found(client):
|
||||
resp = client.post(
|
||||
"/api/nodes/nonexistent/branch",
|
||||
json={"prompt": "test"},
|
||||
)
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
def test_commit_requires_ready_status(client, conn):
|
||||
"""
|
||||
Directly insert a pending node into the shared DB (bypasses BackgroundTasks
|
||||
so we control the status without triggering mock generation).
|
||||
"""
|
||||
from app.services import chain as svc
|
||||
|
||||
chain_id, root_id = _create_chain_with_root(client)
|
||||
# Create branch via service (no background task), stays "pending"
|
||||
branch = svc.create_branch_node(
|
||||
conn, root_id, chain_id, "test", None, None, None, 3.0, 10.0
|
||||
)
|
||||
resp = client.post(f"/api/nodes/{branch['id']}/commit")
|
||||
assert resp.status_code == 409
|
||||
|
||||
|
||||
def test_delete_root_node_refused(client):
|
||||
chain_id, root_id = _create_chain_with_root(client)
|
||||
resp = client.delete(f"/api/nodes/{root_id}")
|
||||
assert resp.status_code == 409
|
||||
|
||||
|
||||
def test_delete_node_not_found(client):
|
||||
resp = client.delete("/api/nodes/doesnotexist")
|
||||
assert resp.status_code == 409
|
||||
|
||||
|
||||
def test_audio_not_found_for_pending_node(client, conn):
|
||||
"""Pending node (no audio_path yet) returns 404 on the audio endpoint."""
|
||||
from app.services import chain as svc
|
||||
|
||||
chain_id, root_id = _create_chain_with_root(client)
|
||||
branch = svc.create_branch_node(
|
||||
conn, root_id, chain_id, "test", None, None, None, 3.0, 10.0
|
||||
)
|
||||
resp = client.get(f"/api/nodes/{branch['id']}/audio")
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
def test_mock_generation_completes(client, conn):
|
||||
"""
|
||||
With CF_MUSICGEN_MOCK=1, TestClient runs BackgroundTasks synchronously,
|
||||
so the node transitions to 'ready' before client.post() returns.
|
||||
We verify the completed state via the shared conn fixture.
|
||||
"""
|
||||
from app.services import chain as svc
|
||||
|
||||
chain_id, root_id = _create_chain_with_root(client)
|
||||
resp = client.post(
|
||||
f"/api/nodes/{root_id}/branch",
|
||||
json={"prompt": "test mock", "duration_s": 5.0},
|
||||
)
|
||||
assert resp.status_code == 202
|
||||
branch_id = resp.json()["id"]
|
||||
|
||||
# Background task completed synchronously — node is ready
|
||||
node = svc.get_node(conn, branch_id)
|
||||
assert node is not None
|
||||
assert node["status"] == "ready", (
|
||||
f"Expected ready, got: {node['status']}, error: {node.get('error_msg')}"
|
||||
)
|
||||
150
tests/test_chain_service.py
Normal file
150
tests/test_chain_service.py
Normal file
|
|
@ -0,0 +1,150 @@
|
|||
# tests/test_chain_service.py — unit tests for chain/node CRUD service
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from app.services import chain as svc
|
||||
|
||||
|
||||
def test_create_and_list_chain(conn):
|
||||
chain = svc.create_chain(conn, "test chain")
|
||||
assert chain["name"] == "test chain"
|
||||
assert chain["node_count"] == 0
|
||||
|
||||
chains = svc.list_chains(conn)
|
||||
assert len(chains) == 1
|
||||
assert chains[0]["id"] == chain["id"]
|
||||
|
||||
|
||||
def test_get_chain_not_found(conn):
|
||||
assert svc.get_chain(conn, "nonexistent") is None
|
||||
|
||||
|
||||
def test_delete_chain(conn):
|
||||
chain = svc.create_chain(conn, "deleteme")
|
||||
assert svc.delete_chain(conn, chain["id"])
|
||||
assert svc.get_chain(conn, chain["id"]) is None
|
||||
|
||||
|
||||
def test_create_root_node(conn, tmp_path):
|
||||
chain = svc.create_chain(conn, "root test")
|
||||
audio = str(tmp_path / "audio.wav")
|
||||
open(audio, "wb").close() # empty file, just needs to exist
|
||||
|
||||
node = svc.create_root_node(conn, chain["id"], audio, 30.0)
|
||||
assert node["status"] == "ready"
|
||||
assert node["is_committed"] is True
|
||||
assert node["parent_id"] is None
|
||||
assert node["audio_path"] == audio
|
||||
assert node["duration_s"] == 30.0
|
||||
|
||||
|
||||
def test_create_branch_node(conn, tmp_path):
|
||||
chain = svc.create_chain(conn, "branch test")
|
||||
audio = str(tmp_path / "audio.wav")
|
||||
open(audio, "wb").close()
|
||||
root = svc.create_root_node(conn, chain["id"], audio, 30.0)
|
||||
|
||||
branch = svc.create_branch_node(
|
||||
conn,
|
||||
parent_id=root["id"],
|
||||
chain_id=chain["id"],
|
||||
prompt="upbeat jazz",
|
||||
energy=0.8,
|
||||
tempo_feel=None,
|
||||
density=None,
|
||||
cfg_coef=3.0,
|
||||
prompt_duration_s=10.0,
|
||||
)
|
||||
assert branch["status"] == "pending"
|
||||
assert branch["is_committed"] is False
|
||||
assert branch["parent_id"] == root["id"]
|
||||
assert branch["prompt"] == "upbeat jazz"
|
||||
|
||||
|
||||
def test_update_node_status(conn, tmp_path):
|
||||
chain = svc.create_chain(conn, "status test")
|
||||
audio = str(tmp_path / "audio.wav")
|
||||
open(audio, "wb").close()
|
||||
root = svc.create_root_node(conn, chain["id"], audio, 30.0)
|
||||
branch = svc.create_branch_node(
|
||||
conn, root["id"], chain["id"], "test", None, None, None, 3.0, 10.0
|
||||
)
|
||||
|
||||
updated = svc.update_node_status(
|
||||
conn, branch["id"], "generating"
|
||||
)
|
||||
assert updated["status"] == "generating"
|
||||
|
||||
out_path = str(tmp_path / "out.wav")
|
||||
open(out_path, "wb").close()
|
||||
done = svc.update_node_status(
|
||||
conn, branch["id"], "ready", audio_path=out_path, duration_s=15.0
|
||||
)
|
||||
assert done["status"] == "ready"
|
||||
assert done["audio_path"] == out_path
|
||||
assert done["duration_s"] == 15.0
|
||||
|
||||
|
||||
def test_commit_node_discards_siblings(conn, tmp_path):
|
||||
chain = svc.create_chain(conn, "commit test")
|
||||
audio = str(tmp_path / "audio.wav")
|
||||
open(audio, "wb").close()
|
||||
root = svc.create_root_node(conn, chain["id"], audio, 30.0)
|
||||
|
||||
b1 = svc.create_branch_node(
|
||||
conn, root["id"], chain["id"], "branch 1", None, None, None, 3.0, 10.0
|
||||
)
|
||||
b2 = svc.create_branch_node(
|
||||
conn, root["id"], chain["id"], "branch 2", None, None, None, 3.0, 10.0
|
||||
)
|
||||
|
||||
# Make both ready
|
||||
for nid in (b1["id"], b2["id"]):
|
||||
svc.update_node_status(conn, nid, "ready",
|
||||
audio_path=audio, duration_s=15.0)
|
||||
|
||||
# Commit b1 — b2 should be deleted
|
||||
committed = svc.commit_node(conn, b1["id"])
|
||||
assert committed["is_committed"] is True
|
||||
assert svc.get_node(conn, b2["id"]) is None
|
||||
|
||||
|
||||
def test_delete_node_refuses_root(conn, tmp_path):
|
||||
chain = svc.create_chain(conn, "del test")
|
||||
audio = str(tmp_path / "audio.wav")
|
||||
open(audio, "wb").close()
|
||||
root = svc.create_root_node(conn, chain["id"], audio, 30.0)
|
||||
assert svc.delete_node(conn, root["id"]) is False
|
||||
|
||||
|
||||
def test_delete_node_refuses_committed(conn, tmp_path):
|
||||
chain = svc.create_chain(conn, "del committed")
|
||||
audio = str(tmp_path / "audio.wav")
|
||||
open(audio, "wb").close()
|
||||
root = svc.create_root_node(conn, chain["id"], audio, 30.0)
|
||||
branch = svc.create_branch_node(
|
||||
conn, root["id"], chain["id"], "", None, None, None, 3.0, 10.0
|
||||
)
|
||||
svc.update_node_status(conn, branch["id"], "ready",
|
||||
audio_path=audio, duration_s=15.0)
|
||||
svc.commit_node(conn, branch["id"])
|
||||
assert svc.delete_node(conn, branch["id"]) is False
|
||||
|
||||
|
||||
def test_committed_spine_order(conn, tmp_path):
|
||||
chain = svc.create_chain(conn, "spine test")
|
||||
audio = str(tmp_path / "audio.wav")
|
||||
open(audio, "wb").close()
|
||||
root = svc.create_root_node(conn, chain["id"], audio, 30.0)
|
||||
|
||||
child = svc.create_branch_node(
|
||||
conn, root["id"], chain["id"], "child", None, None, None, 3.0, 10.0
|
||||
)
|
||||
svc.update_node_status(conn, child["id"], "ready",
|
||||
audio_path=audio, duration_s=15.0)
|
||||
svc.commit_node(conn, child["id"])
|
||||
|
||||
spine = svc.get_committed_spine(conn, chain["id"])
|
||||
assert len(spine) == 2
|
||||
assert spine[0]["id"] == root["id"]
|
||||
assert spine[1]["id"] == child["id"]
|
||||
34
tests/test_export_service.py
Normal file
34
tests/test_export_service.py
Normal file
|
|
@ -0,0 +1,34 @@
|
|||
# tests/test_export_service.py — unit tests for export stitch service
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stitch_single_file(tmp_path):
|
||||
"""Single file: should copy without invoking ffmpeg."""
|
||||
from app.services.export import stitch
|
||||
|
||||
src = tmp_path / "segment.wav"
|
||||
src.write_bytes(b"WAV DATA")
|
||||
out = str(tmp_path / "output.wav")
|
||||
|
||||
result = await stitch([str(src)], out)
|
||||
assert result == out
|
||||
assert open(out, "rb").read() == b"WAV DATA"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stitch_empty_raises(tmp_path):
|
||||
from app.services.export import stitch
|
||||
|
||||
with pytest.raises(ValueError, match="No audio paths"):
|
||||
await stitch([], str(tmp_path / "out.wav"))
|
||||
|
||||
|
||||
def test_make_export_path():
|
||||
from app.services.export import make_export_path
|
||||
|
||||
p = make_export_path("data", "chain-abc", "mp3")
|
||||
assert p == "data/chains/chain-abc/export.mp3"
|
||||
assert make_export_path("data", "chain-abc", "wav").endswith(".wav")
|
||||
36
tests/test_musicgen_service.py
Normal file
36
tests/test_musicgen_service.py
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
# tests/test_musicgen_service.py — unit tests for MusicGenClient mock mode
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import pytest
|
||||
|
||||
# Ensure mock mode
|
||||
os.environ["CF_MUSICGEN_MOCK"] = "1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mock_generate_copies_file(tmp_path):
|
||||
from app.services.musicgen import MusicGenClient
|
||||
|
||||
source = tmp_path / "source.wav"
|
||||
source.write_bytes(b"RIFF" + b"\x00" * 40)
|
||||
output = str(tmp_path / "output.wav")
|
||||
|
||||
client = MusicGenClient()
|
||||
duration = await client.generate(
|
||||
source_audio_path=str(source),
|
||||
output_path=output,
|
||||
prompt="test",
|
||||
duration_s=10.0,
|
||||
cfg_coef=3.0,
|
||||
prompt_duration_s=5.0,
|
||||
)
|
||||
|
||||
assert os.path.exists(output)
|
||||
assert duration >= 0.0 # torchaudio may not parse our fake WAV; 0.0 is fine
|
||||
|
||||
|
||||
def test_make_output_path():
|
||||
from app.services.musicgen import make_output_path
|
||||
path = make_output_path("data", "chain-abc", "node-xyz")
|
||||
assert path == "data/chains/chain-abc/nodes/node-xyz.wav"
|
||||
Loading…
Reference in a new issue