Compare commits
17 commits
feat/sft-c
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| e6b64d6efe | |||
| fee0cdb4a8 | |||
| 3299c0e23a | |||
| dc246df42d | |||
| 7a392df492 | |||
| 891142570b | |||
| a271278dc9 | |||
| dffb1d0d7a | |||
| ce12b29c94 | |||
| 49ec85706c | |||
| 478a47f6e0 | |||
| 7c304ebc45 | |||
| b6b3d2c390 | |||
| a7cb3ae62a | |||
| c5eaacc767 | |||
| 9633d9a535 | |||
| cfc09b4731 |
23 changed files with 5792 additions and 44 deletions
19
.env.example
Normal file
19
.env.example
Normal file
|
|
@ -0,0 +1,19 @@
|
||||||
|
# Avocet — environment variable configuration
|
||||||
|
# Copy to .env and fill in values. All keys are optional.
|
||||||
|
# label_tool.yaml takes precedence over env vars where both exist.
|
||||||
|
|
||||||
|
# ── Local inference (Ollama) ───────────────────────────────────────────────────
|
||||||
|
# OLLAMA_HOST defaults to http://localhost:11434 if unset.
|
||||||
|
OLLAMA_HOST=http://localhost:11434
|
||||||
|
OLLAMA_MODEL=llama3.2:3b
|
||||||
|
|
||||||
|
# ── cf-orch coordinator (paid/premium tiers) ───────────────────────────────────
|
||||||
|
# Required for multi-GPU LLM benchmarking via the cf-orch benchmark harness.
|
||||||
|
# Free-tier users can leave these unset and use Ollama only.
|
||||||
|
CF_ORCH_URL=http://localhost:7700
|
||||||
|
CF_LICENSE_KEY=CFG-AVCT-xxxx-xxxx-xxxx
|
||||||
|
|
||||||
|
# ── Cloud LLM backends (optional — paid/premium) ──────────────────────────────
|
||||||
|
# Set one of these to use a cloud LLM instead of a local model.
|
||||||
|
# ANTHROPIC_API_KEY=sk-ant-...
|
||||||
|
# OPENAI_API_KEY=sk-...
|
||||||
54
app/api.py
54
app/api.py
|
|
@ -145,6 +145,16 @@ app = FastAPI(title="Avocet API")
|
||||||
from app.sft import router as sft_router
|
from app.sft import router as sft_router
|
||||||
app.include_router(sft_router, prefix="/api/sft")
|
app.include_router(sft_router, prefix="/api/sft")
|
||||||
|
|
||||||
|
from app.models import router as models_router
|
||||||
|
import app.models as _models_module
|
||||||
|
app.include_router(models_router, prefix="/api/models")
|
||||||
|
|
||||||
|
from app.cforch import router as cforch_router
|
||||||
|
app.include_router(cforch_router, prefix="/api/cforch")
|
||||||
|
|
||||||
|
from app.imitate import router as imitate_router
|
||||||
|
app.include_router(imitate_router, prefix="/api/imitate")
|
||||||
|
|
||||||
# In-memory last-action store (single user, local tool — in-memory is fine)
|
# In-memory last-action store (single user, local tool — in-memory is fine)
|
||||||
_last_action: dict | None = None
|
_last_action: dict | None = None
|
||||||
|
|
||||||
|
|
@ -298,10 +308,18 @@ def get_stats():
|
||||||
lbl = r.get("label", "")
|
lbl = r.get("label", "")
|
||||||
if lbl:
|
if lbl:
|
||||||
counts[lbl] = counts.get(lbl, 0) + 1
|
counts[lbl] = counts.get(lbl, 0) + 1
|
||||||
|
benchmark_results: dict = {}
|
||||||
|
benchmark_path = _DATA_DIR / "benchmark_results.json"
|
||||||
|
if benchmark_path.exists():
|
||||||
|
try:
|
||||||
|
benchmark_results = json.loads(benchmark_path.read_text(encoding="utf-8"))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
return {
|
return {
|
||||||
"total": len(records),
|
"total": len(records),
|
||||||
"counts": counts,
|
"counts": counts,
|
||||||
"score_file_bytes": _score_file().stat().st_size if _score_file().exists() else 0,
|
"score_file_bytes": _score_file().stat().st_size if _score_file().exists() else 0,
|
||||||
|
"benchmark_results": benchmark_results,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -336,6 +354,36 @@ from fastapi.responses import StreamingResponse
|
||||||
# Benchmark endpoints
|
# Benchmark endpoints
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@app.get("/api/benchmark/models")
|
||||||
|
def get_benchmark_models() -> dict:
|
||||||
|
"""Return installed models grouped by adapter_type category."""
|
||||||
|
models_dir: Path = _models_module._MODELS_DIR
|
||||||
|
categories: dict[str, list[dict]] = {
|
||||||
|
"ZeroShotAdapter": [],
|
||||||
|
"RerankerAdapter": [],
|
||||||
|
"GenerationAdapter": [],
|
||||||
|
"Unknown": [],
|
||||||
|
}
|
||||||
|
if models_dir.exists():
|
||||||
|
for sub in models_dir.iterdir():
|
||||||
|
if not sub.is_dir():
|
||||||
|
continue
|
||||||
|
info_path = sub / "model_info.json"
|
||||||
|
adapter_type = "Unknown"
|
||||||
|
repo_id: str | None = None
|
||||||
|
if info_path.exists():
|
||||||
|
try:
|
||||||
|
info = json.loads(info_path.read_text(encoding="utf-8"))
|
||||||
|
adapter_type = info.get("adapter_type") or info.get("adapter_recommendation") or "Unknown"
|
||||||
|
repo_id = info.get("repo_id")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
bucket = adapter_type if adapter_type in categories else "Unknown"
|
||||||
|
entry: dict = {"name": sub.name, "repo_id": repo_id, "adapter_type": adapter_type}
|
||||||
|
categories[bucket].append(entry)
|
||||||
|
return {"categories": categories}
|
||||||
|
|
||||||
|
|
||||||
@app.get("/api/benchmark/results")
|
@app.get("/api/benchmark/results")
|
||||||
def get_benchmark_results():
|
def get_benchmark_results():
|
||||||
"""Return the most recently saved benchmark results, or an empty envelope."""
|
"""Return the most recently saved benchmark results, or an empty envelope."""
|
||||||
|
|
@ -346,13 +394,17 @@ def get_benchmark_results():
|
||||||
|
|
||||||
|
|
||||||
@app.get("/api/benchmark/run")
|
@app.get("/api/benchmark/run")
|
||||||
def run_benchmark(include_slow: bool = False):
|
def run_benchmark(include_slow: bool = False, model_names: str = ""):
|
||||||
"""Spawn the benchmark script and stream stdout as SSE progress events."""
|
"""Spawn the benchmark script and stream stdout as SSE progress events."""
|
||||||
python_bin = "/devl/miniconda3/envs/job-seeker-classifiers/bin/python"
|
python_bin = "/devl/miniconda3/envs/job-seeker-classifiers/bin/python"
|
||||||
script = str(_ROOT / "scripts" / "benchmark_classifier.py")
|
script = str(_ROOT / "scripts" / "benchmark_classifier.py")
|
||||||
cmd = [python_bin, script, "--score", "--save"]
|
cmd = [python_bin, script, "--score", "--save"]
|
||||||
if include_slow:
|
if include_slow:
|
||||||
cmd.append("--include-slow")
|
cmd.append("--include-slow")
|
||||||
|
if model_names:
|
||||||
|
names = [n.strip() for n in model_names.split(",") if n.strip()]
|
||||||
|
if names:
|
||||||
|
cmd.extend(["--models"] + names)
|
||||||
|
|
||||||
def generate():
|
def generate():
|
||||||
try:
|
try:
|
||||||
|
|
|
||||||
337
app/cforch.py
Normal file
337
app/cforch.py
Normal file
|
|
@ -0,0 +1,337 @@
|
||||||
|
"""Avocet — cf-orch benchmark integration API.
|
||||||
|
|
||||||
|
Wraps cf-orch's benchmark.py script and exposes it via the Avocet API.
|
||||||
|
Config is read from label_tool.yaml under the `cforch:` key.
|
||||||
|
|
||||||
|
All endpoints are registered on `router` (a FastAPI APIRouter).
|
||||||
|
api.py includes this router with prefix="/api/cforch".
|
||||||
|
|
||||||
|
Module-level globals (_CONFIG_DIR, _BENCH_RUNNING, _bench_proc) follow the
|
||||||
|
same testability pattern as sft.py — override _CONFIG_DIR via set_config_dir()
|
||||||
|
in test fixtures.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import subprocess as _subprocess
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
from fastapi import APIRouter, HTTPException
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_ROOT = Path(__file__).parent.parent
|
||||||
|
_CONFIG_DIR: Path | None = None # override in tests
|
||||||
|
_BENCH_RUNNING: bool = False
|
||||||
|
_bench_proc: Any = None # live Popen object while benchmark runs
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
# ── Testability seams ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def set_config_dir(path: Path | None) -> None:
|
||||||
|
global _CONFIG_DIR
|
||||||
|
_CONFIG_DIR = path
|
||||||
|
|
||||||
|
|
||||||
|
# ── Internal helpers ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _config_file() -> Path:
|
||||||
|
if _CONFIG_DIR is not None:
|
||||||
|
return _CONFIG_DIR / "label_tool.yaml"
|
||||||
|
return _ROOT / "config" / "label_tool.yaml"
|
||||||
|
|
||||||
|
|
||||||
|
def _load_cforch_config() -> dict:
|
||||||
|
"""Read label_tool.yaml cforch section, falling back to environment variables.
|
||||||
|
|
||||||
|
Priority (highest to lowest):
|
||||||
|
1. label_tool.yaml cforch: key
|
||||||
|
2. Environment variables (CF_ORCH_URL, CF_LICENSE_KEY, OLLAMA_HOST, OLLAMA_MODEL)
|
||||||
|
"""
|
||||||
|
f = _config_file()
|
||||||
|
file_cfg: dict = {}
|
||||||
|
if f.exists():
|
||||||
|
try:
|
||||||
|
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
||||||
|
file_cfg = raw.get("cforch", {}) or {}
|
||||||
|
except yaml.YAMLError as exc:
|
||||||
|
logger.warning("Failed to parse cforch config %s: %s", f, exc)
|
||||||
|
|
||||||
|
# Env var fallbacks — only used when the yaml key is absent or empty
|
||||||
|
def _coalesce(file_val: str, env_key: str) -> str:
|
||||||
|
return file_val if file_val else os.environ.get(env_key, "")
|
||||||
|
|
||||||
|
return {
|
||||||
|
**file_cfg,
|
||||||
|
"coordinator_url": _coalesce(file_cfg.get("coordinator_url", ""), "CF_ORCH_URL"),
|
||||||
|
"license_key": _coalesce(file_cfg.get("license_key", ""), "CF_LICENSE_KEY"),
|
||||||
|
"ollama_url": _coalesce(file_cfg.get("ollama_url", ""), "OLLAMA_HOST"),
|
||||||
|
"ollama_model": _coalesce(file_cfg.get("ollama_model", ""), "OLLAMA_MODEL"),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _strip_ansi(text: str) -> str:
|
||||||
|
"""Remove ANSI escape codes from a string."""
|
||||||
|
return re.sub(r'\x1b\[[0-9;]*m', '', text)
|
||||||
|
|
||||||
|
|
||||||
|
def _find_latest_summary(results_dir: str | None) -> Path | None:
|
||||||
|
"""Find the newest summary.json under results_dir, or None if not found."""
|
||||||
|
if not results_dir:
|
||||||
|
return None
|
||||||
|
rdir = Path(results_dir)
|
||||||
|
if not rdir.exists():
|
||||||
|
return None
|
||||||
|
# Subdirs are named YYYY-MM-DD-HHMMSS; sort lexicographically for chronological order
|
||||||
|
subdirs = sorted(
|
||||||
|
[d for d in rdir.iterdir() if d.is_dir()],
|
||||||
|
key=lambda d: d.name,
|
||||||
|
)
|
||||||
|
for subdir in reversed(subdirs):
|
||||||
|
summary = subdir / "summary.json"
|
||||||
|
if summary.exists():
|
||||||
|
return summary
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# ── GET /tasks ─────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@router.get("/tasks")
|
||||||
|
def get_tasks() -> dict:
|
||||||
|
"""Return task list from bench_tasks.yaml."""
|
||||||
|
cfg = _load_cforch_config()
|
||||||
|
tasks_path = cfg.get("bench_tasks", "")
|
||||||
|
if not tasks_path:
|
||||||
|
return {"tasks": [], "types": []}
|
||||||
|
|
||||||
|
p = Path(tasks_path)
|
||||||
|
if not p.exists():
|
||||||
|
return {"tasks": [], "types": []}
|
||||||
|
|
||||||
|
try:
|
||||||
|
raw = yaml.safe_load(p.read_text(encoding="utf-8")) or {}
|
||||||
|
except yaml.YAMLError as exc:
|
||||||
|
logger.warning("Failed to parse bench_tasks.yaml %s: %s", p, exc)
|
||||||
|
return {"tasks": [], "types": []}
|
||||||
|
|
||||||
|
tasks_raw = raw.get("tasks", []) or []
|
||||||
|
tasks: list[dict] = []
|
||||||
|
seen_types: list[str] = []
|
||||||
|
types_set: set[str] = set()
|
||||||
|
|
||||||
|
for t in tasks_raw:
|
||||||
|
if not isinstance(t, dict):
|
||||||
|
continue
|
||||||
|
tasks.append({
|
||||||
|
"id": t.get("id", ""),
|
||||||
|
"name": t.get("name", ""),
|
||||||
|
"type": t.get("type", ""),
|
||||||
|
"prompt": (t.get("prompt") or "").strip(),
|
||||||
|
"system": (t.get("system") or "").strip(),
|
||||||
|
})
|
||||||
|
task_type = t.get("type", "")
|
||||||
|
if task_type and task_type not in types_set:
|
||||||
|
seen_types.append(task_type)
|
||||||
|
types_set.add(task_type)
|
||||||
|
|
||||||
|
return {"tasks": tasks, "types": seen_types}
|
||||||
|
|
||||||
|
|
||||||
|
# ── GET /models ────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@router.get("/models")
|
||||||
|
def get_models() -> dict:
|
||||||
|
"""Return model list from bench_models.yaml."""
|
||||||
|
cfg = _load_cforch_config()
|
||||||
|
models_path = cfg.get("bench_models", "")
|
||||||
|
if not models_path:
|
||||||
|
return {"models": []}
|
||||||
|
|
||||||
|
p = Path(models_path)
|
||||||
|
if not p.exists():
|
||||||
|
return {"models": []}
|
||||||
|
|
||||||
|
try:
|
||||||
|
raw = yaml.safe_load(p.read_text(encoding="utf-8")) or {}
|
||||||
|
except yaml.YAMLError as exc:
|
||||||
|
logger.warning("Failed to parse bench_models.yaml %s: %s", p, exc)
|
||||||
|
return {"models": []}
|
||||||
|
|
||||||
|
models_raw = raw.get("models", []) or []
|
||||||
|
models: list[dict] = []
|
||||||
|
for m in models_raw:
|
||||||
|
if not isinstance(m, dict):
|
||||||
|
continue
|
||||||
|
models.append({
|
||||||
|
"name": m.get("name", ""),
|
||||||
|
"id": m.get("id", ""),
|
||||||
|
"service": m.get("service", "ollama"),
|
||||||
|
"tags": m.get("tags", []) or [],
|
||||||
|
"vram_estimate_mb": m.get("vram_estimate_mb", 0),
|
||||||
|
})
|
||||||
|
|
||||||
|
return {"models": models}
|
||||||
|
|
||||||
|
|
||||||
|
# ── GET /run ───────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@router.get("/run")
|
||||||
|
def run_benchmark(
|
||||||
|
task_ids: str = "",
|
||||||
|
model_tags: str = "",
|
||||||
|
coordinator_url: str = "",
|
||||||
|
ollama_url: str = "",
|
||||||
|
) -> StreamingResponse:
|
||||||
|
"""Spawn cf-orch benchmark.py and stream stdout as SSE progress events."""
|
||||||
|
global _BENCH_RUNNING, _bench_proc
|
||||||
|
|
||||||
|
if _BENCH_RUNNING:
|
||||||
|
raise HTTPException(409, "A benchmark is already running")
|
||||||
|
|
||||||
|
cfg = _load_cforch_config()
|
||||||
|
bench_script = cfg.get("bench_script", "")
|
||||||
|
bench_tasks = cfg.get("bench_tasks", "")
|
||||||
|
bench_models = cfg.get("bench_models", "")
|
||||||
|
results_dir = cfg.get("results_dir", "")
|
||||||
|
python_bin = cfg.get("python_bin", "/devl/miniconda3/envs/cf/bin/python")
|
||||||
|
cfg_coordinator = cfg.get("coordinator_url", "")
|
||||||
|
cfg_ollama = cfg.get("ollama_url", "")
|
||||||
|
cfg_license_key = cfg.get("license_key", "")
|
||||||
|
|
||||||
|
def generate():
|
||||||
|
global _BENCH_RUNNING, _bench_proc
|
||||||
|
|
||||||
|
if not bench_script or not Path(bench_script).exists():
|
||||||
|
yield f"data: {json.dumps({'type': 'error', 'message': 'bench_script not configured or not found'})}\n\n"
|
||||||
|
return
|
||||||
|
|
||||||
|
cmd = [
|
||||||
|
python_bin,
|
||||||
|
bench_script,
|
||||||
|
"--tasks", bench_tasks,
|
||||||
|
"--models", bench_models,
|
||||||
|
"--output", results_dir,
|
||||||
|
]
|
||||||
|
|
||||||
|
if task_ids:
|
||||||
|
cmd.extend(["--filter-tasks"] + task_ids.split(","))
|
||||||
|
if model_tags:
|
||||||
|
cmd.extend(["--filter-tags"] + model_tags.split(","))
|
||||||
|
|
||||||
|
# query param overrides config, config overrides env var (already resolved by _load_cforch_config)
|
||||||
|
effective_coordinator = coordinator_url if coordinator_url else cfg_coordinator
|
||||||
|
effective_ollama = ollama_url if ollama_url else cfg_ollama
|
||||||
|
if effective_coordinator:
|
||||||
|
cmd.extend(["--coordinator", effective_coordinator])
|
||||||
|
if effective_ollama:
|
||||||
|
cmd.extend(["--ollama-url", effective_ollama])
|
||||||
|
|
||||||
|
# Pass license key as env var so subprocess can authenticate with cf-orch
|
||||||
|
proc_env = {**os.environ}
|
||||||
|
if cfg_license_key:
|
||||||
|
proc_env["CF_LICENSE_KEY"] = cfg_license_key
|
||||||
|
|
||||||
|
_BENCH_RUNNING = True
|
||||||
|
try:
|
||||||
|
proc = _subprocess.Popen(
|
||||||
|
cmd,
|
||||||
|
stdout=_subprocess.PIPE,
|
||||||
|
stderr=_subprocess.STDOUT,
|
||||||
|
text=True,
|
||||||
|
bufsize=1,
|
||||||
|
env=proc_env,
|
||||||
|
)
|
||||||
|
_bench_proc = proc
|
||||||
|
try:
|
||||||
|
for line in proc.stdout:
|
||||||
|
line = _strip_ansi(line.rstrip())
|
||||||
|
if line:
|
||||||
|
yield f"data: {json.dumps({'type': 'progress', 'message': line})}\n\n"
|
||||||
|
proc.wait()
|
||||||
|
if proc.returncode == 0:
|
||||||
|
summary_path = _find_latest_summary(results_dir)
|
||||||
|
if summary_path is not None:
|
||||||
|
try:
|
||||||
|
summary = json.loads(summary_path.read_text(encoding="utf-8"))
|
||||||
|
yield f"data: {json.dumps({'type': 'result', 'summary': summary})}\n\n"
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("Failed to read summary.json: %s", exc)
|
||||||
|
yield f"data: {json.dumps({'type': 'complete'})}\n\n"
|
||||||
|
else:
|
||||||
|
yield f"data: {json.dumps({'type': 'error', 'message': f'Process exited with code {proc.returncode}'})}\n\n"
|
||||||
|
finally:
|
||||||
|
_bench_proc = None
|
||||||
|
except Exception as exc:
|
||||||
|
yield f"data: {json.dumps({'type': 'error', 'message': str(exc)})}\n\n"
|
||||||
|
finally:
|
||||||
|
_BENCH_RUNNING = False
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
generate(),
|
||||||
|
media_type="text/event-stream",
|
||||||
|
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── GET /config ────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@router.get("/config")
|
||||||
|
def get_cforch_config() -> dict:
|
||||||
|
"""Return resolved cf-orch connection config (env vars merged with yaml).
|
||||||
|
|
||||||
|
Redacts license_key — only returns whether it is set, not the value.
|
||||||
|
Used by the Settings UI to show current connection state.
|
||||||
|
"""
|
||||||
|
cfg = _load_cforch_config()
|
||||||
|
return {
|
||||||
|
"coordinator_url": cfg.get("coordinator_url", ""),
|
||||||
|
"ollama_url": cfg.get("ollama_url", ""),
|
||||||
|
"ollama_model": cfg.get("ollama_model", ""),
|
||||||
|
"license_key_set": bool(cfg.get("license_key", "")),
|
||||||
|
"source": "env" if not _config_file().exists() else "yaml+env",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ── GET /results ───────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@router.get("/results")
|
||||||
|
def get_results() -> dict:
|
||||||
|
"""Return the latest benchmark summary.json from results_dir."""
|
||||||
|
cfg = _load_cforch_config()
|
||||||
|
results_dir = cfg.get("results_dir", "")
|
||||||
|
summary_path = _find_latest_summary(results_dir)
|
||||||
|
if summary_path is None:
|
||||||
|
raise HTTPException(404, "No benchmark results found")
|
||||||
|
try:
|
||||||
|
return json.loads(summary_path.read_text(encoding="utf-8"))
|
||||||
|
except Exception as exc:
|
||||||
|
raise HTTPException(500, f"Failed to read summary.json: {exc}") from exc
|
||||||
|
|
||||||
|
|
||||||
|
# ── POST /cancel ───────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@router.post("/cancel")
|
||||||
|
def cancel_benchmark() -> dict:
|
||||||
|
"""Kill the running benchmark subprocess."""
|
||||||
|
global _BENCH_RUNNING, _bench_proc
|
||||||
|
|
||||||
|
if not _BENCH_RUNNING:
|
||||||
|
raise HTTPException(404, "No benchmark is currently running")
|
||||||
|
|
||||||
|
if _bench_proc is not None:
|
||||||
|
try:
|
||||||
|
_bench_proc.terminate()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("Failed to terminate benchmark process: %s", exc)
|
||||||
|
|
||||||
|
_BENCH_RUNNING = False
|
||||||
|
_bench_proc = None
|
||||||
|
return {"status": "cancelled"}
|
||||||
352
app/imitate.py
Normal file
352
app/imitate.py
Normal file
|
|
@ -0,0 +1,352 @@
|
||||||
|
"""Avocet — Imitate tab API.
|
||||||
|
|
||||||
|
Fetches real samples from sibling CF product APIs, sends them through selected
|
||||||
|
local LLMs (ollama), and streams responses back to the UI. Results can be
|
||||||
|
pushed into the SFT corrections queue for human review.
|
||||||
|
|
||||||
|
All endpoints registered on `router`. api.py includes this with prefix="/api/imitate".
|
||||||
|
|
||||||
|
Module-level globals follow the same testability pattern as cforch.py and sft.py:
|
||||||
|
override _CONFIG_DIR and _DATA_DIR via set_config_dir() / set_data_dir() in tests.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
from urllib.error import URLError
|
||||||
|
from urllib.request import Request, urlopen
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
from fastapi import APIRouter, HTTPException
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from app.utils import append_jsonl
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_ROOT = Path(__file__).parent.parent
|
||||||
|
_CONFIG_DIR: Path | None = None
|
||||||
|
_DATA_DIR: Path = _ROOT / "data"
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
# ── Testability seams ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def set_config_dir(path: Path | None) -> None:
|
||||||
|
global _CONFIG_DIR
|
||||||
|
_CONFIG_DIR = path
|
||||||
|
|
||||||
|
|
||||||
|
def set_data_dir(path: Path) -> None:
|
||||||
|
global _DATA_DIR
|
||||||
|
_DATA_DIR = path
|
||||||
|
|
||||||
|
|
||||||
|
# ── Internal helpers ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _config_file() -> Path:
|
||||||
|
if _CONFIG_DIR is not None:
|
||||||
|
return _CONFIG_DIR / "label_tool.yaml"
|
||||||
|
return _ROOT / "config" / "label_tool.yaml"
|
||||||
|
|
||||||
|
|
||||||
|
def _load_imitate_config() -> dict:
|
||||||
|
"""Read label_tool.yaml and return the imitate sub-dict (or {} if absent)."""
|
||||||
|
f = _config_file()
|
||||||
|
if not f.exists():
|
||||||
|
return {}
|
||||||
|
try:
|
||||||
|
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
||||||
|
except yaml.YAMLError as exc:
|
||||||
|
logger.warning("Failed to parse imitate config %s: %s", f, exc)
|
||||||
|
return {}
|
||||||
|
return raw.get("imitate", {}) or {}
|
||||||
|
|
||||||
|
|
||||||
|
def _load_cforch_config() -> dict:
|
||||||
|
"""Read cforch section for ollama_url fallback."""
|
||||||
|
f = _config_file()
|
||||||
|
if not f.exists():
|
||||||
|
return {}
|
||||||
|
try:
|
||||||
|
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
||||||
|
except yaml.YAMLError as exc:
|
||||||
|
return {}
|
||||||
|
return raw.get("cforch", {}) or {}
|
||||||
|
|
||||||
|
|
||||||
|
def _ollama_url(cfg: dict) -> str:
|
||||||
|
cforch = _load_cforch_config()
|
||||||
|
return cfg.get("ollama_url") or cforch.get("ollama_url") or "http://localhost:11434"
|
||||||
|
|
||||||
|
|
||||||
|
def _http_get_json(url: str, timeout: int = 5) -> Any:
|
||||||
|
"""Fetch JSON from url; raise URLError on failure."""
|
||||||
|
req = Request(url, headers={"Accept": "application/json"})
|
||||||
|
with urlopen(req, timeout=timeout) as resp:
|
||||||
|
return json.loads(resp.read().decode("utf-8"))
|
||||||
|
|
||||||
|
|
||||||
|
def _is_online(base_url: str, health_path: str = "/api/health") -> bool:
|
||||||
|
"""Return True if the product's health endpoint responds OK."""
|
||||||
|
try:
|
||||||
|
data = _http_get_json(f"{base_url.rstrip('/')}{health_path}", timeout=2)
|
||||||
|
return bool(data)
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_sample(
|
||||||
|
raw: Any, text_fields: list[str], sample_index: int = 0
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Pull one item from a list or dict response and extract text_fields."""
|
||||||
|
item: dict[str, Any]
|
||||||
|
if isinstance(raw, list):
|
||||||
|
if not raw:
|
||||||
|
return {}
|
||||||
|
item = raw[min(sample_index, len(raw) - 1)]
|
||||||
|
elif isinstance(raw, dict):
|
||||||
|
# may be {items: [...]} or the item itself
|
||||||
|
for key in ("items", "results", "data", "jobs", "listings", "pantry",
|
||||||
|
"saved_searches", "entries", "calls", "records"):
|
||||||
|
if key in raw and isinstance(raw[key], list):
|
||||||
|
lst = raw[key]
|
||||||
|
item = lst[min(sample_index, len(lst) - 1)] if lst else {}
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
item = raw
|
||||||
|
else:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
parts = []
|
||||||
|
for field in text_fields:
|
||||||
|
val = item.get(field)
|
||||||
|
if val and str(val).strip():
|
||||||
|
parts.append(f"**{field}**: {val}")
|
||||||
|
return {"item": item, "text": "\n\n".join(parts)}
|
||||||
|
|
||||||
|
|
||||||
|
def _candidates_file() -> Path:
|
||||||
|
return _DATA_DIR / "sft_candidates.jsonl"
|
||||||
|
|
||||||
|
|
||||||
|
def _sse(data: dict) -> str:
|
||||||
|
return f"data: {json.dumps(data)}\n\n"
|
||||||
|
|
||||||
|
|
||||||
|
def _run_ollama_streaming(
|
||||||
|
ollama_base: str,
|
||||||
|
model_id: str,
|
||||||
|
prompt: str,
|
||||||
|
temperature: float,
|
||||||
|
) -> tuple[str, int]:
|
||||||
|
"""Call ollama /api/generate with stream=True; return (full_response, elapsed_ms).
|
||||||
|
|
||||||
|
Blocks until the model finishes; yields nothing — streaming is handled by
|
||||||
|
the SSE generator in run_imitate().
|
||||||
|
"""
|
||||||
|
url = f"{ollama_base.rstrip('/')}/api/generate"
|
||||||
|
payload = json.dumps({
|
||||||
|
"model": model_id,
|
||||||
|
"prompt": prompt,
|
||||||
|
"stream": False,
|
||||||
|
"options": {"temperature": temperature},
|
||||||
|
}).encode("utf-8")
|
||||||
|
req = Request(url, data=payload, method="POST",
|
||||||
|
headers={"Content-Type": "application/json"})
|
||||||
|
t0 = time.time()
|
||||||
|
try:
|
||||||
|
with urlopen(req, timeout=120) as resp:
|
||||||
|
body = json.loads(resp.read().decode("utf-8"))
|
||||||
|
elapsed = int((time.time() - t0) * 1000)
|
||||||
|
return body.get("response", ""), elapsed
|
||||||
|
except Exception as exc:
|
||||||
|
elapsed = int((time.time() - t0) * 1000)
|
||||||
|
raise RuntimeError(str(exc)) from exc
|
||||||
|
|
||||||
|
|
||||||
|
# ── GET /products ──────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@router.get("/products")
|
||||||
|
def get_products() -> dict:
|
||||||
|
"""List configured CF products with live online status."""
|
||||||
|
cfg = _load_imitate_config()
|
||||||
|
products_raw = cfg.get("products", []) or []
|
||||||
|
products = []
|
||||||
|
for p in products_raw:
|
||||||
|
if not isinstance(p, dict):
|
||||||
|
continue
|
||||||
|
base_url = p.get("base_url", "")
|
||||||
|
products.append({
|
||||||
|
"id": p.get("id", ""),
|
||||||
|
"name": p.get("name", ""),
|
||||||
|
"icon": p.get("icon", "📦"),
|
||||||
|
"description": p.get("description", ""),
|
||||||
|
"base_url": base_url,
|
||||||
|
"online": _is_online(base_url, p.get("health_path", "/api/health")) if base_url else False,
|
||||||
|
})
|
||||||
|
return {"products": products}
|
||||||
|
|
||||||
|
|
||||||
|
# ── GET /products/{product_id}/sample ─────────────────────────────────────────
|
||||||
|
|
||||||
|
@router.get("/products/{product_id}/sample")
|
||||||
|
def get_sample(product_id: str, index: int = 0) -> dict:
|
||||||
|
"""Fetch a real sample from the given product's API."""
|
||||||
|
cfg = _load_imitate_config()
|
||||||
|
products_raw = cfg.get("products", []) or []
|
||||||
|
|
||||||
|
product: dict | None = None
|
||||||
|
for p in products_raw:
|
||||||
|
if isinstance(p, dict) and p.get("id") == product_id:
|
||||||
|
product = p
|
||||||
|
break
|
||||||
|
|
||||||
|
if product is None:
|
||||||
|
raise HTTPException(404, f"Product '{product_id}' not in config")
|
||||||
|
|
||||||
|
base_url = product.get("base_url", "").rstrip("/")
|
||||||
|
endpoint = product.get("sample_endpoint", "")
|
||||||
|
if not base_url or not endpoint:
|
||||||
|
raise HTTPException(422, "Product missing base_url or sample_endpoint")
|
||||||
|
|
||||||
|
url = f"{base_url}{endpoint}"
|
||||||
|
try:
|
||||||
|
raw = _http_get_json(url, timeout=5)
|
||||||
|
except URLError as exc:
|
||||||
|
raise HTTPException(503, f"Product API unreachable: {exc}") from exc
|
||||||
|
except Exception as exc:
|
||||||
|
raise HTTPException(502, f"Bad response from product API: {exc}") from exc
|
||||||
|
|
||||||
|
text_fields = product.get("text_fields", []) or []
|
||||||
|
extracted = _extract_sample(raw, text_fields, index)
|
||||||
|
if not extracted:
|
||||||
|
raise HTTPException(404, "No sample items returned by product API")
|
||||||
|
|
||||||
|
prompt_template = product.get("prompt_template", "{text}")
|
||||||
|
prompt = prompt_template.replace("{text}", extracted["text"])
|
||||||
|
|
||||||
|
return {
|
||||||
|
"product_id": product_id,
|
||||||
|
"sample_index": index,
|
||||||
|
"text": extracted["text"],
|
||||||
|
"prompt": prompt,
|
||||||
|
"raw_item": extracted.get("item", {}),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ── GET /run (SSE) ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@router.get("/run")
|
||||||
|
def run_imitate(
|
||||||
|
prompt: str = "",
|
||||||
|
model_ids: str = "", # comma-separated ollama model IDs
|
||||||
|
temperature: float = 0.7,
|
||||||
|
product_id: str = "",
|
||||||
|
) -> StreamingResponse:
|
||||||
|
"""Run a prompt through selected ollama models and stream results as SSE."""
|
||||||
|
|
||||||
|
if not prompt.strip():
|
||||||
|
raise HTTPException(422, "prompt is required")
|
||||||
|
|
||||||
|
ids = [m.strip() for m in model_ids.split(",") if m.strip()]
|
||||||
|
if not ids:
|
||||||
|
raise HTTPException(422, "model_ids is required")
|
||||||
|
|
||||||
|
cfg = _load_imitate_config()
|
||||||
|
ollama_base = _ollama_url(cfg)
|
||||||
|
|
||||||
|
def generate():
|
||||||
|
results: list[dict] = []
|
||||||
|
yield _sse({"type": "start", "total_models": len(ids)})
|
||||||
|
|
||||||
|
for model_id in ids:
|
||||||
|
yield _sse({"type": "model_start", "model": model_id})
|
||||||
|
try:
|
||||||
|
response, elapsed_ms = _run_ollama_streaming(
|
||||||
|
ollama_base, model_id, prompt, temperature
|
||||||
|
)
|
||||||
|
result = {
|
||||||
|
"model": model_id,
|
||||||
|
"response": response,
|
||||||
|
"elapsed_ms": elapsed_ms,
|
||||||
|
"error": None,
|
||||||
|
}
|
||||||
|
except Exception as exc:
|
||||||
|
result = {
|
||||||
|
"model": model_id,
|
||||||
|
"response": "",
|
||||||
|
"elapsed_ms": 0,
|
||||||
|
"error": str(exc),
|
||||||
|
}
|
||||||
|
results.append(result)
|
||||||
|
yield _sse({"type": "model_done", **result})
|
||||||
|
|
||||||
|
yield _sse({"type": "complete", "results": results})
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
generate(),
|
||||||
|
media_type="text/event-stream",
|
||||||
|
headers={
|
||||||
|
"Cache-Control": "no-cache",
|
||||||
|
"X-Accel-Buffering": "no",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── POST /push-corrections ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class ImitateResult(BaseModel):
|
||||||
|
model: str
|
||||||
|
response: str
|
||||||
|
elapsed_ms: int
|
||||||
|
error: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class PushCorrectionsRequest(BaseModel):
|
||||||
|
product_id: str
|
||||||
|
prompt: str
|
||||||
|
results: list[ImitateResult]
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/push-corrections")
|
||||||
|
def push_corrections(req: PushCorrectionsRequest) -> dict:
|
||||||
|
"""Append imitate results to sft_candidates.jsonl for human review."""
|
||||||
|
if not req.prompt.strip():
|
||||||
|
raise HTTPException(422, "prompt is required")
|
||||||
|
if not req.results:
|
||||||
|
raise HTTPException(422, "results list is empty")
|
||||||
|
|
||||||
|
ts = datetime.now(timezone.utc).isoformat()
|
||||||
|
records = []
|
||||||
|
for r in req.results:
|
||||||
|
if r.error or not r.response.strip():
|
||||||
|
continue
|
||||||
|
records.append({
|
||||||
|
"id": str(uuid.uuid4()),
|
||||||
|
"source": "imitate",
|
||||||
|
"product_id": req.product_id,
|
||||||
|
"prompt_messages": [{"role": "user", "content": req.prompt}],
|
||||||
|
"model_response": r.response,
|
||||||
|
"model_id": r.model,
|
||||||
|
"elapsed_ms": r.elapsed_ms,
|
||||||
|
"status": "pending",
|
||||||
|
"created_at": ts,
|
||||||
|
})
|
||||||
|
|
||||||
|
if not records:
|
||||||
|
raise HTTPException(422, "No non-error results to push")
|
||||||
|
|
||||||
|
dest = _candidates_file()
|
||||||
|
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
for record in records:
|
||||||
|
append_jsonl(dest, record)
|
||||||
|
|
||||||
|
return {"pushed": len(records)}
|
||||||
448
app/models.py
Normal file
448
app/models.py
Normal file
|
|
@ -0,0 +1,448 @@
|
||||||
|
"""Avocet — HF model lifecycle API.
|
||||||
|
|
||||||
|
Handles model metadata lookup, approval queue, download with progress,
|
||||||
|
and installed model management.
|
||||||
|
|
||||||
|
All endpoints are registered on `router` (a FastAPI APIRouter).
|
||||||
|
api.py includes this router with prefix="/api/models".
|
||||||
|
|
||||||
|
Module-level globals (_MODELS_DIR, _QUEUE_DIR) follow the same
|
||||||
|
testability pattern as sft.py — override them via set_models_dir() and
|
||||||
|
set_queue_dir() in test fixtures.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import shutil
|
||||||
|
import threading
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from fastapi import APIRouter, HTTPException
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from app.utils import read_jsonl, write_jsonl
|
||||||
|
|
||||||
|
try:
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
except ImportError: # pragma: no cover
|
||||||
|
snapshot_download = None # type: ignore[assignment]
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_ROOT = Path(__file__).parent.parent
|
||||||
|
_MODELS_DIR: Path = _ROOT / "models"
|
||||||
|
_QUEUE_DIR: Path = _ROOT / "data"
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
# ── Download progress shared state ────────────────────────────────────────────
|
||||||
|
# Updated by the background download thread; read by GET /download/stream.
|
||||||
|
_download_progress: dict[str, Any] = {}
|
||||||
|
|
||||||
|
# ── HF pipeline_tag → adapter recommendation ──────────────────────────────────
|
||||||
|
_TAG_TO_ADAPTER: dict[str, str] = {
|
||||||
|
"zero-shot-classification": "ZeroShotAdapter",
|
||||||
|
"text-classification": "ZeroShotAdapter",
|
||||||
|
"natural-language-inference": "ZeroShotAdapter",
|
||||||
|
"sentence-similarity": "RerankerAdapter",
|
||||||
|
"text-ranking": "RerankerAdapter",
|
||||||
|
"text-generation": "GenerationAdapter",
|
||||||
|
"text2text-generation": "GenerationAdapter",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ── Testability seams ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def set_models_dir(path: Path) -> None:
|
||||||
|
global _MODELS_DIR
|
||||||
|
_MODELS_DIR = path
|
||||||
|
|
||||||
|
|
||||||
|
def set_queue_dir(path: Path) -> None:
|
||||||
|
global _QUEUE_DIR
|
||||||
|
_QUEUE_DIR = path
|
||||||
|
|
||||||
|
|
||||||
|
# ── Internal helpers ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _queue_file() -> Path:
|
||||||
|
return _QUEUE_DIR / "model_queue.jsonl"
|
||||||
|
|
||||||
|
|
||||||
|
def _read_queue() -> list[dict]:
|
||||||
|
return read_jsonl(_queue_file())
|
||||||
|
|
||||||
|
|
||||||
|
def _write_queue(records: list[dict]) -> None:
|
||||||
|
write_jsonl(_queue_file(), records)
|
||||||
|
|
||||||
|
|
||||||
|
def _safe_model_name(repo_id: str) -> str:
|
||||||
|
"""Convert repo_id to a filesystem-safe directory name (HF convention)."""
|
||||||
|
return repo_id.replace("/", "--")
|
||||||
|
|
||||||
|
|
||||||
|
def _is_installed(repo_id: str) -> bool:
|
||||||
|
"""Check if a model is already downloaded in _MODELS_DIR."""
|
||||||
|
safe_name = _safe_model_name(repo_id)
|
||||||
|
model_dir = _MODELS_DIR / safe_name
|
||||||
|
return model_dir.exists() and (
|
||||||
|
(model_dir / "config.json").exists()
|
||||||
|
or (model_dir / "training_info.json").exists()
|
||||||
|
or (model_dir / "model_info.json").exists()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_queued(repo_id: str) -> bool:
|
||||||
|
"""Check if repo_id is already in the queue (non-dismissed)."""
|
||||||
|
for entry in _read_queue():
|
||||||
|
if entry.get("repo_id") == repo_id and entry.get("status") != "dismissed":
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _update_queue_entry(entry_id: str, updates: dict) -> dict | None:
|
||||||
|
"""Update a queue entry by id. Returns updated entry or None if not found."""
|
||||||
|
records = _read_queue()
|
||||||
|
for i, r in enumerate(records):
|
||||||
|
if r.get("id") == entry_id:
|
||||||
|
records[i] = {**r, **updates}
|
||||||
|
_write_queue(records)
|
||||||
|
return records[i]
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _get_queue_entry(entry_id: str) -> dict | None:
|
||||||
|
for r in _read_queue():
|
||||||
|
if r.get("id") == entry_id:
|
||||||
|
return r
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# ── Background download ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _run_download(entry_id: str, repo_id: str, pipeline_tag: str | None, adapter_recommendation: str | None) -> None:
|
||||||
|
"""Background thread: download model via huggingface_hub.snapshot_download."""
|
||||||
|
global _download_progress
|
||||||
|
safe_name = _safe_model_name(repo_id)
|
||||||
|
local_dir = _MODELS_DIR / safe_name
|
||||||
|
|
||||||
|
_download_progress = {
|
||||||
|
"active": True,
|
||||||
|
"repo_id": repo_id,
|
||||||
|
"downloaded_bytes": 0,
|
||||||
|
"total_bytes": 0,
|
||||||
|
"pct": 0.0,
|
||||||
|
"done": False,
|
||||||
|
"error": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
if snapshot_download is None:
|
||||||
|
raise RuntimeError("huggingface_hub is not installed")
|
||||||
|
|
||||||
|
snapshot_download(
|
||||||
|
repo_id=repo_id,
|
||||||
|
local_dir=str(local_dir),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Write model_info.json alongside downloaded files
|
||||||
|
model_info = {
|
||||||
|
"repo_id": repo_id,
|
||||||
|
"pipeline_tag": pipeline_tag,
|
||||||
|
"adapter_recommendation": adapter_recommendation,
|
||||||
|
"downloaded_at": datetime.now(timezone.utc).isoformat(),
|
||||||
|
}
|
||||||
|
local_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
(local_dir / "model_info.json").write_text(
|
||||||
|
json.dumps(model_info, indent=2), encoding="utf-8"
|
||||||
|
)
|
||||||
|
|
||||||
|
_download_progress["done"] = True
|
||||||
|
_download_progress["pct"] = 100.0
|
||||||
|
_update_queue_entry(entry_id, {"status": "ready"})
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
logger.exception("Download failed for %s: %s", repo_id, exc)
|
||||||
|
_download_progress["error"] = str(exc)
|
||||||
|
_download_progress["done"] = True
|
||||||
|
_update_queue_entry(entry_id, {"status": "failed", "error": str(exc)})
|
||||||
|
finally:
|
||||||
|
_download_progress["active"] = False
|
||||||
|
|
||||||
|
|
||||||
|
# ── GET /lookup ────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@router.get("/lookup")
|
||||||
|
def lookup_model(repo_id: str) -> dict:
|
||||||
|
"""Validate repo_id and fetch metadata from the HF API."""
|
||||||
|
# Validate: must contain exactly one '/', no whitespace
|
||||||
|
if "/" not in repo_id or any(c.isspace() for c in repo_id):
|
||||||
|
raise HTTPException(422, f"Invalid repo_id {repo_id!r}: must be 'owner/model-name' with no whitespace")
|
||||||
|
|
||||||
|
hf_url = f"https://huggingface.co/api/models/{repo_id}"
|
||||||
|
try:
|
||||||
|
resp = httpx.get(hf_url, timeout=10.0)
|
||||||
|
except httpx.RequestError as exc:
|
||||||
|
raise HTTPException(502, f"Network error reaching HuggingFace API: {exc}") from exc
|
||||||
|
|
||||||
|
if resp.status_code == 404:
|
||||||
|
raise HTTPException(404, f"Model {repo_id!r} not found on HuggingFace")
|
||||||
|
if resp.status_code != 200:
|
||||||
|
raise HTTPException(502, f"HuggingFace API returned status {resp.status_code}")
|
||||||
|
|
||||||
|
data = resp.json()
|
||||||
|
pipeline_tag = data.get("pipeline_tag")
|
||||||
|
adapter_recommendation = _TAG_TO_ADAPTER.get(pipeline_tag) if pipeline_tag else None
|
||||||
|
|
||||||
|
# Determine compatibility and surface a human-readable warning
|
||||||
|
_supported = ", ".join(sorted(_TAG_TO_ADAPTER.keys()))
|
||||||
|
if adapter_recommendation is not None:
|
||||||
|
compatible = True
|
||||||
|
warning: str | None = None
|
||||||
|
elif pipeline_tag is None:
|
||||||
|
compatible = False
|
||||||
|
warning = (
|
||||||
|
"This model has no task tag on HuggingFace — adapter type is unknown. "
|
||||||
|
"It may not work with Avocet's email classification pipeline."
|
||||||
|
)
|
||||||
|
logger.warning("No pipeline_tag for %s — no adapter recommendation", repo_id)
|
||||||
|
else:
|
||||||
|
compatible = False
|
||||||
|
warning = (
|
||||||
|
f"\"{pipeline_tag}\" models are not supported by Avocet's email classification adapters. "
|
||||||
|
f"Supported task types: {_supported}."
|
||||||
|
)
|
||||||
|
logger.warning("Unsupported pipeline_tag %r for %s", pipeline_tag, repo_id)
|
||||||
|
|
||||||
|
# Estimate model size from siblings list
|
||||||
|
siblings = data.get("siblings") or []
|
||||||
|
model_size_bytes: int = sum(s.get("size", 0) for s in siblings if isinstance(s, dict))
|
||||||
|
|
||||||
|
# Description: first 300 chars of card data (modelId field used as fallback)
|
||||||
|
card_data = data.get("cardData") or {}
|
||||||
|
description_raw = card_data.get("description") or data.get("modelId") or ""
|
||||||
|
description = description_raw[:300] if description_raw else ""
|
||||||
|
|
||||||
|
return {
|
||||||
|
"repo_id": repo_id,
|
||||||
|
"pipeline_tag": pipeline_tag,
|
||||||
|
"adapter_recommendation": adapter_recommendation,
|
||||||
|
"compatible": compatible,
|
||||||
|
"warning": warning,
|
||||||
|
"model_size_bytes": model_size_bytes,
|
||||||
|
"description": description,
|
||||||
|
"tags": data.get("tags") or [],
|
||||||
|
"downloads": data.get("downloads") or 0,
|
||||||
|
"already_installed": _is_installed(repo_id),
|
||||||
|
"already_queued": _is_queued(repo_id),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ── GET /queue ─────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@router.get("/queue")
|
||||||
|
def get_queue() -> list[dict]:
|
||||||
|
"""Return all non-dismissed queue entries sorted newest-first."""
|
||||||
|
records = _read_queue()
|
||||||
|
active = [r for r in records if r.get("status") != "dismissed"]
|
||||||
|
return sorted(active, key=lambda r: r.get("queued_at", ""), reverse=True)
|
||||||
|
|
||||||
|
|
||||||
|
# ── POST /queue ────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class QueueAddRequest(BaseModel):
|
||||||
|
repo_id: str
|
||||||
|
pipeline_tag: str | None = None
|
||||||
|
adapter_recommendation: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/queue", status_code=201)
|
||||||
|
def add_to_queue(req: QueueAddRequest) -> dict:
|
||||||
|
"""Add a model to the approval queue with status 'pending'."""
|
||||||
|
if _is_installed(req.repo_id):
|
||||||
|
raise HTTPException(409, f"{req.repo_id!r} is already installed")
|
||||||
|
if _is_queued(req.repo_id):
|
||||||
|
raise HTTPException(409, f"{req.repo_id!r} is already in the queue")
|
||||||
|
|
||||||
|
entry = {
|
||||||
|
"id": str(uuid4()),
|
||||||
|
"repo_id": req.repo_id,
|
||||||
|
"pipeline_tag": req.pipeline_tag,
|
||||||
|
"adapter_recommendation": req.adapter_recommendation,
|
||||||
|
"status": "pending",
|
||||||
|
"queued_at": datetime.now(timezone.utc).isoformat(),
|
||||||
|
}
|
||||||
|
records = _read_queue()
|
||||||
|
records.append(entry)
|
||||||
|
_write_queue(records)
|
||||||
|
return entry
|
||||||
|
|
||||||
|
|
||||||
|
# ── POST /queue/{id}/approve ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@router.post("/queue/{entry_id}/approve")
|
||||||
|
def approve_queue_entry(entry_id: str) -> dict:
|
||||||
|
"""Approve a pending queue entry and start background download."""
|
||||||
|
entry = _get_queue_entry(entry_id)
|
||||||
|
if entry is None:
|
||||||
|
raise HTTPException(404, f"Queue entry {entry_id!r} not found")
|
||||||
|
if entry.get("status") != "pending":
|
||||||
|
raise HTTPException(409, f"Entry is not in pending state (current: {entry.get('status')!r})")
|
||||||
|
|
||||||
|
updated = _update_queue_entry(entry_id, {"status": "downloading"})
|
||||||
|
|
||||||
|
thread = threading.Thread(
|
||||||
|
target=_run_download,
|
||||||
|
args=(entry_id, entry["repo_id"], entry.get("pipeline_tag"), entry.get("adapter_recommendation")),
|
||||||
|
daemon=True,
|
||||||
|
name=f"model-download-{entry_id}",
|
||||||
|
)
|
||||||
|
thread.start()
|
||||||
|
|
||||||
|
return {"ok": True}
|
||||||
|
|
||||||
|
|
||||||
|
# ── DELETE /queue/{id} ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@router.delete("/queue/{entry_id}")
|
||||||
|
def dismiss_queue_entry(entry_id: str) -> dict:
|
||||||
|
"""Dismiss (soft-delete) a queue entry."""
|
||||||
|
entry = _get_queue_entry(entry_id)
|
||||||
|
if entry is None:
|
||||||
|
raise HTTPException(404, f"Queue entry {entry_id!r} not found")
|
||||||
|
|
||||||
|
_update_queue_entry(entry_id, {"status": "dismissed"})
|
||||||
|
return {"ok": True}
|
||||||
|
|
||||||
|
|
||||||
|
# ── GET /download/stream ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@router.get("/download/stream")
|
||||||
|
def download_stream() -> StreamingResponse:
|
||||||
|
"""SSE stream of download progress. Yields one idle event if no download active."""
|
||||||
|
|
||||||
|
def generate():
|
||||||
|
prog = _download_progress
|
||||||
|
if not prog.get("active") and not (prog.get("done") and not prog.get("error")):
|
||||||
|
yield f"data: {json.dumps({'type': 'idle'})}\n\n"
|
||||||
|
return
|
||||||
|
|
||||||
|
if prog.get("done"):
|
||||||
|
if prog.get("error"):
|
||||||
|
yield f"data: {json.dumps({'type': 'error', 'error': prog['error']})}\n\n"
|
||||||
|
else:
|
||||||
|
yield f"data: {json.dumps({'type': 'done', 'repo_id': prog.get('repo_id')})}\n\n"
|
||||||
|
return
|
||||||
|
|
||||||
|
# Stream live progress
|
||||||
|
import time
|
||||||
|
while True:
|
||||||
|
p = dict(_download_progress)
|
||||||
|
if p.get("done"):
|
||||||
|
if p.get("error"):
|
||||||
|
yield f"data: {json.dumps({'type': 'error', 'error': p['error']})}\n\n"
|
||||||
|
else:
|
||||||
|
yield f"data: {json.dumps({'type': 'done', 'repo_id': p.get('repo_id')})}\n\n"
|
||||||
|
break
|
||||||
|
event = json.dumps({
|
||||||
|
"type": "progress",
|
||||||
|
"repo_id": p.get("repo_id"),
|
||||||
|
"downloaded_bytes": p.get("downloaded_bytes", 0),
|
||||||
|
"total_bytes": p.get("total_bytes", 0),
|
||||||
|
"pct": p.get("pct", 0.0),
|
||||||
|
})
|
||||||
|
yield f"data: {event}\n\n"
|
||||||
|
time.sleep(0.5)
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
generate(),
|
||||||
|
media_type="text/event-stream",
|
||||||
|
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── GET /installed ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@router.get("/installed")
|
||||||
|
def list_installed() -> list[dict]:
|
||||||
|
"""Scan _MODELS_DIR and return info on each installed model."""
|
||||||
|
if not _MODELS_DIR.exists():
|
||||||
|
return []
|
||||||
|
|
||||||
|
results: list[dict] = []
|
||||||
|
for sub in _MODELS_DIR.iterdir():
|
||||||
|
if not sub.is_dir():
|
||||||
|
continue
|
||||||
|
|
||||||
|
has_training_info = (sub / "training_info.json").exists()
|
||||||
|
has_config = (sub / "config.json").exists()
|
||||||
|
has_model_info = (sub / "model_info.json").exists()
|
||||||
|
|
||||||
|
if not (has_training_info or has_config or has_model_info):
|
||||||
|
continue
|
||||||
|
|
||||||
|
model_type = "finetuned" if has_training_info else "downloaded"
|
||||||
|
|
||||||
|
# Compute directory size
|
||||||
|
size_bytes = sum(f.stat().st_size for f in sub.rglob("*") if f.is_file())
|
||||||
|
|
||||||
|
# Load adapter/model_id from model_info.json or training_info.json
|
||||||
|
adapter: str | None = None
|
||||||
|
model_id: str | None = None
|
||||||
|
|
||||||
|
if has_model_info:
|
||||||
|
try:
|
||||||
|
info = json.loads((sub / "model_info.json").read_text(encoding="utf-8"))
|
||||||
|
adapter = info.get("adapter_recommendation")
|
||||||
|
model_id = info.get("repo_id")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
elif has_training_info:
|
||||||
|
try:
|
||||||
|
info = json.loads((sub / "training_info.json").read_text(encoding="utf-8"))
|
||||||
|
adapter = info.get("adapter")
|
||||||
|
model_id = info.get("base_model") or info.get("model_id")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
results.append({
|
||||||
|
"name": sub.name,
|
||||||
|
"path": str(sub),
|
||||||
|
"type": model_type,
|
||||||
|
"adapter": adapter,
|
||||||
|
"size_bytes": size_bytes,
|
||||||
|
"model_id": model_id,
|
||||||
|
})
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
# ── DELETE /installed/{name} ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@router.delete("/installed/{name}")
|
||||||
|
def delete_installed(name: str) -> dict:
|
||||||
|
"""Remove an installed model directory by name. Blocks path traversal."""
|
||||||
|
# Validate: single path component, no slashes or '..'
|
||||||
|
if "/" in name or "\\" in name or ".." in name or not name or name.startswith("."):
|
||||||
|
raise HTTPException(400, f"Invalid model name {name!r}: must be a single directory name with no path separators or '..'")
|
||||||
|
|
||||||
|
model_path = _MODELS_DIR / name
|
||||||
|
|
||||||
|
# Extra safety: confirm resolved path is inside _MODELS_DIR
|
||||||
|
try:
|
||||||
|
model_path.resolve().relative_to(_MODELS_DIR.resolve())
|
||||||
|
except ValueError:
|
||||||
|
raise HTTPException(400, f"Path traversal detected for name {name!r}")
|
||||||
|
|
||||||
|
if not model_path.exists():
|
||||||
|
raise HTTPException(404, f"Installed model {name!r} not found")
|
||||||
|
|
||||||
|
shutil.rmtree(model_path)
|
||||||
|
return {"ok": True}
|
||||||
37
app/sft.py
37
app/sft.py
|
|
@ -51,17 +51,26 @@ def _config_file() -> Path:
|
||||||
return _ROOT / "config" / "label_tool.yaml"
|
return _ROOT / "config" / "label_tool.yaml"
|
||||||
|
|
||||||
|
|
||||||
|
_DEFAULT_BENCH_RESULTS_DIR = "/Library/Development/CircuitForge/circuitforge-orch/scripts/bench_results"
|
||||||
|
|
||||||
|
|
||||||
|
def set_default_bench_results_dir(path: str) -> None:
|
||||||
|
"""Override the default bench_results_dir — used by tests to avoid real filesystem."""
|
||||||
|
global _DEFAULT_BENCH_RESULTS_DIR
|
||||||
|
_DEFAULT_BENCH_RESULTS_DIR = path
|
||||||
|
|
||||||
|
|
||||||
def _get_bench_results_dir() -> Path:
|
def _get_bench_results_dir() -> Path:
|
||||||
f = _config_file()
|
f = _config_file()
|
||||||
if not f.exists():
|
if f.exists():
|
||||||
return Path("/nonexistent-bench-results")
|
|
||||||
try:
|
try:
|
||||||
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
||||||
|
d = raw.get("sft", {}).get("bench_results_dir", "")
|
||||||
|
if d:
|
||||||
|
return Path(d)
|
||||||
except yaml.YAMLError as exc:
|
except yaml.YAMLError as exc:
|
||||||
logger.warning("Failed to parse SFT config %s: %s", f, exc)
|
logger.warning("Failed to parse SFT config %s: %s", f, exc)
|
||||||
return Path("/nonexistent-bench-results")
|
return Path(_DEFAULT_BENCH_RESULTS_DIR)
|
||||||
d = raw.get("sft", {}).get("bench_results_dir", "")
|
|
||||||
return Path(d) if d else Path("/nonexistent-bench-results")
|
|
||||||
|
|
||||||
|
|
||||||
def _candidates_file() -> Path:
|
def _candidates_file() -> Path:
|
||||||
|
|
@ -151,10 +160,21 @@ def get_queue(page: int = 1, per_page: int = 20):
|
||||||
|
|
||||||
# ── POST /submit ───────────────────────────────────────────────────────────
|
# ── POST /submit ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
FailureCategory = Literal[
|
||||||
|
"scoring_artifact",
|
||||||
|
"style_violation",
|
||||||
|
"partial_answer",
|
||||||
|
"wrong_answer",
|
||||||
|
"format_error",
|
||||||
|
"hallucination",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class SubmitRequest(BaseModel):
|
class SubmitRequest(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
action: Literal["correct", "discard", "flag"]
|
action: Literal["correct", "discard", "flag"]
|
||||||
corrected_response: str | None = None
|
corrected_response: str | None = None
|
||||||
|
failure_category: FailureCategory | None = None
|
||||||
|
|
||||||
|
|
||||||
@router.post("/submit")
|
@router.post("/submit")
|
||||||
|
|
@ -174,7 +194,12 @@ def post_submit(req: SubmitRequest):
|
||||||
raise HTTPException(409, f"Record is not in needs_review state (current: {record.get('status')})")
|
raise HTTPException(409, f"Record is not in needs_review state (current: {record.get('status')})")
|
||||||
|
|
||||||
if req.action == "correct":
|
if req.action == "correct":
|
||||||
records[idx] = {**record, "status": "approved", "corrected_response": req.corrected_response}
|
records[idx] = {
|
||||||
|
**record,
|
||||||
|
"status": "approved",
|
||||||
|
"corrected_response": req.corrected_response,
|
||||||
|
"failure_category": req.failure_category,
|
||||||
|
}
|
||||||
_write_candidates(records)
|
_write_candidates(records)
|
||||||
append_jsonl(_approved_file(), records[idx])
|
append_jsonl(_approved_file(), records[idx])
|
||||||
elif req.action == "discard":
|
elif req.action == "discard":
|
||||||
|
|
|
||||||
|
|
@ -26,3 +26,66 @@ max_per_account: 500
|
||||||
# produced by circuitforge-orch's benchmark harness.
|
# produced by circuitforge-orch's benchmark harness.
|
||||||
sft:
|
sft:
|
||||||
bench_results_dir: /path/to/circuitforge-orch/scripts/bench_results
|
bench_results_dir: /path/to/circuitforge-orch/scripts/bench_results
|
||||||
|
|
||||||
|
# cf-orch integration — LLM benchmark harness via cf-orch coordinator.
|
||||||
|
# All keys here override the corresponding environment variables.
|
||||||
|
# Omit any key to fall back to the env var (see .env.example).
|
||||||
|
cforch:
|
||||||
|
# Path to cf-orch's benchmark.py script
|
||||||
|
bench_script: /path/to/circuitforge-orch/scripts/benchmark.py
|
||||||
|
# Task and model definition files (yaml)
|
||||||
|
bench_tasks: /path/to/circuitforge-orch/scripts/bench_tasks.yaml
|
||||||
|
bench_models: /path/to/circuitforge-orch/scripts/bench_models.yaml
|
||||||
|
# Where benchmark results are written (also used for SFT candidate discovery)
|
||||||
|
results_dir: /path/to/circuitforge-orch/scripts/bench_results
|
||||||
|
# Python interpreter with cf-orch installed
|
||||||
|
python_bin: /devl/miniconda3/envs/cf/bin/python
|
||||||
|
|
||||||
|
# Connection config — override env vars CF_ORCH_URL / CF_LICENSE_KEY / OLLAMA_HOST
|
||||||
|
# coordinator_url: http://localhost:7700
|
||||||
|
# license_key: CFG-AVCT-xxxx-xxxx-xxxx
|
||||||
|
# ollama_url: http://localhost:11434
|
||||||
|
# ollama_model: llama3.2:3b
|
||||||
|
|
||||||
|
# Imitate tab — pull real samples from sibling CF product APIs and run them
|
||||||
|
# through local LLMs to build a corrections dataset.
|
||||||
|
# ollama_url defaults to cforch.ollama_url if omitted here.
|
||||||
|
imitate:
|
||||||
|
ollama_url: http://localhost:11434 # optional — falls back to cforch.ollama_url
|
||||||
|
|
||||||
|
products:
|
||||||
|
- id: peregrine
|
||||||
|
name: Peregrine
|
||||||
|
icon: "🦅"
|
||||||
|
description: Job search assistant
|
||||||
|
base_url: http://localhost:8502
|
||||||
|
sample_endpoint: /api/jobs
|
||||||
|
text_fields: [title, description]
|
||||||
|
prompt_template: "Analyze this job listing and identify key requirements:\n\n{text}"
|
||||||
|
|
||||||
|
- id: kiwi
|
||||||
|
name: Kiwi
|
||||||
|
icon: "🥝"
|
||||||
|
description: Pantry tracker
|
||||||
|
base_url: http://localhost:8511
|
||||||
|
sample_endpoint: /api/inventory
|
||||||
|
text_fields: [name, category, notes]
|
||||||
|
prompt_template: "Describe this pantry item and estimate how best to use it:\n\n{text}"
|
||||||
|
|
||||||
|
- id: snipe
|
||||||
|
name: Snipe
|
||||||
|
icon: "🎯"
|
||||||
|
description: eBay trust scoring
|
||||||
|
base_url: http://localhost:8509
|
||||||
|
sample_endpoint: /api/listings
|
||||||
|
text_fields: [title, description, seller_info]
|
||||||
|
prompt_template: "Evaluate the trustworthiness of this listing and flag any red flags:\n\n{text}"
|
||||||
|
|
||||||
|
- id: osprey
|
||||||
|
name: Osprey
|
||||||
|
icon: "📞"
|
||||||
|
description: Gov't hold-line automation
|
||||||
|
base_url: http://localhost:8520
|
||||||
|
sample_endpoint: /api/calls/recent
|
||||||
|
text_fields: [agency, issue, notes]
|
||||||
|
prompt_template: "Draft a concise summary of this government call record:\n\n{text}"
|
||||||
|
|
|
||||||
|
|
@ -22,5 +22,8 @@ dependencies:
|
||||||
# Optional: BGE reranker adapter
|
# Optional: BGE reranker adapter
|
||||||
# - FlagEmbedding
|
# - FlagEmbedding
|
||||||
|
|
||||||
|
# CircuitForge shared core (LLM router, tier system, config)
|
||||||
|
- circuitforge-core>=0.9.0
|
||||||
|
|
||||||
# Dev
|
# Dev
|
||||||
- pytest>=8.0
|
- pytest>=8.0
|
||||||
|
|
|
||||||
369
tests/test_cforch.py
Normal file
369
tests/test_cforch.py
Normal file
|
|
@ -0,0 +1,369 @@
|
||||||
|
"""Tests for app/cforch.py — /api/cforch/* endpoints."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import yaml
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
|
||||||
|
# ── Fixtures ───────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def reset_cforch_globals(tmp_path):
|
||||||
|
"""Redirect _CONFIG_DIR to tmp_path and reset running-state globals."""
|
||||||
|
from app import cforch as cforch_module
|
||||||
|
|
||||||
|
prev_config_dir = cforch_module._CONFIG_DIR
|
||||||
|
prev_running = cforch_module._BENCH_RUNNING
|
||||||
|
prev_proc = cforch_module._bench_proc
|
||||||
|
|
||||||
|
cforch_module.set_config_dir(tmp_path)
|
||||||
|
cforch_module._BENCH_RUNNING = False
|
||||||
|
cforch_module._bench_proc = None
|
||||||
|
|
||||||
|
yield tmp_path
|
||||||
|
|
||||||
|
cforch_module.set_config_dir(prev_config_dir)
|
||||||
|
cforch_module._BENCH_RUNNING = prev_running
|
||||||
|
cforch_module._bench_proc = prev_proc
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def client():
|
||||||
|
from app.api import app
|
||||||
|
return TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def config_dir(reset_cforch_globals):
|
||||||
|
"""Return the tmp config dir (already set as _CONFIG_DIR)."""
|
||||||
|
return reset_cforch_globals
|
||||||
|
|
||||||
|
|
||||||
|
def _write_config(config_dir: Path, cforch_cfg: dict) -> None:
|
||||||
|
"""Write a label_tool.yaml with the given cforch block into config_dir."""
|
||||||
|
cfg = {"cforch": cforch_cfg}
|
||||||
|
(config_dir / "label_tool.yaml").write_text(
|
||||||
|
yaml.dump(cfg), encoding="utf-8"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _write_tasks_yaml(path: Path, tasks: list[dict]) -> None:
|
||||||
|
path.write_text(yaml.dump({"tasks": tasks}), encoding="utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
def _write_models_yaml(path: Path, models: list[dict]) -> None:
|
||||||
|
path.write_text(yaml.dump({"models": models}), encoding="utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
# ── GET /tasks ─────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_tasks_returns_empty_when_not_configured(client):
|
||||||
|
"""No config file present — endpoint returns empty lists."""
|
||||||
|
r = client.get("/api/cforch/tasks")
|
||||||
|
assert r.status_code == 200
|
||||||
|
data = r.json()
|
||||||
|
assert data == {"tasks": [], "types": []}
|
||||||
|
|
||||||
|
|
||||||
|
def test_tasks_parses_yaml(client, config_dir, tmp_path):
|
||||||
|
tasks_file = tmp_path / "bench_tasks.yaml"
|
||||||
|
_write_tasks_yaml(tasks_file, [
|
||||||
|
{"id": "t1", "name": "Task One", "type": "instruction"},
|
||||||
|
{"id": "t2", "name": "Task Two", "type": "reasoning"},
|
||||||
|
])
|
||||||
|
_write_config(config_dir, {"bench_tasks": str(tasks_file)})
|
||||||
|
|
||||||
|
r = client.get("/api/cforch/tasks")
|
||||||
|
assert r.status_code == 200
|
||||||
|
data = r.json()
|
||||||
|
assert len(data["tasks"]) == 2
|
||||||
|
# TaskEntry now includes optional prompt/system fields (default "")
|
||||||
|
t1 = data["tasks"][0]
|
||||||
|
assert t1["id"] == "t1" and t1["name"] == "Task One" and t1["type"] == "instruction"
|
||||||
|
t2 = data["tasks"][1]
|
||||||
|
assert t2["id"] == "t2" and t2["name"] == "Task Two" and t2["type"] == "reasoning"
|
||||||
|
assert "instruction" in data["types"]
|
||||||
|
assert "reasoning" in data["types"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_tasks_returns_types_deduplicated(client, config_dir, tmp_path):
|
||||||
|
"""Multiple tasks sharing a type — types list must not duplicate."""
|
||||||
|
tasks_file = tmp_path / "bench_tasks.yaml"
|
||||||
|
_write_tasks_yaml(tasks_file, [
|
||||||
|
{"id": "t1", "name": "A", "type": "instruction"},
|
||||||
|
{"id": "t2", "name": "B", "type": "instruction"},
|
||||||
|
{"id": "t3", "name": "C", "type": "reasoning"},
|
||||||
|
])
|
||||||
|
_write_config(config_dir, {"bench_tasks": str(tasks_file)})
|
||||||
|
|
||||||
|
r = client.get("/api/cforch/tasks")
|
||||||
|
data = r.json()
|
||||||
|
assert data["types"].count("instruction") == 1
|
||||||
|
assert len(data["types"]) == 2
|
||||||
|
|
||||||
|
|
||||||
|
# ── GET /models ────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_models_returns_empty_when_not_configured(client):
|
||||||
|
"""No config file present — endpoint returns empty model list."""
|
||||||
|
r = client.get("/api/cforch/models")
|
||||||
|
assert r.status_code == 200
|
||||||
|
assert r.json() == {"models": []}
|
||||||
|
|
||||||
|
|
||||||
|
def test_models_parses_bench_models_yaml(client, config_dir, tmp_path):
|
||||||
|
models_file = tmp_path / "bench_models.yaml"
|
||||||
|
_write_models_yaml(models_file, [
|
||||||
|
{
|
||||||
|
"name": "llama3",
|
||||||
|
"id": "llama3:8b",
|
||||||
|
"service": "ollama",
|
||||||
|
"tags": ["fast", "small"],
|
||||||
|
"vram_estimate_mb": 6000,
|
||||||
|
}
|
||||||
|
])
|
||||||
|
_write_config(config_dir, {"bench_models": str(models_file)})
|
||||||
|
|
||||||
|
r = client.get("/api/cforch/models")
|
||||||
|
assert r.status_code == 200
|
||||||
|
data = r.json()
|
||||||
|
assert len(data["models"]) == 1
|
||||||
|
m = data["models"][0]
|
||||||
|
assert m["name"] == "llama3"
|
||||||
|
assert m["id"] == "llama3:8b"
|
||||||
|
assert m["service"] == "ollama"
|
||||||
|
assert m["tags"] == ["fast", "small"]
|
||||||
|
assert m["vram_estimate_mb"] == 6000
|
||||||
|
|
||||||
|
|
||||||
|
# ── GET /run ───────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_run_returns_409_when_already_running(client):
|
||||||
|
"""If _BENCH_RUNNING is True, GET /run returns 409."""
|
||||||
|
from app import cforch as cforch_module
|
||||||
|
cforch_module._BENCH_RUNNING = True
|
||||||
|
|
||||||
|
r = client.get("/api/cforch/run")
|
||||||
|
assert r.status_code == 409
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_returns_error_when_bench_script_not_configured(client):
|
||||||
|
"""No config at all — SSE stream contains an error event."""
|
||||||
|
r = client.get("/api/cforch/run")
|
||||||
|
assert r.status_code == 200
|
||||||
|
assert '"type": "error"' in r.text
|
||||||
|
assert "bench_script not configured" in r.text
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_streams_progress_events(client, config_dir, tmp_path):
|
||||||
|
"""Mock subprocess — SSE stream emits progress events from stdout."""
|
||||||
|
bench_script = tmp_path / "fake_benchmark.py"
|
||||||
|
bench_script.write_text("# fake", encoding="utf-8")
|
||||||
|
|
||||||
|
tasks_file = tmp_path / "bench_tasks.yaml"
|
||||||
|
tasks_file.write_text(yaml.dump({"tasks": []}), encoding="utf-8")
|
||||||
|
models_file = tmp_path / "bench_models.yaml"
|
||||||
|
models_file.write_text(yaml.dump({"models": []}), encoding="utf-8")
|
||||||
|
results_dir = tmp_path / "results"
|
||||||
|
results_dir.mkdir()
|
||||||
|
|
||||||
|
_write_config(config_dir, {
|
||||||
|
"bench_script": str(bench_script),
|
||||||
|
"bench_tasks": str(tasks_file),
|
||||||
|
"bench_models": str(models_file),
|
||||||
|
"results_dir": str(results_dir),
|
||||||
|
"python_bin": "/usr/bin/python3",
|
||||||
|
})
|
||||||
|
|
||||||
|
mock_proc = MagicMock()
|
||||||
|
mock_proc.stdout = iter(["Running task 1\n", "Running task 2\n"])
|
||||||
|
mock_proc.returncode = 1 # non-zero so we don't need summary.json
|
||||||
|
|
||||||
|
def mock_wait():
|
||||||
|
pass
|
||||||
|
|
||||||
|
mock_proc.wait = mock_wait
|
||||||
|
|
||||||
|
with patch("app.cforch._subprocess.Popen", return_value=mock_proc):
|
||||||
|
r = client.get("/api/cforch/run")
|
||||||
|
|
||||||
|
assert r.status_code == 200
|
||||||
|
assert '"type": "progress"' in r.text
|
||||||
|
assert "Running task 1" in r.text
|
||||||
|
assert "Running task 2" in r.text
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_emits_result_on_success(client, config_dir, tmp_path):
|
||||||
|
"""Mock subprocess exit 0 + write fake summary.json — stream emits result event."""
|
||||||
|
bench_script = tmp_path / "fake_benchmark.py"
|
||||||
|
bench_script.write_text("# fake", encoding="utf-8")
|
||||||
|
|
||||||
|
tasks_file = tmp_path / "bench_tasks.yaml"
|
||||||
|
tasks_file.write_text(yaml.dump({"tasks": []}), encoding="utf-8")
|
||||||
|
models_file = tmp_path / "bench_models.yaml"
|
||||||
|
models_file.write_text(yaml.dump({"models": []}), encoding="utf-8")
|
||||||
|
|
||||||
|
results_dir = tmp_path / "results"
|
||||||
|
run_dir = results_dir / "2026-04-08-120000"
|
||||||
|
run_dir.mkdir(parents=True)
|
||||||
|
summary_data = {"score": 0.92, "models_evaluated": 3}
|
||||||
|
(run_dir / "summary.json").write_text(json.dumps(summary_data), encoding="utf-8")
|
||||||
|
|
||||||
|
_write_config(config_dir, {
|
||||||
|
"bench_script": str(bench_script),
|
||||||
|
"bench_tasks": str(tasks_file),
|
||||||
|
"bench_models": str(models_file),
|
||||||
|
"results_dir": str(results_dir),
|
||||||
|
"python_bin": "/usr/bin/python3",
|
||||||
|
})
|
||||||
|
|
||||||
|
mock_proc = MagicMock()
|
||||||
|
mock_proc.stdout = iter([])
|
||||||
|
mock_proc.returncode = 0
|
||||||
|
mock_proc.wait = MagicMock()
|
||||||
|
|
||||||
|
with patch("app.cforch._subprocess.Popen", return_value=mock_proc):
|
||||||
|
r = client.get("/api/cforch/run")
|
||||||
|
|
||||||
|
assert r.status_code == 200
|
||||||
|
assert '"type": "result"' in r.text
|
||||||
|
assert '"score": 0.92' in r.text
|
||||||
|
assert '"type": "complete"' in r.text
|
||||||
|
|
||||||
|
|
||||||
|
# ── GET /results ───────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_results_returns_404_when_no_results(client):
|
||||||
|
"""No results_dir configured — endpoint returns 404."""
|
||||||
|
r = client.get("/api/cforch/results")
|
||||||
|
assert r.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
|
def test_results_returns_latest_summary(client, config_dir, tmp_path):
|
||||||
|
"""Write fake results dir with one subdir containing summary.json."""
|
||||||
|
results_dir = tmp_path / "results"
|
||||||
|
run_dir = results_dir / "2026-04-08-150000"
|
||||||
|
run_dir.mkdir(parents=True)
|
||||||
|
summary_data = {"score": 0.88, "run": "test"}
|
||||||
|
(run_dir / "summary.json").write_text(json.dumps(summary_data), encoding="utf-8")
|
||||||
|
|
||||||
|
_write_config(config_dir, {"results_dir": str(results_dir)})
|
||||||
|
|
||||||
|
r = client.get("/api/cforch/results")
|
||||||
|
assert r.status_code == 200
|
||||||
|
data = r.json()
|
||||||
|
assert data["score"] == 0.88
|
||||||
|
assert data["run"] == "test"
|
||||||
|
|
||||||
|
|
||||||
|
# ── POST /cancel ───────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_cancel_returns_404_when_not_running(client):
|
||||||
|
"""POST /cancel when no benchmark running — returns 404."""
|
||||||
|
r = client.post("/api/cforch/cancel")
|
||||||
|
assert r.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
|
def test_cancel_terminates_running_benchmark(client):
|
||||||
|
"""POST /cancel when benchmark is running — terminates proc and returns cancelled."""
|
||||||
|
from app import cforch as cforch_module
|
||||||
|
|
||||||
|
mock_proc = MagicMock()
|
||||||
|
cforch_module._BENCH_RUNNING = True
|
||||||
|
cforch_module._bench_proc = mock_proc
|
||||||
|
|
||||||
|
r = client.post("/api/cforch/cancel")
|
||||||
|
assert r.status_code == 200
|
||||||
|
assert r.json() == {"status": "cancelled"}
|
||||||
|
mock_proc.terminate.assert_called_once()
|
||||||
|
assert cforch_module._BENCH_RUNNING is False
|
||||||
|
assert cforch_module._bench_proc is None
|
||||||
|
|
||||||
|
|
||||||
|
# ── GET /config ────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_config_returns_empty_when_no_yaml_no_env(client, monkeypatch):
|
||||||
|
"""No yaml, no env vars — all fields empty, license_key_set False."""
|
||||||
|
for key in ("CF_ORCH_URL", "CF_LICENSE_KEY", "OLLAMA_HOST", "OLLAMA_MODEL"):
|
||||||
|
monkeypatch.delenv(key, raising=False)
|
||||||
|
|
||||||
|
r = client.get("/api/cforch/config")
|
||||||
|
assert r.status_code == 200
|
||||||
|
data = r.json()
|
||||||
|
assert data["coordinator_url"] == ""
|
||||||
|
assert data["ollama_url"] == ""
|
||||||
|
assert data["license_key_set"] is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_config_reads_env_vars_when_no_yaml(client, monkeypatch):
|
||||||
|
"""Env vars populate fields when label_tool.yaml has no cforch section."""
|
||||||
|
monkeypatch.setenv("CF_ORCH_URL", "http://orch.example.com:7700")
|
||||||
|
monkeypatch.setenv("CF_LICENSE_KEY", "CFG-AVCT-TEST-TEST-TEST")
|
||||||
|
monkeypatch.setenv("OLLAMA_HOST", "http://ollama.local:11434")
|
||||||
|
monkeypatch.setenv("OLLAMA_MODEL", "mistral:7b")
|
||||||
|
|
||||||
|
r = client.get("/api/cforch/config")
|
||||||
|
assert r.status_code == 200
|
||||||
|
data = r.json()
|
||||||
|
assert data["coordinator_url"] == "http://orch.example.com:7700"
|
||||||
|
assert data["ollama_url"] == "http://ollama.local:11434"
|
||||||
|
assert data["ollama_model"] == "mistral:7b"
|
||||||
|
assert data["license_key_set"] is True # set, but value not exposed
|
||||||
|
|
||||||
|
|
||||||
|
def test_config_yaml_overrides_env(client, config_dir, monkeypatch):
|
||||||
|
"""label_tool.yaml cforch values take priority over env vars."""
|
||||||
|
monkeypatch.setenv("CF_ORCH_URL", "http://env-orch:7700")
|
||||||
|
monkeypatch.setenv("OLLAMA_HOST", "http://env-ollama:11434")
|
||||||
|
|
||||||
|
_write_config(config_dir, {
|
||||||
|
"coordinator_url": "http://yaml-orch:7700",
|
||||||
|
"ollama_url": "http://yaml-ollama:11434",
|
||||||
|
})
|
||||||
|
|
||||||
|
r = client.get("/api/cforch/config")
|
||||||
|
assert r.status_code == 200
|
||||||
|
data = r.json()
|
||||||
|
assert data["coordinator_url"] == "http://yaml-orch:7700"
|
||||||
|
assert data["ollama_url"] == "http://yaml-ollama:11434"
|
||||||
|
assert data["source"] == "yaml+env"
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_passes_license_key_env_to_subprocess(client, config_dir, tmp_path, monkeypatch):
|
||||||
|
"""CF_LICENSE_KEY must be forwarded to the benchmark subprocess env."""
|
||||||
|
monkeypatch.setenv("CF_LICENSE_KEY", "CFG-AVCT-ENV-ONLY-KEY")
|
||||||
|
|
||||||
|
bench_script = tmp_path / "benchmark.py"
|
||||||
|
bench_script.write_text("# stub", encoding="utf-8")
|
||||||
|
tasks_file = tmp_path / "bench_tasks.yaml"
|
||||||
|
tasks_file.write_text(yaml.dump({"tasks": []}), encoding="utf-8")
|
||||||
|
models_file = tmp_path / "bench_models.yaml"
|
||||||
|
models_file.write_text(yaml.dump({"models": []}), encoding="utf-8")
|
||||||
|
|
||||||
|
_write_config(config_dir, {
|
||||||
|
"bench_script": str(bench_script),
|
||||||
|
"bench_tasks": str(tasks_file),
|
||||||
|
"bench_models": str(models_file),
|
||||||
|
"results_dir": str(tmp_path / "results"),
|
||||||
|
"python_bin": "/usr/bin/python3",
|
||||||
|
})
|
||||||
|
|
||||||
|
captured_env: dict = {}
|
||||||
|
|
||||||
|
def fake_popen(cmd, **kwargs):
|
||||||
|
captured_env.update(kwargs.get("env", {}))
|
||||||
|
mock = MagicMock()
|
||||||
|
mock.stdout = iter([])
|
||||||
|
mock.returncode = 0
|
||||||
|
mock.wait = MagicMock()
|
||||||
|
return mock
|
||||||
|
|
||||||
|
with patch("app.cforch._subprocess.Popen", side_effect=fake_popen):
|
||||||
|
client.get("/api/cforch/run")
|
||||||
|
|
||||||
|
assert captured_env.get("CF_LICENSE_KEY") == "CFG-AVCT-ENV-ONLY-KEY"
|
||||||
242
tests/test_imitate.py
Normal file
242
tests/test_imitate.py
Normal file
|
|
@ -0,0 +1,242 @@
|
||||||
|
"""Tests for app/imitate.py — product registry, sample extraction, corrections push."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
from app.api import app
|
||||||
|
from app import imitate as _imitate_module
|
||||||
|
|
||||||
|
|
||||||
|
# ── Fixtures ───────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def reset_module_globals(tmp_path):
|
||||||
|
"""Reset module-level config + data dir globals after each test."""
|
||||||
|
orig_cfg = _imitate_module._CONFIG_DIR
|
||||||
|
orig_data = _imitate_module._DATA_DIR
|
||||||
|
yield
|
||||||
|
_imitate_module._CONFIG_DIR = orig_cfg
|
||||||
|
_imitate_module._DATA_DIR = orig_data
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def config_dir(tmp_path) -> Path:
|
||||||
|
_imitate_module.set_config_dir(tmp_path)
|
||||||
|
return tmp_path
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def data_dir(tmp_path) -> Path:
|
||||||
|
_imitate_module.set_data_dir(tmp_path)
|
||||||
|
return tmp_path
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def cfg_with_products(config_dir: Path) -> Path:
|
||||||
|
"""Write a label_tool.yaml with two products."""
|
||||||
|
(config_dir / "label_tool.yaml").write_text(
|
||||||
|
"""
|
||||||
|
imitate:
|
||||||
|
ollama_url: http://localhost:11434
|
||||||
|
products:
|
||||||
|
- id: peregrine
|
||||||
|
name: Peregrine
|
||||||
|
icon: "🦅"
|
||||||
|
description: Job search assistant
|
||||||
|
base_url: http://peregrine.local
|
||||||
|
sample_endpoint: /api/jobs
|
||||||
|
text_fields: [title, description]
|
||||||
|
prompt_template: "Analyze: {text}"
|
||||||
|
- id: kiwi
|
||||||
|
name: Kiwi
|
||||||
|
icon: "🥝"
|
||||||
|
description: Pantry tracker
|
||||||
|
base_url: http://kiwi.local
|
||||||
|
sample_endpoint: /api/inventory
|
||||||
|
text_fields: [name, notes]
|
||||||
|
prompt_template: "Describe: {text}"
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
return config_dir
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def client() -> TestClient:
|
||||||
|
return TestClient(app, raise_server_exceptions=True)
|
||||||
|
|
||||||
|
|
||||||
|
# ── GET /products ──────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_products_empty_when_no_config(config_dir, client):
|
||||||
|
"""Returns empty list when label_tool.yaml has no imitate section."""
|
||||||
|
(config_dir / "label_tool.yaml").write_text("accounts: []\n")
|
||||||
|
resp = client.get("/api/imitate/products")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json()["products"] == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_products_listed(cfg_with_products, client):
|
||||||
|
"""All configured products are returned with expected fields."""
|
||||||
|
with patch.object(_imitate_module, "_is_online", return_value=True):
|
||||||
|
resp = client.get("/api/imitate/products")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
products = resp.json()["products"]
|
||||||
|
assert len(products) == 2
|
||||||
|
ids = {p["id"] for p in products}
|
||||||
|
assert ids == {"peregrine", "kiwi"}
|
||||||
|
peregrine = next(p for p in products if p["id"] == "peregrine")
|
||||||
|
assert peregrine["name"] == "Peregrine"
|
||||||
|
assert peregrine["icon"] == "🦅"
|
||||||
|
assert peregrine["online"] is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_products_offline_when_unreachable(cfg_with_products, client):
|
||||||
|
"""Products with unreachable base_url are marked offline."""
|
||||||
|
with patch.object(_imitate_module, "_is_online", return_value=False):
|
||||||
|
resp = client.get("/api/imitate/products")
|
||||||
|
assert all(not p["online"] for p in resp.json()["products"])
|
||||||
|
|
||||||
|
|
||||||
|
# ── GET /products/{id}/sample ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_sample_unknown_product(cfg_with_products, client):
|
||||||
|
"""Returns 404 for a product id not in config."""
|
||||||
|
resp = client.get("/api/imitate/products/nonexistent/sample")
|
||||||
|
assert resp.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
|
def test_sample_fetched_from_list(cfg_with_products, client):
|
||||||
|
"""Extracts first item from a list API response."""
|
||||||
|
fake_api = [
|
||||||
|
{"title": "Engineer", "description": "Build things"},
|
||||||
|
{"title": "Other", "description": "Ignore me"},
|
||||||
|
]
|
||||||
|
with patch.object(_imitate_module, "_http_get_json", return_value=fake_api):
|
||||||
|
resp = client.get("/api/imitate/products/peregrine/sample")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
body = resp.json()
|
||||||
|
assert "Engineer" in body["text"]
|
||||||
|
assert "Build things" in body["text"]
|
||||||
|
assert "Analyze:" in body["prompt"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_sample_fetched_from_dict_with_items_key(cfg_with_products, client):
|
||||||
|
"""Extracts from a wrapper dict with a recognised list key."""
|
||||||
|
fake_api = {"items": [{"title": "Wrapped Job", "description": "In a wrapper"}]}
|
||||||
|
with patch.object(_imitate_module, "_http_get_json", return_value=fake_api):
|
||||||
|
resp = client.get("/api/imitate/products/peregrine/sample")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert "Wrapped Job" in resp.json()["text"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_sample_503_when_api_unreachable(cfg_with_products, client):
|
||||||
|
"""Returns 503 when the product API is not reachable."""
|
||||||
|
from urllib.error import URLError
|
||||||
|
with patch.object(_imitate_module, "_http_get_json", side_effect=URLError("refused")):
|
||||||
|
resp = client.get("/api/imitate/products/peregrine/sample")
|
||||||
|
assert resp.status_code == 503
|
||||||
|
|
||||||
|
|
||||||
|
def test_sample_404_on_empty_list(cfg_with_products, client):
|
||||||
|
"""Returns 404 when product API returns an empty list."""
|
||||||
|
with patch.object(_imitate_module, "_http_get_json", return_value=[]):
|
||||||
|
resp = client.get("/api/imitate/products/peregrine/sample")
|
||||||
|
assert resp.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
|
# ── POST /push-corrections ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_push_corrections_appends_jsonl(cfg_with_products, data_dir, client):
|
||||||
|
"""Successful push writes records to sft_candidates.jsonl."""
|
||||||
|
payload = {
|
||||||
|
"product_id": "peregrine",
|
||||||
|
"prompt": "Analyze this job:",
|
||||||
|
"results": [
|
||||||
|
{"model": "qwen2.5:0.5b", "response": "It's a good job.", "elapsed_ms": 800, "error": None},
|
||||||
|
{"model": "llama3.1:8b", "response": "Strong candidate.", "elapsed_ms": 1500, "error": None},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
resp = client.post("/api/imitate/push-corrections", json=payload)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json()["pushed"] == 2
|
||||||
|
|
||||||
|
candidates = (data_dir / "sft_candidates.jsonl").read_text().splitlines()
|
||||||
|
assert len(candidates) == 2
|
||||||
|
for line in candidates:
|
||||||
|
record = json.loads(line)
|
||||||
|
assert record["source"] == "imitate"
|
||||||
|
assert record["product_id"] == "peregrine"
|
||||||
|
assert record["status"] == "pending"
|
||||||
|
assert record["prompt_messages"][0]["role"] == "user"
|
||||||
|
|
||||||
|
|
||||||
|
def test_push_corrections_skips_errors(cfg_with_products, data_dir, client):
|
||||||
|
"""Results with errors are not written to the corrections file."""
|
||||||
|
payload = {
|
||||||
|
"product_id": "peregrine",
|
||||||
|
"prompt": "Analyze:",
|
||||||
|
"results": [
|
||||||
|
{"model": "good-model", "response": "Good answer.", "elapsed_ms": 500, "error": None},
|
||||||
|
{"model": "bad-model", "response": "", "elapsed_ms": 0, "error": "connection refused"},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
resp = client.post("/api/imitate/push-corrections", json=payload)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json()["pushed"] == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_push_corrections_empty_prompt_422(cfg_with_products, data_dir, client):
|
||||||
|
"""Empty prompt returns 422."""
|
||||||
|
payload = {
|
||||||
|
"product_id": "peregrine",
|
||||||
|
"prompt": " ",
|
||||||
|
"results": [{"model": "m", "response": "r", "elapsed_ms": 1, "error": None}],
|
||||||
|
}
|
||||||
|
resp = client.post("/api/imitate/push-corrections", json=payload)
|
||||||
|
assert resp.status_code == 422
|
||||||
|
|
||||||
|
|
||||||
|
def test_push_corrections_all_errors_422(cfg_with_products, data_dir, client):
|
||||||
|
"""422 when every result has an error (nothing to push)."""
|
||||||
|
payload = {
|
||||||
|
"product_id": "peregrine",
|
||||||
|
"prompt": "Analyze:",
|
||||||
|
"results": [
|
||||||
|
{"model": "m", "response": "", "elapsed_ms": 0, "error": "timed out"},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
resp = client.post("/api/imitate/push-corrections", json=payload)
|
||||||
|
assert resp.status_code == 422
|
||||||
|
|
||||||
|
|
||||||
|
# ── _extract_sample helper ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_extract_sample_list():
|
||||||
|
result = _imitate_module._extract_sample(
|
||||||
|
[{"title": "A", "description": "B"}],
|
||||||
|
text_fields=["title", "description"],
|
||||||
|
)
|
||||||
|
assert "A" in result["text"]
|
||||||
|
assert "B" in result["text"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_sample_empty_list():
|
||||||
|
result = _imitate_module._extract_sample([], text_fields=["title"])
|
||||||
|
assert result == {}
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_sample_respects_index():
|
||||||
|
items = [{"title": "First"}, {"title": "Second"}]
|
||||||
|
result = _imitate_module._extract_sample(items, ["title"], sample_index=1)
|
||||||
|
assert "Second" in result["text"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_sample_clamps_index():
|
||||||
|
items = [{"title": "Only"}]
|
||||||
|
result = _imitate_module._extract_sample(items, ["title"], sample_index=99)
|
||||||
|
assert "Only" in result["text"]
|
||||||
402
tests/test_models.py
Normal file
402
tests/test_models.py
Normal file
|
|
@ -0,0 +1,402 @@
|
||||||
|
"""Tests for app/models.py — /api/models/* endpoints."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
|
||||||
|
# ── Fixtures ───────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def reset_models_globals(tmp_path):
|
||||||
|
"""Redirect module-level dirs to tmp_path and reset download progress."""
|
||||||
|
from app import models as models_module
|
||||||
|
|
||||||
|
prev_models = models_module._MODELS_DIR
|
||||||
|
prev_queue = models_module._QUEUE_DIR
|
||||||
|
prev_progress = dict(models_module._download_progress)
|
||||||
|
|
||||||
|
models_dir = tmp_path / "models"
|
||||||
|
queue_dir = tmp_path / "data"
|
||||||
|
models_dir.mkdir()
|
||||||
|
queue_dir.mkdir()
|
||||||
|
|
||||||
|
models_module.set_models_dir(models_dir)
|
||||||
|
models_module.set_queue_dir(queue_dir)
|
||||||
|
models_module._download_progress = {}
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
models_module.set_models_dir(prev_models)
|
||||||
|
models_module.set_queue_dir(prev_queue)
|
||||||
|
models_module._download_progress = prev_progress
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def client():
|
||||||
|
from app.api import app
|
||||||
|
return TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_hf_response(repo_id: str = "org/model", pipeline_tag: str = "text-classification") -> dict:
|
||||||
|
"""Minimal HF API response payload."""
|
||||||
|
return {
|
||||||
|
"modelId": repo_id,
|
||||||
|
"pipeline_tag": pipeline_tag,
|
||||||
|
"tags": ["pytorch", pipeline_tag],
|
||||||
|
"downloads": 42000,
|
||||||
|
"siblings": [
|
||||||
|
{"rfilename": "pytorch_model.bin", "size": 500_000_000},
|
||||||
|
],
|
||||||
|
"cardData": {"description": "A test model description."},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _queue_one(client, repo_id: str = "org/model") -> dict:
|
||||||
|
"""Helper: POST to /queue and return the created entry."""
|
||||||
|
r = client.post("/api/models/queue", json={
|
||||||
|
"repo_id": repo_id,
|
||||||
|
"pipeline_tag": "text-classification",
|
||||||
|
"adapter_recommendation": "ZeroShotAdapter",
|
||||||
|
})
|
||||||
|
assert r.status_code == 201, r.text
|
||||||
|
return r.json()
|
||||||
|
|
||||||
|
|
||||||
|
# ── GET /lookup ────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_lookup_invalid_repo_id_returns_422_no_slash(client):
|
||||||
|
"""repo_id without a '/' should be rejected with 422."""
|
||||||
|
r = client.get("/api/models/lookup", params={"repo_id": "noslash"})
|
||||||
|
assert r.status_code == 422
|
||||||
|
|
||||||
|
|
||||||
|
def test_lookup_invalid_repo_id_returns_422_whitespace(client):
|
||||||
|
"""repo_id containing whitespace should be rejected with 422."""
|
||||||
|
r = client.get("/api/models/lookup", params={"repo_id": "org/model name"})
|
||||||
|
assert r.status_code == 422
|
||||||
|
|
||||||
|
|
||||||
|
def test_lookup_hf_404_returns_404(client):
|
||||||
|
"""HF API returning 404 should surface as HTTP 404."""
|
||||||
|
mock_resp = MagicMock()
|
||||||
|
mock_resp.status_code = 404
|
||||||
|
|
||||||
|
with patch("app.models.httpx.get", return_value=mock_resp):
|
||||||
|
r = client.get("/api/models/lookup", params={"repo_id": "org/nonexistent"})
|
||||||
|
|
||||||
|
assert r.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
|
def test_lookup_hf_network_error_returns_502(client):
|
||||||
|
"""Network error reaching HF API should return 502."""
|
||||||
|
import httpx as _httpx
|
||||||
|
|
||||||
|
with patch("app.models.httpx.get", side_effect=_httpx.RequestError("timeout")):
|
||||||
|
r = client.get("/api/models/lookup", params={"repo_id": "org/model"})
|
||||||
|
|
||||||
|
assert r.status_code == 502
|
||||||
|
|
||||||
|
|
||||||
|
def test_lookup_returns_correct_shape(client):
|
||||||
|
"""Successful lookup returns all required fields."""
|
||||||
|
mock_resp = MagicMock()
|
||||||
|
mock_resp.status_code = 200
|
||||||
|
mock_resp.json.return_value = _make_hf_response("org/mymodel", "text-classification")
|
||||||
|
|
||||||
|
with patch("app.models.httpx.get", return_value=mock_resp):
|
||||||
|
r = client.get("/api/models/lookup", params={"repo_id": "org/mymodel"})
|
||||||
|
|
||||||
|
assert r.status_code == 200
|
||||||
|
data = r.json()
|
||||||
|
assert data["repo_id"] == "org/mymodel"
|
||||||
|
assert data["pipeline_tag"] == "text-classification"
|
||||||
|
assert data["adapter_recommendation"] == "ZeroShotAdapter"
|
||||||
|
assert data["model_size_bytes"] == 500_000_000
|
||||||
|
assert data["downloads"] == 42000
|
||||||
|
assert data["already_installed"] is False
|
||||||
|
assert data["already_queued"] is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_lookup_unknown_pipeline_tag_returns_null_adapter(client):
|
||||||
|
"""An unrecognised pipeline_tag yields adapter_recommendation=null."""
|
||||||
|
mock_resp = MagicMock()
|
||||||
|
mock_resp.status_code = 200
|
||||||
|
mock_resp.json.return_value = _make_hf_response("org/m", "audio-classification")
|
||||||
|
|
||||||
|
with patch("app.models.httpx.get", return_value=mock_resp):
|
||||||
|
r = client.get("/api/models/lookup", params={"repo_id": "org/m"})
|
||||||
|
|
||||||
|
assert r.status_code == 200
|
||||||
|
assert r.json()["adapter_recommendation"] is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_lookup_already_queued_flag(client):
|
||||||
|
"""already_queued is True when repo_id is in the pending queue."""
|
||||||
|
_queue_one(client, "org/queued-model")
|
||||||
|
|
||||||
|
mock_resp = MagicMock()
|
||||||
|
mock_resp.status_code = 200
|
||||||
|
mock_resp.json.return_value = _make_hf_response("org/queued-model")
|
||||||
|
|
||||||
|
with patch("app.models.httpx.get", return_value=mock_resp):
|
||||||
|
r = client.get("/api/models/lookup", params={"repo_id": "org/queued-model"})
|
||||||
|
|
||||||
|
assert r.status_code == 200
|
||||||
|
assert r.json()["already_queued"] is True
|
||||||
|
|
||||||
|
|
||||||
|
# ── GET /queue ─────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_queue_empty_initially(client):
|
||||||
|
r = client.get("/api/models/queue")
|
||||||
|
assert r.status_code == 200
|
||||||
|
assert r.json() == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_queue_add_and_list(client):
|
||||||
|
"""POST then GET /queue should return the entry."""
|
||||||
|
entry = _queue_one(client, "org/my-model")
|
||||||
|
|
||||||
|
r = client.get("/api/models/queue")
|
||||||
|
assert r.status_code == 200
|
||||||
|
items = r.json()
|
||||||
|
assert len(items) == 1
|
||||||
|
assert items[0]["repo_id"] == "org/my-model"
|
||||||
|
assert items[0]["status"] == "pending"
|
||||||
|
assert items[0]["id"] == entry["id"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_queue_add_returns_entry_fields(client):
|
||||||
|
"""POST /queue returns an entry with all expected fields."""
|
||||||
|
entry = _queue_one(client)
|
||||||
|
assert "id" in entry
|
||||||
|
assert "queued_at" in entry
|
||||||
|
assert entry["status"] == "pending"
|
||||||
|
assert entry["pipeline_tag"] == "text-classification"
|
||||||
|
assert entry["adapter_recommendation"] == "ZeroShotAdapter"
|
||||||
|
|
||||||
|
|
||||||
|
# ── POST /queue — 409 duplicate ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_queue_duplicate_returns_409(client):
|
||||||
|
"""Posting the same repo_id twice should return 409."""
|
||||||
|
_queue_one(client, "org/dup-model")
|
||||||
|
|
||||||
|
r = client.post("/api/models/queue", json={
|
||||||
|
"repo_id": "org/dup-model",
|
||||||
|
"pipeline_tag": "text-classification",
|
||||||
|
"adapter_recommendation": "ZeroShotAdapter",
|
||||||
|
})
|
||||||
|
assert r.status_code == 409
|
||||||
|
|
||||||
|
|
||||||
|
def test_queue_multiple_different_models(client):
|
||||||
|
"""Multiple distinct repo_ids should all be accepted."""
|
||||||
|
_queue_one(client, "org/model-a")
|
||||||
|
_queue_one(client, "org/model-b")
|
||||||
|
_queue_one(client, "org/model-c")
|
||||||
|
|
||||||
|
r = client.get("/api/models/queue")
|
||||||
|
assert r.status_code == 200
|
||||||
|
assert len(r.json()) == 3
|
||||||
|
|
||||||
|
|
||||||
|
# ── DELETE /queue/{id} — dismiss ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_queue_dismiss(client):
|
||||||
|
"""DELETE /queue/{id} sets status=dismissed; entry not returned by GET /queue."""
|
||||||
|
entry = _queue_one(client)
|
||||||
|
entry_id = entry["id"]
|
||||||
|
|
||||||
|
r = client.delete(f"/api/models/queue/{entry_id}")
|
||||||
|
assert r.status_code == 200
|
||||||
|
assert r.json() == {"ok": True}
|
||||||
|
|
||||||
|
r2 = client.get("/api/models/queue")
|
||||||
|
assert r2.status_code == 200
|
||||||
|
assert r2.json() == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_queue_dismiss_nonexistent_returns_404(client):
|
||||||
|
"""DELETE /queue/{id} with unknown id returns 404."""
|
||||||
|
r = client.delete("/api/models/queue/does-not-exist")
|
||||||
|
assert r.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
|
def test_queue_dismiss_allows_re_queue(client):
|
||||||
|
"""After dismissal the same repo_id can be queued again."""
|
||||||
|
entry = _queue_one(client, "org/requeue-model")
|
||||||
|
client.delete(f"/api/models/queue/{entry['id']}")
|
||||||
|
|
||||||
|
r = client.post("/api/models/queue", json={
|
||||||
|
"repo_id": "org/requeue-model",
|
||||||
|
"pipeline_tag": None,
|
||||||
|
"adapter_recommendation": None,
|
||||||
|
})
|
||||||
|
assert r.status_code == 201
|
||||||
|
|
||||||
|
|
||||||
|
# ── POST /queue/{id}/approve ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_approve_nonexistent_returns_404(client):
|
||||||
|
"""Approving an unknown id returns 404."""
|
||||||
|
r = client.post("/api/models/queue/ghost-id/approve")
|
||||||
|
assert r.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
|
def test_approve_non_pending_returns_409(client):
|
||||||
|
"""Approving an entry that is not in 'pending' state returns 409."""
|
||||||
|
from app import models as models_module
|
||||||
|
|
||||||
|
entry = _queue_one(client)
|
||||||
|
# Manually flip status to 'failed'
|
||||||
|
models_module._update_queue_entry(entry["id"], {"status": "failed"})
|
||||||
|
|
||||||
|
r = client.post(f"/api/models/queue/{entry['id']}/approve")
|
||||||
|
assert r.status_code == 409
|
||||||
|
|
||||||
|
|
||||||
|
def test_approve_starts_download_and_returns_ok(client):
|
||||||
|
"""Approving a pending entry returns {ok: true} and starts a background thread."""
|
||||||
|
import time
|
||||||
|
import threading
|
||||||
|
|
||||||
|
entry = _queue_one(client)
|
||||||
|
|
||||||
|
# Patch snapshot_download so the thread doesn't actually hit the network.
|
||||||
|
# Use an Event so we can wait for the thread to finish before asserting.
|
||||||
|
thread_done = threading.Event()
|
||||||
|
original_run = None
|
||||||
|
|
||||||
|
def _fake_snapshot_download(**kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
with patch("app.models.snapshot_download", side_effect=_fake_snapshot_download):
|
||||||
|
r = client.post(f"/api/models/queue/{entry['id']}/approve")
|
||||||
|
assert r.status_code == 200
|
||||||
|
assert r.json() == {"ok": True}
|
||||||
|
# Give the background thread a moment to complete while snapshot_download is patched
|
||||||
|
time.sleep(0.3)
|
||||||
|
|
||||||
|
# Queue entry status should have moved to 'downloading' (or 'ready' if fast)
|
||||||
|
from app import models as models_module
|
||||||
|
updated = models_module._get_queue_entry(entry["id"])
|
||||||
|
assert updated is not None, "Queue entry not found — thread may have run after fixture teardown"
|
||||||
|
assert updated["status"] in ("downloading", "ready", "failed")
|
||||||
|
|
||||||
|
|
||||||
|
# ── GET /download/stream ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_download_stream_idle_when_no_download(client):
|
||||||
|
"""GET /download/stream returns a single idle event when nothing is downloading."""
|
||||||
|
r = client.get("/api/models/download/stream")
|
||||||
|
assert r.status_code == 200
|
||||||
|
# SSE body should contain the idle event
|
||||||
|
assert "idle" in r.text
|
||||||
|
|
||||||
|
|
||||||
|
# ── GET /installed ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_installed_empty(client):
|
||||||
|
"""GET /installed returns [] when models dir is empty."""
|
||||||
|
r = client.get("/api/models/installed")
|
||||||
|
assert r.status_code == 200
|
||||||
|
assert r.json() == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_installed_detects_downloaded_model(client, tmp_path):
|
||||||
|
"""A subdir with config.json is surfaced as type='downloaded'."""
|
||||||
|
from app import models as models_module
|
||||||
|
|
||||||
|
model_dir = models_module._MODELS_DIR / "org--mymodel"
|
||||||
|
model_dir.mkdir()
|
||||||
|
(model_dir / "config.json").write_text(json.dumps({"model_type": "bert"}), encoding="utf-8")
|
||||||
|
(model_dir / "model_info.json").write_text(
|
||||||
|
json.dumps({"repo_id": "org/mymodel", "adapter_recommendation": "ZeroShotAdapter"}),
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
|
||||||
|
r = client.get("/api/models/installed")
|
||||||
|
assert r.status_code == 200
|
||||||
|
items = r.json()
|
||||||
|
assert len(items) == 1
|
||||||
|
assert items[0]["type"] == "downloaded"
|
||||||
|
assert items[0]["name"] == "org--mymodel"
|
||||||
|
assert items[0]["adapter"] == "ZeroShotAdapter"
|
||||||
|
assert items[0]["model_id"] == "org/mymodel"
|
||||||
|
|
||||||
|
|
||||||
|
def test_installed_detects_finetuned_model(client):
|
||||||
|
"""A subdir with training_info.json is surfaced as type='finetuned'."""
|
||||||
|
from app import models as models_module
|
||||||
|
|
||||||
|
model_dir = models_module._MODELS_DIR / "my-finetuned"
|
||||||
|
model_dir.mkdir()
|
||||||
|
(model_dir / "training_info.json").write_text(
|
||||||
|
json.dumps({"base_model": "org/base", "epochs": 5}), encoding="utf-8"
|
||||||
|
)
|
||||||
|
|
||||||
|
r = client.get("/api/models/installed")
|
||||||
|
assert r.status_code == 200
|
||||||
|
items = r.json()
|
||||||
|
assert len(items) == 1
|
||||||
|
assert items[0]["type"] == "finetuned"
|
||||||
|
assert items[0]["name"] == "my-finetuned"
|
||||||
|
|
||||||
|
|
||||||
|
# ── DELETE /installed/{name} ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_delete_installed_removes_directory(client):
|
||||||
|
"""DELETE /installed/{name} removes the directory and returns {ok: true}."""
|
||||||
|
from app import models as models_module
|
||||||
|
|
||||||
|
model_dir = models_module._MODELS_DIR / "org--removeme"
|
||||||
|
model_dir.mkdir()
|
||||||
|
(model_dir / "config.json").write_text("{}", encoding="utf-8")
|
||||||
|
|
||||||
|
r = client.delete("/api/models/installed/org--removeme")
|
||||||
|
assert r.status_code == 200
|
||||||
|
assert r.json() == {"ok": True}
|
||||||
|
assert not model_dir.exists()
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_installed_not_found_returns_404(client):
|
||||||
|
r = client.delete("/api/models/installed/does-not-exist")
|
||||||
|
assert r.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_installed_path_traversal_blocked(client):
|
||||||
|
"""DELETE /installed/../../etc must be blocked.
|
||||||
|
Path traversal normalises to a different URL (/api/etc); if web/dist exists
|
||||||
|
the StaticFiles mount intercepts it and returns 405 (GET/HEAD only).
|
||||||
|
"""
|
||||||
|
r = client.delete("/api/models/installed/../../etc")
|
||||||
|
assert r.status_code in (400, 404, 405, 422)
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_installed_dotdot_name_blocked(client):
|
||||||
|
"""A name containing '..' in any form must be rejected."""
|
||||||
|
r = client.delete("/api/models/installed/..%2F..%2Fetc")
|
||||||
|
assert r.status_code in (400, 404, 405, 422)
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_installed_name_with_slash_blocked(client):
|
||||||
|
"""A name containing a literal '/' after URL decoding must be rejected."""
|
||||||
|
from app import models as models_module
|
||||||
|
|
||||||
|
# The router will see the path segment after /installed/ — a second '/' would
|
||||||
|
# be parsed as a new path segment, so we test via the validation helper directly.
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
# Simulate calling delete logic with a slash-containing name directly
|
||||||
|
from fastapi import HTTPException as _HTTPException
|
||||||
|
from app.models import delete_installed
|
||||||
|
try:
|
||||||
|
delete_installed("org/traversal")
|
||||||
|
except _HTTPException as exc:
|
||||||
|
assert exc.status_code in (400, 404)
|
||||||
|
raise
|
||||||
|
|
@ -10,11 +10,14 @@ def reset_sft_globals(tmp_path):
|
||||||
from app import sft as sft_module
|
from app import sft as sft_module
|
||||||
_prev_data = sft_module._SFT_DATA_DIR
|
_prev_data = sft_module._SFT_DATA_DIR
|
||||||
_prev_cfg = sft_module._SFT_CONFIG_DIR
|
_prev_cfg = sft_module._SFT_CONFIG_DIR
|
||||||
|
_prev_default = sft_module._DEFAULT_BENCH_RESULTS_DIR
|
||||||
sft_module.set_sft_data_dir(tmp_path)
|
sft_module.set_sft_data_dir(tmp_path)
|
||||||
sft_module.set_sft_config_dir(tmp_path)
|
sft_module.set_sft_config_dir(tmp_path)
|
||||||
|
sft_module.set_default_bench_results_dir(str(tmp_path / "bench_results"))
|
||||||
yield
|
yield
|
||||||
sft_module.set_sft_data_dir(_prev_data)
|
sft_module.set_sft_data_dir(_prev_data)
|
||||||
sft_module.set_sft_config_dir(_prev_cfg)
|
sft_module.set_sft_config_dir(_prev_cfg)
|
||||||
|
sft_module.set_default_bench_results_dir(_prev_default)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|
@ -232,6 +235,41 @@ def test_submit_already_approved_returns_409(client, tmp_path):
|
||||||
assert r.status_code == 409
|
assert r.status_code == 409
|
||||||
|
|
||||||
|
|
||||||
|
def test_submit_correct_stores_failure_category(client, tmp_path):
|
||||||
|
_populate_candidates(tmp_path, [_make_record("a")])
|
||||||
|
r = client.post("/api/sft/submit", json={
|
||||||
|
"id": "a", "action": "correct",
|
||||||
|
"corrected_response": "def add(a, b): return a + b",
|
||||||
|
"failure_category": "style_violation",
|
||||||
|
})
|
||||||
|
assert r.status_code == 200
|
||||||
|
from app import sft as sft_module
|
||||||
|
records = sft_module._read_candidates()
|
||||||
|
assert records[0]["failure_category"] == "style_violation"
|
||||||
|
|
||||||
|
|
||||||
|
def test_submit_correct_null_failure_category(client, tmp_path):
|
||||||
|
_populate_candidates(tmp_path, [_make_record("a")])
|
||||||
|
r = client.post("/api/sft/submit", json={
|
||||||
|
"id": "a", "action": "correct",
|
||||||
|
"corrected_response": "def add(a, b): return a + b",
|
||||||
|
})
|
||||||
|
assert r.status_code == 200
|
||||||
|
from app import sft as sft_module
|
||||||
|
records = sft_module._read_candidates()
|
||||||
|
assert records[0]["failure_category"] is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_submit_invalid_failure_category_returns_422(client, tmp_path):
|
||||||
|
_populate_candidates(tmp_path, [_make_record("a")])
|
||||||
|
r = client.post("/api/sft/submit", json={
|
||||||
|
"id": "a", "action": "correct",
|
||||||
|
"corrected_response": "def add(a, b): return a + b",
|
||||||
|
"failure_category": "nonsense",
|
||||||
|
})
|
||||||
|
assert r.status_code == 422
|
||||||
|
|
||||||
|
|
||||||
# ── /api/sft/undo ────────────────────────────────────────────────────────────
|
# ── /api/sft/undo ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
def test_undo_restores_discarded_to_needs_review(client, tmp_path):
|
def test_undo_restores_discarded_to_needs_review(client, tmp_path):
|
||||||
|
|
|
||||||
|
|
@ -66,6 +66,8 @@ const navItems = [
|
||||||
{ path: '/fetch', icon: '📥', label: 'Fetch' },
|
{ path: '/fetch', icon: '📥', label: 'Fetch' },
|
||||||
{ path: '/stats', icon: '📊', label: 'Stats' },
|
{ path: '/stats', icon: '📊', label: 'Stats' },
|
||||||
{ path: '/benchmark', icon: '🏁', label: 'Benchmark' },
|
{ path: '/benchmark', icon: '🏁', label: 'Benchmark' },
|
||||||
|
{ path: '/models', icon: '🤗', label: 'Models' },
|
||||||
|
{ path: '/imitate', icon: '🪞', label: 'Imitate' },
|
||||||
{ path: '/corrections', icon: '✍️', label: 'Corrections' },
|
{ path: '/corrections', icon: '✍️', label: 'Corrections' },
|
||||||
{ path: '/settings', icon: '⚙️', label: 'Settings' },
|
{ path: '/settings', icon: '⚙️', label: 'Settings' },
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -13,6 +13,7 @@ const LOW_QUALITY_ITEM: SftQueueItem = {
|
||||||
model_response: 'def add(a, b): return a - b',
|
model_response: 'def add(a, b): return a - b',
|
||||||
corrected_response: null, quality_score: 0.2,
|
corrected_response: null, quality_score: 0.2,
|
||||||
failure_reason: 'pattern_match: 0/2 matched',
|
failure_reason: 'pattern_match: 0/2 matched',
|
||||||
|
failure_category: null,
|
||||||
task_id: 'code-fn', task_type: 'code', task_name: 'Code: Write a function',
|
task_id: 'code-fn', task_type: 'code', task_name: 'Code: Write a function',
|
||||||
model_id: 'Qwen/Qwen2.5-3B', model_name: 'Qwen2.5-3B',
|
model_id: 'Qwen/Qwen2.5-3B', model_name: 'Qwen2.5-3B',
|
||||||
node_id: 'heimdall', gpu_id: 0, tokens_per_sec: 38.4,
|
node_id: 'heimdall', gpu_id: 0, tokens_per_sec: 38.4,
|
||||||
|
|
@ -68,15 +69,17 @@ describe('SftCard', () => {
|
||||||
expect(w.emitted('correct')).toBeTruthy()
|
expect(w.emitted('correct')).toBeTruthy()
|
||||||
})
|
})
|
||||||
|
|
||||||
it('clicking Discard button emits discard', async () => {
|
it('clicking Discard button then confirming emits discard', async () => {
|
||||||
const w = mount(SftCard, { props: { item: LOW_QUALITY_ITEM } })
|
const w = mount(SftCard, { props: { item: LOW_QUALITY_ITEM } })
|
||||||
await w.find('[data-testid="discard-btn"]').trigger('click')
|
await w.find('[data-testid="discard-btn"]').trigger('click')
|
||||||
|
await w.find('[data-testid="confirm-pending-btn"]').trigger('click')
|
||||||
expect(w.emitted('discard')).toBeTruthy()
|
expect(w.emitted('discard')).toBeTruthy()
|
||||||
})
|
})
|
||||||
|
|
||||||
it('clicking Flag Model button emits flag', async () => {
|
it('clicking Flag Model button then confirming emits flag', async () => {
|
||||||
const w = mount(SftCard, { props: { item: LOW_QUALITY_ITEM } })
|
const w = mount(SftCard, { props: { item: LOW_QUALITY_ITEM } })
|
||||||
await w.find('[data-testid="flag-btn"]').trigger('click')
|
await w.find('[data-testid="flag-btn"]').trigger('click')
|
||||||
|
await w.find('[data-testid="confirm-pending-btn"]').trigger('click')
|
||||||
expect(w.emitted('flag')).toBeTruthy()
|
expect(w.emitted('flag')).toBeTruthy()
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
@ -95,4 +98,82 @@ describe('SftCard', () => {
|
||||||
const w = mount(SftCard, { props: { item } })
|
const w = mount(SftCard, { props: { item } })
|
||||||
expect(w.find('.failure-reason').exists()).toBe(false)
|
expect(w.find('.failure-reason').exists()).toBe(false)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// ── Failure category chip-group ───────────────────────────────────
|
||||||
|
it('failure category section hidden when not correcting and no pending action', () => {
|
||||||
|
const w = mount(SftCard, { props: { item: LOW_QUALITY_ITEM } })
|
||||||
|
expect(w.find('[data-testid="failure-category-section"]').exists()).toBe(false)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('failure category section shown when correcting prop is true', () => {
|
||||||
|
const w = mount(SftCard, { props: { item: LOW_QUALITY_ITEM, correcting: true } })
|
||||||
|
expect(w.find('[data-testid="failure-category-section"]').exists()).toBe(true)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('renders all six category chips when correcting', () => {
|
||||||
|
const w = mount(SftCard, { props: { item: LOW_QUALITY_ITEM, correcting: true } })
|
||||||
|
const chips = w.findAll('.category-chip')
|
||||||
|
expect(chips).toHaveLength(6)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('clicking a category chip selects it (adds active class)', async () => {
|
||||||
|
const w = mount(SftCard, { props: { item: LOW_QUALITY_ITEM, correcting: true } })
|
||||||
|
const chip = w.find('[data-testid="category-chip-wrong_answer"]')
|
||||||
|
await chip.trigger('click')
|
||||||
|
expect(chip.classes()).toContain('category-chip--active')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('clicking the active chip again deselects it', async () => {
|
||||||
|
const w = mount(SftCard, { props: { item: LOW_QUALITY_ITEM, correcting: true } })
|
||||||
|
const chip = w.find('[data-testid="category-chip-hallucination"]')
|
||||||
|
await chip.trigger('click')
|
||||||
|
expect(chip.classes()).toContain('category-chip--active')
|
||||||
|
await chip.trigger('click')
|
||||||
|
expect(chip.classes()).not.toContain('category-chip--active')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('only one chip can be active at a time', async () => {
|
||||||
|
const w = mount(SftCard, { props: { item: LOW_QUALITY_ITEM, correcting: true } })
|
||||||
|
await w.find('[data-testid="category-chip-wrong_answer"]').trigger('click')
|
||||||
|
await w.find('[data-testid="category-chip-hallucination"]').trigger('click')
|
||||||
|
const active = w.findAll('.category-chip--active')
|
||||||
|
expect(active).toHaveLength(1)
|
||||||
|
expect(active[0].attributes('data-testid')).toBe('category-chip-hallucination')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('clicking Discard shows pending action row with category section', async () => {
|
||||||
|
const w = mount(SftCard, { props: { item: LOW_QUALITY_ITEM } })
|
||||||
|
await w.find('[data-testid="discard-btn"]').trigger('click')
|
||||||
|
expect(w.find('[data-testid="failure-category-section"]').exists()).toBe(true)
|
||||||
|
expect(w.find('[data-testid="pending-action-row"]').exists()).toBe(true)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('clicking Flag shows pending action row', async () => {
|
||||||
|
const w = mount(SftCard, { props: { item: LOW_QUALITY_ITEM } })
|
||||||
|
await w.find('[data-testid="flag-btn"]').trigger('click')
|
||||||
|
expect(w.find('[data-testid="pending-action-row"]').exists()).toBe(true)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('confirming discard emits discard with null when no category selected', async () => {
|
||||||
|
const w = mount(SftCard, { props: { item: LOW_QUALITY_ITEM } })
|
||||||
|
await w.find('[data-testid="discard-btn"]').trigger('click')
|
||||||
|
await w.find('[data-testid="confirm-pending-btn"]').trigger('click')
|
||||||
|
expect(w.emitted('discard')).toBeTruthy()
|
||||||
|
expect(w.emitted('discard')![0]).toEqual([null])
|
||||||
|
})
|
||||||
|
|
||||||
|
it('confirming discard emits discard with selected category', async () => {
|
||||||
|
const w = mount(SftCard, { props: { item: LOW_QUALITY_ITEM } })
|
||||||
|
await w.find('[data-testid="discard-btn"]').trigger('click')
|
||||||
|
await w.find('[data-testid="category-chip-scoring_artifact"]').trigger('click')
|
||||||
|
await w.find('[data-testid="confirm-pending-btn"]').trigger('click')
|
||||||
|
expect(w.emitted('discard')![0]).toEqual(['scoring_artifact'])
|
||||||
|
})
|
||||||
|
|
||||||
|
it('cancelling pending action hides the pending row', async () => {
|
||||||
|
const w = mount(SftCard, { props: { item: LOW_QUALITY_ITEM } })
|
||||||
|
await w.find('[data-testid="discard-btn"]').trigger('click')
|
||||||
|
await w.find('[data-testid="cancel-pending-btn"]').trigger('click')
|
||||||
|
expect(w.find('[data-testid="pending-action-row"]').exists()).toBe(false)
|
||||||
|
})
|
||||||
})
|
})
|
||||||
|
|
|
||||||
|
|
@ -57,21 +57,52 @@
|
||||||
<button
|
<button
|
||||||
data-testid="discard-btn"
|
data-testid="discard-btn"
|
||||||
class="btn-discard"
|
class="btn-discard"
|
||||||
@click="$emit('discard')"
|
@click="emitWithCategory('discard')"
|
||||||
>✕ Discard</button>
|
>✕ Discard</button>
|
||||||
<button
|
<button
|
||||||
data-testid="flag-btn"
|
data-testid="flag-btn"
|
||||||
class="btn-flag"
|
class="btn-flag"
|
||||||
@click="$emit('flag')"
|
@click="emitWithCategory('flag')"
|
||||||
>⚑ Flag Model</button>
|
>⚑ Flag Model</button>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
<!-- Failure category selector (shown when correcting or acting) -->
|
||||||
|
<div
|
||||||
|
v-if="correcting || pendingAction"
|
||||||
|
class="failure-category-section"
|
||||||
|
data-testid="failure-category-section"
|
||||||
|
>
|
||||||
|
<p class="section-label">Failure category <span class="optional-label">(optional)</span></p>
|
||||||
|
<div class="category-chips" role="group" aria-label="Failure category">
|
||||||
|
<button
|
||||||
|
v-for="cat in FAILURE_CATEGORIES"
|
||||||
|
:key="cat.value"
|
||||||
|
type="button"
|
||||||
|
class="category-chip"
|
||||||
|
:class="{ 'category-chip--active': selectedCategory === cat.value }"
|
||||||
|
:aria-pressed="selectedCategory === cat.value || undefined"
|
||||||
|
:data-testid="'category-chip-' + cat.value"
|
||||||
|
@click="toggleCategory(cat.value)"
|
||||||
|
>{{ cat.label }}</button>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- Pending discard/flag confirm row -->
|
||||||
|
<div v-if="pendingAction" class="pending-action-row" data-testid="pending-action-row">
|
||||||
|
<button class="btn-confirm" @click="confirmPendingAction" data-testid="confirm-pending-btn">
|
||||||
|
Confirm {{ pendingAction }}
|
||||||
|
</button>
|
||||||
|
<button class="btn-cancel-pending" @click="cancelPendingAction" data-testid="cancel-pending-btn">
|
||||||
|
Cancel
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
<!-- Correction area (shown when correcting = true) -->
|
<!-- Correction area (shown when correcting = true) -->
|
||||||
<div v-if="correcting" data-testid="correction-area">
|
<div v-if="correcting" data-testid="correction-area">
|
||||||
<SftCorrectionArea
|
<SftCorrectionArea
|
||||||
ref="correctionAreaEl"
|
ref="correctionAreaEl"
|
||||||
:described-by="'sft-failure-' + item.id"
|
:described-by="'sft-failure-' + item.id"
|
||||||
@submit="$emit('submit-correction', $event)"
|
@submit="handleSubmitCorrection"
|
||||||
@cancel="$emit('cancel-correction')"
|
@cancel="$emit('cancel-correction')"
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
|
|
@ -80,21 +111,32 @@
|
||||||
|
|
||||||
<script setup lang="ts">
|
<script setup lang="ts">
|
||||||
import { ref, computed } from 'vue'
|
import { ref, computed } from 'vue'
|
||||||
import type { SftQueueItem } from '../stores/sft'
|
import type { SftQueueItem, SftFailureCategory } from '../stores/sft'
|
||||||
import SftCorrectionArea from './SftCorrectionArea.vue'
|
import SftCorrectionArea from './SftCorrectionArea.vue'
|
||||||
|
|
||||||
const props = defineProps<{ item: SftQueueItem; correcting?: boolean }>()
|
const props = defineProps<{ item: SftQueueItem; correcting?: boolean }>()
|
||||||
|
|
||||||
const emit = defineEmits<{
|
const emit = defineEmits<{
|
||||||
correct: []
|
correct: []
|
||||||
discard: []
|
discard: [category: SftFailureCategory | null]
|
||||||
flag: []
|
flag: [category: SftFailureCategory | null]
|
||||||
'submit-correction': [text: string]
|
'submit-correction': [text: string, category: SftFailureCategory | null]
|
||||||
'cancel-correction': []
|
'cancel-correction': []
|
||||||
}>()
|
}>()
|
||||||
|
|
||||||
|
const FAILURE_CATEGORIES: { value: SftFailureCategory; label: string }[] = [
|
||||||
|
{ value: 'scoring_artifact', label: 'Scoring artifact' },
|
||||||
|
{ value: 'style_violation', label: 'Style violation' },
|
||||||
|
{ value: 'partial_answer', label: 'Partial answer' },
|
||||||
|
{ value: 'wrong_answer', label: 'Wrong answer' },
|
||||||
|
{ value: 'format_error', label: 'Format error' },
|
||||||
|
{ value: 'hallucination', label: 'Hallucination' },
|
||||||
|
]
|
||||||
|
|
||||||
const promptExpanded = ref(false)
|
const promptExpanded = ref(false)
|
||||||
const correctionAreaEl = ref<InstanceType<typeof SftCorrectionArea> | null>(null)
|
const correctionAreaEl = ref<InstanceType<typeof SftCorrectionArea> | null>(null)
|
||||||
|
const selectedCategory = ref<SftFailureCategory | null>(null)
|
||||||
|
const pendingAction = ref<'discard' | 'flag' | null>(null)
|
||||||
|
|
||||||
const qualityClass = computed(() => {
|
const qualityClass = computed(() => {
|
||||||
const s = props.item.quality_score
|
const s = props.item.quality_score
|
||||||
|
|
@ -110,8 +152,34 @@ const qualityLabel = computed(() => {
|
||||||
return 'acceptable'
|
return 'acceptable'
|
||||||
})
|
})
|
||||||
|
|
||||||
|
function toggleCategory(cat: SftFailureCategory) {
|
||||||
|
selectedCategory.value = selectedCategory.value === cat ? null : cat
|
||||||
|
}
|
||||||
|
|
||||||
|
function emitWithCategory(action: 'discard' | 'flag') {
|
||||||
|
pendingAction.value = action
|
||||||
|
}
|
||||||
|
|
||||||
|
function confirmPendingAction() {
|
||||||
|
if (!pendingAction.value) return
|
||||||
|
emit(pendingAction.value, selectedCategory.value)
|
||||||
|
pendingAction.value = null
|
||||||
|
selectedCategory.value = null
|
||||||
|
}
|
||||||
|
|
||||||
|
function cancelPendingAction() {
|
||||||
|
pendingAction.value = null
|
||||||
|
}
|
||||||
|
|
||||||
|
function handleSubmitCorrection(text: string) {
|
||||||
|
emit('submit-correction', text, selectedCategory.value)
|
||||||
|
selectedCategory.value = null
|
||||||
|
}
|
||||||
|
|
||||||
function resetCorrection() {
|
function resetCorrection() {
|
||||||
correctionAreaEl.value?.reset()
|
correctionAreaEl.value?.reset()
|
||||||
|
selectedCategory.value = null
|
||||||
|
pendingAction.value = null
|
||||||
}
|
}
|
||||||
|
|
||||||
defineExpose({ resetCorrection })
|
defineExpose({ resetCorrection })
|
||||||
|
|
@ -243,4 +311,83 @@ defineExpose({ resetCorrection })
|
||||||
|
|
||||||
.btn-flag { border-color: var(--color-warning); color: var(--color-warning); }
|
.btn-flag { border-color: var(--color-warning); color: var(--color-warning); }
|
||||||
.btn-flag:hover { background: color-mix(in srgb, var(--color-warning) 10%, transparent); }
|
.btn-flag:hover { background: color-mix(in srgb, var(--color-warning) 10%, transparent); }
|
||||||
|
|
||||||
|
/* ── Failure category selector ─────────────────── */
|
||||||
|
.failure-category-section {
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
gap: var(--space-2);
|
||||||
|
}
|
||||||
|
|
||||||
|
.optional-label {
|
||||||
|
font-size: 0.75rem;
|
||||||
|
font-weight: 400;
|
||||||
|
color: var(--color-text-muted);
|
||||||
|
}
|
||||||
|
|
||||||
|
.category-chips {
|
||||||
|
display: flex;
|
||||||
|
flex-wrap: wrap;
|
||||||
|
gap: var(--space-2);
|
||||||
|
}
|
||||||
|
|
||||||
|
.category-chip {
|
||||||
|
padding: var(--space-1) var(--space-3);
|
||||||
|
border-radius: var(--radius-full);
|
||||||
|
border: 1px solid var(--color-border);
|
||||||
|
background: var(--color-surface-alt);
|
||||||
|
color: var(--color-text-muted);
|
||||||
|
font-size: 0.78rem;
|
||||||
|
font-weight: 500;
|
||||||
|
cursor: pointer;
|
||||||
|
transition: background var(--transition), color var(--transition), border-color var(--transition);
|
||||||
|
}
|
||||||
|
|
||||||
|
.category-chip:hover {
|
||||||
|
border-color: var(--color-accent);
|
||||||
|
color: var(--color-accent);
|
||||||
|
background: var(--color-accent-light);
|
||||||
|
}
|
||||||
|
|
||||||
|
.category-chip--active {
|
||||||
|
background: var(--color-accent-light);
|
||||||
|
border-color: var(--color-accent);
|
||||||
|
color: var(--color-accent);
|
||||||
|
font-weight: 700;
|
||||||
|
}
|
||||||
|
|
||||||
|
.pending-action-row {
|
||||||
|
display: flex;
|
||||||
|
gap: var(--space-2);
|
||||||
|
margin-top: var(--space-1);
|
||||||
|
}
|
||||||
|
|
||||||
|
.btn-confirm {
|
||||||
|
padding: var(--space-1) var(--space-3);
|
||||||
|
border-radius: var(--radius-md);
|
||||||
|
border: 1px solid var(--color-accent);
|
||||||
|
background: var(--color-accent-light);
|
||||||
|
color: var(--color-accent);
|
||||||
|
font-size: 0.85rem;
|
||||||
|
font-weight: 600;
|
||||||
|
cursor: pointer;
|
||||||
|
}
|
||||||
|
|
||||||
|
.btn-confirm:hover {
|
||||||
|
background: color-mix(in srgb, var(--color-accent) 15%, transparent);
|
||||||
|
}
|
||||||
|
|
||||||
|
.btn-cancel-pending {
|
||||||
|
padding: var(--space-1) var(--space-3);
|
||||||
|
border-radius: var(--radius-md);
|
||||||
|
border: 1px solid var(--color-border);
|
||||||
|
background: none;
|
||||||
|
color: var(--color-text-muted);
|
||||||
|
font-size: 0.85rem;
|
||||||
|
cursor: pointer;
|
||||||
|
}
|
||||||
|
|
||||||
|
.btn-cancel-pending:hover {
|
||||||
|
background: var(--color-surface-alt);
|
||||||
|
}
|
||||||
</style>
|
</style>
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,8 @@ const StatsView = () => import('../views/StatsView.vue')
|
||||||
const BenchmarkView = () => import('../views/BenchmarkView.vue')
|
const BenchmarkView = () => import('../views/BenchmarkView.vue')
|
||||||
const SettingsView = () => import('../views/SettingsView.vue')
|
const SettingsView = () => import('../views/SettingsView.vue')
|
||||||
const CorrectionsView = () => import('../views/CorrectionsView.vue')
|
const CorrectionsView = () => import('../views/CorrectionsView.vue')
|
||||||
|
const ModelsView = () => import('../views/ModelsView.vue')
|
||||||
|
const ImitateView = () => import('../views/ImitateView.vue')
|
||||||
|
|
||||||
export const router = createRouter({
|
export const router = createRouter({
|
||||||
history: createWebHashHistory(),
|
history: createWebHashHistory(),
|
||||||
|
|
@ -15,6 +17,8 @@ export const router = createRouter({
|
||||||
{ path: '/fetch', component: FetchView, meta: { title: 'Fetch' } },
|
{ path: '/fetch', component: FetchView, meta: { title: 'Fetch' } },
|
||||||
{ path: '/stats', component: StatsView, meta: { title: 'Stats' } },
|
{ path: '/stats', component: StatsView, meta: { title: 'Stats' } },
|
||||||
{ path: '/benchmark', component: BenchmarkView, meta: { title: 'Benchmark' } },
|
{ path: '/benchmark', component: BenchmarkView, meta: { title: 'Benchmark' } },
|
||||||
|
{ path: '/models', component: ModelsView, meta: { title: 'Models' } },
|
||||||
|
{ path: '/imitate', component: ImitateView, meta: { title: 'Imitate' } },
|
||||||
{ path: '/corrections', component: CorrectionsView, meta: { title: 'Corrections' } },
|
{ path: '/corrections', component: CorrectionsView, meta: { title: 'Corrections' } },
|
||||||
{ path: '/settings', component: SettingsView, meta: { title: 'Settings' } },
|
{ path: '/settings', component: SettingsView, meta: { title: 'Settings' } },
|
||||||
],
|
],
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,14 @@
|
||||||
import { defineStore } from 'pinia'
|
import { defineStore } from 'pinia'
|
||||||
import { computed, ref } from 'vue'
|
import { computed, ref } from 'vue'
|
||||||
|
|
||||||
|
export type SftFailureCategory =
|
||||||
|
| 'scoring_artifact'
|
||||||
|
| 'style_violation'
|
||||||
|
| 'partial_answer'
|
||||||
|
| 'wrong_answer'
|
||||||
|
| 'format_error'
|
||||||
|
| 'hallucination'
|
||||||
|
|
||||||
export interface SftQueueItem {
|
export interface SftQueueItem {
|
||||||
id: string
|
id: string
|
||||||
source: 'cf-orch-benchmark'
|
source: 'cf-orch-benchmark'
|
||||||
|
|
@ -13,6 +21,7 @@ export interface SftQueueItem {
|
||||||
corrected_response: string | null
|
corrected_response: string | null
|
||||||
quality_score: number // 0.0 to 1.0
|
quality_score: number // 0.0 to 1.0
|
||||||
failure_reason: string | null
|
failure_reason: string | null
|
||||||
|
failure_category: SftFailureCategory | null
|
||||||
task_id: string
|
task_id: string
|
||||||
task_type: string
|
task_type: string
|
||||||
task_name: string
|
task_name: string
|
||||||
|
|
@ -26,6 +35,7 @@ export interface SftQueueItem {
|
||||||
export interface SftLastAction {
|
export interface SftLastAction {
|
||||||
type: 'correct' | 'discard' | 'flag'
|
type: 'correct' | 'discard' | 'flag'
|
||||||
item: SftQueueItem
|
item: SftQueueItem
|
||||||
|
failure_category?: SftFailureCategory | null
|
||||||
}
|
}
|
||||||
|
|
||||||
export const useSftStore = defineStore('sft', () => {
|
export const useSftStore = defineStore('sft', () => {
|
||||||
|
|
@ -39,8 +49,12 @@ export const useSftStore = defineStore('sft', () => {
|
||||||
queue.value.shift()
|
queue.value.shift()
|
||||||
}
|
}
|
||||||
|
|
||||||
function setLastAction(type: SftLastAction['type'], item: SftQueueItem) {
|
function setLastAction(
|
||||||
lastAction.value = { type, item }
|
type: SftLastAction['type'],
|
||||||
|
item: SftQueueItem,
|
||||||
|
failure_category?: SftFailureCategory | null,
|
||||||
|
) {
|
||||||
|
lastAction.value = { type, item, failure_category }
|
||||||
}
|
}
|
||||||
|
|
||||||
function clearLastAction() {
|
function clearLastAction() {
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load diff
|
|
@ -36,6 +36,7 @@
|
||||||
@flag="handleFlag"
|
@flag="handleFlag"
|
||||||
@submit-correction="handleCorrect"
|
@submit-correction="handleCorrect"
|
||||||
@cancel-correction="correcting = false"
|
@cancel-correction="correcting = false"
|
||||||
|
ref="sftCardEl"
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
</template>
|
</template>
|
||||||
|
|
@ -67,6 +68,7 @@
|
||||||
<script setup lang="ts">
|
<script setup lang="ts">
|
||||||
import { ref, onMounted } from 'vue'
|
import { ref, onMounted } from 'vue'
|
||||||
import { useSftStore } from '../stores/sft'
|
import { useSftStore } from '../stores/sft'
|
||||||
|
import type { SftFailureCategory } from '../stores/sft'
|
||||||
import { useSftKeyboard } from '../composables/useSftKeyboard'
|
import { useSftKeyboard } from '../composables/useSftKeyboard'
|
||||||
import SftCard from '../components/SftCard.vue'
|
import SftCard from '../components/SftCard.vue'
|
||||||
|
|
||||||
|
|
@ -76,6 +78,7 @@ const apiError = ref(false)
|
||||||
const correcting = ref(false)
|
const correcting = ref(false)
|
||||||
const stats = ref<Record<string, any> | null>(null)
|
const stats = ref<Record<string, any> | null>(null)
|
||||||
const exportUrl = '/api/sft/export'
|
const exportUrl = '/api/sft/export'
|
||||||
|
const sftCardEl = ref<InstanceType<typeof SftCard> | null>(null)
|
||||||
|
|
||||||
useSftKeyboard({
|
useSftKeyboard({
|
||||||
onCorrect: () => { if (store.current && !correcting.value) correcting.value = true },
|
onCorrect: () => { if (store.current && !correcting.value) correcting.value = true },
|
||||||
|
|
@ -113,19 +116,21 @@ function startCorrection() {
|
||||||
correcting.value = true
|
correcting.value = true
|
||||||
}
|
}
|
||||||
|
|
||||||
async function handleCorrect(text: string) {
|
async function handleCorrect(text: string, category: SftFailureCategory | null = null) {
|
||||||
if (!store.current) return
|
if (!store.current) return
|
||||||
const item = store.current
|
const item = store.current
|
||||||
correcting.value = false
|
correcting.value = false
|
||||||
try {
|
try {
|
||||||
|
const body: Record<string, unknown> = { id: item.id, action: 'correct', corrected_response: text }
|
||||||
|
if (category != null) body.failure_category = category
|
||||||
const res = await fetch('/api/sft/submit', {
|
const res = await fetch('/api/sft/submit', {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: { 'Content-Type': 'application/json' },
|
headers: { 'Content-Type': 'application/json' },
|
||||||
body: JSON.stringify({ id: item.id, action: 'correct', corrected_response: text }),
|
body: JSON.stringify(body),
|
||||||
})
|
})
|
||||||
if (!res.ok) throw new Error(`HTTP ${res.status}`)
|
if (!res.ok) throw new Error(`HTTP ${res.status}`)
|
||||||
store.removeCurrentFromQueue()
|
store.removeCurrentFromQueue()
|
||||||
store.setLastAction('correct', item)
|
store.setLastAction('correct', item, category)
|
||||||
store.totalRemaining = Math.max(0, store.totalRemaining - 1)
|
store.totalRemaining = Math.max(0, store.totalRemaining - 1)
|
||||||
fetchStats()
|
fetchStats()
|
||||||
if (store.queue.length < 5) fetchBatch()
|
if (store.queue.length < 5) fetchBatch()
|
||||||
|
|
@ -134,18 +139,20 @@ async function handleCorrect(text: string) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async function handleDiscard() {
|
async function handleDiscard(category: SftFailureCategory | null = null) {
|
||||||
if (!store.current) return
|
if (!store.current) return
|
||||||
const item = store.current
|
const item = store.current
|
||||||
try {
|
try {
|
||||||
|
const body: Record<string, unknown> = { id: item.id, action: 'discard' }
|
||||||
|
if (category != null) body.failure_category = category
|
||||||
const res = await fetch('/api/sft/submit', {
|
const res = await fetch('/api/sft/submit', {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: { 'Content-Type': 'application/json' },
|
headers: { 'Content-Type': 'application/json' },
|
||||||
body: JSON.stringify({ id: item.id, action: 'discard' }),
|
body: JSON.stringify(body),
|
||||||
})
|
})
|
||||||
if (!res.ok) throw new Error(`HTTP ${res.status}`)
|
if (!res.ok) throw new Error(`HTTP ${res.status}`)
|
||||||
store.removeCurrentFromQueue()
|
store.removeCurrentFromQueue()
|
||||||
store.setLastAction('discard', item)
|
store.setLastAction('discard', item, category)
|
||||||
store.totalRemaining = Math.max(0, store.totalRemaining - 1)
|
store.totalRemaining = Math.max(0, store.totalRemaining - 1)
|
||||||
fetchStats()
|
fetchStats()
|
||||||
if (store.queue.length < 5) fetchBatch()
|
if (store.queue.length < 5) fetchBatch()
|
||||||
|
|
@ -154,18 +161,20 @@ async function handleDiscard() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async function handleFlag() {
|
async function handleFlag(category: SftFailureCategory | null = null) {
|
||||||
if (!store.current) return
|
if (!store.current) return
|
||||||
const item = store.current
|
const item = store.current
|
||||||
try {
|
try {
|
||||||
|
const body: Record<string, unknown> = { id: item.id, action: 'flag' }
|
||||||
|
if (category != null) body.failure_category = category
|
||||||
const res = await fetch('/api/sft/submit', {
|
const res = await fetch('/api/sft/submit', {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: { 'Content-Type': 'application/json' },
|
headers: { 'Content-Type': 'application/json' },
|
||||||
body: JSON.stringify({ id: item.id, action: 'flag' }),
|
body: JSON.stringify(body),
|
||||||
})
|
})
|
||||||
if (!res.ok) throw new Error(`HTTP ${res.status}`)
|
if (!res.ok) throw new Error(`HTTP ${res.status}`)
|
||||||
store.removeCurrentFromQueue()
|
store.removeCurrentFromQueue()
|
||||||
store.setLastAction('flag', item)
|
store.setLastAction('flag', item, category)
|
||||||
store.totalRemaining = Math.max(0, store.totalRemaining - 1)
|
store.totalRemaining = Math.max(0, store.totalRemaining - 1)
|
||||||
fetchStats()
|
fetchStats()
|
||||||
if (store.queue.length < 5) fetchBatch()
|
if (store.queue.length < 5) fetchBatch()
|
||||||
|
|
|
||||||
898
web/src/views/ImitateView.vue
Normal file
898
web/src/views/ImitateView.vue
Normal file
|
|
@ -0,0 +1,898 @@
|
||||||
|
<template>
|
||||||
|
<div class="imitate-view">
|
||||||
|
<header class="bench-header">
|
||||||
|
<h1 class="page-title">🪞 Imitate</h1>
|
||||||
|
<p class="page-subtitle">Pull real samples from CF product APIs and compare LLM responses</p>
|
||||||
|
</header>
|
||||||
|
|
||||||
|
<!-- ── Step 1: Product selection ──────────────────────────────── -->
|
||||||
|
<section class="step-section">
|
||||||
|
<h2 class="step-heading">1. Select Product</h2>
|
||||||
|
<div v-if="productsLoading" class="picker-loading">Loading products…</div>
|
||||||
|
<div v-else-if="products.length === 0" class="picker-empty">
|
||||||
|
No products configured — add an <code>imitate:</code> section to
|
||||||
|
<code>config/label_tool.yaml</code>.
|
||||||
|
</div>
|
||||||
|
<div v-else class="product-grid">
|
||||||
|
<button
|
||||||
|
v-for="p in products"
|
||||||
|
:key="p.id"
|
||||||
|
class="product-card"
|
||||||
|
:class="{
|
||||||
|
selected: selectedProduct?.id === p.id,
|
||||||
|
offline: !p.online,
|
||||||
|
}"
|
||||||
|
:disabled="!p.online"
|
||||||
|
:title="p.online ? p.description : `${p.name} is offline`"
|
||||||
|
@click="selectProduct(p)"
|
||||||
|
>
|
||||||
|
<span class="product-icon">{{ p.icon }}</span>
|
||||||
|
<span class="product-name">{{ p.name }}</span>
|
||||||
|
<span class="product-status" :class="p.online ? 'status-on' : 'status-off'">
|
||||||
|
{{ p.online ? 'online' : 'offline' }}
|
||||||
|
</span>
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
</section>
|
||||||
|
|
||||||
|
<!-- ── Step 2: Sample + Prompt ────────────────────────────────── -->
|
||||||
|
<section v-if="selectedProduct" class="step-section">
|
||||||
|
<h2 class="step-heading">2. Sample & Prompt</h2>
|
||||||
|
<div class="sample-toolbar">
|
||||||
|
<span class="sample-product-label">{{ selectedProduct.icon }} {{ selectedProduct.name }}</span>
|
||||||
|
<button class="btn-refresh" :disabled="sampleLoading" @click="fetchSample">
|
||||||
|
{{ sampleLoading ? '⏳ Fetching…' : '🔄 Refresh Sample' }}
|
||||||
|
</button>
|
||||||
|
<span v-if="sampleError" class="sample-error">{{ sampleError }}</span>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div v-if="sampleLoading" class="picker-loading">Fetching sample from API…</div>
|
||||||
|
|
||||||
|
<template v-else-if="rawSample">
|
||||||
|
<!-- Fetched text preview -->
|
||||||
|
<details class="sample-preview" open>
|
||||||
|
<summary class="sample-preview-toggle">Raw sample text</summary>
|
||||||
|
<pre class="sample-text">{{ rawSample.text }}</pre>
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<!-- Prompt editor -->
|
||||||
|
<label class="prompt-label" for="prompt-editor">Prompt sent to models</label>
|
||||||
|
<textarea
|
||||||
|
id="prompt-editor"
|
||||||
|
class="prompt-editor"
|
||||||
|
v-model="editedPrompt"
|
||||||
|
rows="8"
|
||||||
|
/>
|
||||||
|
</template>
|
||||||
|
|
||||||
|
<div v-else-if="!sampleLoading && selectedProduct" class="picker-empty">
|
||||||
|
Click "Refresh Sample" to fetch a real sample from {{ selectedProduct.name }}.
|
||||||
|
</div>
|
||||||
|
</section>
|
||||||
|
|
||||||
|
<!-- ── Step 3: Models + Run ───────────────────────────────────── -->
|
||||||
|
<section v-if="editedPrompt" class="step-section">
|
||||||
|
<h2 class="step-heading">3. Models & Run</h2>
|
||||||
|
|
||||||
|
<!-- Ollama model picker -->
|
||||||
|
<details class="model-picker" open>
|
||||||
|
<summary class="picker-summary">
|
||||||
|
<span class="picker-title">🤖 Ollama Models</span>
|
||||||
|
<span class="picker-badge">{{ selectedModels.size }} / {{ ollamaModels.length }}</span>
|
||||||
|
</summary>
|
||||||
|
<div class="picker-body">
|
||||||
|
<div v-if="modelsLoading" class="picker-loading">Loading models…</div>
|
||||||
|
<div v-else-if="ollamaModels.length === 0" class="picker-empty">
|
||||||
|
No ollama models in bench_models.yaml — add models with <code>service: ollama</code>.
|
||||||
|
</div>
|
||||||
|
<template v-else>
|
||||||
|
<label class="picker-cat-header">
|
||||||
|
<input
|
||||||
|
type="checkbox"
|
||||||
|
:checked="selectedModels.size === ollamaModels.length"
|
||||||
|
:indeterminate="selectedModels.size > 0 && selectedModels.size < ollamaModels.length"
|
||||||
|
@change="toggleAllModels(($event.target as HTMLInputElement).checked)"
|
||||||
|
/>
|
||||||
|
<span class="picker-cat-name">All ollama models</span>
|
||||||
|
</label>
|
||||||
|
<div class="picker-model-list">
|
||||||
|
<label v-for="m in ollamaModels" :key="m.id" class="picker-model-row">
|
||||||
|
<input
|
||||||
|
type="checkbox"
|
||||||
|
:checked="selectedModels.has(m.id)"
|
||||||
|
@change="toggleModel(m.id, ($event.target as HTMLInputElement).checked)"
|
||||||
|
/>
|
||||||
|
<span class="picker-model-name" :title="m.name">{{ m.name }}</span>
|
||||||
|
<span class="picker-model-tags">
|
||||||
|
<span v-for="tag in m.tags.slice(0, 3)" :key="tag" class="tag">{{ tag }}</span>
|
||||||
|
</span>
|
||||||
|
</label>
|
||||||
|
</div>
|
||||||
|
</template>
|
||||||
|
</div>
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<!-- Temperature -->
|
||||||
|
<div class="temp-row">
|
||||||
|
<label for="temp-slider" class="temp-label">Temperature: <strong>{{ temperature.toFixed(1) }}</strong></label>
|
||||||
|
<input
|
||||||
|
id="temp-slider"
|
||||||
|
type="range" min="0" max="1" step="0.1"
|
||||||
|
:value="temperature"
|
||||||
|
@input="temperature = parseFloat(($event.target as HTMLInputElement).value)"
|
||||||
|
class="temp-slider"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- Run controls -->
|
||||||
|
<div class="run-row">
|
||||||
|
<button
|
||||||
|
class="btn-run"
|
||||||
|
:disabled="running || selectedModels.size === 0"
|
||||||
|
@click="startRun"
|
||||||
|
>
|
||||||
|
{{ running ? '⏳ Running…' : '▶ Run' }}
|
||||||
|
</button>
|
||||||
|
<button v-if="running" class="btn-cancel" @click="cancelRun">✕ Cancel</button>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- Progress log -->
|
||||||
|
<div v-if="runLog.length > 0" class="run-log" aria-live="polite">
|
||||||
|
<div v-for="(line, i) in runLog" :key="i" class="log-line">{{ line }}</div>
|
||||||
|
</div>
|
||||||
|
</section>
|
||||||
|
|
||||||
|
<!-- ── Step 4: Results ────────────────────────────────────────── -->
|
||||||
|
<section v-if="results.length > 0" class="step-section">
|
||||||
|
<h2 class="step-heading">4. Results</h2>
|
||||||
|
|
||||||
|
<div class="results-grid">
|
||||||
|
<div
|
||||||
|
v-for="r in results"
|
||||||
|
:key="r.model"
|
||||||
|
class="result-card"
|
||||||
|
:class="{ 'result-error': !!r.error }"
|
||||||
|
>
|
||||||
|
<div class="result-header">
|
||||||
|
<span class="result-model">{{ r.model }}</span>
|
||||||
|
<span class="result-meta">
|
||||||
|
<template v-if="r.error">
|
||||||
|
<span class="result-err-badge">error</span>
|
||||||
|
</template>
|
||||||
|
<template v-else>
|
||||||
|
{{ (r.elapsed_ms / 1000).toFixed(1) }}s
|
||||||
|
</template>
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
<pre v-if="r.error" class="result-error-text">{{ r.error }}</pre>
|
||||||
|
<pre v-else class="result-response">{{ r.response }}</pre>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="corrections-row">
|
||||||
|
<button
|
||||||
|
class="btn-corrections"
|
||||||
|
:disabled="pushingCorrections || !selectedProduct || successfulResults.length === 0"
|
||||||
|
@click="pushCorrections"
|
||||||
|
>
|
||||||
|
{{ pushingCorrections ? '⏳ Pushing…' : `✍ Send ${successfulResults.length} to Corrections` }}
|
||||||
|
</button>
|
||||||
|
<span v-if="correctionsPushMsg" class="corrections-msg" :class="correctionsPushOk ? 'msg-ok' : 'msg-err'">
|
||||||
|
{{ correctionsPushMsg }}
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
</section>
|
||||||
|
</div>
|
||||||
|
</template>
|
||||||
|
|
||||||
|
<script setup lang="ts">
|
||||||
|
import { ref, computed, onMounted } from 'vue'
|
||||||
|
|
||||||
|
// ── Types ──────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
interface Product {
|
||||||
|
id: string
|
||||||
|
name: string
|
||||||
|
icon: string
|
||||||
|
description: string
|
||||||
|
base_url: string
|
||||||
|
online: boolean
|
||||||
|
}
|
||||||
|
|
||||||
|
interface Sample {
|
||||||
|
product_id: string
|
||||||
|
sample_index: number
|
||||||
|
text: string
|
||||||
|
prompt: string
|
||||||
|
raw_item: Record<string, unknown>
|
||||||
|
}
|
||||||
|
|
||||||
|
interface ModelEntry {
|
||||||
|
id: string
|
||||||
|
name: string
|
||||||
|
service: string
|
||||||
|
tags: string[]
|
||||||
|
vram_estimate_mb: number
|
||||||
|
}
|
||||||
|
|
||||||
|
interface RunResult {
|
||||||
|
model: string
|
||||||
|
response: string
|
||||||
|
elapsed_ms: number
|
||||||
|
error: string | null
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── State ──────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
const productsLoading = ref(false)
|
||||||
|
const products = ref<Product[]>([])
|
||||||
|
const selectedProduct = ref<Product | null>(null)
|
||||||
|
|
||||||
|
const sampleLoading = ref(false)
|
||||||
|
const sampleError = ref<string | null>(null)
|
||||||
|
const rawSample = ref<Sample | null>(null)
|
||||||
|
const editedPrompt = ref('')
|
||||||
|
|
||||||
|
const modelsLoading = ref(false)
|
||||||
|
const allModels = ref<ModelEntry[]>([])
|
||||||
|
const selectedModels = ref<Set<string>>(new Set())
|
||||||
|
|
||||||
|
const temperature = ref(0.7)
|
||||||
|
|
||||||
|
const running = ref(false)
|
||||||
|
const eventSource = ref<EventSource | null>(null)
|
||||||
|
const runLog = ref<string[]>([])
|
||||||
|
const results = ref<RunResult[]>([])
|
||||||
|
|
||||||
|
const pushingCorrections = ref(false)
|
||||||
|
const correctionsPushMsg = ref<string | null>(null)
|
||||||
|
const correctionsPushOk = ref(false)
|
||||||
|
|
||||||
|
// ── Computed ───────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
const ollamaModels = computed(() =>
|
||||||
|
allModels.value.filter(m => m.service === 'ollama')
|
||||||
|
)
|
||||||
|
|
||||||
|
const successfulResults = computed(() =>
|
||||||
|
results.value.filter(r => !r.error && r.response.trim())
|
||||||
|
)
|
||||||
|
|
||||||
|
// ── Lifecycle ─────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
onMounted(async () => {
|
||||||
|
await Promise.all([loadProducts(), loadModels()])
|
||||||
|
})
|
||||||
|
|
||||||
|
// ── Methods ────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async function loadProducts() {
|
||||||
|
productsLoading.value = true
|
||||||
|
try {
|
||||||
|
const resp = await fetch('/api/imitate/products')
|
||||||
|
if (!resp.ok) throw new Error(`HTTP ${resp.status}`)
|
||||||
|
const data = await resp.json()
|
||||||
|
products.value = data.products ?? []
|
||||||
|
} catch {
|
||||||
|
products.value = []
|
||||||
|
} finally {
|
||||||
|
productsLoading.value = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async function loadModels() {
|
||||||
|
modelsLoading.value = true
|
||||||
|
try {
|
||||||
|
const resp = await fetch('/api/cforch/models')
|
||||||
|
if (!resp.ok) throw new Error(`HTTP ${resp.status}`)
|
||||||
|
const data = await resp.json()
|
||||||
|
allModels.value = data.models ?? []
|
||||||
|
// Select all ollama models by default
|
||||||
|
for (const m of allModels.value) {
|
||||||
|
if (m.service === 'ollama') selectedModels.value.add(m.id)
|
||||||
|
}
|
||||||
|
} catch {
|
||||||
|
allModels.value = []
|
||||||
|
} finally {
|
||||||
|
modelsLoading.value = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async function selectProduct(p: Product) {
|
||||||
|
selectedProduct.value = p
|
||||||
|
rawSample.value = null
|
||||||
|
editedPrompt.value = ''
|
||||||
|
sampleError.value = null
|
||||||
|
results.value = []
|
||||||
|
runLog.value = []
|
||||||
|
await fetchSample()
|
||||||
|
}
|
||||||
|
|
||||||
|
async function fetchSample() {
|
||||||
|
if (!selectedProduct.value) return
|
||||||
|
sampleLoading.value = true
|
||||||
|
sampleError.value = null
|
||||||
|
try {
|
||||||
|
const resp = await fetch(`/api/imitate/products/${selectedProduct.value.id}/sample`)
|
||||||
|
if (!resp.ok) {
|
||||||
|
const body = await resp.json().catch(() => ({ detail: 'Unknown error' }))
|
||||||
|
throw new Error(body.detail ?? `HTTP ${resp.status}`)
|
||||||
|
}
|
||||||
|
const data: Sample = await resp.json()
|
||||||
|
rawSample.value = data
|
||||||
|
editedPrompt.value = data.prompt
|
||||||
|
} catch (err: unknown) {
|
||||||
|
sampleError.value = err instanceof Error ? err.message : String(err)
|
||||||
|
} finally {
|
||||||
|
sampleLoading.value = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function toggleModel(id: string, checked: boolean) {
|
||||||
|
const next = new Set(selectedModels.value)
|
||||||
|
checked ? next.add(id) : next.delete(id)
|
||||||
|
selectedModels.value = next
|
||||||
|
}
|
||||||
|
|
||||||
|
function toggleAllModels(checked: boolean) {
|
||||||
|
selectedModels.value = checked
|
||||||
|
? new Set(ollamaModels.value.map(m => m.id))
|
||||||
|
: new Set()
|
||||||
|
}
|
||||||
|
|
||||||
|
function startRun() {
|
||||||
|
if (running.value || !editedPrompt.value.trim() || selectedModels.value.size === 0) return
|
||||||
|
|
||||||
|
running.value = true
|
||||||
|
results.value = []
|
||||||
|
runLog.value = []
|
||||||
|
correctionsPushMsg.value = null
|
||||||
|
|
||||||
|
const params = new URLSearchParams({
|
||||||
|
prompt: editedPrompt.value,
|
||||||
|
model_ids: [...selectedModels.value].join(','),
|
||||||
|
temperature: temperature.value.toString(),
|
||||||
|
product_id: selectedProduct.value?.id ?? '',
|
||||||
|
})
|
||||||
|
|
||||||
|
const es = new EventSource(`/api/imitate/run?${params}`)
|
||||||
|
eventSource.value = es
|
||||||
|
|
||||||
|
es.onmessage = (event: MessageEvent) => {
|
||||||
|
try {
|
||||||
|
const msg = JSON.parse(event.data)
|
||||||
|
if (msg.type === 'start') {
|
||||||
|
runLog.value.push(`Running ${msg.total_models} model(s)…`)
|
||||||
|
} else if (msg.type === 'model_start') {
|
||||||
|
runLog.value.push(`→ ${msg.model}…`)
|
||||||
|
} else if (msg.type === 'model_done') {
|
||||||
|
const status = msg.error
|
||||||
|
? `✕ error: ${msg.error}`
|
||||||
|
: `✓ done (${(msg.elapsed_ms / 1000).toFixed(1)}s)`
|
||||||
|
runLog.value.push(` ${msg.model}: ${status}`)
|
||||||
|
results.value.push({
|
||||||
|
model: msg.model,
|
||||||
|
response: msg.response,
|
||||||
|
elapsed_ms: msg.elapsed_ms,
|
||||||
|
error: msg.error ?? null,
|
||||||
|
})
|
||||||
|
} else if (msg.type === 'complete') {
|
||||||
|
runLog.value.push(`Complete. ${results.value.length} responses.`)
|
||||||
|
running.value = false
|
||||||
|
es.close()
|
||||||
|
}
|
||||||
|
} catch {
|
||||||
|
// ignore malformed SSE frames
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
es.onerror = () => {
|
||||||
|
runLog.value.push('Connection error — run may be incomplete.')
|
||||||
|
running.value = false
|
||||||
|
es.close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function cancelRun() {
|
||||||
|
eventSource.value?.close()
|
||||||
|
eventSource.value = null
|
||||||
|
running.value = false
|
||||||
|
runLog.value.push('Cancelled.')
|
||||||
|
}
|
||||||
|
|
||||||
|
async function pushCorrections() {
|
||||||
|
if (!selectedProduct.value || successfulResults.value.length === 0) return
|
||||||
|
|
||||||
|
pushingCorrections.value = true
|
||||||
|
correctionsPushMsg.value = null
|
||||||
|
try {
|
||||||
|
const resp = await fetch('/api/imitate/push-corrections', {
|
||||||
|
method: 'POST',
|
||||||
|
headers: { 'Content-Type': 'application/json' },
|
||||||
|
body: JSON.stringify({
|
||||||
|
product_id: selectedProduct.value.id,
|
||||||
|
prompt: editedPrompt.value,
|
||||||
|
results: successfulResults.value,
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
if (!resp.ok) {
|
||||||
|
const body = await resp.json().catch(() => ({ detail: 'Unknown error' }))
|
||||||
|
throw new Error(body.detail ?? `HTTP ${resp.status}`)
|
||||||
|
}
|
||||||
|
const data = await resp.json()
|
||||||
|
correctionsPushMsg.value = `${data.pushed} record(s) added to Corrections queue.`
|
||||||
|
correctionsPushOk.value = true
|
||||||
|
} catch (err: unknown) {
|
||||||
|
correctionsPushMsg.value = err instanceof Error ? err.message : String(err)
|
||||||
|
correctionsPushOk.value = false
|
||||||
|
} finally {
|
||||||
|
pushingCorrections.value = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
</script>
|
||||||
|
|
||||||
|
<style scoped>
|
||||||
|
.imitate-view {
|
||||||
|
max-width: 1100px;
|
||||||
|
margin: 0 auto;
|
||||||
|
padding: 1.5rem;
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
gap: 1.5rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.bench-header {
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
gap: 0.25rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.page-title {
|
||||||
|
font-size: 1.6rem;
|
||||||
|
font-weight: 700;
|
||||||
|
color: var(--color-text, #1a2338);
|
||||||
|
}
|
||||||
|
|
||||||
|
.page-subtitle {
|
||||||
|
font-size: 0.9rem;
|
||||||
|
color: var(--color-text-secondary, #6b7a99);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Steps */
|
||||||
|
.step-section {
|
||||||
|
background: var(--color-surface-raised, #e4ebf5);
|
||||||
|
border: 1px solid var(--color-border, #d0d7e8);
|
||||||
|
border-radius: 0.5rem;
|
||||||
|
padding: 1.25rem;
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
gap: 1rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.step-heading {
|
||||||
|
font-size: 1rem;
|
||||||
|
font-weight: 600;
|
||||||
|
color: var(--color-text-secondary, #6b7a99);
|
||||||
|
text-transform: uppercase;
|
||||||
|
letter-spacing: 0.05em;
|
||||||
|
border-bottom: 1px solid var(--color-border, #d0d7e8);
|
||||||
|
padding-bottom: 0.5rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Product grid */
|
||||||
|
.product-grid {
|
||||||
|
display: grid;
|
||||||
|
grid-template-columns: repeat(auto-fill, minmax(160px, 1fr));
|
||||||
|
gap: 0.75rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.product-card {
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
align-items: center;
|
||||||
|
gap: 0.35rem;
|
||||||
|
padding: 1rem 0.75rem;
|
||||||
|
border: 2px solid var(--color-border, #d0d7e8);
|
||||||
|
border-radius: 0.5rem;
|
||||||
|
background: var(--color-surface, #f0f4fc);
|
||||||
|
cursor: pointer;
|
||||||
|
transition: border-color 0.15s, background 0.15s;
|
||||||
|
font-size: 0.9rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.product-card:hover:not(:disabled) {
|
||||||
|
border-color: var(--app-primary, #2A6080);
|
||||||
|
background: color-mix(in srgb, var(--app-primary, #2A6080) 6%, var(--color-surface, #f0f4fc));
|
||||||
|
}
|
||||||
|
|
||||||
|
.product-card.selected {
|
||||||
|
border-color: var(--app-primary, #2A6080);
|
||||||
|
background: color-mix(in srgb, var(--app-primary, #2A6080) 12%, var(--color-surface, #f0f4fc));
|
||||||
|
}
|
||||||
|
|
||||||
|
.product-card.offline {
|
||||||
|
opacity: 0.45;
|
||||||
|
cursor: not-allowed;
|
||||||
|
}
|
||||||
|
|
||||||
|
.product-icon {
|
||||||
|
font-size: 2rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.product-name {
|
||||||
|
font-weight: 600;
|
||||||
|
color: var(--color-text, #1a2338);
|
||||||
|
}
|
||||||
|
|
||||||
|
.product-status {
|
||||||
|
font-size: 0.72rem;
|
||||||
|
padding: 0.1rem 0.45rem;
|
||||||
|
border-radius: 9999px;
|
||||||
|
font-weight: 600;
|
||||||
|
}
|
||||||
|
|
||||||
|
.status-on {
|
||||||
|
background: #d1fae5;
|
||||||
|
color: #065f46;
|
||||||
|
}
|
||||||
|
|
||||||
|
.status-off {
|
||||||
|
background: #fee2e2;
|
||||||
|
color: #991b1b;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Sample panel */
|
||||||
|
.sample-toolbar {
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
gap: 0.75rem;
|
||||||
|
flex-wrap: wrap;
|
||||||
|
}
|
||||||
|
|
||||||
|
.sample-product-label {
|
||||||
|
font-weight: 600;
|
||||||
|
color: var(--app-primary, #2A6080);
|
||||||
|
}
|
||||||
|
|
||||||
|
.sample-error {
|
||||||
|
color: #b91c1c;
|
||||||
|
font-size: 0.85rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.sample-preview {
|
||||||
|
border: 1px solid var(--color-border, #d0d7e8);
|
||||||
|
border-radius: 0.375rem;
|
||||||
|
overflow: hidden;
|
||||||
|
}
|
||||||
|
|
||||||
|
.sample-preview-toggle {
|
||||||
|
padding: 0.5rem 0.75rem;
|
||||||
|
cursor: pointer;
|
||||||
|
font-size: 0.85rem;
|
||||||
|
color: var(--color-text-secondary, #6b7a99);
|
||||||
|
background: var(--color-surface, #f0f4fc);
|
||||||
|
user-select: none;
|
||||||
|
}
|
||||||
|
|
||||||
|
.sample-text {
|
||||||
|
padding: 0.75rem;
|
||||||
|
font-size: 0.82rem;
|
||||||
|
white-space: pre-wrap;
|
||||||
|
word-break: break-word;
|
||||||
|
max-height: 180px;
|
||||||
|
overflow-y: auto;
|
||||||
|
background: var(--color-bg, #f0f4fc);
|
||||||
|
margin: 0;
|
||||||
|
color: var(--color-text, #1a2338);
|
||||||
|
}
|
||||||
|
|
||||||
|
.prompt-label {
|
||||||
|
font-size: 0.85rem;
|
||||||
|
font-weight: 600;
|
||||||
|
color: var(--color-text-secondary, #6b7a99);
|
||||||
|
}
|
||||||
|
|
||||||
|
.prompt-editor {
|
||||||
|
width: 100%;
|
||||||
|
font-family: var(--font-mono, monospace);
|
||||||
|
font-size: 0.85rem;
|
||||||
|
padding: 0.75rem;
|
||||||
|
border: 1px solid var(--color-border, #d0d7e8);
|
||||||
|
border-radius: 0.375rem;
|
||||||
|
background: var(--color-surface, #f0f4fc);
|
||||||
|
color: var(--color-text, #1a2338);
|
||||||
|
resize: vertical;
|
||||||
|
line-height: 1.5;
|
||||||
|
}
|
||||||
|
|
||||||
|
.prompt-editor:focus {
|
||||||
|
outline: 2px solid var(--app-primary, #2A6080);
|
||||||
|
outline-offset: -1px;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Model picker — reuse bench-view classes */
|
||||||
|
.model-picker {
|
||||||
|
border: 1px solid var(--color-border, #d0d7e8);
|
||||||
|
border-radius: 0.5rem;
|
||||||
|
overflow: hidden;
|
||||||
|
}
|
||||||
|
|
||||||
|
.picker-summary {
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
justify-content: space-between;
|
||||||
|
padding: 0.75rem 1rem;
|
||||||
|
background: var(--color-surface, #f0f4fc);
|
||||||
|
cursor: pointer;
|
||||||
|
font-size: 0.95rem;
|
||||||
|
font-weight: 600;
|
||||||
|
user-select: none;
|
||||||
|
list-style: none;
|
||||||
|
}
|
||||||
|
|
||||||
|
.picker-title { flex: 1; }
|
||||||
|
|
||||||
|
.picker-badge {
|
||||||
|
font-size: 0.8rem;
|
||||||
|
background: var(--app-primary, #2A6080);
|
||||||
|
color: #fff;
|
||||||
|
border-radius: 9999px;
|
||||||
|
padding: 0.15rem 0.6rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.picker-body {
|
||||||
|
padding: 0.75rem 1rem;
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
gap: 0.25rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.picker-loading, .picker-empty {
|
||||||
|
font-size: 0.85rem;
|
||||||
|
color: var(--color-text-secondary, #6b7a99);
|
||||||
|
padding: 0.5rem 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
.picker-cat-header {
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
gap: 0.5rem;
|
||||||
|
font-weight: 600;
|
||||||
|
font-size: 0.9rem;
|
||||||
|
padding: 0.35rem 0;
|
||||||
|
cursor: pointer;
|
||||||
|
}
|
||||||
|
|
||||||
|
.picker-model-list {
|
||||||
|
display: flex;
|
||||||
|
flex-wrap: wrap;
|
||||||
|
gap: 0.25rem;
|
||||||
|
padding-left: 1.25rem;
|
||||||
|
padding-bottom: 0.5rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.picker-model-row {
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
gap: 0.4rem;
|
||||||
|
font-size: 0.85rem;
|
||||||
|
cursor: pointer;
|
||||||
|
padding: 0.2rem 0.5rem;
|
||||||
|
border-radius: 0.25rem;
|
||||||
|
min-width: 220px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.picker-model-row:hover {
|
||||||
|
background: color-mix(in srgb, var(--app-primary, #2A6080) 8%, transparent);
|
||||||
|
}
|
||||||
|
|
||||||
|
.picker-model-name {
|
||||||
|
flex: 1;
|
||||||
|
overflow: hidden;
|
||||||
|
text-overflow: ellipsis;
|
||||||
|
white-space: nowrap;
|
||||||
|
}
|
||||||
|
|
||||||
|
.picker-model-tags {
|
||||||
|
display: flex;
|
||||||
|
gap: 0.2rem;
|
||||||
|
flex-shrink: 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
.tag {
|
||||||
|
font-size: 0.68rem;
|
||||||
|
background: var(--color-border, #d0d7e8);
|
||||||
|
border-radius: 9999px;
|
||||||
|
padding: 0.05rem 0.4rem;
|
||||||
|
color: var(--color-text-secondary, #6b7a99);
|
||||||
|
white-space: nowrap;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Temperature */
|
||||||
|
.temp-row {
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
gap: 0.75rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.temp-label {
|
||||||
|
font-size: 0.85rem;
|
||||||
|
white-space: nowrap;
|
||||||
|
min-width: 160px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.temp-slider {
|
||||||
|
flex: 1;
|
||||||
|
accent-color: var(--app-primary, #2A6080);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Run controls */
|
||||||
|
.run-row {
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
gap: 0.75rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.btn-run {
|
||||||
|
background: var(--app-primary, #2A6080);
|
||||||
|
color: #fff;
|
||||||
|
border: none;
|
||||||
|
border-radius: 0.375rem;
|
||||||
|
padding: 0.55rem 1.25rem;
|
||||||
|
font-size: 0.9rem;
|
||||||
|
font-weight: 600;
|
||||||
|
cursor: pointer;
|
||||||
|
transition: opacity 0.15s;
|
||||||
|
}
|
||||||
|
|
||||||
|
.btn-run:disabled {
|
||||||
|
opacity: 0.4;
|
||||||
|
cursor: not-allowed;
|
||||||
|
}
|
||||||
|
|
||||||
|
.btn-cancel {
|
||||||
|
background: transparent;
|
||||||
|
border: 1px solid var(--color-border, #d0d7e8);
|
||||||
|
border-radius: 0.375rem;
|
||||||
|
padding: 0.5rem 0.9rem;
|
||||||
|
font-size: 0.85rem;
|
||||||
|
cursor: pointer;
|
||||||
|
color: var(--color-text-secondary, #6b7a99);
|
||||||
|
}
|
||||||
|
|
||||||
|
.btn-refresh {
|
||||||
|
background: transparent;
|
||||||
|
border: 1px solid var(--app-primary, #2A6080);
|
||||||
|
border-radius: 0.375rem;
|
||||||
|
padding: 0.35rem 0.8rem;
|
||||||
|
font-size: 0.85rem;
|
||||||
|
color: var(--app-primary, #2A6080);
|
||||||
|
cursor: pointer;
|
||||||
|
transition: background 0.15s;
|
||||||
|
}
|
||||||
|
|
||||||
|
.btn-refresh:hover:not(:disabled) {
|
||||||
|
background: color-mix(in srgb, var(--app-primary, #2A6080) 10%, transparent);
|
||||||
|
}
|
||||||
|
|
||||||
|
.btn-refresh:disabled { opacity: 0.5; cursor: not-allowed; }
|
||||||
|
|
||||||
|
/* Run log */
|
||||||
|
.run-log {
|
||||||
|
background: var(--color-bg, #f0f4fc);
|
||||||
|
border: 1px solid var(--color-border, #d0d7e8);
|
||||||
|
border-radius: 0.375rem;
|
||||||
|
padding: 0.75rem;
|
||||||
|
font-family: var(--font-mono, monospace);
|
||||||
|
font-size: 0.8rem;
|
||||||
|
max-height: 140px;
|
||||||
|
overflow-y: auto;
|
||||||
|
}
|
||||||
|
|
||||||
|
.log-line {
|
||||||
|
padding: 0.05rem 0;
|
||||||
|
color: var(--color-text, #1a2338);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Results */
|
||||||
|
.results-grid {
|
||||||
|
display: grid;
|
||||||
|
grid-template-columns: repeat(auto-fill, minmax(300px, 1fr));
|
||||||
|
gap: 1rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.result-card {
|
||||||
|
border: 1px solid var(--color-border, #d0d7e8);
|
||||||
|
border-radius: 0.5rem;
|
||||||
|
overflow: hidden;
|
||||||
|
background: var(--color-surface, #f0f4fc);
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
}
|
||||||
|
|
||||||
|
.result-card.result-error {
|
||||||
|
border-color: #fca5a5;
|
||||||
|
}
|
||||||
|
|
||||||
|
.result-header {
|
||||||
|
display: flex;
|
||||||
|
justify-content: space-between;
|
||||||
|
align-items: center;
|
||||||
|
padding: 0.5rem 0.75rem;
|
||||||
|
background: var(--color-surface-raised, #e4ebf5);
|
||||||
|
border-bottom: 1px solid var(--color-border, #d0d7e8);
|
||||||
|
}
|
||||||
|
|
||||||
|
.result-model {
|
||||||
|
font-size: 0.82rem;
|
||||||
|
font-weight: 600;
|
||||||
|
color: var(--color-text, #1a2338);
|
||||||
|
overflow: hidden;
|
||||||
|
text-overflow: ellipsis;
|
||||||
|
white-space: nowrap;
|
||||||
|
}
|
||||||
|
|
||||||
|
.result-meta {
|
||||||
|
font-size: 0.75rem;
|
||||||
|
color: var(--color-text-secondary, #6b7a99);
|
||||||
|
flex-shrink: 0;
|
||||||
|
margin-left: 0.5rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.result-err-badge {
|
||||||
|
background: #fee2e2;
|
||||||
|
color: #991b1b;
|
||||||
|
border-radius: 9999px;
|
||||||
|
padding: 0.1rem 0.45rem;
|
||||||
|
font-size: 0.7rem;
|
||||||
|
font-weight: 600;
|
||||||
|
}
|
||||||
|
|
||||||
|
.result-response, .result-error-text {
|
||||||
|
padding: 0.75rem;
|
||||||
|
font-size: 0.82rem;
|
||||||
|
white-space: pre-wrap;
|
||||||
|
word-break: break-word;
|
||||||
|
max-height: 280px;
|
||||||
|
overflow-y: auto;
|
||||||
|
margin: 0;
|
||||||
|
flex: 1;
|
||||||
|
color: var(--color-text, #1a2338);
|
||||||
|
}
|
||||||
|
|
||||||
|
.result-error-text {
|
||||||
|
color: #b91c1c;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Corrections */
|
||||||
|
.corrections-row {
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
gap: 0.75rem;
|
||||||
|
flex-wrap: wrap;
|
||||||
|
}
|
||||||
|
|
||||||
|
.btn-corrections {
|
||||||
|
background: var(--color-accent-warm, #b45309);
|
||||||
|
color: #fff;
|
||||||
|
border: none;
|
||||||
|
border-radius: 0.375rem;
|
||||||
|
padding: 0.55rem 1.25rem;
|
||||||
|
font-size: 0.9rem;
|
||||||
|
font-weight: 600;
|
||||||
|
cursor: pointer;
|
||||||
|
transition: opacity 0.15s;
|
||||||
|
}
|
||||||
|
|
||||||
|
.btn-corrections:disabled {
|
||||||
|
opacity: 0.4;
|
||||||
|
cursor: not-allowed;
|
||||||
|
}
|
||||||
|
|
||||||
|
.corrections-msg {
|
||||||
|
font-size: 0.85rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.msg-ok { color: #065f46; }
|
||||||
|
.msg-err { color: #b91c1c; }
|
||||||
|
</style>
|
||||||
858
web/src/views/ModelsView.vue
Normal file
858
web/src/views/ModelsView.vue
Normal file
|
|
@ -0,0 +1,858 @@
|
||||||
|
<template>
|
||||||
|
<div class="models-view">
|
||||||
|
<h1 class="page-title">🤗 Models</h1>
|
||||||
|
|
||||||
|
<!-- ── 1. HF Lookup ───────────────────────────────── -->
|
||||||
|
<section class="section">
|
||||||
|
<h2 class="section-title">HuggingFace Lookup</h2>
|
||||||
|
|
||||||
|
<div class="lookup-row">
|
||||||
|
<input
|
||||||
|
v-model="lookupInput"
|
||||||
|
type="text"
|
||||||
|
class="lookup-input"
|
||||||
|
placeholder="org/model or huggingface.co/org/model"
|
||||||
|
:disabled="lookupLoading"
|
||||||
|
@keydown.enter="doLookup"
|
||||||
|
aria-label="HuggingFace model ID"
|
||||||
|
/>
|
||||||
|
<button
|
||||||
|
class="btn-primary"
|
||||||
|
:disabled="lookupLoading || !lookupInput.trim()"
|
||||||
|
@click="doLookup"
|
||||||
|
>
|
||||||
|
{{ lookupLoading ? 'Looking up…' : 'Lookup' }}
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div v-if="lookupError" class="error-notice" role="alert">
|
||||||
|
{{ lookupError }}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div v-if="lookupResult" class="preview-card">
|
||||||
|
<div class="preview-header">
|
||||||
|
<span class="preview-repo-id">{{ lookupResult.repo_id }}</span>
|
||||||
|
<div class="badge-group">
|
||||||
|
<span v-if="lookupResult.already_installed" class="badge badge-success">Installed</span>
|
||||||
|
<span v-if="lookupResult.already_queued" class="badge badge-info">In queue</span>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="preview-meta">
|
||||||
|
<span v-if="lookupResult.pipeline_tag" class="chip chip-pipeline">
|
||||||
|
{{ lookupResult.pipeline_tag }}
|
||||||
|
</span>
|
||||||
|
<span v-if="lookupResult.adapter_recommendation" class="chip chip-adapter">
|
||||||
|
{{ lookupResult.adapter_recommendation }}
|
||||||
|
</span>
|
||||||
|
<span v-if="lookupResult.size != null" class="preview-size">
|
||||||
|
{{ humanBytes(lookupResult.size) }}
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<p v-if="lookupResult.description" class="preview-desc">
|
||||||
|
{{ lookupResult.description }}
|
||||||
|
</p>
|
||||||
|
|
||||||
|
<div v-if="lookupResult.warning" class="compat-warning" role="alert">
|
||||||
|
<span class="compat-warning-icon">⚠️</span>
|
||||||
|
<span>{{ lookupResult.warning }}</span>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<button
|
||||||
|
class="btn-primary btn-add-queue"
|
||||||
|
:class="{ 'btn-add-queue-warn': !lookupResult.compatible }"
|
||||||
|
:disabled="lookupResult.already_installed || lookupResult.already_queued || addingToQueue"
|
||||||
|
@click="addToQueue"
|
||||||
|
>
|
||||||
|
{{ addingToQueue ? 'Adding…' : lookupResult.compatible ? 'Add to queue' : 'Add anyway' }}
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
</section>
|
||||||
|
|
||||||
|
<!-- ── 2. Approval Queue ──────────────────────────── -->
|
||||||
|
<section class="section">
|
||||||
|
<h2 class="section-title">Approval Queue</h2>
|
||||||
|
|
||||||
|
<div v-if="pendingModels.length === 0" class="empty-notice">
|
||||||
|
No models waiting for approval.
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div v-for="model in pendingModels" :key="model.id" class="model-card">
|
||||||
|
<div class="model-card-header">
|
||||||
|
<span class="model-repo-id">{{ model.repo_id }}</span>
|
||||||
|
<button
|
||||||
|
class="btn-dismiss"
|
||||||
|
:aria-label="`Dismiss ${model.repo_id}`"
|
||||||
|
@click="dismissModel(model.id)"
|
||||||
|
>
|
||||||
|
✕
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
<div class="model-meta">
|
||||||
|
<span v-if="model.pipeline_tag" class="chip chip-pipeline">{{ model.pipeline_tag }}</span>
|
||||||
|
<span v-if="model.adapter_recommendation" class="chip chip-adapter">{{ model.adapter_recommendation }}</span>
|
||||||
|
</div>
|
||||||
|
<div class="model-card-actions">
|
||||||
|
<button class="btn-primary btn-sm" @click="approveModel(model.id)">
|
||||||
|
Approve download
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</section>
|
||||||
|
|
||||||
|
<!-- ── 3. Active Downloads ────────────────────────── -->
|
||||||
|
<section class="section">
|
||||||
|
<h2 class="section-title">Active Downloads</h2>
|
||||||
|
|
||||||
|
<div v-if="downloadingModels.length === 0" class="empty-notice">
|
||||||
|
No active downloads.
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div v-for="model in downloadingModels" :key="model.id" class="model-card">
|
||||||
|
<div class="model-card-header">
|
||||||
|
<span class="model-repo-id">{{ model.repo_id }}</span>
|
||||||
|
<span v-if="downloadErrors[model.id]" class="badge badge-error">Error</span>
|
||||||
|
</div>
|
||||||
|
<div class="model-meta">
|
||||||
|
<span v-if="model.pipeline_tag" class="chip chip-pipeline">{{ model.pipeline_tag }}</span>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div v-if="downloadErrors[model.id]" class="download-error" role="alert">
|
||||||
|
{{ downloadErrors[model.id] }}
|
||||||
|
</div>
|
||||||
|
<div v-else class="progress-wrap" :aria-label="`Download progress for ${model.repo_id}`">
|
||||||
|
<div
|
||||||
|
class="progress-bar"
|
||||||
|
:style="{ width: `${downloadProgress[model.id] ?? 0}%` }"
|
||||||
|
role="progressbar"
|
||||||
|
:aria-valuenow="downloadProgress[model.id] ?? 0"
|
||||||
|
aria-valuemin="0"
|
||||||
|
aria-valuemax="100"
|
||||||
|
/>
|
||||||
|
<span class="progress-label">
|
||||||
|
{{ downloadProgress[model.id] == null ? 'Preparing…' : `${downloadProgress[model.id]}%` }}
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</section>
|
||||||
|
|
||||||
|
<!-- ── 4. Installed Models ────────────────────────── -->
|
||||||
|
<section class="section">
|
||||||
|
<h2 class="section-title">Installed Models</h2>
|
||||||
|
|
||||||
|
<div v-if="installedModels.length === 0" class="empty-notice">
|
||||||
|
No models installed yet.
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div v-else class="installed-table-wrap">
|
||||||
|
<table class="installed-table">
|
||||||
|
<thead>
|
||||||
|
<tr>
|
||||||
|
<th>Name</th>
|
||||||
|
<th>Type</th>
|
||||||
|
<th>Adapter</th>
|
||||||
|
<th>Size</th>
|
||||||
|
<th></th>
|
||||||
|
</tr>
|
||||||
|
</thead>
|
||||||
|
<tbody>
|
||||||
|
<tr v-for="model in installedModels" :key="model.name">
|
||||||
|
<td class="td-name">{{ model.name }}</td>
|
||||||
|
<td>
|
||||||
|
<span
|
||||||
|
class="badge"
|
||||||
|
:class="model.type === 'finetuned' ? 'badge-accent' : 'badge-info'"
|
||||||
|
>
|
||||||
|
{{ model.type }}
|
||||||
|
</span>
|
||||||
|
</td>
|
||||||
|
<td>{{ model.adapter ?? '—' }}</td>
|
||||||
|
<td>{{ humanBytes(model.size) }}</td>
|
||||||
|
<td>
|
||||||
|
<button
|
||||||
|
class="btn-danger btn-sm"
|
||||||
|
@click="deleteInstalled(model.name)"
|
||||||
|
>
|
||||||
|
Delete
|
||||||
|
</button>
|
||||||
|
</td>
|
||||||
|
</tr>
|
||||||
|
</tbody>
|
||||||
|
</table>
|
||||||
|
</div>
|
||||||
|
</section>
|
||||||
|
</div>
|
||||||
|
</template>
|
||||||
|
|
||||||
|
<script setup lang="ts">
|
||||||
|
import { ref, computed, onMounted, onUnmounted } from 'vue'
|
||||||
|
|
||||||
|
// ── Type definitions ──────────────────────────────────
|
||||||
|
|
||||||
|
interface LookupResult {
|
||||||
|
repo_id: string
|
||||||
|
pipeline_tag: string | null
|
||||||
|
adapter_recommendation: string | null
|
||||||
|
compatible: boolean
|
||||||
|
warning: string | null
|
||||||
|
size: number | null
|
||||||
|
description: string | null
|
||||||
|
already_installed: boolean
|
||||||
|
already_queued: boolean
|
||||||
|
}
|
||||||
|
|
||||||
|
interface QueuedModel {
|
||||||
|
id: string
|
||||||
|
repo_id: string
|
||||||
|
status: 'pending' | 'downloading' | 'done' | 'error'
|
||||||
|
pipeline_tag: string | null
|
||||||
|
adapter_recommendation: string | null
|
||||||
|
}
|
||||||
|
|
||||||
|
interface InstalledModel {
|
||||||
|
name: string
|
||||||
|
type: 'finetuned' | 'downloaded'
|
||||||
|
adapter: string | null
|
||||||
|
size: number
|
||||||
|
}
|
||||||
|
|
||||||
|
interface SseProgressEvent {
|
||||||
|
model_id: string
|
||||||
|
pct: number | null
|
||||||
|
status: 'progress' | 'done' | 'error'
|
||||||
|
message?: string
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── State ─────────────────────────────────────────────
|
||||||
|
|
||||||
|
const lookupInput = ref('')
|
||||||
|
const lookupLoading = ref(false)
|
||||||
|
const lookupError = ref<string | null>(null)
|
||||||
|
const lookupResult = ref<LookupResult | null>(null)
|
||||||
|
const addingToQueue = ref(false)
|
||||||
|
|
||||||
|
const queuedModels = ref<QueuedModel[]>([])
|
||||||
|
const installedModels = ref<InstalledModel[]>([])
|
||||||
|
|
||||||
|
const downloadProgress = ref<Record<string, number>>({})
|
||||||
|
const downloadErrors = ref<Record<string, string>>({})
|
||||||
|
|
||||||
|
let pollInterval: ReturnType<typeof setInterval> | null = null
|
||||||
|
let sseSource: EventSource | null = null
|
||||||
|
|
||||||
|
// ── Derived ───────────────────────────────────────────
|
||||||
|
|
||||||
|
const pendingModels = computed(() =>
|
||||||
|
queuedModels.value.filter(m => m.status === 'pending')
|
||||||
|
)
|
||||||
|
|
||||||
|
const downloadingModels = computed(() =>
|
||||||
|
queuedModels.value.filter(m => m.status === 'downloading')
|
||||||
|
)
|
||||||
|
|
||||||
|
// ── Helpers ───────────────────────────────────────────
|
||||||
|
|
||||||
|
function humanBytes(bytes: number | null): string {
|
||||||
|
if (bytes == null) return '—'
|
||||||
|
const units = ['B', 'KB', 'MB', 'GB', 'TB']
|
||||||
|
let value = bytes
|
||||||
|
let unitIndex = 0
|
||||||
|
while (value >= 1024 && unitIndex < units.length - 1) {
|
||||||
|
value /= 1024
|
||||||
|
unitIndex++
|
||||||
|
}
|
||||||
|
return `${value.toFixed(unitIndex === 0 ? 0 : 1)} ${units[unitIndex]}`
|
||||||
|
}
|
||||||
|
|
||||||
|
function normalizeRepoId(raw: string): string {
|
||||||
|
return raw.trim().replace(/^https?:\/\/huggingface\.co\//, '')
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── API calls ─────────────────────────────────────────
|
||||||
|
|
||||||
|
async function doLookup() {
|
||||||
|
const repoId = normalizeRepoId(lookupInput.value)
|
||||||
|
if (!repoId) return
|
||||||
|
|
||||||
|
lookupLoading.value = true
|
||||||
|
lookupError.value = null
|
||||||
|
lookupResult.value = null
|
||||||
|
|
||||||
|
try {
|
||||||
|
const res = await fetch(`/api/models/lookup?repo_id=${encodeURIComponent(repoId)}`)
|
||||||
|
if (res.status === 404) {
|
||||||
|
lookupError.value = 'Model not found on HuggingFace.'
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if (res.status === 502) {
|
||||||
|
lookupError.value = 'HuggingFace unreachable. Check your connection and try again.'
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if (!res.ok) {
|
||||||
|
lookupError.value = `Lookup failed (HTTP ${res.status}).`
|
||||||
|
return
|
||||||
|
}
|
||||||
|
lookupResult.value = await res.json() as LookupResult
|
||||||
|
} catch {
|
||||||
|
lookupError.value = 'Network error. Is the Avocet API running?'
|
||||||
|
} finally {
|
||||||
|
lookupLoading.value = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async function addToQueue() {
|
||||||
|
if (!lookupResult.value) return
|
||||||
|
addingToQueue.value = true
|
||||||
|
try {
|
||||||
|
const res = await fetch('/api/models/queue', {
|
||||||
|
method: 'POST',
|
||||||
|
headers: { 'Content-Type': 'application/json' },
|
||||||
|
body: JSON.stringify({ repo_id: lookupResult.value.repo_id }),
|
||||||
|
})
|
||||||
|
if (res.ok) {
|
||||||
|
lookupResult.value = { ...lookupResult.value, already_queued: true }
|
||||||
|
await loadQueue()
|
||||||
|
}
|
||||||
|
} catch { /* ignore — already_queued badge won't flip, user can retry */ }
|
||||||
|
finally {
|
||||||
|
addingToQueue.value = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async function approveModel(id: string) {
|
||||||
|
try {
|
||||||
|
const res = await fetch(`/api/models/queue/${encodeURIComponent(id)}/approve`, { method: 'POST' })
|
||||||
|
if (res.ok) {
|
||||||
|
await loadQueue()
|
||||||
|
startSse()
|
||||||
|
}
|
||||||
|
} catch { /* ignore */ }
|
||||||
|
}
|
||||||
|
|
||||||
|
async function dismissModel(id: string) {
|
||||||
|
try {
|
||||||
|
const res = await fetch(`/api/models/queue/${encodeURIComponent(id)}`, { method: 'DELETE' })
|
||||||
|
if (res.ok) {
|
||||||
|
queuedModels.value = queuedModels.value.filter(m => m.id !== id)
|
||||||
|
}
|
||||||
|
} catch { /* ignore */ }
|
||||||
|
}
|
||||||
|
|
||||||
|
async function deleteInstalled(name: string) {
|
||||||
|
if (!window.confirm(`Delete installed model "${name}"? This cannot be undone.`)) return
|
||||||
|
try {
|
||||||
|
const res = await fetch(`/api/models/installed/${encodeURIComponent(name)}`, { method: 'DELETE' })
|
||||||
|
if (res.ok) {
|
||||||
|
installedModels.value = installedModels.value.filter(m => m.name !== name)
|
||||||
|
}
|
||||||
|
} catch { /* ignore */ }
|
||||||
|
}
|
||||||
|
|
||||||
|
async function loadQueue() {
|
||||||
|
try {
|
||||||
|
const res = await fetch('/api/models/queue')
|
||||||
|
if (res.ok) queuedModels.value = await res.json() as QueuedModel[]
|
||||||
|
} catch { /* non-fatal */ }
|
||||||
|
}
|
||||||
|
|
||||||
|
async function loadInstalled() {
|
||||||
|
try {
|
||||||
|
const res = await fetch('/api/models/installed')
|
||||||
|
if (res.ok) installedModels.value = await res.json() as InstalledModel[]
|
||||||
|
} catch { /* non-fatal */ }
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── SSE for download progress ─────────────────────────
|
||||||
|
|
||||||
|
function startSse() {
|
||||||
|
if (sseSource) return // already connected
|
||||||
|
|
||||||
|
sseSource = new EventSource('/api/models/download/stream')
|
||||||
|
|
||||||
|
sseSource.addEventListener('message', (e: MessageEvent) => {
|
||||||
|
let event: SseProgressEvent
|
||||||
|
try {
|
||||||
|
event = JSON.parse(e.data as string) as SseProgressEvent
|
||||||
|
} catch {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
const { model_id, pct, status, message } = event
|
||||||
|
|
||||||
|
if (status === 'progress' && pct != null) {
|
||||||
|
downloadProgress.value = { ...downloadProgress.value, [model_id]: pct }
|
||||||
|
} else if (status === 'done') {
|
||||||
|
const updated = { ...downloadProgress.value }
|
||||||
|
delete updated[model_id]
|
||||||
|
downloadProgress.value = updated
|
||||||
|
|
||||||
|
queuedModels.value = queuedModels.value.filter(m => m.id !== model_id)
|
||||||
|
loadInstalled()
|
||||||
|
} else if (status === 'error') {
|
||||||
|
downloadErrors.value = {
|
||||||
|
...downloadErrors.value,
|
||||||
|
[model_id]: message ?? 'Download failed.',
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
sseSource.onerror = () => {
|
||||||
|
sseSource?.close()
|
||||||
|
sseSource = null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function stopSse() {
|
||||||
|
sseSource?.close()
|
||||||
|
sseSource = null
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Polling ───────────────────────────────────────────
|
||||||
|
|
||||||
|
function startPollingIfDownloading() {
|
||||||
|
if (pollInterval) return
|
||||||
|
pollInterval = setInterval(async () => {
|
||||||
|
await loadQueue()
|
||||||
|
if (downloadingModels.value.length === 0) {
|
||||||
|
stopPolling()
|
||||||
|
}
|
||||||
|
}, 5000)
|
||||||
|
}
|
||||||
|
|
||||||
|
function stopPolling() {
|
||||||
|
if (pollInterval) {
|
||||||
|
clearInterval(pollInterval)
|
||||||
|
pollInterval = null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Lifecycle ─────────────────────────────────────────
|
||||||
|
|
||||||
|
onMounted(async () => {
|
||||||
|
await Promise.all([loadQueue(), loadInstalled()])
|
||||||
|
|
||||||
|
if (downloadingModels.value.length > 0) {
|
||||||
|
startSse()
|
||||||
|
startPollingIfDownloading()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
onUnmounted(() => {
|
||||||
|
stopPolling()
|
||||||
|
stopSse()
|
||||||
|
})
|
||||||
|
</script>
|
||||||
|
|
||||||
|
<style scoped>
|
||||||
|
.models-view {
|
||||||
|
max-width: 760px;
|
||||||
|
margin: 0 auto;
|
||||||
|
padding: 1.5rem 1rem 4rem;
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
gap: 2rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.page-title {
|
||||||
|
font-family: var(--font-display, var(--font-body, sans-serif));
|
||||||
|
font-size: 1.4rem;
|
||||||
|
font-weight: 700;
|
||||||
|
color: var(--color-primary, #2d5a27);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ── Sections ── */
|
||||||
|
.section {
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
gap: 0.75rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.section-title {
|
||||||
|
font-size: 1rem;
|
||||||
|
font-weight: 600;
|
||||||
|
color: var(--color-text, #1a2338);
|
||||||
|
padding-bottom: 0.4rem;
|
||||||
|
border-bottom: 1px solid var(--color-border, #a8b8d0);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ── Lookup row ── */
|
||||||
|
.lookup-row {
|
||||||
|
display: flex;
|
||||||
|
gap: 0.5rem;
|
||||||
|
flex-wrap: wrap;
|
||||||
|
}
|
||||||
|
|
||||||
|
.lookup-input {
|
||||||
|
flex: 1;
|
||||||
|
min-width: 0;
|
||||||
|
padding: 0.45rem 0.7rem;
|
||||||
|
border: 1px solid var(--color-border, #a8b8d0);
|
||||||
|
border-radius: var(--radius-md, 0.5rem);
|
||||||
|
background: var(--color-surface-raised, #f5f7fc);
|
||||||
|
color: var(--color-text, #1a2338);
|
||||||
|
font-size: 0.9rem;
|
||||||
|
font-family: var(--font-body, sans-serif);
|
||||||
|
}
|
||||||
|
|
||||||
|
.lookup-input:disabled {
|
||||||
|
opacity: 0.6;
|
||||||
|
}
|
||||||
|
|
||||||
|
.lookup-input::placeholder {
|
||||||
|
color: var(--color-text-muted, #4a5c7a);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ── Notices ── */
|
||||||
|
.error-notice {
|
||||||
|
padding: 0.6rem 0.8rem;
|
||||||
|
background: color-mix(in srgb, var(--color-error, #c0392b) 12%, transparent);
|
||||||
|
border: 1px solid color-mix(in srgb, var(--color-error, #c0392b) 30%, transparent);
|
||||||
|
border-radius: var(--radius-md, 0.5rem);
|
||||||
|
color: var(--color-error, #c0392b);
|
||||||
|
font-size: 0.88rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.empty-notice {
|
||||||
|
color: var(--color-text-muted, #4a5c7a);
|
||||||
|
font-size: 0.9rem;
|
||||||
|
padding: 0.75rem;
|
||||||
|
border: 1px dashed var(--color-border, #a8b8d0);
|
||||||
|
border-radius: var(--radius-md, 0.5rem);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ── Preview card ── */
|
||||||
|
.preview-card {
|
||||||
|
border: 1px solid var(--color-border, #a8b8d0);
|
||||||
|
border-radius: var(--radius-lg, 1rem);
|
||||||
|
background: var(--color-surface-raised, #f5f7fc);
|
||||||
|
padding: 1rem;
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
gap: 0.6rem;
|
||||||
|
box-shadow: var(--shadow-sm);
|
||||||
|
}
|
||||||
|
|
||||||
|
.preview-header {
|
||||||
|
display: flex;
|
||||||
|
align-items: flex-start;
|
||||||
|
justify-content: space-between;
|
||||||
|
gap: 0.5rem;
|
||||||
|
flex-wrap: wrap;
|
||||||
|
}
|
||||||
|
|
||||||
|
.preview-repo-id {
|
||||||
|
font-family: var(--font-mono, monospace);
|
||||||
|
font-size: 0.95rem;
|
||||||
|
font-weight: 600;
|
||||||
|
color: var(--color-text, #1a2338);
|
||||||
|
word-break: break-all;
|
||||||
|
}
|
||||||
|
|
||||||
|
.preview-meta {
|
||||||
|
display: flex;
|
||||||
|
gap: 0.4rem;
|
||||||
|
flex-wrap: wrap;
|
||||||
|
align-items: center;
|
||||||
|
}
|
||||||
|
|
||||||
|
.preview-size {
|
||||||
|
font-size: 0.8rem;
|
||||||
|
color: var(--color-text-muted, #4a5c7a);
|
||||||
|
margin-left: 0.25rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.preview-desc {
|
||||||
|
font-size: 0.875rem;
|
||||||
|
color: var(--color-text-muted, #4a5c7a);
|
||||||
|
line-height: 1.5;
|
||||||
|
margin: 0;
|
||||||
|
display: -webkit-box;
|
||||||
|
-webkit-line-clamp: 3;
|
||||||
|
-webkit-box-orient: vertical;
|
||||||
|
overflow: hidden;
|
||||||
|
}
|
||||||
|
|
||||||
|
.compat-warning {
|
||||||
|
display: flex;
|
||||||
|
align-items: flex-start;
|
||||||
|
gap: 0.5rem;
|
||||||
|
padding: 0.6rem 0.75rem;
|
||||||
|
border-radius: var(--radius-sm, 0.25rem);
|
||||||
|
background: color-mix(in srgb, var(--color-warning, #f59e0b) 12%, transparent);
|
||||||
|
border: 1px solid color-mix(in srgb, var(--color-warning, #f59e0b) 40%, transparent);
|
||||||
|
font-size: 0.82rem;
|
||||||
|
color: var(--color-text, #1a2338);
|
||||||
|
line-height: 1.45;
|
||||||
|
}
|
||||||
|
|
||||||
|
.compat-warning-icon {
|
||||||
|
flex-shrink: 0;
|
||||||
|
line-height: 1.45;
|
||||||
|
}
|
||||||
|
|
||||||
|
.btn-add-queue {
|
||||||
|
align-self: flex-start;
|
||||||
|
}
|
||||||
|
|
||||||
|
.btn-add-queue-warn {
|
||||||
|
background: var(--color-surface-raised, #e4ebf5);
|
||||||
|
color: var(--color-text-secondary, #6b7a99);
|
||||||
|
border: 1px solid var(--color-border, #d0d7e8);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ── Model cards (queue + downloads) ── */
|
||||||
|
.model-card {
|
||||||
|
border: 1px solid var(--color-border, #a8b8d0);
|
||||||
|
border-radius: var(--radius-md, 0.5rem);
|
||||||
|
background: var(--color-surface-raised, #f5f7fc);
|
||||||
|
padding: 0.75rem 1rem;
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
gap: 0.5rem;
|
||||||
|
box-shadow: var(--shadow-sm);
|
||||||
|
}
|
||||||
|
|
||||||
|
.model-card-header {
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
justify-content: space-between;
|
||||||
|
gap: 0.5rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.model-repo-id {
|
||||||
|
font-family: var(--font-mono, monospace);
|
||||||
|
font-size: 0.9rem;
|
||||||
|
font-weight: 600;
|
||||||
|
color: var(--color-text, #1a2338);
|
||||||
|
word-break: break-all;
|
||||||
|
}
|
||||||
|
|
||||||
|
.model-meta {
|
||||||
|
display: flex;
|
||||||
|
gap: 0.4rem;
|
||||||
|
flex-wrap: wrap;
|
||||||
|
}
|
||||||
|
|
||||||
|
.model-card-actions {
|
||||||
|
display: flex;
|
||||||
|
gap: 0.5rem;
|
||||||
|
flex-wrap: wrap;
|
||||||
|
padding-top: 0.25rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ── Progress bar ── */
|
||||||
|
.progress-wrap {
|
||||||
|
position: relative;
|
||||||
|
height: 1.5rem;
|
||||||
|
background: var(--color-surface-alt, #dde4f0);
|
||||||
|
border-radius: var(--radius-full, 9999px);
|
||||||
|
overflow: hidden;
|
||||||
|
}
|
||||||
|
|
||||||
|
.progress-bar {
|
||||||
|
position: absolute;
|
||||||
|
top: 0;
|
||||||
|
left: 0;
|
||||||
|
height: 100%;
|
||||||
|
background: var(--color-accent, #c4732a);
|
||||||
|
border-radius: var(--radius-full, 9999px);
|
||||||
|
transition: width 300ms ease;
|
||||||
|
}
|
||||||
|
|
||||||
|
.progress-label {
|
||||||
|
position: absolute;
|
||||||
|
inset: 0;
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
justify-content: center;
|
||||||
|
font-size: 0.75rem;
|
||||||
|
font-weight: 600;
|
||||||
|
color: var(--color-text, #1a2338);
|
||||||
|
pointer-events: none;
|
||||||
|
}
|
||||||
|
|
||||||
|
.download-error {
|
||||||
|
font-size: 0.85rem;
|
||||||
|
color: var(--color-error, #c0392b);
|
||||||
|
padding: 0.4rem 0.5rem;
|
||||||
|
background: color-mix(in srgb, var(--color-error, #c0392b) 10%, transparent);
|
||||||
|
border-radius: var(--radius-sm, 0.25rem);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ── Installed table ── */
|
||||||
|
.installed-table-wrap {
|
||||||
|
overflow-x: auto;
|
||||||
|
}
|
||||||
|
|
||||||
|
.installed-table {
|
||||||
|
width: 100%;
|
||||||
|
border-collapse: collapse;
|
||||||
|
font-size: 0.875rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.installed-table th {
|
||||||
|
text-align: left;
|
||||||
|
padding: 0.4rem 0.6rem;
|
||||||
|
color: var(--color-text-muted, #4a5c7a);
|
||||||
|
font-size: 0.78rem;
|
||||||
|
font-weight: 600;
|
||||||
|
text-transform: uppercase;
|
||||||
|
letter-spacing: 0.03em;
|
||||||
|
border-bottom: 1px solid var(--color-border, #a8b8d0);
|
||||||
|
white-space: nowrap;
|
||||||
|
}
|
||||||
|
|
||||||
|
.installed-table td {
|
||||||
|
padding: 0.55rem 0.6rem;
|
||||||
|
border-bottom: 1px solid var(--color-border-light, #ccd5e6);
|
||||||
|
vertical-align: middle;
|
||||||
|
}
|
||||||
|
|
||||||
|
.td-name {
|
||||||
|
font-family: var(--font-mono, monospace);
|
||||||
|
font-size: 0.85rem;
|
||||||
|
word-break: break-all;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ── Badges ── */
|
||||||
|
.badge-group {
|
||||||
|
display: flex;
|
||||||
|
gap: 0.35rem;
|
||||||
|
flex-wrap: wrap;
|
||||||
|
align-items: center;
|
||||||
|
}
|
||||||
|
|
||||||
|
.badge {
|
||||||
|
display: inline-flex;
|
||||||
|
align-items: center;
|
||||||
|
padding: 0.15rem 0.55rem;
|
||||||
|
border-radius: var(--radius-full, 9999px);
|
||||||
|
font-size: 0.72rem;
|
||||||
|
font-weight: 700;
|
||||||
|
letter-spacing: 0.02em;
|
||||||
|
text-transform: uppercase;
|
||||||
|
white-space: nowrap;
|
||||||
|
}
|
||||||
|
|
||||||
|
.badge-success {
|
||||||
|
background: color-mix(in srgb, var(--color-success, #3a7a32) 15%, transparent);
|
||||||
|
color: var(--color-success, #3a7a32);
|
||||||
|
}
|
||||||
|
|
||||||
|
.badge-info {
|
||||||
|
background: color-mix(in srgb, var(--color-info, #1e6091) 15%, transparent);
|
||||||
|
color: var(--color-info, #1e6091);
|
||||||
|
}
|
||||||
|
|
||||||
|
.badge-accent {
|
||||||
|
background: color-mix(in srgb, var(--color-accent, #c4732a) 15%, transparent);
|
||||||
|
color: var(--color-accent, #c4732a);
|
||||||
|
}
|
||||||
|
|
||||||
|
.badge-error {
|
||||||
|
background: color-mix(in srgb, var(--color-error, #c0392b) 15%, transparent);
|
||||||
|
color: var(--color-error, #c0392b);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ── Chips ── */
|
||||||
|
.chip {
|
||||||
|
display: inline-flex;
|
||||||
|
align-items: center;
|
||||||
|
padding: 0.15rem 0.5rem;
|
||||||
|
border-radius: var(--radius-full, 9999px);
|
||||||
|
font-size: 0.75rem;
|
||||||
|
font-weight: 600;
|
||||||
|
background: var(--color-surface-alt, #dde4f0);
|
||||||
|
white-space: nowrap;
|
||||||
|
}
|
||||||
|
|
||||||
|
.chip-pipeline {
|
||||||
|
color: var(--color-primary, #2d5a27);
|
||||||
|
background: color-mix(in srgb, var(--color-primary, #2d5a27) 12%, var(--color-surface-alt, #dde4f0));
|
||||||
|
}
|
||||||
|
|
||||||
|
.chip-adapter {
|
||||||
|
color: var(--color-accent, #c4732a);
|
||||||
|
background: color-mix(in srgb, var(--color-accent, #c4732a) 12%, var(--color-surface-alt, #dde4f0));
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ── Buttons ── */
|
||||||
|
.btn-primary, .btn-danger {
|
||||||
|
padding: 0.4rem 0.9rem;
|
||||||
|
border-radius: var(--radius-md, 0.5rem);
|
||||||
|
font-size: 0.85rem;
|
||||||
|
cursor: pointer;
|
||||||
|
border: 1px solid;
|
||||||
|
font-family: var(--font-body, sans-serif);
|
||||||
|
transition: background var(--transition, 200ms ease), color var(--transition, 200ms ease);
|
||||||
|
}
|
||||||
|
|
||||||
|
.btn-sm {
|
||||||
|
padding: 0.25rem 0.65rem;
|
||||||
|
font-size: 0.8rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.btn-primary {
|
||||||
|
border-color: var(--color-primary, #2d5a27);
|
||||||
|
background: var(--color-primary, #2d5a27);
|
||||||
|
color: var(--color-text-inverse, #eaeff8);
|
||||||
|
}
|
||||||
|
|
||||||
|
.btn-primary:hover:not(:disabled) {
|
||||||
|
background: var(--color-primary-hover, #234820);
|
||||||
|
border-color: var(--color-primary-hover, #234820);
|
||||||
|
}
|
||||||
|
|
||||||
|
.btn-primary:disabled {
|
||||||
|
opacity: 0.5;
|
||||||
|
cursor: not-allowed;
|
||||||
|
}
|
||||||
|
|
||||||
|
.btn-danger {
|
||||||
|
border-color: var(--color-error, #c0392b);
|
||||||
|
background: transparent;
|
||||||
|
color: var(--color-error, #c0392b);
|
||||||
|
}
|
||||||
|
|
||||||
|
.btn-danger:hover {
|
||||||
|
background: color-mix(in srgb, var(--color-error, #c0392b) 10%, transparent);
|
||||||
|
}
|
||||||
|
|
||||||
|
.btn-dismiss {
|
||||||
|
border: none;
|
||||||
|
background: transparent;
|
||||||
|
color: var(--color-text-muted, #4a5c7a);
|
||||||
|
cursor: pointer;
|
||||||
|
font-size: 0.9rem;
|
||||||
|
padding: 0.15rem 0.4rem;
|
||||||
|
border-radius: var(--radius-sm, 0.25rem);
|
||||||
|
flex-shrink: 0;
|
||||||
|
transition: color var(--transition, 200ms ease), background var(--transition, 200ms ease);
|
||||||
|
}
|
||||||
|
|
||||||
|
.btn-dismiss:hover {
|
||||||
|
color: var(--color-error, #c0392b);
|
||||||
|
background: color-mix(in srgb, var(--color-error, #c0392b) 10%, transparent);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ── Responsive ── */
|
||||||
|
@media (max-width: 480px) {
|
||||||
|
.lookup-row {
|
||||||
|
flex-direction: column;
|
||||||
|
}
|
||||||
|
|
||||||
|
.lookup-input {
|
||||||
|
width: 100%;
|
||||||
|
}
|
||||||
|
|
||||||
|
.btn-primary:not(.btn-sm) {
|
||||||
|
width: 100%;
|
||||||
|
}
|
||||||
|
|
||||||
|
.installed-table th:nth-child(3),
|
||||||
|
.installed-table td:nth-child(3) {
|
||||||
|
display: none; /* hide Adapter column on very narrow screens */
|
||||||
|
}
|
||||||
|
}
|
||||||
|
</style>
|
||||||
|
|
@ -115,8 +115,18 @@
|
||||||
<h2 class="section-title">cf-orch Integration</h2>
|
<h2 class="section-title">cf-orch Integration</h2>
|
||||||
<p class="section-desc">
|
<p class="section-desc">
|
||||||
Import SFT (supervised fine-tuning) candidates from cf-orch benchmark runs.
|
Import SFT (supervised fine-tuning) candidates from cf-orch benchmark runs.
|
||||||
|
Connection settings fall back to environment variables
|
||||||
|
(<code>CF_ORCH_URL</code>, <code>CF_LICENSE_KEY</code>, <code>OLLAMA_HOST</code>)
|
||||||
|
when not set here.
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
|
<!-- Connection status pill -->
|
||||||
|
<div v-if="orchConfig" class="orch-status-row">
|
||||||
|
<span class="orch-status-pill" :class="orchStatusClass">{{ orchStatusLabel }}</span>
|
||||||
|
<span v-if="orchConfig.source === 'env'" class="orch-source-note">via env vars</span>
|
||||||
|
<span v-else class="orch-source-note">via label_tool.yaml</span>
|
||||||
|
</div>
|
||||||
|
|
||||||
<div class="field-row">
|
<div class="field-row">
|
||||||
<label class="field field-grow">
|
<label class="field field-grow">
|
||||||
<span>bench_results_dir</span>
|
<span>bench_results_dir</span>
|
||||||
|
|
@ -181,7 +191,7 @@
|
||||||
</template>
|
</template>
|
||||||
|
|
||||||
<script setup lang="ts">
|
<script setup lang="ts">
|
||||||
import { ref, onMounted } from 'vue'
|
import { ref, computed, onMounted } from 'vue'
|
||||||
import { useApiFetch } from '../composables/useApi'
|
import { useApiFetch } from '../composables/useApi'
|
||||||
|
|
||||||
interface Account {
|
interface Account {
|
||||||
|
|
@ -199,12 +209,27 @@ const saveOk = ref(true)
|
||||||
const richMotion = ref(localStorage.getItem('cf-avocet-rich-motion') !== 'false')
|
const richMotion = ref(localStorage.getItem('cf-avocet-rich-motion') !== 'false')
|
||||||
const keyHints = ref(localStorage.getItem('cf-avocet-key-hints') !== 'false')
|
const keyHints = ref(localStorage.getItem('cf-avocet-key-hints') !== 'false')
|
||||||
|
|
||||||
// SFT integration state
|
// SFT / cf-orch integration state
|
||||||
const benchResultsDir = ref('')
|
const benchResultsDir = ref('')
|
||||||
const runs = ref<Array<{ run_id: string; timestamp: string; candidate_count: number; already_imported: boolean }>>([])
|
const runs = ref<Array<{ run_id: string; timestamp: string; candidate_count: number; already_imported: boolean }>>([])
|
||||||
const importingRunId = ref<string | null>(null)
|
const importingRunId = ref<string | null>(null)
|
||||||
const importResult = ref<{ imported: number; skipped: number } | null>(null)
|
const importResult = ref<{ imported: number; skipped: number } | null>(null)
|
||||||
const saveStatus = ref('')
|
const saveStatus = ref('')
|
||||||
|
const orchConfig = ref<{ coordinator_url: string; ollama_url: string; ollama_model: string; license_key_set: boolean; source: string } | null>(null)
|
||||||
|
|
||||||
|
const orchStatusClass = computed(() => {
|
||||||
|
if (!orchConfig.value) return 'status-unknown'
|
||||||
|
if (orchConfig.value.coordinator_url) return 'status-connected'
|
||||||
|
if (orchConfig.value.ollama_url) return 'status-local'
|
||||||
|
return 'status-unconfigured'
|
||||||
|
})
|
||||||
|
|
||||||
|
const orchStatusLabel = computed(() => {
|
||||||
|
if (!orchConfig.value) return 'Unknown'
|
||||||
|
if (orchConfig.value.coordinator_url) return '● cf-orch coordinator'
|
||||||
|
if (orchConfig.value.ollama_url) return '● Ollama (local)'
|
||||||
|
return '○ Not configured'
|
||||||
|
})
|
||||||
|
|
||||||
async function loadSftConfig() {
|
async function loadSftConfig() {
|
||||||
try {
|
try {
|
||||||
|
|
@ -218,6 +243,15 @@ async function loadSftConfig() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async function loadOrchConfig() {
|
||||||
|
try {
|
||||||
|
const res = await fetch('/api/cforch/config')
|
||||||
|
if (res.ok) orchConfig.value = await res.json()
|
||||||
|
} catch {
|
||||||
|
// non-fatal
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
async function saveSftConfig() {
|
async function saveSftConfig() {
|
||||||
saveStatus.value = 'Saving…'
|
saveStatus.value = 'Saving…'
|
||||||
try {
|
try {
|
||||||
|
|
@ -337,6 +371,7 @@ function onKeyHintsChange() {
|
||||||
onMounted(() => {
|
onMounted(() => {
|
||||||
reload()
|
reload()
|
||||||
loadSftConfig()
|
loadSftConfig()
|
||||||
|
loadOrchConfig()
|
||||||
})
|
})
|
||||||
</script>
|
</script>
|
||||||
|
|
||||||
|
|
@ -564,6 +599,31 @@ onMounted(() => {
|
||||||
width: 100%;
|
width: 100%;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.orch-status-row {
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
gap: var(--space-2);
|
||||||
|
margin-bottom: var(--space-3);
|
||||||
|
}
|
||||||
|
|
||||||
|
.orch-status-pill {
|
||||||
|
font-size: 0.8rem;
|
||||||
|
font-weight: 600;
|
||||||
|
padding: var(--space-1) var(--space-3);
|
||||||
|
border-radius: var(--radius-full);
|
||||||
|
}
|
||||||
|
|
||||||
|
.status-connected { background: color-mix(in srgb, var(--color-success, #3a7a32) 12%, transparent); color: var(--color-success, #3a7a32); }
|
||||||
|
.status-local { background: color-mix(in srgb, var(--color-primary) 12%, transparent); color: var(--color-primary); }
|
||||||
|
.status-unconfigured { background: var(--color-surface-alt); color: var(--color-text-muted); }
|
||||||
|
.status-unknown { background: var(--color-surface-alt); color: var(--color-text-muted); }
|
||||||
|
|
||||||
|
.orch-source-note {
|
||||||
|
font-size: 0.75rem;
|
||||||
|
color: var(--color-text-muted);
|
||||||
|
font-style: italic;
|
||||||
|
}
|
||||||
|
|
||||||
.runs-table {
|
.runs-table {
|
||||||
width: 100%;
|
width: 100%;
|
||||||
border-collapse: collapse;
|
border-collapse: collapse;
|
||||||
|
|
|
||||||
|
|
@ -35,6 +35,77 @@
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
<!-- Benchmark Results -->
|
||||||
|
<template v-if="benchRows.length > 0">
|
||||||
|
<h2 class="section-title">🏁 Benchmark Results</h2>
|
||||||
|
<div class="bench-table-wrap">
|
||||||
|
<table class="bench-table">
|
||||||
|
<thead>
|
||||||
|
<tr>
|
||||||
|
<th class="bt-model-col">Model</th>
|
||||||
|
<th
|
||||||
|
v-for="m in BENCH_METRICS"
|
||||||
|
:key="m.key as string"
|
||||||
|
class="bt-metric-col"
|
||||||
|
>{{ m.label }}</th>
|
||||||
|
</tr>
|
||||||
|
</thead>
|
||||||
|
<tbody>
|
||||||
|
<tr v-for="row in benchRows" :key="row.name">
|
||||||
|
<td class="bt-model-cell" :title="row.name">{{ row.name }}</td>
|
||||||
|
<td
|
||||||
|
v-for="m in BENCH_METRICS"
|
||||||
|
:key="m.key as string"
|
||||||
|
class="bt-metric-cell"
|
||||||
|
:class="{ 'bt-best': bestByMetric[m.key as string] === row.name }"
|
||||||
|
>
|
||||||
|
{{ formatMetric(row.result[m.key]) }}
|
||||||
|
</td>
|
||||||
|
</tr>
|
||||||
|
</tbody>
|
||||||
|
</table>
|
||||||
|
</div>
|
||||||
|
<p class="bench-hint">Highlighted cells are the best-scoring model per metric.</p>
|
||||||
|
</template>
|
||||||
|
|
||||||
|
<!-- LLM Benchmark Results -->
|
||||||
|
<template v-if="llmResults.length > 0">
|
||||||
|
<h2 class="section-title">🤖 LLM Benchmark</h2>
|
||||||
|
<div class="bench-table-wrap">
|
||||||
|
<table class="bench-table">
|
||||||
|
<thead>
|
||||||
|
<tr>
|
||||||
|
<th class="bt-model-col">Model</th>
|
||||||
|
<th class="bt-metric-col">overall</th>
|
||||||
|
<th
|
||||||
|
v-for="col in llmTaskTypeCols"
|
||||||
|
:key="col"
|
||||||
|
class="bt-metric-col"
|
||||||
|
>{{ col }}</th>
|
||||||
|
<th class="bt-metric-col">tok/s</th>
|
||||||
|
</tr>
|
||||||
|
</thead>
|
||||||
|
<tbody>
|
||||||
|
<tr v-for="row in llmResults" :key="row.model_id">
|
||||||
|
<td class="bt-model-cell" :title="row.model_id">{{ row.model_name }}</td>
|
||||||
|
<td
|
||||||
|
class="bt-metric-cell"
|
||||||
|
:class="{ 'bt-best': llmBestByCol['overall'] === row.model_id }"
|
||||||
|
>{{ llmPct(row.avg_quality_score) }}</td>
|
||||||
|
<td
|
||||||
|
v-for="col in llmTaskTypeCols"
|
||||||
|
:key="col"
|
||||||
|
class="bt-metric-cell"
|
||||||
|
:class="{ 'bt-best': llmBestByCol[col] === row.model_id }"
|
||||||
|
>{{ row.quality_by_task_type[col] != null ? llmPct(row.quality_by_task_type[col]) : '—' }}</td>
|
||||||
|
<td class="bt-metric-cell">{{ row.avg_tokens_per_sec.toFixed(1) }}</td>
|
||||||
|
</tr>
|
||||||
|
</tbody>
|
||||||
|
</table>
|
||||||
|
</div>
|
||||||
|
<p class="bench-hint">Run LLM Eval on the Benchmark tab to refresh. Highlighted = best per column.</p>
|
||||||
|
</template>
|
||||||
|
|
||||||
<div class="file-info">
|
<div class="file-info">
|
||||||
<span class="file-path">Score file: <code>data/email_score.jsonl</code></span>
|
<span class="file-path">Score file: <code>data/email_score.jsonl</code></span>
|
||||||
<span class="file-size">{{ fileSizeLabel }}</span>
|
<span class="file-size">{{ fileSizeLabel }}</span>
|
||||||
|
|
@ -54,10 +125,30 @@
|
||||||
import { ref, computed, onMounted } from 'vue'
|
import { ref, computed, onMounted } from 'vue'
|
||||||
import { useApiFetch } from '../composables/useApi'
|
import { useApiFetch } from '../composables/useApi'
|
||||||
|
|
||||||
|
interface BenchmarkModelResult {
|
||||||
|
accuracy?: number
|
||||||
|
macro_f1?: number
|
||||||
|
weighted_f1?: number
|
||||||
|
[key: string]: number | undefined
|
||||||
|
}
|
||||||
|
|
||||||
|
interface LlmModelResult {
|
||||||
|
model_name: string
|
||||||
|
model_id: string
|
||||||
|
node_id: string
|
||||||
|
avg_tokens_per_sec: number
|
||||||
|
avg_completion_ms: number
|
||||||
|
avg_quality_score: number
|
||||||
|
finetune_candidates: number
|
||||||
|
error_count: number
|
||||||
|
quality_by_task_type: Record<string, number>
|
||||||
|
}
|
||||||
|
|
||||||
interface StatsResponse {
|
interface StatsResponse {
|
||||||
total: number
|
total: number
|
||||||
counts: Record<string, number>
|
counts: Record<string, number>
|
||||||
score_file_bytes: number
|
score_file_bytes: number
|
||||||
|
benchmark_results?: Record<string, BenchmarkModelResult>
|
||||||
}
|
}
|
||||||
|
|
||||||
// Canonical label order + metadata
|
// Canonical label order + metadata
|
||||||
|
|
@ -108,6 +199,85 @@ const fileSizeLabel = computed(() => {
|
||||||
return `${(b / 1024 / 1024).toFixed(2)} MB`
|
return `${(b / 1024 / 1024).toFixed(2)} MB`
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// Benchmark results helpers
|
||||||
|
const BENCH_METRICS: Array<{ key: keyof BenchmarkModelResult; label: string }> = [
|
||||||
|
{ key: 'accuracy', label: 'Accuracy' },
|
||||||
|
{ key: 'macro_f1', label: 'Macro F1' },
|
||||||
|
{ key: 'weighted_f1', label: 'Weighted F1' },
|
||||||
|
]
|
||||||
|
|
||||||
|
const benchRows = computed(() => {
|
||||||
|
const br = stats.value.benchmark_results
|
||||||
|
if (!br || Object.keys(br).length === 0) return []
|
||||||
|
return Object.entries(br).map(([name, result]) => ({ name, result }))
|
||||||
|
})
|
||||||
|
|
||||||
|
// Find the best model name for each metric
|
||||||
|
const bestByMetric = computed((): Record<string, string> => {
|
||||||
|
const result: Record<string, string> = {}
|
||||||
|
for (const { key } of BENCH_METRICS) {
|
||||||
|
let bestName = ''
|
||||||
|
let bestVal = -Infinity
|
||||||
|
for (const { name, result: r } of benchRows.value) {
|
||||||
|
const v = r[key]
|
||||||
|
if (v != null && v > bestVal) { bestVal = v; bestName = name }
|
||||||
|
}
|
||||||
|
result[key as string] = bestName
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
})
|
||||||
|
|
||||||
|
function formatMetric(v: number | undefined): string {
|
||||||
|
if (v == null) return '—'
|
||||||
|
// Values in 0-1 range: format as percentage
|
||||||
|
if (v <= 1) return `${(v * 100).toFixed(1)}%`
|
||||||
|
// Already a percentage
|
||||||
|
return `${v.toFixed(1)}%`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── LLM Benchmark results ────────────────────────────────────────────────────
|
||||||
|
const llmResults = ref<LlmModelResult[]>([])
|
||||||
|
|
||||||
|
const llmTaskTypeCols = computed(() => {
|
||||||
|
const types = new Set<string>()
|
||||||
|
for (const r of llmResults.value) {
|
||||||
|
for (const k of Object.keys(r.quality_by_task_type)) types.add(k)
|
||||||
|
}
|
||||||
|
return [...types].sort()
|
||||||
|
})
|
||||||
|
|
||||||
|
const llmBestByCol = computed((): Record<string, string> => {
|
||||||
|
const best: Record<string, string> = {}
|
||||||
|
if (llmResults.value.length === 0) return best
|
||||||
|
|
||||||
|
let bestId = '', bestVal = -Infinity
|
||||||
|
for (const r of llmResults.value) {
|
||||||
|
if (r.avg_quality_score > bestVal) { bestVal = r.avg_quality_score; bestId = r.model_id }
|
||||||
|
}
|
||||||
|
best['overall'] = bestId
|
||||||
|
|
||||||
|
for (const col of llmTaskTypeCols.value) {
|
||||||
|
bestId = ''; bestVal = -Infinity
|
||||||
|
for (const r of llmResults.value) {
|
||||||
|
const v = r.quality_by_task_type[col]
|
||||||
|
if (v != null && v > bestVal) { bestVal = v; bestId = r.model_id }
|
||||||
|
}
|
||||||
|
best[col] = bestId
|
||||||
|
}
|
||||||
|
return best
|
||||||
|
})
|
||||||
|
|
||||||
|
function llmPct(v: number): string {
|
||||||
|
return `${(v * 100).toFixed(1)}%`
|
||||||
|
}
|
||||||
|
|
||||||
|
async function loadLlmResults() {
|
||||||
|
const { data } = await useApiFetch<LlmModelResult[]>('/api/cforch/results')
|
||||||
|
if (Array.isArray(data) && data.length > 0) {
|
||||||
|
llmResults.value = data
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
async function load() {
|
async function load() {
|
||||||
loading.value = true
|
loading.value = true
|
||||||
error.value = ''
|
error.value = ''
|
||||||
|
|
@ -120,7 +290,10 @@ async function load() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
onMounted(load)
|
onMounted(() => {
|
||||||
|
load()
|
||||||
|
loadLlmResults()
|
||||||
|
})
|
||||||
</script>
|
</script>
|
||||||
|
|
||||||
<style scoped>
|
<style scoped>
|
||||||
|
|
@ -234,6 +407,79 @@ onMounted(load)
|
||||||
padding: 1rem;
|
padding: 1rem;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ── Benchmark Results ──────────────────────────── */
|
||||||
|
.section-title {
|
||||||
|
font-family: var(--font-display, var(--font-body, sans-serif));
|
||||||
|
font-size: 1.05rem;
|
||||||
|
font-weight: 700;
|
||||||
|
color: var(--app-primary, #2A6080);
|
||||||
|
margin: 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
.bench-table-wrap {
|
||||||
|
overflow-x: auto;
|
||||||
|
border: 1px solid var(--color-border, #d0d7e8);
|
||||||
|
border-radius: 0.5rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.bench-table {
|
||||||
|
border-collapse: collapse;
|
||||||
|
width: 100%;
|
||||||
|
font-size: 0.82rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.bt-model-col {
|
||||||
|
text-align: left;
|
||||||
|
padding: 0.45rem 0.75rem;
|
||||||
|
background: var(--color-surface-raised, #e4ebf5);
|
||||||
|
border-bottom: 1px solid var(--color-border, #d0d7e8);
|
||||||
|
font-weight: 600;
|
||||||
|
min-width: 12rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.bt-metric-col {
|
||||||
|
text-align: right;
|
||||||
|
padding: 0.45rem 0.75rem;
|
||||||
|
background: var(--color-surface-raised, #e4ebf5);
|
||||||
|
border-bottom: 1px solid var(--color-border, #d0d7e8);
|
||||||
|
font-weight: 600;
|
||||||
|
white-space: nowrap;
|
||||||
|
min-width: 6rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.bt-model-cell {
|
||||||
|
padding: 0.4rem 0.75rem;
|
||||||
|
border-top: 1px solid var(--color-border, #d0d7e8);
|
||||||
|
font-family: var(--font-mono, monospace);
|
||||||
|
font-size: 0.76rem;
|
||||||
|
white-space: nowrap;
|
||||||
|
overflow: hidden;
|
||||||
|
text-overflow: ellipsis;
|
||||||
|
max-width: 16rem;
|
||||||
|
color: var(--color-text, #1a2338);
|
||||||
|
}
|
||||||
|
|
||||||
|
.bt-metric-cell {
|
||||||
|
padding: 0.4rem 0.75rem;
|
||||||
|
border-top: 1px solid var(--color-border, #d0d7e8);
|
||||||
|
text-align: right;
|
||||||
|
font-family: var(--font-mono, monospace);
|
||||||
|
font-variant-numeric: tabular-nums;
|
||||||
|
color: var(--color-text, #1a2338);
|
||||||
|
}
|
||||||
|
|
||||||
|
.bt-metric-cell.bt-best {
|
||||||
|
color: var(--color-success, #3a7a32);
|
||||||
|
font-weight: 700;
|
||||||
|
background: color-mix(in srgb, var(--color-success, #3a7a32) 8%, transparent);
|
||||||
|
}
|
||||||
|
|
||||||
|
.bench-hint {
|
||||||
|
font-size: 0.75rem;
|
||||||
|
color: var(--color-text-secondary, #6b7a99);
|
||||||
|
margin: 0;
|
||||||
|
}
|
||||||
|
|
||||||
@media (max-width: 480px) {
|
@media (max-width: 480px) {
|
||||||
.bar-row {
|
.bar-row {
|
||||||
grid-template-columns: 1.5rem 1fr 1fr 3rem;
|
grid-template-columns: 1.5rem 1fr 1fr 3rem;
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue