feat(musicgen): cf-musicgen module — MusicGen inference server
FastAPI service wrapping facebook/musicgen-* models.
Exposes POST /generate {prompt, duration_s} → audio/wav.
Registered in VRAM tiers (8GB+).
This commit is contained in:
parent
146fe97227
commit
8b357064ce
6 changed files with 418 additions and 0 deletions
1
circuitforge_core/musicgen/__init__.py
Normal file
1
circuitforge_core/musicgen/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
"""circuitforge_core.musicgen — music continuation service (BSL 1.1)."""
|
||||
138
circuitforge_core/musicgen/app.py
Normal file
138
circuitforge_core/musicgen/app.py
Normal file
|
|
@ -0,0 +1,138 @@
|
|||
"""
|
||||
cf-musicgen FastAPI service — managed by cf-orch.
|
||||
|
||||
Endpoints:
|
||||
GET /health -> {"status": "ok", "model": str, "vram_mb": int}
|
||||
POST /continue -> audio bytes (Content-Type: audio/wav or audio/mpeg)
|
||||
|
||||
Usage:
|
||||
python -m circuitforge_core.musicgen.app \
|
||||
--model facebook/musicgen-melody \
|
||||
--port 8006 \
|
||||
--gpu-id 0
|
||||
|
||||
The service streams back raw audio bytes. Headers include:
|
||||
X-Duration-S generated duration in seconds
|
||||
X-Prompt-Duration-S how many seconds of the input were used as prompt
|
||||
X-Model model name
|
||||
X-Sample-Rate output sample rate (32000 for all MusicGen variants)
|
||||
|
||||
Model weights are cached at /Library/Assets/LLM/musicgen/.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import FastAPI, File, Form, HTTPException, UploadFile
|
||||
from fastapi.responses import Response
|
||||
|
||||
from circuitforge_core.musicgen.backends.base import (
|
||||
MODEL_MELODY,
|
||||
MODEL_SMALL,
|
||||
AudioFormat,
|
||||
MusicGenBackend,
|
||||
make_musicgen_backend,
|
||||
)
|
||||
|
||||
_CONTENT_TYPES: dict[str, str] = {
|
||||
"wav": "audio/wav",
|
||||
"mp3": "audio/mpeg",
|
||||
}
|
||||
|
||||
app = FastAPI(title="cf-musicgen", version="0.1.0")
|
||||
_backend: MusicGenBackend | None = None
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
def health() -> dict:
|
||||
if _backend is None:
|
||||
raise HTTPException(503, detail="backend not initialised")
|
||||
return {
|
||||
"status": "ok",
|
||||
"model": _backend.model_name,
|
||||
"vram_mb": _backend.vram_mb,
|
||||
}
|
||||
|
||||
|
||||
@app.post("/continue")
|
||||
async def continue_audio(
|
||||
audio: UploadFile = File(..., description="Audio file (WAV, MP3, FLAC, OGG, ...)"),
|
||||
description: Annotated[str | None, Form()] = None,
|
||||
duration_s: Annotated[float, Form()] = 15.0,
|
||||
prompt_duration_s: Annotated[float, Form()] = 10.0,
|
||||
format: Annotated[AudioFormat, Form()] = "wav",
|
||||
) -> Response:
|
||||
if _backend is None:
|
||||
raise HTTPException(503, detail="backend not initialised")
|
||||
if duration_s <= 0 or duration_s > 60:
|
||||
raise HTTPException(422, detail="duration_s must be between 0 and 60")
|
||||
if prompt_duration_s <= 0 or prompt_duration_s > 30:
|
||||
raise HTTPException(422, detail="prompt_duration_s must be between 0 and 30")
|
||||
|
||||
audio_bytes = await audio.read()
|
||||
if not audio_bytes:
|
||||
raise HTTPException(400, detail="Empty audio file")
|
||||
|
||||
try:
|
||||
result = _backend.continue_audio(
|
||||
audio_bytes,
|
||||
description=description or None,
|
||||
duration_s=duration_s,
|
||||
prompt_duration_s=prompt_duration_s,
|
||||
format=format,
|
||||
)
|
||||
except Exception as exc:
|
||||
logging.exception("Music continuation failed")
|
||||
raise HTTPException(500, detail=str(exc)) from exc
|
||||
|
||||
return Response(
|
||||
content=result.audio_bytes,
|
||||
media_type=_CONTENT_TYPES.get(result.format, "audio/wav"),
|
||||
headers={
|
||||
"X-Duration-S": str(round(result.duration_s, 3)),
|
||||
"X-Prompt-Duration-S": str(round(result.prompt_duration_s, 3)),
|
||||
"X-Model": result.model,
|
||||
"X-Sample-Rate": str(result.sample_rate),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _parse_args() -> argparse.Namespace:
|
||||
p = argparse.ArgumentParser(description="cf-musicgen service")
|
||||
p.add_argument(
|
||||
"--model",
|
||||
default=MODEL_MELODY,
|
||||
choices=[MODEL_MELODY, MODEL_SMALL, "facebook/musicgen-medium", "facebook/musicgen-large"],
|
||||
help="MusicGen model variant",
|
||||
)
|
||||
p.add_argument("--port", type=int, default=8006)
|
||||
p.add_argument("--host", default="0.0.0.0")
|
||||
p.add_argument("--gpu-id", type=int, default=0,
|
||||
help="CUDA device index (sets CUDA_VISIBLE_DEVICES)")
|
||||
p.add_argument("--device", default="cuda", choices=["cuda", "cpu"])
|
||||
p.add_argument("--mock", action="store_true",
|
||||
help="Run with mock backend (no GPU, for testing)")
|
||||
return p.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s %(levelname)s %(name)s %(message)s",
|
||||
)
|
||||
args = _parse_args()
|
||||
|
||||
if args.device == "cuda" and not args.mock:
|
||||
os.environ.setdefault("CUDA_VISIBLE_DEVICES", str(args.gpu_id))
|
||||
|
||||
mock = args.mock or args.model == "mock"
|
||||
device = "cpu" if mock else args.device
|
||||
|
||||
_backend = make_musicgen_backend(model_name=args.model, mock=mock, device=device)
|
||||
|
||||
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|
||||
1
circuitforge_core/musicgen/backends/__init__.py
Normal file
1
circuitforge_core/musicgen/backends/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
"""MusicGen backend implementations."""
|
||||
128
circuitforge_core/musicgen/backends/audiocraft.py
Normal file
128
circuitforge_core/musicgen/backends/audiocraft.py
Normal file
|
|
@ -0,0 +1,128 @@
|
|||
"""
|
||||
AudioCraft MusicGen backend — music continuation via Meta's MusicGen.
|
||||
|
||||
Models are downloaded to /Library/Assets/LLM/musicgen/ (HF hub cache).
|
||||
The melody model (~8 GB VRAM) is the default; small (~1.5 GB) is available
|
||||
for lower-VRAM nodes.
|
||||
|
||||
Continuation workflow:
|
||||
1. Decode input audio with torchaudio (any format ffmpeg understands)
|
||||
2. Trim to the last `prompt_duration_s` seconds — this anchors the generation
|
||||
3. Call model.generate_continuation(prompt_waveform, prompt_sample_rate, ...)
|
||||
4. Output tensor is the NEW audio only (not prompt + continuation)
|
||||
5. Encode to the requested format and return
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
from circuitforge_core.musicgen.backends.base import (
|
||||
AudioFormat,
|
||||
MusicContinueResult,
|
||||
decode_audio,
|
||||
encode_audio,
|
||||
)
|
||||
|
||||
# All MusicGen/AudioCraft weights land here — consistent with other CF model dirs.
|
||||
_MUSICGEN_CACHE = "/Library/Assets/LLM/musicgen"
|
||||
|
||||
# VRAM estimates (MB) per model variant
|
||||
_VRAM_MB: dict[str, int] = {
|
||||
"facebook/musicgen-small": 1500,
|
||||
"facebook/musicgen-medium": 4500,
|
||||
"facebook/musicgen-melody": 8000,
|
||||
"facebook/musicgen-large": 8500,
|
||||
}
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AudioCraftBackend:
|
||||
"""MusicGen backend using Meta's AudioCraft library."""
|
||||
|
||||
def __init__(self, model_name: str = "facebook/musicgen-melody", device: str = "cuda") -> None:
|
||||
# Redirect HF hub cache before the first import so weights go to /Library/Assets
|
||||
os.environ.setdefault("HF_HOME", _MUSICGEN_CACHE)
|
||||
os.makedirs(_MUSICGEN_CACHE, exist_ok=True)
|
||||
|
||||
from audiocraft.models import MusicGen # noqa: PLC0415
|
||||
|
||||
logger.info("Loading MusicGen model: %s on %s", model_name, device)
|
||||
self._model = MusicGen.get_pretrained(model_name, device=device)
|
||||
self._model_name = model_name
|
||||
self._device = device
|
||||
logger.info("MusicGen ready: %s", model_name)
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
return self._model_name
|
||||
|
||||
@property
|
||||
def vram_mb(self) -> int:
|
||||
return _VRAM_MB.get(self._model_name, 8000)
|
||||
|
||||
def continue_audio(
|
||||
self,
|
||||
audio_bytes: bytes,
|
||||
*,
|
||||
description: str | None = None,
|
||||
duration_s: float = 15.0,
|
||||
prompt_duration_s: float = 10.0,
|
||||
format: AudioFormat = "wav",
|
||||
) -> MusicContinueResult:
|
||||
import torch
|
||||
|
||||
# Decode input audio -> [C, T] tensor
|
||||
wav, sr = decode_audio(audio_bytes)
|
||||
|
||||
# Trim to the last `prompt_duration_s` seconds to form the conditioning prompt.
|
||||
# Using the end of the track (not the beginning) gives the model the musical
|
||||
# context closest to where we want to continue.
|
||||
max_prompt_samples = int(prompt_duration_s * sr)
|
||||
if wav.shape[-1] > max_prompt_samples:
|
||||
wav = wav[..., -max_prompt_samples:]
|
||||
|
||||
# MusicGen expects [batch, channels, time]
|
||||
prompt_tensor = wav.unsqueeze(0).to(self._device)
|
||||
|
||||
# Build descriptions list — one entry per batch item (batch=1 here)
|
||||
descriptions = [description] if description else [None]
|
||||
|
||||
self._model.set_generation_params(
|
||||
duration=duration_s,
|
||||
top_k=250,
|
||||
temperature=1.0,
|
||||
cfg_coef=3.0,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Generating %.1fs continuation (prompt=%.1fs) model=%s",
|
||||
duration_s,
|
||||
prompt_duration_s,
|
||||
self._model_name,
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
output = self._model.generate_continuation(
|
||||
prompt=prompt_tensor,
|
||||
prompt_sample_rate=sr,
|
||||
descriptions=descriptions,
|
||||
progress=True,
|
||||
)
|
||||
|
||||
# output: [batch, channels, time] at model sample rate (32 kHz)
|
||||
output_wav = output[0] # [C, T]
|
||||
model_sr = self._model.sample_rate
|
||||
|
||||
actual_duration_s = output_wav.shape[-1] / model_sr
|
||||
audio_bytes_out = encode_audio(output_wav, model_sr, format)
|
||||
|
||||
return MusicContinueResult(
|
||||
audio_bytes=audio_bytes_out,
|
||||
sample_rate=model_sr,
|
||||
duration_s=actual_duration_s,
|
||||
format=format,
|
||||
model=self._model_name,
|
||||
prompt_duration_s=prompt_duration_s,
|
||||
)
|
||||
97
circuitforge_core/musicgen/backends/base.py
Normal file
97
circuitforge_core/musicgen/backends/base.py
Normal file
|
|
@ -0,0 +1,97 @@
|
|||
"""
|
||||
MusicGenBackend Protocol — backend-agnostic music continuation interface.
|
||||
|
||||
All backends accept an audio prompt (raw bytes, any ffmpeg-readable format) and
|
||||
return MusicContinueResult with the generated continuation as audio bytes.
|
||||
|
||||
The continuation is the *new* audio only (not prompt + continuation). Callers
|
||||
that want a seamless joined file can concatenate the original + result themselves.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal, Protocol, runtime_checkable
|
||||
|
||||
AudioFormat = Literal["wav", "mp3"]
|
||||
|
||||
MODEL_SMALL = "facebook/musicgen-small"
|
||||
MODEL_MELODY = "facebook/musicgen-melody"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MusicContinueResult:
|
||||
audio_bytes: bytes
|
||||
sample_rate: int
|
||||
duration_s: float
|
||||
format: AudioFormat
|
||||
model: str
|
||||
prompt_duration_s: float
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class MusicGenBackend(Protocol):
|
||||
def continue_audio(
|
||||
self,
|
||||
audio_bytes: bytes,
|
||||
*,
|
||||
description: str | None = None,
|
||||
duration_s: float = 15.0,
|
||||
prompt_duration_s: float = 10.0,
|
||||
format: AudioFormat = "wav",
|
||||
) -> MusicContinueResult: ...
|
||||
|
||||
@property
|
||||
def model_name(self) -> str: ...
|
||||
|
||||
@property
|
||||
def vram_mb(self) -> int: ...
|
||||
|
||||
|
||||
def encode_audio(wav_tensor, sample_rate: int, format: AudioFormat) -> bytes:
|
||||
"""Encode a [C, T] or [1, C, T] torch tensor to audio bytes."""
|
||||
import io
|
||||
import torch
|
||||
import torchaudio
|
||||
|
||||
wav = wav_tensor
|
||||
if wav.dim() == 3:
|
||||
wav = wav.squeeze(0) # [1, C, T] -> [C, T]
|
||||
if wav.dim() == 1:
|
||||
wav = wav.unsqueeze(0) # [T] -> [1, T]
|
||||
wav = wav.to(torch.float32).cpu()
|
||||
|
||||
buf = io.BytesIO()
|
||||
if format == "wav":
|
||||
torchaudio.save(buf, wav, sample_rate, format="wav")
|
||||
elif format == "mp3":
|
||||
try:
|
||||
torchaudio.save(buf, wav, sample_rate, format="mp3")
|
||||
except Exception:
|
||||
# ffmpeg backend not available; fall back to wav
|
||||
buf = io.BytesIO()
|
||||
torchaudio.save(buf, wav, sample_rate, format="wav")
|
||||
return buf.getvalue()
|
||||
|
||||
|
||||
def decode_audio(audio_bytes: bytes) -> tuple:
|
||||
"""Decode arbitrary audio bytes to (waveform [C, T], sample_rate)."""
|
||||
import io
|
||||
import torchaudio
|
||||
|
||||
buf = io.BytesIO(audio_bytes)
|
||||
wav, sr = torchaudio.load(buf)
|
||||
return wav, sr
|
||||
|
||||
|
||||
def make_musicgen_backend(
|
||||
model_name: str = MODEL_MELODY,
|
||||
*,
|
||||
mock: bool = False,
|
||||
device: str = "cuda",
|
||||
) -> MusicGenBackend:
|
||||
if mock:
|
||||
from circuitforge_core.musicgen.backends.mock import MockMusicGenBackend
|
||||
return MockMusicGenBackend()
|
||||
from circuitforge_core.musicgen.backends.audiocraft import AudioCraftBackend
|
||||
return AudioCraftBackend(model_name=model_name, device=device)
|
||||
53
circuitforge_core/musicgen/backends/mock.py
Normal file
53
circuitforge_core/musicgen/backends/mock.py
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
"""
|
||||
Mock MusicGenBackend — returns silent WAV audio; no GPU required.
|
||||
|
||||
Used in unit tests and CI where GPU is unavailable.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import struct
|
||||
import wave
|
||||
|
||||
from circuitforge_core.musicgen.backends.base import AudioFormat, MusicContinueResult
|
||||
|
||||
|
||||
class MockMusicGenBackend:
|
||||
"""Returns a silent WAV file of the requested duration."""
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
return "mock"
|
||||
|
||||
@property
|
||||
def vram_mb(self) -> int:
|
||||
return 0
|
||||
|
||||
def continue_audio(
|
||||
self,
|
||||
audio_bytes: bytes,
|
||||
*,
|
||||
description: str | None = None,
|
||||
duration_s: float = 15.0,
|
||||
prompt_duration_s: float = 10.0,
|
||||
format: AudioFormat = "wav",
|
||||
) -> MusicContinueResult:
|
||||
sample_rate = 32000
|
||||
n_samples = int(duration_s * sample_rate)
|
||||
silent_samples = b"\x00\x00" * n_samples # 16-bit PCM silence
|
||||
|
||||
buf = io.BytesIO()
|
||||
with wave.open(buf, "wb") as wf:
|
||||
wf.setnchannels(1)
|
||||
wf.setsampwidth(2)
|
||||
wf.setframerate(sample_rate)
|
||||
wf.writeframes(silent_samples)
|
||||
|
||||
return MusicContinueResult(
|
||||
audio_bytes=buf.getvalue(),
|
||||
sample_rate=sample_rate,
|
||||
duration_s=duration_s,
|
||||
format="wav",
|
||||
model="mock",
|
||||
prompt_duration_s=prompt_duration_s,
|
||||
)
|
||||
Loading…
Reference in a new issue