Compare commits

..

52 commits

Author SHA1 Message Date
32872d1ec6 fix: assigned-only state, remove dead HfNodeModelPanel prop, deduplicate yaml example 2026-05-05 22:11:02 -07:00
1521198cb1 fix: code quality fixes from review (SSE abort, aria-live, shared types, type safety)
- Add AbortController to SSE pull stream in OllamaModelPanel; abort on unmount
- Fix SSE loop: break on success/error events, call fetchModels() after the loop
- Add AbortController to fetchModels() and fetchProfile() one-shot fetches
- Add onUnmounted cleanup to both panel components
- Extract GpuEntry, ServiceInfo, NodeSummary to web/src/types/nodes.ts
- Remove duplicate interface definitions from NodeCard, GpuRow, NodeManagementView
- Fix aria-live regions: persistent container with v-if on inner span (avoids
  screen reader announcement miss on initial mount)
- Tighten STATE_LABELS/STATE_ICONS to Record<ServiceState, string> for exhaustiveness
- Add explicit (await r.json()) as NodeSummary[] cast in fetchNodes()
2026-05-05 21:35:13 -07:00
8dda040480 fix: move /nodes route immediately after /fleet per spec 2026-05-05 21:29:35 -07:00
bf675ed1f6 feat: add OllamaModelPanel and HfNodeModelPanel Vue components 2026-05-05 21:24:38 -07:00
0efd1aedbe feat: add NodeCard, GpuRow, ServiceBadge Vue components 2026-05-05 21:24:32 -07:00
4c225b94f5 feat: add /nodes route, AppSidebar nav item, and NodeManagementView 2026-05-05 21:24:27 -07:00
1cd9c5d455 fix: move json import to module scope in nodes.py 2026-05-05 21:01:32 -07:00
5702a7190b feat: add Ollama list/pull-SSE/delete endpoints 2026-05-05 20:41:29 -07:00
55b017ba3b fix: log coordinator reload failures in update_gpu_services
- Replace bare `except Exception: pass` with `except Exception as exc` and a
  logger.warning call that surfaces node_id and the exception for diagnostics.
- Move `import os as _os` from mid-file (between test functions) to the
  top-level import block to satisfy PEP 8 and linter expectations.
2026-05-05 20:36:08 -07:00
f952ec8971 feat: add profile endpoint and GPU service assignment with compatibility check 2026-05-05 20:33:41 -07:00
fd8cb622a1 feat: add GET /api/nodes-mgmt/nodes/{node_id}/profile endpoint 2026-05-05 20:31:22 -07:00
47cb9f661f fix: narrow exception handling in list_nodes, move mock imports to top
- Remove redundant httpx.ConnectError from nodes except clause (it's a
  subclass of HTTPError so the tuple catch was redundant)
- Narrow services except clause from bare Exception to httpx.HTTPError,
  add logger.warning with coordinator_url for debuggability
- Move `from unittest.mock import MagicMock, patch` from mid-file to
  the top-of-file import block with the other stdlib/third-party imports
2026-05-05 20:18:50 -07:00
c2de9e53da feat: implement GET /api/nodes-mgmt/nodes with coordinator proxy and profile merge 2026-05-05 20:16:06 -07:00
c039ea4698 fix: remove unused imports and em dash in nodes.py scaffold
- Drop unused StreamingResponse import from app/nodes.py (will be
  re-added in Task 2 when the SSE endpoint is implemented)
- Replace em dash with colon in _get_ollama_url HTTPException detail
- Remove unused os and unittest.mock imports from test_nodes.py
  (mock imports will return in Task 2 tests)
2026-05-05 19:59:32 -07:00
95afddb772 feat: add nodes.py scaffold with set_config_dir and router mount
- Create app/nodes.py with _CONFIG_DIR testability seam, _load_config,
  _profiles_dir, _profile_path, _load_profile, _get_ollama_url helpers,
  and stub list_nodes endpoint returning [] when no coordinator_url is set
- Mount nodes router at /api/nodes-mgmt in app/api.py
- Add profiles_dir comment to config/label_tool.yaml.example cforch section
- Create tests/test_nodes.py with autouse fixture and two passing tests
2026-05-05 19:35:28 -07:00
cbe8c0f03e feat(benchmark): wire EmbeddingKNNAdapter into MODEL_REGISTRY; add embed_model config
- Add embed_model: nomic-embed-text to config/label_tool.yaml (local, gitignored)
- Add # embed_model: commented example to config/label_tool.yaml.example
- Add pyyaml>=6.0 to requirements.txt (explicit dep for _resolve_urls yaml.safe_load)
- Add params assertion to test_embed_knn_nomic_registry_entry
2026-05-05 14:05:45 -07:00
5df33b0f41 feat(benchmark): wire EmbeddingKNNAdapter into MODEL_REGISTRY as embed-knn-nomic 2026-05-05 12:43:48 -07:00
41584de5df fix(benchmark): guard empty exemplars, warn on malformed JSON in build_exemplars_from_jsonl 2026-05-05 12:41:46 -07:00
1d4c07e4a0 feat(benchmark): add build_exemplars_from_jsonl() for k-NN seed 2026-05-05 11:43:12 -07:00
e823b5e76d fix(classifier): majority-vote key, partial-load guard, sparse label test 2026-05-05 11:39:24 -07:00
88bc6bed67 feat(classifier): implement EmbeddingKNNAdapter.classify() with k-NN vote 2026-05-05 08:04:54 -07:00
4a64a6686d fix(classifier): atomic embed assignment, logging on orch failure, guard double load 2026-05-05 07:53:15 -07:00
f2f150b4fb feat(classifier): implement EmbeddingKNNAdapter.load() and unload() 2026-05-05 07:12:53 -07:00
72449561cf feat(classifier): add EmbeddingKNNAdapter skeleton and constructor tests 2026-05-05 06:08:21 -07:00
c177fb1628 fix(classifier): quality fixes for DEFAULT_EXEMPLARS — remove forward __all__ entry, tighten tests, fix survey exemplar 2026-05-04 20:03:18 -07:00
3be5055e31 feat(classifier): add DEFAULT_EXEMPLARS for embedding k-NN fallback 2026-05-04 17:44:44 -07:00
78b64d007d feat(classifier): add _cosine() helper for embedding similarity 2026-05-04 17:41:45 -07:00
bce932461a feat: plans benchmark harness — model scoring for CF planning prompts
Adds benchmark_plans.py script, plans_bench API router, PlansBenchTab Vue
component, and registers /api/plans-bench in api.py. Also extends models
registry (cf-text catalog integration), cforch client, LlmEvalTab, and
ModelsView with cf-orch fleet support. Wires Planning mode into BenchmarkView.
2026-05-02 23:36:04 -07:00
e11db5ccd9 fix: align train job/results API envelope, config_json key, progress SSE, dashboard model_key
- GET /api/train/jobs now returns {"jobs":[...]} instead of bare array
- GET /api/train/results now returns {"results":[...]} instead of bare array
- POST /api/train/jobs body key renamed config -> config_json to match Pydantic model
- SSE log handler now handles 'progress' event type (backend never emits 'log')
- Dashboard _get_active_jobs() adds model_key to SELECT and return dict
- corrections.py docstring updated: both /api/corrections and /api/sft prefixes noted
- test_train.py assertions updated to unwrap new envelope shapes
2026-05-02 21:22:18 -07:00
13d1a394d5 fix: add loading state, widen nullable types, add API response guard in TrainResultsView 2026-05-02 20:49:34 -07:00
b077371107 feat: add TrainResultsView with training history table and Fleet registration links 2026-05-02 20:46:03 -07:00
53b25b27ab fix: surface cancel errors, fix SSE sentinel scroll, add missing test coverage in TrainJobsView 2026-05-02 20:33:03 -07:00
e014da2dec feat: add TrainJobsView with job queue, form submission, cancel, and SSE log streaming 2026-05-02 20:28:19 -07:00
c48db45d91 test: fix async flush and add mode-switch coverage in BenchmarkView 2026-05-02 19:35:02 -07:00
d0ba75b995 feat: extract CompareView at /eval/compare; remove Compare tab from BenchmarkView 2026-05-02 18:03:13 -07:00
a134af8b7b feat: add DashboardView with flywheel stage cards and CTA nudges 2026-05-02 16:50:24 -07:00
6ef6f06023 feat: restructure AppSidebar into two-domain nav with section headers and flywheel signal badges 2026-05-02 13:52:45 -07:00
5bdb095235 feat: restructure router into /data/* /eval/* /train/* domains with backward-compat redirects
- Export named `routes` array from router/index.ts for testability
- Move label/fetch/corrections/imitate under /data/* namespace
- Move benchmark/compare under /eval/* namespace
- Add /train/jobs and /train/results under /train/* namespace
- Add / -> DashboardView and /fleet -> ModelsView (replaces old / -> LabelView)
- Add backward-compat redirects for all old flat paths (/benchmark, /models, /stats, /label, /fetch, /corrections, /imitate)
- Add stub views for DashboardView, CompareView, TrainJobsView, TrainResultsView (implemented in later tasks)
- Add router.test.ts: 16 tests covering route structure and redirect targets
2026-05-02 13:00:04 -07:00
0904967320 feat: slim api.py to factory-only; all domain routes in dedicated modules
Replace 149-line api.py (with inline helpers, JSONL utilities, and ad-hoc
router registrations) with a 57-line pure factory. All business logic was
already extracted to domain modules in B1-B7; this removes the dead code
and adds the /api/corrections/* prefix alongside the /api/sft/* backward-
compat alias. Smoke tests updated to cover the new /api/corrections/ingest
and /api/dashboard routes.
2026-05-02 09:55:58 -07:00
8fda821e15 feat: add POST /ingest endpoint to corrections API with Bearer auth
Adds IngestRequest model and POST /api/sft/ingest route to
app/data/corrections.py. Sibling CF products (Peregrine, Kiwi, etc.)
can push pre-approved corrections via Bearer token auth
(AVOCET_INGESTION_SECRET). Records land as status=approved in both
sft_candidates.jsonl and sft_approved.jsonl immediately.

7 tests in tests/test_data_corrections.py cover 503 (secret unset),
401 (missing/malformed header), 403 (wrong secret), happy-path writes
to both files, and optional label field.
2026-05-02 09:07:10 -07:00
0853ed7d56 fix: add logger.warning to silent except blocks in dashboard._find_latest_eval 2026-05-01 23:36:19 -07:00
aa742bcfc0 feat: add GET /api/dashboard flywheel aggregate endpoint 2026-05-01 23:30:04 -07:00
32d3436bbd fix: path traversal guard, python_bin config, completed_at on Popen failure 2026-05-01 23:24:00 -07:00
766fbafa02 feat: build SQLite-backed train job queue in app/train/train.py
Replaces the ad-hoc _running_procs dict in api.py with a persistent,
inspectable SQLite job queue. Removes old /api/finetune/* routes and
_best_cuda_device from api.py. Adds /api/train/* routes (list, create,
get, cancel, run SSE, results). 16 new tests all passing.
2026-05-01 23:05:11 -07:00
d432026fd7 fix: restore real plans_bench.py (was accidentally stubbed) 2026-05-01 22:25:22 -07:00
bccb385f61 feat: build app/eval/cforch.py aggregating eval benchmark routers 2026-05-01 22:23:06 -07:00
d74ad3f972 feat: move imitate API into app/data/imitate.py 2026-05-01 22:12:19 -07:00
99ea39fe38 feat: move SFT corrections API into app/data/corrections.py 2026-05-01 22:02:22 -07:00
2054866ff1 feat: extract fetch routes and IMAP helpers into app/data/fetch.py 2026-05-01 21:57:31 -07:00
cbec776ef1 fix: restore ensure_ascii=False in utils jsonl helpers; remove dead _last_action from api.py 2026-05-01 20:59:44 -07:00
167d7351e3 feat: extract label queue API into app/data/label.py 2026-05-01 18:48:14 -07:00
6689ff07b1 chore: gitignore .worktrees/ directory 2026-05-01 12:25:23 -07:00
65 changed files with 11024 additions and 2449 deletions

View file

@ -17,3 +17,7 @@ CF_LICENSE_KEY=CFG-AVCT-xxxx-xxxx-xxxx
# Set one of these to use a cloud LLM instead of a local model. # Set one of these to use a cloud LLM instead of a local model.
# ANTHROPIC_API_KEY=sk-ant-... # ANTHROPIC_API_KEY=sk-ant-...
# OPENAI_API_KEY=sk-... # OPENAI_API_KEY=sk-...
# ── HuggingFace (required for gated/terms-restricted model downloads) ─────────
# Generate at https://huggingface.co/settings/tokens and accept model terms first.
# HF_TOKEN=hf_xxxxxxxxxxxxxxxxxxxx

4
.gitignore vendored
View file

@ -20,3 +20,7 @@ data/sft_approved.jsonl
# Claude context — BSL 1.1, keep out of version control # Claude context — BSL 1.1, keep out of version control
CLAUDE.md CLAUDE.md
docs/superpowers/ docs/superpowers/
.superpowers/
# Git worktrees
.worktrees/

View file

@ -1,623 +1,62 @@
"""Avocet — FastAPI REST layer. """Avocet -- FastAPI app factory.
JSONL read/write helpers and FastAPI app instance. Mounts all domain routers and serves the Vue SPA.
Endpoints and static file serving are added in subsequent tasks. All business logic lives in the domain modules below.
""" """
from __future__ import annotations from __future__ import annotations
import hashlib
import json
import os
import subprocess as _subprocess
import yaml
from pathlib import Path from pathlib import Path
from datetime import datetime, timezone from fastapi import FastAPI
from fastapi import FastAPI, HTTPException, Query
from pydantic import BaseModel
_ROOT = Path(__file__).parent.parent
_DATA_DIR: Path = _ROOT / "data" # overridable in tests via set_data_dir()
_MODELS_DIR: Path = _ROOT / "models" # overridable in tests via set_models_dir()
_CONFIG_DIR: Path | None = None # None = use real path
# Process registry for running jobs — used by cancel endpoints.
# Keys: "benchmark" | "finetune". Values: the live Popen object.
_running_procs: dict = {}
_cancelled_jobs: set = set()
def set_data_dir(path: Path) -> None:
"""Override data directory — used by tests."""
global _DATA_DIR
_DATA_DIR = path
def _best_cuda_device() -> str:
"""Return the index of the GPU with the most free VRAM as a string.
Uses nvidia-smi so it works in the job-seeker env (no torch). Returns ""
if nvidia-smi is unavailable or no GPUs are found. Restricting the
training subprocess to a single GPU via CUDA_VISIBLE_DEVICES prevents
PyTorch DataParallel from replicating the model across all GPUs, which
would OOM the GPU with less headroom.
"""
try:
out = _subprocess.check_output(
["nvidia-smi", "--query-gpu=index,memory.free",
"--format=csv,noheader,nounits"],
text=True,
timeout=5,
)
best_idx, best_free = "", 0
for line in out.strip().splitlines():
parts = line.strip().split(", ")
if len(parts) == 2:
idx, free = parts[0].strip(), int(parts[1].strip())
if free > best_free:
best_free, best_idx = free, idx
return best_idx
except Exception:
return ""
def set_models_dir(path: Path) -> None:
"""Override models directory — used by tests."""
global _MODELS_DIR
_MODELS_DIR = path
def set_config_dir(path: Path | None) -> None:
"""Override config directory — used by tests."""
global _CONFIG_DIR
_CONFIG_DIR = path
def _config_file() -> Path:
if _CONFIG_DIR is not None:
return _CONFIG_DIR / "label_tool.yaml"
return _ROOT / "config" / "label_tool.yaml"
def reset_last_action() -> None:
"""Reset undo state — used by tests."""
global _last_action
_last_action = None
def _queue_file() -> Path:
return _DATA_DIR / "email_label_queue.jsonl"
def _score_file() -> Path:
return _DATA_DIR / "email_score.jsonl"
def _discarded_file() -> Path:
return _DATA_DIR / "discarded.jsonl"
def _read_jsonl(path: Path) -> list[dict]:
if not path.exists():
return []
lines = path.read_text(encoding="utf-8").splitlines()
return [json.loads(l) for l in lines if l.strip()]
def _write_jsonl(path: Path, records: list[dict]) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
text = "\n".join(json.dumps(r, ensure_ascii=False) for r in records)
path.write_text(text + "\n" if records else "", encoding="utf-8")
def _append_jsonl(path: Path, record: dict) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
with path.open("a", encoding="utf-8") as f:
f.write(json.dumps(record, ensure_ascii=False) + "\n")
def _item_id(item: dict) -> str:
"""Stable content-hash ID — matches label_tool.py _entry_key dedup logic."""
key = (item.get("subject", "") + (item.get("body", "") or "")[:100])
return hashlib.md5(key.encode("utf-8", errors="replace")).hexdigest()
def _normalize(item: dict) -> dict:
"""Normalize JSONL item to the Vue frontend schema.
label_tool.py stores: subject, body, from_addr, date, account (no id).
The Vue app expects: id, subject, body, from, date, source.
Both old (from_addr/account) and new (from/source) field names are handled.
"""
return {
"id": item.get("id") or _item_id(item),
"subject": item.get("subject", ""),
"body": item.get("body", ""),
"from": item.get("from") or item.get("from_addr", ""),
"date": item.get("date", ""),
"source": item.get("source") or item.get("account", ""),
}
app = FastAPI(title="Avocet API") app = FastAPI(title="Avocet API")
from app.sft import router as sft_router # -- Domain routers --------------------------------------------------------
app.include_router(sft_router, prefix="/api/sft")
from app.models import router as models_router from app.data.label import router as label_router
import app.models as _models_module app.include_router(label_router, prefix="/api")
app.include_router(models_router, prefix="/api/models")
from app.cforch import router as cforch_router from app.data.fetch import router as fetch_router
app.include_router(cforch_router, prefix="/api/cforch") app.include_router(fetch_router, prefix="/api")
from app.imitate import router as imitate_router from app.data.corrections import router as corrections_router
app.include_router(corrections_router, prefix="/api/corrections")
# Backward-compat alias -- remove when Vue SPA is updated to /api/corrections/*
app.include_router(corrections_router, prefix="/api/sft")
from app.data.imitate import router as imitate_router
app.include_router(imitate_router, prefix="/api/imitate") app.include_router(imitate_router, prefix="/api/imitate")
from app.style import router as style_router from app.eval.cforch import router as eval_router
app.include_router(style_router, prefix="/api/style") app.include_router(eval_router, prefix="/api")
from app.train.train import router as train_router
app.include_router(train_router, prefix="/api/train")
from app.plans_bench import router as plans_bench_router
app.include_router(plans_bench_router, prefix="/api/plans-bench")
# In-memory last-action store (single user, local tool — in-memory is fine) # In-memory last-action store (single user, local tool — in-memory is fine)
_last_action: dict | None = None _last_action: dict | None = None
from app.dashboard import router as dashboard_router
app.include_router(dashboard_router, prefix="/api")
@app.get("/api/queue") from app.models import router as models_router
def get_queue(limit: int = Query(default=10, ge=1, le=50)): app.include_router(models_router, prefix="/api/models")
items = _read_jsonl(_queue_file())
return {"items": [_normalize(x) for x in items[:limit]], "total": len(items)}
from app.nodes import router as nodes_router
app.include_router(nodes_router, prefix="/api/nodes-mgmt")
class LabelRequest(BaseModel): # -- Static SPA -- MUST be last (catches all unmatched paths) ---------------
id: str
label: str
_ROOT = Path(__file__).parent.parent
@app.post("/api/label")
def post_label(req: LabelRequest):
global _last_action
items = _read_jsonl(_queue_file())
match = next((x for x in items if _normalize(x)["id"] == req.id), None)
if not match:
raise HTTPException(404, f"Item {req.id!r} not found in queue")
record = {**match, "label": req.label,
"labeled_at": datetime.now(timezone.utc).isoformat()}
_append_jsonl(_score_file(), record)
_write_jsonl(_queue_file(), [x for x in items if _normalize(x)["id"] != req.id])
_last_action = {"type": "label", "item": match, "label": req.label}
return {"ok": True}
class SkipRequest(BaseModel):
id: str
@app.post("/api/skip")
def post_skip(req: SkipRequest):
global _last_action
items = _read_jsonl(_queue_file())
match = next((x for x in items if _normalize(x)["id"] == req.id), None)
if not match:
raise HTTPException(404, f"Item {req.id!r} not found in queue")
reordered = [x for x in items if _normalize(x)["id"] != req.id] + [match]
_write_jsonl(_queue_file(), reordered)
_last_action = {"type": "skip", "item": match}
return {"ok": True}
class DiscardRequest(BaseModel):
id: str
@app.post("/api/discard")
def post_discard(req: DiscardRequest):
global _last_action
items = _read_jsonl(_queue_file())
match = next((x for x in items if _normalize(x)["id"] == req.id), None)
if not match:
raise HTTPException(404, f"Item {req.id!r} not found in queue")
record = {**match, "label": "__discarded__",
"discarded_at": datetime.now(timezone.utc).isoformat()}
_append_jsonl(_discarded_file(), record)
_write_jsonl(_queue_file(), [x for x in items if _normalize(x)["id"] != req.id])
_last_action = {"type": "discard", "item": match}
return {"ok": True}
@app.delete("/api/label/undo")
def delete_undo():
global _last_action
if not _last_action:
raise HTTPException(404, "No action to undo")
action = _last_action
item = action["item"] # always the original clean queue item
# Perform file operations FIRST — only clear _last_action on success
if action["type"] == "label":
records = _read_jsonl(_score_file())
if not records:
raise HTTPException(409, "Score file is empty — cannot undo label")
_write_jsonl(_score_file(), records[:-1])
items = _read_jsonl(_queue_file())
_write_jsonl(_queue_file(), [item] + items)
elif action["type"] == "discard":
records = _read_jsonl(_discarded_file())
if not records:
raise HTTPException(409, "Discarded file is empty — cannot undo discard")
_write_jsonl(_discarded_file(), records[:-1])
items = _read_jsonl(_queue_file())
_write_jsonl(_queue_file(), [item] + items)
elif action["type"] == "skip":
items = _read_jsonl(_queue_file())
item_id = _normalize(item)["id"]
items = [item] + [x for x in items if _normalize(x)["id"] != item_id]
_write_jsonl(_queue_file(), items)
# Clear AFTER all file operations succeed
_last_action = None
return {"undone": {"type": action["type"], "item": _normalize(item)}}
# Label metadata — 10 labels matching label_tool.py
_LABEL_META = [
{"name": "interview_scheduled", "emoji": "\U0001f4c5", "color": "#4CAF50", "key": "1"},
{"name": "offer_received", "emoji": "\U0001f389", "color": "#2196F3", "key": "2"},
{"name": "rejected", "emoji": "\u274c", "color": "#F44336", "key": "3"},
{"name": "positive_response", "emoji": "\U0001f44d", "color": "#FF9800", "key": "4"},
{"name": "survey_received", "emoji": "\U0001f4cb", "color": "#9C27B0", "key": "5"},
{"name": "neutral", "emoji": "\u2b1c", "color": "#607D8B", "key": "6"},
{"name": "event_rescheduled", "emoji": "\U0001f504", "color": "#FF5722", "key": "7"},
{"name": "digest", "emoji": "\U0001f4f0", "color": "#00BCD4", "key": "8"},
{"name": "new_lead", "emoji": "\U0001f91d", "color": "#009688", "key": "9"},
{"name": "hired", "emoji": "\U0001f38a", "color": "#FFC107", "key": "h"},
]
@app.get("/api/config/labels")
def get_labels():
return _LABEL_META
@app.get("/api/config")
def get_config():
f = _config_file()
if not f.exists():
return {"accounts": [], "max_per_account": 500}
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
return {"accounts": raw.get("accounts", []), "max_per_account": raw.get("max_per_account", 500)}
class ConfigPayload(BaseModel):
accounts: list[dict]
max_per_account: int = 500
@app.post("/api/config")
def post_config(payload: ConfigPayload):
f = _config_file()
f.parent.mkdir(parents=True, exist_ok=True)
tmp = f.with_suffix(".tmp")
tmp.write_text(yaml.dump(payload.model_dump(), allow_unicode=True, sort_keys=False),
encoding="utf-8")
tmp.rename(f)
return {"ok": True}
@app.get("/api/stats")
def get_stats():
records = _read_jsonl(_score_file())
counts: dict[str, int] = {}
for r in records:
lbl = r.get("label", "")
if lbl:
counts[lbl] = counts.get(lbl, 0) + 1
benchmark_results: dict = {}
benchmark_path = _DATA_DIR / "benchmark_results.json"
if benchmark_path.exists():
try:
benchmark_results = json.loads(benchmark_path.read_text(encoding="utf-8"))
except Exception:
pass
return {
"total": len(records),
"counts": counts,
"score_file_bytes": _score_file().stat().st_size if _score_file().exists() else 0,
"benchmark_results": benchmark_results,
}
@app.get("/api/stats/download")
def download_stats():
from fastapi.responses import FileResponse
if not _score_file().exists():
raise HTTPException(404, "No score file")
return FileResponse(
str(_score_file()),
filename="email_score.jsonl",
media_type="application/jsonlines",
headers={"Content-Disposition": 'attachment; filename="email_score.jsonl"'},
)
class AccountTestRequest(BaseModel):
account: dict
@app.post("/api/accounts/test")
def test_account(req: AccountTestRequest):
from app.imap_fetch import test_connection
ok, message, count = test_connection(req.account)
return {"ok": ok, "message": message, "count": count}
from fastapi.responses import StreamingResponse
# ---------------------------------------------------------------------------
# Benchmark endpoints
# ---------------------------------------------------------------------------
@app.get("/api/benchmark/models")
def get_benchmark_models() -> dict:
"""Return installed models grouped by adapter_type category."""
models_dir: Path = _models_module._MODELS_DIR
categories: dict[str, list[dict]] = {
"ZeroShotAdapter": [],
"RerankerAdapter": [],
"GenerationAdapter": [],
"Unknown": [],
}
if models_dir.exists():
for sub in models_dir.iterdir():
if not sub.is_dir():
continue
info_path = sub / "model_info.json"
adapter_type = "Unknown"
repo_id: str | None = None
if info_path.exists():
try:
info = json.loads(info_path.read_text(encoding="utf-8"))
adapter_type = info.get("adapter_type") or info.get("adapter_recommendation") or "Unknown"
repo_id = info.get("repo_id")
except Exception:
pass
bucket = adapter_type if adapter_type in categories else "Unknown"
entry: dict = {"name": sub.name, "repo_id": repo_id, "adapter_type": adapter_type}
categories[bucket].append(entry)
return {"categories": categories}
@app.get("/api/benchmark/results")
def get_benchmark_results():
"""Return the most recently saved benchmark results, or an empty envelope."""
path = _DATA_DIR / "benchmark_results.json"
if not path.exists():
return {"models": {}, "sample_count": 0, "timestamp": None}
return json.loads(path.read_text())
@app.get("/api/benchmark/run")
def run_benchmark(include_slow: bool = False, model_names: str = ""):
"""Spawn the benchmark script and stream stdout as SSE progress events."""
python_bin = "/devl/miniconda3/envs/job-seeker-classifiers/bin/python"
script = str(_ROOT / "scripts" / "benchmark_classifier.py")
cmd = [python_bin, script, "--score", "--save"]
if include_slow:
cmd.append("--include-slow")
if model_names:
names = [n.strip() for n in model_names.split(",") if n.strip()]
if names:
cmd.extend(["--models"] + names)
def generate():
try:
proc = _subprocess.Popen(
cmd,
stdout=_subprocess.PIPE,
stderr=_subprocess.STDOUT,
text=True,
bufsize=1,
cwd=str(_ROOT),
)
_running_procs["benchmark"] = proc
_cancelled_jobs.discard("benchmark") # clear any stale flag from a prior run
try:
for line in proc.stdout:
line = line.rstrip()
if line:
yield f"data: {json.dumps({'type': 'progress', 'message': line})}\n\n"
proc.wait()
if proc.returncode == 0:
yield f"data: {json.dumps({'type': 'complete'})}\n\n"
elif "benchmark" in _cancelled_jobs:
_cancelled_jobs.discard("benchmark")
yield f"data: {json.dumps({'type': 'cancelled'})}\n\n"
else:
yield f"data: {json.dumps({'type': 'error', 'message': f'Process exited with code {proc.returncode}'})}\n\n"
finally:
_running_procs.pop("benchmark", None)
except Exception as exc:
yield f"data: {json.dumps({'type': 'error', 'message': str(exc)})}\n\n"
return StreamingResponse(
generate(),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
)
# ---------------------------------------------------------------------------
# Finetune endpoints
# ---------------------------------------------------------------------------
@app.get("/api/finetune/status")
def get_finetune_status():
"""Scan models/ for training_info.json files. Returns [] if none exist."""
models_dir = _MODELS_DIR
if not models_dir.exists():
return []
results = []
for sub in models_dir.iterdir():
if not sub.is_dir():
continue
info_path = sub / "training_info.json"
if not info_path.exists():
continue
try:
info = json.loads(info_path.read_text(encoding="utf-8"))
results.append(info)
except Exception:
pass
return results
@app.get("/api/finetune/run")
def run_finetune_endpoint(
model: str = "deberta-small",
epochs: int = 5,
score: list[str] = Query(default=[]),
):
"""Spawn finetune_classifier.py and stream stdout as SSE progress events."""
python_bin = "/devl/miniconda3/envs/job-seeker-classifiers/bin/python"
script = str(_ROOT / "scripts" / "finetune_classifier.py")
cmd = [python_bin, script, "--model", model, "--epochs", str(epochs)]
data_root = _DATA_DIR.resolve()
for score_file in score:
resolved = (_DATA_DIR / score_file).resolve()
if not str(resolved).startswith(str(data_root)):
raise HTTPException(400, f"Invalid score path: {score_file!r}")
cmd.extend(["--score", str(resolved)])
# Pick the GPU with the most free VRAM. Setting CUDA_VISIBLE_DEVICES to a
# single device prevents DataParallel from replicating the model across all
# GPUs, which would force a full copy onto the more memory-constrained device.
proc_env = {**os.environ, "PYTORCH_ALLOC_CONF": "expandable_segments:True"}
best_gpu = _best_cuda_device()
if best_gpu:
proc_env["CUDA_VISIBLE_DEVICES"] = best_gpu
gpu_note = f"GPU {best_gpu}" if best_gpu else "CPU (no GPU found)"
def generate():
yield f"data: {json.dumps({'type': 'progress', 'message': f'[api] Using {gpu_note} (most free VRAM)'})}\n\n"
try:
proc = _subprocess.Popen(
cmd,
stdout=_subprocess.PIPE,
stderr=_subprocess.STDOUT,
text=True,
bufsize=1,
cwd=str(_ROOT),
env=proc_env,
)
_running_procs["finetune"] = proc
_cancelled_jobs.discard("finetune") # clear any stale flag from a prior run
try:
for line in proc.stdout:
line = line.rstrip()
if line:
yield f"data: {json.dumps({'type': 'progress', 'message': line})}\n\n"
proc.wait()
if proc.returncode == 0:
yield f"data: {json.dumps({'type': 'complete'})}\n\n"
elif "finetune" in _cancelled_jobs:
_cancelled_jobs.discard("finetune")
yield f"data: {json.dumps({'type': 'cancelled'})}\n\n"
else:
yield f"data: {json.dumps({'type': 'error', 'message': f'Process exited with code {proc.returncode}'})}\n\n"
finally:
_running_procs.pop("finetune", None)
except Exception as exc:
yield f"data: {json.dumps({'type': 'error', 'message': str(exc)})}\n\n"
return StreamingResponse(
generate(),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
)
@app.post("/api/benchmark/cancel")
def cancel_benchmark():
"""Kill the running benchmark subprocess. 404 if none is running."""
proc = _running_procs.get("benchmark")
if proc is None:
raise HTTPException(404, "No benchmark is running")
_cancelled_jobs.add("benchmark")
proc.terminate()
try:
proc.wait(timeout=3)
except _subprocess.TimeoutExpired:
proc.kill()
return {"status": "cancelled"}
@app.post("/api/finetune/cancel")
def cancel_finetune():
"""Kill the running fine-tune subprocess. 404 if none is running."""
proc = _running_procs.get("finetune")
if proc is None:
raise HTTPException(404, "No finetune is running")
_cancelled_jobs.add("finetune")
proc.terminate()
try:
proc.wait(timeout=3)
except _subprocess.TimeoutExpired:
proc.kill()
return {"status": "cancelled"}
@app.get("/api/fetch/stream")
def fetch_stream(
accounts: str = Query(default=""),
days_back: int = Query(default=90, ge=1, le=365),
limit: int = Query(default=150, ge=1, le=1000),
mode: str = Query(default="wide"),
):
from app.imap_fetch import fetch_account_stream
selected_names = {n.strip() for n in accounts.split(",") if n.strip()}
config = get_config() # reuse existing endpoint logic
selected = [a for a in config["accounts"] if a.get("name") in selected_names]
def generate():
known_keys = {_item_id(x) for x in _read_jsonl(_queue_file())}
total_added = 0
for acc in selected:
try:
batch_emails: list[dict] = []
for event in fetch_account_stream(acc, days_back, limit, known_keys):
if event["type"] == "done":
batch_emails = event.pop("emails", [])
total_added += event["added"]
yield f"data: {json.dumps(event)}\n\n"
# Write new emails to queue after each account
if batch_emails:
existing = _read_jsonl(_queue_file())
_write_jsonl(_queue_file(), existing + batch_emails)
except Exception as exc:
error_event = {"type": "error", "account": acc.get("name", "?"),
"message": str(exc)}
yield f"data: {json.dumps(error_event)}\n\n"
queue_size = len(_read_jsonl(_queue_file()))
complete = {"type": "complete", "total_added": total_added, "queue_size": queue_size}
yield f"data: {json.dumps(complete)}\n\n"
return StreamingResponse(generate(), media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"})
# Static SPA — MUST be last (catches all unmatched paths)
_DIST = _ROOT / "web" / "dist" _DIST = _ROOT / "web" / "dist"
if _DIST.exists(): if _DIST.exists():
from fastapi.responses import FileResponse from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
# Serve index.html with no-cache so browsers always fetch fresh HTML after rebuilds.
# Hashed assets (/assets/index-abc123.js) can be cached forever — they change names
# when content changes (standard Vite cache-busting strategy).
_NO_CACHE = {"Cache-Control": "no-cache, no-store, must-revalidate", "Pragma": "no-cache"} _NO_CACHE = {"Cache-Control": "no-cache, no-store, must-revalidate", "Pragma": "no-cache"}
@app.get("/") @app.get("/")

View file

@ -17,9 +17,12 @@ import logging
import os import os
import re import re
import subprocess as _subprocess import subprocess as _subprocess
import tempfile
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
import urllib.parse
import yaml import yaml
from fastapi import APIRouter, HTTPException from fastapi import APIRouter, HTTPException
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
@ -75,9 +78,31 @@ def _load_cforch_config() -> dict:
"license_key": _coalesce(file_cfg.get("license_key", ""), "CF_LICENSE_KEY"), "license_key": _coalesce(file_cfg.get("license_key", ""), "CF_LICENSE_KEY"),
"ollama_url": _coalesce(file_cfg.get("ollama_url", ""), "OLLAMA_HOST"), "ollama_url": _coalesce(file_cfg.get("ollama_url", ""), "OLLAMA_HOST"),
"ollama_model": _coalesce(file_cfg.get("ollama_model", ""), "OLLAMA_MODEL"), "ollama_model": _coalesce(file_cfg.get("ollama_model", ""), "OLLAMA_MODEL"),
"judge_url": _coalesce(file_cfg.get("judge_url", ""), "CF_JUDGE_URL"),
"hf_token": _coalesce(file_cfg.get("hf_token", ""), "HF_TOKEN"),
} }
def _validate_service_url(url: str, param_name: str) -> str:
"""Validate that a URL is a well-formed http/https URL with a hostname.
Guards against SSRF: only http/https is allowed; the URL must have a
non-empty host. Does not enforce an allowlist call sites are internal
tooling, not a public API.
"""
if not url:
return url
try:
parsed = urllib.parse.urlparse(url)
except Exception:
raise HTTPException(400, f"{param_name}: not a valid URL")
if parsed.scheme not in ("http", "https"):
raise HTTPException(400, f"{param_name}: URL must start with http:// or https://")
if not parsed.hostname:
raise HTTPException(400, f"{param_name}: URL has no hostname")
return url
def _strip_ansi(text: str) -> str: def _strip_ansi(text: str) -> str:
"""Remove ANSI escape codes from a string.""" """Remove ANSI escape codes from a string."""
return re.sub(r'\x1b\[[0-9;]*m', '', text) return re.sub(r'\x1b\[[0-9;]*m', '', text)
@ -147,48 +172,141 @@ def get_tasks() -> dict:
# ── GET /models ──────────────────────────────────────────────────────────────── # ── GET /models ────────────────────────────────────────────────────────────────
# Services and roles surfaced in the benchmark model picker.
# Covers all cf-orch service types that benchmark.py can route tasks to.
_BENCH_SERVICES = frozenset({
"cf-text", "vllm", # LLM text generation
"cf-stt", # speech-to-text
"cf-tts", # text-to-speech
"cf-vision", # image classification / embedding
"cf-voice", # audio context classification
})
_BENCH_ROLES = frozenset({
"generator", "vlm", # LLM roles
"stt", "alm", # speech recognition
"tts", # speech synthesis
"vision", "embedding", # image understanding
"classifier", # audio classification (cf-voice)
})
@router.get("/models") @router.get("/models")
def get_models() -> dict: def get_models() -> dict:
"""Return model list from bench_models.yaml.""" """Return model list from bench_models.yaml merged with locally installed models.
bench_models.yaml entries are listed first and take precedence; any installed
model whose repo_id is already present in the YAML is skipped. Only models
whose service is in _BENCH_SERVICES (cf-text, vllm, cf-stt, cf-tts, cf-vision,
cf-voice) are surfaced from the installed registry.
"""
cfg = _load_cforch_config() cfg = _load_cforch_config()
models_path = cfg.get("bench_models", "") models_path = cfg.get("bench_models", "")
if not models_path:
return {"models": []}
p = Path(models_path)
if not p.exists():
return {"models": []}
try:
raw = yaml.safe_load(p.read_text(encoding="utf-8")) or {}
except yaml.YAMLError as exc:
logger.warning("Failed to parse bench_models.yaml %s: %s", p, exc)
return {"models": []}
models_raw = raw.get("models", []) or []
models: list[dict] = [] models: list[dict] = []
for m in models_raw: bench_ids: set[str] = set()
if not isinstance(m, dict):
continue if models_path:
models.append({ p = Path(models_path)
"name": m.get("name", ""), if p.exists():
"id": m.get("id", ""), try:
"service": m.get("service", "ollama"), raw = yaml.safe_load(p.read_text(encoding="utf-8")) or {}
"tags": m.get("tags", []) or [], except yaml.YAMLError as exc:
"vram_estimate_mb": m.get("vram_estimate_mb", 0), logger.warning("Failed to parse bench_models.yaml %s: %s", p, exc)
}) raw = {}
for m in (raw.get("models", []) or []):
if not isinstance(m, dict):
continue
model_id = m.get("id", "")
models.append({
"name": m.get("name", ""),
"id": model_id,
"service": m.get("service", "ollama"),
"tags": m.get("tags", []) or [],
"vram_estimate_mb": m.get("vram_estimate_mb", 0),
})
if model_id:
bench_ids.add(model_id)
# Merge installed generator models not already in bench_models.yaml.
try:
from app.models import list_installed # local import avoids circular dependency at module load
for installed in list_installed():
model_id: str = installed.get("model_id") or ""
service: str = installed.get("service") or ""
role: str = installed.get("role") or ""
if not model_id:
continue
if service not in _BENCH_SERVICES or role not in _BENCH_ROLES:
continue
if model_id in bench_ids:
continue
display_name = model_id.split("/", 1)[-1] if "/" in model_id else model_id
models.append({
"name": display_name,
"id": model_id,
"service": service,
"tags": [role],
"vram_estimate_mb": installed.get("vram_mb") or 0,
})
bench_ids.add(model_id)
except Exception as exc:
logger.warning("Could not merge installed models into model list: %s", exc)
return {"models": models} return {"models": models}
# ── GET /run ─────────────────────────────────────────────────────────────────── # ── GET /run ───────────────────────────────────────────────────────────────────
@router.get("/nodes")
def get_nodes() -> dict:
"""Proxy the coordinator's /api/nodes list, returning node_id + online status.
Online is inferred from last_heartbeat: any node with a recent heartbeat is online.
Returns an empty list if the coordinator is unreachable.
"""
cfg = _load_cforch_config()
coordinator_url = cfg.get("coordinator_url", "").rstrip("/")
if not coordinator_url:
return {"nodes": []}
try:
import httpx as _httpx
resp = _httpx.get(f"{coordinator_url}/api/nodes", timeout=5.0)
resp.raise_for_status()
raw_nodes = resp.json().get("nodes", [])
return {
"nodes": [
{
"node_id": n.get("node_id", ""),
"online": n.get("last_heartbeat") is not None,
"gpus": [
{
"gpu_id": g.get("gpu_id"),
"name": g.get("name", ""),
"vram_total_mb": g.get("vram_total_mb", 0),
"vram_free_mb": g.get("vram_free_mb", 0),
}
for g in n.get("gpus", [])
],
}
for n in raw_nodes
]
}
except Exception as exc:
logger.warning("Could not fetch nodes from coordinator: %s", exc)
return {"nodes": []}
@router.get("/run") @router.get("/run")
def run_benchmark( def run_benchmark(
task_ids: str = "", task_ids: str = "",
model_ids: str = "",
model_tags: str = "", model_tags: str = "",
coordinator_url: str = "", coordinator_url: str = "",
ollama_url: str = "", ollama_url: str = "",
judge_url: str = "",
judge_backend: str = "chat",
workers: int = 1,
node_ids: str = "",
) -> StreamingResponse: ) -> StreamingResponse:
"""Spawn cf-orch benchmark.py and stream stdout as SSE progress events.""" """Spawn cf-orch benchmark.py and stream stdout as SSE progress events."""
global _BENCH_RUNNING, _bench_proc global _BENCH_RUNNING, _bench_proc
@ -205,6 +323,13 @@ def run_benchmark(
cfg_coordinator = cfg.get("coordinator_url", "") cfg_coordinator = cfg.get("coordinator_url", "")
cfg_ollama = cfg.get("ollama_url", "") cfg_ollama = cfg.get("ollama_url", "")
cfg_license_key = cfg.get("license_key", "") cfg_license_key = cfg.get("license_key", "")
cfg_judge_url = cfg.get("judge_url", "")
# Validate URL params before spawning the subprocess.
# _validate_service_url raises HTTPException on bad input (caught by FastAPI before streaming starts).
_validate_service_url(coordinator_url, "coordinator_url")
_validate_service_url(ollama_url, "ollama_url")
_validate_service_url(judge_url, "judge_url")
def generate(): def generate():
global _BENCH_RUNNING, _bench_proc global _BENCH_RUNNING, _bench_proc
@ -213,16 +338,68 @@ def run_benchmark(
yield f"data: {json.dumps({'type': 'error', 'message': 'bench_script not configured or not found'})}\n\n" yield f"data: {json.dumps({'type': 'error', 'message': 'bench_script not configured or not found'})}\n\n"
return return
# Build effective models file: bench_models.yaml + any installed models
# whose IDs were selected but are absent from the YAML (e.g. downloaded
# via the Models view). Written to a temp file so benchmark.py sees one
# unified list; cleaned up in the finally block.
effective_models_file = bench_models
_tmp_models_path: str | None = None
if model_ids and bench_models and Path(bench_models).exists():
requested_ids = set(model_ids.split(","))
try:
raw_bench = yaml.safe_load(Path(bench_models).read_text(encoding="utf-8")) or {}
bench_entries: list[dict] = raw_bench.get("models", []) or []
bench_id_set = {m.get("id", "") for m in bench_entries if isinstance(m, dict)}
missing_ids = requested_ids - bench_id_set
if missing_ids:
from app.models import list_installed
installed_map = {
m["model_id"]: m
for m in list_installed()
if m.get("model_id") and m.get("service") in _BENCH_SERVICES
}
extra: list[dict] = []
for mid in missing_ids:
if mid in installed_map:
inst = installed_map[mid]
entry: dict[str, Any] = {
"id": mid,
"name": mid.split("/", 1)[-1] if "/" in mid else mid,
"service": inst.get("service", "cf-text"),
"vram_estimate_mb": inst.get("vram_mb") or 0,
"tags": [inst.get("role", "generator")],
"temperature": 0.0,
}
local_path = inst.get("path", "") or inst.get("local_path", "")
if local_path:
entry["model_path"] = local_path
extra.append(entry)
if extra:
merged = {"models": bench_entries + extra}
tf = tempfile.NamedTemporaryFile(
mode="w", suffix=".yaml", delete=False,
prefix="avocet_bench_models_",
)
yaml.dump(merged, tf)
tf.close()
_tmp_models_path = tf.name
effective_models_file = _tmp_models_path
except Exception as exc:
logger.warning("Could not merge installed models into temp bench file: %s", exc)
cmd = [ cmd = [
python_bin, python_bin,
bench_script, bench_script,
"--tasks", bench_tasks, "--tasks", bench_tasks,
"--models", bench_models, "--models", effective_models_file,
"--output", results_dir, "--output", results_dir,
] ]
if task_ids: if task_ids:
cmd.extend(["--filter-tasks"] + task_ids.split(",")) cmd.extend(["--filter-tasks"] + task_ids.split(","))
if model_ids:
cmd.extend(["--filter-models"] + model_ids.split(","))
if model_tags: if model_tags:
cmd.extend(["--filter-tags"] + model_tags.split(",")) cmd.extend(["--filter-tags"] + model_tags.split(","))
@ -233,6 +410,15 @@ def run_benchmark(
cmd.extend(["--coordinator", effective_coordinator]) cmd.extend(["--coordinator", effective_coordinator])
if effective_ollama: if effective_ollama:
cmd.extend(["--ollama-url", effective_ollama]) cmd.extend(["--ollama-url", effective_ollama])
effective_judge = judge_url if judge_url else cfg_judge_url
if effective_judge:
cmd.extend(["--judge-url", effective_judge])
if judge_backend and judge_backend != "chat":
cmd.extend(["--judge-backend", judge_backend])
if workers > 1:
cmd.extend(["--workers", str(workers)])
if node_ids:
cmd.extend(["--nodes"] + node_ids.split(","))
# Pass license key as env var so subprocess can authenticate with cf-orch # Pass license key as env var so subprocess can authenticate with cf-orch
proc_env = {**os.environ} proc_env = {**os.environ}
@ -273,6 +459,11 @@ def run_benchmark(
yield f"data: {json.dumps({'type': 'error', 'message': str(exc)})}\n\n" yield f"data: {json.dumps({'type': 'error', 'message': str(exc)})}\n\n"
finally: finally:
_BENCH_RUNNING = False _BENCH_RUNNING = False
if _tmp_models_path:
try:
os.unlink(_tmp_models_path)
except OSError:
pass
return StreamingResponse( return StreamingResponse(
generate(), generate(),
@ -295,6 +486,7 @@ def get_cforch_config() -> dict:
"coordinator_url": cfg.get("coordinator_url", ""), "coordinator_url": cfg.get("coordinator_url", ""),
"ollama_url": cfg.get("ollama_url", ""), "ollama_url": cfg.get("ollama_url", ""),
"ollama_model": cfg.get("ollama_model", ""), "ollama_model": cfg.get("ollama_model", ""),
"judge_url": cfg.get("judge_url", ""),
"license_key_set": bool(cfg.get("license_key", "")), "license_key_set": bool(cfg.get("license_key", "")),
"source": "env" if not _config_file().exists() else "yaml+env", "source": "env" if not _config_file().exists() else "yaml+env",
} }
@ -303,7 +495,7 @@ def get_cforch_config() -> dict:
# ── GET /results ─────────────────────────────────────────────────────────────── # ── GET /results ───────────────────────────────────────────────────────────────
@router.get("/results") @router.get("/results")
def get_results() -> dict: def get_results() -> list:
"""Return the latest benchmark summary.json from results_dir.""" """Return the latest benchmark summary.json from results_dir."""
cfg = _load_cforch_config() cfg = _load_cforch_config()
results_dir = cfg.get("results_dir", "") results_dir = cfg.get("results_dir", "")

191
app/dashboard.py Normal file
View file

@ -0,0 +1,191 @@
"""Avocet -- dashboard aggregate API.
GET /api/dashboard returns the current flywheel state:
labeled_since_last_eval -- items labeled after the most recent eval run
last_eval_timestamp -- ISO timestamp of newest bench_results summary
last_eval_best_score -- best macro_f1 from that summary
active_jobs -- jobs with status queued or running
corrections_pending -- sft_candidates with status=needs_review
corrections_export_ready -- approved sft candidates with non-blank correction
signals -- computed booleans for UI nudge indicators
Thresholds in label_tool.yaml pipeline: section:
pipeline:
data_eval_threshold: 50 # labeled items since last eval to trigger nudge
eval_train_threshold: 0.05 # improvement delta needed before retraining (future)
"""
from __future__ import annotations
import json
import logging
import yaml
from pathlib import Path
from fastapi import APIRouter
logger = logging.getLogger(__name__)
_ROOT = Path(__file__).parent.parent
_DATA_DIR: Path = _ROOT / "data"
_CONFIG_DIR: Path | None = None
router = APIRouter()
_DEFAULT_DATA_EVAL_THRESHOLD = 50
_DEFAULT_EVAL_TRAIN_THRESHOLD = 0.05
def set_data_dir(path: Path) -> None:
global _DATA_DIR
_DATA_DIR = path
def set_config_dir(path: Path | None) -> None:
global _CONFIG_DIR
_CONFIG_DIR = path
def _config_file() -> Path:
if _CONFIG_DIR is not None:
return _CONFIG_DIR / "label_tool.yaml"
return _ROOT / "config" / "label_tool.yaml"
def _load_thresholds() -> tuple[int, float]:
f = _config_file()
if f.exists():
try:
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
pipeline = raw.get("pipeline", {}) or {}
return (
int(pipeline.get("data_eval_threshold", _DEFAULT_DATA_EVAL_THRESHOLD)),
float(pipeline.get("eval_train_threshold", _DEFAULT_EVAL_TRAIN_THRESHOLD)),
)
except Exception as exc:
logger.warning("Failed to read pipeline thresholds: %s", exc)
return _DEFAULT_DATA_EVAL_THRESHOLD, _DEFAULT_EVAL_TRAIN_THRESHOLD
def _load_score_records() -> list[dict]:
path = _DATA_DIR / "email_score.jsonl"
if not path.exists():
return []
records = []
for line in path.read_text(encoding="utf-8").splitlines():
line = line.strip()
if not line:
continue
try:
records.append(json.loads(line))
except json.JSONDecodeError:
pass
return records
def _find_latest_eval(results_dir_override: str = "") -> tuple[str | None, float | None]:
"""Return (iso_timestamp, best_macro_f1) from the newest bench_results summary.
Checks results_dir from cforch config if set, then falls back to
_ROOT/bench_results/. Returns (None, None) if no results exist.
"""
candidates = []
if results_dir_override:
candidates.append(Path(results_dir_override))
else:
f = _config_file()
if f.exists():
try:
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
rd = (raw.get("cforch", {}) or {}).get("results_dir", "")
if rd:
candidates.append(Path(rd))
except Exception as exc:
logger.warning("Failed to read cforch.results_dir from config: %s", exc)
candidates.append(_ROOT / "bench_results")
for rdir in candidates:
if not rdir.exists():
continue
subdirs = sorted([d for d in rdir.iterdir() if d.is_dir()], key=lambda d: d.name)
for subdir in reversed(subdirs):
summary = subdir / "summary.json"
if summary.exists():
try:
data = json.loads(summary.read_text(encoding="utf-8"))
ts = data.get("timestamp") or subdir.name
score = data.get("best_macro_f1") or data.get("macro_f1")
return ts, (float(score) if isinstance(score, (int, float)) else None)
except Exception as exc:
logger.warning("Failed to parse summary.json at %s: %s", summary, exc)
return None, None
def _count_corrections() -> tuple[int, int]:
"""Return (pending_count, export_ready_count)."""
pending = 0
export_ready = 0
candidates_path = _DATA_DIR / "sft_candidates.jsonl"
approved_path = _DATA_DIR / "sft_approved.jsonl"
if candidates_path.exists():
for line in candidates_path.read_text(encoding="utf-8").splitlines():
line = line.strip()
if not line:
continue
try:
r = json.loads(line)
if r.get("status") == "needs_review":
pending += 1
except json.JSONDecodeError:
pass
if approved_path.exists():
for line in approved_path.read_text(encoding="utf-8").splitlines():
line = line.strip()
if not line:
continue
try:
r = json.loads(line)
if (r.get("status") == "approved"
and r.get("corrected_response")
and str(r["corrected_response"]).strip()):
export_ready += 1
except json.JSONDecodeError:
pass
return pending, export_ready
def _get_active_jobs() -> list[dict]:
"""Query train SQLite DB for queued/running jobs. Returns [] if DB absent."""
try:
from app.train.train import _DB_PATH, _db, _init_db
if not _DB_PATH.exists():
return []
_init_db()
with _db() as conn:
rows = conn.execute(
"SELECT id, type, model_key, status FROM jobs WHERE status IN ('queued', 'running')"
).fetchall()
return [{"id": r["id"], "type": r["type"], "model_key": r["model_key"], "status": r["status"]} for r in rows]
except Exception as exc:
logger.warning("Failed to query train jobs DB: %s", exc)
return []
def _count_labeled_since(since_ts: str | None) -> int:
records = _load_score_records()
if since_ts is None:
return len(records)
return sum(1 for r in records if r.get("labeled_at", "") > since_ts)
@router.get("/dashboard")
def get_dashboard() -> dict:
data_eval_threshold, eval_train_threshold = _load_thresholds()
last_eval_ts, last_eval_score = _find_latest_eval()
labeled_since = _count_labeled_since(last_eval_ts)
corrections_pending, corrections_export_ready = _count_corrections()
active_jobs = _get_active_jobs()
return {
"labeled_since_last_eval": labeled_since,
"last_eval_timestamp": last_eval_ts,
"last_eval_best_score": last_eval_score,
"active_jobs": active_jobs,
"corrections_pending": corrections_pending,
"corrections_export_ready": corrections_export_ready,
"signals": {
"data_to_eval": labeled_since >= data_eval_threshold,
"eval_to_train": False, # future: implement delta-F1 comparison
"train_to_fleet": False, # future: implement fleet sync signal
},
}

0
app/data/__init__.py Normal file
View file

393
app/data/corrections.py Normal file
View file

@ -0,0 +1,393 @@
"""Avocet -- SFT candidate corrections API (moved from app/sft.py).
All endpoints are registered on `router` (a FastAPI APIRouter).
Primary prefix: /api/corrections (backward-compat alias: /api/sft -- pending Vue SPA migration)
Module-level globals (_DATA_DIR, _CONFIG_DIR) follow the same
testability pattern as api.py -- override them via set_data_dir() and
set_config_dir() in test fixtures.
"""
from __future__ import annotations
import json
import logging
import os
import uuid
from datetime import datetime, timezone
from pathlib import Path
from typing import Literal
import yaml
from fastapi import APIRouter, Header, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from app.utils import append_jsonl, read_jsonl, write_jsonl
logger = logging.getLogger(__name__)
_ROOT = Path(__file__).parent.parent.parent
_DATA_DIR: Path = _ROOT / "data"
_CONFIG_DIR: Path | None = None
router = APIRouter()
# -- Testability seams ---------------------------------------------------------
def set_data_dir(path: Path) -> None:
global _DATA_DIR
_DATA_DIR = path
def set_config_dir(path: Path | None) -> None:
global _CONFIG_DIR
_CONFIG_DIR = path
# -- Internal helpers ----------------------------------------------------------
def _config_file() -> Path:
if _CONFIG_DIR is not None:
return _CONFIG_DIR / "label_tool.yaml"
return _ROOT / "config" / "label_tool.yaml"
_DEFAULT_BENCH_RESULTS_DIR = "/Library/Development/CircuitForge/circuitforge-orch/scripts/bench_results"
def set_default_bench_results_dir(path: str) -> None:
"""Override the default bench_results_dir -- used by tests to avoid real filesystem."""
global _DEFAULT_BENCH_RESULTS_DIR
_DEFAULT_BENCH_RESULTS_DIR = path
def _get_bench_results_dir() -> Path:
f = _config_file()
if f.exists():
try:
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
d = raw.get("sft", {}).get("bench_results_dir", "")
if d:
return Path(d)
except yaml.YAMLError as exc:
logger.warning("Failed to parse SFT config %s: %s", f, exc)
return Path(_DEFAULT_BENCH_RESULTS_DIR)
def _candidates_file() -> Path:
return _DATA_DIR / "sft_candidates.jsonl"
def _approved_file() -> Path:
return _DATA_DIR / "sft_approved.jsonl"
def _read_candidates() -> list[dict]:
return read_jsonl(_candidates_file())
def _write_candidates(records: list[dict]) -> None:
write_jsonl(_candidates_file(), records)
def _is_exportable(r: dict) -> bool:
"""Return True if an approved record is ready to include in SFT export."""
return (
r.get("status") == "approved"
and bool(r.get("corrected_response"))
and str(r["corrected_response"]).strip() != ""
)
# -- GET /runs -----------------------------------------------------------------
@router.get("/runs")
def get_runs():
"""List available benchmark runs in the configured bench_results_dir."""
from scripts.sft_import import discover_runs
bench_dir = _get_bench_results_dir()
existing = _read_candidates()
# benchmark_run_id in each record equals the run's directory name by cf-orch convention
imported_run_ids = {
r["benchmark_run_id"]
for r in existing
if r.get("benchmark_run_id") is not None
}
runs = discover_runs(bench_dir)
return [
{
"run_id": r["run_id"],
"timestamp": r["timestamp"],
"candidate_count": r["candidate_count"],
"already_imported": r["run_id"] in imported_run_ids,
}
for r in runs
]
# -- POST /import --------------------------------------------------------------
class ImportRequest(BaseModel):
run_id: str
@router.post("/import")
def post_import(req: ImportRequest):
"""Import one benchmark run's sft_candidates.jsonl into the local data dir."""
from scripts.sft_import import discover_runs, import_run
bench_dir = _get_bench_results_dir()
runs = discover_runs(bench_dir)
run = next((r for r in runs if r["run_id"] == req.run_id), None)
if run is None:
raise HTTPException(404, f"Run {req.run_id!r} not found in bench_results_dir")
return import_run(run["sft_path"], _DATA_DIR)
# -- GET /queue ----------------------------------------------------------------
@router.get("/queue")
def get_queue(page: int = 1, per_page: int = 20):
"""Return paginated needs_review candidates."""
records = _read_candidates()
pending = [r for r in records if r.get("status") == "needs_review"]
start = (page - 1) * per_page
return {
"items": pending[start:start + per_page],
"total": len(pending),
"page": page,
"per_page": per_page,
}
# -- POST /submit --------------------------------------------------------------
FailureCategory = Literal[
"scoring_artifact",
"style_violation",
"partial_answer",
"wrong_answer",
"format_error",
"hallucination",
]
class SubmitRequest(BaseModel):
id: str
action: Literal["correct", "discard", "flag"]
corrected_response: str | None = None
failure_category: FailureCategory | None = None
@router.post("/submit")
def post_submit(req: SubmitRequest):
"""Record a reviewer decision for one SFT candidate."""
if req.action == "correct":
if not req.corrected_response or not req.corrected_response.strip():
raise HTTPException(422, "corrected_response must be non-empty when action is 'correct'")
records = _read_candidates()
idx = next((i for i, r in enumerate(records) if r.get("id") == req.id), None)
if idx is None:
raise HTTPException(404, f"Record {req.id!r} not found")
record = records[idx]
if record.get("status") != "needs_review":
raise HTTPException(409, f"Record is not in needs_review state (current: {record.get('status')})")
if req.action == "correct":
records[idx] = {
**record,
"status": "approved",
"corrected_response": req.corrected_response,
"failure_category": req.failure_category,
}
_write_candidates(records)
append_jsonl(_approved_file(), records[idx])
elif req.action == "discard":
records[idx] = {**record, "status": "discarded"}
_write_candidates(records)
else: # flag
records[idx] = {**record, "status": "model_rejected"}
_write_candidates(records)
return {"ok": True}
# -- POST /undo ----------------------------------------------------------------
class UndoRequest(BaseModel):
id: str
@router.post("/undo")
def post_undo(req: UndoRequest):
"""Restore a previously actioned candidate back to needs_review."""
records = _read_candidates()
idx = next((i for i, r in enumerate(records) if r.get("id") == req.id), None)
if idx is None:
raise HTTPException(404, f"Record {req.id!r} not found")
record = records[idx]
old_status = record.get("status")
if old_status == "needs_review":
raise HTTPException(409, "Record is already in needs_review state")
records[idx] = {**record, "status": "needs_review", "corrected_response": None}
_write_candidates(records)
# If it was approved, remove from the approved file too
if old_status == "approved":
approved = read_jsonl(_approved_file())
write_jsonl(_approved_file(), [r for r in approved if r.get("id") != req.id])
return {"ok": True}
# -- GET /export ---------------------------------------------------------------
@router.get("/export")
def get_export() -> StreamingResponse:
"""Stream approved records as SFT-ready JSONL for download."""
exportable = [r for r in read_jsonl(_approved_file()) if _is_exportable(r)]
def generate():
for r in exportable:
record = {
"messages": r.get("prompt_messages", []) + [
{"role": "assistant", "content": r["corrected_response"]}
]
}
yield json.dumps(record) + "\n"
timestamp = datetime.now(timezone.utc).strftime("%Y%m%d-%H%M%S")
return StreamingResponse(
generate(),
media_type="application/x-ndjson",
headers={
"Content-Disposition": f'attachment; filename="sft_export_{timestamp}.jsonl"'
},
)
# -- GET /stats ----------------------------------------------------------------
@router.get("/stats")
def get_stats() -> dict[str, object]:
"""Return counts by status, model, and task type."""
records = _read_candidates()
by_status: dict[str, int] = {}
by_model: dict[str, int] = {}
by_task_type: dict[str, int] = {}
for r in records:
status = r.get("status", "unknown")
by_status[status] = by_status.get(status, 0) + 1
model = r.get("model_name", "unknown")
by_model[model] = by_model.get(model, 0) + 1
task_type = r.get("task_type", "unknown")
by_task_type[task_type] = by_task_type.get(task_type, 0) + 1
approved = read_jsonl(_approved_file())
export_ready = sum(1 for r in approved if _is_exportable(r))
return {
"total": len(records),
"by_status": by_status,
"by_model": by_model,
"by_task_type": by_task_type,
"export_ready": export_ready,
}
# -- GET /config ---------------------------------------------------------------
@router.get("/config")
def get_sft_config() -> dict:
"""Return the current SFT configuration (bench_results_dir)."""
f = _config_file()
if not f.exists():
return {"bench_results_dir": ""}
try:
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
except yaml.YAMLError:
return {"bench_results_dir": ""}
sft_section = raw.get("sft") or {}
return {"bench_results_dir": sft_section.get("bench_results_dir", "")}
class SftConfigPayload(BaseModel):
bench_results_dir: str
@router.post("/config")
def post_sft_config(payload: SftConfigPayload) -> dict:
"""Write the bench_results_dir setting to the config file."""
f = _config_file()
f.parent.mkdir(parents=True, exist_ok=True)
try:
raw = yaml.safe_load(f.read_text(encoding="utf-8")) if f.exists() else {}
raw = raw or {}
except yaml.YAMLError:
raw = {}
raw["sft"] = {"bench_results_dir": payload.bench_results_dir}
tmp = f.with_suffix(".tmp")
tmp.write_text(yaml.dump(raw, allow_unicode=True, sort_keys=False), encoding="utf-8")
tmp.rename(f)
return {"ok": True}
# -- POST /ingest --------------------------------------------------------------
class IngestRequest(BaseModel):
source: str # e.g. "peregrine", "kiwi"
task_type: str # e.g. "email_classification", "recipe_suggestion"
prompt: str # the prompt that was sent to the LLM
response: str # the LLM's original response
correction: str # the human-corrected response
label: str | None = None # optional label/category
@router.post("/ingest")
def post_ingest(
req: IngestRequest,
authorization: str | None = Header(default=None),
) -> dict:
"""Ingest a correction from a sibling CF product.
Authentication: Authorization: Bearer <AVOCET_INGESTION_SECRET>
Creates a sft_candidates record with status='approved' (pre-approved by
the calling product -- human review already happened upstream). Also writes
to sft_approved.jsonl so it is immediately included in export counts.
Returns {"ok": True, "id": "<uuid>"}.
"""
expected_secret = os.environ.get("AVOCET_INGESTION_SECRET", "")
if not expected_secret:
raise HTTPException(503, "Ingestion not configured -- AVOCET_INGESTION_SECRET not set")
if not authorization or not authorization.startswith("Bearer "):
raise HTTPException(401, "Missing or malformed Authorization header")
token = authorization.removeprefix("Bearer ").strip()
if token != expected_secret:
raise HTTPException(403, "Invalid ingestion secret")
record_id = str(uuid.uuid4())
now = datetime.now(timezone.utc).isoformat()
record = {
"id": record_id,
"source": req.source,
"task_type": req.task_type,
"status": "approved",
"prompt_messages": [{"role": "user", "content": req.prompt}],
"model_response": req.response,
"corrected_response": req.correction,
"label": req.label,
"timestamp": now,
"benchmark_run_id": None,
}
append_jsonl(_candidates_file(), record)
append_jsonl(_approved_file(), record)
return {"ok": True, "id": record_id}

243
app/data/fetch.py Normal file
View file

@ -0,0 +1,243 @@
"""Avocet -- IMAP fetch utilities and fetch API routes.
All IMAP helper functions (from app/imap_fetch.py) plus the
/api/accounts/test and /api/fetch/stream endpoints.
"""
from __future__ import annotations
import email as _email_lib
import hashlib
import imaplib
import json
import yaml
from datetime import datetime, timedelta
from email.header import decode_header as _raw_decode
from pathlib import Path
from typing import Iterator
from fastapi import APIRouter, Query
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from app.utils import extract_body, read_jsonl, write_jsonl
_ROOT = Path(__file__).parent.parent.parent
_DATA_DIR: Path = _ROOT / "data"
_CONFIG_DIR: Path | None = None
router = APIRouter()
def set_data_dir(path: Path) -> None:
global _DATA_DIR
_DATA_DIR = path
def set_config_dir(path: Path | None) -> None:
global _CONFIG_DIR
_CONFIG_DIR = path
def _config_file() -> Path:
if _CONFIG_DIR is not None:
return _CONFIG_DIR / "label_tool.yaml"
return _ROOT / "config" / "label_tool.yaml"
def _queue_file() -> Path:
return _DATA_DIR / "email_label_queue.jsonl"
def _get_config_accounts() -> list[dict]:
f = _config_file()
if not f.exists():
return []
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
return raw.get("accounts", [])
# ── IMAP decode helpers ───────────────────────────────────────────────────────
def _decode_str(value: str | None) -> str:
if not value:
return ""
parts = _raw_decode(value)
out = []
for part, enc in parts:
if isinstance(part, bytes):
out.append(part.decode(enc or "utf-8", errors="replace"))
else:
out.append(str(part))
return " ".join(out).strip()
def entry_key(e: dict) -> str:
"""Stable MD5 content-hash for dedup — matches label_tool.py _entry_key."""
key = (e.get("subject", "") + (e.get("body", "") or "")[:100])
return hashlib.md5(key.encode("utf-8", errors="replace")).hexdigest()
# ── Wide search terms ────────────────────────────────────────────────────────
_WIDE_TERMS = [
"interview", "phone screen", "video call", "zoom link", "schedule a call",
"offer letter", "job offer", "offer of employment", "pleased to offer",
"unfortunately", "not moving forward", "other candidates", "regret to inform",
"no longer", "decided not to", "decided to go with",
"opportunity", "interested in your background", "reached out", "great fit",
"exciting role", "love to connect",
"assessment", "questionnaire", "culture fit", "culture-fit", "online assessment",
"application received", "thank you for applying", "application confirmation",
"you applied", "your application for",
"reschedule", "rescheduled", "new time", "moved to", "postponed", "new date",
"job digest", "jobs you may like", "recommended jobs", "jobs for you",
"new jobs", "job alert",
"came across your profile", "reaching out about", "great fit for a role",
"exciting opportunity",
"welcome to the team", "start date", "onboarding", "first day", "we're excited to have you",
"application", "recruiter", "recruiting", "hiring", "candidate",
]
# ── Public API ────────────────────────────────────────────────────────────────
def test_connection(acc: dict) -> tuple[bool, str, int | None]:
"""Connect, login, select folder. Returns (ok, human_message, message_count|None)."""
host = acc.get("host", "")
port = int(acc.get("port", 993))
use_ssl = acc.get("use_ssl", True)
username = acc.get("username", "")
password = acc.get("password", "")
folder = acc.get("folder", "INBOX")
if not host or not username or not password:
return False, "Host, username, and password are all required.", None
try:
conn = (imaplib.IMAP4_SSL if use_ssl else imaplib.IMAP4)(host, port)
conn.login(username, password)
_, data = conn.select(folder, readonly=True)
count_raw = data[0].decode() if data and data[0] else "0"
count = int(count_raw) if count_raw.isdigit() else 0
conn.logout()
return True, f"Connected — {count:,} message(s) in {folder}.", count
except Exception as exc:
return False, str(exc), None
def fetch_account_stream(
acc: dict,
days_back: int,
limit: int,
known_keys: set[str],
) -> Iterator[dict]:
"""Generator — yields progress dicts while fetching emails via IMAP.
Mutates `known_keys` in place for cross-account dedup within one fetch session.
Yields event dicts with "type" key:
{"type": "start", "account": str, "total_uids": int}
{"type": "progress", "account": str, "fetched": int, "total_uids": int}
{"type": "done", "account": str, "added": int, "skipped": int, "emails": list}
"""
name = acc.get("name", acc.get("username", "?"))
host = acc.get("host", "imap.gmail.com")
port = int(acc.get("port", 993))
use_ssl = acc.get("use_ssl", True)
username = acc["username"]
password = acc["password"]
folder = acc.get("folder", "INBOX")
since = (datetime.now() - timedelta(days=days_back)).strftime("%d-%b-%Y")
conn = (imaplib.IMAP4_SSL if use_ssl else imaplib.IMAP4)(host, port)
conn.login(username, password)
conn.select(folder, readonly=True)
seen_uids: dict[bytes, None] = {}
for term in _WIDE_TERMS:
try:
_, data = conn.search(None, f'(SUBJECT "{term}" SINCE "{since}")')
for uid in (data[0] or b"").split():
seen_uids[uid] = None
except Exception:
pass
uids = list(seen_uids.keys())[: limit * 3]
yield {"type": "start", "account": name, "total_uids": len(uids)}
emails: list[dict] = []
skipped = 0
for i, uid in enumerate(uids):
if len(emails) >= limit:
break
if i % 5 == 0:
yield {"type": "progress", "account": name, "fetched": len(emails), "total_uids": len(uids)}
try:
_, raw_data = conn.fetch(uid, "(RFC822)")
if not raw_data or not raw_data[0]:
continue
msg = _email_lib.message_from_bytes(raw_data[0][1])
subj = _decode_str(msg.get("Subject", ""))
from_addr = _decode_str(msg.get("From", ""))
date = _decode_str(msg.get("Date", ""))
body = extract_body(msg)[:800]
entry = {"subject": subj, "body": body, "from_addr": from_addr,
"date": date, "account": name}
k = entry_key(entry)
if k not in known_keys:
known_keys.add(k)
emails.append(entry)
else:
skipped += 1
except Exception:
skipped += 1
try:
conn.logout()
except Exception:
pass
yield {"type": "done", "account": name, "added": len(emails), "skipped": skipped,
"emails": emails}
class AccountTestRequest(BaseModel):
account: dict
@router.post("/accounts/test")
def test_account_route(req: AccountTestRequest) -> dict:
ok, message, count = test_connection(req.account)
return {"ok": ok, "message": message, "count": count}
@router.get("/fetch/stream")
def fetch_stream(
accounts: str = Query(default=""),
days_back: int = Query(default=90, ge=1, le=365),
limit: int = Query(default=150, ge=1, le=1000),
mode: str = Query(default="wide"),
) -> StreamingResponse:
selected_names = {n.strip() for n in accounts.split(",") if n.strip()}
all_accounts = _get_config_accounts()
selected = [a for a in all_accounts if a.get("name") in selected_names]
def generate():
known_keys = {entry_key(x) for x in read_jsonl(_queue_file())}
total_added = 0
for acc in selected:
try:
batch_emails: list[dict] = []
for event in fetch_account_stream(acc, days_back, limit, known_keys):
if event["type"] == "done":
batch_emails = event.pop("emails", [])
total_added += event["added"]
yield f"data: {json.dumps(event)}\n\n"
if batch_emails:
existing = read_jsonl(_queue_file())
write_jsonl(_queue_file(), existing + batch_emails)
except Exception as exc:
yield f"data: {json.dumps({'type': 'error', 'account': acc.get('name', '?'), 'message': str(exc)})}\n\n"
queue_size = len(read_jsonl(_queue_file()))
yield f"data: {json.dumps({'type': 'complete', 'total_added': total_added, 'queue_size': queue_size})}\n\n"
return StreamingResponse(generate(), media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"})

644
app/data/imitate.py Normal file
View file

@ -0,0 +1,644 @@
"""Avocet — Imitate tab API.
Fetches real samples from sibling CF product APIs, sends them through selected
local LLMs (ollama), and streams responses back to the UI. Results can be
pushed into the SFT corrections queue for human review.
All endpoints registered on `router`. api.py includes this with prefix="/api/imitate".
Module-level globals follow the same testability pattern as cforch.py and sft.py:
override _CONFIG_DIR and _DATA_DIR via set_config_dir() / set_data_dir() in tests.
"""
from __future__ import annotations
import base64
import json
import logging
import time
import uuid
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
from urllib.error import URLError
from urllib.request import Request, urlopen
import httpx
import yaml
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from app.utils import append_jsonl
logger = logging.getLogger(__name__)
_ROOT = Path(__file__).parent.parent.parent
_CONFIG_DIR: Path | None = None
_DATA_DIR: Path = _ROOT / "data"
router = APIRouter()
# ── Testability seams ──────────────────────────────────────────────────────────
def set_config_dir(path: Path | None) -> None:
global _CONFIG_DIR
_CONFIG_DIR = path
def set_data_dir(path: Path) -> None:
global _DATA_DIR
_DATA_DIR = path
# ── Internal helpers ───────────────────────────────────────────────────────────
def _config_file() -> Path:
if _CONFIG_DIR is not None:
return _CONFIG_DIR / "label_tool.yaml"
return _ROOT / "config" / "label_tool.yaml"
def _load_imitate_config() -> dict:
"""Read label_tool.yaml and return the imitate sub-dict (or {} if absent)."""
f = _config_file()
if not f.exists():
return {}
try:
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
except yaml.YAMLError as exc:
logger.warning("Failed to parse imitate config %s: %s", f, exc)
return {}
return raw.get("imitate", {}) or {}
def _load_cforch_config() -> dict:
"""Read cforch section for ollama_url fallback."""
f = _config_file()
if not f.exists():
return {}
try:
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
except yaml.YAMLError as exc:
return {}
return raw.get("cforch", {}) or {}
def _ollama_url(cfg: dict) -> str:
cforch = _load_cforch_config()
return cfg.get("ollama_url") or cforch.get("ollama_url") or "http://localhost:11434"
def _cforch_url() -> str:
cforch = _load_cforch_config()
return cforch.get("coordinator_url") or "http://localhost:7700"
def _cforch_catalog(cforch_base: str) -> list[dict]:
"""Fetch the live cf-text catalog from cf-orch.
Filters out proxy entries (ollama://, vllm://, http://) those models are
served by their own services and should not be allocated via cf-text.
Returns only models with real file-system paths that cf-text can load directly.
"""
try:
resp = httpx.get(
f"{cforch_base}/api/services/cf-text/catalog",
params={"node_id": "heimdall"},
timeout=5.0,
)
resp.raise_for_status()
raw = resp.json()
result = []
for model_id, entry in raw.items():
if not isinstance(entry, dict):
continue
path = entry.get("path", "")
# Skip proxy entries — they're routed through other services
if "://" in path:
continue
result.append({
"id": model_id,
"vram_mb": entry.get("vram_mb", 0),
"description": entry.get("description", ""),
})
return result
except Exception as exc:
logger.warning("Could not fetch cf-orch catalog: %s", exc)
return []
def _http_get_json(url: str, timeout: int = 5) -> Any:
"""Fetch JSON from url; raise URLError on failure."""
req = Request(url, headers={"Accept": "application/json"})
with urlopen(req, timeout=timeout) as resp:
return json.loads(resp.read().decode("utf-8"))
def _is_online(base_url: str, health_path: str = "/api/health") -> bool:
"""Return True if the product's health endpoint responds OK."""
try:
data = _http_get_json(f"{base_url.rstrip('/')}{health_path}", timeout=2)
return bool(data)
except Exception:
return False
def _extract_sample(
raw: Any,
text_fields: list[str],
sample_index: int = 0,
sample_key: str | None = None,
) -> dict[str, Any]:
"""Pull one item from a list or dict response and extract text_fields.
sample_key: if provided, unwrap raw[sample_key] before looking for a list.
Falls back to a set of conventional envelope keys if sample_key is absent.
"""
item: dict[str, Any]
if isinstance(raw, list):
if not raw:
return {}
item = raw[min(sample_index, len(raw) - 1)]
elif isinstance(raw, dict):
# Use declared sample_key first, then fall back to conventional names.
_ENVELOPE_KEYS = (
"samples", "items", "results", "data", "jobs", "listings",
"pantry", "saved_searches", "entries", "calls", "records",
)
search_keys = ([sample_key] if sample_key else []) + list(_ENVELOPE_KEYS)
for key in search_keys:
if key in raw and isinstance(raw[key], list):
lst = raw[key]
item = lst[min(sample_index, len(lst) - 1)] if lst else {}
break
else:
item = raw
else:
return {}
parts = []
for field in text_fields:
val = item.get(field)
if val and str(val).strip():
parts.append(f"**{field}**: {val}")
return {"item": item, "text": "\n\n".join(parts)}
def _candidates_file() -> Path:
return _DATA_DIR / "sft_candidates.jsonl"
def _sse(data: dict) -> str:
return f"data: {json.dumps(data)}\n\n"
def _fetch_image_b64(image_url: str) -> str:
"""Download an image URL and return it as a base64 string for ollama.
Returns empty string on any failure a missing image is non-fatal;
the model will still run against the text prompt alone.
"""
try:
req = Request(image_url, headers={"User-Agent": "Avocet/1.0"})
with urlopen(req, timeout=10) as resp:
return base64.b64encode(resp.read()).decode("ascii")
except Exception as exc:
logger.warning("Failed to fetch image %s: %s", image_url, exc)
return ""
def _run_ollama_streaming(
ollama_base: str,
model_id: str,
prompt: str,
temperature: float,
system: str = "",
images: list[str] | None = None,
) -> tuple[str, int]:
"""Call ollama /api/generate with stream=False; return (full_response, elapsed_ms).
Blocks until the model finishes; yields nothing streaming is handled by
the SSE generator in run_imitate().
system: optional system prompt passed as a separate field to ollama.
images: list of base64-encoded image strings (vision models only).
"""
url = f"{ollama_base.rstrip('/')}/api/generate"
body: dict = {
"model": model_id,
"prompt": prompt,
"stream": False,
"options": {"temperature": temperature},
}
if system:
body["system"] = system
if images:
body["images"] = images
payload = json.dumps(body).encode("utf-8")
req = Request(url, data=payload, method="POST",
headers={"Content-Type": "application/json"})
t0 = time.time()
try:
with urlopen(req, timeout=120) as resp:
body = json.loads(resp.read().decode("utf-8"))
elapsed = int((time.time() - t0) * 1000)
return body.get("response", ""), elapsed
except Exception as exc:
elapsed = int((time.time() - t0) * 1000)
raise RuntimeError(str(exc)) from exc
def _run_cftext(
cforch_base: str,
model_id: str,
prompt: str,
system: str,
temperature: float,
startup_timeout_s: float = 180.0,
user_id: str | None = None,
) -> tuple[str, int, bool]:
"""Allocate cf-text via cf-orch, generate, release. Returns (response, elapsed_ms, cold_started).
Raises RuntimeError on allocation failure or generation error.
cold_started=True means the service was launched from scratch (caller may log this).
Cold-start detection uses coordinator state signals (running/stopped) rather than
polling the service health endpoint this fails fast on model load errors instead
of waiting out the full timeout.
"""
# Allocate
alloc_resp = httpx.post(
f"{cforch_base}/api/services/cf-text/allocate",
json={
"model_candidates": [model_id],
"caller": "avocet",
"pipeline": "imitate",
**({"user_id": user_id} if user_id else {}),
},
timeout=30.0,
)
alloc_resp.raise_for_status()
data = alloc_resp.json()
service_url: str = data["url"]
allocation_id: str = data.get("allocation_id", "")
node_id: str = data.get("node_id", "")
gpu_id: int | None = data.get("gpu_id")
cold_started = data.get("started", False) and not data.get("warm", True)
# Wait for ready using coordinator state signals
if cold_started:
deadline = time.monotonic() + startup_timeout_s
probe_misses = 0
while time.monotonic() < deadline:
try:
status = httpx.get(
f"{cforch_base}/api/services/cf-text/status", timeout=5.0
)
if status.is_success:
instances = status.json().get("instances", [])
match = next(
(i for i in instances
if i.get("node_id") == node_id and i.get("gpu_id") == gpu_id),
None,
)
if match:
probe_misses = 0
state = match.get("state", "")
if state == "running":
break
elif state == "stopped":
if allocation_id:
httpx.delete(
f"{cforch_base}/api/services/cf-text/allocations/{allocation_id}",
timeout=5.0,
)
raise RuntimeError(f"cf-text failed to load {model_id!r} (service stopped)")
else:
probe_misses += 1
if probe_misses >= 6:
# Coordinator hasn't registered instance yet — fall back to health poll
try:
if httpx.get(f"{service_url}/health", timeout=3.0).is_success:
break
except Exception:
pass
except RuntimeError:
raise
except Exception:
pass
time.sleep(2.0)
else:
if allocation_id:
httpx.delete(f"{cforch_base}/api/services/cf-text/allocations/{allocation_id}", timeout=5.0)
raise RuntimeError(f"cf-text cold start timed out after {startup_timeout_s:.0f}s")
# Generate
messages: list[dict] = []
if system:
messages.append({"role": "system", "content": system})
messages.append({"role": "user", "content": prompt})
t0 = time.time()
try:
gen_resp = httpx.post(
f"{service_url}/v1/chat/completions",
json={
"model": model_id,
"messages": messages,
"max_tokens": 300,
"temperature": temperature,
"stream": False,
},
timeout=120.0,
)
gen_resp.raise_for_status()
elapsed_ms = int((time.time() - t0) * 1000)
content = gen_resp.json()["choices"][0]["message"]["content"]
return content.strip(), elapsed_ms, cold_started
except Exception as exc:
elapsed_ms = int((time.time() - t0) * 1000)
raise RuntimeError(str(exc)) from exc
finally:
if allocation_id:
try:
httpx.delete(f"{cforch_base}/api/services/cf-text/allocations/{allocation_id}", timeout=5.0)
except Exception:
pass
# ── GET /products ──────────────────────────────────────────────────────────────
@router.get("/products")
def get_products() -> dict:
"""List configured CF products with live online status."""
cfg = _load_imitate_config()
products_raw = cfg.get("products", []) or []
products = []
for p in products_raw:
if not isinstance(p, dict):
continue
base_url = p.get("base_url", "")
products.append({
"id": p.get("id", ""),
"name": p.get("name", ""),
"icon": p.get("icon", "📦"),
"description": p.get("description", ""),
"base_url": base_url,
"online": _is_online(base_url, p.get("health_path", "/api/health")) if base_url else False,
})
return {"products": products}
# ── GET /products/{product_id}/sample ─────────────────────────────────────────
@router.get("/products/{product_id}/sample")
def get_sample(product_id: str, index: int = 0) -> dict:
"""Fetch a real sample from the given product's API."""
cfg = _load_imitate_config()
products_raw = cfg.get("products", []) or []
product: dict | None = None
for p in products_raw:
if isinstance(p, dict) and p.get("id") == product_id:
product = p
break
if product is None:
raise HTTPException(404, f"Product '{product_id}' not in config")
base_url = product.get("base_url", "").rstrip("/")
endpoint = product.get("sample_endpoint", "")
if not base_url or not endpoint:
raise HTTPException(422, "Product missing base_url or sample_endpoint")
url = f"{base_url}{endpoint}"
try:
raw = _http_get_json(url, timeout=5)
except URLError as exc:
raise HTTPException(503, f"Product API unreachable: {exc}") from exc
except Exception as exc:
raise HTTPException(502, f"Bad response from product API: {exc}") from exc
text_fields = product.get("text_fields", []) or []
sample_key = product.get("sample_key") or None
extracted = _extract_sample(raw, text_fields, index, sample_key=sample_key)
if not extracted:
raise HTTPException(404, "No sample items returned by product API")
prompt_template = product.get("prompt_template", "{text}")
prompt = prompt_template.replace("{text}", extracted["text"])
# Also substitute any {field_name} placeholders from the raw item fields.
item = extracted.get("item", {})
for field, val in item.items():
prompt = prompt.replace(f"{{{field}}}", str(val) if val is not None else "")
# Expose system_prompt and image_url if the product API returns them.
# system_prompt: Peregrine, Snipe (vision analysis instructions)
# image_url: Snipe listing photos — Avocet downloads + base64-encodes at run time
item = extracted.get("item", {})
system_prompt = str(item.get("system_prompt", "")) if isinstance(item, dict) else ""
image_url = str(item.get("image_url", "")) if isinstance(item, dict) else ""
return {
"product_id": product_id,
"sample_index": index,
"text": extracted["text"],
"prompt": prompt,
"system_prompt": system_prompt,
"image_url": image_url,
"raw_item": item,
}
# ── GET /catalog ───────────────────────────────────────────────────────────────
@router.get("/catalog")
def get_catalog() -> dict:
"""Return the live cf-text model catalog from cf-orch coordinator."""
models = _cforch_catalog(_cforch_url())
return {"models": models}
# ── GET /run (SSE) ─────────────────────────────────────────────────────────────
def _get_imitate_session(request: Any, response: Any) -> "CloudUser | None":
"""Optional session dependency — returns None when cloud_session is unavailable."""
try:
from app.cloud_session import get_session
return get_session(request, response)
except Exception:
return None
@router.get("/run")
def run_imitate(
prompt: str = "",
model_ids: str = "", # comma-separated ollama model IDs
cf_text_model_ids: str = "", # comma-separated cf-text model IDs (via cf-orch)
temperature: float = 0.7,
product_id: str = "",
system: str = "", # optional system prompt
image_url: str = "", # optional image URL for vision models
session: "Any" = Depends(_get_imitate_session),
) -> StreamingResponse:
"""Run a prompt through selected ollama models and stream results as SSE.
If image_url is provided, the image is downloaded once and passed to every
model as a base64-encoded blob allowing vision-capable local models to
evaluate listing photos the same way Snipe's background task pipeline does.
"""
if not prompt.strip():
raise HTTPException(422, "prompt is required")
ollama_ids = [m.strip() for m in model_ids.split(",") if m.strip()]
cftext_ids = [m.strip() for m in cf_text_model_ids.split(",") if m.strip()]
if not ollama_ids and not cftext_ids:
raise HTTPException(422, "model_ids or cf_text_model_ids is required")
cfg = _load_imitate_config()
ollama_base = _ollama_url(cfg)
cforch_base = _cforch_url()
system_ctx = system.strip() or ""
total_models = len(ollama_ids) + len(cftext_ids)
# Download image once before streaming — shared across ollama vision models
images: list[str] = []
if image_url.strip():
b64 = _fetch_image_b64(image_url.strip())
if b64:
images = [b64]
def generate():
results: list[dict] = []
yield _sse({"type": "start", "total_models": total_models, "has_image": bool(images)})
# Ollama models
for model_id in ollama_ids:
yield _sse({"type": "model_start", "model": model_id, "service": "ollama"})
try:
response, elapsed_ms = _run_ollama_streaming(
ollama_base, model_id, prompt, temperature,
system=system_ctx, images=images or None,
)
result = {
"model": model_id,
"response": response,
"elapsed_ms": elapsed_ms,
"error": None,
}
except Exception as exc:
result = {
"model": model_id,
"response": "",
"elapsed_ms": 0,
"error": str(exc),
}
results.append(result)
yield _sse({"type": "model_done", **result})
# cf-text models via cf-orch — fan out in parallel when multiple models selected
if cftext_ids:
from concurrent.futures import ThreadPoolExecutor, as_completed
# Announce all models upfront so the UI can show loading states immediately
for model_id in cftext_ids:
yield _sse({"type": "model_start", "model": model_id, "service": "cf-text"})
_user_id: str | None = getattr(session, "user_id", None)
# Only forward real cloud user IDs — skip local/anon sessions
if _user_id in (None, "local", "local-dev") or (_user_id or "").startswith("anon-"):
_user_id = None
with ThreadPoolExecutor(max_workers=len(cftext_ids)) as pool:
future_to_model = {
pool.submit(
_run_cftext, cforch_base, mid, prompt, system_ctx, temperature,
180.0, _user_id,
): mid
for mid in cftext_ids
}
for future in as_completed(future_to_model):
model_id = future_to_model[future]
try:
response, elapsed_ms, cold_started = future.result()
if cold_started:
yield _sse({"type": "model_coldstart", "model": model_id})
result = {
"model": model_id,
"response": response,
"elapsed_ms": elapsed_ms,
"error": None,
}
except Exception as exc:
result = {
"model": model_id,
"response": "",
"elapsed_ms": 0,
"error": str(exc),
}
results.append(result)
yield _sse({"type": "model_done", **result})
yield _sse({"type": "complete", "results": results})
return StreamingResponse(
generate(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"X-Accel-Buffering": "no",
},
)
# ── POST /push-corrections ─────────────────────────────────────────────────────
class ImitateResult(BaseModel):
model: str
response: str
elapsed_ms: int
error: str | None = None
class PushCorrectionsRequest(BaseModel):
product_id: str
prompt: str
results: list[ImitateResult]
@router.post("/push-corrections")
def push_corrections(req: PushCorrectionsRequest) -> dict:
"""Append imitate results to sft_candidates.jsonl for human review."""
if not req.prompt.strip():
raise HTTPException(422, "prompt is required")
if not req.results:
raise HTTPException(422, "results list is empty")
ts = datetime.now(timezone.utc).isoformat()
records = []
for r in req.results:
if r.error or not r.response.strip():
continue
records.append({
"id": str(uuid.uuid4()),
"source": "imitate",
"product_id": req.product_id,
"prompt_messages": [{"role": "user", "content": req.prompt}],
"model_response": r.response,
"model_id": r.model,
"elapsed_ms": r.elapsed_ms,
"status": "pending",
"created_at": ts,
})
if not records:
raise HTTPException(422, "No non-error results to push")
dest = _candidates_file()
dest.parent.mkdir(parents=True, exist_ok=True)
for record in records:
append_jsonl(dest, record)
return {"pushed": len(records)}

222
app/data/label.py Normal file
View file

@ -0,0 +1,222 @@
"""Avocet -- label queue API.
All label/skip/discard/undo/stats/config endpoints.
Extracted from app/api.py as part of the v2 domain split.
"""
from __future__ import annotations
import hashlib
import json
import yaml
from datetime import datetime, timezone
from pathlib import Path
from fastapi import APIRouter, HTTPException, Query
from fastapi.responses import FileResponse
from pydantic import BaseModel
from app.utils import append_jsonl, read_jsonl, write_jsonl
_ROOT = Path(__file__).parent.parent.parent
_DATA_DIR: Path = _ROOT / "data"
_CONFIG_DIR: Path | None = None
_last_action: dict | None = None
router = APIRouter()
def set_data_dir(path: Path) -> None:
global _DATA_DIR
_DATA_DIR = path
def set_config_dir(path: Path | None) -> None:
global _CONFIG_DIR
_CONFIG_DIR = path
def reset_last_action() -> None:
global _last_action
_last_action = None
def _config_file() -> Path:
if _CONFIG_DIR is not None:
return _CONFIG_DIR / "label_tool.yaml"
return _ROOT / "config" / "label_tool.yaml"
def _queue_file() -> Path:
return _DATA_DIR / "email_label_queue.jsonl"
def _score_file() -> Path:
return _DATA_DIR / "email_score.jsonl"
def _discarded_file() -> Path:
return _DATA_DIR / "discarded.jsonl"
def _item_id(item: dict) -> str:
key = (item.get("subject", "") + (item.get("body", "") or "")[:100])
return hashlib.md5(key.encode("utf-8", errors="replace")).hexdigest()
def _normalize(item: dict) -> dict:
return {
"id": item.get("id") or _item_id(item),
"subject": item.get("subject", ""),
"body": item.get("body", ""),
"from": item.get("from") or item.get("from_addr", ""),
"date": item.get("date", ""),
"source": item.get("source") or item.get("account", ""),
}
_LABEL_META = [
{"name": "interview_scheduled", "emoji": "\U0001f4c5", "color": "#4CAF50", "key": "1"},
{"name": "offer_received", "emoji": "\U0001f389", "color": "#2196F3", "key": "2"},
{"name": "rejected", "emoji": "", "color": "#F44336", "key": "3"},
{"name": "positive_response", "emoji": "\U0001f44d", "color": "#FF9800", "key": "4"},
{"name": "survey_received", "emoji": "\U0001f4cb", "color": "#9C27B0", "key": "5"},
{"name": "neutral", "emoji": "", "color": "#607D8B", "key": "6"},
{"name": "event_rescheduled", "emoji": "\U0001f504", "color": "#FF5722", "key": "7"},
{"name": "digest", "emoji": "\U0001f4f0", "color": "#00BCD4", "key": "8"},
{"name": "new_lead", "emoji": "\U0001f91d", "color": "#009688", "key": "9"},
{"name": "hired", "emoji": "\U0001f38a", "color": "#FFC107", "key": "h"},
]
@router.get("/queue")
def get_queue(limit: int = Query(default=10, ge=1, le=50)):
items = read_jsonl(_queue_file())
return {"items": [_normalize(x) for x in items[:limit]], "total": len(items)}
class LabelRequest(BaseModel):
id: str
label: str
@router.post("/label")
def post_label(req: LabelRequest):
global _last_action
items = read_jsonl(_queue_file())
match = next((x for x in items if _normalize(x)["id"] == req.id), None)
if not match:
raise HTTPException(404, f"Item {req.id!r} not found in queue")
record = {**match, "label": req.label,
"labeled_at": datetime.now(timezone.utc).isoformat()}
append_jsonl(_score_file(), record)
write_jsonl(_queue_file(), [x for x in items if _normalize(x)["id"] != req.id])
_last_action = {"type": "label", "item": match, "label": req.label}
return {"ok": True}
class SkipRequest(BaseModel):
id: str
@router.post("/skip")
def post_skip(req: SkipRequest):
global _last_action
items = read_jsonl(_queue_file())
match = next((x for x in items if _normalize(x)["id"] == req.id), None)
if not match:
raise HTTPException(404, f"Item {req.id!r} not found in queue")
reordered = [x for x in items if _normalize(x)["id"] != req.id] + [match]
write_jsonl(_queue_file(), reordered)
_last_action = {"type": "skip", "item": match}
return {"ok": True}
class DiscardRequest(BaseModel):
id: str
@router.post("/discard")
def post_discard(req: DiscardRequest):
global _last_action
items = read_jsonl(_queue_file())
match = next((x for x in items if _normalize(x)["id"] == req.id), None)
if not match:
raise HTTPException(404, f"Item {req.id!r} not found in queue")
record = {**match, "label": "__discarded__",
"discarded_at": datetime.now(timezone.utc).isoformat()}
append_jsonl(_discarded_file(), record)
write_jsonl(_queue_file(), [x for x in items if _normalize(x)["id"] != req.id])
_last_action = {"type": "discard", "item": match}
return {"ok": True}
@router.delete("/label/undo")
def delete_undo():
global _last_action
if not _last_action:
raise HTTPException(404, "No action to undo")
action = _last_action
item = action["item"]
if action["type"] == "label":
records = read_jsonl(_score_file())
if not records:
raise HTTPException(409, "Score file is empty -- cannot undo label")
write_jsonl(_score_file(), records[:-1])
items = read_jsonl(_queue_file())
write_jsonl(_queue_file(), [item] + items)
elif action["type"] == "discard":
records = read_jsonl(_discarded_file())
if not records:
raise HTTPException(409, "Discarded file is empty -- cannot undo discard")
write_jsonl(_discarded_file(), records[:-1])
items = read_jsonl(_queue_file())
write_jsonl(_queue_file(), [item] + items)
elif action["type"] == "skip":
items = read_jsonl(_queue_file())
item_id = _normalize(item)["id"]
items = [item] + [x for x in items if _normalize(x)["id"] != item_id]
write_jsonl(_queue_file(), items)
_last_action = None
return {"undone": {"type": action["type"], "item": _normalize(item)}}
@router.get("/config/labels")
def get_labels():
return _LABEL_META
@router.get("/config")
def get_config():
f = _config_file()
if not f.exists():
return {"accounts": [], "max_per_account": 500}
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
return {"accounts": raw.get("accounts", []), "max_per_account": raw.get("max_per_account", 500)}
class ConfigPayload(BaseModel):
accounts: list[dict]
max_per_account: int = 500
@router.post("/config")
def post_config(payload: ConfigPayload):
f = _config_file()
f.parent.mkdir(parents=True, exist_ok=True)
tmp = f.with_suffix(".tmp")
tmp.write_text(yaml.dump(payload.model_dump(), allow_unicode=True, sort_keys=False),
encoding="utf-8")
tmp.rename(f)
return {"ok": True}
@router.get("/stats")
def get_stats():
records = read_jsonl(_score_file())
counts: dict[str, int] = {}
for r in records:
lbl = r.get("label", "")
if lbl:
counts[lbl] = counts.get(lbl, 0) + 1
benchmark_results: dict = {}
benchmark_path = _DATA_DIR / "benchmark_results.json"
if benchmark_path.exists():
try:
benchmark_results = json.loads(benchmark_path.read_text(encoding="utf-8"))
except Exception:
pass
return {
"total": len(records),
"counts": counts,
"score_file_bytes": _score_file().stat().st_size if _score_file().exists() else 0,
"benchmark_results": benchmark_results,
}
@router.get("/stats/download")
def download_stats():
if not _score_file().exists():
raise HTTPException(404, "No score file")
return FileResponse(
str(_score_file()),
filename="email_score.jsonl",
media_type="application/jsonlines",
headers={"Content-Disposition": 'attachment; filename="email_score.jsonl"'},
)

0
app/eval/__init__.py Normal file
View file

38
app/eval/cforch.py Normal file
View file

@ -0,0 +1,38 @@
"""Avocet -- eval router aggregator.
Collects benchmark sub-routers into a single importable `router`
for the api.py factory. Each sub-router retains its established prefix
so no frontend URL changes are needed.
Route prefixes when mounted at /api in api.py:
/api/cforch/* -- cf-orch benchmark routes
/api/style/* -- writing style benchmark routes
/api/voice/* -- voice benchmark routes
/api/plans-bench/* -- plans benchmark routes
"""
from __future__ import annotations
from fastapi import APIRouter
from app.cforch import router as _cforch_router
from app.style import router as _style_router
from app.voice import router as _voice_router
from app.plans_bench import router as _plans_router
router = APIRouter()
router.include_router(_cforch_router, prefix="/cforch")
router.include_router(_style_router, prefix="/style")
router.include_router(_voice_router, prefix="/voice")
router.include_router(_plans_router, prefix="/plans-bench")
def set_config_dir(path) -> None:
"""Propagate config dir override to all sub-modules -- used by tests."""
import app.cforch as _cforch_mod
import app.style as _style_mod
import app.voice as _voice_mod
import app.plans_bench as _plans_mod
_cforch_mod.set_config_dir(path)
_style_mod.set_config_dir(path)
_voice_mod.set_config_dir(path)
_plans_mod.set_config_dir(path)

View file

@ -1,158 +1,9 @@
"""Avocet — IMAP fetch utilities. """Backward-compat shim -- logic moved to app/data/fetch.py."""
import imaplib # noqa: F401 -- re-exported so existing patch("app.imap_fetch.imaplib...") calls still work
Shared between app/api.py (FastAPI SSE endpoint) and the label UI. from app.data.fetch import ( # noqa: F401
No Streamlit imports here stdlib + imaplib only. entry_key,
""" fetch_account_stream,
from __future__ import annotations test_connection,
_decode_str,
import email as _email_lib _WIDE_TERMS,
import hashlib )
import imaplib
from datetime import datetime, timedelta
from email.header import decode_header as _raw_decode
from typing import Any, Iterator
from app.utils import extract_body, strip_html # noqa: F401 (strip_html re-exported for callers)
# ── IMAP decode helpers ───────────────────────────────────────────────────────
def _decode_str(value: str | None) -> str:
if not value:
return ""
parts = _raw_decode(value)
out = []
for part, enc in parts:
if isinstance(part, bytes):
out.append(part.decode(enc or "utf-8", errors="replace"))
else:
out.append(str(part))
return " ".join(out).strip()
def entry_key(e: dict) -> str:
"""Stable MD5 content-hash for dedup — matches label_tool.py _entry_key."""
key = (e.get("subject", "") + (e.get("body", "") or "")[:100])
return hashlib.md5(key.encode("utf-8", errors="replace")).hexdigest()
# ── Wide search terms ────────────────────────────────────────────────────────
_WIDE_TERMS = [
"interview", "phone screen", "video call", "zoom link", "schedule a call",
"offer letter", "job offer", "offer of employment", "pleased to offer",
"unfortunately", "not moving forward", "other candidates", "regret to inform",
"no longer", "decided not to", "decided to go with",
"opportunity", "interested in your background", "reached out", "great fit",
"exciting role", "love to connect",
"assessment", "questionnaire", "culture fit", "culture-fit", "online assessment",
"application received", "thank you for applying", "application confirmation",
"you applied", "your application for",
"reschedule", "rescheduled", "new time", "moved to", "postponed", "new date",
"job digest", "jobs you may like", "recommended jobs", "jobs for you",
"new jobs", "job alert",
"came across your profile", "reaching out about", "great fit for a role",
"exciting opportunity",
"welcome to the team", "start date", "onboarding", "first day", "we're excited to have you",
"application", "recruiter", "recruiting", "hiring", "candidate",
]
# ── Public API ────────────────────────────────────────────────────────────────
def test_connection(acc: dict) -> tuple[bool, str, int | None]:
"""Connect, login, select folder. Returns (ok, human_message, message_count|None)."""
host = acc.get("host", "")
port = int(acc.get("port", 993))
use_ssl = acc.get("use_ssl", True)
username = acc.get("username", "")
password = acc.get("password", "")
folder = acc.get("folder", "INBOX")
if not host or not username or not password:
return False, "Host, username, and password are all required.", None
try:
conn = (imaplib.IMAP4_SSL if use_ssl else imaplib.IMAP4)(host, port)
conn.login(username, password)
_, data = conn.select(folder, readonly=True)
count_raw = data[0].decode() if data and data[0] else "0"
count = int(count_raw) if count_raw.isdigit() else 0
conn.logout()
return True, f"Connected — {count:,} message(s) in {folder}.", count
except Exception as exc:
return False, str(exc), None
def fetch_account_stream(
acc: dict,
days_back: int,
limit: int,
known_keys: set[str],
) -> Iterator[dict]:
"""Generator — yields progress dicts while fetching emails via IMAP.
Mutates `known_keys` in place for cross-account dedup within one fetch session.
Yields event dicts with "type" key:
{"type": "start", "account": str, "total_uids": int}
{"type": "progress", "account": str, "fetched": int, "total_uids": int}
{"type": "done", "account": str, "added": int, "skipped": int, "emails": list}
"""
name = acc.get("name", acc.get("username", "?"))
host = acc.get("host", "imap.gmail.com")
port = int(acc.get("port", 993))
use_ssl = acc.get("use_ssl", True)
username = acc["username"]
password = acc["password"]
folder = acc.get("folder", "INBOX")
since = (datetime.now() - timedelta(days=days_back)).strftime("%d-%b-%Y")
conn = (imaplib.IMAP4_SSL if use_ssl else imaplib.IMAP4)(host, port)
conn.login(username, password)
conn.select(folder, readonly=True)
seen_uids: dict[bytes, None] = {}
for term in _WIDE_TERMS:
try:
_, data = conn.search(None, f'(SUBJECT "{term}" SINCE "{since}")')
for uid in (data[0] or b"").split():
seen_uids[uid] = None
except Exception:
pass
uids = list(seen_uids.keys())[: limit * 3]
yield {"type": "start", "account": name, "total_uids": len(uids)}
emails: list[dict] = []
skipped = 0
for i, uid in enumerate(uids):
if len(emails) >= limit:
break
if i % 5 == 0:
yield {"type": "progress", "account": name, "fetched": len(emails), "total_uids": len(uids)}
try:
_, raw_data = conn.fetch(uid, "(RFC822)")
if not raw_data or not raw_data[0]:
continue
msg = _email_lib.message_from_bytes(raw_data[0][1])
subj = _decode_str(msg.get("Subject", ""))
from_addr = _decode_str(msg.get("From", ""))
date = _decode_str(msg.get("Date", ""))
body = extract_body(msg)[:800]
entry = {"subject": subj, "body": body, "from_addr": from_addr,
"date": date, "account": name}
k = entry_key(entry)
if k not in known_keys:
known_keys.add(k)
emails.append(entry)
else:
skipped += 1
except Exception:
skipped += 1
try:
conn.logout()
except Exception:
pass
yield {"type": "done", "account": name, "added": len(emails), "skipped": skipped,
"emails": emails}

View file

@ -1,644 +1,3 @@
"""Avocet — Imitate tab API. """Backward-compat shim -- logic moved to app/data/imitate.py."""
from app.data.imitate import router # noqa: F401
Fetches real samples from sibling CF product APIs, sends them through selected from app.data.imitate import set_config_dir, set_data_dir # noqa: F401
local LLMs (ollama), and streams responses back to the UI. Results can be
pushed into the SFT corrections queue for human review.
All endpoints registered on `router`. api.py includes this with prefix="/api/imitate".
Module-level globals follow the same testability pattern as cforch.py and sft.py:
override _CONFIG_DIR and _DATA_DIR via set_config_dir() / set_data_dir() in tests.
"""
from __future__ import annotations
import base64
import json
import logging
import time
import uuid
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
from urllib.error import URLError
from urllib.request import Request, urlopen
import httpx
import yaml
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from app.utils import append_jsonl
logger = logging.getLogger(__name__)
_ROOT = Path(__file__).parent.parent
_CONFIG_DIR: Path | None = None
_DATA_DIR: Path = _ROOT / "data"
router = APIRouter()
# ── Testability seams ──────────────────────────────────────────────────────────
def set_config_dir(path: Path | None) -> None:
global _CONFIG_DIR
_CONFIG_DIR = path
def set_data_dir(path: Path) -> None:
global _DATA_DIR
_DATA_DIR = path
# ── Internal helpers ───────────────────────────────────────────────────────────
def _config_file() -> Path:
if _CONFIG_DIR is not None:
return _CONFIG_DIR / "label_tool.yaml"
return _ROOT / "config" / "label_tool.yaml"
def _load_imitate_config() -> dict:
"""Read label_tool.yaml and return the imitate sub-dict (or {} if absent)."""
f = _config_file()
if not f.exists():
return {}
try:
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
except yaml.YAMLError as exc:
logger.warning("Failed to parse imitate config %s: %s", f, exc)
return {}
return raw.get("imitate", {}) or {}
def _load_cforch_config() -> dict:
"""Read cforch section for ollama_url fallback."""
f = _config_file()
if not f.exists():
return {}
try:
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
except yaml.YAMLError as exc:
return {}
return raw.get("cforch", {}) or {}
def _ollama_url(cfg: dict) -> str:
cforch = _load_cforch_config()
return cfg.get("ollama_url") or cforch.get("ollama_url") or "http://localhost:11434"
def _cforch_url() -> str:
cforch = _load_cforch_config()
return cforch.get("coordinator_url") or "http://localhost:7700"
def _cforch_catalog(cforch_base: str) -> list[dict]:
"""Fetch the live cf-text catalog from cf-orch.
Filters out proxy entries (ollama://, vllm://, http://) those models are
served by their own services and should not be allocated via cf-text.
Returns only models with real file-system paths that cf-text can load directly.
"""
try:
resp = httpx.get(
f"{cforch_base}/api/services/cf-text/catalog",
params={"node_id": "heimdall"},
timeout=5.0,
)
resp.raise_for_status()
raw = resp.json()
result = []
for model_id, entry in raw.items():
if not isinstance(entry, dict):
continue
path = entry.get("path", "")
# Skip proxy entries — they're routed through other services
if "://" in path:
continue
result.append({
"id": model_id,
"vram_mb": entry.get("vram_mb", 0),
"description": entry.get("description", ""),
})
return result
except Exception as exc:
logger.warning("Could not fetch cf-orch catalog: %s", exc)
return []
def _http_get_json(url: str, timeout: int = 5) -> Any:
"""Fetch JSON from url; raise URLError on failure."""
req = Request(url, headers={"Accept": "application/json"})
with urlopen(req, timeout=timeout) as resp:
return json.loads(resp.read().decode("utf-8"))
def _is_online(base_url: str, health_path: str = "/api/health") -> bool:
"""Return True if the product's health endpoint responds OK."""
try:
data = _http_get_json(f"{base_url.rstrip('/')}{health_path}", timeout=2)
return bool(data)
except Exception:
return False
def _extract_sample(
raw: Any,
text_fields: list[str],
sample_index: int = 0,
sample_key: str | None = None,
) -> dict[str, Any]:
"""Pull one item from a list or dict response and extract text_fields.
sample_key: if provided, unwrap raw[sample_key] before looking for a list.
Falls back to a set of conventional envelope keys if sample_key is absent.
"""
item: dict[str, Any]
if isinstance(raw, list):
if not raw:
return {}
item = raw[min(sample_index, len(raw) - 1)]
elif isinstance(raw, dict):
# Use declared sample_key first, then fall back to conventional names.
_ENVELOPE_KEYS = (
"samples", "items", "results", "data", "jobs", "listings",
"pantry", "saved_searches", "entries", "calls", "records",
)
search_keys = ([sample_key] if sample_key else []) + list(_ENVELOPE_KEYS)
for key in search_keys:
if key in raw and isinstance(raw[key], list):
lst = raw[key]
item = lst[min(sample_index, len(lst) - 1)] if lst else {}
break
else:
item = raw
else:
return {}
parts = []
for field in text_fields:
val = item.get(field)
if val and str(val).strip():
parts.append(f"**{field}**: {val}")
return {"item": item, "text": "\n\n".join(parts)}
def _candidates_file() -> Path:
return _DATA_DIR / "sft_candidates.jsonl"
def _sse(data: dict) -> str:
return f"data: {json.dumps(data)}\n\n"
def _fetch_image_b64(image_url: str) -> str:
"""Download an image URL and return it as a base64 string for ollama.
Returns empty string on any failure a missing image is non-fatal;
the model will still run against the text prompt alone.
"""
try:
req = Request(image_url, headers={"User-Agent": "Avocet/1.0"})
with urlopen(req, timeout=10) as resp:
return base64.b64encode(resp.read()).decode("ascii")
except Exception as exc:
logger.warning("Failed to fetch image %s: %s", image_url, exc)
return ""
def _run_ollama_streaming(
ollama_base: str,
model_id: str,
prompt: str,
temperature: float,
system: str = "",
images: list[str] | None = None,
) -> tuple[str, int]:
"""Call ollama /api/generate with stream=False; return (full_response, elapsed_ms).
Blocks until the model finishes; yields nothing streaming is handled by
the SSE generator in run_imitate().
system: optional system prompt passed as a separate field to ollama.
images: list of base64-encoded image strings (vision models only).
"""
url = f"{ollama_base.rstrip('/')}/api/generate"
body: dict = {
"model": model_id,
"prompt": prompt,
"stream": False,
"options": {"temperature": temperature},
}
if system:
body["system"] = system
if images:
body["images"] = images
payload = json.dumps(body).encode("utf-8")
req = Request(url, data=payload, method="POST",
headers={"Content-Type": "application/json"})
t0 = time.time()
try:
with urlopen(req, timeout=120) as resp:
body = json.loads(resp.read().decode("utf-8"))
elapsed = int((time.time() - t0) * 1000)
return body.get("response", ""), elapsed
except Exception as exc:
elapsed = int((time.time() - t0) * 1000)
raise RuntimeError(str(exc)) from exc
def _run_cftext(
cforch_base: str,
model_id: str,
prompt: str,
system: str,
temperature: float,
startup_timeout_s: float = 180.0,
user_id: str | None = None,
) -> tuple[str, int, bool]:
"""Allocate cf-text via cf-orch, generate, release. Returns (response, elapsed_ms, cold_started).
Raises RuntimeError on allocation failure or generation error.
cold_started=True means the service was launched from scratch (caller may log this).
Cold-start detection uses coordinator state signals (running/stopped) rather than
polling the service health endpoint this fails fast on model load errors instead
of waiting out the full timeout.
"""
# Allocate
alloc_resp = httpx.post(
f"{cforch_base}/api/services/cf-text/allocate",
json={
"model_candidates": [model_id],
"caller": "avocet",
"pipeline": "imitate",
**({"user_id": user_id} if user_id else {}),
},
timeout=30.0,
)
alloc_resp.raise_for_status()
data = alloc_resp.json()
service_url: str = data["url"]
allocation_id: str = data.get("allocation_id", "")
node_id: str = data.get("node_id", "")
gpu_id: int | None = data.get("gpu_id")
cold_started = data.get("started", False) and not data.get("warm", True)
# Wait for ready using coordinator state signals
if cold_started:
deadline = time.monotonic() + startup_timeout_s
probe_misses = 0
while time.monotonic() < deadline:
try:
status = httpx.get(
f"{cforch_base}/api/services/cf-text/status", timeout=5.0
)
if status.is_success:
instances = status.json().get("instances", [])
match = next(
(i for i in instances
if i.get("node_id") == node_id and i.get("gpu_id") == gpu_id),
None,
)
if match:
probe_misses = 0
state = match.get("state", "")
if state == "running":
break
elif state == "stopped":
if allocation_id:
httpx.delete(
f"{cforch_base}/api/services/cf-text/allocations/{allocation_id}",
timeout=5.0,
)
raise RuntimeError(f"cf-text failed to load {model_id!r} (service stopped)")
else:
probe_misses += 1
if probe_misses >= 6:
# Coordinator hasn't registered instance yet — fall back to health poll
try:
if httpx.get(f"{service_url}/health", timeout=3.0).is_success:
break
except Exception:
pass
except RuntimeError:
raise
except Exception:
pass
time.sleep(2.0)
else:
if allocation_id:
httpx.delete(f"{cforch_base}/api/services/cf-text/allocations/{allocation_id}", timeout=5.0)
raise RuntimeError(f"cf-text cold start timed out after {startup_timeout_s:.0f}s")
# Generate
messages: list[dict] = []
if system:
messages.append({"role": "system", "content": system})
messages.append({"role": "user", "content": prompt})
t0 = time.time()
try:
gen_resp = httpx.post(
f"{service_url}/v1/chat/completions",
json={
"model": model_id,
"messages": messages,
"max_tokens": 300,
"temperature": temperature,
"stream": False,
},
timeout=120.0,
)
gen_resp.raise_for_status()
elapsed_ms = int((time.time() - t0) * 1000)
content = gen_resp.json()["choices"][0]["message"]["content"]
return content.strip(), elapsed_ms, cold_started
except Exception as exc:
elapsed_ms = int((time.time() - t0) * 1000)
raise RuntimeError(str(exc)) from exc
finally:
if allocation_id:
try:
httpx.delete(f"{cforch_base}/api/services/cf-text/allocations/{allocation_id}", timeout=5.0)
except Exception:
pass
# ── GET /products ──────────────────────────────────────────────────────────────
@router.get("/products")
def get_products() -> dict:
"""List configured CF products with live online status."""
cfg = _load_imitate_config()
products_raw = cfg.get("products", []) or []
products = []
for p in products_raw:
if not isinstance(p, dict):
continue
base_url = p.get("base_url", "")
products.append({
"id": p.get("id", ""),
"name": p.get("name", ""),
"icon": p.get("icon", "📦"),
"description": p.get("description", ""),
"base_url": base_url,
"online": _is_online(base_url, p.get("health_path", "/api/health")) if base_url else False,
})
return {"products": products}
# ── GET /products/{product_id}/sample ─────────────────────────────────────────
@router.get("/products/{product_id}/sample")
def get_sample(product_id: str, index: int = 0) -> dict:
"""Fetch a real sample from the given product's API."""
cfg = _load_imitate_config()
products_raw = cfg.get("products", []) or []
product: dict | None = None
for p in products_raw:
if isinstance(p, dict) and p.get("id") == product_id:
product = p
break
if product is None:
raise HTTPException(404, f"Product '{product_id}' not in config")
base_url = product.get("base_url", "").rstrip("/")
endpoint = product.get("sample_endpoint", "")
if not base_url or not endpoint:
raise HTTPException(422, "Product missing base_url or sample_endpoint")
url = f"{base_url}{endpoint}"
try:
raw = _http_get_json(url, timeout=5)
except URLError as exc:
raise HTTPException(503, f"Product API unreachable: {exc}") from exc
except Exception as exc:
raise HTTPException(502, f"Bad response from product API: {exc}") from exc
text_fields = product.get("text_fields", []) or []
sample_key = product.get("sample_key") or None
extracted = _extract_sample(raw, text_fields, index, sample_key=sample_key)
if not extracted:
raise HTTPException(404, "No sample items returned by product API")
prompt_template = product.get("prompt_template", "{text}")
prompt = prompt_template.replace("{text}", extracted["text"])
# Also substitute any {field_name} placeholders from the raw item fields.
item = extracted.get("item", {})
for field, val in item.items():
prompt = prompt.replace(f"{{{field}}}", str(val) if val is not None else "")
# Expose system_prompt and image_url if the product API returns them.
# system_prompt: Peregrine, Snipe (vision analysis instructions)
# image_url: Snipe listing photos — Avocet downloads + base64-encodes at run time
item = extracted.get("item", {})
system_prompt = str(item.get("system_prompt", "")) if isinstance(item, dict) else ""
image_url = str(item.get("image_url", "")) if isinstance(item, dict) else ""
return {
"product_id": product_id,
"sample_index": index,
"text": extracted["text"],
"prompt": prompt,
"system_prompt": system_prompt,
"image_url": image_url,
"raw_item": item,
}
# ── GET /catalog ───────────────────────────────────────────────────────────────
@router.get("/catalog")
def get_catalog() -> dict:
"""Return the live cf-text model catalog from cf-orch coordinator."""
models = _cforch_catalog(_cforch_url())
return {"models": models}
# ── GET /run (SSE) ─────────────────────────────────────────────────────────────
def _get_imitate_session(request: Any, response: Any) -> "CloudUser | None":
"""Optional session dependency — returns None when cloud_session is unavailable."""
try:
from app.cloud_session import get_session
return get_session(request, response)
except Exception:
return None
@router.get("/run")
def run_imitate(
prompt: str = "",
model_ids: str = "", # comma-separated ollama model IDs
cf_text_model_ids: str = "", # comma-separated cf-text model IDs (via cf-orch)
temperature: float = 0.7,
product_id: str = "",
system: str = "", # optional system prompt
image_url: str = "", # optional image URL for vision models
session: "Any" = Depends(_get_imitate_session),
) -> StreamingResponse:
"""Run a prompt through selected ollama models and stream results as SSE.
If image_url is provided, the image is downloaded once and passed to every
model as a base64-encoded blob allowing vision-capable local models to
evaluate listing photos the same way Snipe's background task pipeline does.
"""
if not prompt.strip():
raise HTTPException(422, "prompt is required")
ollama_ids = [m.strip() for m in model_ids.split(",") if m.strip()]
cftext_ids = [m.strip() for m in cf_text_model_ids.split(",") if m.strip()]
if not ollama_ids and not cftext_ids:
raise HTTPException(422, "model_ids or cf_text_model_ids is required")
cfg = _load_imitate_config()
ollama_base = _ollama_url(cfg)
cforch_base = _cforch_url()
system_ctx = system.strip() or ""
total_models = len(ollama_ids) + len(cftext_ids)
# Download image once before streaming — shared across ollama vision models
images: list[str] = []
if image_url.strip():
b64 = _fetch_image_b64(image_url.strip())
if b64:
images = [b64]
def generate():
results: list[dict] = []
yield _sse({"type": "start", "total_models": total_models, "has_image": bool(images)})
# Ollama models
for model_id in ollama_ids:
yield _sse({"type": "model_start", "model": model_id, "service": "ollama"})
try:
response, elapsed_ms = _run_ollama_streaming(
ollama_base, model_id, prompt, temperature,
system=system_ctx, images=images or None,
)
result = {
"model": model_id,
"response": response,
"elapsed_ms": elapsed_ms,
"error": None,
}
except Exception as exc:
result = {
"model": model_id,
"response": "",
"elapsed_ms": 0,
"error": str(exc),
}
results.append(result)
yield _sse({"type": "model_done", **result})
# cf-text models via cf-orch — fan out in parallel when multiple models selected
if cftext_ids:
from concurrent.futures import ThreadPoolExecutor, as_completed
# Announce all models upfront so the UI can show loading states immediately
for model_id in cftext_ids:
yield _sse({"type": "model_start", "model": model_id, "service": "cf-text"})
_user_id: str | None = getattr(session, "user_id", None)
# Only forward real cloud user IDs — skip local/anon sessions
if _user_id in (None, "local", "local-dev") or (_user_id or "").startswith("anon-"):
_user_id = None
with ThreadPoolExecutor(max_workers=len(cftext_ids)) as pool:
future_to_model = {
pool.submit(
_run_cftext, cforch_base, mid, prompt, system_ctx, temperature,
180.0, _user_id,
): mid
for mid in cftext_ids
}
for future in as_completed(future_to_model):
model_id = future_to_model[future]
try:
response, elapsed_ms, cold_started = future.result()
if cold_started:
yield _sse({"type": "model_coldstart", "model": model_id})
result = {
"model": model_id,
"response": response,
"elapsed_ms": elapsed_ms,
"error": None,
}
except Exception as exc:
result = {
"model": model_id,
"response": "",
"elapsed_ms": 0,
"error": str(exc),
}
results.append(result)
yield _sse({"type": "model_done", **result})
yield _sse({"type": "complete", "results": results})
return StreamingResponse(
generate(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"X-Accel-Buffering": "no",
},
)
# ── POST /push-corrections ─────────────────────────────────────────────────────
class ImitateResult(BaseModel):
model: str
response: str
elapsed_ms: int
error: str | None = None
class PushCorrectionsRequest(BaseModel):
product_id: str
prompt: str
results: list[ImitateResult]
@router.post("/push-corrections")
def push_corrections(req: PushCorrectionsRequest) -> dict:
"""Append imitate results to sft_candidates.jsonl for human review."""
if not req.prompt.strip():
raise HTTPException(422, "prompt is required")
if not req.results:
raise HTTPException(422, "results list is empty")
ts = datetime.now(timezone.utc).isoformat()
records = []
for r in req.results:
if r.error or not r.response.strip():
continue
records.append({
"id": str(uuid.uuid4()),
"source": "imitate",
"product_id": req.product_id,
"prompt_messages": [{"role": "user", "content": req.prompt}],
"model_response": r.response,
"model_id": r.model,
"elapsed_ms": r.elapsed_ms,
"status": "pending",
"created_at": ts,
})
if not records:
raise HTTPException(422, "No non-error results to push")
dest = _candidates_file()
dest.parent.mkdir(parents=True, exist_ok=True)
for record in records:
append_jsonl(dest, record)
return {"pushed": len(records)}

View file

@ -15,6 +15,7 @@ from __future__ import annotations
import json import json
import logging import logging
import os import os
import re
import shutil import shutil
import threading import threading
from datetime import datetime, timezone from datetime import datetime, timezone
@ -60,6 +61,30 @@ _CF_ORCH_PROFILES_DIR: Path = Path(
router = APIRouter() router = APIRouter()
# ── HuggingFace auth ─────────────────────────────────────────────────────────
def _get_hf_token() -> str | None:
"""Return HF token from label_tool.yaml, then HF_TOKEN / HUGGING_FACE_HUB_TOKEN env vars."""
config_file = _ROOT / "config" / "label_tool.yaml"
if config_file.exists():
try:
import yaml as _yaml
raw = _yaml.safe_load(config_file.read_text(encoding="utf-8")) or {}
token = (raw.get("hf_token") or raw.get("cforch", {}).get("hf_token") or "").strip()
if token:
return token
except Exception:
pass
return os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN") or None
# ── GGUF quantization detection ───────────────────────────────────────────────
# Matches quant identifiers in GGUF filenames: Q4_K_M, Q8_0, F16, IQ3_M, etc.
_QUANT_RE = re.compile(
r'[._-]((?:IQ\d|Q\d)[A-Z0-9_]*|F16|BF16)\.gguf$',
re.IGNORECASE,
)
# ── Download progress shared state ──────────────────────────────────────────── # ── Download progress shared state ────────────────────────────────────────────
# Updated by the background download thread; read by GET /download/stream. # Updated by the background download thread; read by GET /download/stream.
_download_progress: dict[str, Any] = {} _download_progress: dict[str, Any] = {}
@ -91,12 +116,15 @@ _TAG_TO_INFO: dict[str, _TagInfo] = {
"audio-classification": {"adapter": None, "role": "classifier", "service": "cf-voice"}, "audio-classification": {"adapter": None, "role": "classifier", "service": "cf-voice"},
# TTS — cf-tts text-to-speech service # TTS — cf-tts text-to-speech service
"text-to-speech": {"adapter": None, "role": "tts", "service": "cf-tts"}, "text-to-speech": {"adapter": None, "role": "tts", "service": "cf-tts"},
# Vision — cf-vision image classification / embedding / VLM service # Vision classifiers / embedders — cf-vision (SigLIP/CLIP-style models)
"image-classification": {"adapter": None, "role": "vision", "service": "cf-vision"}, "image-classification": {"adapter": None, "role": "vision", "service": "cf-vision"},
"zero-shot-image-classification": {"adapter": None, "role": "vision", "service": "cf-vision"}, "zero-shot-image-classification": {"adapter": None, "role": "vision", "service": "cf-vision"},
"image-feature-extraction": {"adapter": None, "role": "embedding", "service": "cf-vision"}, "image-feature-extraction": {"adapter": None, "role": "embedding", "service": "cf-vision"},
"image-text-to-text": {"adapter": None, "role": "vlm", "service": "cf-vision"}, # Generative VLMs (image+text → text) — run under vllm, not cf-vision.
"visual-question-answering": {"adapter": None, "role": "vlm", "service": "cf-vision"}, # cf-vision is a classifier/embedder service; generative VLMs like Qwen-VL,
# LLaVA, and InternVL are textgen models that happen to accept image inputs.
"image-text-to-text": {"adapter": None, "role": "vlm", "service": "vllm"},
"visual-question-answering": {"adapter": None, "role": "vlm", "service": "vllm"},
# Image generation — cf-image (text → image; distinct from cf-vision image understanding) # Image generation — cf-image (text → image; distinct from cf-vision image understanding)
"text-to-image": {"adapter": None, "role": "image-gen", "service": "cf-image"}, "text-to-image": {"adapter": None, "role": "image-gen", "service": "cf-image"},
# Embedding — cf-core shared embedding layer # Embedding — cf-core shared embedding layer
@ -195,10 +223,17 @@ def _get_queue_entry(entry_id: str) -> dict | None:
def _catalog_key(repo_id: str) -> str: def _catalog_key(repo_id: str) -> str:
"""Derive a readable catalog key from repo_id. """Derive a readable catalog key from repo_id.
ibm-granite/granite-4.1-8b granite-4.1-8b ibm-granite/granite-4.1-8b granite-4.1-8b
facebook/bart-large-cnn bart-large-cnn facebook/bart-large-cnn bart-large-cnn
WithinUsAI/Opus4.7-GODs.Ghost.Codex-4B.GGuF opus4.7-gods.ghost.codex-4b
The coordinator skips catalog lookup for keys ending in ".gguf" (treats them
as direct file paths). Strip the suffix so GGUF repo names produce valid keys.
""" """
return repo_id.split("/", 1)[-1].lower() key = repo_id.split("/", 1)[-1].lower()
if key.endswith(".gguf"):
key = key[:-5]
return key
def _insert_catalog_entry(content: str, entry_lines: str) -> str: def _insert_catalog_entry(content: str, entry_lines: str) -> str:
@ -290,6 +325,15 @@ def _register_in_node_catalogs(
max_mb: int = cf_text.get("max_mb", 0) max_mb: int = cf_text.get("max_mb", 0)
catalog: dict = cf_text.get("catalog") or {} catalog: dict = cf_text.get("catalog") or {}
# If the node has a different local model dir, remap the NFS path.
model_base = cf_text.get("model_base_path", "").rstrip("/")
if model_base:
nfs_base = str(_CF_TEXT_MODELS_DIR).rstrip("/")
model_name = local_path.name
effective_path_str = f"{model_base}/{model_name}"
else:
effective_path_str = local_path_str
# Skip if key already exists # Skip if key already exists
if model_key in catalog: if model_key in catalog:
logger.debug("Key %r already in %s — skipping", model_key, yaml_file.name) logger.debug("Key %r already in %s — skipping", model_key, yaml_file.name)
@ -301,10 +345,10 @@ def _register_in_node_catalogs(
for entry in catalog.values() for entry in catalog.values()
if isinstance(entry, dict) if isinstance(entry, dict)
} }
if local_path_str in registered_paths or any( if effective_path_str in registered_paths or any(
p.startswith(local_path_str + "/") for p in registered_paths p.startswith(effective_path_str + "/") for p in registered_paths
): ):
logger.debug("Path %s already registered in %s — skipping", local_path_str, yaml_file.name) logger.debug("Path %s already registered in %s — skipping", effective_path_str, yaml_file.name)
continue continue
# Determine whether model fits at FP16 or needs 4-bit # Determine whether model fits at FP16 or needs 4-bit
@ -330,12 +374,18 @@ def _register_in_node_catalogs(
if needs_4bit if needs_4bit
else f" # FP16 file-size estimate" else f" # FP16 file-size estimate"
) )
env_block = (
f" env:\n"
f" CF_TEXT_4BIT: \"1\"\n"
if needs_4bit else ""
)
entry_block = ( entry_block = (
f" # auto-registered by avocet on download\n" f" # auto-registered by avocet on download\n"
f" {model_key}:\n" f" {model_key}:\n"
f" path: {local_path_str}\n" f" path: {effective_path_str}\n"
f" vram_mb: {vram_for_node}{vram_comment}\n" f" vram_mb: {vram_for_node}{vram_comment}\n"
f" description: \"{desc}\"\n" f" description: \"{desc}\"\n"
f"{env_block}"
) )
new_content = _insert_catalog_entry(content, entry_block) new_content = _insert_catalog_entry(content, entry_block)
@ -388,12 +438,17 @@ def _run_download(
role: str | None = None, role: str | None = None,
service: str | None = None, service: str | None = None,
model_size_bytes: int = 0, model_size_bytes: int = 0,
quant_pattern: str | None = None,
) -> None: ) -> None:
"""Background thread: download model via huggingface_hub.snapshot_download. """Background thread: download model via huggingface_hub.snapshot_download.
model_size_bytes is the sum of file sizes reported by the HF API (siblings). model_size_bytes is the sum of file sizes reported by the HF API (siblings).
It is used to estimate vram_mb and written to model_info.json so cf-orch can It is used to estimate vram_mb and written to model_info.json so cf-orch can
budget VRAM when allocating a cf-text instance for this model. budget VRAM when allocating a cf-text instance for this model.
quant_pattern: when set, restricts snapshot_download to only files matching
*{quant_pattern}*.gguf (plus metadata). Avoids downloading every quant variant
from GGUF-only repos like bartowski/*.
""" """
global _download_progress global _download_progress
local_dir = _model_dir_for(repo_id, service) local_dir = _model_dir_for(repo_id, service)
@ -422,10 +477,20 @@ def _run_download(
local_dir.mkdir(parents=True, exist_ok=True) local_dir.mkdir(parents=True, exist_ok=True)
poll_thread.start() poll_thread.start()
snapshot_download(
repo_id=repo_id, dl_kwargs: dict[str, Any] = {"repo_id": repo_id, "local_dir": str(local_dir)}
local_dir=str(local_dir), hf_token = _get_hf_token()
) if hf_token:
dl_kwargs["token"] = hf_token
if quant_pattern:
# Include both cases: repos use mixed conventions (Q6_K vs q6_k).
dl_kwargs["allow_patterns"] = [
f"*{quant_pattern.upper()}*.gguf",
f"*{quant_pattern.lower()}*.gguf",
"*.json",
"README.md",
]
snapshot_download(**dl_kwargs)
# Estimate VRAM from reported file size. # Estimate VRAM from reported file size.
# HF siblings sizes are pre-quantisation file sizes; add 10% for KV cache # HF siblings sizes are pre-quantisation file sizes; add 10% for KV cache
@ -531,9 +596,31 @@ def lookup_model(repo_id: str) -> dict:
) )
logger.warning("Unsupported pipeline_tag %r for %s", pipeline_tag, repo_id) logger.warning("Unsupported pipeline_tag %r for %s", pipeline_tag, repo_id)
# Estimate model size from siblings list # Detect GGUF files and parse quant names from siblings list.
# For GGUF-only repos (bartowski, TheBloke, etc.) this lets the UI show
# a per-quant size picker instead of downloading every variant.
siblings = data.get("siblings") or [] siblings = data.get("siblings") or []
model_size_bytes: int = sum(s.get("size", 0) for s in siblings if isinstance(s, dict)) gguf_files: list[dict] = []
for s in siblings:
if not isinstance(s, dict):
continue
fname: str = s.get("rfilename", "")
if not fname.lower().endswith(".gguf"):
continue
m = _QUANT_RE.search(fname)
gguf_files.append({
"filename": fname,
"size": s.get("size", 0) or 0,
"quant_name": m.group(1).upper() if m else None,
})
gguf_files.sort(key=lambda f: f["size"])
# model_size_bytes: total of all siblings (for non-GGUF repos) or all GGUFs only.
# For GGUF repos the frontend will substitute the selected quant's size on submit.
if gguf_files:
model_size_bytes: int = sum(f["size"] for f in gguf_files)
else:
model_size_bytes = sum(s.get("size", 0) for s in siblings if isinstance(s, dict))
# Description: first 300 chars of card data (modelId field used as fallback) # Description: first 300 chars of card data (modelId field used as fallback)
card_data = data.get("cardData") or {} card_data = data.get("cardData") or {}
@ -549,6 +636,7 @@ def lookup_model(repo_id: str) -> dict:
"compatible": compatible, "compatible": compatible,
"warning": warning, "warning": warning,
"model_size_bytes": model_size_bytes, "model_size_bytes": model_size_bytes,
"gguf_files": gguf_files if gguf_files else None,
"description": description, "description": description,
"tags": data.get("tags") or [], "tags": data.get("tags") or [],
"downloads": data.get("downloads") or 0, "downloads": data.get("downloads") or 0,
@ -579,6 +667,9 @@ class QueueAddRequest(BaseModel):
# Stored in the queue entry so approve can pass it to _run_download # Stored in the queue entry so approve can pass it to _run_download
# without a second HF API round-trip. # without a second HF API round-trip.
model_size_bytes: int = 0 model_size_bytes: int = 0
# GGUF quantization pattern (e.g. "Q5_K_M"). When set, snapshot_download
# restricts to *{quant_pattern}*.gguf instead of fetching all variants.
quant_pattern: str | None = None
@router.post("/queue", status_code=201) @router.post("/queue", status_code=201)
@ -597,6 +688,7 @@ def add_to_queue(req: QueueAddRequest) -> dict:
"role": req.role, "role": req.role,
"service": req.service, "service": req.service,
"model_size_bytes": req.model_size_bytes, "model_size_bytes": req.model_size_bytes,
"quant_pattern": req.quant_pattern,
"status": "pending", "status": "pending",
"queued_at": datetime.now(timezone.utc).isoformat(), "queued_at": datetime.now(timezone.utc).isoformat(),
} }
@ -629,6 +721,7 @@ def approve_queue_entry(entry_id: str) -> dict:
entry.get("role"), entry.get("role"),
entry.get("service"), entry.get("service"),
entry.get("model_size_bytes", 0), entry.get("model_size_bytes", 0),
entry.get("quant_pattern"),
), ),
daemon=True, daemon=True,
name=f"model-download-{entry_id}", name=f"model-download-{entry_id}",
@ -638,6 +731,32 @@ def approve_queue_entry(entry_id: str) -> dict:
return {"ok": True} return {"ok": True}
# ── PATCH /queue/{id} ─────────────────────────────────────────────────────────
class QueuePatchRequest(BaseModel):
service: str | None = None
role: str | None = None
@router.patch("/queue/{entry_id}")
def patch_queue_entry(entry_id: str, body: QueuePatchRequest) -> dict:
"""Update mutable fields (service, role) on a pending queue entry."""
entry = _get_queue_entry(entry_id)
if entry is None:
raise HTTPException(404, f"Queue entry {entry_id!r} not found")
if entry.get("status") != "pending":
raise HTTPException(409, f"Only pending entries can be patched (current: {entry.get('status')!r})")
updates: dict = {}
if body.service is not None:
updates["service"] = body.service
if body.role is not None:
updates["role"] = body.role
updated = _update_queue_entry(entry_id, updates)
return updated or {}
# ── DELETE /queue/{id} ───────────────────────────────────────────────────────── # ── DELETE /queue/{id} ─────────────────────────────────────────────────────────
@router.delete("/queue/{entry_id}") @router.delete("/queue/{entry_id}")

359
app/nodes.py Normal file
View file

@ -0,0 +1,359 @@
"""Avocet — Node Management API.
Proxies cf-orch coordinator and agent APIs to expose per-node GPU state,
service affinity management, and Ollama model management.
Config is read from label_tool.yaml under the `cforch:` key.
The `profiles_dir` key (new) points to the cf-orch node profile YAML directory.
Module-level globals follow the set_config_dir() testability pattern from cforch.py.
"""
from __future__ import annotations
import json
import logging
import os
from pathlib import Path
from urllib.parse import urlparse
import yaml
from fastapi import APIRouter, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
logger = logging.getLogger(__name__)
_ROOT = Path(__file__).parent.parent
_CONFIG_DIR: Path | None = None # override in tests
router = APIRouter()
# ── Testability seams ──────────────────────────────────────────────────────────
def set_config_dir(path: Path | None) -> None:
global _CONFIG_DIR
_CONFIG_DIR = path
# ── Internal helpers ───────────────────────────────────────────────────────────
def _config_file() -> Path:
if _CONFIG_DIR is not None:
return _CONFIG_DIR / "label_tool.yaml"
return _ROOT / "config" / "label_tool.yaml"
def _load_config() -> dict:
"""Read label_tool.yaml cforch section. Returns empty dict on missing or parse error."""
f = _config_file()
if not f.exists():
return {}
try:
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
return raw.get("cforch", {}) or {}
except yaml.YAMLError as exc:
logger.warning("Failed to parse config %s: %s", f, exc)
return {}
def _profiles_dir() -> Path | None:
"""Return the cf-orch node profiles directory, or None if not configured."""
cfg = _load_config()
pd = cfg.get("profiles_dir", "") or ""
if pd:
return Path(pd)
bench = cfg.get("bench_script", "") or ""
if bench:
return Path(bench).parent.parent / "profiles" / "nodes"
return None
def _profile_path(node_id: str) -> Path | None:
"""Return the path to a node's profile YAML, or None if profiles_dir is unknown."""
pd = _profiles_dir()
if pd is None:
return None
return pd / f"{node_id}.yaml"
def _load_profile(node_id: str) -> dict | None:
"""Load and parse a node profile YAML. Returns None if not found or malformed."""
p = _profile_path(node_id)
if p is None or not p.exists():
return None
try:
return yaml.safe_load(p.read_text(encoding="utf-8")) or {}
except yaml.YAMLError as exc:
logger.warning("Malformed profile YAML %s: %s", p, exc)
return None
def _get_ollama_url(node_id: str) -> str:
"""Derive Ollama URL from the node profile's agent_url (same host, port 11434)."""
profile = _load_profile(node_id)
if profile:
nodes_section = profile.get("nodes", {}) or {}
node_entry = nodes_section.get(node_id, {}) or {}
agent_url = node_entry.get("agent_url", "") or ""
if agent_url:
parsed = urlparse(agent_url)
return f"{parsed.scheme}://{parsed.hostname}:11434"
raise HTTPException(
status_code=404,
detail=f"Cannot determine Ollama URL for node {node_id}: no agent_url in profile",
)
# ── Endpoints ──────────────────────────────────────────────────────────────────
@router.get("/nodes")
def list_nodes() -> list:
"""Return all nodes with live GPU stats merged with profile YAML."""
import httpx
cfg = _load_config()
coordinator_url = cfg.get("coordinator_url", "") or ""
if not coordinator_url:
return []
try:
r = httpx.get(f"{coordinator_url}/api/nodes", timeout=5.0)
r.raise_for_status()
coord_nodes: list[dict] = r.json()
except httpx.HTTPError as exc:
logger.warning("Coordinator unreachable: %s", exc)
return []
try:
sr = httpx.get(f"{coordinator_url}/api/services", timeout=5.0)
sr.raise_for_status()
services_data: list[dict] = sr.json()
except httpx.HTTPError:
logger.warning("Services API unreachable for %s, skipping", coordinator_url)
services_data = []
# Build per-node, per-GPU running services map
running: dict[str, dict[int, list[str]]] = {}
for svc in services_data:
nid = svc.get("node_id", "")
gid = svc.get("gpu_id")
svc_name = svc.get("service", "")
if nid and gid is not None and svc_name:
running.setdefault(nid, {}).setdefault(gid, []).append(svc_name)
result = []
for node in coord_nodes:
node_id = node.get("node_id", "") or node.get("id", "")
profile = _load_profile(node_id) if node_id else None
profile_loaded = profile is not None
gpus = []
for gpu in (node.get("gpus", []) or []):
gpu_id = gpu.get("gpu_id", gpu.get("id", 0))
services_assigned: list[str] = []
if profile:
node_entry = (profile.get("nodes", {}) or {}).get(node_id, {}) or {}
for g in (node_entry.get("gpus", []) or []):
if isinstance(g, dict) and g.get("id") == gpu_id:
services_assigned = g.get("services", []) or []
break
gpus.append({
"gpu_id": gpu_id,
"card": gpu.get("card", ""),
"vram_total_mb": gpu.get("vram_total_mb", 0),
"vram_used_mb": gpu.get("vram_used_mb", 0),
"vram_free_mb": gpu.get("vram_free_mb", 0),
"temp_c": gpu.get("temp_c"),
"utilization_pct": gpu.get("utilization_pct"),
"compute_cap": gpu.get("compute_cap"),
"services_assigned": services_assigned,
"services_running": running.get(node_id, {}).get(gpu_id, []),
})
services_catalog: dict = {}
if profile:
for svc_name, svc_info in (profile.get("services", {}) or {}).items():
catalog = svc_info.get("catalog", {}) or {}
services_catalog[svc_name] = {
"min_compute_cap": svc_info.get("min_compute_cap", 0.0),
"max_mb": svc_info.get("max_mb", 0),
"catalog_size": len(catalog),
}
result.append({
"node_id": node_id,
"online": node.get("online", True),
"agent_url": node.get("agent_url", ""),
"gpus": gpus,
"profile_loaded": profile_loaded,
"services_catalog": services_catalog,
})
return result
@router.get("/nodes/{node_id}/profile")
def get_node_profile(node_id: str) -> dict:
"""Return the full parsed profile YAML for a node."""
p = _profile_path(node_id)
if p is None or not p.exists():
raise HTTPException(404, f"No profile found for node {node_id}")
try:
data = yaml.safe_load(p.read_text(encoding="utf-8")) or {}
except yaml.YAMLError as exc:
raise HTTPException(500, f"Malformed profile YAML: {exc}")
return data
class UpdateServicesRequest(BaseModel):
services: list[str]
@router.post("/nodes/{node_id}/gpu/{gpu_id}/services")
def update_gpu_services(node_id: str, gpu_id: int, body: UpdateServicesRequest) -> dict:
"""Set service assignment for a GPU with compatibility validation, then atomic write."""
import httpx
cfg = _load_config()
coordinator_url = cfg.get("coordinator_url", "") or ""
p = _profile_path(node_id)
if p is None or not p.exists():
raise HTTPException(404, f"No profile found for node {node_id}")
try:
profile = yaml.safe_load(p.read_text(encoding="utf-8")) or {}
except yaml.YAMLError as exc:
raise HTTPException(500, f"Malformed profile YAML: {exc}")
nodes_section = profile.get("nodes", {}) or {}
node_entry = nodes_section.get(node_id, {}) or {}
gpu_list = node_entry.get("gpus", []) or []
gpu_entry = next(
(g for g in gpu_list if isinstance(g, dict) and g.get("id") == gpu_id),
None,
)
if gpu_entry is None:
raise HTTPException(404, f"GPU {gpu_id} not found in profile for node {node_id}")
gpu_compute_cap: float = gpu_entry.get("compute_cap") or 0.0
gpu_vram_mb: int = gpu_entry.get("vram_mb") or 0
services_def = profile.get("services", {}) or {}
for svc_name in body.services:
if svc_name not in services_def:
raise HTTPException(422, f"Service '{svc_name}' not defined in profile services dict")
svc = services_def[svc_name]
min_cap: float = svc.get("min_compute_cap", 0.0) or 0.0
if gpu_compute_cap < min_cap:
raise HTTPException(
422,
f"Service '{svc_name}' requires compute_cap >= {min_cap}; GPU has {gpu_compute_cap}",
)
catalog = svc.get("catalog", {}) or {}
min_catalog_vram = (
min((m.get("vram_mb", 0) for m in catalog.values()), default=0)
if catalog else svc.get("max_mb", 0)
)
if gpu_vram_mb < min_catalog_vram:
raise HTTPException(
422,
f"Service '{svc_name}' requires {min_catalog_vram} MB VRAM; GPU has {gpu_vram_mb} MB",
)
# Immutable update of GPU services list
new_gpu_list = [
({**g, "services": body.services} if isinstance(g, dict) and g.get("id") == gpu_id else g)
for g in gpu_list
]
new_profile = {
**profile,
"nodes": {
**nodes_section,
node_id: {**node_entry, "gpus": new_gpu_list},
},
}
# Atomic write: write to .tmp then rename
tmp_yaml = Path(str(p) + ".tmp")
tmp_yaml.write_text(yaml.dump(new_profile, default_flow_style=False), encoding="utf-8")
os.replace(tmp_yaml, p)
# Trigger coordinator profile reload
reloaded = False
if coordinator_url:
try:
rr = httpx.post(
f"{coordinator_url}/api/nodes/{node_id}/reload-profile", timeout=5.0
)
reloaded = rr.status_code < 300
except Exception as exc:
logger.warning("Coordinator reload failed for node %s: %s", node_id, exc)
return {"ok": True, "reloaded": reloaded, "warnings": []}
# ── Ollama model management ────────────────────────────────────────────────────
class PullRequest(BaseModel):
name: str
@router.get("/nodes/{node_id}/models/ollama")
def list_ollama_models(node_id: str) -> dict:
"""Proxy GET {ollama_url}/api/tags for a specific node."""
import httpx
ollama_url = _get_ollama_url(node_id)
try:
r = httpx.get(f"{ollama_url}/api/tags", timeout=10.0)
r.raise_for_status()
return r.json()
except Exception as exc:
return {"error": str(exc)}
@router.post("/nodes/{node_id}/models/ollama/pull")
def pull_ollama_model(node_id: str, body: PullRequest) -> StreamingResponse:
"""Stream Ollama pull progress as SSE events."""
import httpx
if not body.name:
raise HTTPException(400, "name is required")
ollama_url = _get_ollama_url(node_id)
def stream():
try:
with httpx.stream(
"POST",
f"{ollama_url}/api/pull",
json={"name": body.name, "stream": True},
timeout=300.0,
) as resp:
for line in resp.iter_lines():
if line:
yield f"data: {line}\n\n"
except Exception as exc:
yield f"data: {json.dumps({'error': str(exc)})}\n\n"
return StreamingResponse(stream(), media_type="text/event-stream")
@router.delete("/nodes/{node_id}/models/ollama/{name:path}")
def delete_ollama_model(node_id: str, name: str) -> dict:
"""Proxy DELETE to Ollama for a specific node."""
import httpx
ollama_url = _get_ollama_url(node_id)
try:
r = httpx.request("DELETE", f"{ollama_url}/api/delete", json={"name": name}, timeout=10.0)
if r.status_code == 404:
raise HTTPException(404, f"Model '{name}' not found on node {node_id}")
r.raise_for_status()
return {"ok": True}
except HTTPException:
raise
except Exception as exc:
raise HTTPException(502, f"Ollama unreachable: {exc}")

323
app/plans_bench.py Normal file
View file

@ -0,0 +1,323 @@
"""Avocet — CF planning benchmark integration API.
Wraps scripts/benchmark_plans.py and exposes it via the Avocet API.
Connection config (api_base) is read from label_tool.yaml under the
`plans_bench:` key (optional; falls back to localhost:8080).
All endpoints are registered on `router` (FastAPI APIRouter).
api.py includes this router with prefix="/api/plans-bench".
"""
from __future__ import annotations
import json
import logging
import subprocess as _subprocess
import uuid
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
import httpx
import yaml
from fastapi import APIRouter, HTTPException, Query
from fastapi.responses import StreamingResponse
logger = logging.getLogger(__name__)
_ROOT = Path(__file__).parent.parent
_CONFIG_DIR: Path | None = None # override in tests via set_config_dir()
_BENCH_RUNNING: bool = False
_bench_proc: Any = None
_BENCH_SCRIPT = _ROOT / "scripts" / "benchmark_plans.py"
_RESULTS_DIR = _ROOT / "data" / "plans_bench_results"
router = APIRouter()
# ── Registered model shortcuts (mirrors benchmark_plans.MODEL_REGISTRY) ────────
# Kept here so the UI can list them without importing the script.
MODEL_REGISTRY: dict[str, str] = {
"llama3.2-3b": "Llama 3.2 3B Instruct (local via cf-text)",
"llama3.2-1b": "Llama 3.2 1B Instruct (local via cf-text)",
"mistral-7b": "Mistral 7B Instruct (local via cf-text)",
"phi3-mini": "Phi-3 Mini 3.8B (local via cf-text)",
"qwen2.5-3b": "Qwen 2.5 3B Instruct (local via cf-text)",
}
RUBRIC_LABELS: dict[str, str] = {
"task_structure": "Task structure (checkboxes + commits)",
"tier_awareness": "Tier awareness (Free/Paid/Premium/Ultra)",
"privacy_pillar": "Privacy pillar (local-first, no logging)",
"safety_pillar": "Safety pillar (human approval, reversibility)",
"accessibility": "Accessibility (ND/adaptive users)",
"license_split": "License awareness (MIT vs BSL)",
"file_paths": "File paths (plausible project paths)",
"cf_conventions": "CF conventions (conda, manage.sh, /Library/…)",
"length_ok": "Response length (2002500 words)",
}
# ── Testability seam ───────────────────────────────────────────────────────────
def set_config_dir(path: Path | None) -> None:
global _CONFIG_DIR
_CONFIG_DIR = path
# ── Internal helpers ───────────────────────────────────────────────────────────
def _config_file() -> Path:
if _CONFIG_DIR is not None:
return _CONFIG_DIR / "label_tool.yaml"
return _ROOT / "config" / "label_tool.yaml"
def _load_config() -> dict:
f = _config_file()
cforch_cfg: dict = {}
bench_cfg: dict = {}
if f.exists():
try:
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
cforch_cfg = raw.get("cforch", {}) or {}
bench_cfg = raw.get("plans_bench", {}) or {}
except yaml.YAMLError as exc:
logger.warning("Failed to parse plans_bench config %s: %s", f, exc)
return {
"coordinator_url": cforch_cfg.get("coordinator_url",
bench_cfg.get("coordinator_url", "http://10.1.10.71:7700")),
"python_bin": cforch_cfg.get("python_bin",
bench_cfg.get("python_bin", "/devl/miniconda3/envs/cf/bin/python")),
}
def _results_file(run_id: str) -> Path:
return _RESULTS_DIR / f"{run_id}.json"
# ── GET /models ────────────────────────────────────────────────────────────────
@router.get("/models")
def get_models() -> dict:
"""Return registered model shortcuts, live cf-orch catalog, and rubric labels."""
cfg = _load_config()
cforch_models: list[dict] = []
try:
resp = httpx.get(
f"{cfg['coordinator_url']}/api/services/cf-text/catalog",
timeout=5.0,
)
resp.raise_for_status()
for model_id, entry in resp.json().items():
if isinstance(entry, dict):
cforch_models.append({
"id": model_id,
"name": model_id,
"vram_mb": entry.get("vram_mb"),
"description": entry.get("description", ""),
})
except Exception as exc:
logger.warning("Failed to fetch cf-orch catalog: %s", exc)
return {
"registry": [
{"key": k, "description": v}
for k, v in MODEL_REGISTRY.items()
],
"cforch_models": cforch_models,
"coordinator_url": cfg["coordinator_url"],
"rubric_labels": RUBRIC_LABELS,
}
# ── GET /run ───────────────────────────────────────────────────────────────────
@router.get("/run")
def run_plans_benchmark(
models: str = Query(..., description="Comma-separated model IDs (registry keys or cf-orch model names)"),
prompt_ids: str = Query("", description="Comma-separated prompt IDs to run (empty = all 10)"),
use_cforch: bool = Query(True, description="Route inference through cf-orch coordinator"),
api_base: str = Query("", description="Direct API base URL when not using cf-orch"),
workers: int = Query(1, ge=1, le=8, description="Number of models to benchmark concurrently"),
) -> StreamingResponse:
"""Spawn benchmark_plans.py and stream stdout as SSE progress events.
On successful completion emits a `type: result` event with parsed JSON
and saves results to data/plans_bench_results/<run_id>.json.
"""
global _BENCH_RUNNING, _bench_proc
if _BENCH_RUNNING:
raise HTTPException(409, "A planning benchmark is already running")
cfg = _load_config()
python_bin = cfg["python_bin"]
coordinator_url = cfg["coordinator_url"]
model_keys = [m.strip() for m in models.split(",") if m.strip()]
if not model_keys:
raise HTTPException(400, "At least one model key is required")
run_id = datetime.now(tz=timezone.utc).strftime("plans_%Y-%m-%d_%H%M%S")
output_path = _results_file(run_id)
_RESULTS_DIR.mkdir(parents=True, exist_ok=True)
def generate():
global _BENCH_RUNNING, _bench_proc
if not _BENCH_SCRIPT.exists():
yield f"data: {json.dumps({'type': 'error', 'message': f'benchmark_plans.py not found at {_BENCH_SCRIPT}'})}\n\n"
return
cmd = [python_bin, str(_BENCH_SCRIPT)]
if len(model_keys) > 1:
cmd.extend(["--compare"] + model_keys)
else:
cmd.extend(["--model", model_keys[0]])
if use_cforch:
cmd.extend(["--cforch", "--cforch-url", coordinator_url])
elif api_base.strip():
cmd.extend(["--api-base", api_base.strip()])
cmd.extend(["--verbose", "--output", str(output_path)])
if workers > 1:
cmd.extend(["--workers", str(workers)])
if prompt_ids.strip():
cmd.extend(["--prompts"] + [p.strip() for p in prompt_ids.split(",") if p.strip()])
_BENCH_RUNNING = True
try:
proc = _subprocess.Popen(
cmd,
stdout=_subprocess.PIPE,
stderr=_subprocess.STDOUT,
text=True,
bufsize=1,
cwd=str(_ROOT),
)
_bench_proc = proc
try:
for line in proc.stdout:
line = line.rstrip()
if line:
yield f"data: {json.dumps({'type': 'progress', 'message': line})}\n\n"
proc.wait()
if proc.returncode == 0 and output_path.exists():
try:
results = json.loads(output_path.read_text(encoding="utf-8"))
yield f"data: {json.dumps({'type': 'result', 'run_id': run_id, 'results': results})}\n\n"
except Exception as exc:
logger.warning("Failed to read plans benchmark output: %s", exc)
yield f"data: {json.dumps({'type': 'complete'})}\n\n"
else:
yield f"data: {json.dumps({'type': 'error', 'message': f'Process exited with code {proc.returncode}'})}\n\n"
finally:
_bench_proc = None
except Exception as exc:
yield f"data: {json.dumps({'type': 'error', 'message': str(exc)})}\n\n"
finally:
_BENCH_RUNNING = False
return StreamingResponse(
generate(),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
)
# ── GET /results ───────────────────────────────────────────────────────────────
@router.get("/results")
def list_results() -> list[dict]:
"""List past planning benchmark runs, newest first."""
if not _RESULTS_DIR.exists():
return []
runs: list[dict] = []
for f in sorted(_RESULTS_DIR.glob("plans_*.json"), reverse=True):
run_id = f.stem
try:
data: dict = json.loads(f.read_text(encoding="utf-8"))
model_keys = list(data.keys())
# Average total_score across all models and prompts
all_scores = [
r["total_score"]
for results in data.values()
for r in results
if not r.get("error")
]
avg_score = round(sum(all_scores) / len(all_scores), 3) if all_scores else 0.0
except Exception:
model_keys = []
avg_score = 0.0
# Parse display date from run_id (plans_2026-04-27_143022)
try:
date_part = run_id.removeprefix("plans_") # 2026-04-27_143022
date, time = date_part.split("_")
display_date = f"{date} {time[:2]}:{time[2:4]}"
except Exception:
display_date = run_id
runs.append({
"run_id": run_id,
"filename": f.name,
"date": display_date,
"models": model_keys,
"avg_score": avg_score,
})
return runs
@router.get("/results/latest")
def get_latest_results() -> dict:
"""Return the most recent planning benchmark results dict."""
if not _RESULTS_DIR.exists():
raise HTTPException(404, "No benchmark results found")
files = sorted(_RESULTS_DIR.glob("plans_*.json"))
if not files:
raise HTTPException(404, "No benchmark results found")
try:
return json.loads(files[-1].read_text(encoding="utf-8"))
except Exception as exc:
raise HTTPException(500, f"Failed to read results: {exc}") from exc
@router.get("/results/{run_id}")
def get_results_by_run_id(run_id: str) -> dict:
"""Return planning benchmark results for a specific run."""
if not run_id.startswith("plans_"):
raise HTTPException(400, "Invalid run_id — expected plans_YYYY-MM-DD_HHMMSS")
f = _results_file(run_id)
if not f.exists():
raise HTTPException(404, f"Results not found: {run_id}")
try:
return json.loads(f.read_text(encoding="utf-8"))
except Exception as exc:
raise HTTPException(500, f"Failed to read results: {exc}") from exc
# ── POST /cancel ───────────────────────────────────────────────────────────────
@router.post("/cancel")
def cancel_plans_benchmark() -> dict:
"""Kill the running planning benchmark subprocess."""
global _BENCH_RUNNING, _bench_proc
if not _BENCH_RUNNING:
raise HTTPException(404, "No planning benchmark is currently running")
if _bench_proc is not None:
try:
_bench_proc.terminate()
except Exception as exc:
logger.warning("Failed to terminate plans benchmark: %s", exc)
_BENCH_RUNNING = False
_bench_proc = None
return {"status": "cancelled"}

View file

@ -1,335 +1,8 @@
"""Avocet — SFT candidate import and correction API. """Backward-compat shim -- logic moved to app/data/corrections.py."""
from app.data.corrections import ( # noqa: F401
All endpoints are registered on `router` (a FastAPI APIRouter). router,
api.py includes this router with prefix="/api/sft". set_data_dir as set_sft_data_dir,
set_config_dir as set_sft_config_dir,
Module-level globals (_SFT_DATA_DIR, _SFT_CONFIG_DIR) follow the same set_default_bench_results_dir,
testability pattern as api.py override them via set_sft_data_dir() and _DEFAULT_BENCH_RESULTS_DIR,
set_sft_config_dir() in test fixtures. )
"""
from __future__ import annotations
import json
import logging
from datetime import datetime, timezone
from pathlib import Path
from typing import Literal
import yaml
from fastapi import APIRouter, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from app.utils import append_jsonl, read_jsonl, write_jsonl
logger = logging.getLogger(__name__)
_ROOT = Path(__file__).parent.parent
_SFT_DATA_DIR: Path = _ROOT / "data"
_SFT_CONFIG_DIR: Path | None = None
router = APIRouter()
# ── Testability seams ──────────────────────────────────────────────────────
def set_sft_data_dir(path: Path) -> None:
global _SFT_DATA_DIR
_SFT_DATA_DIR = path
def set_sft_config_dir(path: Path | None) -> None:
global _SFT_CONFIG_DIR
_SFT_CONFIG_DIR = path
# ── Internal helpers ───────────────────────────────────────────────────────
def _config_file() -> Path:
if _SFT_CONFIG_DIR is not None:
return _SFT_CONFIG_DIR / "label_tool.yaml"
return _ROOT / "config" / "label_tool.yaml"
_DEFAULT_BENCH_RESULTS_DIR = "/Library/Development/CircuitForge/circuitforge-orch/scripts/bench_results"
def set_default_bench_results_dir(path: str) -> None:
"""Override the default bench_results_dir — used by tests to avoid real filesystem."""
global _DEFAULT_BENCH_RESULTS_DIR
_DEFAULT_BENCH_RESULTS_DIR = path
def _get_bench_results_dir() -> Path:
f = _config_file()
if f.exists():
try:
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
d = raw.get("sft", {}).get("bench_results_dir", "")
if d:
return Path(d)
except yaml.YAMLError as exc:
logger.warning("Failed to parse SFT config %s: %s", f, exc)
return Path(_DEFAULT_BENCH_RESULTS_DIR)
def _candidates_file() -> Path:
return _SFT_DATA_DIR / "sft_candidates.jsonl"
def _approved_file() -> Path:
return _SFT_DATA_DIR / "sft_approved.jsonl"
def _read_candidates() -> list[dict]:
return read_jsonl(_candidates_file())
def _write_candidates(records: list[dict]) -> None:
write_jsonl(_candidates_file(), records)
def _is_exportable(r: dict) -> bool:
"""Return True if an approved record is ready to include in SFT export."""
return (
r.get("status") == "approved"
and bool(r.get("corrected_response"))
and str(r["corrected_response"]).strip() != ""
)
# ── GET /runs ──────────────────────────────────────────────────────────────
@router.get("/runs")
def get_runs():
"""List available benchmark runs in the configured bench_results_dir."""
from scripts.sft_import import discover_runs
bench_dir = _get_bench_results_dir()
existing = _read_candidates()
# benchmark_run_id in each record equals the run's directory name by cf-orch convention
imported_run_ids = {
r["benchmark_run_id"]
for r in existing
if r.get("benchmark_run_id") is not None
}
runs = discover_runs(bench_dir)
return [
{
"run_id": r["run_id"],
"timestamp": r["timestamp"],
"candidate_count": r["candidate_count"],
"already_imported": r["run_id"] in imported_run_ids,
}
for r in runs
]
# ── POST /import ───────────────────────────────────────────────────────────
class ImportRequest(BaseModel):
run_id: str
@router.post("/import")
def post_import(req: ImportRequest):
"""Import one benchmark run's sft_candidates.jsonl into the local data dir."""
from scripts.sft_import import discover_runs, import_run
bench_dir = _get_bench_results_dir()
runs = discover_runs(bench_dir)
run = next((r for r in runs if r["run_id"] == req.run_id), None)
if run is None:
raise HTTPException(404, f"Run {req.run_id!r} not found in bench_results_dir")
return import_run(run["sft_path"], _SFT_DATA_DIR)
# ── GET /queue ─────────────────────────────────────────────────────────────
@router.get("/queue")
def get_queue(page: int = 1, per_page: int = 20):
"""Return paginated needs_review candidates."""
records = _read_candidates()
pending = [r for r in records if r.get("status") == "needs_review"]
start = (page - 1) * per_page
return {
"items": pending[start:start + per_page],
"total": len(pending),
"page": page,
"per_page": per_page,
}
# ── POST /submit ───────────────────────────────────────────────────────────
FailureCategory = Literal[
"scoring_artifact",
"style_violation",
"partial_answer",
"wrong_answer",
"format_error",
"hallucination",
]
class SubmitRequest(BaseModel):
id: str
action: Literal["correct", "discard", "flag"]
corrected_response: str | None = None
failure_category: FailureCategory | None = None
@router.post("/submit")
def post_submit(req: SubmitRequest):
"""Record a reviewer decision for one SFT candidate."""
if req.action == "correct":
if not req.corrected_response or not req.corrected_response.strip():
raise HTTPException(422, "corrected_response must be non-empty when action is 'correct'")
records = _read_candidates()
idx = next((i for i, r in enumerate(records) if r.get("id") == req.id), None)
if idx is None:
raise HTTPException(404, f"Record {req.id!r} not found")
record = records[idx]
if record.get("status") != "needs_review":
raise HTTPException(409, f"Record is not in needs_review state (current: {record.get('status')})")
if req.action == "correct":
records[idx] = {
**record,
"status": "approved",
"corrected_response": req.corrected_response,
"failure_category": req.failure_category,
}
_write_candidates(records)
append_jsonl(_approved_file(), records[idx])
elif req.action == "discard":
records[idx] = {**record, "status": "discarded"}
_write_candidates(records)
else: # flag
records[idx] = {**record, "status": "model_rejected"}
_write_candidates(records)
return {"ok": True}
# ── POST /undo ─────────────────────────────────────────────────────────────
class UndoRequest(BaseModel):
id: str
@router.post("/undo")
def post_undo(req: UndoRequest):
"""Restore a previously actioned candidate back to needs_review."""
records = _read_candidates()
idx = next((i for i, r in enumerate(records) if r.get("id") == req.id), None)
if idx is None:
raise HTTPException(404, f"Record {req.id!r} not found")
record = records[idx]
old_status = record.get("status")
if old_status == "needs_review":
raise HTTPException(409, "Record is already in needs_review state")
records[idx] = {**record, "status": "needs_review", "corrected_response": None}
_write_candidates(records)
# If it was approved, remove from the approved file too
if old_status == "approved":
approved = read_jsonl(_approved_file())
write_jsonl(_approved_file(), [r for r in approved if r.get("id") != req.id])
return {"ok": True}
# ── GET /export ─────────────────────────────────────────────────────────────
@router.get("/export")
def get_export() -> StreamingResponse:
"""Stream approved records as SFT-ready JSONL for download."""
exportable = [r for r in read_jsonl(_approved_file()) if _is_exportable(r)]
def generate():
for r in exportable:
record = {
"messages": r.get("prompt_messages", []) + [
{"role": "assistant", "content": r["corrected_response"]}
]
}
yield json.dumps(record) + "\n"
timestamp = datetime.now(timezone.utc).strftime("%Y%m%d-%H%M%S")
return StreamingResponse(
generate(),
media_type="application/x-ndjson",
headers={
"Content-Disposition": f'attachment; filename="sft_export_{timestamp}.jsonl"'
},
)
# ── GET /stats ──────────────────────────────────────────────────────────────
@router.get("/stats")
def get_stats() -> dict[str, object]:
"""Return counts by status, model, and task type."""
records = _read_candidates()
by_status: dict[str, int] = {}
by_model: dict[str, int] = {}
by_task_type: dict[str, int] = {}
for r in records:
status = r.get("status", "unknown")
by_status[status] = by_status.get(status, 0) + 1
model = r.get("model_name", "unknown")
by_model[model] = by_model.get(model, 0) + 1
task_type = r.get("task_type", "unknown")
by_task_type[task_type] = by_task_type.get(task_type, 0) + 1
approved = read_jsonl(_approved_file())
export_ready = sum(1 for r in approved if _is_exportable(r))
return {
"total": len(records),
"by_status": by_status,
"by_model": by_model,
"by_task_type": by_task_type,
"export_ready": export_ready,
}
# ── GET /config ─────────────────────────────────────────────────────────────
@router.get("/config")
def get_sft_config() -> dict:
"""Return the current SFT configuration (bench_results_dir)."""
f = _config_file()
if not f.exists():
return {"bench_results_dir": ""}
try:
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
except yaml.YAMLError:
return {"bench_results_dir": ""}
sft_section = raw.get("sft") or {}
return {"bench_results_dir": sft_section.get("bench_results_dir", "")}
class SftConfigPayload(BaseModel):
bench_results_dir: str
@router.post("/config")
def post_sft_config(payload: SftConfigPayload) -> dict:
"""Write the bench_results_dir setting to the config file."""
f = _config_file()
f.parent.mkdir(parents=True, exist_ok=True)
try:
raw = yaml.safe_load(f.read_text(encoding="utf-8")) if f.exists() else {}
raw = raw or {}
except yaml.YAMLError:
raw = {}
raw["sft"] = {"bench_results_dir": payload.bench_results_dir}
tmp = f.with_suffix(".tmp")
tmp.write_text(yaml.dump(raw, allow_unicode=True, sort_keys=False), encoding="utf-8")
tmp.rename(f)
return {"ok": True}

0
app/train/__init__.py Normal file
View file

339
app/train/train.py Normal file
View file

@ -0,0 +1,339 @@
"""Avocet -- train job queue API.
SQLite-backed job queue for finetune jobs. Replaces the ad-hoc
_running_procs dict in api.py with a persistent, inspectable queue.
Routes (all under /api/train when api.py mounts with prefix="/api/train"):
GET /jobs -- list all jobs, newest first
POST /jobs -- create a new job
GET /jobs/{id} -- get one job by id
DELETE /jobs/{id}/cancel -- cancel a queued or running job
GET /jobs/{id}/run -- SSE: run the job, stream stdout
GET /results -- list completed models with training_info.json metrics
SQLite schema:
CREATE TABLE IF NOT EXISTS jobs (
id TEXT PRIMARY KEY,
type TEXT NOT NULL, -- 'classifier' | 'llm-sft'
model_key TEXT NOT NULL,
status TEXT NOT NULL DEFAULT 'queued',
config_json TEXT NOT NULL DEFAULT '{}',
created_at TEXT NOT NULL,
started_at TEXT,
completed_at TEXT,
error TEXT
)
Testability seam: _DB_PATH global, override via set_db_path().
"""
from __future__ import annotations
import json
import logging
import os
import sqlite3
import subprocess as _subprocess
import uuid
from contextlib import contextmanager
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Generator
import yaml
from fastapi import APIRouter, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
logger = logging.getLogger(__name__)
_ROOT = Path(__file__).parent.parent.parent
_DB_PATH: Path = _ROOT / "data" / "train_jobs.db"
_MODELS_DIR: Path = _ROOT / "models"
_CONFIG_DIR: Path | None = None # override in tests via set_config_dir()
_running_procs: dict[str, Any] = {}
router = APIRouter()
# -- Testability seams -------------------------------------------------
def set_db_path(path: Path) -> None:
global _DB_PATH
_DB_PATH = path
def set_models_dir(path: Path) -> None:
global _MODELS_DIR
_MODELS_DIR = path
def set_config_dir(path: "Path | None") -> None:
global _CONFIG_DIR
_CONFIG_DIR = path
# -- Config helpers ----------------------------------------------------
def _config_file() -> Path:
if _CONFIG_DIR is not None:
return _CONFIG_DIR / "label_tool.yaml"
return _ROOT / "config" / "label_tool.yaml"
def _load_train_config() -> dict:
"""Read python_bin from label_tool.yaml.
Priority (highest to lowest):
1. label_tool.yaml train: python_bin
2. label_tool.yaml cforch: python_bin
3. Hardcoded default (classifiers conda env)
"""
_DEFAULT_PYTHON_BIN = "/devl/miniconda3/envs/job-seeker-classifiers/bin/python"
f = _config_file()
train_cfg: dict = {}
cforch_cfg: dict = {}
if f.exists():
try:
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
train_cfg = raw.get("train", {}) or {}
cforch_cfg = raw.get("cforch", {}) or {}
except yaml.YAMLError as exc:
logger.warning("Failed to parse train config %s: %s", f, exc)
return {
"python_bin": train_cfg.get(
"python_bin",
cforch_cfg.get("python_bin", _DEFAULT_PYTHON_BIN),
),
}
# -- Database helpers --------------------------------------------------
@contextmanager
def _db() -> Generator[sqlite3.Connection, None, None]:
conn = sqlite3.connect(str(_DB_PATH))
conn.row_factory = sqlite3.Row
try:
yield conn
conn.commit()
finally:
conn.close()
def _init_db() -> None:
"""Create jobs table if it does not exist. Called lazily per request."""
_DB_PATH.parent.mkdir(parents=True, exist_ok=True)
with _db() as conn:
conn.execute("""
CREATE TABLE IF NOT EXISTS jobs (
id TEXT PRIMARY KEY,
type TEXT NOT NULL,
model_key TEXT NOT NULL,
status TEXT NOT NULL DEFAULT 'queued',
config_json TEXT NOT NULL DEFAULT '{}',
created_at TEXT NOT NULL,
started_at TEXT,
completed_at TEXT,
error TEXT
)
""")
def _row_to_dict(row: sqlite3.Row) -> dict:
return {k: row[k] for k in row.keys()}
# -- GPU selection (copied from api.py) --------------------------------
def _best_cuda_device() -> str:
"""Return index of GPU with most free VRAM, or empty string."""
try:
out = _subprocess.check_output(
["nvidia-smi", "--query-gpu=index,memory.free",
"--format=csv,noheader,nounits"],
text=True, timeout=5,
)
best_idx, best_free = "", 0
for line in out.strip().splitlines():
parts = line.strip().split(", ")
if len(parts) == 2:
idx, free = parts[0].strip(), int(parts[1].strip())
if free > best_free:
best_free, best_idx = free, idx
return best_idx
except Exception:
return ""
# -- Pydantic models ---------------------------------------------------
class CreateJobRequest(BaseModel):
type: str # "classifier" | "llm-sft"
model_key: str # e.g. "deberta-small"
config_json: dict = {}
# -- Routes ------------------------------------------------------------
@router.get("/jobs")
def list_jobs() -> dict:
_init_db()
with _db() as conn:
rows = conn.execute("SELECT * FROM jobs ORDER BY created_at DESC").fetchall()
return {"jobs": [_row_to_dict(r) for r in rows]}
@router.post("/jobs")
def create_job(req: CreateJobRequest) -> dict:
if req.type not in ("classifier", "llm-sft"):
raise HTTPException(400, f"Unknown job type: {req.type!r}. Must be 'classifier' or 'llm-sft'.")
_init_db()
job_id = str(uuid.uuid4())
now = datetime.now(timezone.utc).isoformat()
with _db() as conn:
conn.execute(
"INSERT INTO jobs (id, type, model_key, status, config_json, created_at) "
"VALUES (?, ?, ?, 'queued', ?, ?)",
(job_id, req.type, req.model_key, json.dumps(req.config_json), now),
)
return {"id": job_id, "type": req.type, "model_key": req.model_key,
"status": "queued", "config_json": req.config_json,
"created_at": now, "started_at": None, "completed_at": None, "error": None}
@router.get("/jobs/{job_id}")
def get_job(job_id: str) -> dict:
_init_db()
with _db() as conn:
row = conn.execute("SELECT * FROM jobs WHERE id = ?", (job_id,)).fetchone()
if row is None:
raise HTTPException(404, f"Job {job_id!r} not found")
return _row_to_dict(row)
@router.delete("/jobs/{job_id}/cancel")
def cancel_job(job_id: str) -> dict:
_init_db()
with _db() as conn:
row = conn.execute("SELECT * FROM jobs WHERE id = ?", (job_id,)).fetchone()
if row is None:
raise HTTPException(404, f"Job {job_id!r} not found")
if row["status"] not in ("queued", "running"):
raise HTTPException(409, f"Job is {row['status']} -- cannot cancel")
now = datetime.now(timezone.utc).isoformat()
conn.execute("UPDATE jobs SET status='cancelled', completed_at=? WHERE id=?", (now, job_id))
proc = _running_procs.pop(job_id, None)
if proc is not None:
try:
proc.terminate()
proc.wait(timeout=3)
except _subprocess.TimeoutExpired:
try:
proc.kill()
except OSError:
pass
return {"status": "cancelled"}
@router.get("/jobs/{job_id}/run")
def run_job(job_id: str) -> StreamingResponse:
_init_db()
with _db() as conn:
row = conn.execute("SELECT * FROM jobs WHERE id = ?", (job_id,)).fetchone()
if row is None:
raise HTTPException(404, f"Job {job_id!r} not found")
if row["status"] != "queued":
raise HTTPException(409, f"Job is {row['status']} -- only queued jobs can be run")
job = _row_to_dict(row)
def generate():
cfg = _load_train_config()
python_bin = cfg["python_bin"]
config = json.loads(job["config_json"] or "{}")
model_key = job["model_key"]
epochs = config.get("epochs", 5)
if job["type"] == "classifier":
script = str(_ROOT / "scripts" / "finetune_classifier.py")
cmd = [python_bin, script, "--model", model_key, "--epochs", str(epochs)]
data_dir = _ROOT / "data"
for sf in config.get("score_files", []):
resolved = (data_dir / sf).resolve()
if resolved.is_relative_to(data_dir.resolve()):
cmd.extend(["--score", str(resolved)])
elif job["type"] == "llm-sft":
script = str(_ROOT / "scripts" / "finetune_sft.py")
cmd = [python_bin, script, "--model", model_key, "--epochs", str(epochs)]
else:
job_type = job["type"]
yield f"data: {json.dumps({'type': 'error', 'message': f'Unknown job type: {job_type}'})}\n\n"
return
proc_env = {**os.environ, "PYTORCH_ALLOC_CONF": "expandable_segments:True"}
best_gpu = _best_cuda_device()
if best_gpu:
proc_env["CUDA_VISIBLE_DEVICES"] = best_gpu
gpu_note = f"GPU {best_gpu}" if best_gpu else "CPU (no GPU found)"
yield f"data: {json.dumps({'type': 'progress', 'message': f'[train] Using {gpu_note}'})}\n\n"
now = datetime.now(timezone.utc).isoformat()
with _db() as conn:
conn.execute("UPDATE jobs SET status='running', started_at=? WHERE id=?", (now, job_id))
try:
proc = _subprocess.Popen(
cmd, stdout=_subprocess.PIPE, stderr=_subprocess.STDOUT,
text=True, bufsize=1, cwd=str(_ROOT), env=proc_env,
)
_running_procs[job_id] = proc
try:
for line in proc.stdout:
line = line.rstrip()
if line:
yield f"data: {json.dumps({'type': 'progress', 'message': line})}\n\n"
proc.wait()
finished_at = datetime.now(timezone.utc).isoformat()
if proc.returncode == 0:
with _db() as conn:
conn.execute(
"UPDATE jobs SET status='completed', completed_at=? WHERE id=?",
(finished_at, job_id))
yield f"data: {json.dumps({'type': 'complete'})}\n\n"
else:
err = f"Process exited with code {proc.returncode}"
with _db() as conn:
conn.execute(
"UPDATE jobs SET status='failed', completed_at=?, error=? WHERE id=?",
(finished_at, err, job_id))
yield f"data: {json.dumps({'type': 'error', 'message': err})}\n\n"
finally:
_running_procs.pop(job_id, None)
except Exception as exc:
err = str(exc)
finished_at = datetime.now(timezone.utc).isoformat()
with _db() as conn:
conn.execute(
"UPDATE jobs SET status='failed', completed_at=?, error=? WHERE id=?",
(finished_at, err, job_id))
yield f"data: {json.dumps({'type': 'error', 'message': err})}\n\n"
return StreamingResponse(generate(), media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"})
@router.get("/results")
def list_results() -> dict:
if not _MODELS_DIR.exists():
return {"results": []}
results = []
for sub in _MODELS_DIR.iterdir():
if not sub.is_dir():
continue
info_path = sub / "training_info.json"
if not info_path.exists():
continue
try:
info = json.loads(info_path.read_text(encoding="utf-8"))
results.append(info)
except Exception as exc:
logger.warning("Failed to read training_info.json from %s: %s", info_path, exc)
return {"results": results}

View file

@ -106,7 +106,7 @@ def read_jsonl(path: Path) -> list[dict]:
def write_jsonl(path: Path, records: list[dict]) -> None: def write_jsonl(path: Path, records: list[dict]) -> None:
"""Write records to a JSONL file, overwriting any existing content.""" """Write records to a JSONL file, overwriting any existing content."""
path.parent.mkdir(parents=True, exist_ok=True) path.parent.mkdir(parents=True, exist_ok=True)
content = "\n".join(json.dumps(r) for r in records) content = "\n".join(json.dumps(r, ensure_ascii=False) for r in records)
path.write_text(content + ("\n" if records else ""), encoding="utf-8") path.write_text(content + ("\n" if records else ""), encoding="utf-8")
@ -114,4 +114,4 @@ def append_jsonl(path: Path, record: dict) -> None:
"""Append a single record to a JSONL file.""" """Append a single record to a JSONL file."""
path.parent.mkdir(parents=True, exist_ok=True) path.parent.mkdir(parents=True, exist_ok=True)
with open(path, "a", encoding="utf-8") as fh: with open(path, "a", encoding="utf-8") as fh:
fh.write(json.dumps(record) + "\n") fh.write(json.dumps(record, ensure_ascii=False) + "\n")

View file

@ -41,11 +41,20 @@ cforch:
# Python interpreter with cf-orch installed # Python interpreter with cf-orch installed
python_bin: /devl/miniconda3/envs/cf/bin/python python_bin: /devl/miniconda3/envs/cf/bin/python
# Connection config — override env vars CF_ORCH_URL / CF_LICENSE_KEY / OLLAMA_HOST # Connection config — override env vars CF_ORCH_URL / CF_LICENSE_KEY / OLLAMA_HOST / CF_JUDGE_URL / HF_TOKEN
# coordinator_url: http://localhost:7700 # coordinator_url: http://localhost:7700
# license_key: CFG-AVCT-xxxx-xxxx-xxxx # license_key: CFG-AVCT-xxxx-xxxx-xxxx
# ollama_url: http://localhost:11434 # ollama_url: http://localhost:11434
# ollama_model: llama3.2:3b # ollama_model: llama3.2:3b
# embed_model: nomic-embed-text # Ollama embedding model for EmbeddingKNNAdapter
# judge_url: http://10.1.10.158:8008 # Sif cf-text — LLM-as-judge secondary scorer
# judge_url: http://10.1.10.71:8008 # Heimdall cf-text (alternative)
# Or set CF_JUDGE_URL. Populates the Judge URL field in the LLM Eval UI automatically.
# hf_token: hf_xxxxxxxxxxxxxxxxxxxx # HuggingFace token — required for gated/terms-restricted models
# Directory containing per-node profile YAMLs (cf-orch node profiles).
# Default: derived from bench_script location (../../profiles/nodes).
# profiles_dir: /Library/Development/CircuitForge/circuitforge-orch/circuitforge_orch/profiles/nodes
# Imitate tab — pull real samples from sibling CF product APIs and run them # Imitate tab — pull real samples from sibling CF product APIs and run them
# through local LLMs to build a corrections dataset. # through local LLMs to build a corrections dataset.
@ -101,12 +110,3 @@ imitate:
sample_endpoint: /api/listings sample_endpoint: /api/listings
text_fields: [title, description, seller_info] text_fields: [title, description, seller_info]
prompt_template: "Evaluate the trustworthiness of this listing and flag any red flags:\n\n{text}" prompt_template: "Evaluate the trustworthiness of this listing and flag any red flags:\n\n{text}"
- id: osprey
name: Osprey
icon: "📞"
description: Gov't hold-line automation
base_url: http://localhost:8520
sample_endpoint: /api/calls/recent
text_fields: [agency, issue, notes]
prompt_template: "Draft a concise summary of this government call record:\n\n{text}"

View file

@ -90,6 +90,12 @@ usage() {
echo -e " ${GREEN}score [args]${NC} Shortcut: --score [args]" echo -e " ${GREEN}score [args]${NC} Shortcut: --score [args]"
echo -e " ${GREEN}compare [args]${NC} Shortcut: --compare [args]" echo -e " ${GREEN}compare [args]${NC} Shortcut: --compare [args]"
echo "" echo ""
echo " Planning Benchmark:"
echo -e " ${GREEN}plans-bench [args]${NC} Run benchmark_plans.py (args passed through)"
echo -e " ${GREEN}plans-list${NC} Shortcut: --list-models"
echo -e " ${GREEN}plans-run <model> [args]${NC} Run a single model (--verbose auto-added)"
echo -e " ${GREEN}plans-compare <m1> <m2> [more]${NC} Compare models side-by-side"
echo ""
echo " Writing Style Benchmark:" echo " Writing Style Benchmark:"
echo -e " ${GREEN}style-bench [args]${NC} Run benchmark_style.py (args passed through)" echo -e " ${GREEN}style-bench [args]${NC} Run benchmark_style.py (args passed through)"
echo -e " ${GREEN}style-list${NC} List available ollama models for style bench" echo -e " ${GREEN}style-list${NC} List available ollama models for style bench"
@ -127,6 +133,8 @@ case "$CMD" in
fi fi
mkdir -p "$LOG_DIR" mkdir -p "$LOG_DIR"
API_LOG="${LOG_DIR}/api.log" API_LOG="${LOG_DIR}/api.log"
# Load .env if present — sets HF_TOKEN and other optional overrides.
[[ -f .env ]] && set -a && source .env && set +a
info "Building Vue SPA…" info "Building Vue SPA…"
(cd web && npm run build) >> "$API_LOG" 2>&1 (cd web && npm run build) >> "$API_LOG" 2>&1
info "Starting FastAPI on port ${API_PORT}" info "Starting FastAPI on port ${API_PORT}"
@ -179,6 +187,9 @@ case "$CMD" in
mkdir -p "$LOG_DIR" mkdir -p "$LOG_DIR"
DEV_API_LOG="${LOG_DIR}/dev-api.log" DEV_API_LOG="${LOG_DIR}/dev-api.log"
# Load .env if present — sets HF_TOKEN and other optional overrides.
[[ -f .env ]] && set -a && source .env && set +a
if [[ -f "$DEV_API_PID_FILE" ]] && kill -0 "$(<"$DEV_API_PID_FILE")" 2>/dev/null; then if [[ -f "$DEV_API_PID_FILE" ]] && kill -0 "$(<"$DEV_API_PID_FILE")" 2>/dev/null; then
warn "Dev API already running (PID $(<"$DEV_API_PID_FILE"))" warn "Dev API already running (PID $(<"$DEV_API_PID_FILE"))"
else else
@ -255,6 +266,30 @@ case "$CMD" in
exec "$0" benchmark --compare "$@" exec "$0" benchmark --compare "$@"
;; ;;
plans-bench)
info "Running planning benchmark (${ENV_UI})…"
"$PYTHON_UI" scripts/benchmark_plans.py "$@"
;;
plans-list)
exec "$0" plans-bench --list-models
;;
plans-run)
if [[ $# -lt 1 ]]; then
error "Usage: ./manage.sh plans-run <model-key> [extra args]"
fi
MODEL="$1"; shift
exec "$0" plans-bench --model "$MODEL" --verbose "$@"
;;
plans-compare)
if [[ $# -lt 2 ]]; then
error "Usage: ./manage.sh plans-compare <model1> <model2> [more…]"
fi
exec "$0" plans-bench --compare "$@" --verbose
;;
style-bench) style-bench)
info "Running writing style benchmark (${ENV_BM})…" info "Running writing style benchmark (${ENV_BM})…"
if [[ ! -x "$PYTHON_BM" ]]; then if [[ ! -x "$PYTHON_BM" ]]; then

View file

@ -3,3 +3,4 @@ pydantic>=2.0.0
uvicorn[standard]>=0.20.0 uvicorn[standard]>=0.20.0
httpx>=0.24.0 httpx>=0.24.0
pytest>=7.0.0 pytest>=7.0.0
pyyaml>=6.0

View file

@ -39,6 +39,7 @@ from scripts.classifier_adapters import (
LABELS, LABELS,
LABEL_DESCRIPTIONS, LABEL_DESCRIPTIONS,
ClassifierAdapter, ClassifierAdapter,
EmbeddingKNNAdapter,
FineTunedAdapter, FineTunedAdapter,
GLiClassAdapter, GLiClassAdapter,
RerankerAdapter, RerankerAdapter,
@ -130,6 +131,13 @@ MODEL_REGISTRY: dict[str, dict[str, Any]] = {
"params": "600M", "params": "600M",
"default": False, "default": False,
}, },
"embed-knn-nomic": {
"adapter": EmbeddingKNNAdapter,
"model_id": "nomic-embed-text",
"params": "local-embed",
"default": False, # requires orch or ollama; use --include-slow
"kwargs": {"k": 3},
},
} }
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@ -184,6 +192,42 @@ def discover_finetuned_models(models_dir: Path | None = None) -> list[dict]:
return found return found
def build_exemplars_from_jsonl(path: str, k_per_label: int = 10) -> dict[str, list[str]]:
"""Sample up to k_per_label formatted email texts per label from a scored JSONL.
Formats each row as 'Subject: {subject}\n\n{body[:600]}' the same format
EmbeddingKNNAdapter uses at classify() time. Rows missing the 'label' key
are skipped silently.
Returns dict[label, list[str]] ready for EmbeddingKNNAdapter(exemplar_texts=...).
"""
result: dict[str, list[str]] = {}
p = Path(path)
with p.open(encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
try:
row = json.loads(line)
except json.JSONDecodeError as exc:
print(f"[build_exemplars] WARN: skipping malformed line: {exc}", flush=True)
continue
label = row.get("label")
if not label:
continue
subject = row.get("subject", "")
body = row.get("body", "")
if not subject and not body:
continue
texts = result.setdefault(label, [])
if len(texts) < k_per_label:
texts.append(
f"Subject: {subject}\n\n{body[:600]}"
)
return result
def _active_models(include_slow: bool = False) -> dict[str, dict[str, Any]]: def _active_models(include_slow: bool = False) -> dict[str, dict[str, Any]]:
"""Return the active model registry, merged with any discovered fine-tuned models.""" """Return the active model registry, merged with any discovered fine-tuned models."""
active: dict[str, dict[str, Any]] = { active: dict[str, dict[str, Any]] = {

719
scripts/benchmark_plans.py Normal file
View file

@ -0,0 +1,719 @@
#!/usr/bin/env python
"""CF-specific planning benchmark — compare base models before fine-tuning.
Sends held-out CircuitForge planning prompts to one or more models via the
cf-text (local) or cf-orch API, then scores responses against CF-specific
rubrics. Use this to select the best base model for SFT.
Scoring rubrics (each 0-1, summed to total/N):
- task_structure : uses checkbox syntax (- [ ]), git commit steps
- tier_awareness : mentions Free/Paid/Premium/Ultra tiers
- privacy_pillar : mentions privacy/local-inference/no-logging
- safety_pillar : mentions safety, human approval, or reversibility
- accessibility : mentions ND/accessibility/adaptive needs
- license_split : mentions MIT vs BSL or open-core model
- file_paths : uses plausible file path references
- cf_conventions : uses conda run -n cf, /Library/Development/, or known CF dirs
- paired_coherence : (paired only) plan references the design doc's feature name
- length_ok : 3002500 words (under-short = hallucination risk; over-long = padding)
Usage
-----
# List available model targets
python scripts/benchmark_plans.py --list-models
# Run all held-out prompts against a single model, print report
python scripts/benchmark_plans.py --model llama3.2-3b
# Compare two models side-by-side
python scripts/benchmark_plans.py --compare llama3.2-3b mistral-7b
# Run with a custom API base (cf-text default: http://localhost:8080/v1)
python scripts/benchmark_plans.py --model llama3.2-3b --api-base http://localhost:8080/v1
# Export detailed results JSON
python scripts/benchmark_plans.py --model llama3.2-3b --output data/bench_results.json
"""
from __future__ import annotations
import argparse
import json
import re
import sys
import threading
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass, field, asdict
from pathlib import Path
from typing import Any
import httpx
# ── Paths ──────────────────────────────────────────────────────────────────────
_ROOT = Path(__file__).parent.parent
_DATA_DIR = _ROOT / "data"
CF_TEXT_BASE = "http://localhost:8080/v1"
CF_ORCH_BASE = "http://localhost:8090/v1"
CF_COORD_URL = "http://10.1.10.71:7700" # cf-orch coordinator (LAN)
# ── Held-out prompts ───────────────────────────────────────────────────────────
# These are NOT in the training export (no matching docs in circuitforge-plans/).
# Each prompt exercises a different CF planning domain.
HELD_OUT_PROMPTS: list[dict[str, Any]] = [
{
"id": "ho_001",
"name": "kiwi_barcode_ocr",
"domain": "feature_plan",
"prompt": (
"You are a senior engineer on Kiwi, a CircuitForge pantry-tracking product. "
"Write a detailed implementation plan for adding barcode scanning via device camera "
"and receipt OCR to the item-add flow.\n\n"
"The plan should include: file structure (create/modify), step-by-step task checklist "
"with checkboxes, any DB migrations, and git commit steps."
),
"expected_signals": ["task_structure", "file_paths", "cf_conventions"],
},
{
"id": "ho_002",
"name": "peregrine_ats_scoring",
"domain": "feature_design",
"prompt": (
"Write a design document for Peregrine: ATS keyword scoring for job applications.\n\n"
"Context: Peregrine users paste job descriptions and their resume. "
"We want to score how well the resume keywords match the JD and suggest rewrites. "
"Describe the architecture, data flow, and key design decisions."
),
"expected_signals": ["privacy_pillar", "tier_awareness", "license_split"],
},
{
"id": "ho_003",
"name": "tier_gate_local_llm",
"domain": "architecture",
"prompt": (
"Design the tier-gating architecture for a new CircuitForge product. "
"The product should:\n"
"- Default to local LLM inference for all tiers\n"
"- Unlock cloud LLM for Paid tier and above\n"
"- Keep fine-tuned model weights for Premium/Ultra only\n\n"
"Describe how the tier check integrates with the LLM router, "
"what happens when a Free user tries a Paid-tier feature, "
"and how BYOK (bring-your-own-key) fits in."
),
"expected_signals": ["tier_awareness", "privacy_pillar", "license_split"],
},
{
"id": "ho_004",
"name": "heimdall_webhook_plan",
"domain": "feature_plan",
"prompt": (
"Break the following Heimdall feature into a detailed implementation plan with "
"file structure and task checkboxes — Stripe webhook handler for subscription lifecycle.\n\n"
"Heimdall is the CircuitForge license server (FastAPI + SQLite). "
"The webhook needs to handle checkout.session.completed, "
"customer.subscription.updated, and customer.subscription.deleted events."
),
"expected_signals": ["task_structure", "file_paths", "safety_pillar"],
},
{
"id": "ho_005",
"name": "nd_accessible_onboarding",
"domain": "ux_design",
"prompt": (
"You are a product designer working on Harrier, a CircuitForge tool for "
"helping people navigate government benefits applications.\n\n"
"Design the onboarding flow for neurodivergent (ND) users. "
"Consider: ADHD time-blindness, executive function challenges, demand avoidance, "
"and rejection sensitivity. The flow should reduce cognitive load and "
"never use urgency or panic patterns."
),
"expected_signals": ["accessibility", "safety_pillar", "privacy_pillar"],
},
{
"id": "ho_006",
"name": "circuitforge_core_extraction",
"domain": "architecture",
"prompt": (
"Produce a CircuitForge-style design document for the following circuitforge-core "
"feature — shared ActivityPub federation module.\n\n"
"Background: Multiple CF products (Kiwi, Rook, Snipe) want to publish updates "
"to ActivityPub. Build it once in cf-core (MIT licensed) so all products can use it. "
"Design the module API, describe what belongs in MIT vs BSL, and note federation "
"privacy constraints."
),
"expected_signals": ["license_split", "privacy_pillar", "cf_conventions"],
},
{
"id": "ho_007",
"name": "snipe_trust_score_plan",
"domain": "feature_plan",
"prompt": (
"You are a senior engineer on Snipe, a CircuitForge eBay trust-scoring tool. "
"Write a step-by-step engineering plan for: seller trust score calculation.\n\n"
"The score should combine: feedback ratio, account age, item-specifics completeness, "
"listing photo quality, and shipping time accuracy. "
"Include file structure, test plan, and migration steps."
),
"expected_signals": ["task_structure", "file_paths", "safety_pillar"],
},
{
"id": "ho_008",
"name": "avocet_training_pipeline",
"domain": "feature_plan",
"prompt": (
"Break the following Avocet feature into a detailed implementation plan — "
"end-to-end fine-tuning pipeline from labeled JSONL to deployed GGUF model.\n\n"
"Avocet is the CircuitForge email classifier training tool. "
"The pipeline should: validate the dataset, run LoRA SFT via unsloth, "
"quantize to Q5_K_M GGUF, run the benchmark harness, and register the model "
"in the Avocet model queue if it beats the baseline."
),
"expected_signals": ["task_structure", "file_paths", "cf_conventions"],
},
{
"id": "ho_009",
"name": "privacy_data_flow",
"domain": "architecture",
"prompt": (
"Design the data privacy architecture for a CircuitForge cloud product. "
"Describe: what PII is collected, how it's stored, retention policy, "
"obfuscation strategy for cloud-side logs, and how consent is obtained "
"in plain language. The product handles job applications (resumes, cover letters)."
),
"expected_signals": ["privacy_pillar", "safety_pillar", "accessibility"],
},
{
"id": "ho_010",
"name": "git_workflow_doc",
"domain": "process_doc",
"prompt": (
"Write a developer process document for CircuitForge: conventional commit and "
"branch workflow for a BSL 1.1 open-core product.\n\n"
"Cover: commit message format (type: description), branch naming, "
"when to use feature branches vs direct main commits, "
"how the MIT/BSL split affects which commits go in which branch, "
"and how CI gates on gitleaks for secret scanning."
),
"expected_signals": ["license_split", "cf_conventions", "task_structure"],
},
]
# ── Rubric scoring ─────────────────────────────────────────────────────────────
_TASK_STRUCTURE_RE = re.compile(r"- \[ \]", re.MULTILINE)
_COMMIT_RE = re.compile(r"git commit|git add", re.IGNORECASE)
_TIER_RE = re.compile(r"\b(Free|Paid|Premium|Ultra)\s+tier|\btier\s+(Free|Paid|Premium|Ultra)", re.IGNORECASE)
_PRIVACY_RE = re.compile(r"\b(privacy|local.?inference|no.?logging|no.?pii|user.?data|data.?reten|obfuscat)", re.IGNORECASE)
_SAFETY_RE = re.compile(r"\b(human.?approv|reversib|safety|safe.?default|fail.?safe|harm)", re.IGNORECASE)
_A11Y_RE = re.compile(r"\b(neurodiverg|ND\b|accessib|adaptive|ADHD|autism|executive.?function|demand.?avoid)", re.IGNORECASE)
_LICENSE_RE = re.compile(r"\b(MIT|BSL|open.?core|proprietary|commercial.?licens)", re.IGNORECASE)
_FILE_PATH_RE = re.compile(r"(app/|tests?/|src/|scripts?/)\w[\w/.-]{3,}", re.IGNORECASE)
_CF_CONV_RE = re.compile(r"(conda run -n cf|/Library/Development/CircuitForge|circuitforge-core|manage\.sh)", re.IGNORECASE)
@dataclass
class RubricScore:
task_structure: float = 0.0
tier_awareness: float = 0.0
privacy_pillar: float = 0.0
safety_pillar: float = 0.0
accessibility: float = 0.0
license_split: float = 0.0
file_paths: float = 0.0
cf_conventions: float = 0.0
length_ok: float = 0.0
def total(self) -> float:
vals = [self.task_structure, self.tier_awareness, self.privacy_pillar,
self.safety_pillar, self.accessibility, self.license_split,
self.file_paths, self.cf_conventions, self.length_ok]
return sum(vals) / len(vals)
def as_dict(self) -> dict[str, float]:
return asdict(self)
def score_response(response: str, prompt_meta: dict[str, Any]) -> RubricScore:
words = len(response.split())
s = RubricScore()
# Task structure: needs checkboxes AND at least one commit step
checkbox_hits = len(_TASK_STRUCTURE_RE.findall(response))
has_commit = bool(_COMMIT_RE.search(response))
s.task_structure = min(1.0, checkbox_hits / 5) * 0.7 + (0.3 if has_commit else 0.0)
# Tier awareness
s.tier_awareness = min(1.0, len(_TIER_RE.findall(response)) / 2)
# Privacy pillar
s.privacy_pillar = min(1.0, len(_PRIVACY_RE.findall(response)) / 3)
# Safety pillar
s.safety_pillar = min(1.0, len(_SAFETY_RE.findall(response)) / 2)
# Accessibility
s.accessibility = min(1.0, len(_A11Y_RE.findall(response)) / 2)
# License split awareness
s.license_split = min(1.0, len(_LICENSE_RE.findall(response)) / 2)
# File paths: at least 3 plausible path references
s.file_paths = min(1.0, len(_FILE_PATH_RE.findall(response)) / 3)
# CF conventions
s.cf_conventions = min(1.0, len(_CF_CONV_RE.findall(response)) / 2)
# Length: 2002500 words is healthy; outside = partial credit
if 200 <= words <= 2500:
s.length_ok = 1.0
elif words < 200:
s.length_ok = words / 200
else:
s.length_ok = max(0.0, 1.0 - (words - 2500) / 2500)
return s
# ── Model client ───────────────────────────────────────────────────────────────
# Registry of named model targets (shorthand → {api_base, model_name})
MODEL_REGISTRY: dict[str, dict[str, str]] = {
"deepseek-r1-1.5b": {
"api_base": CF_TEXT_BASE,
"model": "deepseek-r1-1.5b",
"description": "DeepSeek R1 1.5B distill (cf-orch catalog key)",
},
"deepseek-r1-7b-4bit": {
"api_base": CF_TEXT_BASE,
"model": "deepseek-r1-7b-4bit",
"description": "DeepSeek R1 7B distill, 4-bit (cf-orch catalog key)",
},
"deepseek-coder-6.7b-4bit": {
"api_base": CF_TEXT_BASE,
"model": "deepseek-coder-6.7b-4bit",
"description": "DeepSeek Coder 6.7B instruct, 4-bit (cf-orch catalog key)",
},
"granite-4.1-8b": {
"api_base": CF_TEXT_BASE,
"model": "granite-4.1-8b",
"description": "IBM Granite 4.1 8B, 4-bit (cf-orch catalog key)",
},
"qwen2.5-3b": {
"api_base": CF_TEXT_BASE,
"model": "qwen2.5-3b",
"description": "Qwen 2.5 3B Q4 GGUF (cf-orch catalog key, navi only)",
},
"qwen2.5-7b": {
"api_base": CF_TEXT_BASE,
"model": "qwen2.5-7b",
"description": "Qwen 2.5 7B Q4 GGUF (cf-orch catalog key, navi only)",
},
}
# ── cf-orch allocation ─────────────────────────────────────────────────────────
def _cforch_allocate(
model_id: str,
cforch_url: str,
startup_timeout_s: float = 300.0,
) -> tuple[str, str] | None:
"""Allocate a cf-text instance for model_id via the cf-orch coordinator.
Returns (service_url, allocation_id) on success, None on failure.
service_url is the direct node URL exposing /v1/chat/completions.
"""
try:
resp = httpx.post(
f"{cforch_url}/api/services/cf-text/allocate",
json={
"model_candidates": [model_id],
"caller": "avocet",
"pipeline": "plans_benchmark",
},
timeout=120.0,
)
resp.raise_for_status()
data = resp.json()
service_url: str = data["url"]
allocation_id: str = data.get("allocation_id", "")
node_id: str = data.get("node_id", "")
gpu_id: int | None = data.get("gpu_id")
if data.get("started", False) and not data.get("warm", True):
# Use \n so the SSE generator sees the line immediately
print(f" [cold start] loading {model_id!r} — polling every 3s…", flush=True)
t0 = time.monotonic()
deadline = t0 + startup_timeout_s
probe_misses = 0
while time.monotonic() < deadline:
elapsed = time.monotonic() - t0
try:
status = httpx.get(f"{cforch_url}/api/services/cf-text/status", timeout=5.0)
if status.is_success:
instances = status.json().get("instances", [])
match = next(
(i for i in instances
if i.get("node_id") == node_id and i.get("gpu_id") == gpu_id),
None,
)
if match:
probe_misses = 0
state = match.get("state", "")
if state == "running":
print(f" [cold start] ready in {elapsed:.0f}s", flush=True)
return service_url, allocation_id
elif state == "stopped":
print(f" [cold start] failed — service stopped after {elapsed:.0f}s", flush=True)
return None
else:
# still starting — emit keepalive so SSE stream stays alive
print(f" [cold start] state={state!r} elapsed={elapsed:.0f}s", flush=True)
else:
probe_misses += 1
print(f" [cold start] waiting… elapsed={elapsed:.0f}s", flush=True)
if probe_misses >= 6:
try:
h = httpx.get(f"{service_url}/health", timeout=3.0)
if h.is_success:
print(f" [cold start] ready via health check in {elapsed:.0f}s", flush=True)
return service_url, allocation_id
except Exception:
pass
else:
print(f" [cold start] status poll returned {status.status_code}, elapsed={elapsed:.0f}s", flush=True)
except Exception as poll_exc:
print(f" [cold start] poll error: {poll_exc} elapsed={elapsed:.0f}s", flush=True)
time.sleep(3.0)
print(f" [cold start] timed out after {time.monotonic()-t0:.0f}s", flush=True)
return None
return service_url, allocation_id
except Exception as exc:
print(f"[warn] cf-orch allocation failed for {model_id!r}: {exc}", file=sys.stderr)
return None
def _call_model_direct(service_url: str, model: str, prompt: str, timeout: int = 600) -> tuple[str, float]:
"""Call an OpenAI-compatible /v1/chat/completions on a direct service URL."""
t0 = time.monotonic()
resp = httpx.post(
f"{service_url.rstrip('/')}/v1/chat/completions",
json={
"model": model,
"messages": [{"role": "user", "content": prompt}],
"max_tokens": 2048,
"temperature": 0.2,
},
timeout=timeout,
)
resp.raise_for_status()
latency = time.monotonic() - t0
text = resp.json()["choices"][0]["message"]["content"]
return text, latency
def _call_model(api_base: str, model: str, prompt: str, timeout: int = 180) -> tuple[str, float]:
"""Call an OpenAI-compatible /chat/completions endpoint. Returns (text, latency_s)."""
t0 = time.monotonic()
resp = httpx.post(
f"{api_base}/chat/completions",
json={
"model": model,
"messages": [{"role": "user", "content": prompt}],
"max_tokens": 2048,
"temperature": 0.2,
},
timeout=timeout,
)
resp.raise_for_status()
latency = time.monotonic() - t0
text = resp.json()["choices"][0]["message"]["content"]
return text, latency
# ── Benchmark runner ───────────────────────────────────────────────────────────
@dataclass
class PromptResult:
prompt_id: str
prompt_name: str
model_key: str
response: str
latency_s: float
word_count: int
scores: dict[str, float]
total_score: float
error: str | None = None
def run_benchmark(
model_key: str,
model_name: str,
prompts: list[dict[str, Any]] | None = None,
verbose: bool = False,
# cf-orch path
use_cforch: bool = False,
cforch_url: str = CF_COORD_URL,
# direct path (used when not cf-orch)
api_base: str = CF_TEXT_BASE,
) -> list[PromptResult]:
"""Run all prompts through one model. Uses cf-orch allocation when use_cforch=True."""
if prompts is None:
prompts = HELD_OUT_PROMPTS
# Allocate once per model when using cf-orch
service_url: str | None = None
if use_cforch:
print(f" Allocating {model_name!r} via cf-orch…", flush=True)
alloc = _cforch_allocate(model_name, cforch_url)
if alloc is None:
# Return all prompts as errors
return [
PromptResult(
prompt_id=p["id"], prompt_name=p["name"], model_key=model_key,
response="", latency_s=0.0, word_count=0, scores={}, total_score=0.0,
error=f"cf-orch allocation failed for {model_name!r}",
)
for p in prompts
]
service_url, _alloc_id = alloc
results: list[PromptResult] = []
for p in prompts:
if verbose:
print(f" [{p['id']}] {p['name']}", end="", flush=True)
try:
if service_url:
response, latency = _call_model_direct(service_url, model_name, p["prompt"])
else:
response, latency = _call_model(api_base, model_name, p["prompt"])
rubric = score_response(response, p)
result = PromptResult(
prompt_id=p["id"],
prompt_name=p["name"],
model_key=model_key,
response=response,
latency_s=round(latency, 2),
word_count=len(response.split()),
scores=rubric.as_dict(),
total_score=round(rubric.total(), 3),
)
if verbose:
print(f"score={result.total_score:.3f} ({result.word_count}w, {latency:.1f}s)")
except Exception as exc:
result = PromptResult(
prompt_id=p["id"],
prompt_name=p["name"],
model_key=model_key,
response="",
latency_s=0.0,
word_count=0,
scores={},
total_score=0.0,
error=str(exc),
)
if verbose:
print(f"ERROR: {exc}")
results.append(result)
return results
# ── Reporting ──────────────────────────────────────────────────────────────────
def _print_single_report(results: list[PromptResult], model_key: str) -> None:
ok = [r for r in results if not r.error]
err = [r for r in results if r.error]
if not ok:
print(f"\n[{model_key}] All {len(err)} prompts failed.\n")
return
avg_total = sum(r.total_score for r in ok) / len(ok)
avg_latency = sum(r.latency_s for r in ok) / len(ok)
# Aggregate per-rubric averages
rubric_keys = list(ok[0].scores.keys())
rubric_avgs = {k: sum(r.scores.get(k, 0) for r in ok) / len(ok) for k in rubric_keys}
print(f"\n{'='*60}")
print(f" Model : {model_key}")
print(f" Prompts: {len(ok)}/{len(results)} passed ({len(err)} errors)")
print(f" Overall score : {avg_total:.3f} (avg latency {avg_latency:.1f}s)")
print(f"\n Rubric breakdown:")
for k, v in sorted(rubric_avgs.items(), key=lambda x: -x[1]):
bar = "" * int(v * 20)
print(f" {k:<22} {v:.3f} {bar}")
print(f"\n Per-prompt scores:")
for r in sorted(ok, key=lambda x: -x.total_score):
flag = "" if r.total_score < 0.3 else " "
print(f" {flag} {r.prompt_id} {r.prompt_name:<35} {r.total_score:.3f} ({r.word_count}w)")
if err:
print(f"\n Errors:")
for r in err:
print(f" {r.prompt_id} {r.prompt_name}: {r.error}")
print(f"{'='*60}\n")
def _print_comparison_table(all_results: dict[str, list[PromptResult]]) -> None:
model_keys = list(all_results.keys())
prompt_ids = [p["id"] for p in HELD_OUT_PROMPTS]
# Scores by (model, prompt_id)
score_map: dict[tuple[str, str], float] = {}
for mk, results in all_results.items():
for r in results:
score_map[(mk, r.prompt_id)] = r.total_score if not r.error else 0.0
col_w = 10
header = f"{'Prompt':<35}" + "".join(f"{mk[:col_w-1]:<{col_w}}" for mk in model_keys)
print(f"\n{'='*len(header)}")
print(" COMPARISON TABLE")
print(f"{'='*len(header)}")
print(f" {header}")
print(f" {'-'*len(header)}")
for pid in prompt_ids:
pname = next(p["name"] for p in HELD_OUT_PROMPTS if p["id"] == pid)
row = f" {pname:<35}"
best = max(score_map.get((mk, pid), 0.0) for mk in model_keys)
for mk in model_keys:
v = score_map.get((mk, pid), 0.0)
marker = "*" if v == best and len(model_keys) > 1 else " "
row += f"{v:.3f}{marker} "
print(row)
print(f" {'-'*len(header)}")
avgs_row = f" {'AVERAGE':<35}"
best_avg = -1.0
avgs: dict[str, float] = {}
for mk in model_keys:
vals = [score_map.get((mk, pid), 0.0) for pid in prompt_ids]
avgs[mk] = sum(vals) / len(vals)
best_avg = max(best_avg, avgs[mk])
for mk in model_keys:
marker = "*" if avgs[mk] == best_avg and len(model_keys) > 1 else " "
avgs_row += f"{avgs[mk]:.3f}{marker} "
print(avgs_row)
print(f"{'='*len(header)}\n")
if len(model_keys) > 1:
winner = max(avgs, key=lambda k: avgs[k])
print(f" Winner: {winner} (avg {avgs[winner]:.3f})\n")
# ── CLI ────────────────────────────────────────────────────────────────────────
def main() -> None:
parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter)
parser.add_argument("--list-models", action="store_true",
help="Print registered model shortcuts and exit")
parser.add_argument("--model", metavar="KEY",
help="Benchmark a single model (registry key or raw model name)")
parser.add_argument("--compare", nargs="+", metavar="KEY",
help="Compare two or more models side-by-side")
parser.add_argument("--cforch", action="store_true",
help="Route inference through cf-orch coordinator (allocate per model)")
parser.add_argument("--cforch-url", default=CF_COORD_URL, metavar="URL",
help=f"cf-orch coordinator URL (default: {CF_COORD_URL})")
parser.add_argument("--api-base", default=None,
help="Direct API base URL when not using cf-orch")
parser.add_argument("--model-name", default=None,
help="Override model name sent to API (single-model runs only)")
parser.add_argument("--prompts", nargs="+", metavar="ID",
help="Run only specific prompt IDs (e.g. ho_001 ho_003)")
parser.add_argument("--output", type=Path, default=None,
help="Write detailed JSON results to this path")
parser.add_argument("--workers", type=int, default=1, metavar="N",
help="Run N models concurrently (default 1). Set to number of available nodes.")
parser.add_argument("--verbose", "-v", action="store_true",
help="Print per-prompt progress")
args = parser.parse_args()
if args.list_models:
print("\nRegistered model shortcuts:")
for key, info in MODEL_REGISTRY.items():
print(f" {key:<20} {info['description']}")
print(f"\nDefault endpoints:")
print(f" direct {CF_TEXT_BASE}")
print(f" cf-orch {CF_COORD_URL}")
return
prompts = HELD_OUT_PROMPTS
if args.prompts:
ids = set(args.prompts)
prompts = [p for p in HELD_OUT_PROMPTS if p["id"] in ids]
if not prompts:
print(f"No prompts matched IDs: {args.prompts}", file=sys.stderr)
sys.exit(1)
model_keys: list[str] = []
if args.compare:
model_keys = args.compare
elif args.model:
model_keys = [args.model]
else:
parser.print_help()
sys.exit(0)
all_results: dict[str, list[PromptResult]] = {}
print_lock = threading.Lock()
def _run_one(mk: str) -> tuple[str, list[PromptResult]]:
if mk in MODEL_REGISTRY:
reg = MODEL_REGISTRY[mk]
model_name = args.model_name or reg["model"]
direct_base = args.api_base or reg["api_base"]
else:
model_name = args.model_name or mk
direct_base = args.api_base or CF_TEXT_BASE
if args.cforch:
with print_lock:
print(f"\nRunning [{mk}] via cf-orch ({args.cforch_url}) model={model_name}")
results = run_benchmark(
mk, model_name, prompts=prompts, verbose=args.verbose,
use_cforch=True, cforch_url=args.cforch_url,
)
else:
with print_lock:
print(f"\nRunning [{mk}] → {direct_base} model={model_name}")
results = run_benchmark(
mk, model_name, prompts=prompts, verbose=args.verbose,
api_base=direct_base,
)
with print_lock:
_print_single_report(results, mk)
return mk, results
workers = max(1, args.workers)
if workers == 1 or len(model_keys) == 1:
for mk in model_keys:
mk_out, results = _run_one(mk)
all_results[mk_out] = results
else:
with ThreadPoolExecutor(max_workers=workers) as pool:
futures = {pool.submit(_run_one, mk): mk for mk in model_keys}
for fut in as_completed(futures):
mk_out, results = fut.result()
all_results[mk_out] = results
if len(model_keys) > 1:
_print_comparison_table(all_results)
if args.output:
args.output.parent.mkdir(parents=True, exist_ok=True)
payload = {
mk: [asdict(r) for r in results]
for mk, results in all_results.items()
}
with open(args.output, "w", encoding="utf-8") as f:
json.dump(payload, f, indent=2, ensure_ascii=False)
print(f"Wrote detailed results to {args.output}")
if __name__ == "__main__":
main()

View file

@ -7,19 +7,26 @@ from __future__ import annotations
import abc import abc
from collections import defaultdict from collections import defaultdict
import httpx
import logging
from pathlib import Path
from typing import Any from typing import Any
__all__ = [ __all__ = [
"LABELS", "LABELS",
"LABEL_DESCRIPTIONS", "LABEL_DESCRIPTIONS",
"DEFAULT_EXEMPLARS",
"compute_metrics", "compute_metrics",
"ClassifierAdapter", "ClassifierAdapter",
"ZeroShotAdapter", "ZeroShotAdapter",
"GLiClassAdapter", "GLiClassAdapter",
"RerankerAdapter", "RerankerAdapter",
"FineTunedAdapter", "FineTunedAdapter",
"EmbeddingKNNAdapter",
] ]
_logger = logging.getLogger(__name__)
LABELS: list[str] = [ LABELS: list[str] = [
"interview_scheduled", "interview_scheduled",
"offer_received", "offer_received",
@ -117,6 +124,81 @@ def compute_metrics(
return result return result
def _cosine(a: list[float], b: list[float]) -> float:
dot = sum(x * y for x, y in zip(a, b))
norm_a = sum(x * x for x in a) ** 0.5
norm_b = sum(x * x for x in b) ** 0.5
return dot / (norm_a * norm_b) if norm_a and norm_b else 0.0
DEFAULT_EXEMPLARS: dict[str, list[str]] = {
"interview_scheduled": [
"Subject: Interview Invitation\n\nWe would like to invite you for a phone screen next week.",
"Subject: Schedule a call\n\nCould you be available for a video interview on Tuesday?",
"Subject: Next Steps\n\nWe'd like to move forward with a technical interview. Please select a time.",
"Subject: Interview Details\n\nHere are the dial-in instructions for your interview tomorrow.",
],
"offer_received": [
"Subject: Offer Letter Enclosed\n\nWe are pleased to extend you an offer of employment.",
"Subject: Job Offer\n\nDear candidate, we are excited to offer you the position of Software Engineer.",
"Subject: Employment Offer\n\nPlease find attached your formal offer letter and compensation details.",
"Subject: Offer of Employment\n\nCongratulations! We would like to offer you a full-time position.",
],
"rejected": [
"Subject: Your Application\n\nAfter careful consideration, we have decided to move forward with other candidates.",
"Subject: Application Status\n\nWe regret to inform you that your application has not been selected.",
"Subject: Thank you for applying\n\nWe appreciate your interest but have chosen not to proceed.",
"Subject: Update on your candidacy\n\nWe will not be moving forward with your application at this time.",
],
"positive_response": [
"Subject: Your profile\n\nI came across your LinkedIn and think you would be a great fit for our team.",
"Subject: Exciting opportunity\n\nWe were impressed by your background and would love to connect.",
"Subject: Following up\n\nThank you for your interest — we'd like to learn more about your experience.",
"Subject: Great fit\n\nYour skills align well with what we are looking for. Let's set up a call.",
],
"survey_received": [
"Subject: Candidate Experience Survey\n\nPlease complete this brief survey about your application experience.",
"Subject: Culture Fit Assessment\n\nAs part of our process, we ask all candidates to complete a short assessment.",
"Subject: Skills Assessment\n\nWe'd like you to complete our online coding assessment before proceeding.",
"Subject: Personality Assessment\n\nPlease complete the following assessment as the next step in our process.",
"Subject: Pre-interview questionnaire\n\nBefore we schedule your interview, please complete this brief skills survey.",
],
"neutral": [
"Subject: Application Received\n\nWe have received your application and will be in touch.",
"Subject: Thank you for applying\n\nYour application is under review. We will contact you if needed.",
"Subject: Confirmation\n\nThis email confirms receipt of your application to our company.",
"Subject: Application Confirmation\n\nThank you for your interest. We will review your materials and follow up.",
],
"event_rescheduled": [
"Subject: Interview Rescheduled\n\nDue to a conflict, we need to move your interview to a new time.",
"Subject: Change of interview time\n\nWe apologize — your interview has been rescheduled to Thursday.",
"Subject: Updated interview details\n\nYour interview has been moved from Monday to Wednesday at 2pm.",
"Subject: Reschedule request\n\nWould you be available to reschedule to a different time slot?",
"Subject: New interview time\n\nYour phone screen has been moved from tomorrow to next week.",
],
"digest": [
"Subject: 15 new jobs matching your search\n\nHere are the latest job postings that match your profile.",
"Subject: Weekly Job Digest\n\nThis week's top opportunities for Software Engineers in your area.",
"Subject: Jobs you might like\n\nBased on your profile, here are some positions we recommend.",
"Subject: New jobs for you\n\nSee the latest openings from companies on your watchlist.",
],
"new_lead": [
"Subject: Exciting opportunity at our company\n\nHi, I noticed your background and think you'd be a great fit.",
"Subject: Are you open to new opportunities?\n\nI'm a recruiter reaching out about a role matching your experience.",
"Subject: Quick question\n\nWould you be interested in hearing about a senior engineering role?",
"Subject: Recruiting outreach\n\nI came across your profile and wanted to share an exciting opening.",
],
"hired": [
"Subject: Welcome to the team!\n\nWe are thrilled to have you join us. Here are your onboarding details.",
"Subject: Onboarding information\n\nCongratulations on accepting our offer. Your start date is confirmed.",
"Subject: First day information\n\nWe look forward to your first day. Please arrive at 9am and ask for HR.",
"Subject: Background check initiated\n\nAs part of your onboarding, we have initiated a background check.",
"Subject: Equipment setup\n\nYour laptop and equipment will be ready for pickup on your first day.",
],
}
class ClassifierAdapter(abc.ABC): class ClassifierAdapter(abc.ABC):
"""Abstract base for all email classifier adapters.""" """Abstract base for all email classifier adapters."""
@ -304,3 +386,148 @@ class FineTunedAdapter(ClassifierAdapter):
text = f"{subject} [SEP] {body[:400]}" text = f"{subject} [SEP] {body[:400]}"
result = self._pipeline(text) result = self._pipeline(text)
return result[0]["label"] return result[0]["label"]
class EmbeddingKNNAdapter(ClassifierAdapter):
"""k-NN email classifier using Ollama /v1/embeddings via cf-orch allocation.
load():
1. Allocates an Ollama instance from cf-orch (POST /api/services/ollama/allocate).
Falls back to ollama_url directly if orch allocation fails or is not configured.
2. Pre-embeds all exemplar texts and stores per-label vector lists.
classify(subject, body):
Embeds the input email, computes cosine similarity against all stored exemplar
vectors, and majority-votes the top-k labels (default k=3). Tie-break: label
with the highest total similarity score among tied vote counts wins.
unload():
Releases the cf-orch allocation (DELETE .../allocations/{id}) and clears state.
"""
def __init__(
self,
name: str,
model_id: str,
*,
k: int = 3,
orch_url: str = "",
ollama_url: str = "",
exemplar_texts: dict[str, list[str]] | None = None,
) -> None:
self._name = name
self._model_id = model_id
self._k = k
self._orch_url = orch_url
self._ollama_url = ollama_url
self._exemplar_texts: dict[str, list[str]] = (
exemplar_texts if exemplar_texts is not None else DEFAULT_EXEMPLARS
)
self._exemplar_embeddings: dict[str, list[list[float]]] = {}
self._node_url: str = ""
self._allocation_id: str = ""
self._orch_url_used: str = ""
@property
def name(self) -> str:
return self._name
@property
def model_id(self) -> str:
return self._model_id
def _resolve_urls(self) -> tuple[str, str]:
if self._orch_url or self._ollama_url:
return self._orch_url, self._ollama_url
import yaml # noqa: PLC0415
cfg_path = Path(__file__).parent.parent / "config" / "label_tool.yaml"
cfg: dict = {}
if cfg_path.exists():
try:
cfg = yaml.safe_load(cfg_path.read_text(encoding="utf-8")) or {}
except yaml.YAMLError:
pass
cforch = cfg.get("cforch", {}) or {}
return cforch.get("coordinator_url", ""), cforch.get("ollama_url", "")
def _embed(self, node_url: str, texts: list[str]) -> list[list[float]]:
resp = httpx.post(
f"{node_url}/v1/embeddings",
json={"model": self._model_id, "input": texts},
timeout=30.0,
)
resp.raise_for_status()
return [item["embedding"] for item in resp.json()["data"]]
def load(self) -> None:
if self._allocation_id or self._exemplar_embeddings:
raise RuntimeError(
"EmbeddingKNNAdapter.load() called while already loaded — call unload() first"
)
orch_url, ollama_url = self._resolve_urls()
node_url = ""
orch_url_used = ""
if orch_url:
try:
resp = httpx.post(
f"{orch_url}/api/services/ollama/allocate",
json={"model": self._model_id},
timeout=15.0,
)
if resp.status_code == 200:
data = resp.json()
node_url = data["url"]
self._allocation_id = data["allocation_id"]
orch_url_used = orch_url
except Exception as exc:
_logger.warning(
"cf-orch allocation failed, falling back to direct ollama_url: %s", exc
)
if not node_url:
node_url = ollama_url
self._allocation_id = ""
orch_url_used = ""
self._node_url = node_url
self._orch_url_used = orch_url_used
try:
embeddings: dict[str, list[list[float]]] = {}
for label, texts in self._exemplar_texts.items():
embeddings[label] = self._embed(node_url, texts)
self._exemplar_embeddings = embeddings
except Exception:
self.unload()
raise
def unload(self) -> None:
if self._allocation_id and self._orch_url_used:
try:
httpx.request(
"DELETE",
f"{self._orch_url_used}/api/services/ollama/allocations/{self._allocation_id}",
timeout=10.0,
)
except Exception:
pass
self._exemplar_embeddings = {}
self._node_url = ""
self._allocation_id = ""
self._orch_url_used = ""
def classify(self, subject: str, body: str) -> str:
if not self._exemplar_embeddings:
self.load()
text = f"Subject: {subject}\n\n{body[:600]}"
[query_vec] = self._embed(self._node_url, [text])
scored: list[tuple[float, str]] = [
(_cosine(query_vec, vec), label)
for label, vecs in self._exemplar_embeddings.items()
for vec in vecs
]
top_k = sorted(scored, reverse=True)[: self._k]
votes: dict[str, list[float]] = {}
for score, label in top_k:
votes.setdefault(label, []).append(score)
return max(
votes,
key=lambda lbl: (len(votes[lbl]), sum(votes[lbl])),
)

458
scripts/export_plans.py Normal file
View file

@ -0,0 +1,458 @@
"""Export circuitforge-plans/ documents as instruction-tuning JSONL pairs.
Each record is a HuggingFace chat-format example:
{
"id": "<sha256>",
"messages": [
{"role": "user", "content": "<reconstructed planning prompt>"},
{"role": "assistant", "content": "<cleaned document content>"}
],
"meta": {
"source": "peregrine/2026-03-03-feedback-button-design.md",
"product": "peregrine",
"doc_type": "design", # design | plan | spec | implementation | other
"date": "2026-03-03",
"paired_with": "...", # sibling path, or null
"word_count": 1847,
"pair_role": "context" # "context" | "target" | "standalone"
}
}
Pairing strategy
----------------
When a design doc and a plan doc share the same date + feature-name prefix,
they are treated as a pair:
- design plan: instruction = "Given this design doc, write the implementation plan."
context appended = full design doc content.
- Solo docs get a synthetic instruction from the title + first overview section.
Usage
-----
# Preview stats and 5 sample records
python scripts/export_plans.py --preview
# Write full output
python scripts/export_plans.py --output data/plan_pairs.jsonl
# Restrict to specific products
python scripts/export_plans.py --products peregrine,kiwi --output data/plan_pairs.jsonl
"""
from __future__ import annotations
import argparse
import hashlib
import json
import re
import sys
from pathlib import Path
from typing import Iterator
# ── Paths ──────────────────────────────────────────────────────────────────────
_SCRIPT_DIR = Path(__file__).parent
_AVOCET_ROOT = _SCRIPT_DIR.parent
_DEFAULT_PLANS_DIR = Path("/Library/Development/CircuitForge/circuitforge-plans")
_DEFAULT_OUTPUT = _AVOCET_ROOT / "data" / "plan_pairs.jsonl"
# ── Doc type detection ─────────────────────────────────────────────────────────
_TYPE_RE = re.compile(
r"-(design|plan|spec|implementation|specs|plans)s?$",
re.IGNORECASE,
)
_SKIP_DIRS = {"__pycache__", ".git", "node_modules"}
# Boilerplate lines to strip from document content before using as output.
_BOILERPLATE_RE = re.compile(
r"""
^\s*>\s*\*\*For\s+agentic\s+workers.* # superpowers agent hints
|^\s*>\s*REQUIRED\s+SUB-SKILL.*
|^\s*\*\*Date:\*\*.* # metadata header lines
|\*\*Status:\*\*\s*Complete.* # completed-feature noise
|\*\*Status:\*\*\s*Done.*
|\*\*Product:\*\*.*
|\*\*Repo:\*\*.*
|\*\*Tech\s+Stack:\*\*.*
|\*\*Candidate:\*\*.* # old synthetic personas
|^Candidate:.*
|^Team:.*
""",
re.VERBOSE | re.MULTILINE,
)
# Old repo/path names to normalise to current equivalents.
_PATH_NORMALIZATIONS: list[tuple[re.Pattern, str]] = [
(re.compile(r"/devl/job-seeker", re.IGNORECASE), "/Library/Development/CircuitForge/peregrine"),
(re.compile(r"\bjob-seeker\b", re.IGNORECASE), "peregrine"),
(re.compile(r"Alex Rivera", re.IGNORECASE), "[user]"),
]
# Instruction paraphrase templates per doc type.
# Each entry is (user_prefix, paired_prefix).
# {title}, {product}, {type_phrase}, {overview}, {design_context} are substituted.
_DESIGN_INSTRUCTIONS = [
"Write a design document for {product}: {title}.\n\nContext: {overview}",
"You are a software architect working on {product}. Draft a design spec for: {title}.\n\n{overview}",
"Produce a CircuitForge-style design document for the following {product} feature — {title}.\n\nBackground: {overview}",
]
_PLAN_INSTRUCTIONS = [
"Write an implementation plan for {product}: {title}.\n\nContext: {overview}",
"Break the following {product} feature into a detailed implementation plan with file structure and task checkboxes — {title}.\n\n{overview}",
"You are a senior engineer on {product}. Produce a step-by-step engineering plan for: {title}.\n\n{overview}",
]
_PAIRED_INSTRUCTIONS = [
(
"You are a software architect working on {product}, a CircuitForge product. "
"Given the following design document, write a detailed implementation plan "
"(file structure, task breakdown with checkboxes, migration steps if needed).\n\n"
"---\n{design_context}\n---"
),
(
"The following is a design spec for a {product} feature. "
"Produce a concrete implementation plan: file list, task checklist, any DB migrations needed.\n\n"
"---\n{design_context}\n---"
),
(
"Convert this {product} design document into an actionable implementation plan. "
"Include all files to create/modify, step-by-step tasks with checkboxes, and migration steps.\n\n"
"---\n{design_context}\n---"
),
]
def _doc_type(stem: str) -> str:
m = _TYPE_RE.search(stem)
if not m:
return "other"
raw = m.group(1).lower().rstrip("s")
return {"implementation": "plan"}.get(raw, raw)
def _date_feature(stem: str) -> tuple[str, str]:
"""Return (date, feature_slug) from '2026-03-03-feedback-button-design'."""
m = re.match(r"^(\d{4}-\d{2}-\d{2})-(.+?)(?:-(design|plan|spec|implementation)s?)?$", stem, re.I)
if m:
return m.group(1), m.group(2)
return "", stem
# ── Content extraction ─────────────────────────────────────────────────────────
def _extract_title(content: str) -> str:
m = re.search(r"^#\s+(.+)", content, re.MULTILINE)
return m.group(1).strip() if m else ""
def _extract_overview(content: str) -> str:
"""Return first substantive paragraph or h2 section body (≤300 chars)."""
# Superpowers plans have an explicit **Goal:** line — prefer that.
goal_m = re.search(r"\*\*Goal:\*\*\s*(.+)", content)
if goal_m:
return goal_m.group(1).strip()[:300]
# Otherwise use the body of the first h2 section.
h2_m = re.search(
r"^##\s+\d*\.?\s*.+\n([\s\S]+?)(?=^##|\Z)",
content,
re.MULTILINE,
)
if h2_m:
body = h2_m.group(1).strip()
# Strip markdown bullet/code noise for the instruction
body = re.sub(r"```[\s\S]*?```", "", body)
body = re.sub(r"`[^`]+`", lambda m: m.group().strip("`"), body)
body = re.sub(r"\*\*([^*]+)\*\*", r"\1", body)
body = re.sub(r"\s+", " ", body).strip()
return body[:300]
return ""
def _clean_content(content: str) -> str:
"""Remove boilerplate, normalize old paths/names, collapse whitespace."""
cleaned = _BOILERPLATE_RE.sub("", content)
for pattern, replacement in _PATH_NORMALIZATIONS:
cleaned = pattern.sub(replacement, cleaned)
cleaned = re.sub(r"\n{3,}", "\n\n", cleaned)
return cleaned.strip()
def _quality_flags(content: str) -> list[str]:
"""Return a list of quality issue labels found in cleaned content."""
flags = []
if "Alex Rivera" in content or "[user]" in content:
flags.append("persona-residue")
if re.search(r"\bStatus:\s*(Complete|Done|Merged)\b", content):
flags.append("completed-status")
return flags
def _make_instruction(
title: str,
product: str,
doc_type: str,
overview: str,
design_context: str | None = None,
variant: int = 0,
) -> str:
"""Synthesise a natural planning prompt for this document.
variant: 0-2 selects which paraphrase template to use. Caller cycles
through all three to produce multiple training examples per document.
"""
product_label = product.replace("-", " ").title() if product else "CircuitForge"
idx = variant % 3
if design_context:
tmpl = _PAIRED_INSTRUCTIONS[idx]
return tmpl.format(
product=product_label,
design_context=design_context[:2500],
)
templates = _PLAN_INSTRUCTIONS if doc_type in ("plan",) else _DESIGN_INSTRUCTIONS
tmpl = templates[idx]
return tmpl.format(
product=product_label,
title=title,
overview=overview or "",
type_phrase="planning document",
)
def _record_id(content: str, source: str) -> str:
return hashlib.sha256(f"{source}:{content}".encode()).hexdigest()[:16]
# ── Pair discovery ─────────────────────────────────────────────────────────────
def _find_pairs(plans_dir: Path) -> dict[str, list[tuple[str, Path]]]:
"""Return {prefix_key → [(doc_type, path), ...]} for docs sharing date+feature."""
by_prefix: dict[str, list[tuple[str, Path]]] = {}
for path in plans_dir.rglob("*.md"):
if any(part in _SKIP_DIRS for part in path.parts):
continue
if path.name == "README.md":
continue
stem = path.stem
date, feature = _date_feature(stem)
if not date:
continue
key = str(path.parent / f"{date}-{feature}")
by_prefix.setdefault(key, []).append((_doc_type(stem), path))
return by_prefix
# ── Record generation ──────────────────────────────────────────────────────────
def _records_for_group(
doc_type_paths: list[tuple[str, Path]],
plans_dir: Path,
) -> Iterator[dict]:
"""Yield one or more training records for a group of related docs."""
# Separate design vs plan docs within this group
designs = [(t, p) for t, p in doc_type_paths if t in ("design", "spec")]
plans_ = [(t, p) for t, p in doc_type_paths if t in ("plan",)]
others = [(t, p) for t, p in doc_type_paths if t not in ("design", "spec", "plan")]
all_paths = doc_type_paths
if designs and plans_:
# Paired: yield a design→plan record (3 instruction variants)
design_type, design_path = designs[0]
plan_type, plan_path = plans_[0]
design_content = design_path.read_text(encoding="utf-8")
plan_content = plan_path.read_text(encoding="utf-8")
product = _product_from_path(plan_path, plans_dir)
title = _extract_title(plan_content) or plan_path.stem
cleaned = _clean_content(plan_content)
design_cleaned = _clean_content(design_content)
flags = _quality_flags(cleaned)
if len(cleaned.split()) >= 80:
rel_src = str(plan_path.relative_to(plans_dir))
rel_design = str(design_path.relative_to(plans_dir))
for variant in range(3):
instruction = _make_instruction(
title=title,
product=product,
doc_type="plan",
overview=_extract_overview(design_content),
design_context=design_cleaned,
variant=variant,
)
yield {
"id": _record_id(f"v{variant}:{cleaned}", rel_src),
"messages": [
{"role": "user", "content": instruction},
{"role": "assistant", "content": cleaned},
],
"meta": {
"source": rel_src,
"product": product,
"doc_type": "plan",
"date": _date_feature(plan_path.stem)[0],
"paired_with": rel_design,
"word_count": len(cleaned.split()),
"pair_role": "target",
"variant": variant,
"quality_flags": flags,
},
}
# Also yield the design doc as standalone variants
all_paths = [(t, p) for t, p in all_paths if p != plan_path]
# Remaining docs as standalone records (3 instruction variants each)
for doc_type, path in all_paths:
content = path.read_text(encoding="utf-8")
cleaned = _clean_content(content)
if len(cleaned.split()) < 80:
continue
product = _product_from_path(path, plans_dir)
title = _extract_title(content) or path.stem
overview = _extract_overview(content)
flags = _quality_flags(cleaned)
rel_src = str(path.relative_to(plans_dir))
for variant in range(3):
instruction = _make_instruction(
title=title,
product=product,
doc_type=doc_type,
overview=overview,
variant=variant,
)
yield {
"id": _record_id(f"v{variant}:{cleaned}", rel_src),
"messages": [
{"role": "user", "content": instruction},
{"role": "assistant", "content": cleaned},
],
"meta": {
"source": rel_src,
"product": product,
"doc_type": doc_type,
"date": _date_feature(path.stem)[0],
"paired_with": None,
"word_count": len(cleaned.split()),
"pair_role": "standalone",
"variant": variant,
"quality_flags": flags,
},
}
def _product_from_path(path: Path, plans_dir: Path) -> str:
rel = path.relative_to(plans_dir)
return rel.parts[0] if len(rel.parts) > 1 else "shared"
# ── Main export ────────────────────────────────────────────────────────────────
def export(
plans_dir: Path,
products: list[str] | None = None,
) -> list[dict]:
groups = _find_pairs(plans_dir)
records: list[dict] = []
seen_ids: set[str] = set()
for group_key, doc_type_paths in groups.items():
# Filter by product if requested
if products:
paths = [p for _, p in doc_type_paths]
prods = {_product_from_path(p, plans_dir) for p in paths}
if not prods.intersection(products):
continue
for record in _records_for_group(doc_type_paths, plans_dir):
if record["id"] not in seen_ids:
seen_ids.add(record["id"])
records.append(record)
return records
# ── CLI ────────────────────────────────────────────────────────────────────────
def _print_stats(records: list[dict]) -> None:
from collections import Counter
products = Counter(r["meta"]["product"] for r in records)
doc_types = Counter(r["meta"]["doc_type"] for r in records)
pair_roles = Counter(r["meta"]["pair_role"] for r in records)
wc = [r["meta"]["word_count"] for r in records]
wc.sort()
print(f"\n{'='*55}")
print(f" Total records: {len(records)}")
print(f" Word counts : min={wc[0]}, median={wc[len(wc)//2]}, max={wc[-1]}")
print(f"\n By product:")
for p, n in products.most_common():
print(f" {p:<22} {n}")
print(f"\n By doc type:")
for t, n in doc_types.most_common():
print(f" {t:<22} {n}")
print(f"\n Pair roles:")
for r, n in pair_roles.most_common():
print(f" {r:<22} {n}")
print(f"{'='*55}\n")
def _print_sample(records: list[dict], n: int = 3) -> None:
import random
sample = random.sample(records, min(n, len(records)))
for i, rec in enumerate(sample, 1):
meta = rec["meta"]
user_msg = rec["messages"][0]["content"]
asst_msg = rec["messages"][1]["content"]
print(f"\n{''*55}")
print(f"SAMPLE {i}/{n} [{meta['product']} / {meta['doc_type']} / {meta['pair_role']}]")
print(f"source: {meta['source']}")
print(f"\nUSER ({len(user_msg)} chars):\n{user_msg[:500]}{'...' if len(user_msg)>500 else ''}")
print(f"\nASSISTANT ({meta['word_count']} words):\n{asst_msg[:400]}{'...' if len(asst_msg)>400 else ''}")
print(f"\n{''*55}\n")
def main() -> None:
parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter)
parser.add_argument("--plans-dir", type=Path, default=_DEFAULT_PLANS_DIR)
parser.add_argument("--output", type=Path, default=None,
help="Write JSONL to this path (omit for preview-only)")
parser.add_argument("--products", default=None,
help="Comma-separated product filter, e.g. peregrine,kiwi")
parser.add_argument("--preview", action="store_true",
help="Print stats + sample records, don't write output")
parser.add_argument("--samples", type=int, default=3,
help="Number of sample records to show in preview (default 3)")
args = parser.parse_args()
products = [p.strip() for p in args.products.split(",")] if args.products else None
print(f"Scanning {args.plans_dir}", file=sys.stderr)
records = export(args.plans_dir, products=products)
_print_stats(records)
if args.preview or args.output is None:
_print_sample(records, n=args.samples)
if args.output is None:
print("(Pass --output <path> to write JSONL)")
return
args.output.parent.mkdir(parents=True, exist_ok=True)
with open(args.output, "w", encoding="utf-8") as f:
for rec in records:
f.write(json.dumps(rec, ensure_ascii=False) + "\n")
print(f"Wrote {len(records)} records to {args.output}")
if __name__ == "__main__":
main()

View file

@ -1,23 +1,37 @@
import json """Smoke tests for the app factory (app/api.py).
Detailed route tests live in test_data_label.py, test_data_fetch.py,
test_data_corrections.py, test_train.py, and test_dashboard.py.
"""
import pytest import pytest
from app import api as api_module # noqa: F401 from fastapi.testclient import TestClient
@pytest.fixture(autouse=True)
def reset_globals(tmp_path):
from app import api
api.set_data_dir(tmp_path)
api.reset_last_action()
yield
api.reset_last_action()
def test_import(): def test_import():
from app import api # noqa: F401 from app import api # noqa: F401
from fastapi.testclient import TestClient def test_app_has_required_routes():
from app.api import app
paths = {r.path for r in app.routes}
# Label routes
assert "/api/queue" in paths
assert "/api/label" in paths
assert "/api/skip" in paths
assert "/api/discard" in paths
assert "/api/label/undo" in paths
assert "/api/config/labels" in paths
assert "/api/stats" in paths
# Fetch routes
assert "/api/accounts/test" in paths
assert "/api/fetch/stream" in paths
# Train routes
assert "/api/train/jobs" in paths
assert "/api/train/results" in paths
# Dashboard
assert "/api/dashboard" in paths
# Corrections (new prefix)
assert "/api/corrections/ingest" in paths
@pytest.fixture @pytest.fixture
@ -26,536 +40,8 @@ def client():
return TestClient(app) return TestClient(app)
@pytest.fixture def test_queue_endpoint_reachable(client):
def queue_with_items():
"""Write 3 test emails to the queue file."""
from app import api as api_module
items = [
{"id": f"id{i}", "subject": f"Subject {i}", "body": f"Body {i}",
"from": "test@example.com", "date": "2026-03-01", "source": "imap:test"}
for i in range(3)
]
queue_path = api_module._DATA_DIR / "email_label_queue.jsonl"
queue_path.write_text("\n".join(json.dumps(x) for x in items) + "\n")
return items
def test_queue_returns_items(client, queue_with_items):
r = client.get("/api/queue?limit=2")
assert r.status_code == 200
data = r.json()
assert len(data["items"]) == 2
assert data["total"] == 3
def test_queue_empty_when_no_file(client):
r = client.get("/api/queue") r = client.get("/api/queue")
assert r.status_code == 200 assert r.status_code == 200
assert r.json() == {"items": [], "total": 0} assert "items" in r.json()
assert "total" in r.json()
def test_label_appends_to_score(client, queue_with_items):
from app import api as api_module
r = client.post("/api/label", json={"id": "id0", "label": "interview_scheduled"})
assert r.status_code == 200
records = api_module._read_jsonl(api_module._score_file())
assert len(records) == 1
assert records[0]["id"] == "id0"
assert records[0]["label"] == "interview_scheduled"
assert "labeled_at" in records[0]
def test_label_removes_from_queue(client, queue_with_items):
from app import api as api_module
client.post("/api/label", json={"id": "id0", "label": "rejected"})
queue = api_module._read_jsonl(api_module._queue_file())
assert not any(x["id"] == "id0" for x in queue)
def test_label_unknown_id_returns_404(client, queue_with_items):
r = client.post("/api/label", json={"id": "unknown", "label": "neutral"})
assert r.status_code == 404
def test_skip_moves_to_back(client, queue_with_items):
from app import api as api_module
r = client.post("/api/skip", json={"id": "id0"})
assert r.status_code == 200
queue = api_module._read_jsonl(api_module._queue_file())
assert queue[-1]["id"] == "id0"
assert queue[0]["id"] == "id1"
def test_skip_unknown_id_returns_404(client, queue_with_items):
r = client.post("/api/skip", json={"id": "nope"})
assert r.status_code == 404
# --- Part A: POST /api/discard ---
def test_discard_writes_to_discarded_file(client, queue_with_items):
from app import api as api_module
r = client.post("/api/discard", json={"id": "id1"})
assert r.status_code == 200
discarded = api_module._read_jsonl(api_module._discarded_file())
assert len(discarded) == 1
assert discarded[0]["id"] == "id1"
assert discarded[0]["label"] == "__discarded__"
def test_discard_removes_from_queue(client, queue_with_items):
from app import api as api_module
client.post("/api/discard", json={"id": "id1"})
queue = api_module._read_jsonl(api_module._queue_file())
assert not any(x["id"] == "id1" for x in queue)
# --- Part B: DELETE /api/label/undo ---
def test_undo_label_removes_from_score(client, queue_with_items):
from app import api as api_module
client.post("/api/label", json={"id": "id0", "label": "neutral"})
r = client.delete("/api/label/undo")
assert r.status_code == 200
data = r.json()
assert data["undone"]["type"] == "label"
score = api_module._read_jsonl(api_module._score_file())
assert score == []
# Item should be restored to front of queue
queue = api_module._read_jsonl(api_module._queue_file())
assert queue[0]["id"] == "id0"
def test_undo_discard_removes_from_discarded(client, queue_with_items):
from app import api as api_module
client.post("/api/discard", json={"id": "id0"})
r = client.delete("/api/label/undo")
assert r.status_code == 200
discarded = api_module._read_jsonl(api_module._discarded_file())
assert discarded == []
def test_undo_skip_restores_to_front(client, queue_with_items):
from app import api as api_module
client.post("/api/skip", json={"id": "id0"})
r = client.delete("/api/label/undo")
assert r.status_code == 200
queue = api_module._read_jsonl(api_module._queue_file())
assert queue[0]["id"] == "id0"
def test_undo_with_no_action_returns_404(client):
r = client.delete("/api/label/undo")
assert r.status_code == 404
# --- Part C: GET /api/config/labels ---
def test_config_labels_returns_metadata(client):
r = client.get("/api/config/labels")
assert r.status_code == 200
labels = r.json()
assert len(labels) == 10
assert labels[0]["key"] == "1"
assert "emoji" in labels[0]
assert "color" in labels[0]
assert "name" in labels[0]
# ── /api/config ──────────────────────────────────────────────────────────────
@pytest.fixture
def config_dir(tmp_path):
"""Give the API a writable config directory."""
from app import api as api_module
api_module.set_config_dir(tmp_path)
yield tmp_path
api_module.set_config_dir(None) # reset to default
@pytest.fixture
def data_dir():
"""Expose the current _DATA_DIR set by the autouse reset_globals fixture."""
from app import api as api_module
return api_module._DATA_DIR
def test_get_config_returns_empty_when_no_file(client, config_dir):
r = client.get("/api/config")
assert r.status_code == 200
data = r.json()
assert data["accounts"] == []
assert data["max_per_account"] == 500
def test_post_config_writes_yaml(client, config_dir):
import yaml
payload = {
"accounts": [{"name": "Test", "host": "imap.test.com", "port": 993,
"use_ssl": True, "username": "u@t.com", "password": "pw",
"folder": "INBOX", "days_back": 30}],
"max_per_account": 200,
}
r = client.post("/api/config", json=payload)
assert r.status_code == 200
assert r.json()["ok"] is True
cfg_file = config_dir / "label_tool.yaml"
assert cfg_file.exists()
saved = yaml.safe_load(cfg_file.read_text())
assert saved["max_per_account"] == 200
assert saved["accounts"][0]["name"] == "Test"
def test_get_config_round_trips(client, config_dir):
payload = {"accounts": [{"name": "R", "host": "h", "port": 993, "use_ssl": True,
"username": "u", "password": "p", "folder": "INBOX",
"days_back": 90}], "max_per_account": 300}
client.post("/api/config", json=payload)
r = client.get("/api/config")
data = r.json()
assert data["max_per_account"] == 300
assert data["accounts"][0]["name"] == "R"
# ── /api/stats ───────────────────────────────────────────────────────────────
@pytest.fixture
def score_with_labels(tmp_path, data_dir):
"""Write a score file with 3 labels for stats tests."""
score_path = data_dir / "email_score.jsonl"
records = [
{"id": "a", "label": "interview_scheduled"},
{"id": "b", "label": "interview_scheduled"},
{"id": "c", "label": "rejected"},
]
score_path.write_text("\n".join(json.dumps(r) for r in records) + "\n")
return records
def test_stats_returns_counts(client, score_with_labels):
r = client.get("/api/stats")
assert r.status_code == 200
data = r.json()
assert data["total"] == 3
assert data["counts"]["interview_scheduled"] == 2
assert data["counts"]["rejected"] == 1
def test_stats_empty_when_no_file(client, data_dir):
r = client.get("/api/stats")
assert r.status_code == 200
data = r.json()
assert data["total"] == 0
assert data["counts"] == {}
assert data["score_file_bytes"] == 0
def test_stats_download_returns_file(client, score_with_labels):
r = client.get("/api/stats/download")
assert r.status_code == 200
assert "jsonlines" in r.headers.get("content-type", "")
def test_stats_download_404_when_no_file(client, data_dir):
r = client.get("/api/stats/download")
assert r.status_code == 404
# ── /api/accounts/test ───────────────────────────────────────────────────────
def test_account_test_missing_fields(client):
r = client.post("/api/accounts/test", json={"account": {"host": "", "username": "", "password": ""}})
assert r.status_code == 200
data = r.json()
assert data["ok"] is False
assert "required" in data["message"].lower()
def test_account_test_success(client):
from unittest.mock import MagicMock, patch
mock_conn = MagicMock()
mock_conn.select.return_value = ("OK", [b"99"])
with patch("app.imap_fetch.imaplib.IMAP4_SSL", return_value=mock_conn):
r = client.post("/api/accounts/test", json={"account": {
"host": "imap.example.com", "port": 993, "use_ssl": True,
"username": "u@example.com", "password": "pw", "folder": "INBOX",
}})
assert r.status_code == 200
data = r.json()
assert data["ok"] is True
assert data["count"] == 99
# ── /api/fetch/stream (SSE) ──────────────────────────────────────────────────
def _parse_sse(content: bytes) -> list[dict]:
"""Parse SSE response body into list of event dicts."""
events = []
for line in content.decode().splitlines():
if line.startswith("data: "):
events.append(json.loads(line[6:]))
return events
def test_fetch_stream_no_accounts_configured(client, config_dir):
"""With no config, stream should immediately complete with 0 added."""
r = client.get("/api/fetch/stream?accounts=NoSuchAccount&days_back=30&limit=10")
assert r.status_code == 200
events = _parse_sse(r.content)
complete = next((e for e in events if e["type"] == "complete"), None)
assert complete is not None
assert complete["total_added"] == 0
def test_fetch_stream_with_mock_imap(client, config_dir, data_dir):
"""With one configured account, stream should yield start/done/complete events."""
import yaml
from unittest.mock import MagicMock, patch
# Write a config with one account
cfg = {"accounts": [{"name": "Mock", "host": "h", "port": 993, "use_ssl": True,
"username": "u", "password": "p", "folder": "INBOX",
"days_back": 30}], "max_per_account": 50}
(config_dir / "label_tool.yaml").write_text(yaml.dump(cfg))
raw_msg = (b"Subject: Interview\r\nFrom: a@b.com\r\n"
b"Date: Mon, 1 Mar 2026 12:00:00 +0000\r\n\r\nBody")
mock_conn = MagicMock()
mock_conn.search.return_value = ("OK", [b"1"])
mock_conn.fetch.return_value = ("OK", [(b"1 (RFC822 {N})", raw_msg)])
with patch("app.imap_fetch.imaplib.IMAP4_SSL", return_value=mock_conn):
r = client.get("/api/fetch/stream?accounts=Mock&days_back=30&limit=50")
assert r.status_code == 200
events = _parse_sse(r.content)
types = [e["type"] for e in events]
assert "start" in types
assert "done" in types
assert "complete" in types
# ---- /api/finetune/status tests ----
def test_finetune_status_returns_empty_when_no_models_dir(client):
"""GET /api/finetune/status must return [] if models/ does not exist."""
r = client.get("/api/finetune/status")
assert r.status_code == 200
assert r.json() == []
def test_finetune_status_returns_training_info(client, tmp_path):
"""GET /api/finetune/status must return one entry per training_info.json found."""
import json as _json
from app import api as api_module
models_dir = tmp_path / "models" / "avocet-deberta-small"
models_dir.mkdir(parents=True)
info = {
"name": "avocet-deberta-small",
"base_model_id": "cross-encoder/nli-deberta-v3-small",
"val_macro_f1": 0.712,
"timestamp": "2026-03-15T12:00:00Z",
"sample_count": 401,
}
(models_dir / "training_info.json").write_text(_json.dumps(info))
api_module.set_models_dir(tmp_path / "models")
try:
r = client.get("/api/finetune/status")
assert r.status_code == 200
data = r.json()
assert any(d["name"] == "avocet-deberta-small" for d in data)
finally:
api_module.set_models_dir(api_module._ROOT / "models")
def test_finetune_run_streams_sse_events(client):
"""GET /api/finetune/run must return text/event-stream content type."""
from unittest.mock import patch, MagicMock
mock_proc = MagicMock()
mock_proc.stdout = iter(["Training epoch 1\n", "Done\n"])
mock_proc.returncode = 0
mock_proc.wait = MagicMock()
with patch("app.api._subprocess.Popen",return_value=mock_proc):
r = client.get("/api/finetune/run?model=deberta-small&epochs=1")
assert r.status_code == 200
assert "text/event-stream" in r.headers.get("content-type", "")
def test_finetune_run_emits_complete_on_success(client):
"""GET /api/finetune/run must emit a complete event on clean exit."""
from unittest.mock import patch, MagicMock
mock_proc = MagicMock()
mock_proc.stdout = iter(["progress line\n"])
mock_proc.returncode = 0
mock_proc.wait = MagicMock()
with patch("app.api._subprocess.Popen",return_value=mock_proc):
r = client.get("/api/finetune/run?model=deberta-small&epochs=1")
assert '{"type": "complete"}' in r.text
def test_finetune_run_emits_error_on_nonzero_exit(client):
"""GET /api/finetune/run must emit an error event on non-zero exit."""
from unittest.mock import patch, MagicMock
mock_proc = MagicMock()
mock_proc.stdout = iter([])
mock_proc.returncode = 1
mock_proc.wait = MagicMock()
with patch("app.api._subprocess.Popen",return_value=mock_proc):
r = client.get("/api/finetune/run?model=deberta-small&epochs=1")
assert '"type": "error"' in r.text
def test_finetune_run_passes_score_files_to_subprocess(client):
"""GET /api/finetune/run?score=file1&score=file2 must pass --score args to subprocess."""
from unittest.mock import patch, MagicMock
captured_cmd = []
def mock_popen(cmd, **kwargs):
captured_cmd.extend(cmd)
m = MagicMock()
m.stdout = iter([])
m.returncode = 0
m.wait = MagicMock()
return m
with patch("app.api._subprocess.Popen",side_effect=mock_popen):
client.get("/api/finetune/run?model=deberta-small&epochs=1&score=run1.jsonl&score=run2.jsonl")
assert "--score" in captured_cmd
assert captured_cmd.count("--score") == 2
# Paths are resolved to absolute — check filenames are present as substrings
assert any("run1.jsonl" in arg for arg in captured_cmd)
assert any("run2.jsonl" in arg for arg in captured_cmd)
# ---- Cancel endpoint tests ----
def test_benchmark_cancel_returns_404_when_not_running(client):
"""POST /api/benchmark/cancel must return 404 if no benchmark is running."""
from app import api as api_module
api_module._running_procs.pop("benchmark", None)
r = client.post("/api/benchmark/cancel")
assert r.status_code == 404
def test_finetune_cancel_returns_404_when_not_running(client):
"""POST /api/finetune/cancel must return 404 if no finetune is running."""
from app import api as api_module
api_module._running_procs.pop("finetune", None)
r = client.post("/api/finetune/cancel")
assert r.status_code == 404
def test_benchmark_cancel_terminates_running_process(client):
"""POST /api/benchmark/cancel must call terminate() on the running process."""
from unittest.mock import MagicMock
from app import api as api_module
mock_proc = MagicMock()
mock_proc.wait = MagicMock()
api_module._running_procs["benchmark"] = mock_proc
try:
r = client.post("/api/benchmark/cancel")
assert r.status_code == 200
assert r.json()["status"] == "cancelled"
mock_proc.terminate.assert_called_once()
finally:
api_module._running_procs.pop("benchmark", None)
api_module._cancelled_jobs.discard("benchmark")
def test_finetune_cancel_terminates_running_process(client):
"""POST /api/finetune/cancel must call terminate() on the running process."""
from unittest.mock import MagicMock
from app import api as api_module
mock_proc = MagicMock()
mock_proc.wait = MagicMock()
api_module._running_procs["finetune"] = mock_proc
try:
r = client.post("/api/finetune/cancel")
assert r.status_code == 200
assert r.json()["status"] == "cancelled"
mock_proc.terminate.assert_called_once()
finally:
api_module._running_procs.pop("finetune", None)
api_module._cancelled_jobs.discard("finetune")
def test_benchmark_cancel_kills_process_on_timeout(client):
"""POST /api/benchmark/cancel must call kill() if the process does not exit within 3 s."""
import subprocess
from unittest.mock import MagicMock
from app import api as api_module
mock_proc = MagicMock()
mock_proc.wait.side_effect = subprocess.TimeoutExpired(cmd="benchmark", timeout=3)
api_module._running_procs["benchmark"] = mock_proc
try:
r = client.post("/api/benchmark/cancel")
assert r.status_code == 200
mock_proc.kill.assert_called_once()
finally:
api_module._running_procs.pop("benchmark", None)
api_module._cancelled_jobs.discard("benchmark")
def test_finetune_run_emits_cancelled_event(client):
"""GET /api/finetune/run must emit cancelled (not error) when job was cancelled."""
from unittest.mock import patch, MagicMock
from app import api as api_module
mock_proc = MagicMock()
mock_proc.stdout = iter([])
mock_proc.returncode = -15 # SIGTERM
def mock_wait():
# Simulate cancel being called while the process is running (after discard clears stale flag)
api_module._cancelled_jobs.add("finetune")
mock_proc.wait = mock_wait
def mock_popen(cmd, **kwargs):
return mock_proc
try:
with patch("app.api._subprocess.Popen",side_effect=mock_popen):
r = client.get("/api/finetune/run?model=deberta-small&epochs=1")
assert '{"type": "cancelled"}' in r.text
assert '"type": "error"' not in r.text
finally:
api_module._cancelled_jobs.discard("finetune")
def test_benchmark_run_emits_cancelled_event(client):
"""GET /api/benchmark/run must emit cancelled (not error) when job was cancelled."""
from unittest.mock import patch, MagicMock
from app import api as api_module
mock_proc = MagicMock()
mock_proc.stdout = iter([])
mock_proc.returncode = -15
def mock_wait():
# Simulate cancel being called while the process is running (after discard clears stale flag)
api_module._cancelled_jobs.add("benchmark")
mock_proc.wait = mock_wait
def mock_popen(cmd, **kwargs):
return mock_proc
try:
with patch("app.api._subprocess.Popen",side_effect=mock_popen):
r = client.get("/api/benchmark/run")
assert '{"type": "cancelled"}' in r.text
assert '"type": "error"' not in r.text
finally:
api_module._cancelled_jobs.discard("benchmark")

View file

@ -2,11 +2,6 @@
import pytest import pytest
def test_registry_has_thirteen_models():
from scripts.benchmark_classifier import MODEL_REGISTRY
assert len(MODEL_REGISTRY) == 13
def test_registry_default_count(): def test_registry_default_count():
from scripts.benchmark_classifier import MODEL_REGISTRY from scripts.benchmark_classifier import MODEL_REGISTRY
defaults = [k for k, v in MODEL_REGISTRY.items() if v["default"]] defaults = [k for k, v in MODEL_REGISTRY.items() if v["default"]]
@ -166,3 +161,95 @@ def test_active_models_includes_discovered_finetuned(tmp_path):
assert "avocet-deberta-small" in models assert "avocet-deberta-small" in models
assert isinstance(models["avocet-deberta-small"]["adapter_instance"], FineTunedAdapter) assert isinstance(models["avocet-deberta-small"]["adapter_instance"], FineTunedAdapter)
# ---- build_exemplars_from_jsonl() tests ----
def test_build_exemplars_samples_up_to_k_per_label(tmp_path):
from scripts.benchmark_classifier import build_exemplars_from_jsonl
import json
rows = [{"subject": f"S{i}", "body": f"B{i}", "label": "rejected"} for i in range(15)]
rows.append({"subject": "Hire", "body": "Welcome", "label": "hired"})
f = tmp_path / "score.jsonl"
f.write_text("\n".join(json.dumps(r) for r in rows))
result = build_exemplars_from_jsonl(str(f), k_per_label=10)
assert len(result["rejected"]) == 10
assert len(result["hired"]) == 1
assert result["rejected"][0].startswith("Subject: S")
def test_build_exemplars_formats_text_correctly(tmp_path):
from scripts.benchmark_classifier import build_exemplars_from_jsonl
import json
row = {"subject": "My Subject", "body": "My Body", "label": "neutral"}
f = tmp_path / "score.jsonl"
f.write_text(json.dumps(row))
result = build_exemplars_from_jsonl(str(f))
assert result["neutral"][0] == "Subject: My Subject\n\nMy Body"
def test_build_exemplars_skips_rows_missing_label(tmp_path):
from scripts.benchmark_classifier import build_exemplars_from_jsonl
import json
rows = [
{"subject": "A", "body": "B", "label": "neutral"},
{"subject": "No label here", "body": "Body"},
]
f = tmp_path / "score.jsonl"
f.write_text("\n".join(json.dumps(r) for r in rows))
result = build_exemplars_from_jsonl(str(f))
assert list(result.keys()) == ["neutral"]
def test_build_exemplars_truncates_body_at_600(tmp_path):
from scripts.benchmark_classifier import build_exemplars_from_jsonl
import json
row = {"subject": "S", "body": "x" * 800, "label": "neutral"}
f = tmp_path / "score.jsonl"
f.write_text(json.dumps(row))
result = build_exemplars_from_jsonl(str(f))
body_part = result["neutral"][0].split("\n\n", 1)[1]
assert len(body_part) == 600
def test_build_exemplars_skips_rows_with_no_content(tmp_path):
from scripts.benchmark_classifier import build_exemplars_from_jsonl
import json
rows = [
{"label": "neutral"}, # no subject, no body -> skip
{"subject": "S", "body": "B", "label": "neutral"}, # valid -> keep
{"label": "rejected", "subject": "", "body": ""}, # empty strings -> skip
]
f = tmp_path / "score.jsonl"
lines = [json.dumps(r) for r in rows]
f.write_text("\n".join(lines))
result = build_exemplars_from_jsonl(str(f))
assert list(result.keys()) == ["neutral"]
assert len(result["neutral"]) == 1
def test_registry_has_fourteen_models():
from scripts.benchmark_classifier import MODEL_REGISTRY
assert len(MODEL_REGISTRY) == 14
def test_embed_knn_nomic_registry_entry():
from scripts.benchmark_classifier import MODEL_REGISTRY
from scripts.classifier_adapters import EmbeddingKNNAdapter
entry = MODEL_REGISTRY["embed-knn-nomic"]
assert entry["adapter"] is EmbeddingKNNAdapter
assert entry["model_id"] == "nomic-embed-text"
assert entry["params"] == "local-embed"
assert entry["default"] is False
assert entry.get("kwargs", {}).get("k") == 3

View file

@ -14,7 +14,9 @@ from fastapi.testclient import TestClient
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def reset_cforch_globals(tmp_path): def reset_cforch_globals(tmp_path):
"""Redirect _CONFIG_DIR to tmp_path and reset running-state globals.""" """Redirect _CONFIG_DIR to tmp_path, reset running-state globals, and stub
list_installed to return [] so real disk model directories don't bleed into
tests that don't exercise the installed-model merge path."""
from app import cforch as cforch_module from app import cforch as cforch_module
prev_config_dir = cforch_module._CONFIG_DIR prev_config_dir = cforch_module._CONFIG_DIR
@ -25,7 +27,8 @@ def reset_cforch_globals(tmp_path):
cforch_module._BENCH_RUNNING = False cforch_module._BENCH_RUNNING = False
cforch_module._bench_proc = None cforch_module._bench_proc = None
yield tmp_path with patch("app.models.list_installed", return_value=[]):
yield tmp_path
cforch_module.set_config_dir(prev_config_dir) cforch_module.set_config_dir(prev_config_dir)
cforch_module._BENCH_RUNNING = prev_running cforch_module._BENCH_RUNNING = prev_running
@ -141,6 +144,35 @@ def test_models_parses_bench_models_yaml(client, config_dir, tmp_path):
assert m["vram_estimate_mb"] == 6000 assert m["vram_estimate_mb"] == 6000
def test_models_merges_installed_generators(client, config_dir, tmp_path):
"""Installed cf-text/vllm generator models appear in the model list,
deduplicated against bench_models.yaml entries."""
models_file = tmp_path / "bench_models.yaml"
_write_models_yaml(models_file, [
{"name": "llama3", "id": "llama3:8b", "service": "ollama", "tags": [], "vram_estimate_mb": 6000},
{"name": "already-there", "id": "ibm-granite/granite-4.1-8b", "service": "cf-text", "tags": [], "vram_estimate_mb": 8000},
])
_write_config(config_dir, {"bench_models": str(models_file)})
fake_installed = [
# should be included — cf-text generator not already in YAML
{"model_id": "meta-llama/Llama-3.1-8B", "service": "cf-text", "role": "generator", "vram_mb": 16000},
# should be deduped — repo_id matches a YAML entry
{"model_id": "ibm-granite/granite-4.1-8b", "service": "cf-text", "role": "generator", "vram_mb": 8000},
# should be excluded — classifier, not a generator
{"model_id": "cross-encoder/ms-marco-MiniLM-L6", "service": "avocet", "role": "reranker", "vram_mb": 500},
]
with patch("app.models.list_installed", return_value=fake_installed):
r = client.get("/api/cforch/models")
assert r.status_code == 200
ids = [m["id"] for m in r.json()["models"]]
assert "llama3:8b" in ids # from YAML
assert "ibm-granite/granite-4.1-8b" in ids # from YAML (not duplicated)
assert "meta-llama/Llama-3.1-8B" in ids # merged from installed
assert "cross-encoder/ms-marco-MiniLM-L6" not in ids # filtered out (reranker)
assert ids.count("ibm-granite/granite-4.1-8b") == 1 # no duplicate
# ── GET /run ─────────────────────────────────────────────────────────────────── # ── GET /run ───────────────────────────────────────────────────────────────────
def test_run_returns_409_when_already_running(client): def test_run_returns_409_when_already_running(client):
@ -367,3 +399,13 @@ def test_run_passes_license_key_env_to_subprocess(client, config_dir, tmp_path,
client.get("/api/cforch/run") client.get("/api/cforch/run")
assert captured_env.get("CF_LICENSE_KEY") == "CFG-AVCT-ENV-ONLY-KEY" assert captured_env.get("CF_LICENSE_KEY") == "CFG-AVCT-ENV-ONLY-KEY"
def test_eval_cforch_router_includes_all_sub_routers():
"""eval/cforch.py router must include routes from all four sub-routers."""
from app.eval.cforch import router
paths = {r.path for r in router.routes}
assert any("/cforch/" in p for p in paths), f"no /cforch/ routes found in {paths}"
assert any("/style/" in p for p in paths), f"no /style/ routes found in {paths}"
assert any("/voice/" in p for p in paths), f"no /voice/ routes found in {paths}"
assert any("/plans-bench/" in p for p in paths), f"no /plans-bench/ routes found in {paths}"

View file

@ -268,3 +268,373 @@ def test_finetuned_adapter_unload_clears_pipeline():
assert adapter._pipeline is not None assert adapter._pipeline is not None
adapter.unload() adapter.unload()
assert adapter._pipeline is None assert adapter._pipeline is None
# ---- _cosine() tests ----
def test_cosine_identical_unit_vectors():
import math
from scripts.classifier_adapters import _cosine
assert _cosine([1.0, 0.0], [1.0, 0.0]) == pytest.approx(1.0)
def test_cosine_orthogonal_vectors():
from scripts.classifier_adapters import _cosine
assert _cosine([1.0, 0.0], [0.0, 1.0]) == pytest.approx(0.0)
def test_cosine_known_value():
import math
from scripts.classifier_adapters import _cosine
# [1,0] vs [1/sqrt(2), 1/sqrt(2)] → dot = 1/sqrt(2), both norms = 1 → 1/sqrt(2)
v = [1.0 / math.sqrt(2), 1.0 / math.sqrt(2)]
assert _cosine([1.0, 0.0], v) == pytest.approx(1.0 / math.sqrt(2))
def test_cosine_zero_vector_returns_zero():
from scripts.classifier_adapters import _cosine
assert _cosine([0.0, 0.0], [1.0, 0.0]) == pytest.approx(0.0)
# ---- DEFAULT_EXEMPLARS tests ----
def test_default_exemplars_covers_all_labels():
from scripts.classifier_adapters import DEFAULT_EXEMPLARS, LABELS
for label in LABELS:
assert label in DEFAULT_EXEMPLARS, f"DEFAULT_EXEMPLARS missing label: {label}"
assert len(DEFAULT_EXEMPLARS[label]) >= 4, f"{label} needs >= 4 exemplars for k=3 voting"
def test_default_exemplars_sparse_labels_have_at_least_four():
from scripts.classifier_adapters import DEFAULT_EXEMPLARS
# These labels have very few real examples; need >= 4 so k=3 vote is meaningful
for label in ("hired", "survey_received", "event_rescheduled"):
assert len(DEFAULT_EXEMPLARS[label]) >= 4, (
f"{label} needs >= 4 exemplars for k=3 voting to work reliably"
)
def test_default_exemplars_strings_are_formatted_correctly():
from scripts.classifier_adapters import DEFAULT_EXEMPLARS
for label, texts in DEFAULT_EXEMPLARS.items():
for text in texts:
assert text.startswith("Subject: "), (
f"{label!r} exemplar missing 'Subject: ' prefix: {text[:50]!r}"
)
assert "\n\n" in text, (
f"{label!r} exemplar missing double-newline separator: {text[:50]!r}"
)
# ---- EmbeddingKNNAdapter constructor tests ----
def test_embedding_knn_is_classifier_adapter():
from scripts.classifier_adapters import EmbeddingKNNAdapter, ClassifierAdapter
adapter = EmbeddingKNNAdapter(
"test-knn", "nomic-embed-text",
k=3, orch_url="http://orch:7700", ollama_url="http://ollama:11434",
)
assert isinstance(adapter, ClassifierAdapter)
def test_embedding_knn_name_and_model_id():
from scripts.classifier_adapters import EmbeddingKNNAdapter
adapter = EmbeddingKNNAdapter(
"embed-knn-nomic", "nomic-embed-text",
k=3, orch_url="http://orch:7700", ollama_url="http://ollama:11434",
)
assert adapter.name == "embed-knn-nomic"
assert adapter.model_id == "nomic-embed-text"
def test_embedding_knn_uses_default_exemplars_when_none_given():
from scripts.classifier_adapters import EmbeddingKNNAdapter, DEFAULT_EXEMPLARS
adapter = EmbeddingKNNAdapter(
"test", "nomic-embed-text",
k=3, orch_url="http://orch:7700", ollama_url="http://ollama:11434",
)
assert adapter._exemplar_texts is DEFAULT_EXEMPLARS
def test_embedding_knn_accepts_custom_exemplars():
from scripts.classifier_adapters import EmbeddingKNNAdapter
custom = {"rejected": ["Sorry, we went with others."]}
adapter = EmbeddingKNNAdapter(
"test", "nomic-embed-text",
k=3, orch_url="http://orch:7700", ollama_url="http://ollama:11434",
exemplar_texts=custom,
)
assert adapter._exemplar_texts is custom
# ---- EmbeddingKNNAdapter.load() tests ----
def _make_post_mock(alloc_url="http://navi:11434", alloc_id="alloc-abc"):
"""Return a side_effect function for patching httpx.post.
Allocate calls get alloc_url/alloc_id; embed calls return one [0.1,0.2,0.3]
embedding per input text.
"""
def _side_effect(url, *, json=None, timeout=None, **kwargs):
from unittest.mock import MagicMock
resp = MagicMock()
resp.raise_for_status.return_value = None
if "/allocate" in url:
resp.status_code = 200
resp.json.return_value = {"allocation_id": alloc_id, "url": alloc_url}
else:
n = len((json or {}).get("input", []))
resp.status_code = 200
resp.json.return_value = {"data": [{"embedding": [0.1, 0.2, 0.3]}] * n}
return resp
return _side_effect
def test_load_calls_allocate_then_embeds_each_label():
from unittest.mock import patch
from scripts.classifier_adapters import EmbeddingKNNAdapter
exemplars = {
"rejected": ["We went with others"],
"hired": ["Welcome aboard!", "First day info"],
}
adapter = EmbeddingKNNAdapter(
"test", "nomic-embed-text", k=3,
orch_url="http://orch:7700", ollama_url="http://ollama:11434",
exemplar_texts=exemplars,
)
post_urls = []
def capturing_mock(url, *, json=None, timeout=None, **kwargs):
post_urls.append(url)
return _make_post_mock()(url, json=json, timeout=timeout)
with patch("httpx.post", side_effect=capturing_mock):
adapter.load()
assert any("/allocate" in u for u in post_urls), "expected allocate call"
assert any("/v1/embeddings" in u for u in post_urls), "expected embed call"
assert adapter._allocation_id == "alloc-abc"
assert adapter._node_url == "http://navi:11434"
assert adapter._orch_url_used == "http://orch:7700"
assert "rejected" in adapter._exemplar_embeddings
assert "hired" in adapter._exemplar_embeddings
assert len(adapter._exemplar_embeddings["rejected"]) == 1
assert len(adapter._exemplar_embeddings["hired"]) == 2
assert adapter._exemplar_embeddings["rejected"][0] == [0.1, 0.2, 0.3]
assert adapter._exemplar_embeddings["hired"][0] == [0.1, 0.2, 0.3]
def test_load_falls_back_to_ollama_when_allocate_fails():
from unittest.mock import patch, MagicMock
from scripts.classifier_adapters import EmbeddingKNNAdapter
exemplars = {"rejected": ["We went with others"]}
adapter = EmbeddingKNNAdapter(
"test", "nomic-embed-text", k=3,
orch_url="http://orch:7700", ollama_url="http://ollama:11434",
exemplar_texts=exemplars,
)
def failing_allocate_mock(url, *, json=None, timeout=None, **kwargs):
resp = MagicMock()
if "/allocate" in url:
resp.status_code = 503
resp.json.return_value = {}
else:
resp.raise_for_status.return_value = None
resp.json.return_value = {"data": [{"embedding": [0.1, 0.2, 0.3]}]}
return resp
with patch("httpx.post", side_effect=failing_allocate_mock):
adapter.load()
assert adapter._allocation_id == ""
assert adapter._orch_url_used == ""
assert adapter._node_url == "http://ollama:11434"
assert "rejected" in adapter._exemplar_embeddings
def test_load_falls_back_to_ollama_when_allocate_raises():
from unittest.mock import patch, MagicMock
import httpx as _httpx
from scripts.classifier_adapters import EmbeddingKNNAdapter
exemplars = {"rejected": ["We went with others"]}
adapter = EmbeddingKNNAdapter(
"test", "nomic-embed-text", k=3,
orch_url="http://orch:7700", ollama_url="http://ollama:11434",
exemplar_texts=exemplars,
)
def raising_mock(url, *, json=None, timeout=None, **kwargs):
if "/allocate" in url:
raise _httpx.ConnectError("connection refused")
resp = MagicMock()
resp.raise_for_status.return_value = None
resp.json.return_value = {"data": [{"embedding": [0.1, 0.2, 0.3]}]}
return resp
with patch("httpx.post", side_effect=raising_mock):
adapter.load()
assert adapter._allocation_id == ""
assert adapter._orch_url_used == ""
assert adapter._node_url == "http://ollama:11434"
assert "rejected" in adapter._exemplar_embeddings
# ---- EmbeddingKNNAdapter.unload() tests ----
def test_unload_releases_orch_allocation_and_clears_state():
from unittest.mock import patch, MagicMock
from scripts.classifier_adapters import EmbeddingKNNAdapter
adapter = EmbeddingKNNAdapter(
"test", "nomic-embed-text", k=3,
orch_url="http://orch:7700", ollama_url="http://ollama:11434",
)
adapter._exemplar_embeddings = {"rejected": [[1.0, 0.0]]}
adapter._node_url = "http://navi:11434"
adapter._allocation_id = "alloc-abc"
adapter._orch_url_used = "http://orch:7700"
delete_calls = []
def mock_request(method, url, **kwargs):
delete_calls.append((method, url))
resp = MagicMock()
resp.status_code = 200
return resp
with patch("httpx.request", side_effect=mock_request):
adapter.unload()
assert len(delete_calls) == 1
method, url = delete_calls[0]
assert method == "DELETE"
assert "alloc-abc" in url
assert adapter._exemplar_embeddings == {}
assert adapter._allocation_id == ""
assert adapter._node_url == ""
assert adapter._orch_url_used == ""
def test_unload_skips_delete_on_ollama_fallback_path():
from unittest.mock import patch
from scripts.classifier_adapters import EmbeddingKNNAdapter
adapter = EmbeddingKNNAdapter(
"test", "nomic-embed-text", k=3,
orch_url="http://orch:7700", ollama_url="http://ollama:11434",
)
adapter._exemplar_embeddings = {"rejected": [[1.0, 0.0]]}
adapter._node_url = "http://ollama:11434"
adapter._allocation_id = "" # fallback path: no allocation was made
adapter._orch_url_used = ""
delete_calls = []
with patch("httpx.request", side_effect=lambda *a, **k: delete_calls.append(a)):
adapter.unload()
assert len(delete_calls) == 0
assert adapter._exemplar_embeddings == {}
assert adapter._node_url == ""
# ---- EmbeddingKNNAdapter.classify() tests ----
def _adapter_with_embeddings(exemplar_embeddings, k=3):
"""Return a pre-loaded EmbeddingKNNAdapter (bypass load()) with given per-label vectors."""
from scripts.classifier_adapters import EmbeddingKNNAdapter
adapter = EmbeddingKNNAdapter(
"test", "nomic-embed-text", k=k,
orch_url="http://orch:7700", ollama_url="http://ollama:11434",
)
adapter._exemplar_embeddings = exemplar_embeddings
adapter._node_url = "http://navi:11434"
return adapter
def _embed_resp(vec):
"""Return a mock httpx response for /v1/embeddings returning a single vector."""
from unittest.mock import MagicMock
resp = MagicMock()
resp.raise_for_status.return_value = None
resp.json.return_value = {"data": [{"embedding": vec}]}
return resp
def test_classify_returns_majority_vote_label():
from unittest.mock import patch
adapter = _adapter_with_embeddings({
"rejected": [[1.0, 0.0, 0.0], [0.9, 0.1, 0.0], [0.85, 0.15, 0.0]],
"neutral": [[0.0, 1.0, 0.0]],
}, k=3)
# Query [1,0,0] is closest to all three "rejected" exemplars
with patch("httpx.post", return_value=_embed_resp([1.0, 0.0, 0.0])):
result = adapter.classify("We went with others", "Thank you for applying.")
assert result == "rejected"
def test_classify_tiebreak_by_mean_score():
from unittest.mock import patch
# k=2: each label gets exactly 1 vote → tie-break by mean similarity
# [1,0] query: cosine to [1,0] = 1.0 ("rejected"), cosine to [0.6,0.8] ≈ 0.6 ("neutral")
adapter = _adapter_with_embeddings({
"rejected": [[1.0, 0.0]],
"neutral": [[0.6, 0.8]],
}, k=2)
with patch("httpx.post", return_value=_embed_resp([1.0, 0.0])):
result = adapter.classify("Rejection", "Sorry")
assert result == "rejected"
def test_classify_sparse_label_can_win():
from unittest.mock import patch
# "hired" has only 1 exemplar; with k=1, the single closest match wins
adapter = _adapter_with_embeddings({
"rejected": [[0.0, 0.0, 1.0], [0.0, 0.1, 0.9]],
"hired": [[1.0, 0.0, 0.0]],
}, k=1)
# Query [1,0,0] → hired exemplar scores 1.0; closest single match wins
with patch("httpx.post", return_value=_embed_resp([1.0, 0.0, 0.0])):
result = adapter.classify("Welcome aboard", "Your first day details")
assert result == "hired"
def test_classify_lazy_loads_when_not_loaded():
from unittest.mock import patch
from scripts.classifier_adapters import EmbeddingKNNAdapter
exemplars = {"rejected": ["We went with others"]}
adapter = EmbeddingKNNAdapter(
"test", "nomic-embed-text", k=1,
orch_url="http://orch:7700", ollama_url="http://ollama:11434",
exemplar_texts=exemplars,
)
assert adapter._exemplar_embeddings == {}
post_urls = []
def mock_post(url, *, json=None, timeout=None, **kwargs):
post_urls.append(url)
from unittest.mock import MagicMock
resp = MagicMock()
resp.raise_for_status.return_value = None
if "/allocate" in url:
resp.status_code = 200
resp.json.return_value = {"allocation_id": "a1", "url": "http://navi:11434"}
else:
n = len((json or {}).get("input", []))
resp.json.return_value = {"data": [{"embedding": [1.0, 0.0]}] * n}
return resp
with patch("httpx.post", side_effect=mock_post):
result = adapter.classify("Rejection", "Sorry")
assert result == "rejected"
assert any("/allocate" in u for u in post_urls), "lazy load must call allocate"
assert adapter._exemplar_embeddings != {}
assert adapter._node_url == "http://navi:11434"

122
tests/test_dashboard.py Normal file
View file

@ -0,0 +1,122 @@
"""Tests for app/dashboard.py -- GET /api/dashboard."""
import json
import pytest
import yaml
from fastapi.testclient import TestClient
from pathlib import Path
@pytest.fixture(autouse=True)
def reset_globals(tmp_path):
from app import dashboard as dash_module
dash_module.set_data_dir(tmp_path)
dash_module.set_config_dir(tmp_path)
yield
@pytest.fixture
def client():
from app.api import app
return TestClient(app)
def _write_score(tmp_path: Path, records: list[dict]) -> None:
(tmp_path / "email_score.jsonl").write_text(
"\n".join(json.dumps(r) for r in records) + "\n"
)
def _write_summary(tmp_path: Path, run_id: str, ts: str, score: float) -> None:
run_dir = tmp_path / "bench_results" / run_id
run_dir.mkdir(parents=True)
(run_dir / "summary.json").write_text(
json.dumps({"timestamp": ts, "best_macro_f1": score})
)
def test_dashboard_returns_expected_keys(client):
r = client.get("/api/dashboard")
assert r.status_code == 200
data = r.json()
for key in ("labeled_since_last_eval", "last_eval_timestamp", "last_eval_best_score",
"active_jobs", "corrections_pending", "corrections_export_ready", "signals"):
assert key in data, f"missing key: {key}"
for sig in ("data_to_eval", "eval_to_train", "train_to_fleet"):
assert sig in data["signals"], f"missing signal: {sig}"
def test_dashboard_empty_state(client):
r = client.get("/api/dashboard")
assert r.status_code == 200
data = r.json()
assert data["labeled_since_last_eval"] == 0
assert data["last_eval_timestamp"] is None
assert data["last_eval_best_score"] is None
assert data["active_jobs"] == []
assert data["corrections_pending"] == 0
assert data["corrections_export_ready"] == 0
def test_labeled_since_counts_all_when_no_eval(client, tmp_path):
_write_score(tmp_path, [
{"id": "a", "label": "neutral", "labeled_at": "2026-05-01T10:00:00+00:00"},
{"id": "b", "label": "neutral", "labeled_at": "2026-05-01T11:00:00+00:00"},
])
r = client.get("/api/dashboard")
assert r.json()["labeled_since_last_eval"] == 2
def test_labeled_since_filters_by_eval_timestamp(client, tmp_path):
_write_summary(tmp_path, "2026-05-01-100000", "2026-05-01T10:00:00+00:00", 0.80)
_write_score(tmp_path, [
{"id": "a", "label": "neutral", "labeled_at": "2026-05-01T09:00:00+00:00"},
{"id": "b", "label": "neutral", "labeled_at": "2026-05-01T11:00:00+00:00"},
])
(tmp_path / "label_tool.yaml").write_text(
yaml.dump({"cforch": {"results_dir": str(tmp_path / "bench_results")}})
)
r = client.get("/api/dashboard")
data = r.json()
assert data["labeled_since_last_eval"] == 1
assert abs(data["last_eval_best_score"] - 0.80) < 0.001
def test_data_to_eval_false_below_threshold(client, tmp_path):
_write_score(tmp_path, [{"id": str(i), "label": "neutral",
"labeled_at": "2026-05-01T10:00:00+00:00"} for i in range(10)])
(tmp_path / "label_tool.yaml").write_text(yaml.dump({"pipeline": {"data_eval_threshold": 50}}))
r = client.get("/api/dashboard")
assert r.json()["signals"]["data_to_eval"] is False
def test_data_to_eval_true_at_threshold(client, tmp_path):
_write_score(tmp_path, [{"id": str(i), "label": "neutral",
"labeled_at": "2026-05-01T10:00:00+00:00"} for i in range(50)])
(tmp_path / "label_tool.yaml").write_text(yaml.dump({"pipeline": {"data_eval_threshold": 50}}))
r = client.get("/api/dashboard")
assert r.json()["signals"]["data_to_eval"] is True
def test_corrections_pending_count(client, tmp_path):
candidates = [
{"id": "c1", "status": "needs_review"},
{"id": "c2", "status": "needs_review"},
{"id": "c3", "status": "discarded"},
]
(tmp_path / "sft_candidates.jsonl").write_text(
"\n".join(json.dumps(c) for c in candidates) + "\n"
)
r = client.get("/api/dashboard")
assert r.json()["corrections_pending"] == 2
def test_corrections_export_ready_count(client, tmp_path):
approved = [
{"id": "a1", "status": "approved", "corrected_response": "Good answer"},
{"id": "a2", "status": "approved", "corrected_response": ""},
{"id": "a3", "status": "approved", "corrected_response": "Another answer"},
]
(tmp_path / "sft_approved.jsonl").write_text(
"\n".join(json.dumps(a) for a in approved) + "\n"
)
r = client.get("/api/dashboard")
assert r.json()["corrections_export_ready"] == 2

View file

@ -0,0 +1,102 @@
"""Tests for app/data/corrections.py -- POST /api/sft/ingest.
The corrections router is mounted at prefix="/api/sft" via the app/sft.py
backward-compat shim, so ingest lives at /api/sft/ingest.
"""
import json
import pytest
from fastapi.testclient import TestClient
@pytest.fixture(autouse=True)
def reset_globals(tmp_path):
from app.data import corrections as corr_module
corr_module.set_data_dir(tmp_path)
corr_module.set_config_dir(tmp_path)
yield
@pytest.fixture
def client():
from app.api import app
return TestClient(app)
_VALID_PAYLOAD = {
"source": "peregrine",
"task_type": "email_classification",
"prompt": "Classify this email: ...",
"response": "skip",
"correction": "action_required",
"label": "action_required",
}
_SECRET = "test-secret-abc123"
def test_ingest_503_when_secret_not_configured(client, monkeypatch):
monkeypatch.delenv("AVOCET_INGESTION_SECRET", raising=False)
r = client.post("/api/sft/ingest", json=_VALID_PAYLOAD,
headers={"Authorization": f"Bearer {_SECRET}"})
assert r.status_code == 503
def test_ingest_401_when_no_auth_header(client, monkeypatch):
monkeypatch.setenv("AVOCET_INGESTION_SECRET", _SECRET)
r = client.post("/api/sft/ingest", json=_VALID_PAYLOAD)
assert r.status_code == 401
def test_ingest_401_when_malformed_header(client, monkeypatch):
monkeypatch.setenv("AVOCET_INGESTION_SECRET", _SECRET)
r = client.post("/api/sft/ingest", json=_VALID_PAYLOAD,
headers={"Authorization": "Token bad-format"})
assert r.status_code == 401
def test_ingest_403_when_wrong_secret(client, monkeypatch):
monkeypatch.setenv("AVOCET_INGESTION_SECRET", _SECRET)
r = client.post("/api/sft/ingest", json=_VALID_PAYLOAD,
headers={"Authorization": "Bearer wrong-secret"})
assert r.status_code == 403
def test_ingest_creates_approved_record(client, monkeypatch, tmp_path):
from app.data import corrections as corr_module
monkeypatch.setenv("AVOCET_INGESTION_SECRET", _SECRET)
corr_module.set_data_dir(tmp_path)
r = client.post("/api/sft/ingest", json=_VALID_PAYLOAD,
headers={"Authorization": f"Bearer {_SECRET}"})
assert r.status_code == 200
data = r.json()
assert data["ok"] is True
assert "id" in data
candidates = corr_module.read_jsonl(corr_module._candidates_file())
assert len(candidates) == 1
rec = candidates[0]
assert rec["status"] == "approved"
assert rec["source"] == "peregrine"
assert rec["corrected_response"] == "action_required"
assert rec["id"] == data["id"]
def test_ingest_also_writes_to_approved_file(client, monkeypatch, tmp_path):
from app.data import corrections as corr_module
monkeypatch.setenv("AVOCET_INGESTION_SECRET", _SECRET)
corr_module.set_data_dir(tmp_path)
r = client.post("/api/sft/ingest", json=_VALID_PAYLOAD,
headers={"Authorization": f"Bearer {_SECRET}"})
assert r.status_code == 200
approved = corr_module.read_jsonl(corr_module._approved_file())
assert len(approved) == 1
assert approved[0]["id"] == r.json()["id"]
def test_ingest_without_label_is_accepted(client, monkeypatch, tmp_path):
from app.data import corrections as corr_module
monkeypatch.setenv("AVOCET_INGESTION_SECRET", _SECRET)
corr_module.set_data_dir(tmp_path)
payload = {**_VALID_PAYLOAD, "label": None}
r = client.post("/api/sft/ingest", json=payload,
headers={"Authorization": f"Bearer {_SECRET}"})
assert r.status_code == 200

95
tests/test_data_fetch.py Normal file
View file

@ -0,0 +1,95 @@
"""Tests for app/data/fetch.py"""
import json
import yaml
import pytest
from fastapi.testclient import TestClient
from unittest.mock import MagicMock, patch
@pytest.fixture(autouse=True)
def reset_globals(tmp_path):
from app.data import fetch as fetch_module
fetch_module.set_data_dir(tmp_path)
fetch_module.set_config_dir(tmp_path)
yield
@pytest.fixture
def client():
from app.api import app
return TestClient(app)
def _parse_sse(content: bytes) -> list[dict]:
events = []
for line in content.decode().splitlines():
if line.startswith("data: "):
events.append(json.loads(line[6:]))
return events
def test_account_test_missing_fields(client):
r = client.post("/api/accounts/test",
json={"account": {"host": "", "username": "", "password": ""}})
assert r.status_code == 200
data = r.json()
assert data["ok"] is False
assert "required" in data["message"].lower()
def test_account_test_success(client):
mock_conn = MagicMock()
mock_conn.select.return_value = ("OK", [b"99"])
with patch("app.data.fetch.imaplib.IMAP4_SSL", return_value=mock_conn):
r = client.post("/api/accounts/test", json={"account": {
"host": "imap.example.com", "port": 993, "use_ssl": True,
"username": "u@example.com", "password": "pw", "folder": "INBOX",
}})
assert r.status_code == 200
data = r.json()
assert data["ok"] is True
assert data["count"] == 99
def test_fetch_stream_no_accounts_configured(client, tmp_path):
r = client.get("/api/fetch/stream?accounts=NoSuchAccount&days_back=30&limit=10")
assert r.status_code == 200
events = _parse_sse(r.content)
complete = next((e for e in events if e["type"] == "complete"), None)
assert complete is not None
assert complete["total_added"] == 0
def test_fetch_stream_with_mock_imap(client, tmp_path):
from app.data import fetch as fetch_module
fetch_module.set_config_dir(tmp_path)
cfg = {"accounts": [{"name": "Mock", "host": "h", "port": 993, "use_ssl": True,
"username": "u", "password": "p", "folder": "INBOX",
"days_back": 30}], "max_per_account": 50}
(tmp_path / "label_tool.yaml").write_text(yaml.dump(cfg))
raw_msg = (b"Subject: Interview\r\nFrom: a@b.com\r\n"
b"Date: Mon, 1 Mar 2026 12:00:00 +0000\r\n\r\nBody")
mock_conn = MagicMock()
mock_conn.search.return_value = ("OK", [b"1"])
mock_conn.fetch.return_value = ("OK", [(b"1 (RFC822 {N})", raw_msg)])
with patch("app.data.fetch.imaplib.IMAP4_SSL", return_value=mock_conn):
r = client.get("/api/fetch/stream?accounts=Mock&days_back=30&limit=50")
assert r.status_code == 200
events = _parse_sse(r.content)
types = [e["type"] for e in events]
assert "start" in types
assert "done" in types
assert "complete" in types
def test_entry_key_deterministic():
from app.data.fetch import entry_key
e = {"subject": "Test", "body": "Hello world"}
assert entry_key(e) == entry_key(e)
def test_entry_key_differs_by_subject():
from app.data.fetch import entry_key
a = {"subject": "A", "body": "same body"}
b = {"subject": "B", "body": "same body"}
assert entry_key(a) != entry_key(b)

219
tests/test_data_label.py Normal file
View file

@ -0,0 +1,219 @@
"""Tests for app/data/label.py"""
import json
import pytest
import yaml
from fastapi.testclient import TestClient
@pytest.fixture(autouse=True)
def reset_globals(tmp_path):
from app.data import label as label_module
label_module.set_data_dir(tmp_path)
label_module.set_config_dir(tmp_path)
label_module.reset_last_action()
yield
label_module.reset_last_action()
@pytest.fixture
def client():
from app.api import app
return TestClient(app)
@pytest.fixture
def queue_with_items(tmp_path):
from app.data import label as label_module
items = [
{"id": f"id{i}", "subject": f"Subject {i}", "body": f"Body {i}",
"from": "test@example.com", "date": "2026-03-01", "source": "imap:test"}
for i in range(3)
]
(label_module._DATA_DIR / "email_label_queue.jsonl").write_text(
"\n".join(json.dumps(x) for x in items) + "\n")
return items
def test_queue_returns_items(client, queue_with_items):
r = client.get("/api/queue?limit=2")
assert r.status_code == 200
data = r.json()
assert len(data["items"]) == 2
assert data["total"] == 3
def test_queue_empty_when_no_file(client):
r = client.get("/api/queue")
assert r.status_code == 200
assert r.json() == {"items": [], "total": 0}
def test_label_appends_to_score(client, queue_with_items):
from app.data import label as label_module
r = client.post("/api/label", json={"id": "id0", "label": "interview_scheduled"})
assert r.status_code == 200
records = label_module.read_jsonl(label_module._score_file())
assert len(records) == 1
assert records[0]["id"] == "id0"
assert records[0]["label"] == "interview_scheduled"
assert "labeled_at" in records[0]
def test_label_removes_from_queue(client, queue_with_items):
from app.data import label as label_module
client.post("/api/label", json={"id": "id0", "label": "rejected"})
queue = label_module.read_jsonl(label_module._queue_file())
assert not any(x["id"] == "id0" for x in queue)
def test_label_unknown_id_returns_404(client, queue_with_items):
r = client.post("/api/label", json={"id": "unknown", "label": "neutral"})
assert r.status_code == 404
def test_skip_moves_to_back(client, queue_with_items):
from app.data import label as label_module
r = client.post("/api/skip", json={"id": "id0"})
assert r.status_code == 200
queue = label_module.read_jsonl(label_module._queue_file())
assert queue[-1]["id"] == "id0"
assert queue[0]["id"] == "id1"
def test_skip_unknown_id_returns_404(client, queue_with_items):
r = client.post("/api/skip", json={"id": "nope"})
assert r.status_code == 404
def test_discard_writes_to_discarded_file(client, queue_with_items):
from app.data import label as label_module
r = client.post("/api/discard", json={"id": "id1"})
assert r.status_code == 200
discarded = label_module.read_jsonl(label_module._discarded_file())
assert len(discarded) == 1
assert discarded[0]["id"] == "id1"
assert discarded[0]["label"] == "__discarded__"
def test_discard_removes_from_queue(client, queue_with_items):
from app.data import label as label_module
client.post("/api/discard", json={"id": "id1"})
queue = label_module.read_jsonl(label_module._queue_file())
assert not any(x["id"] == "id1" for x in queue)
def test_undo_label_removes_from_score(client, queue_with_items):
from app.data import label as label_module
client.post("/api/label", json={"id": "id0", "label": "neutral"})
r = client.delete("/api/label/undo")
assert r.status_code == 200
assert r.json()["undone"]["type"] == "label"
assert label_module.read_jsonl(label_module._score_file()) == []
queue = label_module.read_jsonl(label_module._queue_file())
assert queue[0]["id"] == "id0"
def test_undo_discard_removes_from_discarded(client, queue_with_items):
from app.data import label as label_module
client.post("/api/discard", json={"id": "id0"})
r = client.delete("/api/label/undo")
assert r.status_code == 200
assert label_module.read_jsonl(label_module._discarded_file()) == []
def test_undo_skip_restores_to_front(client, queue_with_items):
from app.data import label as label_module
client.post("/api/skip", json={"id": "id0"})
r = client.delete("/api/label/undo")
assert r.status_code == 200
queue = label_module.read_jsonl(label_module._queue_file())
assert queue[0]["id"] == "id0"
def test_undo_with_no_action_returns_404(client):
r = client.delete("/api/label/undo")
assert r.status_code == 404
def test_config_labels_returns_10_labels(client):
r = client.get("/api/config/labels")
assert r.status_code == 200
labels = r.json()
assert len(labels) == 10
assert labels[0]["key"] == "1"
for lbl in labels:
assert "emoji" in lbl and "color" in lbl and "name" in lbl
def test_get_config_returns_empty_when_no_file(client):
r = client.get("/api/config")
assert r.status_code == 200
data = r.json()
assert data["accounts"] == []
assert data["max_per_account"] == 500
def test_post_config_writes_yaml(client, tmp_path):
from app.data import label as label_module
label_module.set_config_dir(tmp_path)
payload = {"accounts": [{"name": "Test", "host": "imap.test.com", "port": 993,
"use_ssl": True, "username": "u@t.com", "password": "pw",
"folder": "INBOX", "days_back": 30}], "max_per_account": 200}
r = client.post("/api/config", json=payload)
assert r.status_code == 200
assert r.json()["ok"] is True
saved = yaml.safe_load((tmp_path / "label_tool.yaml").read_text())
assert saved["max_per_account"] == 200
assert saved["accounts"][0]["name"] == "Test"
def test_get_config_round_trips(client, tmp_path):
from app.data import label as label_module
label_module.set_config_dir(tmp_path)
payload = {"accounts": [{"name": "R", "host": "h", "port": 993, "use_ssl": True,
"username": "u", "password": "p", "folder": "INBOX",
"days_back": 90}], "max_per_account": 300}
client.post("/api/config", json=payload)
r = client.get("/api/config")
data = r.json()
assert data["max_per_account"] == 300
assert data["accounts"][0]["name"] == "R"
def test_stats_returns_counts(client, tmp_path):
from app.data import label as label_module
label_module.set_data_dir(tmp_path)
score_path = tmp_path / "email_score.jsonl"
records = [{"id": "a", "label": "interview_scheduled"},
{"id": "b", "label": "interview_scheduled"},
{"id": "c", "label": "rejected"}]
score_path.write_text("\n".join(json.dumps(r) for r in records) + "\n")
r = client.get("/api/stats")
assert r.status_code == 200
data = r.json()
assert data["total"] == 3
assert data["counts"]["interview_scheduled"] == 2
assert data["counts"]["rejected"] == 1
def test_stats_empty_when_no_file(client):
r = client.get("/api/stats")
assert r.status_code == 200
data = r.json()
assert data["total"] == 0
assert data["counts"] == {}
assert data["score_file_bytes"] == 0
def test_stats_download_returns_file(client, tmp_path):
from app.data import label as label_module
label_module.set_data_dir(tmp_path)
(tmp_path / "email_score.jsonl").write_text(json.dumps({"id": "a", "label": "neutral"}) + "\n")
r = client.get("/api/stats/download")
assert r.status_code == 200
assert "jsonlines" in r.headers.get("content-type", "")
def test_stats_download_404_when_no_file(client):
r = client.get("/api/stats/download")
assert r.status_code == 404

View file

@ -1,4 +1,4 @@
"""Tests for app/imitate.py product registry, sample extraction, corrections push.""" """Tests for app/imitate.py -- product registry, sample extraction, corrections push."""
from __future__ import annotations from __future__ import annotations
import json import json
@ -9,10 +9,10 @@ import pytest
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from app.api import app from app.api import app
from app import imitate as _imitate_module from app.data import imitate as _imitate_module
# ── Fixtures ─────────────────────────────────────────────────────────────────── # -- Fixtures ------------------------------------------------------------------
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def reset_module_globals(tmp_path): def reset_module_globals(tmp_path):
@ -70,7 +70,7 @@ def client() -> TestClient:
return TestClient(app, raise_server_exceptions=True) return TestClient(app, raise_server_exceptions=True)
# ── GET /products ────────────────────────────────────────────────────────────── # -- GET /products -------------------------------------------------------------
def test_products_empty_when_no_config(config_dir, client): def test_products_empty_when_no_config(config_dir, client):
"""Returns empty list when label_tool.yaml has no imitate section.""" """Returns empty list when label_tool.yaml has no imitate section."""
@ -102,7 +102,7 @@ def test_products_offline_when_unreachable(cfg_with_products, client):
assert all(not p["online"] for p in resp.json()["products"]) assert all(not p["online"] for p in resp.json()["products"])
# ── GET /products/{id}/sample ───────────────────────────────────────────────── # -- GET /products/{id}/sample -------------------------------------------------
def test_sample_unknown_product(cfg_with_products, client): def test_sample_unknown_product(cfg_with_products, client):
"""Returns 404 for a product id not in config.""" """Returns 404 for a product id not in config."""
@ -149,7 +149,7 @@ def test_sample_404_on_empty_list(cfg_with_products, client):
assert resp.status_code == 404 assert resp.status_code == 404
# ── POST /push-corrections ───────────────────────────────────────────────────── # -- POST /push-corrections ----------------------------------------------------
def test_push_corrections_appends_jsonl(cfg_with_products, data_dir, client): def test_push_corrections_appends_jsonl(cfg_with_products, data_dir, client):
"""Successful push writes records to sft_candidates.jsonl.""" """Successful push writes records to sft_candidates.jsonl."""
@ -214,7 +214,7 @@ def test_push_corrections_all_errors_422(cfg_with_products, data_dir, client):
assert resp.status_code == 422 assert resp.status_code == 422
# ── _extract_sample helper ───────────────────────────────────────────────────── # -- _extract_sample helper ----------------------------------------------------
def test_extract_sample_list(): def test_extract_sample_list():
result = _imitate_module._extract_sample( result = _imitate_module._extract_sample(

View file

@ -541,3 +541,84 @@ def test_delete_installed_name_with_slash_blocked(client):
except _HTTPException as exc: except _HTTPException as exc:
assert exc.status_code in (400, 404) assert exc.status_code in (400, 404)
raise raise
# ── Catalog registration ───────────────────────────────────────────────────────
_MINIMAL_YAML = """\
services:
cf-text:
max_mb: {max_mb}
catalog:
existing-model:
path: /some/path
vram_mb: 1000
description: "placeholder"
"""
def _make_node_yaml(tmp_path: Path, max_mb: int = 8192) -> Path:
p = tmp_path / "testnode.yaml"
p.write_text(_MINIMAL_YAML.format(max_mb=max_mb), encoding="utf-8")
return p
def test_catalog_registration_fp16_no_env_block(tmp_path):
"""When model fits at FP16, no env block should be written."""
from app import models as models_module
node_yaml = _make_node_yaml(tmp_path, max_mb=8192)
with patch.object(models_module, "_CF_ORCH_PROFILES_DIR", tmp_path):
updated = models_module._register_in_node_catalogs(
repo_id="org/SmallModel",
local_path=tmp_path / "org--SmallModel",
vram_mb_fp16=4000,
role="generator",
)
assert "testnode" in updated
content = node_yaml.read_text()
# _catalog_key strips org prefix and lowercases: "org/SmallModel" → "smallmodel"
assert "smallmodel:" in content
assert "CF_TEXT_4BIT" not in content
assert "env:" not in content
def test_catalog_registration_needs_4bit_writes_env_block(tmp_path):
"""When model only fits at 4-bit, env: CF_TEXT_4BIT: '1' must be written."""
from app import models as models_module
node_yaml = _make_node_yaml(tmp_path, max_mb=8192)
with patch.object(models_module, "_CF_ORCH_PROFILES_DIR", tmp_path):
updated = models_module._register_in_node_catalogs(
repo_id="org/BigModel",
local_path=tmp_path / "org--BigModel",
vram_mb_fp16=20000, # won't fit at FP16 on 8 GB
role="generator",
)
assert "testnode" in updated
content = node_yaml.read_text()
# _catalog_key: "org/BigModel" → "bigmodel"
assert "bigmodel:" in content
assert "env:" in content
assert 'CF_TEXT_4BIT: "1"' in content
assert "CF_TEXT_4BIT=1 required" in content # description note
def test_catalog_registration_too_large_skipped(tmp_path):
"""Model too large even at 4-bit should not be registered."""
from app import models as models_module
node_yaml = _make_node_yaml(tmp_path, max_mb=8192)
with patch.object(models_module, "_CF_ORCH_PROFILES_DIR", tmp_path):
updated = models_module._register_in_node_catalogs(
repo_id="org/HugeModel",
local_path=tmp_path / "org--HugeModel",
vram_mb_fp16=80000, # 4-bit ~22 GB, still won't fit on 8 GB
role="generator",
)
assert updated == []
content = node_yaml.read_text()
assert "hugemodel" not in content

471
tests/test_nodes.py Normal file
View file

@ -0,0 +1,471 @@
"""Tests for app/nodes.py — /api/nodes-mgmt/* endpoints."""
from __future__ import annotations
from pathlib import Path
import pytest
import yaml
from fastapi.testclient import TestClient
from unittest.mock import MagicMock, patch
import os as _os
@pytest.fixture(autouse=True)
def reset_nodes_globals(tmp_path):
"""Redirect _CONFIG_DIR to tmp_path so tests never read the real config."""
from app import nodes as nodes_module
prev = nodes_module._CONFIG_DIR
nodes_module.set_config_dir(tmp_path)
yield tmp_path
nodes_module.set_config_dir(prev)
@pytest.fixture
def client():
from app.api import app
return TestClient(app)
def _write_config(config_dir: Path, cforch_cfg: dict) -> None:
cfg = {"cforch": cforch_cfg}
(config_dir / "label_tool.yaml").write_text(yaml.dump(cfg), encoding="utf-8")
def _write_profile(profiles_dir: Path, node_id: str, profile: dict) -> None:
profiles_dir.mkdir(parents=True, exist_ok=True)
(profiles_dir / f"{node_id}.yaml").write_text(yaml.dump(profile), encoding="utf-8")
def test_nodes_module_imports():
from app import nodes
assert hasattr(nodes, "router")
assert hasattr(nodes, "set_config_dir")
def test_list_nodes_returns_empty_when_no_coordinator(client):
"""No cforch config — endpoint returns empty list, not 500."""
r = client.get("/api/nodes-mgmt/nodes")
assert r.status_code == 200
assert r.json() == []
def _fake_nodes_response(nodes_json: list, services_json: list | None = None):
"""Build side_effect list for two httpx.get calls: nodes then services."""
mock_nodes = MagicMock()
mock_nodes.raise_for_status = MagicMock()
mock_nodes.json.return_value = nodes_json
mock_services = MagicMock()
mock_services.raise_for_status = MagicMock()
mock_services.json.return_value = services_json or []
return [mock_nodes, mock_services]
def test_list_nodes_coordinator_unreachable_returns_empty(client, tmp_path):
"""Coordinator unreachable — returns [] with no 500."""
import httpx
_write_config(tmp_path, {"coordinator_url": "http://fake-coord:7700"})
with patch("httpx.get", side_effect=httpx.ConnectError("refused")):
r = client.get("/api/nodes-mgmt/nodes")
assert r.status_code == 200
assert r.json() == []
def test_list_nodes_merges_profile_data(client, tmp_path):
"""Profile YAML services_assigned merged with live GPU stats."""
profiles_dir = tmp_path / "profiles"
_write_config(tmp_path, {
"coordinator_url": "http://fake-coord:7700",
"profiles_dir": str(profiles_dir),
})
_write_profile(profiles_dir, "heimdall", {
"services": {
"cf-text": {"min_compute_cap": 7.0, "max_mb": 8192, "catalog": {}},
},
"nodes": {
"heimdall": {
"gpus": [{"id": 0, "vram_mb": 24576, "compute_cap": 8.6,
"services": ["cf-text"], "role": "primary", "card": "RTX 3090",
"always_on": True}],
"agent_url": "http://10.1.10.71:7701",
}
}
})
coord_nodes = [{
"node_id": "heimdall", "online": True, "agent_url": "http://10.1.10.71:7701",
"gpus": [{"gpu_id": 0, "card": "RTX 3090", "vram_total_mb": 24576,
"vram_used_mb": 4096, "vram_free_mb": 20480,
"temp_c": 42.0, "utilization_pct": 15.0, "compute_cap": 8.6}],
}]
with patch("httpx.get", side_effect=_fake_nodes_response(coord_nodes)):
r = client.get("/api/nodes-mgmt/nodes")
assert r.status_code == 200
data = r.json()
assert len(data) == 1
node = data[0]
assert node["node_id"] == "heimdall"
assert node["profile_loaded"] is True
assert node["gpus"][0]["services_assigned"] == ["cf-text"]
assert node["gpus"][0]["vram_total_mb"] == 24576
assert "cf-text" in node["services_catalog"]
def test_list_nodes_no_profile_returns_profile_loaded_false(client, tmp_path):
"""Node with no profile YAML — profile_loaded: false, GPU stats still returned."""
_write_config(tmp_path, {"coordinator_url": "http://fake-coord:7700"})
coord_nodes = [{
"node_id": "sif", "online": True, "agent_url": "http://10.1.10.158:7701",
"gpus": [{"gpu_id": 0, "card": "RTX 5060 Ti", "vram_total_mb": 16384,
"vram_used_mb": 0, "vram_free_mb": 16384,
"temp_c": None, "utilization_pct": None, "compute_cap": 10.0}],
}]
with patch("httpx.get", side_effect=_fake_nodes_response(coord_nodes)):
r = client.get("/api/nodes-mgmt/nodes")
assert r.status_code == 200
data = r.json()
node = data[0]
assert node["profile_loaded"] is False
assert node["gpus"][0]["card"] == "RTX 5060 Ti"
assert node["services_catalog"] == {}
def test_list_nodes_marks_running_services(client, tmp_path):
"""services_running populated from coordinator /api/services response."""
profiles_dir = tmp_path / "profiles"
_write_config(tmp_path, {
"coordinator_url": "http://fake-coord:7700",
"profiles_dir": str(profiles_dir),
})
_write_profile(profiles_dir, "heimdall", {
"services": {},
"nodes": {"heimdall": {"gpus": [{"id": 0, "vram_mb": 24576, "compute_cap": 8.6,
"services": ["cf-text"], "role": "p",
"card": "RTX 3090", "always_on": True}],
"agent_url": "http://10.1.10.71:7701"}}
})
coord_nodes = [{"node_id": "heimdall", "online": True,
"agent_url": "http://10.1.10.71:7701",
"gpus": [{"gpu_id": 0, "card": "RTX 3090", "vram_total_mb": 24576,
"vram_used_mb": 8192, "vram_free_mb": 16384,
"temp_c": 55.0, "utilization_pct": 80.0, "compute_cap": 8.6}]}]
coord_services = [{"service": "cf-text", "node_id": "heimdall", "gpu_id": 0}]
with patch("httpx.get", side_effect=_fake_nodes_response(coord_nodes, coord_services)):
r = client.get("/api/nodes-mgmt/nodes")
data = r.json()
assert data[0]["gpus"][0]["services_running"] == ["cf-text"]
# ── GET /api/nodes-mgmt/nodes/{node_id}/profile ────────────────────────────────
def test_get_profile_returns_parsed_yaml(client, tmp_path):
profiles_dir = tmp_path / "profiles"
_write_config(tmp_path, {"profiles_dir": str(profiles_dir)})
profile = {
"services": {"cf-text": {"min_compute_cap": 7.0, "max_mb": 8192, "catalog": {}}},
"nodes": {"heimdall": {"gpus": [], "agent_url": "http://10.1.10.71:7701"}},
}
_write_profile(profiles_dir, "heimdall", profile)
r = client.get("/api/nodes-mgmt/nodes/heimdall/profile")
assert r.status_code == 200
data = r.json()
assert "services" in data
assert "cf-text" in data["services"]
def test_get_profile_404_when_missing(client, tmp_path):
_write_config(tmp_path, {"profiles_dir": str(tmp_path / "profiles")})
r = client.get("/api/nodes-mgmt/nodes/nonexistent/profile")
assert r.status_code == 404
def test_get_profile_500_on_malformed_yaml(client, tmp_path):
profiles_dir = tmp_path / "profiles"
profiles_dir.mkdir()
_write_config(tmp_path, {"profiles_dir": str(profiles_dir)})
(profiles_dir / "bad.yaml").write_text("key: [unclosed", encoding="utf-8")
r = client.get("/api/nodes-mgmt/nodes/bad/profile")
assert r.status_code == 500
# ── POST /api/nodes-mgmt/nodes/{node_id}/gpu/{gpu_id}/services ─────────────────
_BASE_PROFILE = {
"services": {
"cf-text": {"min_compute_cap": 7.0, "max_mb": 8192, "priority": 1,
"catalog": {"llama3": {"vram_mb": 6144, "path": "/m/llama3",
"description": "", "multi_gpu": False, "env": {}}}},
"ollama": {"min_compute_cap": 0.0, "max_mb": 2048, "priority": 2, "catalog": {}},
},
"nodes": {
"heimdall": {
"gpus": [{"id": 0, "vram_mb": 24576, "compute_cap": 8.6,
"services": [], "role": "primary", "card": "RTX 3090",
"always_on": True}],
"agent_url": "http://10.1.10.71:7701",
}
}
}
def _setup_profile(tmp_path, profile=None):
profiles_dir = tmp_path / "profiles"
_write_config(tmp_path, {
"coordinator_url": "http://fake-coord:7700",
"profiles_dir": str(profiles_dir),
})
_write_profile(profiles_dir, "heimdall", profile or _BASE_PROFILE)
return profiles_dir
def test_update_services_compatible_writes_and_reloads(client, tmp_path):
profiles_dir = _setup_profile(tmp_path)
mock_reload = MagicMock()
mock_reload.status_code = 200
with patch("httpx.post", return_value=mock_reload):
r = client.post(
"/api/nodes-mgmt/nodes/heimdall/gpu/0/services",
json={"services": ["cf-text"]},
)
assert r.status_code == 200
data = r.json()
assert data["ok"] is True
assert data["reloaded"] is True
saved = yaml.safe_load((profiles_dir / "heimdall.yaml").read_text())
assert saved["nodes"]["heimdall"]["gpus"][0]["services"] == ["cf-text"]
def test_update_services_atomic_write_uses_tmp_file(client, tmp_path):
"""YAML must be written to .tmp then renamed — never written directly."""
profiles_dir = _setup_profile(tmp_path)
renamed_pairs: list[tuple] = []
original_replace = _os.replace
def capture(src, dst):
renamed_pairs.append((str(src), str(dst)))
original_replace(src, dst)
with patch("os.replace", side_effect=capture), \
patch("httpx.post", return_value=MagicMock(status_code=200)):
client.post(
"/api/nodes-mgmt/nodes/heimdall/gpu/0/services",
json={"services": ["ollama"]},
)
assert any(src.endswith(".tmp") for src, dst in renamed_pairs), \
"Expected atomic write via .tmp rename"
def test_update_services_incompatible_compute_cap_returns_422(client, tmp_path):
low_cap_profile = {
**_BASE_PROFILE,
"nodes": {
"heimdall": {
"gpus": [{"id": 0, "vram_mb": 24576, "compute_cap": 6.0,
"services": [], "role": "p", "card": "GTX 1080",
"always_on": False}],
"agent_url": "http://10.1.10.71:7701",
}
}
}
_setup_profile(tmp_path, low_cap_profile)
r = client.post(
"/api/nodes-mgmt/nodes/heimdall/gpu/0/services",
json={"services": ["cf-text"]},
)
assert r.status_code == 422
assert "compute_cap" in r.json()["detail"]
def test_update_services_insufficient_vram_returns_422(client, tmp_path):
tiny_vram_profile = {
**_BASE_PROFILE,
"nodes": {
"heimdall": {
"gpus": [{"id": 0, "vram_mb": 512, "compute_cap": 8.6,
"services": [], "role": "p", "card": "old",
"always_on": False}],
"agent_url": "http://10.1.10.71:7701",
}
}
}
_setup_profile(tmp_path, tiny_vram_profile)
r = client.post(
"/api/nodes-mgmt/nodes/heimdall/gpu/0/services",
json={"services": ["cf-text"]},
)
assert r.status_code == 422
assert "VRAM" in r.json()["detail"]
def test_update_services_unknown_service_returns_422(client, tmp_path):
_setup_profile(tmp_path)
r = client.post(
"/api/nodes-mgmt/nodes/heimdall/gpu/0/services",
json={"services": ["not-a-real-service"]},
)
assert r.status_code == 422
def test_update_services_reload_failure_returns_reloaded_false(client, tmp_path):
"""YAML saved but coordinator reload fails — ok: true, reloaded: false."""
_setup_profile(tmp_path)
mock_reload = MagicMock()
mock_reload.status_code = 500
with patch("httpx.post", return_value=mock_reload):
r = client.post(
"/api/nodes-mgmt/nodes/heimdall/gpu/0/services",
json={"services": ["ollama"]},
)
assert r.status_code == 200
data = r.json()
assert data["ok"] is True
assert data["reloaded"] is False
# ── Ollama endpoints ───────────────────────────────────────────────────────────
_OLLAMA_PROFILE = {
"services": {},
"nodes": {
"heimdall": {
"gpus": [],
"agent_url": "http://10.1.10.71:7701",
}
}
}
def test_list_ollama_models_proxies_tags(client, tmp_path):
profiles_dir = tmp_path / "profiles"
_write_config(tmp_path, {"profiles_dir": str(profiles_dir)})
_write_profile(profiles_dir, "heimdall", _OLLAMA_PROFILE)
mock_tags = MagicMock()
mock_tags.raise_for_status = MagicMock()
mock_tags.json.return_value = {
"models": [{"name": "nomic-embed-text", "size": 274000000, "modified_at": "2025-01-01"}]
}
with patch("httpx.get", return_value=mock_tags):
r = client.get("/api/nodes-mgmt/nodes/heimdall/models/ollama")
assert r.status_code == 200
data = r.json()
assert len(data["models"]) == 1
assert data["models"][0]["name"] == "nomic-embed-text"
def test_list_ollama_models_unreachable_returns_error(client, tmp_path):
import httpx as _httpx
profiles_dir = tmp_path / "profiles"
_write_config(tmp_path, {"profiles_dir": str(profiles_dir)})
_write_profile(profiles_dir, "heimdall", _OLLAMA_PROFILE)
with patch("httpx.get", side_effect=_httpx.ConnectError("refused")):
r = client.get("/api/nodes-mgmt/nodes/heimdall/models/ollama")
assert r.status_code == 200
data = r.json()
assert "error" in data
def test_pull_ollama_model_streams_sse(client, tmp_path):
profiles_dir = tmp_path / "profiles"
_write_config(tmp_path, {"profiles_dir": str(profiles_dir)})
_write_profile(profiles_dir, "heimdall", _OLLAMA_PROFILE)
mock_resp = MagicMock()
mock_resp.iter_lines.return_value = iter([
'{"status": "pulling manifest"}',
'{"status": "pulling", "digest": "sha256-abc", "total": 1000, "completed": 500}',
'{"status": "success"}',
])
with patch("httpx.stream") as mock_stream_fn:
mock_stream_fn.return_value.__enter__ = MagicMock(return_value=mock_resp)
mock_stream_fn.return_value.__exit__ = MagicMock(return_value=False)
r = client.post(
"/api/nodes-mgmt/nodes/heimdall/models/ollama/pull",
json={"name": "nomic-embed-text"},
)
assert r.status_code == 200
body = r.text
assert 'data: {"status": "pulling manifest"}' in body
assert 'data: {"status": "success"}' in body
def test_pull_ollama_model_error_event_in_stream(client, tmp_path):
profiles_dir = tmp_path / "profiles"
_write_config(tmp_path, {"profiles_dir": str(profiles_dir)})
_write_profile(profiles_dir, "heimdall", _OLLAMA_PROFILE)
mock_resp = MagicMock()
mock_resp.iter_lines.return_value = iter([
'{"error": "permission denied: /var/lib/ollama/sha256-abc-partial-0"}',
])
with patch("httpx.stream") as mock_stream_fn:
mock_stream_fn.return_value.__enter__ = MagicMock(return_value=mock_resp)
mock_stream_fn.return_value.__exit__ = MagicMock(return_value=False)
r = client.post(
"/api/nodes-mgmt/nodes/heimdall/models/ollama/pull",
json={"name": "nomic-embed-text"},
)
assert r.status_code == 200
assert "permission denied" in r.text
def test_delete_ollama_model_proxies_delete(client, tmp_path):
profiles_dir = tmp_path / "profiles"
_write_config(tmp_path, {"profiles_dir": str(profiles_dir)})
_write_profile(profiles_dir, "heimdall", _OLLAMA_PROFILE)
mock_del = MagicMock()
mock_del.status_code = 200
mock_del.raise_for_status = MagicMock()
with patch("httpx.request", return_value=mock_del):
r = client.delete("/api/nodes-mgmt/nodes/heimdall/models/ollama/nomic-embed-text")
assert r.status_code == 200
assert r.json() == {"ok": True}
def test_delete_ollama_model_404_when_not_found(client, tmp_path):
profiles_dir = tmp_path / "profiles"
_write_config(tmp_path, {"profiles_dir": str(profiles_dir)})
_write_profile(profiles_dir, "heimdall", _OLLAMA_PROFILE)
mock_del = MagicMock()
mock_del.status_code = 404
with patch("httpx.request", return_value=mock_del):
r = client.delete("/api/nodes-mgmt/nodes/heimdall/models/ollama/missing-model")
assert r.status_code == 404

View file

@ -1,4 +1,4 @@
"""API integration tests for app/sft.py /api/sft/* endpoints.""" """API integration tests for app/sft.py -- /api/sft/* endpoints."""
import json import json
import pytest import pytest
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
@ -7,17 +7,17 @@ from pathlib import Path
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def reset_sft_globals(tmp_path): def reset_sft_globals(tmp_path):
from app import sft as sft_module from app.data import corrections as corr_module
_prev_data = sft_module._SFT_DATA_DIR _prev_data = corr_module._DATA_DIR
_prev_cfg = sft_module._SFT_CONFIG_DIR _prev_cfg = corr_module._CONFIG_DIR
_prev_default = sft_module._DEFAULT_BENCH_RESULTS_DIR _prev_default = corr_module._DEFAULT_BENCH_RESULTS_DIR
sft_module.set_sft_data_dir(tmp_path) corr_module.set_data_dir(tmp_path)
sft_module.set_sft_config_dir(tmp_path) corr_module.set_config_dir(tmp_path)
sft_module.set_default_bench_results_dir(str(tmp_path / "bench_results")) corr_module.set_default_bench_results_dir(str(tmp_path / "bench_results"))
yield yield
sft_module.set_sft_data_dir(_prev_data) corr_module.set_data_dir(_prev_data)
sft_module.set_sft_config_dir(_prev_cfg) corr_module.set_config_dir(_prev_cfg)
sft_module.set_default_bench_results_dir(_prev_default) corr_module.set_default_bench_results_dir(_prev_default)
@pytest.fixture @pytest.fixture
@ -63,7 +63,7 @@ def _write_config(tmp_path, bench_results_dir: Path) -> None:
) )
# ── /api/sft/runs ────────────────────────────────────────────────────────── # -- /api/sft/runs -------------------------------------------------------------
def test_runs_returns_empty_when_no_config(client): def test_runs_returns_empty_when_no_config(client):
r = client.get("/api/sft/runs") r = client.get("/api/sft/runs")
@ -86,7 +86,7 @@ def test_runs_returns_available_runs(client, tmp_path):
def test_runs_marks_already_imported(client, tmp_path): def test_runs_marks_already_imported(client, tmp_path):
_write_run(tmp_path, "2026-04-07-143022", [_make_record("a")]) _write_run(tmp_path, "2026-04-07-143022", [_make_record("a")])
_write_config(tmp_path, tmp_path / "bench_results") _write_config(tmp_path, tmp_path / "bench_results")
from app import sft as sft_module from app.data import corrections as sft_module
candidates = sft_module._candidates_file() candidates = sft_module._candidates_file()
candidates.parent.mkdir(parents=True, exist_ok=True) candidates.parent.mkdir(parents=True, exist_ok=True)
candidates.write_text( candidates.write_text(
@ -97,7 +97,7 @@ def test_runs_marks_already_imported(client, tmp_path):
assert r.json()[0]["already_imported"] is True assert r.json()[0]["already_imported"] is True
# ── /api/sft/import ───────────────────────────────────────────────────────── # -- /api/sft/import -----------------------------------------------------------
def test_import_adds_records(client, tmp_path): def test_import_adds_records(client, tmp_path):
_write_run(tmp_path, "2026-04-07-143022", [_make_record("a"), _make_record("b")]) _write_run(tmp_path, "2026-04-07-143022", [_make_record("a"), _make_record("b")])
@ -121,10 +121,10 @@ def test_import_unknown_run_returns_404(client, tmp_path):
assert r.status_code == 404 assert r.status_code == 404
# ── /api/sft/queue ────────────────────────────────────────────────────────── # -- /api/sft/queue ------------------------------------------------------------
def _populate_candidates(tmp_path, records: list[dict]) -> None: def _populate_candidates(tmp_path, records: list[dict]) -> None:
from app import sft as sft_module from app.data import corrections as sft_module
path = sft_module._candidates_file() path = sft_module._candidates_file()
path.parent.mkdir(parents=True, exist_ok=True) path.parent.mkdir(parents=True, exist_ok=True)
path.write_text( path.write_text(
@ -164,7 +164,7 @@ def test_queue_empty_when_no_file(client):
assert r.json() == {"items": [], "total": 0, "page": 1, "per_page": 20} assert r.json() == {"items": [], "total": 0, "page": 1, "per_page": 20}
# ── /api/sft/submit ───────────────────────────────────────────────────────── # -- /api/sft/submit -----------------------------------------------------------
def test_submit_correct_sets_approved(client, tmp_path): def test_submit_correct_sets_approved(client, tmp_path):
_populate_candidates(tmp_path, [_make_record("a")]) _populate_candidates(tmp_path, [_make_record("a")])
@ -173,7 +173,7 @@ def test_submit_correct_sets_approved(client, tmp_path):
"corrected_response": "def add(a, b): return a + b", "corrected_response": "def add(a, b): return a + b",
}) })
assert r.status_code == 200 assert r.status_code == 200
from app import sft as sft_module from app.data import corrections as sft_module
records = sft_module._read_candidates() records = sft_module._read_candidates()
assert records[0]["status"] == "approved" assert records[0]["status"] == "approved"
assert records[0]["corrected_response"] == "def add(a, b): return a + b" assert records[0]["corrected_response"] == "def add(a, b): return a + b"
@ -185,7 +185,7 @@ def test_submit_correct_also_appends_to_approved_file(client, tmp_path):
"id": "a", "action": "correct", "id": "a", "action": "correct",
"corrected_response": "def add(a, b): return a + b", "corrected_response": "def add(a, b): return a + b",
}) })
from app import sft as sft_module from app.data import corrections as sft_module
from app.utils import read_jsonl from app.utils import read_jsonl
approved = read_jsonl(sft_module._approved_file()) approved = read_jsonl(sft_module._approved_file())
assert len(approved) == 1 assert len(approved) == 1
@ -196,7 +196,7 @@ def test_submit_discard_sets_discarded(client, tmp_path):
_populate_candidates(tmp_path, [_make_record("a")]) _populate_candidates(tmp_path, [_make_record("a")])
r = client.post("/api/sft/submit", json={"id": "a", "action": "discard"}) r = client.post("/api/sft/submit", json={"id": "a", "action": "discard"})
assert r.status_code == 200 assert r.status_code == 200
from app import sft as sft_module from app.data import corrections as sft_module
assert sft_module._read_candidates()[0]["status"] == "discarded" assert sft_module._read_candidates()[0]["status"] == "discarded"
@ -204,7 +204,7 @@ def test_submit_flag_sets_model_rejected(client, tmp_path):
_populate_candidates(tmp_path, [_make_record("a")]) _populate_candidates(tmp_path, [_make_record("a")])
r = client.post("/api/sft/submit", json={"id": "a", "action": "flag"}) r = client.post("/api/sft/submit", json={"id": "a", "action": "flag"})
assert r.status_code == 200 assert r.status_code == 200
from app import sft as sft_module from app.data import corrections as sft_module
assert sft_module._read_candidates()[0]["status"] == "model_rejected" assert sft_module._read_candidates()[0]["status"] == "model_rejected"
@ -243,7 +243,7 @@ def test_submit_correct_stores_failure_category(client, tmp_path):
"failure_category": "style_violation", "failure_category": "style_violation",
}) })
assert r.status_code == 200 assert r.status_code == 200
from app import sft as sft_module from app.data import corrections as sft_module
records = sft_module._read_candidates() records = sft_module._read_candidates()
assert records[0]["failure_category"] == "style_violation" assert records[0]["failure_category"] == "style_violation"
@ -255,7 +255,7 @@ def test_submit_correct_null_failure_category(client, tmp_path):
"corrected_response": "def add(a, b): return a + b", "corrected_response": "def add(a, b): return a + b",
}) })
assert r.status_code == 200 assert r.status_code == 200
from app import sft as sft_module from app.data import corrections as sft_module
records = sft_module._read_candidates() records = sft_module._read_candidates()
assert records[0]["failure_category"] is None assert records[0]["failure_category"] is None
@ -270,14 +270,14 @@ def test_submit_invalid_failure_category_returns_422(client, tmp_path):
assert r.status_code == 422 assert r.status_code == 422
# ── /api/sft/undo ──────────────────────────────────────────────────────────── # -- /api/sft/undo -------------------------------------------------------------
def test_undo_restores_discarded_to_needs_review(client, tmp_path): def test_undo_restores_discarded_to_needs_review(client, tmp_path):
_populate_candidates(tmp_path, [_make_record("a")]) _populate_candidates(tmp_path, [_make_record("a")])
client.post("/api/sft/submit", json={"id": "a", "action": "discard"}) client.post("/api/sft/submit", json={"id": "a", "action": "discard"})
r = client.post("/api/sft/undo", json={"id": "a"}) r = client.post("/api/sft/undo", json={"id": "a"})
assert r.status_code == 200 assert r.status_code == 200
from app import sft as sft_module from app.data import corrections as sft_module
assert sft_module._read_candidates()[0]["status"] == "needs_review" assert sft_module._read_candidates()[0]["status"] == "needs_review"
@ -288,7 +288,7 @@ def test_undo_removes_approved_from_approved_file(client, tmp_path):
"corrected_response": "def add(a, b): return a + b", "corrected_response": "def add(a, b): return a + b",
}) })
client.post("/api/sft/undo", json={"id": "a"}) client.post("/api/sft/undo", json={"id": "a"})
from app import sft as sft_module from app.data import corrections as sft_module
from app.utils import read_jsonl from app.utils import read_jsonl
approved = read_jsonl(sft_module._approved_file()) approved = read_jsonl(sft_module._approved_file())
assert not any(r["id"] == "a" for r in approved) assert not any(r["id"] == "a" for r in approved)
@ -300,10 +300,10 @@ def test_undo_already_needs_review_returns_409(client, tmp_path):
assert r.status_code == 409 assert r.status_code == 409
# ── /api/sft/export ────────────────────────────────────────────────────────── # -- /api/sft/export -----------------------------------------------------------
def test_export_returns_approved_as_sft_jsonl(client, tmp_path): def test_export_returns_approved_as_sft_jsonl(client, tmp_path):
from app import sft as sft_module from app.data import corrections as sft_module
from app.utils import write_jsonl from app.utils import write_jsonl
approved = { approved = {
**_make_record("a"), **_make_record("a"),
@ -331,7 +331,7 @@ def test_export_returns_approved_as_sft_jsonl(client, tmp_path):
def test_export_excludes_non_approved(client, tmp_path): def test_export_excludes_non_approved(client, tmp_path):
from app import sft as sft_module from app.data import corrections as sft_module
from app.utils import write_jsonl from app.utils import write_jsonl
records = [ records = [
{**_make_record("a"), "status": "discarded", "corrected_response": None}, {**_make_record("a"), "status": "discarded", "corrected_response": None},
@ -348,10 +348,10 @@ def test_export_empty_when_no_approved_file(client):
assert r.text.strip() == "" assert r.text.strip() == ""
# ── /api/sft/stats ─────────────────────────────────────────────────────────── # -- /api/sft/stats ------------------------------------------------------------
def test_stats_counts_by_status(client, tmp_path): def test_stats_counts_by_status(client, tmp_path):
from app import sft as sft_module from app.data import corrections as sft_module
from app.utils import write_jsonl from app.utils import write_jsonl
records = [ records = [
_make_record("a"), _make_record("a"),

187
tests/test_train.py Normal file
View file

@ -0,0 +1,187 @@
"""Tests for app/train/train.py -- /api/train/* endpoints."""
import json
import pytest
from fastapi.testclient import TestClient
from unittest.mock import MagicMock, patch
@pytest.fixture(autouse=True)
def reset_globals(tmp_path):
from app.train import train as train_module
train_module.set_db_path(tmp_path / "train_jobs.db")
train_module.set_models_dir(tmp_path / "models")
train_module._running_procs.clear()
yield
train_module._running_procs.clear()
@pytest.fixture
def client():
from app.api import app
return TestClient(app)
def _parse_sse(content: bytes) -> list[dict]:
events = []
for line in content.decode().splitlines():
if line.startswith("data: "):
events.append(json.loads(line[6:]))
return events
def test_list_jobs_empty(client):
r = client.get("/api/train/jobs")
assert r.status_code == 200
assert r.json() == {"jobs": []}
def test_create_job_returns_queued_record(client):
r = client.post("/api/train/jobs",
json={"type": "classifier", "model_key": "deberta-small",
"config_json": {"epochs": 3}})
assert r.status_code == 200
data = r.json()
assert data["status"] == "queued"
assert data["type"] == "classifier"
assert data["model_key"] == "deberta-small"
assert "id" in data
def test_create_job_invalid_type_returns_400(client):
r = client.post("/api/train/jobs",
json={"type": "unknown-type", "model_key": "deberta-small"})
assert r.status_code == 400
def test_create_job_appears_in_list(client):
client.post("/api/train/jobs", json={"type": "classifier", "model_key": "deberta-small"})
r = client.get("/api/train/jobs")
assert r.status_code == 200
assert len(r.json()["jobs"]) == 1
def test_get_job_returns_record(client):
r = client.post("/api/train/jobs", json={"type": "classifier", "model_key": "deberta-small"})
job_id = r.json()["id"]
r2 = client.get(f"/api/train/jobs/{job_id}")
assert r2.status_code == 200
assert r2.json()["id"] == job_id
def test_get_job_404_for_unknown(client):
r = client.get("/api/train/jobs/no-such-id")
assert r.status_code == 404
def test_cancel_queued_job(client):
r = client.post("/api/train/jobs", json={"type": "classifier", "model_key": "deberta-small"})
job_id = r.json()["id"]
r2 = client.delete(f"/api/train/jobs/{job_id}/cancel")
assert r2.status_code == 200
assert r2.json()["status"] == "cancelled"
r3 = client.get(f"/api/train/jobs/{job_id}")
assert r3.json()["status"] == "cancelled"
def test_cancel_completed_job_returns_409(client):
from app.train import train as train_module
train_module._init_db()
with train_module._db() as conn:
conn.execute(
"INSERT INTO jobs (id, type, model_key, status, config_json, created_at) "
"VALUES ('abc', 'classifier', 'deberta-small', 'completed', '{}', '2026-05-01T00:00:00Z')"
)
r = client.delete("/api/train/jobs/abc/cancel")
assert r.status_code == 409
def test_cancel_terminates_running_proc(client):
from app.train import train as train_module
mock_proc = MagicMock()
mock_proc.wait = MagicMock()
r = client.post("/api/train/jobs", json={"type": "classifier", "model_key": "deberta-small"})
job_id = r.json()["id"]
train_module._running_procs[job_id] = mock_proc
with train_module._db() as conn:
conn.execute("UPDATE jobs SET status='running' WHERE id=?", (job_id,))
r2 = client.delete(f"/api/train/jobs/{job_id}/cancel")
assert r2.status_code == 200
mock_proc.terminate.assert_called_once()
def test_run_job_streams_sse(client):
r = client.post("/api/train/jobs", json={"type": "classifier", "model_key": "deberta-small"})
job_id = r.json()["id"]
mock_proc = MagicMock()
mock_proc.stdout = iter(["Epoch 1\n", "Done\n"])
mock_proc.returncode = 0
mock_proc.wait = MagicMock()
with patch("app.train.train._subprocess.Popen", return_value=mock_proc):
r2 = client.get(f"/api/train/jobs/{job_id}/run")
assert r2.status_code == 200
assert "text/event-stream" in r2.headers.get("content-type", "")
events = _parse_sse(r2.content)
assert any(e["type"] == "complete" for e in events)
def test_run_job_marks_completed_in_db(client):
from app.train import train as train_module
r = client.post("/api/train/jobs", json={"type": "classifier", "model_key": "deberta-small"})
job_id = r.json()["id"]
mock_proc = MagicMock()
mock_proc.stdout = iter([])
mock_proc.returncode = 0
mock_proc.wait = MagicMock()
with patch("app.train.train._subprocess.Popen", return_value=mock_proc):
client.get(f"/api/train/jobs/{job_id}/run")
r2 = client.get(f"/api/train/jobs/{job_id}")
assert r2.json()["status"] == "completed"
def test_run_job_marks_failed_on_nonzero_exit(client):
r = client.post("/api/train/jobs", json={"type": "classifier", "model_key": "deberta-small"})
job_id = r.json()["id"]
mock_proc = MagicMock()
mock_proc.stdout = iter([])
mock_proc.returncode = 1
mock_proc.wait = MagicMock()
with patch("app.train.train._subprocess.Popen", return_value=mock_proc):
client.get(f"/api/train/jobs/{job_id}/run")
r2 = client.get(f"/api/train/jobs/{job_id}")
assert r2.json()["status"] == "failed"
def test_run_nonqueued_job_returns_409(client):
from app.train import train as train_module
train_module._init_db()
with train_module._db() as conn:
conn.execute(
"INSERT INTO jobs (id, type, model_key, status, config_json, created_at) "
"VALUES ('xyz', 'classifier', 'deberta-small', 'running', '{}', '2026-05-01T00:00:00Z')"
)
r = client.get("/api/train/jobs/xyz/run")
assert r.status_code == 409
def test_run_unknown_job_returns_404(client):
r = client.get("/api/train/jobs/no-such/run")
assert r.status_code == 404
def test_results_empty_when_no_models_dir(client):
r = client.get("/api/train/results")
assert r.status_code == 200
assert r.json() == {"results": []}
def test_results_returns_training_info(client, tmp_path):
from app.train import train as train_module
models_dir = tmp_path / "models" / "avocet-deberta-small"
models_dir.mkdir(parents=True)
train_module.set_models_dir(tmp_path / "models")
info = {"name": "avocet-deberta-small", "val_macro_f1": 0.712, "sample_count": 401}
(models_dir / "training_info.json").write_text(json.dumps(info))
r = client.get("/api/train/results")
assert r.status_code == 200
data = r.json()
assert any(d["name"] == "avocet-deberta-small" for d in data["results"])

View file

@ -0,0 +1,124 @@
import { mount, flushPromises } from '@vue/test-utils'
import { createRouter, createWebHashHistory } from 'vue-router'
import { describe, it, expect, vi, beforeEach } from 'vitest'
import AppSidebar from './AppSidebar.vue'
// Minimal router so RouterLink renders without warnings
const router = createRouter({
history: createWebHashHistory(),
routes: [
{ path: '/', component: { template: '<div />' } },
{ path: '/fleet', component: { template: '<div />' } },
{ path: '/data/label', component: { template: '<div />' } },
{ path: '/data/fetch', component: { template: '<div />' } },
{ path: '/data/corrections', component: { template: '<div />' } },
{ path: '/data/imitate', component: { template: '<div />' } },
{ path: '/eval/benchmark', component: { template: '<div />' } },
{ path: '/eval/compare', component: { template: '<div />' } },
{ path: '/train/jobs', component: { template: '<div />' } },
{ path: '/train/results', component: { template: '<div />' } },
{ path: '/settings', component: { template: '<div />' } },
],
})
function makeFetch(signals: Record<string, boolean> = {}) {
return vi.fn().mockResolvedValue({
ok: true,
json: async () => ({
labeled_since_last_eval: 0,
last_eval_timestamp: null,
last_eval_best_score: null,
active_jobs: [],
corrections_export_ready: 0,
signals,
}),
text: async () => '',
})
}
beforeEach(() => {
localStorage.clear()
vi.stubGlobal('fetch', makeFetch())
})
describe('AppSidebar structure', () => {
it('renders section headers for Data, Eval, Train', async () => {
const w = mount(AppSidebar, { global: { plugins: [router] } })
await flushPromises()
const text = w.text()
expect(text).toContain('Data')
expect(text).toContain('Eval')
expect(text).toContain('Train')
})
it('renders all sub-links', async () => {
const w = mount(AppSidebar, { global: { plugins: [router] } })
await flushPromises()
const anchors = w.findAll('a')
const hrefs = anchors.map(a => a.attributes('href') ?? '')
expect(hrefs.some(h => h.includes('/data/label'))).toBe(true)
expect(hrefs.some(h => h.includes('/data/fetch'))).toBe(true)
expect(hrefs.some(h => h.includes('/data/corrections'))).toBe(true)
expect(hrefs.some(h => h.includes('/data/imitate'))).toBe(true)
expect(hrefs.some(h => h.includes('/eval/benchmark'))).toBe(true)
expect(hrefs.some(h => h.includes('/eval/compare'))).toBe(true)
expect(hrefs.some(h => h.includes('/train/jobs'))).toBe(true)
expect(hrefs.some(h => h.includes('/train/results'))).toBe(true)
expect(hrefs.some(h => h.includes('/fleet'))).toBe(true)
expect(hrefs.some(h => h.includes('/settings'))).toBe(true)
})
it('does NOT render the old /benchmark or /models links', async () => {
const w = mount(AppSidebar, { global: { plugins: [router] } })
await flushPromises()
const anchors = w.findAll('a')
const hrefs = anchors.map(a => a.attributes('href') ?? '')
// Old paths must not appear as direct links (they're only redirects)
expect(hrefs.every(h => !h.endsWith('/#/benchmark'))).toBe(true)
expect(hrefs.every(h => !h.endsWith('/#/models'))).toBe(true)
expect(hrefs.every(h => !h.endsWith('/#/stats'))).toBe(true)
})
it('shows no signal badges when all signals are false', async () => {
vi.stubGlobal('fetch', makeFetch({ data_to_eval: false, eval_to_train: false, train_to_fleet: false }))
const w = mount(AppSidebar, { global: { plugins: [router] } })
await flushPromises()
expect(w.findAll('.signal-badge').length).toBe(0)
})
it('shows signal badge on Data section when data_to_eval is true', async () => {
vi.stubGlobal('fetch', makeFetch({ data_to_eval: true, eval_to_train: false, train_to_fleet: false }))
const w = mount(AppSidebar, { global: { plugins: [router] } })
await flushPromises()
const badges = w.findAll('.signal-badge')
expect(badges.length).toBe(1)
// It should be inside the Data section header
const dataHeader = w.find('[data-section="data"]')
expect(dataHeader.find('.signal-badge').exists()).toBe(true)
})
it('shows signal badge on Eval section when eval_to_train is true', async () => {
vi.stubGlobal('fetch', makeFetch({ data_to_eval: false, eval_to_train: true, train_to_fleet: false }))
const w = mount(AppSidebar, { global: { plugins: [router] } })
await flushPromises()
const evalHeader = w.find('[data-section="eval"]')
expect(evalHeader.find('.signal-badge').exists()).toBe(true)
})
it('shows signal badge on Train section when train_to_fleet is true', async () => {
vi.stubGlobal('fetch', makeFetch({ data_to_eval: false, eval_to_train: false, train_to_fleet: true }))
const w = mount(AppSidebar, { global: { plugins: [router] } })
await flushPromises()
const trainHeader = w.find('[data-section="train"]')
expect(trainHeader.find('.signal-badge').exists()).toBe(true)
})
it('stow toggle still works', async () => {
const w = mount(AppSidebar, { global: { plugins: [router] } })
await flushPromises()
const nav = w.find('nav')
expect(nav.classes()).not.toContain('stowed')
await w.find('.stow-btn').trigger('click')
expect(nav.classes()).toContain('stowed')
})
})

View file

@ -28,12 +28,70 @@
</button> </button>
</div> </div>
<!-- Nav items --> <!-- Nav -->
<ul class="nav-list" role="list"> <ul class="nav-list" role="list">
<li v-for="item in navItems" :key="item.path"> <!-- Top-level links -->
<li>
<RouterLink
to="/"
class="nav-item"
:title="stowed ? 'Dashboard' : ''"
@click="isMobile && stow()"
>
<span class="nav-icon" aria-hidden="true">📊</span>
<span v-if="!stowed" class="nav-label">Dashboard</span>
</RouterLink>
</li>
<li>
<RouterLink
to="/fleet"
class="nav-item"
:title="stowed ? 'Fleet' : ''"
@click="isMobile && stow()"
>
<span class="nav-icon" aria-hidden="true"></span>
<span v-if="!stowed" class="nav-label">Fleet</span>
</RouterLink>
</li>
<li>
<RouterLink
to="/nodes"
class="nav-item"
:title="stowed ? 'Nodes' : ''"
@click="isMobile && stow()"
>
<span class="nav-icon" aria-hidden="true">🖥</span>
<span v-if="!stowed" class="nav-label">Nodes</span>
</RouterLink>
</li>
<!-- Data section -->
<li>
<div class="section-header" data-section="data" aria-hidden="true">
<template v-if="!stowed">
<span class="section-label"> Data</span>
<span
v-if="signals.data_to_eval"
class="signal-badge"
title="Enough new labels to run eval"
aria-label="Eval recommended"
/>
</template>
<template v-else>
<span class="section-icon"></span>
<span
v-if="signals.data_to_eval"
class="signal-badge signal-badge-stowed"
title="Eval recommended"
aria-label="Eval recommended"
/>
</template>
</div>
</li>
<li v-for="item in dataItems" :key="item.path">
<RouterLink <RouterLink
:to="item.path" :to="item.path"
class="nav-item" class="nav-item nav-subitem"
:title="stowed ? item.label : ''" :title="stowed ? item.label : ''"
@click="isMobile && stow()" @click="isMobile && stow()"
> >
@ -41,10 +99,94 @@
<span v-if="!stowed" class="nav-label">{{ item.label }}</span> <span v-if="!stowed" class="nav-label">{{ item.label }}</span>
</RouterLink> </RouterLink>
</li> </li>
<!-- Eval section -->
<li>
<div class="section-header" data-section="eval" aria-hidden="true">
<template v-if="!stowed">
<span class="section-label"> Eval</span>
<span
v-if="signals.eval_to_train"
class="signal-badge"
title="Strong eval result — consider finetuning"
aria-label="Finetune recommended"
/>
</template>
<template v-else>
<span class="section-icon"></span>
<span
v-if="signals.eval_to_train"
class="signal-badge signal-badge-stowed"
title="Finetune recommended"
aria-label="Finetune recommended"
/>
</template>
</div>
</li>
<li v-for="item in evalItems" :key="item.path">
<RouterLink
:to="item.path"
class="nav-item nav-subitem"
:title="stowed ? item.label : ''"
@click="isMobile && stow()"
>
<span class="nav-icon" aria-hidden="true">{{ item.icon }}</span>
<span v-if="!stowed" class="nav-label">{{ item.label }}</span>
</RouterLink>
</li>
<!-- Train section -->
<li>
<div class="section-header" data-section="train" aria-hidden="true">
<template v-if="!stowed">
<span class="section-label"> Train</span>
<span
v-if="signals.train_to_fleet"
class="signal-badge"
title="Trained model ready for fleet registration"
aria-label="Fleet registration recommended"
/>
</template>
<template v-else>
<span class="section-icon"></span>
<span
v-if="signals.train_to_fleet"
class="signal-badge signal-badge-stowed"
title="Fleet registration recommended"
aria-label="Fleet registration recommended"
/>
</template>
</div>
</li>
<li v-for="item in trainItems" :key="item.path">
<RouterLink
:to="item.path"
class="nav-item nav-subitem"
:title="stowed ? item.label : ''"
@click="isMobile && stow()"
>
<span class="nav-icon" aria-hidden="true">{{ item.icon }}</span>
<span v-if="!stowed" class="nav-label">{{ item.label }}</span>
</RouterLink>
</li>
<!-- Divider + Settings -->
<li class="nav-divider" aria-hidden="true" />
<li>
<RouterLink
to="/settings"
class="nav-item"
:title="stowed ? 'Settings' : ''"
@click="isMobile && stow()"
>
<span class="nav-icon" aria-hidden="true"></span>
<span v-if="!stowed" class="nav-label">Settings</span>
</RouterLink>
</li>
</ul> </ul>
</nav> </nav>
<!-- Mobile hamburger button rendered outside the sidebar so it's visible when stowed --> <!-- Mobile hamburger button visible when sidebar is stowed on mobile -->
<button <button
v-if="isMobile && stowed" v-if="isMobile && stowed"
class="mobile-hamburger" class="mobile-hamburger"
@ -61,25 +203,66 @@ import { RouterLink } from 'vue-router'
const LS_KEY = 'cf-avocet-nav-stowed' const LS_KEY = 'cf-avocet-nav-stowed'
const navItems = [ interface NavItem {
{ path: '/', icon: '🃏', label: 'Label' }, path: string
{ path: '/fetch', icon: '📥', label: 'Fetch' }, icon: string
{ path: '/stats', icon: '📊', label: 'Stats' }, label: string
{ path: '/benchmark', icon: '🏁', label: 'Benchmark' }, }
{ path: '/models', icon: '🤗', label: 'Models' },
{ path: '/imitate', icon: '🪞', label: 'Imitate' }, interface DashboardSignals {
{ path: '/corrections', icon: '✍️', label: 'Corrections' }, data_to_eval: boolean
{ path: '/settings', icon: '⚙️', label: 'Settings' }, eval_to_train: boolean
train_to_fleet: boolean
}
const dataItems: NavItem[] = [
{ path: '/data/label', icon: '🏷', label: 'Label' },
{ path: '/data/fetch', icon: '📬', label: 'Fetch' },
{ path: '/data/corrections', icon: '✏️', label: 'Corrections' },
{ path: '/data/imitate', icon: '🪞', label: 'Imitate' },
] ]
const stowed = ref(localStorage.getItem(LS_KEY) === 'true') const evalItems: NavItem[] = [
const winWidth = ref(window.innerWidth) { path: '/eval/benchmark', icon: '📊', label: 'Benchmark' },
const isMobile = computed(() => winWidth.value < 640) { path: '/eval/compare', icon: '🔍', label: 'Compare' },
]
const trainItems: NavItem[] = [
{ path: '/train/jobs', icon: '🧠', label: 'Jobs' },
{ path: '/train/results', icon: '📈', label: 'Results' },
]
const stowed = ref(localStorage.getItem(LS_KEY) === 'true')
const winWidth = ref(window.innerWidth)
const isMobile = computed(() => winWidth.value < 640)
const signals = ref<DashboardSignals>({
data_to_eval: false,
eval_to_train: false,
train_to_fleet: false,
})
async function loadSignals() {
try {
const res = await fetch('/api/dashboard')
if (res.ok) {
const data = await res.json() as { signals?: DashboardSignals }
if (data.signals) {
signals.value = {
data_to_eval: data.signals.data_to_eval ?? false,
eval_to_train: data.signals.eval_to_train ?? false,
train_to_fleet: data.signals.train_to_fleet ?? false,
}
}
}
} catch {
// Non-fatal: badges simply stay hidden if API is unreachable
}
}
function toggle() { function toggle() {
stowed.value = !stowed.value stowed.value = !stowed.value
localStorage.setItem(LS_KEY, String(stowed.value)) localStorage.setItem(LS_KEY, String(stowed.value))
// Update CSS variable on :root so .app-main margin-left syncs
document.documentElement.style.setProperty('--sidebar-width', stowed.value ? '56px' : '200px') document.documentElement.style.setProperty('--sidebar-width', stowed.value ? '56px' : '200px')
} }
@ -93,13 +276,12 @@ function onResize() { winWidth.value = window.innerWidth }
onMounted(() => { onMounted(() => {
window.addEventListener('resize', onResize) window.addEventListener('resize', onResize)
// Apply persisted sidebar width to :root on mount
document.documentElement.style.setProperty('--sidebar-width', stowed.value ? '56px' : '200px') document.documentElement.style.setProperty('--sidebar-width', stowed.value ? '56px' : '200px')
// On mobile, default to stowed
if (isMobile.value && !localStorage.getItem(LS_KEY)) { if (isMobile.value && !localStorage.getItem(LS_KEY)) {
stowed.value = true stowed.value = true
document.documentElement.style.setProperty('--sidebar-width', '56px') document.documentElement.style.setProperty('--sidebar-width', '56px')
} }
loadSignals()
}) })
onUnmounted(() => window.removeEventListener('resize', onResize)) onUnmounted(() => window.removeEventListener('resize', onResize))
@ -121,18 +303,15 @@ onUnmounted(() => window.removeEventListener('resize', onResize))
overflow: hidden; overflow: hidden;
} }
.sidebar.stowed { .sidebar.stowed { width: 56px; }
width: 56px;
}
/* Mobile: slide in/out from left */
.sidebar.mobile { .sidebar.mobile {
box-shadow: 2px 0 16px rgba(0, 0, 0, 0.15); box-shadow: 2px 0 16px rgba(0, 0, 0, 0.15);
} }
.sidebar.mobile.stowed { .sidebar.mobile.stowed {
transform: translateX(-100%); transform: translateX(-100%);
width: 200px; /* keep width so slide-in looks right */ width: 200px;
transition: transform 250ms ease, width 250ms ease; transition: transform 250ms ease, width 250ms ease;
} }
@ -165,10 +344,7 @@ onUnmounted(() => window.removeEventListener('resize', onResize))
white-space: nowrap; white-space: nowrap;
} }
.logo-icon { .logo-icon { font-size: 1.25rem; flex-shrink: 0; }
font-size: 1.25rem;
flex-shrink: 0;
}
.logo-name { .logo-name {
font-family: var(--font-display, var(--font-body, sans-serif)); font-family: var(--font-display, var(--font-body, sans-serif));
@ -193,16 +369,76 @@ onUnmounted(() => window.removeEventListener('resize', onResize))
transition: background 0.15s; transition: background 0.15s;
} }
.stow-btn:hover { .stow-btn:hover { background: var(--color-border, #d0d7e8); }
background: var(--color-border, #d0d7e8);
}
.nav-list { .nav-list {
list-style: none; list-style: none;
padding: 0.5rem 0; padding: 0.5rem 0;
flex: 1; flex: 1;
overflow-y: auto;
overflow-x: hidden;
} }
/* ── Section headers ── */
.section-header {
display: flex;
align-items: center;
gap: 0.4rem;
padding: 0.55rem 0.75rem 0.25rem;
margin-top: 0.5rem;
pointer-events: none;
user-select: none;
}
.section-label {
font-size: 0.7rem;
font-weight: 700;
text-transform: uppercase;
letter-spacing: 0.07em;
color: var(--color-text-muted, #4a5c7a);
white-space: nowrap;
flex: 1;
}
.section-icon {
font-size: 0.75rem;
color: var(--color-text-muted, #4a5c7a);
width: 24px;
text-align: center;
flex-shrink: 0;
}
/* ── Signal badges ── */
.signal-badge {
width: 8px;
height: 8px;
border-radius: 50%;
background: var(--color-warning, #d4891a);
flex-shrink: 0;
display: inline-block;
}
.signal-badge-stowed {
position: absolute;
top: 4px;
right: 4px;
}
/* Make the stowed section header container position:relative for the badge */
.sidebar.stowed .section-header {
position: relative;
justify-content: center;
padding: 0.55rem 0 0.25rem;
}
/* ── Nav divider ── */
.nav-divider {
height: 1px;
background: var(--color-border, #d0d7e8);
margin: 0.5rem 0.75rem;
}
/* ── Nav items ── */
.nav-item { .nav-item {
display: flex; display: flex;
align-items: center; align-items: center;
@ -238,6 +474,9 @@ onUnmounted(() => window.removeEventListener('resize', onResize))
border-radius: 0 2px 2px 0; border-radius: 0 2px 2px 0;
} }
/* Sub-items are indented slightly in expanded state */
.nav-subitem { padding-left: 1.1rem; font-size: 0.875rem; }
.nav-icon { .nav-icon {
font-size: 1.1rem; font-size: 1.1rem;
flex-shrink: 0; flex-shrink: 0;
@ -245,12 +484,9 @@ onUnmounted(() => window.removeEventListener('resize', onResize))
text-align: center; text-align: center;
} }
.nav-label { .nav-label { overflow: hidden; text-overflow: ellipsis; }
overflow: hidden;
text-overflow: ellipsis;
}
/* Mobile hamburger — visible when sidebar is stowed on mobile */ /* Mobile hamburger */
.mobile-hamburger { .mobile-hamburger {
position: fixed; position: fixed;
top: 0.75rem; top: 0.75rem;

View file

@ -0,0 +1,129 @@
<script setup lang="ts">
import { ref, computed } from 'vue'
import ServiceBadge from './ServiceBadge.vue'
import type { GpuEntry, ServiceInfo } from '../../types/nodes'
const props = defineProps<{
gpu: GpuEntry
nodeId: string
profileLoaded: boolean
servicesCatalog: Record<string, ServiceInfo>
}>()
const emit = defineEmits<{ updated: [] }>()
const saving = ref(false)
const saveError = ref('')
const vramPct = computed(() => {
if (!props.gpu.vram_total_mb) return 0
return Math.round((props.gpu.vram_used_mb / props.gpu.vram_total_mb) * 100)
})
function serviceState(svcName: string): 'running' | 'stopped' | 'assigned-only' | 'available' | 'incompatible' | 'unknown' {
const svc = props.servicesCatalog[svcName]
if (!svc) return 'unknown'
const cap = props.gpu.compute_cap ?? 0
if (cap < svc.min_compute_cap) return 'incompatible'
if (props.gpu.services_running.includes(svcName)) return 'running'
if (props.gpu.services_assigned.includes(svcName)) return 'assigned-only'
return 'available'
}
async function toggleService(svcName: string) {
if (!props.profileLoaded || saving.value) return
const current = [...props.gpu.services_assigned]
const removing = current.includes(svcName)
if (removing && !confirm(`Remove ${svcName} from GPU ${props.gpu.gpu_id}?`)) return
const next = removing ? current.filter(s => s !== svcName) : [...current, svcName]
saving.value = true
saveError.value = ''
try {
const r = await fetch(
`/api/nodes-mgmt/nodes/${props.nodeId}/gpu/${props.gpu.gpu_id}/services`,
{
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ services: next }),
},
)
if (!r.ok) {
const data = await r.json().catch(() => ({}))
throw new Error((data as { detail?: string }).detail ?? `HTTP ${r.status}`)
}
const data = await r.json() as { ok: boolean; reloaded: boolean; warnings: string[] }
if (data.warnings?.length) saveError.value = `Saved (warning: ${data.warnings.join(', ')})`
emit('updated')
} catch (e) {
saveError.value = e instanceof Error ? e.message : 'Failed to update services'
} finally {
saving.value = false
}
}
</script>
<template>
<div class="gpu-row">
<div class="gpu-info">
<span class="gpu-label">GPU {{ gpu.gpu_id }}: {{ gpu.card }}</span>
<span v-if="gpu.compute_cap != null" class="gpu-meta">sm{{ gpu.compute_cap }}</span>
<span v-if="gpu.temp_c != null" class="gpu-meta">{{ gpu.temp_c }}°C</span>
<span v-if="gpu.utilization_pct != null" class="gpu-meta">{{ gpu.utilization_pct }}%</span>
</div>
<div class="vram-wrap">
<div
class="vram-bar"
role="progressbar"
:aria-valuenow="gpu.vram_used_mb"
aria-valuemin="0"
:aria-valuemax="gpu.vram_total_mb"
:aria-label="`VRAM: ${gpu.vram_used_mb} of ${gpu.vram_total_mb} MB used`"
>
<div class="vram-fill" :style="{ width: `${vramPct}%` }" />
</div>
<span class="vram-text">{{ gpu.vram_used_mb }} / {{ gpu.vram_total_mb }} MB ({{ vramPct }}%)</span>
</div>
<div v-if="profileLoaded" class="services-row" aria-label="Service assignments">
<ServiceBadge
v-for="(_, svcName) in servicesCatalog"
:key="String(svcName)"
:service-name="String(svcName)"
:state="serviceState(String(svcName))"
:assigned="gpu.services_assigned.includes(String(svcName))"
:disabled="saving"
@toggle="toggleService(String(svcName))"
/>
</div>
<div v-if="saveError" class="save-msg" role="alert">{{ saveError }}</div>
</div>
</template>
<style scoped>
.gpu-row {
padding: 0.5rem 0.75rem;
border-radius: 4px;
background: var(--bg-secondary, #111);
display: flex;
flex-direction: column;
gap: 0.4rem;
}
.gpu-info { display: flex; gap: 0.75rem; align-items: center; flex-wrap: wrap; font-size: 0.875rem; }
.gpu-label { font-weight: 500; }
.gpu-meta { color: var(--text-secondary, #888); font-size: 0.8rem; }
.vram-wrap { display: flex; align-items: center; gap: 0.5rem; }
.vram-bar {
flex: 1;
height: 8px;
background: var(--bg-bar, #2a2a2a);
border-radius: 4px;
overflow: hidden;
}
.vram-fill { height: 100%; background: var(--color-primary, #4080ff); transition: width 0.3s; }
.vram-text { font-size: 0.75rem; color: var(--text-secondary, #888); white-space: nowrap; }
.services-row { display: flex; flex-wrap: wrap; gap: 0.4rem; }
.save-msg { color: var(--color-warning, #ed8936); font-size: 0.8rem; }
</style>

View file

@ -0,0 +1,132 @@
<script setup lang="ts">
import { ref, onMounted, onUnmounted } from 'vue'
interface CatalogEntry {
path: string
vram_mb: number
description: string
multi_gpu: boolean
}
interface ServiceProfile {
catalog: Record<string, CatalogEntry>
min_compute_cap: number
max_mb: number
}
interface NodeProfile {
services: Record<string, ServiceProfile>
}
const props = defineProps<{
nodeId: string
}>()
const profile = ref<NodeProfile | null>(null)
const loading = ref(true)
const error = ref('')
let fetchAbort: AbortController | null = null
async function fetchProfile() {
fetchAbort?.abort()
fetchAbort = new AbortController()
loading.value = true
error.value = ''
try {
const r = await fetch(`/api/nodes-mgmt/nodes/${props.nodeId}/profile`, {
signal: fetchAbort.signal,
})
if (r.status === 404) { profile.value = null; return }
if (!r.ok) throw new Error(`HTTP ${r.status}`)
profile.value = await r.json() as NodeProfile
} catch (e) {
if (e instanceof Error && e.name === 'AbortError') return
error.value = e instanceof Error ? e.message : 'Failed to load profile'
} finally {
loading.value = false
}
}
onMounted(fetchProfile)
onUnmounted(() => { fetchAbort?.abort() })
</script>
<template>
<section class="hf-panel">
<h3 class="panel-title">Model Catalog</h3>
<p class="hf-hint">
To download a new HuggingFace model,
<a href="#/fleet" class="hf-link">go to Fleet</a>.
Models downloaded there are automatically registered in node catalogs.
</p>
<div aria-live="polite" aria-atomic="true" class="sr-announce">
<span v-if="loading">Loading catalog...</span>
</div>
<div v-if="error" class="panel-error" role="alert">{{ error }}</div>
<div v-else-if="!loading && !profile" class="panel-empty">No profile loaded for this node.</div>
<div v-else-if="!loading && profile" class="catalog-body">
<div
v-for="(svcInfo, svcName) in profile.services"
:key="String(svcName)"
class="svc-section"
>
<h4 class="svc-name">{{ svcName }}</h4>
<ul class="catalog-list" role="list">
<li
v-if="!Object.keys(svcInfo.catalog ?? {}).length"
class="catalog-empty"
>
No models in catalog.
</li>
<li
v-for="(entry, modelName) in (svcInfo.catalog ?? {})"
:key="String(modelName)"
class="catalog-item"
>
<span class="catalog-model">{{ modelName }}</span>
<span class="catalog-vram">{{ entry.vram_mb }} MB</span>
<span v-if="entry.description" class="catalog-desc">{{ entry.description }}</span>
</li>
</ul>
</div>
</div>
</section>
</template>
<style scoped>
.hf-panel {
margin-top: 0.75rem;
padding: 0.75rem;
border: 1px solid var(--border, #333);
border-radius: 6px;
}
.panel-title { margin: 0 0 0.5rem; font-size: 0.9rem; }
.hf-hint { font-size: 0.8rem; color: var(--text-secondary, #888); margin: 0 0 0.75rem; }
.hf-link { color: var(--color-primary, #4080ff); }
.svc-section { margin-bottom: 0.75rem; }
.svc-name {
margin: 0 0 0.25rem;
font-size: 0.75rem;
text-transform: uppercase;
letter-spacing: 0.05em;
color: var(--text-secondary, #888);
}
.catalog-list { list-style: none; margin: 0; padding: 0; display: flex; flex-direction: column; gap: 0.2rem; }
.catalog-item {
display: flex;
align-items: center;
gap: 0.5rem;
padding: 0.25rem 0.5rem;
background: var(--bg-secondary, #111);
border-radius: 4px;
font-size: 0.8rem;
}
.catalog-model { font-family: monospace; flex: 1; }
.catalog-vram { color: var(--text-secondary, #888); white-space: nowrap; }
.catalog-desc { color: var(--text-secondary, #888); font-size: 0.75rem; flex: 2; }
.catalog-empty, .panel-empty { color: var(--text-secondary, #888); font-size: 0.875rem; }
.sr-announce { min-height: 1.2em; }
.panel-error { color: var(--color-error, #fc8181); font-size: 0.8rem; }
</style>

View file

@ -0,0 +1,90 @@
<script setup lang="ts">
import { ref } from 'vue'
import GpuRow from './GpuRow.vue'
import OllamaModelPanel from './OllamaModelPanel.vue'
import HfNodeModelPanel from './HfNodeModelPanel.vue'
import type { NodeSummary } from '../../types/nodes'
const props = defineProps<{ node: NodeSummary }>()
const emit = defineEmits<{ updated: [] }>()
const showOllama = ref(false)
const showHf = ref(false)
</script>
<template>
<section class="node-card" :class="{ offline: !node.online }">
<header class="node-card-header">
<div class="node-identity">
<span
class="status-dot"
:class="node.online ? 'online' : 'offline'"
:aria-label="node.online ? 'Online' : 'Offline'"
role="img"
/>
<h2 class="node-name">{{ node.node_id }}</h2>
<span class="node-agent">{{ node.agent_url }}</span>
</div>
<div v-if="node.profile_loaded" class="node-actions">
<button class="btn-secondary btn-sm" @click="showOllama = !showOllama">
{{ showOllama ? 'Hide Ollama' : 'Ollama' }}
</button>
<button class="btn-secondary btn-sm" @click="showHf = !showHf">
{{ showHf ? 'Hide Catalog' : 'Catalog' }}
</button>
</div>
</header>
<div v-if="!node.profile_loaded" class="no-profile" role="status">
No profile configured for this node. GPU stats are visible; service assignment is disabled.
</div>
<div class="gpu-list">
<GpuRow
v-for="gpu in node.gpus"
:key="gpu.gpu_id"
:gpu="gpu"
:node-id="node.node_id"
:profile-loaded="node.profile_loaded"
:services-catalog="node.services_catalog"
@updated="emit('updated')"
/>
</div>
<OllamaModelPanel v-if="showOllama" :node-id="node.node_id" />
<HfNodeModelPanel v-if="showHf" :node-id="node.node_id" />
</section>
</template>
<style scoped>
.node-card {
border: 1px solid var(--border, #333);
border-radius: 8px;
padding: 1rem;
background: var(--bg-card, #1a1a1a);
}
.node-card.offline { opacity: 0.65; }
.node-card-header {
display: flex;
align-items: flex-start;
justify-content: space-between;
gap: 0.5rem;
margin-bottom: 0.75rem;
}
.node-identity { display: flex; align-items: center; gap: 0.5rem; flex-wrap: wrap; }
.node-name { margin: 0; font-size: 1rem; font-weight: 600; }
.node-agent { color: var(--text-secondary, #888); font-size: 0.8rem; font-family: monospace; }
.status-dot { width: 10px; height: 10px; border-radius: 50%; flex-shrink: 0; }
.status-dot.online { background: var(--color-success, #48bb78); }
.status-dot.offline { background: var(--color-warning, #ed8936); }
.node-actions { display: flex; gap: 0.5rem; flex-shrink: 0; }
.no-profile {
padding: 0.6rem 0.75rem;
background: var(--bg-notice, #1e1e1e);
border-radius: 4px;
color: var(--text-secondary, #888);
font-size: 0.875rem;
margin-bottom: 0.5rem;
}
.gpu-list { display: flex; flex-direction: column; gap: 0.5rem; }
</style>

View file

@ -0,0 +1,241 @@
<script setup lang="ts">
import { ref, onMounted, onUnmounted } from 'vue'
const props = defineProps<{ nodeId: string }>()
interface OllamaModel {
name: string
size: number
modified_at: string
}
const models = ref<OllamaModel[]>([])
const loading = ref(true)
const loadError = ref('')
const pullName = ref('')
const pulling = ref(false)
const pullStatus = ref('')
const pullPct = ref(0)
const pullError = ref('')
// AbortController for the SSE pull stream
const abortCtrl = ref<AbortController | null>(null)
// AbortController for the one-shot fetchModels request
let fetchAbort: AbortController | null = null
async function fetchModels() {
fetchAbort?.abort()
fetchAbort = new AbortController()
loading.value = true
loadError.value = ''
try {
const r = await fetch(`/api/nodes-mgmt/nodes/${props.nodeId}/models/ollama`, {
signal: fetchAbort.signal,
})
const data = await r.json() as { models?: OllamaModel[]; error?: string }
if (data.error) { loadError.value = data.error; return }
models.value = data.models ?? []
} catch (e) {
if (e instanceof Error && e.name === 'AbortError') return
loadError.value = e instanceof Error ? e.message : 'Failed to load models'
} finally {
loading.value = false
}
}
async function doPull() {
const name = pullName.value.trim()
if (!name || pulling.value) return
pulling.value = true
pullStatus.value = 'Starting...'
pullError.value = ''
pullPct.value = 0
const ctrl = new AbortController()
abortCtrl.value = ctrl
try {
const resp = await fetch(`/api/nodes-mgmt/nodes/${props.nodeId}/models/ollama/pull`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ name }),
signal: ctrl.signal,
})
if (!resp.ok) throw new Error(`HTTP ${resp.status}`)
if (!resp.body) throw new Error('No response body')
const reader = resp.body.getReader()
const decoder = new TextDecoder()
let buf = ''
while (true) {
const { done, value } = await reader.read()
if (done) break
buf += decoder.decode(value, { stream: true })
const lines = buf.split('\n')
buf = lines.pop() ?? ''
for (const line of lines) {
if (!line.startsWith('data: ')) continue
try {
const evt = JSON.parse(line.slice(6)) as {
status?: string; error?: string; total?: number; completed?: number
}
if (evt.error) {
pullError.value = evt.error
break
}
if (evt.status) pullStatus.value = evt.status
if (evt.total && evt.completed) {
pullPct.value = Math.round((evt.completed / evt.total) * 100)
}
if (evt.status === 'success') {
pullStatus.value = 'Done!'
pullName.value = ''
break
}
} catch { /* skip malformed line */ }
}
}
// Refresh model list after the stream closes (success or benign end)
await fetchModels()
} catch (e) {
if (e instanceof Error && e.name === 'AbortError') return
pullError.value = e instanceof Error ? e.message : 'Pull failed'
} finally {
pulling.value = false
abortCtrl.value = null
}
}
async function deleteModel(name: string) {
if (!confirm(`Delete model "${name}" from node ${props.nodeId}?`)) return
try {
const r = await fetch(
`/api/nodes-mgmt/nodes/${props.nodeId}/models/ollama/${encodeURIComponent(name)}`,
{ method: 'DELETE' },
)
if (!r.ok) throw new Error(`HTTP ${r.status}`)
await fetchModels()
} catch (e) {
loadError.value = e instanceof Error ? e.message : 'Delete failed'
}
}
function formatSize(bytes: number): string {
return (bytes / 1e9).toFixed(1) + ' GB'
}
onMounted(fetchModels)
onUnmounted(() => {
abortCtrl.value?.abort()
fetchAbort?.abort()
})
</script>
<template>
<section class="ollama-panel">
<h3 class="panel-title">Ollama Models</h3>
<form class="pull-form" @submit.prevent="doPull">
<input
v-model="pullName"
type="text"
placeholder="nomic-embed-text, llama3.2:3b, ..."
:disabled="pulling"
aria-label="Model name to pull from Ollama"
class="pull-input"
/>
<button type="submit" :disabled="pulling || !pullName.trim()" class="btn-primary btn-sm">
{{ pulling ? 'Pulling...' : 'Pull' }}
</button>
</form>
<div v-if="pulling || pullStatus" class="pull-progress" aria-live="polite">
<div
class="progress-bar"
role="progressbar"
:aria-valuenow="pullPct"
aria-valuemin="0"
aria-valuemax="100"
:aria-label="`Pull progress: ${pullStatus}`"
>
<div class="progress-fill" :style="{ width: `${pullPct}%` }" />
</div>
<span class="progress-label">{{ pullStatus }}{{ pullPct > 0 ? ` (${pullPct}%)` : '' }}</span>
</div>
<div v-if="pullError" class="pull-error" role="alert">
{{ pullError }}
<span v-if="pullError.includes('permission denied')">
Remove the partial file on the node and retry.
</span>
</div>
<div aria-live="polite" aria-atomic="true" class="sr-announce">
<span v-if="loading">Loading...</span>
</div>
<div v-if="loadError" class="panel-error" role="alert">{{ loadError }}</div>
<ul v-if="!loading && !loadError" class="model-list" role="list">
<li v-if="!models.length" class="model-empty">No Ollama models installed on this node.</li>
<li v-for="m in models" :key="m.name" class="model-item">
<span class="model-name">{{ m.name }}</span>
<span class="model-size">{{ formatSize(m.size) }}</span>
<button
class="btn-danger btn-xs"
@click="deleteModel(m.name)"
:aria-label="`Delete ${m.name}`"
>
Delete
</button>
</li>
</ul>
</section>
</template>
<style scoped>
.ollama-panel {
margin-top: 0.75rem;
padding: 0.75rem;
border: 1px solid var(--border, #333);
border-radius: 6px;
}
.panel-title { margin: 0 0 0.75rem; font-size: 0.9rem; }
.pull-form { display: flex; gap: 0.5rem; margin-bottom: 0.5rem; }
.pull-input {
flex: 1;
padding: 0.3rem 0.5rem;
background: var(--bg-input, #111);
border: 1px solid var(--border, #333);
border-radius: 4px;
color: inherit;
font-size: 0.875rem;
}
.pull-progress { margin-bottom: 0.5rem; }
.progress-bar {
height: 8px;
background: var(--bg-bar, #2a2a2a);
border-radius: 4px;
overflow: hidden;
margin-bottom: 0.25rem;
}
.progress-fill { height: 100%; background: var(--color-primary, #4080ff); transition: width 0.2s; }
.progress-label { font-size: 0.75rem; color: var(--text-secondary, #888); }
.pull-error, .panel-error { color: var(--color-error, #fc8181); font-size: 0.8rem; margin-bottom: 0.5rem; }
.sr-announce { min-height: 1.2em; }
.panel-loading { color: var(--text-secondary, #888); font-size: 0.875rem; }
.model-list { list-style: none; margin: 0; padding: 0; display: flex; flex-direction: column; gap: 0.3rem; }
.model-item {
display: flex;
align-items: center;
gap: 0.5rem;
padding: 0.3rem 0.5rem;
background: var(--bg-secondary, #111);
border-radius: 4px;
font-size: 0.875rem;
}
.model-name { flex: 1; font-family: monospace; }
.model-size { color: var(--text-secondary, #888); font-size: 0.8rem; }
.model-empty { color: var(--text-secondary, #888); font-size: 0.875rem; padding: 0.25rem 0; }
</style>

View file

@ -0,0 +1,81 @@
<script setup lang="ts">
type ServiceState =
| 'running'
| 'stopped'
| 'assigned-only'
| 'available'
| 'incompatible'
| 'vram-tight'
| 'unknown'
const props = defineProps<{
serviceName: string
state: ServiceState
assigned: boolean
disabled?: boolean
}>()
const emit = defineEmits<{ toggle: [] }>()
const STATE_LABELS: Record<ServiceState, string> = {
running: 'Running',
stopped: 'Stopped',
'assigned-only': 'Assigned',
available: 'Available',
incompatible: 'Incompatible',
'vram-tight': 'VRAM tight',
unknown: 'Unknown',
}
const STATE_ICONS: Record<ServiceState, string> = {
running: '▶',
stopped: '⏹',
'assigned-only': '📌',
available: '○',
incompatible: '✕',
'vram-tight': '⚠',
unknown: '?',
}
function handleToggle() {
if (!props.disabled && props.state !== 'incompatible') emit('toggle')
}
</script>
<template>
<button
class="service-badge"
:class="[`state-${state}`, { assigned, 'is-disabled': disabled || state === 'incompatible' }]"
:aria-pressed="assigned"
:aria-label="`${serviceName}: ${STATE_LABELS[state] ?? state}${assigned ? ' (assigned)' : ''}`"
:disabled="disabled || state === 'incompatible'"
@click="handleToggle"
>
<span class="badge-icon" aria-hidden="true">{{ STATE_ICONS[state] ?? '?' }}</span>
<span class="badge-name">{{ serviceName }}</span>
<span class="badge-state">{{ STATE_LABELS[state] ?? state }}</span>
</button>
</template>
<style scoped>
.service-badge {
display: inline-flex;
align-items: center;
gap: 0.3rem;
padding: 0.2rem 0.5rem;
border-radius: 4px;
border: 1px solid var(--border, #333);
background: var(--bg-badge, #1e1e1e);
font-size: 0.75rem;
cursor: pointer;
transition: opacity 0.1s, border-color 0.1s;
}
.service-badge:hover:not(.is-disabled) { opacity: 0.8; }
.service-badge.is-disabled { cursor: not-allowed; opacity: 0.5; }
.service-badge.state-running { border-color: var(--color-success, #48bb78); }
.service-badge.state-stopped { border-color: var(--color-warning, #ed8936); }
.service-badge.state-assigned-only { border-color: var(--color-info, #4299e1); }
.service-badge.state-incompatible { border-color: var(--color-error, #fc8181); }
.service-badge.state-vram-tight { border-color: var(--color-warning, #ed8936); }
.badge-state { color: var(--text-secondary, #888); }
</style>

View file

@ -1,25 +1,51 @@
import { createRouter, createWebHashHistory } from 'vue-router' import { createRouter, createWebHashHistory } from 'vue-router'
import LabelView from '../views/LabelView.vue'
// Views are lazy-loaded to keep initial bundle small // Lazy-loaded views
const FetchView = () => import('../views/FetchView.vue') const DashboardView = () => import('../views/DashboardView.vue')
const StatsView = () => import('../views/StatsView.vue') const LabelView = () => import('../views/LabelView.vue')
const BenchmarkView = () => import('../views/BenchmarkView.vue') const FetchView = () => import('../views/FetchView.vue')
const SettingsView = () => import('../views/SettingsView.vue') const CorrectionsView = () => import('../views/CorrectionsView.vue')
const CorrectionsView = () => import('../views/CorrectionsView.vue') const ImitateView = () => import('../views/ImitateView.vue')
const ModelsView = () => import('../views/ModelsView.vue') const BenchmarkView = () => import('../views/BenchmarkView.vue')
const ImitateView = () => import('../views/ImitateView.vue') const CompareView = () => import('../views/CompareView.vue')
const TrainJobsView = () => import('../views/TrainJobsView.vue')
const TrainResultsView = () => import('../views/TrainResultsView.vue')
const ModelsView = () => import('../views/ModelsView.vue')
const SettingsView = () => import('../views/SettingsView.vue')
const NodeManagementView = () => import('../views/NodeManagementView.vue')
export const routes = [
// ── Top-level ────────────────────────────────────────────
{ path: '/', component: DashboardView, meta: { title: 'Dashboard' } },
{ path: '/fleet', component: ModelsView, meta: { title: 'Fleet' } },
{ path: '/nodes', component: NodeManagementView, meta: { title: 'Nodes' } },
{ path: '/settings', component: SettingsView, meta: { title: 'Settings' } },
// ── Data domain ──────────────────────────────────────────
{ path: '/data/label', component: LabelView, meta: { title: 'Label' } },
{ path: '/data/fetch', component: FetchView, meta: { title: 'Fetch' } },
{ path: '/data/corrections', component: CorrectionsView, meta: { title: 'Corrections' } },
{ path: '/data/imitate', component: ImitateView, meta: { title: 'Imitate' } },
// ── Eval domain ──────────────────────────────────────────
{ path: '/eval/benchmark', component: BenchmarkView, meta: { title: 'Benchmark' } },
{ path: '/eval/compare', component: CompareView, meta: { title: 'Compare' } },
// ── Train domain ─────────────────────────────────────────
{ path: '/train/jobs', component: TrainJobsView, meta: { title: 'Training Jobs' } },
{ path: '/train/results', component: TrainResultsView, meta: { title: 'Training Results' } },
// ── Backward-compat redirects ────────────────────────────
{ path: '/benchmark', redirect: '/eval/benchmark' },
{ path: '/models', redirect: '/fleet' },
{ path: '/stats', redirect: '/' },
{ path: '/label', redirect: '/data/label' },
{ path: '/fetch', redirect: '/data/fetch' },
{ path: '/corrections', redirect: '/data/corrections' },
{ path: '/imitate', redirect: '/data/imitate' },
]
export const router = createRouter({ export const router = createRouter({
history: createWebHashHistory(), history: createWebHashHistory(),
routes: [ routes,
{ path: '/', component: LabelView, meta: { title: 'Label' } },
{ path: '/fetch', component: FetchView, meta: { title: 'Fetch' } },
{ path: '/stats', component: StatsView, meta: { title: 'Stats' } },
{ path: '/benchmark', component: BenchmarkView, meta: { title: 'Benchmark' } },
{ path: '/models', component: ModelsView, meta: { title: 'Models' } },
{ path: '/imitate', component: ImitateView, meta: { title: 'Imitate' } },
{ path: '/corrections', component: CorrectionsView, meta: { title: 'Corrections' } },
{ path: '/settings', component: SettingsView, meta: { title: 'Settings' } },
],
}) })

View file

@ -0,0 +1,94 @@
import { describe, it, expect } from 'vitest'
import { createRouter, createWebHashHistory } from 'vue-router'
// Import the raw routes array so we can test structure without mounting App
import { routes } from './index'
describe('router routes', () => {
it('exports a routes array', () => {
expect(Array.isArray(routes)).toBe(true)
})
it('has / pointing to DashboardView', () => {
const root = routes.find(r => r.path === '/')
expect(root).toBeDefined()
// Component should be async (lazy) or have a name
expect(root?.component).toBeDefined()
})
it('has /fleet route', () => {
const r = routes.find(r => r.path === '/fleet')
expect(r).toBeDefined()
})
it('has /data/label route', () => {
const r = routes.find(r => r.path === '/data/label')
expect(r).toBeDefined()
})
it('has /data/fetch route', () => {
const r = routes.find(r => r.path === '/data/fetch')
expect(r).toBeDefined()
})
it('has /data/corrections route', () => {
const r = routes.find(r => r.path === '/data/corrections')
expect(r).toBeDefined()
})
it('has /data/imitate route', () => {
const r = routes.find(r => r.path === '/data/imitate')
expect(r).toBeDefined()
})
it('has /eval/benchmark route', () => {
const r = routes.find(r => r.path === '/eval/benchmark')
expect(r).toBeDefined()
})
it('has /eval/compare route', () => {
const r = routes.find(r => r.path === '/eval/compare')
expect(r).toBeDefined()
})
it('has /train/jobs route', () => {
const r = routes.find(r => r.path === '/train/jobs')
expect(r).toBeDefined()
})
it('has /train/results route', () => {
const r = routes.find(r => r.path === '/train/results')
expect(r).toBeDefined()
})
it('has /settings route', () => {
const r = routes.find(r => r.path === '/settings')
expect(r).toBeDefined()
})
it('has backward-compat redirect from /benchmark to /eval/benchmark', () => {
const r = routes.find(r => r.path === '/benchmark')
expect(r).toBeDefined()
expect((r as { redirect?: string }).redirect).toBe('/eval/benchmark')
})
it('has backward-compat redirect from /models to /fleet', () => {
const r = routes.find(r => r.path === '/models')
expect(r).toBeDefined()
expect((r as { redirect?: string }).redirect).toBe('/fleet')
})
it('has backward-compat redirect from /stats to /', () => {
const r = routes.find(r => r.path === '/stats')
expect(r).toBeDefined()
expect((r as { redirect?: string }).redirect).toBe('/')
})
it('can create a functional router instance', () => {
const router = createRouter({
history: createWebHashHistory(),
routes,
})
expect(router).toBeDefined()
})
})

27
web/src/types/nodes.ts Normal file
View file

@ -0,0 +1,27 @@
export interface GpuEntry {
gpu_id: number
card: string
vram_total_mb: number
vram_used_mb: number
vram_free_mb: number
temp_c: number | null
utilization_pct: number | null
compute_cap: number | null
services_assigned: string[]
services_running: string[]
}
export interface ServiceInfo {
min_compute_cap: number
max_mb: number
catalog_size: number
}
export interface NodeSummary {
node_id: string
online: boolean
agent_url: string
gpus: GpuEntry[]
profile_loaded: boolean
services_catalog: Record<string, ServiceInfo>
}

View file

@ -0,0 +1,82 @@
import { mount, flushPromises } from '@vue/test-utils'
import { describe, it, expect, vi, beforeEach } from 'vitest'
import BenchmarkView from './BenchmarkView.vue'
beforeEach(() => {
vi.stubGlobal('fetch', vi.fn().mockImplementation((url: string) => {
// LlmEvalTab calls /api/cforch/models and expects { models: CfOrchModel[] }
if (url.includes('/api/cforch/models')) {
return Promise.resolve({
ok: true,
json: async () => ({ models: [] }),
text: async () => '',
})
}
// Default: satisfies ClassifierTab (/api/benchmark/results, /api/benchmark/models,
// /api/finetune/status), StyleTab (/api/style/models, /api/style/results),
// and any other tab that tolerates empty arrays/objects.
return Promise.resolve({
ok: true,
json: async () => ({ models: {}, categories: {}, tasks: [], types: [], results: [] }),
text: async () => '',
})
}))
vi.stubGlobal('EventSource', class {
onmessage = null
onerror = null
close() {}
})
})
describe('BenchmarkView', () => {
it('renders page title "Benchmark"', async () => {
const w = mount(BenchmarkView)
await flushPromises()
expect(w.text()).toContain('Benchmark')
})
it('has mode buttons: Classifier, LLM Eval, Writing Style', async () => {
const w = mount(BenchmarkView)
await flushPromises()
const text = w.text()
expect(text).toContain('Classifier')
expect(text).toContain('LLM Eval')
expect(text).toContain('Writing Style')
})
it('does NOT have a Compare mode button', async () => {
const w = mount(BenchmarkView)
await flushPromises()
const buttons = w.findAll('.mode-btn')
const labels = buttons.map(b => b.text())
expect(labels.every(l => !l.includes('Compare'))).toBe(true)
})
it('shows Classifier tab by default', async () => {
const w = mount(BenchmarkView)
await flushPromises()
// ClassifierTab has a .classifier-tab root
expect(w.find('.classifier-tab').exists()).toBe(true)
})
it('switches to LlmEvalTab when LLM Eval clicked', async () => {
const w = mount(BenchmarkView)
await flushPromises()
const llmBtn = w.findAll('.mode-btn').find(b => b.text().includes('LLM Eval'))!
await llmBtn.trigger('click')
await flushPromises()
expect(w.find('.llm-eval-tab').exists()).toBe(true)
expect(w.find('.classifier-tab').exists()).toBe(false)
expect(llmBtn.classes()).toContain('active')
})
it('switches to StyleTab when Writing Style clicked', async () => {
const w = mount(BenchmarkView)
await flushPromises()
const styleBtn = w.findAll('.mode-btn').find(b => b.text().includes('Writing Style'))!
await styleBtn.trigger('click')
await flushPromises()
expect(w.find('.style-tab').exists()).toBe(true)
expect(w.find('.classifier-tab').exists()).toBe(false)
})
})

View file

@ -16,33 +16,33 @@
:class="{ active: benchMode === 'llm' }" :class="{ active: benchMode === 'llm' }"
@click="benchMode = 'llm'" @click="benchMode = 'llm'"
>🤖 LLM Eval</button> >🤖 LLM Eval</button>
<button
class="mode-btn"
:class="{ active: benchMode === 'compare' }"
@click="benchMode = 'compare'"
> Compare</button>
<button <button
class="mode-btn" class="mode-btn"
:class="{ active: benchMode === 'style' }" :class="{ active: benchMode === 'style' }"
@click="benchMode = 'style'" @click="benchMode = 'style'"
> Writing Style</button> > Writing Style</button>
<button
class="mode-btn"
:class="{ active: benchMode === 'plans' }"
@click="benchMode = 'plans'"
>📐 Planning</button>
</div> </div>
<ClassifierTab v-if="benchMode === 'classifier'" /> <ClassifierTab v-if="benchMode === 'classifier'" />
<LlmEvalTab v-if="benchMode === 'llm'" /> <LlmEvalTab v-if="benchMode === 'llm'" />
<CompareTab v-if="benchMode === 'compare'" /> <StyleTab v-if="benchMode === 'style'" />
<StyleTab v-if="benchMode === 'style'" /> <PlansBenchTab v-if="benchMode === 'plans'" />
</div> </div>
</template> </template>
<script setup lang="ts"> <script setup lang="ts">
import { ref } from 'vue' import { ref } from 'vue'
import ClassifierTab from './ClassifierTab.vue' import ClassifierTab from './ClassifierTab.vue'
import LlmEvalTab from './LlmEvalTab.vue' import LlmEvalTab from './LlmEvalTab.vue'
import CompareTab from './CompareTab.vue' import StyleTab from './StyleTab.vue'
import StyleTab from './StyleTab.vue' import PlansBenchTab from './PlansBenchTab.vue'
type BenchMode = 'classifier' | 'llm' | 'compare' | 'style' type BenchMode = 'classifier' | 'llm' | 'style' | 'plans'
const benchMode = ref<BenchMode>('classifier') const benchMode = ref<BenchMode>('classifier')
</script> </script>
@ -69,7 +69,7 @@ const benchMode = ref<BenchMode>('classifier')
margin: 0; margin: 0;
} }
/* ── Mode toggle (segmented control) ────────────────────── */ /* ── Mode toggle (segmented control) ── */
.mode-toggle { .mode-toggle {
display: inline-flex; display: inline-flex;
border: 1px solid var(--color-border, #d0d7e8); border: 1px solid var(--color-border, #d0d7e8);

View file

@ -0,0 +1,31 @@
import { mount, flushPromises } from '@vue/test-utils'
import { describe, it, expect, vi, beforeEach } from 'vitest'
import CompareView from './CompareView.vue'
beforeEach(() => {
vi.stubGlobal('fetch', vi.fn().mockResolvedValue({
ok: true,
json: async () => ({ tasks: [], types: [], models: [] }),
text: async () => '',
}))
vi.stubGlobal('EventSource', class {
onmessage = null
onerror = null
close() {}
})
})
describe('CompareView', () => {
it('renders page title "Compare"', async () => {
const w = mount(CompareView)
await flushPromises()
expect(w.find('h1.page-title').text()).toContain('Compare')
})
it('wraps CompareTab component', async () => {
const w = mount(CompareView)
await flushPromises()
// CompareTab renders a .compare-tab root div
expect(w.find('.compare-tab').exists()).toBe(true)
})
})

View file

@ -0,0 +1,36 @@
<template>
<div class="compare-view">
<header class="compare-header">
<h1 class="page-title">🔍 Compare</h1>
</header>
<CompareTab />
</div>
</template>
<script setup lang="ts">
import CompareTab from './CompareTab.vue'
</script>
<style scoped>
.compare-view {
max-width: 860px;
margin: 0 auto;
padding: 1.5rem 1rem 4rem;
display: flex;
flex-direction: column;
gap: 1.75rem;
}
.compare-header {
display: flex;
align-items: center;
}
.page-title {
font-family: var(--font-display, var(--font-body, sans-serif));
font-size: 1.4rem;
font-weight: 700;
color: var(--app-primary, #2A6080);
margin: 0;
}
</style>

View file

@ -0,0 +1,119 @@
import { mount, flushPromises } from '@vue/test-utils'
import { createRouter, createWebHashHistory } from 'vue-router'
import { describe, it, expect, vi, beforeEach } from 'vitest'
import DashboardView from './DashboardView.vue'
const router = createRouter({
history: createWebHashHistory(),
routes: [
{ path: '/', component: { template: '<div />' } },
{ path: '/eval/benchmark', component: { template: '<div />' } },
{ path: '/train/jobs', component: { template: '<div />' } },
{ path: '/fleet', component: { template: '<div />' } },
],
})
const baseDashboard = {
labeled_since_last_eval: 0,
last_eval_timestamp: null,
last_eval_best_score: null,
active_jobs: [],
corrections_export_ready: 0,
signals: { data_to_eval: false, eval_to_train: false, train_to_fleet: false },
}
function mockFetch(overrides: Partial<typeof baseDashboard> = {}) {
vi.stubGlobal('fetch', vi.fn().mockResolvedValue({
ok: true,
json: async () => ({ ...baseDashboard, ...overrides }),
text: async () => '',
}))
}
beforeEach(() => mockFetch())
describe('DashboardView', () => {
it('renders page title', async () => {
const w = mount(DashboardView, { global: { plugins: [router] } })
await flushPromises()
expect(w.text()).toContain('Dashboard')
})
it('shows three stage cards: Data, Eval, Train', async () => {
const w = mount(DashboardView, { global: { plugins: [router] } })
await flushPromises()
expect(w.find('.stage-card[data-stage="data"]').exists()).toBe(true)
expect(w.find('.stage-card[data-stage="eval"]').exists()).toBe(true)
expect(w.find('.stage-card[data-stage="train"]').exists()).toBe(true)
})
it('shows labeled_since_last_eval count in Data card', async () => {
mockFetch({ labeled_since_last_eval: 42 })
const w = mount(DashboardView, { global: { plugins: [router] } })
await flushPromises()
expect(w.find('.stage-card[data-stage="data"]').text()).toContain('42')
})
it('does NOT show Run Eval CTA when data_to_eval is false', async () => {
mockFetch({ signals: { data_to_eval: false, eval_to_train: false, train_to_fleet: false } })
const w = mount(DashboardView, { global: { plugins: [router] } })
await flushPromises()
const dataCard = w.find('.stage-card[data-stage="data"]')
expect(dataCard.find('.cta-btn').exists()).toBe(false)
})
it('shows Run Eval CTA when data_to_eval is true', async () => {
mockFetch({ signals: { data_to_eval: true, eval_to_train: false, train_to_fleet: false } })
const w = mount(DashboardView, { global: { plugins: [router] } })
await flushPromises()
const dataCard = w.find('.stage-card[data-stage="data"]')
expect(dataCard.find('.cta-btn').exists()).toBe(true)
expect(dataCard.find('.cta-btn').text()).toContain('Run Eval')
})
it('shows Queue Finetune CTA when eval_to_train is true', async () => {
mockFetch({ signals: { data_to_eval: false, eval_to_train: true, train_to_fleet: false } })
const w = mount(DashboardView, { global: { plugins: [router] } })
await flushPromises()
const evalCard = w.find('.stage-card[data-stage="eval"]')
expect(evalCard.find('.cta-btn').text()).toContain('Queue Finetune')
})
it('shows Register in Fleet CTA when train_to_fleet is true', async () => {
mockFetch({ signals: { data_to_eval: false, eval_to_train: false, train_to_fleet: true } })
const w = mount(DashboardView, { global: { plugins: [router] } })
await flushPromises()
const trainCard = w.find('.stage-card[data-stage="train"]')
expect(trainCard.find('.cta-btn').text()).toContain('Register in Fleet')
})
it('shows active job status pills in Train card', async () => {
mockFetch({ active_jobs: [{ id: 'j1', type: 'classifier', model_key: 'deberta-v3', status: 'running' }] })
const w = mount(DashboardView, { global: { plugins: [router] } })
await flushPromises()
const trainCard = w.find('.stage-card[data-stage="train"]')
expect(trainCard.find('.status-pill').exists()).toBe(true)
expect(trainCard.text()).toContain('deberta-v3')
})
it('shows last eval score in Eval card when present', async () => {
mockFetch({ last_eval_best_score: 0.821 })
const w = mount(DashboardView, { global: { plugins: [router] } })
await flushPromises()
const evalCard = w.find('.stage-card[data-stage="eval"]')
expect(evalCard.text()).toContain('82.1%')
})
it('shows error state when API call fails', async () => {
vi.stubGlobal('fetch', vi.fn().mockResolvedValue({ ok: false, status: 503, text: async () => '' }))
const w = mount(DashboardView, { global: { plugins: [router] } })
await flushPromises()
expect(w.find('.error-notice').exists()).toBe(true)
})
it('shows refresh button', async () => {
const w = mount(DashboardView, { global: { plugins: [router] } })
await flushPromises()
expect(w.find('.refresh-btn').exists()).toBe(true)
})
})

View file

@ -0,0 +1,347 @@
<template>
<div class="dashboard-view">
<header class="dashboard-header">
<h1 class="page-title">📊 Dashboard</h1>
<button class="refresh-btn" :disabled="loading" @click="load" aria-label="Refresh dashboard">
🔄
</button>
</header>
<div v-if="loading && !data" class="loading-state">Loading</div>
<div v-if="error" class="error-notice" role="alert">
{{ error }}
<button class="btn-retry" @click="load">Retry</button>
</div>
<div v-if="data" class="flywheel-grid">
<!-- Data card -->
<div class="stage-card" data-stage="data">
<div class="card-header">
<span class="card-step"></span>
<h2 class="card-title">Data</h2>
</div>
<div class="card-body">
<p class="card-metric">
<strong class="metric-value">{{ data.labeled_since_last_eval.toLocaleString() }}</strong>
<span class="metric-label"> labeled since last eval</span>
</p>
</div>
<div v-if="data.signals.data_to_eval" class="card-cta">
<RouterLink to="/eval/benchmark" class="cta-btn">Run Eval</RouterLink>
</div>
</div>
<!-- Eval card -->
<div class="stage-card" data-stage="eval">
<div class="card-header">
<span class="card-step"></span>
<h2 class="card-title">Eval</h2>
</div>
<div class="card-body">
<p class="card-metric">
<span class="metric-label">Last run: </span>
<strong class="metric-value">{{ formattedEvalTime }}</strong>
</p>
<p v-if="data.last_eval_best_score != null" class="card-metric">
<span class="metric-label">Best score: </span>
<strong class="metric-value">{{ formatScore(data.last_eval_best_score) }}</strong>
</p>
</div>
<div v-if="data.signals.eval_to_train" class="card-cta">
<RouterLink to="/train/jobs" class="cta-btn">Queue Finetune</RouterLink>
</div>
</div>
<!-- Train card -->
<div class="stage-card" data-stage="train">
<div class="card-header">
<span class="card-step"></span>
<h2 class="card-title">Train</h2>
</div>
<div class="card-body">
<template v-if="data.active_jobs.length > 0">
<div
v-for="job in data.active_jobs"
:key="job.id"
class="job-row"
>
<span class="job-key">{{ job.model_key }}</span>
<span class="status-pill" :class="`status-${job.status}`">{{ job.status }}</span>
</div>
</template>
<p v-else class="card-metric metric-muted">No active jobs</p>
<p v-if="data.corrections_export_ready > 0" class="card-metric">
<strong class="metric-value">{{ data.corrections_export_ready }}</strong>
<span class="metric-label"> corrections ready</span>
</p>
</div>
<div v-if="data.signals.train_to_fleet" class="card-cta">
<RouterLink to="/fleet" class="cta-btn">Register in Fleet</RouterLink>
</div>
</div>
</div>
</div>
</template>
<script setup lang="ts">
import { ref, computed, onMounted } from 'vue'
import { RouterLink } from 'vue-router'
interface ActiveJob {
id: string
type: string
model_key: string
status: 'queued' | 'running' | 'completed' | 'failed' | 'cancelled'
}
interface DashboardSignals {
data_to_eval: boolean
eval_to_train: boolean
train_to_fleet: boolean
}
interface DashboardData {
labeled_since_last_eval: number
last_eval_timestamp: string | null
last_eval_best_score: number | null
active_jobs: ActiveJob[]
corrections_export_ready: number
signals: DashboardSignals
}
const data = ref<DashboardData | null>(null)
const loading = ref(false)
const error = ref<string | null>(null)
const formattedEvalTime = computed(() => {
if (!data.value?.last_eval_timestamp) return 'Never'
const date = new Date(data.value.last_eval_timestamp)
if (isNaN(date.getTime())) return 'Unknown'
const now = Date.now()
const diff = now - date.getTime()
const mins = Math.floor(diff / 60000)
if (mins < 1) return 'just now'
if (mins < 60) return `${mins}m ago`
const hrs = Math.floor(mins / 60)
if (hrs < 24) return `${hrs}h ago`
const days = Math.floor(hrs / 24)
return `${days}d ago`
})
function formatScore(score: number): string {
return `${(score * 100).toFixed(1)}%`
}
async function load() {
loading.value = true
error.value = null
try {
const res = await fetch('/api/dashboard')
if (!res.ok) {
error.value = `Could not load dashboard (HTTP ${res.status}).`
return
}
data.value = await res.json() as DashboardData
} catch {
error.value = 'Network error. Is the Avocet API running?'
} finally {
loading.value = false
}
}
onMounted(() => load())
</script>
<style scoped>
.dashboard-view {
max-width: 860px;
margin: 0 auto;
padding: 1.5rem 1rem 4rem;
display: flex;
flex-direction: column;
gap: 1.75rem;
}
.dashboard-header {
display: flex;
align-items: center;
gap: 0.75rem;
}
.page-title {
font-family: var(--font-display, var(--font-body, sans-serif));
font-size: 1.4rem;
font-weight: 700;
color: var(--app-primary, #2A6080);
margin: 0;
flex: 1;
}
.refresh-btn {
background: transparent;
border: 1px solid var(--color-border, #d0d7e8);
border-radius: 0.375rem;
cursor: pointer;
font-size: 1rem;
padding: 0.3rem 0.5rem;
transition: background 0.15s;
}
.refresh-btn:hover:not(:disabled) { background: var(--color-surface-raised, #e4ebf5); }
.refresh-btn:disabled { opacity: 0.5; cursor: not-allowed; }
/* ── Flywheel grid ── */
.flywheel-grid {
display: grid;
grid-template-columns: repeat(3, 1fr);
gap: 1rem;
}
@media (max-width: 680px) {
.flywheel-grid {
grid-template-columns: 1fr;
}
}
/* ── Stage cards ── */
.stage-card {
background: var(--color-surface-raised, #f5f7fc);
border: 1px solid var(--color-border, #d0d7e8);
border-radius: var(--radius-lg, 1rem);
padding: 1rem;
display: flex;
flex-direction: column;
gap: 0.75rem;
box-shadow: var(--shadow-sm);
}
.card-header {
display: flex;
align-items: center;
gap: 0.5rem;
border-bottom: 1px solid var(--color-border, #d0d7e8);
padding-bottom: 0.6rem;
}
.card-step {
font-size: 1.1rem;
font-weight: 700;
color: var(--app-primary, #2A6080);
flex-shrink: 0;
}
.card-title {
font-family: var(--font-display, var(--font-body, sans-serif));
font-size: 1rem;
font-weight: 600;
color: var(--color-text, #1a2338);
margin: 0;
}
.card-body {
display: flex;
flex-direction: column;
gap: 0.4rem;
flex: 1;
}
.card-metric {
margin: 0;
font-size: 0.875rem;
color: var(--color-text, #1a2338);
}
.metric-value {
font-size: 1.05rem;
font-weight: 700;
color: var(--app-primary, #2A6080);
}
.metric-label {
color: var(--color-text-muted, #4a5c7a);
}
.metric-muted { color: var(--color-text-muted, #4a5c7a); }
.card-cta { margin-top: auto; }
.cta-btn {
display: block;
width: 100%;
text-align: center;
padding: 0.5rem;
background: var(--app-primary, #2A6080);
color: #fff;
border-radius: 0.375rem;
text-decoration: none;
font-size: 0.875rem;
font-weight: 600;
transition: background 0.15s;
}
.cta-btn:hover { background: color-mix(in srgb, var(--app-primary, #2A6080) 85%, black); }
/* ── Job pills ── */
.job-row {
display: flex;
align-items: center;
gap: 0.5rem;
font-size: 0.875rem;
}
.job-key {
flex: 1;
overflow: hidden;
text-overflow: ellipsis;
white-space: nowrap;
color: var(--color-text, #1a2338);
}
.status-pill {
font-size: 0.75rem;
padding: 0.15rem 0.45rem;
border-radius: 100px;
font-weight: 600;
flex-shrink: 0;
background: var(--color-surface-raised, #e4ebf5);
color: var(--color-text-muted, #4a5c7a);
}
.status-pill.status-running { background: #d4f4e0; color: #1a7a3a; }
.status-pill.status-queued { background: #fef3cd; color: #856404; }
.status-pill.status-failed { background: #fde8e8; color: #842029; }
.status-pill.status-completed { background: #e0f0ff; color: #0c5481; }
/* ── State indicators ── */
.loading-state {
color: var(--color-text-muted, #4a5c7a);
font-size: 0.9rem;
}
.error-notice {
background: #fde8e8;
color: #842029;
border: 1px solid #f5c2c7;
border-radius: 0.5rem;
padding: 0.75rem 1rem;
font-size: 0.875rem;
display: flex;
align-items: center;
gap: 0.75rem;
}
.btn-retry {
background: transparent;
border: 1px solid currentColor;
border-radius: 0.25rem;
color: inherit;
cursor: pointer;
font-size: 0.75rem;
padding: 0.2rem 0.5rem;
margin-left: auto;
}
</style>

View file

@ -6,6 +6,8 @@
<summary class="picker-summary"> <summary class="picker-summary">
<span class="picker-title">📋 Task Selection</span> <span class="picker-title">📋 Task Selection</span>
<span class="picker-badge">{{ llmTaskBadge }}</span> <span class="picker-badge">{{ llmTaskBadge }}</span>
<button class="picker-bulk-btn" @click.stop.prevent="selectAllTasks()">All</button>
<button class="picker-bulk-btn" @click.stop.prevent="clearAllTasks()">None</button>
</summary> </summary>
<div class="picker-body"> <div class="picker-body">
<div v-if="llmTasksLoading" class="picker-loading">Loading tasks</div> <div v-if="llmTasksLoading" class="picker-loading">Loading tasks</div>
@ -44,6 +46,8 @@
<summary class="picker-summary"> <summary class="picker-summary">
<span class="picker-title">🎯 Model Selection</span> <span class="picker-title">🎯 Model Selection</span>
<span class="picker-badge">{{ llmModelBadge }}</span> <span class="picker-badge">{{ llmModelBadge }}</span>
<button class="picker-bulk-btn" @click.stop.prevent="selectAllModels()">All</button>
<button class="picker-bulk-btn" @click.stop.prevent="clearAllModels()">None</button>
</summary> </summary>
<div class="picker-body"> <div class="picker-body">
<div v-if="llmModelsLoading" class="picker-loading">Loading models</div> <div v-if="llmModelsLoading" class="picker-loading">Loading models</div>
@ -78,6 +82,33 @@
</div> </div>
</details> </details>
<!-- Node Selection -->
<div class="node-picker" v-if="llmNodes.length > 0">
<span class="node-picker-label">Nodes:</span>
<label
v-for="node in llmNodes"
:key="node.node_id"
class="node-chip"
:class="{ 'node-chip--off': !enabledNodes.has(node.node_id), 'node-chip--offline': !node.online }"
:title="node.online ? `${node.node_id} — ${node.gpus.length} GPU(s)` : `${node.node_id} — offline`"
>
<input
type="checkbox"
class="node-chip-check"
:checked="enabledNodes.has(node.node_id)"
:disabled="!node.online || llmRunning"
@change="toggleNode(node.node_id, ($event.target as HTMLInputElement).checked)"
/>
{{ node.node_id }}
<span class="node-chip-status" v-if="!node.online">offline</span>
</label>
<span class="node-picker-hint">
{{ enabledNodeIds.length === llmNodes.filter(n => n.online).length
? 'auto-routing (all nodes)'
: `restricted to: ${enabledNodeIds.join(', ')}` }}
</span>
</div>
<!-- Run Controls --> <!-- Run Controls -->
<div class="run-controls"> <div class="run-controls">
<button <button
@ -88,6 +119,24 @@
{{ llmRunning ? '⏳ Running…' : '▶ Run LLM Eval' }} {{ llmRunning ? '⏳ Running…' : '▶ Run LLM Eval' }}
</button> </button>
<button v-if="llmRunning" class="btn-cancel" @click="cancelLlmBenchmark"> Cancel</button> <button v-if="llmRunning" class="btn-cancel" @click="cancelLlmBenchmark"> Cancel</button>
<input
v-model="llmJudgeUrl"
class="judge-url-input"
placeholder="Judge URL — leave empty to skip LLM judge scoring"
:disabled="llmRunning"
title="Optional: URL of a running cf-text service (e.g. http://10.1.10.158:8008). When set, each LLM response gets a secondary score from the judge model — adds a 'judge' column to results. Empty = primary quality scoring only."
/>
<label class="workers-label" title="Run this many models concurrently (requires multiple GPUs)">
<span class="workers-prefix">workers</span>
<input
v-model.number="llmWorkers"
type="number"
min="1"
max="8"
class="workers-input"
:disabled="llmRunning"
/>
</label>
<span v-if="selectedLlmTasks.size === 0 || selectedLlmModels.size === 0" class="run-hint"> <span v-if="selectedLlmTasks.size === 0 || selectedLlmModels.size === 0" class="run-hint">
Select at least one task and one model to run. Select at least one task and one model to run.
</span> </span>
@ -119,6 +168,7 @@
<tr> <tr>
<th class="hm-label-col">Model</th> <th class="hm-label-col">Model</th>
<th class="hm-model-col">overall</th> <th class="hm-model-col">overall</th>
<th v-if="llmHasJudge" class="hm-model-col hm-judge-col">judge</th>
<th v-for="col in llmTaskTypeCols" :key="col" class="hm-model-col">{{ col }}</th> <th v-for="col in llmTaskTypeCols" :key="col" class="hm-model-col">{{ col }}</th>
<th class="hm-model-col">tok/s</th> <th class="hm-model-col">tok/s</th>
</tr> </tr>
@ -130,6 +180,12 @@
class="hm-value-cell" class="hm-value-cell"
:class="{ 'bt-best': llmBestByCol['overall'] === row.model_id }" :class="{ 'bt-best': llmBestByCol['overall'] === row.model_id }"
>{{ pct(row.avg_quality_score) }}</td> >{{ pct(row.avg_quality_score) }}</td>
<td
v-if="llmHasJudge"
class="hm-value-cell hm-judge-cell"
:class="{ 'bt-best': llmBestByCol['judge'] === row.model_id }"
title="LLM-as-judge secondary score"
>{{ row.avg_judge_score != null ? pct(row.avg_judge_score) : '—' }}</td>
<td <td
v-for="col in llmTaskTypeCols" v-for="col in llmTaskTypeCols"
:key="col" :key="col"
@ -168,6 +224,12 @@ interface CfOrchModel {
vram_estimate_mb?: number vram_estimate_mb?: number
} }
interface CfOrchNode {
node_id: string
online: boolean
gpus: { gpu_id: number; name: string; vram_total_mb: number; vram_free_mb: number }[]
}
interface LlmModelResult { interface LlmModelResult {
model_name: string model_name: string
model_id: string model_id: string
@ -175,9 +237,11 @@ interface LlmModelResult {
avg_tokens_per_sec: number avg_tokens_per_sec: number
avg_completion_ms: number avg_completion_ms: number
avg_quality_score: number avg_quality_score: number
avg_judge_score: number | null
finetune_candidates: number finetune_candidates: number
error_count: number error_count: number
quality_by_task_type: Record<string, number> quality_by_task_type: Record<string, number>
judge_score_by_task_type?: Record<string, number>
} }
// State // State
@ -195,6 +259,10 @@ const llmError = ref('')
const llmResults = ref<LlmModelResult[]>([]) const llmResults = ref<LlmModelResult[]>([])
const llmEventSource = ref<EventSource | null>(null) const llmEventSource = ref<EventSource | null>(null)
const llmLogEl = ref<HTMLElement | null>(null) const llmLogEl = ref<HTMLElement | null>(null)
const llmJudgeUrl = ref('')
const llmWorkers = ref(1)
const llmNodes = ref<CfOrchNode[]>([])
const enabledNodes = ref<Set<string>>(new Set())
// Computed // Computed
const llmTasksByType = computed((): Record<string, CfOrchTask[]> => { const llmTasksByType = computed((): Record<string, CfOrchTask[]> => {
@ -239,6 +307,14 @@ const llmTaskTypeCols = computed(() => {
return [...types].sort() return [...types].sort()
}) })
const llmHasJudge = computed(() =>
llmResults.value.some(r => r.avg_judge_score != null)
)
const enabledNodeIds = computed(() =>
llmNodes.value.filter(n => n.online && enabledNodes.value.has(n.node_id)).map(n => n.node_id)
)
const llmBestByCol = computed((): Record<string, string> => { const llmBestByCol = computed((): Record<string, string> => {
const best: Record<string, string> = {} const best: Record<string, string> = {}
if (llmResults.value.length === 0) return best if (llmResults.value.length === 0) return best
@ -249,6 +325,16 @@ const llmBestByCol = computed((): Record<string, string> => {
} }
best['overall'] = bestId best['overall'] = bestId
if (llmHasJudge.value) {
bestId = ''; bestVal = -Infinity
for (const r of llmResults.value) {
if (r.avg_judge_score != null && r.avg_judge_score > bestVal) {
bestVal = r.avg_judge_score; bestId = r.model_id
}
}
best['judge'] = bestId
}
for (const col of llmTaskTypeCols.value) { for (const col of llmTaskTypeCols.value) {
bestId = ''; bestVal = -Infinity bestId = ''; bestVal = -Infinity
for (const r of llmResults.value) { for (const r of llmResults.value) {
@ -306,6 +392,15 @@ function toggleService(models: CfOrchModel[], checked: boolean) {
} }
selectedLlmModels.value = next selectedLlmModels.value = next
} }
function selectAllTasks() { selectedLlmTasks.value = new Set(llmTasks.value.map(t => t.id)) }
function clearAllTasks() { selectedLlmTasks.value = new Set() }
function selectAllModels() { selectedLlmModels.value = new Set(llmModels.value.map(m => m.id)) }
function clearAllModels() { selectedLlmModels.value = new Set() }
function toggleNode(id: string, checked: boolean) {
const next = new Set(enabledNodes.value)
checked ? next.add(id) : next.delete(id)
enabledNodes.value = next
}
// Data loaders // Data loaders
async function loadLlmTasks() { async function loadLlmTasks() {
@ -335,6 +430,21 @@ async function loadLlmResults() {
} }
} }
async function loadLlmConfig() {
const { data } = await useApiFetch<{ judge_url?: string }>('/api/cforch/config')
if (data?.judge_url && !llmJudgeUrl.value) {
llmJudgeUrl.value = data.judge_url
}
}
async function loadLlmNodes() {
const { data } = await useApiFetch<{ nodes: CfOrchNode[] }>('/api/cforch/nodes')
if (data?.nodes) {
llmNodes.value = data.nodes
enabledNodes.value = new Set(data.nodes.filter(n => n.online).map(n => n.node_id))
}
}
// Run / cancel // Run / cancel
function startLlmBenchmark() { function startLlmBenchmark() {
llmRunning.value = true llmRunning.value = true
@ -344,6 +454,15 @@ function startLlmBenchmark() {
const params = new URLSearchParams() const params = new URLSearchParams()
const taskIds = [...selectedLlmTasks.value].join(',') const taskIds = [...selectedLlmTasks.value].join(',')
if (taskIds) params.set('task_ids', taskIds) if (taskIds) params.set('task_ids', taskIds)
const modelIds = [...selectedLlmModels.value].join(',')
if (modelIds) params.set('model_ids', modelIds)
if (llmJudgeUrl.value.trim()) params.set('judge_url', llmJudgeUrl.value.trim())
if (llmWorkers.value > 1) params.set('workers', String(llmWorkers.value))
const onlineNodeIds = llmNodes.value.filter(n => n.online).map(n => n.node_id)
const isRestricted = enabledNodeIds.value.length < onlineNodeIds.length
if (isRestricted && enabledNodeIds.value.length > 0) {
params.set('node_ids', enabledNodeIds.value.join(','))
}
const es = new EventSource(`/api/cforch/run?${params}`) const es = new EventSource(`/api/cforch/run?${params}`)
llmEventSource.value = es llmEventSource.value = es
@ -387,6 +506,8 @@ onMounted(() => {
loadLlmTasks() loadLlmTasks()
loadLlmModels() loadLlmModels()
loadLlmResults() loadLlmResults()
loadLlmConfig()
loadLlmNodes()
}) })
</script> </script>
@ -451,6 +572,43 @@ onMounted(() => {
color: var(--color-text-secondary, #6b7a99); color: var(--color-text-secondary, #6b7a99);
} }
.judge-url-input {
flex: 1;
min-width: 14rem;
max-width: 24rem;
padding: 0.35rem 0.6rem;
border: 1px solid var(--color-border, #d0d7e8);
border-radius: 0.375rem;
background: var(--color-surface, #fff);
color: var(--color-text, #1a2338);
font-size: 0.8rem;
font-family: var(--font-mono, monospace);
}
.judge-url-input:disabled { opacity: 0.5; }
.judge-url-input::placeholder { color: var(--color-text-secondary, #6b7a99); }
.workers-label {
display: flex;
align-items: center;
gap: 0.3rem;
font-size: 0.8rem;
color: var(--color-text-secondary, #6b7a99);
white-space: nowrap;
}
.workers-prefix { font-family: var(--font-mono, monospace); }
.workers-input {
width: 3.2rem;
padding: 0.35rem 0.4rem;
border: 1px solid var(--color-border, #d0d7e8);
border-radius: 0.375rem;
background: var(--color-surface, #fff);
color: var(--color-text, #1a2338);
font-size: 0.8rem;
font-family: var(--font-mono, monospace);
text-align: center;
}
.workers-input:disabled { opacity: 0.5; }
/* ── Run log ────────────────────────────────────────────── */ /* ── Run log ────────────────────────────────────────────── */
.run-log { .run-log {
border: 1px solid var(--color-border, #d0d7e8); border: 1px solid var(--color-border, #d0d7e8);
@ -592,6 +750,15 @@ onMounted(() => {
white-space: nowrap; white-space: nowrap;
} }
.hm-judge-col {
background: color-mix(in srgb, var(--color-surface-raised, #e4ebf5) 80%, #c6d5f5);
}
.hm-judge-cell {
background: color-mix(in srgb, var(--color-surface, #fff) 85%, #c6d5f5);
font-style: italic;
opacity: 0.9;
}
/* ── Model Picker ───────────────────────────────────────── */ /* ── Model Picker ───────────────────────────────────────── */
.model-picker { .model-picker {
border: 1px solid var(--color-border, #d0d7e8); border: 1px solid var(--color-border, #d0d7e8);
@ -630,6 +797,24 @@ details[open] .picker-summary::before { content: '▼ '; }
margin-left: auto; margin-left: auto;
} }
.picker-bulk-btn {
padding: 0.1rem 0.45rem;
font-size: 0.7rem;
font-family: var(--font-mono, monospace);
background: var(--color-surface, #fff);
border: 1px solid var(--color-border, #d0d7e8);
border-radius: 0.25rem;
color: var(--color-text-secondary, #6b7a99);
cursor: pointer;
transition: background 0.12s, color 0.12s;
flex-shrink: 0;
}
.picker-bulk-btn:hover {
background: var(--app-primary, #2A6080);
color: #fff;
border-color: var(--app-primary, #2A6080);
}
.picker-body { .picker-body {
padding: 0.75rem; padding: 0.75rem;
border-top: 1px solid var(--color-border, #d0d7e8); border-top: 1px solid var(--color-border, #d0d7e8);
@ -712,4 +897,61 @@ details[open] .picker-summary::before { content: '▼ '; }
.picker-model-list { padding-left: 0; } .picker-model-list { padding-left: 0; }
.picker-model-name { max-width: 14ch; } .picker-model-name { max-width: 14ch; }
} }
/* ── Node picker ────────────────────────────────────── */
.node-picker {
display: flex;
align-items: center;
gap: 0.5rem;
flex-wrap: wrap;
padding: 0.5rem 0.75rem;
border: 1px solid var(--color-border, #d0d7e8);
border-radius: 0.5rem;
background: var(--color-surface-raised, #e4ebf5);
}
.node-picker-label {
font-size: 0.78rem;
font-weight: 600;
color: var(--color-text-secondary, #6b7a99);
text-transform: uppercase;
letter-spacing: 0.04em;
white-space: nowrap;
}
.node-chip {
display: inline-flex;
align-items: center;
gap: 0.3rem;
padding: 0.2rem 0.55rem;
border: 1px solid var(--color-border, #d0d7e8);
border-radius: 1rem;
background: var(--color-surface, #fff);
font-size: 0.78rem;
font-family: var(--font-mono, monospace);
color: var(--color-text, #1a2338);
cursor: pointer;
transition: background 0.12s, opacity 0.12s;
}
.node-chip--off {
opacity: 0.45;
background: transparent;
}
.node-chip--offline {
opacity: 0.35;
cursor: not-allowed;
font-style: italic;
}
.node-chip-check { accent-color: var(--app-primary, #2A6080); }
.node-chip-status {
font-size: 0.66rem;
color: var(--color-text-secondary, #6b7a99);
}
.node-picker-hint {
font-size: 0.72rem;
color: var(--color-text-secondary, #6b7a99);
font-family: var(--font-mono, monospace);
margin-left: auto;
}
</style> </style>

View file

@ -51,8 +51,31 @@
<span v-if="lookupResult.adapter_recommendation" class="chip chip-adapter"> <span v-if="lookupResult.adapter_recommendation" class="chip chip-adapter">
{{ lookupResult.adapter_recommendation }} {{ lookupResult.adapter_recommendation }}
</span> </span>
<span v-if="lookupResult.size != null" class="preview-size"> <span v-if="selectedQuantSize > 0" class="preview-size">
{{ humanBytes(lookupResult.size) }} {{ humanBytes(selectedQuantSize) }}
</span>
</div>
<!-- GGUF quantization picker only shown for GGUF repos -->
<div v-if="lookupResult.gguf_files?.length" class="quant-picker">
<label class="quant-label" for="quant-select">Quantization</label>
<select
id="quant-select"
v-model="selectedQuant"
class="quant-select"
aria-label="Select quantization variant"
>
<option :value="null" disabled>Select quantization</option>
<option
v-for="f in lookupResult.gguf_files"
:key="f.filename"
:value="f.quant_name ?? f.filename"
>
{{ f.quant_name ?? f.filename }} {{ humanBytes(f.size) }}
</option>
</select>
<span class="quant-hint">
Q5_K_M or Q6_K recommended for 8 GB GPUs. Q8_0 for max quality.
</span> </span>
</div> </div>
@ -67,7 +90,7 @@
<button <button
class="btn-primary btn-add-queue" class="btn-primary btn-add-queue"
:disabled="lookupResult.already_installed || lookupResult.already_queued || addingToQueue" :disabled="!canAddToQueue"
@click="addToQueue" @click="addToQueue"
> >
{{ addingToQueue ? 'Adding…' : 'Add to queue' }} {{ addingToQueue ? 'Adding…' : 'Add to queue' }}
@ -99,9 +122,39 @@
<span v-if="model.role" class="chip chip-role">{{ model.role }}</span> <span v-if="model.role" class="chip chip-role">{{ model.role }}</span>
<span v-if="model.service" class="chip" :class="serviceChipClass(model.service)">{{ model.service }}</span> <span v-if="model.service" class="chip" :class="serviceChipClass(model.service)">{{ model.service }}</span>
<span v-if="model.adapter_recommendation" class="chip chip-adapter">{{ model.adapter_recommendation }}</span> <span v-if="model.adapter_recommendation" class="chip chip-adapter">{{ model.adapter_recommendation }}</span>
<span v-if="model.quant_pattern" class="chip chip-quant">{{ model.quant_pattern }}</span>
</div>
<!-- Allow manual service/role assignment for unrecognized pipeline tags -->
<div v-if="!model.service" class="classify-row queue-classify">
<select
class="classify-select"
:value="classifyDraft[model.id]?.service ?? ''"
@change="onServiceChange(model.id, ($event.target as HTMLSelectElement).value)"
aria-label="Assign service"
>
<option value="" disabled>Service</option>
<option v-for="svc in CLASSIFIABLE_SERVICES" :key="svc.value" :value="svc.value">{{ svc.label }}</option>
</select>
<select
class="classify-select"
:value="classifyDraft[model.id]?.role ?? ''"
:disabled="!classifyDraft[model.id]?.service"
@change="(e) => setClassifyRole(model.id, (e.target as HTMLSelectElement).value)"
aria-label="Assign role"
>
<option value="" disabled>Role</option>
<option
v-for="role in rolesForService(classifyDraft[model.id]?.service ?? '')"
:key="role"
:value="role"
>{{ role }}</option>
</select>
</div> </div>
<div class="model-card-actions"> <div class="model-card-actions">
<button class="btn-primary btn-sm" @click="approveModel(model.id)"> <button
class="btn-primary btn-sm"
@click="approveModel(model.id, classifyDraft[model.id])"
>
Approve download Approve download
</button> </button>
</div> </div>
@ -252,6 +305,12 @@ import { ref, computed, onMounted, onUnmounted } from 'vue'
// Type definitions // Type definitions
interface GgufFile {
filename: string
size: number
quant_name: string | null
}
interface LookupResult { interface LookupResult {
repo_id: string repo_id: string
pipeline_tag: string | null pipeline_tag: string | null
@ -260,7 +319,8 @@ interface LookupResult {
service: string | null service: string | null
compatible: boolean compatible: boolean
warning: string | null warning: string | null
size: number | null model_size_bytes: number
gguf_files: GgufFile[] | null
description: string | null description: string | null
already_installed: boolean already_installed: boolean
already_queued: boolean already_queued: boolean
@ -274,6 +334,7 @@ interface QueuedModel {
adapter_recommendation: string | null adapter_recommendation: string | null
role: string | null role: string | null
service: string | null service: string | null
quant_pattern: string | null
} }
interface InstalledModel { interface InstalledModel {
@ -302,6 +363,26 @@ const lookupLoading = ref(false)
const lookupError = ref<string | null>(null) const lookupError = ref<string | null>(null)
const lookupResult = ref<LookupResult | null>(null) const lookupResult = ref<LookupResult | null>(null)
const addingToQueue = ref(false) const addingToQueue = ref(false)
const selectedQuant = ref<string | null>(null)
// Size of the selected GGUF file, or total model size for non-GGUF repos.
const selectedQuantSize = computed<number>(() => {
const r = lookupResult.value
if (!r) return 0
if (r.gguf_files?.length && selectedQuant.value) {
const f = r.gguf_files.find(f => (f.quant_name ?? f.filename) === selectedQuant.value)
return f?.size ?? r.model_size_bytes
}
return r.model_size_bytes
})
// Disable "Add to queue" when a GGUF repo but no quant chosen yet.
const canAddToQueue = computed(() => {
const r = lookupResult.value
if (!r || r.already_installed || r.already_queued || addingToQueue.value) return false
if (r.gguf_files?.length && !selectedQuant.value) return false
return true
})
const queuedModels = ref<QueuedModel[]>([]) const queuedModels = ref<QueuedModel[]>([])
const installedModels = ref<InstalledModel[]>([]) const installedModels = ref<InstalledModel[]>([])
@ -411,6 +492,7 @@ async function doLookup() {
lookupLoading.value = true lookupLoading.value = true
lookupError.value = null lookupError.value = null
lookupResult.value = null lookupResult.value = null
selectedQuant.value = null
try { try {
const res = await fetch(`/api/models/lookup?repo_id=${encodeURIComponent(repoId)}`) const res = await fetch(`/api/models/lookup?repo_id=${encodeURIComponent(repoId)}`)
@ -442,7 +524,15 @@ async function addToQueue() {
const res = await fetch('/api/models/queue', { const res = await fetch('/api/models/queue', {
method: 'POST', method: 'POST',
headers: { 'Content-Type': 'application/json' }, headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ repo_id, pipeline_tag, adapter_recommendation, role, service }), body: JSON.stringify({
repo_id,
pipeline_tag,
adapter_recommendation,
role,
service,
model_size_bytes: selectedQuantSize.value,
quant_pattern: selectedQuant.value,
}),
}) })
if (res.ok) { if (res.ok) {
lookupResult.value = { ...lookupResult.value, already_queued: true } lookupResult.value = { ...lookupResult.value, already_queued: true }
@ -454,8 +544,16 @@ async function addToQueue() {
} }
} }
async function approveModel(id: string) { async function approveModel(id: string, draft?: { service: string; role: string }) {
try { try {
// If the user picked a service/role for an unrecognized model, patch it first.
if (draft?.service && draft?.role) {
await fetch(`/api/models/queue/${encodeURIComponent(id)}`, {
method: 'PATCH',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ service: draft.service, role: draft.role }),
})
}
const res = await fetch(`/api/models/queue/${encodeURIComponent(id)}/approve`, { method: 'POST' }) const res = await fetch(`/api/models/queue/${encodeURIComponent(id)}/approve`, { method: 'POST' })
if (res.ok) { if (res.ok) {
await loadQueue() await loadQueue()
@ -774,6 +872,44 @@ onUnmounted(() => {
align-self: flex-start; align-self: flex-start;
} }
/* ── Quant picker ── */
.quant-picker {
display: flex;
flex-direction: column;
gap: 0.35rem;
}
.quant-label {
font-size: 0.8rem;
font-weight: 600;
color: var(--color-text-muted, #4a5c7a);
text-transform: uppercase;
letter-spacing: 0.04em;
}
.quant-select {
padding: 0.4rem 0.6rem;
border: 1px solid var(--color-border, #a8b8d0);
border-radius: var(--radius-md, 0.5rem);
background: var(--color-surface, #f0f4fb);
color: var(--color-text, #1a2338);
font-size: 0.9rem;
font-family: var(--font-mono, monospace);
cursor: pointer;
}
.quant-hint {
font-size: 0.78rem;
color: var(--color-text-muted, #4a5c7a);
}
.chip-quant {
background: color-mix(in srgb, var(--color-primary, #2A6080) 15%, transparent);
color: var(--color-primary, #2A6080);
font-family: var(--font-mono, monospace);
font-size: 0.75rem;
}
/* ── Model cards (queue + downloads) ── */ /* ── Model cards (queue + downloads) ── */
.model-card { .model-card {
border: 1px solid var(--color-border, #a8b8d0); border: 1px solid var(--color-border, #a8b8d0);

View file

@ -0,0 +1,69 @@
<script setup lang="ts">
import { ref, onMounted } from 'vue'
import NodeCard from '../components/nodes/NodeCard.vue'
import type { NodeSummary } from '../types/nodes'
const nodes = ref<NodeSummary[]>([])
const loading = ref(true)
const error = ref('')
async function fetchNodes() {
loading.value = true
error.value = ''
try {
const r = await fetch('/api/nodes-mgmt/nodes')
if (!r.ok) throw new Error(`HTTP ${r.status}`)
nodes.value = (await r.json()) as NodeSummary[]
} catch (e) {
error.value = e instanceof Error ? e.message : 'Failed to load nodes'
} finally {
loading.value = false
}
}
onMounted(fetchNodes)
</script>
<template>
<main class="nodes-page">
<header class="nodes-header">
<h1>Nodes</h1>
<button class="btn-secondary" @click="fetchNodes" :disabled="loading">Refresh</button>
</header>
<div aria-live="polite" aria-atomic="true" class="sr-announce">
<span v-if="loading">Loading nodes...</span>
</div>
<div v-if="error" class="nodes-status nodes-error" role="alert">{{ error }}</div>
<div v-else-if="!loading && nodes.length === 0" class="nodes-status">
No nodes found. Check <code>coordinator_url</code> in config.
</div>
<div v-else-if="!loading" class="nodes-grid">
<NodeCard
v-for="node in nodes"
:key="node.node_id"
:node="node"
@updated="fetchNodes"
/>
</div>
</main>
</template>
<style scoped>
.nodes-page { padding: 1.5rem; }
.nodes-header {
display: flex;
align-items: center;
justify-content: space-between;
margin-bottom: 1.5rem;
}
.nodes-header h1 { margin: 0; font-size: 1.5rem; }
.nodes-grid { display: flex; flex-direction: column; gap: 1.5rem; }
.nodes-status {
color: var(--text-secondary, #888);
padding: 2rem;
text-align: center;
}
.nodes-error { color: var(--color-error, #fc8181); }
.sr-announce { min-height: 1.2em; }
</style>

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,161 @@
import { mount, flushPromises } from '@vue/test-utils'
import { describe, it, expect, vi, beforeEach } from 'vitest'
import TrainJobsView from './TrainJobsView.vue'
const sampleJob = {
id: 'job-abc123',
type: 'classifier',
model_key: 'deberta-v3-small',
status: 'queued',
created_at: '2026-05-01T10:00:00Z',
config: null,
}
function makeFetch(jobs: unknown[] = []) {
return vi.fn().mockImplementation((url: string, opts?: RequestInit) => {
if ((opts?.method ?? 'GET') === 'POST') {
return Promise.resolve({
ok: true,
json: async () => ({ ...sampleJob, id: 'new-job', status: 'queued' }),
text: async () => '',
})
}
if ((opts?.method ?? 'GET') === 'DELETE') {
return Promise.resolve({ ok: true, json: async () => ({}), text: async () => '' })
}
// GET
return Promise.resolve({
ok: true,
json: async () => ({ jobs }),
text: async () => '',
})
})
}
class MockEventSource {
onmessage: ((e: MessageEvent) => void) | null = null
onerror: ((e: Event) => void) | null = null
private _url: string
constructor(url: string) { this._url = url }
close() {}
}
beforeEach(() => {
vi.stubGlobal('fetch', makeFetch([sampleJob]))
vi.stubGlobal('EventSource', MockEventSource)
})
describe('TrainJobsView', () => {
it('renders page title "Training Jobs"', async () => {
const w = mount(TrainJobsView)
await flushPromises()
expect(w.find('h1.page-title').text()).toContain('Training Jobs')
})
it('renders the new job form with type selector and model key input', async () => {
const w = mount(TrainJobsView)
await flushPromises()
expect(w.find('select.job-type-select').exists()).toBe(true)
expect(w.find('input.model-key-input').exists()).toBe(true)
expect(w.find('button.submit-job-btn').exists()).toBe(true)
})
it('type selector has classifier and llm-sft options', async () => {
const w = mount(TrainJobsView)
await flushPromises()
const options = w.findAll('select.job-type-select option')
const values = options.map(o => o.attributes('value') ?? o.element.textContent)
expect(values).toContain('classifier')
expect(values).toContain('llm-sft')
})
it('submit button is disabled when model key is empty', async () => {
const w = mount(TrainJobsView)
await flushPromises()
const btn = w.find('button.submit-job-btn')
expect((btn.element as HTMLButtonElement).disabled).toBe(true)
})
it('submit button is enabled when model key is entered', async () => {
const w = mount(TrainJobsView)
await flushPromises()
await w.find('input.model-key-input').setValue('deberta-v3-small')
const btn = w.find('button.submit-job-btn')
expect((btn.element as HTMLButtonElement).disabled).toBe(false)
})
it('shows job table with existing jobs', async () => {
const w = mount(TrainJobsView)
await flushPromises()
expect(w.find('table.jobs-table').exists()).toBe(true)
expect(w.text()).toContain('deberta-v3-small')
})
it('shows status pill for each job', async () => {
const w = mount(TrainJobsView)
await flushPromises()
expect(w.find('.status-pill').exists()).toBe(true)
expect(w.find('.status-queued').exists()).toBe(true)
})
it('shows cancel button for queued/running jobs', async () => {
const w = mount(TrainJobsView)
await flushPromises()
expect(w.find('button.cancel-btn').exists()).toBe(true)
})
it('submitting new job calls POST /api/train/jobs and refreshes', async () => {
const fetchMock = makeFetch([])
vi.stubGlobal('fetch', fetchMock)
const w = mount(TrainJobsView)
await flushPromises()
await w.find('input.model-key-input').setValue('my-model')
await w.find('button.submit-job-btn').trigger('click')
await flushPromises()
const calls = (fetchMock as ReturnType<typeof vi.fn>).mock.calls as [string, RequestInit?][]
const postCall = calls.find(([, opts]) => (opts?.method ?? 'GET') === 'POST')
expect(postCall).toBeDefined()
expect(postCall![0]).toContain('/api/train/jobs')
})
it('shows View Log button for running jobs', async () => {
vi.stubGlobal('fetch', makeFetch([{ ...sampleJob, status: 'running' }]))
const w = mount(TrainJobsView)
await flushPromises()
expect(w.find('button.view-log-btn').exists()).toBe(true)
})
it('shows error when config JSON is invalid', async () => {
const w = mount(TrainJobsView)
await flushPromises()
await w.find('input.model-key-input').setValue('my-model')
await w.find('textarea.config-textarea').setValue('{ not valid json }')
await w.find('button.submit-job-btn').trigger('click')
await flushPromises()
expect(w.find('.error-notice').exists()).toBe(true)
expect(w.find('.error-notice').text()).toContain('not valid')
})
it('shows error notice when jobs load fails', async () => {
vi.stubGlobal('fetch', vi.fn().mockResolvedValue({
ok: false,
status: 500,
json: async () => ({}),
text: async () => '',
}))
const w = mount(TrainJobsView)
await flushPromises()
expect(w.find('.error-notice').exists()).toBe(true)
expect(w.find('table.jobs-table').exists()).toBe(false)
})
it('cancel button optimistically updates job status to cancelled', async () => {
const w = mount(TrainJobsView)
await flushPromises()
await w.find('button.cancel-btn').trigger('click')
await flushPromises()
// After cancel, job should show status-cancelled pill (not status-queued)
expect(w.find('.status-cancelled').exists()).toBe(true)
expect(w.find('.status-queued').exists()).toBe(false)
})
})

View file

@ -0,0 +1,593 @@
<template>
<div class="train-jobs-view">
<header class="view-header">
<h1 class="page-title">🧠 Training Jobs</h1>
</header>
<!-- New Job form -->
<section class="section">
<h2 class="section-title">New Job</h2>
<form class="new-job-form" @submit.prevent="submitJob">
<div class="form-row">
<label class="form-label" for="job-type">Type</label>
<select
id="job-type"
v-model="form.type"
class="job-type-select form-control"
>
<option value="classifier">classifier</option>
<option value="llm-sft">llm-sft</option>
</select>
</div>
<div class="form-row">
<label class="form-label" for="model-key">Model key</label>
<input
id="model-key"
v-model.trim="form.model_key"
type="text"
class="model-key-input form-control"
placeholder="e.g. microsoft/deberta-v3-small"
autocomplete="off"
/>
</div>
<div class="form-row">
<label class="form-label" for="job-config">Config JSON <span class="form-hint">(optional)</span></label>
<textarea
id="job-config"
v-model="form.config_raw"
class="config-textarea form-control"
rows="4"
placeholder='{"learning_rate": 2e-5}'
/>
</div>
<div v-if="submitError" class="error-notice" role="alert">{{ submitError }}</div>
<button
type="submit"
class="submit-job-btn btn-primary"
:disabled="submitting || !form.model_key"
@click.prevent="submitJob"
>
{{ submitting ? 'Queuing…' : 'Queue Job' }}
</button>
</form>
</section>
<!-- Job queue table -->
<section class="section">
<h2 class="section-title">Job Queue</h2>
<div v-if="loadError" class="error-notice" role="alert">
{{ loadError }}
<button class="btn-retry" @click="loadJobs">Retry</button>
</div>
<div v-else-if="jobs.length === 0" class="empty-notice">
No training jobs yet.
</div>
<div v-else class="jobs-table-wrap">
<table class="jobs-table">
<thead>
<tr>
<th>ID</th>
<th>Type</th>
<th>Model</th>
<th>Status</th>
<th>Created</th>
<th></th>
</tr>
</thead>
<tbody>
<tr v-for="job in jobs" :key="job.id">
<td class="td-id" :title="job.id">{{ job.id.slice(0, 8) }}</td>
<td>
<span class="type-chip">{{ job.type }}</span>
</td>
<td class="td-model">{{ job.model_key }}</td>
<td>
<span class="status-pill" :class="`status-${job.status}`">{{ job.status }}</span>
</td>
<td class="td-date">{{ formatDate(job.created_at) }}</td>
<td class="td-actions">
<button
v-if="job.status === 'running'"
class="view-log-btn btn-sm"
@click="openLog(job.id)"
>
View Log
</button>
<button
v-if="job.status === 'queued' || job.status === 'running'"
class="cancel-btn btn-sm btn-danger-sm"
:disabled="cancellingId === job.id"
@click="cancelJob(job.id)"
>
{{ cancellingId === job.id ? '…' : 'Cancel' }}
</button>
</td>
</tr>
</tbody>
</table>
</div>
<div v-if="cancelError" class="error-notice" role="alert">{{ cancelError }}</div>
</section>
<!-- Log panel (SSE) -->
<section v-if="logJobId" class="section log-section">
<div class="log-header">
<h2 class="section-title">Log {{ logJobId.slice(0, 8) }}</h2>
<button class="btn-close-log" @click="closeLog"> Close</button>
</div>
<div class="log-panel" ref="logPanelEl">
<div
v-for="(line, i) in logLines"
:key="i"
class="log-line"
>{{ line }}</div>
<div v-if="logLines.length === 0" class="log-line log-muted">Connecting</div>
</div>
</section>
</div>
</template>
<script setup lang="ts">
import { ref, nextTick, onUnmounted } from 'vue'
import { useApiSSE } from '../composables/useApi'
interface TrainJob {
id: string
type: 'classifier' | 'llm-sft'
model_key: string
status: 'queued' | 'running' | 'completed' | 'failed' | 'cancelled'
created_at: string
config: Record<string, unknown> | null
}
const jobs = ref<TrainJob[]>([])
const loadError = ref<string | null>(null)
const submitError = ref<string | null>(null)
const submitting = ref(false)
const cancellingId = ref<string | null>(null)
const cancelError = ref<string | null>(null)
const form = ref({
type: 'classifier' as 'classifier' | 'llm-sft',
model_key: '',
config_raw: '',
})
// Log panel state
const logJobId = ref<string | null>(null)
const logLines = ref<string[]>([])
const logPanelEl = ref<HTMLElement | null>(null)
let closeSSE: (() => void) | null = null
// Data loading
async function loadJobs() {
loadError.value = null
try {
const res = await fetch('/api/train/jobs')
if (!res.ok) { loadError.value = `Failed to load jobs (HTTP ${res.status}).`; return }
const data = await res.json() as { jobs: TrainJob[] }
jobs.value = data.jobs ?? []
} catch {
loadError.value = 'Network error loading jobs.'
}
}
// Submit
async function submitJob() {
if (!form.value.model_key) return
submitError.value = null
submitting.value = true
let config: Record<string, unknown> | null = null
if (form.value.config_raw.trim()) {
try {
config = JSON.parse(form.value.config_raw) as Record<string, unknown>
} catch {
submitError.value = 'Config JSON is not valid. Fix it before submitting.'
submitting.value = false
return
}
}
try {
const res = await fetch('/api/train/jobs', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({
type: form.value.type,
model_key: form.value.model_key,
config_json: config,
}),
})
if (!res.ok) {
const detail = await res.text().catch(() => '')
submitError.value = `Failed to queue job (HTTP ${res.status})${detail ? `: ${detail}` : '.'}`
return
}
const newJob = await res.json() as TrainJob
jobs.value = [newJob, ...jobs.value]
form.value = { type: 'classifier', model_key: '', config_raw: '' }
} catch {
submitError.value = 'Network error submitting job.'
} finally {
submitting.value = false
}
}
// Cancel
async function cancelJob(id: string) {
cancellingId.value = id
cancelError.value = null
try {
const res = await fetch(`/api/train/jobs/${encodeURIComponent(id)}/cancel`, { method: 'DELETE' })
if (res.ok) {
jobs.value = jobs.value.map(j =>
j.id === id ? { ...j, status: 'cancelled' as const } : j
)
} else {
cancelError.value = `Failed to cancel job (HTTP ${res.status}).`
}
} catch {
cancelError.value = 'Network error cancelling job.'
} finally {
cancellingId.value = null
}
}
// Log SSE
function openLog(id: string) {
closeLog()
logJobId.value = id
logLines.value = []
closeSSE = useApiSSE(
`/api/train/jobs/${encodeURIComponent(id)}/run`,
(data) => {
if (data.type === 'log' || data.type === 'progress' || data.type === 'error') {
logLines.value = [...logLines.value, String(data.message ?? '')]
nextTick(() => {
if (logPanelEl.value) {
logPanelEl.value.scrollTop = logPanelEl.value.scrollHeight
}
})
}
if (data.type === 'error') {
logLines.value = [...logLines.value, '--- stream ended with error ---']
nextTick(() => {
if (logPanelEl.value) {
logPanelEl.value.scrollTop = logPanelEl.value.scrollHeight
}
})
}
},
() => {
logLines.value = [...logLines.value, '--- stream complete ---']
},
() => {
logLines.value = [...logLines.value, '--- connection lost ---']
},
)
}
function closeLog() {
closeSSE?.()
closeSSE = null
logJobId.value = null
logLines.value = []
}
// Helpers
function formatDate(iso: string): string {
const d = new Date(iso)
if (isNaN(d.getTime())) return iso
return d.toLocaleString(undefined, { dateStyle: 'short', timeStyle: 'short' })
}
// Lifecycle
loadJobs()
onUnmounted(() => {
closeSSE?.()
})
</script>
<style scoped>
.train-jobs-view {
max-width: 860px;
margin: 0 auto;
padding: 1.5rem 1rem 4rem;
display: flex;
flex-direction: column;
gap: 2rem;
}
.view-header { display: flex; align-items: center; }
.page-title {
font-family: var(--font-display, var(--font-body, sans-serif));
font-size: 1.4rem;
font-weight: 700;
color: var(--app-primary, #2A6080);
margin: 0;
}
.section {
display: flex;
flex-direction: column;
gap: 0.75rem;
}
.section-title {
font-size: 1rem;
font-weight: 600;
color: var(--color-text, #1a2338);
padding-bottom: 0.4rem;
border-bottom: 1px solid var(--color-border, #a8b8d0);
margin: 0;
}
.new-job-form {
display: flex;
flex-direction: column;
gap: 0.75rem;
max-width: 480px;
}
.form-row {
display: flex;
flex-direction: column;
gap: 0.3rem;
}
.form-label {
font-size: 0.85rem;
font-weight: 600;
color: var(--color-text-muted, #4a5c7a);
}
.form-hint {
font-weight: 400;
font-size: 0.78rem;
}
.form-control {
padding: 0.45rem 0.65rem;
border: 1px solid var(--color-border, #a8b8d0);
border-radius: var(--radius-md, 0.5rem);
background: var(--color-surface-raised, #f5f7fc);
color: var(--color-text, #1a2338);
font-size: 0.9rem;
font-family: var(--font-body, sans-serif);
}
.form-control:focus {
outline: 2px solid var(--app-primary, #2A6080);
outline-offset: -1px;
}
.config-textarea {
resize: vertical;
font-family: var(--font-mono, monospace);
font-size: 0.82rem;
}
.btn-primary {
padding: 0.4rem 0.9rem;
border-radius: var(--radius-md, 0.5rem);
border: 1px solid var(--app-primary, #2A6080);
background: var(--app-primary, #2A6080);
color: #fff;
font-size: 0.88rem;
font-family: var(--font-body, sans-serif);
cursor: pointer;
align-self: flex-start;
transition: opacity 0.15s;
}
.btn-primary:disabled { opacity: 0.5; cursor: not-allowed; }
.btn-primary:not(:disabled):hover { opacity: 0.85; }
.btn-sm {
padding: 0.2rem 0.55rem;
font-size: 0.78rem;
border-radius: 0.3rem;
cursor: pointer;
font-family: var(--font-body, sans-serif);
border: 1px solid;
transition: background 0.1s;
}
.view-log-btn {
border-color: var(--color-info, #1e6091);
background: transparent;
color: var(--color-info, #1e6091);
}
.view-log-btn:hover {
background: color-mix(in srgb, var(--color-info, #1e6091) 10%, transparent);
}
.btn-danger-sm {
border-color: var(--color-error, #c0392b);
background: transparent;
color: var(--color-error, #c0392b);
}
.btn-danger-sm:hover:not(:disabled) {
background: color-mix(in srgb, var(--color-error, #c0392b) 10%, transparent);
}
.btn-danger-sm:disabled { opacity: 0.5; cursor: not-allowed; }
.btn-retry {
margin-left: 0.5rem;
padding: 0.2rem 0.55rem;
border-radius: 0.25rem;
border: 1px solid var(--color-error, #c0392b);
background: transparent;
color: var(--color-error, #c0392b);
cursor: pointer;
font-size: 0.82rem;
}
.error-notice {
padding: 0.6rem 0.8rem;
background: color-mix(in srgb, var(--color-error, #c0392b) 10%, transparent);
border: 1px solid color-mix(in srgb, var(--color-error, #c0392b) 30%, transparent);
border-radius: var(--radius-md, 0.5rem);
color: var(--color-error, #c0392b);
font-size: 0.88rem;
display: flex;
align-items: center;
gap: 0.5rem;
}
.empty-notice {
color: var(--color-text-muted, #4a5c7a);
font-size: 0.9rem;
padding: 0.75rem;
border: 1px dashed var(--color-border, #a8b8d0);
border-radius: var(--radius-md, 0.5rem);
}
.jobs-table-wrap { overflow-x: auto; }
.jobs-table {
width: 100%;
border-collapse: collapse;
font-size: 0.875rem;
}
.jobs-table th {
text-align: left;
padding: 0.4rem 0.6rem;
background: var(--color-surface-raised, #f5f7fc);
color: var(--color-text-muted, #4a5c7a);
font-size: 0.78rem;
font-weight: 600;
text-transform: uppercase;
letter-spacing: 0.03em;
border-bottom: 1px solid var(--color-border, #a8b8d0);
white-space: nowrap;
}
.jobs-table td {
padding: 0.5rem 0.6rem;
border-bottom: 1px solid var(--color-border-light, #ccd5e6);
vertical-align: middle;
}
.td-id {
font-family: var(--font-mono, monospace);
font-size: 0.78rem;
color: var(--color-text-muted, #4a5c7a);
}
.td-model {
font-family: var(--font-mono, monospace);
font-size: 0.82rem;
word-break: break-all;
}
.td-date {
font-size: 0.8rem;
color: var(--color-text-muted, #4a5c7a);
white-space: nowrap;
}
.td-actions {
display: flex;
gap: 0.35rem;
align-items: center;
flex-wrap: wrap;
}
.status-pill {
font-size: 0.68rem;
font-weight: 700;
text-transform: uppercase;
letter-spacing: 0.04em;
padding: 0.15rem 0.45rem;
border-radius: var(--radius-full, 9999px);
white-space: nowrap;
}
.status-queued { background: var(--color-surface-alt, #dde4f0); color: var(--color-text-muted, #4a5c7a); }
.status-running { background: color-mix(in srgb, var(--color-info, #1e6091) 15%, transparent); color: var(--color-info, #1e6091); }
.status-completed { background: color-mix(in srgb, var(--color-success, #3a7a32) 15%, transparent); color: var(--color-success, #3a7a32); }
.status-failed { background: color-mix(in srgb, var(--color-error, #c0392b) 15%, transparent); color: var(--color-error, #c0392b); }
.status-cancelled { background: color-mix(in srgb, var(--color-warning, #d4891a) 15%, transparent); color: var(--color-warning, #d4891a); }
.type-chip {
font-size: 0.72rem;
font-family: var(--font-mono, monospace);
padding: 0.1rem 0.4rem;
border-radius: 0.25rem;
background: var(--color-surface-alt, #dde4f0);
color: var(--color-text, #1a2338);
}
.log-section { gap: 0.5rem; }
.log-header {
display: flex;
align-items: center;
justify-content: space-between;
gap: 0.5rem;
}
.btn-close-log {
background: transparent;
border: 1px solid var(--color-border, #d0d7e8);
border-radius: 0.25rem;
cursor: pointer;
font-size: 0.8rem;
padding: 0.2rem 0.5rem;
color: var(--color-text-muted, #4a5c7a);
transition: background 0.1s;
}
.btn-close-log:hover { background: var(--color-surface-raised, #e4ebf5); }
.log-panel {
border: 1px solid var(--color-border, #d0d7e8);
border-radius: 0.5rem;
max-height: 320px;
overflow-y: auto;
padding: 0.5rem 0.75rem;
background: var(--color-surface, #f0f4fc);
font-family: var(--font-mono, monospace);
font-size: 0.78rem;
}
.log-line {
color: var(--color-text, #1a2338);
line-height: 1.5;
white-space: pre-wrap;
word-break: break-all;
}
.log-muted { color: var(--color-text-muted, #4a5c7a); }
@media (max-width: 560px) {
.jobs-table th:nth-child(4),
.jobs-table td:nth-child(4),
.jobs-table th:nth-child(5),
.jobs-table td:nth-child(5) {
display: none;
}
}
</style>

View file

@ -0,0 +1,101 @@
import { mount, flushPromises } from '@vue/test-utils'
import { createRouter, createWebHashHistory } from 'vue-router'
import { describe, it, expect, vi, beforeEach } from 'vitest'
import TrainResultsView from './TrainResultsView.vue'
const router = createRouter({
history: createWebHashHistory(),
routes: [
{ path: '/fleet', component: { template: '<div />' } },
],
})
const sampleResult = {
id: 'run-xyz',
job_id: 'job-abc123',
model_type: 'classifier',
base_model: 'microsoft/deberta-v3-small',
val_macro_f1: 0.847,
val_accuracy: 0.891,
sample_count: 1240,
duration_seconds: 842,
created_at: '2026-05-01T11:30:00Z',
}
function makeFetch(results: unknown[] = []) {
return vi.fn().mockResolvedValue({
ok: true,
json: async () => ({ results }),
text: async () => '',
})
}
beforeEach(() => {
vi.stubGlobal('fetch', makeFetch([sampleResult]))
})
describe('TrainResultsView', () => {
it('renders page title "Training Results"', async () => {
const w = mount(TrainResultsView, { global: { plugins: [router] } })
await flushPromises()
expect(w.find('h1.page-title').text()).toContain('Training Results')
})
it('shows empty notice when there are no results', async () => {
vi.stubGlobal('fetch', makeFetch([]))
const w = mount(TrainResultsView, { global: { plugins: [router] } })
await flushPromises()
expect(w.find('.empty-notice').exists()).toBe(true)
})
it('renders results table when results exist', async () => {
const w = mount(TrainResultsView, { global: { plugins: [router] } })
await flushPromises()
expect(w.find('table.results-table').exists()).toBe(true)
})
it('shows base_model in table', async () => {
const w = mount(TrainResultsView, { global: { plugins: [router] } })
await flushPromises()
expect(w.text()).toContain('deberta-v3-small')
})
it('shows val_macro_f1 formatted as percentage', async () => {
const w = mount(TrainResultsView, { global: { plugins: [router] } })
await flushPromises()
expect(w.text()).toContain('84.7%')
})
it('shows val_accuracy formatted as percentage', async () => {
const w = mount(TrainResultsView, { global: { plugins: [router] } })
await flushPromises()
expect(w.text()).toContain('89.1%')
})
it('shows duration formatted as minutes and seconds', async () => {
const w = mount(TrainResultsView, { global: { plugins: [router] } })
await flushPromises()
// 842 seconds = 14m 2s
expect(w.text()).toContain('14m')
})
it('shows Register in Fleet button for classifier results', async () => {
const w = mount(TrainResultsView, { global: { plugins: [router] } })
await flushPromises()
expect(w.find('a.register-btn').exists()).toBe(true)
})
it('does NOT show Register in Fleet button for llm-sft results', async () => {
vi.stubGlobal('fetch', makeFetch([{ ...sampleResult, model_type: 'llm-sft' }]))
const w = mount(TrainResultsView, { global: { plugins: [router] } })
await flushPromises()
expect(w.find('a.register-btn').exists()).toBe(false)
})
it('shows error notice when API fails', async () => {
vi.stubGlobal('fetch', vi.fn().mockResolvedValue({ ok: false, status: 500, text: async () => '' }))
const w = mount(TrainResultsView, { global: { plugins: [router] } })
await flushPromises()
expect(w.find('.error-notice').exists()).toBe(true)
})
})

View file

@ -0,0 +1,296 @@
<template>
<div class="train-results-view">
<header class="view-header">
<h1 class="page-title">Training Results</h1>
<button class="refresh-btn" :disabled="loading" @click="loadResults" aria-label="Refresh">&#x1F504;</button>
</header>
<div v-if="error" class="error-notice" role="alert">
{{ error }}
<button class="btn-retry" @click="loadResults">Retry</button>
</div>
<div v-if="loading" class="loading-state" aria-live="polite">Loading</div>
<div v-if="!error && results.length === 0 && !loading" class="empty-notice">
No training results yet. Completed jobs will appear here.
</div>
<div v-if="results.length > 0" class="results-table-wrap">
<table class="results-table">
<thead>
<tr>
<th>Run</th>
<th>Type</th>
<th>Base Model</th>
<th class="th-numeric">Macro F1</th>
<th class="th-numeric">Accuracy</th>
<th class="th-numeric">Samples</th>
<th class="th-numeric">Duration</th>
<th></th>
</tr>
</thead>
<tbody>
<tr v-for="r in results" :key="r.id">
<td class="td-id" :title="r.id">{{ r.id.slice(0, 8) }}</td>
<td>
<span class="type-chip">{{ r.model_type }}</span>
</td>
<td class="td-model" :title="r.base_model">{{ shortModel(r.base_model) }}</td>
<td class="td-numeric">
<span class="metric-val" :class="scoreClass(r.val_macro_f1)">
{{ formatPct(r.val_macro_f1) }}
</span>
</td>
<td class="td-numeric">{{ formatPct(r.val_accuracy) }}</td>
<td class="td-numeric">{{ r.sample_count.toLocaleString() }}</td>
<td class="td-numeric">{{ formatDuration(r.duration_seconds) }}</td>
<td class="td-actions">
<RouterLink
v-if="r.model_type === 'classifier'"
:to="`/fleet?model=${encodeURIComponent(r.base_model)}`"
class="register-btn btn-sm-link"
>
Register in Fleet
</RouterLink>
</td>
</tr>
</tbody>
</table>
</div>
</div>
</template>
<script setup lang="ts">
import { ref, onMounted } from 'vue'
import { RouterLink } from 'vue-router'
interface TrainResult {
id: string
job_id: string
model_type: string
base_model: string
val_macro_f1: number | null
val_accuracy: number | null
sample_count: number
duration_seconds: number | null
created_at: string
}
const results = ref<TrainResult[]>([])
const loading = ref(false)
const error = ref<string | null>(null)
async function loadResults() {
loading.value = true
error.value = null
try {
const res = await fetch('/api/train/results')
if (!res.ok) {
error.value = `Failed to load results (HTTP ${res.status}).`
return
}
const raw = await res.json() as { results?: TrainResult[] }
results.value = Array.isArray(raw?.results) ? raw.results : []
} catch {
error.value = 'Network error loading results.'
} finally {
loading.value = false
}
}
function formatPct(v: number | null | undefined): string {
if (v == null) return '—'
return `${(v * 100).toFixed(1)}%`
}
function formatDuration(seconds: number | null | undefined): string {
if (seconds == null) return '—'
const mins = Math.floor(seconds / 60)
const secs = seconds % 60
if (mins === 0) return `${secs}s`
return `${mins}m ${secs}s`
}
function shortModel(model: string): string {
const parts = model.split('/')
return parts[parts.length - 1] ?? model
}
function scoreClass(f1: number | null | undefined): string {
if (f1 == null) return ''
if (f1 >= 0.85) return 'score-great'
if (f1 >= 0.75) return 'score-good'
return 'score-fair'
}
onMounted(() => loadResults())
</script>
<style scoped>
.train-results-view {
max-width: 860px;
margin: 0 auto;
padding: 1.5rem 1rem 4rem;
display: flex;
flex-direction: column;
gap: 1.75rem;
}
.view-header {
display: flex;
align-items: center;
gap: 0.75rem;
}
.page-title {
font-family: var(--font-display, var(--font-body, sans-serif));
font-size: 1.4rem;
font-weight: 700;
color: var(--app-primary, #2A6080);
margin: 0;
flex: 1;
}
.refresh-btn {
background: transparent;
border: 1px solid var(--color-border, #d0d7e8);
border-radius: 0.375rem;
cursor: pointer;
font-size: 1rem;
padding: 0.3rem 0.5rem;
transition: background 0.15s;
}
.refresh-btn:hover:not(:disabled) { background: var(--color-surface-raised, #e4ebf5); }
.refresh-btn:disabled { opacity: 0.5; cursor: not-allowed; }
.error-notice {
display: flex;
align-items: center;
gap: 0.75rem;
padding: 0.75rem 1rem;
background: color-mix(in srgb, var(--color-error, #c0392b) 10%, transparent);
border: 1px solid color-mix(in srgb, var(--color-error, #c0392b) 30%, transparent);
border-radius: var(--radius-md, 0.5rem);
color: var(--color-error, #c0392b);
font-size: 0.88rem;
}
.btn-retry {
padding: 0.2rem 0.55rem;
border-radius: 0.25rem;
border: 1px solid var(--color-error, #c0392b);
background: transparent;
color: var(--color-error, #c0392b);
cursor: pointer;
font-size: 0.82rem;
}
.empty-notice {
color: var(--color-text-muted, #4a5c7a);
font-size: 0.9rem;
padding: 0.75rem;
border: 1px dashed var(--color-border, #a8b8d0);
border-radius: var(--radius-md, 0.5rem);
}
.loading-state {
color: var(--color-text-muted, #4a5c7a);
font-size: 0.9rem;
padding: 0.75rem;
}
.results-table-wrap { overflow-x: auto; }
.results-table {
width: 100%;
border-collapse: collapse;
font-size: 0.875rem;
}
.results-table th {
text-align: left;
padding: 0.4rem 0.6rem;
background: var(--color-surface-raised, #f5f7fc);
color: var(--color-text-muted, #4a5c7a);
font-size: 0.78rem;
font-weight: 600;
text-transform: uppercase;
letter-spacing: 0.03em;
border-bottom: 1px solid var(--color-border, #a8b8d0);
white-space: nowrap;
}
.th-numeric { text-align: right; }
.results-table td {
padding: 0.5rem 0.6rem;
border-bottom: 1px solid var(--color-border-light, #ccd5e6);
vertical-align: middle;
}
.td-id {
font-family: var(--font-mono, monospace);
font-size: 0.78rem;
color: var(--color-text-muted, #4a5c7a);
}
.td-model {
font-family: var(--font-mono, monospace);
font-size: 0.82rem;
max-width: 16ch;
overflow: hidden;
text-overflow: ellipsis;
white-space: nowrap;
}
.td-numeric {
text-align: right;
font-family: var(--font-mono, monospace);
font-variant-numeric: tabular-nums;
font-size: 0.82rem;
}
.td-actions { text-align: right; }
.metric-val { font-weight: 600; }
.score-great { color: var(--color-success, #3a7a32); }
.score-good { color: var(--color-warning, #d4891a); }
.score-fair { color: var(--color-text-muted, #4a5c7a); }
.type-chip {
font-size: 0.72rem;
font-family: var(--font-mono, monospace);
padding: 0.1rem 0.4rem;
border-radius: 0.25rem;
background: var(--color-surface-alt, #dde4f0);
color: var(--color-text, #1a2338);
}
.btn-sm-link {
font-size: 0.78rem;
padding: 0.2rem 0.55rem;
border-radius: 0.3rem;
border: 1px solid var(--app-primary, #2A6080);
color: var(--app-primary, #2A6080);
background: transparent;
text-decoration: none;
white-space: nowrap;
display: inline-block;
transition: background 0.1s;
}
.btn-sm-link:hover {
background: color-mix(in srgb, var(--app-primary, #2A6080) 10%, transparent);
}
@media (max-width: 600px) {
.results-table th:nth-child(6),
.results-table td:nth-child(6),
.results-table th:nth-child(7),
.results-table td:nth-child(7) {
display: none;
}
}
</style>