- cf_voice/stt.py: WhisperSTT async wrapper (faster-whisper, thread-pool executor, rolling 50-word session prompt for cross-chunk context continuity) - cf_voice/classify.py: ToneClassifier — wav2vec2 SER + librosa prosody flags (energy, ZCR speech rate, YIN pitch contour) mapped to AFFECT_LABELS - cf_voice/diarize.py: Diarizer async wrapper around pyannote/speaker-diarization-3.1; speaker_at() helper for Navigation v0.2.x wiring - cf_voice/capture.py: MicVoiceIO — sounddevice 16kHz mono capture, 2s window accumulation, parallel STT+classify tasks, shift_magnitude from confidence delta - cf_voice/io.py: make_io() now returns MicVoiceIO when CF_VOICE_MOCK is unset - cf_voice/context.py: classify_chunk() split into mock/real paths; real path decodes base64 PCM and runs ToneClassifier synchronously (cf-orch endpoint) - pyproject.toml: inference extras expanded (faster-whisper, sounddevice, librosa, python-dotenv) - .env.example: HF_TOKEN, CF_VOICE_WHISPER_MODEL, CF_VOICE_DEVICE, CF_VOICE_MOCK, CF_VOICE_CONFIDENCE_THRESHOLD Prior art ported from: Plex-Scripts/transcription/diarization.py (pyannote setup), devl/ogma/backend/speech/transcription_engine.py (faster-whisper preprocessing and session prompt pattern).
142 lines
4.5 KiB
Python
142 lines
4.5 KiB
Python
# cf_voice/stt.py — faster-whisper STT wrapper
|
||
#
|
||
# BSL 1.1 when real inference models are integrated.
|
||
# Requires the [inference] extras: pip install cf-voice[inference]
|
||
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import logging
|
||
import os
|
||
from dataclasses import dataclass
|
||
|
||
import numpy as np
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
_VRAM_ESTIMATES_MB: dict[str, int] = {
|
||
"tiny": 150, "base": 300, "small": 500,
|
||
"medium": 1500, "large": 3000, "large-v2": 3000, "large-v3": 3500,
|
||
}
|
||
|
||
# Minimum audio duration in seconds before attempting transcription.
|
||
# Whisper hallucinates on very short clips.
|
||
_MIN_DURATION_S = 0.3
|
||
|
||
|
||
@dataclass
|
||
class STTResult:
|
||
text: str
|
||
language: str
|
||
duration_s: float
|
||
is_final: bool
|
||
|
||
|
||
class WhisperSTT:
|
||
"""
|
||
Async wrapper around faster-whisper for real-time chunk transcription.
|
||
|
||
Runs transcription in a thread pool executor so it never blocks the event
|
||
loop. Maintains a rolling 50-word session prompt to improve context
|
||
continuity across 2-second windows.
|
||
|
||
Usage
|
||
-----
|
||
stt = WhisperSTT.from_env()
|
||
result = await stt.transcribe_chunk_async(pcm_int16_bytes)
|
||
print(result.text)
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
model_name: str = "small",
|
||
device: str = "auto",
|
||
compute_type: str | None = None,
|
||
) -> None:
|
||
try:
|
||
from faster_whisper import WhisperModel
|
||
except ImportError as exc:
|
||
raise ImportError(
|
||
"faster-whisper is required for real STT. "
|
||
"Install with: pip install cf-voice[inference]"
|
||
) from exc
|
||
|
||
if device == "auto":
|
||
try:
|
||
import torch
|
||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||
except ImportError:
|
||
device = "cpu"
|
||
|
||
if compute_type is None:
|
||
compute_type = "float16" if device == "cuda" else "int8"
|
||
|
||
logger.info("Loading Whisper %s on %s (%s)", model_name, device, compute_type)
|
||
self._model = WhisperModel(
|
||
model_name, device=device, compute_type=compute_type
|
||
)
|
||
self._device = device
|
||
self._model_name = model_name
|
||
self._session_prompt: str = ""
|
||
|
||
@classmethod
|
||
def from_env(cls) -> "WhisperSTT":
|
||
"""Construct from CF_VOICE_WHISPER_MODEL and CF_VOICE_DEVICE env vars."""
|
||
return cls(
|
||
model_name=os.environ.get("CF_VOICE_WHISPER_MODEL", "small"),
|
||
device=os.environ.get("CF_VOICE_DEVICE", "auto"),
|
||
)
|
||
|
||
@property
|
||
def vram_mb(self) -> int:
|
||
"""Estimated VRAM usage in MB for this model/compute_type combination."""
|
||
return _VRAM_ESTIMATES_MB.get(self._model_name, 1500)
|
||
|
||
def _transcribe_sync(self, audio_float32: np.ndarray) -> STTResult:
|
||
"""Synchronous transcription — always call via transcribe_chunk_async."""
|
||
duration = len(audio_float32) / 16_000.0
|
||
|
||
if duration < _MIN_DURATION_S:
|
||
return STTResult(
|
||
text="", language="en", duration_s=duration, is_final=False
|
||
)
|
||
|
||
segments, info = self._model.transcribe(
|
||
audio_float32,
|
||
language=None,
|
||
initial_prompt=self._session_prompt or None,
|
||
vad_filter=False, # silence gating happens upstream in MicVoiceIO
|
||
word_timestamps=False,
|
||
beam_size=3,
|
||
temperature=0.0,
|
||
)
|
||
|
||
text = " ".join(s.text.strip() for s in segments).strip()
|
||
|
||
# Rolling context: keep last ~50 words so the next chunk has prior text
|
||
if text:
|
||
words = (self._session_prompt + " " + text).split()
|
||
self._session_prompt = " ".join(words[-50:])
|
||
|
||
return STTResult(
|
||
text=text,
|
||
language=info.language,
|
||
duration_s=duration,
|
||
is_final=duration >= 1.0 and info.language_probability > 0.5,
|
||
)
|
||
|
||
async def transcribe_chunk_async(self, pcm_int16: bytes) -> STTResult:
|
||
"""
|
||
Transcribe a raw PCM Int16 chunk, non-blocking.
|
||
|
||
pcm_int16 should be 16kHz mono bytes. Typical input is 20 × 100ms
|
||
chunks accumulated by MicVoiceIO (2-second window = 64000 bytes).
|
||
"""
|
||
audio = (
|
||
np.frombuffer(pcm_int16, dtype=np.int16).astype(np.float32) / 32768.0
|
||
)
|
||
loop = asyncio.get_event_loop()
|
||
return await loop.run_in_executor(None, self._transcribe_sync, audio)
|
||
|
||
def reset_session(self) -> None:
|
||
"""Clear the rolling prompt. Call at the start of each new conversation."""
|
||
self._session_prompt = ""
|