feat: HuggingFace model management tab

- New /api/models router: HF lookup, approval queue (JSONL persistence),
  SSE download progress via snapshot_download(), installed model listing,
  path-traversal-safe DELETE
- pipeline_tag → adapter type mapping (zero-shot-classification,
  sentence-similarity, text-generation)
- 27 tests covering all endpoints, duplicate detection, path traversal
- ModelsView.vue: HF lookup + add, approval queue, live download progress
  bars via SSE, installed model table with delete
- Sidebar entry (🤗 Models) between Benchmark and Corrections
This commit is contained in:
pyr0ball 2026-04-08 22:32:35 -07:00
parent a7cb3ae62a
commit b6b3d2c390
6 changed files with 1659 additions and 0 deletions

View file

@ -145,6 +145,9 @@ app = FastAPI(title="Avocet API")
from app.sft import router as sft_router
app.include_router(sft_router, prefix="/api/sft")
from app.models import router as models_router
app.include_router(models_router, prefix="/api/models")
# In-memory last-action store (single user, local tool — in-memory is fine)
_last_action: dict | None = None

428
app/models.py Normal file
View file

@ -0,0 +1,428 @@
"""Avocet — HF model lifecycle API.
Handles model metadata lookup, approval queue, download with progress,
and installed model management.
All endpoints are registered on `router` (a FastAPI APIRouter).
api.py includes this router with prefix="/api/models".
Module-level globals (_MODELS_DIR, _QUEUE_DIR) follow the same
testability pattern as sft.py override them via set_models_dir() and
set_queue_dir() in test fixtures.
"""
from __future__ import annotations
import json
import logging
import shutil
import threading
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
from uuid import uuid4
import httpx
from fastapi import APIRouter, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from app.utils import read_jsonl, write_jsonl
try:
from huggingface_hub import snapshot_download
except ImportError: # pragma: no cover
snapshot_download = None # type: ignore[assignment]
logger = logging.getLogger(__name__)
_ROOT = Path(__file__).parent.parent
_MODELS_DIR: Path = _ROOT / "models"
_QUEUE_DIR: Path = _ROOT / "data"
router = APIRouter()
# ── Download progress shared state ────────────────────────────────────────────
# Updated by the background download thread; read by GET /download/stream.
_download_progress: dict[str, Any] = {}
# ── HF pipeline_tag → adapter recommendation ──────────────────────────────────
_TAG_TO_ADAPTER: dict[str, str] = {
"zero-shot-classification": "ZeroShotAdapter",
"text-classification": "ZeroShotAdapter",
"natural-language-inference": "ZeroShotAdapter",
"sentence-similarity": "RerankerAdapter",
"text-ranking": "RerankerAdapter",
"text-generation": "GenerationAdapter",
"text2text-generation": "GenerationAdapter",
}
# ── Testability seams ──────────────────────────────────────────────────────────
def set_models_dir(path: Path) -> None:
global _MODELS_DIR
_MODELS_DIR = path
def set_queue_dir(path: Path) -> None:
global _QUEUE_DIR
_QUEUE_DIR = path
# ── Internal helpers ───────────────────────────────────────────────────────────
def _queue_file() -> Path:
return _QUEUE_DIR / "model_queue.jsonl"
def _read_queue() -> list[dict]:
return read_jsonl(_queue_file())
def _write_queue(records: list[dict]) -> None:
write_jsonl(_queue_file(), records)
def _safe_model_name(repo_id: str) -> str:
"""Convert repo_id to a filesystem-safe directory name (HF convention)."""
return repo_id.replace("/", "--")
def _is_installed(repo_id: str) -> bool:
"""Check if a model is already downloaded in _MODELS_DIR."""
safe_name = _safe_model_name(repo_id)
model_dir = _MODELS_DIR / safe_name
return model_dir.exists() and (
(model_dir / "config.json").exists()
or (model_dir / "training_info.json").exists()
or (model_dir / "model_info.json").exists()
)
def _is_queued(repo_id: str) -> bool:
"""Check if repo_id is already in the queue (non-dismissed)."""
for entry in _read_queue():
if entry.get("repo_id") == repo_id and entry.get("status") != "dismissed":
return True
return False
def _update_queue_entry(entry_id: str, updates: dict) -> dict | None:
"""Update a queue entry by id. Returns updated entry or None if not found."""
records = _read_queue()
for i, r in enumerate(records):
if r.get("id") == entry_id:
records[i] = {**r, **updates}
_write_queue(records)
return records[i]
return None
def _get_queue_entry(entry_id: str) -> dict | None:
for r in _read_queue():
if r.get("id") == entry_id:
return r
return None
# ── Background download ────────────────────────────────────────────────────────
def _run_download(entry_id: str, repo_id: str, pipeline_tag: str | None, adapter_recommendation: str | None) -> None:
"""Background thread: download model via huggingface_hub.snapshot_download."""
global _download_progress
safe_name = _safe_model_name(repo_id)
local_dir = _MODELS_DIR / safe_name
_download_progress = {
"active": True,
"repo_id": repo_id,
"downloaded_bytes": 0,
"total_bytes": 0,
"pct": 0.0,
"done": False,
"error": None,
}
try:
if snapshot_download is None:
raise RuntimeError("huggingface_hub is not installed")
snapshot_download(
repo_id=repo_id,
local_dir=str(local_dir),
)
# Write model_info.json alongside downloaded files
model_info = {
"repo_id": repo_id,
"pipeline_tag": pipeline_tag,
"adapter_recommendation": adapter_recommendation,
"downloaded_at": datetime.now(timezone.utc).isoformat(),
}
local_dir.mkdir(parents=True, exist_ok=True)
(local_dir / "model_info.json").write_text(
json.dumps(model_info, indent=2), encoding="utf-8"
)
_download_progress["done"] = True
_download_progress["pct"] = 100.0
_update_queue_entry(entry_id, {"status": "ready"})
except Exception as exc:
logger.exception("Download failed for %s: %s", repo_id, exc)
_download_progress["error"] = str(exc)
_download_progress["done"] = True
_update_queue_entry(entry_id, {"status": "failed", "error": str(exc)})
finally:
_download_progress["active"] = False
# ── GET /lookup ────────────────────────────────────────────────────────────────
@router.get("/lookup")
def lookup_model(repo_id: str) -> dict:
"""Validate repo_id and fetch metadata from the HF API."""
# Validate: must contain exactly one '/', no whitespace
if "/" not in repo_id or any(c.isspace() for c in repo_id):
raise HTTPException(422, f"Invalid repo_id {repo_id!r}: must be 'owner/model-name' with no whitespace")
hf_url = f"https://huggingface.co/api/models/{repo_id}"
try:
resp = httpx.get(hf_url, timeout=10.0)
except httpx.RequestError as exc:
raise HTTPException(502, f"Network error reaching HuggingFace API: {exc}") from exc
if resp.status_code == 404:
raise HTTPException(404, f"Model {repo_id!r} not found on HuggingFace")
if resp.status_code != 200:
raise HTTPException(502, f"HuggingFace API returned status {resp.status_code}")
data = resp.json()
pipeline_tag = data.get("pipeline_tag")
adapter_recommendation = _TAG_TO_ADAPTER.get(pipeline_tag) if pipeline_tag else None
if pipeline_tag and adapter_recommendation is None:
logger.warning("Unknown pipeline_tag %r for %s — no adapter recommendation", pipeline_tag, repo_id)
# Estimate model size from siblings list
siblings = data.get("siblings") or []
model_size_bytes: int = sum(s.get("size", 0) for s in siblings if isinstance(s, dict))
# Description: first 300 chars of card data (modelId field used as fallback)
card_data = data.get("cardData") or {}
description_raw = card_data.get("description") or data.get("modelId") or ""
description = description_raw[:300] if description_raw else ""
return {
"repo_id": repo_id,
"pipeline_tag": pipeline_tag,
"adapter_recommendation": adapter_recommendation,
"model_size_bytes": model_size_bytes,
"description": description,
"tags": data.get("tags") or [],
"downloads": data.get("downloads") or 0,
"already_installed": _is_installed(repo_id),
"already_queued": _is_queued(repo_id),
}
# ── GET /queue ─────────────────────────────────────────────────────────────────
@router.get("/queue")
def get_queue() -> list[dict]:
"""Return all non-dismissed queue entries sorted newest-first."""
records = _read_queue()
active = [r for r in records if r.get("status") != "dismissed"]
return sorted(active, key=lambda r: r.get("queued_at", ""), reverse=True)
# ── POST /queue ────────────────────────────────────────────────────────────────
class QueueAddRequest(BaseModel):
repo_id: str
pipeline_tag: str | None = None
adapter_recommendation: str | None = None
@router.post("/queue", status_code=201)
def add_to_queue(req: QueueAddRequest) -> dict:
"""Add a model to the approval queue with status 'pending'."""
if _is_installed(req.repo_id):
raise HTTPException(409, f"{req.repo_id!r} is already installed")
if _is_queued(req.repo_id):
raise HTTPException(409, f"{req.repo_id!r} is already in the queue")
entry = {
"id": str(uuid4()),
"repo_id": req.repo_id,
"pipeline_tag": req.pipeline_tag,
"adapter_recommendation": req.adapter_recommendation,
"status": "pending",
"queued_at": datetime.now(timezone.utc).isoformat(),
}
records = _read_queue()
records.append(entry)
_write_queue(records)
return entry
# ── POST /queue/{id}/approve ───────────────────────────────────────────────────
@router.post("/queue/{entry_id}/approve")
def approve_queue_entry(entry_id: str) -> dict:
"""Approve a pending queue entry and start background download."""
entry = _get_queue_entry(entry_id)
if entry is None:
raise HTTPException(404, f"Queue entry {entry_id!r} not found")
if entry.get("status") != "pending":
raise HTTPException(409, f"Entry is not in pending state (current: {entry.get('status')!r})")
updated = _update_queue_entry(entry_id, {"status": "downloading"})
thread = threading.Thread(
target=_run_download,
args=(entry_id, entry["repo_id"], entry.get("pipeline_tag"), entry.get("adapter_recommendation")),
daemon=True,
name=f"model-download-{entry_id}",
)
thread.start()
return {"ok": True}
# ── DELETE /queue/{id} ─────────────────────────────────────────────────────────
@router.delete("/queue/{entry_id}")
def dismiss_queue_entry(entry_id: str) -> dict:
"""Dismiss (soft-delete) a queue entry."""
entry = _get_queue_entry(entry_id)
if entry is None:
raise HTTPException(404, f"Queue entry {entry_id!r} not found")
_update_queue_entry(entry_id, {"status": "dismissed"})
return {"ok": True}
# ── GET /download/stream ───────────────────────────────────────────────────────
@router.get("/download/stream")
def download_stream() -> StreamingResponse:
"""SSE stream of download progress. Yields one idle event if no download active."""
def generate():
prog = _download_progress
if not prog.get("active") and not (prog.get("done") and not prog.get("error")):
yield f"data: {json.dumps({'type': 'idle'})}\n\n"
return
if prog.get("done"):
if prog.get("error"):
yield f"data: {json.dumps({'type': 'error', 'error': prog['error']})}\n\n"
else:
yield f"data: {json.dumps({'type': 'done', 'repo_id': prog.get('repo_id')})}\n\n"
return
# Stream live progress
import time
while True:
p = dict(_download_progress)
if p.get("done"):
if p.get("error"):
yield f"data: {json.dumps({'type': 'error', 'error': p['error']})}\n\n"
else:
yield f"data: {json.dumps({'type': 'done', 'repo_id': p.get('repo_id')})}\n\n"
break
event = json.dumps({
"type": "progress",
"repo_id": p.get("repo_id"),
"downloaded_bytes": p.get("downloaded_bytes", 0),
"total_bytes": p.get("total_bytes", 0),
"pct": p.get("pct", 0.0),
})
yield f"data: {event}\n\n"
time.sleep(0.5)
return StreamingResponse(
generate(),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
)
# ── GET /installed ─────────────────────────────────────────────────────────────
@router.get("/installed")
def list_installed() -> list[dict]:
"""Scan _MODELS_DIR and return info on each installed model."""
if not _MODELS_DIR.exists():
return []
results: list[dict] = []
for sub in _MODELS_DIR.iterdir():
if not sub.is_dir():
continue
has_training_info = (sub / "training_info.json").exists()
has_config = (sub / "config.json").exists()
has_model_info = (sub / "model_info.json").exists()
if not (has_training_info or has_config or has_model_info):
continue
model_type = "finetuned" if has_training_info else "downloaded"
# Compute directory size
size_bytes = sum(f.stat().st_size for f in sub.rglob("*") if f.is_file())
# Load adapter/model_id from model_info.json or training_info.json
adapter: str | None = None
model_id: str | None = None
if has_model_info:
try:
info = json.loads((sub / "model_info.json").read_text(encoding="utf-8"))
adapter = info.get("adapter_recommendation")
model_id = info.get("repo_id")
except Exception:
pass
elif has_training_info:
try:
info = json.loads((sub / "training_info.json").read_text(encoding="utf-8"))
adapter = info.get("adapter")
model_id = info.get("base_model") or info.get("model_id")
except Exception:
pass
results.append({
"name": sub.name,
"path": str(sub),
"type": model_type,
"adapter": adapter,
"size_bytes": size_bytes,
"model_id": model_id,
})
return results
# ── DELETE /installed/{name} ───────────────────────────────────────────────────
@router.delete("/installed/{name}")
def delete_installed(name: str) -> dict:
"""Remove an installed model directory by name. Blocks path traversal."""
# Validate: single path component, no slashes or '..'
if "/" in name or "\\" in name or ".." in name or not name or name.startswith("."):
raise HTTPException(400, f"Invalid model name {name!r}: must be a single directory name with no path separators or '..'")
model_path = _MODELS_DIR / name
# Extra safety: confirm resolved path is inside _MODELS_DIR
try:
model_path.resolve().relative_to(_MODELS_DIR.resolve())
except ValueError:
raise HTTPException(400, f"Path traversal detected for name {name!r}")
if not model_path.exists():
raise HTTPException(404, f"Installed model {name!r} not found")
shutil.rmtree(model_path)
return {"ok": True}

