diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..2e58c43 --- /dev/null +++ b/.env.example @@ -0,0 +1,31 @@ +# cf-voice environment — copy to .env and fill in values +# cf-voice itself does not auto-load .env; consumers (Linnet, Osprey, etc.) +# load it via python-dotenv in their own startup. For standalone cf-voice +# dev/testing, source this file manually or install python-dotenv. + +# ── HuggingFace ─────────────────────────────────────────────────────────────── +# Required for pyannote.audio speaker diarization model download. +# Get a free token at https://huggingface.co/settings/tokens +# Also accept the gated model terms at: +# https://huggingface.co/pyannote/speaker-diarization-3.1 +# https://huggingface.co/pyannote/segmentation-3.0 +HF_TOKEN= + +# ── Whisper STT ─────────────────────────────────────────────────────────────── +# Model size: tiny | base | small | medium | large-v2 | large-v3 +# Smaller = faster / less VRAM; larger = more accurate. +# Recommended: small (500MB VRAM) for real-time use. +CF_VOICE_WHISPER_MODEL=small + +# ── Compute ─────────────────────────────────────────────────────────────────── +# auto (detect GPU), cuda, cpu +CF_VOICE_DEVICE=auto + +# ── Mock mode ───────────────────────────────────────────────────────────────── +# Set to 1 to use synthetic VoiceFrames — no GPU, mic, or HF token required. +# Unset or 0 for real audio capture. +CF_VOICE_MOCK= + +# ── Tone classifier ─────────────────────────────────────────────────────────── +# Minimum confidence to emit a VoiceFrame (below this = frame skipped). +CF_VOICE_CONFIDENCE_THRESHOLD=0.55 diff --git a/cf_voice/capture.py b/cf_voice/capture.py new file mode 100644 index 0000000..0b4f64e --- /dev/null +++ b/cf_voice/capture.py @@ -0,0 +1,192 @@ +# cf_voice/capture.py — real microphone capture +# +# MIT licensed. This layer handles audio I/O and buffering only. +# Inference (STT, classify) runs here but is imported lazily so that mock +# mode works without the [inference] extras installed. +from __future__ import annotations + +import asyncio +import logging +import time +from typing import AsyncIterator + +import numpy as np + +from cf_voice.models import VoiceFrame +from cf_voice.io import VoiceIO + +logger = logging.getLogger(__name__) + +_SAMPLE_RATE = 16_000 +_CHUNK_FRAMES = 1_600 # 100ms of audio at 16kHz +_WINDOW_CHUNKS = 20 # 20 × 100ms = 2s STT window + +# Skip windows whose RMS is below this (microphone noise floor) +_SILENCE_RMS = 0.008 + + +class MicVoiceIO(VoiceIO): + """ + Real microphone capture producing enriched VoiceFrames. + + Capture loop (sounddevice callback → asyncio.Queue): + ┌──────────────────────────────────────────────────────────────┐ + │ sounddevice InputStream │ + │ 100ms PCM Int16 chunks → _queue │ + └──────────────────────────────────────────────────────────────┘ + ↓ accumulated 20× (2s window) + ┌──────────────────────────────────────────────────────────────┐ + │ Parallel async tasks (both run in thread pool) │ + │ WhisperSTT.transcribe_chunk_async() → transcript │ + │ ToneClassifier.classify_async() → ToneResult │ + └──────────────────────────────────────────────────────────────┘ + ↓ combined + VoiceFrame(label, confidence, speaker_id, shift_magnitude, ...) + + speaker_id is always "speaker_a" until Navigation v0.2.x wires in the + Diarizer. That integration happens in ContextClassifier, not here. + + Requires: pip install cf-voice[inference] + """ + + def __init__( + self, + device_index: int | None = None, + ) -> None: + # Lazy import so that importing MicVoiceIO doesn't blow up without + # inference deps — only instantiation requires them. + try: + from cf_voice.stt import WhisperSTT + from cf_voice.classify import ToneClassifier + except ImportError as exc: + raise ImportError( + "Real audio capture requires the [inference] extras. " + "Install with: pip install cf-voice[inference]" + ) from exc + + self._stt = WhisperSTT.from_env() + self._classifier = ToneClassifier.from_env() + self._device_index = device_index + self._running = False + self._queue: asyncio.Queue[bytes] = asyncio.Queue(maxsize=200) + self._prev_label: str = "" + self._prev_confidence: float = 0.0 + + async def stream(self) -> AsyncIterator[VoiceFrame]: # type: ignore[override] + try: + import sounddevice as sd + except ImportError as exc: + raise ImportError( + "sounddevice is required for microphone capture. " + "Install with: pip install cf-voice[inference]" + ) from exc + + self._running = True + self._stt.reset_session() + + loop = asyncio.get_event_loop() + + def _sd_callback( + indata: np.ndarray, frames: int, time_info, status + ) -> None: + if status: + logger.debug("sounddevice status: %s", status) + pcm = (indata[:, 0] * 32_767).astype(np.int16).tobytes() + try: + loop.call_soon_threadsafe(self._queue.put_nowait, pcm) + except asyncio.QueueFull: + logger.warning("Audio queue full — dropping chunk (inference too slow?)") + + input_stream = sd.InputStream( + samplerate=_SAMPLE_RATE, + channels=1, + dtype="float32", + blocksize=_CHUNK_FRAMES, + device=self._device_index, + callback=_sd_callback, + ) + + window: list[bytes] = [] + session_start = time.monotonic() + + with input_stream: + while self._running: + try: + chunk = await asyncio.wait_for(self._queue.get(), timeout=0.5) + except asyncio.TimeoutError: + continue + + window.append(chunk) + + if len(window) < _WINDOW_CHUNKS: + continue + + # 2-second window ready + raw = b"".join(window) + window.clear() + + audio = ( + np.frombuffer(raw, dtype=np.int16).astype(np.float32) / 32_768.0 + ) + + # Skip silent windows + rms = float(np.sqrt(np.mean(audio ** 2))) + if rms < _SILENCE_RMS: + continue + + # Run STT and tone classification in parallel + stt_task = asyncio.create_task( + self._stt.transcribe_chunk_async(raw) + ) + tone_task = asyncio.create_task( + self._classifier.classify_async(audio) + ) + stt_result, tone = await asyncio.gather(stt_task, tone_task) + + # Update transcript on tone result now that STT is done + # (re-classify with text is cheap enough to do inline) + if stt_result.text: + tone = await self._classifier.classify_async( + audio, stt_result.text + ) + + shift = _compute_shift( + self._prev_label, + self._prev_confidence, + tone.label, + tone.confidence, + ) + self._prev_label = tone.label + self._prev_confidence = tone.confidence + + yield VoiceFrame( + label=tone.label, + confidence=tone.confidence, + speaker_id="speaker_a", # Navigation v0.2.x: diarizer wired here + shift_magnitude=round(shift, 3), + timestamp=round(time.monotonic() - session_start, 2), + ) + + async def stop(self) -> None: + self._running = False + + +# ── Helpers ─────────────────────────────────────────────────────────────────── + +def _compute_shift( + prev_label: str, + prev_confidence: float, + curr_label: str, + curr_confidence: float, +) -> float: + """ + Compute shift_magnitude for a VoiceFrame. + + 0.0 when the label hasn't changed. + Higher when the label changes with high confidence in both directions. + Capped at 1.0. + """ + if not prev_label or curr_label == prev_label: + return 0.0 + # A high-confidence transition in both frames = large shift + return min(1.0, (prev_confidence + curr_confidence) / 2.0) diff --git a/cf_voice/classify.py b/cf_voice/classify.py new file mode 100644 index 0000000..f5dff6a --- /dev/null +++ b/cf_voice/classify.py @@ -0,0 +1,245 @@ +# cf_voice/classify.py — tone / affect classifier +# +# BSL 1.1: real inference. Requires [inference] extras. +# Stub behaviour: raises NotImplementedError if inference deps not installed. +# +# Pipeline: wav2vec2 SER (speech emotion recognition) + librosa prosody +# features → AFFECT_LABELS defined in cf_voice.events. +from __future__ import annotations + +import asyncio +import logging +import os +from dataclasses import dataclass, field +from functools import partial + +import numpy as np + +logger = logging.getLogger(__name__) + +_SAMPLE_RATE = 16_000 + +# Confidence floor — results below this are discarded by the caller +_DEFAULT_THRESHOLD = float(os.environ.get("CF_VOICE_CONFIDENCE_THRESHOLD", "0.55")) + +# wav2vec2 SER model from HuggingFace +# ehcalabres/wav2vec2-lg-xlsr-en-speech-emotion-recognition +# Outputs 7 classes: angry, disgust, fear, happy, neutral, sadness, surprise +_SER_MODEL_ID = "ehcalabres/wav2vec2-lg-xlsr-en-speech-emotion-recognition" + +# ── Affect label mapping ────────────────────────────────────────────────────── +# Maps (emotion, prosody_profile) → affect label from cf_voice.events.AFFECT_LABELS +# Prosody profile is a tuple of flags present from _extract_prosody_flags(). + +_EMOTION_BASE: dict[str, str] = { + "angry": "frustrated", + "disgust": "dismissive", + "fear": "apologetic", + "happy": "warm", + "neutral": "neutral", + "sadness": "tired", + "surprise": "confused", +} + +# Prosody-driven overrides: (base_affect, flag) → override affect +_PROSODY_OVERRIDES: dict[tuple[str, str], str] = { + ("neutral", "fast_rate"): "genuine", + ("neutral", "flat_pitch"): "scripted", + ("neutral", "low_energy"): "tired", + ("frustrated", "rising"): "urgent", + ("warm", "rising"): "genuine", + ("tired", "rising"): "optimistic", + ("dismissive", "flat_pitch"): "condescending", +} + +# Affect → human-readable VoiceFrame label (reverse of events._label_to_affect) +_AFFECT_TO_LABEL: dict[str, str] = { + "neutral": "Calm and focused", + "warm": "Enthusiastic", + "frustrated": "Frustrated but contained", + "dismissive": "Politely dismissive", + "apologetic": "Nervous but cooperative", + "urgent": "Warmly impatient", + "condescending": "Politely dismissive", + "scripted": "Calm and focused", # scripted reads as neutral to the observer + "genuine": "Genuinely curious", + "confused": "Confused but engaged", + "tired": "Tired and compliant", + "optimistic": "Guardedly optimistic", +} + + +@dataclass +class ToneResult: + label: str # human-readable VoiceFrame label + affect: str # AFFECT_LABELS key + confidence: float + prosody_flags: list[str] = field(default_factory=list) + + +class ToneClassifier: + """ + Tone/affect classifier: wav2vec2 SER + librosa prosody. + + Loads the model lazily on first call to avoid import-time GPU allocation. + Thread-safe for concurrent classify() calls — the pipeline is stateless + per-call; session state lives in the caller (ContextClassifier). + """ + + def __init__(self, threshold: float = _DEFAULT_THRESHOLD) -> None: + self._threshold = threshold + self._pipeline = None # lazy-loaded + + @classmethod + def from_env(cls) -> "ToneClassifier": + threshold = float(os.environ.get("CF_VOICE_CONFIDENCE_THRESHOLD", "0.55")) + return cls(threshold=threshold) + + def _load_pipeline(self) -> None: + if self._pipeline is not None: + return + try: + from transformers import pipeline as hf_pipeline + except ImportError as exc: + raise ImportError( + "transformers is required for tone classification. " + "Install with: pip install cf-voice[inference]" + ) from exc + + device = 0 if _cuda_available() else -1 + logger.info("Loading SER model %s on device %s", _SER_MODEL_ID, device) + self._pipeline = hf_pipeline( + "audio-classification", + model=_SER_MODEL_ID, + device=device, + ) + + def classify(self, audio_float32: np.ndarray, transcript: str = "") -> ToneResult: + """ + Classify tone/affect from a float32 16kHz mono audio window. + + transcript is used as a weak signal for ambiguous cases (e.g. words + like "unfortunately" bias toward apologetic even on a neutral voice). + """ + self._load_pipeline() + + # Ensure the model sees float32 at the right rate + assert audio_float32.dtype == np.float32, "audio must be float32" + + # Run SER + preds = self._pipeline({"raw": audio_float32, "sampling_rate": _SAMPLE_RATE}) + best = max(preds, key=lambda p: p["score"]) + emotion = best["label"].lower() + confidence = float(best["score"]) + + # Extract prosody features from raw audio + prosody_flags = _extract_prosody_flags(audio_float32) + + # Resolve affect from base emotion + prosody + affect = _EMOTION_BASE.get(emotion, "neutral") + for flag in prosody_flags: + override = _PROSODY_OVERRIDES.get((affect, flag)) + if override: + affect = override + break + + # Weak transcript signals + affect = _apply_transcript_hints(affect, transcript) + + label = _AFFECT_TO_LABEL.get(affect, "Calm and focused") + return ToneResult( + label=label, + affect=affect, + confidence=confidence, + prosody_flags=prosody_flags, + ) + + async def classify_async( + self, audio_float32: np.ndarray, transcript: str = "" + ) -> ToneResult: + """classify() without blocking the event loop.""" + loop = asyncio.get_event_loop() + fn = partial(self.classify, audio_float32, transcript) + return await loop.run_in_executor(None, fn) + + +# ── Prosody helpers ─────────────────────────────────────────────────────────── + +def _extract_prosody_flags(audio: np.ndarray) -> list[str]: + """ + Extract lightweight prosody flags from a float32 16kHz mono window. + Returns a list of string flags consumed by _PROSODY_OVERRIDES. + """ + try: + import librosa + except ImportError: + return [] + + flags: list[str] = [] + + # Energy (RMS) + rms = float(np.sqrt(np.mean(audio ** 2))) + if rms < 0.02: + flags.append("low_energy") + elif rms > 0.15: + flags.append("high_energy") + + # Speech rate approximation via zero-crossing rate + zcr = float(np.mean(librosa.feature.zero_crossing_rate(audio))) + if zcr > 0.12: + flags.append("fast_rate") + elif zcr < 0.04: + flags.append("slow_rate") + + # Pitch contour via YIN + try: + f0 = librosa.yin( + audio, + fmin=librosa.note_to_hz("C2"), + fmax=librosa.note_to_hz("C7"), + sr=_SAMPLE_RATE, + ) + voiced = f0[f0 > 0] + if len(voiced) > 5: + # Rising: last quarter higher than first quarter + q = len(voiced) // 4 + if q > 0 and np.mean(voiced[-q:]) > np.mean(voiced[:q]) * 1.15: + flags.append("rising") + # Flat: variance less than 15Hz + if np.std(voiced) < 15: + flags.append("flat_pitch") + except Exception: + pass # pitch extraction is best-effort + + return flags + + +def _apply_transcript_hints(affect: str, transcript: str) -> str: + """ + Apply weak keyword signals from transcript text to adjust affect. + Only overrides when affect is already ambiguous (neutral/tired). + """ + if not transcript or affect not in ("neutral", "tired"): + return affect + + t = transcript.lower() + apologetic_words = {"sorry", "apologize", "unfortunately", "afraid", "regret"} + urgent_words = {"urgent", "immediately", "asap", "right now", "critical"} + dismissive_words = {"policy", "unable to", "cannot", "not possible", "outside"} + + if any(w in t for w in apologetic_words): + return "apologetic" + if any(w in t for w in urgent_words): + return "urgent" + if any(w in t for w in dismissive_words): + return "dismissive" + + return affect + + +def _cuda_available() -> bool: + try: + import torch + return torch.cuda.is_available() + except ImportError: + return False diff --git a/cf_voice/context.py b/cf_voice/context.py index 84f4c84..7b450f4 100644 --- a/cf_voice/context.py +++ b/cf_voice/context.py @@ -78,26 +78,26 @@ class ContextClassifier: This is the request-response path used by the cf-orch endpoint. The streaming path (async generator) is for continuous consumers. - Stub: audio_b64 is ignored; returns synthetic events from the mock IO. - Real: decode audio, run YAMNet + SER + pyannote, return events. - elcor=True switches subtext format to Mass Effect Elcor prefix style. Generic tone annotation is always present regardless of elcor flag. """ - if not isinstance(self._io, MockVoiceIO): - raise NotImplementedError( - "classify_chunk() requires mock mode. " - "Real audio inference is not yet implemented." - ) - # Generate a synthetic VoiceFrame to derive events from - rng = self._io._rng - import time - label = rng.choice(self._io._labels) + if isinstance(self._io, MockVoiceIO): + return self._classify_chunk_mock(timestamp, prior_frames, elcor) + + return self._classify_chunk_real(audio_b64, timestamp, elcor) + + def _classify_chunk_mock( + self, timestamp: float, prior_frames: int, elcor: bool + ) -> list[AudioEvent]: + """Synthetic path — used in mock mode and CI.""" + rng = self._io._rng # type: ignore[attr-defined] + import time as _time + label = rng.choice(self._io._labels) # type: ignore[attr-defined] shift = rng.uniform(0.1, 0.7) if prior_frames > 0 else 0.0 frame = VoiceFrame( label=label, confidence=rng.uniform(0.6, 0.97), - speaker_id=rng.choice(self._io._speakers), + speaker_id=rng.choice(self._io._speakers), # type: ignore[attr-defined] shift_magnitude=round(shift, 3), timestamp=timestamp, ) @@ -110,6 +110,38 @@ class ContextClassifier: ) return [tone] + def _classify_chunk_real( + self, audio_b64: str, timestamp: float, elcor: bool + ) -> list[AudioEvent]: + """Real inference path — used when CF_VOICE_MOCK is unset.""" + import asyncio + import base64 + import numpy as np + from cf_voice.classify import ToneClassifier + + pcm = base64.b64decode(audio_b64) + audio = np.frombuffer(pcm, dtype=np.int16).astype(np.float32) / 32_768.0 + + # ToneClassifier is stateless per-call, safe to instantiate inline + classifier = ToneClassifier.from_env() + tone_result = classifier.classify(audio) + + frame = VoiceFrame( + label=tone_result.label, + confidence=tone_result.confidence, + speaker_id="speaker_a", + shift_magnitude=0.0, + timestamp=timestamp, + ) + event = tone_event_from_voice_frame( + frame_label=frame.label, + frame_confidence=frame.confidence, + shift_magnitude=frame.shift_magnitude, + timestamp=frame.timestamp, + elcor=elcor, + ) + return [event] + def _enrich(self, frame: VoiceFrame) -> VoiceFrame: """ Apply tone classification to a raw frame. diff --git a/cf_voice/diarize.py b/cf_voice/diarize.py new file mode 100644 index 0000000..217dd51 --- /dev/null +++ b/cf_voice/diarize.py @@ -0,0 +1,146 @@ +# cf_voice/diarize.py — async speaker diarization via pyannote.audio +# +# BSL 1.1: real inference. Requires HF_TOKEN + pyannote model access. +# Gate usage with CF_VOICE_MOCK=1 to skip entirely in dev/CI. +# +# Model used: pyannote/speaker-diarization-3.1 +# Requires accepting gated model terms at: +# https://huggingface.co/pyannote/speaker-diarization-3.1 +# https://huggingface.co/pyannote/segmentation-3.0 +from __future__ import annotations + +import asyncio +import logging +import os +from dataclasses import dataclass + +import numpy as np + +logger = logging.getLogger(__name__) + +_DIARIZATION_MODEL = "pyannote/speaker-diarization-3.1" +_SAMPLE_RATE = 16_000 + + +@dataclass +class SpeakerSegment: + """A speaker-labelled time range within an audio window.""" + speaker_id: str # ephemeral local label, e.g. "SPEAKER_00" + start_s: float + end_s: float + + @property + def duration_s(self) -> float: + return self.end_s - self.start_s + + +class Diarizer: + """ + Async wrapper around pyannote.audio speaker diarization pipeline. + + PyAnnote's pipeline is synchronous and GPU-bound. We run it in a thread + pool executor to avoid blocking the asyncio event loop. + + The pipeline is loaded once at construction time (model download on first + use, then cached by HuggingFace). CUDA is used automatically if available. + + Usage + ----- + diarizer = Diarizer.from_env() + segments = await diarizer.diarize_async(audio_float32) + for seg in segments: + print(seg.speaker_id, seg.start_s, seg.end_s) + + Navigation v0.2.x wires this into ContextClassifier so that each + VoiceFrame carries the correct speaker_id from diarization output. + """ + + def __init__(self, hf_token: str) -> None: + try: + from pyannote.audio import Pipeline + except ImportError as exc: + raise ImportError( + "pyannote.audio is required for speaker diarization. " + "Install with: pip install cf-voice[inference]" + ) from exc + + logger.info("Loading diarization pipeline %s", _DIARIZATION_MODEL) + self._pipeline = Pipeline.from_pretrained( + _DIARIZATION_MODEL, + use_auth_token=hf_token, + ) + + # Move to GPU if available + try: + import torch + if torch.cuda.is_available(): + self._pipeline.to(torch.device("cuda")) + logger.info("Diarization pipeline on CUDA") + except ImportError: + pass + + @classmethod + def from_env(cls) -> "Diarizer": + """Construct from HF_TOKEN environment variable.""" + token = os.environ.get("HF_TOKEN", "").strip() + if not token: + raise EnvironmentError( + "HF_TOKEN is required for speaker diarization. " + "Set it in your .env or environment. " + "See .env.example for setup instructions." + ) + return cls(hf_token=token) + + def _diarize_sync( + self, audio_float32: np.ndarray, sample_rate: int = _SAMPLE_RATE + ) -> list[SpeakerSegment]: + """Synchronous diarization — always call via diarize_async.""" + import torch + + # pyannote expects (channels, samples) float32 tensor + waveform = torch.from_numpy(audio_float32[np.newaxis, :].astype(np.float32)) + diarization = self._pipeline( + {"waveform": waveform, "sample_rate": sample_rate} + ) + + segments: list[SpeakerSegment] = [] + for turn, _, speaker in diarization.itertracks(yield_label=True): + segments.append( + SpeakerSegment( + speaker_id=speaker, + start_s=round(turn.start, 3), + end_s=round(turn.end, 3), + ) + ) + return segments + + async def diarize_async( + self, + audio_float32: np.ndarray, + sample_rate: int = _SAMPLE_RATE, + ) -> list[SpeakerSegment]: + """ + Diarize an audio window without blocking the event loop. + + audio_float32 should be 16kHz mono float32. + Typical input is a 2-second window from MicVoiceIO (32000 samples). + Returns segments ordered by start_s. + """ + loop = asyncio.get_event_loop() + return await loop.run_in_executor( + None, self._diarize_sync, audio_float32, sample_rate + ) + + def speaker_at( + self, segments: list[SpeakerSegment], timestamp_s: float + ) -> str: + """ + Return the speaker_id active at a given timestamp within the window. + + Falls back to "speaker_a" if no segment covers the timestamp + (e.g. during silence or at window boundaries). + """ + for seg in segments: + if seg.start_s <= timestamp_s <= seg.end_s: + return seg.speaker_id + return "speaker_a" diff --git a/cf_voice/io.py b/cf_voice/io.py index 94e6104..0257afd 100644 --- a/cf_voice/io.py +++ b/cf_voice/io.py @@ -106,17 +106,17 @@ class MockVoiceIO(VoiceIO): def make_io( mock: bool | None = None, interval_s: float = 2.5, + device_index: int | None = None, ) -> VoiceIO: """ Factory: return a VoiceIO instance appropriate for the current environment. - mock=True or CF_VOICE_MOCK=1 → MockVoiceIO (no audio hardware needed) - Otherwise → real audio capture (not yet implemented) + mock=True or CF_VOICE_MOCK=1 → MockVoiceIO (no GPU, mic, or HF token needed) + Otherwise → MicVoiceIO (requires [inference] extras) """ use_mock = mock if mock is not None else os.environ.get("CF_VOICE_MOCK", "") == "1" if use_mock: return MockVoiceIO(interval_s=interval_s) - raise NotImplementedError( - "Real audio capture is not yet implemented. " - "Set CF_VOICE_MOCK=1 or pass mock=True to use synthetic frames." - ) + + from cf_voice.capture import MicVoiceIO + return MicVoiceIO(device_index=device_index) diff --git a/cf_voice/stt.py b/cf_voice/stt.py new file mode 100644 index 0000000..7bb3685 --- /dev/null +++ b/cf_voice/stt.py @@ -0,0 +1,142 @@ +# cf_voice/stt.py — faster-whisper STT wrapper +# +# BSL 1.1 when real inference models are integrated. +# Requires the [inference] extras: pip install cf-voice[inference] +from __future__ import annotations + +import asyncio +import logging +import os +from dataclasses import dataclass + +import numpy as np + +logger = logging.getLogger(__name__) + +_VRAM_ESTIMATES_MB: dict[str, int] = { + "tiny": 150, "base": 300, "small": 500, + "medium": 1500, "large": 3000, "large-v2": 3000, "large-v3": 3500, +} + +# Minimum audio duration in seconds before attempting transcription. +# Whisper hallucinates on very short clips. +_MIN_DURATION_S = 0.3 + + +@dataclass +class STTResult: + text: str + language: str + duration_s: float + is_final: bool + + +class WhisperSTT: + """ + Async wrapper around faster-whisper for real-time chunk transcription. + + Runs transcription in a thread pool executor so it never blocks the event + loop. Maintains a rolling 50-word session prompt to improve context + continuity across 2-second windows. + + Usage + ----- + stt = WhisperSTT.from_env() + result = await stt.transcribe_chunk_async(pcm_int16_bytes) + print(result.text) + """ + + def __init__( + self, + model_name: str = "small", + device: str = "auto", + compute_type: str | None = None, + ) -> None: + try: + from faster_whisper import WhisperModel + except ImportError as exc: + raise ImportError( + "faster-whisper is required for real STT. " + "Install with: pip install cf-voice[inference]" + ) from exc + + if device == "auto": + try: + import torch + device = "cuda" if torch.cuda.is_available() else "cpu" + except ImportError: + device = "cpu" + + if compute_type is None: + compute_type = "float16" if device == "cuda" else "int8" + + logger.info("Loading Whisper %s on %s (%s)", model_name, device, compute_type) + self._model = WhisperModel( + model_name, device=device, compute_type=compute_type + ) + self._device = device + self._model_name = model_name + self._session_prompt: str = "" + + @classmethod + def from_env(cls) -> "WhisperSTT": + """Construct from CF_VOICE_WHISPER_MODEL and CF_VOICE_DEVICE env vars.""" + return cls( + model_name=os.environ.get("CF_VOICE_WHISPER_MODEL", "small"), + device=os.environ.get("CF_VOICE_DEVICE", "auto"), + ) + + @property + def vram_mb(self) -> int: + """Estimated VRAM usage in MB for this model/compute_type combination.""" + return _VRAM_ESTIMATES_MB.get(self._model_name, 1500) + + def _transcribe_sync(self, audio_float32: np.ndarray) -> STTResult: + """Synchronous transcription — always call via transcribe_chunk_async.""" + duration = len(audio_float32) / 16_000.0 + + if duration < _MIN_DURATION_S: + return STTResult( + text="", language="en", duration_s=duration, is_final=False + ) + + segments, info = self._model.transcribe( + audio_float32, + language=None, + initial_prompt=self._session_prompt or None, + vad_filter=False, # silence gating happens upstream in MicVoiceIO + word_timestamps=False, + beam_size=3, + temperature=0.0, + ) + + text = " ".join(s.text.strip() for s in segments).strip() + + # Rolling context: keep last ~50 words so the next chunk has prior text + if text: + words = (self._session_prompt + " " + text).split() + self._session_prompt = " ".join(words[-50:]) + + return STTResult( + text=text, + language=info.language, + duration_s=duration, + is_final=duration >= 1.0 and info.language_probability > 0.5, + ) + + async def transcribe_chunk_async(self, pcm_int16: bytes) -> STTResult: + """ + Transcribe a raw PCM Int16 chunk, non-blocking. + + pcm_int16 should be 16kHz mono bytes. Typical input is 20 × 100ms + chunks accumulated by MicVoiceIO (2-second window = 64000 bytes). + """ + audio = ( + np.frombuffer(pcm_int16, dtype=np.int16).astype(np.float32) / 32768.0 + ) + loop = asyncio.get_event_loop() + return await loop.run_in_executor(None, self._transcribe_sync, audio) + + def reset_session(self) -> None: + """Clear the rolling prompt. Call at the start of each new conversation.""" + self._session_prompt = "" diff --git a/pyproject.toml b/pyproject.toml index 20d3d37..0aa0c6c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,12 +18,18 @@ dependencies = [ inference = [ "torch>=2.0", "torchaudio>=2.0", + "numpy>=1.24", + "faster-whisper>=1.0", + "sounddevice>=0.4", "transformers>=4.40", + "librosa>=0.10", "pyannote.audio>=3.1", + "python-dotenv>=1.0", ] dev = [ "pytest>=8.0", "pytest-asyncio>=0.23", + "numpy>=1.24", ] [project.scripts]