avocet/tests/test_models.py
pyr0ball b6b3d2c390 feat: HuggingFace model management tab
- New /api/models router: HF lookup, approval queue (JSONL persistence),
  SSE download progress via snapshot_download(), installed model listing,
  path-traversal-safe DELETE
- pipeline_tag → adapter type mapping (zero-shot-classification,
  sentence-similarity, text-generation)
- 27 tests covering all endpoints, duplicate detection, path traversal
- ModelsView.vue: HF lookup + add, approval queue, live download progress
  bars via SSE, installed model table with delete
- Sidebar entry (🤗 Models) between Benchmark and Corrections
2026-04-08 22:32:35 -07:00

399 lines
15 KiB
Python

"""Tests for app/models.py — /api/models/* endpoints."""
from __future__ import annotations
import json
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
from fastapi.testclient import TestClient
# ── Fixtures ───────────────────────────────────────────────────────────────────
@pytest.fixture(autouse=True)
def reset_models_globals(tmp_path):
"""Redirect module-level dirs to tmp_path and reset download progress."""
from app import models as models_module
prev_models = models_module._MODELS_DIR
prev_queue = models_module._QUEUE_DIR
prev_progress = dict(models_module._download_progress)
models_dir = tmp_path / "models"
queue_dir = tmp_path / "data"
models_dir.mkdir()
queue_dir.mkdir()
models_module.set_models_dir(models_dir)
models_module.set_queue_dir(queue_dir)
models_module._download_progress = {}
yield
models_module.set_models_dir(prev_models)
models_module.set_queue_dir(prev_queue)
models_module._download_progress = prev_progress
@pytest.fixture
def client():
from app.api import app
return TestClient(app)
def _make_hf_response(repo_id: str = "org/model", pipeline_tag: str = "text-classification") -> dict:
"""Minimal HF API response payload."""
return {
"modelId": repo_id,
"pipeline_tag": pipeline_tag,
"tags": ["pytorch", pipeline_tag],
"downloads": 42000,
"siblings": [
{"rfilename": "pytorch_model.bin", "size": 500_000_000},
],
"cardData": {"description": "A test model description."},
}
def _queue_one(client, repo_id: str = "org/model") -> dict:
"""Helper: POST to /queue and return the created entry."""
r = client.post("/api/models/queue", json={
"repo_id": repo_id,
"pipeline_tag": "text-classification",
"adapter_recommendation": "ZeroShotAdapter",
})
assert r.status_code == 201, r.text
return r.json()
# ── GET /lookup ────────────────────────────────────────────────────────────────
def test_lookup_invalid_repo_id_returns_422_no_slash(client):
"""repo_id without a '/' should be rejected with 422."""
r = client.get("/api/models/lookup", params={"repo_id": "noslash"})
assert r.status_code == 422
def test_lookup_invalid_repo_id_returns_422_whitespace(client):
"""repo_id containing whitespace should be rejected with 422."""
r = client.get("/api/models/lookup", params={"repo_id": "org/model name"})
assert r.status_code == 422
def test_lookup_hf_404_returns_404(client):
"""HF API returning 404 should surface as HTTP 404."""
mock_resp = MagicMock()
mock_resp.status_code = 404
with patch("app.models.httpx.get", return_value=mock_resp):
r = client.get("/api/models/lookup", params={"repo_id": "org/nonexistent"})
assert r.status_code == 404
def test_lookup_hf_network_error_returns_502(client):
"""Network error reaching HF API should return 502."""
import httpx as _httpx
with patch("app.models.httpx.get", side_effect=_httpx.RequestError("timeout")):
r = client.get("/api/models/lookup", params={"repo_id": "org/model"})
assert r.status_code == 502
def test_lookup_returns_correct_shape(client):
"""Successful lookup returns all required fields."""
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.json.return_value = _make_hf_response("org/mymodel", "text-classification")
with patch("app.models.httpx.get", return_value=mock_resp):
r = client.get("/api/models/lookup", params={"repo_id": "org/mymodel"})
assert r.status_code == 200
data = r.json()
assert data["repo_id"] == "org/mymodel"
assert data["pipeline_tag"] == "text-classification"
assert data["adapter_recommendation"] == "ZeroShotAdapter"
assert data["model_size_bytes"] == 500_000_000
assert data["downloads"] == 42000
assert data["already_installed"] is False
assert data["already_queued"] is False
def test_lookup_unknown_pipeline_tag_returns_null_adapter(client):
"""An unrecognised pipeline_tag yields adapter_recommendation=null."""
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.json.return_value = _make_hf_response("org/m", "audio-classification")
with patch("app.models.httpx.get", return_value=mock_resp):
r = client.get("/api/models/lookup", params={"repo_id": "org/m"})
assert r.status_code == 200
assert r.json()["adapter_recommendation"] is None
def test_lookup_already_queued_flag(client):
"""already_queued is True when repo_id is in the pending queue."""
_queue_one(client, "org/queued-model")
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.json.return_value = _make_hf_response("org/queued-model")
with patch("app.models.httpx.get", return_value=mock_resp):
r = client.get("/api/models/lookup", params={"repo_id": "org/queued-model"})
assert r.status_code == 200
assert r.json()["already_queued"] is True
# ── GET /queue ─────────────────────────────────────────────────────────────────
def test_queue_empty_initially(client):
r = client.get("/api/models/queue")
assert r.status_code == 200
assert r.json() == []
def test_queue_add_and_list(client):
"""POST then GET /queue should return the entry."""
entry = _queue_one(client, "org/my-model")
r = client.get("/api/models/queue")
assert r.status_code == 200
items = r.json()
assert len(items) == 1
assert items[0]["repo_id"] == "org/my-model"
assert items[0]["status"] == "pending"
assert items[0]["id"] == entry["id"]
def test_queue_add_returns_entry_fields(client):
"""POST /queue returns an entry with all expected fields."""
entry = _queue_one(client)
assert "id" in entry
assert "queued_at" in entry
assert entry["status"] == "pending"
assert entry["pipeline_tag"] == "text-classification"
assert entry["adapter_recommendation"] == "ZeroShotAdapter"
# ── POST /queue — 409 duplicate ────────────────────────────────────────────────
def test_queue_duplicate_returns_409(client):
"""Posting the same repo_id twice should return 409."""
_queue_one(client, "org/dup-model")
r = client.post("/api/models/queue", json={
"repo_id": "org/dup-model",
"pipeline_tag": "text-classification",
"adapter_recommendation": "ZeroShotAdapter",
})
assert r.status_code == 409
def test_queue_multiple_different_models(client):
"""Multiple distinct repo_ids should all be accepted."""
_queue_one(client, "org/model-a")
_queue_one(client, "org/model-b")
_queue_one(client, "org/model-c")
r = client.get("/api/models/queue")
assert r.status_code == 200
assert len(r.json()) == 3
# ── DELETE /queue/{id} — dismiss ──────────────────────────────────────────────
def test_queue_dismiss(client):
"""DELETE /queue/{id} sets status=dismissed; entry not returned by GET /queue."""
entry = _queue_one(client)
entry_id = entry["id"]
r = client.delete(f"/api/models/queue/{entry_id}")
assert r.status_code == 200
assert r.json() == {"ok": True}
r2 = client.get("/api/models/queue")
assert r2.status_code == 200
assert r2.json() == []
def test_queue_dismiss_nonexistent_returns_404(client):
"""DELETE /queue/{id} with unknown id returns 404."""
r = client.delete("/api/models/queue/does-not-exist")
assert r.status_code == 404
def test_queue_dismiss_allows_re_queue(client):
"""After dismissal the same repo_id can be queued again."""
entry = _queue_one(client, "org/requeue-model")
client.delete(f"/api/models/queue/{entry['id']}")
r = client.post("/api/models/queue", json={
"repo_id": "org/requeue-model",
"pipeline_tag": None,
"adapter_recommendation": None,
})
assert r.status_code == 201
# ── POST /queue/{id}/approve ───────────────────────────────────────────────────
def test_approve_nonexistent_returns_404(client):
"""Approving an unknown id returns 404."""
r = client.post("/api/models/queue/ghost-id/approve")
assert r.status_code == 404
def test_approve_non_pending_returns_409(client):
"""Approving an entry that is not in 'pending' state returns 409."""
from app import models as models_module
entry = _queue_one(client)
# Manually flip status to 'failed'
models_module._update_queue_entry(entry["id"], {"status": "failed"})
r = client.post(f"/api/models/queue/{entry['id']}/approve")
assert r.status_code == 409
def test_approve_starts_download_and_returns_ok(client):
"""Approving a pending entry returns {ok: true} and starts a background thread."""
import time
import threading
entry = _queue_one(client)
# Patch snapshot_download so the thread doesn't actually hit the network.
# Use an Event so we can wait for the thread to finish before asserting.
thread_done = threading.Event()
original_run = None
def _fake_snapshot_download(**kwargs):
pass
with patch("app.models.snapshot_download", side_effect=_fake_snapshot_download):
r = client.post(f"/api/models/queue/{entry['id']}/approve")
assert r.status_code == 200
assert r.json() == {"ok": True}
# Give the background thread a moment to complete while snapshot_download is patched
time.sleep(0.3)
# Queue entry status should have moved to 'downloading' (or 'ready' if fast)
from app import models as models_module
updated = models_module._get_queue_entry(entry["id"])
assert updated is not None, "Queue entry not found — thread may have run after fixture teardown"
assert updated["status"] in ("downloading", "ready", "failed")
# ── GET /download/stream ───────────────────────────────────────────────────────
def test_download_stream_idle_when_no_download(client):
"""GET /download/stream returns a single idle event when nothing is downloading."""
r = client.get("/api/models/download/stream")
assert r.status_code == 200
# SSE body should contain the idle event
assert "idle" in r.text
# ── GET /installed ─────────────────────────────────────────────────────────────
def test_installed_empty(client):
"""GET /installed returns [] when models dir is empty."""
r = client.get("/api/models/installed")
assert r.status_code == 200
assert r.json() == []
def test_installed_detects_downloaded_model(client, tmp_path):
"""A subdir with config.json is surfaced as type='downloaded'."""
from app import models as models_module
model_dir = models_module._MODELS_DIR / "org--mymodel"
model_dir.mkdir()
(model_dir / "config.json").write_text(json.dumps({"model_type": "bert"}), encoding="utf-8")
(model_dir / "model_info.json").write_text(
json.dumps({"repo_id": "org/mymodel", "adapter_recommendation": "ZeroShotAdapter"}),
encoding="utf-8",
)
r = client.get("/api/models/installed")
assert r.status_code == 200
items = r.json()
assert len(items) == 1
assert items[0]["type"] == "downloaded"
assert items[0]["name"] == "org--mymodel"
assert items[0]["adapter"] == "ZeroShotAdapter"
assert items[0]["model_id"] == "org/mymodel"
def test_installed_detects_finetuned_model(client):
"""A subdir with training_info.json is surfaced as type='finetuned'."""
from app import models as models_module
model_dir = models_module._MODELS_DIR / "my-finetuned"
model_dir.mkdir()
(model_dir / "training_info.json").write_text(
json.dumps({"base_model": "org/base", "epochs": 5}), encoding="utf-8"
)
r = client.get("/api/models/installed")
assert r.status_code == 200
items = r.json()
assert len(items) == 1
assert items[0]["type"] == "finetuned"
assert items[0]["name"] == "my-finetuned"
# ── DELETE /installed/{name} ───────────────────────────────────────────────────
def test_delete_installed_removes_directory(client):
"""DELETE /installed/{name} removes the directory and returns {ok: true}."""
from app import models as models_module
model_dir = models_module._MODELS_DIR / "org--removeme"
model_dir.mkdir()
(model_dir / "config.json").write_text("{}", encoding="utf-8")
r = client.delete("/api/models/installed/org--removeme")
assert r.status_code == 200
assert r.json() == {"ok": True}
assert not model_dir.exists()
def test_delete_installed_not_found_returns_404(client):
r = client.delete("/api/models/installed/does-not-exist")
assert r.status_code == 404
def test_delete_installed_path_traversal_blocked(client):
"""DELETE /installed/../../etc must be blocked (400 or 422)."""
r = client.delete("/api/models/installed/../../etc")
assert r.status_code in (400, 404, 422)
def test_delete_installed_dotdot_name_blocked(client):
"""A name containing '..' in any form must be rejected."""
r = client.delete("/api/models/installed/..%2F..%2Fetc")
assert r.status_code in (400, 404, 422)
def test_delete_installed_name_with_slash_blocked(client):
"""A name containing a literal '/' after URL decoding must be rejected."""
from app import models as models_module
# The router will see the path segment after /installed/ — a second '/' would
# be parsed as a new path segment, so we test via the validation helper directly.
with pytest.raises(Exception):
# Simulate calling delete logic with a slash-containing name directly
from fastapi import HTTPException as _HTTPException
from app.models import delete_installed
try:
delete_installed("org/traversal")
except _HTTPException as exc:
assert exc.status_code in (400, 404)
raise