diff --git a/.env.example b/.env.example index f0f6415..b9a7ade 100644 --- a/.env.example +++ b/.env.example @@ -11,6 +11,14 @@ DATA_DIR=./data # Database (defaults to DATA_DIR/kiwi.db) # DB_PATH=./data/kiwi.db +# Pipeline data directory for downloaded parquets (used by download_datasets.py) +# Override to store large datasets on a separate drive or NAS +# KIWI_PIPELINE_DATA_DIR=./data/pipeline + +# CF-core resource coordinator (VRAM lease management) +# Set to the coordinator URL when running alongside cf-core orchestration +# COORDINATOR_URL=http://localhost:7700 + # Processing USE_GPU=true GPU_MEMORY_LIMIT=6144 @@ -28,6 +36,14 @@ DEMO_MODE=false # Cloud mode (set in compose.cloud.yml; also set here for reference) # CLOUD_DATA_ROOT=/devl/kiwi-cloud-data # KIWI_DB=data/kiwi.db # local-mode DB path override +# DEV ONLY: bypass JWT auth for these IPs/CIDRs (LAN testing without Caddy in the path). +# NEVER set in production. +# IMPORTANT: Docker port mapping NATs source IPs to the bridge gateway. When hitting +# localhost:8515 (host → Docker → nginx → API), nginx sees 192.168.80.1, not 127.0.0.1. +# Include the Docker bridge CIDR to allow localhost and LAN access through nginx. +# Run: docker network inspect kiwi-cloud_kiwi-cloud-net | grep Subnet +# Example: CLOUD_AUTH_BYPASS_IPS=10.1.10.0/24,127.0.0.1,::1,192.168.80.0/20 +# CLOUD_AUTH_BYPASS_IPS= # Heimdall license server (required for cloud tier resolution) # HEIMDALL_URL=https://license.circuitforge.tech diff --git a/app/api/endpoints/recipes.py b/app/api/endpoints/recipes.py index fedeabd..12db086 100644 --- a/app/api/endpoints/recipes.py +++ b/app/api/endpoints/recipes.py @@ -2,11 +2,11 @@ from __future__ import annotations import asyncio +from pathlib import Path from fastapi import APIRouter, Depends, HTTPException from app.cloud_session import CloudUser, get_session -from app.db.session import get_store from app.db.store import Store from app.models.schemas.recipe import RecipeRequest, RecipeResult from app.services.recipe.recipe_engine import RecipeEngine @@ -15,11 +15,25 @@ from app.tiers import can_use router = APIRouter() +def _suggest_in_thread(db_path: Path, req: RecipeRequest) -> RecipeResult: + """Run recipe suggestion in a worker thread with its own Store connection. + + SQLite connections cannot be shared across threads. This function creates + a fresh Store (and therefore a fresh sqlite3.Connection) in the same thread + where it will be used, avoiding ProgrammingError: SQLite objects created in + a thread can only be used in that same thread. + """ + store = Store(db_path) + try: + return RecipeEngine(store).suggest(req) + finally: + store.close() + + @router.post("/suggest", response_model=RecipeResult) async def suggest_recipes( req: RecipeRequest, session: CloudUser = Depends(get_session), - store: Store = Depends(get_store), ) -> RecipeResult: # Inject session-authoritative tier/byok immediately — client-supplied values are ignored. req = req.model_copy(update={"tier": session.tier, "has_byok": session.has_byok}) @@ -35,13 +49,19 @@ async def suggest_recipes( ) if req.style_id and not can_use("style_picker", req.tier): raise HTTPException(status_code=403, detail="Style picker requires Paid tier.") - engine = RecipeEngine(store) - return await asyncio.to_thread(engine.suggest, req) + return await asyncio.to_thread(_suggest_in_thread, session.db, req) @router.get("/{recipe_id}") -async def get_recipe(recipe_id: int, store: Store = Depends(get_store)) -> dict: - recipe = await asyncio.to_thread(store.get_recipe, recipe_id) +async def get_recipe(recipe_id: int, session: CloudUser = Depends(get_session)) -> dict: + def _get(db_path: Path, rid: int) -> dict | None: + store = Store(db_path) + try: + return store.get_recipe(rid) + finally: + store.close() + + recipe = await asyncio.to_thread(_get, session.db, recipe_id) if not recipe: raise HTTPException(status_code=404, detail="Recipe not found.") return recipe diff --git a/app/cloud_session.py b/app/cloud_session.py index 4cba4a1..4431da1 100644 --- a/app/cloud_session.py +++ b/app/cloud_session.py @@ -37,6 +37,43 @@ DIRECTUS_JWT_SECRET: str = os.environ.get("DIRECTUS_JWT_SECRET", "") HEIMDALL_URL: str = os.environ.get("HEIMDALL_URL", "https://license.circuitforge.tech") HEIMDALL_ADMIN_TOKEN: str = os.environ.get("HEIMDALL_ADMIN_TOKEN", "") +# Dev bypass: comma-separated IPs or CIDR ranges that skip JWT auth. +# NEVER set this in production. Intended only for LAN developer testing when +# the request doesn't pass through Caddy (which normally injects X-CF-Session). +# Example: CLOUD_AUTH_BYPASS_IPS=10.1.10.0/24,127.0.0.1 +import ipaddress as _ipaddress + +_BYPASS_RAW: list[str] = [ + e.strip() + for e in os.environ.get("CLOUD_AUTH_BYPASS_IPS", "").split(",") + if e.strip() +] + +_BYPASS_NETS: list[_ipaddress.IPv4Network | _ipaddress.IPv6Network] = [] +_BYPASS_IPS: frozenset[str] = frozenset() + +if _BYPASS_RAW: + _nets, _ips = [], set() + for entry in _BYPASS_RAW: + try: + _nets.append(_ipaddress.ip_network(entry, strict=False)) + except ValueError: + _ips.add(entry) # treat non-parseable entries as bare IPs + _BYPASS_NETS = _nets + _BYPASS_IPS = frozenset(_ips) + + +def _is_bypass_ip(ip: str) -> bool: + if not ip: + return False + if ip in _BYPASS_IPS: + return True + try: + addr = _ipaddress.ip_address(ip) + return any(addr in net for net in _BYPASS_NETS) + except ValueError: + return False + _LOCAL_KIWI_DB: Path = Path(os.environ.get("KIWI_DB", "data/kiwi.db")) _TIER_CACHE: dict[str, tuple[str, float]] = {} @@ -153,12 +190,28 @@ def get_session(request: Request) -> CloudUser: Local mode: fully-privileged "local" user pointing at local DB. Cloud mode: validates X-CF-Session JWT, provisions license, resolves tier. + Dev bypass: if CLOUD_AUTH_BYPASS_IPS is set and the client IP matches, + returns a "local" session without JWT validation (dev/LAN use only). """ has_byok = _detect_byok() if not CLOUD_MODE: return CloudUser(user_id="local", tier="local", db=_LOCAL_KIWI_DB, has_byok=has_byok) + # Prefer X-Real-IP (set by nginx from the actual client address) over the + # TCP peer address (which is nginx's container IP when behind the proxy). + # Prefer X-Real-IP (set by nginx from the actual client address) over the + # TCP peer address (which is nginx's container IP when behind the proxy). + client_ip = ( + request.headers.get("x-real-ip", "") + or (request.client.host if request.client else "") + ) + if (_BYPASS_IPS or _BYPASS_NETS) and _is_bypass_ip(client_ip): + log.debug("CLOUD_AUTH_BYPASS_IPS match for %s — returning local session", client_ip) + # Use a dev DB under CLOUD_DATA_ROOT so the container has a writable path. + dev_db = _user_db_path("local-dev") + return CloudUser(user_id="local-dev", tier="local", db=dev_db, has_byok=has_byok) + raw_header = ( request.headers.get("x-cf-session", "") or request.headers.get("cookie", "") diff --git a/app/core/config.py b/app/core/config.py index 1f8015d..0b06934 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -43,6 +43,9 @@ class Settings: # Quality MIN_QUALITY_SCORE: float = float(os.environ.get("MIN_QUALITY_SCORE", "50.0")) + # CF-core resource coordinator (VRAM lease management) + COORDINATOR_URL: str = os.environ.get("COORDINATOR_URL", "http://localhost:7700") + # Feature flags ENABLE_OCR: bool = os.environ.get("ENABLE_OCR", "false").lower() in ("1", "true", "yes") diff --git a/app/services/barcode_scanner.py b/app/services/barcode_scanner.py index 889e807..f5f667b 100644 --- a/app/services/barcode_scanner.py +++ b/app/services/barcode_scanner.py @@ -5,6 +5,8 @@ This module provides functionality to detect and decode barcodes from images (UPC, EAN, QR codes, etc.). """ +import io + import cv2 import numpy as np from pyzbar import pyzbar @@ -12,6 +14,12 @@ from pathlib import Path from typing import List, Dict, Any, Optional import logging +try: + from PIL import Image as _PILImage + _HAS_PIL = True +except ImportError: + _HAS_PIL = False + logger = logging.getLogger(__name__) @@ -76,9 +84,7 @@ class BarcodeScanner: # 4. Try rotations if still no barcodes found (handles tilted/rotated barcodes) if not barcodes: logger.info("No barcodes found in standard orientation, trying rotations...") - # Try incremental angles: 30°, 60°, 90° (covers 0-90° range) - # 0° already tried, 180° is functionally same as 0°, 90°/270° are same axis - for angle in [30, 60, 90]: + for angle in [90, 180, 270, 45, 135]: rotated_gray = self._rotate_image(gray, angle) rotated_color = self._rotate_image(image, angle) detected = self._detect_barcodes(rotated_gray, rotated_color) @@ -264,6 +270,26 @@ class BarcodeScanner: return list(seen.values()) + def _fix_exif_orientation(self, image_bytes: bytes) -> bytes: + """Apply EXIF orientation correction so cv2 sees an upright image. + + Phone cameras embed rotation in EXIF; cv2.imdecode ignores it, + so a photo taken in portrait may arrive physically sideways in memory. + """ + if not _HAS_PIL: + return image_bytes + try: + pil = _PILImage.open(io.BytesIO(image_bytes)) + pil = _PILImage.fromarray(np.array(pil)) # strips EXIF but applies orientation via PIL + # Use ImageOps.exif_transpose for proper EXIF-aware rotation + import PIL.ImageOps + pil = PIL.ImageOps.exif_transpose(pil) + buf = io.BytesIO() + pil.save(buf, format="JPEG") + return buf.getvalue() + except Exception: + return image_bytes + def scan_from_bytes(self, image_bytes: bytes) -> List[Dict[str, Any]]: """ Scan barcodes from image bytes (uploaded file). @@ -275,6 +301,10 @@ class BarcodeScanner: List of detected barcodes """ try: + # Apply EXIF orientation correction first (phone cameras embed rotation in EXIF; + # cv2.imdecode ignores it, causing sideways barcodes to appear rotated in memory). + image_bytes = self._fix_exif_orientation(image_bytes) + # Convert bytes to numpy array nparr = np.frombuffer(image_bytes, np.uint8) image = cv2.imdecode(nparr, cv2.IMREAD_COLOR) @@ -300,11 +330,12 @@ class BarcodeScanner: ) barcodes.extend(self._detect_barcodes(thresh, image)) - # 3. Try rotations if still no barcodes found + # 3. Try all 90° rotations + common tilt angles + # 90/270 catches truly sideways barcodes; 180 catches upside-down; + # 45/135 catches tilted barcodes on flat surfaces. if not barcodes: logger.info("No barcodes found in uploaded image, trying rotations...") - # Try incremental angles: 30°, 60°, 90° (covers 0-90° range) - for angle in [30, 60, 90]: + for angle in [90, 180, 270, 45, 135]: rotated_gray = self._rotate_image(gray, angle) rotated_color = self._rotate_image(image, angle) detected = self._detect_barcodes(rotated_gray, rotated_color) diff --git a/app/services/recipe/llm_recipe.py b/app/services/recipe/llm_recipe.py index 9081a8d..a01ca72 100644 --- a/app/services/recipe/llm_recipe.py +++ b/app/services/recipe/llm_recipe.py @@ -113,8 +113,66 @@ class LLMRecipeGenerator: return "\n".join(lines) + def _acquire_vram_lease(self) -> str | None: + """Request a VRAM lease from the CF-core coordinator. Best-effort — returns None if unavailable.""" + try: + import httpx + from app.core.config import settings + from app.tasks.runner import VRAM_BUDGETS + + budget_mb = int(VRAM_BUDGETS.get("recipe_llm", 4.0) * 1024) + coordinator = settings.COORDINATOR_URL + + nodes_resp = httpx.get(f"{coordinator}/api/nodes", timeout=2.0) + if nodes_resp.status_code != 200: + return None + nodes = nodes_resp.json().get("nodes", []) + if not nodes: + return None + + best_node = best_gpu = best_free = None + for node in nodes: + for gpu in node.get("gpus", []): + free = gpu.get("vram_free_mb", 0) + if best_free is None or free > best_free: + best_node = node["node_id"] + best_gpu = gpu["gpu_id"] + best_free = free + if best_node is None: + return None + + lease_resp = httpx.post( + f"{coordinator}/api/leases", + json={ + "node_id": best_node, + "gpu_id": best_gpu, + "mb": budget_mb, + "service": "kiwi", + "priority": 5, + }, + timeout=3.0, + ) + if lease_resp.status_code == 200: + lease_id = lease_resp.json()["lease"]["lease_id"] + logger.debug("Acquired VRAM lease %s for recipe_llm (%d MB)", lease_id, budget_mb) + return lease_id + except Exception as exc: + logger.debug("VRAM lease acquire failed (non-fatal): %s", exc) + return None + + def _release_vram_lease(self, lease_id: str) -> None: + """Release a VRAM lease. Best-effort.""" + try: + import httpx + from app.core.config import settings + httpx.delete(f"{settings.COORDINATOR_URL}/api/leases/{lease_id}", timeout=3.0) + logger.debug("Released VRAM lease %s", lease_id) + except Exception as exc: + logger.debug("VRAM lease release failed (non-fatal): %s", exc) + def _call_llm(self, prompt: str) -> str: - """Call the LLM router and return the response text.""" + """Call the LLM router with a VRAM lease held for the duration.""" + lease_id = self._acquire_vram_lease() try: from circuitforge_core.llm.router import LLMRouter router = LLMRouter() @@ -122,6 +180,9 @@ class LLMRecipeGenerator: except Exception as exc: logger.error("LLM call failed: %s", exc) return "" + finally: + if lease_id: + self._release_vram_lease(lease_id) def _parse_response(self, response: str) -> dict[str, str | list[str]]: """Parse LLM response text into structured recipe fields.""" diff --git a/app/tasks/runner.py b/app/tasks/runner.py index 99da8ee..f9315c3 100644 --- a/app/tasks/runner.py +++ b/app/tasks/runner.py @@ -27,6 +27,9 @@ LLM_TASK_TYPES: frozenset[str] = frozenset({"expiry_llm_fallback"}) VRAM_BUDGETS: dict[str, float] = { # ExpirationPredictor uses a small LLM (16 tokens out, single pass). "expiry_llm_fallback": 2.0, + # Recipe LLM (levels 3-4): full recipe generation, ~200-500 tokens out. + # Budget assumes a quantized 7B-class model. + "recipe_llm": 4.0, } diff --git a/compose.cloud.yml b/compose.cloud.yml index 02c0efa..7c4fcfd 100644 --- a/compose.cloud.yml +++ b/compose.cloud.yml @@ -14,6 +14,9 @@ services: CLOUD_MODE: "true" CLOUD_DATA_ROOT: /devl/kiwi-cloud-data # DIRECTUS_JWT_SECRET, HEIMDALL_URL, HEIMDALL_ADMIN_TOKEN — set in .env + # DEV ONLY: comma-separated IPs that bypass JWT auth (LAN testing without Caddy). + # Production deployments must NOT set this. Leave blank or omit entirely. + CLOUD_AUTH_BYPASS_IPS: ${CLOUD_AUTH_BYPASS_IPS:-} volumes: - /devl/kiwi-cloud-data:/devl/kiwi-cloud-data # LLM config — shared with other CF products; read-only in container diff --git a/docker/web/nginx.cloud.conf b/docker/web/nginx.cloud.conf index ea8d37a..dcd5bc9 100644 --- a/docker/web/nginx.cloud.conf +++ b/docker/web/nginx.cloud.conf @@ -14,6 +14,8 @@ server { proxy_set_header X-Forwarded-Proto $http_x_forwarded_proto; # Forward the session header injected by Caddy from cf_session cookie. proxy_set_header X-CF-Session $http_x_cf_session; + # Allow image uploads (barcode/receipt photos from phone cameras). + client_max_body_size 20m; } location = /index.html { diff --git a/docker/web/nginx.conf b/docker/web/nginx.conf index a987d0f..e341ee1 100644 --- a/docker/web/nginx.conf +++ b/docker/web/nginx.conf @@ -9,6 +9,8 @@ server { proxy_pass http://172.17.0.1:8512; proxy_set_header Host $host; proxy_set_header X-Real-IP $remote_addr; + # Allow image uploads (barcode/receipt photos from phone cameras). + client_max_body_size 20m; } location = /index.html { diff --git a/pyproject.toml b/pyproject.toml index 29d3b06..f0386ce 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,6 +18,7 @@ dependencies = [ "opencv-python>=4.8", "numpy>=1.25", "pyzbar>=0.1.9", + "Pillow>=10.0", # HTTP client "httpx>=0.27", # CircuitForge shared scaffold diff --git a/scripts/pipeline/derive_substitutions.py b/scripts/pipeline/derive_substitutions.py index 72f0277..889bbf4 100644 --- a/scripts/pipeline/derive_substitutions.py +++ b/scripts/pipeline/derive_substitutions.py @@ -3,24 +3,21 @@ Derive substitution pairs by diffing lishuyang/recipepairs. GPL-3.0 source -- derived annotations only, raw pairs not shipped. Usage: - conda run -n job-seeker python scripts/pipeline/derive_substitutions.py \ + PYTHONPATH=/path/to/kiwi conda run -n cf python scripts/pipeline/derive_substitutions.py \ --db /path/to/kiwi.db \ - --recipepairs data/recipepairs.parquet + --recipepairs data/pipeline/recipepairs.parquet \ + --recipepairs-recipes data/pipeline/recipepairs_recipes.parquet """ from __future__ import annotations import argparse import json +import re import sqlite3 from collections import defaultdict from pathlib import Path import pandas as pd -from scripts.pipeline.build_recipe_index import extract_ingredient_names - -CONSTRAINT_COLS = ["vegan", "vegetarian", "dairy_free", "low_calorie", - "low_carb", "low_fat", "low_sodium", "gluten_free"] - def diff_ingredients(base: list[str], target: list[str]) -> tuple[list[str], list[str]]: base_set = set(base) @@ -30,21 +27,44 @@ def diff_ingredients(base: list[str], target: list[str]) -> tuple[list[str], lis return removed, added -def build(db_path: Path, recipepairs_path: Path) -> None: +def _parse_categories(val: object) -> list[str]: + """Parse categories field which may be a list, str-repr list, or bare string.""" + if isinstance(val, list): + return [str(v) for v in val] + if isinstance(val, str): + val = val.strip() + if val.startswith("["): + # parse list repr: ['a', 'b'] — use json after converting single quotes + try: + fixed = re.sub(r"'", '"', val) + return json.loads(fixed) + except Exception: + pass + return [val] if val else [] + return [] + + +def build(db_path: Path, recipepairs_path: Path, recipes_path: Path) -> None: conn = sqlite3.connect(db_path) try: - print("Loading recipe ingredient index...") + # Load ingredient lists from the bundled recipepairs recipe corpus. + # This is GPL-3.0 data — we only use it for diffing; raw data is not persisted. + print("Loading recipe ingredient index from recipepairs corpus...") + recipes_df = pd.read_parquet(recipes_path, columns=["id", "ingredients"]) recipe_ingredients: dict[str, list[str]] = {} - for row in conn.execute("SELECT external_id, ingredient_names FROM recipes"): - recipe_ingredients[str(row[0])] = json.loads(row[1]) + for _, r in recipes_df.iterrows(): + ings = r["ingredients"] + if ings is not None and hasattr(ings, "__iter__") and not isinstance(ings, str): + recipe_ingredients[str(int(r["id"]))] = [str(i) for i in ings] + print(f" {len(recipe_ingredients)} recipes loaded") - df = pd.read_parquet(recipepairs_path) + pairs_df = pd.read_parquet(recipepairs_path) pair_counts: dict[tuple, dict] = defaultdict(lambda: {"count": 0}) print("Diffing recipe pairs...") - for _, row in df.iterrows(): - base_id = str(row.get("base", "")) - target_id = str(row.get("target", "")) + for _, row in pairs_df.iterrows(): + base_id = str(int(row["base"])) + target_id = str(int(row["target"])) base_ings = recipe_ingredients.get(base_id, []) target_ings = recipe_ingredients.get(target_id, []) if not base_ings or not target_ings: @@ -56,7 +76,9 @@ def build(db_path: Path, recipepairs_path: Path) -> None: original = removed[0] substitute = added[0] - constraints = [c for c in CONSTRAINT_COLS if row.get(c, 0)] + constraints = _parse_categories(row.get("categories", [])) + if not constraints: + continue for constraint in constraints: key = (original, substitute, constraint) pair_counts[key]["count"] += 1 @@ -102,7 +124,11 @@ def build(db_path: Path, recipepairs_path: Path) -> None: if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--db", required=True, type=Path) - parser.add_argument("--recipepairs", required=True, type=Path) + parser.add_argument("--db", required=True, type=Path) + parser.add_argument("--recipepairs", required=True, type=Path, + help="pairs.parquet from lishuyang/recipepairs") + parser.add_argument("--recipepairs-recipes", required=True, type=Path, + dest="recipepairs_recipes", + help="recipes.parquet from lishuyang/recipepairs (ingredient lookup)") args = parser.parse_args() - build(args.db, args.recipepairs) + build(args.db, args.recipepairs, args.recipepairs_recipes) diff --git a/scripts/pipeline/download_datasets.py b/scripts/pipeline/download_datasets.py index ab2d733..d680712 100644 --- a/scripts/pipeline/download_datasets.py +++ b/scripts/pipeline/download_datasets.py @@ -12,21 +12,33 @@ Downloads: """ from __future__ import annotations import argparse +import os +import shutil from pathlib import Path + from datasets import load_dataset +from huggingface_hub import hf_hub_download -DATASETS = [ - ("corbt/all-recipes", "train", "recipes_allrecipes.parquet"), - ("omid5/usda-fdc-foods-cleaned", "train", "usda_fdc_cleaned.parquet"), +# Standard HuggingFace datasets: (hf_path, split, output_filename) +HF_DATASETS = [ + ("corbt/all-recipes", "train", "recipes_allrecipes.parquet"), + ("omid5/usda-fdc-foods-cleaned", "train", "usda_fdc_cleaned.parquet"), ("jacktol/usda-branded-food-data","train", "usda_branded.parquet"), - ("lishuyang/recipepairs", "train", "recipepairs.parquet"), +] + +# Datasets that expose raw parquet files directly (no HF dataset builder) +HF_PARQUET_FILES = [ + # (repo_id, repo_filename, output_filename) + # lishuyang/recipepairs: GPL-3.0 ⚠ — derive only, don't ship + ("lishuyang/recipepairs", "pairs.parquet", "recipepairs.parquet"), ] def download_all(data_dir: Path) -> None: data_dir.mkdir(parents=True, exist_ok=True) - for hf_path, split, filename in DATASETS: + + for hf_path, split, filename in HF_DATASETS: out = data_dir / filename if out.exists(): print(f" skip {filename} (already exists)") @@ -36,9 +48,29 @@ def download_all(data_dir: Path) -> None: ds.to_parquet(str(out)) print(f" saved → {out}") + for repo_id, repo_file, filename in HF_PARQUET_FILES: + out = data_dir / filename + if out.exists(): + print(f" skip {filename} (already exists)") + continue + print(f" downloading {repo_id}/{repo_file} ...") + cached = hf_hub_download(repo_id=repo_id, filename=repo_file, repo_type="dataset") + shutil.copy2(cached, out) + print(f" saved → {out}") + + +_DEFAULT_DATA_DIR = Path( + os.environ.get("KIWI_PIPELINE_DATA_DIR", "data/pipeline") +) + if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--data-dir", required=True, type=Path) + parser.add_argument( + "--data-dir", + type=Path, + default=_DEFAULT_DATA_DIR, + help="Directory for downloaded parquets (default: $KIWI_PIPELINE_DATA_DIR or data/pipeline)", + ) args = parser.parse_args() download_all(args.data_dir) diff --git a/tests/pipeline/test_build_flavorgraph_index.py b/tests/pipeline/test_build_flavorgraph_index.py index febf381..b6949ab 100644 --- a/tests/pipeline/test_build_flavorgraph_index.py +++ b/tests/pipeline/test_build_flavorgraph_index.py @@ -1,18 +1,39 @@ +import csv +import tempfile +from pathlib import Path + + +def _write_csv(path: Path, rows: list[dict], fieldnames: list[str]) -> None: + with open(path, "w", newline="") as f: + w = csv.DictWriter(f, fieldnames=fieldnames) + w.writeheader() + w.writerows(rows) + + def test_parse_flavorgraph_node(): from scripts.pipeline.build_flavorgraph_index import parse_ingredient_nodes - sample = { - "nodes": [ - {"id": "I_beef", "type": "ingredient", "name": "beef"}, - {"id": "C_pyrazine", "type": "compound", "name": "pyrazine"}, - {"id": "I_mushroom", "type": "ingredient", "name": "mushroom"}, - ], - "links": [ - {"source": "I_beef", "target": "C_pyrazine"}, - {"source": "I_mushroom","target": "C_pyrazine"}, - ] - } - result = parse_ingredient_nodes(sample) - assert "beef" in result - assert "C_pyrazine" in result["beef"] - assert "mushroom" in result - assert "C_pyrazine" in result["mushroom"] + + with tempfile.TemporaryDirectory() as tmp: + nodes_path = Path(tmp) / "nodes.csv" + edges_path = Path(tmp) / "edges.csv" + + _write_csv(nodes_path, [ + {"node_id": "1", "name": "beef", "node_type": "ingredient"}, + {"node_id": "2", "name": "pyrazine", "node_type": "compound"}, + {"node_id": "3", "name": "mushroom", "node_type": "ingredient"}, + ], ["node_id", "name", "node_type"]) + + _write_csv(edges_path, [ + {"id_1": "1", "id_2": "2", "score": "0.8"}, + {"id_1": "3", "id_2": "2", "score": "0.7"}, + ], ["id_1", "id_2", "score"]) + + ingredient_to_compounds, compound_names = parse_ingredient_nodes(nodes_path, edges_path) + + assert "beef" in ingredient_to_compounds + assert "mushroom" in ingredient_to_compounds + # compound node_id "2" maps to name "pyrazine" + beef_compounds = ingredient_to_compounds["beef"] + assert any(compound_names.get(c) == "pyrazine" for c in beef_compounds) + mushroom_compounds = ingredient_to_compounds["mushroom"] + assert any(compound_names.get(c) == "pyrazine" for c in mushroom_compounds)