feat: real inference pipeline — STT, tone classifier, diarization, mic capture
- cf_voice/stt.py: WhisperSTT async wrapper (faster-whisper, thread-pool executor, rolling 50-word session prompt for cross-chunk context continuity) - cf_voice/classify.py: ToneClassifier — wav2vec2 SER + librosa prosody flags (energy, ZCR speech rate, YIN pitch contour) mapped to AFFECT_LABELS - cf_voice/diarize.py: Diarizer async wrapper around pyannote/speaker-diarization-3.1; speaker_at() helper for Navigation v0.2.x wiring - cf_voice/capture.py: MicVoiceIO — sounddevice 16kHz mono capture, 2s window accumulation, parallel STT+classify tasks, shift_magnitude from confidence delta - cf_voice/io.py: make_io() now returns MicVoiceIO when CF_VOICE_MOCK is unset - cf_voice/context.py: classify_chunk() split into mock/real paths; real path decodes base64 PCM and runs ToneClassifier synchronously (cf-orch endpoint) - pyproject.toml: inference extras expanded (faster-whisper, sounddevice, librosa, python-dotenv) - .env.example: HF_TOKEN, CF_VOICE_WHISPER_MODEL, CF_VOICE_DEVICE, CF_VOICE_MOCK, CF_VOICE_CONFIDENCE_THRESHOLD Prior art ported from: Plex-Scripts/transcription/diarization.py (pyannote setup), devl/ogma/backend/speech/transcription_engine.py (faster-whisper preprocessing and session prompt pattern).
This commit is contained in:
parent
6e17da9e93
commit
fed6388b99
8 changed files with 813 additions and 19 deletions
31
.env.example
Normal file
31
.env.example
Normal file
|
|
@ -0,0 +1,31 @@
|
||||||
|
# cf-voice environment — copy to .env and fill in values
|
||||||
|
# cf-voice itself does not auto-load .env; consumers (Linnet, Osprey, etc.)
|
||||||
|
# load it via python-dotenv in their own startup. For standalone cf-voice
|
||||||
|
# dev/testing, source this file manually or install python-dotenv.
|
||||||
|
|
||||||
|
# ── HuggingFace ───────────────────────────────────────────────────────────────
|
||||||
|
# Required for pyannote.audio speaker diarization model download.
|
||||||
|
# Get a free token at https://huggingface.co/settings/tokens
|
||||||
|
# Also accept the gated model terms at:
|
||||||
|
# https://huggingface.co/pyannote/speaker-diarization-3.1
|
||||||
|
# https://huggingface.co/pyannote/segmentation-3.0
|
||||||
|
HF_TOKEN=
|
||||||
|
|
||||||
|
# ── Whisper STT ───────────────────────────────────────────────────────────────
|
||||||
|
# Model size: tiny | base | small | medium | large-v2 | large-v3
|
||||||
|
# Smaller = faster / less VRAM; larger = more accurate.
|
||||||
|
# Recommended: small (500MB VRAM) for real-time use.
|
||||||
|
CF_VOICE_WHISPER_MODEL=small
|
||||||
|
|
||||||
|
# ── Compute ───────────────────────────────────────────────────────────────────
|
||||||
|
# auto (detect GPU), cuda, cpu
|
||||||
|
CF_VOICE_DEVICE=auto
|
||||||
|
|
||||||
|
# ── Mock mode ─────────────────────────────────────────────────────────────────
|
||||||
|
# Set to 1 to use synthetic VoiceFrames — no GPU, mic, or HF token required.
|
||||||
|
# Unset or 0 for real audio capture.
|
||||||
|
CF_VOICE_MOCK=
|
||||||
|
|
||||||
|
# ── Tone classifier ───────────────────────────────────────────────────────────
|
||||||
|
# Minimum confidence to emit a VoiceFrame (below this = frame skipped).
|
||||||
|
CF_VOICE_CONFIDENCE_THRESHOLD=0.55
|
||||||
192
cf_voice/capture.py
Normal file
192
cf_voice/capture.py
Normal file
|
|
@ -0,0 +1,192 @@
|
||||||
|
# cf_voice/capture.py — real microphone capture
|
||||||
|
#
|
||||||
|
# MIT licensed. This layer handles audio I/O and buffering only.
|
||||||
|
# Inference (STT, classify) runs here but is imported lazily so that mock
|
||||||
|
# mode works without the [inference] extras installed.
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from typing import AsyncIterator
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from cf_voice.models import VoiceFrame
|
||||||
|
from cf_voice.io import VoiceIO
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_SAMPLE_RATE = 16_000
|
||||||
|
_CHUNK_FRAMES = 1_600 # 100ms of audio at 16kHz
|
||||||
|
_WINDOW_CHUNKS = 20 # 20 × 100ms = 2s STT window
|
||||||
|
|
||||||
|
# Skip windows whose RMS is below this (microphone noise floor)
|
||||||
|
_SILENCE_RMS = 0.008
|
||||||
|
|
||||||
|
|
||||||
|
class MicVoiceIO(VoiceIO):
|
||||||
|
"""
|
||||||
|
Real microphone capture producing enriched VoiceFrames.
|
||||||
|
|
||||||
|
Capture loop (sounddevice callback → asyncio.Queue):
|
||||||
|
┌──────────────────────────────────────────────────────────────┐
|
||||||
|
│ sounddevice InputStream │
|
||||||
|
│ 100ms PCM Int16 chunks → _queue │
|
||||||
|
└──────────────────────────────────────────────────────────────┘
|
||||||
|
↓ accumulated 20× (2s window)
|
||||||
|
┌──────────────────────────────────────────────────────────────┐
|
||||||
|
│ Parallel async tasks (both run in thread pool) │
|
||||||
|
│ WhisperSTT.transcribe_chunk_async() → transcript │
|
||||||
|
│ ToneClassifier.classify_async() → ToneResult │
|
||||||
|
└──────────────────────────────────────────────────────────────┘
|
||||||
|
↓ combined
|
||||||
|
VoiceFrame(label, confidence, speaker_id, shift_magnitude, ...)
|
||||||
|
|
||||||
|
speaker_id is always "speaker_a" until Navigation v0.2.x wires in the
|
||||||
|
Diarizer. That integration happens in ContextClassifier, not here.
|
||||||
|
|
||||||
|
Requires: pip install cf-voice[inference]
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
device_index: int | None = None,
|
||||||
|
) -> None:
|
||||||
|
# Lazy import so that importing MicVoiceIO doesn't blow up without
|
||||||
|
# inference deps — only instantiation requires them.
|
||||||
|
try:
|
||||||
|
from cf_voice.stt import WhisperSTT
|
||||||
|
from cf_voice.classify import ToneClassifier
|
||||||
|
except ImportError as exc:
|
||||||
|
raise ImportError(
|
||||||
|
"Real audio capture requires the [inference] extras. "
|
||||||
|
"Install with: pip install cf-voice[inference]"
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
self._stt = WhisperSTT.from_env()
|
||||||
|
self._classifier = ToneClassifier.from_env()
|
||||||
|
self._device_index = device_index
|
||||||
|
self._running = False
|
||||||
|
self._queue: asyncio.Queue[bytes] = asyncio.Queue(maxsize=200)
|
||||||
|
self._prev_label: str = ""
|
||||||
|
self._prev_confidence: float = 0.0
|
||||||
|
|
||||||
|
async def stream(self) -> AsyncIterator[VoiceFrame]: # type: ignore[override]
|
||||||
|
try:
|
||||||
|
import sounddevice as sd
|
||||||
|
except ImportError as exc:
|
||||||
|
raise ImportError(
|
||||||
|
"sounddevice is required for microphone capture. "
|
||||||
|
"Install with: pip install cf-voice[inference]"
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
self._running = True
|
||||||
|
self._stt.reset_session()
|
||||||
|
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
|
def _sd_callback(
|
||||||
|
indata: np.ndarray, frames: int, time_info, status
|
||||||
|
) -> None:
|
||||||
|
if status:
|
||||||
|
logger.debug("sounddevice status: %s", status)
|
||||||
|
pcm = (indata[:, 0] * 32_767).astype(np.int16).tobytes()
|
||||||
|
try:
|
||||||
|
loop.call_soon_threadsafe(self._queue.put_nowait, pcm)
|
||||||
|
except asyncio.QueueFull:
|
||||||
|
logger.warning("Audio queue full — dropping chunk (inference too slow?)")
|
||||||
|
|
||||||
|
input_stream = sd.InputStream(
|
||||||
|
samplerate=_SAMPLE_RATE,
|
||||||
|
channels=1,
|
||||||
|
dtype="float32",
|
||||||
|
blocksize=_CHUNK_FRAMES,
|
||||||
|
device=self._device_index,
|
||||||
|
callback=_sd_callback,
|
||||||
|
)
|
||||||
|
|
||||||
|
window: list[bytes] = []
|
||||||
|
session_start = time.monotonic()
|
||||||
|
|
||||||
|
with input_stream:
|
||||||
|
while self._running:
|
||||||
|
try:
|
||||||
|
chunk = await asyncio.wait_for(self._queue.get(), timeout=0.5)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
window.append(chunk)
|
||||||
|
|
||||||
|
if len(window) < _WINDOW_CHUNKS:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 2-second window ready
|
||||||
|
raw = b"".join(window)
|
||||||
|
window.clear()
|
||||||
|
|
||||||
|
audio = (
|
||||||
|
np.frombuffer(raw, dtype=np.int16).astype(np.float32) / 32_768.0
|
||||||
|
)
|
||||||
|
|
||||||
|
# Skip silent windows
|
||||||
|
rms = float(np.sqrt(np.mean(audio ** 2)))
|
||||||
|
if rms < _SILENCE_RMS:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Run STT and tone classification in parallel
|
||||||
|
stt_task = asyncio.create_task(
|
||||||
|
self._stt.transcribe_chunk_async(raw)
|
||||||
|
)
|
||||||
|
tone_task = asyncio.create_task(
|
||||||
|
self._classifier.classify_async(audio)
|
||||||
|
)
|
||||||
|
stt_result, tone = await asyncio.gather(stt_task, tone_task)
|
||||||
|
|
||||||
|
# Update transcript on tone result now that STT is done
|
||||||
|
# (re-classify with text is cheap enough to do inline)
|
||||||
|
if stt_result.text:
|
||||||
|
tone = await self._classifier.classify_async(
|
||||||
|
audio, stt_result.text
|
||||||
|
)
|
||||||
|
|
||||||
|
shift = _compute_shift(
|
||||||
|
self._prev_label,
|
||||||
|
self._prev_confidence,
|
||||||
|
tone.label,
|
||||||
|
tone.confidence,
|
||||||
|
)
|
||||||
|
self._prev_label = tone.label
|
||||||
|
self._prev_confidence = tone.confidence
|
||||||
|
|
||||||
|
yield VoiceFrame(
|
||||||
|
label=tone.label,
|
||||||
|
confidence=tone.confidence,
|
||||||
|
speaker_id="speaker_a", # Navigation v0.2.x: diarizer wired here
|
||||||
|
shift_magnitude=round(shift, 3),
|
||||||
|
timestamp=round(time.monotonic() - session_start, 2),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def stop(self) -> None:
|
||||||
|
self._running = False
|
||||||
|
|
||||||
|
|
||||||
|
# ── Helpers ───────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _compute_shift(
|
||||||
|
prev_label: str,
|
||||||
|
prev_confidence: float,
|
||||||
|
curr_label: str,
|
||||||
|
curr_confidence: float,
|
||||||
|
) -> float:
|
||||||
|
"""
|
||||||
|
Compute shift_magnitude for a VoiceFrame.
|
||||||
|
|
||||||
|
0.0 when the label hasn't changed.
|
||||||
|
Higher when the label changes with high confidence in both directions.
|
||||||
|
Capped at 1.0.
|
||||||
|
"""
|
||||||
|
if not prev_label or curr_label == prev_label:
|
||||||
|
return 0.0
|
||||||
|
# A high-confidence transition in both frames = large shift
|
||||||
|
return min(1.0, (prev_confidence + curr_confidence) / 2.0)
|
||||||
245
cf_voice/classify.py
Normal file
245
cf_voice/classify.py
Normal file
|
|
@ -0,0 +1,245 @@
|
||||||
|
# cf_voice/classify.py — tone / affect classifier
|
||||||
|
#
|
||||||
|
# BSL 1.1: real inference. Requires [inference] extras.
|
||||||
|
# Stub behaviour: raises NotImplementedError if inference deps not installed.
|
||||||
|
#
|
||||||
|
# Pipeline: wav2vec2 SER (speech emotion recognition) + librosa prosody
|
||||||
|
# features → AFFECT_LABELS defined in cf_voice.events.
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_SAMPLE_RATE = 16_000
|
||||||
|
|
||||||
|
# Confidence floor — results below this are discarded by the caller
|
||||||
|
_DEFAULT_THRESHOLD = float(os.environ.get("CF_VOICE_CONFIDENCE_THRESHOLD", "0.55"))
|
||||||
|
|
||||||
|
# wav2vec2 SER model from HuggingFace
|
||||||
|
# ehcalabres/wav2vec2-lg-xlsr-en-speech-emotion-recognition
|
||||||
|
# Outputs 7 classes: angry, disgust, fear, happy, neutral, sadness, surprise
|
||||||
|
_SER_MODEL_ID = "ehcalabres/wav2vec2-lg-xlsr-en-speech-emotion-recognition"
|
||||||
|
|
||||||
|
# ── Affect label mapping ──────────────────────────────────────────────────────
|
||||||
|
# Maps (emotion, prosody_profile) → affect label from cf_voice.events.AFFECT_LABELS
|
||||||
|
# Prosody profile is a tuple of flags present from _extract_prosody_flags().
|
||||||
|
|
||||||
|
_EMOTION_BASE: dict[str, str] = {
|
||||||
|
"angry": "frustrated",
|
||||||
|
"disgust": "dismissive",
|
||||||
|
"fear": "apologetic",
|
||||||
|
"happy": "warm",
|
||||||
|
"neutral": "neutral",
|
||||||
|
"sadness": "tired",
|
||||||
|
"surprise": "confused",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Prosody-driven overrides: (base_affect, flag) → override affect
|
||||||
|
_PROSODY_OVERRIDES: dict[tuple[str, str], str] = {
|
||||||
|
("neutral", "fast_rate"): "genuine",
|
||||||
|
("neutral", "flat_pitch"): "scripted",
|
||||||
|
("neutral", "low_energy"): "tired",
|
||||||
|
("frustrated", "rising"): "urgent",
|
||||||
|
("warm", "rising"): "genuine",
|
||||||
|
("tired", "rising"): "optimistic",
|
||||||
|
("dismissive", "flat_pitch"): "condescending",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Affect → human-readable VoiceFrame label (reverse of events._label_to_affect)
|
||||||
|
_AFFECT_TO_LABEL: dict[str, str] = {
|
||||||
|
"neutral": "Calm and focused",
|
||||||
|
"warm": "Enthusiastic",
|
||||||
|
"frustrated": "Frustrated but contained",
|
||||||
|
"dismissive": "Politely dismissive",
|
||||||
|
"apologetic": "Nervous but cooperative",
|
||||||
|
"urgent": "Warmly impatient",
|
||||||
|
"condescending": "Politely dismissive",
|
||||||
|
"scripted": "Calm and focused", # scripted reads as neutral to the observer
|
||||||
|
"genuine": "Genuinely curious",
|
||||||
|
"confused": "Confused but engaged",
|
||||||
|
"tired": "Tired and compliant",
|
||||||
|
"optimistic": "Guardedly optimistic",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ToneResult:
|
||||||
|
label: str # human-readable VoiceFrame label
|
||||||
|
affect: str # AFFECT_LABELS key
|
||||||
|
confidence: float
|
||||||
|
prosody_flags: list[str] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class ToneClassifier:
|
||||||
|
"""
|
||||||
|
Tone/affect classifier: wav2vec2 SER + librosa prosody.
|
||||||
|
|
||||||
|
Loads the model lazily on first call to avoid import-time GPU allocation.
|
||||||
|
Thread-safe for concurrent classify() calls — the pipeline is stateless
|
||||||
|
per-call; session state lives in the caller (ContextClassifier).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, threshold: float = _DEFAULT_THRESHOLD) -> None:
|
||||||
|
self._threshold = threshold
|
||||||
|
self._pipeline = None # lazy-loaded
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_env(cls) -> "ToneClassifier":
|
||||||
|
threshold = float(os.environ.get("CF_VOICE_CONFIDENCE_THRESHOLD", "0.55"))
|
||||||
|
return cls(threshold=threshold)
|
||||||
|
|
||||||
|
def _load_pipeline(self) -> None:
|
||||||
|
if self._pipeline is not None:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
from transformers import pipeline as hf_pipeline
|
||||||
|
except ImportError as exc:
|
||||||
|
raise ImportError(
|
||||||
|
"transformers is required for tone classification. "
|
||||||
|
"Install with: pip install cf-voice[inference]"
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
device = 0 if _cuda_available() else -1
|
||||||
|
logger.info("Loading SER model %s on device %s", _SER_MODEL_ID, device)
|
||||||
|
self._pipeline = hf_pipeline(
|
||||||
|
"audio-classification",
|
||||||
|
model=_SER_MODEL_ID,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
def classify(self, audio_float32: np.ndarray, transcript: str = "") -> ToneResult:
|
||||||
|
"""
|
||||||
|
Classify tone/affect from a float32 16kHz mono audio window.
|
||||||
|
|
||||||
|
transcript is used as a weak signal for ambiguous cases (e.g. words
|
||||||
|
like "unfortunately" bias toward apologetic even on a neutral voice).
|
||||||
|
"""
|
||||||
|
self._load_pipeline()
|
||||||
|
|
||||||
|
# Ensure the model sees float32 at the right rate
|
||||||
|
assert audio_float32.dtype == np.float32, "audio must be float32"
|
||||||
|
|
||||||
|
# Run SER
|
||||||
|
preds = self._pipeline({"raw": audio_float32, "sampling_rate": _SAMPLE_RATE})
|
||||||
|
best = max(preds, key=lambda p: p["score"])
|
||||||
|
emotion = best["label"].lower()
|
||||||
|
confidence = float(best["score"])
|
||||||
|
|
||||||
|
# Extract prosody features from raw audio
|
||||||
|
prosody_flags = _extract_prosody_flags(audio_float32)
|
||||||
|
|
||||||
|
# Resolve affect from base emotion + prosody
|
||||||
|
affect = _EMOTION_BASE.get(emotion, "neutral")
|
||||||
|
for flag in prosody_flags:
|
||||||
|
override = _PROSODY_OVERRIDES.get((affect, flag))
|
||||||
|
if override:
|
||||||
|
affect = override
|
||||||
|
break
|
||||||
|
|
||||||
|
# Weak transcript signals
|
||||||
|
affect = _apply_transcript_hints(affect, transcript)
|
||||||
|
|
||||||
|
label = _AFFECT_TO_LABEL.get(affect, "Calm and focused")
|
||||||
|
return ToneResult(
|
||||||
|
label=label,
|
||||||
|
affect=affect,
|
||||||
|
confidence=confidence,
|
||||||
|
prosody_flags=prosody_flags,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def classify_async(
|
||||||
|
self, audio_float32: np.ndarray, transcript: str = ""
|
||||||
|
) -> ToneResult:
|
||||||
|
"""classify() without blocking the event loop."""
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
fn = partial(self.classify, audio_float32, transcript)
|
||||||
|
return await loop.run_in_executor(None, fn)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Prosody helpers ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _extract_prosody_flags(audio: np.ndarray) -> list[str]:
|
||||||
|
"""
|
||||||
|
Extract lightweight prosody flags from a float32 16kHz mono window.
|
||||||
|
Returns a list of string flags consumed by _PROSODY_OVERRIDES.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import librosa
|
||||||
|
except ImportError:
|
||||||
|
return []
|
||||||
|
|
||||||
|
flags: list[str] = []
|
||||||
|
|
||||||
|
# Energy (RMS)
|
||||||
|
rms = float(np.sqrt(np.mean(audio ** 2)))
|
||||||
|
if rms < 0.02:
|
||||||
|
flags.append("low_energy")
|
||||||
|
elif rms > 0.15:
|
||||||
|
flags.append("high_energy")
|
||||||
|
|
||||||
|
# Speech rate approximation via zero-crossing rate
|
||||||
|
zcr = float(np.mean(librosa.feature.zero_crossing_rate(audio)))
|
||||||
|
if zcr > 0.12:
|
||||||
|
flags.append("fast_rate")
|
||||||
|
elif zcr < 0.04:
|
||||||
|
flags.append("slow_rate")
|
||||||
|
|
||||||
|
# Pitch contour via YIN
|
||||||
|
try:
|
||||||
|
f0 = librosa.yin(
|
||||||
|
audio,
|
||||||
|
fmin=librosa.note_to_hz("C2"),
|
||||||
|
fmax=librosa.note_to_hz("C7"),
|
||||||
|
sr=_SAMPLE_RATE,
|
||||||
|
)
|
||||||
|
voiced = f0[f0 > 0]
|
||||||
|
if len(voiced) > 5:
|
||||||
|
# Rising: last quarter higher than first quarter
|
||||||
|
q = len(voiced) // 4
|
||||||
|
if q > 0 and np.mean(voiced[-q:]) > np.mean(voiced[:q]) * 1.15:
|
||||||
|
flags.append("rising")
|
||||||
|
# Flat: variance less than 15Hz
|
||||||
|
if np.std(voiced) < 15:
|
||||||
|
flags.append("flat_pitch")
|
||||||
|
except Exception:
|
||||||
|
pass # pitch extraction is best-effort
|
||||||
|
|
||||||
|
return flags
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_transcript_hints(affect: str, transcript: str) -> str:
|
||||||
|
"""
|
||||||
|
Apply weak keyword signals from transcript text to adjust affect.
|
||||||
|
Only overrides when affect is already ambiguous (neutral/tired).
|
||||||
|
"""
|
||||||
|
if not transcript or affect not in ("neutral", "tired"):
|
||||||
|
return affect
|
||||||
|
|
||||||
|
t = transcript.lower()
|
||||||
|
apologetic_words = {"sorry", "apologize", "unfortunately", "afraid", "regret"}
|
||||||
|
urgent_words = {"urgent", "immediately", "asap", "right now", "critical"}
|
||||||
|
dismissive_words = {"policy", "unable to", "cannot", "not possible", "outside"}
|
||||||
|
|
||||||
|
if any(w in t for w in apologetic_words):
|
||||||
|
return "apologetic"
|
||||||
|
if any(w in t for w in urgent_words):
|
||||||
|
return "urgent"
|
||||||
|
if any(w in t for w in dismissive_words):
|
||||||
|
return "dismissive"
|
||||||
|
|
||||||
|
return affect
|
||||||
|
|
||||||
|
|
||||||
|
def _cuda_available() -> bool:
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
return torch.cuda.is_available()
|
||||||
|
except ImportError:
|
||||||
|
return False
|
||||||
|
|
@ -78,26 +78,26 @@ class ContextClassifier:
|
||||||
This is the request-response path used by the cf-orch endpoint.
|
This is the request-response path used by the cf-orch endpoint.
|
||||||
The streaming path (async generator) is for continuous consumers.
|
The streaming path (async generator) is for continuous consumers.
|
||||||
|
|
||||||
Stub: audio_b64 is ignored; returns synthetic events from the mock IO.
|
|
||||||
Real: decode audio, run YAMNet + SER + pyannote, return events.
|
|
||||||
|
|
||||||
elcor=True switches subtext format to Mass Effect Elcor prefix style.
|
elcor=True switches subtext format to Mass Effect Elcor prefix style.
|
||||||
Generic tone annotation is always present regardless of elcor flag.
|
Generic tone annotation is always present regardless of elcor flag.
|
||||||
"""
|
"""
|
||||||
if not isinstance(self._io, MockVoiceIO):
|
if isinstance(self._io, MockVoiceIO):
|
||||||
raise NotImplementedError(
|
return self._classify_chunk_mock(timestamp, prior_frames, elcor)
|
||||||
"classify_chunk() requires mock mode. "
|
|
||||||
"Real audio inference is not yet implemented."
|
return self._classify_chunk_real(audio_b64, timestamp, elcor)
|
||||||
)
|
|
||||||
# Generate a synthetic VoiceFrame to derive events from
|
def _classify_chunk_mock(
|
||||||
rng = self._io._rng
|
self, timestamp: float, prior_frames: int, elcor: bool
|
||||||
import time
|
) -> list[AudioEvent]:
|
||||||
label = rng.choice(self._io._labels)
|
"""Synthetic path — used in mock mode and CI."""
|
||||||
|
rng = self._io._rng # type: ignore[attr-defined]
|
||||||
|
import time as _time
|
||||||
|
label = rng.choice(self._io._labels) # type: ignore[attr-defined]
|
||||||
shift = rng.uniform(0.1, 0.7) if prior_frames > 0 else 0.0
|
shift = rng.uniform(0.1, 0.7) if prior_frames > 0 else 0.0
|
||||||
frame = VoiceFrame(
|
frame = VoiceFrame(
|
||||||
label=label,
|
label=label,
|
||||||
confidence=rng.uniform(0.6, 0.97),
|
confidence=rng.uniform(0.6, 0.97),
|
||||||
speaker_id=rng.choice(self._io._speakers),
|
speaker_id=rng.choice(self._io._speakers), # type: ignore[attr-defined]
|
||||||
shift_magnitude=round(shift, 3),
|
shift_magnitude=round(shift, 3),
|
||||||
timestamp=timestamp,
|
timestamp=timestamp,
|
||||||
)
|
)
|
||||||
|
|
@ -110,6 +110,38 @@ class ContextClassifier:
|
||||||
)
|
)
|
||||||
return [tone]
|
return [tone]
|
||||||
|
|
||||||
|
def _classify_chunk_real(
|
||||||
|
self, audio_b64: str, timestamp: float, elcor: bool
|
||||||
|
) -> list[AudioEvent]:
|
||||||
|
"""Real inference path — used when CF_VOICE_MOCK is unset."""
|
||||||
|
import asyncio
|
||||||
|
import base64
|
||||||
|
import numpy as np
|
||||||
|
from cf_voice.classify import ToneClassifier
|
||||||
|
|
||||||
|
pcm = base64.b64decode(audio_b64)
|
||||||
|
audio = np.frombuffer(pcm, dtype=np.int16).astype(np.float32) / 32_768.0
|
||||||
|
|
||||||
|
# ToneClassifier is stateless per-call, safe to instantiate inline
|
||||||
|
classifier = ToneClassifier.from_env()
|
||||||
|
tone_result = classifier.classify(audio)
|
||||||
|
|
||||||
|
frame = VoiceFrame(
|
||||||
|
label=tone_result.label,
|
||||||
|
confidence=tone_result.confidence,
|
||||||
|
speaker_id="speaker_a",
|
||||||
|
shift_magnitude=0.0,
|
||||||
|
timestamp=timestamp,
|
||||||
|
)
|
||||||
|
event = tone_event_from_voice_frame(
|
||||||
|
frame_label=frame.label,
|
||||||
|
frame_confidence=frame.confidence,
|
||||||
|
shift_magnitude=frame.shift_magnitude,
|
||||||
|
timestamp=frame.timestamp,
|
||||||
|
elcor=elcor,
|
||||||
|
)
|
||||||
|
return [event]
|
||||||
|
|
||||||
def _enrich(self, frame: VoiceFrame) -> VoiceFrame:
|
def _enrich(self, frame: VoiceFrame) -> VoiceFrame:
|
||||||
"""
|
"""
|
||||||
Apply tone classification to a raw frame.
|
Apply tone classification to a raw frame.
|
||||||
|
|
|
||||||
146
cf_voice/diarize.py
Normal file
146
cf_voice/diarize.py
Normal file
|
|
@ -0,0 +1,146 @@
|
||||||
|
# cf_voice/diarize.py — async speaker diarization via pyannote.audio
|
||||||
|
#
|
||||||
|
# BSL 1.1: real inference. Requires HF_TOKEN + pyannote model access.
|
||||||
|
# Gate usage with CF_VOICE_MOCK=1 to skip entirely in dev/CI.
|
||||||
|
#
|
||||||
|
# Model used: pyannote/speaker-diarization-3.1
|
||||||
|
# Requires accepting gated model terms at:
|
||||||
|
# https://huggingface.co/pyannote/speaker-diarization-3.1
|
||||||
|
# https://huggingface.co/pyannote/segmentation-3.0
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_DIARIZATION_MODEL = "pyannote/speaker-diarization-3.1"
|
||||||
|
_SAMPLE_RATE = 16_000
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SpeakerSegment:
|
||||||
|
"""A speaker-labelled time range within an audio window."""
|
||||||
|
speaker_id: str # ephemeral local label, e.g. "SPEAKER_00"
|
||||||
|
start_s: float
|
||||||
|
end_s: float
|
||||||
|
|
||||||
|
@property
|
||||||
|
def duration_s(self) -> float:
|
||||||
|
return self.end_s - self.start_s
|
||||||
|
|
||||||
|
|
||||||
|
class Diarizer:
|
||||||
|
"""
|
||||||
|
Async wrapper around pyannote.audio speaker diarization pipeline.
|
||||||
|
|
||||||
|
PyAnnote's pipeline is synchronous and GPU-bound. We run it in a thread
|
||||||
|
pool executor to avoid blocking the asyncio event loop.
|
||||||
|
|
||||||
|
The pipeline is loaded once at construction time (model download on first
|
||||||
|
use, then cached by HuggingFace). CUDA is used automatically if available.
|
||||||
|
|
||||||
|
Usage
|
||||||
|
-----
|
||||||
|
diarizer = Diarizer.from_env()
|
||||||
|
segments = await diarizer.diarize_async(audio_float32)
|
||||||
|
for seg in segments:
|
||||||
|
print(seg.speaker_id, seg.start_s, seg.end_s)
|
||||||
|
|
||||||
|
Navigation v0.2.x wires this into ContextClassifier so that each
|
||||||
|
VoiceFrame carries the correct speaker_id from diarization output.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, hf_token: str) -> None:
|
||||||
|
try:
|
||||||
|
from pyannote.audio import Pipeline
|
||||||
|
except ImportError as exc:
|
||||||
|
raise ImportError(
|
||||||
|
"pyannote.audio is required for speaker diarization. "
|
||||||
|
"Install with: pip install cf-voice[inference]"
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
logger.info("Loading diarization pipeline %s", _DIARIZATION_MODEL)
|
||||||
|
self._pipeline = Pipeline.from_pretrained(
|
||||||
|
_DIARIZATION_MODEL,
|
||||||
|
use_auth_token=hf_token,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Move to GPU if available
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
self._pipeline.to(torch.device("cuda"))
|
||||||
|
logger.info("Diarization pipeline on CUDA")
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_env(cls) -> "Diarizer":
|
||||||
|
"""Construct from HF_TOKEN environment variable."""
|
||||||
|
token = os.environ.get("HF_TOKEN", "").strip()
|
||||||
|
if not token:
|
||||||
|
raise EnvironmentError(
|
||||||
|
"HF_TOKEN is required for speaker diarization. "
|
||||||
|
"Set it in your .env or environment. "
|
||||||
|
"See .env.example for setup instructions."
|
||||||
|
)
|
||||||
|
return cls(hf_token=token)
|
||||||
|
|
||||||
|
def _diarize_sync(
|
||||||
|
self, audio_float32: np.ndarray, sample_rate: int = _SAMPLE_RATE
|
||||||
|
) -> list[SpeakerSegment]:
|
||||||
|
"""Synchronous diarization — always call via diarize_async."""
|
||||||
|
import torch
|
||||||
|
|
||||||
|
# pyannote expects (channels, samples) float32 tensor
|
||||||
|
waveform = torch.from_numpy(audio_float32[np.newaxis, :].astype(np.float32))
|
||||||
|
diarization = self._pipeline(
|
||||||
|
{"waveform": waveform, "sample_rate": sample_rate}
|
||||||
|
)
|
||||||
|
|
||||||
|
segments: list[SpeakerSegment] = []
|
||||||
|
for turn, _, speaker in diarization.itertracks(yield_label=True):
|
||||||
|
segments.append(
|
||||||
|
SpeakerSegment(
|
||||||
|
speaker_id=speaker,
|
||||||
|
start_s=round(turn.start, 3),
|
||||||
|
end_s=round(turn.end, 3),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return segments
|
||||||
|
|
||||||
|
async def diarize_async(
|
||||||
|
self,
|
||||||
|
audio_float32: np.ndarray,
|
||||||
|
sample_rate: int = _SAMPLE_RATE,
|
||||||
|
) -> list[SpeakerSegment]:
|
||||||
|
"""
|
||||||
|
Diarize an audio window without blocking the event loop.
|
||||||
|
|
||||||
|
audio_float32 should be 16kHz mono float32.
|
||||||
|
Typical input is a 2-second window from MicVoiceIO (32000 samples).
|
||||||
|
Returns segments ordered by start_s.
|
||||||
|
"""
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
return await loop.run_in_executor(
|
||||||
|
None, self._diarize_sync, audio_float32, sample_rate
|
||||||
|
)
|
||||||
|
|
||||||
|
def speaker_at(
|
||||||
|
self, segments: list[SpeakerSegment], timestamp_s: float
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Return the speaker_id active at a given timestamp within the window.
|
||||||
|
|
||||||
|
Falls back to "speaker_a" if no segment covers the timestamp
|
||||||
|
(e.g. during silence or at window boundaries).
|
||||||
|
"""
|
||||||
|
for seg in segments:
|
||||||
|
if seg.start_s <= timestamp_s <= seg.end_s:
|
||||||
|
return seg.speaker_id
|
||||||
|
return "speaker_a"
|
||||||
|
|
@ -106,17 +106,17 @@ class MockVoiceIO(VoiceIO):
|
||||||
def make_io(
|
def make_io(
|
||||||
mock: bool | None = None,
|
mock: bool | None = None,
|
||||||
interval_s: float = 2.5,
|
interval_s: float = 2.5,
|
||||||
|
device_index: int | None = None,
|
||||||
) -> VoiceIO:
|
) -> VoiceIO:
|
||||||
"""
|
"""
|
||||||
Factory: return a VoiceIO instance appropriate for the current environment.
|
Factory: return a VoiceIO instance appropriate for the current environment.
|
||||||
|
|
||||||
mock=True or CF_VOICE_MOCK=1 → MockVoiceIO (no audio hardware needed)
|
mock=True or CF_VOICE_MOCK=1 → MockVoiceIO (no GPU, mic, or HF token needed)
|
||||||
Otherwise → real audio capture (not yet implemented)
|
Otherwise → MicVoiceIO (requires [inference] extras)
|
||||||
"""
|
"""
|
||||||
use_mock = mock if mock is not None else os.environ.get("CF_VOICE_MOCK", "") == "1"
|
use_mock = mock if mock is not None else os.environ.get("CF_VOICE_MOCK", "") == "1"
|
||||||
if use_mock:
|
if use_mock:
|
||||||
return MockVoiceIO(interval_s=interval_s)
|
return MockVoiceIO(interval_s=interval_s)
|
||||||
raise NotImplementedError(
|
|
||||||
"Real audio capture is not yet implemented. "
|
from cf_voice.capture import MicVoiceIO
|
||||||
"Set CF_VOICE_MOCK=1 or pass mock=True to use synthetic frames."
|
return MicVoiceIO(device_index=device_index)
|
||||||
)
|
|
||||||
|
|
|
||||||
142
cf_voice/stt.py
Normal file
142
cf_voice/stt.py
Normal file
|
|
@ -0,0 +1,142 @@
|
||||||
|
# cf_voice/stt.py — faster-whisper STT wrapper
|
||||||
|
#
|
||||||
|
# BSL 1.1 when real inference models are integrated.
|
||||||
|
# Requires the [inference] extras: pip install cf-voice[inference]
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_VRAM_ESTIMATES_MB: dict[str, int] = {
|
||||||
|
"tiny": 150, "base": 300, "small": 500,
|
||||||
|
"medium": 1500, "large": 3000, "large-v2": 3000, "large-v3": 3500,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Minimum audio duration in seconds before attempting transcription.
|
||||||
|
# Whisper hallucinates on very short clips.
|
||||||
|
_MIN_DURATION_S = 0.3
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class STTResult:
|
||||||
|
text: str
|
||||||
|
language: str
|
||||||
|
duration_s: float
|
||||||
|
is_final: bool
|
||||||
|
|
||||||
|
|
||||||
|
class WhisperSTT:
|
||||||
|
"""
|
||||||
|
Async wrapper around faster-whisper for real-time chunk transcription.
|
||||||
|
|
||||||
|
Runs transcription in a thread pool executor so it never blocks the event
|
||||||
|
loop. Maintains a rolling 50-word session prompt to improve context
|
||||||
|
continuity across 2-second windows.
|
||||||
|
|
||||||
|
Usage
|
||||||
|
-----
|
||||||
|
stt = WhisperSTT.from_env()
|
||||||
|
result = await stt.transcribe_chunk_async(pcm_int16_bytes)
|
||||||
|
print(result.text)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_name: str = "small",
|
||||||
|
device: str = "auto",
|
||||||
|
compute_type: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
try:
|
||||||
|
from faster_whisper import WhisperModel
|
||||||
|
except ImportError as exc:
|
||||||
|
raise ImportError(
|
||||||
|
"faster-whisper is required for real STT. "
|
||||||
|
"Install with: pip install cf-voice[inference]"
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
if device == "auto":
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
except ImportError:
|
||||||
|
device = "cpu"
|
||||||
|
|
||||||
|
if compute_type is None:
|
||||||
|
compute_type = "float16" if device == "cuda" else "int8"
|
||||||
|
|
||||||
|
logger.info("Loading Whisper %s on %s (%s)", model_name, device, compute_type)
|
||||||
|
self._model = WhisperModel(
|
||||||
|
model_name, device=device, compute_type=compute_type
|
||||||
|
)
|
||||||
|
self._device = device
|
||||||
|
self._model_name = model_name
|
||||||
|
self._session_prompt: str = ""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_env(cls) -> "WhisperSTT":
|
||||||
|
"""Construct from CF_VOICE_WHISPER_MODEL and CF_VOICE_DEVICE env vars."""
|
||||||
|
return cls(
|
||||||
|
model_name=os.environ.get("CF_VOICE_WHISPER_MODEL", "small"),
|
||||||
|
device=os.environ.get("CF_VOICE_DEVICE", "auto"),
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def vram_mb(self) -> int:
|
||||||
|
"""Estimated VRAM usage in MB for this model/compute_type combination."""
|
||||||
|
return _VRAM_ESTIMATES_MB.get(self._model_name, 1500)
|
||||||
|
|
||||||
|
def _transcribe_sync(self, audio_float32: np.ndarray) -> STTResult:
|
||||||
|
"""Synchronous transcription — always call via transcribe_chunk_async."""
|
||||||
|
duration = len(audio_float32) / 16_000.0
|
||||||
|
|
||||||
|
if duration < _MIN_DURATION_S:
|
||||||
|
return STTResult(
|
||||||
|
text="", language="en", duration_s=duration, is_final=False
|
||||||
|
)
|
||||||
|
|
||||||
|
segments, info = self._model.transcribe(
|
||||||
|
audio_float32,
|
||||||
|
language=None,
|
||||||
|
initial_prompt=self._session_prompt or None,
|
||||||
|
vad_filter=False, # silence gating happens upstream in MicVoiceIO
|
||||||
|
word_timestamps=False,
|
||||||
|
beam_size=3,
|
||||||
|
temperature=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
text = " ".join(s.text.strip() for s in segments).strip()
|
||||||
|
|
||||||
|
# Rolling context: keep last ~50 words so the next chunk has prior text
|
||||||
|
if text:
|
||||||
|
words = (self._session_prompt + " " + text).split()
|
||||||
|
self._session_prompt = " ".join(words[-50:])
|
||||||
|
|
||||||
|
return STTResult(
|
||||||
|
text=text,
|
||||||
|
language=info.language,
|
||||||
|
duration_s=duration,
|
||||||
|
is_final=duration >= 1.0 and info.language_probability > 0.5,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def transcribe_chunk_async(self, pcm_int16: bytes) -> STTResult:
|
||||||
|
"""
|
||||||
|
Transcribe a raw PCM Int16 chunk, non-blocking.
|
||||||
|
|
||||||
|
pcm_int16 should be 16kHz mono bytes. Typical input is 20 × 100ms
|
||||||
|
chunks accumulated by MicVoiceIO (2-second window = 64000 bytes).
|
||||||
|
"""
|
||||||
|
audio = (
|
||||||
|
np.frombuffer(pcm_int16, dtype=np.int16).astype(np.float32) / 32768.0
|
||||||
|
)
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
return await loop.run_in_executor(None, self._transcribe_sync, audio)
|
||||||
|
|
||||||
|
def reset_session(self) -> None:
|
||||||
|
"""Clear the rolling prompt. Call at the start of each new conversation."""
|
||||||
|
self._session_prompt = ""
|
||||||
|
|
@ -18,12 +18,18 @@ dependencies = [
|
||||||
inference = [
|
inference = [
|
||||||
"torch>=2.0",
|
"torch>=2.0",
|
||||||
"torchaudio>=2.0",
|
"torchaudio>=2.0",
|
||||||
|
"numpy>=1.24",
|
||||||
|
"faster-whisper>=1.0",
|
||||||
|
"sounddevice>=0.4",
|
||||||
"transformers>=4.40",
|
"transformers>=4.40",
|
||||||
|
"librosa>=0.10",
|
||||||
"pyannote.audio>=3.1",
|
"pyannote.audio>=3.1",
|
||||||
|
"python-dotenv>=1.0",
|
||||||
]
|
]
|
||||||
dev = [
|
dev = [
|
||||||
"pytest>=8.0",
|
"pytest>=8.0",
|
||||||
"pytest-asyncio>=0.23",
|
"pytest-asyncio>=0.23",
|
||||||
|
"numpy>=1.24",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue