fix: ElementClassifier — guard empty input, safe JSON decode, dedup heuristic elements, strengthen test assertions
This commit is contained in:
parent
727336fab9
commit
a03807951b
2 changed files with 21 additions and 5 deletions
|
|
@ -40,6 +40,17 @@ _HEURISTIC: list[tuple[list[str], str]] = [
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _safe_json_list(val) -> list:
|
||||||
|
if isinstance(val, list):
|
||||||
|
return val
|
||||||
|
if isinstance(val, str):
|
||||||
|
try:
|
||||||
|
return json.loads(val)
|
||||||
|
except Exception:
|
||||||
|
return []
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class IngredientProfile:
|
class IngredientProfile:
|
||||||
name: str
|
name: str
|
||||||
|
|
@ -71,6 +82,8 @@ class ElementClassifier:
|
||||||
def classify(self, ingredient_name: str) -> IngredientProfile:
|
def classify(self, ingredient_name: str) -> IngredientProfile:
|
||||||
"""Return element profile for a single ingredient name."""
|
"""Return element profile for a single ingredient name."""
|
||||||
name = ingredient_name.lower().strip()
|
name = ingredient_name.lower().strip()
|
||||||
|
if not name:
|
||||||
|
return IngredientProfile(name="", elements=[], source="heuristic")
|
||||||
row = self._store._fetch_one(
|
row = self._store._fetch_one(
|
||||||
"SELECT * FROM ingredient_profiles WHERE name = ?", (name,)
|
"SELECT * FROM ingredient_profiles WHERE name = ?", (name,)
|
||||||
)
|
)
|
||||||
|
|
@ -91,7 +104,7 @@ class ElementClassifier:
|
||||||
def _row_to_profile(self, row: dict) -> IngredientProfile:
|
def _row_to_profile(self, row: dict) -> IngredientProfile:
|
||||||
return IngredientProfile(
|
return IngredientProfile(
|
||||||
name=row["name"],
|
name=row["name"],
|
||||||
elements=json.loads(row.get("elements") or "[]"),
|
elements=_safe_json_list(row.get("elements")),
|
||||||
fat_pct=row.get("fat_pct") or 0.0,
|
fat_pct=row.get("fat_pct") or 0.0,
|
||||||
fat_saturated_pct=row.get("fat_saturated_pct") or 0.0,
|
fat_saturated_pct=row.get("fat_saturated_pct") or 0.0,
|
||||||
moisture_pct=row.get("moisture_pct") or 0.0,
|
moisture_pct=row.get("moisture_pct") or 0.0,
|
||||||
|
|
@ -100,7 +113,7 @@ class ElementClassifier:
|
||||||
binding_score=row.get("binding_score") or 0,
|
binding_score=row.get("binding_score") or 0,
|
||||||
glutamate_mg=row.get("glutamate_mg") or 0.0,
|
glutamate_mg=row.get("glutamate_mg") or 0.0,
|
||||||
ph_estimate=row.get("ph_estimate"),
|
ph_estimate=row.get("ph_estimate"),
|
||||||
flavor_molecule_ids=json.loads(row.get("flavor_molecule_ids") or "[]"),
|
flavor_molecule_ids=_safe_json_list(row.get("flavor_molecule_ids")),
|
||||||
heat_stable=bool(row.get("heat_stable", 1)),
|
heat_stable=bool(row.get("heat_stable", 1)),
|
||||||
add_timing=row.get("add_timing") or "any",
|
add_timing=row.get("add_timing") or "any",
|
||||||
acid_type=row.get("acid_type"),
|
acid_type=row.get("acid_type"),
|
||||||
|
|
@ -113,8 +126,10 @@ class ElementClassifier:
|
||||||
)
|
)
|
||||||
|
|
||||||
def _heuristic_profile(self, name: str) -> IngredientProfile:
|
def _heuristic_profile(self, name: str) -> IngredientProfile:
|
||||||
elements = []
|
seen: set[str] = set()
|
||||||
|
elements: list[str] = []
|
||||||
for keywords, element in _HEURISTIC:
|
for keywords, element in _HEURISTIC:
|
||||||
if any(kw in name for kw in keywords):
|
if element not in seen and any(kw in name for kw in keywords):
|
||||||
elements.append(element)
|
elements.append(element)
|
||||||
|
seen.add(element)
|
||||||
return IngredientProfile(name=name, elements=elements, source="heuristic")
|
return IngredientProfile(name=name, elements=elements, source="heuristic")
|
||||||
|
|
|
||||||
|
|
@ -35,6 +35,7 @@ def test_classify_known_ingredient(store_with_profiles):
|
||||||
assert "Richness" in profile.elements
|
assert "Richness" in profile.elements
|
||||||
assert profile.fat_pct == pytest.approx(81.0)
|
assert profile.fat_pct == pytest.approx(81.0)
|
||||||
assert profile.name == "butter"
|
assert profile.name == "butter"
|
||||||
|
assert profile.source == "db"
|
||||||
|
|
||||||
|
|
||||||
def test_classify_unknown_ingredient_uses_heuristic(store_with_profiles):
|
def test_classify_unknown_ingredient_uses_heuristic(store_with_profiles):
|
||||||
|
|
@ -42,7 +43,7 @@ def test_classify_unknown_ingredient_uses_heuristic(store_with_profiles):
|
||||||
clf = ElementClassifier(store_with_profiles)
|
clf = ElementClassifier(store_with_profiles)
|
||||||
profile = clf.classify("ghost pepper hot sauce")
|
profile = clf.classify("ghost pepper hot sauce")
|
||||||
# Heuristic should detect acid / aroma
|
# Heuristic should detect acid / aroma
|
||||||
assert len(profile.elements) > 0
|
assert "Aroma" in profile.elements # "pepper" in name matches Aroma heuristic
|
||||||
assert profile.name == "ghost pepper hot sauce"
|
assert profile.name == "ghost pepper hot sauce"
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue