# cf_voice/diarize.py — async speaker diarization via pyannote.audio # # BSL 1.1: real inference. Requires HF_TOKEN + pyannote model access. # Gate usage with CF_VOICE_MOCK=1 to skip entirely in dev/CI. # # Model used: pyannote/speaker-diarization-3.1 # Requires accepting gated model terms at: # https://huggingface.co/pyannote/speaker-diarization-3.1 # https://huggingface.co/pyannote/segmentation-3.0 # # Enable with: CF_VOICE_DIARIZE=1 (default off) # Requires: HF_TOKEN set in environment from __future__ import annotations import asyncio import logging import os import string from dataclasses import dataclass, field import numpy as np logger = logging.getLogger(__name__) _DIARIZATION_MODEL = "pyannote/speaker-diarization-3.1" _SAMPLE_RATE = 16_000 # Label returned when two speakers overlap in the same window SPEAKER_MULTIPLE = "Multiple" # Label returned when no speaker segment covers the timestamp (silence / VAD miss) SPEAKER_UNKNOWN = "speaker_a" @dataclass class SpeakerSegment: """A speaker-labelled time range within an audio window.""" speaker_id: str # raw pyannote label, e.g. "SPEAKER_00" start_s: float end_s: float @property def duration_s(self) -> float: return self.end_s - self.start_s class SpeakerTracker: """ Maps ephemeral pyannote speaker IDs to stable per-session friendly labels. pyannote returns IDs like "SPEAKER_00", "SPEAKER_01" which are opaque and may differ across audio windows. SpeakerTracker assigns a consistent friendly label ("Speaker A", "Speaker B", ...) for the lifetime of one session, based on first-seen order. Speaker embeddings are never stored — only the raw_id → label string map, which contains no biometric information. Call reset() at session end to discard the map. For sessions with more than 26 speakers, labels wrap to "Speaker AA", "Speaker AB", etc. (unlikely in practice but handled gracefully). """ def __init__(self) -> None: self._map: dict[str, str] = {} self._counter: int = 0 def label(self, raw_id: str) -> str: """Return the friendly label for a pyannote speaker ID.""" if raw_id not in self._map: self._map[raw_id] = self._next_label() return self._map[raw_id] def reset(self) -> None: """Discard all label mappings. Call at session end.""" self._map.clear() self._counter = 0 def _next_label(self) -> str: idx = self._counter self._counter += 1 letters = string.ascii_uppercase n = len(letters) if idx < n: return f"Speaker {letters[idx]}" # Two-letter suffix for >26 speakers outer = idx // n inner = idx % n return f"Speaker {letters[outer - 1]}{letters[inner]}" class Diarizer: """ Async wrapper around pyannote.audio speaker diarization pipeline. PyAnnote's pipeline is synchronous and GPU-bound. We run it in a thread pool executor to avoid blocking the asyncio event loop. The pipeline is loaded once at construction time (model download on first use, then cached by HuggingFace). CUDA is used automatically if available. Usage ----- diarizer = Diarizer.from_env() tracker = SpeakerTracker() segments = await diarizer.diarize_async(audio_float32) label = diarizer.speaker_at(segments, timestamp_s=1.0, tracker=tracker) Navigation v0.2.x wires this into ContextClassifier so that each VoiceFrame carries the correct speaker_id from diarization output. """ def __init__(self, hf_token: str) -> None: try: from pyannote.audio import Pipeline except ImportError as exc: raise ImportError( "pyannote.audio is required for speaker diarization. " "Install with: pip install cf-voice[inference]" ) from exc logger.info("Loading diarization pipeline %s", _DIARIZATION_MODEL) self._pipeline = Pipeline.from_pretrained( _DIARIZATION_MODEL, token=hf_token, ) # Move to GPU if available try: import torch if torch.cuda.is_available(): self._pipeline.to(torch.device("cuda")) logger.info("Diarization pipeline on CUDA") except ImportError: pass @classmethod def from_env(cls) -> "Diarizer": """Construct from HF_TOKEN environment variable.""" token = os.environ.get("HF_TOKEN", "").strip() if not token: raise EnvironmentError( "HF_TOKEN is required for speaker diarization. " "Set it in your .env or environment. " "See .env.example for setup instructions." ) return cls(hf_token=token) def _diarize_sync( self, audio_float32: np.ndarray, sample_rate: int = _SAMPLE_RATE, num_speakers: int | None = None, ) -> list[SpeakerSegment]: """Synchronous diarization — always call via diarize_async. num_speakers: when set, passed as min_speakers=max_speakers to pyannote, which skips the agglomeration heuristic and improves boundary accuracy for known-size conversations (e.g. 2-person call). """ import torch # pyannote expects (channels, samples) float32 tensor waveform = torch.from_numpy(audio_float32[np.newaxis, :].astype(np.float32)) pipeline_kwargs: dict = {"waveform": waveform, "sample_rate": sample_rate} if num_speakers and num_speakers > 0: pipeline_kwargs["min_speakers"] = num_speakers pipeline_kwargs["max_speakers"] = num_speakers output = self._pipeline(pipeline_kwargs) # pyannote >= 3.3 wraps results in DiarizeOutput; earlier versions return # Annotation directly. Normalise to Annotation before iterating. diarization = getattr(output, "speaker_diarization", output) segments: list[SpeakerSegment] = [] for turn, _, speaker in diarization.itertracks(yield_label=True): segments.append( SpeakerSegment( speaker_id=speaker, start_s=round(turn.start, 3), end_s=round(turn.end, 3), ) ) return segments async def diarize_async( self, audio_float32: np.ndarray, sample_rate: int = _SAMPLE_RATE, num_speakers: int | None = None, ) -> list[SpeakerSegment]: """ Diarize an audio window without blocking the event loop. audio_float32 should be 16kHz mono float32. Typical input is a 2-second window from MicVoiceIO (32000 samples). Returns segments ordered by start_s. num_speakers: passed through to pyannote as min_speakers=max_speakers when set and > 0. Improves accuracy for known speaker counts. """ from functools import partial loop = asyncio.get_running_loop() return await loop.run_in_executor( None, partial(self._diarize_sync, audio_float32, sample_rate, num_speakers), ) def speaker_at( self, segments: list[SpeakerSegment], timestamp_s: float, tracker: SpeakerTracker | None = None, window_s: float = 1.0, ) -> str: """ Return the friendly speaker label dominating a window around timestamp_s. Strategy (in order): 1. If segments directly cover timestamp_s: use majority rule among them. 2. If timestamp_s falls in a silence gap: use the speaker with the most total speaking time across the whole window [0, window_s]. This handles pauses between pyannote segments without falling back to "speaker_a". 3. No segments at all: SPEAKER_UNKNOWN. tracker is optional; if omitted, raw pyannote IDs are returned as-is. """ if not segments: return SPEAKER_UNKNOWN covering = [seg for seg in segments if seg.start_s <= timestamp_s <= seg.end_s] if len(covering) >= 2: return SPEAKER_MULTIPLE if len(covering) == 1: raw_id = covering[0].speaker_id return tracker.label(raw_id) if tracker else raw_id # Midpoint fell in a silence gap — find dominant speaker over the window. from collections import defaultdict duration_by_speaker: dict[str, float] = defaultdict(float) win_start = max(0.0, timestamp_s - window_s / 2) win_end = timestamp_s + window_s / 2 for seg in segments: overlap = min(seg.end_s, win_end) - max(seg.start_s, win_start) if overlap > 0: duration_by_speaker[seg.speaker_id] += overlap if not duration_by_speaker: return SPEAKER_UNKNOWN raw_id = max(duration_by_speaker, key=lambda k: duration_by_speaker[k]) return tracker.label(raw_id) if tracker else raw_id