Full FastAPI backend for the AI music continuation editor:
Services
- chain.py: chain + node CRUD, commit/discard, recursive CTE spine query
- musicgen.py: MusicGenClient with cf-orch allocation + mock mode (CF_MUSICGEN_MOCK=1)
- stems.py: Demucs 4-stem separation subprocess wrapper + mock mode
- export.py: ffmpeg concat demuxer to stitch committed spine into WAV/MP3
API endpoints
- chains: CRUD, multipart audio upload (WAV/MP3/FLAC/OGG/M4A/AIFF)
- nodes: branch creation (202 + BackgroundTasks), commit, discard, audio stream
- gpu: cf-orch capacity status; session allocation stubbed pending cf-orch#43
- stems: Paid-tier stem separation (Demucs, gated via tiers.py)
- export: POST /{chain_id}/export → FileResponse download
- events: SSE stream (node-status events) per chain via asyncio Queue pub/sub
Infrastructure
- lifespan: reads SPARROW_DB_PATH/DATA_DIR at startup (not import time)
- events_store: subscribe/unsubscribe/broadcast pattern for SSE
- CORS: open in dev, SPARROW_CORS_ORIGINS in production
- Background generation opens its own DB connection (WAL-safe)
Tests: 30/30 passing across service units and API integration
69 lines
2 KiB
Python
69 lines
2 KiB
Python
# tests/test_api_chains.py — integration tests for chain + upload endpoints
|
|
from __future__ import annotations
|
|
|
|
import io
|
|
|
|
|
|
def test_create_and_list_chains(client):
|
|
resp = client.post("/api/chains/", json={"name": "my chain"})
|
|
assert resp.status_code == 201
|
|
data = resp.json()
|
|
assert data["name"] == "my chain"
|
|
assert "id" in data
|
|
|
|
resp = client.get("/api/chains/")
|
|
assert resp.status_code == 200
|
|
assert len(resp.json()) == 1
|
|
|
|
|
|
def test_get_chain_not_found(client):
|
|
resp = client.get("/api/chains/doesnotexist")
|
|
assert resp.status_code == 404
|
|
|
|
|
|
def test_delete_chain(client):
|
|
resp = client.post("/api/chains/", json={"name": "to delete"})
|
|
chain_id = resp.json()["id"]
|
|
|
|
resp = client.delete(f"/api/chains/{chain_id}")
|
|
assert resp.status_code == 204
|
|
|
|
resp = client.get(f"/api/chains/{chain_id}")
|
|
assert resp.status_code == 404
|
|
|
|
|
|
def test_upload_root_node(client, tmp_path):
|
|
resp = client.post("/api/chains/", json={"name": "upload test"})
|
|
chain_id = resp.json()["id"]
|
|
|
|
# Minimal valid WAV header (44 bytes) so torchaudio doesn't crash
|
|
# If torchaudio isn't installed, _probe_duration returns 0.0 gracefully
|
|
wav_bytes = b"RIFF" + b"\x00" * 40
|
|
resp = client.post(
|
|
f"/api/chains/{chain_id}/upload",
|
|
files={"file": ("test.wav", io.BytesIO(wav_bytes), "audio/wav")},
|
|
)
|
|
assert resp.status_code == 201
|
|
node = resp.json()
|
|
assert node["status"] == "ready"
|
|
assert node["is_committed"] is True
|
|
assert node["parent_id"] is None
|
|
|
|
|
|
def test_upload_unsupported_format(client):
|
|
resp = client.post("/api/chains/", json={"name": "format test"})
|
|
chain_id = resp.json()["id"]
|
|
|
|
resp = client.post(
|
|
f"/api/chains/{chain_id}/upload",
|
|
files={"file": ("track.txt", io.BytesIO(b"not audio"), "text/plain")},
|
|
)
|
|
assert resp.status_code == 422
|
|
|
|
|
|
def test_upload_chain_not_found(client):
|
|
resp = client.post(
|
|
"/api/chains/nonexistent/upload",
|
|
files={"file": ("test.wav", io.BytesIO(b""), "audio/wav")},
|
|
)
|
|
assert resp.status_code == 404
|