cf-voice/cf_voice/classify.py
pyr0ball fed6388b99 feat: real inference pipeline — STT, tone classifier, diarization, mic capture
- cf_voice/stt.py: WhisperSTT async wrapper (faster-whisper, thread-pool executor,
  rolling 50-word session prompt for cross-chunk context continuity)
- cf_voice/classify.py: ToneClassifier — wav2vec2 SER + librosa prosody flags
  (energy, ZCR speech rate, YIN pitch contour) mapped to AFFECT_LABELS
- cf_voice/diarize.py: Diarizer async wrapper around pyannote/speaker-diarization-3.1;
  speaker_at() helper for Navigation v0.2.x wiring
- cf_voice/capture.py: MicVoiceIO — sounddevice 16kHz mono capture, 2s window
  accumulation, parallel STT+classify tasks, shift_magnitude from confidence delta
- cf_voice/io.py: make_io() now returns MicVoiceIO when CF_VOICE_MOCK is unset
- cf_voice/context.py: classify_chunk() split into mock/real paths; real path
  decodes base64 PCM and runs ToneClassifier synchronously (cf-orch endpoint)
- pyproject.toml: inference extras expanded (faster-whisper, sounddevice,
  librosa, python-dotenv)
- .env.example: HF_TOKEN, CF_VOICE_WHISPER_MODEL, CF_VOICE_DEVICE, CF_VOICE_MOCK,
  CF_VOICE_CONFIDENCE_THRESHOLD

Prior art ported from: Plex-Scripts/transcription/diarization.py (pyannote
setup), devl/ogma/backend/speech/transcription_engine.py (faster-whisper
preprocessing and session prompt pattern).
2026-04-06 17:33:51 -07:00

245 lines
8.4 KiB
Python

