Full FastAPI backend for the AI music continuation editor:
Services
- chain.py: chain + node CRUD, commit/discard, recursive CTE spine query
- musicgen.py: MusicGenClient with cf-orch allocation + mock mode (CF_MUSICGEN_MOCK=1)
- stems.py: Demucs 4-stem separation subprocess wrapper + mock mode
- export.py: ffmpeg concat demuxer to stitch committed spine into WAV/MP3
API endpoints
- chains: CRUD, multipart audio upload (WAV/MP3/FLAC/OGG/M4A/AIFF)
- nodes: branch creation (202 + BackgroundTasks), commit, discard, audio stream
- gpu: cf-orch capacity status; session allocation stubbed pending cf-orch#43
- stems: Paid-tier stem separation (Demucs, gated via tiers.py)
- export: POST /{chain_id}/export → FileResponse download
- events: SSE stream (node-status events) per chain via asyncio Queue pub/sub
Infrastructure
- lifespan: reads SPARROW_DB_PATH/DATA_DIR at startup (not import time)
- events_store: subscribe/unsubscribe/broadcast pattern for SSE
- CORS: open in dev, SPARROW_CORS_ORIGINS in production
- Background generation opens its own DB connection (WAL-safe)
Tests: 30/30 passing across service units and API integration
157 lines
5.6 KiB
Python
157 lines
5.6 KiB
Python
# app/services/musicgen.py — MusicGen client (cf-orch allocation + /continue call)
|
||
#
|
||
# Mock mode (CF_MUSICGEN_MOCK=1): copies source file, adds 1s of silence padding.
|
||
# Real mode: allocates cf-musicgen from cf-orch, calls POST /continue, releases.
|
||
#
|
||
# cf_core.musicgen.app is not yet implemented (tracked in cf-core #49).
|
||
# Until it is, real mode will fail at allocation time with a clear error.
|
||
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import logging
|
||
import os
|
||
import shutil
|
||
import time
|
||
import uuid
|
||
from pathlib import Path
|
||
|
||
import httpx
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
_ORCH_URL = os.environ.get("CF_ORCH_URL", "").rstrip("/")
|
||
_SERVICE = "cf-musicgen"
|
||
_MOCK = os.environ.get("CF_MUSICGEN_MOCK", "") == "1"
|
||
|
||
|
||
class MusicGenClient:
|
||
"""
|
||
Allocates a cf-musicgen instance from cf-orch and calls POST /continue.
|
||
|
||
Each generate() call allocates, generates, and releases. For Premium tier
|
||
session-held allocations, subclass and override _allocate/_release.
|
||
"""
|
||
|
||
async def generate(
|
||
self,
|
||
source_audio_path: str,
|
||
output_path: str,
|
||
prompt: str,
|
||
duration_s: float,
|
||
cfg_coef: float,
|
||
prompt_duration_s: float,
|
||
) -> float:
|
||
"""
|
||
Generate a continuation of source_audio_path.
|
||
|
||
Writes the result to output_path and returns the actual duration_s.
|
||
Raises RuntimeError on cf-orch allocation failure or generation error.
|
||
"""
|
||
if _MOCK:
|
||
return await _mock_generate(source_audio_path, output_path, duration_s)
|
||
|
||
service_url = await self._allocate()
|
||
try:
|
||
return await _call_continue(
|
||
service_url=service_url,
|
||
audio_path=source_audio_path,
|
||
output_path=output_path,
|
||
prompt=prompt,
|
||
duration_s=duration_s,
|
||
cfg_coef=cfg_coef,
|
||
prompt_duration_s=prompt_duration_s,
|
||
)
|
||
finally:
|
||
await self._release(service_url)
|
||
|
||
async def _allocate(self) -> str:
|
||
"""Request a cf-musicgen allocation from cf-orch. Returns the service URL."""
|
||
if not _ORCH_URL:
|
||
raise RuntimeError(
|
||
"CF_ORCH_URL is not configured. Set it in .env or use CF_MUSICGEN_MOCK=1."
|
||
)
|
||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||
resp = await client.post(
|
||
f"{_ORCH_URL}/api/services/{_SERVICE}/allocations",
|
||
json={"requester": "sparrow"},
|
||
)
|
||
if resp.status_code == 503:
|
||
raise RuntimeError("No cf-musicgen capacity available — all GPUs busy.")
|
||
resp.raise_for_status()
|
||
data = resp.json()
|
||
url = data.get("url") or data.get("service_url")
|
||
if not url:
|
||
raise RuntimeError(f"cf-orch allocation response missing URL: {data}")
|
||
logger.info("Allocated cf-musicgen at %s", url)
|
||
return url
|
||
|
||
async def _release(self, service_url: str) -> None:
|
||
"""Release the cf-musicgen allocation back to cf-orch."""
|
||
if not _ORCH_URL:
|
||
return
|
||
try:
|
||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||
await client.delete(f"{_ORCH_URL}/api/services/{_SERVICE}/allocations",
|
||
json={"url": service_url})
|
||
except Exception as exc:
|
||
logger.warning("Failed to release cf-musicgen allocation: %s", exc)
|
||
|
||
|
||
# ── Real /continue call ───────────────────────────────────────────────────────
|
||
|
||
async def _call_continue(
|
||
service_url: str,
|
||
audio_path: str,
|
||
output_path: str,
|
||
prompt: str,
|
||
duration_s: float,
|
||
cfg_coef: float,
|
||
prompt_duration_s: float,
|
||
) -> float:
|
||
"""Call POST /continue on the allocated cf-musicgen service."""
|
||
payload = {
|
||
"audio_path": audio_path,
|
||
"output_path": output_path,
|
||
"prompt": prompt,
|
||
"duration_s": duration_s,
|
||
"cfg_coef": cfg_coef,
|
||
"prompt_duration_s": prompt_duration_s,
|
||
}
|
||
# MusicGen can take 30–120s depending on duration and hardware
|
||
async with httpx.AsyncClient(timeout=180.0) as client:
|
||
resp = await client.post(f"{service_url.rstrip('/')}/continue", json=payload)
|
||
resp.raise_for_status()
|
||
data = resp.json()
|
||
return float(data.get("duration_s", duration_s))
|
||
|
||
|
||
# ── Mock generation ───────────────────────────────────────────────────────────
|
||
|
||
async def _mock_generate(
|
||
source_audio_path: str,
|
||
output_path: str,
|
||
duration_s: float,
|
||
) -> float:
|
||
"""
|
||
Mock: copy the source file to output_path with a short simulated delay.
|
||
|
||
Simulates generation latency so the UI state machine (pending → generating → ready)
|
||
exercises all transitions during development.
|
||
"""
|
||
await asyncio.sleep(2.0) # simulate generation time
|
||
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
||
shutil.copy2(source_audio_path, output_path)
|
||
logger.info("Mock generation: copied %s → %s", source_audio_path, output_path)
|
||
|
||
# Return the actual duration from the copied file
|
||
try:
|
||
import torchaudio
|
||
info = torchaudio.info(output_path)
|
||
return info.num_frames / info.sample_rate
|
||
except Exception:
|
||
return duration_s
|
||
|
||
|
||
def make_output_path(data_dir: str, chain_id: str, node_id: str) -> str:
|
||
"""Standard output path for a generated node's audio file."""
|
||
return str(Path(data_dir) / "chains" / chain_id / "nodes" / f"{node_id}.wav")
|