avocet/tests/test_imitate.py
pyr0ball 3299c0e23a feat: Imitate tab — pull CF product samples, compare LLM responses
Backend (app/imitate.py):
- GET /api/imitate/products — reads imitate: config, checks online status
- GET /api/imitate/products/{id}/sample — fetches real item from product API
- GET /api/imitate/run (SSE) — streams ollama responses for selected models
- POST /api/imitate/push-corrections — queues results in SFT corrections JSONL

Frontend (ImitateView.vue):
- Step 1: product picker grid (online/offline status, icon from config)
- Step 2: raw sample preview + editable prompt textarea
- Step 3: ollama model multi-select, temperature slider, SSE run with live log
- Step 4: response cards side by side, push to Corrections button

Wiring:
- app/api.py: include imitate_router at /api/imitate
- web/src/router: /imitate route + lazy import
- AppSidebar: Imitate nav entry (mirror icon)
- config/label_tool.yaml.example: imitate: section with peregrine example
- 16 unit tests (100% passing)

Also: BenchmarkView.vue Compare panel — side-by-side run diff for bench results
2026-04-09 20:12:57 -07:00

242 lines
9 KiB
Python

"""Tests for app/imitate.py — product registry, sample extraction, corrections push."""
from __future__ import annotations
import json
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
from fastapi.testclient import TestClient
from app.api import app
from app import imitate as _imitate_module
# ── Fixtures ───────────────────────────────────────────────────────────────────
@pytest.fixture(autouse=True)
def reset_module_globals(tmp_path):
"""Reset module-level config + data dir globals after each test."""
orig_cfg = _imitate_module._CONFIG_DIR
orig_data = _imitate_module._DATA_DIR
yield
_imitate_module._CONFIG_DIR = orig_cfg
_imitate_module._DATA_DIR = orig_data
@pytest.fixture()
def config_dir(tmp_path) -> Path:
_imitate_module.set_config_dir(tmp_path)
return tmp_path
@pytest.fixture()
def data_dir(tmp_path) -> Path:
_imitate_module.set_data_dir(tmp_path)
return tmp_path
@pytest.fixture()
def cfg_with_products(config_dir: Path) -> Path:
"""Write a label_tool.yaml with two products."""
(config_dir / "label_tool.yaml").write_text(
"""
imitate:
ollama_url: http://localhost:11434
products:
- id: peregrine
name: Peregrine
icon: "🦅"
description: Job search assistant
base_url: http://peregrine.local
sample_endpoint: /api/jobs
text_fields: [title, description]
prompt_template: "Analyze: {text}"
- id: kiwi
name: Kiwi
icon: "🥝"
description: Pantry tracker
base_url: http://kiwi.local
sample_endpoint: /api/inventory
text_fields: [name, notes]
prompt_template: "Describe: {text}"
"""
)
return config_dir
@pytest.fixture()
def client() -> TestClient:
return TestClient(app, raise_server_exceptions=True)
# ── GET /products ──────────────────────────────────────────────────────────────
def test_products_empty_when_no_config(config_dir, client):
"""Returns empty list when label_tool.yaml has no imitate section."""
(config_dir / "label_tool.yaml").write_text("accounts: []\n")
resp = client.get("/api/imitate/products")
assert resp.status_code == 200
assert resp.json()["products"] == []
def test_products_listed(cfg_with_products, client):
"""All configured products are returned with expected fields."""
with patch.object(_imitate_module, "_is_online", return_value=True):
resp = client.get("/api/imitate/products")
assert resp.status_code == 200
products = resp.json()["products"]
assert len(products) == 2
ids = {p["id"] for p in products}
assert ids == {"peregrine", "kiwi"}
peregrine = next(p for p in products if p["id"] == "peregrine")
assert peregrine["name"] == "Peregrine"
assert peregrine["icon"] == "🦅"
assert peregrine["online"] is True
def test_products_offline_when_unreachable(cfg_with_products, client):
"""Products with unreachable base_url are marked offline."""
with patch.object(_imitate_module, "_is_online", return_value=False):
resp = client.get("/api/imitate/products")
assert all(not p["online"] for p in resp.json()["products"])
# ── GET /products/{id}/sample ─────────────────────────────────────────────────
def test_sample_unknown_product(cfg_with_products, client):
"""Returns 404 for a product id not in config."""
resp = client.get("/api/imitate/products/nonexistent/sample")
assert resp.status_code == 404
def test_sample_fetched_from_list(cfg_with_products, client):
"""Extracts first item from a list API response."""
fake_api = [
{"title": "Engineer", "description": "Build things"},
{"title": "Other", "description": "Ignore me"},
]
with patch.object(_imitate_module, "_http_get_json", return_value=fake_api):
resp = client.get("/api/imitate/products/peregrine/sample")
assert resp.status_code == 200
body = resp.json()
assert "Engineer" in body["text"]
assert "Build things" in body["text"]
assert "Analyze:" in body["prompt"]
def test_sample_fetched_from_dict_with_items_key(cfg_with_products, client):
"""Extracts from a wrapper dict with a recognised list key."""
fake_api = {"items": [{"title": "Wrapped Job", "description": "In a wrapper"}]}
with patch.object(_imitate_module, "_http_get_json", return_value=fake_api):
resp = client.get("/api/imitate/products/peregrine/sample")
assert resp.status_code == 200
assert "Wrapped Job" in resp.json()["text"]
def test_sample_503_when_api_unreachable(cfg_with_products, client):
"""Returns 503 when the product API is not reachable."""
from urllib.error import URLError
with patch.object(_imitate_module, "_http_get_json", side_effect=URLError("refused")):
resp = client.get("/api/imitate/products/peregrine/sample")
assert resp.status_code == 503
def test_sample_404_on_empty_list(cfg_with_products, client):
"""Returns 404 when product API returns an empty list."""
with patch.object(_imitate_module, "_http_get_json", return_value=[]):
resp = client.get("/api/imitate/products/peregrine/sample")
assert resp.status_code == 404
# ── POST /push-corrections ─────────────────────────────────────────────────────
def test_push_corrections_appends_jsonl(cfg_with_products, data_dir, client):
"""Successful push writes records to sft_candidates.jsonl."""
payload = {
"product_id": "peregrine",
"prompt": "Analyze this job:",
"results": [
{"model": "qwen2.5:0.5b", "response": "It's a good job.", "elapsed_ms": 800, "error": None},
{"model": "llama3.1:8b", "response": "Strong candidate.", "elapsed_ms": 1500, "error": None},
],
}
resp = client.post("/api/imitate/push-corrections", json=payload)
assert resp.status_code == 200
assert resp.json()["pushed"] == 2
candidates = (data_dir / "sft_candidates.jsonl").read_text().splitlines()
assert len(candidates) == 2
for line in candidates:
record = json.loads(line)
assert record["source"] == "imitate"
assert record["product_id"] == "peregrine"
assert record["status"] == "pending"
assert record["prompt_messages"][0]["role"] == "user"
def test_push_corrections_skips_errors(cfg_with_products, data_dir, client):
"""Results with errors are not written to the corrections file."""
payload = {
"product_id": "peregrine",
"prompt": "Analyze:",
"results": [
{"model": "good-model", "response": "Good answer.", "elapsed_ms": 500, "error": None},
{"model": "bad-model", "response": "", "elapsed_ms": 0, "error": "connection refused"},
],
}
resp = client.post("/api/imitate/push-corrections", json=payload)
assert resp.status_code == 200
assert resp.json()["pushed"] == 1
def test_push_corrections_empty_prompt_422(cfg_with_products, data_dir, client):
"""Empty prompt returns 422."""
payload = {
"product_id": "peregrine",
"prompt": " ",
"results": [{"model": "m", "response": "r", "elapsed_ms": 1, "error": None}],
}
resp = client.post("/api/imitate/push-corrections", json=payload)
assert resp.status_code == 422
def test_push_corrections_all_errors_422(cfg_with_products, data_dir, client):
"""422 when every result has an error (nothing to push)."""
payload = {
"product_id": "peregrine",
"prompt": "Analyze:",
"results": [
{"model": "m", "response": "", "elapsed_ms": 0, "error": "timed out"},
],
}
resp = client.post("/api/imitate/push-corrections", json=payload)
assert resp.status_code == 422
# ── _extract_sample helper ─────────────────────────────────────────────────────
def test_extract_sample_list():
result = _imitate_module._extract_sample(
[{"title": "A", "description": "B"}],
text_fields=["title", "description"],
)
assert "A" in result["text"]
assert "B" in result["text"]
def test_extract_sample_empty_list():
result = _imitate_module._extract_sample([], text_fields=["title"])
assert result == {}
def test_extract_sample_respects_index():
items = [{"title": "First"}, {"title": "Second"}]
result = _imitate_module._extract_sample(items, ["title"], sample_index=1)
assert "Second" in result["text"]
def test_extract_sample_clamps_index():
items = [{"title": "Only"}]
result = _imitate_module._extract_sample(items, ["title"], sample_index=99)
assert "Only" in result["text"]