feat(musicgen): cf-musicgen module — MusicGen inference server
FastAPI service wrapping facebook/musicgen-* models.
Exposes POST /generate {prompt, duration_s} → audio/wav.
Registered in VRAM tiers (8GB+).
This commit is contained in:
parent
146fe97227
commit
8b357064ce
6 changed files with 418 additions and 0 deletions
1
circuitforge_core/musicgen/__init__.py
Normal file
1
circuitforge_core/musicgen/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
||||||
|
"""circuitforge_core.musicgen — music continuation service (BSL 1.1)."""
|
||||||
138
circuitforge_core/musicgen/app.py
Normal file
138
circuitforge_core/musicgen/app.py
Normal file
|
|
@ -0,0 +1,138 @@
|
||||||
|
"""
|
||||||
|
cf-musicgen FastAPI service — managed by cf-orch.
|
||||||
|
|
||||||
|
Endpoints:
|
||||||
|
GET /health -> {"status": "ok", "model": str, "vram_mb": int}
|
||||||
|
POST /continue -> audio bytes (Content-Type: audio/wav or audio/mpeg)
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python -m circuitforge_core.musicgen.app \
|
||||||
|
--model facebook/musicgen-melody \
|
||||||
|
--port 8006 \
|
||||||
|
--gpu-id 0
|
||||||
|
|
||||||
|
The service streams back raw audio bytes. Headers include:
|
||||||
|
X-Duration-S generated duration in seconds
|
||||||
|
X-Prompt-Duration-S how many seconds of the input were used as prompt
|
||||||
|
X-Model model name
|
||||||
|
X-Sample-Rate output sample rate (32000 for all MusicGen variants)
|
||||||
|
|
||||||
|
Model weights are cached at /Library/Assets/LLM/musicgen/.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from fastapi import FastAPI, File, Form, HTTPException, UploadFile
|
||||||
|
from fastapi.responses import Response
|
||||||
|
|
||||||
|
from circuitforge_core.musicgen.backends.base import (
|
||||||
|
MODEL_MELODY,
|
||||||
|
MODEL_SMALL,
|
||||||
|
AudioFormat,
|
||||||
|
MusicGenBackend,
|
||||||
|
make_musicgen_backend,
|
||||||
|
)
|
||||||
|
|
||||||
|
_CONTENT_TYPES: dict[str, str] = {
|
||||||
|
"wav": "audio/wav",
|
||||||
|
"mp3": "audio/mpeg",
|
||||||
|
}
|
||||||
|
|
||||||
|
app = FastAPI(title="cf-musicgen", version="0.1.0")
|
||||||
|
_backend: MusicGenBackend | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/health")
|
||||||
|
def health() -> dict:
|
||||||
|
if _backend is None:
|
||||||
|
raise HTTPException(503, detail="backend not initialised")
|
||||||
|
return {
|
||||||
|
"status": "ok",
|
||||||
|
"model": _backend.model_name,
|
||||||
|
"vram_mb": _backend.vram_mb,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/continue")
|
||||||
|
async def continue_audio(
|
||||||
|
audio: UploadFile = File(..., description="Audio file (WAV, MP3, FLAC, OGG, ...)"),
|
||||||
|
description: Annotated[str | None, Form()] = None,
|
||||||
|
duration_s: Annotated[float, Form()] = 15.0,
|
||||||
|
prompt_duration_s: Annotated[float, Form()] = 10.0,
|
||||||
|
format: Annotated[AudioFormat, Form()] = "wav",
|
||||||
|
) -> Response:
|
||||||
|
if _backend is None:
|
||||||
|
raise HTTPException(503, detail="backend not initialised")
|
||||||
|
if duration_s <= 0 or duration_s > 60:
|
||||||
|
raise HTTPException(422, detail="duration_s must be between 0 and 60")
|
||||||
|
if prompt_duration_s <= 0 or prompt_duration_s > 30:
|
||||||
|
raise HTTPException(422, detail="prompt_duration_s must be between 0 and 30")
|
||||||
|
|
||||||
|
audio_bytes = await audio.read()
|
||||||
|
if not audio_bytes:
|
||||||
|
raise HTTPException(400, detail="Empty audio file")
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = _backend.continue_audio(
|
||||||
|
audio_bytes,
|
||||||
|
description=description or None,
|
||||||
|
duration_s=duration_s,
|
||||||
|
prompt_duration_s=prompt_duration_s,
|
||||||
|
format=format,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logging.exception("Music continuation failed")
|
||||||
|
raise HTTPException(500, detail=str(exc)) from exc
|
||||||
|
|
||||||
|
return Response(
|
||||||
|
content=result.audio_bytes,
|
||||||
|
media_type=_CONTENT_TYPES.get(result.format, "audio/wav"),
|
||||||
|
headers={
|
||||||
|
"X-Duration-S": str(round(result.duration_s, 3)),
|
||||||
|
"X-Prompt-Duration-S": str(round(result.prompt_duration_s, 3)),
|
||||||
|
"X-Model": result.model,
|
||||||
|
"X-Sample-Rate": str(result.sample_rate),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_args() -> argparse.Namespace:
|
||||||
|
p = argparse.ArgumentParser(description="cf-musicgen service")
|
||||||
|
p.add_argument(
|
||||||
|
"--model",
|
||||||
|
default=MODEL_MELODY,
|
||||||
|
choices=[MODEL_MELODY, MODEL_SMALL, "facebook/musicgen-medium", "facebook/musicgen-large"],
|
||||||
|
help="MusicGen model variant",
|
||||||
|
)
|
||||||
|
p.add_argument("--port", type=int, default=8006)
|
||||||
|
p.add_argument("--host", default="0.0.0.0")
|
||||||
|
p.add_argument("--gpu-id", type=int, default=0,
|
||||||
|
help="CUDA device index (sets CUDA_VISIBLE_DEVICES)")
|
||||||
|
p.add_argument("--device", default="cuda", choices=["cuda", "cpu"])
|
||||||
|
p.add_argument("--mock", action="store_true",
|
||||||
|
help="Run with mock backend (no GPU, for testing)")
|
||||||
|
return p.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format="%(asctime)s %(levelname)s %(name)s %(message)s",
|
||||||
|
)
|
||||||
|
args = _parse_args()
|
||||||
|
|
||||||
|
if args.device == "cuda" and not args.mock:
|
||||||
|
os.environ.setdefault("CUDA_VISIBLE_DEVICES", str(args.gpu_id))
|
||||||
|
|
||||||
|
mock = args.mock or args.model == "mock"
|
||||||
|
device = "cpu" if mock else args.device
|
||||||
|
|
||||||
|
_backend = make_musicgen_backend(model_name=args.model, mock=mock, device=device)
|
||||||
|
|
||||||
|
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|
||||||
1
circuitforge_core/musicgen/backends/__init__.py
Normal file
1
circuitforge_core/musicgen/backends/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
||||||
|
"""MusicGen backend implementations."""
|
||||||
128
circuitforge_core/musicgen/backends/audiocraft.py
Normal file
128
circuitforge_core/musicgen/backends/audiocraft.py
Normal file
|
|
@ -0,0 +1,128 @@
|
||||||
|
"""
|
||||||
|
AudioCraft MusicGen backend — music continuation via Meta's MusicGen.
|
||||||
|
|
||||||
|
Models are downloaded to /Library/Assets/LLM/musicgen/ (HF hub cache).
|
||||||
|
The melody model (~8 GB VRAM) is the default; small (~1.5 GB) is available
|
||||||
|
for lower-VRAM nodes.
|
||||||
|
|
||||||
|
Continuation workflow:
|
||||||
|
1. Decode input audio with torchaudio (any format ffmpeg understands)
|
||||||
|
2. Trim to the last `prompt_duration_s` seconds — this anchors the generation
|
||||||
|
3. Call model.generate_continuation(prompt_waveform, prompt_sample_rate, ...)
|
||||||
|
4. Output tensor is the NEW audio only (not prompt + continuation)
|
||||||
|
5. Encode to the requested format and return
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
|
from circuitforge_core.musicgen.backends.base import (
|
||||||
|
AudioFormat,
|
||||||
|
MusicContinueResult,
|
||||||
|
decode_audio,
|
||||||
|
encode_audio,
|
||||||
|
)
|
||||||
|
|
||||||
|
# All MusicGen/AudioCraft weights land here — consistent with other CF model dirs.
|
||||||
|
_MUSICGEN_CACHE = "/Library/Assets/LLM/musicgen"
|
||||||
|
|
||||||
|
# VRAM estimates (MB) per model variant
|
||||||
|
_VRAM_MB: dict[str, int] = {
|
||||||
|
"facebook/musicgen-small": 1500,
|
||||||
|
"facebook/musicgen-medium": 4500,
|
||||||
|
"facebook/musicgen-melody": 8000,
|
||||||
|
"facebook/musicgen-large": 8500,
|
||||||
|
}
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class AudioCraftBackend:
|
||||||
|
"""MusicGen backend using Meta's AudioCraft library."""
|
||||||
|
|
||||||
|
def __init__(self, model_name: str = "facebook/musicgen-melody", device: str = "cuda") -> None:
|
||||||
|
# Redirect HF hub cache before the first import so weights go to /Library/Assets
|
||||||
|
os.environ.setdefault("HF_HOME", _MUSICGEN_CACHE)
|
||||||
|
os.makedirs(_MUSICGEN_CACHE, exist_ok=True)
|
||||||
|
|
||||||
|
from audiocraft.models import MusicGen # noqa: PLC0415
|
||||||
|
|
||||||
|
logger.info("Loading MusicGen model: %s on %s", model_name, device)
|
||||||
|
self._model = MusicGen.get_pretrained(model_name, device=device)
|
||||||
|
self._model_name = model_name
|
||||||
|
self._device = device
|
||||||
|
logger.info("MusicGen ready: %s", model_name)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model_name(self) -> str:
|
||||||
|
return self._model_name
|
||||||
|
|
||||||
|
@property
|
||||||
|
def vram_mb(self) -> int:
|
||||||
|
return _VRAM_MB.get(self._model_name, 8000)
|
||||||
|
|
||||||
|
def continue_audio(
|
||||||
|
self,
|
||||||
|
audio_bytes: bytes,
|
||||||
|
*,
|
||||||
|
description: str | None = None,
|
||||||
|
duration_s: float = 15.0,
|
||||||
|
prompt_duration_s: float = 10.0,
|
||||||
|
format: AudioFormat = "wav",
|
||||||
|
) -> MusicContinueResult:
|
||||||
|
import torch
|
||||||
|
|
||||||
|
# Decode input audio -> [C, T] tensor
|
||||||
|
wav, sr = decode_audio(audio_bytes)
|
||||||
|
|
||||||
|
# Trim to the last `prompt_duration_s` seconds to form the conditioning prompt.
|
||||||
|
# Using the end of the track (not the beginning) gives the model the musical
|
||||||
|
# context closest to where we want to continue.
|
||||||
|
max_prompt_samples = int(prompt_duration_s * sr)
|
||||||
|
if wav.shape[-1] > max_prompt_samples:
|
||||||
|
wav = wav[..., -max_prompt_samples:]
|
||||||
|
|
||||||
|
# MusicGen expects [batch, channels, time]
|
||||||
|
prompt_tensor = wav.unsqueeze(0).to(self._device)
|
||||||
|
|
||||||
|
# Build descriptions list — one entry per batch item (batch=1 here)
|
||||||
|
descriptions = [description] if description else [None]
|
||||||
|
|
||||||
|
self._model.set_generation_params(
|
||||||
|
duration=duration_s,
|
||||||
|
top_k=250,
|
||||||
|
temperature=1.0,
|
||||||
|
cfg_coef=3.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Generating %.1fs continuation (prompt=%.1fs) model=%s",
|
||||||
|
duration_s,
|
||||||
|
prompt_duration_s,
|
||||||
|
self._model_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
output = self._model.generate_continuation(
|
||||||
|
prompt=prompt_tensor,
|
||||||
|
prompt_sample_rate=sr,
|
||||||
|
descriptions=descriptions,
|
||||||
|
progress=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# output: [batch, channels, time] at model sample rate (32 kHz)
|
||||||
|
output_wav = output[0] # [C, T]
|
||||||
|
model_sr = self._model.sample_rate
|
||||||
|
|
||||||
|
actual_duration_s = output_wav.shape[-1] / model_sr
|
||||||
|
audio_bytes_out = encode_audio(output_wav, model_sr, format)
|
||||||
|
|
||||||
|
return MusicContinueResult(
|
||||||
|
audio_bytes=audio_bytes_out,
|
||||||
|
sample_rate=model_sr,
|
||||||
|
duration_s=actual_duration_s,
|
||||||
|
format=format,
|
||||||
|
model=self._model_name,
|
||||||
|
prompt_duration_s=prompt_duration_s,
|
||||||
|
)
|
||||||
97
circuitforge_core/musicgen/backends/base.py
Normal file
97
circuitforge_core/musicgen/backends/base.py
Normal file
|
|
@ -0,0 +1,97 @@
|
||||||
|
"""
|
||||||
|
MusicGenBackend Protocol — backend-agnostic music continuation interface.
|
||||||
|
|
||||||
|
All backends accept an audio prompt (raw bytes, any ffmpeg-readable format) and
|
||||||
|
return MusicContinueResult with the generated continuation as audio bytes.
|
||||||
|
|
||||||
|
The continuation is the *new* audio only (not prompt + continuation). Callers
|
||||||
|
that want a seamless joined file can concatenate the original + result themselves.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import io
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Literal, Protocol, runtime_checkable
|
||||||
|
|
||||||
|
AudioFormat = Literal["wav", "mp3"]
|
||||||
|
|
||||||
|
MODEL_SMALL = "facebook/musicgen-small"
|
||||||
|
MODEL_MELODY = "facebook/musicgen-melody"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class MusicContinueResult:
|
||||||
|
audio_bytes: bytes
|
||||||
|
sample_rate: int
|
||||||
|
duration_s: float
|
||||||
|
format: AudioFormat
|
||||||
|
model: str
|
||||||
|
prompt_duration_s: float
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class MusicGenBackend(Protocol):
|
||||||
|
def continue_audio(
|
||||||
|
self,
|
||||||
|
audio_bytes: bytes,
|
||||||
|
*,
|
||||||
|
description: str | None = None,
|
||||||
|
duration_s: float = 15.0,
|
||||||
|
prompt_duration_s: float = 10.0,
|
||||||
|
format: AudioFormat = "wav",
|
||||||
|
) -> MusicContinueResult: ...
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model_name(self) -> str: ...
|
||||||
|
|
||||||
|
@property
|
||||||
|
def vram_mb(self) -> int: ...
|
||||||
|
|
||||||
|
|
||||||
|
def encode_audio(wav_tensor, sample_rate: int, format: AudioFormat) -> bytes:
|
||||||
|
"""Encode a [C, T] or [1, C, T] torch tensor to audio bytes."""
|
||||||
|
import io
|
||||||
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
|
||||||
|
wav = wav_tensor
|
||||||
|
if wav.dim() == 3:
|
||||||
|
wav = wav.squeeze(0) # [1, C, T] -> [C, T]
|
||||||
|
if wav.dim() == 1:
|
||||||
|
wav = wav.unsqueeze(0) # [T] -> [1, T]
|
||||||
|
wav = wav.to(torch.float32).cpu()
|
||||||
|
|
||||||
|
buf = io.BytesIO()
|
||||||
|
if format == "wav":
|
||||||
|
torchaudio.save(buf, wav, sample_rate, format="wav")
|
||||||
|
elif format == "mp3":
|
||||||
|
try:
|
||||||
|
torchaudio.save(buf, wav, sample_rate, format="mp3")
|
||||||
|
except Exception:
|
||||||
|
# ffmpeg backend not available; fall back to wav
|
||||||
|
buf = io.BytesIO()
|
||||||
|
torchaudio.save(buf, wav, sample_rate, format="wav")
|
||||||
|
return buf.getvalue()
|
||||||
|
|
||||||
|
|
||||||
|
def decode_audio(audio_bytes: bytes) -> tuple:
|
||||||
|
"""Decode arbitrary audio bytes to (waveform [C, T], sample_rate)."""
|
||||||
|
import io
|
||||||
|
import torchaudio
|
||||||
|
|
||||||
|
buf = io.BytesIO(audio_bytes)
|
||||||
|
wav, sr = torchaudio.load(buf)
|
||||||
|
return wav, sr
|
||||||
|
|
||||||
|
|
||||||
|
def make_musicgen_backend(
|
||||||
|
model_name: str = MODEL_MELODY,
|
||||||
|
*,
|
||||||
|
mock: bool = False,
|
||||||
|
device: str = "cuda",
|
||||||
|
) -> MusicGenBackend:
|
||||||
|
if mock:
|
||||||
|
from circuitforge_core.musicgen.backends.mock import MockMusicGenBackend
|
||||||
|
return MockMusicGenBackend()
|
||||||
|
from circuitforge_core.musicgen.backends.audiocraft import AudioCraftBackend
|
||||||
|
return AudioCraftBackend(model_name=model_name, device=device)
|
||||||
53
circuitforge_core/musicgen/backends/mock.py
Normal file
53
circuitforge_core/musicgen/backends/mock.py
Normal file
|
|
@ -0,0 +1,53 @@
|
||||||
|
"""
|
||||||
|
Mock MusicGenBackend — returns silent WAV audio; no GPU required.
|
||||||
|
|
||||||
|
Used in unit tests and CI where GPU is unavailable.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import io
|
||||||
|
import struct
|
||||||
|
import wave
|
||||||
|
|
||||||
|
from circuitforge_core.musicgen.backends.base import AudioFormat, MusicContinueResult
|
||||||
|
|
||||||
|
|
||||||
|
class MockMusicGenBackend:
|
||||||
|
"""Returns a silent WAV file of the requested duration."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model_name(self) -> str:
|
||||||
|
return "mock"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def vram_mb(self) -> int:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def continue_audio(
|
||||||
|
self,
|
||||||
|
audio_bytes: bytes,
|
||||||
|
*,
|
||||||
|
description: str | None = None,
|
||||||
|
duration_s: float = 15.0,
|
||||||
|
prompt_duration_s: float = 10.0,
|
||||||
|
format: AudioFormat = "wav",
|
||||||
|
) -> MusicContinueResult:
|
||||||
|
sample_rate = 32000
|
||||||
|
n_samples = int(duration_s * sample_rate)
|
||||||
|
silent_samples = b"\x00\x00" * n_samples # 16-bit PCM silence
|
||||||
|
|
||||||
|
buf = io.BytesIO()
|
||||||
|
with wave.open(buf, "wb") as wf:
|
||||||
|
wf.setnchannels(1)
|
||||||
|
wf.setsampwidth(2)
|
||||||
|
wf.setframerate(sample_rate)
|
||||||
|
wf.writeframes(silent_samples)
|
||||||
|
|
||||||
|
return MusicContinueResult(
|
||||||
|
audio_bytes=buf.getvalue(),
|
||||||
|
sample_rate=sample_rate,
|
||||||
|
duration_s=duration_s,
|
||||||
|
format="wav",
|
||||||
|
model="mock",
|
||||||
|
prompt_duration_s=prompt_duration_s,
|
||||||
|
)
|
||||||
Loading…
Reference in a new issue