# cf_voice/diarize.py — async speaker diarization via pyannote.audio # # BSL 1.1: real inference. Requires HF_TOKEN + pyannote model access. # Gate usage with CF_VOICE_MOCK=1 to skip entirely in dev/CI. # # Model used: pyannote/speaker-diarization-3.1 # Requires accepting gated model terms at: # https://huggingface.co/pyannote/speaker-diarization-3.1 # https://huggingface.co/pyannote/segmentation-3.0 from __future__ import annotations import asyncio import logging import os from dataclasses import dataclass import numpy as np logger = logging.getLogger(__name__) _DIARIZATION_MODEL = "pyannote/speaker-diarization-3.1" _SAMPLE_RATE = 16_000 @dataclass class SpeakerSegment: """A speaker-labelled time range within an audio window.""" speaker_id: str # ephemeral local label, e.g. "SPEAKER_00" start_s: float end_s: float @property def duration_s(self) -> float: return self.end_s - self.start_s class Diarizer: """ Async wrapper around pyannote.audio speaker diarization pipeline. PyAnnote's pipeline is synchronous and GPU-bound. We run it in a thread pool executor to avoid blocking the asyncio event loop. The pipeline is loaded once at construction time (model download on first use, then cached by HuggingFace). CUDA is used automatically if available. Usage ----- diarizer = Diarizer.from_env() segments = await diarizer.diarize_async(audio_float32) for seg in segments: print(seg.speaker_id, seg.start_s, seg.end_s) Navigation v0.2.x wires this into ContextClassifier so that each VoiceFrame carries the correct speaker_id from diarization output. """ def __init__(self, hf_token: str) -> None: try: from pyannote.audio import Pipeline except ImportError as exc: raise ImportError( "pyannote.audio is required for speaker diarization. " "Install with: pip install cf-voice[inference]" ) from exc logger.info("Loading diarization pipeline %s", _DIARIZATION_MODEL) self._pipeline = Pipeline.from_pretrained( _DIARIZATION_MODEL, use_auth_token=hf_token, ) # Move to GPU if available try: import torch if torch.cuda.is_available(): self._pipeline.to(torch.device("cuda")) logger.info("Diarization pipeline on CUDA") except ImportError: pass @classmethod def from_env(cls) -> "Diarizer": """Construct from HF_TOKEN environment variable.""" token = os.environ.get("HF_TOKEN", "").strip() if not token: raise EnvironmentError( "HF_TOKEN is required for speaker diarization. " "Set it in your .env or environment. " "See .env.example for setup instructions." ) return cls(hf_token=token) def _diarize_sync( self, audio_float32: np.ndarray, sample_rate: int = _SAMPLE_RATE ) -> list[SpeakerSegment]: """Synchronous diarization — always call via diarize_async.""" import torch # pyannote expects (channels, samples) float32 tensor waveform = torch.from_numpy(audio_float32[np.newaxis, :].astype(np.float32)) diarization = self._pipeline( {"waveform": waveform, "sample_rate": sample_rate} ) segments: list[SpeakerSegment] = [] for turn, _, speaker in diarization.itertracks(yield_label=True): segments.append( SpeakerSegment( speaker_id=speaker, start_s=round(turn.start, 3), end_s=round(turn.end, 3), ) ) return segments async def diarize_async( self, audio_float32: np.ndarray, sample_rate: int = _SAMPLE_RATE, ) -> list[SpeakerSegment]: """ Diarize an audio window without blocking the event loop. audio_float32 should be 16kHz mono float32. Typical input is a 2-second window from MicVoiceIO (32000 samples). Returns segments ordered by start_s. """ loop = asyncio.get_event_loop() return await loop.run_in_executor( None, self._diarize_sync, audio_float32, sample_rate ) def speaker_at( self, segments: list[SpeakerSegment], timestamp_s: float ) -> str: """ Return the speaker_id active at a given timestamp within the window. Falls back to "speaker_a" if no segment covers the timestamp (e.g. during silence or at window boundaries). """ for seg in segments: if seg.start_s <= timestamp_s <= seg.end_s: return seg.speaker_id return "speaker_a"