feat: cf_core.reranker — shared reranker module Phase 1 (#54)
Some checks failed
CI / test (push) Has been cancelled
Mirror / mirror (push) Has been cancelled

Trunk + text branch + BGE adapter:
- base.py: Reranker Protocol, RerankResult (frozen dataclass), TextReranker
  base class with rerank() / rerank_batch() built on _score_pairs()
- adapters/mock.py: MockTextReranker — Jaccard scoring, no deps, deterministic
- adapters/bge.py: BGETextReranker — FlagEmbedding cross-encoder, thread-safe,
  batched forward pass via rerank_batch(); graceful ImportError if dep missing
- __init__.py: rerank() singleton, make_reranker(), reset_reranker();
  CF_RERANKER_MODEL / CF_RERANKER_BACKEND / CF_RERANKER_MOCK env vars
- pyproject.toml: reranker-bge and reranker-qwen3 optional dep groups
- 20 tests, all passing

Architecture ready for Phase 2 (Qwen3TextReranker) and Phase 3 (cf-orch remote
backend). ImageReranker/AudioReranker branches stubbed in base.py docstring.
This commit is contained in:
pyr0ball 2026-04-21 12:25:01 -07:00
parent 3167ee8011
commit 82f0b4c3d0
10 changed files with 662 additions and 2 deletions

View file

@ -1,4 +1,4 @@
__version__ = "0.14.0"
__version__ = "0.15.0"
try:
from circuitforge_core.community import CommunityDB, CommunityPost, SharedStore

View file

@ -0,0 +1,145 @@
"""
circuitforge_core.reranker shared reranker module for RAG pipelines.
Provides a modality-aware scoring interface for ranking candidates against a
query. Built to handle text today and audio/image/video in future branches.
Architecture:
Reranker (Protocol / trunk)
TextReranker (branch)
MockTextReranker no deps, deterministic, for tests
BGETextReranker FlagEmbedding cross-encoder, MIT, Free tier
Qwen3TextReranker generative reranker, MIT/BSL, Paid tier (Phase 2)
Quick start (mock mode no model required):
import os; os.environ["CF_RERANKER_MOCK"] = "1"
from circuitforge_core.reranker import rerank
results = rerank("chicken soup", ["hearty chicken noodle", "chocolate cake", "tomato basil soup"])
for r in results:
print(r.rank, r.score, r.candidate[:40])
Real inference (BGE cross-encoder):
export CF_RERANKER_MODEL=BAAI/bge-reranker-base
from circuitforge_core.reranker import rerank
results = rerank(query, candidates, top_n=20)
Explicit backend (per-request or per-user):
from circuitforge_core.reranker import make_reranker
reranker = make_reranker("BAAI/bge-reranker-v2-m3", backend="bge")
results = reranker.rerank(query, candidates, top_n=10)
Batch scoring (efficient for large corpora):
from circuitforge_core.reranker import make_reranker
reranker = make_reranker("BAAI/bge-reranker-base")
batch = reranker.rerank_batch(queries, candidate_lists, top_n=10)
Environment variables:
CF_RERANKER_MODEL model ID or path (default: "BAAI/bge-reranker-base")
CF_RERANKER_BACKEND backend override: "bge" | "mock" (default: auto-detect)
CF_RERANKER_MOCK set to "1" to force mock backend (no model required)
cf-orch service profile (Phase 3 remote backend):
service_type: cf-reranker
max_mb: per-model (base 600, large 1400, 8B 8192)
shared: true
"""
from __future__ import annotations
import os
from typing import Sequence
from circuitforge_core.reranker.base import RerankResult, Reranker, TextReranker
from circuitforge_core.reranker.adapters.mock import MockTextReranker
# ── Process-level singleton ───────────────────────────────────────────────────
_reranker: TextReranker | None = None
_DEFAULT_MODEL = "BAAI/bge-reranker-base"
def _get_reranker() -> TextReranker:
global _reranker
if _reranker is None:
_reranker = make_reranker()
return _reranker
def make_reranker(
model_id: str | None = None,
backend: str | None = None,
mock: bool | None = None,
) -> TextReranker:
"""
Create a TextReranker for the given model.
Use this when you need an explicit reranker instance (e.g. per-service
with a specific model) rather than the process-level singleton.
model_id HuggingFace model ID or local path. Defaults to
CF_RERANKER_MODEL env var, then BAAI/bge-reranker-base.
backend "bge" | "mock". Auto-detected from model_id if omitted.
mock Force mock backend. Defaults to CF_RERANKER_MOCK env var.
"""
_mock = mock if mock is not None else os.environ.get("CF_RERANKER_MOCK", "") == "1"
if _mock:
return MockTextReranker()
_model_id = model_id or os.environ.get("CF_RERANKER_MODEL", _DEFAULT_MODEL)
_backend = backend or os.environ.get("CF_RERANKER_BACKEND", "bge")
if _backend == "mock":
return MockTextReranker()
if _backend == "bge":
from circuitforge_core.reranker.adapters.bge import BGETextReranker
return BGETextReranker(_model_id)
raise ValueError(
f"Unknown reranker backend {_backend!r}. "
"Valid options: 'bge', 'mock'. "
"(Qwen3 support is coming in Phase 2.)"
)
# ── Convenience functions (singleton path) ────────────────────────────────────
def rerank(
query: str,
candidates: Sequence[str],
top_n: int = 0,
) -> list[RerankResult]:
"""
Score and sort candidates against query using the process-level reranker.
Returns a list of RerankResult sorted by score descending (rank 0 first).
top_n=0 returns all candidates.
For large corpora, prefer rerank_batch() on an explicit reranker instance
to amortise model load time and batch the forward pass.
"""
return _get_reranker().rerank(query, candidates, top_n=top_n)
def reset_reranker() -> None:
"""Reset the process-level singleton. Test teardown only."""
global _reranker
_reranker = None
__all__ = [
"Reranker",
"TextReranker",
"RerankResult",
"MockTextReranker",
"make_reranker",
"rerank",
"reset_reranker",
]

View file

@ -0,0 +1,137 @@
# circuitforge_core/reranker/adapters/bge.py — BGE cross-encoder reranker
#
# Requires: pip install circuitforge-core[reranker-bge]
# Tested with FlagEmbedding>=1.2 (BAAI/bge-reranker-* family).
#
# MIT licensed — local inference only, no tier gate.
from __future__ import annotations
import logging
import threading
from typing import Sequence
from circuitforge_core.reranker.base import TextReranker
logger = logging.getLogger(__name__)
# Lazy import sentinel — FlagEmbedding is an optional dep.
try:
from FlagEmbedding import FlagReranker as _FlagReranker # type: ignore[import]
except ImportError:
_FlagReranker = None # type: ignore[assignment]
def _cuda_available() -> bool:
try:
import torch
return torch.cuda.is_available()
except ImportError:
return False
class BGETextReranker(TextReranker):
"""
Cross-encoder reranker using the BAAI BGE reranker family.
Scores (query, candidate) pairs via FlagEmbedding.FlagReranker.
Thread-safe: a lock serialises concurrent _score_pairs calls since
FlagReranker is not guaranteed to be reentrant.
Recommended free-tier models:
BAAI/bge-reranker-base ~570MB VRAM, fast
BAAI/bge-reranker-v2-m3 ~570MB VRAM, multilingual
BAAI/bge-reranker-large ~1.3GB VRAM, higher quality
Usage:
reranker = BGETextReranker("BAAI/bge-reranker-base")
results = reranker.rerank("chicken soup recipe", ["recipe 1...", "recipe 2..."])
"""
def __init__(self, model_id: str = "BAAI/bge-reranker-base") -> None:
self._model_id = model_id
self._reranker: object | None = None
self._lock = threading.Lock()
@property
def model_id(self) -> str:
return self._model_id
def load(self) -> None:
"""Explicitly load model weights. Called automatically on first rerank()."""
if _FlagReranker is None:
raise ImportError(
"FlagEmbedding is not installed. "
"Run: pip install circuitforge-core[reranker-bge]"
)
with self._lock:
if self._reranker is None:
logger.info("Loading BGE reranker: %s (fp16=%s)", self._model_id, _cuda_available())
self._reranker = _FlagReranker(self._model_id, use_fp16=_cuda_available())
def unload(self) -> None:
"""Release model weights. Useful for VRAM management between tasks."""
with self._lock:
self._reranker = None
def _score_pairs(self, query: str, candidates: list[str]) -> list[float]:
if self._reranker is None:
self.load()
pairs = [[query, c] for c in candidates]
with self._lock:
scores: list[float] = self._reranker.compute_score( # type: ignore[union-attr]
pairs, normalize=True
)
# compute_score may return a single float when given one pair.
if isinstance(scores, float):
scores = [scores]
return scores
def rerank_batch(
self,
queries: Sequence[str],
candidates: Sequence[Sequence[str]],
top_n: int = 0,
) -> list[list[object]]:
"""Batch all pairs into a single compute_score call for efficiency."""
from circuitforge_core.reranker.base import RerankResult
if self._reranker is None:
self.load()
# Flatten all pairs, recording group boundaries for reconstruction.
all_pairs: list[list[str]] = []
group_sizes: list[int] = []
for q, cs in zip(queries, candidates):
cands = list(cs)
group_sizes.append(len(cands))
all_pairs.extend([q, c] for c in cands)
if not all_pairs:
return [[] for _ in queries]
with self._lock:
all_scores: list[float] = self._reranker.compute_score( # type: ignore[union-attr]
all_pairs, normalize=True
)
if isinstance(all_scores, float):
all_scores = [all_scores]
# Reconstruct per-query result lists.
results: list[list[RerankResult]] = []
offset = 0
for (q, cs), size in zip(zip(queries, candidates), group_sizes):
cands = list(cs)
scores = all_scores[offset : offset + size]
offset += size
sorted_results = sorted(
(RerankResult(candidate=c, score=s, rank=0) for c, s in zip(cands, scores)),
key=lambda r: r.score,
reverse=True,
)
if top_n > 0:
sorted_results = sorted_results[:top_n]
results.append([
RerankResult(candidate=r.candidate, score=r.score, rank=i)
for i, r in enumerate(sorted_results)
])
return results

View file

@ -0,0 +1,37 @@
# circuitforge_core/reranker/adapters/mock.py — deterministic mock reranker
#
# Always importable, no optional deps. Used in tests and CF_RERANKER_MOCK=1 mode.
# Scores by descending overlap of query tokens with candidate tokens so results
# are deterministic and meaningful enough to exercise product code paths.
from __future__ import annotations
from typing import Sequence
from circuitforge_core.reranker.base import RerankResult, TextReranker
class MockTextReranker(TextReranker):
"""Deterministic reranker for tests. No model weights required.
Scoring: Jaccard similarity between query token set and candidate token set.
Ties broken by candidate length (shorter wins) then lexicographic order,
so test assertions can be written against a stable ordering.
"""
_MODEL_ID = "mock"
@property
def model_id(self) -> str:
return self._MODEL_ID
def _score_pairs(self, query: str, candidates: list[str]) -> list[float]:
q_tokens = set(query.lower().split())
scores: list[float] = []
for candidate in candidates:
c_tokens = set(candidate.lower().split())
union = q_tokens | c_tokens
if not union:
scores.append(0.0)
else:
scores.append(len(q_tokens & c_tokens) / len(union))
return scores

View file

@ -0,0 +1,135 @@
# circuitforge_core/reranker/base.py — Reranker Protocol + modality branches
#
# MIT licensed. The Protocol and RerankResult are always importable.
# Adapter implementations (BGE, Qwen3, cf-orch remote) require optional extras.
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Protocol, Sequence, runtime_checkable
# ── Result type ───────────────────────────────────────────────────────────────
@dataclass(frozen=True)
class RerankResult:
"""A single scored candidate returned by a reranker.
rank is 0-based (0 = highest score).
candidate preserves the original object text, Path, or any other type
passed in by the caller, so products don't need to re-index the input list.
"""
candidate: Any
score: float
rank: int
# ── Trunk: generic Reranker Protocol ─────────────────────────────────────────
@runtime_checkable
class Reranker(Protocol):
"""
Abstract interface for all reranker adapters.
Implementations must be safe to construct once and call concurrently;
internal state (loaded model weights) should be guarded by a lock if
the backend is not thread-safe.
query the reference item to rank against (typically a text query)
candidates ordered collection of items to score; ordering is preserved
in the returned list, which is sorted by score descending
top_n return at most this many results; 0 means return all
Returns a list of RerankResult sorted by score descending (rank 0 first).
"""
def rerank(
self,
query: str,
candidates: Sequence[Any],
top_n: int = 0,
) -> list[RerankResult]:
...
def rerank_batch(
self,
queries: Sequence[str],
candidates: Sequence[Sequence[Any]],
top_n: int = 0,
) -> list[list[RerankResult]]:
"""Score multiple (query, candidates) pairs in one call.
Default implementation loops over rerank(); adapters may override
with a true batched forward pass for efficiency.
"""
...
@property
def model_id(self) -> str:
"""Identifier for the loaded model (name, path, or URL)."""
...
# ── Branch: text-specific reranker ───────────────────────────────────────────
class TextReranker:
"""
Base class for text-to-text rerankers.
Subclasses implement _score_pairs(query, candidates) and get rerank()
and rerank_batch() for free. The default rerank_batch() loops over
_score_pairs; override it in adapters that support native batching.
candidates must be strings. query is always a string.
"""
@property
def model_id(self) -> str:
raise NotImplementedError
def _score_pairs(
self,
query: str,
candidates: list[str],
) -> list[float]:
"""Return a score per candidate (higher = more relevant).
Called by rerank() and rerank_batch(). Must return a list of the
same length as candidates, in the same order.
"""
raise NotImplementedError
def rerank(
self,
query: str,
candidates: Sequence[str],
top_n: int = 0,
) -> list[RerankResult]:
cands = list(candidates)
if not cands:
return []
scores = self._score_pairs(query, cands)
results = sorted(
(RerankResult(candidate=c, score=s, rank=0) for c, s in zip(cands, scores)),
key=lambda r: r.score,
reverse=True,
)
if top_n > 0:
results = results[:top_n]
return [
RerankResult(candidate=r.candidate, score=r.score, rank=i)
for i, r in enumerate(results)
]
def rerank_batch(
self,
queries: Sequence[str],
candidates: Sequence[Sequence[str]],
top_n: int = 0,
) -> list[list[RerankResult]]:
return [
self.rerank(q, cs, top_n)
for q, cs in zip(queries, candidates)
]

View file

@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "circuitforge-core"
version = "0.14.0"
version = "0.15.0"
description = "Shared scaffold for CircuitForge products (MIT)"
requires-python = ">=3.11"
dependencies = [
@ -83,6 +83,14 @@ vision-service = [
"uvicorn[standard]>=0.29",
"python-multipart>=0.0.9",
]
reranker-bge = [
"FlagEmbedding>=1.2",
]
reranker-qwen3 = [
"torch>=2.0",
"transformers>=4.40",
"accelerate>=0.27",
]
dev = [
"circuitforge-core[manage]",
"pytest>=8.0",

View file

View file

@ -0,0 +1,102 @@
"""Tests for RerankResult, TextReranker base class, and the public rerank() API."""
from __future__ import annotations
import os
import pytest
from circuitforge_core.reranker.base import RerankResult, TextReranker
from circuitforge_core.reranker.adapters.mock import MockTextReranker
# ── RerankResult ──────────────────────────────────────────────────────────────
def test_rerank_result_fields():
r = RerankResult(candidate="recipe text", score=0.9, rank=0)
assert r.candidate == "recipe text"
assert r.score == 0.9
assert r.rank == 0
def test_rerank_result_is_frozen():
r = RerankResult(candidate="x", score=0.5, rank=0)
with pytest.raises(Exception):
r.score = 0.1 # type: ignore[misc]
# ── MockTextReranker ──────────────────────────────────────────────────────────
def test_mock_rerank_returns_sorted_results():
reranker = MockTextReranker()
results = reranker.rerank(
"chicken soup recipe",
["chocolate cake recipe", "chicken noodle soup", "tomato basil pasta"],
)
assert len(results) == 3
# "chicken noodle soup" shares more tokens with query → should rank first
assert results[0].candidate == "chicken noodle soup"
assert results[0].rank == 0
assert results[1].rank == 1
assert results[2].rank == 2
def test_mock_rerank_top_n():
reranker = MockTextReranker()
results = reranker.rerank("chicken", ["a", "b chicken", "c chicken soup"], top_n=2)
assert len(results) == 2
def test_mock_rerank_empty_candidates():
reranker = MockTextReranker()
assert reranker.rerank("query", []) == []
def test_mock_rerank_scores_descending():
reranker = MockTextReranker()
results = reranker.rerank("apple pie dessert", ["apple pie", "beef stew", "apple crumble dessert"])
scores = [r.score for r in results]
assert scores == sorted(scores, reverse=True)
def test_mock_rerank_batch():
reranker = MockTextReranker()
batch = reranker.rerank_batch(
["chicken soup", "chocolate cake"],
[["chicken noodle", "beef stew"], ["chocolate mousse", "caesar salad"]],
top_n=1,
)
assert len(batch) == 2
assert batch[0][0].candidate == "chicken noodle"
assert batch[1][0].candidate == "chocolate mousse"
def test_mock_model_id():
assert MockTextReranker().model_id == "mock"
# ── Public API singleton ──────────────────────────────────────────────────────
def test_rerank_function_mock_mode(monkeypatch):
monkeypatch.setenv("CF_RERANKER_MOCK", "1")
from circuitforge_core.reranker import rerank, reset_reranker
reset_reranker()
results = rerank("chicken soup", ["beef stew", "chicken noodle soup", "cake"])
assert results[0].candidate == "chicken noodle soup"
reset_reranker()
def test_make_reranker_mock_explicit():
from circuitforge_core.reranker import make_reranker
r = make_reranker(mock=True)
assert isinstance(r, MockTextReranker)
def test_make_reranker_unknown_backend_raises():
from circuitforge_core.reranker import make_reranker
with pytest.raises(ValueError, match="Unknown reranker backend"):
make_reranker(backend="nonexistent")
def test_reranker_protocol_conformance():
"""MockTextReranker satisfies the Reranker Protocol."""
from circuitforge_core.reranker.base import Reranker
assert isinstance(MockTextReranker(), Reranker)

View file

@ -0,0 +1,96 @@
"""Tests for BGETextReranker with mocked FlagEmbedding."""
from __future__ import annotations
from unittest.mock import MagicMock, patch
import pytest
from circuitforge_core.reranker.adapters.bge import BGETextReranker
from circuitforge_core.reranker.base import RerankResult
def _make_mock_flag_reranker(scores: list[float]) -> MagicMock:
"""Return a mock FlagReranker that yields the given scores."""
m = MagicMock()
m.compute_score.return_value = scores
return m
# ── BGETextReranker unit tests ────────────────────────────────────────────────
def test_bge_rerank_scores_and_sorts():
reranker = BGETextReranker("BAAI/bge-reranker-base")
mock_fr = _make_mock_flag_reranker([0.2, 0.9, 0.5])
reranker._reranker = mock_fr
results = reranker.rerank("query", ["a", "b", "c"])
assert len(results) == 3
assert results[0].candidate == "b" # highest score 0.9
assert results[0].score == pytest.approx(0.9)
assert results[0].rank == 0
assert results[1].candidate == "c"
assert results[2].candidate == "a"
def test_bge_rerank_top_n():
reranker = BGETextReranker("BAAI/bge-reranker-base")
reranker._reranker = _make_mock_flag_reranker([0.1, 0.8, 0.5])
results = reranker.rerank("q", ["a", "b", "c"], top_n=2)
assert len(results) == 2
assert results[0].candidate == "b"
def test_bge_rerank_single_candidate_float_return():
"""compute_score returns a float (not list) for a single pair."""
reranker = BGETextReranker("BAAI/bge-reranker-base")
mock_fr = MagicMock()
mock_fr.compute_score.return_value = 0.75 # single float
reranker._reranker = mock_fr
results = reranker.rerank("q", ["only candidate"])
assert len(results) == 1
assert results[0].score == pytest.approx(0.75)
def test_bge_rerank_batch_flattens_pairs():
reranker = BGETextReranker("BAAI/bge-reranker-base")
mock_fr = _make_mock_flag_reranker([0.9, 0.1, 0.3, 0.8])
reranker._reranker = mock_fr
batch = reranker.rerank_batch(
["q1", "q2"],
[["a1", "a2"], ["b1", "b2"]],
)
assert len(batch) == 2
# q1: scores [0.9, 0.1] → a1 first
assert batch[0][0].candidate == "a1"
# q2: scores [0.3, 0.8] → b2 first
assert batch[1][0].candidate == "b2"
# All pairs were sent in a single compute_score call
all_pairs = mock_fr.compute_score.call_args[0][0]
assert len(all_pairs) == 4
def test_bge_rerank_empty_batch():
reranker = BGETextReranker("BAAI/bge-reranker-base")
reranker._reranker = MagicMock()
result = reranker.rerank_batch([], [], top_n=5)
assert result == []
def test_bge_load_raises_without_flagembedding():
reranker = BGETextReranker("BAAI/bge-reranker-base")
with patch("circuitforge_core.reranker.adapters.bge._FlagReranker", None):
with pytest.raises(ImportError, match="FlagEmbedding"):
reranker.load()
def test_bge_model_id():
assert BGETextReranker("BAAI/bge-reranker-v2-m3").model_id == "BAAI/bge-reranker-v2-m3"
def test_bge_unload_clears_reranker():
reranker = BGETextReranker("BAAI/bge-reranker-base")
reranker._reranker = MagicMock()
reranker.unload()
assert reranker._reranker is None