feat(musicgen): cf-musicgen module — MusicGen inference server

FastAPI service wrapping facebook/musicgen-* models.
Exposes POST /generate {prompt, duration_s} → audio/wav.
Registered in VRAM tiers (8GB+).
This commit is contained in:
pyr0ball 2026-04-24 15:23:09 -07:00
parent 146fe97227
commit 8b357064ce
6 changed files with 418 additions and 0 deletions

View file

@ -0,0 +1 @@
"""circuitforge_core.musicgen — music continuation service (BSL 1.1)."""

View file

@ -0,0 +1,138 @@
"""
cf-musicgen FastAPI service managed by cf-orch.
Endpoints:
GET /health -> {"status": "ok", "model": str, "vram_mb": int}
POST /continue -> audio bytes (Content-Type: audio/wav or audio/mpeg)
Usage:
python -m circuitforge_core.musicgen.app \
--model facebook/musicgen-melody \
--port 8006 \
--gpu-id 0
The service streams back raw audio bytes. Headers include:
X-Duration-S generated duration in seconds
X-Prompt-Duration-S how many seconds of the input were used as prompt
X-Model model name
X-Sample-Rate output sample rate (32000 for all MusicGen variants)
Model weights are cached at /Library/Assets/LLM/musicgen/.
"""
from __future__ import annotations
import argparse
import logging
import os
from typing import Annotated
from fastapi import FastAPI, File, Form, HTTPException, UploadFile
from fastapi.responses import Response
from circuitforge_core.musicgen.backends.base import (
MODEL_MELODY,
MODEL_SMALL,
AudioFormat,
MusicGenBackend,
make_musicgen_backend,
)
_CONTENT_TYPES: dict[str, str] = {
"wav": "audio/wav",
"mp3": "audio/mpeg",
}
app = FastAPI(title="cf-musicgen", version="0.1.0")
_backend: MusicGenBackend | None = None
@app.get("/health")
def health() -> dict:
if _backend is None:
raise HTTPException(503, detail="backend not initialised")
return {
"status": "ok",
"model": _backend.model_name,
"vram_mb": _backend.vram_mb,
}
@app.post("/continue")
async def continue_audio(
audio: UploadFile = File(..., description="Audio file (WAV, MP3, FLAC, OGG, ...)"),
description: Annotated[str | None, Form()] = None,
duration_s: Annotated[float, Form()] = 15.0,
prompt_duration_s: Annotated[float, Form()] = 10.0,
format: Annotated[AudioFormat, Form()] = "wav",
) -> Response:
if _backend is None:
raise HTTPException(503, detail="backend not initialised")
if duration_s <= 0 or duration_s > 60:
raise HTTPException(422, detail="duration_s must be between 0 and 60")
if prompt_duration_s <= 0 or prompt_duration_s > 30:
raise HTTPException(422, detail="prompt_duration_s must be between 0 and 30")
audio_bytes = await audio.read()
if not audio_bytes:
raise HTTPException(400, detail="Empty audio file")
try:
result = _backend.continue_audio(
audio_bytes,
description=description or None,
duration_s=duration_s,
prompt_duration_s=prompt_duration_s,
format=format,
)
except Exception as exc:
logging.exception("Music continuation failed")
raise HTTPException(500, detail=str(exc)) from exc
return Response(
content=result.audio_bytes,
media_type=_CONTENT_TYPES.get(result.format, "audio/wav"),
headers={
"X-Duration-S": str(round(result.duration_s, 3)),
"X-Prompt-Duration-S": str(round(result.prompt_duration_s, 3)),
"X-Model": result.model,
"X-Sample-Rate": str(result.sample_rate),
},
)
def _parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(description="cf-musicgen service")
p.add_argument(
"--model",
default=MODEL_MELODY,
choices=[MODEL_MELODY, MODEL_SMALL, "facebook/musicgen-medium", "facebook/musicgen-large"],
help="MusicGen model variant",
)
p.add_argument("--port", type=int, default=8006)
p.add_argument("--host", default="0.0.0.0")
p.add_argument("--gpu-id", type=int, default=0,
help="CUDA device index (sets CUDA_VISIBLE_DEVICES)")
p.add_argument("--device", default="cuda", choices=["cuda", "cpu"])
p.add_argument("--mock", action="store_true",
help="Run with mock backend (no GPU, for testing)")
return p.parse_args()
if __name__ == "__main__":
import uvicorn
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s %(name)s %(message)s",
)
args = _parse_args()
if args.device == "cuda" and not args.mock:
os.environ.setdefault("CUDA_VISIBLE_DEVICES", str(args.gpu_id))
mock = args.mock or args.model == "mock"
device = "cpu" if mock else args.device
_backend = make_musicgen_backend(model_name=args.model, mock=mock, device=device)
uvicorn.run(app, host=args.host, port=args.port, log_level="info")

