# app/api/endpoints/chains.py — chain CRUD + root audio upload from __future__ import annotations import os import uuid from pathlib import Path from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile from app.api.deps import get_conn, get_data_dir from app.models.schemas.chain import ChainCreate, ChainRow, ChainTree from app.models.schemas.node import NodeRow from app.services import chain as chain_svc router = APIRouter(prefix="/api/chains", tags=["chains"]) _ALLOWED_AUDIO = {".wav", ".mp3", ".flac", ".ogg", ".m4a", ".aiff"} @router.post("/", response_model=ChainRow, status_code=201) def create_chain(body: ChainCreate, conn=Depends(get_conn)) -> dict: return chain_svc.create_chain(conn, body.name) @router.get("/", response_model=list[ChainRow]) def list_chains(conn=Depends(get_conn)) -> list: return chain_svc.list_chains(conn) @router.get("/{chain_id}", response_model=ChainTree) def get_chain(chain_id: str, conn=Depends(get_conn)) -> dict: chain = chain_svc.get_chain(conn, chain_id) if chain is None: raise HTTPException(status_code=404, detail="Chain not found.") return chain @router.delete("/{chain_id}", status_code=204) def delete_chain(chain_id: str, conn=Depends(get_conn)) -> None: if not chain_svc.delete_chain(conn, chain_id): raise HTTPException(status_code=404, detail="Chain not found.") @router.post("/{chain_id}/upload", response_model=NodeRow, status_code=201) async def upload_root( chain_id: str, file: UploadFile = File(...), conn=Depends(get_conn), data_dir: str = Depends(get_data_dir), ) -> dict: """ Upload source audio and create the root node for a chain. Accepts WAV, MP3, FLAC, OGG, M4A, AIFF. The file is stored in data/chains/{chain_id}/nodes/root_{node_id}{ext} and a root node (always committed) is created in the DB. """ chain = chain_svc.get_chain(conn, chain_id) if chain is None: raise HTTPException(status_code=404, detail="Chain not found.") suffix = Path(file.filename or "audio.wav").suffix.lower() if suffix not in _ALLOWED_AUDIO: raise HTTPException( status_code=422, detail=f"Unsupported audio format '{suffix}'. " f"Allowed: {', '.join(sorted(_ALLOWED_AUDIO))}", ) node_id = str(uuid.uuid4()) dest_dir = Path(data_dir) / "chains" / chain_id / "nodes" dest_dir.mkdir(parents=True, exist_ok=True) audio_path = str(dest_dir / f"root_{node_id}{suffix}") contents = await file.read() Path(audio_path).write_bytes(contents) # Get duration via torchaudio if available, else 0.0 duration_s = _probe_duration(audio_path) return chain_svc.create_root_node(conn, chain_id, audio_path, duration_s) def _probe_duration(path: str) -> float: try: import torchaudio info = torchaudio.info(path) return info.num_frames / info.sample_rate except Exception: return 0.0