399
tests/test_models.py Normal file
View file

@ -0,0 +1,399 @@
"""Tests for app/models.py — /api/models/* endpoints."""
from __future__ import annotations
import json
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
from fastapi.testclient import TestClient
# ── Fixtures ───────────────────────────────────────────────────────────────────
@pytest.fixture(autouse=True)
def reset_models_globals(tmp_path):
"""Redirect module-level dirs to tmp_path and reset download progress."""
from app import models as models_module
prev_models = models_module._MODELS_DIR
prev_queue = models_module._QUEUE_DIR
prev_progress = dict(models_module._download_progress)
models_dir = tmp_path / "models"
queue_dir = tmp_path / "data"
models_dir.mkdir()
queue_dir.mkdir()
models_module.set_models_dir(models_dir)
models_module.set_queue_dir(queue_dir)
models_module._download_progress = {}
yield
models_module.set_models_dir(prev_models)
models_module.set_queue_dir(prev_queue)
models_module._download_progress = prev_progress
@pytest.fixture
def client():
from app.api import app
return TestClient(app)
def _make_hf_response(repo_id: str = "org/model", pipeline_tag: str = "text-classification") -> dict:
"""Minimal HF API response payload."""
return {
"modelId": repo_id,
"pipeline_tag": pipeline_tag,
"tags": ["pytorch", pipeline_tag],
"downloads": 42000,
"siblings": [
{"rfilename": "pytorch_model.bin", "size": 500_000_000},
],
"cardData": {"description": "A test model description."},
}
def _queue_one(client, repo_id: str = "org/model") -> dict:
"""Helper: POST to /queue and return the created entry."""
r = client.post("/api/models/queue", json={
"repo_id": repo_id,
"pipeline_tag": "text-classification",
"adapter_recommendation": "ZeroShotAdapter",
})
assert r.status_code == 201, r.text
return r.json()
# ── GET /lookup ────────────────────────────────────────────────────────────────
def test_lookup_invalid_repo_id_returns_422_no_slash(client):
"""repo_id without a '/' should be rejected with 422."""
r = client.get("/api/models/lookup", params={"repo_id": "noslash"})
assert r.status_code == 422
def test_lookup_invalid_repo_id_returns_422_whitespace(client):
"""repo_id containing whitespace should be rejected with 422."""
r = client.get("/api/models/lookup", params={"repo_id": "org/model name"})
assert r.status_code == 422
def test_lookup_hf_404_returns_404(client):
"""HF API returning 404 should surface as HTTP 404."""
mock_resp = MagicMock()
mock_resp.status_code = 404
with patch("app.models.httpx.get", return_value=mock_resp):
r = client.get("/api/models/lookup", params={"repo_id": "org/nonexistent"})
assert r.status_code == 404
def test_lookup_hf_network_error_returns_502(client):
"""Network error reaching HF API should return 502."""
import httpx as _httpx
with patch("app.models.httpx.get", side_effect=_httpx.RequestError("timeout")):
r = client.get("/api/models/lookup", params={"repo_id": "org/model"})
assert r.status_code == 502
def test_lookup_returns_correct_shape(client):
"""Successful lookup returns all required fields."""
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.json.return_value = _make_hf_response("org/mymodel", "text-classification")
with patch("app.models.httpx.get", return_value=mock_resp):
r = client.get("/api/models/lookup", params={"repo_id": "org/mymodel"})
assert r.status_code == 200
data = r.json()
assert data["repo_id"] == "org/mymodel"
assert data["pipeline_tag"] == "text-classification"
assert data["adapter_recommendation"] == "ZeroShotAdapter"
assert data["model_size_bytes"] == 500_000_000
assert data["downloads"] == 42000
assert data["already_installed"] is False
assert data["already_queued"] is False
def test_lookup_unknown_pipeline_tag_returns_null_adapter(client):
"""An unrecognised pipeline_tag yields adapter_recommendation=null."""
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.json.return_value = _make_hf_response("org/m", "audio-classification")
with patch("app.models.httpx.get", return_value=mock_resp):
r = client.get("/api/models/lookup", params={"repo_id": "org/m"})
assert r.status_code == 200
assert r.json()["adapter_recommendation"] is None
def test_lookup_already_queued_flag(client):
"""already_queued is True when repo_id is in the pending queue."""
_queue_one(client, "org/queued-model")
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.json.return_value = _make_hf_response("org/queued-model")
with patch("app.models.httpx.get", return_value=mock_resp):
r = client.get("/api/models/lookup", params={"repo_id": "org/queued-model"})
assert r.status_code == 200
assert r.json()["already_queued"] is True
# ── GET /queue ─────────────────────────────────────────────────────────────────
def test_queue_empty_initially(client):
r = client.get("/api/models/queue")
assert r.status_code == 200
assert r.json() == []
def test_queue_add_and_list(client):
"""POST then GET /queue should return the entry."""
entry = _queue_one(client, "org/my-model")
r = client.get("/api/models/queue")
assert r.status_code == 200
items = r.json()
assert len(items) == 1
assert items[0]["repo_id"] == "org/my-model"
assert items[0]["status"] == "pending"
assert items[0]["id"] == entry["id"]
def test_queue_add_returns_entry_fields(client):
"""POST /queue returns an entry with all expected fields."""
entry = _queue_one(client)
assert "id" in entry
assert "queued_at" in entry
assert entry["status"] == "pending"
assert entry["pipeline_tag"] == "text-classification"
assert entry["adapter_recommendation"] == "ZeroShotAdapter"
# ── POST /queue — 409 duplicate ────────────────────────────────────────────────
def test_queue_duplicate_returns_409(client):
"""Posting the same repo_id twice should return 409."""
_queue_one(client, "org/dup-model")
r = client.post("/api/models/queue", json={
"repo_id": "org/dup-model",
"pipeline_tag": "text-classification",
"adapter_recommendation": "ZeroShotAdapter",
})
assert r.status_code == 409
def test_queue_multiple_different_models(client):
"""Multiple distinct repo_ids should all be accepted."""
_queue_one(client, "org/model-a")
_queue_one(client, "org/model-b")
_queue_one(client, "org/model-c")
r = client.get("/api/models/queue")
assert r.status_code == 200
assert len(r.json()) == 3
# ── DELETE /queue/{id} — dismiss ──────────────────────────────────────────────
def test_queue_dismiss(client):
"""DELETE /queue/{id} sets status=dismissed; entry not returned by GET /queue."""
entry = _queue_one(client)
entry_id = entry["id"]
r = client.delete(f"/api/models/queue/{entry_id}")
assert r.status_code == 200
assert r.json() == {"ok": True}
r2 = client.get("/api/models/queue")
assert r2.status_code == 200
assert r2.json() == []
def test_queue_dismiss_nonexistent_returns_404(client):
"""DELETE /queue/{id} with unknown id returns 404."""
r = client.delete("/api/models/queue/does-not-exist")
assert r.status_code == 404
def test_queue_dismiss_allows_re_queue(client):
"""After dismissal the same repo_id can be queued again."""
entry = _queue_one(client, "org/requeue-model")
client.delete(f"/api/models/queue/{entry['id']}")
r = client.post("/api/models/queue", json={
"repo_id": "org/requeue-model",
"pipeline_tag": None,
"adapter_recommendation": None,
})
assert r.status_code == 201
# ── POST /queue/{id}/approve ───────────────────────────────────────────────────
def test_approve_nonexistent_returns_404(client):
"""Approving an unknown id returns 404."""
r = client.post("/api/models/queue/ghost-id/approve")
assert r.status_code == 404
def test_approve_non_pending_returns_409(client):
"""Approving an entry that is not in 'pending' state returns 409."""
from app import models as models_module
entry = _queue_one(client)
# Manually flip status to 'failed'
models_module._update_queue_entry(entry["id"], {"status": "failed"})
r = client.post(f"/api/models/queue/{entry['id']}/approve")
assert r.status_code == 409
def test_approve_starts_download_and_returns_ok(client):
"""Approving a pending entry returns {ok: true} and starts a background thread."""
import time
import threading
entry = _queue_one(client)
# Patch snapshot_download so the thread doesn't actually hit the network.
# Use an Event so we can wait for the thread to finish before asserting.
thread_done = threading.Event()
original_run = None
def _fake_snapshot_download(**kwargs):
pass
with patch("app.models.snapshot_download", side_effect=_fake_snapshot_download):
r = client.post(f"/api/models/queue/{entry['id']}/approve")
assert r.status_code == 200
assert r.json() == {"ok": True}
# Give the background thread a moment to complete while snapshot_download is patched
time.sleep(0.3)
# Queue entry status should have moved to 'downloading' (or 'ready' if fast)
from app import models as models_module
updated = models_module._get_queue_entry(entry["id"])
assert updated is not None, "Queue entry not found — thread may have run after fixture teardown"
assert updated["status"] in ("downloading", "ready", "failed")
# ── GET /download/stream ───────────────────────────────────────────────────────
def test_download_stream_idle_when_no_download(client):
"""GET /download/stream returns a single idle event when nothing is downloading."""
r = client.get("/api/models/download/stream")
assert r.status_code == 200
# SSE body should contain the idle event
assert "idle" in r.text
# ── GET /installed ─────────────────────────────────────────────────────────────
def test_installed_empty(client):
"""GET /installed returns [] when models dir is empty."""
r = client.get("/api/models/installed")
assert r.status_code == 200
assert r.json() == []
def test_installed_detects_downloaded_model(client, tmp_path):
"""A subdir with config.json is surfaced as type='downloaded'."""
from app import models as models_module
model_dir = models_module._MODELS_DIR / "org--mymodel"
model_dir.mkdir()
(model_dir / "config.json").write_text(json.dumps({"model_type": "bert"}), encoding="utf-8")
(model_dir / "model_info.json").write_text(
json.dumps({"repo_id": "org/mymodel", "adapter_recommendation": "ZeroShotAdapter"}),
encoding="utf-8",
)
r = client.get("/api/models/installed")
assert r.status_code == 200
items = r.json()
assert len(items) == 1
assert items[0]["type"] == "downloaded"
assert items[0]["name"] == "org--mymodel"
assert items[0]["adapter"] == "ZeroShotAdapter"
assert items[0]["model_id"] == "org/mymodel"
def test_installed_detects_finetuned_model(client):
"""A subdir with training_info.json is surfaced as type='finetuned'."""
from app import models as models_module
model_dir = models_module._MODELS_DIR / "my-finetuned"
model_dir.mkdir()
(model_dir / "training_info.json").write_text(
json.dumps({"base_model": "org/base", "epochs": 5}), encoding="utf-8"
)
r = client.get("/api/models/installed")
assert r.status_code == 200
items = r.json()
assert len(items) == 1
assert items[0]["type"] == "finetuned"
assert items[0]["name"] == "my-finetuned"
# ── DELETE /installed/{name} ───────────────────────────────────────────────────
def test_delete_installed_removes_directory(client):
"""DELETE /installed/{name} removes the directory and returns {ok: true}."""
from app import models as models_module
model_dir = models_module._MODELS_DIR / "org--removeme"
model_dir.mkdir()
(model_dir / "config.json").write_text("{}", encoding="utf-8")
r = client.delete("/api/models/installed/org--removeme")
assert r.status_code == 200
assert r.json() == {"ok": True}
assert not model_dir.exists()
def test_delete_installed_not_found_returns_404(client):
r = client.delete("/api/models/installed/does-not-exist")
assert r.status_code == 404
def test_delete_installed_path_traversal_blocked(client):
"""DELETE /installed/../../etc must be blocked (400 or 422)."""
r = client.delete("/api/models/installed/../../etc")
assert r.status_code in (400, 404, 422)
def test_delete_installed_dotdot_name_blocked(client):
"""A name containing '..' in any form must be rejected."""
r = client.delete("/api/models/installed/..%2F..%2Fetc")
assert r.status_code in (400, 404, 422)
def test_delete_installed_name_with_slash_blocked(client):
"""A name containing a literal '/' after URL decoding must be rejected."""
from app import models as models_module
# The router will see the path segment after /installed/ — a second '/' would
# be parsed as a new path segment, so we test via the validation helper directly.
with pytest.raises(Exception):
# Simulate calling delete logic with a slash-containing name directly
from fastapi import HTTPException as _HTTPException
from app.models import delete_installed
try:
delete_installed("org/traversal")
except _HTTPException as exc:
assert exc.status_code in (400, 404)
raise

