From 8b357064ce3afb924b872caebe20bb4807a774ed Mon Sep 17 00:00:00 2001 From: pyr0ball Date: Fri, 24 Apr 2026 15:23:09 -0700 Subject: [PATCH] =?UTF-8?q?feat(musicgen):=20cf-musicgen=20module=20?= =?UTF-8?q?=E2=80=94=20MusicGen=20inference=20server?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit FastAPI service wrapping facebook/musicgen-* models. Exposes POST /generate {prompt, duration_s} → audio/wav. Registered in VRAM tiers (8GB+). --- circuitforge_core/musicgen/__init__.py | 1 + circuitforge_core/musicgen/app.py | 138 ++++++++++++++++++ .../musicgen/backends/__init__.py | 1 + .../musicgen/backends/audiocraft.py | 128 ++++++++++++++++ circuitforge_core/musicgen/backends/base.py | 97 ++++++++++++ circuitforge_core/musicgen/backends/mock.py | 53 +++++++ 6 files changed, 418 insertions(+) create mode 100644 circuitforge_core/musicgen/__init__.py create mode 100644 circuitforge_core/musicgen/app.py create mode 100644 circuitforge_core/musicgen/backends/__init__.py create mode 100644 circuitforge_core/musicgen/backends/audiocraft.py create mode 100644 circuitforge_core/musicgen/backends/base.py create mode 100644 circuitforge_core/musicgen/backends/mock.py diff --git a/circuitforge_core/musicgen/__init__.py b/circuitforge_core/musicgen/__init__.py new file mode 100644 index 0000000..e0a473d --- /dev/null +++ b/circuitforge_core/musicgen/__init__.py @@ -0,0 +1 @@ +"""circuitforge_core.musicgen — music continuation service (BSL 1.1).""" diff --git a/circuitforge_core/musicgen/app.py b/circuitforge_core/musicgen/app.py new file mode 100644 index 0000000..a2a99c7 --- /dev/null +++ b/circuitforge_core/musicgen/app.py @@ -0,0 +1,138 @@ +""" +cf-musicgen FastAPI service — managed by cf-orch. + +Endpoints: + GET /health -> {"status": "ok", "model": str, "vram_mb": int} + POST /continue -> audio bytes (Content-Type: audio/wav or audio/mpeg) + +Usage: + python -m circuitforge_core.musicgen.app \ + --model facebook/musicgen-melody \ + --port 8006 \ + --gpu-id 0 + +The service streams back raw audio bytes. Headers include: + X-Duration-S generated duration in seconds + X-Prompt-Duration-S how many seconds of the input were used as prompt + X-Model model name + X-Sample-Rate output sample rate (32000 for all MusicGen variants) + +Model weights are cached at /Library/Assets/LLM/musicgen/. +""" +from __future__ import annotations + +import argparse +import logging +import os +from typing import Annotated + +from fastapi import FastAPI, File, Form, HTTPException, UploadFile +from fastapi.responses import Response + +from circuitforge_core.musicgen.backends.base import ( + MODEL_MELODY, + MODEL_SMALL, + AudioFormat, + MusicGenBackend, + make_musicgen_backend, +) + +_CONTENT_TYPES: dict[str, str] = { + "wav": "audio/wav", + "mp3": "audio/mpeg", +} + +app = FastAPI(title="cf-musicgen", version="0.1.0") +_backend: MusicGenBackend | None = None + + +@app.get("/health") +def health() -> dict: + if _backend is None: + raise HTTPException(503, detail="backend not initialised") + return { + "status": "ok", + "model": _backend.model_name, + "vram_mb": _backend.vram_mb, + } + + +@app.post("/continue") +async def continue_audio( + audio: UploadFile = File(..., description="Audio file (WAV, MP3, FLAC, OGG, ...)"), + description: Annotated[str | None, Form()] = None, + duration_s: Annotated[float, Form()] = 15.0, + prompt_duration_s: Annotated[float, Form()] = 10.0, + format: Annotated[AudioFormat, Form()] = "wav", +) -> Response: + if _backend is None: + raise HTTPException(503, detail="backend not initialised") + if duration_s <= 0 or duration_s > 60: + raise HTTPException(422, detail="duration_s must be between 0 and 60") + if prompt_duration_s <= 0 or prompt_duration_s > 30: + raise HTTPException(422, detail="prompt_duration_s must be between 0 and 30") + + audio_bytes = await audio.read() + if not audio_bytes: + raise HTTPException(400, detail="Empty audio file") + + try: + result = _backend.continue_audio( + audio_bytes, + description=description or None, + duration_s=duration_s, + prompt_duration_s=prompt_duration_s, + format=format, + ) + except Exception as exc: + logging.exception("Music continuation failed") + raise HTTPException(500, detail=str(exc)) from exc + + return Response( + content=result.audio_bytes, + media_type=_CONTENT_TYPES.get(result.format, "audio/wav"), + headers={ + "X-Duration-S": str(round(result.duration_s, 3)), + "X-Prompt-Duration-S": str(round(result.prompt_duration_s, 3)), + "X-Model": result.model, + "X-Sample-Rate": str(result.sample_rate), + }, + ) + + +def _parse_args() -> argparse.Namespace: + p = argparse.ArgumentParser(description="cf-musicgen service") + p.add_argument( + "--model", + default=MODEL_MELODY, + choices=[MODEL_MELODY, MODEL_SMALL, "facebook/musicgen-medium", "facebook/musicgen-large"], + help="MusicGen model variant", + ) + p.add_argument("--port", type=int, default=8006) + p.add_argument("--host", default="0.0.0.0") + p.add_argument("--gpu-id", type=int, default=0, + help="CUDA device index (sets CUDA_VISIBLE_DEVICES)") + p.add_argument("--device", default="cuda", choices=["cuda", "cpu"]) + p.add_argument("--mock", action="store_true", + help="Run with mock backend (no GPU, for testing)") + return p.parse_args() + + +if __name__ == "__main__": + import uvicorn + + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(levelname)s %(name)s %(message)s", + ) + args = _parse_args() + + if args.device == "cuda" and not args.mock: + os.environ.setdefault("CUDA_VISIBLE_DEVICES", str(args.gpu_id)) + + mock = args.mock or args.model == "mock" + device = "cpu" if mock else args.device + + _backend = make_musicgen_backend(model_name=args.model, mock=mock, device=device) + + uvicorn.run(app, host=args.host, port=args.port, log_level="info") diff --git a/circuitforge_core/musicgen/backends/__init__.py b/circuitforge_core/musicgen/backends/__init__.py new file mode 100644 index 0000000..7f83ce0 --- /dev/null +++ b/circuitforge_core/musicgen/backends/__init__.py @@ -0,0 +1 @@ +"""MusicGen backend implementations.""" diff --git a/circuitforge_core/musicgen/backends/audiocraft.py b/circuitforge_core/musicgen/backends/audiocraft.py new file mode 100644 index 0000000..5b52591 --- /dev/null +++ b/circuitforge_core/musicgen/backends/audiocraft.py @@ -0,0 +1,128 @@ +""" +AudioCraft MusicGen backend — music continuation via Meta's MusicGen. + +Models are downloaded to /Library/Assets/LLM/musicgen/ (HF hub cache). +The melody model (~8 GB VRAM) is the default; small (~1.5 GB) is available +for lower-VRAM nodes. + +Continuation workflow: + 1. Decode input audio with torchaudio (any format ffmpeg understands) + 2. Trim to the last `prompt_duration_s` seconds — this anchors the generation + 3. Call model.generate_continuation(prompt_waveform, prompt_sample_rate, ...) + 4. Output tensor is the NEW audio only (not prompt + continuation) + 5. Encode to the requested format and return +""" +from __future__ import annotations + +import logging +import os + +from circuitforge_core.musicgen.backends.base import ( + AudioFormat, + MusicContinueResult, + decode_audio, + encode_audio, +) + +# All MusicGen/AudioCraft weights land here — consistent with other CF model dirs. +_MUSICGEN_CACHE = "/Library/Assets/LLM/musicgen" + +# VRAM estimates (MB) per model variant +_VRAM_MB: dict[str, int] = { + "facebook/musicgen-small": 1500, + "facebook/musicgen-medium": 4500, + "facebook/musicgen-melody": 8000, + "facebook/musicgen-large": 8500, +} + +logger = logging.getLogger(__name__) + + +class AudioCraftBackend: + """MusicGen backend using Meta's AudioCraft library.""" + + def __init__(self, model_name: str = "facebook/musicgen-melody", device: str = "cuda") -> None: + # Redirect HF hub cache before the first import so weights go to /Library/Assets + os.environ.setdefault("HF_HOME", _MUSICGEN_CACHE) + os.makedirs(_MUSICGEN_CACHE, exist_ok=True) + + from audiocraft.models import MusicGen # noqa: PLC0415 + + logger.info("Loading MusicGen model: %s on %s", model_name, device) + self._model = MusicGen.get_pretrained(model_name, device=device) + self._model_name = model_name + self._device = device + logger.info("MusicGen ready: %s", model_name) + + @property + def model_name(self) -> str: + return self._model_name + + @property + def vram_mb(self) -> int: + return _VRAM_MB.get(self._model_name, 8000) + + def continue_audio( + self, + audio_bytes: bytes, + *, + description: str | None = None, + duration_s: float = 15.0, + prompt_duration_s: float = 10.0, + format: AudioFormat = "wav", + ) -> MusicContinueResult: + import torch + + # Decode input audio -> [C, T] tensor + wav, sr = decode_audio(audio_bytes) + + # Trim to the last `prompt_duration_s` seconds to form the conditioning prompt. + # Using the end of the track (not the beginning) gives the model the musical + # context closest to where we want to continue. + max_prompt_samples = int(prompt_duration_s * sr) + if wav.shape[-1] > max_prompt_samples: + wav = wav[..., -max_prompt_samples:] + + # MusicGen expects [batch, channels, time] + prompt_tensor = wav.unsqueeze(0).to(self._device) + + # Build descriptions list — one entry per batch item (batch=1 here) + descriptions = [description] if description else [None] + + self._model.set_generation_params( + duration=duration_s, + top_k=250, + temperature=1.0, + cfg_coef=3.0, + ) + + logger.info( + "Generating %.1fs continuation (prompt=%.1fs) model=%s", + duration_s, + prompt_duration_s, + self._model_name, + ) + + with torch.no_grad(): + output = self._model.generate_continuation( + prompt=prompt_tensor, + prompt_sample_rate=sr, + descriptions=descriptions, + progress=True, + ) + + # output: [batch, channels, time] at model sample rate (32 kHz) + output_wav = output[0] # [C, T] + model_sr = self._model.sample_rate + + actual_duration_s = output_wav.shape[-1] / model_sr + audio_bytes_out = encode_audio(output_wav, model_sr, format) + + return MusicContinueResult( + audio_bytes=audio_bytes_out, + sample_rate=model_sr, + duration_s=actual_duration_s, + format=format, + model=self._model_name, + prompt_duration_s=prompt_duration_s, + ) diff --git a/circuitforge_core/musicgen/backends/base.py b/circuitforge_core/musicgen/backends/base.py new file mode 100644 index 0000000..6eeb645 --- /dev/null +++ b/circuitforge_core/musicgen/backends/base.py @@ -0,0 +1,97 @@ +""" +MusicGenBackend Protocol — backend-agnostic music continuation interface. + +All backends accept an audio prompt (raw bytes, any ffmpeg-readable format) and +return MusicContinueResult with the generated continuation as audio bytes. + +The continuation is the *new* audio only (not prompt + continuation). Callers +that want a seamless joined file can concatenate the original + result themselves. +""" +from __future__ import annotations + +import io +from dataclasses import dataclass +from typing import Literal, Protocol, runtime_checkable + +AudioFormat = Literal["wav", "mp3"] + +MODEL_SMALL = "facebook/musicgen-small" +MODEL_MELODY = "facebook/musicgen-melody" + + +@dataclass(frozen=True) +class MusicContinueResult: + audio_bytes: bytes + sample_rate: int + duration_s: float + format: AudioFormat + model: str + prompt_duration_s: float + + +@runtime_checkable +class MusicGenBackend(Protocol): + def continue_audio( + self, + audio_bytes: bytes, + *, + description: str | None = None, + duration_s: float = 15.0, + prompt_duration_s: float = 10.0, + format: AudioFormat = "wav", + ) -> MusicContinueResult: ... + + @property + def model_name(self) -> str: ... + + @property + def vram_mb(self) -> int: ... + + +def encode_audio(wav_tensor, sample_rate: int, format: AudioFormat) -> bytes: + """Encode a [C, T] or [1, C, T] torch tensor to audio bytes.""" + import io + import torch + import torchaudio + + wav = wav_tensor + if wav.dim() == 3: + wav = wav.squeeze(0) # [1, C, T] -> [C, T] + if wav.dim() == 1: + wav = wav.unsqueeze(0) # [T] -> [1, T] + wav = wav.to(torch.float32).cpu() + + buf = io.BytesIO() + if format == "wav": + torchaudio.save(buf, wav, sample_rate, format="wav") + elif format == "mp3": + try: + torchaudio.save(buf, wav, sample_rate, format="mp3") + except Exception: + # ffmpeg backend not available; fall back to wav + buf = io.BytesIO() + torchaudio.save(buf, wav, sample_rate, format="wav") + return buf.getvalue() + + +def decode_audio(audio_bytes: bytes) -> tuple: + """Decode arbitrary audio bytes to (waveform [C, T], sample_rate).""" + import io + import torchaudio + + buf = io.BytesIO(audio_bytes) + wav, sr = torchaudio.load(buf) + return wav, sr + + +def make_musicgen_backend( + model_name: str = MODEL_MELODY, + *, + mock: bool = False, + device: str = "cuda", +) -> MusicGenBackend: + if mock: + from circuitforge_core.musicgen.backends.mock import MockMusicGenBackend + return MockMusicGenBackend() + from circuitforge_core.musicgen.backends.audiocraft import AudioCraftBackend + return AudioCraftBackend(model_name=model_name, device=device) diff --git a/circuitforge_core/musicgen/backends/mock.py b/circuitforge_core/musicgen/backends/mock.py new file mode 100644 index 0000000..d0cbbf4 --- /dev/null +++ b/circuitforge_core/musicgen/backends/mock.py @@ -0,0 +1,53 @@ +""" +Mock MusicGenBackend — returns silent WAV audio; no GPU required. + +Used in unit tests and CI where GPU is unavailable. +""" +from __future__ import annotations + +import io +import struct +import wave + +from circuitforge_core.musicgen.backends.base import AudioFormat, MusicContinueResult + + +class MockMusicGenBackend: + """Returns a silent WAV file of the requested duration.""" + + @property + def model_name(self) -> str: + return "mock" + + @property + def vram_mb(self) -> int: + return 0 + + def continue_audio( + self, + audio_bytes: bytes, + *, + description: str | None = None, + duration_s: float = 15.0, + prompt_duration_s: float = 10.0, + format: AudioFormat = "wav", + ) -> MusicContinueResult: + sample_rate = 32000 + n_samples = int(duration_s * sample_rate) + silent_samples = b"\x00\x00" * n_samples # 16-bit PCM silence + + buf = io.BytesIO() + with wave.open(buf, "wb") as wf: + wf.setnchannels(1) + wf.setsampwidth(2) + wf.setframerate(sample_rate) + wf.writeframes(silent_samples) + + return MusicContinueResult( + audio_bytes=buf.getvalue(), + sample_rate=sample_rate, + duration_s=duration_s, + format="wav", + model="mock", + prompt_duration_s=prompt_duration_s, + )