- platforms/: eBay platform adapter (snipe integration layer) - docs/: developer guide, module reference, getting-started docs - scripts/: utility scripts for development and deployment
129 lines
5.3 KiB
Python
129 lines
5.3 KiB
Python
#!/usr/bin/env python
|
|
"""
|
|
Standalone music continuation test — no service required.
|
|
|
|
Usage:
|
|
conda run -n cf python scripts/test_musicgen.py \
|
|
--input "/Library/Audio/Music/KAESUL/Schedule I - Original Soundtrack (2025)/KAESUL - Schedule I - Original Soundtrack - 17 - the life i lead (reveal trailer).mp3"
|
|
|
|
Options:
|
|
--input PATH Audio file to continue (any ffmpeg-readable format)
|
|
--output PATH Output WAV path (default: /tmp/continuation_output.wav)
|
|
--model MODEL MusicGen variant (default: facebook/musicgen-melody)
|
|
--duration SECS Seconds of new audio to generate (default: 30)
|
|
--prompt-duration SECS Seconds from end of song to condition on (default: 10)
|
|
--description TEXT Optional style hint, e.g. "dark ambient electronic"
|
|
--device DEVICE cuda or cpu (default: cuda)
|
|
--join Concatenate original prompt segment + continuation in output
|
|
|
|
The generated file is saved to --output. Open it in any audio player to listen.
|
|
Model weights download to /Library/Assets/LLM/musicgen/ on first run (~8 GB for melody).
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import logging
|
|
import os
|
|
import sys
|
|
import time
|
|
|
|
# Redirect HF cache before any audiocraft import
|
|
os.environ.setdefault("HF_HOME", "/Library/Assets/LLM/musicgen")
|
|
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="%(asctime)s %(levelname)s %(message)s",
|
|
)
|
|
log = logging.getLogger("test_musicgen")
|
|
|
|
|
|
def parse_args() -> argparse.Namespace:
|
|
p = argparse.ArgumentParser(description="cf-musicgen standalone test")
|
|
p.add_argument("--input", required=True, help="Input audio file path")
|
|
p.add_argument("--output", default="/tmp/continuation_output.wav")
|
|
p.add_argument("--model", default="facebook/musicgen-melody")
|
|
p.add_argument("--duration", type=float, default=30.0,
|
|
help="Seconds of new audio to generate")
|
|
p.add_argument("--prompt-duration", type=float, default=10.0,
|
|
help="Seconds from end of song used as prompt")
|
|
p.add_argument("--description", default=None,
|
|
help="Optional text description to guide the style")
|
|
p.add_argument("--device", default="cuda", choices=["cuda", "cpu"])
|
|
p.add_argument("--join", action="store_true",
|
|
help="Prepend the prompt segment to the output file")
|
|
return p.parse_args()
|
|
|
|
|
|
def main() -> None:
|
|
args = parse_args()
|
|
|
|
if not os.path.exists(args.input):
|
|
log.error("Input file not found: %s", args.input)
|
|
sys.exit(1)
|
|
|
|
log.info("Input: %s", args.input)
|
|
log.info("Model: %s", args.model)
|
|
log.info("Duration: %.1fs | Prompt: %.1fs", args.duration, args.prompt_duration)
|
|
if args.description:
|
|
log.info("Style hint: %s", args.description)
|
|
|
|
import torch
|
|
import torchaudio
|
|
|
|
log.info("Loading model (weights -> /Library/Assets/LLM/musicgen/)")
|
|
from audiocraft.models import MusicGen
|
|
|
|
model = MusicGen.get_pretrained(args.model, device=args.device)
|
|
model.set_generation_params(duration=args.duration, top_k=250, temperature=1.0, cfg_coef=3.0)
|
|
|
|
# Load input audio
|
|
wav, sr = torchaudio.load(args.input)
|
|
log.info("Loaded audio: %.1fs @ %d Hz (%d ch)", wav.shape[-1] / sr, sr, wav.shape[0])
|
|
|
|
# Trim to last prompt_duration_s seconds
|
|
max_prompt_samples = int(args.prompt_duration * sr)
|
|
prompt_wav = wav[..., -max_prompt_samples:] if wav.shape[-1] > max_prompt_samples else wav
|
|
log.info("Using %.1fs prompt from end of track", prompt_wav.shape[-1] / sr)
|
|
|
|
# MusicGen expects [batch, channels, time]
|
|
prompt_tensor = prompt_wav.unsqueeze(0).to(args.device)
|
|
|
|
log.info("Generating %.1fs of continuation ...", args.duration)
|
|
t0 = time.time()
|
|
|
|
with torch.no_grad():
|
|
output = model.generate_continuation(
|
|
prompt=prompt_tensor,
|
|
prompt_sample_rate=sr,
|
|
descriptions=[args.description],
|
|
progress=True,
|
|
)
|
|
|
|
elapsed = time.time() - t0
|
|
model_sr = model.sample_rate
|
|
output_wav = output[0].cpu() # [C, T]
|
|
actual_s = output_wav.shape[-1] / model_sr
|
|
log.info("Done in %.1fs -> %.1fs of audio at %d Hz", elapsed, actual_s, model_sr)
|
|
|
|
if args.join:
|
|
# Resample prompt to model sample rate so concatenation is seamless
|
|
prompt_resampled = torchaudio.functional.resample(prompt_wav, sr, model_sr)
|
|
# Reconcile channel count: MusicGen outputs 1ch; prompt may be stereo.
|
|
# Convert to mono by averaging if needed so cat doesn't blow up.
|
|
if prompt_resampled.shape[0] != output_wav.shape[0]:
|
|
if output_wav.shape[0] == 1 and prompt_resampled.shape[0] > 1:
|
|
prompt_resampled = prompt_resampled.mean(dim=0, keepdim=True)
|
|
elif prompt_resampled.shape[0] == 1 and output_wav.shape[0] > 1:
|
|
prompt_resampled = prompt_resampled.expand_as(output_wav)
|
|
output_wav = torch.cat([prompt_resampled, output_wav], dim=-1)
|
|
total_s = output_wav.shape[-1] / model_sr
|
|
log.info("Joined prompt + continuation: %.1fs total", total_s)
|
|
|
|
os.makedirs(os.path.dirname(os.path.abspath(args.output)), exist_ok=True)
|
|
torchaudio.save(args.output, output_wav, model_sr)
|
|
log.info("Saved: %s", args.output)
|
|
log.info("Play: ffplay %r (or open in any audio player)", args.output)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|