From 636bffda5a461ebe341b302de7ee055e5777d9f6 Mon Sep 17 00:00:00 2001 From: pyr0ball Date: Tue, 31 Mar 2026 09:25:48 -0700 Subject: [PATCH] feat(tasks): add background task scheduler for LLM expiry fallback Uses circuitforge_core.tasks.scheduler. VRAM detection via cf-orch when available, falling back to unlimited. Adds expiry_llm_fallback task type to background-predict expiry dates for items the LUT doesn't cover. --- app/db/migrations/006_background_tasks.sql | 18 +++ app/main.py | 11 ++ app/tasks/__init__.py | 0 app/tasks/runner.py | 142 +++++++++++++++++++++ app/tasks/scheduler.py | 23 ++++ tests/test_tasks/__init__.py | 0 tests/test_tasks/test_runner.py | 126 ++++++++++++++++++ 7 files changed, 320 insertions(+) create mode 100644 app/db/migrations/006_background_tasks.sql create mode 100644 app/tasks/__init__.py create mode 100644 app/tasks/runner.py create mode 100644 app/tasks/scheduler.py create mode 100644 tests/test_tasks/__init__.py create mode 100644 tests/test_tasks/test_runner.py diff --git a/app/db/migrations/006_background_tasks.sql b/app/db/migrations/006_background_tasks.sql new file mode 100644 index 0000000..d98450a --- /dev/null +++ b/app/db/migrations/006_background_tasks.sql @@ -0,0 +1,18 @@ +-- 006_background_tasks.sql +-- Shared background task queue used by the LLM task scheduler. +-- Schema mirrors Peregrine's background_tasks for circuitforge-core compatibility. + +CREATE TABLE IF NOT EXISTS background_tasks ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + task_type TEXT NOT NULL, + job_id INTEGER NOT NULL DEFAULT 0, + status TEXT NOT NULL DEFAULT 'queued', + params TEXT, + error TEXT, + stage TEXT, + created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP +); + +CREATE INDEX IF NOT EXISTS idx_bg_tasks_status_type + ON background_tasks (status, task_type); diff --git a/app/main.py b/app/main.py index 912660f..24bdfdb 100644 --- a/app/main.py +++ b/app/main.py @@ -17,7 +17,18 @@ logger = logging.getLogger(__name__) async def lifespan(app: FastAPI): logger.info("Starting Kiwi API...") settings.ensure_dirs() + + # Start LLM background task scheduler + from app.tasks.scheduler import get_scheduler + get_scheduler(settings.DB_PATH) + logger.info("Task scheduler started.") + yield + + # Graceful scheduler shutdown + from app.tasks.scheduler import get_scheduler, reset_scheduler + get_scheduler(settings.DB_PATH).shutdown(timeout=10.0) + reset_scheduler() logger.info("Kiwi API shutting down.") diff --git a/app/tasks/__init__.py b/app/tasks/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/tasks/runner.py b/app/tasks/runner.py new file mode 100644 index 0000000..99da8ee --- /dev/null +++ b/app/tasks/runner.py @@ -0,0 +1,142 @@ +# app/tasks/runner.py +"""Kiwi background task runner. + +Implements the run_task_fn interface expected by circuitforge_core.tasks.scheduler. +Each kiwi LLM task type has its own handler below. + +Public API: + LLM_TASK_TYPES — frozenset of task type strings to route through the scheduler + VRAM_BUDGETS — VRAM GB estimates per task type + insert_task() — deduplicating task insertion + run_task() — called by the scheduler batch worker +""" +from __future__ import annotations + +import json +import logging +import sqlite3 +from datetime import date, timedelta +from pathlib import Path + +from app.services.expiration_predictor import ExpirationPredictor + +log = logging.getLogger(__name__) + +LLM_TASK_TYPES: frozenset[str] = frozenset({"expiry_llm_fallback"}) + +VRAM_BUDGETS: dict[str, float] = { + # ExpirationPredictor uses a small LLM (16 tokens out, single pass). + "expiry_llm_fallback": 2.0, +} + + +def insert_task( + db_path: Path, + task_type: str, + job_id: int, + *, + params: str | None = None, +) -> tuple[int, bool]: + """Insert a background task if no identical task is already in-flight. + + Returns (task_id, True) if a new task was created. + Returns (existing_id, False) if an identical task is already queued/running. + """ + conn = sqlite3.connect(db_path) + conn.row_factory = sqlite3.Row + existing = conn.execute( + "SELECT id FROM background_tasks " + "WHERE task_type=? AND job_id=? AND status IN ('queued','running')", + (task_type, job_id), + ).fetchone() + if existing: + conn.close() + return existing["id"], False + cursor = conn.execute( + "INSERT INTO background_tasks (task_type, job_id, params) VALUES (?,?,?)", + (task_type, job_id, params), + ) + conn.commit() + task_id = cursor.lastrowid + conn.close() + return task_id, True + + +def _update_task_status( + db_path: Path, task_id: int, status: str, *, error: str = "" +) -> None: + with sqlite3.connect(db_path) as conn: + conn.execute( + "UPDATE background_tasks " + "SET status=?, error=?, updated_at=CURRENT_TIMESTAMP WHERE id=?", + (status, error, task_id), + ) + + +def run_task( + db_path: Path, + task_id: int, + task_type: str, + job_id: int, + params: str | None = None, +) -> None: + """Execute one background task. Called by the scheduler's batch worker.""" + _update_task_status(db_path, task_id, "running") + try: + if task_type == "expiry_llm_fallback": + _run_expiry_llm_fallback(db_path, job_id, params) + else: + raise ValueError(f"Unknown kiwi task type: {task_type!r}") + _update_task_status(db_path, task_id, "completed") + except Exception as exc: + log.exception("Task %d (%s) failed: %s", task_id, task_type, exc) + _update_task_status(db_path, task_id, "failed", error=str(exc)) + + +def _run_expiry_llm_fallback( + db_path: Path, + item_id: int, + params: str | None, +) -> None: + """Predict expiry date via LLM for an inventory item and write result to DB. + + params JSON keys: + product_name (required) — e.g. "Trader Joe's Organic Tempeh" + category (optional) — category hint for the predictor + location (optional) — "fridge" | "freezer" | "pantry" (default: "fridge") + """ + p = json.loads(params or "{}") + product_name = p.get("product_name", "") + category = p.get("category") + location = p.get("location", "fridge") + + if not product_name: + raise ValueError("expiry_llm_fallback: 'product_name' is required in params") + + predictor = ExpirationPredictor() + days = predictor._llm_predict_days(product_name, category, location) + + if days is None: + log.warning( + "LLM expiry fallback returned None for item_id=%d product=%r — " + "expiry_date will remain NULL", + item_id, + product_name, + ) + return + + expiry = (date.today() + timedelta(days=days)).isoformat() + + with sqlite3.connect(db_path) as conn: + conn.execute( + "UPDATE inventory_items SET expiry_date=? WHERE id=?", + (expiry, item_id), + ) + + log.info( + "LLM expiry fallback: item_id=%d %r → %s (%d days)", + item_id, + product_name, + expiry, + days, + ) diff --git a/app/tasks/scheduler.py b/app/tasks/scheduler.py new file mode 100644 index 0000000..b916852 --- /dev/null +++ b/app/tasks/scheduler.py @@ -0,0 +1,23 @@ +# app/tasks/scheduler.py +"""Kiwi LLM task scheduler — thin shim over circuitforge_core.tasks.scheduler.""" +from __future__ import annotations + +from pathlib import Path + +from circuitforge_core.tasks.scheduler import ( + TaskScheduler, + get_scheduler as _base_get_scheduler, + reset_scheduler, # re-export for tests +) + +from app.tasks.runner import LLM_TASK_TYPES, VRAM_BUDGETS, run_task + + +def get_scheduler(db_path: Path) -> TaskScheduler: + """Return the process-level TaskScheduler singleton for Kiwi.""" + return _base_get_scheduler( + db_path=db_path, + run_task_fn=run_task, + task_types=LLM_TASK_TYPES, + vram_budgets=VRAM_BUDGETS, + ) diff --git a/tests/test_tasks/__init__.py b/tests/test_tasks/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_tasks/test_runner.py b/tests/test_tasks/test_runner.py new file mode 100644 index 0000000..f02913a --- /dev/null +++ b/tests/test_tasks/test_runner.py @@ -0,0 +1,126 @@ +"""Tests for kiwi background task runner.""" +from __future__ import annotations + +import json +import sqlite3 +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + +from app.tasks.runner import ( + LLM_TASK_TYPES, + VRAM_BUDGETS, + insert_task, + run_task, +) + + +@pytest.fixture +def tmp_db(tmp_path: Path) -> Path: + db = tmp_path / "kiwi.db" + conn = sqlite3.connect(db) + conn.executescript(""" + CREATE TABLE background_tasks ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + task_type TEXT NOT NULL, + job_id INTEGER NOT NULL DEFAULT 0, + status TEXT NOT NULL DEFAULT 'queued', + params TEXT, + error TEXT, + created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP + ); + CREATE TABLE inventory_items ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL, + expiry_date TEXT + ); + INSERT INTO inventory_items (name, expiry_date) VALUES ('mystery tofu', NULL); + """) + conn.commit() + conn.close() + return db + + +def test_llm_task_types_defined(): + assert "expiry_llm_fallback" in LLM_TASK_TYPES + + +def test_vram_budgets_defined(): + assert "expiry_llm_fallback" in VRAM_BUDGETS + assert VRAM_BUDGETS["expiry_llm_fallback"] > 0 + + +def test_insert_task_creates_row(tmp_db: Path): + task_id, is_new = insert_task(tmp_db, "expiry_llm_fallback", job_id=1) + assert is_new is True + conn = sqlite3.connect(tmp_db) + row = conn.execute("SELECT status FROM background_tasks WHERE id=?", (task_id,)).fetchone() + conn.close() + assert row[0] == "queued" + + +def test_insert_task_dedup(tmp_db: Path): + id1, new1 = insert_task(tmp_db, "expiry_llm_fallback", job_id=1) + id2, new2 = insert_task(tmp_db, "expiry_llm_fallback", job_id=1) + assert id1 == id2 + assert new1 is True + assert new2 is False + + +def test_run_task_expiry_success(tmp_db: Path): + params = json.dumps({"product_name": "Tofu", "category": "protein", "location": "fridge"}) + task_id, _ = insert_task(tmp_db, "expiry_llm_fallback", job_id=1, params=params) + + with patch("app.tasks.runner.ExpirationPredictor") as MockPredictor: + instance = MockPredictor.return_value + instance._llm_predict_days.return_value = 7 + run_task(tmp_db, task_id, "expiry_llm_fallback", 1, params) + + conn = sqlite3.connect(tmp_db) + item = conn.execute("SELECT expiry_date FROM inventory_items WHERE id=1").fetchone() + task = conn.execute("SELECT status FROM background_tasks WHERE id=?", (task_id,)).fetchone() + conn.close() + assert item[0] is not None, "expiry_date should be set" + assert task[0] == "completed" + + +def test_run_task_expiry_llm_returns_none(tmp_db: Path): + """If LLM returns None, task completes without writing expiry_date.""" + params = json.dumps({"product_name": "Unknown widget", "location": "fridge"}) + task_id, _ = insert_task(tmp_db, "expiry_llm_fallback", job_id=1, params=params) + + with patch("app.tasks.runner.ExpirationPredictor") as MockPredictor: + instance = MockPredictor.return_value + instance._llm_predict_days.return_value = None + run_task(tmp_db, task_id, "expiry_llm_fallback", 1, params) + + conn = sqlite3.connect(tmp_db) + item = conn.execute("SELECT expiry_date FROM inventory_items WHERE id=1").fetchone() + task = conn.execute("SELECT status FROM background_tasks WHERE id=?", (task_id,)).fetchone() + conn.close() + assert item[0] is None, "expiry_date should remain NULL when LLM returns None" + assert task[0] == "completed" + + +def test_run_task_missing_product_name_marks_failed(tmp_db: Path): + params = json.dumps({}) + task_id, _ = insert_task(tmp_db, "expiry_llm_fallback", job_id=1, params=params) + run_task(tmp_db, task_id, "expiry_llm_fallback", 1, params) + + conn = sqlite3.connect(tmp_db) + task = conn.execute("SELECT status, error FROM background_tasks WHERE id=?", (task_id,)).fetchone() + conn.close() + assert task[0] == "failed" + assert "product_name" in task[1] + + +def test_run_task_unknown_type_marks_failed(tmp_db: Path): + task_id, _ = insert_task(tmp_db, "expiry_llm_fallback", job_id=1) + run_task(tmp_db, task_id, "unknown_type", 1, None) + + conn = sqlite3.connect(tmp_db) + task = conn.execute("SELECT status FROM background_tasks WHERE id=?", (task_id,)).fetchone() + conn.close() + assert task[0] == "failed"