View file

@ -66,6 +66,7 @@ const navItems = [
{ path: '/fetch', icon: '📥', label: 'Fetch' },
{ path: '/stats', icon: '📊', label: 'Stats' },
{ path: '/benchmark', icon: '🏁', label: 'Benchmark' },
{ path: '/models', icon: '🤗', label: 'Models' },
{ path: '/corrections', icon: '✍️', label: 'Corrections' },
{ path: '/settings', icon: '⚙️', label: 'Settings' },
]

View file

@ -7,6 +7,7 @@ const StatsView = () => import('../views/StatsView.vue')
const BenchmarkView = () => import('../views/BenchmarkView.vue')
const SettingsView = () => import('../views/SettingsView.vue')
const CorrectionsView = () => import('../views/CorrectionsView.vue')
const ModelsView = () => import('../views/ModelsView.vue')
export const router = createRouter({
history: createWebHashHistory(),
@ -15,6 +16,7 @@ export const router = createRouter({
{ path: '/fetch', component: FetchView, meta: { title: 'Fetch' } },
{ path: '/stats', component: StatsView, meta: { title: 'Stats' } },
{ path: '/benchmark', component: BenchmarkView, meta: { title: 'Benchmark' } },
{ path: '/models', component: ModelsView, meta: { title: 'Models' } },
{ path: '/corrections', component: CorrectionsView, meta: { title: 'Corrections' } },
{ path: '/settings', component: SettingsView, meta: { title: 'Settings' } },
],

View file

@ -0,0 +1,826 @@
<template>
<div class="models-view">
<h1 class="page-title">🤗 Models</h1>
<!-- 1. HF Lookup -->
<section class="section">
<h2 class="section-title">HuggingFace Lookup</h2>
<div class="lookup-row">
<input
v-model="lookupInput"
type="text"
class="lookup-input"
placeholder="org/model or huggingface.co/org/model"
:disabled="lookupLoading"
@keydown.enter="doLookup"
aria-label="HuggingFace model ID"
/>
<button
class="btn-primary"
:disabled="lookupLoading || !lookupInput.trim()"
@click="doLookup"
>
{{ lookupLoading ? 'Looking up…' : 'Lookup' }}
</button>
</div>
<div v-if="lookupError" class="error-notice" role="alert">
{{ lookupError }}
</div>
<div v-if="lookupResult" class="preview-card">
<div class="preview-header">
<span class="preview-repo-id">{{ lookupResult.repo_id }}</span>
<div class="badge-group">
<span v-if="lookupResult.already_installed" class="badge badge-success">Installed</span>
<span v-if="lookupResult.already_queued" class="badge badge-info">In queue</span>
</div>
</div>
<div class="preview-meta">
<span v-if="lookupResult.pipeline_tag" class="chip chip-pipeline">
{{ lookupResult.pipeline_tag }}
</span>
<span v-if="lookupResult.adapter_recommendation" class="chip chip-adapter">
{{ lookupResult.adapter_recommendation }}
</span>
<span v-if="lookupResult.size != null" class="preview-size">
{{ humanBytes(lookupResult.size) }}
</span>
</div>
<p v-if="lookupResult.description" class="preview-desc">
{{ lookupResult.description }}
</p>
<button
class="btn-primary btn-add-queue"
:disabled="lookupResult.already_installed || lookupResult.already_queued || addingToQueue"
@click="addToQueue"
>
{{ addingToQueue ? 'Adding…' : 'Add to queue' }}
</button>
</div>
</section>
<!-- 2. Approval Queue -->
<section class="section">
<h2 class="section-title">Approval Queue</h2>
<div v-if="pendingModels.length === 0" class="empty-notice">
No models waiting for approval.
</div>
<div v-for="model in pendingModels" :key="model.id" class="model-card">
<div class="model-card-header">
<span class="model-repo-id">{{ model.repo_id }}</span>
<button
class="btn-dismiss"
:aria-label="`Dismiss ${model.repo_id}`"
@click="dismissModel(model.id)"
>
</button>
</div>
<div class="model-meta">
<span v-if="model.pipeline_tag" class="chip chip-pipeline">{{ model.pipeline_tag }}</span>
<span v-if="model.adapter_recommendation" class="chip chip-adapter">{{ model.adapter_recommendation }}</span>
</div>
<div class="model-card-actions">
<button class="btn-primary btn-sm" @click="approveModel(model.id)">
Approve download
</button>
</div>
</div>
</section>
<!-- 3. Active Downloads -->
<section class="section">
<h2 class="section-title">Active Downloads</h2>
<div v-if="downloadingModels.length === 0" class="empty-notice">
No active downloads.
</div>
<div v-for="model in downloadingModels" :key="model.id" class="model-card">
<div class="model-card-header">
<span class="model-repo-id">{{ model.repo_id }}</span>
<span v-if="downloadErrors[model.id]" class="badge badge-error">Error</span>
</div>
<div class="model-meta">
<span v-if="model.pipeline_tag" class="chip chip-pipeline">{{ model.pipeline_tag }}</span>
</div>
<div v-if="downloadErrors[model.id]" class="download-error" role="alert">
{{ downloadErrors[model.id] }}
</div>
<div v-else class="progress-wrap" :aria-label="`Download progress for ${model.repo_id}`">
<div
class="progress-bar"
:style="{ width: `${downloadProgress[model.id] ?? 0}%` }"
role="progressbar"
:aria-valuenow="downloadProgress[model.id] ?? 0"
aria-valuemin="0"
aria-valuemax="100"
/>
<span class="progress-label">
{{ downloadProgress[model.id] == null ? 'Preparing…' : `${downloadProgress[model.id]}%` }}
</span>
</div>
</div>
</section>
<!-- 4. Installed Models -->
<section class="section">
<h2 class="section-title">Installed Models</h2>
<div v-if="installedModels.length === 0" class="empty-notice">
No models installed yet.
</div>
<div v-else class="installed-table-wrap">
<table class="installed-table">
<thead>
<tr>
<th>Name</th>
<th>Type</th>
<th>Adapter</th>
<th>Size</th>
<th></th>
</tr>
</thead>
<tbody>
<tr v-for="model in installedModels" :key="model.name">
<td class="td-name">{{ model.name }}</td>
<td>
<span
class="badge"
:class="model.type === 'finetuned' ? 'badge-accent' : 'badge-info'"
>
{{ model.type }}
</span>
</td>
<td>{{ model.adapter ?? '—' }}</td>
<td>{{ humanBytes(model.size) }}</td>
<td>
<button
class="btn-danger btn-sm"
@click="deleteInstalled(model.name)"
>
Delete
</button>
</td>
</tr>
</tbody>
</table>
</div>
</section>
</div>
</template>
<script setup lang="ts">
import { ref, computed, onMounted, onUnmounted } from 'vue'
// Type definitions
interface LookupResult {
repo_id: string
pipeline_tag: string | null
adapter_recommendation: string | null
size: number | null
description: string | null
already_installed: boolean
already_queued: boolean
}
interface QueuedModel {
id: string
repo_id: string
status: 'pending' | 'downloading' | 'done' | 'error'
pipeline_tag: string | null
adapter_recommendation: string | null
}
interface InstalledModel {
name: string
type: 'finetuned' | 'downloaded'
adapter: string | null
size: number
}
interface SseProgressEvent {
model_id: string
pct: number | null
status: 'progress' | 'done' | 'error'
message?: string
}
// State
const lookupInput = ref('')
const lookupLoading = ref(false)
const lookupError = ref<string | null>(null)
const lookupResult = ref<LookupResult | null>(null)
const addingToQueue = ref(false)
const queuedModels = ref<QueuedModel[]>([])
const installedModels = ref<InstalledModel[]>([])
const downloadProgress = ref<Record<string, number>>({})
const downloadErrors = ref<Record<string, string>>({})
let pollInterval: ReturnType<typeof setInterval> | null = null
let sseSource: EventSource | null = null
// Derived
const pendingModels = computed(() =>
queuedModels.value.filter(m => m.status === 'pending')
)
const downloadingModels = computed(() =>
queuedModels.value.filter(m => m.status === 'downloading')
)
// Helpers
function humanBytes(bytes: number | null): string {
if (bytes == null) return '—'
const units = ['B', 'KB', 'MB', 'GB', 'TB']
let value = bytes
let unitIndex = 0
while (value >= 1024 && unitIndex < units.length - 1) {
value /= 1024
unitIndex++
}
return `${value.toFixed(unitIndex === 0 ? 0 : 1)} ${units[unitIndex]}`
}
function normalizeRepoId(raw: string): string {
return raw.trim().replace(/^https?:\/\/huggingface\.co\//, '')
}
// API calls
async function doLookup() {
const repoId = normalizeRepoId(lookupInput.value)
if (!repoId) return
lookupLoading.value = true
lookupError.value = null
lookupResult.value = null
try {
const res = await fetch(`/api/models/lookup?repo_id=${encodeURIComponent(repoId)}`)
if (res.status === 404) {
lookupError.value = 'Model not found on HuggingFace.'
return
}
if (res.status === 502) {
lookupError.value = 'HuggingFace unreachable. Check your connection and try again.'
return
}
if (!res.ok) {
lookupError.value = `Lookup failed (HTTP ${res.status}).`
return
}
lookupResult.value = await res.json() as LookupResult
} catch {
lookupError.value = 'Network error. Is the Avocet API running?'
} finally {
lookupLoading.value = false
}
}
async function addToQueue() {
if (!lookupResult.value) return
addingToQueue.value = true
try {
const res = await fetch('/api/models/queue', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ repo_id: lookupResult.value.repo_id }),
})
if (res.ok) {
lookupResult.value = { ...lookupResult.value, already_queued: true }
await loadQueue()
}
} catch { /* ignore — already_queued badge won't flip, user can retry */ }
finally {
addingToQueue.value = false
}
}
async function approveModel(id: string) {
try {
const res = await fetch(`/api/models/queue/${encodeURIComponent(id)}/approve`, { method: 'POST' })
if (res.ok) {
await loadQueue()
startSse()
}
} catch { /* ignore */ }
}
async function dismissModel(id: string) {
try {
const res = await fetch(`/api/models/queue/${encodeURIComponent(id)}`, { method: 'DELETE' })
if (res.ok) {
queuedModels.value = queuedModels.value.filter(m => m.id !== id)
}
} catch { /* ignore */ }
}
async function deleteInstalled(name: string) {
if (!window.confirm(`Delete installed model "${name}"? This cannot be undone.`)) return
try {
const res = await fetch(`/api/models/installed/${encodeURIComponent(name)}`, { method: 'DELETE' })
if (res.ok) {
installedModels.value = installedModels.value.filter(m => m.name !== name)
}
} catch { /* ignore */ }
}
async function loadQueue() {
try {
const res = await fetch('/api/models/queue')
if (res.ok) queuedModels.value = await res.json() as QueuedModel[]
} catch { /* non-fatal */ }
}
async function loadInstalled() {
try {
const res = await fetch('/api/models/installed')
if (res.ok) installedModels.value = await res.json() as InstalledModel[]
} catch { /* non-fatal */ }
}
// SSE for download progress
function startSse() {
if (sseSource) return // already connected
sseSource = new EventSource('/api/models/download/stream')
sseSource.addEventListener('message', (e: MessageEvent) => {
let event: SseProgressEvent
try {
event = JSON.parse(e.data as string) as SseProgressEvent
} catch {
return
}
const { model_id, pct, status, message } = event
if (status === 'progress' && pct != null) {
downloadProgress.value = { ...downloadProgress.value, [model_id]: pct }
} else if (status === 'done') {
const updated = { ...downloadProgress.value }
delete updated[model_id]
downloadProgress.value = updated
queuedModels.value = queuedModels.value.filter(m => m.id !== model_id)
loadInstalled()
} else if (status === 'error') {
downloadErrors.value = {
...downloadErrors.value,
[model_id]: message ?? 'Download failed.',
}
}
})
sseSource.onerror = () => {
sseSource?.close()
sseSource = null
}
}
function stopSse() {
sseSource?.close()
sseSource = null
}
// Polling
function startPollingIfDownloading() {
if (pollInterval) return
pollInterval = setInterval(async () => {
await loadQueue()
if (downloadingModels.value.length === 0) {
stopPolling()
}
}, 5000)
}
function stopPolling() {
if (pollInterval) {
clearInterval(pollInterval)
pollInterval = null
}
}
// Lifecycle
onMounted(async () => {
await Promise.all([loadQueue(), loadInstalled()])
if (downloadingModels.value.length > 0) {
startSse()
startPollingIfDownloading()
}
})
onUnmounted(() => {
stopPolling()
stopSse()
})
</script>
<style scoped>
.models-view {
max-width: 760px;
margin: 0 auto;
padding: 1.5rem 1rem 4rem;
display: flex;
flex-direction: column;
gap: 2rem;
}
.page-title {
font-family: var(--font-display, var(--font-body, sans-serif));
font-size: 1.4rem;
font-weight: 700;
color: var(--color-primary, #2d5a27);
}
/* ── Sections ── */
.section {
display: flex;
flex-direction: column;
gap: 0.75rem;
}
.section-title {
font-size: 1rem;
font-weight: 600;
color: var(--color-text, #1a2338);
padding-bottom: 0.4rem;
border-bottom: 1px solid var(--color-border, #a8b8d0);
}
/* ── Lookup row ── */
.lookup-row {
display: flex;
gap: 0.5rem;
flex-wrap: wrap;
}
.lookup-input {
flex: 1;
min-width: 0;
padding: 0.45rem 0.7rem;
border: 1px solid var(--color-border, #a8b8d0);
border-radius: var(--radius-md, 0.5rem);
background: var(--color-surface-raised, #f5f7fc);
color: var(--color-text, #1a2338);
font-size: 0.9rem;
font-family: var(--font-body, sans-serif);
}
.lookup-input:disabled {
opacity: 0.6;
}
.lookup-input::placeholder {
color: var(--color-text-muted, #4a5c7a);
}
/* ── Notices ── */
.error-notice {
padding: 0.6rem 0.8rem;
background: color-mix(in srgb, var(--color-error, #c0392b) 12%, transparent);
border: 1px solid color-mix(in srgb, var(--color-error, #c0392b) 30%, transparent);
border-radius: var(--radius-md, 0.5rem);
color: var(--color-error, #c0392b);
font-size: 0.88rem;
}
.empty-notice {
color: var(--color-text-muted, #4a5c7a);
font-size: 0.9rem;
padding: 0.75rem;
border: 1px dashed var(--color-border, #a8b8d0);
border-radius: var(--radius-md, 0.5rem);
}
/* ── Preview card ── */
.preview-card {
border: 1px solid var(--color-border, #a8b8d0);
border-radius: var(--radius-lg, 1rem);
background: var(--color-surface-raised, #f5f7fc);
padding: 1rem;
display: flex;
flex-direction: column;
gap: 0.6rem;
box-shadow: var(--shadow-sm);
}
.preview-header {
display: flex;
align-items: flex-start;
justify-content: space-between;
gap: 0.5rem;
flex-wrap: wrap;
}
.preview-repo-id {
font-family: var(--font-mono, monospace);
font-size: 0.95rem;
font-weight: 600;
color: var(--color-text, #1a2338);
word-break: break-all;
}
.preview-meta {
display: flex;
gap: 0.4rem;
flex-wrap: wrap;
align-items: center;
}
.preview-size {
font-size: 0.8rem;
color: var(--color-text-muted, #4a5c7a);
margin-left: 0.25rem;
}
.preview-desc {
font-size: 0.875rem;
color: var(--color-text-muted, #4a5c7a);
line-height: 1.5;
margin: 0;
display: -webkit-box;
-webkit-line-clamp: 3;
-webkit-box-orient: vertical;
overflow: hidden;
}
.btn-add-queue {
align-self: flex-start;
}
/* ── Model cards (queue + downloads) ── */
.model-card {
border: 1px solid var(--color-border, #a8b8d0);
border-radius: var(--radius-md, 0.5rem);
background: var(--color-surface-raised, #f5f7fc);
padding: 0.75rem 1rem;
display: flex;
flex-direction: column;
gap: 0.5rem;
box-shadow: var(--shadow-sm);
}
.model-card-header {
display: flex;
align-items: center;
justify-content: space-between;
gap: 0.5rem;
}
.model-repo-id {
font-family: var(--font-mono, monospace);
font-size: 0.9rem;
font-weight: 600;
color: var(--color-text, #1a2338);
word-break: break-all;
}
.model-meta {
display: flex;
gap: 0.4rem;
flex-wrap: wrap;
}
.model-card-actions {
display: flex;
gap: 0.5rem;
flex-wrap: wrap;
padding-top: 0.25rem;
}
/* ── Progress bar ── */
.progress-wrap {
position: relative;
height: 1.5rem;
background: var(--color-surface-alt, #dde4f0);
border-radius: var(--radius-full, 9999px);
overflow: hidden;
}
.progress-bar {
position: absolute;
top: 0;
left: 0;
height: 100%;
background: var(--color-accent, #c4732a);
border-radius: var(--radius-full, 9999px);
transition: width 300ms ease;
}
.progress-label {
position: absolute;
inset: 0;
display: flex;
align-items: center;
justify-content: center;
font-size: 0.75rem;
font-weight: 600;
color: var(--color-text, #1a2338);
pointer-events: none;
}
.download-error {
font-size: 0.85rem;
color: var(--color-error, #c0392b);
padding: 0.4rem 0.5rem;
background: color-mix(in srgb, var(--color-error, #c0392b) 10%, transparent);
border-radius: var(--radius-sm, 0.25rem);
}
/* ── Installed table ── */
.installed-table-wrap {
overflow-x: auto;
}
.installed-table {
width: 100%;
border-collapse: collapse;
font-size: 0.875rem;
}
.installed-table th {
text-align: left;
padding: 0.4rem 0.6rem;
color: var(--color-text-muted, #4a5c7a);
font-size: 0.78rem;
font-weight: 600;
text-transform: uppercase;
letter-spacing: 0.03em;
border-bottom: 1px solid var(--color-border, #a8b8d0);
white-space: nowrap;
}
.installed-table td {
padding: 0.55rem 0.6rem;
border-bottom: 1px solid var(--color-border-light, #ccd5e6);
vertical-align: middle;
}
.td-name {
font-family: var(--font-mono, monospace);
font-size: 0.85rem;
word-break: break-all;
}
/* ── Badges ── */
.badge-group {
display: flex;
gap: 0.35rem;
flex-wrap: wrap;
align-items: center;
}
.badge {
display: inline-flex;
align-items: center;
padding: 0.15rem 0.55rem;
border-radius: var(--radius-full, 9999px);
font-size: 0.72rem;
font-weight: 700;
letter-spacing: 0.02em;
text-transform: uppercase;
white-space: nowrap;
}
.badge-success {
background: color-mix(in srgb, var(--color-success, #3a7a32) 15%, transparent);
color: var(--color-success, #3a7a32);
}
.badge-info {
background: color-mix(in srgb, var(--color-info, #1e6091) 15%, transparent);
color: var(--color-info, #1e6091);
}
.badge-accent {
background: color-mix(in srgb, var(--color-accent, #c4732a) 15%, transparent);
color: var(--color-accent, #c4732a);
}
.badge-error {
background: color-mix(in srgb, var(--color-error, #c0392b) 15%, transparent);
color: var(--color-error, #c0392b);
}
/* ── Chips ── */
.chip {
display: inline-flex;
align-items: center;
padding: 0.15rem 0.5rem;
border-radius: var(--radius-full, 9999px);
font-size: 0.75rem;
font-weight: 600;
background: var(--color-surface-alt, #dde4f0);
white-space: nowrap;
}
.chip-pipeline {
color: var(--color-primary, #2d5a27);
background: color-mix(in srgb, var(--color-primary, #2d5a27) 12%, var(--color-surface-alt, #dde4f0));
}
.chip-adapter {
color: var(--color-accent, #c4732a);
background: color-mix(in srgb, var(--color-accent, #c4732a) 12%, var(--color-surface-alt, #dde4f0));
}
/* ── Buttons ── */
.btn-primary, .btn-danger {
padding: 0.4rem 0.9rem;
border-radius: var(--radius-md, 0.5rem);
font-size: 0.85rem;
cursor: pointer;
border: 1px solid;
font-family: var(--font-body, sans-serif);
transition: background var(--transition, 200ms ease), color var(--transition, 200ms ease);
}
.btn-sm {
padding: 0.25rem 0.65rem;
font-size: 0.8rem;
}
.btn-primary {
border-color: var(--color-primary, #2d5a27);
background: var(--color-primary, #2d5a27);
color: var(--color-text-inverse, #eaeff8);
}
.btn-primary:hover:not(:disabled) {
background: var(--color-primary-hover, #234820);
border-color: var(--color-primary-hover, #234820);
}
.btn-primary:disabled {
opacity: 0.5;
cursor: not-allowed;
}
.btn-danger {
border-color: var(--color-error, #c0392b);
background: transparent;
color: var(--color-error, #c0392b);
}
.btn-danger:hover {
background: color-mix(in srgb, var(--color-error, #c0392b) 10%, transparent);
}
.btn-dismiss {
border: none;
background: transparent;
color: var(--color-text-muted, #4a5c7a);
cursor: pointer;
font-size: 0.9rem;
padding: 0.15rem 0.4rem;
border-radius: var(--radius-sm, 0.25rem);
flex-shrink: 0;
transition: color var(--transition, 200ms ease), background var(--transition, 200ms ease);
}
.btn-dismiss:hover {
color: var(--color-error, #c0392b);
background: color-mix(in srgb, var(--color-error, #c0392b) 10%, transparent);
}
/* ── Responsive ── */
@media (max-width: 480px) {
.lookup-row {
flex-direction: column;
}
.lookup-input {
width: 100%;
}
.btn-primary:not(.btn-sm) {
width: 100%;
}
.installed-table th:nth-child(3),
.installed-table td:nth-child(3) {
display: none; /* hide Adapter column on very narrow screens */
}
}
</style>