feat: real inference pipeline — STT, tone classifier, diarization, mic capture
- cf_voice/stt.py: WhisperSTT async wrapper (faster-whisper, thread-pool executor, rolling 50-word session prompt for cross-chunk context continuity) - cf_voice/classify.py: ToneClassifier — wav2vec2 SER + librosa prosody flags (energy, ZCR speech rate, YIN pitch contour) mapped to AFFECT_LABELS - cf_voice/diarize.py: Diarizer async wrapper around pyannote/speaker-diarization-3.1; speaker_at() helper for Navigation v0.2.x wiring - cf_voice/capture.py: MicVoiceIO — sounddevice 16kHz mono capture, 2s window accumulation, parallel STT+classify tasks, shift_magnitude from confidence delta - cf_voice/io.py: make_io() now returns MicVoiceIO when CF_VOICE_MOCK is unset - cf_voice/context.py: classify_chunk() split into mock/real paths; real path decodes base64 PCM and runs ToneClassifier synchronously (cf-orch endpoint) - pyproject.toml: inference extras expanded (faster-whisper, sounddevice, librosa, python-dotenv) - .env.example: HF_TOKEN, CF_VOICE_WHISPER_MODEL, CF_VOICE_DEVICE, CF_VOICE_MOCK, CF_VOICE_CONFIDENCE_THRESHOLD Prior art ported from: Plex-Scripts/transcription/diarization.py (pyannote setup), devl/ogma/backend/speech/transcription_engine.py (faster-whisper preprocessing and session prompt pattern).
This commit is contained in:
parent
6e17da9e93
commit
fed6388b99
8 changed files with 813 additions and 19 deletions
31
.env.example
Normal file
31
.env.example
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
# cf-voice environment — copy to .env and fill in values
|
||||
# cf-voice itself does not auto-load .env; consumers (Linnet, Osprey, etc.)
|
||||
# load it via python-dotenv in their own startup. For standalone cf-voice
|
||||
# dev/testing, source this file manually or install python-dotenv.
|
||||
|
||||
# ── HuggingFace ───────────────────────────────────────────────────────────────
|
||||
# Required for pyannote.audio speaker diarization model download.
|
||||
# Get a free token at https://huggingface.co/settings/tokens
|
||||
# Also accept the gated model terms at:
|
||||
# https://huggingface.co/pyannote/speaker-diarization-3.1
|
||||
# https://huggingface.co/pyannote/segmentation-3.0
|
||||
HF_TOKEN=
|
||||
|
||||
# ── Whisper STT ───────────────────────────────────────────────────────────────
|
||||
# Model size: tiny | base | small | medium | large-v2 | large-v3
|
||||
# Smaller = faster / less VRAM; larger = more accurate.
|
||||
# Recommended: small (500MB VRAM) for real-time use.
|
||||
CF_VOICE_WHISPER_MODEL=small
|
||||
|
||||
# ── Compute ───────────────────────────────────────────────────────────────────
|
||||
# auto (detect GPU), cuda, cpu
|
||||
CF_VOICE_DEVICE=auto
|
||||
|
||||
# ── Mock mode ─────────────────────────────────────────────────────────────────
|
||||
# Set to 1 to use synthetic VoiceFrames — no GPU, mic, or HF token required.
|
||||
# Unset or 0 for real audio capture.
|
||||
CF_VOICE_MOCK=
|
||||
|
||||
# ── Tone classifier ───────────────────────────────────────────────────────────
|
||||
# Minimum confidence to emit a VoiceFrame (below this = frame skipped).
|
||||
CF_VOICE_CONFIDENCE_THRESHOLD=0.55
|
||||
192
cf_voice/capture.py
Normal file
192
cf_voice/capture.py
Normal file
|
|
@ -0,0 +1,192 @@
|
|||
# cf_voice/capture.py — real microphone capture
|
||||
#
|
||||
# MIT licensed. This layer handles audio I/O and buffering only.
|
||||
# Inference (STT, classify) runs here but is imported lazily so that mock
|
||||
# mode works without the [inference] extras installed.
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from typing import AsyncIterator
|
||||
|
||||
import numpy as np
|
||||
|
||||
from cf_voice.models import VoiceFrame
|
||||
from cf_voice.io import VoiceIO
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_SAMPLE_RATE = 16_000
|
||||
_CHUNK_FRAMES = 1_600 # 100ms of audio at 16kHz
|
||||
_WINDOW_CHUNKS = 20 # 20 × 100ms = 2s STT window
|
||||
|
||||
# Skip windows whose RMS is below this (microphone noise floor)
|
||||
_SILENCE_RMS = 0.008
|
||||
|
||||
|
||||
class MicVoiceIO(VoiceIO):
|
||||
"""
|
||||
Real microphone capture producing enriched VoiceFrames.
|
||||
|
||||
Capture loop (sounddevice callback → asyncio.Queue):
|
||||
┌──────────────────────────────────────────────────────────────┐
|
||||
│ sounddevice InputStream │
|
||||
│ 100ms PCM Int16 chunks → _queue │
|
||||
└──────────────────────────────────────────────────────────────┘
|
||||
↓ accumulated 20× (2s window)
|
||||
┌──────────────────────────────────────────────────────────────┐
|
||||
│ Parallel async tasks (both run in thread pool) │
|
||||
│ WhisperSTT.transcribe_chunk_async() → transcript │
|
||||
│ ToneClassifier.classify_async() → ToneResult │
|
||||
└──────────────────────────────────────────────────────────────┘
|
||||
↓ combined
|
||||
VoiceFrame(label, confidence, speaker_id, shift_magnitude, ...)
|
||||
|
||||
speaker_id is always "speaker_a" until Navigation v0.2.x wires in the
|
||||
Diarizer. That integration happens in ContextClassifier, not here.
|
||||
|
||||
Requires: pip install cf-voice[inference]
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device_index: int | None = None,
|
||||
) -> None:
|
||||
# Lazy import so that importing MicVoiceIO doesn't blow up without
|
||||
# inference deps — only instantiation requires them.
|
||||
try:
|
||||
from cf_voice.stt import WhisperSTT
|
||||
from cf_voice.classify import ToneClassifier
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"Real audio capture requires the [inference] extras. "
|
||||
"Install with: pip install cf-voice[inference]"
|
||||
) from exc
|
||||
|
||||
self._stt = WhisperSTT.from_env()
|
||||
self._classifier = ToneClassifier.from_env()
|
||||
self._device_index = device_index
|
||||
self._running = False
|
||||
self._queue: asyncio.Queue[bytes] = asyncio.Queue(maxsize=200)
|
||||
self._prev_label: str = ""
|
||||
self._prev_confidence: float = 0.0
|
||||
|
||||
async def stream(self) -> AsyncIterator[VoiceFrame]: # type: ignore[override]
|
||||
try:
|
||||
import sounddevice as sd
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"sounddevice is required for microphone capture. "
|
||||
"Install with: pip install cf-voice[inference]"
|
||||
) from exc
|
||||
|
||||
self._running = True
|
||||
self._stt.reset_session()
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
def _sd_callback(
|
||||
indata: np.ndarray, frames: int, time_info, status
|
||||
) -> None:
|
||||
if status:
|
||||
logger.debug("sounddevice status: %s", status)
|
||||
pcm = (indata[:, 0] * 32_767).astype(np.int16).tobytes()
|
||||
try:
|
||||
loop.call_soon_threadsafe(self._queue.put_nowait, pcm)
|
||||
except asyncio.QueueFull:
|
||||
logger.warning("Audio queue full — dropping chunk (inference too slow?)")
|
||||
|
||||
input_stream = sd.InputStream(
|
||||
samplerate=_SAMPLE_RATE,
|
||||
channels=1,
|
||||
dtype="float32",
|
||||
blocksize=_CHUNK_FRAMES,
|
||||
device=self._device_index,
|
||||
callback=_sd_callback,
|
||||
)
|
||||
|
||||
window: list[bytes] = []
|
||||
session_start = time.monotonic()
|
||||
|
||||
with input_stream:
|
||||
while self._running:
|
||||
try:
|
||||
chunk = await asyncio.wait_for(self._queue.get(), timeout=0.5)
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
|
||||
window.append(chunk)
|
||||
|
||||
if len(window) < _WINDOW_CHUNKS:
|
||||
continue
|
||||
|
||||
# 2-second window ready
|
||||
raw = b"".join(window)
|
||||
window.clear()
|
||||
|
||||
audio = (
|
||||
np.frombuffer(raw, dtype=np.int16).astype(np.float32) / 32_768.0
|
||||
)
|
||||
|
||||
# Skip silent windows
|
||||
rms = float(np.sqrt(np.mean(audio ** 2)))
|
||||
if rms < _SILENCE_RMS:
|
||||
continue
|
||||
|
||||
# Run STT and tone classification in parallel
|
||||
stt_task = asyncio.create_task(
|
||||
self._stt.transcribe_chunk_async(raw)
|
||||
)
|
||||
tone_task = asyncio.create_task(
|
||||
self._classifier.classify_async(audio)
|
||||
)
|
||||
stt_result, tone = await asyncio.gather(stt_task, tone_task)
|
||||
|
||||
# Update transcript on tone result now that STT is done
|
||||
# (re-classify with text is cheap enough to do inline)
|
||||
if stt_result.text:
|
||||
tone = await self._classifier.classify_async(
|
||||
audio, stt_result.text
|
||||
)
|
||||
|
||||
shift = _compute_shift(
|
||||
self._prev_label,
|
||||
self._prev_confidence,
|
||||
tone.label,
|
||||
tone.confidence,
|
||||
)
|
||||
self._prev_label = tone.label
|
||||
self._prev_confidence = tone.confidence
|
||||
|
||||
yield VoiceFrame(
|
||||
label=tone.label,
|
||||
confidence=tone.confidence,
|
||||
speaker_id="speaker_a", # Navigation v0.2.x: diarizer wired here
|
||||
shift_magnitude=round(shift, 3),
|
||||
timestamp=round(time.monotonic() - session_start, 2),
|
||||
)
|
||||
|
||||
async def stop(self) -> None:
|
||||
self._running = False
|
||||
|
||||
|
||||
# ── Helpers ───────────────────────────────────────────────────────────────────
|
||||
|
||||
def _compute_shift(
|
||||
prev_label: str,
|
||||
prev_confidence: float,
|
||||
curr_label: str,
|
||||
curr_confidence: float,
|
||||
) -> float:
|
||||
"""
|
||||
Compute shift_magnitude for a VoiceFrame.
|
||||
|
||||
0.0 when the label hasn't changed.
|
||||
Higher when the label changes with high confidence in both directions.
|
||||
Capped at 1.0.
|
||||
"""
|
||||
if not prev_label or curr_label == prev_label:
|
||||
return 0.0
|
||||
# A high-confidence transition in both frames = large shift
|
||||
return min(1.0, (prev_confidence + curr_confidence) / 2.0)
|
||||
245
cf_voice/classify.py
Normal file
245
cf_voice/classify.py
Normal file
|
|
@ -0,0 +1,245 @@
|
|||
# cf_voice/classify.py — tone / affect classifier
|
||||
#
|
||||
# BSL 1.1: real inference. Requires [inference] extras.
|
||||
# Stub behaviour: raises NotImplementedError if inference deps not installed.
|
||||
#
|
||||
# Pipeline: wav2vec2 SER (speech emotion recognition) + librosa prosody
|
||||
# features → AFFECT_LABELS defined in cf_voice.events.
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_SAMPLE_RATE = 16_000
|
||||
|
||||
# Confidence floor — results below this are discarded by the caller
|
||||
_DEFAULT_THRESHOLD = float(os.environ.get("CF_VOICE_CONFIDENCE_THRESHOLD", "0.55"))
|
||||
|
||||
# wav2vec2 SER model from HuggingFace
|
||||
# ehcalabres/wav2vec2-lg-xlsr-en-speech-emotion-recognition
|
||||
# Outputs 7 classes: angry, disgust, fear, happy, neutral, sadness, surprise
|
||||
_SER_MODEL_ID = "ehcalabres/wav2vec2-lg-xlsr-en-speech-emotion-recognition"
|
||||
|
||||
# ── Affect label mapping ──────────────────────────────────────────────────────
|
||||
# Maps (emotion, prosody_profile) → affect label from cf_voice.events.AFFECT_LABELS
|
||||
# Prosody profile is a tuple of flags present from _extract_prosody_flags().
|
||||
|
||||
_EMOTION_BASE: dict[str, str] = {
|
||||
"angry": "frustrated",
|
||||
"disgust": "dismissive",
|
||||
"fear": "apologetic",
|
||||
"happy": "warm",
|
||||
"neutral": "neutral",
|
||||
"sadness": "tired",
|
||||
"surprise": "confused",
|
||||
}
|
||||
|
||||
# Prosody-driven overrides: (base_affect, flag) → override affect
|
||||
_PROSODY_OVERRIDES: dict[tuple[str, str], str] = {
|
||||
("neutral", "fast_rate"): "genuine",
|
||||
("neutral", "flat_pitch"): "scripted",
|
||||
("neutral", "low_energy"): "tired",
|
||||
("frustrated", "rising"): "urgent",
|
||||
("warm", "rising"): "genuine",
|
||||
("tired", "rising"): "optimistic",
|
||||
("dismissive", "flat_pitch"): "condescending",
|
||||
}
|
||||
|
||||
# Affect → human-readable VoiceFrame label (reverse of events._label_to_affect)
|
||||
_AFFECT_TO_LABEL: dict[str, str] = {
|
||||
"neutral": "Calm and focused",
|
||||
"warm": "Enthusiastic",
|
||||
"frustrated": "Frustrated but contained",
|
||||
"dismissive": "Politely dismissive",
|
||||
"apologetic": "Nervous but cooperative",
|
||||
"urgent": "Warmly impatient",
|
||||
"condescending": "Politely dismissive",
|
||||
"scripted": "Calm and focused", # scripted reads as neutral to the observer
|
||||
"genuine": "Genuinely curious",
|
||||
"confused": "Confused but engaged",
|
||||
"tired": "Tired and compliant",
|
||||
"optimistic": "Guardedly optimistic",
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToneResult:
|
||||
label: str # human-readable VoiceFrame label
|
||||
affect: str # AFFECT_LABELS key
|
||||
confidence: float
|
||||
prosody_flags: list[str] = field(default_factory=list)
|
||||
|
||||
|
||||
class ToneClassifier:
|
||||
"""
|
||||
Tone/affect classifier: wav2vec2 SER + librosa prosody.
|
||||
|
||||
Loads the model lazily on first call to avoid import-time GPU allocation.
|
||||
Thread-safe for concurrent classify() calls — the pipeline is stateless
|
||||
per-call; session state lives in the caller (ContextClassifier).
|
||||
"""
|
||||
|
||||
def __init__(self, threshold: float = _DEFAULT_THRESHOLD) -> None:
|
||||
self._threshold = threshold
|
||||
self._pipeline = None # lazy-loaded
|
||||
|
||||
@classmethod
|
||||
def from_env(cls) -> "ToneClassifier":
|
||||
threshold = float(os.environ.get("CF_VOICE_CONFIDENCE_THRESHOLD", "0.55"))
|
||||
return cls(threshold=threshold)
|
||||
|
||||
def _load_pipeline(self) -> None:
|
||||
if self._pipeline is not None:
|
||||
return
|
||||
try:
|
||||
from transformers import pipeline as hf_pipeline
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"transformers is required for tone classification. "
|
||||
"Install with: pip install cf-voice[inference]"
|
||||
) from exc
|
||||
|
||||
device = 0 if _cuda_available() else -1
|
||||
logger.info("Loading SER model %s on device %s", _SER_MODEL_ID, device)
|
||||
self._pipeline = hf_pipeline(
|
||||
"audio-classification",
|
||||
model=_SER_MODEL_ID,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def classify(self, audio_float32: np.ndarray, transcript: str = "") -> ToneResult:
|
||||
"""
|
||||
Classify tone/affect from a float32 16kHz mono audio window.
|
||||
|
||||
transcript is used as a weak signal for ambiguous cases (e.g. words
|
||||
like "unfortunately" bias toward apologetic even on a neutral voice).
|
||||
"""
|
||||
self._load_pipeline()
|
||||
|
||||
# Ensure the model sees float32 at the right rate
|
||||
assert audio_float32.dtype == np.float32, "audio must be float32"
|
||||
|
||||
# Run SER
|
||||
preds = self._pipeline({"raw": audio_float32, "sampling_rate": _SAMPLE_RATE})
|
||||
best = max(preds, key=lambda p: p["score"])
|
||||
emotion = best["label"].lower()
|
||||
confidence = float(best["score"])
|
||||
|
||||
# Extract prosody features from raw audio
|
||||
prosody_flags = _extract_prosody_flags(audio_float32)
|
||||
|
||||
# Resolve affect from base emotion + prosody
|
||||
affect = _EMOTION_BASE.get(emotion, "neutral")
|
||||
for flag in prosody_flags:
|
||||
override = _PROSODY_OVERRIDES.get((affect, flag))
|
||||
if override:
|
||||
affect = override
|
||||
break
|
||||
|
||||
# Weak transcript signals
|
||||
affect = _apply_transcript_hints(affect, transcript)
|
||||
|
||||
label = _AFFECT_TO_LABEL.get(affect, "Calm and focused")
|
||||
return ToneResult(
|
||||
label=label,
|
||||
affect=affect,
|
||||
confidence=confidence,
|
||||
prosody_flags=prosody_flags,
|
||||
)
|
||||
|
||||
async def classify_async(
|
||||
self, audio_float32: np.ndarray, transcript: str = ""
|
||||
) -> ToneResult:
|
||||
"""classify() without blocking the event loop."""
|
||||
loop = asyncio.get_event_loop()
|
||||
fn = partial(self.classify, audio_float32, transcript)
|
||||
return await loop.run_in_executor(None, fn)
|
||||
|
||||
|
||||
# ── Prosody helpers ───────────────────────────────────────────────────────────
|
||||
|
||||
def _extract_prosody_flags(audio: np.ndarray) -> list[str]:
|
||||
"""
|
||||
Extract lightweight prosody flags from a float32 16kHz mono window.
|
||||
Returns a list of string flags consumed by _PROSODY_OVERRIDES.
|
||||
"""
|
||||
try:
|
||||
import librosa
|
||||
except ImportError:
|
||||
return []
|
||||
|
||||
flags: list[str] = []
|
||||
|
||||
# Energy (RMS)
|
||||
rms = float(np.sqrt(np.mean(audio ** 2)))
|
||||
if rms < 0.02:
|
||||
flags.append("low_energy")
|
||||
elif rms > 0.15:
|
||||
flags.append("high_energy")
|
||||
|
||||
# Speech rate approximation via zero-crossing rate
|
||||
zcr = float(np.mean(librosa.feature.zero_crossing_rate(audio)))
|
||||
if zcr > 0.12:
|
||||
flags.append("fast_rate")
|
||||
elif zcr < 0.04:
|
||||
flags.append("slow_rate")
|
||||
|
||||
# Pitch contour via YIN
|
||||
try:
|
||||
f0 = librosa.yin(
|
||||
audio,
|
||||
fmin=librosa.note_to_hz("C2"),
|
||||
fmax=librosa.note_to_hz("C7"),
|
||||
sr=_SAMPLE_RATE,
|
||||
)
|
||||
voiced = f0[f0 > 0]
|
||||
if len(voiced) > 5:
|
||||
# Rising: last quarter higher than first quarter
|
||||
q = len(voiced) // 4
|
||||
if q > 0 and np.mean(voiced[-q:]) > np.mean(voiced[:q]) * 1.15:
|
||||
flags.append("rising")
|
||||
# Flat: variance less than 15Hz
|
||||
if np.std(voiced) < 15:
|
||||
flags.append("flat_pitch")
|
||||
except Exception:
|
||||
pass # pitch extraction is best-effort
|
||||
|
||||
return flags
|
||||
|
||||
|
||||
def _apply_transcript_hints(affect: str, transcript: str) -> str:
|
||||
"""
|
||||
Apply weak keyword signals from transcript text to adjust affect.
|
||||
Only overrides when affect is already ambiguous (neutral/tired).
|
||||
"""
|
||||
if not transcript or affect not in ("neutral", "tired"):
|
||||
return affect
|
||||
|
||||
t = transcript.lower()
|
||||
apologetic_words = {"sorry", "apologize", "unfortunately", "afraid", "regret"}
|
||||
urgent_words = {"urgent", "immediately", "asap", "right now", "critical"}
|
||||
dismissive_words = {"policy", "unable to", "cannot", "not possible", "outside"}
|
||||
|
||||
if any(w in t for w in apologetic_words):
|
||||
return "apologetic"
|
||||
if any(w in t for w in urgent_words):
|
||||
return "urgent"
|
||||
if any(w in t for w in dismissive_words):
|
||||
return "dismissive"
|
||||
|
||||
return affect
|
||||
|
||||
|
||||
def _cuda_available() -> bool:
|
||||
try:
|
||||
import torch
|
||||
return torch.cuda.is_available()
|
||||
except ImportError:
|
||||
return False
|
||||
|
|
@ -78,26 +78,26 @@ class ContextClassifier:
|
|||
This is the request-response path used by the cf-orch endpoint.
|
||||
The streaming path (async generator) is for continuous consumers.
|
||||
|
||||
Stub: audio_b64 is ignored; returns synthetic events from the mock IO.
|
||||
Real: decode audio, run YAMNet + SER + pyannote, return events.
|
||||
|
||||
elcor=True switches subtext format to Mass Effect Elcor prefix style.
|
||||
Generic tone annotation is always present regardless of elcor flag.
|
||||
"""
|
||||
if not isinstance(self._io, MockVoiceIO):
|
||||
raise NotImplementedError(
|
||||
"classify_chunk() requires mock mode. "
|
||||
"Real audio inference is not yet implemented."
|
||||
)
|
||||
# Generate a synthetic VoiceFrame to derive events from
|
||||
rng = self._io._rng
|
||||
import time
|
||||
label = rng.choice(self._io._labels)
|
||||
if isinstance(self._io, MockVoiceIO):
|
||||
return self._classify_chunk_mock(timestamp, prior_frames, elcor)
|
||||
|
||||
return self._classify_chunk_real(audio_b64, timestamp, elcor)
|
||||
|
||||
def _classify_chunk_mock(
|
||||
self, timestamp: float, prior_frames: int, elcor: bool
|
||||
) -> list[AudioEvent]:
|
||||
"""Synthetic path — used in mock mode and CI."""
|
||||
rng = self._io._rng # type: ignore[attr-defined]
|
||||
import time as _time
|
||||
label = rng.choice(self._io._labels) # type: ignore[attr-defined]
|
||||
shift = rng.uniform(0.1, 0.7) if prior_frames > 0 else 0.0
|
||||
frame = VoiceFrame(
|
||||
label=label,
|
||||
confidence=rng.uniform(0.6, 0.97),
|
||||
speaker_id=rng.choice(self._io._speakers),
|
||||
speaker_id=rng.choice(self._io._speakers), # type: ignore[attr-defined]
|
||||
shift_magnitude=round(shift, 3),
|
||||
timestamp=timestamp,
|
||||
)
|
||||
|
|
@ -110,6 +110,38 @@ class ContextClassifier:
|
|||
)
|
||||
return [tone]
|
||||
|
||||
def _classify_chunk_real(
|
||||
self, audio_b64: str, timestamp: float, elcor: bool
|
||||
) -> list[AudioEvent]:
|
||||
"""Real inference path — used when CF_VOICE_MOCK is unset."""
|
||||
import asyncio
|
||||
import base64
|
||||
import numpy as np
|
||||
from cf_voice.classify import ToneClassifier
|
||||
|
||||
pcm = base64.b64decode(audio_b64)
|
||||
audio = np.frombuffer(pcm, dtype=np.int16).astype(np.float32) / 32_768.0
|
||||
|
||||
# ToneClassifier is stateless per-call, safe to instantiate inline
|
||||
classifier = ToneClassifier.from_env()
|
||||
tone_result = classifier.classify(audio)
|
||||
|
||||
frame = VoiceFrame(
|
||||
label=tone_result.label,
|
||||
confidence=tone_result.confidence,
|
||||
speaker_id="speaker_a",
|
||||
shift_magnitude=0.0,
|
||||
timestamp=timestamp,
|
||||
)
|
||||
event = tone_event_from_voice_frame(
|
||||
frame_label=frame.label,
|
||||
frame_confidence=frame.confidence,
|
||||
shift_magnitude=frame.shift_magnitude,
|
||||
timestamp=frame.timestamp,
|
||||
elcor=elcor,
|
||||
)
|
||||
return [event]
|
||||
|
||||
def _enrich(self, frame: VoiceFrame) -> VoiceFrame:
|
||||
"""
|
||||
Apply tone classification to a raw frame.
|
||||
|
|
|
|||
146
cf_voice/diarize.py
Normal file
146
cf_voice/diarize.py
Normal file
|
|
@ -0,0 +1,146 @@
|
|||
# cf_voice/diarize.py — async speaker diarization via pyannote.audio
|
||||
#
|
||||
# BSL 1.1: real inference. Requires HF_TOKEN + pyannote model access.
|
||||
# Gate usage with CF_VOICE_MOCK=1 to skip entirely in dev/CI.
|
||||
#
|
||||
# Model used: pyannote/speaker-diarization-3.1
|
||||
# Requires accepting gated model terms at:
|
||||
# https://huggingface.co/pyannote/speaker-diarization-3.1
|
||||
# https://huggingface.co/pyannote/segmentation-3.0
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_DIARIZATION_MODEL = "pyannote/speaker-diarization-3.1"
|
||||
_SAMPLE_RATE = 16_000
|
||||
|
||||
|
||||
@dataclass
|
||||
class SpeakerSegment:
|
||||
"""A speaker-labelled time range within an audio window."""
|
||||
speaker_id: str # ephemeral local label, e.g. "SPEAKER_00"
|
||||
start_s: float
|
||||
end_s: float
|
||||
|
||||
@property
|
||||
def duration_s(self) -> float:
|
||||
return self.end_s - self.start_s
|
||||
|
||||
|
||||
class Diarizer:
|
||||
"""
|
||||
Async wrapper around pyannote.audio speaker diarization pipeline.
|
||||
|
||||
PyAnnote's pipeline is synchronous and GPU-bound. We run it in a thread
|
||||
pool executor to avoid blocking the asyncio event loop.
|
||||
|
||||
The pipeline is loaded once at construction time (model download on first
|
||||
use, then cached by HuggingFace). CUDA is used automatically if available.
|
||||
|
||||
Usage
|
||||
-----
|
||||
diarizer = Diarizer.from_env()
|
||||
segments = await diarizer.diarize_async(audio_float32)
|
||||
for seg in segments:
|
||||
print(seg.speaker_id, seg.start_s, seg.end_s)
|
||||
|
||||
Navigation v0.2.x wires this into ContextClassifier so that each
|
||||
VoiceFrame carries the correct speaker_id from diarization output.
|
||||
"""
|
||||
|
||||
def __init__(self, hf_token: str) -> None:
|
||||
try:
|
||||
from pyannote.audio import Pipeline
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"pyannote.audio is required for speaker diarization. "
|
||||
"Install with: pip install cf-voice[inference]"
|
||||
) from exc
|
||||
|
||||
logger.info("Loading diarization pipeline %s", _DIARIZATION_MODEL)
|
||||
self._pipeline = Pipeline.from_pretrained(
|
||||
_DIARIZATION_MODEL,
|
||||
use_auth_token=hf_token,
|
||||
)
|
||||
|
||||
# Move to GPU if available
|
||||
try:
|
||||
import torch
|
||||
if torch.cuda.is_available():
|
||||
self._pipeline.to(torch.device("cuda"))
|
||||
logger.info("Diarization pipeline on CUDA")
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def from_env(cls) -> "Diarizer":
|
||||
"""Construct from HF_TOKEN environment variable."""
|
||||
token = os.environ.get("HF_TOKEN", "").strip()
|
||||
if not token:
|
||||
raise EnvironmentError(
|
||||
"HF_TOKEN is required for speaker diarization. "
|
||||
"Set it in your .env or environment. "
|
||||
"See .env.example for setup instructions."
|
||||
)
|
||||
return cls(hf_token=token)
|
||||
|
||||
def _diarize_sync(
|
||||
self, audio_float32: np.ndarray, sample_rate: int = _SAMPLE_RATE
|
||||
) -> list[SpeakerSegment]:
|
||||
"""Synchronous diarization — always call via diarize_async."""
|
||||
import torch
|
||||
|
||||
# pyannote expects (channels, samples) float32 tensor
|
||||
waveform = torch.from_numpy(audio_float32[np.newaxis, :].astype(np.float32))
|
||||
diarization = self._pipeline(
|
||||
{"waveform": waveform, "sample_rate": sample_rate}
|
||||
)
|
||||
|
||||
segments: list[SpeakerSegment] = []
|
||||
for turn, _, speaker in diarization.itertracks(yield_label=True):
|
||||
segments.append(
|
||||
SpeakerSegment(
|
||||
speaker_id=speaker,
|
||||
start_s=round(turn.start, 3),
|
||||
end_s=round(turn.end, 3),
|
||||
)
|
||||
)
|
||||
return segments
|
||||
|
||||
async def diarize_async(
|
||||
self,
|
||||
audio_float32: np.ndarray,
|
||||
sample_rate: int = _SAMPLE_RATE,
|
||||
) -> list[SpeakerSegment]:
|
||||
"""
|
||||
Diarize an audio window without blocking the event loop.
|
||||
|
||||
audio_float32 should be 16kHz mono float32.
|
||||
Typical input is a 2-second window from MicVoiceIO (32000 samples).
|
||||
Returns segments ordered by start_s.
|
||||
"""
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(
|
||||
None, self._diarize_sync, audio_float32, sample_rate
|
||||
)
|
||||
|
||||
def speaker_at(
|
||||
self, segments: list[SpeakerSegment], timestamp_s: float
|
||||
) -> str:
|
||||
"""
|
||||
Return the speaker_id active at a given timestamp within the window.
|
||||
|
||||
Falls back to "speaker_a" if no segment covers the timestamp
|
||||
(e.g. during silence or at window boundaries).
|
||||
"""
|
||||
for seg in segments:
|
||||
if seg.start_s <= timestamp_s <= seg.end_s:
|
||||
return seg.speaker_id
|
||||
return "speaker_a"
|
||||
|
|
@ -106,17 +106,17 @@ class MockVoiceIO(VoiceIO):
|
|||
def make_io(
|
||||
mock: bool | None = None,
|
||||
interval_s: float = 2.5,
|
||||
device_index: int | None = None,
|
||||
) -> VoiceIO:
|
||||
"""
|
||||
Factory: return a VoiceIO instance appropriate for the current environment.
|
||||
|
||||
mock=True or CF_VOICE_MOCK=1 → MockVoiceIO (no audio hardware needed)
|
||||
Otherwise → real audio capture (not yet implemented)
|
||||
mock=True or CF_VOICE_MOCK=1 → MockVoiceIO (no GPU, mic, or HF token needed)
|
||||
Otherwise → MicVoiceIO (requires [inference] extras)
|
||||
"""
|
||||
use_mock = mock if mock is not None else os.environ.get("CF_VOICE_MOCK", "") == "1"
|
||||
if use_mock:
|
||||
return MockVoiceIO(interval_s=interval_s)
|
||||
raise NotImplementedError(
|
||||
"Real audio capture is not yet implemented. "
|
||||
"Set CF_VOICE_MOCK=1 or pass mock=True to use synthetic frames."
|
||||
)
|
||||
|
||||
from cf_voice.capture import MicVoiceIO
|
||||
return MicVoiceIO(device_index=device_index)
|
||||
|
|
|
|||
142
cf_voice/stt.py
Normal file
142
cf_voice/stt.py
Normal file
|
|
@ -0,0 +1,142 @@
|
|||
# cf_voice/stt.py — faster-whisper STT wrapper
|
||||
#
|
||||
# BSL 1.1 when real inference models are integrated.
|
||||
# Requires the [inference] extras: pip install cf-voice[inference]
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_VRAM_ESTIMATES_MB: dict[str, int] = {
|
||||
"tiny": 150, "base": 300, "small": 500,
|
||||
"medium": 1500, "large": 3000, "large-v2": 3000, "large-v3": 3500,
|
||||
}
|
||||
|
||||
# Minimum audio duration in seconds before attempting transcription.
|
||||
# Whisper hallucinates on very short clips.
|
||||
_MIN_DURATION_S = 0.3
|
||||
|
||||
|
||||
@dataclass
|
||||
class STTResult:
|
||||
text: str
|
||||
language: str
|
||||
duration_s: float
|
||||
is_final: bool
|
||||
|
||||
|
||||
class WhisperSTT:
|
||||
"""
|
||||
Async wrapper around faster-whisper for real-time chunk transcription.
|
||||
|
||||
Runs transcription in a thread pool executor so it never blocks the event
|
||||
loop. Maintains a rolling 50-word session prompt to improve context
|
||||
continuity across 2-second windows.
|
||||
|
||||
Usage
|
||||
-----
|
||||
stt = WhisperSTT.from_env()
|
||||
result = await stt.transcribe_chunk_async(pcm_int16_bytes)
|
||||
print(result.text)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "small",
|
||||
device: str = "auto",
|
||||
compute_type: str | None = None,
|
||||
) -> None:
|
||||
try:
|
||||
from faster_whisper import WhisperModel
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"faster-whisper is required for real STT. "
|
||||
"Install with: pip install cf-voice[inference]"
|
||||
) from exc
|
||||
|
||||
if device == "auto":
|
||||
try:
|
||||
import torch
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
except ImportError:
|
||||
device = "cpu"
|
||||
|
||||
if compute_type is None:
|
||||
compute_type = "float16" if device == "cuda" else "int8"
|
||||
|
||||
logger.info("Loading Whisper %s on %s (%s)", model_name, device, compute_type)
|
||||
self._model = WhisperModel(
|
||||
model_name, device=device, compute_type=compute_type
|
||||
)
|
||||
self._device = device
|
||||
self._model_name = model_name
|
||||
self._session_prompt: str = ""
|
||||
|
||||
@classmethod
|
||||
def from_env(cls) -> "WhisperSTT":
|
||||
"""Construct from CF_VOICE_WHISPER_MODEL and CF_VOICE_DEVICE env vars."""
|
||||
return cls(
|
||||
model_name=os.environ.get("CF_VOICE_WHISPER_MODEL", "small"),
|
||||
device=os.environ.get("CF_VOICE_DEVICE", "auto"),
|
||||
)
|
||||
|
||||
@property
|
||||
def vram_mb(self) -> int:
|
||||
"""Estimated VRAM usage in MB for this model/compute_type combination."""
|
||||
return _VRAM_ESTIMATES_MB.get(self._model_name, 1500)
|
||||
|
||||
def _transcribe_sync(self, audio_float32: np.ndarray) -> STTResult:
|
||||
"""Synchronous transcription — always call via transcribe_chunk_async."""
|
||||
duration = len(audio_float32) / 16_000.0
|
||||
|
||||
if duration < _MIN_DURATION_S:
|
||||
return STTResult(
|
||||
text="", language="en", duration_s=duration, is_final=False
|
||||
)
|
||||
|
||||
segments, info = self._model.transcribe(
|
||||
audio_float32,
|
||||
language=None,
|
||||
initial_prompt=self._session_prompt or None,
|
||||
vad_filter=False, # silence gating happens upstream in MicVoiceIO
|
||||
word_timestamps=False,
|
||||
beam_size=3,
|
||||
temperature=0.0,
|
||||
)
|
||||
|
||||
text = " ".join(s.text.strip() for s in segments).strip()
|
||||
|
||||
# Rolling context: keep last ~50 words so the next chunk has prior text
|
||||
if text:
|
||||
words = (self._session_prompt + " " + text).split()
|
||||
self._session_prompt = " ".join(words[-50:])
|
||||
|
||||
return STTResult(
|
||||
text=text,
|
||||
language=info.language,
|
||||
duration_s=duration,
|
||||
is_final=duration >= 1.0 and info.language_probability > 0.5,
|
||||
)
|
||||
|
||||
async def transcribe_chunk_async(self, pcm_int16: bytes) -> STTResult:
|
||||
"""
|
||||
Transcribe a raw PCM Int16 chunk, non-blocking.
|
||||
|
||||
pcm_int16 should be 16kHz mono bytes. Typical input is 20 × 100ms
|
||||
chunks accumulated by MicVoiceIO (2-second window = 64000 bytes).
|
||||
"""
|
||||
audio = (
|
||||
np.frombuffer(pcm_int16, dtype=np.int16).astype(np.float32) / 32768.0
|
||||
)
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(None, self._transcribe_sync, audio)
|
||||
|
||||
def reset_session(self) -> None:
|
||||
"""Clear the rolling prompt. Call at the start of each new conversation."""
|
||||
self._session_prompt = ""
|
||||
|
|
@ -18,12 +18,18 @@ dependencies = [
|
|||
inference = [
|
||||
"torch>=2.0",
|
||||
"torchaudio>=2.0",
|
||||
"numpy>=1.24",
|
||||
"faster-whisper>=1.0",
|
||||
"sounddevice>=0.4",
|
||||
"transformers>=4.40",
|
||||
"librosa>=0.10",
|
||||
"pyannote.audio>=3.1",
|
||||
"python-dotenv>=1.0",
|
||||
]
|
||||
dev = [
|
||||
"pytest>=8.0",
|
||||
"pytest-asyncio>=0.23",
|
||||
"numpy>=1.24",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
|
|
|||
Loading…
Reference in a new issue