feat(tasks): add background task scheduler for LLM expiry fallback
Uses circuitforge_core.tasks.scheduler. VRAM detection via cf-orch when available, falling back to unlimited. Adds expiry_llm_fallback task type to background-predict expiry dates for items the LUT doesn't cover.
This commit is contained in:
parent
8cbde774e5
commit
636bffda5a
7 changed files with 320 additions and 0 deletions
18
app/db/migrations/006_background_tasks.sql
Normal file
18
app/db/migrations/006_background_tasks.sql
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
-- 006_background_tasks.sql
|
||||
-- Shared background task queue used by the LLM task scheduler.
|
||||
-- Schema mirrors Peregrine's background_tasks for circuitforge-core compatibility.
|
||||
|
||||
CREATE TABLE IF NOT EXISTS background_tasks (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
task_type TEXT NOT NULL,
|
||||
job_id INTEGER NOT NULL DEFAULT 0,
|
||||
status TEXT NOT NULL DEFAULT 'queued',
|
||||
params TEXT,
|
||||
error TEXT,
|
||||
stage TEXT,
|
||||
created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_bg_tasks_status_type
|
||||
ON background_tasks (status, task_type);
|
||||
11
app/main.py
11
app/main.py
|
|
@ -17,7 +17,18 @@ logger = logging.getLogger(__name__)
|
|||
async def lifespan(app: FastAPI):
|
||||
logger.info("Starting Kiwi API...")
|
||||
settings.ensure_dirs()
|
||||
|
||||
# Start LLM background task scheduler
|
||||
from app.tasks.scheduler import get_scheduler
|
||||
get_scheduler(settings.DB_PATH)
|
||||
logger.info("Task scheduler started.")
|
||||
|
||||
yield
|
||||
|
||||
# Graceful scheduler shutdown
|
||||
from app.tasks.scheduler import get_scheduler, reset_scheduler
|
||||
get_scheduler(settings.DB_PATH).shutdown(timeout=10.0)
|
||||
reset_scheduler()
|
||||
logger.info("Kiwi API shutting down.")
|
||||
|
||||
|
||||
|
|
|
|||
0
app/tasks/__init__.py
Normal file
0
app/tasks/__init__.py
Normal file
142
app/tasks/runner.py
Normal file
142
app/tasks/runner.py
Normal file
|
|
@ -0,0 +1,142 @@
|
|||
# app/tasks/runner.py
|
||||
"""Kiwi background task runner.
|
||||
|
||||
Implements the run_task_fn interface expected by circuitforge_core.tasks.scheduler.
|
||||
Each kiwi LLM task type has its own handler below.
|
||||
|
||||
Public API:
|
||||
LLM_TASK_TYPES — frozenset of task type strings to route through the scheduler
|
||||
VRAM_BUDGETS — VRAM GB estimates per task type
|
||||
insert_task() — deduplicating task insertion
|
||||
run_task() — called by the scheduler batch worker
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import sqlite3
|
||||
from datetime import date, timedelta
|
||||
from pathlib import Path
|
||||
|
||||
from app.services.expiration_predictor import ExpirationPredictor
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
LLM_TASK_TYPES: frozenset[str] = frozenset({"expiry_llm_fallback"})
|
||||
|
||||
VRAM_BUDGETS: dict[str, float] = {
|
||||
# ExpirationPredictor uses a small LLM (16 tokens out, single pass).
|
||||
"expiry_llm_fallback": 2.0,
|
||||
}
|
||||
|
||||
|
||||
def insert_task(
|
||||
db_path: Path,
|
||||
task_type: str,
|
||||
job_id: int,
|
||||
*,
|
||||
params: str | None = None,
|
||||
) -> tuple[int, bool]:
|
||||
"""Insert a background task if no identical task is already in-flight.
|
||||
|
||||
Returns (task_id, True) if a new task was created.
|
||||
Returns (existing_id, False) if an identical task is already queued/running.
|
||||
"""
|
||||
conn = sqlite3.connect(db_path)
|
||||
conn.row_factory = sqlite3.Row
|
||||
existing = conn.execute(
|
||||
"SELECT id FROM background_tasks "
|
||||
"WHERE task_type=? AND job_id=? AND status IN ('queued','running')",
|
||||
(task_type, job_id),
|
||||
).fetchone()
|
||||
if existing:
|
||||
conn.close()
|
||||
return existing["id"], False
|
||||
cursor = conn.execute(
|
||||
"INSERT INTO background_tasks (task_type, job_id, params) VALUES (?,?,?)",
|
||||
(task_type, job_id, params),
|
||||
)
|
||||
conn.commit()
|
||||
task_id = cursor.lastrowid
|
||||
conn.close()
|
||||
return task_id, True
|
||||
|
||||
|
||||
def _update_task_status(
|
||||
db_path: Path, task_id: int, status: str, *, error: str = ""
|
||||
) -> None:
|
||||
with sqlite3.connect(db_path) as conn:
|
||||
conn.execute(
|
||||
"UPDATE background_tasks "
|
||||
"SET status=?, error=?, updated_at=CURRENT_TIMESTAMP WHERE id=?",
|
||||
(status, error, task_id),
|
||||
)
|
||||
|
||||
|
||||
def run_task(
|
||||
db_path: Path,
|
||||
task_id: int,
|
||||
task_type: str,
|
||||
job_id: int,
|
||||
params: str | None = None,
|
||||
) -> None:
|
||||
"""Execute one background task. Called by the scheduler's batch worker."""
|
||||
_update_task_status(db_path, task_id, "running")
|
||||
try:
|
||||
if task_type == "expiry_llm_fallback":
|
||||
_run_expiry_llm_fallback(db_path, job_id, params)
|
||||
else:
|
||||
raise ValueError(f"Unknown kiwi task type: {task_type!r}")
|
||||
_update_task_status(db_path, task_id, "completed")
|
||||
except Exception as exc:
|
||||
log.exception("Task %d (%s) failed: %s", task_id, task_type, exc)
|
||||
_update_task_status(db_path, task_id, "failed", error=str(exc))
|
||||
|
||||
|
||||
def _run_expiry_llm_fallback(
|
||||
db_path: Path,
|
||||
item_id: int,
|
||||
params: str | None,
|
||||
) -> None:
|
||||
"""Predict expiry date via LLM for an inventory item and write result to DB.
|
||||
|
||||
params JSON keys:
|
||||
product_name (required) — e.g. "Trader Joe's Organic Tempeh"
|
||||
category (optional) — category hint for the predictor
|
||||
location (optional) — "fridge" | "freezer" | "pantry" (default: "fridge")
|
||||
"""
|
||||
p = json.loads(params or "{}")
|
||||
product_name = p.get("product_name", "")
|
||||
category = p.get("category")
|
||||
location = p.get("location", "fridge")
|
||||
|
||||
if not product_name:
|
||||
raise ValueError("expiry_llm_fallback: 'product_name' is required in params")
|
||||
|
||||
predictor = ExpirationPredictor()
|
||||
days = predictor._llm_predict_days(product_name, category, location)
|
||||
|
||||
if days is None:
|
||||
log.warning(
|
||||
"LLM expiry fallback returned None for item_id=%d product=%r — "
|
||||
"expiry_date will remain NULL",
|
||||
item_id,
|
||||
product_name,
|
||||
)
|
||||
return
|
||||
|
||||
expiry = (date.today() + timedelta(days=days)).isoformat()
|
||||
|
||||
with sqlite3.connect(db_path) as conn:
|
||||
conn.execute(
|
||||
"UPDATE inventory_items SET expiry_date=? WHERE id=?",
|
||||
(expiry, item_id),
|
||||
)
|
||||
|
||||
log.info(
|
||||
"LLM expiry fallback: item_id=%d %r → %s (%d days)",
|
||||
item_id,
|
||||
product_name,
|
||||
expiry,
|
||||
days,
|
||||
)
|
||||
23
app/tasks/scheduler.py
Normal file
23
app/tasks/scheduler.py
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
# app/tasks/scheduler.py
|
||||
"""Kiwi LLM task scheduler — thin shim over circuitforge_core.tasks.scheduler."""
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from circuitforge_core.tasks.scheduler import (
|
||||
TaskScheduler,
|
||||
get_scheduler as _base_get_scheduler,
|
||||
reset_scheduler, # re-export for tests
|
||||
)
|
||||
|
||||
from app.tasks.runner import LLM_TASK_TYPES, VRAM_BUDGETS, run_task
|
||||
|
||||
|
||||
def get_scheduler(db_path: Path) -> TaskScheduler:
|
||||
"""Return the process-level TaskScheduler singleton for Kiwi."""
|
||||
return _base_get_scheduler(
|
||||
db_path=db_path,
|
||||
run_task_fn=run_task,
|
||||
task_types=LLM_TASK_TYPES,
|
||||
vram_budgets=VRAM_BUDGETS,
|
||||
)
|
||||
0
tests/test_tasks/__init__.py
Normal file
0
tests/test_tasks/__init__.py
Normal file
126
tests/test_tasks/test_runner.py
Normal file
126
tests/test_tasks/test_runner.py
Normal file
|
|
@ -0,0 +1,126 @@
|
|||
"""Tests for kiwi background task runner."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.tasks.runner import (
|
||||
LLM_TASK_TYPES,
|
||||
VRAM_BUDGETS,
|
||||
insert_task,
|
||||
run_task,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tmp_db(tmp_path: Path) -> Path:
|
||||
db = tmp_path / "kiwi.db"
|
||||
conn = sqlite3.connect(db)
|
||||
conn.executescript("""
|
||||
CREATE TABLE background_tasks (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
task_type TEXT NOT NULL,
|
||||
job_id INTEGER NOT NULL DEFAULT 0,
|
||||
status TEXT NOT NULL DEFAULT 'queued',
|
||||
params TEXT,
|
||||
error TEXT,
|
||||
created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
CREATE TABLE inventory_items (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
name TEXT NOT NULL,
|
||||
expiry_date TEXT
|
||||
);
|
||||
INSERT INTO inventory_items (name, expiry_date) VALUES ('mystery tofu', NULL);
|
||||
""")
|
||||
conn.commit()
|
||||
conn.close()
|
||||
return db
|
||||
|
||||
|
||||
def test_llm_task_types_defined():
|
||||
assert "expiry_llm_fallback" in LLM_TASK_TYPES
|
||||
|
||||
|
||||
def test_vram_budgets_defined():
|
||||
assert "expiry_llm_fallback" in VRAM_BUDGETS
|
||||
assert VRAM_BUDGETS["expiry_llm_fallback"] > 0
|
||||
|
||||
|
||||
def test_insert_task_creates_row(tmp_db: Path):
|
||||
task_id, is_new = insert_task(tmp_db, "expiry_llm_fallback", job_id=1)
|
||||
assert is_new is True
|
||||
conn = sqlite3.connect(tmp_db)
|
||||
row = conn.execute("SELECT status FROM background_tasks WHERE id=?", (task_id,)).fetchone()
|
||||
conn.close()
|
||||
assert row[0] == "queued"
|
||||
|
||||
|
||||
def test_insert_task_dedup(tmp_db: Path):
|
||||
id1, new1 = insert_task(tmp_db, "expiry_llm_fallback", job_id=1)
|
||||
id2, new2 = insert_task(tmp_db, "expiry_llm_fallback", job_id=1)
|
||||
assert id1 == id2
|
||||
assert new1 is True
|
||||
assert new2 is False
|
||||
|
||||
|
||||
def test_run_task_expiry_success(tmp_db: Path):
|
||||
params = json.dumps({"product_name": "Tofu", "category": "protein", "location": "fridge"})
|
||||
task_id, _ = insert_task(tmp_db, "expiry_llm_fallback", job_id=1, params=params)
|
||||
|
||||
with patch("app.tasks.runner.ExpirationPredictor") as MockPredictor:
|
||||
instance = MockPredictor.return_value
|
||||
instance._llm_predict_days.return_value = 7
|
||||
run_task(tmp_db, task_id, "expiry_llm_fallback", 1, params)
|
||||
|
||||
conn = sqlite3.connect(tmp_db)
|
||||
item = conn.execute("SELECT expiry_date FROM inventory_items WHERE id=1").fetchone()
|
||||
task = conn.execute("SELECT status FROM background_tasks WHERE id=?", (task_id,)).fetchone()
|
||||
conn.close()
|
||||
assert item[0] is not None, "expiry_date should be set"
|
||||
assert task[0] == "completed"
|
||||
|
||||
|
||||
def test_run_task_expiry_llm_returns_none(tmp_db: Path):
|
||||
"""If LLM returns None, task completes without writing expiry_date."""
|
||||
params = json.dumps({"product_name": "Unknown widget", "location": "fridge"})
|
||||
task_id, _ = insert_task(tmp_db, "expiry_llm_fallback", job_id=1, params=params)
|
||||
|
||||
with patch("app.tasks.runner.ExpirationPredictor") as MockPredictor:
|
||||
instance = MockPredictor.return_value
|
||||
instance._llm_predict_days.return_value = None
|
||||
run_task(tmp_db, task_id, "expiry_llm_fallback", 1, params)
|
||||
|
||||
conn = sqlite3.connect(tmp_db)
|
||||
item = conn.execute("SELECT expiry_date FROM inventory_items WHERE id=1").fetchone()
|
||||
task = conn.execute("SELECT status FROM background_tasks WHERE id=?", (task_id,)).fetchone()
|
||||
conn.close()
|
||||
assert item[0] is None, "expiry_date should remain NULL when LLM returns None"
|
||||
assert task[0] == "completed"
|
||||
|
||||
|
||||
def test_run_task_missing_product_name_marks_failed(tmp_db: Path):
|
||||
params = json.dumps({})
|
||||
task_id, _ = insert_task(tmp_db, "expiry_llm_fallback", job_id=1, params=params)
|
||||
run_task(tmp_db, task_id, "expiry_llm_fallback", 1, params)
|
||||
|
||||
conn = sqlite3.connect(tmp_db)
|
||||
task = conn.execute("SELECT status, error FROM background_tasks WHERE id=?", (task_id,)).fetchone()
|
||||
conn.close()
|
||||
assert task[0] == "failed"
|
||||
assert "product_name" in task[1]
|
||||
|
||||
|
||||
def test_run_task_unknown_type_marks_failed(tmp_db: Path):
|
||||
task_id, _ = insert_task(tmp_db, "expiry_llm_fallback", job_id=1)
|
||||
run_task(tmp_db, task_id, "unknown_type", 1, None)
|
||||
|
||||
conn = sqlite3.connect(tmp_db)
|
||||
task = conn.execute("SELECT status FROM background_tasks WHERE id=?", (task_id,)).fetchone()
|
||||
conn.close()
|
||||
assert task[0] == "failed"
|
||||
Loading…
Reference in a new issue