kiwi/tests/test_tasks/test_runner.py
pyr0ball 636bffda5a feat(tasks): add background task scheduler for LLM expiry fallback
Uses circuitforge_core.tasks.scheduler. VRAM detection via cf-orch when
available, falling back to unlimited. Adds expiry_llm_fallback task type
to background-predict expiry dates for items the LUT doesn't cover.
2026-03-31 09:25:48 -07:00

126 lines
4.4 KiB
Python

"""Tests for kiwi background task runner."""
from __future__ import annotations
import json
import sqlite3
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
from app.tasks.runner import (
LLM_TASK_TYPES,
VRAM_BUDGETS,
insert_task,
run_task,
)
@pytest.fixture
def tmp_db(tmp_path: Path) -> Path:
db = tmp_path / "kiwi.db"
conn = sqlite3.connect(db)
conn.executescript("""
CREATE TABLE background_tasks (
id INTEGER PRIMARY KEY AUTOINCREMENT,
task_type TEXT NOT NULL,
job_id INTEGER NOT NULL DEFAULT 0,
status TEXT NOT NULL DEFAULT 'queued',
params TEXT,
error TEXT,
created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP
);
CREATE TABLE inventory_items (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL,
expiry_date TEXT
);
INSERT INTO inventory_items (name, expiry_date) VALUES ('mystery tofu', NULL);
""")
conn.commit()
conn.close()
return db
def test_llm_task_types_defined():
assert "expiry_llm_fallback" in LLM_TASK_TYPES
def test_vram_budgets_defined():
assert "expiry_llm_fallback" in VRAM_BUDGETS
assert VRAM_BUDGETS["expiry_llm_fallback"] > 0
def test_insert_task_creates_row(tmp_db: Path):
task_id, is_new = insert_task(tmp_db, "expiry_llm_fallback", job_id=1)
assert is_new is True
conn = sqlite3.connect(tmp_db)
row = conn.execute("SELECT status FROM background_tasks WHERE id=?", (task_id,)).fetchone()
conn.close()
assert row[0] == "queued"
def test_insert_task_dedup(tmp_db: Path):
id1, new1 = insert_task(tmp_db, "expiry_llm_fallback", job_id=1)
id2, new2 = insert_task(tmp_db, "expiry_llm_fallback", job_id=1)
assert id1 == id2
assert new1 is True
assert new2 is False
def test_run_task_expiry_success(tmp_db: Path):
params = json.dumps({"product_name": "Tofu", "category": "protein", "location": "fridge"})
task_id, _ = insert_task(tmp_db, "expiry_llm_fallback", job_id=1, params=params)
with patch("app.tasks.runner.ExpirationPredictor") as MockPredictor:
instance = MockPredictor.return_value
instance._llm_predict_days.return_value = 7
run_task(tmp_db, task_id, "expiry_llm_fallback", 1, params)
conn = sqlite3.connect(tmp_db)
item = conn.execute("SELECT expiry_date FROM inventory_items WHERE id=1").fetchone()
task = conn.execute("SELECT status FROM background_tasks WHERE id=?", (task_id,)).fetchone()
conn.close()
assert item[0] is not None, "expiry_date should be set"
assert task[0] == "completed"
def test_run_task_expiry_llm_returns_none(tmp_db: Path):
"""If LLM returns None, task completes without writing expiry_date."""
params = json.dumps({"product_name": "Unknown widget", "location": "fridge"})
task_id, _ = insert_task(tmp_db, "expiry_llm_fallback", job_id=1, params=params)
with patch("app.tasks.runner.ExpirationPredictor") as MockPredictor:
instance = MockPredictor.return_value
instance._llm_predict_days.return_value = None
run_task(tmp_db, task_id, "expiry_llm_fallback", 1, params)
conn = sqlite3.connect(tmp_db)
item = conn.execute("SELECT expiry_date FROM inventory_items WHERE id=1").fetchone()
task = conn.execute("SELECT status FROM background_tasks WHERE id=?", (task_id,)).fetchone()
conn.close()
assert item[0] is None, "expiry_date should remain NULL when LLM returns None"
assert task[0] == "completed"
def test_run_task_missing_product_name_marks_failed(tmp_db: Path):
params = json.dumps({})
task_id, _ = insert_task(tmp_db, "expiry_llm_fallback", job_id=1, params=params)
run_task(tmp_db, task_id, "expiry_llm_fallback", 1, params)
conn = sqlite3.connect(tmp_db)
task = conn.execute("SELECT status, error FROM background_tasks WHERE id=?", (task_id,)).fetchone()
conn.close()
assert task[0] == "failed"
assert "product_name" in task[1]
def test_run_task_unknown_type_marks_failed(tmp_db: Path):
task_id, _ = insert_task(tmp_db, "expiry_llm_fallback", job_id=1)
run_task(tmp_db, task_id, "unknown_type", 1, None)
conn = sqlite3.connect(tmp_db)
task = conn.execute("SELECT status FROM background_tasks WHERE id=?", (task_id,)).fetchone()
conn.close()
assert task[0] == "failed"