diff --git a/circuitforge_core/audio/__init__.py b/circuitforge_core/audio/__init__.py new file mode 100644 index 0000000..b8ceebb --- /dev/null +++ b/circuitforge_core/audio/__init__.py @@ -0,0 +1,29 @@ +""" +circuitforge_core.audio — shared PCM and audio signal utilities. + +MIT licensed. No model weights. No HuggingFace. Dependency: numpy only +(scipy optional for high-quality resampling). + +Consumers: + cf-voice — replaces hand-rolled PCM conversion in stt.py / context.py + Sparrow — torchaudio stitching, export, acoustic analysis + Avocet — audio preprocessing for classifier training corpus + Linnet — chunk accumulation for real-time tone annotation +""" +from circuitforge_core.audio.convert import ( + bytes_to_float32, + float32_to_pcm, + pcm_to_float32, +) +from circuitforge_core.audio.gate import is_silent +from circuitforge_core.audio.resample import resample +from circuitforge_core.audio.buffer import ChunkAccumulator + +__all__ = [ + "bytes_to_float32", + "float32_to_pcm", + "pcm_to_float32", + "is_silent", + "resample", + "ChunkAccumulator", +] diff --git a/circuitforge_core/audio/buffer.py b/circuitforge_core/audio/buffer.py new file mode 100644 index 0000000..aaba3c2 --- /dev/null +++ b/circuitforge_core/audio/buffer.py @@ -0,0 +1,67 @@ +""" +ChunkAccumulator — collect fixed-size audio chunks into a classify window. + +Used by cf-voice and Linnet to gather N × 100ms frames before firing +a classification pass. The window size trades latency against context: +a 2-second window (20 × 100ms) gives the classifier enough signal to +detect tone/affect reliably without lagging the conversation. +""" +from __future__ import annotations + +from collections import deque + +import numpy as np + + +class ChunkAccumulator: + """Accumulate audio chunks and flush when the window is full. + + Args: + window_chunks: Number of chunks to collect before is_ready() is True. + dtype: numpy dtype of the accumulated array. Default float32. + """ + + def __init__(self, window_chunks: int, *, dtype: np.dtype = np.float32) -> None: + if window_chunks < 1: + raise ValueError(f"window_chunks must be >= 1, got {window_chunks}") + self._window = window_chunks + self._dtype = dtype + self._buf: deque[np.ndarray] = deque() + + def accumulate(self, chunk: np.ndarray) -> None: + """Add a chunk to the buffer. Oldest chunks are dropped once the + buffer exceeds window_chunks to bound memory.""" + self._buf.append(chunk.astype(self._dtype)) + while len(self._buf) > self._window: + self._buf.popleft() + + def is_ready(self) -> bool: + """True when window_chunks have been accumulated.""" + return len(self._buf) >= self._window + + def flush(self) -> np.ndarray: + """Concatenate accumulated chunks and reset the buffer. + + Returns: + float32 ndarray of concatenated audio. + + Raises: + RuntimeError: if fewer than window_chunks have been accumulated. + """ + if not self.is_ready(): + raise RuntimeError( + f"Not enough chunks accumulated: have {len(self._buf)}, " + f"need {self._window}. Check is_ready() before calling flush()." + ) + result = np.concatenate(list(self._buf), axis=-1).astype(self._dtype) + self._buf.clear() + return result + + def reset(self) -> None: + """Discard all buffered audio without returning it.""" + self._buf.clear() + + @property + def chunk_count(self) -> int: + """Current number of buffered chunks.""" + return len(self._buf) diff --git a/circuitforge_core/audio/convert.py b/circuitforge_core/audio/convert.py new file mode 100644 index 0000000..ed0e23e --- /dev/null +++ b/circuitforge_core/audio/convert.py @@ -0,0 +1,50 @@ +""" +PCM / float32 conversion utilities. + +All functions operate on raw audio bytes or numpy arrays. No torch dependency. + +Standard pipeline: + bytes (int16 PCM) -> float32 ndarray -> signal processing -> bytes (int16 PCM) +""" +from __future__ import annotations + +import numpy as np + + +def pcm_to_float32(pcm_bytes: bytes, *, dtype: np.dtype = np.int16) -> np.ndarray: + """Convert raw PCM bytes to a float32 numpy array in [-1.0, 1.0]. + + Args: + pcm_bytes: Raw PCM audio bytes. + dtype: Sample dtype of the input. Default: int16 (standard mic input). + + Returns: + float32 ndarray, values in [-1.0, 1.0]. + """ + scale = np.iinfo(dtype).max + return np.frombuffer(pcm_bytes, dtype=dtype).astype(np.float32) / scale + + +def bytes_to_float32(pcm_bytes: bytes) -> np.ndarray: + """Alias for pcm_to_float32 with default int16 dtype. + + Matches the naming used in cf-voice context.py for easier migration. + """ + return pcm_to_float32(pcm_bytes) + + +def float32_to_pcm(audio: np.ndarray, *, dtype: np.dtype = np.int16) -> bytes: + """Convert a float32 ndarray in [-1.0, 1.0] to raw PCM bytes. + + Clips to [-1.0, 1.0] before scaling to prevent wraparound distortion. + + Args: + audio: float32 ndarray, values nominally in [-1.0, 1.0]. + dtype: Target PCM sample dtype. Default: int16. + + Returns: + Raw PCM bytes. + """ + scale = np.iinfo(dtype).max + clipped = np.clip(audio, -1.0, 1.0) + return (clipped * scale).astype(dtype).tobytes() diff --git a/circuitforge_core/audio/gate.py b/circuitforge_core/audio/gate.py new file mode 100644 index 0000000..7843bd3 --- /dev/null +++ b/circuitforge_core/audio/gate.py @@ -0,0 +1,44 @@ +""" +Energy gate — silence detection via RMS amplitude. +""" +from __future__ import annotations + +import numpy as np + +# Default threshold extracted from cf-voice stt.py. +# Signals below this RMS level are considered silent. +_DEFAULT_RMS_THRESHOLD = 0.005 + + +def is_silent( + audio: np.ndarray, + *, + rms_threshold: float = _DEFAULT_RMS_THRESHOLD, +) -> bool: + """Return True when the audio clip is effectively silent. + + Uses root-mean-square amplitude as the energy estimate. This is a fast + frame-level gate — not a VAD model. Use it to skip inference on empty + audio frames before they hit a more expensive transcription or + classification pipeline. + + Args: + audio: float32 ndarray, values in [-1.0, 1.0]. + rms_threshold: Clips with RMS below this value are silent. + Default 0.005 is conservative — genuine speech at + normal mic levels sits well above this. + + Returns: + True if silent, False if the clip contains meaningful signal. + """ + if audio.size == 0: + return True + rms = float(np.sqrt(np.mean(audio.astype(np.float32) ** 2))) + return rms < rms_threshold + + +def rms(audio: np.ndarray) -> float: + """Return the RMS amplitude of an audio array.""" + if audio.size == 0: + return 0.0 + return float(np.sqrt(np.mean(audio.astype(np.float32) ** 2))) diff --git a/circuitforge_core/audio/resample.py b/circuitforge_core/audio/resample.py new file mode 100644 index 0000000..2a36446 --- /dev/null +++ b/circuitforge_core/audio/resample.py @@ -0,0 +1,39 @@ +""" +Audio resampling — change sample rate of a float32 audio array. + +Uses scipy.signal.resample_poly when available (high-quality, anti-aliased). +Falls back to linear interpolation via numpy when scipy is absent — acceptable +for 16kHz speech but not for music. +""" +from __future__ import annotations + +import numpy as np + + +def resample(audio: np.ndarray, from_hz: int, to_hz: int) -> np.ndarray: + """Resample audio from one sample rate to another. + + Args: + audio: float32 ndarray, shape (samples,) or (channels, samples). + from_hz: Source sample rate in Hz. + to_hz: Target sample rate in Hz. + + Returns: + Resampled float32 ndarray at to_hz. + """ + if from_hz == to_hz: + return audio.astype(np.float32) + + try: + from scipy.signal import resample_poly # type: ignore[import] + from math import gcd + g = gcd(from_hz, to_hz) + up, down = to_hz // g, from_hz // g + return resample_poly(audio.astype(np.float32), up, down, axis=-1) + except ImportError: + # Numpy linear interpolation fallback — lower quality but no extra deps. + # Adequate for 16kHz ↔ 8kHz conversion on speech; avoid for music. + n_out = int(len(audio) * to_hz / from_hz) + x_old = np.linspace(0, 1, len(audio), endpoint=False) + x_new = np.linspace(0, 1, n_out, endpoint=False) + return np.interp(x_new, x_old, audio.astype(np.float32)).astype(np.float32) diff --git a/circuitforge_core/db/base.py b/circuitforge_core/db/base.py index e3bbf2e..9de3212 100644 --- a/circuitforge_core/db/base.py +++ b/circuitforge_core/db/base.py @@ -23,7 +23,7 @@ def get_connection(db_path: Path, key: str = "") -> sqlite3.Connection: if cloud_mode and key: from pysqlcipher3 import dbapi2 as _sqlcipher # type: ignore conn = _sqlcipher.connect(str(db_path), timeout=30) - conn.execute(f"PRAGMA key='{key}'") + conn.execute("PRAGMA key=?", (key,)) return conn # timeout=30: retry for up to 30s when another writer holds the lock (WAL mode # allows concurrent readers but only one writer at a time). diff --git a/tests/test_audio/__init__.py b/tests/test_audio/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_audio/test_buffer.py b/tests/test_audio/test_buffer.py new file mode 100644 index 0000000..2b99ab0 --- /dev/null +++ b/tests/test_audio/test_buffer.py @@ -0,0 +1,72 @@ +import numpy as np +import pytest +from circuitforge_core.audio.buffer import ChunkAccumulator + + +def _chunk(value: float = 0.0, size: int = 1600) -> np.ndarray: + return np.full(size, value, dtype=np.float32) + + +def test_not_ready_initially(): + acc = ChunkAccumulator(window_chunks=3) + assert acc.is_ready() is False + assert acc.chunk_count == 0 + + +def test_ready_after_window_filled(): + acc = ChunkAccumulator(window_chunks=3) + for _ in range(3): + acc.accumulate(_chunk()) + assert acc.is_ready() is True + assert acc.chunk_count == 3 + + +def test_flush_returns_concatenated(): + acc = ChunkAccumulator(window_chunks=2) + acc.accumulate(_chunk(0.1, size=100)) + acc.accumulate(_chunk(0.2, size=100)) + result = acc.flush() + assert result.shape == (200,) + assert np.allclose(result[:100], 0.1) + assert np.allclose(result[100:], 0.2) + + +def test_flush_clears_buffer(): + acc = ChunkAccumulator(window_chunks=2) + acc.accumulate(_chunk()) + acc.accumulate(_chunk()) + acc.flush() + assert acc.chunk_count == 0 + assert acc.is_ready() is False + + +def test_flush_raises_when_not_ready(): + acc = ChunkAccumulator(window_chunks=3) + acc.accumulate(_chunk()) + with pytest.raises(RuntimeError, match="Not enough chunks"): + acc.flush() + + +def test_reset_clears_buffer(): + acc = ChunkAccumulator(window_chunks=2) + acc.accumulate(_chunk()) + acc.accumulate(_chunk()) + acc.reset() + assert acc.chunk_count == 0 + + +def test_oldest_dropped_when_overfilled(): + # Accumulate more than window_chunks — oldest should be evicted + acc = ChunkAccumulator(window_chunks=2) + acc.accumulate(_chunk(1.0, size=10)) # will be evicted + acc.accumulate(_chunk(2.0, size=10)) + acc.accumulate(_chunk(3.0, size=10)) + assert acc.chunk_count == 2 + result = acc.flush() + assert np.allclose(result[:10], 2.0) + assert np.allclose(result[10:], 3.0) + + +def test_invalid_window_raises(): + with pytest.raises(ValueError): + ChunkAccumulator(window_chunks=0) diff --git a/tests/test_audio/test_convert.py b/tests/test_audio/test_convert.py new file mode 100644 index 0000000..6052d52 --- /dev/null +++ b/tests/test_audio/test_convert.py @@ -0,0 +1,57 @@ +import numpy as np +import pytest +from circuitforge_core.audio.convert import pcm_to_float32, float32_to_pcm, bytes_to_float32 + + +def _silence_pcm(n_samples: int = 1024) -> bytes: + return (np.zeros(n_samples, dtype=np.int16)).tobytes() + + +def _sine_pcm(freq_hz: float = 440.0, sample_rate: int = 16000, duration_s: float = 0.1) -> bytes: + t = np.linspace(0, duration_s, int(sample_rate * duration_s), endpoint=False) + samples = (np.sin(2 * np.pi * freq_hz * t) * 16000).astype(np.int16) + return samples.tobytes() + + +def test_pcm_to_float32_silence(): + result = pcm_to_float32(_silence_pcm()) + assert result.dtype == np.float32 + assert np.allclose(result, 0.0) + + +def test_pcm_to_float32_range(): + # Full-scale positive int16 -> ~1.0 + pcm = np.array([32767], dtype=np.int16).tobytes() + result = pcm_to_float32(pcm) + assert abs(result[0] - 1.0) < 0.001 + + # Full-scale negative int16 -> ~-1.0 + pcm = np.array([-32768], dtype=np.int16).tobytes() + result = pcm_to_float32(pcm) + assert abs(result[0] - (-32768 / 32767)) < 0.001 + + +def test_pcm_roundtrip(): + original = _sine_pcm() + float_audio = pcm_to_float32(original) + recovered = float32_to_pcm(float_audio) + # Roundtrip through float32 introduces tiny quantisation error — within 1 LSB + orig_arr = np.frombuffer(original, dtype=np.int16) + recv_arr = np.frombuffer(recovered, dtype=np.int16) + assert np.max(np.abs(orig_arr.astype(np.int32) - recv_arr.astype(np.int32))) <= 1 + + +def test_float32_to_pcm_clips(): + # Values outside [-1.0, 1.0] must be clipped, not wrap. + # int16 is asymmetric: max=32767, min=-32768. Scaling by 32767 means + # -1.0 → -32767, not -32768 — that's expected and correct. + audio = np.array([2.0, -2.0], dtype=np.float32) + result = float32_to_pcm(audio) + samples = np.frombuffer(result, dtype=np.int16) + assert samples[0] == 32767 + assert samples[1] == -32767 + + +def test_bytes_to_float32_alias(): + pcm = _sine_pcm() + assert np.allclose(bytes_to_float32(pcm), pcm_to_float32(pcm)) diff --git a/tests/test_audio/test_gate.py b/tests/test_audio/test_gate.py new file mode 100644 index 0000000..16c0ebb --- /dev/null +++ b/tests/test_audio/test_gate.py @@ -0,0 +1,38 @@ +import numpy as np +from circuitforge_core.audio.gate import is_silent, rms + + +def test_silence_detected(): + audio = np.zeros(1600, dtype=np.float32) + assert is_silent(audio) is True + + +def test_speech_level_not_silent(): + # Sine at 0.1 amplitude — well above 0.005 RMS + t = np.linspace(0, 0.1, 1600, endpoint=False) + audio = (np.sin(2 * np.pi * 440 * t) * 0.1).astype(np.float32) + assert is_silent(audio) is False + + +def test_just_below_threshold(): + # RMS exactly at 0.004 — should be silent + audio = np.full(1600, 0.004, dtype=np.float32) + assert is_silent(audio, rms_threshold=0.005) is True + + +def test_just_above_threshold(): + audio = np.full(1600, 0.006, dtype=np.float32) + assert is_silent(audio, rms_threshold=0.005) is False + + +def test_empty_array_is_silent(): + assert is_silent(np.array([], dtype=np.float32)) is True + + +def test_rms_zero_for_silence(): + assert rms(np.zeros(100, dtype=np.float32)) == 0.0 + + +def test_rms_nonzero_for_signal(): + audio = np.ones(100, dtype=np.float32) * 0.5 + assert abs(rms(audio) - 0.5) < 1e-6 diff --git a/tests/test_audio/test_resample.py b/tests/test_audio/test_resample.py new file mode 100644 index 0000000..bbab4bd --- /dev/null +++ b/tests/test_audio/test_resample.py @@ -0,0 +1,41 @@ +import numpy as np +import pytest +from circuitforge_core.audio.resample import resample + + +def _sine(freq_hz: float, sample_rate: int, duration_s: float = 0.5) -> np.ndarray: + t = np.linspace(0, duration_s, int(sample_rate * duration_s), endpoint=False) + return np.sin(2 * np.pi * freq_hz * t).astype(np.float32) + + +def test_same_rate_is_noop(): + audio = _sine(440.0, 16000) + result = resample(audio, 16000, 16000) + assert np.allclose(result, audio, atol=1e-5) + + +def test_output_length_correct(): + audio = _sine(440.0, 16000, duration_s=1.0) + result = resample(audio, 16000, 8000) + assert len(result) == 8000 + + +def test_upsample_output_length(): + audio = _sine(440.0, 8000, duration_s=1.0) + result = resample(audio, 8000, 16000) + assert len(result) == 16000 + + +def test_output_dtype_float32(): + audio = _sine(440.0, 16000) + result = resample(audio, 16000, 8000) + assert result.dtype == np.float32 + + +def test_energy_preserved_approximately(): + # RMS should be approximately the same after resampling a sine wave + audio = _sine(440.0, 16000, duration_s=1.0) + result = resample(audio, 16000, 8000) + rms_in = float(np.sqrt(np.mean(audio ** 2))) + rms_out = float(np.sqrt(np.mean(result ** 2))) + assert abs(rms_in - rms_out) < 0.05 # within 5% — resampling is not power-preserving exactly diff --git a/tests/test_db.py b/tests/test_db.py index 1b2a018..aa0362e 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -1,9 +1,15 @@ +import os import sqlite3 import tempfile from pathlib import Path import pytest from circuitforge_core.db import get_connection, run_migrations +sqlcipher_available = pytest.mark.skipif( + __import__("importlib").util.find_spec("pysqlcipher3") is None, + reason="pysqlcipher3 not installed", +) + def test_get_connection_returns_sqlite_connection(tmp_path): db = tmp_path / "test.db" @@ -61,3 +67,38 @@ def test_run_migrations_applies_in_order(tmp_path): run_migrations(conn, migrations_dir) conn.execute("INSERT INTO foo (name) VALUES ('bar')") conn.close() + + +# ── SQLCipher PRAGMA key tests (skipped when pysqlcipher3 not installed) ────── + + +@sqlcipher_available +def test_sqlcipher_key_with_special_chars_does_not_inject(tmp_path, monkeypatch): + """Key containing a single quote must not cause a SQL syntax error. + + Regression for: conn.execute(f"PRAGMA key='{key}'") — if key = "x'--" + the f-string form produced a broken PRAGMA statement. Parameterized + form (PRAGMA key=?) must handle this safely. + """ + monkeypatch.setenv("CLOUD_MODE", "1") + db = tmp_path / "enc.db" + tricky_key = "pass'word\"--inject" + # Must not raise; if the f-string form were used, this would produce + # a syntax error or silently set an incorrect key. + conn = get_connection(db, key=tricky_key) + conn.execute("CREATE TABLE t (x INTEGER)") + conn.close() + + +@sqlcipher_available +def test_sqlcipher_wrong_key_raises(tmp_path, monkeypatch): + """Opening an encrypted DB with the wrong key should raise, not silently corrupt.""" + monkeypatch.setenv("CLOUD_MODE", "1") + db = tmp_path / "enc.db" + conn = get_connection(db, key="correct-key") + conn.execute("CREATE TABLE t (x INTEGER)") + conn.close() + + with pytest.raises(Exception): + bad = get_connection(db, key="wrong-key") + bad.execute("SELECT * FROM t") # should raise on bad key diff --git a/tests/test_musicgen/__init__.py b/tests/test_musicgen/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_musicgen/test_app.py b/tests/test_musicgen/test_app.py new file mode 100644 index 0000000..84dc2da --- /dev/null +++ b/tests/test_musicgen/test_app.py @@ -0,0 +1,124 @@ +""" +Tests for the cf-musicgen FastAPI app using mock backend. +""" +import io +import os +import wave + +import pytest +from fastapi.testclient import TestClient + +import circuitforge_core.musicgen.app as musicgen_app +from circuitforge_core.musicgen.backends.mock import MockMusicGenBackend + + +@pytest.fixture(autouse=True) +def inject_mock_backend(): + """Inject mock backend before each test; restore None after.""" + original = musicgen_app._backend + musicgen_app._backend = MockMusicGenBackend() + yield + musicgen_app._backend = original + + +@pytest.fixture() +def client(): + return TestClient(musicgen_app.app) + + +def _silent_wav(duration_s: float = 1.0) -> bytes: + n = int(duration_s * 16000) + buf = io.BytesIO() + with wave.open(buf, "wb") as wf: + wf.setnchannels(1) + wf.setsampwidth(2) + wf.setframerate(16000) + wf.writeframes(b"\x00\x00" * n) + return buf.getvalue() + + +# ── /health ─────────────────────────────────────────────────────────────────── + + +def test_health_returns_ok(client): + resp = client.get("/health") + assert resp.status_code == 200 + data = resp.json() + assert data["status"] == "ok" + assert data["model"] == "mock" + assert data["vram_mb"] == 0 + + +def test_health_503_when_no_backend(client): + musicgen_app._backend = None + resp = client.get("/health") + assert resp.status_code == 503 + + +# ── /continue ───────────────────────────────────────────────────────────────── + + +def test_continue_returns_audio(client): + resp = client.post( + "/continue", + data={"duration_s": "5.0"}, + files={"audio": ("test.wav", _silent_wav(), "audio/wav")}, + ) + assert resp.status_code == 200 + assert resp.headers["content-type"] == "audio/wav" + + +def test_continue_duration_header(client): + resp = client.post( + "/continue", + data={"duration_s": "7.0"}, + files={"audio": ("test.wav", _silent_wav(), "audio/wav")}, + ) + assert resp.status_code == 200 + assert float(resp.headers["x-duration-s"]) == pytest.approx(7.0) + + +def test_continue_model_header(client): + resp = client.post( + "/continue", + data={"duration_s": "5.0"}, + files={"audio": ("test.wav", _silent_wav(), "audio/wav")}, + ) + assert resp.headers["x-model"] == "mock" + + +def test_continue_rejects_zero_duration(client): + resp = client.post( + "/continue", + data={"duration_s": "0"}, + files={"audio": ("test.wav", _silent_wav(), "audio/wav")}, + ) + assert resp.status_code == 422 + + +def test_continue_rejects_too_long_duration(client): + resp = client.post( + "/continue", + data={"duration_s": "61"}, + files={"audio": ("test.wav", _silent_wav(), "audio/wav")}, + ) + assert resp.status_code == 422 + + +def test_continue_rejects_empty_audio(client): + resp = client.post( + "/continue", + data={"duration_s": "5.0"}, + files={"audio": ("empty.wav", b"", "audio/wav")}, + ) + assert resp.status_code == 400 + + +def test_continue_503_when_no_backend(client): + musicgen_app._backend = None + resp = client.post( + "/continue", + data={"duration_s": "5.0"}, + files={"audio": ("test.wav", _silent_wav(), "audio/wav")}, + ) + assert resp.status_code == 503 diff --git a/tests/test_musicgen/test_mock_backend.py b/tests/test_musicgen/test_mock_backend.py new file mode 100644 index 0000000..c4fad49 --- /dev/null +++ b/tests/test_musicgen/test_mock_backend.py @@ -0,0 +1,97 @@ +""" +Tests for the mock MusicGen backend and shared audio encode/decode utilities. + +All tests run without a GPU or AudioCraft install. +""" +import io +import wave + +import pytest +from circuitforge_core.musicgen.backends.base import ( + MODEL_MELODY, + MODEL_SMALL, + MusicGenBackend, + MusicContinueResult, + make_musicgen_backend, +) +from circuitforge_core.musicgen.backends.mock import MockMusicGenBackend + + +# ── Mock backend ────────────────────────────────────────────────────────────── + + +def test_mock_satisfies_protocol(): + backend = MockMusicGenBackend() + assert isinstance(backend, MusicGenBackend) + + +def test_mock_model_name(): + assert MockMusicGenBackend().model_name == "mock" + + +def test_mock_vram_mb(): + assert MockMusicGenBackend().vram_mb == 0 + + +def _silent_wav(duration_s: float = 1.0, sample_rate: int = 16000) -> bytes: + n = int(duration_s * sample_rate) + buf = io.BytesIO() + with wave.open(buf, "wb") as wf: + wf.setnchannels(1) + wf.setsampwidth(2) + wf.setframerate(sample_rate) + wf.writeframes(b"\x00\x00" * n) + return buf.getvalue() + + +def test_mock_returns_result(): + backend = MockMusicGenBackend() + result = backend.continue_audio(_silent_wav(), duration_s=5.0) + assert isinstance(result, MusicContinueResult) + + +def test_mock_duration_matches_request(): + backend = MockMusicGenBackend() + result = backend.continue_audio(_silent_wav(), duration_s=7.5) + assert result.duration_s == 7.5 + + +def test_mock_returns_valid_wav(): + backend = MockMusicGenBackend() + result = backend.continue_audio(_silent_wav(), duration_s=2.0) + assert result.format == "wav" + buf = io.BytesIO(result.audio_bytes) + with wave.open(buf, "rb") as wf: + assert wf.getnframes() > 0 + + +def test_mock_sample_rate(): + backend = MockMusicGenBackend() + result = backend.continue_audio(_silent_wav()) + assert result.sample_rate == 32000 + + +def test_mock_prompt_duration_passthrough(): + backend = MockMusicGenBackend() + result = backend.continue_audio(_silent_wav(), prompt_duration_s=8.0) + assert result.prompt_duration_s == 8.0 + + +def test_mock_description_ignored(): + backend = MockMusicGenBackend() + # Should not raise regardless of description + result = backend.continue_audio(_silent_wav(), description="upbeat jazz") + assert result is not None + + +# ── make_musicgen_backend factory ───────────────────────────────────────────── + + +def test_factory_returns_mock_when_flag_set(): + backend = make_musicgen_backend(mock=True) + assert isinstance(backend, MockMusicGenBackend) + + +def test_factory_mock_for_mock_model_name(): + backend = make_musicgen_backend(model_name="mock", mock=True) + assert isinstance(backend, MockMusicGenBackend)