Uses circuitforge_core.tasks.scheduler. VRAM detection via cf-orch when available, falling back to unlimited. Adds expiry_llm_fallback task type to background-predict expiry dates for items the LUT doesn't cover.
126 lines
4.4 KiB
Python
126 lines
4.4 KiB
Python
"""Tests for kiwi background task runner."""
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import sqlite3
|
|
from pathlib import Path
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
from app.tasks.runner import (
|
|
LLM_TASK_TYPES,
|
|
VRAM_BUDGETS,
|
|
insert_task,
|
|
run_task,
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def tmp_db(tmp_path: Path) -> Path:
|
|
db = tmp_path / "kiwi.db"
|
|
conn = sqlite3.connect(db)
|
|
conn.executescript("""
|
|
CREATE TABLE background_tasks (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
task_type TEXT NOT NULL,
|
|
job_id INTEGER NOT NULL DEFAULT 0,
|
|
status TEXT NOT NULL DEFAULT 'queued',
|
|
params TEXT,
|
|
error TEXT,
|
|
created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
|
updated_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP
|
|
);
|
|
CREATE TABLE inventory_items (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
name TEXT NOT NULL,
|
|
expiry_date TEXT
|
|
);
|
|
INSERT INTO inventory_items (name, expiry_date) VALUES ('mystery tofu', NULL);
|
|
""")
|
|
conn.commit()
|
|
conn.close()
|
|
return db
|
|
|
|
|
|
def test_llm_task_types_defined():
|
|
assert "expiry_llm_fallback" in LLM_TASK_TYPES
|
|
|
|
|
|
def test_vram_budgets_defined():
|
|
assert "expiry_llm_fallback" in VRAM_BUDGETS
|
|
assert VRAM_BUDGETS["expiry_llm_fallback"] > 0
|
|
|
|
|
|
def test_insert_task_creates_row(tmp_db: Path):
|
|
task_id, is_new = insert_task(tmp_db, "expiry_llm_fallback", job_id=1)
|
|
assert is_new is True
|
|
conn = sqlite3.connect(tmp_db)
|
|
row = conn.execute("SELECT status FROM background_tasks WHERE id=?", (task_id,)).fetchone()
|
|
conn.close()
|
|
assert row[0] == "queued"
|
|
|
|
|
|
def test_insert_task_dedup(tmp_db: Path):
|
|
id1, new1 = insert_task(tmp_db, "expiry_llm_fallback", job_id=1)
|
|
id2, new2 = insert_task(tmp_db, "expiry_llm_fallback", job_id=1)
|
|
assert id1 == id2
|
|
assert new1 is True
|
|
assert new2 is False
|
|
|
|
|
|
def test_run_task_expiry_success(tmp_db: Path):
|
|
params = json.dumps({"product_name": "Tofu", "category": "protein", "location": "fridge"})
|
|
task_id, _ = insert_task(tmp_db, "expiry_llm_fallback", job_id=1, params=params)
|
|
|
|
with patch("app.tasks.runner.ExpirationPredictor") as MockPredictor:
|
|
instance = MockPredictor.return_value
|
|
instance._llm_predict_days.return_value = 7
|
|
run_task(tmp_db, task_id, "expiry_llm_fallback", 1, params)
|
|
|
|
conn = sqlite3.connect(tmp_db)
|
|
item = conn.execute("SELECT expiry_date FROM inventory_items WHERE id=1").fetchone()
|
|
task = conn.execute("SELECT status FROM background_tasks WHERE id=?", (task_id,)).fetchone()
|
|
conn.close()
|
|
assert item[0] is not None, "expiry_date should be set"
|
|
assert task[0] == "completed"
|
|
|
|
|
|
def test_run_task_expiry_llm_returns_none(tmp_db: Path):
|
|
"""If LLM returns None, task completes without writing expiry_date."""
|
|
params = json.dumps({"product_name": "Unknown widget", "location": "fridge"})
|
|
task_id, _ = insert_task(tmp_db, "expiry_llm_fallback", job_id=1, params=params)
|
|
|
|
with patch("app.tasks.runner.ExpirationPredictor") as MockPredictor:
|
|
instance = MockPredictor.return_value
|
|
instance._llm_predict_days.return_value = None
|
|
run_task(tmp_db, task_id, "expiry_llm_fallback", 1, params)
|
|
|
|
conn = sqlite3.connect(tmp_db)
|
|
item = conn.execute("SELECT expiry_date FROM inventory_items WHERE id=1").fetchone()
|
|
task = conn.execute("SELECT status FROM background_tasks WHERE id=?", (task_id,)).fetchone()
|
|
conn.close()
|
|
assert item[0] is None, "expiry_date should remain NULL when LLM returns None"
|
|
assert task[0] == "completed"
|
|
|
|
|
|
def test_run_task_missing_product_name_marks_failed(tmp_db: Path):
|
|
params = json.dumps({})
|
|
task_id, _ = insert_task(tmp_db, "expiry_llm_fallback", job_id=1, params=params)
|
|
run_task(tmp_db, task_id, "expiry_llm_fallback", 1, params)
|
|
|
|
conn = sqlite3.connect(tmp_db)
|
|
task = conn.execute("SELECT status, error FROM background_tasks WHERE id=?", (task_id,)).fetchone()
|
|
conn.close()
|
|
assert task[0] == "failed"
|
|
assert "product_name" in task[1]
|
|
|
|
|
|
def test_run_task_unknown_type_marks_failed(tmp_db: Path):
|
|
task_id, _ = insert_task(tmp_db, "expiry_llm_fallback", job_id=1)
|
|
run_task(tmp_db, task_id, "unknown_type", 1, None)
|
|
|
|
conn = sqlite3.connect(tmp_db)
|
|
task = conn.execute("SELECT status FROM background_tasks WHERE id=?", (task_id,)).fetchone()
|
|
conn.close()
|
|
assert task[0] == "failed"
|