sparrow/tests/test_api_chains.py
pyr0ball a6f60c9e07 feat: implement Sparrow backend (v0.1.0)
Full FastAPI backend for the AI music continuation editor:

Services
- chain.py: chain + node CRUD, commit/discard, recursive CTE spine query
- musicgen.py: MusicGenClient with cf-orch allocation + mock mode (CF_MUSICGEN_MOCK=1)
- stems.py: Demucs 4-stem separation subprocess wrapper + mock mode
- export.py: ffmpeg concat demuxer to stitch committed spine into WAV/MP3

API endpoints
- chains: CRUD, multipart audio upload (WAV/MP3/FLAC/OGG/M4A/AIFF)
- nodes: branch creation (202 + BackgroundTasks), commit, discard, audio stream
- gpu: cf-orch capacity status; session allocation stubbed pending cf-orch#43
- stems: Paid-tier stem separation (Demucs, gated via tiers.py)
- export: POST /{chain_id}/export → FileResponse download
- events: SSE stream (node-status events) per chain via asyncio Queue pub/sub

Infrastructure
- lifespan: reads SPARROW_DB_PATH/DATA_DIR at startup (not import time)
- events_store: subscribe/unsubscribe/broadcast pattern for SSE
- CORS: open in dev, SPARROW_CORS_ORIGINS in production
- Background generation opens its own DB connection (WAL-safe)

Tests: 30/30 passing across service units and API integration
2026-04-17 15:22:37 -07:00

69 lines
2 KiB
Python

# tests/test_api_chains.py — integration tests for chain + upload endpoints
from __future__ import annotations
import io
def test_create_and_list_chains(client):
resp = client.post("/api/chains/", json={"name": "my chain"})
assert resp.status_code == 201
data = resp.json()
assert data["name"] == "my chain"
assert "id" in data
resp = client.get("/api/chains/")
assert resp.status_code == 200
assert len(resp.json()) == 1
def test_get_chain_not_found(client):
resp = client.get("/api/chains/doesnotexist")
assert resp.status_code == 404
def test_delete_chain(client):
resp = client.post("/api/chains/", json={"name": "to delete"})
chain_id = resp.json()["id"]
resp = client.delete(f"/api/chains/{chain_id}")
assert resp.status_code == 204
resp = client.get(f"/api/chains/{chain_id}")
assert resp.status_code == 404
def test_upload_root_node(client, tmp_path):
resp = client.post("/api/chains/", json={"name": "upload test"})
chain_id = resp.json()["id"]
# Minimal valid WAV header (44 bytes) so torchaudio doesn't crash
# If torchaudio isn't installed, _probe_duration returns 0.0 gracefully
wav_bytes = b"RIFF" + b"\x00" * 40
resp = client.post(
f"/api/chains/{chain_id}/upload",
files={"file": ("test.wav", io.BytesIO(wav_bytes), "audio/wav")},
)
assert resp.status_code == 201
node = resp.json()
assert node["status"] == "ready"
assert node["is_committed"] is True
assert node["parent_id"] is None
def test_upload_unsupported_format(client):
resp = client.post("/api/chains/", json={"name": "format test"})
chain_id = resp.json()["id"]
resp = client.post(
f"/api/chains/{chain_id}/upload",
files={"file": ("track.txt", io.BytesIO(b"not audio"), "text/plain")},
)
assert resp.status_code == 422
def test_upload_chain_not_found(client):
resp = client.post(
"/api/chains/nonexistent/upload",
files={"file": ("test.wav", io.BytesIO(b""), "audio/wav")},
)
assert resp.status_code == 404