cf-voice/cf_voice/diarize.py
pyr0ball fed6388b99 feat: real inference pipeline — STT, tone classifier, diarization, mic capture
- cf_voice/stt.py: WhisperSTT async wrapper (faster-whisper, thread-pool executor,
  rolling 50-word session prompt for cross-chunk context continuity)
- cf_voice/classify.py: ToneClassifier — wav2vec2 SER + librosa prosody flags
  (energy, ZCR speech rate, YIN pitch contour) mapped to AFFECT_LABELS
- cf_voice/diarize.py: Diarizer async wrapper around pyannote/speaker-diarization-3.1;
  speaker_at() helper for Navigation v0.2.x wiring
- cf_voice/capture.py: MicVoiceIO — sounddevice 16kHz mono capture, 2s window
  accumulation, parallel STT+classify tasks, shift_magnitude from confidence delta
- cf_voice/io.py: make_io() now returns MicVoiceIO when CF_VOICE_MOCK is unset
- cf_voice/context.py: classify_chunk() split into mock/real paths; real path
  decodes base64 PCM and runs ToneClassifier synchronously (cf-orch endpoint)
- pyproject.toml: inference extras expanded (faster-whisper, sounddevice,
  librosa, python-dotenv)
- .env.example: HF_TOKEN, CF_VOICE_WHISPER_MODEL, CF_VOICE_DEVICE, CF_VOICE_MOCK,
  CF_VOICE_CONFIDENCE_THRESHOLD

Prior art ported from: Plex-Scripts/transcription/diarization.py (pyannote
setup), devl/ogma/backend/speech/transcription_engine.py (faster-whisper
preprocessing and session prompt pattern).
2026-04-06 17:33:51 -07:00

146 lines
4.8 KiB
Python

# cf_voice/diarize.py — async speaker diarization via pyannote.audio
#
# BSL 1.1: real inference. Requires HF_TOKEN + pyannote model access.
# Gate usage with CF_VOICE_MOCK=1 to skip entirely in dev/CI.
#
# Model used: pyannote/speaker-diarization-3.1
# Requires accepting gated model terms at:
# https://huggingface.co/pyannote/speaker-diarization-3.1
# https://huggingface.co/pyannote/segmentation-3.0
from __future__ import annotations
import asyncio
import logging
import os
from dataclasses import dataclass
import numpy as np
logger = logging.getLogger(__name__)
_DIARIZATION_MODEL = "pyannote/speaker-diarization-3.1"
_SAMPLE_RATE = 16_000
@dataclass
class SpeakerSegment:
"""A speaker-labelled time range within an audio window."""
speaker_id: str # ephemeral local label, e.g. "SPEAKER_00"
start_s: float
end_s: float
@property
def duration_s(self) -> float:
return self.end_s - self.start_s
class Diarizer:
"""
Async wrapper around pyannote.audio speaker diarization pipeline.
PyAnnote's pipeline is synchronous and GPU-bound. We run it in a thread
pool executor to avoid blocking the asyncio event loop.
The pipeline is loaded once at construction time (model download on first
use, then cached by HuggingFace). CUDA is used automatically if available.
Usage
-----
diarizer = Diarizer.from_env()
segments = await diarizer.diarize_async(audio_float32)
for seg in segments:
print(seg.speaker_id, seg.start_s, seg.end_s)
Navigation v0.2.x wires this into ContextClassifier so that each
VoiceFrame carries the correct speaker_id from diarization output.
"""
def __init__(self, hf_token: str) -> None:
try:
from pyannote.audio import Pipeline
except ImportError as exc:
raise ImportError(
"pyannote.audio is required for speaker diarization. "
"Install with: pip install cf-voice[inference]"
) from exc
logger.info("Loading diarization pipeline %s", _DIARIZATION_MODEL)
self._pipeline = Pipeline.from_pretrained(
_DIARIZATION_MODEL,
use_auth_token=hf_token,
)
# Move to GPU if available
try:
import torch
if torch.cuda.is_available():
self._pipeline.to(torch.device("cuda"))
logger.info("Diarization pipeline on CUDA")
except ImportError:
pass
@classmethod
def from_env(cls) -> "Diarizer":
"""Construct from HF_TOKEN environment variable."""
token = os.environ.get("HF_TOKEN", "").strip()
if not token:
raise EnvironmentError(
"HF_TOKEN is required for speaker diarization. "
"Set it in your .env or environment. "
"See .env.example for setup instructions."
)
return cls(hf_token=token)
def _diarize_sync(
self, audio_float32: np.ndarray, sample_rate: int = _SAMPLE_RATE
) -> list[SpeakerSegment]:
"""Synchronous diarization — always call via diarize_async."""
import torch
# pyannote expects (channels, samples) float32 tensor
waveform = torch.from_numpy(audio_float32[np.newaxis, :].astype(np.float32))
diarization = self._pipeline(
{"waveform": waveform, "sample_rate": sample_rate}
)
segments: list[SpeakerSegment] = []
for turn, _, speaker in diarization.itertracks(yield_label=True):
segments.append(
SpeakerSegment(
speaker_id=speaker,
start_s=round(turn.start, 3),
end_s=round(turn.end, 3),
)
)
return segments
async def diarize_async(
self,
audio_float32: np.ndarray,
sample_rate: int = _SAMPLE_RATE,
) -> list[SpeakerSegment]:
"""
Diarize an audio window without blocking the event loop.
audio_float32 should be 16kHz mono float32.
Typical input is a 2-second window from MicVoiceIO (32000 samples).
Returns segments ordered by start_s.
"""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
None, self._diarize_sync, audio_float32, sample_rate
)
def speaker_at(
self, segments: list[SpeakerSegment], timestamp_s: float
) -> str:
"""
Return the speaker_id active at a given timestamp within the window.
Falls back to "speaker_a" if no segment covers the timestamp
(e.g. during silence or at window boundaries).
"""
for seg in segments:
if seg.start_s <= timestamp_s <= seg.end_s:
return seg.speaker_id
return "speaker_a"