feat(reranker): full adapter suite + cf-orch auto-routing (closes #54)
Five backends: BGE (FlagEmbedding), Qwen3 (generative yes/no logit scorer, batched forward pass), CrossEncoder (sentence-transformers, covers mxbai-rerank / ms-marco / jina), Cohere (BYOK cloud), Remote (HTTP delegate to cf-reranker service). Mock adapter for tests. 54 tests. cf-reranker FastAPI service app (port 8011) — cf-orch manages as a process, defaults to Qwen3-Reranker-0.6B. make_reranker() auto-detects CF_ORCH_URL and routes to cf-orch cf-reranker when set — cloud apps (Kiwi, Peregrine, Snipe) get remote Qwen3 reranking with zero code changes. Local dev falls back to local BGE. pyproject extras: reranker-bge, reranker-qwen3, reranker-cross-encoder, reranker-cohere, reranker-service.
This commit is contained in:
parent
b21d6acc8e
commit
185057d8ca
12 changed files with 1277 additions and 4 deletions
31
CHANGELOG.md
31
CHANGELOG.md
|
|
@ -6,6 +6,37 @@ Versions follow [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
|||
|
||||
---
|
||||
|
||||
## [0.17.0] — 2026-04-27
|
||||
|
||||
### Added
|
||||
|
||||
**`circuitforge_core.reranker`** — shared reranker module for RAG pipelines across the orchard (MIT, closes #54)
|
||||
|
||||
Five adapters covering local and cloud paths:
|
||||
|
||||
- `adapters/bge.py` — `BGETextReranker`: FlagEmbedding cross-encoder (`BAAI/bge-reranker-*`). Batches all pairs in a single `compute_score()` call via `rerank_batch()`. Thread-safe with internal lock. Free tier.
|
||||
- `adapters/qwen3.py` — `Qwen3TextReranker`: generative reranker using `AutoModelForCausalLM`. Scores by reading yes/no token logits at the last input position after pre-filling the assistant `<think>\n\n</think>` block — one forward pass per batch, no generation loop. Left-pads for consistent last-token position across batch. Free / Paid tier.
|
||||
- `adapters/cross_encoder.py` — `CrossEncoderTextReranker`: sentence-transformers `CrossEncoder`. Broader model coverage: `mxbai-rerank-*`, `ms-marco-MiniLM-*`, `jina-reranker-*`. Free tier.
|
||||
- `adapters/cohere.py` — `CohereTextReranker`: Cohere Rerank API (BYOK cloud path). Reads `COHERE_API_KEY` from env or explicit `api_key=` arg. Restores original candidate order from Cohere's score-sorted response. Paid / BYOK.
|
||||
- `adapters/remote.py` — `RemoteTextReranker`: HTTP delegate to a cf-reranker service endpoint. `from_cf_orch()` classmethod allocates via cf-orch on demand. `release()` method returns the lease.
|
||||
- `adapters/mock.py` — `MockTextReranker`: Jaccard-similarity scorer, no model required. Used in tests and `CF_RERANKER_MOCK=1` mode.
|
||||
|
||||
`app.py` — `cf-reranker` FastAPI service (port 8011). Managed by cf-orch as a process-type service. Exposes `GET /health` and `POST /rerank`. Defaults to `Qwen3-Reranker-0.6B`.
|
||||
|
||||
**Auto cf-orch routing:** `make_reranker()` checks `CF_ORCH_URL` at construction time. When set (cloud deployments), it automatically allocates a `cf-reranker` service via cf-orch and returns a `RemoteTextReranker` — no code changes needed in Kiwi, Peregrine, or Snipe. Local dev (no `CF_ORCH_URL`) falls back to local BGE inference.
|
||||
|
||||
**Public API:**
|
||||
- `rerank(query, candidates, top_n)` — process-level singleton, mock-safe
|
||||
- `make_reranker(model_id, backend, mock)` — explicit instance
|
||||
- `reset_reranker()` — test teardown only
|
||||
- `RerankResult(candidate, score, rank)` — frozen dataclass result type
|
||||
|
||||
**`pyproject.toml` extras:** `reranker-bge`, `reranker-qwen3`, `reranker-cross-encoder`, `reranker-cohere`, `reranker-service`
|
||||
|
||||
54 tests across all adapters.
|
||||
|
||||
---
|
||||
|
||||
## [0.14.0] — 2026-04-20
|
||||
|
||||
### Added
|
||||
|
|
|
|||
|
|
@ -92,7 +92,21 @@ def make_reranker(
|
|||
return MockTextReranker()
|
||||
|
||||
_model_id = model_id or os.environ.get("CF_RERANKER_MODEL", _DEFAULT_MODEL)
|
||||
_backend = backend or os.environ.get("CF_RERANKER_BACKEND", "bge")
|
||||
_backend = backend or os.environ.get("CF_RERANKER_BACKEND", "")
|
||||
|
||||
# Auto-route to cf-orch when CF_ORCH_URL is set and no explicit backend override.
|
||||
# Cloud deployments set CF_ORCH_URL; local dev leaves it unset → local inference.
|
||||
if not _backend:
|
||||
orch_url = os.environ.get("CF_ORCH_URL", "")
|
||||
if orch_url:
|
||||
from circuitforge_core.reranker.adapters.remote import RemoteTextReranker
|
||||
logger.info("[reranker] CF_ORCH_URL set — using remote cf-reranker via cf-orch")
|
||||
return RemoteTextReranker.from_cf_orch(
|
||||
orch_url=orch_url,
|
||||
service="cf-reranker",
|
||||
ttl_s=float(os.environ.get("CF_RERANKER_TTL", "3600")),
|
||||
)
|
||||
_backend = "bge" # local default
|
||||
|
||||
if _backend == "mock":
|
||||
return MockTextReranker()
|
||||
|
|
@ -101,10 +115,25 @@ def make_reranker(
|
|||
from circuitforge_core.reranker.adapters.bge import BGETextReranker
|
||||
return BGETextReranker(_model_id)
|
||||
|
||||
if _backend == "qwen3":
|
||||
from circuitforge_core.reranker.adapters.qwen3 import Qwen3TextReranker
|
||||
return Qwen3TextReranker(_model_id)
|
||||
|
||||
if _backend == "cross-encoder":
|
||||
from circuitforge_core.reranker.adapters.cross_encoder import CrossEncoderTextReranker
|
||||
return CrossEncoderTextReranker(_model_id)
|
||||
|
||||
if _backend == "cohere":
|
||||
from circuitforge_core.reranker.adapters.cohere import CohereTextReranker
|
||||
return CohereTextReranker(model=_model_id)
|
||||
|
||||
if _backend == "remote":
|
||||
from circuitforge_core.reranker.adapters.remote import RemoteTextReranker
|
||||
return RemoteTextReranker(_model_id)
|
||||
|
||||
raise ValueError(
|
||||
f"Unknown reranker backend {_backend!r}. "
|
||||
"Valid options: 'bge', 'mock'. "
|
||||
"(Qwen3 support is coming in Phase 2.)"
|
||||
"Valid options: 'bge', 'qwen3', 'cross-encoder', 'cohere', 'remote', 'mock'."
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
94
circuitforge_core/reranker/adapters/cohere.py
Normal file
94
circuitforge_core/reranker/adapters/cohere.py
Normal file
|
|
@ -0,0 +1,94 @@
|
|||
# circuitforge_core/reranker/adapters/cohere.py — Cohere Rerank API (BYOK cloud)
|
||||
#
|
||||
# Requires: pip install circuitforge-core[reranker-cohere]
|
||||
# API key: set COHERE_API_KEY env var, or pass api_key= explicitly.
|
||||
#
|
||||
# Models (as of 2026):
|
||||
# rerank-english-v3.0 English-only, highest quality
|
||||
# rerank-multilingual-v3.0 Multilingual
|
||||
# rerank-english-v2.0 Legacy, lower cost
|
||||
#
|
||||
# BYOK unlock path: free-tier users who supply their own Cohere key get cloud
|
||||
# reranking without needing a cf-orch node. Same pattern as the Anthropic
|
||||
# backend in LLMRouter.
|
||||
#
|
||||
# MIT licensed.
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Sequence
|
||||
|
||||
from circuitforge_core.reranker.base import TextReranker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
import cohere as _cohere # type: ignore[import]
|
||||
except ImportError:
|
||||
_cohere = None # type: ignore[assignment]
|
||||
|
||||
_DEFAULT_MODEL = "rerank-english-v3.0"
|
||||
|
||||
|
||||
class CohereTextReranker(TextReranker):
|
||||
"""
|
||||
Cloud reranker backed by the Cohere Rerank API.
|
||||
|
||||
BYOK (bring your own key): pass api_key= or set COHERE_API_KEY in the
|
||||
environment. No model weights loaded locally.
|
||||
|
||||
Usage:
|
||||
reranker = CohereTextReranker() # reads COHERE_API_KEY from env
|
||||
results = reranker.rerank("chicken soup recipe", ["recipe 1...", "recipe 2..."])
|
||||
|
||||
With an explicit key and model:
|
||||
reranker = CohereTextReranker(
|
||||
api_key="co-...",
|
||||
model="rerank-multilingual-v3.0",
|
||||
)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
model: str = _DEFAULT_MODEL,
|
||||
max_chunks_per_doc: int = 1,
|
||||
) -> None:
|
||||
self._api_key_arg = api_key
|
||||
self._model = model
|
||||
self._max_chunks_per_doc = max_chunks_per_doc
|
||||
|
||||
@property
|
||||
def model_id(self) -> str:
|
||||
return f"cohere:{self._model}"
|
||||
|
||||
def _get_client(self) -> object:
|
||||
if _cohere is None:
|
||||
raise ImportError(
|
||||
"cohere is not installed. "
|
||||
"Run: pip install circuitforge-core[reranker-cohere]"
|
||||
)
|
||||
api_key = self._api_key_arg or os.environ.get("COHERE_API_KEY", "")
|
||||
if not api_key:
|
||||
raise RuntimeError(
|
||||
"Cohere API key is not set. "
|
||||
"Pass api_key= to CohereTextReranker or set COHERE_API_KEY."
|
||||
)
|
||||
return _cohere.Client(api_key=api_key)
|
||||
|
||||
def _score_pairs(self, query: str, candidates: list[str]) -> list[float]:
|
||||
client = self._get_client()
|
||||
response = client.rerank( # type: ignore[union-attr]
|
||||
query=query,
|
||||
documents=candidates,
|
||||
model=self._model,
|
||||
top_n=len(candidates),
|
||||
max_chunks_per_doc=self._max_chunks_per_doc,
|
||||
)
|
||||
# response.results is sorted by relevance_score desc; rebuild
|
||||
# in original candidate order so TextReranker.rerank() re-sorts correctly.
|
||||
score_map: dict[int, float] = {
|
||||
r.index: r.relevance_score for r in response.results
|
||||
}
|
||||
return [score_map.get(i, 0.0) for i in range(len(candidates))]
|
||||
96
circuitforge_core/reranker/adapters/cross_encoder.py
Normal file
96
circuitforge_core/reranker/adapters/cross_encoder.py
Normal file
|
|
@ -0,0 +1,96 @@
|
|||
# circuitforge_core/reranker/adapters/cross_encoder.py — sentence-transformers CrossEncoder
|
||||
#
|
||||
# Requires: pip install circuitforge-core[reranker-cross-encoder]
|
||||
#
|
||||
# Covers models not in the FlagEmbedding ecosystem:
|
||||
# mixedbread-ai/mxbai-rerank-base-v1 ~570MB VRAM, strong general-purpose
|
||||
# mixedbread-ai/mxbai-rerank-large-v1 ~1.3GB VRAM, higher quality
|
||||
# cross-encoder/ms-marco-MiniLM-L-6-v2 ~90MB, fast, English-only
|
||||
# cross-encoder/ms-marco-MiniLM-L-12-v2 ~130MB, balanced
|
||||
# jinaai/jina-reranker-v2-base-multilingual ~280MB, multilingual
|
||||
#
|
||||
# MIT licensed.
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from typing import Sequence
|
||||
|
||||
from circuitforge_core.reranker.base import TextReranker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
from sentence_transformers import CrossEncoder as _CrossEncoder # type: ignore[import]
|
||||
except ImportError:
|
||||
_CrossEncoder = None # type: ignore[assignment]
|
||||
|
||||
|
||||
def _cuda_available() -> bool:
|
||||
try:
|
||||
import torch
|
||||
return torch.cuda.is_available()
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
class CrossEncoderTextReranker(TextReranker):
|
||||
"""
|
||||
Cross-encoder reranker using the sentence-transformers CrossEncoder class.
|
||||
|
||||
Broader model compatibility than BGETextReranker — any HuggingFace model
|
||||
with a sequence-classification head works here. Particularly well-suited
|
||||
for the mxbai-rerank and ms-marco families.
|
||||
|
||||
Usage:
|
||||
reranker = CrossEncoderTextReranker("mixedbread-ai/mxbai-rerank-base-v1")
|
||||
results = reranker.rerank("chicken soup recipe", ["recipe 1...", "recipe 2..."])
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str = "mixedbread-ai/mxbai-rerank-base-v1",
|
||||
max_length: int = 512,
|
||||
) -> None:
|
||||
self._model_id = model_id
|
||||
self._max_length = max_length
|
||||
self._model: object | None = None
|
||||
self._lock = threading.Lock()
|
||||
|
||||
@property
|
||||
def model_id(self) -> str:
|
||||
return self._model_id
|
||||
|
||||
def load(self) -> None:
|
||||
"""Explicitly load model weights. Called automatically on first rerank()."""
|
||||
if _CrossEncoder is None:
|
||||
raise ImportError(
|
||||
"sentence-transformers is not installed. "
|
||||
"Run: pip install circuitforge-core[reranker-cross-encoder]"
|
||||
)
|
||||
with self._lock:
|
||||
if self._model is not None:
|
||||
return
|
||||
device = "cuda" if _cuda_available() else "cpu"
|
||||
logger.info(
|
||||
"Loading CrossEncoder reranker: %s (device=%s)", self._model_id, device
|
||||
)
|
||||
self._model = _CrossEncoder(
|
||||
self._model_id,
|
||||
max_length=self._max_length,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def unload(self) -> None:
|
||||
"""Release model weights."""
|
||||
with self._lock:
|
||||
self._model = None
|
||||
|
||||
def _score_pairs(self, query: str, candidates: list[str]) -> list[float]:
|
||||
if self._model is None:
|
||||
self.load()
|
||||
pairs = [(query, c) for c in candidates]
|
||||
with self._lock:
|
||||
raw = self._model.predict(pairs) # type: ignore[union-attr]
|
||||
# predict() returns a numpy array or list; normalise to plain floats.
|
||||
return [float(s) for s in raw]
|
||||
239
circuitforge_core/reranker/adapters/qwen3.py
Normal file
239
circuitforge_core/reranker/adapters/qwen3.py
Normal file
|
|
@ -0,0 +1,239 @@
|
|||
# circuitforge_core/reranker/adapters/qwen3.py — Qwen3-Reranker adapter
|
||||
#
|
||||
# Requires: pip install circuitforge-core[reranker-qwen3]
|
||||
# Tested with: Qwen/Qwen3-Reranker-0.6B, -1.5B, -8B
|
||||
#
|
||||
# Scoring mechanism (generative reranker):
|
||||
# Rather than generating a full response, we pre-fill the assistant turn with
|
||||
# the <think>\n\n</think>\n block and read the logits at the last input token
|
||||
# position. The softmax probability of "yes" vs "no" at that position is the
|
||||
# relevance score — one forward pass per batch, no generation loop.
|
||||
#
|
||||
# Prompt format (Qwen3 chat template):
|
||||
# system: "Judge whether the Document meets the requirements based on the
|
||||
# Query and the Instruct. Note that the answer can only be 'yes'
|
||||
# or 'no'."
|
||||
# user: "<Instruct>: {task}\n<Query>: {query}\n<Document>: {doc}"
|
||||
# assistant (pre-filled): "<think>\n\n</think>\n\n"
|
||||
#
|
||||
# MIT licensed.
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from typing import Sequence
|
||||
|
||||
from circuitforge_core.reranker.base import TextReranker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Optional heavy deps — lazy-imported at load() time.
|
||||
try:
|
||||
import torch as _torch # type: ignore[import]
|
||||
except ImportError:
|
||||
_torch = None # type: ignore[assignment]
|
||||
|
||||
try:
|
||||
from transformers import AutoModelForCausalLM as _AutoModel # type: ignore[import]
|
||||
from transformers import AutoTokenizer as _AutoTokenizer # type: ignore[import]
|
||||
except ImportError:
|
||||
_AutoModel = None # type: ignore[assignment]
|
||||
_AutoTokenizer = None # type: ignore[assignment]
|
||||
|
||||
# System prompt used for all reranking tasks.
|
||||
_SYSTEM_PROMPT = (
|
||||
"Judge whether the Document meets the requirements based on the Query and "
|
||||
'the Instruct. Note that the answer can only be "yes" or "no".'
|
||||
)
|
||||
|
||||
# Default task instruction — products can override via make_reranker(task=...).
|
||||
_DEFAULT_TASK = "Given a query, retrieve the most relevant document that answers the query."
|
||||
|
||||
# The pre-filled assistant turn that puts the model past its thinking block
|
||||
# so the very next token position scores "yes" vs "no".
|
||||
_ASSISTANT_PREFILL = "<think>\n\n</think>\n\n"
|
||||
|
||||
|
||||
def _requires_deps() -> None:
|
||||
if _torch is None:
|
||||
raise ImportError(
|
||||
"torch is not installed. Run: pip install circuitforge-core[reranker-qwen3]"
|
||||
)
|
||||
if _AutoModel is None:
|
||||
raise ImportError(
|
||||
"transformers is not installed. Run: pip install circuitforge-core[reranker-qwen3]"
|
||||
)
|
||||
|
||||
|
||||
class Qwen3TextReranker(TextReranker):
|
||||
"""
|
||||
Generative reranker using the Qwen3-Reranker model family.
|
||||
|
||||
Scores candidates by reading yes/no token logits at the last input position
|
||||
after pre-filling the assistant thinking block. One forward pass covers an
|
||||
entire batch — efficient for ranking large candidate lists.
|
||||
|
||||
Model options (by tier):
|
||||
Free: Qwen/Qwen3-Reranker-0.6B (~1.2 GB VRAM fp16)
|
||||
Qwen/Qwen3-Reranker-1.5B (~3.0 GB VRAM fp16)
|
||||
Paid: Qwen/Qwen3-Reranker-8B (~16 GB VRAM fp16, or ~9 GB int8)
|
||||
|
||||
Usage:
|
||||
reranker = Qwen3TextReranker("Qwen/Qwen3-Reranker-0.6B")
|
||||
results = reranker.rerank("chicken soup recipe", ["recipe 1...", "recipe 2..."])
|
||||
|
||||
With a custom task instruction:
|
||||
reranker = Qwen3TextReranker(
|
||||
"Qwen/Qwen3-Reranker-1.5B",
|
||||
task="Given a job description, retrieve the most relevant resume.",
|
||||
)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str = "Qwen/Qwen3-Reranker-0.6B",
|
||||
task: str = _DEFAULT_TASK,
|
||||
device: str | None = None,
|
||||
dtype: str = "float16",
|
||||
batch_size: int = 32,
|
||||
) -> None:
|
||||
self._model_id = model_id
|
||||
self._task = task
|
||||
self._device = device # None = auto-detect at load time
|
||||
self._dtype_str = dtype
|
||||
self._batch_size = batch_size
|
||||
self._model: object | None = None
|
||||
self._tokenizer: object | None = None
|
||||
self._yes_id: int | None = None
|
||||
self._no_id: int | None = None
|
||||
self._lock = threading.Lock()
|
||||
|
||||
@property
|
||||
def model_id(self) -> str:
|
||||
return self._model_id
|
||||
|
||||
def load(self) -> None:
|
||||
"""Explicitly load model weights. Called automatically on first rerank()."""
|
||||
_requires_deps()
|
||||
with self._lock:
|
||||
if self._model is not None:
|
||||
return
|
||||
device = self._device or ("cuda" if _torch.cuda.is_available() else "cpu")
|
||||
dtype_map: dict[str, object] = {
|
||||
"float16": _torch.float16,
|
||||
"bfloat16": _torch.bfloat16,
|
||||
"float32": _torch.float32,
|
||||
}
|
||||
torch_dtype = dtype_map.get(self._dtype_str, _torch.float16)
|
||||
|
||||
logger.info(
|
||||
"Loading Qwen3 reranker: %s (device=%s dtype=%s)",
|
||||
self._model_id, device, self._dtype_str,
|
||||
)
|
||||
tokenizer = _AutoTokenizer.from_pretrained(
|
||||
self._model_id, trust_remote_code=True
|
||||
)
|
||||
model = _AutoModel.from_pretrained(
|
||||
self._model_id,
|
||||
torch_dtype=torch_dtype,
|
||||
device_map=device,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
model.eval()
|
||||
|
||||
# Resolve the token IDs for "yes" and "no" once at load time.
|
||||
# Qwen tokenizers encode single-word tokens without a leading space.
|
||||
yes_ids: list[int] = tokenizer.encode("yes", add_special_tokens=False)
|
||||
no_ids: list[int] = tokenizer.encode("no", add_special_tokens=False)
|
||||
if not yes_ids or not no_ids:
|
||||
raise RuntimeError(
|
||||
f"Could not resolve 'yes'/'no' token IDs from tokenizer {self._model_id!r}. "
|
||||
"This model may not be a Qwen3-Reranker variant."
|
||||
)
|
||||
|
||||
self._tokenizer = tokenizer
|
||||
self._model = model
|
||||
self._yes_id = yes_ids[0]
|
||||
self._no_id = no_ids[0]
|
||||
|
||||
def unload(self) -> None:
|
||||
"""Release model weights. Useful for VRAM management between tasks."""
|
||||
with self._lock:
|
||||
self._model = None
|
||||
self._tokenizer = None
|
||||
self._yes_id = None
|
||||
self._no_id = None
|
||||
|
||||
def _build_prompt(self, query: str, document: str) -> str:
|
||||
"""Format a single (query, document) pair as a chat-template prompt."""
|
||||
messages = [
|
||||
{"role": "system", "content": _SYSTEM_PROMPT},
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
f"<Instruct>: {self._task}\n"
|
||||
f"<Query>: {query}\n"
|
||||
f"<Document>: {document}"
|
||||
),
|
||||
},
|
||||
]
|
||||
# apply_chat_template without tokenization so we can append the prefill.
|
||||
text: str = self._tokenizer.apply_chat_template( # type: ignore[union-attr]
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
return text + _ASSISTANT_PREFILL
|
||||
|
||||
def _score_pairs(self, query: str, candidates: list[str]) -> list[float]:
|
||||
if self._model is None:
|
||||
self.load()
|
||||
return self._score_in_batches(query, candidates)
|
||||
|
||||
def _score_in_batches(self, query: str, candidates: list[str]) -> list[float]:
|
||||
"""Score all (query, candidate) pairs, splitting into sub-batches."""
|
||||
all_scores: list[float] = []
|
||||
for start in range(0, len(candidates), self._batch_size):
|
||||
batch = candidates[start : start + self._batch_size]
|
||||
all_scores.extend(self._score_batch(query, batch))
|
||||
return all_scores
|
||||
|
||||
def _score_batch(self, query: str, candidates: list[str]) -> list[float]:
|
||||
"""Single forward pass for one sub-batch. Returns a score per candidate."""
|
||||
prompts = [self._build_prompt(query, c) for c in candidates]
|
||||
|
||||
# Left-pad so the last token position is consistent across all sequences.
|
||||
tokenizer = self._tokenizer # type: ignore[union-attr]
|
||||
original_side = getattr(tokenizer, "padding_side", "right")
|
||||
tokenizer.padding_side = "left"
|
||||
try:
|
||||
encoded = tokenizer(
|
||||
prompts,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=4096,
|
||||
)
|
||||
finally:
|
||||
tokenizer.padding_side = original_side
|
||||
|
||||
model = self._model # type: ignore[union-attr]
|
||||
device = next(model.parameters()).device # type: ignore[union-attr]
|
||||
input_ids = encoded["input_ids"].to(device)
|
||||
attention_mask = encoded["attention_mask"].to(device)
|
||||
|
||||
with self._lock:
|
||||
with _torch.no_grad():
|
||||
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
|
||||
|
||||
# logits shape: (batch, seq_len, vocab_size)
|
||||
# Last position [-1] is the token the model would output next.
|
||||
last_logits = outputs.logits[:, -1, :] # (batch, vocab)
|
||||
yes_logits = last_logits[:, self._yes_id] # (batch,)
|
||||
no_logits = last_logits[:, self._no_id] # (batch,)
|
||||
|
||||
# Softmax over yes/no only — score = P(yes | query, doc).
|
||||
stacked = _torch.stack([yes_logits, no_logits], dim=-1) # (batch, 2)
|
||||
probs = _torch.softmax(stacked, dim=-1)
|
||||
scores: list[float] = probs[:, 0].float().cpu().tolist()
|
||||
return scores
|
||||
131
circuitforge_core/reranker/adapters/remote.py
Normal file
131
circuitforge_core/reranker/adapters/remote.py
Normal file
|
|
@ -0,0 +1,131 @@
|
|||
# circuitforge_core/reranker/adapters/remote.py — HTTP remote reranker adapter
|
||||
#
|
||||
# Calls a cf-reranker service endpoint (cf-orch allocated or static URL).
|
||||
# No model weights loaded locally — all inference runs on the remote node.
|
||||
#
|
||||
# MIT licensed.
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Sequence
|
||||
|
||||
import requests
|
||||
|
||||
from circuitforge_core.reranker.base import TextReranker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Default timeout for a single /rerank call (seconds).
|
||||
# Large candidate lists may take longer — callers can pass timeout= explicitly.
|
||||
_DEFAULT_TIMEOUT = 30
|
||||
|
||||
|
||||
class RemoteTextReranker(TextReranker):
|
||||
"""
|
||||
Reranker that delegates scoring to a remote cf-reranker HTTP service.
|
||||
|
||||
The remote service must implement POST /rerank with the request body::
|
||||
|
||||
{"query": str, "candidates": [str, ...], "top_n": int}
|
||||
|
||||
and return::
|
||||
|
||||
{"results": [{"candidate": str, "score": float, "rank": int}, ...]}
|
||||
|
||||
cf-orch allocation (recommended — starts service on-demand):
|
||||
reranker = RemoteTextReranker.from_cf_orch(
|
||||
orch_url="http://10.1.10.71:7700",
|
||||
service="cf-reranker",
|
||||
model_candidates=["qwen3-0.6b"],
|
||||
)
|
||||
|
||||
Static URL (e.g. dedicated node already running cf-reranker):
|
||||
reranker = RemoteTextReranker("http://10.1.10.10:8011")
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str,
|
||||
timeout: int = _DEFAULT_TIMEOUT,
|
||||
_model_id: str = "remote",
|
||||
) -> None:
|
||||
self._base_url = base_url.rstrip("/")
|
||||
self._timeout = timeout
|
||||
self._model_id_str = _model_id
|
||||
|
||||
@property
|
||||
def model_id(self) -> str:
|
||||
return self._model_id_str
|
||||
|
||||
@classmethod
|
||||
def from_cf_orch(
|
||||
cls,
|
||||
orch_url: str,
|
||||
service: str = "cf-reranker",
|
||||
model_candidates: list[str] | None = None,
|
||||
ttl_s: float = 3600.0,
|
||||
timeout: int = _DEFAULT_TIMEOUT,
|
||||
) -> "RemoteTextReranker":
|
||||
"""
|
||||
Allocate a cf-reranker service via cf-orch and return a configured adapter.
|
||||
|
||||
Blocks until allocation succeeds or raises on failure. The returned
|
||||
adapter is valid for the duration of the TTL; create a new one if the
|
||||
lease expires.
|
||||
|
||||
This is a one-shot allocation — the caller owns the lifetime. For
|
||||
long-running services, prefer the static URL constructor and let
|
||||
cf-orch manage the process independently.
|
||||
"""
|
||||
try:
|
||||
from circuitforge_orch.client import CFOrchClient # type: ignore[import]
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"circuitforge_orch is not installed — cannot allocate via cf-orch."
|
||||
) from exc
|
||||
|
||||
client = CFOrchClient(orch_url)
|
||||
ctx = client.allocate(
|
||||
service,
|
||||
model_candidates=model_candidates or [],
|
||||
ttl_s=ttl_s,
|
||||
caller="reranker-remote",
|
||||
)
|
||||
alloc = ctx.__enter__()
|
||||
# Note: caller is responsible for ctx.__exit__() when done.
|
||||
# We stash it on the instance so callers can call release().
|
||||
instance = cls(
|
||||
base_url=alloc.url,
|
||||
timeout=timeout,
|
||||
_model_id=f"remote:{service}",
|
||||
)
|
||||
instance._orch_ctx = ctx # type: ignore[attr-defined]
|
||||
return instance
|
||||
|
||||
def release(self) -> None:
|
||||
"""Release the cf-orch allocation if this adapter was created via from_cf_orch()."""
|
||||
ctx = getattr(self, "_orch_ctx", None)
|
||||
if ctx is not None:
|
||||
try:
|
||||
ctx.__exit__(None, None, None)
|
||||
except Exception:
|
||||
pass
|
||||
self._orch_ctx = None # type: ignore[attr-defined]
|
||||
|
||||
def _score_pairs(self, query: str, candidates: list[str]) -> list[float]:
|
||||
url = f"{self._base_url}/rerank"
|
||||
payload = {"query": query, "candidates": candidates, "top_n": 0}
|
||||
try:
|
||||
resp = requests.post(url, json=payload, timeout=self._timeout)
|
||||
resp.raise_for_status()
|
||||
except requests.RequestException as exc:
|
||||
raise RuntimeError(
|
||||
f"Remote reranker at {url!r} failed: {exc}"
|
||||
) from exc
|
||||
|
||||
data = resp.json()
|
||||
# Build a score-per-candidate list in the original order.
|
||||
score_map: dict[str, float] = {
|
||||
r["candidate"]: r["score"] for r in data["results"]
|
||||
}
|
||||
return [score_map.get(c, 0.0) for c in candidates]
|
||||
169
circuitforge_core/reranker/app.py
Normal file
169
circuitforge_core/reranker/app.py
Normal file
|
|
@ -0,0 +1,169 @@
|
|||
"""
|
||||
circuitforge_core.reranker.app — cf-reranker FastAPI service.
|
||||
|
||||
Managed by cf-orch as a process-type service. cf-orch starts this via:
|
||||
|
||||
python -m circuitforge_core.reranker.app \
|
||||
--model BAAI/bge-reranker-base \
|
||||
--backend bge \
|
||||
--port 8011 \
|
||||
--gpu-id 0
|
||||
|
||||
Or with Qwen3:
|
||||
|
||||
python -m circuitforge_core.reranker.app \
|
||||
--model Qwen/Qwen3-Reranker-0.6B \
|
||||
--backend qwen3 \
|
||||
--port 8011 \
|
||||
--gpu-id 0 \
|
||||
--dtype float16
|
||||
|
||||
Endpoints:
|
||||
GET /health → {"status": "ok", "model": "...", "backend": "...", "vram_mb": n}
|
||||
POST /rerank → RerankResponse
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ── Request / response models ─────────────────────────────────────────────────
|
||||
|
||||
class RerankRequest(BaseModel):
|
||||
query: str
|
||||
candidates: list[str]
|
||||
top_n: int = 0
|
||||
|
||||
|
||||
class RerankResultItem(BaseModel):
|
||||
candidate: str
|
||||
score: float
|
||||
rank: int
|
||||
|
||||
|
||||
class RerankResponse(BaseModel):
|
||||
results: list[RerankResultItem]
|
||||
model: str
|
||||
|
||||
|
||||
class HealthResponse(BaseModel):
|
||||
status: str
|
||||
model: str
|
||||
backend: str
|
||||
vram_mb: int
|
||||
|
||||
|
||||
# ── VRAM estimates by backend/model family ────────────────────────────────────
|
||||
|
||||
_VRAM_TABLE: dict[str, int] = {
|
||||
"bge-reranker-base": 570,
|
||||
"bge-reranker-large": 1300,
|
||||
"bge-reranker-v2-m3": 570,
|
||||
"mxbai-rerank-base-v1": 570,
|
||||
"mxbai-rerank-large-v1": 1300,
|
||||
"ms-marco-MiniLM-L-6-v2": 90,
|
||||
"ms-marco-MiniLM-L-12-v2": 130,
|
||||
"Qwen3-Reranker-0.6B": 1200,
|
||||
"Qwen3-Reranker-1.5B": 3000,
|
||||
"Qwen3-Reranker-8B": 16000,
|
||||
}
|
||||
|
||||
def _estimate_vram(model_id: str) -> int:
|
||||
for key, mb in _VRAM_TABLE.items():
|
||||
if key in model_id:
|
||||
return mb
|
||||
return 1024 # safe default
|
||||
|
||||
|
||||
# ── App factory ───────────────────────────────────────────────────────────────
|
||||
|
||||
def create_app(model_id: str, backend: str, dtype: str, mock: bool) -> FastAPI:
|
||||
from circuitforge_core.reranker import make_reranker
|
||||
|
||||
app = FastAPI(title="cf-reranker", version="0.1.0")
|
||||
_reranker = make_reranker(model_id=model_id, backend=backend, mock=mock)
|
||||
_vram_mb = _estimate_vram(model_id)
|
||||
|
||||
logger.info("cf-reranker ready: model=%r backend=%r vram=%dMB", model_id, backend, _vram_mb)
|
||||
|
||||
@app.get("/health", response_model=HealthResponse)
|
||||
async def health() -> HealthResponse:
|
||||
return HealthResponse(
|
||||
status="ok",
|
||||
model=_reranker.model_id,
|
||||
backend=backend,
|
||||
vram_mb=_vram_mb,
|
||||
)
|
||||
|
||||
@app.post("/rerank", response_model=RerankResponse)
|
||||
async def rerank(req: RerankRequest) -> RerankResponse:
|
||||
if not req.candidates:
|
||||
raise HTTPException(status_code=400, detail="candidates must not be empty")
|
||||
try:
|
||||
results = _reranker.rerank(req.query, req.candidates, top_n=req.top_n)
|
||||
except Exception as exc:
|
||||
logger.exception("rerank failed")
|
||||
raise HTTPException(status_code=500, detail=str(exc)) from exc
|
||||
return RerankResponse(
|
||||
results=[
|
||||
RerankResultItem(candidate=r.candidate, score=r.score, rank=r.rank)
|
||||
for r in results
|
||||
],
|
||||
model=_reranker.model_id,
|
||||
)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
# ── CLI entry point ───────────────────────────────────────────────────────────
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="cf-reranker — CircuitForge reranker service")
|
||||
parser.add_argument(
|
||||
"--model", default="BAAI/bge-reranker-base",
|
||||
help="HuggingFace model ID or local path",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--backend", default="bge",
|
||||
choices=["bge", "qwen3", "cross-encoder", "mock"],
|
||||
help="Reranker backend",
|
||||
)
|
||||
parser.add_argument("--port", type=int, default=8011)
|
||||
parser.add_argument("--host", default="0.0.0.0")
|
||||
parser.add_argument("--gpu-id", type=int, default=0)
|
||||
parser.add_argument(
|
||||
"--dtype", default="float16",
|
||||
choices=["float16", "bfloat16", "float32"],
|
||||
)
|
||||
parser.add_argument("--mock", action="store_true",
|
||||
help="Run with mock backend (no GPU, for testing)")
|
||||
args = parser.parse_args()
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s %(levelname)s %(name)s %(message)s",
|
||||
)
|
||||
|
||||
if args.backend != "mock" and not args.mock:
|
||||
os.environ.setdefault("CUDA_VISIBLE_DEVICES", str(args.gpu_id))
|
||||
|
||||
mock = args.mock or os.environ.get("CF_RERANKER_MOCK", "") == "1"
|
||||
app = create_app(
|
||||
model_id=args.model,
|
||||
backend=args.backend,
|
||||
dtype=args.dtype,
|
||||
mock=mock,
|
||||
)
|
||||
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|||
|
||||
[project]
|
||||
name = "circuitforge-core"
|
||||
version = "0.15.0"
|
||||
version = "0.17.0"
|
||||
description = "Shared scaffold for CircuitForge products (MIT)"
|
||||
requires-python = ">=3.11"
|
||||
dependencies = [
|
||||
|
|
@ -91,6 +91,17 @@ reranker-qwen3 = [
|
|||
"transformers>=4.40",
|
||||
"accelerate>=0.27",
|
||||
]
|
||||
reranker-cross-encoder = [
|
||||
"sentence-transformers>=3.0",
|
||||
]
|
||||
reranker-cohere = [
|
||||
"cohere>=5.0",
|
||||
]
|
||||
reranker-service = [
|
||||
"circuitforge-core[reranker-qwen3]",
|
||||
"fastapi>=0.110",
|
||||
"uvicorn[standard]>=0.29",
|
||||
]
|
||||
dev = [
|
||||
"circuitforge-core[manage]",
|
||||
"pytest>=8.0",
|
||||
|
|
|
|||
106
tests/test_reranker/test_cohere.py
Normal file
106
tests/test_reranker/test_cohere.py
Normal file
|
|
@ -0,0 +1,106 @@
|
|||
"""Tests for CohereTextReranker with mocked cohere client."""
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from circuitforge_core.reranker.adapters.cohere import CohereTextReranker
|
||||
|
||||
|
||||
def _make_cohere_result(index: int, score: float) -> MagicMock:
|
||||
r = MagicMock()
|
||||
r.index = index
|
||||
r.relevance_score = score
|
||||
return r
|
||||
|
||||
|
||||
def _make_mock_client(results: list[MagicMock]) -> MagicMock:
|
||||
response = MagicMock()
|
||||
response.results = results
|
||||
client = MagicMock()
|
||||
client.rerank.return_value = response
|
||||
return client
|
||||
|
||||
|
||||
def test_model_id_includes_model_name():
|
||||
r = CohereTextReranker(model="rerank-multilingual-v3.0")
|
||||
assert r.model_id == "cohere:rerank-multilingual-v3.0"
|
||||
|
||||
|
||||
def test_raises_without_cohere_package():
|
||||
reranker = CohereTextReranker(api_key="co-test")
|
||||
with patch("circuitforge_core.reranker.adapters.cohere._cohere", None):
|
||||
with pytest.raises(ImportError, match="cohere"):
|
||||
reranker._score_pairs("q", ["doc"])
|
||||
|
||||
|
||||
def test_raises_without_api_key(monkeypatch):
|
||||
monkeypatch.delenv("COHERE_API_KEY", raising=False)
|
||||
reranker = CohereTextReranker() # no api_key arg
|
||||
mock_cohere = MagicMock()
|
||||
with patch("circuitforge_core.reranker.adapters.cohere._cohere", mock_cohere):
|
||||
with pytest.raises(RuntimeError, match="API key"):
|
||||
reranker._get_client()
|
||||
|
||||
|
||||
def test_reads_api_key_from_env(monkeypatch):
|
||||
monkeypatch.setenv("COHERE_API_KEY", "co-fromenv")
|
||||
mock_cohere = MagicMock()
|
||||
with patch("circuitforge_core.reranker.adapters.cohere._cohere", mock_cohere):
|
||||
reranker = CohereTextReranker()
|
||||
reranker._get_client()
|
||||
mock_cohere.Client.assert_called_once_with(api_key="co-fromenv")
|
||||
|
||||
|
||||
def test_score_pairs_returns_original_order():
|
||||
"""Cohere returns results sorted by score; we must restore original order."""
|
||||
reranker = CohereTextReranker(api_key="co-test")
|
||||
# Cohere returns candidates ranked: index 2 (0.9), index 0 (0.6), index 1 (0.1)
|
||||
mock_client = _make_mock_client([
|
||||
_make_cohere_result(index=2, score=0.9),
|
||||
_make_cohere_result(index=0, score=0.6),
|
||||
_make_cohere_result(index=1, score=0.1),
|
||||
])
|
||||
with patch.object(reranker, "_get_client", return_value=mock_client):
|
||||
scores = reranker._score_pairs("query", ["a", "b", "c"])
|
||||
# Original order: a=0.6, b=0.1, c=0.9
|
||||
assert scores == [pytest.approx(0.6), pytest.approx(0.1), pytest.approx(0.9)]
|
||||
|
||||
|
||||
def test_rerank_sorts_correctly():
|
||||
reranker = CohereTextReranker(api_key="co-test")
|
||||
mock_client = _make_mock_client([
|
||||
_make_cohere_result(index=1, score=0.95),
|
||||
_make_cohere_result(index=0, score=0.3),
|
||||
])
|
||||
with patch.object(reranker, "_get_client", return_value=mock_client):
|
||||
results = reranker.rerank("query", ["less relevant", "more relevant"])
|
||||
assert results[0].candidate == "more relevant"
|
||||
assert results[0].rank == 0
|
||||
|
||||
|
||||
def test_rerank_top_n():
|
||||
reranker = CohereTextReranker(api_key="co-test")
|
||||
mock_client = _make_mock_client([
|
||||
_make_cohere_result(index=0, score=0.9),
|
||||
_make_cohere_result(index=1, score=0.5),
|
||||
_make_cohere_result(index=2, score=0.1),
|
||||
])
|
||||
with patch.object(reranker, "_get_client", return_value=mock_client):
|
||||
results = reranker.rerank("q", ["a", "b", "c"], top_n=2)
|
||||
assert len(results) == 2
|
||||
|
||||
|
||||
def test_rerank_calls_cohere_with_correct_args():
|
||||
reranker = CohereTextReranker(api_key="co-test", model="rerank-english-v3.0")
|
||||
mock_client = _make_mock_client([_make_cohere_result(index=0, score=0.8)])
|
||||
with patch.object(reranker, "_get_client", return_value=mock_client):
|
||||
reranker.rerank("my query", ["only doc"])
|
||||
mock_client.rerank.assert_called_once_with(
|
||||
query="my query",
|
||||
documents=["only doc"],
|
||||
model="rerank-english-v3.0",
|
||||
top_n=1,
|
||||
max_chunks_per_doc=1,
|
||||
)
|
||||
77
tests/test_reranker/test_cross_encoder.py
Normal file
77
tests/test_reranker/test_cross_encoder.py
Normal file
|
|
@ -0,0 +1,77 @@
|
|||
"""Tests for CrossEncoderTextReranker with mocked sentence-transformers."""
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from circuitforge_core.reranker.adapters.cross_encoder import CrossEncoderTextReranker
|
||||
|
||||
|
||||
def _make_mock_cross_encoder(scores: list[float]) -> MagicMock:
|
||||
m = MagicMock()
|
||||
m.predict.return_value = scores
|
||||
return m
|
||||
|
||||
|
||||
def test_model_id():
|
||||
assert (
|
||||
CrossEncoderTextReranker("mixedbread-ai/mxbai-rerank-base-v1").model_id
|
||||
== "mixedbread-ai/mxbai-rerank-base-v1"
|
||||
)
|
||||
|
||||
|
||||
def test_load_raises_without_sentence_transformers():
|
||||
reranker = CrossEncoderTextReranker()
|
||||
with patch("circuitforge_core.reranker.adapters.cross_encoder._CrossEncoder", None):
|
||||
with pytest.raises(ImportError, match="sentence-transformers"):
|
||||
reranker.load()
|
||||
|
||||
|
||||
def test_rerank_scores_and_sorts():
|
||||
reranker = CrossEncoderTextReranker("mixedbread-ai/mxbai-rerank-base-v1")
|
||||
reranker._model = _make_mock_cross_encoder([0.2, 0.9, 0.5])
|
||||
|
||||
results = reranker.rerank("query", ["a", "b", "c"])
|
||||
assert results[0].candidate == "b"
|
||||
assert results[0].rank == 0
|
||||
assert results[2].candidate == "a"
|
||||
|
||||
|
||||
def test_rerank_top_n():
|
||||
reranker = CrossEncoderTextReranker()
|
||||
reranker._model = _make_mock_cross_encoder([0.1, 0.8, 0.5])
|
||||
results = reranker.rerank("q", ["a", "b", "c"], top_n=2)
|
||||
assert len(results) == 2
|
||||
assert results[0].candidate == "b"
|
||||
|
||||
|
||||
def test_predict_called_with_pairs():
|
||||
reranker = CrossEncoderTextReranker()
|
||||
mock_model = _make_mock_cross_encoder([0.7, 0.3])
|
||||
reranker._model = mock_model
|
||||
|
||||
reranker.rerank("chicken soup", ["recipe one", "recipe two"])
|
||||
pairs = mock_model.predict.call_args[0][0]
|
||||
assert pairs == [("chicken soup", "recipe one"), ("chicken soup", "recipe two")]
|
||||
|
||||
|
||||
def test_numpy_scores_coerced_to_float():
|
||||
"""predict() may return numpy floats — verify they're converted cleanly."""
|
||||
try:
|
||||
import numpy as np
|
||||
numpy_scores = np.array([0.8, 0.2])
|
||||
except ImportError:
|
||||
pytest.skip("numpy not installed")
|
||||
|
||||
reranker = CrossEncoderTextReranker()
|
||||
reranker._model = _make_mock_cross_encoder(numpy_scores) # type: ignore[arg-type]
|
||||
results = reranker.rerank("q", ["a", "b"])
|
||||
assert isinstance(results[0].score, float)
|
||||
|
||||
|
||||
def test_unload_clears_model():
|
||||
reranker = CrossEncoderTextReranker()
|
||||
reranker._model = MagicMock()
|
||||
reranker.unload()
|
||||
assert reranker._model is None
|
||||
185
tests/test_reranker/test_qwen3.py
Normal file
185
tests/test_reranker/test_qwen3.py
Normal file
|
|
@ -0,0 +1,185 @@
|
|||
"""Tests for Qwen3TextReranker with mocked transformers."""
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock, patch, PropertyMock
|
||||
|
||||
import pytest
|
||||
|
||||
from circuitforge_core.reranker.adapters.qwen3 import Qwen3TextReranker, _ASSISTANT_PREFILL
|
||||
|
||||
|
||||
# ── Helpers ───────────────────────────────────────────────────────────────────
|
||||
|
||||
def _make_mock_model(yes_logit: float = 5.0, no_logit: float = 1.0, batch_size: int = 1):
|
||||
"""Return a mock AutoModelForCausalLM that outputs fixed yes/no logits."""
|
||||
import torch
|
||||
|
||||
model = MagicMock()
|
||||
# Simulate logits: (batch, seq_len=1, vocab_size=32000)
|
||||
# yes token id = 9693, no token id = 2201 (Qwen tokenizer typical values)
|
||||
vocab_size = 32000
|
||||
logits = torch.zeros(batch_size, 1, vocab_size)
|
||||
logits[:, :, 9693] = yes_logit # "yes" token
|
||||
logits[:, :, 2201] = no_logit # "no" token
|
||||
output = MagicMock()
|
||||
output.logits = logits
|
||||
model.return_value = output
|
||||
|
||||
# next(model.parameters()).device
|
||||
param = MagicMock()
|
||||
param.device = torch.device("cpu")
|
||||
model.parameters.return_value = iter([param])
|
||||
return model
|
||||
|
||||
|
||||
def _make_mock_tokenizer(yes_id: int = 9693, no_id: int = 2201):
|
||||
"""Return a mock AutoTokenizer."""
|
||||
import torch
|
||||
|
||||
tokenizer = MagicMock()
|
||||
tokenizer.encode.side_effect = lambda text, **kw: (
|
||||
[yes_id] if text == "yes" else [no_id]
|
||||
)
|
||||
tokenizer.apply_chat_template.return_value = "<prompt>"
|
||||
tokenizer.padding_side = "right"
|
||||
|
||||
# Return simple fixed-length tensors from __call__
|
||||
tokenizer.return_value = {
|
||||
"input_ids": torch.zeros(1, 10, dtype=torch.long),
|
||||
"attention_mask": torch.ones(1, 10, dtype=torch.long),
|
||||
}
|
||||
return tokenizer
|
||||
|
||||
|
||||
# ── Unit tests ────────────────────────────────────────────────────────────────
|
||||
|
||||
def test_load_raises_without_torch():
|
||||
reranker = Qwen3TextReranker("Qwen/Qwen3-Reranker-0.6B")
|
||||
with patch("circuitforge_core.reranker.adapters.qwen3._torch", None):
|
||||
with pytest.raises(ImportError, match="torch"):
|
||||
reranker.load()
|
||||
|
||||
|
||||
def test_load_raises_without_transformers():
|
||||
reranker = Qwen3TextReranker("Qwen/Qwen3-Reranker-0.6B")
|
||||
with patch("circuitforge_core.reranker.adapters.qwen3._AutoModel", None):
|
||||
with pytest.raises(ImportError, match="transformers"):
|
||||
reranker.load()
|
||||
|
||||
|
||||
def test_model_id():
|
||||
assert Qwen3TextReranker("Qwen/Qwen3-Reranker-1.5B").model_id == "Qwen/Qwen3-Reranker-1.5B"
|
||||
|
||||
|
||||
def test_unload_clears_state():
|
||||
reranker = Qwen3TextReranker()
|
||||
reranker._model = MagicMock()
|
||||
reranker._tokenizer = MagicMock()
|
||||
reranker._yes_id = 1
|
||||
reranker._no_id = 2
|
||||
reranker.unload()
|
||||
assert reranker._model is None
|
||||
assert reranker._tokenizer is None
|
||||
assert reranker._yes_id is None
|
||||
assert reranker._no_id is None
|
||||
|
||||
|
||||
def test_build_prompt_includes_prefill():
|
||||
reranker = Qwen3TextReranker()
|
||||
reranker._tokenizer = _make_mock_tokenizer()
|
||||
prompt = reranker._build_prompt("what is chicken soup", "a hearty recipe")
|
||||
assert _ASSISTANT_PREFILL in prompt
|
||||
|
||||
|
||||
def test_score_batch_returns_yes_probability():
|
||||
"""Higher yes_logit → score closer to 1.0."""
|
||||
import torch
|
||||
|
||||
reranker = Qwen3TextReranker()
|
||||
reranker._tokenizer = _make_mock_tokenizer()
|
||||
reranker._model = _make_mock_model(yes_logit=10.0, no_logit=0.0)
|
||||
reranker._yes_id = 9693
|
||||
reranker._no_id = 2201
|
||||
|
||||
scores = reranker._score_batch("query", ["candidate"])
|
||||
assert len(scores) == 1
|
||||
assert scores[0] > 0.99 # softmax(10, 0)[0] ≈ 0.9999
|
||||
|
||||
|
||||
def test_score_batch_low_yes_logit():
|
||||
"""Lower yes_logit → score closer to 0.0."""
|
||||
reranker = Qwen3TextReranker()
|
||||
reranker._tokenizer = _make_mock_tokenizer()
|
||||
reranker._model = _make_mock_model(yes_logit=0.0, no_logit=10.0)
|
||||
reranker._yes_id = 9693
|
||||
reranker._no_id = 2201
|
||||
|
||||
scores = reranker._score_batch("query", ["irrelevant candidate"])
|
||||
assert scores[0] < 0.01
|
||||
|
||||
|
||||
def test_rerank_sorts_by_score():
|
||||
"""Integration through rerank() — highest yes-logit candidate should rank first."""
|
||||
import torch
|
||||
|
||||
reranker = Qwen3TextReranker(batch_size=10)
|
||||
tokenizer = _make_mock_tokenizer()
|
||||
# Return different-length tensors per call to simulate multi-candidate batch
|
||||
call_count = [0]
|
||||
|
||||
def tokenize_side_effect(prompts, **kw):
|
||||
n = len(prompts)
|
||||
return {
|
||||
"input_ids": torch.zeros(n, 10, dtype=torch.long),
|
||||
"attention_mask": torch.ones(n, 10, dtype=torch.long),
|
||||
}
|
||||
|
||||
tokenizer.side_effect = tokenize_side_effect
|
||||
tokenizer.return_value = None # disable default return
|
||||
reranker._tokenizer = tokenizer
|
||||
|
||||
# Simulate two candidates: first gets low yes logit, second gets high
|
||||
import torch as _torch
|
||||
vocab_size = 32000
|
||||
batch_logits = _torch.zeros(2, 1, vocab_size)
|
||||
batch_logits[0, 0, 9693] = 1.0 # candidate 0: low relevance
|
||||
batch_logits[0, 0, 2201] = 5.0
|
||||
batch_logits[1, 0, 9693] = 5.0 # candidate 1: high relevance
|
||||
batch_logits[1, 0, 2201] = 1.0
|
||||
|
||||
output = MagicMock()
|
||||
output.logits = batch_logits
|
||||
model = MagicMock()
|
||||
model.return_value = output
|
||||
param = MagicMock()
|
||||
param.device = _torch.device("cpu")
|
||||
model.parameters.return_value = iter([param])
|
||||
|
||||
reranker._model = model
|
||||
reranker._yes_id = 9693
|
||||
reranker._no_id = 2201
|
||||
|
||||
results = reranker.rerank("query", ["low relevance doc", "high relevance doc"])
|
||||
assert results[0].candidate == "high relevance doc"
|
||||
assert results[0].rank == 0
|
||||
|
||||
|
||||
def test_score_in_batches_splits_correctly():
|
||||
"""Verify that large candidate lists are split into sub-batches."""
|
||||
reranker = Qwen3TextReranker(batch_size=2)
|
||||
reranker._tokenizer = _make_mock_tokenizer()
|
||||
reranker._yes_id = 9693
|
||||
reranker._no_id = 2201
|
||||
|
||||
batch_results: list[list[float]] = []
|
||||
|
||||
def fake_score_batch(query, cands):
|
||||
batch_results.append(cands)
|
||||
return [0.5] * len(cands)
|
||||
|
||||
reranker._score_batch = fake_score_batch # type: ignore[method-assign]
|
||||
scores = reranker._score_in_batches("q", ["a", "b", "c", "d", "e"])
|
||||
assert len(scores) == 5
|
||||
# 5 candidates with batch_size=2 → 3 sub-batches: [a,b], [c,d], [e]
|
||||
assert len(batch_results) == 3
|
||||
assert batch_results[2] == ["e"]
|
||||
105
tests/test_reranker/test_remote.py
Normal file
105
tests/test_reranker/test_remote.py
Normal file
|
|
@ -0,0 +1,105 @@
|
|||
"""Tests for RemoteTextReranker."""
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from circuitforge_core.reranker.adapters.remote import RemoteTextReranker
|
||||
|
||||
|
||||
def _make_mock_response(results: list[dict]) -> MagicMock:
|
||||
resp = MagicMock()
|
||||
resp.json.return_value = {"results": results}
|
||||
resp.raise_for_status = MagicMock()
|
||||
return resp
|
||||
|
||||
|
||||
def test_model_id():
|
||||
assert RemoteTextReranker("http://10.0.0.1:8011").model_id == "remote"
|
||||
|
||||
|
||||
def test_score_pairs_posts_to_rerank_endpoint():
|
||||
reranker = RemoteTextReranker("http://10.0.0.1:8011")
|
||||
mock_resp = _make_mock_response([
|
||||
{"candidate": "doc a", "score": 0.9, "rank": 0},
|
||||
{"candidate": "doc b", "score": 0.3, "rank": 1},
|
||||
])
|
||||
with patch("requests.post", return_value=mock_resp) as mock_post:
|
||||
scores = reranker._score_pairs("query", ["doc a", "doc b"])
|
||||
|
||||
mock_post.assert_called_once_with(
|
||||
"http://10.0.0.1:8011/rerank",
|
||||
json={"query": "query", "candidates": ["doc a", "doc b"], "top_n": 0},
|
||||
timeout=30,
|
||||
)
|
||||
assert scores == [pytest.approx(0.9), pytest.approx(0.3)]
|
||||
|
||||
|
||||
def test_score_pairs_restores_original_order():
|
||||
"""Remote may return results in any order — scores must align with input."""
|
||||
reranker = RemoteTextReranker("http://10.0.0.1:8011")
|
||||
# Remote returned c first (highest score), then a, then b
|
||||
mock_resp = _make_mock_response([
|
||||
{"candidate": "c", "score": 0.95, "rank": 0},
|
||||
{"candidate": "a", "score": 0.6, "rank": 1},
|
||||
{"candidate": "b", "score": 0.1, "rank": 2},
|
||||
])
|
||||
with patch("requests.post", return_value=mock_resp):
|
||||
scores = reranker._score_pairs("q", ["a", "b", "c"])
|
||||
assert scores == [pytest.approx(0.6), pytest.approx(0.1), pytest.approx(0.95)]
|
||||
|
||||
|
||||
def test_score_pairs_raises_on_http_error():
|
||||
import requests as req
|
||||
reranker = RemoteTextReranker("http://10.0.0.1:8011")
|
||||
with patch("requests.post", side_effect=req.ConnectionError("refused")):
|
||||
with pytest.raises(RuntimeError, match="Remote reranker"):
|
||||
reranker._score_pairs("q", ["doc"])
|
||||
|
||||
|
||||
def test_rerank_end_to_end():
|
||||
reranker = RemoteTextReranker("http://10.0.0.1:8011")
|
||||
mock_resp = _make_mock_response([
|
||||
{"candidate": "irrelevant", "score": 0.2, "rank": 0},
|
||||
{"candidate": "very relevant", "score": 0.9, "rank": 1},
|
||||
])
|
||||
with patch("requests.post", return_value=mock_resp):
|
||||
results = reranker.rerank("q", ["irrelevant", "very relevant"])
|
||||
assert results[0].candidate == "very relevant"
|
||||
assert results[0].rank == 0
|
||||
|
||||
|
||||
# ── make_reranker wiring ──────────────────────────────────────────────────────
|
||||
|
||||
def test_make_reranker_qwen3():
|
||||
from circuitforge_core.reranker import make_reranker
|
||||
from circuitforge_core.reranker.adapters.qwen3 import Qwen3TextReranker
|
||||
r = make_reranker("Qwen/Qwen3-Reranker-0.6B", backend="qwen3")
|
||||
assert isinstance(r, Qwen3TextReranker)
|
||||
|
||||
|
||||
def test_make_reranker_cross_encoder():
|
||||
from circuitforge_core.reranker import make_reranker
|
||||
from circuitforge_core.reranker.adapters.cross_encoder import CrossEncoderTextReranker
|
||||
r = make_reranker("mixedbread-ai/mxbai-rerank-base-v1", backend="cross-encoder")
|
||||
assert isinstance(r, CrossEncoderTextReranker)
|
||||
|
||||
|
||||
def test_make_reranker_cohere():
|
||||
from circuitforge_core.reranker import make_reranker
|
||||
from circuitforge_core.reranker.adapters.cohere import CohereTextReranker
|
||||
r = make_reranker("rerank-english-v3.0", backend="cohere")
|
||||
assert isinstance(r, CohereTextReranker)
|
||||
|
||||
|
||||
def test_make_reranker_remote():
|
||||
from circuitforge_core.reranker import make_reranker
|
||||
r = make_reranker("http://10.0.0.1:8011", backend="remote")
|
||||
assert isinstance(r, RemoteTextReranker)
|
||||
|
||||
|
||||
def test_make_reranker_unknown_raises():
|
||||
from circuitforge_core.reranker import make_reranker
|
||||
with pytest.raises(ValueError, match="cross-encoder"):
|
||||
make_reranker(backend="unknown-backend")
|
||||
Loading…
Reference in a new issue