cf-voice/cf_voice/stt.py
pyr0ball fed6388b99 feat: real inference pipeline — STT, tone classifier, diarization, mic capture
- cf_voice/stt.py: WhisperSTT async wrapper (faster-whisper, thread-pool executor,
  rolling 50-word session prompt for cross-chunk context continuity)
- cf_voice/classify.py: ToneClassifier — wav2vec2 SER + librosa prosody flags
  (energy, ZCR speech rate, YIN pitch contour) mapped to AFFECT_LABELS
- cf_voice/diarize.py: Diarizer async wrapper around pyannote/speaker-diarization-3.1;
  speaker_at() helper for Navigation v0.2.x wiring
- cf_voice/capture.py: MicVoiceIO — sounddevice 16kHz mono capture, 2s window
  accumulation, parallel STT+classify tasks, shift_magnitude from confidence delta
- cf_voice/io.py: make_io() now returns MicVoiceIO when CF_VOICE_MOCK is unset
- cf_voice/context.py: classify_chunk() split into mock/real paths; real path
  decodes base64 PCM and runs ToneClassifier synchronously (cf-orch endpoint)
- pyproject.toml: inference extras expanded (faster-whisper, sounddevice,
  librosa, python-dotenv)
- .env.example: HF_TOKEN, CF_VOICE_WHISPER_MODEL, CF_VOICE_DEVICE, CF_VOICE_MOCK,
  CF_VOICE_CONFIDENCE_THRESHOLD

Prior art ported from: Plex-Scripts/transcription/diarization.py (pyannote
setup), devl/ogma/backend/speech/transcription_engine.py (faster-whisper
preprocessing and session prompt pattern).
2026-04-06 17:33:51 -07:00

142 lines
4.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# cf_voice/stt.py — faster-whisper STT wrapper
#
# BSL 1.1 when real inference models are integrated.
# Requires the [inference] extras: pip install cf-voice[inference]
from __future__ import annotations
import asyncio
import logging
import os
from dataclasses import dataclass
import numpy as np
logger = logging.getLogger(__name__)
_VRAM_ESTIMATES_MB: dict[str, int] = {
"tiny": 150, "base": 300, "small": 500,
"medium": 1500, "large": 3000, "large-v2": 3000, "large-v3": 3500,
}
# Minimum audio duration in seconds before attempting transcription.
# Whisper hallucinates on very short clips.
_MIN_DURATION_S = 0.3
@dataclass
class STTResult:
text: str
language: str
duration_s: float
is_final: bool
class WhisperSTT:
"""
Async wrapper around faster-whisper for real-time chunk transcription.
Runs transcription in a thread pool executor so it never blocks the event
loop. Maintains a rolling 50-word session prompt to improve context
continuity across 2-second windows.
Usage
-----
stt = WhisperSTT.from_env()
result = await stt.transcribe_chunk_async(pcm_int16_bytes)
print(result.text)
"""
def __init__(
self,
model_name: str = "small",
device: str = "auto",
compute_type: str | None = None,
) -> None:
try:
from faster_whisper import WhisperModel
except ImportError as exc:
raise ImportError(
"faster-whisper is required for real STT. "
"Install with: pip install cf-voice[inference]"
) from exc
if device == "auto":
try:
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
except ImportError:
device = "cpu"
if compute_type is None:
compute_type = "float16" if device == "cuda" else "int8"
logger.info("Loading Whisper %s on %s (%s)", model_name, device, compute_type)
self._model = WhisperModel(
model_name, device=device, compute_type=compute_type
)
self._device = device
self._model_name = model_name
self._session_prompt: str = ""
@classmethod
def from_env(cls) -> "WhisperSTT":
"""Construct from CF_VOICE_WHISPER_MODEL and CF_VOICE_DEVICE env vars."""
return cls(
model_name=os.environ.get("CF_VOICE_WHISPER_MODEL", "small"),
device=os.environ.get("CF_VOICE_DEVICE", "auto"),
)
@property
def vram_mb(self) -> int:
"""Estimated VRAM usage in MB for this model/compute_type combination."""
return _VRAM_ESTIMATES_MB.get(self._model_name, 1500)
def _transcribe_sync(self, audio_float32: np.ndarray) -> STTResult:
"""Synchronous transcription — always call via transcribe_chunk_async."""
duration = len(audio_float32) / 16_000.0
if duration < _MIN_DURATION_S:
return STTResult(
text="", language="en", duration_s=duration, is_final=False
)
segments, info = self._model.transcribe(
audio_float32,
language=None,
initial_prompt=self._session_prompt or None,
vad_filter=False, # silence gating happens upstream in MicVoiceIO
word_timestamps=False,
beam_size=3,
temperature=0.0,
)
text = " ".join(s.text.strip() for s in segments).strip()
# Rolling context: keep last ~50 words so the next chunk has prior text
if text:
words = (self._session_prompt + " " + text).split()
self._session_prompt = " ".join(words[-50:])
return STTResult(
text=text,
language=info.language,
duration_s=duration,
is_final=duration >= 1.0 and info.language_probability > 0.5,
)
async def transcribe_chunk_async(self, pcm_int16: bytes) -> STTResult:
"""
Transcribe a raw PCM Int16 chunk, non-blocking.
pcm_int16 should be 16kHz mono bytes. Typical input is 20 × 100ms
chunks accumulated by MicVoiceIO (2-second window = 64000 bytes).
"""
audio = (
np.frombuffer(pcm_int16, dtype=np.int16).astype(np.float32) / 32768.0
)
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, self._transcribe_sync, audio)
def reset_session(self) -> None:
"""Clear the rolling prompt. Call at the start of each new conversation."""
self._session_prompt = ""