avocet/app/models.py
pyr0ball b6b3d2c390 feat: HuggingFace model management tab
- New /api/models router: HF lookup, approval queue (JSONL persistence),
  SSE download progress via snapshot_download(), installed model listing,
  path-traversal-safe DELETE
- pipeline_tag → adapter type mapping (zero-shot-classification,
  sentence-similarity, text-generation)
- 27 tests covering all endpoints, duplicate detection, path traversal
- ModelsView.vue: HF lookup + add, approval queue, live download progress
  bars via SSE, installed model table with delete
- Sidebar entry (🤗 Models) between Benchmark and Corrections
2026-04-08 22:32:35 -07:00

428 lines
16 KiB
Python

"""Avocet — HF model lifecycle API.
Handles model metadata lookup, approval queue, download with progress,
and installed model management.
All endpoints are registered on `router` (a FastAPI APIRouter).
api.py includes this router with prefix="/api/models".
Module-level globals (_MODELS_DIR, _QUEUE_DIR) follow the same
testability pattern as sft.py — override them via set_models_dir() and
set_queue_dir() in test fixtures.
"""
from __future__ import annotations
import json
import logging
import shutil
import threading
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
from uuid import uuid4
import httpx
from fastapi import APIRouter, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from app.utils import read_jsonl, write_jsonl
try:
from huggingface_hub import snapshot_download
except ImportError: # pragma: no cover
snapshot_download = None # type: ignore[assignment]
logger = logging.getLogger(__name__)
_ROOT = Path(__file__).parent.parent
_MODELS_DIR: Path = _ROOT / "models"
_QUEUE_DIR: Path = _ROOT / "data"
router = APIRouter()
# ── Download progress shared state ────────────────────────────────────────────
# Updated by the background download thread; read by GET /download/stream.
_download_progress: dict[str, Any] = {}
# ── HF pipeline_tag → adapter recommendation ──────────────────────────────────
_TAG_TO_ADAPTER: dict[str, str] = {
"zero-shot-classification": "ZeroShotAdapter",
"text-classification": "ZeroShotAdapter",
"natural-language-inference": "ZeroShotAdapter",
"sentence-similarity": "RerankerAdapter",
"text-ranking": "RerankerAdapter",
"text-generation": "GenerationAdapter",
"text2text-generation": "GenerationAdapter",
}
# ── Testability seams ──────────────────────────────────────────────────────────
def set_models_dir(path: Path) -> None:
global _MODELS_DIR
_MODELS_DIR = path
def set_queue_dir(path: Path) -> None:
global _QUEUE_DIR
_QUEUE_DIR = path
# ── Internal helpers ───────────────────────────────────────────────────────────
def _queue_file() -> Path:
return _QUEUE_DIR / "model_queue.jsonl"
def _read_queue() -> list[dict]:
return read_jsonl(_queue_file())
def _write_queue(records: list[dict]) -> None:
write_jsonl(_queue_file(), records)
def _safe_model_name(repo_id: str) -> str:
"""Convert repo_id to a filesystem-safe directory name (HF convention)."""
return repo_id.replace("/", "--")
def _is_installed(repo_id: str) -> bool:
"""Check if a model is already downloaded in _MODELS_DIR."""
safe_name = _safe_model_name(repo_id)
model_dir = _MODELS_DIR / safe_name
return model_dir.exists() and (
(model_dir / "config.json").exists()
or (model_dir / "training_info.json").exists()
or (model_dir / "model_info.json").exists()
)
def _is_queued(repo_id: str) -> bool:
"""Check if repo_id is already in the queue (non-dismissed)."""
for entry in _read_queue():
if entry.get("repo_id") == repo_id and entry.get("status") != "dismissed":
return True
return False
def _update_queue_entry(entry_id: str, updates: dict) -> dict | None:
"""Update a queue entry by id. Returns updated entry or None if not found."""
records = _read_queue()
for i, r in enumerate(records):
if r.get("id") == entry_id:
records[i] = {**r, **updates}
_write_queue(records)
return records[i]
return None
def _get_queue_entry(entry_id: str) -> dict | None:
for r in _read_queue():
if r.get("id") == entry_id:
return r
return None
# ── Background download ────────────────────────────────────────────────────────
def _run_download(entry_id: str, repo_id: str, pipeline_tag: str | None, adapter_recommendation: str | None) -> None:
"""Background thread: download model via huggingface_hub.snapshot_download."""
global _download_progress
safe_name = _safe_model_name(repo_id)
local_dir = _MODELS_DIR / safe_name
_download_progress = {
"active": True,
"repo_id": repo_id,
"downloaded_bytes": 0,
"total_bytes": 0,
"pct": 0.0,
"done": False,
"error": None,
}
try:
if snapshot_download is None:
raise RuntimeError("huggingface_hub is not installed")
snapshot_download(
repo_id=repo_id,
local_dir=str(local_dir),
)
# Write model_info.json alongside downloaded files
model_info = {
"repo_id": repo_id,
"pipeline_tag": pipeline_tag,
"adapter_recommendation": adapter_recommendation,
"downloaded_at": datetime.now(timezone.utc).isoformat(),
}
local_dir.mkdir(parents=True, exist_ok=True)
(local_dir / "model_info.json").write_text(
json.dumps(model_info, indent=2), encoding="utf-8"
)
_download_progress["done"] = True
_download_progress["pct"] = 100.0
_update_queue_entry(entry_id, {"status": "ready"})
except Exception as exc:
logger.exception("Download failed for %s: %s", repo_id, exc)
_download_progress["error"] = str(exc)
_download_progress["done"] = True
_update_queue_entry(entry_id, {"status": "failed", "error": str(exc)})
finally:
_download_progress["active"] = False
# ── GET /lookup ────────────────────────────────────────────────────────────────
@router.get("/lookup")
def lookup_model(repo_id: str) -> dict:
"""Validate repo_id and fetch metadata from the HF API."""
# Validate: must contain exactly one '/', no whitespace
if "/" not in repo_id or any(c.isspace() for c in repo_id):
raise HTTPException(422, f"Invalid repo_id {repo_id!r}: must be 'owner/model-name' with no whitespace")
hf_url = f"https://huggingface.co/api/models/{repo_id}"
try:
resp = httpx.get(hf_url, timeout=10.0)
except httpx.RequestError as exc:
raise HTTPException(502, f"Network error reaching HuggingFace API: {exc}") from exc
if resp.status_code == 404:
raise HTTPException(404, f"Model {repo_id!r} not found on HuggingFace")
if resp.status_code != 200:
raise HTTPException(502, f"HuggingFace API returned status {resp.status_code}")
data = resp.json()
pipeline_tag = data.get("pipeline_tag")
adapter_recommendation = _TAG_TO_ADAPTER.get(pipeline_tag) if pipeline_tag else None
if pipeline_tag and adapter_recommendation is None:
logger.warning("Unknown pipeline_tag %r for %s — no adapter recommendation", pipeline_tag, repo_id)
# Estimate model size from siblings list
siblings = data.get("siblings") or []
model_size_bytes: int = sum(s.get("size", 0) for s in siblings if isinstance(s, dict))
# Description: first 300 chars of card data (modelId field used as fallback)
card_data = data.get("cardData") or {}
description_raw = card_data.get("description") or data.get("modelId") or ""
description = description_raw[:300] if description_raw else ""
return {
"repo_id": repo_id,
"pipeline_tag": pipeline_tag,
"adapter_recommendation": adapter_recommendation,
"model_size_bytes": model_size_bytes,
"description": description,
"tags": data.get("tags") or [],
"downloads": data.get("downloads") or 0,
"already_installed": _is_installed(repo_id),
"already_queued": _is_queued(repo_id),
}
# ── GET /queue ─────────────────────────────────────────────────────────────────
@router.get("/queue")
def get_queue() -> list[dict]:
"""Return all non-dismissed queue entries sorted newest-first."""
records = _read_queue()
active = [r for r in records if r.get("status") != "dismissed"]
return sorted(active, key=lambda r: r.get("queued_at", ""), reverse=True)
# ── POST /queue ────────────────────────────────────────────────────────────────
class QueueAddRequest(BaseModel):
repo_id: str
pipeline_tag: str | None = None
adapter_recommendation: str | None = None
@router.post("/queue", status_code=201)
def add_to_queue(req: QueueAddRequest) -> dict:
"""Add a model to the approval queue with status 'pending'."""
if _is_installed(req.repo_id):
raise HTTPException(409, f"{req.repo_id!r} is already installed")
if _is_queued(req.repo_id):
raise HTTPException(409, f"{req.repo_id!r} is already in the queue")
entry = {
"id": str(uuid4()),
"repo_id": req.repo_id,
"pipeline_tag": req.pipeline_tag,
"adapter_recommendation": req.adapter_recommendation,
"status": "pending",
"queued_at": datetime.now(timezone.utc).isoformat(),
}
records = _read_queue()
records.append(entry)
_write_queue(records)
return entry
# ── POST /queue/{id}/approve ───────────────────────────────────────────────────
@router.post("/queue/{entry_id}/approve")
def approve_queue_entry(entry_id: str) -> dict:
"""Approve a pending queue entry and start background download."""
entry = _get_queue_entry(entry_id)
if entry is None:
raise HTTPException(404, f"Queue entry {entry_id!r} not found")
if entry.get("status") != "pending":
raise HTTPException(409, f"Entry is not in pending state (current: {entry.get('status')!r})")
updated = _update_queue_entry(entry_id, {"status": "downloading"})
thread = threading.Thread(
target=_run_download,
args=(entry_id, entry["repo_id"], entry.get("pipeline_tag"), entry.get("adapter_recommendation")),
daemon=True,
name=f"model-download-{entry_id}",
)
thread.start()
return {"ok": True}
# ── DELETE /queue/{id} ─────────────────────────────────────────────────────────
@router.delete("/queue/{entry_id}")
def dismiss_queue_entry(entry_id: str) -> dict:
"""Dismiss (soft-delete) a queue entry."""
entry = _get_queue_entry(entry_id)
if entry is None:
raise HTTPException(404, f"Queue entry {entry_id!r} not found")
_update_queue_entry(entry_id, {"status": "dismissed"})
return {"ok": True}
# ── GET /download/stream ───────────────────────────────────────────────────────
@router.get("/download/stream")
def download_stream() -> StreamingResponse:
"""SSE stream of download progress. Yields one idle event if no download active."""
def generate():
prog = _download_progress
if not prog.get("active") and not (prog.get("done") and not prog.get("error")):
yield f"data: {json.dumps({'type': 'idle'})}\n\n"
return
if prog.get("done"):
if prog.get("error"):
yield f"data: {json.dumps({'type': 'error', 'error': prog['error']})}\n\n"
else:
yield f"data: {json.dumps({'type': 'done', 'repo_id': prog.get('repo_id')})}\n\n"
return
# Stream live progress
import time
while True:
p = dict(_download_progress)
if p.get("done"):
if p.get("error"):
yield f"data: {json.dumps({'type': 'error', 'error': p['error']})}\n\n"
else:
yield f"data: {json.dumps({'type': 'done', 'repo_id': p.get('repo_id')})}\n\n"
break
event = json.dumps({
"type": "progress",
"repo_id": p.get("repo_id"),
"downloaded_bytes": p.get("downloaded_bytes", 0),
"total_bytes": p.get("total_bytes", 0),
"pct": p.get("pct", 0.0),
})
yield f"data: {event}\n\n"
time.sleep(0.5)
return StreamingResponse(
generate(),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
)
# ── GET /installed ─────────────────────────────────────────────────────────────
@router.get("/installed")
def list_installed() -> list[dict]:
"""Scan _MODELS_DIR and return info on each installed model."""
if not _MODELS_DIR.exists():
return []
results: list[dict] = []
for sub in _MODELS_DIR.iterdir():
if not sub.is_dir():
continue
has_training_info = (sub / "training_info.json").exists()
has_config = (sub / "config.json").exists()
has_model_info = (sub / "model_info.json").exists()
if not (has_training_info or has_config or has_model_info):
continue
model_type = "finetuned" if has_training_info else "downloaded"
# Compute directory size
size_bytes = sum(f.stat().st_size for f in sub.rglob("*") if f.is_file())
# Load adapter/model_id from model_info.json or training_info.json
adapter: str | None = None
model_id: str | None = None
if has_model_info:
try:
info = json.loads((sub / "model_info.json").read_text(encoding="utf-8"))
adapter = info.get("adapter_recommendation")
model_id = info.get("repo_id")
except Exception:
pass
elif has_training_info:
try:
info = json.loads((sub / "training_info.json").read_text(encoding="utf-8"))
adapter = info.get("adapter")
model_id = info.get("base_model") or info.get("model_id")
except Exception:
pass
results.append({
"name": sub.name,
"path": str(sub),
"type": model_type,
"adapter": adapter,
"size_bytes": size_bytes,
"model_id": model_id,
})
return results
# ── DELETE /installed/{name} ───────────────────────────────────────────────────
@router.delete("/installed/{name}")
def delete_installed(name: str) -> dict:
"""Remove an installed model directory by name. Blocks path traversal."""
# Validate: single path component, no slashes or '..'
if "/" in name or "\\" in name or ".." in name or not name or name.startswith("."):
raise HTTPException(400, f"Invalid model name {name!r}: must be a single directory name with no path separators or '..'")
model_path = _MODELS_DIR / name
# Extra safety: confirm resolved path is inside _MODELS_DIR
try:
model_path.resolve().relative_to(_MODELS_DIR.resolve())
except ValueError:
raise HTTPException(400, f"Path traversal detected for name {name!r}")
if not model_path.exists():
raise HTTPException(404, f"Installed model {name!r} not found")
shutil.rmtree(model_path)
return {"ok": True}