# cf_voice/stt.py — faster-whisper STT wrapper # # BSL 1.1 when real inference models are integrated. # Requires the [inference] extras: pip install cf-voice[inference] from __future__ import annotations import asyncio import logging import os from dataclasses import dataclass import numpy as np logger = logging.getLogger(__name__) _VRAM_ESTIMATES_MB: dict[str, int] = { "tiny": 150, "base": 300, "small": 500, "medium": 1500, "large": 3000, "large-v2": 3000, "large-v3": 3500, } # Minimum audio duration in seconds before attempting transcription. # Whisper hallucinates on very short clips. _MIN_DURATION_S = 0.3 @dataclass class STTResult: text: str language: str duration_s: float is_final: bool class WhisperSTT: """ Async wrapper around faster-whisper for real-time chunk transcription. Runs transcription in a thread pool executor so it never blocks the event loop. Maintains a rolling 50-word session prompt to improve context continuity across 2-second windows. Usage ----- stt = WhisperSTT.from_env() result = await stt.transcribe_chunk_async(pcm_int16_bytes) print(result.text) """ def __init__( self, model_name: str = "small", device: str = "auto", compute_type: str | None = None, ) -> None: try: from faster_whisper import WhisperModel except ImportError as exc: raise ImportError( "faster-whisper is required for real STT. " "Install with: pip install cf-voice[inference]" ) from exc if device == "auto": try: import torch device = "cuda" if torch.cuda.is_available() else "cpu" except ImportError: device = "cpu" if compute_type is None: compute_type = "float16" if device == "cuda" else "int8" logger.info("Loading Whisper %s on %s (%s)", model_name, device, compute_type) self._model = WhisperModel( model_name, device=device, compute_type=compute_type ) self._device = device self._model_name = model_name self._session_prompt: str = "" @classmethod def from_env(cls) -> "WhisperSTT": """Construct from CF_VOICE_WHISPER_MODEL and CF_VOICE_DEVICE env vars.""" return cls( model_name=os.environ.get("CF_VOICE_WHISPER_MODEL", "small"), device=os.environ.get("CF_VOICE_DEVICE", "auto"), ) @property def vram_mb(self) -> int: """Estimated VRAM usage in MB for this model/compute_type combination.""" return _VRAM_ESTIMATES_MB.get(self._model_name, 1500) def _transcribe_sync(self, audio_float32: np.ndarray) -> STTResult: """Synchronous transcription — always call via transcribe_chunk_async.""" duration = len(audio_float32) / 16_000.0 if duration < _MIN_DURATION_S: return STTResult( text="", language="en", duration_s=duration, is_final=False ) segments, info = self._model.transcribe( audio_float32, language=None, initial_prompt=self._session_prompt or None, vad_filter=False, # silence gating happens upstream in MicVoiceIO word_timestamps=False, beam_size=3, temperature=0.0, ) text = " ".join(s.text.strip() for s in segments).strip() # Rolling context: keep last ~50 words so the next chunk has prior text if text: words = (self._session_prompt + " " + text).split() self._session_prompt = " ".join(words[-50:]) return STTResult( text=text, language=info.language, duration_s=duration, is_final=duration >= 1.0 and info.language_probability > 0.5, ) async def transcribe_chunk_async(self, pcm_int16: bytes) -> STTResult: """ Transcribe a raw PCM Int16 chunk, non-blocking. pcm_int16 should be 16kHz mono bytes. Typical input is 20 × 100ms chunks accumulated by MicVoiceIO (2-second window = 64000 bytes). """ audio = ( np.frombuffer(pcm_int16, dtype=np.int16).astype(np.float32) / 32768.0 ) loop = asyncio.get_event_loop() return await loop.run_in_executor(None, self._transcribe_sync, audio) def reset_session(self) -> None: """Clear the rolling prompt. Call at the start of each new conversation.""" self._session_prompt = ""