feat: cloud auth bypass, VRAM leasing, barcode EXIF fix, pipeline improvements
- cloud_session.py: CLOUD_AUTH_BYPASS_IPS with CIDR support; X-Real-IP for Docker bridge NAT-aware client IP resolution; local-dev DB path under CLOUD_DATA_ROOT for bypass sessions - compose.cloud.yml: thread CLOUD_AUTH_BYPASS_IPS from shell env; document Docker bridge CIDR requirement in .env.example - nginx.cloud.conf + nginx.conf: client_max_body_size 20m for barcode uploads - barcode_scanner.py: EXIF orientation correction (PIL ImageOps.exif_transpose) before cv2 decode; rotation coverage extended to [90, 180, 270, 45, 135] to catch sideways barcodes the 270° case was missing - llm_recipe.py: CF-core VRAM lease acquire/release wrapping LLMRouter calls - tasks/runner.py + config.py: COORDINATOR_URL + recipe_llm VRAM budget (4GB) - recipes.py: per-request Store creation inside asyncio.to_thread worker to avoid SQLite check_same_thread violations - download_datasets.py: HF_PARQUET_FILES strategy for repos without dataset builders (lishuyang/recipepairs direct parquet download) - derive_substitutions.py: use recipepairs_recipes.parquet for ingredient lookup; numpy array detection; JSON category parsing - test_build_flavorgraph_index.py: rewritten for CSV-based index format - pyproject.toml: add Pillow>=10.0 for EXIF rotation support
This commit is contained in:
parent
77627cec23
commit
33a5cdec37
14 changed files with 328 additions and 54 deletions
16
.env.example
16
.env.example
|
|
@ -11,6 +11,14 @@ DATA_DIR=./data
|
|||
# Database (defaults to DATA_DIR/kiwi.db)
|
||||
# DB_PATH=./data/kiwi.db
|
||||
|
||||
# Pipeline data directory for downloaded parquets (used by download_datasets.py)
|
||||
# Override to store large datasets on a separate drive or NAS
|
||||
# KIWI_PIPELINE_DATA_DIR=./data/pipeline
|
||||
|
||||
# CF-core resource coordinator (VRAM lease management)
|
||||
# Set to the coordinator URL when running alongside cf-core orchestration
|
||||
# COORDINATOR_URL=http://localhost:7700
|
||||
|
||||
# Processing
|
||||
USE_GPU=true
|
||||
GPU_MEMORY_LIMIT=6144
|
||||
|
|
@ -28,6 +36,14 @@ DEMO_MODE=false
|
|||
# Cloud mode (set in compose.cloud.yml; also set here for reference)
|
||||
# CLOUD_DATA_ROOT=/devl/kiwi-cloud-data
|
||||
# KIWI_DB=data/kiwi.db # local-mode DB path override
|
||||
# DEV ONLY: bypass JWT auth for these IPs/CIDRs (LAN testing without Caddy in the path).
|
||||
# NEVER set in production.
|
||||
# IMPORTANT: Docker port mapping NATs source IPs to the bridge gateway. When hitting
|
||||
# localhost:8515 (host → Docker → nginx → API), nginx sees 192.168.80.1, not 127.0.0.1.
|
||||
# Include the Docker bridge CIDR to allow localhost and LAN access through nginx.
|
||||
# Run: docker network inspect kiwi-cloud_kiwi-cloud-net | grep Subnet
|
||||
# Example: CLOUD_AUTH_BYPASS_IPS=10.1.10.0/24,127.0.0.1,::1,192.168.80.0/20
|
||||
# CLOUD_AUTH_BYPASS_IPS=
|
||||
|
||||
# Heimdall license server (required for cloud tier resolution)
|
||||
# HEIMDALL_URL=https://license.circuitforge.tech
|
||||
|
|
|
|||
|
|
@ -2,11 +2,11 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
|
||||
from app.cloud_session import CloudUser, get_session
|
||||
from app.db.session import get_store
|
||||
from app.db.store import Store
|
||||
from app.models.schemas.recipe import RecipeRequest, RecipeResult
|
||||
from app.services.recipe.recipe_engine import RecipeEngine
|
||||
|
|
@ -15,11 +15,25 @@ from app.tiers import can_use
|
|||
router = APIRouter()
|
||||
|
||||
|
||||
def _suggest_in_thread(db_path: Path, req: RecipeRequest) -> RecipeResult:
|
||||
"""Run recipe suggestion in a worker thread with its own Store connection.
|
||||
|
||||
SQLite connections cannot be shared across threads. This function creates
|
||||
a fresh Store (and therefore a fresh sqlite3.Connection) in the same thread
|
||||
where it will be used, avoiding ProgrammingError: SQLite objects created in
|
||||
a thread can only be used in that same thread.
|
||||
"""
|
||||
store = Store(db_path)
|
||||
try:
|
||||
return RecipeEngine(store).suggest(req)
|
||||
finally:
|
||||
store.close()
|
||||
|
||||
|
||||
@router.post("/suggest", response_model=RecipeResult)
|
||||
async def suggest_recipes(
|
||||
req: RecipeRequest,
|
||||
session: CloudUser = Depends(get_session),
|
||||
store: Store = Depends(get_store),
|
||||
) -> RecipeResult:
|
||||
# Inject session-authoritative tier/byok immediately — client-supplied values are ignored.
|
||||
req = req.model_copy(update={"tier": session.tier, "has_byok": session.has_byok})
|
||||
|
|
@ -35,13 +49,19 @@ async def suggest_recipes(
|
|||
)
|
||||
if req.style_id and not can_use("style_picker", req.tier):
|
||||
raise HTTPException(status_code=403, detail="Style picker requires Paid tier.")
|
||||
engine = RecipeEngine(store)
|
||||
return await asyncio.to_thread(engine.suggest, req)
|
||||
return await asyncio.to_thread(_suggest_in_thread, session.db, req)
|
||||
|
||||
|
||||
@router.get("/{recipe_id}")
|
||||
async def get_recipe(recipe_id: int, store: Store = Depends(get_store)) -> dict:
|
||||
recipe = await asyncio.to_thread(store.get_recipe, recipe_id)
|
||||
async def get_recipe(recipe_id: int, session: CloudUser = Depends(get_session)) -> dict:
|
||||
def _get(db_path: Path, rid: int) -> dict | None:
|
||||
store = Store(db_path)
|
||||
try:
|
||||
return store.get_recipe(rid)
|
||||
finally:
|
||||
store.close()
|
||||
|
||||
recipe = await asyncio.to_thread(_get, session.db, recipe_id)
|
||||
if not recipe:
|
||||
raise HTTPException(status_code=404, detail="Recipe not found.")
|
||||
return recipe
|
||||
|
|
|
|||
|
|
@ -37,6 +37,43 @@ DIRECTUS_JWT_SECRET: str = os.environ.get("DIRECTUS_JWT_SECRET", "")
|
|||
HEIMDALL_URL: str = os.environ.get("HEIMDALL_URL", "https://license.circuitforge.tech")
|
||||
HEIMDALL_ADMIN_TOKEN: str = os.environ.get("HEIMDALL_ADMIN_TOKEN", "")
|
||||
|
||||
# Dev bypass: comma-separated IPs or CIDR ranges that skip JWT auth.
|
||||
# NEVER set this in production. Intended only for LAN developer testing when
|
||||
# the request doesn't pass through Caddy (which normally injects X-CF-Session).
|
||||
# Example: CLOUD_AUTH_BYPASS_IPS=10.1.10.0/24,127.0.0.1
|
||||
import ipaddress as _ipaddress
|
||||
|
||||
_BYPASS_RAW: list[str] = [
|
||||
e.strip()
|
||||
for e in os.environ.get("CLOUD_AUTH_BYPASS_IPS", "").split(",")
|
||||
if e.strip()
|
||||
]
|
||||
|
||||
_BYPASS_NETS: list[_ipaddress.IPv4Network | _ipaddress.IPv6Network] = []
|
||||
_BYPASS_IPS: frozenset[str] = frozenset()
|
||||
|
||||
if _BYPASS_RAW:
|
||||
_nets, _ips = [], set()
|
||||
for entry in _BYPASS_RAW:
|
||||
try:
|
||||
_nets.append(_ipaddress.ip_network(entry, strict=False))
|
||||
except ValueError:
|
||||
_ips.add(entry) # treat non-parseable entries as bare IPs
|
||||
_BYPASS_NETS = _nets
|
||||
_BYPASS_IPS = frozenset(_ips)
|
||||
|
||||
|
||||
def _is_bypass_ip(ip: str) -> bool:
|
||||
if not ip:
|
||||
return False
|
||||
if ip in _BYPASS_IPS:
|
||||
return True
|
||||
try:
|
||||
addr = _ipaddress.ip_address(ip)
|
||||
return any(addr in net for net in _BYPASS_NETS)
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
_LOCAL_KIWI_DB: Path = Path(os.environ.get("KIWI_DB", "data/kiwi.db"))
|
||||
|
||||
_TIER_CACHE: dict[str, tuple[str, float]] = {}
|
||||
|
|
@ -153,12 +190,28 @@ def get_session(request: Request) -> CloudUser:
|
|||
|
||||
Local mode: fully-privileged "local" user pointing at local DB.
|
||||
Cloud mode: validates X-CF-Session JWT, provisions license, resolves tier.
|
||||
Dev bypass: if CLOUD_AUTH_BYPASS_IPS is set and the client IP matches,
|
||||
returns a "local" session without JWT validation (dev/LAN use only).
|
||||
"""
|
||||
has_byok = _detect_byok()
|
||||
|
||||
if not CLOUD_MODE:
|
||||
return CloudUser(user_id="local", tier="local", db=_LOCAL_KIWI_DB, has_byok=has_byok)
|
||||
|
||||
# Prefer X-Real-IP (set by nginx from the actual client address) over the
|
||||
# TCP peer address (which is nginx's container IP when behind the proxy).
|
||||
# Prefer X-Real-IP (set by nginx from the actual client address) over the
|
||||
# TCP peer address (which is nginx's container IP when behind the proxy).
|
||||
client_ip = (
|
||||
request.headers.get("x-real-ip", "")
|
||||
or (request.client.host if request.client else "")
|
||||
)
|
||||
if (_BYPASS_IPS or _BYPASS_NETS) and _is_bypass_ip(client_ip):
|
||||
log.debug("CLOUD_AUTH_BYPASS_IPS match for %s — returning local session", client_ip)
|
||||
# Use a dev DB under CLOUD_DATA_ROOT so the container has a writable path.
|
||||
dev_db = _user_db_path("local-dev")
|
||||
return CloudUser(user_id="local-dev", tier="local", db=dev_db, has_byok=has_byok)
|
||||
|
||||
raw_header = (
|
||||
request.headers.get("x-cf-session", "")
|
||||
or request.headers.get("cookie", "")
|
||||
|
|
|
|||
|
|
@ -43,6 +43,9 @@ class Settings:
|
|||
# Quality
|
||||
MIN_QUALITY_SCORE: float = float(os.environ.get("MIN_QUALITY_SCORE", "50.0"))
|
||||
|
||||
# CF-core resource coordinator (VRAM lease management)
|
||||
COORDINATOR_URL: str = os.environ.get("COORDINATOR_URL", "http://localhost:7700")
|
||||
|
||||
# Feature flags
|
||||
ENABLE_OCR: bool = os.environ.get("ENABLE_OCR", "false").lower() in ("1", "true", "yes")
|
||||
|
||||
|
|
|
|||
|
|
@ -5,6 +5,8 @@ This module provides functionality to detect and decode barcodes
|
|||
from images (UPC, EAN, QR codes, etc.).
|
||||
"""
|
||||
|
||||
import io
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from pyzbar import pyzbar
|
||||
|
|
@ -12,6 +14,12 @@ from pathlib import Path
|
|||
from typing import List, Dict, Any, Optional
|
||||
import logging
|
||||
|
||||
try:
|
||||
from PIL import Image as _PILImage
|
||||
_HAS_PIL = True
|
||||
except ImportError:
|
||||
_HAS_PIL = False
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
|
@ -76,9 +84,7 @@ class BarcodeScanner:
|
|||
# 4. Try rotations if still no barcodes found (handles tilted/rotated barcodes)
|
||||
if not barcodes:
|
||||
logger.info("No barcodes found in standard orientation, trying rotations...")
|
||||
# Try incremental angles: 30°, 60°, 90° (covers 0-90° range)
|
||||
# 0° already tried, 180° is functionally same as 0°, 90°/270° are same axis
|
||||
for angle in [30, 60, 90]:
|
||||
for angle in [90, 180, 270, 45, 135]:
|
||||
rotated_gray = self._rotate_image(gray, angle)
|
||||
rotated_color = self._rotate_image(image, angle)
|
||||
detected = self._detect_barcodes(rotated_gray, rotated_color)
|
||||
|
|
@ -264,6 +270,26 @@ class BarcodeScanner:
|
|||
|
||||
return list(seen.values())
|
||||
|
||||
def _fix_exif_orientation(self, image_bytes: bytes) -> bytes:
|
||||
"""Apply EXIF orientation correction so cv2 sees an upright image.
|
||||
|
||||
Phone cameras embed rotation in EXIF; cv2.imdecode ignores it,
|
||||
so a photo taken in portrait may arrive physically sideways in memory.
|
||||
"""
|
||||
if not _HAS_PIL:
|
||||
return image_bytes
|
||||
try:
|
||||
pil = _PILImage.open(io.BytesIO(image_bytes))
|
||||
pil = _PILImage.fromarray(np.array(pil)) # strips EXIF but applies orientation via PIL
|
||||
# Use ImageOps.exif_transpose for proper EXIF-aware rotation
|
||||
import PIL.ImageOps
|
||||
pil = PIL.ImageOps.exif_transpose(pil)
|
||||
buf = io.BytesIO()
|
||||
pil.save(buf, format="JPEG")
|
||||
return buf.getvalue()
|
||||
except Exception:
|
||||
return image_bytes
|
||||
|
||||
def scan_from_bytes(self, image_bytes: bytes) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Scan barcodes from image bytes (uploaded file).
|
||||
|
|
@ -275,6 +301,10 @@ class BarcodeScanner:
|
|||
List of detected barcodes
|
||||
"""
|
||||
try:
|
||||
# Apply EXIF orientation correction first (phone cameras embed rotation in EXIF;
|
||||
# cv2.imdecode ignores it, causing sideways barcodes to appear rotated in memory).
|
||||
image_bytes = self._fix_exif_orientation(image_bytes)
|
||||
|
||||
# Convert bytes to numpy array
|
||||
nparr = np.frombuffer(image_bytes, np.uint8)
|
||||
image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
||||
|
|
@ -300,11 +330,12 @@ class BarcodeScanner:
|
|||
)
|
||||
barcodes.extend(self._detect_barcodes(thresh, image))
|
||||
|
||||
# 3. Try rotations if still no barcodes found
|
||||
# 3. Try all 90° rotations + common tilt angles
|
||||
# 90/270 catches truly sideways barcodes; 180 catches upside-down;
|
||||
# 45/135 catches tilted barcodes on flat surfaces.
|
||||
if not barcodes:
|
||||
logger.info("No barcodes found in uploaded image, trying rotations...")
|
||||
# Try incremental angles: 30°, 60°, 90° (covers 0-90° range)
|
||||
for angle in [30, 60, 90]:
|
||||
for angle in [90, 180, 270, 45, 135]:
|
||||
rotated_gray = self._rotate_image(gray, angle)
|
||||
rotated_color = self._rotate_image(image, angle)
|
||||
detected = self._detect_barcodes(rotated_gray, rotated_color)
|
||||
|
|
|
|||
|
|
@ -113,8 +113,66 @@ class LLMRecipeGenerator:
|
|||
|
||||
return "\n".join(lines)
|
||||
|
||||
def _acquire_vram_lease(self) -> str | None:
|
||||
"""Request a VRAM lease from the CF-core coordinator. Best-effort — returns None if unavailable."""
|
||||
try:
|
||||
import httpx
|
||||
from app.core.config import settings
|
||||
from app.tasks.runner import VRAM_BUDGETS
|
||||
|
||||
budget_mb = int(VRAM_BUDGETS.get("recipe_llm", 4.0) * 1024)
|
||||
coordinator = settings.COORDINATOR_URL
|
||||
|
||||
nodes_resp = httpx.get(f"{coordinator}/api/nodes", timeout=2.0)
|
||||
if nodes_resp.status_code != 200:
|
||||
return None
|
||||
nodes = nodes_resp.json().get("nodes", [])
|
||||
if not nodes:
|
||||
return None
|
||||
|
||||
best_node = best_gpu = best_free = None
|
||||
for node in nodes:
|
||||
for gpu in node.get("gpus", []):
|
||||
free = gpu.get("vram_free_mb", 0)
|
||||
if best_free is None or free > best_free:
|
||||
best_node = node["node_id"]
|
||||
best_gpu = gpu["gpu_id"]
|
||||
best_free = free
|
||||
if best_node is None:
|
||||
return None
|
||||
|
||||
lease_resp = httpx.post(
|
||||
f"{coordinator}/api/leases",
|
||||
json={
|
||||
"node_id": best_node,
|
||||
"gpu_id": best_gpu,
|
||||
"mb": budget_mb,
|
||||
"service": "kiwi",
|
||||
"priority": 5,
|
||||
},
|
||||
timeout=3.0,
|
||||
)
|
||||
if lease_resp.status_code == 200:
|
||||
lease_id = lease_resp.json()["lease"]["lease_id"]
|
||||
logger.debug("Acquired VRAM lease %s for recipe_llm (%d MB)", lease_id, budget_mb)
|
||||
return lease_id
|
||||
except Exception as exc:
|
||||
logger.debug("VRAM lease acquire failed (non-fatal): %s", exc)
|
||||
return None
|
||||
|
||||
def _release_vram_lease(self, lease_id: str) -> None:
|
||||
"""Release a VRAM lease. Best-effort."""
|
||||
try:
|
||||
import httpx
|
||||
from app.core.config import settings
|
||||
httpx.delete(f"{settings.COORDINATOR_URL}/api/leases/{lease_id}", timeout=3.0)
|
||||
logger.debug("Released VRAM lease %s", lease_id)
|
||||
except Exception as exc:
|
||||
logger.debug("VRAM lease release failed (non-fatal): %s", exc)
|
||||
|
||||
def _call_llm(self, prompt: str) -> str:
|
||||
"""Call the LLM router and return the response text."""
|
||||
"""Call the LLM router with a VRAM lease held for the duration."""
|
||||
lease_id = self._acquire_vram_lease()
|
||||
try:
|
||||
from circuitforge_core.llm.router import LLMRouter
|
||||
router = LLMRouter()
|
||||
|
|
@ -122,6 +180,9 @@ class LLMRecipeGenerator:
|
|||
except Exception as exc:
|
||||
logger.error("LLM call failed: %s", exc)
|
||||
return ""
|
||||
finally:
|
||||
if lease_id:
|
||||
self._release_vram_lease(lease_id)
|
||||
|
||||
def _parse_response(self, response: str) -> dict[str, str | list[str]]:
|
||||
"""Parse LLM response text into structured recipe fields."""
|
||||
|
|
|
|||
|
|
@ -27,6 +27,9 @@ LLM_TASK_TYPES: frozenset[str] = frozenset({"expiry_llm_fallback"})
|
|||
VRAM_BUDGETS: dict[str, float] = {
|
||||
# ExpirationPredictor uses a small LLM (16 tokens out, single pass).
|
||||
"expiry_llm_fallback": 2.0,
|
||||
# Recipe LLM (levels 3-4): full recipe generation, ~200-500 tokens out.
|
||||
# Budget assumes a quantized 7B-class model.
|
||||
"recipe_llm": 4.0,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -14,6 +14,9 @@ services:
|
|||
CLOUD_MODE: "true"
|
||||
CLOUD_DATA_ROOT: /devl/kiwi-cloud-data
|
||||
# DIRECTUS_JWT_SECRET, HEIMDALL_URL, HEIMDALL_ADMIN_TOKEN — set in .env
|
||||
# DEV ONLY: comma-separated IPs that bypass JWT auth (LAN testing without Caddy).
|
||||
# Production deployments must NOT set this. Leave blank or omit entirely.
|
||||
CLOUD_AUTH_BYPASS_IPS: ${CLOUD_AUTH_BYPASS_IPS:-}
|
||||
volumes:
|
||||
- /devl/kiwi-cloud-data:/devl/kiwi-cloud-data
|
||||
# LLM config — shared with other CF products; read-only in container
|
||||
|
|
|
|||
|
|
@ -14,6 +14,8 @@ server {
|
|||
proxy_set_header X-Forwarded-Proto $http_x_forwarded_proto;
|
||||
# Forward the session header injected by Caddy from cf_session cookie.
|
||||
proxy_set_header X-CF-Session $http_x_cf_session;
|
||||
# Allow image uploads (barcode/receipt photos from phone cameras).
|
||||
client_max_body_size 20m;
|
||||
}
|
||||
|
||||
location = /index.html {
|
||||
|
|
|
|||
|
|
@ -9,6 +9,8 @@ server {
|
|||
proxy_pass http://172.17.0.1:8512;
|
||||
proxy_set_header Host $host;
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
# Allow image uploads (barcode/receipt photos from phone cameras).
|
||||
client_max_body_size 20m;
|
||||
}
|
||||
|
||||
location = /index.html {
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ dependencies = [
|
|||
"opencv-python>=4.8",
|
||||
"numpy>=1.25",
|
||||
"pyzbar>=0.1.9",
|
||||
"Pillow>=10.0",
|
||||
# HTTP client
|
||||
"httpx>=0.27",
|
||||
# CircuitForge shared scaffold
|
||||
|
|
|
|||
|
|
@ -3,24 +3,21 @@ Derive substitution pairs by diffing lishuyang/recipepairs.
|
|||
GPL-3.0 source -- derived annotations only, raw pairs not shipped.
|
||||
|
||||
Usage:
|
||||
conda run -n job-seeker python scripts/pipeline/derive_substitutions.py \
|
||||
PYTHONPATH=/path/to/kiwi conda run -n cf python scripts/pipeline/derive_substitutions.py \
|
||||
--db /path/to/kiwi.db \
|
||||
--recipepairs data/recipepairs.parquet
|
||||
--recipepairs data/pipeline/recipepairs.parquet \
|
||||
--recipepairs-recipes data/pipeline/recipepairs_recipes.parquet
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import argparse
|
||||
import json
|
||||
import re
|
||||
import sqlite3
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from scripts.pipeline.build_recipe_index import extract_ingredient_names
|
||||
|
||||
CONSTRAINT_COLS = ["vegan", "vegetarian", "dairy_free", "low_calorie",
|
||||
"low_carb", "low_fat", "low_sodium", "gluten_free"]
|
||||
|
||||
|
||||
def diff_ingredients(base: list[str], target: list[str]) -> tuple[list[str], list[str]]:
|
||||
base_set = set(base)
|
||||
|
|
@ -30,21 +27,44 @@ def diff_ingredients(base: list[str], target: list[str]) -> tuple[list[str], lis
|
|||
return removed, added
|
||||
|
||||
|
||||
def build(db_path: Path, recipepairs_path: Path) -> None:
|
||||
def _parse_categories(val: object) -> list[str]:
|
||||
"""Parse categories field which may be a list, str-repr list, or bare string."""
|
||||
if isinstance(val, list):
|
||||
return [str(v) for v in val]
|
||||
if isinstance(val, str):
|
||||
val = val.strip()
|
||||
if val.startswith("["):
|
||||
# parse list repr: ['a', 'b'] — use json after converting single quotes
|
||||
try:
|
||||
fixed = re.sub(r"'", '"', val)
|
||||
return json.loads(fixed)
|
||||
except Exception:
|
||||
pass
|
||||
return [val] if val else []
|
||||
return []
|
||||
|
||||
|
||||
def build(db_path: Path, recipepairs_path: Path, recipes_path: Path) -> None:
|
||||
conn = sqlite3.connect(db_path)
|
||||
try:
|
||||
print("Loading recipe ingredient index...")
|
||||
# Load ingredient lists from the bundled recipepairs recipe corpus.
|
||||
# This is GPL-3.0 data — we only use it for diffing; raw data is not persisted.
|
||||
print("Loading recipe ingredient index from recipepairs corpus...")
|
||||
recipes_df = pd.read_parquet(recipes_path, columns=["id", "ingredients"])
|
||||
recipe_ingredients: dict[str, list[str]] = {}
|
||||
for row in conn.execute("SELECT external_id, ingredient_names FROM recipes"):
|
||||
recipe_ingredients[str(row[0])] = json.loads(row[1])
|
||||
for _, r in recipes_df.iterrows():
|
||||
ings = r["ingredients"]
|
||||
if ings is not None and hasattr(ings, "__iter__") and not isinstance(ings, str):
|
||||
recipe_ingredients[str(int(r["id"]))] = [str(i) for i in ings]
|
||||
print(f" {len(recipe_ingredients)} recipes loaded")
|
||||
|
||||
df = pd.read_parquet(recipepairs_path)
|
||||
pairs_df = pd.read_parquet(recipepairs_path)
|
||||
pair_counts: dict[tuple, dict] = defaultdict(lambda: {"count": 0})
|
||||
|
||||
print("Diffing recipe pairs...")
|
||||
for _, row in df.iterrows():
|
||||
base_id = str(row.get("base", ""))
|
||||
target_id = str(row.get("target", ""))
|
||||
for _, row in pairs_df.iterrows():
|
||||
base_id = str(int(row["base"]))
|
||||
target_id = str(int(row["target"]))
|
||||
base_ings = recipe_ingredients.get(base_id, [])
|
||||
target_ings = recipe_ingredients.get(target_id, [])
|
||||
if not base_ings or not target_ings:
|
||||
|
|
@ -56,7 +76,9 @@ def build(db_path: Path, recipepairs_path: Path) -> None:
|
|||
|
||||
original = removed[0]
|
||||
substitute = added[0]
|
||||
constraints = [c for c in CONSTRAINT_COLS if row.get(c, 0)]
|
||||
constraints = _parse_categories(row.get("categories", []))
|
||||
if not constraints:
|
||||
continue
|
||||
for constraint in constraints:
|
||||
key = (original, substitute, constraint)
|
||||
pair_counts[key]["count"] += 1
|
||||
|
|
@ -102,7 +124,11 @@ def build(db_path: Path, recipepairs_path: Path) -> None:
|
|||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--db", required=True, type=Path)
|
||||
parser.add_argument("--recipepairs", required=True, type=Path)
|
||||
parser.add_argument("--db", required=True, type=Path)
|
||||
parser.add_argument("--recipepairs", required=True, type=Path,
|
||||
help="pairs.parquet from lishuyang/recipepairs")
|
||||
parser.add_argument("--recipepairs-recipes", required=True, type=Path,
|
||||
dest="recipepairs_recipes",
|
||||
help="recipes.parquet from lishuyang/recipepairs (ingredient lookup)")
|
||||
args = parser.parse_args()
|
||||
build(args.db, args.recipepairs)
|
||||
build(args.db, args.recipepairs, args.recipepairs_recipes)
|
||||
|
|
|
|||
|
|
@ -12,21 +12,33 @@ Downloads:
|
|||
"""
|
||||
from __future__ import annotations
|
||||
import argparse
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
|
||||
DATASETS = [
|
||||
("corbt/all-recipes", "train", "recipes_allrecipes.parquet"),
|
||||
("omid5/usda-fdc-foods-cleaned", "train", "usda_fdc_cleaned.parquet"),
|
||||
# Standard HuggingFace datasets: (hf_path, split, output_filename)
|
||||
HF_DATASETS = [
|
||||
("corbt/all-recipes", "train", "recipes_allrecipes.parquet"),
|
||||
("omid5/usda-fdc-foods-cleaned", "train", "usda_fdc_cleaned.parquet"),
|
||||
("jacktol/usda-branded-food-data","train", "usda_branded.parquet"),
|
||||
("lishuyang/recipepairs", "train", "recipepairs.parquet"),
|
||||
]
|
||||
|
||||
# Datasets that expose raw parquet files directly (no HF dataset builder)
|
||||
HF_PARQUET_FILES = [
|
||||
# (repo_id, repo_filename, output_filename)
|
||||
# lishuyang/recipepairs: GPL-3.0 ⚠ — derive only, don't ship
|
||||
("lishuyang/recipepairs", "pairs.parquet", "recipepairs.parquet"),
|
||||
]
|
||||
|
||||
|
||||
def download_all(data_dir: Path) -> None:
|
||||
data_dir.mkdir(parents=True, exist_ok=True)
|
||||
for hf_path, split, filename in DATASETS:
|
||||
|
||||
for hf_path, split, filename in HF_DATASETS:
|
||||
out = data_dir / filename
|
||||
if out.exists():
|
||||
print(f" skip {filename} (already exists)")
|
||||
|
|
@ -36,9 +48,29 @@ def download_all(data_dir: Path) -> None:
|
|||
ds.to_parquet(str(out))
|
||||
print(f" saved → {out}")
|
||||
|
||||
for repo_id, repo_file, filename in HF_PARQUET_FILES:
|
||||
out = data_dir / filename
|
||||
if out.exists():
|
||||
print(f" skip {filename} (already exists)")
|
||||
continue
|
||||
print(f" downloading {repo_id}/{repo_file} ...")
|
||||
cached = hf_hub_download(repo_id=repo_id, filename=repo_file, repo_type="dataset")
|
||||
shutil.copy2(cached, out)
|
||||
print(f" saved → {out}")
|
||||
|
||||
|
||||
_DEFAULT_DATA_DIR = Path(
|
||||
os.environ.get("KIWI_PIPELINE_DATA_DIR", "data/pipeline")
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--data-dir", required=True, type=Path)
|
||||
parser.add_argument(
|
||||
"--data-dir",
|
||||
type=Path,
|
||||
default=_DEFAULT_DATA_DIR,
|
||||
help="Directory for downloaded parquets (default: $KIWI_PIPELINE_DATA_DIR or data/pipeline)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
download_all(args.data_dir)
|
||||
|
|
|
|||
|
|
@ -1,18 +1,39 @@
|
|||
import csv
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def _write_csv(path: Path, rows: list[dict], fieldnames: list[str]) -> None:
|
||||
with open(path, "w", newline="") as f:
|
||||
w = csv.DictWriter(f, fieldnames=fieldnames)
|
||||
w.writeheader()
|
||||
w.writerows(rows)
|
||||
|
||||
|
||||
def test_parse_flavorgraph_node():
|
||||
from scripts.pipeline.build_flavorgraph_index import parse_ingredient_nodes
|
||||
sample = {
|
||||
"nodes": [
|
||||
{"id": "I_beef", "type": "ingredient", "name": "beef"},
|
||||
{"id": "C_pyrazine", "type": "compound", "name": "pyrazine"},
|
||||
{"id": "I_mushroom", "type": "ingredient", "name": "mushroom"},
|
||||
],
|
||||
"links": [
|
||||
{"source": "I_beef", "target": "C_pyrazine"},
|
||||
{"source": "I_mushroom","target": "C_pyrazine"},
|
||||
]
|
||||
}
|
||||
result = parse_ingredient_nodes(sample)
|
||||
assert "beef" in result
|
||||
assert "C_pyrazine" in result["beef"]
|
||||
assert "mushroom" in result
|
||||
assert "C_pyrazine" in result["mushroom"]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
nodes_path = Path(tmp) / "nodes.csv"
|
||||
edges_path = Path(tmp) / "edges.csv"
|
||||
|
||||
_write_csv(nodes_path, [
|
||||
{"node_id": "1", "name": "beef", "node_type": "ingredient"},
|
||||
{"node_id": "2", "name": "pyrazine", "node_type": "compound"},
|
||||
{"node_id": "3", "name": "mushroom", "node_type": "ingredient"},
|
||||
], ["node_id", "name", "node_type"])
|
||||
|
||||
_write_csv(edges_path, [
|
||||
{"id_1": "1", "id_2": "2", "score": "0.8"},
|
||||
{"id_1": "3", "id_2": "2", "score": "0.7"},
|
||||
], ["id_1", "id_2", "score"])
|
||||
|
||||
ingredient_to_compounds, compound_names = parse_ingredient_nodes(nodes_path, edges_path)
|
||||
|
||||
assert "beef" in ingredient_to_compounds
|
||||
assert "mushroom" in ingredient_to_compounds
|
||||
# compound node_id "2" maps to name "pyrazine"
|
||||
beef_compounds = ingredient_to_compounds["beef"]
|
||||
assert any(compound_names.get(c) == "pyrazine" for c in beef_compounds)
|
||||
mushroom_compounds = ingredient_to_compounds["mushroom"]
|
||||
assert any(compound_names.get(c) == "pyrazine" for c in mushroom_compounds)
|
||||
|
|
|
|||
Loading…
Reference in a new issue