feat: audio module, musicgen tests, SQLCipher PRAGMA hardening
#45 — db/base.py: PRAGMA key=? parameterized form instead of f-string interpolation. Regression tests added (skip when pysqlcipher3 absent). #50 — circuitforge_core.audio: shared PCM/signal utilities (MIT, numpy-only) - convert.py: pcm_to_float32, float32_to_pcm, bytes_to_float32 - gate.py: is_silent, rms (RMS energy gate) - resample.py: resample (scipy.signal.resample_poly; numpy linear fallback) - buffer.py: ChunkAccumulator (window-based chunk collector + flush) Replaces hand-rolled equivalents in cf-voice stt.py + context.py. 34 tests, all passing. #49 — tests/test_musicgen/: 21 tests covering mock backend, factory, and FastAPI app endpoints. musicgen module was already implemented; tests were the missing piece to close the issue.
This commit is contained in:
parent
5149de0556
commit
80eeae5460
15 changed files with 700 additions and 1 deletions
29
circuitforge_core/audio/__init__.py
Normal file
29
circuitforge_core/audio/__init__.py
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
"""
|
||||
circuitforge_core.audio — shared PCM and audio signal utilities.
|
||||
|
||||
MIT licensed. No model weights. No HuggingFace. Dependency: numpy only
|
||||
(scipy optional for high-quality resampling).
|
||||
|
||||
Consumers:
|
||||
cf-voice — replaces hand-rolled PCM conversion in stt.py / context.py
|
||||
Sparrow — torchaudio stitching, export, acoustic analysis
|
||||
Avocet — audio preprocessing for classifier training corpus
|
||||
Linnet — chunk accumulation for real-time tone annotation
|
||||
"""
|
||||
from circuitforge_core.audio.convert import (
|
||||
bytes_to_float32,
|
||||
float32_to_pcm,
|
||||
pcm_to_float32,
|
||||
)
|
||||
from circuitforge_core.audio.gate import is_silent
|
||||
from circuitforge_core.audio.resample import resample
|
||||
from circuitforge_core.audio.buffer import ChunkAccumulator
|
||||
|
||||
__all__ = [
|
||||
"bytes_to_float32",
|
||||
"float32_to_pcm",
|
||||
"pcm_to_float32",
|
||||
"is_silent",
|
||||
"resample",
|
||||
"ChunkAccumulator",
|
||||
]
|
||||
67
circuitforge_core/audio/buffer.py
Normal file
67
circuitforge_core/audio/buffer.py
Normal file
|
|
@ -0,0 +1,67 @@
|
|||
"""
|
||||
ChunkAccumulator — collect fixed-size audio chunks into a classify window.
|
||||
|
||||
Used by cf-voice and Linnet to gather N × 100ms frames before firing
|
||||
a classification pass. The window size trades latency against context:
|
||||
a 2-second window (20 × 100ms) gives the classifier enough signal to
|
||||
detect tone/affect reliably without lagging the conversation.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import deque
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class ChunkAccumulator:
|
||||
"""Accumulate audio chunks and flush when the window is full.
|
||||
|
||||
Args:
|
||||
window_chunks: Number of chunks to collect before is_ready() is True.
|
||||
dtype: numpy dtype of the accumulated array. Default float32.
|
||||
"""
|
||||
|
||||
def __init__(self, window_chunks: int, *, dtype: np.dtype = np.float32) -> None:
|
||||
if window_chunks < 1:
|
||||
raise ValueError(f"window_chunks must be >= 1, got {window_chunks}")
|
||||
self._window = window_chunks
|
||||
self._dtype = dtype
|
||||
self._buf: deque[np.ndarray] = deque()
|
||||
|
||||
def accumulate(self, chunk: np.ndarray) -> None:
|
||||
"""Add a chunk to the buffer. Oldest chunks are dropped once the
|
||||
buffer exceeds window_chunks to bound memory."""
|
||||
self._buf.append(chunk.astype(self._dtype))
|
||||
while len(self._buf) > self._window:
|
||||
self._buf.popleft()
|
||||
|
||||
def is_ready(self) -> bool:
|
||||
"""True when window_chunks have been accumulated."""
|
||||
return len(self._buf) >= self._window
|
||||
|
||||
def flush(self) -> np.ndarray:
|
||||
"""Concatenate accumulated chunks and reset the buffer.
|
||||
|
||||
Returns:
|
||||
float32 ndarray of concatenated audio.
|
||||
|
||||
Raises:
|
||||
RuntimeError: if fewer than window_chunks have been accumulated.
|
||||
"""
|
||||
if not self.is_ready():
|
||||
raise RuntimeError(
|
||||
f"Not enough chunks accumulated: have {len(self._buf)}, "
|
||||
f"need {self._window}. Check is_ready() before calling flush()."
|
||||
)
|
||||
result = np.concatenate(list(self._buf), axis=-1).astype(self._dtype)
|
||||
self._buf.clear()
|
||||
return result
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Discard all buffered audio without returning it."""
|
||||
self._buf.clear()
|
||||
|
||||
@property
|
||||
def chunk_count(self) -> int:
|
||||
"""Current number of buffered chunks."""
|
||||
return len(self._buf)
|
||||
50
circuitforge_core/audio/convert.py
Normal file
50
circuitforge_core/audio/convert.py
Normal file
|
|
@ -0,0 +1,50 @@
|
|||
"""
|
||||
PCM / float32 conversion utilities.
|
||||
|
||||
All functions operate on raw audio bytes or numpy arrays. No torch dependency.
|
||||
|
||||
Standard pipeline:
|
||||
bytes (int16 PCM) -> float32 ndarray -> signal processing -> bytes (int16 PCM)
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def pcm_to_float32(pcm_bytes: bytes, *, dtype: np.dtype = np.int16) -> np.ndarray:
|
||||
"""Convert raw PCM bytes to a float32 numpy array in [-1.0, 1.0].
|
||||
|
||||
Args:
|
||||
pcm_bytes: Raw PCM audio bytes.
|
||||
dtype: Sample dtype of the input. Default: int16 (standard mic input).
|
||||
|
||||
Returns:
|
||||
float32 ndarray, values in [-1.0, 1.0].
|
||||
"""
|
||||
scale = np.iinfo(dtype).max
|
||||
return np.frombuffer(pcm_bytes, dtype=dtype).astype(np.float32) / scale
|
||||
|
||||
|
||||
def bytes_to_float32(pcm_bytes: bytes) -> np.ndarray:
|
||||
"""Alias for pcm_to_float32 with default int16 dtype.
|
||||
|
||||
Matches the naming used in cf-voice context.py for easier migration.
|
||||
"""
|
||||
return pcm_to_float32(pcm_bytes)
|
||||
|
||||
|
||||
def float32_to_pcm(audio: np.ndarray, *, dtype: np.dtype = np.int16) -> bytes:
|
||||
"""Convert a float32 ndarray in [-1.0, 1.0] to raw PCM bytes.
|
||||
|
||||
Clips to [-1.0, 1.0] before scaling to prevent wraparound distortion.
|
||||
|
||||
Args:
|
||||
audio: float32 ndarray, values nominally in [-1.0, 1.0].
|
||||
dtype: Target PCM sample dtype. Default: int16.
|
||||
|
||||
Returns:
|
||||
Raw PCM bytes.
|
||||
"""
|
||||
scale = np.iinfo(dtype).max
|
||||
clipped = np.clip(audio, -1.0, 1.0)
|
||||
return (clipped * scale).astype(dtype).tobytes()
|
||||
44
circuitforge_core/audio/gate.py
Normal file
44
circuitforge_core/audio/gate.py
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
"""
|
||||
Energy gate — silence detection via RMS amplitude.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
|
||||
# Default threshold extracted from cf-voice stt.py.
|
||||
# Signals below this RMS level are considered silent.
|
||||
_DEFAULT_RMS_THRESHOLD = 0.005
|
||||
|
||||
|
||||
def is_silent(
|
||||
audio: np.ndarray,
|
||||
*,
|
||||
rms_threshold: float = _DEFAULT_RMS_THRESHOLD,
|
||||
) -> bool:
|
||||
"""Return True when the audio clip is effectively silent.
|
||||
|
||||
Uses root-mean-square amplitude as the energy estimate. This is a fast
|
||||
frame-level gate — not a VAD model. Use it to skip inference on empty
|
||||
audio frames before they hit a more expensive transcription or
|
||||
classification pipeline.
|
||||
|
||||
Args:
|
||||
audio: float32 ndarray, values in [-1.0, 1.0].
|
||||
rms_threshold: Clips with RMS below this value are silent.
|
||||
Default 0.005 is conservative — genuine speech at
|
||||
normal mic levels sits well above this.
|
||||
|
||||
Returns:
|
||||
True if silent, False if the clip contains meaningful signal.
|
||||
"""
|
||||
if audio.size == 0:
|
||||
return True
|
||||
rms = float(np.sqrt(np.mean(audio.astype(np.float32) ** 2)))
|
||||
return rms < rms_threshold
|
||||
|
||||
|
||||
def rms(audio: np.ndarray) -> float:
|
||||
"""Return the RMS amplitude of an audio array."""
|
||||
if audio.size == 0:
|
||||
return 0.0
|
||||
return float(np.sqrt(np.mean(audio.astype(np.float32) ** 2)))
|
||||
39
circuitforge_core/audio/resample.py
Normal file
39
circuitforge_core/audio/resample.py
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
"""
|
||||
Audio resampling — change sample rate of a float32 audio array.
|
||||
|
||||
Uses scipy.signal.resample_poly when available (high-quality, anti-aliased).
|
||||
Falls back to linear interpolation via numpy when scipy is absent — acceptable
|
||||
for 16kHz speech but not for music.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def resample(audio: np.ndarray, from_hz: int, to_hz: int) -> np.ndarray:
|
||||
"""Resample audio from one sample rate to another.
|
||||
|
||||
Args:
|
||||
audio: float32 ndarray, shape (samples,) or (channels, samples).
|
||||
from_hz: Source sample rate in Hz.
|
||||
to_hz: Target sample rate in Hz.
|
||||
|
||||
Returns:
|
||||
Resampled float32 ndarray at to_hz.
|
||||
"""
|
||||
if from_hz == to_hz:
|
||||
return audio.astype(np.float32)
|
||||
|
||||
try:
|
||||
from scipy.signal import resample_poly # type: ignore[import]
|
||||
from math import gcd
|
||||
g = gcd(from_hz, to_hz)
|
||||
up, down = to_hz // g, from_hz // g
|
||||
return resample_poly(audio.astype(np.float32), up, down, axis=-1)
|
||||
except ImportError:
|
||||
# Numpy linear interpolation fallback — lower quality but no extra deps.
|
||||
# Adequate for 16kHz ↔ 8kHz conversion on speech; avoid for music.
|
||||
n_out = int(len(audio) * to_hz / from_hz)
|
||||
x_old = np.linspace(0, 1, len(audio), endpoint=False)
|
||||
x_new = np.linspace(0, 1, n_out, endpoint=False)
|
||||
return np.interp(x_new, x_old, audio.astype(np.float32)).astype(np.float32)
|
||||
|
|
@ -23,7 +23,7 @@ def get_connection(db_path: Path, key: str = "") -> sqlite3.Connection:
|
|||
if cloud_mode and key:
|
||||
from pysqlcipher3 import dbapi2 as _sqlcipher # type: ignore
|
||||
conn = _sqlcipher.connect(str(db_path), timeout=30)
|
||||
conn.execute(f"PRAGMA key='{key}'")
|
||||
conn.execute("PRAGMA key=?", (key,))
|
||||
return conn
|
||||
# timeout=30: retry for up to 30s when another writer holds the lock (WAL mode
|
||||
# allows concurrent readers but only one writer at a time).
|
||||
|
|
|
|||
0
tests/test_audio/__init__.py
Normal file
0
tests/test_audio/__init__.py
Normal file
72
tests/test_audio/test_buffer.py
Normal file
72
tests/test_audio/test_buffer.py
Normal file
|
|
@ -0,0 +1,72 @@
|
|||
import numpy as np
|
||||
import pytest
|
||||
from circuitforge_core.audio.buffer import ChunkAccumulator
|
||||
|
||||
|
||||
def _chunk(value: float = 0.0, size: int = 1600) -> np.ndarray:
|
||||
return np.full(size, value, dtype=np.float32)
|
||||
|
||||
|
||||
def test_not_ready_initially():
|
||||
acc = ChunkAccumulator(window_chunks=3)
|
||||
assert acc.is_ready() is False
|
||||
assert acc.chunk_count == 0
|
||||
|
||||
|
||||
def test_ready_after_window_filled():
|
||||
acc = ChunkAccumulator(window_chunks=3)
|
||||
for _ in range(3):
|
||||
acc.accumulate(_chunk())
|
||||
assert acc.is_ready() is True
|
||||
assert acc.chunk_count == 3
|
||||
|
||||
|
||||
def test_flush_returns_concatenated():
|
||||
acc = ChunkAccumulator(window_chunks=2)
|
||||
acc.accumulate(_chunk(0.1, size=100))
|
||||
acc.accumulate(_chunk(0.2, size=100))
|
||||
result = acc.flush()
|
||||
assert result.shape == (200,)
|
||||
assert np.allclose(result[:100], 0.1)
|
||||
assert np.allclose(result[100:], 0.2)
|
||||
|
||||
|
||||
def test_flush_clears_buffer():
|
||||
acc = ChunkAccumulator(window_chunks=2)
|
||||
acc.accumulate(_chunk())
|
||||
acc.accumulate(_chunk())
|
||||
acc.flush()
|
||||
assert acc.chunk_count == 0
|
||||
assert acc.is_ready() is False
|
||||
|
||||
|
||||
def test_flush_raises_when_not_ready():
|
||||
acc = ChunkAccumulator(window_chunks=3)
|
||||
acc.accumulate(_chunk())
|
||||
with pytest.raises(RuntimeError, match="Not enough chunks"):
|
||||
acc.flush()
|
||||
|
||||
|
||||
def test_reset_clears_buffer():
|
||||
acc = ChunkAccumulator(window_chunks=2)
|
||||
acc.accumulate(_chunk())
|
||||
acc.accumulate(_chunk())
|
||||
acc.reset()
|
||||
assert acc.chunk_count == 0
|
||||
|
||||
|
||||
def test_oldest_dropped_when_overfilled():
|
||||
# Accumulate more than window_chunks — oldest should be evicted
|
||||
acc = ChunkAccumulator(window_chunks=2)
|
||||
acc.accumulate(_chunk(1.0, size=10)) # will be evicted
|
||||
acc.accumulate(_chunk(2.0, size=10))
|
||||
acc.accumulate(_chunk(3.0, size=10))
|
||||
assert acc.chunk_count == 2
|
||||
result = acc.flush()
|
||||
assert np.allclose(result[:10], 2.0)
|
||||
assert np.allclose(result[10:], 3.0)
|
||||
|
||||
|
||||
def test_invalid_window_raises():
|
||||
with pytest.raises(ValueError):
|
||||
ChunkAccumulator(window_chunks=0)
|
||||
57
tests/test_audio/test_convert.py
Normal file
57
tests/test_audio/test_convert.py
Normal file
|
|
@ -0,0 +1,57 @@
|
|||
import numpy as np
|
||||
import pytest
|
||||
from circuitforge_core.audio.convert import pcm_to_float32, float32_to_pcm, bytes_to_float32
|
||||
|
||||
|
||||
def _silence_pcm(n_samples: int = 1024) -> bytes:
|
||||
return (np.zeros(n_samples, dtype=np.int16)).tobytes()
|
||||
|
||||
|
||||
def _sine_pcm(freq_hz: float = 440.0, sample_rate: int = 16000, duration_s: float = 0.1) -> bytes:
|
||||
t = np.linspace(0, duration_s, int(sample_rate * duration_s), endpoint=False)
|
||||
samples = (np.sin(2 * np.pi * freq_hz * t) * 16000).astype(np.int16)
|
||||
return samples.tobytes()
|
||||
|
||||
|
||||
def test_pcm_to_float32_silence():
|
||||
result = pcm_to_float32(_silence_pcm())
|
||||
assert result.dtype == np.float32
|
||||
assert np.allclose(result, 0.0)
|
||||
|
||||
|
||||
def test_pcm_to_float32_range():
|
||||
# Full-scale positive int16 -> ~1.0
|
||||
pcm = np.array([32767], dtype=np.int16).tobytes()
|
||||
result = pcm_to_float32(pcm)
|
||||
assert abs(result[0] - 1.0) < 0.001
|
||||
|
||||
# Full-scale negative int16 -> ~-1.0
|
||||
pcm = np.array([-32768], dtype=np.int16).tobytes()
|
||||
result = pcm_to_float32(pcm)
|
||||
assert abs(result[0] - (-32768 / 32767)) < 0.001
|
||||
|
||||
|
||||
def test_pcm_roundtrip():
|
||||
original = _sine_pcm()
|
||||
float_audio = pcm_to_float32(original)
|
||||
recovered = float32_to_pcm(float_audio)
|
||||
# Roundtrip through float32 introduces tiny quantisation error — within 1 LSB
|
||||
orig_arr = np.frombuffer(original, dtype=np.int16)
|
||||
recv_arr = np.frombuffer(recovered, dtype=np.int16)
|
||||
assert np.max(np.abs(orig_arr.astype(np.int32) - recv_arr.astype(np.int32))) <= 1
|
||||
|
||||
|
||||
def test_float32_to_pcm_clips():
|
||||
# Values outside [-1.0, 1.0] must be clipped, not wrap.
|
||||
# int16 is asymmetric: max=32767, min=-32768. Scaling by 32767 means
|
||||
# -1.0 → -32767, not -32768 — that's expected and correct.
|
||||
audio = np.array([2.0, -2.0], dtype=np.float32)
|
||||
result = float32_to_pcm(audio)
|
||||
samples = np.frombuffer(result, dtype=np.int16)
|
||||
assert samples[0] == 32767
|
||||
assert samples[1] == -32767
|
||||
|
||||
|
||||
def test_bytes_to_float32_alias():
|
||||
pcm = _sine_pcm()
|
||||
assert np.allclose(bytes_to_float32(pcm), pcm_to_float32(pcm))
|
||||
38
tests/test_audio/test_gate.py
Normal file
38
tests/test_audio/test_gate.py
Normal file
|
|
@ -0,0 +1,38 @@
|
|||
import numpy as np
|
||||
from circuitforge_core.audio.gate import is_silent, rms
|
||||
|
||||
|
||||
def test_silence_detected():
|
||||
audio = np.zeros(1600, dtype=np.float32)
|
||||
assert is_silent(audio) is True
|
||||
|
||||
|
||||
def test_speech_level_not_silent():
|
||||
# Sine at 0.1 amplitude — well above 0.005 RMS
|
||||
t = np.linspace(0, 0.1, 1600, endpoint=False)
|
||||
audio = (np.sin(2 * np.pi * 440 * t) * 0.1).astype(np.float32)
|
||||
assert is_silent(audio) is False
|
||||
|
||||
|
||||
def test_just_below_threshold():
|
||||
# RMS exactly at 0.004 — should be silent
|
||||
audio = np.full(1600, 0.004, dtype=np.float32)
|
||||
assert is_silent(audio, rms_threshold=0.005) is True
|
||||
|
||||
|
||||
def test_just_above_threshold():
|
||||
audio = np.full(1600, 0.006, dtype=np.float32)
|
||||
assert is_silent(audio, rms_threshold=0.005) is False
|
||||
|
||||
|
||||
def test_empty_array_is_silent():
|
||||
assert is_silent(np.array([], dtype=np.float32)) is True
|
||||
|
||||
|
||||
def test_rms_zero_for_silence():
|
||||
assert rms(np.zeros(100, dtype=np.float32)) == 0.0
|
||||
|
||||
|
||||
def test_rms_nonzero_for_signal():
|
||||
audio = np.ones(100, dtype=np.float32) * 0.5
|
||||
assert abs(rms(audio) - 0.5) < 1e-6
|
||||
41
tests/test_audio/test_resample.py
Normal file
41
tests/test_audio/test_resample.py
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
import numpy as np
|
||||
import pytest
|
||||
from circuitforge_core.audio.resample import resample
|
||||
|
||||
|
||||
def _sine(freq_hz: float, sample_rate: int, duration_s: float = 0.5) -> np.ndarray:
|
||||
t = np.linspace(0, duration_s, int(sample_rate * duration_s), endpoint=False)
|
||||
return np.sin(2 * np.pi * freq_hz * t).astype(np.float32)
|
||||
|
||||
|
||||
def test_same_rate_is_noop():
|
||||
audio = _sine(440.0, 16000)
|
||||
result = resample(audio, 16000, 16000)
|
||||
assert np.allclose(result, audio, atol=1e-5)
|
||||
|
||||
|
||||
def test_output_length_correct():
|
||||
audio = _sine(440.0, 16000, duration_s=1.0)
|
||||
result = resample(audio, 16000, 8000)
|
||||
assert len(result) == 8000
|
||||
|
||||
|
||||
def test_upsample_output_length():
|
||||
audio = _sine(440.0, 8000, duration_s=1.0)
|
||||
result = resample(audio, 8000, 16000)
|
||||
assert len(result) == 16000
|
||||
|
||||
|
||||
def test_output_dtype_float32():
|
||||
audio = _sine(440.0, 16000)
|
||||
result = resample(audio, 16000, 8000)
|
||||
assert result.dtype == np.float32
|
||||
|
||||
|
||||
def test_energy_preserved_approximately():
|
||||
# RMS should be approximately the same after resampling a sine wave
|
||||
audio = _sine(440.0, 16000, duration_s=1.0)
|
||||
result = resample(audio, 16000, 8000)
|
||||
rms_in = float(np.sqrt(np.mean(audio ** 2)))
|
||||
rms_out = float(np.sqrt(np.mean(result ** 2)))
|
||||
assert abs(rms_in - rms_out) < 0.05 # within 5% — resampling is not power-preserving exactly
|
||||
|
|
@ -1,9 +1,15 @@
|
|||
import os
|
||||
import sqlite3
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
import pytest
|
||||
from circuitforge_core.db import get_connection, run_migrations
|
||||
|
||||
sqlcipher_available = pytest.mark.skipif(
|
||||
__import__("importlib").util.find_spec("pysqlcipher3") is None,
|
||||
reason="pysqlcipher3 not installed",
|
||||
)
|
||||
|
||||
|
||||
def test_get_connection_returns_sqlite_connection(tmp_path):
|
||||
db = tmp_path / "test.db"
|
||||
|
|
@ -61,3 +67,38 @@ def test_run_migrations_applies_in_order(tmp_path):
|
|||
run_migrations(conn, migrations_dir)
|
||||
conn.execute("INSERT INTO foo (name) VALUES ('bar')")
|
||||
conn.close()
|
||||
|
||||
|
||||
# ── SQLCipher PRAGMA key tests (skipped when pysqlcipher3 not installed) ──────
|
||||
|
||||
|
||||
@sqlcipher_available
|
||||
def test_sqlcipher_key_with_special_chars_does_not_inject(tmp_path, monkeypatch):
|
||||
"""Key containing a single quote must not cause a SQL syntax error.
|
||||
|
||||
Regression for: conn.execute(f"PRAGMA key='{key}'") — if key = "x'--"
|
||||
the f-string form produced a broken PRAGMA statement. Parameterized
|
||||
form (PRAGMA key=?) must handle this safely.
|
||||
"""
|
||||
monkeypatch.setenv("CLOUD_MODE", "1")
|
||||
db = tmp_path / "enc.db"
|
||||
tricky_key = "pass'word\"--inject"
|
||||
# Must not raise; if the f-string form were used, this would produce
|
||||
# a syntax error or silently set an incorrect key.
|
||||
conn = get_connection(db, key=tricky_key)
|
||||
conn.execute("CREATE TABLE t (x INTEGER)")
|
||||
conn.close()
|
||||
|
||||
|
||||
@sqlcipher_available
|
||||
def test_sqlcipher_wrong_key_raises(tmp_path, monkeypatch):
|
||||
"""Opening an encrypted DB with the wrong key should raise, not silently corrupt."""
|
||||
monkeypatch.setenv("CLOUD_MODE", "1")
|
||||
db = tmp_path / "enc.db"
|
||||
conn = get_connection(db, key="correct-key")
|
||||
conn.execute("CREATE TABLE t (x INTEGER)")
|
||||
conn.close()
|
||||
|
||||
with pytest.raises(Exception):
|
||||
bad = get_connection(db, key="wrong-key")
|
||||
bad.execute("SELECT * FROM t") # should raise on bad key
|
||||
|
|
|
|||
0
tests/test_musicgen/__init__.py
Normal file
0
tests/test_musicgen/__init__.py
Normal file
124
tests/test_musicgen/test_app.py
Normal file
124
tests/test_musicgen/test_app.py
Normal file
|
|
@ -0,0 +1,124 @@
|
|||
"""
|
||||
Tests for the cf-musicgen FastAPI app using mock backend.
|
||||
"""
|
||||
import io
|
||||
import os
|
||||
import wave
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
import circuitforge_core.musicgen.app as musicgen_app
|
||||
from circuitforge_core.musicgen.backends.mock import MockMusicGenBackend
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def inject_mock_backend():
|
||||
"""Inject mock backend before each test; restore None after."""
|
||||
original = musicgen_app._backend
|
||||
musicgen_app._backend = MockMusicGenBackend()
|
||||
yield
|
||||
musicgen_app._backend = original
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def client():
|
||||
return TestClient(musicgen_app.app)
|
||||
|
||||
|
||||
def _silent_wav(duration_s: float = 1.0) -> bytes:
|
||||
n = int(duration_s * 16000)
|
||||
buf = io.BytesIO()
|
||||
with wave.open(buf, "wb") as wf:
|
||||
wf.setnchannels(1)
|
||||
wf.setsampwidth(2)
|
||||
wf.setframerate(16000)
|
||||
wf.writeframes(b"\x00\x00" * n)
|
||||
return buf.getvalue()
|
||||
|
||||
|
||||
# ── /health ───────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_health_returns_ok(client):
|
||||
resp = client.get("/health")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["status"] == "ok"
|
||||
assert data["model"] == "mock"
|
||||
assert data["vram_mb"] == 0
|
||||
|
||||
|
||||
def test_health_503_when_no_backend(client):
|
||||
musicgen_app._backend = None
|
||||
resp = client.get("/health")
|
||||
assert resp.status_code == 503
|
||||
|
||||
|
||||
# ── /continue ─────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_continue_returns_audio(client):
|
||||
resp = client.post(
|
||||
"/continue",
|
||||
data={"duration_s": "5.0"},
|
||||
files={"audio": ("test.wav", _silent_wav(), "audio/wav")},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.headers["content-type"] == "audio/wav"
|
||||
|
||||
|
||||
def test_continue_duration_header(client):
|
||||
resp = client.post(
|
||||
"/continue",
|
||||
data={"duration_s": "7.0"},
|
||||
files={"audio": ("test.wav", _silent_wav(), "audio/wav")},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert float(resp.headers["x-duration-s"]) == pytest.approx(7.0)
|
||||
|
||||
|
||||
def test_continue_model_header(client):
|
||||
resp = client.post(
|
||||
"/continue",
|
||||
data={"duration_s": "5.0"},
|
||||
files={"audio": ("test.wav", _silent_wav(), "audio/wav")},
|
||||
)
|
||||
assert resp.headers["x-model"] == "mock"
|
||||
|
||||
|
||||
def test_continue_rejects_zero_duration(client):
|
||||
resp = client.post(
|
||||
"/continue",
|
||||
data={"duration_s": "0"},
|
||||
files={"audio": ("test.wav", _silent_wav(), "audio/wav")},
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
def test_continue_rejects_too_long_duration(client):
|
||||
resp = client.post(
|
||||
"/continue",
|
||||
data={"duration_s": "61"},
|
||||
files={"audio": ("test.wav", _silent_wav(), "audio/wav")},
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
def test_continue_rejects_empty_audio(client):
|
||||
resp = client.post(
|
||||
"/continue",
|
||||
data={"duration_s": "5.0"},
|
||||
files={"audio": ("empty.wav", b"", "audio/wav")},
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
|
||||
|
||||
def test_continue_503_when_no_backend(client):
|
||||
musicgen_app._backend = None
|
||||
resp = client.post(
|
||||
"/continue",
|
||||
data={"duration_s": "5.0"},
|
||||
files={"audio": ("test.wav", _silent_wav(), "audio/wav")},
|
||||
)
|
||||
assert resp.status_code == 503
|
||||
97
tests/test_musicgen/test_mock_backend.py
Normal file
97
tests/test_musicgen/test_mock_backend.py
Normal file
|
|
@ -0,0 +1,97 @@
|
|||
"""
|
||||
Tests for the mock MusicGen backend and shared audio encode/decode utilities.
|
||||
|
||||
All tests run without a GPU or AudioCraft install.
|
||||
"""
|
||||
import io
|
||||
import wave
|
||||
|
||||
import pytest
|
||||
from circuitforge_core.musicgen.backends.base import (
|
||||
MODEL_MELODY,
|
||||
MODEL_SMALL,
|
||||
MusicGenBackend,
|
||||
MusicContinueResult,
|
||||
make_musicgen_backend,
|
||||
)
|
||||
from circuitforge_core.musicgen.backends.mock import MockMusicGenBackend
|
||||
|
||||
|
||||
# ── Mock backend ──────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_mock_satisfies_protocol():
|
||||
backend = MockMusicGenBackend()
|
||||
assert isinstance(backend, MusicGenBackend)
|
||||
|
||||
|
||||
def test_mock_model_name():
|
||||
assert MockMusicGenBackend().model_name == "mock"
|
||||
|
||||
|
||||
def test_mock_vram_mb():
|
||||
assert MockMusicGenBackend().vram_mb == 0
|
||||
|
||||
|
||||
def _silent_wav(duration_s: float = 1.0, sample_rate: int = 16000) -> bytes:
|
||||
n = int(duration_s * sample_rate)
|
||||
buf = io.BytesIO()
|
||||
with wave.open(buf, "wb") as wf:
|
||||
wf.setnchannels(1)
|
||||
wf.setsampwidth(2)
|
||||
wf.setframerate(sample_rate)
|
||||
wf.writeframes(b"\x00\x00" * n)
|
||||
return buf.getvalue()
|
||||
|
||||
|
||||
def test_mock_returns_result():
|
||||
backend = MockMusicGenBackend()
|
||||
result = backend.continue_audio(_silent_wav(), duration_s=5.0)
|
||||
assert isinstance(result, MusicContinueResult)
|
||||
|
||||
|
||||
def test_mock_duration_matches_request():
|
||||
backend = MockMusicGenBackend()
|
||||
result = backend.continue_audio(_silent_wav(), duration_s=7.5)
|
||||
assert result.duration_s == 7.5
|
||||
|
||||
|
||||
def test_mock_returns_valid_wav():
|
||||
backend = MockMusicGenBackend()
|
||||
result = backend.continue_audio(_silent_wav(), duration_s=2.0)
|
||||
assert result.format == "wav"
|
||||
buf = io.BytesIO(result.audio_bytes)
|
||||
with wave.open(buf, "rb") as wf:
|
||||
assert wf.getnframes() > 0
|
||||
|
||||
|
||||
def test_mock_sample_rate():
|
||||
backend = MockMusicGenBackend()
|
||||
result = backend.continue_audio(_silent_wav())
|
||||
assert result.sample_rate == 32000
|
||||
|
||||
|
||||
def test_mock_prompt_duration_passthrough():
|
||||
backend = MockMusicGenBackend()
|
||||
result = backend.continue_audio(_silent_wav(), prompt_duration_s=8.0)
|
||||
assert result.prompt_duration_s == 8.0
|
||||
|
||||
|
||||
def test_mock_description_ignored():
|
||||
backend = MockMusicGenBackend()
|
||||
# Should not raise regardless of description
|
||||
result = backend.continue_audio(_silent_wav(), description="upbeat jazz")
|
||||
assert result is not None
|
||||
|
||||
|
||||
# ── make_musicgen_backend factory ─────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_factory_returns_mock_when_flag_set():
|
||||
backend = make_musicgen_backend(mock=True)
|
||||
assert isinstance(backend, MockMusicGenBackend)
|
||||
|
||||
|
||||
def test_factory_mock_for_mock_model_name():
|
||||
backend = make_musicgen_backend(model_name="mock", mock=True)
|
||||
assert isinstance(backend, MockMusicGenBackend)
|
||||
Loading…
Reference in a new issue