- cf_voice/stt.py: WhisperSTT async wrapper (faster-whisper, thread-pool executor, rolling 50-word session prompt for cross-chunk context continuity) - cf_voice/classify.py: ToneClassifier — wav2vec2 SER + librosa prosody flags (energy, ZCR speech rate, YIN pitch contour) mapped to AFFECT_LABELS - cf_voice/diarize.py: Diarizer async wrapper around pyannote/speaker-diarization-3.1; speaker_at() helper for Navigation v0.2.x wiring - cf_voice/capture.py: MicVoiceIO — sounddevice 16kHz mono capture, 2s window accumulation, parallel STT+classify tasks, shift_magnitude from confidence delta - cf_voice/io.py: make_io() now returns MicVoiceIO when CF_VOICE_MOCK is unset - cf_voice/context.py: classify_chunk() split into mock/real paths; real path decodes base64 PCM and runs ToneClassifier synchronously (cf-orch endpoint) - pyproject.toml: inference extras expanded (faster-whisper, sounddevice, librosa, python-dotenv) - .env.example: HF_TOKEN, CF_VOICE_WHISPER_MODEL, CF_VOICE_DEVICE, CF_VOICE_MOCK, CF_VOICE_CONFIDENCE_THRESHOLD Prior art ported from: Plex-Scripts/transcription/diarization.py (pyannote setup), devl/ogma/backend/speech/transcription_engine.py (faster-whisper preprocessing and session prompt pattern).
146 lines
4.8 KiB
Python
146 lines
4.8 KiB
Python
# cf_voice/diarize.py — async speaker diarization via pyannote.audio
|
|
#
|
|
# BSL 1.1: real inference. Requires HF_TOKEN + pyannote model access.
|
|
# Gate usage with CF_VOICE_MOCK=1 to skip entirely in dev/CI.
|
|
#
|
|
# Model used: pyannote/speaker-diarization-3.1
|
|
# Requires accepting gated model terms at:
|
|
# https://huggingface.co/pyannote/speaker-diarization-3.1
|
|
# https://huggingface.co/pyannote/segmentation-3.0
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import logging
|
|
import os
|
|
from dataclasses import dataclass
|
|
|
|
import numpy as np
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_DIARIZATION_MODEL = "pyannote/speaker-diarization-3.1"
|
|
_SAMPLE_RATE = 16_000
|
|
|
|
|
|
@dataclass
|
|
class SpeakerSegment:
|
|
"""A speaker-labelled time range within an audio window."""
|
|
speaker_id: str # ephemeral local label, e.g. "SPEAKER_00"
|
|
start_s: float
|
|
end_s: float
|
|
|
|
@property
|
|
def duration_s(self) -> float:
|
|
return self.end_s - self.start_s
|
|
|
|
|
|
class Diarizer:
|
|
"""
|
|
Async wrapper around pyannote.audio speaker diarization pipeline.
|
|
|
|
PyAnnote's pipeline is synchronous and GPU-bound. We run it in a thread
|
|
pool executor to avoid blocking the asyncio event loop.
|
|
|
|
The pipeline is loaded once at construction time (model download on first
|
|
use, then cached by HuggingFace). CUDA is used automatically if available.
|
|
|
|
Usage
|
|
-----
|
|
diarizer = Diarizer.from_env()
|
|
segments = await diarizer.diarize_async(audio_float32)
|
|
for seg in segments:
|
|
print(seg.speaker_id, seg.start_s, seg.end_s)
|
|
|
|
Navigation v0.2.x wires this into ContextClassifier so that each
|
|
VoiceFrame carries the correct speaker_id from diarization output.
|
|
"""
|
|
|
|
def __init__(self, hf_token: str) -> None:
|
|
try:
|
|
from pyannote.audio import Pipeline
|
|
except ImportError as exc:
|
|
raise ImportError(
|
|
"pyannote.audio is required for speaker diarization. "
|
|
"Install with: pip install cf-voice[inference]"
|
|
) from exc
|
|
|
|
logger.info("Loading diarization pipeline %s", _DIARIZATION_MODEL)
|
|
self._pipeline = Pipeline.from_pretrained(
|
|
_DIARIZATION_MODEL,
|
|
use_auth_token=hf_token,
|
|
)
|
|
|
|
# Move to GPU if available
|
|
try:
|
|
import torch
|
|
if torch.cuda.is_available():
|
|
self._pipeline.to(torch.device("cuda"))
|
|
logger.info("Diarization pipeline on CUDA")
|
|
except ImportError:
|
|
pass
|
|
|
|
@classmethod
|
|
def from_env(cls) -> "Diarizer":
|
|
"""Construct from HF_TOKEN environment variable."""
|
|
token = os.environ.get("HF_TOKEN", "").strip()
|
|
if not token:
|
|
raise EnvironmentError(
|
|
"HF_TOKEN is required for speaker diarization. "
|
|
"Set it in your .env or environment. "
|
|
"See .env.example for setup instructions."
|
|
)
|
|
return cls(hf_token=token)
|
|
|
|
def _diarize_sync(
|
|
self, audio_float32: np.ndarray, sample_rate: int = _SAMPLE_RATE
|
|
) -> list[SpeakerSegment]:
|
|
"""Synchronous diarization — always call via diarize_async."""
|
|
import torch
|
|
|
|
# pyannote expects (channels, samples) float32 tensor
|
|
waveform = torch.from_numpy(audio_float32[np.newaxis, :].astype(np.float32))
|
|
diarization = self._pipeline(
|
|
{"waveform": waveform, "sample_rate": sample_rate}
|
|
)
|
|
|
|
segments: list[SpeakerSegment] = []
|
|
for turn, _, speaker in diarization.itertracks(yield_label=True):
|
|
segments.append(
|
|
SpeakerSegment(
|
|
speaker_id=speaker,
|
|
start_s=round(turn.start, 3),
|
|
end_s=round(turn.end, 3),
|
|
)
|
|
)
|
|
return segments
|
|
|
|
async def diarize_async(
|
|
self,
|
|
audio_float32: np.ndarray,
|
|
sample_rate: int = _SAMPLE_RATE,
|
|
) -> list[SpeakerSegment]:
|
|
"""
|
|
Diarize an audio window without blocking the event loop.
|
|
|
|
audio_float32 should be 16kHz mono float32.
|
|
Typical input is a 2-second window from MicVoiceIO (32000 samples).
|
|
Returns segments ordered by start_s.
|
|
"""
|
|
loop = asyncio.get_event_loop()
|
|
return await loop.run_in_executor(
|
|
None, self._diarize_sync, audio_float32, sample_rate
|
|
)
|
|
|
|
def speaker_at(
|
|
self, segments: list[SpeakerSegment], timestamp_s: float
|
|
) -> str:
|
|
"""
|
|
Return the speaker_id active at a given timestamp within the window.
|
|
|
|
Falls back to "speaker_a" if no segment covers the timestamp
|
|
(e.g. during silence or at window boundaries).
|
|
"""
|
|
for seg in segments:
|
|
if seg.start_s <= timestamp_s <= seg.end_s:
|
|
return seg.speaker_id
|
|
return "speaker_a"
|