# cf_voice/classify.py — tone / affect classifier
#
# BSL 1.1: real inference. Requires [inference] extras.
# Stub behaviour: raises NotImplementedError if inference deps not installed.
#
# Pipeline: wav2vec2 SER (speech emotion recognition) + librosa prosody
# features → AFFECT_LABELS defined in cf_voice.events.
from __future__ import annotations
import asyncio
import logging
import os
from dataclasses import dataclass, field
from functools import partial
import numpy as np
logger = logging.getLogger(__name__)
_SAMPLE_RATE = 16_000
# Confidence floor — results below this are discarded by the caller
_DEFAULT_THRESHOLD = float(os.environ.get("CF_VOICE_CONFIDENCE_THRESHOLD", "0.55"))
# wav2vec2 SER model from HuggingFace
# ehcalabres/wav2vec2-lg-xlsr-en-speech-emotion-recognition
# Outputs 7 classes: angry, disgust, fear, happy, neutral, sadness, surprise
_SER_MODEL_ID = "ehcalabres/wav2vec2-lg-xlsr-en-speech-emotion-recognition"
# ── Affect label mapping ──────────────────────────────────────────────────────
# Maps (emotion, prosody_profile) → affect label from cf_voice.events.AFFECT_LABELS
# Prosody profile is a tuple of flags present from _extract_prosody_flags().
_EMOTION_BASE: dict[str, str] = {
"angry": "frustrated",
"disgust": "dismissive",
"fear": "apologetic",
"happy": "warm",
"neutral": "neutral",
"sadness": "tired",
"surprise": "confused",
}
# Prosody-driven overrides: (base_affect, flag) → override affect
_PROSODY_OVERRIDES: dict[tuple[str, str], str] = {
("neutral", "fast_rate"): "genuine",
("neutral", "flat_pitch"): "scripted",
("neutral", "low_energy"): "tired",
("frustrated", "rising"): "urgent",
("warm", "rising"): "genuine",
("tired", "rising"): "optimistic",
("dismissive", "flat_pitch"): "condescending",
}
# Affect → human-readable VoiceFrame label (reverse of events._label_to_affect)
_AFFECT_TO_LABEL: dict[str, str] = {
"neutral": "Calm and focused",
"warm": "Enthusiastic",
"frustrated": "Frustrated but contained",
"dismissive": "Politely dismissive",
"apologetic": "Nervous but cooperative",
"urgent": "Warmly impatient",
"condescending": "Politely dismissive",
"scripted": "Calm and focused", # scripted reads as neutral to the observer
"genuine": "Genuinely curious",
"confused": "Confused but engaged",
"tired": "Tired and compliant",
"optimistic": "Guardedly optimistic",
}
@dataclass
class ToneResult:
label: str # human-readable VoiceFrame label
affect: str # AFFECT_LABELS key
confidence: float
prosody_flags: list[str] = field(default_factory=list)
class ToneClassifier:
"""
Tone/affect classifier: wav2vec2 SER + librosa prosody.
Loads the model lazily on first call to avoid import-time GPU allocation.
Thread-safe for concurrent classify() calls — the pipeline is stateless
per-call; session state lives in the caller (ContextClassifier).
"""
def __init__(self, threshold: float = _DEFAULT_THRESHOLD) -> None:
self._threshold = threshold
self._pipeline = None # lazy-loaded
@classmethod
def from_env(cls) -> "ToneClassifier":
threshold = float(os.environ.get("CF_VOICE_CONFIDENCE_THRESHOLD", "0.55"))
return cls(threshold=threshold)
def _load_pipeline(self) -> None:
if self._pipeline is not None:
return
try:
from transformers import pipeline as hf_pipeline
except ImportError as exc:
raise ImportError(
"transformers is required for tone classification. "
"Install with: pip install cf-voice[inference]"
) from exc
device = 0 if _cuda_available() else -1
logger.info("Loading SER model %s on device %s", _SER_MODEL_ID, device)
self._pipeline = hf_pipeline(
"audio-classification",
model=_SER_MODEL_ID,
device=device,
)
def classify(self, audio_float32: np.ndarray, transcript: str = "") -> ToneResult:
"""
Classify tone/affect from a float32 16kHz mono audio window.
transcript is used as a weak signal for ambiguous cases (e.g. words
like "unfortunately" bias toward apologetic even on a neutral voice).
"""
self._load_pipeline()
# Ensure the model sees float32 at the right rate
assert audio_float32.dtype == np.float32, "audio must be float32"
# Run SER
preds = self._pipeline({"raw": audio_float32, "sampling_rate": _SAMPLE_RATE})
best = max(preds, key=lambda p: p["score"])
emotion = best["label"].lower()
confidence = float(best["score"])
# Extract prosody features from raw audio
prosody_flags = _extract_prosody_flags(audio_float32)
# Resolve affect from base emotion + prosody
affect = _EMOTION_BASE.get(emotion, "neutral")
for flag in prosody_flags:
override = _PROSODY_OVERRIDES.get((affect, flag))
if override:
affect = override
break
# Weak transcript signals
affect = _apply_transcript_hints(affect, transcript)
label = _AFFECT_TO_LABEL.get(affect, "Calm and focused")
return ToneResult(
label=label,
affect=affect,
confidence=confidence,
prosody_flags=prosody_flags,
)
async def classify_async(
self, audio_float32: np.ndarray, transcript: str = ""
) -> ToneResult:
"""classify() without blocking the event loop."""
loop = asyncio.get_event_loop()
fn = partial(self.classify, audio_float32, transcript)
return await loop.run_in_executor(None, fn)
# ── Prosody helpers ───────────────────────────────────────────────────────────
def _extract_prosody_flags(audio: np.ndarray) -> list[str]:
"""
Extract lightweight prosody flags from a float32 16kHz mono window.
Returns a list of string flags consumed by _PROSODY_OVERRIDES.
"""
try:
import librosa
except ImportError:
return []
flags: list[str] = []
# Energy (RMS)
rms = float(np.sqrt(np.mean(audio ** 2)))
if rms < 0.02:
flags.append("low_energy")
elif rms > 0.15:
flags.append("high_energy")
# Speech rate approximation via zero-crossing rate
zcr = float(np.mean(librosa.feature.zero_crossing_rate(audio)))
if zcr > 0.12:
flags.append("fast_rate")
elif zcr < 0.04:
flags.append("slow_rate")
# Pitch contour via YIN
try:
f0 = librosa.yin(
audio,
fmin=librosa.note_to_hz("C2"),
fmax=librosa.note_to_hz("C7"),
sr=_SAMPLE_RATE,
)
voiced = f0[f0 > 0]
if len(voiced) > 5:
# Rising: last quarter higher than first quarter
q = len(voiced) // 4
if q > 0 and np.mean(voiced[-q:]) > np.mean(voiced[:q]) * 1.15:
flags.append("rising")
# Flat: variance less than 15Hz
if np.std(voiced) < 15:
flags.append("flat_pitch")
except Exception:
pass # pitch extraction is best-effort
return flags
def _apply_transcript_hints(affect: str, transcript: str) -> str:
"""
Apply weak keyword signals from transcript text to adjust affect.
Only overrides when affect is already ambiguous (neutral/tired).
"""
if not transcript or affect not in ("neutral", "tired"):
return affect
t = transcript.lower()
apologetic_words = {"sorry", "apologize", "unfortunately", "afraid", "regret"}
urgent_words = {"urgent", "immediately", "asap", "right now", "critical"}
dismissive_words = {"policy", "unable to", "cannot", "not possible", "outside"}
if any(w in t for w in apologetic_words):
return "apologetic"
if any(w in t for w in urgent_words):
return "urgent"
if any(w in t for w in dismissive_words):
return "dismissive"
return affect
def _cuda_available() -> bool:
try:
import torch
return torch.cuda.is_available()
except ImportError:
return False