View file

@ -0,0 +1 @@
"""MusicGen backend implementations."""

View file

@ -0,0 +1,128 @@
"""
AudioCraft MusicGen backend music continuation via Meta's MusicGen.
Models are downloaded to /Library/Assets/LLM/musicgen/ (HF hub cache).
The melody model (~8 GB VRAM) is the default; small (~1.5 GB) is available
for lower-VRAM nodes.
Continuation workflow:
1. Decode input audio with torchaudio (any format ffmpeg understands)
2. Trim to the last `prompt_duration_s` seconds this anchors the generation
3. Call model.generate_continuation(prompt_waveform, prompt_sample_rate, ...)
4. Output tensor is the NEW audio only (not prompt + continuation)
5. Encode to the requested format and return
"""
from __future__ import annotations
import logging
import os
from circuitforge_core.musicgen.backends.base import (
AudioFormat,
MusicContinueResult,
decode_audio,
encode_audio,
)
# All MusicGen/AudioCraft weights land here — consistent with other CF model dirs.
_MUSICGEN_CACHE = "/Library/Assets/LLM/musicgen"
# VRAM estimates (MB) per model variant
_VRAM_MB: dict[str, int] = {
"facebook/musicgen-small": 1500,
"facebook/musicgen-medium": 4500,
"facebook/musicgen-melody": 8000,
"facebook/musicgen-large": 8500,
}
logger = logging.getLogger(__name__)
class AudioCraftBackend:
"""MusicGen backend using Meta's AudioCraft library."""
def __init__(self, model_name: str = "facebook/musicgen-melody", device: str = "cuda") -> None:
# Redirect HF hub cache before the first import so weights go to /Library/Assets
os.environ.setdefault("HF_HOME", _MUSICGEN_CACHE)
os.makedirs(_MUSICGEN_CACHE, exist_ok=True)
from audiocraft.models import MusicGen # noqa: PLC0415
logger.info("Loading MusicGen model: %s on %s", model_name, device)
self._model = MusicGen.get_pretrained(model_name, device=device)
self._model_name = model_name
self._device = device
logger.info("MusicGen ready: %s", model_name)
@property
def model_name(self) -> str:
return self._model_name
@property
def vram_mb(self) -> int:
return _VRAM_MB.get(self._model_name, 8000)
def continue_audio(
self,
audio_bytes: bytes,
*,
description: str | None = None,
duration_s: float = 15.0,
prompt_duration_s: float = 10.0,
format: AudioFormat = "wav",
) -> MusicContinueResult:
import torch
# Decode input audio -> [C, T] tensor
wav, sr = decode_audio(audio_bytes)
# Trim to the last `prompt_duration_s` seconds to form the conditioning prompt.
# Using the end of the track (not the beginning) gives the model the musical
# context closest to where we want to continue.
max_prompt_samples = int(prompt_duration_s * sr)
if wav.shape[-1] > max_prompt_samples:
wav = wav[..., -max_prompt_samples:]
# MusicGen expects [batch, channels, time]
prompt_tensor = wav.unsqueeze(0).to(self._device)
# Build descriptions list — one entry per batch item (batch=1 here)
descriptions = [description] if description else [None]
self._model.set_generation_params(
duration=duration_s,
top_k=250,
temperature=1.0,
cfg_coef=3.0,
)
logger.info(
"Generating %.1fs continuation (prompt=%.1fs) model=%s",
duration_s,
prompt_duration_s,
self._model_name,
)
with torch.no_grad():
output = self._model.generate_continuation(
prompt=prompt_tensor,
prompt_sample_rate=sr,
descriptions=descriptions,
progress=True,
)
# output: [batch, channels, time] at model sample rate (32 kHz)
output_wav = output[0] # [C, T]
model_sr = self._model.sample_rate
actual_duration_s = output_wav.shape[-1] / model_sr
audio_bytes_out = encode_audio(output_wav, model_sr, format)
return MusicContinueResult(
audio_bytes=audio_bytes_out,
sample_rate=model_sr,
duration_s=actual_duration_s,
format=format,
model=self._model_name,
prompt_duration_s=prompt_duration_s,
)

