feat: cf_core.reranker — shared reranker module Phase 1 (#54)
Trunk + text branch + BGE adapter: - base.py: Reranker Protocol, RerankResult (frozen dataclass), TextReranker base class with rerank() / rerank_batch() built on _score_pairs() - adapters/mock.py: MockTextReranker — Jaccard scoring, no deps, deterministic - adapters/bge.py: BGETextReranker — FlagEmbedding cross-encoder, thread-safe, batched forward pass via rerank_batch(); graceful ImportError if dep missing - __init__.py: rerank() singleton, make_reranker(), reset_reranker(); CF_RERANKER_MODEL / CF_RERANKER_BACKEND / CF_RERANKER_MOCK env vars - pyproject.toml: reranker-bge and reranker-qwen3 optional dep groups - 20 tests, all passing Architecture ready for Phase 2 (Qwen3TextReranker) and Phase 3 (cf-orch remote backend). ImageReranker/AudioReranker branches stubbed in base.py docstring.
This commit is contained in:
parent
3167ee8011
commit
82f0b4c3d0
10 changed files with 662 additions and 2 deletions
|
|
@ -1,4 +1,4 @@
|
|||
__version__ = "0.14.0"
|
||||
__version__ = "0.15.0"
|
||||
|
||||
try:
|
||||
from circuitforge_core.community import CommunityDB, CommunityPost, SharedStore
|
||||
|
|
|
|||
145
circuitforge_core/reranker/__init__.py
Normal file
145
circuitforge_core/reranker/__init__.py
Normal file
|
|
@ -0,0 +1,145 @@
|
|||
"""
|
||||
circuitforge_core.reranker — shared reranker module for RAG pipelines.
|
||||
|
||||
Provides a modality-aware scoring interface for ranking candidates against a
|
||||
query. Built to handle text today and audio/image/video in future branches.
|
||||
|
||||
Architecture:
|
||||
|
||||
Reranker (Protocol / trunk)
|
||||
└── TextReranker (branch)
|
||||
├── MockTextReranker — no deps, deterministic, for tests
|
||||
├── BGETextReranker — FlagEmbedding cross-encoder, MIT, Free tier
|
||||
└── Qwen3TextReranker — generative reranker, MIT/BSL, Paid tier (Phase 2)
|
||||
|
||||
Quick start (mock mode — no model required):
|
||||
|
||||
import os; os.environ["CF_RERANKER_MOCK"] = "1"
|
||||
from circuitforge_core.reranker import rerank
|
||||
|
||||
results = rerank("chicken soup", ["hearty chicken noodle", "chocolate cake", "tomato basil soup"])
|
||||
for r in results:
|
||||
print(r.rank, r.score, r.candidate[:40])
|
||||
|
||||
Real inference (BGE cross-encoder):
|
||||
|
||||
export CF_RERANKER_MODEL=BAAI/bge-reranker-base
|
||||
from circuitforge_core.reranker import rerank
|
||||
results = rerank(query, candidates, top_n=20)
|
||||
|
||||
Explicit backend (per-request or per-user):
|
||||
|
||||
from circuitforge_core.reranker import make_reranker
|
||||
reranker = make_reranker("BAAI/bge-reranker-v2-m3", backend="bge")
|
||||
results = reranker.rerank(query, candidates, top_n=10)
|
||||
|
||||
Batch scoring (efficient for large corpora):
|
||||
|
||||
from circuitforge_core.reranker import make_reranker
|
||||
reranker = make_reranker("BAAI/bge-reranker-base")
|
||||
batch = reranker.rerank_batch(queries, candidate_lists, top_n=10)
|
||||
|
||||
Environment variables:
|
||||
CF_RERANKER_MODEL model ID or path (default: "BAAI/bge-reranker-base")
|
||||
CF_RERANKER_BACKEND backend override: "bge" | "mock" (default: auto-detect)
|
||||
CF_RERANKER_MOCK set to "1" to force mock backend (no model required)
|
||||
|
||||
cf-orch service profile (Phase 3 — remote backend):
|
||||
service_type: cf-reranker
|
||||
max_mb: per-model (base ≈ 600, large ≈ 1400, 8B ≈ 8192)
|
||||
shared: true
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Sequence
|
||||
|
||||
from circuitforge_core.reranker.base import RerankResult, Reranker, TextReranker
|
||||
from circuitforge_core.reranker.adapters.mock import MockTextReranker
|
||||
|
||||
# ── Process-level singleton ───────────────────────────────────────────────────
|
||||
|
||||
_reranker: TextReranker | None = None
|
||||
|
||||
_DEFAULT_MODEL = "BAAI/bge-reranker-base"
|
||||
|
||||
|
||||
def _get_reranker() -> TextReranker:
|
||||
global _reranker
|
||||
if _reranker is None:
|
||||
_reranker = make_reranker()
|
||||
return _reranker
|
||||
|
||||
|
||||
def make_reranker(
|
||||
model_id: str | None = None,
|
||||
backend: str | None = None,
|
||||
mock: bool | None = None,
|
||||
) -> TextReranker:
|
||||
"""
|
||||
Create a TextReranker for the given model.
|
||||
|
||||
Use this when you need an explicit reranker instance (e.g. per-service
|
||||
with a specific model) rather than the process-level singleton.
|
||||
|
||||
model_id — HuggingFace model ID or local path. Defaults to
|
||||
CF_RERANKER_MODEL env var, then BAAI/bge-reranker-base.
|
||||
backend — "bge" | "mock". Auto-detected from model_id if omitted.
|
||||
mock — Force mock backend. Defaults to CF_RERANKER_MOCK env var.
|
||||
"""
|
||||
_mock = mock if mock is not None else os.environ.get("CF_RERANKER_MOCK", "") == "1"
|
||||
if _mock:
|
||||
return MockTextReranker()
|
||||
|
||||
_model_id = model_id or os.environ.get("CF_RERANKER_MODEL", _DEFAULT_MODEL)
|
||||
_backend = backend or os.environ.get("CF_RERANKER_BACKEND", "bge")
|
||||
|
||||
if _backend == "mock":
|
||||
return MockTextReranker()
|
||||
|
||||
if _backend == "bge":
|
||||
from circuitforge_core.reranker.adapters.bge import BGETextReranker
|
||||
return BGETextReranker(_model_id)
|
||||
|
||||
raise ValueError(
|
||||
f"Unknown reranker backend {_backend!r}. "
|
||||
"Valid options: 'bge', 'mock'. "
|
||||
"(Qwen3 support is coming in Phase 2.)"
|
||||
)
|
||||
|
||||
|
||||
# ── Convenience functions (singleton path) ────────────────────────────────────
|
||||
|
||||
|
||||
def rerank(
|
||||
query: str,
|
||||
candidates: Sequence[str],
|
||||
top_n: int = 0,
|
||||
) -> list[RerankResult]:
|
||||
"""
|
||||
Score and sort candidates against query using the process-level reranker.
|
||||
|
||||
Returns a list of RerankResult sorted by score descending (rank 0 first).
|
||||
top_n=0 returns all candidates.
|
||||
|
||||
For large corpora, prefer rerank_batch() on an explicit reranker instance
|
||||
to amortise model load time and batch the forward pass.
|
||||
"""
|
||||
return _get_reranker().rerank(query, candidates, top_n=top_n)
|
||||
|
||||
|
||||
def reset_reranker() -> None:
|
||||
"""Reset the process-level singleton. Test teardown only."""
|
||||
global _reranker
|
||||
_reranker = None
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Reranker",
|
||||
"TextReranker",
|
||||
"RerankResult",
|
||||
"MockTextReranker",
|
||||
"make_reranker",
|
||||
"rerank",
|
||||
"reset_reranker",
|
||||
]
|
||||
0
circuitforge_core/reranker/adapters/__init__.py
Normal file
0
circuitforge_core/reranker/adapters/__init__.py
Normal file
137
circuitforge_core/reranker/adapters/bge.py
Normal file
137
circuitforge_core/reranker/adapters/bge.py
Normal file
|
|
@ -0,0 +1,137 @@
|
|||
# circuitforge_core/reranker/adapters/bge.py — BGE cross-encoder reranker
|
||||
#
|
||||
# Requires: pip install circuitforge-core[reranker-bge]
|
||||
# Tested with FlagEmbedding>=1.2 (BAAI/bge-reranker-* family).
|
||||
#
|
||||
# MIT licensed — local inference only, no tier gate.
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from typing import Sequence
|
||||
|
||||
from circuitforge_core.reranker.base import TextReranker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Lazy import sentinel — FlagEmbedding is an optional dep.
|
||||
try:
|
||||
from FlagEmbedding import FlagReranker as _FlagReranker # type: ignore[import]
|
||||
except ImportError:
|
||||
_FlagReranker = None # type: ignore[assignment]
|
||||
|
||||
|
||||
def _cuda_available() -> bool:
|
||||
try:
|
||||
import torch
|
||||
return torch.cuda.is_available()
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
class BGETextReranker(TextReranker):
|
||||
"""
|
||||
Cross-encoder reranker using the BAAI BGE reranker family.
|
||||
|
||||
Scores (query, candidate) pairs via FlagEmbedding.FlagReranker.
|
||||
Thread-safe: a lock serialises concurrent _score_pairs calls since
|
||||
FlagReranker is not guaranteed to be reentrant.
|
||||
|
||||
Recommended free-tier models:
|
||||
BAAI/bge-reranker-base ~570MB VRAM, fast
|
||||
BAAI/bge-reranker-v2-m3 ~570MB VRAM, multilingual
|
||||
BAAI/bge-reranker-large ~1.3GB VRAM, higher quality
|
||||
|
||||
Usage:
|
||||
reranker = BGETextReranker("BAAI/bge-reranker-base")
|
||||
results = reranker.rerank("chicken soup recipe", ["recipe 1...", "recipe 2..."])
|
||||
"""
|
||||
|
||||
def __init__(self, model_id: str = "BAAI/bge-reranker-base") -> None:
|
||||
self._model_id = model_id
|
||||
self._reranker: object | None = None
|
||||
self._lock = threading.Lock()
|
||||
|
||||
@property
|
||||
def model_id(self) -> str:
|
||||
return self._model_id
|
||||
|
||||
def load(self) -> None:
|
||||
"""Explicitly load model weights. Called automatically on first rerank()."""
|
||||
if _FlagReranker is None:
|
||||
raise ImportError(
|
||||
"FlagEmbedding is not installed. "
|
||||
"Run: pip install circuitforge-core[reranker-bge]"
|
||||
)
|
||||
with self._lock:
|
||||
if self._reranker is None:
|
||||
logger.info("Loading BGE reranker: %s (fp16=%s)", self._model_id, _cuda_available())
|
||||
self._reranker = _FlagReranker(self._model_id, use_fp16=_cuda_available())
|
||||
|
||||
def unload(self) -> None:
|
||||
"""Release model weights. Useful for VRAM management between tasks."""
|
||||
with self._lock:
|
||||
self._reranker = None
|
||||
|
||||
def _score_pairs(self, query: str, candidates: list[str]) -> list[float]:
|
||||
if self._reranker is None:
|
||||
self.load()
|
||||
pairs = [[query, c] for c in candidates]
|
||||
with self._lock:
|
||||
scores: list[float] = self._reranker.compute_score( # type: ignore[union-attr]
|
||||
pairs, normalize=True
|
||||
)
|
||||
# compute_score may return a single float when given one pair.
|
||||
if isinstance(scores, float):
|
||||
scores = [scores]
|
||||
return scores
|
||||
|
||||
def rerank_batch(
|
||||
self,
|
||||
queries: Sequence[str],
|
||||
candidates: Sequence[Sequence[str]],
|
||||
top_n: int = 0,
|
||||
) -> list[list[object]]:
|
||||
"""Batch all pairs into a single compute_score call for efficiency."""
|
||||
from circuitforge_core.reranker.base import RerankResult
|
||||
|
||||
if self._reranker is None:
|
||||
self.load()
|
||||
|
||||
# Flatten all pairs, recording group boundaries for reconstruction.
|
||||
all_pairs: list[list[str]] = []
|
||||
group_sizes: list[int] = []
|
||||
for q, cs in zip(queries, candidates):
|
||||
cands = list(cs)
|
||||
group_sizes.append(len(cands))
|
||||
all_pairs.extend([q, c] for c in cands)
|
||||
|
||||
if not all_pairs:
|
||||
return [[] for _ in queries]
|
||||
|
||||
with self._lock:
|
||||
all_scores: list[float] = self._reranker.compute_score( # type: ignore[union-attr]
|
||||
all_pairs, normalize=True
|
||||
)
|
||||
if isinstance(all_scores, float):
|
||||
all_scores = [all_scores]
|
||||
|
||||
# Reconstruct per-query result lists.
|
||||
results: list[list[RerankResult]] = []
|
||||
offset = 0
|
||||
for (q, cs), size in zip(zip(queries, candidates), group_sizes):
|
||||
cands = list(cs)
|
||||
scores = all_scores[offset : offset + size]
|
||||
offset += size
|
||||
sorted_results = sorted(
|
||||
(RerankResult(candidate=c, score=s, rank=0) for c, s in zip(cands, scores)),
|
||||
key=lambda r: r.score,
|
||||
reverse=True,
|
||||
)
|
||||
if top_n > 0:
|
||||
sorted_results = sorted_results[:top_n]
|
||||
results.append([
|
||||
RerankResult(candidate=r.candidate, score=r.score, rank=i)
|
||||
for i, r in enumerate(sorted_results)
|
||||
])
|
||||
return results
|
||||
37
circuitforge_core/reranker/adapters/mock.py
Normal file
37
circuitforge_core/reranker/adapters/mock.py
Normal file
|
|
@ -0,0 +1,37 @@
|
|||
# circuitforge_core/reranker/adapters/mock.py — deterministic mock reranker
|
||||
#
|
||||
# Always importable, no optional deps. Used in tests and CF_RERANKER_MOCK=1 mode.
|
||||
# Scores by descending overlap of query tokens with candidate tokens so results
|
||||
# are deterministic and meaningful enough to exercise product code paths.
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Sequence
|
||||
|
||||
from circuitforge_core.reranker.base import RerankResult, TextReranker
|
||||
|
||||
|
||||
class MockTextReranker(TextReranker):
|
||||
"""Deterministic reranker for tests. No model weights required.
|
||||
|
||||
Scoring: Jaccard similarity between query token set and candidate token set.
|
||||
Ties broken by candidate length (shorter wins) then lexicographic order,
|
||||
so test assertions can be written against a stable ordering.
|
||||
"""
|
||||
|
||||
_MODEL_ID = "mock"
|
||||
|
||||
@property
|
||||
def model_id(self) -> str:
|
||||
return self._MODEL_ID
|
||||
|
||||
def _score_pairs(self, query: str, candidates: list[str]) -> list[float]:
|
||||
q_tokens = set(query.lower().split())
|
||||
scores: list[float] = []
|
||||
for candidate in candidates:
|
||||
c_tokens = set(candidate.lower().split())
|
||||
union = q_tokens | c_tokens
|
||||
if not union:
|
||||
scores.append(0.0)
|
||||
else:
|
||||
scores.append(len(q_tokens & c_tokens) / len(union))
|
||||
return scores
|
||||
135
circuitforge_core/reranker/base.py
Normal file
135
circuitforge_core/reranker/base.py
Normal file
|
|
@ -0,0 +1,135 @@
|
|||
# circuitforge_core/reranker/base.py — Reranker Protocol + modality branches
|
||||
#
|
||||
# MIT licensed. The Protocol and RerankResult are always importable.
|
||||
# Adapter implementations (BGE, Qwen3, cf-orch remote) require optional extras.
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Protocol, Sequence, runtime_checkable
|
||||
|
||||
|
||||
# ── Result type ───────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RerankResult:
|
||||
"""A single scored candidate returned by a reranker.
|
||||
|
||||
rank is 0-based (0 = highest score).
|
||||
candidate preserves the original object — text, Path, or any other type
|
||||
passed in by the caller, so products don't need to re-index the input list.
|
||||
"""
|
||||
candidate: Any
|
||||
score: float
|
||||
rank: int
|
||||
|
||||
|
||||
# ── Trunk: generic Reranker Protocol ─────────────────────────────────────────
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Reranker(Protocol):
|
||||
"""
|
||||
Abstract interface for all reranker adapters.
|
||||
|
||||
Implementations must be safe to construct once and call concurrently;
|
||||
internal state (loaded model weights) should be guarded by a lock if
|
||||
the backend is not thread-safe.
|
||||
|
||||
query — the reference item to rank against (typically a text query)
|
||||
candidates — ordered collection of items to score; ordering is preserved
|
||||
in the returned list, which is sorted by score descending
|
||||
top_n — return at most this many results; 0 means return all
|
||||
|
||||
Returns a list of RerankResult sorted by score descending (rank 0 first).
|
||||
"""
|
||||
|
||||
def rerank(
|
||||
self,
|
||||
query: str,
|
||||
candidates: Sequence[Any],
|
||||
top_n: int = 0,
|
||||
) -> list[RerankResult]:
|
||||
...
|
||||
|
||||
def rerank_batch(
|
||||
self,
|
||||
queries: Sequence[str],
|
||||
candidates: Sequence[Sequence[Any]],
|
||||
top_n: int = 0,
|
||||
) -> list[list[RerankResult]]:
|
||||
"""Score multiple (query, candidates) pairs in one call.
|
||||
|
||||
Default implementation loops over rerank(); adapters may override
|
||||
with a true batched forward pass for efficiency.
|
||||
"""
|
||||
...
|
||||
|
||||
@property
|
||||
def model_id(self) -> str:
|
||||
"""Identifier for the loaded model (name, path, or URL)."""
|
||||
...
|
||||
|
||||
|
||||
# ── Branch: text-specific reranker ───────────────────────────────────────────
|
||||
|
||||
|
||||
class TextReranker:
|
||||
"""
|
||||
Base class for text-to-text rerankers.
|
||||
|
||||
Subclasses implement _score_pairs(query, candidates) and get rerank()
|
||||
and rerank_batch() for free. The default rerank_batch() loops over
|
||||
_score_pairs; override it in adapters that support native batching.
|
||||
|
||||
candidates must be strings. query is always a string.
|
||||
"""
|
||||
|
||||
@property
|
||||
def model_id(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def _score_pairs(
|
||||
self,
|
||||
query: str,
|
||||
candidates: list[str],
|
||||
) -> list[float]:
|
||||
"""Return a score per candidate (higher = more relevant).
|
||||
|
||||
Called by rerank() and rerank_batch(). Must return a list of the
|
||||
same length as candidates, in the same order.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def rerank(
|
||||
self,
|
||||
query: str,
|
||||
candidates: Sequence[str],
|
||||
top_n: int = 0,
|
||||
) -> list[RerankResult]:
|
||||
cands = list(candidates)
|
||||
if not cands:
|
||||
return []
|
||||
scores = self._score_pairs(query, cands)
|
||||
results = sorted(
|
||||
(RerankResult(candidate=c, score=s, rank=0) for c, s in zip(cands, scores)),
|
||||
key=lambda r: r.score,
|
||||
reverse=True,
|
||||
)
|
||||
if top_n > 0:
|
||||
results = results[:top_n]
|
||||
return [
|
||||
RerankResult(candidate=r.candidate, score=r.score, rank=i)
|
||||
for i, r in enumerate(results)
|
||||
]
|
||||
|
||||
def rerank_batch(
|
||||
self,
|
||||
queries: Sequence[str],
|
||||
candidates: Sequence[Sequence[str]],
|
||||
top_n: int = 0,
|
||||
) -> list[list[RerankResult]]:
|
||||
return [
|
||||
self.rerank(q, cs, top_n)
|
||||
for q, cs in zip(queries, candidates)
|
||||
]
|
||||
|
|
@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|||
|
||||
[project]
|
||||
name = "circuitforge-core"
|
||||
version = "0.14.0"
|
||||
version = "0.15.0"
|
||||
description = "Shared scaffold for CircuitForge products (MIT)"
|
||||
requires-python = ">=3.11"
|
||||
dependencies = [
|
||||
|
|
@ -83,6 +83,14 @@ vision-service = [
|
|||
"uvicorn[standard]>=0.29",
|
||||
"python-multipart>=0.0.9",
|
||||
]
|
||||
reranker-bge = [
|
||||
"FlagEmbedding>=1.2",
|
||||
]
|
||||
reranker-qwen3 = [
|
||||
"torch>=2.0",
|
||||
"transformers>=4.40",
|
||||
"accelerate>=0.27",
|
||||
]
|
||||
dev = [
|
||||
"circuitforge-core[manage]",
|
||||
"pytest>=8.0",
|
||||
|
|
|
|||
0
tests/test_reranker/__init__.py
Normal file
0
tests/test_reranker/__init__.py
Normal file
102
tests/test_reranker/test_base.py
Normal file
102
tests/test_reranker/test_base.py
Normal file
|
|
@ -0,0 +1,102 @@
|
|||
"""Tests for RerankResult, TextReranker base class, and the public rerank() API."""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import pytest
|
||||
|
||||
from circuitforge_core.reranker.base import RerankResult, TextReranker
|
||||
from circuitforge_core.reranker.adapters.mock import MockTextReranker
|
||||
|
||||
|
||||
# ── RerankResult ──────────────────────────────────────────────────────────────
|
||||
|
||||
def test_rerank_result_fields():
|
||||
r = RerankResult(candidate="recipe text", score=0.9, rank=0)
|
||||
assert r.candidate == "recipe text"
|
||||
assert r.score == 0.9
|
||||
assert r.rank == 0
|
||||
|
||||
|
||||
def test_rerank_result_is_frozen():
|
||||
r = RerankResult(candidate="x", score=0.5, rank=0)
|
||||
with pytest.raises(Exception):
|
||||
r.score = 0.1 # type: ignore[misc]
|
||||
|
||||
|
||||
# ── MockTextReranker ──────────────────────────────────────────────────────────
|
||||
|
||||
def test_mock_rerank_returns_sorted_results():
|
||||
reranker = MockTextReranker()
|
||||
results = reranker.rerank(
|
||||
"chicken soup recipe",
|
||||
["chocolate cake recipe", "chicken noodle soup", "tomato basil pasta"],
|
||||
)
|
||||
assert len(results) == 3
|
||||
# "chicken noodle soup" shares more tokens with query → should rank first
|
||||
assert results[0].candidate == "chicken noodle soup"
|
||||
assert results[0].rank == 0
|
||||
assert results[1].rank == 1
|
||||
assert results[2].rank == 2
|
||||
|
||||
|
||||
def test_mock_rerank_top_n():
|
||||
reranker = MockTextReranker()
|
||||
results = reranker.rerank("chicken", ["a", "b chicken", "c chicken soup"], top_n=2)
|
||||
assert len(results) == 2
|
||||
|
||||
|
||||
def test_mock_rerank_empty_candidates():
|
||||
reranker = MockTextReranker()
|
||||
assert reranker.rerank("query", []) == []
|
||||
|
||||
|
||||
def test_mock_rerank_scores_descending():
|
||||
reranker = MockTextReranker()
|
||||
results = reranker.rerank("apple pie dessert", ["apple pie", "beef stew", "apple crumble dessert"])
|
||||
scores = [r.score for r in results]
|
||||
assert scores == sorted(scores, reverse=True)
|
||||
|
||||
|
||||
def test_mock_rerank_batch():
|
||||
reranker = MockTextReranker()
|
||||
batch = reranker.rerank_batch(
|
||||
["chicken soup", "chocolate cake"],
|
||||
[["chicken noodle", "beef stew"], ["chocolate mousse", "caesar salad"]],
|
||||
top_n=1,
|
||||
)
|
||||
assert len(batch) == 2
|
||||
assert batch[0][0].candidate == "chicken noodle"
|
||||
assert batch[1][0].candidate == "chocolate mousse"
|
||||
|
||||
|
||||
def test_mock_model_id():
|
||||
assert MockTextReranker().model_id == "mock"
|
||||
|
||||
|
||||
# ── Public API singleton ──────────────────────────────────────────────────────
|
||||
|
||||
def test_rerank_function_mock_mode(monkeypatch):
|
||||
monkeypatch.setenv("CF_RERANKER_MOCK", "1")
|
||||
from circuitforge_core.reranker import rerank, reset_reranker
|
||||
reset_reranker()
|
||||
results = rerank("chicken soup", ["beef stew", "chicken noodle soup", "cake"])
|
||||
assert results[0].candidate == "chicken noodle soup"
|
||||
reset_reranker()
|
||||
|
||||
|
||||
def test_make_reranker_mock_explicit():
|
||||
from circuitforge_core.reranker import make_reranker
|
||||
r = make_reranker(mock=True)
|
||||
assert isinstance(r, MockTextReranker)
|
||||
|
||||
|
||||
def test_make_reranker_unknown_backend_raises():
|
||||
from circuitforge_core.reranker import make_reranker
|
||||
with pytest.raises(ValueError, match="Unknown reranker backend"):
|
||||
make_reranker(backend="nonexistent")
|
||||
|
||||
|
||||
def test_reranker_protocol_conformance():
|
||||
"""MockTextReranker satisfies the Reranker Protocol."""
|
||||
from circuitforge_core.reranker.base import Reranker
|
||||
assert isinstance(MockTextReranker(), Reranker)
|
||||
96
tests/test_reranker/test_bge.py
Normal file
96
tests/test_reranker/test_bge.py
Normal file
|
|
@ -0,0 +1,96 @@
|
|||
"""Tests for BGETextReranker with mocked FlagEmbedding."""
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from circuitforge_core.reranker.adapters.bge import BGETextReranker
|
||||
from circuitforge_core.reranker.base import RerankResult
|
||||
|
||||
|
||||
def _make_mock_flag_reranker(scores: list[float]) -> MagicMock:
|
||||
"""Return a mock FlagReranker that yields the given scores."""
|
||||
m = MagicMock()
|
||||
m.compute_score.return_value = scores
|
||||
return m
|
||||
|
||||
|
||||
# ── BGETextReranker unit tests ────────────────────────────────────────────────
|
||||
|
||||
def test_bge_rerank_scores_and_sorts():
|
||||
reranker = BGETextReranker("BAAI/bge-reranker-base")
|
||||
mock_fr = _make_mock_flag_reranker([0.2, 0.9, 0.5])
|
||||
reranker._reranker = mock_fr
|
||||
|
||||
results = reranker.rerank("query", ["a", "b", "c"])
|
||||
assert len(results) == 3
|
||||
assert results[0].candidate == "b" # highest score 0.9
|
||||
assert results[0].score == pytest.approx(0.9)
|
||||
assert results[0].rank == 0
|
||||
assert results[1].candidate == "c"
|
||||
assert results[2].candidate == "a"
|
||||
|
||||
|
||||
def test_bge_rerank_top_n():
|
||||
reranker = BGETextReranker("BAAI/bge-reranker-base")
|
||||
reranker._reranker = _make_mock_flag_reranker([0.1, 0.8, 0.5])
|
||||
results = reranker.rerank("q", ["a", "b", "c"], top_n=2)
|
||||
assert len(results) == 2
|
||||
assert results[0].candidate == "b"
|
||||
|
||||
|
||||
def test_bge_rerank_single_candidate_float_return():
|
||||
"""compute_score returns a float (not list) for a single pair."""
|
||||
reranker = BGETextReranker("BAAI/bge-reranker-base")
|
||||
mock_fr = MagicMock()
|
||||
mock_fr.compute_score.return_value = 0.75 # single float
|
||||
reranker._reranker = mock_fr
|
||||
results = reranker.rerank("q", ["only candidate"])
|
||||
assert len(results) == 1
|
||||
assert results[0].score == pytest.approx(0.75)
|
||||
|
||||
|
||||
def test_bge_rerank_batch_flattens_pairs():
|
||||
reranker = BGETextReranker("BAAI/bge-reranker-base")
|
||||
mock_fr = _make_mock_flag_reranker([0.9, 0.1, 0.3, 0.8])
|
||||
reranker._reranker = mock_fr
|
||||
|
||||
batch = reranker.rerank_batch(
|
||||
["q1", "q2"],
|
||||
[["a1", "a2"], ["b1", "b2"]],
|
||||
)
|
||||
assert len(batch) == 2
|
||||
# q1: scores [0.9, 0.1] → a1 first
|
||||
assert batch[0][0].candidate == "a1"
|
||||
# q2: scores [0.3, 0.8] → b2 first
|
||||
assert batch[1][0].candidate == "b2"
|
||||
|
||||
# All pairs were sent in a single compute_score call
|
||||
all_pairs = mock_fr.compute_score.call_args[0][0]
|
||||
assert len(all_pairs) == 4
|
||||
|
||||
|
||||
def test_bge_rerank_empty_batch():
|
||||
reranker = BGETextReranker("BAAI/bge-reranker-base")
|
||||
reranker._reranker = MagicMock()
|
||||
result = reranker.rerank_batch([], [], top_n=5)
|
||||
assert result == []
|
||||
|
||||
|
||||
def test_bge_load_raises_without_flagembedding():
|
||||
reranker = BGETextReranker("BAAI/bge-reranker-base")
|
||||
with patch("circuitforge_core.reranker.adapters.bge._FlagReranker", None):
|
||||
with pytest.raises(ImportError, match="FlagEmbedding"):
|
||||
reranker.load()
|
||||
|
||||
|
||||
def test_bge_model_id():
|
||||
assert BGETextReranker("BAAI/bge-reranker-v2-m3").model_id == "BAAI/bge-reranker-v2-m3"
|
||||
|
||||
|
||||
def test_bge_unload_clears_reranker():
|
||||
reranker = BGETextReranker("BAAI/bge-reranker-base")
|
||||
reranker._reranker = MagicMock()
|
||||
reranker.unload()
|
||||
assert reranker._reranker is None
|
||||
Loading…
Reference in a new issue