From 185057d8ca0b2ecad51d0b38208846bfe41bedd3 Mon Sep 17 00:00:00 2001 From: pyr0ball Date: Sun, 26 Apr 2026 09:04:39 -0700 Subject: [PATCH] feat(reranker): full adapter suite + cf-orch auto-routing (closes #54) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Five backends: BGE (FlagEmbedding), Qwen3 (generative yes/no logit scorer, batched forward pass), CrossEncoder (sentence-transformers, covers mxbai-rerank / ms-marco / jina), Cohere (BYOK cloud), Remote (HTTP delegate to cf-reranker service). Mock adapter for tests. 54 tests. cf-reranker FastAPI service app (port 8011) — cf-orch manages as a process, defaults to Qwen3-Reranker-0.6B. make_reranker() auto-detects CF_ORCH_URL and routes to cf-orch cf-reranker when set — cloud apps (Kiwi, Peregrine, Snipe) get remote Qwen3 reranking with zero code changes. Local dev falls back to local BGE. pyproject extras: reranker-bge, reranker-qwen3, reranker-cross-encoder, reranker-cohere, reranker-service. --- CHANGELOG.md | 31 +++ circuitforge_core/reranker/__init__.py | 35 ++- circuitforge_core/reranker/adapters/cohere.py | 94 +++++++ .../reranker/adapters/cross_encoder.py | 96 +++++++ circuitforge_core/reranker/adapters/qwen3.py | 239 ++++++++++++++++++ circuitforge_core/reranker/adapters/remote.py | 131 ++++++++++ circuitforge_core/reranker/app.py | 169 +++++++++++++ pyproject.toml | 13 +- tests/test_reranker/test_cohere.py | 106 ++++++++ tests/test_reranker/test_cross_encoder.py | 77 ++++++ tests/test_reranker/test_qwen3.py | 185 ++++++++++++++ tests/test_reranker/test_remote.py | 105 ++++++++ 12 files changed, 1277 insertions(+), 4 deletions(-) create mode 100644 circuitforge_core/reranker/adapters/cohere.py create mode 100644 circuitforge_core/reranker/adapters/cross_encoder.py create mode 100644 circuitforge_core/reranker/adapters/qwen3.py create mode 100644 circuitforge_core/reranker/adapters/remote.py create mode 100644 circuitforge_core/reranker/app.py create mode 100644 tests/test_reranker/test_cohere.py create mode 100644 tests/test_reranker/test_cross_encoder.py create mode 100644 tests/test_reranker/test_qwen3.py create mode 100644 tests/test_reranker/test_remote.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 6194c2e..fa5b92f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,37 @@ Versions follow [Semantic Versioning](https://semver.org/spec/v2.0.0.html). --- +## [0.17.0] — 2026-04-27 + +### Added + +**`circuitforge_core.reranker`** — shared reranker module for RAG pipelines across the orchard (MIT, closes #54) + +Five adapters covering local and cloud paths: + +- `adapters/bge.py` — `BGETextReranker`: FlagEmbedding cross-encoder (`BAAI/bge-reranker-*`). Batches all pairs in a single `compute_score()` call via `rerank_batch()`. Thread-safe with internal lock. Free tier. +- `adapters/qwen3.py` — `Qwen3TextReranker`: generative reranker using `AutoModelForCausalLM`. Scores by reading yes/no token logits at the last input position after pre-filling the assistant `\n\n` block — one forward pass per batch, no generation loop. Left-pads for consistent last-token position across batch. Free / Paid tier. +- `adapters/cross_encoder.py` — `CrossEncoderTextReranker`: sentence-transformers `CrossEncoder`. Broader model coverage: `mxbai-rerank-*`, `ms-marco-MiniLM-*`, `jina-reranker-*`. Free tier. +- `adapters/cohere.py` — `CohereTextReranker`: Cohere Rerank API (BYOK cloud path). Reads `COHERE_API_KEY` from env or explicit `api_key=` arg. Restores original candidate order from Cohere's score-sorted response. Paid / BYOK. +- `adapters/remote.py` — `RemoteTextReranker`: HTTP delegate to a cf-reranker service endpoint. `from_cf_orch()` classmethod allocates via cf-orch on demand. `release()` method returns the lease. +- `adapters/mock.py` — `MockTextReranker`: Jaccard-similarity scorer, no model required. Used in tests and `CF_RERANKER_MOCK=1` mode. + +`app.py` — `cf-reranker` FastAPI service (port 8011). Managed by cf-orch as a process-type service. Exposes `GET /health` and `POST /rerank`. Defaults to `Qwen3-Reranker-0.6B`. + +**Auto cf-orch routing:** `make_reranker()` checks `CF_ORCH_URL` at construction time. When set (cloud deployments), it automatically allocates a `cf-reranker` service via cf-orch and returns a `RemoteTextReranker` — no code changes needed in Kiwi, Peregrine, or Snipe. Local dev (no `CF_ORCH_URL`) falls back to local BGE inference. + +**Public API:** +- `rerank(query, candidates, top_n)` — process-level singleton, mock-safe +- `make_reranker(model_id, backend, mock)` — explicit instance +- `reset_reranker()` — test teardown only +- `RerankResult(candidate, score, rank)` — frozen dataclass result type + +**`pyproject.toml` extras:** `reranker-bge`, `reranker-qwen3`, `reranker-cross-encoder`, `reranker-cohere`, `reranker-service` + +54 tests across all adapters. + +--- + ## [0.14.0] — 2026-04-20 ### Added diff --git a/circuitforge_core/reranker/__init__.py b/circuitforge_core/reranker/__init__.py index be8db77..81452d1 100644 --- a/circuitforge_core/reranker/__init__.py +++ b/circuitforge_core/reranker/__init__.py @@ -92,7 +92,21 @@ def make_reranker( return MockTextReranker() _model_id = model_id or os.environ.get("CF_RERANKER_MODEL", _DEFAULT_MODEL) - _backend = backend or os.environ.get("CF_RERANKER_BACKEND", "bge") + _backend = backend or os.environ.get("CF_RERANKER_BACKEND", "") + + # Auto-route to cf-orch when CF_ORCH_URL is set and no explicit backend override. + # Cloud deployments set CF_ORCH_URL; local dev leaves it unset → local inference. + if not _backend: + orch_url = os.environ.get("CF_ORCH_URL", "") + if orch_url: + from circuitforge_core.reranker.adapters.remote import RemoteTextReranker + logger.info("[reranker] CF_ORCH_URL set — using remote cf-reranker via cf-orch") + return RemoteTextReranker.from_cf_orch( + orch_url=orch_url, + service="cf-reranker", + ttl_s=float(os.environ.get("CF_RERANKER_TTL", "3600")), + ) + _backend = "bge" # local default if _backend == "mock": return MockTextReranker() @@ -101,10 +115,25 @@ def make_reranker( from circuitforge_core.reranker.adapters.bge import BGETextReranker return BGETextReranker(_model_id) + if _backend == "qwen3": + from circuitforge_core.reranker.adapters.qwen3 import Qwen3TextReranker + return Qwen3TextReranker(_model_id) + + if _backend == "cross-encoder": + from circuitforge_core.reranker.adapters.cross_encoder import CrossEncoderTextReranker + return CrossEncoderTextReranker(_model_id) + + if _backend == "cohere": + from circuitforge_core.reranker.adapters.cohere import CohereTextReranker + return CohereTextReranker(model=_model_id) + + if _backend == "remote": + from circuitforge_core.reranker.adapters.remote import RemoteTextReranker + return RemoteTextReranker(_model_id) + raise ValueError( f"Unknown reranker backend {_backend!r}. " - "Valid options: 'bge', 'mock'. " - "(Qwen3 support is coming in Phase 2.)" + "Valid options: 'bge', 'qwen3', 'cross-encoder', 'cohere', 'remote', 'mock'." ) diff --git a/circuitforge_core/reranker/adapters/cohere.py b/circuitforge_core/reranker/adapters/cohere.py new file mode 100644 index 0000000..6b01965 --- /dev/null +++ b/circuitforge_core/reranker/adapters/cohere.py @@ -0,0 +1,94 @@ +# circuitforge_core/reranker/adapters/cohere.py — Cohere Rerank API (BYOK cloud) +# +# Requires: pip install circuitforge-core[reranker-cohere] +# API key: set COHERE_API_KEY env var, or pass api_key= explicitly. +# +# Models (as of 2026): +# rerank-english-v3.0 English-only, highest quality +# rerank-multilingual-v3.0 Multilingual +# rerank-english-v2.0 Legacy, lower cost +# +# BYOK unlock path: free-tier users who supply their own Cohere key get cloud +# reranking without needing a cf-orch node. Same pattern as the Anthropic +# backend in LLMRouter. +# +# MIT licensed. +from __future__ import annotations + +import logging +import os +from typing import Sequence + +from circuitforge_core.reranker.base import TextReranker + +logger = logging.getLogger(__name__) + +try: + import cohere as _cohere # type: ignore[import] +except ImportError: + _cohere = None # type: ignore[assignment] + +_DEFAULT_MODEL = "rerank-english-v3.0" + + +class CohereTextReranker(TextReranker): + """ + Cloud reranker backed by the Cohere Rerank API. + + BYOK (bring your own key): pass api_key= or set COHERE_API_KEY in the + environment. No model weights loaded locally. + + Usage: + reranker = CohereTextReranker() # reads COHERE_API_KEY from env + results = reranker.rerank("chicken soup recipe", ["recipe 1...", "recipe 2..."]) + + With an explicit key and model: + reranker = CohereTextReranker( + api_key="co-...", + model="rerank-multilingual-v3.0", + ) + """ + + def __init__( + self, + api_key: str | None = None, + model: str = _DEFAULT_MODEL, + max_chunks_per_doc: int = 1, + ) -> None: + self._api_key_arg = api_key + self._model = model + self._max_chunks_per_doc = max_chunks_per_doc + + @property + def model_id(self) -> str: + return f"cohere:{self._model}" + + def _get_client(self) -> object: + if _cohere is None: + raise ImportError( + "cohere is not installed. " + "Run: pip install circuitforge-core[reranker-cohere]" + ) + api_key = self._api_key_arg or os.environ.get("COHERE_API_KEY", "") + if not api_key: + raise RuntimeError( + "Cohere API key is not set. " + "Pass api_key= to CohereTextReranker or set COHERE_API_KEY." + ) + return _cohere.Client(api_key=api_key) + + def _score_pairs(self, query: str, candidates: list[str]) -> list[float]: + client = self._get_client() + response = client.rerank( # type: ignore[union-attr] + query=query, + documents=candidates, + model=self._model, + top_n=len(candidates), + max_chunks_per_doc=self._max_chunks_per_doc, + ) + # response.results is sorted by relevance_score desc; rebuild + # in original candidate order so TextReranker.rerank() re-sorts correctly. + score_map: dict[int, float] = { + r.index: r.relevance_score for r in response.results + } + return [score_map.get(i, 0.0) for i in range(len(candidates))] diff --git a/circuitforge_core/reranker/adapters/cross_encoder.py b/circuitforge_core/reranker/adapters/cross_encoder.py new file mode 100644 index 0000000..8ad94f7 --- /dev/null +++ b/circuitforge_core/reranker/adapters/cross_encoder.py @@ -0,0 +1,96 @@ +# circuitforge_core/reranker/adapters/cross_encoder.py — sentence-transformers CrossEncoder +# +# Requires: pip install circuitforge-core[reranker-cross-encoder] +# +# Covers models not in the FlagEmbedding ecosystem: +# mixedbread-ai/mxbai-rerank-base-v1 ~570MB VRAM, strong general-purpose +# mixedbread-ai/mxbai-rerank-large-v1 ~1.3GB VRAM, higher quality +# cross-encoder/ms-marco-MiniLM-L-6-v2 ~90MB, fast, English-only +# cross-encoder/ms-marco-MiniLM-L-12-v2 ~130MB, balanced +# jinaai/jina-reranker-v2-base-multilingual ~280MB, multilingual +# +# MIT licensed. +from __future__ import annotations + +import logging +import threading +from typing import Sequence + +from circuitforge_core.reranker.base import TextReranker + +logger = logging.getLogger(__name__) + +try: + from sentence_transformers import CrossEncoder as _CrossEncoder # type: ignore[import] +except ImportError: + _CrossEncoder = None # type: ignore[assignment] + + +def _cuda_available() -> bool: + try: + import torch + return torch.cuda.is_available() + except ImportError: + return False + + +class CrossEncoderTextReranker(TextReranker): + """ + Cross-encoder reranker using the sentence-transformers CrossEncoder class. + + Broader model compatibility than BGETextReranker — any HuggingFace model + with a sequence-classification head works here. Particularly well-suited + for the mxbai-rerank and ms-marco families. + + Usage: + reranker = CrossEncoderTextReranker("mixedbread-ai/mxbai-rerank-base-v1") + results = reranker.rerank("chicken soup recipe", ["recipe 1...", "recipe 2..."]) + """ + + def __init__( + self, + model_id: str = "mixedbread-ai/mxbai-rerank-base-v1", + max_length: int = 512, + ) -> None: + self._model_id = model_id + self._max_length = max_length + self._model: object | None = None + self._lock = threading.Lock() + + @property + def model_id(self) -> str: + return self._model_id + + def load(self) -> None: + """Explicitly load model weights. Called automatically on first rerank().""" + if _CrossEncoder is None: + raise ImportError( + "sentence-transformers is not installed. " + "Run: pip install circuitforge-core[reranker-cross-encoder]" + ) + with self._lock: + if self._model is not None: + return + device = "cuda" if _cuda_available() else "cpu" + logger.info( + "Loading CrossEncoder reranker: %s (device=%s)", self._model_id, device + ) + self._model = _CrossEncoder( + self._model_id, + max_length=self._max_length, + device=device, + ) + + def unload(self) -> None: + """Release model weights.""" + with self._lock: + self._model = None + + def _score_pairs(self, query: str, candidates: list[str]) -> list[float]: + if self._model is None: + self.load() + pairs = [(query, c) for c in candidates] + with self._lock: + raw = self._model.predict(pairs) # type: ignore[union-attr] + # predict() returns a numpy array or list; normalise to plain floats. + return [float(s) for s in raw] diff --git a/circuitforge_core/reranker/adapters/qwen3.py b/circuitforge_core/reranker/adapters/qwen3.py new file mode 100644 index 0000000..77b654d --- /dev/null +++ b/circuitforge_core/reranker/adapters/qwen3.py @@ -0,0 +1,239 @@ +# circuitforge_core/reranker/adapters/qwen3.py — Qwen3-Reranker adapter +# +# Requires: pip install circuitforge-core[reranker-qwen3] +# Tested with: Qwen/Qwen3-Reranker-0.6B, -1.5B, -8B +# +# Scoring mechanism (generative reranker): +# Rather than generating a full response, we pre-fill the assistant turn with +# the \n\n\n block and read the logits at the last input token +# position. The softmax probability of "yes" vs "no" at that position is the +# relevance score — one forward pass per batch, no generation loop. +# +# Prompt format (Qwen3 chat template): +# system: "Judge whether the Document meets the requirements based on the +# Query and the Instruct. Note that the answer can only be 'yes' +# or 'no'." +# user: ": {task}\n: {query}\n: {doc}" +# assistant (pre-filled): "\n\n\n\n" +# +# MIT licensed. +from __future__ import annotations + +import logging +import threading +from typing import Sequence + +from circuitforge_core.reranker.base import TextReranker + +logger = logging.getLogger(__name__) + +# Optional heavy deps — lazy-imported at load() time. +try: + import torch as _torch # type: ignore[import] +except ImportError: + _torch = None # type: ignore[assignment] + +try: + from transformers import AutoModelForCausalLM as _AutoModel # type: ignore[import] + from transformers import AutoTokenizer as _AutoTokenizer # type: ignore[import] +except ImportError: + _AutoModel = None # type: ignore[assignment] + _AutoTokenizer = None # type: ignore[assignment] + +# System prompt used for all reranking tasks. +_SYSTEM_PROMPT = ( + "Judge whether the Document meets the requirements based on the Query and " + 'the Instruct. Note that the answer can only be "yes" or "no".' +) + +# Default task instruction — products can override via make_reranker(task=...). +_DEFAULT_TASK = "Given a query, retrieve the most relevant document that answers the query." + +# The pre-filled assistant turn that puts the model past its thinking block +# so the very next token position scores "yes" vs "no". +_ASSISTANT_PREFILL = "\n\n\n\n" + + +def _requires_deps() -> None: + if _torch is None: + raise ImportError( + "torch is not installed. Run: pip install circuitforge-core[reranker-qwen3]" + ) + if _AutoModel is None: + raise ImportError( + "transformers is not installed. Run: pip install circuitforge-core[reranker-qwen3]" + ) + + +class Qwen3TextReranker(TextReranker): + """ + Generative reranker using the Qwen3-Reranker model family. + + Scores candidates by reading yes/no token logits at the last input position + after pre-filling the assistant thinking block. One forward pass covers an + entire batch — efficient for ranking large candidate lists. + + Model options (by tier): + Free: Qwen/Qwen3-Reranker-0.6B (~1.2 GB VRAM fp16) + Qwen/Qwen3-Reranker-1.5B (~3.0 GB VRAM fp16) + Paid: Qwen/Qwen3-Reranker-8B (~16 GB VRAM fp16, or ~9 GB int8) + + Usage: + reranker = Qwen3TextReranker("Qwen/Qwen3-Reranker-0.6B") + results = reranker.rerank("chicken soup recipe", ["recipe 1...", "recipe 2..."]) + + With a custom task instruction: + reranker = Qwen3TextReranker( + "Qwen/Qwen3-Reranker-1.5B", + task="Given a job description, retrieve the most relevant resume.", + ) + """ + + def __init__( + self, + model_id: str = "Qwen/Qwen3-Reranker-0.6B", + task: str = _DEFAULT_TASK, + device: str | None = None, + dtype: str = "float16", + batch_size: int = 32, + ) -> None: + self._model_id = model_id + self._task = task + self._device = device # None = auto-detect at load time + self._dtype_str = dtype + self._batch_size = batch_size + self._model: object | None = None + self._tokenizer: object | None = None + self._yes_id: int | None = None + self._no_id: int | None = None + self._lock = threading.Lock() + + @property + def model_id(self) -> str: + return self._model_id + + def load(self) -> None: + """Explicitly load model weights. Called automatically on first rerank().""" + _requires_deps() + with self._lock: + if self._model is not None: + return + device = self._device or ("cuda" if _torch.cuda.is_available() else "cpu") + dtype_map: dict[str, object] = { + "float16": _torch.float16, + "bfloat16": _torch.bfloat16, + "float32": _torch.float32, + } + torch_dtype = dtype_map.get(self._dtype_str, _torch.float16) + + logger.info( + "Loading Qwen3 reranker: %s (device=%s dtype=%s)", + self._model_id, device, self._dtype_str, + ) + tokenizer = _AutoTokenizer.from_pretrained( + self._model_id, trust_remote_code=True + ) + model = _AutoModel.from_pretrained( + self._model_id, + torch_dtype=torch_dtype, + device_map=device, + trust_remote_code=True, + ) + model.eval() + + # Resolve the token IDs for "yes" and "no" once at load time. + # Qwen tokenizers encode single-word tokens without a leading space. + yes_ids: list[int] = tokenizer.encode("yes", add_special_tokens=False) + no_ids: list[int] = tokenizer.encode("no", add_special_tokens=False) + if not yes_ids or not no_ids: + raise RuntimeError( + f"Could not resolve 'yes'/'no' token IDs from tokenizer {self._model_id!r}. " + "This model may not be a Qwen3-Reranker variant." + ) + + self._tokenizer = tokenizer + self._model = model + self._yes_id = yes_ids[0] + self._no_id = no_ids[0] + + def unload(self) -> None: + """Release model weights. Useful for VRAM management between tasks.""" + with self._lock: + self._model = None + self._tokenizer = None + self._yes_id = None + self._no_id = None + + def _build_prompt(self, query: str, document: str) -> str: + """Format a single (query, document) pair as a chat-template prompt.""" + messages = [ + {"role": "system", "content": _SYSTEM_PROMPT}, + { + "role": "user", + "content": ( + f": {self._task}\n" + f": {query}\n" + f": {document}" + ), + }, + ] + # apply_chat_template without tokenization so we can append the prefill. + text: str = self._tokenizer.apply_chat_template( # type: ignore[union-attr] + messages, + tokenize=False, + add_generation_prompt=True, + ) + return text + _ASSISTANT_PREFILL + + def _score_pairs(self, query: str, candidates: list[str]) -> list[float]: + if self._model is None: + self.load() + return self._score_in_batches(query, candidates) + + def _score_in_batches(self, query: str, candidates: list[str]) -> list[float]: + """Score all (query, candidate) pairs, splitting into sub-batches.""" + all_scores: list[float] = [] + for start in range(0, len(candidates), self._batch_size): + batch = candidates[start : start + self._batch_size] + all_scores.extend(self._score_batch(query, batch)) + return all_scores + + def _score_batch(self, query: str, candidates: list[str]) -> list[float]: + """Single forward pass for one sub-batch. Returns a score per candidate.""" + prompts = [self._build_prompt(query, c) for c in candidates] + + # Left-pad so the last token position is consistent across all sequences. + tokenizer = self._tokenizer # type: ignore[union-attr] + original_side = getattr(tokenizer, "padding_side", "right") + tokenizer.padding_side = "left" + try: + encoded = tokenizer( + prompts, + return_tensors="pt", + padding=True, + truncation=True, + max_length=4096, + ) + finally: + tokenizer.padding_side = original_side + + model = self._model # type: ignore[union-attr] + device = next(model.parameters()).device # type: ignore[union-attr] + input_ids = encoded["input_ids"].to(device) + attention_mask = encoded["attention_mask"].to(device) + + with self._lock: + with _torch.no_grad(): + outputs = model(input_ids=input_ids, attention_mask=attention_mask) + + # logits shape: (batch, seq_len, vocab_size) + # Last position [-1] is the token the model would output next. + last_logits = outputs.logits[:, -1, :] # (batch, vocab) + yes_logits = last_logits[:, self._yes_id] # (batch,) + no_logits = last_logits[:, self._no_id] # (batch,) + + # Softmax over yes/no only — score = P(yes | query, doc). + stacked = _torch.stack([yes_logits, no_logits], dim=-1) # (batch, 2) + probs = _torch.softmax(stacked, dim=-1) + scores: list[float] = probs[:, 0].float().cpu().tolist() + return scores diff --git a/circuitforge_core/reranker/adapters/remote.py b/circuitforge_core/reranker/adapters/remote.py new file mode 100644 index 0000000..4b288b9 --- /dev/null +++ b/circuitforge_core/reranker/adapters/remote.py @@ -0,0 +1,131 @@ +# circuitforge_core/reranker/adapters/remote.py — HTTP remote reranker adapter +# +# Calls a cf-reranker service endpoint (cf-orch allocated or static URL). +# No model weights loaded locally — all inference runs on the remote node. +# +# MIT licensed. +from __future__ import annotations + +import logging +from typing import Sequence + +import requests + +from circuitforge_core.reranker.base import TextReranker + +logger = logging.getLogger(__name__) + +# Default timeout for a single /rerank call (seconds). +# Large candidate lists may take longer — callers can pass timeout= explicitly. +_DEFAULT_TIMEOUT = 30 + + +class RemoteTextReranker(TextReranker): + """ + Reranker that delegates scoring to a remote cf-reranker HTTP service. + + The remote service must implement POST /rerank with the request body:: + + {"query": str, "candidates": [str, ...], "top_n": int} + + and return:: + + {"results": [{"candidate": str, "score": float, "rank": int}, ...]} + + cf-orch allocation (recommended — starts service on-demand): + reranker = RemoteTextReranker.from_cf_orch( + orch_url="http://10.1.10.71:7700", + service="cf-reranker", + model_candidates=["qwen3-0.6b"], + ) + + Static URL (e.g. dedicated node already running cf-reranker): + reranker = RemoteTextReranker("http://10.1.10.10:8011") + """ + + def __init__( + self, + base_url: str, + timeout: int = _DEFAULT_TIMEOUT, + _model_id: str = "remote", + ) -> None: + self._base_url = base_url.rstrip("/") + self._timeout = timeout + self._model_id_str = _model_id + + @property + def model_id(self) -> str: + return self._model_id_str + + @classmethod + def from_cf_orch( + cls, + orch_url: str, + service: str = "cf-reranker", + model_candidates: list[str] | None = None, + ttl_s: float = 3600.0, + timeout: int = _DEFAULT_TIMEOUT, + ) -> "RemoteTextReranker": + """ + Allocate a cf-reranker service via cf-orch and return a configured adapter. + + Blocks until allocation succeeds or raises on failure. The returned + adapter is valid for the duration of the TTL; create a new one if the + lease expires. + + This is a one-shot allocation — the caller owns the lifetime. For + long-running services, prefer the static URL constructor and let + cf-orch manage the process independently. + """ + try: + from circuitforge_orch.client import CFOrchClient # type: ignore[import] + except ImportError as exc: + raise ImportError( + "circuitforge_orch is not installed — cannot allocate via cf-orch." + ) from exc + + client = CFOrchClient(orch_url) + ctx = client.allocate( + service, + model_candidates=model_candidates or [], + ttl_s=ttl_s, + caller="reranker-remote", + ) + alloc = ctx.__enter__() + # Note: caller is responsible for ctx.__exit__() when done. + # We stash it on the instance so callers can call release(). + instance = cls( + base_url=alloc.url, + timeout=timeout, + _model_id=f"remote:{service}", + ) + instance._orch_ctx = ctx # type: ignore[attr-defined] + return instance + + def release(self) -> None: + """Release the cf-orch allocation if this adapter was created via from_cf_orch().""" + ctx = getattr(self, "_orch_ctx", None) + if ctx is not None: + try: + ctx.__exit__(None, None, None) + except Exception: + pass + self._orch_ctx = None # type: ignore[attr-defined] + + def _score_pairs(self, query: str, candidates: list[str]) -> list[float]: + url = f"{self._base_url}/rerank" + payload = {"query": query, "candidates": candidates, "top_n": 0} + try: + resp = requests.post(url, json=payload, timeout=self._timeout) + resp.raise_for_status() + except requests.RequestException as exc: + raise RuntimeError( + f"Remote reranker at {url!r} failed: {exc}" + ) from exc + + data = resp.json() + # Build a score-per-candidate list in the original order. + score_map: dict[str, float] = { + r["candidate"]: r["score"] for r in data["results"] + } + return [score_map.get(c, 0.0) for c in candidates] diff --git a/circuitforge_core/reranker/app.py b/circuitforge_core/reranker/app.py new file mode 100644 index 0000000..ec693ea --- /dev/null +++ b/circuitforge_core/reranker/app.py @@ -0,0 +1,169 @@ +""" +circuitforge_core.reranker.app — cf-reranker FastAPI service. + +Managed by cf-orch as a process-type service. cf-orch starts this via: + + python -m circuitforge_core.reranker.app \ + --model BAAI/bge-reranker-base \ + --backend bge \ + --port 8011 \ + --gpu-id 0 + +Or with Qwen3: + + python -m circuitforge_core.reranker.app \ + --model Qwen/Qwen3-Reranker-0.6B \ + --backend qwen3 \ + --port 8011 \ + --gpu-id 0 \ + --dtype float16 + +Endpoints: + GET /health → {"status": "ok", "model": "...", "backend": "...", "vram_mb": n} + POST /rerank → RerankResponse +""" +from __future__ import annotations + +import argparse +import logging +import os + +import uvicorn +from fastapi import FastAPI, HTTPException +from pydantic import BaseModel + +logger = logging.getLogger(__name__) + + +# ── Request / response models ───────────────────────────────────────────────── + +class RerankRequest(BaseModel): + query: str + candidates: list[str] + top_n: int = 0 + + +class RerankResultItem(BaseModel): + candidate: str + score: float + rank: int + + +class RerankResponse(BaseModel): + results: list[RerankResultItem] + model: str + + +class HealthResponse(BaseModel): + status: str + model: str + backend: str + vram_mb: int + + +# ── VRAM estimates by backend/model family ──────────────────────────────────── + +_VRAM_TABLE: dict[str, int] = { + "bge-reranker-base": 570, + "bge-reranker-large": 1300, + "bge-reranker-v2-m3": 570, + "mxbai-rerank-base-v1": 570, + "mxbai-rerank-large-v1": 1300, + "ms-marco-MiniLM-L-6-v2": 90, + "ms-marco-MiniLM-L-12-v2": 130, + "Qwen3-Reranker-0.6B": 1200, + "Qwen3-Reranker-1.5B": 3000, + "Qwen3-Reranker-8B": 16000, +} + +def _estimate_vram(model_id: str) -> int: + for key, mb in _VRAM_TABLE.items(): + if key in model_id: + return mb + return 1024 # safe default + + +# ── App factory ─────────────────────────────────────────────────────────────── + +def create_app(model_id: str, backend: str, dtype: str, mock: bool) -> FastAPI: + from circuitforge_core.reranker import make_reranker + + app = FastAPI(title="cf-reranker", version="0.1.0") + _reranker = make_reranker(model_id=model_id, backend=backend, mock=mock) + _vram_mb = _estimate_vram(model_id) + + logger.info("cf-reranker ready: model=%r backend=%r vram=%dMB", model_id, backend, _vram_mb) + + @app.get("/health", response_model=HealthResponse) + async def health() -> HealthResponse: + return HealthResponse( + status="ok", + model=_reranker.model_id, + backend=backend, + vram_mb=_vram_mb, + ) + + @app.post("/rerank", response_model=RerankResponse) + async def rerank(req: RerankRequest) -> RerankResponse: + if not req.candidates: + raise HTTPException(status_code=400, detail="candidates must not be empty") + try: + results = _reranker.rerank(req.query, req.candidates, top_n=req.top_n) + except Exception as exc: + logger.exception("rerank failed") + raise HTTPException(status_code=500, detail=str(exc)) from exc + return RerankResponse( + results=[ + RerankResultItem(candidate=r.candidate, score=r.score, rank=r.rank) + for r in results + ], + model=_reranker.model_id, + ) + + return app + + +# ── CLI entry point ─────────────────────────────────────────────────────────── + +def main() -> None: + parser = argparse.ArgumentParser(description="cf-reranker — CircuitForge reranker service") + parser.add_argument( + "--model", default="BAAI/bge-reranker-base", + help="HuggingFace model ID or local path", + ) + parser.add_argument( + "--backend", default="bge", + choices=["bge", "qwen3", "cross-encoder", "mock"], + help="Reranker backend", + ) + parser.add_argument("--port", type=int, default=8011) + parser.add_argument("--host", default="0.0.0.0") + parser.add_argument("--gpu-id", type=int, default=0) + parser.add_argument( + "--dtype", default="float16", + choices=["float16", "bfloat16", "float32"], + ) + parser.add_argument("--mock", action="store_true", + help="Run with mock backend (no GPU, for testing)") + args = parser.parse_args() + + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(levelname)s %(name)s %(message)s", + ) + + if args.backend != "mock" and not args.mock: + os.environ.setdefault("CUDA_VISIBLE_DEVICES", str(args.gpu_id)) + + mock = args.mock or os.environ.get("CF_RERANKER_MOCK", "") == "1" + app = create_app( + model_id=args.model, + backend=args.backend, + dtype=args.dtype, + mock=mock, + ) + uvicorn.run(app, host=args.host, port=args.port, log_level="info") + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index af9f1a0..e1cd70e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "circuitforge-core" -version = "0.15.0" +version = "0.17.0" description = "Shared scaffold for CircuitForge products (MIT)" requires-python = ">=3.11" dependencies = [ @@ -91,6 +91,17 @@ reranker-qwen3 = [ "transformers>=4.40", "accelerate>=0.27", ] +reranker-cross-encoder = [ + "sentence-transformers>=3.0", +] +reranker-cohere = [ + "cohere>=5.0", +] +reranker-service = [ + "circuitforge-core[reranker-qwen3]", + "fastapi>=0.110", + "uvicorn[standard]>=0.29", +] dev = [ "circuitforge-core[manage]", "pytest>=8.0", diff --git a/tests/test_reranker/test_cohere.py b/tests/test_reranker/test_cohere.py new file mode 100644 index 0000000..d2f3635 --- /dev/null +++ b/tests/test_reranker/test_cohere.py @@ -0,0 +1,106 @@ +"""Tests for CohereTextReranker with mocked cohere client.""" +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest + +from circuitforge_core.reranker.adapters.cohere import CohereTextReranker + + +def _make_cohere_result(index: int, score: float) -> MagicMock: + r = MagicMock() + r.index = index + r.relevance_score = score + return r + + +def _make_mock_client(results: list[MagicMock]) -> MagicMock: + response = MagicMock() + response.results = results + client = MagicMock() + client.rerank.return_value = response + return client + + +def test_model_id_includes_model_name(): + r = CohereTextReranker(model="rerank-multilingual-v3.0") + assert r.model_id == "cohere:rerank-multilingual-v3.0" + + +def test_raises_without_cohere_package(): + reranker = CohereTextReranker(api_key="co-test") + with patch("circuitforge_core.reranker.adapters.cohere._cohere", None): + with pytest.raises(ImportError, match="cohere"): + reranker._score_pairs("q", ["doc"]) + + +def test_raises_without_api_key(monkeypatch): + monkeypatch.delenv("COHERE_API_KEY", raising=False) + reranker = CohereTextReranker() # no api_key arg + mock_cohere = MagicMock() + with patch("circuitforge_core.reranker.adapters.cohere._cohere", mock_cohere): + with pytest.raises(RuntimeError, match="API key"): + reranker._get_client() + + +def test_reads_api_key_from_env(monkeypatch): + monkeypatch.setenv("COHERE_API_KEY", "co-fromenv") + mock_cohere = MagicMock() + with patch("circuitforge_core.reranker.adapters.cohere._cohere", mock_cohere): + reranker = CohereTextReranker() + reranker._get_client() + mock_cohere.Client.assert_called_once_with(api_key="co-fromenv") + + +def test_score_pairs_returns_original_order(): + """Cohere returns results sorted by score; we must restore original order.""" + reranker = CohereTextReranker(api_key="co-test") + # Cohere returns candidates ranked: index 2 (0.9), index 0 (0.6), index 1 (0.1) + mock_client = _make_mock_client([ + _make_cohere_result(index=2, score=0.9), + _make_cohere_result(index=0, score=0.6), + _make_cohere_result(index=1, score=0.1), + ]) + with patch.object(reranker, "_get_client", return_value=mock_client): + scores = reranker._score_pairs("query", ["a", "b", "c"]) + # Original order: a=0.6, b=0.1, c=0.9 + assert scores == [pytest.approx(0.6), pytest.approx(0.1), pytest.approx(0.9)] + + +def test_rerank_sorts_correctly(): + reranker = CohereTextReranker(api_key="co-test") + mock_client = _make_mock_client([ + _make_cohere_result(index=1, score=0.95), + _make_cohere_result(index=0, score=0.3), + ]) + with patch.object(reranker, "_get_client", return_value=mock_client): + results = reranker.rerank("query", ["less relevant", "more relevant"]) + assert results[0].candidate == "more relevant" + assert results[0].rank == 0 + + +def test_rerank_top_n(): + reranker = CohereTextReranker(api_key="co-test") + mock_client = _make_mock_client([ + _make_cohere_result(index=0, score=0.9), + _make_cohere_result(index=1, score=0.5), + _make_cohere_result(index=2, score=0.1), + ]) + with patch.object(reranker, "_get_client", return_value=mock_client): + results = reranker.rerank("q", ["a", "b", "c"], top_n=2) + assert len(results) == 2 + + +def test_rerank_calls_cohere_with_correct_args(): + reranker = CohereTextReranker(api_key="co-test", model="rerank-english-v3.0") + mock_client = _make_mock_client([_make_cohere_result(index=0, score=0.8)]) + with patch.object(reranker, "_get_client", return_value=mock_client): + reranker.rerank("my query", ["only doc"]) + mock_client.rerank.assert_called_once_with( + query="my query", + documents=["only doc"], + model="rerank-english-v3.0", + top_n=1, + max_chunks_per_doc=1, + ) diff --git a/tests/test_reranker/test_cross_encoder.py b/tests/test_reranker/test_cross_encoder.py new file mode 100644 index 0000000..586cb9d --- /dev/null +++ b/tests/test_reranker/test_cross_encoder.py @@ -0,0 +1,77 @@ +"""Tests for CrossEncoderTextReranker with mocked sentence-transformers.""" +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest + +from circuitforge_core.reranker.adapters.cross_encoder import CrossEncoderTextReranker + + +def _make_mock_cross_encoder(scores: list[float]) -> MagicMock: + m = MagicMock() + m.predict.return_value = scores + return m + + +def test_model_id(): + assert ( + CrossEncoderTextReranker("mixedbread-ai/mxbai-rerank-base-v1").model_id + == "mixedbread-ai/mxbai-rerank-base-v1" + ) + + +def test_load_raises_without_sentence_transformers(): + reranker = CrossEncoderTextReranker() + with patch("circuitforge_core.reranker.adapters.cross_encoder._CrossEncoder", None): + with pytest.raises(ImportError, match="sentence-transformers"): + reranker.load() + + +def test_rerank_scores_and_sorts(): + reranker = CrossEncoderTextReranker("mixedbread-ai/mxbai-rerank-base-v1") + reranker._model = _make_mock_cross_encoder([0.2, 0.9, 0.5]) + + results = reranker.rerank("query", ["a", "b", "c"]) + assert results[0].candidate == "b" + assert results[0].rank == 0 + assert results[2].candidate == "a" + + +def test_rerank_top_n(): + reranker = CrossEncoderTextReranker() + reranker._model = _make_mock_cross_encoder([0.1, 0.8, 0.5]) + results = reranker.rerank("q", ["a", "b", "c"], top_n=2) + assert len(results) == 2 + assert results[0].candidate == "b" + + +def test_predict_called_with_pairs(): + reranker = CrossEncoderTextReranker() + mock_model = _make_mock_cross_encoder([0.7, 0.3]) + reranker._model = mock_model + + reranker.rerank("chicken soup", ["recipe one", "recipe two"]) + pairs = mock_model.predict.call_args[0][0] + assert pairs == [("chicken soup", "recipe one"), ("chicken soup", "recipe two")] + + +def test_numpy_scores_coerced_to_float(): + """predict() may return numpy floats — verify they're converted cleanly.""" + try: + import numpy as np + numpy_scores = np.array([0.8, 0.2]) + except ImportError: + pytest.skip("numpy not installed") + + reranker = CrossEncoderTextReranker() + reranker._model = _make_mock_cross_encoder(numpy_scores) # type: ignore[arg-type] + results = reranker.rerank("q", ["a", "b"]) + assert isinstance(results[0].score, float) + + +def test_unload_clears_model(): + reranker = CrossEncoderTextReranker() + reranker._model = MagicMock() + reranker.unload() + assert reranker._model is None diff --git a/tests/test_reranker/test_qwen3.py b/tests/test_reranker/test_qwen3.py new file mode 100644 index 0000000..91dc81b --- /dev/null +++ b/tests/test_reranker/test_qwen3.py @@ -0,0 +1,185 @@ +"""Tests for Qwen3TextReranker with mocked transformers.""" +from __future__ import annotations + +from unittest.mock import MagicMock, patch, PropertyMock + +import pytest + +from circuitforge_core.reranker.adapters.qwen3 import Qwen3TextReranker, _ASSISTANT_PREFILL + + +# ── Helpers ─────────────────────────────────────────────────────────────────── + +def _make_mock_model(yes_logit: float = 5.0, no_logit: float = 1.0, batch_size: int = 1): + """Return a mock AutoModelForCausalLM that outputs fixed yes/no logits.""" + import torch + + model = MagicMock() + # Simulate logits: (batch, seq_len=1, vocab_size=32000) + # yes token id = 9693, no token id = 2201 (Qwen tokenizer typical values) + vocab_size = 32000 + logits = torch.zeros(batch_size, 1, vocab_size) + logits[:, :, 9693] = yes_logit # "yes" token + logits[:, :, 2201] = no_logit # "no" token + output = MagicMock() + output.logits = logits + model.return_value = output + + # next(model.parameters()).device + param = MagicMock() + param.device = torch.device("cpu") + model.parameters.return_value = iter([param]) + return model + + +def _make_mock_tokenizer(yes_id: int = 9693, no_id: int = 2201): + """Return a mock AutoTokenizer.""" + import torch + + tokenizer = MagicMock() + tokenizer.encode.side_effect = lambda text, **kw: ( + [yes_id] if text == "yes" else [no_id] + ) + tokenizer.apply_chat_template.return_value = "" + tokenizer.padding_side = "right" + + # Return simple fixed-length tensors from __call__ + tokenizer.return_value = { + "input_ids": torch.zeros(1, 10, dtype=torch.long), + "attention_mask": torch.ones(1, 10, dtype=torch.long), + } + return tokenizer + + +# ── Unit tests ──────────────────────────────────────────────────────────────── + +def test_load_raises_without_torch(): + reranker = Qwen3TextReranker("Qwen/Qwen3-Reranker-0.6B") + with patch("circuitforge_core.reranker.adapters.qwen3._torch", None): + with pytest.raises(ImportError, match="torch"): + reranker.load() + + +def test_load_raises_without_transformers(): + reranker = Qwen3TextReranker("Qwen/Qwen3-Reranker-0.6B") + with patch("circuitforge_core.reranker.adapters.qwen3._AutoModel", None): + with pytest.raises(ImportError, match="transformers"): + reranker.load() + + +def test_model_id(): + assert Qwen3TextReranker("Qwen/Qwen3-Reranker-1.5B").model_id == "Qwen/Qwen3-Reranker-1.5B" + + +def test_unload_clears_state(): + reranker = Qwen3TextReranker() + reranker._model = MagicMock() + reranker._tokenizer = MagicMock() + reranker._yes_id = 1 + reranker._no_id = 2 + reranker.unload() + assert reranker._model is None + assert reranker._tokenizer is None + assert reranker._yes_id is None + assert reranker._no_id is None + + +def test_build_prompt_includes_prefill(): + reranker = Qwen3TextReranker() + reranker._tokenizer = _make_mock_tokenizer() + prompt = reranker._build_prompt("what is chicken soup", "a hearty recipe") + assert _ASSISTANT_PREFILL in prompt + + +def test_score_batch_returns_yes_probability(): + """Higher yes_logit → score closer to 1.0.""" + import torch + + reranker = Qwen3TextReranker() + reranker._tokenizer = _make_mock_tokenizer() + reranker._model = _make_mock_model(yes_logit=10.0, no_logit=0.0) + reranker._yes_id = 9693 + reranker._no_id = 2201 + + scores = reranker._score_batch("query", ["candidate"]) + assert len(scores) == 1 + assert scores[0] > 0.99 # softmax(10, 0)[0] ≈ 0.9999 + + +def test_score_batch_low_yes_logit(): + """Lower yes_logit → score closer to 0.0.""" + reranker = Qwen3TextReranker() + reranker._tokenizer = _make_mock_tokenizer() + reranker._model = _make_mock_model(yes_logit=0.0, no_logit=10.0) + reranker._yes_id = 9693 + reranker._no_id = 2201 + + scores = reranker._score_batch("query", ["irrelevant candidate"]) + assert scores[0] < 0.01 + + +def test_rerank_sorts_by_score(): + """Integration through rerank() — highest yes-logit candidate should rank first.""" + import torch + + reranker = Qwen3TextReranker(batch_size=10) + tokenizer = _make_mock_tokenizer() + # Return different-length tensors per call to simulate multi-candidate batch + call_count = [0] + + def tokenize_side_effect(prompts, **kw): + n = len(prompts) + return { + "input_ids": torch.zeros(n, 10, dtype=torch.long), + "attention_mask": torch.ones(n, 10, dtype=torch.long), + } + + tokenizer.side_effect = tokenize_side_effect + tokenizer.return_value = None # disable default return + reranker._tokenizer = tokenizer + + # Simulate two candidates: first gets low yes logit, second gets high + import torch as _torch + vocab_size = 32000 + batch_logits = _torch.zeros(2, 1, vocab_size) + batch_logits[0, 0, 9693] = 1.0 # candidate 0: low relevance + batch_logits[0, 0, 2201] = 5.0 + batch_logits[1, 0, 9693] = 5.0 # candidate 1: high relevance + batch_logits[1, 0, 2201] = 1.0 + + output = MagicMock() + output.logits = batch_logits + model = MagicMock() + model.return_value = output + param = MagicMock() + param.device = _torch.device("cpu") + model.parameters.return_value = iter([param]) + + reranker._model = model + reranker._yes_id = 9693 + reranker._no_id = 2201 + + results = reranker.rerank("query", ["low relevance doc", "high relevance doc"]) + assert results[0].candidate == "high relevance doc" + assert results[0].rank == 0 + + +def test_score_in_batches_splits_correctly(): + """Verify that large candidate lists are split into sub-batches.""" + reranker = Qwen3TextReranker(batch_size=2) + reranker._tokenizer = _make_mock_tokenizer() + reranker._yes_id = 9693 + reranker._no_id = 2201 + + batch_results: list[list[float]] = [] + + def fake_score_batch(query, cands): + batch_results.append(cands) + return [0.5] * len(cands) + + reranker._score_batch = fake_score_batch # type: ignore[method-assign] + scores = reranker._score_in_batches("q", ["a", "b", "c", "d", "e"]) + assert len(scores) == 5 + # 5 candidates with batch_size=2 → 3 sub-batches: [a,b], [c,d], [e] + assert len(batch_results) == 3 + assert batch_results[2] == ["e"] diff --git a/tests/test_reranker/test_remote.py b/tests/test_reranker/test_remote.py new file mode 100644 index 0000000..e4d5fec --- /dev/null +++ b/tests/test_reranker/test_remote.py @@ -0,0 +1,105 @@ +"""Tests for RemoteTextReranker.""" +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest + +from circuitforge_core.reranker.adapters.remote import RemoteTextReranker + + +def _make_mock_response(results: list[dict]) -> MagicMock: + resp = MagicMock() + resp.json.return_value = {"results": results} + resp.raise_for_status = MagicMock() + return resp + + +def test_model_id(): + assert RemoteTextReranker("http://10.0.0.1:8011").model_id == "remote" + + +def test_score_pairs_posts_to_rerank_endpoint(): + reranker = RemoteTextReranker("http://10.0.0.1:8011") + mock_resp = _make_mock_response([ + {"candidate": "doc a", "score": 0.9, "rank": 0}, + {"candidate": "doc b", "score": 0.3, "rank": 1}, + ]) + with patch("requests.post", return_value=mock_resp) as mock_post: + scores = reranker._score_pairs("query", ["doc a", "doc b"]) + + mock_post.assert_called_once_with( + "http://10.0.0.1:8011/rerank", + json={"query": "query", "candidates": ["doc a", "doc b"], "top_n": 0}, + timeout=30, + ) + assert scores == [pytest.approx(0.9), pytest.approx(0.3)] + + +def test_score_pairs_restores_original_order(): + """Remote may return results in any order — scores must align with input.""" + reranker = RemoteTextReranker("http://10.0.0.1:8011") + # Remote returned c first (highest score), then a, then b + mock_resp = _make_mock_response([ + {"candidate": "c", "score": 0.95, "rank": 0}, + {"candidate": "a", "score": 0.6, "rank": 1}, + {"candidate": "b", "score": 0.1, "rank": 2}, + ]) + with patch("requests.post", return_value=mock_resp): + scores = reranker._score_pairs("q", ["a", "b", "c"]) + assert scores == [pytest.approx(0.6), pytest.approx(0.1), pytest.approx(0.95)] + + +def test_score_pairs_raises_on_http_error(): + import requests as req + reranker = RemoteTextReranker("http://10.0.0.1:8011") + with patch("requests.post", side_effect=req.ConnectionError("refused")): + with pytest.raises(RuntimeError, match="Remote reranker"): + reranker._score_pairs("q", ["doc"]) + + +def test_rerank_end_to_end(): + reranker = RemoteTextReranker("http://10.0.0.1:8011") + mock_resp = _make_mock_response([ + {"candidate": "irrelevant", "score": 0.2, "rank": 0}, + {"candidate": "very relevant", "score": 0.9, "rank": 1}, + ]) + with patch("requests.post", return_value=mock_resp): + results = reranker.rerank("q", ["irrelevant", "very relevant"]) + assert results[0].candidate == "very relevant" + assert results[0].rank == 0 + + +# ── make_reranker wiring ────────────────────────────────────────────────────── + +def test_make_reranker_qwen3(): + from circuitforge_core.reranker import make_reranker + from circuitforge_core.reranker.adapters.qwen3 import Qwen3TextReranker + r = make_reranker("Qwen/Qwen3-Reranker-0.6B", backend="qwen3") + assert isinstance(r, Qwen3TextReranker) + + +def test_make_reranker_cross_encoder(): + from circuitforge_core.reranker import make_reranker + from circuitforge_core.reranker.adapters.cross_encoder import CrossEncoderTextReranker + r = make_reranker("mixedbread-ai/mxbai-rerank-base-v1", backend="cross-encoder") + assert isinstance(r, CrossEncoderTextReranker) + + +def test_make_reranker_cohere(): + from circuitforge_core.reranker import make_reranker + from circuitforge_core.reranker.adapters.cohere import CohereTextReranker + r = make_reranker("rerank-english-v3.0", backend="cohere") + assert isinstance(r, CohereTextReranker) + + +def test_make_reranker_remote(): + from circuitforge_core.reranker import make_reranker + r = make_reranker("http://10.0.0.1:8011", backend="remote") + assert isinstance(r, RemoteTextReranker) + + +def test_make_reranker_unknown_raises(): + from circuitforge_core.reranker import make_reranker + with pytest.raises(ValueError, match="cross-encoder"): + make_reranker(backend="unknown-backend")