feat: cross-encoder reranker for recipe suggestions (kiwi#117)
Integrates cf-core reranker into the L1/L2 recipe engine. Paid+ tier gets a BGE cross-encoder pass over the top-20 FTS candidates, scoring each recipe against the user's full context: pantry state, dietary constraints, allergies, expiry urgency, style preference, and effort preference. Free tier keeps the existing overlap sort unchanged. - New app/services/recipe/reranker.py: build_query, build_candidate_string, rerank_suggestions with tier gate (_RERANKER_TIERS) and graceful fallback - rerank_score field added to RecipeSuggestion (None on free tier, float on paid+) - recipe_engine.py: single call after candidate assembly, before final sort; hard_day_mode tier grouping preserved as primary sort when reranker active - Fix pre-existing circular import in app/services/__init__.py (eager import of ReceiptService triggered store.py → services → receipt_service → store) - 27 unit tests (mock backend, no model weights) + 2 engine-level tier tests; 325 tests passing, no regressions
This commit is contained in:
parent
91867f15f4
commit
b5eb8e4772
6 changed files with 520 additions and 7 deletions
|
|
@ -43,6 +43,7 @@ class RecipeSuggestion(BaseModel):
|
||||||
source_url: str | None = None
|
source_url: str | None = None
|
||||||
complexity: str | None = None # 'easy' | 'moderate' | 'involved'
|
complexity: str | None = None # 'easy' | 'moderate' | 'involved'
|
||||||
estimated_time_min: int | None = None # derived from step count + method signals
|
estimated_time_min: int | None = None # derived from step count + method signals
|
||||||
|
rerank_score: float | None = None # cross-encoder relevance score (paid+ only, None for free tier)
|
||||||
|
|
||||||
|
|
||||||
class GroceryLink(BaseModel):
|
class GroceryLink(BaseModel):
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,11 @@
|
||||||
Business logic services for Kiwi.
|
Business logic services for Kiwi.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from app.services.receipt_service import ReceiptService
|
|
||||||
|
|
||||||
__all__ = ["ReceiptService"]
|
__all__ = ["ReceiptService"]
|
||||||
|
|
||||||
|
|
||||||
|
def __getattr__(name: str):
|
||||||
|
if name == "ReceiptService":
|
||||||
|
from app.services.receipt_service import ReceiptService
|
||||||
|
return ReceiptService
|
||||||
|
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||||
|
|
|
||||||
|
|
@ -26,6 +26,7 @@ from app.services.recipe.grocery_links import GroceryLinkBuilder
|
||||||
from app.services.recipe.substitution_engine import SubstitutionEngine
|
from app.services.recipe.substitution_engine import SubstitutionEngine
|
||||||
from app.services.recipe.sensory import SensoryExclude, build_sensory_exclude, passes_sensory_filter
|
from app.services.recipe.sensory import SensoryExclude, build_sensory_exclude, passes_sensory_filter
|
||||||
from app.services.recipe.time_effort import parse_time_effort
|
from app.services.recipe.time_effort import parse_time_effort
|
||||||
|
from app.services.recipe.reranker import rerank_suggestions
|
||||||
|
|
||||||
_LEFTOVER_DAILY_MAX_FREE = 5
|
_LEFTOVER_DAILY_MAX_FREE = 5
|
||||||
|
|
||||||
|
|
@ -880,11 +881,21 @@ class RecipeEngine:
|
||||||
estimated_time_min=row_time_min,
|
estimated_time_min=row_time_min,
|
||||||
))
|
))
|
||||||
|
|
||||||
# Sort corpus results — assembly templates are now served from a dedicated tab.
|
# Sort corpus results.
|
||||||
# Hard day mode: primary sort by tier (0=premade, 1=simple, 2=moderate),
|
# Paid+ tier: cross-encoder reranker orders by full pantry + dietary fit.
|
||||||
# then by match_count within each tier.
|
# Free tier (or reranker failure): overlap sort with hard_day_mode tier grouping.
|
||||||
# Normal mode: sort by match_count descending.
|
reranked = rerank_suggestions(req, suggestions)
|
||||||
if req.hard_day_mode and hard_day_tier_map:
|
if reranked is not None:
|
||||||
|
# Reranker provided relevance order. In hard_day_mode, still respect
|
||||||
|
# tier grouping as primary sort; reranker order applies within each tier.
|
||||||
|
if req.hard_day_mode and hard_day_tier_map:
|
||||||
|
suggestions = sorted(
|
||||||
|
reranked,
|
||||||
|
key=lambda s: hard_day_tier_map.get(s.id, 1),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
suggestions = reranked
|
||||||
|
elif req.hard_day_mode and hard_day_tier_map:
|
||||||
suggestions = sorted(
|
suggestions = sorted(
|
||||||
suggestions,
|
suggestions,
|
||||||
key=lambda s: (hard_day_tier_map.get(s.id, 1), -s.match_count),
|
key=lambda s: (hard_day_tier_map.get(s.id, 1), -s.match_count),
|
||||||
|
|
|
||||||
175
app/services/recipe/reranker.py
Normal file
175
app/services/recipe/reranker.py
Normal file
|
|
@ -0,0 +1,175 @@
|
||||||
|
"""
|
||||||
|
Reranker integration for recipe suggestions.
|
||||||
|
|
||||||
|
Wraps circuitforge_core.reranker to score recipe candidates against a
|
||||||
|
natural-language query built from the user's pantry, constraints, and
|
||||||
|
preferences. Paid+ tier only; free tier returns None (caller keeps
|
||||||
|
existing sort). All exceptions are caught and logged — the reranker
|
||||||
|
must never break recipe suggestions.
|
||||||
|
|
||||||
|
Environment:
|
||||||
|
CF_RERANKER_MOCK=1 — force mock backend (tests, no model required)
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
from app.models.schemas.recipe import RecipeRequest, RecipeSuggestion
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Tiers that get reranker access.
|
||||||
|
_RERANKER_TIERS: frozenset[str] = frozenset({"paid", "premium", "local"})
|
||||||
|
|
||||||
|
# Minimum candidates worth reranking — below this the cross-encoder
|
||||||
|
# overhead is not justified and the overlap sort is fine.
|
||||||
|
_MIN_CANDIDATES: int = 3
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class RerankerInput:
|
||||||
|
"""Intermediate representation passed to the reranker."""
|
||||||
|
query: str
|
||||||
|
candidates: list[str]
|
||||||
|
suggestion_ids: list[int] # parallel to candidates, for re-mapping
|
||||||
|
|
||||||
|
|
||||||
|
# ── Query builder ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def build_query(req: RecipeRequest) -> str:
|
||||||
|
"""Build a natural-language query string from the recipe request.
|
||||||
|
|
||||||
|
Encodes the user's full context so the cross-encoder can score
|
||||||
|
relevance, dietary fit, and expiry urgency in a single pass.
|
||||||
|
Only non-empty segments are included.
|
||||||
|
"""
|
||||||
|
parts: list[str] = []
|
||||||
|
|
||||||
|
if req.pantry_items:
|
||||||
|
parts.append(f"Recipe using: {', '.join(req.pantry_items)}")
|
||||||
|
|
||||||
|
if req.exclude_ingredients:
|
||||||
|
parts.append(f"Avoid: {', '.join(req.exclude_ingredients)}")
|
||||||
|
|
||||||
|
if req.allergies:
|
||||||
|
parts.append(f"Allergies: {', '.join(req.allergies)}")
|
||||||
|
|
||||||
|
if req.constraints:
|
||||||
|
parts.append(f"Dietary: {', '.join(req.constraints)}")
|
||||||
|
|
||||||
|
if req.category:
|
||||||
|
parts.append(f"Category: {req.category}")
|
||||||
|
|
||||||
|
if req.style_id:
|
||||||
|
parts.append(f"Style: {req.style_id}")
|
||||||
|
|
||||||
|
if req.complexity_filter:
|
||||||
|
parts.append(f"Prefer: {req.complexity_filter}")
|
||||||
|
|
||||||
|
if req.hard_day_mode:
|
||||||
|
parts.append("Prefer: easy, minimal effort")
|
||||||
|
|
||||||
|
# Secondary pantry items carry a state label (e.g. "stale", "overripe")
|
||||||
|
# that helps the reranker favour recipes suited to those specific states.
|
||||||
|
if req.secondary_pantry_items:
|
||||||
|
expiry_parts = [f"{name} ({state})" for name, state in req.secondary_pantry_items.items()]
|
||||||
|
parts.append(f"Use soon: {', '.join(expiry_parts)}")
|
||||||
|
elif req.expiry_first:
|
||||||
|
parts.append("Prefer: recipes that use expiring items first")
|
||||||
|
|
||||||
|
return ". ".join(parts) + "." if parts else "Recipe."
|
||||||
|
|
||||||
|
|
||||||
|
# ── Candidate builder ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def build_candidate_string(suggestion: RecipeSuggestion) -> str:
|
||||||
|
"""Build a candidate string for a single recipe suggestion.
|
||||||
|
|
||||||
|
Format: "{title}. Ingredients: {comma-joined ingredients}"
|
||||||
|
Matched ingredients appear before missing ones.
|
||||||
|
Directions excluded to stay within BGE's 512-token window.
|
||||||
|
"""
|
||||||
|
ingredients = suggestion.matched_ingredients + suggestion.missing_ingredients
|
||||||
|
if not ingredients:
|
||||||
|
return suggestion.title
|
||||||
|
return f"{suggestion.title}. Ingredients: {', '.join(ingredients)}"
|
||||||
|
|
||||||
|
|
||||||
|
# ── Input assembler ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def build_reranker_input(
|
||||||
|
req: RecipeRequest,
|
||||||
|
suggestions: list[RecipeSuggestion],
|
||||||
|
) -> RerankerInput:
|
||||||
|
"""Assemble query and candidate strings for the reranker."""
|
||||||
|
query = build_query(req)
|
||||||
|
candidates: list[str] = []
|
||||||
|
ids: list[int] = []
|
||||||
|
for s in suggestions:
|
||||||
|
candidates.append(build_candidate_string(s))
|
||||||
|
ids.append(s.id)
|
||||||
|
return RerankerInput(query=query, candidates=candidates, suggestion_ids=ids)
|
||||||
|
|
||||||
|
|
||||||
|
# ── cf-core seam (isolated for monkeypatching in tests) ──────────────────────
|
||||||
|
|
||||||
|
def _do_rerank(query: str, candidates: list[str], top_n: int = 0):
|
||||||
|
"""Thin wrapper around cf-core rerank(). Extracted so tests can patch it."""
|
||||||
|
from circuitforge_core.reranker import rerank
|
||||||
|
return rerank(query, candidates, top_n=top_n)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Public entry point ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def rerank_suggestions(
|
||||||
|
req: RecipeRequest,
|
||||||
|
suggestions: list[RecipeSuggestion],
|
||||||
|
) -> list[RecipeSuggestion] | None:
|
||||||
|
"""Rerank suggestions using the cf-core cross-encoder.
|
||||||
|
|
||||||
|
Returns a reordered list with rerank_score populated, or None when:
|
||||||
|
- Tier is not paid+ (free tier keeps overlap sort)
|
||||||
|
- Fewer than _MIN_CANDIDATES suggestions (not worth the overhead)
|
||||||
|
- Any exception is raised (graceful fallback to existing sort)
|
||||||
|
|
||||||
|
The caller should treat None as "keep existing sort order".
|
||||||
|
Original suggestions are never mutated.
|
||||||
|
"""
|
||||||
|
if req.tier not in _RERANKER_TIERS:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if len(suggestions) < _MIN_CANDIDATES:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
rinput = build_reranker_input(req, suggestions)
|
||||||
|
results = _do_rerank(rinput.query, rinput.candidates, top_n=0)
|
||||||
|
|
||||||
|
# Map reranked results back to RecipeSuggestion objects using the
|
||||||
|
# candidate string as key (build_candidate_string is deterministic).
|
||||||
|
candidate_map: dict[str, RecipeSuggestion] = {
|
||||||
|
build_candidate_string(s): s for s in suggestions
|
||||||
|
}
|
||||||
|
|
||||||
|
reranked: list[RecipeSuggestion] = []
|
||||||
|
for rr in results:
|
||||||
|
suggestion = candidate_map.get(rr.candidate)
|
||||||
|
if suggestion is not None:
|
||||||
|
reranked.append(suggestion.model_copy(
|
||||||
|
update={"rerank_score": round(float(rr.score), 4)}
|
||||||
|
))
|
||||||
|
|
||||||
|
if len(reranked) < len(suggestions):
|
||||||
|
log.warning(
|
||||||
|
"Reranker lost %d/%d suggestions during mapping, falling back",
|
||||||
|
len(suggestions) - len(reranked),
|
||||||
|
len(suggestions),
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
return reranked
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
log.exception("Reranker failed, falling back to overlap sort")
|
||||||
|
return None
|
||||||
|
|
@ -170,3 +170,37 @@ def test_within_time_over_limit_fails():
|
||||||
from app.services.recipe.recipe_engine import _within_time
|
from app.services.recipe.recipe_engine import _within_time
|
||||||
steps = ["brown onions for 15 minutes", "simmer for 30 minutes"]
|
steps = ["brown onions for 15 minutes", "simmer for 30 minutes"]
|
||||||
assert _within_time(steps, max_total_min=30) is False
|
assert _within_time(steps, max_total_min=30) is False
|
||||||
|
|
||||||
|
|
||||||
|
# ── Reranker tier-gating tests ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_paid_tier_suggest_populates_rerank_score(store_with_recipes, monkeypatch):
|
||||||
|
"""Paid tier: at least one suggestion should have rerank_score populated."""
|
||||||
|
monkeypatch.setenv("CF_RERANKER_MOCK", "1")
|
||||||
|
try:
|
||||||
|
from circuitforge_core.reranker import reset_reranker
|
||||||
|
reset_reranker()
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("cf-core reranker not installed")
|
||||||
|
|
||||||
|
from app.services.recipe.recipe_engine import RecipeEngine
|
||||||
|
from app.models.schemas.recipe import RecipeRequest
|
||||||
|
engine = RecipeEngine(store_with_recipes)
|
||||||
|
req = RecipeRequest(pantry_items=["butter", "parmesan", "pasta"], level=1, tier="paid")
|
||||||
|
result = engine.suggest(req)
|
||||||
|
# Need at least _MIN_CANDIDATES for reranker to fire
|
||||||
|
from app.services.recipe.reranker import _MIN_CANDIDATES
|
||||||
|
if len(result.suggestions) >= _MIN_CANDIDATES:
|
||||||
|
assert any(s.rerank_score is not None for s in result.suggestions)
|
||||||
|
|
||||||
|
reset_reranker()
|
||||||
|
|
||||||
|
|
||||||
|
def test_free_tier_suggest_has_no_rerank_score(store_with_recipes):
|
||||||
|
"""Free tier: rerank_score must be None on all suggestions."""
|
||||||
|
from app.services.recipe.recipe_engine import RecipeEngine
|
||||||
|
from app.models.schemas.recipe import RecipeRequest
|
||||||
|
engine = RecipeEngine(store_with_recipes)
|
||||||
|
req = RecipeRequest(pantry_items=["butter", "parmesan"], level=1, tier="free")
|
||||||
|
result = engine.suggest(req)
|
||||||
|
assert all(s.rerank_score is None for s in result.suggestions)
|
||||||
|
|
|
||||||
287
tests/services/recipe/test_reranker.py
Normal file
287
tests/services/recipe/test_reranker.py
Normal file
|
|
@ -0,0 +1,287 @@
|
||||||
|
"""
|
||||||
|
Tests for app.services.recipe.reranker.
|
||||||
|
|
||||||
|
All tests use CF_RERANKER_MOCK=1 -- no model weights required.
|
||||||
|
The mock reranker scores by Jaccard similarity of query tokens vs candidate
|
||||||
|
tokens, which is deterministic and fast.
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
# ── Fixtures ──────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def mock_reranker(monkeypatch):
|
||||||
|
"""Force mock backend and reset the cf-core singleton before/after each test."""
|
||||||
|
monkeypatch.setenv("CF_RERANKER_MOCK", "1")
|
||||||
|
try:
|
||||||
|
from circuitforge_core.reranker import reset_reranker
|
||||||
|
reset_reranker()
|
||||||
|
yield
|
||||||
|
reset_reranker()
|
||||||
|
except ImportError:
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
def _make_request(**kwargs):
|
||||||
|
from app.models.schemas.recipe import RecipeRequest
|
||||||
|
defaults = dict(pantry_items=["chicken", "rice"], tier="paid")
|
||||||
|
defaults.update(kwargs)
|
||||||
|
return RecipeRequest(**defaults)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_suggestion(id: int, title: str, matched: list[str], missing: list[str] | None = None, match_count: int | None = None):
|
||||||
|
from app.models.schemas.recipe import RecipeSuggestion
|
||||||
|
mi = missing or []
|
||||||
|
return RecipeSuggestion(
|
||||||
|
id=id,
|
||||||
|
title=title,
|
||||||
|
match_count=match_count if match_count is not None else len(matched),
|
||||||
|
matched_ingredients=matched,
|
||||||
|
missing_ingredients=mi,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── TestBuildQuery ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestBuildQuery:
|
||||||
|
def test_basic_pantry(self):
|
||||||
|
from app.services.recipe.reranker import build_query
|
||||||
|
req = _make_request(pantry_items=["chicken", "rice", "broccoli"])
|
||||||
|
query = build_query(req)
|
||||||
|
assert "chicken" in query
|
||||||
|
assert "rice" in query
|
||||||
|
assert "broccoli" in query
|
||||||
|
|
||||||
|
def test_exclude_ingredients_included(self):
|
||||||
|
from app.services.recipe.reranker import build_query
|
||||||
|
req = _make_request(exclude_ingredients=["cilantro", "fish sauce"])
|
||||||
|
query = build_query(req)
|
||||||
|
assert "cilantro" in query
|
||||||
|
assert "fish sauce" in query
|
||||||
|
|
||||||
|
def test_allergies_separate_from_exclude(self):
|
||||||
|
from app.services.recipe.reranker import build_query
|
||||||
|
req = _make_request(allergies=["shellfish"], exclude_ingredients=["cilantro"])
|
||||||
|
query = build_query(req)
|
||||||
|
# Both should appear, and they should be in separate labeled segments
|
||||||
|
assert "shellfish" in query
|
||||||
|
assert "cilantro" in query
|
||||||
|
allergy_pos = query.index("shellfish")
|
||||||
|
exclude_pos = query.index("cilantro")
|
||||||
|
assert allergy_pos != exclude_pos
|
||||||
|
|
||||||
|
def test_allergies_labeled_separately(self):
|
||||||
|
from app.services.recipe.reranker import build_query
|
||||||
|
req = _make_request(allergies=["peanuts"], exclude_ingredients=[])
|
||||||
|
query = build_query(req)
|
||||||
|
assert "Allergies" in query or "allerg" in query.lower()
|
||||||
|
|
||||||
|
def test_constraints_included(self):
|
||||||
|
from app.services.recipe.reranker import build_query
|
||||||
|
req = _make_request(constraints=["gluten-free", "dairy-free"])
|
||||||
|
query = build_query(req)
|
||||||
|
assert "gluten-free" in query
|
||||||
|
assert "dairy-free" in query
|
||||||
|
|
||||||
|
def test_category_included(self):
|
||||||
|
from app.services.recipe.reranker import build_query
|
||||||
|
req = _make_request(category="Soup")
|
||||||
|
query = build_query(req)
|
||||||
|
assert "Soup" in query
|
||||||
|
|
||||||
|
def test_complexity_filter_included(self):
|
||||||
|
from app.services.recipe.reranker import build_query
|
||||||
|
req = _make_request(complexity_filter="easy")
|
||||||
|
query = build_query(req)
|
||||||
|
assert "easy" in query
|
||||||
|
|
||||||
|
def test_hard_day_mode_signal(self):
|
||||||
|
from app.services.recipe.reranker import build_query
|
||||||
|
req = _make_request(hard_day_mode=True)
|
||||||
|
query = build_query(req)
|
||||||
|
assert "easy" in query.lower() or "minimal" in query.lower() or "effort" in query.lower()
|
||||||
|
|
||||||
|
def test_secondary_pantry_items_expiry(self):
|
||||||
|
from app.services.recipe.reranker import build_query
|
||||||
|
req = _make_request(secondary_pantry_items={"bread": "stale", "banana": "overripe"})
|
||||||
|
query = build_query(req)
|
||||||
|
assert "bread" in query
|
||||||
|
assert "banana" in query
|
||||||
|
# State labels add specificity for the cross-encoder
|
||||||
|
assert "stale" in query or "overripe" in query
|
||||||
|
|
||||||
|
def test_expiry_first_without_secondary(self):
|
||||||
|
from app.services.recipe.reranker import build_query
|
||||||
|
req = _make_request(expiry_first=True, secondary_pantry_items={})
|
||||||
|
query = build_query(req)
|
||||||
|
assert "expir" in query.lower()
|
||||||
|
|
||||||
|
def test_style_id_included(self):
|
||||||
|
from app.services.recipe.reranker import build_query
|
||||||
|
req = _make_request(style_id="mediterranean")
|
||||||
|
query = build_query(req)
|
||||||
|
assert "mediterranean" in query.lower()
|
||||||
|
|
||||||
|
def test_empty_pantry_returns_fallback(self):
|
||||||
|
from app.services.recipe.reranker import build_query
|
||||||
|
req = _make_request(pantry_items=[])
|
||||||
|
query = build_query(req)
|
||||||
|
assert len(query) > 0 # never empty string
|
||||||
|
|
||||||
|
def test_no_duplicate_separators(self):
|
||||||
|
from app.services.recipe.reranker import build_query
|
||||||
|
req = _make_request(
|
||||||
|
pantry_items=["egg"],
|
||||||
|
allergies=["nuts"],
|
||||||
|
constraints=["vegan"],
|
||||||
|
complexity_filter="easy",
|
||||||
|
)
|
||||||
|
query = build_query(req)
|
||||||
|
assert ".." not in query # no doubled periods from empty segments
|
||||||
|
|
||||||
|
|
||||||
|
# ── TestBuildCandidateString ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestBuildCandidateString:
|
||||||
|
def test_title_and_ingredients(self):
|
||||||
|
from app.services.recipe.reranker import build_candidate_string
|
||||||
|
s = _make_suggestion(1, "Chicken Fried Rice", ["chicken", "rice"], ["soy sauce"])
|
||||||
|
candidate = build_candidate_string(s)
|
||||||
|
assert candidate.startswith("Chicken Fried Rice")
|
||||||
|
assert "chicken" in candidate
|
||||||
|
assert "rice" in candidate
|
||||||
|
assert "soy sauce" in candidate
|
||||||
|
|
||||||
|
def test_title_only_when_no_ingredients(self):
|
||||||
|
from app.services.recipe.reranker import build_candidate_string
|
||||||
|
s = _make_suggestion(2, "Mystery Dish", [], [])
|
||||||
|
candidate = build_candidate_string(s)
|
||||||
|
assert candidate == "Mystery Dish"
|
||||||
|
assert "Ingredients:" not in candidate
|
||||||
|
|
||||||
|
def test_matched_before_missing(self):
|
||||||
|
from app.services.recipe.reranker import build_candidate_string
|
||||||
|
s = _make_suggestion(3, "Pasta Dish", ["pasta", "butter"], ["parmesan", "cream"])
|
||||||
|
candidate = build_candidate_string(s)
|
||||||
|
pasta_pos = candidate.index("pasta")
|
||||||
|
parmesan_pos = candidate.index("parmesan")
|
||||||
|
assert pasta_pos < parmesan_pos
|
||||||
|
|
||||||
|
|
||||||
|
# ── TestBuildRerankerInput ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestBuildRerankerInput:
|
||||||
|
def test_parallel_ids_and_candidates(self):
|
||||||
|
from app.services.recipe.reranker import build_reranker_input
|
||||||
|
req = _make_request()
|
||||||
|
suggestions = [
|
||||||
|
_make_suggestion(10, "Recipe A", ["chicken"]),
|
||||||
|
_make_suggestion(20, "Recipe B", ["rice"]),
|
||||||
|
_make_suggestion(30, "Recipe C", ["broccoli"]),
|
||||||
|
]
|
||||||
|
rinput = build_reranker_input(req, suggestions)
|
||||||
|
assert len(rinput.candidates) == 3
|
||||||
|
assert len(rinput.suggestion_ids) == 3
|
||||||
|
assert rinput.suggestion_ids == [10, 20, 30]
|
||||||
|
|
||||||
|
def test_query_matches_build_query(self):
|
||||||
|
from app.services.recipe.reranker import build_reranker_input, build_query
|
||||||
|
req = _make_request(pantry_items=["egg", "cheese"], constraints=["vegetarian"])
|
||||||
|
suggestions = [_make_suggestion(1, "Omelette", ["egg", "cheese"])]
|
||||||
|
rinput = build_reranker_input(req, suggestions)
|
||||||
|
assert rinput.query == build_query(req)
|
||||||
|
|
||||||
|
|
||||||
|
# ── TestRerankSuggestions ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestRerankSuggestions:
|
||||||
|
def test_free_tier_returns_none(self):
|
||||||
|
from app.services.recipe.reranker import rerank_suggestions
|
||||||
|
req = _make_request(tier="free")
|
||||||
|
suggestions = [_make_suggestion(i, f"Recipe {i}", ["chicken"]) for i in range(5)]
|
||||||
|
result = rerank_suggestions(req, suggestions)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_paid_tier_returns_reranked_list(self):
|
||||||
|
from app.services.recipe.reranker import rerank_suggestions
|
||||||
|
req = _make_request(tier="paid", pantry_items=["chicken", "rice"])
|
||||||
|
suggestions = [
|
||||||
|
_make_suggestion(1, "Chicken Fried Rice", ["chicken", "rice"]),
|
||||||
|
_make_suggestion(2, "Chocolate Cake", ["flour", "sugar", "cocoa"]),
|
||||||
|
_make_suggestion(3, "Chicken Soup", ["chicken", "broth"]),
|
||||||
|
]
|
||||||
|
result = rerank_suggestions(req, suggestions)
|
||||||
|
assert result is not None
|
||||||
|
assert len(result) == len(suggestions)
|
||||||
|
|
||||||
|
def test_rerank_score_is_populated(self):
|
||||||
|
from app.services.recipe.reranker import rerank_suggestions
|
||||||
|
req = _make_request(tier="paid")
|
||||||
|
suggestions = [_make_suggestion(i, f"Recipe {i}", ["chicken"]) for i in range(4)]
|
||||||
|
result = rerank_suggestions(req, suggestions)
|
||||||
|
assert result is not None
|
||||||
|
assert all(s.rerank_score is not None for s in result)
|
||||||
|
assert all(isinstance(s.rerank_score, float) for s in result)
|
||||||
|
|
||||||
|
def test_too_few_candidates_returns_none(self):
|
||||||
|
from app.services.recipe.reranker import rerank_suggestions, _MIN_CANDIDATES
|
||||||
|
req = _make_request(tier="paid")
|
||||||
|
suggestions = [_make_suggestion(i, f"Recipe {i}", ["chicken"]) for i in range(_MIN_CANDIDATES - 1)]
|
||||||
|
result = rerank_suggestions(req, suggestions)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_premium_tier_gets_reranker(self):
|
||||||
|
from app.services.recipe.reranker import rerank_suggestions
|
||||||
|
req = _make_request(tier="premium")
|
||||||
|
suggestions = [_make_suggestion(i, f"Recipe {i}", ["chicken"]) for i in range(4)]
|
||||||
|
result = rerank_suggestions(req, suggestions)
|
||||||
|
assert result is not None
|
||||||
|
|
||||||
|
def test_local_tier_gets_reranker(self):
|
||||||
|
from app.services.recipe.reranker import rerank_suggestions
|
||||||
|
req = _make_request(tier="local")
|
||||||
|
suggestions = [_make_suggestion(i, f"Recipe {i}", ["chicken"]) for i in range(4)]
|
||||||
|
result = rerank_suggestions(req, suggestions)
|
||||||
|
assert result is not None
|
||||||
|
|
||||||
|
def test_preserves_all_suggestion_fields(self):
|
||||||
|
from app.services.recipe.reranker import rerank_suggestions
|
||||||
|
req = _make_request(tier="paid")
|
||||||
|
original = _make_suggestion(
|
||||||
|
id=42,
|
||||||
|
title="Garlic Butter Pasta",
|
||||||
|
matched=["pasta", "butter", "garlic"],
|
||||||
|
missing=["parmesan"],
|
||||||
|
match_count=3,
|
||||||
|
)
|
||||||
|
result = rerank_suggestions(req, [original, original, original, original])
|
||||||
|
assert result is not None
|
||||||
|
found = next((s for s in result if s.id == 42), None)
|
||||||
|
assert found is not None
|
||||||
|
assert found.title == "Garlic Butter Pasta"
|
||||||
|
assert found.matched_ingredients == ["pasta", "butter", "garlic"]
|
||||||
|
assert found.missing_ingredients == ["parmesan"]
|
||||||
|
assert found.match_count == 3
|
||||||
|
|
||||||
|
def test_graceful_fallback_on_exception(self, monkeypatch):
|
||||||
|
from app.services.recipe.reranker import rerank_suggestions
|
||||||
|
# Simulate reranker raising at runtime
|
||||||
|
import app.services.recipe.reranker as reranker_mod
|
||||||
|
def _boom(query, candidates, top_n=0):
|
||||||
|
raise RuntimeError("model exploded")
|
||||||
|
monkeypatch.setattr(reranker_mod, "_do_rerank", _boom)
|
||||||
|
req = _make_request(tier="paid")
|
||||||
|
suggestions = [_make_suggestion(i, f"Recipe {i}", ["chicken"]) for i in range(4)]
|
||||||
|
result = rerank_suggestions(req, suggestions)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_original_suggestions_not_mutated(self):
|
||||||
|
from app.services.recipe.reranker import rerank_suggestions
|
||||||
|
req = _make_request(tier="paid")
|
||||||
|
suggestions = [_make_suggestion(i, f"Recipe {i}", ["chicken"]) for i in range(4)]
|
||||||
|
originals = [s.model_copy() for s in suggestions]
|
||||||
|
rerank_suggestions(req, suggestions)
|
||||||
|
for original, after in zip(originals, suggestions):
|
||||||
|
assert original.rerank_score == after.rerank_score # None == None (no mutation)
|
||||||
Loading…
Reference in a new issue