diff --git a/app/api.py b/app/api.py index 3ce74c3..a96b88c 100644 --- a/app/api.py +++ b/app/api.py @@ -145,6 +145,9 @@ app = FastAPI(title="Avocet API") from app.sft import router as sft_router app.include_router(sft_router, prefix="/api/sft") +from app.models import router as models_router +app.include_router(models_router, prefix="/api/models") + # In-memory last-action store (single user, local tool — in-memory is fine) _last_action: dict | None = None diff --git a/app/models.py b/app/models.py new file mode 100644 index 0000000..0ac40a8 --- /dev/null +++ b/app/models.py @@ -0,0 +1,428 @@ +"""Avocet — HF model lifecycle API. + +Handles model metadata lookup, approval queue, download with progress, +and installed model management. + +All endpoints are registered on `router` (a FastAPI APIRouter). +api.py includes this router with prefix="/api/models". + +Module-level globals (_MODELS_DIR, _QUEUE_DIR) follow the same +testability pattern as sft.py — override them via set_models_dir() and +set_queue_dir() in test fixtures. +""" +from __future__ import annotations + +import json +import logging +import shutil +import threading +from datetime import datetime, timezone +from pathlib import Path +from typing import Any +from uuid import uuid4 + +import httpx +from fastapi import APIRouter, HTTPException +from fastapi.responses import StreamingResponse +from pydantic import BaseModel + +from app.utils import read_jsonl, write_jsonl + +try: + from huggingface_hub import snapshot_download +except ImportError: # pragma: no cover + snapshot_download = None # type: ignore[assignment] + +logger = logging.getLogger(__name__) + +_ROOT = Path(__file__).parent.parent +_MODELS_DIR: Path = _ROOT / "models" +_QUEUE_DIR: Path = _ROOT / "data" + +router = APIRouter() + +# ── Download progress shared state ──────────────────────────────────────────── +# Updated by the background download thread; read by GET /download/stream. +_download_progress: dict[str, Any] = {} + +# ── HF pipeline_tag → adapter recommendation ────────────────────────────────── +_TAG_TO_ADAPTER: dict[str, str] = { + "zero-shot-classification": "ZeroShotAdapter", + "text-classification": "ZeroShotAdapter", + "natural-language-inference": "ZeroShotAdapter", + "sentence-similarity": "RerankerAdapter", + "text-ranking": "RerankerAdapter", + "text-generation": "GenerationAdapter", + "text2text-generation": "GenerationAdapter", +} + + +# ── Testability seams ────────────────────────────────────────────────────────── + +def set_models_dir(path: Path) -> None: + global _MODELS_DIR + _MODELS_DIR = path + + +def set_queue_dir(path: Path) -> None: + global _QUEUE_DIR + _QUEUE_DIR = path + + +# ── Internal helpers ─────────────────────────────────────────────────────────── + +def _queue_file() -> Path: + return _QUEUE_DIR / "model_queue.jsonl" + + +def _read_queue() -> list[dict]: + return read_jsonl(_queue_file()) + + +def _write_queue(records: list[dict]) -> None: + write_jsonl(_queue_file(), records) + + +def _safe_model_name(repo_id: str) -> str: + """Convert repo_id to a filesystem-safe directory name (HF convention).""" + return repo_id.replace("/", "--") + + +def _is_installed(repo_id: str) -> bool: + """Check if a model is already downloaded in _MODELS_DIR.""" + safe_name = _safe_model_name(repo_id) + model_dir = _MODELS_DIR / safe_name + return model_dir.exists() and ( + (model_dir / "config.json").exists() + or (model_dir / "training_info.json").exists() + or (model_dir / "model_info.json").exists() + ) + + +def _is_queued(repo_id: str) -> bool: + """Check if repo_id is already in the queue (non-dismissed).""" + for entry in _read_queue(): + if entry.get("repo_id") == repo_id and entry.get("status") != "dismissed": + return True + return False + + +def _update_queue_entry(entry_id: str, updates: dict) -> dict | None: + """Update a queue entry by id. Returns updated entry or None if not found.""" + records = _read_queue() + for i, r in enumerate(records): + if r.get("id") == entry_id: + records[i] = {**r, **updates} + _write_queue(records) + return records[i] + return None + + +def _get_queue_entry(entry_id: str) -> dict | None: + for r in _read_queue(): + if r.get("id") == entry_id: + return r + return None + + +# ── Background download ──────────────────────────────────────────────────────── + +def _run_download(entry_id: str, repo_id: str, pipeline_tag: str | None, adapter_recommendation: str | None) -> None: + """Background thread: download model via huggingface_hub.snapshot_download.""" + global _download_progress + safe_name = _safe_model_name(repo_id) + local_dir = _MODELS_DIR / safe_name + + _download_progress = { + "active": True, + "repo_id": repo_id, + "downloaded_bytes": 0, + "total_bytes": 0, + "pct": 0.0, + "done": False, + "error": None, + } + + try: + if snapshot_download is None: + raise RuntimeError("huggingface_hub is not installed") + + snapshot_download( + repo_id=repo_id, + local_dir=str(local_dir), + ) + + # Write model_info.json alongside downloaded files + model_info = { + "repo_id": repo_id, + "pipeline_tag": pipeline_tag, + "adapter_recommendation": adapter_recommendation, + "downloaded_at": datetime.now(timezone.utc).isoformat(), + } + local_dir.mkdir(parents=True, exist_ok=True) + (local_dir / "model_info.json").write_text( + json.dumps(model_info, indent=2), encoding="utf-8" + ) + + _download_progress["done"] = True + _download_progress["pct"] = 100.0 + _update_queue_entry(entry_id, {"status": "ready"}) + + except Exception as exc: + logger.exception("Download failed for %s: %s", repo_id, exc) + _download_progress["error"] = str(exc) + _download_progress["done"] = True + _update_queue_entry(entry_id, {"status": "failed", "error": str(exc)}) + finally: + _download_progress["active"] = False + + +# ── GET /lookup ──────────────────────────────────────────────────────────────── + +@router.get("/lookup") +def lookup_model(repo_id: str) -> dict: + """Validate repo_id and fetch metadata from the HF API.""" + # Validate: must contain exactly one '/', no whitespace + if "/" not in repo_id or any(c.isspace() for c in repo_id): + raise HTTPException(422, f"Invalid repo_id {repo_id!r}: must be 'owner/model-name' with no whitespace") + + hf_url = f"https://huggingface.co/api/models/{repo_id}" + try: + resp = httpx.get(hf_url, timeout=10.0) + except httpx.RequestError as exc: + raise HTTPException(502, f"Network error reaching HuggingFace API: {exc}") from exc + + if resp.status_code == 404: + raise HTTPException(404, f"Model {repo_id!r} not found on HuggingFace") + if resp.status_code != 200: + raise HTTPException(502, f"HuggingFace API returned status {resp.status_code}") + + data = resp.json() + pipeline_tag = data.get("pipeline_tag") + adapter_recommendation = _TAG_TO_ADAPTER.get(pipeline_tag) if pipeline_tag else None + if pipeline_tag and adapter_recommendation is None: + logger.warning("Unknown pipeline_tag %r for %s — no adapter recommendation", pipeline_tag, repo_id) + + # Estimate model size from siblings list + siblings = data.get("siblings") or [] + model_size_bytes: int = sum(s.get("size", 0) for s in siblings if isinstance(s, dict)) + + # Description: first 300 chars of card data (modelId field used as fallback) + card_data = data.get("cardData") or {} + description_raw = card_data.get("description") or data.get("modelId") or "" + description = description_raw[:300] if description_raw else "" + + return { + "repo_id": repo_id, + "pipeline_tag": pipeline_tag, + "adapter_recommendation": adapter_recommendation, + "model_size_bytes": model_size_bytes, + "description": description, + "tags": data.get("tags") or [], + "downloads": data.get("downloads") or 0, + "already_installed": _is_installed(repo_id), + "already_queued": _is_queued(repo_id), + } + + +# ── GET /queue ───────────────────────────────────────────────────────────────── + +@router.get("/queue") +def get_queue() -> list[dict]: + """Return all non-dismissed queue entries sorted newest-first.""" + records = _read_queue() + active = [r for r in records if r.get("status") != "dismissed"] + return sorted(active, key=lambda r: r.get("queued_at", ""), reverse=True) + + +# ── POST /queue ──────────────────────────────────────────────────────────────── + +class QueueAddRequest(BaseModel): + repo_id: str + pipeline_tag: str | None = None + adapter_recommendation: str | None = None + + +@router.post("/queue", status_code=201) +def add_to_queue(req: QueueAddRequest) -> dict: + """Add a model to the approval queue with status 'pending'.""" + if _is_installed(req.repo_id): + raise HTTPException(409, f"{req.repo_id!r} is already installed") + if _is_queued(req.repo_id): + raise HTTPException(409, f"{req.repo_id!r} is already in the queue") + + entry = { + "id": str(uuid4()), + "repo_id": req.repo_id, + "pipeline_tag": req.pipeline_tag, + "adapter_recommendation": req.adapter_recommendation, + "status": "pending", + "queued_at": datetime.now(timezone.utc).isoformat(), + } + records = _read_queue() + records.append(entry) + _write_queue(records) + return entry + + +# ── POST /queue/{id}/approve ─────────────────────────────────────────────────── + +@router.post("/queue/{entry_id}/approve") +def approve_queue_entry(entry_id: str) -> dict: + """Approve a pending queue entry and start background download.""" + entry = _get_queue_entry(entry_id) + if entry is None: + raise HTTPException(404, f"Queue entry {entry_id!r} not found") + if entry.get("status") != "pending": + raise HTTPException(409, f"Entry is not in pending state (current: {entry.get('status')!r})") + + updated = _update_queue_entry(entry_id, {"status": "downloading"}) + + thread = threading.Thread( + target=_run_download, + args=(entry_id, entry["repo_id"], entry.get("pipeline_tag"), entry.get("adapter_recommendation")), + daemon=True, + name=f"model-download-{entry_id}", + ) + thread.start() + + return {"ok": True} + + +# ── DELETE /queue/{id} ───────────────────────────────────────────────────────── + +@router.delete("/queue/{entry_id}") +def dismiss_queue_entry(entry_id: str) -> dict: + """Dismiss (soft-delete) a queue entry.""" + entry = _get_queue_entry(entry_id) + if entry is None: + raise HTTPException(404, f"Queue entry {entry_id!r} not found") + + _update_queue_entry(entry_id, {"status": "dismissed"}) + return {"ok": True} + + +# ── GET /download/stream ─────────────────────────────────────────────────────── + +@router.get("/download/stream") +def download_stream() -> StreamingResponse: + """SSE stream of download progress. Yields one idle event if no download active.""" + + def generate(): + prog = _download_progress + if not prog.get("active") and not (prog.get("done") and not prog.get("error")): + yield f"data: {json.dumps({'type': 'idle'})}\n\n" + return + + if prog.get("done"): + if prog.get("error"): + yield f"data: {json.dumps({'type': 'error', 'error': prog['error']})}\n\n" + else: + yield f"data: {json.dumps({'type': 'done', 'repo_id': prog.get('repo_id')})}\n\n" + return + + # Stream live progress + import time + while True: + p = dict(_download_progress) + if p.get("done"): + if p.get("error"): + yield f"data: {json.dumps({'type': 'error', 'error': p['error']})}\n\n" + else: + yield f"data: {json.dumps({'type': 'done', 'repo_id': p.get('repo_id')})}\n\n" + break + event = json.dumps({ + "type": "progress", + "repo_id": p.get("repo_id"), + "downloaded_bytes": p.get("downloaded_bytes", 0), + "total_bytes": p.get("total_bytes", 0), + "pct": p.get("pct", 0.0), + }) + yield f"data: {event}\n\n" + time.sleep(0.5) + + return StreamingResponse( + generate(), + media_type="text/event-stream", + headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}, + ) + + +# ── GET /installed ───────────────────────────────────────────────────────────── + +@router.get("/installed") +def list_installed() -> list[dict]: + """Scan _MODELS_DIR and return info on each installed model.""" + if not _MODELS_DIR.exists(): + return [] + + results: list[dict] = [] + for sub in _MODELS_DIR.iterdir(): + if not sub.is_dir(): + continue + + has_training_info = (sub / "training_info.json").exists() + has_config = (sub / "config.json").exists() + has_model_info = (sub / "model_info.json").exists() + + if not (has_training_info or has_config or has_model_info): + continue + + model_type = "finetuned" if has_training_info else "downloaded" + + # Compute directory size + size_bytes = sum(f.stat().st_size for f in sub.rglob("*") if f.is_file()) + + # Load adapter/model_id from model_info.json or training_info.json + adapter: str | None = None + model_id: str | None = None + + if has_model_info: + try: + info = json.loads((sub / "model_info.json").read_text(encoding="utf-8")) + adapter = info.get("adapter_recommendation") + model_id = info.get("repo_id") + except Exception: + pass + elif has_training_info: + try: + info = json.loads((sub / "training_info.json").read_text(encoding="utf-8")) + adapter = info.get("adapter") + model_id = info.get("base_model") or info.get("model_id") + except Exception: + pass + + results.append({ + "name": sub.name, + "path": str(sub), + "type": model_type, + "adapter": adapter, + "size_bytes": size_bytes, + "model_id": model_id, + }) + + return results + + +# ── DELETE /installed/{name} ─────────────────────────────────────────────────── + +@router.delete("/installed/{name}") +def delete_installed(name: str) -> dict: + """Remove an installed model directory by name. Blocks path traversal.""" + # Validate: single path component, no slashes or '..' + if "/" in name or "\\" in name or ".." in name or not name or name.startswith("."): + raise HTTPException(400, f"Invalid model name {name!r}: must be a single directory name with no path separators or '..'") + + model_path = _MODELS_DIR / name + + # Extra safety: confirm resolved path is inside _MODELS_DIR + try: + model_path.resolve().relative_to(_MODELS_DIR.resolve()) + except ValueError: + raise HTTPException(400, f"Path traversal detected for name {name!r}") + + if not model_path.exists(): + raise HTTPException(404, f"Installed model {name!r} not found") + + shutil.rmtree(model_path) + return {"ok": True} diff --git a/tests/test_models.py b/tests/test_models.py new file mode 100644 index 0000000..0e31f04 --- /dev/null +++ b/tests/test_models.py @@ -0,0 +1,399 @@ +"""Tests for app/models.py — /api/models/* endpoints.""" +from __future__ import annotations + +import json +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest +from fastapi.testclient import TestClient + + +# ── Fixtures ─────────────────────────────────────────────────────────────────── + +@pytest.fixture(autouse=True) +def reset_models_globals(tmp_path): + """Redirect module-level dirs to tmp_path and reset download progress.""" + from app import models as models_module + + prev_models = models_module._MODELS_DIR + prev_queue = models_module._QUEUE_DIR + prev_progress = dict(models_module._download_progress) + + models_dir = tmp_path / "models" + queue_dir = tmp_path / "data" + models_dir.mkdir() + queue_dir.mkdir() + + models_module.set_models_dir(models_dir) + models_module.set_queue_dir(queue_dir) + models_module._download_progress = {} + + yield + + models_module.set_models_dir(prev_models) + models_module.set_queue_dir(prev_queue) + models_module._download_progress = prev_progress + + +@pytest.fixture +def client(): + from app.api import app + return TestClient(app) + + +def _make_hf_response(repo_id: str = "org/model", pipeline_tag: str = "text-classification") -> dict: + """Minimal HF API response payload.""" + return { + "modelId": repo_id, + "pipeline_tag": pipeline_tag, + "tags": ["pytorch", pipeline_tag], + "downloads": 42000, + "siblings": [ + {"rfilename": "pytorch_model.bin", "size": 500_000_000}, + ], + "cardData": {"description": "A test model description."}, + } + + +def _queue_one(client, repo_id: str = "org/model") -> dict: + """Helper: POST to /queue and return the created entry.""" + r = client.post("/api/models/queue", json={ + "repo_id": repo_id, + "pipeline_tag": "text-classification", + "adapter_recommendation": "ZeroShotAdapter", + }) + assert r.status_code == 201, r.text + return r.json() + + +# ── GET /lookup ──────────────────────────────────────────────────────────────── + +def test_lookup_invalid_repo_id_returns_422_no_slash(client): + """repo_id without a '/' should be rejected with 422.""" + r = client.get("/api/models/lookup", params={"repo_id": "noslash"}) + assert r.status_code == 422 + + +def test_lookup_invalid_repo_id_returns_422_whitespace(client): + """repo_id containing whitespace should be rejected with 422.""" + r = client.get("/api/models/lookup", params={"repo_id": "org/model name"}) + assert r.status_code == 422 + + +def test_lookup_hf_404_returns_404(client): + """HF API returning 404 should surface as HTTP 404.""" + mock_resp = MagicMock() + mock_resp.status_code = 404 + + with patch("app.models.httpx.get", return_value=mock_resp): + r = client.get("/api/models/lookup", params={"repo_id": "org/nonexistent"}) + + assert r.status_code == 404 + + +def test_lookup_hf_network_error_returns_502(client): + """Network error reaching HF API should return 502.""" + import httpx as _httpx + + with patch("app.models.httpx.get", side_effect=_httpx.RequestError("timeout")): + r = client.get("/api/models/lookup", params={"repo_id": "org/model"}) + + assert r.status_code == 502 + + +def test_lookup_returns_correct_shape(client): + """Successful lookup returns all required fields.""" + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.json.return_value = _make_hf_response("org/mymodel", "text-classification") + + with patch("app.models.httpx.get", return_value=mock_resp): + r = client.get("/api/models/lookup", params={"repo_id": "org/mymodel"}) + + assert r.status_code == 200 + data = r.json() + assert data["repo_id"] == "org/mymodel" + assert data["pipeline_tag"] == "text-classification" + assert data["adapter_recommendation"] == "ZeroShotAdapter" + assert data["model_size_bytes"] == 500_000_000 + assert data["downloads"] == 42000 + assert data["already_installed"] is False + assert data["already_queued"] is False + + +def test_lookup_unknown_pipeline_tag_returns_null_adapter(client): + """An unrecognised pipeline_tag yields adapter_recommendation=null.""" + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.json.return_value = _make_hf_response("org/m", "audio-classification") + + with patch("app.models.httpx.get", return_value=mock_resp): + r = client.get("/api/models/lookup", params={"repo_id": "org/m"}) + + assert r.status_code == 200 + assert r.json()["adapter_recommendation"] is None + + +def test_lookup_already_queued_flag(client): + """already_queued is True when repo_id is in the pending queue.""" + _queue_one(client, "org/queued-model") + + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.json.return_value = _make_hf_response("org/queued-model") + + with patch("app.models.httpx.get", return_value=mock_resp): + r = client.get("/api/models/lookup", params={"repo_id": "org/queued-model"}) + + assert r.status_code == 200 + assert r.json()["already_queued"] is True + + +# ── GET /queue ───────────────────────────────────────────────────────────────── + +def test_queue_empty_initially(client): + r = client.get("/api/models/queue") + assert r.status_code == 200 + assert r.json() == [] + + +def test_queue_add_and_list(client): + """POST then GET /queue should return the entry.""" + entry = _queue_one(client, "org/my-model") + + r = client.get("/api/models/queue") + assert r.status_code == 200 + items = r.json() + assert len(items) == 1 + assert items[0]["repo_id"] == "org/my-model" + assert items[0]["status"] == "pending" + assert items[0]["id"] == entry["id"] + + +def test_queue_add_returns_entry_fields(client): + """POST /queue returns an entry with all expected fields.""" + entry = _queue_one(client) + assert "id" in entry + assert "queued_at" in entry + assert entry["status"] == "pending" + assert entry["pipeline_tag"] == "text-classification" + assert entry["adapter_recommendation"] == "ZeroShotAdapter" + + +# ── POST /queue — 409 duplicate ──────────────────────────────────────────────── + +def test_queue_duplicate_returns_409(client): + """Posting the same repo_id twice should return 409.""" + _queue_one(client, "org/dup-model") + + r = client.post("/api/models/queue", json={ + "repo_id": "org/dup-model", + "pipeline_tag": "text-classification", + "adapter_recommendation": "ZeroShotAdapter", + }) + assert r.status_code == 409 + + +def test_queue_multiple_different_models(client): + """Multiple distinct repo_ids should all be accepted.""" + _queue_one(client, "org/model-a") + _queue_one(client, "org/model-b") + _queue_one(client, "org/model-c") + + r = client.get("/api/models/queue") + assert r.status_code == 200 + assert len(r.json()) == 3 + + +# ── DELETE /queue/{id} — dismiss ────────────────────────────────────────────── + +def test_queue_dismiss(client): + """DELETE /queue/{id} sets status=dismissed; entry not returned by GET /queue.""" + entry = _queue_one(client) + entry_id = entry["id"] + + r = client.delete(f"/api/models/queue/{entry_id}") + assert r.status_code == 200 + assert r.json() == {"ok": True} + + r2 = client.get("/api/models/queue") + assert r2.status_code == 200 + assert r2.json() == [] + + +def test_queue_dismiss_nonexistent_returns_404(client): + """DELETE /queue/{id} with unknown id returns 404.""" + r = client.delete("/api/models/queue/does-not-exist") + assert r.status_code == 404 + + +def test_queue_dismiss_allows_re_queue(client): + """After dismissal the same repo_id can be queued again.""" + entry = _queue_one(client, "org/requeue-model") + client.delete(f"/api/models/queue/{entry['id']}") + + r = client.post("/api/models/queue", json={ + "repo_id": "org/requeue-model", + "pipeline_tag": None, + "adapter_recommendation": None, + }) + assert r.status_code == 201 + + +# ── POST /queue/{id}/approve ─────────────────────────────────────────────────── + +def test_approve_nonexistent_returns_404(client): + """Approving an unknown id returns 404.""" + r = client.post("/api/models/queue/ghost-id/approve") + assert r.status_code == 404 + + +def test_approve_non_pending_returns_409(client): + """Approving an entry that is not in 'pending' state returns 409.""" + from app import models as models_module + + entry = _queue_one(client) + # Manually flip status to 'failed' + models_module._update_queue_entry(entry["id"], {"status": "failed"}) + + r = client.post(f"/api/models/queue/{entry['id']}/approve") + assert r.status_code == 409 + + +def test_approve_starts_download_and_returns_ok(client): + """Approving a pending entry returns {ok: true} and starts a background thread.""" + import time + import threading + + entry = _queue_one(client) + + # Patch snapshot_download so the thread doesn't actually hit the network. + # Use an Event so we can wait for the thread to finish before asserting. + thread_done = threading.Event() + original_run = None + + def _fake_snapshot_download(**kwargs): + pass + + with patch("app.models.snapshot_download", side_effect=_fake_snapshot_download): + r = client.post(f"/api/models/queue/{entry['id']}/approve") + assert r.status_code == 200 + assert r.json() == {"ok": True} + # Give the background thread a moment to complete while snapshot_download is patched + time.sleep(0.3) + + # Queue entry status should have moved to 'downloading' (or 'ready' if fast) + from app import models as models_module + updated = models_module._get_queue_entry(entry["id"]) + assert updated is not None, "Queue entry not found — thread may have run after fixture teardown" + assert updated["status"] in ("downloading", "ready", "failed") + + +# ── GET /download/stream ─────────────────────────────────────────────────────── + +def test_download_stream_idle_when_no_download(client): + """GET /download/stream returns a single idle event when nothing is downloading.""" + r = client.get("/api/models/download/stream") + assert r.status_code == 200 + # SSE body should contain the idle event + assert "idle" in r.text + + +# ── GET /installed ───────────────────────────────────────────────────────────── + +def test_installed_empty(client): + """GET /installed returns [] when models dir is empty.""" + r = client.get("/api/models/installed") + assert r.status_code == 200 + assert r.json() == [] + + +def test_installed_detects_downloaded_model(client, tmp_path): + """A subdir with config.json is surfaced as type='downloaded'.""" + from app import models as models_module + + model_dir = models_module._MODELS_DIR / "org--mymodel" + model_dir.mkdir() + (model_dir / "config.json").write_text(json.dumps({"model_type": "bert"}), encoding="utf-8") + (model_dir / "model_info.json").write_text( + json.dumps({"repo_id": "org/mymodel", "adapter_recommendation": "ZeroShotAdapter"}), + encoding="utf-8", + ) + + r = client.get("/api/models/installed") + assert r.status_code == 200 + items = r.json() + assert len(items) == 1 + assert items[0]["type"] == "downloaded" + assert items[0]["name"] == "org--mymodel" + assert items[0]["adapter"] == "ZeroShotAdapter" + assert items[0]["model_id"] == "org/mymodel" + + +def test_installed_detects_finetuned_model(client): + """A subdir with training_info.json is surfaced as type='finetuned'.""" + from app import models as models_module + + model_dir = models_module._MODELS_DIR / "my-finetuned" + model_dir.mkdir() + (model_dir / "training_info.json").write_text( + json.dumps({"base_model": "org/base", "epochs": 5}), encoding="utf-8" + ) + + r = client.get("/api/models/installed") + assert r.status_code == 200 + items = r.json() + assert len(items) == 1 + assert items[0]["type"] == "finetuned" + assert items[0]["name"] == "my-finetuned" + + +# ── DELETE /installed/{name} ─────────────────────────────────────────────────── + +def test_delete_installed_removes_directory(client): + """DELETE /installed/{name} removes the directory and returns {ok: true}.""" + from app import models as models_module + + model_dir = models_module._MODELS_DIR / "org--removeme" + model_dir.mkdir() + (model_dir / "config.json").write_text("{}", encoding="utf-8") + + r = client.delete("/api/models/installed/org--removeme") + assert r.status_code == 200 + assert r.json() == {"ok": True} + assert not model_dir.exists() + + +def test_delete_installed_not_found_returns_404(client): + r = client.delete("/api/models/installed/does-not-exist") + assert r.status_code == 404 + + +def test_delete_installed_path_traversal_blocked(client): + """DELETE /installed/../../etc must be blocked (400 or 422).""" + r = client.delete("/api/models/installed/../../etc") + assert r.status_code in (400, 404, 422) + + +def test_delete_installed_dotdot_name_blocked(client): + """A name containing '..' in any form must be rejected.""" + r = client.delete("/api/models/installed/..%2F..%2Fetc") + assert r.status_code in (400, 404, 422) + + +def test_delete_installed_name_with_slash_blocked(client): + """A name containing a literal '/' after URL decoding must be rejected.""" + from app import models as models_module + + # The router will see the path segment after /installed/ — a second '/' would + # be parsed as a new path segment, so we test via the validation helper directly. + with pytest.raises(Exception): + # Simulate calling delete logic with a slash-containing name directly + from fastapi import HTTPException as _HTTPException + from app.models import delete_installed + try: + delete_installed("org/traversal") + except _HTTPException as exc: + assert exc.status_code in (400, 404) + raise diff --git a/web/src/components/AppSidebar.vue b/web/src/components/AppSidebar.vue index 051feda..74e4a5c 100644 --- a/web/src/components/AppSidebar.vue +++ b/web/src/components/AppSidebar.vue @@ -66,6 +66,7 @@ const navItems = [ { path: '/fetch', icon: '📥', label: 'Fetch' }, { path: '/stats', icon: '📊', label: 'Stats' }, { path: '/benchmark', icon: '🏁', label: 'Benchmark' }, + { path: '/models', icon: '🤗', label: 'Models' }, { path: '/corrections', icon: '✍️', label: 'Corrections' }, { path: '/settings', icon: '⚙️', label: 'Settings' }, ] diff --git a/web/src/router/index.ts b/web/src/router/index.ts index 24f9321..a052e4c 100644 --- a/web/src/router/index.ts +++ b/web/src/router/index.ts @@ -7,6 +7,7 @@ const StatsView = () => import('../views/StatsView.vue') const BenchmarkView = () => import('../views/BenchmarkView.vue') const SettingsView = () => import('../views/SettingsView.vue') const CorrectionsView = () => import('../views/CorrectionsView.vue') +const ModelsView = () => import('../views/ModelsView.vue') export const router = createRouter({ history: createWebHashHistory(), @@ -15,6 +16,7 @@ export const router = createRouter({ { path: '/fetch', component: FetchView, meta: { title: 'Fetch' } }, { path: '/stats', component: StatsView, meta: { title: 'Stats' } }, { path: '/benchmark', component: BenchmarkView, meta: { title: 'Benchmark' } }, + { path: '/models', component: ModelsView, meta: { title: 'Models' } }, { path: '/corrections', component: CorrectionsView, meta: { title: 'Corrections' } }, { path: '/settings', component: SettingsView, meta: { title: 'Settings' } }, ], diff --git a/web/src/views/ModelsView.vue b/web/src/views/ModelsView.vue new file mode 100644 index 0000000..10df382 --- /dev/null +++ b/web/src/views/ModelsView.vue @@ -0,0 +1,826 @@ + + + + +