Compare commits
52 commits
0745bc3f70
...
32872d1ec6
| Author | SHA1 | Date | |
|---|---|---|---|
| 32872d1ec6 | |||
| 1521198cb1 | |||
| 8dda040480 | |||
| bf675ed1f6 | |||
| 0efd1aedbe | |||
| 4c225b94f5 | |||
| 1cd9c5d455 | |||
| 5702a7190b | |||
| 55b017ba3b | |||
| f952ec8971 | |||
| fd8cb622a1 | |||
| 47cb9f661f | |||
| c2de9e53da | |||
| c039ea4698 | |||
| 95afddb772 | |||
| cbe8c0f03e | |||
| 5df33b0f41 | |||
| 41584de5df | |||
| 1d4c07e4a0 | |||
| e823b5e76d | |||
| 88bc6bed67 | |||
| 4a64a6686d | |||
| f2f150b4fb | |||
| 72449561cf | |||
| c177fb1628 | |||
| 3be5055e31 | |||
| 78b64d007d | |||
| bce932461a | |||
| e11db5ccd9 | |||
| 13d1a394d5 | |||
| b077371107 | |||
| 53b25b27ab | |||
| e014da2dec | |||
| c48db45d91 | |||
| d0ba75b995 | |||
| a134af8b7b | |||
| 6ef6f06023 | |||
| 5bdb095235 | |||
| 0904967320 | |||
| 8fda821e15 | |||
| 0853ed7d56 | |||
| aa742bcfc0 | |||
| 32d3436bbd | |||
| 766fbafa02 | |||
| d432026fd7 | |||
| bccb385f61 | |||
| d74ad3f972 | |||
| 99ea39fe38 | |||
| 2054866ff1 | |||
| cbec776ef1 | |||
| 167d7351e3 | |||
| 6689ff07b1 |
65 changed files with 11024 additions and 2449 deletions
|
|
@ -17,3 +17,7 @@ CF_LICENSE_KEY=CFG-AVCT-xxxx-xxxx-xxxx
|
|||
# Set one of these to use a cloud LLM instead of a local model.
|
||||
# ANTHROPIC_API_KEY=sk-ant-...
|
||||
# OPENAI_API_KEY=sk-...
|
||||
|
||||
# ── HuggingFace (required for gated/terms-restricted model downloads) ─────────
|
||||
# Generate at https://huggingface.co/settings/tokens and accept model terms first.
|
||||
# HF_TOKEN=hf_xxxxxxxxxxxxxxxxxxxx
|
||||
|
|
|
|||
4
.gitignore
vendored
4
.gitignore
vendored
|
|
@ -20,3 +20,7 @@ data/sft_approved.jsonl
|
|||
# Claude context — BSL 1.1, keep out of version control
|
||||
CLAUDE.md
|
||||
docs/superpowers/
|
||||
.superpowers/
|
||||
|
||||
# Git worktrees
|
||||
.worktrees/
|
||||
|
|
|
|||
625
app/api.py
625
app/api.py
|
|
@ -1,623 +1,62 @@
|
|||
"""Avocet — FastAPI REST layer.
|
||||
"""Avocet -- FastAPI app factory.
|
||||
|
||||
JSONL read/write helpers and FastAPI app instance.
|
||||
Endpoints and static file serving are added in subsequent tasks.
|
||||
Mounts all domain routers and serves the Vue SPA.
|
||||
All business logic lives in the domain modules below.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import subprocess as _subprocess
|
||||
import yaml
|
||||
from pathlib import Path
|
||||
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from fastapi import FastAPI, HTTPException, Query
|
||||
from pydantic import BaseModel
|
||||
|
||||
_ROOT = Path(__file__).parent.parent
|
||||
_DATA_DIR: Path = _ROOT / "data" # overridable in tests via set_data_dir()
|
||||
_MODELS_DIR: Path = _ROOT / "models" # overridable in tests via set_models_dir()
|
||||
_CONFIG_DIR: Path | None = None # None = use real path
|
||||
|
||||
# Process registry for running jobs — used by cancel endpoints.
|
||||
# Keys: "benchmark" | "finetune". Values: the live Popen object.
|
||||
_running_procs: dict = {}
|
||||
_cancelled_jobs: set = set()
|
||||
|
||||
|
||||
def set_data_dir(path: Path) -> None:
|
||||
"""Override data directory — used by tests."""
|
||||
global _DATA_DIR
|
||||
_DATA_DIR = path
|
||||
|
||||
|
||||
def _best_cuda_device() -> str:
|
||||
"""Return the index of the GPU with the most free VRAM as a string.
|
||||
|
||||
Uses nvidia-smi so it works in the job-seeker env (no torch). Returns ""
|
||||
if nvidia-smi is unavailable or no GPUs are found. Restricting the
|
||||
training subprocess to a single GPU via CUDA_VISIBLE_DEVICES prevents
|
||||
PyTorch DataParallel from replicating the model across all GPUs, which
|
||||
would OOM the GPU with less headroom.
|
||||
"""
|
||||
try:
|
||||
out = _subprocess.check_output(
|
||||
["nvidia-smi", "--query-gpu=index,memory.free",
|
||||
"--format=csv,noheader,nounits"],
|
||||
text=True,
|
||||
timeout=5,
|
||||
)
|
||||
best_idx, best_free = "", 0
|
||||
for line in out.strip().splitlines():
|
||||
parts = line.strip().split(", ")
|
||||
if len(parts) == 2:
|
||||
idx, free = parts[0].strip(), int(parts[1].strip())
|
||||
if free > best_free:
|
||||
best_free, best_idx = free, idx
|
||||
return best_idx
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
|
||||
def set_models_dir(path: Path) -> None:
|
||||
"""Override models directory — used by tests."""
|
||||
global _MODELS_DIR
|
||||
_MODELS_DIR = path
|
||||
|
||||
|
||||
def set_config_dir(path: Path | None) -> None:
|
||||
"""Override config directory — used by tests."""
|
||||
global _CONFIG_DIR
|
||||
_CONFIG_DIR = path
|
||||
|
||||
|
||||
def _config_file() -> Path:
|
||||
if _CONFIG_DIR is not None:
|
||||
return _CONFIG_DIR / "label_tool.yaml"
|
||||
return _ROOT / "config" / "label_tool.yaml"
|
||||
|
||||
|
||||
def reset_last_action() -> None:
|
||||
"""Reset undo state — used by tests."""
|
||||
global _last_action
|
||||
_last_action = None
|
||||
|
||||
|
||||
def _queue_file() -> Path:
|
||||
return _DATA_DIR / "email_label_queue.jsonl"
|
||||
|
||||
|
||||
def _score_file() -> Path:
|
||||
return _DATA_DIR / "email_score.jsonl"
|
||||
|
||||
|
||||
def _discarded_file() -> Path:
|
||||
return _DATA_DIR / "discarded.jsonl"
|
||||
|
||||
|
||||
def _read_jsonl(path: Path) -> list[dict]:
|
||||
if not path.exists():
|
||||
return []
|
||||
lines = path.read_text(encoding="utf-8").splitlines()
|
||||
return [json.loads(l) for l in lines if l.strip()]
|
||||
|
||||
|
||||
def _write_jsonl(path: Path, records: list[dict]) -> None:
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
text = "\n".join(json.dumps(r, ensure_ascii=False) for r in records)
|
||||
path.write_text(text + "\n" if records else "", encoding="utf-8")
|
||||
|
||||
|
||||
def _append_jsonl(path: Path, record: dict) -> None:
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with path.open("a", encoding="utf-8") as f:
|
||||
f.write(json.dumps(record, ensure_ascii=False) + "\n")
|
||||
|
||||
|
||||
def _item_id(item: dict) -> str:
|
||||
"""Stable content-hash ID — matches label_tool.py _entry_key dedup logic."""
|
||||
key = (item.get("subject", "") + (item.get("body", "") or "")[:100])
|
||||
return hashlib.md5(key.encode("utf-8", errors="replace")).hexdigest()
|
||||
|
||||
|
||||
def _normalize(item: dict) -> dict:
|
||||
"""Normalize JSONL item to the Vue frontend schema.
|
||||
|
||||
label_tool.py stores: subject, body, from_addr, date, account (no id).
|
||||
The Vue app expects: id, subject, body, from, date, source.
|
||||
Both old (from_addr/account) and new (from/source) field names are handled.
|
||||
"""
|
||||
return {
|
||||
"id": item.get("id") or _item_id(item),
|
||||
"subject": item.get("subject", ""),
|
||||
"body": item.get("body", ""),
|
||||
"from": item.get("from") or item.get("from_addr", ""),
|
||||
"date": item.get("date", ""),
|
||||
"source": item.get("source") or item.get("account", ""),
|
||||
}
|
||||
|
||||
from fastapi import FastAPI
|
||||
|
||||
app = FastAPI(title="Avocet API")
|
||||
|
||||
from app.sft import router as sft_router
|
||||
app.include_router(sft_router, prefix="/api/sft")
|
||||
# -- Domain routers --------------------------------------------------------
|
||||
|
||||
from app.models import router as models_router
|
||||
import app.models as _models_module
|
||||
app.include_router(models_router, prefix="/api/models")
|
||||
from app.data.label import router as label_router
|
||||
app.include_router(label_router, prefix="/api")
|
||||
|
||||
from app.cforch import router as cforch_router
|
||||
app.include_router(cforch_router, prefix="/api/cforch")
|
||||
from app.data.fetch import router as fetch_router
|
||||
app.include_router(fetch_router, prefix="/api")
|
||||
|
||||
from app.imitate import router as imitate_router
|
||||
from app.data.corrections import router as corrections_router
|
||||
app.include_router(corrections_router, prefix="/api/corrections")
|
||||
|
||||
# Backward-compat alias -- remove when Vue SPA is updated to /api/corrections/*
|
||||
app.include_router(corrections_router, prefix="/api/sft")
|
||||
|
||||
from app.data.imitate import router as imitate_router
|
||||
app.include_router(imitate_router, prefix="/api/imitate")
|
||||
|
||||
from app.style import router as style_router
|
||||
app.include_router(style_router, prefix="/api/style")
|
||||
from app.eval.cforch import router as eval_router
|
||||
app.include_router(eval_router, prefix="/api")
|
||||
|
||||
from app.train.train import router as train_router
|
||||
app.include_router(train_router, prefix="/api/train")
|
||||
|
||||
from app.plans_bench import router as plans_bench_router
|
||||
app.include_router(plans_bench_router, prefix="/api/plans-bench")
|
||||
|
||||
# In-memory last-action store (single user, local tool — in-memory is fine)
|
||||
_last_action: dict | None = None
|
||||
|
||||
from app.dashboard import router as dashboard_router
|
||||
app.include_router(dashboard_router, prefix="/api")
|
||||
|
||||
@app.get("/api/queue")
|
||||
def get_queue(limit: int = Query(default=10, ge=1, le=50)):
|
||||
items = _read_jsonl(_queue_file())
|
||||
return {"items": [_normalize(x) for x in items[:limit]], "total": len(items)}
|
||||
from app.models import router as models_router
|
||||
app.include_router(models_router, prefix="/api/models")
|
||||
|
||||
from app.nodes import router as nodes_router
|
||||
app.include_router(nodes_router, prefix="/api/nodes-mgmt")
|
||||
|
||||
class LabelRequest(BaseModel):
|
||||
id: str
|
||||
label: str
|
||||
# -- Static SPA -- MUST be last (catches all unmatched paths) ---------------
|
||||
|
||||
|
||||
@app.post("/api/label")
|
||||
def post_label(req: LabelRequest):
|
||||
global _last_action
|
||||
items = _read_jsonl(_queue_file())
|
||||
match = next((x for x in items if _normalize(x)["id"] == req.id), None)
|
||||
if not match:
|
||||
raise HTTPException(404, f"Item {req.id!r} not found in queue")
|
||||
record = {**match, "label": req.label,
|
||||
"labeled_at": datetime.now(timezone.utc).isoformat()}
|
||||
_append_jsonl(_score_file(), record)
|
||||
_write_jsonl(_queue_file(), [x for x in items if _normalize(x)["id"] != req.id])
|
||||
_last_action = {"type": "label", "item": match, "label": req.label}
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
class SkipRequest(BaseModel):
|
||||
id: str
|
||||
|
||||
|
||||
@app.post("/api/skip")
|
||||
def post_skip(req: SkipRequest):
|
||||
global _last_action
|
||||
items = _read_jsonl(_queue_file())
|
||||
match = next((x for x in items if _normalize(x)["id"] == req.id), None)
|
||||
if not match:
|
||||
raise HTTPException(404, f"Item {req.id!r} not found in queue")
|
||||
reordered = [x for x in items if _normalize(x)["id"] != req.id] + [match]
|
||||
_write_jsonl(_queue_file(), reordered)
|
||||
_last_action = {"type": "skip", "item": match}
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
class DiscardRequest(BaseModel):
|
||||
id: str
|
||||
|
||||
|
||||
@app.post("/api/discard")
|
||||
def post_discard(req: DiscardRequest):
|
||||
global _last_action
|
||||
items = _read_jsonl(_queue_file())
|
||||
match = next((x for x in items if _normalize(x)["id"] == req.id), None)
|
||||
if not match:
|
||||
raise HTTPException(404, f"Item {req.id!r} not found in queue")
|
||||
record = {**match, "label": "__discarded__",
|
||||
"discarded_at": datetime.now(timezone.utc).isoformat()}
|
||||
_append_jsonl(_discarded_file(), record)
|
||||
_write_jsonl(_queue_file(), [x for x in items if _normalize(x)["id"] != req.id])
|
||||
_last_action = {"type": "discard", "item": match}
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
@app.delete("/api/label/undo")
|
||||
def delete_undo():
|
||||
global _last_action
|
||||
if not _last_action:
|
||||
raise HTTPException(404, "No action to undo")
|
||||
action = _last_action
|
||||
item = action["item"] # always the original clean queue item
|
||||
|
||||
# Perform file operations FIRST — only clear _last_action on success
|
||||
if action["type"] == "label":
|
||||
records = _read_jsonl(_score_file())
|
||||
if not records:
|
||||
raise HTTPException(409, "Score file is empty — cannot undo label")
|
||||
_write_jsonl(_score_file(), records[:-1])
|
||||
items = _read_jsonl(_queue_file())
|
||||
_write_jsonl(_queue_file(), [item] + items)
|
||||
elif action["type"] == "discard":
|
||||
records = _read_jsonl(_discarded_file())
|
||||
if not records:
|
||||
raise HTTPException(409, "Discarded file is empty — cannot undo discard")
|
||||
_write_jsonl(_discarded_file(), records[:-1])
|
||||
items = _read_jsonl(_queue_file())
|
||||
_write_jsonl(_queue_file(), [item] + items)
|
||||
elif action["type"] == "skip":
|
||||
items = _read_jsonl(_queue_file())
|
||||
item_id = _normalize(item)["id"]
|
||||
items = [item] + [x for x in items if _normalize(x)["id"] != item_id]
|
||||
_write_jsonl(_queue_file(), items)
|
||||
|
||||
# Clear AFTER all file operations succeed
|
||||
_last_action = None
|
||||
return {"undone": {"type": action["type"], "item": _normalize(item)}}
|
||||
|
||||
|
||||
# Label metadata — 10 labels matching label_tool.py
|
||||
_LABEL_META = [
|
||||
{"name": "interview_scheduled", "emoji": "\U0001f4c5", "color": "#4CAF50", "key": "1"},
|
||||
{"name": "offer_received", "emoji": "\U0001f389", "color": "#2196F3", "key": "2"},
|
||||
{"name": "rejected", "emoji": "\u274c", "color": "#F44336", "key": "3"},
|
||||
{"name": "positive_response", "emoji": "\U0001f44d", "color": "#FF9800", "key": "4"},
|
||||
{"name": "survey_received", "emoji": "\U0001f4cb", "color": "#9C27B0", "key": "5"},
|
||||
{"name": "neutral", "emoji": "\u2b1c", "color": "#607D8B", "key": "6"},
|
||||
{"name": "event_rescheduled", "emoji": "\U0001f504", "color": "#FF5722", "key": "7"},
|
||||
{"name": "digest", "emoji": "\U0001f4f0", "color": "#00BCD4", "key": "8"},
|
||||
{"name": "new_lead", "emoji": "\U0001f91d", "color": "#009688", "key": "9"},
|
||||
{"name": "hired", "emoji": "\U0001f38a", "color": "#FFC107", "key": "h"},
|
||||
]
|
||||
|
||||
|
||||
@app.get("/api/config/labels")
|
||||
def get_labels():
|
||||
return _LABEL_META
|
||||
|
||||
|
||||
@app.get("/api/config")
|
||||
def get_config():
|
||||
f = _config_file()
|
||||
if not f.exists():
|
||||
return {"accounts": [], "max_per_account": 500}
|
||||
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
||||
return {"accounts": raw.get("accounts", []), "max_per_account": raw.get("max_per_account", 500)}
|
||||
|
||||
|
||||
class ConfigPayload(BaseModel):
|
||||
accounts: list[dict]
|
||||
max_per_account: int = 500
|
||||
|
||||
|
||||
@app.post("/api/config")
|
||||
def post_config(payload: ConfigPayload):
|
||||
f = _config_file()
|
||||
f.parent.mkdir(parents=True, exist_ok=True)
|
||||
tmp = f.with_suffix(".tmp")
|
||||
tmp.write_text(yaml.dump(payload.model_dump(), allow_unicode=True, sort_keys=False),
|
||||
encoding="utf-8")
|
||||
tmp.rename(f)
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
@app.get("/api/stats")
|
||||
def get_stats():
|
||||
records = _read_jsonl(_score_file())
|
||||
counts: dict[str, int] = {}
|
||||
for r in records:
|
||||
lbl = r.get("label", "")
|
||||
if lbl:
|
||||
counts[lbl] = counts.get(lbl, 0) + 1
|
||||
benchmark_results: dict = {}
|
||||
benchmark_path = _DATA_DIR / "benchmark_results.json"
|
||||
if benchmark_path.exists():
|
||||
try:
|
||||
benchmark_results = json.loads(benchmark_path.read_text(encoding="utf-8"))
|
||||
except Exception:
|
||||
pass
|
||||
return {
|
||||
"total": len(records),
|
||||
"counts": counts,
|
||||
"score_file_bytes": _score_file().stat().st_size if _score_file().exists() else 0,
|
||||
"benchmark_results": benchmark_results,
|
||||
}
|
||||
|
||||
|
||||
@app.get("/api/stats/download")
|
||||
def download_stats():
|
||||
from fastapi.responses import FileResponse
|
||||
if not _score_file().exists():
|
||||
raise HTTPException(404, "No score file")
|
||||
return FileResponse(
|
||||
str(_score_file()),
|
||||
filename="email_score.jsonl",
|
||||
media_type="application/jsonlines",
|
||||
headers={"Content-Disposition": 'attachment; filename="email_score.jsonl"'},
|
||||
)
|
||||
|
||||
|
||||
class AccountTestRequest(BaseModel):
|
||||
account: dict
|
||||
|
||||
|
||||
@app.post("/api/accounts/test")
|
||||
def test_account(req: AccountTestRequest):
|
||||
from app.imap_fetch import test_connection
|
||||
ok, message, count = test_connection(req.account)
|
||||
return {"ok": ok, "message": message, "count": count}
|
||||
|
||||
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Benchmark endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@app.get("/api/benchmark/models")
|
||||
def get_benchmark_models() -> dict:
|
||||
"""Return installed models grouped by adapter_type category."""
|
||||
models_dir: Path = _models_module._MODELS_DIR
|
||||
categories: dict[str, list[dict]] = {
|
||||
"ZeroShotAdapter": [],
|
||||
"RerankerAdapter": [],
|
||||
"GenerationAdapter": [],
|
||||
"Unknown": [],
|
||||
}
|
||||
if models_dir.exists():
|
||||
for sub in models_dir.iterdir():
|
||||
if not sub.is_dir():
|
||||
continue
|
||||
info_path = sub / "model_info.json"
|
||||
adapter_type = "Unknown"
|
||||
repo_id: str | None = None
|
||||
if info_path.exists():
|
||||
try:
|
||||
info = json.loads(info_path.read_text(encoding="utf-8"))
|
||||
adapter_type = info.get("adapter_type") or info.get("adapter_recommendation") or "Unknown"
|
||||
repo_id = info.get("repo_id")
|
||||
except Exception:
|
||||
pass
|
||||
bucket = adapter_type if adapter_type in categories else "Unknown"
|
||||
entry: dict = {"name": sub.name, "repo_id": repo_id, "adapter_type": adapter_type}
|
||||
categories[bucket].append(entry)
|
||||
return {"categories": categories}
|
||||
|
||||
|
||||
@app.get("/api/benchmark/results")
|
||||
def get_benchmark_results():
|
||||
"""Return the most recently saved benchmark results, or an empty envelope."""
|
||||
path = _DATA_DIR / "benchmark_results.json"
|
||||
if not path.exists():
|
||||
return {"models": {}, "sample_count": 0, "timestamp": None}
|
||||
return json.loads(path.read_text())
|
||||
|
||||
|
||||
@app.get("/api/benchmark/run")
|
||||
def run_benchmark(include_slow: bool = False, model_names: str = ""):
|
||||
"""Spawn the benchmark script and stream stdout as SSE progress events."""
|
||||
python_bin = "/devl/miniconda3/envs/job-seeker-classifiers/bin/python"
|
||||
script = str(_ROOT / "scripts" / "benchmark_classifier.py")
|
||||
cmd = [python_bin, script, "--score", "--save"]
|
||||
if include_slow:
|
||||
cmd.append("--include-slow")
|
||||
if model_names:
|
||||
names = [n.strip() for n in model_names.split(",") if n.strip()]
|
||||
if names:
|
||||
cmd.extend(["--models"] + names)
|
||||
|
||||
def generate():
|
||||
try:
|
||||
proc = _subprocess.Popen(
|
||||
cmd,
|
||||
stdout=_subprocess.PIPE,
|
||||
stderr=_subprocess.STDOUT,
|
||||
text=True,
|
||||
bufsize=1,
|
||||
cwd=str(_ROOT),
|
||||
)
|
||||
_running_procs["benchmark"] = proc
|
||||
_cancelled_jobs.discard("benchmark") # clear any stale flag from a prior run
|
||||
try:
|
||||
for line in proc.stdout:
|
||||
line = line.rstrip()
|
||||
if line:
|
||||
yield f"data: {json.dumps({'type': 'progress', 'message': line})}\n\n"
|
||||
proc.wait()
|
||||
if proc.returncode == 0:
|
||||
yield f"data: {json.dumps({'type': 'complete'})}\n\n"
|
||||
elif "benchmark" in _cancelled_jobs:
|
||||
_cancelled_jobs.discard("benchmark")
|
||||
yield f"data: {json.dumps({'type': 'cancelled'})}\n\n"
|
||||
else:
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': f'Process exited with code {proc.returncode}'})}\n\n"
|
||||
finally:
|
||||
_running_procs.pop("benchmark", None)
|
||||
except Exception as exc:
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': str(exc)})}\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
generate(),
|
||||
media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Finetune endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@app.get("/api/finetune/status")
|
||||
def get_finetune_status():
|
||||
"""Scan models/ for training_info.json files. Returns [] if none exist."""
|
||||
models_dir = _MODELS_DIR
|
||||
if not models_dir.exists():
|
||||
return []
|
||||
results = []
|
||||
for sub in models_dir.iterdir():
|
||||
if not sub.is_dir():
|
||||
continue
|
||||
info_path = sub / "training_info.json"
|
||||
if not info_path.exists():
|
||||
continue
|
||||
try:
|
||||
info = json.loads(info_path.read_text(encoding="utf-8"))
|
||||
results.append(info)
|
||||
except Exception:
|
||||
pass
|
||||
return results
|
||||
|
||||
|
||||
@app.get("/api/finetune/run")
|
||||
def run_finetune_endpoint(
|
||||
model: str = "deberta-small",
|
||||
epochs: int = 5,
|
||||
score: list[str] = Query(default=[]),
|
||||
):
|
||||
"""Spawn finetune_classifier.py and stream stdout as SSE progress events."""
|
||||
python_bin = "/devl/miniconda3/envs/job-seeker-classifiers/bin/python"
|
||||
script = str(_ROOT / "scripts" / "finetune_classifier.py")
|
||||
cmd = [python_bin, script, "--model", model, "--epochs", str(epochs)]
|
||||
data_root = _DATA_DIR.resolve()
|
||||
for score_file in score:
|
||||
resolved = (_DATA_DIR / score_file).resolve()
|
||||
if not str(resolved).startswith(str(data_root)):
|
||||
raise HTTPException(400, f"Invalid score path: {score_file!r}")
|
||||
cmd.extend(["--score", str(resolved)])
|
||||
|
||||
# Pick the GPU with the most free VRAM. Setting CUDA_VISIBLE_DEVICES to a
|
||||
# single device prevents DataParallel from replicating the model across all
|
||||
# GPUs, which would force a full copy onto the more memory-constrained device.
|
||||
proc_env = {**os.environ, "PYTORCH_ALLOC_CONF": "expandable_segments:True"}
|
||||
best_gpu = _best_cuda_device()
|
||||
if best_gpu:
|
||||
proc_env["CUDA_VISIBLE_DEVICES"] = best_gpu
|
||||
|
||||
gpu_note = f"GPU {best_gpu}" if best_gpu else "CPU (no GPU found)"
|
||||
|
||||
def generate():
|
||||
yield f"data: {json.dumps({'type': 'progress', 'message': f'[api] Using {gpu_note} (most free VRAM)'})}\n\n"
|
||||
try:
|
||||
proc = _subprocess.Popen(
|
||||
cmd,
|
||||
stdout=_subprocess.PIPE,
|
||||
stderr=_subprocess.STDOUT,
|
||||
text=True,
|
||||
bufsize=1,
|
||||
cwd=str(_ROOT),
|
||||
env=proc_env,
|
||||
)
|
||||
_running_procs["finetune"] = proc
|
||||
_cancelled_jobs.discard("finetune") # clear any stale flag from a prior run
|
||||
try:
|
||||
for line in proc.stdout:
|
||||
line = line.rstrip()
|
||||
if line:
|
||||
yield f"data: {json.dumps({'type': 'progress', 'message': line})}\n\n"
|
||||
proc.wait()
|
||||
if proc.returncode == 0:
|
||||
yield f"data: {json.dumps({'type': 'complete'})}\n\n"
|
||||
elif "finetune" in _cancelled_jobs:
|
||||
_cancelled_jobs.discard("finetune")
|
||||
yield f"data: {json.dumps({'type': 'cancelled'})}\n\n"
|
||||
else:
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': f'Process exited with code {proc.returncode}'})}\n\n"
|
||||
finally:
|
||||
_running_procs.pop("finetune", None)
|
||||
except Exception as exc:
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': str(exc)})}\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
generate(),
|
||||
media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
|
||||
)
|
||||
|
||||
|
||||
@app.post("/api/benchmark/cancel")
|
||||
def cancel_benchmark():
|
||||
"""Kill the running benchmark subprocess. 404 if none is running."""
|
||||
proc = _running_procs.get("benchmark")
|
||||
if proc is None:
|
||||
raise HTTPException(404, "No benchmark is running")
|
||||
_cancelled_jobs.add("benchmark")
|
||||
proc.terminate()
|
||||
try:
|
||||
proc.wait(timeout=3)
|
||||
except _subprocess.TimeoutExpired:
|
||||
proc.kill()
|
||||
return {"status": "cancelled"}
|
||||
|
||||
|
||||
@app.post("/api/finetune/cancel")
|
||||
def cancel_finetune():
|
||||
"""Kill the running fine-tune subprocess. 404 if none is running."""
|
||||
proc = _running_procs.get("finetune")
|
||||
if proc is None:
|
||||
raise HTTPException(404, "No finetune is running")
|
||||
_cancelled_jobs.add("finetune")
|
||||
proc.terminate()
|
||||
try:
|
||||
proc.wait(timeout=3)
|
||||
except _subprocess.TimeoutExpired:
|
||||
proc.kill()
|
||||
return {"status": "cancelled"}
|
||||
|
||||
|
||||
@app.get("/api/fetch/stream")
|
||||
def fetch_stream(
|
||||
accounts: str = Query(default=""),
|
||||
days_back: int = Query(default=90, ge=1, le=365),
|
||||
limit: int = Query(default=150, ge=1, le=1000),
|
||||
mode: str = Query(default="wide"),
|
||||
):
|
||||
from app.imap_fetch import fetch_account_stream
|
||||
|
||||
selected_names = {n.strip() for n in accounts.split(",") if n.strip()}
|
||||
config = get_config() # reuse existing endpoint logic
|
||||
selected = [a for a in config["accounts"] if a.get("name") in selected_names]
|
||||
|
||||
def generate():
|
||||
known_keys = {_item_id(x) for x in _read_jsonl(_queue_file())}
|
||||
total_added = 0
|
||||
|
||||
for acc in selected:
|
||||
try:
|
||||
batch_emails: list[dict] = []
|
||||
for event in fetch_account_stream(acc, days_back, limit, known_keys):
|
||||
if event["type"] == "done":
|
||||
batch_emails = event.pop("emails", [])
|
||||
total_added += event["added"]
|
||||
yield f"data: {json.dumps(event)}\n\n"
|
||||
# Write new emails to queue after each account
|
||||
if batch_emails:
|
||||
existing = _read_jsonl(_queue_file())
|
||||
_write_jsonl(_queue_file(), existing + batch_emails)
|
||||
except Exception as exc:
|
||||
error_event = {"type": "error", "account": acc.get("name", "?"),
|
||||
"message": str(exc)}
|
||||
yield f"data: {json.dumps(error_event)}\n\n"
|
||||
|
||||
queue_size = len(_read_jsonl(_queue_file()))
|
||||
complete = {"type": "complete", "total_added": total_added, "queue_size": queue_size}
|
||||
yield f"data: {json.dumps(complete)}\n\n"
|
||||
|
||||
return StreamingResponse(generate(), media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"})
|
||||
|
||||
|
||||
# Static SPA — MUST be last (catches all unmatched paths)
|
||||
_ROOT = Path(__file__).parent.parent
|
||||
_DIST = _ROOT / "web" / "dist"
|
||||
if _DIST.exists():
|
||||
from fastapi.responses import FileResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
|
||||
# Serve index.html with no-cache so browsers always fetch fresh HTML after rebuilds.
|
||||
# Hashed assets (/assets/index-abc123.js) can be cached forever — they change names
|
||||
# when content changes (standard Vite cache-busting strategy).
|
||||
_NO_CACHE = {"Cache-Control": "no-cache, no-store, must-revalidate", "Pragma": "no-cache"}
|
||||
|
||||
@app.get("/")
|
||||
|
|
|
|||
244
app/cforch.py
244
app/cforch.py
|
|
@ -17,9 +17,12 @@ import logging
|
|||
import os
|
||||
import re
|
||||
import subprocess as _subprocess
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import urllib.parse
|
||||
|
||||
import yaml
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
|
@ -75,9 +78,31 @@ def _load_cforch_config() -> dict:
|
|||
"license_key": _coalesce(file_cfg.get("license_key", ""), "CF_LICENSE_KEY"),
|
||||
"ollama_url": _coalesce(file_cfg.get("ollama_url", ""), "OLLAMA_HOST"),
|
||||
"ollama_model": _coalesce(file_cfg.get("ollama_model", ""), "OLLAMA_MODEL"),
|
||||
"judge_url": _coalesce(file_cfg.get("judge_url", ""), "CF_JUDGE_URL"),
|
||||
"hf_token": _coalesce(file_cfg.get("hf_token", ""), "HF_TOKEN"),
|
||||
}
|
||||
|
||||
|
||||
def _validate_service_url(url: str, param_name: str) -> str:
|
||||
"""Validate that a URL is a well-formed http/https URL with a hostname.
|
||||
|
||||
Guards against SSRF: only http/https is allowed; the URL must have a
|
||||
non-empty host. Does not enforce an allowlist — call sites are internal
|
||||
tooling, not a public API.
|
||||
"""
|
||||
if not url:
|
||||
return url
|
||||
try:
|
||||
parsed = urllib.parse.urlparse(url)
|
||||
except Exception:
|
||||
raise HTTPException(400, f"{param_name}: not a valid URL")
|
||||
if parsed.scheme not in ("http", "https"):
|
||||
raise HTTPException(400, f"{param_name}: URL must start with http:// or https://")
|
||||
if not parsed.hostname:
|
||||
raise HTTPException(400, f"{param_name}: URL has no hostname")
|
||||
return url
|
||||
|
||||
|
||||
def _strip_ansi(text: str) -> str:
|
||||
"""Remove ANSI escape codes from a string."""
|
||||
return re.sub(r'\x1b\[[0-9;]*m', '', text)
|
||||
|
|
@ -147,48 +172,141 @@ def get_tasks() -> dict:
|
|||
|
||||
# ── GET /models ────────────────────────────────────────────────────────────────
|
||||
|
||||
# Services and roles surfaced in the benchmark model picker.
|
||||
# Covers all cf-orch service types that benchmark.py can route tasks to.
|
||||
_BENCH_SERVICES = frozenset({
|
||||
"cf-text", "vllm", # LLM text generation
|
||||
"cf-stt", # speech-to-text
|
||||
"cf-tts", # text-to-speech
|
||||
"cf-vision", # image classification / embedding
|
||||
"cf-voice", # audio context classification
|
||||
})
|
||||
_BENCH_ROLES = frozenset({
|
||||
"generator", "vlm", # LLM roles
|
||||
"stt", "alm", # speech recognition
|
||||
"tts", # speech synthesis
|
||||
"vision", "embedding", # image understanding
|
||||
"classifier", # audio classification (cf-voice)
|
||||
})
|
||||
|
||||
|
||||
@router.get("/models")
|
||||
def get_models() -> dict:
|
||||
"""Return model list from bench_models.yaml."""
|
||||
"""Return model list from bench_models.yaml merged with locally installed models.
|
||||
|
||||
bench_models.yaml entries are listed first and take precedence; any installed
|
||||
model whose repo_id is already present in the YAML is skipped. Only models
|
||||
whose service is in _BENCH_SERVICES (cf-text, vllm, cf-stt, cf-tts, cf-vision,
|
||||
cf-voice) are surfaced from the installed registry.
|
||||
"""
|
||||
cfg = _load_cforch_config()
|
||||
models_path = cfg.get("bench_models", "")
|
||||
if not models_path:
|
||||
return {"models": []}
|
||||
|
||||
p = Path(models_path)
|
||||
if not p.exists():
|
||||
return {"models": []}
|
||||
|
||||
try:
|
||||
raw = yaml.safe_load(p.read_text(encoding="utf-8")) or {}
|
||||
except yaml.YAMLError as exc:
|
||||
logger.warning("Failed to parse bench_models.yaml %s: %s", p, exc)
|
||||
return {"models": []}
|
||||
|
||||
models_raw = raw.get("models", []) or []
|
||||
models: list[dict] = []
|
||||
for m in models_raw:
|
||||
if not isinstance(m, dict):
|
||||
continue
|
||||
models.append({
|
||||
"name": m.get("name", ""),
|
||||
"id": m.get("id", ""),
|
||||
"service": m.get("service", "ollama"),
|
||||
"tags": m.get("tags", []) or [],
|
||||
"vram_estimate_mb": m.get("vram_estimate_mb", 0),
|
||||
})
|
||||
bench_ids: set[str] = set()
|
||||
|
||||
if models_path:
|
||||
p = Path(models_path)
|
||||
if p.exists():
|
||||
try:
|
||||
raw = yaml.safe_load(p.read_text(encoding="utf-8")) or {}
|
||||
except yaml.YAMLError as exc:
|
||||
logger.warning("Failed to parse bench_models.yaml %s: %s", p, exc)
|
||||
raw = {}
|
||||
for m in (raw.get("models", []) or []):
|
||||
if not isinstance(m, dict):
|
||||
continue
|
||||
model_id = m.get("id", "")
|
||||
models.append({
|
||||
"name": m.get("name", ""),
|
||||
"id": model_id,
|
||||
"service": m.get("service", "ollama"),
|
||||
"tags": m.get("tags", []) or [],
|
||||
"vram_estimate_mb": m.get("vram_estimate_mb", 0),
|
||||
})
|
||||
if model_id:
|
||||
bench_ids.add(model_id)
|
||||
|
||||
# Merge installed generator models not already in bench_models.yaml.
|
||||
try:
|
||||
from app.models import list_installed # local import avoids circular dependency at module load
|
||||
for installed in list_installed():
|
||||
model_id: str = installed.get("model_id") or ""
|
||||
service: str = installed.get("service") or ""
|
||||
role: str = installed.get("role") or ""
|
||||
if not model_id:
|
||||
continue
|
||||
if service not in _BENCH_SERVICES or role not in _BENCH_ROLES:
|
||||
continue
|
||||
if model_id in bench_ids:
|
||||
continue
|
||||
display_name = model_id.split("/", 1)[-1] if "/" in model_id else model_id
|
||||
models.append({
|
||||
"name": display_name,
|
||||
"id": model_id,
|
||||
"service": service,
|
||||
"tags": [role],
|
||||
"vram_estimate_mb": installed.get("vram_mb") or 0,
|
||||
})
|
||||
bench_ids.add(model_id)
|
||||
except Exception as exc:
|
||||
logger.warning("Could not merge installed models into model list: %s", exc)
|
||||
|
||||
return {"models": models}
|
||||
|
||||
|
||||
# ── GET /run ───────────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/nodes")
|
||||
def get_nodes() -> dict:
|
||||
"""Proxy the coordinator's /api/nodes list, returning node_id + online status.
|
||||
|
||||
Online is inferred from last_heartbeat: any node with a recent heartbeat is online.
|
||||
Returns an empty list if the coordinator is unreachable.
|
||||
"""
|
||||
cfg = _load_cforch_config()
|
||||
coordinator_url = cfg.get("coordinator_url", "").rstrip("/")
|
||||
if not coordinator_url:
|
||||
return {"nodes": []}
|
||||
try:
|
||||
import httpx as _httpx
|
||||
resp = _httpx.get(f"{coordinator_url}/api/nodes", timeout=5.0)
|
||||
resp.raise_for_status()
|
||||
raw_nodes = resp.json().get("nodes", [])
|
||||
return {
|
||||
"nodes": [
|
||||
{
|
||||
"node_id": n.get("node_id", ""),
|
||||
"online": n.get("last_heartbeat") is not None,
|
||||
"gpus": [
|
||||
{
|
||||
"gpu_id": g.get("gpu_id"),
|
||||
"name": g.get("name", ""),
|
||||
"vram_total_mb": g.get("vram_total_mb", 0),
|
||||
"vram_free_mb": g.get("vram_free_mb", 0),
|
||||
}
|
||||
for g in n.get("gpus", [])
|
||||
],
|
||||
}
|
||||
for n in raw_nodes
|
||||
]
|
||||
}
|
||||
except Exception as exc:
|
||||
logger.warning("Could not fetch nodes from coordinator: %s", exc)
|
||||
return {"nodes": []}
|
||||
|
||||
|
||||
@router.get("/run")
|
||||
def run_benchmark(
|
||||
task_ids: str = "",
|
||||
model_ids: str = "",
|
||||
model_tags: str = "",
|
||||
coordinator_url: str = "",
|
||||
ollama_url: str = "",
|
||||
judge_url: str = "",
|
||||
judge_backend: str = "chat",
|
||||
workers: int = 1,
|
||||
node_ids: str = "",
|
||||
) -> StreamingResponse:
|
||||
"""Spawn cf-orch benchmark.py and stream stdout as SSE progress events."""
|
||||
global _BENCH_RUNNING, _bench_proc
|
||||
|
|
@ -205,6 +323,13 @@ def run_benchmark(
|
|||
cfg_coordinator = cfg.get("coordinator_url", "")
|
||||
cfg_ollama = cfg.get("ollama_url", "")
|
||||
cfg_license_key = cfg.get("license_key", "")
|
||||
cfg_judge_url = cfg.get("judge_url", "")
|
||||
|
||||
# Validate URL params before spawning the subprocess.
|
||||
# _validate_service_url raises HTTPException on bad input (caught by FastAPI before streaming starts).
|
||||
_validate_service_url(coordinator_url, "coordinator_url")
|
||||
_validate_service_url(ollama_url, "ollama_url")
|
||||
_validate_service_url(judge_url, "judge_url")
|
||||
|
||||
def generate():
|
||||
global _BENCH_RUNNING, _bench_proc
|
||||
|
|
@ -213,16 +338,68 @@ def run_benchmark(
|
|||
yield f"data: {json.dumps({'type': 'error', 'message': 'bench_script not configured or not found'})}\n\n"
|
||||
return
|
||||
|
||||
# Build effective models file: bench_models.yaml + any installed models
|
||||
# whose IDs were selected but are absent from the YAML (e.g. downloaded
|
||||
# via the Models view). Written to a temp file so benchmark.py sees one
|
||||
# unified list; cleaned up in the finally block.
|
||||
effective_models_file = bench_models
|
||||
_tmp_models_path: str | None = None
|
||||
|
||||
if model_ids and bench_models and Path(bench_models).exists():
|
||||
requested_ids = set(model_ids.split(","))
|
||||
try:
|
||||
raw_bench = yaml.safe_load(Path(bench_models).read_text(encoding="utf-8")) or {}
|
||||
bench_entries: list[dict] = raw_bench.get("models", []) or []
|
||||
bench_id_set = {m.get("id", "") for m in bench_entries if isinstance(m, dict)}
|
||||
missing_ids = requested_ids - bench_id_set
|
||||
if missing_ids:
|
||||
from app.models import list_installed
|
||||
installed_map = {
|
||||
m["model_id"]: m
|
||||
for m in list_installed()
|
||||
if m.get("model_id") and m.get("service") in _BENCH_SERVICES
|
||||
}
|
||||
extra: list[dict] = []
|
||||
for mid in missing_ids:
|
||||
if mid in installed_map:
|
||||
inst = installed_map[mid]
|
||||
entry: dict[str, Any] = {
|
||||
"id": mid,
|
||||
"name": mid.split("/", 1)[-1] if "/" in mid else mid,
|
||||
"service": inst.get("service", "cf-text"),
|
||||
"vram_estimate_mb": inst.get("vram_mb") or 0,
|
||||
"tags": [inst.get("role", "generator")],
|
||||
"temperature": 0.0,
|
||||
}
|
||||
local_path = inst.get("path", "") or inst.get("local_path", "")
|
||||
if local_path:
|
||||
entry["model_path"] = local_path
|
||||
extra.append(entry)
|
||||
if extra:
|
||||
merged = {"models": bench_entries + extra}
|
||||
tf = tempfile.NamedTemporaryFile(
|
||||
mode="w", suffix=".yaml", delete=False,
|
||||
prefix="avocet_bench_models_",
|
||||
)
|
||||
yaml.dump(merged, tf)
|
||||
tf.close()
|
||||
_tmp_models_path = tf.name
|
||||
effective_models_file = _tmp_models_path
|
||||
except Exception as exc:
|
||||
logger.warning("Could not merge installed models into temp bench file: %s", exc)
|
||||
|
||||
cmd = [
|
||||
python_bin,
|
||||
bench_script,
|
||||
"--tasks", bench_tasks,
|
||||
"--models", bench_models,
|
||||
"--models", effective_models_file,
|
||||
"--output", results_dir,
|
||||
]
|
||||
|
||||
if task_ids:
|
||||
cmd.extend(["--filter-tasks"] + task_ids.split(","))
|
||||
if model_ids:
|
||||
cmd.extend(["--filter-models"] + model_ids.split(","))
|
||||
if model_tags:
|
||||
cmd.extend(["--filter-tags"] + model_tags.split(","))
|
||||
|
||||
|
|
@ -233,6 +410,15 @@ def run_benchmark(
|
|||
cmd.extend(["--coordinator", effective_coordinator])
|
||||
if effective_ollama:
|
||||
cmd.extend(["--ollama-url", effective_ollama])
|
||||
effective_judge = judge_url if judge_url else cfg_judge_url
|
||||
if effective_judge:
|
||||
cmd.extend(["--judge-url", effective_judge])
|
||||
if judge_backend and judge_backend != "chat":
|
||||
cmd.extend(["--judge-backend", judge_backend])
|
||||
if workers > 1:
|
||||
cmd.extend(["--workers", str(workers)])
|
||||
if node_ids:
|
||||
cmd.extend(["--nodes"] + node_ids.split(","))
|
||||
|
||||
# Pass license key as env var so subprocess can authenticate with cf-orch
|
||||
proc_env = {**os.environ}
|
||||
|
|
@ -273,6 +459,11 @@ def run_benchmark(
|
|||
yield f"data: {json.dumps({'type': 'error', 'message': str(exc)})}\n\n"
|
||||
finally:
|
||||
_BENCH_RUNNING = False
|
||||
if _tmp_models_path:
|
||||
try:
|
||||
os.unlink(_tmp_models_path)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
return StreamingResponse(
|
||||
generate(),
|
||||
|
|
@ -295,6 +486,7 @@ def get_cforch_config() -> dict:
|
|||
"coordinator_url": cfg.get("coordinator_url", ""),
|
||||
"ollama_url": cfg.get("ollama_url", ""),
|
||||
"ollama_model": cfg.get("ollama_model", ""),
|
||||
"judge_url": cfg.get("judge_url", ""),
|
||||
"license_key_set": bool(cfg.get("license_key", "")),
|
||||
"source": "env" if not _config_file().exists() else "yaml+env",
|
||||
}
|
||||
|
|
@ -303,7 +495,7 @@ def get_cforch_config() -> dict:
|
|||
# ── GET /results ───────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/results")
|
||||
def get_results() -> dict:
|
||||
def get_results() -> list:
|
||||
"""Return the latest benchmark summary.json from results_dir."""
|
||||
cfg = _load_cforch_config()
|
||||
results_dir = cfg.get("results_dir", "")
|
||||
|
|
|
|||
191
app/dashboard.py
Normal file
191
app/dashboard.py
Normal file
|
|
@ -0,0 +1,191 @@
|
|||
"""Avocet -- dashboard aggregate API.
|
||||
|
||||
GET /api/dashboard returns the current flywheel state:
|
||||
labeled_since_last_eval -- items labeled after the most recent eval run
|
||||
last_eval_timestamp -- ISO timestamp of newest bench_results summary
|
||||
last_eval_best_score -- best macro_f1 from that summary
|
||||
active_jobs -- jobs with status queued or running
|
||||
corrections_pending -- sft_candidates with status=needs_review
|
||||
corrections_export_ready -- approved sft candidates with non-blank correction
|
||||
signals -- computed booleans for UI nudge indicators
|
||||
|
||||
Thresholds in label_tool.yaml pipeline: section:
|
||||
pipeline:
|
||||
data_eval_threshold: 50 # labeled items since last eval to trigger nudge
|
||||
eval_train_threshold: 0.05 # improvement delta needed before retraining (future)
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import yaml
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_ROOT = Path(__file__).parent.parent
|
||||
_DATA_DIR: Path = _ROOT / "data"
|
||||
_CONFIG_DIR: Path | None = None
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
_DEFAULT_DATA_EVAL_THRESHOLD = 50
|
||||
_DEFAULT_EVAL_TRAIN_THRESHOLD = 0.05
|
||||
|
||||
|
||||
def set_data_dir(path: Path) -> None:
|
||||
global _DATA_DIR
|
||||
_DATA_DIR = path
|
||||
|
||||
def set_config_dir(path: Path | None) -> None:
|
||||
global _CONFIG_DIR
|
||||
_CONFIG_DIR = path
|
||||
|
||||
def _config_file() -> Path:
|
||||
if _CONFIG_DIR is not None:
|
||||
return _CONFIG_DIR / "label_tool.yaml"
|
||||
return _ROOT / "config" / "label_tool.yaml"
|
||||
|
||||
def _load_thresholds() -> tuple[int, float]:
|
||||
f = _config_file()
|
||||
if f.exists():
|
||||
try:
|
||||
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
||||
pipeline = raw.get("pipeline", {}) or {}
|
||||
return (
|
||||
int(pipeline.get("data_eval_threshold", _DEFAULT_DATA_EVAL_THRESHOLD)),
|
||||
float(pipeline.get("eval_train_threshold", _DEFAULT_EVAL_TRAIN_THRESHOLD)),
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to read pipeline thresholds: %s", exc)
|
||||
return _DEFAULT_DATA_EVAL_THRESHOLD, _DEFAULT_EVAL_TRAIN_THRESHOLD
|
||||
|
||||
def _load_score_records() -> list[dict]:
|
||||
path = _DATA_DIR / "email_score.jsonl"
|
||||
if not path.exists():
|
||||
return []
|
||||
records = []
|
||||
for line in path.read_text(encoding="utf-8").splitlines():
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
records.append(json.loads(line))
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
return records
|
||||
|
||||
def _find_latest_eval(results_dir_override: str = "") -> tuple[str | None, float | None]:
|
||||
"""Return (iso_timestamp, best_macro_f1) from the newest bench_results summary.
|
||||
|
||||
Checks results_dir from cforch config if set, then falls back to
|
||||
_ROOT/bench_results/. Returns (None, None) if no results exist.
|
||||
"""
|
||||
candidates = []
|
||||
if results_dir_override:
|
||||
candidates.append(Path(results_dir_override))
|
||||
else:
|
||||
f = _config_file()
|
||||
if f.exists():
|
||||
try:
|
||||
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
||||
rd = (raw.get("cforch", {}) or {}).get("results_dir", "")
|
||||
if rd:
|
||||
candidates.append(Path(rd))
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to read cforch.results_dir from config: %s", exc)
|
||||
candidates.append(_ROOT / "bench_results")
|
||||
|
||||
for rdir in candidates:
|
||||
if not rdir.exists():
|
||||
continue
|
||||
subdirs = sorted([d for d in rdir.iterdir() if d.is_dir()], key=lambda d: d.name)
|
||||
for subdir in reversed(subdirs):
|
||||
summary = subdir / "summary.json"
|
||||
if summary.exists():
|
||||
try:
|
||||
data = json.loads(summary.read_text(encoding="utf-8"))
|
||||
ts = data.get("timestamp") or subdir.name
|
||||
score = data.get("best_macro_f1") or data.get("macro_f1")
|
||||
return ts, (float(score) if isinstance(score, (int, float)) else None)
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to parse summary.json at %s: %s", summary, exc)
|
||||
return None, None
|
||||
|
||||
def _count_corrections() -> tuple[int, int]:
|
||||
"""Return (pending_count, export_ready_count)."""
|
||||
pending = 0
|
||||
export_ready = 0
|
||||
candidates_path = _DATA_DIR / "sft_candidates.jsonl"
|
||||
approved_path = _DATA_DIR / "sft_approved.jsonl"
|
||||
if candidates_path.exists():
|
||||
for line in candidates_path.read_text(encoding="utf-8").splitlines():
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
r = json.loads(line)
|
||||
if r.get("status") == "needs_review":
|
||||
pending += 1
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
if approved_path.exists():
|
||||
for line in approved_path.read_text(encoding="utf-8").splitlines():
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
r = json.loads(line)
|
||||
if (r.get("status") == "approved"
|
||||
and r.get("corrected_response")
|
||||
and str(r["corrected_response"]).strip()):
|
||||
export_ready += 1
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
return pending, export_ready
|
||||
|
||||
def _get_active_jobs() -> list[dict]:
|
||||
"""Query train SQLite DB for queued/running jobs. Returns [] if DB absent."""
|
||||
try:
|
||||
from app.train.train import _DB_PATH, _db, _init_db
|
||||
if not _DB_PATH.exists():
|
||||
return []
|
||||
_init_db()
|
||||
with _db() as conn:
|
||||
rows = conn.execute(
|
||||
"SELECT id, type, model_key, status FROM jobs WHERE status IN ('queued', 'running')"
|
||||
).fetchall()
|
||||
return [{"id": r["id"], "type": r["type"], "model_key": r["model_key"], "status": r["status"]} for r in rows]
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to query train jobs DB: %s", exc)
|
||||
return []
|
||||
|
||||
def _count_labeled_since(since_ts: str | None) -> int:
|
||||
records = _load_score_records()
|
||||
if since_ts is None:
|
||||
return len(records)
|
||||
return sum(1 for r in records if r.get("labeled_at", "") > since_ts)
|
||||
|
||||
|
||||
@router.get("/dashboard")
|
||||
def get_dashboard() -> dict:
|
||||
data_eval_threshold, eval_train_threshold = _load_thresholds()
|
||||
last_eval_ts, last_eval_score = _find_latest_eval()
|
||||
labeled_since = _count_labeled_since(last_eval_ts)
|
||||
corrections_pending, corrections_export_ready = _count_corrections()
|
||||
active_jobs = _get_active_jobs()
|
||||
return {
|
||||
"labeled_since_last_eval": labeled_since,
|
||||
"last_eval_timestamp": last_eval_ts,
|
||||
"last_eval_best_score": last_eval_score,
|
||||
"active_jobs": active_jobs,
|
||||
"corrections_pending": corrections_pending,
|
||||
"corrections_export_ready": corrections_export_ready,
|
||||
"signals": {
|
||||
"data_to_eval": labeled_since >= data_eval_threshold,
|
||||
"eval_to_train": False, # future: implement delta-F1 comparison
|
||||
"train_to_fleet": False, # future: implement fleet sync signal
|
||||
},
|
||||
}
|
||||
0
app/data/__init__.py
Normal file
0
app/data/__init__.py
Normal file
393
app/data/corrections.py
Normal file
393
app/data/corrections.py
Normal file
|
|
@ -0,0 +1,393 @@
|
|||
"""Avocet -- SFT candidate corrections API (moved from app/sft.py).
|
||||
|
||||
All endpoints are registered on `router` (a FastAPI APIRouter).
|
||||
Primary prefix: /api/corrections (backward-compat alias: /api/sft -- pending Vue SPA migration)
|
||||
|
||||
Module-level globals (_DATA_DIR, _CONFIG_DIR) follow the same
|
||||
testability pattern as api.py -- override them via set_data_dir() and
|
||||
set_config_dir() in test fixtures.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
import yaml
|
||||
from fastapi import APIRouter, Header, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.utils import append_jsonl, read_jsonl, write_jsonl
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_ROOT = Path(__file__).parent.parent.parent
|
||||
_DATA_DIR: Path = _ROOT / "data"
|
||||
_CONFIG_DIR: Path | None = None
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# -- Testability seams ---------------------------------------------------------
|
||||
|
||||
def set_data_dir(path: Path) -> None:
|
||||
global _DATA_DIR
|
||||
_DATA_DIR = path
|
||||
|
||||
|
||||
def set_config_dir(path: Path | None) -> None:
|
||||
global _CONFIG_DIR
|
||||
_CONFIG_DIR = path
|
||||
|
||||
|
||||
# -- Internal helpers ----------------------------------------------------------
|
||||
|
||||
def _config_file() -> Path:
|
||||
if _CONFIG_DIR is not None:
|
||||
return _CONFIG_DIR / "label_tool.yaml"
|
||||
return _ROOT / "config" / "label_tool.yaml"
|
||||
|
||||
|
||||
_DEFAULT_BENCH_RESULTS_DIR = "/Library/Development/CircuitForge/circuitforge-orch/scripts/bench_results"
|
||||
|
||||
|
||||
def set_default_bench_results_dir(path: str) -> None:
|
||||
"""Override the default bench_results_dir -- used by tests to avoid real filesystem."""
|
||||
global _DEFAULT_BENCH_RESULTS_DIR
|
||||
_DEFAULT_BENCH_RESULTS_DIR = path
|
||||
|
||||
|
||||
def _get_bench_results_dir() -> Path:
|
||||
f = _config_file()
|
||||
if f.exists():
|
||||
try:
|
||||
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
||||
d = raw.get("sft", {}).get("bench_results_dir", "")
|
||||
if d:
|
||||
return Path(d)
|
||||
except yaml.YAMLError as exc:
|
||||
logger.warning("Failed to parse SFT config %s: %s", f, exc)
|
||||
return Path(_DEFAULT_BENCH_RESULTS_DIR)
|
||||
|
||||
|
||||
def _candidates_file() -> Path:
|
||||
return _DATA_DIR / "sft_candidates.jsonl"
|
||||
|
||||
|
||||
def _approved_file() -> Path:
|
||||
return _DATA_DIR / "sft_approved.jsonl"
|
||||
|
||||
|
||||
def _read_candidates() -> list[dict]:
|
||||
return read_jsonl(_candidates_file())
|
||||
|
||||
|
||||
def _write_candidates(records: list[dict]) -> None:
|
||||
write_jsonl(_candidates_file(), records)
|
||||
|
||||
|
||||
def _is_exportable(r: dict) -> bool:
|
||||
"""Return True if an approved record is ready to include in SFT export."""
|
||||
return (
|
||||
r.get("status") == "approved"
|
||||
and bool(r.get("corrected_response"))
|
||||
and str(r["corrected_response"]).strip() != ""
|
||||
)
|
||||
|
||||
|
||||
# -- GET /runs -----------------------------------------------------------------
|
||||
|
||||
@router.get("/runs")
|
||||
def get_runs():
|
||||
"""List available benchmark runs in the configured bench_results_dir."""
|
||||
from scripts.sft_import import discover_runs
|
||||
bench_dir = _get_bench_results_dir()
|
||||
existing = _read_candidates()
|
||||
# benchmark_run_id in each record equals the run's directory name by cf-orch convention
|
||||
imported_run_ids = {
|
||||
r["benchmark_run_id"]
|
||||
for r in existing
|
||||
if r.get("benchmark_run_id") is not None
|
||||
}
|
||||
runs = discover_runs(bench_dir)
|
||||
return [
|
||||
{
|
||||
"run_id": r["run_id"],
|
||||
"timestamp": r["timestamp"],
|
||||
"candidate_count": r["candidate_count"],
|
||||
"already_imported": r["run_id"] in imported_run_ids,
|
||||
}
|
||||
for r in runs
|
||||
]
|
||||
|
||||
|
||||
# -- POST /import --------------------------------------------------------------
|
||||
|
||||
class ImportRequest(BaseModel):
|
||||
run_id: str
|
||||
|
||||
|
||||
@router.post("/import")
|
||||
def post_import(req: ImportRequest):
|
||||
"""Import one benchmark run's sft_candidates.jsonl into the local data dir."""
|
||||
from scripts.sft_import import discover_runs, import_run
|
||||
bench_dir = _get_bench_results_dir()
|
||||
runs = discover_runs(bench_dir)
|
||||
run = next((r for r in runs if r["run_id"] == req.run_id), None)
|
||||
if run is None:
|
||||
raise HTTPException(404, f"Run {req.run_id!r} not found in bench_results_dir")
|
||||
return import_run(run["sft_path"], _DATA_DIR)
|
||||
|
||||
|
||||
# -- GET /queue ----------------------------------------------------------------
|
||||
|
||||
@router.get("/queue")
|
||||
def get_queue(page: int = 1, per_page: int = 20):
|
||||
"""Return paginated needs_review candidates."""
|
||||
records = _read_candidates()
|
||||
pending = [r for r in records if r.get("status") == "needs_review"]
|
||||
start = (page - 1) * per_page
|
||||
return {
|
||||
"items": pending[start:start + per_page],
|
||||
"total": len(pending),
|
||||
"page": page,
|
||||
"per_page": per_page,
|
||||
}
|
||||
|
||||
|
||||
# -- POST /submit --------------------------------------------------------------
|
||||
|
||||
FailureCategory = Literal[
|
||||
"scoring_artifact",
|
||||
"style_violation",
|
||||
"partial_answer",
|
||||
"wrong_answer",
|
||||
"format_error",
|
||||
"hallucination",
|
||||
]
|
||||
|
||||
|
||||
class SubmitRequest(BaseModel):
|
||||
id: str
|
||||
action: Literal["correct", "discard", "flag"]
|
||||
corrected_response: str | None = None
|
||||
failure_category: FailureCategory | None = None
|
||||
|
||||
|
||||
@router.post("/submit")
|
||||
def post_submit(req: SubmitRequest):
|
||||
"""Record a reviewer decision for one SFT candidate."""
|
||||
if req.action == "correct":
|
||||
if not req.corrected_response or not req.corrected_response.strip():
|
||||
raise HTTPException(422, "corrected_response must be non-empty when action is 'correct'")
|
||||
|
||||
records = _read_candidates()
|
||||
idx = next((i for i, r in enumerate(records) if r.get("id") == req.id), None)
|
||||
if idx is None:
|
||||
raise HTTPException(404, f"Record {req.id!r} not found")
|
||||
|
||||
record = records[idx]
|
||||
if record.get("status") != "needs_review":
|
||||
raise HTTPException(409, f"Record is not in needs_review state (current: {record.get('status')})")
|
||||
|
||||
if req.action == "correct":
|
||||
records[idx] = {
|
||||
**record,
|
||||
"status": "approved",
|
||||
"corrected_response": req.corrected_response,
|
||||
"failure_category": req.failure_category,
|
||||
}
|
||||
_write_candidates(records)
|
||||
append_jsonl(_approved_file(), records[idx])
|
||||
elif req.action == "discard":
|
||||
records[idx] = {**record, "status": "discarded"}
|
||||
_write_candidates(records)
|
||||
else: # flag
|
||||
records[idx] = {**record, "status": "model_rejected"}
|
||||
_write_candidates(records)
|
||||
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
# -- POST /undo ----------------------------------------------------------------
|
||||
|
||||
class UndoRequest(BaseModel):
|
||||
id: str
|
||||
|
||||
|
||||
@router.post("/undo")
|
||||
def post_undo(req: UndoRequest):
|
||||
"""Restore a previously actioned candidate back to needs_review."""
|
||||
records = _read_candidates()
|
||||
idx = next((i for i, r in enumerate(records) if r.get("id") == req.id), None)
|
||||
if idx is None:
|
||||
raise HTTPException(404, f"Record {req.id!r} not found")
|
||||
|
||||
record = records[idx]
|
||||
old_status = record.get("status")
|
||||
if old_status == "needs_review":
|
||||
raise HTTPException(409, "Record is already in needs_review state")
|
||||
|
||||
records[idx] = {**record, "status": "needs_review", "corrected_response": None}
|
||||
_write_candidates(records)
|
||||
|
||||
# If it was approved, remove from the approved file too
|
||||
if old_status == "approved":
|
||||
approved = read_jsonl(_approved_file())
|
||||
write_jsonl(_approved_file(), [r for r in approved if r.get("id") != req.id])
|
||||
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
# -- GET /export ---------------------------------------------------------------
|
||||
|
||||
@router.get("/export")
|
||||
def get_export() -> StreamingResponse:
|
||||
"""Stream approved records as SFT-ready JSONL for download."""
|
||||
exportable = [r for r in read_jsonl(_approved_file()) if _is_exportable(r)]
|
||||
|
||||
def generate():
|
||||
for r in exportable:
|
||||
record = {
|
||||
"messages": r.get("prompt_messages", []) + [
|
||||
{"role": "assistant", "content": r["corrected_response"]}
|
||||
]
|
||||
}
|
||||
yield json.dumps(record) + "\n"
|
||||
|
||||
timestamp = datetime.now(timezone.utc).strftime("%Y%m%d-%H%M%S")
|
||||
return StreamingResponse(
|
||||
generate(),
|
||||
media_type="application/x-ndjson",
|
||||
headers={
|
||||
"Content-Disposition": f'attachment; filename="sft_export_{timestamp}.jsonl"'
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# -- GET /stats ----------------------------------------------------------------
|
||||
|
||||
@router.get("/stats")
|
||||
def get_stats() -> dict[str, object]:
|
||||
"""Return counts by status, model, and task type."""
|
||||
records = _read_candidates()
|
||||
by_status: dict[str, int] = {}
|
||||
by_model: dict[str, int] = {}
|
||||
by_task_type: dict[str, int] = {}
|
||||
|
||||
for r in records:
|
||||
status = r.get("status", "unknown")
|
||||
by_status[status] = by_status.get(status, 0) + 1
|
||||
model = r.get("model_name", "unknown")
|
||||
by_model[model] = by_model.get(model, 0) + 1
|
||||
task_type = r.get("task_type", "unknown")
|
||||
by_task_type[task_type] = by_task_type.get(task_type, 0) + 1
|
||||
|
||||
approved = read_jsonl(_approved_file())
|
||||
export_ready = sum(1 for r in approved if _is_exportable(r))
|
||||
|
||||
return {
|
||||
"total": len(records),
|
||||
"by_status": by_status,
|
||||
"by_model": by_model,
|
||||
"by_task_type": by_task_type,
|
||||
"export_ready": export_ready,
|
||||
}
|
||||
|
||||
|
||||
# -- GET /config ---------------------------------------------------------------
|
||||
|
||||
@router.get("/config")
|
||||
def get_sft_config() -> dict:
|
||||
"""Return the current SFT configuration (bench_results_dir)."""
|
||||
f = _config_file()
|
||||
if not f.exists():
|
||||
return {"bench_results_dir": ""}
|
||||
try:
|
||||
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
||||
except yaml.YAMLError:
|
||||
return {"bench_results_dir": ""}
|
||||
sft_section = raw.get("sft") or {}
|
||||
return {"bench_results_dir": sft_section.get("bench_results_dir", "")}
|
||||
|
||||
|
||||
class SftConfigPayload(BaseModel):
|
||||
bench_results_dir: str
|
||||
|
||||
|
||||
@router.post("/config")
|
||||
def post_sft_config(payload: SftConfigPayload) -> dict:
|
||||
"""Write the bench_results_dir setting to the config file."""
|
||||
f = _config_file()
|
||||
f.parent.mkdir(parents=True, exist_ok=True)
|
||||
try:
|
||||
raw = yaml.safe_load(f.read_text(encoding="utf-8")) if f.exists() else {}
|
||||
raw = raw or {}
|
||||
except yaml.YAMLError:
|
||||
raw = {}
|
||||
raw["sft"] = {"bench_results_dir": payload.bench_results_dir}
|
||||
tmp = f.with_suffix(".tmp")
|
||||
tmp.write_text(yaml.dump(raw, allow_unicode=True, sort_keys=False), encoding="utf-8")
|
||||
tmp.rename(f)
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
# -- POST /ingest --------------------------------------------------------------
|
||||
|
||||
class IngestRequest(BaseModel):
|
||||
source: str # e.g. "peregrine", "kiwi"
|
||||
task_type: str # e.g. "email_classification", "recipe_suggestion"
|
||||
prompt: str # the prompt that was sent to the LLM
|
||||
response: str # the LLM's original response
|
||||
correction: str # the human-corrected response
|
||||
label: str | None = None # optional label/category
|
||||
|
||||
|
||||
@router.post("/ingest")
|
||||
def post_ingest(
|
||||
req: IngestRequest,
|
||||
authorization: str | None = Header(default=None),
|
||||
) -> dict:
|
||||
"""Ingest a correction from a sibling CF product.
|
||||
|
||||
Authentication: Authorization: Bearer <AVOCET_INGESTION_SECRET>
|
||||
|
||||
Creates a sft_candidates record with status='approved' (pre-approved by
|
||||
the calling product -- human review already happened upstream). Also writes
|
||||
to sft_approved.jsonl so it is immediately included in export counts.
|
||||
|
||||
Returns {"ok": True, "id": "<uuid>"}.
|
||||
"""
|
||||
expected_secret = os.environ.get("AVOCET_INGESTION_SECRET", "")
|
||||
if not expected_secret:
|
||||
raise HTTPException(503, "Ingestion not configured -- AVOCET_INGESTION_SECRET not set")
|
||||
|
||||
if not authorization or not authorization.startswith("Bearer "):
|
||||
raise HTTPException(401, "Missing or malformed Authorization header")
|
||||
|
||||
token = authorization.removeprefix("Bearer ").strip()
|
||||
if token != expected_secret:
|
||||
raise HTTPException(403, "Invalid ingestion secret")
|
||||
|
||||
record_id = str(uuid.uuid4())
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
record = {
|
||||
"id": record_id,
|
||||
"source": req.source,
|
||||
"task_type": req.task_type,
|
||||
"status": "approved",
|
||||
"prompt_messages": [{"role": "user", "content": req.prompt}],
|
||||
"model_response": req.response,
|
||||
"corrected_response": req.correction,
|
||||
"label": req.label,
|
||||
"timestamp": now,
|
||||
"benchmark_run_id": None,
|
||||
}
|
||||
append_jsonl(_candidates_file(), record)
|
||||
append_jsonl(_approved_file(), record)
|
||||
return {"ok": True, "id": record_id}
|
||||
243
app/data/fetch.py
Normal file
243
app/data/fetch.py
Normal file
|
|
@ -0,0 +1,243 @@
|
|||
"""Avocet -- IMAP fetch utilities and fetch API routes.
|
||||
|
||||
All IMAP helper functions (from app/imap_fetch.py) plus the
|
||||
/api/accounts/test and /api/fetch/stream endpoints.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import email as _email_lib
|
||||
import hashlib
|
||||
import imaplib
|
||||
import json
|
||||
import yaml
|
||||
from datetime import datetime, timedelta
|
||||
from email.header import decode_header as _raw_decode
|
||||
from pathlib import Path
|
||||
from typing import Iterator
|
||||
|
||||
from fastapi import APIRouter, Query
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.utils import extract_body, read_jsonl, write_jsonl
|
||||
|
||||
_ROOT = Path(__file__).parent.parent.parent
|
||||
_DATA_DIR: Path = _ROOT / "data"
|
||||
_CONFIG_DIR: Path | None = None
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def set_data_dir(path: Path) -> None:
|
||||
global _DATA_DIR
|
||||
_DATA_DIR = path
|
||||
|
||||
|
||||
def set_config_dir(path: Path | None) -> None:
|
||||
global _CONFIG_DIR
|
||||
_CONFIG_DIR = path
|
||||
|
||||
|
||||
def _config_file() -> Path:
|
||||
if _CONFIG_DIR is not None:
|
||||
return _CONFIG_DIR / "label_tool.yaml"
|
||||
return _ROOT / "config" / "label_tool.yaml"
|
||||
|
||||
|
||||
def _queue_file() -> Path:
|
||||
return _DATA_DIR / "email_label_queue.jsonl"
|
||||
|
||||
|
||||
def _get_config_accounts() -> list[dict]:
|
||||
f = _config_file()
|
||||
if not f.exists():
|
||||
return []
|
||||
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
||||
return raw.get("accounts", [])
|
||||
|
||||
|
||||
# ── IMAP decode helpers ───────────────────────────────────────────────────────
|
||||
|
||||
def _decode_str(value: str | None) -> str:
|
||||
if not value:
|
||||
return ""
|
||||
parts = _raw_decode(value)
|
||||
out = []
|
||||
for part, enc in parts:
|
||||
if isinstance(part, bytes):
|
||||
out.append(part.decode(enc or "utf-8", errors="replace"))
|
||||
else:
|
||||
out.append(str(part))
|
||||
return " ".join(out).strip()
|
||||
|
||||
|
||||
def entry_key(e: dict) -> str:
|
||||
"""Stable MD5 content-hash for dedup — matches label_tool.py _entry_key."""
|
||||
key = (e.get("subject", "") + (e.get("body", "") or "")[:100])
|
||||
return hashlib.md5(key.encode("utf-8", errors="replace")).hexdigest()
|
||||
|
||||
|
||||
# ── Wide search terms ────────────────────────────────────────────────────────
|
||||
|
||||
_WIDE_TERMS = [
|
||||
"interview", "phone screen", "video call", "zoom link", "schedule a call",
|
||||
"offer letter", "job offer", "offer of employment", "pleased to offer",
|
||||
"unfortunately", "not moving forward", "other candidates", "regret to inform",
|
||||
"no longer", "decided not to", "decided to go with",
|
||||
"opportunity", "interested in your background", "reached out", "great fit",
|
||||
"exciting role", "love to connect",
|
||||
"assessment", "questionnaire", "culture fit", "culture-fit", "online assessment",
|
||||
"application received", "thank you for applying", "application confirmation",
|
||||
"you applied", "your application for",
|
||||
"reschedule", "rescheduled", "new time", "moved to", "postponed", "new date",
|
||||
"job digest", "jobs you may like", "recommended jobs", "jobs for you",
|
||||
"new jobs", "job alert",
|
||||
"came across your profile", "reaching out about", "great fit for a role",
|
||||
"exciting opportunity",
|
||||
"welcome to the team", "start date", "onboarding", "first day", "we're excited to have you",
|
||||
"application", "recruiter", "recruiting", "hiring", "candidate",
|
||||
]
|
||||
|
||||
|
||||
# ── Public API ────────────────────────────────────────────────────────────────
|
||||
|
||||
def test_connection(acc: dict) -> tuple[bool, str, int | None]:
|
||||
"""Connect, login, select folder. Returns (ok, human_message, message_count|None)."""
|
||||
host = acc.get("host", "")
|
||||
port = int(acc.get("port", 993))
|
||||
use_ssl = acc.get("use_ssl", True)
|
||||
username = acc.get("username", "")
|
||||
password = acc.get("password", "")
|
||||
folder = acc.get("folder", "INBOX")
|
||||
if not host or not username or not password:
|
||||
return False, "Host, username, and password are all required.", None
|
||||
try:
|
||||
conn = (imaplib.IMAP4_SSL if use_ssl else imaplib.IMAP4)(host, port)
|
||||
conn.login(username, password)
|
||||
_, data = conn.select(folder, readonly=True)
|
||||
count_raw = data[0].decode() if data and data[0] else "0"
|
||||
count = int(count_raw) if count_raw.isdigit() else 0
|
||||
conn.logout()
|
||||
return True, f"Connected — {count:,} message(s) in {folder}.", count
|
||||
except Exception as exc:
|
||||
return False, str(exc), None
|
||||
|
||||
|
||||
def fetch_account_stream(
|
||||
acc: dict,
|
||||
days_back: int,
|
||||
limit: int,
|
||||
known_keys: set[str],
|
||||
) -> Iterator[dict]:
|
||||
"""Generator — yields progress dicts while fetching emails via IMAP.
|
||||
|
||||
Mutates `known_keys` in place for cross-account dedup within one fetch session.
|
||||
|
||||
Yields event dicts with "type" key:
|
||||
{"type": "start", "account": str, "total_uids": int}
|
||||
{"type": "progress", "account": str, "fetched": int, "total_uids": int}
|
||||
{"type": "done", "account": str, "added": int, "skipped": int, "emails": list}
|
||||
"""
|
||||
name = acc.get("name", acc.get("username", "?"))
|
||||
host = acc.get("host", "imap.gmail.com")
|
||||
port = int(acc.get("port", 993))
|
||||
use_ssl = acc.get("use_ssl", True)
|
||||
username = acc["username"]
|
||||
password = acc["password"]
|
||||
folder = acc.get("folder", "INBOX")
|
||||
since = (datetime.now() - timedelta(days=days_back)).strftime("%d-%b-%Y")
|
||||
|
||||
conn = (imaplib.IMAP4_SSL if use_ssl else imaplib.IMAP4)(host, port)
|
||||
conn.login(username, password)
|
||||
conn.select(folder, readonly=True)
|
||||
|
||||
seen_uids: dict[bytes, None] = {}
|
||||
for term in _WIDE_TERMS:
|
||||
try:
|
||||
_, data = conn.search(None, f'(SUBJECT "{term}" SINCE "{since}")')
|
||||
for uid in (data[0] or b"").split():
|
||||
seen_uids[uid] = None
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
uids = list(seen_uids.keys())[: limit * 3]
|
||||
yield {"type": "start", "account": name, "total_uids": len(uids)}
|
||||
|
||||
emails: list[dict] = []
|
||||
skipped = 0
|
||||
for i, uid in enumerate(uids):
|
||||
if len(emails) >= limit:
|
||||
break
|
||||
if i % 5 == 0:
|
||||
yield {"type": "progress", "account": name, "fetched": len(emails), "total_uids": len(uids)}
|
||||
try:
|
||||
_, raw_data = conn.fetch(uid, "(RFC822)")
|
||||
if not raw_data or not raw_data[0]:
|
||||
continue
|
||||
msg = _email_lib.message_from_bytes(raw_data[0][1])
|
||||
subj = _decode_str(msg.get("Subject", ""))
|
||||
from_addr = _decode_str(msg.get("From", ""))
|
||||
date = _decode_str(msg.get("Date", ""))
|
||||
body = extract_body(msg)[:800]
|
||||
entry = {"subject": subj, "body": body, "from_addr": from_addr,
|
||||
"date": date, "account": name}
|
||||
k = entry_key(entry)
|
||||
if k not in known_keys:
|
||||
known_keys.add(k)
|
||||
emails.append(entry)
|
||||
else:
|
||||
skipped += 1
|
||||
except Exception:
|
||||
skipped += 1
|
||||
|
||||
try:
|
||||
conn.logout()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
yield {"type": "done", "account": name, "added": len(emails), "skipped": skipped,
|
||||
"emails": emails}
|
||||
|
||||
|
||||
class AccountTestRequest(BaseModel):
|
||||
account: dict
|
||||
|
||||
|
||||
@router.post("/accounts/test")
|
||||
def test_account_route(req: AccountTestRequest) -> dict:
|
||||
ok, message, count = test_connection(req.account)
|
||||
return {"ok": ok, "message": message, "count": count}
|
||||
|
||||
|
||||
@router.get("/fetch/stream")
|
||||
def fetch_stream(
|
||||
accounts: str = Query(default=""),
|
||||
days_back: int = Query(default=90, ge=1, le=365),
|
||||
limit: int = Query(default=150, ge=1, le=1000),
|
||||
mode: str = Query(default="wide"),
|
||||
) -> StreamingResponse:
|
||||
selected_names = {n.strip() for n in accounts.split(",") if n.strip()}
|
||||
all_accounts = _get_config_accounts()
|
||||
selected = [a for a in all_accounts if a.get("name") in selected_names]
|
||||
|
||||
def generate():
|
||||
known_keys = {entry_key(x) for x in read_jsonl(_queue_file())}
|
||||
total_added = 0
|
||||
for acc in selected:
|
||||
try:
|
||||
batch_emails: list[dict] = []
|
||||
for event in fetch_account_stream(acc, days_back, limit, known_keys):
|
||||
if event["type"] == "done":
|
||||
batch_emails = event.pop("emails", [])
|
||||
total_added += event["added"]
|
||||
yield f"data: {json.dumps(event)}\n\n"
|
||||
if batch_emails:
|
||||
existing = read_jsonl(_queue_file())
|
||||
write_jsonl(_queue_file(), existing + batch_emails)
|
||||
except Exception as exc:
|
||||
yield f"data: {json.dumps({'type': 'error', 'account': acc.get('name', '?'), 'message': str(exc)})}\n\n"
|
||||
queue_size = len(read_jsonl(_queue_file()))
|
||||
yield f"data: {json.dumps({'type': 'complete', 'total_added': total_added, 'queue_size': queue_size})}\n\n"
|
||||
|
||||
return StreamingResponse(generate(), media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"})
|
||||
644
app/data/imitate.py
Normal file
644
app/data/imitate.py
Normal file
|
|
@ -0,0 +1,644 @@
|
|||
"""Avocet — Imitate tab API.
|
||||
|
||||
Fetches real samples from sibling CF product APIs, sends them through selected
|
||||
local LLMs (ollama), and streams responses back to the UI. Results can be
|
||||
pushed into the SFT corrections queue for human review.
|
||||
|
||||
All endpoints registered on `router`. api.py includes this with prefix="/api/imitate".
|
||||
|
||||
Module-level globals follow the same testability pattern as cforch.py and sft.py:
|
||||
override _CONFIG_DIR and _DATA_DIR via set_config_dir() / set_data_dir() in tests.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from urllib.error import URLError
|
||||
from urllib.request import Request, urlopen
|
||||
|
||||
import httpx
|
||||
import yaml
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.utils import append_jsonl
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_ROOT = Path(__file__).parent.parent.parent
|
||||
_CONFIG_DIR: Path | None = None
|
||||
_DATA_DIR: Path = _ROOT / "data"
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# ── Testability seams ──────────────────────────────────────────────────────────
|
||||
|
||||
def set_config_dir(path: Path | None) -> None:
|
||||
global _CONFIG_DIR
|
||||
_CONFIG_DIR = path
|
||||
|
||||
|
||||
def set_data_dir(path: Path) -> None:
|
||||
global _DATA_DIR
|
||||
_DATA_DIR = path
|
||||
|
||||
|
||||
# ── Internal helpers ───────────────────────────────────────────────────────────
|
||||
|
||||
def _config_file() -> Path:
|
||||
if _CONFIG_DIR is not None:
|
||||
return _CONFIG_DIR / "label_tool.yaml"
|
||||
return _ROOT / "config" / "label_tool.yaml"
|
||||
|
||||
|
||||
def _load_imitate_config() -> dict:
|
||||
"""Read label_tool.yaml and return the imitate sub-dict (or {} if absent)."""
|
||||
f = _config_file()
|
||||
if not f.exists():
|
||||
return {}
|
||||
try:
|
||||
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
||||
except yaml.YAMLError as exc:
|
||||
logger.warning("Failed to parse imitate config %s: %s", f, exc)
|
||||
return {}
|
||||
return raw.get("imitate", {}) or {}
|
||||
|
||||
|
||||
def _load_cforch_config() -> dict:
|
||||
"""Read cforch section for ollama_url fallback."""
|
||||
f = _config_file()
|
||||
if not f.exists():
|
||||
return {}
|
||||
try:
|
||||
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
||||
except yaml.YAMLError as exc:
|
||||
return {}
|
||||
return raw.get("cforch", {}) or {}
|
||||
|
||||
|
||||
def _ollama_url(cfg: dict) -> str:
|
||||
cforch = _load_cforch_config()
|
||||
return cfg.get("ollama_url") or cforch.get("ollama_url") or "http://localhost:11434"
|
||||
|
||||
|
||||
def _cforch_url() -> str:
|
||||
cforch = _load_cforch_config()
|
||||
return cforch.get("coordinator_url") or "http://localhost:7700"
|
||||
|
||||
|
||||
def _cforch_catalog(cforch_base: str) -> list[dict]:
|
||||
"""Fetch the live cf-text catalog from cf-orch.
|
||||
|
||||
Filters out proxy entries (ollama://, vllm://, http://) — those models are
|
||||
served by their own services and should not be allocated via cf-text.
|
||||
Returns only models with real file-system paths that cf-text can load directly.
|
||||
"""
|
||||
try:
|
||||
resp = httpx.get(
|
||||
f"{cforch_base}/api/services/cf-text/catalog",
|
||||
params={"node_id": "heimdall"},
|
||||
timeout=5.0,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
raw = resp.json()
|
||||
result = []
|
||||
for model_id, entry in raw.items():
|
||||
if not isinstance(entry, dict):
|
||||
continue
|
||||
path = entry.get("path", "")
|
||||
# Skip proxy entries — they're routed through other services
|
||||
if "://" in path:
|
||||
continue
|
||||
result.append({
|
||||
"id": model_id,
|
||||
"vram_mb": entry.get("vram_mb", 0),
|
||||
"description": entry.get("description", ""),
|
||||
})
|
||||
return result
|
||||
except Exception as exc:
|
||||
logger.warning("Could not fetch cf-orch catalog: %s", exc)
|
||||
return []
|
||||
|
||||
|
||||
def _http_get_json(url: str, timeout: int = 5) -> Any:
|
||||
"""Fetch JSON from url; raise URLError on failure."""
|
||||
req = Request(url, headers={"Accept": "application/json"})
|
||||
with urlopen(req, timeout=timeout) as resp:
|
||||
return json.loads(resp.read().decode("utf-8"))
|
||||
|
||||
|
||||
def _is_online(base_url: str, health_path: str = "/api/health") -> bool:
|
||||
"""Return True if the product's health endpoint responds OK."""
|
||||
try:
|
||||
data = _http_get_json(f"{base_url.rstrip('/')}{health_path}", timeout=2)
|
||||
return bool(data)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _extract_sample(
|
||||
raw: Any,
|
||||
text_fields: list[str],
|
||||
sample_index: int = 0,
|
||||
sample_key: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Pull one item from a list or dict response and extract text_fields.
|
||||
|
||||
sample_key: if provided, unwrap raw[sample_key] before looking for a list.
|
||||
Falls back to a set of conventional envelope keys if sample_key is absent.
|
||||
"""
|
||||
item: dict[str, Any]
|
||||
if isinstance(raw, list):
|
||||
if not raw:
|
||||
return {}
|
||||
item = raw[min(sample_index, len(raw) - 1)]
|
||||
elif isinstance(raw, dict):
|
||||
# Use declared sample_key first, then fall back to conventional names.
|
||||
_ENVELOPE_KEYS = (
|
||||
"samples", "items", "results", "data", "jobs", "listings",
|
||||
"pantry", "saved_searches", "entries", "calls", "records",
|
||||
)
|
||||
search_keys = ([sample_key] if sample_key else []) + list(_ENVELOPE_KEYS)
|
||||
for key in search_keys:
|
||||
if key in raw and isinstance(raw[key], list):
|
||||
lst = raw[key]
|
||||
item = lst[min(sample_index, len(lst) - 1)] if lst else {}
|
||||
break
|
||||
else:
|
||||
item = raw
|
||||
else:
|
||||
return {}
|
||||
|
||||
parts = []
|
||||
for field in text_fields:
|
||||
val = item.get(field)
|
||||
if val and str(val).strip():
|
||||
parts.append(f"**{field}**: {val}")
|
||||
return {"item": item, "text": "\n\n".join(parts)}
|
||||
|
||||
|
||||
def _candidates_file() -> Path:
|
||||
return _DATA_DIR / "sft_candidates.jsonl"
|
||||
|
||||
|
||||
def _sse(data: dict) -> str:
|
||||
return f"data: {json.dumps(data)}\n\n"
|
||||
|
||||
|
||||
def _fetch_image_b64(image_url: str) -> str:
|
||||
"""Download an image URL and return it as a base64 string for ollama.
|
||||
|
||||
Returns empty string on any failure — a missing image is non-fatal;
|
||||
the model will still run against the text prompt alone.
|
||||
"""
|
||||
try:
|
||||
req = Request(image_url, headers={"User-Agent": "Avocet/1.0"})
|
||||
with urlopen(req, timeout=10) as resp:
|
||||
return base64.b64encode(resp.read()).decode("ascii")
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to fetch image %s: %s", image_url, exc)
|
||||
return ""
|
||||
|
||||
|
||||
def _run_ollama_streaming(
|
||||
ollama_base: str,
|
||||
model_id: str,
|
||||
prompt: str,
|
||||
temperature: float,
|
||||
system: str = "",
|
||||
images: list[str] | None = None,
|
||||
) -> tuple[str, int]:
|
||||
"""Call ollama /api/generate with stream=False; return (full_response, elapsed_ms).
|
||||
|
||||
Blocks until the model finishes; yields nothing — streaming is handled by
|
||||
the SSE generator in run_imitate().
|
||||
|
||||
system: optional system prompt passed as a separate field to ollama.
|
||||
images: list of base64-encoded image strings (vision models only).
|
||||
"""
|
||||
url = f"{ollama_base.rstrip('/')}/api/generate"
|
||||
body: dict = {
|
||||
"model": model_id,
|
||||
"prompt": prompt,
|
||||
"stream": False,
|
||||
"options": {"temperature": temperature},
|
||||
}
|
||||
if system:
|
||||
body["system"] = system
|
||||
if images:
|
||||
body["images"] = images
|
||||
payload = json.dumps(body).encode("utf-8")
|
||||
req = Request(url, data=payload, method="POST",
|
||||
headers={"Content-Type": "application/json"})
|
||||
t0 = time.time()
|
||||
try:
|
||||
with urlopen(req, timeout=120) as resp:
|
||||
body = json.loads(resp.read().decode("utf-8"))
|
||||
elapsed = int((time.time() - t0) * 1000)
|
||||
return body.get("response", ""), elapsed
|
||||
except Exception as exc:
|
||||
elapsed = int((time.time() - t0) * 1000)
|
||||
raise RuntimeError(str(exc)) from exc
|
||||
|
||||
|
||||
def _run_cftext(
|
||||
cforch_base: str,
|
||||
model_id: str,
|
||||
prompt: str,
|
||||
system: str,
|
||||
temperature: float,
|
||||
startup_timeout_s: float = 180.0,
|
||||
user_id: str | None = None,
|
||||
) -> tuple[str, int, bool]:
|
||||
"""Allocate cf-text via cf-orch, generate, release. Returns (response, elapsed_ms, cold_started).
|
||||
|
||||
Raises RuntimeError on allocation failure or generation error.
|
||||
cold_started=True means the service was launched from scratch (caller may log this).
|
||||
|
||||
Cold-start detection uses coordinator state signals (running/stopped) rather than
|
||||
polling the service health endpoint — this fails fast on model load errors instead
|
||||
of waiting out the full timeout.
|
||||
"""
|
||||
# Allocate
|
||||
alloc_resp = httpx.post(
|
||||
f"{cforch_base}/api/services/cf-text/allocate",
|
||||
json={
|
||||
"model_candidates": [model_id],
|
||||
"caller": "avocet",
|
||||
"pipeline": "imitate",
|
||||
**({"user_id": user_id} if user_id else {}),
|
||||
},
|
||||
timeout=30.0,
|
||||
)
|
||||
alloc_resp.raise_for_status()
|
||||
data = alloc_resp.json()
|
||||
service_url: str = data["url"]
|
||||
allocation_id: str = data.get("allocation_id", "")
|
||||
node_id: str = data.get("node_id", "")
|
||||
gpu_id: int | None = data.get("gpu_id")
|
||||
cold_started = data.get("started", False) and not data.get("warm", True)
|
||||
|
||||
# Wait for ready using coordinator state signals
|
||||
if cold_started:
|
||||
deadline = time.monotonic() + startup_timeout_s
|
||||
probe_misses = 0
|
||||
while time.monotonic() < deadline:
|
||||
try:
|
||||
status = httpx.get(
|
||||
f"{cforch_base}/api/services/cf-text/status", timeout=5.0
|
||||
)
|
||||
if status.is_success:
|
||||
instances = status.json().get("instances", [])
|
||||
match = next(
|
||||
(i for i in instances
|
||||
if i.get("node_id") == node_id and i.get("gpu_id") == gpu_id),
|
||||
None,
|
||||
)
|
||||
if match:
|
||||
probe_misses = 0
|
||||
state = match.get("state", "")
|
||||
if state == "running":
|
||||
break
|
||||
elif state == "stopped":
|
||||
if allocation_id:
|
||||
httpx.delete(
|
||||
f"{cforch_base}/api/services/cf-text/allocations/{allocation_id}",
|
||||
timeout=5.0,
|
||||
)
|
||||
raise RuntimeError(f"cf-text failed to load {model_id!r} (service stopped)")
|
||||
else:
|
||||
probe_misses += 1
|
||||
if probe_misses >= 6:
|
||||
# Coordinator hasn't registered instance yet — fall back to health poll
|
||||
try:
|
||||
if httpx.get(f"{service_url}/health", timeout=3.0).is_success:
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
except RuntimeError:
|
||||
raise
|
||||
except Exception:
|
||||
pass
|
||||
time.sleep(2.0)
|
||||
else:
|
||||
if allocation_id:
|
||||
httpx.delete(f"{cforch_base}/api/services/cf-text/allocations/{allocation_id}", timeout=5.0)
|
||||
raise RuntimeError(f"cf-text cold start timed out after {startup_timeout_s:.0f}s")
|
||||
|
||||
# Generate
|
||||
messages: list[dict] = []
|
||||
if system:
|
||||
messages.append({"role": "system", "content": system})
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
t0 = time.time()
|
||||
try:
|
||||
gen_resp = httpx.post(
|
||||
f"{service_url}/v1/chat/completions",
|
||||
json={
|
||||
"model": model_id,
|
||||
"messages": messages,
|
||||
"max_tokens": 300,
|
||||
"temperature": temperature,
|
||||
"stream": False,
|
||||
},
|
||||
timeout=120.0,
|
||||
)
|
||||
gen_resp.raise_for_status()
|
||||
elapsed_ms = int((time.time() - t0) * 1000)
|
||||
content = gen_resp.json()["choices"][0]["message"]["content"]
|
||||
return content.strip(), elapsed_ms, cold_started
|
||||
except Exception as exc:
|
||||
elapsed_ms = int((time.time() - t0) * 1000)
|
||||
raise RuntimeError(str(exc)) from exc
|
||||
finally:
|
||||
if allocation_id:
|
||||
try:
|
||||
httpx.delete(f"{cforch_base}/api/services/cf-text/allocations/{allocation_id}", timeout=5.0)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
# ── GET /products ──────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/products")
|
||||
def get_products() -> dict:
|
||||
"""List configured CF products with live online status."""
|
||||
cfg = _load_imitate_config()
|
||||
products_raw = cfg.get("products", []) or []
|
||||
products = []
|
||||
for p in products_raw:
|
||||
if not isinstance(p, dict):
|
||||
continue
|
||||
base_url = p.get("base_url", "")
|
||||
products.append({
|
||||
"id": p.get("id", ""),
|
||||
"name": p.get("name", ""),
|
||||
"icon": p.get("icon", "📦"),
|
||||
"description": p.get("description", ""),
|
||||
"base_url": base_url,
|
||||
"online": _is_online(base_url, p.get("health_path", "/api/health")) if base_url else False,
|
||||
})
|
||||
return {"products": products}
|
||||
|
||||
|
||||
# ── GET /products/{product_id}/sample ─────────────────────────────────────────
|
||||
|
||||
@router.get("/products/{product_id}/sample")
|
||||
def get_sample(product_id: str, index: int = 0) -> dict:
|
||||
"""Fetch a real sample from the given product's API."""
|
||||
cfg = _load_imitate_config()
|
||||
products_raw = cfg.get("products", []) or []
|
||||
|
||||
product: dict | None = None
|
||||
for p in products_raw:
|
||||
if isinstance(p, dict) and p.get("id") == product_id:
|
||||
product = p
|
||||
break
|
||||
|
||||
if product is None:
|
||||
raise HTTPException(404, f"Product '{product_id}' not in config")
|
||||
|
||||
base_url = product.get("base_url", "").rstrip("/")
|
||||
endpoint = product.get("sample_endpoint", "")
|
||||
if not base_url or not endpoint:
|
||||
raise HTTPException(422, "Product missing base_url or sample_endpoint")
|
||||
|
||||
url = f"{base_url}{endpoint}"
|
||||
try:
|
||||
raw = _http_get_json(url, timeout=5)
|
||||
except URLError as exc:
|
||||
raise HTTPException(503, f"Product API unreachable: {exc}") from exc
|
||||
except Exception as exc:
|
||||
raise HTTPException(502, f"Bad response from product API: {exc}") from exc
|
||||
|
||||
text_fields = product.get("text_fields", []) or []
|
||||
sample_key = product.get("sample_key") or None
|
||||
extracted = _extract_sample(raw, text_fields, index, sample_key=sample_key)
|
||||
if not extracted:
|
||||
raise HTTPException(404, "No sample items returned by product API")
|
||||
|
||||
prompt_template = product.get("prompt_template", "{text}")
|
||||
prompt = prompt_template.replace("{text}", extracted["text"])
|
||||
# Also substitute any {field_name} placeholders from the raw item fields.
|
||||
item = extracted.get("item", {})
|
||||
for field, val in item.items():
|
||||
prompt = prompt.replace(f"{{{field}}}", str(val) if val is not None else "")
|
||||
|
||||
# Expose system_prompt and image_url if the product API returns them.
|
||||
# system_prompt: Peregrine, Snipe (vision analysis instructions)
|
||||
# image_url: Snipe listing photos — Avocet downloads + base64-encodes at run time
|
||||
item = extracted.get("item", {})
|
||||
system_prompt = str(item.get("system_prompt", "")) if isinstance(item, dict) else ""
|
||||
image_url = str(item.get("image_url", "")) if isinstance(item, dict) else ""
|
||||
|
||||
return {
|
||||
"product_id": product_id,
|
||||
"sample_index": index,
|
||||
"text": extracted["text"],
|
||||
"prompt": prompt,
|
||||
"system_prompt": system_prompt,
|
||||
"image_url": image_url,
|
||||
"raw_item": item,
|
||||
}
|
||||
|
||||
|
||||
# ── GET /catalog ───────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/catalog")
|
||||
def get_catalog() -> dict:
|
||||
"""Return the live cf-text model catalog from cf-orch coordinator."""
|
||||
models = _cforch_catalog(_cforch_url())
|
||||
return {"models": models}
|
||||
|
||||
|
||||
# ── GET /run (SSE) ─────────────────────────────────────────────────────────────
|
||||
|
||||
def _get_imitate_session(request: Any, response: Any) -> "CloudUser | None":
|
||||
"""Optional session dependency — returns None when cloud_session is unavailable."""
|
||||
try:
|
||||
from app.cloud_session import get_session
|
||||
return get_session(request, response)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
@router.get("/run")
|
||||
def run_imitate(
|
||||
prompt: str = "",
|
||||
model_ids: str = "", # comma-separated ollama model IDs
|
||||
cf_text_model_ids: str = "", # comma-separated cf-text model IDs (via cf-orch)
|
||||
temperature: float = 0.7,
|
||||
product_id: str = "",
|
||||
system: str = "", # optional system prompt
|
||||
image_url: str = "", # optional image URL for vision models
|
||||
session: "Any" = Depends(_get_imitate_session),
|
||||
) -> StreamingResponse:
|
||||
"""Run a prompt through selected ollama models and stream results as SSE.
|
||||
|
||||
If image_url is provided, the image is downloaded once and passed to every
|
||||
model as a base64-encoded blob — allowing vision-capable local models to
|
||||
evaluate listing photos the same way Snipe's background task pipeline does.
|
||||
"""
|
||||
|
||||
if not prompt.strip():
|
||||
raise HTTPException(422, "prompt is required")
|
||||
|
||||
ollama_ids = [m.strip() for m in model_ids.split(",") if m.strip()]
|
||||
cftext_ids = [m.strip() for m in cf_text_model_ids.split(",") if m.strip()]
|
||||
if not ollama_ids and not cftext_ids:
|
||||
raise HTTPException(422, "model_ids or cf_text_model_ids is required")
|
||||
|
||||
cfg = _load_imitate_config()
|
||||
ollama_base = _ollama_url(cfg)
|
||||
cforch_base = _cforch_url()
|
||||
system_ctx = system.strip() or ""
|
||||
total_models = len(ollama_ids) + len(cftext_ids)
|
||||
|
||||
# Download image once before streaming — shared across ollama vision models
|
||||
images: list[str] = []
|
||||
if image_url.strip():
|
||||
b64 = _fetch_image_b64(image_url.strip())
|
||||
if b64:
|
||||
images = [b64]
|
||||
|
||||
def generate():
|
||||
results: list[dict] = []
|
||||
yield _sse({"type": "start", "total_models": total_models, "has_image": bool(images)})
|
||||
|
||||
# Ollama models
|
||||
for model_id in ollama_ids:
|
||||
yield _sse({"type": "model_start", "model": model_id, "service": "ollama"})
|
||||
try:
|
||||
response, elapsed_ms = _run_ollama_streaming(
|
||||
ollama_base, model_id, prompt, temperature,
|
||||
system=system_ctx, images=images or None,
|
||||
)
|
||||
result = {
|
||||
"model": model_id,
|
||||
"response": response,
|
||||
"elapsed_ms": elapsed_ms,
|
||||
"error": None,
|
||||
}
|
||||
except Exception as exc:
|
||||
result = {
|
||||
"model": model_id,
|
||||
"response": "",
|
||||
"elapsed_ms": 0,
|
||||
"error": str(exc),
|
||||
}
|
||||
results.append(result)
|
||||
yield _sse({"type": "model_done", **result})
|
||||
|
||||
# cf-text models via cf-orch — fan out in parallel when multiple models selected
|
||||
if cftext_ids:
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
# Announce all models upfront so the UI can show loading states immediately
|
||||
for model_id in cftext_ids:
|
||||
yield _sse({"type": "model_start", "model": model_id, "service": "cf-text"})
|
||||
|
||||
_user_id: str | None = getattr(session, "user_id", None)
|
||||
# Only forward real cloud user IDs — skip local/anon sessions
|
||||
if _user_id in (None, "local", "local-dev") or (_user_id or "").startswith("anon-"):
|
||||
_user_id = None
|
||||
|
||||
with ThreadPoolExecutor(max_workers=len(cftext_ids)) as pool:
|
||||
future_to_model = {
|
||||
pool.submit(
|
||||
_run_cftext, cforch_base, mid, prompt, system_ctx, temperature,
|
||||
180.0, _user_id,
|
||||
): mid
|
||||
for mid in cftext_ids
|
||||
}
|
||||
for future in as_completed(future_to_model):
|
||||
model_id = future_to_model[future]
|
||||
try:
|
||||
response, elapsed_ms, cold_started = future.result()
|
||||
if cold_started:
|
||||
yield _sse({"type": "model_coldstart", "model": model_id})
|
||||
result = {
|
||||
"model": model_id,
|
||||
"response": response,
|
||||
"elapsed_ms": elapsed_ms,
|
||||
"error": None,
|
||||
}
|
||||
except Exception as exc:
|
||||
result = {
|
||||
"model": model_id,
|
||||
"response": "",
|
||||
"elapsed_ms": 0,
|
||||
"error": str(exc),
|
||||
}
|
||||
results.append(result)
|
||||
yield _sse({"type": "model_done", **result})
|
||||
|
||||
yield _sse({"type": "complete", "results": results})
|
||||
|
||||
return StreamingResponse(
|
||||
generate(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# ── POST /push-corrections ─────────────────────────────────────────────────────
|
||||
|
||||
class ImitateResult(BaseModel):
|
||||
model: str
|
||||
response: str
|
||||
elapsed_ms: int
|
||||
error: str | None = None
|
||||
|
||||
|
||||
class PushCorrectionsRequest(BaseModel):
|
||||
product_id: str
|
||||
prompt: str
|
||||
results: list[ImitateResult]
|
||||
|
||||
|
||||
@router.post("/push-corrections")
|
||||
def push_corrections(req: PushCorrectionsRequest) -> dict:
|
||||
"""Append imitate results to sft_candidates.jsonl for human review."""
|
||||
if not req.prompt.strip():
|
||||
raise HTTPException(422, "prompt is required")
|
||||
if not req.results:
|
||||
raise HTTPException(422, "results list is empty")
|
||||
|
||||
ts = datetime.now(timezone.utc).isoformat()
|
||||
records = []
|
||||
for r in req.results:
|
||||
if r.error or not r.response.strip():
|
||||
continue
|
||||
records.append({
|
||||
"id": str(uuid.uuid4()),
|
||||
"source": "imitate",
|
||||
"product_id": req.product_id,
|
||||
"prompt_messages": [{"role": "user", "content": req.prompt}],
|
||||
"model_response": r.response,
|
||||
"model_id": r.model,
|
||||
"elapsed_ms": r.elapsed_ms,
|
||||
"status": "pending",
|
||||
"created_at": ts,
|
||||
})
|
||||
|
||||
if not records:
|
||||
raise HTTPException(422, "No non-error results to push")
|
||||
|
||||
dest = _candidates_file()
|
||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||
for record in records:
|
||||
append_jsonl(dest, record)
|
||||
|
||||
return {"pushed": len(records)}
|
||||
222
app/data/label.py
Normal file
222
app/data/label.py
Normal file
|
|
@ -0,0 +1,222 @@
|
|||
"""Avocet -- label queue API.
|
||||
|
||||
All label/skip/discard/undo/stats/config endpoints.
|
||||
Extracted from app/api.py as part of the v2 domain split.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import yaml
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
from fastapi.responses import FileResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.utils import append_jsonl, read_jsonl, write_jsonl
|
||||
|
||||
_ROOT = Path(__file__).parent.parent.parent
|
||||
_DATA_DIR: Path = _ROOT / "data"
|
||||
_CONFIG_DIR: Path | None = None
|
||||
_last_action: dict | None = None
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def set_data_dir(path: Path) -> None:
|
||||
global _DATA_DIR
|
||||
_DATA_DIR = path
|
||||
|
||||
def set_config_dir(path: Path | None) -> None:
|
||||
global _CONFIG_DIR
|
||||
_CONFIG_DIR = path
|
||||
|
||||
def reset_last_action() -> None:
|
||||
global _last_action
|
||||
_last_action = None
|
||||
|
||||
def _config_file() -> Path:
|
||||
if _CONFIG_DIR is not None:
|
||||
return _CONFIG_DIR / "label_tool.yaml"
|
||||
return _ROOT / "config" / "label_tool.yaml"
|
||||
|
||||
def _queue_file() -> Path:
|
||||
return _DATA_DIR / "email_label_queue.jsonl"
|
||||
|
||||
def _score_file() -> Path:
|
||||
return _DATA_DIR / "email_score.jsonl"
|
||||
|
||||
def _discarded_file() -> Path:
|
||||
return _DATA_DIR / "discarded.jsonl"
|
||||
|
||||
def _item_id(item: dict) -> str:
|
||||
key = (item.get("subject", "") + (item.get("body", "") or "")[:100])
|
||||
return hashlib.md5(key.encode("utf-8", errors="replace")).hexdigest()
|
||||
|
||||
def _normalize(item: dict) -> dict:
|
||||
return {
|
||||
"id": item.get("id") or _item_id(item),
|
||||
"subject": item.get("subject", ""),
|
||||
"body": item.get("body", ""),
|
||||
"from": item.get("from") or item.get("from_addr", ""),
|
||||
"date": item.get("date", ""),
|
||||
"source": item.get("source") or item.get("account", ""),
|
||||
}
|
||||
|
||||
_LABEL_META = [
|
||||
{"name": "interview_scheduled", "emoji": "\U0001f4c5", "color": "#4CAF50", "key": "1"},
|
||||
{"name": "offer_received", "emoji": "\U0001f389", "color": "#2196F3", "key": "2"},
|
||||
{"name": "rejected", "emoji": "❌", "color": "#F44336", "key": "3"},
|
||||
{"name": "positive_response", "emoji": "\U0001f44d", "color": "#FF9800", "key": "4"},
|
||||
{"name": "survey_received", "emoji": "\U0001f4cb", "color": "#9C27B0", "key": "5"},
|
||||
{"name": "neutral", "emoji": "⬜", "color": "#607D8B", "key": "6"},
|
||||
{"name": "event_rescheduled", "emoji": "\U0001f504", "color": "#FF5722", "key": "7"},
|
||||
{"name": "digest", "emoji": "\U0001f4f0", "color": "#00BCD4", "key": "8"},
|
||||
{"name": "new_lead", "emoji": "\U0001f91d", "color": "#009688", "key": "9"},
|
||||
{"name": "hired", "emoji": "\U0001f38a", "color": "#FFC107", "key": "h"},
|
||||
]
|
||||
|
||||
@router.get("/queue")
|
||||
def get_queue(limit: int = Query(default=10, ge=1, le=50)):
|
||||
items = read_jsonl(_queue_file())
|
||||
return {"items": [_normalize(x) for x in items[:limit]], "total": len(items)}
|
||||
|
||||
class LabelRequest(BaseModel):
|
||||
id: str
|
||||
label: str
|
||||
|
||||
@router.post("/label")
|
||||
def post_label(req: LabelRequest):
|
||||
global _last_action
|
||||
items = read_jsonl(_queue_file())
|
||||
match = next((x for x in items if _normalize(x)["id"] == req.id), None)
|
||||
if not match:
|
||||
raise HTTPException(404, f"Item {req.id!r} not found in queue")
|
||||
record = {**match, "label": req.label,
|
||||
"labeled_at": datetime.now(timezone.utc).isoformat()}
|
||||
append_jsonl(_score_file(), record)
|
||||
write_jsonl(_queue_file(), [x for x in items if _normalize(x)["id"] != req.id])
|
||||
_last_action = {"type": "label", "item": match, "label": req.label}
|
||||
return {"ok": True}
|
||||
|
||||
class SkipRequest(BaseModel):
|
||||
id: str
|
||||
|
||||
@router.post("/skip")
|
||||
def post_skip(req: SkipRequest):
|
||||
global _last_action
|
||||
items = read_jsonl(_queue_file())
|
||||
match = next((x for x in items if _normalize(x)["id"] == req.id), None)
|
||||
if not match:
|
||||
raise HTTPException(404, f"Item {req.id!r} not found in queue")
|
||||
reordered = [x for x in items if _normalize(x)["id"] != req.id] + [match]
|
||||
write_jsonl(_queue_file(), reordered)
|
||||
_last_action = {"type": "skip", "item": match}
|
||||
return {"ok": True}
|
||||
|
||||
class DiscardRequest(BaseModel):
|
||||
id: str
|
||||
|
||||
@router.post("/discard")
|
||||
def post_discard(req: DiscardRequest):
|
||||
global _last_action
|
||||
items = read_jsonl(_queue_file())
|
||||
match = next((x for x in items if _normalize(x)["id"] == req.id), None)
|
||||
if not match:
|
||||
raise HTTPException(404, f"Item {req.id!r} not found in queue")
|
||||
record = {**match, "label": "__discarded__",
|
||||
"discarded_at": datetime.now(timezone.utc).isoformat()}
|
||||
append_jsonl(_discarded_file(), record)
|
||||
write_jsonl(_queue_file(), [x for x in items if _normalize(x)["id"] != req.id])
|
||||
_last_action = {"type": "discard", "item": match}
|
||||
return {"ok": True}
|
||||
|
||||
@router.delete("/label/undo")
|
||||
def delete_undo():
|
||||
global _last_action
|
||||
if not _last_action:
|
||||
raise HTTPException(404, "No action to undo")
|
||||
action = _last_action
|
||||
item = action["item"]
|
||||
if action["type"] == "label":
|
||||
records = read_jsonl(_score_file())
|
||||
if not records:
|
||||
raise HTTPException(409, "Score file is empty -- cannot undo label")
|
||||
write_jsonl(_score_file(), records[:-1])
|
||||
items = read_jsonl(_queue_file())
|
||||
write_jsonl(_queue_file(), [item] + items)
|
||||
elif action["type"] == "discard":
|
||||
records = read_jsonl(_discarded_file())
|
||||
if not records:
|
||||
raise HTTPException(409, "Discarded file is empty -- cannot undo discard")
|
||||
write_jsonl(_discarded_file(), records[:-1])
|
||||
items = read_jsonl(_queue_file())
|
||||
write_jsonl(_queue_file(), [item] + items)
|
||||
elif action["type"] == "skip":
|
||||
items = read_jsonl(_queue_file())
|
||||
item_id = _normalize(item)["id"]
|
||||
items = [item] + [x for x in items if _normalize(x)["id"] != item_id]
|
||||
write_jsonl(_queue_file(), items)
|
||||
_last_action = None
|
||||
return {"undone": {"type": action["type"], "item": _normalize(item)}}
|
||||
|
||||
@router.get("/config/labels")
|
||||
def get_labels():
|
||||
return _LABEL_META
|
||||
|
||||
@router.get("/config")
|
||||
def get_config():
|
||||
f = _config_file()
|
||||
if not f.exists():
|
||||
return {"accounts": [], "max_per_account": 500}
|
||||
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
||||
return {"accounts": raw.get("accounts", []), "max_per_account": raw.get("max_per_account", 500)}
|
||||
|
||||
class ConfigPayload(BaseModel):
|
||||
accounts: list[dict]
|
||||
max_per_account: int = 500
|
||||
|
||||
@router.post("/config")
|
||||
def post_config(payload: ConfigPayload):
|
||||
f = _config_file()
|
||||
f.parent.mkdir(parents=True, exist_ok=True)
|
||||
tmp = f.with_suffix(".tmp")
|
||||
tmp.write_text(yaml.dump(payload.model_dump(), allow_unicode=True, sort_keys=False),
|
||||
encoding="utf-8")
|
||||
tmp.rename(f)
|
||||
return {"ok": True}
|
||||
|
||||
@router.get("/stats")
|
||||
def get_stats():
|
||||
records = read_jsonl(_score_file())
|
||||
counts: dict[str, int] = {}
|
||||
for r in records:
|
||||
lbl = r.get("label", "")
|
||||
if lbl:
|
||||
counts[lbl] = counts.get(lbl, 0) + 1
|
||||
benchmark_results: dict = {}
|
||||
benchmark_path = _DATA_DIR / "benchmark_results.json"
|
||||
if benchmark_path.exists():
|
||||
try:
|
||||
benchmark_results = json.loads(benchmark_path.read_text(encoding="utf-8"))
|
||||
except Exception:
|
||||
pass
|
||||
return {
|
||||
"total": len(records),
|
||||
"counts": counts,
|
||||
"score_file_bytes": _score_file().stat().st_size if _score_file().exists() else 0,
|
||||
"benchmark_results": benchmark_results,
|
||||
}
|
||||
|
||||
@router.get("/stats/download")
|
||||
def download_stats():
|
||||
if not _score_file().exists():
|
||||
raise HTTPException(404, "No score file")
|
||||
return FileResponse(
|
||||
str(_score_file()),
|
||||
filename="email_score.jsonl",
|
||||
media_type="application/jsonlines",
|
||||
headers={"Content-Disposition": 'attachment; filename="email_score.jsonl"'},
|
||||
)
|
||||
0
app/eval/__init__.py
Normal file
0
app/eval/__init__.py
Normal file
38
app/eval/cforch.py
Normal file
38
app/eval/cforch.py
Normal file
|
|
@ -0,0 +1,38 @@
|
|||
"""Avocet -- eval router aggregator.
|
||||
|
||||
Collects benchmark sub-routers into a single importable `router`
|
||||
for the api.py factory. Each sub-router retains its established prefix
|
||||
so no frontend URL changes are needed.
|
||||
|
||||
Route prefixes when mounted at /api in api.py:
|
||||
/api/cforch/* -- cf-orch benchmark routes
|
||||
/api/style/* -- writing style benchmark routes
|
||||
/api/voice/* -- voice benchmark routes
|
||||
/api/plans-bench/* -- plans benchmark routes
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from app.cforch import router as _cforch_router
|
||||
from app.style import router as _style_router
|
||||
from app.voice import router as _voice_router
|
||||
from app.plans_bench import router as _plans_router
|
||||
|
||||
router = APIRouter()
|
||||
router.include_router(_cforch_router, prefix="/cforch")
|
||||
router.include_router(_style_router, prefix="/style")
|
||||
router.include_router(_voice_router, prefix="/voice")
|
||||
router.include_router(_plans_router, prefix="/plans-bench")
|
||||
|
||||
|
||||
def set_config_dir(path) -> None:
|
||||
"""Propagate config dir override to all sub-modules -- used by tests."""
|
||||
import app.cforch as _cforch_mod
|
||||
import app.style as _style_mod
|
||||
import app.voice as _voice_mod
|
||||
import app.plans_bench as _plans_mod
|
||||
_cforch_mod.set_config_dir(path)
|
||||
_style_mod.set_config_dir(path)
|
||||
_voice_mod.set_config_dir(path)
|
||||
_plans_mod.set_config_dir(path)
|
||||
|
|
@ -1,158 +1,9 @@
|
|||
"""Avocet — IMAP fetch utilities.
|
||||
|
||||
Shared between app/api.py (FastAPI SSE endpoint) and the label UI.
|
||||
No Streamlit imports here — stdlib + imaplib only.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import email as _email_lib
|
||||
import hashlib
|
||||
import imaplib
|
||||
from datetime import datetime, timedelta
|
||||
from email.header import decode_header as _raw_decode
|
||||
from typing import Any, Iterator
|
||||
|
||||
from app.utils import extract_body, strip_html # noqa: F401 (strip_html re-exported for callers)
|
||||
|
||||
|
||||
# ── IMAP decode helpers ───────────────────────────────────────────────────────
|
||||
|
||||
def _decode_str(value: str | None) -> str:
|
||||
if not value:
|
||||
return ""
|
||||
parts = _raw_decode(value)
|
||||
out = []
|
||||
for part, enc in parts:
|
||||
if isinstance(part, bytes):
|
||||
out.append(part.decode(enc or "utf-8", errors="replace"))
|
||||
else:
|
||||
out.append(str(part))
|
||||
return " ".join(out).strip()
|
||||
|
||||
|
||||
def entry_key(e: dict) -> str:
|
||||
"""Stable MD5 content-hash for dedup — matches label_tool.py _entry_key."""
|
||||
key = (e.get("subject", "") + (e.get("body", "") or "")[:100])
|
||||
return hashlib.md5(key.encode("utf-8", errors="replace")).hexdigest()
|
||||
|
||||
|
||||
# ── Wide search terms ────────────────────────────────────────────────────────
|
||||
|
||||
_WIDE_TERMS = [
|
||||
"interview", "phone screen", "video call", "zoom link", "schedule a call",
|
||||
"offer letter", "job offer", "offer of employment", "pleased to offer",
|
||||
"unfortunately", "not moving forward", "other candidates", "regret to inform",
|
||||
"no longer", "decided not to", "decided to go with",
|
||||
"opportunity", "interested in your background", "reached out", "great fit",
|
||||
"exciting role", "love to connect",
|
||||
"assessment", "questionnaire", "culture fit", "culture-fit", "online assessment",
|
||||
"application received", "thank you for applying", "application confirmation",
|
||||
"you applied", "your application for",
|
||||
"reschedule", "rescheduled", "new time", "moved to", "postponed", "new date",
|
||||
"job digest", "jobs you may like", "recommended jobs", "jobs for you",
|
||||
"new jobs", "job alert",
|
||||
"came across your profile", "reaching out about", "great fit for a role",
|
||||
"exciting opportunity",
|
||||
"welcome to the team", "start date", "onboarding", "first day", "we're excited to have you",
|
||||
"application", "recruiter", "recruiting", "hiring", "candidate",
|
||||
]
|
||||
|
||||
|
||||
# ── Public API ────────────────────────────────────────────────────────────────
|
||||
|
||||
def test_connection(acc: dict) -> tuple[bool, str, int | None]:
|
||||
"""Connect, login, select folder. Returns (ok, human_message, message_count|None)."""
|
||||
host = acc.get("host", "")
|
||||
port = int(acc.get("port", 993))
|
||||
use_ssl = acc.get("use_ssl", True)
|
||||
username = acc.get("username", "")
|
||||
password = acc.get("password", "")
|
||||
folder = acc.get("folder", "INBOX")
|
||||
if not host or not username or not password:
|
||||
return False, "Host, username, and password are all required.", None
|
||||
try:
|
||||
conn = (imaplib.IMAP4_SSL if use_ssl else imaplib.IMAP4)(host, port)
|
||||
conn.login(username, password)
|
||||
_, data = conn.select(folder, readonly=True)
|
||||
count_raw = data[0].decode() if data and data[0] else "0"
|
||||
count = int(count_raw) if count_raw.isdigit() else 0
|
||||
conn.logout()
|
||||
return True, f"Connected — {count:,} message(s) in {folder}.", count
|
||||
except Exception as exc:
|
||||
return False, str(exc), None
|
||||
|
||||
|
||||
def fetch_account_stream(
|
||||
acc: dict,
|
||||
days_back: int,
|
||||
limit: int,
|
||||
known_keys: set[str],
|
||||
) -> Iterator[dict]:
|
||||
"""Generator — yields progress dicts while fetching emails via IMAP.
|
||||
|
||||
Mutates `known_keys` in place for cross-account dedup within one fetch session.
|
||||
|
||||
Yields event dicts with "type" key:
|
||||
{"type": "start", "account": str, "total_uids": int}
|
||||
{"type": "progress", "account": str, "fetched": int, "total_uids": int}
|
||||
{"type": "done", "account": str, "added": int, "skipped": int, "emails": list}
|
||||
"""
|
||||
name = acc.get("name", acc.get("username", "?"))
|
||||
host = acc.get("host", "imap.gmail.com")
|
||||
port = int(acc.get("port", 993))
|
||||
use_ssl = acc.get("use_ssl", True)
|
||||
username = acc["username"]
|
||||
password = acc["password"]
|
||||
folder = acc.get("folder", "INBOX")
|
||||
since = (datetime.now() - timedelta(days=days_back)).strftime("%d-%b-%Y")
|
||||
|
||||
conn = (imaplib.IMAP4_SSL if use_ssl else imaplib.IMAP4)(host, port)
|
||||
conn.login(username, password)
|
||||
conn.select(folder, readonly=True)
|
||||
|
||||
seen_uids: dict[bytes, None] = {}
|
||||
for term in _WIDE_TERMS:
|
||||
try:
|
||||
_, data = conn.search(None, f'(SUBJECT "{term}" SINCE "{since}")')
|
||||
for uid in (data[0] or b"").split():
|
||||
seen_uids[uid] = None
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
uids = list(seen_uids.keys())[: limit * 3]
|
||||
yield {"type": "start", "account": name, "total_uids": len(uids)}
|
||||
|
||||
emails: list[dict] = []
|
||||
skipped = 0
|
||||
for i, uid in enumerate(uids):
|
||||
if len(emails) >= limit:
|
||||
break
|
||||
if i % 5 == 0:
|
||||
yield {"type": "progress", "account": name, "fetched": len(emails), "total_uids": len(uids)}
|
||||
try:
|
||||
_, raw_data = conn.fetch(uid, "(RFC822)")
|
||||
if not raw_data or not raw_data[0]:
|
||||
continue
|
||||
msg = _email_lib.message_from_bytes(raw_data[0][1])
|
||||
subj = _decode_str(msg.get("Subject", ""))
|
||||
from_addr = _decode_str(msg.get("From", ""))
|
||||
date = _decode_str(msg.get("Date", ""))
|
||||
body = extract_body(msg)[:800]
|
||||
entry = {"subject": subj, "body": body, "from_addr": from_addr,
|
||||
"date": date, "account": name}
|
||||
k = entry_key(entry)
|
||||
if k not in known_keys:
|
||||
known_keys.add(k)
|
||||
emails.append(entry)
|
||||
else:
|
||||
skipped += 1
|
||||
except Exception:
|
||||
skipped += 1
|
||||
|
||||
try:
|
||||
conn.logout()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
yield {"type": "done", "account": name, "added": len(emails), "skipped": skipped,
|
||||
"emails": emails}
|
||||
"""Backward-compat shim -- logic moved to app/data/fetch.py."""
|
||||
import imaplib # noqa: F401 -- re-exported so existing patch("app.imap_fetch.imaplib...") calls still work
|
||||
from app.data.fetch import ( # noqa: F401
|
||||
entry_key,
|
||||
fetch_account_stream,
|
||||
test_connection,
|
||||
_decode_str,
|
||||
_WIDE_TERMS,
|
||||
)
|
||||
|
|
|
|||
647
app/imitate.py
647
app/imitate.py
|
|
@ -1,644 +1,3 @@
|
|||
"""Avocet — Imitate tab API.
|
||||
|
||||
Fetches real samples from sibling CF product APIs, sends them through selected
|
||||
local LLMs (ollama), and streams responses back to the UI. Results can be
|
||||
pushed into the SFT corrections queue for human review.
|
||||
|
||||
All endpoints registered on `router`. api.py includes this with prefix="/api/imitate".
|
||||
|
||||
Module-level globals follow the same testability pattern as cforch.py and sft.py:
|
||||
override _CONFIG_DIR and _DATA_DIR via set_config_dir() / set_data_dir() in tests.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from urllib.error import URLError
|
||||
from urllib.request import Request, urlopen
|
||||
|
||||
import httpx
|
||||
import yaml
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.utils import append_jsonl
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_ROOT = Path(__file__).parent.parent
|
||||
_CONFIG_DIR: Path | None = None
|
||||
_DATA_DIR: Path = _ROOT / "data"
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# ── Testability seams ──────────────────────────────────────────────────────────
|
||||
|
||||
def set_config_dir(path: Path | None) -> None:
|
||||
global _CONFIG_DIR
|
||||
_CONFIG_DIR = path
|
||||
|
||||
|
||||
def set_data_dir(path: Path) -> None:
|
||||
global _DATA_DIR
|
||||
_DATA_DIR = path
|
||||
|
||||
|
||||
# ── Internal helpers ───────────────────────────────────────────────────────────
|
||||
|
||||
def _config_file() -> Path:
|
||||
if _CONFIG_DIR is not None:
|
||||
return _CONFIG_DIR / "label_tool.yaml"
|
||||
return _ROOT / "config" / "label_tool.yaml"
|
||||
|
||||
|
||||
def _load_imitate_config() -> dict:
|
||||
"""Read label_tool.yaml and return the imitate sub-dict (or {} if absent)."""
|
||||
f = _config_file()
|
||||
if not f.exists():
|
||||
return {}
|
||||
try:
|
||||
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
||||
except yaml.YAMLError as exc:
|
||||
logger.warning("Failed to parse imitate config %s: %s", f, exc)
|
||||
return {}
|
||||
return raw.get("imitate", {}) or {}
|
||||
|
||||
|
||||
def _load_cforch_config() -> dict:
|
||||
"""Read cforch section for ollama_url fallback."""
|
||||
f = _config_file()
|
||||
if not f.exists():
|
||||
return {}
|
||||
try:
|
||||
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
||||
except yaml.YAMLError as exc:
|
||||
return {}
|
||||
return raw.get("cforch", {}) or {}
|
||||
|
||||
|
||||
def _ollama_url(cfg: dict) -> str:
|
||||
cforch = _load_cforch_config()
|
||||
return cfg.get("ollama_url") or cforch.get("ollama_url") or "http://localhost:11434"
|
||||
|
||||
|
||||
def _cforch_url() -> str:
|
||||
cforch = _load_cforch_config()
|
||||
return cforch.get("coordinator_url") or "http://localhost:7700"
|
||||
|
||||
|
||||
def _cforch_catalog(cforch_base: str) -> list[dict]:
|
||||
"""Fetch the live cf-text catalog from cf-orch.
|
||||
|
||||
Filters out proxy entries (ollama://, vllm://, http://) — those models are
|
||||
served by their own services and should not be allocated via cf-text.
|
||||
Returns only models with real file-system paths that cf-text can load directly.
|
||||
"""
|
||||
try:
|
||||
resp = httpx.get(
|
||||
f"{cforch_base}/api/services/cf-text/catalog",
|
||||
params={"node_id": "heimdall"},
|
||||
timeout=5.0,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
raw = resp.json()
|
||||
result = []
|
||||
for model_id, entry in raw.items():
|
||||
if not isinstance(entry, dict):
|
||||
continue
|
||||
path = entry.get("path", "")
|
||||
# Skip proxy entries — they're routed through other services
|
||||
if "://" in path:
|
||||
continue
|
||||
result.append({
|
||||
"id": model_id,
|
||||
"vram_mb": entry.get("vram_mb", 0),
|
||||
"description": entry.get("description", ""),
|
||||
})
|
||||
return result
|
||||
except Exception as exc:
|
||||
logger.warning("Could not fetch cf-orch catalog: %s", exc)
|
||||
return []
|
||||
|
||||
|
||||
def _http_get_json(url: str, timeout: int = 5) -> Any:
|
||||
"""Fetch JSON from url; raise URLError on failure."""
|
||||
req = Request(url, headers={"Accept": "application/json"})
|
||||
with urlopen(req, timeout=timeout) as resp:
|
||||
return json.loads(resp.read().decode("utf-8"))
|
||||
|
||||
|
||||
def _is_online(base_url: str, health_path: str = "/api/health") -> bool:
|
||||
"""Return True if the product's health endpoint responds OK."""
|
||||
try:
|
||||
data = _http_get_json(f"{base_url.rstrip('/')}{health_path}", timeout=2)
|
||||
return bool(data)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _extract_sample(
|
||||
raw: Any,
|
||||
text_fields: list[str],
|
||||
sample_index: int = 0,
|
||||
sample_key: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Pull one item from a list or dict response and extract text_fields.
|
||||
|
||||
sample_key: if provided, unwrap raw[sample_key] before looking for a list.
|
||||
Falls back to a set of conventional envelope keys if sample_key is absent.
|
||||
"""
|
||||
item: dict[str, Any]
|
||||
if isinstance(raw, list):
|
||||
if not raw:
|
||||
return {}
|
||||
item = raw[min(sample_index, len(raw) - 1)]
|
||||
elif isinstance(raw, dict):
|
||||
# Use declared sample_key first, then fall back to conventional names.
|
||||
_ENVELOPE_KEYS = (
|
||||
"samples", "items", "results", "data", "jobs", "listings",
|
||||
"pantry", "saved_searches", "entries", "calls", "records",
|
||||
)
|
||||
search_keys = ([sample_key] if sample_key else []) + list(_ENVELOPE_KEYS)
|
||||
for key in search_keys:
|
||||
if key in raw and isinstance(raw[key], list):
|
||||
lst = raw[key]
|
||||
item = lst[min(sample_index, len(lst) - 1)] if lst else {}
|
||||
break
|
||||
else:
|
||||
item = raw
|
||||
else:
|
||||
return {}
|
||||
|
||||
parts = []
|
||||
for field in text_fields:
|
||||
val = item.get(field)
|
||||
if val and str(val).strip():
|
||||
parts.append(f"**{field}**: {val}")
|
||||
return {"item": item, "text": "\n\n".join(parts)}
|
||||
|
||||
|
||||
def _candidates_file() -> Path:
|
||||
return _DATA_DIR / "sft_candidates.jsonl"
|
||||
|
||||
|
||||
def _sse(data: dict) -> str:
|
||||
return f"data: {json.dumps(data)}\n\n"
|
||||
|
||||
|
||||
def _fetch_image_b64(image_url: str) -> str:
|
||||
"""Download an image URL and return it as a base64 string for ollama.
|
||||
|
||||
Returns empty string on any failure — a missing image is non-fatal;
|
||||
the model will still run against the text prompt alone.
|
||||
"""
|
||||
try:
|
||||
req = Request(image_url, headers={"User-Agent": "Avocet/1.0"})
|
||||
with urlopen(req, timeout=10) as resp:
|
||||
return base64.b64encode(resp.read()).decode("ascii")
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to fetch image %s: %s", image_url, exc)
|
||||
return ""
|
||||
|
||||
|
||||
def _run_ollama_streaming(
|
||||
ollama_base: str,
|
||||
model_id: str,
|
||||
prompt: str,
|
||||
temperature: float,
|
||||
system: str = "",
|
||||
images: list[str] | None = None,
|
||||
) -> tuple[str, int]:
|
||||
"""Call ollama /api/generate with stream=False; return (full_response, elapsed_ms).
|
||||
|
||||
Blocks until the model finishes; yields nothing — streaming is handled by
|
||||
the SSE generator in run_imitate().
|
||||
|
||||
system: optional system prompt passed as a separate field to ollama.
|
||||
images: list of base64-encoded image strings (vision models only).
|
||||
"""
|
||||
url = f"{ollama_base.rstrip('/')}/api/generate"
|
||||
body: dict = {
|
||||
"model": model_id,
|
||||
"prompt": prompt,
|
||||
"stream": False,
|
||||
"options": {"temperature": temperature},
|
||||
}
|
||||
if system:
|
||||
body["system"] = system
|
||||
if images:
|
||||
body["images"] = images
|
||||
payload = json.dumps(body).encode("utf-8")
|
||||
req = Request(url, data=payload, method="POST",
|
||||
headers={"Content-Type": "application/json"})
|
||||
t0 = time.time()
|
||||
try:
|
||||
with urlopen(req, timeout=120) as resp:
|
||||
body = json.loads(resp.read().decode("utf-8"))
|
||||
elapsed = int((time.time() - t0) * 1000)
|
||||
return body.get("response", ""), elapsed
|
||||
except Exception as exc:
|
||||
elapsed = int((time.time() - t0) * 1000)
|
||||
raise RuntimeError(str(exc)) from exc
|
||||
|
||||
|
||||
def _run_cftext(
|
||||
cforch_base: str,
|
||||
model_id: str,
|
||||
prompt: str,
|
||||
system: str,
|
||||
temperature: float,
|
||||
startup_timeout_s: float = 180.0,
|
||||
user_id: str | None = None,
|
||||
) -> tuple[str, int, bool]:
|
||||
"""Allocate cf-text via cf-orch, generate, release. Returns (response, elapsed_ms, cold_started).
|
||||
|
||||
Raises RuntimeError on allocation failure or generation error.
|
||||
cold_started=True means the service was launched from scratch (caller may log this).
|
||||
|
||||
Cold-start detection uses coordinator state signals (running/stopped) rather than
|
||||
polling the service health endpoint — this fails fast on model load errors instead
|
||||
of waiting out the full timeout.
|
||||
"""
|
||||
# Allocate
|
||||
alloc_resp = httpx.post(
|
||||
f"{cforch_base}/api/services/cf-text/allocate",
|
||||
json={
|
||||
"model_candidates": [model_id],
|
||||
"caller": "avocet",
|
||||
"pipeline": "imitate",
|
||||
**({"user_id": user_id} if user_id else {}),
|
||||
},
|
||||
timeout=30.0,
|
||||
)
|
||||
alloc_resp.raise_for_status()
|
||||
data = alloc_resp.json()
|
||||
service_url: str = data["url"]
|
||||
allocation_id: str = data.get("allocation_id", "")
|
||||
node_id: str = data.get("node_id", "")
|
||||
gpu_id: int | None = data.get("gpu_id")
|
||||
cold_started = data.get("started", False) and not data.get("warm", True)
|
||||
|
||||
# Wait for ready using coordinator state signals
|
||||
if cold_started:
|
||||
deadline = time.monotonic() + startup_timeout_s
|
||||
probe_misses = 0
|
||||
while time.monotonic() < deadline:
|
||||
try:
|
||||
status = httpx.get(
|
||||
f"{cforch_base}/api/services/cf-text/status", timeout=5.0
|
||||
)
|
||||
if status.is_success:
|
||||
instances = status.json().get("instances", [])
|
||||
match = next(
|
||||
(i for i in instances
|
||||
if i.get("node_id") == node_id and i.get("gpu_id") == gpu_id),
|
||||
None,
|
||||
)
|
||||
if match:
|
||||
probe_misses = 0
|
||||
state = match.get("state", "")
|
||||
if state == "running":
|
||||
break
|
||||
elif state == "stopped":
|
||||
if allocation_id:
|
||||
httpx.delete(
|
||||
f"{cforch_base}/api/services/cf-text/allocations/{allocation_id}",
|
||||
timeout=5.0,
|
||||
)
|
||||
raise RuntimeError(f"cf-text failed to load {model_id!r} (service stopped)")
|
||||
else:
|
||||
probe_misses += 1
|
||||
if probe_misses >= 6:
|
||||
# Coordinator hasn't registered instance yet — fall back to health poll
|
||||
try:
|
||||
if httpx.get(f"{service_url}/health", timeout=3.0).is_success:
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
except RuntimeError:
|
||||
raise
|
||||
except Exception:
|
||||
pass
|
||||
time.sleep(2.0)
|
||||
else:
|
||||
if allocation_id:
|
||||
httpx.delete(f"{cforch_base}/api/services/cf-text/allocations/{allocation_id}", timeout=5.0)
|
||||
raise RuntimeError(f"cf-text cold start timed out after {startup_timeout_s:.0f}s")
|
||||
|
||||
# Generate
|
||||
messages: list[dict] = []
|
||||
if system:
|
||||
messages.append({"role": "system", "content": system})
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
t0 = time.time()
|
||||
try:
|
||||
gen_resp = httpx.post(
|
||||
f"{service_url}/v1/chat/completions",
|
||||
json={
|
||||
"model": model_id,
|
||||
"messages": messages,
|
||||
"max_tokens": 300,
|
||||
"temperature": temperature,
|
||||
"stream": False,
|
||||
},
|
||||
timeout=120.0,
|
||||
)
|
||||
gen_resp.raise_for_status()
|
||||
elapsed_ms = int((time.time() - t0) * 1000)
|
||||
content = gen_resp.json()["choices"][0]["message"]["content"]
|
||||
return content.strip(), elapsed_ms, cold_started
|
||||
except Exception as exc:
|
||||
elapsed_ms = int((time.time() - t0) * 1000)
|
||||
raise RuntimeError(str(exc)) from exc
|
||||
finally:
|
||||
if allocation_id:
|
||||
try:
|
||||
httpx.delete(f"{cforch_base}/api/services/cf-text/allocations/{allocation_id}", timeout=5.0)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
# ── GET /products ──────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/products")
|
||||
def get_products() -> dict:
|
||||
"""List configured CF products with live online status."""
|
||||
cfg = _load_imitate_config()
|
||||
products_raw = cfg.get("products", []) or []
|
||||
products = []
|
||||
for p in products_raw:
|
||||
if not isinstance(p, dict):
|
||||
continue
|
||||
base_url = p.get("base_url", "")
|
||||
products.append({
|
||||
"id": p.get("id", ""),
|
||||
"name": p.get("name", ""),
|
||||
"icon": p.get("icon", "📦"),
|
||||
"description": p.get("description", ""),
|
||||
"base_url": base_url,
|
||||
"online": _is_online(base_url, p.get("health_path", "/api/health")) if base_url else False,
|
||||
})
|
||||
return {"products": products}
|
||||
|
||||
|
||||
# ── GET /products/{product_id}/sample ─────────────────────────────────────────
|
||||
|
||||
@router.get("/products/{product_id}/sample")
|
||||
def get_sample(product_id: str, index: int = 0) -> dict:
|
||||
"""Fetch a real sample from the given product's API."""
|
||||
cfg = _load_imitate_config()
|
||||
products_raw = cfg.get("products", []) or []
|
||||
|
||||
product: dict | None = None
|
||||
for p in products_raw:
|
||||
if isinstance(p, dict) and p.get("id") == product_id:
|
||||
product = p
|
||||
break
|
||||
|
||||
if product is None:
|
||||
raise HTTPException(404, f"Product '{product_id}' not in config")
|
||||
|
||||
base_url = product.get("base_url", "").rstrip("/")
|
||||
endpoint = product.get("sample_endpoint", "")
|
||||
if not base_url or not endpoint:
|
||||
raise HTTPException(422, "Product missing base_url or sample_endpoint")
|
||||
|
||||
url = f"{base_url}{endpoint}"
|
||||
try:
|
||||
raw = _http_get_json(url, timeout=5)
|
||||
except URLError as exc:
|
||||
raise HTTPException(503, f"Product API unreachable: {exc}") from exc
|
||||
except Exception as exc:
|
||||
raise HTTPException(502, f"Bad response from product API: {exc}") from exc
|
||||
|
||||
text_fields = product.get("text_fields", []) or []
|
||||
sample_key = product.get("sample_key") or None
|
||||
extracted = _extract_sample(raw, text_fields, index, sample_key=sample_key)
|
||||
if not extracted:
|
||||
raise HTTPException(404, "No sample items returned by product API")
|
||||
|
||||
prompt_template = product.get("prompt_template", "{text}")
|
||||
prompt = prompt_template.replace("{text}", extracted["text"])
|
||||
# Also substitute any {field_name} placeholders from the raw item fields.
|
||||
item = extracted.get("item", {})
|
||||
for field, val in item.items():
|
||||
prompt = prompt.replace(f"{{{field}}}", str(val) if val is not None else "")
|
||||
|
||||
# Expose system_prompt and image_url if the product API returns them.
|
||||
# system_prompt: Peregrine, Snipe (vision analysis instructions)
|
||||
# image_url: Snipe listing photos — Avocet downloads + base64-encodes at run time
|
||||
item = extracted.get("item", {})
|
||||
system_prompt = str(item.get("system_prompt", "")) if isinstance(item, dict) else ""
|
||||
image_url = str(item.get("image_url", "")) if isinstance(item, dict) else ""
|
||||
|
||||
return {
|
||||
"product_id": product_id,
|
||||
"sample_index": index,
|
||||
"text": extracted["text"],
|
||||
"prompt": prompt,
|
||||
"system_prompt": system_prompt,
|
||||
"image_url": image_url,
|
||||
"raw_item": item,
|
||||
}
|
||||
|
||||
|
||||
# ── GET /catalog ───────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/catalog")
|
||||
def get_catalog() -> dict:
|
||||
"""Return the live cf-text model catalog from cf-orch coordinator."""
|
||||
models = _cforch_catalog(_cforch_url())
|
||||
return {"models": models}
|
||||
|
||||
|
||||
# ── GET /run (SSE) ─────────────────────────────────────────────────────────────
|
||||
|
||||
def _get_imitate_session(request: Any, response: Any) -> "CloudUser | None":
|
||||
"""Optional session dependency — returns None when cloud_session is unavailable."""
|
||||
try:
|
||||
from app.cloud_session import get_session
|
||||
return get_session(request, response)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
@router.get("/run")
|
||||
def run_imitate(
|
||||
prompt: str = "",
|
||||
model_ids: str = "", # comma-separated ollama model IDs
|
||||
cf_text_model_ids: str = "", # comma-separated cf-text model IDs (via cf-orch)
|
||||
temperature: float = 0.7,
|
||||
product_id: str = "",
|
||||
system: str = "", # optional system prompt
|
||||
image_url: str = "", # optional image URL for vision models
|
||||
session: "Any" = Depends(_get_imitate_session),
|
||||
) -> StreamingResponse:
|
||||
"""Run a prompt through selected ollama models and stream results as SSE.
|
||||
|
||||
If image_url is provided, the image is downloaded once and passed to every
|
||||
model as a base64-encoded blob — allowing vision-capable local models to
|
||||
evaluate listing photos the same way Snipe's background task pipeline does.
|
||||
"""
|
||||
|
||||
if not prompt.strip():
|
||||
raise HTTPException(422, "prompt is required")
|
||||
|
||||
ollama_ids = [m.strip() for m in model_ids.split(",") if m.strip()]
|
||||
cftext_ids = [m.strip() for m in cf_text_model_ids.split(",") if m.strip()]
|
||||
if not ollama_ids and not cftext_ids:
|
||||
raise HTTPException(422, "model_ids or cf_text_model_ids is required")
|
||||
|
||||
cfg = _load_imitate_config()
|
||||
ollama_base = _ollama_url(cfg)
|
||||
cforch_base = _cforch_url()
|
||||
system_ctx = system.strip() or ""
|
||||
total_models = len(ollama_ids) + len(cftext_ids)
|
||||
|
||||
# Download image once before streaming — shared across ollama vision models
|
||||
images: list[str] = []
|
||||
if image_url.strip():
|
||||
b64 = _fetch_image_b64(image_url.strip())
|
||||
if b64:
|
||||
images = [b64]
|
||||
|
||||
def generate():
|
||||
results: list[dict] = []
|
||||
yield _sse({"type": "start", "total_models": total_models, "has_image": bool(images)})
|
||||
|
||||
# Ollama models
|
||||
for model_id in ollama_ids:
|
||||
yield _sse({"type": "model_start", "model": model_id, "service": "ollama"})
|
||||
try:
|
||||
response, elapsed_ms = _run_ollama_streaming(
|
||||
ollama_base, model_id, prompt, temperature,
|
||||
system=system_ctx, images=images or None,
|
||||
)
|
||||
result = {
|
||||
"model": model_id,
|
||||
"response": response,
|
||||
"elapsed_ms": elapsed_ms,
|
||||
"error": None,
|
||||
}
|
||||
except Exception as exc:
|
||||
result = {
|
||||
"model": model_id,
|
||||
"response": "",
|
||||
"elapsed_ms": 0,
|
||||
"error": str(exc),
|
||||
}
|
||||
results.append(result)
|
||||
yield _sse({"type": "model_done", **result})
|
||||
|
||||
# cf-text models via cf-orch — fan out in parallel when multiple models selected
|
||||
if cftext_ids:
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
# Announce all models upfront so the UI can show loading states immediately
|
||||
for model_id in cftext_ids:
|
||||
yield _sse({"type": "model_start", "model": model_id, "service": "cf-text"})
|
||||
|
||||
_user_id: str | None = getattr(session, "user_id", None)
|
||||
# Only forward real cloud user IDs — skip local/anon sessions
|
||||
if _user_id in (None, "local", "local-dev") or (_user_id or "").startswith("anon-"):
|
||||
_user_id = None
|
||||
|
||||
with ThreadPoolExecutor(max_workers=len(cftext_ids)) as pool:
|
||||
future_to_model = {
|
||||
pool.submit(
|
||||
_run_cftext, cforch_base, mid, prompt, system_ctx, temperature,
|
||||
180.0, _user_id,
|
||||
): mid
|
||||
for mid in cftext_ids
|
||||
}
|
||||
for future in as_completed(future_to_model):
|
||||
model_id = future_to_model[future]
|
||||
try:
|
||||
response, elapsed_ms, cold_started = future.result()
|
||||
if cold_started:
|
||||
yield _sse({"type": "model_coldstart", "model": model_id})
|
||||
result = {
|
||||
"model": model_id,
|
||||
"response": response,
|
||||
"elapsed_ms": elapsed_ms,
|
||||
"error": None,
|
||||
}
|
||||
except Exception as exc:
|
||||
result = {
|
||||
"model": model_id,
|
||||
"response": "",
|
||||
"elapsed_ms": 0,
|
||||
"error": str(exc),
|
||||
}
|
||||
results.append(result)
|
||||
yield _sse({"type": "model_done", **result})
|
||||
|
||||
yield _sse({"type": "complete", "results": results})
|
||||
|
||||
return StreamingResponse(
|
||||
generate(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# ── POST /push-corrections ─────────────────────────────────────────────────────
|
||||
|
||||
class ImitateResult(BaseModel):
|
||||
model: str
|
||||
response: str
|
||||
elapsed_ms: int
|
||||
error: str | None = None
|
||||
|
||||
|
||||
class PushCorrectionsRequest(BaseModel):
|
||||
product_id: str
|
||||
prompt: str
|
||||
results: list[ImitateResult]
|
||||
|
||||
|
||||
@router.post("/push-corrections")
|
||||
def push_corrections(req: PushCorrectionsRequest) -> dict:
|
||||
"""Append imitate results to sft_candidates.jsonl for human review."""
|
||||
if not req.prompt.strip():
|
||||
raise HTTPException(422, "prompt is required")
|
||||
if not req.results:
|
||||
raise HTTPException(422, "results list is empty")
|
||||
|
||||
ts = datetime.now(timezone.utc).isoformat()
|
||||
records = []
|
||||
for r in req.results:
|
||||
if r.error or not r.response.strip():
|
||||
continue
|
||||
records.append({
|
||||
"id": str(uuid.uuid4()),
|
||||
"source": "imitate",
|
||||
"product_id": req.product_id,
|
||||
"prompt_messages": [{"role": "user", "content": req.prompt}],
|
||||
"model_response": r.response,
|
||||
"model_id": r.model,
|
||||
"elapsed_ms": r.elapsed_ms,
|
||||
"status": "pending",
|
||||
"created_at": ts,
|
||||
})
|
||||
|
||||
if not records:
|
||||
raise HTTPException(422, "No non-error results to push")
|
||||
|
||||
dest = _candidates_file()
|
||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||
for record in records:
|
||||
append_jsonl(dest, record)
|
||||
|
||||
return {"pushed": len(records)}
|
||||
"""Backward-compat shim -- logic moved to app/data/imitate.py."""
|
||||
from app.data.imitate import router # noqa: F401
|
||||
from app.data.imitate import set_config_dir, set_data_dir # noqa: F401
|
||||
|
|
|
|||
151
app/models.py
151
app/models.py
|
|
@ -15,6 +15,7 @@ from __future__ import annotations
|
|||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import threading
|
||||
from datetime import datetime, timezone
|
||||
|
|
@ -60,6 +61,30 @@ _CF_ORCH_PROFILES_DIR: Path = Path(
|
|||
|
||||
router = APIRouter()
|
||||
|
||||
# ── HuggingFace auth ─────────────────────────────────────────────────────────
|
||||
|
||||
def _get_hf_token() -> str | None:
|
||||
"""Return HF token from label_tool.yaml, then HF_TOKEN / HUGGING_FACE_HUB_TOKEN env vars."""
|
||||
config_file = _ROOT / "config" / "label_tool.yaml"
|
||||
if config_file.exists():
|
||||
try:
|
||||
import yaml as _yaml
|
||||
raw = _yaml.safe_load(config_file.read_text(encoding="utf-8")) or {}
|
||||
token = (raw.get("hf_token") or raw.get("cforch", {}).get("hf_token") or "").strip()
|
||||
if token:
|
||||
return token
|
||||
except Exception:
|
||||
pass
|
||||
return os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN") or None
|
||||
|
||||
|
||||
# ── GGUF quantization detection ───────────────────────────────────────────────
|
||||
# Matches quant identifiers in GGUF filenames: Q4_K_M, Q8_0, F16, IQ3_M, etc.
|
||||
_QUANT_RE = re.compile(
|
||||
r'[._-]((?:IQ\d|Q\d)[A-Z0-9_]*|F16|BF16)\.gguf$',
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
# ── Download progress shared state ────────────────────────────────────────────
|
||||
# Updated by the background download thread; read by GET /download/stream.
|
||||
_download_progress: dict[str, Any] = {}
|
||||
|
|
@ -91,12 +116,15 @@ _TAG_TO_INFO: dict[str, _TagInfo] = {
|
|||
"audio-classification": {"adapter": None, "role": "classifier", "service": "cf-voice"},
|
||||
# TTS — cf-tts text-to-speech service
|
||||
"text-to-speech": {"adapter": None, "role": "tts", "service": "cf-tts"},
|
||||
# Vision — cf-vision image classification / embedding / VLM service
|
||||
# Vision classifiers / embedders — cf-vision (SigLIP/CLIP-style models)
|
||||
"image-classification": {"adapter": None, "role": "vision", "service": "cf-vision"},
|
||||
"zero-shot-image-classification": {"adapter": None, "role": "vision", "service": "cf-vision"},
|
||||
"image-feature-extraction": {"adapter": None, "role": "embedding", "service": "cf-vision"},
|
||||
"image-text-to-text": {"adapter": None, "role": "vlm", "service": "cf-vision"},
|
||||
"visual-question-answering": {"adapter": None, "role": "vlm", "service": "cf-vision"},
|
||||
# Generative VLMs (image+text → text) — run under vllm, not cf-vision.
|
||||
# cf-vision is a classifier/embedder service; generative VLMs like Qwen-VL,
|
||||
# LLaVA, and InternVL are textgen models that happen to accept image inputs.
|
||||
"image-text-to-text": {"adapter": None, "role": "vlm", "service": "vllm"},
|
||||
"visual-question-answering": {"adapter": None, "role": "vlm", "service": "vllm"},
|
||||
# Image generation — cf-image (text → image; distinct from cf-vision image understanding)
|
||||
"text-to-image": {"adapter": None, "role": "image-gen", "service": "cf-image"},
|
||||
# Embedding — cf-core shared embedding layer
|
||||
|
|
@ -195,10 +223,17 @@ def _get_queue_entry(entry_id: str) -> dict | None:
|
|||
def _catalog_key(repo_id: str) -> str:
|
||||
"""Derive a readable catalog key from repo_id.
|
||||
|
||||
ibm-granite/granite-4.1-8b → granite-4.1-8b
|
||||
facebook/bart-large-cnn → bart-large-cnn
|
||||
ibm-granite/granite-4.1-8b → granite-4.1-8b
|
||||
facebook/bart-large-cnn → bart-large-cnn
|
||||
WithinUsAI/Opus4.7-GODs.Ghost.Codex-4B.GGuF → opus4.7-gods.ghost.codex-4b
|
||||
|
||||
The coordinator skips catalog lookup for keys ending in ".gguf" (treats them
|
||||
as direct file paths). Strip the suffix so GGUF repo names produce valid keys.
|
||||
"""
|
||||
return repo_id.split("/", 1)[-1].lower()
|
||||
key = repo_id.split("/", 1)[-1].lower()
|
||||
if key.endswith(".gguf"):
|
||||
key = key[:-5]
|
||||
return key
|
||||
|
||||
|
||||
def _insert_catalog_entry(content: str, entry_lines: str) -> str:
|
||||
|
|
@ -290,6 +325,15 @@ def _register_in_node_catalogs(
|
|||
max_mb: int = cf_text.get("max_mb", 0)
|
||||
catalog: dict = cf_text.get("catalog") or {}
|
||||
|
||||
# If the node has a different local model dir, remap the NFS path.
|
||||
model_base = cf_text.get("model_base_path", "").rstrip("/")
|
||||
if model_base:
|
||||
nfs_base = str(_CF_TEXT_MODELS_DIR).rstrip("/")
|
||||
model_name = local_path.name
|
||||
effective_path_str = f"{model_base}/{model_name}"
|
||||
else:
|
||||
effective_path_str = local_path_str
|
||||
|
||||
# Skip if key already exists
|
||||
if model_key in catalog:
|
||||
logger.debug("Key %r already in %s — skipping", model_key, yaml_file.name)
|
||||
|
|
@ -301,10 +345,10 @@ def _register_in_node_catalogs(
|
|||
for entry in catalog.values()
|
||||
if isinstance(entry, dict)
|
||||
}
|
||||
if local_path_str in registered_paths or any(
|
||||
p.startswith(local_path_str + "/") for p in registered_paths
|
||||
if effective_path_str in registered_paths or any(
|
||||
p.startswith(effective_path_str + "/") for p in registered_paths
|
||||
):
|
||||
logger.debug("Path %s already registered in %s — skipping", local_path_str, yaml_file.name)
|
||||
logger.debug("Path %s already registered in %s — skipping", effective_path_str, yaml_file.name)
|
||||
continue
|
||||
|
||||
# Determine whether model fits at FP16 or needs 4-bit
|
||||
|
|
@ -330,12 +374,18 @@ def _register_in_node_catalogs(
|
|||
if needs_4bit
|
||||
else f" # FP16 file-size estimate"
|
||||
)
|
||||
env_block = (
|
||||
f" env:\n"
|
||||
f" CF_TEXT_4BIT: \"1\"\n"
|
||||
if needs_4bit else ""
|
||||
)
|
||||
entry_block = (
|
||||
f" # auto-registered by avocet on download\n"
|
||||
f" {model_key}:\n"
|
||||
f" path: {local_path_str}\n"
|
||||
f" path: {effective_path_str}\n"
|
||||
f" vram_mb: {vram_for_node}{vram_comment}\n"
|
||||
f" description: \"{desc}\"\n"
|
||||
f"{env_block}"
|
||||
)
|
||||
|
||||
new_content = _insert_catalog_entry(content, entry_block)
|
||||
|
|
@ -388,12 +438,17 @@ def _run_download(
|
|||
role: str | None = None,
|
||||
service: str | None = None,
|
||||
model_size_bytes: int = 0,
|
||||
quant_pattern: str | None = None,
|
||||
) -> None:
|
||||
"""Background thread: download model via huggingface_hub.snapshot_download.
|
||||
|
||||
model_size_bytes is the sum of file sizes reported by the HF API (siblings).
|
||||
It is used to estimate vram_mb and written to model_info.json so cf-orch can
|
||||
budget VRAM when allocating a cf-text instance for this model.
|
||||
|
||||
quant_pattern: when set, restricts snapshot_download to only files matching
|
||||
*{quant_pattern}*.gguf (plus metadata). Avoids downloading every quant variant
|
||||
from GGUF-only repos like bartowski/*.
|
||||
"""
|
||||
global _download_progress
|
||||
local_dir = _model_dir_for(repo_id, service)
|
||||
|
|
@ -422,10 +477,20 @@ def _run_download(
|
|||
|
||||
local_dir.mkdir(parents=True, exist_ok=True)
|
||||
poll_thread.start()
|
||||
snapshot_download(
|
||||
repo_id=repo_id,
|
||||
local_dir=str(local_dir),
|
||||
)
|
||||
|
||||
dl_kwargs: dict[str, Any] = {"repo_id": repo_id, "local_dir": str(local_dir)}
|
||||
hf_token = _get_hf_token()
|
||||
if hf_token:
|
||||
dl_kwargs["token"] = hf_token
|
||||
if quant_pattern:
|
||||
# Include both cases: repos use mixed conventions (Q6_K vs q6_k).
|
||||
dl_kwargs["allow_patterns"] = [
|
||||
f"*{quant_pattern.upper()}*.gguf",
|
||||
f"*{quant_pattern.lower()}*.gguf",
|
||||
"*.json",
|
||||
"README.md",
|
||||
]
|
||||
snapshot_download(**dl_kwargs)
|
||||
|
||||
# Estimate VRAM from reported file size.
|
||||
# HF siblings sizes are pre-quantisation file sizes; add 10% for KV cache
|
||||
|
|
@ -531,9 +596,31 @@ def lookup_model(repo_id: str) -> dict:
|
|||
)
|
||||
logger.warning("Unsupported pipeline_tag %r for %s", pipeline_tag, repo_id)
|
||||
|
||||
# Estimate model size from siblings list
|
||||
# Detect GGUF files and parse quant names from siblings list.
|
||||
# For GGUF-only repos (bartowski, TheBloke, etc.) this lets the UI show
|
||||
# a per-quant size picker instead of downloading every variant.
|
||||
siblings = data.get("siblings") or []
|
||||
model_size_bytes: int = sum(s.get("size", 0) for s in siblings if isinstance(s, dict))
|
||||
gguf_files: list[dict] = []
|
||||
for s in siblings:
|
||||
if not isinstance(s, dict):
|
||||
continue
|
||||
fname: str = s.get("rfilename", "")
|
||||
if not fname.lower().endswith(".gguf"):
|
||||
continue
|
||||
m = _QUANT_RE.search(fname)
|
||||
gguf_files.append({
|
||||
"filename": fname,
|
||||
"size": s.get("size", 0) or 0,
|
||||
"quant_name": m.group(1).upper() if m else None,
|
||||
})
|
||||
gguf_files.sort(key=lambda f: f["size"])
|
||||
|
||||
# model_size_bytes: total of all siblings (for non-GGUF repos) or all GGUFs only.
|
||||
# For GGUF repos the frontend will substitute the selected quant's size on submit.
|
||||
if gguf_files:
|
||||
model_size_bytes: int = sum(f["size"] for f in gguf_files)
|
||||
else:
|
||||
model_size_bytes = sum(s.get("size", 0) for s in siblings if isinstance(s, dict))
|
||||
|
||||
# Description: first 300 chars of card data (modelId field used as fallback)
|
||||
card_data = data.get("cardData") or {}
|
||||
|
|
@ -549,6 +636,7 @@ def lookup_model(repo_id: str) -> dict:
|
|||
"compatible": compatible,
|
||||
"warning": warning,
|
||||
"model_size_bytes": model_size_bytes,
|
||||
"gguf_files": gguf_files if gguf_files else None,
|
||||
"description": description,
|
||||
"tags": data.get("tags") or [],
|
||||
"downloads": data.get("downloads") or 0,
|
||||
|
|
@ -579,6 +667,9 @@ class QueueAddRequest(BaseModel):
|
|||
# Stored in the queue entry so approve can pass it to _run_download
|
||||
# without a second HF API round-trip.
|
||||
model_size_bytes: int = 0
|
||||
# GGUF quantization pattern (e.g. "Q5_K_M"). When set, snapshot_download
|
||||
# restricts to *{quant_pattern}*.gguf instead of fetching all variants.
|
||||
quant_pattern: str | None = None
|
||||
|
||||
|
||||
@router.post("/queue", status_code=201)
|
||||
|
|
@ -597,6 +688,7 @@ def add_to_queue(req: QueueAddRequest) -> dict:
|
|||
"role": req.role,
|
||||
"service": req.service,
|
||||
"model_size_bytes": req.model_size_bytes,
|
||||
"quant_pattern": req.quant_pattern,
|
||||
"status": "pending",
|
||||
"queued_at": datetime.now(timezone.utc).isoformat(),
|
||||
}
|
||||
|
|
@ -629,6 +721,7 @@ def approve_queue_entry(entry_id: str) -> dict:
|
|||
entry.get("role"),
|
||||
entry.get("service"),
|
||||
entry.get("model_size_bytes", 0),
|
||||
entry.get("quant_pattern"),
|
||||
),
|
||||
daemon=True,
|
||||
name=f"model-download-{entry_id}",
|
||||
|
|
@ -638,6 +731,32 @@ def approve_queue_entry(entry_id: str) -> dict:
|
|||
return {"ok": True}
|
||||
|
||||
|
||||
# ── PATCH /queue/{id} ─────────────────────────────────────────────────────────
|
||||
|
||||
class QueuePatchRequest(BaseModel):
|
||||
service: str | None = None
|
||||
role: str | None = None
|
||||
|
||||
|
||||
@router.patch("/queue/{entry_id}")
|
||||
def patch_queue_entry(entry_id: str, body: QueuePatchRequest) -> dict:
|
||||
"""Update mutable fields (service, role) on a pending queue entry."""
|
||||
entry = _get_queue_entry(entry_id)
|
||||
if entry is None:
|
||||
raise HTTPException(404, f"Queue entry {entry_id!r} not found")
|
||||
if entry.get("status") != "pending":
|
||||
raise HTTPException(409, f"Only pending entries can be patched (current: {entry.get('status')!r})")
|
||||
|
||||
updates: dict = {}
|
||||
if body.service is not None:
|
||||
updates["service"] = body.service
|
||||
if body.role is not None:
|
||||
updates["role"] = body.role
|
||||
|
||||
updated = _update_queue_entry(entry_id, updates)
|
||||
return updated or {}
|
||||
|
||||
|
||||
# ── DELETE /queue/{id} ─────────────────────────────────────────────────────────
|
||||
|
||||
@router.delete("/queue/{entry_id}")
|
||||
|
|
|
|||
359
app/nodes.py
Normal file
359
app/nodes.py
Normal file
|
|
@ -0,0 +1,359 @@
|
|||
"""Avocet — Node Management API.
|
||||
|
||||
Proxies cf-orch coordinator and agent APIs to expose per-node GPU state,
|
||||
service affinity management, and Ollama model management.
|
||||
|
||||
Config is read from label_tool.yaml under the `cforch:` key.
|
||||
The `profiles_dir` key (new) points to the cf-orch node profile YAML directory.
|
||||
|
||||
Module-level globals follow the set_config_dir() testability pattern from cforch.py.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import yaml
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_ROOT = Path(__file__).parent.parent
|
||||
_CONFIG_DIR: Path | None = None # override in tests
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# ── Testability seams ──────────────────────────────────────────────────────────
|
||||
|
||||
def set_config_dir(path: Path | None) -> None:
|
||||
global _CONFIG_DIR
|
||||
_CONFIG_DIR = path
|
||||
|
||||
|
||||
# ── Internal helpers ───────────────────────────────────────────────────────────
|
||||
|
||||
def _config_file() -> Path:
|
||||
if _CONFIG_DIR is not None:
|
||||
return _CONFIG_DIR / "label_tool.yaml"
|
||||
return _ROOT / "config" / "label_tool.yaml"
|
||||
|
||||
|
||||
def _load_config() -> dict:
|
||||
"""Read label_tool.yaml cforch section. Returns empty dict on missing or parse error."""
|
||||
f = _config_file()
|
||||
if not f.exists():
|
||||
return {}
|
||||
try:
|
||||
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
||||
return raw.get("cforch", {}) or {}
|
||||
except yaml.YAMLError as exc:
|
||||
logger.warning("Failed to parse config %s: %s", f, exc)
|
||||
return {}
|
||||
|
||||
|
||||
def _profiles_dir() -> Path | None:
|
||||
"""Return the cf-orch node profiles directory, or None if not configured."""
|
||||
cfg = _load_config()
|
||||
pd = cfg.get("profiles_dir", "") or ""
|
||||
if pd:
|
||||
return Path(pd)
|
||||
bench = cfg.get("bench_script", "") or ""
|
||||
if bench:
|
||||
return Path(bench).parent.parent / "profiles" / "nodes"
|
||||
return None
|
||||
|
||||
|
||||
def _profile_path(node_id: str) -> Path | None:
|
||||
"""Return the path to a node's profile YAML, or None if profiles_dir is unknown."""
|
||||
pd = _profiles_dir()
|
||||
if pd is None:
|
||||
return None
|
||||
return pd / f"{node_id}.yaml"
|
||||
|
||||
|
||||
def _load_profile(node_id: str) -> dict | None:
|
||||
"""Load and parse a node profile YAML. Returns None if not found or malformed."""
|
||||
p = _profile_path(node_id)
|
||||
if p is None or not p.exists():
|
||||
return None
|
||||
try:
|
||||
return yaml.safe_load(p.read_text(encoding="utf-8")) or {}
|
||||
except yaml.YAMLError as exc:
|
||||
logger.warning("Malformed profile YAML %s: %s", p, exc)
|
||||
return None
|
||||
|
||||
|
||||
def _get_ollama_url(node_id: str) -> str:
|
||||
"""Derive Ollama URL from the node profile's agent_url (same host, port 11434)."""
|
||||
profile = _load_profile(node_id)
|
||||
if profile:
|
||||
nodes_section = profile.get("nodes", {}) or {}
|
||||
node_entry = nodes_section.get(node_id, {}) or {}
|
||||
agent_url = node_entry.get("agent_url", "") or ""
|
||||
if agent_url:
|
||||
parsed = urlparse(agent_url)
|
||||
return f"{parsed.scheme}://{parsed.hostname}:11434"
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Cannot determine Ollama URL for node {node_id}: no agent_url in profile",
|
||||
)
|
||||
|
||||
|
||||
# ── Endpoints ──────────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/nodes")
|
||||
def list_nodes() -> list:
|
||||
"""Return all nodes with live GPU stats merged with profile YAML."""
|
||||
import httpx
|
||||
|
||||
cfg = _load_config()
|
||||
coordinator_url = cfg.get("coordinator_url", "") or ""
|
||||
if not coordinator_url:
|
||||
return []
|
||||
|
||||
try:
|
||||
r = httpx.get(f"{coordinator_url}/api/nodes", timeout=5.0)
|
||||
r.raise_for_status()
|
||||
coord_nodes: list[dict] = r.json()
|
||||
except httpx.HTTPError as exc:
|
||||
logger.warning("Coordinator unreachable: %s", exc)
|
||||
return []
|
||||
|
||||
try:
|
||||
sr = httpx.get(f"{coordinator_url}/api/services", timeout=5.0)
|
||||
sr.raise_for_status()
|
||||
services_data: list[dict] = sr.json()
|
||||
except httpx.HTTPError:
|
||||
logger.warning("Services API unreachable for %s, skipping", coordinator_url)
|
||||
services_data = []
|
||||
|
||||
# Build per-node, per-GPU running services map
|
||||
running: dict[str, dict[int, list[str]]] = {}
|
||||
for svc in services_data:
|
||||
nid = svc.get("node_id", "")
|
||||
gid = svc.get("gpu_id")
|
||||
svc_name = svc.get("service", "")
|
||||
if nid and gid is not None and svc_name:
|
||||
running.setdefault(nid, {}).setdefault(gid, []).append(svc_name)
|
||||
|
||||
result = []
|
||||
for node in coord_nodes:
|
||||
node_id = node.get("node_id", "") or node.get("id", "")
|
||||
profile = _load_profile(node_id) if node_id else None
|
||||
profile_loaded = profile is not None
|
||||
|
||||
gpus = []
|
||||
for gpu in (node.get("gpus", []) or []):
|
||||
gpu_id = gpu.get("gpu_id", gpu.get("id", 0))
|
||||
services_assigned: list[str] = []
|
||||
if profile:
|
||||
node_entry = (profile.get("nodes", {}) or {}).get(node_id, {}) or {}
|
||||
for g in (node_entry.get("gpus", []) or []):
|
||||
if isinstance(g, dict) and g.get("id") == gpu_id:
|
||||
services_assigned = g.get("services", []) or []
|
||||
break
|
||||
gpus.append({
|
||||
"gpu_id": gpu_id,
|
||||
"card": gpu.get("card", ""),
|
||||
"vram_total_mb": gpu.get("vram_total_mb", 0),
|
||||
"vram_used_mb": gpu.get("vram_used_mb", 0),
|
||||
"vram_free_mb": gpu.get("vram_free_mb", 0),
|
||||
"temp_c": gpu.get("temp_c"),
|
||||
"utilization_pct": gpu.get("utilization_pct"),
|
||||
"compute_cap": gpu.get("compute_cap"),
|
||||
"services_assigned": services_assigned,
|
||||
"services_running": running.get(node_id, {}).get(gpu_id, []),
|
||||
})
|
||||
|
||||
services_catalog: dict = {}
|
||||
if profile:
|
||||
for svc_name, svc_info in (profile.get("services", {}) or {}).items():
|
||||
catalog = svc_info.get("catalog", {}) or {}
|
||||
services_catalog[svc_name] = {
|
||||
"min_compute_cap": svc_info.get("min_compute_cap", 0.0),
|
||||
"max_mb": svc_info.get("max_mb", 0),
|
||||
"catalog_size": len(catalog),
|
||||
}
|
||||
|
||||
result.append({
|
||||
"node_id": node_id,
|
||||
"online": node.get("online", True),
|
||||
"agent_url": node.get("agent_url", ""),
|
||||
"gpus": gpus,
|
||||
"profile_loaded": profile_loaded,
|
||||
"services_catalog": services_catalog,
|
||||
})
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@router.get("/nodes/{node_id}/profile")
|
||||
def get_node_profile(node_id: str) -> dict:
|
||||
"""Return the full parsed profile YAML for a node."""
|
||||
p = _profile_path(node_id)
|
||||
if p is None or not p.exists():
|
||||
raise HTTPException(404, f"No profile found for node {node_id}")
|
||||
try:
|
||||
data = yaml.safe_load(p.read_text(encoding="utf-8")) or {}
|
||||
except yaml.YAMLError as exc:
|
||||
raise HTTPException(500, f"Malformed profile YAML: {exc}")
|
||||
return data
|
||||
|
||||
|
||||
class UpdateServicesRequest(BaseModel):
|
||||
services: list[str]
|
||||
|
||||
|
||||
@router.post("/nodes/{node_id}/gpu/{gpu_id}/services")
|
||||
def update_gpu_services(node_id: str, gpu_id: int, body: UpdateServicesRequest) -> dict:
|
||||
"""Set service assignment for a GPU with compatibility validation, then atomic write."""
|
||||
import httpx
|
||||
|
||||
cfg = _load_config()
|
||||
coordinator_url = cfg.get("coordinator_url", "") or ""
|
||||
|
||||
p = _profile_path(node_id)
|
||||
if p is None or not p.exists():
|
||||
raise HTTPException(404, f"No profile found for node {node_id}")
|
||||
|
||||
try:
|
||||
profile = yaml.safe_load(p.read_text(encoding="utf-8")) or {}
|
||||
except yaml.YAMLError as exc:
|
||||
raise HTTPException(500, f"Malformed profile YAML: {exc}")
|
||||
|
||||
nodes_section = profile.get("nodes", {}) or {}
|
||||
node_entry = nodes_section.get(node_id, {}) or {}
|
||||
gpu_list = node_entry.get("gpus", []) or []
|
||||
|
||||
gpu_entry = next(
|
||||
(g for g in gpu_list if isinstance(g, dict) and g.get("id") == gpu_id),
|
||||
None,
|
||||
)
|
||||
if gpu_entry is None:
|
||||
raise HTTPException(404, f"GPU {gpu_id} not found in profile for node {node_id}")
|
||||
|
||||
gpu_compute_cap: float = gpu_entry.get("compute_cap") or 0.0
|
||||
gpu_vram_mb: int = gpu_entry.get("vram_mb") or 0
|
||||
services_def = profile.get("services", {}) or {}
|
||||
|
||||
for svc_name in body.services:
|
||||
if svc_name not in services_def:
|
||||
raise HTTPException(422, f"Service '{svc_name}' not defined in profile services dict")
|
||||
svc = services_def[svc_name]
|
||||
min_cap: float = svc.get("min_compute_cap", 0.0) or 0.0
|
||||
if gpu_compute_cap < min_cap:
|
||||
raise HTTPException(
|
||||
422,
|
||||
f"Service '{svc_name}' requires compute_cap >= {min_cap}; GPU has {gpu_compute_cap}",
|
||||
)
|
||||
catalog = svc.get("catalog", {}) or {}
|
||||
min_catalog_vram = (
|
||||
min((m.get("vram_mb", 0) for m in catalog.values()), default=0)
|
||||
if catalog else svc.get("max_mb", 0)
|
||||
)
|
||||
if gpu_vram_mb < min_catalog_vram:
|
||||
raise HTTPException(
|
||||
422,
|
||||
f"Service '{svc_name}' requires {min_catalog_vram} MB VRAM; GPU has {gpu_vram_mb} MB",
|
||||
)
|
||||
|
||||
# Immutable update of GPU services list
|
||||
new_gpu_list = [
|
||||
({**g, "services": body.services} if isinstance(g, dict) and g.get("id") == gpu_id else g)
|
||||
for g in gpu_list
|
||||
]
|
||||
new_profile = {
|
||||
**profile,
|
||||
"nodes": {
|
||||
**nodes_section,
|
||||
node_id: {**node_entry, "gpus": new_gpu_list},
|
||||
},
|
||||
}
|
||||
|
||||
# Atomic write: write to .tmp then rename
|
||||
tmp_yaml = Path(str(p) + ".tmp")
|
||||
tmp_yaml.write_text(yaml.dump(new_profile, default_flow_style=False), encoding="utf-8")
|
||||
os.replace(tmp_yaml, p)
|
||||
|
||||
# Trigger coordinator profile reload
|
||||
reloaded = False
|
||||
if coordinator_url:
|
||||
try:
|
||||
rr = httpx.post(
|
||||
f"{coordinator_url}/api/nodes/{node_id}/reload-profile", timeout=5.0
|
||||
)
|
||||
reloaded = rr.status_code < 300
|
||||
except Exception as exc:
|
||||
logger.warning("Coordinator reload failed for node %s: %s", node_id, exc)
|
||||
|
||||
return {"ok": True, "reloaded": reloaded, "warnings": []}
|
||||
|
||||
# ── Ollama model management ────────────────────────────────────────────────────
|
||||
|
||||
class PullRequest(BaseModel):
|
||||
name: str
|
||||
|
||||
|
||||
@router.get("/nodes/{node_id}/models/ollama")
|
||||
def list_ollama_models(node_id: str) -> dict:
|
||||
"""Proxy GET {ollama_url}/api/tags for a specific node."""
|
||||
import httpx
|
||||
|
||||
ollama_url = _get_ollama_url(node_id)
|
||||
try:
|
||||
r = httpx.get(f"{ollama_url}/api/tags", timeout=10.0)
|
||||
r.raise_for_status()
|
||||
return r.json()
|
||||
except Exception as exc:
|
||||
return {"error": str(exc)}
|
||||
|
||||
|
||||
@router.post("/nodes/{node_id}/models/ollama/pull")
|
||||
def pull_ollama_model(node_id: str, body: PullRequest) -> StreamingResponse:
|
||||
"""Stream Ollama pull progress as SSE events."""
|
||||
import httpx
|
||||
|
||||
if not body.name:
|
||||
raise HTTPException(400, "name is required")
|
||||
|
||||
ollama_url = _get_ollama_url(node_id)
|
||||
|
||||
def stream():
|
||||
try:
|
||||
with httpx.stream(
|
||||
"POST",
|
||||
f"{ollama_url}/api/pull",
|
||||
json={"name": body.name, "stream": True},
|
||||
timeout=300.0,
|
||||
) as resp:
|
||||
for line in resp.iter_lines():
|
||||
if line:
|
||||
yield f"data: {line}\n\n"
|
||||
except Exception as exc:
|
||||
yield f"data: {json.dumps({'error': str(exc)})}\n\n"
|
||||
|
||||
return StreamingResponse(stream(), media_type="text/event-stream")
|
||||
|
||||
|
||||
@router.delete("/nodes/{node_id}/models/ollama/{name:path}")
|
||||
def delete_ollama_model(node_id: str, name: str) -> dict:
|
||||
"""Proxy DELETE to Ollama for a specific node."""
|
||||
import httpx
|
||||
|
||||
ollama_url = _get_ollama_url(node_id)
|
||||
try:
|
||||
r = httpx.request("DELETE", f"{ollama_url}/api/delete", json={"name": name}, timeout=10.0)
|
||||
if r.status_code == 404:
|
||||
raise HTTPException(404, f"Model '{name}' not found on node {node_id}")
|
||||
r.raise_for_status()
|
||||
return {"ok": True}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
raise HTTPException(502, f"Ollama unreachable: {exc}")
|
||||
323
app/plans_bench.py
Normal file
323
app/plans_bench.py
Normal file
|
|
@ -0,0 +1,323 @@
|
|||
"""Avocet — CF planning benchmark integration API.
|
||||
|
||||
Wraps scripts/benchmark_plans.py and exposes it via the Avocet API.
|
||||
Connection config (api_base) is read from label_tool.yaml under the
|
||||
`plans_bench:` key (optional; falls back to localhost:8080).
|
||||
|
||||
All endpoints are registered on `router` (FastAPI APIRouter).
|
||||
api.py includes this router with prefix="/api/plans-bench".
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import subprocess as _subprocess
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
import yaml
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_ROOT = Path(__file__).parent.parent
|
||||
_CONFIG_DIR: Path | None = None # override in tests via set_config_dir()
|
||||
_BENCH_RUNNING: bool = False
|
||||
_bench_proc: Any = None
|
||||
|
||||
_BENCH_SCRIPT = _ROOT / "scripts" / "benchmark_plans.py"
|
||||
_RESULTS_DIR = _ROOT / "data" / "plans_bench_results"
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# ── Registered model shortcuts (mirrors benchmark_plans.MODEL_REGISTRY) ────────
|
||||
# Kept here so the UI can list them without importing the script.
|
||||
|
||||
MODEL_REGISTRY: dict[str, str] = {
|
||||
"llama3.2-3b": "Llama 3.2 3B Instruct (local via cf-text)",
|
||||
"llama3.2-1b": "Llama 3.2 1B Instruct (local via cf-text)",
|
||||
"mistral-7b": "Mistral 7B Instruct (local via cf-text)",
|
||||
"phi3-mini": "Phi-3 Mini 3.8B (local via cf-text)",
|
||||
"qwen2.5-3b": "Qwen 2.5 3B Instruct (local via cf-text)",
|
||||
}
|
||||
|
||||
RUBRIC_LABELS: dict[str, str] = {
|
||||
"task_structure": "Task structure (checkboxes + commits)",
|
||||
"tier_awareness": "Tier awareness (Free/Paid/Premium/Ultra)",
|
||||
"privacy_pillar": "Privacy pillar (local-first, no logging)",
|
||||
"safety_pillar": "Safety pillar (human approval, reversibility)",
|
||||
"accessibility": "Accessibility (ND/adaptive users)",
|
||||
"license_split": "License awareness (MIT vs BSL)",
|
||||
"file_paths": "File paths (plausible project paths)",
|
||||
"cf_conventions": "CF conventions (conda, manage.sh, /Library/…)",
|
||||
"length_ok": "Response length (200–2500 words)",
|
||||
}
|
||||
|
||||
|
||||
# ── Testability seam ───────────────────────────────────────────────────────────
|
||||
|
||||
def set_config_dir(path: Path | None) -> None:
|
||||
global _CONFIG_DIR
|
||||
_CONFIG_DIR = path
|
||||
|
||||
|
||||
# ── Internal helpers ───────────────────────────────────────────────────────────
|
||||
|
||||
def _config_file() -> Path:
|
||||
if _CONFIG_DIR is not None:
|
||||
return _CONFIG_DIR / "label_tool.yaml"
|
||||
return _ROOT / "config" / "label_tool.yaml"
|
||||
|
||||
|
||||
def _load_config() -> dict:
|
||||
f = _config_file()
|
||||
cforch_cfg: dict = {}
|
||||
bench_cfg: dict = {}
|
||||
if f.exists():
|
||||
try:
|
||||
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
||||
cforch_cfg = raw.get("cforch", {}) or {}
|
||||
bench_cfg = raw.get("plans_bench", {}) or {}
|
||||
except yaml.YAMLError as exc:
|
||||
logger.warning("Failed to parse plans_bench config %s: %s", f, exc)
|
||||
return {
|
||||
"coordinator_url": cforch_cfg.get("coordinator_url",
|
||||
bench_cfg.get("coordinator_url", "http://10.1.10.71:7700")),
|
||||
"python_bin": cforch_cfg.get("python_bin",
|
||||
bench_cfg.get("python_bin", "/devl/miniconda3/envs/cf/bin/python")),
|
||||
}
|
||||
|
||||
|
||||
def _results_file(run_id: str) -> Path:
|
||||
return _RESULTS_DIR / f"{run_id}.json"
|
||||
|
||||
|
||||
# ── GET /models ────────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/models")
|
||||
def get_models() -> dict:
|
||||
"""Return registered model shortcuts, live cf-orch catalog, and rubric labels."""
|
||||
cfg = _load_config()
|
||||
|
||||
cforch_models: list[dict] = []
|
||||
try:
|
||||
resp = httpx.get(
|
||||
f"{cfg['coordinator_url']}/api/services/cf-text/catalog",
|
||||
timeout=5.0,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
for model_id, entry in resp.json().items():
|
||||
if isinstance(entry, dict):
|
||||
cforch_models.append({
|
||||
"id": model_id,
|
||||
"name": model_id,
|
||||
"vram_mb": entry.get("vram_mb"),
|
||||
"description": entry.get("description", ""),
|
||||
})
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to fetch cf-orch catalog: %s", exc)
|
||||
|
||||
return {
|
||||
"registry": [
|
||||
{"key": k, "description": v}
|
||||
for k, v in MODEL_REGISTRY.items()
|
||||
],
|
||||
"cforch_models": cforch_models,
|
||||
"coordinator_url": cfg["coordinator_url"],
|
||||
"rubric_labels": RUBRIC_LABELS,
|
||||
}
|
||||
|
||||
|
||||
# ── GET /run ───────────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/run")
|
||||
def run_plans_benchmark(
|
||||
models: str = Query(..., description="Comma-separated model IDs (registry keys or cf-orch model names)"),
|
||||
prompt_ids: str = Query("", description="Comma-separated prompt IDs to run (empty = all 10)"),
|
||||
use_cforch: bool = Query(True, description="Route inference through cf-orch coordinator"),
|
||||
api_base: str = Query("", description="Direct API base URL when not using cf-orch"),
|
||||
workers: int = Query(1, ge=1, le=8, description="Number of models to benchmark concurrently"),
|
||||
) -> StreamingResponse:
|
||||
"""Spawn benchmark_plans.py and stream stdout as SSE progress events.
|
||||
|
||||
On successful completion emits a `type: result` event with parsed JSON
|
||||
and saves results to data/plans_bench_results/<run_id>.json.
|
||||
"""
|
||||
global _BENCH_RUNNING, _bench_proc
|
||||
|
||||
if _BENCH_RUNNING:
|
||||
raise HTTPException(409, "A planning benchmark is already running")
|
||||
|
||||
cfg = _load_config()
|
||||
python_bin = cfg["python_bin"]
|
||||
coordinator_url = cfg["coordinator_url"]
|
||||
|
||||
model_keys = [m.strip() for m in models.split(",") if m.strip()]
|
||||
if not model_keys:
|
||||
raise HTTPException(400, "At least one model key is required")
|
||||
|
||||
run_id = datetime.now(tz=timezone.utc).strftime("plans_%Y-%m-%d_%H%M%S")
|
||||
output_path = _results_file(run_id)
|
||||
_RESULTS_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def generate():
|
||||
global _BENCH_RUNNING, _bench_proc
|
||||
|
||||
if not _BENCH_SCRIPT.exists():
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': f'benchmark_plans.py not found at {_BENCH_SCRIPT}'})}\n\n"
|
||||
return
|
||||
|
||||
cmd = [python_bin, str(_BENCH_SCRIPT)]
|
||||
if len(model_keys) > 1:
|
||||
cmd.extend(["--compare"] + model_keys)
|
||||
else:
|
||||
cmd.extend(["--model", model_keys[0]])
|
||||
|
||||
if use_cforch:
|
||||
cmd.extend(["--cforch", "--cforch-url", coordinator_url])
|
||||
elif api_base.strip():
|
||||
cmd.extend(["--api-base", api_base.strip()])
|
||||
|
||||
cmd.extend(["--verbose", "--output", str(output_path)])
|
||||
if workers > 1:
|
||||
cmd.extend(["--workers", str(workers)])
|
||||
|
||||
if prompt_ids.strip():
|
||||
cmd.extend(["--prompts"] + [p.strip() for p in prompt_ids.split(",") if p.strip()])
|
||||
|
||||
_BENCH_RUNNING = True
|
||||
try:
|
||||
proc = _subprocess.Popen(
|
||||
cmd,
|
||||
stdout=_subprocess.PIPE,
|
||||
stderr=_subprocess.STDOUT,
|
||||
text=True,
|
||||
bufsize=1,
|
||||
cwd=str(_ROOT),
|
||||
)
|
||||
_bench_proc = proc
|
||||
try:
|
||||
for line in proc.stdout:
|
||||
line = line.rstrip()
|
||||
if line:
|
||||
yield f"data: {json.dumps({'type': 'progress', 'message': line})}\n\n"
|
||||
proc.wait()
|
||||
if proc.returncode == 0 and output_path.exists():
|
||||
try:
|
||||
results = json.loads(output_path.read_text(encoding="utf-8"))
|
||||
yield f"data: {json.dumps({'type': 'result', 'run_id': run_id, 'results': results})}\n\n"
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to read plans benchmark output: %s", exc)
|
||||
yield f"data: {json.dumps({'type': 'complete'})}\n\n"
|
||||
else:
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': f'Process exited with code {proc.returncode}'})}\n\n"
|
||||
finally:
|
||||
_bench_proc = None
|
||||
except Exception as exc:
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': str(exc)})}\n\n"
|
||||
finally:
|
||||
_BENCH_RUNNING = False
|
||||
|
||||
return StreamingResponse(
|
||||
generate(),
|
||||
media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
|
||||
)
|
||||
|
||||
|
||||
# ── GET /results ───────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/results")
|
||||
def list_results() -> list[dict]:
|
||||
"""List past planning benchmark runs, newest first."""
|
||||
if not _RESULTS_DIR.exists():
|
||||
return []
|
||||
|
||||
runs: list[dict] = []
|
||||
for f in sorted(_RESULTS_DIR.glob("plans_*.json"), reverse=True):
|
||||
run_id = f.stem
|
||||
try:
|
||||
data: dict = json.loads(f.read_text(encoding="utf-8"))
|
||||
model_keys = list(data.keys())
|
||||
# Average total_score across all models and prompts
|
||||
all_scores = [
|
||||
r["total_score"]
|
||||
for results in data.values()
|
||||
for r in results
|
||||
if not r.get("error")
|
||||
]
|
||||
avg_score = round(sum(all_scores) / len(all_scores), 3) if all_scores else 0.0
|
||||
except Exception:
|
||||
model_keys = []
|
||||
avg_score = 0.0
|
||||
|
||||
# Parse display date from run_id (plans_2026-04-27_143022)
|
||||
try:
|
||||
date_part = run_id.removeprefix("plans_") # 2026-04-27_143022
|
||||
date, time = date_part.split("_")
|
||||
display_date = f"{date} {time[:2]}:{time[2:4]}"
|
||||
except Exception:
|
||||
display_date = run_id
|
||||
|
||||
runs.append({
|
||||
"run_id": run_id,
|
||||
"filename": f.name,
|
||||
"date": display_date,
|
||||
"models": model_keys,
|
||||
"avg_score": avg_score,
|
||||
})
|
||||
|
||||
return runs
|
||||
|
||||
|
||||
@router.get("/results/latest")
|
||||
def get_latest_results() -> dict:
|
||||
"""Return the most recent planning benchmark results dict."""
|
||||
if not _RESULTS_DIR.exists():
|
||||
raise HTTPException(404, "No benchmark results found")
|
||||
files = sorted(_RESULTS_DIR.glob("plans_*.json"))
|
||||
if not files:
|
||||
raise HTTPException(404, "No benchmark results found")
|
||||
try:
|
||||
return json.loads(files[-1].read_text(encoding="utf-8"))
|
||||
except Exception as exc:
|
||||
raise HTTPException(500, f"Failed to read results: {exc}") from exc
|
||||
|
||||
|
||||
@router.get("/results/{run_id}")
|
||||
def get_results_by_run_id(run_id: str) -> dict:
|
||||
"""Return planning benchmark results for a specific run."""
|
||||
if not run_id.startswith("plans_"):
|
||||
raise HTTPException(400, "Invalid run_id — expected plans_YYYY-MM-DD_HHMMSS")
|
||||
f = _results_file(run_id)
|
||||
if not f.exists():
|
||||
raise HTTPException(404, f"Results not found: {run_id}")
|
||||
try:
|
||||
return json.loads(f.read_text(encoding="utf-8"))
|
||||
except Exception as exc:
|
||||
raise HTTPException(500, f"Failed to read results: {exc}") from exc
|
||||
|
||||
|
||||
# ── POST /cancel ───────────────────────────────────────────────────────────────
|
||||
|
||||
@router.post("/cancel")
|
||||
def cancel_plans_benchmark() -> dict:
|
||||
"""Kill the running planning benchmark subprocess."""
|
||||
global _BENCH_RUNNING, _bench_proc
|
||||
|
||||
if not _BENCH_RUNNING:
|
||||
raise HTTPException(404, "No planning benchmark is currently running")
|
||||
|
||||
if _bench_proc is not None:
|
||||
try:
|
||||
_bench_proc.terminate()
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to terminate plans benchmark: %s", exc)
|
||||
|
||||
_BENCH_RUNNING = False
|
||||
_bench_proc = None
|
||||
return {"status": "cancelled"}
|
||||
343
app/sft.py
343
app/sft.py
|
|
@ -1,335 +1,8 @@
|
|||
"""Avocet — SFT candidate import and correction API.
|
||||
|
||||
All endpoints are registered on `router` (a FastAPI APIRouter).
|
||||
api.py includes this router with prefix="/api/sft".
|
||||
|
||||
Module-level globals (_SFT_DATA_DIR, _SFT_CONFIG_DIR) follow the same
|
||||
testability pattern as api.py — override them via set_sft_data_dir() and
|
||||
set_sft_config_dir() in test fixtures.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
import yaml
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.utils import append_jsonl, read_jsonl, write_jsonl
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_ROOT = Path(__file__).parent.parent
|
||||
_SFT_DATA_DIR: Path = _ROOT / "data"
|
||||
_SFT_CONFIG_DIR: Path | None = None
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# ── Testability seams ──────────────────────────────────────────────────────
|
||||
|
||||
def set_sft_data_dir(path: Path) -> None:
|
||||
global _SFT_DATA_DIR
|
||||
_SFT_DATA_DIR = path
|
||||
|
||||
|
||||
def set_sft_config_dir(path: Path | None) -> None:
|
||||
global _SFT_CONFIG_DIR
|
||||
_SFT_CONFIG_DIR = path
|
||||
|
||||
|
||||
# ── Internal helpers ───────────────────────────────────────────────────────
|
||||
|
||||
def _config_file() -> Path:
|
||||
if _SFT_CONFIG_DIR is not None:
|
||||
return _SFT_CONFIG_DIR / "label_tool.yaml"
|
||||
return _ROOT / "config" / "label_tool.yaml"
|
||||
|
||||
|
||||
_DEFAULT_BENCH_RESULTS_DIR = "/Library/Development/CircuitForge/circuitforge-orch/scripts/bench_results"
|
||||
|
||||
|
||||
def set_default_bench_results_dir(path: str) -> None:
|
||||
"""Override the default bench_results_dir — used by tests to avoid real filesystem."""
|
||||
global _DEFAULT_BENCH_RESULTS_DIR
|
||||
_DEFAULT_BENCH_RESULTS_DIR = path
|
||||
|
||||
|
||||
def _get_bench_results_dir() -> Path:
|
||||
f = _config_file()
|
||||
if f.exists():
|
||||
try:
|
||||
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
||||
d = raw.get("sft", {}).get("bench_results_dir", "")
|
||||
if d:
|
||||
return Path(d)
|
||||
except yaml.YAMLError as exc:
|
||||
logger.warning("Failed to parse SFT config %s: %s", f, exc)
|
||||
return Path(_DEFAULT_BENCH_RESULTS_DIR)
|
||||
|
||||
|
||||
def _candidates_file() -> Path:
|
||||
return _SFT_DATA_DIR / "sft_candidates.jsonl"
|
||||
|
||||
|
||||
def _approved_file() -> Path:
|
||||
return _SFT_DATA_DIR / "sft_approved.jsonl"
|
||||
|
||||
|
||||
def _read_candidates() -> list[dict]:
|
||||
return read_jsonl(_candidates_file())
|
||||
|
||||
|
||||
def _write_candidates(records: list[dict]) -> None:
|
||||
write_jsonl(_candidates_file(), records)
|
||||
|
||||
|
||||
def _is_exportable(r: dict) -> bool:
|
||||
"""Return True if an approved record is ready to include in SFT export."""
|
||||
return (
|
||||
r.get("status") == "approved"
|
||||
and bool(r.get("corrected_response"))
|
||||
and str(r["corrected_response"]).strip() != ""
|
||||
)
|
||||
|
||||
|
||||
# ── GET /runs ──────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/runs")
|
||||
def get_runs():
|
||||
"""List available benchmark runs in the configured bench_results_dir."""
|
||||
from scripts.sft_import import discover_runs
|
||||
bench_dir = _get_bench_results_dir()
|
||||
existing = _read_candidates()
|
||||
# benchmark_run_id in each record equals the run's directory name by cf-orch convention
|
||||
imported_run_ids = {
|
||||
r["benchmark_run_id"]
|
||||
for r in existing
|
||||
if r.get("benchmark_run_id") is not None
|
||||
}
|
||||
runs = discover_runs(bench_dir)
|
||||
return [
|
||||
{
|
||||
"run_id": r["run_id"],
|
||||
"timestamp": r["timestamp"],
|
||||
"candidate_count": r["candidate_count"],
|
||||
"already_imported": r["run_id"] in imported_run_ids,
|
||||
}
|
||||
for r in runs
|
||||
]
|
||||
|
||||
|
||||
# ── POST /import ───────────────────────────────────────────────────────────
|
||||
|
||||
class ImportRequest(BaseModel):
|
||||
run_id: str
|
||||
|
||||
|
||||
@router.post("/import")
|
||||
def post_import(req: ImportRequest):
|
||||
"""Import one benchmark run's sft_candidates.jsonl into the local data dir."""
|
||||
from scripts.sft_import import discover_runs, import_run
|
||||
bench_dir = _get_bench_results_dir()
|
||||
runs = discover_runs(bench_dir)
|
||||
run = next((r for r in runs if r["run_id"] == req.run_id), None)
|
||||
if run is None:
|
||||
raise HTTPException(404, f"Run {req.run_id!r} not found in bench_results_dir")
|
||||
return import_run(run["sft_path"], _SFT_DATA_DIR)
|
||||
|
||||
|
||||
# ── GET /queue ─────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/queue")
|
||||
def get_queue(page: int = 1, per_page: int = 20):
|
||||
"""Return paginated needs_review candidates."""
|
||||
records = _read_candidates()
|
||||
pending = [r for r in records if r.get("status") == "needs_review"]
|
||||
start = (page - 1) * per_page
|
||||
return {
|
||||
"items": pending[start:start + per_page],
|
||||
"total": len(pending),
|
||||
"page": page,
|
||||
"per_page": per_page,
|
||||
}
|
||||
|
||||
|
||||
# ── POST /submit ───────────────────────────────────────────────────────────
|
||||
|
||||
FailureCategory = Literal[
|
||||
"scoring_artifact",
|
||||
"style_violation",
|
||||
"partial_answer",
|
||||
"wrong_answer",
|
||||
"format_error",
|
||||
"hallucination",
|
||||
]
|
||||
|
||||
|
||||
class SubmitRequest(BaseModel):
|
||||
id: str
|
||||
action: Literal["correct", "discard", "flag"]
|
||||
corrected_response: str | None = None
|
||||
failure_category: FailureCategory | None = None
|
||||
|
||||
|
||||
@router.post("/submit")
|
||||
def post_submit(req: SubmitRequest):
|
||||
"""Record a reviewer decision for one SFT candidate."""
|
||||
if req.action == "correct":
|
||||
if not req.corrected_response or not req.corrected_response.strip():
|
||||
raise HTTPException(422, "corrected_response must be non-empty when action is 'correct'")
|
||||
|
||||
records = _read_candidates()
|
||||
idx = next((i for i, r in enumerate(records) if r.get("id") == req.id), None)
|
||||
if idx is None:
|
||||
raise HTTPException(404, f"Record {req.id!r} not found")
|
||||
|
||||
record = records[idx]
|
||||
if record.get("status") != "needs_review":
|
||||
raise HTTPException(409, f"Record is not in needs_review state (current: {record.get('status')})")
|
||||
|
||||
if req.action == "correct":
|
||||
records[idx] = {
|
||||
**record,
|
||||
"status": "approved",
|
||||
"corrected_response": req.corrected_response,
|
||||
"failure_category": req.failure_category,
|
||||
}
|
||||
_write_candidates(records)
|
||||
append_jsonl(_approved_file(), records[idx])
|
||||
elif req.action == "discard":
|
||||
records[idx] = {**record, "status": "discarded"}
|
||||
_write_candidates(records)
|
||||
else: # flag
|
||||
records[idx] = {**record, "status": "model_rejected"}
|
||||
_write_candidates(records)
|
||||
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
# ── POST /undo ─────────────────────────────────────────────────────────────
|
||||
|
||||
class UndoRequest(BaseModel):
|
||||
id: str
|
||||
|
||||
|
||||
@router.post("/undo")
|
||||
def post_undo(req: UndoRequest):
|
||||
"""Restore a previously actioned candidate back to needs_review."""
|
||||
records = _read_candidates()
|
||||
idx = next((i for i, r in enumerate(records) if r.get("id") == req.id), None)
|
||||
if idx is None:
|
||||
raise HTTPException(404, f"Record {req.id!r} not found")
|
||||
|
||||
record = records[idx]
|
||||
old_status = record.get("status")
|
||||
if old_status == "needs_review":
|
||||
raise HTTPException(409, "Record is already in needs_review state")
|
||||
|
||||
records[idx] = {**record, "status": "needs_review", "corrected_response": None}
|
||||
_write_candidates(records)
|
||||
|
||||
# If it was approved, remove from the approved file too
|
||||
if old_status == "approved":
|
||||
approved = read_jsonl(_approved_file())
|
||||
write_jsonl(_approved_file(), [r for r in approved if r.get("id") != req.id])
|
||||
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
# ── GET /export ─────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/export")
|
||||
def get_export() -> StreamingResponse:
|
||||
"""Stream approved records as SFT-ready JSONL for download."""
|
||||
exportable = [r for r in read_jsonl(_approved_file()) if _is_exportable(r)]
|
||||
|
||||
def generate():
|
||||
for r in exportable:
|
||||
record = {
|
||||
"messages": r.get("prompt_messages", []) + [
|
||||
{"role": "assistant", "content": r["corrected_response"]}
|
||||
]
|
||||
}
|
||||
yield json.dumps(record) + "\n"
|
||||
|
||||
timestamp = datetime.now(timezone.utc).strftime("%Y%m%d-%H%M%S")
|
||||
return StreamingResponse(
|
||||
generate(),
|
||||
media_type="application/x-ndjson",
|
||||
headers={
|
||||
"Content-Disposition": f'attachment; filename="sft_export_{timestamp}.jsonl"'
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# ── GET /stats ──────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/stats")
|
||||
def get_stats() -> dict[str, object]:
|
||||
"""Return counts by status, model, and task type."""
|
||||
records = _read_candidates()
|
||||
by_status: dict[str, int] = {}
|
||||
by_model: dict[str, int] = {}
|
||||
by_task_type: dict[str, int] = {}
|
||||
|
||||
for r in records:
|
||||
status = r.get("status", "unknown")
|
||||
by_status[status] = by_status.get(status, 0) + 1
|
||||
model = r.get("model_name", "unknown")
|
||||
by_model[model] = by_model.get(model, 0) + 1
|
||||
task_type = r.get("task_type", "unknown")
|
||||
by_task_type[task_type] = by_task_type.get(task_type, 0) + 1
|
||||
|
||||
approved = read_jsonl(_approved_file())
|
||||
export_ready = sum(1 for r in approved if _is_exportable(r))
|
||||
|
||||
return {
|
||||
"total": len(records),
|
||||
"by_status": by_status,
|
||||
"by_model": by_model,
|
||||
"by_task_type": by_task_type,
|
||||
"export_ready": export_ready,
|
||||
}
|
||||
|
||||
|
||||
# ── GET /config ─────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/config")
|
||||
def get_sft_config() -> dict:
|
||||
"""Return the current SFT configuration (bench_results_dir)."""
|
||||
f = _config_file()
|
||||
if not f.exists():
|
||||
return {"bench_results_dir": ""}
|
||||
try:
|
||||
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
||||
except yaml.YAMLError:
|
||||
return {"bench_results_dir": ""}
|
||||
sft_section = raw.get("sft") or {}
|
||||
return {"bench_results_dir": sft_section.get("bench_results_dir", "")}
|
||||
|
||||
|
||||
class SftConfigPayload(BaseModel):
|
||||
bench_results_dir: str
|
||||
|
||||
|
||||
@router.post("/config")
|
||||
def post_sft_config(payload: SftConfigPayload) -> dict:
|
||||
"""Write the bench_results_dir setting to the config file."""
|
||||
f = _config_file()
|
||||
f.parent.mkdir(parents=True, exist_ok=True)
|
||||
try:
|
||||
raw = yaml.safe_load(f.read_text(encoding="utf-8")) if f.exists() else {}
|
||||
raw = raw or {}
|
||||
except yaml.YAMLError:
|
||||
raw = {}
|
||||
raw["sft"] = {"bench_results_dir": payload.bench_results_dir}
|
||||
tmp = f.with_suffix(".tmp")
|
||||
tmp.write_text(yaml.dump(raw, allow_unicode=True, sort_keys=False), encoding="utf-8")
|
||||
tmp.rename(f)
|
||||
return {"ok": True}
|
||||
"""Backward-compat shim -- logic moved to app/data/corrections.py."""
|
||||
from app.data.corrections import ( # noqa: F401
|
||||
router,
|
||||
set_data_dir as set_sft_data_dir,
|
||||
set_config_dir as set_sft_config_dir,
|
||||
set_default_bench_results_dir,
|
||||
_DEFAULT_BENCH_RESULTS_DIR,
|
||||
)
|
||||
|
|
|
|||
0
app/train/__init__.py
Normal file
0
app/train/__init__.py
Normal file
339
app/train/train.py
Normal file
339
app/train/train.py
Normal file
|
|
@ -0,0 +1,339 @@
|
|||
"""Avocet -- train job queue API.
|
||||
|
||||
SQLite-backed job queue for finetune jobs. Replaces the ad-hoc
|
||||
_running_procs dict in api.py with a persistent, inspectable queue.
|
||||
|
||||
Routes (all under /api/train when api.py mounts with prefix="/api/train"):
|
||||
GET /jobs -- list all jobs, newest first
|
||||
POST /jobs -- create a new job
|
||||
GET /jobs/{id} -- get one job by id
|
||||
DELETE /jobs/{id}/cancel -- cancel a queued or running job
|
||||
GET /jobs/{id}/run -- SSE: run the job, stream stdout
|
||||
GET /results -- list completed models with training_info.json metrics
|
||||
|
||||
SQLite schema:
|
||||
CREATE TABLE IF NOT EXISTS jobs (
|
||||
id TEXT PRIMARY KEY,
|
||||
type TEXT NOT NULL, -- 'classifier' | 'llm-sft'
|
||||
model_key TEXT NOT NULL,
|
||||
status TEXT NOT NULL DEFAULT 'queued',
|
||||
config_json TEXT NOT NULL DEFAULT '{}',
|
||||
created_at TEXT NOT NULL,
|
||||
started_at TEXT,
|
||||
completed_at TEXT,
|
||||
error TEXT
|
||||
)
|
||||
|
||||
Testability seam: _DB_PATH global, override via set_db_path().
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sqlite3
|
||||
import subprocess as _subprocess
|
||||
import uuid
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any, Generator
|
||||
|
||||
import yaml
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_ROOT = Path(__file__).parent.parent.parent
|
||||
_DB_PATH: Path = _ROOT / "data" / "train_jobs.db"
|
||||
_MODELS_DIR: Path = _ROOT / "models"
|
||||
_CONFIG_DIR: Path | None = None # override in tests via set_config_dir()
|
||||
_running_procs: dict[str, Any] = {}
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# -- Testability seams -------------------------------------------------
|
||||
|
||||
def set_db_path(path: Path) -> None:
|
||||
global _DB_PATH
|
||||
_DB_PATH = path
|
||||
|
||||
def set_models_dir(path: Path) -> None:
|
||||
global _MODELS_DIR
|
||||
_MODELS_DIR = path
|
||||
|
||||
def set_config_dir(path: "Path | None") -> None:
|
||||
global _CONFIG_DIR
|
||||
_CONFIG_DIR = path
|
||||
|
||||
|
||||
# -- Config helpers ----------------------------------------------------
|
||||
|
||||
def _config_file() -> Path:
|
||||
if _CONFIG_DIR is not None:
|
||||
return _CONFIG_DIR / "label_tool.yaml"
|
||||
return _ROOT / "config" / "label_tool.yaml"
|
||||
|
||||
|
||||
def _load_train_config() -> dict:
|
||||
"""Read python_bin from label_tool.yaml.
|
||||
|
||||
Priority (highest to lowest):
|
||||
1. label_tool.yaml train: python_bin
|
||||
2. label_tool.yaml cforch: python_bin
|
||||
3. Hardcoded default (classifiers conda env)
|
||||
"""
|
||||
_DEFAULT_PYTHON_BIN = "/devl/miniconda3/envs/job-seeker-classifiers/bin/python"
|
||||
f = _config_file()
|
||||
train_cfg: dict = {}
|
||||
cforch_cfg: dict = {}
|
||||
if f.exists():
|
||||
try:
|
||||
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
||||
train_cfg = raw.get("train", {}) or {}
|
||||
cforch_cfg = raw.get("cforch", {}) or {}
|
||||
except yaml.YAMLError as exc:
|
||||
logger.warning("Failed to parse train config %s: %s", f, exc)
|
||||
return {
|
||||
"python_bin": train_cfg.get(
|
||||
"python_bin",
|
||||
cforch_cfg.get("python_bin", _DEFAULT_PYTHON_BIN),
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
# -- Database helpers --------------------------------------------------
|
||||
|
||||
@contextmanager
|
||||
def _db() -> Generator[sqlite3.Connection, None, None]:
|
||||
conn = sqlite3.connect(str(_DB_PATH))
|
||||
conn.row_factory = sqlite3.Row
|
||||
try:
|
||||
yield conn
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
def _init_db() -> None:
|
||||
"""Create jobs table if it does not exist. Called lazily per request."""
|
||||
_DB_PATH.parent.mkdir(parents=True, exist_ok=True)
|
||||
with _db() as conn:
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS jobs (
|
||||
id TEXT PRIMARY KEY,
|
||||
type TEXT NOT NULL,
|
||||
model_key TEXT NOT NULL,
|
||||
status TEXT NOT NULL DEFAULT 'queued',
|
||||
config_json TEXT NOT NULL DEFAULT '{}',
|
||||
created_at TEXT NOT NULL,
|
||||
started_at TEXT,
|
||||
completed_at TEXT,
|
||||
error TEXT
|
||||
)
|
||||
""")
|
||||
|
||||
|
||||
def _row_to_dict(row: sqlite3.Row) -> dict:
|
||||
return {k: row[k] for k in row.keys()}
|
||||
|
||||
|
||||
# -- GPU selection (copied from api.py) --------------------------------
|
||||
|
||||
def _best_cuda_device() -> str:
|
||||
"""Return index of GPU with most free VRAM, or empty string."""
|
||||
try:
|
||||
out = _subprocess.check_output(
|
||||
["nvidia-smi", "--query-gpu=index,memory.free",
|
||||
"--format=csv,noheader,nounits"],
|
||||
text=True, timeout=5,
|
||||
)
|
||||
best_idx, best_free = "", 0
|
||||
for line in out.strip().splitlines():
|
||||
parts = line.strip().split(", ")
|
||||
if len(parts) == 2:
|
||||
idx, free = parts[0].strip(), int(parts[1].strip())
|
||||
if free > best_free:
|
||||
best_free, best_idx = free, idx
|
||||
return best_idx
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
|
||||
# -- Pydantic models ---------------------------------------------------
|
||||
|
||||
class CreateJobRequest(BaseModel):
|
||||
type: str # "classifier" | "llm-sft"
|
||||
model_key: str # e.g. "deberta-small"
|
||||
config_json: dict = {}
|
||||
|
||||
|
||||
# -- Routes ------------------------------------------------------------
|
||||
|
||||
@router.get("/jobs")
|
||||
def list_jobs() -> dict:
|
||||
_init_db()
|
||||
with _db() as conn:
|
||||
rows = conn.execute("SELECT * FROM jobs ORDER BY created_at DESC").fetchall()
|
||||
return {"jobs": [_row_to_dict(r) for r in rows]}
|
||||
|
||||
|
||||
@router.post("/jobs")
|
||||
def create_job(req: CreateJobRequest) -> dict:
|
||||
if req.type not in ("classifier", "llm-sft"):
|
||||
raise HTTPException(400, f"Unknown job type: {req.type!r}. Must be 'classifier' or 'llm-sft'.")
|
||||
_init_db()
|
||||
job_id = str(uuid.uuid4())
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
with _db() as conn:
|
||||
conn.execute(
|
||||
"INSERT INTO jobs (id, type, model_key, status, config_json, created_at) "
|
||||
"VALUES (?, ?, ?, 'queued', ?, ?)",
|
||||
(job_id, req.type, req.model_key, json.dumps(req.config_json), now),
|
||||
)
|
||||
return {"id": job_id, "type": req.type, "model_key": req.model_key,
|
||||
"status": "queued", "config_json": req.config_json,
|
||||
"created_at": now, "started_at": None, "completed_at": None, "error": None}
|
||||
|
||||
|
||||
@router.get("/jobs/{job_id}")
|
||||
def get_job(job_id: str) -> dict:
|
||||
_init_db()
|
||||
with _db() as conn:
|
||||
row = conn.execute("SELECT * FROM jobs WHERE id = ?", (job_id,)).fetchone()
|
||||
if row is None:
|
||||
raise HTTPException(404, f"Job {job_id!r} not found")
|
||||
return _row_to_dict(row)
|
||||
|
||||
|
||||
@router.delete("/jobs/{job_id}/cancel")
|
||||
def cancel_job(job_id: str) -> dict:
|
||||
_init_db()
|
||||
with _db() as conn:
|
||||
row = conn.execute("SELECT * FROM jobs WHERE id = ?", (job_id,)).fetchone()
|
||||
if row is None:
|
||||
raise HTTPException(404, f"Job {job_id!r} not found")
|
||||
if row["status"] not in ("queued", "running"):
|
||||
raise HTTPException(409, f"Job is {row['status']} -- cannot cancel")
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
conn.execute("UPDATE jobs SET status='cancelled', completed_at=? WHERE id=?", (now, job_id))
|
||||
proc = _running_procs.pop(job_id, None)
|
||||
if proc is not None:
|
||||
try:
|
||||
proc.terminate()
|
||||
proc.wait(timeout=3)
|
||||
except _subprocess.TimeoutExpired:
|
||||
try:
|
||||
proc.kill()
|
||||
except OSError:
|
||||
pass
|
||||
return {"status": "cancelled"}
|
||||
|
||||
|
||||
@router.get("/jobs/{job_id}/run")
|
||||
def run_job(job_id: str) -> StreamingResponse:
|
||||
_init_db()
|
||||
with _db() as conn:
|
||||
row = conn.execute("SELECT * FROM jobs WHERE id = ?", (job_id,)).fetchone()
|
||||
if row is None:
|
||||
raise HTTPException(404, f"Job {job_id!r} not found")
|
||||
if row["status"] != "queued":
|
||||
raise HTTPException(409, f"Job is {row['status']} -- only queued jobs can be run")
|
||||
job = _row_to_dict(row)
|
||||
|
||||
def generate():
|
||||
cfg = _load_train_config()
|
||||
python_bin = cfg["python_bin"]
|
||||
config = json.loads(job["config_json"] or "{}")
|
||||
model_key = job["model_key"]
|
||||
epochs = config.get("epochs", 5)
|
||||
|
||||
if job["type"] == "classifier":
|
||||
script = str(_ROOT / "scripts" / "finetune_classifier.py")
|
||||
cmd = [python_bin, script, "--model", model_key, "--epochs", str(epochs)]
|
||||
data_dir = _ROOT / "data"
|
||||
for sf in config.get("score_files", []):
|
||||
resolved = (data_dir / sf).resolve()
|
||||
if resolved.is_relative_to(data_dir.resolve()):
|
||||
cmd.extend(["--score", str(resolved)])
|
||||
elif job["type"] == "llm-sft":
|
||||
script = str(_ROOT / "scripts" / "finetune_sft.py")
|
||||
cmd = [python_bin, script, "--model", model_key, "--epochs", str(epochs)]
|
||||
else:
|
||||
job_type = job["type"]
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': f'Unknown job type: {job_type}'})}\n\n"
|
||||
return
|
||||
|
||||
proc_env = {**os.environ, "PYTORCH_ALLOC_CONF": "expandable_segments:True"}
|
||||
best_gpu = _best_cuda_device()
|
||||
if best_gpu:
|
||||
proc_env["CUDA_VISIBLE_DEVICES"] = best_gpu
|
||||
|
||||
gpu_note = f"GPU {best_gpu}" if best_gpu else "CPU (no GPU found)"
|
||||
yield f"data: {json.dumps({'type': 'progress', 'message': f'[train] Using {gpu_note}'})}\n\n"
|
||||
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
with _db() as conn:
|
||||
conn.execute("UPDATE jobs SET status='running', started_at=? WHERE id=?", (now, job_id))
|
||||
|
||||
try:
|
||||
proc = _subprocess.Popen(
|
||||
cmd, stdout=_subprocess.PIPE, stderr=_subprocess.STDOUT,
|
||||
text=True, bufsize=1, cwd=str(_ROOT), env=proc_env,
|
||||
)
|
||||
_running_procs[job_id] = proc
|
||||
try:
|
||||
for line in proc.stdout:
|
||||
line = line.rstrip()
|
||||
if line:
|
||||
yield f"data: {json.dumps({'type': 'progress', 'message': line})}\n\n"
|
||||
proc.wait()
|
||||
finished_at = datetime.now(timezone.utc).isoformat()
|
||||
if proc.returncode == 0:
|
||||
with _db() as conn:
|
||||
conn.execute(
|
||||
"UPDATE jobs SET status='completed', completed_at=? WHERE id=?",
|
||||
(finished_at, job_id))
|
||||
yield f"data: {json.dumps({'type': 'complete'})}\n\n"
|
||||
else:
|
||||
err = f"Process exited with code {proc.returncode}"
|
||||
with _db() as conn:
|
||||
conn.execute(
|
||||
"UPDATE jobs SET status='failed', completed_at=?, error=? WHERE id=?",
|
||||
(finished_at, err, job_id))
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': err})}\n\n"
|
||||
finally:
|
||||
_running_procs.pop(job_id, None)
|
||||
except Exception as exc:
|
||||
err = str(exc)
|
||||
finished_at = datetime.now(timezone.utc).isoformat()
|
||||
with _db() as conn:
|
||||
conn.execute(
|
||||
"UPDATE jobs SET status='failed', completed_at=?, error=? WHERE id=?",
|
||||
(finished_at, err, job_id))
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': err})}\n\n"
|
||||
|
||||
return StreamingResponse(generate(), media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"})
|
||||
|
||||
|
||||
@router.get("/results")
|
||||
def list_results() -> dict:
|
||||
if not _MODELS_DIR.exists():
|
||||
return {"results": []}
|
||||
results = []
|
||||
for sub in _MODELS_DIR.iterdir():
|
||||
if not sub.is_dir():
|
||||
continue
|
||||
info_path = sub / "training_info.json"
|
||||
if not info_path.exists():
|
||||
continue
|
||||
try:
|
||||
info = json.loads(info_path.read_text(encoding="utf-8"))
|
||||
results.append(info)
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to read training_info.json from %s: %s", info_path, exc)
|
||||
return {"results": results}
|
||||
|
|
@ -106,7 +106,7 @@ def read_jsonl(path: Path) -> list[dict]:
|
|||
def write_jsonl(path: Path, records: list[dict]) -> None:
|
||||
"""Write records to a JSONL file, overwriting any existing content."""
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
content = "\n".join(json.dumps(r) for r in records)
|
||||
content = "\n".join(json.dumps(r, ensure_ascii=False) for r in records)
|
||||
path.write_text(content + ("\n" if records else ""), encoding="utf-8")
|
||||
|
||||
|
||||
|
|
@ -114,4 +114,4 @@ def append_jsonl(path: Path, record: dict) -> None:
|
|||
"""Append a single record to a JSONL file."""
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(path, "a", encoding="utf-8") as fh:
|
||||
fh.write(json.dumps(record) + "\n")
|
||||
fh.write(json.dumps(record, ensure_ascii=False) + "\n")
|
||||
|
|
|
|||
|
|
@ -41,11 +41,20 @@ cforch:
|
|||
# Python interpreter with cf-orch installed
|
||||
python_bin: /devl/miniconda3/envs/cf/bin/python
|
||||
|
||||
# Connection config — override env vars CF_ORCH_URL / CF_LICENSE_KEY / OLLAMA_HOST
|
||||
# Connection config — override env vars CF_ORCH_URL / CF_LICENSE_KEY / OLLAMA_HOST / CF_JUDGE_URL / HF_TOKEN
|
||||
# coordinator_url: http://localhost:7700
|
||||
# license_key: CFG-AVCT-xxxx-xxxx-xxxx
|
||||
# ollama_url: http://localhost:11434
|
||||
# ollama_model: llama3.2:3b
|
||||
# embed_model: nomic-embed-text # Ollama embedding model for EmbeddingKNNAdapter
|
||||
# judge_url: http://10.1.10.158:8008 # Sif cf-text — LLM-as-judge secondary scorer
|
||||
# judge_url: http://10.1.10.71:8008 # Heimdall cf-text (alternative)
|
||||
# Or set CF_JUDGE_URL. Populates the Judge URL field in the LLM Eval UI automatically.
|
||||
# hf_token: hf_xxxxxxxxxxxxxxxxxxxx # HuggingFace token — required for gated/terms-restricted models
|
||||
|
||||
# Directory containing per-node profile YAMLs (cf-orch node profiles).
|
||||
# Default: derived from bench_script location (../../profiles/nodes).
|
||||
# profiles_dir: /Library/Development/CircuitForge/circuitforge-orch/circuitforge_orch/profiles/nodes
|
||||
|
||||
# Imitate tab — pull real samples from sibling CF product APIs and run them
|
||||
# through local LLMs to build a corrections dataset.
|
||||
|
|
@ -101,12 +110,3 @@ imitate:
|
|||
sample_endpoint: /api/listings
|
||||
text_fields: [title, description, seller_info]
|
||||
prompt_template: "Evaluate the trustworthiness of this listing and flag any red flags:\n\n{text}"
|
||||
|
||||
- id: osprey
|
||||
name: Osprey
|
||||
icon: "📞"
|
||||
description: Gov't hold-line automation
|
||||
base_url: http://localhost:8520
|
||||
sample_endpoint: /api/calls/recent
|
||||
text_fields: [agency, issue, notes]
|
||||
prompt_template: "Draft a concise summary of this government call record:\n\n{text}"
|
||||
|
|
|
|||
35
manage.sh
35
manage.sh
|
|
@ -90,6 +90,12 @@ usage() {
|
|||
echo -e " ${GREEN}score [args]${NC} Shortcut: --score [args]"
|
||||
echo -e " ${GREEN}compare [args]${NC} Shortcut: --compare [args]"
|
||||
echo ""
|
||||
echo " Planning Benchmark:"
|
||||
echo -e " ${GREEN}plans-bench [args]${NC} Run benchmark_plans.py (args passed through)"
|
||||
echo -e " ${GREEN}plans-list${NC} Shortcut: --list-models"
|
||||
echo -e " ${GREEN}plans-run <model> [args]${NC} Run a single model (--verbose auto-added)"
|
||||
echo -e " ${GREEN}plans-compare <m1> <m2> [more]${NC} Compare models side-by-side"
|
||||
echo ""
|
||||
echo " Writing Style Benchmark:"
|
||||
echo -e " ${GREEN}style-bench [args]${NC} Run benchmark_style.py (args passed through)"
|
||||
echo -e " ${GREEN}style-list${NC} List available ollama models for style bench"
|
||||
|
|
@ -127,6 +133,8 @@ case "$CMD" in
|
|||
fi
|
||||
mkdir -p "$LOG_DIR"
|
||||
API_LOG="${LOG_DIR}/api.log"
|
||||
# Load .env if present — sets HF_TOKEN and other optional overrides.
|
||||
[[ -f .env ]] && set -a && source .env && set +a
|
||||
info "Building Vue SPA…"
|
||||
(cd web && npm run build) >> "$API_LOG" 2>&1
|
||||
info "Starting FastAPI on port ${API_PORT}…"
|
||||
|
|
@ -179,6 +187,9 @@ case "$CMD" in
|
|||
mkdir -p "$LOG_DIR"
|
||||
DEV_API_LOG="${LOG_DIR}/dev-api.log"
|
||||
|
||||
# Load .env if present — sets HF_TOKEN and other optional overrides.
|
||||
[[ -f .env ]] && set -a && source .env && set +a
|
||||
|
||||
if [[ -f "$DEV_API_PID_FILE" ]] && kill -0 "$(<"$DEV_API_PID_FILE")" 2>/dev/null; then
|
||||
warn "Dev API already running (PID $(<"$DEV_API_PID_FILE"))"
|
||||
else
|
||||
|
|
@ -255,6 +266,30 @@ case "$CMD" in
|
|||
exec "$0" benchmark --compare "$@"
|
||||
;;
|
||||
|
||||
plans-bench)
|
||||
info "Running planning benchmark (${ENV_UI})…"
|
||||
"$PYTHON_UI" scripts/benchmark_plans.py "$@"
|
||||
;;
|
||||
|
||||
plans-list)
|
||||
exec "$0" plans-bench --list-models
|
||||
;;
|
||||
|
||||
plans-run)
|
||||
if [[ $# -lt 1 ]]; then
|
||||
error "Usage: ./manage.sh plans-run <model-key> [extra args]"
|
||||
fi
|
||||
MODEL="$1"; shift
|
||||
exec "$0" plans-bench --model "$MODEL" --verbose "$@"
|
||||
;;
|
||||
|
||||
plans-compare)
|
||||
if [[ $# -lt 2 ]]; then
|
||||
error "Usage: ./manage.sh plans-compare <model1> <model2> [more…]"
|
||||
fi
|
||||
exec "$0" plans-bench --compare "$@" --verbose
|
||||
;;
|
||||
|
||||
style-bench)
|
||||
info "Running writing style benchmark (${ENV_BM})…"
|
||||
if [[ ! -x "$PYTHON_BM" ]]; then
|
||||
|
|
|
|||
|
|
@ -3,3 +3,4 @@ pydantic>=2.0.0
|
|||
uvicorn[standard]>=0.20.0
|
||||
httpx>=0.24.0
|
||||
pytest>=7.0.0
|
||||
pyyaml>=6.0
|
||||
|
|
|
|||
|
|
@ -39,6 +39,7 @@ from scripts.classifier_adapters import (
|
|||
LABELS,
|
||||
LABEL_DESCRIPTIONS,
|
||||
ClassifierAdapter,
|
||||
EmbeddingKNNAdapter,
|
||||
FineTunedAdapter,
|
||||
GLiClassAdapter,
|
||||
RerankerAdapter,
|
||||
|
|
@ -130,6 +131,13 @@ MODEL_REGISTRY: dict[str, dict[str, Any]] = {
|
|||
"params": "600M",
|
||||
"default": False,
|
||||
},
|
||||
"embed-knn-nomic": {
|
||||
"adapter": EmbeddingKNNAdapter,
|
||||
"model_id": "nomic-embed-text",
|
||||
"params": "local-embed",
|
||||
"default": False, # requires orch or ollama; use --include-slow
|
||||
"kwargs": {"k": 3},
|
||||
},
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -184,6 +192,42 @@ def discover_finetuned_models(models_dir: Path | None = None) -> list[dict]:
|
|||
return found
|
||||
|
||||
|
||||
def build_exemplars_from_jsonl(path: str, k_per_label: int = 10) -> dict[str, list[str]]:
|
||||
"""Sample up to k_per_label formatted email texts per label from a scored JSONL.
|
||||
|
||||
Formats each row as 'Subject: {subject}\n\n{body[:600]}' — the same format
|
||||
EmbeddingKNNAdapter uses at classify() time. Rows missing the 'label' key
|
||||
are skipped silently.
|
||||
|
||||
Returns dict[label, list[str]] ready for EmbeddingKNNAdapter(exemplar_texts=...).
|
||||
"""
|
||||
result: dict[str, list[str]] = {}
|
||||
p = Path(path)
|
||||
with p.open(encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
row = json.loads(line)
|
||||
except json.JSONDecodeError as exc:
|
||||
print(f"[build_exemplars] WARN: skipping malformed line: {exc}", flush=True)
|
||||
continue
|
||||
label = row.get("label")
|
||||
if not label:
|
||||
continue
|
||||
subject = row.get("subject", "")
|
||||
body = row.get("body", "")
|
||||
if not subject and not body:
|
||||
continue
|
||||
texts = result.setdefault(label, [])
|
||||
if len(texts) < k_per_label:
|
||||
texts.append(
|
||||
f"Subject: {subject}\n\n{body[:600]}"
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def _active_models(include_slow: bool = False) -> dict[str, dict[str, Any]]:
|
||||
"""Return the active model registry, merged with any discovered fine-tuned models."""
|
||||
active: dict[str, dict[str, Any]] = {
|
||||
|
|
|
|||
719
scripts/benchmark_plans.py
Normal file
719
scripts/benchmark_plans.py
Normal file
|
|
@ -0,0 +1,719 @@
|
|||
#!/usr/bin/env python
|
||||
"""CF-specific planning benchmark — compare base models before fine-tuning.
|
||||
|
||||
Sends held-out CircuitForge planning prompts to one or more models via the
|
||||
cf-text (local) or cf-orch API, then scores responses against CF-specific
|
||||
rubrics. Use this to select the best base model for SFT.
|
||||
|
||||
Scoring rubrics (each 0-1, summed to total/N):
|
||||
- task_structure : uses checkbox syntax (- [ ]), git commit steps
|
||||
- tier_awareness : mentions Free/Paid/Premium/Ultra tiers
|
||||
- privacy_pillar : mentions privacy/local-inference/no-logging
|
||||
- safety_pillar : mentions safety, human approval, or reversibility
|
||||
- accessibility : mentions ND/accessibility/adaptive needs
|
||||
- license_split : mentions MIT vs BSL or open-core model
|
||||
- file_paths : uses plausible file path references
|
||||
- cf_conventions : uses conda run -n cf, /Library/Development/, or known CF dirs
|
||||
- paired_coherence : (paired only) plan references the design doc's feature name
|
||||
- length_ok : 300–2500 words (under-short = hallucination risk; over-long = padding)
|
||||
|
||||
Usage
|
||||
-----
|
||||
# List available model targets
|
||||
python scripts/benchmark_plans.py --list-models
|
||||
|
||||
# Run all held-out prompts against a single model, print report
|
||||
python scripts/benchmark_plans.py --model llama3.2-3b
|
||||
|
||||
# Compare two models side-by-side
|
||||
python scripts/benchmark_plans.py --compare llama3.2-3b mistral-7b
|
||||
|
||||
# Run with a custom API base (cf-text default: http://localhost:8080/v1)
|
||||
python scripts/benchmark_plans.py --model llama3.2-3b --api-base http://localhost:8080/v1
|
||||
|
||||
# Export detailed results JSON
|
||||
python scripts/benchmark_plans.py --model llama3.2-3b --output data/bench_results.json
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import re
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from dataclasses import dataclass, field, asdict
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
# ── Paths ──────────────────────────────────────────────────────────────────────
|
||||
|
||||
_ROOT = Path(__file__).parent.parent
|
||||
_DATA_DIR = _ROOT / "data"
|
||||
|
||||
CF_TEXT_BASE = "http://localhost:8080/v1"
|
||||
CF_ORCH_BASE = "http://localhost:8090/v1"
|
||||
CF_COORD_URL = "http://10.1.10.71:7700" # cf-orch coordinator (LAN)
|
||||
|
||||
# ── Held-out prompts ───────────────────────────────────────────────────────────
|
||||
# These are NOT in the training export (no matching docs in circuitforge-plans/).
|
||||
# Each prompt exercises a different CF planning domain.
|
||||
|
||||
HELD_OUT_PROMPTS: list[dict[str, Any]] = [
|
||||
{
|
||||
"id": "ho_001",
|
||||
"name": "kiwi_barcode_ocr",
|
||||
"domain": "feature_plan",
|
||||
"prompt": (
|
||||
"You are a senior engineer on Kiwi, a CircuitForge pantry-tracking product. "
|
||||
"Write a detailed implementation plan for adding barcode scanning via device camera "
|
||||
"and receipt OCR to the item-add flow.\n\n"
|
||||
"The plan should include: file structure (create/modify), step-by-step task checklist "
|
||||
"with checkboxes, any DB migrations, and git commit steps."
|
||||
),
|
||||
"expected_signals": ["task_structure", "file_paths", "cf_conventions"],
|
||||
},
|
||||
{
|
||||
"id": "ho_002",
|
||||
"name": "peregrine_ats_scoring",
|
||||
"domain": "feature_design",
|
||||
"prompt": (
|
||||
"Write a design document for Peregrine: ATS keyword scoring for job applications.\n\n"
|
||||
"Context: Peregrine users paste job descriptions and their resume. "
|
||||
"We want to score how well the resume keywords match the JD and suggest rewrites. "
|
||||
"Describe the architecture, data flow, and key design decisions."
|
||||
),
|
||||
"expected_signals": ["privacy_pillar", "tier_awareness", "license_split"],
|
||||
},
|
||||
{
|
||||
"id": "ho_003",
|
||||
"name": "tier_gate_local_llm",
|
||||
"domain": "architecture",
|
||||
"prompt": (
|
||||
"Design the tier-gating architecture for a new CircuitForge product. "
|
||||
"The product should:\n"
|
||||
"- Default to local LLM inference for all tiers\n"
|
||||
"- Unlock cloud LLM for Paid tier and above\n"
|
||||
"- Keep fine-tuned model weights for Premium/Ultra only\n\n"
|
||||
"Describe how the tier check integrates with the LLM router, "
|
||||
"what happens when a Free user tries a Paid-tier feature, "
|
||||
"and how BYOK (bring-your-own-key) fits in."
|
||||
),
|
||||
"expected_signals": ["tier_awareness", "privacy_pillar", "license_split"],
|
||||
},
|
||||
{
|
||||
"id": "ho_004",
|
||||
"name": "heimdall_webhook_plan",
|
||||
"domain": "feature_plan",
|
||||
"prompt": (
|
||||
"Break the following Heimdall feature into a detailed implementation plan with "
|
||||
"file structure and task checkboxes — Stripe webhook handler for subscription lifecycle.\n\n"
|
||||
"Heimdall is the CircuitForge license server (FastAPI + SQLite). "
|
||||
"The webhook needs to handle checkout.session.completed, "
|
||||
"customer.subscription.updated, and customer.subscription.deleted events."
|
||||
),
|
||||
"expected_signals": ["task_structure", "file_paths", "safety_pillar"],
|
||||
},
|
||||
{
|
||||
"id": "ho_005",
|
||||
"name": "nd_accessible_onboarding",
|
||||
"domain": "ux_design",
|
||||
"prompt": (
|
||||
"You are a product designer working on Harrier, a CircuitForge tool for "
|
||||
"helping people navigate government benefits applications.\n\n"
|
||||
"Design the onboarding flow for neurodivergent (ND) users. "
|
||||
"Consider: ADHD time-blindness, executive function challenges, demand avoidance, "
|
||||
"and rejection sensitivity. The flow should reduce cognitive load and "
|
||||
"never use urgency or panic patterns."
|
||||
),
|
||||
"expected_signals": ["accessibility", "safety_pillar", "privacy_pillar"],
|
||||
},
|
||||
{
|
||||
"id": "ho_006",
|
||||
"name": "circuitforge_core_extraction",
|
||||
"domain": "architecture",
|
||||
"prompt": (
|
||||
"Produce a CircuitForge-style design document for the following circuitforge-core "
|
||||
"feature — shared ActivityPub federation module.\n\n"
|
||||
"Background: Multiple CF products (Kiwi, Rook, Snipe) want to publish updates "
|
||||
"to ActivityPub. Build it once in cf-core (MIT licensed) so all products can use it. "
|
||||
"Design the module API, describe what belongs in MIT vs BSL, and note federation "
|
||||
"privacy constraints."
|
||||
),
|
||||
"expected_signals": ["license_split", "privacy_pillar", "cf_conventions"],
|
||||
},
|
||||
{
|
||||
"id": "ho_007",
|
||||
"name": "snipe_trust_score_plan",
|
||||
"domain": "feature_plan",
|
||||
"prompt": (
|
||||
"You are a senior engineer on Snipe, a CircuitForge eBay trust-scoring tool. "
|
||||
"Write a step-by-step engineering plan for: seller trust score calculation.\n\n"
|
||||
"The score should combine: feedback ratio, account age, item-specifics completeness, "
|
||||
"listing photo quality, and shipping time accuracy. "
|
||||
"Include file structure, test plan, and migration steps."
|
||||
),
|
||||
"expected_signals": ["task_structure", "file_paths", "safety_pillar"],
|
||||
},
|
||||
{
|
||||
"id": "ho_008",
|
||||
"name": "avocet_training_pipeline",
|
||||
"domain": "feature_plan",
|
||||
"prompt": (
|
||||
"Break the following Avocet feature into a detailed implementation plan — "
|
||||
"end-to-end fine-tuning pipeline from labeled JSONL to deployed GGUF model.\n\n"
|
||||
"Avocet is the CircuitForge email classifier training tool. "
|
||||
"The pipeline should: validate the dataset, run LoRA SFT via unsloth, "
|
||||
"quantize to Q5_K_M GGUF, run the benchmark harness, and register the model "
|
||||
"in the Avocet model queue if it beats the baseline."
|
||||
),
|
||||
"expected_signals": ["task_structure", "file_paths", "cf_conventions"],
|
||||
},
|
||||
{
|
||||
"id": "ho_009",
|
||||
"name": "privacy_data_flow",
|
||||
"domain": "architecture",
|
||||
"prompt": (
|
||||
"Design the data privacy architecture for a CircuitForge cloud product. "
|
||||
"Describe: what PII is collected, how it's stored, retention policy, "
|
||||
"obfuscation strategy for cloud-side logs, and how consent is obtained "
|
||||
"in plain language. The product handles job applications (resumes, cover letters)."
|
||||
),
|
||||
"expected_signals": ["privacy_pillar", "safety_pillar", "accessibility"],
|
||||
},
|
||||
{
|
||||
"id": "ho_010",
|
||||
"name": "git_workflow_doc",
|
||||
"domain": "process_doc",
|
||||
"prompt": (
|
||||
"Write a developer process document for CircuitForge: conventional commit and "
|
||||
"branch workflow for a BSL 1.1 open-core product.\n\n"
|
||||
"Cover: commit message format (type: description), branch naming, "
|
||||
"when to use feature branches vs direct main commits, "
|
||||
"how the MIT/BSL split affects which commits go in which branch, "
|
||||
"and how CI gates on gitleaks for secret scanning."
|
||||
),
|
||||
"expected_signals": ["license_split", "cf_conventions", "task_structure"],
|
||||
},
|
||||
]
|
||||
|
||||
# ── Rubric scoring ─────────────────────────────────────────────────────────────
|
||||
|
||||
_TASK_STRUCTURE_RE = re.compile(r"- \[ \]", re.MULTILINE)
|
||||
_COMMIT_RE = re.compile(r"git commit|git add", re.IGNORECASE)
|
||||
_TIER_RE = re.compile(r"\b(Free|Paid|Premium|Ultra)\s+tier|\btier\s+(Free|Paid|Premium|Ultra)", re.IGNORECASE)
|
||||
_PRIVACY_RE = re.compile(r"\b(privacy|local.?inference|no.?logging|no.?pii|user.?data|data.?reten|obfuscat)", re.IGNORECASE)
|
||||
_SAFETY_RE = re.compile(r"\b(human.?approv|reversib|safety|safe.?default|fail.?safe|harm)", re.IGNORECASE)
|
||||
_A11Y_RE = re.compile(r"\b(neurodiverg|ND\b|accessib|adaptive|ADHD|autism|executive.?function|demand.?avoid)", re.IGNORECASE)
|
||||
_LICENSE_RE = re.compile(r"\b(MIT|BSL|open.?core|proprietary|commercial.?licens)", re.IGNORECASE)
|
||||
_FILE_PATH_RE = re.compile(r"(app/|tests?/|src/|scripts?/)\w[\w/.-]{3,}", re.IGNORECASE)
|
||||
_CF_CONV_RE = re.compile(r"(conda run -n cf|/Library/Development/CircuitForge|circuitforge-core|manage\.sh)", re.IGNORECASE)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RubricScore:
|
||||
task_structure: float = 0.0
|
||||
tier_awareness: float = 0.0
|
||||
privacy_pillar: float = 0.0
|
||||
safety_pillar: float = 0.0
|
||||
accessibility: float = 0.0
|
||||
license_split: float = 0.0
|
||||
file_paths: float = 0.0
|
||||
cf_conventions: float = 0.0
|
||||
length_ok: float = 0.0
|
||||
|
||||
def total(self) -> float:
|
||||
vals = [self.task_structure, self.tier_awareness, self.privacy_pillar,
|
||||
self.safety_pillar, self.accessibility, self.license_split,
|
||||
self.file_paths, self.cf_conventions, self.length_ok]
|
||||
return sum(vals) / len(vals)
|
||||
|
||||
def as_dict(self) -> dict[str, float]:
|
||||
return asdict(self)
|
||||
|
||||
|
||||
def score_response(response: str, prompt_meta: dict[str, Any]) -> RubricScore:
|
||||
words = len(response.split())
|
||||
s = RubricScore()
|
||||
|
||||
# Task structure: needs checkboxes AND at least one commit step
|
||||
checkbox_hits = len(_TASK_STRUCTURE_RE.findall(response))
|
||||
has_commit = bool(_COMMIT_RE.search(response))
|
||||
s.task_structure = min(1.0, checkbox_hits / 5) * 0.7 + (0.3 if has_commit else 0.0)
|
||||
|
||||
# Tier awareness
|
||||
s.tier_awareness = min(1.0, len(_TIER_RE.findall(response)) / 2)
|
||||
|
||||
# Privacy pillar
|
||||
s.privacy_pillar = min(1.0, len(_PRIVACY_RE.findall(response)) / 3)
|
||||
|
||||
# Safety pillar
|
||||
s.safety_pillar = min(1.0, len(_SAFETY_RE.findall(response)) / 2)
|
||||
|
||||
# Accessibility
|
||||
s.accessibility = min(1.0, len(_A11Y_RE.findall(response)) / 2)
|
||||
|
||||
# License split awareness
|
||||
s.license_split = min(1.0, len(_LICENSE_RE.findall(response)) / 2)
|
||||
|
||||
# File paths: at least 3 plausible path references
|
||||
s.file_paths = min(1.0, len(_FILE_PATH_RE.findall(response)) / 3)
|
||||
|
||||
# CF conventions
|
||||
s.cf_conventions = min(1.0, len(_CF_CONV_RE.findall(response)) / 2)
|
||||
|
||||
# Length: 200–2500 words is healthy; outside = partial credit
|
||||
if 200 <= words <= 2500:
|
||||
s.length_ok = 1.0
|
||||
elif words < 200:
|
||||
s.length_ok = words / 200
|
||||
else:
|
||||
s.length_ok = max(0.0, 1.0 - (words - 2500) / 2500)
|
||||
|
||||
return s
|
||||
|
||||
|
||||
# ── Model client ───────────────────────────────────────────────────────────────
|
||||
|
||||
# Registry of named model targets (shorthand → {api_base, model_name})
|
||||
MODEL_REGISTRY: dict[str, dict[str, str]] = {
|
||||
"deepseek-r1-1.5b": {
|
||||
"api_base": CF_TEXT_BASE,
|
||||
"model": "deepseek-r1-1.5b",
|
||||
"description": "DeepSeek R1 1.5B distill (cf-orch catalog key)",
|
||||
},
|
||||
"deepseek-r1-7b-4bit": {
|
||||
"api_base": CF_TEXT_BASE,
|
||||
"model": "deepseek-r1-7b-4bit",
|
||||
"description": "DeepSeek R1 7B distill, 4-bit (cf-orch catalog key)",
|
||||
},
|
||||
"deepseek-coder-6.7b-4bit": {
|
||||
"api_base": CF_TEXT_BASE,
|
||||
"model": "deepseek-coder-6.7b-4bit",
|
||||
"description": "DeepSeek Coder 6.7B instruct, 4-bit (cf-orch catalog key)",
|
||||
},
|
||||
"granite-4.1-8b": {
|
||||
"api_base": CF_TEXT_BASE,
|
||||
"model": "granite-4.1-8b",
|
||||
"description": "IBM Granite 4.1 8B, 4-bit (cf-orch catalog key)",
|
||||
},
|
||||
"qwen2.5-3b": {
|
||||
"api_base": CF_TEXT_BASE,
|
||||
"model": "qwen2.5-3b",
|
||||
"description": "Qwen 2.5 3B Q4 GGUF (cf-orch catalog key, navi only)",
|
||||
},
|
||||
"qwen2.5-7b": {
|
||||
"api_base": CF_TEXT_BASE,
|
||||
"model": "qwen2.5-7b",
|
||||
"description": "Qwen 2.5 7B Q4 GGUF (cf-orch catalog key, navi only)",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# ── cf-orch allocation ─────────────────────────────────────────────────────────
|
||||
|
||||
def _cforch_allocate(
|
||||
model_id: str,
|
||||
cforch_url: str,
|
||||
startup_timeout_s: float = 300.0,
|
||||
) -> tuple[str, str] | None:
|
||||
"""Allocate a cf-text instance for model_id via the cf-orch coordinator.
|
||||
|
||||
Returns (service_url, allocation_id) on success, None on failure.
|
||||
service_url is the direct node URL exposing /v1/chat/completions.
|
||||
"""
|
||||
try:
|
||||
resp = httpx.post(
|
||||
f"{cforch_url}/api/services/cf-text/allocate",
|
||||
json={
|
||||
"model_candidates": [model_id],
|
||||
"caller": "avocet",
|
||||
"pipeline": "plans_benchmark",
|
||||
},
|
||||
timeout=120.0,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
service_url: str = data["url"]
|
||||
allocation_id: str = data.get("allocation_id", "")
|
||||
node_id: str = data.get("node_id", "")
|
||||
gpu_id: int | None = data.get("gpu_id")
|
||||
|
||||
if data.get("started", False) and not data.get("warm", True):
|
||||
# Use \n so the SSE generator sees the line immediately
|
||||
print(f" [cold start] loading {model_id!r} — polling every 3s…", flush=True)
|
||||
t0 = time.monotonic()
|
||||
deadline = t0 + startup_timeout_s
|
||||
probe_misses = 0
|
||||
|
||||
while time.monotonic() < deadline:
|
||||
elapsed = time.monotonic() - t0
|
||||
try:
|
||||
status = httpx.get(f"{cforch_url}/api/services/cf-text/status", timeout=5.0)
|
||||
if status.is_success:
|
||||
instances = status.json().get("instances", [])
|
||||
match = next(
|
||||
(i for i in instances
|
||||
if i.get("node_id") == node_id and i.get("gpu_id") == gpu_id),
|
||||
None,
|
||||
)
|
||||
if match:
|
||||
probe_misses = 0
|
||||
state = match.get("state", "")
|
||||
if state == "running":
|
||||
print(f" [cold start] ready in {elapsed:.0f}s", flush=True)
|
||||
return service_url, allocation_id
|
||||
elif state == "stopped":
|
||||
print(f" [cold start] failed — service stopped after {elapsed:.0f}s", flush=True)
|
||||
return None
|
||||
else:
|
||||
# still starting — emit keepalive so SSE stream stays alive
|
||||
print(f" [cold start] state={state!r} elapsed={elapsed:.0f}s", flush=True)
|
||||
else:
|
||||
probe_misses += 1
|
||||
print(f" [cold start] waiting… elapsed={elapsed:.0f}s", flush=True)
|
||||
if probe_misses >= 6:
|
||||
try:
|
||||
h = httpx.get(f"{service_url}/health", timeout=3.0)
|
||||
if h.is_success:
|
||||
print(f" [cold start] ready via health check in {elapsed:.0f}s", flush=True)
|
||||
return service_url, allocation_id
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
print(f" [cold start] status poll returned {status.status_code}, elapsed={elapsed:.0f}s", flush=True)
|
||||
except Exception as poll_exc:
|
||||
print(f" [cold start] poll error: {poll_exc} elapsed={elapsed:.0f}s", flush=True)
|
||||
time.sleep(3.0)
|
||||
|
||||
print(f" [cold start] timed out after {time.monotonic()-t0:.0f}s", flush=True)
|
||||
return None
|
||||
|
||||
return service_url, allocation_id
|
||||
except Exception as exc:
|
||||
print(f"[warn] cf-orch allocation failed for {model_id!r}: {exc}", file=sys.stderr)
|
||||
return None
|
||||
|
||||
|
||||
def _call_model_direct(service_url: str, model: str, prompt: str, timeout: int = 600) -> tuple[str, float]:
|
||||
"""Call an OpenAI-compatible /v1/chat/completions on a direct service URL."""
|
||||
t0 = time.monotonic()
|
||||
resp = httpx.post(
|
||||
f"{service_url.rstrip('/')}/v1/chat/completions",
|
||||
json={
|
||||
"model": model,
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"max_tokens": 2048,
|
||||
"temperature": 0.2,
|
||||
},
|
||||
timeout=timeout,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
latency = time.monotonic() - t0
|
||||
text = resp.json()["choices"][0]["message"]["content"]
|
||||
return text, latency
|
||||
|
||||
|
||||
def _call_model(api_base: str, model: str, prompt: str, timeout: int = 180) -> tuple[str, float]:
|
||||
"""Call an OpenAI-compatible /chat/completions endpoint. Returns (text, latency_s)."""
|
||||
t0 = time.monotonic()
|
||||
resp = httpx.post(
|
||||
f"{api_base}/chat/completions",
|
||||
json={
|
||||
"model": model,
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"max_tokens": 2048,
|
||||
"temperature": 0.2,
|
||||
},
|
||||
timeout=timeout,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
latency = time.monotonic() - t0
|
||||
text = resp.json()["choices"][0]["message"]["content"]
|
||||
return text, latency
|
||||
|
||||
|
||||
# ── Benchmark runner ───────────────────────────────────────────────────────────
|
||||
|
||||
@dataclass
|
||||
class PromptResult:
|
||||
prompt_id: str
|
||||
prompt_name: str
|
||||
model_key: str
|
||||
response: str
|
||||
latency_s: float
|
||||
word_count: int
|
||||
scores: dict[str, float]
|
||||
total_score: float
|
||||
error: str | None = None
|
||||
|
||||
|
||||
def run_benchmark(
|
||||
model_key: str,
|
||||
model_name: str,
|
||||
prompts: list[dict[str, Any]] | None = None,
|
||||
verbose: bool = False,
|
||||
# cf-orch path
|
||||
use_cforch: bool = False,
|
||||
cforch_url: str = CF_COORD_URL,
|
||||
# direct path (used when not cf-orch)
|
||||
api_base: str = CF_TEXT_BASE,
|
||||
) -> list[PromptResult]:
|
||||
"""Run all prompts through one model. Uses cf-orch allocation when use_cforch=True."""
|
||||
if prompts is None:
|
||||
prompts = HELD_OUT_PROMPTS
|
||||
|
||||
# Allocate once per model when using cf-orch
|
||||
service_url: str | None = None
|
||||
if use_cforch:
|
||||
print(f" Allocating {model_name!r} via cf-orch…", flush=True)
|
||||
alloc = _cforch_allocate(model_name, cforch_url)
|
||||
if alloc is None:
|
||||
# Return all prompts as errors
|
||||
return [
|
||||
PromptResult(
|
||||
prompt_id=p["id"], prompt_name=p["name"], model_key=model_key,
|
||||
response="", latency_s=0.0, word_count=0, scores={}, total_score=0.0,
|
||||
error=f"cf-orch allocation failed for {model_name!r}",
|
||||
)
|
||||
for p in prompts
|
||||
]
|
||||
service_url, _alloc_id = alloc
|
||||
|
||||
results: list[PromptResult] = []
|
||||
for p in prompts:
|
||||
if verbose:
|
||||
print(f" [{p['id']}] {p['name']} … ", end="", flush=True)
|
||||
try:
|
||||
if service_url:
|
||||
response, latency = _call_model_direct(service_url, model_name, p["prompt"])
|
||||
else:
|
||||
response, latency = _call_model(api_base, model_name, p["prompt"])
|
||||
rubric = score_response(response, p)
|
||||
result = PromptResult(
|
||||
prompt_id=p["id"],
|
||||
prompt_name=p["name"],
|
||||
model_key=model_key,
|
||||
response=response,
|
||||
latency_s=round(latency, 2),
|
||||
word_count=len(response.split()),
|
||||
scores=rubric.as_dict(),
|
||||
total_score=round(rubric.total(), 3),
|
||||
)
|
||||
if verbose:
|
||||
print(f"score={result.total_score:.3f} ({result.word_count}w, {latency:.1f}s)")
|
||||
except Exception as exc:
|
||||
result = PromptResult(
|
||||
prompt_id=p["id"],
|
||||
prompt_name=p["name"],
|
||||
model_key=model_key,
|
||||
response="",
|
||||
latency_s=0.0,
|
||||
word_count=0,
|
||||
scores={},
|
||||
total_score=0.0,
|
||||
error=str(exc),
|
||||
)
|
||||
if verbose:
|
||||
print(f"ERROR: {exc}")
|
||||
results.append(result)
|
||||
return results
|
||||
|
||||
|
||||
# ── Reporting ──────────────────────────────────────────────────────────────────
|
||||
|
||||
def _print_single_report(results: list[PromptResult], model_key: str) -> None:
|
||||
ok = [r for r in results if not r.error]
|
||||
err = [r for r in results if r.error]
|
||||
if not ok:
|
||||
print(f"\n[{model_key}] All {len(err)} prompts failed.\n")
|
||||
return
|
||||
|
||||
avg_total = sum(r.total_score for r in ok) / len(ok)
|
||||
avg_latency = sum(r.latency_s for r in ok) / len(ok)
|
||||
|
||||
# Aggregate per-rubric averages
|
||||
rubric_keys = list(ok[0].scores.keys())
|
||||
rubric_avgs = {k: sum(r.scores.get(k, 0) for r in ok) / len(ok) for k in rubric_keys}
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print(f" Model : {model_key}")
|
||||
print(f" Prompts: {len(ok)}/{len(results)} passed ({len(err)} errors)")
|
||||
print(f" Overall score : {avg_total:.3f} (avg latency {avg_latency:.1f}s)")
|
||||
print(f"\n Rubric breakdown:")
|
||||
for k, v in sorted(rubric_avgs.items(), key=lambda x: -x[1]):
|
||||
bar = "█" * int(v * 20)
|
||||
print(f" {k:<22} {v:.3f} {bar}")
|
||||
print(f"\n Per-prompt scores:")
|
||||
for r in sorted(ok, key=lambda x: -x.total_score):
|
||||
flag = "⚠" if r.total_score < 0.3 else " "
|
||||
print(f" {flag} {r.prompt_id} {r.prompt_name:<35} {r.total_score:.3f} ({r.word_count}w)")
|
||||
if err:
|
||||
print(f"\n Errors:")
|
||||
for r in err:
|
||||
print(f" {r.prompt_id} {r.prompt_name}: {r.error}")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
|
||||
def _print_comparison_table(all_results: dict[str, list[PromptResult]]) -> None:
|
||||
model_keys = list(all_results.keys())
|
||||
prompt_ids = [p["id"] for p in HELD_OUT_PROMPTS]
|
||||
|
||||
# Scores by (model, prompt_id)
|
||||
score_map: dict[tuple[str, str], float] = {}
|
||||
for mk, results in all_results.items():
|
||||
for r in results:
|
||||
score_map[(mk, r.prompt_id)] = r.total_score if not r.error else 0.0
|
||||
|
||||
col_w = 10
|
||||
header = f"{'Prompt':<35}" + "".join(f"{mk[:col_w-1]:<{col_w}}" for mk in model_keys)
|
||||
print(f"\n{'='*len(header)}")
|
||||
print(" COMPARISON TABLE")
|
||||
print(f"{'='*len(header)}")
|
||||
print(f" {header}")
|
||||
print(f" {'-'*len(header)}")
|
||||
|
||||
for pid in prompt_ids:
|
||||
pname = next(p["name"] for p in HELD_OUT_PROMPTS if p["id"] == pid)
|
||||
row = f" {pname:<35}"
|
||||
best = max(score_map.get((mk, pid), 0.0) for mk in model_keys)
|
||||
for mk in model_keys:
|
||||
v = score_map.get((mk, pid), 0.0)
|
||||
marker = "*" if v == best and len(model_keys) > 1 else " "
|
||||
row += f"{v:.3f}{marker} "
|
||||
print(row)
|
||||
|
||||
print(f" {'-'*len(header)}")
|
||||
avgs_row = f" {'AVERAGE':<35}"
|
||||
best_avg = -1.0
|
||||
avgs: dict[str, float] = {}
|
||||
for mk in model_keys:
|
||||
vals = [score_map.get((mk, pid), 0.0) for pid in prompt_ids]
|
||||
avgs[mk] = sum(vals) / len(vals)
|
||||
best_avg = max(best_avg, avgs[mk])
|
||||
for mk in model_keys:
|
||||
marker = "*" if avgs[mk] == best_avg and len(model_keys) > 1 else " "
|
||||
avgs_row += f"{avgs[mk]:.3f}{marker} "
|
||||
print(avgs_row)
|
||||
print(f"{'='*len(header)}\n")
|
||||
if len(model_keys) > 1:
|
||||
winner = max(avgs, key=lambda k: avgs[k])
|
||||
print(f" Winner: {winner} (avg {avgs[winner]:.3f})\n")
|
||||
|
||||
|
||||
# ── CLI ────────────────────────────────────────────────────────────────────────
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter)
|
||||
parser.add_argument("--list-models", action="store_true",
|
||||
help="Print registered model shortcuts and exit")
|
||||
parser.add_argument("--model", metavar="KEY",
|
||||
help="Benchmark a single model (registry key or raw model name)")
|
||||
parser.add_argument("--compare", nargs="+", metavar="KEY",
|
||||
help="Compare two or more models side-by-side")
|
||||
parser.add_argument("--cforch", action="store_true",
|
||||
help="Route inference through cf-orch coordinator (allocate per model)")
|
||||
parser.add_argument("--cforch-url", default=CF_COORD_URL, metavar="URL",
|
||||
help=f"cf-orch coordinator URL (default: {CF_COORD_URL})")
|
||||
parser.add_argument("--api-base", default=None,
|
||||
help="Direct API base URL when not using cf-orch")
|
||||
parser.add_argument("--model-name", default=None,
|
||||
help="Override model name sent to API (single-model runs only)")
|
||||
parser.add_argument("--prompts", nargs="+", metavar="ID",
|
||||
help="Run only specific prompt IDs (e.g. ho_001 ho_003)")
|
||||
parser.add_argument("--output", type=Path, default=None,
|
||||
help="Write detailed JSON results to this path")
|
||||
parser.add_argument("--workers", type=int, default=1, metavar="N",
|
||||
help="Run N models concurrently (default 1). Set to number of available nodes.")
|
||||
parser.add_argument("--verbose", "-v", action="store_true",
|
||||
help="Print per-prompt progress")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.list_models:
|
||||
print("\nRegistered model shortcuts:")
|
||||
for key, info in MODEL_REGISTRY.items():
|
||||
print(f" {key:<20} {info['description']}")
|
||||
print(f"\nDefault endpoints:")
|
||||
print(f" direct {CF_TEXT_BASE}")
|
||||
print(f" cf-orch {CF_COORD_URL}")
|
||||
return
|
||||
|
||||
prompts = HELD_OUT_PROMPTS
|
||||
if args.prompts:
|
||||
ids = set(args.prompts)
|
||||
prompts = [p for p in HELD_OUT_PROMPTS if p["id"] in ids]
|
||||
if not prompts:
|
||||
print(f"No prompts matched IDs: {args.prompts}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
model_keys: list[str] = []
|
||||
if args.compare:
|
||||
model_keys = args.compare
|
||||
elif args.model:
|
||||
model_keys = [args.model]
|
||||
else:
|
||||
parser.print_help()
|
||||
sys.exit(0)
|
||||
|
||||
all_results: dict[str, list[PromptResult]] = {}
|
||||
print_lock = threading.Lock()
|
||||
|
||||
def _run_one(mk: str) -> tuple[str, list[PromptResult]]:
|
||||
if mk in MODEL_REGISTRY:
|
||||
reg = MODEL_REGISTRY[mk]
|
||||
model_name = args.model_name or reg["model"]
|
||||
direct_base = args.api_base or reg["api_base"]
|
||||
else:
|
||||
model_name = args.model_name or mk
|
||||
direct_base = args.api_base or CF_TEXT_BASE
|
||||
|
||||
if args.cforch:
|
||||
with print_lock:
|
||||
print(f"\nRunning [{mk}] via cf-orch ({args.cforch_url}) model={model_name}")
|
||||
results = run_benchmark(
|
||||
mk, model_name, prompts=prompts, verbose=args.verbose,
|
||||
use_cforch=True, cforch_url=args.cforch_url,
|
||||
)
|
||||
else:
|
||||
with print_lock:
|
||||
print(f"\nRunning [{mk}] → {direct_base} model={model_name}")
|
||||
results = run_benchmark(
|
||||
mk, model_name, prompts=prompts, verbose=args.verbose,
|
||||
api_base=direct_base,
|
||||
)
|
||||
|
||||
with print_lock:
|
||||
_print_single_report(results, mk)
|
||||
return mk, results
|
||||
|
||||
workers = max(1, args.workers)
|
||||
if workers == 1 or len(model_keys) == 1:
|
||||
for mk in model_keys:
|
||||
mk_out, results = _run_one(mk)
|
||||
all_results[mk_out] = results
|
||||
else:
|
||||
with ThreadPoolExecutor(max_workers=workers) as pool:
|
||||
futures = {pool.submit(_run_one, mk): mk for mk in model_keys}
|
||||
for fut in as_completed(futures):
|
||||
mk_out, results = fut.result()
|
||||
all_results[mk_out] = results
|
||||
|
||||
if len(model_keys) > 1:
|
||||
_print_comparison_table(all_results)
|
||||
|
||||
if args.output:
|
||||
args.output.parent.mkdir(parents=True, exist_ok=True)
|
||||
payload = {
|
||||
mk: [asdict(r) for r in results]
|
||||
for mk, results in all_results.items()
|
||||
}
|
||||
with open(args.output, "w", encoding="utf-8") as f:
|
||||
json.dump(payload, f, indent=2, ensure_ascii=False)
|
||||
print(f"Wrote detailed results to {args.output}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -7,19 +7,26 @@ from __future__ import annotations
|
|||
|
||||
import abc
|
||||
from collections import defaultdict
|
||||
import httpx
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
__all__ = [
|
||||
"LABELS",
|
||||
"LABEL_DESCRIPTIONS",
|
||||
"DEFAULT_EXEMPLARS",
|
||||
"compute_metrics",
|
||||
"ClassifierAdapter",
|
||||
"ZeroShotAdapter",
|
||||
"GLiClassAdapter",
|
||||
"RerankerAdapter",
|
||||
"FineTunedAdapter",
|
||||
"EmbeddingKNNAdapter",
|
||||
]
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
LABELS: list[str] = [
|
||||
"interview_scheduled",
|
||||
"offer_received",
|
||||
|
|
@ -117,6 +124,81 @@ def compute_metrics(
|
|||
return result
|
||||
|
||||
|
||||
|
||||
def _cosine(a: list[float], b: list[float]) -> float:
|
||||
dot = sum(x * y for x, y in zip(a, b))
|
||||
norm_a = sum(x * x for x in a) ** 0.5
|
||||
norm_b = sum(x * x for x in b) ** 0.5
|
||||
return dot / (norm_a * norm_b) if norm_a and norm_b else 0.0
|
||||
|
||||
|
||||
DEFAULT_EXEMPLARS: dict[str, list[str]] = {
|
||||
"interview_scheduled": [
|
||||
"Subject: Interview Invitation\n\nWe would like to invite you for a phone screen next week.",
|
||||
"Subject: Schedule a call\n\nCould you be available for a video interview on Tuesday?",
|
||||
"Subject: Next Steps\n\nWe'd like to move forward with a technical interview. Please select a time.",
|
||||
"Subject: Interview Details\n\nHere are the dial-in instructions for your interview tomorrow.",
|
||||
],
|
||||
"offer_received": [
|
||||
"Subject: Offer Letter Enclosed\n\nWe are pleased to extend you an offer of employment.",
|
||||
"Subject: Job Offer\n\nDear candidate, we are excited to offer you the position of Software Engineer.",
|
||||
"Subject: Employment Offer\n\nPlease find attached your formal offer letter and compensation details.",
|
||||
"Subject: Offer of Employment\n\nCongratulations! We would like to offer you a full-time position.",
|
||||
],
|
||||
"rejected": [
|
||||
"Subject: Your Application\n\nAfter careful consideration, we have decided to move forward with other candidates.",
|
||||
"Subject: Application Status\n\nWe regret to inform you that your application has not been selected.",
|
||||
"Subject: Thank you for applying\n\nWe appreciate your interest but have chosen not to proceed.",
|
||||
"Subject: Update on your candidacy\n\nWe will not be moving forward with your application at this time.",
|
||||
],
|
||||
"positive_response": [
|
||||
"Subject: Your profile\n\nI came across your LinkedIn and think you would be a great fit for our team.",
|
||||
"Subject: Exciting opportunity\n\nWe were impressed by your background and would love to connect.",
|
||||
"Subject: Following up\n\nThank you for your interest — we'd like to learn more about your experience.",
|
||||
"Subject: Great fit\n\nYour skills align well with what we are looking for. Let's set up a call.",
|
||||
],
|
||||
"survey_received": [
|
||||
"Subject: Candidate Experience Survey\n\nPlease complete this brief survey about your application experience.",
|
||||
"Subject: Culture Fit Assessment\n\nAs part of our process, we ask all candidates to complete a short assessment.",
|
||||
"Subject: Skills Assessment\n\nWe'd like you to complete our online coding assessment before proceeding.",
|
||||
"Subject: Personality Assessment\n\nPlease complete the following assessment as the next step in our process.",
|
||||
"Subject: Pre-interview questionnaire\n\nBefore we schedule your interview, please complete this brief skills survey.",
|
||||
],
|
||||
"neutral": [
|
||||
"Subject: Application Received\n\nWe have received your application and will be in touch.",
|
||||
"Subject: Thank you for applying\n\nYour application is under review. We will contact you if needed.",
|
||||
"Subject: Confirmation\n\nThis email confirms receipt of your application to our company.",
|
||||
"Subject: Application Confirmation\n\nThank you for your interest. We will review your materials and follow up.",
|
||||
],
|
||||
"event_rescheduled": [
|
||||
"Subject: Interview Rescheduled\n\nDue to a conflict, we need to move your interview to a new time.",
|
||||
"Subject: Change of interview time\n\nWe apologize — your interview has been rescheduled to Thursday.",
|
||||
"Subject: Updated interview details\n\nYour interview has been moved from Monday to Wednesday at 2pm.",
|
||||
"Subject: Reschedule request\n\nWould you be available to reschedule to a different time slot?",
|
||||
"Subject: New interview time\n\nYour phone screen has been moved from tomorrow to next week.",
|
||||
],
|
||||
"digest": [
|
||||
"Subject: 15 new jobs matching your search\n\nHere are the latest job postings that match your profile.",
|
||||
"Subject: Weekly Job Digest\n\nThis week's top opportunities for Software Engineers in your area.",
|
||||
"Subject: Jobs you might like\n\nBased on your profile, here are some positions we recommend.",
|
||||
"Subject: New jobs for you\n\nSee the latest openings from companies on your watchlist.",
|
||||
],
|
||||
"new_lead": [
|
||||
"Subject: Exciting opportunity at our company\n\nHi, I noticed your background and think you'd be a great fit.",
|
||||
"Subject: Are you open to new opportunities?\n\nI'm a recruiter reaching out about a role matching your experience.",
|
||||
"Subject: Quick question\n\nWould you be interested in hearing about a senior engineering role?",
|
||||
"Subject: Recruiting outreach\n\nI came across your profile and wanted to share an exciting opening.",
|
||||
],
|
||||
"hired": [
|
||||
"Subject: Welcome to the team!\n\nWe are thrilled to have you join us. Here are your onboarding details.",
|
||||
"Subject: Onboarding information\n\nCongratulations on accepting our offer. Your start date is confirmed.",
|
||||
"Subject: First day information\n\nWe look forward to your first day. Please arrive at 9am and ask for HR.",
|
||||
"Subject: Background check initiated\n\nAs part of your onboarding, we have initiated a background check.",
|
||||
"Subject: Equipment setup\n\nYour laptop and equipment will be ready for pickup on your first day.",
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
class ClassifierAdapter(abc.ABC):
|
||||
"""Abstract base for all email classifier adapters."""
|
||||
|
||||
|
|
@ -304,3 +386,148 @@ class FineTunedAdapter(ClassifierAdapter):
|
|||
text = f"{subject} [SEP] {body[:400]}"
|
||||
result = self._pipeline(text)
|
||||
return result[0]["label"]
|
||||
|
||||
|
||||
class EmbeddingKNNAdapter(ClassifierAdapter):
|
||||
"""k-NN email classifier using Ollama /v1/embeddings via cf-orch allocation.
|
||||
|
||||
load():
|
||||
1. Allocates an Ollama instance from cf-orch (POST /api/services/ollama/allocate).
|
||||
Falls back to ollama_url directly if orch allocation fails or is not configured.
|
||||
2. Pre-embeds all exemplar texts and stores per-label vector lists.
|
||||
|
||||
classify(subject, body):
|
||||
Embeds the input email, computes cosine similarity against all stored exemplar
|
||||
vectors, and majority-votes the top-k labels (default k=3). Tie-break: label
|
||||
with the highest total similarity score among tied vote counts wins.
|
||||
|
||||
unload():
|
||||
Releases the cf-orch allocation (DELETE .../allocations/{id}) and clears state.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
model_id: str,
|
||||
*,
|
||||
k: int = 3,
|
||||
orch_url: str = "",
|
||||
ollama_url: str = "",
|
||||
exemplar_texts: dict[str, list[str]] | None = None,
|
||||
) -> None:
|
||||
self._name = name
|
||||
self._model_id = model_id
|
||||
self._k = k
|
||||
self._orch_url = orch_url
|
||||
self._ollama_url = ollama_url
|
||||
self._exemplar_texts: dict[str, list[str]] = (
|
||||
exemplar_texts if exemplar_texts is not None else DEFAULT_EXEMPLARS
|
||||
)
|
||||
self._exemplar_embeddings: dict[str, list[list[float]]] = {}
|
||||
self._node_url: str = ""
|
||||
self._allocation_id: str = ""
|
||||
self._orch_url_used: str = ""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def model_id(self) -> str:
|
||||
return self._model_id
|
||||
|
||||
def _resolve_urls(self) -> tuple[str, str]:
|
||||
if self._orch_url or self._ollama_url:
|
||||
return self._orch_url, self._ollama_url
|
||||
import yaml # noqa: PLC0415
|
||||
cfg_path = Path(__file__).parent.parent / "config" / "label_tool.yaml"
|
||||
cfg: dict = {}
|
||||
if cfg_path.exists():
|
||||
try:
|
||||
cfg = yaml.safe_load(cfg_path.read_text(encoding="utf-8")) or {}
|
||||
except yaml.YAMLError:
|
||||
pass
|
||||
cforch = cfg.get("cforch", {}) or {}
|
||||
return cforch.get("coordinator_url", ""), cforch.get("ollama_url", "")
|
||||
|
||||
def _embed(self, node_url: str, texts: list[str]) -> list[list[float]]:
|
||||
resp = httpx.post(
|
||||
f"{node_url}/v1/embeddings",
|
||||
json={"model": self._model_id, "input": texts},
|
||||
timeout=30.0,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return [item["embedding"] for item in resp.json()["data"]]
|
||||
|
||||
def load(self) -> None:
|
||||
if self._allocation_id or self._exemplar_embeddings:
|
||||
raise RuntimeError(
|
||||
"EmbeddingKNNAdapter.load() called while already loaded — call unload() first"
|
||||
)
|
||||
orch_url, ollama_url = self._resolve_urls()
|
||||
node_url = ""
|
||||
orch_url_used = ""
|
||||
if orch_url:
|
||||
try:
|
||||
resp = httpx.post(
|
||||
f"{orch_url}/api/services/ollama/allocate",
|
||||
json={"model": self._model_id},
|
||||
timeout=15.0,
|
||||
)
|
||||
if resp.status_code == 200:
|
||||
data = resp.json()
|
||||
node_url = data["url"]
|
||||
self._allocation_id = data["allocation_id"]
|
||||
orch_url_used = orch_url
|
||||
except Exception as exc:
|
||||
_logger.warning(
|
||||
"cf-orch allocation failed, falling back to direct ollama_url: %s", exc
|
||||
)
|
||||
if not node_url:
|
||||
node_url = ollama_url
|
||||
self._allocation_id = ""
|
||||
orch_url_used = ""
|
||||
self._node_url = node_url
|
||||
self._orch_url_used = orch_url_used
|
||||
try:
|
||||
embeddings: dict[str, list[list[float]]] = {}
|
||||
for label, texts in self._exemplar_texts.items():
|
||||
embeddings[label] = self._embed(node_url, texts)
|
||||
self._exemplar_embeddings = embeddings
|
||||
except Exception:
|
||||
self.unload()
|
||||
raise
|
||||
|
||||
def unload(self) -> None:
|
||||
if self._allocation_id and self._orch_url_used:
|
||||
try:
|
||||
httpx.request(
|
||||
"DELETE",
|
||||
f"{self._orch_url_used}/api/services/ollama/allocations/{self._allocation_id}",
|
||||
timeout=10.0,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
self._exemplar_embeddings = {}
|
||||
self._node_url = ""
|
||||
self._allocation_id = ""
|
||||
self._orch_url_used = ""
|
||||
|
||||
def classify(self, subject: str, body: str) -> str:
|
||||
if not self._exemplar_embeddings:
|
||||
self.load()
|
||||
text = f"Subject: {subject}\n\n{body[:600]}"
|
||||
[query_vec] = self._embed(self._node_url, [text])
|
||||
scored: list[tuple[float, str]] = [
|
||||
(_cosine(query_vec, vec), label)
|
||||
for label, vecs in self._exemplar_embeddings.items()
|
||||
for vec in vecs
|
||||
]
|
||||
top_k = sorted(scored, reverse=True)[: self._k]
|
||||
votes: dict[str, list[float]] = {}
|
||||
for score, label in top_k:
|
||||
votes.setdefault(label, []).append(score)
|
||||
return max(
|
||||
votes,
|
||||
key=lambda lbl: (len(votes[lbl]), sum(votes[lbl])),
|
||||
)
|
||||
|
|
|
|||
458
scripts/export_plans.py
Normal file
458
scripts/export_plans.py
Normal file
|
|
@ -0,0 +1,458 @@
|
|||
"""Export circuitforge-plans/ documents as instruction-tuning JSONL pairs.
|
||||
|
||||
Each record is a HuggingFace chat-format example:
|
||||
|
||||
{
|
||||
"id": "<sha256>",
|
||||
"messages": [
|
||||
{"role": "user", "content": "<reconstructed planning prompt>"},
|
||||
{"role": "assistant", "content": "<cleaned document content>"}
|
||||
],
|
||||
"meta": {
|
||||
"source": "peregrine/2026-03-03-feedback-button-design.md",
|
||||
"product": "peregrine",
|
||||
"doc_type": "design", # design | plan | spec | implementation | other
|
||||
"date": "2026-03-03",
|
||||
"paired_with": "...", # sibling path, or null
|
||||
"word_count": 1847,
|
||||
"pair_role": "context" # "context" | "target" | "standalone"
|
||||
}
|
||||
}
|
||||
|
||||
Pairing strategy
|
||||
----------------
|
||||
When a design doc and a plan doc share the same date + feature-name prefix,
|
||||
they are treated as a pair:
|
||||
- design → plan: instruction = "Given this design doc, write the implementation plan."
|
||||
context appended = full design doc content.
|
||||
- Solo docs get a synthetic instruction from the title + first overview section.
|
||||
|
||||
Usage
|
||||
-----
|
||||
# Preview stats and 5 sample records
|
||||
python scripts/export_plans.py --preview
|
||||
|
||||
# Write full output
|
||||
python scripts/export_plans.py --output data/plan_pairs.jsonl
|
||||
|
||||
# Restrict to specific products
|
||||
python scripts/export_plans.py --products peregrine,kiwi --output data/plan_pairs.jsonl
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import hashlib
|
||||
import json
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Iterator
|
||||
|
||||
# ── Paths ──────────────────────────────────────────────────────────────────────
|
||||
|
||||
_SCRIPT_DIR = Path(__file__).parent
|
||||
_AVOCET_ROOT = _SCRIPT_DIR.parent
|
||||
_DEFAULT_PLANS_DIR = Path("/Library/Development/CircuitForge/circuitforge-plans")
|
||||
_DEFAULT_OUTPUT = _AVOCET_ROOT / "data" / "plan_pairs.jsonl"
|
||||
|
||||
# ── Doc type detection ─────────────────────────────────────────────────────────
|
||||
|
||||
_TYPE_RE = re.compile(
|
||||
r"-(design|plan|spec|implementation|specs|plans)s?$",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
_SKIP_DIRS = {"__pycache__", ".git", "node_modules"}
|
||||
|
||||
# Boilerplate lines to strip from document content before using as output.
|
||||
_BOILERPLATE_RE = re.compile(
|
||||
r"""
|
||||
^\s*>\s*\*\*For\s+agentic\s+workers.* # superpowers agent hints
|
||||
|^\s*>\s*REQUIRED\s+SUB-SKILL.*
|
||||
|^\s*\*\*Date:\*\*.* # metadata header lines
|
||||
|\*\*Status:\*\*\s*Complete.* # completed-feature noise
|
||||
|\*\*Status:\*\*\s*Done.*
|
||||
|\*\*Product:\*\*.*
|
||||
|\*\*Repo:\*\*.*
|
||||
|\*\*Tech\s+Stack:\*\*.*
|
||||
|\*\*Candidate:\*\*.* # old synthetic personas
|
||||
|^Candidate:.*
|
||||
|^Team:.*
|
||||
""",
|
||||
re.VERBOSE | re.MULTILINE,
|
||||
)
|
||||
|
||||
# Old repo/path names to normalise to current equivalents.
|
||||
_PATH_NORMALIZATIONS: list[tuple[re.Pattern, str]] = [
|
||||
(re.compile(r"/devl/job-seeker", re.IGNORECASE), "/Library/Development/CircuitForge/peregrine"),
|
||||
(re.compile(r"\bjob-seeker\b", re.IGNORECASE), "peregrine"),
|
||||
(re.compile(r"Alex Rivera", re.IGNORECASE), "[user]"),
|
||||
]
|
||||
|
||||
# Instruction paraphrase templates per doc type.
|
||||
# Each entry is (user_prefix, paired_prefix).
|
||||
# {title}, {product}, {type_phrase}, {overview}, {design_context} are substituted.
|
||||
_DESIGN_INSTRUCTIONS = [
|
||||
"Write a design document for {product}: {title}.\n\nContext: {overview}",
|
||||
"You are a software architect working on {product}. Draft a design spec for: {title}.\n\n{overview}",
|
||||
"Produce a CircuitForge-style design document for the following {product} feature — {title}.\n\nBackground: {overview}",
|
||||
]
|
||||
|
||||
_PLAN_INSTRUCTIONS = [
|
||||
"Write an implementation plan for {product}: {title}.\n\nContext: {overview}",
|
||||
"Break the following {product} feature into a detailed implementation plan with file structure and task checkboxes — {title}.\n\n{overview}",
|
||||
"You are a senior engineer on {product}. Produce a step-by-step engineering plan for: {title}.\n\n{overview}",
|
||||
]
|
||||
|
||||
_PAIRED_INSTRUCTIONS = [
|
||||
(
|
||||
"You are a software architect working on {product}, a CircuitForge product. "
|
||||
"Given the following design document, write a detailed implementation plan "
|
||||
"(file structure, task breakdown with checkboxes, migration steps if needed).\n\n"
|
||||
"---\n{design_context}\n---"
|
||||
),
|
||||
(
|
||||
"The following is a design spec for a {product} feature. "
|
||||
"Produce a concrete implementation plan: file list, task checklist, any DB migrations needed.\n\n"
|
||||
"---\n{design_context}\n---"
|
||||
),
|
||||
(
|
||||
"Convert this {product} design document into an actionable implementation plan. "
|
||||
"Include all files to create/modify, step-by-step tasks with checkboxes, and migration steps.\n\n"
|
||||
"---\n{design_context}\n---"
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def _doc_type(stem: str) -> str:
|
||||
m = _TYPE_RE.search(stem)
|
||||
if not m:
|
||||
return "other"
|
||||
raw = m.group(1).lower().rstrip("s")
|
||||
return {"implementation": "plan"}.get(raw, raw)
|
||||
|
||||
|
||||
def _date_feature(stem: str) -> tuple[str, str]:
|
||||
"""Return (date, feature_slug) from '2026-03-03-feedback-button-design'."""
|
||||
m = re.match(r"^(\d{4}-\d{2}-\d{2})-(.+?)(?:-(design|plan|spec|implementation)s?)?$", stem, re.I)
|
||||
if m:
|
||||
return m.group(1), m.group(2)
|
||||
return "", stem
|
||||
|
||||
|
||||
# ── Content extraction ─────────────────────────────────────────────────────────
|
||||
|
||||
def _extract_title(content: str) -> str:
|
||||
m = re.search(r"^#\s+(.+)", content, re.MULTILINE)
|
||||
return m.group(1).strip() if m else ""
|
||||
|
||||
|
||||
def _extract_overview(content: str) -> str:
|
||||
"""Return first substantive paragraph or h2 section body (≤300 chars)."""
|
||||
# Superpowers plans have an explicit **Goal:** line — prefer that.
|
||||
goal_m = re.search(r"\*\*Goal:\*\*\s*(.+)", content)
|
||||
if goal_m:
|
||||
return goal_m.group(1).strip()[:300]
|
||||
|
||||
# Otherwise use the body of the first h2 section.
|
||||
h2_m = re.search(
|
||||
r"^##\s+\d*\.?\s*.+\n([\s\S]+?)(?=^##|\Z)",
|
||||
content,
|
||||
re.MULTILINE,
|
||||
)
|
||||
if h2_m:
|
||||
body = h2_m.group(1).strip()
|
||||
# Strip markdown bullet/code noise for the instruction
|
||||
body = re.sub(r"```[\s\S]*?```", "", body)
|
||||
body = re.sub(r"`[^`]+`", lambda m: m.group().strip("`"), body)
|
||||
body = re.sub(r"\*\*([^*]+)\*\*", r"\1", body)
|
||||
body = re.sub(r"\s+", " ", body).strip()
|
||||
return body[:300]
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
def _clean_content(content: str) -> str:
|
||||
"""Remove boilerplate, normalize old paths/names, collapse whitespace."""
|
||||
cleaned = _BOILERPLATE_RE.sub("", content)
|
||||
for pattern, replacement in _PATH_NORMALIZATIONS:
|
||||
cleaned = pattern.sub(replacement, cleaned)
|
||||
cleaned = re.sub(r"\n{3,}", "\n\n", cleaned)
|
||||
return cleaned.strip()
|
||||
|
||||
|
||||
def _quality_flags(content: str) -> list[str]:
|
||||
"""Return a list of quality issue labels found in cleaned content."""
|
||||
flags = []
|
||||
if "Alex Rivera" in content or "[user]" in content:
|
||||
flags.append("persona-residue")
|
||||
if re.search(r"\bStatus:\s*(Complete|Done|Merged)\b", content):
|
||||
flags.append("completed-status")
|
||||
return flags
|
||||
|
||||
|
||||
def _make_instruction(
|
||||
title: str,
|
||||
product: str,
|
||||
doc_type: str,
|
||||
overview: str,
|
||||
design_context: str | None = None,
|
||||
variant: int = 0,
|
||||
) -> str:
|
||||
"""Synthesise a natural planning prompt for this document.
|
||||
|
||||
variant: 0-2 selects which paraphrase template to use. Caller cycles
|
||||
through all three to produce multiple training examples per document.
|
||||
"""
|
||||
product_label = product.replace("-", " ").title() if product else "CircuitForge"
|
||||
idx = variant % 3
|
||||
|
||||
if design_context:
|
||||
tmpl = _PAIRED_INSTRUCTIONS[idx]
|
||||
return tmpl.format(
|
||||
product=product_label,
|
||||
design_context=design_context[:2500],
|
||||
)
|
||||
|
||||
templates = _PLAN_INSTRUCTIONS if doc_type in ("plan",) else _DESIGN_INSTRUCTIONS
|
||||
tmpl = templates[idx]
|
||||
return tmpl.format(
|
||||
product=product_label,
|
||||
title=title,
|
||||
overview=overview or "",
|
||||
type_phrase="planning document",
|
||||
)
|
||||
|
||||
|
||||
def _record_id(content: str, source: str) -> str:
|
||||
return hashlib.sha256(f"{source}:{content}".encode()).hexdigest()[:16]
|
||||
|
||||
|
||||
# ── Pair discovery ─────────────────────────────────────────────────────────────
|
||||
|
||||
def _find_pairs(plans_dir: Path) -> dict[str, list[tuple[str, Path]]]:
|
||||
"""Return {prefix_key → [(doc_type, path), ...]} for docs sharing date+feature."""
|
||||
by_prefix: dict[str, list[tuple[str, Path]]] = {}
|
||||
for path in plans_dir.rglob("*.md"):
|
||||
if any(part in _SKIP_DIRS for part in path.parts):
|
||||
continue
|
||||
if path.name == "README.md":
|
||||
continue
|
||||
stem = path.stem
|
||||
date, feature = _date_feature(stem)
|
||||
if not date:
|
||||
continue
|
||||
key = str(path.parent / f"{date}-{feature}")
|
||||
by_prefix.setdefault(key, []).append((_doc_type(stem), path))
|
||||
return by_prefix
|
||||
|
||||
|
||||
# ── Record generation ──────────────────────────────────────────────────────────
|
||||
|
||||
def _records_for_group(
|
||||
doc_type_paths: list[tuple[str, Path]],
|
||||
plans_dir: Path,
|
||||
) -> Iterator[dict]:
|
||||
"""Yield one or more training records for a group of related docs."""
|
||||
# Separate design vs plan docs within this group
|
||||
designs = [(t, p) for t, p in doc_type_paths if t in ("design", "spec")]
|
||||
plans_ = [(t, p) for t, p in doc_type_paths if t in ("plan",)]
|
||||
others = [(t, p) for t, p in doc_type_paths if t not in ("design", "spec", "plan")]
|
||||
|
||||
all_paths = doc_type_paths
|
||||
|
||||
if designs and plans_:
|
||||
# Paired: yield a design→plan record (3 instruction variants)
|
||||
design_type, design_path = designs[0]
|
||||
plan_type, plan_path = plans_[0]
|
||||
design_content = design_path.read_text(encoding="utf-8")
|
||||
plan_content = plan_path.read_text(encoding="utf-8")
|
||||
|
||||
product = _product_from_path(plan_path, plans_dir)
|
||||
title = _extract_title(plan_content) or plan_path.stem
|
||||
cleaned = _clean_content(plan_content)
|
||||
design_cleaned = _clean_content(design_content)
|
||||
flags = _quality_flags(cleaned)
|
||||
|
||||
if len(cleaned.split()) >= 80:
|
||||
rel_src = str(plan_path.relative_to(plans_dir))
|
||||
rel_design = str(design_path.relative_to(plans_dir))
|
||||
for variant in range(3):
|
||||
instruction = _make_instruction(
|
||||
title=title,
|
||||
product=product,
|
||||
doc_type="plan",
|
||||
overview=_extract_overview(design_content),
|
||||
design_context=design_cleaned,
|
||||
variant=variant,
|
||||
)
|
||||
yield {
|
||||
"id": _record_id(f"v{variant}:{cleaned}", rel_src),
|
||||
"messages": [
|
||||
{"role": "user", "content": instruction},
|
||||
{"role": "assistant", "content": cleaned},
|
||||
],
|
||||
"meta": {
|
||||
"source": rel_src,
|
||||
"product": product,
|
||||
"doc_type": "plan",
|
||||
"date": _date_feature(plan_path.stem)[0],
|
||||
"paired_with": rel_design,
|
||||
"word_count": len(cleaned.split()),
|
||||
"pair_role": "target",
|
||||
"variant": variant,
|
||||
"quality_flags": flags,
|
||||
},
|
||||
}
|
||||
|
||||
# Also yield the design doc as standalone variants
|
||||
all_paths = [(t, p) for t, p in all_paths if p != plan_path]
|
||||
|
||||
# Remaining docs as standalone records (3 instruction variants each)
|
||||
for doc_type, path in all_paths:
|
||||
content = path.read_text(encoding="utf-8")
|
||||
cleaned = _clean_content(content)
|
||||
if len(cleaned.split()) < 80:
|
||||
continue
|
||||
|
||||
product = _product_from_path(path, plans_dir)
|
||||
title = _extract_title(content) or path.stem
|
||||
overview = _extract_overview(content)
|
||||
flags = _quality_flags(cleaned)
|
||||
rel_src = str(path.relative_to(plans_dir))
|
||||
|
||||
for variant in range(3):
|
||||
instruction = _make_instruction(
|
||||
title=title,
|
||||
product=product,
|
||||
doc_type=doc_type,
|
||||
overview=overview,
|
||||
variant=variant,
|
||||
)
|
||||
yield {
|
||||
"id": _record_id(f"v{variant}:{cleaned}", rel_src),
|
||||
"messages": [
|
||||
{"role": "user", "content": instruction},
|
||||
{"role": "assistant", "content": cleaned},
|
||||
],
|
||||
"meta": {
|
||||
"source": rel_src,
|
||||
"product": product,
|
||||
"doc_type": doc_type,
|
||||
"date": _date_feature(path.stem)[0],
|
||||
"paired_with": None,
|
||||
"word_count": len(cleaned.split()),
|
||||
"pair_role": "standalone",
|
||||
"variant": variant,
|
||||
"quality_flags": flags,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _product_from_path(path: Path, plans_dir: Path) -> str:
|
||||
rel = path.relative_to(plans_dir)
|
||||
return rel.parts[0] if len(rel.parts) > 1 else "shared"
|
||||
|
||||
|
||||
# ── Main export ────────────────────────────────────────────────────────────────
|
||||
|
||||
def export(
|
||||
plans_dir: Path,
|
||||
products: list[str] | None = None,
|
||||
) -> list[dict]:
|
||||
groups = _find_pairs(plans_dir)
|
||||
records: list[dict] = []
|
||||
seen_ids: set[str] = set()
|
||||
|
||||
for group_key, doc_type_paths in groups.items():
|
||||
# Filter by product if requested
|
||||
if products:
|
||||
paths = [p for _, p in doc_type_paths]
|
||||
prods = {_product_from_path(p, plans_dir) for p in paths}
|
||||
if not prods.intersection(products):
|
||||
continue
|
||||
|
||||
for record in _records_for_group(doc_type_paths, plans_dir):
|
||||
if record["id"] not in seen_ids:
|
||||
seen_ids.add(record["id"])
|
||||
records.append(record)
|
||||
|
||||
return records
|
||||
|
||||
|
||||
# ── CLI ────────────────────────────────────────────────────────────────────────
|
||||
|
||||
def _print_stats(records: list[dict]) -> None:
|
||||
from collections import Counter
|
||||
products = Counter(r["meta"]["product"] for r in records)
|
||||
doc_types = Counter(r["meta"]["doc_type"] for r in records)
|
||||
pair_roles = Counter(r["meta"]["pair_role"] for r in records)
|
||||
wc = [r["meta"]["word_count"] for r in records]
|
||||
wc.sort()
|
||||
|
||||
print(f"\n{'='*55}")
|
||||
print(f" Total records: {len(records)}")
|
||||
print(f" Word counts : min={wc[0]}, median={wc[len(wc)//2]}, max={wc[-1]}")
|
||||
print(f"\n By product:")
|
||||
for p, n in products.most_common():
|
||||
print(f" {p:<22} {n}")
|
||||
print(f"\n By doc type:")
|
||||
for t, n in doc_types.most_common():
|
||||
print(f" {t:<22} {n}")
|
||||
print(f"\n Pair roles:")
|
||||
for r, n in pair_roles.most_common():
|
||||
print(f" {r:<22} {n}")
|
||||
print(f"{'='*55}\n")
|
||||
|
||||
|
||||
def _print_sample(records: list[dict], n: int = 3) -> None:
|
||||
import random
|
||||
sample = random.sample(records, min(n, len(records)))
|
||||
for i, rec in enumerate(sample, 1):
|
||||
meta = rec["meta"]
|
||||
user_msg = rec["messages"][0]["content"]
|
||||
asst_msg = rec["messages"][1]["content"]
|
||||
print(f"\n{'─'*55}")
|
||||
print(f"SAMPLE {i}/{n} [{meta['product']} / {meta['doc_type']} / {meta['pair_role']}]")
|
||||
print(f"source: {meta['source']}")
|
||||
print(f"\nUSER ({len(user_msg)} chars):\n{user_msg[:500]}{'...' if len(user_msg)>500 else ''}")
|
||||
print(f"\nASSISTANT ({meta['word_count']} words):\n{asst_msg[:400]}{'...' if len(asst_msg)>400 else ''}")
|
||||
print(f"\n{'─'*55}\n")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter)
|
||||
parser.add_argument("--plans-dir", type=Path, default=_DEFAULT_PLANS_DIR)
|
||||
parser.add_argument("--output", type=Path, default=None,
|
||||
help="Write JSONL to this path (omit for preview-only)")
|
||||
parser.add_argument("--products", default=None,
|
||||
help="Comma-separated product filter, e.g. peregrine,kiwi")
|
||||
parser.add_argument("--preview", action="store_true",
|
||||
help="Print stats + sample records, don't write output")
|
||||
parser.add_argument("--samples", type=int, default=3,
|
||||
help="Number of sample records to show in preview (default 3)")
|
||||
args = parser.parse_args()
|
||||
|
||||
products = [p.strip() for p in args.products.split(",")] if args.products else None
|
||||
|
||||
print(f"Scanning {args.plans_dir} …", file=sys.stderr)
|
||||
records = export(args.plans_dir, products=products)
|
||||
|
||||
_print_stats(records)
|
||||
|
||||
if args.preview or args.output is None:
|
||||
_print_sample(records, n=args.samples)
|
||||
if args.output is None:
|
||||
print("(Pass --output <path> to write JSONL)")
|
||||
return
|
||||
|
||||
args.output.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(args.output, "w", encoding="utf-8") as f:
|
||||
for rec in records:
|
||||
f.write(json.dumps(rec, ensure_ascii=False) + "\n")
|
||||
|
||||
print(f"Wrote {len(records)} records to {args.output}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -1,23 +1,37 @@
|
|||
import json
|
||||
"""Smoke tests for the app factory (app/api.py).
|
||||
|
||||
Detailed route tests live in test_data_label.py, test_data_fetch.py,
|
||||
test_data_corrections.py, test_train.py, and test_dashboard.py.
|
||||
"""
|
||||
import pytest
|
||||
from app import api as api_module # noqa: F401
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_globals(tmp_path):
|
||||
from app import api
|
||||
api.set_data_dir(tmp_path)
|
||||
api.reset_last_action()
|
||||
yield
|
||||
api.reset_last_action()
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
|
||||
def test_import():
|
||||
from app import api # noqa: F401
|
||||
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
def test_app_has_required_routes():
|
||||
from app.api import app
|
||||
paths = {r.path for r in app.routes}
|
||||
# Label routes
|
||||
assert "/api/queue" in paths
|
||||
assert "/api/label" in paths
|
||||
assert "/api/skip" in paths
|
||||
assert "/api/discard" in paths
|
||||
assert "/api/label/undo" in paths
|
||||
assert "/api/config/labels" in paths
|
||||
assert "/api/stats" in paths
|
||||
# Fetch routes
|
||||
assert "/api/accounts/test" in paths
|
||||
assert "/api/fetch/stream" in paths
|
||||
# Train routes
|
||||
assert "/api/train/jobs" in paths
|
||||
assert "/api/train/results" in paths
|
||||
# Dashboard
|
||||
assert "/api/dashboard" in paths
|
||||
# Corrections (new prefix)
|
||||
assert "/api/corrections/ingest" in paths
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -26,536 +40,8 @@ def client():
|
|||
return TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def queue_with_items():
|
||||
"""Write 3 test emails to the queue file."""
|
||||
from app import api as api_module
|
||||
items = [
|
||||
{"id": f"id{i}", "subject": f"Subject {i}", "body": f"Body {i}",
|
||||
"from": "test@example.com", "date": "2026-03-01", "source": "imap:test"}
|
||||
for i in range(3)
|
||||
]
|
||||
queue_path = api_module._DATA_DIR / "email_label_queue.jsonl"
|
||||
queue_path.write_text("\n".join(json.dumps(x) for x in items) + "\n")
|
||||
return items
|
||||
|
||||
|
||||
def test_queue_returns_items(client, queue_with_items):
|
||||
r = client.get("/api/queue?limit=2")
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert len(data["items"]) == 2
|
||||
assert data["total"] == 3
|
||||
|
||||
|
||||
def test_queue_empty_when_no_file(client):
|
||||
def test_queue_endpoint_reachable(client):
|
||||
r = client.get("/api/queue")
|
||||
assert r.status_code == 200
|
||||
assert r.json() == {"items": [], "total": 0}
|
||||
|
||||
|
||||
def test_label_appends_to_score(client, queue_with_items):
|
||||
from app import api as api_module
|
||||
r = client.post("/api/label", json={"id": "id0", "label": "interview_scheduled"})
|
||||
assert r.status_code == 200
|
||||
records = api_module._read_jsonl(api_module._score_file())
|
||||
assert len(records) == 1
|
||||
assert records[0]["id"] == "id0"
|
||||
assert records[0]["label"] == "interview_scheduled"
|
||||
assert "labeled_at" in records[0]
|
||||
|
||||
def test_label_removes_from_queue(client, queue_with_items):
|
||||
from app import api as api_module
|
||||
client.post("/api/label", json={"id": "id0", "label": "rejected"})
|
||||
queue = api_module._read_jsonl(api_module._queue_file())
|
||||
assert not any(x["id"] == "id0" for x in queue)
|
||||
|
||||
def test_label_unknown_id_returns_404(client, queue_with_items):
|
||||
r = client.post("/api/label", json={"id": "unknown", "label": "neutral"})
|
||||
assert r.status_code == 404
|
||||
|
||||
def test_skip_moves_to_back(client, queue_with_items):
|
||||
from app import api as api_module
|
||||
r = client.post("/api/skip", json={"id": "id0"})
|
||||
assert r.status_code == 200
|
||||
queue = api_module._read_jsonl(api_module._queue_file())
|
||||
assert queue[-1]["id"] == "id0"
|
||||
assert queue[0]["id"] == "id1"
|
||||
|
||||
def test_skip_unknown_id_returns_404(client, queue_with_items):
|
||||
r = client.post("/api/skip", json={"id": "nope"})
|
||||
assert r.status_code == 404
|
||||
|
||||
|
||||
# --- Part A: POST /api/discard ---
|
||||
|
||||
def test_discard_writes_to_discarded_file(client, queue_with_items):
|
||||
from app import api as api_module
|
||||
r = client.post("/api/discard", json={"id": "id1"})
|
||||
assert r.status_code == 200
|
||||
discarded = api_module._read_jsonl(api_module._discarded_file())
|
||||
assert len(discarded) == 1
|
||||
assert discarded[0]["id"] == "id1"
|
||||
assert discarded[0]["label"] == "__discarded__"
|
||||
|
||||
def test_discard_removes_from_queue(client, queue_with_items):
|
||||
from app import api as api_module
|
||||
client.post("/api/discard", json={"id": "id1"})
|
||||
queue = api_module._read_jsonl(api_module._queue_file())
|
||||
assert not any(x["id"] == "id1" for x in queue)
|
||||
|
||||
|
||||
# --- Part B: DELETE /api/label/undo ---
|
||||
|
||||
def test_undo_label_removes_from_score(client, queue_with_items):
|
||||
from app import api as api_module
|
||||
client.post("/api/label", json={"id": "id0", "label": "neutral"})
|
||||
r = client.delete("/api/label/undo")
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["undone"]["type"] == "label"
|
||||
score = api_module._read_jsonl(api_module._score_file())
|
||||
assert score == []
|
||||
# Item should be restored to front of queue
|
||||
queue = api_module._read_jsonl(api_module._queue_file())
|
||||
assert queue[0]["id"] == "id0"
|
||||
|
||||
def test_undo_discard_removes_from_discarded(client, queue_with_items):
|
||||
from app import api as api_module
|
||||
client.post("/api/discard", json={"id": "id0"})
|
||||
r = client.delete("/api/label/undo")
|
||||
assert r.status_code == 200
|
||||
discarded = api_module._read_jsonl(api_module._discarded_file())
|
||||
assert discarded == []
|
||||
|
||||
def test_undo_skip_restores_to_front(client, queue_with_items):
|
||||
from app import api as api_module
|
||||
client.post("/api/skip", json={"id": "id0"})
|
||||
r = client.delete("/api/label/undo")
|
||||
assert r.status_code == 200
|
||||
queue = api_module._read_jsonl(api_module._queue_file())
|
||||
assert queue[0]["id"] == "id0"
|
||||
|
||||
def test_undo_with_no_action_returns_404(client):
|
||||
r = client.delete("/api/label/undo")
|
||||
assert r.status_code == 404
|
||||
|
||||
|
||||
# --- Part C: GET /api/config/labels ---
|
||||
|
||||
def test_config_labels_returns_metadata(client):
|
||||
r = client.get("/api/config/labels")
|
||||
assert r.status_code == 200
|
||||
labels = r.json()
|
||||
assert len(labels) == 10
|
||||
assert labels[0]["key"] == "1"
|
||||
assert "emoji" in labels[0]
|
||||
assert "color" in labels[0]
|
||||
assert "name" in labels[0]
|
||||
|
||||
|
||||
# ── /api/config ──────────────────────────────────────────────────────────────
|
||||
|
||||
@pytest.fixture
|
||||
def config_dir(tmp_path):
|
||||
"""Give the API a writable config directory."""
|
||||
from app import api as api_module
|
||||
api_module.set_config_dir(tmp_path)
|
||||
yield tmp_path
|
||||
api_module.set_config_dir(None) # reset to default
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def data_dir():
|
||||
"""Expose the current _DATA_DIR set by the autouse reset_globals fixture."""
|
||||
from app import api as api_module
|
||||
return api_module._DATA_DIR
|
||||
|
||||
|
||||
def test_get_config_returns_empty_when_no_file(client, config_dir):
|
||||
r = client.get("/api/config")
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["accounts"] == []
|
||||
assert data["max_per_account"] == 500
|
||||
|
||||
|
||||
def test_post_config_writes_yaml(client, config_dir):
|
||||
import yaml
|
||||
payload = {
|
||||
"accounts": [{"name": "Test", "host": "imap.test.com", "port": 993,
|
||||
"use_ssl": True, "username": "u@t.com", "password": "pw",
|
||||
"folder": "INBOX", "days_back": 30}],
|
||||
"max_per_account": 200,
|
||||
}
|
||||
r = client.post("/api/config", json=payload)
|
||||
assert r.status_code == 200
|
||||
assert r.json()["ok"] is True
|
||||
cfg_file = config_dir / "label_tool.yaml"
|
||||
assert cfg_file.exists()
|
||||
saved = yaml.safe_load(cfg_file.read_text())
|
||||
assert saved["max_per_account"] == 200
|
||||
assert saved["accounts"][0]["name"] == "Test"
|
||||
|
||||
|
||||
def test_get_config_round_trips(client, config_dir):
|
||||
payload = {"accounts": [{"name": "R", "host": "h", "port": 993, "use_ssl": True,
|
||||
"username": "u", "password": "p", "folder": "INBOX",
|
||||
"days_back": 90}], "max_per_account": 300}
|
||||
client.post("/api/config", json=payload)
|
||||
r = client.get("/api/config")
|
||||
data = r.json()
|
||||
assert data["max_per_account"] == 300
|
||||
assert data["accounts"][0]["name"] == "R"
|
||||
|
||||
|
||||
# ── /api/stats ───────────────────────────────────────────────────────────────
|
||||
|
||||
@pytest.fixture
|
||||
def score_with_labels(tmp_path, data_dir):
|
||||
"""Write a score file with 3 labels for stats tests."""
|
||||
score_path = data_dir / "email_score.jsonl"
|
||||
records = [
|
||||
{"id": "a", "label": "interview_scheduled"},
|
||||
{"id": "b", "label": "interview_scheduled"},
|
||||
{"id": "c", "label": "rejected"},
|
||||
]
|
||||
score_path.write_text("\n".join(json.dumps(r) for r in records) + "\n")
|
||||
return records
|
||||
|
||||
|
||||
def test_stats_returns_counts(client, score_with_labels):
|
||||
r = client.get("/api/stats")
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["total"] == 3
|
||||
assert data["counts"]["interview_scheduled"] == 2
|
||||
assert data["counts"]["rejected"] == 1
|
||||
|
||||
|
||||
def test_stats_empty_when_no_file(client, data_dir):
|
||||
r = client.get("/api/stats")
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["total"] == 0
|
||||
assert data["counts"] == {}
|
||||
assert data["score_file_bytes"] == 0
|
||||
|
||||
|
||||
def test_stats_download_returns_file(client, score_with_labels):
|
||||
r = client.get("/api/stats/download")
|
||||
assert r.status_code == 200
|
||||
assert "jsonlines" in r.headers.get("content-type", "")
|
||||
|
||||
|
||||
def test_stats_download_404_when_no_file(client, data_dir):
|
||||
r = client.get("/api/stats/download")
|
||||
assert r.status_code == 404
|
||||
|
||||
|
||||
# ── /api/accounts/test ───────────────────────────────────────────────────────
|
||||
|
||||
def test_account_test_missing_fields(client):
|
||||
r = client.post("/api/accounts/test", json={"account": {"host": "", "username": "", "password": ""}})
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["ok"] is False
|
||||
assert "required" in data["message"].lower()
|
||||
|
||||
|
||||
def test_account_test_success(client):
|
||||
from unittest.mock import MagicMock, patch
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.select.return_value = ("OK", [b"99"])
|
||||
with patch("app.imap_fetch.imaplib.IMAP4_SSL", return_value=mock_conn):
|
||||
r = client.post("/api/accounts/test", json={"account": {
|
||||
"host": "imap.example.com", "port": 993, "use_ssl": True,
|
||||
"username": "u@example.com", "password": "pw", "folder": "INBOX",
|
||||
}})
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["ok"] is True
|
||||
assert data["count"] == 99
|
||||
|
||||
|
||||
# ── /api/fetch/stream (SSE) ──────────────────────────────────────────────────
|
||||
|
||||
def _parse_sse(content: bytes) -> list[dict]:
|
||||
"""Parse SSE response body into list of event dicts."""
|
||||
events = []
|
||||
for line in content.decode().splitlines():
|
||||
if line.startswith("data: "):
|
||||
events.append(json.loads(line[6:]))
|
||||
return events
|
||||
|
||||
|
||||
def test_fetch_stream_no_accounts_configured(client, config_dir):
|
||||
"""With no config, stream should immediately complete with 0 added."""
|
||||
r = client.get("/api/fetch/stream?accounts=NoSuchAccount&days_back=30&limit=10")
|
||||
assert r.status_code == 200
|
||||
events = _parse_sse(r.content)
|
||||
complete = next((e for e in events if e["type"] == "complete"), None)
|
||||
assert complete is not None
|
||||
assert complete["total_added"] == 0
|
||||
|
||||
|
||||
def test_fetch_stream_with_mock_imap(client, config_dir, data_dir):
|
||||
"""With one configured account, stream should yield start/done/complete events."""
|
||||
import yaml
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
# Write a config with one account
|
||||
cfg = {"accounts": [{"name": "Mock", "host": "h", "port": 993, "use_ssl": True,
|
||||
"username": "u", "password": "p", "folder": "INBOX",
|
||||
"days_back": 30}], "max_per_account": 50}
|
||||
(config_dir / "label_tool.yaml").write_text(yaml.dump(cfg))
|
||||
|
||||
raw_msg = (b"Subject: Interview\r\nFrom: a@b.com\r\n"
|
||||
b"Date: Mon, 1 Mar 2026 12:00:00 +0000\r\n\r\nBody")
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.search.return_value = ("OK", [b"1"])
|
||||
mock_conn.fetch.return_value = ("OK", [(b"1 (RFC822 {N})", raw_msg)])
|
||||
|
||||
with patch("app.imap_fetch.imaplib.IMAP4_SSL", return_value=mock_conn):
|
||||
r = client.get("/api/fetch/stream?accounts=Mock&days_back=30&limit=50")
|
||||
|
||||
assert r.status_code == 200
|
||||
events = _parse_sse(r.content)
|
||||
types = [e["type"] for e in events]
|
||||
assert "start" in types
|
||||
assert "done" in types
|
||||
assert "complete" in types
|
||||
|
||||
|
||||
# ---- /api/finetune/status tests ----
|
||||
|
||||
def test_finetune_status_returns_empty_when_no_models_dir(client):
|
||||
"""GET /api/finetune/status must return [] if models/ does not exist."""
|
||||
r = client.get("/api/finetune/status")
|
||||
assert r.status_code == 200
|
||||
assert r.json() == []
|
||||
|
||||
|
||||
def test_finetune_status_returns_training_info(client, tmp_path):
|
||||
"""GET /api/finetune/status must return one entry per training_info.json found."""
|
||||
import json as _json
|
||||
from app import api as api_module
|
||||
|
||||
models_dir = tmp_path / "models" / "avocet-deberta-small"
|
||||
models_dir.mkdir(parents=True)
|
||||
info = {
|
||||
"name": "avocet-deberta-small",
|
||||
"base_model_id": "cross-encoder/nli-deberta-v3-small",
|
||||
"val_macro_f1": 0.712,
|
||||
"timestamp": "2026-03-15T12:00:00Z",
|
||||
"sample_count": 401,
|
||||
}
|
||||
(models_dir / "training_info.json").write_text(_json.dumps(info))
|
||||
|
||||
api_module.set_models_dir(tmp_path / "models")
|
||||
try:
|
||||
r = client.get("/api/finetune/status")
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert any(d["name"] == "avocet-deberta-small" for d in data)
|
||||
finally:
|
||||
api_module.set_models_dir(api_module._ROOT / "models")
|
||||
|
||||
|
||||
def test_finetune_run_streams_sse_events(client):
|
||||
"""GET /api/finetune/run must return text/event-stream content type."""
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.stdout = iter(["Training epoch 1\n", "Done\n"])
|
||||
mock_proc.returncode = 0
|
||||
mock_proc.wait = MagicMock()
|
||||
|
||||
with patch("app.api._subprocess.Popen",return_value=mock_proc):
|
||||
r = client.get("/api/finetune/run?model=deberta-small&epochs=1")
|
||||
|
||||
assert r.status_code == 200
|
||||
assert "text/event-stream" in r.headers.get("content-type", "")
|
||||
|
||||
|
||||
def test_finetune_run_emits_complete_on_success(client):
|
||||
"""GET /api/finetune/run must emit a complete event on clean exit."""
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.stdout = iter(["progress line\n"])
|
||||
mock_proc.returncode = 0
|
||||
mock_proc.wait = MagicMock()
|
||||
|
||||
with patch("app.api._subprocess.Popen",return_value=mock_proc):
|
||||
r = client.get("/api/finetune/run?model=deberta-small&epochs=1")
|
||||
|
||||
assert '{"type": "complete"}' in r.text
|
||||
|
||||
|
||||
def test_finetune_run_emits_error_on_nonzero_exit(client):
|
||||
"""GET /api/finetune/run must emit an error event on non-zero exit."""
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.stdout = iter([])
|
||||
mock_proc.returncode = 1
|
||||
mock_proc.wait = MagicMock()
|
||||
|
||||
with patch("app.api._subprocess.Popen",return_value=mock_proc):
|
||||
r = client.get("/api/finetune/run?model=deberta-small&epochs=1")
|
||||
|
||||
assert '"type": "error"' in r.text
|
||||
|
||||
|
||||
def test_finetune_run_passes_score_files_to_subprocess(client):
|
||||
"""GET /api/finetune/run?score=file1&score=file2 must pass --score args to subprocess."""
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
captured_cmd = []
|
||||
|
||||
def mock_popen(cmd, **kwargs):
|
||||
captured_cmd.extend(cmd)
|
||||
m = MagicMock()
|
||||
m.stdout = iter([])
|
||||
m.returncode = 0
|
||||
m.wait = MagicMock()
|
||||
return m
|
||||
|
||||
with patch("app.api._subprocess.Popen",side_effect=mock_popen):
|
||||
client.get("/api/finetune/run?model=deberta-small&epochs=1&score=run1.jsonl&score=run2.jsonl")
|
||||
|
||||
assert "--score" in captured_cmd
|
||||
assert captured_cmd.count("--score") == 2
|
||||
# Paths are resolved to absolute — check filenames are present as substrings
|
||||
assert any("run1.jsonl" in arg for arg in captured_cmd)
|
||||
assert any("run2.jsonl" in arg for arg in captured_cmd)
|
||||
|
||||
|
||||
# ---- Cancel endpoint tests ----
|
||||
|
||||
def test_benchmark_cancel_returns_404_when_not_running(client):
|
||||
"""POST /api/benchmark/cancel must return 404 if no benchmark is running."""
|
||||
from app import api as api_module
|
||||
api_module._running_procs.pop("benchmark", None)
|
||||
r = client.post("/api/benchmark/cancel")
|
||||
assert r.status_code == 404
|
||||
|
||||
|
||||
def test_finetune_cancel_returns_404_when_not_running(client):
|
||||
"""POST /api/finetune/cancel must return 404 if no finetune is running."""
|
||||
from app import api as api_module
|
||||
api_module._running_procs.pop("finetune", None)
|
||||
r = client.post("/api/finetune/cancel")
|
||||
assert r.status_code == 404
|
||||
|
||||
|
||||
def test_benchmark_cancel_terminates_running_process(client):
|
||||
"""POST /api/benchmark/cancel must call terminate() on the running process."""
|
||||
from unittest.mock import MagicMock
|
||||
from app import api as api_module
|
||||
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.wait = MagicMock()
|
||||
api_module._running_procs["benchmark"] = mock_proc
|
||||
|
||||
try:
|
||||
r = client.post("/api/benchmark/cancel")
|
||||
assert r.status_code == 200
|
||||
assert r.json()["status"] == "cancelled"
|
||||
mock_proc.terminate.assert_called_once()
|
||||
finally:
|
||||
api_module._running_procs.pop("benchmark", None)
|
||||
api_module._cancelled_jobs.discard("benchmark")
|
||||
|
||||
|
||||
def test_finetune_cancel_terminates_running_process(client):
|
||||
"""POST /api/finetune/cancel must call terminate() on the running process."""
|
||||
from unittest.mock import MagicMock
|
||||
from app import api as api_module
|
||||
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.wait = MagicMock()
|
||||
api_module._running_procs["finetune"] = mock_proc
|
||||
|
||||
try:
|
||||
r = client.post("/api/finetune/cancel")
|
||||
assert r.status_code == 200
|
||||
assert r.json()["status"] == "cancelled"
|
||||
mock_proc.terminate.assert_called_once()
|
||||
finally:
|
||||
api_module._running_procs.pop("finetune", None)
|
||||
api_module._cancelled_jobs.discard("finetune")
|
||||
|
||||
|
||||
def test_benchmark_cancel_kills_process_on_timeout(client):
|
||||
"""POST /api/benchmark/cancel must call kill() if the process does not exit within 3 s."""
|
||||
import subprocess
|
||||
from unittest.mock import MagicMock
|
||||
from app import api as api_module
|
||||
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.wait.side_effect = subprocess.TimeoutExpired(cmd="benchmark", timeout=3)
|
||||
api_module._running_procs["benchmark"] = mock_proc
|
||||
|
||||
try:
|
||||
r = client.post("/api/benchmark/cancel")
|
||||
assert r.status_code == 200
|
||||
mock_proc.kill.assert_called_once()
|
||||
finally:
|
||||
api_module._running_procs.pop("benchmark", None)
|
||||
api_module._cancelled_jobs.discard("benchmark")
|
||||
|
||||
|
||||
def test_finetune_run_emits_cancelled_event(client):
|
||||
"""GET /api/finetune/run must emit cancelled (not error) when job was cancelled."""
|
||||
from unittest.mock import patch, MagicMock
|
||||
from app import api as api_module
|
||||
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.stdout = iter([])
|
||||
mock_proc.returncode = -15 # SIGTERM
|
||||
|
||||
def mock_wait():
|
||||
# Simulate cancel being called while the process is running (after discard clears stale flag)
|
||||
api_module._cancelled_jobs.add("finetune")
|
||||
|
||||
mock_proc.wait = mock_wait
|
||||
|
||||
def mock_popen(cmd, **kwargs):
|
||||
return mock_proc
|
||||
|
||||
try:
|
||||
with patch("app.api._subprocess.Popen",side_effect=mock_popen):
|
||||
r = client.get("/api/finetune/run?model=deberta-small&epochs=1")
|
||||
assert '{"type": "cancelled"}' in r.text
|
||||
assert '"type": "error"' not in r.text
|
||||
finally:
|
||||
api_module._cancelled_jobs.discard("finetune")
|
||||
|
||||
|
||||
def test_benchmark_run_emits_cancelled_event(client):
|
||||
"""GET /api/benchmark/run must emit cancelled (not error) when job was cancelled."""
|
||||
from unittest.mock import patch, MagicMock
|
||||
from app import api as api_module
|
||||
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.stdout = iter([])
|
||||
mock_proc.returncode = -15
|
||||
|
||||
def mock_wait():
|
||||
# Simulate cancel being called while the process is running (after discard clears stale flag)
|
||||
api_module._cancelled_jobs.add("benchmark")
|
||||
|
||||
mock_proc.wait = mock_wait
|
||||
|
||||
def mock_popen(cmd, **kwargs):
|
||||
return mock_proc
|
||||
|
||||
try:
|
||||
with patch("app.api._subprocess.Popen",side_effect=mock_popen):
|
||||
r = client.get("/api/benchmark/run")
|
||||
assert '{"type": "cancelled"}' in r.text
|
||||
assert '"type": "error"' not in r.text
|
||||
finally:
|
||||
api_module._cancelled_jobs.discard("benchmark")
|
||||
assert "items" in r.json()
|
||||
assert "total" in r.json()
|
||||
|
|
|
|||
|
|
@ -2,11 +2,6 @@
|
|||
import pytest
|
||||
|
||||
|
||||
def test_registry_has_thirteen_models():
|
||||
from scripts.benchmark_classifier import MODEL_REGISTRY
|
||||
assert len(MODEL_REGISTRY) == 13
|
||||
|
||||
|
||||
def test_registry_default_count():
|
||||
from scripts.benchmark_classifier import MODEL_REGISTRY
|
||||
defaults = [k for k, v in MODEL_REGISTRY.items() if v["default"]]
|
||||
|
|
@ -166,3 +161,95 @@ def test_active_models_includes_discovered_finetuned(tmp_path):
|
|||
|
||||
assert "avocet-deberta-small" in models
|
||||
assert isinstance(models["avocet-deberta-small"]["adapter_instance"], FineTunedAdapter)
|
||||
|
||||
|
||||
# ---- build_exemplars_from_jsonl() tests ----
|
||||
|
||||
def test_build_exemplars_samples_up_to_k_per_label(tmp_path):
|
||||
from scripts.benchmark_classifier import build_exemplars_from_jsonl
|
||||
import json
|
||||
|
||||
rows = [{"subject": f"S{i}", "body": f"B{i}", "label": "rejected"} for i in range(15)]
|
||||
rows.append({"subject": "Hire", "body": "Welcome", "label": "hired"})
|
||||
f = tmp_path / "score.jsonl"
|
||||
f.write_text("\n".join(json.dumps(r) for r in rows))
|
||||
|
||||
result = build_exemplars_from_jsonl(str(f), k_per_label=10)
|
||||
|
||||
assert len(result["rejected"]) == 10
|
||||
assert len(result["hired"]) == 1
|
||||
assert result["rejected"][0].startswith("Subject: S")
|
||||
|
||||
|
||||
def test_build_exemplars_formats_text_correctly(tmp_path):
|
||||
from scripts.benchmark_classifier import build_exemplars_from_jsonl
|
||||
import json
|
||||
|
||||
row = {"subject": "My Subject", "body": "My Body", "label": "neutral"}
|
||||
f = tmp_path / "score.jsonl"
|
||||
f.write_text(json.dumps(row))
|
||||
|
||||
result = build_exemplars_from_jsonl(str(f))
|
||||
|
||||
assert result["neutral"][0] == "Subject: My Subject\n\nMy Body"
|
||||
|
||||
|
||||
def test_build_exemplars_skips_rows_missing_label(tmp_path):
|
||||
from scripts.benchmark_classifier import build_exemplars_from_jsonl
|
||||
import json
|
||||
|
||||
rows = [
|
||||
{"subject": "A", "body": "B", "label": "neutral"},
|
||||
{"subject": "No label here", "body": "Body"},
|
||||
]
|
||||
f = tmp_path / "score.jsonl"
|
||||
f.write_text("\n".join(json.dumps(r) for r in rows))
|
||||
|
||||
result = build_exemplars_from_jsonl(str(f))
|
||||
assert list(result.keys()) == ["neutral"]
|
||||
|
||||
|
||||
def test_build_exemplars_truncates_body_at_600(tmp_path):
|
||||
from scripts.benchmark_classifier import build_exemplars_from_jsonl
|
||||
import json
|
||||
|
||||
row = {"subject": "S", "body": "x" * 800, "label": "neutral"}
|
||||
f = tmp_path / "score.jsonl"
|
||||
f.write_text(json.dumps(row))
|
||||
|
||||
result = build_exemplars_from_jsonl(str(f))
|
||||
body_part = result["neutral"][0].split("\n\n", 1)[1]
|
||||
assert len(body_part) == 600
|
||||
|
||||
|
||||
def test_build_exemplars_skips_rows_with_no_content(tmp_path):
|
||||
from scripts.benchmark_classifier import build_exemplars_from_jsonl
|
||||
import json
|
||||
|
||||
rows = [
|
||||
{"label": "neutral"}, # no subject, no body -> skip
|
||||
{"subject": "S", "body": "B", "label": "neutral"}, # valid -> keep
|
||||
{"label": "rejected", "subject": "", "body": ""}, # empty strings -> skip
|
||||
]
|
||||
f = tmp_path / "score.jsonl"
|
||||
lines = [json.dumps(r) for r in rows]
|
||||
f.write_text("\n".join(lines))
|
||||
|
||||
result = build_exemplars_from_jsonl(str(f))
|
||||
assert list(result.keys()) == ["neutral"]
|
||||
assert len(result["neutral"]) == 1
|
||||
|
||||
def test_registry_has_fourteen_models():
|
||||
from scripts.benchmark_classifier import MODEL_REGISTRY
|
||||
assert len(MODEL_REGISTRY) == 14
|
||||
|
||||
|
||||
def test_embed_knn_nomic_registry_entry():
|
||||
from scripts.benchmark_classifier import MODEL_REGISTRY
|
||||
from scripts.classifier_adapters import EmbeddingKNNAdapter
|
||||
entry = MODEL_REGISTRY["embed-knn-nomic"]
|
||||
assert entry["adapter"] is EmbeddingKNNAdapter
|
||||
assert entry["model_id"] == "nomic-embed-text"
|
||||
assert entry["params"] == "local-embed"
|
||||
assert entry["default"] is False
|
||||
assert entry.get("kwargs", {}).get("k") == 3
|
||||
|
|
|
|||
|
|
@ -14,7 +14,9 @@ from fastapi.testclient import TestClient
|
|||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_cforch_globals(tmp_path):
|
||||
"""Redirect _CONFIG_DIR to tmp_path and reset running-state globals."""
|
||||
"""Redirect _CONFIG_DIR to tmp_path, reset running-state globals, and stub
|
||||
list_installed to return [] so real disk model directories don't bleed into
|
||||
tests that don't exercise the installed-model merge path."""
|
||||
from app import cforch as cforch_module
|
||||
|
||||
prev_config_dir = cforch_module._CONFIG_DIR
|
||||
|
|
@ -25,7 +27,8 @@ def reset_cforch_globals(tmp_path):
|
|||
cforch_module._BENCH_RUNNING = False
|
||||
cforch_module._bench_proc = None
|
||||
|
||||
yield tmp_path
|
||||
with patch("app.models.list_installed", return_value=[]):
|
||||
yield tmp_path
|
||||
|
||||
cforch_module.set_config_dir(prev_config_dir)
|
||||
cforch_module._BENCH_RUNNING = prev_running
|
||||
|
|
@ -141,6 +144,35 @@ def test_models_parses_bench_models_yaml(client, config_dir, tmp_path):
|
|||
assert m["vram_estimate_mb"] == 6000
|
||||
|
||||
|
||||
def test_models_merges_installed_generators(client, config_dir, tmp_path):
|
||||
"""Installed cf-text/vllm generator models appear in the model list,
|
||||
deduplicated against bench_models.yaml entries."""
|
||||
models_file = tmp_path / "bench_models.yaml"
|
||||
_write_models_yaml(models_file, [
|
||||
{"name": "llama3", "id": "llama3:8b", "service": "ollama", "tags": [], "vram_estimate_mb": 6000},
|
||||
{"name": "already-there", "id": "ibm-granite/granite-4.1-8b", "service": "cf-text", "tags": [], "vram_estimate_mb": 8000},
|
||||
])
|
||||
_write_config(config_dir, {"bench_models": str(models_file)})
|
||||
|
||||
fake_installed = [
|
||||
# should be included — cf-text generator not already in YAML
|
||||
{"model_id": "meta-llama/Llama-3.1-8B", "service": "cf-text", "role": "generator", "vram_mb": 16000},
|
||||
# should be deduped — repo_id matches a YAML entry
|
||||
{"model_id": "ibm-granite/granite-4.1-8b", "service": "cf-text", "role": "generator", "vram_mb": 8000},
|
||||
# should be excluded — classifier, not a generator
|
||||
{"model_id": "cross-encoder/ms-marco-MiniLM-L6", "service": "avocet", "role": "reranker", "vram_mb": 500},
|
||||
]
|
||||
with patch("app.models.list_installed", return_value=fake_installed):
|
||||
r = client.get("/api/cforch/models")
|
||||
assert r.status_code == 200
|
||||
ids = [m["id"] for m in r.json()["models"]]
|
||||
assert "llama3:8b" in ids # from YAML
|
||||
assert "ibm-granite/granite-4.1-8b" in ids # from YAML (not duplicated)
|
||||
assert "meta-llama/Llama-3.1-8B" in ids # merged from installed
|
||||
assert "cross-encoder/ms-marco-MiniLM-L6" not in ids # filtered out (reranker)
|
||||
assert ids.count("ibm-granite/granite-4.1-8b") == 1 # no duplicate
|
||||
|
||||
|
||||
# ── GET /run ───────────────────────────────────────────────────────────────────
|
||||
|
||||
def test_run_returns_409_when_already_running(client):
|
||||
|
|
@ -367,3 +399,13 @@ def test_run_passes_license_key_env_to_subprocess(client, config_dir, tmp_path,
|
|||
client.get("/api/cforch/run")
|
||||
|
||||
assert captured_env.get("CF_LICENSE_KEY") == "CFG-AVCT-ENV-ONLY-KEY"
|
||||
|
||||
|
||||
def test_eval_cforch_router_includes_all_sub_routers():
|
||||
"""eval/cforch.py router must include routes from all four sub-routers."""
|
||||
from app.eval.cforch import router
|
||||
paths = {r.path for r in router.routes}
|
||||
assert any("/cforch/" in p for p in paths), f"no /cforch/ routes found in {paths}"
|
||||
assert any("/style/" in p for p in paths), f"no /style/ routes found in {paths}"
|
||||
assert any("/voice/" in p for p in paths), f"no /voice/ routes found in {paths}"
|
||||
assert any("/plans-bench/" in p for p in paths), f"no /plans-bench/ routes found in {paths}"
|
||||
|
|
|
|||
|
|
@ -268,3 +268,373 @@ def test_finetuned_adapter_unload_clears_pipeline():
|
|||
assert adapter._pipeline is not None
|
||||
adapter.unload()
|
||||
assert adapter._pipeline is None
|
||||
|
||||
# ---- _cosine() tests ----
|
||||
|
||||
def test_cosine_identical_unit_vectors():
|
||||
import math
|
||||
from scripts.classifier_adapters import _cosine
|
||||
assert _cosine([1.0, 0.0], [1.0, 0.0]) == pytest.approx(1.0)
|
||||
|
||||
|
||||
def test_cosine_orthogonal_vectors():
|
||||
from scripts.classifier_adapters import _cosine
|
||||
assert _cosine([1.0, 0.0], [0.0, 1.0]) == pytest.approx(0.0)
|
||||
|
||||
|
||||
def test_cosine_known_value():
|
||||
import math
|
||||
from scripts.classifier_adapters import _cosine
|
||||
# [1,0] vs [1/sqrt(2), 1/sqrt(2)] → dot = 1/sqrt(2), both norms = 1 → 1/sqrt(2)
|
||||
v = [1.0 / math.sqrt(2), 1.0 / math.sqrt(2)]
|
||||
assert _cosine([1.0, 0.0], v) == pytest.approx(1.0 / math.sqrt(2))
|
||||
|
||||
|
||||
def test_cosine_zero_vector_returns_zero():
|
||||
from scripts.classifier_adapters import _cosine
|
||||
assert _cosine([0.0, 0.0], [1.0, 0.0]) == pytest.approx(0.0)
|
||||
|
||||
|
||||
# ---- DEFAULT_EXEMPLARS tests ----
|
||||
|
||||
def test_default_exemplars_covers_all_labels():
|
||||
from scripts.classifier_adapters import DEFAULT_EXEMPLARS, LABELS
|
||||
for label in LABELS:
|
||||
assert label in DEFAULT_EXEMPLARS, f"DEFAULT_EXEMPLARS missing label: {label}"
|
||||
assert len(DEFAULT_EXEMPLARS[label]) >= 4, f"{label} needs >= 4 exemplars for k=3 voting"
|
||||
|
||||
|
||||
def test_default_exemplars_sparse_labels_have_at_least_four():
|
||||
from scripts.classifier_adapters import DEFAULT_EXEMPLARS
|
||||
# These labels have very few real examples; need >= 4 so k=3 vote is meaningful
|
||||
for label in ("hired", "survey_received", "event_rescheduled"):
|
||||
assert len(DEFAULT_EXEMPLARS[label]) >= 4, (
|
||||
f"{label} needs >= 4 exemplars for k=3 voting to work reliably"
|
||||
)
|
||||
|
||||
def test_default_exemplars_strings_are_formatted_correctly():
|
||||
from scripts.classifier_adapters import DEFAULT_EXEMPLARS
|
||||
for label, texts in DEFAULT_EXEMPLARS.items():
|
||||
for text in texts:
|
||||
assert text.startswith("Subject: "), (
|
||||
f"{label!r} exemplar missing 'Subject: ' prefix: {text[:50]!r}"
|
||||
)
|
||||
assert "\n\n" in text, (
|
||||
f"{label!r} exemplar missing double-newline separator: {text[:50]!r}"
|
||||
)
|
||||
|
||||
# ---- EmbeddingKNNAdapter constructor tests ----
|
||||
|
||||
def test_embedding_knn_is_classifier_adapter():
|
||||
from scripts.classifier_adapters import EmbeddingKNNAdapter, ClassifierAdapter
|
||||
adapter = EmbeddingKNNAdapter(
|
||||
"test-knn", "nomic-embed-text",
|
||||
k=3, orch_url="http://orch:7700", ollama_url="http://ollama:11434",
|
||||
)
|
||||
assert isinstance(adapter, ClassifierAdapter)
|
||||
|
||||
|
||||
def test_embedding_knn_name_and_model_id():
|
||||
from scripts.classifier_adapters import EmbeddingKNNAdapter
|
||||
adapter = EmbeddingKNNAdapter(
|
||||
"embed-knn-nomic", "nomic-embed-text",
|
||||
k=3, orch_url="http://orch:7700", ollama_url="http://ollama:11434",
|
||||
)
|
||||
assert adapter.name == "embed-knn-nomic"
|
||||
assert adapter.model_id == "nomic-embed-text"
|
||||
|
||||
|
||||
def test_embedding_knn_uses_default_exemplars_when_none_given():
|
||||
from scripts.classifier_adapters import EmbeddingKNNAdapter, DEFAULT_EXEMPLARS
|
||||
adapter = EmbeddingKNNAdapter(
|
||||
"test", "nomic-embed-text",
|
||||
k=3, orch_url="http://orch:7700", ollama_url="http://ollama:11434",
|
||||
)
|
||||
assert adapter._exemplar_texts is DEFAULT_EXEMPLARS
|
||||
|
||||
|
||||
def test_embedding_knn_accepts_custom_exemplars():
|
||||
from scripts.classifier_adapters import EmbeddingKNNAdapter
|
||||
custom = {"rejected": ["Sorry, we went with others."]}
|
||||
adapter = EmbeddingKNNAdapter(
|
||||
"test", "nomic-embed-text",
|
||||
k=3, orch_url="http://orch:7700", ollama_url="http://ollama:11434",
|
||||
exemplar_texts=custom,
|
||||
)
|
||||
assert adapter._exemplar_texts is custom
|
||||
|
||||
|
||||
# ---- EmbeddingKNNAdapter.load() tests ----
|
||||
|
||||
def _make_post_mock(alloc_url="http://navi:11434", alloc_id="alloc-abc"):
|
||||
"""Return a side_effect function for patching httpx.post.
|
||||
|
||||
Allocate calls get alloc_url/alloc_id; embed calls return one [0.1,0.2,0.3]
|
||||
embedding per input text.
|
||||
"""
|
||||
def _side_effect(url, *, json=None, timeout=None, **kwargs):
|
||||
from unittest.mock import MagicMock
|
||||
resp = MagicMock()
|
||||
resp.raise_for_status.return_value = None
|
||||
if "/allocate" in url:
|
||||
resp.status_code = 200
|
||||
resp.json.return_value = {"allocation_id": alloc_id, "url": alloc_url}
|
||||
else:
|
||||
n = len((json or {}).get("input", []))
|
||||
resp.status_code = 200
|
||||
resp.json.return_value = {"data": [{"embedding": [0.1, 0.2, 0.3]}] * n}
|
||||
return resp
|
||||
return _side_effect
|
||||
|
||||
|
||||
def test_load_calls_allocate_then_embeds_each_label():
|
||||
from unittest.mock import patch
|
||||
from scripts.classifier_adapters import EmbeddingKNNAdapter
|
||||
|
||||
exemplars = {
|
||||
"rejected": ["We went with others"],
|
||||
"hired": ["Welcome aboard!", "First day info"],
|
||||
}
|
||||
adapter = EmbeddingKNNAdapter(
|
||||
"test", "nomic-embed-text", k=3,
|
||||
orch_url="http://orch:7700", ollama_url="http://ollama:11434",
|
||||
exemplar_texts=exemplars,
|
||||
)
|
||||
|
||||
post_urls = []
|
||||
def capturing_mock(url, *, json=None, timeout=None, **kwargs):
|
||||
post_urls.append(url)
|
||||
return _make_post_mock()(url, json=json, timeout=timeout)
|
||||
|
||||
with patch("httpx.post", side_effect=capturing_mock):
|
||||
adapter.load()
|
||||
|
||||
assert any("/allocate" in u for u in post_urls), "expected allocate call"
|
||||
assert any("/v1/embeddings" in u for u in post_urls), "expected embed call"
|
||||
assert adapter._allocation_id == "alloc-abc"
|
||||
assert adapter._node_url == "http://navi:11434"
|
||||
assert adapter._orch_url_used == "http://orch:7700"
|
||||
assert "rejected" in adapter._exemplar_embeddings
|
||||
assert "hired" in adapter._exemplar_embeddings
|
||||
assert len(adapter._exemplar_embeddings["rejected"]) == 1
|
||||
assert len(adapter._exemplar_embeddings["hired"]) == 2
|
||||
assert adapter._exemplar_embeddings["rejected"][0] == [0.1, 0.2, 0.3]
|
||||
assert adapter._exemplar_embeddings["hired"][0] == [0.1, 0.2, 0.3]
|
||||
|
||||
|
||||
def test_load_falls_back_to_ollama_when_allocate_fails():
|
||||
from unittest.mock import patch, MagicMock
|
||||
from scripts.classifier_adapters import EmbeddingKNNAdapter
|
||||
|
||||
exemplars = {"rejected": ["We went with others"]}
|
||||
adapter = EmbeddingKNNAdapter(
|
||||
"test", "nomic-embed-text", k=3,
|
||||
orch_url="http://orch:7700", ollama_url="http://ollama:11434",
|
||||
exemplar_texts=exemplars,
|
||||
)
|
||||
|
||||
def failing_allocate_mock(url, *, json=None, timeout=None, **kwargs):
|
||||
resp = MagicMock()
|
||||
if "/allocate" in url:
|
||||
resp.status_code = 503
|
||||
resp.json.return_value = {}
|
||||
else:
|
||||
resp.raise_for_status.return_value = None
|
||||
resp.json.return_value = {"data": [{"embedding": [0.1, 0.2, 0.3]}]}
|
||||
return resp
|
||||
|
||||
with patch("httpx.post", side_effect=failing_allocate_mock):
|
||||
adapter.load()
|
||||
|
||||
assert adapter._allocation_id == ""
|
||||
assert adapter._orch_url_used == ""
|
||||
assert adapter._node_url == "http://ollama:11434"
|
||||
assert "rejected" in adapter._exemplar_embeddings
|
||||
|
||||
|
||||
def test_load_falls_back_to_ollama_when_allocate_raises():
|
||||
from unittest.mock import patch, MagicMock
|
||||
import httpx as _httpx
|
||||
from scripts.classifier_adapters import EmbeddingKNNAdapter
|
||||
|
||||
exemplars = {"rejected": ["We went with others"]}
|
||||
adapter = EmbeddingKNNAdapter(
|
||||
"test", "nomic-embed-text", k=3,
|
||||
orch_url="http://orch:7700", ollama_url="http://ollama:11434",
|
||||
exemplar_texts=exemplars,
|
||||
)
|
||||
|
||||
def raising_mock(url, *, json=None, timeout=None, **kwargs):
|
||||
if "/allocate" in url:
|
||||
raise _httpx.ConnectError("connection refused")
|
||||
resp = MagicMock()
|
||||
resp.raise_for_status.return_value = None
|
||||
resp.json.return_value = {"data": [{"embedding": [0.1, 0.2, 0.3]}]}
|
||||
return resp
|
||||
|
||||
with patch("httpx.post", side_effect=raising_mock):
|
||||
adapter.load()
|
||||
|
||||
assert adapter._allocation_id == ""
|
||||
assert adapter._orch_url_used == ""
|
||||
assert adapter._node_url == "http://ollama:11434"
|
||||
assert "rejected" in adapter._exemplar_embeddings
|
||||
|
||||
|
||||
# ---- EmbeddingKNNAdapter.unload() tests ----
|
||||
|
||||
def test_unload_releases_orch_allocation_and_clears_state():
|
||||
from unittest.mock import patch, MagicMock
|
||||
from scripts.classifier_adapters import EmbeddingKNNAdapter
|
||||
|
||||
adapter = EmbeddingKNNAdapter(
|
||||
"test", "nomic-embed-text", k=3,
|
||||
orch_url="http://orch:7700", ollama_url="http://ollama:11434",
|
||||
)
|
||||
adapter._exemplar_embeddings = {"rejected": [[1.0, 0.0]]}
|
||||
adapter._node_url = "http://navi:11434"
|
||||
adapter._allocation_id = "alloc-abc"
|
||||
adapter._orch_url_used = "http://orch:7700"
|
||||
|
||||
delete_calls = []
|
||||
def mock_request(method, url, **kwargs):
|
||||
delete_calls.append((method, url))
|
||||
resp = MagicMock()
|
||||
resp.status_code = 200
|
||||
return resp
|
||||
|
||||
with patch("httpx.request", side_effect=mock_request):
|
||||
adapter.unload()
|
||||
|
||||
assert len(delete_calls) == 1
|
||||
method, url = delete_calls[0]
|
||||
assert method == "DELETE"
|
||||
assert "alloc-abc" in url
|
||||
assert adapter._exemplar_embeddings == {}
|
||||
assert adapter._allocation_id == ""
|
||||
assert adapter._node_url == ""
|
||||
assert adapter._orch_url_used == ""
|
||||
|
||||
|
||||
def test_unload_skips_delete_on_ollama_fallback_path():
|
||||
from unittest.mock import patch
|
||||
from scripts.classifier_adapters import EmbeddingKNNAdapter
|
||||
|
||||
adapter = EmbeddingKNNAdapter(
|
||||
"test", "nomic-embed-text", k=3,
|
||||
orch_url="http://orch:7700", ollama_url="http://ollama:11434",
|
||||
)
|
||||
adapter._exemplar_embeddings = {"rejected": [[1.0, 0.0]]}
|
||||
adapter._node_url = "http://ollama:11434"
|
||||
adapter._allocation_id = "" # fallback path: no allocation was made
|
||||
adapter._orch_url_used = ""
|
||||
|
||||
delete_calls = []
|
||||
with patch("httpx.request", side_effect=lambda *a, **k: delete_calls.append(a)):
|
||||
adapter.unload()
|
||||
|
||||
assert len(delete_calls) == 0
|
||||
assert adapter._exemplar_embeddings == {}
|
||||
assert adapter._node_url == ""
|
||||
|
||||
|
||||
# ---- EmbeddingKNNAdapter.classify() tests ----
|
||||
|
||||
def _adapter_with_embeddings(exemplar_embeddings, k=3):
|
||||
"""Return a pre-loaded EmbeddingKNNAdapter (bypass load()) with given per-label vectors."""
|
||||
from scripts.classifier_adapters import EmbeddingKNNAdapter
|
||||
adapter = EmbeddingKNNAdapter(
|
||||
"test", "nomic-embed-text", k=k,
|
||||
orch_url="http://orch:7700", ollama_url="http://ollama:11434",
|
||||
)
|
||||
adapter._exemplar_embeddings = exemplar_embeddings
|
||||
adapter._node_url = "http://navi:11434"
|
||||
return adapter
|
||||
|
||||
|
||||
def _embed_resp(vec):
|
||||
"""Return a mock httpx response for /v1/embeddings returning a single vector."""
|
||||
from unittest.mock import MagicMock
|
||||
resp = MagicMock()
|
||||
resp.raise_for_status.return_value = None
|
||||
resp.json.return_value = {"data": [{"embedding": vec}]}
|
||||
return resp
|
||||
|
||||
|
||||
def test_classify_returns_majority_vote_label():
|
||||
from unittest.mock import patch
|
||||
adapter = _adapter_with_embeddings({
|
||||
"rejected": [[1.0, 0.0, 0.0], [0.9, 0.1, 0.0], [0.85, 0.15, 0.0]],
|
||||
"neutral": [[0.0, 1.0, 0.0]],
|
||||
}, k=3)
|
||||
|
||||
# Query [1,0,0] is closest to all three "rejected" exemplars
|
||||
with patch("httpx.post", return_value=_embed_resp([1.0, 0.0, 0.0])):
|
||||
result = adapter.classify("We went with others", "Thank you for applying.")
|
||||
|
||||
assert result == "rejected"
|
||||
|
||||
|
||||
def test_classify_tiebreak_by_mean_score():
|
||||
from unittest.mock import patch
|
||||
# k=2: each label gets exactly 1 vote → tie-break by mean similarity
|
||||
# [1,0] query: cosine to [1,0] = 1.0 ("rejected"), cosine to [0.6,0.8] ≈ 0.6 ("neutral")
|
||||
adapter = _adapter_with_embeddings({
|
||||
"rejected": [[1.0, 0.0]],
|
||||
"neutral": [[0.6, 0.8]],
|
||||
}, k=2)
|
||||
|
||||
with patch("httpx.post", return_value=_embed_resp([1.0, 0.0])):
|
||||
result = adapter.classify("Rejection", "Sorry")
|
||||
|
||||
assert result == "rejected"
|
||||
|
||||
|
||||
def test_classify_sparse_label_can_win():
|
||||
from unittest.mock import patch
|
||||
# "hired" has only 1 exemplar; with k=1, the single closest match wins
|
||||
adapter = _adapter_with_embeddings({
|
||||
"rejected": [[0.0, 0.0, 1.0], [0.0, 0.1, 0.9]],
|
||||
"hired": [[1.0, 0.0, 0.0]],
|
||||
}, k=1)
|
||||
|
||||
# Query [1,0,0] → hired exemplar scores 1.0; closest single match wins
|
||||
with patch("httpx.post", return_value=_embed_resp([1.0, 0.0, 0.0])):
|
||||
result = adapter.classify("Welcome aboard", "Your first day details")
|
||||
|
||||
assert result == "hired"
|
||||
|
||||
|
||||
def test_classify_lazy_loads_when_not_loaded():
|
||||
from unittest.mock import patch
|
||||
from scripts.classifier_adapters import EmbeddingKNNAdapter
|
||||
|
||||
exemplars = {"rejected": ["We went with others"]}
|
||||
adapter = EmbeddingKNNAdapter(
|
||||
"test", "nomic-embed-text", k=1,
|
||||
orch_url="http://orch:7700", ollama_url="http://ollama:11434",
|
||||
exemplar_texts=exemplars,
|
||||
)
|
||||
assert adapter._exemplar_embeddings == {}
|
||||
|
||||
post_urls = []
|
||||
def mock_post(url, *, json=None, timeout=None, **kwargs):
|
||||
post_urls.append(url)
|
||||
from unittest.mock import MagicMock
|
||||
resp = MagicMock()
|
||||
resp.raise_for_status.return_value = None
|
||||
if "/allocate" in url:
|
||||
resp.status_code = 200
|
||||
resp.json.return_value = {"allocation_id": "a1", "url": "http://navi:11434"}
|
||||
else:
|
||||
n = len((json or {}).get("input", []))
|
||||
resp.json.return_value = {"data": [{"embedding": [1.0, 0.0]}] * n}
|
||||
return resp
|
||||
|
||||
with patch("httpx.post", side_effect=mock_post):
|
||||
result = adapter.classify("Rejection", "Sorry")
|
||||
|
||||
assert result == "rejected"
|
||||
assert any("/allocate" in u for u in post_urls), "lazy load must call allocate"
|
||||
assert adapter._exemplar_embeddings != {}
|
||||
assert adapter._node_url == "http://navi:11434"
|
||||
|
|
|
|||
122
tests/test_dashboard.py
Normal file
122
tests/test_dashboard.py
Normal file
|
|
@ -0,0 +1,122 @@
|
|||
"""Tests for app/dashboard.py -- GET /api/dashboard."""
|
||||
import json
|
||||
import pytest
|
||||
import yaml
|
||||
from fastapi.testclient import TestClient
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_globals(tmp_path):
|
||||
from app import dashboard as dash_module
|
||||
dash_module.set_data_dir(tmp_path)
|
||||
dash_module.set_config_dir(tmp_path)
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
from app.api import app
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
def _write_score(tmp_path: Path, records: list[dict]) -> None:
|
||||
(tmp_path / "email_score.jsonl").write_text(
|
||||
"\n".join(json.dumps(r) for r in records) + "\n"
|
||||
)
|
||||
|
||||
def _write_summary(tmp_path: Path, run_id: str, ts: str, score: float) -> None:
|
||||
run_dir = tmp_path / "bench_results" / run_id
|
||||
run_dir.mkdir(parents=True)
|
||||
(run_dir / "summary.json").write_text(
|
||||
json.dumps({"timestamp": ts, "best_macro_f1": score})
|
||||
)
|
||||
|
||||
|
||||
def test_dashboard_returns_expected_keys(client):
|
||||
r = client.get("/api/dashboard")
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
for key in ("labeled_since_last_eval", "last_eval_timestamp", "last_eval_best_score",
|
||||
"active_jobs", "corrections_pending", "corrections_export_ready", "signals"):
|
||||
assert key in data, f"missing key: {key}"
|
||||
for sig in ("data_to_eval", "eval_to_train", "train_to_fleet"):
|
||||
assert sig in data["signals"], f"missing signal: {sig}"
|
||||
|
||||
|
||||
def test_dashboard_empty_state(client):
|
||||
r = client.get("/api/dashboard")
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["labeled_since_last_eval"] == 0
|
||||
assert data["last_eval_timestamp"] is None
|
||||
assert data["last_eval_best_score"] is None
|
||||
assert data["active_jobs"] == []
|
||||
assert data["corrections_pending"] == 0
|
||||
assert data["corrections_export_ready"] == 0
|
||||
|
||||
|
||||
def test_labeled_since_counts_all_when_no_eval(client, tmp_path):
|
||||
_write_score(tmp_path, [
|
||||
{"id": "a", "label": "neutral", "labeled_at": "2026-05-01T10:00:00+00:00"},
|
||||
{"id": "b", "label": "neutral", "labeled_at": "2026-05-01T11:00:00+00:00"},
|
||||
])
|
||||
r = client.get("/api/dashboard")
|
||||
assert r.json()["labeled_since_last_eval"] == 2
|
||||
|
||||
|
||||
def test_labeled_since_filters_by_eval_timestamp(client, tmp_path):
|
||||
_write_summary(tmp_path, "2026-05-01-100000", "2026-05-01T10:00:00+00:00", 0.80)
|
||||
_write_score(tmp_path, [
|
||||
{"id": "a", "label": "neutral", "labeled_at": "2026-05-01T09:00:00+00:00"},
|
||||
{"id": "b", "label": "neutral", "labeled_at": "2026-05-01T11:00:00+00:00"},
|
||||
])
|
||||
(tmp_path / "label_tool.yaml").write_text(
|
||||
yaml.dump({"cforch": {"results_dir": str(tmp_path / "bench_results")}})
|
||||
)
|
||||
r = client.get("/api/dashboard")
|
||||
data = r.json()
|
||||
assert data["labeled_since_last_eval"] == 1
|
||||
assert abs(data["last_eval_best_score"] - 0.80) < 0.001
|
||||
|
||||
|
||||
def test_data_to_eval_false_below_threshold(client, tmp_path):
|
||||
_write_score(tmp_path, [{"id": str(i), "label": "neutral",
|
||||
"labeled_at": "2026-05-01T10:00:00+00:00"} for i in range(10)])
|
||||
(tmp_path / "label_tool.yaml").write_text(yaml.dump({"pipeline": {"data_eval_threshold": 50}}))
|
||||
r = client.get("/api/dashboard")
|
||||
assert r.json()["signals"]["data_to_eval"] is False
|
||||
|
||||
|
||||
def test_data_to_eval_true_at_threshold(client, tmp_path):
|
||||
_write_score(tmp_path, [{"id": str(i), "label": "neutral",
|
||||
"labeled_at": "2026-05-01T10:00:00+00:00"} for i in range(50)])
|
||||
(tmp_path / "label_tool.yaml").write_text(yaml.dump({"pipeline": {"data_eval_threshold": 50}}))
|
||||
r = client.get("/api/dashboard")
|
||||
assert r.json()["signals"]["data_to_eval"] is True
|
||||
|
||||
|
||||
def test_corrections_pending_count(client, tmp_path):
|
||||
candidates = [
|
||||
{"id": "c1", "status": "needs_review"},
|
||||
{"id": "c2", "status": "needs_review"},
|
||||
{"id": "c3", "status": "discarded"},
|
||||
]
|
||||
(tmp_path / "sft_candidates.jsonl").write_text(
|
||||
"\n".join(json.dumps(c) for c in candidates) + "\n"
|
||||
)
|
||||
r = client.get("/api/dashboard")
|
||||
assert r.json()["corrections_pending"] == 2
|
||||
|
||||
|
||||
def test_corrections_export_ready_count(client, tmp_path):
|
||||
approved = [
|
||||
{"id": "a1", "status": "approved", "corrected_response": "Good answer"},
|
||||
{"id": "a2", "status": "approved", "corrected_response": ""},
|
||||
{"id": "a3", "status": "approved", "corrected_response": "Another answer"},
|
||||
]
|
||||
(tmp_path / "sft_approved.jsonl").write_text(
|
||||
"\n".join(json.dumps(a) for a in approved) + "\n"
|
||||
)
|
||||
r = client.get("/api/dashboard")
|
||||
assert r.json()["corrections_export_ready"] == 2
|
||||
102
tests/test_data_corrections.py
Normal file
102
tests/test_data_corrections.py
Normal file
|
|
@ -0,0 +1,102 @@
|
|||
"""Tests for app/data/corrections.py -- POST /api/sft/ingest.
|
||||
|
||||
The corrections router is mounted at prefix="/api/sft" via the app/sft.py
|
||||
backward-compat shim, so ingest lives at /api/sft/ingest.
|
||||
"""
|
||||
import json
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_globals(tmp_path):
|
||||
from app.data import corrections as corr_module
|
||||
corr_module.set_data_dir(tmp_path)
|
||||
corr_module.set_config_dir(tmp_path)
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
from app.api import app
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
_VALID_PAYLOAD = {
|
||||
"source": "peregrine",
|
||||
"task_type": "email_classification",
|
||||
"prompt": "Classify this email: ...",
|
||||
"response": "skip",
|
||||
"correction": "action_required",
|
||||
"label": "action_required",
|
||||
}
|
||||
|
||||
_SECRET = "test-secret-abc123"
|
||||
|
||||
|
||||
def test_ingest_503_when_secret_not_configured(client, monkeypatch):
|
||||
monkeypatch.delenv("AVOCET_INGESTION_SECRET", raising=False)
|
||||
r = client.post("/api/sft/ingest", json=_VALID_PAYLOAD,
|
||||
headers={"Authorization": f"Bearer {_SECRET}"})
|
||||
assert r.status_code == 503
|
||||
|
||||
|
||||
def test_ingest_401_when_no_auth_header(client, monkeypatch):
|
||||
monkeypatch.setenv("AVOCET_INGESTION_SECRET", _SECRET)
|
||||
r = client.post("/api/sft/ingest", json=_VALID_PAYLOAD)
|
||||
assert r.status_code == 401
|
||||
|
||||
|
||||
def test_ingest_401_when_malformed_header(client, monkeypatch):
|
||||
monkeypatch.setenv("AVOCET_INGESTION_SECRET", _SECRET)
|
||||
r = client.post("/api/sft/ingest", json=_VALID_PAYLOAD,
|
||||
headers={"Authorization": "Token bad-format"})
|
||||
assert r.status_code == 401
|
||||
|
||||
|
||||
def test_ingest_403_when_wrong_secret(client, monkeypatch):
|
||||
monkeypatch.setenv("AVOCET_INGESTION_SECRET", _SECRET)
|
||||
r = client.post("/api/sft/ingest", json=_VALID_PAYLOAD,
|
||||
headers={"Authorization": "Bearer wrong-secret"})
|
||||
assert r.status_code == 403
|
||||
|
||||
|
||||
def test_ingest_creates_approved_record(client, monkeypatch, tmp_path):
|
||||
from app.data import corrections as corr_module
|
||||
monkeypatch.setenv("AVOCET_INGESTION_SECRET", _SECRET)
|
||||
corr_module.set_data_dir(tmp_path)
|
||||
r = client.post("/api/sft/ingest", json=_VALID_PAYLOAD,
|
||||
headers={"Authorization": f"Bearer {_SECRET}"})
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["ok"] is True
|
||||
assert "id" in data
|
||||
candidates = corr_module.read_jsonl(corr_module._candidates_file())
|
||||
assert len(candidates) == 1
|
||||
rec = candidates[0]
|
||||
assert rec["status"] == "approved"
|
||||
assert rec["source"] == "peregrine"
|
||||
assert rec["corrected_response"] == "action_required"
|
||||
assert rec["id"] == data["id"]
|
||||
|
||||
|
||||
def test_ingest_also_writes_to_approved_file(client, monkeypatch, tmp_path):
|
||||
from app.data import corrections as corr_module
|
||||
monkeypatch.setenv("AVOCET_INGESTION_SECRET", _SECRET)
|
||||
corr_module.set_data_dir(tmp_path)
|
||||
r = client.post("/api/sft/ingest", json=_VALID_PAYLOAD,
|
||||
headers={"Authorization": f"Bearer {_SECRET}"})
|
||||
assert r.status_code == 200
|
||||
approved = corr_module.read_jsonl(corr_module._approved_file())
|
||||
assert len(approved) == 1
|
||||
assert approved[0]["id"] == r.json()["id"]
|
||||
|
||||
|
||||
def test_ingest_without_label_is_accepted(client, monkeypatch, tmp_path):
|
||||
from app.data import corrections as corr_module
|
||||
monkeypatch.setenv("AVOCET_INGESTION_SECRET", _SECRET)
|
||||
corr_module.set_data_dir(tmp_path)
|
||||
payload = {**_VALID_PAYLOAD, "label": None}
|
||||
r = client.post("/api/sft/ingest", json=payload,
|
||||
headers={"Authorization": f"Bearer {_SECRET}"})
|
||||
assert r.status_code == 200
|
||||
95
tests/test_data_fetch.py
Normal file
95
tests/test_data_fetch.py
Normal file
|
|
@ -0,0 +1,95 @@
|
|||
"""Tests for app/data/fetch.py"""
|
||||
import json
|
||||
import yaml
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_globals(tmp_path):
|
||||
from app.data import fetch as fetch_module
|
||||
fetch_module.set_data_dir(tmp_path)
|
||||
fetch_module.set_config_dir(tmp_path)
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
from app.api import app
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
def _parse_sse(content: bytes) -> list[dict]:
|
||||
events = []
|
||||
for line in content.decode().splitlines():
|
||||
if line.startswith("data: "):
|
||||
events.append(json.loads(line[6:]))
|
||||
return events
|
||||
|
||||
|
||||
def test_account_test_missing_fields(client):
|
||||
r = client.post("/api/accounts/test",
|
||||
json={"account": {"host": "", "username": "", "password": ""}})
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["ok"] is False
|
||||
assert "required" in data["message"].lower()
|
||||
|
||||
|
||||
def test_account_test_success(client):
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.select.return_value = ("OK", [b"99"])
|
||||
with patch("app.data.fetch.imaplib.IMAP4_SSL", return_value=mock_conn):
|
||||
r = client.post("/api/accounts/test", json={"account": {
|
||||
"host": "imap.example.com", "port": 993, "use_ssl": True,
|
||||
"username": "u@example.com", "password": "pw", "folder": "INBOX",
|
||||
}})
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["ok"] is True
|
||||
assert data["count"] == 99
|
||||
|
||||
|
||||
def test_fetch_stream_no_accounts_configured(client, tmp_path):
|
||||
r = client.get("/api/fetch/stream?accounts=NoSuchAccount&days_back=30&limit=10")
|
||||
assert r.status_code == 200
|
||||
events = _parse_sse(r.content)
|
||||
complete = next((e for e in events if e["type"] == "complete"), None)
|
||||
assert complete is not None
|
||||
assert complete["total_added"] == 0
|
||||
|
||||
|
||||
def test_fetch_stream_with_mock_imap(client, tmp_path):
|
||||
from app.data import fetch as fetch_module
|
||||
fetch_module.set_config_dir(tmp_path)
|
||||
cfg = {"accounts": [{"name": "Mock", "host": "h", "port": 993, "use_ssl": True,
|
||||
"username": "u", "password": "p", "folder": "INBOX",
|
||||
"days_back": 30}], "max_per_account": 50}
|
||||
(tmp_path / "label_tool.yaml").write_text(yaml.dump(cfg))
|
||||
raw_msg = (b"Subject: Interview\r\nFrom: a@b.com\r\n"
|
||||
b"Date: Mon, 1 Mar 2026 12:00:00 +0000\r\n\r\nBody")
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.search.return_value = ("OK", [b"1"])
|
||||
mock_conn.fetch.return_value = ("OK", [(b"1 (RFC822 {N})", raw_msg)])
|
||||
with patch("app.data.fetch.imaplib.IMAP4_SSL", return_value=mock_conn):
|
||||
r = client.get("/api/fetch/stream?accounts=Mock&days_back=30&limit=50")
|
||||
assert r.status_code == 200
|
||||
events = _parse_sse(r.content)
|
||||
types = [e["type"] for e in events]
|
||||
assert "start" in types
|
||||
assert "done" in types
|
||||
assert "complete" in types
|
||||
|
||||
|
||||
def test_entry_key_deterministic():
|
||||
from app.data.fetch import entry_key
|
||||
e = {"subject": "Test", "body": "Hello world"}
|
||||
assert entry_key(e) == entry_key(e)
|
||||
|
||||
|
||||
def test_entry_key_differs_by_subject():
|
||||
from app.data.fetch import entry_key
|
||||
a = {"subject": "A", "body": "same body"}
|
||||
b = {"subject": "B", "body": "same body"}
|
||||
assert entry_key(a) != entry_key(b)
|
||||
219
tests/test_data_label.py
Normal file
219
tests/test_data_label.py
Normal file
|
|
@ -0,0 +1,219 @@
|
|||
"""Tests for app/data/label.py"""
|
||||
import json
|
||||
import pytest
|
||||
import yaml
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_globals(tmp_path):
|
||||
from app.data import label as label_module
|
||||
label_module.set_data_dir(tmp_path)
|
||||
label_module.set_config_dir(tmp_path)
|
||||
label_module.reset_last_action()
|
||||
yield
|
||||
label_module.reset_last_action()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
from app.api import app
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def queue_with_items(tmp_path):
|
||||
from app.data import label as label_module
|
||||
items = [
|
||||
{"id": f"id{i}", "subject": f"Subject {i}", "body": f"Body {i}",
|
||||
"from": "test@example.com", "date": "2026-03-01", "source": "imap:test"}
|
||||
for i in range(3)
|
||||
]
|
||||
(label_module._DATA_DIR / "email_label_queue.jsonl").write_text(
|
||||
"\n".join(json.dumps(x) for x in items) + "\n")
|
||||
return items
|
||||
|
||||
|
||||
def test_queue_returns_items(client, queue_with_items):
|
||||
r = client.get("/api/queue?limit=2")
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert len(data["items"]) == 2
|
||||
assert data["total"] == 3
|
||||
|
||||
|
||||
def test_queue_empty_when_no_file(client):
|
||||
r = client.get("/api/queue")
|
||||
assert r.status_code == 200
|
||||
assert r.json() == {"items": [], "total": 0}
|
||||
|
||||
|
||||
def test_label_appends_to_score(client, queue_with_items):
|
||||
from app.data import label as label_module
|
||||
r = client.post("/api/label", json={"id": "id0", "label": "interview_scheduled"})
|
||||
assert r.status_code == 200
|
||||
records = label_module.read_jsonl(label_module._score_file())
|
||||
assert len(records) == 1
|
||||
assert records[0]["id"] == "id0"
|
||||
assert records[0]["label"] == "interview_scheduled"
|
||||
assert "labeled_at" in records[0]
|
||||
|
||||
|
||||
def test_label_removes_from_queue(client, queue_with_items):
|
||||
from app.data import label as label_module
|
||||
client.post("/api/label", json={"id": "id0", "label": "rejected"})
|
||||
queue = label_module.read_jsonl(label_module._queue_file())
|
||||
assert not any(x["id"] == "id0" for x in queue)
|
||||
|
||||
|
||||
def test_label_unknown_id_returns_404(client, queue_with_items):
|
||||
r = client.post("/api/label", json={"id": "unknown", "label": "neutral"})
|
||||
assert r.status_code == 404
|
||||
|
||||
|
||||
def test_skip_moves_to_back(client, queue_with_items):
|
||||
from app.data import label as label_module
|
||||
r = client.post("/api/skip", json={"id": "id0"})
|
||||
assert r.status_code == 200
|
||||
queue = label_module.read_jsonl(label_module._queue_file())
|
||||
assert queue[-1]["id"] == "id0"
|
||||
assert queue[0]["id"] == "id1"
|
||||
|
||||
|
||||
def test_skip_unknown_id_returns_404(client, queue_with_items):
|
||||
r = client.post("/api/skip", json={"id": "nope"})
|
||||
assert r.status_code == 404
|
||||
|
||||
|
||||
def test_discard_writes_to_discarded_file(client, queue_with_items):
|
||||
from app.data import label as label_module
|
||||
r = client.post("/api/discard", json={"id": "id1"})
|
||||
assert r.status_code == 200
|
||||
discarded = label_module.read_jsonl(label_module._discarded_file())
|
||||
assert len(discarded) == 1
|
||||
assert discarded[0]["id"] == "id1"
|
||||
assert discarded[0]["label"] == "__discarded__"
|
||||
|
||||
|
||||
def test_discard_removes_from_queue(client, queue_with_items):
|
||||
from app.data import label as label_module
|
||||
client.post("/api/discard", json={"id": "id1"})
|
||||
queue = label_module.read_jsonl(label_module._queue_file())
|
||||
assert not any(x["id"] == "id1" for x in queue)
|
||||
|
||||
|
||||
def test_undo_label_removes_from_score(client, queue_with_items):
|
||||
from app.data import label as label_module
|
||||
client.post("/api/label", json={"id": "id0", "label": "neutral"})
|
||||
r = client.delete("/api/label/undo")
|
||||
assert r.status_code == 200
|
||||
assert r.json()["undone"]["type"] == "label"
|
||||
assert label_module.read_jsonl(label_module._score_file()) == []
|
||||
queue = label_module.read_jsonl(label_module._queue_file())
|
||||
assert queue[0]["id"] == "id0"
|
||||
|
||||
|
||||
def test_undo_discard_removes_from_discarded(client, queue_with_items):
|
||||
from app.data import label as label_module
|
||||
client.post("/api/discard", json={"id": "id0"})
|
||||
r = client.delete("/api/label/undo")
|
||||
assert r.status_code == 200
|
||||
assert label_module.read_jsonl(label_module._discarded_file()) == []
|
||||
|
||||
|
||||
def test_undo_skip_restores_to_front(client, queue_with_items):
|
||||
from app.data import label as label_module
|
||||
client.post("/api/skip", json={"id": "id0"})
|
||||
r = client.delete("/api/label/undo")
|
||||
assert r.status_code == 200
|
||||
queue = label_module.read_jsonl(label_module._queue_file())
|
||||
assert queue[0]["id"] == "id0"
|
||||
|
||||
|
||||
def test_undo_with_no_action_returns_404(client):
|
||||
r = client.delete("/api/label/undo")
|
||||
assert r.status_code == 404
|
||||
|
||||
|
||||
def test_config_labels_returns_10_labels(client):
|
||||
r = client.get("/api/config/labels")
|
||||
assert r.status_code == 200
|
||||
labels = r.json()
|
||||
assert len(labels) == 10
|
||||
assert labels[0]["key"] == "1"
|
||||
for lbl in labels:
|
||||
assert "emoji" in lbl and "color" in lbl and "name" in lbl
|
||||
|
||||
|
||||
def test_get_config_returns_empty_when_no_file(client):
|
||||
r = client.get("/api/config")
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["accounts"] == []
|
||||
assert data["max_per_account"] == 500
|
||||
|
||||
|
||||
def test_post_config_writes_yaml(client, tmp_path):
|
||||
from app.data import label as label_module
|
||||
label_module.set_config_dir(tmp_path)
|
||||
payload = {"accounts": [{"name": "Test", "host": "imap.test.com", "port": 993,
|
||||
"use_ssl": True, "username": "u@t.com", "password": "pw",
|
||||
"folder": "INBOX", "days_back": 30}], "max_per_account": 200}
|
||||
r = client.post("/api/config", json=payload)
|
||||
assert r.status_code == 200
|
||||
assert r.json()["ok"] is True
|
||||
saved = yaml.safe_load((tmp_path / "label_tool.yaml").read_text())
|
||||
assert saved["max_per_account"] == 200
|
||||
assert saved["accounts"][0]["name"] == "Test"
|
||||
|
||||
|
||||
def test_get_config_round_trips(client, tmp_path):
|
||||
from app.data import label as label_module
|
||||
label_module.set_config_dir(tmp_path)
|
||||
payload = {"accounts": [{"name": "R", "host": "h", "port": 993, "use_ssl": True,
|
||||
"username": "u", "password": "p", "folder": "INBOX",
|
||||
"days_back": 90}], "max_per_account": 300}
|
||||
client.post("/api/config", json=payload)
|
||||
r = client.get("/api/config")
|
||||
data = r.json()
|
||||
assert data["max_per_account"] == 300
|
||||
assert data["accounts"][0]["name"] == "R"
|
||||
|
||||
|
||||
def test_stats_returns_counts(client, tmp_path):
|
||||
from app.data import label as label_module
|
||||
label_module.set_data_dir(tmp_path)
|
||||
score_path = tmp_path / "email_score.jsonl"
|
||||
records = [{"id": "a", "label": "interview_scheduled"},
|
||||
{"id": "b", "label": "interview_scheduled"},
|
||||
{"id": "c", "label": "rejected"}]
|
||||
score_path.write_text("\n".join(json.dumps(r) for r in records) + "\n")
|
||||
r = client.get("/api/stats")
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["total"] == 3
|
||||
assert data["counts"]["interview_scheduled"] == 2
|
||||
assert data["counts"]["rejected"] == 1
|
||||
|
||||
|
||||
def test_stats_empty_when_no_file(client):
|
||||
r = client.get("/api/stats")
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["total"] == 0
|
||||
assert data["counts"] == {}
|
||||
assert data["score_file_bytes"] == 0
|
||||
|
||||
|
||||
def test_stats_download_returns_file(client, tmp_path):
|
||||
from app.data import label as label_module
|
||||
label_module.set_data_dir(tmp_path)
|
||||
(tmp_path / "email_score.jsonl").write_text(json.dumps({"id": "a", "label": "neutral"}) + "\n")
|
||||
r = client.get("/api/stats/download")
|
||||
assert r.status_code == 200
|
||||
assert "jsonlines" in r.headers.get("content-type", "")
|
||||
|
||||
|
||||
def test_stats_download_404_when_no_file(client):
|
||||
r = client.get("/api/stats/download")
|
||||
assert r.status_code == 404
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
"""Tests for app/imitate.py — product registry, sample extraction, corrections push."""
|
||||
"""Tests for app/imitate.py -- product registry, sample extraction, corrections push."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
|
@ -9,10 +9,10 @@ import pytest
|
|||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.api import app
|
||||
from app import imitate as _imitate_module
|
||||
from app.data import imitate as _imitate_module
|
||||
|
||||
|
||||
# ── Fixtures ───────────────────────────────────────────────────────────────────
|
||||
# -- Fixtures ------------------------------------------------------------------
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_module_globals(tmp_path):
|
||||
|
|
@ -70,7 +70,7 @@ def client() -> TestClient:
|
|||
return TestClient(app, raise_server_exceptions=True)
|
||||
|
||||
|
||||
# ── GET /products ──────────────────────────────────────────────────────────────
|
||||
# -- GET /products -------------------------------------------------------------
|
||||
|
||||
def test_products_empty_when_no_config(config_dir, client):
|
||||
"""Returns empty list when label_tool.yaml has no imitate section."""
|
||||
|
|
@ -102,7 +102,7 @@ def test_products_offline_when_unreachable(cfg_with_products, client):
|
|||
assert all(not p["online"] for p in resp.json()["products"])
|
||||
|
||||
|
||||
# ── GET /products/{id}/sample ─────────────────────────────────────────────────
|
||||
# -- GET /products/{id}/sample -------------------------------------------------
|
||||
|
||||
def test_sample_unknown_product(cfg_with_products, client):
|
||||
"""Returns 404 for a product id not in config."""
|
||||
|
|
@ -149,7 +149,7 @@ def test_sample_404_on_empty_list(cfg_with_products, client):
|
|||
assert resp.status_code == 404
|
||||
|
||||
|
||||
# ── POST /push-corrections ─────────────────────────────────────────────────────
|
||||
# -- POST /push-corrections ----------------------------------------------------
|
||||
|
||||
def test_push_corrections_appends_jsonl(cfg_with_products, data_dir, client):
|
||||
"""Successful push writes records to sft_candidates.jsonl."""
|
||||
|
|
@ -214,7 +214,7 @@ def test_push_corrections_all_errors_422(cfg_with_products, data_dir, client):
|
|||
assert resp.status_code == 422
|
||||
|
||||
|
||||
# ── _extract_sample helper ─────────────────────────────────────────────────────
|
||||
# -- _extract_sample helper ----------------------------------------------------
|
||||
|
||||
def test_extract_sample_list():
|
||||
result = _imitate_module._extract_sample(
|
||||
|
|
|
|||
|
|
@ -541,3 +541,84 @@ def test_delete_installed_name_with_slash_blocked(client):
|
|||
except _HTTPException as exc:
|
||||
assert exc.status_code in (400, 404)
|
||||
raise
|
||||
|
||||
|
||||
# ── Catalog registration ───────────────────────────────────────────────────────
|
||||
|
||||
_MINIMAL_YAML = """\
|
||||
services:
|
||||
cf-text:
|
||||
max_mb: {max_mb}
|
||||
catalog:
|
||||
existing-model:
|
||||
path: /some/path
|
||||
vram_mb: 1000
|
||||
description: "placeholder"
|
||||
"""
|
||||
|
||||
|
||||
def _make_node_yaml(tmp_path: Path, max_mb: int = 8192) -> Path:
|
||||
p = tmp_path / "testnode.yaml"
|
||||
p.write_text(_MINIMAL_YAML.format(max_mb=max_mb), encoding="utf-8")
|
||||
return p
|
||||
|
||||
|
||||
def test_catalog_registration_fp16_no_env_block(tmp_path):
|
||||
"""When model fits at FP16, no env block should be written."""
|
||||
from app import models as models_module
|
||||
|
||||
node_yaml = _make_node_yaml(tmp_path, max_mb=8192)
|
||||
with patch.object(models_module, "_CF_ORCH_PROFILES_DIR", tmp_path):
|
||||
updated = models_module._register_in_node_catalogs(
|
||||
repo_id="org/SmallModel",
|
||||
local_path=tmp_path / "org--SmallModel",
|
||||
vram_mb_fp16=4000,
|
||||
role="generator",
|
||||
)
|
||||
|
||||
assert "testnode" in updated
|
||||
content = node_yaml.read_text()
|
||||
# _catalog_key strips org prefix and lowercases: "org/SmallModel" → "smallmodel"
|
||||
assert "smallmodel:" in content
|
||||
assert "CF_TEXT_4BIT" not in content
|
||||
assert "env:" not in content
|
||||
|
||||
|
||||
def test_catalog_registration_needs_4bit_writes_env_block(tmp_path):
|
||||
"""When model only fits at 4-bit, env: CF_TEXT_4BIT: '1' must be written."""
|
||||
from app import models as models_module
|
||||
|
||||
node_yaml = _make_node_yaml(tmp_path, max_mb=8192)
|
||||
with patch.object(models_module, "_CF_ORCH_PROFILES_DIR", tmp_path):
|
||||
updated = models_module._register_in_node_catalogs(
|
||||
repo_id="org/BigModel",
|
||||
local_path=tmp_path / "org--BigModel",
|
||||
vram_mb_fp16=20000, # won't fit at FP16 on 8 GB
|
||||
role="generator",
|
||||
)
|
||||
|
||||
assert "testnode" in updated
|
||||
content = node_yaml.read_text()
|
||||
# _catalog_key: "org/BigModel" → "bigmodel"
|
||||
assert "bigmodel:" in content
|
||||
assert "env:" in content
|
||||
assert 'CF_TEXT_4BIT: "1"' in content
|
||||
assert "CF_TEXT_4BIT=1 required" in content # description note
|
||||
|
||||
|
||||
def test_catalog_registration_too_large_skipped(tmp_path):
|
||||
"""Model too large even at 4-bit should not be registered."""
|
||||
from app import models as models_module
|
||||
|
||||
node_yaml = _make_node_yaml(tmp_path, max_mb=8192)
|
||||
with patch.object(models_module, "_CF_ORCH_PROFILES_DIR", tmp_path):
|
||||
updated = models_module._register_in_node_catalogs(
|
||||
repo_id="org/HugeModel",
|
||||
local_path=tmp_path / "org--HugeModel",
|
||||
vram_mb_fp16=80000, # 4-bit ~22 GB, still won't fit on 8 GB
|
||||
role="generator",
|
||||
)
|
||||
|
||||
assert updated == []
|
||||
content = node_yaml.read_text()
|
||||
assert "hugemodel" not in content
|
||||
|
|
|
|||
471
tests/test_nodes.py
Normal file
471
tests/test_nodes.py
Normal file
|
|
@ -0,0 +1,471 @@
|
|||
"""Tests for app/nodes.py — /api/nodes-mgmt/* endpoints."""
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
from fastapi.testclient import TestClient
|
||||
from unittest.mock import MagicMock, patch
|
||||
import os as _os
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_nodes_globals(tmp_path):
|
||||
"""Redirect _CONFIG_DIR to tmp_path so tests never read the real config."""
|
||||
from app import nodes as nodes_module
|
||||
prev = nodes_module._CONFIG_DIR
|
||||
nodes_module.set_config_dir(tmp_path)
|
||||
yield tmp_path
|
||||
nodes_module.set_config_dir(prev)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
from app.api import app
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
def _write_config(config_dir: Path, cforch_cfg: dict) -> None:
|
||||
cfg = {"cforch": cforch_cfg}
|
||||
(config_dir / "label_tool.yaml").write_text(yaml.dump(cfg), encoding="utf-8")
|
||||
|
||||
|
||||
def _write_profile(profiles_dir: Path, node_id: str, profile: dict) -> None:
|
||||
profiles_dir.mkdir(parents=True, exist_ok=True)
|
||||
(profiles_dir / f"{node_id}.yaml").write_text(yaml.dump(profile), encoding="utf-8")
|
||||
|
||||
|
||||
def test_nodes_module_imports():
|
||||
from app import nodes
|
||||
assert hasattr(nodes, "router")
|
||||
assert hasattr(nodes, "set_config_dir")
|
||||
|
||||
|
||||
def test_list_nodes_returns_empty_when_no_coordinator(client):
|
||||
"""No cforch config — endpoint returns empty list, not 500."""
|
||||
r = client.get("/api/nodes-mgmt/nodes")
|
||||
assert r.status_code == 200
|
||||
assert r.json() == []
|
||||
|
||||
|
||||
|
||||
|
||||
def _fake_nodes_response(nodes_json: list, services_json: list | None = None):
|
||||
"""Build side_effect list for two httpx.get calls: nodes then services."""
|
||||
mock_nodes = MagicMock()
|
||||
mock_nodes.raise_for_status = MagicMock()
|
||||
mock_nodes.json.return_value = nodes_json
|
||||
|
||||
mock_services = MagicMock()
|
||||
mock_services.raise_for_status = MagicMock()
|
||||
mock_services.json.return_value = services_json or []
|
||||
|
||||
return [mock_nodes, mock_services]
|
||||
|
||||
|
||||
def test_list_nodes_coordinator_unreachable_returns_empty(client, tmp_path):
|
||||
"""Coordinator unreachable — returns [] with no 500."""
|
||||
import httpx
|
||||
_write_config(tmp_path, {"coordinator_url": "http://fake-coord:7700"})
|
||||
with patch("httpx.get", side_effect=httpx.ConnectError("refused")):
|
||||
r = client.get("/api/nodes-mgmt/nodes")
|
||||
assert r.status_code == 200
|
||||
assert r.json() == []
|
||||
|
||||
|
||||
def test_list_nodes_merges_profile_data(client, tmp_path):
|
||||
"""Profile YAML services_assigned merged with live GPU stats."""
|
||||
profiles_dir = tmp_path / "profiles"
|
||||
_write_config(tmp_path, {
|
||||
"coordinator_url": "http://fake-coord:7700",
|
||||
"profiles_dir": str(profiles_dir),
|
||||
})
|
||||
_write_profile(profiles_dir, "heimdall", {
|
||||
"services": {
|
||||
"cf-text": {"min_compute_cap": 7.0, "max_mb": 8192, "catalog": {}},
|
||||
},
|
||||
"nodes": {
|
||||
"heimdall": {
|
||||
"gpus": [{"id": 0, "vram_mb": 24576, "compute_cap": 8.6,
|
||||
"services": ["cf-text"], "role": "primary", "card": "RTX 3090",
|
||||
"always_on": True}],
|
||||
"agent_url": "http://10.1.10.71:7701",
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
coord_nodes = [{
|
||||
"node_id": "heimdall", "online": True, "agent_url": "http://10.1.10.71:7701",
|
||||
"gpus": [{"gpu_id": 0, "card": "RTX 3090", "vram_total_mb": 24576,
|
||||
"vram_used_mb": 4096, "vram_free_mb": 20480,
|
||||
"temp_c": 42.0, "utilization_pct": 15.0, "compute_cap": 8.6}],
|
||||
}]
|
||||
|
||||
with patch("httpx.get", side_effect=_fake_nodes_response(coord_nodes)):
|
||||
r = client.get("/api/nodes-mgmt/nodes")
|
||||
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert len(data) == 1
|
||||
node = data[0]
|
||||
assert node["node_id"] == "heimdall"
|
||||
assert node["profile_loaded"] is True
|
||||
assert node["gpus"][0]["services_assigned"] == ["cf-text"]
|
||||
assert node["gpus"][0]["vram_total_mb"] == 24576
|
||||
assert "cf-text" in node["services_catalog"]
|
||||
|
||||
|
||||
def test_list_nodes_no_profile_returns_profile_loaded_false(client, tmp_path):
|
||||
"""Node with no profile YAML — profile_loaded: false, GPU stats still returned."""
|
||||
_write_config(tmp_path, {"coordinator_url": "http://fake-coord:7700"})
|
||||
|
||||
coord_nodes = [{
|
||||
"node_id": "sif", "online": True, "agent_url": "http://10.1.10.158:7701",
|
||||
"gpus": [{"gpu_id": 0, "card": "RTX 5060 Ti", "vram_total_mb": 16384,
|
||||
"vram_used_mb": 0, "vram_free_mb": 16384,
|
||||
"temp_c": None, "utilization_pct": None, "compute_cap": 10.0}],
|
||||
}]
|
||||
|
||||
with patch("httpx.get", side_effect=_fake_nodes_response(coord_nodes)):
|
||||
r = client.get("/api/nodes-mgmt/nodes")
|
||||
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
node = data[0]
|
||||
assert node["profile_loaded"] is False
|
||||
assert node["gpus"][0]["card"] == "RTX 5060 Ti"
|
||||
assert node["services_catalog"] == {}
|
||||
|
||||
|
||||
def test_list_nodes_marks_running_services(client, tmp_path):
|
||||
"""services_running populated from coordinator /api/services response."""
|
||||
profiles_dir = tmp_path / "profiles"
|
||||
_write_config(tmp_path, {
|
||||
"coordinator_url": "http://fake-coord:7700",
|
||||
"profiles_dir": str(profiles_dir),
|
||||
})
|
||||
_write_profile(profiles_dir, "heimdall", {
|
||||
"services": {},
|
||||
"nodes": {"heimdall": {"gpus": [{"id": 0, "vram_mb": 24576, "compute_cap": 8.6,
|
||||
"services": ["cf-text"], "role": "p",
|
||||
"card": "RTX 3090", "always_on": True}],
|
||||
"agent_url": "http://10.1.10.71:7701"}}
|
||||
})
|
||||
|
||||
coord_nodes = [{"node_id": "heimdall", "online": True,
|
||||
"agent_url": "http://10.1.10.71:7701",
|
||||
"gpus": [{"gpu_id": 0, "card": "RTX 3090", "vram_total_mb": 24576,
|
||||
"vram_used_mb": 8192, "vram_free_mb": 16384,
|
||||
"temp_c": 55.0, "utilization_pct": 80.0, "compute_cap": 8.6}]}]
|
||||
coord_services = [{"service": "cf-text", "node_id": "heimdall", "gpu_id": 0}]
|
||||
|
||||
with patch("httpx.get", side_effect=_fake_nodes_response(coord_nodes, coord_services)):
|
||||
r = client.get("/api/nodes-mgmt/nodes")
|
||||
|
||||
data = r.json()
|
||||
assert data[0]["gpus"][0]["services_running"] == ["cf-text"]
|
||||
|
||||
|
||||
# ── GET /api/nodes-mgmt/nodes/{node_id}/profile ────────────────────────────────
|
||||
|
||||
def test_get_profile_returns_parsed_yaml(client, tmp_path):
|
||||
profiles_dir = tmp_path / "profiles"
|
||||
_write_config(tmp_path, {"profiles_dir": str(profiles_dir)})
|
||||
profile = {
|
||||
"services": {"cf-text": {"min_compute_cap": 7.0, "max_mb": 8192, "catalog": {}}},
|
||||
"nodes": {"heimdall": {"gpus": [], "agent_url": "http://10.1.10.71:7701"}},
|
||||
}
|
||||
_write_profile(profiles_dir, "heimdall", profile)
|
||||
|
||||
r = client.get("/api/nodes-mgmt/nodes/heimdall/profile")
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert "services" in data
|
||||
assert "cf-text" in data["services"]
|
||||
|
||||
|
||||
def test_get_profile_404_when_missing(client, tmp_path):
|
||||
_write_config(tmp_path, {"profiles_dir": str(tmp_path / "profiles")})
|
||||
r = client.get("/api/nodes-mgmt/nodes/nonexistent/profile")
|
||||
assert r.status_code == 404
|
||||
|
||||
|
||||
def test_get_profile_500_on_malformed_yaml(client, tmp_path):
|
||||
profiles_dir = tmp_path / "profiles"
|
||||
profiles_dir.mkdir()
|
||||
_write_config(tmp_path, {"profiles_dir": str(profiles_dir)})
|
||||
(profiles_dir / "bad.yaml").write_text("key: [unclosed", encoding="utf-8")
|
||||
|
||||
r = client.get("/api/nodes-mgmt/nodes/bad/profile")
|
||||
assert r.status_code == 500
|
||||
|
||||
|
||||
# ── POST /api/nodes-mgmt/nodes/{node_id}/gpu/{gpu_id}/services ─────────────────
|
||||
|
||||
|
||||
_BASE_PROFILE = {
|
||||
"services": {
|
||||
"cf-text": {"min_compute_cap": 7.0, "max_mb": 8192, "priority": 1,
|
||||
"catalog": {"llama3": {"vram_mb": 6144, "path": "/m/llama3",
|
||||
"description": "", "multi_gpu": False, "env": {}}}},
|
||||
"ollama": {"min_compute_cap": 0.0, "max_mb": 2048, "priority": 2, "catalog": {}},
|
||||
},
|
||||
"nodes": {
|
||||
"heimdall": {
|
||||
"gpus": [{"id": 0, "vram_mb": 24576, "compute_cap": 8.6,
|
||||
"services": [], "role": "primary", "card": "RTX 3090",
|
||||
"always_on": True}],
|
||||
"agent_url": "http://10.1.10.71:7701",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def _setup_profile(tmp_path, profile=None):
|
||||
profiles_dir = tmp_path / "profiles"
|
||||
_write_config(tmp_path, {
|
||||
"coordinator_url": "http://fake-coord:7700",
|
||||
"profiles_dir": str(profiles_dir),
|
||||
})
|
||||
_write_profile(profiles_dir, "heimdall", profile or _BASE_PROFILE)
|
||||
return profiles_dir
|
||||
|
||||
|
||||
def test_update_services_compatible_writes_and_reloads(client, tmp_path):
|
||||
profiles_dir = _setup_profile(tmp_path)
|
||||
|
||||
mock_reload = MagicMock()
|
||||
mock_reload.status_code = 200
|
||||
|
||||
with patch("httpx.post", return_value=mock_reload):
|
||||
r = client.post(
|
||||
"/api/nodes-mgmt/nodes/heimdall/gpu/0/services",
|
||||
json={"services": ["cf-text"]},
|
||||
)
|
||||
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["ok"] is True
|
||||
assert data["reloaded"] is True
|
||||
|
||||
saved = yaml.safe_load((profiles_dir / "heimdall.yaml").read_text())
|
||||
assert saved["nodes"]["heimdall"]["gpus"][0]["services"] == ["cf-text"]
|
||||
|
||||
|
||||
def test_update_services_atomic_write_uses_tmp_file(client, tmp_path):
|
||||
"""YAML must be written to .tmp then renamed — never written directly."""
|
||||
profiles_dir = _setup_profile(tmp_path)
|
||||
renamed_pairs: list[tuple] = []
|
||||
|
||||
original_replace = _os.replace
|
||||
|
||||
def capture(src, dst):
|
||||
renamed_pairs.append((str(src), str(dst)))
|
||||
original_replace(src, dst)
|
||||
|
||||
with patch("os.replace", side_effect=capture), \
|
||||
patch("httpx.post", return_value=MagicMock(status_code=200)):
|
||||
client.post(
|
||||
"/api/nodes-mgmt/nodes/heimdall/gpu/0/services",
|
||||
json={"services": ["ollama"]},
|
||||
)
|
||||
|
||||
assert any(src.endswith(".tmp") for src, dst in renamed_pairs), \
|
||||
"Expected atomic write via .tmp rename"
|
||||
|
||||
|
||||
def test_update_services_incompatible_compute_cap_returns_422(client, tmp_path):
|
||||
low_cap_profile = {
|
||||
**_BASE_PROFILE,
|
||||
"nodes": {
|
||||
"heimdall": {
|
||||
"gpus": [{"id": 0, "vram_mb": 24576, "compute_cap": 6.0,
|
||||
"services": [], "role": "p", "card": "GTX 1080",
|
||||
"always_on": False}],
|
||||
"agent_url": "http://10.1.10.71:7701",
|
||||
}
|
||||
}
|
||||
}
|
||||
_setup_profile(tmp_path, low_cap_profile)
|
||||
|
||||
r = client.post(
|
||||
"/api/nodes-mgmt/nodes/heimdall/gpu/0/services",
|
||||
json={"services": ["cf-text"]},
|
||||
)
|
||||
assert r.status_code == 422
|
||||
assert "compute_cap" in r.json()["detail"]
|
||||
|
||||
|
||||
def test_update_services_insufficient_vram_returns_422(client, tmp_path):
|
||||
tiny_vram_profile = {
|
||||
**_BASE_PROFILE,
|
||||
"nodes": {
|
||||
"heimdall": {
|
||||
"gpus": [{"id": 0, "vram_mb": 512, "compute_cap": 8.6,
|
||||
"services": [], "role": "p", "card": "old",
|
||||
"always_on": False}],
|
||||
"agent_url": "http://10.1.10.71:7701",
|
||||
}
|
||||
}
|
||||
}
|
||||
_setup_profile(tmp_path, tiny_vram_profile)
|
||||
|
||||
r = client.post(
|
||||
"/api/nodes-mgmt/nodes/heimdall/gpu/0/services",
|
||||
json={"services": ["cf-text"]},
|
||||
)
|
||||
assert r.status_code == 422
|
||||
assert "VRAM" in r.json()["detail"]
|
||||
|
||||
|
||||
def test_update_services_unknown_service_returns_422(client, tmp_path):
|
||||
_setup_profile(tmp_path)
|
||||
r = client.post(
|
||||
"/api/nodes-mgmt/nodes/heimdall/gpu/0/services",
|
||||
json={"services": ["not-a-real-service"]},
|
||||
)
|
||||
assert r.status_code == 422
|
||||
|
||||
|
||||
def test_update_services_reload_failure_returns_reloaded_false(client, tmp_path):
|
||||
"""YAML saved but coordinator reload fails — ok: true, reloaded: false."""
|
||||
_setup_profile(tmp_path)
|
||||
|
||||
mock_reload = MagicMock()
|
||||
mock_reload.status_code = 500
|
||||
|
||||
with patch("httpx.post", return_value=mock_reload):
|
||||
r = client.post(
|
||||
"/api/nodes-mgmt/nodes/heimdall/gpu/0/services",
|
||||
json={"services": ["ollama"]},
|
||||
)
|
||||
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["ok"] is True
|
||||
assert data["reloaded"] is False
|
||||
|
||||
# ── Ollama endpoints ───────────────────────────────────────────────────────────
|
||||
|
||||
_OLLAMA_PROFILE = {
|
||||
"services": {},
|
||||
"nodes": {
|
||||
"heimdall": {
|
||||
"gpus": [],
|
||||
"agent_url": "http://10.1.10.71:7701",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def test_list_ollama_models_proxies_tags(client, tmp_path):
|
||||
profiles_dir = tmp_path / "profiles"
|
||||
_write_config(tmp_path, {"profiles_dir": str(profiles_dir)})
|
||||
_write_profile(profiles_dir, "heimdall", _OLLAMA_PROFILE)
|
||||
|
||||
mock_tags = MagicMock()
|
||||
mock_tags.raise_for_status = MagicMock()
|
||||
mock_tags.json.return_value = {
|
||||
"models": [{"name": "nomic-embed-text", "size": 274000000, "modified_at": "2025-01-01"}]
|
||||
}
|
||||
|
||||
with patch("httpx.get", return_value=mock_tags):
|
||||
r = client.get("/api/nodes-mgmt/nodes/heimdall/models/ollama")
|
||||
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert len(data["models"]) == 1
|
||||
assert data["models"][0]["name"] == "nomic-embed-text"
|
||||
|
||||
|
||||
def test_list_ollama_models_unreachable_returns_error(client, tmp_path):
|
||||
import httpx as _httpx
|
||||
profiles_dir = tmp_path / "profiles"
|
||||
_write_config(tmp_path, {"profiles_dir": str(profiles_dir)})
|
||||
_write_profile(profiles_dir, "heimdall", _OLLAMA_PROFILE)
|
||||
|
||||
with patch("httpx.get", side_effect=_httpx.ConnectError("refused")):
|
||||
r = client.get("/api/nodes-mgmt/nodes/heimdall/models/ollama")
|
||||
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert "error" in data
|
||||
|
||||
|
||||
def test_pull_ollama_model_streams_sse(client, tmp_path):
|
||||
profiles_dir = tmp_path / "profiles"
|
||||
_write_config(tmp_path, {"profiles_dir": str(profiles_dir)})
|
||||
_write_profile(profiles_dir, "heimdall", _OLLAMA_PROFILE)
|
||||
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.iter_lines.return_value = iter([
|
||||
'{"status": "pulling manifest"}',
|
||||
'{"status": "pulling", "digest": "sha256-abc", "total": 1000, "completed": 500}',
|
||||
'{"status": "success"}',
|
||||
])
|
||||
|
||||
with patch("httpx.stream") as mock_stream_fn:
|
||||
mock_stream_fn.return_value.__enter__ = MagicMock(return_value=mock_resp)
|
||||
mock_stream_fn.return_value.__exit__ = MagicMock(return_value=False)
|
||||
r = client.post(
|
||||
"/api/nodes-mgmt/nodes/heimdall/models/ollama/pull",
|
||||
json={"name": "nomic-embed-text"},
|
||||
)
|
||||
|
||||
assert r.status_code == 200
|
||||
body = r.text
|
||||
assert 'data: {"status": "pulling manifest"}' in body
|
||||
assert 'data: {"status": "success"}' in body
|
||||
|
||||
|
||||
def test_pull_ollama_model_error_event_in_stream(client, tmp_path):
|
||||
profiles_dir = tmp_path / "profiles"
|
||||
_write_config(tmp_path, {"profiles_dir": str(profiles_dir)})
|
||||
_write_profile(profiles_dir, "heimdall", _OLLAMA_PROFILE)
|
||||
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.iter_lines.return_value = iter([
|
||||
'{"error": "permission denied: /var/lib/ollama/sha256-abc-partial-0"}',
|
||||
])
|
||||
|
||||
with patch("httpx.stream") as mock_stream_fn:
|
||||
mock_stream_fn.return_value.__enter__ = MagicMock(return_value=mock_resp)
|
||||
mock_stream_fn.return_value.__exit__ = MagicMock(return_value=False)
|
||||
r = client.post(
|
||||
"/api/nodes-mgmt/nodes/heimdall/models/ollama/pull",
|
||||
json={"name": "nomic-embed-text"},
|
||||
)
|
||||
|
||||
assert r.status_code == 200
|
||||
assert "permission denied" in r.text
|
||||
|
||||
|
||||
def test_delete_ollama_model_proxies_delete(client, tmp_path):
|
||||
profiles_dir = tmp_path / "profiles"
|
||||
_write_config(tmp_path, {"profiles_dir": str(profiles_dir)})
|
||||
_write_profile(profiles_dir, "heimdall", _OLLAMA_PROFILE)
|
||||
|
||||
mock_del = MagicMock()
|
||||
mock_del.status_code = 200
|
||||
mock_del.raise_for_status = MagicMock()
|
||||
|
||||
with patch("httpx.request", return_value=mock_del):
|
||||
r = client.delete("/api/nodes-mgmt/nodes/heimdall/models/ollama/nomic-embed-text")
|
||||
|
||||
assert r.status_code == 200
|
||||
assert r.json() == {"ok": True}
|
||||
|
||||
|
||||
def test_delete_ollama_model_404_when_not_found(client, tmp_path):
|
||||
profiles_dir = tmp_path / "profiles"
|
||||
_write_config(tmp_path, {"profiles_dir": str(profiles_dir)})
|
||||
_write_profile(profiles_dir, "heimdall", _OLLAMA_PROFILE)
|
||||
|
||||
mock_del = MagicMock()
|
||||
mock_del.status_code = 404
|
||||
|
||||
with patch("httpx.request", return_value=mock_del):
|
||||
r = client.delete("/api/nodes-mgmt/nodes/heimdall/models/ollama/missing-model")
|
||||
|
||||
assert r.status_code == 404
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
"""API integration tests for app/sft.py — /api/sft/* endpoints."""
|
||||
"""API integration tests for app/sft.py -- /api/sft/* endpoints."""
|
||||
import json
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
|
@ -7,17 +7,17 @@ from pathlib import Path
|
|||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_sft_globals(tmp_path):
|
||||
from app import sft as sft_module
|
||||
_prev_data = sft_module._SFT_DATA_DIR
|
||||
_prev_cfg = sft_module._SFT_CONFIG_DIR
|
||||
_prev_default = sft_module._DEFAULT_BENCH_RESULTS_DIR
|
||||
sft_module.set_sft_data_dir(tmp_path)
|
||||
sft_module.set_sft_config_dir(tmp_path)
|
||||
sft_module.set_default_bench_results_dir(str(tmp_path / "bench_results"))
|
||||
from app.data import corrections as corr_module
|
||||
_prev_data = corr_module._DATA_DIR
|
||||
_prev_cfg = corr_module._CONFIG_DIR
|
||||
_prev_default = corr_module._DEFAULT_BENCH_RESULTS_DIR
|
||||
corr_module.set_data_dir(tmp_path)
|
||||
corr_module.set_config_dir(tmp_path)
|
||||
corr_module.set_default_bench_results_dir(str(tmp_path / "bench_results"))
|
||||
yield
|
||||
sft_module.set_sft_data_dir(_prev_data)
|
||||
sft_module.set_sft_config_dir(_prev_cfg)
|
||||
sft_module.set_default_bench_results_dir(_prev_default)
|
||||
corr_module.set_data_dir(_prev_data)
|
||||
corr_module.set_config_dir(_prev_cfg)
|
||||
corr_module.set_default_bench_results_dir(_prev_default)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -63,7 +63,7 @@ def _write_config(tmp_path, bench_results_dir: Path) -> None:
|
|||
)
|
||||
|
||||
|
||||
# ── /api/sft/runs ──────────────────────────────────────────────────────────
|
||||
# -- /api/sft/runs -------------------------------------------------------------
|
||||
|
||||
def test_runs_returns_empty_when_no_config(client):
|
||||
r = client.get("/api/sft/runs")
|
||||
|
|
@ -86,7 +86,7 @@ def test_runs_returns_available_runs(client, tmp_path):
|
|||
def test_runs_marks_already_imported(client, tmp_path):
|
||||
_write_run(tmp_path, "2026-04-07-143022", [_make_record("a")])
|
||||
_write_config(tmp_path, tmp_path / "bench_results")
|
||||
from app import sft as sft_module
|
||||
from app.data import corrections as sft_module
|
||||
candidates = sft_module._candidates_file()
|
||||
candidates.parent.mkdir(parents=True, exist_ok=True)
|
||||
candidates.write_text(
|
||||
|
|
@ -97,7 +97,7 @@ def test_runs_marks_already_imported(client, tmp_path):
|
|||
assert r.json()[0]["already_imported"] is True
|
||||
|
||||
|
||||
# ── /api/sft/import ─────────────────────────────────────────────────────────
|
||||
# -- /api/sft/import -----------------------------------------------------------
|
||||
|
||||
def test_import_adds_records(client, tmp_path):
|
||||
_write_run(tmp_path, "2026-04-07-143022", [_make_record("a"), _make_record("b")])
|
||||
|
|
@ -121,10 +121,10 @@ def test_import_unknown_run_returns_404(client, tmp_path):
|
|||
assert r.status_code == 404
|
||||
|
||||
|
||||
# ── /api/sft/queue ──────────────────────────────────────────────────────────
|
||||
# -- /api/sft/queue ------------------------------------------------------------
|
||||
|
||||
def _populate_candidates(tmp_path, records: list[dict]) -> None:
|
||||
from app import sft as sft_module
|
||||
from app.data import corrections as sft_module
|
||||
path = sft_module._candidates_file()
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
path.write_text(
|
||||
|
|
@ -164,7 +164,7 @@ def test_queue_empty_when_no_file(client):
|
|||
assert r.json() == {"items": [], "total": 0, "page": 1, "per_page": 20}
|
||||
|
||||
|
||||
# ── /api/sft/submit ─────────────────────────────────────────────────────────
|
||||
# -- /api/sft/submit -----------------------------------------------------------
|
||||
|
||||
def test_submit_correct_sets_approved(client, tmp_path):
|
||||
_populate_candidates(tmp_path, [_make_record("a")])
|
||||
|
|
@ -173,7 +173,7 @@ def test_submit_correct_sets_approved(client, tmp_path):
|
|||
"corrected_response": "def add(a, b): return a + b",
|
||||
})
|
||||
assert r.status_code == 200
|
||||
from app import sft as sft_module
|
||||
from app.data import corrections as sft_module
|
||||
records = sft_module._read_candidates()
|
||||
assert records[0]["status"] == "approved"
|
||||
assert records[0]["corrected_response"] == "def add(a, b): return a + b"
|
||||
|
|
@ -185,7 +185,7 @@ def test_submit_correct_also_appends_to_approved_file(client, tmp_path):
|
|||
"id": "a", "action": "correct",
|
||||
"corrected_response": "def add(a, b): return a + b",
|
||||
})
|
||||
from app import sft as sft_module
|
||||
from app.data import corrections as sft_module
|
||||
from app.utils import read_jsonl
|
||||
approved = read_jsonl(sft_module._approved_file())
|
||||
assert len(approved) == 1
|
||||
|
|
@ -196,7 +196,7 @@ def test_submit_discard_sets_discarded(client, tmp_path):
|
|||
_populate_candidates(tmp_path, [_make_record("a")])
|
||||
r = client.post("/api/sft/submit", json={"id": "a", "action": "discard"})
|
||||
assert r.status_code == 200
|
||||
from app import sft as sft_module
|
||||
from app.data import corrections as sft_module
|
||||
assert sft_module._read_candidates()[0]["status"] == "discarded"
|
||||
|
||||
|
||||
|
|
@ -204,7 +204,7 @@ def test_submit_flag_sets_model_rejected(client, tmp_path):
|
|||
_populate_candidates(tmp_path, [_make_record("a")])
|
||||
r = client.post("/api/sft/submit", json={"id": "a", "action": "flag"})
|
||||
assert r.status_code == 200
|
||||
from app import sft as sft_module
|
||||
from app.data import corrections as sft_module
|
||||
assert sft_module._read_candidates()[0]["status"] == "model_rejected"
|
||||
|
||||
|
||||
|
|
@ -243,7 +243,7 @@ def test_submit_correct_stores_failure_category(client, tmp_path):
|
|||
"failure_category": "style_violation",
|
||||
})
|
||||
assert r.status_code == 200
|
||||
from app import sft as sft_module
|
||||
from app.data import corrections as sft_module
|
||||
records = sft_module._read_candidates()
|
||||
assert records[0]["failure_category"] == "style_violation"
|
||||
|
||||
|
|
@ -255,7 +255,7 @@ def test_submit_correct_null_failure_category(client, tmp_path):
|
|||
"corrected_response": "def add(a, b): return a + b",
|
||||
})
|
||||
assert r.status_code == 200
|
||||
from app import sft as sft_module
|
||||
from app.data import corrections as sft_module
|
||||
records = sft_module._read_candidates()
|
||||
assert records[0]["failure_category"] is None
|
||||
|
||||
|
|
@ -270,14 +270,14 @@ def test_submit_invalid_failure_category_returns_422(client, tmp_path):
|
|||
assert r.status_code == 422
|
||||
|
||||
|
||||
# ── /api/sft/undo ────────────────────────────────────────────────────────────
|
||||
# -- /api/sft/undo -------------------------------------------------------------
|
||||
|
||||
def test_undo_restores_discarded_to_needs_review(client, tmp_path):
|
||||
_populate_candidates(tmp_path, [_make_record("a")])
|
||||
client.post("/api/sft/submit", json={"id": "a", "action": "discard"})
|
||||
r = client.post("/api/sft/undo", json={"id": "a"})
|
||||
assert r.status_code == 200
|
||||
from app import sft as sft_module
|
||||
from app.data import corrections as sft_module
|
||||
assert sft_module._read_candidates()[0]["status"] == "needs_review"
|
||||
|
||||
|
||||
|
|
@ -288,7 +288,7 @@ def test_undo_removes_approved_from_approved_file(client, tmp_path):
|
|||
"corrected_response": "def add(a, b): return a + b",
|
||||
})
|
||||
client.post("/api/sft/undo", json={"id": "a"})
|
||||
from app import sft as sft_module
|
||||
from app.data import corrections as sft_module
|
||||
from app.utils import read_jsonl
|
||||
approved = read_jsonl(sft_module._approved_file())
|
||||
assert not any(r["id"] == "a" for r in approved)
|
||||
|
|
@ -300,10 +300,10 @@ def test_undo_already_needs_review_returns_409(client, tmp_path):
|
|||
assert r.status_code == 409
|
||||
|
||||
|
||||
# ── /api/sft/export ──────────────────────────────────────────────────────────
|
||||
# -- /api/sft/export -----------------------------------------------------------
|
||||
|
||||
def test_export_returns_approved_as_sft_jsonl(client, tmp_path):
|
||||
from app import sft as sft_module
|
||||
from app.data import corrections as sft_module
|
||||
from app.utils import write_jsonl
|
||||
approved = {
|
||||
**_make_record("a"),
|
||||
|
|
@ -331,7 +331,7 @@ def test_export_returns_approved_as_sft_jsonl(client, tmp_path):
|
|||
|
||||
|
||||
def test_export_excludes_non_approved(client, tmp_path):
|
||||
from app import sft as sft_module
|
||||
from app.data import corrections as sft_module
|
||||
from app.utils import write_jsonl
|
||||
records = [
|
||||
{**_make_record("a"), "status": "discarded", "corrected_response": None},
|
||||
|
|
@ -348,10 +348,10 @@ def test_export_empty_when_no_approved_file(client):
|
|||
assert r.text.strip() == ""
|
||||
|
||||
|
||||
# ── /api/sft/stats ───────────────────────────────────────────────────────────
|
||||
# -- /api/sft/stats ------------------------------------------------------------
|
||||
|
||||
def test_stats_counts_by_status(client, tmp_path):
|
||||
from app import sft as sft_module
|
||||
from app.data import corrections as sft_module
|
||||
from app.utils import write_jsonl
|
||||
records = [
|
||||
_make_record("a"),
|
||||
|
|
|
|||
187
tests/test_train.py
Normal file
187
tests/test_train.py
Normal file
|
|
@ -0,0 +1,187 @@
|
|||
"""Tests for app/train/train.py -- /api/train/* endpoints."""
|
||||
import json
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_globals(tmp_path):
|
||||
from app.train import train as train_module
|
||||
train_module.set_db_path(tmp_path / "train_jobs.db")
|
||||
train_module.set_models_dir(tmp_path / "models")
|
||||
train_module._running_procs.clear()
|
||||
yield
|
||||
train_module._running_procs.clear()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
from app.api import app
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
def _parse_sse(content: bytes) -> list[dict]:
|
||||
events = []
|
||||
for line in content.decode().splitlines():
|
||||
if line.startswith("data: "):
|
||||
events.append(json.loads(line[6:]))
|
||||
return events
|
||||
|
||||
|
||||
def test_list_jobs_empty(client):
|
||||
r = client.get("/api/train/jobs")
|
||||
assert r.status_code == 200
|
||||
assert r.json() == {"jobs": []}
|
||||
|
||||
|
||||
def test_create_job_returns_queued_record(client):
|
||||
r = client.post("/api/train/jobs",
|
||||
json={"type": "classifier", "model_key": "deberta-small",
|
||||
"config_json": {"epochs": 3}})
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["status"] == "queued"
|
||||
assert data["type"] == "classifier"
|
||||
assert data["model_key"] == "deberta-small"
|
||||
assert "id" in data
|
||||
|
||||
|
||||
def test_create_job_invalid_type_returns_400(client):
|
||||
r = client.post("/api/train/jobs",
|
||||
json={"type": "unknown-type", "model_key": "deberta-small"})
|
||||
assert r.status_code == 400
|
||||
|
||||
|
||||
def test_create_job_appears_in_list(client):
|
||||
client.post("/api/train/jobs", json={"type": "classifier", "model_key": "deberta-small"})
|
||||
r = client.get("/api/train/jobs")
|
||||
assert r.status_code == 200
|
||||
assert len(r.json()["jobs"]) == 1
|
||||
|
||||
|
||||
def test_get_job_returns_record(client):
|
||||
r = client.post("/api/train/jobs", json={"type": "classifier", "model_key": "deberta-small"})
|
||||
job_id = r.json()["id"]
|
||||
r2 = client.get(f"/api/train/jobs/{job_id}")
|
||||
assert r2.status_code == 200
|
||||
assert r2.json()["id"] == job_id
|
||||
|
||||
|
||||
def test_get_job_404_for_unknown(client):
|
||||
r = client.get("/api/train/jobs/no-such-id")
|
||||
assert r.status_code == 404
|
||||
|
||||
|
||||
def test_cancel_queued_job(client):
|
||||
r = client.post("/api/train/jobs", json={"type": "classifier", "model_key": "deberta-small"})
|
||||
job_id = r.json()["id"]
|
||||
r2 = client.delete(f"/api/train/jobs/{job_id}/cancel")
|
||||
assert r2.status_code == 200
|
||||
assert r2.json()["status"] == "cancelled"
|
||||
r3 = client.get(f"/api/train/jobs/{job_id}")
|
||||
assert r3.json()["status"] == "cancelled"
|
||||
|
||||
|
||||
def test_cancel_completed_job_returns_409(client):
|
||||
from app.train import train as train_module
|
||||
train_module._init_db()
|
||||
with train_module._db() as conn:
|
||||
conn.execute(
|
||||
"INSERT INTO jobs (id, type, model_key, status, config_json, created_at) "
|
||||
"VALUES ('abc', 'classifier', 'deberta-small', 'completed', '{}', '2026-05-01T00:00:00Z')"
|
||||
)
|
||||
r = client.delete("/api/train/jobs/abc/cancel")
|
||||
assert r.status_code == 409
|
||||
|
||||
|
||||
def test_cancel_terminates_running_proc(client):
|
||||
from app.train import train as train_module
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.wait = MagicMock()
|
||||
r = client.post("/api/train/jobs", json={"type": "classifier", "model_key": "deberta-small"})
|
||||
job_id = r.json()["id"]
|
||||
train_module._running_procs[job_id] = mock_proc
|
||||
with train_module._db() as conn:
|
||||
conn.execute("UPDATE jobs SET status='running' WHERE id=?", (job_id,))
|
||||
r2 = client.delete(f"/api/train/jobs/{job_id}/cancel")
|
||||
assert r2.status_code == 200
|
||||
mock_proc.terminate.assert_called_once()
|
||||
|
||||
|
||||
def test_run_job_streams_sse(client):
|
||||
r = client.post("/api/train/jobs", json={"type": "classifier", "model_key": "deberta-small"})
|
||||
job_id = r.json()["id"]
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.stdout = iter(["Epoch 1\n", "Done\n"])
|
||||
mock_proc.returncode = 0
|
||||
mock_proc.wait = MagicMock()
|
||||
with patch("app.train.train._subprocess.Popen", return_value=mock_proc):
|
||||
r2 = client.get(f"/api/train/jobs/{job_id}/run")
|
||||
assert r2.status_code == 200
|
||||
assert "text/event-stream" in r2.headers.get("content-type", "")
|
||||
events = _parse_sse(r2.content)
|
||||
assert any(e["type"] == "complete" for e in events)
|
||||
|
||||
|
||||
def test_run_job_marks_completed_in_db(client):
|
||||
from app.train import train as train_module
|
||||
r = client.post("/api/train/jobs", json={"type": "classifier", "model_key": "deberta-small"})
|
||||
job_id = r.json()["id"]
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.stdout = iter([])
|
||||
mock_proc.returncode = 0
|
||||
mock_proc.wait = MagicMock()
|
||||
with patch("app.train.train._subprocess.Popen", return_value=mock_proc):
|
||||
client.get(f"/api/train/jobs/{job_id}/run")
|
||||
r2 = client.get(f"/api/train/jobs/{job_id}")
|
||||
assert r2.json()["status"] == "completed"
|
||||
|
||||
|
||||
def test_run_job_marks_failed_on_nonzero_exit(client):
|
||||
r = client.post("/api/train/jobs", json={"type": "classifier", "model_key": "deberta-small"})
|
||||
job_id = r.json()["id"]
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.stdout = iter([])
|
||||
mock_proc.returncode = 1
|
||||
mock_proc.wait = MagicMock()
|
||||
with patch("app.train.train._subprocess.Popen", return_value=mock_proc):
|
||||
client.get(f"/api/train/jobs/{job_id}/run")
|
||||
r2 = client.get(f"/api/train/jobs/{job_id}")
|
||||
assert r2.json()["status"] == "failed"
|
||||
|
||||
|
||||
def test_run_nonqueued_job_returns_409(client):
|
||||
from app.train import train as train_module
|
||||
train_module._init_db()
|
||||
with train_module._db() as conn:
|
||||
conn.execute(
|
||||
"INSERT INTO jobs (id, type, model_key, status, config_json, created_at) "
|
||||
"VALUES ('xyz', 'classifier', 'deberta-small', 'running', '{}', '2026-05-01T00:00:00Z')"
|
||||
)
|
||||
r = client.get("/api/train/jobs/xyz/run")
|
||||
assert r.status_code == 409
|
||||
|
||||
|
||||
def test_run_unknown_job_returns_404(client):
|
||||
r = client.get("/api/train/jobs/no-such/run")
|
||||
assert r.status_code == 404
|
||||
|
||||
|
||||
def test_results_empty_when_no_models_dir(client):
|
||||
r = client.get("/api/train/results")
|
||||
assert r.status_code == 200
|
||||
assert r.json() == {"results": []}
|
||||
|
||||
|
||||
def test_results_returns_training_info(client, tmp_path):
|
||||
from app.train import train as train_module
|
||||
models_dir = tmp_path / "models" / "avocet-deberta-small"
|
||||
models_dir.mkdir(parents=True)
|
||||
train_module.set_models_dir(tmp_path / "models")
|
||||
info = {"name": "avocet-deberta-small", "val_macro_f1": 0.712, "sample_count": 401}
|
||||
(models_dir / "training_info.json").write_text(json.dumps(info))
|
||||
r = client.get("/api/train/results")
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert any(d["name"] == "avocet-deberta-small" for d in data["results"])
|
||||
124
web/src/components/AppSidebar.test.ts
Normal file
124
web/src/components/AppSidebar.test.ts
Normal file
|
|
@ -0,0 +1,124 @@
|
|||
import { mount, flushPromises } from '@vue/test-utils'
|
||||
import { createRouter, createWebHashHistory } from 'vue-router'
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest'
|
||||
import AppSidebar from './AppSidebar.vue'
|
||||
|
||||
// Minimal router so RouterLink renders without warnings
|
||||
const router = createRouter({
|
||||
history: createWebHashHistory(),
|
||||
routes: [
|
||||
{ path: '/', component: { template: '<div />' } },
|
||||
{ path: '/fleet', component: { template: '<div />' } },
|
||||
{ path: '/data/label', component: { template: '<div />' } },
|
||||
{ path: '/data/fetch', component: { template: '<div />' } },
|
||||
{ path: '/data/corrections', component: { template: '<div />' } },
|
||||
{ path: '/data/imitate', component: { template: '<div />' } },
|
||||
{ path: '/eval/benchmark', component: { template: '<div />' } },
|
||||
{ path: '/eval/compare', component: { template: '<div />' } },
|
||||
{ path: '/train/jobs', component: { template: '<div />' } },
|
||||
{ path: '/train/results', component: { template: '<div />' } },
|
||||
{ path: '/settings', component: { template: '<div />' } },
|
||||
],
|
||||
})
|
||||
|
||||
function makeFetch(signals: Record<string, boolean> = {}) {
|
||||
return vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
json: async () => ({
|
||||
labeled_since_last_eval: 0,
|
||||
last_eval_timestamp: null,
|
||||
last_eval_best_score: null,
|
||||
active_jobs: [],
|
||||
corrections_export_ready: 0,
|
||||
signals,
|
||||
}),
|
||||
text: async () => '',
|
||||
})
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
localStorage.clear()
|
||||
vi.stubGlobal('fetch', makeFetch())
|
||||
})
|
||||
|
||||
describe('AppSidebar structure', () => {
|
||||
it('renders section headers for Data, Eval, Train', async () => {
|
||||
const w = mount(AppSidebar, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
const text = w.text()
|
||||
expect(text).toContain('Data')
|
||||
expect(text).toContain('Eval')
|
||||
expect(text).toContain('Train')
|
||||
})
|
||||
|
||||
it('renders all sub-links', async () => {
|
||||
const w = mount(AppSidebar, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
const anchors = w.findAll('a')
|
||||
const hrefs = anchors.map(a => a.attributes('href') ?? '')
|
||||
expect(hrefs.some(h => h.includes('/data/label'))).toBe(true)
|
||||
expect(hrefs.some(h => h.includes('/data/fetch'))).toBe(true)
|
||||
expect(hrefs.some(h => h.includes('/data/corrections'))).toBe(true)
|
||||
expect(hrefs.some(h => h.includes('/data/imitate'))).toBe(true)
|
||||
expect(hrefs.some(h => h.includes('/eval/benchmark'))).toBe(true)
|
||||
expect(hrefs.some(h => h.includes('/eval/compare'))).toBe(true)
|
||||
expect(hrefs.some(h => h.includes('/train/jobs'))).toBe(true)
|
||||
expect(hrefs.some(h => h.includes('/train/results'))).toBe(true)
|
||||
expect(hrefs.some(h => h.includes('/fleet'))).toBe(true)
|
||||
expect(hrefs.some(h => h.includes('/settings'))).toBe(true)
|
||||
})
|
||||
|
||||
it('does NOT render the old /benchmark or /models links', async () => {
|
||||
const w = mount(AppSidebar, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
const anchors = w.findAll('a')
|
||||
const hrefs = anchors.map(a => a.attributes('href') ?? '')
|
||||
// Old paths must not appear as direct links (they're only redirects)
|
||||
expect(hrefs.every(h => !h.endsWith('/#/benchmark'))).toBe(true)
|
||||
expect(hrefs.every(h => !h.endsWith('/#/models'))).toBe(true)
|
||||
expect(hrefs.every(h => !h.endsWith('/#/stats'))).toBe(true)
|
||||
})
|
||||
|
||||
it('shows no signal badges when all signals are false', async () => {
|
||||
vi.stubGlobal('fetch', makeFetch({ data_to_eval: false, eval_to_train: false, train_to_fleet: false }))
|
||||
const w = mount(AppSidebar, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
expect(w.findAll('.signal-badge').length).toBe(0)
|
||||
})
|
||||
|
||||
it('shows signal badge on Data section when data_to_eval is true', async () => {
|
||||
vi.stubGlobal('fetch', makeFetch({ data_to_eval: true, eval_to_train: false, train_to_fleet: false }))
|
||||
const w = mount(AppSidebar, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
const badges = w.findAll('.signal-badge')
|
||||
expect(badges.length).toBe(1)
|
||||
// It should be inside the Data section header
|
||||
const dataHeader = w.find('[data-section="data"]')
|
||||
expect(dataHeader.find('.signal-badge').exists()).toBe(true)
|
||||
})
|
||||
|
||||
it('shows signal badge on Eval section when eval_to_train is true', async () => {
|
||||
vi.stubGlobal('fetch', makeFetch({ data_to_eval: false, eval_to_train: true, train_to_fleet: false }))
|
||||
const w = mount(AppSidebar, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
const evalHeader = w.find('[data-section="eval"]')
|
||||
expect(evalHeader.find('.signal-badge').exists()).toBe(true)
|
||||
})
|
||||
|
||||
it('shows signal badge on Train section when train_to_fleet is true', async () => {
|
||||
vi.stubGlobal('fetch', makeFetch({ data_to_eval: false, eval_to_train: false, train_to_fleet: true }))
|
||||
const w = mount(AppSidebar, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
const trainHeader = w.find('[data-section="train"]')
|
||||
expect(trainHeader.find('.signal-badge').exists()).toBe(true)
|
||||
})
|
||||
|
||||
it('stow toggle still works', async () => {
|
||||
const w = mount(AppSidebar, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
const nav = w.find('nav')
|
||||
expect(nav.classes()).not.toContain('stowed')
|
||||
await w.find('.stow-btn').trigger('click')
|
||||
expect(nav.classes()).toContain('stowed')
|
||||
})
|
||||
})
|
||||
|
|
@ -28,12 +28,70 @@
|
|||
</button>
|
||||
</div>
|
||||
|
||||
<!-- Nav items -->
|
||||
<!-- Nav -->
|
||||
<ul class="nav-list" role="list">
|
||||
<li v-for="item in navItems" :key="item.path">
|
||||
<!-- Top-level links -->
|
||||
<li>
|
||||
<RouterLink
|
||||
to="/"
|
||||
class="nav-item"
|
||||
:title="stowed ? 'Dashboard' : ''"
|
||||
@click="isMobile && stow()"
|
||||
>
|
||||
<span class="nav-icon" aria-hidden="true">📊</span>
|
||||
<span v-if="!stowed" class="nav-label">Dashboard</span>
|
||||
</RouterLink>
|
||||
</li>
|
||||
<li>
|
||||
<RouterLink
|
||||
to="/fleet"
|
||||
class="nav-item"
|
||||
:title="stowed ? 'Fleet' : ''"
|
||||
@click="isMobile && stow()"
|
||||
>
|
||||
<span class="nav-icon" aria-hidden="true">⚡</span>
|
||||
<span v-if="!stowed" class="nav-label">Fleet</span>
|
||||
</RouterLink>
|
||||
</li>
|
||||
<li>
|
||||
<RouterLink
|
||||
to="/nodes"
|
||||
class="nav-item"
|
||||
:title="stowed ? 'Nodes' : ''"
|
||||
@click="isMobile && stow()"
|
||||
>
|
||||
<span class="nav-icon" aria-hidden="true">🖥️</span>
|
||||
<span v-if="!stowed" class="nav-label">Nodes</span>
|
||||
</RouterLink>
|
||||
</li>
|
||||
|
||||
<!-- ① Data section -->
|
||||
<li>
|
||||
<div class="section-header" data-section="data" aria-hidden="true">
|
||||
<template v-if="!stowed">
|
||||
<span class="section-label">① Data</span>
|
||||
<span
|
||||
v-if="signals.data_to_eval"
|
||||
class="signal-badge"
|
||||
title="Enough new labels to run eval"
|
||||
aria-label="Eval recommended"
|
||||
/>
|
||||
</template>
|
||||
<template v-else>
|
||||
<span class="section-icon">①</span>
|
||||
<span
|
||||
v-if="signals.data_to_eval"
|
||||
class="signal-badge signal-badge-stowed"
|
||||
title="Eval recommended"
|
||||
aria-label="Eval recommended"
|
||||
/>
|
||||
</template>
|
||||
</div>
|
||||
</li>
|
||||
<li v-for="item in dataItems" :key="item.path">
|
||||
<RouterLink
|
||||
:to="item.path"
|
||||
class="nav-item"
|
||||
class="nav-item nav-subitem"
|
||||
:title="stowed ? item.label : ''"
|
||||
@click="isMobile && stow()"
|
||||
>
|
||||
|
|
@ -41,10 +99,94 @@
|
|||
<span v-if="!stowed" class="nav-label">{{ item.label }}</span>
|
||||
</RouterLink>
|
||||
</li>
|
||||
|
||||
<!-- ② Eval section -->
|
||||
<li>
|
||||
<div class="section-header" data-section="eval" aria-hidden="true">
|
||||
<template v-if="!stowed">
|
||||
<span class="section-label">② Eval</span>
|
||||
<span
|
||||
v-if="signals.eval_to_train"
|
||||
class="signal-badge"
|
||||
title="Strong eval result — consider finetuning"
|
||||
aria-label="Finetune recommended"
|
||||
/>
|
||||
</template>
|
||||
<template v-else>
|
||||
<span class="section-icon">②</span>
|
||||
<span
|
||||
v-if="signals.eval_to_train"
|
||||
class="signal-badge signal-badge-stowed"
|
||||
title="Finetune recommended"
|
||||
aria-label="Finetune recommended"
|
||||
/>
|
||||
</template>
|
||||
</div>
|
||||
</li>
|
||||
<li v-for="item in evalItems" :key="item.path">
|
||||
<RouterLink
|
||||
:to="item.path"
|
||||
class="nav-item nav-subitem"
|
||||
:title="stowed ? item.label : ''"
|
||||
@click="isMobile && stow()"
|
||||
>
|
||||
<span class="nav-icon" aria-hidden="true">{{ item.icon }}</span>
|
||||
<span v-if="!stowed" class="nav-label">{{ item.label }}</span>
|
||||
</RouterLink>
|
||||
</li>
|
||||
|
||||
<!-- ③ Train section -->
|
||||
<li>
|
||||
<div class="section-header" data-section="train" aria-hidden="true">
|
||||
<template v-if="!stowed">
|
||||
<span class="section-label">③ Train</span>
|
||||
<span
|
||||
v-if="signals.train_to_fleet"
|
||||
class="signal-badge"
|
||||
title="Trained model ready for fleet registration"
|
||||
aria-label="Fleet registration recommended"
|
||||
/>
|
||||
</template>
|
||||
<template v-else>
|
||||
<span class="section-icon">③</span>
|
||||
<span
|
||||
v-if="signals.train_to_fleet"
|
||||
class="signal-badge signal-badge-stowed"
|
||||
title="Fleet registration recommended"
|
||||
aria-label="Fleet registration recommended"
|
||||
/>
|
||||
</template>
|
||||
</div>
|
||||
</li>
|
||||
<li v-for="item in trainItems" :key="item.path">
|
||||
<RouterLink
|
||||
:to="item.path"
|
||||
class="nav-item nav-subitem"
|
||||
:title="stowed ? item.label : ''"
|
||||
@click="isMobile && stow()"
|
||||
>
|
||||
<span class="nav-icon" aria-hidden="true">{{ item.icon }}</span>
|
||||
<span v-if="!stowed" class="nav-label">{{ item.label }}</span>
|
||||
</RouterLink>
|
||||
</li>
|
||||
|
||||
<!-- Divider + Settings -->
|
||||
<li class="nav-divider" aria-hidden="true" />
|
||||
<li>
|
||||
<RouterLink
|
||||
to="/settings"
|
||||
class="nav-item"
|
||||
:title="stowed ? 'Settings' : ''"
|
||||
@click="isMobile && stow()"
|
||||
>
|
||||
<span class="nav-icon" aria-hidden="true">⚙️</span>
|
||||
<span v-if="!stowed" class="nav-label">Settings</span>
|
||||
</RouterLink>
|
||||
</li>
|
||||
</ul>
|
||||
</nav>
|
||||
|
||||
<!-- Mobile hamburger button rendered outside the sidebar so it's visible when stowed -->
|
||||
<!-- Mobile hamburger button — visible when sidebar is stowed on mobile -->
|
||||
<button
|
||||
v-if="isMobile && stowed"
|
||||
class="mobile-hamburger"
|
||||
|
|
@ -61,25 +203,66 @@ import { RouterLink } from 'vue-router'
|
|||
|
||||
const LS_KEY = 'cf-avocet-nav-stowed'
|
||||
|
||||
const navItems = [
|
||||
{ path: '/', icon: '🃏', label: 'Label' },
|
||||
{ path: '/fetch', icon: '📥', label: 'Fetch' },
|
||||
{ path: '/stats', icon: '📊', label: 'Stats' },
|
||||
{ path: '/benchmark', icon: '🏁', label: 'Benchmark' },
|
||||
{ path: '/models', icon: '🤗', label: 'Models' },
|
||||
{ path: '/imitate', icon: '🪞', label: 'Imitate' },
|
||||
{ path: '/corrections', icon: '✍️', label: 'Corrections' },
|
||||
{ path: '/settings', icon: '⚙️', label: 'Settings' },
|
||||
interface NavItem {
|
||||
path: string
|
||||
icon: string
|
||||
label: string
|
||||
}
|
||||
|
||||
interface DashboardSignals {
|
||||
data_to_eval: boolean
|
||||
eval_to_train: boolean
|
||||
train_to_fleet: boolean
|
||||
}
|
||||
|
||||
const dataItems: NavItem[] = [
|
||||
{ path: '/data/label', icon: '🏷', label: 'Label' },
|
||||
{ path: '/data/fetch', icon: '📬', label: 'Fetch' },
|
||||
{ path: '/data/corrections', icon: '✏️', label: 'Corrections' },
|
||||
{ path: '/data/imitate', icon: '🪞', label: 'Imitate' },
|
||||
]
|
||||
|
||||
const stowed = ref(localStorage.getItem(LS_KEY) === 'true')
|
||||
const winWidth = ref(window.innerWidth)
|
||||
const isMobile = computed(() => winWidth.value < 640)
|
||||
const evalItems: NavItem[] = [
|
||||
{ path: '/eval/benchmark', icon: '📊', label: 'Benchmark' },
|
||||
{ path: '/eval/compare', icon: '🔍', label: 'Compare' },
|
||||
]
|
||||
|
||||
const trainItems: NavItem[] = [
|
||||
{ path: '/train/jobs', icon: '🧠', label: 'Jobs' },
|
||||
{ path: '/train/results', icon: '📈', label: 'Results' },
|
||||
]
|
||||
|
||||
const stowed = ref(localStorage.getItem(LS_KEY) === 'true')
|
||||
const winWidth = ref(window.innerWidth)
|
||||
const isMobile = computed(() => winWidth.value < 640)
|
||||
|
||||
const signals = ref<DashboardSignals>({
|
||||
data_to_eval: false,
|
||||
eval_to_train: false,
|
||||
train_to_fleet: false,
|
||||
})
|
||||
|
||||
async function loadSignals() {
|
||||
try {
|
||||
const res = await fetch('/api/dashboard')
|
||||
if (res.ok) {
|
||||
const data = await res.json() as { signals?: DashboardSignals }
|
||||
if (data.signals) {
|
||||
signals.value = {
|
||||
data_to_eval: data.signals.data_to_eval ?? false,
|
||||
eval_to_train: data.signals.eval_to_train ?? false,
|
||||
train_to_fleet: data.signals.train_to_fleet ?? false,
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
// Non-fatal: badges simply stay hidden if API is unreachable
|
||||
}
|
||||
}
|
||||
|
||||
function toggle() {
|
||||
stowed.value = !stowed.value
|
||||
localStorage.setItem(LS_KEY, String(stowed.value))
|
||||
// Update CSS variable on :root so .app-main margin-left syncs
|
||||
document.documentElement.style.setProperty('--sidebar-width', stowed.value ? '56px' : '200px')
|
||||
}
|
||||
|
||||
|
|
@ -93,13 +276,12 @@ function onResize() { winWidth.value = window.innerWidth }
|
|||
|
||||
onMounted(() => {
|
||||
window.addEventListener('resize', onResize)
|
||||
// Apply persisted sidebar width to :root on mount
|
||||
document.documentElement.style.setProperty('--sidebar-width', stowed.value ? '56px' : '200px')
|
||||
// On mobile, default to stowed
|
||||
if (isMobile.value && !localStorage.getItem(LS_KEY)) {
|
||||
stowed.value = true
|
||||
document.documentElement.style.setProperty('--sidebar-width', '56px')
|
||||
}
|
||||
loadSignals()
|
||||
})
|
||||
|
||||
onUnmounted(() => window.removeEventListener('resize', onResize))
|
||||
|
|
@ -121,18 +303,15 @@ onUnmounted(() => window.removeEventListener('resize', onResize))
|
|||
overflow: hidden;
|
||||
}
|
||||
|
||||
.sidebar.stowed {
|
||||
width: 56px;
|
||||
}
|
||||
.sidebar.stowed { width: 56px; }
|
||||
|
||||
/* Mobile: slide in/out from left */
|
||||
.sidebar.mobile {
|
||||
box-shadow: 2px 0 16px rgba(0, 0, 0, 0.15);
|
||||
}
|
||||
|
||||
.sidebar.mobile.stowed {
|
||||
transform: translateX(-100%);
|
||||
width: 200px; /* keep width so slide-in looks right */
|
||||
width: 200px;
|
||||
transition: transform 250ms ease, width 250ms ease;
|
||||
}
|
||||
|
||||
|
|
@ -165,10 +344,7 @@ onUnmounted(() => window.removeEventListener('resize', onResize))
|
|||
white-space: nowrap;
|
||||
}
|
||||
|
||||
.logo-icon {
|
||||
font-size: 1.25rem;
|
||||
flex-shrink: 0;
|
||||
}
|
||||
.logo-icon { font-size: 1.25rem; flex-shrink: 0; }
|
||||
|
||||
.logo-name {
|
||||
font-family: var(--font-display, var(--font-body, sans-serif));
|
||||
|
|
@ -193,16 +369,76 @@ onUnmounted(() => window.removeEventListener('resize', onResize))
|
|||
transition: background 0.15s;
|
||||
}
|
||||
|
||||
.stow-btn:hover {
|
||||
background: var(--color-border, #d0d7e8);
|
||||
}
|
||||
.stow-btn:hover { background: var(--color-border, #d0d7e8); }
|
||||
|
||||
.nav-list {
|
||||
list-style: none;
|
||||
padding: 0.5rem 0;
|
||||
flex: 1;
|
||||
overflow-y: auto;
|
||||
overflow-x: hidden;
|
||||
}
|
||||
|
||||
/* ── Section headers ── */
|
||||
.section-header {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.4rem;
|
||||
padding: 0.55rem 0.75rem 0.25rem;
|
||||
margin-top: 0.5rem;
|
||||
pointer-events: none;
|
||||
user-select: none;
|
||||
}
|
||||
|
||||
.section-label {
|
||||
font-size: 0.7rem;
|
||||
font-weight: 700;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.07em;
|
||||
color: var(--color-text-muted, #4a5c7a);
|
||||
white-space: nowrap;
|
||||
flex: 1;
|
||||
}
|
||||
|
||||
.section-icon {
|
||||
font-size: 0.75rem;
|
||||
color: var(--color-text-muted, #4a5c7a);
|
||||
width: 24px;
|
||||
text-align: center;
|
||||
flex-shrink: 0;
|
||||
}
|
||||
|
||||
/* ── Signal badges ── */
|
||||
.signal-badge {
|
||||
width: 8px;
|
||||
height: 8px;
|
||||
border-radius: 50%;
|
||||
background: var(--color-warning, #d4891a);
|
||||
flex-shrink: 0;
|
||||
display: inline-block;
|
||||
}
|
||||
|
||||
.signal-badge-stowed {
|
||||
position: absolute;
|
||||
top: 4px;
|
||||
right: 4px;
|
||||
}
|
||||
|
||||
/* Make the stowed section header container position:relative for the badge */
|
||||
.sidebar.stowed .section-header {
|
||||
position: relative;
|
||||
justify-content: center;
|
||||
padding: 0.55rem 0 0.25rem;
|
||||
}
|
||||
|
||||
/* ── Nav divider ── */
|
||||
.nav-divider {
|
||||
height: 1px;
|
||||
background: var(--color-border, #d0d7e8);
|
||||
margin: 0.5rem 0.75rem;
|
||||
}
|
||||
|
||||
/* ── Nav items ── */
|
||||
.nav-item {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
|
|
@ -238,6 +474,9 @@ onUnmounted(() => window.removeEventListener('resize', onResize))
|
|||
border-radius: 0 2px 2px 0;
|
||||
}
|
||||
|
||||
/* Sub-items are indented slightly in expanded state */
|
||||
.nav-subitem { padding-left: 1.1rem; font-size: 0.875rem; }
|
||||
|
||||
.nav-icon {
|
||||
font-size: 1.1rem;
|
||||
flex-shrink: 0;
|
||||
|
|
@ -245,12 +484,9 @@ onUnmounted(() => window.removeEventListener('resize', onResize))
|
|||
text-align: center;
|
||||
}
|
||||
|
||||
.nav-label {
|
||||
overflow: hidden;
|
||||
text-overflow: ellipsis;
|
||||
}
|
||||
.nav-label { overflow: hidden; text-overflow: ellipsis; }
|
||||
|
||||
/* Mobile hamburger — visible when sidebar is stowed on mobile */
|
||||
/* Mobile hamburger */
|
||||
.mobile-hamburger {
|
||||
position: fixed;
|
||||
top: 0.75rem;
|
||||
|
|
|
|||
129
web/src/components/nodes/GpuRow.vue
Normal file
129
web/src/components/nodes/GpuRow.vue
Normal file
|
|
@ -0,0 +1,129 @@
|
|||
<script setup lang="ts">
|
||||
import { ref, computed } from 'vue'
|
||||
import ServiceBadge from './ServiceBadge.vue'
|
||||
import type { GpuEntry, ServiceInfo } from '../../types/nodes'
|
||||
|
||||
const props = defineProps<{
|
||||
gpu: GpuEntry
|
||||
nodeId: string
|
||||
profileLoaded: boolean
|
||||
servicesCatalog: Record<string, ServiceInfo>
|
||||
}>()
|
||||
|
||||
const emit = defineEmits<{ updated: [] }>()
|
||||
|
||||
const saving = ref(false)
|
||||
const saveError = ref('')
|
||||
|
||||
const vramPct = computed(() => {
|
||||
if (!props.gpu.vram_total_mb) return 0
|
||||
return Math.round((props.gpu.vram_used_mb / props.gpu.vram_total_mb) * 100)
|
||||
})
|
||||
|
||||
function serviceState(svcName: string): 'running' | 'stopped' | 'assigned-only' | 'available' | 'incompatible' | 'unknown' {
|
||||
const svc = props.servicesCatalog[svcName]
|
||||
if (!svc) return 'unknown'
|
||||
const cap = props.gpu.compute_cap ?? 0
|
||||
if (cap < svc.min_compute_cap) return 'incompatible'
|
||||
if (props.gpu.services_running.includes(svcName)) return 'running'
|
||||
if (props.gpu.services_assigned.includes(svcName)) return 'assigned-only'
|
||||
return 'available'
|
||||
}
|
||||
|
||||
async function toggleService(svcName: string) {
|
||||
if (!props.profileLoaded || saving.value) return
|
||||
const current = [...props.gpu.services_assigned]
|
||||
const removing = current.includes(svcName)
|
||||
if (removing && !confirm(`Remove ${svcName} from GPU ${props.gpu.gpu_id}?`)) return
|
||||
const next = removing ? current.filter(s => s !== svcName) : [...current, svcName]
|
||||
|
||||
saving.value = true
|
||||
saveError.value = ''
|
||||
try {
|
||||
const r = await fetch(
|
||||
`/api/nodes-mgmt/nodes/${props.nodeId}/gpu/${props.gpu.gpu_id}/services`,
|
||||
{
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ services: next }),
|
||||
},
|
||||
)
|
||||
if (!r.ok) {
|
||||
const data = await r.json().catch(() => ({}))
|
||||
throw new Error((data as { detail?: string }).detail ?? `HTTP ${r.status}`)
|
||||
}
|
||||
const data = await r.json() as { ok: boolean; reloaded: boolean; warnings: string[] }
|
||||
if (data.warnings?.length) saveError.value = `Saved (warning: ${data.warnings.join(', ')})`
|
||||
emit('updated')
|
||||
} catch (e) {
|
||||
saveError.value = e instanceof Error ? e.message : 'Failed to update services'
|
||||
} finally {
|
||||
saving.value = false
|
||||
}
|
||||
}
|
||||
</script>
|
||||
|
||||
<template>
|
||||
<div class="gpu-row">
|
||||
<div class="gpu-info">
|
||||
<span class="gpu-label">GPU {{ gpu.gpu_id }}: {{ gpu.card }}</span>
|
||||
<span v-if="gpu.compute_cap != null" class="gpu-meta">sm{{ gpu.compute_cap }}</span>
|
||||
<span v-if="gpu.temp_c != null" class="gpu-meta">{{ gpu.temp_c }}°C</span>
|
||||
<span v-if="gpu.utilization_pct != null" class="gpu-meta">{{ gpu.utilization_pct }}%</span>
|
||||
</div>
|
||||
|
||||
<div class="vram-wrap">
|
||||
<div
|
||||
class="vram-bar"
|
||||
role="progressbar"
|
||||
:aria-valuenow="gpu.vram_used_mb"
|
||||
aria-valuemin="0"
|
||||
:aria-valuemax="gpu.vram_total_mb"
|
||||
:aria-label="`VRAM: ${gpu.vram_used_mb} of ${gpu.vram_total_mb} MB used`"
|
||||
>
|
||||
<div class="vram-fill" :style="{ width: `${vramPct}%` }" />
|
||||
</div>
|
||||
<span class="vram-text">{{ gpu.vram_used_mb }} / {{ gpu.vram_total_mb }} MB ({{ vramPct }}%)</span>
|
||||
</div>
|
||||
|
||||
<div v-if="profileLoaded" class="services-row" aria-label="Service assignments">
|
||||
<ServiceBadge
|
||||
v-for="(_, svcName) in servicesCatalog"
|
||||
:key="String(svcName)"
|
||||
:service-name="String(svcName)"
|
||||
:state="serviceState(String(svcName))"
|
||||
:assigned="gpu.services_assigned.includes(String(svcName))"
|
||||
:disabled="saving"
|
||||
@toggle="toggleService(String(svcName))"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div v-if="saveError" class="save-msg" role="alert">{{ saveError }}</div>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<style scoped>
|
||||
.gpu-row {
|
||||
padding: 0.5rem 0.75rem;
|
||||
border-radius: 4px;
|
||||
background: var(--bg-secondary, #111);
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0.4rem;
|
||||
}
|
||||
.gpu-info { display: flex; gap: 0.75rem; align-items: center; flex-wrap: wrap; font-size: 0.875rem; }
|
||||
.gpu-label { font-weight: 500; }
|
||||
.gpu-meta { color: var(--text-secondary, #888); font-size: 0.8rem; }
|
||||
.vram-wrap { display: flex; align-items: center; gap: 0.5rem; }
|
||||
.vram-bar {
|
||||
flex: 1;
|
||||
height: 8px;
|
||||
background: var(--bg-bar, #2a2a2a);
|
||||
border-radius: 4px;
|
||||
overflow: hidden;
|
||||
}
|
||||
.vram-fill { height: 100%; background: var(--color-primary, #4080ff); transition: width 0.3s; }
|
||||
.vram-text { font-size: 0.75rem; color: var(--text-secondary, #888); white-space: nowrap; }
|
||||
.services-row { display: flex; flex-wrap: wrap; gap: 0.4rem; }
|
||||
.save-msg { color: var(--color-warning, #ed8936); font-size: 0.8rem; }
|
||||
</style>
|
||||
132
web/src/components/nodes/HfNodeModelPanel.vue
Normal file
132
web/src/components/nodes/HfNodeModelPanel.vue
Normal file
|
|
@ -0,0 +1,132 @@
|
|||
<script setup lang="ts">
|
||||
import { ref, onMounted, onUnmounted } from 'vue'
|
||||
|
||||
interface CatalogEntry {
|
||||
path: string
|
||||
vram_mb: number
|
||||
description: string
|
||||
multi_gpu: boolean
|
||||
}
|
||||
|
||||
interface ServiceProfile {
|
||||
catalog: Record<string, CatalogEntry>
|
||||
min_compute_cap: number
|
||||
max_mb: number
|
||||
}
|
||||
|
||||
interface NodeProfile {
|
||||
services: Record<string, ServiceProfile>
|
||||
}
|
||||
|
||||
const props = defineProps<{
|
||||
nodeId: string
|
||||
}>()
|
||||
|
||||
const profile = ref<NodeProfile | null>(null)
|
||||
const loading = ref(true)
|
||||
const error = ref('')
|
||||
|
||||
let fetchAbort: AbortController | null = null
|
||||
|
||||
async function fetchProfile() {
|
||||
fetchAbort?.abort()
|
||||
fetchAbort = new AbortController()
|
||||
loading.value = true
|
||||
error.value = ''
|
||||
try {
|
||||
const r = await fetch(`/api/nodes-mgmt/nodes/${props.nodeId}/profile`, {
|
||||
signal: fetchAbort.signal,
|
||||
})
|
||||
if (r.status === 404) { profile.value = null; return }
|
||||
if (!r.ok) throw new Error(`HTTP ${r.status}`)
|
||||
profile.value = await r.json() as NodeProfile
|
||||
} catch (e) {
|
||||
if (e instanceof Error && e.name === 'AbortError') return
|
||||
error.value = e instanceof Error ? e.message : 'Failed to load profile'
|
||||
} finally {
|
||||
loading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
onMounted(fetchProfile)
|
||||
onUnmounted(() => { fetchAbort?.abort() })
|
||||
</script>
|
||||
|
||||
<template>
|
||||
<section class="hf-panel">
|
||||
<h3 class="panel-title">Model Catalog</h3>
|
||||
<p class="hf-hint">
|
||||
To download a new HuggingFace model,
|
||||
<a href="#/fleet" class="hf-link">go to Fleet</a>.
|
||||
Models downloaded there are automatically registered in node catalogs.
|
||||
</p>
|
||||
|
||||
<div aria-live="polite" aria-atomic="true" class="sr-announce">
|
||||
<span v-if="loading">Loading catalog...</span>
|
||||
</div>
|
||||
<div v-if="error" class="panel-error" role="alert">{{ error }}</div>
|
||||
<div v-else-if="!loading && !profile" class="panel-empty">No profile loaded for this node.</div>
|
||||
<div v-else-if="!loading && profile" class="catalog-body">
|
||||
<div
|
||||
v-for="(svcInfo, svcName) in profile.services"
|
||||
:key="String(svcName)"
|
||||
class="svc-section"
|
||||
>
|
||||
<h4 class="svc-name">{{ svcName }}</h4>
|
||||
<ul class="catalog-list" role="list">
|
||||
<li
|
||||
v-if="!Object.keys(svcInfo.catalog ?? {}).length"
|
||||
class="catalog-empty"
|
||||
>
|
||||
No models in catalog.
|
||||
</li>
|
||||
<li
|
||||
v-for="(entry, modelName) in (svcInfo.catalog ?? {})"
|
||||
:key="String(modelName)"
|
||||
class="catalog-item"
|
||||
>
|
||||
<span class="catalog-model">{{ modelName }}</span>
|
||||
<span class="catalog-vram">{{ entry.vram_mb }} MB</span>
|
||||
<span v-if="entry.description" class="catalog-desc">{{ entry.description }}</span>
|
||||
</li>
|
||||
</ul>
|
||||
</div>
|
||||
</div>
|
||||
</section>
|
||||
</template>
|
||||
|
||||
<style scoped>
|
||||
.hf-panel {
|
||||
margin-top: 0.75rem;
|
||||
padding: 0.75rem;
|
||||
border: 1px solid var(--border, #333);
|
||||
border-radius: 6px;
|
||||
}
|
||||
.panel-title { margin: 0 0 0.5rem; font-size: 0.9rem; }
|
||||
.hf-hint { font-size: 0.8rem; color: var(--text-secondary, #888); margin: 0 0 0.75rem; }
|
||||
.hf-link { color: var(--color-primary, #4080ff); }
|
||||
.svc-section { margin-bottom: 0.75rem; }
|
||||
.svc-name {
|
||||
margin: 0 0 0.25rem;
|
||||
font-size: 0.75rem;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.05em;
|
||||
color: var(--text-secondary, #888);
|
||||
}
|
||||
.catalog-list { list-style: none; margin: 0; padding: 0; display: flex; flex-direction: column; gap: 0.2rem; }
|
||||
.catalog-item {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.5rem;
|
||||
padding: 0.25rem 0.5rem;
|
||||
background: var(--bg-secondary, #111);
|
||||
border-radius: 4px;
|
||||
font-size: 0.8rem;
|
||||
}
|
||||
.catalog-model { font-family: monospace; flex: 1; }
|
||||
.catalog-vram { color: var(--text-secondary, #888); white-space: nowrap; }
|
||||
.catalog-desc { color: var(--text-secondary, #888); font-size: 0.75rem; flex: 2; }
|
||||
.catalog-empty, .panel-empty { color: var(--text-secondary, #888); font-size: 0.875rem; }
|
||||
.sr-announce { min-height: 1.2em; }
|
||||
.panel-error { color: var(--color-error, #fc8181); font-size: 0.8rem; }
|
||||
</style>
|
||||
90
web/src/components/nodes/NodeCard.vue
Normal file
90
web/src/components/nodes/NodeCard.vue
Normal file
|
|
@ -0,0 +1,90 @@
|
|||
<script setup lang="ts">
|
||||
import { ref } from 'vue'
|
||||
import GpuRow from './GpuRow.vue'
|
||||
import OllamaModelPanel from './OllamaModelPanel.vue'
|
||||
import HfNodeModelPanel from './HfNodeModelPanel.vue'
|
||||
import type { NodeSummary } from '../../types/nodes'
|
||||
|
||||
const props = defineProps<{ node: NodeSummary }>()
|
||||
const emit = defineEmits<{ updated: [] }>()
|
||||
|
||||
const showOllama = ref(false)
|
||||
const showHf = ref(false)
|
||||
</script>
|
||||
|
||||
<template>
|
||||
<section class="node-card" :class="{ offline: !node.online }">
|
||||
<header class="node-card-header">
|
||||
<div class="node-identity">
|
||||
<span
|
||||
class="status-dot"
|
||||
:class="node.online ? 'online' : 'offline'"
|
||||
:aria-label="node.online ? 'Online' : 'Offline'"
|
||||
role="img"
|
||||
/>
|
||||
<h2 class="node-name">{{ node.node_id }}</h2>
|
||||
<span class="node-agent">{{ node.agent_url }}</span>
|
||||
</div>
|
||||
<div v-if="node.profile_loaded" class="node-actions">
|
||||
<button class="btn-secondary btn-sm" @click="showOllama = !showOllama">
|
||||
{{ showOllama ? 'Hide Ollama' : 'Ollama' }}
|
||||
</button>
|
||||
<button class="btn-secondary btn-sm" @click="showHf = !showHf">
|
||||
{{ showHf ? 'Hide Catalog' : 'Catalog' }}
|
||||
</button>
|
||||
</div>
|
||||
</header>
|
||||
|
||||
<div v-if="!node.profile_loaded" class="no-profile" role="status">
|
||||
No profile configured for this node. GPU stats are visible; service assignment is disabled.
|
||||
</div>
|
||||
|
||||
<div class="gpu-list">
|
||||
<GpuRow
|
||||
v-for="gpu in node.gpus"
|
||||
:key="gpu.gpu_id"
|
||||
:gpu="gpu"
|
||||
:node-id="node.node_id"
|
||||
:profile-loaded="node.profile_loaded"
|
||||
:services-catalog="node.services_catalog"
|
||||
@updated="emit('updated')"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<OllamaModelPanel v-if="showOllama" :node-id="node.node_id" />
|
||||
<HfNodeModelPanel v-if="showHf" :node-id="node.node_id" />
|
||||
</section>
|
||||
</template>
|
||||
|
||||
<style scoped>
|
||||
.node-card {
|
||||
border: 1px solid var(--border, #333);
|
||||
border-radius: 8px;
|
||||
padding: 1rem;
|
||||
background: var(--bg-card, #1a1a1a);
|
||||
}
|
||||
.node-card.offline { opacity: 0.65; }
|
||||
.node-card-header {
|
||||
display: flex;
|
||||
align-items: flex-start;
|
||||
justify-content: space-between;
|
||||
gap: 0.5rem;
|
||||
margin-bottom: 0.75rem;
|
||||
}
|
||||
.node-identity { display: flex; align-items: center; gap: 0.5rem; flex-wrap: wrap; }
|
||||
.node-name { margin: 0; font-size: 1rem; font-weight: 600; }
|
||||
.node-agent { color: var(--text-secondary, #888); font-size: 0.8rem; font-family: monospace; }
|
||||
.status-dot { width: 10px; height: 10px; border-radius: 50%; flex-shrink: 0; }
|
||||
.status-dot.online { background: var(--color-success, #48bb78); }
|
||||
.status-dot.offline { background: var(--color-warning, #ed8936); }
|
||||
.node-actions { display: flex; gap: 0.5rem; flex-shrink: 0; }
|
||||
.no-profile {
|
||||
padding: 0.6rem 0.75rem;
|
||||
background: var(--bg-notice, #1e1e1e);
|
||||
border-radius: 4px;
|
||||
color: var(--text-secondary, #888);
|
||||
font-size: 0.875rem;
|
||||
margin-bottom: 0.5rem;
|
||||
}
|
||||
.gpu-list { display: flex; flex-direction: column; gap: 0.5rem; }
|
||||
</style>
|
||||
241
web/src/components/nodes/OllamaModelPanel.vue
Normal file
241
web/src/components/nodes/OllamaModelPanel.vue
Normal file
|
|
@ -0,0 +1,241 @@
|
|||
<script setup lang="ts">
|
||||
import { ref, onMounted, onUnmounted } from 'vue'
|
||||
|
||||
const props = defineProps<{ nodeId: string }>()
|
||||
|
||||
interface OllamaModel {
|
||||
name: string
|
||||
size: number
|
||||
modified_at: string
|
||||
}
|
||||
|
||||
const models = ref<OllamaModel[]>([])
|
||||
const loading = ref(true)
|
||||
const loadError = ref('')
|
||||
const pullName = ref('')
|
||||
const pulling = ref(false)
|
||||
const pullStatus = ref('')
|
||||
const pullPct = ref(0)
|
||||
const pullError = ref('')
|
||||
|
||||
// AbortController for the SSE pull stream
|
||||
const abortCtrl = ref<AbortController | null>(null)
|
||||
|
||||
// AbortController for the one-shot fetchModels request
|
||||
let fetchAbort: AbortController | null = null
|
||||
|
||||
async function fetchModels() {
|
||||
fetchAbort?.abort()
|
||||
fetchAbort = new AbortController()
|
||||
loading.value = true
|
||||
loadError.value = ''
|
||||
try {
|
||||
const r = await fetch(`/api/nodes-mgmt/nodes/${props.nodeId}/models/ollama`, {
|
||||
signal: fetchAbort.signal,
|
||||
})
|
||||
const data = await r.json() as { models?: OllamaModel[]; error?: string }
|
||||
if (data.error) { loadError.value = data.error; return }
|
||||
models.value = data.models ?? []
|
||||
} catch (e) {
|
||||
if (e instanceof Error && e.name === 'AbortError') return
|
||||
loadError.value = e instanceof Error ? e.message : 'Failed to load models'
|
||||
} finally {
|
||||
loading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
async function doPull() {
|
||||
const name = pullName.value.trim()
|
||||
if (!name || pulling.value) return
|
||||
pulling.value = true
|
||||
pullStatus.value = 'Starting...'
|
||||
pullError.value = ''
|
||||
pullPct.value = 0
|
||||
|
||||
const ctrl = new AbortController()
|
||||
abortCtrl.value = ctrl
|
||||
|
||||
try {
|
||||
const resp = await fetch(`/api/nodes-mgmt/nodes/${props.nodeId}/models/ollama/pull`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ name }),
|
||||
signal: ctrl.signal,
|
||||
})
|
||||
if (!resp.ok) throw new Error(`HTTP ${resp.status}`)
|
||||
if (!resp.body) throw new Error('No response body')
|
||||
|
||||
const reader = resp.body.getReader()
|
||||
const decoder = new TextDecoder()
|
||||
let buf = ''
|
||||
|
||||
while (true) {
|
||||
const { done, value } = await reader.read()
|
||||
if (done) break
|
||||
buf += decoder.decode(value, { stream: true })
|
||||
const lines = buf.split('\n')
|
||||
buf = lines.pop() ?? ''
|
||||
for (const line of lines) {
|
||||
if (!line.startsWith('data: ')) continue
|
||||
try {
|
||||
const evt = JSON.parse(line.slice(6)) as {
|
||||
status?: string; error?: string; total?: number; completed?: number
|
||||
}
|
||||
if (evt.error) {
|
||||
pullError.value = evt.error
|
||||
break
|
||||
}
|
||||
if (evt.status) pullStatus.value = evt.status
|
||||
if (evt.total && evt.completed) {
|
||||
pullPct.value = Math.round((evt.completed / evt.total) * 100)
|
||||
}
|
||||
if (evt.status === 'success') {
|
||||
pullStatus.value = 'Done!'
|
||||
pullName.value = ''
|
||||
break
|
||||
}
|
||||
} catch { /* skip malformed line */ }
|
||||
}
|
||||
}
|
||||
|
||||
// Refresh model list after the stream closes (success or benign end)
|
||||
await fetchModels()
|
||||
} catch (e) {
|
||||
if (e instanceof Error && e.name === 'AbortError') return
|
||||
pullError.value = e instanceof Error ? e.message : 'Pull failed'
|
||||
} finally {
|
||||
pulling.value = false
|
||||
abortCtrl.value = null
|
||||
}
|
||||
}
|
||||
|
||||
async function deleteModel(name: string) {
|
||||
if (!confirm(`Delete model "${name}" from node ${props.nodeId}?`)) return
|
||||
try {
|
||||
const r = await fetch(
|
||||
`/api/nodes-mgmt/nodes/${props.nodeId}/models/ollama/${encodeURIComponent(name)}`,
|
||||
{ method: 'DELETE' },
|
||||
)
|
||||
if (!r.ok) throw new Error(`HTTP ${r.status}`)
|
||||
await fetchModels()
|
||||
} catch (e) {
|
||||
loadError.value = e instanceof Error ? e.message : 'Delete failed'
|
||||
}
|
||||
}
|
||||
|
||||
function formatSize(bytes: number): string {
|
||||
return (bytes / 1e9).toFixed(1) + ' GB'
|
||||
}
|
||||
|
||||
onMounted(fetchModels)
|
||||
onUnmounted(() => {
|
||||
abortCtrl.value?.abort()
|
||||
fetchAbort?.abort()
|
||||
})
|
||||
</script>
|
||||
|
||||
<template>
|
||||
<section class="ollama-panel">
|
||||
<h3 class="panel-title">Ollama Models</h3>
|
||||
|
||||
<form class="pull-form" @submit.prevent="doPull">
|
||||
<input
|
||||
v-model="pullName"
|
||||
type="text"
|
||||
placeholder="nomic-embed-text, llama3.2:3b, ..."
|
||||
:disabled="pulling"
|
||||
aria-label="Model name to pull from Ollama"
|
||||
class="pull-input"
|
||||
/>
|
||||
<button type="submit" :disabled="pulling || !pullName.trim()" class="btn-primary btn-sm">
|
||||
{{ pulling ? 'Pulling...' : 'Pull' }}
|
||||
</button>
|
||||
</form>
|
||||
|
||||
<div v-if="pulling || pullStatus" class="pull-progress" aria-live="polite">
|
||||
<div
|
||||
class="progress-bar"
|
||||
role="progressbar"
|
||||
:aria-valuenow="pullPct"
|
||||
aria-valuemin="0"
|
||||
aria-valuemax="100"
|
||||
:aria-label="`Pull progress: ${pullStatus}`"
|
||||
>
|
||||
<div class="progress-fill" :style="{ width: `${pullPct}%` }" />
|
||||
</div>
|
||||
<span class="progress-label">{{ pullStatus }}{{ pullPct > 0 ? ` (${pullPct}%)` : '' }}</span>
|
||||
</div>
|
||||
|
||||
<div v-if="pullError" class="pull-error" role="alert">
|
||||
{{ pullError }}
|
||||
<span v-if="pullError.includes('permission denied')">
|
||||
— Remove the partial file on the node and retry.
|
||||
</span>
|
||||
</div>
|
||||
|
||||
<div aria-live="polite" aria-atomic="true" class="sr-announce">
|
||||
<span v-if="loading">Loading...</span>
|
||||
</div>
|
||||
<div v-if="loadError" class="panel-error" role="alert">{{ loadError }}</div>
|
||||
<ul v-if="!loading && !loadError" class="model-list" role="list">
|
||||
<li v-if="!models.length" class="model-empty">No Ollama models installed on this node.</li>
|
||||
<li v-for="m in models" :key="m.name" class="model-item">
|
||||
<span class="model-name">{{ m.name }}</span>
|
||||
<span class="model-size">{{ formatSize(m.size) }}</span>
|
||||
<button
|
||||
class="btn-danger btn-xs"
|
||||
@click="deleteModel(m.name)"
|
||||
:aria-label="`Delete ${m.name}`"
|
||||
>
|
||||
Delete
|
||||
</button>
|
||||
</li>
|
||||
</ul>
|
||||
</section>
|
||||
</template>
|
||||
|
||||
<style scoped>
|
||||
.ollama-panel {
|
||||
margin-top: 0.75rem;
|
||||
padding: 0.75rem;
|
||||
border: 1px solid var(--border, #333);
|
||||
border-radius: 6px;
|
||||
}
|
||||
.panel-title { margin: 0 0 0.75rem; font-size: 0.9rem; }
|
||||
.pull-form { display: flex; gap: 0.5rem; margin-bottom: 0.5rem; }
|
||||
.pull-input {
|
||||
flex: 1;
|
||||
padding: 0.3rem 0.5rem;
|
||||
background: var(--bg-input, #111);
|
||||
border: 1px solid var(--border, #333);
|
||||
border-radius: 4px;
|
||||
color: inherit;
|
||||
font-size: 0.875rem;
|
||||
}
|
||||
.pull-progress { margin-bottom: 0.5rem; }
|
||||
.progress-bar {
|
||||
height: 8px;
|
||||
background: var(--bg-bar, #2a2a2a);
|
||||
border-radius: 4px;
|
||||
overflow: hidden;
|
||||
margin-bottom: 0.25rem;
|
||||
}
|
||||
.progress-fill { height: 100%; background: var(--color-primary, #4080ff); transition: width 0.2s; }
|
||||
.progress-label { font-size: 0.75rem; color: var(--text-secondary, #888); }
|
||||
.pull-error, .panel-error { color: var(--color-error, #fc8181); font-size: 0.8rem; margin-bottom: 0.5rem; }
|
||||
.sr-announce { min-height: 1.2em; }
|
||||
.panel-loading { color: var(--text-secondary, #888); font-size: 0.875rem; }
|
||||
.model-list { list-style: none; margin: 0; padding: 0; display: flex; flex-direction: column; gap: 0.3rem; }
|
||||
.model-item {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.5rem;
|
||||
padding: 0.3rem 0.5rem;
|
||||
background: var(--bg-secondary, #111);
|
||||
border-radius: 4px;
|
||||
font-size: 0.875rem;
|
||||
}
|
||||
.model-name { flex: 1; font-family: monospace; }
|
||||
.model-size { color: var(--text-secondary, #888); font-size: 0.8rem; }
|
||||
.model-empty { color: var(--text-secondary, #888); font-size: 0.875rem; padding: 0.25rem 0; }
|
||||
</style>
|
||||
81
web/src/components/nodes/ServiceBadge.vue
Normal file
81
web/src/components/nodes/ServiceBadge.vue
Normal file
|
|
@ -0,0 +1,81 @@
|
|||
<script setup lang="ts">
|
||||
type ServiceState =
|
||||
| 'running'
|
||||
| 'stopped'
|
||||
| 'assigned-only'
|
||||
| 'available'
|
||||
| 'incompatible'
|
||||
| 'vram-tight'
|
||||
| 'unknown'
|
||||
|
||||
const props = defineProps<{
|
||||
serviceName: string
|
||||
state: ServiceState
|
||||
assigned: boolean
|
||||
disabled?: boolean
|
||||
}>()
|
||||
|
||||
const emit = defineEmits<{ toggle: [] }>()
|
||||
|
||||
const STATE_LABELS: Record<ServiceState, string> = {
|
||||
running: 'Running',
|
||||
stopped: 'Stopped',
|
||||
'assigned-only': 'Assigned',
|
||||
available: 'Available',
|
||||
incompatible: 'Incompatible',
|
||||
'vram-tight': 'VRAM tight',
|
||||
unknown: 'Unknown',
|
||||
}
|
||||
|
||||
const STATE_ICONS: Record<ServiceState, string> = {
|
||||
running: '▶',
|
||||
stopped: '⏹',
|
||||
'assigned-only': '📌',
|
||||
available: '○',
|
||||
incompatible: '✕',
|
||||
'vram-tight': '⚠',
|
||||
unknown: '?',
|
||||
}
|
||||
|
||||
function handleToggle() {
|
||||
if (!props.disabled && props.state !== 'incompatible') emit('toggle')
|
||||
}
|
||||
</script>
|
||||
|
||||
<template>
|
||||
<button
|
||||
class="service-badge"
|
||||
:class="[`state-${state}`, { assigned, 'is-disabled': disabled || state === 'incompatible' }]"
|
||||
:aria-pressed="assigned"
|
||||
:aria-label="`${serviceName}: ${STATE_LABELS[state] ?? state}${assigned ? ' (assigned)' : ''}`"
|
||||
:disabled="disabled || state === 'incompatible'"
|
||||
@click="handleToggle"
|
||||
>
|
||||
<span class="badge-icon" aria-hidden="true">{{ STATE_ICONS[state] ?? '?' }}</span>
|
||||
<span class="badge-name">{{ serviceName }}</span>
|
||||
<span class="badge-state">{{ STATE_LABELS[state] ?? state }}</span>
|
||||
</button>
|
||||
</template>
|
||||
|
||||
<style scoped>
|
||||
.service-badge {
|
||||
display: inline-flex;
|
||||
align-items: center;
|
||||
gap: 0.3rem;
|
||||
padding: 0.2rem 0.5rem;
|
||||
border-radius: 4px;
|
||||
border: 1px solid var(--border, #333);
|
||||
background: var(--bg-badge, #1e1e1e);
|
||||
font-size: 0.75rem;
|
||||
cursor: pointer;
|
||||
transition: opacity 0.1s, border-color 0.1s;
|
||||
}
|
||||
.service-badge:hover:not(.is-disabled) { opacity: 0.8; }
|
||||
.service-badge.is-disabled { cursor: not-allowed; opacity: 0.5; }
|
||||
.service-badge.state-running { border-color: var(--color-success, #48bb78); }
|
||||
.service-badge.state-stopped { border-color: var(--color-warning, #ed8936); }
|
||||
.service-badge.state-assigned-only { border-color: var(--color-info, #4299e1); }
|
||||
.service-badge.state-incompatible { border-color: var(--color-error, #fc8181); }
|
||||
.service-badge.state-vram-tight { border-color: var(--color-warning, #ed8936); }
|
||||
.badge-state { color: var(--text-secondary, #888); }
|
||||
</style>
|
||||
|
|
@ -1,25 +1,51 @@
|
|||
import { createRouter, createWebHashHistory } from 'vue-router'
|
||||
import LabelView from '../views/LabelView.vue'
|
||||
|
||||
// Views are lazy-loaded to keep initial bundle small
|
||||
const FetchView = () => import('../views/FetchView.vue')
|
||||
const StatsView = () => import('../views/StatsView.vue')
|
||||
const BenchmarkView = () => import('../views/BenchmarkView.vue')
|
||||
const SettingsView = () => import('../views/SettingsView.vue')
|
||||
const CorrectionsView = () => import('../views/CorrectionsView.vue')
|
||||
const ModelsView = () => import('../views/ModelsView.vue')
|
||||
const ImitateView = () => import('../views/ImitateView.vue')
|
||||
// Lazy-loaded views
|
||||
const DashboardView = () => import('../views/DashboardView.vue')
|
||||
const LabelView = () => import('../views/LabelView.vue')
|
||||
const FetchView = () => import('../views/FetchView.vue')
|
||||
const CorrectionsView = () => import('../views/CorrectionsView.vue')
|
||||
const ImitateView = () => import('../views/ImitateView.vue')
|
||||
const BenchmarkView = () => import('../views/BenchmarkView.vue')
|
||||
const CompareView = () => import('../views/CompareView.vue')
|
||||
const TrainJobsView = () => import('../views/TrainJobsView.vue')
|
||||
const TrainResultsView = () => import('../views/TrainResultsView.vue')
|
||||
const ModelsView = () => import('../views/ModelsView.vue')
|
||||
const SettingsView = () => import('../views/SettingsView.vue')
|
||||
const NodeManagementView = () => import('../views/NodeManagementView.vue')
|
||||
|
||||
export const routes = [
|
||||
// ── Top-level ────────────────────────────────────────────
|
||||
{ path: '/', component: DashboardView, meta: { title: 'Dashboard' } },
|
||||
{ path: '/fleet', component: ModelsView, meta: { title: 'Fleet' } },
|
||||
{ path: '/nodes', component: NodeManagementView, meta: { title: 'Nodes' } },
|
||||
{ path: '/settings', component: SettingsView, meta: { title: 'Settings' } },
|
||||
|
||||
// ── Data domain ──────────────────────────────────────────
|
||||
{ path: '/data/label', component: LabelView, meta: { title: 'Label' } },
|
||||
{ path: '/data/fetch', component: FetchView, meta: { title: 'Fetch' } },
|
||||
{ path: '/data/corrections', component: CorrectionsView, meta: { title: 'Corrections' } },
|
||||
{ path: '/data/imitate', component: ImitateView, meta: { title: 'Imitate' } },
|
||||
|
||||
// ── Eval domain ──────────────────────────────────────────
|
||||
{ path: '/eval/benchmark', component: BenchmarkView, meta: { title: 'Benchmark' } },
|
||||
{ path: '/eval/compare', component: CompareView, meta: { title: 'Compare' } },
|
||||
|
||||
// ── Train domain ─────────────────────────────────────────
|
||||
{ path: '/train/jobs', component: TrainJobsView, meta: { title: 'Training Jobs' } },
|
||||
{ path: '/train/results', component: TrainResultsView, meta: { title: 'Training Results' } },
|
||||
|
||||
// ── Backward-compat redirects ────────────────────────────
|
||||
{ path: '/benchmark', redirect: '/eval/benchmark' },
|
||||
{ path: '/models', redirect: '/fleet' },
|
||||
{ path: '/stats', redirect: '/' },
|
||||
{ path: '/label', redirect: '/data/label' },
|
||||
{ path: '/fetch', redirect: '/data/fetch' },
|
||||
{ path: '/corrections', redirect: '/data/corrections' },
|
||||
{ path: '/imitate', redirect: '/data/imitate' },
|
||||
]
|
||||
|
||||
export const router = createRouter({
|
||||
history: createWebHashHistory(),
|
||||
routes: [
|
||||
{ path: '/', component: LabelView, meta: { title: 'Label' } },
|
||||
{ path: '/fetch', component: FetchView, meta: { title: 'Fetch' } },
|
||||
{ path: '/stats', component: StatsView, meta: { title: 'Stats' } },
|
||||
{ path: '/benchmark', component: BenchmarkView, meta: { title: 'Benchmark' } },
|
||||
{ path: '/models', component: ModelsView, meta: { title: 'Models' } },
|
||||
{ path: '/imitate', component: ImitateView, meta: { title: 'Imitate' } },
|
||||
{ path: '/corrections', component: CorrectionsView, meta: { title: 'Corrections' } },
|
||||
{ path: '/settings', component: SettingsView, meta: { title: 'Settings' } },
|
||||
],
|
||||
routes,
|
||||
})
|
||||
|
|
|
|||
94
web/src/router/router.test.ts
Normal file
94
web/src/router/router.test.ts
Normal file
|
|
@ -0,0 +1,94 @@
|
|||
import { describe, it, expect } from 'vitest'
|
||||
import { createRouter, createWebHashHistory } from 'vue-router'
|
||||
|
||||
// Import the raw routes array so we can test structure without mounting App
|
||||
import { routes } from './index'
|
||||
|
||||
describe('router routes', () => {
|
||||
it('exports a routes array', () => {
|
||||
expect(Array.isArray(routes)).toBe(true)
|
||||
})
|
||||
|
||||
it('has / pointing to DashboardView', () => {
|
||||
const root = routes.find(r => r.path === '/')
|
||||
expect(root).toBeDefined()
|
||||
// Component should be async (lazy) or have a name
|
||||
expect(root?.component).toBeDefined()
|
||||
})
|
||||
|
||||
it('has /fleet route', () => {
|
||||
const r = routes.find(r => r.path === '/fleet')
|
||||
expect(r).toBeDefined()
|
||||
})
|
||||
|
||||
it('has /data/label route', () => {
|
||||
const r = routes.find(r => r.path === '/data/label')
|
||||
expect(r).toBeDefined()
|
||||
})
|
||||
|
||||
it('has /data/fetch route', () => {
|
||||
const r = routes.find(r => r.path === '/data/fetch')
|
||||
expect(r).toBeDefined()
|
||||
})
|
||||
|
||||
it('has /data/corrections route', () => {
|
||||
const r = routes.find(r => r.path === '/data/corrections')
|
||||
expect(r).toBeDefined()
|
||||
})
|
||||
|
||||
it('has /data/imitate route', () => {
|
||||
const r = routes.find(r => r.path === '/data/imitate')
|
||||
expect(r).toBeDefined()
|
||||
})
|
||||
|
||||
it('has /eval/benchmark route', () => {
|
||||
const r = routes.find(r => r.path === '/eval/benchmark')
|
||||
expect(r).toBeDefined()
|
||||
})
|
||||
|
||||
it('has /eval/compare route', () => {
|
||||
const r = routes.find(r => r.path === '/eval/compare')
|
||||
expect(r).toBeDefined()
|
||||
})
|
||||
|
||||
it('has /train/jobs route', () => {
|
||||
const r = routes.find(r => r.path === '/train/jobs')
|
||||
expect(r).toBeDefined()
|
||||
})
|
||||
|
||||
it('has /train/results route', () => {
|
||||
const r = routes.find(r => r.path === '/train/results')
|
||||
expect(r).toBeDefined()
|
||||
})
|
||||
|
||||
it('has /settings route', () => {
|
||||
const r = routes.find(r => r.path === '/settings')
|
||||
expect(r).toBeDefined()
|
||||
})
|
||||
|
||||
it('has backward-compat redirect from /benchmark to /eval/benchmark', () => {
|
||||
const r = routes.find(r => r.path === '/benchmark')
|
||||
expect(r).toBeDefined()
|
||||
expect((r as { redirect?: string }).redirect).toBe('/eval/benchmark')
|
||||
})
|
||||
|
||||
it('has backward-compat redirect from /models to /fleet', () => {
|
||||
const r = routes.find(r => r.path === '/models')
|
||||
expect(r).toBeDefined()
|
||||
expect((r as { redirect?: string }).redirect).toBe('/fleet')
|
||||
})
|
||||
|
||||
it('has backward-compat redirect from /stats to /', () => {
|
||||
const r = routes.find(r => r.path === '/stats')
|
||||
expect(r).toBeDefined()
|
||||
expect((r as { redirect?: string }).redirect).toBe('/')
|
||||
})
|
||||
|
||||
it('can create a functional router instance', () => {
|
||||
const router = createRouter({
|
||||
history: createWebHashHistory(),
|
||||
routes,
|
||||
})
|
||||
expect(router).toBeDefined()
|
||||
})
|
||||
})
|
||||
27
web/src/types/nodes.ts
Normal file
27
web/src/types/nodes.ts
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
export interface GpuEntry {
|
||||
gpu_id: number
|
||||
card: string
|
||||
vram_total_mb: number
|
||||
vram_used_mb: number
|
||||
vram_free_mb: number
|
||||
temp_c: number | null
|
||||
utilization_pct: number | null
|
||||
compute_cap: number | null
|
||||
services_assigned: string[]
|
||||
services_running: string[]
|
||||
}
|
||||
|
||||
export interface ServiceInfo {
|
||||
min_compute_cap: number
|
||||
max_mb: number
|
||||
catalog_size: number
|
||||
}
|
||||
|
||||
export interface NodeSummary {
|
||||
node_id: string
|
||||
online: boolean
|
||||
agent_url: string
|
||||
gpus: GpuEntry[]
|
||||
profile_loaded: boolean
|
||||
services_catalog: Record<string, ServiceInfo>
|
||||
}
|
||||
82
web/src/views/BenchmarkView.test.ts
Normal file
82
web/src/views/BenchmarkView.test.ts
Normal file
|
|
@ -0,0 +1,82 @@
|
|||
import { mount, flushPromises } from '@vue/test-utils'
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest'
|
||||
import BenchmarkView from './BenchmarkView.vue'
|
||||
|
||||
beforeEach(() => {
|
||||
vi.stubGlobal('fetch', vi.fn().mockImplementation((url: string) => {
|
||||
// LlmEvalTab calls /api/cforch/models and expects { models: CfOrchModel[] }
|
||||
if (url.includes('/api/cforch/models')) {
|
||||
return Promise.resolve({
|
||||
ok: true,
|
||||
json: async () => ({ models: [] }),
|
||||
text: async () => '',
|
||||
})
|
||||
}
|
||||
// Default: satisfies ClassifierTab (/api/benchmark/results, /api/benchmark/models,
|
||||
// /api/finetune/status), StyleTab (/api/style/models, /api/style/results),
|
||||
// and any other tab that tolerates empty arrays/objects.
|
||||
return Promise.resolve({
|
||||
ok: true,
|
||||
json: async () => ({ models: {}, categories: {}, tasks: [], types: [], results: [] }),
|
||||
text: async () => '',
|
||||
})
|
||||
}))
|
||||
vi.stubGlobal('EventSource', class {
|
||||
onmessage = null
|
||||
onerror = null
|
||||
close() {}
|
||||
})
|
||||
})
|
||||
|
||||
describe('BenchmarkView', () => {
|
||||
it('renders page title "Benchmark"', async () => {
|
||||
const w = mount(BenchmarkView)
|
||||
await flushPromises()
|
||||
expect(w.text()).toContain('Benchmark')
|
||||
})
|
||||
|
||||
it('has mode buttons: Classifier, LLM Eval, Writing Style', async () => {
|
||||
const w = mount(BenchmarkView)
|
||||
await flushPromises()
|
||||
const text = w.text()
|
||||
expect(text).toContain('Classifier')
|
||||
expect(text).toContain('LLM Eval')
|
||||
expect(text).toContain('Writing Style')
|
||||
})
|
||||
|
||||
it('does NOT have a Compare mode button', async () => {
|
||||
const w = mount(BenchmarkView)
|
||||
await flushPromises()
|
||||
const buttons = w.findAll('.mode-btn')
|
||||
const labels = buttons.map(b => b.text())
|
||||
expect(labels.every(l => !l.includes('Compare'))).toBe(true)
|
||||
})
|
||||
|
||||
it('shows Classifier tab by default', async () => {
|
||||
const w = mount(BenchmarkView)
|
||||
await flushPromises()
|
||||
// ClassifierTab has a .classifier-tab root
|
||||
expect(w.find('.classifier-tab').exists()).toBe(true)
|
||||
})
|
||||
|
||||
it('switches to LlmEvalTab when LLM Eval clicked', async () => {
|
||||
const w = mount(BenchmarkView)
|
||||
await flushPromises()
|
||||
const llmBtn = w.findAll('.mode-btn').find(b => b.text().includes('LLM Eval'))!
|
||||
await llmBtn.trigger('click')
|
||||
await flushPromises()
|
||||
expect(w.find('.llm-eval-tab').exists()).toBe(true)
|
||||
expect(w.find('.classifier-tab').exists()).toBe(false)
|
||||
expect(llmBtn.classes()).toContain('active')
|
||||
})
|
||||
|
||||
it('switches to StyleTab when Writing Style clicked', async () => {
|
||||
const w = mount(BenchmarkView)
|
||||
await flushPromises()
|
||||
const styleBtn = w.findAll('.mode-btn').find(b => b.text().includes('Writing Style'))!
|
||||
await styleBtn.trigger('click')
|
||||
await flushPromises()
|
||||
expect(w.find('.style-tab').exists()).toBe(true)
|
||||
expect(w.find('.classifier-tab').exists()).toBe(false)
|
||||
})
|
||||
})
|
||||
|
|
@ -16,33 +16,33 @@
|
|||
:class="{ active: benchMode === 'llm' }"
|
||||
@click="benchMode = 'llm'"
|
||||
>🤖 LLM Eval</button>
|
||||
<button
|
||||
class="mode-btn"
|
||||
:class="{ active: benchMode === 'compare' }"
|
||||
@click="benchMode = 'compare'"
|
||||
>⚖️ Compare</button>
|
||||
<button
|
||||
class="mode-btn"
|
||||
:class="{ active: benchMode === 'style' }"
|
||||
@click="benchMode = 'style'"
|
||||
>✍️ Writing Style</button>
|
||||
<button
|
||||
class="mode-btn"
|
||||
:class="{ active: benchMode === 'plans' }"
|
||||
@click="benchMode = 'plans'"
|
||||
>📐 Planning</button>
|
||||
</div>
|
||||
|
||||
<ClassifierTab v-if="benchMode === 'classifier'" />
|
||||
<LlmEvalTab v-if="benchMode === 'llm'" />
|
||||
<CompareTab v-if="benchMode === 'compare'" />
|
||||
<StyleTab v-if="benchMode === 'style'" />
|
||||
<ClassifierTab v-if="benchMode === 'classifier'" />
|
||||
<LlmEvalTab v-if="benchMode === 'llm'" />
|
||||
<StyleTab v-if="benchMode === 'style'" />
|
||||
<PlansBenchTab v-if="benchMode === 'plans'" />
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref } from 'vue'
|
||||
import ClassifierTab from './ClassifierTab.vue'
|
||||
import LlmEvalTab from './LlmEvalTab.vue'
|
||||
import CompareTab from './CompareTab.vue'
|
||||
import StyleTab from './StyleTab.vue'
|
||||
import ClassifierTab from './ClassifierTab.vue'
|
||||
import LlmEvalTab from './LlmEvalTab.vue'
|
||||
import StyleTab from './StyleTab.vue'
|
||||
import PlansBenchTab from './PlansBenchTab.vue'
|
||||
|
||||
type BenchMode = 'classifier' | 'llm' | 'compare' | 'style'
|
||||
type BenchMode = 'classifier' | 'llm' | 'style' | 'plans'
|
||||
const benchMode = ref<BenchMode>('classifier')
|
||||
</script>
|
||||
|
||||
|
|
@ -69,7 +69,7 @@ const benchMode = ref<BenchMode>('classifier')
|
|||
margin: 0;
|
||||
}
|
||||
|
||||
/* ── Mode toggle (segmented control) ────────────────────── */
|
||||
/* ── Mode toggle (segmented control) ── */
|
||||
.mode-toggle {
|
||||
display: inline-flex;
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
|
|
|
|||
31
web/src/views/CompareView.test.ts
Normal file
31
web/src/views/CompareView.test.ts
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
import { mount, flushPromises } from '@vue/test-utils'
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest'
|
||||
import CompareView from './CompareView.vue'
|
||||
|
||||
beforeEach(() => {
|
||||
vi.stubGlobal('fetch', vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
json: async () => ({ tasks: [], types: [], models: [] }),
|
||||
text: async () => '',
|
||||
}))
|
||||
vi.stubGlobal('EventSource', class {
|
||||
onmessage = null
|
||||
onerror = null
|
||||
close() {}
|
||||
})
|
||||
})
|
||||
|
||||
describe('CompareView', () => {
|
||||
it('renders page title "Compare"', async () => {
|
||||
const w = mount(CompareView)
|
||||
await flushPromises()
|
||||
expect(w.find('h1.page-title').text()).toContain('Compare')
|
||||
})
|
||||
|
||||
it('wraps CompareTab component', async () => {
|
||||
const w = mount(CompareView)
|
||||
await flushPromises()
|
||||
// CompareTab renders a .compare-tab root div
|
||||
expect(w.find('.compare-tab').exists()).toBe(true)
|
||||
})
|
||||
})
|
||||
36
web/src/views/CompareView.vue
Normal file
36
web/src/views/CompareView.vue
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
<template>
|
||||
<div class="compare-view">
|
||||
<header class="compare-header">
|
||||
<h1 class="page-title">🔍 Compare</h1>
|
||||
</header>
|
||||
<CompareTab />
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import CompareTab from './CompareTab.vue'
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
.compare-view {
|
||||
max-width: 860px;
|
||||
margin: 0 auto;
|
||||
padding: 1.5rem 1rem 4rem;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 1.75rem;
|
||||
}
|
||||
|
||||
.compare-header {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
.page-title {
|
||||
font-family: var(--font-display, var(--font-body, sans-serif));
|
||||
font-size: 1.4rem;
|
||||
font-weight: 700;
|
||||
color: var(--app-primary, #2A6080);
|
||||
margin: 0;
|
||||
}
|
||||
</style>
|
||||
119
web/src/views/DashboardView.test.ts
Normal file
119
web/src/views/DashboardView.test.ts
Normal file
|
|
@ -0,0 +1,119 @@
|
|||
import { mount, flushPromises } from '@vue/test-utils'
|
||||
import { createRouter, createWebHashHistory } from 'vue-router'
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest'
|
||||
import DashboardView from './DashboardView.vue'
|
||||
|
||||
const router = createRouter({
|
||||
history: createWebHashHistory(),
|
||||
routes: [
|
||||
{ path: '/', component: { template: '<div />' } },
|
||||
{ path: '/eval/benchmark', component: { template: '<div />' } },
|
||||
{ path: '/train/jobs', component: { template: '<div />' } },
|
||||
{ path: '/fleet', component: { template: '<div />' } },
|
||||
],
|
||||
})
|
||||
|
||||
const baseDashboard = {
|
||||
labeled_since_last_eval: 0,
|
||||
last_eval_timestamp: null,
|
||||
last_eval_best_score: null,
|
||||
active_jobs: [],
|
||||
corrections_export_ready: 0,
|
||||
signals: { data_to_eval: false, eval_to_train: false, train_to_fleet: false },
|
||||
}
|
||||
|
||||
function mockFetch(overrides: Partial<typeof baseDashboard> = {}) {
|
||||
vi.stubGlobal('fetch', vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
json: async () => ({ ...baseDashboard, ...overrides }),
|
||||
text: async () => '',
|
||||
}))
|
||||
}
|
||||
|
||||
beforeEach(() => mockFetch())
|
||||
|
||||
describe('DashboardView', () => {
|
||||
it('renders page title', async () => {
|
||||
const w = mount(DashboardView, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
expect(w.text()).toContain('Dashboard')
|
||||
})
|
||||
|
||||
it('shows three stage cards: Data, Eval, Train', async () => {
|
||||
const w = mount(DashboardView, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
expect(w.find('.stage-card[data-stage="data"]').exists()).toBe(true)
|
||||
expect(w.find('.stage-card[data-stage="eval"]').exists()).toBe(true)
|
||||
expect(w.find('.stage-card[data-stage="train"]').exists()).toBe(true)
|
||||
})
|
||||
|
||||
it('shows labeled_since_last_eval count in Data card', async () => {
|
||||
mockFetch({ labeled_since_last_eval: 42 })
|
||||
const w = mount(DashboardView, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
expect(w.find('.stage-card[data-stage="data"]').text()).toContain('42')
|
||||
})
|
||||
|
||||
it('does NOT show Run Eval CTA when data_to_eval is false', async () => {
|
||||
mockFetch({ signals: { data_to_eval: false, eval_to_train: false, train_to_fleet: false } })
|
||||
const w = mount(DashboardView, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
const dataCard = w.find('.stage-card[data-stage="data"]')
|
||||
expect(dataCard.find('.cta-btn').exists()).toBe(false)
|
||||
})
|
||||
|
||||
it('shows Run Eval CTA when data_to_eval is true', async () => {
|
||||
mockFetch({ signals: { data_to_eval: true, eval_to_train: false, train_to_fleet: false } })
|
||||
const w = mount(DashboardView, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
const dataCard = w.find('.stage-card[data-stage="data"]')
|
||||
expect(dataCard.find('.cta-btn').exists()).toBe(true)
|
||||
expect(dataCard.find('.cta-btn').text()).toContain('Run Eval')
|
||||
})
|
||||
|
||||
it('shows Queue Finetune CTA when eval_to_train is true', async () => {
|
||||
mockFetch({ signals: { data_to_eval: false, eval_to_train: true, train_to_fleet: false } })
|
||||
const w = mount(DashboardView, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
const evalCard = w.find('.stage-card[data-stage="eval"]')
|
||||
expect(evalCard.find('.cta-btn').text()).toContain('Queue Finetune')
|
||||
})
|
||||
|
||||
it('shows Register in Fleet CTA when train_to_fleet is true', async () => {
|
||||
mockFetch({ signals: { data_to_eval: false, eval_to_train: false, train_to_fleet: true } })
|
||||
const w = mount(DashboardView, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
const trainCard = w.find('.stage-card[data-stage="train"]')
|
||||
expect(trainCard.find('.cta-btn').text()).toContain('Register in Fleet')
|
||||
})
|
||||
|
||||
it('shows active job status pills in Train card', async () => {
|
||||
mockFetch({ active_jobs: [{ id: 'j1', type: 'classifier', model_key: 'deberta-v3', status: 'running' }] })
|
||||
const w = mount(DashboardView, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
const trainCard = w.find('.stage-card[data-stage="train"]')
|
||||
expect(trainCard.find('.status-pill').exists()).toBe(true)
|
||||
expect(trainCard.text()).toContain('deberta-v3')
|
||||
})
|
||||
|
||||
it('shows last eval score in Eval card when present', async () => {
|
||||
mockFetch({ last_eval_best_score: 0.821 })
|
||||
const w = mount(DashboardView, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
const evalCard = w.find('.stage-card[data-stage="eval"]')
|
||||
expect(evalCard.text()).toContain('82.1%')
|
||||
})
|
||||
|
||||
it('shows error state when API call fails', async () => {
|
||||
vi.stubGlobal('fetch', vi.fn().mockResolvedValue({ ok: false, status: 503, text: async () => '' }))
|
||||
const w = mount(DashboardView, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
expect(w.find('.error-notice').exists()).toBe(true)
|
||||
})
|
||||
|
||||
it('shows refresh button', async () => {
|
||||
const w = mount(DashboardView, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
expect(w.find('.refresh-btn').exists()).toBe(true)
|
||||
})
|
||||
})
|
||||
347
web/src/views/DashboardView.vue
Normal file
347
web/src/views/DashboardView.vue
Normal file
|
|
@ -0,0 +1,347 @@
|
|||
<template>
|
||||
<div class="dashboard-view">
|
||||
<header class="dashboard-header">
|
||||
<h1 class="page-title">📊 Dashboard</h1>
|
||||
<button class="refresh-btn" :disabled="loading" @click="load" aria-label="Refresh dashboard">
|
||||
🔄
|
||||
</button>
|
||||
</header>
|
||||
|
||||
<div v-if="loading && !data" class="loading-state">Loading…</div>
|
||||
|
||||
<div v-if="error" class="error-notice" role="alert">
|
||||
{{ error }}
|
||||
<button class="btn-retry" @click="load">Retry</button>
|
||||
</div>
|
||||
|
||||
<div v-if="data" class="flywheel-grid">
|
||||
|
||||
<!-- ① Data card -->
|
||||
<div class="stage-card" data-stage="data">
|
||||
<div class="card-header">
|
||||
<span class="card-step">①</span>
|
||||
<h2 class="card-title">Data</h2>
|
||||
</div>
|
||||
<div class="card-body">
|
||||
<p class="card-metric">
|
||||
<strong class="metric-value">{{ data.labeled_since_last_eval.toLocaleString() }}</strong>
|
||||
<span class="metric-label"> labeled since last eval</span>
|
||||
</p>
|
||||
</div>
|
||||
<div v-if="data.signals.data_to_eval" class="card-cta">
|
||||
<RouterLink to="/eval/benchmark" class="cta-btn">Run Eval</RouterLink>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- ② Eval card -->
|
||||
<div class="stage-card" data-stage="eval">
|
||||
<div class="card-header">
|
||||
<span class="card-step">②</span>
|
||||
<h2 class="card-title">Eval</h2>
|
||||
</div>
|
||||
<div class="card-body">
|
||||
<p class="card-metric">
|
||||
<span class="metric-label">Last run: </span>
|
||||
<strong class="metric-value">{{ formattedEvalTime }}</strong>
|
||||
</p>
|
||||
<p v-if="data.last_eval_best_score != null" class="card-metric">
|
||||
<span class="metric-label">Best score: </span>
|
||||
<strong class="metric-value">{{ formatScore(data.last_eval_best_score) }}</strong>
|
||||
</p>
|
||||
</div>
|
||||
<div v-if="data.signals.eval_to_train" class="card-cta">
|
||||
<RouterLink to="/train/jobs" class="cta-btn">Queue Finetune</RouterLink>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- ③ Train card -->
|
||||
<div class="stage-card" data-stage="train">
|
||||
<div class="card-header">
|
||||
<span class="card-step">③</span>
|
||||
<h2 class="card-title">Train</h2>
|
||||
</div>
|
||||
<div class="card-body">
|
||||
<template v-if="data.active_jobs.length > 0">
|
||||
<div
|
||||
v-for="job in data.active_jobs"
|
||||
:key="job.id"
|
||||
class="job-row"
|
||||
>
|
||||
<span class="job-key">{{ job.model_key }}</span>
|
||||
<span class="status-pill" :class="`status-${job.status}`">{{ job.status }}</span>
|
||||
</div>
|
||||
</template>
|
||||
<p v-else class="card-metric metric-muted">No active jobs</p>
|
||||
|
||||
<p v-if="data.corrections_export_ready > 0" class="card-metric">
|
||||
<strong class="metric-value">{{ data.corrections_export_ready }}</strong>
|
||||
<span class="metric-label"> corrections ready</span>
|
||||
</p>
|
||||
</div>
|
||||
<div v-if="data.signals.train_to_fleet" class="card-cta">
|
||||
<RouterLink to="/fleet" class="cta-btn">Register in Fleet</RouterLink>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, computed, onMounted } from 'vue'
|
||||
import { RouterLink } from 'vue-router'
|
||||
|
||||
interface ActiveJob {
|
||||
id: string
|
||||
type: string
|
||||
model_key: string
|
||||
status: 'queued' | 'running' | 'completed' | 'failed' | 'cancelled'
|
||||
}
|
||||
|
||||
interface DashboardSignals {
|
||||
data_to_eval: boolean
|
||||
eval_to_train: boolean
|
||||
train_to_fleet: boolean
|
||||
}
|
||||
|
||||
interface DashboardData {
|
||||
labeled_since_last_eval: number
|
||||
last_eval_timestamp: string | null
|
||||
last_eval_best_score: number | null
|
||||
active_jobs: ActiveJob[]
|
||||
corrections_export_ready: number
|
||||
signals: DashboardSignals
|
||||
}
|
||||
|
||||
const data = ref<DashboardData | null>(null)
|
||||
const loading = ref(false)
|
||||
const error = ref<string | null>(null)
|
||||
|
||||
const formattedEvalTime = computed(() => {
|
||||
if (!data.value?.last_eval_timestamp) return 'Never'
|
||||
const date = new Date(data.value.last_eval_timestamp)
|
||||
if (isNaN(date.getTime())) return 'Unknown'
|
||||
const now = Date.now()
|
||||
const diff = now - date.getTime()
|
||||
const mins = Math.floor(diff / 60000)
|
||||
if (mins < 1) return 'just now'
|
||||
if (mins < 60) return `${mins}m ago`
|
||||
const hrs = Math.floor(mins / 60)
|
||||
if (hrs < 24) return `${hrs}h ago`
|
||||
const days = Math.floor(hrs / 24)
|
||||
return `${days}d ago`
|
||||
})
|
||||
|
||||
function formatScore(score: number): string {
|
||||
return `${(score * 100).toFixed(1)}%`
|
||||
}
|
||||
|
||||
async function load() {
|
||||
loading.value = true
|
||||
error.value = null
|
||||
try {
|
||||
const res = await fetch('/api/dashboard')
|
||||
if (!res.ok) {
|
||||
error.value = `Could not load dashboard (HTTP ${res.status}).`
|
||||
return
|
||||
}
|
||||
data.value = await res.json() as DashboardData
|
||||
} catch {
|
||||
error.value = 'Network error. Is the Avocet API running?'
|
||||
} finally {
|
||||
loading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
onMounted(() => load())
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
.dashboard-view {
|
||||
max-width: 860px;
|
||||
margin: 0 auto;
|
||||
padding: 1.5rem 1rem 4rem;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 1.75rem;
|
||||
}
|
||||
|
||||
.dashboard-header {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.75rem;
|
||||
}
|
||||
|
||||
.page-title {
|
||||
font-family: var(--font-display, var(--font-body, sans-serif));
|
||||
font-size: 1.4rem;
|
||||
font-weight: 700;
|
||||
color: var(--app-primary, #2A6080);
|
||||
margin: 0;
|
||||
flex: 1;
|
||||
}
|
||||
|
||||
.refresh-btn {
|
||||
background: transparent;
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: 0.375rem;
|
||||
cursor: pointer;
|
||||
font-size: 1rem;
|
||||
padding: 0.3rem 0.5rem;
|
||||
transition: background 0.15s;
|
||||
}
|
||||
|
||||
.refresh-btn:hover:not(:disabled) { background: var(--color-surface-raised, #e4ebf5); }
|
||||
.refresh-btn:disabled { opacity: 0.5; cursor: not-allowed; }
|
||||
|
||||
/* ── Flywheel grid ── */
|
||||
.flywheel-grid {
|
||||
display: grid;
|
||||
grid-template-columns: repeat(3, 1fr);
|
||||
gap: 1rem;
|
||||
}
|
||||
|
||||
@media (max-width: 680px) {
|
||||
.flywheel-grid {
|
||||
grid-template-columns: 1fr;
|
||||
}
|
||||
}
|
||||
|
||||
/* ── Stage cards ── */
|
||||
.stage-card {
|
||||
background: var(--color-surface-raised, #f5f7fc);
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: var(--radius-lg, 1rem);
|
||||
padding: 1rem;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0.75rem;
|
||||
box-shadow: var(--shadow-sm);
|
||||
}
|
||||
|
||||
.card-header {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.5rem;
|
||||
border-bottom: 1px solid var(--color-border, #d0d7e8);
|
||||
padding-bottom: 0.6rem;
|
||||
}
|
||||
|
||||
.card-step {
|
||||
font-size: 1.1rem;
|
||||
font-weight: 700;
|
||||
color: var(--app-primary, #2A6080);
|
||||
flex-shrink: 0;
|
||||
}
|
||||
|
||||
.card-title {
|
||||
font-family: var(--font-display, var(--font-body, sans-serif));
|
||||
font-size: 1rem;
|
||||
font-weight: 600;
|
||||
color: var(--color-text, #1a2338);
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
.card-body {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0.4rem;
|
||||
flex: 1;
|
||||
}
|
||||
|
||||
.card-metric {
|
||||
margin: 0;
|
||||
font-size: 0.875rem;
|
||||
color: var(--color-text, #1a2338);
|
||||
}
|
||||
|
||||
.metric-value {
|
||||
font-size: 1.05rem;
|
||||
font-weight: 700;
|
||||
color: var(--app-primary, #2A6080);
|
||||
}
|
||||
|
||||
.metric-label {
|
||||
color: var(--color-text-muted, #4a5c7a);
|
||||
}
|
||||
|
||||
.metric-muted { color: var(--color-text-muted, #4a5c7a); }
|
||||
|
||||
.card-cta { margin-top: auto; }
|
||||
|
||||
.cta-btn {
|
||||
display: block;
|
||||
width: 100%;
|
||||
text-align: center;
|
||||
padding: 0.5rem;
|
||||
background: var(--app-primary, #2A6080);
|
||||
color: #fff;
|
||||
border-radius: 0.375rem;
|
||||
text-decoration: none;
|
||||
font-size: 0.875rem;
|
||||
font-weight: 600;
|
||||
transition: background 0.15s;
|
||||
}
|
||||
|
||||
.cta-btn:hover { background: color-mix(in srgb, var(--app-primary, #2A6080) 85%, black); }
|
||||
|
||||
/* ── Job pills ── */
|
||||
.job-row {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.5rem;
|
||||
font-size: 0.875rem;
|
||||
}
|
||||
|
||||
.job-key {
|
||||
flex: 1;
|
||||
overflow: hidden;
|
||||
text-overflow: ellipsis;
|
||||
white-space: nowrap;
|
||||
color: var(--color-text, #1a2338);
|
||||
}
|
||||
|
||||
.status-pill {
|
||||
font-size: 0.75rem;
|
||||
padding: 0.15rem 0.45rem;
|
||||
border-radius: 100px;
|
||||
font-weight: 600;
|
||||
flex-shrink: 0;
|
||||
background: var(--color-surface-raised, #e4ebf5);
|
||||
color: var(--color-text-muted, #4a5c7a);
|
||||
}
|
||||
|
||||
.status-pill.status-running { background: #d4f4e0; color: #1a7a3a; }
|
||||
.status-pill.status-queued { background: #fef3cd; color: #856404; }
|
||||
.status-pill.status-failed { background: #fde8e8; color: #842029; }
|
||||
.status-pill.status-completed { background: #e0f0ff; color: #0c5481; }
|
||||
|
||||
/* ── State indicators ── */
|
||||
.loading-state {
|
||||
color: var(--color-text-muted, #4a5c7a);
|
||||
font-size: 0.9rem;
|
||||
}
|
||||
|
||||
.error-notice {
|
||||
background: #fde8e8;
|
||||
color: #842029;
|
||||
border: 1px solid #f5c2c7;
|
||||
border-radius: 0.5rem;
|
||||
padding: 0.75rem 1rem;
|
||||
font-size: 0.875rem;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.75rem;
|
||||
}
|
||||
|
||||
.btn-retry {
|
||||
background: transparent;
|
||||
border: 1px solid currentColor;
|
||||
border-radius: 0.25rem;
|
||||
color: inherit;
|
||||
cursor: pointer;
|
||||
font-size: 0.75rem;
|
||||
padding: 0.2rem 0.5rem;
|
||||
margin-left: auto;
|
||||
}
|
||||
</style>
|
||||
|
|
@ -6,6 +6,8 @@
|
|||
<summary class="picker-summary">
|
||||
<span class="picker-title">📋 Task Selection</span>
|
||||
<span class="picker-badge">{{ llmTaskBadge }}</span>
|
||||
<button class="picker-bulk-btn" @click.stop.prevent="selectAllTasks()">All</button>
|
||||
<button class="picker-bulk-btn" @click.stop.prevent="clearAllTasks()">None</button>
|
||||
</summary>
|
||||
<div class="picker-body">
|
||||
<div v-if="llmTasksLoading" class="picker-loading">Loading tasks…</div>
|
||||
|
|
@ -44,6 +46,8 @@
|
|||
<summary class="picker-summary">
|
||||
<span class="picker-title">🎯 Model Selection</span>
|
||||
<span class="picker-badge">{{ llmModelBadge }}</span>
|
||||
<button class="picker-bulk-btn" @click.stop.prevent="selectAllModels()">All</button>
|
||||
<button class="picker-bulk-btn" @click.stop.prevent="clearAllModels()">None</button>
|
||||
</summary>
|
||||
<div class="picker-body">
|
||||
<div v-if="llmModelsLoading" class="picker-loading">Loading models…</div>
|
||||
|
|
@ -78,6 +82,33 @@
|
|||
</div>
|
||||
</details>
|
||||
|
||||
<!-- Node Selection -->
|
||||
<div class="node-picker" v-if="llmNodes.length > 0">
|
||||
<span class="node-picker-label">Nodes:</span>
|
||||
<label
|
||||
v-for="node in llmNodes"
|
||||
:key="node.node_id"
|
||||
class="node-chip"
|
||||
:class="{ 'node-chip--off': !enabledNodes.has(node.node_id), 'node-chip--offline': !node.online }"
|
||||
:title="node.online ? `${node.node_id} — ${node.gpus.length} GPU(s)` : `${node.node_id} — offline`"
|
||||
>
|
||||
<input
|
||||
type="checkbox"
|
||||
class="node-chip-check"
|
||||
:checked="enabledNodes.has(node.node_id)"
|
||||
:disabled="!node.online || llmRunning"
|
||||
@change="toggleNode(node.node_id, ($event.target as HTMLInputElement).checked)"
|
||||
/>
|
||||
{{ node.node_id }}
|
||||
<span class="node-chip-status" v-if="!node.online">offline</span>
|
||||
</label>
|
||||
<span class="node-picker-hint">
|
||||
{{ enabledNodeIds.length === llmNodes.filter(n => n.online).length
|
||||
? 'auto-routing (all nodes)'
|
||||
: `restricted to: ${enabledNodeIds.join(', ')}` }}
|
||||
</span>
|
||||
</div>
|
||||
|
||||
<!-- Run Controls -->
|
||||
<div class="run-controls">
|
||||
<button
|
||||
|
|
@ -88,6 +119,24 @@
|
|||
{{ llmRunning ? '⏳ Running…' : '▶ Run LLM Eval' }}
|
||||
</button>
|
||||
<button v-if="llmRunning" class="btn-cancel" @click="cancelLlmBenchmark">✕ Cancel</button>
|
||||
<input
|
||||
v-model="llmJudgeUrl"
|
||||
class="judge-url-input"
|
||||
placeholder="Judge URL — leave empty to skip LLM judge scoring"
|
||||
:disabled="llmRunning"
|
||||
title="Optional: URL of a running cf-text service (e.g. http://10.1.10.158:8008). When set, each LLM response gets a secondary score from the judge model — adds a 'judge' column to results. Empty = primary quality scoring only."
|
||||
/>
|
||||
<label class="workers-label" title="Run this many models concurrently (requires multiple GPUs)">
|
||||
<span class="workers-prefix">workers</span>
|
||||
<input
|
||||
v-model.number="llmWorkers"
|
||||
type="number"
|
||||
min="1"
|
||||
max="8"
|
||||
class="workers-input"
|
||||
:disabled="llmRunning"
|
||||
/>
|
||||
</label>
|
||||
<span v-if="selectedLlmTasks.size === 0 || selectedLlmModels.size === 0" class="run-hint">
|
||||
Select at least one task and one model to run.
|
||||
</span>
|
||||
|
|
@ -119,6 +168,7 @@
|
|||
<tr>
|
||||
<th class="hm-label-col">Model</th>
|
||||
<th class="hm-model-col">overall</th>
|
||||
<th v-if="llmHasJudge" class="hm-model-col hm-judge-col">judge</th>
|
||||
<th v-for="col in llmTaskTypeCols" :key="col" class="hm-model-col">{{ col }}</th>
|
||||
<th class="hm-model-col">tok/s</th>
|
||||
</tr>
|
||||
|
|
@ -130,6 +180,12 @@
|
|||
class="hm-value-cell"
|
||||
:class="{ 'bt-best': llmBestByCol['overall'] === row.model_id }"
|
||||
>{{ pct(row.avg_quality_score) }}</td>
|
||||
<td
|
||||
v-if="llmHasJudge"
|
||||
class="hm-value-cell hm-judge-cell"
|
||||
:class="{ 'bt-best': llmBestByCol['judge'] === row.model_id }"
|
||||
title="LLM-as-judge secondary score"
|
||||
>{{ row.avg_judge_score != null ? pct(row.avg_judge_score) : '—' }}</td>
|
||||
<td
|
||||
v-for="col in llmTaskTypeCols"
|
||||
:key="col"
|
||||
|
|
@ -168,6 +224,12 @@ interface CfOrchModel {
|
|||
vram_estimate_mb?: number
|
||||
}
|
||||
|
||||
interface CfOrchNode {
|
||||
node_id: string
|
||||
online: boolean
|
||||
gpus: { gpu_id: number; name: string; vram_total_mb: number; vram_free_mb: number }[]
|
||||
}
|
||||
|
||||
interface LlmModelResult {
|
||||
model_name: string
|
||||
model_id: string
|
||||
|
|
@ -175,9 +237,11 @@ interface LlmModelResult {
|
|||
avg_tokens_per_sec: number
|
||||
avg_completion_ms: number
|
||||
avg_quality_score: number
|
||||
avg_judge_score: number | null
|
||||
finetune_candidates: number
|
||||
error_count: number
|
||||
quality_by_task_type: Record<string, number>
|
||||
judge_score_by_task_type?: Record<string, number>
|
||||
}
|
||||
|
||||
// ── State ───────────────────────────────────────────────────────────────────
|
||||
|
|
@ -195,6 +259,10 @@ const llmError = ref('')
|
|||
const llmResults = ref<LlmModelResult[]>([])
|
||||
const llmEventSource = ref<EventSource | null>(null)
|
||||
const llmLogEl = ref<HTMLElement | null>(null)
|
||||
const llmJudgeUrl = ref('')
|
||||
const llmWorkers = ref(1)
|
||||
const llmNodes = ref<CfOrchNode[]>([])
|
||||
const enabledNodes = ref<Set<string>>(new Set())
|
||||
|
||||
// ── Computed ────────────────────────────────────────────────────────────────
|
||||
const llmTasksByType = computed((): Record<string, CfOrchTask[]> => {
|
||||
|
|
@ -239,6 +307,14 @@ const llmTaskTypeCols = computed(() => {
|
|||
return [...types].sort()
|
||||
})
|
||||
|
||||
const llmHasJudge = computed(() =>
|
||||
llmResults.value.some(r => r.avg_judge_score != null)
|
||||
)
|
||||
|
||||
const enabledNodeIds = computed(() =>
|
||||
llmNodes.value.filter(n => n.online && enabledNodes.value.has(n.node_id)).map(n => n.node_id)
|
||||
)
|
||||
|
||||
const llmBestByCol = computed((): Record<string, string> => {
|
||||
const best: Record<string, string> = {}
|
||||
if (llmResults.value.length === 0) return best
|
||||
|
|
@ -249,6 +325,16 @@ const llmBestByCol = computed((): Record<string, string> => {
|
|||
}
|
||||
best['overall'] = bestId
|
||||
|
||||
if (llmHasJudge.value) {
|
||||
bestId = ''; bestVal = -Infinity
|
||||
for (const r of llmResults.value) {
|
||||
if (r.avg_judge_score != null && r.avg_judge_score > bestVal) {
|
||||
bestVal = r.avg_judge_score; bestId = r.model_id
|
||||
}
|
||||
}
|
||||
best['judge'] = bestId
|
||||
}
|
||||
|
||||
for (const col of llmTaskTypeCols.value) {
|
||||
bestId = ''; bestVal = -Infinity
|
||||
for (const r of llmResults.value) {
|
||||
|
|
@ -306,6 +392,15 @@ function toggleService(models: CfOrchModel[], checked: boolean) {
|
|||
}
|
||||
selectedLlmModels.value = next
|
||||
}
|
||||
function selectAllTasks() { selectedLlmTasks.value = new Set(llmTasks.value.map(t => t.id)) }
|
||||
function clearAllTasks() { selectedLlmTasks.value = new Set() }
|
||||
function selectAllModels() { selectedLlmModels.value = new Set(llmModels.value.map(m => m.id)) }
|
||||
function clearAllModels() { selectedLlmModels.value = new Set() }
|
||||
function toggleNode(id: string, checked: boolean) {
|
||||
const next = new Set(enabledNodes.value)
|
||||
checked ? next.add(id) : next.delete(id)
|
||||
enabledNodes.value = next
|
||||
}
|
||||
|
||||
// ── Data loaders ─────────────────────────────────────────────────────────────
|
||||
async function loadLlmTasks() {
|
||||
|
|
@ -335,6 +430,21 @@ async function loadLlmResults() {
|
|||
}
|
||||
}
|
||||
|
||||
async function loadLlmConfig() {
|
||||
const { data } = await useApiFetch<{ judge_url?: string }>('/api/cforch/config')
|
||||
if (data?.judge_url && !llmJudgeUrl.value) {
|
||||
llmJudgeUrl.value = data.judge_url
|
||||
}
|
||||
}
|
||||
|
||||
async function loadLlmNodes() {
|
||||
const { data } = await useApiFetch<{ nodes: CfOrchNode[] }>('/api/cforch/nodes')
|
||||
if (data?.nodes) {
|
||||
llmNodes.value = data.nodes
|
||||
enabledNodes.value = new Set(data.nodes.filter(n => n.online).map(n => n.node_id))
|
||||
}
|
||||
}
|
||||
|
||||
// ── Run / cancel ──────────────────────────────────────────────────────────────
|
||||
function startLlmBenchmark() {
|
||||
llmRunning.value = true
|
||||
|
|
@ -344,6 +454,15 @@ function startLlmBenchmark() {
|
|||
const params = new URLSearchParams()
|
||||
const taskIds = [...selectedLlmTasks.value].join(',')
|
||||
if (taskIds) params.set('task_ids', taskIds)
|
||||
const modelIds = [...selectedLlmModels.value].join(',')
|
||||
if (modelIds) params.set('model_ids', modelIds)
|
||||
if (llmJudgeUrl.value.trim()) params.set('judge_url', llmJudgeUrl.value.trim())
|
||||
if (llmWorkers.value > 1) params.set('workers', String(llmWorkers.value))
|
||||
const onlineNodeIds = llmNodes.value.filter(n => n.online).map(n => n.node_id)
|
||||
const isRestricted = enabledNodeIds.value.length < onlineNodeIds.length
|
||||
if (isRestricted && enabledNodeIds.value.length > 0) {
|
||||
params.set('node_ids', enabledNodeIds.value.join(','))
|
||||
}
|
||||
|
||||
const es = new EventSource(`/api/cforch/run?${params}`)
|
||||
llmEventSource.value = es
|
||||
|
|
@ -387,6 +506,8 @@ onMounted(() => {
|
|||
loadLlmTasks()
|
||||
loadLlmModels()
|
||||
loadLlmResults()
|
||||
loadLlmConfig()
|
||||
loadLlmNodes()
|
||||
})
|
||||
</script>
|
||||
|
||||
|
|
@ -451,6 +572,43 @@ onMounted(() => {
|
|||
color: var(--color-text-secondary, #6b7a99);
|
||||
}
|
||||
|
||||
.judge-url-input {
|
||||
flex: 1;
|
||||
min-width: 14rem;
|
||||
max-width: 24rem;
|
||||
padding: 0.35rem 0.6rem;
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: 0.375rem;
|
||||
background: var(--color-surface, #fff);
|
||||
color: var(--color-text, #1a2338);
|
||||
font-size: 0.8rem;
|
||||
font-family: var(--font-mono, monospace);
|
||||
}
|
||||
.judge-url-input:disabled { opacity: 0.5; }
|
||||
.judge-url-input::placeholder { color: var(--color-text-secondary, #6b7a99); }
|
||||
|
||||
.workers-label {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.3rem;
|
||||
font-size: 0.8rem;
|
||||
color: var(--color-text-secondary, #6b7a99);
|
||||
white-space: nowrap;
|
||||
}
|
||||
.workers-prefix { font-family: var(--font-mono, monospace); }
|
||||
.workers-input {
|
||||
width: 3.2rem;
|
||||
padding: 0.35rem 0.4rem;
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: 0.375rem;
|
||||
background: var(--color-surface, #fff);
|
||||
color: var(--color-text, #1a2338);
|
||||
font-size: 0.8rem;
|
||||
font-family: var(--font-mono, monospace);
|
||||
text-align: center;
|
||||
}
|
||||
.workers-input:disabled { opacity: 0.5; }
|
||||
|
||||
/* ── Run log ────────────────────────────────────────────── */
|
||||
.run-log {
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
|
|
@ -592,6 +750,15 @@ onMounted(() => {
|
|||
white-space: nowrap;
|
||||
}
|
||||
|
||||
.hm-judge-col {
|
||||
background: color-mix(in srgb, var(--color-surface-raised, #e4ebf5) 80%, #c6d5f5);
|
||||
}
|
||||
.hm-judge-cell {
|
||||
background: color-mix(in srgb, var(--color-surface, #fff) 85%, #c6d5f5);
|
||||
font-style: italic;
|
||||
opacity: 0.9;
|
||||
}
|
||||
|
||||
/* ── Model Picker ───────────────────────────────────────── */
|
||||
.model-picker {
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
|
|
@ -630,6 +797,24 @@ details[open] .picker-summary::before { content: '▼ '; }
|
|||
margin-left: auto;
|
||||
}
|
||||
|
||||
.picker-bulk-btn {
|
||||
padding: 0.1rem 0.45rem;
|
||||
font-size: 0.7rem;
|
||||
font-family: var(--font-mono, monospace);
|
||||
background: var(--color-surface, #fff);
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: 0.25rem;
|
||||
color: var(--color-text-secondary, #6b7a99);
|
||||
cursor: pointer;
|
||||
transition: background 0.12s, color 0.12s;
|
||||
flex-shrink: 0;
|
||||
}
|
||||
.picker-bulk-btn:hover {
|
||||
background: var(--app-primary, #2A6080);
|
||||
color: #fff;
|
||||
border-color: var(--app-primary, #2A6080);
|
||||
}
|
||||
|
||||
.picker-body {
|
||||
padding: 0.75rem;
|
||||
border-top: 1px solid var(--color-border, #d0d7e8);
|
||||
|
|
@ -712,4 +897,61 @@ details[open] .picker-summary::before { content: '▼ '; }
|
|||
.picker-model-list { padding-left: 0; }
|
||||
.picker-model-name { max-width: 14ch; }
|
||||
}
|
||||
|
||||
/* ── Node picker ────────────────────────────────────── */
|
||||
.node-picker {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.5rem;
|
||||
flex-wrap: wrap;
|
||||
padding: 0.5rem 0.75rem;
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: 0.5rem;
|
||||
background: var(--color-surface-raised, #e4ebf5);
|
||||
}
|
||||
|
||||
.node-picker-label {
|
||||
font-size: 0.78rem;
|
||||
font-weight: 600;
|
||||
color: var(--color-text-secondary, #6b7a99);
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.04em;
|
||||
white-space: nowrap;
|
||||
}
|
||||
|
||||
.node-chip {
|
||||
display: inline-flex;
|
||||
align-items: center;
|
||||
gap: 0.3rem;
|
||||
padding: 0.2rem 0.55rem;
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: 1rem;
|
||||
background: var(--color-surface, #fff);
|
||||
font-size: 0.78rem;
|
||||
font-family: var(--font-mono, monospace);
|
||||
color: var(--color-text, #1a2338);
|
||||
cursor: pointer;
|
||||
transition: background 0.12s, opacity 0.12s;
|
||||
}
|
||||
.node-chip--off {
|
||||
opacity: 0.45;
|
||||
background: transparent;
|
||||
}
|
||||
.node-chip--offline {
|
||||
opacity: 0.35;
|
||||
cursor: not-allowed;
|
||||
font-style: italic;
|
||||
}
|
||||
.node-chip-check { accent-color: var(--app-primary, #2A6080); }
|
||||
.node-chip-status {
|
||||
font-size: 0.66rem;
|
||||
color: var(--color-text-secondary, #6b7a99);
|
||||
}
|
||||
|
||||
.node-picker-hint {
|
||||
font-size: 0.72rem;
|
||||
color: var(--color-text-secondary, #6b7a99);
|
||||
font-family: var(--font-mono, monospace);
|
||||
margin-left: auto;
|
||||
}
|
||||
</style>
|
||||
|
|
|
|||
|
|
@ -51,8 +51,31 @@
|
|||
<span v-if="lookupResult.adapter_recommendation" class="chip chip-adapter">
|
||||
{{ lookupResult.adapter_recommendation }}
|
||||
</span>
|
||||
<span v-if="lookupResult.size != null" class="preview-size">
|
||||
{{ humanBytes(lookupResult.size) }}
|
||||
<span v-if="selectedQuantSize > 0" class="preview-size">
|
||||
{{ humanBytes(selectedQuantSize) }}
|
||||
</span>
|
||||
</div>
|
||||
|
||||
<!-- GGUF quantization picker — only shown for GGUF repos -->
|
||||
<div v-if="lookupResult.gguf_files?.length" class="quant-picker">
|
||||
<label class="quant-label" for="quant-select">Quantization</label>
|
||||
<select
|
||||
id="quant-select"
|
||||
v-model="selectedQuant"
|
||||
class="quant-select"
|
||||
aria-label="Select quantization variant"
|
||||
>
|
||||
<option :value="null" disabled>Select quantization…</option>
|
||||
<option
|
||||
v-for="f in lookupResult.gguf_files"
|
||||
:key="f.filename"
|
||||
:value="f.quant_name ?? f.filename"
|
||||
>
|
||||
{{ f.quant_name ?? f.filename }} — {{ humanBytes(f.size) }}
|
||||
</option>
|
||||
</select>
|
||||
<span class="quant-hint">
|
||||
Q5_K_M or Q6_K recommended for 8 GB GPUs. Q8_0 for max quality.
|
||||
</span>
|
||||
</div>
|
||||
|
||||
|
|
@ -67,7 +90,7 @@
|
|||
|
||||
<button
|
||||
class="btn-primary btn-add-queue"
|
||||
:disabled="lookupResult.already_installed || lookupResult.already_queued || addingToQueue"
|
||||
:disabled="!canAddToQueue"
|
||||
@click="addToQueue"
|
||||
>
|
||||
{{ addingToQueue ? 'Adding…' : 'Add to queue' }}
|
||||
|
|
@ -99,9 +122,39 @@
|
|||
<span v-if="model.role" class="chip chip-role">{{ model.role }}</span>
|
||||
<span v-if="model.service" class="chip" :class="serviceChipClass(model.service)">{{ model.service }}</span>
|
||||
<span v-if="model.adapter_recommendation" class="chip chip-adapter">{{ model.adapter_recommendation }}</span>
|
||||
<span v-if="model.quant_pattern" class="chip chip-quant">{{ model.quant_pattern }}</span>
|
||||
</div>
|
||||
<!-- Allow manual service/role assignment for unrecognized pipeline tags -->
|
||||
<div v-if="!model.service" class="classify-row queue-classify">
|
||||
<select
|
||||
class="classify-select"
|
||||
:value="classifyDraft[model.id]?.service ?? ''"
|
||||
@change="onServiceChange(model.id, ($event.target as HTMLSelectElement).value)"
|
||||
aria-label="Assign service"
|
||||
>
|
||||
<option value="" disabled>Service…</option>
|
||||
<option v-for="svc in CLASSIFIABLE_SERVICES" :key="svc.value" :value="svc.value">{{ svc.label }}</option>
|
||||
</select>
|
||||
<select
|
||||
class="classify-select"
|
||||
:value="classifyDraft[model.id]?.role ?? ''"
|
||||
:disabled="!classifyDraft[model.id]?.service"
|
||||
@change="(e) => setClassifyRole(model.id, (e.target as HTMLSelectElement).value)"
|
||||
aria-label="Assign role"
|
||||
>
|
||||
<option value="" disabled>Role…</option>
|
||||
<option
|
||||
v-for="role in rolesForService(classifyDraft[model.id]?.service ?? '')"
|
||||
:key="role"
|
||||
:value="role"
|
||||
>{{ role }}</option>
|
||||
</select>
|
||||
</div>
|
||||
<div class="model-card-actions">
|
||||
<button class="btn-primary btn-sm" @click="approveModel(model.id)">
|
||||
<button
|
||||
class="btn-primary btn-sm"
|
||||
@click="approveModel(model.id, classifyDraft[model.id])"
|
||||
>
|
||||
Approve download
|
||||
</button>
|
||||
</div>
|
||||
|
|
@ -252,6 +305,12 @@ import { ref, computed, onMounted, onUnmounted } from 'vue'
|
|||
|
||||
// ── Type definitions ──────────────────────────────────
|
||||
|
||||
interface GgufFile {
|
||||
filename: string
|
||||
size: number
|
||||
quant_name: string | null
|
||||
}
|
||||
|
||||
interface LookupResult {
|
||||
repo_id: string
|
||||
pipeline_tag: string | null
|
||||
|
|
@ -260,7 +319,8 @@ interface LookupResult {
|
|||
service: string | null
|
||||
compatible: boolean
|
||||
warning: string | null
|
||||
size: number | null
|
||||
model_size_bytes: number
|
||||
gguf_files: GgufFile[] | null
|
||||
description: string | null
|
||||
already_installed: boolean
|
||||
already_queued: boolean
|
||||
|
|
@ -274,6 +334,7 @@ interface QueuedModel {
|
|||
adapter_recommendation: string | null
|
||||
role: string | null
|
||||
service: string | null
|
||||
quant_pattern: string | null
|
||||
}
|
||||
|
||||
interface InstalledModel {
|
||||
|
|
@ -302,6 +363,26 @@ const lookupLoading = ref(false)
|
|||
const lookupError = ref<string | null>(null)
|
||||
const lookupResult = ref<LookupResult | null>(null)
|
||||
const addingToQueue = ref(false)
|
||||
const selectedQuant = ref<string | null>(null)
|
||||
|
||||
// Size of the selected GGUF file, or total model size for non-GGUF repos.
|
||||
const selectedQuantSize = computed<number>(() => {
|
||||
const r = lookupResult.value
|
||||
if (!r) return 0
|
||||
if (r.gguf_files?.length && selectedQuant.value) {
|
||||
const f = r.gguf_files.find(f => (f.quant_name ?? f.filename) === selectedQuant.value)
|
||||
return f?.size ?? r.model_size_bytes
|
||||
}
|
||||
return r.model_size_bytes
|
||||
})
|
||||
|
||||
// Disable "Add to queue" when a GGUF repo but no quant chosen yet.
|
||||
const canAddToQueue = computed(() => {
|
||||
const r = lookupResult.value
|
||||
if (!r || r.already_installed || r.already_queued || addingToQueue.value) return false
|
||||
if (r.gguf_files?.length && !selectedQuant.value) return false
|
||||
return true
|
||||
})
|
||||
|
||||
const queuedModels = ref<QueuedModel[]>([])
|
||||
const installedModels = ref<InstalledModel[]>([])
|
||||
|
|
@ -411,6 +492,7 @@ async function doLookup() {
|
|||
lookupLoading.value = true
|
||||
lookupError.value = null
|
||||
lookupResult.value = null
|
||||
selectedQuant.value = null
|
||||
|
||||
try {
|
||||
const res = await fetch(`/api/models/lookup?repo_id=${encodeURIComponent(repoId)}`)
|
||||
|
|
@ -442,7 +524,15 @@ async function addToQueue() {
|
|||
const res = await fetch('/api/models/queue', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ repo_id, pipeline_tag, adapter_recommendation, role, service }),
|
||||
body: JSON.stringify({
|
||||
repo_id,
|
||||
pipeline_tag,
|
||||
adapter_recommendation,
|
||||
role,
|
||||
service,
|
||||
model_size_bytes: selectedQuantSize.value,
|
||||
quant_pattern: selectedQuant.value,
|
||||
}),
|
||||
})
|
||||
if (res.ok) {
|
||||
lookupResult.value = { ...lookupResult.value, already_queued: true }
|
||||
|
|
@ -454,8 +544,16 @@ async function addToQueue() {
|
|||
}
|
||||
}
|
||||
|
||||
async function approveModel(id: string) {
|
||||
async function approveModel(id: string, draft?: { service: string; role: string }) {
|
||||
try {
|
||||
// If the user picked a service/role for an unrecognized model, patch it first.
|
||||
if (draft?.service && draft?.role) {
|
||||
await fetch(`/api/models/queue/${encodeURIComponent(id)}`, {
|
||||
method: 'PATCH',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ service: draft.service, role: draft.role }),
|
||||
})
|
||||
}
|
||||
const res = await fetch(`/api/models/queue/${encodeURIComponent(id)}/approve`, { method: 'POST' })
|
||||
if (res.ok) {
|
||||
await loadQueue()
|
||||
|
|
@ -774,6 +872,44 @@ onUnmounted(() => {
|
|||
align-self: flex-start;
|
||||
}
|
||||
|
||||
/* ── Quant picker ── */
|
||||
.quant-picker {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0.35rem;
|
||||
}
|
||||
|
||||
.quant-label {
|
||||
font-size: 0.8rem;
|
||||
font-weight: 600;
|
||||
color: var(--color-text-muted, #4a5c7a);
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.04em;
|
||||
}
|
||||
|
||||
.quant-select {
|
||||
padding: 0.4rem 0.6rem;
|
||||
border: 1px solid var(--color-border, #a8b8d0);
|
||||
border-radius: var(--radius-md, 0.5rem);
|
||||
background: var(--color-surface, #f0f4fb);
|
||||
color: var(--color-text, #1a2338);
|
||||
font-size: 0.9rem;
|
||||
font-family: var(--font-mono, monospace);
|
||||
cursor: pointer;
|
||||
}
|
||||
|
||||
.quant-hint {
|
||||
font-size: 0.78rem;
|
||||
color: var(--color-text-muted, #4a5c7a);
|
||||
}
|
||||
|
||||
.chip-quant {
|
||||
background: color-mix(in srgb, var(--color-primary, #2A6080) 15%, transparent);
|
||||
color: var(--color-primary, #2A6080);
|
||||
font-family: var(--font-mono, monospace);
|
||||
font-size: 0.75rem;
|
||||
}
|
||||
|
||||
/* ── Model cards (queue + downloads) ── */
|
||||
.model-card {
|
||||
border: 1px solid var(--color-border, #a8b8d0);
|
||||
|
|
|
|||
69
web/src/views/NodeManagementView.vue
Normal file
69
web/src/views/NodeManagementView.vue
Normal file
|
|
@ -0,0 +1,69 @@
|
|||
<script setup lang="ts">
|
||||
import { ref, onMounted } from 'vue'
|
||||
import NodeCard from '../components/nodes/NodeCard.vue'
|
||||
import type { NodeSummary } from '../types/nodes'
|
||||
|
||||
const nodes = ref<NodeSummary[]>([])
|
||||
const loading = ref(true)
|
||||
const error = ref('')
|
||||
|
||||
async function fetchNodes() {
|
||||
loading.value = true
|
||||
error.value = ''
|
||||
try {
|
||||
const r = await fetch('/api/nodes-mgmt/nodes')
|
||||
if (!r.ok) throw new Error(`HTTP ${r.status}`)
|
||||
nodes.value = (await r.json()) as NodeSummary[]
|
||||
} catch (e) {
|
||||
error.value = e instanceof Error ? e.message : 'Failed to load nodes'
|
||||
} finally {
|
||||
loading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
onMounted(fetchNodes)
|
||||
</script>
|
||||
|
||||
<template>
|
||||
<main class="nodes-page">
|
||||
<header class="nodes-header">
|
||||
<h1>Nodes</h1>
|
||||
<button class="btn-secondary" @click="fetchNodes" :disabled="loading">Refresh</button>
|
||||
</header>
|
||||
|
||||
<div aria-live="polite" aria-atomic="true" class="sr-announce">
|
||||
<span v-if="loading">Loading nodes...</span>
|
||||
</div>
|
||||
<div v-if="error" class="nodes-status nodes-error" role="alert">{{ error }}</div>
|
||||
<div v-else-if="!loading && nodes.length === 0" class="nodes-status">
|
||||
No nodes found. Check <code>coordinator_url</code> in config.
|
||||
</div>
|
||||
<div v-else-if="!loading" class="nodes-grid">
|
||||
<NodeCard
|
||||
v-for="node in nodes"
|
||||
:key="node.node_id"
|
||||
:node="node"
|
||||
@updated="fetchNodes"
|
||||
/>
|
||||
</div>
|
||||
</main>
|
||||
</template>
|
||||
|
||||
<style scoped>
|
||||
.nodes-page { padding: 1.5rem; }
|
||||
.nodes-header {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: space-between;
|
||||
margin-bottom: 1.5rem;
|
||||
}
|
||||
.nodes-header h1 { margin: 0; font-size: 1.5rem; }
|
||||
.nodes-grid { display: flex; flex-direction: column; gap: 1.5rem; }
|
||||
.nodes-status {
|
||||
color: var(--text-secondary, #888);
|
||||
padding: 2rem;
|
||||
text-align: center;
|
||||
}
|
||||
.nodes-error { color: var(--color-error, #fc8181); }
|
||||
.sr-announce { min-height: 1.2em; }
|
||||
</style>
|
||||
1043
web/src/views/PlansBenchTab.vue
Normal file
1043
web/src/views/PlansBenchTab.vue
Normal file
File diff suppressed because it is too large
Load diff
161
web/src/views/TrainJobsView.test.ts
Normal file
161
web/src/views/TrainJobsView.test.ts
Normal file
|
|
@ -0,0 +1,161 @@
|
|||
import { mount, flushPromises } from '@vue/test-utils'
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest'
|
||||
import TrainJobsView from './TrainJobsView.vue'
|
||||
|
||||
const sampleJob = {
|
||||
id: 'job-abc123',
|
||||
type: 'classifier',
|
||||
model_key: 'deberta-v3-small',
|
||||
status: 'queued',
|
||||
created_at: '2026-05-01T10:00:00Z',
|
||||
config: null,
|
||||
}
|
||||
|
||||
function makeFetch(jobs: unknown[] = []) {
|
||||
return vi.fn().mockImplementation((url: string, opts?: RequestInit) => {
|
||||
if ((opts?.method ?? 'GET') === 'POST') {
|
||||
return Promise.resolve({
|
||||
ok: true,
|
||||
json: async () => ({ ...sampleJob, id: 'new-job', status: 'queued' }),
|
||||
text: async () => '',
|
||||
})
|
||||
}
|
||||
if ((opts?.method ?? 'GET') === 'DELETE') {
|
||||
return Promise.resolve({ ok: true, json: async () => ({}), text: async () => '' })
|
||||
}
|
||||
// GET
|
||||
return Promise.resolve({
|
||||
ok: true,
|
||||
json: async () => ({ jobs }),
|
||||
text: async () => '',
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
class MockEventSource {
|
||||
onmessage: ((e: MessageEvent) => void) | null = null
|
||||
onerror: ((e: Event) => void) | null = null
|
||||
private _url: string
|
||||
constructor(url: string) { this._url = url }
|
||||
close() {}
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
vi.stubGlobal('fetch', makeFetch([sampleJob]))
|
||||
vi.stubGlobal('EventSource', MockEventSource)
|
||||
})
|
||||
|
||||
describe('TrainJobsView', () => {
|
||||
it('renders page title "Training Jobs"', async () => {
|
||||
const w = mount(TrainJobsView)
|
||||
await flushPromises()
|
||||
expect(w.find('h1.page-title').text()).toContain('Training Jobs')
|
||||
})
|
||||
|
||||
it('renders the new job form with type selector and model key input', async () => {
|
||||
const w = mount(TrainJobsView)
|
||||
await flushPromises()
|
||||
expect(w.find('select.job-type-select').exists()).toBe(true)
|
||||
expect(w.find('input.model-key-input').exists()).toBe(true)
|
||||
expect(w.find('button.submit-job-btn').exists()).toBe(true)
|
||||
})
|
||||
|
||||
it('type selector has classifier and llm-sft options', async () => {
|
||||
const w = mount(TrainJobsView)
|
||||
await flushPromises()
|
||||
const options = w.findAll('select.job-type-select option')
|
||||
const values = options.map(o => o.attributes('value') ?? o.element.textContent)
|
||||
expect(values).toContain('classifier')
|
||||
expect(values).toContain('llm-sft')
|
||||
})
|
||||
|
||||
it('submit button is disabled when model key is empty', async () => {
|
||||
const w = mount(TrainJobsView)
|
||||
await flushPromises()
|
||||
const btn = w.find('button.submit-job-btn')
|
||||
expect((btn.element as HTMLButtonElement).disabled).toBe(true)
|
||||
})
|
||||
|
||||
it('submit button is enabled when model key is entered', async () => {
|
||||
const w = mount(TrainJobsView)
|
||||
await flushPromises()
|
||||
await w.find('input.model-key-input').setValue('deberta-v3-small')
|
||||
const btn = w.find('button.submit-job-btn')
|
||||
expect((btn.element as HTMLButtonElement).disabled).toBe(false)
|
||||
})
|
||||
|
||||
it('shows job table with existing jobs', async () => {
|
||||
const w = mount(TrainJobsView)
|
||||
await flushPromises()
|
||||
expect(w.find('table.jobs-table').exists()).toBe(true)
|
||||
expect(w.text()).toContain('deberta-v3-small')
|
||||
})
|
||||
|
||||
it('shows status pill for each job', async () => {
|
||||
const w = mount(TrainJobsView)
|
||||
await flushPromises()
|
||||
expect(w.find('.status-pill').exists()).toBe(true)
|
||||
expect(w.find('.status-queued').exists()).toBe(true)
|
||||
})
|
||||
|
||||
it('shows cancel button for queued/running jobs', async () => {
|
||||
const w = mount(TrainJobsView)
|
||||
await flushPromises()
|
||||
expect(w.find('button.cancel-btn').exists()).toBe(true)
|
||||
})
|
||||
|
||||
it('submitting new job calls POST /api/train/jobs and refreshes', async () => {
|
||||
const fetchMock = makeFetch([])
|
||||
vi.stubGlobal('fetch', fetchMock)
|
||||
const w = mount(TrainJobsView)
|
||||
await flushPromises()
|
||||
await w.find('input.model-key-input').setValue('my-model')
|
||||
await w.find('button.submit-job-btn').trigger('click')
|
||||
await flushPromises()
|
||||
const calls = (fetchMock as ReturnType<typeof vi.fn>).mock.calls as [string, RequestInit?][]
|
||||
const postCall = calls.find(([, opts]) => (opts?.method ?? 'GET') === 'POST')
|
||||
expect(postCall).toBeDefined()
|
||||
expect(postCall![0]).toContain('/api/train/jobs')
|
||||
})
|
||||
|
||||
it('shows View Log button for running jobs', async () => {
|
||||
vi.stubGlobal('fetch', makeFetch([{ ...sampleJob, status: 'running' }]))
|
||||
const w = mount(TrainJobsView)
|
||||
await flushPromises()
|
||||
expect(w.find('button.view-log-btn').exists()).toBe(true)
|
||||
})
|
||||
it('shows error when config JSON is invalid', async () => {
|
||||
const w = mount(TrainJobsView)
|
||||
await flushPromises()
|
||||
await w.find('input.model-key-input').setValue('my-model')
|
||||
await w.find('textarea.config-textarea').setValue('{ not valid json }')
|
||||
await w.find('button.submit-job-btn').trigger('click')
|
||||
await flushPromises()
|
||||
expect(w.find('.error-notice').exists()).toBe(true)
|
||||
expect(w.find('.error-notice').text()).toContain('not valid')
|
||||
})
|
||||
|
||||
it('shows error notice when jobs load fails', async () => {
|
||||
vi.stubGlobal('fetch', vi.fn().mockResolvedValue({
|
||||
ok: false,
|
||||
status: 500,
|
||||
json: async () => ({}),
|
||||
text: async () => '',
|
||||
}))
|
||||
const w = mount(TrainJobsView)
|
||||
await flushPromises()
|
||||
expect(w.find('.error-notice').exists()).toBe(true)
|
||||
expect(w.find('table.jobs-table').exists()).toBe(false)
|
||||
})
|
||||
|
||||
it('cancel button optimistically updates job status to cancelled', async () => {
|
||||
const w = mount(TrainJobsView)
|
||||
await flushPromises()
|
||||
await w.find('button.cancel-btn').trigger('click')
|
||||
await flushPromises()
|
||||
// After cancel, job should show status-cancelled pill (not status-queued)
|
||||
expect(w.find('.status-cancelled').exists()).toBe(true)
|
||||
expect(w.find('.status-queued').exists()).toBe(false)
|
||||
})
|
||||
|
||||
})
|
||||
593
web/src/views/TrainJobsView.vue
Normal file
593
web/src/views/TrainJobsView.vue
Normal file
|
|
@ -0,0 +1,593 @@
|
|||
<template>
|
||||
<div class="train-jobs-view">
|
||||
<header class="view-header">
|
||||
<h1 class="page-title">🧠 Training Jobs</h1>
|
||||
</header>
|
||||
|
||||
<!-- New Job form -->
|
||||
<section class="section">
|
||||
<h2 class="section-title">New Job</h2>
|
||||
<form class="new-job-form" @submit.prevent="submitJob">
|
||||
<div class="form-row">
|
||||
<label class="form-label" for="job-type">Type</label>
|
||||
<select
|
||||
id="job-type"
|
||||
v-model="form.type"
|
||||
class="job-type-select form-control"
|
||||
>
|
||||
<option value="classifier">classifier</option>
|
||||
<option value="llm-sft">llm-sft</option>
|
||||
</select>
|
||||
</div>
|
||||
|
||||
<div class="form-row">
|
||||
<label class="form-label" for="model-key">Model key</label>
|
||||
<input
|
||||
id="model-key"
|
||||
v-model.trim="form.model_key"
|
||||
type="text"
|
||||
class="model-key-input form-control"
|
||||
placeholder="e.g. microsoft/deberta-v3-small"
|
||||
autocomplete="off"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div class="form-row">
|
||||
<label class="form-label" for="job-config">Config JSON <span class="form-hint">(optional)</span></label>
|
||||
<textarea
|
||||
id="job-config"
|
||||
v-model="form.config_raw"
|
||||
class="config-textarea form-control"
|
||||
rows="4"
|
||||
placeholder='{"learning_rate": 2e-5}'
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div v-if="submitError" class="error-notice" role="alert">{{ submitError }}</div>
|
||||
|
||||
<button
|
||||
type="submit"
|
||||
class="submit-job-btn btn-primary"
|
||||
:disabled="submitting || !form.model_key"
|
||||
@click.prevent="submitJob"
|
||||
>
|
||||
{{ submitting ? 'Queuing…' : 'Queue Job' }}
|
||||
</button>
|
||||
</form>
|
||||
</section>
|
||||
|
||||
<!-- Job queue table -->
|
||||
<section class="section">
|
||||
<h2 class="section-title">Job Queue</h2>
|
||||
|
||||
<div v-if="loadError" class="error-notice" role="alert">
|
||||
{{ loadError }}
|
||||
<button class="btn-retry" @click="loadJobs">Retry</button>
|
||||
</div>
|
||||
|
||||
<div v-else-if="jobs.length === 0" class="empty-notice">
|
||||
No training jobs yet.
|
||||
</div>
|
||||
|
||||
<div v-else class="jobs-table-wrap">
|
||||
<table class="jobs-table">
|
||||
<thead>
|
||||
<tr>
|
||||
<th>ID</th>
|
||||
<th>Type</th>
|
||||
<th>Model</th>
|
||||
<th>Status</th>
|
||||
<th>Created</th>
|
||||
<th></th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
<tr v-for="job in jobs" :key="job.id">
|
||||
<td class="td-id" :title="job.id">{{ job.id.slice(0, 8) }}</td>
|
||||
<td>
|
||||
<span class="type-chip">{{ job.type }}</span>
|
||||
</td>
|
||||
<td class="td-model">{{ job.model_key }}</td>
|
||||
<td>
|
||||
<span class="status-pill" :class="`status-${job.status}`">{{ job.status }}</span>
|
||||
</td>
|
||||
<td class="td-date">{{ formatDate(job.created_at) }}</td>
|
||||
<td class="td-actions">
|
||||
<button
|
||||
v-if="job.status === 'running'"
|
||||
class="view-log-btn btn-sm"
|
||||
@click="openLog(job.id)"
|
||||
>
|
||||
View Log
|
||||
</button>
|
||||
<button
|
||||
v-if="job.status === 'queued' || job.status === 'running'"
|
||||
class="cancel-btn btn-sm btn-danger-sm"
|
||||
:disabled="cancellingId === job.id"
|
||||
@click="cancelJob(job.id)"
|
||||
>
|
||||
{{ cancellingId === job.id ? '…' : 'Cancel' }}
|
||||
</button>
|
||||
</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
<div v-if="cancelError" class="error-notice" role="alert">{{ cancelError }}</div>
|
||||
</section>
|
||||
|
||||
<!-- Log panel (SSE) -->
|
||||
<section v-if="logJobId" class="section log-section">
|
||||
<div class="log-header">
|
||||
<h2 class="section-title">Log — {{ logJobId.slice(0, 8) }}</h2>
|
||||
<button class="btn-close-log" @click="closeLog">✕ Close</button>
|
||||
</div>
|
||||
<div class="log-panel" ref="logPanelEl">
|
||||
<div
|
||||
v-for="(line, i) in logLines"
|
||||
:key="i"
|
||||
class="log-line"
|
||||
>{{ line }}</div>
|
||||
<div v-if="logLines.length === 0" class="log-line log-muted">Connecting…</div>
|
||||
</div>
|
||||
</section>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, nextTick, onUnmounted } from 'vue'
|
||||
import { useApiSSE } from '../composables/useApi'
|
||||
|
||||
interface TrainJob {
|
||||
id: string
|
||||
type: 'classifier' | 'llm-sft'
|
||||
model_key: string
|
||||
status: 'queued' | 'running' | 'completed' | 'failed' | 'cancelled'
|
||||
created_at: string
|
||||
config: Record<string, unknown> | null
|
||||
}
|
||||
|
||||
const jobs = ref<TrainJob[]>([])
|
||||
const loadError = ref<string | null>(null)
|
||||
const submitError = ref<string | null>(null)
|
||||
const submitting = ref(false)
|
||||
const cancellingId = ref<string | null>(null)
|
||||
const cancelError = ref<string | null>(null)
|
||||
|
||||
const form = ref({
|
||||
type: 'classifier' as 'classifier' | 'llm-sft',
|
||||
model_key: '',
|
||||
config_raw: '',
|
||||
})
|
||||
|
||||
// ── Log panel state ──
|
||||
const logJobId = ref<string | null>(null)
|
||||
const logLines = ref<string[]>([])
|
||||
const logPanelEl = ref<HTMLElement | null>(null)
|
||||
let closeSSE: (() => void) | null = null
|
||||
|
||||
// ── Data loading ──
|
||||
|
||||
async function loadJobs() {
|
||||
loadError.value = null
|
||||
try {
|
||||
const res = await fetch('/api/train/jobs')
|
||||
if (!res.ok) { loadError.value = `Failed to load jobs (HTTP ${res.status}).`; return }
|
||||
const data = await res.json() as { jobs: TrainJob[] }
|
||||
jobs.value = data.jobs ?? []
|
||||
} catch {
|
||||
loadError.value = 'Network error loading jobs.'
|
||||
}
|
||||
}
|
||||
|
||||
// ── Submit ──
|
||||
|
||||
async function submitJob() {
|
||||
if (!form.value.model_key) return
|
||||
submitError.value = null
|
||||
submitting.value = true
|
||||
|
||||
let config: Record<string, unknown> | null = null
|
||||
if (form.value.config_raw.trim()) {
|
||||
try {
|
||||
config = JSON.parse(form.value.config_raw) as Record<string, unknown>
|
||||
} catch {
|
||||
submitError.value = 'Config JSON is not valid. Fix it before submitting.'
|
||||
submitting.value = false
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
const res = await fetch('/api/train/jobs', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({
|
||||
type: form.value.type,
|
||||
model_key: form.value.model_key,
|
||||
config_json: config,
|
||||
}),
|
||||
})
|
||||
if (!res.ok) {
|
||||
const detail = await res.text().catch(() => '')
|
||||
submitError.value = `Failed to queue job (HTTP ${res.status})${detail ? `: ${detail}` : '.'}`
|
||||
return
|
||||
}
|
||||
const newJob = await res.json() as TrainJob
|
||||
jobs.value = [newJob, ...jobs.value]
|
||||
form.value = { type: 'classifier', model_key: '', config_raw: '' }
|
||||
} catch {
|
||||
submitError.value = 'Network error submitting job.'
|
||||
} finally {
|
||||
submitting.value = false
|
||||
}
|
||||
}
|
||||
|
||||
// ── Cancel ──
|
||||
|
||||
async function cancelJob(id: string) {
|
||||
cancellingId.value = id
|
||||
cancelError.value = null
|
||||
try {
|
||||
const res = await fetch(`/api/train/jobs/${encodeURIComponent(id)}/cancel`, { method: 'DELETE' })
|
||||
if (res.ok) {
|
||||
jobs.value = jobs.value.map(j =>
|
||||
j.id === id ? { ...j, status: 'cancelled' as const } : j
|
||||
)
|
||||
} else {
|
||||
cancelError.value = `Failed to cancel job (HTTP ${res.status}).`
|
||||
}
|
||||
} catch {
|
||||
cancelError.value = 'Network error cancelling job.'
|
||||
} finally {
|
||||
cancellingId.value = null
|
||||
}
|
||||
}
|
||||
|
||||
// ── Log SSE ──
|
||||
|
||||
function openLog(id: string) {
|
||||
closeLog()
|
||||
logJobId.value = id
|
||||
logLines.value = []
|
||||
|
||||
closeSSE = useApiSSE(
|
||||
`/api/train/jobs/${encodeURIComponent(id)}/run`,
|
||||
(data) => {
|
||||
if (data.type === 'log' || data.type === 'progress' || data.type === 'error') {
|
||||
logLines.value = [...logLines.value, String(data.message ?? '')]
|
||||
nextTick(() => {
|
||||
if (logPanelEl.value) {
|
||||
logPanelEl.value.scrollTop = logPanelEl.value.scrollHeight
|
||||
}
|
||||
})
|
||||
}
|
||||
if (data.type === 'error') {
|
||||
logLines.value = [...logLines.value, '--- stream ended with error ---']
|
||||
nextTick(() => {
|
||||
if (logPanelEl.value) {
|
||||
logPanelEl.value.scrollTop = logPanelEl.value.scrollHeight
|
||||
}
|
||||
})
|
||||
}
|
||||
},
|
||||
() => {
|
||||
logLines.value = [...logLines.value, '--- stream complete ---']
|
||||
},
|
||||
() => {
|
||||
logLines.value = [...logLines.value, '--- connection lost ---']
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
function closeLog() {
|
||||
closeSSE?.()
|
||||
closeSSE = null
|
||||
logJobId.value = null
|
||||
logLines.value = []
|
||||
}
|
||||
|
||||
// ── Helpers ──
|
||||
|
||||
function formatDate(iso: string): string {
|
||||
const d = new Date(iso)
|
||||
if (isNaN(d.getTime())) return iso
|
||||
return d.toLocaleString(undefined, { dateStyle: 'short', timeStyle: 'short' })
|
||||
}
|
||||
|
||||
// ── Lifecycle ──
|
||||
|
||||
loadJobs()
|
||||
|
||||
onUnmounted(() => {
|
||||
closeSSE?.()
|
||||
})
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
.train-jobs-view {
|
||||
max-width: 860px;
|
||||
margin: 0 auto;
|
||||
padding: 1.5rem 1rem 4rem;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 2rem;
|
||||
}
|
||||
|
||||
.view-header { display: flex; align-items: center; }
|
||||
|
||||
.page-title {
|
||||
font-family: var(--font-display, var(--font-body, sans-serif));
|
||||
font-size: 1.4rem;
|
||||
font-weight: 700;
|
||||
color: var(--app-primary, #2A6080);
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
.section {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0.75rem;
|
||||
}
|
||||
|
||||
.section-title {
|
||||
font-size: 1rem;
|
||||
font-weight: 600;
|
||||
color: var(--color-text, #1a2338);
|
||||
padding-bottom: 0.4rem;
|
||||
border-bottom: 1px solid var(--color-border, #a8b8d0);
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
.new-job-form {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0.75rem;
|
||||
max-width: 480px;
|
||||
}
|
||||
|
||||
.form-row {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0.3rem;
|
||||
}
|
||||
|
||||
.form-label {
|
||||
font-size: 0.85rem;
|
||||
font-weight: 600;
|
||||
color: var(--color-text-muted, #4a5c7a);
|
||||
}
|
||||
|
||||
.form-hint {
|
||||
font-weight: 400;
|
||||
font-size: 0.78rem;
|
||||
}
|
||||
|
||||
.form-control {
|
||||
padding: 0.45rem 0.65rem;
|
||||
border: 1px solid var(--color-border, #a8b8d0);
|
||||
border-radius: var(--radius-md, 0.5rem);
|
||||
background: var(--color-surface-raised, #f5f7fc);
|
||||
color: var(--color-text, #1a2338);
|
||||
font-size: 0.9rem;
|
||||
font-family: var(--font-body, sans-serif);
|
||||
}
|
||||
|
||||
.form-control:focus {
|
||||
outline: 2px solid var(--app-primary, #2A6080);
|
||||
outline-offset: -1px;
|
||||
}
|
||||
|
||||
.config-textarea {
|
||||
resize: vertical;
|
||||
font-family: var(--font-mono, monospace);
|
||||
font-size: 0.82rem;
|
||||
}
|
||||
|
||||
.btn-primary {
|
||||
padding: 0.4rem 0.9rem;
|
||||
border-radius: var(--radius-md, 0.5rem);
|
||||
border: 1px solid var(--app-primary, #2A6080);
|
||||
background: var(--app-primary, #2A6080);
|
||||
color: #fff;
|
||||
font-size: 0.88rem;
|
||||
font-family: var(--font-body, sans-serif);
|
||||
cursor: pointer;
|
||||
align-self: flex-start;
|
||||
transition: opacity 0.15s;
|
||||
}
|
||||
|
||||
.btn-primary:disabled { opacity: 0.5; cursor: not-allowed; }
|
||||
.btn-primary:not(:disabled):hover { opacity: 0.85; }
|
||||
|
||||
.btn-sm {
|
||||
padding: 0.2rem 0.55rem;
|
||||
font-size: 0.78rem;
|
||||
border-radius: 0.3rem;
|
||||
cursor: pointer;
|
||||
font-family: var(--font-body, sans-serif);
|
||||
border: 1px solid;
|
||||
transition: background 0.1s;
|
||||
}
|
||||
|
||||
.view-log-btn {
|
||||
border-color: var(--color-info, #1e6091);
|
||||
background: transparent;
|
||||
color: var(--color-info, #1e6091);
|
||||
}
|
||||
|
||||
.view-log-btn:hover {
|
||||
background: color-mix(in srgb, var(--color-info, #1e6091) 10%, transparent);
|
||||
}
|
||||
|
||||
.btn-danger-sm {
|
||||
border-color: var(--color-error, #c0392b);
|
||||
background: transparent;
|
||||
color: var(--color-error, #c0392b);
|
||||
}
|
||||
|
||||
.btn-danger-sm:hover:not(:disabled) {
|
||||
background: color-mix(in srgb, var(--color-error, #c0392b) 10%, transparent);
|
||||
}
|
||||
|
||||
.btn-danger-sm:disabled { opacity: 0.5; cursor: not-allowed; }
|
||||
|
||||
.btn-retry {
|
||||
margin-left: 0.5rem;
|
||||
padding: 0.2rem 0.55rem;
|
||||
border-radius: 0.25rem;
|
||||
border: 1px solid var(--color-error, #c0392b);
|
||||
background: transparent;
|
||||
color: var(--color-error, #c0392b);
|
||||
cursor: pointer;
|
||||
font-size: 0.82rem;
|
||||
}
|
||||
|
||||
.error-notice {
|
||||
padding: 0.6rem 0.8rem;
|
||||
background: color-mix(in srgb, var(--color-error, #c0392b) 10%, transparent);
|
||||
border: 1px solid color-mix(in srgb, var(--color-error, #c0392b) 30%, transparent);
|
||||
border-radius: var(--radius-md, 0.5rem);
|
||||
color: var(--color-error, #c0392b);
|
||||
font-size: 0.88rem;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.5rem;
|
||||
}
|
||||
|
||||
.empty-notice {
|
||||
color: var(--color-text-muted, #4a5c7a);
|
||||
font-size: 0.9rem;
|
||||
padding: 0.75rem;
|
||||
border: 1px dashed var(--color-border, #a8b8d0);
|
||||
border-radius: var(--radius-md, 0.5rem);
|
||||
}
|
||||
|
||||
.jobs-table-wrap { overflow-x: auto; }
|
||||
|
||||
.jobs-table {
|
||||
width: 100%;
|
||||
border-collapse: collapse;
|
||||
font-size: 0.875rem;
|
||||
}
|
||||
|
||||
.jobs-table th {
|
||||
text-align: left;
|
||||
padding: 0.4rem 0.6rem;
|
||||
background: var(--color-surface-raised, #f5f7fc);
|
||||
color: var(--color-text-muted, #4a5c7a);
|
||||
font-size: 0.78rem;
|
||||
font-weight: 600;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.03em;
|
||||
border-bottom: 1px solid var(--color-border, #a8b8d0);
|
||||
white-space: nowrap;
|
||||
}
|
||||
|
||||
.jobs-table td {
|
||||
padding: 0.5rem 0.6rem;
|
||||
border-bottom: 1px solid var(--color-border-light, #ccd5e6);
|
||||
vertical-align: middle;
|
||||
}
|
||||
|
||||
.td-id {
|
||||
font-family: var(--font-mono, monospace);
|
||||
font-size: 0.78rem;
|
||||
color: var(--color-text-muted, #4a5c7a);
|
||||
}
|
||||
|
||||
.td-model {
|
||||
font-family: var(--font-mono, monospace);
|
||||
font-size: 0.82rem;
|
||||
word-break: break-all;
|
||||
}
|
||||
|
||||
.td-date {
|
||||
font-size: 0.8rem;
|
||||
color: var(--color-text-muted, #4a5c7a);
|
||||
white-space: nowrap;
|
||||
}
|
||||
|
||||
.td-actions {
|
||||
display: flex;
|
||||
gap: 0.35rem;
|
||||
align-items: center;
|
||||
flex-wrap: wrap;
|
||||
}
|
||||
|
||||
.status-pill {
|
||||
font-size: 0.68rem;
|
||||
font-weight: 700;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.04em;
|
||||
padding: 0.15rem 0.45rem;
|
||||
border-radius: var(--radius-full, 9999px);
|
||||
white-space: nowrap;
|
||||
}
|
||||
|
||||
.status-queued { background: var(--color-surface-alt, #dde4f0); color: var(--color-text-muted, #4a5c7a); }
|
||||
.status-running { background: color-mix(in srgb, var(--color-info, #1e6091) 15%, transparent); color: var(--color-info, #1e6091); }
|
||||
.status-completed { background: color-mix(in srgb, var(--color-success, #3a7a32) 15%, transparent); color: var(--color-success, #3a7a32); }
|
||||
.status-failed { background: color-mix(in srgb, var(--color-error, #c0392b) 15%, transparent); color: var(--color-error, #c0392b); }
|
||||
.status-cancelled { background: color-mix(in srgb, var(--color-warning, #d4891a) 15%, transparent); color: var(--color-warning, #d4891a); }
|
||||
|
||||
.type-chip {
|
||||
font-size: 0.72rem;
|
||||
font-family: var(--font-mono, monospace);
|
||||
padding: 0.1rem 0.4rem;
|
||||
border-radius: 0.25rem;
|
||||
background: var(--color-surface-alt, #dde4f0);
|
||||
color: var(--color-text, #1a2338);
|
||||
}
|
||||
|
||||
.log-section { gap: 0.5rem; }
|
||||
|
||||
.log-header {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: space-between;
|
||||
gap: 0.5rem;
|
||||
}
|
||||
|
||||
.btn-close-log {
|
||||
background: transparent;
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: 0.25rem;
|
||||
cursor: pointer;
|
||||
font-size: 0.8rem;
|
||||
padding: 0.2rem 0.5rem;
|
||||
color: var(--color-text-muted, #4a5c7a);
|
||||
transition: background 0.1s;
|
||||
}
|
||||
|
||||
.btn-close-log:hover { background: var(--color-surface-raised, #e4ebf5); }
|
||||
|
||||
.log-panel {
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: 0.5rem;
|
||||
max-height: 320px;
|
||||
overflow-y: auto;
|
||||
padding: 0.5rem 0.75rem;
|
||||
background: var(--color-surface, #f0f4fc);
|
||||
font-family: var(--font-mono, monospace);
|
||||
font-size: 0.78rem;
|
||||
}
|
||||
|
||||
.log-line {
|
||||
color: var(--color-text, #1a2338);
|
||||
line-height: 1.5;
|
||||
white-space: pre-wrap;
|
||||
word-break: break-all;
|
||||
}
|
||||
|
||||
.log-muted { color: var(--color-text-muted, #4a5c7a); }
|
||||
|
||||
@media (max-width: 560px) {
|
||||
.jobs-table th:nth-child(4),
|
||||
.jobs-table td:nth-child(4),
|
||||
.jobs-table th:nth-child(5),
|
||||
.jobs-table td:nth-child(5) {
|
||||
display: none;
|
||||
}
|
||||
}
|
||||
</style>
|
||||
101
web/src/views/TrainResultsView.test.ts
Normal file
101
web/src/views/TrainResultsView.test.ts
Normal file
|
|
@ -0,0 +1,101 @@
|
|||
import { mount, flushPromises } from '@vue/test-utils'
|
||||
import { createRouter, createWebHashHistory } from 'vue-router'
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest'
|
||||
import TrainResultsView from './TrainResultsView.vue'
|
||||
|
||||
const router = createRouter({
|
||||
history: createWebHashHistory(),
|
||||
routes: [
|
||||
{ path: '/fleet', component: { template: '<div />' } },
|
||||
],
|
||||
})
|
||||
|
||||
const sampleResult = {
|
||||
id: 'run-xyz',
|
||||
job_id: 'job-abc123',
|
||||
model_type: 'classifier',
|
||||
base_model: 'microsoft/deberta-v3-small',
|
||||
val_macro_f1: 0.847,
|
||||
val_accuracy: 0.891,
|
||||
sample_count: 1240,
|
||||
duration_seconds: 842,
|
||||
created_at: '2026-05-01T11:30:00Z',
|
||||
}
|
||||
|
||||
function makeFetch(results: unknown[] = []) {
|
||||
return vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
json: async () => ({ results }),
|
||||
text: async () => '',
|
||||
})
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
vi.stubGlobal('fetch', makeFetch([sampleResult]))
|
||||
})
|
||||
|
||||
describe('TrainResultsView', () => {
|
||||
it('renders page title "Training Results"', async () => {
|
||||
const w = mount(TrainResultsView, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
expect(w.find('h1.page-title').text()).toContain('Training Results')
|
||||
})
|
||||
|
||||
it('shows empty notice when there are no results', async () => {
|
||||
vi.stubGlobal('fetch', makeFetch([]))
|
||||
const w = mount(TrainResultsView, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
expect(w.find('.empty-notice').exists()).toBe(true)
|
||||
})
|
||||
|
||||
it('renders results table when results exist', async () => {
|
||||
const w = mount(TrainResultsView, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
expect(w.find('table.results-table').exists()).toBe(true)
|
||||
})
|
||||
|
||||
it('shows base_model in table', async () => {
|
||||
const w = mount(TrainResultsView, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
expect(w.text()).toContain('deberta-v3-small')
|
||||
})
|
||||
|
||||
it('shows val_macro_f1 formatted as percentage', async () => {
|
||||
const w = mount(TrainResultsView, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
expect(w.text()).toContain('84.7%')
|
||||
})
|
||||
|
||||
it('shows val_accuracy formatted as percentage', async () => {
|
||||
const w = mount(TrainResultsView, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
expect(w.text()).toContain('89.1%')
|
||||
})
|
||||
|
||||
it('shows duration formatted as minutes and seconds', async () => {
|
||||
const w = mount(TrainResultsView, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
// 842 seconds = 14m 2s
|
||||
expect(w.text()).toContain('14m')
|
||||
})
|
||||
|
||||
it('shows Register in Fleet button for classifier results', async () => {
|
||||
const w = mount(TrainResultsView, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
expect(w.find('a.register-btn').exists()).toBe(true)
|
||||
})
|
||||
|
||||
it('does NOT show Register in Fleet button for llm-sft results', async () => {
|
||||
vi.stubGlobal('fetch', makeFetch([{ ...sampleResult, model_type: 'llm-sft' }]))
|
||||
const w = mount(TrainResultsView, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
expect(w.find('a.register-btn').exists()).toBe(false)
|
||||
})
|
||||
|
||||
it('shows error notice when API fails', async () => {
|
||||
vi.stubGlobal('fetch', vi.fn().mockResolvedValue({ ok: false, status: 500, text: async () => '' }))
|
||||
const w = mount(TrainResultsView, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
expect(w.find('.error-notice').exists()).toBe(true)
|
||||
})
|
||||
})
|
||||
296
web/src/views/TrainResultsView.vue
Normal file
296
web/src/views/TrainResultsView.vue
Normal file
|
|
@ -0,0 +1,296 @@
|
|||
<template>
|
||||
<div class="train-results-view">
|
||||
<header class="view-header">
|
||||
<h1 class="page-title">Training Results</h1>
|
||||
<button class="refresh-btn" :disabled="loading" @click="loadResults" aria-label="Refresh">🔄</button>
|
||||
</header>
|
||||
|
||||
<div v-if="error" class="error-notice" role="alert">
|
||||
{{ error }}
|
||||
<button class="btn-retry" @click="loadResults">Retry</button>
|
||||
</div>
|
||||
|
||||
<div v-if="loading" class="loading-state" aria-live="polite">Loading…</div>
|
||||
|
||||
<div v-if="!error && results.length === 0 && !loading" class="empty-notice">
|
||||
No training results yet. Completed jobs will appear here.
|
||||
</div>
|
||||
|
||||
<div v-if="results.length > 0" class="results-table-wrap">
|
||||
<table class="results-table">
|
||||
<thead>
|
||||
<tr>
|
||||
<th>Run</th>
|
||||
<th>Type</th>
|
||||
<th>Base Model</th>
|
||||
<th class="th-numeric">Macro F1</th>
|
||||
<th class="th-numeric">Accuracy</th>
|
||||
<th class="th-numeric">Samples</th>
|
||||
<th class="th-numeric">Duration</th>
|
||||
<th></th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
<tr v-for="r in results" :key="r.id">
|
||||
<td class="td-id" :title="r.id">{{ r.id.slice(0, 8) }}</td>
|
||||
<td>
|
||||
<span class="type-chip">{{ r.model_type }}</span>
|
||||
</td>
|
||||
<td class="td-model" :title="r.base_model">{{ shortModel(r.base_model) }}</td>
|
||||
<td class="td-numeric">
|
||||
<span class="metric-val" :class="scoreClass(r.val_macro_f1)">
|
||||
{{ formatPct(r.val_macro_f1) }}
|
||||
</span>
|
||||
</td>
|
||||
<td class="td-numeric">{{ formatPct(r.val_accuracy) }}</td>
|
||||
<td class="td-numeric">{{ r.sample_count.toLocaleString() }}</td>
|
||||
<td class="td-numeric">{{ formatDuration(r.duration_seconds) }}</td>
|
||||
<td class="td-actions">
|
||||
<RouterLink
|
||||
v-if="r.model_type === 'classifier'"
|
||||
:to="`/fleet?model=${encodeURIComponent(r.base_model)}`"
|
||||
class="register-btn btn-sm-link"
|
||||
>
|
||||
Register in Fleet
|
||||
</RouterLink>
|
||||
</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, onMounted } from 'vue'
|
||||
import { RouterLink } from 'vue-router'
|
||||
|
||||
interface TrainResult {
|
||||
id: string
|
||||
job_id: string
|
||||
model_type: string
|
||||
base_model: string
|
||||
val_macro_f1: number | null
|
||||
val_accuracy: number | null
|
||||
sample_count: number
|
||||
duration_seconds: number | null
|
||||
created_at: string
|
||||
}
|
||||
|
||||
const results = ref<TrainResult[]>([])
|
||||
const loading = ref(false)
|
||||
const error = ref<string | null>(null)
|
||||
|
||||
async function loadResults() {
|
||||
loading.value = true
|
||||
error.value = null
|
||||
try {
|
||||
const res = await fetch('/api/train/results')
|
||||
if (!res.ok) {
|
||||
error.value = `Failed to load results (HTTP ${res.status}).`
|
||||
return
|
||||
}
|
||||
const raw = await res.json() as { results?: TrainResult[] }
|
||||
results.value = Array.isArray(raw?.results) ? raw.results : []
|
||||
} catch {
|
||||
error.value = 'Network error loading results.'
|
||||
} finally {
|
||||
loading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
function formatPct(v: number | null | undefined): string {
|
||||
if (v == null) return '—'
|
||||
return `${(v * 100).toFixed(1)}%`
|
||||
}
|
||||
|
||||
function formatDuration(seconds: number | null | undefined): string {
|
||||
if (seconds == null) return '—'
|
||||
const mins = Math.floor(seconds / 60)
|
||||
const secs = seconds % 60
|
||||
if (mins === 0) return `${secs}s`
|
||||
return `${mins}m ${secs}s`
|
||||
}
|
||||
|
||||
function shortModel(model: string): string {
|
||||
const parts = model.split('/')
|
||||
return parts[parts.length - 1] ?? model
|
||||
}
|
||||
|
||||
function scoreClass(f1: number | null | undefined): string {
|
||||
if (f1 == null) return ''
|
||||
if (f1 >= 0.85) return 'score-great'
|
||||
if (f1 >= 0.75) return 'score-good'
|
||||
return 'score-fair'
|
||||
}
|
||||
|
||||
onMounted(() => loadResults())
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
.train-results-view {
|
||||
max-width: 860px;
|
||||
margin: 0 auto;
|
||||
padding: 1.5rem 1rem 4rem;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 1.75rem;
|
||||
}
|
||||
|
||||
.view-header {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.75rem;
|
||||
}
|
||||
|
||||
.page-title {
|
||||
font-family: var(--font-display, var(--font-body, sans-serif));
|
||||
font-size: 1.4rem;
|
||||
font-weight: 700;
|
||||
color: var(--app-primary, #2A6080);
|
||||
margin: 0;
|
||||
flex: 1;
|
||||
}
|
||||
|
||||
.refresh-btn {
|
||||
background: transparent;
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: 0.375rem;
|
||||
cursor: pointer;
|
||||
font-size: 1rem;
|
||||
padding: 0.3rem 0.5rem;
|
||||
transition: background 0.15s;
|
||||
}
|
||||
|
||||
.refresh-btn:hover:not(:disabled) { background: var(--color-surface-raised, #e4ebf5); }
|
||||
.refresh-btn:disabled { opacity: 0.5; cursor: not-allowed; }
|
||||
|
||||
.error-notice {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.75rem;
|
||||
padding: 0.75rem 1rem;
|
||||
background: color-mix(in srgb, var(--color-error, #c0392b) 10%, transparent);
|
||||
border: 1px solid color-mix(in srgb, var(--color-error, #c0392b) 30%, transparent);
|
||||
border-radius: var(--radius-md, 0.5rem);
|
||||
color: var(--color-error, #c0392b);
|
||||
font-size: 0.88rem;
|
||||
}
|
||||
|
||||
.btn-retry {
|
||||
padding: 0.2rem 0.55rem;
|
||||
border-radius: 0.25rem;
|
||||
border: 1px solid var(--color-error, #c0392b);
|
||||
background: transparent;
|
||||
color: var(--color-error, #c0392b);
|
||||
cursor: pointer;
|
||||
font-size: 0.82rem;
|
||||
}
|
||||
|
||||
.empty-notice {
|
||||
color: var(--color-text-muted, #4a5c7a);
|
||||
font-size: 0.9rem;
|
||||
padding: 0.75rem;
|
||||
border: 1px dashed var(--color-border, #a8b8d0);
|
||||
border-radius: var(--radius-md, 0.5rem);
|
||||
}
|
||||
|
||||
.loading-state {
|
||||
color: var(--color-text-muted, #4a5c7a);
|
||||
font-size: 0.9rem;
|
||||
padding: 0.75rem;
|
||||
}
|
||||
|
||||
.results-table-wrap { overflow-x: auto; }
|
||||
|
||||
.results-table {
|
||||
width: 100%;
|
||||
border-collapse: collapse;
|
||||
font-size: 0.875rem;
|
||||
}
|
||||
|
||||
.results-table th {
|
||||
text-align: left;
|
||||
padding: 0.4rem 0.6rem;
|
||||
background: var(--color-surface-raised, #f5f7fc);
|
||||
color: var(--color-text-muted, #4a5c7a);
|
||||
font-size: 0.78rem;
|
||||
font-weight: 600;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.03em;
|
||||
border-bottom: 1px solid var(--color-border, #a8b8d0);
|
||||
white-space: nowrap;
|
||||
}
|
||||
|
||||
.th-numeric { text-align: right; }
|
||||
|
||||
.results-table td {
|
||||
padding: 0.5rem 0.6rem;
|
||||
border-bottom: 1px solid var(--color-border-light, #ccd5e6);
|
||||
vertical-align: middle;
|
||||
}
|
||||
|
||||
.td-id {
|
||||
font-family: var(--font-mono, monospace);
|
||||
font-size: 0.78rem;
|
||||
color: var(--color-text-muted, #4a5c7a);
|
||||
}
|
||||
|
||||
.td-model {
|
||||
font-family: var(--font-mono, monospace);
|
||||
font-size: 0.82rem;
|
||||
max-width: 16ch;
|
||||
overflow: hidden;
|
||||
text-overflow: ellipsis;
|
||||
white-space: nowrap;
|
||||
}
|
||||
|
||||
.td-numeric {
|
||||
text-align: right;
|
||||
font-family: var(--font-mono, monospace);
|
||||
font-variant-numeric: tabular-nums;
|
||||
font-size: 0.82rem;
|
||||
}
|
||||
|
||||
.td-actions { text-align: right; }
|
||||
|
||||
.metric-val { font-weight: 600; }
|
||||
.score-great { color: var(--color-success, #3a7a32); }
|
||||
.score-good { color: var(--color-warning, #d4891a); }
|
||||
.score-fair { color: var(--color-text-muted, #4a5c7a); }
|
||||
|
||||
.type-chip {
|
||||
font-size: 0.72rem;
|
||||
font-family: var(--font-mono, monospace);
|
||||
padding: 0.1rem 0.4rem;
|
||||
border-radius: 0.25rem;
|
||||
background: var(--color-surface-alt, #dde4f0);
|
||||
color: var(--color-text, #1a2338);
|
||||
}
|
||||
|
||||
.btn-sm-link {
|
||||
font-size: 0.78rem;
|
||||
padding: 0.2rem 0.55rem;
|
||||
border-radius: 0.3rem;
|
||||
border: 1px solid var(--app-primary, #2A6080);
|
||||
color: var(--app-primary, #2A6080);
|
||||
background: transparent;
|
||||
text-decoration: none;
|
||||
white-space: nowrap;
|
||||
display: inline-block;
|
||||
transition: background 0.1s;
|
||||
}
|
||||
|
||||
.btn-sm-link:hover {
|
||||
background: color-mix(in srgb, var(--app-primary, #2A6080) 10%, transparent);
|
||||
}
|
||||
|
||||
@media (max-width: 600px) {
|
||||
.results-table th:nth-child(6),
|
||||
.results-table td:nth-child(6),
|
||||
.results-table th:nth-child(7),
|
||||
.results-table td:nth-child(7) {
|
||||
display: none;
|
||||
}
|
||||
}
|
||||
</style>
|
||||
Loading…
Reference in a new issue