sparrow/app/services/musicgen.py
pyr0ball a6f60c9e07 feat: implement Sparrow backend (v0.1.0)
Full FastAPI backend for the AI music continuation editor:

Services
- chain.py: chain + node CRUD, commit/discard, recursive CTE spine query
- musicgen.py: MusicGenClient with cf-orch allocation + mock mode (CF_MUSICGEN_MOCK=1)
- stems.py: Demucs 4-stem separation subprocess wrapper + mock mode
- export.py: ffmpeg concat demuxer to stitch committed spine into WAV/MP3

API endpoints
- chains: CRUD, multipart audio upload (WAV/MP3/FLAC/OGG/M4A/AIFF)
- nodes: branch creation (202 + BackgroundTasks), commit, discard, audio stream
- gpu: cf-orch capacity status; session allocation stubbed pending cf-orch#43
- stems: Paid-tier stem separation (Demucs, gated via tiers.py)
- export: POST /{chain_id}/export → FileResponse download
- events: SSE stream (node-status events) per chain via asyncio Queue pub/sub

Infrastructure
- lifespan: reads SPARROW_DB_PATH/DATA_DIR at startup (not import time)
- events_store: subscribe/unsubscribe/broadcast pattern for SSE
- CORS: open in dev, SPARROW_CORS_ORIGINS in production
- Background generation opens its own DB connection (WAL-safe)

Tests: 30/30 passing across service units and API integration
2026-04-17 15:22:37 -07:00

157 lines
5.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# app/services/musicgen.py — MusicGen client (cf-orch allocation + /continue call)
#
# Mock mode (CF_MUSICGEN_MOCK=1): copies source file, adds 1s of silence padding.
# Real mode: allocates cf-musicgen from cf-orch, calls POST /continue, releases.
#
# cf_core.musicgen.app is not yet implemented (tracked in cf-core #49).
# Until it is, real mode will fail at allocation time with a clear error.
from __future__ import annotations
import asyncio
import logging
import os
import shutil
import time
import uuid
from pathlib import Path
import httpx
logger = logging.getLogger(__name__)
_ORCH_URL = os.environ.get("CF_ORCH_URL", "").rstrip("/")
_SERVICE = "cf-musicgen"
_MOCK = os.environ.get("CF_MUSICGEN_MOCK", "") == "1"
class MusicGenClient:
"""
Allocates a cf-musicgen instance from cf-orch and calls POST /continue.
Each generate() call allocates, generates, and releases. For Premium tier
session-held allocations, subclass and override _allocate/_release.
"""
async def generate(
self,
source_audio_path: str,
output_path: str,
prompt: str,
duration_s: float,
cfg_coef: float,
prompt_duration_s: float,
) -> float:
"""
Generate a continuation of source_audio_path.
Writes the result to output_path and returns the actual duration_s.
Raises RuntimeError on cf-orch allocation failure or generation error.
"""
if _MOCK:
return await _mock_generate(source_audio_path, output_path, duration_s)
service_url = await self._allocate()
try:
return await _call_continue(
service_url=service_url,
audio_path=source_audio_path,
output_path=output_path,
prompt=prompt,
duration_s=duration_s,
cfg_coef=cfg_coef,
prompt_duration_s=prompt_duration_s,
)
finally:
await self._release(service_url)
async def _allocate(self) -> str:
"""Request a cf-musicgen allocation from cf-orch. Returns the service URL."""
if not _ORCH_URL:
raise RuntimeError(
"CF_ORCH_URL is not configured. Set it in .env or use CF_MUSICGEN_MOCK=1."
)
async with httpx.AsyncClient(timeout=30.0) as client:
resp = await client.post(
f"{_ORCH_URL}/api/services/{_SERVICE}/allocations",
json={"requester": "sparrow"},
)
if resp.status_code == 503:
raise RuntimeError("No cf-musicgen capacity available — all GPUs busy.")
resp.raise_for_status()
data = resp.json()
url = data.get("url") or data.get("service_url")
if not url:
raise RuntimeError(f"cf-orch allocation response missing URL: {data}")
logger.info("Allocated cf-musicgen at %s", url)
return url
async def _release(self, service_url: str) -> None:
"""Release the cf-musicgen allocation back to cf-orch."""
if not _ORCH_URL:
return
try:
async with httpx.AsyncClient(timeout=10.0) as client:
await client.delete(f"{_ORCH_URL}/api/services/{_SERVICE}/allocations",
json={"url": service_url})
except Exception as exc:
logger.warning("Failed to release cf-musicgen allocation: %s", exc)
# ── Real /continue call ───────────────────────────────────────────────────────
async def _call_continue(
service_url: str,
audio_path: str,
output_path: str,
prompt: str,
duration_s: float,
cfg_coef: float,
prompt_duration_s: float,
) -> float:
"""Call POST /continue on the allocated cf-musicgen service."""
payload = {
"audio_path": audio_path,
"output_path": output_path,
"prompt": prompt,
"duration_s": duration_s,
"cfg_coef": cfg_coef,
"prompt_duration_s": prompt_duration_s,
}
# MusicGen can take 30120s depending on duration and hardware
async with httpx.AsyncClient(timeout=180.0) as client:
resp = await client.post(f"{service_url.rstrip('/')}/continue", json=payload)
resp.raise_for_status()
data = resp.json()
return float(data.get("duration_s", duration_s))
# ── Mock generation ───────────────────────────────────────────────────────────
async def _mock_generate(
source_audio_path: str,
output_path: str,
duration_s: float,
) -> float:
"""
Mock: copy the source file to output_path with a short simulated delay.
Simulates generation latency so the UI state machine (pending → generating → ready)
exercises all transitions during development.
"""
await asyncio.sleep(2.0) # simulate generation time
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
shutil.copy2(source_audio_path, output_path)
logger.info("Mock generation: copied %s%s", source_audio_path, output_path)
# Return the actual duration from the copied file
try:
import torchaudio
info = torchaudio.info(output_path)
return info.num_frames / info.sample_rate
except Exception:
return duration_s
def make_output_path(data_dir: str, chain_id: str, node_id: str) -> str:
"""Standard output path for a generated node's audio file."""
return str(Path(data_dir) / "chains" / chain_id / "nodes" / f"{node_id}.wav")