circuitforge-core/scripts/test_musicgen.py
pyr0ball 383897f990
Some checks are pending
CI / test (push) Waiting to run
Mirror / mirror (push) Waiting to run
Release — PyPI / release (push) Waiting to run
feat: platforms module + docs + scripts
- platforms/: eBay platform adapter (snipe integration layer)
- docs/: developer guide, module reference, getting-started docs
- scripts/: utility scripts for development and deployment
2026-04-24 15:23:16 -07:00

129 lines
5.3 KiB
Python

#!/usr/bin/env python
"""
Standalone music continuation test — no service required.
Usage:
conda run -n cf python scripts/test_musicgen.py \
--input "/Library/Audio/Music/KAESUL/Schedule I - Original Soundtrack (2025)/KAESUL - Schedule I - Original Soundtrack - 17 - the life i lead (reveal trailer).mp3"
Options:
--input PATH Audio file to continue (any ffmpeg-readable format)
--output PATH Output WAV path (default: /tmp/continuation_output.wav)
--model MODEL MusicGen variant (default: facebook/musicgen-melody)
--duration SECS Seconds of new audio to generate (default: 30)
--prompt-duration SECS Seconds from end of song to condition on (default: 10)
--description TEXT Optional style hint, e.g. "dark ambient electronic"
--device DEVICE cuda or cpu (default: cuda)
--join Concatenate original prompt segment + continuation in output
The generated file is saved to --output. Open it in any audio player to listen.
Model weights download to /Library/Assets/LLM/musicgen/ on first run (~8 GB for melody).
"""
from __future__ import annotations
import argparse
import logging
import os
import sys
import time
# Redirect HF cache before any audiocraft import
os.environ.setdefault("HF_HOME", "/Library/Assets/LLM/musicgen")
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s %(message)s",
)
log = logging.getLogger("test_musicgen")
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(description="cf-musicgen standalone test")
p.add_argument("--input", required=True, help="Input audio file path")
p.add_argument("--output", default="/tmp/continuation_output.wav")
p.add_argument("--model", default="facebook/musicgen-melody")
p.add_argument("--duration", type=float, default=30.0,
help="Seconds of new audio to generate")
p.add_argument("--prompt-duration", type=float, default=10.0,
help="Seconds from end of song used as prompt")
p.add_argument("--description", default=None,
help="Optional text description to guide the style")
p.add_argument("--device", default="cuda", choices=["cuda", "cpu"])
p.add_argument("--join", action="store_true",
help="Prepend the prompt segment to the output file")
return p.parse_args()
def main() -> None:
args = parse_args()
if not os.path.exists(args.input):
log.error("Input file not found: %s", args.input)
sys.exit(1)
log.info("Input: %s", args.input)
log.info("Model: %s", args.model)
log.info("Duration: %.1fs | Prompt: %.1fs", args.duration, args.prompt_duration)
if args.description:
log.info("Style hint: %s", args.description)
import torch
import torchaudio
log.info("Loading model (weights -> /Library/Assets/LLM/musicgen/)")
from audiocraft.models import MusicGen
model = MusicGen.get_pretrained(args.model, device=args.device)
model.set_generation_params(duration=args.duration, top_k=250, temperature=1.0, cfg_coef=3.0)
# Load input audio
wav, sr = torchaudio.load(args.input)
log.info("Loaded audio: %.1fs @ %d Hz (%d ch)", wav.shape[-1] / sr, sr, wav.shape[0])
# Trim to last prompt_duration_s seconds
max_prompt_samples = int(args.prompt_duration * sr)
prompt_wav = wav[..., -max_prompt_samples:] if wav.shape[-1] > max_prompt_samples else wav
log.info("Using %.1fs prompt from end of track", prompt_wav.shape[-1] / sr)
# MusicGen expects [batch, channels, time]
prompt_tensor = prompt_wav.unsqueeze(0).to(args.device)
log.info("Generating %.1fs of continuation ...", args.duration)
t0 = time.time()
with torch.no_grad():
output = model.generate_continuation(
prompt=prompt_tensor,
prompt_sample_rate=sr,
descriptions=[args.description],
progress=True,
)
elapsed = time.time() - t0
model_sr = model.sample_rate
output_wav = output[0].cpu() # [C, T]
actual_s = output_wav.shape[-1] / model_sr
log.info("Done in %.1fs -> %.1fs of audio at %d Hz", elapsed, actual_s, model_sr)
if args.join:
# Resample prompt to model sample rate so concatenation is seamless
prompt_resampled = torchaudio.functional.resample(prompt_wav, sr, model_sr)
# Reconcile channel count: MusicGen outputs 1ch; prompt may be stereo.
# Convert to mono by averaging if needed so cat doesn't blow up.
if prompt_resampled.shape[0] != output_wav.shape[0]:
if output_wav.shape[0] == 1 and prompt_resampled.shape[0] > 1:
prompt_resampled = prompt_resampled.mean(dim=0, keepdim=True)
elif prompt_resampled.shape[0] == 1 and output_wav.shape[0] > 1:
prompt_resampled = prompt_resampled.expand_as(output_wav)
output_wav = torch.cat([prompt_resampled, output_wav], dim=-1)
total_s = output_wav.shape[-1] / model_sr
log.info("Joined prompt + continuation: %.1fs total", total_s)
os.makedirs(os.path.dirname(os.path.abspath(args.output)), exist_ok=True)
torchaudio.save(args.output, output_wav, model_sr)
log.info("Saved: %s", args.output)
log.info("Play: ffplay %r (or open in any audio player)", args.output)
if __name__ == "__main__":
main()