Full FastAPI backend for the AI music continuation editor:
Services
- chain.py: chain + node CRUD, commit/discard, recursive CTE spine query
- musicgen.py: MusicGenClient with cf-orch allocation + mock mode (CF_MUSICGEN_MOCK=1)
- stems.py: Demucs 4-stem separation subprocess wrapper + mock mode
- export.py: ffmpeg concat demuxer to stitch committed spine into WAV/MP3
API endpoints
- chains: CRUD, multipart audio upload (WAV/MP3/FLAC/OGG/M4A/AIFF)
- nodes: branch creation (202 + BackgroundTasks), commit, discard, audio stream
- gpu: cf-orch capacity status; session allocation stubbed pending cf-orch#43
- stems: Paid-tier stem separation (Demucs, gated via tiers.py)
- export: POST /{chain_id}/export → FileResponse download
- events: SSE stream (node-status events) per chain via asyncio Queue pub/sub
Infrastructure
- lifespan: reads SPARROW_DB_PATH/DATA_DIR at startup (not import time)
- events_store: subscribe/unsubscribe/broadcast pattern for SSE
- CORS: open in dev, SPARROW_CORS_ORIGINS in production
- Background generation opens its own DB connection (WAL-safe)
Tests: 30/30 passing across service units and API integration
54 lines
1.7 KiB
Python
54 lines
1.7 KiB
Python
# app/api/endpoints/events.py — SSE stream for node status transitions
|
|
#
|
|
# Clients connect to GET /api/chains/{chain_id}/events and receive
|
|
# server-sent events whenever a node in that chain changes status.
|
|
#
|
|
# Event format:
|
|
# event: node-status
|
|
# data: {"node_id": "...", "status": "generating"|"ready"|"error", ...}
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
|
|
from fastapi import APIRouter, Depends
|
|
from sse_starlette.sse import EventSourceResponse
|
|
|
|
from app.api.deps import get_conn
|
|
from app.api.events_store import subscribe, unsubscribe
|
|
|
|
logger = logging.getLogger(__name__)
|
|
router = APIRouter(prefix="/api/chains", tags=["events"])
|
|
|
|
_KEEPALIVE_S = 15 # send a comment ping every N seconds to keep the connection alive
|
|
|
|
|
|
@router.get("/{chain_id}/events")
|
|
async def chain_events(chain_id: str, conn=Depends(get_conn)) -> EventSourceResponse:
|
|
"""
|
|
SSE stream for node status transitions in a chain.
|
|
|
|
Emits 'node-status' events. Closes when the client disconnects.
|
|
"""
|
|
q = subscribe(chain_id)
|
|
|
|
async def generator():
|
|
try:
|
|
while True:
|
|
try:
|
|
event = await asyncio.wait_for(q.get(), timeout=_KEEPALIVE_S)
|
|
yield {
|
|
"event": "node-status",
|
|
"data": json.dumps(event),
|
|
}
|
|
except asyncio.TimeoutError:
|
|
# Send keepalive comment to prevent proxy/browser timeout
|
|
yield {"comment": "keepalive"}
|
|
except asyncio.CancelledError:
|
|
pass
|
|
finally:
|
|
unsubscribe(chain_id, q)
|
|
logger.debug("SSE client disconnected from chain %s", chain_id)
|
|
|
|
return EventSourceResponse(generator())
|