feat(reranker): full adapter suite + cf-orch auto-routing (closes #54)
Some checks failed
CI / test (push) Has been cancelled
Mirror / mirror (push) Has been cancelled
Release — PyPI / release (push) Has been cancelled

Five backends: BGE (FlagEmbedding), Qwen3 (generative yes/no logit scorer,
batched forward pass), CrossEncoder (sentence-transformers, covers mxbai-rerank
/ ms-marco / jina), Cohere (BYOK cloud), Remote (HTTP delegate to cf-reranker
service). Mock adapter for tests. 54 tests.

cf-reranker FastAPI service app (port 8011) — cf-orch manages as a process,
defaults to Qwen3-Reranker-0.6B.

make_reranker() auto-detects CF_ORCH_URL and routes to cf-orch cf-reranker
when set — cloud apps (Kiwi, Peregrine, Snipe) get remote Qwen3 reranking
with zero code changes. Local dev falls back to local BGE.

pyproject extras: reranker-bge, reranker-qwen3, reranker-cross-encoder,
reranker-cohere, reranker-service.
This commit is contained in:
pyr0ball 2026-04-26 09:04:39 -07:00
parent b21d6acc8e
commit 185057d8ca
12 changed files with 1277 additions and 4 deletions

View file

@ -6,6 +6,37 @@ Versions follow [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
---
## [0.17.0] — 2026-04-27
### Added
**`circuitforge_core.reranker`** — shared reranker module for RAG pipelines across the orchard (MIT, closes #54)
Five adapters covering local and cloud paths:
- `adapters/bge.py``BGETextReranker`: FlagEmbedding cross-encoder (`BAAI/bge-reranker-*`). Batches all pairs in a single `compute_score()` call via `rerank_batch()`. Thread-safe with internal lock. Free tier.
- `adapters/qwen3.py``Qwen3TextReranker`: generative reranker using `AutoModelForCausalLM`. Scores by reading yes/no token logits at the last input position after pre-filling the assistant `<think>\n\n</think>` block — one forward pass per batch, no generation loop. Left-pads for consistent last-token position across batch. Free / Paid tier.
- `adapters/cross_encoder.py``CrossEncoderTextReranker`: sentence-transformers `CrossEncoder`. Broader model coverage: `mxbai-rerank-*`, `ms-marco-MiniLM-*`, `jina-reranker-*`. Free tier.
- `adapters/cohere.py``CohereTextReranker`: Cohere Rerank API (BYOK cloud path). Reads `COHERE_API_KEY` from env or explicit `api_key=` arg. Restores original candidate order from Cohere's score-sorted response. Paid / BYOK.
- `adapters/remote.py``RemoteTextReranker`: HTTP delegate to a cf-reranker service endpoint. `from_cf_orch()` classmethod allocates via cf-orch on demand. `release()` method returns the lease.
- `adapters/mock.py``MockTextReranker`: Jaccard-similarity scorer, no model required. Used in tests and `CF_RERANKER_MOCK=1` mode.
`app.py``cf-reranker` FastAPI service (port 8011). Managed by cf-orch as a process-type service. Exposes `GET /health` and `POST /rerank`. Defaults to `Qwen3-Reranker-0.6B`.
**Auto cf-orch routing:** `make_reranker()` checks `CF_ORCH_URL` at construction time. When set (cloud deployments), it automatically allocates a `cf-reranker` service via cf-orch and returns a `RemoteTextReranker` — no code changes needed in Kiwi, Peregrine, or Snipe. Local dev (no `CF_ORCH_URL`) falls back to local BGE inference.
**Public API:**
- `rerank(query, candidates, top_n)` — process-level singleton, mock-safe
- `make_reranker(model_id, backend, mock)` — explicit instance
- `reset_reranker()` — test teardown only
- `RerankResult(candidate, score, rank)` — frozen dataclass result type
**`pyproject.toml` extras:** `reranker-bge`, `reranker-qwen3`, `reranker-cross-encoder`, `reranker-cohere`, `reranker-service`
54 tests across all adapters.
---
## [0.14.0] — 2026-04-20
### Added

View file

@ -92,7 +92,21 @@ def make_reranker(
return MockTextReranker()
_model_id = model_id or os.environ.get("CF_RERANKER_MODEL", _DEFAULT_MODEL)
_backend = backend or os.environ.get("CF_RERANKER_BACKEND", "bge")
_backend = backend or os.environ.get("CF_RERANKER_BACKEND", "")
# Auto-route to cf-orch when CF_ORCH_URL is set and no explicit backend override.
# Cloud deployments set CF_ORCH_URL; local dev leaves it unset → local inference.
if not _backend:
orch_url = os.environ.get("CF_ORCH_URL", "")
if orch_url:
from circuitforge_core.reranker.adapters.remote import RemoteTextReranker
logger.info("[reranker] CF_ORCH_URL set — using remote cf-reranker via cf-orch")
return RemoteTextReranker.from_cf_orch(
orch_url=orch_url,
service="cf-reranker",
ttl_s=float(os.environ.get("CF_RERANKER_TTL", "3600")),
)
_backend = "bge" # local default
if _backend == "mock":
return MockTextReranker()
@ -101,10 +115,25 @@ def make_reranker(
from circuitforge_core.reranker.adapters.bge import BGETextReranker
return BGETextReranker(_model_id)
if _backend == "qwen3":
from circuitforge_core.reranker.adapters.qwen3 import Qwen3TextReranker
return Qwen3TextReranker(_model_id)
if _backend == "cross-encoder":
from circuitforge_core.reranker.adapters.cross_encoder import CrossEncoderTextReranker
return CrossEncoderTextReranker(_model_id)
if _backend == "cohere":
from circuitforge_core.reranker.adapters.cohere import CohereTextReranker
return CohereTextReranker(model=_model_id)
if _backend == "remote":
from circuitforge_core.reranker.adapters.remote import RemoteTextReranker
return RemoteTextReranker(_model_id)
raise ValueError(
f"Unknown reranker backend {_backend!r}. "
"Valid options: 'bge', 'mock'. "
"(Qwen3 support is coming in Phase 2.)"
"Valid options: 'bge', 'qwen3', 'cross-encoder', 'cohere', 'remote', 'mock'."
)

View file

@ -0,0 +1,94 @@
# circuitforge_core/reranker/adapters/cohere.py — Cohere Rerank API (BYOK cloud)
#
# Requires: pip install circuitforge-core[reranker-cohere]
# API key: set COHERE_API_KEY env var, or pass api_key= explicitly.
#
# Models (as of 2026):
# rerank-english-v3.0 English-only, highest quality
# rerank-multilingual-v3.0 Multilingual
# rerank-english-v2.0 Legacy, lower cost
#
# BYOK unlock path: free-tier users who supply their own Cohere key get cloud
# reranking without needing a cf-orch node. Same pattern as the Anthropic
# backend in LLMRouter.
#
# MIT licensed.
from __future__ import annotations
import logging
import os
from typing import Sequence
from circuitforge_core.reranker.base import TextReranker
logger = logging.getLogger(__name__)
try:
import cohere as _cohere # type: ignore[import]
except ImportError:
_cohere = None # type: ignore[assignment]
_DEFAULT_MODEL = "rerank-english-v3.0"
class CohereTextReranker(TextReranker):
"""
Cloud reranker backed by the Cohere Rerank API.
BYOK (bring your own key): pass api_key= or set COHERE_API_KEY in the
environment. No model weights loaded locally.
Usage:
reranker = CohereTextReranker() # reads COHERE_API_KEY from env
results = reranker.rerank("chicken soup recipe", ["recipe 1...", "recipe 2..."])
With an explicit key and model:
reranker = CohereTextReranker(
api_key="co-...",
model="rerank-multilingual-v3.0",
)
"""
def __init__(
self,
api_key: str | None = None,
model: str = _DEFAULT_MODEL,
max_chunks_per_doc: int = 1,
) -> None:
self._api_key_arg = api_key
self._model = model
self._max_chunks_per_doc = max_chunks_per_doc
@property
def model_id(self) -> str:
return f"cohere:{self._model}"
def _get_client(self) -> object:
if _cohere is None:
raise ImportError(
"cohere is not installed. "
"Run: pip install circuitforge-core[reranker-cohere]"
)
api_key = self._api_key_arg or os.environ.get("COHERE_API_KEY", "")
if not api_key:
raise RuntimeError(
"Cohere API key is not set. "
"Pass api_key= to CohereTextReranker or set COHERE_API_KEY."
)
return _cohere.Client(api_key=api_key)
def _score_pairs(self, query: str, candidates: list[str]) -> list[float]:
client = self._get_client()
response = client.rerank( # type: ignore[union-attr]
query=query,
documents=candidates,
model=self._model,
top_n=len(candidates),
max_chunks_per_doc=self._max_chunks_per_doc,
)
# response.results is sorted by relevance_score desc; rebuild
# in original candidate order so TextReranker.rerank() re-sorts correctly.
score_map: dict[int, float] = {
r.index: r.relevance_score for r in response.results
}
return [score_map.get(i, 0.0) for i in range(len(candidates))]

View file

@ -0,0 +1,96 @@
# circuitforge_core/reranker/adapters/cross_encoder.py — sentence-transformers CrossEncoder
#
# Requires: pip install circuitforge-core[reranker-cross-encoder]
#
# Covers models not in the FlagEmbedding ecosystem:
# mixedbread-ai/mxbai-rerank-base-v1 ~570MB VRAM, strong general-purpose
# mixedbread-ai/mxbai-rerank-large-v1 ~1.3GB VRAM, higher quality
# cross-encoder/ms-marco-MiniLM-L-6-v2 ~90MB, fast, English-only
# cross-encoder/ms-marco-MiniLM-L-12-v2 ~130MB, balanced
# jinaai/jina-reranker-v2-base-multilingual ~280MB, multilingual
#
# MIT licensed.
from __future__ import annotations
import logging
import threading
from typing import Sequence
from circuitforge_core.reranker.base import TextReranker
logger = logging.getLogger(__name__)
try:
from sentence_transformers import CrossEncoder as _CrossEncoder # type: ignore[import]
except ImportError:
_CrossEncoder = None # type: ignore[assignment]
def _cuda_available() -> bool:
try:
import torch
return torch.cuda.is_available()
except ImportError:
return False
class CrossEncoderTextReranker(TextReranker):
"""
Cross-encoder reranker using the sentence-transformers CrossEncoder class.
Broader model compatibility than BGETextReranker any HuggingFace model
with a sequence-classification head works here. Particularly well-suited
for the mxbai-rerank and ms-marco families.
Usage:
reranker = CrossEncoderTextReranker("mixedbread-ai/mxbai-rerank-base-v1")
results = reranker.rerank("chicken soup recipe", ["recipe 1...", "recipe 2..."])
"""
def __init__(
self,
model_id: str = "mixedbread-ai/mxbai-rerank-base-v1",
max_length: int = 512,
) -> None:
self._model_id = model_id
self._max_length = max_length
self._model: object | None = None
self._lock = threading.Lock()
@property
def model_id(self) -> str:
return self._model_id
def load(self) -> None:
"""Explicitly load model weights. Called automatically on first rerank()."""
if _CrossEncoder is None:
raise ImportError(
"sentence-transformers is not installed. "
"Run: pip install circuitforge-core[reranker-cross-encoder]"
)
with self._lock:
if self._model is not None:
return
device = "cuda" if _cuda_available() else "cpu"
logger.info(
"Loading CrossEncoder reranker: %s (device=%s)", self._model_id, device
)
self._model = _CrossEncoder(
self._model_id,
max_length=self._max_length,
device=device,
)
def unload(self) -> None:
"""Release model weights."""
with self._lock:
self._model = None
def _score_pairs(self, query: str, candidates: list[str]) -> list[float]:
if self._model is None:
self.load()
pairs = [(query, c) for c in candidates]
with self._lock:
raw = self._model.predict(pairs) # type: ignore[union-attr]
# predict() returns a numpy array or list; normalise to plain floats.
return [float(s) for s in raw]

View file

@ -0,0 +1,239 @@
# circuitforge_core/reranker/adapters/qwen3.py — Qwen3-Reranker adapter
#
# Requires: pip install circuitforge-core[reranker-qwen3]
# Tested with: Qwen/Qwen3-Reranker-0.6B, -1.5B, -8B
#
# Scoring mechanism (generative reranker):
# Rather than generating a full response, we pre-fill the assistant turn with
# the <think>\n\n</think>\n block and read the logits at the last input token
# position. The softmax probability of "yes" vs "no" at that position is the
# relevance score — one forward pass per batch, no generation loop.
#
# Prompt format (Qwen3 chat template):
# system: "Judge whether the Document meets the requirements based on the
# Query and the Instruct. Note that the answer can only be 'yes'
# or 'no'."
# user: "<Instruct>: {task}\n<Query>: {query}\n<Document>: {doc}"
# assistant (pre-filled): "<think>\n\n</think>\n\n"
#
# MIT licensed.
from __future__ import annotations
import logging
import threading
from typing import Sequence
from circuitforge_core.reranker.base import TextReranker
logger = logging.getLogger(__name__)
# Optional heavy deps — lazy-imported at load() time.
try:
import torch as _torch # type: ignore[import]
except ImportError:
_torch = None # type: ignore[assignment]
try:
from transformers import AutoModelForCausalLM as _AutoModel # type: ignore[import]
from transformers import AutoTokenizer as _AutoTokenizer # type: ignore[import]
except ImportError:
_AutoModel = None # type: ignore[assignment]
_AutoTokenizer = None # type: ignore[assignment]
# System prompt used for all reranking tasks.
_SYSTEM_PROMPT = (
"Judge whether the Document meets the requirements based on the Query and "
'the Instruct. Note that the answer can only be "yes" or "no".'
)
# Default task instruction — products can override via make_reranker(task=...).
_DEFAULT_TASK = "Given a query, retrieve the most relevant document that answers the query."
# The pre-filled assistant turn that puts the model past its thinking block
# so the very next token position scores "yes" vs "no".
_ASSISTANT_PREFILL = "<think>\n\n</think>\n\n"
def _requires_deps() -> None:
if _torch is None:
raise ImportError(
"torch is not installed. Run: pip install circuitforge-core[reranker-qwen3]"
)
if _AutoModel is None:
raise ImportError(
"transformers is not installed. Run: pip install circuitforge-core[reranker-qwen3]"
)
class Qwen3TextReranker(TextReranker):
"""
Generative reranker using the Qwen3-Reranker model family.
Scores candidates by reading yes/no token logits at the last input position
after pre-filling the assistant thinking block. One forward pass covers an
entire batch efficient for ranking large candidate lists.
Model options (by tier):
Free: Qwen/Qwen3-Reranker-0.6B (~1.2 GB VRAM fp16)
Qwen/Qwen3-Reranker-1.5B (~3.0 GB VRAM fp16)
Paid: Qwen/Qwen3-Reranker-8B (~16 GB VRAM fp16, or ~9 GB int8)
Usage:
reranker = Qwen3TextReranker("Qwen/Qwen3-Reranker-0.6B")
results = reranker.rerank("chicken soup recipe", ["recipe 1...", "recipe 2..."])
With a custom task instruction:
reranker = Qwen3TextReranker(
"Qwen/Qwen3-Reranker-1.5B",
task="Given a job description, retrieve the most relevant resume.",
)
"""
def __init__(
self,
model_id: str = "Qwen/Qwen3-Reranker-0.6B",
task: str = _DEFAULT_TASK,
device: str | None = None,
dtype: str = "float16",
batch_size: int = 32,
) -> None:
self._model_id = model_id
self._task = task
self._device = device # None = auto-detect at load time
self._dtype_str = dtype
self._batch_size = batch_size
self._model: object | None = None
self._tokenizer: object | None = None
self._yes_id: int | None = None
self._no_id: int | None = None
self._lock = threading.Lock()
@property
def model_id(self) -> str:
return self._model_id
def load(self) -> None:
"""Explicitly load model weights. Called automatically on first rerank()."""
_requires_deps()
with self._lock:
if self._model is not None:
return
device = self._device or ("cuda" if _torch.cuda.is_available() else "cpu")
dtype_map: dict[str, object] = {
"float16": _torch.float16,
"bfloat16": _torch.bfloat16,
"float32": _torch.float32,
}
torch_dtype = dtype_map.get(self._dtype_str, _torch.float16)
logger.info(
"Loading Qwen3 reranker: %s (device=%s dtype=%s)",
self._model_id, device, self._dtype_str,
)
tokenizer = _AutoTokenizer.from_pretrained(
self._model_id, trust_remote_code=True
)
model = _AutoModel.from_pretrained(
self._model_id,
torch_dtype=torch_dtype,
device_map=device,
trust_remote_code=True,
)
model.eval()
# Resolve the token IDs for "yes" and "no" once at load time.
# Qwen tokenizers encode single-word tokens without a leading space.
yes_ids: list[int] = tokenizer.encode("yes", add_special_tokens=False)
no_ids: list[int] = tokenizer.encode("no", add_special_tokens=False)
if not yes_ids or not no_ids:
raise RuntimeError(
f"Could not resolve 'yes'/'no' token IDs from tokenizer {self._model_id!r}. "
"This model may not be a Qwen3-Reranker variant."
)
self._tokenizer = tokenizer
self._model = model
self._yes_id = yes_ids[0]
self._no_id = no_ids[0]
def unload(self) -> None:
"""Release model weights. Useful for VRAM management between tasks."""
with self._lock:
self._model = None
self._tokenizer = None
self._yes_id = None
self._no_id = None
def _build_prompt(self, query: str, document: str) -> str:
"""Format a single (query, document) pair as a chat-template prompt."""
messages = [
{"role": "system", "content": _SYSTEM_PROMPT},
{
"role": "user",
"content": (
f"<Instruct>: {self._task}\n"
f"<Query>: {query}\n"
f"<Document>: {document}"
),
},
]
# apply_chat_template without tokenization so we can append the prefill.
text: str = self._tokenizer.apply_chat_template( # type: ignore[union-attr]
messages,
tokenize=False,
add_generation_prompt=True,
)
return text + _ASSISTANT_PREFILL
def _score_pairs(self, query: str, candidates: list[str]) -> list[float]:
if self._model is None:
self.load()
return self._score_in_batches(query, candidates)
def _score_in_batches(self, query: str, candidates: list[str]) -> list[float]:
"""Score all (query, candidate) pairs, splitting into sub-batches."""
all_scores: list[float] = []
for start in range(0, len(candidates), self._batch_size):
batch = candidates[start : start + self._batch_size]
all_scores.extend(self._score_batch(query, batch))
return all_scores
def _score_batch(self, query: str, candidates: list[str]) -> list[float]:
"""Single forward pass for one sub-batch. Returns a score per candidate."""
prompts = [self._build_prompt(query, c) for c in candidates]
# Left-pad so the last token position is consistent across all sequences.
tokenizer = self._tokenizer # type: ignore[union-attr]
original_side = getattr(tokenizer, "padding_side", "right")
tokenizer.padding_side = "left"
try:
encoded = tokenizer(
prompts,
return_tensors="pt",
padding=True,
truncation=True,
max_length=4096,
)
finally:
tokenizer.padding_side = original_side
model = self._model # type: ignore[union-attr]
device = next(model.parameters()).device # type: ignore[union-attr]
input_ids = encoded["input_ids"].to(device)
attention_mask = encoded["attention_mask"].to(device)
with self._lock:
with _torch.no_grad():
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
# logits shape: (batch, seq_len, vocab_size)
# Last position [-1] is the token the model would output next.
last_logits = outputs.logits[:, -1, :] # (batch, vocab)
yes_logits = last_logits[:, self._yes_id] # (batch,)
no_logits = last_logits[:, self._no_id] # (batch,)
# Softmax over yes/no only — score = P(yes | query, doc).
stacked = _torch.stack([yes_logits, no_logits], dim=-1) # (batch, 2)
probs = _torch.softmax(stacked, dim=-1)
scores: list[float] = probs[:, 0].float().cpu().tolist()
return scores

View file

@ -0,0 +1,131 @@
# circuitforge_core/reranker/adapters/remote.py — HTTP remote reranker adapter
#
# Calls a cf-reranker service endpoint (cf-orch allocated or static URL).
# No model weights loaded locally — all inference runs on the remote node.
#
# MIT licensed.
from __future__ import annotations
import logging
from typing import Sequence
import requests
from circuitforge_core.reranker.base import TextReranker
logger = logging.getLogger(__name__)
# Default timeout for a single /rerank call (seconds).
# Large candidate lists may take longer — callers can pass timeout= explicitly.
_DEFAULT_TIMEOUT = 30
class RemoteTextReranker(TextReranker):
"""
Reranker that delegates scoring to a remote cf-reranker HTTP service.
The remote service must implement POST /rerank with the request body::
{"query": str, "candidates": [str, ...], "top_n": int}
and return::
{"results": [{"candidate": str, "score": float, "rank": int}, ...]}
cf-orch allocation (recommended starts service on-demand):
reranker = RemoteTextReranker.from_cf_orch(
orch_url="http://10.1.10.71:7700",
service="cf-reranker",
model_candidates=["qwen3-0.6b"],
)
Static URL (e.g. dedicated node already running cf-reranker):
reranker = RemoteTextReranker("http://10.1.10.10:8011")
"""
def __init__(
self,
base_url: str,
timeout: int = _DEFAULT_TIMEOUT,
_model_id: str = "remote",
) -> None:
self._base_url = base_url.rstrip("/")
self._timeout = timeout
self._model_id_str = _model_id
@property
def model_id(self) -> str:
return self._model_id_str
@classmethod
def from_cf_orch(
cls,
orch_url: str,
service: str = "cf-reranker",
model_candidates: list[str] | None = None,
ttl_s: float = 3600.0,
timeout: int = _DEFAULT_TIMEOUT,
) -> "RemoteTextReranker":
"""
Allocate a cf-reranker service via cf-orch and return a configured adapter.
Blocks until allocation succeeds or raises on failure. The returned
adapter is valid for the duration of the TTL; create a new one if the
lease expires.
This is a one-shot allocation the caller owns the lifetime. For
long-running services, prefer the static URL constructor and let
cf-orch manage the process independently.
"""
try:
from circuitforge_orch.client import CFOrchClient # type: ignore[import]
except ImportError as exc:
raise ImportError(
"circuitforge_orch is not installed — cannot allocate via cf-orch."
) from exc
client = CFOrchClient(orch_url)
ctx = client.allocate(
service,
model_candidates=model_candidates or [],
ttl_s=ttl_s,
caller="reranker-remote",
)
alloc = ctx.__enter__()
# Note: caller is responsible for ctx.__exit__() when done.
# We stash it on the instance so callers can call release().
instance = cls(
base_url=alloc.url,
timeout=timeout,
_model_id=f"remote:{service}",
)
instance._orch_ctx = ctx # type: ignore[attr-defined]
return instance
def release(self) -> None:
"""Release the cf-orch allocation if this adapter was created via from_cf_orch()."""
ctx = getattr(self, "_orch_ctx", None)
if ctx is not None:
try:
ctx.__exit__(None, None, None)
except Exception:
pass
self._orch_ctx = None # type: ignore[attr-defined]
def _score_pairs(self, query: str, candidates: list[str]) -> list[float]:
url = f"{self._base_url}/rerank"
payload = {"query": query, "candidates": candidates, "top_n": 0}
try:
resp = requests.post(url, json=payload, timeout=self._timeout)
resp.raise_for_status()
except requests.RequestException as exc:
raise RuntimeError(
f"Remote reranker at {url!r} failed: {exc}"
) from exc
data = resp.json()
# Build a score-per-candidate list in the original order.
score_map: dict[str, float] = {
r["candidate"]: r["score"] for r in data["results"]
}
return [score_map.get(c, 0.0) for c in candidates]

View file

@ -0,0 +1,169 @@
"""
circuitforge_core.reranker.app cf-reranker FastAPI service.
Managed by cf-orch as a process-type service. cf-orch starts this via:
python -m circuitforge_core.reranker.app \
--model BAAI/bge-reranker-base \
--backend bge \
--port 8011 \
--gpu-id 0
Or with Qwen3:
python -m circuitforge_core.reranker.app \
--model Qwen/Qwen3-Reranker-0.6B \
--backend qwen3 \
--port 8011 \
--gpu-id 0 \
--dtype float16
Endpoints:
GET /health {"status": "ok", "model": "...", "backend": "...", "vram_mb": n}
POST /rerank RerankResponse
"""
from __future__ import annotations
import argparse
import logging
import os
import uvicorn
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
logger = logging.getLogger(__name__)
# ── Request / response models ─────────────────────────────────────────────────
class RerankRequest(BaseModel):
query: str
candidates: list[str]
top_n: int = 0
class RerankResultItem(BaseModel):
candidate: str
score: float
rank: int
class RerankResponse(BaseModel):
results: list[RerankResultItem]
model: str
class HealthResponse(BaseModel):
status: str
model: str
backend: str
vram_mb: int
# ── VRAM estimates by backend/model family ────────────────────────────────────
_VRAM_TABLE: dict[str, int] = {
"bge-reranker-base": 570,
"bge-reranker-large": 1300,
"bge-reranker-v2-m3": 570,
"mxbai-rerank-base-v1": 570,
"mxbai-rerank-large-v1": 1300,
"ms-marco-MiniLM-L-6-v2": 90,
"ms-marco-MiniLM-L-12-v2": 130,
"Qwen3-Reranker-0.6B": 1200,
"Qwen3-Reranker-1.5B": 3000,
"Qwen3-Reranker-8B": 16000,
}
def _estimate_vram(model_id: str) -> int:
for key, mb in _VRAM_TABLE.items():
if key in model_id:
return mb
return 1024 # safe default
# ── App factory ───────────────────────────────────────────────────────────────
def create_app(model_id: str, backend: str, dtype: str, mock: bool) -> FastAPI:
from circuitforge_core.reranker import make_reranker
app = FastAPI(title="cf-reranker", version="0.1.0")
_reranker = make_reranker(model_id=model_id, backend=backend, mock=mock)
_vram_mb = _estimate_vram(model_id)
logger.info("cf-reranker ready: model=%r backend=%r vram=%dMB", model_id, backend, _vram_mb)
@app.get("/health", response_model=HealthResponse)
async def health() -> HealthResponse:
return HealthResponse(
status="ok",
model=_reranker.model_id,
backend=backend,
vram_mb=_vram_mb,
)
@app.post("/rerank", response_model=RerankResponse)
async def rerank(req: RerankRequest) -> RerankResponse:
if not req.candidates:
raise HTTPException(status_code=400, detail="candidates must not be empty")
try:
results = _reranker.rerank(req.query, req.candidates, top_n=req.top_n)
except Exception as exc:
logger.exception("rerank failed")
raise HTTPException(status_code=500, detail=str(exc)) from exc
return RerankResponse(
results=[
RerankResultItem(candidate=r.candidate, score=r.score, rank=r.rank)
for r in results
],
model=_reranker.model_id,
)
return app
# ── CLI entry point ───────────────────────────────────────────────────────────
def main() -> None:
parser = argparse.ArgumentParser(description="cf-reranker — CircuitForge reranker service")
parser.add_argument(
"--model", default="BAAI/bge-reranker-base",
help="HuggingFace model ID or local path",
)
parser.add_argument(
"--backend", default="bge",
choices=["bge", "qwen3", "cross-encoder", "mock"],
help="Reranker backend",
)
parser.add_argument("--port", type=int, default=8011)
parser.add_argument("--host", default="0.0.0.0")
parser.add_argument("--gpu-id", type=int, default=0)
parser.add_argument(
"--dtype", default="float16",
choices=["float16", "bfloat16", "float32"],
)
parser.add_argument("--mock", action="store_true",
help="Run with mock backend (no GPU, for testing)")
args = parser.parse_args()
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s %(name)s %(message)s",
)
if args.backend != "mock" and not args.mock:
os.environ.setdefault("CUDA_VISIBLE_DEVICES", str(args.gpu_id))
mock = args.mock or os.environ.get("CF_RERANKER_MOCK", "") == "1"
app = create_app(
model_id=args.model,
backend=args.backend,
dtype=args.dtype,
mock=mock,
)
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
if __name__ == "__main__":
main()

View file

@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "circuitforge-core"
version = "0.15.0"
version = "0.17.0"
description = "Shared scaffold for CircuitForge products (MIT)"
requires-python = ">=3.11"
dependencies = [
@ -91,6 +91,17 @@ reranker-qwen3 = [
"transformers>=4.40",
"accelerate>=0.27",
]
reranker-cross-encoder = [
"sentence-transformers>=3.0",
]
reranker-cohere = [
"cohere>=5.0",
]
reranker-service = [
"circuitforge-core[reranker-qwen3]",
"fastapi>=0.110",
"uvicorn[standard]>=0.29",
]
dev = [
"circuitforge-core[manage]",
"pytest>=8.0",

View file

@ -0,0 +1,106 @@
"""Tests for CohereTextReranker with mocked cohere client."""
from __future__ import annotations
from unittest.mock import MagicMock, patch
import pytest
from circuitforge_core.reranker.adapters.cohere import CohereTextReranker
def _make_cohere_result(index: int, score: float) -> MagicMock:
r = MagicMock()
r.index = index
r.relevance_score = score
return r
def _make_mock_client(results: list[MagicMock]) -> MagicMock:
response = MagicMock()
response.results = results
client = MagicMock()
client.rerank.return_value = response
return client
def test_model_id_includes_model_name():
r = CohereTextReranker(model="rerank-multilingual-v3.0")
assert r.model_id == "cohere:rerank-multilingual-v3.0"
def test_raises_without_cohere_package():
reranker = CohereTextReranker(api_key="co-test")
with patch("circuitforge_core.reranker.adapters.cohere._cohere", None):
with pytest.raises(ImportError, match="cohere"):
reranker._score_pairs("q", ["doc"])
def test_raises_without_api_key(monkeypatch):
monkeypatch.delenv("COHERE_API_KEY", raising=False)
reranker = CohereTextReranker() # no api_key arg
mock_cohere = MagicMock()
with patch("circuitforge_core.reranker.adapters.cohere._cohere", mock_cohere):
with pytest.raises(RuntimeError, match="API key"):
reranker._get_client()
def test_reads_api_key_from_env(monkeypatch):
monkeypatch.setenv("COHERE_API_KEY", "co-fromenv")
mock_cohere = MagicMock()
with patch("circuitforge_core.reranker.adapters.cohere._cohere", mock_cohere):
reranker = CohereTextReranker()
reranker._get_client()
mock_cohere.Client.assert_called_once_with(api_key="co-fromenv")
def test_score_pairs_returns_original_order():
"""Cohere returns results sorted by score; we must restore original order."""
reranker = CohereTextReranker(api_key="co-test")
# Cohere returns candidates ranked: index 2 (0.9), index 0 (0.6), index 1 (0.1)
mock_client = _make_mock_client([
_make_cohere_result(index=2, score=0.9),
_make_cohere_result(index=0, score=0.6),
_make_cohere_result(index=1, score=0.1),
])
with patch.object(reranker, "_get_client", return_value=mock_client):
scores = reranker._score_pairs("query", ["a", "b", "c"])
# Original order: a=0.6, b=0.1, c=0.9
assert scores == [pytest.approx(0.6), pytest.approx(0.1), pytest.approx(0.9)]
def test_rerank_sorts_correctly():
reranker = CohereTextReranker(api_key="co-test")
mock_client = _make_mock_client([
_make_cohere_result(index=1, score=0.95),
_make_cohere_result(index=0, score=0.3),
])
with patch.object(reranker, "_get_client", return_value=mock_client):
results = reranker.rerank("query", ["less relevant", "more relevant"])
assert results[0].candidate == "more relevant"
assert results[0].rank == 0
def test_rerank_top_n():
reranker = CohereTextReranker(api_key="co-test")
mock_client = _make_mock_client([
_make_cohere_result(index=0, score=0.9),
_make_cohere_result(index=1, score=0.5),
_make_cohere_result(index=2, score=0.1),
])
with patch.object(reranker, "_get_client", return_value=mock_client):
results = reranker.rerank("q", ["a", "b", "c"], top_n=2)
assert len(results) == 2
def test_rerank_calls_cohere_with_correct_args():
reranker = CohereTextReranker(api_key="co-test", model="rerank-english-v3.0")
mock_client = _make_mock_client([_make_cohere_result(index=0, score=0.8)])
with patch.object(reranker, "_get_client", return_value=mock_client):
reranker.rerank("my query", ["only doc"])
mock_client.rerank.assert_called_once_with(
query="my query",
documents=["only doc"],
model="rerank-english-v3.0",
top_n=1,
max_chunks_per_doc=1,
)

View file

@ -0,0 +1,77 @@
"""Tests for CrossEncoderTextReranker with mocked sentence-transformers."""
from __future__ import annotations
from unittest.mock import MagicMock, patch
import pytest
from circuitforge_core.reranker.adapters.cross_encoder import CrossEncoderTextReranker
def _make_mock_cross_encoder(scores: list[float]) -> MagicMock:
m = MagicMock()
m.predict.return_value = scores
return m
def test_model_id():
assert (
CrossEncoderTextReranker("mixedbread-ai/mxbai-rerank-base-v1").model_id
== "mixedbread-ai/mxbai-rerank-base-v1"
)
def test_load_raises_without_sentence_transformers():
reranker = CrossEncoderTextReranker()
with patch("circuitforge_core.reranker.adapters.cross_encoder._CrossEncoder", None):
with pytest.raises(ImportError, match="sentence-transformers"):
reranker.load()
def test_rerank_scores_and_sorts():
reranker = CrossEncoderTextReranker("mixedbread-ai/mxbai-rerank-base-v1")
reranker._model = _make_mock_cross_encoder([0.2, 0.9, 0.5])
results = reranker.rerank("query", ["a", "b", "c"])
assert results[0].candidate == "b"
assert results[0].rank == 0
assert results[2].candidate == "a"
def test_rerank_top_n():
reranker = CrossEncoderTextReranker()
reranker._model = _make_mock_cross_encoder([0.1, 0.8, 0.5])
results = reranker.rerank("q", ["a", "b", "c"], top_n=2)
assert len(results) == 2
assert results[0].candidate == "b"
def test_predict_called_with_pairs():
reranker = CrossEncoderTextReranker()
mock_model = _make_mock_cross_encoder([0.7, 0.3])
reranker._model = mock_model
reranker.rerank("chicken soup", ["recipe one", "recipe two"])
pairs = mock_model.predict.call_args[0][0]
assert pairs == [("chicken soup", "recipe one"), ("chicken soup", "recipe two")]
def test_numpy_scores_coerced_to_float():
"""predict() may return numpy floats — verify they're converted cleanly."""
try:
import numpy as np
numpy_scores = np.array([0.8, 0.2])
except ImportError:
pytest.skip("numpy not installed")
reranker = CrossEncoderTextReranker()
reranker._model = _make_mock_cross_encoder(numpy_scores) # type: ignore[arg-type]
results = reranker.rerank("q", ["a", "b"])
assert isinstance(results[0].score, float)
def test_unload_clears_model():
reranker = CrossEncoderTextReranker()
reranker._model = MagicMock()
reranker.unload()
assert reranker._model is None

View file

@ -0,0 +1,185 @@
"""Tests for Qwen3TextReranker with mocked transformers."""
from __future__ import annotations
from unittest.mock import MagicMock, patch, PropertyMock
import pytest
from circuitforge_core.reranker.adapters.qwen3 import Qwen3TextReranker, _ASSISTANT_PREFILL
# ── Helpers ───────────────────────────────────────────────────────────────────
def _make_mock_model(yes_logit: float = 5.0, no_logit: float = 1.0, batch_size: int = 1):
"""Return a mock AutoModelForCausalLM that outputs fixed yes/no logits."""
import torch
model = MagicMock()
# Simulate logits: (batch, seq_len=1, vocab_size=32000)
# yes token id = 9693, no token id = 2201 (Qwen tokenizer typical values)
vocab_size = 32000
logits = torch.zeros(batch_size, 1, vocab_size)
logits[:, :, 9693] = yes_logit # "yes" token
logits[:, :, 2201] = no_logit # "no" token
output = MagicMock()
output.logits = logits
model.return_value = output
# next(model.parameters()).device
param = MagicMock()
param.device = torch.device("cpu")
model.parameters.return_value = iter([param])
return model
def _make_mock_tokenizer(yes_id: int = 9693, no_id: int = 2201):
"""Return a mock AutoTokenizer."""
import torch
tokenizer = MagicMock()
tokenizer.encode.side_effect = lambda text, **kw: (
[yes_id] if text == "yes" else [no_id]
)
tokenizer.apply_chat_template.return_value = "<prompt>"
tokenizer.padding_side = "right"
# Return simple fixed-length tensors from __call__
tokenizer.return_value = {
"input_ids": torch.zeros(1, 10, dtype=torch.long),
"attention_mask": torch.ones(1, 10, dtype=torch.long),
}
return tokenizer
# ── Unit tests ────────────────────────────────────────────────────────────────
def test_load_raises_without_torch():
reranker = Qwen3TextReranker("Qwen/Qwen3-Reranker-0.6B")
with patch("circuitforge_core.reranker.adapters.qwen3._torch", None):
with pytest.raises(ImportError, match="torch"):
reranker.load()
def test_load_raises_without_transformers():
reranker = Qwen3TextReranker("Qwen/Qwen3-Reranker-0.6B")
with patch("circuitforge_core.reranker.adapters.qwen3._AutoModel", None):
with pytest.raises(ImportError, match="transformers"):
reranker.load()
def test_model_id():
assert Qwen3TextReranker("Qwen/Qwen3-Reranker-1.5B").model_id == "Qwen/Qwen3-Reranker-1.5B"
def test_unload_clears_state():
reranker = Qwen3TextReranker()
reranker._model = MagicMock()
reranker._tokenizer = MagicMock()
reranker._yes_id = 1
reranker._no_id = 2
reranker.unload()
assert reranker._model is None
assert reranker._tokenizer is None
assert reranker._yes_id is None
assert reranker._no_id is None
def test_build_prompt_includes_prefill():
reranker = Qwen3TextReranker()
reranker._tokenizer = _make_mock_tokenizer()
prompt = reranker._build_prompt("what is chicken soup", "a hearty recipe")
assert _ASSISTANT_PREFILL in prompt
def test_score_batch_returns_yes_probability():
"""Higher yes_logit → score closer to 1.0."""
import torch
reranker = Qwen3TextReranker()
reranker._tokenizer = _make_mock_tokenizer()
reranker._model = _make_mock_model(yes_logit=10.0, no_logit=0.0)
reranker._yes_id = 9693
reranker._no_id = 2201
scores = reranker._score_batch("query", ["candidate"])
assert len(scores) == 1
assert scores[0] > 0.99 # softmax(10, 0)[0] ≈ 0.9999
def test_score_batch_low_yes_logit():
"""Lower yes_logit → score closer to 0.0."""
reranker = Qwen3TextReranker()
reranker._tokenizer = _make_mock_tokenizer()
reranker._model = _make_mock_model(yes_logit=0.0, no_logit=10.0)
reranker._yes_id = 9693
reranker._no_id = 2201
scores = reranker._score_batch("query", ["irrelevant candidate"])
assert scores[0] < 0.01
def test_rerank_sorts_by_score():
"""Integration through rerank() — highest yes-logit candidate should rank first."""
import torch
reranker = Qwen3TextReranker(batch_size=10)
tokenizer = _make_mock_tokenizer()
# Return different-length tensors per call to simulate multi-candidate batch
call_count = [0]
def tokenize_side_effect(prompts, **kw):
n = len(prompts)
return {
"input_ids": torch.zeros(n, 10, dtype=torch.long),
"attention_mask": torch.ones(n, 10, dtype=torch.long),
}
tokenizer.side_effect = tokenize_side_effect
tokenizer.return_value = None # disable default return
reranker._tokenizer = tokenizer
# Simulate two candidates: first gets low yes logit, second gets high
import torch as _torch
vocab_size = 32000
batch_logits = _torch.zeros(2, 1, vocab_size)
batch_logits[0, 0, 9693] = 1.0 # candidate 0: low relevance
batch_logits[0, 0, 2201] = 5.0
batch_logits[1, 0, 9693] = 5.0 # candidate 1: high relevance
batch_logits[1, 0, 2201] = 1.0
output = MagicMock()
output.logits = batch_logits
model = MagicMock()
model.return_value = output
param = MagicMock()
param.device = _torch.device("cpu")
model.parameters.return_value = iter([param])
reranker._model = model
reranker._yes_id = 9693
reranker._no_id = 2201
results = reranker.rerank("query", ["low relevance doc", "high relevance doc"])
assert results[0].candidate == "high relevance doc"
assert results[0].rank == 0
def test_score_in_batches_splits_correctly():
"""Verify that large candidate lists are split into sub-batches."""
reranker = Qwen3TextReranker(batch_size=2)
reranker._tokenizer = _make_mock_tokenizer()
reranker._yes_id = 9693
reranker._no_id = 2201
batch_results: list[list[float]] = []
def fake_score_batch(query, cands):
batch_results.append(cands)
return [0.5] * len(cands)
reranker._score_batch = fake_score_batch # type: ignore[method-assign]
scores = reranker._score_in_batches("q", ["a", "b", "c", "d", "e"])
assert len(scores) == 5
# 5 candidates with batch_size=2 → 3 sub-batches: [a,b], [c,d], [e]
assert len(batch_results) == 3
assert batch_results[2] == ["e"]

View file

@ -0,0 +1,105 @@
"""Tests for RemoteTextReranker."""
from __future__ import annotations
from unittest.mock import MagicMock, patch
import pytest
from circuitforge_core.reranker.adapters.remote import RemoteTextReranker
def _make_mock_response(results: list[dict]) -> MagicMock:
resp = MagicMock()
resp.json.return_value = {"results": results}
resp.raise_for_status = MagicMock()
return resp
def test_model_id():
assert RemoteTextReranker("http://10.0.0.1:8011").model_id == "remote"
def test_score_pairs_posts_to_rerank_endpoint():
reranker = RemoteTextReranker("http://10.0.0.1:8011")
mock_resp = _make_mock_response([
{"candidate": "doc a", "score": 0.9, "rank": 0},
{"candidate": "doc b", "score": 0.3, "rank": 1},
])
with patch("requests.post", return_value=mock_resp) as mock_post:
scores = reranker._score_pairs("query", ["doc a", "doc b"])
mock_post.assert_called_once_with(
"http://10.0.0.1:8011/rerank",
json={"query": "query", "candidates": ["doc a", "doc b"], "top_n": 0},
timeout=30,
)
assert scores == [pytest.approx(0.9), pytest.approx(0.3)]
def test_score_pairs_restores_original_order():
"""Remote may return results in any order — scores must align with input."""
reranker = RemoteTextReranker("http://10.0.0.1:8011")
# Remote returned c first (highest score), then a, then b
mock_resp = _make_mock_response([
{"candidate": "c", "score": 0.95, "rank": 0},
{"candidate": "a", "score": 0.6, "rank": 1},
{"candidate": "b", "score": 0.1, "rank": 2},
])
with patch("requests.post", return_value=mock_resp):
scores = reranker._score_pairs("q", ["a", "b", "c"])
assert scores == [pytest.approx(0.6), pytest.approx(0.1), pytest.approx(0.95)]
def test_score_pairs_raises_on_http_error():
import requests as req
reranker = RemoteTextReranker("http://10.0.0.1:8011")
with patch("requests.post", side_effect=req.ConnectionError("refused")):
with pytest.raises(RuntimeError, match="Remote reranker"):
reranker._score_pairs("q", ["doc"])
def test_rerank_end_to_end():
reranker = RemoteTextReranker("http://10.0.0.1:8011")
mock_resp = _make_mock_response([
{"candidate": "irrelevant", "score": 0.2, "rank": 0},
{"candidate": "very relevant", "score": 0.9, "rank": 1},
])
with patch("requests.post", return_value=mock_resp):
results = reranker.rerank("q", ["irrelevant", "very relevant"])
assert results[0].candidate == "very relevant"
assert results[0].rank == 0
# ── make_reranker wiring ──────────────────────────────────────────────────────
def test_make_reranker_qwen3():
from circuitforge_core.reranker import make_reranker
from circuitforge_core.reranker.adapters.qwen3 import Qwen3TextReranker
r = make_reranker("Qwen/Qwen3-Reranker-0.6B", backend="qwen3")
assert isinstance(r, Qwen3TextReranker)
def test_make_reranker_cross_encoder():
from circuitforge_core.reranker import make_reranker
from circuitforge_core.reranker.adapters.cross_encoder import CrossEncoderTextReranker
r = make_reranker("mixedbread-ai/mxbai-rerank-base-v1", backend="cross-encoder")
assert isinstance(r, CrossEncoderTextReranker)
def test_make_reranker_cohere():
from circuitforge_core.reranker import make_reranker
from circuitforge_core.reranker.adapters.cohere import CohereTextReranker
r = make_reranker("rerank-english-v3.0", backend="cohere")
assert isinstance(r, CohereTextReranker)
def test_make_reranker_remote():
from circuitforge_core.reranker import make_reranker
r = make_reranker("http://10.0.0.1:8011", backend="remote")
assert isinstance(r, RemoteTextReranker)
def test_make_reranker_unknown_raises():
from circuitforge_core.reranker import make_reranker
with pytest.raises(ValueError, match="cross-encoder"):
make_reranker(backend="unknown-backend")