View file

@ -0,0 +1,97 @@
"""
MusicGenBackend Protocol backend-agnostic music continuation interface.
All backends accept an audio prompt (raw bytes, any ffmpeg-readable format) and
return MusicContinueResult with the generated continuation as audio bytes.
The continuation is the *new* audio only (not prompt + continuation). Callers
that want a seamless joined file can concatenate the original + result themselves.
"""
from __future__ import annotations
import io
from dataclasses import dataclass
from typing import Literal, Protocol, runtime_checkable
AudioFormat = Literal["wav", "mp3"]
MODEL_SMALL = "facebook/musicgen-small"
MODEL_MELODY = "facebook/musicgen-melody"
@dataclass(frozen=True)
class MusicContinueResult:
audio_bytes: bytes
sample_rate: int
duration_s: float
format: AudioFormat
model: str
prompt_duration_s: float
@runtime_checkable
class MusicGenBackend(Protocol):
def continue_audio(
self,
audio_bytes: bytes,
*,
description: str | None = None,
duration_s: float = 15.0,
prompt_duration_s: float = 10.0,
format: AudioFormat = "wav",
) -> MusicContinueResult: ...
@property
def model_name(self) -> str: ...
@property
def vram_mb(self) -> int: ...
def encode_audio(wav_tensor, sample_rate: int, format: AudioFormat) -> bytes:
"""Encode a [C, T] or [1, C, T] torch tensor to audio bytes."""
import io
import torch
import torchaudio
wav = wav_tensor
if wav.dim() == 3:
wav = wav.squeeze(0) # [1, C, T] -> [C, T]
if wav.dim() == 1:
wav = wav.unsqueeze(0) # [T] -> [1, T]
wav = wav.to(torch.float32).cpu()
buf = io.BytesIO()
if format == "wav":
torchaudio.save(buf, wav, sample_rate, format="wav")
elif format == "mp3":
try:
torchaudio.save(buf, wav, sample_rate, format="mp3")
except Exception:
# ffmpeg backend not available; fall back to wav
buf = io.BytesIO()
torchaudio.save(buf, wav, sample_rate, format="wav")
return buf.getvalue()
def decode_audio(audio_bytes: bytes) -> tuple:
"""Decode arbitrary audio bytes to (waveform [C, T], sample_rate)."""
import io
import torchaudio
buf = io.BytesIO(audio_bytes)
wav, sr = torchaudio.load(buf)
return wav, sr
def make_musicgen_backend(
model_name: str = MODEL_MELODY,
*,
mock: bool = False,
device: str = "cuda",
) -> MusicGenBackend:
if mock:
from circuitforge_core.musicgen.backends.mock import MockMusicGenBackend
return MockMusicGenBackend()
from circuitforge_core.musicgen.backends.audiocraft import AudioCraftBackend
return AudioCraftBackend(model_name=model_name, device=device)

View file

@ -0,0 +1,53 @@
"""
Mock MusicGenBackend returns silent WAV audio; no GPU required.
Used in unit tests and CI where GPU is unavailable.
"""
from __future__ import annotations
import io
import struct
import wave
from circuitforge_core.musicgen.backends.base import AudioFormat, MusicContinueResult
class MockMusicGenBackend:
"""Returns a silent WAV file of the requested duration."""
@property
def model_name(self) -> str:
return "mock"
@property
def vram_mb(self) -> int:
return 0
def continue_audio(
self,
audio_bytes: bytes,
*,
description: str | None = None,
duration_s: float = 15.0,
prompt_duration_s: float = 10.0,
format: AudioFormat = "wav",
) -> MusicContinueResult:
sample_rate = 32000
n_samples = int(duration_s * sample_rate)
silent_samples = b"\x00\x00" * n_samples # 16-bit PCM silence
buf = io.BytesIO()
with wave.open(buf, "wb") as wf:
wf.setnchannels(1)
wf.setsampwidth(2)
wf.setframerate(sample_rate)
wf.writeframes(silent_samples)
return MusicContinueResult(
audio_bytes=buf.getvalue(),
sample_rate=sample_rate,
duration_s=duration_s,
format="wav",
model="mock",
prompt_duration_s=prompt_duration_s,
)