- New /api/models router: HF lookup, approval queue (JSONL persistence),
SSE download progress via snapshot_download(), installed model listing,
path-traversal-safe DELETE
- pipeline_tag → adapter type mapping (zero-shot-classification,
sentence-similarity, text-generation)
- 27 tests covering all endpoints, duplicate detection, path traversal
- ModelsView.vue: HF lookup + add, approval queue, live download progress
bars via SSE, installed model table with delete
- Sidebar entry (🤗 Models) between Benchmark and Corrections
428 lines
16 KiB
Python
428 lines
16 KiB
Python
"""Avocet — HF model lifecycle API.
|
|
|
|
Handles model metadata lookup, approval queue, download with progress,
|
|
and installed model management.
|
|
|
|
All endpoints are registered on `router` (a FastAPI APIRouter).
|
|
api.py includes this router with prefix="/api/models".
|
|
|
|
Module-level globals (_MODELS_DIR, _QUEUE_DIR) follow the same
|
|
testability pattern as sft.py — override them via set_models_dir() and
|
|
set_queue_dir() in test fixtures.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
import shutil
|
|
import threading
|
|
from datetime import datetime, timezone
|
|
from pathlib import Path
|
|
from typing import Any
|
|
from uuid import uuid4
|
|
|
|
import httpx
|
|
from fastapi import APIRouter, HTTPException
|
|
from fastapi.responses import StreamingResponse
|
|
from pydantic import BaseModel
|
|
|
|
from app.utils import read_jsonl, write_jsonl
|
|
|
|
try:
|
|
from huggingface_hub import snapshot_download
|
|
except ImportError: # pragma: no cover
|
|
snapshot_download = None # type: ignore[assignment]
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_ROOT = Path(__file__).parent.parent
|
|
_MODELS_DIR: Path = _ROOT / "models"
|
|
_QUEUE_DIR: Path = _ROOT / "data"
|
|
|
|
router = APIRouter()
|
|
|
|
# ── Download progress shared state ────────────────────────────────────────────
|
|
# Updated by the background download thread; read by GET /download/stream.
|
|
_download_progress: dict[str, Any] = {}
|
|
|
|
# ── HF pipeline_tag → adapter recommendation ──────────────────────────────────
|
|
_TAG_TO_ADAPTER: dict[str, str] = {
|
|
"zero-shot-classification": "ZeroShotAdapter",
|
|
"text-classification": "ZeroShotAdapter",
|
|
"natural-language-inference": "ZeroShotAdapter",
|
|
"sentence-similarity": "RerankerAdapter",
|
|
"text-ranking": "RerankerAdapter",
|
|
"text-generation": "GenerationAdapter",
|
|
"text2text-generation": "GenerationAdapter",
|
|
}
|
|
|
|
|
|
# ── Testability seams ──────────────────────────────────────────────────────────
|
|
|
|
def set_models_dir(path: Path) -> None:
|
|
global _MODELS_DIR
|
|
_MODELS_DIR = path
|
|
|
|
|
|
def set_queue_dir(path: Path) -> None:
|
|
global _QUEUE_DIR
|
|
_QUEUE_DIR = path
|
|
|
|
|
|
# ── Internal helpers ───────────────────────────────────────────────────────────
|
|
|
|
def _queue_file() -> Path:
|
|
return _QUEUE_DIR / "model_queue.jsonl"
|
|
|
|
|
|
def _read_queue() -> list[dict]:
|
|
return read_jsonl(_queue_file())
|
|
|
|
|
|
def _write_queue(records: list[dict]) -> None:
|
|
write_jsonl(_queue_file(), records)
|
|
|
|
|
|
def _safe_model_name(repo_id: str) -> str:
|
|
"""Convert repo_id to a filesystem-safe directory name (HF convention)."""
|
|
return repo_id.replace("/", "--")
|
|
|
|
|
|
def _is_installed(repo_id: str) -> bool:
|
|
"""Check if a model is already downloaded in _MODELS_DIR."""
|
|
safe_name = _safe_model_name(repo_id)
|
|
model_dir = _MODELS_DIR / safe_name
|
|
return model_dir.exists() and (
|
|
(model_dir / "config.json").exists()
|
|
or (model_dir / "training_info.json").exists()
|
|
or (model_dir / "model_info.json").exists()
|
|
)
|
|
|
|
|
|
def _is_queued(repo_id: str) -> bool:
|
|
"""Check if repo_id is already in the queue (non-dismissed)."""
|
|
for entry in _read_queue():
|
|
if entry.get("repo_id") == repo_id and entry.get("status") != "dismissed":
|
|
return True
|
|
return False
|
|
|
|
|
|
def _update_queue_entry(entry_id: str, updates: dict) -> dict | None:
|
|
"""Update a queue entry by id. Returns updated entry or None if not found."""
|
|
records = _read_queue()
|
|
for i, r in enumerate(records):
|
|
if r.get("id") == entry_id:
|
|
records[i] = {**r, **updates}
|
|
_write_queue(records)
|
|
return records[i]
|
|
return None
|
|
|
|
|
|
def _get_queue_entry(entry_id: str) -> dict | None:
|
|
for r in _read_queue():
|
|
if r.get("id") == entry_id:
|
|
return r
|
|
return None
|
|
|
|
|
|
# ── Background download ────────────────────────────────────────────────────────
|
|
|
|
def _run_download(entry_id: str, repo_id: str, pipeline_tag: str | None, adapter_recommendation: str | None) -> None:
|
|
"""Background thread: download model via huggingface_hub.snapshot_download."""
|
|
global _download_progress
|
|
safe_name = _safe_model_name(repo_id)
|
|
local_dir = _MODELS_DIR / safe_name
|
|
|
|
_download_progress = {
|
|
"active": True,
|
|
"repo_id": repo_id,
|
|
"downloaded_bytes": 0,
|
|
"total_bytes": 0,
|
|
"pct": 0.0,
|
|
"done": False,
|
|
"error": None,
|
|
}
|
|
|
|
try:
|
|
if snapshot_download is None:
|
|
raise RuntimeError("huggingface_hub is not installed")
|
|
|
|
snapshot_download(
|
|
repo_id=repo_id,
|
|
local_dir=str(local_dir),
|
|
)
|
|
|
|
# Write model_info.json alongside downloaded files
|
|
model_info = {
|
|
"repo_id": repo_id,
|
|
"pipeline_tag": pipeline_tag,
|
|
"adapter_recommendation": adapter_recommendation,
|
|
"downloaded_at": datetime.now(timezone.utc).isoformat(),
|
|
}
|
|
local_dir.mkdir(parents=True, exist_ok=True)
|
|
(local_dir / "model_info.json").write_text(
|
|
json.dumps(model_info, indent=2), encoding="utf-8"
|
|
)
|
|
|
|
_download_progress["done"] = True
|
|
_download_progress["pct"] = 100.0
|
|
_update_queue_entry(entry_id, {"status": "ready"})
|
|
|
|
except Exception as exc:
|
|
logger.exception("Download failed for %s: %s", repo_id, exc)
|
|
_download_progress["error"] = str(exc)
|
|
_download_progress["done"] = True
|
|
_update_queue_entry(entry_id, {"status": "failed", "error": str(exc)})
|
|
finally:
|
|
_download_progress["active"] = False
|
|
|
|
|
|
# ── GET /lookup ────────────────────────────────────────────────────────────────
|
|
|
|
@router.get("/lookup")
|
|
def lookup_model(repo_id: str) -> dict:
|
|
"""Validate repo_id and fetch metadata from the HF API."""
|
|
# Validate: must contain exactly one '/', no whitespace
|
|
if "/" not in repo_id or any(c.isspace() for c in repo_id):
|
|
raise HTTPException(422, f"Invalid repo_id {repo_id!r}: must be 'owner/model-name' with no whitespace")
|
|
|
|
hf_url = f"https://huggingface.co/api/models/{repo_id}"
|
|
try:
|
|
resp = httpx.get(hf_url, timeout=10.0)
|
|
except httpx.RequestError as exc:
|
|
raise HTTPException(502, f"Network error reaching HuggingFace API: {exc}") from exc
|
|
|
|
if resp.status_code == 404:
|
|
raise HTTPException(404, f"Model {repo_id!r} not found on HuggingFace")
|
|
if resp.status_code != 200:
|
|
raise HTTPException(502, f"HuggingFace API returned status {resp.status_code}")
|
|
|
|
data = resp.json()
|
|
pipeline_tag = data.get("pipeline_tag")
|
|
adapter_recommendation = _TAG_TO_ADAPTER.get(pipeline_tag) if pipeline_tag else None
|
|
if pipeline_tag and adapter_recommendation is None:
|
|
logger.warning("Unknown pipeline_tag %r for %s — no adapter recommendation", pipeline_tag, repo_id)
|
|
|
|
# Estimate model size from siblings list
|
|
siblings = data.get("siblings") or []
|
|
model_size_bytes: int = sum(s.get("size", 0) for s in siblings if isinstance(s, dict))
|
|
|
|
# Description: first 300 chars of card data (modelId field used as fallback)
|
|
card_data = data.get("cardData") or {}
|
|
description_raw = card_data.get("description") or data.get("modelId") or ""
|
|
description = description_raw[:300] if description_raw else ""
|
|
|
|
return {
|
|
"repo_id": repo_id,
|
|
"pipeline_tag": pipeline_tag,
|
|
"adapter_recommendation": adapter_recommendation,
|
|
"model_size_bytes": model_size_bytes,
|
|
"description": description,
|
|
"tags": data.get("tags") or [],
|
|
"downloads": data.get("downloads") or 0,
|
|
"already_installed": _is_installed(repo_id),
|
|
"already_queued": _is_queued(repo_id),
|
|
}
|
|
|
|
|
|
# ── GET /queue ─────────────────────────────────────────────────────────────────
|
|
|
|
@router.get("/queue")
|
|
def get_queue() -> list[dict]:
|
|
"""Return all non-dismissed queue entries sorted newest-first."""
|
|
records = _read_queue()
|
|
active = [r for r in records if r.get("status") != "dismissed"]
|
|
return sorted(active, key=lambda r: r.get("queued_at", ""), reverse=True)
|
|
|
|
|
|
# ── POST /queue ────────────────────────────────────────────────────────────────
|
|
|
|
class QueueAddRequest(BaseModel):
|
|
repo_id: str
|
|
pipeline_tag: str | None = None
|
|
adapter_recommendation: str | None = None
|
|
|
|
|
|
@router.post("/queue", status_code=201)
|
|
def add_to_queue(req: QueueAddRequest) -> dict:
|
|
"""Add a model to the approval queue with status 'pending'."""
|
|
if _is_installed(req.repo_id):
|
|
raise HTTPException(409, f"{req.repo_id!r} is already installed")
|
|
if _is_queued(req.repo_id):
|
|
raise HTTPException(409, f"{req.repo_id!r} is already in the queue")
|
|
|
|
entry = {
|
|
"id": str(uuid4()),
|
|
"repo_id": req.repo_id,
|
|
"pipeline_tag": req.pipeline_tag,
|
|
"adapter_recommendation": req.adapter_recommendation,
|
|
"status": "pending",
|
|
"queued_at": datetime.now(timezone.utc).isoformat(),
|
|
}
|
|
records = _read_queue()
|
|
records.append(entry)
|
|
_write_queue(records)
|
|
return entry
|
|
|
|
|
|
# ── POST /queue/{id}/approve ───────────────────────────────────────────────────
|
|
|
|
@router.post("/queue/{entry_id}/approve")
|
|
def approve_queue_entry(entry_id: str) -> dict:
|
|
"""Approve a pending queue entry and start background download."""
|
|
entry = _get_queue_entry(entry_id)
|
|
if entry is None:
|
|
raise HTTPException(404, f"Queue entry {entry_id!r} not found")
|
|
if entry.get("status") != "pending":
|
|
raise HTTPException(409, f"Entry is not in pending state (current: {entry.get('status')!r})")
|
|
|
|
updated = _update_queue_entry(entry_id, {"status": "downloading"})
|
|
|
|
thread = threading.Thread(
|
|
target=_run_download,
|
|
args=(entry_id, entry["repo_id"], entry.get("pipeline_tag"), entry.get("adapter_recommendation")),
|
|
daemon=True,
|
|
name=f"model-download-{entry_id}",
|
|
)
|
|
thread.start()
|
|
|
|
return {"ok": True}
|
|
|
|
|
|
# ── DELETE /queue/{id} ─────────────────────────────────────────────────────────
|
|
|
|
@router.delete("/queue/{entry_id}")
|
|
def dismiss_queue_entry(entry_id: str) -> dict:
|
|
"""Dismiss (soft-delete) a queue entry."""
|
|
entry = _get_queue_entry(entry_id)
|
|
if entry is None:
|
|
raise HTTPException(404, f"Queue entry {entry_id!r} not found")
|
|
|
|
_update_queue_entry(entry_id, {"status": "dismissed"})
|
|
return {"ok": True}
|
|
|
|
|
|
# ── GET /download/stream ───────────────────────────────────────────────────────
|
|
|
|
@router.get("/download/stream")
|
|
def download_stream() -> StreamingResponse:
|
|
"""SSE stream of download progress. Yields one idle event if no download active."""
|
|
|
|
def generate():
|
|
prog = _download_progress
|
|
if not prog.get("active") and not (prog.get("done") and not prog.get("error")):
|
|
yield f"data: {json.dumps({'type': 'idle'})}\n\n"
|
|
return
|
|
|
|
if prog.get("done"):
|
|
if prog.get("error"):
|
|
yield f"data: {json.dumps({'type': 'error', 'error': prog['error']})}\n\n"
|
|
else:
|
|
yield f"data: {json.dumps({'type': 'done', 'repo_id': prog.get('repo_id')})}\n\n"
|
|
return
|
|
|
|
# Stream live progress
|
|
import time
|
|
while True:
|
|
p = dict(_download_progress)
|
|
if p.get("done"):
|
|
if p.get("error"):
|
|
yield f"data: {json.dumps({'type': 'error', 'error': p['error']})}\n\n"
|
|
else:
|
|
yield f"data: {json.dumps({'type': 'done', 'repo_id': p.get('repo_id')})}\n\n"
|
|
break
|
|
event = json.dumps({
|
|
"type": "progress",
|
|
"repo_id": p.get("repo_id"),
|
|
"downloaded_bytes": p.get("downloaded_bytes", 0),
|
|
"total_bytes": p.get("total_bytes", 0),
|
|
"pct": p.get("pct", 0.0),
|
|
})
|
|
yield f"data: {event}\n\n"
|
|
time.sleep(0.5)
|
|
|
|
return StreamingResponse(
|
|
generate(),
|
|
media_type="text/event-stream",
|
|
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
|
|
)
|
|
|
|
|
|
# ── GET /installed ─────────────────────────────────────────────────────────────
|
|
|
|
@router.get("/installed")
|
|
def list_installed() -> list[dict]:
|
|
"""Scan _MODELS_DIR and return info on each installed model."""
|
|
if not _MODELS_DIR.exists():
|
|
return []
|
|
|
|
results: list[dict] = []
|
|
for sub in _MODELS_DIR.iterdir():
|
|
if not sub.is_dir():
|
|
continue
|
|
|
|
has_training_info = (sub / "training_info.json").exists()
|
|
has_config = (sub / "config.json").exists()
|
|
has_model_info = (sub / "model_info.json").exists()
|
|
|
|
if not (has_training_info or has_config or has_model_info):
|
|
continue
|
|
|
|
model_type = "finetuned" if has_training_info else "downloaded"
|
|
|
|
# Compute directory size
|
|
size_bytes = sum(f.stat().st_size for f in sub.rglob("*") if f.is_file())
|
|
|
|
# Load adapter/model_id from model_info.json or training_info.json
|
|
adapter: str | None = None
|
|
model_id: str | None = None
|
|
|
|
if has_model_info:
|
|
try:
|
|
info = json.loads((sub / "model_info.json").read_text(encoding="utf-8"))
|
|
adapter = info.get("adapter_recommendation")
|
|
model_id = info.get("repo_id")
|
|
except Exception:
|
|
pass
|
|
elif has_training_info:
|
|
try:
|
|
info = json.loads((sub / "training_info.json").read_text(encoding="utf-8"))
|
|
adapter = info.get("adapter")
|
|
model_id = info.get("base_model") or info.get("model_id")
|
|
except Exception:
|
|
pass
|
|
|
|
results.append({
|
|
"name": sub.name,
|
|
"path": str(sub),
|
|
"type": model_type,
|
|
"adapter": adapter,
|
|
"size_bytes": size_bytes,
|
|
"model_id": model_id,
|
|
})
|
|
|
|
return results
|
|
|
|
|
|
# ── DELETE /installed/{name} ───────────────────────────────────────────────────
|
|
|
|
@router.delete("/installed/{name}")
|
|
def delete_installed(name: str) -> dict:
|
|
"""Remove an installed model directory by name. Blocks path traversal."""
|
|
# Validate: single path component, no slashes or '..'
|
|
if "/" in name or "\\" in name or ".." in name or not name or name.startswith("."):
|
|
raise HTTPException(400, f"Invalid model name {name!r}: must be a single directory name with no path separators or '..'")
|
|
|
|
model_path = _MODELS_DIR / name
|
|
|
|
# Extra safety: confirm resolved path is inside _MODELS_DIR
|
|
try:
|
|
model_path.resolve().relative_to(_MODELS_DIR.resolve())
|
|
except ValueError:
|
|
raise HTTPException(400, f"Path traversal detected for name {name!r}")
|
|
|
|
if not model_path.exists():
|
|
raise HTTPException(404, f"Installed model {name!r} not found")
|
|
|
|
shutil.rmtree(model_path)
|
|
return {"ok": True}
|