feat: cross-encoder reranker for recipe suggestions (kiwi#117)
Integrates cf-core reranker into the L1/L2 recipe engine. Paid+ tier gets a BGE cross-encoder pass over the top-20 FTS candidates, scoring each recipe against the user's full context: pantry state, dietary constraints, allergies, expiry urgency, style preference, and effort preference. Free tier keeps the existing overlap sort unchanged. - New app/services/recipe/reranker.py: build_query, build_candidate_string, rerank_suggestions with tier gate (_RERANKER_TIERS) and graceful fallback - rerank_score field added to RecipeSuggestion (None on free tier, float on paid+) - recipe_engine.py: single call after candidate assembly, before final sort; hard_day_mode tier grouping preserved as primary sort when reranker active - Fix pre-existing circular import in app/services/__init__.py (eager import of ReceiptService triggered store.py → services → receipt_service → store) - 27 unit tests (mock backend, no model weights) + 2 engine-level tier tests; 325 tests passing, no regressions
This commit is contained in:
parent
91867f15f4
commit
b5eb8e4772
6 changed files with 520 additions and 7 deletions
|
|
@ -43,6 +43,7 @@ class RecipeSuggestion(BaseModel):
|
|||
source_url: str | None = None
|
||||
complexity: str | None = None # 'easy' | 'moderate' | 'involved'
|
||||
estimated_time_min: int | None = None # derived from step count + method signals
|
||||
rerank_score: float | None = None # cross-encoder relevance score (paid+ only, None for free tier)
|
||||
|
||||
|
||||
class GroceryLink(BaseModel):
|
||||
|
|
|
|||
|
|
@ -3,6 +3,11 @@
|
|||
Business logic services for Kiwi.
|
||||
"""
|
||||
|
||||
from app.services.receipt_service import ReceiptService
|
||||
__all__ = ["ReceiptService"]
|
||||
|
||||
__all__ = ["ReceiptService"]
|
||||
|
||||
def __getattr__(name: str):
|
||||
if name == "ReceiptService":
|
||||
from app.services.receipt_service import ReceiptService
|
||||
return ReceiptService
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ from app.services.recipe.grocery_links import GroceryLinkBuilder
|
|||
from app.services.recipe.substitution_engine import SubstitutionEngine
|
||||
from app.services.recipe.sensory import SensoryExclude, build_sensory_exclude, passes_sensory_filter
|
||||
from app.services.recipe.time_effort import parse_time_effort
|
||||
from app.services.recipe.reranker import rerank_suggestions
|
||||
|
||||
_LEFTOVER_DAILY_MAX_FREE = 5
|
||||
|
||||
|
|
@ -880,11 +881,21 @@ class RecipeEngine:
|
|||
estimated_time_min=row_time_min,
|
||||
))
|
||||
|
||||
# Sort corpus results — assembly templates are now served from a dedicated tab.
|
||||
# Hard day mode: primary sort by tier (0=premade, 1=simple, 2=moderate),
|
||||
# then by match_count within each tier.
|
||||
# Normal mode: sort by match_count descending.
|
||||
if req.hard_day_mode and hard_day_tier_map:
|
||||
# Sort corpus results.
|
||||
# Paid+ tier: cross-encoder reranker orders by full pantry + dietary fit.
|
||||
# Free tier (or reranker failure): overlap sort with hard_day_mode tier grouping.
|
||||
reranked = rerank_suggestions(req, suggestions)
|
||||
if reranked is not None:
|
||||
# Reranker provided relevance order. In hard_day_mode, still respect
|
||||
# tier grouping as primary sort; reranker order applies within each tier.
|
||||
if req.hard_day_mode and hard_day_tier_map:
|
||||
suggestions = sorted(
|
||||
reranked,
|
||||
key=lambda s: hard_day_tier_map.get(s.id, 1),
|
||||
)
|
||||
else:
|
||||
suggestions = reranked
|
||||
elif req.hard_day_mode and hard_day_tier_map:
|
||||
suggestions = sorted(
|
||||
suggestions,
|
||||
key=lambda s: (hard_day_tier_map.get(s.id, 1), -s.match_count),
|
||||
|
|
|
|||
175
app/services/recipe/reranker.py
Normal file
175
app/services/recipe/reranker.py
Normal file
|
|
@ -0,0 +1,175 @@
|
|||
"""
|
||||
Reranker integration for recipe suggestions.
|
||||
|
||||
Wraps circuitforge_core.reranker to score recipe candidates against a
|
||||
natural-language query built from the user's pantry, constraints, and
|
||||
preferences. Paid+ tier only; free tier returns None (caller keeps
|
||||
existing sort). All exceptions are caught and logged — the reranker
|
||||
must never break recipe suggestions.
|
||||
|
||||
Environment:
|
||||
CF_RERANKER_MOCK=1 — force mock backend (tests, no model required)
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from app.models.schemas.recipe import RecipeRequest, RecipeSuggestion
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
# Tiers that get reranker access.
|
||||
_RERANKER_TIERS: frozenset[str] = frozenset({"paid", "premium", "local"})
|
||||
|
||||
# Minimum candidates worth reranking — below this the cross-encoder
|
||||
# overhead is not justified and the overlap sort is fine.
|
||||
_MIN_CANDIDATES: int = 3
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RerankerInput:
|
||||
"""Intermediate representation passed to the reranker."""
|
||||
query: str
|
||||
candidates: list[str]
|
||||
suggestion_ids: list[int] # parallel to candidates, for re-mapping
|
||||
|
||||
|
||||
# ── Query builder ─────────────────────────────────────────────────────────────
|
||||
|
||||
def build_query(req: RecipeRequest) -> str:
|
||||
"""Build a natural-language query string from the recipe request.
|
||||
|
||||
Encodes the user's full context so the cross-encoder can score
|
||||
relevance, dietary fit, and expiry urgency in a single pass.
|
||||
Only non-empty segments are included.
|
||||
"""
|
||||
parts: list[str] = []
|
||||
|
||||
if req.pantry_items:
|
||||
parts.append(f"Recipe using: {', '.join(req.pantry_items)}")
|
||||
|
||||
if req.exclude_ingredients:
|
||||
parts.append(f"Avoid: {', '.join(req.exclude_ingredients)}")
|
||||
|
||||
if req.allergies:
|
||||
parts.append(f"Allergies: {', '.join(req.allergies)}")
|
||||
|
||||
if req.constraints:
|
||||
parts.append(f"Dietary: {', '.join(req.constraints)}")
|
||||
|
||||
if req.category:
|
||||
parts.append(f"Category: {req.category}")
|
||||
|
||||
if req.style_id:
|
||||
parts.append(f"Style: {req.style_id}")
|
||||
|
||||
if req.complexity_filter:
|
||||
parts.append(f"Prefer: {req.complexity_filter}")
|
||||
|
||||
if req.hard_day_mode:
|
||||
parts.append("Prefer: easy, minimal effort")
|
||||
|
||||
# Secondary pantry items carry a state label (e.g. "stale", "overripe")
|
||||
# that helps the reranker favour recipes suited to those specific states.
|
||||
if req.secondary_pantry_items:
|
||||
expiry_parts = [f"{name} ({state})" for name, state in req.secondary_pantry_items.items()]
|
||||
parts.append(f"Use soon: {', '.join(expiry_parts)}")
|
||||
elif req.expiry_first:
|
||||
parts.append("Prefer: recipes that use expiring items first")
|
||||
|
||||
return ". ".join(parts) + "." if parts else "Recipe."
|
||||
|
||||
|
||||
# ── Candidate builder ─────────────────────────────────────────────────────────
|
||||
|
||||
def build_candidate_string(suggestion: RecipeSuggestion) -> str:
|
||||
"""Build a candidate string for a single recipe suggestion.
|
||||
|
||||
Format: "{title}. Ingredients: {comma-joined ingredients}"
|
||||
Matched ingredients appear before missing ones.
|
||||
Directions excluded to stay within BGE's 512-token window.
|
||||
"""
|
||||
ingredients = suggestion.matched_ingredients + suggestion.missing_ingredients
|
||||
if not ingredients:
|
||||
return suggestion.title
|
||||
return f"{suggestion.title}. Ingredients: {', '.join(ingredients)}"
|
||||
|
||||
|
||||
# ── Input assembler ───────────────────────────────────────────────────────────
|
||||
|
||||
def build_reranker_input(
|
||||
req: RecipeRequest,
|
||||
suggestions: list[RecipeSuggestion],
|
||||
) -> RerankerInput:
|
||||
"""Assemble query and candidate strings for the reranker."""
|
||||
query = build_query(req)
|
||||
candidates: list[str] = []
|
||||
ids: list[int] = []
|
||||
for s in suggestions:
|
||||
candidates.append(build_candidate_string(s))
|
||||
ids.append(s.id)
|
||||
return RerankerInput(query=query, candidates=candidates, suggestion_ids=ids)
|
||||
|
||||
|
||||
# ── cf-core seam (isolated for monkeypatching in tests) ──────────────────────
|
||||
|
||||
def _do_rerank(query: str, candidates: list[str], top_n: int = 0):
|
||||
"""Thin wrapper around cf-core rerank(). Extracted so tests can patch it."""
|
||||
from circuitforge_core.reranker import rerank
|
||||
return rerank(query, candidates, top_n=top_n)
|
||||
|
||||
|
||||
# ── Public entry point ────────────────────────────────────────────────────────
|
||||
|
||||
def rerank_suggestions(
|
||||
req: RecipeRequest,
|
||||
suggestions: list[RecipeSuggestion],
|
||||
) -> list[RecipeSuggestion] | None:
|
||||
"""Rerank suggestions using the cf-core cross-encoder.
|
||||
|
||||
Returns a reordered list with rerank_score populated, or None when:
|
||||
- Tier is not paid+ (free tier keeps overlap sort)
|
||||
- Fewer than _MIN_CANDIDATES suggestions (not worth the overhead)
|
||||
- Any exception is raised (graceful fallback to existing sort)
|
||||
|
||||
The caller should treat None as "keep existing sort order".
|
||||
Original suggestions are never mutated.
|
||||
"""
|
||||
if req.tier not in _RERANKER_TIERS:
|
||||
return None
|
||||
|
||||
if len(suggestions) < _MIN_CANDIDATES:
|
||||
return None
|
||||
|
||||
try:
|
||||
rinput = build_reranker_input(req, suggestions)
|
||||
results = _do_rerank(rinput.query, rinput.candidates, top_n=0)
|
||||
|
||||
# Map reranked results back to RecipeSuggestion objects using the
|
||||
# candidate string as key (build_candidate_string is deterministic).
|
||||
candidate_map: dict[str, RecipeSuggestion] = {
|
||||
build_candidate_string(s): s for s in suggestions
|
||||
}
|
||||
|
||||
reranked: list[RecipeSuggestion] = []
|
||||
for rr in results:
|
||||
suggestion = candidate_map.get(rr.candidate)
|
||||
if suggestion is not None:
|
||||
reranked.append(suggestion.model_copy(
|
||||
update={"rerank_score": round(float(rr.score), 4)}
|
||||
))
|
||||
|
||||
if len(reranked) < len(suggestions):
|
||||
log.warning(
|
||||
"Reranker lost %d/%d suggestions during mapping, falling back",
|
||||
len(suggestions) - len(reranked),
|
||||
len(suggestions),
|
||||
)
|
||||
return None
|
||||
|
||||
return reranked
|
||||
|
||||
except Exception:
|
||||
log.exception("Reranker failed, falling back to overlap sort")
|
||||
return None
|
||||
|
|
@ -170,3 +170,37 @@ def test_within_time_over_limit_fails():
|
|||
from app.services.recipe.recipe_engine import _within_time
|
||||
steps = ["brown onions for 15 minutes", "simmer for 30 minutes"]
|
||||
assert _within_time(steps, max_total_min=30) is False
|
||||
|
||||
|
||||
# ── Reranker tier-gating tests ────────────────────────────────────────────────
|
||||
|
||||
def test_paid_tier_suggest_populates_rerank_score(store_with_recipes, monkeypatch):
|
||||
"""Paid tier: at least one suggestion should have rerank_score populated."""
|
||||
monkeypatch.setenv("CF_RERANKER_MOCK", "1")
|
||||
try:
|
||||
from circuitforge_core.reranker import reset_reranker
|
||||
reset_reranker()
|
||||
except ImportError:
|
||||
pytest.skip("cf-core reranker not installed")
|
||||
|
||||
from app.services.recipe.recipe_engine import RecipeEngine
|
||||
from app.models.schemas.recipe import RecipeRequest
|
||||
engine = RecipeEngine(store_with_recipes)
|
||||
req = RecipeRequest(pantry_items=["butter", "parmesan", "pasta"], level=1, tier="paid")
|
||||
result = engine.suggest(req)
|
||||
# Need at least _MIN_CANDIDATES for reranker to fire
|
||||
from app.services.recipe.reranker import _MIN_CANDIDATES
|
||||
if len(result.suggestions) >= _MIN_CANDIDATES:
|
||||
assert any(s.rerank_score is not None for s in result.suggestions)
|
||||
|
||||
reset_reranker()
|
||||
|
||||
|
||||
def test_free_tier_suggest_has_no_rerank_score(store_with_recipes):
|
||||
"""Free tier: rerank_score must be None on all suggestions."""
|
||||
from app.services.recipe.recipe_engine import RecipeEngine
|
||||
from app.models.schemas.recipe import RecipeRequest
|
||||
engine = RecipeEngine(store_with_recipes)
|
||||
req = RecipeRequest(pantry_items=["butter", "parmesan"], level=1, tier="free")
|
||||
result = engine.suggest(req)
|
||||
assert all(s.rerank_score is None for s in result.suggestions)
|
||||
|
|
|
|||
287
tests/services/recipe/test_reranker.py
Normal file
287
tests/services/recipe/test_reranker.py
Normal file
|
|
@ -0,0 +1,287 @@
|
|||
"""
|
||||
Tests for app.services.recipe.reranker.
|
||||
|
||||
All tests use CF_RERANKER_MOCK=1 -- no model weights required.
|
||||
The mock reranker scores by Jaccard similarity of query tokens vs candidate
|
||||
tokens, which is deterministic and fast.
|
||||
"""
|
||||
import os
|
||||
import pytest
|
||||
|
||||
|
||||
# ── Fixtures ──────────────────────────────────────────────────────────────────
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_reranker(monkeypatch):
|
||||
"""Force mock backend and reset the cf-core singleton before/after each test."""
|
||||
monkeypatch.setenv("CF_RERANKER_MOCK", "1")
|
||||
try:
|
||||
from circuitforge_core.reranker import reset_reranker
|
||||
reset_reranker()
|
||||
yield
|
||||
reset_reranker()
|
||||
except ImportError:
|
||||
yield
|
||||
|
||||
|
||||
def _make_request(**kwargs):
|
||||
from app.models.schemas.recipe import RecipeRequest
|
||||
defaults = dict(pantry_items=["chicken", "rice"], tier="paid")
|
||||
defaults.update(kwargs)
|
||||
return RecipeRequest(**defaults)
|
||||
|
||||
|
||||
def _make_suggestion(id: int, title: str, matched: list[str], missing: list[str] | None = None, match_count: int | None = None):
|
||||
from app.models.schemas.recipe import RecipeSuggestion
|
||||
mi = missing or []
|
||||
return RecipeSuggestion(
|
||||
id=id,
|
||||
title=title,
|
||||
match_count=match_count if match_count is not None else len(matched),
|
||||
matched_ingredients=matched,
|
||||
missing_ingredients=mi,
|
||||
)
|
||||
|
||||
|
||||
# ── TestBuildQuery ────────────────────────────────────────────────────────────
|
||||
|
||||
class TestBuildQuery:
|
||||
def test_basic_pantry(self):
|
||||
from app.services.recipe.reranker import build_query
|
||||
req = _make_request(pantry_items=["chicken", "rice", "broccoli"])
|
||||
query = build_query(req)
|
||||
assert "chicken" in query
|
||||
assert "rice" in query
|
||||
assert "broccoli" in query
|
||||
|
||||
def test_exclude_ingredients_included(self):
|
||||
from app.services.recipe.reranker import build_query
|
||||
req = _make_request(exclude_ingredients=["cilantro", "fish sauce"])
|
||||
query = build_query(req)
|
||||
assert "cilantro" in query
|
||||
assert "fish sauce" in query
|
||||
|
||||
def test_allergies_separate_from_exclude(self):
|
||||
from app.services.recipe.reranker import build_query
|
||||
req = _make_request(allergies=["shellfish"], exclude_ingredients=["cilantro"])
|
||||
query = build_query(req)
|
||||
# Both should appear, and they should be in separate labeled segments
|
||||
assert "shellfish" in query
|
||||
assert "cilantro" in query
|
||||
allergy_pos = query.index("shellfish")
|
||||
exclude_pos = query.index("cilantro")
|
||||
assert allergy_pos != exclude_pos
|
||||
|
||||
def test_allergies_labeled_separately(self):
|
||||
from app.services.recipe.reranker import build_query
|
||||
req = _make_request(allergies=["peanuts"], exclude_ingredients=[])
|
||||
query = build_query(req)
|
||||
assert "Allergies" in query or "allerg" in query.lower()
|
||||
|
||||
def test_constraints_included(self):
|
||||
from app.services.recipe.reranker import build_query
|
||||
req = _make_request(constraints=["gluten-free", "dairy-free"])
|
||||
query = build_query(req)
|
||||
assert "gluten-free" in query
|
||||
assert "dairy-free" in query
|
||||
|
||||
def test_category_included(self):
|
||||
from app.services.recipe.reranker import build_query
|
||||
req = _make_request(category="Soup")
|
||||
query = build_query(req)
|
||||
assert "Soup" in query
|
||||
|
||||
def test_complexity_filter_included(self):
|
||||
from app.services.recipe.reranker import build_query
|
||||
req = _make_request(complexity_filter="easy")
|
||||
query = build_query(req)
|
||||
assert "easy" in query
|
||||
|
||||
def test_hard_day_mode_signal(self):
|
||||
from app.services.recipe.reranker import build_query
|
||||
req = _make_request(hard_day_mode=True)
|
||||
query = build_query(req)
|
||||
assert "easy" in query.lower() or "minimal" in query.lower() or "effort" in query.lower()
|
||||
|
||||
def test_secondary_pantry_items_expiry(self):
|
||||
from app.services.recipe.reranker import build_query
|
||||
req = _make_request(secondary_pantry_items={"bread": "stale", "banana": "overripe"})
|
||||
query = build_query(req)
|
||||
assert "bread" in query
|
||||
assert "banana" in query
|
||||
# State labels add specificity for the cross-encoder
|
||||
assert "stale" in query or "overripe" in query
|
||||
|
||||
def test_expiry_first_without_secondary(self):
|
||||
from app.services.recipe.reranker import build_query
|
||||
req = _make_request(expiry_first=True, secondary_pantry_items={})
|
||||
query = build_query(req)
|
||||
assert "expir" in query.lower()
|
||||
|
||||
def test_style_id_included(self):
|
||||
from app.services.recipe.reranker import build_query
|
||||
req = _make_request(style_id="mediterranean")
|
||||
query = build_query(req)
|
||||
assert "mediterranean" in query.lower()
|
||||
|
||||
def test_empty_pantry_returns_fallback(self):
|
||||
from app.services.recipe.reranker import build_query
|
||||
req = _make_request(pantry_items=[])
|
||||
query = build_query(req)
|
||||
assert len(query) > 0 # never empty string
|
||||
|
||||
def test_no_duplicate_separators(self):
|
||||
from app.services.recipe.reranker import build_query
|
||||
req = _make_request(
|
||||
pantry_items=["egg"],
|
||||
allergies=["nuts"],
|
||||
constraints=["vegan"],
|
||||
complexity_filter="easy",
|
||||
)
|
||||
query = build_query(req)
|
||||
assert ".." not in query # no doubled periods from empty segments
|
||||
|
||||
|
||||
# ── TestBuildCandidateString ──────────────────────────────────────────────────
|
||||
|
||||
class TestBuildCandidateString:
|
||||
def test_title_and_ingredients(self):
|
||||
from app.services.recipe.reranker import build_candidate_string
|
||||
s = _make_suggestion(1, "Chicken Fried Rice", ["chicken", "rice"], ["soy sauce"])
|
||||
candidate = build_candidate_string(s)
|
||||
assert candidate.startswith("Chicken Fried Rice")
|
||||
assert "chicken" in candidate
|
||||
assert "rice" in candidate
|
||||
assert "soy sauce" in candidate
|
||||
|
||||
def test_title_only_when_no_ingredients(self):
|
||||
from app.services.recipe.reranker import build_candidate_string
|
||||
s = _make_suggestion(2, "Mystery Dish", [], [])
|
||||
candidate = build_candidate_string(s)
|
||||
assert candidate == "Mystery Dish"
|
||||
assert "Ingredients:" not in candidate
|
||||
|
||||
def test_matched_before_missing(self):
|
||||
from app.services.recipe.reranker import build_candidate_string
|
||||
s = _make_suggestion(3, "Pasta Dish", ["pasta", "butter"], ["parmesan", "cream"])
|
||||
candidate = build_candidate_string(s)
|
||||
pasta_pos = candidate.index("pasta")
|
||||
parmesan_pos = candidate.index("parmesan")
|
||||
assert pasta_pos < parmesan_pos
|
||||
|
||||
|
||||
# ── TestBuildRerankerInput ────────────────────────────────────────────────────
|
||||
|
||||
class TestBuildRerankerInput:
|
||||
def test_parallel_ids_and_candidates(self):
|
||||
from app.services.recipe.reranker import build_reranker_input
|
||||
req = _make_request()
|
||||
suggestions = [
|
||||
_make_suggestion(10, "Recipe A", ["chicken"]),
|
||||
_make_suggestion(20, "Recipe B", ["rice"]),
|
||||
_make_suggestion(30, "Recipe C", ["broccoli"]),
|
||||
]
|
||||
rinput = build_reranker_input(req, suggestions)
|
||||
assert len(rinput.candidates) == 3
|
||||
assert len(rinput.suggestion_ids) == 3
|
||||
assert rinput.suggestion_ids == [10, 20, 30]
|
||||
|
||||
def test_query_matches_build_query(self):
|
||||
from app.services.recipe.reranker import build_reranker_input, build_query
|
||||
req = _make_request(pantry_items=["egg", "cheese"], constraints=["vegetarian"])
|
||||
suggestions = [_make_suggestion(1, "Omelette", ["egg", "cheese"])]
|
||||
rinput = build_reranker_input(req, suggestions)
|
||||
assert rinput.query == build_query(req)
|
||||
|
||||
|
||||
# ── TestRerankSuggestions ─────────────────────────────────────────────────────
|
||||
|
||||
class TestRerankSuggestions:
|
||||
def test_free_tier_returns_none(self):
|
||||
from app.services.recipe.reranker import rerank_suggestions
|
||||
req = _make_request(tier="free")
|
||||
suggestions = [_make_suggestion(i, f"Recipe {i}", ["chicken"]) for i in range(5)]
|
||||
result = rerank_suggestions(req, suggestions)
|
||||
assert result is None
|
||||
|
||||
def test_paid_tier_returns_reranked_list(self):
|
||||
from app.services.recipe.reranker import rerank_suggestions
|
||||
req = _make_request(tier="paid", pantry_items=["chicken", "rice"])
|
||||
suggestions = [
|
||||
_make_suggestion(1, "Chicken Fried Rice", ["chicken", "rice"]),
|
||||
_make_suggestion(2, "Chocolate Cake", ["flour", "sugar", "cocoa"]),
|
||||
_make_suggestion(3, "Chicken Soup", ["chicken", "broth"]),
|
||||
]
|
||||
result = rerank_suggestions(req, suggestions)
|
||||
assert result is not None
|
||||
assert len(result) == len(suggestions)
|
||||
|
||||
def test_rerank_score_is_populated(self):
|
||||
from app.services.recipe.reranker import rerank_suggestions
|
||||
req = _make_request(tier="paid")
|
||||
suggestions = [_make_suggestion(i, f"Recipe {i}", ["chicken"]) for i in range(4)]
|
||||
result = rerank_suggestions(req, suggestions)
|
||||
assert result is not None
|
||||
assert all(s.rerank_score is not None for s in result)
|
||||
assert all(isinstance(s.rerank_score, float) for s in result)
|
||||
|
||||
def test_too_few_candidates_returns_none(self):
|
||||
from app.services.recipe.reranker import rerank_suggestions, _MIN_CANDIDATES
|
||||
req = _make_request(tier="paid")
|
||||
suggestions = [_make_suggestion(i, f"Recipe {i}", ["chicken"]) for i in range(_MIN_CANDIDATES - 1)]
|
||||
result = rerank_suggestions(req, suggestions)
|
||||
assert result is None
|
||||
|
||||
def test_premium_tier_gets_reranker(self):
|
||||
from app.services.recipe.reranker import rerank_suggestions
|
||||
req = _make_request(tier="premium")
|
||||
suggestions = [_make_suggestion(i, f"Recipe {i}", ["chicken"]) for i in range(4)]
|
||||
result = rerank_suggestions(req, suggestions)
|
||||
assert result is not None
|
||||
|
||||
def test_local_tier_gets_reranker(self):
|
||||
from app.services.recipe.reranker import rerank_suggestions
|
||||
req = _make_request(tier="local")
|
||||
suggestions = [_make_suggestion(i, f"Recipe {i}", ["chicken"]) for i in range(4)]
|
||||
result = rerank_suggestions(req, suggestions)
|
||||
assert result is not None
|
||||
|
||||
def test_preserves_all_suggestion_fields(self):
|
||||
from app.services.recipe.reranker import rerank_suggestions
|
||||
req = _make_request(tier="paid")
|
||||
original = _make_suggestion(
|
||||
id=42,
|
||||
title="Garlic Butter Pasta",
|
||||
matched=["pasta", "butter", "garlic"],
|
||||
missing=["parmesan"],
|
||||
match_count=3,
|
||||
)
|
||||
result = rerank_suggestions(req, [original, original, original, original])
|
||||
assert result is not None
|
||||
found = next((s for s in result if s.id == 42), None)
|
||||
assert found is not None
|
||||
assert found.title == "Garlic Butter Pasta"
|
||||
assert found.matched_ingredients == ["pasta", "butter", "garlic"]
|
||||
assert found.missing_ingredients == ["parmesan"]
|
||||
assert found.match_count == 3
|
||||
|
||||
def test_graceful_fallback_on_exception(self, monkeypatch):
|
||||
from app.services.recipe.reranker import rerank_suggestions
|
||||
# Simulate reranker raising at runtime
|
||||
import app.services.recipe.reranker as reranker_mod
|
||||
def _boom(query, candidates, top_n=0):
|
||||
raise RuntimeError("model exploded")
|
||||
monkeypatch.setattr(reranker_mod, "_do_rerank", _boom)
|
||||
req = _make_request(tier="paid")
|
||||
suggestions = [_make_suggestion(i, f"Recipe {i}", ["chicken"]) for i in range(4)]
|
||||
result = rerank_suggestions(req, suggestions)
|
||||
assert result is None
|
||||
|
||||
def test_original_suggestions_not_mutated(self):
|
||||
from app.services.recipe.reranker import rerank_suggestions
|
||||
req = _make_request(tier="paid")
|
||||
suggestions = [_make_suggestion(i, f"Recipe {i}", ["chicken"]) for i in range(4)]
|
||||
originals = [s.model_copy() for s in suggestions]
|
||||
rerank_suggestions(req, suggestions)
|
||||
for original, after in zip(originals, suggestions):
|
||||
assert original.rerank_score == after.rerank_score # None == None (no mutation)
|
||||
Loading…
Reference in a new issue