Compare commits
No commits in common. "main" and "feat/benchmark-model-picker" have entirely different histories.
main
...
feat/bench
97 changed files with 2907 additions and 28296 deletions
23
.env.example
23
.env.example
|
|
@ -1,23 +0,0 @@
|
|||
# Avocet — environment variable configuration
|
||||
# Copy to .env and fill in values. All keys are optional.
|
||||
# label_tool.yaml takes precedence over env vars where both exist.
|
||||
|
||||
# ── Local inference (Ollama) ───────────────────────────────────────────────────
|
||||
# OLLAMA_HOST defaults to http://localhost:11434 if unset.
|
||||
OLLAMA_HOST=http://localhost:11434
|
||||
OLLAMA_MODEL=llama3.2:3b
|
||||
|
||||
# ── cf-orch coordinator (paid/premium tiers) ───────────────────────────────────
|
||||
# Required for multi-GPU LLM benchmarking via the cf-orch benchmark harness.
|
||||
# Free-tier users can leave these unset and use Ollama only.
|
||||
CF_ORCH_URL=http://localhost:7700
|
||||
CF_LICENSE_KEY=CFG-AVCT-xxxx-xxxx-xxxx
|
||||
|
||||
# ── Cloud LLM backends (optional — paid/premium) ──────────────────────────────
|
||||
# Set one of these to use a cloud LLM instead of a local model.
|
||||
# ANTHROPIC_API_KEY=sk-ant-...
|
||||
# OPENAI_API_KEY=sk-...
|
||||
|
||||
# ── HuggingFace (required for gated/terms-restricted model downloads) ─────────
|
||||
# Generate at https://huggingface.co/settings/tokens and accept model terms first.
|
||||
# HF_TOKEN=hf_xxxxxxxxxxxxxxxxxxxx
|
||||
7
.gitignore
vendored
7
.gitignore
vendored
|
|
@ -8,9 +8,6 @@ __pycache__/
|
|||
config/label_tool.yaml
|
||||
|
||||
# Data files (user-generated, not for version control)
|
||||
data/corpus.db
|
||||
data/corpus.db-wal
|
||||
data/corpus.db-shm
|
||||
data/email_score.jsonl
|
||||
data/email_label_queue.jsonl
|
||||
data/email_compare_sample.jsonl
|
||||
|
|
@ -23,7 +20,3 @@ data/sft_approved.jsonl
|
|||
# Claude context — BSL 1.1, keep out of version control
|
||||
CLAUDE.md
|
||||
docs/superpowers/
|
||||
.superpowers/
|
||||
|
||||
# Git worktrees
|
||||
.worktrees/
|
||||
|
|
|
|||
177
README.md
177
README.md
|
|
@ -1,177 +0,0 @@
|
|||
<div align="center">
|
||||
<img src="docs/avocet-logo.svg" alt="Avocet" height="96" />
|
||||
|
||||
# Avocet
|
||||
|
||||
**Email classifier training tool — label, benchmark, fine-tune.**
|
||||
|
||||
[]()
|
||||
[](https://git.opensourcesolarpunk.com/Circuit-Forge/avocet/releases)
|
||||
[](LICENSE)
|
||||
[]()
|
||||
[](https://circuitforge.tech)
|
||||
</div>
|
||||
|
||||
---
|
||||
|
||||
## What is Avocet?
|
||||
|
||||
Avocet is the internal data pipeline Circuit Forge uses to build, evaluate, and fine-tune email classifiers. It implements a three-stage workflow: human labelers review emails one at a time in a drag-to-bucket UI and produce a ground-truth dataset; the benchmark harness scores any number of HuggingFace zero-shot models against that dataset and produces a ranked comparison; and the fine-tune harness adapts the best-scoring base model to the labeled distribution. The output feeds directly into Peregrine's email classification layer. No LLM API key required for the label tool or benchmark — all inference runs locally via HuggingFace Transformers.
|
||||
|
||||
---
|
||||
|
||||
## Quick Start
|
||||
|
||||
```bash
|
||||
git clone https://git.opensourcesolarpunk.com/Circuit-Forge/avocet.git
|
||||
cd avocet
|
||||
|
||||
# Copy config template and fill in your IMAP credentials
|
||||
cp config/label_tool.yaml.example config/label_tool.yaml
|
||||
|
||||
# Start the label tool (Vue SPA + FastAPI, port 8503)
|
||||
./manage.sh start
|
||||
./manage.sh open
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Features
|
||||
|
||||
- **Drag-to-bucket label UI** — ASMR-style card interface; drag emails into labeled buckets or discard without queuing noise into the training set
|
||||
- **Targeted IMAP fetch** — pull emails by date range, sender, or subject filter across multiple accounts without flooding the queue
|
||||
- **Email classifier benchmark** — score any HuggingFace zero-shot model against your labeled JSONL; side-by-side comparison on live IMAP emails
|
||||
- **Planning benchmark** — evaluate LLMs on structured planning tasks; compare models head-to-head with verbose diff output
|
||||
- **Writing style benchmark** — compare Ollama models on writing style coherence; scan local disk for existing outputs
|
||||
- **Fine-tune harness** — HuggingFace Transformers fine-tuning from labeled ground truth; classifier adapter interface for swapping backends at runtime
|
||||
- **Local inference first** — no API key required; GPU optional; designed to run on developer hardware
|
||||
- **Hot-reload dev mode** — uvicorn `--reload` + Vite HMR (hot module replacement) for fast iteration on both API and UI
|
||||
|
||||
---
|
||||
|
||||
## CLI Reference
|
||||
|
||||
All operations go through `manage.sh`.
|
||||
|
||||
### Label Tool
|
||||
|
||||
```bash
|
||||
./manage.sh start # Build Vue SPA and start FastAPI on port 8503
|
||||
./manage.sh stop # Stop FastAPI server
|
||||
./manage.sh restart # Stop, rebuild, and restart
|
||||
./manage.sh status # Show running state and port
|
||||
./manage.sh logs # Tail the API log
|
||||
./manage.sh open # Open http://localhost:8503 in browser
|
||||
./manage.sh dev # Hot-reload: uvicorn --reload + Vite HMR
|
||||
./manage.sh test # Run pytest suite
|
||||
```
|
||||
|
||||
### Email Classifier Benchmark
|
||||
|
||||
```bash
|
||||
./manage.sh benchmark [args] # Run benchmark_classifier.py
|
||||
./manage.sh list-models # List available zero-shot models
|
||||
./manage.sh score # Score models against labeled JSONL
|
||||
./manage.sh score --include-slow # Include large/slow models
|
||||
./manage.sh compare --limit 30 # Side-by-side comparison on live IMAP emails
|
||||
```
|
||||
|
||||
### Planning Benchmark
|
||||
|
||||
```bash
|
||||
./manage.sh plans-bench [args] # Run benchmark_plans.py
|
||||
./manage.sh plans-list # List available models
|
||||
./manage.sh plans-run <model> [args] # Run a single model (verbose)
|
||||
./manage.sh plans-compare <m1> <m2> [...] # Compare models side-by-side
|
||||
```
|
||||
|
||||
### Writing Style Benchmark
|
||||
|
||||
```bash
|
||||
./manage.sh style-bench [args] # Run benchmark_style.py
|
||||
./manage.sh style-list # List available Ollama models
|
||||
./manage.sh style-run [args] # Run writing style benchmark
|
||||
./manage.sh style-last # Print most recent benchmark report
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Data Flow
|
||||
|
||||
```
|
||||
IMAP accounts
|
||||
→ fetch (targeted or wide)
|
||||
→ email_label_queue.jsonl
|
||||
|
||||
email_label_queue.jsonl
|
||||
→ label tool drag-to-bucket UI
|
||||
→ email_score.jsonl (ground truth)
|
||||
|
||||
email_score.jsonl
|
||||
→ benchmark harness
|
||||
→ model rankings
|
||||
|
||||
best model
|
||||
→ fine-tune harness
|
||||
→ Peregrine classifier adapter
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Labels
|
||||
|
||||
| Label | Key |
|
||||
|-------|-----|
|
||||
| `interview_scheduled` | 1 |
|
||||
| `offer_received` | 2 |
|
||||
| `rejected` | 3 |
|
||||
| `positive_response` | 4 |
|
||||
| `survey_received` | 5 |
|
||||
| `neutral` | 6 |
|
||||
| `event_rescheduled` | 7 |
|
||||
| `unrelated` | 8 |
|
||||
| `digest` | 9 |
|
||||
|
||||
---
|
||||
|
||||
## Stack
|
||||
|
||||
| Layer | Technology |
|
||||
|-------|-----------|
|
||||
| Label UI | Vue 3 SPA (Vite) |
|
||||
| API | FastAPI + uvicorn (port 8503) |
|
||||
| Benchmark | Python + HuggingFace Transformers |
|
||||
| Email fetch | IMAP (multi-account, targeted date/sender/subject filter) |
|
||||
| Data | JSONL (`data/email_label_queue.jsonl`, `data/email_score.jsonl`) |
|
||||
| Runtime | SQLite |
|
||||
| Config | `config/label_tool.yaml` (gitignored — `.example` committed) |
|
||||
|
||||
---
|
||||
|
||||
## Logo
|
||||
|
||||
The Avocet logo (`avocet_v1_poly.svg`) lives in the shared graphics repo. Copy it to `docs/avocet-logo.svg` to render correctly in this README.
|
||||
|
||||
---
|
||||
|
||||
## About
|
||||
|
||||
Avocet is internal CircuitForge infrastructure, open source as a reference implementation. It is not a user-facing product. The primary consumer is [Peregrine](https://git.opensourcesolarpunk.com/Circuit-Forge/peregrine), CircuitForge's job-search pipeline tool.
|
||||
|
||||
Docs: [docs.circuitforge.tech/avocet](https://docs.circuitforge.tech/avocet)
|
||||
|
||||
## Forgejo-primary
|
||||
|
||||
Avocet is developed and maintained on Forgejo at [git.opensourcesolarpunk.com/Circuit-Forge/avocet](https://git.opensourcesolarpunk.com/Circuit-Forge/avocet). GitHub and Codeberg are read-only mirrors.
|
||||
|
||||
---
|
||||
|
||||
## License
|
||||
|
||||
[Business Source License 1.1](LICENSE) — classifier training is an AI feature under the CircuitForge licensing model.
|
||||
|
||||
Free for personal non-commercial self-hosting. Commercial use or SaaS re-hosting requires a paid license. Converts to MIT after 4 years.
|
||||
|
||||
Humans own design, architecture, code review, testing, and verification. LLMs are part of our development workflow. [Our positions on LLM use →](https://circuitforge.tech/positions)
|
||||
|
||||
© 2026 Circuit Forge LLC — Privacy · Safety · Accessibility
|
||||
641
app/api.py
641
app/api.py
|
|
@ -1,95 +1,614 @@
|
|||
"""Avocet -- FastAPI app factory.
|
||||
"""Avocet — FastAPI REST layer.
|
||||
|
||||
Mounts all domain routers and serves the Vue SPA.
|
||||
All business logic lives in the domain modules below.
|
||||
JSONL read/write helpers and FastAPI app instance.
|
||||
Endpoints and static file serving are added in subsequent tasks.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import subprocess as _subprocess
|
||||
import yaml
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import FastAPI
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from fastapi import FastAPI, HTTPException, Query
|
||||
from pydantic import BaseModel
|
||||
|
||||
_ROOT = Path(__file__).parent.parent
|
||||
_DATA_DIR: Path = _ROOT / "data" # overridable in tests via set_data_dir()
|
||||
_MODELS_DIR: Path = _ROOT / "models" # overridable in tests via set_models_dir()
|
||||
_CONFIG_DIR: Path | None = None # None = use real path
|
||||
|
||||
# Process registry for running jobs — used by cancel endpoints.
|
||||
# Keys: "benchmark" | "finetune". Values: the live Popen object.
|
||||
_running_procs: dict = {}
|
||||
_cancelled_jobs: set = set()
|
||||
|
||||
|
||||
def set_data_dir(path: Path) -> None:
|
||||
"""Override data directory — used by tests."""
|
||||
global _DATA_DIR
|
||||
_DATA_DIR = path
|
||||
|
||||
|
||||
def _best_cuda_device() -> str:
|
||||
"""Return the index of the GPU with the most free VRAM as a string.
|
||||
|
||||
Uses nvidia-smi so it works in the job-seeker env (no torch). Returns ""
|
||||
if nvidia-smi is unavailable or no GPUs are found. Restricting the
|
||||
training subprocess to a single GPU via CUDA_VISIBLE_DEVICES prevents
|
||||
PyTorch DataParallel from replicating the model across all GPUs, which
|
||||
would OOM the GPU with less headroom.
|
||||
"""
|
||||
try:
|
||||
out = _subprocess.check_output(
|
||||
["nvidia-smi", "--query-gpu=index,memory.free",
|
||||
"--format=csv,noheader,nounits"],
|
||||
text=True,
|
||||
timeout=5,
|
||||
)
|
||||
best_idx, best_free = "", 0
|
||||
for line in out.strip().splitlines():
|
||||
parts = line.strip().split(", ")
|
||||
if len(parts) == 2:
|
||||
idx, free = parts[0].strip(), int(parts[1].strip())
|
||||
if free > best_free:
|
||||
best_free, best_idx = free, idx
|
||||
return best_idx
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
|
||||
def set_models_dir(path: Path) -> None:
|
||||
"""Override models directory — used by tests."""
|
||||
global _MODELS_DIR
|
||||
_MODELS_DIR = path
|
||||
|
||||
|
||||
def set_config_dir(path: Path | None) -> None:
|
||||
"""Override config directory — used by tests."""
|
||||
global _CONFIG_DIR
|
||||
_CONFIG_DIR = path
|
||||
|
||||
|
||||
def _config_file() -> Path:
|
||||
if _CONFIG_DIR is not None:
|
||||
return _CONFIG_DIR / "label_tool.yaml"
|
||||
return _ROOT / "config" / "label_tool.yaml"
|
||||
|
||||
|
||||
def reset_last_action() -> None:
|
||||
"""Reset undo state — used by tests."""
|
||||
global _last_action
|
||||
_last_action = None
|
||||
|
||||
|
||||
def _queue_file() -> Path:
|
||||
return _DATA_DIR / "email_label_queue.jsonl"
|
||||
|
||||
|
||||
def _score_file() -> Path:
|
||||
return _DATA_DIR / "email_score.jsonl"
|
||||
|
||||
|
||||
def _discarded_file() -> Path:
|
||||
return _DATA_DIR / "discarded.jsonl"
|
||||
|
||||
|
||||
def _read_jsonl(path: Path) -> list[dict]:
|
||||
if not path.exists():
|
||||
return []
|
||||
lines = path.read_text(encoding="utf-8").splitlines()
|
||||
return [json.loads(l) for l in lines if l.strip()]
|
||||
|
||||
|
||||
def _write_jsonl(path: Path, records: list[dict]) -> None:
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
text = "\n".join(json.dumps(r, ensure_ascii=False) for r in records)
|
||||
path.write_text(text + "\n" if records else "", encoding="utf-8")
|
||||
|
||||
|
||||
def _append_jsonl(path: Path, record: dict) -> None:
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with path.open("a", encoding="utf-8") as f:
|
||||
f.write(json.dumps(record, ensure_ascii=False) + "\n")
|
||||
|
||||
|
||||
def _item_id(item: dict) -> str:
|
||||
"""Stable content-hash ID — matches label_tool.py _entry_key dedup logic."""
|
||||
key = (item.get("subject", "") + (item.get("body", "") or "")[:100])
|
||||
return hashlib.md5(key.encode("utf-8", errors="replace")).hexdigest()
|
||||
|
||||
|
||||
def _normalize(item: dict) -> dict:
|
||||
"""Normalize JSONL item to the Vue frontend schema.
|
||||
|
||||
label_tool.py stores: subject, body, from_addr, date, account (no id).
|
||||
The Vue app expects: id, subject, body, from, date, source.
|
||||
Both old (from_addr/account) and new (from/source) field names are handled.
|
||||
"""
|
||||
return {
|
||||
"id": item.get("id") or _item_id(item),
|
||||
"subject": item.get("subject", ""),
|
||||
"body": item.get("body", ""),
|
||||
"from": item.get("from") or item.get("from_addr", ""),
|
||||
"date": item.get("date", ""),
|
||||
"source": item.get("source") or item.get("account", ""),
|
||||
}
|
||||
|
||||
|
||||
app = FastAPI(title="Avocet API")
|
||||
|
||||
# -- Domain routers --------------------------------------------------------
|
||||
from app.sft import router as sft_router
|
||||
app.include_router(sft_router, prefix="/api/sft")
|
||||
|
||||
from app.data.label import router as label_router
|
||||
app.include_router(label_router, prefix="/api")
|
||||
|
||||
from app.data.fetch import router as fetch_router
|
||||
app.include_router(fetch_router, prefix="/api")
|
||||
|
||||
from app.data.corrections import router as corrections_router
|
||||
app.include_router(corrections_router, prefix="/api/corrections")
|
||||
|
||||
# Backward-compat alias -- remove when Vue SPA is updated to /api/corrections/*
|
||||
app.include_router(corrections_router, prefix="/api/sft")
|
||||
|
||||
from app.data.imitate import router as imitate_router
|
||||
app.include_router(imitate_router, prefix="/api/imitate")
|
||||
|
||||
from app.eval.cforch import router as eval_router
|
||||
app.include_router(eval_router, prefix="/api")
|
||||
|
||||
from app.train.train import router as train_router
|
||||
app.include_router(train_router, prefix="/api/train")
|
||||
|
||||
from app.plans_bench import router as plans_bench_router
|
||||
app.include_router(plans_bench_router, prefix="/api/plans-bench")
|
||||
from app.models import router as models_router
|
||||
import app.models as _models_module
|
||||
app.include_router(models_router, prefix="/api/models")
|
||||
|
||||
# In-memory last-action store (single user, local tool — in-memory is fine)
|
||||
_last_action: dict | None = None
|
||||
|
||||
# -- Backward-compat shims (ClassifierTab still uses old /api/finetune/* paths)
|
||||
# Remove once ClassifierTab fine-tune section is migrated to TrainJobsView.
|
||||
|
||||
from fastapi import Query
|
||||
from fastapi.responses import StreamingResponse as _StreamingResponse
|
||||
@app.get("/api/queue")
|
||||
def get_queue(limit: int = Query(default=10, ge=1, le=50)):
|
||||
items = _read_jsonl(_queue_file())
|
||||
return {"items": [_normalize(x) for x in items[:limit]], "total": len(items)}
|
||||
|
||||
|
||||
class LabelRequest(BaseModel):
|
||||
id: str
|
||||
label: str
|
||||
|
||||
|
||||
@app.post("/api/label")
|
||||
def post_label(req: LabelRequest):
|
||||
global _last_action
|
||||
items = _read_jsonl(_queue_file())
|
||||
match = next((x for x in items if _normalize(x)["id"] == req.id), None)
|
||||
if not match:
|
||||
raise HTTPException(404, f"Item {req.id!r} not found in queue")
|
||||
record = {**match, "label": req.label,
|
||||
"labeled_at": datetime.now(timezone.utc).isoformat()}
|
||||
_append_jsonl(_score_file(), record)
|
||||
_write_jsonl(_queue_file(), [x for x in items if _normalize(x)["id"] != req.id])
|
||||
_last_action = {"type": "label", "item": match, "label": req.label}
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
class SkipRequest(BaseModel):
|
||||
id: str
|
||||
|
||||
|
||||
@app.post("/api/skip")
|
||||
def post_skip(req: SkipRequest):
|
||||
global _last_action
|
||||
items = _read_jsonl(_queue_file())
|
||||
match = next((x for x in items if _normalize(x)["id"] == req.id), None)
|
||||
if not match:
|
||||
raise HTTPException(404, f"Item {req.id!r} not found in queue")
|
||||
reordered = [x for x in items if _normalize(x)["id"] != req.id] + [match]
|
||||
_write_jsonl(_queue_file(), reordered)
|
||||
_last_action = {"type": "skip", "item": match}
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
class DiscardRequest(BaseModel):
|
||||
id: str
|
||||
|
||||
|
||||
@app.post("/api/discard")
|
||||
def post_discard(req: DiscardRequest):
|
||||
global _last_action
|
||||
items = _read_jsonl(_queue_file())
|
||||
match = next((x for x in items if _normalize(x)["id"] == req.id), None)
|
||||
if not match:
|
||||
raise HTTPException(404, f"Item {req.id!r} not found in queue")
|
||||
record = {**match, "label": "__discarded__",
|
||||
"discarded_at": datetime.now(timezone.utc).isoformat()}
|
||||
_append_jsonl(_discarded_file(), record)
|
||||
_write_jsonl(_queue_file(), [x for x in items if _normalize(x)["id"] != req.id])
|
||||
_last_action = {"type": "discard", "item": match}
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
@app.delete("/api/label/undo")
|
||||
def delete_undo():
|
||||
global _last_action
|
||||
if not _last_action:
|
||||
raise HTTPException(404, "No action to undo")
|
||||
action = _last_action
|
||||
item = action["item"] # always the original clean queue item
|
||||
|
||||
# Perform file operations FIRST — only clear _last_action on success
|
||||
if action["type"] == "label":
|
||||
records = _read_jsonl(_score_file())
|
||||
if not records:
|
||||
raise HTTPException(409, "Score file is empty — cannot undo label")
|
||||
_write_jsonl(_score_file(), records[:-1])
|
||||
items = _read_jsonl(_queue_file())
|
||||
_write_jsonl(_queue_file(), [item] + items)
|
||||
elif action["type"] == "discard":
|
||||
records = _read_jsonl(_discarded_file())
|
||||
if not records:
|
||||
raise HTTPException(409, "Discarded file is empty — cannot undo discard")
|
||||
_write_jsonl(_discarded_file(), records[:-1])
|
||||
items = _read_jsonl(_queue_file())
|
||||
_write_jsonl(_queue_file(), [item] + items)
|
||||
elif action["type"] == "skip":
|
||||
items = _read_jsonl(_queue_file())
|
||||
item_id = _normalize(item)["id"]
|
||||
items = [item] + [x for x in items if _normalize(x)["id"] != item_id]
|
||||
_write_jsonl(_queue_file(), items)
|
||||
|
||||
# Clear AFTER all file operations succeed
|
||||
_last_action = None
|
||||
return {"undone": {"type": action["type"], "item": _normalize(item)}}
|
||||
|
||||
|
||||
# Label metadata — 10 labels matching label_tool.py
|
||||
_LABEL_META = [
|
||||
{"name": "interview_scheduled", "emoji": "\U0001f4c5", "color": "#4CAF50", "key": "1"},
|
||||
{"name": "offer_received", "emoji": "\U0001f389", "color": "#2196F3", "key": "2"},
|
||||
{"name": "rejected", "emoji": "\u274c", "color": "#F44336", "key": "3"},
|
||||
{"name": "positive_response", "emoji": "\U0001f44d", "color": "#FF9800", "key": "4"},
|
||||
{"name": "survey_received", "emoji": "\U0001f4cb", "color": "#9C27B0", "key": "5"},
|
||||
{"name": "neutral", "emoji": "\u2b1c", "color": "#607D8B", "key": "6"},
|
||||
{"name": "event_rescheduled", "emoji": "\U0001f504", "color": "#FF5722", "key": "7"},
|
||||
{"name": "digest", "emoji": "\U0001f4f0", "color": "#00BCD4", "key": "8"},
|
||||
{"name": "new_lead", "emoji": "\U0001f91d", "color": "#009688", "key": "9"},
|
||||
{"name": "hired", "emoji": "\U0001f38a", "color": "#FFC107", "key": "h"},
|
||||
]
|
||||
|
||||
|
||||
@app.get("/api/config/labels")
|
||||
def get_labels():
|
||||
return _LABEL_META
|
||||
|
||||
|
||||
@app.get("/api/config")
|
||||
def get_config():
|
||||
f = _config_file()
|
||||
if not f.exists():
|
||||
return {"accounts": [], "max_per_account": 500}
|
||||
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
||||
return {"accounts": raw.get("accounts", []), "max_per_account": raw.get("max_per_account", 500)}
|
||||
|
||||
|
||||
class ConfigPayload(BaseModel):
|
||||
accounts: list[dict]
|
||||
max_per_account: int = 500
|
||||
|
||||
|
||||
@app.post("/api/config")
|
||||
def post_config(payload: ConfigPayload):
|
||||
f = _config_file()
|
||||
f.parent.mkdir(parents=True, exist_ok=True)
|
||||
tmp = f.with_suffix(".tmp")
|
||||
tmp.write_text(yaml.dump(payload.model_dump(), allow_unicode=True, sort_keys=False),
|
||||
encoding="utf-8")
|
||||
tmp.rename(f)
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
@app.get("/api/stats")
|
||||
def get_stats():
|
||||
records = _read_jsonl(_score_file())
|
||||
counts: dict[str, int] = {}
|
||||
for r in records:
|
||||
lbl = r.get("label", "")
|
||||
if lbl:
|
||||
counts[lbl] = counts.get(lbl, 0) + 1
|
||||
benchmark_results: dict = {}
|
||||
benchmark_path = _DATA_DIR / "benchmark_results.json"
|
||||
if benchmark_path.exists():
|
||||
try:
|
||||
benchmark_results = json.loads(benchmark_path.read_text(encoding="utf-8"))
|
||||
except Exception:
|
||||
pass
|
||||
return {
|
||||
"total": len(records),
|
||||
"counts": counts,
|
||||
"score_file_bytes": _score_file().stat().st_size if _score_file().exists() else 0,
|
||||
"benchmark_results": benchmark_results,
|
||||
}
|
||||
|
||||
|
||||
@app.get("/api/stats/download")
|
||||
def download_stats():
|
||||
from fastapi.responses import FileResponse
|
||||
if not _score_file().exists():
|
||||
raise HTTPException(404, "No score file")
|
||||
return FileResponse(
|
||||
str(_score_file()),
|
||||
filename="email_score.jsonl",
|
||||
media_type="application/jsonlines",
|
||||
headers={"Content-Disposition": 'attachment; filename="email_score.jsonl"'},
|
||||
)
|
||||
|
||||
|
||||
class AccountTestRequest(BaseModel):
|
||||
account: dict
|
||||
|
||||
|
||||
@app.post("/api/accounts/test")
|
||||
def test_account(req: AccountTestRequest):
|
||||
from app.imap_fetch import test_connection
|
||||
ok, message, count = test_connection(req.account)
|
||||
return {"ok": ok, "message": message, "count": count}
|
||||
|
||||
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Benchmark endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@app.get("/api/benchmark/models")
|
||||
def get_benchmark_models() -> dict:
|
||||
"""Return installed models grouped by adapter_type category."""
|
||||
models_dir: Path = _models_module._MODELS_DIR
|
||||
categories: dict[str, list[dict]] = {
|
||||
"ZeroShotAdapter": [],
|
||||
"RerankerAdapter": [],
|
||||
"GenerationAdapter": [],
|
||||
"Unknown": [],
|
||||
}
|
||||
if models_dir.exists():
|
||||
for sub in models_dir.iterdir():
|
||||
if not sub.is_dir():
|
||||
continue
|
||||
info_path = sub / "model_info.json"
|
||||
adapter_type = "Unknown"
|
||||
repo_id: str | None = None
|
||||
if info_path.exists():
|
||||
try:
|
||||
info = json.loads(info_path.read_text(encoding="utf-8"))
|
||||
adapter_type = info.get("adapter_type") or info.get("adapter_recommendation") or "Unknown"
|
||||
repo_id = info.get("repo_id")
|
||||
except Exception:
|
||||
pass
|
||||
bucket = adapter_type if adapter_type in categories else "Unknown"
|
||||
entry: dict = {"name": sub.name, "repo_id": repo_id, "adapter_type": adapter_type}
|
||||
categories[bucket].append(entry)
|
||||
return {"categories": categories}
|
||||
|
||||
|
||||
@app.get("/api/benchmark/results")
|
||||
def get_benchmark_results():
|
||||
"""Return the most recently saved benchmark results, or an empty envelope."""
|
||||
path = _DATA_DIR / "benchmark_results.json"
|
||||
if not path.exists():
|
||||
return {"models": {}, "sample_count": 0, "timestamp": None}
|
||||
return json.loads(path.read_text())
|
||||
|
||||
|
||||
@app.get("/api/benchmark/run")
|
||||
def run_benchmark(include_slow: bool = False, model_names: str = ""):
|
||||
"""Spawn the benchmark script and stream stdout as SSE progress events."""
|
||||
python_bin = "/devl/miniconda3/envs/job-seeker-classifiers/bin/python"
|
||||
script = str(_ROOT / "scripts" / "benchmark_classifier.py")
|
||||
cmd = [python_bin, script, "--score", "--save"]
|
||||
if include_slow:
|
||||
cmd.append("--include-slow")
|
||||
if model_names:
|
||||
names = [n.strip() for n in model_names.split(",") if n.strip()]
|
||||
if names:
|
||||
cmd.extend(["--models"] + names)
|
||||
|
||||
def generate():
|
||||
try:
|
||||
proc = _subprocess.Popen(
|
||||
cmd,
|
||||
stdout=_subprocess.PIPE,
|
||||
stderr=_subprocess.STDOUT,
|
||||
text=True,
|
||||
bufsize=1,
|
||||
cwd=str(_ROOT),
|
||||
)
|
||||
_running_procs["benchmark"] = proc
|
||||
_cancelled_jobs.discard("benchmark") # clear any stale flag from a prior run
|
||||
try:
|
||||
for line in proc.stdout:
|
||||
line = line.rstrip()
|
||||
if line:
|
||||
yield f"data: {json.dumps({'type': 'progress', 'message': line})}\n\n"
|
||||
proc.wait()
|
||||
if proc.returncode == 0:
|
||||
yield f"data: {json.dumps({'type': 'complete'})}\n\n"
|
||||
elif "benchmark" in _cancelled_jobs:
|
||||
_cancelled_jobs.discard("benchmark")
|
||||
yield f"data: {json.dumps({'type': 'cancelled'})}\n\n"
|
||||
else:
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': f'Process exited with code {proc.returncode}'})}\n\n"
|
||||
finally:
|
||||
_running_procs.pop("benchmark", None)
|
||||
except Exception as exc:
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': str(exc)})}\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
generate(),
|
||||
media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Finetune endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@app.get("/api/finetune/status")
|
||||
def get_finetune_status():
|
||||
"""Scan models/ for training_info.json files. Returns [] if none exist."""
|
||||
models_dir = _MODELS_DIR
|
||||
if not models_dir.exists():
|
||||
return []
|
||||
results = []
|
||||
for sub in models_dir.iterdir():
|
||||
if not sub.is_dir():
|
||||
continue
|
||||
info_path = sub / "training_info.json"
|
||||
if not info_path.exists():
|
||||
continue
|
||||
try:
|
||||
info = json.loads(info_path.read_text(encoding="utf-8"))
|
||||
results.append(info)
|
||||
except Exception:
|
||||
pass
|
||||
return results
|
||||
|
||||
|
||||
@app.get("/api/finetune/run")
|
||||
def finetune_run_compat(model: str = Query(...), epochs: int = Query(5)) -> _StreamingResponse:
|
||||
"""Shim: create a classifier train job and immediately stream it."""
|
||||
from app.train.train import create_job, run_job, CreateJobRequest
|
||||
job = create_job(CreateJobRequest(type="classifier", model_key=model, config_json={"epochs": epochs}))
|
||||
return run_job(job["id"])
|
||||
def run_finetune_endpoint(
|
||||
model: str = "deberta-small",
|
||||
epochs: int = 5,
|
||||
score: list[str] = Query(default=[]),
|
||||
):
|
||||
"""Spawn finetune_classifier.py and stream stdout as SSE progress events."""
|
||||
python_bin = "/devl/miniconda3/envs/job-seeker-classifiers/bin/python"
|
||||
script = str(_ROOT / "scripts" / "finetune_classifier.py")
|
||||
cmd = [python_bin, script, "--model", model, "--epochs", str(epochs)]
|
||||
data_root = _DATA_DIR.resolve()
|
||||
for score_file in score:
|
||||
resolved = (_DATA_DIR / score_file).resolve()
|
||||
if not str(resolved).startswith(str(data_root)):
|
||||
raise HTTPException(400, f"Invalid score path: {score_file!r}")
|
||||
cmd.extend(["--score", str(resolved)])
|
||||
|
||||
# Pick the GPU with the most free VRAM. Setting CUDA_VISIBLE_DEVICES to a
|
||||
# single device prevents DataParallel from replicating the model across all
|
||||
# GPUs, which would force a full copy onto the more memory-constrained device.
|
||||
proc_env = {**os.environ, "PYTORCH_ALLOC_CONF": "expandable_segments:True"}
|
||||
best_gpu = _best_cuda_device()
|
||||
if best_gpu:
|
||||
proc_env["CUDA_VISIBLE_DEVICES"] = best_gpu
|
||||
|
||||
gpu_note = f"GPU {best_gpu}" if best_gpu else "CPU (no GPU found)"
|
||||
|
||||
def generate():
|
||||
yield f"data: {json.dumps({'type': 'progress', 'message': f'[api] Using {gpu_note} (most free VRAM)'})}\n\n"
|
||||
try:
|
||||
proc = _subprocess.Popen(
|
||||
cmd,
|
||||
stdout=_subprocess.PIPE,
|
||||
stderr=_subprocess.STDOUT,
|
||||
text=True,
|
||||
bufsize=1,
|
||||
cwd=str(_ROOT),
|
||||
env=proc_env,
|
||||
)
|
||||
_running_procs["finetune"] = proc
|
||||
_cancelled_jobs.discard("finetune") # clear any stale flag from a prior run
|
||||
try:
|
||||
for line in proc.stdout:
|
||||
line = line.rstrip()
|
||||
if line:
|
||||
yield f"data: {json.dumps({'type': 'progress', 'message': line})}\n\n"
|
||||
proc.wait()
|
||||
if proc.returncode == 0:
|
||||
yield f"data: {json.dumps({'type': 'complete'})}\n\n"
|
||||
elif "finetune" in _cancelled_jobs:
|
||||
_cancelled_jobs.discard("finetune")
|
||||
yield f"data: {json.dumps({'type': 'cancelled'})}\n\n"
|
||||
else:
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': f'Process exited with code {proc.returncode}'})}\n\n"
|
||||
finally:
|
||||
_running_procs.pop("finetune", None)
|
||||
except Exception as exc:
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': str(exc)})}\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
generate(),
|
||||
media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
|
||||
)
|
||||
|
||||
|
||||
@app.post("/api/benchmark/cancel")
|
||||
def cancel_benchmark():
|
||||
"""Kill the running benchmark subprocess. 404 if none is running."""
|
||||
proc = _running_procs.get("benchmark")
|
||||
if proc is None:
|
||||
raise HTTPException(404, "No benchmark is running")
|
||||
_cancelled_jobs.add("benchmark")
|
||||
proc.terminate()
|
||||
try:
|
||||
proc.wait(timeout=3)
|
||||
except _subprocess.TimeoutExpired:
|
||||
proc.kill()
|
||||
return {"status": "cancelled"}
|
||||
|
||||
|
||||
@app.post("/api/finetune/cancel")
|
||||
def finetune_cancel_compat() -> dict:
|
||||
"""Shim: cancel the most recent running classifier job."""
|
||||
from app.train.train import _db, _init_db, cancel_job
|
||||
from fastapi import HTTPException
|
||||
_init_db()
|
||||
with _db() as conn:
|
||||
row = conn.execute(
|
||||
"SELECT id FROM jobs WHERE type='classifier' AND status='running' ORDER BY started_at DESC LIMIT 1"
|
||||
).fetchone()
|
||||
if row is None:
|
||||
return {"status": "nothing_running"}
|
||||
return cancel_job(row["id"])
|
||||
def cancel_finetune():
|
||||
"""Kill the running fine-tune subprocess. 404 if none is running."""
|
||||
proc = _running_procs.get("finetune")
|
||||
if proc is None:
|
||||
raise HTTPException(404, "No finetune is running")
|
||||
_cancelled_jobs.add("finetune")
|
||||
proc.terminate()
|
||||
try:
|
||||
proc.wait(timeout=3)
|
||||
except _subprocess.TimeoutExpired:
|
||||
proc.kill()
|
||||
return {"status": "cancelled"}
|
||||
|
||||
from app.data.log_corpus import router as log_corpus_router
|
||||
app.include_router(log_corpus_router, prefix="/api/corpus")
|
||||
|
||||
from app.data.recipe_scan import router as recipe_scan_router
|
||||
app.include_router(recipe_scan_router, prefix="/api/recipe-scan")
|
||||
@app.get("/api/fetch/stream")
|
||||
def fetch_stream(
|
||||
accounts: str = Query(default=""),
|
||||
days_back: int = Query(default=90, ge=1, le=365),
|
||||
limit: int = Query(default=150, ge=1, le=1000),
|
||||
mode: str = Query(default="wide"),
|
||||
):
|
||||
from app.imap_fetch import fetch_account_stream
|
||||
|
||||
from app.dashboard import router as dashboard_router
|
||||
app.include_router(dashboard_router, prefix="/api")
|
||||
selected_names = {n.strip() for n in accounts.split(",") if n.strip()}
|
||||
config = get_config() # reuse existing endpoint logic
|
||||
selected = [a for a in config["accounts"] if a.get("name") in selected_names]
|
||||
|
||||
from app.models import router as models_router
|
||||
app.include_router(models_router, prefix="/api/models")
|
||||
def generate():
|
||||
known_keys = {_item_id(x) for x in _read_jsonl(_queue_file())}
|
||||
total_added = 0
|
||||
|
||||
from app.nodes import router as nodes_router
|
||||
app.include_router(nodes_router, prefix="/api/nodes-mgmt")
|
||||
for acc in selected:
|
||||
try:
|
||||
batch_emails: list[dict] = []
|
||||
for event in fetch_account_stream(acc, days_back, limit, known_keys):
|
||||
if event["type"] == "done":
|
||||
batch_emails = event.pop("emails", [])
|
||||
total_added += event["added"]
|
||||
yield f"data: {json.dumps(event)}\n\n"
|
||||
# Write new emails to queue after each account
|
||||
if batch_emails:
|
||||
existing = _read_jsonl(_queue_file())
|
||||
_write_jsonl(_queue_file(), existing + batch_emails)
|
||||
except Exception as exc:
|
||||
error_event = {"type": "error", "account": acc.get("name", "?"),
|
||||
"message": str(exc)}
|
||||
yield f"data: {json.dumps(error_event)}\n\n"
|
||||
|
||||
# -- Static SPA -- MUST be last (catches all unmatched paths) ---------------
|
||||
queue_size = len(_read_jsonl(_queue_file()))
|
||||
complete = {"type": "complete", "total_added": total_added, "queue_size": queue_size}
|
||||
yield f"data: {json.dumps(complete)}\n\n"
|
||||
|
||||
_ROOT = Path(__file__).parent.parent
|
||||
return StreamingResponse(generate(), media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"})
|
||||
|
||||
|
||||
# Static SPA — MUST be last (catches all unmatched paths)
|
||||
_DIST = _ROOT / "web" / "dist"
|
||||
if _DIST.exists():
|
||||
from fastapi.responses import FileResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
|
||||
# Serve index.html with no-cache so browsers always fetch fresh HTML after rebuilds.
|
||||
# Hashed assets (/assets/index-abc123.js) can be cached forever — they change names
|
||||
# when content changes (standard Vite cache-busting strategy).
|
||||
_NO_CACHE = {"Cache-Control": "no-cache, no-store, must-revalidate", "Pragma": "no-cache"}
|
||||
|
||||
@app.get("/")
|
||||
|
|
|
|||
653
app/cforch.py
653
app/cforch.py
|
|
@ -1,653 +0,0 @@
|
|||
"""Avocet — cf-orch benchmark integration API.
|
||||
|
||||
Wraps cf-orch's benchmark.py script and exposes it via the Avocet API.
|
||||
Config is read from label_tool.yaml under the `cforch:` key.
|
||||
|
||||
All endpoints are registered on `router` (a FastAPI APIRouter).
|
||||
api.py includes this router with prefix="/api/cforch".
|
||||
|
||||
Module-level globals (_CONFIG_DIR, _BENCH_RUNNING, _bench_proc) follow the
|
||||
same testability pattern as sft.py — override _CONFIG_DIR via set_config_dir()
|
||||
in test fixtures.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import select as _select
|
||||
import subprocess as _subprocess
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
import urllib.parse
|
||||
|
||||
import yaml
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_ROOT = Path(__file__).parent.parent
|
||||
_CONFIG_DIR: Path | None = None # override in tests
|
||||
_BENCH_RUNNING: bool = False
|
||||
_bench_proc: Any = None # live Popen object while benchmark runs
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# ── Testability seams ──────────────────────────────────────────────────────────
|
||||
|
||||
def set_config_dir(path: Path | None) -> None:
|
||||
global _CONFIG_DIR
|
||||
_CONFIG_DIR = path
|
||||
|
||||
|
||||
# ── Internal helpers ───────────────────────────────────────────────────────────
|
||||
|
||||
def _config_file() -> Path:
|
||||
if _CONFIG_DIR is not None:
|
||||
return _CONFIG_DIR / "label_tool.yaml"
|
||||
return _ROOT / "config" / "label_tool.yaml"
|
||||
|
||||
|
||||
def _load_cforch_config() -> dict:
|
||||
"""Read label_tool.yaml cforch section, falling back to environment variables.
|
||||
|
||||
Priority (highest to lowest):
|
||||
1. label_tool.yaml cforch: key
|
||||
2. Environment variables (CF_ORCH_URL, CF_LICENSE_KEY, OLLAMA_HOST, OLLAMA_MODEL)
|
||||
"""
|
||||
f = _config_file()
|
||||
file_cfg: dict = {}
|
||||
if f.exists():
|
||||
try:
|
||||
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
||||
file_cfg = raw.get("cforch", {}) or {}
|
||||
except yaml.YAMLError as exc:
|
||||
logger.warning("Failed to parse cforch config %s: %s", f, exc)
|
||||
|
||||
# Env var fallbacks — only used when the yaml key is absent or empty
|
||||
def _coalesce(file_val: str, env_key: str) -> str:
|
||||
return file_val if file_val else os.environ.get(env_key, "")
|
||||
|
||||
return {
|
||||
**file_cfg,
|
||||
"coordinator_url": _coalesce(file_cfg.get("coordinator_url", ""), "CF_ORCH_URL"),
|
||||
"license_key": _coalesce(file_cfg.get("license_key", ""), "CF_LICENSE_KEY"),
|
||||
"ollama_url": _coalesce(file_cfg.get("ollama_url", ""), "OLLAMA_HOST"),
|
||||
"ollama_model": _coalesce(file_cfg.get("ollama_model", ""), "OLLAMA_MODEL"),
|
||||
"judge_url": _coalesce(file_cfg.get("judge_url", ""), "CF_JUDGE_URL"),
|
||||
"hf_token": _coalesce(file_cfg.get("hf_token", ""), "HF_TOKEN"),
|
||||
}
|
||||
|
||||
|
||||
def _validate_service_url(url: str, param_name: str) -> str:
|
||||
"""Validate that a URL is a well-formed http/https URL with a hostname.
|
||||
|
||||
Guards against SSRF: only http/https is allowed; the URL must have a
|
||||
non-empty host. Does not enforce an allowlist — call sites are internal
|
||||
tooling, not a public API.
|
||||
"""
|
||||
if not url:
|
||||
return url
|
||||
try:
|
||||
parsed = urllib.parse.urlparse(url)
|
||||
except Exception:
|
||||
raise HTTPException(400, f"{param_name}: not a valid URL")
|
||||
if parsed.scheme not in ("http", "https"):
|
||||
raise HTTPException(400, f"{param_name}: URL must start with http:// or https://")
|
||||
if not parsed.hostname:
|
||||
raise HTTPException(400, f"{param_name}: URL has no hostname")
|
||||
return url
|
||||
|
||||
|
||||
def _strip_ansi(text: str) -> str:
|
||||
"""Remove ANSI escape codes from a string."""
|
||||
return re.sub(r'\x1b\[[0-9;]*m', '', text)
|
||||
|
||||
|
||||
def _find_latest_summary(results_dir: str | None) -> Path | None:
|
||||
"""Find the newest summary.json under results_dir, or None if not found."""
|
||||
if not results_dir:
|
||||
return None
|
||||
rdir = Path(results_dir)
|
||||
if not rdir.exists():
|
||||
return None
|
||||
# Subdirs are named YYYY-MM-DD-HHMMSS; sort lexicographically for chronological order
|
||||
subdirs = sorted(
|
||||
[d for d in rdir.iterdir() if d.is_dir()],
|
||||
key=lambda d: d.name,
|
||||
)
|
||||
for subdir in reversed(subdirs):
|
||||
summary = subdir / "summary.json"
|
||||
if summary.exists():
|
||||
return summary
|
||||
return None
|
||||
|
||||
|
||||
# ── GET /tasks ─────────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/tasks")
|
||||
def get_tasks() -> dict:
|
||||
"""Return task list from bench_tasks.yaml."""
|
||||
cfg = _load_cforch_config()
|
||||
tasks_path = cfg.get("bench_tasks", "")
|
||||
if not tasks_path:
|
||||
return {"tasks": [], "types": []}
|
||||
|
||||
p = Path(tasks_path)
|
||||
if not p.exists():
|
||||
return {"tasks": [], "types": []}
|
||||
|
||||
try:
|
||||
raw = yaml.safe_load(p.read_text(encoding="utf-8")) or {}
|
||||
except yaml.YAMLError as exc:
|
||||
logger.warning("Failed to parse bench_tasks.yaml %s: %s", p, exc)
|
||||
return {"tasks": [], "types": []}
|
||||
|
||||
tasks_raw = raw.get("tasks", []) or []
|
||||
tasks: list[dict] = []
|
||||
seen_types: list[str] = []
|
||||
types_set: set[str] = set()
|
||||
|
||||
for t in tasks_raw:
|
||||
if not isinstance(t, dict):
|
||||
continue
|
||||
tasks.append({
|
||||
"id": t.get("id", ""),
|
||||
"name": t.get("name", ""),
|
||||
"type": t.get("type", ""),
|
||||
"prompt": (t.get("prompt") or "").strip(),
|
||||
"system": (t.get("system") or "").strip(),
|
||||
})
|
||||
task_type = t.get("type", "")
|
||||
if task_type and task_type not in types_set:
|
||||
seen_types.append(task_type)
|
||||
types_set.add(task_type)
|
||||
|
||||
return {"tasks": tasks, "types": seen_types}
|
||||
|
||||
|
||||
# ── GET /models ────────────────────────────────────────────────────────────────
|
||||
|
||||
# Services and roles surfaced in the benchmark model picker.
|
||||
# Covers all cf-orch service types that benchmark.py can route tasks to.
|
||||
_BENCH_SERVICES = frozenset({
|
||||
"cf-text", "vllm", # LLM text generation
|
||||
"cf-stt", # speech-to-text
|
||||
"cf-tts", # text-to-speech
|
||||
"cf-vision", # image classification / embedding
|
||||
"cf-voice", # audio context classification
|
||||
})
|
||||
_BENCH_ROLES = frozenset({
|
||||
"generator", "vlm", # LLM roles
|
||||
"stt", "alm", # speech recognition
|
||||
"tts", # speech synthesis
|
||||
"vision", "embedding", # image understanding
|
||||
"classifier", # audio classification (cf-voice)
|
||||
})
|
||||
|
||||
|
||||
@router.get("/models")
|
||||
def get_models() -> dict:
|
||||
"""Return model list from bench_models.yaml merged with locally installed models.
|
||||
|
||||
bench_models.yaml entries are listed first and take precedence; any installed
|
||||
model whose repo_id is already present in the YAML is skipped. Only models
|
||||
whose service is in _BENCH_SERVICES (cf-text, vllm, cf-stt, cf-tts, cf-vision,
|
||||
cf-voice) are surfaced from the installed registry.
|
||||
"""
|
||||
cfg = _load_cforch_config()
|
||||
models_path = cfg.get("bench_models", "")
|
||||
|
||||
models: list[dict] = []
|
||||
bench_ids: set[str] = set()
|
||||
|
||||
if models_path:
|
||||
p = Path(models_path)
|
||||
if p.exists():
|
||||
try:
|
||||
raw = yaml.safe_load(p.read_text(encoding="utf-8")) or {}
|
||||
except yaml.YAMLError as exc:
|
||||
logger.warning("Failed to parse bench_models.yaml %s: %s", p, exc)
|
||||
raw = {}
|
||||
for m in (raw.get("models", []) or []):
|
||||
if not isinstance(m, dict):
|
||||
continue
|
||||
model_id = m.get("id", "")
|
||||
models.append({
|
||||
"name": m.get("name", ""),
|
||||
"id": model_id,
|
||||
"service": m.get("service", "ollama"),
|
||||
"tags": m.get("tags", []) or [],
|
||||
"vram_estimate_mb": m.get("vram_estimate_mb", 0),
|
||||
})
|
||||
if model_id:
|
||||
bench_ids.add(model_id)
|
||||
|
||||
# Merge installed generator models not already in bench_models.yaml.
|
||||
try:
|
||||
from app.models import list_installed # local import avoids circular dependency at module load
|
||||
for installed in list_installed():
|
||||
model_id: str = installed.get("model_id") or ""
|
||||
service: str = installed.get("service") or ""
|
||||
role: str = installed.get("role") or ""
|
||||
if not model_id:
|
||||
continue
|
||||
if service not in _BENCH_SERVICES or role not in _BENCH_ROLES:
|
||||
continue
|
||||
if model_id in bench_ids:
|
||||
continue
|
||||
display_name = model_id.split("/", 1)[-1] if "/" in model_id else model_id
|
||||
models.append({
|
||||
"name": display_name,
|
||||
"id": model_id,
|
||||
"service": service,
|
||||
"tags": [role],
|
||||
"vram_estimate_mb": installed.get("vram_mb") or 0,
|
||||
})
|
||||
bench_ids.add(model_id)
|
||||
except Exception as exc:
|
||||
logger.warning("Could not merge installed models into model list: %s", exc)
|
||||
|
||||
return {"models": models}
|
||||
|
||||
|
||||
# ── GET /run ───────────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/nodes")
|
||||
def get_nodes() -> dict:
|
||||
"""Proxy the coordinator's /api/nodes list, returning node_id + online status.
|
||||
|
||||
Online is inferred from last_heartbeat: any node with a recent heartbeat is online.
|
||||
Returns an empty list if the coordinator is unreachable.
|
||||
"""
|
||||
cfg = _load_cforch_config()
|
||||
coordinator_url = cfg.get("coordinator_url", "").rstrip("/")
|
||||
if not coordinator_url:
|
||||
return {"nodes": []}
|
||||
try:
|
||||
import httpx as _httpx
|
||||
resp = _httpx.get(f"{coordinator_url}/api/nodes", timeout=5.0)
|
||||
resp.raise_for_status()
|
||||
raw_nodes = resp.json().get("nodes", [])
|
||||
return {
|
||||
"nodes": [
|
||||
{
|
||||
"node_id": n.get("node_id", ""),
|
||||
"online": n.get("last_heartbeat") is not None,
|
||||
"gpus": [
|
||||
{
|
||||
"gpu_id": g.get("gpu_id"),
|
||||
"name": g.get("name", ""),
|
||||
"vram_total_mb": g.get("vram_total_mb", 0),
|
||||
"vram_free_mb": g.get("vram_free_mb", 0),
|
||||
}
|
||||
for g in n.get("gpus", [])
|
||||
],
|
||||
}
|
||||
for n in raw_nodes
|
||||
]
|
||||
}
|
||||
except Exception as exc:
|
||||
logger.warning("Could not fetch nodes from coordinator: %s", exc)
|
||||
return {"nodes": []}
|
||||
|
||||
|
||||
@router.get("/run")
|
||||
def run_benchmark(
|
||||
task_ids: str = "",
|
||||
model_ids: str = "",
|
||||
model_tags: str = "",
|
||||
coordinator_url: str = "",
|
||||
ollama_url: str = "",
|
||||
judge_url: str = "",
|
||||
judge_backend: str = "chat",
|
||||
workers: int = 1,
|
||||
node_ids: str = "",
|
||||
) -> StreamingResponse:
|
||||
"""Spawn cf-orch benchmark.py and stream stdout as SSE progress events."""
|
||||
global _BENCH_RUNNING, _bench_proc
|
||||
|
||||
# Check if the process is actually still alive; reset stale flag if not.
|
||||
if _BENCH_RUNNING:
|
||||
if _bench_proc is not None and _bench_proc.poll() is None:
|
||||
raise HTTPException(409, "A benchmark is already running")
|
||||
_BENCH_RUNNING = False
|
||||
_bench_proc = None
|
||||
|
||||
cfg = _load_cforch_config()
|
||||
bench_script = cfg.get("bench_script", "")
|
||||
bench_tasks = cfg.get("bench_tasks", "")
|
||||
bench_models = cfg.get("bench_models", "")
|
||||
results_dir = cfg.get("results_dir", "")
|
||||
python_bin = cfg.get("python_bin", "/devl/miniconda3/envs/cf/bin/python")
|
||||
cfg_coordinator = cfg.get("coordinator_url", "")
|
||||
cfg_ollama = cfg.get("ollama_url", "")
|
||||
cfg_license_key = cfg.get("license_key", "")
|
||||
cfg_judge_url = cfg.get("judge_url", "")
|
||||
|
||||
# Validate URL params before spawning the subprocess.
|
||||
# _validate_service_url raises HTTPException on bad input (caught by FastAPI before streaming starts).
|
||||
_validate_service_url(coordinator_url, "coordinator_url")
|
||||
_validate_service_url(ollama_url, "ollama_url")
|
||||
_validate_service_url(judge_url, "judge_url")
|
||||
|
||||
def generate():
|
||||
global _BENCH_RUNNING, _bench_proc
|
||||
|
||||
if not bench_script or not Path(bench_script).exists():
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': 'bench_script not configured or not found'})}\n\n"
|
||||
return
|
||||
|
||||
# Build effective models file: bench_models.yaml + any installed models
|
||||
# whose IDs were selected but are absent from the YAML (e.g. downloaded
|
||||
# via the Models view). Written to a temp file so benchmark.py sees one
|
||||
# unified list; cleaned up in the finally block.
|
||||
effective_models_file = bench_models
|
||||
_tmp_models_path: str | None = None
|
||||
|
||||
if model_ids and bench_models and Path(bench_models).exists():
|
||||
requested_ids = set(model_ids.split(","))
|
||||
try:
|
||||
raw_bench = yaml.safe_load(Path(bench_models).read_text(encoding="utf-8")) or {}
|
||||
bench_entries: list[dict] = raw_bench.get("models", []) or []
|
||||
bench_id_set = {m.get("id", "") for m in bench_entries if isinstance(m, dict)}
|
||||
missing_ids = requested_ids - bench_id_set
|
||||
if missing_ids:
|
||||
from app.models import list_installed
|
||||
installed_map = {
|
||||
m["model_id"]: m
|
||||
for m in list_installed()
|
||||
if m.get("model_id") and m.get("service") in _BENCH_SERVICES
|
||||
}
|
||||
extra: list[dict] = []
|
||||
for mid in missing_ids:
|
||||
if mid in installed_map:
|
||||
inst = installed_map[mid]
|
||||
entry: dict[str, Any] = {
|
||||
"id": mid,
|
||||
"name": mid.split("/", 1)[-1] if "/" in mid else mid,
|
||||
"service": inst.get("service", "cf-text"),
|
||||
"vram_estimate_mb": inst.get("vram_mb") or 0,
|
||||
"tags": [inst.get("role", "generator")],
|
||||
"temperature": 0.0,
|
||||
}
|
||||
local_path = inst.get("path", "") or inst.get("local_path", "")
|
||||
if local_path:
|
||||
entry["model_path"] = local_path
|
||||
extra.append(entry)
|
||||
if extra:
|
||||
merged = {"models": bench_entries + extra}
|
||||
tf = tempfile.NamedTemporaryFile(
|
||||
mode="w", suffix=".yaml", delete=False,
|
||||
prefix="avocet_bench_models_",
|
||||
)
|
||||
yaml.dump(merged, tf)
|
||||
tf.close()
|
||||
_tmp_models_path = tf.name
|
||||
effective_models_file = _tmp_models_path
|
||||
except Exception as exc:
|
||||
logger.warning("Could not merge installed models into temp bench file: %s", exc)
|
||||
|
||||
cmd = [
|
||||
python_bin,
|
||||
bench_script,
|
||||
"--tasks", bench_tasks,
|
||||
"--models", effective_models_file,
|
||||
"--output", results_dir,
|
||||
]
|
||||
|
||||
if task_ids:
|
||||
cmd.extend(["--filter-tasks"] + task_ids.split(","))
|
||||
if model_ids:
|
||||
cmd.extend(["--filter-models"] + model_ids.split(","))
|
||||
if model_tags:
|
||||
cmd.extend(["--filter-tags"] + model_tags.split(","))
|
||||
|
||||
# query param overrides config, config overrides env var (already resolved by _load_cforch_config)
|
||||
effective_coordinator = coordinator_url if coordinator_url else cfg_coordinator
|
||||
effective_ollama = ollama_url if ollama_url else cfg_ollama
|
||||
if effective_coordinator:
|
||||
cmd.extend(["--coordinator", effective_coordinator])
|
||||
if effective_ollama:
|
||||
cmd.extend(["--ollama-url", effective_ollama])
|
||||
effective_judge = judge_url if judge_url else cfg_judge_url
|
||||
if effective_judge:
|
||||
cmd.extend(["--judge-url", effective_judge])
|
||||
if judge_backend and judge_backend != "chat":
|
||||
cmd.extend(["--judge-backend", judge_backend])
|
||||
if workers > 1:
|
||||
cmd.extend(["--workers", str(workers)])
|
||||
if node_ids:
|
||||
cmd.extend(["--nodes"] + node_ids.split(","))
|
||||
|
||||
# Pass license key as env var so subprocess can authenticate with cf-orch
|
||||
proc_env = {**os.environ}
|
||||
if cfg_license_key:
|
||||
proc_env["CF_LICENSE_KEY"] = cfg_license_key
|
||||
|
||||
_BENCH_RUNNING = True
|
||||
try:
|
||||
proc = _subprocess.Popen(
|
||||
cmd,
|
||||
stdout=_subprocess.PIPE,
|
||||
stderr=_subprocess.STDOUT,
|
||||
text=True,
|
||||
bufsize=1,
|
||||
env=proc_env,
|
||||
)
|
||||
_bench_proc = proc
|
||||
_IDLE_TIMEOUT_S = 120 # kill if no output for 2 minutes (node crash)
|
||||
try:
|
||||
while True:
|
||||
ready = _select.select([proc.stdout], [], [], _IDLE_TIMEOUT_S)
|
||||
if not ready[0]:
|
||||
# No output for IDLE_TIMEOUT_S — node likely crashed
|
||||
proc.terminate()
|
||||
try:
|
||||
proc.wait(timeout=5)
|
||||
except _subprocess.TimeoutExpired:
|
||||
proc.kill()
|
||||
msg = f"Benchmark timed out — no output for {_IDLE_TIMEOUT_S}s (cluster node may have crashed)"
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': msg})}\n\n"
|
||||
break
|
||||
line = proc.stdout.readline()
|
||||
if not line:
|
||||
break
|
||||
line = _strip_ansi(line.rstrip())
|
||||
if line:
|
||||
yield f"data: {json.dumps({'type': 'progress', 'message': line})}\n\n"
|
||||
proc.wait()
|
||||
if proc.returncode == 0:
|
||||
summary_path = _find_latest_summary(results_dir)
|
||||
if summary_path is not None:
|
||||
try:
|
||||
summary = json.loads(summary_path.read_text(encoding="utf-8"))
|
||||
yield f"data: {json.dumps({'type': 'result', 'summary': summary})}\n\n"
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to read summary.json: %s", exc)
|
||||
yield f"data: {json.dumps({'type': 'complete'})}\n\n"
|
||||
else:
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': f'Process exited with code {proc.returncode}'})}\n\n"
|
||||
finally:
|
||||
_bench_proc = None
|
||||
except Exception as exc:
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': str(exc)})}\n\n"
|
||||
finally:
|
||||
_BENCH_RUNNING = False
|
||||
if _tmp_models_path:
|
||||
try:
|
||||
os.unlink(_tmp_models_path)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
return StreamingResponse(
|
||||
generate(),
|
||||
media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
|
||||
)
|
||||
|
||||
|
||||
# ── GET /config ────────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/config")
|
||||
def get_cforch_config() -> dict:
|
||||
"""Return resolved cf-orch connection config (env vars merged with yaml).
|
||||
|
||||
Redacts license_key — only returns whether it is set, not the value.
|
||||
Used by the Settings UI to show current connection state.
|
||||
"""
|
||||
cfg = _load_cforch_config()
|
||||
return {
|
||||
"coordinator_url": cfg.get("coordinator_url", ""),
|
||||
"ollama_url": cfg.get("ollama_url", ""),
|
||||
"ollama_model": cfg.get("ollama_model", ""),
|
||||
"judge_url": cfg.get("judge_url", ""),
|
||||
"license_key_set": bool(cfg.get("license_key", "")),
|
||||
"source": "env" if not _config_file().exists() else "yaml+env",
|
||||
}
|
||||
|
||||
|
||||
# ── GET /results ───────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/results")
|
||||
def get_results() -> dict:
|
||||
"""Return the latest benchmark summary.json from results_dir."""
|
||||
cfg = _load_cforch_config()
|
||||
results_dir = cfg.get("results_dir", "")
|
||||
summary_path = _find_latest_summary(results_dir)
|
||||
if summary_path is None:
|
||||
raise HTTPException(404, "No benchmark results found")
|
||||
try:
|
||||
return json.loads(summary_path.read_text(encoding="utf-8"))
|
||||
except Exception as exc:
|
||||
raise HTTPException(500, f"Failed to read summary.json: {exc}") from exc
|
||||
|
||||
|
||||
# ── POST /cancel ───────────────────────────────────────────────────────────────
|
||||
|
||||
@router.post("/cancel")
|
||||
def cancel_benchmark() -> dict:
|
||||
"""Kill the running benchmark subprocess."""
|
||||
global _BENCH_RUNNING, _bench_proc
|
||||
|
||||
if not _BENCH_RUNNING:
|
||||
raise HTTPException(404, "No benchmark is currently running")
|
||||
|
||||
if _bench_proc is not None:
|
||||
try:
|
||||
_bench_proc.terminate()
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to terminate benchmark process: %s", exc)
|
||||
|
||||
_BENCH_RUNNING = False
|
||||
_bench_proc = None
|
||||
return {"status": "cancelled"}
|
||||
|
||||
|
||||
# ── Coordinator proxy helpers ──────────────────────────────────────────────────
|
||||
|
||||
def _coordinator_url() -> str:
|
||||
"""Return coordinator base URL from config, or raise 503 if not configured."""
|
||||
url = _load_cforch_config().get("coordinator_url", "").rstrip("/")
|
||||
if not url:
|
||||
raise HTTPException(503, "cf-orch coordinator_url not configured")
|
||||
return url
|
||||
|
||||
|
||||
def _coordinator_get(path: str) -> Any:
|
||||
"""GET from coordinator, return parsed JSON body. Raises HTTPException on error."""
|
||||
import httpx as _httpx
|
||||
try:
|
||||
resp = _httpx.get(f"{_coordinator_url()}{path}", timeout=10.0)
|
||||
except Exception as exc:
|
||||
raise HTTPException(502, f"Coordinator unreachable: {exc}") from exc
|
||||
if not resp.is_success:
|
||||
raise HTTPException(resp.status_code, resp.text)
|
||||
return resp.json()
|
||||
|
||||
|
||||
async def _coordinator_post(path: str, body: dict) -> Any:
|
||||
import httpx as _httpx
|
||||
try:
|
||||
async with _httpx.AsyncClient(timeout=10.0) as client:
|
||||
resp = await client.post(f"{_coordinator_url()}{path}", json=body)
|
||||
except Exception as exc:
|
||||
raise HTTPException(502, f"Coordinator unreachable: {exc}") from exc
|
||||
if not resp.is_success:
|
||||
raise HTTPException(resp.status_code, resp.text)
|
||||
return resp.json()
|
||||
|
||||
|
||||
async def _coordinator_delete(path: str) -> Any:
|
||||
import httpx as _httpx
|
||||
try:
|
||||
async with _httpx.AsyncClient(timeout=10.0) as client:
|
||||
resp = await client.delete(f"{_coordinator_url()}{path}")
|
||||
except Exception as exc:
|
||||
raise HTTPException(502, f"Coordinator unreachable: {exc}") from exc
|
||||
if not resp.is_success:
|
||||
raise HTTPException(resp.status_code, resp.text)
|
||||
return resp.json()
|
||||
|
||||
|
||||
# ── GET /assignments/deployment-status ───────────────────────────────────────
|
||||
|
||||
@router.get("/assignments/deployment-status")
|
||||
def get_deployment_status() -> Any:
|
||||
return _coordinator_get("/api/assignments/deployment-status")
|
||||
|
||||
|
||||
# ── /assignments ──────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/assignments")
|
||||
def list_assignments() -> Any:
|
||||
return _coordinator_get("/api/assignments")
|
||||
|
||||
|
||||
class AssignmentBody(BaseModel):
|
||||
product: str
|
||||
task: str
|
||||
model_id: str
|
||||
description: str = ""
|
||||
|
||||
|
||||
@router.post("/assignments")
|
||||
async def upsert_assignment(body: AssignmentBody) -> Any:
|
||||
return await _coordinator_post("/api/assignments", body.model_dump())
|
||||
|
||||
|
||||
@router.delete("/assignments/{product}/{task}")
|
||||
async def delete_assignment(product: str, task: str) -> Any:
|
||||
return await _coordinator_delete(f"/api/assignments/{urllib.parse.quote(product, safe='')}/{urllib.parse.quote(task, safe='')}")
|
||||
|
||||
|
||||
# ── /model-registry ────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/model-registry")
|
||||
def list_model_registry() -> Any:
|
||||
return _coordinator_get("/api/model-registry")
|
||||
|
||||
|
||||
class ModelRegistryBody(BaseModel):
|
||||
model_id: str
|
||||
service_type: str
|
||||
vram_mb: int
|
||||
description: str = ""
|
||||
hf_repo: str = ""
|
||||
alias: str = ""
|
||||
|
||||
|
||||
@router.post("/model-registry")
|
||||
async def upsert_model_registry(body: ModelRegistryBody) -> Any:
|
||||
return await _coordinator_post("/api/model-registry", body.model_dump())
|
||||
|
||||
|
||||
@router.delete("/model-registry/{model_id:path}")
|
||||
async def delete_model_registry(model_id: str) -> Any:
|
||||
return await _coordinator_delete(f"/api/model-registry/{urllib.parse.quote(model_id, safe='')}")
|
||||
|
|
@ -1,34 +0,0 @@
|
|||
"""
|
||||
Avocet cloud session — thin wrapper around cf_core.cloud_session.
|
||||
|
||||
Usage in FastAPI routes:
|
||||
|
||||
from app.cloud_session import get_session, require_tier, CloudUser
|
||||
from fastapi import Depends
|
||||
|
||||
@router.get("/api/imitate")
|
||||
def imitate(session: CloudUser = Depends(get_session)):
|
||||
# session.user_id — Directus UUID (cloud) or "local" (self-hosted)
|
||||
# session.tier — free | paid | premium | ultra | local
|
||||
# session.has_byok — True if user has a configured LLM backend
|
||||
...
|
||||
|
||||
@router.post("/api/custom-models")
|
||||
def list_custom_models(session: CloudUser = Depends(require_tier("premium"))):
|
||||
...
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
from circuitforge_core.cloud_session import CloudSessionFactory, CloudUser, detect_byok
|
||||
|
||||
__all__ = ["CloudUser", "get_session", "require_tier"]
|
||||
|
||||
_factory = CloudSessionFactory(
|
||||
product="avocet",
|
||||
byok_detector=detect_byok,
|
||||
)
|
||||
|
||||
get_session = _factory.dependency()
|
||||
require_tier = _factory.require_tier
|
||||
282
app/dashboard.py
282
app/dashboard.py
|
|
@ -1,282 +0,0 @@
|
|||
"""Avocet -- dashboard aggregate API.
|
||||
|
||||
GET /api/dashboard returns the current flywheel state:
|
||||
labeled_since_last_eval -- items labeled after the most recent bench run
|
||||
last_eval_timestamp -- ISO timestamp of newest bench_results summary
|
||||
last_eval_best_score -- best macro_f1 from that summary
|
||||
active_jobs -- jobs with status queued or running
|
||||
corrections_pending -- sft_candidates with status=needs_review
|
||||
corrections_export_ready -- approved sft candidates with non-blank correction
|
||||
recent_bench_runs -- most-recent timestamp + score per bench type
|
||||
signals -- computed booleans for UI nudge indicators
|
||||
|
||||
Thresholds in label_tool.yaml pipeline: section:
|
||||
pipeline:
|
||||
data_eval_threshold: 50 # labeled items since last bench to trigger nudge
|
||||
eval_train_threshold: 0.05 # improvement delta needed before retraining (future)
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import yaml
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_ROOT = Path(__file__).parent.parent
|
||||
_DATA_DIR: Path = _ROOT / "data"
|
||||
_CONFIG_DIR: Path | None = None
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
_DEFAULT_DATA_EVAL_THRESHOLD = 50
|
||||
_DEFAULT_EVAL_TRAIN_THRESHOLD = 0.05
|
||||
|
||||
|
||||
def set_data_dir(path: Path) -> None:
|
||||
global _DATA_DIR
|
||||
_DATA_DIR = path
|
||||
|
||||
def set_config_dir(path: Path | None) -> None:
|
||||
global _CONFIG_DIR
|
||||
_CONFIG_DIR = path
|
||||
|
||||
def _config_file() -> Path:
|
||||
if _CONFIG_DIR is not None:
|
||||
return _CONFIG_DIR / "label_tool.yaml"
|
||||
return _ROOT / "config" / "label_tool.yaml"
|
||||
|
||||
def _load_thresholds() -> tuple[int, float]:
|
||||
f = _config_file()
|
||||
if f.exists():
|
||||
try:
|
||||
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
||||
pipeline = raw.get("pipeline", {}) or {}
|
||||
return (
|
||||
int(pipeline.get("data_eval_threshold", _DEFAULT_DATA_EVAL_THRESHOLD)),
|
||||
float(pipeline.get("eval_train_threshold", _DEFAULT_EVAL_TRAIN_THRESHOLD)),
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to read pipeline thresholds: %s", exc)
|
||||
return _DEFAULT_DATA_EVAL_THRESHOLD, _DEFAULT_EVAL_TRAIN_THRESHOLD
|
||||
|
||||
def _load_score_records() -> list[dict]:
|
||||
path = _DATA_DIR / "email_score.jsonl"
|
||||
if not path.exists():
|
||||
return []
|
||||
records = []
|
||||
for line in path.read_text(encoding="utf-8").splitlines():
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
records.append(json.loads(line))
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
return records
|
||||
|
||||
def _find_latest_classifier_bench(results_dir_override: str = "") -> tuple[str | None, float | None]:
|
||||
"""Return (iso_timestamp, best_macro_f1) from the newest bench_results summary.
|
||||
|
||||
Checks results_dir from cforch config if set, then falls back to
|
||||
_ROOT/bench_results/. Returns (None, None) if no results exist.
|
||||
"""
|
||||
candidates = []
|
||||
if results_dir_override:
|
||||
candidates.append(Path(results_dir_override))
|
||||
else:
|
||||
f = _config_file()
|
||||
if f.exists():
|
||||
try:
|
||||
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
||||
rd = (raw.get("cforch", {}) or {}).get("results_dir", "")
|
||||
if rd:
|
||||
candidates.append(Path(rd))
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to read cforch.results_dir from config: %s", exc)
|
||||
candidates.append(_ROOT / "bench_results")
|
||||
|
||||
for rdir in candidates:
|
||||
if not rdir.exists():
|
||||
continue
|
||||
subdirs = sorted([d for d in rdir.iterdir() if d.is_dir()], key=lambda d: d.name)
|
||||
for subdir in reversed(subdirs):
|
||||
summary = subdir / "summary.json"
|
||||
if summary.exists():
|
||||
try:
|
||||
data = json.loads(summary.read_text(encoding="utf-8"))
|
||||
if not isinstance(data, dict):
|
||||
continue # cforch LLM-bench summaries are lists; skip
|
||||
ts = data.get("timestamp") or subdir.name
|
||||
score = data.get("best_macro_f1") or data.get("macro_f1")
|
||||
return ts, (float(score) if isinstance(score, (int, float)) else None)
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to parse summary.json at %s: %s", summary, exc)
|
||||
return None, None
|
||||
|
||||
# Keep old name as alias so existing callers in tests still work.
|
||||
_find_latest_eval = _find_latest_classifier_bench
|
||||
|
||||
|
||||
def _count_corrections() -> tuple[int, int]:
|
||||
"""Return (pending_count, export_ready_count)."""
|
||||
pending = 0
|
||||
export_ready = 0
|
||||
candidates_path = _DATA_DIR / "sft_candidates.jsonl"
|
||||
approved_path = _DATA_DIR / "sft_approved.jsonl"
|
||||
if candidates_path.exists():
|
||||
for line in candidates_path.read_text(encoding="utf-8").splitlines():
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
r = json.loads(line)
|
||||
if r.get("status") == "needs_review":
|
||||
pending += 1
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
if approved_path.exists():
|
||||
for line in approved_path.read_text(encoding="utf-8").splitlines():
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
r = json.loads(line)
|
||||
if (r.get("status") == "approved"
|
||||
and r.get("corrected_response")
|
||||
and str(r["corrected_response"]).strip()):
|
||||
export_ready += 1
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
return pending, export_ready
|
||||
|
||||
def _get_active_jobs() -> list[dict]:
|
||||
"""Query train SQLite DB for queued/running jobs. Returns [] if DB absent."""
|
||||
try:
|
||||
from app.train.train import _DB_PATH, _db, _init_db
|
||||
if not _DB_PATH.exists():
|
||||
return []
|
||||
_init_db()
|
||||
with _db() as conn:
|
||||
rows = conn.execute(
|
||||
"SELECT id, type, model_key, status FROM jobs WHERE status IN ('queued', 'running')"
|
||||
).fetchall()
|
||||
return [{"id": r["id"], "type": r["type"], "model_key": r["model_key"], "status": r["status"]} for r in rows]
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to query train jobs DB: %s", exc)
|
||||
return []
|
||||
|
||||
def _count_labeled_since(since_ts: str | None) -> int:
|
||||
records = _load_score_records()
|
||||
if since_ts is None:
|
||||
return len(records)
|
||||
return sum(1 for r in records if r.get("labeled_at", "") > since_ts)
|
||||
|
||||
|
||||
def _get_recent_bench_runs() -> dict:
|
||||
"""Return most-recent run summary for each bench type.
|
||||
|
||||
Each entry: {"timestamp": str|None, "metric": str|None, "score": float|None}
|
||||
"""
|
||||
runs: dict[str, dict] = {
|
||||
"classifier": {"timestamp": None, "metric": "macro_f1", "score": None},
|
||||
"llm": {"timestamp": None, "metric": None, "score": None},
|
||||
"style": {"timestamp": None, "metric": None, "score": None},
|
||||
"plans": {"timestamp": None, "metric": "avg_score", "score": None},
|
||||
}
|
||||
|
||||
# ── Classifier: bench_results/<run>/summary.json ──────────────────────
|
||||
clf_ts, clf_score = _find_latest_classifier_bench()
|
||||
if clf_ts:
|
||||
runs["classifier"]["timestamp"] = clf_ts
|
||||
runs["classifier"]["score"] = clf_score
|
||||
|
||||
# ── LLM bench + Style: benchmark_results/ ─────────────────────────────
|
||||
f = _config_file()
|
||||
bench_dir: Path | None = None
|
||||
if f.exists():
|
||||
try:
|
||||
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
||||
rd = (raw.get("cforch", {}) or {}).get("results_dir", "")
|
||||
if rd:
|
||||
bench_dir = Path(rd)
|
||||
except Exception:
|
||||
pass
|
||||
if bench_dir is None:
|
||||
bench_dir = _ROOT / "benchmark_results"
|
||||
|
||||
if bench_dir.exists():
|
||||
llm_files = sorted(
|
||||
[p for p in bench_dir.glob("*.json") if not p.name.startswith("style_")],
|
||||
key=lambda p: p.stat().st_mtime, reverse=True,
|
||||
)
|
||||
if llm_files:
|
||||
try:
|
||||
data = json.loads(llm_files[0].read_text(encoding="utf-8"))
|
||||
runs["llm"]["timestamp"] = data.get("timestamp") or llm_files[0].stem
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
style_files = sorted(bench_dir.glob("style_*.json"), reverse=True)
|
||||
if style_files:
|
||||
try:
|
||||
data = json.loads(style_files[0].read_text(encoding="utf-8"))
|
||||
if isinstance(data, list) and data:
|
||||
runs["style"]["timestamp"] = data[0].get("timestamp") or style_files[0].stem
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# ── Plans bench: data/plans_bench_results/plans_*.json ────────────────
|
||||
plans_dir = _DATA_DIR / "plans_bench_results"
|
||||
if plans_dir.exists():
|
||||
plans_files = sorted(plans_dir.glob("plans_*.json"), reverse=True)
|
||||
if plans_files:
|
||||
run_id = plans_files[0].stem
|
||||
try:
|
||||
d: dict = json.loads(plans_files[0].read_text(encoding="utf-8"))
|
||||
all_scores = [
|
||||
r["total_score"]
|
||||
for results in d.values()
|
||||
for r in results
|
||||
if isinstance(r, dict) and not r.get("error")
|
||||
]
|
||||
avg = round(sum(all_scores) / len(all_scores), 3) if all_scores else None
|
||||
try:
|
||||
date_part = run_id.removeprefix("plans_")
|
||||
date, time_part = date_part.split("_")
|
||||
ts_display = f"{date} {time_part[:2]}:{time_part[2:4]}"
|
||||
except Exception:
|
||||
ts_display = run_id
|
||||
runs["plans"]["timestamp"] = ts_display
|
||||
runs["plans"]["score"] = avg
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return runs
|
||||
|
||||
|
||||
@router.get("/dashboard")
|
||||
def get_dashboard() -> dict:
|
||||
data_threshold, _train_threshold = _load_thresholds()
|
||||
last_ts, last_score = _find_latest_classifier_bench()
|
||||
labeled_since = _count_labeled_since(last_ts)
|
||||
corrections_pending, corrections_export_ready = _count_corrections()
|
||||
active_jobs = _get_active_jobs()
|
||||
recent_bench = _get_recent_bench_runs()
|
||||
return {
|
||||
"labeled_since_last_eval": labeled_since,
|
||||
"last_eval_timestamp": last_ts,
|
||||
"last_eval_best_score": last_score,
|
||||
"active_jobs": active_jobs,
|
||||
"corrections_pending": corrections_pending,
|
||||
"corrections_export_ready": corrections_export_ready,
|
||||
"recent_bench_runs": recent_bench,
|
||||
"signals": {
|
||||
"data_to_eval": labeled_since >= data_threshold,
|
||||
"eval_to_train": False, # future: implement delta-F1 comparison
|
||||
"train_to_fleet": False, # future: implement fleet sync signal
|
||||
},
|
||||
}
|
||||
|
|
@ -1,393 +0,0 @@
|
|||
"""Avocet -- SFT candidate corrections API (moved from app/sft.py).
|
||||
|
||||
All endpoints are registered on `router` (a FastAPI APIRouter).
|
||||
Primary prefix: /api/corrections (backward-compat alias: /api/sft -- pending Vue SPA migration)
|
||||
|
||||
Module-level globals (_DATA_DIR, _CONFIG_DIR) follow the same
|
||||
testability pattern as api.py -- override them via set_data_dir() and
|
||||
set_config_dir() in test fixtures.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
import yaml
|
||||
from fastapi import APIRouter, Header, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.utils import append_jsonl, read_jsonl, write_jsonl
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_ROOT = Path(__file__).parent.parent.parent
|
||||
_DATA_DIR: Path = _ROOT / "data"
|
||||
_CONFIG_DIR: Path | None = None
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# -- Testability seams ---------------------------------------------------------
|
||||
|
||||
def set_data_dir(path: Path) -> None:
|
||||
global _DATA_DIR
|
||||
_DATA_DIR = path
|
||||
|
||||
|
||||
def set_config_dir(path: Path | None) -> None:
|
||||
global _CONFIG_DIR
|
||||
_CONFIG_DIR = path
|
||||
|
||||
|
||||
# -- Internal helpers ----------------------------------------------------------
|
||||
|
||||
def _config_file() -> Path:
|
||||
if _CONFIG_DIR is not None:
|
||||
return _CONFIG_DIR / "label_tool.yaml"
|
||||
return _ROOT / "config" / "label_tool.yaml"
|
||||
|
||||
|
||||
_DEFAULT_BENCH_RESULTS_DIR = "/Library/Development/CircuitForge/circuitforge-orch/scripts/bench_results"
|
||||
|
||||
|
||||
def set_default_bench_results_dir(path: str) -> None:
|
||||
"""Override the default bench_results_dir -- used by tests to avoid real filesystem."""
|
||||
global _DEFAULT_BENCH_RESULTS_DIR
|
||||
_DEFAULT_BENCH_RESULTS_DIR = path
|
||||
|
||||
|
||||
def _get_bench_results_dir() -> Path:
|
||||
f = _config_file()
|
||||
if f.exists():
|
||||
try:
|
||||
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
||||
d = raw.get("sft", {}).get("bench_results_dir", "")
|
||||
if d:
|
||||
return Path(d)
|
||||
except yaml.YAMLError as exc:
|
||||
logger.warning("Failed to parse SFT config %s: %s", f, exc)
|
||||
return Path(_DEFAULT_BENCH_RESULTS_DIR)
|
||||
|
||||
|
||||
def _candidates_file() -> Path:
|
||||
return _DATA_DIR / "sft_candidates.jsonl"
|
||||
|
||||
|
||||
def _approved_file() -> Path:
|
||||
return _DATA_DIR / "sft_approved.jsonl"
|
||||
|
||||
|
||||
def _read_candidates() -> list[dict]:
|
||||
return read_jsonl(_candidates_file())
|
||||
|
||||
|
||||
def _write_candidates(records: list[dict]) -> None:
|
||||
write_jsonl(_candidates_file(), records)
|
||||
|
||||
|
||||
def _is_exportable(r: dict) -> bool:
|
||||
"""Return True if an approved record is ready to include in SFT export."""
|
||||
return (
|
||||
r.get("status") == "approved"
|
||||
and bool(r.get("corrected_response"))
|
||||
and str(r["corrected_response"]).strip() != ""
|
||||
)
|
||||
|
||||
|
||||
# -- GET /runs -----------------------------------------------------------------
|
||||
|
||||
@router.get("/runs")
|
||||
def get_runs():
|
||||
"""List available benchmark runs in the configured bench_results_dir."""
|
||||
from scripts.sft_import import discover_runs
|
||||
bench_dir = _get_bench_results_dir()
|
||||
existing = _read_candidates()
|
||||
# benchmark_run_id in each record equals the run's directory name by cf-orch convention
|
||||
imported_run_ids = {
|
||||
r["benchmark_run_id"]
|
||||
for r in existing
|
||||
if r.get("benchmark_run_id") is not None
|
||||
}
|
||||
runs = discover_runs(bench_dir)
|
||||
return [
|
||||
{
|
||||
"run_id": r["run_id"],
|
||||
"timestamp": r["timestamp"],
|
||||
"candidate_count": r["candidate_count"],
|
||||
"already_imported": r["run_id"] in imported_run_ids,
|
||||
}
|
||||
for r in runs
|
||||
]
|
||||
|
||||
|
||||
# -- POST /import --------------------------------------------------------------
|
||||
|
||||
class ImportRequest(BaseModel):
|
||||
run_id: str
|
||||
|
||||
|
||||
@router.post("/import")
|
||||
def post_import(req: ImportRequest):
|
||||
"""Import one benchmark run's sft_candidates.jsonl into the local data dir."""
|
||||
from scripts.sft_import import discover_runs, import_run
|
||||
bench_dir = _get_bench_results_dir()
|
||||
runs = discover_runs(bench_dir)
|
||||
run = next((r for r in runs if r["run_id"] == req.run_id), None)
|
||||
if run is None:
|
||||
raise HTTPException(404, f"Run {req.run_id!r} not found in bench_results_dir")
|
||||
return import_run(run["sft_path"], _DATA_DIR)
|
||||
|
||||
|
||||
# -- GET /queue ----------------------------------------------------------------
|
||||
|
||||
@router.get("/queue")
|
||||
def get_queue(page: int = 1, per_page: int = 20):
|
||||
"""Return paginated needs_review candidates."""
|
||||
records = _read_candidates()
|
||||
pending = [r for r in records if r.get("status") == "needs_review"]
|
||||
start = (page - 1) * per_page
|
||||
return {
|
||||
"items": pending[start:start + per_page],
|
||||
"total": len(pending),
|
||||
"page": page,
|
||||
"per_page": per_page,
|
||||
}
|
||||
|
||||
|
||||
# -- POST /submit --------------------------------------------------------------
|
||||
|
||||
FailureCategory = Literal[
|
||||
"scoring_artifact",
|
||||
"style_violation",
|
||||
"partial_answer",
|
||||
"wrong_answer",
|
||||
"format_error",
|
||||
"hallucination",
|
||||
]
|
||||
|
||||
|
||||
class SubmitRequest(BaseModel):
|
||||
id: str
|
||||
action: Literal["correct", "discard", "flag"]
|
||||
corrected_response: str | None = None
|
||||
failure_category: FailureCategory | None = None
|
||||
|
||||
|
||||
@router.post("/submit")
|
||||
def post_submit(req: SubmitRequest):
|
||||
"""Record a reviewer decision for one SFT candidate."""
|
||||
if req.action == "correct":
|
||||
if not req.corrected_response or not req.corrected_response.strip():
|
||||
raise HTTPException(422, "corrected_response must be non-empty when action is 'correct'")
|
||||
|
||||
records = _read_candidates()
|
||||
idx = next((i for i, r in enumerate(records) if r.get("id") == req.id), None)
|
||||
if idx is None:
|
||||
raise HTTPException(404, f"Record {req.id!r} not found")
|
||||
|
||||
record = records[idx]
|
||||
if record.get("status") != "needs_review":
|
||||
raise HTTPException(409, f"Record is not in needs_review state (current: {record.get('status')})")
|
||||
|
||||
if req.action == "correct":
|
||||
records[idx] = {
|
||||
**record,
|
||||
"status": "approved",
|
||||
"corrected_response": req.corrected_response,
|
||||
"failure_category": req.failure_category,
|
||||
}
|
||||
_write_candidates(records)
|
||||
append_jsonl(_approved_file(), records[idx])
|
||||
elif req.action == "discard":
|
||||
records[idx] = {**record, "status": "discarded"}
|
||||
_write_candidates(records)
|
||||
else: # flag
|
||||
records[idx] = {**record, "status": "model_rejected"}
|
||||
_write_candidates(records)
|
||||
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
# -- POST /undo ----------------------------------------------------------------
|
||||
|
||||
class UndoRequest(BaseModel):
|
||||
id: str
|
||||
|
||||
|
||||
@router.post("/undo")
|
||||
def post_undo(req: UndoRequest):
|
||||
"""Restore a previously actioned candidate back to needs_review."""
|
||||
records = _read_candidates()
|
||||
idx = next((i for i, r in enumerate(records) if r.get("id") == req.id), None)
|
||||
if idx is None:
|
||||
raise HTTPException(404, f"Record {req.id!r} not found")
|
||||
|
||||
record = records[idx]
|
||||
old_status = record.get("status")
|
||||
if old_status == "needs_review":
|
||||
raise HTTPException(409, "Record is already in needs_review state")
|
||||
|
||||
records[idx] = {**record, "status": "needs_review", "corrected_response": None}
|
||||
_write_candidates(records)
|
||||
|
||||
# If it was approved, remove from the approved file too
|
||||
if old_status == "approved":
|
||||
approved = read_jsonl(_approved_file())
|
||||
write_jsonl(_approved_file(), [r for r in approved if r.get("id") != req.id])
|
||||
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
# -- GET /export ---------------------------------------------------------------
|
||||
|
||||
@router.get("/export")
|
||||
def get_export() -> StreamingResponse:
|
||||
"""Stream approved records as SFT-ready JSONL for download."""
|
||||
exportable = [r for r in read_jsonl(_approved_file()) if _is_exportable(r)]
|
||||
|
||||
def generate():
|
||||
for r in exportable:
|
||||
record = {
|
||||
"messages": r.get("prompt_messages", []) + [
|
||||
{"role": "assistant", "content": r["corrected_response"]}
|
||||
]
|
||||
}
|
||||
yield json.dumps(record) + "\n"
|
||||
|
||||
timestamp = datetime.now(timezone.utc).strftime("%Y%m%d-%H%M%S")
|
||||
return StreamingResponse(
|
||||
generate(),
|
||||
media_type="application/x-ndjson",
|
||||
headers={
|
||||
"Content-Disposition": f'attachment; filename="sft_export_{timestamp}.jsonl"'
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# -- GET /stats ----------------------------------------------------------------
|
||||
|
||||
@router.get("/stats")
|
||||
def get_stats() -> dict[str, object]:
|
||||
"""Return counts by status, model, and task type."""
|
||||
records = _read_candidates()
|
||||
by_status: dict[str, int] = {}
|
||||
by_model: dict[str, int] = {}
|
||||
by_task_type: dict[str, int] = {}
|
||||
|
||||
for r in records:
|
||||
status = r.get("status", "unknown")
|
||||
by_status[status] = by_status.get(status, 0) + 1
|
||||
model = r.get("model_name", "unknown")
|
||||
by_model[model] = by_model.get(model, 0) + 1
|
||||
task_type = r.get("task_type", "unknown")
|
||||
by_task_type[task_type] = by_task_type.get(task_type, 0) + 1
|
||||
|
||||
approved = read_jsonl(_approved_file())
|
||||
export_ready = sum(1 for r in approved if _is_exportable(r))
|
||||
|
||||
return {
|
||||
"total": len(records),
|
||||
"by_status": by_status,
|
||||
"by_model": by_model,
|
||||
"by_task_type": by_task_type,
|
||||
"export_ready": export_ready,
|
||||
}
|
||||
|
||||
|
||||
# -- GET /config ---------------------------------------------------------------
|
||||
|
||||
@router.get("/config")
|
||||
def get_sft_config() -> dict:
|
||||
"""Return the current SFT configuration (bench_results_dir)."""
|
||||
f = _config_file()
|
||||
if not f.exists():
|
||||
return {"bench_results_dir": ""}
|
||||
try:
|
||||
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
||||
except yaml.YAMLError:
|
||||
return {"bench_results_dir": ""}
|
||||
sft_section = raw.get("sft") or {}
|
||||
return {"bench_results_dir": sft_section.get("bench_results_dir", "")}
|
||||
|
||||
|
||||
class SftConfigPayload(BaseModel):
|
||||
bench_results_dir: str
|
||||
|
||||
|
||||
@router.post("/config")
|
||||
def post_sft_config(payload: SftConfigPayload) -> dict:
|
||||
"""Write the bench_results_dir setting to the config file."""
|
||||
f = _config_file()
|
||||
f.parent.mkdir(parents=True, exist_ok=True)
|
||||
try:
|
||||
raw = yaml.safe_load(f.read_text(encoding="utf-8")) if f.exists() else {}
|
||||
raw = raw or {}
|
||||
except yaml.YAMLError:
|
||||
raw = {}
|
||||
raw["sft"] = {"bench_results_dir": payload.bench_results_dir}
|
||||
tmp = f.with_suffix(".tmp")
|
||||
tmp.write_text(yaml.dump(raw, allow_unicode=True, sort_keys=False), encoding="utf-8")
|
||||
tmp.rename(f)
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
# -- POST /ingest --------------------------------------------------------------
|
||||
|
||||
class IngestRequest(BaseModel):
|
||||
source: str # e.g. "peregrine", "kiwi"
|
||||
task_type: str # e.g. "email_classification", "recipe_suggestion"
|
||||
prompt: str # the prompt that was sent to the LLM
|
||||
response: str # the LLM's original response
|
||||
correction: str # the human-corrected response
|
||||
label: str | None = None # optional label/category
|
||||
|
||||
|
||||
@router.post("/ingest")
|
||||
def post_ingest(
|
||||
req: IngestRequest,
|
||||
authorization: str | None = Header(default=None),
|
||||
) -> dict:
|
||||
"""Ingest a correction from a sibling CF product.
|
||||
|
||||
Authentication: Authorization: Bearer <AVOCET_INGESTION_SECRET>
|
||||
|
||||
Creates a sft_candidates record with status='approved' (pre-approved by
|
||||
the calling product -- human review already happened upstream). Also writes
|
||||
to sft_approved.jsonl so it is immediately included in export counts.
|
||||
|
||||
Returns {"ok": True, "id": "<uuid>"}.
|
||||
"""
|
||||
expected_secret = os.environ.get("AVOCET_INGESTION_SECRET", "")
|
||||
if not expected_secret:
|
||||
raise HTTPException(503, "Ingestion not configured -- AVOCET_INGESTION_SECRET not set")
|
||||
|
||||
if not authorization or not authorization.startswith("Bearer "):
|
||||
raise HTTPException(401, "Missing or malformed Authorization header")
|
||||
|
||||
token = authorization.removeprefix("Bearer ").strip()
|
||||
if token != expected_secret:
|
||||
raise HTTPException(403, "Invalid ingestion secret")
|
||||
|
||||
record_id = str(uuid.uuid4())
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
record = {
|
||||
"id": record_id,
|
||||
"source": req.source,
|
||||
"task_type": req.task_type,
|
||||
"status": "approved",
|
||||
"prompt_messages": [{"role": "user", "content": req.prompt}],
|
||||
"model_response": req.response,
|
||||
"corrected_response": req.correction,
|
||||
"label": req.label,
|
||||
"timestamp": now,
|
||||
"benchmark_run_id": None,
|
||||
}
|
||||
append_jsonl(_candidates_file(), record)
|
||||
append_jsonl(_approved_file(), record)
|
||||
return {"ok": True, "id": record_id}
|
||||
|
|
@ -1,243 +0,0 @@
|
|||
"""Avocet -- IMAP fetch utilities and fetch API routes.
|
||||
|
||||
All IMAP helper functions (from app/imap_fetch.py) plus the
|
||||
/api/accounts/test and /api/fetch/stream endpoints.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import email as _email_lib
|
||||
import hashlib
|
||||
import imaplib
|
||||
import json
|
||||
import yaml
|
||||
from datetime import datetime, timedelta
|
||||
from email.header import decode_header as _raw_decode
|
||||
from pathlib import Path
|
||||
from typing import Iterator
|
||||
|
||||
from fastapi import APIRouter, Query
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.utils import extract_body, read_jsonl, write_jsonl
|
||||
|
||||
_ROOT = Path(__file__).parent.parent.parent
|
||||
_DATA_DIR: Path = _ROOT / "data"
|
||||
_CONFIG_DIR: Path | None = None
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def set_data_dir(path: Path) -> None:
|
||||
global _DATA_DIR
|
||||
_DATA_DIR = path
|
||||
|
||||
|
||||
def set_config_dir(path: Path | None) -> None:
|
||||
global _CONFIG_DIR
|
||||
_CONFIG_DIR = path
|
||||
|
||||
|
||||
def _config_file() -> Path:
|
||||
if _CONFIG_DIR is not None:
|
||||
return _CONFIG_DIR / "label_tool.yaml"
|
||||
return _ROOT / "config" / "label_tool.yaml"
|
||||
|
||||
|
||||
def _queue_file() -> Path:
|
||||
return _DATA_DIR / "email_label_queue.jsonl"
|
||||
|
||||
|
||||
def _get_config_accounts() -> list[dict]:
|
||||
f = _config_file()
|
||||
if not f.exists():
|
||||
return []
|
||||
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
||||
return raw.get("accounts", [])
|
||||
|
||||
|
||||
# ── IMAP decode helpers ───────────────────────────────────────────────────────
|
||||
|
||||
def _decode_str(value: str | None) -> str:
|
||||
if not value:
|
||||
return ""
|
||||
parts = _raw_decode(value)
|
||||
out = []
|
||||
for part, enc in parts:
|
||||
if isinstance(part, bytes):
|
||||
out.append(part.decode(enc or "utf-8", errors="replace"))
|
||||
else:
|
||||
out.append(str(part))
|
||||
return " ".join(out).strip()
|
||||
|
||||
|
||||
def entry_key(e: dict) -> str:
|
||||
"""Stable MD5 content-hash for dedup — matches label_tool.py _entry_key."""
|
||||
key = (e.get("subject", "") + (e.get("body", "") or "")[:100])
|
||||
return hashlib.md5(key.encode("utf-8", errors="replace")).hexdigest()
|
||||
|
||||
|
||||
# ── Wide search terms ────────────────────────────────────────────────────────
|
||||
|
||||
_WIDE_TERMS = [
|
||||
"interview", "phone screen", "video call", "zoom link", "schedule a call",
|
||||
"offer letter", "job offer", "offer of employment", "pleased to offer",
|
||||
"unfortunately", "not moving forward", "other candidates", "regret to inform",
|
||||
"no longer", "decided not to", "decided to go with",
|
||||
"opportunity", "interested in your background", "reached out", "great fit",
|
||||
"exciting role", "love to connect",
|
||||
"assessment", "questionnaire", "culture fit", "culture-fit", "online assessment",
|
||||
"application received", "thank you for applying", "application confirmation",
|
||||
"you applied", "your application for",
|
||||
"reschedule", "rescheduled", "new time", "moved to", "postponed", "new date",
|
||||
"job digest", "jobs you may like", "recommended jobs", "jobs for you",
|
||||
"new jobs", "job alert",
|
||||
"came across your profile", "reaching out about", "great fit for a role",
|
||||
"exciting opportunity",
|
||||
"welcome to the team", "start date", "onboarding", "first day", "we're excited to have you",
|
||||
"application", "recruiter", "recruiting", "hiring", "candidate",
|
||||
]
|
||||
|
||||
|
||||
# ── Public API ────────────────────────────────────────────────────────────────
|
||||
|
||||
def test_connection(acc: dict) -> tuple[bool, str, int | None]:
|
||||
"""Connect, login, select folder. Returns (ok, human_message, message_count|None)."""
|
||||
host = acc.get("host", "")
|
||||
port = int(acc.get("port", 993))
|
||||
use_ssl = acc.get("use_ssl", True)
|
||||
username = acc.get("username", "")
|
||||
password = acc.get("password", "")
|
||||
folder = acc.get("folder", "INBOX")
|
||||
if not host or not username or not password:
|
||||
return False, "Host, username, and password are all required.", None
|
||||
try:
|
||||
conn = (imaplib.IMAP4_SSL if use_ssl else imaplib.IMAP4)(host, port)
|
||||
conn.login(username, password)
|
||||
_, data = conn.select(folder, readonly=True)
|
||||
count_raw = data[0].decode() if data and data[0] else "0"
|
||||
count = int(count_raw) if count_raw.isdigit() else 0
|
||||
conn.logout()
|
||||
return True, f"Connected — {count:,} message(s) in {folder}.", count
|
||||
except Exception as exc:
|
||||
return False, str(exc), None
|
||||
|
||||
|
||||
def fetch_account_stream(
|
||||
acc: dict,
|
||||
days_back: int,
|
||||
limit: int,
|
||||
known_keys: set[str],
|
||||
) -> Iterator[dict]:
|
||||
"""Generator — yields progress dicts while fetching emails via IMAP.
|
||||
|
||||
Mutates `known_keys` in place for cross-account dedup within one fetch session.
|
||||
|
||||
Yields event dicts with "type" key:
|
||||
{"type": "start", "account": str, "total_uids": int}
|
||||
{"type": "progress", "account": str, "fetched": int, "total_uids": int}
|
||||
{"type": "done", "account": str, "added": int, "skipped": int, "emails": list}
|
||||
"""
|
||||
name = acc.get("name", acc.get("username", "?"))
|
||||
host = acc.get("host", "imap.gmail.com")
|
||||
port = int(acc.get("port", 993))
|
||||
use_ssl = acc.get("use_ssl", True)
|
||||
username = acc["username"]
|
||||
password = acc["password"]
|
||||
folder = acc.get("folder", "INBOX")
|
||||
since = (datetime.now() - timedelta(days=days_back)).strftime("%d-%b-%Y")
|
||||
|
||||
conn = (imaplib.IMAP4_SSL if use_ssl else imaplib.IMAP4)(host, port)
|
||||
conn.login(username, password)
|
||||
conn.select(folder, readonly=True)
|
||||
|
||||
seen_uids: dict[bytes, None] = {}
|
||||
for term in _WIDE_TERMS:
|
||||
try:
|
||||
_, data = conn.search(None, f'(SUBJECT "{term}" SINCE "{since}")')
|
||||
for uid in (data[0] or b"").split():
|
||||
seen_uids[uid] = None
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
uids = list(seen_uids.keys())[: limit * 3]
|
||||
yield {"type": "start", "account": name, "total_uids": len(uids)}
|
||||
|
||||
emails: list[dict] = []
|
||||
skipped = 0
|
||||
for i, uid in enumerate(uids):
|
||||
if len(emails) >= limit:
|
||||
break
|
||||
if i % 5 == 0:
|
||||
yield {"type": "progress", "account": name, "fetched": len(emails), "total_uids": len(uids)}
|
||||
try:
|
||||
_, raw_data = conn.fetch(uid, "(RFC822)")
|
||||
if not raw_data or not raw_data[0]:
|
||||
continue
|
||||
msg = _email_lib.message_from_bytes(raw_data[0][1])
|
||||
subj = _decode_str(msg.get("Subject", ""))
|
||||
from_addr = _decode_str(msg.get("From", ""))
|
||||
date = _decode_str(msg.get("Date", ""))
|
||||
body = extract_body(msg)[:800]
|
||||
entry = {"subject": subj, "body": body, "from_addr": from_addr,
|
||||
"date": date, "account": name}
|
||||
k = entry_key(entry)
|
||||
if k not in known_keys:
|
||||
known_keys.add(k)
|
||||
emails.append(entry)
|
||||
else:
|
||||
skipped += 1
|
||||
except Exception:
|
||||
skipped += 1
|
||||
|
||||
try:
|
||||
conn.logout()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
yield {"type": "done", "account": name, "added": len(emails), "skipped": skipped,
|
||||
"emails": emails}
|
||||
|
||||
|
||||
class AccountTestRequest(BaseModel):
|
||||
account: dict
|
||||
|
||||
|
||||
@router.post("/accounts/test")
|
||||
def test_account_route(req: AccountTestRequest) -> dict:
|
||||
ok, message, count = test_connection(req.account)
|
||||
return {"ok": ok, "message": message, "count": count}
|
||||
|
||||
|
||||
@router.get("/fetch/stream")
|
||||
def fetch_stream(
|
||||
accounts: str = Query(default=""),
|
||||
days_back: int = Query(default=90, ge=1, le=365),
|
||||
limit: int = Query(default=150, ge=1, le=1000),
|
||||
mode: str = Query(default="wide"),
|
||||
) -> StreamingResponse:
|
||||
selected_names = {n.strip() for n in accounts.split(",") if n.strip()}
|
||||
all_accounts = _get_config_accounts()
|
||||
selected = [a for a in all_accounts if a.get("name") in selected_names]
|
||||
|
||||
def generate():
|
||||
known_keys = {entry_key(x) for x in read_jsonl(_queue_file())}
|
||||
total_added = 0
|
||||
for acc in selected:
|
||||
try:
|
||||
batch_emails: list[dict] = []
|
||||
for event in fetch_account_stream(acc, days_back, limit, known_keys):
|
||||
if event["type"] == "done":
|
||||
batch_emails = event.pop("emails", [])
|
||||
total_added += event["added"]
|
||||
yield f"data: {json.dumps(event)}\n\n"
|
||||
if batch_emails:
|
||||
existing = read_jsonl(_queue_file())
|
||||
write_jsonl(_queue_file(), existing + batch_emails)
|
||||
except Exception as exc:
|
||||
yield f"data: {json.dumps({'type': 'error', 'account': acc.get('name', '?'), 'message': str(exc)})}\n\n"
|
||||
queue_size = len(read_jsonl(_queue_file()))
|
||||
yield f"data: {json.dumps({'type': 'complete', 'total_added': total_added, 'queue_size': queue_size})}\n\n"
|
||||
|
||||
return StreamingResponse(generate(), media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"})
|
||||
|
|
@ -1,729 +0,0 @@
|
|||
"""Avocet — Imitate tab API.
|
||||
|
||||
Fetches real samples from sibling CF product APIs, sends them through selected
|
||||
local LLMs (ollama), and streams responses back to the UI. Results can be
|
||||
pushed into the SFT corrections queue for human review.
|
||||
|
||||
All endpoints registered on `router`. api.py includes this with prefix="/api/imitate".
|
||||
|
||||
Module-level globals follow the same testability pattern as cforch.py and sft.py:
|
||||
override _CONFIG_DIR and _DATA_DIR via set_config_dir() / set_data_dir() in tests.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from urllib.error import URLError
|
||||
from urllib.request import Request, urlopen
|
||||
|
||||
import httpx
|
||||
import yaml
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.utils import append_jsonl
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_ROOT = Path(__file__).parent.parent.parent
|
||||
_CONFIG_DIR: Path | None = None
|
||||
_DATA_DIR: Path = _ROOT / "data"
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# ── Testability seams ──────────────────────────────────────────────────────────
|
||||
|
||||
def set_config_dir(path: Path | None) -> None:
|
||||
global _CONFIG_DIR
|
||||
_CONFIG_DIR = path
|
||||
|
||||
|
||||
def set_data_dir(path: Path) -> None:
|
||||
global _DATA_DIR
|
||||
_DATA_DIR = path
|
||||
|
||||
|
||||
# ── Internal helpers ───────────────────────────────────────────────────────────
|
||||
|
||||
def _config_file() -> Path:
|
||||
if _CONFIG_DIR is not None:
|
||||
return _CONFIG_DIR / "label_tool.yaml"
|
||||
return _ROOT / "config" / "label_tool.yaml"
|
||||
|
||||
|
||||
def _load_imitate_config() -> dict:
|
||||
"""Read label_tool.yaml and return the imitate sub-dict (or {} if absent)."""
|
||||
f = _config_file()
|
||||
if not f.exists():
|
||||
return {}
|
||||
try:
|
||||
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
||||
except yaml.YAMLError as exc:
|
||||
logger.warning("Failed to parse imitate config %s: %s", f, exc)
|
||||
return {}
|
||||
return raw.get("imitate", {}) or {}
|
||||
|
||||
|
||||
def _load_cforch_config() -> dict:
|
||||
"""Read cforch section for ollama_url fallback."""
|
||||
f = _config_file()
|
||||
if not f.exists():
|
||||
return {}
|
||||
try:
|
||||
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
||||
except yaml.YAMLError as exc:
|
||||
return {}
|
||||
return raw.get("cforch", {}) or {}
|
||||
|
||||
|
||||
def _ollama_url(cfg: dict) -> str:
|
||||
cforch = _load_cforch_config()
|
||||
return cfg.get("ollama_url") or cforch.get("ollama_url") or "http://localhost:11434"
|
||||
|
||||
|
||||
def _cforch_url() -> str:
|
||||
cforch = _load_cforch_config()
|
||||
return cforch.get("coordinator_url") or "http://localhost:7700"
|
||||
|
||||
|
||||
def _resolve_task_model(cforch_base: str, product: str, task: str) -> dict | None:
|
||||
"""Return {model_id, service_type} for a product.task assignment, or None if not found.
|
||||
|
||||
Calls GET coordinator/api/assignments and filters by product+task.
|
||||
The model registry entry is fetched separately to get service_type.
|
||||
Returns None (not raises) — callers emit a 'model_done' error event instead.
|
||||
"""
|
||||
try:
|
||||
asgn_resp = httpx.get(f"{cforch_base}/api/assignments", timeout=5.0)
|
||||
asgn_resp.raise_for_status()
|
||||
assignments: list[dict] = asgn_resp.json().get("assignments", []) or []
|
||||
match = next(
|
||||
(a for a in assignments if a.get("product") == product and a.get("task") == task),
|
||||
None,
|
||||
)
|
||||
if match is None:
|
||||
return None
|
||||
model_id: str = match.get("model_id", "")
|
||||
if not model_id:
|
||||
return None
|
||||
|
||||
# Look up service_type from model registry
|
||||
reg_resp = httpx.get(f"{cforch_base}/api/model-registry", timeout=5.0)
|
||||
service_type = "cf-text" # sensible default
|
||||
if reg_resp.is_success:
|
||||
models: list[dict] = reg_resp.json().get("models", []) or []
|
||||
reg_entry = next((m for m in models if m.get("model_id") == model_id), None)
|
||||
if reg_entry:
|
||||
service_type = reg_entry.get("service_type", "cf-text") or "cf-text"
|
||||
|
||||
return {"model_id": model_id, "service_type": service_type}
|
||||
except Exception as exc:
|
||||
logger.warning("Task resolution failed for %s.%s: %s", product, task, exc)
|
||||
return None
|
||||
|
||||
|
||||
def _cforch_catalog(cforch_base: str) -> list[dict]:
|
||||
"""Fetch the live cf-text catalog from cf-orch.
|
||||
|
||||
Filters out proxy entries (ollama://, vllm://, http://) — those models are
|
||||
served by their own services and should not be allocated via cf-text.
|
||||
Returns only models with real file-system paths that cf-text can load directly.
|
||||
"""
|
||||
try:
|
||||
resp = httpx.get(
|
||||
f"{cforch_base}/api/services/cf-text/catalog",
|
||||
params={"node_id": "heimdall"},
|
||||
timeout=5.0,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
raw = resp.json()
|
||||
result = []
|
||||
for model_id, entry in raw.items():
|
||||
if not isinstance(entry, dict):
|
||||
continue
|
||||
path = entry.get("path", "")
|
||||
# Skip proxy entries — they're routed through other services
|
||||
if "://" in path:
|
||||
continue
|
||||
result.append({
|
||||
"id": model_id,
|
||||
"vram_mb": entry.get("vram_mb", 0),
|
||||
"description": entry.get("description", ""),
|
||||
})
|
||||
return result
|
||||
except Exception as exc:
|
||||
logger.warning("Could not fetch cf-orch catalog: %s", exc)
|
||||
return []
|
||||
|
||||
|
||||
def _http_get_json(url: str, timeout: int = 5) -> Any:
|
||||
"""Fetch JSON from url; raise URLError on failure."""
|
||||
req = Request(url, headers={"Accept": "application/json"})
|
||||
with urlopen(req, timeout=timeout) as resp:
|
||||
return json.loads(resp.read().decode("utf-8"))
|
||||
|
||||
|
||||
def _is_online(base_url: str, health_path: str = "/api/health") -> bool:
|
||||
"""Return True if the product's health endpoint responds OK."""
|
||||
try:
|
||||
data = _http_get_json(f"{base_url.rstrip('/')}{health_path}", timeout=2)
|
||||
return bool(data)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _extract_sample(
|
||||
raw: Any,
|
||||
text_fields: list[str],
|
||||
sample_index: int = 0,
|
||||
sample_key: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Pull one item from a list or dict response and extract text_fields.
|
||||
|
||||
sample_key: if provided, unwrap raw[sample_key] before looking for a list.
|
||||
Falls back to a set of conventional envelope keys if sample_key is absent.
|
||||
"""
|
||||
item: dict[str, Any]
|
||||
if isinstance(raw, list):
|
||||
if not raw:
|
||||
return {}
|
||||
item = raw[min(sample_index, len(raw) - 1)]
|
||||
elif isinstance(raw, dict):
|
||||
# Use declared sample_key first, then fall back to conventional names.
|
||||
_ENVELOPE_KEYS = (
|
||||
"samples", "items", "results", "data", "jobs", "listings",
|
||||
"pantry", "saved_searches", "entries", "calls", "records",
|
||||
)
|
||||
search_keys = ([sample_key] if sample_key else []) + list(_ENVELOPE_KEYS)
|
||||
for key in search_keys:
|
||||
if key in raw and isinstance(raw[key], list):
|
||||
lst = raw[key]
|
||||
item = lst[min(sample_index, len(lst) - 1)] if lst else {}
|
||||
break
|
||||
else:
|
||||
item = raw
|
||||
else:
|
||||
return {}
|
||||
|
||||
parts = []
|
||||
for field in text_fields:
|
||||
val = item.get(field)
|
||||
if val and str(val).strip():
|
||||
parts.append(f"**{field}**: {val}")
|
||||
return {"item": item, "text": "\n\n".join(parts)}
|
||||
|
||||
|
||||
def _candidates_file() -> Path:
|
||||
return _DATA_DIR / "sft_candidates.jsonl"
|
||||
|
||||
|
||||
def _sse(data: dict) -> str:
|
||||
return f"data: {json.dumps(data)}\n\n"
|
||||
|
||||
|
||||
def _fetch_image_b64(image_url: str) -> str:
|
||||
"""Download an image URL and return it as a base64 string for ollama.
|
||||
|
||||
Returns empty string on any failure — a missing image is non-fatal;
|
||||
the model will still run against the text prompt alone.
|
||||
"""
|
||||
try:
|
||||
req = Request(image_url, headers={"User-Agent": "Avocet/1.0"})
|
||||
with urlopen(req, timeout=10) as resp:
|
||||
return base64.b64encode(resp.read()).decode("ascii")
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to fetch image %s: %s", image_url, exc)
|
||||
return ""
|
||||
|
||||
|
||||
def _run_ollama_streaming(
|
||||
ollama_base: str,
|
||||
model_id: str,
|
||||
prompt: str,
|
||||
temperature: float,
|
||||
system: str = "",
|
||||
images: list[str] | None = None,
|
||||
) -> tuple[str, int]:
|
||||
"""Call ollama /api/generate with stream=False; return (full_response, elapsed_ms).
|
||||
|
||||
Blocks until the model finishes; yields nothing — streaming is handled by
|
||||
the SSE generator in run_imitate().
|
||||
|
||||
system: optional system prompt passed as a separate field to ollama.
|
||||
images: list of base64-encoded image strings (vision models only).
|
||||
"""
|
||||
url = f"{ollama_base.rstrip('/')}/api/generate"
|
||||
body: dict = {
|
||||
"model": model_id,
|
||||
"prompt": prompt,
|
||||
"stream": False,
|
||||
"options": {"temperature": temperature},
|
||||
}
|
||||
if system:
|
||||
body["system"] = system
|
||||
if images:
|
||||
body["images"] = images
|
||||
payload = json.dumps(body).encode("utf-8")
|
||||
req = Request(url, data=payload, method="POST",
|
||||
headers={"Content-Type": "application/json"})
|
||||
t0 = time.time()
|
||||
try:
|
||||
with urlopen(req, timeout=120) as resp:
|
||||
body = json.loads(resp.read().decode("utf-8"))
|
||||
elapsed = int((time.time() - t0) * 1000)
|
||||
return body.get("response", ""), elapsed
|
||||
except Exception as exc:
|
||||
elapsed = int((time.time() - t0) * 1000)
|
||||
raise RuntimeError(str(exc)) from exc
|
||||
|
||||
|
||||
def _run_cftext(
|
||||
cforch_base: str,
|
||||
model_id: str,
|
||||
prompt: str,
|
||||
system: str,
|
||||
temperature: float,
|
||||
startup_timeout_s: float = 180.0,
|
||||
user_id: str | None = None,
|
||||
) -> tuple[str, int, bool]:
|
||||
"""Allocate cf-text via cf-orch, generate, release. Returns (response, elapsed_ms, cold_started).
|
||||
|
||||
Raises RuntimeError on allocation failure or generation error.
|
||||
cold_started=True means the service was launched from scratch (caller may log this).
|
||||
|
||||
Cold-start detection uses coordinator state signals (running/stopped) rather than
|
||||
polling the service health endpoint — this fails fast on model load errors instead
|
||||
of waiting out the full timeout.
|
||||
"""
|
||||
# Allocate
|
||||
alloc_resp = httpx.post(
|
||||
f"{cforch_base}/api/services/cf-text/allocate",
|
||||
json={
|
||||
"model_candidates": [model_id],
|
||||
"caller": "avocet",
|
||||
"pipeline": "imitate",
|
||||
**({"user_id": user_id} if user_id else {}),
|
||||
},
|
||||
timeout=30.0,
|
||||
)
|
||||
alloc_resp.raise_for_status()
|
||||
data = alloc_resp.json()
|
||||
service_url: str = data["url"]
|
||||
allocation_id: str = data.get("allocation_id", "")
|
||||
node_id: str = data.get("node_id", "")
|
||||
gpu_id: int | None = data.get("gpu_id")
|
||||
cold_started = data.get("started", False) and not data.get("warm", True)
|
||||
|
||||
# Wait for ready using coordinator state signals
|
||||
if cold_started:
|
||||
deadline = time.monotonic() + startup_timeout_s
|
||||
probe_misses = 0
|
||||
while time.monotonic() < deadline:
|
||||
try:
|
||||
status = httpx.get(
|
||||
f"{cforch_base}/api/services/cf-text/status", timeout=5.0
|
||||
)
|
||||
if status.is_success:
|
||||
instances = status.json().get("instances", [])
|
||||
match = next(
|
||||
(i for i in instances
|
||||
if i.get("node_id") == node_id and i.get("gpu_id") == gpu_id),
|
||||
None,
|
||||
)
|
||||
if match:
|
||||
probe_misses = 0
|
||||
state = match.get("state", "")
|
||||
if state == "running":
|
||||
break
|
||||
elif state == "stopped":
|
||||
if allocation_id:
|
||||
httpx.delete(
|
||||
f"{cforch_base}/api/services/cf-text/allocations/{allocation_id}",
|
||||
timeout=5.0,
|
||||
)
|
||||
raise RuntimeError(f"cf-text failed to load {model_id!r} (service stopped)")
|
||||
else:
|
||||
probe_misses += 1
|
||||
if probe_misses >= 6:
|
||||
# Coordinator hasn't registered instance yet — fall back to health poll
|
||||
try:
|
||||
if httpx.get(f"{service_url}/health", timeout=3.0).is_success:
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
except RuntimeError:
|
||||
raise
|
||||
except Exception:
|
||||
pass
|
||||
time.sleep(2.0)
|
||||
else:
|
||||
if allocation_id:
|
||||
httpx.delete(f"{cforch_base}/api/services/cf-text/allocations/{allocation_id}", timeout=5.0)
|
||||
raise RuntimeError(f"cf-text cold start timed out after {startup_timeout_s:.0f}s")
|
||||
|
||||
# Generate
|
||||
messages: list[dict] = []
|
||||
if system:
|
||||
messages.append({"role": "system", "content": system})
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
t0 = time.time()
|
||||
try:
|
||||
gen_resp = httpx.post(
|
||||
f"{service_url}/v1/chat/completions",
|
||||
json={
|
||||
"model": model_id,
|
||||
"messages": messages,
|
||||
"max_tokens": 300,
|
||||
"temperature": temperature,
|
||||
"stream": False,
|
||||
},
|
||||
timeout=120.0,
|
||||
)
|
||||
gen_resp.raise_for_status()
|
||||
elapsed_ms = int((time.time() - t0) * 1000)
|
||||
content = gen_resp.json()["choices"][0]["message"]["content"]
|
||||
return content.strip(), elapsed_ms, cold_started
|
||||
except Exception as exc:
|
||||
elapsed_ms = int((time.time() - t0) * 1000)
|
||||
raise RuntimeError(str(exc)) from exc
|
||||
finally:
|
||||
if allocation_id:
|
||||
try:
|
||||
httpx.delete(f"{cforch_base}/api/services/cf-text/allocations/{allocation_id}", timeout=5.0)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
# ── GET /products ──────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/products")
|
||||
def get_products() -> dict:
|
||||
"""List configured CF products with live online status."""
|
||||
cfg = _load_imitate_config()
|
||||
products_raw = cfg.get("products", []) or []
|
||||
products = []
|
||||
for p in products_raw:
|
||||
if not isinstance(p, dict):
|
||||
continue
|
||||
base_url = p.get("base_url", "")
|
||||
products.append({
|
||||
"id": p.get("id", ""),
|
||||
"name": p.get("name", ""),
|
||||
"icon": p.get("icon", "📦"),
|
||||
"description": p.get("description", ""),
|
||||
"base_url": base_url,
|
||||
"online": _is_online(base_url, p.get("health_path", "/api/health")) if base_url else False,
|
||||
})
|
||||
return {"products": products}
|
||||
|
||||
|
||||
# ── GET /products/{product_id}/sample ─────────────────────────────────────────
|
||||
|
||||
@router.get("/products/{product_id}/sample")
|
||||
def get_sample(product_id: str, index: int = 0) -> dict:
|
||||
"""Fetch a real sample from the given product's API."""
|
||||
cfg = _load_imitate_config()
|
||||
products_raw = cfg.get("products", []) or []
|
||||
|
||||
product: dict | None = None
|
||||
for p in products_raw:
|
||||
if isinstance(p, dict) and p.get("id") == product_id:
|
||||
product = p
|
||||
break
|
||||
|
||||
if product is None:
|
||||
raise HTTPException(404, f"Product '{product_id}' not in config")
|
||||
|
||||
base_url = product.get("base_url", "").rstrip("/")
|
||||
endpoint = product.get("sample_endpoint", "")
|
||||
if not base_url or not endpoint:
|
||||
raise HTTPException(422, "Product missing base_url or sample_endpoint")
|
||||
|
||||
url = f"{base_url}{endpoint}"
|
||||
try:
|
||||
raw = _http_get_json(url, timeout=5)
|
||||
except URLError as exc:
|
||||
raise HTTPException(503, f"Product API unreachable: {exc}") from exc
|
||||
except Exception as exc:
|
||||
raise HTTPException(502, f"Bad response from product API: {exc}") from exc
|
||||
|
||||
text_fields = product.get("text_fields", []) or []
|
||||
sample_key = product.get("sample_key") or None
|
||||
extracted = _extract_sample(raw, text_fields, index, sample_key=sample_key)
|
||||
if not extracted:
|
||||
raise HTTPException(404, "No sample items returned by product API")
|
||||
|
||||
prompt_template = product.get("prompt_template", "{text}")
|
||||
prompt = prompt_template.replace("{text}", extracted["text"])
|
||||
# Also substitute any {field_name} placeholders from the raw item fields.
|
||||
item = extracted.get("item", {})
|
||||
for field, val in item.items():
|
||||
prompt = prompt.replace(f"{{{field}}}", str(val) if val is not None else "")
|
||||
|
||||
# Expose system_prompt and image_url if the product API returns them.
|
||||
# system_prompt: Peregrine, Snipe (vision analysis instructions)
|
||||
# image_url: Snipe listing photos — Avocet downloads + base64-encodes at run time
|
||||
item = extracted.get("item", {})
|
||||
system_prompt = str(item.get("system_prompt", "")) if isinstance(item, dict) else ""
|
||||
image_url = str(item.get("image_url", "")) if isinstance(item, dict) else ""
|
||||
|
||||
return {
|
||||
"product_id": product_id,
|
||||
"sample_index": index,
|
||||
"text": extracted["text"],
|
||||
"prompt": prompt,
|
||||
"system_prompt": system_prompt,
|
||||
"image_url": image_url,
|
||||
"raw_item": item,
|
||||
}
|
||||
|
||||
|
||||
# ── GET /catalog ───────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/catalog")
|
||||
def get_catalog() -> dict:
|
||||
"""Return the live cf-text model catalog from cf-orch coordinator."""
|
||||
models = _cforch_catalog(_cforch_url())
|
||||
return {"models": models}
|
||||
|
||||
|
||||
# ── GET /run (SSE) ─────────────────────────────────────────────────────────────
|
||||
|
||||
def _get_imitate_session(request: Any, response: Any) -> "CloudUser | None":
|
||||
"""Optional session dependency — returns None when cloud_session is unavailable."""
|
||||
try:
|
||||
from app.cloud_session import get_session
|
||||
return get_session(request, response)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
@router.get("/run")
|
||||
def run_imitate(
|
||||
prompt: str = "",
|
||||
model_ids: str = "", # comma-separated ollama model IDs
|
||||
cf_text_model_ids: str = "", # comma-separated cf-text model IDs (via cf-orch)
|
||||
task_ids: str = "", # comma-separated "product/task" strings — resolved via assignments
|
||||
temperature: float = 0.7,
|
||||
product_id: str = "",
|
||||
system: str = "", # optional system prompt
|
||||
image_url: str = "", # optional image URL for vision models
|
||||
session: "Any" = Depends(_get_imitate_session),
|
||||
) -> StreamingResponse:
|
||||
"""Run a prompt through selected models and stream results as SSE.
|
||||
|
||||
Models can be selected three ways (combinable):
|
||||
- model_ids: explicit ollama model IDs
|
||||
- cf_text_model_ids: explicit cf-text model IDs routed via cf-orch
|
||||
- task_ids: "product/task" strings resolved via the coordinator assignments table
|
||||
|
||||
If image_url is provided, the image is downloaded once and passed to every
|
||||
model as a base64-encoded blob — allowing vision-capable local models to
|
||||
evaluate listing photos the same way Snipe's background task pipeline does.
|
||||
"""
|
||||
|
||||
if not prompt.strip():
|
||||
raise HTTPException(422, "prompt is required")
|
||||
|
||||
ollama_ids = [m.strip() for m in model_ids.split(",") if m.strip()]
|
||||
cftext_ids = [m.strip() for m in cf_text_model_ids.split(",") if m.strip()]
|
||||
raw_task_ids = [t.strip() for t in task_ids.split(",") if t.strip()]
|
||||
|
||||
# Resolve task assignments to concrete model IDs, routing to the right service.
|
||||
# Models that fail to resolve emit an error event at run time (non-fatal).
|
||||
if raw_task_ids:
|
||||
cforch_base = _cforch_url()
|
||||
for task_spec in raw_task_ids:
|
||||
parts = task_spec.split("/", 1)
|
||||
if len(parts) != 2:
|
||||
logger.warning("Skipping malformed task_id %r (expected product/task)", task_spec)
|
||||
continue
|
||||
product_name, task_name = parts
|
||||
resolved = _resolve_task_model(cforch_base, product_name, task_name)
|
||||
if resolved is None:
|
||||
logger.warning("No assignment found for task %r", task_spec)
|
||||
# Emit error at stream time via a sentinel in cftext_ids with a special label.
|
||||
# We instead store the failed task_spec to emit a model_done error.
|
||||
cftext_ids.append(f"__task_unresolved__:{task_spec}")
|
||||
continue
|
||||
mid = resolved["model_id"]
|
||||
svc = resolved["service_type"]
|
||||
if svc == "ollama":
|
||||
if mid not in ollama_ids:
|
||||
ollama_ids.append(mid)
|
||||
else:
|
||||
# cf-text, vllm, and any other cf-orch-managed service
|
||||
if mid not in cftext_ids:
|
||||
cftext_ids.append(mid)
|
||||
|
||||
if not ollama_ids and not cftext_ids:
|
||||
raise HTTPException(422, "model_ids, cf_text_model_ids, or task_ids is required")
|
||||
|
||||
cfg = _load_imitate_config()
|
||||
ollama_base = _ollama_url(cfg)
|
||||
cforch_base = _cforch_url()
|
||||
system_ctx = system.strip() or ""
|
||||
total_models = len(ollama_ids) + len(cftext_ids)
|
||||
|
||||
# Download image once before streaming — shared across ollama vision models
|
||||
images: list[str] = []
|
||||
if image_url.strip():
|
||||
b64 = _fetch_image_b64(image_url.strip())
|
||||
if b64:
|
||||
images = [b64]
|
||||
|
||||
def generate():
|
||||
results: list[dict] = []
|
||||
yield _sse({"type": "start", "total_models": total_models, "has_image": bool(images)})
|
||||
|
||||
# Ollama models
|
||||
for model_id in ollama_ids:
|
||||
yield _sse({"type": "model_start", "model": model_id, "service": "ollama"})
|
||||
try:
|
||||
response, elapsed_ms = _run_ollama_streaming(
|
||||
ollama_base, model_id, prompt, temperature,
|
||||
system=system_ctx, images=images or None,
|
||||
)
|
||||
result = {
|
||||
"model": model_id,
|
||||
"response": response,
|
||||
"elapsed_ms": elapsed_ms,
|
||||
"error": None,
|
||||
}
|
||||
except Exception as exc:
|
||||
result = {
|
||||
"model": model_id,
|
||||
"response": "",
|
||||
"elapsed_ms": 0,
|
||||
"error": str(exc),
|
||||
}
|
||||
results.append(result)
|
||||
yield _sse({"type": "model_done", **result})
|
||||
|
||||
# cf-text models via cf-orch — fan out in parallel when multiple models selected
|
||||
# Partition the list: real cf-text IDs vs unresolved-task sentinels.
|
||||
cftext_real = [m for m in cftext_ids if not m.startswith("__task_unresolved__:")]
|
||||
cftext_unresolved = [m for m in cftext_ids if m.startswith("__task_unresolved__:")]
|
||||
for sentinel in cftext_unresolved:
|
||||
task_spec = sentinel.split(":", 1)[1]
|
||||
result = {
|
||||
"model": task_spec,
|
||||
"response": "",
|
||||
"elapsed_ms": 0,
|
||||
"error": f"No assignment configured for task '{task_spec}'",
|
||||
}
|
||||
results.append(result)
|
||||
yield _sse({"type": "model_done", **result})
|
||||
|
||||
if cftext_real:
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
# Announce all models upfront so the UI can show loading states immediately
|
||||
for model_id in cftext_real:
|
||||
yield _sse({"type": "model_start", "model": model_id, "service": "cf-text"})
|
||||
|
||||
_user_id: str | None = getattr(session, "user_id", None)
|
||||
# Only forward real cloud user IDs — skip local/anon sessions
|
||||
if _user_id in (None, "local", "local-dev") or (_user_id or "").startswith("anon-"):
|
||||
_user_id = None
|
||||
|
||||
with ThreadPoolExecutor(max_workers=len(cftext_real)) as pool:
|
||||
future_to_model = {
|
||||
pool.submit(
|
||||
_run_cftext, cforch_base, mid, prompt, system_ctx, temperature,
|
||||
180.0, _user_id,
|
||||
): mid
|
||||
for mid in cftext_real
|
||||
}
|
||||
for future in as_completed(future_to_model):
|
||||
model_id = future_to_model[future]
|
||||
try:
|
||||
response, elapsed_ms, cold_started = future.result()
|
||||
if cold_started:
|
||||
yield _sse({"type": "model_coldstart", "model": model_id})
|
||||
result = {
|
||||
"model": model_id,
|
||||
"response": response,
|
||||
"elapsed_ms": elapsed_ms,
|
||||
"error": None,
|
||||
}
|
||||
except Exception as exc:
|
||||
result = {
|
||||
"model": model_id,
|
||||
"response": "",
|
||||
"elapsed_ms": 0,
|
||||
"error": str(exc),
|
||||
}
|
||||
results.append(result)
|
||||
yield _sse({"type": "model_done", **result})
|
||||
|
||||
yield _sse({"type": "complete", "results": results})
|
||||
|
||||
return StreamingResponse(
|
||||
generate(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# ── POST /push-corrections ─────────────────────────────────────────────────────
|
||||
|
||||
class ImitateResult(BaseModel):
|
||||
model: str
|
||||
response: str
|
||||
elapsed_ms: int
|
||||
error: str | None = None
|
||||
|
||||
|
||||
class PushCorrectionsRequest(BaseModel):
|
||||
product_id: str
|
||||
prompt: str
|
||||
results: list[ImitateResult]
|
||||
|
||||
|
||||
@router.post("/push-corrections")
|
||||
def push_corrections(req: PushCorrectionsRequest) -> dict:
|
||||
"""Append imitate results to sft_candidates.jsonl for human review."""
|
||||
if not req.prompt.strip():
|
||||
raise HTTPException(422, "prompt is required")
|
||||
if not req.results:
|
||||
raise HTTPException(422, "results list is empty")
|
||||
|
||||
ts = datetime.now(timezone.utc).isoformat()
|
||||
records = []
|
||||
for r in req.results:
|
||||
if r.error or not r.response.strip():
|
||||
continue
|
||||
records.append({
|
||||
"id": str(uuid.uuid4()),
|
||||
"source": "imitate",
|
||||
"product_id": req.product_id,
|
||||
"prompt_messages": [{"role": "user", "content": req.prompt}],
|
||||
"model_response": r.response,
|
||||
"model_id": r.model,
|
||||
"elapsed_ms": r.elapsed_ms,
|
||||
"status": "pending",
|
||||
"created_at": ts,
|
||||
})
|
||||
|
||||
if not records:
|
||||
raise HTTPException(422, "No non-error results to push")
|
||||
|
||||
dest = _candidates_file()
|
||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||
for record in records:
|
||||
append_jsonl(dest, record)
|
||||
|
||||
return {"pushed": len(records)}
|
||||
|
|
@ -1,222 +0,0 @@
|
|||
"""Avocet -- label queue API.
|
||||
|
||||
All label/skip/discard/undo/stats/config endpoints.
|
||||
Extracted from app/api.py as part of the v2 domain split.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import yaml
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
from fastapi.responses import FileResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.utils import append_jsonl, read_jsonl, write_jsonl
|
||||
|
||||
_ROOT = Path(__file__).parent.parent.parent
|
||||
_DATA_DIR: Path = _ROOT / "data"
|
||||
_CONFIG_DIR: Path | None = None
|
||||
_last_action: dict | None = None
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def set_data_dir(path: Path) -> None:
|
||||
global _DATA_DIR
|
||||
_DATA_DIR = path
|
||||
|
||||
def set_config_dir(path: Path | None) -> None:
|
||||
global _CONFIG_DIR
|
||||
_CONFIG_DIR = path
|
||||
|
||||
def reset_last_action() -> None:
|
||||
global _last_action
|
||||
_last_action = None
|
||||
|
||||
def _config_file() -> Path:
|
||||
if _CONFIG_DIR is not None:
|
||||
return _CONFIG_DIR / "label_tool.yaml"
|
||||
return _ROOT / "config" / "label_tool.yaml"
|
||||
|
||||
def _queue_file() -> Path:
|
||||
return _DATA_DIR / "email_label_queue.jsonl"
|
||||
|
||||
def _score_file() -> Path:
|
||||
return _DATA_DIR / "email_score.jsonl"
|
||||
|
||||
def _discarded_file() -> Path:
|
||||
return _DATA_DIR / "discarded.jsonl"
|
||||
|
||||
def _item_id(item: dict) -> str:
|
||||
key = (item.get("subject", "") + (item.get("body", "") or "")[:100])
|
||||
return hashlib.md5(key.encode("utf-8", errors="replace")).hexdigest()
|
||||
|
||||
def _normalize(item: dict) -> dict:
|
||||
return {
|
||||
"id": item.get("id") or _item_id(item),
|
||||
"subject": item.get("subject", ""),
|
||||
"body": item.get("body", ""),
|
||||
"from": item.get("from") or item.get("from_addr", ""),
|
||||
"date": item.get("date", ""),
|
||||
"source": item.get("source") or item.get("account", ""),
|
||||
}
|
||||
|
||||
_LABEL_META = [
|
||||
{"name": "interview_scheduled", "emoji": "\U0001f4c5", "color": "#4CAF50", "key": "1"},
|
||||
{"name": "offer_received", "emoji": "\U0001f389", "color": "#2196F3", "key": "2"},
|
||||
{"name": "rejected", "emoji": "❌", "color": "#F44336", "key": "3"},
|
||||
{"name": "positive_response", "emoji": "\U0001f44d", "color": "#FF9800", "key": "4"},
|
||||
{"name": "survey_received", "emoji": "\U0001f4cb", "color": "#9C27B0", "key": "5"},
|
||||
{"name": "neutral", "emoji": "⬜", "color": "#607D8B", "key": "6"},
|
||||
{"name": "event_rescheduled", "emoji": "\U0001f504", "color": "#FF5722", "key": "7"},
|
||||
{"name": "digest", "emoji": "\U0001f4f0", "color": "#00BCD4", "key": "8"},
|
||||
{"name": "new_lead", "emoji": "\U0001f91d", "color": "#009688", "key": "9"},
|
||||
{"name": "hired", "emoji": "\U0001f38a", "color": "#FFC107", "key": "h"},
|
||||
]
|
||||
|
||||
@router.get("/queue")
|
||||
def get_queue(limit: int = Query(default=10, ge=1, le=50)):
|
||||
items = read_jsonl(_queue_file())
|
||||
return {"items": [_normalize(x) for x in items[:limit]], "total": len(items)}
|
||||
|
||||
class LabelRequest(BaseModel):
|
||||
id: str
|
||||
label: str
|
||||
|
||||
@router.post("/label")
|
||||
def post_label(req: LabelRequest):
|
||||
global _last_action
|
||||
items = read_jsonl(_queue_file())
|
||||
match = next((x for x in items if _normalize(x)["id"] == req.id), None)
|
||||
if not match:
|
||||
raise HTTPException(404, f"Item {req.id!r} not found in queue")
|
||||
record = {**match, "label": req.label,
|
||||
"labeled_at": datetime.now(timezone.utc).isoformat()}
|
||||
append_jsonl(_score_file(), record)
|
||||
write_jsonl(_queue_file(), [x for x in items if _normalize(x)["id"] != req.id])
|
||||
_last_action = {"type": "label", "item": match, "label": req.label}
|
||||
return {"ok": True}
|
||||
|
||||
class SkipRequest(BaseModel):
|
||||
id: str
|
||||
|
||||
@router.post("/skip")
|
||||
def post_skip(req: SkipRequest):
|
||||
global _last_action
|
||||
items = read_jsonl(_queue_file())
|
||||
match = next((x for x in items if _normalize(x)["id"] == req.id), None)
|
||||
if not match:
|
||||
raise HTTPException(404, f"Item {req.id!r} not found in queue")
|
||||
reordered = [x for x in items if _normalize(x)["id"] != req.id] + [match]
|
||||
write_jsonl(_queue_file(), reordered)
|
||||
_last_action = {"type": "skip", "item": match}
|
||||
return {"ok": True}
|
||||
|
||||
class DiscardRequest(BaseModel):
|
||||
id: str
|
||||
|
||||
@router.post("/discard")
|
||||
def post_discard(req: DiscardRequest):
|
||||
global _last_action
|
||||
items = read_jsonl(_queue_file())
|
||||
match = next((x for x in items if _normalize(x)["id"] == req.id), None)
|
||||
if not match:
|
||||
raise HTTPException(404, f"Item {req.id!r} not found in queue")
|
||||
record = {**match, "label": "__discarded__",
|
||||
"discarded_at": datetime.now(timezone.utc).isoformat()}
|
||||
append_jsonl(_discarded_file(), record)
|
||||
write_jsonl(_queue_file(), [x for x in items if _normalize(x)["id"] != req.id])
|
||||
_last_action = {"type": "discard", "item": match}
|
||||
return {"ok": True}
|
||||
|
||||
@router.delete("/label/undo")
|
||||
def delete_undo():
|
||||
global _last_action
|
||||
if not _last_action:
|
||||
raise HTTPException(404, "No action to undo")
|
||||
action = _last_action
|
||||
item = action["item"]
|
||||
if action["type"] == "label":
|
||||
records = read_jsonl(_score_file())
|
||||
if not records:
|
||||
raise HTTPException(409, "Score file is empty -- cannot undo label")
|
||||
write_jsonl(_score_file(), records[:-1])
|
||||
items = read_jsonl(_queue_file())
|
||||
write_jsonl(_queue_file(), [item] + items)
|
||||
elif action["type"] == "discard":
|
||||
records = read_jsonl(_discarded_file())
|
||||
if not records:
|
||||
raise HTTPException(409, "Discarded file is empty -- cannot undo discard")
|
||||
write_jsonl(_discarded_file(), records[:-1])
|
||||
items = read_jsonl(_queue_file())
|
||||
write_jsonl(_queue_file(), [item] + items)
|
||||
elif action["type"] == "skip":
|
||||
items = read_jsonl(_queue_file())
|
||||
item_id = _normalize(item)["id"]
|
||||
items = [item] + [x for x in items if _normalize(x)["id"] != item_id]
|
||||
write_jsonl(_queue_file(), items)
|
||||
_last_action = None
|
||||
return {"undone": {"type": action["type"], "item": _normalize(item)}}
|
||||
|
||||
@router.get("/config/labels")
|
||||
def get_labels():
|
||||
return _LABEL_META
|
||||
|
||||
@router.get("/config")
|
||||
def get_config():
|
||||
f = _config_file()
|
||||
if not f.exists():
|
||||
return {"accounts": [], "max_per_account": 500}
|
||||
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
||||
return {"accounts": raw.get("accounts", []), "max_per_account": raw.get("max_per_account", 500)}
|
||||
|
||||
class ConfigPayload(BaseModel):
|
||||
accounts: list[dict]
|
||||
max_per_account: int = 500
|
||||
|
||||
@router.post("/config")
|
||||
def post_config(payload: ConfigPayload):
|
||||
f = _config_file()
|
||||
f.parent.mkdir(parents=True, exist_ok=True)
|
||||
tmp = f.with_suffix(".tmp")
|
||||
tmp.write_text(yaml.dump(payload.model_dump(), allow_unicode=True, sort_keys=False),
|
||||
encoding="utf-8")
|
||||
tmp.rename(f)
|
||||
return {"ok": True}
|
||||
|
||||
@router.get("/stats")
|
||||
def get_stats():
|
||||
records = read_jsonl(_score_file())
|
||||
counts: dict[str, int] = {}
|
||||
for r in records:
|
||||
lbl = r.get("label", "")
|
||||
if lbl:
|
||||
counts[lbl] = counts.get(lbl, 0) + 1
|
||||
benchmark_results: dict = {}
|
||||
benchmark_path = _DATA_DIR / "benchmark_results.json"
|
||||
if benchmark_path.exists():
|
||||
try:
|
||||
benchmark_results = json.loads(benchmark_path.read_text(encoding="utf-8"))
|
||||
except Exception:
|
||||
pass
|
||||
return {
|
||||
"total": len(records),
|
||||
"counts": counts,
|
||||
"score_file_bytes": _score_file().stat().st_size if _score_file().exists() else 0,
|
||||
"benchmark_results": benchmark_results,
|
||||
}
|
||||
|
||||
@router.get("/stats/download")
|
||||
def download_stats():
|
||||
if not _score_file().exists():
|
||||
raise HTTPException(404, "No score file")
|
||||
return FileResponse(
|
||||
str(_score_file()),
|
||||
filename="email_score.jsonl",
|
||||
media_type="application/jsonlines",
|
||||
headers={"Content-Disposition": 'attachment; filename="email_score.jsonl"'},
|
||||
)
|
||||
|
|
@ -1,462 +0,0 @@
|
|||
"""Avocet — Log Corpus receiver and labeling API.
|
||||
|
||||
Receives push batches from consented Turnstone nodes, stores entries for labeling,
|
||||
and exports labeled data as JSONL for the logreading fine-tune pipeline.
|
||||
|
||||
DB: data/corpus.db (separate from train_jobs.db — different lifecycle)
|
||||
Auth: Bearer token validated against corpus_sources table (seeded from label_tool.yaml).
|
||||
|
||||
All endpoints registered on `router`. api.py includes this with prefix="/api/corpus".
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import sqlite3
|
||||
import uuid
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Generator
|
||||
|
||||
import yaml
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.requests import Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_ROOT = Path(__file__).parent.parent.parent
|
||||
_CONFIG_DIR: Path | None = None
|
||||
_DATA_DIR: Path = _ROOT / "data"
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
_DB_PATH: Path = _ROOT / "data" / "corpus.db"
|
||||
|
||||
_PIPELINE_SOURCE_HOST = "pipeline_scrape"
|
||||
|
||||
_SCHEMA = """
|
||||
CREATE TABLE IF NOT EXISTS corpus_sources (
|
||||
token TEXT PRIMARY KEY,
|
||||
source_host TEXT NOT NULL,
|
||||
owner TEXT NOT NULL,
|
||||
consent_date TEXT NOT NULL,
|
||||
consent_method TEXT NOT NULL,
|
||||
active INTEGER NOT NULL DEFAULT 1
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS corpus_batches (
|
||||
id TEXT PRIMARY KEY,
|
||||
source_host TEXT NOT NULL,
|
||||
batch_type TEXT NOT NULL,
|
||||
received_at TEXT NOT NULL,
|
||||
entry_count INTEGER NOT NULL,
|
||||
watermark_from TEXT,
|
||||
watermark_to TEXT,
|
||||
raw_json TEXT NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS corpus_entries (
|
||||
id TEXT PRIMARY KEY,
|
||||
batch_id TEXT NOT NULL REFERENCES corpus_batches(id),
|
||||
source_host TEXT NOT NULL,
|
||||
origin_entry_id TEXT,
|
||||
timestamp_iso TEXT,
|
||||
severity TEXT,
|
||||
source_id TEXT,
|
||||
text TEXT NOT NULL,
|
||||
matched_patterns TEXT DEFAULT '[]',
|
||||
label_state TEXT NOT NULL DEFAULT 'unlabeled',
|
||||
failure_type TEXT,
|
||||
plain_explanation TEXT,
|
||||
known_pattern TEXT,
|
||||
labeled_at TEXT,
|
||||
labeled_by TEXT DEFAULT 'alan',
|
||||
pii_flagged INTEGER NOT NULL DEFAULT 0
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_ce_label_state ON corpus_entries(label_state);
|
||||
CREATE INDEX IF NOT EXISTS idx_ce_source ON corpus_entries(source_host);
|
||||
CREATE INDEX IF NOT EXISTS idx_ce_severity ON corpus_entries(severity);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS ingested_pipeline_files (
|
||||
filename TEXT PRIMARY KEY,
|
||||
ingested_at TEXT NOT NULL,
|
||||
entry_count INTEGER NOT NULL
|
||||
);
|
||||
"""
|
||||
|
||||
|
||||
# ── Testability seams ──────────────────────────────────────────────────────────
|
||||
|
||||
def set_config_dir(path: Path | None) -> None:
|
||||
global _CONFIG_DIR
|
||||
_CONFIG_DIR = path
|
||||
|
||||
|
||||
def set_data_dir(path: Path) -> None:
|
||||
global _DATA_DIR, _DB_PATH
|
||||
_DATA_DIR = path
|
||||
_DB_PATH = path / "corpus.db"
|
||||
|
||||
|
||||
# ── Internal helpers ───────────────────────────────────────────────────────────
|
||||
|
||||
def _config_file() -> Path:
|
||||
if _CONFIG_DIR is not None:
|
||||
return _CONFIG_DIR / "label_tool.yaml"
|
||||
return _ROOT / "config" / "label_tool.yaml"
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _db() -> Generator[sqlite3.Connection, None, None]:
|
||||
conn = sqlite3.connect(str(_DB_PATH))
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
try:
|
||||
yield conn
|
||||
conn.commit()
|
||||
except Exception:
|
||||
conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
def _init_db() -> None:
|
||||
with _db() as conn:
|
||||
conn.executescript(_SCHEMA)
|
||||
_seed_sources(conn)
|
||||
|
||||
|
||||
def _pipeline_ingest_dir() -> Path | None:
|
||||
"""Return the configured pipeline log ingest directory, or None if unset."""
|
||||
f = _config_file()
|
||||
if not f.exists():
|
||||
return None
|
||||
try:
|
||||
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
||||
except yaml.YAMLError:
|
||||
return None
|
||||
val = raw.get("corpus", {}).get("pipeline_ingest_dir", "") or ""
|
||||
return Path(val) if val else None
|
||||
|
||||
|
||||
def _load_corpus_config() -> list[dict]:
|
||||
f = _config_file()
|
||||
if not f.exists():
|
||||
return []
|
||||
try:
|
||||
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
||||
except yaml.YAMLError as exc:
|
||||
logger.warning("Failed to parse corpus config: %s", exc)
|
||||
return []
|
||||
return raw.get("corpus", {}).get("sources", []) or []
|
||||
|
||||
|
||||
def _seed_sources(conn: sqlite3.Connection) -> None:
|
||||
for src in _load_corpus_config():
|
||||
conn.execute(
|
||||
"INSERT OR IGNORE INTO corpus_sources (token, source_host, owner, consent_date, consent_method) "
|
||||
"VALUES (?, ?, ?, ?, ?)",
|
||||
(src["token"], src["source_host"], src["owner"],
|
||||
src["consent_date"], src["consent_method"]),
|
||||
)
|
||||
|
||||
|
||||
def _validate_token(token: str, conn: sqlite3.Connection) -> str:
|
||||
"""Return source_host for token, or raise 403."""
|
||||
row = conn.execute(
|
||||
"SELECT source_host FROM corpus_sources WHERE token = ? AND active = 1",
|
||||
(token,),
|
||||
).fetchone()
|
||||
if row is None:
|
||||
raise HTTPException(status_code=403, detail="Unknown or revoked consent token")
|
||||
return row["source_host"]
|
||||
|
||||
|
||||
def _extract_bearer(request: Request) -> str:
|
||||
auth = request.headers.get("Authorization", "")
|
||||
if not auth.startswith("Bearer "):
|
||||
raise HTTPException(status_code=401, detail="Bearer token required")
|
||||
return auth.removeprefix("Bearer ").strip()
|
||||
|
||||
|
||||
def _now_iso() -> str:
|
||||
return datetime.now(timezone.utc).isoformat()
|
||||
|
||||
|
||||
# ── Startup ────────────────────────────────────────────────────────────────────
|
||||
|
||||
_init_db()
|
||||
|
||||
|
||||
# ── POST /api/corpus/log-batch ─────────────────────────────────────────────────
|
||||
|
||||
@router.post("/log-batch")
|
||||
def receive_batch(request: Request, payload: dict) -> dict:
|
||||
"""Accept a push batch from a Turnstone node."""
|
||||
token = _extract_bearer(request)
|
||||
|
||||
batch_type = payload.get("batch_type", "raw_entries")
|
||||
entries_raw = payload.get("entries", [])
|
||||
batch_id = payload.get("batch_id") or str(uuid.uuid4())
|
||||
|
||||
with _db() as conn:
|
||||
source_host = _validate_token(token, conn)
|
||||
|
||||
conn.execute(
|
||||
"INSERT INTO corpus_batches (id, source_host, batch_type, received_at, entry_count, "
|
||||
"watermark_from, watermark_to, raw_json) VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
(batch_id, source_host, batch_type, _now_iso(), len(entries_raw),
|
||||
str(payload.get("watermark_from", "")),
|
||||
str(payload.get("watermark_to", "")),
|
||||
json.dumps(payload)),
|
||||
)
|
||||
|
||||
stored = 0
|
||||
for entry in entries_raw:
|
||||
text = entry.get("text", "").strip()
|
||||
if not text:
|
||||
continue
|
||||
conn.execute(
|
||||
"INSERT OR IGNORE INTO corpus_entries "
|
||||
"(id, batch_id, source_host, origin_entry_id, timestamp_iso, severity, "
|
||||
"source_id, text, matched_patterns) "
|
||||
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
(str(uuid.uuid4()), batch_id, source_host,
|
||||
entry.get("entry_id") or entry.get("id"),
|
||||
entry.get("timestamp_iso"),
|
||||
entry.get("severity"),
|
||||
entry.get("source_id"),
|
||||
text,
|
||||
json.dumps(entry.get("matched_patterns", []))),
|
||||
)
|
||||
stored += 1
|
||||
|
||||
logger.info("Received batch %s from %s: %d/%d entries stored",
|
||||
batch_id, source_host, stored, len(entries_raw))
|
||||
return {"received": True, "batch_id": batch_id, "entries_stored": stored}
|
||||
|
||||
|
||||
# ── GET /api/corpus/entries ────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/entries")
|
||||
def list_entries(
|
||||
state: str = "unlabeled",
|
||||
source_host: str | None = None,
|
||||
limit: int = 25,
|
||||
) -> dict:
|
||||
"""Return entries for labeling. Default: unlabeled entries, oldest first."""
|
||||
with _db() as conn:
|
||||
query = "SELECT * FROM corpus_entries WHERE label_state = ?"
|
||||
params: list = [state]
|
||||
if source_host:
|
||||
query += " AND source_host = ?"
|
||||
params.append(source_host)
|
||||
query += " ORDER BY rowid LIMIT ?"
|
||||
params.append(min(limit, 100))
|
||||
rows = conn.execute(query, params).fetchall()
|
||||
return {"entries": [dict(r) for r in rows], "count": len(rows)}
|
||||
|
||||
|
||||
# ── POST /api/corpus/entries/{id}/label ───────────────────────────────────────
|
||||
|
||||
@router.post("/entries/{entry_id}/label")
|
||||
def label_entry(entry_id: str, body: dict) -> dict:
|
||||
"""Submit a label for a corpus entry."""
|
||||
failure_type = body.get("failure_type")
|
||||
plain_explanation = body.get("plain_explanation", "").strip()
|
||||
known_pattern = body.get("known_pattern")
|
||||
pii_flagged = int(bool(body.get("pii_flagged", False)))
|
||||
|
||||
if not failure_type:
|
||||
raise HTTPException(status_code=422, detail="failure_type is required")
|
||||
valid_types = {"hardware", "software", "network", "security", "application", "none", "other"}
|
||||
if failure_type not in valid_types:
|
||||
raise HTTPException(status_code=422, detail=f"failure_type must be one of {sorted(valid_types)}")
|
||||
|
||||
with _db() as conn:
|
||||
row = conn.execute("SELECT id FROM corpus_entries WHERE id = ?", (entry_id,)).fetchone()
|
||||
if row is None:
|
||||
raise HTTPException(status_code=404, detail="Entry not found")
|
||||
conn.execute(
|
||||
"UPDATE corpus_entries SET label_state='labeled', failure_type=?, plain_explanation=?, "
|
||||
"known_pattern=?, labeled_at=?, pii_flagged=? WHERE id=?",
|
||||
(failure_type, plain_explanation, known_pattern, _now_iso(), pii_flagged, entry_id),
|
||||
)
|
||||
return {"labeled": True, "entry_id": entry_id}
|
||||
|
||||
|
||||
# ── POST /api/corpus/entries/{id}/skip ────────────────────────────────────────
|
||||
|
||||
@router.post("/entries/{entry_id}/skip")
|
||||
def skip_entry(entry_id: str) -> dict:
|
||||
with _db() as conn:
|
||||
row = conn.execute("SELECT id FROM corpus_entries WHERE id = ?", (entry_id,)).fetchone()
|
||||
if row is None:
|
||||
raise HTTPException(status_code=404, detail="Entry not found")
|
||||
conn.execute(
|
||||
"UPDATE corpus_entries SET label_state='skipped' WHERE id=?", (entry_id,)
|
||||
)
|
||||
return {"skipped": True, "entry_id": entry_id}
|
||||
|
||||
|
||||
# ── GET /api/corpus/stats ──────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/stats")
|
||||
def get_stats() -> dict:
|
||||
with _db() as conn:
|
||||
total = conn.execute("SELECT COUNT(*) FROM corpus_entries").fetchone()[0]
|
||||
by_state = {
|
||||
r["label_state"]: r["cnt"]
|
||||
for r in conn.execute(
|
||||
"SELECT label_state, COUNT(*) AS cnt FROM corpus_entries GROUP BY label_state"
|
||||
).fetchall()
|
||||
}
|
||||
by_source = {
|
||||
r["source_host"]: r["cnt"]
|
||||
for r in conn.execute(
|
||||
"SELECT source_host, COUNT(*) AS cnt FROM corpus_entries GROUP BY source_host"
|
||||
).fetchall()
|
||||
}
|
||||
by_severity = {
|
||||
r["severity"]: r["cnt"]
|
||||
for r in conn.execute(
|
||||
"SELECT severity, COUNT(*) AS cnt FROM corpus_entries "
|
||||
"WHERE severity IS NOT NULL GROUP BY severity"
|
||||
).fetchall()
|
||||
}
|
||||
batch_count = conn.execute("SELECT COUNT(*) FROM corpus_batches").fetchone()[0]
|
||||
return {
|
||||
"total_entries": total,
|
||||
"batch_count": batch_count,
|
||||
"by_label_state": by_state,
|
||||
"by_source": by_source,
|
||||
"by_severity": by_severity,
|
||||
}
|
||||
|
||||
|
||||
# ── GET /api/corpus/export ────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/export")
|
||||
def export_labeled() -> StreamingResponse:
|
||||
"""Stream labeled, non-PII entries as JSONL for SFT harness."""
|
||||
with _db() as conn:
|
||||
rows = conn.execute(
|
||||
"SELECT source_host, source_id, severity, text, failure_type, plain_explanation, known_pattern "
|
||||
"FROM corpus_entries "
|
||||
"WHERE label_state = 'labeled' AND pii_flagged = 0 AND plain_explanation != ''"
|
||||
"ORDER BY rowid"
|
||||
).fetchall()
|
||||
|
||||
def _generate():
|
||||
for row in rows:
|
||||
record = {
|
||||
"input": row["text"],
|
||||
"output": row["plain_explanation"],
|
||||
"metadata": {
|
||||
"failure_type": row["failure_type"],
|
||||
"source": row["source_host"],
|
||||
"source_id": row["source_id"],
|
||||
"severity": row["severity"],
|
||||
"known_pattern": row["known_pattern"],
|
||||
},
|
||||
}
|
||||
yield json.dumps(record) + "\n"
|
||||
|
||||
return StreamingResponse(
|
||||
_generate(),
|
||||
media_type="application/x-ndjson",
|
||||
headers={"Content-Disposition": "attachment; filename=log_corpus_labeled.jsonl"},
|
||||
)
|
||||
|
||||
|
||||
# ── POST /api/corpus/pipeline-ingest ─────────────────────────────────────────
|
||||
|
||||
def _ingest_one_file(conn: sqlite3.Connection, path: Path) -> int:
|
||||
"""Parse a pipeline JSONL file and insert entries. Returns count stored."""
|
||||
batch_id = str(uuid.uuid4())
|
||||
lines = path.read_text(encoding="utf-8").splitlines()
|
||||
entries_raw: list[dict] = []
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
entries_raw.append(json.loads(line))
|
||||
except json.JSONDecodeError:
|
||||
logger.debug("Skipping malformed line in %s", path.name)
|
||||
|
||||
conn.execute(
|
||||
"INSERT INTO corpus_batches (id, source_host, batch_type, received_at, entry_count, raw_json) "
|
||||
"VALUES (?, ?, ?, ?, ?, ?)",
|
||||
(batch_id, _PIPELINE_SOURCE_HOST, "pipeline_log", _now_iso(),
|
||||
len(entries_raw), json.dumps({"file": path.name})),
|
||||
)
|
||||
|
||||
stored = 0
|
||||
for entry in entries_raw:
|
||||
text = (entry.get("msg") or "").strip()
|
||||
if not text:
|
||||
continue
|
||||
conn.execute(
|
||||
"INSERT OR IGNORE INTO corpus_entries "
|
||||
"(id, batch_id, source_host, timestamp_iso, severity, source_id, text, matched_patterns) "
|
||||
"VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
(str(uuid.uuid4()), batch_id, _PIPELINE_SOURCE_HOST,
|
||||
entry.get("ts"),
|
||||
entry.get("level"),
|
||||
entry.get("logger"),
|
||||
text,
|
||||
json.dumps([entry["extra"]] if entry.get("extra") else [])),
|
||||
)
|
||||
stored += 1
|
||||
|
||||
conn.execute(
|
||||
"INSERT INTO ingested_pipeline_files (filename, ingested_at, entry_count) VALUES (?, ?, ?)",
|
||||
(path.name, _now_iso(), stored),
|
||||
)
|
||||
return stored
|
||||
|
||||
|
||||
@router.post("/pipeline-ingest")
|
||||
def pipeline_ingest() -> dict:
|
||||
"""Walk the configured pipeline log directory and ingest new JSONL files.
|
||||
|
||||
Skips files already recorded in ingested_pipeline_files. Safe to call
|
||||
repeatedly — idempotent by filename.
|
||||
"""
|
||||
ingest_dir = _pipeline_ingest_dir()
|
||||
if ingest_dir is None:
|
||||
raise HTTPException(404, "pipeline_ingest_dir not configured in label_tool.yaml")
|
||||
|
||||
ingested = 0
|
||||
skipped = 0
|
||||
total_stored = 0
|
||||
files_detail: list[dict] = []
|
||||
|
||||
with _db() as conn:
|
||||
already_done: set[str] = {
|
||||
row[0]
|
||||
for row in conn.execute("SELECT filename FROM ingested_pipeline_files").fetchall()
|
||||
}
|
||||
|
||||
for path in sorted(ingest_dir.glob("*.jsonl")):
|
||||
if path.name in already_done:
|
||||
skipped += 1
|
||||
continue
|
||||
stored = _ingest_one_file(conn, path)
|
||||
ingested += 1
|
||||
total_stored += stored
|
||||
files_detail.append({"file": path.name, "entries_stored": stored})
|
||||
|
||||
logger.info("Pipeline ingest: %d files ingested, %d skipped, %d entries stored",
|
||||
ingested, skipped, total_stored)
|
||||
return {
|
||||
"ingested_files": ingested,
|
||||
"skipped_files": skipped,
|
||||
"entries_stored": total_stored,
|
||||
"files": files_detail,
|
||||
}
|
||||
|
|
@ -1,313 +0,0 @@
|
|||
"""Avocet — Recipe scan labeling API (avocet#65).
|
||||
|
||||
Receives recipe scan items from the Kiwi pipeline (scanner/phone image +
|
||||
docuvision OCR extraction + ground-truth structured recipe), presents them
|
||||
for human review, and exports approved/edited pairs in the messages chat
|
||||
format for the vision fine-tune harness.
|
||||
|
||||
DB: data/recipe_scan.db (separate from corpus.db — different lifecycle)
|
||||
No auth required — local admin tool, not a push endpoint.
|
||||
|
||||
All endpoints registered on `router`. api.py includes this with
|
||||
prefix="/api/recipe-scan".
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import sqlite3
|
||||
import uuid
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Generator, Literal
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_ROOT = Path(__file__).parent.parent.parent
|
||||
_DB_PATH: Path = _ROOT / "data" / "recipe_scan.db"
|
||||
|
||||
_VALID_MODALITIES = {"scanner", "phone", "handwritten"}
|
||||
_VALID_STATUSES = {"pending", "approved", "edited", "rejected"}
|
||||
|
||||
_SCHEMA = """
|
||||
CREATE TABLE IF NOT EXISTS recipe_scan_items (
|
||||
id TEXT PRIMARY KEY,
|
||||
image_path TEXT NOT NULL,
|
||||
modality TEXT NOT NULL DEFAULT 'scanner',
|
||||
source TEXT NOT NULL DEFAULT 'purple_carrot',
|
||||
extracted TEXT NOT NULL,
|
||||
ground_truth TEXT NOT NULL,
|
||||
status TEXT NOT NULL DEFAULT 'pending',
|
||||
corrected TEXT,
|
||||
labeled_at TEXT,
|
||||
rejected_reason TEXT
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_rsi_status ON recipe_scan_items(status);
|
||||
CREATE INDEX IF NOT EXISTS idx_rsi_modality ON recipe_scan_items(modality);
|
||||
"""
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# ── Testability seam ──────────────────────────────────────────────────────────
|
||||
|
||||
def set_db_path(path: Path) -> None:
|
||||
global _DB_PATH
|
||||
_DB_PATH = path
|
||||
|
||||
|
||||
# ── Internal helpers ──────────────────────────────────────────────────────────
|
||||
|
||||
@contextmanager
|
||||
def _db() -> Generator[sqlite3.Connection, None, None]:
|
||||
conn = sqlite3.connect(str(_DB_PATH))
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
try:
|
||||
yield conn
|
||||
conn.commit()
|
||||
except Exception:
|
||||
conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
def _init_db() -> None:
|
||||
with _db() as conn:
|
||||
conn.executescript(_SCHEMA)
|
||||
|
||||
|
||||
def _now_iso() -> str:
|
||||
return datetime.now(timezone.utc).isoformat()
|
||||
|
||||
|
||||
def _build_training_pair(row: sqlite3.Row) -> dict:
|
||||
"""Build a messages-format training pair from a labeled row.
|
||||
|
||||
user message: correction prompt + the docuvision-extracted JSON draft.
|
||||
Trains the model to review and correct an existing extraction, which is
|
||||
more data-efficient than producing from scratch when OCR is usually close.
|
||||
|
||||
assistant message: the approved ground truth (or human-corrected JSON).
|
||||
"""
|
||||
target_str = row["corrected"] if row["corrected"] else row["ground_truth"]
|
||||
extracted = json.loads(row["extracted"])
|
||||
target = json.loads(target_str)
|
||||
user_content = (
|
||||
"Review and correct this recipe extraction. "
|
||||
"Return valid JSON with fields: title, description, ingredients, steps, "
|
||||
"prep_time, cook_time, servings.\n\n"
|
||||
f"Extraction to review:\n{json.dumps(extracted, ensure_ascii=False, indent=2)}"
|
||||
)
|
||||
return {
|
||||
"id": row["id"],
|
||||
"modality": row["modality"],
|
||||
"source": row["source"],
|
||||
"image_path": row["image_path"],
|
||||
"messages": [
|
||||
{"role": "user", "content": user_content},
|
||||
{"role": "assistant", "content": json.dumps(target, ensure_ascii=False)},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
_init_db()
|
||||
|
||||
|
||||
# ── POST /import ───────────────────────────────────────────────────────────────
|
||||
|
||||
class ImportItem(BaseModel):
|
||||
id: str = ""
|
||||
image_path: str
|
||||
modality: Literal["scanner", "phone", "handwritten"] = "scanner"
|
||||
source: str = "purple_carrot"
|
||||
extracted: dict
|
||||
ground_truth: dict
|
||||
|
||||
@field_validator("id", mode="before")
|
||||
@classmethod
|
||||
def default_id(cls, v: str) -> str:
|
||||
return v or str(uuid.uuid4())
|
||||
|
||||
|
||||
class ImportRequest(BaseModel):
|
||||
items: list[ImportItem]
|
||||
|
||||
|
||||
@router.post("/import")
|
||||
def import_items(body: ImportRequest) -> dict:
|
||||
"""Bulk-import scan items from the Kiwi pipeline. Idempotent by item id."""
|
||||
stored = 0
|
||||
with _db() as conn:
|
||||
for item in body.items:
|
||||
result = conn.execute(
|
||||
"INSERT OR IGNORE INTO recipe_scan_items "
|
||||
"(id, image_path, modality, source, extracted, ground_truth) "
|
||||
"VALUES (?, ?, ?, ?, ?, ?)",
|
||||
(item.id, item.image_path, item.modality, item.source,
|
||||
json.dumps(item.extracted), json.dumps(item.ground_truth)),
|
||||
)
|
||||
stored += result.rowcount
|
||||
return {"imported": stored, "total_submitted": len(body.items)}
|
||||
|
||||
|
||||
# ── GET /next ─────────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/next")
|
||||
def get_next() -> dict:
|
||||
"""Return the next pending item for review, oldest-first."""
|
||||
with _db() as conn:
|
||||
row = conn.execute(
|
||||
"SELECT * FROM recipe_scan_items WHERE status = 'pending' ORDER BY rowid LIMIT 1"
|
||||
).fetchone()
|
||||
if row is None:
|
||||
raise HTTPException(404, "No pending items in queue")
|
||||
return {
|
||||
**dict(row),
|
||||
"extracted": json.loads(row["extracted"]),
|
||||
"ground_truth": json.loads(row["ground_truth"]),
|
||||
}
|
||||
|
||||
|
||||
# ── POST /items/{id}/approve ──────────────────────────────────────────────────
|
||||
|
||||
@router.post("/items/{item_id}/approve")
|
||||
def approve_item(item_id: str) -> dict:
|
||||
"""Mark item as approved — extracted JSON is close enough to ground truth."""
|
||||
with _db() as conn:
|
||||
row = conn.execute("SELECT id FROM recipe_scan_items WHERE id = ?", (item_id,)).fetchone()
|
||||
if row is None:
|
||||
raise HTTPException(404, "Item not found")
|
||||
conn.execute(
|
||||
"UPDATE recipe_scan_items SET status='approved', labeled_at=? WHERE id=?",
|
||||
(_now_iso(), item_id),
|
||||
)
|
||||
return {"status": "approved", "id": item_id}
|
||||
|
||||
|
||||
# ── POST /items/{id}/edit ─────────────────────────────────────────────────────
|
||||
|
||||
class EditBody(BaseModel):
|
||||
corrected: dict
|
||||
|
||||
|
||||
@router.post("/items/{item_id}/edit")
|
||||
def edit_item(item_id: str, body: EditBody) -> dict:
|
||||
"""Approve with a human-corrected JSON. corrected overrides extracted in export."""
|
||||
with _db() as conn:
|
||||
row = conn.execute("SELECT id FROM recipe_scan_items WHERE id = ?", (item_id,)).fetchone()
|
||||
if row is None:
|
||||
raise HTTPException(404, "Item not found")
|
||||
conn.execute(
|
||||
"UPDATE recipe_scan_items SET status='edited', corrected=?, labeled_at=? WHERE id=?",
|
||||
(json.dumps(body.corrected), _now_iso(), item_id),
|
||||
)
|
||||
return {"status": "edited", "id": item_id}
|
||||
|
||||
|
||||
# ── POST /items/{id}/reject ───────────────────────────────────────────────────
|
||||
|
||||
class RejectBody(BaseModel):
|
||||
reason: str = ""
|
||||
|
||||
|
||||
@router.post("/items/{item_id}/reject")
|
||||
def reject_item(item_id: str, body: RejectBody = RejectBody()) -> dict:
|
||||
"""Reject item — extraction too broken to use for training."""
|
||||
with _db() as conn:
|
||||
row = conn.execute("SELECT id FROM recipe_scan_items WHERE id = ?", (item_id,)).fetchone()
|
||||
if row is None:
|
||||
raise HTTPException(404, "Item not found")
|
||||
conn.execute(
|
||||
"UPDATE recipe_scan_items SET status='rejected', rejected_reason=?, labeled_at=? WHERE id=?",
|
||||
(body.reason or None, _now_iso(), item_id),
|
||||
)
|
||||
return {"status": "rejected", "id": item_id}
|
||||
|
||||
|
||||
# ── GET /stats ────────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/stats")
|
||||
def get_stats() -> dict:
|
||||
with _db() as conn:
|
||||
total = conn.execute("SELECT COUNT(*) FROM recipe_scan_items").fetchone()[0]
|
||||
by_status = {
|
||||
r["status"]: r["cnt"]
|
||||
for r in conn.execute(
|
||||
"SELECT status, COUNT(*) AS cnt FROM recipe_scan_items GROUP BY status"
|
||||
).fetchall()
|
||||
}
|
||||
by_modality = {
|
||||
r["modality"]: r["cnt"]
|
||||
for r in conn.execute(
|
||||
"SELECT modality, COUNT(*) AS cnt FROM recipe_scan_items GROUP BY modality"
|
||||
).fetchall()
|
||||
}
|
||||
export_ready = conn.execute(
|
||||
"SELECT COUNT(*) FROM recipe_scan_items WHERE status IN ('approved', 'edited')"
|
||||
).fetchone()[0]
|
||||
return {
|
||||
"total": total,
|
||||
"by_status": by_status,
|
||||
"by_modality": by_modality,
|
||||
"export_ready": export_ready,
|
||||
}
|
||||
|
||||
|
||||
# ── GET /export ───────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/export")
|
||||
def export_pairs() -> StreamingResponse:
|
||||
"""Stream approved/edited items as JSONL training pairs (messages format)."""
|
||||
with _db() as conn:
|
||||
rows = conn.execute(
|
||||
"SELECT * FROM recipe_scan_items WHERE status IN ('approved', 'edited') ORDER BY rowid"
|
||||
).fetchall()
|
||||
|
||||
def _generate():
|
||||
for row in rows:
|
||||
yield json.dumps(_build_training_pair(row), ensure_ascii=False) + "\n"
|
||||
|
||||
return StreamingResponse(
|
||||
_generate(),
|
||||
media_type="application/x-ndjson",
|
||||
headers={"Content-Disposition": "attachment; filename=recipe_scan_pairs.jsonl"},
|
||||
)
|
||||
|
||||
|
||||
# ── GET /image ────────────────────────────────────────────────────────────────
|
||||
|
||||
_IMAGE_ROOT = Path("/Library/Assets/kiwi")
|
||||
|
||||
|
||||
@router.get("/image")
|
||||
def serve_image(path: str) -> StreamingResponse:
|
||||
"""Serve a scan image from /Library/Assets/kiwi/.
|
||||
|
||||
path must resolve within /Library/Assets/kiwi/ — rejects traversal attempts.
|
||||
"""
|
||||
try:
|
||||
resolved = Path(path).resolve()
|
||||
_IMAGE_ROOT.resolve() # ensure root itself is valid
|
||||
resolved.relative_to(_IMAGE_ROOT.resolve())
|
||||
except (ValueError, OSError):
|
||||
raise HTTPException(403, "Path outside allowed image directory")
|
||||
|
||||
if not resolved.exists():
|
||||
raise HTTPException(404, "Image not found")
|
||||
|
||||
suffix = resolved.suffix.lower()
|
||||
media_types = {".jpg": "image/jpeg", ".jpeg": "image/jpeg", ".png": "image/png", ".webp": "image/webp"}
|
||||
media_type = media_types.get(suffix, "application/octet-stream")
|
||||
|
||||
return StreamingResponse(
|
||||
open(resolved, "rb"),
|
||||
media_type=media_type,
|
||||
headers={"Cache-Control": "public, max-age=86400"},
|
||||
)
|
||||
|
|
@ -1,44 +0,0 @@
|
|||
"""Avocet -- eval router aggregator.
|
||||
|
||||
Collects benchmark sub-routers into a single importable `router`
|
||||
for the api.py factory. Each sub-router retains its established prefix
|
||||
so no frontend URL changes are needed.
|
||||
|
||||
Route prefixes when mounted at /api in api.py:
|
||||
/api/cforch/* -- cf-orch benchmark routes
|
||||
/api/style/* -- writing style benchmark routes
|
||||
/api/voice/* -- voice benchmark routes
|
||||
/api/plans-bench/* -- plans benchmark routes
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from app.cforch import router as _cforch_router
|
||||
from app.style import router as _style_router
|
||||
from app.voice import router as _voice_router
|
||||
from app.plans_bench import router as _plans_router
|
||||
from app.eval.embed_bench import router as _embed_router
|
||||
|
||||
router = APIRouter()
|
||||
router.include_router(_cforch_router, prefix="/cforch")
|
||||
router.include_router(_style_router, prefix="/style")
|
||||
router.include_router(_voice_router, prefix="/voice")
|
||||
router.include_router(_plans_router, prefix="/plans-bench")
|
||||
router.include_router(_embed_router, prefix="/embed-bench")
|
||||
|
||||
|
||||
def set_config_dir(path: Path | None) -> None:
|
||||
"""Propagate config dir override to all sub-modules -- used by tests."""
|
||||
import app.cforch as _cforch_mod
|
||||
import app.style as _style_mod
|
||||
import app.voice as _voice_mod
|
||||
import app.plans_bench as _plans_mod
|
||||
import app.eval.embed_bench as _embed_mod
|
||||
_cforch_mod.set_config_dir(path)
|
||||
_style_mod.set_config_dir(path)
|
||||
_voice_mod.set_config_dir(path)
|
||||
_plans_mod.set_config_dir(path)
|
||||
_embed_mod.set_config_dir(path)
|
||||
|
|
@ -1,293 +0,0 @@
|
|||
"""Avocet — embedding model comparison harness.
|
||||
|
||||
Exposes FastAPI routes under /api/embed-bench (mounted via app/eval/cforch.py).
|
||||
All computation is local: no LLM inference, Ollama only. MIT tier throughout.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import csv
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
import yaml
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_ROOT = Path(__file__).parent.parent.parent
|
||||
_CONFIG_DIR: Path | None = None # override via set_config_dir() in tests
|
||||
_RUN_ACTIVE: bool = False
|
||||
_RATINGS_FILE = _ROOT / "data" / "embed_bench_ratings.jsonl"
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# ── Testability seam ──────────────────────────────────────────────────────────
|
||||
|
||||
def set_config_dir(path: Path | None) -> None:
|
||||
global _CONFIG_DIR
|
||||
_CONFIG_DIR = path
|
||||
|
||||
|
||||
# ── Internal helpers ──────────────────────────────────────────────────────────
|
||||
|
||||
def _config_file() -> Path:
|
||||
if _CONFIG_DIR is not None:
|
||||
return _CONFIG_DIR / "label_tool.yaml"
|
||||
return _ROOT / "config" / "label_tool.yaml"
|
||||
|
||||
|
||||
def _load_config() -> dict[str, Any]:
|
||||
f = _config_file()
|
||||
if not f.exists():
|
||||
return {}
|
||||
try:
|
||||
return yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
||||
except yaml.YAMLError as exc:
|
||||
logger.warning("Failed to parse embed_bench config %s: %s", f, exc)
|
||||
return {}
|
||||
|
||||
|
||||
def _ollama_url() -> str:
|
||||
cfg = _load_config()
|
||||
embed_cfg = cfg.get("embed_bench", {}) or {}
|
||||
cforch_cfg = cfg.get("cforch", {}) or {}
|
||||
return (
|
||||
embed_cfg.get("ollama_url")
|
||||
or cforch_cfg.get("ollama_url", "http://localhost:11434")
|
||||
)
|
||||
|
||||
|
||||
def _ratings_path() -> Path:
|
||||
if _CONFIG_DIR is not None:
|
||||
return _CONFIG_DIR / "embed_bench_ratings.jsonl"
|
||||
return _RATINGS_FILE
|
||||
|
||||
|
||||
def _cosine(a: list[float], b: list[float]) -> float:
|
||||
if len(a) != len(b):
|
||||
raise ValueError(
|
||||
f"Embedding dimension mismatch: {len(a)} vs {len(b)}"
|
||||
)
|
||||
dot = sum(x * y for x, y in zip(a, b))
|
||||
mag_a = math.sqrt(sum(x * x for x in a))
|
||||
mag_b = math.sqrt(sum(x * x for x in b))
|
||||
if mag_a == 0.0 or mag_b == 0.0:
|
||||
return 0.0
|
||||
return dot / (mag_a * mag_b)
|
||||
|
||||
|
||||
# ── GET /models ───────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/models")
|
||||
def get_models() -> dict:
|
||||
"""Return Ollama embedding models available on the configured instance."""
|
||||
ollama = _ollama_url()
|
||||
models: list[dict] = []
|
||||
try:
|
||||
resp = httpx.get(f"{ollama}/api/tags", timeout=5.0)
|
||||
resp.raise_for_status()
|
||||
for entry in resp.json().get("models", []):
|
||||
models.append({
|
||||
"name": entry.get("name", ""),
|
||||
"size": entry.get("size", 0),
|
||||
})
|
||||
except httpx.HTTPStatusError as exc:
|
||||
logger.warning("Ollama /api/tags returned HTTP %s: %s", exc.response.status_code, exc)
|
||||
except httpx.RequestError as exc:
|
||||
logger.warning("Failed to reach Ollama for model list: %s", exc)
|
||||
return {"models": models, "ollama_url": ollama}
|
||||
|
||||
|
||||
# ── POST /run ─────────────────────────────────────────────────────────────────
|
||||
|
||||
class RunRequest(BaseModel):
|
||||
corpus: list[str]
|
||||
queries: list[str]
|
||||
models: list[str]
|
||||
top_k: int = 5
|
||||
ollama_url: str = ""
|
||||
|
||||
@field_validator("corpus")
|
||||
@classmethod
|
||||
def corpus_nonempty(cls, v: list[str]) -> list[str]:
|
||||
if not v:
|
||||
raise ValueError("corpus must not be empty")
|
||||
return v
|
||||
|
||||
@field_validator("queries")
|
||||
@classmethod
|
||||
def queries_nonempty(cls, v: list[str]) -> list[str]:
|
||||
if not v:
|
||||
raise ValueError("queries must not be empty")
|
||||
return v
|
||||
|
||||
@field_validator("models")
|
||||
@classmethod
|
||||
def models_nonempty(cls, v: list[str]) -> list[str]:
|
||||
if not v:
|
||||
raise ValueError("models must contain at least one model name")
|
||||
return v
|
||||
|
||||
|
||||
def _embed_texts(ollama: str, model: str, texts: list[str]) -> list[list[float]]:
|
||||
"""Batch-embed texts via Ollama /v1/embeddings. Returns one vector per text."""
|
||||
resp = httpx.post(
|
||||
f"{ollama}/v1/embeddings",
|
||||
json={"model": model, "input": texts},
|
||||
timeout=120.0,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json().get("data", [])
|
||||
return [item["embedding"] for item in data]
|
||||
|
||||
|
||||
def _sse(event: dict) -> str:
|
||||
return f"data: {json.dumps(event)}\n\n"
|
||||
|
||||
|
||||
@router.post("/run")
|
||||
def run_embed_bench(req: RunRequest) -> StreamingResponse:
|
||||
"""Embed corpus + queries with each model; stream SSE results."""
|
||||
global _RUN_ACTIVE
|
||||
|
||||
if _RUN_ACTIVE:
|
||||
raise HTTPException(409, "An embedding benchmark run is already active")
|
||||
|
||||
ollama = req.ollama_url or _ollama_url()
|
||||
|
||||
def _generate():
|
||||
global _RUN_ACTIVE
|
||||
_RUN_ACTIVE = True
|
||||
try:
|
||||
for model_idx, model in enumerate(req.models, start=1):
|
||||
yield _sse({
|
||||
"type": "progress",
|
||||
"msg": f"Indexing corpus with {model} ({model_idx}/{len(req.models)})...",
|
||||
})
|
||||
try:
|
||||
corpus_vecs = _embed_texts(ollama, model, req.corpus)
|
||||
except Exception as exc:
|
||||
yield _sse({"type": "error", "msg": f"Ollama error for {model}: {exc}"})
|
||||
continue
|
||||
|
||||
yield _sse({
|
||||
"type": "progress",
|
||||
"msg": f"Running queries with {model}...",
|
||||
})
|
||||
|
||||
for q_idx, query in enumerate(req.queries):
|
||||
try:
|
||||
q_vecs = _embed_texts(ollama, model, [query])
|
||||
except Exception as exc:
|
||||
yield _sse({"type": "error", "msg": f"Query embed error ({model}): {exc}"})
|
||||
continue
|
||||
q_vec = q_vecs[0]
|
||||
scored = sorted(
|
||||
[
|
||||
{"chunk_idx": i, "text": chunk, "score": round(_cosine(q_vec, cv), 4)}
|
||||
for i, (chunk, cv) in enumerate(zip(req.corpus, corpus_vecs))
|
||||
],
|
||||
key=lambda h: h["score"],
|
||||
reverse=True,
|
||||
)[: req.top_k]
|
||||
yield _sse({
|
||||
"type": "result",
|
||||
"query_idx": q_idx,
|
||||
"query": query,
|
||||
"model": model,
|
||||
"hits": scored,
|
||||
})
|
||||
|
||||
yield _sse({"type": "done"})
|
||||
finally:
|
||||
_RUN_ACTIVE = False
|
||||
|
||||
return StreamingResponse(_generate(), media_type="text/event-stream")
|
||||
|
||||
|
||||
# ── POST /rate ────────────────────────────────────────────────────────────────
|
||||
|
||||
_VALID_RATINGS = {"relevant", "not_relevant"}
|
||||
|
||||
|
||||
class RatingRequest(BaseModel):
|
||||
query: str
|
||||
model: str
|
||||
chunk_text: str
|
||||
chunk_idx: int
|
||||
rating: str
|
||||
|
||||
@field_validator("rating")
|
||||
@classmethod
|
||||
def rating_valid(cls, v: str) -> str:
|
||||
if v not in _VALID_RATINGS:
|
||||
raise ValueError(f"rating must be one of {_VALID_RATINGS}")
|
||||
return v
|
||||
|
||||
|
||||
@router.post("/rate")
|
||||
def rate_result(req: RatingRequest) -> dict:
|
||||
"""Append one rating to the JSONL ratings file."""
|
||||
entry = {
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"query": req.query,
|
||||
"model": req.model,
|
||||
"chunk_idx": req.chunk_idx,
|
||||
"chunk_text": req.chunk_text,
|
||||
"rating": req.rating,
|
||||
}
|
||||
path = _ratings_path()
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with path.open("a", encoding="utf-8") as fh:
|
||||
fh.write(json.dumps(entry) + "\n")
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
# ── GET /export ───────────────────────────────────────────────────────────────
|
||||
|
||||
_CSV_FIELDS = ["timestamp", "query", "model", "chunk_idx", "chunk_text", "rating"]
|
||||
|
||||
|
||||
@router.get("/export")
|
||||
def export_ratings(format: str = "csv") -> Any:
|
||||
"""Download ratings as CSV or JSON."""
|
||||
path = _ratings_path()
|
||||
rows: list[dict] = []
|
||||
if path.exists():
|
||||
for raw in path.read_text(encoding="utf-8").splitlines():
|
||||
raw = raw.strip()
|
||||
if raw:
|
||||
try:
|
||||
rows.append(json.loads(raw))
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
date_str = datetime.now(timezone.utc).strftime("%Y-%m-%d")
|
||||
|
||||
if format == "json":
|
||||
content = json.dumps(rows, ensure_ascii=False, indent=2)
|
||||
return StreamingResponse(
|
||||
iter([content]),
|
||||
media_type="application/json",
|
||||
headers={"Content-Disposition": f'attachment; filename="embed_comparison_{date_str}.json"'},
|
||||
)
|
||||
|
||||
# Default: CSV
|
||||
buf = io.StringIO()
|
||||
writer = csv.DictWriter(buf, fieldnames=_CSV_FIELDS, extrasaction="ignore")
|
||||
writer.writeheader()
|
||||
writer.writerows(rows)
|
||||
return StreamingResponse(
|
||||
iter([buf.getvalue()]),
|
||||
media_type="text/csv",
|
||||
headers={"Content-Disposition": f'attachment; filename="embed_comparison_{date_str}.csv"'},
|
||||
)
|
||||
|
|
@ -1,9 +1,158 @@
|
|||
"""Backward-compat shim -- logic moved to app/data/fetch.py."""
|
||||
import imaplib # noqa: F401 -- re-exported so existing patch("app.imap_fetch.imaplib...") calls still work
|
||||
from app.data.fetch import ( # noqa: F401
|
||||
entry_key,
|
||||
fetch_account_stream,
|
||||
test_connection,
|
||||
_decode_str,
|
||||
_WIDE_TERMS,
|
||||
)
|
||||
"""Avocet — IMAP fetch utilities.
|
||||
|
||||
Shared between app/api.py (FastAPI SSE endpoint) and the label UI.
|
||||
No Streamlit imports here — stdlib + imaplib only.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import email as _email_lib
|
||||
import hashlib
|
||||
import imaplib
|
||||
from datetime import datetime, timedelta
|
||||
from email.header import decode_header as _raw_decode
|
||||
from typing import Any, Iterator
|
||||
|
||||
from app.utils import extract_body, strip_html # noqa: F401 (strip_html re-exported for callers)
|
||||
|
||||
|
||||
# ── IMAP decode helpers ───────────────────────────────────────────────────────
|
||||
|
||||
def _decode_str(value: str | None) -> str:
|
||||
if not value:
|
||||
return ""
|
||||
parts = _raw_decode(value)
|
||||
out = []
|
||||
for part, enc in parts:
|
||||
if isinstance(part, bytes):
|
||||
out.append(part.decode(enc or "utf-8", errors="replace"))
|
||||
else:
|
||||
out.append(str(part))
|
||||
return " ".join(out).strip()
|
||||
|
||||
|
||||
def entry_key(e: dict) -> str:
|
||||
"""Stable MD5 content-hash for dedup — matches label_tool.py _entry_key."""
|
||||
key = (e.get("subject", "") + (e.get("body", "") or "")[:100])
|
||||
return hashlib.md5(key.encode("utf-8", errors="replace")).hexdigest()
|
||||
|
||||
|
||||
# ── Wide search terms ────────────────────────────────────────────────────────
|
||||
|
||||
_WIDE_TERMS = [
|
||||
"interview", "phone screen", "video call", "zoom link", "schedule a call",
|
||||
"offer letter", "job offer", "offer of employment", "pleased to offer",
|
||||
"unfortunately", "not moving forward", "other candidates", "regret to inform",
|
||||
"no longer", "decided not to", "decided to go with",
|
||||
"opportunity", "interested in your background", "reached out", "great fit",
|
||||
"exciting role", "love to connect",
|
||||
"assessment", "questionnaire", "culture fit", "culture-fit", "online assessment",
|
||||
"application received", "thank you for applying", "application confirmation",
|
||||
"you applied", "your application for",
|
||||
"reschedule", "rescheduled", "new time", "moved to", "postponed", "new date",
|
||||
"job digest", "jobs you may like", "recommended jobs", "jobs for you",
|
||||
"new jobs", "job alert",
|
||||
"came across your profile", "reaching out about", "great fit for a role",
|
||||
"exciting opportunity",
|
||||
"welcome to the team", "start date", "onboarding", "first day", "we're excited to have you",
|
||||
"application", "recruiter", "recruiting", "hiring", "candidate",
|
||||
]
|
||||
|
||||
|
||||
# ── Public API ────────────────────────────────────────────────────────────────
|
||||
|
||||
def test_connection(acc: dict) -> tuple[bool, str, int | None]:
|
||||
"""Connect, login, select folder. Returns (ok, human_message, message_count|None)."""
|
||||
host = acc.get("host", "")
|
||||
port = int(acc.get("port", 993))
|
||||
use_ssl = acc.get("use_ssl", True)
|
||||
username = acc.get("username", "")
|
||||
password = acc.get("password", "")
|
||||
folder = acc.get("folder", "INBOX")
|
||||
if not host or not username or not password:
|
||||
return False, "Host, username, and password are all required.", None
|
||||
try:
|
||||
conn = (imaplib.IMAP4_SSL if use_ssl else imaplib.IMAP4)(host, port)
|
||||
conn.login(username, password)
|
||||
_, data = conn.select(folder, readonly=True)
|
||||
count_raw = data[0].decode() if data and data[0] else "0"
|
||||
count = int(count_raw) if count_raw.isdigit() else 0
|
||||
conn.logout()
|
||||
return True, f"Connected — {count:,} message(s) in {folder}.", count
|
||||
except Exception as exc:
|
||||
return False, str(exc), None
|
||||
|
||||
|
||||
def fetch_account_stream(
|
||||
acc: dict,
|
||||
days_back: int,
|
||||
limit: int,
|
||||
known_keys: set[str],
|
||||
) -> Iterator[dict]:
|
||||
"""Generator — yields progress dicts while fetching emails via IMAP.
|
||||
|
||||
Mutates `known_keys` in place for cross-account dedup within one fetch session.
|
||||
|
||||
Yields event dicts with "type" key:
|
||||
{"type": "start", "account": str, "total_uids": int}
|
||||
{"type": "progress", "account": str, "fetched": int, "total_uids": int}
|
||||
{"type": "done", "account": str, "added": int, "skipped": int, "emails": list}
|
||||
"""
|
||||
name = acc.get("name", acc.get("username", "?"))
|
||||
host = acc.get("host", "imap.gmail.com")
|
||||
port = int(acc.get("port", 993))
|
||||
use_ssl = acc.get("use_ssl", True)
|
||||
username = acc["username"]
|
||||
password = acc["password"]
|
||||
folder = acc.get("folder", "INBOX")
|
||||
since = (datetime.now() - timedelta(days=days_back)).strftime("%d-%b-%Y")
|
||||
|
||||
conn = (imaplib.IMAP4_SSL if use_ssl else imaplib.IMAP4)(host, port)
|
||||
conn.login(username, password)
|
||||
conn.select(folder, readonly=True)
|
||||
|
||||
seen_uids: dict[bytes, None] = {}
|
||||
for term in _WIDE_TERMS:
|
||||
try:
|
||||
_, data = conn.search(None, f'(SUBJECT "{term}" SINCE "{since}")')
|
||||
for uid in (data[0] or b"").split():
|
||||
seen_uids[uid] = None
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
uids = list(seen_uids.keys())[: limit * 3]
|
||||
yield {"type": "start", "account": name, "total_uids": len(uids)}
|
||||
|
||||
emails: list[dict] = []
|
||||
skipped = 0
|
||||
for i, uid in enumerate(uids):
|
||||
if len(emails) >= limit:
|
||||
break
|
||||
if i % 5 == 0:
|
||||
yield {"type": "progress", "account": name, "fetched": len(emails), "total_uids": len(uids)}
|
||||
try:
|
||||
_, raw_data = conn.fetch(uid, "(RFC822)")
|
||||
if not raw_data or not raw_data[0]:
|
||||
continue
|
||||
msg = _email_lib.message_from_bytes(raw_data[0][1])
|
||||
subj = _decode_str(msg.get("Subject", ""))
|
||||
from_addr = _decode_str(msg.get("From", ""))
|
||||
date = _decode_str(msg.get("Date", ""))
|
||||
body = extract_body(msg)[:800]
|
||||
entry = {"subject": subj, "body": body, "from_addr": from_addr,
|
||||
"date": date, "account": name}
|
||||
k = entry_key(entry)
|
||||
if k not in known_keys:
|
||||
known_keys.add(k)
|
||||
emails.append(entry)
|
||||
else:
|
||||
skipped += 1
|
||||
except Exception:
|
||||
skipped += 1
|
||||
|
||||
try:
|
||||
conn.logout()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
yield {"type": "done", "account": name, "added": len(emails), "skipped": skipped,
|
||||
"emails": emails}
|
||||
|
|
|
|||
|
|
@ -1,3 +0,0 @@
|
|||
"""Backward-compat shim -- logic moved to app/data/imitate.py."""
|
||||
from app.data.imitate import router # noqa: F401
|
||||
from app.data.imitate import set_config_dir, set_data_dir # noqa: F401
|
||||
847
app/models.py
847
app/models.py
File diff suppressed because it is too large
Load diff
535
app/nodes.py
535
app/nodes.py
|
|
@ -1,535 +0,0 @@
|
|||
"""Avocet — Node Management API.
|
||||
|
||||
Proxies cf-orch coordinator and agent APIs to expose per-node GPU state,
|
||||
service affinity management, and Ollama model management.
|
||||
|
||||
Config is read from label_tool.yaml under the `cforch:` key.
|
||||
The `profiles_dir` key (new) points to the cf-orch node profile YAML directory.
|
||||
|
||||
Module-level globals follow the set_config_dir() testability pattern from cforch.py.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import yaml
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_ROOT = Path(__file__).parent.parent
|
||||
_CONFIG_DIR: Path | None = None # override in tests
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# ── Testability seams ──────────────────────────────────────────────────────────
|
||||
|
||||
def set_config_dir(path: Path | None) -> None:
|
||||
global _CONFIG_DIR
|
||||
_CONFIG_DIR = path
|
||||
|
||||
|
||||
# ── Internal helpers ───────────────────────────────────────────────────────────
|
||||
|
||||
def _config_file() -> Path:
|
||||
if _CONFIG_DIR is not None:
|
||||
return _CONFIG_DIR / "label_tool.yaml"
|
||||
return _ROOT / "config" / "label_tool.yaml"
|
||||
|
||||
|
||||
def _load_config() -> dict:
|
||||
"""Read label_tool.yaml cforch section. Returns empty dict on missing or parse error."""
|
||||
f = _config_file()
|
||||
if not f.exists():
|
||||
return {}
|
||||
try:
|
||||
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
||||
return raw.get("cforch", {}) or {}
|
||||
except yaml.YAMLError as exc:
|
||||
logger.warning("Failed to parse config %s: %s", f, exc)
|
||||
return {}
|
||||
|
||||
|
||||
def _profiles_dir() -> Path | None:
|
||||
"""Return the cf-orch node profiles directory, or None if not configured."""
|
||||
cfg = _load_config()
|
||||
pd = cfg.get("profiles_dir", "") or ""
|
||||
if pd:
|
||||
return Path(pd)
|
||||
bench = cfg.get("bench_script", "") or ""
|
||||
if bench:
|
||||
return Path(bench).parent.parent / "profiles" / "nodes"
|
||||
return None
|
||||
|
||||
|
||||
def _profile_path(node_id: str) -> Path | None:
|
||||
"""Return the path to a node's profile YAML, or None if profiles_dir is unknown."""
|
||||
pd = _profiles_dir()
|
||||
if pd is None:
|
||||
return None
|
||||
return pd / f"{node_id}.yaml"
|
||||
|
||||
|
||||
def _load_profile(node_id: str) -> dict | None:
|
||||
"""Load and parse a node profile YAML. Returns None if not found or malformed."""
|
||||
p = _profile_path(node_id)
|
||||
if p is None or not p.exists():
|
||||
return None
|
||||
try:
|
||||
return yaml.safe_load(p.read_text(encoding="utf-8")) or {}
|
||||
except yaml.YAMLError as exc:
|
||||
logger.warning("Malformed profile YAML %s: %s", p, exc)
|
||||
return None
|
||||
|
||||
|
||||
def _get_ollama_url(node_id: str) -> str:
|
||||
"""Derive Ollama URL from the node profile's agent_url (same host, port 11434)."""
|
||||
profile = _load_profile(node_id)
|
||||
if profile:
|
||||
nodes_section = profile.get("nodes", {}) or {}
|
||||
node_entry = nodes_section.get(node_id, {}) or {}
|
||||
agent_url = node_entry.get("agent_url", "") or ""
|
||||
if agent_url:
|
||||
parsed = urlparse(agent_url)
|
||||
return f"{parsed.scheme}://{parsed.hostname}:11434"
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Cannot determine Ollama URL for node {node_id}: no agent_url in profile",
|
||||
)
|
||||
|
||||
|
||||
# ── Endpoints ──────────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/nodes")
|
||||
def list_nodes() -> list:
|
||||
"""Return all nodes with live GPU stats merged with profile YAML."""
|
||||
import httpx
|
||||
|
||||
cfg = _load_config()
|
||||
coordinator_url = cfg.get("coordinator_url", "") or ""
|
||||
if not coordinator_url:
|
||||
return []
|
||||
|
||||
try:
|
||||
r = httpx.get(f"{coordinator_url}/api/nodes", timeout=5.0)
|
||||
r.raise_for_status()
|
||||
coord_nodes: list[dict] = r.json().get("nodes", [])
|
||||
except httpx.HTTPError as exc:
|
||||
logger.warning("Coordinator unreachable: %s", exc)
|
||||
return []
|
||||
|
||||
try:
|
||||
sr = httpx.get(f"{coordinator_url}/api/services", timeout=5.0)
|
||||
sr.raise_for_status()
|
||||
services_data: list[dict] = sr.json().get("services", [])
|
||||
except httpx.HTTPError:
|
||||
logger.warning("Services API unreachable for %s, skipping", coordinator_url)
|
||||
services_data = []
|
||||
|
||||
# Build per-node, per-GPU running services map
|
||||
running: dict[str, dict[int, list[str]]] = {}
|
||||
for svc in services_data:
|
||||
nid = svc.get("node_id", "")
|
||||
gid = svc.get("gpu_id")
|
||||
svc_name = svc.get("service", "")
|
||||
if nid and gid is not None and svc_name:
|
||||
running.setdefault(nid, {}).setdefault(gid, []).append(svc_name)
|
||||
|
||||
result = []
|
||||
for node in coord_nodes:
|
||||
node_id = node.get("node_id", "") or node.get("id", "")
|
||||
profile = _load_profile(node_id) if node_id else None
|
||||
profile_loaded = profile is not None
|
||||
|
||||
gpus = []
|
||||
for gpu in (node.get("gpus", []) or []):
|
||||
gpu_id = gpu.get("gpu_id", gpu.get("id", 0))
|
||||
services_assigned: list[str] = []
|
||||
if profile:
|
||||
node_entry = (profile.get("nodes", {}) or {}).get(node_id, {}) or {}
|
||||
for g in (node_entry.get("gpus", []) or []):
|
||||
if isinstance(g, dict) and g.get("id") == gpu_id:
|
||||
services_assigned = g.get("services", []) or []
|
||||
break
|
||||
gpus.append({
|
||||
"gpu_id": gpu_id,
|
||||
"card": gpu.get("card", ""),
|
||||
"vram_total_mb": gpu.get("vram_total_mb", 0),
|
||||
"vram_used_mb": gpu.get("vram_used_mb", 0),
|
||||
"vram_free_mb": gpu.get("vram_free_mb", 0),
|
||||
"temp_c": gpu.get("temp_c"),
|
||||
"utilization_pct": gpu.get("utilization_pct"),
|
||||
"compute_cap": gpu.get("compute_cap"),
|
||||
"services_assigned": services_assigned,
|
||||
"services_running": running.get(node_id, {}).get(gpu_id, []),
|
||||
})
|
||||
|
||||
services_catalog: dict = {}
|
||||
if profile:
|
||||
for svc_name, svc_info in (profile.get("services", {}) or {}).items():
|
||||
catalog = svc_info.get("catalog", {}) or {}
|
||||
services_catalog[svc_name] = {
|
||||
"min_compute_cap": svc_info.get("min_compute_cap", 0.0),
|
||||
"max_mb": svc_info.get("max_mb", 0),
|
||||
"catalog_size": len(catalog),
|
||||
}
|
||||
|
||||
result.append({
|
||||
"node_id": node_id,
|
||||
"online": node.get("online", True),
|
||||
"agent_url": node.get("agent_url", ""),
|
||||
"gpus": gpus,
|
||||
"profile_loaded": profile_loaded,
|
||||
"services_catalog": services_catalog,
|
||||
})
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@router.get("/nodes/{node_id}/profile")
|
||||
def get_node_profile(node_id: str) -> dict:
|
||||
"""Return the full parsed profile YAML for a node."""
|
||||
p = _profile_path(node_id)
|
||||
if p is None or not p.exists():
|
||||
raise HTTPException(404, f"No profile found for node {node_id}")
|
||||
try:
|
||||
data = yaml.safe_load(p.read_text(encoding="utf-8")) or {}
|
||||
except yaml.YAMLError as exc:
|
||||
raise HTTPException(500, f"Malformed profile YAML: {exc}")
|
||||
return data
|
||||
|
||||
|
||||
class UpdateServicesRequest(BaseModel):
|
||||
services: list[str]
|
||||
|
||||
|
||||
@router.post("/nodes/{node_id}/gpu/{gpu_id}/services")
|
||||
def update_gpu_services(node_id: str, gpu_id: int, body: UpdateServicesRequest) -> dict:
|
||||
"""Set service assignment for a GPU with compatibility validation, then atomic write."""
|
||||
import httpx
|
||||
|
||||
cfg = _load_config()
|
||||
coordinator_url = cfg.get("coordinator_url", "") or ""
|
||||
|
||||
p = _profile_path(node_id)
|
||||
if p is None or not p.exists():
|
||||
raise HTTPException(404, f"No profile found for node {node_id}")
|
||||
|
||||
try:
|
||||
profile = yaml.safe_load(p.read_text(encoding="utf-8")) or {}
|
||||
except yaml.YAMLError as exc:
|
||||
raise HTTPException(500, f"Malformed profile YAML: {exc}")
|
||||
|
||||
nodes_section = profile.get("nodes", {}) or {}
|
||||
node_entry = nodes_section.get(node_id, {}) or {}
|
||||
gpu_list = node_entry.get("gpus", []) or []
|
||||
|
||||
gpu_entry = next(
|
||||
(g for g in gpu_list if isinstance(g, dict) and g.get("id") == gpu_id),
|
||||
None,
|
||||
)
|
||||
if gpu_entry is None:
|
||||
raise HTTPException(404, f"GPU {gpu_id} not found in profile for node {node_id}")
|
||||
|
||||
gpu_compute_cap: float = gpu_entry.get("compute_cap") or 0.0
|
||||
gpu_vram_mb: int = gpu_entry.get("vram_mb") or 0
|
||||
services_def = profile.get("services", {}) or {}
|
||||
|
||||
for svc_name in body.services:
|
||||
if svc_name not in services_def:
|
||||
raise HTTPException(422, f"Service '{svc_name}' not defined in profile services dict")
|
||||
svc = services_def[svc_name]
|
||||
min_cap: float = svc.get("min_compute_cap", 0.0) or 0.0
|
||||
if gpu_compute_cap < min_cap:
|
||||
raise HTTPException(
|
||||
422,
|
||||
f"Service '{svc_name}' requires compute_cap >= {min_cap}; GPU has {gpu_compute_cap}",
|
||||
)
|
||||
catalog = svc.get("catalog", {}) or {}
|
||||
min_catalog_vram = (
|
||||
min((m.get("vram_mb", 0) for m in catalog.values()), default=0)
|
||||
if catalog else svc.get("max_mb", 0)
|
||||
)
|
||||
if gpu_vram_mb < min_catalog_vram:
|
||||
raise HTTPException(
|
||||
422,
|
||||
f"Service '{svc_name}' requires {min_catalog_vram} MB VRAM; GPU has {gpu_vram_mb} MB",
|
||||
)
|
||||
|
||||
# Immutable update of GPU services list
|
||||
new_gpu_list = [
|
||||
({**g, "services": body.services} if isinstance(g, dict) and g.get("id") == gpu_id else g)
|
||||
for g in gpu_list
|
||||
]
|
||||
new_profile = {
|
||||
**profile,
|
||||
"nodes": {
|
||||
**nodes_section,
|
||||
node_id: {**node_entry, "gpus": new_gpu_list},
|
||||
},
|
||||
}
|
||||
|
||||
# Atomic write: write to .tmp then rename
|
||||
tmp_yaml = Path(str(p) + ".tmp")
|
||||
tmp_yaml.write_text(yaml.dump(new_profile, default_flow_style=False), encoding="utf-8")
|
||||
os.replace(tmp_yaml, p)
|
||||
|
||||
# Trigger coordinator profile reload
|
||||
reloaded = False
|
||||
if coordinator_url:
|
||||
try:
|
||||
rr = httpx.post(
|
||||
f"{coordinator_url}/api/nodes/{node_id}/reload-profile", timeout=5.0
|
||||
)
|
||||
reloaded = rr.status_code < 300
|
||||
except Exception as exc:
|
||||
logger.warning("Coordinator reload failed for node %s: %s", node_id, exc)
|
||||
|
||||
return {"ok": True, "reloaded": reloaded, "warnings": []}
|
||||
|
||||
# ── Profile save / generate ────────────────────────────────────────────────────
|
||||
|
||||
class SaveProfileRequest(BaseModel):
|
||||
profile: dict
|
||||
|
||||
|
||||
@router.put("/nodes/{node_id}/profile", status_code=200)
|
||||
def save_profile(node_id: str, body: SaveProfileRequest) -> dict:
|
||||
"""Write a full profile dict to disk as YAML, then trigger coordinator reload."""
|
||||
p = _profile_path(node_id)
|
||||
if p is None:
|
||||
raise HTTPException(500, "profiles_dir not configured in label_tool.yaml")
|
||||
p.parent.mkdir(parents=True, exist_ok=True)
|
||||
tmp = Path(str(p) + ".tmp")
|
||||
tmp.write_text(
|
||||
yaml.dump(body.profile, default_flow_style=False, allow_unicode=True, sort_keys=False),
|
||||
encoding="utf-8",
|
||||
)
|
||||
os.replace(tmp, p)
|
||||
|
||||
cfg = _load_config()
|
||||
coordinator_url = cfg.get("coordinator_url", "") or ""
|
||||
reloaded = False
|
||||
if coordinator_url:
|
||||
try:
|
||||
import httpx
|
||||
rr = httpx.post(f"{coordinator_url}/api/nodes/{node_id}/reload-profile", timeout=5.0)
|
||||
reloaded = rr.status_code < 300
|
||||
except Exception as exc:
|
||||
logger.warning("Coordinator reload failed for %s: %s", node_id, exc)
|
||||
return {"ok": True, "reloaded": reloaded}
|
||||
|
||||
|
||||
@router.post("/nodes/{node_id}/profile/generate")
|
||||
def generate_profile(node_id: str) -> dict:
|
||||
"""Return a profile skeleton seeded from coordinator GPU data.
|
||||
|
||||
If a profile already exists, preserves its services section and only
|
||||
refreshes the nodes hardware section. Never writes to disk — the caller
|
||||
must call PUT /profile to persist.
|
||||
"""
|
||||
import httpx
|
||||
|
||||
cfg = _load_config()
|
||||
coordinator_url = cfg.get("coordinator_url", "") or ""
|
||||
if not coordinator_url:
|
||||
raise HTTPException(503, "coordinator_url not configured")
|
||||
|
||||
try:
|
||||
r = httpx.get(f"{coordinator_url}/api/nodes", timeout=5.0)
|
||||
r.raise_for_status()
|
||||
coord_nodes: list[dict] = r.json().get("nodes", [])
|
||||
except httpx.HTTPError as exc:
|
||||
raise HTTPException(502, f"Coordinator unreachable: {exc}")
|
||||
|
||||
node = next((n for n in coord_nodes if n.get("node_id") == node_id), None)
|
||||
if node is None:
|
||||
raise HTTPException(404, f"Node {node_id!r} not found in coordinator")
|
||||
|
||||
gpus = [
|
||||
{
|
||||
"id": g.get("gpu_id", i),
|
||||
"vram_mb": g.get("vram_total_mb", 0),
|
||||
"compute_cap": g.get("compute_cap", 0.0),
|
||||
"card": g.get("card", g.get("name", "")),
|
||||
"role": "inference",
|
||||
"services": [],
|
||||
}
|
||||
for i, g in enumerate(node.get("gpus", []))
|
||||
]
|
||||
vram_total = max((g["vram_mb"] for g in gpus), default=0)
|
||||
|
||||
existing = _load_profile(node_id) or {}
|
||||
return {
|
||||
"schema_version": existing.get("schema_version", 1),
|
||||
"name": existing.get("name", f"node-{node_id}"),
|
||||
"vram_total_mb": vram_total,
|
||||
"eviction_timeout_s": existing.get("eviction_timeout_s", 10.0),
|
||||
"services": existing.get("services", {}),
|
||||
"nodes": {
|
||||
node_id: {
|
||||
"local_model_root": (
|
||||
(existing.get("nodes", {}) or {})
|
||||
.get(node_id, {})
|
||||
.get("local_model_root", "")
|
||||
),
|
||||
"gpus": gpus,
|
||||
}
|
||||
},
|
||||
"model_size_hints": existing.get("model_size_hints", {}),
|
||||
}
|
||||
|
||||
|
||||
# ── Ollama model management ────────────────────────────────────────────────────
|
||||
|
||||
class PullRequest(BaseModel):
|
||||
name: str
|
||||
|
||||
|
||||
@router.get("/nodes/{node_id}/models/ollama")
|
||||
def list_ollama_models(node_id: str) -> dict:
|
||||
"""Proxy GET {ollama_url}/api/tags for a specific node."""
|
||||
import httpx
|
||||
|
||||
ollama_url = _get_ollama_url(node_id)
|
||||
try:
|
||||
r = httpx.get(f"{ollama_url}/api/tags", timeout=10.0)
|
||||
r.raise_for_status()
|
||||
return r.json()
|
||||
except Exception as exc:
|
||||
return {"error": str(exc)}
|
||||
|
||||
|
||||
@router.post("/nodes/{node_id}/models/ollama/pull")
|
||||
def pull_ollama_model(node_id: str, body: PullRequest) -> StreamingResponse:
|
||||
"""Stream Ollama pull progress as SSE events."""
|
||||
import httpx
|
||||
|
||||
if not body.name:
|
||||
raise HTTPException(400, "name is required")
|
||||
|
||||
ollama_url = _get_ollama_url(node_id)
|
||||
|
||||
def stream():
|
||||
try:
|
||||
with httpx.stream(
|
||||
"POST",
|
||||
f"{ollama_url}/api/pull",
|
||||
json={"name": body.name, "stream": True},
|
||||
timeout=300.0,
|
||||
) as resp:
|
||||
for line in resp.iter_lines():
|
||||
if line:
|
||||
yield f"data: {line}\n\n"
|
||||
except Exception as exc:
|
||||
yield f"data: {json.dumps({'error': str(exc)})}\n\n"
|
||||
|
||||
return StreamingResponse(stream(), media_type="text/event-stream")
|
||||
|
||||
|
||||
@router.delete("/nodes/{node_id}/models/ollama/{name:path}")
|
||||
def delete_ollama_model(node_id: str, name: str) -> dict:
|
||||
"""Proxy DELETE to Ollama for a specific node."""
|
||||
import httpx
|
||||
|
||||
ollama_url = _get_ollama_url(node_id)
|
||||
try:
|
||||
r = httpx.request("DELETE", f"{ollama_url}/api/delete", json={"name": name}, timeout=10.0)
|
||||
if r.status_code == 404:
|
||||
raise HTTPException(404, f"Model '{name}' not found on node {node_id}")
|
||||
r.raise_for_status()
|
||||
return {"ok": True}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
raise HTTPException(502, f"Ollama unreachable: {exc}")
|
||||
|
||||
|
||||
# ── Model deploy (add catalog entry) ──────────────────────────────────────────
|
||||
|
||||
class DeployModelRequest(BaseModel):
|
||||
model_id: str
|
||||
service_type: str
|
||||
vram_mb: int
|
||||
description: str = ""
|
||||
hf_repo: str = ""
|
||||
path: str = "" # explicit path; if empty, constructed from model_base_path + hf_repo slug
|
||||
|
||||
|
||||
@router.post("/nodes/{node_id}/models/deploy", status_code=200)
|
||||
def deploy_model(node_id: str, body: DeployModelRequest) -> dict:
|
||||
"""Register a model in the node's service catalog.
|
||||
|
||||
Adds (or updates) the catalog entry for body.model_id under the given
|
||||
service_type in the node's profile YAML, then triggers a coordinator reload.
|
||||
Does not download the model — that is the user's responsibility.
|
||||
Returns the resolved path so the caller can see where the model should land.
|
||||
"""
|
||||
p = _profile_path(node_id)
|
||||
if p is None or not p.exists():
|
||||
raise HTTPException(404, f"No profile found for node {node_id!r}")
|
||||
|
||||
try:
|
||||
profile = yaml.safe_load(p.read_text(encoding="utf-8")) or {}
|
||||
except yaml.YAMLError as exc:
|
||||
raise HTTPException(500, f"Malformed profile YAML: {exc}")
|
||||
|
||||
services_def = profile.get("services", {}) or {}
|
||||
svc = services_def.get(body.service_type)
|
||||
if svc is None:
|
||||
raise HTTPException(
|
||||
422,
|
||||
f"Service '{body.service_type}' not defined in node '{node_id}' profile; "
|
||||
"add it first via the profile editor",
|
||||
)
|
||||
|
||||
# Resolve path: explicit > model_base_path + hf slug > model_id slug
|
||||
model_path = body.path.strip()
|
||||
if not model_path:
|
||||
base = (svc.get("model_base_path", "") or "").rstrip("/")
|
||||
if not base:
|
||||
raise HTTPException(
|
||||
422,
|
||||
f"Service '{body.service_type}' has no model_base_path; supply an explicit path",
|
||||
)
|
||||
slug_src = body.hf_repo.strip() if body.hf_repo.strip() else body.model_id
|
||||
hf_slug = slug_src.replace("/", "--")
|
||||
model_path = f"{base}/{hf_slug}"
|
||||
|
||||
# Immutable catalog update — spread, never mutate
|
||||
entry: dict = {"path": model_path, "vram_mb": body.vram_mb}
|
||||
if body.description:
|
||||
entry["description"] = body.description
|
||||
new_catalog = {**(svc.get("catalog") or {}), body.model_id: entry}
|
||||
new_svc = {**svc, "catalog": new_catalog}
|
||||
new_services = {**services_def, body.service_type: new_svc}
|
||||
new_profile = {**profile, "services": new_services}
|
||||
|
||||
# Atomic write
|
||||
tmp = Path(str(p) + ".tmp")
|
||||
tmp.write_text(
|
||||
yaml.dump(new_profile, default_flow_style=False, allow_unicode=True, sort_keys=False),
|
||||
encoding="utf-8",
|
||||
)
|
||||
os.replace(tmp, p)
|
||||
|
||||
# Trigger coordinator reload
|
||||
cfg = _load_config()
|
||||
coordinator_url = cfg.get("coordinator_url", "") or ""
|
||||
reloaded = False
|
||||
if coordinator_url:
|
||||
try:
|
||||
import httpx
|
||||
rr = httpx.post(f"{coordinator_url}/api/nodes/{node_id}/reload-profile", timeout=5.0)
|
||||
reloaded = rr.status_code < 300
|
||||
except Exception as exc:
|
||||
logger.warning("Coordinator reload failed for %s: %s", node_id, exc)
|
||||
|
||||
return {"ok": True, "reloaded": reloaded, "path": model_path}
|
||||
|
|
@ -1,327 +0,0 @@
|
|||
"""Avocet — CF planning benchmark integration API.
|
||||
|
||||
Wraps scripts/benchmark_plans.py and exposes it via the Avocet API.
|
||||
Connection config (api_base) is read from label_tool.yaml under the
|
||||
`plans_bench:` key (optional; falls back to localhost:8080).
|
||||
|
||||
All endpoints are registered on `router` (FastAPI APIRouter).
|
||||
api.py includes this router with prefix="/api/plans-bench".
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import subprocess as _subprocess
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
import yaml
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_ROOT = Path(__file__).parent.parent
|
||||
_CONFIG_DIR: Path | None = None # override in tests via set_config_dir()
|
||||
_BENCH_RUNNING: bool = False
|
||||
_bench_proc: Any = None
|
||||
|
||||
_BENCH_SCRIPT = _ROOT / "scripts" / "benchmark_plans.py"
|
||||
_RESULTS_DIR = _ROOT / "data" / "plans_bench_results"
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# ── Registered model shortcuts (mirrors benchmark_plans.MODEL_REGISTRY) ────────
|
||||
# Kept here so the UI can list them without importing the script.
|
||||
|
||||
MODEL_REGISTRY: dict[str, str] = {
|
||||
"deepseek-r1-1.5b": "DeepSeek R1 1.5B distill (cf-orch catalog key)",
|
||||
"deepseek-r1-7b-4bit": "DeepSeek R1 7B distill, 4-bit (cf-orch catalog key)",
|
||||
"deepseek-r1-0528-qwen3-8b-gguf": "DeepSeek R1 0528 Qwen3 8B GGUF (4 nodes)",
|
||||
"deepseek-coder-6.7b-4bit": "DeepSeek Coder 6.7B instruct, 4-bit (cf-orch catalog key)",
|
||||
"granite-4.1-8b": "IBM Granite 4.1 8B, 4-bit (cf-orch catalog key)",
|
||||
"qwen2.5-3b": "Qwen 2.5 3B Q4 GGUF (cf-orch catalog key)",
|
||||
"qwen2.5-7b": "Qwen 2.5 7B Q4 GGUF (cf-orch catalog key)",
|
||||
"capybarahermes-2.5-mistral-7b-gguf": "CapybaraHermes 2.5 Mistral 7B GGUF (4 nodes)",
|
||||
"darwin-9b-opus-gguf": "Darwin 9B Opus GGUF -- long-form writing (3 nodes)",
|
||||
}
|
||||
|
||||
RUBRIC_LABELS: dict[str, str] = {
|
||||
"task_structure": "Task structure (checkboxes + commits)",
|
||||
"tier_awareness": "Tier awareness (Free/Paid/Premium/Ultra)",
|
||||
"privacy_pillar": "Privacy pillar (local-first, no logging)",
|
||||
"safety_pillar": "Safety pillar (human approval, reversibility)",
|
||||
"accessibility": "Accessibility (ND/adaptive users)",
|
||||
"license_split": "License awareness (MIT vs BSL)",
|
||||
"file_paths": "File paths (plausible project paths)",
|
||||
"cf_conventions": "CF conventions (conda, manage.sh, /Library/…)",
|
||||
"length_ok": "Response length (200–2500 words)",
|
||||
}
|
||||
|
||||
|
||||
# ── Testability seam ───────────────────────────────────────────────────────────
|
||||
|
||||
def set_config_dir(path: Path | None) -> None:
|
||||
global _CONFIG_DIR
|
||||
_CONFIG_DIR = path
|
||||
|
||||
|
||||
# ── Internal helpers ───────────────────────────────────────────────────────────
|
||||
|
||||
def _config_file() -> Path:
|
||||
if _CONFIG_DIR is not None:
|
||||
return _CONFIG_DIR / "label_tool.yaml"
|
||||
return _ROOT / "config" / "label_tool.yaml"
|
||||
|
||||
|
||||
def _load_config() -> dict:
|
||||
f = _config_file()
|
||||
cforch_cfg: dict = {}
|
||||
bench_cfg: dict = {}
|
||||
if f.exists():
|
||||
try:
|
||||
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
||||
cforch_cfg = raw.get("cforch", {}) or {}
|
||||
bench_cfg = raw.get("plans_bench", {}) or {}
|
||||
except yaml.YAMLError as exc:
|
||||
logger.warning("Failed to parse plans_bench config %s: %s", f, exc)
|
||||
return {
|
||||
"coordinator_url": cforch_cfg.get("coordinator_url",
|
||||
bench_cfg.get("coordinator_url", "http://10.1.10.71:7700")),
|
||||
"python_bin": cforch_cfg.get("python_bin",
|
||||
bench_cfg.get("python_bin", "/devl/miniconda3/envs/cf/bin/python")),
|
||||
}
|
||||
|
||||
|
||||
def _results_file(run_id: str) -> Path:
|
||||
return _RESULTS_DIR / f"{run_id}.json"
|
||||
|
||||
|
||||
# ── GET /models ────────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/models")
|
||||
def get_models() -> dict:
|
||||
"""Return registered model shortcuts, live cf-orch catalog, and rubric labels."""
|
||||
cfg = _load_config()
|
||||
|
||||
cforch_models: list[dict] = []
|
||||
try:
|
||||
resp = httpx.get(
|
||||
f"{cfg['coordinator_url']}/api/services/cf-text/catalog",
|
||||
timeout=5.0,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
for model_id, entry in resp.json().items():
|
||||
if isinstance(entry, dict):
|
||||
cforch_models.append({
|
||||
"id": model_id,
|
||||
"name": model_id,
|
||||
"vram_mb": entry.get("vram_mb"),
|
||||
"description": entry.get("description", ""),
|
||||
})
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to fetch cf-orch catalog: %s", exc)
|
||||
|
||||
return {
|
||||
"registry": [
|
||||
{"key": k, "description": v}
|
||||
for k, v in MODEL_REGISTRY.items()
|
||||
],
|
||||
"cforch_models": cforch_models,
|
||||
"coordinator_url": cfg["coordinator_url"],
|
||||
"rubric_labels": RUBRIC_LABELS,
|
||||
}
|
||||
|
||||
|
||||
# ── GET /run ───────────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/run")
|
||||
def run_plans_benchmark(
|
||||
models: str = Query(..., description="Comma-separated model IDs (registry keys or cf-orch model names)"),
|
||||
prompt_ids: str = Query("", description="Comma-separated prompt IDs to run (empty = all 10)"),
|
||||
use_cforch: bool = Query(True, description="Route inference through cf-orch coordinator"),
|
||||
api_base: str = Query("", description="Direct API base URL when not using cf-orch"),
|
||||
workers: int = Query(1, ge=1, le=8, description="Number of models to benchmark concurrently"),
|
||||
) -> StreamingResponse:
|
||||
"""Spawn benchmark_plans.py and stream stdout as SSE progress events.
|
||||
|
||||
On successful completion emits a `type: result` event with parsed JSON
|
||||
and saves results to data/plans_bench_results/<run_id>.json.
|
||||
"""
|
||||
global _BENCH_RUNNING, _bench_proc
|
||||
|
||||
if _BENCH_RUNNING:
|
||||
raise HTTPException(409, "A planning benchmark is already running")
|
||||
|
||||
cfg = _load_config()
|
||||
python_bin = cfg["python_bin"]
|
||||
coordinator_url = cfg["coordinator_url"]
|
||||
|
||||
model_keys = [m.strip() for m in models.split(",") if m.strip()]
|
||||
if not model_keys:
|
||||
raise HTTPException(400, "At least one model key is required")
|
||||
|
||||
run_id = datetime.now(tz=timezone.utc).strftime("plans_%Y-%m-%d_%H%M%S")
|
||||
output_path = _results_file(run_id)
|
||||
_RESULTS_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def generate():
|
||||
global _BENCH_RUNNING, _bench_proc
|
||||
|
||||
if not _BENCH_SCRIPT.exists():
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': f'benchmark_plans.py not found at {_BENCH_SCRIPT}'})}\n\n"
|
||||
return
|
||||
|
||||
cmd = [python_bin, str(_BENCH_SCRIPT)]
|
||||
if len(model_keys) > 1:
|
||||
cmd.extend(["--compare"] + model_keys)
|
||||
else:
|
||||
cmd.extend(["--model", model_keys[0]])
|
||||
|
||||
if use_cforch:
|
||||
cmd.extend(["--cforch", "--cforch-url", coordinator_url])
|
||||
elif api_base.strip():
|
||||
cmd.extend(["--api-base", api_base.strip()])
|
||||
|
||||
cmd.extend(["--verbose", "--output", str(output_path)])
|
||||
if workers > 1:
|
||||
cmd.extend(["--workers", str(workers)])
|
||||
|
||||
if prompt_ids.strip():
|
||||
cmd.extend(["--prompts"] + [p.strip() for p in prompt_ids.split(",") if p.strip()])
|
||||
|
||||
_BENCH_RUNNING = True
|
||||
try:
|
||||
proc = _subprocess.Popen(
|
||||
cmd,
|
||||
stdout=_subprocess.PIPE,
|
||||
stderr=_subprocess.STDOUT,
|
||||
text=True,
|
||||
bufsize=1,
|
||||
cwd=str(_ROOT),
|
||||
)
|
||||
_bench_proc = proc
|
||||
try:
|
||||
for line in proc.stdout:
|
||||
line = line.rstrip()
|
||||
if line:
|
||||
yield f"data: {json.dumps({'type': 'progress', 'message': line})}\n\n"
|
||||
proc.wait()
|
||||
if proc.returncode == 0 and output_path.exists():
|
||||
try:
|
||||
results = json.loads(output_path.read_text(encoding="utf-8"))
|
||||
yield f"data: {json.dumps({'type': 'result', 'run_id': run_id, 'results': results})}\n\n"
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to read plans benchmark output: %s", exc)
|
||||
yield f"data: {json.dumps({'type': 'complete'})}\n\n"
|
||||
else:
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': f'Process exited with code {proc.returncode}'})}\n\n"
|
||||
finally:
|
||||
_bench_proc = None
|
||||
except Exception as exc:
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': str(exc)})}\n\n"
|
||||
finally:
|
||||
_BENCH_RUNNING = False
|
||||
|
||||
return StreamingResponse(
|
||||
generate(),
|
||||
media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
|
||||
)
|
||||
|
||||
|
||||
# ── GET /results ───────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/results")
|
||||
def list_results() -> list[dict]:
|
||||
"""List past planning benchmark runs, newest first."""
|
||||
if not _RESULTS_DIR.exists():
|
||||
return []
|
||||
|
||||
runs: list[dict] = []
|
||||
for f in sorted(_RESULTS_DIR.glob("plans_*.json"), reverse=True):
|
||||
run_id = f.stem
|
||||
try:
|
||||
data: dict = json.loads(f.read_text(encoding="utf-8"))
|
||||
model_keys = list(data.keys())
|
||||
# Average total_score across all models and prompts
|
||||
all_scores = [
|
||||
r["total_score"]
|
||||
for results in data.values()
|
||||
for r in results
|
||||
if not r.get("error")
|
||||
]
|
||||
avg_score = round(sum(all_scores) / len(all_scores), 3) if all_scores else 0.0
|
||||
except Exception:
|
||||
model_keys = []
|
||||
avg_score = 0.0
|
||||
|
||||
# Parse display date from run_id (plans_2026-04-27_143022)
|
||||
try:
|
||||
date_part = run_id.removeprefix("plans_") # 2026-04-27_143022
|
||||
date, time = date_part.split("_")
|
||||
display_date = f"{date} {time[:2]}:{time[2:4]}"
|
||||
except Exception:
|
||||
display_date = run_id
|
||||
|
||||
runs.append({
|
||||
"run_id": run_id,
|
||||
"filename": f.name,
|
||||
"date": display_date,
|
||||
"models": model_keys,
|
||||
"avg_score": avg_score,
|
||||
})
|
||||
|
||||
return runs
|
||||
|
||||
|
||||
@router.get("/results/latest")
|
||||
def get_latest_results() -> dict:
|
||||
"""Return the most recent planning benchmark results dict."""
|
||||
if not _RESULTS_DIR.exists():
|
||||
raise HTTPException(404, "No benchmark results found")
|
||||
files = sorted(_RESULTS_DIR.glob("plans_*.json"))
|
||||
if not files:
|
||||
raise HTTPException(404, "No benchmark results found")
|
||||
try:
|
||||
return json.loads(files[-1].read_text(encoding="utf-8"))
|
||||
except Exception as exc:
|
||||
raise HTTPException(500, f"Failed to read results: {exc}") from exc
|
||||
|
||||
|
||||
@router.get("/results/{run_id}")
|
||||
def get_results_by_run_id(run_id: str) -> dict:
|
||||
"""Return planning benchmark results for a specific run."""
|
||||
if not run_id.startswith("plans_"):
|
||||
raise HTTPException(400, "Invalid run_id — expected plans_YYYY-MM-DD_HHMMSS")
|
||||
f = _results_file(run_id)
|
||||
if not f.exists():
|
||||
raise HTTPException(404, f"Results not found: {run_id}")
|
||||
try:
|
||||
return json.loads(f.read_text(encoding="utf-8"))
|
||||
except Exception as exc:
|
||||
raise HTTPException(500, f"Failed to read results: {exc}") from exc
|
||||
|
||||
|
||||
# ── POST /cancel ───────────────────────────────────────────────────────────────
|
||||
|
||||
@router.post("/cancel")
|
||||
def cancel_plans_benchmark() -> dict:
|
||||
"""Kill the running planning benchmark subprocess."""
|
||||
global _BENCH_RUNNING, _bench_proc
|
||||
|
||||
if not _BENCH_RUNNING:
|
||||
raise HTTPException(404, "No planning benchmark is currently running")
|
||||
|
||||
if _bench_proc is not None:
|
||||
try:
|
||||
_bench_proc.terminate()
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to terminate plans benchmark: %s", exc)
|
||||
|
||||
_BENCH_RUNNING = False
|
||||
_bench_proc = None
|
||||
return {"status": "cancelled"}
|
||||
334
app/sft.py
334
app/sft.py
|
|
@ -1,8 +1,326 @@
|
|||
"""Backward-compat shim -- logic moved to app/data/corrections.py."""
|
||||
from app.data.corrections import ( # noqa: F401
|
||||
router,
|
||||
set_data_dir as set_sft_data_dir,
|
||||
set_config_dir as set_sft_config_dir,
|
||||
set_default_bench_results_dir,
|
||||
_DEFAULT_BENCH_RESULTS_DIR,
|
||||
)
|
||||
"""Avocet — SFT candidate import and correction API.
|
||||
|
||||
All endpoints are registered on `router` (a FastAPI APIRouter).
|
||||
api.py includes this router with prefix="/api/sft".
|
||||
|
||||
Module-level globals (_SFT_DATA_DIR, _SFT_CONFIG_DIR) follow the same
|
||||
testability pattern as api.py — override them via set_sft_data_dir() and
|
||||
set_sft_config_dir() in test fixtures.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
import yaml
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.utils import append_jsonl, read_jsonl, write_jsonl
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_ROOT = Path(__file__).parent.parent
|
||||
_SFT_DATA_DIR: Path = _ROOT / "data"
|
||||
_SFT_CONFIG_DIR: Path | None = None
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# ── Testability seams ──────────────────────────────────────────────────────
|
||||
|
||||
def set_sft_data_dir(path: Path) -> None:
|
||||
global _SFT_DATA_DIR
|
||||
_SFT_DATA_DIR = path
|
||||
|
||||
|
||||
def set_sft_config_dir(path: Path | None) -> None:
|
||||
global _SFT_CONFIG_DIR
|
||||
_SFT_CONFIG_DIR = path
|
||||
|
||||
|
||||
# ── Internal helpers ───────────────────────────────────────────────────────
|
||||
|
||||
def _config_file() -> Path:
|
||||
if _SFT_CONFIG_DIR is not None:
|
||||
return _SFT_CONFIG_DIR / "label_tool.yaml"
|
||||
return _ROOT / "config" / "label_tool.yaml"
|
||||
|
||||
|
||||
def _get_bench_results_dir() -> Path:
|
||||
f = _config_file()
|
||||
if not f.exists():
|
||||
return Path("/nonexistent-bench-results")
|
||||
try:
|
||||
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
||||
except yaml.YAMLError as exc:
|
||||
logger.warning("Failed to parse SFT config %s: %s", f, exc)
|
||||
return Path("/nonexistent-bench-results")
|
||||
d = raw.get("sft", {}).get("bench_results_dir", "")
|
||||
return Path(d) if d else Path("/nonexistent-bench-results")
|
||||
|
||||
|
||||
def _candidates_file() -> Path:
|
||||
return _SFT_DATA_DIR / "sft_candidates.jsonl"
|
||||
|
||||
|
||||
def _approved_file() -> Path:
|
||||
return _SFT_DATA_DIR / "sft_approved.jsonl"
|
||||
|
||||
|
||||
def _read_candidates() -> list[dict]:
|
||||
return read_jsonl(_candidates_file())
|
||||
|
||||
|
||||
def _write_candidates(records: list[dict]) -> None:
|
||||
write_jsonl(_candidates_file(), records)
|
||||
|
||||
|
||||
def _is_exportable(r: dict) -> bool:
|
||||
"""Return True if an approved record is ready to include in SFT export."""
|
||||
return (
|
||||
r.get("status") == "approved"
|
||||
and bool(r.get("corrected_response"))
|
||||
and str(r["corrected_response"]).strip() != ""
|
||||
)
|
||||
|
||||
|
||||
# ── GET /runs ──────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/runs")
|
||||
def get_runs():
|
||||
"""List available benchmark runs in the configured bench_results_dir."""
|
||||
from scripts.sft_import import discover_runs
|
||||
bench_dir = _get_bench_results_dir()
|
||||
existing = _read_candidates()
|
||||
# benchmark_run_id in each record equals the run's directory name by cf-orch convention
|
||||
imported_run_ids = {
|
||||
r["benchmark_run_id"]
|
||||
for r in existing
|
||||
if r.get("benchmark_run_id") is not None
|
||||
}
|
||||
runs = discover_runs(bench_dir)
|
||||
return [
|
||||
{
|
||||
"run_id": r["run_id"],
|
||||
"timestamp": r["timestamp"],
|
||||
"candidate_count": r["candidate_count"],
|
||||
"already_imported": r["run_id"] in imported_run_ids,
|
||||
}
|
||||
for r in runs
|
||||
]
|
||||
|
||||
|
||||
# ── POST /import ───────────────────────────────────────────────────────────
|
||||
|
||||
class ImportRequest(BaseModel):
|
||||
run_id: str
|
||||
|
||||
|
||||
@router.post("/import")
|
||||
def post_import(req: ImportRequest):
|
||||
"""Import one benchmark run's sft_candidates.jsonl into the local data dir."""
|
||||
from scripts.sft_import import discover_runs, import_run
|
||||
bench_dir = _get_bench_results_dir()
|
||||
runs = discover_runs(bench_dir)
|
||||
run = next((r for r in runs if r["run_id"] == req.run_id), None)
|
||||
if run is None:
|
||||
raise HTTPException(404, f"Run {req.run_id!r} not found in bench_results_dir")
|
||||
return import_run(run["sft_path"], _SFT_DATA_DIR)
|
||||
|
||||
|
||||
# ── GET /queue ─────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/queue")
|
||||
def get_queue(page: int = 1, per_page: int = 20):
|
||||
"""Return paginated needs_review candidates."""
|
||||
records = _read_candidates()
|
||||
pending = [r for r in records if r.get("status") == "needs_review"]
|
||||
start = (page - 1) * per_page
|
||||
return {
|
||||
"items": pending[start:start + per_page],
|
||||
"total": len(pending),
|
||||
"page": page,
|
||||
"per_page": per_page,
|
||||
}
|
||||
|
||||
|
||||
# ── POST /submit ───────────────────────────────────────────────────────────
|
||||
|
||||
FailureCategory = Literal[
|
||||
"scoring_artifact",
|
||||
"style_violation",
|
||||
"partial_answer",
|
||||
"wrong_answer",
|
||||
"format_error",
|
||||
"hallucination",
|
||||
]
|
||||
|
||||
|
||||
class SubmitRequest(BaseModel):
|
||||
id: str
|
||||
action: Literal["correct", "discard", "flag"]
|
||||
corrected_response: str | None = None
|
||||
failure_category: FailureCategory | None = None
|
||||
|
||||
|
||||
@router.post("/submit")
|
||||
def post_submit(req: SubmitRequest):
|
||||
"""Record a reviewer decision for one SFT candidate."""
|
||||
if req.action == "correct":
|
||||
if not req.corrected_response or not req.corrected_response.strip():
|
||||
raise HTTPException(422, "corrected_response must be non-empty when action is 'correct'")
|
||||
|
||||
records = _read_candidates()
|
||||
idx = next((i for i, r in enumerate(records) if r.get("id") == req.id), None)
|
||||
if idx is None:
|
||||
raise HTTPException(404, f"Record {req.id!r} not found")
|
||||
|
||||
record = records[idx]
|
||||
if record.get("status") != "needs_review":
|
||||
raise HTTPException(409, f"Record is not in needs_review state (current: {record.get('status')})")
|
||||
|
||||
if req.action == "correct":
|
||||
records[idx] = {
|
||||
**record,
|
||||
"status": "approved",
|
||||
"corrected_response": req.corrected_response,
|
||||
"failure_category": req.failure_category,
|
||||
}
|
||||
_write_candidates(records)
|
||||
append_jsonl(_approved_file(), records[idx])
|
||||
elif req.action == "discard":
|
||||
records[idx] = {**record, "status": "discarded"}
|
||||
_write_candidates(records)
|
||||
else: # flag
|
||||
records[idx] = {**record, "status": "model_rejected"}
|
||||
_write_candidates(records)
|
||||
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
# ── POST /undo ─────────────────────────────────────────────────────────────
|
||||
|
||||
class UndoRequest(BaseModel):
|
||||
id: str
|
||||
|
||||
|
||||
@router.post("/undo")
|
||||
def post_undo(req: UndoRequest):
|
||||
"""Restore a previously actioned candidate back to needs_review."""
|
||||
records = _read_candidates()
|
||||
idx = next((i for i, r in enumerate(records) if r.get("id") == req.id), None)
|
||||
if idx is None:
|
||||
raise HTTPException(404, f"Record {req.id!r} not found")
|
||||
|
||||
record = records[idx]
|
||||
old_status = record.get("status")
|
||||
if old_status == "needs_review":
|
||||
raise HTTPException(409, "Record is already in needs_review state")
|
||||
|
||||
records[idx] = {**record, "status": "needs_review", "corrected_response": None}
|
||||
_write_candidates(records)
|
||||
|
||||
# If it was approved, remove from the approved file too
|
||||
if old_status == "approved":
|
||||
approved = read_jsonl(_approved_file())
|
||||
write_jsonl(_approved_file(), [r for r in approved if r.get("id") != req.id])
|
||||
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
# ── GET /export ─────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/export")
|
||||
def get_export() -> StreamingResponse:
|
||||
"""Stream approved records as SFT-ready JSONL for download."""
|
||||
exportable = [r for r in read_jsonl(_approved_file()) if _is_exportable(r)]
|
||||
|
||||
def generate():
|
||||
for r in exportable:
|
||||
record = {
|
||||
"messages": r.get("prompt_messages", []) + [
|
||||
{"role": "assistant", "content": r["corrected_response"]}
|
||||
]
|
||||
}
|
||||
yield json.dumps(record) + "\n"
|
||||
|
||||
timestamp = datetime.now(timezone.utc).strftime("%Y%m%d-%H%M%S")
|
||||
return StreamingResponse(
|
||||
generate(),
|
||||
media_type="application/x-ndjson",
|
||||
headers={
|
||||
"Content-Disposition": f'attachment; filename="sft_export_{timestamp}.jsonl"'
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# ── GET /stats ──────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/stats")
|
||||
def get_stats() -> dict[str, object]:
|
||||
"""Return counts by status, model, and task type."""
|
||||
records = _read_candidates()
|
||||
by_status: dict[str, int] = {}
|
||||
by_model: dict[str, int] = {}
|
||||
by_task_type: dict[str, int] = {}
|
||||
|
||||
for r in records:
|
||||
status = r.get("status", "unknown")
|
||||
by_status[status] = by_status.get(status, 0) + 1
|
||||
model = r.get("model_name", "unknown")
|
||||
by_model[model] = by_model.get(model, 0) + 1
|
||||
task_type = r.get("task_type", "unknown")
|
||||
by_task_type[task_type] = by_task_type.get(task_type, 0) + 1
|
||||
|
||||
approved = read_jsonl(_approved_file())
|
||||
export_ready = sum(1 for r in approved if _is_exportable(r))
|
||||
|
||||
return {
|
||||
"total": len(records),
|
||||
"by_status": by_status,
|
||||
"by_model": by_model,
|
||||
"by_task_type": by_task_type,
|
||||
"export_ready": export_ready,
|
||||
}
|
||||
|
||||
|
||||
# ── GET /config ─────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/config")
|
||||
def get_sft_config() -> dict:
|
||||
"""Return the current SFT configuration (bench_results_dir)."""
|
||||
f = _config_file()
|
||||
if not f.exists():
|
||||
return {"bench_results_dir": ""}
|
||||
try:
|
||||
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
||||
except yaml.YAMLError:
|
||||
return {"bench_results_dir": ""}
|
||||
sft_section = raw.get("sft") or {}
|
||||
return {"bench_results_dir": sft_section.get("bench_results_dir", "")}
|
||||
|
||||
|
||||
class SftConfigPayload(BaseModel):
|
||||
bench_results_dir: str
|
||||
|
||||
|
||||
@router.post("/config")
|
||||
def post_sft_config(payload: SftConfigPayload) -> dict:
|
||||
"""Write the bench_results_dir setting to the config file."""
|
||||
f = _config_file()
|
||||
f.parent.mkdir(parents=True, exist_ok=True)
|
||||
try:
|
||||
raw = yaml.safe_load(f.read_text(encoding="utf-8")) if f.exists() else {}
|
||||
raw = raw or {}
|
||||
except yaml.YAMLError:
|
||||
raw = {}
|
||||
raw["sft"] = {"bench_results_dir": payload.bench_results_dir}
|
||||
tmp = f.with_suffix(".tmp")
|
||||
tmp.write_text(yaml.dump(raw, allow_unicode=True, sort_keys=False), encoding="utf-8")
|
||||
tmp.rename(f)
|
||||
return {"ok": True}
|
||||
|
|
|
|||
427
app/style.py
427
app/style.py
|
|
@ -1,427 +0,0 @@
|
|||
"""Avocet — Writing style benchmark integration API.
|
||||
|
||||
Wraps scripts/benchmark_style.py and exposes it via the Avocet API.
|
||||
Connection config (coordinator_url, ollama_url, python_bin) is read
|
||||
from label_tool.yaml under the `cforch:` key — the same block used
|
||||
by cforch.py, so no new config section is needed.
|
||||
|
||||
All endpoints are registered on `router` (a FastAPI APIRouter).
|
||||
api.py includes this router with prefix="/api/style".
|
||||
|
||||
Module-level globals (_BENCH_RUNNING, _bench_proc) follow the same
|
||||
testability pattern as cforch.py.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import subprocess as _subprocess
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
import yaml
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_ROOT = Path(__file__).parent.parent
|
||||
_CONFIG_DIR: Path | None = None # override in tests via set_config_dir()
|
||||
_BENCH_RUNNING: bool = False
|
||||
_bench_proc: Any = None
|
||||
|
||||
_BENCH_SCRIPT = _ROOT / "scripts" / "benchmark_style.py"
|
||||
_RESULTS_DIR = _ROOT / "benchmark_results"
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# ── Testability seams ──────────────────────────────────────────────────────────
|
||||
|
||||
def set_config_dir(path: Path | None) -> None:
|
||||
global _CONFIG_DIR
|
||||
_CONFIG_DIR = path
|
||||
|
||||
|
||||
# ── Internal helpers ───────────────────────────────────────────────────────────
|
||||
|
||||
def _config_file() -> Path:
|
||||
if _CONFIG_DIR is not None:
|
||||
return _CONFIG_DIR / "label_tool.yaml"
|
||||
return _ROOT / "config" / "label_tool.yaml"
|
||||
|
||||
|
||||
def _load_config() -> dict:
|
||||
"""Read label_tool.yaml cforch section for coordinator/ollama/python config."""
|
||||
f = _config_file()
|
||||
file_cfg: dict = {}
|
||||
if f.exists():
|
||||
try:
|
||||
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
||||
file_cfg = raw.get("cforch", {}) or {}
|
||||
except yaml.YAMLError as exc:
|
||||
logger.warning("Failed to parse style config %s: %s", f, exc)
|
||||
return {
|
||||
"coordinator_url": file_cfg.get("coordinator_url", "http://10.1.10.71:7700"),
|
||||
"ollama_url": file_cfg.get("ollama_url", "http://localhost:11434"),
|
||||
"python_bin": file_cfg.get("python_bin", "/devl/miniconda3/envs/cf/bin/python"),
|
||||
}
|
||||
|
||||
|
||||
# ── GET /models ────────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/models")
|
||||
def get_models() -> dict:
|
||||
"""Return available models grouped by source.
|
||||
|
||||
- ollama: fetched live from /api/tags (includes any models downloaded
|
||||
via the Models view — automatically in sync)
|
||||
- cf_text: fetched from cf-orch catalog endpoint (requires node profile
|
||||
entry + coordinator restart when new GGUFs are added)
|
||||
"""
|
||||
cfg = _load_config()
|
||||
|
||||
# Ollama models — live query so newly downloaded models appear immediately
|
||||
ollama_models: list[dict] = []
|
||||
try:
|
||||
resp = httpx.get(f"{cfg['ollama_url']}/api/tags", timeout=5.0)
|
||||
resp.raise_for_status()
|
||||
for m in resp.json().get("models", []):
|
||||
name = m.get("name", "")
|
||||
if name:
|
||||
size_bytes = m.get("size", 0)
|
||||
ollama_models.append({
|
||||
"id": name,
|
||||
"name": name,
|
||||
"source": "ollama",
|
||||
"size_mb": round(size_bytes / (1024 * 1024)) if size_bytes else None,
|
||||
"vram_mb": None,
|
||||
})
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to fetch ollama models: %s", exc)
|
||||
|
||||
# cf-text catalog — fetched from cf-orch coordinator
|
||||
cftext_models: list[dict] = []
|
||||
try:
|
||||
resp = httpx.get(
|
||||
f"{cfg['coordinator_url']}/api/services/cf-text/catalog",
|
||||
timeout=5.0,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
for model_id, entry in resp.json().items():
|
||||
if isinstance(entry, dict):
|
||||
cftext_models.append({
|
||||
"id": model_id,
|
||||
"name": model_id,
|
||||
"source": "cf-text",
|
||||
"vram_mb": entry.get("vram_mb"),
|
||||
"description": entry.get("description", ""),
|
||||
})
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to fetch cf-text catalog: %s", exc)
|
||||
|
||||
return {"ollama": ollama_models, "cf_text": cftext_models}
|
||||
|
||||
|
||||
# ── GET /run ───────────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/run")
|
||||
def run_style_benchmark(
|
||||
models: str = Query("", description="Comma-separated model IDs (empty = all)"),
|
||||
use_cforch: bool = Query(False),
|
||||
max_vram: int = Query(7200, description="Max VRAM MB for cf-orch OOM filter"),
|
||||
include_large: bool = Query(False, description="Include large (30B+) ollama models"),
|
||||
workers: int = Query(1, description="Parallel workers — run N models simultaneously"),
|
||||
) -> StreamingResponse:
|
||||
"""Spawn benchmark_style.py and stream stdout as SSE progress events.
|
||||
|
||||
On successful completion, emits a final `type: result` event containing
|
||||
the parsed JSON from the newest style_*.json file.
|
||||
"""
|
||||
global _BENCH_RUNNING, _bench_proc
|
||||
|
||||
if _BENCH_RUNNING:
|
||||
raise HTTPException(409, "A writing style benchmark is already running")
|
||||
|
||||
cfg = _load_config()
|
||||
python_bin = cfg["python_bin"]
|
||||
|
||||
def generate():
|
||||
global _BENCH_RUNNING, _bench_proc
|
||||
|
||||
if not _BENCH_SCRIPT.exists():
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': f'benchmark_style.py not found at {_BENCH_SCRIPT}'})}\n\n"
|
||||
return
|
||||
|
||||
cmd = [python_bin, str(_BENCH_SCRIPT), "run"]
|
||||
|
||||
if models:
|
||||
cmd.extend(["--models", ",".join(m.strip() for m in models.split(",") if m.strip())])
|
||||
if use_cforch:
|
||||
cmd.extend(["--cforch", "--cforch-url", cfg["coordinator_url"],
|
||||
"--max-vram", str(max_vram)])
|
||||
if include_large:
|
||||
cmd.append("--include-large")
|
||||
if workers > 1:
|
||||
cmd.extend(["--workers", str(workers)])
|
||||
|
||||
_BENCH_RUNNING = True
|
||||
try:
|
||||
proc = _subprocess.Popen(
|
||||
cmd,
|
||||
stdout=_subprocess.PIPE,
|
||||
stderr=_subprocess.STDOUT,
|
||||
text=True,
|
||||
bufsize=1,
|
||||
cwd=str(_ROOT),
|
||||
)
|
||||
_bench_proc = proc
|
||||
try:
|
||||
for line in proc.stdout:
|
||||
line = line.rstrip()
|
||||
if line:
|
||||
yield f"data: {json.dumps({'type': 'progress', 'message': line})}\n\n"
|
||||
proc.wait()
|
||||
if proc.returncode == 0:
|
||||
result_files = sorted(_RESULTS_DIR.glob("style_*.json"))
|
||||
if result_files:
|
||||
try:
|
||||
results = json.loads(result_files[-1].read_text(encoding="utf-8"))
|
||||
yield f"data: {json.dumps({'type': 'result', 'results': results, 'filename': result_files[-1].name})}\n\n"
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to read style results: %s", exc)
|
||||
yield f"data: {json.dumps({'type': 'complete'})}\n\n"
|
||||
else:
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': f'Process exited with code {proc.returncode}'})}\n\n"
|
||||
finally:
|
||||
_bench_proc = None
|
||||
except Exception as exc:
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': str(exc)})}\n\n"
|
||||
finally:
|
||||
_BENCH_RUNNING = False
|
||||
|
||||
return StreamingResponse(
|
||||
generate(),
|
||||
media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
|
||||
)
|
||||
|
||||
|
||||
# ── GET /results ───────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/results")
|
||||
def list_results() -> list[dict]:
|
||||
"""List past writing style benchmark runs, newest first.
|
||||
|
||||
Returns lightweight summaries (date, model count, top score).
|
||||
Use /results/{filename} to fetch full model-level detail.
|
||||
"""
|
||||
if not _RESULTS_DIR.exists():
|
||||
return []
|
||||
|
||||
runs: list[dict] = []
|
||||
for f in sorted(_RESULTS_DIR.glob("style_*.json"), reverse=True):
|
||||
stem = f.stem # style_2026-04-22_1502
|
||||
date_str = stem.removeprefix("style_") # 2026-04-22_1502
|
||||
try:
|
||||
date_part, time_part = date_str.split("_")
|
||||
display_date = f"{date_part} {time_part[:2]}:{time_part[2:]}"
|
||||
except Exception:
|
||||
display_date = date_str
|
||||
|
||||
try:
|
||||
results = json.loads(f.read_text(encoding="utf-8"))
|
||||
top_score = max((r.get("avg_score", 0) for r in results), default=0)
|
||||
model_count = len(results)
|
||||
except Exception:
|
||||
top_score = 0
|
||||
model_count = 0
|
||||
|
||||
runs.append({
|
||||
"filename": f.name,
|
||||
"date": display_date,
|
||||
"model_count": model_count,
|
||||
"top_score": round(top_score, 1),
|
||||
})
|
||||
|
||||
return runs
|
||||
|
||||
|
||||
@router.get("/results/latest")
|
||||
def get_latest_results() -> list[dict]:
|
||||
"""Return the latest writing style benchmark result list."""
|
||||
if not _RESULTS_DIR.exists():
|
||||
raise HTTPException(404, "No benchmark results found")
|
||||
files = sorted(_RESULTS_DIR.glob("style_*.json"))
|
||||
if not files:
|
||||
raise HTTPException(404, "No benchmark results found")
|
||||
try:
|
||||
return json.loads(files[-1].read_text(encoding="utf-8"))
|
||||
except Exception as exc:
|
||||
raise HTTPException(500, f"Failed to read results: {exc}") from exc
|
||||
|
||||
|
||||
@router.get("/results/{filename}")
|
||||
def get_results_by_filename(filename: str) -> list[dict]:
|
||||
"""Return writing style benchmark results for a specific run file."""
|
||||
if not filename.startswith("style_") or not filename.endswith(".json"):
|
||||
raise HTTPException(400, "Invalid filename — expected style_*.json")
|
||||
f = _RESULTS_DIR / filename
|
||||
if not f.exists():
|
||||
raise HTTPException(404, f"Results file not found: {filename}")
|
||||
try:
|
||||
return json.loads(f.read_text(encoding="utf-8"))
|
||||
except Exception as exc:
|
||||
raise HTTPException(500, f"Failed to read results: {exc}") from exc
|
||||
|
||||
|
||||
# ── POST /send-to-corrections ──────────────────────────────────────────────────
|
||||
|
||||
class SendToCorrectionsRequest(BaseModel):
|
||||
filename: str # style_YYYY-MM-DD_HHMM.json — the source run file
|
||||
model_ids: list[str] = [] # empty = all models in the run
|
||||
|
||||
|
||||
@router.post("/send-to-corrections")
|
||||
def send_to_corrections(req: SendToCorrectionsRequest) -> dict:
|
||||
"""Push writing style benchmark outputs into the SFT corrections queue.
|
||||
|
||||
Each prompt_result from the selected models becomes one SFT candidate
|
||||
with status='needs_review'. Duplicates are skipped via the 'id' field
|
||||
(hash of model_id + tag).
|
||||
"""
|
||||
if not req.filename.startswith("style_") or not req.filename.endswith(".json"):
|
||||
raise HTTPException(400, "Invalid filename")
|
||||
|
||||
src = _RESULTS_DIR / req.filename
|
||||
if not src.exists():
|
||||
raise HTTPException(404, f"Results file not found: {req.filename}")
|
||||
|
||||
try:
|
||||
run_results: list[dict] = json.loads(src.read_text(encoding="utf-8"))
|
||||
except Exception as exc:
|
||||
raise HTTPException(500, f"Failed to read results: {exc}") from exc
|
||||
|
||||
# Resolve sft_candidates.jsonl path (same logic as sft.py)
|
||||
sft_data_dir = _ROOT / "data"
|
||||
sft_file = sft_data_dir / "sft_candidates.jsonl"
|
||||
|
||||
# Load existing IDs to deduplicate
|
||||
existing_ids: set[str] = set()
|
||||
if sft_file.exists():
|
||||
for line in sft_file.read_text(encoding="utf-8").splitlines():
|
||||
line = line.strip()
|
||||
if line:
|
||||
try:
|
||||
existing_ids.add(json.loads(line)["id"])
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
run_id = req.filename.removesuffix(".json") # style_2026-04-22_1502
|
||||
timestamp = datetime.now(tz=timezone.utc).isoformat()
|
||||
|
||||
new_candidates: list[dict] = []
|
||||
for model_result in run_results:
|
||||
model_id = model_result.get("model_id", "")
|
||||
if req.model_ids and model_id not in req.model_ids:
|
||||
continue
|
||||
for pr in model_result.get("prompt_results", []):
|
||||
tag = pr.get("tag", "")
|
||||
# Stable id: deterministic hash of run + model + prompt tag
|
||||
candidate_id = str(uuid.uuid5(
|
||||
uuid.NAMESPACE_URL,
|
||||
f"style-benchmark/{run_id}/{model_id}/{tag}",
|
||||
))
|
||||
if candidate_id in existing_ids:
|
||||
continue
|
||||
|
||||
score_pct = pr.get("score", 0.0) / 100.0
|
||||
signals = pr.get("signals", {})
|
||||
|
||||
# Build the prompt message list matching the benchmark's actual request
|
||||
prompt_messages = [
|
||||
{"role": "system", "content": _STYLE_SYSTEM_PROMPT},
|
||||
{"role": "user", "content": pr.get("user_prompt", tag)},
|
||||
]
|
||||
|
||||
new_candidates.append({
|
||||
"id": candidate_id,
|
||||
"source": "style-benchmark",
|
||||
"benchmark_run_id": run_id,
|
||||
"timestamp": timestamp,
|
||||
"status": "needs_review",
|
||||
"prompt_messages": prompt_messages,
|
||||
"model_response": pr.get("output", ""),
|
||||
"corrected_response": None,
|
||||
"quality_score": round(score_pct, 4),
|
||||
"failure_reason": _build_failure_reason(pr, signals),
|
||||
"failure_category": None,
|
||||
"task_id": f"style/{tag}",
|
||||
"task_type": "style-match",
|
||||
"task_name": tag.replace("_", " ").title(),
|
||||
"model_id": model_id,
|
||||
"model_name": model_id,
|
||||
"node_id": "",
|
||||
"gpu_id": 0,
|
||||
"tokens_per_sec": 0,
|
||||
})
|
||||
existing_ids.add(candidate_id)
|
||||
|
||||
if new_candidates:
|
||||
sft_data_dir.mkdir(parents=True, exist_ok=True)
|
||||
with open(sft_file, "a", encoding="utf-8") as fh:
|
||||
for c in new_candidates:
|
||||
fh.write(json.dumps(c) + "\n")
|
||||
|
||||
return {"imported": len(new_candidates), "skipped": 0}
|
||||
|
||||
|
||||
# Excerpt of the system prompt used in benchmark_style.py — reproduced here
|
||||
# so the SFT candidate captures the full generation context.
|
||||
_STYLE_SYSTEM_PROMPT = (
|
||||
"You are a writing assistant. Your job is to write a Reddit reply that matches "
|
||||
"the voice, tone, and style of the provided samples exactly.\n\n"
|
||||
"Voice characteristics:\n"
|
||||
"- Casual engineer tone. Short punchy sentences.\n"
|
||||
"- No em dashes. No semicolons. No filler phrases.\n"
|
||||
"- Direct. Opinionated. Community-first."
|
||||
)
|
||||
|
||||
|
||||
def _build_failure_reason(pr: dict, signals: dict) -> str | None:
|
||||
"""Return a human-readable failure reason string if there are violations."""
|
||||
reasons = []
|
||||
if signals.get("em_dash_count", 0) > 0:
|
||||
reasons.append(f"{signals['em_dash_count']} em dash(es)")
|
||||
if signals.get("semicolon_count", 0) > 0:
|
||||
reasons.append(f"{signals['semicolon_count']} semicolon(s)")
|
||||
if signals.get("filler_hits"):
|
||||
reasons.append(f"filler phrases: {', '.join(signals['filler_hits'])}")
|
||||
if not pr.get("output", "").strip():
|
||||
reasons.append("empty output")
|
||||
return "; ".join(reasons) if reasons else None
|
||||
|
||||
|
||||
# ── POST /cancel ───────────────────────────────────────────────────────────────
|
||||
|
||||
@router.post("/cancel")
|
||||
def cancel_style_benchmark() -> dict:
|
||||
"""Kill the running writing style benchmark subprocess."""
|
||||
global _BENCH_RUNNING, _bench_proc
|
||||
|
||||
if not _BENCH_RUNNING:
|
||||
raise HTTPException(404, "No writing style benchmark is currently running")
|
||||
|
||||
if _bench_proc is not None:
|
||||
try:
|
||||
_bench_proc.terminate()
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to terminate style benchmark: %s", exc)
|
||||
|
||||
_BENCH_RUNNING = False
|
||||
_bench_proc = None
|
||||
return {"status": "cancelled"}
|
||||
|
|
@ -1,339 +0,0 @@
|
|||
"""Avocet -- train job queue API.
|
||||
|
||||
SQLite-backed job queue for finetune jobs. Replaces the ad-hoc
|
||||
_running_procs dict in api.py with a persistent, inspectable queue.
|
||||
|
||||
Routes (all under /api/train when api.py mounts with prefix="/api/train"):
|
||||
GET /jobs -- list all jobs, newest first
|
||||
POST /jobs -- create a new job
|
||||
GET /jobs/{id} -- get one job by id
|
||||
DELETE /jobs/{id}/cancel -- cancel a queued or running job
|
||||
GET /jobs/{id}/run -- SSE: run the job, stream stdout
|
||||
GET /results -- list completed models with training_info.json metrics
|
||||
|
||||
SQLite schema:
|
||||
CREATE TABLE IF NOT EXISTS jobs (
|
||||
id TEXT PRIMARY KEY,
|
||||
type TEXT NOT NULL, -- 'classifier' | 'llm-sft'
|
||||
model_key TEXT NOT NULL,
|
||||
status TEXT NOT NULL DEFAULT 'queued',
|
||||
config_json TEXT NOT NULL DEFAULT '{}',
|
||||
created_at TEXT NOT NULL,
|
||||
started_at TEXT,
|
||||
completed_at TEXT,
|
||||
error TEXT
|
||||
)
|
||||
|
||||
Testability seam: _DB_PATH global, override via set_db_path().
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sqlite3
|
||||
import subprocess as _subprocess
|
||||
import uuid
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any, Generator
|
||||
|
||||
import yaml
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_ROOT = Path(__file__).parent.parent.parent
|
||||
_DB_PATH: Path = _ROOT / "data" / "train_jobs.db"
|
||||
_MODELS_DIR: Path = _ROOT / "models"
|
||||
_CONFIG_DIR: Path | None = None # override in tests via set_config_dir()
|
||||
_running_procs: dict[str, Any] = {}
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# -- Testability seams -------------------------------------------------
|
||||
|
||||
def set_db_path(path: Path) -> None:
|
||||
global _DB_PATH
|
||||
_DB_PATH = path
|
||||
|
||||
def set_models_dir(path: Path) -> None:
|
||||
global _MODELS_DIR
|
||||
_MODELS_DIR = path
|
||||
|
||||
def set_config_dir(path: "Path | None") -> None:
|
||||
global _CONFIG_DIR
|
||||
_CONFIG_DIR = path
|
||||
|
||||
|
||||
# -- Config helpers ----------------------------------------------------
|
||||
|
||||
def _config_file() -> Path:
|
||||
if _CONFIG_DIR is not None:
|
||||
return _CONFIG_DIR / "label_tool.yaml"
|
||||
return _ROOT / "config" / "label_tool.yaml"
|
||||
|
||||
|
||||
def _load_train_config() -> dict:
|
||||
"""Read python_bin from label_tool.yaml.
|
||||
|
||||
Priority (highest to lowest):
|
||||
1. label_tool.yaml train: python_bin
|
||||
2. label_tool.yaml cforch: python_bin
|
||||
3. Hardcoded default (classifiers conda env)
|
||||
"""
|
||||
_DEFAULT_PYTHON_BIN = "/devl/miniconda3/envs/job-seeker-classifiers/bin/python"
|
||||
f = _config_file()
|
||||
train_cfg: dict = {}
|
||||
cforch_cfg: dict = {}
|
||||
if f.exists():
|
||||
try:
|
||||
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
||||
train_cfg = raw.get("train", {}) or {}
|
||||
cforch_cfg = raw.get("cforch", {}) or {}
|
||||
except yaml.YAMLError as exc:
|
||||
logger.warning("Failed to parse train config %s: %s", f, exc)
|
||||
return {
|
||||
"python_bin": train_cfg.get(
|
||||
"python_bin",
|
||||
cforch_cfg.get("python_bin", _DEFAULT_PYTHON_BIN),
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
# -- Database helpers --------------------------------------------------
|
||||
|
||||
@contextmanager
|
||||
def _db() -> Generator[sqlite3.Connection, None, None]:
|
||||
conn = sqlite3.connect(str(_DB_PATH))
|
||||
conn.row_factory = sqlite3.Row
|
||||
try:
|
||||
yield conn
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
def _init_db() -> None:
|
||||
"""Create jobs table if it does not exist. Called lazily per request."""
|
||||
_DB_PATH.parent.mkdir(parents=True, exist_ok=True)
|
||||
with _db() as conn:
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS jobs (
|
||||
id TEXT PRIMARY KEY,
|
||||
type TEXT NOT NULL,
|
||||
model_key TEXT NOT NULL,
|
||||
status TEXT NOT NULL DEFAULT 'queued',
|
||||
config_json TEXT NOT NULL DEFAULT '{}',
|
||||
created_at TEXT NOT NULL,
|
||||
started_at TEXT,
|
||||
completed_at TEXT,
|
||||
error TEXT
|
||||
)
|
||||
""")
|
||||
|
||||
|
||||
def _row_to_dict(row: sqlite3.Row) -> dict:
|
||||
return {k: row[k] for k in row.keys()}
|
||||
|
||||
|
||||
# -- GPU selection (copied from api.py) --------------------------------
|
||||
|
||||
def _best_cuda_device() -> str:
|
||||
"""Return index of GPU with most free VRAM, or empty string."""
|
||||
try:
|
||||
out = _subprocess.check_output(
|
||||
["nvidia-smi", "--query-gpu=index,memory.free",
|
||||
"--format=csv,noheader,nounits"],
|
||||
text=True, timeout=5,
|
||||
)
|
||||
best_idx, best_free = "", 0
|
||||
for line in out.strip().splitlines():
|
||||
parts = line.strip().split(", ")
|
||||
if len(parts) == 2:
|
||||
idx, free = parts[0].strip(), int(parts[1].strip())
|
||||
if free > best_free:
|
||||
best_free, best_idx = free, idx
|
||||
return best_idx
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
|
||||
# -- Pydantic models ---------------------------------------------------
|
||||
|
||||
class CreateJobRequest(BaseModel):
|
||||
type: str # "classifier" | "llm-sft"
|
||||
model_key: str # e.g. "deberta-small"
|
||||
config_json: dict = {}
|
||||
|
||||
|
||||
# -- Routes ------------------------------------------------------------
|
||||
|
||||
@router.get("/jobs")
|
||||
def list_jobs() -> dict:
|
||||
_init_db()
|
||||
with _db() as conn:
|
||||
rows = conn.execute("SELECT * FROM jobs ORDER BY created_at DESC").fetchall()
|
||||
return {"jobs": [_row_to_dict(r) for r in rows]}
|
||||
|
||||
|
||||
@router.post("/jobs")
|
||||
def create_job(req: CreateJobRequest) -> dict:
|
||||
if req.type not in ("classifier", "llm-sft"):
|
||||
raise HTTPException(400, f"Unknown job type: {req.type!r}. Must be 'classifier' or 'llm-sft'.")
|
||||
_init_db()
|
||||
job_id = str(uuid.uuid4())
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
with _db() as conn:
|
||||
conn.execute(
|
||||
"INSERT INTO jobs (id, type, model_key, status, config_json, created_at) "
|
||||
"VALUES (?, ?, ?, 'queued', ?, ?)",
|
||||
(job_id, req.type, req.model_key, json.dumps(req.config_json), now),
|
||||
)
|
||||
return {"id": job_id, "type": req.type, "model_key": req.model_key,
|
||||
"status": "queued", "config_json": req.config_json,
|
||||
"created_at": now, "started_at": None, "completed_at": None, "error": None}
|
||||
|
||||
|
||||
@router.get("/jobs/{job_id}")
|
||||
def get_job(job_id: str) -> dict:
|
||||
_init_db()
|
||||
with _db() as conn:
|
||||
row = conn.execute("SELECT * FROM jobs WHERE id = ?", (job_id,)).fetchone()
|
||||
if row is None:
|
||||
raise HTTPException(404, f"Job {job_id!r} not found")
|
||||
return _row_to_dict(row)
|
||||
|
||||
|
||||
@router.delete("/jobs/{job_id}/cancel")
|
||||
def cancel_job(job_id: str) -> dict:
|
||||
_init_db()
|
||||
with _db() as conn:
|
||||
row = conn.execute("SELECT * FROM jobs WHERE id = ?", (job_id,)).fetchone()
|
||||
if row is None:
|
||||
raise HTTPException(404, f"Job {job_id!r} not found")
|
||||
if row["status"] not in ("queued", "running"):
|
||||
raise HTTPException(409, f"Job is {row['status']} -- cannot cancel")
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
conn.execute("UPDATE jobs SET status='cancelled', completed_at=? WHERE id=?", (now, job_id))
|
||||
proc = _running_procs.pop(job_id, None)
|
||||
if proc is not None:
|
||||
try:
|
||||
proc.terminate()
|
||||
proc.wait(timeout=3)
|
||||
except _subprocess.TimeoutExpired:
|
||||
try:
|
||||
proc.kill()
|
||||
except OSError:
|
||||
pass
|
||||
return {"status": "cancelled"}
|
||||
|
||||
|
||||
@router.get("/jobs/{job_id}/run")
|
||||
def run_job(job_id: str) -> StreamingResponse:
|
||||
_init_db()
|
||||
with _db() as conn:
|
||||
row = conn.execute("SELECT * FROM jobs WHERE id = ?", (job_id,)).fetchone()
|
||||
if row is None:
|
||||
raise HTTPException(404, f"Job {job_id!r} not found")
|
||||
if row["status"] != "queued":
|
||||
raise HTTPException(409, f"Job is {row['status']} -- only queued jobs can be run")
|
||||
job = _row_to_dict(row)
|
||||
|
||||
def generate():
|
||||
cfg = _load_train_config()
|
||||
python_bin = cfg["python_bin"]
|
||||
config = json.loads(job["config_json"] or "{}")
|
||||
model_key = job["model_key"]
|
||||
epochs = config.get("epochs", 5)
|
||||
|
||||
if job["type"] == "classifier":
|
||||
script = str(_ROOT / "scripts" / "finetune_classifier.py")
|
||||
cmd = [python_bin, script, "--model", model_key, "--epochs", str(epochs)]
|
||||
data_dir = _ROOT / "data"
|
||||
for sf in config.get("score_files", []):
|
||||
resolved = (data_dir / sf).resolve()
|
||||
if resolved.is_relative_to(data_dir.resolve()):
|
||||
cmd.extend(["--score", str(resolved)])
|
||||
elif job["type"] == "llm-sft":
|
||||
script = str(_ROOT / "scripts" / "finetune_sft.py")
|
||||
cmd = [python_bin, script, "--model", model_key, "--epochs", str(epochs)]
|
||||
else:
|
||||
job_type = job["type"]
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': f'Unknown job type: {job_type}'})}\n\n"
|
||||
return
|
||||
|
||||
proc_env = {**os.environ, "PYTORCH_ALLOC_CONF": "expandable_segments:True"}
|
||||
best_gpu = _best_cuda_device()
|
||||
if best_gpu:
|
||||
proc_env["CUDA_VISIBLE_DEVICES"] = best_gpu
|
||||
|
||||
gpu_note = f"GPU {best_gpu}" if best_gpu else "CPU (no GPU found)"
|
||||
yield f"data: {json.dumps({'type': 'progress', 'message': f'[train] Using {gpu_note}'})}\n\n"
|
||||
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
with _db() as conn:
|
||||
conn.execute("UPDATE jobs SET status='running', started_at=? WHERE id=?", (now, job_id))
|
||||
|
||||
try:
|
||||
proc = _subprocess.Popen(
|
||||
cmd, stdout=_subprocess.PIPE, stderr=_subprocess.STDOUT,
|
||||
text=True, bufsize=1, cwd=str(_ROOT), env=proc_env,
|
||||
)
|
||||
_running_procs[job_id] = proc
|
||||
try:
|
||||
for line in proc.stdout:
|
||||
line = line.rstrip()
|
||||
if line:
|
||||
yield f"data: {json.dumps({'type': 'progress', 'message': line})}\n\n"
|
||||
proc.wait()
|
||||
finished_at = datetime.now(timezone.utc).isoformat()
|
||||
if proc.returncode == 0:
|
||||
with _db() as conn:
|
||||
conn.execute(
|
||||
"UPDATE jobs SET status='completed', completed_at=? WHERE id=?",
|
||||
(finished_at, job_id))
|
||||
yield f"data: {json.dumps({'type': 'complete'})}\n\n"
|
||||
else:
|
||||
err = f"Process exited with code {proc.returncode}"
|
||||
with _db() as conn:
|
||||
conn.execute(
|
||||
"UPDATE jobs SET status='failed', completed_at=?, error=? WHERE id=?",
|
||||
(finished_at, err, job_id))
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': err})}\n\n"
|
||||
finally:
|
||||
_running_procs.pop(job_id, None)
|
||||
except Exception as exc:
|
||||
err = str(exc)
|
||||
finished_at = datetime.now(timezone.utc).isoformat()
|
||||
with _db() as conn:
|
||||
conn.execute(
|
||||
"UPDATE jobs SET status='failed', completed_at=?, error=? WHERE id=?",
|
||||
(finished_at, err, job_id))
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': err})}\n\n"
|
||||
|
||||
return StreamingResponse(generate(), media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"})
|
||||
|
||||
|
||||
@router.get("/results")
|
||||
def list_results() -> dict:
|
||||
if not _MODELS_DIR.exists():
|
||||
return {"results": []}
|
||||
results = []
|
||||
for sub in _MODELS_DIR.iterdir():
|
||||
if not sub.is_dir():
|
||||
continue
|
||||
info_path = sub / "training_info.json"
|
||||
if not info_path.exists():
|
||||
continue
|
||||
try:
|
||||
info = json.loads(info_path.read_text(encoding="utf-8"))
|
||||
results.append(info)
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to read training_info.json from %s: %s", info_path, exc)
|
||||
return {"results": results}
|
||||
|
|
@ -106,7 +106,7 @@ def read_jsonl(path: Path) -> list[dict]:
|
|||
def write_jsonl(path: Path, records: list[dict]) -> None:
|
||||
"""Write records to a JSONL file, overwriting any existing content."""
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
content = "\n".join(json.dumps(r, ensure_ascii=False) for r in records)
|
||||
content = "\n".join(json.dumps(r) for r in records)
|
||||
path.write_text(content + ("\n" if records else ""), encoding="utf-8")
|
||||
|
||||
|
||||
|
|
@ -114,4 +114,4 @@ def append_jsonl(path: Path, record: dict) -> None:
|
|||
"""Append a single record to a JSONL file."""
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(path, "a", encoding="utf-8") as fh:
|
||||
fh.write(json.dumps(record, ensure_ascii=False) + "\n")
|
||||
fh.write(json.dumps(record) + "\n")
|
||||
|
|
|
|||
427
app/voice.py
427
app/voice.py
|
|
@ -1,427 +0,0 @@
|
|||
"""Avocet — Voice benchmark integration API.
|
||||
|
||||
Wraps scripts/benchmark_voice.py and exposes it via the Avocet API.
|
||||
Connection config (coordinator_url, ollama_url, python_bin) is read
|
||||
from label_tool.yaml under the `cforch:` key — the same block used
|
||||
by cforch.py, so no new config section is needed.
|
||||
|
||||
All endpoints are registered on `router` (a FastAPI APIRouter).
|
||||
api.py includes this router with prefix="/api/voice".
|
||||
|
||||
Module-level globals (_BENCH_RUNNING, _bench_proc) follow the same
|
||||
testability pattern as cforch.py.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import subprocess as _subprocess
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
import yaml
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_ROOT = Path(__file__).parent.parent
|
||||
_CONFIG_DIR: Path | None = None # override in tests via set_config_dir()
|
||||
_BENCH_RUNNING: bool = False
|
||||
_bench_proc: Any = None
|
||||
|
||||
_BENCH_SCRIPT = _ROOT / "scripts" / "benchmark_voice.py"
|
||||
_RESULTS_DIR = _ROOT / "benchmark_results"
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# ── Testability seams ──────────────────────────────────────────────────────────
|
||||
|
||||
def set_config_dir(path: Path | None) -> None:
|
||||
global _CONFIG_DIR
|
||||
_CONFIG_DIR = path
|
||||
|
||||
|
||||
# ── Internal helpers ───────────────────────────────────────────────────────────
|
||||
|
||||
def _config_file() -> Path:
|
||||
if _CONFIG_DIR is not None:
|
||||
return _CONFIG_DIR / "label_tool.yaml"
|
||||
return _ROOT / "config" / "label_tool.yaml"
|
||||
|
||||
|
||||
def _load_config() -> dict:
|
||||
"""Read label_tool.yaml cforch section for coordinator/ollama/python config."""
|
||||
f = _config_file()
|
||||
file_cfg: dict = {}
|
||||
if f.exists():
|
||||
try:
|
||||
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
||||
file_cfg = raw.get("cforch", {}) or {}
|
||||
except yaml.YAMLError as exc:
|
||||
logger.warning("Failed to parse voice config %s: %s", f, exc)
|
||||
return {
|
||||
"coordinator_url": file_cfg.get("coordinator_url", "http://10.1.10.71:7700"),
|
||||
"ollama_url": file_cfg.get("ollama_url", "http://localhost:11434"),
|
||||
"python_bin": file_cfg.get("python_bin", "/devl/miniconda3/envs/cf/bin/python"),
|
||||
}
|
||||
|
||||
|
||||
# ── GET /models ────────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/models")
|
||||
def get_models() -> dict:
|
||||
"""Return available models grouped by source.
|
||||
|
||||
- ollama: fetched live from /api/tags (includes any models downloaded
|
||||
via the Models view — automatically in sync)
|
||||
- cf_text: fetched from cf-orch catalog endpoint (requires node profile
|
||||
entry + coordinator restart when new GGUFs are added)
|
||||
"""
|
||||
cfg = _load_config()
|
||||
|
||||
# Ollama models — live query so newly downloaded models appear immediately
|
||||
ollama_models: list[dict] = []
|
||||
try:
|
||||
resp = httpx.get(f"{cfg['ollama_url']}/api/tags", timeout=5.0)
|
||||
resp.raise_for_status()
|
||||
for m in resp.json().get("models", []):
|
||||
name = m.get("name", "")
|
||||
if name:
|
||||
size_bytes = m.get("size", 0)
|
||||
ollama_models.append({
|
||||
"id": name,
|
||||
"name": name,
|
||||
"source": "ollama",
|
||||
"size_mb": round(size_bytes / (1024 * 1024)) if size_bytes else None,
|
||||
"vram_mb": None,
|
||||
})
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to fetch ollama models: %s", exc)
|
||||
|
||||
# cf-text catalog — fetched from cf-orch coordinator
|
||||
cftext_models: list[dict] = []
|
||||
try:
|
||||
resp = httpx.get(
|
||||
f"{cfg['coordinator_url']}/api/services/cf-text/catalog",
|
||||
timeout=5.0,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
for model_id, entry in resp.json().items():
|
||||
if isinstance(entry, dict):
|
||||
cftext_models.append({
|
||||
"id": model_id,
|
||||
"name": model_id,
|
||||
"source": "cf-text",
|
||||
"vram_mb": entry.get("vram_mb"),
|
||||
"description": entry.get("description", ""),
|
||||
})
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to fetch cf-text catalog: %s", exc)
|
||||
|
||||
return {"ollama": ollama_models, "cf_text": cftext_models}
|
||||
|
||||
|
||||
# ── GET /run ───────────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/run")
|
||||
def run_voice_benchmark(
|
||||
models: str = Query("", description="Comma-separated model IDs (empty = all)"),
|
||||
use_cforch: bool = Query(False),
|
||||
max_vram: int = Query(7200, description="Max VRAM MB for cf-orch OOM filter"),
|
||||
include_large: bool = Query(False, description="Include large (30B+) ollama models"),
|
||||
workers: int = Query(1, description="Parallel workers — run N models simultaneously"),
|
||||
) -> StreamingResponse:
|
||||
"""Spawn benchmark_voice.py and stream stdout as SSE progress events.
|
||||
|
||||
On successful completion, emits a final `type: result` event containing
|
||||
the parsed JSON from the newest voice_*.json file.
|
||||
"""
|
||||
global _BENCH_RUNNING, _bench_proc
|
||||
|
||||
if _BENCH_RUNNING:
|
||||
raise HTTPException(409, "A voice benchmark is already running")
|
||||
|
||||
cfg = _load_config()
|
||||
python_bin = cfg["python_bin"]
|
||||
|
||||
def generate():
|
||||
global _BENCH_RUNNING, _bench_proc
|
||||
|
||||
if not _BENCH_SCRIPT.exists():
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': f'benchmark_voice.py not found at {_BENCH_SCRIPT}'})}\n\n"
|
||||
return
|
||||
|
||||
cmd = [python_bin, str(_BENCH_SCRIPT), "run"]
|
||||
|
||||
if models:
|
||||
cmd.extend(["--models", ",".join(m.strip() for m in models.split(",") if m.strip())])
|
||||
if use_cforch:
|
||||
cmd.extend(["--cforch", "--cforch-url", cfg["coordinator_url"],
|
||||
"--max-vram", str(max_vram)])
|
||||
if include_large:
|
||||
cmd.append("--include-large")
|
||||
if workers > 1:
|
||||
cmd.extend(["--workers", str(workers)])
|
||||
|
||||
_BENCH_RUNNING = True
|
||||
try:
|
||||
proc = _subprocess.Popen(
|
||||
cmd,
|
||||
stdout=_subprocess.PIPE,
|
||||
stderr=_subprocess.STDOUT,
|
||||
text=True,
|
||||
bufsize=1,
|
||||
cwd=str(_ROOT),
|
||||
)
|
||||
_bench_proc = proc
|
||||
try:
|
||||
for line in proc.stdout:
|
||||
line = line.rstrip()
|
||||
if line:
|
||||
yield f"data: {json.dumps({'type': 'progress', 'message': line})}\n\n"
|
||||
proc.wait()
|
||||
if proc.returncode == 0:
|
||||
result_files = sorted(_RESULTS_DIR.glob("voice_*.json"))
|
||||
if result_files:
|
||||
try:
|
||||
results = json.loads(result_files[-1].read_text(encoding="utf-8"))
|
||||
yield f"data: {json.dumps({'type': 'result', 'results': results, 'filename': result_files[-1].name})}\n\n"
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to read voice results: %s", exc)
|
||||
yield f"data: {json.dumps({'type': 'complete'})}\n\n"
|
||||
else:
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': f'Process exited with code {proc.returncode}'})}\n\n"
|
||||
finally:
|
||||
_bench_proc = None
|
||||
except Exception as exc:
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': str(exc)})}\n\n"
|
||||
finally:
|
||||
_BENCH_RUNNING = False
|
||||
|
||||
return StreamingResponse(
|
||||
generate(),
|
||||
media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
|
||||
)
|
||||
|
||||
|
||||
# ── GET /results ───────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/results")
|
||||
def list_results() -> list[dict]:
|
||||
"""List past voice benchmark runs, newest first.
|
||||
|
||||
Returns lightweight summaries (date, model count, top score).
|
||||
Use /results/{filename} to fetch full model-level detail.
|
||||
"""
|
||||
if not _RESULTS_DIR.exists():
|
||||
return []
|
||||
|
||||
runs: list[dict] = []
|
||||
for f in sorted(_RESULTS_DIR.glob("voice_*.json"), reverse=True):
|
||||
stem = f.stem # voice_2026-04-22_1502
|
||||
date_str = stem.removeprefix("voice_") # 2026-04-22_1502
|
||||
try:
|
||||
date_part, time_part = date_str.split("_")
|
||||
display_date = f"{date_part} {time_part[:2]}:{time_part[2:]}"
|
||||
except Exception:
|
||||
display_date = date_str
|
||||
|
||||
try:
|
||||
results = json.loads(f.read_text(encoding="utf-8"))
|
||||
top_score = max((r.get("avg_score", 0) for r in results), default=0)
|
||||
model_count = len(results)
|
||||
except Exception:
|
||||
top_score = 0
|
||||
model_count = 0
|
||||
|
||||
runs.append({
|
||||
"filename": f.name,
|
||||
"date": display_date,
|
||||
"model_count": model_count,
|
||||
"top_score": round(top_score, 1),
|
||||
})
|
||||
|
||||
return runs
|
||||
|
||||
|
||||
@router.get("/results/latest")
|
||||
def get_latest_results() -> list[dict]:
|
||||
"""Return the latest voice benchmark result list."""
|
||||
if not _RESULTS_DIR.exists():
|
||||
raise HTTPException(404, "No benchmark results found")
|
||||
files = sorted(_RESULTS_DIR.glob("voice_*.json"))
|
||||
if not files:
|
||||
raise HTTPException(404, "No benchmark results found")
|
||||
try:
|
||||
return json.loads(files[-1].read_text(encoding="utf-8"))
|
||||
except Exception as exc:
|
||||
raise HTTPException(500, f"Failed to read results: {exc}") from exc
|
||||
|
||||
|
||||
@router.get("/results/{filename}")
|
||||
def get_results_by_filename(filename: str) -> list[dict]:
|
||||
"""Return voice benchmark results for a specific run file."""
|
||||
if not filename.startswith("voice_") or not filename.endswith(".json"):
|
||||
raise HTTPException(400, "Invalid filename — expected voice_*.json")
|
||||
f = _RESULTS_DIR / filename
|
||||
if not f.exists():
|
||||
raise HTTPException(404, f"Results file not found: {filename}")
|
||||
try:
|
||||
return json.loads(f.read_text(encoding="utf-8"))
|
||||
except Exception as exc:
|
||||
raise HTTPException(500, f"Failed to read results: {exc}") from exc
|
||||
|
||||
|
||||
# ── POST /send-to-corrections ──────────────────────────────────────────────────
|
||||
|
||||
class SendToCorrectionsRequest(BaseModel):
|
||||
filename: str # voice_YYYY-MM-DD_HHMM.json — the source run file
|
||||
model_ids: list[str] = [] # empty = all models in the run
|
||||
|
||||
|
||||
@router.post("/send-to-corrections")
|
||||
def send_to_corrections(req: SendToCorrectionsRequest) -> dict:
|
||||
"""Push voice benchmark outputs into the SFT corrections queue.
|
||||
|
||||
Each prompt_result from the selected models becomes one SFT candidate
|
||||
with status='needs_review'. Duplicates are skipped via the 'id' field
|
||||
(hash of model_id + tag).
|
||||
"""
|
||||
if not req.filename.startswith("voice_") or not req.filename.endswith(".json"):
|
||||
raise HTTPException(400, "Invalid filename")
|
||||
|
||||
src = _RESULTS_DIR / req.filename
|
||||
if not src.exists():
|
||||
raise HTTPException(404, f"Results file not found: {req.filename}")
|
||||
|
||||
try:
|
||||
run_results: list[dict] = json.loads(src.read_text(encoding="utf-8"))
|
||||
except Exception as exc:
|
||||
raise HTTPException(500, f"Failed to read results: {exc}") from exc
|
||||
|
||||
# Resolve sft_candidates.jsonl path (same logic as sft.py)
|
||||
sft_data_dir = _ROOT / "data"
|
||||
sft_file = sft_data_dir / "sft_candidates.jsonl"
|
||||
|
||||
# Load existing IDs to deduplicate
|
||||
existing_ids: set[str] = set()
|
||||
if sft_file.exists():
|
||||
for line in sft_file.read_text(encoding="utf-8").splitlines():
|
||||
line = line.strip()
|
||||
if line:
|
||||
try:
|
||||
existing_ids.add(json.loads(line)["id"])
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
run_id = req.filename.removesuffix(".json") # voice_2026-04-22_1502
|
||||
timestamp = datetime.now(tz=timezone.utc).isoformat()
|
||||
|
||||
new_candidates: list[dict] = []
|
||||
for model_result in run_results:
|
||||
model_id = model_result.get("model_id", "")
|
||||
if req.model_ids and model_id not in req.model_ids:
|
||||
continue
|
||||
for pr in model_result.get("prompt_results", []):
|
||||
tag = pr.get("tag", "")
|
||||
# Stable id: deterministic hash of run + model + prompt tag
|
||||
candidate_id = str(uuid.uuid5(
|
||||
uuid.NAMESPACE_URL,
|
||||
f"voice-benchmark/{run_id}/{model_id}/{tag}",
|
||||
))
|
||||
if candidate_id in existing_ids:
|
||||
continue
|
||||
|
||||
score_pct = pr.get("score", 0.0) / 100.0
|
||||
signals = pr.get("signals", {})
|
||||
|
||||
# Build the prompt message list matching the benchmark's actual request
|
||||
prompt_messages = [
|
||||
{"role": "system", "content": _VOICE_SYSTEM_PROMPT},
|
||||
{"role": "user", "content": pr.get("user_prompt", tag)},
|
||||
]
|
||||
|
||||
new_candidates.append({
|
||||
"id": candidate_id,
|
||||
"source": "voice-benchmark",
|
||||
"benchmark_run_id": run_id,
|
||||
"timestamp": timestamp,
|
||||
"status": "needs_review",
|
||||
"prompt_messages": prompt_messages,
|
||||
"model_response": pr.get("output", ""),
|
||||
"corrected_response": None,
|
||||
"quality_score": round(score_pct, 4),
|
||||
"failure_reason": _build_failure_reason(pr, signals),
|
||||
"failure_category": None,
|
||||
"task_id": f"voice/{tag}",
|
||||
"task_type": "voice-match",
|
||||
"task_name": tag.replace("_", " ").title(),
|
||||
"model_id": model_id,
|
||||
"model_name": model_id,
|
||||
"node_id": "",
|
||||
"gpu_id": 0,
|
||||
"tokens_per_sec": 0,
|
||||
})
|
||||
existing_ids.add(candidate_id)
|
||||
|
||||
if new_candidates:
|
||||
sft_data_dir.mkdir(parents=True, exist_ok=True)
|
||||
with open(sft_file, "a", encoding="utf-8") as fh:
|
||||
for c in new_candidates:
|
||||
fh.write(json.dumps(c) + "\n")
|
||||
|
||||
return {"imported": len(new_candidates), "skipped": 0}
|
||||
|
||||
|
||||
# Excerpt of the system prompt used in benchmark_voice.py — reproduced here
|
||||
# so the SFT candidate captures the full generation context.
|
||||
_VOICE_SYSTEM_PROMPT = (
|
||||
"You are a writing assistant. Your job is to write a Reddit reply that matches "
|
||||
"the voice, tone, and style of the provided samples exactly.\n\n"
|
||||
"Voice characteristics:\n"
|
||||
"- Casual engineer tone. Short punchy sentences.\n"
|
||||
"- No em dashes. No semicolons. No filler phrases.\n"
|
||||
"- Direct. Opinionated. Community-first."
|
||||
)
|
||||
|
||||
|
||||
def _build_failure_reason(pr: dict, signals: dict) -> str | None:
|
||||
"""Return a human-readable failure reason string if there are violations."""
|
||||
reasons = []
|
||||
if signals.get("em_dash_count", 0) > 0:
|
||||
reasons.append(f"{signals['em_dash_count']} em dash(es)")
|
||||
if signals.get("semicolon_count", 0) > 0:
|
||||
reasons.append(f"{signals['semicolon_count']} semicolon(s)")
|
||||
if signals.get("filler_hits"):
|
||||
reasons.append(f"filler phrases: {', '.join(signals['filler_hits'])}")
|
||||
if not pr.get("output", "").strip():
|
||||
reasons.append("empty output")
|
||||
return "; ".join(reasons) if reasons else None
|
||||
|
||||
|
||||
# ── POST /cancel ───────────────────────────────────────────────────────────────
|
||||
|
||||
@router.post("/cancel")
|
||||
def cancel_voice_benchmark() -> dict:
|
||||
"""Kill the running voice benchmark subprocess."""
|
||||
global _BENCH_RUNNING, _bench_proc
|
||||
|
||||
if not _BENCH_RUNNING:
|
||||
raise HTTPException(404, "No voice benchmark is currently running")
|
||||
|
||||
if _bench_proc is not None:
|
||||
try:
|
||||
_bench_proc.terminate()
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to terminate voice benchmark: %s", exc)
|
||||
|
||||
_BENCH_RUNNING = False
|
||||
_bench_proc = None
|
||||
return {"status": "cancelled"}
|
||||
|
|
@ -26,119 +26,3 @@ max_per_account: 500
|
|||
# produced by circuitforge-orch's benchmark harness.
|
||||
sft:
|
||||
bench_results_dir: /path/to/circuitforge-orch/scripts/bench_results
|
||||
|
||||
# cf-orch integration — LLM benchmark harness via cf-orch coordinator.
|
||||
# All keys here override the corresponding environment variables.
|
||||
# Omit any key to fall back to the env var (see .env.example).
|
||||
cforch:
|
||||
# Path to cf-orch's benchmark.py script
|
||||
bench_script: /path/to/circuitforge-orch/scripts/benchmark.py
|
||||
# Task and model definition files (yaml)
|
||||
bench_tasks: /path/to/circuitforge-orch/scripts/bench_tasks.yaml
|
||||
bench_models: /path/to/circuitforge-orch/scripts/bench_models.yaml
|
||||
# Where benchmark results are written (also used for SFT candidate discovery)
|
||||
results_dir: /path/to/circuitforge-orch/scripts/bench_results
|
||||
# Python interpreter with cf-orch installed
|
||||
python_bin: /devl/miniconda3/envs/cf/bin/python
|
||||
|
||||
# Connection config — override env vars CF_ORCH_URL / CF_LICENSE_KEY / OLLAMA_HOST / CF_JUDGE_URL / HF_TOKEN
|
||||
# coordinator_url: http://localhost:7700
|
||||
# license_key: CFG-AVCT-xxxx-xxxx-xxxx
|
||||
# ollama_url: http://localhost:11434
|
||||
# ollama_model: llama3.2:3b
|
||||
# embed_model: nomic-embed-text # Ollama embedding model for EmbeddingKNNAdapter
|
||||
# judge_url: http://10.1.10.158:8008 # Sif cf-text — LLM-as-judge secondary scorer
|
||||
# judge_url: http://10.1.10.71:8008 # Heimdall cf-text (alternative)
|
||||
# Or set CF_JUDGE_URL. Populates the Judge URL field in the LLM Eval UI automatically.
|
||||
# hf_token: hf_xxxxxxxxxxxxxxxxxxxx # HuggingFace token — required for gated/terms-restricted models
|
||||
|
||||
# Directory containing per-node profile YAMLs (cf-orch node profiles).
|
||||
# Default: derived from bench_script location (../../profiles/nodes).
|
||||
# profiles_dir: /Library/Development/CircuitForge/circuitforge-orch/circuitforge_orch/profiles/nodes
|
||||
|
||||
# Imitate tab — pull real samples from sibling CF product APIs and run them
|
||||
# through local LLMs to build a corrections dataset.
|
||||
# ollama_url defaults to cforch.ollama_url if omitted here.
|
||||
imitate:
|
||||
ollama_url: http://localhost:11434 # optional — falls back to cforch.ollama_url
|
||||
|
||||
products:
|
||||
- id: peregrine
|
||||
name: Peregrine
|
||||
icon: "🦅"
|
||||
description: Job search assistant — live job listings
|
||||
base_url: http://localhost:8601
|
||||
health_path: /api/jobs/counts
|
||||
sample_endpoint: /api/jobs?status=pending&limit=5
|
||||
text_fields: [title, company, description]
|
||||
prompt_template: "Analyze this job listing and identify the key requirements, must-have skills, and any culture signals that would help tailor an application:\n\n{text}"
|
||||
|
||||
- id: osprey
|
||||
name: Osprey
|
||||
icon: "📞"
|
||||
description: Gov't hold-line automation — recent call records
|
||||
base_url: http://localhost:8520
|
||||
health_path: /api/health
|
||||
sample_endpoint: /api/calls/recent
|
||||
text_fields: [agency, issue, notes]
|
||||
prompt_template: "Draft a clear, professional follow-up letter for this government hold-line call. Include what was discussed, what action the agency committed to, and a polite deadline for response:\n\n{text}"
|
||||
|
||||
- id: linnet
|
||||
name: Linnet
|
||||
icon: "🐦"
|
||||
description: Real-time tone annotation — Elcor-style subtext for ND users
|
||||
base_url: http://localhost:8522
|
||||
health_path: /health
|
||||
sample_endpoint: /samples
|
||||
text_fields: [text, context]
|
||||
prompt_template: "Annotate the emotional tone and subtext of the following text using explicit Elcor-style markers (e.g. [SINCERELY], [UNCERTAIN], [FRUSTRATED]). Identify implied emotions, potential sarcasm, and any ambiguity that might be misread by neurodivergent readers:\n\n{text}"
|
||||
|
||||
- id: kiwi
|
||||
name: Kiwi
|
||||
icon: "🥝"
|
||||
description: Pantry tracker
|
||||
base_url: http://localhost:8511
|
||||
sample_endpoint: /api/inventory
|
||||
text_fields: [name, category, notes]
|
||||
prompt_template: "Describe this pantry item and estimate how best to use it:\n\n{text}"
|
||||
|
||||
- id: snipe
|
||||
name: Snipe
|
||||
icon: "🎯"
|
||||
description: eBay trust scoring
|
||||
base_url: http://localhost:8509
|
||||
sample_endpoint: /api/listings
|
||||
text_fields: [title, description, seller_info]
|
||||
prompt_template: "Evaluate the trustworthiness of this listing and flag any red flags:\n\n{text}"
|
||||
|
||||
- id: pagepiper
|
||||
name: Pagepiper
|
||||
icon: "📄"
|
||||
description: "PDF/rulebook RAG tool: page-level text chunks"
|
||||
base_url: http://localhost:8511
|
||||
health_path: /api/health
|
||||
sample_endpoint: /api/library
|
||||
chunk_endpoint: /api/library/sample-chunks?limit=50 # requires pagepiper#6
|
||||
text_fields: [title]
|
||||
prompt_template: "Summarize the key rules described in this passage:\n\n{text}"
|
||||
|
||||
# ── Log corpus (Turnstone training data) ──────────────────────────────────────
|
||||
corpus:
|
||||
# Directory containing pipeline JSONL log files to ingest (pull-side).
|
||||
# Files named <script>_<ts>.jsonl; one structured record per line.
|
||||
# POST /api/corpus/pipeline-ingest walks this dir and imports new files.
|
||||
# NFS-mounted on both Heimdall and Sif at /Library/Assets/
|
||||
pipeline_ingest_dir: /Library/Assets/logs/pipeline/
|
||||
|
||||
# Turnstone push sources (consent-gated, token-authenticated).
|
||||
# sources:
|
||||
# - token: "your-bearer-token"
|
||||
# source_host: "node.local"
|
||||
# owner: YourName
|
||||
# consent_date: "2026-05-17"
|
||||
# consent_method: signal_chat
|
||||
|
||||
# ── Embedding model comparison harness ────────────────────────────────────────
|
||||
embed_bench:
|
||||
# ollama_url: http://localhost:11434 # optional; falls back to cforch.ollama_url
|
||||
# top_k: 5 # default hits per model per query
|
||||
|
|
|
|||
|
|
@ -22,8 +22,5 @@ dependencies:
|
|||
# Optional: BGE reranker adapter
|
||||
# - FlagEmbedding
|
||||
|
||||
# CircuitForge shared core (LLM router, tier system, config)
|
||||
- circuitforge-core>=0.9.0
|
||||
|
||||
# Dev
|
||||
- pytest>=8.0
|
||||
|
|
|
|||
61
manage.sh
61
manage.sh
|
|
@ -90,18 +90,6 @@ usage() {
|
|||
echo -e " ${GREEN}score [args]${NC} Shortcut: --score [args]"
|
||||
echo -e " ${GREEN}compare [args]${NC} Shortcut: --compare [args]"
|
||||
echo ""
|
||||
echo " Planning Benchmark:"
|
||||
echo -e " ${GREEN}plans-bench [args]${NC} Run benchmark_plans.py (args passed through)"
|
||||
echo -e " ${GREEN}plans-list${NC} Shortcut: --list-models"
|
||||
echo -e " ${GREEN}plans-run <model> [args]${NC} Run a single model (--verbose auto-added)"
|
||||
echo -e " ${GREEN}plans-compare <m1> <m2> [more]${NC} Compare models side-by-side"
|
||||
echo ""
|
||||
echo " Writing Style Benchmark:"
|
||||
echo -e " ${GREEN}style-bench [args]${NC} Run benchmark_style.py (args passed through)"
|
||||
echo -e " ${GREEN}style-list${NC} List available ollama models for style bench"
|
||||
echo -e " ${GREEN}style-run [args]${NC} Run writing style benchmark (--models, --samples, --include-large, --scan-disk PATH, --cforch)"
|
||||
echo -e " ${GREEN}style-last${NC} Print most recent writing style benchmark report"
|
||||
echo ""
|
||||
echo " Dev:"
|
||||
echo -e " ${GREEN}dev${NC} Hot-reload: uvicorn --reload (:8503) + Vite HMR (:5173)"
|
||||
echo -e " ${GREEN}test${NC} Run pytest suite"
|
||||
|
|
@ -133,8 +121,6 @@ case "$CMD" in
|
|||
fi
|
||||
mkdir -p "$LOG_DIR"
|
||||
API_LOG="${LOG_DIR}/api.log"
|
||||
# Load .env if present — sets HF_TOKEN and other optional overrides.
|
||||
[[ -f .env ]] && set -a && source .env && set +a
|
||||
info "Building Vue SPA…"
|
||||
(cd web && npm run build) >> "$API_LOG" 2>&1
|
||||
info "Starting FastAPI on port ${API_PORT}…"
|
||||
|
|
@ -187,9 +173,6 @@ case "$CMD" in
|
|||
mkdir -p "$LOG_DIR"
|
||||
DEV_API_LOG="${LOG_DIR}/dev-api.log"
|
||||
|
||||
# Load .env if present — sets HF_TOKEN and other optional overrides.
|
||||
[[ -f .env ]] && set -a && source .env && set +a
|
||||
|
||||
if [[ -f "$DEV_API_PID_FILE" ]] && kill -0 "$(<"$DEV_API_PID_FILE")" 2>/dev/null; then
|
||||
warn "Dev API already running (PID $(<"$DEV_API_PID_FILE"))"
|
||||
else
|
||||
|
|
@ -266,50 +249,6 @@ case "$CMD" in
|
|||
exec "$0" benchmark --compare "$@"
|
||||
;;
|
||||
|
||||
plans-bench)
|
||||
info "Running planning benchmark (${ENV_UI})…"
|
||||
"$PYTHON_UI" scripts/benchmark_plans.py "$@"
|
||||
;;
|
||||
|
||||
plans-list)
|
||||
exec "$0" plans-bench --list-models
|
||||
;;
|
||||
|
||||
plans-run)
|
||||
if [[ $# -lt 1 ]]; then
|
||||
error "Usage: ./manage.sh plans-run <model-key> [extra args]"
|
||||
fi
|
||||
MODEL="$1"; shift
|
||||
exec "$0" plans-bench --model "$MODEL" --verbose "$@"
|
||||
;;
|
||||
|
||||
plans-compare)
|
||||
if [[ $# -lt 2 ]]; then
|
||||
error "Usage: ./manage.sh plans-compare <model1> <model2> [more…]"
|
||||
fi
|
||||
exec "$0" plans-bench --compare "$@" --verbose
|
||||
;;
|
||||
|
||||
style-bench)
|
||||
info "Running writing style benchmark (${ENV_BM})…"
|
||||
if [[ ! -x "$PYTHON_BM" ]]; then
|
||||
error "Python not found in ${ENV_BM} env at ${PYTHON_BM}"
|
||||
fi
|
||||
"$PYTHON_BM" scripts/benchmark_style.py "$@"
|
||||
;;
|
||||
|
||||
style-list)
|
||||
exec "$0" style-bench --list-models
|
||||
;;
|
||||
|
||||
style-run)
|
||||
exec "$0" style-bench --run "$@"
|
||||
;;
|
||||
|
||||
style-last)
|
||||
exec "$0" style-bench --show-last
|
||||
;;
|
||||
|
||||
help|--help|-h)
|
||||
usage
|
||||
;;
|
||||
|
|
|
|||
|
|
@ -3,6 +3,3 @@ testpaths = tests
|
|||
python_files = test_*.py
|
||||
python_classes = Test*
|
||||
python_functions = test_*
|
||||
markers =
|
||||
gpu: requires an idle GPU; excluded from default runs
|
||||
slow: long-running test; excluded from default CI runs
|
||||
|
|
|
|||
|
|
@ -3,4 +3,3 @@ pydantic>=2.0.0
|
|||
uvicorn[standard]>=0.20.0
|
||||
httpx>=0.24.0
|
||||
pytest>=7.0.0
|
||||
pyyaml>=6.0
|
||||
|
|
|
|||
|
|
@ -39,7 +39,6 @@ from scripts.classifier_adapters import (
|
|||
LABELS,
|
||||
LABEL_DESCRIPTIONS,
|
||||
ClassifierAdapter,
|
||||
EmbeddingKNNAdapter,
|
||||
FineTunedAdapter,
|
||||
GLiClassAdapter,
|
||||
RerankerAdapter,
|
||||
|
|
@ -131,13 +130,6 @@ MODEL_REGISTRY: dict[str, dict[str, Any]] = {
|
|||
"params": "600M",
|
||||
"default": False,
|
||||
},
|
||||
"embed-knn-nomic": {
|
||||
"adapter": EmbeddingKNNAdapter,
|
||||
"model_id": "nomic-embed-text",
|
||||
"params": "local-embed",
|
||||
"default": False, # requires orch or ollama; use --include-slow
|
||||
"kwargs": {"k": 3},
|
||||
},
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -192,42 +184,6 @@ def discover_finetuned_models(models_dir: Path | None = None) -> list[dict]:
|
|||
return found
|
||||
|
||||
|
||||
def build_exemplars_from_jsonl(path: str, k_per_label: int = 10) -> dict[str, list[str]]:
|
||||
"""Sample up to k_per_label formatted email texts per label from a scored JSONL.
|
||||
|
||||
Formats each row as 'Subject: {subject}\n\n{body[:600]}' — the same format
|
||||
EmbeddingKNNAdapter uses at classify() time. Rows missing the 'label' key
|
||||
are skipped silently.
|
||||
|
||||
Returns dict[label, list[str]] ready for EmbeddingKNNAdapter(exemplar_texts=...).
|
||||
"""
|
||||
result: dict[str, list[str]] = {}
|
||||
p = Path(path)
|
||||
with p.open(encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
row = json.loads(line)
|
||||
except json.JSONDecodeError as exc:
|
||||
print(f"[build_exemplars] WARN: skipping malformed line: {exc}", flush=True)
|
||||
continue
|
||||
label = row.get("label")
|
||||
if not label:
|
||||
continue
|
||||
subject = row.get("subject", "")
|
||||
body = row.get("body", "")
|
||||
if not subject and not body:
|
||||
continue
|
||||
texts = result.setdefault(label, [])
|
||||
if len(texts) < k_per_label:
|
||||
texts.append(
|
||||
f"Subject: {subject}\n\n{body[:600]}"
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def _active_models(include_slow: bool = False) -> dict[str, dict[str, Any]]:
|
||||
"""Return the active model registry, merged with any discovered fine-tuned models."""
|
||||
active: dict[str, dict[str, Any]] = {
|
||||
|
|
|
|||
|
|
@ -1,734 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
"""CF-specific planning benchmark — compare base models before fine-tuning.
|
||||
|
||||
Sends held-out CircuitForge planning prompts to one or more models via the
|
||||
cf-text (local) or cf-orch API, then scores responses against CF-specific
|
||||
rubrics. Use this to select the best base model for SFT.
|
||||
|
||||
Scoring rubrics (each 0-1, summed to total/N):
|
||||
- task_structure : uses checkbox syntax (- [ ]), git commit steps
|
||||
- tier_awareness : mentions Free/Paid/Premium/Ultra tiers
|
||||
- privacy_pillar : mentions privacy/local-inference/no-logging
|
||||
- safety_pillar : mentions safety, human approval, or reversibility
|
||||
- accessibility : mentions ND/accessibility/adaptive needs
|
||||
- license_split : mentions MIT vs BSL or open-core model
|
||||
- file_paths : uses plausible file path references
|
||||
- cf_conventions : uses conda run -n cf, /Library/Development/, or known CF dirs
|
||||
- paired_coherence : (paired only) plan references the design doc's feature name
|
||||
- length_ok : 300–2500 words (under-short = hallucination risk; over-long = padding)
|
||||
|
||||
Usage
|
||||
-----
|
||||
# List available model targets
|
||||
python scripts/benchmark_plans.py --list-models
|
||||
|
||||
# Run all held-out prompts against a single model, print report
|
||||
python scripts/benchmark_plans.py --model granite-4.1-8b
|
||||
|
||||
# Compare two models side-by-side
|
||||
python scripts/benchmark_plans.py --compare granite-4.1-8b deepseek-r1-7b-4bit
|
||||
|
||||
# Run with a custom API base (cf-text default: http://localhost:8080/v1)
|
||||
python scripts/benchmark_plans.py --model granite-4.1-8b --api-base http://localhost:8080/v1
|
||||
|
||||
# Export detailed results JSON
|
||||
python scripts/benchmark_plans.py --model granite-4.1-8b --output data/bench_results.json
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import re
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from dataclasses import dataclass, field, asdict
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
# ── Paths ──────────────────────────────────────────────────────────────────────
|
||||
|
||||
_ROOT = Path(__file__).parent.parent
|
||||
_DATA_DIR = _ROOT / "data"
|
||||
|
||||
CF_TEXT_BASE = "http://localhost:8080/v1"
|
||||
CF_ORCH_BASE = "http://localhost:8090/v1"
|
||||
CF_COORD_URL = "http://10.1.10.71:7700" # cf-orch coordinator (LAN)
|
||||
|
||||
# ── Held-out prompts ───────────────────────────────────────────────────────────
|
||||
# These are NOT in the training export (no matching docs in circuitforge-plans/).
|
||||
# Each prompt exercises a different CF planning domain.
|
||||
|
||||
HELD_OUT_PROMPTS: list[dict[str, Any]] = [
|
||||
{
|
||||
"id": "ho_001",
|
||||
"name": "kiwi_barcode_ocr",
|
||||
"domain": "feature_plan",
|
||||
"prompt": (
|
||||
"You are a senior engineer on Kiwi, a CircuitForge pantry-tracking product. "
|
||||
"Write a detailed implementation plan for adding barcode scanning via device camera "
|
||||
"and receipt OCR to the item-add flow.\n\n"
|
||||
"The plan should include: file structure (create/modify), step-by-step task checklist "
|
||||
"with checkboxes, any DB migrations, and git commit steps."
|
||||
),
|
||||
"expected_signals": ["task_structure", "file_paths", "cf_conventions"],
|
||||
},
|
||||
{
|
||||
"id": "ho_002",
|
||||
"name": "peregrine_ats_scoring",
|
||||
"domain": "feature_design",
|
||||
"prompt": (
|
||||
"Write a design document for Peregrine: ATS keyword scoring for job applications.\n\n"
|
||||
"Context: Peregrine users paste job descriptions and their resume. "
|
||||
"We want to score how well the resume keywords match the JD and suggest rewrites. "
|
||||
"Describe the architecture, data flow, and key design decisions."
|
||||
),
|
||||
"expected_signals": ["privacy_pillar", "tier_awareness", "license_split"],
|
||||
},
|
||||
{
|
||||
"id": "ho_003",
|
||||
"name": "tier_gate_local_llm",
|
||||
"domain": "architecture",
|
||||
"prompt": (
|
||||
"Design the tier-gating architecture for a new CircuitForge product. "
|
||||
"The product should:\n"
|
||||
"- Default to local LLM inference for all tiers\n"
|
||||
"- Unlock cloud LLM for Paid tier and above\n"
|
||||
"- Keep fine-tuned model weights for Premium/Ultra only\n\n"
|
||||
"Describe how the tier check integrates with the LLM router, "
|
||||
"what happens when a Free user tries a Paid-tier feature, "
|
||||
"and how BYOK (bring-your-own-key) fits in."
|
||||
),
|
||||
"expected_signals": ["tier_awareness", "privacy_pillar", "license_split"],
|
||||
},
|
||||
{
|
||||
"id": "ho_004",
|
||||
"name": "heimdall_webhook_plan",
|
||||
"domain": "feature_plan",
|
||||
"prompt": (
|
||||
"Break the following Heimdall feature into a detailed implementation plan with "
|
||||
"file structure and task checkboxes — Stripe webhook handler for subscription lifecycle.\n\n"
|
||||
"Heimdall is the CircuitForge license server (FastAPI + SQLite). "
|
||||
"The webhook needs to handle checkout.session.completed, "
|
||||
"customer.subscription.updated, and customer.subscription.deleted events."
|
||||
),
|
||||
"expected_signals": ["task_structure", "file_paths", "safety_pillar"],
|
||||
},
|
||||
{
|
||||
"id": "ho_005",
|
||||
"name": "nd_accessible_onboarding",
|
||||
"domain": "ux_design",
|
||||
"prompt": (
|
||||
"You are a product designer working on Harrier, a CircuitForge tool for "
|
||||
"helping people navigate government benefits applications.\n\n"
|
||||
"Design the onboarding flow for neurodivergent (ND) users. "
|
||||
"Consider: ADHD time-blindness, executive function challenges, demand avoidance, "
|
||||
"and rejection sensitivity. The flow should reduce cognitive load and "
|
||||
"never use urgency or panic patterns."
|
||||
),
|
||||
"expected_signals": ["accessibility", "safety_pillar", "privacy_pillar"],
|
||||
},
|
||||
{
|
||||
"id": "ho_006",
|
||||
"name": "circuitforge_core_extraction",
|
||||
"domain": "architecture",
|
||||
"prompt": (
|
||||
"Produce a CircuitForge-style design document for the following circuitforge-core "
|
||||
"feature — shared ActivityPub federation module.\n\n"
|
||||
"Background: Multiple CF products (Kiwi, Rook, Snipe) want to publish updates "
|
||||
"to ActivityPub. Build it once in cf-core (MIT licensed) so all products can use it. "
|
||||
"Design the module API, describe what belongs in MIT vs BSL, and note federation "
|
||||
"privacy constraints."
|
||||
),
|
||||
"expected_signals": ["license_split", "privacy_pillar", "cf_conventions"],
|
||||
},
|
||||
{
|
||||
"id": "ho_007",
|
||||
"name": "snipe_trust_score_plan",
|
||||
"domain": "feature_plan",
|
||||
"prompt": (
|
||||
"You are a senior engineer on Snipe, a CircuitForge eBay trust-scoring tool. "
|
||||
"Write a step-by-step engineering plan for: seller trust score calculation.\n\n"
|
||||
"The score should combine: feedback ratio, account age, item-specifics completeness, "
|
||||
"listing photo quality, and shipping time accuracy. "
|
||||
"Include file structure, test plan, and migration steps."
|
||||
),
|
||||
"expected_signals": ["task_structure", "file_paths", "safety_pillar"],
|
||||
},
|
||||
{
|
||||
"id": "ho_008",
|
||||
"name": "avocet_training_pipeline",
|
||||
"domain": "feature_plan",
|
||||
"prompt": (
|
||||
"Break the following Avocet feature into a detailed implementation plan — "
|
||||
"end-to-end fine-tuning pipeline from labeled JSONL to deployed GGUF model.\n\n"
|
||||
"Avocet is the CircuitForge email classifier training tool. "
|
||||
"The pipeline should: validate the dataset, run LoRA SFT via unsloth, "
|
||||
"quantize to Q5_K_M GGUF, run the benchmark harness, and register the model "
|
||||
"in the Avocet model queue if it beats the baseline."
|
||||
),
|
||||
"expected_signals": ["task_structure", "file_paths", "cf_conventions"],
|
||||
},
|
||||
{
|
||||
"id": "ho_009",
|
||||
"name": "privacy_data_flow",
|
||||
"domain": "architecture",
|
||||
"prompt": (
|
||||
"Design the data privacy architecture for a CircuitForge cloud product. "
|
||||
"Describe: what PII is collected, how it's stored, retention policy, "
|
||||
"obfuscation strategy for cloud-side logs, and how consent is obtained "
|
||||
"in plain language. The product handles job applications (resumes, cover letters)."
|
||||
),
|
||||
"expected_signals": ["privacy_pillar", "safety_pillar", "accessibility"],
|
||||
},
|
||||
{
|
||||
"id": "ho_010",
|
||||
"name": "git_workflow_doc",
|
||||
"domain": "process_doc",
|
||||
"prompt": (
|
||||
"Write a developer process document for CircuitForge: conventional commit and "
|
||||
"branch workflow for a BSL 1.1 open-core product.\n\n"
|
||||
"Cover: commit message format (type: description), branch naming, "
|
||||
"when to use feature branches vs direct main commits, "
|
||||
"how the MIT/BSL split affects which commits go in which branch, "
|
||||
"and how CI gates on gitleaks for secret scanning."
|
||||
),
|
||||
"expected_signals": ["license_split", "cf_conventions", "task_structure"],
|
||||
},
|
||||
]
|
||||
|
||||
# ── Rubric scoring ─────────────────────────────────────────────────────────────
|
||||
|
||||
_TASK_STRUCTURE_RE = re.compile(r"- \[ \]", re.MULTILINE)
|
||||
_COMMIT_RE = re.compile(r"git commit|git add", re.IGNORECASE)
|
||||
_TIER_RE = re.compile(r"\b(Free|Paid|Premium|Ultra)\s+tier|\btier\s+(Free|Paid|Premium|Ultra)", re.IGNORECASE)
|
||||
_PRIVACY_RE = re.compile(r"\b(privacy|local.?inference|no.?logging|no.?pii|user.?data|data.?reten|obfuscat)", re.IGNORECASE)
|
||||
_SAFETY_RE = re.compile(r"\b(human.?approv|reversib|safety|safe.?default|fail.?safe|harm)", re.IGNORECASE)
|
||||
_A11Y_RE = re.compile(r"\b(neurodiverg|ND\b|accessib|adaptive|ADHD|autism|executive.?function|demand.?avoid)", re.IGNORECASE)
|
||||
_LICENSE_RE = re.compile(r"\b(MIT|BSL|open.?core|proprietary|commercial.?licens)", re.IGNORECASE)
|
||||
_FILE_PATH_RE = re.compile(r"(app/|tests?/|src/|scripts?/)\w[\w/.-]{3,}", re.IGNORECASE)
|
||||
_CF_CONV_RE = re.compile(r"(conda run -n cf|/Library/Development/CircuitForge|circuitforge-core|manage\.sh)", re.IGNORECASE)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RubricScore:
|
||||
task_structure: float = 0.0
|
||||
tier_awareness: float = 0.0
|
||||
privacy_pillar: float = 0.0
|
||||
safety_pillar: float = 0.0
|
||||
accessibility: float = 0.0
|
||||
license_split: float = 0.0
|
||||
file_paths: float = 0.0
|
||||
cf_conventions: float = 0.0
|
||||
length_ok: float = 0.0
|
||||
|
||||
def total(self) -> float:
|
||||
vals = [self.task_structure, self.tier_awareness, self.privacy_pillar,
|
||||
self.safety_pillar, self.accessibility, self.license_split,
|
||||
self.file_paths, self.cf_conventions, self.length_ok]
|
||||
return sum(vals) / len(vals)
|
||||
|
||||
def as_dict(self) -> dict[str, float]:
|
||||
return asdict(self)
|
||||
|
||||
|
||||
def score_response(response: str, prompt_meta: dict[str, Any]) -> RubricScore:
|
||||
words = len(response.split())
|
||||
s = RubricScore()
|
||||
|
||||
# Task structure: needs checkboxes AND at least one commit step
|
||||
checkbox_hits = len(_TASK_STRUCTURE_RE.findall(response))
|
||||
has_commit = bool(_COMMIT_RE.search(response))
|
||||
s.task_structure = min(1.0, checkbox_hits / 5) * 0.7 + (0.3 if has_commit else 0.0)
|
||||
|
||||
# Tier awareness
|
||||
s.tier_awareness = min(1.0, len(_TIER_RE.findall(response)) / 2)
|
||||
|
||||
# Privacy pillar
|
||||
s.privacy_pillar = min(1.0, len(_PRIVACY_RE.findall(response)) / 3)
|
||||
|
||||
# Safety pillar
|
||||
s.safety_pillar = min(1.0, len(_SAFETY_RE.findall(response)) / 2)
|
||||
|
||||
# Accessibility
|
||||
s.accessibility = min(1.0, len(_A11Y_RE.findall(response)) / 2)
|
||||
|
||||
# License split awareness
|
||||
s.license_split = min(1.0, len(_LICENSE_RE.findall(response)) / 2)
|
||||
|
||||
# File paths: at least 3 plausible path references
|
||||
s.file_paths = min(1.0, len(_FILE_PATH_RE.findall(response)) / 3)
|
||||
|
||||
# CF conventions
|
||||
s.cf_conventions = min(1.0, len(_CF_CONV_RE.findall(response)) / 2)
|
||||
|
||||
# Length: 200–2500 words is healthy; outside = partial credit
|
||||
if 200 <= words <= 2500:
|
||||
s.length_ok = 1.0
|
||||
elif words < 200:
|
||||
s.length_ok = words / 200
|
||||
else:
|
||||
s.length_ok = max(0.0, 1.0 - (words - 2500) / 2500)
|
||||
|
||||
return s
|
||||
|
||||
|
||||
# ── Model client ───────────────────────────────────────────────────────────────
|
||||
|
||||
# Registry of named model targets (shorthand → {api_base, model_name})
|
||||
MODEL_REGISTRY: dict[str, dict[str, str]] = {
|
||||
"deepseek-r1-1.5b": {
|
||||
"api_base": CF_TEXT_BASE,
|
||||
"model": "deepseek-r1-1.5b",
|
||||
"description": "DeepSeek R1 1.5B distill (cf-orch catalog key)",
|
||||
},
|
||||
"deepseek-r1-7b-4bit": {
|
||||
"api_base": CF_TEXT_BASE,
|
||||
"model": "deepseek-r1-7b-4bit",
|
||||
"description": "DeepSeek R1 7B distill, 4-bit (cf-orch catalog key)",
|
||||
},
|
||||
"deepseek-r1-0528-qwen3-8b-gguf": {
|
||||
"api_base": CF_TEXT_BASE,
|
||||
"model": "deepseek-r1-0528-qwen3-8b-gguf",
|
||||
"description": "DeepSeek R1 0528 Qwen3 8B GGUF -- current reasoning model (4 nodes)",
|
||||
},
|
||||
"deepseek-coder-6.7b-4bit": {
|
||||
"api_base": CF_TEXT_BASE,
|
||||
"model": "deepseek-coder-6.7b-4bit",
|
||||
"description": "DeepSeek Coder 6.7B instruct, 4-bit (cf-orch catalog key)",
|
||||
},
|
||||
"granite-4.1-8b": {
|
||||
"api_base": CF_TEXT_BASE,
|
||||
"model": "granite-4.1-8b",
|
||||
"description": "IBM Granite 4.1 8B, 4-bit -- safety-trained (cf-orch catalog key)",
|
||||
},
|
||||
"capybarahermes-2.5-mistral-7b-gguf": {
|
||||
"api_base": CF_TEXT_BASE,
|
||||
"model": "capybarahermes-2.5-mistral-7b-gguf",
|
||||
"description": "CapybaraHermes 2.5 Mistral 7B GGUF -- conversational/creative (4 nodes)",
|
||||
},
|
||||
"darwin-9b-opus-gguf": {
|
||||
"api_base": CF_TEXT_BASE,
|
||||
"model": "darwin-9b-opus-gguf",
|
||||
"description": "Darwin 9B Opus GGUF -- high-quality long-form writing (3 nodes)",
|
||||
},
|
||||
"qwen2.5-3b": {
|
||||
"api_base": CF_TEXT_BASE,
|
||||
"model": "qwen2.5-3b",
|
||||
"description": "Qwen 2.5 3B Q4 GGUF (cf-orch catalog key)",
|
||||
},
|
||||
"qwen2.5-7b": {
|
||||
"api_base": CF_TEXT_BASE,
|
||||
"model": "qwen2.5-7b",
|
||||
"description": "Qwen 2.5 7B Q4 GGUF (cf-orch catalog key)",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# ── cf-orch allocation ─────────────────────────────────────────────────────────
|
||||
|
||||
def _cforch_allocate(
|
||||
model_id: str,
|
||||
cforch_url: str,
|
||||
startup_timeout_s: float = 300.0,
|
||||
) -> tuple[str, str] | None:
|
||||
"""Allocate a cf-text instance for model_id via the cf-orch coordinator.
|
||||
|
||||
Returns (service_url, allocation_id) on success, None on failure.
|
||||
service_url is the direct node URL exposing /v1/chat/completions.
|
||||
"""
|
||||
try:
|
||||
resp = httpx.post(
|
||||
f"{cforch_url}/api/services/cf-text/allocate",
|
||||
json={
|
||||
"model_candidates": [model_id],
|
||||
"caller": "avocet",
|
||||
"pipeline": "plans_benchmark",
|
||||
},
|
||||
timeout=120.0,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
service_url: str = data["url"]
|
||||
allocation_id: str = data.get("allocation_id", "")
|
||||
node_id: str = data.get("node_id", "")
|
||||
gpu_id: int | None = data.get("gpu_id")
|
||||
|
||||
if data.get("started", False) and not data.get("warm", True):
|
||||
# Use \n so the SSE generator sees the line immediately
|
||||
print(f" [cold start] loading {model_id!r} — polling every 3s…", flush=True)
|
||||
t0 = time.monotonic()
|
||||
deadline = t0 + startup_timeout_s
|
||||
probe_misses = 0
|
||||
|
||||
while time.monotonic() < deadline:
|
||||
elapsed = time.monotonic() - t0
|
||||
try:
|
||||
status = httpx.get(f"{cforch_url}/api/services/cf-text/status", timeout=5.0)
|
||||
if status.is_success:
|
||||
instances = status.json().get("instances", [])
|
||||
match = next(
|
||||
(i for i in instances
|
||||
if i.get("node_id") == node_id and i.get("gpu_id") == gpu_id),
|
||||
None,
|
||||
)
|
||||
if match:
|
||||
probe_misses = 0
|
||||
state = match.get("state", "")
|
||||
if state == "running":
|
||||
print(f" [cold start] ready in {elapsed:.0f}s", flush=True)
|
||||
return service_url, allocation_id
|
||||
elif state == "stopped":
|
||||
print(f" [cold start] failed — service stopped after {elapsed:.0f}s", flush=True)
|
||||
return None
|
||||
else:
|
||||
# still starting — emit keepalive so SSE stream stays alive
|
||||
print(f" [cold start] state={state!r} elapsed={elapsed:.0f}s", flush=True)
|
||||
else:
|
||||
probe_misses += 1
|
||||
print(f" [cold start] waiting… elapsed={elapsed:.0f}s", flush=True)
|
||||
if probe_misses >= 6:
|
||||
try:
|
||||
h = httpx.get(f"{service_url}/health", timeout=3.0)
|
||||
if h.is_success:
|
||||
print(f" [cold start] ready via health check in {elapsed:.0f}s", flush=True)
|
||||
return service_url, allocation_id
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
print(f" [cold start] status poll returned {status.status_code}, elapsed={elapsed:.0f}s", flush=True)
|
||||
except Exception as poll_exc:
|
||||
print(f" [cold start] poll error: {poll_exc} elapsed={elapsed:.0f}s", flush=True)
|
||||
time.sleep(3.0)
|
||||
|
||||
print(f" [cold start] timed out after {time.monotonic()-t0:.0f}s", flush=True)
|
||||
return None
|
||||
|
||||
return service_url, allocation_id
|
||||
except Exception as exc:
|
||||
print(f"[warn] cf-orch allocation failed for {model_id!r}: {exc}", file=sys.stderr)
|
||||
return None
|
||||
|
||||
|
||||
def _call_model_direct(service_url: str, model: str, prompt: str, timeout: int = 600) -> tuple[str, float]:
|
||||
"""Call an OpenAI-compatible /v1/chat/completions on a direct service URL."""
|
||||
t0 = time.monotonic()
|
||||
resp = httpx.post(
|
||||
f"{service_url.rstrip('/')}/v1/chat/completions",
|
||||
json={
|
||||
"model": model,
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"max_tokens": 2048,
|
||||
"temperature": 0.2,
|
||||
},
|
||||
timeout=timeout,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
latency = time.monotonic() - t0
|
||||
text = resp.json()["choices"][0]["message"]["content"]
|
||||
return text, latency
|
||||
|
||||
|
||||
def _call_model(api_base: str, model: str, prompt: str, timeout: int = 180) -> tuple[str, float]:
|
||||
"""Call an OpenAI-compatible /chat/completions endpoint. Returns (text, latency_s)."""
|
||||
t0 = time.monotonic()
|
||||
resp = httpx.post(
|
||||
f"{api_base}/chat/completions",
|
||||
json={
|
||||
"model": model,
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"max_tokens": 2048,
|
||||
"temperature": 0.2,
|
||||
},
|
||||
timeout=timeout,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
latency = time.monotonic() - t0
|
||||
text = resp.json()["choices"][0]["message"]["content"]
|
||||
return text, latency
|
||||
|
||||
|
||||
# ── Benchmark runner ───────────────────────────────────────────────────────────
|
||||
|
||||
@dataclass
|
||||
class PromptResult:
|
||||
prompt_id: str
|
||||
prompt_name: str
|
||||
model_key: str
|
||||
response: str
|
||||
latency_s: float
|
||||
word_count: int
|
||||
scores: dict[str, float]
|
||||
total_score: float
|
||||
error: str | None = None
|
||||
|
||||
|
||||
def run_benchmark(
|
||||
model_key: str,
|
||||
model_name: str,
|
||||
prompts: list[dict[str, Any]] | None = None,
|
||||
verbose: bool = False,
|
||||
# cf-orch path
|
||||
use_cforch: bool = False,
|
||||
cforch_url: str = CF_COORD_URL,
|
||||
# direct path (used when not cf-orch)
|
||||
api_base: str = CF_TEXT_BASE,
|
||||
) -> list[PromptResult]:
|
||||
"""Run all prompts through one model. Uses cf-orch allocation when use_cforch=True."""
|
||||
if prompts is None:
|
||||
prompts = HELD_OUT_PROMPTS
|
||||
|
||||
# Allocate once per model when using cf-orch
|
||||
service_url: str | None = None
|
||||
if use_cforch:
|
||||
print(f" Allocating {model_name!r} via cf-orch…", flush=True)
|
||||
alloc = _cforch_allocate(model_name, cforch_url)
|
||||
if alloc is None:
|
||||
# Return all prompts as errors
|
||||
return [
|
||||
PromptResult(
|
||||
prompt_id=p["id"], prompt_name=p["name"], model_key=model_key,
|
||||
response="", latency_s=0.0, word_count=0, scores={}, total_score=0.0,
|
||||
error=f"cf-orch allocation failed for {model_name!r}",
|
||||
)
|
||||
for p in prompts
|
||||
]
|
||||
service_url, _alloc_id = alloc
|
||||
|
||||
results: list[PromptResult] = []
|
||||
for p in prompts:
|
||||
if verbose:
|
||||
print(f" [{p['id']}] {p['name']} … ", end="", flush=True)
|
||||
try:
|
||||
if service_url:
|
||||
response, latency = _call_model_direct(service_url, model_name, p["prompt"])
|
||||
else:
|
||||
response, latency = _call_model(api_base, model_name, p["prompt"])
|
||||
rubric = score_response(response, p)
|
||||
result = PromptResult(
|
||||
prompt_id=p["id"],
|
||||
prompt_name=p["name"],
|
||||
model_key=model_key,
|
||||
response=response,
|
||||
latency_s=round(latency, 2),
|
||||
word_count=len(response.split()),
|
||||
scores=rubric.as_dict(),
|
||||
total_score=round(rubric.total(), 3),
|
||||
)
|
||||
if verbose:
|
||||
print(f"score={result.total_score:.3f} ({result.word_count}w, {latency:.1f}s)")
|
||||
except Exception as exc:
|
||||
result = PromptResult(
|
||||
prompt_id=p["id"],
|
||||
prompt_name=p["name"],
|
||||
model_key=model_key,
|
||||
response="",
|
||||
latency_s=0.0,
|
||||
word_count=0,
|
||||
scores={},
|
||||
total_score=0.0,
|
||||
error=str(exc),
|
||||
)
|
||||
if verbose:
|
||||
print(f"ERROR: {exc}")
|
||||
results.append(result)
|
||||
return results
|
||||
|
||||
|
||||
# ── Reporting ──────────────────────────────────────────────────────────────────
|
||||
|
||||
def _print_single_report(results: list[PromptResult], model_key: str) -> None:
|
||||
ok = [r for r in results if not r.error]
|
||||
err = [r for r in results if r.error]
|
||||
if not ok:
|
||||
print(f"\n[{model_key}] All {len(err)} prompts failed.\n")
|
||||
return
|
||||
|
||||
avg_total = sum(r.total_score for r in ok) / len(ok)
|
||||
avg_latency = sum(r.latency_s for r in ok) / len(ok)
|
||||
|
||||
# Aggregate per-rubric averages
|
||||
rubric_keys = list(ok[0].scores.keys())
|
||||
rubric_avgs = {k: sum(r.scores.get(k, 0) for r in ok) / len(ok) for k in rubric_keys}
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print(f" Model : {model_key}")
|
||||
print(f" Prompts: {len(ok)}/{len(results)} passed ({len(err)} errors)")
|
||||
print(f" Overall score : {avg_total:.3f} (avg latency {avg_latency:.1f}s)")
|
||||
print(f"\n Rubric breakdown:")
|
||||
for k, v in sorted(rubric_avgs.items(), key=lambda x: -x[1]):
|
||||
bar = "█" * int(v * 20)
|
||||
print(f" {k:<22} {v:.3f} {bar}")
|
||||
print(f"\n Per-prompt scores:")
|
||||
for r in sorted(ok, key=lambda x: -x.total_score):
|
||||
flag = "⚠" if r.total_score < 0.3 else " "
|
||||
print(f" {flag} {r.prompt_id} {r.prompt_name:<35} {r.total_score:.3f} ({r.word_count}w)")
|
||||
if err:
|
||||
print(f"\n Errors:")
|
||||
for r in err:
|
||||
print(f" {r.prompt_id} {r.prompt_name}: {r.error}")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
|
||||
def _print_comparison_table(all_results: dict[str, list[PromptResult]]) -> None:
|
||||
model_keys = list(all_results.keys())
|
||||
prompt_ids = [p["id"] for p in HELD_OUT_PROMPTS]
|
||||
|
||||
# Scores by (model, prompt_id)
|
||||
score_map: dict[tuple[str, str], float] = {}
|
||||
for mk, results in all_results.items():
|
||||
for r in results:
|
||||
score_map[(mk, r.prompt_id)] = r.total_score if not r.error else 0.0
|
||||
|
||||
col_w = 10
|
||||
header = f"{'Prompt':<35}" + "".join(f"{mk[:col_w-1]:<{col_w}}" for mk in model_keys)
|
||||
print(f"\n{'='*len(header)}")
|
||||
print(" COMPARISON TABLE")
|
||||
print(f"{'='*len(header)}")
|
||||
print(f" {header}")
|
||||
print(f" {'-'*len(header)}")
|
||||
|
||||
for pid in prompt_ids:
|
||||
pname = next(p["name"] for p in HELD_OUT_PROMPTS if p["id"] == pid)
|
||||
row = f" {pname:<35}"
|
||||
best = max(score_map.get((mk, pid), 0.0) for mk in model_keys)
|
||||
for mk in model_keys:
|
||||
v = score_map.get((mk, pid), 0.0)
|
||||
marker = "*" if v == best and len(model_keys) > 1 else " "
|
||||
row += f"{v:.3f}{marker} "
|
||||
print(row)
|
||||
|
||||
print(f" {'-'*len(header)}")
|
||||
avgs_row = f" {'AVERAGE':<35}"
|
||||
best_avg = -1.0
|
||||
avgs: dict[str, float] = {}
|
||||
for mk in model_keys:
|
||||
vals = [score_map.get((mk, pid), 0.0) for pid in prompt_ids]
|
||||
avgs[mk] = sum(vals) / len(vals)
|
||||
best_avg = max(best_avg, avgs[mk])
|
||||
for mk in model_keys:
|
||||
marker = "*" if avgs[mk] == best_avg and len(model_keys) > 1 else " "
|
||||
avgs_row += f"{avgs[mk]:.3f}{marker} "
|
||||
print(avgs_row)
|
||||
print(f"{'='*len(header)}\n")
|
||||
if len(model_keys) > 1:
|
||||
winner = max(avgs, key=lambda k: avgs[k])
|
||||
print(f" Winner: {winner} (avg {avgs[winner]:.3f})\n")
|
||||
|
||||
|
||||
# ── CLI ────────────────────────────────────────────────────────────────────────
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter)
|
||||
parser.add_argument("--list-models", action="store_true",
|
||||
help="Print registered model shortcuts and exit")
|
||||
parser.add_argument("--model", metavar="KEY",
|
||||
help="Benchmark a single model (registry key or raw model name)")
|
||||
parser.add_argument("--compare", nargs="+", metavar="KEY",
|
||||
help="Compare two or more models side-by-side")
|
||||
parser.add_argument("--cforch", action="store_true",
|
||||
help="Route inference through cf-orch coordinator (allocate per model)")
|
||||
parser.add_argument("--cforch-url", default=CF_COORD_URL, metavar="URL",
|
||||
help=f"cf-orch coordinator URL (default: {CF_COORD_URL})")
|
||||
parser.add_argument("--api-base", default=None,
|
||||
help="Direct API base URL when not using cf-orch")
|
||||
parser.add_argument("--model-name", default=None,
|
||||
help="Override model name sent to API (single-model runs only)")
|
||||
parser.add_argument("--prompts", nargs="+", metavar="ID",
|
||||
help="Run only specific prompt IDs (e.g. ho_001 ho_003)")
|
||||
parser.add_argument("--output", type=Path, default=None,
|
||||
help="Write detailed JSON results to this path")
|
||||
parser.add_argument("--workers", type=int, default=1, metavar="N",
|
||||
help="Run N models concurrently (default 1). Set to number of available nodes.")
|
||||
parser.add_argument("--verbose", "-v", action="store_true",
|
||||
help="Print per-prompt progress")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.list_models:
|
||||
print("\nRegistered model shortcuts:")
|
||||
for key, info in MODEL_REGISTRY.items():
|
||||
print(f" {key:<20} {info['description']}")
|
||||
print(f"\nDefault endpoints:")
|
||||
print(f" direct {CF_TEXT_BASE}")
|
||||
print(f" cf-orch {CF_COORD_URL}")
|
||||
return
|
||||
|
||||
prompts = HELD_OUT_PROMPTS
|
||||
if args.prompts:
|
||||
ids = set(args.prompts)
|
||||
prompts = [p for p in HELD_OUT_PROMPTS if p["id"] in ids]
|
||||
if not prompts:
|
||||
print(f"No prompts matched IDs: {args.prompts}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
model_keys: list[str] = []
|
||||
if args.compare:
|
||||
model_keys = args.compare
|
||||
elif args.model:
|
||||
model_keys = [args.model]
|
||||
else:
|
||||
parser.print_help()
|
||||
sys.exit(0)
|
||||
|
||||
all_results: dict[str, list[PromptResult]] = {}
|
||||
print_lock = threading.Lock()
|
||||
|
||||
def _run_one(mk: str) -> tuple[str, list[PromptResult]]:
|
||||
if mk in MODEL_REGISTRY:
|
||||
reg = MODEL_REGISTRY[mk]
|
||||
model_name = args.model_name or reg["model"]
|
||||
direct_base = args.api_base or reg["api_base"]
|
||||
else:
|
||||
model_name = args.model_name or mk
|
||||
direct_base = args.api_base or CF_TEXT_BASE
|
||||
|
||||
if args.cforch:
|
||||
with print_lock:
|
||||
print(f"\nRunning [{mk}] via cf-orch ({args.cforch_url}) model={model_name}")
|
||||
results = run_benchmark(
|
||||
mk, model_name, prompts=prompts, verbose=args.verbose,
|
||||
use_cforch=True, cforch_url=args.cforch_url,
|
||||
)
|
||||
else:
|
||||
with print_lock:
|
||||
print(f"\nRunning [{mk}] → {direct_base} model={model_name}")
|
||||
results = run_benchmark(
|
||||
mk, model_name, prompts=prompts, verbose=args.verbose,
|
||||
api_base=direct_base,
|
||||
)
|
||||
|
||||
with print_lock:
|
||||
_print_single_report(results, mk)
|
||||
return mk, results
|
||||
|
||||
workers = max(1, args.workers)
|
||||
if workers == 1 or len(model_keys) == 1:
|
||||
for mk in model_keys:
|
||||
mk_out, results = _run_one(mk)
|
||||
all_results[mk_out] = results
|
||||
else:
|
||||
with ThreadPoolExecutor(max_workers=workers) as pool:
|
||||
futures = {pool.submit(_run_one, mk): mk for mk in model_keys}
|
||||
for fut in as_completed(futures):
|
||||
mk_out, results = fut.result()
|
||||
all_results[mk_out] = results
|
||||
|
||||
if len(model_keys) > 1:
|
||||
_print_comparison_table(all_results)
|
||||
|
||||
if args.output:
|
||||
args.output.parent.mkdir(parents=True, exist_ok=True)
|
||||
payload = {
|
||||
mk: [asdict(r) for r in results]
|
||||
for mk, results in all_results.items()
|
||||
}
|
||||
with open(args.output, "w", encoding="utf-8") as f:
|
||||
json.dump(payload, f, indent=2, ensure_ascii=False)
|
||||
print(f"Wrote detailed results to {args.output}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -1,952 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
"""
|
||||
Writing style benchmark harness -- score local text-gen models for writing style match.
|
||||
|
||||
Runs each model against a set of test prompts, extracts style signals from the
|
||||
outputs, compares them to a style corpus, and produces a ranked markdown table.
|
||||
|
||||
Usage:
|
||||
# List available ollama models
|
||||
conda run -n cf python scripts/benchmark_style.py --list-models
|
||||
|
||||
# Run against all models with default test prompts
|
||||
conda run -n cf python scripts/benchmark_style.py --run
|
||||
|
||||
# Run specific models only
|
||||
conda run -n cf python scripts/benchmark_style.py --run --models mistral:7b,llama3.1:8b
|
||||
|
||||
# Use a custom corpus directory
|
||||
conda run -n cf python scripts/benchmark_style.py --run --samples data/style_corpus/
|
||||
|
||||
# Print last results table
|
||||
conda run -n cf python scripts/benchmark_style.py --show-last
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
_ROOT = Path(__file__).parent.parent
|
||||
_CORPUS_DIR = _ROOT / "data" / "style_corpus"
|
||||
_RESULTS_DIR = _ROOT / "benchmark_results"
|
||||
_OLLAMA_URL = "http://localhost:11434"
|
||||
_CFORCH_URL = "http://localhost:7700"
|
||||
|
||||
# Subdirectories under --scan-disk root that may contain GGUFs
|
||||
_SCAN_SUBDIRS = ["textgen/models", "llama.cpp/models", "cf-text/models", "vllm/models"]
|
||||
|
||||
# ── Filler phrases that should be absent from good style-match output ──────────
|
||||
FILLER_PHRASES: list[str] = [
|
||||
"delve", "certainly", "absolutely", "i apologize", "i'd be happy to",
|
||||
"of course", "great question", "i understand", "let me know if",
|
||||
"feel free to", "it's important to note", "it's worth noting",
|
||||
"in conclusion", "to summarize", "in summary",
|
||||
]
|
||||
|
||||
# ── Test prompts: (thread_title, thread_body, context_tag) ───────────────────
|
||||
# These are representative threads that Magpie might reply to.
|
||||
# Extend this list with real examples as the corpus grows.
|
||||
TEST_PROMPTS: list[dict[str, str]] = [
|
||||
{
|
||||
"tag": "selfhosted_ai_fatigue",
|
||||
"thread_title": "Anyone else getting tired of re-explaining their setup every time an AI model forgets?",
|
||||
"thread_body": (
|
||||
"Every session I start over. My whole hardware setup, what tools I use, "
|
||||
"what I've already tried. It's exhausting. There has to be a better way."
|
||||
),
|
||||
},
|
||||
{
|
||||
"tag": "privacy_local_llm",
|
||||
"thread_title": "What's the point of running local LLMs if the apps still phone home?",
|
||||
"thread_body": (
|
||||
"I went through all the trouble of setting up ollama and now I find out "
|
||||
"the frontend I'm using is sending telemetry. Kind of defeats the purpose."
|
||||
),
|
||||
},
|
||||
{
|
||||
"tag": "solarpunk_tech",
|
||||
"thread_title": "What does solarpunk computing actually look like in practice?",
|
||||
"thread_body": (
|
||||
"I keep seeing the aesthetic but not a lot of concrete examples of "
|
||||
"people living it out with their tech choices. What does it mean day to day?"
|
||||
),
|
||||
},
|
||||
{
|
||||
"tag": "nd_tools",
|
||||
"thread_title": "Tools that actually help with executive function vs ones that just add friction",
|
||||
"thread_body": (
|
||||
"I've tried a dozen productivity apps and most of them require more "
|
||||
"executive function to maintain than they save. What actually sticks for you?"
|
||||
),
|
||||
},
|
||||
{
|
||||
"tag": "data_ownership",
|
||||
"thread_title": "Who actually owns your data when you use a 'free' AI tool?",
|
||||
"thread_body": (
|
||||
"Read the ToS on three different AI assistants today. In all three cases "
|
||||
"your inputs can be used for training, shared with partners, and retained "
|
||||
"indefinitely. At what point does 'free' just mean you're the product?"
|
||||
),
|
||||
},
|
||||
{
|
||||
"tag": "digital_culture",
|
||||
"thread_title": "The internet used to feel like it belonged to everyone. What happened?",
|
||||
"thread_body": (
|
||||
"I grew up on forums, IRC, personal homepages. Now everything is a platform "
|
||||
"owned by someone trying to extract value from the community that built it. "
|
||||
"Is the fediverse / self-hosting movement actually reversing this or just "
|
||||
"a niche hobby?"
|
||||
),
|
||||
},
|
||||
]
|
||||
|
||||
GENERATION_PARAMS: dict[str, Any] = {
|
||||
"temperature": 0.7,
|
||||
"top_p": 0.9,
|
||||
"num_predict": 300,
|
||||
}
|
||||
|
||||
SYSTEM_PROMPT = (
|
||||
"You are a writing assistant. Your job is to write a Reddit reply that matches "
|
||||
"the voice, tone, and style of the provided samples exactly.\n\n"
|
||||
"Voice characteristics:\n"
|
||||
"- Casual engineer tone. Short punchy sentences.\n"
|
||||
"- No hype, no buzzwords, no em dashes, no semicolons.\n"
|
||||
"- Community-first perspective. Solarpunk values.\n"
|
||||
"- Direct and opinionated. No throat-clearing or filler.\n"
|
||||
"- When relevant, mention personal experience with real tools.\n\n"
|
||||
"Write ONLY the reply. No preamble, no 'Here is a reply:', no meta-commentary."
|
||||
)
|
||||
|
||||
|
||||
# ── Style signal extraction ───────────────────────────────────────────────────
|
||||
|
||||
@dataclass
|
||||
class StyleSignals:
|
||||
"""Quantitative style signals extracted from a text sample."""
|
||||
sentence_count: int = 0
|
||||
word_count: int = 0
|
||||
avg_sentence_length: float = 0.0
|
||||
em_dash_count: int = 0
|
||||
semicolon_count: int = 0
|
||||
filler_hits: list[str] = field(default_factory=list)
|
||||
question_ratio: float = 0.0 # fraction of sentences ending in '?'
|
||||
first_person_ratio: float = 0.0 # fraction of sentences starting with 'I'
|
||||
avg_word_length: float = 0.0
|
||||
|
||||
|
||||
def extract_signals(text: str) -> StyleSignals:
|
||||
"""Extract style signals from a text sample."""
|
||||
text = text.strip()
|
||||
if text.startswith("[ERROR:"):
|
||||
return StyleSignals() # zero-score sentinel — caller checks for empty output
|
||||
sentences = [s.strip() for s in re.split(r'(?<=[.!?])\s+', text) if s.strip()]
|
||||
words = text.split()
|
||||
|
||||
if not sentences:
|
||||
return StyleSignals()
|
||||
|
||||
avg_sentence_length = len(words) / len(sentences) if sentences else 0.0
|
||||
avg_word_length = (sum(len(w.strip('.,!?;:"\'')) for w in words) / len(words)) if words else 0.0
|
||||
|
||||
em_dash_count = text.count('\u2014') + text.count(' -- ') + text.count('--')
|
||||
semicolon_count = text.count(';')
|
||||
|
||||
filler_hits = [p for p in FILLER_PHRASES if p.lower() in text.lower()]
|
||||
|
||||
question_ratio = sum(1 for s in sentences if s.endswith('?')) / len(sentences)
|
||||
first_person_ratio = sum(1 for s in sentences if re.match(r"^I\b", s)) / len(sentences)
|
||||
|
||||
return StyleSignals(
|
||||
sentence_count=len(sentences),
|
||||
word_count=len(words),
|
||||
avg_sentence_length=avg_sentence_length,
|
||||
em_dash_count=em_dash_count,
|
||||
semicolon_count=semicolon_count,
|
||||
filler_hits=filler_hits,
|
||||
question_ratio=question_ratio,
|
||||
first_person_ratio=first_person_ratio,
|
||||
avg_word_length=avg_word_length,
|
||||
)
|
||||
|
||||
|
||||
def build_corpus_profile(corpus_dir: Path) -> StyleSignals | None:
|
||||
"""Aggregate style signals across all corpus samples into a target profile."""
|
||||
samples = list(corpus_dir.glob("*.txt"))
|
||||
if not samples:
|
||||
return None
|
||||
|
||||
all_signals = [extract_signals(p.read_text(encoding="utf-8")) for p in samples]
|
||||
n = len(all_signals)
|
||||
|
||||
return StyleSignals(
|
||||
sentence_count=int(sum(s.sentence_count for s in all_signals) / n),
|
||||
word_count=int(sum(s.word_count for s in all_signals) / n),
|
||||
avg_sentence_length=sum(s.avg_sentence_length for s in all_signals) / n,
|
||||
em_dash_count=int(sum(s.em_dash_count for s in all_signals) / n),
|
||||
semicolon_count=int(sum(s.semicolon_count for s in all_signals) / n),
|
||||
question_ratio=sum(s.question_ratio for s in all_signals) / n,
|
||||
first_person_ratio=sum(s.first_person_ratio for s in all_signals) / n,
|
||||
avg_word_length=sum(s.avg_word_length for s in all_signals) / n,
|
||||
)
|
||||
|
||||
|
||||
def score_against_profile(output_signals: StyleSignals, profile: StyleSignals | None) -> float:
|
||||
"""Score a model output against the corpus profile. Returns 0-100.
|
||||
|
||||
Penalties:
|
||||
- Em dashes / semicolons: -5 each occurrence (hard CF style violation)
|
||||
- Filler phrases: -8 each hit (strong signal of non-style output)
|
||||
- Sentence length delta: proportional penalty (target: close to corpus avg)
|
||||
- Word length delta: smaller penalty
|
||||
|
||||
When no corpus profile is available, falls back to absolute signal scores only.
|
||||
"""
|
||||
score = 100.0
|
||||
|
||||
# Hard violations -- always penalised regardless of corpus
|
||||
score -= output_signals.em_dash_count * 5
|
||||
score -= output_signals.semicolon_count * 3
|
||||
score -= len(output_signals.filler_hits) * 8
|
||||
|
||||
if profile is not None:
|
||||
# Sentence length delta: penalise proportionally
|
||||
length_delta = abs(output_signals.avg_sentence_length - profile.avg_sentence_length)
|
||||
score -= min(length_delta * 2, 20)
|
||||
|
||||
# Question ratio delta
|
||||
question_delta = abs(output_signals.question_ratio - profile.question_ratio)
|
||||
score -= min(question_delta * 10, 10)
|
||||
|
||||
return max(0.0, score)
|
||||
|
||||
|
||||
# ── Ollama generation ─────────────────────────────────────────────────────────
|
||||
|
||||
_CFORCH_NODE_ID = "heimdall"
|
||||
|
||||
|
||||
def cforch_list_catalog(
|
||||
cforch_url: str = _CFORCH_URL,
|
||||
node_id: str = _CFORCH_NODE_ID,
|
||||
) -> dict[str, int]:
|
||||
"""Return the cf-text catalog from cf-orch as {model_id: vram_mb}.
|
||||
|
||||
Uses ?node_id= to request the catalog from a specific node's profile,
|
||||
avoiding cross-node catalog shadowing when multiple nodes define catalogs
|
||||
for the same service.
|
||||
"""
|
||||
try:
|
||||
resp = httpx.get(
|
||||
f"{cforch_url}/api/services/cf-text/catalog",
|
||||
params={"node_id": node_id} if node_id else {},
|
||||
timeout=10.0,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
raw = resp.json()
|
||||
return {
|
||||
model_id: (entry.get("vram_mb", 0) if isinstance(entry, dict) else 0)
|
||||
for model_id, entry in raw.items()
|
||||
}
|
||||
except Exception as exc:
|
||||
print(f"[warn] Could not reach cf-orch catalog at {cforch_url}: {exc}", file=sys.stderr)
|
||||
return {}
|
||||
|
||||
|
||||
def _cforch_allocate_service(
|
||||
service: str,
|
||||
model_id: str,
|
||||
cforch_url: str,
|
||||
startup_timeout_s: float,
|
||||
health_path: str,
|
||||
) -> tuple[str, str] | None:
|
||||
"""Generic cf-orch allocate + state-signal wait. Returns (service_url, allocation_id) or None.
|
||||
|
||||
After allocating, waits for the coordinator's service state to reach 'running'.
|
||||
Fails immediately if the state reaches 'stopped' (crashed load) — no waiting out
|
||||
the full timeout for a model that already failed.
|
||||
Falls back to health-polling if the coordinator doesn't expose a matching instance
|
||||
(e.g. older coordinator version or service not yet registered in probe loop).
|
||||
"""
|
||||
try:
|
||||
resp = httpx.post(
|
||||
f"{cforch_url}/api/services/{service}/allocate",
|
||||
json={
|
||||
"model_candidates": [model_id],
|
||||
"caller": "avocet",
|
||||
"pipeline": "style_benchmark",
|
||||
},
|
||||
timeout=120.0,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
service_url: str = data["url"]
|
||||
allocation_id: str = data.get("allocation_id", "")
|
||||
node_id: str = data.get("node_id", "")
|
||||
gpu_id: int | None = data.get("gpu_id")
|
||||
|
||||
if data.get("started", False) and not data.get("warm", True):
|
||||
print(f" [cold start] waiting for {service} to load {model_id!r}...", end=" ", flush=True)
|
||||
t0 = time.monotonic()
|
||||
deadline = t0 + startup_timeout_s
|
||||
probe_misses = 0 # consecutive polls with no matching instance in status
|
||||
|
||||
while time.monotonic() < deadline:
|
||||
try:
|
||||
status = httpx.get(
|
||||
f"{cforch_url}/api/services/{service}/status", timeout=5.0
|
||||
)
|
||||
if status.is_success:
|
||||
instances = status.json().get("instances", [])
|
||||
# Find our specific instance by node+gpu
|
||||
match = next(
|
||||
(i for i in instances
|
||||
if i.get("node_id") == node_id and i.get("gpu_id") == gpu_id),
|
||||
None,
|
||||
)
|
||||
if match:
|
||||
probe_misses = 0
|
||||
state = match.get("state", "")
|
||||
if state == "running":
|
||||
elapsed = time.monotonic() - t0
|
||||
print(f"ready ({elapsed:.0f}s)", flush=True)
|
||||
return service_url, allocation_id
|
||||
elif state == "stopped":
|
||||
print(f"failed (service stopped — model load error)", flush=True)
|
||||
return None
|
||||
# state == "starting" or unknown → keep waiting
|
||||
else:
|
||||
probe_misses += 1
|
||||
# After a grace period with no instance visible, fall back to
|
||||
# direct health-poll (coordinator may not have probed yet)
|
||||
if probe_misses >= 6:
|
||||
try:
|
||||
health = httpx.get(f"{service_url}{health_path}", timeout=3.0)
|
||||
if health.is_success:
|
||||
elapsed = time.monotonic() - t0
|
||||
print(f"ready via health ({elapsed:.0f}s)", flush=True)
|
||||
return service_url, allocation_id
|
||||
except Exception:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
time.sleep(3.0)
|
||||
|
||||
elapsed = time.monotonic() - t0
|
||||
print(f"timed out after {elapsed:.0f}s", flush=True)
|
||||
return None
|
||||
|
||||
return service_url, allocation_id
|
||||
except Exception as exc:
|
||||
print(f"[warn] cf-orch allocation failed for {model_id!r} ({service}): {exc}", file=sys.stderr)
|
||||
return None
|
||||
|
||||
|
||||
def cforch_allocate(
|
||||
model_id: str,
|
||||
cforch_url: str = _CFORCH_URL,
|
||||
startup_timeout_s: float = 180.0,
|
||||
) -> tuple[str, str] | None:
|
||||
"""Allocate a cf-text instance for model_id. Returns (service_url, allocation_id) or None."""
|
||||
return _cforch_allocate_service("cf-text", model_id, cforch_url, startup_timeout_s, "/health")
|
||||
|
||||
|
||||
def cforch_allocate_vllm(
|
||||
model_id: str,
|
||||
cforch_url: str = _CFORCH_URL,
|
||||
startup_timeout_s: float = 300.0,
|
||||
) -> tuple[str, str] | None:
|
||||
"""Allocate a vllm instance for model_id. Returns (service_url, allocation_id) or None.
|
||||
|
||||
vllm exposes an OpenAI-compatible API — generate_cftext() works unchanged
|
||||
against the returned service_url. Startup timeout is longer (300s) because
|
||||
vllm loads large model weights from disk before becoming ready.
|
||||
"""
|
||||
return _cforch_allocate_service("vllm", model_id, cforch_url, startup_timeout_s, "/health")
|
||||
|
||||
|
||||
def cforch_release(allocation_id: str, cforch_url: str = _CFORCH_URL) -> None:
|
||||
"""Release a cf-orch allocation."""
|
||||
if not allocation_id:
|
||||
return
|
||||
try:
|
||||
httpx.delete(f"{cforch_url}/api/services/cf-text/allocations/{allocation_id}", timeout=10.0)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def generate_cftext(
|
||||
service_url: str,
|
||||
model_id: str,
|
||||
prompt: str,
|
||||
system: str = "",
|
||||
) -> tuple[str, float]:
|
||||
"""Call cf-text via OpenAI-compatible /v1/chat/completions. Returns (text, elapsed_ms)."""
|
||||
messages: list[dict[str, str]] = []
|
||||
if system:
|
||||
messages.append({"role": "system", "content": system})
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
payload: dict[str, Any] = {
|
||||
"model": model_id,
|
||||
"messages": messages,
|
||||
"max_tokens": GENERATION_PARAMS.get("num_predict", 300),
|
||||
"temperature": GENERATION_PARAMS.get("temperature", 0.7),
|
||||
"top_p": GENERATION_PARAMS.get("top_p", 0.9),
|
||||
"stream": False,
|
||||
}
|
||||
|
||||
t0 = time.monotonic()
|
||||
try:
|
||||
resp = httpx.post(
|
||||
f"{service_url.rstrip('/')}/v1/chat/completions",
|
||||
json=payload,
|
||||
timeout=180.0,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
elapsed_ms = (time.monotonic() - t0) * 1000
|
||||
content = resp.json()["choices"][0]["message"]["content"]
|
||||
return content.strip(), elapsed_ms
|
||||
except Exception as exc:
|
||||
elapsed_ms = (time.monotonic() - t0) * 1000
|
||||
return f"[ERROR: {exc}]", elapsed_ms
|
||||
|
||||
|
||||
def generate(model_id: str, prompt: str, system: str = "") -> tuple[str, float]:
|
||||
"""Call ollama /api/generate. Returns (text, elapsed_ms)."""
|
||||
payload: dict[str, Any] = {
|
||||
"model": model_id,
|
||||
"prompt": prompt,
|
||||
"stream": False,
|
||||
"options": GENERATION_PARAMS,
|
||||
}
|
||||
if system:
|
||||
payload["system"] = system
|
||||
|
||||
t0 = time.monotonic()
|
||||
try:
|
||||
resp = httpx.post(
|
||||
f"{_OLLAMA_URL}/api/generate",
|
||||
json=payload,
|
||||
timeout=120.0,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
elapsed_ms = (time.monotonic() - t0) * 1000
|
||||
return resp.json().get("response", "").strip(), elapsed_ms
|
||||
except Exception as exc:
|
||||
elapsed_ms = (time.monotonic() - t0) * 1000
|
||||
return f"[ERROR: {exc}]", elapsed_ms
|
||||
|
||||
|
||||
def find_disk_ggufs(llm_root: Path) -> list[Path]:
|
||||
"""Recursively find .gguf files under known subdirs of llm_root.
|
||||
|
||||
Skips vocab-only GGUFs (ggml-vocab-*) which aren't standalone models.
|
||||
"""
|
||||
found: list[Path] = []
|
||||
search_dirs = [llm_root / sub for sub in _SCAN_SUBDIRS] + [llm_root]
|
||||
seen: set[Path] = set()
|
||||
for base in search_dirs:
|
||||
if not base.exists():
|
||||
continue
|
||||
for gguf in base.rglob("*.gguf"):
|
||||
if gguf in seen:
|
||||
continue
|
||||
seen.add(gguf)
|
||||
if gguf.name.startswith("ggml-vocab-"):
|
||||
continue
|
||||
found.append(gguf)
|
||||
return sorted(found)
|
||||
|
||||
|
||||
def gguf_to_ollama_tag(gguf_path: Path) -> str:
|
||||
"""Derive a stable ollama tag from a GGUF path.
|
||||
|
||||
Uses parent dir name + stem to avoid collisions, e.g.:
|
||||
claude-3.7-sonnet-reasoning-gemma3-12B/foo.Q8_0.gguf
|
||||
→ bench-claude-3.7-sonnet-reasoning-gemma3-12b-foo-q8-0
|
||||
"""
|
||||
parent = gguf_path.parent.name.lower()
|
||||
stem = gguf_path.stem.lower()
|
||||
# If stem is contained in parent (common pattern), just use parent
|
||||
slug = parent if stem.replace("-", "").replace("_", "") in parent.replace("-", "").replace("_", "") else f"{parent}-{stem}"
|
||||
slug = re.sub(r"[^a-z0-9]+", "-", slug).strip("-")
|
||||
return f"bench-{slug}:latest"
|
||||
|
||||
|
||||
def register_gguf(gguf_path: Path, tag: str) -> bool:
|
||||
"""Create a temporary ollama model entry from a GGUF file. Returns True on success."""
|
||||
import subprocess
|
||||
import tempfile
|
||||
modelfile = f"FROM {gguf_path.resolve()}\n"
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".Modelfile", delete=False) as f:
|
||||
f.write(modelfile)
|
||||
modelfile_path = f.name
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["ollama", "create", tag, "-f", modelfile_path],
|
||||
capture_output=True, text=True, timeout=60,
|
||||
)
|
||||
return result.returncode == 0
|
||||
except Exception as exc:
|
||||
print(f"[warn] Could not register {gguf_path.name}: {exc}", file=sys.stderr)
|
||||
return False
|
||||
finally:
|
||||
Path(modelfile_path).unlink(missing_ok=True)
|
||||
|
||||
|
||||
def deregister_gguf(tag: str) -> None:
|
||||
"""Remove a temporary ollama model entry."""
|
||||
import subprocess
|
||||
try:
|
||||
subprocess.run(["ollama", "rm", tag], capture_output=True, timeout=30)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def backfill_disk_models(
|
||||
llm_root: Path,
|
||||
existing_tags: set[str],
|
||||
max_vram_mb: int = 0,
|
||||
) -> list[str]:
|
||||
"""Register GGUFs from disk that aren't already in ollama. Returns new tags.
|
||||
|
||||
max_vram_mb: skip files whose size exceeds this threshold (0 = no limit).
|
||||
GGUF file size is a reliable VRAM proxy -- quantized weights load ~1:1.
|
||||
"""
|
||||
ggufs = find_disk_ggufs(llm_root)
|
||||
if not ggufs:
|
||||
print(f"No .gguf files found under {llm_root}", file=sys.stderr)
|
||||
return []
|
||||
|
||||
new_tags: list[str] = []
|
||||
skipped_oom = 0
|
||||
for gguf in ggufs:
|
||||
size_mb = gguf.stat().st_size // (1024 * 1024)
|
||||
if max_vram_mb and size_mb > max_vram_mb:
|
||||
print(f" [skip-oom] {gguf.name} ({size_mb} MB > {max_vram_mb} MB limit)")
|
||||
skipped_oom += 1
|
||||
continue
|
||||
tag = gguf_to_ollama_tag(gguf)
|
||||
if tag in existing_tags:
|
||||
print(f" [skip] {gguf.name} already registered as {tag}")
|
||||
continue
|
||||
print(f" [register] {gguf.name} ({size_mb} MB) → {tag} ...", end=" ", flush=True)
|
||||
if register_gguf(gguf, tag):
|
||||
print("ok")
|
||||
new_tags.append(tag)
|
||||
else:
|
||||
print("failed")
|
||||
|
||||
if skipped_oom:
|
||||
print(f" [info] {skipped_oom} GGUF(s) skipped (exceed {max_vram_mb} MB VRAM limit)")
|
||||
return new_tags
|
||||
|
||||
|
||||
def list_ollama_models() -> list[str]:
|
||||
"""Return model names from ollama /api/tags, filtered to text-gen candidates."""
|
||||
try:
|
||||
resp = httpx.get(f"{_OLLAMA_URL}/api/tags", timeout=10.0)
|
||||
resp.raise_for_status()
|
||||
models = resp.json().get("models", [])
|
||||
# Exclude embedding-only models
|
||||
exclude = {"mxbai-embed-large", "nomic-embed-text", "all-minilm"}
|
||||
return [
|
||||
m["name"] for m in models
|
||||
if not any(x in m["name"].lower() for x in exclude)
|
||||
]
|
||||
except Exception as exc:
|
||||
print(f"[warn] Could not reach ollama: {exc}", file=sys.stderr)
|
||||
return []
|
||||
|
||||
|
||||
# ── Run benchmark ─────────────────────────────────────────────────────────────
|
||||
|
||||
@dataclass
|
||||
class ModelResult:
|
||||
model_id: str
|
||||
prompt_results: list[dict[str, Any]] = field(default_factory=list)
|
||||
avg_score: float = 0.0
|
||||
avg_latency_ms: float = 0.0
|
||||
total_filler_hits: int = 0
|
||||
total_em_dashes: int = 0
|
||||
total_semicolons: int = 0
|
||||
|
||||
|
||||
def _bench_one_model(
|
||||
model_id: str,
|
||||
prompts: list[dict[str, str]],
|
||||
profile: Any,
|
||||
use_cforch: bool,
|
||||
cforch_url: str,
|
||||
use_vllm: bool = False,
|
||||
) -> "ModelResult | None":
|
||||
"""Run all prompts for a single model. Thread-safe — all output is prefixed with model_id.
|
||||
|
||||
Dispatch priority:
|
||||
use_vllm=True → allocate vllm via cf-orch, then generate_cftext() (OpenAI-compatible)
|
||||
use_cforch=True → allocate cf-text via cf-orch, then generate_cftext()
|
||||
else → direct ollama generate()
|
||||
Both vllm and cf-text expose /v1/chat/completions so generate_cftext() works for both.
|
||||
"""
|
||||
prefix = f"[{model_id}]"
|
||||
result = ModelResult(model_id=model_id)
|
||||
|
||||
service_url: str | None = None
|
||||
allocation_id: str = ""
|
||||
if use_vllm:
|
||||
alloc = cforch_allocate_vllm(model_id, cforch_url)
|
||||
if alloc is None:
|
||||
print(f"{prefix} [skip] vllm allocation failed", flush=True)
|
||||
return None
|
||||
service_url, allocation_id = alloc
|
||||
print(f"{prefix} vllm allocated: {service_url}", flush=True)
|
||||
elif use_cforch:
|
||||
alloc = cforch_allocate(model_id, cforch_url)
|
||||
if alloc is None:
|
||||
print(f"{prefix} [skip] cf-orch allocation failed", flush=True)
|
||||
return None
|
||||
service_url, allocation_id = alloc
|
||||
print(f"{prefix} allocated: {service_url}", flush=True)
|
||||
|
||||
try:
|
||||
for prompt_def in prompts:
|
||||
tag = prompt_def["tag"]
|
||||
user_prompt = (
|
||||
f"Thread: {prompt_def['thread_title']}\n\n"
|
||||
f"{prompt_def['thread_body']}\n\n"
|
||||
f"Write a reply:"
|
||||
)
|
||||
print(f"{prefix} [{tag}] generating...", flush=True)
|
||||
|
||||
if (use_cforch or use_vllm) and service_url:
|
||||
# Both cf-text and vllm expose /v1/chat/completions — same call
|
||||
output, elapsed_ms = generate_cftext(service_url, model_id, user_prompt, system=SYSTEM_PROMPT)
|
||||
else:
|
||||
output, elapsed_ms = generate(model_id, user_prompt, system=SYSTEM_PROMPT)
|
||||
|
||||
signals = extract_signals(output)
|
||||
score = score_against_profile(signals, profile)
|
||||
|
||||
print(f"{prefix} [{tag}] {score:.0f}/100 ({elapsed_ms:.0f}ms)", flush=True)
|
||||
if signals.filler_hits:
|
||||
print(f"{prefix} ⚠ filler: {signals.filler_hits}", flush=True)
|
||||
if signals.em_dash_count:
|
||||
print(f"{prefix} ⚠ em-dashes: {signals.em_dash_count}", flush=True)
|
||||
|
||||
result.prompt_results.append({
|
||||
"tag": tag,
|
||||
"user_prompt": user_prompt,
|
||||
"output": output,
|
||||
"signals": {
|
||||
"avg_sentence_length": signals.avg_sentence_length,
|
||||
"em_dash_count": signals.em_dash_count,
|
||||
"semicolon_count": signals.semicolon_count,
|
||||
"filler_hits": signals.filler_hits,
|
||||
"question_ratio": signals.question_ratio,
|
||||
"word_count": signals.word_count,
|
||||
},
|
||||
"score": score,
|
||||
"latency_ms": elapsed_ms,
|
||||
})
|
||||
finally:
|
||||
if (use_cforch or use_vllm) and allocation_id:
|
||||
cforch_release(allocation_id, cforch_url)
|
||||
|
||||
if not result.prompt_results:
|
||||
return None
|
||||
|
||||
scores = [r["score"] for r in result.prompt_results]
|
||||
latencies = [r["latency_ms"] for r in result.prompt_results]
|
||||
result.avg_score = sum(scores) / len(scores)
|
||||
result.avg_latency_ms = sum(latencies) / len(latencies)
|
||||
result.total_filler_hits = sum(len(r["signals"]["filler_hits"]) for r in result.prompt_results)
|
||||
result.total_em_dashes = sum(r["signals"]["em_dash_count"] for r in result.prompt_results)
|
||||
result.total_semicolons = sum(r["signals"]["semicolon_count"] for r in result.prompt_results)
|
||||
|
||||
print(f"{prefix} done — avg score {result.avg_score:.0f}/100", flush=True)
|
||||
return result
|
||||
|
||||
|
||||
def run_benchmark(
|
||||
model_ids: list[str],
|
||||
corpus_dir: Path,
|
||||
prompts: list[dict[str, str]],
|
||||
use_cforch: bool = False,
|
||||
use_vllm: bool = False,
|
||||
cforch_url: str = _CFORCH_URL,
|
||||
workers: int = 1,
|
||||
) -> list[ModelResult]:
|
||||
profile = build_corpus_profile(corpus_dir)
|
||||
if profile:
|
||||
print(f"Corpus profile loaded from {corpus_dir} ({len(list(corpus_dir.glob('*.txt')))} samples)")
|
||||
print(f" Target avg sentence length: {profile.avg_sentence_length:.1f} words")
|
||||
else:
|
||||
print(f"[warn] No corpus samples found in {corpus_dir} -- scoring on hard violations only")
|
||||
|
||||
backend = "vllm via cf-orch" if use_vllm else ("cf-text via cf-orch" if use_cforch else "ollama")
|
||||
print(f" Backend: {backend}")
|
||||
|
||||
effective_workers = min(workers, len(model_ids)) if model_ids else 1
|
||||
print(f" Workers: {effective_workers} (of {len(model_ids)} models)", flush=True)
|
||||
|
||||
results: list[ModelResult] = []
|
||||
|
||||
if effective_workers <= 1:
|
||||
# Sequential path — simpler output, easier to follow for single-model runs
|
||||
for model_id in model_ids:
|
||||
print(f"\n{'='*60}\nModel: {model_id}", flush=True)
|
||||
r = _bench_one_model(model_id, prompts, profile, use_cforch, cforch_url, use_vllm)
|
||||
if r:
|
||||
results.append(r)
|
||||
else:
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
print(f" Fanning out {len(model_ids)} models across {effective_workers} workers...", flush=True)
|
||||
with ThreadPoolExecutor(max_workers=effective_workers) as pool:
|
||||
futures = {
|
||||
pool.submit(_bench_one_model, mid, prompts, profile, use_cforch, cforch_url, use_vllm): mid
|
||||
for mid in model_ids
|
||||
}
|
||||
for future in as_completed(futures):
|
||||
r = future.result()
|
||||
if r:
|
||||
results.append(r)
|
||||
|
||||
return sorted(results, key=lambda r: r.avg_score, reverse=True)
|
||||
|
||||
|
||||
# ── Markdown report ───────────────────────────────────────────────────────────
|
||||
|
||||
def render_report(results: list[ModelResult], corpus_dir: Path) -> str:
|
||||
date_str = datetime.now().strftime("%Y-%m-%d %H:%M")
|
||||
lines: list[str] = [
|
||||
f"# Writing Style Benchmark Results",
|
||||
f"",
|
||||
f"**Date:** {date_str} ",
|
||||
f"**Corpus:** `{corpus_dir}` ",
|
||||
f"**Models tested:** {len(results)} ",
|
||||
f"**Prompts per model:** {len(TEST_PROMPTS)}",
|
||||
f"",
|
||||
f"## Rankings",
|
||||
f"",
|
||||
f"| Rank | Model | Score | Latency | Em-dashes | Fillers | Semicolons |",
|
||||
f"|------|-------|-------|---------|-----------|---------|------------|",
|
||||
]
|
||||
|
||||
for i, r in enumerate(results, 1):
|
||||
medal = {1: "🥇", 2: "🥈", 3: "🥉"}.get(i, f"#{i}")
|
||||
lines.append(
|
||||
f"| {medal} | `{r.model_id}` | {r.avg_score:.0f}/100 "
|
||||
f"| {r.avg_latency_ms:.0f}ms "
|
||||
f"| {r.total_em_dashes} "
|
||||
f"| {r.total_filler_hits} "
|
||||
f"| {r.total_semicolons} |"
|
||||
)
|
||||
|
||||
lines += ["", "## Sample Outputs", ""]
|
||||
|
||||
for r in results[:3]: # top 3 only to keep report readable
|
||||
lines += [f"### `{r.model_id}` (avg score: {r.avg_score:.0f})", ""]
|
||||
for pr in r.prompt_results:
|
||||
lines += [
|
||||
f"**Prompt:** {pr['tag']} ",
|
||||
f"**Score:** {pr['score']:.0f}/100 ",
|
||||
f"",
|
||||
f"```",
|
||||
pr["output"],
|
||||
f"```",
|
||||
f"",
|
||||
]
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def save_report(results: list[ModelResult], corpus_dir: Path) -> Path:
|
||||
_RESULTS_DIR.mkdir(exist_ok=True)
|
||||
date_str = datetime.now().strftime("%Y-%m-%d_%H%M")
|
||||
report_path = _RESULTS_DIR / f"style_{date_str}.md"
|
||||
report_path.write_text(render_report(results, corpus_dir), encoding="utf-8")
|
||||
|
||||
# Also save raw JSON for programmatic use
|
||||
json_path = _RESULTS_DIR / f"style_{date_str}.json"
|
||||
json_path.write_text(
|
||||
json.dumps(
|
||||
[
|
||||
{
|
||||
"model_id": r.model_id,
|
||||
"avg_score": r.avg_score,
|
||||
"avg_latency_ms": r.avg_latency_ms,
|
||||
"total_filler_hits": r.total_filler_hits,
|
||||
"total_em_dashes": r.total_em_dashes,
|
||||
"total_semicolons": r.total_semicolons,
|
||||
"prompt_results": r.prompt_results,
|
||||
}
|
||||
for r in results
|
||||
],
|
||||
indent=2,
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
return report_path
|
||||
|
||||
|
||||
# ── CLI commands ──────────────────────────────────────────────────────────────
|
||||
|
||||
def cmd_list_models(_args: argparse.Namespace) -> None:
|
||||
models = list_ollama_models()
|
||||
if not models:
|
||||
print("No models found (is ollama running?)")
|
||||
return
|
||||
print(f"{len(models)} models available:\n")
|
||||
for m in models:
|
||||
print(f" {m}")
|
||||
|
||||
|
||||
def cmd_run(args: argparse.Namespace) -> None:
|
||||
corpus_dir = Path(args.samples)
|
||||
if not corpus_dir.exists():
|
||||
print(f"[error] Corpus directory not found: {corpus_dir}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
max_vram_mb: int = getattr(args, "max_vram", 7200)
|
||||
use_cforch: bool = getattr(args, "cforch", False)
|
||||
use_vllm: bool = getattr(args, "vllm", False)
|
||||
cforch_url: str = getattr(args, "cforch_url", _CFORCH_URL)
|
||||
registered_tags: list[str] = []
|
||||
|
||||
def _filter_ollama_by_size(ids: list[str], include_large: bool) -> list[str]:
|
||||
"""Apply name-pattern size filter to ollama model list."""
|
||||
if include_large:
|
||||
return ids
|
||||
skip_patterns = ["270b", "70b", "32b", "30b", "21b", "20b", "deepseek-r1"]
|
||||
filtered = [m for m in ids if not any(p in m.lower() for p in skip_patterns)]
|
||||
skipped = len(ids) - len(filtered)
|
||||
if skipped:
|
||||
print(f"[info] Skipped {skipped} large model(s) by name pattern. "
|
||||
"Pass --include-large to include them.")
|
||||
return filtered
|
||||
|
||||
if args.models and args.models != "all":
|
||||
model_ids = [m.strip() for m in args.models.split(",") if m.strip()]
|
||||
elif use_cforch:
|
||||
# cf-orch path: pull model list from catalog, filter by vram_mb
|
||||
catalog = cforch_list_catalog(cforch_url)
|
||||
if not catalog:
|
||||
print("[warn] cf-orch catalog empty or unreachable -- falling back to ollama models")
|
||||
use_cforch = False
|
||||
model_ids = _filter_ollama_by_size(list_ollama_models(), args.include_large)
|
||||
if not model_ids:
|
||||
print("[error] No models found. Pass --models explicitly or check ollama.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
else:
|
||||
before = list(catalog.items())
|
||||
allowed = {mid: mb for mid, mb in before if mb == 0 or mb <= max_vram_mb}
|
||||
skipped_oom = {mid: mb for mid, mb in before if mid not in allowed}
|
||||
model_ids = list(allowed.keys())
|
||||
print(f"[info] cf-orch catalog: {len(before)} model(s), "
|
||||
f"{len(allowed)} within {max_vram_mb} MB VRAM limit")
|
||||
if skipped_oom:
|
||||
print(f"[info] Skipped (OOM risk): "
|
||||
+ ", ".join(f"{mid} ({mb} MB)" for mid, mb in sorted(skipped_oom.items())))
|
||||
else:
|
||||
# Ollama path
|
||||
model_ids = list_ollama_models()
|
||||
if not model_ids:
|
||||
print("[error] No models found. Pass --models explicitly or check ollama.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
# Backfill GGUFs from disk before filtering -- skips files that exceed VRAM limit
|
||||
if getattr(args, "scan_disk", None):
|
||||
llm_root = Path(args.scan_disk)
|
||||
print(f"\nScanning {llm_root} for unregistered GGUFs (limit: {max_vram_mb} MB)...")
|
||||
registered_tags = backfill_disk_models(llm_root, set(model_ids), max_vram_mb=max_vram_mb)
|
||||
model_ids = list_ollama_models() # re-fetch with new registrations
|
||||
|
||||
model_ids = _filter_ollama_by_size(model_ids, args.include_large)
|
||||
|
||||
print(f"\nRunning writing style benchmark on {len(model_ids)} model(s)...")
|
||||
try:
|
||||
results = run_benchmark(model_ids, corpus_dir, TEST_PROMPTS, use_cforch=use_cforch, use_vllm=use_vllm, cforch_url=cforch_url, workers=args.workers)
|
||||
report_path = save_report(results, corpus_dir)
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Results saved to: {report_path}")
|
||||
print(f"\n{render_report(results, corpus_dir)}")
|
||||
finally:
|
||||
if registered_tags:
|
||||
print(f"\nCleaning up {len(registered_tags)} temporary ollama registrations...")
|
||||
for tag in registered_tags:
|
||||
deregister_gguf(tag)
|
||||
|
||||
|
||||
def cmd_show_last(_args: argparse.Namespace) -> None:
|
||||
reports = sorted(_RESULTS_DIR.glob("style_*.md"), reverse=True)
|
||||
if not reports:
|
||||
print("No benchmark results found. Run --run first.")
|
||||
return
|
||||
print(reports[0].read_text(encoding="utf-8"))
|
||||
|
||||
|
||||
# ── Entry point ───────────────────────────────────────────────────────────────
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Writing style benchmark harness for local text-gen models",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
)
|
||||
sub = parser.add_subparsers(dest="cmd")
|
||||
|
||||
sub.add_parser("list-models", help="List available ollama models")
|
||||
|
||||
run_p = sub.add_parser("run", help="Run the benchmark")
|
||||
run_p.add_argument("--models", default="all", help="Comma-separated model IDs, or 'all'")
|
||||
run_p.add_argument("--samples", default=str(_CORPUS_DIR), help="Path to style corpus directory")
|
||||
run_p.add_argument("--include-large", action="store_true", help="Include models >20B params")
|
||||
run_p.add_argument("--scan-disk", metavar="LLM_ROOT", help="Scan directory for GGUFs not yet in ollama (e.g. /Library/Assets/LLM)")
|
||||
run_p.add_argument("--cforch", action="store_true", help="Route generation through cf-orch/cf-text instead of direct ollama")
|
||||
run_p.add_argument("--vllm", action="store_true", help="Route generation through cf-orch/vllm (OpenAI-compatible) instead of ollama")
|
||||
run_p.add_argument("--cforch-url", default=_CFORCH_URL, help=f"cf-orch coordinator URL (default: {_CFORCH_URL})")
|
||||
run_p.add_argument("--max-vram", type=int, default=7200, metavar="MB",
|
||||
help="Skip models whose VRAM footprint exceeds this limit in MB (default: 7200)")
|
||||
run_p.add_argument("--workers", type=int, default=1, metavar="N",
|
||||
help="Parallel workers — run N models simultaneously (default: 1; use 4+ with cf-orch)")
|
||||
|
||||
sub.add_parser("show-last", help="Print the most recent benchmark report")
|
||||
|
||||
# Also support legacy --list-models / --run / --show-last flags for manage.sh compat
|
||||
parser.add_argument("--list-models", action="store_true")
|
||||
parser.add_argument("--run", action="store_true")
|
||||
parser.add_argument("--show-last", action="store_true")
|
||||
parser.add_argument("--models", default="all")
|
||||
parser.add_argument("--samples", default=str(_CORPUS_DIR))
|
||||
parser.add_argument("--include-large", action="store_true")
|
||||
parser.add_argument("--scan-disk", metavar="LLM_ROOT")
|
||||
parser.add_argument("--cforch", action="store_true")
|
||||
parser.add_argument("--vllm", action="store_true")
|
||||
parser.add_argument("--cforch-url", default=_CFORCH_URL)
|
||||
parser.add_argument("--max-vram", type=int, default=7200, metavar="MB")
|
||||
parser.add_argument("--workers", type=int, default=1, metavar="N")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.cmd == "list-models" or args.list_models:
|
||||
cmd_list_models(args)
|
||||
elif args.cmd == "run" or args.run:
|
||||
cmd_run(args)
|
||||
elif args.cmd == "show-last" or args.show_last:
|
||||
cmd_show_last(args)
|
||||
else:
|
||||
parser.print_help()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -1,909 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
"""
|
||||
Voice benchmark harness -- score local text-gen models for writing style match.
|
||||
|
||||
Runs each model against a set of test prompts, extracts style signals from the
|
||||
outputs, compares them to a voice corpus, and produces a ranked markdown table.
|
||||
|
||||
Usage:
|
||||
# List available ollama models
|
||||
conda run -n cf python scripts/benchmark_voice.py --list-models
|
||||
|
||||
# Run against all models with default test prompts
|
||||
conda run -n cf python scripts/benchmark_voice.py --run
|
||||
|
||||
# Run specific models only
|
||||
conda run -n cf python scripts/benchmark_voice.py --run --models mistral:7b,llama3.1:8b
|
||||
|
||||
# Use a custom corpus directory
|
||||
conda run -n cf python scripts/benchmark_voice.py --run --samples data/voice_corpus/
|
||||
|
||||
# Print last results table
|
||||
conda run -n cf python scripts/benchmark_voice.py --show-last
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
_ROOT = Path(__file__).parent.parent
|
||||
_CORPUS_DIR = _ROOT / "data" / "voice_corpus"
|
||||
_RESULTS_DIR = _ROOT / "benchmark_results"
|
||||
_OLLAMA_URL = "http://localhost:11434"
|
||||
_CFORCH_URL = "http://localhost:7700"
|
||||
|
||||
# Subdirectories under --scan-disk root that may contain GGUFs
|
||||
_SCAN_SUBDIRS = ["textgen/models", "llama.cpp/models", "cf-text/models", "vllm/models"]
|
||||
|
||||
# ── Filler phrases that should be absent from good voice-match output ─────────
|
||||
FILLER_PHRASES: list[str] = [
|
||||
"delve", "certainly", "absolutely", "i apologize", "i'd be happy to",
|
||||
"of course", "great question", "i understand", "let me know if",
|
||||
"feel free to", "it's important to note", "it's worth noting",
|
||||
"in conclusion", "to summarize", "in summary",
|
||||
]
|
||||
|
||||
# ── Test prompts: (thread_title, thread_body, context_tag) ───────────────────
|
||||
# These are representative threads that Magpie might reply to.
|
||||
# Extend this list with real examples as the corpus grows.
|
||||
TEST_PROMPTS: list[dict[str, str]] = [
|
||||
{
|
||||
"tag": "selfhosted_ai_fatigue",
|
||||
"thread_title": "Anyone else getting tired of re-explaining their setup every time an AI model forgets?",
|
||||
"thread_body": (
|
||||
"Every session I start over. My whole hardware setup, what tools I use, "
|
||||
"what I've already tried. It's exhausting. There has to be a better way."
|
||||
),
|
||||
},
|
||||
{
|
||||
"tag": "privacy_local_llm",
|
||||
"thread_title": "What's the point of running local LLMs if the apps still phone home?",
|
||||
"thread_body": (
|
||||
"I went through all the trouble of setting up ollama and now I find out "
|
||||
"the frontend I'm using is sending telemetry. Kind of defeats the purpose."
|
||||
),
|
||||
},
|
||||
{
|
||||
"tag": "solarpunk_tech",
|
||||
"thread_title": "What does solarpunk computing actually look like in practice?",
|
||||
"thread_body": (
|
||||
"I keep seeing the aesthetic but not a lot of concrete examples of "
|
||||
"people living it out with their tech choices. What does it mean day to day?"
|
||||
),
|
||||
},
|
||||
{
|
||||
"tag": "nd_tools",
|
||||
"thread_title": "Tools that actually help with executive function vs ones that just add friction",
|
||||
"thread_body": (
|
||||
"I've tried a dozen productivity apps and most of them require more "
|
||||
"executive function to maintain than they save. What actually sticks for you?"
|
||||
),
|
||||
},
|
||||
{
|
||||
"tag": "data_ownership",
|
||||
"thread_title": "Who actually owns your data when you use a 'free' AI tool?",
|
||||
"thread_body": (
|
||||
"Read the ToS on three different AI assistants today. In all three cases "
|
||||
"your inputs can be used for training, shared with partners, and retained "
|
||||
"indefinitely. At what point does 'free' just mean you're the product?"
|
||||
),
|
||||
},
|
||||
{
|
||||
"tag": "digital_culture",
|
||||
"thread_title": "The internet used to feel like it belonged to everyone. What happened?",
|
||||
"thread_body": (
|
||||
"I grew up on forums, IRC, personal homepages. Now everything is a platform "
|
||||
"owned by someone trying to extract value from the community that built it. "
|
||||
"Is the fediverse / self-hosting movement actually reversing this or just "
|
||||
"a niche hobby?"
|
||||
),
|
||||
},
|
||||
]
|
||||
|
||||
GENERATION_PARAMS: dict[str, Any] = {
|
||||
"temperature": 0.7,
|
||||
"top_p": 0.9,
|
||||
"num_predict": 300,
|
||||
}
|
||||
|
||||
SYSTEM_PROMPT = (
|
||||
"You are a writing assistant. Your job is to write a Reddit reply that matches "
|
||||
"the voice, tone, and style of the provided samples exactly.\n\n"
|
||||
"Voice characteristics:\n"
|
||||
"- Casual engineer tone. Short punchy sentences.\n"
|
||||
"- No hype, no buzzwords, no em dashes, no semicolons.\n"
|
||||
"- Community-first perspective. Solarpunk values.\n"
|
||||
"- Direct and opinionated. No throat-clearing or filler.\n"
|
||||
"- When relevant, mention personal experience with real tools.\n\n"
|
||||
"Write ONLY the reply. No preamble, no 'Here is a reply:', no meta-commentary."
|
||||
)
|
||||
|
||||
|
||||
# ── Style signal extraction ───────────────────────────────────────────────────
|
||||
|
||||
@dataclass
|
||||
class StyleSignals:
|
||||
"""Quantitative style signals extracted from a text sample."""
|
||||
sentence_count: int = 0
|
||||
word_count: int = 0
|
||||
avg_sentence_length: float = 0.0
|
||||
em_dash_count: int = 0
|
||||
semicolon_count: int = 0
|
||||
filler_hits: list[str] = field(default_factory=list)
|
||||
question_ratio: float = 0.0 # fraction of sentences ending in '?'
|
||||
first_person_ratio: float = 0.0 # fraction of sentences starting with 'I'
|
||||
avg_word_length: float = 0.0
|
||||
|
||||
|
||||
def extract_signals(text: str) -> StyleSignals:
|
||||
"""Extract style signals from a text sample."""
|
||||
text = text.strip()
|
||||
if text.startswith("[ERROR:"):
|
||||
return StyleSignals() # zero-score sentinel — caller checks for empty output
|
||||
sentences = [s.strip() for s in re.split(r'(?<=[.!?])\s+', text) if s.strip()]
|
||||
words = text.split()
|
||||
|
||||
if not sentences:
|
||||
return StyleSignals()
|
||||
|
||||
avg_sentence_length = len(words) / len(sentences) if sentences else 0.0
|
||||
avg_word_length = (sum(len(w.strip('.,!?;:"\'')) for w in words) / len(words)) if words else 0.0
|
||||
|
||||
em_dash_count = text.count('\u2014') + text.count(' -- ') + text.count('--')
|
||||
semicolon_count = text.count(';')
|
||||
|
||||
filler_hits = [p for p in FILLER_PHRASES if p.lower() in text.lower()]
|
||||
|
||||
question_ratio = sum(1 for s in sentences if s.endswith('?')) / len(sentences)
|
||||
first_person_ratio = sum(1 for s in sentences if re.match(r"^I\b", s)) / len(sentences)
|
||||
|
||||
return StyleSignals(
|
||||
sentence_count=len(sentences),
|
||||
word_count=len(words),
|
||||
avg_sentence_length=avg_sentence_length,
|
||||
em_dash_count=em_dash_count,
|
||||
semicolon_count=semicolon_count,
|
||||
filler_hits=filler_hits,
|
||||
question_ratio=question_ratio,
|
||||
first_person_ratio=first_person_ratio,
|
||||
avg_word_length=avg_word_length,
|
||||
)
|
||||
|
||||
|
||||
def build_corpus_profile(corpus_dir: Path) -> StyleSignals | None:
|
||||
"""Aggregate style signals across all corpus samples into a target profile."""
|
||||
samples = list(corpus_dir.glob("*.txt"))
|
||||
if not samples:
|
||||
return None
|
||||
|
||||
all_signals = [extract_signals(p.read_text(encoding="utf-8")) for p in samples]
|
||||
n = len(all_signals)
|
||||
|
||||
return StyleSignals(
|
||||
sentence_count=int(sum(s.sentence_count for s in all_signals) / n),
|
||||
word_count=int(sum(s.word_count for s in all_signals) / n),
|
||||
avg_sentence_length=sum(s.avg_sentence_length for s in all_signals) / n,
|
||||
em_dash_count=int(sum(s.em_dash_count for s in all_signals) / n),
|
||||
semicolon_count=int(sum(s.semicolon_count for s in all_signals) / n),
|
||||
question_ratio=sum(s.question_ratio for s in all_signals) / n,
|
||||
first_person_ratio=sum(s.first_person_ratio for s in all_signals) / n,
|
||||
avg_word_length=sum(s.avg_word_length for s in all_signals) / n,
|
||||
)
|
||||
|
||||
|
||||
def score_against_profile(output_signals: StyleSignals, profile: StyleSignals | None) -> float:
|
||||
"""Score a model output against the corpus profile. Returns 0-100.
|
||||
|
||||
Penalties:
|
||||
- Em dashes / semicolons: -5 each occurrence (hard CF style violation)
|
||||
- Filler phrases: -8 each hit (strong signal of non-voice output)
|
||||
- Sentence length delta: proportional penalty (target: close to corpus avg)
|
||||
- Word length delta: smaller penalty
|
||||
|
||||
When no corpus profile is available, falls back to absolute signal scores only.
|
||||
"""
|
||||
score = 100.0
|
||||
|
||||
# Hard violations -- always penalised regardless of corpus
|
||||
score -= output_signals.em_dash_count * 5
|
||||
score -= output_signals.semicolon_count * 3
|
||||
score -= len(output_signals.filler_hits) * 8
|
||||
|
||||
if profile is not None:
|
||||
# Sentence length delta: penalise proportionally
|
||||
length_delta = abs(output_signals.avg_sentence_length - profile.avg_sentence_length)
|
||||
score -= min(length_delta * 2, 20)
|
||||
|
||||
# Question ratio delta
|
||||
question_delta = abs(output_signals.question_ratio - profile.question_ratio)
|
||||
score -= min(question_delta * 10, 10)
|
||||
|
||||
return max(0.0, score)
|
||||
|
||||
|
||||
# ── Ollama generation ─────────────────────────────────────────────────────────
|
||||
|
||||
_CFORCH_NODE_ID = "heimdall"
|
||||
|
||||
|
||||
def cforch_list_catalog(
|
||||
cforch_url: str = _CFORCH_URL,
|
||||
node_id: str = _CFORCH_NODE_ID,
|
||||
) -> dict[str, int]:
|
||||
"""Return the cf-text catalog from cf-orch as {model_id: vram_mb}.
|
||||
|
||||
Uses ?node_id= to request the catalog from a specific node's profile,
|
||||
avoiding cross-node catalog shadowing when multiple nodes define catalogs
|
||||
for the same service.
|
||||
"""
|
||||
try:
|
||||
resp = httpx.get(
|
||||
f"{cforch_url}/api/services/cf-text/catalog",
|
||||
params={"node_id": node_id} if node_id else {},
|
||||
timeout=10.0,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
raw = resp.json()
|
||||
return {
|
||||
model_id: (entry.get("vram_mb", 0) if isinstance(entry, dict) else 0)
|
||||
for model_id, entry in raw.items()
|
||||
}
|
||||
except Exception as exc:
|
||||
print(f"[warn] Could not reach cf-orch catalog at {cforch_url}: {exc}", file=sys.stderr)
|
||||
return {}
|
||||
|
||||
|
||||
def _cforch_allocate_service(
|
||||
service: str,
|
||||
model_id: str,
|
||||
cforch_url: str,
|
||||
startup_timeout_s: float,
|
||||
health_path: str,
|
||||
) -> tuple[str, str] | None:
|
||||
"""Generic cf-orch allocate + health-poll. Returns (service_url, allocation_id) or None."""
|
||||
try:
|
||||
resp = httpx.post(
|
||||
f"{cforch_url}/api/services/{service}/allocate",
|
||||
json={
|
||||
"model_candidates": [model_id],
|
||||
"caller": "avocet",
|
||||
"pipeline": "voice_benchmark",
|
||||
},
|
||||
timeout=120.0,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
service_url: str = data["url"]
|
||||
allocation_id: str = data.get("allocation_id", "")
|
||||
|
||||
if data.get("started", False) and not data.get("warm", True):
|
||||
label = service
|
||||
print(f" [cold start] waiting for {label} to load {model_id!r}...", end=" ", flush=True)
|
||||
deadline = time.monotonic() + startup_timeout_s
|
||||
while time.monotonic() < deadline:
|
||||
try:
|
||||
health = httpx.get(f"{service_url}{health_path}", timeout=3.0)
|
||||
if health.is_success:
|
||||
print(f"ready ({time.monotonic() - (deadline - startup_timeout_s):.0f}s)", flush=True)
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
time.sleep(2.0)
|
||||
else:
|
||||
print(f"timed out after {startup_timeout_s:.0f}s", flush=True)
|
||||
return None
|
||||
|
||||
return service_url, allocation_id
|
||||
except Exception as exc:
|
||||
print(f"[warn] cf-orch allocation failed for {model_id!r} ({service}): {exc}", file=sys.stderr)
|
||||
return None
|
||||
|
||||
|
||||
def cforch_allocate(
|
||||
model_id: str,
|
||||
cforch_url: str = _CFORCH_URL,
|
||||
startup_timeout_s: float = 180.0,
|
||||
) -> tuple[str, str] | None:
|
||||
"""Allocate a cf-text instance for model_id. Returns (service_url, allocation_id) or None."""
|
||||
return _cforch_allocate_service("cf-text", model_id, cforch_url, startup_timeout_s, "/health")
|
||||
|
||||
|
||||
def cforch_allocate_vllm(
|
||||
model_id: str,
|
||||
cforch_url: str = _CFORCH_URL,
|
||||
startup_timeout_s: float = 300.0,
|
||||
) -> tuple[str, str] | None:
|
||||
"""Allocate a vllm instance for model_id. Returns (service_url, allocation_id) or None.
|
||||
|
||||
vllm exposes an OpenAI-compatible API — generate_cftext() works unchanged
|
||||
against the returned service_url. Startup timeout is longer (300s) because
|
||||
vllm loads large model weights from disk before becoming ready.
|
||||
"""
|
||||
return _cforch_allocate_service("vllm", model_id, cforch_url, startup_timeout_s, "/health")
|
||||
|
||||
|
||||
def cforch_release(allocation_id: str, cforch_url: str = _CFORCH_URL) -> None:
|
||||
"""Release a cf-orch allocation."""
|
||||
if not allocation_id:
|
||||
return
|
||||
try:
|
||||
httpx.post(f"{cforch_url}/api/leases/{allocation_id}/release", timeout=10.0)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def generate_cftext(
|
||||
service_url: str,
|
||||
model_id: str,
|
||||
prompt: str,
|
||||
system: str = "",
|
||||
) -> tuple[str, float]:
|
||||
"""Call cf-text via OpenAI-compatible /v1/chat/completions. Returns (text, elapsed_ms)."""
|
||||
messages: list[dict[str, str]] = []
|
||||
if system:
|
||||
messages.append({"role": "system", "content": system})
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
payload: dict[str, Any] = {
|
||||
"model": model_id,
|
||||
"messages": messages,
|
||||
"max_tokens": GENERATION_PARAMS.get("num_predict", 300),
|
||||
"temperature": GENERATION_PARAMS.get("temperature", 0.7),
|
||||
"top_p": GENERATION_PARAMS.get("top_p", 0.9),
|
||||
"stream": False,
|
||||
}
|
||||
|
||||
t0 = time.monotonic()
|
||||
try:
|
||||
resp = httpx.post(
|
||||
f"{service_url.rstrip('/')}/v1/chat/completions",
|
||||
json=payload,
|
||||
timeout=180.0,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
elapsed_ms = (time.monotonic() - t0) * 1000
|
||||
content = resp.json()["choices"][0]["message"]["content"]
|
||||
return content.strip(), elapsed_ms
|
||||
except Exception as exc:
|
||||
elapsed_ms = (time.monotonic() - t0) * 1000
|
||||
return f"[ERROR: {exc}]", elapsed_ms
|
||||
|
||||
|
||||
def generate(model_id: str, prompt: str, system: str = "") -> tuple[str, float]:
|
||||
"""Call ollama /api/generate. Returns (text, elapsed_ms)."""
|
||||
payload: dict[str, Any] = {
|
||||
"model": model_id,
|
||||
"prompt": prompt,
|
||||
"stream": False,
|
||||
"options": GENERATION_PARAMS,
|
||||
}
|
||||
if system:
|
||||
payload["system"] = system
|
||||
|
||||
t0 = time.monotonic()
|
||||
try:
|
||||
resp = httpx.post(
|
||||
f"{_OLLAMA_URL}/api/generate",
|
||||
json=payload,
|
||||
timeout=120.0,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
elapsed_ms = (time.monotonic() - t0) * 1000
|
||||
return resp.json().get("response", "").strip(), elapsed_ms
|
||||
except Exception as exc:
|
||||
elapsed_ms = (time.monotonic() - t0) * 1000
|
||||
return f"[ERROR: {exc}]", elapsed_ms
|
||||
|
||||
|
||||
def find_disk_ggufs(llm_root: Path) -> list[Path]:
|
||||
"""Recursively find .gguf files under known subdirs of llm_root.
|
||||
|
||||
Skips vocab-only GGUFs (ggml-vocab-*) which aren't standalone models.
|
||||
"""
|
||||
found: list[Path] = []
|
||||
search_dirs = [llm_root / sub for sub in _SCAN_SUBDIRS] + [llm_root]
|
||||
seen: set[Path] = set()
|
||||
for base in search_dirs:
|
||||
if not base.exists():
|
||||
continue
|
||||
for gguf in base.rglob("*.gguf"):
|
||||
if gguf in seen:
|
||||
continue
|
||||
seen.add(gguf)
|
||||
if gguf.name.startswith("ggml-vocab-"):
|
||||
continue
|
||||
found.append(gguf)
|
||||
return sorted(found)
|
||||
|
||||
|
||||
def gguf_to_ollama_tag(gguf_path: Path) -> str:
|
||||
"""Derive a stable ollama tag from a GGUF path.
|
||||
|
||||
Uses parent dir name + stem to avoid collisions, e.g.:
|
||||
claude-3.7-sonnet-reasoning-gemma3-12B/foo.Q8_0.gguf
|
||||
→ bench-claude-3.7-sonnet-reasoning-gemma3-12b-foo-q8-0
|
||||
"""
|
||||
parent = gguf_path.parent.name.lower()
|
||||
stem = gguf_path.stem.lower()
|
||||
# If stem is contained in parent (common pattern), just use parent
|
||||
slug = parent if stem.replace("-", "").replace("_", "") in parent.replace("-", "").replace("_", "") else f"{parent}-{stem}"
|
||||
slug = re.sub(r"[^a-z0-9]+", "-", slug).strip("-")
|
||||
return f"bench-{slug}:latest"
|
||||
|
||||
|
||||
def register_gguf(gguf_path: Path, tag: str) -> bool:
|
||||
"""Create a temporary ollama model entry from a GGUF file. Returns True on success."""
|
||||
import subprocess
|
||||
import tempfile
|
||||
modelfile = f"FROM {gguf_path.resolve()}\n"
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".Modelfile", delete=False) as f:
|
||||
f.write(modelfile)
|
||||
modelfile_path = f.name
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["ollama", "create", tag, "-f", modelfile_path],
|
||||
capture_output=True, text=True, timeout=60,
|
||||
)
|
||||
return result.returncode == 0
|
||||
except Exception as exc:
|
||||
print(f"[warn] Could not register {gguf_path.name}: {exc}", file=sys.stderr)
|
||||
return False
|
||||
finally:
|
||||
Path(modelfile_path).unlink(missing_ok=True)
|
||||
|
||||
|
||||
def deregister_gguf(tag: str) -> None:
|
||||
"""Remove a temporary ollama model entry."""
|
||||
import subprocess
|
||||
try:
|
||||
subprocess.run(["ollama", "rm", tag], capture_output=True, timeout=30)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def backfill_disk_models(
|
||||
llm_root: Path,
|
||||
existing_tags: set[str],
|
||||
max_vram_mb: int = 0,
|
||||
) -> list[str]:
|
||||
"""Register GGUFs from disk that aren't already in ollama. Returns new tags.
|
||||
|
||||
max_vram_mb: skip files whose size exceeds this threshold (0 = no limit).
|
||||
GGUF file size is a reliable VRAM proxy -- quantized weights load ~1:1.
|
||||
"""
|
||||
ggufs = find_disk_ggufs(llm_root)
|
||||
if not ggufs:
|
||||
print(f"No .gguf files found under {llm_root}", file=sys.stderr)
|
||||
return []
|
||||
|
||||
new_tags: list[str] = []
|
||||
skipped_oom = 0
|
||||
for gguf in ggufs:
|
||||
size_mb = gguf.stat().st_size // (1024 * 1024)
|
||||
if max_vram_mb and size_mb > max_vram_mb:
|
||||
print(f" [skip-oom] {gguf.name} ({size_mb} MB > {max_vram_mb} MB limit)")
|
||||
skipped_oom += 1
|
||||
continue
|
||||
tag = gguf_to_ollama_tag(gguf)
|
||||
if tag in existing_tags:
|
||||
print(f" [skip] {gguf.name} already registered as {tag}")
|
||||
continue
|
||||
print(f" [register] {gguf.name} ({size_mb} MB) → {tag} ...", end=" ", flush=True)
|
||||
if register_gguf(gguf, tag):
|
||||
print("ok")
|
||||
new_tags.append(tag)
|
||||
else:
|
||||
print("failed")
|
||||
|
||||
if skipped_oom:
|
||||
print(f" [info] {skipped_oom} GGUF(s) skipped (exceed {max_vram_mb} MB VRAM limit)")
|
||||
return new_tags
|
||||
|
||||
|
||||
def list_ollama_models() -> list[str]:
|
||||
"""Return model names from ollama /api/tags, filtered to text-gen candidates."""
|
||||
try:
|
||||
resp = httpx.get(f"{_OLLAMA_URL}/api/tags", timeout=10.0)
|
||||
resp.raise_for_status()
|
||||
models = resp.json().get("models", [])
|
||||
# Exclude embedding-only models
|
||||
exclude = {"mxbai-embed-large", "nomic-embed-text", "all-minilm"}
|
||||
return [
|
||||
m["name"] for m in models
|
||||
if not any(x in m["name"].lower() for x in exclude)
|
||||
]
|
||||
except Exception as exc:
|
||||
print(f"[warn] Could not reach ollama: {exc}", file=sys.stderr)
|
||||
return []
|
||||
|
||||
|
||||
# ── Run benchmark ─────────────────────────────────────────────────────────────
|
||||
|
||||
@dataclass
|
||||
class ModelResult:
|
||||
model_id: str
|
||||
prompt_results: list[dict[str, Any]] = field(default_factory=list)
|
||||
avg_score: float = 0.0
|
||||
avg_latency_ms: float = 0.0
|
||||
total_filler_hits: int = 0
|
||||
total_em_dashes: int = 0
|
||||
total_semicolons: int = 0
|
||||
|
||||
|
||||
def _bench_one_model(
|
||||
model_id: str,
|
||||
prompts: list[dict[str, str]],
|
||||
profile: Any,
|
||||
use_cforch: bool,
|
||||
cforch_url: str,
|
||||
use_vllm: bool = False,
|
||||
) -> "ModelResult | None":
|
||||
"""Run all prompts for a single model. Thread-safe — all output is prefixed with model_id.
|
||||
|
||||
Dispatch priority:
|
||||
use_vllm=True → allocate vllm via cf-orch, then generate_cftext() (OpenAI-compatible)
|
||||
use_cforch=True → allocate cf-text via cf-orch, then generate_cftext()
|
||||
else → direct ollama generate()
|
||||
Both vllm and cf-text expose /v1/chat/completions so generate_cftext() works for both.
|
||||
"""
|
||||
prefix = f"[{model_id}]"
|
||||
result = ModelResult(model_id=model_id)
|
||||
|
||||
service_url: str | None = None
|
||||
allocation_id: str = ""
|
||||
if use_vllm:
|
||||
alloc = cforch_allocate_vllm(model_id, cforch_url)
|
||||
if alloc is None:
|
||||
print(f"{prefix} [skip] vllm allocation failed", flush=True)
|
||||
return None
|
||||
service_url, allocation_id = alloc
|
||||
print(f"{prefix} vllm allocated: {service_url}", flush=True)
|
||||
elif use_cforch:
|
||||
alloc = cforch_allocate(model_id, cforch_url)
|
||||
if alloc is None:
|
||||
print(f"{prefix} [skip] cf-orch allocation failed", flush=True)
|
||||
return None
|
||||
service_url, allocation_id = alloc
|
||||
print(f"{prefix} allocated: {service_url}", flush=True)
|
||||
|
||||
try:
|
||||
for prompt_def in prompts:
|
||||
tag = prompt_def["tag"]
|
||||
user_prompt = (
|
||||
f"Thread: {prompt_def['thread_title']}\n\n"
|
||||
f"{prompt_def['thread_body']}\n\n"
|
||||
f"Write a reply:"
|
||||
)
|
||||
print(f"{prefix} [{tag}] generating...", flush=True)
|
||||
|
||||
if (use_cforch or use_vllm) and service_url:
|
||||
# Both cf-text and vllm expose /v1/chat/completions — same call
|
||||
output, elapsed_ms = generate_cftext(service_url, model_id, user_prompt, system=SYSTEM_PROMPT)
|
||||
else:
|
||||
output, elapsed_ms = generate(model_id, user_prompt, system=SYSTEM_PROMPT)
|
||||
|
||||
signals = extract_signals(output)
|
||||
score = score_against_profile(signals, profile)
|
||||
|
||||
print(f"{prefix} [{tag}] {score:.0f}/100 ({elapsed_ms:.0f}ms)", flush=True)
|
||||
if signals.filler_hits:
|
||||
print(f"{prefix} ⚠ filler: {signals.filler_hits}", flush=True)
|
||||
if signals.em_dash_count:
|
||||
print(f"{prefix} ⚠ em-dashes: {signals.em_dash_count}", flush=True)
|
||||
|
||||
result.prompt_results.append({
|
||||
"tag": tag,
|
||||
"user_prompt": user_prompt,
|
||||
"output": output,
|
||||
"signals": {
|
||||
"avg_sentence_length": signals.avg_sentence_length,
|
||||
"em_dash_count": signals.em_dash_count,
|
||||
"semicolon_count": signals.semicolon_count,
|
||||
"filler_hits": signals.filler_hits,
|
||||
"question_ratio": signals.question_ratio,
|
||||
"word_count": signals.word_count,
|
||||
},
|
||||
"score": score,
|
||||
"latency_ms": elapsed_ms,
|
||||
})
|
||||
finally:
|
||||
if use_cforch and allocation_id:
|
||||
cforch_release(allocation_id, cforch_url)
|
||||
|
||||
if not result.prompt_results:
|
||||
return None
|
||||
|
||||
scores = [r["score"] for r in result.prompt_results]
|
||||
latencies = [r["latency_ms"] for r in result.prompt_results]
|
||||
result.avg_score = sum(scores) / len(scores)
|
||||
result.avg_latency_ms = sum(latencies) / len(latencies)
|
||||
result.total_filler_hits = sum(len(r["signals"]["filler_hits"]) for r in result.prompt_results)
|
||||
result.total_em_dashes = sum(r["signals"]["em_dash_count"] for r in result.prompt_results)
|
||||
result.total_semicolons = sum(r["signals"]["semicolon_count"] for r in result.prompt_results)
|
||||
|
||||
print(f"{prefix} done — avg score {result.avg_score:.0f}/100", flush=True)
|
||||
return result
|
||||
|
||||
|
||||
def run_benchmark(
|
||||
model_ids: list[str],
|
||||
corpus_dir: Path,
|
||||
prompts: list[dict[str, str]],
|
||||
use_cforch: bool = False,
|
||||
use_vllm: bool = False,
|
||||
cforch_url: str = _CFORCH_URL,
|
||||
workers: int = 1,
|
||||
) -> list[ModelResult]:
|
||||
profile = build_corpus_profile(corpus_dir)
|
||||
if profile:
|
||||
print(f"Corpus profile loaded from {corpus_dir} ({len(list(corpus_dir.glob('*.txt')))} samples)")
|
||||
print(f" Target avg sentence length: {profile.avg_sentence_length:.1f} words")
|
||||
else:
|
||||
print(f"[warn] No corpus samples found in {corpus_dir} -- scoring on hard violations only")
|
||||
|
||||
backend = "vllm via cf-orch" if use_vllm else ("cf-text via cf-orch" if use_cforch else "ollama")
|
||||
print(f" Backend: {backend}")
|
||||
|
||||
effective_workers = min(workers, len(model_ids)) if model_ids else 1
|
||||
print(f" Workers: {effective_workers} (of {len(model_ids)} models)", flush=True)
|
||||
|
||||
results: list[ModelResult] = []
|
||||
|
||||
if effective_workers <= 1:
|
||||
# Sequential path — simpler output, easier to follow for single-model runs
|
||||
for model_id in model_ids:
|
||||
print(f"\n{'='*60}\nModel: {model_id}", flush=True)
|
||||
r = _bench_one_model(model_id, prompts, profile, use_cforch, cforch_url, use_vllm)
|
||||
if r:
|
||||
results.append(r)
|
||||
else:
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
print(f" Fanning out {len(model_ids)} models across {effective_workers} workers...", flush=True)
|
||||
with ThreadPoolExecutor(max_workers=effective_workers) as pool:
|
||||
futures = {
|
||||
pool.submit(_bench_one_model, mid, prompts, profile, use_cforch, cforch_url, use_vllm): mid
|
||||
for mid in model_ids
|
||||
}
|
||||
for future in as_completed(futures):
|
||||
r = future.result()
|
||||
if r:
|
||||
results.append(r)
|
||||
|
||||
return sorted(results, key=lambda r: r.avg_score, reverse=True)
|
||||
|
||||
|
||||
# ── Markdown report ───────────────────────────────────────────────────────────
|
||||
|
||||
def render_report(results: list[ModelResult], corpus_dir: Path) -> str:
|
||||
date_str = datetime.now().strftime("%Y-%m-%d %H:%M")
|
||||
lines: list[str] = [
|
||||
f"# Voice Benchmark Results",
|
||||
f"",
|
||||
f"**Date:** {date_str} ",
|
||||
f"**Corpus:** `{corpus_dir}` ",
|
||||
f"**Models tested:** {len(results)} ",
|
||||
f"**Prompts per model:** {len(TEST_PROMPTS)}",
|
||||
f"",
|
||||
f"## Rankings",
|
||||
f"",
|
||||
f"| Rank | Model | Score | Latency | Em-dashes | Fillers | Semicolons |",
|
||||
f"|------|-------|-------|---------|-----------|---------|------------|",
|
||||
]
|
||||
|
||||
for i, r in enumerate(results, 1):
|
||||
medal = {1: "🥇", 2: "🥈", 3: "🥉"}.get(i, f"#{i}")
|
||||
lines.append(
|
||||
f"| {medal} | `{r.model_id}` | {r.avg_score:.0f}/100 "
|
||||
f"| {r.avg_latency_ms:.0f}ms "
|
||||
f"| {r.total_em_dashes} "
|
||||
f"| {r.total_filler_hits} "
|
||||
f"| {r.total_semicolons} |"
|
||||
)
|
||||
|
||||
lines += ["", "## Sample Outputs", ""]
|
||||
|
||||
for r in results[:3]: # top 3 only to keep report readable
|
||||
lines += [f"### `{r.model_id}` (avg score: {r.avg_score:.0f})", ""]
|
||||
for pr in r.prompt_results:
|
||||
lines += [
|
||||
f"**Prompt:** {pr['tag']} ",
|
||||
f"**Score:** {pr['score']:.0f}/100 ",
|
||||
f"",
|
||||
f"```",
|
||||
pr["output"],
|
||||
f"```",
|
||||
f"",
|
||||
]
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def save_report(results: list[ModelResult], corpus_dir: Path) -> Path:
|
||||
_RESULTS_DIR.mkdir(exist_ok=True)
|
||||
date_str = datetime.now().strftime("%Y-%m-%d_%H%M")
|
||||
report_path = _RESULTS_DIR / f"voice_{date_str}.md"
|
||||
report_path.write_text(render_report(results, corpus_dir), encoding="utf-8")
|
||||
|
||||
# Also save raw JSON for programmatic use
|
||||
json_path = _RESULTS_DIR / f"voice_{date_str}.json"
|
||||
json_path.write_text(
|
||||
json.dumps(
|
||||
[
|
||||
{
|
||||
"model_id": r.model_id,
|
||||
"avg_score": r.avg_score,
|
||||
"avg_latency_ms": r.avg_latency_ms,
|
||||
"total_filler_hits": r.total_filler_hits,
|
||||
"total_em_dashes": r.total_em_dashes,
|
||||
"total_semicolons": r.total_semicolons,
|
||||
"prompt_results": r.prompt_results,
|
||||
}
|
||||
for r in results
|
||||
],
|
||||
indent=2,
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
return report_path
|
||||
|
||||
|
||||
# ── CLI commands ──────────────────────────────────────────────────────────────
|
||||
|
||||
def cmd_list_models(_args: argparse.Namespace) -> None:
|
||||
models = list_ollama_models()
|
||||
if not models:
|
||||
print("No models found (is ollama running?)")
|
||||
return
|
||||
print(f"{len(models)} models available:\n")
|
||||
for m in models:
|
||||
print(f" {m}")
|
||||
|
||||
|
||||
def cmd_run(args: argparse.Namespace) -> None:
|
||||
corpus_dir = Path(args.samples)
|
||||
if not corpus_dir.exists():
|
||||
print(f"[error] Corpus directory not found: {corpus_dir}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
max_vram_mb: int = getattr(args, "max_vram", 7200)
|
||||
use_cforch: bool = getattr(args, "cforch", False)
|
||||
use_vllm: bool = getattr(args, "vllm", False)
|
||||
cforch_url: str = getattr(args, "cforch_url", _CFORCH_URL)
|
||||
registered_tags: list[str] = []
|
||||
|
||||
def _filter_ollama_by_size(ids: list[str], include_large: bool) -> list[str]:
|
||||
"""Apply name-pattern size filter to ollama model list."""
|
||||
if include_large:
|
||||
return ids
|
||||
skip_patterns = ["270b", "70b", "32b", "30b", "21b", "20b", "deepseek-r1"]
|
||||
filtered = [m for m in ids if not any(p in m.lower() for p in skip_patterns)]
|
||||
skipped = len(ids) - len(filtered)
|
||||
if skipped:
|
||||
print(f"[info] Skipped {skipped} large model(s) by name pattern. "
|
||||
"Pass --include-large to include them.")
|
||||
return filtered
|
||||
|
||||
if args.models and args.models != "all":
|
||||
model_ids = [m.strip() for m in args.models.split(",") if m.strip()]
|
||||
elif use_cforch:
|
||||
# cf-orch path: pull model list from catalog, filter by vram_mb
|
||||
catalog = cforch_list_catalog(cforch_url)
|
||||
if not catalog:
|
||||
print("[warn] cf-orch catalog empty or unreachable -- falling back to ollama models")
|
||||
use_cforch = False
|
||||
model_ids = _filter_ollama_by_size(list_ollama_models(), args.include_large)
|
||||
if not model_ids:
|
||||
print("[error] No models found. Pass --models explicitly or check ollama.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
else:
|
||||
before = list(catalog.items())
|
||||
allowed = {mid: mb for mid, mb in before if mb == 0 or mb <= max_vram_mb}
|
||||
skipped_oom = {mid: mb for mid, mb in before if mid not in allowed}
|
||||
model_ids = list(allowed.keys())
|
||||
print(f"[info] cf-orch catalog: {len(before)} model(s), "
|
||||
f"{len(allowed)} within {max_vram_mb} MB VRAM limit")
|
||||
if skipped_oom:
|
||||
print(f"[info] Skipped (OOM risk): "
|
||||
+ ", ".join(f"{mid} ({mb} MB)" for mid, mb in sorted(skipped_oom.items())))
|
||||
else:
|
||||
# Ollama path
|
||||
model_ids = list_ollama_models()
|
||||
if not model_ids:
|
||||
print("[error] No models found. Pass --models explicitly or check ollama.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
# Backfill GGUFs from disk before filtering -- skips files that exceed VRAM limit
|
||||
if getattr(args, "scan_disk", None):
|
||||
llm_root = Path(args.scan_disk)
|
||||
print(f"\nScanning {llm_root} for unregistered GGUFs (limit: {max_vram_mb} MB)...")
|
||||
registered_tags = backfill_disk_models(llm_root, set(model_ids), max_vram_mb=max_vram_mb)
|
||||
model_ids = list_ollama_models() # re-fetch with new registrations
|
||||
|
||||
model_ids = _filter_ollama_by_size(model_ids, args.include_large)
|
||||
|
||||
print(f"\nRunning voice benchmark on {len(model_ids)} model(s)...")
|
||||
try:
|
||||
results = run_benchmark(model_ids, corpus_dir, TEST_PROMPTS, use_cforch=use_cforch, use_vllm=use_vllm, cforch_url=cforch_url, workers=args.workers)
|
||||
report_path = save_report(results, corpus_dir)
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Results saved to: {report_path}")
|
||||
print(f"\n{render_report(results, corpus_dir)}")
|
||||
finally:
|
||||
if registered_tags:
|
||||
print(f"\nCleaning up {len(registered_tags)} temporary ollama registrations...")
|
||||
for tag in registered_tags:
|
||||
deregister_gguf(tag)
|
||||
|
||||
|
||||
def cmd_show_last(_args: argparse.Namespace) -> None:
|
||||
reports = sorted(_RESULTS_DIR.glob("voice_*.md"), reverse=True)
|
||||
if not reports:
|
||||
print("No benchmark results found. Run --run first.")
|
||||
return
|
||||
print(reports[0].read_text(encoding="utf-8"))
|
||||
|
||||
|
||||
# ── Entry point ───────────────────────────────────────────────────────────────
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Voice benchmark harness for local text-gen models",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
)
|
||||
sub = parser.add_subparsers(dest="cmd")
|
||||
|
||||
sub.add_parser("list-models", help="List available ollama models")
|
||||
|
||||
run_p = sub.add_parser("run", help="Run the benchmark")
|
||||
run_p.add_argument("--models", default="all", help="Comma-separated model IDs, or 'all'")
|
||||
run_p.add_argument("--samples", default=str(_CORPUS_DIR), help="Path to voice corpus directory")
|
||||
run_p.add_argument("--include-large", action="store_true", help="Include models >20B params")
|
||||
run_p.add_argument("--scan-disk", metavar="LLM_ROOT", help="Scan directory for GGUFs not yet in ollama (e.g. /Library/Assets/LLM)")
|
||||
run_p.add_argument("--cforch", action="store_true", help="Route generation through cf-orch/cf-text instead of direct ollama")
|
||||
run_p.add_argument("--vllm", action="store_true", help="Route generation through cf-orch/vllm (OpenAI-compatible) instead of ollama")
|
||||
run_p.add_argument("--cforch-url", default=_CFORCH_URL, help=f"cf-orch coordinator URL (default: {_CFORCH_URL})")
|
||||
run_p.add_argument("--max-vram", type=int, default=7200, metavar="MB",
|
||||
help="Skip models whose VRAM footprint exceeds this limit in MB (default: 7200)")
|
||||
run_p.add_argument("--workers", type=int, default=1, metavar="N",
|
||||
help="Parallel workers — run N models simultaneously (default: 1; use 4+ with cf-orch)")
|
||||
|
||||
sub.add_parser("show-last", help="Print the most recent benchmark report")
|
||||
|
||||
# Also support legacy --list-models / --run / --show-last flags for manage.sh compat
|
||||
parser.add_argument("--list-models", action="store_true")
|
||||
parser.add_argument("--run", action="store_true")
|
||||
parser.add_argument("--show-last", action="store_true")
|
||||
parser.add_argument("--models", default="all")
|
||||
parser.add_argument("--samples", default=str(_CORPUS_DIR))
|
||||
parser.add_argument("--include-large", action="store_true")
|
||||
parser.add_argument("--scan-disk", metavar="LLM_ROOT")
|
||||
parser.add_argument("--cforch", action="store_true")
|
||||
parser.add_argument("--vllm", action="store_true")
|
||||
parser.add_argument("--cforch-url", default=_CFORCH_URL)
|
||||
parser.add_argument("--max-vram", type=int, default=7200, metavar="MB")
|
||||
parser.add_argument("--workers", type=int, default=1, metavar="N")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.cmd == "list-models" or args.list_models:
|
||||
cmd_list_models(args)
|
||||
elif args.cmd == "run" or args.run:
|
||||
cmd_run(args)
|
||||
elif args.cmd == "show-last" or args.show_last:
|
||||
cmd_show_last(args)
|
||||
else:
|
||||
parser.print_help()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -7,26 +7,19 @@ from __future__ import annotations
|
|||
|
||||
import abc
|
||||
from collections import defaultdict
|
||||
import httpx
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
__all__ = [
|
||||
"LABELS",
|
||||
"LABEL_DESCRIPTIONS",
|
||||
"DEFAULT_EXEMPLARS",
|
||||
"compute_metrics",
|
||||
"ClassifierAdapter",
|
||||
"ZeroShotAdapter",
|
||||
"GLiClassAdapter",
|
||||
"RerankerAdapter",
|
||||
"FineTunedAdapter",
|
||||
"EmbeddingKNNAdapter",
|
||||
]
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
LABELS: list[str] = [
|
||||
"interview_scheduled",
|
||||
"offer_received",
|
||||
|
|
@ -124,81 +117,6 @@ def compute_metrics(
|
|||
return result
|
||||
|
||||
|
||||
|
||||
def _cosine(a: list[float], b: list[float]) -> float:
|
||||
dot = sum(x * y for x, y in zip(a, b))
|
||||
norm_a = sum(x * x for x in a) ** 0.5
|
||||
norm_b = sum(x * x for x in b) ** 0.5
|
||||
return dot / (norm_a * norm_b) if norm_a and norm_b else 0.0
|
||||
|
||||
|
||||
DEFAULT_EXEMPLARS: dict[str, list[str]] = {
|
||||
"interview_scheduled": [
|
||||
"Subject: Interview Invitation\n\nWe would like to invite you for a phone screen next week.",
|
||||
"Subject: Schedule a call\n\nCould you be available for a video interview on Tuesday?",
|
||||
"Subject: Next Steps\n\nWe'd like to move forward with a technical interview. Please select a time.",
|
||||
"Subject: Interview Details\n\nHere are the dial-in instructions for your interview tomorrow.",
|
||||
],
|
||||
"offer_received": [
|
||||
"Subject: Offer Letter Enclosed\n\nWe are pleased to extend you an offer of employment.",
|
||||
"Subject: Job Offer\n\nDear candidate, we are excited to offer you the position of Software Engineer.",
|
||||
"Subject: Employment Offer\n\nPlease find attached your formal offer letter and compensation details.",
|
||||
"Subject: Offer of Employment\n\nCongratulations! We would like to offer you a full-time position.",
|
||||
],
|
||||
"rejected": [
|
||||
"Subject: Your Application\n\nAfter careful consideration, we have decided to move forward with other candidates.",
|
||||
"Subject: Application Status\n\nWe regret to inform you that your application has not been selected.",
|
||||
"Subject: Thank you for applying\n\nWe appreciate your interest but have chosen not to proceed.",
|
||||
"Subject: Update on your candidacy\n\nWe will not be moving forward with your application at this time.",
|
||||
],
|
||||
"positive_response": [
|
||||
"Subject: Your profile\n\nI came across your LinkedIn and think you would be a great fit for our team.",
|
||||
"Subject: Exciting opportunity\n\nWe were impressed by your background and would love to connect.",
|
||||
"Subject: Following up\n\nThank you for your interest — we'd like to learn more about your experience.",
|
||||
"Subject: Great fit\n\nYour skills align well with what we are looking for. Let's set up a call.",
|
||||
],
|
||||
"survey_received": [
|
||||
"Subject: Candidate Experience Survey\n\nPlease complete this brief survey about your application experience.",
|
||||
"Subject: Culture Fit Assessment\n\nAs part of our process, we ask all candidates to complete a short assessment.",
|
||||
"Subject: Skills Assessment\n\nWe'd like you to complete our online coding assessment before proceeding.",
|
||||
"Subject: Personality Assessment\n\nPlease complete the following assessment as the next step in our process.",
|
||||
"Subject: Pre-interview questionnaire\n\nBefore we schedule your interview, please complete this brief skills survey.",
|
||||
],
|
||||
"neutral": [
|
||||
"Subject: Application Received\n\nWe have received your application and will be in touch.",
|
||||
"Subject: Thank you for applying\n\nYour application is under review. We will contact you if needed.",
|
||||
"Subject: Confirmation\n\nThis email confirms receipt of your application to our company.",
|
||||
"Subject: Application Confirmation\n\nThank you for your interest. We will review your materials and follow up.",
|
||||
],
|
||||
"event_rescheduled": [
|
||||
"Subject: Interview Rescheduled\n\nDue to a conflict, we need to move your interview to a new time.",
|
||||
"Subject: Change of interview time\n\nWe apologize — your interview has been rescheduled to Thursday.",
|
||||
"Subject: Updated interview details\n\nYour interview has been moved from Monday to Wednesday at 2pm.",
|
||||
"Subject: Reschedule request\n\nWould you be available to reschedule to a different time slot?",
|
||||
"Subject: New interview time\n\nYour phone screen has been moved from tomorrow to next week.",
|
||||
],
|
||||
"digest": [
|
||||
"Subject: 15 new jobs matching your search\n\nHere are the latest job postings that match your profile.",
|
||||
"Subject: Weekly Job Digest\n\nThis week's top opportunities for Software Engineers in your area.",
|
||||
"Subject: Jobs you might like\n\nBased on your profile, here are some positions we recommend.",
|
||||
"Subject: New jobs for you\n\nSee the latest openings from companies on your watchlist.",
|
||||
],
|
||||
"new_lead": [
|
||||
"Subject: Exciting opportunity at our company\n\nHi, I noticed your background and think you'd be a great fit.",
|
||||
"Subject: Are you open to new opportunities?\n\nI'm a recruiter reaching out about a role matching your experience.",
|
||||
"Subject: Quick question\n\nWould you be interested in hearing about a senior engineering role?",
|
||||
"Subject: Recruiting outreach\n\nI came across your profile and wanted to share an exciting opening.",
|
||||
],
|
||||
"hired": [
|
||||
"Subject: Welcome to the team!\n\nWe are thrilled to have you join us. Here are your onboarding details.",
|
||||
"Subject: Onboarding information\n\nCongratulations on accepting our offer. Your start date is confirmed.",
|
||||
"Subject: First day information\n\nWe look forward to your first day. Please arrive at 9am and ask for HR.",
|
||||
"Subject: Background check initiated\n\nAs part of your onboarding, we have initiated a background check.",
|
||||
"Subject: Equipment setup\n\nYour laptop and equipment will be ready for pickup on your first day.",
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
class ClassifierAdapter(abc.ABC):
|
||||
"""Abstract base for all email classifier adapters."""
|
||||
|
||||
|
|
@ -386,148 +304,3 @@ class FineTunedAdapter(ClassifierAdapter):
|
|||
text = f"{subject} [SEP] {body[:400]}"
|
||||
result = self._pipeline(text)
|
||||
return result[0]["label"]
|
||||
|
||||
|
||||
class EmbeddingKNNAdapter(ClassifierAdapter):
|
||||
"""k-NN email classifier using Ollama /v1/embeddings via cf-orch allocation.
|
||||
|
||||
load():
|
||||
1. Allocates an Ollama instance from cf-orch (POST /api/services/ollama/allocate).
|
||||
Falls back to ollama_url directly if orch allocation fails or is not configured.
|
||||
2. Pre-embeds all exemplar texts and stores per-label vector lists.
|
||||
|
||||
classify(subject, body):
|
||||
Embeds the input email, computes cosine similarity against all stored exemplar
|
||||
vectors, and majority-votes the top-k labels (default k=3). Tie-break: label
|
||||
with the highest total similarity score among tied vote counts wins.
|
||||
|
||||
unload():
|
||||
Releases the cf-orch allocation (DELETE .../allocations/{id}) and clears state.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
model_id: str,
|
||||
*,
|
||||
k: int = 3,
|
||||
orch_url: str = "",
|
||||
ollama_url: str = "",
|
||||
exemplar_texts: dict[str, list[str]] | None = None,
|
||||
) -> None:
|
||||
self._name = name
|
||||
self._model_id = model_id
|
||||
self._k = k
|
||||
self._orch_url = orch_url
|
||||
self._ollama_url = ollama_url
|
||||
self._exemplar_texts: dict[str, list[str]] = (
|
||||
exemplar_texts if exemplar_texts is not None else DEFAULT_EXEMPLARS
|
||||
)
|
||||
self._exemplar_embeddings: dict[str, list[list[float]]] = {}
|
||||
self._node_url: str = ""
|
||||
self._allocation_id: str = ""
|
||||
self._orch_url_used: str = ""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def model_id(self) -> str:
|
||||
return self._model_id
|
||||
|
||||
def _resolve_urls(self) -> tuple[str, str]:
|
||||
if self._orch_url or self._ollama_url:
|
||||
return self._orch_url, self._ollama_url
|
||||
import yaml # noqa: PLC0415
|
||||
cfg_path = Path(__file__).parent.parent / "config" / "label_tool.yaml"
|
||||
cfg: dict = {}
|
||||
if cfg_path.exists():
|
||||
try:
|
||||
cfg = yaml.safe_load(cfg_path.read_text(encoding="utf-8")) or {}
|
||||
except yaml.YAMLError:
|
||||
pass
|
||||
cforch = cfg.get("cforch", {}) or {}
|
||||
return cforch.get("coordinator_url", ""), cforch.get("ollama_url", "")
|
||||
|
||||
def _embed(self, node_url: str, texts: list[str]) -> list[list[float]]:
|
||||
resp = httpx.post(
|
||||
f"{node_url}/v1/embeddings",
|
||||
json={"model": self._model_id, "input": texts},
|
||||
timeout=30.0,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return [item["embedding"] for item in resp.json()["data"]]
|
||||
|
||||
def load(self) -> None:
|
||||
if self._allocation_id or self._exemplar_embeddings:
|
||||
raise RuntimeError(
|
||||
"EmbeddingKNNAdapter.load() called while already loaded — call unload() first"
|
||||
)
|
||||
orch_url, ollama_url = self._resolve_urls()
|
||||
node_url = ""
|
||||
orch_url_used = ""
|
||||
if orch_url:
|
||||
try:
|
||||
resp = httpx.post(
|
||||
f"{orch_url}/api/services/ollama/allocate",
|
||||
json={"model": self._model_id},
|
||||
timeout=15.0,
|
||||
)
|
||||
if resp.status_code == 200:
|
||||
data = resp.json()
|
||||
node_url = data["url"]
|
||||
self._allocation_id = data["allocation_id"]
|
||||
orch_url_used = orch_url
|
||||
except Exception as exc:
|
||||
_logger.warning(
|
||||
"cf-orch allocation failed, falling back to direct ollama_url: %s", exc
|
||||
)
|
||||
if not node_url:
|
||||
node_url = ollama_url
|
||||
self._allocation_id = ""
|
||||
orch_url_used = ""
|
||||
self._node_url = node_url
|
||||
self._orch_url_used = orch_url_used
|
||||
try:
|
||||
embeddings: dict[str, list[list[float]]] = {}
|
||||
for label, texts in self._exemplar_texts.items():
|
||||
embeddings[label] = self._embed(node_url, texts)
|
||||
self._exemplar_embeddings = embeddings
|
||||
except Exception:
|
||||
self.unload()
|
||||
raise
|
||||
|
||||
def unload(self) -> None:
|
||||
if self._allocation_id and self._orch_url_used:
|
||||
try:
|
||||
httpx.request(
|
||||
"DELETE",
|
||||
f"{self._orch_url_used}/api/services/ollama/allocations/{self._allocation_id}",
|
||||
timeout=10.0,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
self._exemplar_embeddings = {}
|
||||
self._node_url = ""
|
||||
self._allocation_id = ""
|
||||
self._orch_url_used = ""
|
||||
|
||||
def classify(self, subject: str, body: str) -> str:
|
||||
if not self._exemplar_embeddings:
|
||||
self.load()
|
||||
text = f"Subject: {subject}\n\n{body[:600]}"
|
||||
[query_vec] = self._embed(self._node_url, [text])
|
||||
scored: list[tuple[float, str]] = [
|
||||
(_cosine(query_vec, vec), label)
|
||||
for label, vecs in self._exemplar_embeddings.items()
|
||||
for vec in vecs
|
||||
]
|
||||
top_k = sorted(scored, reverse=True)[: self._k]
|
||||
votes: dict[str, list[float]] = {}
|
||||
for score, label in top_k:
|
||||
votes.setdefault(label, []).append(score)
|
||||
return max(
|
||||
votes,
|
||||
key=lambda lbl: (len(votes[lbl]), sum(votes[lbl])),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,458 +0,0 @@
|
|||
"""Export circuitforge-plans/ documents as instruction-tuning JSONL pairs.
|
||||
|
||||
Each record is a HuggingFace chat-format example:
|
||||
|
||||
{
|
||||
"id": "<sha256>",
|
||||
"messages": [
|
||||
{"role": "user", "content": "<reconstructed planning prompt>"},
|
||||
{"role": "assistant", "content": "<cleaned document content>"}
|
||||
],
|
||||
"meta": {
|
||||
"source": "peregrine/2026-03-03-feedback-button-design.md",
|
||||
"product": "peregrine",
|
||||
"doc_type": "design", # design | plan | spec | implementation | other
|
||||
"date": "2026-03-03",
|
||||
"paired_with": "...", # sibling path, or null
|
||||
"word_count": 1847,
|
||||
"pair_role": "context" # "context" | "target" | "standalone"
|
||||
}
|
||||
}
|
||||
|
||||
Pairing strategy
|
||||
----------------
|
||||
When a design doc and a plan doc share the same date + feature-name prefix,
|
||||
they are treated as a pair:
|
||||
- design → plan: instruction = "Given this design doc, write the implementation plan."
|
||||
context appended = full design doc content.
|
||||
- Solo docs get a synthetic instruction from the title + first overview section.
|
||||
|
||||
Usage
|
||||
-----
|
||||
# Preview stats and 5 sample records
|
||||
python scripts/export_plans.py --preview
|
||||
|
||||
# Write full output
|
||||
python scripts/export_plans.py --output data/plan_pairs.jsonl
|
||||
|
||||
# Restrict to specific products
|
||||
python scripts/export_plans.py --products peregrine,kiwi --output data/plan_pairs.jsonl
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import hashlib
|
||||
import json
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Iterator
|
||||
|
||||
# ── Paths ──────────────────────────────────────────────────────────────────────
|
||||
|
||||
_SCRIPT_DIR = Path(__file__).parent
|
||||
_AVOCET_ROOT = _SCRIPT_DIR.parent
|
||||
_DEFAULT_PLANS_DIR = Path("/Library/Development/CircuitForge/circuitforge-plans")
|
||||
_DEFAULT_OUTPUT = _AVOCET_ROOT / "data" / "plan_pairs.jsonl"
|
||||
|
||||
# ── Doc type detection ─────────────────────────────────────────────────────────
|
||||
|
||||
_TYPE_RE = re.compile(
|
||||
r"-(design|plan|spec|implementation|specs|plans)s?$",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
_SKIP_DIRS = {"__pycache__", ".git", "node_modules"}
|
||||
|
||||
# Boilerplate lines to strip from document content before using as output.
|
||||
_BOILERPLATE_RE = re.compile(
|
||||
r"""
|
||||
^\s*>\s*\*\*For\s+agentic\s+workers.* # superpowers agent hints
|
||||
|^\s*>\s*REQUIRED\s+SUB-SKILL.*
|
||||
|^\s*\*\*Date:\*\*.* # metadata header lines
|
||||
|\*\*Status:\*\*\s*Complete.* # completed-feature noise
|
||||
|\*\*Status:\*\*\s*Done.*
|
||||
|\*\*Product:\*\*.*
|
||||
|\*\*Repo:\*\*.*
|
||||
|\*\*Tech\s+Stack:\*\*.*
|
||||
|\*\*Candidate:\*\*.* # old synthetic personas
|
||||
|^Candidate:.*
|
||||
|^Team:.*
|
||||
""",
|
||||
re.VERBOSE | re.MULTILINE,
|
||||
)
|
||||
|
||||
# Old repo/path names to normalise to current equivalents.
|
||||
_PATH_NORMALIZATIONS: list[tuple[re.Pattern, str]] = [
|
||||
(re.compile(r"/devl/job-seeker", re.IGNORECASE), "/Library/Development/CircuitForge/peregrine"),
|
||||
(re.compile(r"\bjob-seeker\b", re.IGNORECASE), "peregrine"),
|
||||
(re.compile(r"Alex Rivera", re.IGNORECASE), "[user]"),
|
||||
]
|
||||
|
||||
# Instruction paraphrase templates per doc type.
|
||||
# Each entry is (user_prefix, paired_prefix).
|
||||
# {title}, {product}, {type_phrase}, {overview}, {design_context} are substituted.
|
||||
_DESIGN_INSTRUCTIONS = [
|
||||
"Write a design document for {product}: {title}.\n\nContext: {overview}",
|
||||
"You are a software architect working on {product}. Draft a design spec for: {title}.\n\n{overview}",
|
||||
"Produce a CircuitForge-style design document for the following {product} feature — {title}.\n\nBackground: {overview}",
|
||||
]
|
||||
|
||||
_PLAN_INSTRUCTIONS = [
|
||||
"Write an implementation plan for {product}: {title}.\n\nContext: {overview}",
|
||||
"Break the following {product} feature into a detailed implementation plan with file structure and task checkboxes — {title}.\n\n{overview}",
|
||||
"You are a senior engineer on {product}. Produce a step-by-step engineering plan for: {title}.\n\n{overview}",
|
||||
]
|
||||
|
||||
_PAIRED_INSTRUCTIONS = [
|
||||
(
|
||||
"You are a software architect working on {product}, a CircuitForge product. "
|
||||
"Given the following design document, write a detailed implementation plan "
|
||||
"(file structure, task breakdown with checkboxes, migration steps if needed).\n\n"
|
||||
"---\n{design_context}\n---"
|
||||
),
|
||||
(
|
||||
"The following is a design spec for a {product} feature. "
|
||||
"Produce a concrete implementation plan: file list, task checklist, any DB migrations needed.\n\n"
|
||||
"---\n{design_context}\n---"
|
||||
),
|
||||
(
|
||||
"Convert this {product} design document into an actionable implementation plan. "
|
||||
"Include all files to create/modify, step-by-step tasks with checkboxes, and migration steps.\n\n"
|
||||
"---\n{design_context}\n---"
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def _doc_type(stem: str) -> str:
|
||||
m = _TYPE_RE.search(stem)
|
||||
if not m:
|
||||
return "other"
|
||||
raw = m.group(1).lower().rstrip("s")
|
||||
return {"implementation": "plan"}.get(raw, raw)
|
||||
|
||||
|
||||
def _date_feature(stem: str) -> tuple[str, str]:
|
||||
"""Return (date, feature_slug) from '2026-03-03-feedback-button-design'."""
|
||||
m = re.match(r"^(\d{4}-\d{2}-\d{2})-(.+?)(?:-(design|plan|spec|implementation)s?)?$", stem, re.I)
|
||||
if m:
|
||||
return m.group(1), m.group(2)
|
||||
return "", stem
|
||||
|
||||
|
||||
# ── Content extraction ─────────────────────────────────────────────────────────
|
||||
|
||||
def _extract_title(content: str) -> str:
|
||||
m = re.search(r"^#\s+(.+)", content, re.MULTILINE)
|
||||
return m.group(1).strip() if m else ""
|
||||
|
||||
|
||||
def _extract_overview(content: str) -> str:
|
||||
"""Return first substantive paragraph or h2 section body (≤300 chars)."""
|
||||
# Superpowers plans have an explicit **Goal:** line — prefer that.
|
||||
goal_m = re.search(r"\*\*Goal:\*\*\s*(.+)", content)
|
||||
if goal_m:
|
||||
return goal_m.group(1).strip()[:300]
|
||||
|
||||
# Otherwise use the body of the first h2 section.
|
||||
h2_m = re.search(
|
||||
r"^##\s+\d*\.?\s*.+\n([\s\S]+?)(?=^##|\Z)",
|
||||
content,
|
||||
re.MULTILINE,
|
||||
)
|
||||
if h2_m:
|
||||
body = h2_m.group(1).strip()
|
||||
# Strip markdown bullet/code noise for the instruction
|
||||
body = re.sub(r"```[\s\S]*?```", "", body)
|
||||
body = re.sub(r"`[^`]+`", lambda m: m.group().strip("`"), body)
|
||||
body = re.sub(r"\*\*([^*]+)\*\*", r"\1", body)
|
||||
body = re.sub(r"\s+", " ", body).strip()
|
||||
return body[:300]
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
def _clean_content(content: str) -> str:
|
||||
"""Remove boilerplate, normalize old paths/names, collapse whitespace."""
|
||||
cleaned = _BOILERPLATE_RE.sub("", content)
|
||||
for pattern, replacement in _PATH_NORMALIZATIONS:
|
||||
cleaned = pattern.sub(replacement, cleaned)
|
||||
cleaned = re.sub(r"\n{3,}", "\n\n", cleaned)
|
||||
return cleaned.strip()
|
||||
|
||||
|
||||
def _quality_flags(content: str) -> list[str]:
|
||||
"""Return a list of quality issue labels found in cleaned content."""
|
||||
flags = []
|
||||
if "Alex Rivera" in content or "[user]" in content:
|
||||
flags.append("persona-residue")
|
||||
if re.search(r"\bStatus:\s*(Complete|Done|Merged)\b", content):
|
||||
flags.append("completed-status")
|
||||
return flags
|
||||
|
||||
|
||||
def _make_instruction(
|
||||
title: str,
|
||||
product: str,
|
||||
doc_type: str,
|
||||
overview: str,
|
||||
design_context: str | None = None,
|
||||
variant: int = 0,
|
||||
) -> str:
|
||||
"""Synthesise a natural planning prompt for this document.
|
||||
|
||||
variant: 0-2 selects which paraphrase template to use. Caller cycles
|
||||
through all three to produce multiple training examples per document.
|
||||
"""
|
||||
product_label = product.replace("-", " ").title() if product else "CircuitForge"
|
||||
idx = variant % 3
|
||||
|
||||
if design_context:
|
||||
tmpl = _PAIRED_INSTRUCTIONS[idx]
|
||||
return tmpl.format(
|
||||
product=product_label,
|
||||
design_context=design_context[:2500],
|
||||
)
|
||||
|
||||
templates = _PLAN_INSTRUCTIONS if doc_type in ("plan",) else _DESIGN_INSTRUCTIONS
|
||||
tmpl = templates[idx]
|
||||
return tmpl.format(
|
||||
product=product_label,
|
||||
title=title,
|
||||
overview=overview or "",
|
||||
type_phrase="planning document",
|
||||
)
|
||||
|
||||
|
||||
def _record_id(content: str, source: str) -> str:
|
||||
return hashlib.sha256(f"{source}:{content}".encode()).hexdigest()[:16]
|
||||
|
||||
|
||||
# ── Pair discovery ─────────────────────────────────────────────────────────────
|
||||
|
||||
def _find_pairs(plans_dir: Path) -> dict[str, list[tuple[str, Path]]]:
|
||||
"""Return {prefix_key → [(doc_type, path), ...]} for docs sharing date+feature."""
|
||||
by_prefix: dict[str, list[tuple[str, Path]]] = {}
|
||||
for path in plans_dir.rglob("*.md"):
|
||||
if any(part in _SKIP_DIRS for part in path.parts):
|
||||
continue
|
||||
if path.name == "README.md":
|
||||
continue
|
||||
stem = path.stem
|
||||
date, feature = _date_feature(stem)
|
||||
if not date:
|
||||
continue
|
||||
key = str(path.parent / f"{date}-{feature}")
|
||||
by_prefix.setdefault(key, []).append((_doc_type(stem), path))
|
||||
return by_prefix
|
||||
|
||||
|
||||
# ── Record generation ──────────────────────────────────────────────────────────
|
||||
|
||||
def _records_for_group(
|
||||
doc_type_paths: list[tuple[str, Path]],
|
||||
plans_dir: Path,
|
||||
) -> Iterator[dict]:
|
||||
"""Yield one or more training records for a group of related docs."""
|
||||
# Separate design vs plan docs within this group
|
||||
designs = [(t, p) for t, p in doc_type_paths if t in ("design", "spec")]
|
||||
plans_ = [(t, p) for t, p in doc_type_paths if t in ("plan",)]
|
||||
others = [(t, p) for t, p in doc_type_paths if t not in ("design", "spec", "plan")]
|
||||
|
||||
all_paths = doc_type_paths
|
||||
|
||||
if designs and plans_:
|
||||
# Paired: yield a design→plan record (3 instruction variants)
|
||||
design_type, design_path = designs[0]
|
||||
plan_type, plan_path = plans_[0]
|
||||
design_content = design_path.read_text(encoding="utf-8")
|
||||
plan_content = plan_path.read_text(encoding="utf-8")
|
||||
|
||||
product = _product_from_path(plan_path, plans_dir)
|
||||
title = _extract_title(plan_content) or plan_path.stem
|
||||
cleaned = _clean_content(plan_content)
|
||||
design_cleaned = _clean_content(design_content)
|
||||
flags = _quality_flags(cleaned)
|
||||
|
||||
if len(cleaned.split()) >= 80:
|
||||
rel_src = str(plan_path.relative_to(plans_dir))
|
||||
rel_design = str(design_path.relative_to(plans_dir))
|
||||
for variant in range(3):
|
||||
instruction = _make_instruction(
|
||||
title=title,
|
||||
product=product,
|
||||
doc_type="plan",
|
||||
overview=_extract_overview(design_content),
|
||||
design_context=design_cleaned,
|
||||
variant=variant,
|
||||
)
|
||||
yield {
|
||||
"id": _record_id(f"v{variant}:{cleaned}", rel_src),
|
||||
"messages": [
|
||||
{"role": "user", "content": instruction},
|
||||
{"role": "assistant", "content": cleaned},
|
||||
],
|
||||
"meta": {
|
||||
"source": rel_src,
|
||||
"product": product,
|
||||
"doc_type": "plan",
|
||||
"date": _date_feature(plan_path.stem)[0],
|
||||
"paired_with": rel_design,
|
||||
"word_count": len(cleaned.split()),
|
||||
"pair_role": "target",
|
||||
"variant": variant,
|
||||
"quality_flags": flags,
|
||||
},
|
||||
}
|
||||
|
||||
# Also yield the design doc as standalone variants
|
||||
all_paths = [(t, p) for t, p in all_paths if p != plan_path]
|
||||
|
||||
# Remaining docs as standalone records (3 instruction variants each)
|
||||
for doc_type, path in all_paths:
|
||||
content = path.read_text(encoding="utf-8")
|
||||
cleaned = _clean_content(content)
|
||||
if len(cleaned.split()) < 80:
|
||||
continue
|
||||
|
||||
product = _product_from_path(path, plans_dir)
|
||||
title = _extract_title(content) or path.stem
|
||||
overview = _extract_overview(content)
|
||||
flags = _quality_flags(cleaned)
|
||||
rel_src = str(path.relative_to(plans_dir))
|
||||
|
||||
for variant in range(3):
|
||||
instruction = _make_instruction(
|
||||
title=title,
|
||||
product=product,
|
||||
doc_type=doc_type,
|
||||
overview=overview,
|
||||
variant=variant,
|
||||
)
|
||||
yield {
|
||||
"id": _record_id(f"v{variant}:{cleaned}", rel_src),
|
||||
"messages": [
|
||||
{"role": "user", "content": instruction},
|
||||
{"role": "assistant", "content": cleaned},
|
||||
],
|
||||
"meta": {
|
||||
"source": rel_src,
|
||||
"product": product,
|
||||
"doc_type": doc_type,
|
||||
"date": _date_feature(path.stem)[0],
|
||||
"paired_with": None,
|
||||
"word_count": len(cleaned.split()),
|
||||
"pair_role": "standalone",
|
||||
"variant": variant,
|
||||
"quality_flags": flags,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _product_from_path(path: Path, plans_dir: Path) -> str:
|
||||
rel = path.relative_to(plans_dir)
|
||||
return rel.parts[0] if len(rel.parts) > 1 else "shared"
|
||||
|
||||
|
||||
# ── Main export ────────────────────────────────────────────────────────────────
|
||||
|
||||
def export(
|
||||
plans_dir: Path,
|
||||
products: list[str] | None = None,
|
||||
) -> list[dict]:
|
||||
groups = _find_pairs(plans_dir)
|
||||
records: list[dict] = []
|
||||
seen_ids: set[str] = set()
|
||||
|
||||
for group_key, doc_type_paths in groups.items():
|
||||
# Filter by product if requested
|
||||
if products:
|
||||
paths = [p for _, p in doc_type_paths]
|
||||
prods = {_product_from_path(p, plans_dir) for p in paths}
|
||||
if not prods.intersection(products):
|
||||
continue
|
||||
|
||||
for record in _records_for_group(doc_type_paths, plans_dir):
|
||||
if record["id"] not in seen_ids:
|
||||
seen_ids.add(record["id"])
|
||||
records.append(record)
|
||||
|
||||
return records
|
||||
|
||||
|
||||
# ── CLI ────────────────────────────────────────────────────────────────────────
|
||||
|
||||
def _print_stats(records: list[dict]) -> None:
|
||||
from collections import Counter
|
||||
products = Counter(r["meta"]["product"] for r in records)
|
||||
doc_types = Counter(r["meta"]["doc_type"] for r in records)
|
||||
pair_roles = Counter(r["meta"]["pair_role"] for r in records)
|
||||
wc = [r["meta"]["word_count"] for r in records]
|
||||
wc.sort()
|
||||
|
||||
print(f"\n{'='*55}")
|
||||
print(f" Total records: {len(records)}")
|
||||
print(f" Word counts : min={wc[0]}, median={wc[len(wc)//2]}, max={wc[-1]}")
|
||||
print(f"\n By product:")
|
||||
for p, n in products.most_common():
|
||||
print(f" {p:<22} {n}")
|
||||
print(f"\n By doc type:")
|
||||
for t, n in doc_types.most_common():
|
||||
print(f" {t:<22} {n}")
|
||||
print(f"\n Pair roles:")
|
||||
for r, n in pair_roles.most_common():
|
||||
print(f" {r:<22} {n}")
|
||||
print(f"{'='*55}\n")
|
||||
|
||||
|
||||
def _print_sample(records: list[dict], n: int = 3) -> None:
|
||||
import random
|
||||
sample = random.sample(records, min(n, len(records)))
|
||||
for i, rec in enumerate(sample, 1):
|
||||
meta = rec["meta"]
|
||||
user_msg = rec["messages"][0]["content"]
|
||||
asst_msg = rec["messages"][1]["content"]
|
||||
print(f"\n{'─'*55}")
|
||||
print(f"SAMPLE {i}/{n} [{meta['product']} / {meta['doc_type']} / {meta['pair_role']}]")
|
||||
print(f"source: {meta['source']}")
|
||||
print(f"\nUSER ({len(user_msg)} chars):\n{user_msg[:500]}{'...' if len(user_msg)>500 else ''}")
|
||||
print(f"\nASSISTANT ({meta['word_count']} words):\n{asst_msg[:400]}{'...' if len(asst_msg)>400 else ''}")
|
||||
print(f"\n{'─'*55}\n")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter)
|
||||
parser.add_argument("--plans-dir", type=Path, default=_DEFAULT_PLANS_DIR)
|
||||
parser.add_argument("--output", type=Path, default=None,
|
||||
help="Write JSONL to this path (omit for preview-only)")
|
||||
parser.add_argument("--products", default=None,
|
||||
help="Comma-separated product filter, e.g. peregrine,kiwi")
|
||||
parser.add_argument("--preview", action="store_true",
|
||||
help="Print stats + sample records, don't write output")
|
||||
parser.add_argument("--samples", type=int, default=3,
|
||||
help="Number of sample records to show in preview (default 3)")
|
||||
args = parser.parse_args()
|
||||
|
||||
products = [p.strip() for p in args.products.split(",")] if args.products else None
|
||||
|
||||
print(f"Scanning {args.plans_dir} …", file=sys.stderr)
|
||||
records = export(args.plans_dir, products=products)
|
||||
|
||||
_print_stats(records)
|
||||
|
||||
if args.preview or args.output is None:
|
||||
_print_sample(records, n=args.samples)
|
||||
if args.output is None:
|
||||
print("(Pass --output <path> to write JSONL)")
|
||||
return
|
||||
|
||||
args.output.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(args.output, "w", encoding="utf-8") as f:
|
||||
for rec in records:
|
||||
f.write(json.dumps(rec, ensure_ascii=False) + "\n")
|
||||
|
||||
print(f"Wrote {len(records)} records to {args.output}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -1,355 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Corpus gatherer for the voice benchmark fine-tune pipeline.
|
||||
|
||||
Pulls writing samples from multiple sources and drops .txt files into
|
||||
data/voice_corpus/ in the format expected by benchmark_voice.py.
|
||||
|
||||
Sources:
|
||||
- Reddit: u/pyr0ball post history + comment history (public JSON API)
|
||||
- Campaign copy: claude-bridge/reddit-poster/campaigns/*.py (BODY strings)
|
||||
- Documents: brainmap, homeprojects notes, selected personal writing
|
||||
- Discord: requires manual export (see instructions below)
|
||||
|
||||
Usage:
|
||||
# Full gather (Reddit + local sources)
|
||||
conda run -n cf python scripts/gather_corpus.py
|
||||
|
||||
# Reddit only
|
||||
conda run -n cf python scripts/gather_corpus.py --source reddit
|
||||
|
||||
# Local files only (no network)
|
||||
conda run -n cf python scripts/gather_corpus.py --source local
|
||||
|
||||
# Process a Discord data export zip
|
||||
conda run -n cf python scripts/gather_corpus.py --discord /path/to/discord-export.zip
|
||||
|
||||
Discord export instructions:
|
||||
Discord Settings → Privacy & Safety → Request all my data
|
||||
Wait for email, download zip, then run with --discord flag.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import ast
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Paths
|
||||
# ------------------------------------------------------------------ #
|
||||
|
||||
_ROOT = Path(__file__).parent.parent
|
||||
_CORPUS_DIR = _ROOT / "data" / "style_corpus"
|
||||
_CLAUDE_BRIDGE = Path("/Library/Development/CircuitForge/claude-bridge")
|
||||
_DOCUMENTS = Path("/Library/Documents")
|
||||
|
||||
_REDDIT_USER = "pyr0ball"
|
||||
_USER_AGENT = "Avocet/0.1 corpus-gatherer (CircuitForge; personal research)"
|
||||
_REDDIT_BASE = "https://www.reddit.com"
|
||||
|
||||
# Minimum character length to include a sample (filters out one-liners)
|
||||
_MIN_LENGTH = 80
|
||||
|
||||
# Phrases that suggest AI-generated content — skip these
|
||||
_AI_TELLS = [
|
||||
"certainly!", "absolutely!", "great question", "i'd be happy to",
|
||||
"i apologize for", "it's worth noting", "in conclusion,",
|
||||
"feel free to reach out",
|
||||
]
|
||||
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Helpers
|
||||
# ------------------------------------------------------------------ #
|
||||
|
||||
def _is_ai_generated(text: str) -> bool:
|
||||
lower = text.lower()
|
||||
return any(phrase in lower for phrase in _AI_TELLS)
|
||||
|
||||
|
||||
def _clean(text: str) -> str:
|
||||
"""Strip Reddit formatting artifacts and normalize whitespace."""
|
||||
text = re.sub(r"\[deleted\]|\[removed\]", "", text)
|
||||
text = re.sub(r"\s+", " ", text).strip()
|
||||
return text
|
||||
|
||||
|
||||
def _write_corpus_file(filename: str, samples: list[str], source_label: str) -> None:
|
||||
"""Write samples to a corpus .txt file with minimal separators."""
|
||||
path = _CORPUS_DIR / filename
|
||||
kept = [s for s in samples if len(s) >= _MIN_LENGTH and not _is_ai_generated(s)]
|
||||
if not kept:
|
||||
print(f" [skip] {filename} — no samples passed filters")
|
||||
return
|
||||
separator = "\n\n---\n\n"
|
||||
path.write_text(separator.join(kept), encoding="utf-8")
|
||||
print(f" [ok] {filename} — {len(kept)} samples ({path.stat().st_size // 1024}KB)")
|
||||
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Reddit source
|
||||
# ------------------------------------------------------------------ #
|
||||
|
||||
def _reddit_fetch_page(
|
||||
client: httpx.Client,
|
||||
listing_type: str,
|
||||
after: str | None,
|
||||
) -> tuple[list[dict[str, Any]], str | None]:
|
||||
"""Fetch one page of a user's submitted posts or comments."""
|
||||
params: dict[str, Any] = {"limit": 100, "raw_json": 1}
|
||||
if after:
|
||||
params["after"] = after
|
||||
url = f"{_REDDIT_BASE}/user/{_REDDIT_USER}/{listing_type}.json"
|
||||
resp = client.get(url, params=params)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
children = data["data"]["children"]
|
||||
new_after = data["data"].get("after")
|
||||
return [c["data"] for c in children], new_after
|
||||
|
||||
|
||||
def _reddit_fetch_all(listing_type: str, max_items: int = 1000) -> list[dict[str, Any]]:
|
||||
"""Paginate through a user listing until exhausted or max_items reached."""
|
||||
items: list[dict[str, Any]] = []
|
||||
after: str | None = None
|
||||
with httpx.Client(
|
||||
headers={"User-Agent": _USER_AGENT},
|
||||
follow_redirects=True,
|
||||
timeout=20.0,
|
||||
) as client:
|
||||
while len(items) < max_items:
|
||||
try:
|
||||
page, after = _reddit_fetch_page(client, listing_type, after)
|
||||
except httpx.HTTPStatusError as exc:
|
||||
# Reddit blocks unauthenticated pagination after the first page;
|
||||
# save what we have rather than crashing.
|
||||
print(f" stopped at {len(items)} {listing_type} (HTTP {exc.response.status_code})")
|
||||
break
|
||||
if not page:
|
||||
break
|
||||
items.extend(page)
|
||||
print(f" fetched {len(items)} {listing_type}...")
|
||||
if not after:
|
||||
break
|
||||
time.sleep(1.0) # respect rate limit
|
||||
return items
|
||||
|
||||
|
||||
def gather_reddit() -> None:
|
||||
print("Fetching Reddit history for u/pyr0ball...")
|
||||
|
||||
# Posts (submitted)
|
||||
print(" Posts:")
|
||||
posts = _reddit_fetch_all("submitted")
|
||||
post_texts: list[str] = []
|
||||
for p in posts:
|
||||
body = _clean(p.get("selftext", "") or "")
|
||||
title = _clean(p.get("title", ""))
|
||||
if len(body) >= _MIN_LENGTH:
|
||||
post_texts.append(f"{title}\n\n{body}")
|
||||
elif len(title) >= 20:
|
||||
# Title-only posts (link posts) — include title as micro-sample
|
||||
post_texts.append(title)
|
||||
_write_corpus_file("social_post_reddit.txt", post_texts, "reddit/submitted")
|
||||
|
||||
# Comments
|
||||
print(" Comments:")
|
||||
comments = _reddit_fetch_all("comments")
|
||||
comment_texts: list[str] = []
|
||||
for c in comments:
|
||||
body = _clean(c.get("body", "") or "")
|
||||
if body and body not in ("[deleted]", "[removed]"):
|
||||
comment_texts.append(body)
|
||||
_write_corpus_file("social_reply_reddit_comments.txt", comment_texts, "reddit/comments")
|
||||
|
||||
print(f" Done. {len(posts)} posts, {len(comments)} comments fetched.")
|
||||
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Campaign copy source (claude-bridge)
|
||||
# ------------------------------------------------------------------ #
|
||||
|
||||
def _extract_body_from_campaign(py_file: Path) -> str | None:
|
||||
"""
|
||||
Parse a campaign Python file and extract the BODY string literal.
|
||||
Uses AST to handle multi-line strings safely.
|
||||
"""
|
||||
try:
|
||||
tree = ast.parse(py_file.read_text(encoding="utf-8"))
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.Assign):
|
||||
for target in node.targets:
|
||||
if isinstance(target, ast.Name) and target.id == "BODY":
|
||||
if isinstance(node.value, ast.Constant):
|
||||
return str(node.value.value)
|
||||
except (SyntaxError, UnicodeDecodeError):
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def gather_campaigns() -> None:
|
||||
campaigns_dir = _CLAUDE_BRIDGE / "reddit-poster" / "campaigns"
|
||||
if not campaigns_dir.exists():
|
||||
print(f" [skip] campaigns dir not found: {campaigns_dir}")
|
||||
return
|
||||
|
||||
print("Gathering campaign copy from claude-bridge...")
|
||||
samples: list[str] = []
|
||||
for py_file in sorted(campaigns_dir.glob("*.py")):
|
||||
body = _extract_body_from_campaign(py_file)
|
||||
if body:
|
||||
samples.append(body.strip())
|
||||
print(f" {py_file.name} — {len(body)} chars")
|
||||
|
||||
_write_corpus_file("narrative_campaign_copy.txt", samples, "claude-bridge/campaigns")
|
||||
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Documents source
|
||||
# ------------------------------------------------------------------ #
|
||||
|
||||
def gather_documents() -> None:
|
||||
print("Gathering local Documents...")
|
||||
samples: list[str] = []
|
||||
|
||||
# brainmap — personal planning/thinking notes
|
||||
brainmap = _DOCUMENTS / "brainmap_v1.md"
|
||||
if brainmap.exists():
|
||||
text = _clean(brainmap.read_text(encoding="utf-8"))
|
||||
if len(text) >= _MIN_LENGTH:
|
||||
samples.append(text)
|
||||
print(f" brainmap_v1.md — {len(text)} chars")
|
||||
|
||||
# HomeProjects handoff notes — casual technical prose
|
||||
for handoff in sorted((_DOCUMENTS / "HomeProjects").glob("handoff*.md")):
|
||||
text = _clean(handoff.read_text(encoding="utf-8", errors="replace"))
|
||||
if len(text) >= _MIN_LENGTH:
|
||||
samples.append(text)
|
||||
print(f" {handoff.name} — {len(text)} chars")
|
||||
|
||||
# Personal letters (Closet folder) — intimate prose voice
|
||||
closet = _DOCUMENTS / "Closet"
|
||||
if closet.exists():
|
||||
for letter in closet.glob("*.md"):
|
||||
text = _clean(letter.read_text(encoding="utf-8", errors="replace"))
|
||||
if len(text) >= _MIN_LENGTH and not _is_ai_generated(text):
|
||||
samples.append(text)
|
||||
print(f" {letter.name} — {len(text)} chars")
|
||||
|
||||
_write_corpus_file("narrative_personal_docs.txt", samples, "documents")
|
||||
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Discord export source
|
||||
# ------------------------------------------------------------------ #
|
||||
|
||||
def gather_discord(export_zip: Path) -> None:
|
||||
"""
|
||||
Process a Discord data export zip (from Settings → Privacy & Safety → Request all my data).
|
||||
|
||||
Expected zip structure:
|
||||
messages/
|
||||
c{channel_id}/
|
||||
messages.json -- list of {ID, Timestamp, Contents, Attachments}
|
||||
account/
|
||||
user.json -- {username, ...}
|
||||
"""
|
||||
print(f"Processing Discord export: {export_zip}")
|
||||
samples: list[str] = []
|
||||
message_count = 0
|
||||
|
||||
with zipfile.ZipFile(export_zip) as zf:
|
||||
# Find all messages.json files
|
||||
message_files = [n for n in zf.namelist() if n.endswith("/messages.json")]
|
||||
print(f" Found {len(message_files)} channel(s)")
|
||||
|
||||
for mf in message_files:
|
||||
try:
|
||||
data = json.loads(zf.read(mf))
|
||||
except (json.JSONDecodeError, KeyError):
|
||||
continue
|
||||
|
||||
for msg in data:
|
||||
content = _clean(msg.get("Contents", "") or "")
|
||||
# Skip system messages, bot commands, very short messages
|
||||
if (
|
||||
len(content) < _MIN_LENGTH
|
||||
or content.startswith("/")
|
||||
or content.startswith("!")
|
||||
or _is_ai_generated(content)
|
||||
):
|
||||
continue
|
||||
# Skip messages that are just URLs or attachments
|
||||
if re.match(r"^https?://\S+$", content):
|
||||
continue
|
||||
samples.append(content)
|
||||
message_count += 1
|
||||
|
||||
print(f" {message_count} messages → {len(samples)} passed filters")
|
||||
_write_corpus_file("social_reply_discord.txt", samples, "discord")
|
||||
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Entrypoint
|
||||
# ------------------------------------------------------------------ #
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="Gather writing corpus for voice benchmark")
|
||||
parser.add_argument(
|
||||
"--source",
|
||||
choices=["reddit", "local", "all"],
|
||||
default="all",
|
||||
help="Which sources to gather (default: all)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--discord",
|
||||
type=Path,
|
||||
metavar="ZIP",
|
||||
help="Path to Discord data export zip",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
_CORPUS_DIR.mkdir(parents=True, exist_ok=True)
|
||||
print(f"Output: {_CORPUS_DIR}\n")
|
||||
|
||||
if args.source in ("reddit", "all"):
|
||||
gather_reddit()
|
||||
print()
|
||||
|
||||
if args.source in ("local", "all"):
|
||||
gather_campaigns()
|
||||
print()
|
||||
gather_documents()
|
||||
print()
|
||||
|
||||
if args.discord:
|
||||
if not args.discord.exists():
|
||||
print(f"Error: Discord export not found: {args.discord}")
|
||||
else:
|
||||
gather_discord(args.discord)
|
||||
print()
|
||||
|
||||
if not args.discord and args.source in ("local", "all"):
|
||||
print("Discord: manual step required")
|
||||
print(" 1. Discord Settings → Privacy & Safety → Request all my data")
|
||||
print(" 2. Download the zip from the email link")
|
||||
print(" 3. Run: python scripts/gather_corpus.py --discord /path/to/package.zip")
|
||||
print()
|
||||
|
||||
# Summary
|
||||
corpus_files = sorted(_CORPUS_DIR.glob("*.txt"))
|
||||
total_chars = sum(f.stat().st_size for f in corpus_files)
|
||||
print(f"Corpus: {len(corpus_files)} file(s), {total_chars // 1024}KB total")
|
||||
for f in corpus_files:
|
||||
print(f" {f.name}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -1,37 +1,23 @@
|
|||
"""Smoke tests for the app factory (app/api.py).
|
||||
import json
|
||||
|
||||
Detailed route tests live in test_data_label.py, test_data_fetch.py,
|
||||
test_data_corrections.py, test_train.py, and test_dashboard.py.
|
||||
"""
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from app import api as api_module # noqa: F401
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_globals(tmp_path):
|
||||
from app import api
|
||||
api.set_data_dir(tmp_path)
|
||||
api.reset_last_action()
|
||||
yield
|
||||
api.reset_last_action()
|
||||
|
||||
|
||||
def test_import():
|
||||
from app import api # noqa: F401
|
||||
|
||||
|
||||
def test_app_has_required_routes():
|
||||
from app.api import app
|
||||
paths = {r.path for r in app.routes}
|
||||
# Label routes
|
||||
assert "/api/queue" in paths
|
||||
assert "/api/label" in paths
|
||||
assert "/api/skip" in paths
|
||||
assert "/api/discard" in paths
|
||||
assert "/api/label/undo" in paths
|
||||
assert "/api/config/labels" in paths
|
||||
assert "/api/stats" in paths
|
||||
# Fetch routes
|
||||
assert "/api/accounts/test" in paths
|
||||
assert "/api/fetch/stream" in paths
|
||||
# Train routes
|
||||
assert "/api/train/jobs" in paths
|
||||
assert "/api/train/results" in paths
|
||||
# Dashboard
|
||||
assert "/api/dashboard" in paths
|
||||
# Corrections (new prefix)
|
||||
assert "/api/corrections/ingest" in paths
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -40,8 +26,536 @@ def client():
|
|||
return TestClient(app)
|
||||
|
||||
|
||||
def test_queue_endpoint_reachable(client):
|
||||
@pytest.fixture
|
||||
def queue_with_items():
|
||||
"""Write 3 test emails to the queue file."""
|
||||
from app import api as api_module
|
||||
items = [
|
||||
{"id": f"id{i}", "subject": f"Subject {i}", "body": f"Body {i}",
|
||||
"from": "test@example.com", "date": "2026-03-01", "source": "imap:test"}
|
||||
for i in range(3)
|
||||
]
|
||||
queue_path = api_module._DATA_DIR / "email_label_queue.jsonl"
|
||||
queue_path.write_text("\n".join(json.dumps(x) for x in items) + "\n")
|
||||
return items
|
||||
|
||||
|
||||
def test_queue_returns_items(client, queue_with_items):
|
||||
r = client.get("/api/queue?limit=2")
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert len(data["items"]) == 2
|
||||
assert data["total"] == 3
|
||||
|
||||
|
||||
def test_queue_empty_when_no_file(client):
|
||||
r = client.get("/api/queue")
|
||||
assert r.status_code == 200
|
||||
assert "items" in r.json()
|
||||
assert "total" in r.json()
|
||||
assert r.json() == {"items": [], "total": 0}
|
||||
|
||||
|
||||
def test_label_appends_to_score(client, queue_with_items):
|
||||
from app import api as api_module
|
||||
r = client.post("/api/label", json={"id": "id0", "label": "interview_scheduled"})
|
||||
assert r.status_code == 200
|
||||
records = api_module._read_jsonl(api_module._score_file())
|
||||
assert len(records) == 1
|
||||
assert records[0]["id"] == "id0"
|
||||
assert records[0]["label"] == "interview_scheduled"
|
||||
assert "labeled_at" in records[0]
|
||||
|
||||
def test_label_removes_from_queue(client, queue_with_items):
|
||||
from app import api as api_module
|
||||
client.post("/api/label", json={"id": "id0", "label": "rejected"})
|
||||
queue = api_module._read_jsonl(api_module._queue_file())
|
||||
assert not any(x["id"] == "id0" for x in queue)
|
||||
|
||||
def test_label_unknown_id_returns_404(client, queue_with_items):
|
||||
r = client.post("/api/label", json={"id": "unknown", "label": "neutral"})
|
||||
assert r.status_code == 404
|
||||
|
||||
def test_skip_moves_to_back(client, queue_with_items):
|
||||
from app import api as api_module
|
||||
r = client.post("/api/skip", json={"id": "id0"})
|
||||
assert r.status_code == 200
|
||||
queue = api_module._read_jsonl(api_module._queue_file())
|
||||
assert queue[-1]["id"] == "id0"
|
||||
assert queue[0]["id"] == "id1"
|
||||
|
||||
def test_skip_unknown_id_returns_404(client, queue_with_items):
|
||||
r = client.post("/api/skip", json={"id": "nope"})
|
||||
assert r.status_code == 404
|
||||
|
||||
|
||||
# --- Part A: POST /api/discard ---
|
||||
|
||||
def test_discard_writes_to_discarded_file(client, queue_with_items):
|
||||
from app import api as api_module
|
||||
r = client.post("/api/discard", json={"id": "id1"})
|
||||
assert r.status_code == 200
|
||||
discarded = api_module._read_jsonl(api_module._discarded_file())
|
||||
assert len(discarded) == 1
|
||||
assert discarded[0]["id"] == "id1"
|
||||
assert discarded[0]["label"] == "__discarded__"
|
||||
|
||||
def test_discard_removes_from_queue(client, queue_with_items):
|
||||
from app import api as api_module
|
||||
client.post("/api/discard", json={"id": "id1"})
|
||||
queue = api_module._read_jsonl(api_module._queue_file())
|
||||
assert not any(x["id"] == "id1" for x in queue)
|
||||
|
||||
|
||||
# --- Part B: DELETE /api/label/undo ---
|
||||
|
||||
def test_undo_label_removes_from_score(client, queue_with_items):
|
||||
from app import api as api_module
|
||||
client.post("/api/label", json={"id": "id0", "label": "neutral"})
|
||||
r = client.delete("/api/label/undo")
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["undone"]["type"] == "label"
|
||||
score = api_module._read_jsonl(api_module._score_file())
|
||||
assert score == []
|
||||
# Item should be restored to front of queue
|
||||
queue = api_module._read_jsonl(api_module._queue_file())
|
||||
assert queue[0]["id"] == "id0"
|
||||
|
||||
def test_undo_discard_removes_from_discarded(client, queue_with_items):
|
||||
from app import api as api_module
|
||||
client.post("/api/discard", json={"id": "id0"})
|
||||
r = client.delete("/api/label/undo")
|
||||
assert r.status_code == 200
|
||||
discarded = api_module._read_jsonl(api_module._discarded_file())
|
||||
assert discarded == []
|
||||
|
||||
def test_undo_skip_restores_to_front(client, queue_with_items):
|
||||
from app import api as api_module
|
||||
client.post("/api/skip", json={"id": "id0"})
|
||||
r = client.delete("/api/label/undo")
|
||||
assert r.status_code == 200
|
||||
queue = api_module._read_jsonl(api_module._queue_file())
|
||||
assert queue[0]["id"] == "id0"
|
||||
|
||||
def test_undo_with_no_action_returns_404(client):
|
||||
r = client.delete("/api/label/undo")
|
||||
assert r.status_code == 404
|
||||
|
||||
|
||||
# --- Part C: GET /api/config/labels ---
|
||||
|
||||
def test_config_labels_returns_metadata(client):
|
||||
r = client.get("/api/config/labels")
|
||||
assert r.status_code == 200
|
||||
labels = r.json()
|
||||
assert len(labels) == 10
|
||||
assert labels[0]["key"] == "1"
|
||||
assert "emoji" in labels[0]
|
||||
assert "color" in labels[0]
|
||||
assert "name" in labels[0]
|
||||
|
||||
|
||||
# ── /api/config ──────────────────────────────────────────────────────────────
|
||||
|
||||
@pytest.fixture
|
||||
def config_dir(tmp_path):
|
||||
"""Give the API a writable config directory."""
|
||||
from app import api as api_module
|
||||
api_module.set_config_dir(tmp_path)
|
||||
yield tmp_path
|
||||
api_module.set_config_dir(None) # reset to default
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def data_dir():
|
||||
"""Expose the current _DATA_DIR set by the autouse reset_globals fixture."""
|
||||
from app import api as api_module
|
||||
return api_module._DATA_DIR
|
||||
|
||||
|
||||
def test_get_config_returns_empty_when_no_file(client, config_dir):
|
||||
r = client.get("/api/config")
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["accounts"] == []
|
||||
assert data["max_per_account"] == 500
|
||||
|
||||
|
||||
def test_post_config_writes_yaml(client, config_dir):
|
||||
import yaml
|
||||
payload = {
|
||||
"accounts": [{"name": "Test", "host": "imap.test.com", "port": 993,
|
||||
"use_ssl": True, "username": "u@t.com", "password": "pw",
|
||||
"folder": "INBOX", "days_back": 30}],
|
||||
"max_per_account": 200,
|
||||
}
|
||||
r = client.post("/api/config", json=payload)
|
||||
assert r.status_code == 200
|
||||
assert r.json()["ok"] is True
|
||||
cfg_file = config_dir / "label_tool.yaml"
|
||||
assert cfg_file.exists()
|
||||
saved = yaml.safe_load(cfg_file.read_text())
|
||||
assert saved["max_per_account"] == 200
|
||||
assert saved["accounts"][0]["name"] == "Test"
|
||||
|
||||
|
||||
def test_get_config_round_trips(client, config_dir):
|
||||
payload = {"accounts": [{"name": "R", "host": "h", "port": 993, "use_ssl": True,
|
||||
"username": "u", "password": "p", "folder": "INBOX",
|
||||
"days_back": 90}], "max_per_account": 300}
|
||||
client.post("/api/config", json=payload)
|
||||
r = client.get("/api/config")
|
||||
data = r.json()
|
||||
assert data["max_per_account"] == 300
|
||||
assert data["accounts"][0]["name"] == "R"
|
||||
|
||||
|
||||
# ── /api/stats ───────────────────────────────────────────────────────────────
|
||||
|
||||
@pytest.fixture
|
||||
def score_with_labels(tmp_path, data_dir):
|
||||
"""Write a score file with 3 labels for stats tests."""
|
||||
score_path = data_dir / "email_score.jsonl"
|
||||
records = [
|
||||
{"id": "a", "label": "interview_scheduled"},
|
||||
{"id": "b", "label": "interview_scheduled"},
|
||||
{"id": "c", "label": "rejected"},
|
||||
]
|
||||
score_path.write_text("\n".join(json.dumps(r) for r in records) + "\n")
|
||||
return records
|
||||
|
||||
|
||||
def test_stats_returns_counts(client, score_with_labels):
|
||||
r = client.get("/api/stats")
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["total"] == 3
|
||||
assert data["counts"]["interview_scheduled"] == 2
|
||||
assert data["counts"]["rejected"] == 1
|
||||
|
||||
|
||||
def test_stats_empty_when_no_file(client, data_dir):
|
||||
r = client.get("/api/stats")
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["total"] == 0
|
||||
assert data["counts"] == {}
|
||||
assert data["score_file_bytes"] == 0
|
||||
|
||||
|
||||
def test_stats_download_returns_file(client, score_with_labels):
|
||||
r = client.get("/api/stats/download")
|
||||
assert r.status_code == 200
|
||||
assert "jsonlines" in r.headers.get("content-type", "")
|
||||
|
||||
|
||||
def test_stats_download_404_when_no_file(client, data_dir):
|
||||
r = client.get("/api/stats/download")
|
||||
assert r.status_code == 404
|
||||
|
||||
|
||||
# ── /api/accounts/test ───────────────────────────────────────────────────────
|
||||
|
||||
def test_account_test_missing_fields(client):
|
||||
r = client.post("/api/accounts/test", json={"account": {"host": "", "username": "", "password": ""}})
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["ok"] is False
|
||||
assert "required" in data["message"].lower()
|
||||
|
||||
|
||||
def test_account_test_success(client):
|
||||
from unittest.mock import MagicMock, patch
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.select.return_value = ("OK", [b"99"])
|
||||
with patch("app.imap_fetch.imaplib.IMAP4_SSL", return_value=mock_conn):
|
||||
r = client.post("/api/accounts/test", json={"account": {
|
||||
"host": "imap.example.com", "port": 993, "use_ssl": True,
|
||||
"username": "u@example.com", "password": "pw", "folder": "INBOX",
|
||||
}})
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["ok"] is True
|
||||
assert data["count"] == 99
|
||||
|
||||
|
||||
# ── /api/fetch/stream (SSE) ──────────────────────────────────────────────────
|
||||
|
||||
def _parse_sse(content: bytes) -> list[dict]:
|
||||
"""Parse SSE response body into list of event dicts."""
|
||||
events = []
|
||||
for line in content.decode().splitlines():
|
||||
if line.startswith("data: "):
|
||||
events.append(json.loads(line[6:]))
|
||||
return events
|
||||
|
||||
|
||||
def test_fetch_stream_no_accounts_configured(client, config_dir):
|
||||
"""With no config, stream should immediately complete with 0 added."""
|
||||
r = client.get("/api/fetch/stream?accounts=NoSuchAccount&days_back=30&limit=10")
|
||||
assert r.status_code == 200
|
||||
events = _parse_sse(r.content)
|
||||
complete = next((e for e in events if e["type"] == "complete"), None)
|
||||
assert complete is not None
|
||||
assert complete["total_added"] == 0
|
||||
|
||||
|
||||
def test_fetch_stream_with_mock_imap(client, config_dir, data_dir):
|
||||
"""With one configured account, stream should yield start/done/complete events."""
|
||||
import yaml
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
# Write a config with one account
|
||||
cfg = {"accounts": [{"name": "Mock", "host": "h", "port": 993, "use_ssl": True,
|
||||
"username": "u", "password": "p", "folder": "INBOX",
|
||||
"days_back": 30}], "max_per_account": 50}
|
||||
(config_dir / "label_tool.yaml").write_text(yaml.dump(cfg))
|
||||
|
||||
raw_msg = (b"Subject: Interview\r\nFrom: a@b.com\r\n"
|
||||
b"Date: Mon, 1 Mar 2026 12:00:00 +0000\r\n\r\nBody")
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.search.return_value = ("OK", [b"1"])
|
||||
mock_conn.fetch.return_value = ("OK", [(b"1 (RFC822 {N})", raw_msg)])
|
||||
|
||||
with patch("app.imap_fetch.imaplib.IMAP4_SSL", return_value=mock_conn):
|
||||
r = client.get("/api/fetch/stream?accounts=Mock&days_back=30&limit=50")
|
||||
|
||||
assert r.status_code == 200
|
||||
events = _parse_sse(r.content)
|
||||
types = [e["type"] for e in events]
|
||||
assert "start" in types
|
||||
assert "done" in types
|
||||
assert "complete" in types
|
||||
|
||||
|
||||
# ---- /api/finetune/status tests ----
|
||||
|
||||
def test_finetune_status_returns_empty_when_no_models_dir(client):
|
||||
"""GET /api/finetune/status must return [] if models/ does not exist."""
|
||||
r = client.get("/api/finetune/status")
|
||||
assert r.status_code == 200
|
||||
assert r.json() == []
|
||||
|
||||
|
||||
def test_finetune_status_returns_training_info(client, tmp_path):
|
||||
"""GET /api/finetune/status must return one entry per training_info.json found."""
|
||||
import json as _json
|
||||
from app import api as api_module
|
||||
|
||||
models_dir = tmp_path / "models" / "avocet-deberta-small"
|
||||
models_dir.mkdir(parents=True)
|
||||
info = {
|
||||
"name": "avocet-deberta-small",
|
||||
"base_model_id": "cross-encoder/nli-deberta-v3-small",
|
||||
"val_macro_f1": 0.712,
|
||||
"timestamp": "2026-03-15T12:00:00Z",
|
||||
"sample_count": 401,
|
||||
}
|
||||
(models_dir / "training_info.json").write_text(_json.dumps(info))
|
||||
|
||||
api_module.set_models_dir(tmp_path / "models")
|
||||
try:
|
||||
r = client.get("/api/finetune/status")
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert any(d["name"] == "avocet-deberta-small" for d in data)
|
||||
finally:
|
||||
api_module.set_models_dir(api_module._ROOT / "models")
|
||||
|
||||
|
||||
def test_finetune_run_streams_sse_events(client):
|
||||
"""GET /api/finetune/run must return text/event-stream content type."""
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.stdout = iter(["Training epoch 1\n", "Done\n"])
|
||||
mock_proc.returncode = 0
|
||||
mock_proc.wait = MagicMock()
|
||||
|
||||
with patch("app.api._subprocess.Popen",return_value=mock_proc):
|
||||
r = client.get("/api/finetune/run?model=deberta-small&epochs=1")
|
||||
|
||||
assert r.status_code == 200
|
||||
assert "text/event-stream" in r.headers.get("content-type", "")
|
||||
|
||||
|
||||
def test_finetune_run_emits_complete_on_success(client):
|
||||
"""GET /api/finetune/run must emit a complete event on clean exit."""
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.stdout = iter(["progress line\n"])
|
||||
mock_proc.returncode = 0
|
||||
mock_proc.wait = MagicMock()
|
||||
|
||||
with patch("app.api._subprocess.Popen",return_value=mock_proc):
|
||||
r = client.get("/api/finetune/run?model=deberta-small&epochs=1")
|
||||
|
||||
assert '{"type": "complete"}' in r.text
|
||||
|
||||
|
||||
def test_finetune_run_emits_error_on_nonzero_exit(client):
|
||||
"""GET /api/finetune/run must emit an error event on non-zero exit."""
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.stdout = iter([])
|
||||
mock_proc.returncode = 1
|
||||
mock_proc.wait = MagicMock()
|
||||
|
||||
with patch("app.api._subprocess.Popen",return_value=mock_proc):
|
||||
r = client.get("/api/finetune/run?model=deberta-small&epochs=1")
|
||||
|
||||
assert '"type": "error"' in r.text
|
||||
|
||||
|
||||
def test_finetune_run_passes_score_files_to_subprocess(client):
|
||||
"""GET /api/finetune/run?score=file1&score=file2 must pass --score args to subprocess."""
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
captured_cmd = []
|
||||
|
||||
def mock_popen(cmd, **kwargs):
|
||||
captured_cmd.extend(cmd)
|
||||
m = MagicMock()
|
||||
m.stdout = iter([])
|
||||
m.returncode = 0
|
||||
m.wait = MagicMock()
|
||||
return m
|
||||
|
||||
with patch("app.api._subprocess.Popen",side_effect=mock_popen):
|
||||
client.get("/api/finetune/run?model=deberta-small&epochs=1&score=run1.jsonl&score=run2.jsonl")
|
||||
|
||||
assert "--score" in captured_cmd
|
||||
assert captured_cmd.count("--score") == 2
|
||||
# Paths are resolved to absolute — check filenames are present as substrings
|
||||
assert any("run1.jsonl" in arg for arg in captured_cmd)
|
||||
assert any("run2.jsonl" in arg for arg in captured_cmd)
|
||||
|
||||
|
||||
# ---- Cancel endpoint tests ----
|
||||
|
||||
def test_benchmark_cancel_returns_404_when_not_running(client):
|
||||
"""POST /api/benchmark/cancel must return 404 if no benchmark is running."""
|
||||
from app import api as api_module
|
||||
api_module._running_procs.pop("benchmark", None)
|
||||
r = client.post("/api/benchmark/cancel")
|
||||
assert r.status_code == 404
|
||||
|
||||
|
||||
def test_finetune_cancel_returns_404_when_not_running(client):
|
||||
"""POST /api/finetune/cancel must return 404 if no finetune is running."""
|
||||
from app import api as api_module
|
||||
api_module._running_procs.pop("finetune", None)
|
||||
r = client.post("/api/finetune/cancel")
|
||||
assert r.status_code == 404
|
||||
|
||||
|
||||
def test_benchmark_cancel_terminates_running_process(client):
|
||||
"""POST /api/benchmark/cancel must call terminate() on the running process."""
|
||||
from unittest.mock import MagicMock
|
||||
from app import api as api_module
|
||||
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.wait = MagicMock()
|
||||
api_module._running_procs["benchmark"] = mock_proc
|
||||
|
||||
try:
|
||||
r = client.post("/api/benchmark/cancel")
|
||||
assert r.status_code == 200
|
||||
assert r.json()["status"] == "cancelled"
|
||||
mock_proc.terminate.assert_called_once()
|
||||
finally:
|
||||
api_module._running_procs.pop("benchmark", None)
|
||||
api_module._cancelled_jobs.discard("benchmark")
|
||||
|
||||
|
||||
def test_finetune_cancel_terminates_running_process(client):
|
||||
"""POST /api/finetune/cancel must call terminate() on the running process."""
|
||||
from unittest.mock import MagicMock
|
||||
from app import api as api_module
|
||||
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.wait = MagicMock()
|
||||
api_module._running_procs["finetune"] = mock_proc
|
||||
|
||||
try:
|
||||
r = client.post("/api/finetune/cancel")
|
||||
assert r.status_code == 200
|
||||
assert r.json()["status"] == "cancelled"
|
||||
mock_proc.terminate.assert_called_once()
|
||||
finally:
|
||||
api_module._running_procs.pop("finetune", None)
|
||||
api_module._cancelled_jobs.discard("finetune")
|
||||
|
||||
|
||||
def test_benchmark_cancel_kills_process_on_timeout(client):
|
||||
"""POST /api/benchmark/cancel must call kill() if the process does not exit within 3 s."""
|
||||
import subprocess
|
||||
from unittest.mock import MagicMock
|
||||
from app import api as api_module
|
||||
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.wait.side_effect = subprocess.TimeoutExpired(cmd="benchmark", timeout=3)
|
||||
api_module._running_procs["benchmark"] = mock_proc
|
||||
|
||||
try:
|
||||
r = client.post("/api/benchmark/cancel")
|
||||
assert r.status_code == 200
|
||||
mock_proc.kill.assert_called_once()
|
||||
finally:
|
||||
api_module._running_procs.pop("benchmark", None)
|
||||
api_module._cancelled_jobs.discard("benchmark")
|
||||
|
||||
|
||||
def test_finetune_run_emits_cancelled_event(client):
|
||||
"""GET /api/finetune/run must emit cancelled (not error) when job was cancelled."""
|
||||
from unittest.mock import patch, MagicMock
|
||||
from app import api as api_module
|
||||
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.stdout = iter([])
|
||||
mock_proc.returncode = -15 # SIGTERM
|
||||
|
||||
def mock_wait():
|
||||
# Simulate cancel being called while the process is running (after discard clears stale flag)
|
||||
api_module._cancelled_jobs.add("finetune")
|
||||
|
||||
mock_proc.wait = mock_wait
|
||||
|
||||
def mock_popen(cmd, **kwargs):
|
||||
return mock_proc
|
||||
|
||||
try:
|
||||
with patch("app.api._subprocess.Popen",side_effect=mock_popen):
|
||||
r = client.get("/api/finetune/run?model=deberta-small&epochs=1")
|
||||
assert '{"type": "cancelled"}' in r.text
|
||||
assert '"type": "error"' not in r.text
|
||||
finally:
|
||||
api_module._cancelled_jobs.discard("finetune")
|
||||
|
||||
|
||||
def test_benchmark_run_emits_cancelled_event(client):
|
||||
"""GET /api/benchmark/run must emit cancelled (not error) when job was cancelled."""
|
||||
from unittest.mock import patch, MagicMock
|
||||
from app import api as api_module
|
||||
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.stdout = iter([])
|
||||
mock_proc.returncode = -15
|
||||
|
||||
def mock_wait():
|
||||
# Simulate cancel being called while the process is running (after discard clears stale flag)
|
||||
api_module._cancelled_jobs.add("benchmark")
|
||||
|
||||
mock_proc.wait = mock_wait
|
||||
|
||||
def mock_popen(cmd, **kwargs):
|
||||
return mock_proc
|
||||
|
||||
try:
|
||||
with patch("app.api._subprocess.Popen",side_effect=mock_popen):
|
||||
r = client.get("/api/benchmark/run")
|
||||
assert '{"type": "cancelled"}' in r.text
|
||||
assert '"type": "error"' not in r.text
|
||||
finally:
|
||||
api_module._cancelled_jobs.discard("benchmark")
|
||||
|
|
|
|||
|
|
@ -2,6 +2,11 @@
|
|||
import pytest
|
||||
|
||||
|
||||
def test_registry_has_thirteen_models():
|
||||
from scripts.benchmark_classifier import MODEL_REGISTRY
|
||||
assert len(MODEL_REGISTRY) == 13
|
||||
|
||||
|
||||
def test_registry_default_count():
|
||||
from scripts.benchmark_classifier import MODEL_REGISTRY
|
||||
defaults = [k for k, v in MODEL_REGISTRY.items() if v["default"]]
|
||||
|
|
@ -161,95 +166,3 @@ def test_active_models_includes_discovered_finetuned(tmp_path):
|
|||
|
||||
assert "avocet-deberta-small" in models
|
||||
assert isinstance(models["avocet-deberta-small"]["adapter_instance"], FineTunedAdapter)
|
||||
|
||||
|
||||
# ---- build_exemplars_from_jsonl() tests ----
|
||||
|
||||
def test_build_exemplars_samples_up_to_k_per_label(tmp_path):
|
||||
from scripts.benchmark_classifier import build_exemplars_from_jsonl
|
||||
import json
|
||||
|
||||
rows = [{"subject": f"S{i}", "body": f"B{i}", "label": "rejected"} for i in range(15)]
|
||||
rows.append({"subject": "Hire", "body": "Welcome", "label": "hired"})
|
||||
f = tmp_path / "score.jsonl"
|
||||
f.write_text("\n".join(json.dumps(r) for r in rows))
|
||||
|
||||
result = build_exemplars_from_jsonl(str(f), k_per_label=10)
|
||||
|
||||
assert len(result["rejected"]) == 10
|
||||
assert len(result["hired"]) == 1
|
||||
assert result["rejected"][0].startswith("Subject: S")
|
||||
|
||||
|
||||
def test_build_exemplars_formats_text_correctly(tmp_path):
|
||||
from scripts.benchmark_classifier import build_exemplars_from_jsonl
|
||||
import json
|
||||
|
||||
row = {"subject": "My Subject", "body": "My Body", "label": "neutral"}
|
||||
f = tmp_path / "score.jsonl"
|
||||
f.write_text(json.dumps(row))
|
||||
|
||||
result = build_exemplars_from_jsonl(str(f))
|
||||
|
||||
assert result["neutral"][0] == "Subject: My Subject\n\nMy Body"
|
||||
|
||||
|
||||
def test_build_exemplars_skips_rows_missing_label(tmp_path):
|
||||
from scripts.benchmark_classifier import build_exemplars_from_jsonl
|
||||
import json
|
||||
|
||||
rows = [
|
||||
{"subject": "A", "body": "B", "label": "neutral"},
|
||||
{"subject": "No label here", "body": "Body"},
|
||||
]
|
||||
f = tmp_path / "score.jsonl"
|
||||
f.write_text("\n".join(json.dumps(r) for r in rows))
|
||||
|
||||
result = build_exemplars_from_jsonl(str(f))
|
||||
assert list(result.keys()) == ["neutral"]
|
||||
|
||||
|
||||
def test_build_exemplars_truncates_body_at_600(tmp_path):
|
||||
from scripts.benchmark_classifier import build_exemplars_from_jsonl
|
||||
import json
|
||||
|
||||
row = {"subject": "S", "body": "x" * 800, "label": "neutral"}
|
||||
f = tmp_path / "score.jsonl"
|
||||
f.write_text(json.dumps(row))
|
||||
|
||||
result = build_exemplars_from_jsonl(str(f))
|
||||
body_part = result["neutral"][0].split("\n\n", 1)[1]
|
||||
assert len(body_part) == 600
|
||||
|
||||
|
||||
def test_build_exemplars_skips_rows_with_no_content(tmp_path):
|
||||
from scripts.benchmark_classifier import build_exemplars_from_jsonl
|
||||
import json
|
||||
|
||||
rows = [
|
||||
{"label": "neutral"}, # no subject, no body -> skip
|
||||
{"subject": "S", "body": "B", "label": "neutral"}, # valid -> keep
|
||||
{"label": "rejected", "subject": "", "body": ""}, # empty strings -> skip
|
||||
]
|
||||
f = tmp_path / "score.jsonl"
|
||||
lines = [json.dumps(r) for r in rows]
|
||||
f.write_text("\n".join(lines))
|
||||
|
||||
result = build_exemplars_from_jsonl(str(f))
|
||||
assert list(result.keys()) == ["neutral"]
|
||||
assert len(result["neutral"]) == 1
|
||||
|
||||
def test_registry_has_fourteen_models():
|
||||
from scripts.benchmark_classifier import MODEL_REGISTRY
|
||||
assert len(MODEL_REGISTRY) == 14
|
||||
|
||||
|
||||
def test_embed_knn_nomic_registry_entry():
|
||||
from scripts.benchmark_classifier import MODEL_REGISTRY
|
||||
from scripts.classifier_adapters import EmbeddingKNNAdapter
|
||||
entry = MODEL_REGISTRY["embed-knn-nomic"]
|
||||
assert entry["adapter"] is EmbeddingKNNAdapter
|
||||
assert entry["model_id"] == "nomic-embed-text"
|
||||
assert entry["params"] == "local-embed"
|
||||
assert entry["default"] is False
|
||||
assert entry.get("kwargs", {}).get("k") == 3
|
||||
|
|
|
|||
|
|
@ -1,418 +0,0 @@
|
|||
"""Tests for app/cforch.py — /api/cforch/* endpoints."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
|
||||
# ── Fixtures ───────────────────────────────────────────────────────────────────
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_cforch_globals(tmp_path):
|
||||
"""Redirect _CONFIG_DIR to tmp_path, reset running-state globals, and stub
|
||||
list_installed to return [] so real disk model directories don't bleed into
|
||||
tests that don't exercise the installed-model merge path."""
|
||||
from app import cforch as cforch_module
|
||||
|
||||
prev_config_dir = cforch_module._CONFIG_DIR
|
||||
prev_running = cforch_module._BENCH_RUNNING
|
||||
prev_proc = cforch_module._bench_proc
|
||||
|
||||
cforch_module.set_config_dir(tmp_path)
|
||||
cforch_module._BENCH_RUNNING = False
|
||||
cforch_module._bench_proc = None
|
||||
|
||||
with patch("app.models.list_installed", return_value=[]):
|
||||
yield tmp_path
|
||||
|
||||
cforch_module.set_config_dir(prev_config_dir)
|
||||
cforch_module._BENCH_RUNNING = prev_running
|
||||
cforch_module._bench_proc = prev_proc
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
from app.api import app
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def config_dir(reset_cforch_globals):
|
||||
"""Return the tmp config dir (already set as _CONFIG_DIR)."""
|
||||
return reset_cforch_globals
|
||||
|
||||
|
||||
def _write_config(config_dir: Path, cforch_cfg: dict) -> None:
|
||||
"""Write a label_tool.yaml with the given cforch block into config_dir."""
|
||||
cfg = {"cforch": cforch_cfg}
|
||||
(config_dir / "label_tool.yaml").write_text(
|
||||
yaml.dump(cfg), encoding="utf-8"
|
||||
)
|
||||
|
||||
|
||||
def _write_tasks_yaml(path: Path, tasks: list[dict]) -> None:
|
||||
path.write_text(yaml.dump({"tasks": tasks}), encoding="utf-8")
|
||||
|
||||
|
||||
def _write_models_yaml(path: Path, models: list[dict]) -> None:
|
||||
path.write_text(yaml.dump({"models": models}), encoding="utf-8")
|
||||
|
||||
|
||||
# ── GET /tasks ─────────────────────────────────────────────────────────────────
|
||||
|
||||
def test_tasks_returns_empty_when_not_configured(client):
|
||||
"""No config file present — endpoint returns empty lists."""
|
||||
r = client.get("/api/cforch/tasks")
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data == {"tasks": [], "types": []}
|
||||
|
||||
|
||||
def test_tasks_parses_yaml(client, config_dir, tmp_path):
|
||||
tasks_file = tmp_path / "bench_tasks.yaml"
|
||||
_write_tasks_yaml(tasks_file, [
|
||||
{"id": "t1", "name": "Task One", "type": "instruction"},
|
||||
{"id": "t2", "name": "Task Two", "type": "reasoning"},
|
||||
])
|
||||
_write_config(config_dir, {"bench_tasks": str(tasks_file)})
|
||||
|
||||
r = client.get("/api/cforch/tasks")
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert len(data["tasks"]) == 2
|
||||
# TaskEntry now includes optional prompt/system fields (default "")
|
||||
t1 = data["tasks"][0]
|
||||
assert t1["id"] == "t1" and t1["name"] == "Task One" and t1["type"] == "instruction"
|
||||
t2 = data["tasks"][1]
|
||||
assert t2["id"] == "t2" and t2["name"] == "Task Two" and t2["type"] == "reasoning"
|
||||
assert "instruction" in data["types"]
|
||||
assert "reasoning" in data["types"]
|
||||
|
||||
|
||||
def test_tasks_returns_types_deduplicated(client, config_dir, tmp_path):
|
||||
"""Multiple tasks sharing a type — types list must not duplicate."""
|
||||
tasks_file = tmp_path / "bench_tasks.yaml"
|
||||
_write_tasks_yaml(tasks_file, [
|
||||
{"id": "t1", "name": "A", "type": "instruction"},
|
||||
{"id": "t2", "name": "B", "type": "instruction"},
|
||||
{"id": "t3", "name": "C", "type": "reasoning"},
|
||||
])
|
||||
_write_config(config_dir, {"bench_tasks": str(tasks_file)})
|
||||
|
||||
r = client.get("/api/cforch/tasks")
|
||||
data = r.json()
|
||||
assert data["types"].count("instruction") == 1
|
||||
assert len(data["types"]) == 2
|
||||
|
||||
|
||||
# ── GET /models ────────────────────────────────────────────────────────────────
|
||||
|
||||
def test_models_returns_empty_when_not_configured(client):
|
||||
"""No config file present — endpoint returns empty model list."""
|
||||
r = client.get("/api/cforch/models")
|
||||
assert r.status_code == 200
|
||||
assert r.json() == {"models": []}
|
||||
|
||||
|
||||
def test_models_parses_bench_models_yaml(client, config_dir, tmp_path):
|
||||
models_file = tmp_path / "bench_models.yaml"
|
||||
_write_models_yaml(models_file, [
|
||||
{
|
||||
"name": "llama3",
|
||||
"id": "llama3:8b",
|
||||
"service": "ollama",
|
||||
"tags": ["fast", "small"],
|
||||
"vram_estimate_mb": 6000,
|
||||
}
|
||||
])
|
||||
_write_config(config_dir, {"bench_models": str(models_file)})
|
||||
|
||||
r = client.get("/api/cforch/models")
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert len(data["models"]) == 1
|
||||
m = data["models"][0]
|
||||
assert m["name"] == "llama3"
|
||||
assert m["id"] == "llama3:8b"
|
||||
assert m["service"] == "ollama"
|
||||
assert m["tags"] == ["fast", "small"]
|
||||
assert m["vram_estimate_mb"] == 6000
|
||||
|
||||
|
||||
def test_models_merges_installed_generators(client, config_dir, tmp_path):
|
||||
"""Installed cf-text/vllm generator models appear in the model list,
|
||||
deduplicated against bench_models.yaml entries."""
|
||||
models_file = tmp_path / "bench_models.yaml"
|
||||
_write_models_yaml(models_file, [
|
||||
{"name": "llama3", "id": "llama3:8b", "service": "ollama", "tags": [], "vram_estimate_mb": 6000},
|
||||
{"name": "already-there", "id": "ibm-granite/granite-4.1-8b", "service": "cf-text", "tags": [], "vram_estimate_mb": 8000},
|
||||
])
|
||||
_write_config(config_dir, {"bench_models": str(models_file)})
|
||||
|
||||
fake_installed = [
|
||||
# should be included — cf-text generator not already in YAML
|
||||
{"model_id": "meta-llama/Llama-3.1-8B", "service": "cf-text", "role": "generator", "vram_mb": 16000},
|
||||
# should be deduped — repo_id matches a YAML entry
|
||||
{"model_id": "ibm-granite/granite-4.1-8b", "service": "cf-text", "role": "generator", "vram_mb": 8000},
|
||||
# should be excluded — classifier, not a generator
|
||||
{"model_id": "cross-encoder/ms-marco-MiniLM-L6", "service": "avocet", "role": "reranker", "vram_mb": 500},
|
||||
]
|
||||
with patch("app.models.list_installed", return_value=fake_installed):
|
||||
r = client.get("/api/cforch/models")
|
||||
assert r.status_code == 200
|
||||
ids = [m["id"] for m in r.json()["models"]]
|
||||
assert "llama3:8b" in ids # from YAML
|
||||
assert "ibm-granite/granite-4.1-8b" in ids # from YAML (not duplicated)
|
||||
assert "meta-llama/Llama-3.1-8B" in ids # merged from installed
|
||||
assert "cross-encoder/ms-marco-MiniLM-L6" not in ids # filtered out (reranker)
|
||||
assert ids.count("ibm-granite/granite-4.1-8b") == 1 # no duplicate
|
||||
|
||||
|
||||
# ── GET /run ───────────────────────────────────────────────────────────────────
|
||||
|
||||
def test_run_returns_409_when_already_running(client):
|
||||
"""If a benchmark subprocess is actively running, GET /run returns 409."""
|
||||
from unittest.mock import MagicMock
|
||||
from app import cforch as cforch_module
|
||||
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.poll.return_value = None # process still alive
|
||||
cforch_module._BENCH_RUNNING = True
|
||||
cforch_module._bench_proc = mock_proc
|
||||
|
||||
r = client.get("/api/cforch/run")
|
||||
assert r.status_code == 409
|
||||
|
||||
|
||||
def test_run_returns_error_when_bench_script_not_configured(client):
|
||||
"""No config at all — SSE stream contains an error event."""
|
||||
r = client.get("/api/cforch/run")
|
||||
assert r.status_code == 200
|
||||
assert '"type": "error"' in r.text
|
||||
assert "bench_script not configured" in r.text
|
||||
|
||||
|
||||
def test_run_streams_progress_events(client, config_dir, tmp_path):
|
||||
"""Mock subprocess — SSE stream emits progress events from stdout."""
|
||||
bench_script = tmp_path / "fake_benchmark.py"
|
||||
bench_script.write_text("# fake", encoding="utf-8")
|
||||
|
||||
tasks_file = tmp_path / "bench_tasks.yaml"
|
||||
tasks_file.write_text(yaml.dump({"tasks": []}), encoding="utf-8")
|
||||
models_file = tmp_path / "bench_models.yaml"
|
||||
models_file.write_text(yaml.dump({"models": []}), encoding="utf-8")
|
||||
results_dir = tmp_path / "results"
|
||||
results_dir.mkdir()
|
||||
|
||||
_write_config(config_dir, {
|
||||
"bench_script": str(bench_script),
|
||||
"bench_tasks": str(tasks_file),
|
||||
"bench_models": str(models_file),
|
||||
"results_dir": str(results_dir),
|
||||
"python_bin": "/usr/bin/python3",
|
||||
})
|
||||
|
||||
mock_stdout = MagicMock()
|
||||
mock_stdout.readline.side_effect = ["Running task 1\n", "Running task 2\n", ""]
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.stdout = mock_stdout
|
||||
mock_proc.returncode = 1 # non-zero so we don't need summary.json
|
||||
mock_proc.wait = MagicMock()
|
||||
|
||||
with patch("app.cforch._subprocess.Popen", return_value=mock_proc), \
|
||||
patch("app.cforch._select.select", return_value=([mock_stdout], [], [])):
|
||||
r = client.get("/api/cforch/run")
|
||||
|
||||
assert r.status_code == 200
|
||||
assert '"type": "progress"' in r.text
|
||||
assert "Running task 1" in r.text
|
||||
assert "Running task 2" in r.text
|
||||
|
||||
|
||||
def test_run_emits_result_on_success(client, config_dir, tmp_path):
|
||||
"""Mock subprocess exit 0 + write fake summary.json — stream emits result event."""
|
||||
bench_script = tmp_path / "fake_benchmark.py"
|
||||
bench_script.write_text("# fake", encoding="utf-8")
|
||||
|
||||
tasks_file = tmp_path / "bench_tasks.yaml"
|
||||
tasks_file.write_text(yaml.dump({"tasks": []}), encoding="utf-8")
|
||||
models_file = tmp_path / "bench_models.yaml"
|
||||
models_file.write_text(yaml.dump({"models": []}), encoding="utf-8")
|
||||
|
||||
results_dir = tmp_path / "results"
|
||||
run_dir = results_dir / "2026-04-08-120000"
|
||||
run_dir.mkdir(parents=True)
|
||||
summary_data = {"score": 0.92, "models_evaluated": 3}
|
||||
(run_dir / "summary.json").write_text(json.dumps(summary_data), encoding="utf-8")
|
||||
|
||||
_write_config(config_dir, {
|
||||
"bench_script": str(bench_script),
|
||||
"bench_tasks": str(tasks_file),
|
||||
"bench_models": str(models_file),
|
||||
"results_dir": str(results_dir),
|
||||
"python_bin": "/usr/bin/python3",
|
||||
})
|
||||
|
||||
mock_stdout = MagicMock()
|
||||
mock_stdout.readline.side_effect = [""] # no output lines, immediate EOF
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.stdout = mock_stdout
|
||||
mock_proc.returncode = 0
|
||||
mock_proc.wait = MagicMock()
|
||||
|
||||
with patch("app.cforch._subprocess.Popen", return_value=mock_proc), \
|
||||
patch("app.cforch._select.select", return_value=([mock_stdout], [], [])):
|
||||
r = client.get("/api/cforch/run")
|
||||
|
||||
assert r.status_code == 200
|
||||
assert '"type": "result"' in r.text
|
||||
assert '"score": 0.92' in r.text
|
||||
assert '"type": "complete"' in r.text
|
||||
|
||||
|
||||
# ── GET /results ───────────────────────────────────────────────────────────────
|
||||
|
||||
def test_results_returns_404_when_no_results(client):
|
||||
"""No results_dir configured — endpoint returns 404."""
|
||||
r = client.get("/api/cforch/results")
|
||||
assert r.status_code == 404
|
||||
|
||||
|
||||
def test_results_returns_latest_summary(client, config_dir, tmp_path):
|
||||
"""Write fake results dir with one subdir containing summary.json."""
|
||||
results_dir = tmp_path / "results"
|
||||
run_dir = results_dir / "2026-04-08-150000"
|
||||
run_dir.mkdir(parents=True)
|
||||
summary_data = {"score": 0.88, "run": "test"}
|
||||
(run_dir / "summary.json").write_text(json.dumps(summary_data), encoding="utf-8")
|
||||
|
||||
_write_config(config_dir, {"results_dir": str(results_dir)})
|
||||
|
||||
r = client.get("/api/cforch/results")
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["score"] == 0.88
|
||||
assert data["run"] == "test"
|
||||
|
||||
|
||||
# ── POST /cancel ───────────────────────────────────────────────────────────────
|
||||
|
||||
def test_cancel_returns_404_when_not_running(client):
|
||||
"""POST /cancel when no benchmark running — returns 404."""
|
||||
r = client.post("/api/cforch/cancel")
|
||||
assert r.status_code == 404
|
||||
|
||||
|
||||
def test_cancel_terminates_running_benchmark(client):
|
||||
"""POST /cancel when benchmark is running — terminates proc and returns cancelled."""
|
||||
from app import cforch as cforch_module
|
||||
|
||||
mock_proc = MagicMock()
|
||||
cforch_module._BENCH_RUNNING = True
|
||||
cforch_module._bench_proc = mock_proc
|
||||
|
||||
r = client.post("/api/cforch/cancel")
|
||||
assert r.status_code == 200
|
||||
assert r.json() == {"status": "cancelled"}
|
||||
mock_proc.terminate.assert_called_once()
|
||||
assert cforch_module._BENCH_RUNNING is False
|
||||
assert cforch_module._bench_proc is None
|
||||
|
||||
|
||||
# ── GET /config ────────────────────────────────────────────────────────────────
|
||||
|
||||
def test_config_returns_empty_when_no_yaml_no_env(client, monkeypatch):
|
||||
"""No yaml, no env vars — all fields empty, license_key_set False."""
|
||||
for key in ("CF_ORCH_URL", "CF_LICENSE_KEY", "OLLAMA_HOST", "OLLAMA_MODEL"):
|
||||
monkeypatch.delenv(key, raising=False)
|
||||
|
||||
r = client.get("/api/cforch/config")
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["coordinator_url"] == ""
|
||||
assert data["ollama_url"] == ""
|
||||
assert data["license_key_set"] is False
|
||||
|
||||
|
||||
def test_config_reads_env_vars_when_no_yaml(client, monkeypatch):
|
||||
"""Env vars populate fields when label_tool.yaml has no cforch section."""
|
||||
monkeypatch.setenv("CF_ORCH_URL", "http://orch.example.com:7700")
|
||||
monkeypatch.setenv("CF_LICENSE_KEY", "CFG-AVCT-TEST-TEST-TEST")
|
||||
monkeypatch.setenv("OLLAMA_HOST", "http://ollama.local:11434")
|
||||
monkeypatch.setenv("OLLAMA_MODEL", "mistral:7b")
|
||||
|
||||
r = client.get("/api/cforch/config")
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["coordinator_url"] == "http://orch.example.com:7700"
|
||||
assert data["ollama_url"] == "http://ollama.local:11434"
|
||||
assert data["ollama_model"] == "mistral:7b"
|
||||
assert data["license_key_set"] is True # set, but value not exposed
|
||||
|
||||
|
||||
def test_config_yaml_overrides_env(client, config_dir, monkeypatch):
|
||||
"""label_tool.yaml cforch values take priority over env vars."""
|
||||
monkeypatch.setenv("CF_ORCH_URL", "http://env-orch:7700")
|
||||
monkeypatch.setenv("OLLAMA_HOST", "http://env-ollama:11434")
|
||||
|
||||
_write_config(config_dir, {
|
||||
"coordinator_url": "http://yaml-orch:7700",
|
||||
"ollama_url": "http://yaml-ollama:11434",
|
||||
})
|
||||
|
||||
r = client.get("/api/cforch/config")
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["coordinator_url"] == "http://yaml-orch:7700"
|
||||
assert data["ollama_url"] == "http://yaml-ollama:11434"
|
||||
assert data["source"] == "yaml+env"
|
||||
|
||||
|
||||
def test_run_passes_license_key_env_to_subprocess(client, config_dir, tmp_path, monkeypatch):
|
||||
"""CF_LICENSE_KEY must be forwarded to the benchmark subprocess env."""
|
||||
monkeypatch.setenv("CF_LICENSE_KEY", "CFG-AVCT-ENV-ONLY-KEY")
|
||||
|
||||
bench_script = tmp_path / "benchmark.py"
|
||||
bench_script.write_text("# stub", encoding="utf-8")
|
||||
tasks_file = tmp_path / "bench_tasks.yaml"
|
||||
tasks_file.write_text(yaml.dump({"tasks": []}), encoding="utf-8")
|
||||
models_file = tmp_path / "bench_models.yaml"
|
||||
models_file.write_text(yaml.dump({"models": []}), encoding="utf-8")
|
||||
|
||||
_write_config(config_dir, {
|
||||
"bench_script": str(bench_script),
|
||||
"bench_tasks": str(tasks_file),
|
||||
"bench_models": str(models_file),
|
||||
"results_dir": str(tmp_path / "results"),
|
||||
"python_bin": "/usr/bin/python3",
|
||||
})
|
||||
|
||||
captured_env: dict = {}
|
||||
|
||||
def fake_popen(cmd, **kwargs):
|
||||
captured_env.update(kwargs.get("env", {}))
|
||||
mock = MagicMock()
|
||||
mock.stdout = iter([])
|
||||
mock.returncode = 0
|
||||
mock.wait = MagicMock()
|
||||
return mock
|
||||
|
||||
with patch("app.cforch._subprocess.Popen", side_effect=fake_popen):
|
||||
client.get("/api/cforch/run")
|
||||
|
||||
assert captured_env.get("CF_LICENSE_KEY") == "CFG-AVCT-ENV-ONLY-KEY"
|
||||
|
||||
|
||||
def test_eval_cforch_router_includes_all_sub_routers():
|
||||
"""eval/cforch.py router must include routes from all four sub-routers."""
|
||||
from app.eval.cforch import router
|
||||
paths = {r.path for r in router.routes}
|
||||
assert any("/cforch/" in p for p in paths), f"no /cforch/ routes found in {paths}"
|
||||
assert any("/style/" in p for p in paths), f"no /style/ routes found in {paths}"
|
||||
assert any("/voice/" in p for p in paths), f"no /voice/ routes found in {paths}"
|
||||
assert any("/plans-bench/" in p for p in paths), f"no /plans-bench/ routes found in {paths}"
|
||||
|
|
@ -268,373 +268,3 @@ def test_finetuned_adapter_unload_clears_pipeline():
|
|||
assert adapter._pipeline is not None
|
||||
adapter.unload()
|
||||
assert adapter._pipeline is None
|
||||
|
||||
# ---- _cosine() tests ----
|
||||
|
||||
def test_cosine_identical_unit_vectors():
|
||||
import math
|
||||
from scripts.classifier_adapters import _cosine
|
||||
assert _cosine([1.0, 0.0], [1.0, 0.0]) == pytest.approx(1.0)
|
||||
|
||||
|
||||
def test_cosine_orthogonal_vectors():
|
||||
from scripts.classifier_adapters import _cosine
|
||||
assert _cosine([1.0, 0.0], [0.0, 1.0]) == pytest.approx(0.0)
|
||||
|
||||
|
||||
def test_cosine_known_value():
|
||||
import math
|
||||
from scripts.classifier_adapters import _cosine
|
||||
# [1,0] vs [1/sqrt(2), 1/sqrt(2)] → dot = 1/sqrt(2), both norms = 1 → 1/sqrt(2)
|
||||
v = [1.0 / math.sqrt(2), 1.0 / math.sqrt(2)]
|
||||
assert _cosine([1.0, 0.0], v) == pytest.approx(1.0 / math.sqrt(2))
|
||||
|
||||
|
||||
def test_cosine_zero_vector_returns_zero():
|
||||
from scripts.classifier_adapters import _cosine
|
||||
assert _cosine([0.0, 0.0], [1.0, 0.0]) == pytest.approx(0.0)
|
||||
|
||||
|
||||
# ---- DEFAULT_EXEMPLARS tests ----
|
||||
|
||||
def test_default_exemplars_covers_all_labels():
|
||||
from scripts.classifier_adapters import DEFAULT_EXEMPLARS, LABELS
|
||||
for label in LABELS:
|
||||
assert label in DEFAULT_EXEMPLARS, f"DEFAULT_EXEMPLARS missing label: {label}"
|
||||
assert len(DEFAULT_EXEMPLARS[label]) >= 4, f"{label} needs >= 4 exemplars for k=3 voting"
|
||||
|
||||
|
||||
def test_default_exemplars_sparse_labels_have_at_least_four():
|
||||
from scripts.classifier_adapters import DEFAULT_EXEMPLARS
|
||||
# These labels have very few real examples; need >= 4 so k=3 vote is meaningful
|
||||
for label in ("hired", "survey_received", "event_rescheduled"):
|
||||
assert len(DEFAULT_EXEMPLARS[label]) >= 4, (
|
||||
f"{label} needs >= 4 exemplars for k=3 voting to work reliably"
|
||||
)
|
||||
|
||||
def test_default_exemplars_strings_are_formatted_correctly():
|
||||
from scripts.classifier_adapters import DEFAULT_EXEMPLARS
|
||||
for label, texts in DEFAULT_EXEMPLARS.items():
|
||||
for text in texts:
|
||||
assert text.startswith("Subject: "), (
|
||||
f"{label!r} exemplar missing 'Subject: ' prefix: {text[:50]!r}"
|
||||
)
|
||||
assert "\n\n" in text, (
|
||||
f"{label!r} exemplar missing double-newline separator: {text[:50]!r}"
|
||||
)
|
||||
|
||||
# ---- EmbeddingKNNAdapter constructor tests ----
|
||||
|
||||
def test_embedding_knn_is_classifier_adapter():
|
||||
from scripts.classifier_adapters import EmbeddingKNNAdapter, ClassifierAdapter
|
||||
adapter = EmbeddingKNNAdapter(
|
||||
"test-knn", "nomic-embed-text",
|
||||
k=3, orch_url="http://orch:7700", ollama_url="http://ollama:11434",
|
||||
)
|
||||
assert isinstance(adapter, ClassifierAdapter)
|
||||
|
||||
|
||||
def test_embedding_knn_name_and_model_id():
|
||||
from scripts.classifier_adapters import EmbeddingKNNAdapter
|
||||
adapter = EmbeddingKNNAdapter(
|
||||
"embed-knn-nomic", "nomic-embed-text",
|
||||
k=3, orch_url="http://orch:7700", ollama_url="http://ollama:11434",
|
||||
)
|
||||
assert adapter.name == "embed-knn-nomic"
|
||||
assert adapter.model_id == "nomic-embed-text"
|
||||
|
||||
|
||||
def test_embedding_knn_uses_default_exemplars_when_none_given():
|
||||
from scripts.classifier_adapters import EmbeddingKNNAdapter, DEFAULT_EXEMPLARS
|
||||
adapter = EmbeddingKNNAdapter(
|
||||
"test", "nomic-embed-text",
|
||||
k=3, orch_url="http://orch:7700", ollama_url="http://ollama:11434",
|
||||
)
|
||||
assert adapter._exemplar_texts is DEFAULT_EXEMPLARS
|
||||
|
||||
|
||||
def test_embedding_knn_accepts_custom_exemplars():
|
||||
from scripts.classifier_adapters import EmbeddingKNNAdapter
|
||||
custom = {"rejected": ["Sorry, we went with others."]}
|
||||
adapter = EmbeddingKNNAdapter(
|
||||
"test", "nomic-embed-text",
|
||||
k=3, orch_url="http://orch:7700", ollama_url="http://ollama:11434",
|
||||
exemplar_texts=custom,
|
||||
)
|
||||
assert adapter._exemplar_texts is custom
|
||||
|
||||
|
||||
# ---- EmbeddingKNNAdapter.load() tests ----
|
||||
|
||||
def _make_post_mock(alloc_url="http://navi:11434", alloc_id="alloc-abc"):
|
||||
"""Return a side_effect function for patching httpx.post.
|
||||
|
||||
Allocate calls get alloc_url/alloc_id; embed calls return one [0.1,0.2,0.3]
|
||||
embedding per input text.
|
||||
"""
|
||||
def _side_effect(url, *, json=None, timeout=None, **kwargs):
|
||||
from unittest.mock import MagicMock
|
||||
resp = MagicMock()
|
||||
resp.raise_for_status.return_value = None
|
||||
if "/allocate" in url:
|
||||
resp.status_code = 200
|
||||
resp.json.return_value = {"allocation_id": alloc_id, "url": alloc_url}
|
||||
else:
|
||||
n = len((json or {}).get("input", []))
|
||||
resp.status_code = 200
|
||||
resp.json.return_value = {"data": [{"embedding": [0.1, 0.2, 0.3]}] * n}
|
||||
return resp
|
||||
return _side_effect
|
||||
|
||||
|
||||
def test_load_calls_allocate_then_embeds_each_label():
|
||||
from unittest.mock import patch
|
||||
from scripts.classifier_adapters import EmbeddingKNNAdapter
|
||||
|
||||
exemplars = {
|
||||
"rejected": ["We went with others"],
|
||||
"hired": ["Welcome aboard!", "First day info"],
|
||||
}
|
||||
adapter = EmbeddingKNNAdapter(
|
||||
"test", "nomic-embed-text", k=3,
|
||||
orch_url="http://orch:7700", ollama_url="http://ollama:11434",
|
||||
exemplar_texts=exemplars,
|
||||
)
|
||||
|
||||
post_urls = []
|
||||
def capturing_mock(url, *, json=None, timeout=None, **kwargs):
|
||||
post_urls.append(url)
|
||||
return _make_post_mock()(url, json=json, timeout=timeout)
|
||||
|
||||
with patch("httpx.post", side_effect=capturing_mock):
|
||||
adapter.load()
|
||||
|
||||
assert any("/allocate" in u for u in post_urls), "expected allocate call"
|
||||
assert any("/v1/embeddings" in u for u in post_urls), "expected embed call"
|
||||
assert adapter._allocation_id == "alloc-abc"
|
||||
assert adapter._node_url == "http://navi:11434"
|
||||
assert adapter._orch_url_used == "http://orch:7700"
|
||||
assert "rejected" in adapter._exemplar_embeddings
|
||||
assert "hired" in adapter._exemplar_embeddings
|
||||
assert len(adapter._exemplar_embeddings["rejected"]) == 1
|
||||
assert len(adapter._exemplar_embeddings["hired"]) == 2
|
||||
assert adapter._exemplar_embeddings["rejected"][0] == [0.1, 0.2, 0.3]
|
||||
assert adapter._exemplar_embeddings["hired"][0] == [0.1, 0.2, 0.3]
|
||||
|
||||
|
||||
def test_load_falls_back_to_ollama_when_allocate_fails():
|
||||
from unittest.mock import patch, MagicMock
|
||||
from scripts.classifier_adapters import EmbeddingKNNAdapter
|
||||
|
||||
exemplars = {"rejected": ["We went with others"]}
|
||||
adapter = EmbeddingKNNAdapter(
|
||||
"test", "nomic-embed-text", k=3,
|
||||
orch_url="http://orch:7700", ollama_url="http://ollama:11434",
|
||||
exemplar_texts=exemplars,
|
||||
)
|
||||
|
||||
def failing_allocate_mock(url, *, json=None, timeout=None, **kwargs):
|
||||
resp = MagicMock()
|
||||
if "/allocate" in url:
|
||||
resp.status_code = 503
|
||||
resp.json.return_value = {}
|
||||
else:
|
||||
resp.raise_for_status.return_value = None
|
||||
resp.json.return_value = {"data": [{"embedding": [0.1, 0.2, 0.3]}]}
|
||||
return resp
|
||||
|
||||
with patch("httpx.post", side_effect=failing_allocate_mock):
|
||||
adapter.load()
|
||||
|
||||
assert adapter._allocation_id == ""
|
||||
assert adapter._orch_url_used == ""
|
||||
assert adapter._node_url == "http://ollama:11434"
|
||||
assert "rejected" in adapter._exemplar_embeddings
|
||||
|
||||
|
||||
def test_load_falls_back_to_ollama_when_allocate_raises():
|
||||
from unittest.mock import patch, MagicMock
|
||||
import httpx as _httpx
|
||||
from scripts.classifier_adapters import EmbeddingKNNAdapter
|
||||
|
||||
exemplars = {"rejected": ["We went with others"]}
|
||||
adapter = EmbeddingKNNAdapter(
|
||||
"test", "nomic-embed-text", k=3,
|
||||
orch_url="http://orch:7700", ollama_url="http://ollama:11434",
|
||||
exemplar_texts=exemplars,
|
||||
)
|
||||
|
||||
def raising_mock(url, *, json=None, timeout=None, **kwargs):
|
||||
if "/allocate" in url:
|
||||
raise _httpx.ConnectError("connection refused")
|
||||
resp = MagicMock()
|
||||
resp.raise_for_status.return_value = None
|
||||
resp.json.return_value = {"data": [{"embedding": [0.1, 0.2, 0.3]}]}
|
||||
return resp
|
||||
|
||||
with patch("httpx.post", side_effect=raising_mock):
|
||||
adapter.load()
|
||||
|
||||
assert adapter._allocation_id == ""
|
||||
assert adapter._orch_url_used == ""
|
||||
assert adapter._node_url == "http://ollama:11434"
|
||||
assert "rejected" in adapter._exemplar_embeddings
|
||||
|
||||
|
||||
# ---- EmbeddingKNNAdapter.unload() tests ----
|
||||
|
||||
def test_unload_releases_orch_allocation_and_clears_state():
|
||||
from unittest.mock import patch, MagicMock
|
||||
from scripts.classifier_adapters import EmbeddingKNNAdapter
|
||||
|
||||
adapter = EmbeddingKNNAdapter(
|
||||
"test", "nomic-embed-text", k=3,
|
||||
orch_url="http://orch:7700", ollama_url="http://ollama:11434",
|
||||
)
|
||||
adapter._exemplar_embeddings = {"rejected": [[1.0, 0.0]]}
|
||||
adapter._node_url = "http://navi:11434"
|
||||
adapter._allocation_id = "alloc-abc"
|
||||
adapter._orch_url_used = "http://orch:7700"
|
||||
|
||||
delete_calls = []
|
||||
def mock_request(method, url, **kwargs):
|
||||
delete_calls.append((method, url))
|
||||
resp = MagicMock()
|
||||
resp.status_code = 200
|
||||
return resp
|
||||
|
||||
with patch("httpx.request", side_effect=mock_request):
|
||||
adapter.unload()
|
||||
|
||||
assert len(delete_calls) == 1
|
||||
method, url = delete_calls[0]
|
||||
assert method == "DELETE"
|
||||
assert "alloc-abc" in url
|
||||
assert adapter._exemplar_embeddings == {}
|
||||
assert adapter._allocation_id == ""
|
||||
assert adapter._node_url == ""
|
||||
assert adapter._orch_url_used == ""
|
||||
|
||||
|
||||
def test_unload_skips_delete_on_ollama_fallback_path():
|
||||
from unittest.mock import patch
|
||||
from scripts.classifier_adapters import EmbeddingKNNAdapter
|
||||
|
||||
adapter = EmbeddingKNNAdapter(
|
||||
"test", "nomic-embed-text", k=3,
|
||||
orch_url="http://orch:7700", ollama_url="http://ollama:11434",
|
||||
)
|
||||
adapter._exemplar_embeddings = {"rejected": [[1.0, 0.0]]}
|
||||
adapter._node_url = "http://ollama:11434"
|
||||
adapter._allocation_id = "" # fallback path: no allocation was made
|
||||
adapter._orch_url_used = ""
|
||||
|
||||
delete_calls = []
|
||||
with patch("httpx.request", side_effect=lambda *a, **k: delete_calls.append(a)):
|
||||
adapter.unload()
|
||||
|
||||
assert len(delete_calls) == 0
|
||||
assert adapter._exemplar_embeddings == {}
|
||||
assert adapter._node_url == ""
|
||||
|
||||
|
||||
# ---- EmbeddingKNNAdapter.classify() tests ----
|
||||
|
||||
def _adapter_with_embeddings(exemplar_embeddings, k=3):
|
||||
"""Return a pre-loaded EmbeddingKNNAdapter (bypass load()) with given per-label vectors."""
|
||||
from scripts.classifier_adapters import EmbeddingKNNAdapter
|
||||
adapter = EmbeddingKNNAdapter(
|
||||
"test", "nomic-embed-text", k=k,
|
||||
orch_url="http://orch:7700", ollama_url="http://ollama:11434",
|
||||
)
|
||||
adapter._exemplar_embeddings = exemplar_embeddings
|
||||
adapter._node_url = "http://navi:11434"
|
||||
return adapter
|
||||
|
||||
|
||||
def _embed_resp(vec):
|
||||
"""Return a mock httpx response for /v1/embeddings returning a single vector."""
|
||||
from unittest.mock import MagicMock
|
||||
resp = MagicMock()
|
||||
resp.raise_for_status.return_value = None
|
||||
resp.json.return_value = {"data": [{"embedding": vec}]}
|
||||
return resp
|
||||
|
||||
|
||||
def test_classify_returns_majority_vote_label():
|
||||
from unittest.mock import patch
|
||||
adapter = _adapter_with_embeddings({
|
||||
"rejected": [[1.0, 0.0, 0.0], [0.9, 0.1, 0.0], [0.85, 0.15, 0.0]],
|
||||
"neutral": [[0.0, 1.0, 0.0]],
|
||||
}, k=3)
|
||||
|
||||
# Query [1,0,0] is closest to all three "rejected" exemplars
|
||||
with patch("httpx.post", return_value=_embed_resp([1.0, 0.0, 0.0])):
|
||||
result = adapter.classify("We went with others", "Thank you for applying.")
|
||||
|
||||
assert result == "rejected"
|
||||
|
||||
|
||||
def test_classify_tiebreak_by_mean_score():
|
||||
from unittest.mock import patch
|
||||
# k=2: each label gets exactly 1 vote → tie-break by mean similarity
|
||||
# [1,0] query: cosine to [1,0] = 1.0 ("rejected"), cosine to [0.6,0.8] ≈ 0.6 ("neutral")
|
||||
adapter = _adapter_with_embeddings({
|
||||
"rejected": [[1.0, 0.0]],
|
||||
"neutral": [[0.6, 0.8]],
|
||||
}, k=2)
|
||||
|
||||
with patch("httpx.post", return_value=_embed_resp([1.0, 0.0])):
|
||||
result = adapter.classify("Rejection", "Sorry")
|
||||
|
||||
assert result == "rejected"
|
||||
|
||||
|
||||
def test_classify_sparse_label_can_win():
|
||||
from unittest.mock import patch
|
||||
# "hired" has only 1 exemplar; with k=1, the single closest match wins
|
||||
adapter = _adapter_with_embeddings({
|
||||
"rejected": [[0.0, 0.0, 1.0], [0.0, 0.1, 0.9]],
|
||||
"hired": [[1.0, 0.0, 0.0]],
|
||||
}, k=1)
|
||||
|
||||
# Query [1,0,0] → hired exemplar scores 1.0; closest single match wins
|
||||
with patch("httpx.post", return_value=_embed_resp([1.0, 0.0, 0.0])):
|
||||
result = adapter.classify("Welcome aboard", "Your first day details")
|
||||
|
||||
assert result == "hired"
|
||||
|
||||
|
||||
def test_classify_lazy_loads_when_not_loaded():
|
||||
from unittest.mock import patch
|
||||
from scripts.classifier_adapters import EmbeddingKNNAdapter
|
||||
|
||||
exemplars = {"rejected": ["We went with others"]}
|
||||
adapter = EmbeddingKNNAdapter(
|
||||
"test", "nomic-embed-text", k=1,
|
||||
orch_url="http://orch:7700", ollama_url="http://ollama:11434",
|
||||
exemplar_texts=exemplars,
|
||||
)
|
||||
assert adapter._exemplar_embeddings == {}
|
||||
|
||||
post_urls = []
|
||||
def mock_post(url, *, json=None, timeout=None, **kwargs):
|
||||
post_urls.append(url)
|
||||
from unittest.mock import MagicMock
|
||||
resp = MagicMock()
|
||||
resp.raise_for_status.return_value = None
|
||||
if "/allocate" in url:
|
||||
resp.status_code = 200
|
||||
resp.json.return_value = {"allocation_id": "a1", "url": "http://navi:11434"}
|
||||
else:
|
||||
n = len((json or {}).get("input", []))
|
||||
resp.json.return_value = {"data": [{"embedding": [1.0, 0.0]}] * n}
|
||||
return resp
|
||||
|
||||
with patch("httpx.post", side_effect=mock_post):
|
||||
result = adapter.classify("Rejection", "Sorry")
|
||||
|
||||
assert result == "rejected"
|
||||
assert any("/allocate" in u for u in post_urls), "lazy load must call allocate"
|
||||
assert adapter._exemplar_embeddings != {}
|
||||
assert adapter._node_url == "http://navi:11434"
|
||||
|
|
|
|||
|
|
@ -1,122 +0,0 @@
|
|||
"""Tests for app/dashboard.py -- GET /api/dashboard."""
|
||||
import json
|
||||
import pytest
|
||||
import yaml
|
||||
from fastapi.testclient import TestClient
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_globals(tmp_path):
|
||||
from app import dashboard as dash_module
|
||||
dash_module.set_data_dir(tmp_path)
|
||||
dash_module.set_config_dir(tmp_path)
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
from app.api import app
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
def _write_score(tmp_path: Path, records: list[dict]) -> None:
|
||||
(tmp_path / "email_score.jsonl").write_text(
|
||||
"\n".join(json.dumps(r) for r in records) + "\n"
|
||||
)
|
||||
|
||||
def _write_summary(tmp_path: Path, run_id: str, ts: str, score: float) -> None:
|
||||
run_dir = tmp_path / "bench_results" / run_id
|
||||
run_dir.mkdir(parents=True)
|
||||
(run_dir / "summary.json").write_text(
|
||||
json.dumps({"timestamp": ts, "best_macro_f1": score})
|
||||
)
|
||||
|
||||
|
||||
def test_dashboard_returns_expected_keys(client):
|
||||
r = client.get("/api/dashboard")
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
for key in ("labeled_since_last_eval", "last_eval_timestamp", "last_eval_best_score",
|
||||
"active_jobs", "corrections_pending", "corrections_export_ready", "signals"):
|
||||
assert key in data, f"missing key: {key}"
|
||||
for sig in ("data_to_eval", "eval_to_train", "train_to_fleet"):
|
||||
assert sig in data["signals"], f"missing signal: {sig}"
|
||||
|
||||
|
||||
def test_dashboard_empty_state(client):
|
||||
r = client.get("/api/dashboard")
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["labeled_since_last_eval"] == 0
|
||||
assert data["last_eval_timestamp"] is None
|
||||
assert data["last_eval_best_score"] is None
|
||||
assert data["active_jobs"] == []
|
||||
assert data["corrections_pending"] == 0
|
||||
assert data["corrections_export_ready"] == 0
|
||||
|
||||
|
||||
def test_labeled_since_counts_all_when_no_eval(client, tmp_path):
|
||||
_write_score(tmp_path, [
|
||||
{"id": "a", "label": "neutral", "labeled_at": "2026-05-01T10:00:00+00:00"},
|
||||
{"id": "b", "label": "neutral", "labeled_at": "2026-05-01T11:00:00+00:00"},
|
||||
])
|
||||
r = client.get("/api/dashboard")
|
||||
assert r.json()["labeled_since_last_eval"] == 2
|
||||
|
||||
|
||||
def test_labeled_since_filters_by_eval_timestamp(client, tmp_path):
|
||||
_write_summary(tmp_path, "2026-05-01-100000", "2026-05-01T10:00:00+00:00", 0.80)
|
||||
_write_score(tmp_path, [
|
||||
{"id": "a", "label": "neutral", "labeled_at": "2026-05-01T09:00:00+00:00"},
|
||||
{"id": "b", "label": "neutral", "labeled_at": "2026-05-01T11:00:00+00:00"},
|
||||
])
|
||||
(tmp_path / "label_tool.yaml").write_text(
|
||||
yaml.dump({"cforch": {"results_dir": str(tmp_path / "bench_results")}})
|
||||
)
|
||||
r = client.get("/api/dashboard")
|
||||
data = r.json()
|
||||
assert data["labeled_since_last_eval"] == 1
|
||||
assert abs(data["last_eval_best_score"] - 0.80) < 0.001
|
||||
|
||||
|
||||
def test_data_to_eval_false_below_threshold(client, tmp_path):
|
||||
_write_score(tmp_path, [{"id": str(i), "label": "neutral",
|
||||
"labeled_at": "2026-05-01T10:00:00+00:00"} for i in range(10)])
|
||||
(tmp_path / "label_tool.yaml").write_text(yaml.dump({"pipeline": {"data_eval_threshold": 50}}))
|
||||
r = client.get("/api/dashboard")
|
||||
assert r.json()["signals"]["data_to_eval"] is False
|
||||
|
||||
|
||||
def test_data_to_eval_true_at_threshold(client, tmp_path):
|
||||
_write_score(tmp_path, [{"id": str(i), "label": "neutral",
|
||||
"labeled_at": "2026-05-01T10:00:00+00:00"} for i in range(50)])
|
||||
(tmp_path / "label_tool.yaml").write_text(yaml.dump({"pipeline": {"data_eval_threshold": 50}}))
|
||||
r = client.get("/api/dashboard")
|
||||
assert r.json()["signals"]["data_to_eval"] is True
|
||||
|
||||
|
||||
def test_corrections_pending_count(client, tmp_path):
|
||||
candidates = [
|
||||
{"id": "c1", "status": "needs_review"},
|
||||
{"id": "c2", "status": "needs_review"},
|
||||
{"id": "c3", "status": "discarded"},
|
||||
]
|
||||
(tmp_path / "sft_candidates.jsonl").write_text(
|
||||
"\n".join(json.dumps(c) for c in candidates) + "\n"
|
||||
)
|
||||
r = client.get("/api/dashboard")
|
||||
assert r.json()["corrections_pending"] == 2
|
||||
|
||||
|
||||
def test_corrections_export_ready_count(client, tmp_path):
|
||||
approved = [
|
||||
{"id": "a1", "status": "approved", "corrected_response": "Good answer"},
|
||||
{"id": "a2", "status": "approved", "corrected_response": ""},
|
||||
{"id": "a3", "status": "approved", "corrected_response": "Another answer"},
|
||||
]
|
||||
(tmp_path / "sft_approved.jsonl").write_text(
|
||||
"\n".join(json.dumps(a) for a in approved) + "\n"
|
||||
)
|
||||
r = client.get("/api/dashboard")
|
||||
assert r.json()["corrections_export_ready"] == 2
|
||||
|
|
@ -1,102 +0,0 @@
|
|||
"""Tests for app/data/corrections.py -- POST /api/sft/ingest.
|
||||
|
||||
The corrections router is mounted at prefix="/api/sft" via the app/sft.py
|
||||
backward-compat shim, so ingest lives at /api/sft/ingest.
|
||||
"""
|
||||
import json
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_globals(tmp_path):
|
||||
from app.data import corrections as corr_module
|
||||
corr_module.set_data_dir(tmp_path)
|
||||
corr_module.set_config_dir(tmp_path)
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
from app.api import app
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
_VALID_PAYLOAD = {
|
||||
"source": "peregrine",
|
||||
"task_type": "email_classification",
|
||||
"prompt": "Classify this email: ...",
|
||||
"response": "skip",
|
||||
"correction": "action_required",
|
||||
"label": "action_required",
|
||||
}
|
||||
|
||||
_SECRET = "test-secret-abc123"
|
||||
|
||||
|
||||
def test_ingest_503_when_secret_not_configured(client, monkeypatch):
|
||||
monkeypatch.delenv("AVOCET_INGESTION_SECRET", raising=False)
|
||||
r = client.post("/api/sft/ingest", json=_VALID_PAYLOAD,
|
||||
headers={"Authorization": f"Bearer {_SECRET}"})
|
||||
assert r.status_code == 503
|
||||
|
||||
|
||||
def test_ingest_401_when_no_auth_header(client, monkeypatch):
|
||||
monkeypatch.setenv("AVOCET_INGESTION_SECRET", _SECRET)
|
||||
r = client.post("/api/sft/ingest", json=_VALID_PAYLOAD)
|
||||
assert r.status_code == 401
|
||||
|
||||
|
||||
def test_ingest_401_when_malformed_header(client, monkeypatch):
|
||||
monkeypatch.setenv("AVOCET_INGESTION_SECRET", _SECRET)
|
||||
r = client.post("/api/sft/ingest", json=_VALID_PAYLOAD,
|
||||
headers={"Authorization": "Token bad-format"})
|
||||
assert r.status_code == 401
|
||||
|
||||
|
||||
def test_ingest_403_when_wrong_secret(client, monkeypatch):
|
||||
monkeypatch.setenv("AVOCET_INGESTION_SECRET", _SECRET)
|
||||
r = client.post("/api/sft/ingest", json=_VALID_PAYLOAD,
|
||||
headers={"Authorization": "Bearer wrong-secret"})
|
||||
assert r.status_code == 403
|
||||
|
||||
|
||||
def test_ingest_creates_approved_record(client, monkeypatch, tmp_path):
|
||||
from app.data import corrections as corr_module
|
||||
monkeypatch.setenv("AVOCET_INGESTION_SECRET", _SECRET)
|
||||
corr_module.set_data_dir(tmp_path)
|
||||
r = client.post("/api/sft/ingest", json=_VALID_PAYLOAD,
|
||||
headers={"Authorization": f"Bearer {_SECRET}"})
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["ok"] is True
|
||||
assert "id" in data
|
||||
candidates = corr_module.read_jsonl(corr_module._candidates_file())
|
||||
assert len(candidates) == 1
|
||||
rec = candidates[0]
|
||||
assert rec["status"] == "approved"
|
||||
assert rec["source"] == "peregrine"
|
||||
assert rec["corrected_response"] == "action_required"
|
||||
assert rec["id"] == data["id"]
|
||||
|
||||
|
||||
def test_ingest_also_writes_to_approved_file(client, monkeypatch, tmp_path):
|
||||
from app.data import corrections as corr_module
|
||||
monkeypatch.setenv("AVOCET_INGESTION_SECRET", _SECRET)
|
||||
corr_module.set_data_dir(tmp_path)
|
||||
r = client.post("/api/sft/ingest", json=_VALID_PAYLOAD,
|
||||
headers={"Authorization": f"Bearer {_SECRET}"})
|
||||
assert r.status_code == 200
|
||||
approved = corr_module.read_jsonl(corr_module._approved_file())
|
||||
assert len(approved) == 1
|
||||
assert approved[0]["id"] == r.json()["id"]
|
||||
|
||||
|
||||
def test_ingest_without_label_is_accepted(client, monkeypatch, tmp_path):
|
||||
from app.data import corrections as corr_module
|
||||
monkeypatch.setenv("AVOCET_INGESTION_SECRET", _SECRET)
|
||||
corr_module.set_data_dir(tmp_path)
|
||||
payload = {**_VALID_PAYLOAD, "label": None}
|
||||
r = client.post("/api/sft/ingest", json=payload,
|
||||
headers={"Authorization": f"Bearer {_SECRET}"})
|
||||
assert r.status_code == 200
|
||||
|
|
@ -1,95 +0,0 @@
|
|||
"""Tests for app/data/fetch.py"""
|
||||
import json
|
||||
import yaml
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_globals(tmp_path):
|
||||
from app.data import fetch as fetch_module
|
||||
fetch_module.set_data_dir(tmp_path)
|
||||
fetch_module.set_config_dir(tmp_path)
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
from app.api import app
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
def _parse_sse(content: bytes) -> list[dict]:
|
||||
events = []
|
||||
for line in content.decode().splitlines():
|
||||
if line.startswith("data: "):
|
||||
events.append(json.loads(line[6:]))
|
||||
return events
|
||||
|
||||
|
||||
def test_account_test_missing_fields(client):
|
||||
r = client.post("/api/accounts/test",
|
||||
json={"account": {"host": "", "username": "", "password": ""}})
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["ok"] is False
|
||||
assert "required" in data["message"].lower()
|
||||
|
||||
|
||||
def test_account_test_success(client):
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.select.return_value = ("OK", [b"99"])
|
||||
with patch("app.data.fetch.imaplib.IMAP4_SSL", return_value=mock_conn):
|
||||
r = client.post("/api/accounts/test", json={"account": {
|
||||
"host": "imap.example.com", "port": 993, "use_ssl": True,
|
||||
"username": "u@example.com", "password": "pw", "folder": "INBOX",
|
||||
}})
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["ok"] is True
|
||||
assert data["count"] == 99
|
||||
|
||||
|
||||
def test_fetch_stream_no_accounts_configured(client, tmp_path):
|
||||
r = client.get("/api/fetch/stream?accounts=NoSuchAccount&days_back=30&limit=10")
|
||||
assert r.status_code == 200
|
||||
events = _parse_sse(r.content)
|
||||
complete = next((e for e in events if e["type"] == "complete"), None)
|
||||
assert complete is not None
|
||||
assert complete["total_added"] == 0
|
||||
|
||||
|
||||
def test_fetch_stream_with_mock_imap(client, tmp_path):
|
||||
from app.data import fetch as fetch_module
|
||||
fetch_module.set_config_dir(tmp_path)
|
||||
cfg = {"accounts": [{"name": "Mock", "host": "h", "port": 993, "use_ssl": True,
|
||||
"username": "u", "password": "p", "folder": "INBOX",
|
||||
"days_back": 30}], "max_per_account": 50}
|
||||
(tmp_path / "label_tool.yaml").write_text(yaml.dump(cfg))
|
||||
raw_msg = (b"Subject: Interview\r\nFrom: a@b.com\r\n"
|
||||
b"Date: Mon, 1 Mar 2026 12:00:00 +0000\r\n\r\nBody")
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.search.return_value = ("OK", [b"1"])
|
||||
mock_conn.fetch.return_value = ("OK", [(b"1 (RFC822 {N})", raw_msg)])
|
||||
with patch("app.data.fetch.imaplib.IMAP4_SSL", return_value=mock_conn):
|
||||
r = client.get("/api/fetch/stream?accounts=Mock&days_back=30&limit=50")
|
||||
assert r.status_code == 200
|
||||
events = _parse_sse(r.content)
|
||||
types = [e["type"] for e in events]
|
||||
assert "start" in types
|
||||
assert "done" in types
|
||||
assert "complete" in types
|
||||
|
||||
|
||||
def test_entry_key_deterministic():
|
||||
from app.data.fetch import entry_key
|
||||
e = {"subject": "Test", "body": "Hello world"}
|
||||
assert entry_key(e) == entry_key(e)
|
||||
|
||||
|
||||
def test_entry_key_differs_by_subject():
|
||||
from app.data.fetch import entry_key
|
||||
a = {"subject": "A", "body": "same body"}
|
||||
b = {"subject": "B", "body": "same body"}
|
||||
assert entry_key(a) != entry_key(b)
|
||||
|
|
@ -1,219 +0,0 @@
|
|||
"""Tests for app/data/label.py"""
|
||||
import json
|
||||
import pytest
|
||||
import yaml
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_globals(tmp_path):
|
||||
from app.data import label as label_module
|
||||
label_module.set_data_dir(tmp_path)
|
||||
label_module.set_config_dir(tmp_path)
|
||||
label_module.reset_last_action()
|
||||
yield
|
||||
label_module.reset_last_action()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
from app.api import app
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def queue_with_items(tmp_path):
|
||||
from app.data import label as label_module
|
||||
items = [
|
||||
{"id": f"id{i}", "subject": f"Subject {i}", "body": f"Body {i}",
|
||||
"from": "test@example.com", "date": "2026-03-01", "source": "imap:test"}
|
||||
for i in range(3)
|
||||
]
|
||||
(label_module._DATA_DIR / "email_label_queue.jsonl").write_text(
|
||||
"\n".join(json.dumps(x) for x in items) + "\n")
|
||||
return items
|
||||
|
||||
|
||||
def test_queue_returns_items(client, queue_with_items):
|
||||
r = client.get("/api/queue?limit=2")
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert len(data["items"]) == 2
|
||||
assert data["total"] == 3
|
||||
|
||||
|
||||
def test_queue_empty_when_no_file(client):
|
||||
r = client.get("/api/queue")
|
||||
assert r.status_code == 200
|
||||
assert r.json() == {"items": [], "total": 0}
|
||||
|
||||
|
||||
def test_label_appends_to_score(client, queue_with_items):
|
||||
from app.data import label as label_module
|
||||
r = client.post("/api/label", json={"id": "id0", "label": "interview_scheduled"})
|
||||
assert r.status_code == 200
|
||||
records = label_module.read_jsonl(label_module._score_file())
|
||||
assert len(records) == 1
|
||||
assert records[0]["id"] == "id0"
|
||||
assert records[0]["label"] == "interview_scheduled"
|
||||
assert "labeled_at" in records[0]
|
||||
|
||||
|
||||
def test_label_removes_from_queue(client, queue_with_items):
|
||||
from app.data import label as label_module
|
||||
client.post("/api/label", json={"id": "id0", "label": "rejected"})
|
||||
queue = label_module.read_jsonl(label_module._queue_file())
|
||||
assert not any(x["id"] == "id0" for x in queue)
|
||||
|
||||
|
||||
def test_label_unknown_id_returns_404(client, queue_with_items):
|
||||
r = client.post("/api/label", json={"id": "unknown", "label": "neutral"})
|
||||
assert r.status_code == 404
|
||||
|
||||
|
||||
def test_skip_moves_to_back(client, queue_with_items):
|
||||
from app.data import label as label_module
|
||||
r = client.post("/api/skip", json={"id": "id0"})
|
||||
assert r.status_code == 200
|
||||
queue = label_module.read_jsonl(label_module._queue_file())
|
||||
assert queue[-1]["id"] == "id0"
|
||||
assert queue[0]["id"] == "id1"
|
||||
|
||||
|
||||
def test_skip_unknown_id_returns_404(client, queue_with_items):
|
||||
r = client.post("/api/skip", json={"id": "nope"})
|
||||
assert r.status_code == 404
|
||||
|
||||
|
||||
def test_discard_writes_to_discarded_file(client, queue_with_items):
|
||||
from app.data import label as label_module
|
||||
r = client.post("/api/discard", json={"id": "id1"})
|
||||
assert r.status_code == 200
|
||||
discarded = label_module.read_jsonl(label_module._discarded_file())
|
||||
assert len(discarded) == 1
|
||||
assert discarded[0]["id"] == "id1"
|
||||
assert discarded[0]["label"] == "__discarded__"
|
||||
|
||||
|
||||
def test_discard_removes_from_queue(client, queue_with_items):
|
||||
from app.data import label as label_module
|
||||
client.post("/api/discard", json={"id": "id1"})
|
||||
queue = label_module.read_jsonl(label_module._queue_file())
|
||||
assert not any(x["id"] == "id1" for x in queue)
|
||||
|
||||
|
||||
def test_undo_label_removes_from_score(client, queue_with_items):
|
||||
from app.data import label as label_module
|
||||
client.post("/api/label", json={"id": "id0", "label": "neutral"})
|
||||
r = client.delete("/api/label/undo")
|
||||
assert r.status_code == 200
|
||||
assert r.json()["undone"]["type"] == "label"
|
||||
assert label_module.read_jsonl(label_module._score_file()) == []
|
||||
queue = label_module.read_jsonl(label_module._queue_file())
|
||||
assert queue[0]["id"] == "id0"
|
||||
|
||||
|
||||
def test_undo_discard_removes_from_discarded(client, queue_with_items):
|
||||
from app.data import label as label_module
|
||||
client.post("/api/discard", json={"id": "id0"})
|
||||
r = client.delete("/api/label/undo")
|
||||
assert r.status_code == 200
|
||||
assert label_module.read_jsonl(label_module._discarded_file()) == []
|
||||
|
||||
|
||||
def test_undo_skip_restores_to_front(client, queue_with_items):
|
||||
from app.data import label as label_module
|
||||
client.post("/api/skip", json={"id": "id0"})
|
||||
r = client.delete("/api/label/undo")
|
||||
assert r.status_code == 200
|
||||
queue = label_module.read_jsonl(label_module._queue_file())
|
||||
assert queue[0]["id"] == "id0"
|
||||
|
||||
|
||||
def test_undo_with_no_action_returns_404(client):
|
||||
r = client.delete("/api/label/undo")
|
||||
assert r.status_code == 404
|
||||
|
||||
|
||||
def test_config_labels_returns_10_labels(client):
|
||||
r = client.get("/api/config/labels")
|
||||
assert r.status_code == 200
|
||||
labels = r.json()
|
||||
assert len(labels) == 10
|
||||
assert labels[0]["key"] == "1"
|
||||
for lbl in labels:
|
||||
assert "emoji" in lbl and "color" in lbl and "name" in lbl
|
||||
|
||||
|
||||
def test_get_config_returns_empty_when_no_file(client):
|
||||
r = client.get("/api/config")
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["accounts"] == []
|
||||
assert data["max_per_account"] == 500
|
||||
|
||||
|
||||
def test_post_config_writes_yaml(client, tmp_path):
|
||||
from app.data import label as label_module
|
||||
label_module.set_config_dir(tmp_path)
|
||||
payload = {"accounts": [{"name": "Test", "host": "imap.test.com", "port": 993,
|
||||
"use_ssl": True, "username": "u@t.com", "password": "pw",
|
||||
"folder": "INBOX", "days_back": 30}], "max_per_account": 200}
|
||||
r = client.post("/api/config", json=payload)
|
||||
assert r.status_code == 200
|
||||
assert r.json()["ok"] is True
|
||||
saved = yaml.safe_load((tmp_path / "label_tool.yaml").read_text())
|
||||
assert saved["max_per_account"] == 200
|
||||
assert saved["accounts"][0]["name"] == "Test"
|
||||
|
||||
|
||||
def test_get_config_round_trips(client, tmp_path):
|
||||
from app.data import label as label_module
|
||||
label_module.set_config_dir(tmp_path)
|
||||
payload = {"accounts": [{"name": "R", "host": "h", "port": 993, "use_ssl": True,
|
||||
"username": "u", "password": "p", "folder": "INBOX",
|
||||
"days_back": 90}], "max_per_account": 300}
|
||||
client.post("/api/config", json=payload)
|
||||
r = client.get("/api/config")
|
||||
data = r.json()
|
||||
assert data["max_per_account"] == 300
|
||||
assert data["accounts"][0]["name"] == "R"
|
||||
|
||||
|
||||
def test_stats_returns_counts(client, tmp_path):
|
||||
from app.data import label as label_module
|
||||
label_module.set_data_dir(tmp_path)
|
||||
score_path = tmp_path / "email_score.jsonl"
|
||||
records = [{"id": "a", "label": "interview_scheduled"},
|
||||
{"id": "b", "label": "interview_scheduled"},
|
||||
{"id": "c", "label": "rejected"}]
|
||||
score_path.write_text("\n".join(json.dumps(r) for r in records) + "\n")
|
||||
r = client.get("/api/stats")
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["total"] == 3
|
||||
assert data["counts"]["interview_scheduled"] == 2
|
||||
assert data["counts"]["rejected"] == 1
|
||||
|
||||
|
||||
def test_stats_empty_when_no_file(client):
|
||||
r = client.get("/api/stats")
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["total"] == 0
|
||||
assert data["counts"] == {}
|
||||
assert data["score_file_bytes"] == 0
|
||||
|
||||
|
||||
def test_stats_download_returns_file(client, tmp_path):
|
||||
from app.data import label as label_module
|
||||
label_module.set_data_dir(tmp_path)
|
||||
(tmp_path / "email_score.jsonl").write_text(json.dumps({"id": "a", "label": "neutral"}) + "\n")
|
||||
r = client.get("/api/stats/download")
|
||||
assert r.status_code == 200
|
||||
assert "jsonlines" in r.headers.get("content-type", "")
|
||||
|
||||
|
||||
def test_stats_download_404_when_no_file(client):
|
||||
r = client.get("/api/stats/download")
|
||||
assert r.status_code == 404
|
||||
|
|
@ -1,234 +0,0 @@
|
|||
"""Tests for app/eval/embed_bench.py."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
|
||||
# ── Fixtures ──────────────────────────────────────────────────────────────────
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_embed_bench_globals(tmp_path):
|
||||
"""Redirect config dir to tmp_path and reset running flag."""
|
||||
from app.eval import embed_bench as mod
|
||||
|
||||
prev_config_dir = mod._CONFIG_DIR
|
||||
prev_running = mod._RUN_ACTIVE
|
||||
|
||||
mod.set_config_dir(tmp_path)
|
||||
mod._RUN_ACTIVE = False
|
||||
|
||||
yield tmp_path
|
||||
|
||||
mod.set_config_dir(prev_config_dir)
|
||||
mod._RUN_ACTIVE = prev_running
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
from app.api import app
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
# ── cosine helper ──────────────────────────────────────────────────────────────
|
||||
|
||||
def test_cosine_identical():
|
||||
from app.eval.embed_bench import _cosine
|
||||
assert _cosine([1.0, 0.0], [1.0, 0.0]) == pytest.approx(1.0)
|
||||
|
||||
|
||||
def test_cosine_orthogonal():
|
||||
from app.eval.embed_bench import _cosine
|
||||
assert _cosine([1.0, 0.0], [0.0, 1.0]) == pytest.approx(0.0)
|
||||
|
||||
|
||||
def test_cosine_opposite():
|
||||
from app.eval.embed_bench import _cosine
|
||||
assert _cosine([1.0, 0.0], [-1.0, 0.0]) == pytest.approx(-1.0)
|
||||
|
||||
|
||||
def test_cosine_zero_vector_returns_zero():
|
||||
from app.eval.embed_bench import _cosine
|
||||
assert _cosine([0.0, 0.0], [1.0, 0.0]) == pytest.approx(0.0)
|
||||
|
||||
|
||||
# ── models endpoint ────────────────────────────────────────────────────────────
|
||||
|
||||
def test_models_returns_list_with_mock(client, tmp_path):
|
||||
"""GET /api/embed-bench/models returns list from Ollama tags endpoint."""
|
||||
import yaml
|
||||
cfg = {"cforch": {"ollama_url": "http://localhost:11434"}}
|
||||
(tmp_path / "label_tool.yaml").write_text(yaml.dump(cfg))
|
||||
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.json.return_value = {
|
||||
"models": [
|
||||
{"name": "nomic-embed-text", "size": 274302480},
|
||||
{"name": "mxbai-embed-large", "size": 669000000},
|
||||
]
|
||||
}
|
||||
mock_resp.raise_for_status = MagicMock()
|
||||
|
||||
with patch("app.eval.embed_bench.httpx.get", return_value=mock_resp):
|
||||
r = client.get("/api/embed-bench/models")
|
||||
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert isinstance(data["models"], list)
|
||||
assert any(m["name"] == "nomic-embed-text" for m in data["models"])
|
||||
|
||||
|
||||
def test_models_returns_empty_on_ollama_error(client, tmp_path):
|
||||
"""GET /api/embed-bench/models returns empty list if Ollama unreachable."""
|
||||
import httpx
|
||||
with patch("app.eval.embed_bench.httpx.get", side_effect=httpx.ConnectError("refused")):
|
||||
r = client.get("/api/embed-bench/models")
|
||||
assert r.status_code == 200
|
||||
assert r.json()["models"] == []
|
||||
|
||||
|
||||
# ── run endpoint ───────────────────────────────────────────────────────────────
|
||||
|
||||
def test_run_empty_corpus_returns_422(client):
|
||||
r = client.post("/api/embed-bench/run", json={
|
||||
"corpus": [], "queries": ["test"], "models": ["nomic-embed-text"], "top_k": 3
|
||||
})
|
||||
assert r.status_code == 422
|
||||
|
||||
|
||||
def test_run_empty_queries_returns_422(client):
|
||||
r = client.post("/api/embed-bench/run", json={
|
||||
"corpus": ["chunk 1"], "queries": [], "models": ["nomic-embed-text"], "top_k": 3
|
||||
})
|
||||
assert r.status_code == 422
|
||||
|
||||
|
||||
def test_run_empty_models_returns_422(client):
|
||||
r = client.post("/api/embed-bench/run", json={
|
||||
"corpus": ["chunk 1"], "queries": ["test"], "models": [], "top_k": 3
|
||||
})
|
||||
assert r.status_code == 422
|
||||
|
||||
|
||||
def _fake_embed_response(texts: list[str]) -> MagicMock:
|
||||
"""Build a mock httpx.post response returning unit vectors for each text."""
|
||||
resp = MagicMock()
|
||||
resp.raise_for_status = MagicMock()
|
||||
resp.json.return_value = {
|
||||
"data": [{"embedding": [1.0, 0.0, 0.0] if i % 2 == 0 else [0.0, 1.0, 0.0]}
|
||||
for i, _ in enumerate(texts)]
|
||||
}
|
||||
return resp
|
||||
|
||||
|
||||
def _collect_sse(raw: bytes) -> list[dict]:
|
||||
"""Parse SSE stream bytes into a list of decoded event dicts."""
|
||||
events = []
|
||||
for line in raw.decode().splitlines():
|
||||
if line.startswith("data: "):
|
||||
events.append(json.loads(line[6:]))
|
||||
return events
|
||||
|
||||
|
||||
def test_run_single_model_returns_result_and_done(client, tmp_path):
|
||||
import yaml
|
||||
(tmp_path / "label_tool.yaml").write_text(yaml.dump({"cforch": {"ollama_url": "http://localhost:11434"}}))
|
||||
|
||||
with patch("app.eval.embed_bench.httpx.post", return_value=_fake_embed_response(["chunk 1", "chunk 2"])):
|
||||
r = client.post("/api/embed-bench/run", json={
|
||||
"corpus": ["chunk 1", "chunk 2"],
|
||||
"queries": ["what is chunk one?"],
|
||||
"models": ["nomic-embed-text"],
|
||||
"top_k": 2,
|
||||
})
|
||||
|
||||
assert r.status_code == 200
|
||||
events = _collect_sse(r.content)
|
||||
types = [e["type"] for e in events]
|
||||
assert "result" in types
|
||||
assert types[-1] == "done"
|
||||
result_events = [e for e in events if e["type"] == "result"]
|
||||
assert result_events[0]["model"] == "nomic-embed-text"
|
||||
assert result_events[0]["query_idx"] == 0
|
||||
assert len(result_events[0]["hits"]) <= 2
|
||||
|
||||
|
||||
def test_run_two_models_returns_two_result_events_per_query(client, tmp_path):
|
||||
import yaml
|
||||
(tmp_path / "label_tool.yaml").write_text(yaml.dump({"cforch": {"ollama_url": "http://localhost:11434"}}))
|
||||
|
||||
with patch("app.eval.embed_bench.httpx.post", return_value=_fake_embed_response(["chunk A", "chunk B"])):
|
||||
r = client.post("/api/embed-bench/run", json={
|
||||
"corpus": ["chunk A", "chunk B"],
|
||||
"queries": ["find it"],
|
||||
"models": ["nomic-embed-text", "mxbai-embed-large"],
|
||||
"top_k": 2,
|
||||
})
|
||||
|
||||
events = _collect_sse(r.content)
|
||||
result_events = [e for e in events if e["type"] == "result"]
|
||||
models_seen = {e["model"] for e in result_events}
|
||||
assert "nomic-embed-text" in models_seen
|
||||
assert "mxbai-embed-large" in models_seen
|
||||
|
||||
|
||||
# ── rate + export ──────────────────────────────────────────────────────────────
|
||||
|
||||
def test_rate_appends_jsonl_line(client, tmp_path):
|
||||
r = client.post("/api/embed-bench/rate", json={
|
||||
"query": "test query",
|
||||
"model": "nomic-embed-text",
|
||||
"chunk_text": "some text",
|
||||
"chunk_idx": 2,
|
||||
"rating": "relevant",
|
||||
})
|
||||
assert r.status_code == 200
|
||||
assert r.json() == {"ok": True}
|
||||
ratings_file = tmp_path / "embed_bench_ratings.jsonl"
|
||||
assert ratings_file.exists()
|
||||
line = json.loads(ratings_file.read_text().strip())
|
||||
assert line["query"] == "test query"
|
||||
assert line["rating"] == "relevant"
|
||||
assert line["chunk_idx"] == 2
|
||||
assert "timestamp" in line
|
||||
|
||||
|
||||
def test_export_csv_two_rows(client, tmp_path):
|
||||
for i in range(2):
|
||||
client.post("/api/embed-bench/rate", json={
|
||||
"query": f"q{i}", "model": "nomic-embed-text",
|
||||
"chunk_text": f"chunk {i}", "chunk_idx": i, "rating": "relevant",
|
||||
})
|
||||
r = client.get("/api/embed-bench/export?format=csv")
|
||||
assert r.status_code == 200
|
||||
assert "text/csv" in r.headers["content-type"]
|
||||
lines = r.text.strip().splitlines()
|
||||
assert len(lines) == 3 # header + 2 rows
|
||||
assert "query" in lines[0]
|
||||
|
||||
|
||||
def test_export_json_two_entries(client, tmp_path):
|
||||
for i in range(2):
|
||||
client.post("/api/embed-bench/rate", json={
|
||||
"query": f"q{i}", "model": "nomic-embed-text",
|
||||
"chunk_text": f"chunk {i}", "chunk_idx": i, "rating": "not_relevant",
|
||||
})
|
||||
r = client.get("/api/embed-bench/export?format=json")
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert isinstance(data, list)
|
||||
assert len(data) == 2
|
||||
assert data[0]["rating"] == "not_relevant"
|
||||
|
||||
|
||||
def test_export_empty_returns_csv_header_only(client):
|
||||
r = client.get("/api/embed-bench/export?format=csv")
|
||||
assert r.status_code == 200
|
||||
lines = r.text.strip().splitlines()
|
||||
assert len(lines) == 1 # header only
|
||||
assert "query" in lines[0]
|
||||
|
|
@ -321,7 +321,6 @@ def test_load_and_prepare_data_single_path_still_works(tmp_path):
|
|||
|
||||
# ---- Integration test ----
|
||||
|
||||
@pytest.mark.gpu
|
||||
def test_integration_finetune_on_example_data(tmp_path):
|
||||
"""Fine-tune deberta-small on example data for 1 epoch.
|
||||
|
||||
|
|
|
|||
|
|
@ -1,242 +0,0 @@
|
|||
"""Tests for app/imitate.py -- product registry, sample extraction, corrections push."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.api import app
|
||||
from app.data import imitate as _imitate_module
|
||||
|
||||
|
||||
# -- Fixtures ------------------------------------------------------------------
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_module_globals(tmp_path):
|
||||
"""Reset module-level config + data dir globals after each test."""
|
||||
orig_cfg = _imitate_module._CONFIG_DIR
|
||||
orig_data = _imitate_module._DATA_DIR
|
||||
yield
|
||||
_imitate_module._CONFIG_DIR = orig_cfg
|
||||
_imitate_module._DATA_DIR = orig_data
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def config_dir(tmp_path) -> Path:
|
||||
_imitate_module.set_config_dir(tmp_path)
|
||||
return tmp_path
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def data_dir(tmp_path) -> Path:
|
||||
_imitate_module.set_data_dir(tmp_path)
|
||||
return tmp_path
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def cfg_with_products(config_dir: Path) -> Path:
|
||||
"""Write a label_tool.yaml with two products."""
|
||||
(config_dir / "label_tool.yaml").write_text(
|
||||
"""
|
||||
imitate:
|
||||
ollama_url: http://localhost:11434
|
||||
products:
|
||||
- id: peregrine
|
||||
name: Peregrine
|
||||
icon: "🦅"
|
||||
description: Job search assistant
|
||||
base_url: http://peregrine.local
|
||||
sample_endpoint: /api/jobs
|
||||
text_fields: [title, description]
|
||||
prompt_template: "Analyze: {text}"
|
||||
- id: kiwi
|
||||
name: Kiwi
|
||||
icon: "🥝"
|
||||
description: Pantry tracker
|
||||
base_url: http://kiwi.local
|
||||
sample_endpoint: /api/inventory
|
||||
text_fields: [name, notes]
|
||||
prompt_template: "Describe: {text}"
|
||||
"""
|
||||
)
|
||||
return config_dir
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def client() -> TestClient:
|
||||
return TestClient(app, raise_server_exceptions=True)
|
||||
|
||||
|
||||
# -- GET /products -------------------------------------------------------------
|
||||
|
||||
def test_products_empty_when_no_config(config_dir, client):
|
||||
"""Returns empty list when label_tool.yaml has no imitate section."""
|
||||
(config_dir / "label_tool.yaml").write_text("accounts: []\n")
|
||||
resp = client.get("/api/imitate/products")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["products"] == []
|
||||
|
||||
|
||||
def test_products_listed(cfg_with_products, client):
|
||||
"""All configured products are returned with expected fields."""
|
||||
with patch.object(_imitate_module, "_is_online", return_value=True):
|
||||
resp = client.get("/api/imitate/products")
|
||||
assert resp.status_code == 200
|
||||
products = resp.json()["products"]
|
||||
assert len(products) == 2
|
||||
ids = {p["id"] for p in products}
|
||||
assert ids == {"peregrine", "kiwi"}
|
||||
peregrine = next(p for p in products if p["id"] == "peregrine")
|
||||
assert peregrine["name"] == "Peregrine"
|
||||
assert peregrine["icon"] == "🦅"
|
||||
assert peregrine["online"] is True
|
||||
|
||||
|
||||
def test_products_offline_when_unreachable(cfg_with_products, client):
|
||||
"""Products with unreachable base_url are marked offline."""
|
||||
with patch.object(_imitate_module, "_is_online", return_value=False):
|
||||
resp = client.get("/api/imitate/products")
|
||||
assert all(not p["online"] for p in resp.json()["products"])
|
||||
|
||||
|
||||
# -- GET /products/{id}/sample -------------------------------------------------
|
||||
|
||||
def test_sample_unknown_product(cfg_with_products, client):
|
||||
"""Returns 404 for a product id not in config."""
|
||||
resp = client.get("/api/imitate/products/nonexistent/sample")
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
def test_sample_fetched_from_list(cfg_with_products, client):
|
||||
"""Extracts first item from a list API response."""
|
||||
fake_api = [
|
||||
{"title": "Engineer", "description": "Build things"},
|
||||
{"title": "Other", "description": "Ignore me"},
|
||||
]
|
||||
with patch.object(_imitate_module, "_http_get_json", return_value=fake_api):
|
||||
resp = client.get("/api/imitate/products/peregrine/sample")
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert "Engineer" in body["text"]
|
||||
assert "Build things" in body["text"]
|
||||
assert "Analyze:" in body["prompt"]
|
||||
|
||||
|
||||
def test_sample_fetched_from_dict_with_items_key(cfg_with_products, client):
|
||||
"""Extracts from a wrapper dict with a recognised list key."""
|
||||
fake_api = {"items": [{"title": "Wrapped Job", "description": "In a wrapper"}]}
|
||||
with patch.object(_imitate_module, "_http_get_json", return_value=fake_api):
|
||||
resp = client.get("/api/imitate/products/peregrine/sample")
|
||||
assert resp.status_code == 200
|
||||
assert "Wrapped Job" in resp.json()["text"]
|
||||
|
||||
|
||||
def test_sample_503_when_api_unreachable(cfg_with_products, client):
|
||||
"""Returns 503 when the product API is not reachable."""
|
||||
from urllib.error import URLError
|
||||
with patch.object(_imitate_module, "_http_get_json", side_effect=URLError("refused")):
|
||||
resp = client.get("/api/imitate/products/peregrine/sample")
|
||||
assert resp.status_code == 503
|
||||
|
||||
|
||||
def test_sample_404_on_empty_list(cfg_with_products, client):
|
||||
"""Returns 404 when product API returns an empty list."""
|
||||
with patch.object(_imitate_module, "_http_get_json", return_value=[]):
|
||||
resp = client.get("/api/imitate/products/peregrine/sample")
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
# -- POST /push-corrections ----------------------------------------------------
|
||||
|
||||
def test_push_corrections_appends_jsonl(cfg_with_products, data_dir, client):
|
||||
"""Successful push writes records to sft_candidates.jsonl."""
|
||||
payload = {
|
||||
"product_id": "peregrine",
|
||||
"prompt": "Analyze this job:",
|
||||
"results": [
|
||||
{"model": "qwen2.5:0.5b", "response": "It's a good job.", "elapsed_ms": 800, "error": None},
|
||||
{"model": "llama3.1:8b", "response": "Strong candidate.", "elapsed_ms": 1500, "error": None},
|
||||
],
|
||||
}
|
||||
resp = client.post("/api/imitate/push-corrections", json=payload)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["pushed"] == 2
|
||||
|
||||
candidates = (data_dir / "sft_candidates.jsonl").read_text().splitlines()
|
||||
assert len(candidates) == 2
|
||||
for line in candidates:
|
||||
record = json.loads(line)
|
||||
assert record["source"] == "imitate"
|
||||
assert record["product_id"] == "peregrine"
|
||||
assert record["status"] == "pending"
|
||||
assert record["prompt_messages"][0]["role"] == "user"
|
||||
|
||||
|
||||
def test_push_corrections_skips_errors(cfg_with_products, data_dir, client):
|
||||
"""Results with errors are not written to the corrections file."""
|
||||
payload = {
|
||||
"product_id": "peregrine",
|
||||
"prompt": "Analyze:",
|
||||
"results": [
|
||||
{"model": "good-model", "response": "Good answer.", "elapsed_ms": 500, "error": None},
|
||||
{"model": "bad-model", "response": "", "elapsed_ms": 0, "error": "connection refused"},
|
||||
],
|
||||
}
|
||||
resp = client.post("/api/imitate/push-corrections", json=payload)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["pushed"] == 1
|
||||
|
||||
|
||||
def test_push_corrections_empty_prompt_422(cfg_with_products, data_dir, client):
|
||||
"""Empty prompt returns 422."""
|
||||
payload = {
|
||||
"product_id": "peregrine",
|
||||
"prompt": " ",
|
||||
"results": [{"model": "m", "response": "r", "elapsed_ms": 1, "error": None}],
|
||||
}
|
||||
resp = client.post("/api/imitate/push-corrections", json=payload)
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
def test_push_corrections_all_errors_422(cfg_with_products, data_dir, client):
|
||||
"""422 when every result has an error (nothing to push)."""
|
||||
payload = {
|
||||
"product_id": "peregrine",
|
||||
"prompt": "Analyze:",
|
||||
"results": [
|
||||
{"model": "m", "response": "", "elapsed_ms": 0, "error": "timed out"},
|
||||
],
|
||||
}
|
||||
resp = client.post("/api/imitate/push-corrections", json=payload)
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
# -- _extract_sample helper ----------------------------------------------------
|
||||
|
||||
def test_extract_sample_list():
|
||||
result = _imitate_module._extract_sample(
|
||||
[{"title": "A", "description": "B"}],
|
||||
text_fields=["title", "description"],
|
||||
)
|
||||
assert "A" in result["text"]
|
||||
assert "B" in result["text"]
|
||||
|
||||
|
||||
def test_extract_sample_empty_list():
|
||||
result = _imitate_module._extract_sample([], text_fields=["title"])
|
||||
assert result == {}
|
||||
|
||||
|
||||
def test_extract_sample_respects_index():
|
||||
items = [{"title": "First"}, {"title": "Second"}]
|
||||
result = _imitate_module._extract_sample(items, ["title"], sample_index=1)
|
||||
assert "Second" in result["text"]
|
||||
|
||||
|
||||
def test_extract_sample_clamps_index():
|
||||
items = [{"title": "Only"}]
|
||||
result = _imitate_module._extract_sample(items, ["title"], sample_index=99)
|
||||
assert "Only" in result["text"]
|
||||
|
|
@ -1,454 +0,0 @@
|
|||
"""Tests for app/data/log_corpus.py — corpus receiver and labeling endpoints."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.data import log_corpus as lc
|
||||
|
||||
|
||||
VALID_TOKEN = str(uuid.uuid4())
|
||||
VALID_HOST = "testnode.local"
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def isolated_db(tmp_path, monkeypatch):
|
||||
"""Each test gets its own fresh corpus DB and config dir."""
|
||||
monkeypatch.setattr(lc, "_DATA_DIR", tmp_path)
|
||||
monkeypatch.setattr(lc, "_DB_PATH", tmp_path / "corpus.db")
|
||||
# Config dir pointing to a temp yaml with one test source
|
||||
config_dir = tmp_path / "config"
|
||||
config_dir.mkdir()
|
||||
(config_dir / "label_tool.yaml").write_text(
|
||||
f"corpus:\n sources:\n"
|
||||
f" - token: \"{VALID_TOKEN}\"\n"
|
||||
f" source_host: \"{VALID_HOST}\"\n"
|
||||
f" owner: TestOwner\n"
|
||||
f" consent_date: \"2026-05-11\"\n"
|
||||
f" consent_method: signal_chat\n"
|
||||
)
|
||||
monkeypatch.setattr(lc, "_CONFIG_DIR", config_dir)
|
||||
lc._init_db()
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def client():
|
||||
from fastapi import FastAPI
|
||||
app = FastAPI()
|
||||
app.include_router(lc.router, prefix="/api/corpus")
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
def _batch(batch_type="raw_entries", entries=None, source_host=VALID_HOST):
|
||||
return {
|
||||
"batch_version": 1,
|
||||
"batch_id": str(uuid.uuid4()),
|
||||
"pushed_at": "2026-05-11T10:00:00Z",
|
||||
"source_host": source_host,
|
||||
"batch_type": batch_type,
|
||||
"watermark_from": 0,
|
||||
"watermark_to": 5,
|
||||
"entries": entries or [
|
||||
{
|
||||
"entry_id": str(uuid.uuid4()),
|
||||
"source_id": "sonarr",
|
||||
"timestamp_iso": "2026-05-11T09:58:00Z",
|
||||
"severity": "ERROR",
|
||||
"text": "Connection refused to indexer",
|
||||
"matched_patterns": [],
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
# ── Receive endpoint ───────────────────────────────────────────────────────────
|
||||
|
||||
def test_receive_missing_auth(client):
|
||||
resp = client.post("/api/corpus/log-batch", json=_batch())
|
||||
assert resp.status_code == 401
|
||||
|
||||
|
||||
def test_receive_invalid_token(client):
|
||||
resp = client.post(
|
||||
"/api/corpus/log-batch",
|
||||
json=_batch(),
|
||||
headers={"Authorization": "Bearer bad-token"},
|
||||
)
|
||||
assert resp.status_code == 403
|
||||
|
||||
|
||||
def test_receive_valid_batch(client):
|
||||
resp = client.post(
|
||||
"/api/corpus/log-batch",
|
||||
json=_batch(),
|
||||
headers={"Authorization": f"Bearer {VALID_TOKEN}"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["received"] is True
|
||||
assert data["entries_stored"] == 1
|
||||
|
||||
|
||||
def test_receive_stores_source_host_from_token_not_payload(client):
|
||||
"""source_host is always taken from the DB lookup, not the payload."""
|
||||
payload = _batch(source_host="attacker-injected-host")
|
||||
resp = client.post(
|
||||
"/api/corpus/log-batch",
|
||||
json=payload,
|
||||
headers={"Authorization": f"Bearer {VALID_TOKEN}"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
entries_resp = client.get("/api/corpus/entries")
|
||||
entry = entries_resp.json()["entries"][0]
|
||||
assert entry["source_host"] == VALID_HOST
|
||||
|
||||
|
||||
def test_receive_skips_empty_text_entries(client):
|
||||
payload = _batch(entries=[
|
||||
{"entry_id": "e1", "source_id": "svc", "severity": "ERROR", "text": ""},
|
||||
{"entry_id": "e2", "source_id": "svc", "severity": "ERROR", "text": " "},
|
||||
{"entry_id": "e3", "source_id": "svc", "severity": "ERROR", "text": "real error"},
|
||||
])
|
||||
resp = client.post(
|
||||
"/api/corpus/log-batch",
|
||||
json=payload,
|
||||
headers={"Authorization": f"Bearer {VALID_TOKEN}"},
|
||||
)
|
||||
assert resp.json()["entries_stored"] == 1
|
||||
|
||||
|
||||
def test_receive_incident_bundle(client):
|
||||
payload = _batch(batch_type="incident_bundles", entries=[
|
||||
{"id": "inc-1", "label": "plex crash", "issue_type": "plex",
|
||||
"started_at": "2026-05-11T09:00:00", "ended_at": "2026-05-11T09:30:00",
|
||||
"notes": "audio dropped", "created_at": "2026-05-11T09:35:00",
|
||||
"severity": "high", "text": "plex crash"},
|
||||
])
|
||||
resp = client.post(
|
||||
"/api/corpus/log-batch",
|
||||
json=payload,
|
||||
headers={"Authorization": f"Bearer {VALID_TOKEN}"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["entries_stored"] == 1
|
||||
|
||||
|
||||
# ── Labeling endpoints ─────────────────────────────────────────────────────────
|
||||
|
||||
def test_label_entry(client):
|
||||
client.post(
|
||||
"/api/corpus/log-batch",
|
||||
json=_batch(),
|
||||
headers={"Authorization": f"Bearer {VALID_TOKEN}"},
|
||||
)
|
||||
entry_id = client.get("/api/corpus/entries").json()["entries"][0]["id"]
|
||||
|
||||
resp = client.post(f"/api/corpus/entries/{entry_id}/label", json={
|
||||
"failure_type": "software",
|
||||
"plain_explanation": "Sonarr lost connection to its indexer — restart the service.",
|
||||
"known_pattern": "y",
|
||||
})
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["labeled"] is True
|
||||
|
||||
entries = client.get("/api/corpus/entries", params={"state": "labeled"}).json()["entries"]
|
||||
assert len(entries) == 1
|
||||
assert entries[0]["failure_type"] == "software"
|
||||
|
||||
|
||||
def test_label_entry_invalid_failure_type(client):
|
||||
client.post(
|
||||
"/api/corpus/log-batch",
|
||||
json=_batch(),
|
||||
headers={"Authorization": f"Bearer {VALID_TOKEN}"},
|
||||
)
|
||||
entry_id = client.get("/api/corpus/entries").json()["entries"][0]["id"]
|
||||
resp = client.post(f"/api/corpus/entries/{entry_id}/label", json={"failure_type": "aliens"})
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
def test_label_entry_missing_failure_type(client):
|
||||
client.post(
|
||||
"/api/corpus/log-batch",
|
||||
json=_batch(),
|
||||
headers={"Authorization": f"Bearer {VALID_TOKEN}"},
|
||||
)
|
||||
entry_id = client.get("/api/corpus/entries").json()["entries"][0]["id"]
|
||||
resp = client.post(f"/api/corpus/entries/{entry_id}/label", json={})
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
def test_label_entry_not_found(client):
|
||||
resp = client.post("/api/corpus/entries/nonexistent/label", json={"failure_type": "software"})
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
def test_skip_entry(client):
|
||||
client.post(
|
||||
"/api/corpus/log-batch",
|
||||
json=_batch(),
|
||||
headers={"Authorization": f"Bearer {VALID_TOKEN}"},
|
||||
)
|
||||
entry_id = client.get("/api/corpus/entries").json()["entries"][0]["id"]
|
||||
resp = client.post(f"/api/corpus/entries/{entry_id}/skip")
|
||||
assert resp.status_code == 200
|
||||
|
||||
unlabeled = client.get("/api/corpus/entries").json()["entries"]
|
||||
assert len(unlabeled) == 0
|
||||
|
||||
|
||||
# ── Stats ──────────────────────────────────────────────────────────────────────
|
||||
|
||||
def test_stats_empty(client):
|
||||
stats = client.get("/api/corpus/stats").json()
|
||||
assert stats["total_entries"] == 0
|
||||
assert stats["batch_count"] == 0
|
||||
|
||||
|
||||
def test_stats_after_receive(client):
|
||||
client.post(
|
||||
"/api/corpus/log-batch",
|
||||
json=_batch(),
|
||||
headers={"Authorization": f"Bearer {VALID_TOKEN}"},
|
||||
)
|
||||
stats = client.get("/api/corpus/stats").json()
|
||||
assert stats["total_entries"] == 1
|
||||
assert stats["batch_count"] == 1
|
||||
assert stats["by_label_state"].get("unlabeled", 0) == 1
|
||||
|
||||
|
||||
# ── Export ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
def test_export_excludes_unlabeled(client):
|
||||
client.post(
|
||||
"/api/corpus/log-batch",
|
||||
json=_batch(),
|
||||
headers={"Authorization": f"Bearer {VALID_TOKEN}"},
|
||||
)
|
||||
resp = client.get("/api/corpus/export")
|
||||
assert resp.status_code == 200
|
||||
assert resp.text.strip() == ""
|
||||
|
||||
|
||||
def test_export_includes_labeled(client):
|
||||
client.post(
|
||||
"/api/corpus/log-batch",
|
||||
json=_batch(),
|
||||
headers={"Authorization": f"Bearer {VALID_TOKEN}"},
|
||||
)
|
||||
entry_id = client.get("/api/corpus/entries").json()["entries"][0]["id"]
|
||||
client.post(f"/api/corpus/entries/{entry_id}/label", json={
|
||||
"failure_type": "software",
|
||||
"plain_explanation": "Sonarr lost connection to indexer.",
|
||||
})
|
||||
|
||||
resp = client.get("/api/corpus/export")
|
||||
assert resp.status_code == 200
|
||||
lines = [l for l in resp.text.strip().splitlines() if l]
|
||||
assert len(lines) == 1
|
||||
record = json.loads(lines[0])
|
||||
assert record["output"] == "Sonarr lost connection to indexer."
|
||||
assert record["metadata"]["failure_type"] == "software"
|
||||
|
||||
|
||||
def test_export_excludes_pii_flagged(client):
|
||||
client.post(
|
||||
"/api/corpus/log-batch",
|
||||
json=_batch(),
|
||||
headers={"Authorization": f"Bearer {VALID_TOKEN}"},
|
||||
)
|
||||
entry_id = client.get("/api/corpus/entries").json()["entries"][0]["id"]
|
||||
client.post(f"/api/corpus/entries/{entry_id}/label", json={
|
||||
"failure_type": "software",
|
||||
"plain_explanation": "Contains username — should not export.",
|
||||
"pii_flagged": True,
|
||||
})
|
||||
|
||||
resp = client.get("/api/corpus/export")
|
||||
assert resp.text.strip() == ""
|
||||
|
||||
|
||||
# ── Pipeline ingest endpoint ───────────────────────────────────────────────────
|
||||
|
||||
def _make_pipeline_file(directory: Path, name: str, lines: list[dict]) -> Path:
|
||||
"""Write a JSONL pipeline log file to directory."""
|
||||
p = directory / name
|
||||
p.write_text("\n".join(json.dumps(l) for l in lines), encoding="utf-8")
|
||||
return p
|
||||
|
||||
|
||||
_PIPELINE_LINE = {
|
||||
"ts": "2026-05-17T10:00:00Z",
|
||||
"level": "INFO",
|
||||
"logger": "scripts.pipeline.purple_carrot_scraper",
|
||||
"msg": "Fetched recipe page",
|
||||
"extra": {"url": "https://example.com/recipe/1", "status": 200},
|
||||
}
|
||||
|
||||
|
||||
def test_pipeline_ingest_returns_404_when_dir_not_configured(client, tmp_path):
|
||||
"""No pipeline_ingest_dir in config — endpoint returns 404."""
|
||||
resp = client.post("/api/corpus/pipeline-ingest")
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
def test_pipeline_ingest_empty_dir(client, tmp_path, monkeypatch):
|
||||
"""Configured dir exists but is empty — returns zeros, no error."""
|
||||
ingest_dir = tmp_path / "pipeline_logs"
|
||||
ingest_dir.mkdir()
|
||||
config_dir = tmp_path / "config"
|
||||
config_dir.mkdir(exist_ok=True)
|
||||
(config_dir / "label_tool.yaml").write_text(
|
||||
f"corpus:\n pipeline_ingest_dir: \"{ingest_dir}\"\n sources: []\n"
|
||||
)
|
||||
monkeypatch.setattr(lc, "_CONFIG_DIR", config_dir)
|
||||
|
||||
resp = client.post("/api/corpus/pipeline-ingest")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["ingested_files"] == 0
|
||||
assert data["skipped_files"] == 0
|
||||
assert data["entries_stored"] == 0
|
||||
|
||||
|
||||
def test_pipeline_ingest_ingests_valid_file(client, tmp_path, monkeypatch):
|
||||
"""Valid JSONL file is ingested; entries appear in corpus."""
|
||||
ingest_dir = tmp_path / "pipeline_logs"
|
||||
ingest_dir.mkdir()
|
||||
_make_pipeline_file(ingest_dir, "scraper_20260517.jsonl", [
|
||||
_PIPELINE_LINE,
|
||||
{**_PIPELINE_LINE, "msg": "Saved 3 recipes", "level": "INFO"},
|
||||
])
|
||||
|
||||
config_dir = tmp_path / "config"
|
||||
config_dir.mkdir(exist_ok=True)
|
||||
(config_dir / "label_tool.yaml").write_text(
|
||||
f"corpus:\n pipeline_ingest_dir: \"{ingest_dir}\"\n sources: []\n"
|
||||
)
|
||||
monkeypatch.setattr(lc, "_CONFIG_DIR", config_dir)
|
||||
|
||||
resp = client.post("/api/corpus/pipeline-ingest")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["ingested_files"] == 1
|
||||
assert data["entries_stored"] == 2
|
||||
|
||||
entries = client.get("/api/corpus/entries", params={"limit": 10}).json()["entries"]
|
||||
assert len(entries) == 2
|
||||
assert all(e["source_host"] == "pipeline_scrape" for e in entries)
|
||||
|
||||
|
||||
def test_pipeline_ingest_source_id_from_logger(client, tmp_path, monkeypatch):
|
||||
"""source_id is populated from the 'logger' field of each log line."""
|
||||
ingest_dir = tmp_path / "pipeline_logs"
|
||||
ingest_dir.mkdir()
|
||||
_make_pipeline_file(ingest_dir, "run_20260517.jsonl", [_PIPELINE_LINE])
|
||||
|
||||
config_dir = tmp_path / "config"
|
||||
config_dir.mkdir(exist_ok=True)
|
||||
(config_dir / "label_tool.yaml").write_text(
|
||||
f"corpus:\n pipeline_ingest_dir: \"{ingest_dir}\"\n sources: []\n"
|
||||
)
|
||||
monkeypatch.setattr(lc, "_CONFIG_DIR", config_dir)
|
||||
|
||||
client.post("/api/corpus/pipeline-ingest")
|
||||
entries = client.get("/api/corpus/entries", params={"limit": 10}).json()["entries"]
|
||||
assert entries[0]["source_id"] == "scripts.pipeline.purple_carrot_scraper"
|
||||
|
||||
|
||||
def test_pipeline_ingest_idempotent(client, tmp_path, monkeypatch):
|
||||
"""Calling the endpoint twice does not re-ingest already-processed files."""
|
||||
ingest_dir = tmp_path / "pipeline_logs"
|
||||
ingest_dir.mkdir()
|
||||
_make_pipeline_file(ingest_dir, "scraper_20260517.jsonl", [_PIPELINE_LINE])
|
||||
|
||||
config_dir = tmp_path / "config"
|
||||
config_dir.mkdir(exist_ok=True)
|
||||
(config_dir / "label_tool.yaml").write_text(
|
||||
f"corpus:\n pipeline_ingest_dir: \"{ingest_dir}\"\n sources: []\n"
|
||||
)
|
||||
monkeypatch.setattr(lc, "_CONFIG_DIR", config_dir)
|
||||
|
||||
client.post("/api/corpus/pipeline-ingest")
|
||||
resp2 = client.post("/api/corpus/pipeline-ingest")
|
||||
|
||||
data = resp2.json()
|
||||
assert data["ingested_files"] == 0
|
||||
assert data["skipped_files"] == 1
|
||||
assert data["entries_stored"] == 0
|
||||
|
||||
entries = client.get("/api/corpus/entries", params={"limit": 10}).json()["entries"]
|
||||
assert len(entries) == 1 # still just the one from the first ingest
|
||||
|
||||
|
||||
def test_pipeline_ingest_skips_non_jsonl(client, tmp_path, monkeypatch):
|
||||
"""Non-.jsonl files in the dir are silently ignored."""
|
||||
ingest_dir = tmp_path / "pipeline_logs"
|
||||
ingest_dir.mkdir()
|
||||
(ingest_dir / "notes.txt").write_text("this is not a log file")
|
||||
(ingest_dir / "run.csv").write_text("a,b,c\n1,2,3")
|
||||
_make_pipeline_file(ingest_dir, "valid_20260517.jsonl", [_PIPELINE_LINE])
|
||||
|
||||
config_dir = tmp_path / "config"
|
||||
config_dir.mkdir(exist_ok=True)
|
||||
(config_dir / "label_tool.yaml").write_text(
|
||||
f"corpus:\n pipeline_ingest_dir: \"{ingest_dir}\"\n sources: []\n"
|
||||
)
|
||||
monkeypatch.setattr(lc, "_CONFIG_DIR", config_dir)
|
||||
|
||||
resp = client.post("/api/corpus/pipeline-ingest")
|
||||
assert resp.json()["ingested_files"] == 1
|
||||
|
||||
|
||||
def test_pipeline_ingest_skips_malformed_lines(client, tmp_path, monkeypatch):
|
||||
"""Lines that are not valid JSON are skipped; valid lines in the same file still land."""
|
||||
ingest_dir = tmp_path / "pipeline_logs"
|
||||
ingest_dir.mkdir()
|
||||
p = ingest_dir / "mixed_20260517.jsonl"
|
||||
p.write_text(
|
||||
json.dumps(_PIPELINE_LINE) + "\n"
|
||||
"this is not json\n"
|
||||
+ json.dumps({**_PIPELINE_LINE, "msg": "another valid line"}),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
config_dir = tmp_path / "config"
|
||||
config_dir.mkdir(exist_ok=True)
|
||||
(config_dir / "label_tool.yaml").write_text(
|
||||
f"corpus:\n pipeline_ingest_dir: \"{ingest_dir}\"\n sources: []\n"
|
||||
)
|
||||
monkeypatch.setattr(lc, "_CONFIG_DIR", config_dir)
|
||||
|
||||
resp = client.post("/api/corpus/pipeline-ingest")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["entries_stored"] == 2 # 2 valid lines, 1 skipped
|
||||
|
||||
|
||||
def test_pipeline_ingest_new_file_after_first_run(client, tmp_path, monkeypatch):
|
||||
"""A new file added after the first ingest is picked up on the next call."""
|
||||
ingest_dir = tmp_path / "pipeline_logs"
|
||||
ingest_dir.mkdir()
|
||||
_make_pipeline_file(ingest_dir, "run_a.jsonl", [_PIPELINE_LINE])
|
||||
|
||||
config_dir = tmp_path / "config"
|
||||
config_dir.mkdir(exist_ok=True)
|
||||
(config_dir / "label_tool.yaml").write_text(
|
||||
f"corpus:\n pipeline_ingest_dir: \"{ingest_dir}\"\n sources: []\n"
|
||||
)
|
||||
monkeypatch.setattr(lc, "_CONFIG_DIR", config_dir)
|
||||
|
||||
client.post("/api/corpus/pipeline-ingest") # ingest run_a.jsonl
|
||||
|
||||
_make_pipeline_file(ingest_dir, "run_b.jsonl", [
|
||||
{**_PIPELINE_LINE, "msg": "Second run line"},
|
||||
])
|
||||
|
||||
resp2 = client.post("/api/corpus/pipeline-ingest")
|
||||
data = resp2.json()
|
||||
assert data["ingested_files"] == 1
|
||||
assert data["skipped_files"] == 1
|
||||
assert data["entries_stored"] == 1
|
||||
|
|
@ -17,7 +17,6 @@ def reset_models_globals(tmp_path):
|
|||
from app import models as models_module
|
||||
|
||||
prev_models = models_module._MODELS_DIR
|
||||
prev_cf_text = models_module._CF_TEXT_MODELS_DIR
|
||||
prev_queue = models_module._QUEUE_DIR
|
||||
prev_progress = dict(models_module._download_progress)
|
||||
|
||||
|
|
@ -27,14 +26,12 @@ def reset_models_globals(tmp_path):
|
|||
queue_dir.mkdir()
|
||||
|
||||
models_module.set_models_dir(models_dir)
|
||||
models_module.set_cf_text_models_dir(tmp_path / "cf-text-models")
|
||||
models_module.set_queue_dir(queue_dir)
|
||||
models_module._download_progress = {}
|
||||
|
||||
yield
|
||||
|
||||
models_module.set_models_dir(prev_models)
|
||||
models_module.set_cf_text_models_dir(prev_cf_text)
|
||||
models_module.set_queue_dir(prev_queue)
|
||||
models_module._download_progress = prev_progress
|
||||
|
||||
|
|
@ -125,88 +122,17 @@ def test_lookup_returns_correct_shape(client):
|
|||
assert data["already_queued"] is False
|
||||
|
||||
|
||||
def test_lookup_unknown_pipeline_tag_returns_null_adapter_and_incompatible(client):
|
||||
"""An unrecognised pipeline_tag yields adapter_recommendation=null and compatible=False."""
|
||||
def test_lookup_unknown_pipeline_tag_returns_null_adapter(client):
|
||||
"""An unrecognised pipeline_tag yields adapter_recommendation=null."""
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.json.return_value = _make_hf_response("org/m", "reinforcement-learning")
|
||||
mock_resp.json.return_value = _make_hf_response("org/m", "audio-classification")
|
||||
|
||||
with patch("app.models.httpx.get", return_value=mock_resp):
|
||||
r = client.get("/api/models/lookup", params={"repo_id": "org/m"})
|
||||
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["adapter_recommendation"] is None
|
||||
assert data["compatible"] is False
|
||||
assert data["role"] is None
|
||||
assert data["service"] is None
|
||||
assert "CircuitForge model ecosystem" in data["warning"]
|
||||
|
||||
|
||||
def test_lookup_stt_tag_returns_compatible_with_cf_stt_service(client):
|
||||
"""automatic-speech-recognition tag yields compatible=True, service=cf-stt."""
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.json.return_value = _make_hf_response("openai/whisper-base", "automatic-speech-recognition")
|
||||
|
||||
with patch("app.models.httpx.get", return_value=mock_resp):
|
||||
r = client.get("/api/models/lookup", params={"repo_id": "openai/whisper-base"})
|
||||
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["compatible"] is True
|
||||
assert data["adapter_recommendation"] is None
|
||||
assert data["role"] == "stt"
|
||||
assert data["service"] == "cf-stt"
|
||||
assert data["warning"] is None
|
||||
|
||||
|
||||
def test_lookup_vision_tag_returns_compatible_with_cf_vision_service(client):
|
||||
"""image-classification tag yields compatible=True, service=cf-vision."""
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.json.return_value = _make_hf_response("google/siglip-base", "image-classification")
|
||||
|
||||
with patch("app.models.httpx.get", return_value=mock_resp):
|
||||
r = client.get("/api/models/lookup", params={"repo_id": "google/siglip-base"})
|
||||
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["compatible"] is True
|
||||
assert data["role"] == "vision"
|
||||
assert data["service"] == "cf-vision"
|
||||
|
||||
|
||||
def test_lookup_audio_classification_tag_returns_cf_voice_service(client):
|
||||
"""audio-classification tag yields compatible=True, service=cf-voice."""
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.json.return_value = _make_hf_response("org/audio-model", "audio-classification")
|
||||
|
||||
with patch("app.models.httpx.get", return_value=mock_resp):
|
||||
r = client.get("/api/models/lookup", params={"repo_id": "org/audio-model"})
|
||||
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["compatible"] is True
|
||||
assert data["role"] == "classifier"
|
||||
assert data["service"] == "cf-voice"
|
||||
|
||||
|
||||
def test_lookup_embedding_tag_returns_compatible_with_cf_core_service(client):
|
||||
"""feature-extraction tag yields compatible=True, service=cf-core."""
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.json.return_value = _make_hf_response("BAAI/bge-small-en", "feature-extraction")
|
||||
|
||||
with patch("app.models.httpx.get", return_value=mock_resp):
|
||||
r = client.get("/api/models/lookup", params={"repo_id": "BAAI/bge-small-en"})
|
||||
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["compatible"] is True
|
||||
assert data["role"] == "embedding"
|
||||
assert data["service"] == "cf-core"
|
||||
assert r.json()["adapter_recommendation"] is None
|
||||
|
||||
|
||||
def test_lookup_already_queued_flag(client):
|
||||
|
|
@ -255,26 +181,6 @@ def test_queue_add_returns_entry_fields(client):
|
|||
assert entry["adapter_recommendation"] == "ZeroShotAdapter"
|
||||
|
||||
|
||||
def test_queue_preserves_role_and_service(client):
|
||||
"""POST /queue with role/service fields round-trips them through GET /queue."""
|
||||
r = client.post("/api/models/queue", json={
|
||||
"repo_id": "openai/whisper-base",
|
||||
"pipeline_tag": "automatic-speech-recognition",
|
||||
"adapter_recommendation": None,
|
||||
"role": "stt",
|
||||
"service": "cf-stt",
|
||||
})
|
||||
assert r.status_code == 201
|
||||
entry = r.json()
|
||||
assert entry["role"] == "stt"
|
||||
assert entry["service"] == "cf-stt"
|
||||
|
||||
r2 = client.get("/api/models/queue")
|
||||
items = r2.json()
|
||||
assert items[0]["role"] == "stt"
|
||||
assert items[0]["service"] == "cf-stt"
|
||||
|
||||
|
||||
# ── POST /queue — 409 duplicate ────────────────────────────────────────────────
|
||||
|
||||
def test_queue_duplicate_returns_409(client):
|
||||
|
|
@ -411,12 +317,7 @@ def test_installed_detects_downloaded_model(client, tmp_path):
|
|||
model_dir.mkdir()
|
||||
(model_dir / "config.json").write_text(json.dumps({"model_type": "bert"}), encoding="utf-8")
|
||||
(model_dir / "model_info.json").write_text(
|
||||
json.dumps({
|
||||
"repo_id": "org/mymodel",
|
||||
"adapter_recommendation": "ZeroShotAdapter",
|
||||
"role": "classifier",
|
||||
"service": "avocet",
|
||||
}),
|
||||
json.dumps({"repo_id": "org/mymodel", "adapter_recommendation": "ZeroShotAdapter"}),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
|
|
@ -428,51 +329,6 @@ def test_installed_detects_downloaded_model(client, tmp_path):
|
|||
assert items[0]["name"] == "org--mymodel"
|
||||
assert items[0]["adapter"] == "ZeroShotAdapter"
|
||||
assert items[0]["model_id"] == "org/mymodel"
|
||||
assert items[0]["role"] == "classifier"
|
||||
assert items[0]["service"] == "avocet"
|
||||
|
||||
|
||||
def test_installed_stt_model_surfaces_role_and_service(client):
|
||||
"""A downloaded STT model's role/service are returned by GET /installed."""
|
||||
from app import models as models_module
|
||||
|
||||
model_dir = models_module._MODELS_DIR / "openai--whisper-base"
|
||||
model_dir.mkdir()
|
||||
(model_dir / "config.json").write_text(json.dumps({"model_type": "whisper"}), encoding="utf-8")
|
||||
(model_dir / "model_info.json").write_text(
|
||||
json.dumps({
|
||||
"repo_id": "openai/whisper-base",
|
||||
"adapter_recommendation": None,
|
||||
"role": "stt",
|
||||
"service": "cf-stt",
|
||||
}),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
r = client.get("/api/models/installed")
|
||||
assert r.status_code == 200
|
||||
items = r.json()
|
||||
assert items[0]["role"] == "stt"
|
||||
assert items[0]["service"] == "cf-stt"
|
||||
assert items[0]["adapter"] is None
|
||||
|
||||
|
||||
def test_installed_finetuned_model_defaults_to_avocet_service(client):
|
||||
"""Fine-tuned models with no role/service in training_info default to avocet/classifier."""
|
||||
from app import models as models_module
|
||||
|
||||
model_dir = models_module._MODELS_DIR / "my-finetuned-v2"
|
||||
model_dir.mkdir()
|
||||
(model_dir / "training_info.json").write_text(
|
||||
json.dumps({"base_model": "microsoft/deberta-v3-base", "epochs": 3}),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
r = client.get("/api/models/installed")
|
||||
assert r.status_code == 200
|
||||
items = r.json()
|
||||
assert items[0]["role"] == "classifier"
|
||||
assert items[0]["service"] == "avocet"
|
||||
|
||||
|
||||
def test_installed_detects_finetuned_model(client):
|
||||
|
|
@ -515,18 +371,15 @@ def test_delete_installed_not_found_returns_404(client):
|
|||
|
||||
|
||||
def test_delete_installed_path_traversal_blocked(client):
|
||||
"""DELETE /installed/../../etc must be blocked.
|
||||
Path traversal normalises to a different URL (/api/etc); if web/dist exists
|
||||
the StaticFiles mount intercepts it and returns 405 (GET/HEAD only).
|
||||
"""
|
||||
"""DELETE /installed/../../etc must be blocked (400 or 422)."""
|
||||
r = client.delete("/api/models/installed/../../etc")
|
||||
assert r.status_code in (400, 404, 405, 422)
|
||||
assert r.status_code in (400, 404, 422)
|
||||
|
||||
|
||||
def test_delete_installed_dotdot_name_blocked(client):
|
||||
"""A name containing '..' in any form must be rejected."""
|
||||
r = client.delete("/api/models/installed/..%2F..%2Fetc")
|
||||
assert r.status_code in (400, 404, 405, 422)
|
||||
assert r.status_code in (400, 404, 422)
|
||||
|
||||
|
||||
def test_delete_installed_name_with_slash_blocked(client):
|
||||
|
|
@ -544,84 +397,3 @@ def test_delete_installed_name_with_slash_blocked(client):
|
|||
except _HTTPException as exc:
|
||||
assert exc.status_code in (400, 404)
|
||||
raise
|
||||
|
||||
|
||||
# ── Catalog registration ───────────────────────────────────────────────────────
|
||||
|
||||
_MINIMAL_YAML = """\
|
||||
services:
|
||||
cf-text:
|
||||
max_mb: {max_mb}
|
||||
catalog:
|
||||
existing-model:
|
||||
path: /some/path
|
||||
vram_mb: 1000
|
||||
description: "placeholder"
|
||||
"""
|
||||
|
||||
|
||||
def _make_node_yaml(tmp_path: Path, max_mb: int = 8192) -> Path:
|
||||
p = tmp_path / "testnode.yaml"
|
||||
p.write_text(_MINIMAL_YAML.format(max_mb=max_mb), encoding="utf-8")
|
||||
return p
|
||||
|
||||
|
||||
def test_catalog_registration_fp16_no_env_block(tmp_path):
|
||||
"""When model fits at FP16, no env block should be written."""
|
||||
from app import models as models_module
|
||||
|
||||
node_yaml = _make_node_yaml(tmp_path, max_mb=8192)
|
||||
with patch.object(models_module, "_CF_ORCH_PROFILES_DIR", tmp_path):
|
||||
updated = models_module._register_in_node_catalogs(
|
||||
repo_id="org/SmallModel",
|
||||
local_path=tmp_path / "org--SmallModel",
|
||||
vram_mb_fp16=4000,
|
||||
role="generator",
|
||||
)
|
||||
|
||||
assert "testnode" in updated
|
||||
content = node_yaml.read_text()
|
||||
# _catalog_key strips org prefix and lowercases: "org/SmallModel" → "smallmodel"
|
||||
assert "smallmodel:" in content
|
||||
assert "CF_TEXT_4BIT" not in content
|
||||
assert "env:" not in content
|
||||
|
||||
|
||||
def test_catalog_registration_needs_4bit_writes_env_block(tmp_path):
|
||||
"""When model only fits at 4-bit, env: CF_TEXT_4BIT: '1' must be written."""
|
||||
from app import models as models_module
|
||||
|
||||
node_yaml = _make_node_yaml(tmp_path, max_mb=8192)
|
||||
with patch.object(models_module, "_CF_ORCH_PROFILES_DIR", tmp_path):
|
||||
updated = models_module._register_in_node_catalogs(
|
||||
repo_id="org/BigModel",
|
||||
local_path=tmp_path / "org--BigModel",
|
||||
vram_mb_fp16=20000, # won't fit at FP16 on 8 GB
|
||||
role="generator",
|
||||
)
|
||||
|
||||
assert "testnode" in updated
|
||||
content = node_yaml.read_text()
|
||||
# _catalog_key: "org/BigModel" → "bigmodel"
|
||||
assert "bigmodel:" in content
|
||||
assert "env:" in content
|
||||
assert 'CF_TEXT_4BIT: "1"' in content
|
||||
assert "CF_TEXT_4BIT=1 required" in content # description note
|
||||
|
||||
|
||||
def test_catalog_registration_too_large_skipped(tmp_path):
|
||||
"""Model too large even at 4-bit should not be registered."""
|
||||
from app import models as models_module
|
||||
|
||||
node_yaml = _make_node_yaml(tmp_path, max_mb=8192)
|
||||
with patch.object(models_module, "_CF_ORCH_PROFILES_DIR", tmp_path):
|
||||
updated = models_module._register_in_node_catalogs(
|
||||
repo_id="org/HugeModel",
|
||||
local_path=tmp_path / "org--HugeModel",
|
||||
vram_mb_fp16=80000, # 4-bit ~22 GB, still won't fit on 8 GB
|
||||
role="generator",
|
||||
)
|
||||
|
||||
assert updated == []
|
||||
content = node_yaml.read_text()
|
||||
assert "hugemodel" not in content
|
||||
|
|
|
|||
|
|
@ -1,575 +0,0 @@
|
|||
"""Tests for app/nodes.py — /api/nodes-mgmt/* endpoints."""
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
from fastapi.testclient import TestClient
|
||||
from unittest.mock import MagicMock, patch
|
||||
import os as _os
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_nodes_globals(tmp_path):
|
||||
"""Redirect _CONFIG_DIR to tmp_path so tests never read the real config."""
|
||||
from app import nodes as nodes_module
|
||||
prev = nodes_module._CONFIG_DIR
|
||||
nodes_module.set_config_dir(tmp_path)
|
||||
yield tmp_path
|
||||
nodes_module.set_config_dir(prev)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
from app.api import app
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
def _write_config(config_dir: Path, cforch_cfg: dict) -> None:
|
||||
cfg = {"cforch": cforch_cfg}
|
||||
(config_dir / "label_tool.yaml").write_text(yaml.dump(cfg), encoding="utf-8")
|
||||
|
||||
|
||||
def _write_profile(profiles_dir: Path, node_id: str, profile: dict) -> None:
|
||||
profiles_dir.mkdir(parents=True, exist_ok=True)
|
||||
(profiles_dir / f"{node_id}.yaml").write_text(yaml.dump(profile), encoding="utf-8")
|
||||
|
||||
|
||||
def test_nodes_module_imports():
|
||||
from app import nodes
|
||||
assert hasattr(nodes, "router")
|
||||
assert hasattr(nodes, "set_config_dir")
|
||||
|
||||
|
||||
def test_list_nodes_returns_empty_when_no_coordinator(client):
|
||||
"""No cforch config — endpoint returns empty list, not 500."""
|
||||
r = client.get("/api/nodes-mgmt/nodes")
|
||||
assert r.status_code == 200
|
||||
assert r.json() == []
|
||||
|
||||
|
||||
|
||||
|
||||
def _fake_nodes_response(nodes_json: list, services_json: list | None = None):
|
||||
"""Build side_effect list for two httpx.get calls: nodes then services."""
|
||||
mock_nodes = MagicMock()
|
||||
mock_nodes.raise_for_status = MagicMock()
|
||||
mock_nodes.json.return_value = {"nodes": nodes_json}
|
||||
|
||||
mock_services = MagicMock()
|
||||
mock_services.raise_for_status = MagicMock()
|
||||
mock_services.json.return_value = {"services": services_json or []}
|
||||
|
||||
return [mock_nodes, mock_services]
|
||||
|
||||
|
||||
def test_list_nodes_coordinator_unreachable_returns_empty(client, tmp_path):
|
||||
"""Coordinator unreachable — returns [] with no 500."""
|
||||
import httpx
|
||||
_write_config(tmp_path, {"coordinator_url": "http://fake-coord:7700"})
|
||||
with patch("httpx.get", side_effect=httpx.ConnectError("refused")):
|
||||
r = client.get("/api/nodes-mgmt/nodes")
|
||||
assert r.status_code == 200
|
||||
assert r.json() == []
|
||||
|
||||
|
||||
def test_list_nodes_merges_profile_data(client, tmp_path):
|
||||
"""Profile YAML services_assigned merged with live GPU stats."""
|
||||
profiles_dir = tmp_path / "profiles"
|
||||
_write_config(tmp_path, {
|
||||
"coordinator_url": "http://fake-coord:7700",
|
||||
"profiles_dir": str(profiles_dir),
|
||||
})
|
||||
_write_profile(profiles_dir, "heimdall", {
|
||||
"services": {
|
||||
"cf-text": {"min_compute_cap": 7.0, "max_mb": 8192, "catalog": {}},
|
||||
},
|
||||
"nodes": {
|
||||
"heimdall": {
|
||||
"gpus": [{"id": 0, "vram_mb": 24576, "compute_cap": 8.6,
|
||||
"services": ["cf-text"], "role": "primary", "card": "RTX 3090",
|
||||
"always_on": True}],
|
||||
"agent_url": "http://10.1.10.71:7701",
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
coord_nodes = [{
|
||||
"node_id": "heimdall", "online": True, "agent_url": "http://10.1.10.71:7701",
|
||||
"gpus": [{"gpu_id": 0, "card": "RTX 3090", "vram_total_mb": 24576,
|
||||
"vram_used_mb": 4096, "vram_free_mb": 20480,
|
||||
"temp_c": 42.0, "utilization_pct": 15.0, "compute_cap": 8.6}],
|
||||
}]
|
||||
|
||||
with patch("httpx.get", side_effect=_fake_nodes_response(coord_nodes)):
|
||||
r = client.get("/api/nodes-mgmt/nodes")
|
||||
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert len(data) == 1
|
||||
node = data[0]
|
||||
assert node["node_id"] == "heimdall"
|
||||
assert node["profile_loaded"] is True
|
||||
assert node["gpus"][0]["services_assigned"] == ["cf-text"]
|
||||
assert node["gpus"][0]["vram_total_mb"] == 24576
|
||||
assert "cf-text" in node["services_catalog"]
|
||||
|
||||
|
||||
def test_list_nodes_no_profile_returns_profile_loaded_false(client, tmp_path):
|
||||
"""Node with no profile YAML — profile_loaded: false, GPU stats still returned."""
|
||||
_write_config(tmp_path, {"coordinator_url": "http://fake-coord:7700"})
|
||||
|
||||
coord_nodes = [{
|
||||
"node_id": "sif", "online": True, "agent_url": "http://10.1.10.158:7701",
|
||||
"gpus": [{"gpu_id": 0, "card": "RTX 5060 Ti", "vram_total_mb": 16384,
|
||||
"vram_used_mb": 0, "vram_free_mb": 16384,
|
||||
"temp_c": None, "utilization_pct": None, "compute_cap": 10.0}],
|
||||
}]
|
||||
|
||||
with patch("httpx.get", side_effect=_fake_nodes_response(coord_nodes)):
|
||||
r = client.get("/api/nodes-mgmt/nodes")
|
||||
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
node = data[0]
|
||||
assert node["profile_loaded"] is False
|
||||
assert node["gpus"][0]["card"] == "RTX 5060 Ti"
|
||||
assert node["services_catalog"] == {}
|
||||
|
||||
|
||||
def test_list_nodes_marks_running_services(client, tmp_path):
|
||||
"""services_running populated from coordinator /api/services response."""
|
||||
profiles_dir = tmp_path / "profiles"
|
||||
_write_config(tmp_path, {
|
||||
"coordinator_url": "http://fake-coord:7700",
|
||||
"profiles_dir": str(profiles_dir),
|
||||
})
|
||||
_write_profile(profiles_dir, "heimdall", {
|
||||
"services": {},
|
||||
"nodes": {"heimdall": {"gpus": [{"id": 0, "vram_mb": 24576, "compute_cap": 8.6,
|
||||
"services": ["cf-text"], "role": "p",
|
||||
"card": "RTX 3090", "always_on": True}],
|
||||
"agent_url": "http://10.1.10.71:7701"}}
|
||||
})
|
||||
|
||||
coord_nodes = [{"node_id": "heimdall", "online": True,
|
||||
"agent_url": "http://10.1.10.71:7701",
|
||||
"gpus": [{"gpu_id": 0, "card": "RTX 3090", "vram_total_mb": 24576,
|
||||
"vram_used_mb": 8192, "vram_free_mb": 16384,
|
||||
"temp_c": 55.0, "utilization_pct": 80.0, "compute_cap": 8.6}]}]
|
||||
coord_services = [{"service": "cf-text", "node_id": "heimdall", "gpu_id": 0}]
|
||||
|
||||
with patch("httpx.get", side_effect=_fake_nodes_response(coord_nodes, coord_services)):
|
||||
r = client.get("/api/nodes-mgmt/nodes")
|
||||
|
||||
data = r.json()
|
||||
assert data[0]["gpus"][0]["services_running"] == ["cf-text"]
|
||||
|
||||
|
||||
# ── GET /api/nodes-mgmt/nodes/{node_id}/profile ────────────────────────────────
|
||||
|
||||
def test_get_profile_returns_parsed_yaml(client, tmp_path):
|
||||
profiles_dir = tmp_path / "profiles"
|
||||
_write_config(tmp_path, {"profiles_dir": str(profiles_dir)})
|
||||
profile = {
|
||||
"services": {"cf-text": {"min_compute_cap": 7.0, "max_mb": 8192, "catalog": {}}},
|
||||
"nodes": {"heimdall": {"gpus": [], "agent_url": "http://10.1.10.71:7701"}},
|
||||
}
|
||||
_write_profile(profiles_dir, "heimdall", profile)
|
||||
|
||||
r = client.get("/api/nodes-mgmt/nodes/heimdall/profile")
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert "services" in data
|
||||
assert "cf-text" in data["services"]
|
||||
|
||||
|
||||
def test_get_profile_404_when_missing(client, tmp_path):
|
||||
_write_config(tmp_path, {"profiles_dir": str(tmp_path / "profiles")})
|
||||
r = client.get("/api/nodes-mgmt/nodes/nonexistent/profile")
|
||||
assert r.status_code == 404
|
||||
|
||||
|
||||
def test_get_profile_500_on_malformed_yaml(client, tmp_path):
|
||||
profiles_dir = tmp_path / "profiles"
|
||||
profiles_dir.mkdir()
|
||||
_write_config(tmp_path, {"profiles_dir": str(profiles_dir)})
|
||||
(profiles_dir / "bad.yaml").write_text("key: [unclosed", encoding="utf-8")
|
||||
|
||||
r = client.get("/api/nodes-mgmt/nodes/bad/profile")
|
||||
assert r.status_code == 500
|
||||
|
||||
|
||||
# ── POST /api/nodes-mgmt/nodes/{node_id}/gpu/{gpu_id}/services ─────────────────
|
||||
|
||||
|
||||
_BASE_PROFILE = {
|
||||
"services": {
|
||||
"cf-text": {"min_compute_cap": 7.0, "max_mb": 8192, "priority": 1,
|
||||
"catalog": {"llama3": {"vram_mb": 6144, "path": "/m/llama3",
|
||||
"description": "", "multi_gpu": False, "env": {}}}},
|
||||
"ollama": {"min_compute_cap": 0.0, "max_mb": 2048, "priority": 2, "catalog": {}},
|
||||
},
|
||||
"nodes": {
|
||||
"heimdall": {
|
||||
"gpus": [{"id": 0, "vram_mb": 24576, "compute_cap": 8.6,
|
||||
"services": [], "role": "primary", "card": "RTX 3090",
|
||||
"always_on": True}],
|
||||
"agent_url": "http://10.1.10.71:7701",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def _setup_profile(tmp_path, profile=None):
|
||||
profiles_dir = tmp_path / "profiles"
|
||||
_write_config(tmp_path, {
|
||||
"coordinator_url": "http://fake-coord:7700",
|
||||
"profiles_dir": str(profiles_dir),
|
||||
})
|
||||
_write_profile(profiles_dir, "heimdall", profile or _BASE_PROFILE)
|
||||
return profiles_dir
|
||||
|
||||
|
||||
def test_update_services_compatible_writes_and_reloads(client, tmp_path):
|
||||
profiles_dir = _setup_profile(tmp_path)
|
||||
|
||||
mock_reload = MagicMock()
|
||||
mock_reload.status_code = 200
|
||||
|
||||
with patch("httpx.post", return_value=mock_reload):
|
||||
r = client.post(
|
||||
"/api/nodes-mgmt/nodes/heimdall/gpu/0/services",
|
||||
json={"services": ["cf-text"]},
|
||||
)
|
||||
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["ok"] is True
|
||||
assert data["reloaded"] is True
|
||||
|
||||
saved = yaml.safe_load((profiles_dir / "heimdall.yaml").read_text())
|
||||
assert saved["nodes"]["heimdall"]["gpus"][0]["services"] == ["cf-text"]
|
||||
|
||||
|
||||
def test_update_services_atomic_write_uses_tmp_file(client, tmp_path):
|
||||
"""YAML must be written to .tmp then renamed — never written directly."""
|
||||
profiles_dir = _setup_profile(tmp_path)
|
||||
renamed_pairs: list[tuple] = []
|
||||
|
||||
original_replace = _os.replace
|
||||
|
||||
def capture(src, dst):
|
||||
renamed_pairs.append((str(src), str(dst)))
|
||||
original_replace(src, dst)
|
||||
|
||||
with patch("os.replace", side_effect=capture), \
|
||||
patch("httpx.post", return_value=MagicMock(status_code=200)):
|
||||
client.post(
|
||||
"/api/nodes-mgmt/nodes/heimdall/gpu/0/services",
|
||||
json={"services": ["ollama"]},
|
||||
)
|
||||
|
||||
assert any(src.endswith(".tmp") for src, dst in renamed_pairs), \
|
||||
"Expected atomic write via .tmp rename"
|
||||
|
||||
|
||||
def test_update_services_incompatible_compute_cap_returns_422(client, tmp_path):
|
||||
low_cap_profile = {
|
||||
**_BASE_PROFILE,
|
||||
"nodes": {
|
||||
"heimdall": {
|
||||
"gpus": [{"id": 0, "vram_mb": 24576, "compute_cap": 6.0,
|
||||
"services": [], "role": "p", "card": "GTX 1080",
|
||||
"always_on": False}],
|
||||
"agent_url": "http://10.1.10.71:7701",
|
||||
}
|
||||
}
|
||||
}
|
||||
_setup_profile(tmp_path, low_cap_profile)
|
||||
|
||||
r = client.post(
|
||||
"/api/nodes-mgmt/nodes/heimdall/gpu/0/services",
|
||||
json={"services": ["cf-text"]},
|
||||
)
|
||||
assert r.status_code == 422
|
||||
assert "compute_cap" in r.json()["detail"]
|
||||
|
||||
|
||||
def test_update_services_insufficient_vram_returns_422(client, tmp_path):
|
||||
tiny_vram_profile = {
|
||||
**_BASE_PROFILE,
|
||||
"nodes": {
|
||||
"heimdall": {
|
||||
"gpus": [{"id": 0, "vram_mb": 512, "compute_cap": 8.6,
|
||||
"services": [], "role": "p", "card": "old",
|
||||
"always_on": False}],
|
||||
"agent_url": "http://10.1.10.71:7701",
|
||||
}
|
||||
}
|
||||
}
|
||||
_setup_profile(tmp_path, tiny_vram_profile)
|
||||
|
||||
r = client.post(
|
||||
"/api/nodes-mgmt/nodes/heimdall/gpu/0/services",
|
||||
json={"services": ["cf-text"]},
|
||||
)
|
||||
assert r.status_code == 422
|
||||
assert "VRAM" in r.json()["detail"]
|
||||
|
||||
|
||||
def test_update_services_unknown_service_returns_422(client, tmp_path):
|
||||
_setup_profile(tmp_path)
|
||||
r = client.post(
|
||||
"/api/nodes-mgmt/nodes/heimdall/gpu/0/services",
|
||||
json={"services": ["not-a-real-service"]},
|
||||
)
|
||||
assert r.status_code == 422
|
||||
|
||||
|
||||
def test_update_services_reload_failure_returns_reloaded_false(client, tmp_path):
|
||||
"""YAML saved but coordinator reload fails — ok: true, reloaded: false."""
|
||||
_setup_profile(tmp_path)
|
||||
|
||||
mock_reload = MagicMock()
|
||||
mock_reload.status_code = 500
|
||||
|
||||
with patch("httpx.post", return_value=mock_reload):
|
||||
r = client.post(
|
||||
"/api/nodes-mgmt/nodes/heimdall/gpu/0/services",
|
||||
json={"services": ["ollama"]},
|
||||
)
|
||||
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["ok"] is True
|
||||
assert data["reloaded"] is False
|
||||
|
||||
# ── Ollama endpoints ───────────────────────────────────────────────────────────
|
||||
|
||||
_OLLAMA_PROFILE = {
|
||||
"services": {},
|
||||
"nodes": {
|
||||
"heimdall": {
|
||||
"gpus": [],
|
||||
"agent_url": "http://10.1.10.71:7701",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def test_list_ollama_models_proxies_tags(client, tmp_path):
|
||||
profiles_dir = tmp_path / "profiles"
|
||||
_write_config(tmp_path, {"profiles_dir": str(profiles_dir)})
|
||||
_write_profile(profiles_dir, "heimdall", _OLLAMA_PROFILE)
|
||||
|
||||
mock_tags = MagicMock()
|
||||
mock_tags.raise_for_status = MagicMock()
|
||||
mock_tags.json.return_value = {
|
||||
"models": [{"name": "nomic-embed-text", "size": 274000000, "modified_at": "2025-01-01"}]
|
||||
}
|
||||
|
||||
with patch("httpx.get", return_value=mock_tags):
|
||||
r = client.get("/api/nodes-mgmt/nodes/heimdall/models/ollama")
|
||||
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert len(data["models"]) == 1
|
||||
assert data["models"][0]["name"] == "nomic-embed-text"
|
||||
|
||||
|
||||
def test_list_ollama_models_unreachable_returns_error(client, tmp_path):
|
||||
import httpx as _httpx
|
||||
profiles_dir = tmp_path / "profiles"
|
||||
_write_config(tmp_path, {"profiles_dir": str(profiles_dir)})
|
||||
_write_profile(profiles_dir, "heimdall", _OLLAMA_PROFILE)
|
||||
|
||||
with patch("httpx.get", side_effect=_httpx.ConnectError("refused")):
|
||||
r = client.get("/api/nodes-mgmt/nodes/heimdall/models/ollama")
|
||||
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert "error" in data
|
||||
|
||||
|
||||
def test_pull_ollama_model_streams_sse(client, tmp_path):
|
||||
profiles_dir = tmp_path / "profiles"
|
||||
_write_config(tmp_path, {"profiles_dir": str(profiles_dir)})
|
||||
_write_profile(profiles_dir, "heimdall", _OLLAMA_PROFILE)
|
||||
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.iter_lines.return_value = iter([
|
||||
'{"status": "pulling manifest"}',
|
||||
'{"status": "pulling", "digest": "sha256-abc", "total": 1000, "completed": 500}',
|
||||
'{"status": "success"}',
|
||||
])
|
||||
|
||||
with patch("httpx.stream") as mock_stream_fn:
|
||||
mock_stream_fn.return_value.__enter__ = MagicMock(return_value=mock_resp)
|
||||
mock_stream_fn.return_value.__exit__ = MagicMock(return_value=False)
|
||||
r = client.post(
|
||||
"/api/nodes-mgmt/nodes/heimdall/models/ollama/pull",
|
||||
json={"name": "nomic-embed-text"},
|
||||
)
|
||||
|
||||
assert r.status_code == 200
|
||||
body = r.text
|
||||
assert 'data: {"status": "pulling manifest"}' in body
|
||||
assert 'data: {"status": "success"}' in body
|
||||
|
||||
|
||||
def test_pull_ollama_model_error_event_in_stream(client, tmp_path):
|
||||
profiles_dir = tmp_path / "profiles"
|
||||
_write_config(tmp_path, {"profiles_dir": str(profiles_dir)})
|
||||
_write_profile(profiles_dir, "heimdall", _OLLAMA_PROFILE)
|
||||
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.iter_lines.return_value = iter([
|
||||
'{"error": "permission denied: /var/lib/ollama/sha256-abc-partial-0"}',
|
||||
])
|
||||
|
||||
with patch("httpx.stream") as mock_stream_fn:
|
||||
mock_stream_fn.return_value.__enter__ = MagicMock(return_value=mock_resp)
|
||||
mock_stream_fn.return_value.__exit__ = MagicMock(return_value=False)
|
||||
r = client.post(
|
||||
"/api/nodes-mgmt/nodes/heimdall/models/ollama/pull",
|
||||
json={"name": "nomic-embed-text"},
|
||||
)
|
||||
|
||||
assert r.status_code == 200
|
||||
assert "permission denied" in r.text
|
||||
|
||||
|
||||
def test_delete_ollama_model_proxies_delete(client, tmp_path):
|
||||
profiles_dir = tmp_path / "profiles"
|
||||
_write_config(tmp_path, {"profiles_dir": str(profiles_dir)})
|
||||
_write_profile(profiles_dir, "heimdall", _OLLAMA_PROFILE)
|
||||
|
||||
mock_del = MagicMock()
|
||||
mock_del.status_code = 200
|
||||
mock_del.raise_for_status = MagicMock()
|
||||
|
||||
with patch("httpx.request", return_value=mock_del):
|
||||
r = client.delete("/api/nodes-mgmt/nodes/heimdall/models/ollama/nomic-embed-text")
|
||||
|
||||
assert r.status_code == 200
|
||||
assert r.json() == {"ok": True}
|
||||
|
||||
|
||||
def test_delete_ollama_model_404_when_not_found(client, tmp_path):
|
||||
profiles_dir = tmp_path / "profiles"
|
||||
_write_config(tmp_path, {"profiles_dir": str(profiles_dir)})
|
||||
_write_profile(profiles_dir, "heimdall", _OLLAMA_PROFILE)
|
||||
|
||||
mock_del = MagicMock()
|
||||
mock_del.status_code = 404
|
||||
|
||||
with patch("httpx.request", return_value=mock_del):
|
||||
r = client.delete("/api/nodes-mgmt/nodes/heimdall/models/ollama/missing-model")
|
||||
|
||||
assert r.status_code == 404
|
||||
|
||||
|
||||
# ── Deploy model endpoint ──────────────────────────────────────────────────────
|
||||
|
||||
_DEPLOY_PROFILE = {
|
||||
"services": {
|
||||
"cf-text": {
|
||||
"max_mb": 20000,
|
||||
"min_compute_cap": 7.0,
|
||||
"model_base_path": "/devl/Assets/LLM/cf-text/models",
|
||||
"catalog": {},
|
||||
},
|
||||
},
|
||||
"nodes": {
|
||||
"heimdall": {
|
||||
"gpus": [],
|
||||
"agent_url": "http://10.1.10.71:7701",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def test_deploy_model_adds_catalog_entry(client, tmp_path):
|
||||
"""Deploy endpoint should add the model to the service catalog."""
|
||||
profiles_dir = tmp_path / "profiles"
|
||||
_write_config(tmp_path, {
|
||||
"coordinator_url": "http://fake-coord:7700",
|
||||
"profiles_dir": str(profiles_dir),
|
||||
})
|
||||
_write_profile(profiles_dir, "heimdall", _DEPLOY_PROFILE)
|
||||
|
||||
mock_reload = MagicMock()
|
||||
mock_reload.status_code = 200
|
||||
|
||||
with patch("httpx.post", return_value=mock_reload):
|
||||
r = client.post(
|
||||
"/api/nodes-mgmt/nodes/heimdall/models/deploy",
|
||||
json={
|
||||
"model_id": "fdtn-ai--Foundation-Sec-8B-Q4",
|
||||
"service_type": "cf-text",
|
||||
"vram_mb": 5180,
|
||||
"hf_repo": "fdtn-ai/Foundation-Sec-8B-Q4_K_M-GGUF",
|
||||
},
|
||||
)
|
||||
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["ok"] is True
|
||||
assert data["reloaded"] is True
|
||||
assert "fdtn-ai--Foundation-Sec-8B-Q4_K_M-GGUF" in data["path"]
|
||||
|
||||
saved = yaml.safe_load((profiles_dir / "heimdall.yaml").read_text())
|
||||
catalog = saved["services"]["cf-text"]["catalog"]
|
||||
assert "fdtn-ai--Foundation-Sec-8B-Q4" in catalog
|
||||
entry = catalog["fdtn-ai--Foundation-Sec-8B-Q4"]
|
||||
assert entry["vram_mb"] == 5180
|
||||
assert entry["path"].endswith("fdtn-ai--Foundation-Sec-8B-Q4_K_M-GGUF")
|
||||
|
||||
|
||||
def test_deploy_model_explicit_path_overrides_base(client, tmp_path):
|
||||
"""An explicit path in the request body takes precedence over model_base_path."""
|
||||
profiles_dir = tmp_path / "profiles"
|
||||
_write_config(tmp_path, {
|
||||
"coordinator_url": "http://fake-coord:7700",
|
||||
"profiles_dir": str(profiles_dir),
|
||||
})
|
||||
_write_profile(profiles_dir, "heimdall", _DEPLOY_PROFILE)
|
||||
|
||||
with patch("httpx.post", return_value=MagicMock(status_code=200)):
|
||||
r = client.post(
|
||||
"/api/nodes-mgmt/nodes/heimdall/models/deploy",
|
||||
json={
|
||||
"model_id": "my-model",
|
||||
"service_type": "cf-text",
|
||||
"vram_mb": 8000,
|
||||
"path": "/custom/path/to/model",
|
||||
},
|
||||
)
|
||||
|
||||
assert r.status_code == 200
|
||||
assert r.json()["path"] == "/custom/path/to/model"
|
||||
|
||||
|
||||
def test_deploy_model_unknown_service_returns_422(client, tmp_path):
|
||||
"""Service type not in profile → 422."""
|
||||
profiles_dir = tmp_path / "profiles"
|
||||
_write_config(tmp_path, {"profiles_dir": str(profiles_dir)})
|
||||
_write_profile(profiles_dir, "heimdall", _DEPLOY_PROFILE)
|
||||
|
||||
r = client.post(
|
||||
"/api/nodes-mgmt/nodes/heimdall/models/deploy",
|
||||
json={"model_id": "x", "service_type": "vllm", "vram_mb": 8000},
|
||||
)
|
||||
assert r.status_code == 422
|
||||
assert "vllm" in r.json()["detail"]
|
||||
|
||||
|
||||
def test_deploy_model_missing_profile_returns_404(client, tmp_path):
|
||||
_write_config(tmp_path, {"profiles_dir": str(tmp_path / "profiles")})
|
||||
r = client.post(
|
||||
"/api/nodes-mgmt/nodes/nonexistent/models/deploy",
|
||||
json={"model_id": "x", "service_type": "cf-text", "vram_mb": 100},
|
||||
)
|
||||
assert r.status_code == 404
|
||||
|
|
@ -1,227 +0,0 @@
|
|||
"""Tests for app/data/recipe_scan.py — recipe scan labeling endpoints."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.data import recipe_scan as rs
|
||||
|
||||
|
||||
EXTRACTED = {"title": "Shepherd's Pie", "ingredients": ["lamb", "potato"], "steps": ["brown meat", "mash potato"]}
|
||||
GROUND_TRUTH = {"title": "Shepherd's Pie", "ingredients": ["ground lamb", "mashed potato", "peas"], "steps": ["brown meat", "add veg", "mash potato", "bake"]}
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def isolated_db(tmp_path, monkeypatch):
|
||||
monkeypatch.setattr(rs, "_DB_PATH", tmp_path / "recipe_scan.db")
|
||||
rs._init_db()
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def client():
|
||||
from fastapi import FastAPI
|
||||
app = FastAPI()
|
||||
app.include_router(rs.router, prefix="/api/recipe-scan")
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
def _item(**kwargs) -> dict:
|
||||
return {
|
||||
"id": str(uuid.uuid4()),
|
||||
"image_path": "/Library/Assets/kiwi/scans/pc_test.jpg",
|
||||
"modality": kwargs.get("modality", "scanner"),
|
||||
"source": kwargs.get("source", "purple_carrot"),
|
||||
"extracted": kwargs.get("extracted", EXTRACTED),
|
||||
"ground_truth": kwargs.get("ground_truth", GROUND_TRUTH),
|
||||
}
|
||||
|
||||
|
||||
def _import(client, items: list[dict]) -> None:
|
||||
resp = client.post("/api/recipe-scan/import", json={"items": items})
|
||||
assert resp.status_code == 200
|
||||
|
||||
|
||||
# ── Import ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
def test_import_stores_items(client):
|
||||
_import(client, [_item()])
|
||||
stats = client.get("/api/recipe-scan/stats").json()
|
||||
assert stats["total"] == 1
|
||||
assert stats["by_status"]["pending"] == 1
|
||||
|
||||
|
||||
def test_import_rejects_unknown_modality(client):
|
||||
bad = _item()
|
||||
bad["modality"] = "telepathy"
|
||||
resp = client.post("/api/recipe-scan/import", json={"items": [bad]})
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
def test_import_is_idempotent(client):
|
||||
item = _item()
|
||||
_import(client, [item])
|
||||
_import(client, [item]) # same id — should not duplicate
|
||||
stats = client.get("/api/recipe-scan/stats").json()
|
||||
assert stats["total"] == 1
|
||||
|
||||
|
||||
def test_import_multiple_items(client):
|
||||
_import(client, [_item(), _item(), _item()])
|
||||
assert client.get("/api/recipe-scan/stats").json()["total"] == 3
|
||||
|
||||
|
||||
# ── Next ───────────────────────────────────────────────────────────────────────
|
||||
|
||||
def test_next_returns_404_when_queue_empty(client):
|
||||
resp = client.get("/api/recipe-scan/next")
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
def test_next_returns_pending_item(client):
|
||||
item = _item()
|
||||
_import(client, [item])
|
||||
resp = client.get("/api/recipe-scan/next")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["id"] == item["id"]
|
||||
assert data["status"] == "pending"
|
||||
assert "extracted" in data
|
||||
assert "ground_truth" in data
|
||||
|
||||
|
||||
def test_next_skips_non_pending(client):
|
||||
item = _item()
|
||||
_import(client, [item])
|
||||
client.post(f"/api/recipe-scan/items/{item['id']}/reject")
|
||||
resp = client.get("/api/recipe-scan/next")
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
# ── Approve ────────────────────────────────────────────────────────────────────
|
||||
|
||||
def test_approve_marks_item_approved(client):
|
||||
item = _item()
|
||||
_import(client, [item])
|
||||
resp = client.post(f"/api/recipe-scan/items/{item['id']}/approve")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["status"] == "approved"
|
||||
stats = client.get("/api/recipe-scan/stats").json()
|
||||
assert stats["by_status"]["approved"] == 1
|
||||
|
||||
|
||||
def test_approve_returns_404_for_unknown_id(client):
|
||||
resp = client.post("/api/recipe-scan/items/no-such-id/approve")
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
# ── Edit ───────────────────────────────────────────────────────────────────────
|
||||
|
||||
def test_edit_stores_corrected_json(client):
|
||||
item = _item()
|
||||
_import(client, [item])
|
||||
corrected = {**GROUND_TRUTH, "servings": 4}
|
||||
resp = client.post(
|
||||
f"/api/recipe-scan/items/{item['id']}/edit",
|
||||
json={"corrected": corrected},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["status"] == "edited"
|
||||
stats = client.get("/api/recipe-scan/stats").json()
|
||||
assert stats["by_status"]["edited"] == 1
|
||||
|
||||
|
||||
def test_edit_requires_corrected_field(client):
|
||||
item = _item()
|
||||
_import(client, [item])
|
||||
resp = client.post(f"/api/recipe-scan/items/{item['id']}/edit", json={})
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
# ── Reject ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
def test_reject_marks_item_rejected(client):
|
||||
item = _item()
|
||||
_import(client, [item])
|
||||
resp = client.post(
|
||||
f"/api/recipe-scan/items/{item['id']}/reject",
|
||||
json={"reason": "OCR completely unreadable"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["status"] == "rejected"
|
||||
|
||||
|
||||
def test_reject_without_reason_is_valid(client):
|
||||
item = _item()
|
||||
_import(client, [item])
|
||||
resp = client.post(f"/api/recipe-scan/items/{item['id']}/reject")
|
||||
assert resp.status_code == 200
|
||||
|
||||
|
||||
# ── Export ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
def test_export_empty_when_nothing_approved(client):
|
||||
item = _item()
|
||||
_import(client, [item])
|
||||
resp = client.get("/api/recipe-scan/export")
|
||||
assert resp.status_code == 200
|
||||
assert resp.text.strip() == ""
|
||||
|
||||
|
||||
def test_export_includes_approved_item(client):
|
||||
item = _item()
|
||||
_import(client, [item])
|
||||
client.post(f"/api/recipe-scan/items/{item['id']}/approve")
|
||||
resp = client.get("/api/recipe-scan/export")
|
||||
lines = [l for l in resp.text.strip().splitlines() if l]
|
||||
assert len(lines) == 1
|
||||
pair = json.loads(lines[0])
|
||||
assert pair["id"] == item["id"]
|
||||
assert pair["modality"] == "scanner"
|
||||
assert "messages" in pair
|
||||
assert len(pair["messages"]) == 2
|
||||
assert pair["messages"][0]["role"] == "user"
|
||||
assert pair["messages"][1]["role"] == "assistant"
|
||||
|
||||
|
||||
def test_export_includes_edited_item_with_correction(client):
|
||||
item = _item()
|
||||
_import(client, [item])
|
||||
corrected = {**GROUND_TRUTH, "servings": 4}
|
||||
client.post(
|
||||
f"/api/recipe-scan/items/{item['id']}/edit",
|
||||
json={"corrected": corrected},
|
||||
)
|
||||
resp = client.get("/api/recipe-scan/export")
|
||||
lines = [l for l in resp.text.strip().splitlines() if l]
|
||||
pair = json.loads(lines[0])
|
||||
assistant_content = json.loads(pair["messages"][1]["content"])
|
||||
assert assistant_content["servings"] == 4
|
||||
|
||||
|
||||
def test_export_excludes_rejected_items(client):
|
||||
item = _item()
|
||||
_import(client, [item])
|
||||
client.post(f"/api/recipe-scan/items/{item['id']}/reject")
|
||||
resp = client.get("/api/recipe-scan/export")
|
||||
assert resp.text.strip() == ""
|
||||
|
||||
|
||||
# ── Stats ──────────────────────────────────────────────────────────────────────
|
||||
|
||||
def test_stats_counts_all_statuses(client):
|
||||
items = [_item(), _item(), _item(), _item()]
|
||||
_import(client, items)
|
||||
client.post(f"/api/recipe-scan/items/{items[0]['id']}/approve")
|
||||
client.post(f"/api/recipe-scan/items/{items[1]['id']}/edit", json={"corrected": GROUND_TRUTH})
|
||||
client.post(f"/api/recipe-scan/items/{items[2]['id']}/reject")
|
||||
stats = client.get("/api/recipe-scan/stats").json()
|
||||
assert stats["total"] == 4
|
||||
assert stats["by_status"]["pending"] == 1
|
||||
assert stats["by_status"]["approved"] == 1
|
||||
assert stats["by_status"]["edited"] == 1
|
||||
assert stats["by_status"]["rejected"] == 1
|
||||
assert stats["export_ready"] == 2 # approved + edited
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
"""API integration tests for app/sft.py -- /api/sft/* endpoints."""
|
||||
"""API integration tests for app/sft.py — /api/sft/* endpoints."""
|
||||
import json
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
|
@ -7,17 +7,14 @@ from pathlib import Path
|
|||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_sft_globals(tmp_path):
|
||||
from app.data import corrections as corr_module
|
||||
_prev_data = corr_module._DATA_DIR
|
||||
_prev_cfg = corr_module._CONFIG_DIR
|
||||
_prev_default = corr_module._DEFAULT_BENCH_RESULTS_DIR
|
||||
corr_module.set_data_dir(tmp_path)
|
||||
corr_module.set_config_dir(tmp_path)
|
||||
corr_module.set_default_bench_results_dir(str(tmp_path / "bench_results"))
|
||||
from app import sft as sft_module
|
||||
_prev_data = sft_module._SFT_DATA_DIR
|
||||
_prev_cfg = sft_module._SFT_CONFIG_DIR
|
||||
sft_module.set_sft_data_dir(tmp_path)
|
||||
sft_module.set_sft_config_dir(tmp_path)
|
||||
yield
|
||||
corr_module.set_data_dir(_prev_data)
|
||||
corr_module.set_config_dir(_prev_cfg)
|
||||
corr_module.set_default_bench_results_dir(_prev_default)
|
||||
sft_module.set_sft_data_dir(_prev_data)
|
||||
sft_module.set_sft_config_dir(_prev_cfg)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -63,7 +60,7 @@ def _write_config(tmp_path, bench_results_dir: Path) -> None:
|
|||
)
|
||||
|
||||
|
||||
# -- /api/sft/runs -------------------------------------------------------------
|
||||
# ── /api/sft/runs ──────────────────────────────────────────────────────────
|
||||
|
||||
def test_runs_returns_empty_when_no_config(client):
|
||||
r = client.get("/api/sft/runs")
|
||||
|
|
@ -86,7 +83,7 @@ def test_runs_returns_available_runs(client, tmp_path):
|
|||
def test_runs_marks_already_imported(client, tmp_path):
|
||||
_write_run(tmp_path, "2026-04-07-143022", [_make_record("a")])
|
||||
_write_config(tmp_path, tmp_path / "bench_results")
|
||||
from app.data import corrections as sft_module
|
||||
from app import sft as sft_module
|
||||
candidates = sft_module._candidates_file()
|
||||
candidates.parent.mkdir(parents=True, exist_ok=True)
|
||||
candidates.write_text(
|
||||
|
|
@ -97,7 +94,7 @@ def test_runs_marks_already_imported(client, tmp_path):
|
|||
assert r.json()[0]["already_imported"] is True
|
||||
|
||||
|
||||
# -- /api/sft/import -----------------------------------------------------------
|
||||
# ── /api/sft/import ─────────────────────────────────────────────────────────
|
||||
|
||||
def test_import_adds_records(client, tmp_path):
|
||||
_write_run(tmp_path, "2026-04-07-143022", [_make_record("a"), _make_record("b")])
|
||||
|
|
@ -121,10 +118,10 @@ def test_import_unknown_run_returns_404(client, tmp_path):
|
|||
assert r.status_code == 404
|
||||
|
||||
|
||||
# -- /api/sft/queue ------------------------------------------------------------
|
||||
# ── /api/sft/queue ──────────────────────────────────────────────────────────
|
||||
|
||||
def _populate_candidates(tmp_path, records: list[dict]) -> None:
|
||||
from app.data import corrections as sft_module
|
||||
from app import sft as sft_module
|
||||
path = sft_module._candidates_file()
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
path.write_text(
|
||||
|
|
@ -164,7 +161,7 @@ def test_queue_empty_when_no_file(client):
|
|||
assert r.json() == {"items": [], "total": 0, "page": 1, "per_page": 20}
|
||||
|
||||
|
||||
# -- /api/sft/submit -----------------------------------------------------------
|
||||
# ── /api/sft/submit ─────────────────────────────────────────────────────────
|
||||
|
||||
def test_submit_correct_sets_approved(client, tmp_path):
|
||||
_populate_candidates(tmp_path, [_make_record("a")])
|
||||
|
|
@ -173,7 +170,7 @@ def test_submit_correct_sets_approved(client, tmp_path):
|
|||
"corrected_response": "def add(a, b): return a + b",
|
||||
})
|
||||
assert r.status_code == 200
|
||||
from app.data import corrections as sft_module
|
||||
from app import sft as sft_module
|
||||
records = sft_module._read_candidates()
|
||||
assert records[0]["status"] == "approved"
|
||||
assert records[0]["corrected_response"] == "def add(a, b): return a + b"
|
||||
|
|
@ -185,7 +182,7 @@ def test_submit_correct_also_appends_to_approved_file(client, tmp_path):
|
|||
"id": "a", "action": "correct",
|
||||
"corrected_response": "def add(a, b): return a + b",
|
||||
})
|
||||
from app.data import corrections as sft_module
|
||||
from app import sft as sft_module
|
||||
from app.utils import read_jsonl
|
||||
approved = read_jsonl(sft_module._approved_file())
|
||||
assert len(approved) == 1
|
||||
|
|
@ -196,7 +193,7 @@ def test_submit_discard_sets_discarded(client, tmp_path):
|
|||
_populate_candidates(tmp_path, [_make_record("a")])
|
||||
r = client.post("/api/sft/submit", json={"id": "a", "action": "discard"})
|
||||
assert r.status_code == 200
|
||||
from app.data import corrections as sft_module
|
||||
from app import sft as sft_module
|
||||
assert sft_module._read_candidates()[0]["status"] == "discarded"
|
||||
|
||||
|
||||
|
|
@ -204,7 +201,7 @@ def test_submit_flag_sets_model_rejected(client, tmp_path):
|
|||
_populate_candidates(tmp_path, [_make_record("a")])
|
||||
r = client.post("/api/sft/submit", json={"id": "a", "action": "flag"})
|
||||
assert r.status_code == 200
|
||||
from app.data import corrections as sft_module
|
||||
from app import sft as sft_module
|
||||
assert sft_module._read_candidates()[0]["status"] == "model_rejected"
|
||||
|
||||
|
||||
|
|
@ -243,7 +240,7 @@ def test_submit_correct_stores_failure_category(client, tmp_path):
|
|||
"failure_category": "style_violation",
|
||||
})
|
||||
assert r.status_code == 200
|
||||
from app.data import corrections as sft_module
|
||||
from app import sft as sft_module
|
||||
records = sft_module._read_candidates()
|
||||
assert records[0]["failure_category"] == "style_violation"
|
||||
|
||||
|
|
@ -255,7 +252,7 @@ def test_submit_correct_null_failure_category(client, tmp_path):
|
|||
"corrected_response": "def add(a, b): return a + b",
|
||||
})
|
||||
assert r.status_code == 200
|
||||
from app.data import corrections as sft_module
|
||||
from app import sft as sft_module
|
||||
records = sft_module._read_candidates()
|
||||
assert records[0]["failure_category"] is None
|
||||
|
||||
|
|
@ -270,14 +267,14 @@ def test_submit_invalid_failure_category_returns_422(client, tmp_path):
|
|||
assert r.status_code == 422
|
||||
|
||||
|
||||
# -- /api/sft/undo -------------------------------------------------------------
|
||||
# ── /api/sft/undo ────────────────────────────────────────────────────────────
|
||||
|
||||
def test_undo_restores_discarded_to_needs_review(client, tmp_path):
|
||||
_populate_candidates(tmp_path, [_make_record("a")])
|
||||
client.post("/api/sft/submit", json={"id": "a", "action": "discard"})
|
||||
r = client.post("/api/sft/undo", json={"id": "a"})
|
||||
assert r.status_code == 200
|
||||
from app.data import corrections as sft_module
|
||||
from app import sft as sft_module
|
||||
assert sft_module._read_candidates()[0]["status"] == "needs_review"
|
||||
|
||||
|
||||
|
|
@ -288,7 +285,7 @@ def test_undo_removes_approved_from_approved_file(client, tmp_path):
|
|||
"corrected_response": "def add(a, b): return a + b",
|
||||
})
|
||||
client.post("/api/sft/undo", json={"id": "a"})
|
||||
from app.data import corrections as sft_module
|
||||
from app import sft as sft_module
|
||||
from app.utils import read_jsonl
|
||||
approved = read_jsonl(sft_module._approved_file())
|
||||
assert not any(r["id"] == "a" for r in approved)
|
||||
|
|
@ -300,10 +297,10 @@ def test_undo_already_needs_review_returns_409(client, tmp_path):
|
|||
assert r.status_code == 409
|
||||
|
||||
|
||||
# -- /api/sft/export -----------------------------------------------------------
|
||||
# ── /api/sft/export ──────────────────────────────────────────────────────────
|
||||
|
||||
def test_export_returns_approved_as_sft_jsonl(client, tmp_path):
|
||||
from app.data import corrections as sft_module
|
||||
from app import sft as sft_module
|
||||
from app.utils import write_jsonl
|
||||
approved = {
|
||||
**_make_record("a"),
|
||||
|
|
@ -331,7 +328,7 @@ def test_export_returns_approved_as_sft_jsonl(client, tmp_path):
|
|||
|
||||
|
||||
def test_export_excludes_non_approved(client, tmp_path):
|
||||
from app.data import corrections as sft_module
|
||||
from app import sft as sft_module
|
||||
from app.utils import write_jsonl
|
||||
records = [
|
||||
{**_make_record("a"), "status": "discarded", "corrected_response": None},
|
||||
|
|
@ -348,10 +345,10 @@ def test_export_empty_when_no_approved_file(client):
|
|||
assert r.text.strip() == ""
|
||||
|
||||
|
||||
# -- /api/sft/stats ------------------------------------------------------------
|
||||
# ── /api/sft/stats ───────────────────────────────────────────────────────────
|
||||
|
||||
def test_stats_counts_by_status(client, tmp_path):
|
||||
from app.data import corrections as sft_module
|
||||
from app import sft as sft_module
|
||||
from app.utils import write_jsonl
|
||||
records = [
|
||||
_make_record("a"),
|
||||
|
|
|
|||
|
|
@ -1,187 +0,0 @@
|
|||
"""Tests for app/train/train.py -- /api/train/* endpoints."""
|
||||
import json
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_globals(tmp_path):
|
||||
from app.train import train as train_module
|
||||
train_module.set_db_path(tmp_path / "train_jobs.db")
|
||||
train_module.set_models_dir(tmp_path / "models")
|
||||
train_module._running_procs.clear()
|
||||
yield
|
||||
train_module._running_procs.clear()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
from app.api import app
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
def _parse_sse(content: bytes) -> list[dict]:
|
||||
events = []
|
||||
for line in content.decode().splitlines():
|
||||
if line.startswith("data: "):
|
||||
events.append(json.loads(line[6:]))
|
||||
return events
|
||||
|
||||
|
||||
def test_list_jobs_empty(client):
|
||||
r = client.get("/api/train/jobs")
|
||||
assert r.status_code == 200
|
||||
assert r.json() == {"jobs": []}
|
||||
|
||||
|
||||
def test_create_job_returns_queued_record(client):
|
||||
r = client.post("/api/train/jobs",
|
||||
json={"type": "classifier", "model_key": "deberta-small",
|
||||
"config_json": {"epochs": 3}})
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["status"] == "queued"
|
||||
assert data["type"] == "classifier"
|
||||
assert data["model_key"] == "deberta-small"
|
||||
assert "id" in data
|
||||
|
||||
|
||||
def test_create_job_invalid_type_returns_400(client):
|
||||
r = client.post("/api/train/jobs",
|
||||
json={"type": "unknown-type", "model_key": "deberta-small"})
|
||||
assert r.status_code == 400
|
||||
|
||||
|
||||
def test_create_job_appears_in_list(client):
|
||||
client.post("/api/train/jobs", json={"type": "classifier", "model_key": "deberta-small"})
|
||||
r = client.get("/api/train/jobs")
|
||||
assert r.status_code == 200
|
||||
assert len(r.json()["jobs"]) == 1
|
||||
|
||||
|
||||
def test_get_job_returns_record(client):
|
||||
r = client.post("/api/train/jobs", json={"type": "classifier", "model_key": "deberta-small"})
|
||||
job_id = r.json()["id"]
|
||||
r2 = client.get(f"/api/train/jobs/{job_id}")
|
||||
assert r2.status_code == 200
|
||||
assert r2.json()["id"] == job_id
|
||||
|
||||
|
||||
def test_get_job_404_for_unknown(client):
|
||||
r = client.get("/api/train/jobs/no-such-id")
|
||||
assert r.status_code == 404
|
||||
|
||||
|
||||
def test_cancel_queued_job(client):
|
||||
r = client.post("/api/train/jobs", json={"type": "classifier", "model_key": "deberta-small"})
|
||||
job_id = r.json()["id"]
|
||||
r2 = client.delete(f"/api/train/jobs/{job_id}/cancel")
|
||||
assert r2.status_code == 200
|
||||
assert r2.json()["status"] == "cancelled"
|
||||
r3 = client.get(f"/api/train/jobs/{job_id}")
|
||||
assert r3.json()["status"] == "cancelled"
|
||||
|
||||
|
||||
def test_cancel_completed_job_returns_409(client):
|
||||
from app.train import train as train_module
|
||||
train_module._init_db()
|
||||
with train_module._db() as conn:
|
||||
conn.execute(
|
||||
"INSERT INTO jobs (id, type, model_key, status, config_json, created_at) "
|
||||
"VALUES ('abc', 'classifier', 'deberta-small', 'completed', '{}', '2026-05-01T00:00:00Z')"
|
||||
)
|
||||
r = client.delete("/api/train/jobs/abc/cancel")
|
||||
assert r.status_code == 409
|
||||
|
||||
|
||||
def test_cancel_terminates_running_proc(client):
|
||||
from app.train import train as train_module
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.wait = MagicMock()
|
||||
r = client.post("/api/train/jobs", json={"type": "classifier", "model_key": "deberta-small"})
|
||||
job_id = r.json()["id"]
|
||||
train_module._running_procs[job_id] = mock_proc
|
||||
with train_module._db() as conn:
|
||||
conn.execute("UPDATE jobs SET status='running' WHERE id=?", (job_id,))
|
||||
r2 = client.delete(f"/api/train/jobs/{job_id}/cancel")
|
||||
assert r2.status_code == 200
|
||||
mock_proc.terminate.assert_called_once()
|
||||
|
||||
|
||||
def test_run_job_streams_sse(client):
|
||||
r = client.post("/api/train/jobs", json={"type": "classifier", "model_key": "deberta-small"})
|
||||
job_id = r.json()["id"]
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.stdout = iter(["Epoch 1\n", "Done\n"])
|
||||
mock_proc.returncode = 0
|
||||
mock_proc.wait = MagicMock()
|
||||
with patch("app.train.train._subprocess.Popen", return_value=mock_proc):
|
||||
r2 = client.get(f"/api/train/jobs/{job_id}/run")
|
||||
assert r2.status_code == 200
|
||||
assert "text/event-stream" in r2.headers.get("content-type", "")
|
||||
events = _parse_sse(r2.content)
|
||||
assert any(e["type"] == "complete" for e in events)
|
||||
|
||||
|
||||
def test_run_job_marks_completed_in_db(client):
|
||||
from app.train import train as train_module
|
||||
r = client.post("/api/train/jobs", json={"type": "classifier", "model_key": "deberta-small"})
|
||||
job_id = r.json()["id"]
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.stdout = iter([])
|
||||
mock_proc.returncode = 0
|
||||
mock_proc.wait = MagicMock()
|
||||
with patch("app.train.train._subprocess.Popen", return_value=mock_proc):
|
||||
client.get(f"/api/train/jobs/{job_id}/run")
|
||||
r2 = client.get(f"/api/train/jobs/{job_id}")
|
||||
assert r2.json()["status"] == "completed"
|
||||
|
||||
|
||||
def test_run_job_marks_failed_on_nonzero_exit(client):
|
||||
r = client.post("/api/train/jobs", json={"type": "classifier", "model_key": "deberta-small"})
|
||||
job_id = r.json()["id"]
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.stdout = iter([])
|
||||
mock_proc.returncode = 1
|
||||
mock_proc.wait = MagicMock()
|
||||
with patch("app.train.train._subprocess.Popen", return_value=mock_proc):
|
||||
client.get(f"/api/train/jobs/{job_id}/run")
|
||||
r2 = client.get(f"/api/train/jobs/{job_id}")
|
||||
assert r2.json()["status"] == "failed"
|
||||
|
||||
|
||||
def test_run_nonqueued_job_returns_409(client):
|
||||
from app.train import train as train_module
|
||||
train_module._init_db()
|
||||
with train_module._db() as conn:
|
||||
conn.execute(
|
||||
"INSERT INTO jobs (id, type, model_key, status, config_json, created_at) "
|
||||
"VALUES ('xyz', 'classifier', 'deberta-small', 'running', '{}', '2026-05-01T00:00:00Z')"
|
||||
)
|
||||
r = client.get("/api/train/jobs/xyz/run")
|
||||
assert r.status_code == 409
|
||||
|
||||
|
||||
def test_run_unknown_job_returns_404(client):
|
||||
r = client.get("/api/train/jobs/no-such/run")
|
||||
assert r.status_code == 404
|
||||
|
||||
|
||||
def test_results_empty_when_no_models_dir(client):
|
||||
r = client.get("/api/train/results")
|
||||
assert r.status_code == 200
|
||||
assert r.json() == {"results": []}
|
||||
|
||||
|
||||
def test_results_returns_training_info(client, tmp_path):
|
||||
from app.train import train as train_module
|
||||
models_dir = tmp_path / "models" / "avocet-deberta-small"
|
||||
models_dir.mkdir(parents=True)
|
||||
train_module.set_models_dir(tmp_path / "models")
|
||||
info = {"name": "avocet-deberta-small", "val_macro_f1": 0.712, "sample_count": 401}
|
||||
(models_dir / "training_info.json").write_text(json.dumps(info))
|
||||
r = client.get("/api/train/results")
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert any(d["name"] == "avocet-deberta-small" for d in data["results"])
|
||||
4
web/.gitignore
vendored
4
web/.gitignore
vendored
|
|
@ -22,7 +22,3 @@ dist-ssr
|
|||
*.njsproj
|
||||
*.sln
|
||||
*.sw?
|
||||
|
||||
# Local environment overrides
|
||||
.env
|
||||
|
||||
|
|
|
|||
42
web/package-lock.json
generated
42
web/package-lock.json
generated
|
|
@ -2676,9 +2676,9 @@
|
|||
}
|
||||
},
|
||||
"node_modules/brace-expansion": {
|
||||
"version": "2.1.0",
|
||||
"resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.1.0.tgz",
|
||||
"integrity": "sha512-TN1kCZAgdgweJhWWpgKYrQaMNHcDULHkWwQIspdtjV4Y5aurRdZpjAqn6yX3FPqTA9ngHCc4hJxMAMgGfve85w==",
|
||||
"version": "2.0.2",
|
||||
"resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.2.tgz",
|
||||
"integrity": "sha512-Jt0vHyM+jmUBqojB7E1NIYadt0vI0Qxjxd2TErW94wDz+E2LAm5vKMXXwg6ZZBTHPuUlDgQHKXvjGBdfcF1ZDQ==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
|
|
@ -2890,9 +2890,9 @@
|
|||
"license": "MIT"
|
||||
},
|
||||
"node_modules/defu": {
|
||||
"version": "6.1.7",
|
||||
"resolved": "https://registry.npmjs.org/defu/-/defu-6.1.7.tgz",
|
||||
"integrity": "sha512-7z22QmUWiQ/2d0KkdYmANbRUVABpZ9SNYyH5vx6PZ+nE5bcC0l7uFvEfHlyld/HcGBFTL536ClDt3DEcSlEJAQ==",
|
||||
"version": "6.1.4",
|
||||
"resolved": "https://registry.npmjs.org/defu/-/defu-6.1.4.tgz",
|
||||
"integrity": "sha512-mEQCMmwJu317oSz8CwdIOdwf3xMif1ttiM8LTufzc3g6kR+9Pe236twL8j3IYT1F7GfRgGcW6MWxzZjLIkuHIg==",
|
||||
"dev": true,
|
||||
"license": "MIT"
|
||||
},
|
||||
|
|
@ -3725,9 +3725,9 @@
|
|||
"license": "ISC"
|
||||
},
|
||||
"node_modules/picomatch": {
|
||||
"version": "4.0.4",
|
||||
"resolved": "https://registry.npmjs.org/picomatch/-/picomatch-4.0.4.tgz",
|
||||
"integrity": "sha512-QP88BAKvMam/3NxH6vj2o21R6MjxZUAd6nlwAS/pnGvN9IVLocLHxGYIzFhg6fUQ+5th6P4dv4eW9jX3DSIj7A==",
|
||||
"version": "4.0.3",
|
||||
"resolved": "https://registry.npmjs.org/picomatch/-/picomatch-4.0.3.tgz",
|
||||
"integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==",
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">=12"
|
||||
|
|
@ -3769,9 +3769,9 @@
|
|||
}
|
||||
},
|
||||
"node_modules/postcss": {
|
||||
"version": "8.5.14",
|
||||
"resolved": "https://registry.npmjs.org/postcss/-/postcss-8.5.14.tgz",
|
||||
"integrity": "sha512-SoSL4+OSEtR99LHFZQiJLkT59C5B1amGO1NzTwj7TT1qCUgUO6hxOvzkOYxD+vMrXBM3XJIKzokoERdqQq/Zmg==",
|
||||
"version": "8.5.8",
|
||||
"resolved": "https://registry.npmjs.org/postcss/-/postcss-8.5.8.tgz",
|
||||
"integrity": "sha512-OW/rX8O/jXnm82Ey1k44pObPtdblfiuWnrd8X7GJ7emImCOstunGbXUpp7HdBrFQX6rJzn3sPT397Wp5aCwCHg==",
|
||||
"funding": [
|
||||
{
|
||||
"type": "opencollective",
|
||||
|
|
@ -4325,9 +4325,9 @@
|
|||
}
|
||||
},
|
||||
"node_modules/undici": {
|
||||
"version": "7.25.0",
|
||||
"resolved": "https://registry.npmjs.org/undici/-/undici-7.25.0.tgz",
|
||||
"integrity": "sha512-xXnp4kTyor2Zq+J1FfPI6Eq3ew5h6Vl0F/8d9XU5zZQf1tX9s2Su1/3PiMmUANFULpmksxkClamIZcaUqryHsQ==",
|
||||
"version": "7.22.0",
|
||||
"resolved": "https://registry.npmjs.org/undici/-/undici-7.22.0.tgz",
|
||||
"integrity": "sha512-RqslV2Us5BrllB+JeiZnK4peryVTndy9Dnqq62S3yYRRTj0tFQCwEniUy2167skdGOy3vqRzEvl1Dm4sV2ReDg==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
|
|
@ -4422,9 +4422,9 @@
|
|||
}
|
||||
},
|
||||
"node_modules/vite": {
|
||||
"version": "7.3.2",
|
||||
"resolved": "https://registry.npmjs.org/vite/-/vite-7.3.2.tgz",
|
||||
"integrity": "sha512-Bby3NOsna2jsjfLVOHKes8sGwgl4TT0E6vvpYgnAYDIF/tie7MRaFthmKuHx1NSXjiTueXH3do80FMQgvEktRg==",
|
||||
"version": "7.3.1",
|
||||
"resolved": "https://registry.npmjs.org/vite/-/vite-7.3.1.tgz",
|
||||
"integrity": "sha512-w+N7Hifpc3gRjZ63vYBXA56dvvRlNWRczTdmCBBa+CotUzAPf5b7YMdMR/8CQoeYE5LX3W4wj6RYTgonm1b9DA==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
|
|
@ -4921,9 +4921,9 @@
|
|||
"license": "MIT"
|
||||
},
|
||||
"node_modules/yaml": {
|
||||
"version": "2.8.4",
|
||||
"resolved": "https://registry.npmjs.org/yaml/-/yaml-2.8.4.tgz",
|
||||
"integrity": "sha512-ml/JPOj9fOQK8RNnWojA67GbZ0ApXAUlN2UQclwv2eVgTgn7O9gg9o7paZWKMp4g0H3nTLtS9LVzhkpOFIKzog==",
|
||||
"version": "2.8.2",
|
||||
"resolved": "https://registry.npmjs.org/yaml/-/yaml-2.8.2.tgz",
|
||||
"integrity": "sha512-mplynKqc1C2hTVYxd0PU2xQAc22TI1vShAYGksCCfxbn/dFwnHTNi1bvYsBTkhdUNtGIf5xNOg938rrSSYvS9A==",
|
||||
"license": "ISC",
|
||||
"bin": {
|
||||
"yaml": "bin.mjs"
|
||||
|
|
|
|||
|
|
@ -1,124 +0,0 @@
|
|||
import { mount, flushPromises } from '@vue/test-utils'
|
||||
import { createRouter, createWebHashHistory } from 'vue-router'
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest'
|
||||
import AppSidebar from './AppSidebar.vue'
|
||||
|
||||
// Minimal router so RouterLink renders without warnings
|
||||
const router = createRouter({
|
||||
history: createWebHashHistory(),
|
||||
routes: [
|
||||
{ path: '/', component: { template: '<div />' } },
|
||||
{ path: '/fleet', component: { template: '<div />' } },
|
||||
{ path: '/data/label', component: { template: '<div />' } },
|
||||
{ path: '/data/fetch', component: { template: '<div />' } },
|
||||
{ path: '/data/corrections', component: { template: '<div />' } },
|
||||
{ path: '/data/imitate', component: { template: '<div />' } },
|
||||
{ path: '/eval/benchmark', component: { template: '<div />' } },
|
||||
{ path: '/eval/compare', component: { template: '<div />' } },
|
||||
{ path: '/train/jobs', component: { template: '<div />' } },
|
||||
{ path: '/train/results', component: { template: '<div />' } },
|
||||
{ path: '/settings', component: { template: '<div />' } },
|
||||
],
|
||||
})
|
||||
|
||||
function makeFetch(signals: Record<string, boolean> = {}) {
|
||||
return vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
json: async () => ({
|
||||
labeled_since_last_eval: 0,
|
||||
last_eval_timestamp: null,
|
||||
last_eval_best_score: null,
|
||||
active_jobs: [],
|
||||
corrections_export_ready: 0,
|
||||
signals,
|
||||
}),
|
||||
text: async () => '',
|
||||
})
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
localStorage.clear()
|
||||
vi.stubGlobal('fetch', makeFetch())
|
||||
})
|
||||
|
||||
describe('AppSidebar structure', () => {
|
||||
it('renders section headers for Data, Eval, Train', async () => {
|
||||
const w = mount(AppSidebar, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
const text = w.text()
|
||||
expect(text).toContain('Data')
|
||||
expect(text).toContain('Eval')
|
||||
expect(text).toContain('Train')
|
||||
})
|
||||
|
||||
it('renders all sub-links', async () => {
|
||||
const w = mount(AppSidebar, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
const anchors = w.findAll('a')
|
||||
const hrefs = anchors.map(a => a.attributes('href') ?? '')
|
||||
expect(hrefs.some(h => h.includes('/data/label'))).toBe(true)
|
||||
expect(hrefs.some(h => h.includes('/data/fetch'))).toBe(true)
|
||||
expect(hrefs.some(h => h.includes('/data/corrections'))).toBe(true)
|
||||
expect(hrefs.some(h => h.includes('/data/imitate'))).toBe(true)
|
||||
expect(hrefs.some(h => h.includes('/eval/benchmark'))).toBe(true)
|
||||
expect(hrefs.some(h => h.includes('/eval/compare'))).toBe(true)
|
||||
expect(hrefs.some(h => h.includes('/train/jobs'))).toBe(true)
|
||||
expect(hrefs.some(h => h.includes('/train/results'))).toBe(true)
|
||||
expect(hrefs.some(h => h.includes('/fleet'))).toBe(true)
|
||||
expect(hrefs.some(h => h.includes('/settings'))).toBe(true)
|
||||
})
|
||||
|
||||
it('does NOT render the old /benchmark or /models links', async () => {
|
||||
const w = mount(AppSidebar, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
const anchors = w.findAll('a')
|
||||
const hrefs = anchors.map(a => a.attributes('href') ?? '')
|
||||
// Old paths must not appear as direct links (they're only redirects)
|
||||
expect(hrefs.every(h => !h.endsWith('/#/benchmark'))).toBe(true)
|
||||
expect(hrefs.every(h => !h.endsWith('/#/models'))).toBe(true)
|
||||
expect(hrefs.every(h => !h.endsWith('/#/stats'))).toBe(true)
|
||||
})
|
||||
|
||||
it('shows no signal badges when all signals are false', async () => {
|
||||
vi.stubGlobal('fetch', makeFetch({ data_to_eval: false, eval_to_train: false, train_to_fleet: false }))
|
||||
const w = mount(AppSidebar, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
expect(w.findAll('.signal-badge').length).toBe(0)
|
||||
})
|
||||
|
||||
it('shows signal badge on Data section when data_to_eval is true', async () => {
|
||||
vi.stubGlobal('fetch', makeFetch({ data_to_eval: true, eval_to_train: false, train_to_fleet: false }))
|
||||
const w = mount(AppSidebar, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
const badges = w.findAll('.signal-badge')
|
||||
expect(badges.length).toBe(1)
|
||||
// It should be inside the Data section header
|
||||
const dataHeader = w.find('[data-section="data"]')
|
||||
expect(dataHeader.find('.signal-badge').exists()).toBe(true)
|
||||
})
|
||||
|
||||
it('shows signal badge on Eval section when eval_to_train is true', async () => {
|
||||
vi.stubGlobal('fetch', makeFetch({ data_to_eval: false, eval_to_train: true, train_to_fleet: false }))
|
||||
const w = mount(AppSidebar, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
const evalHeader = w.find('[data-section="eval"]')
|
||||
expect(evalHeader.find('.signal-badge').exists()).toBe(true)
|
||||
})
|
||||
|
||||
it('shows signal badge on Train section when train_to_fleet is true', async () => {
|
||||
vi.stubGlobal('fetch', makeFetch({ data_to_eval: false, eval_to_train: false, train_to_fleet: true }))
|
||||
const w = mount(AppSidebar, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
const trainHeader = w.find('[data-section="train"]')
|
||||
expect(trainHeader.find('.signal-badge').exists()).toBe(true)
|
||||
})
|
||||
|
||||
it('stow toggle still works', async () => {
|
||||
const w = mount(AppSidebar, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
const nav = w.find('nav')
|
||||
expect(nav.classes()).not.toContain('stowed')
|
||||
await w.find('.stow-btn').trigger('click')
|
||||
expect(nav.classes()).toContain('stowed')
|
||||
})
|
||||
})
|
||||
|
|
@ -28,70 +28,12 @@
|
|||
</button>
|
||||
</div>
|
||||
|
||||
<!-- Nav -->
|
||||
<!-- Nav items -->
|
||||
<ul class="nav-list" role="list">
|
||||
<!-- Top-level links -->
|
||||
<li>
|
||||
<RouterLink
|
||||
to="/"
|
||||
class="nav-item"
|
||||
:title="stowed ? 'Dashboard' : ''"
|
||||
@click="isMobile && stow()"
|
||||
>
|
||||
<span class="nav-icon" aria-hidden="true">📊</span>
|
||||
<span v-if="!stowed" class="nav-label">Dashboard</span>
|
||||
</RouterLink>
|
||||
</li>
|
||||
<li>
|
||||
<RouterLink
|
||||
to="/fleet"
|
||||
class="nav-item"
|
||||
:title="stowed ? 'Fleet' : ''"
|
||||
@click="isMobile && stow()"
|
||||
>
|
||||
<span class="nav-icon" aria-hidden="true">⚡</span>
|
||||
<span v-if="!stowed" class="nav-label">Fleet</span>
|
||||
</RouterLink>
|
||||
</li>
|
||||
<li>
|
||||
<RouterLink
|
||||
to="/nodes"
|
||||
class="nav-item"
|
||||
:title="stowed ? 'Nodes' : ''"
|
||||
@click="isMobile && stow()"
|
||||
>
|
||||
<span class="nav-icon" aria-hidden="true">🖥️</span>
|
||||
<span v-if="!stowed" class="nav-label">Nodes</span>
|
||||
</RouterLink>
|
||||
</li>
|
||||
|
||||
<!-- ① Data section -->
|
||||
<li>
|
||||
<div class="section-header" data-section="data" aria-hidden="true">
|
||||
<template v-if="!stowed">
|
||||
<span class="section-label">① Data</span>
|
||||
<span
|
||||
v-if="signals.data_to_eval"
|
||||
class="signal-badge"
|
||||
title="Enough new labels to run eval"
|
||||
aria-label="Eval recommended"
|
||||
/>
|
||||
</template>
|
||||
<template v-else>
|
||||
<span class="section-icon">①</span>
|
||||
<span
|
||||
v-if="signals.data_to_eval"
|
||||
class="signal-badge signal-badge-stowed"
|
||||
title="Eval recommended"
|
||||
aria-label="Eval recommended"
|
||||
/>
|
||||
</template>
|
||||
</div>
|
||||
</li>
|
||||
<li v-for="item in dataItems" :key="item.path">
|
||||
<li v-for="item in navItems" :key="item.path">
|
||||
<RouterLink
|
||||
:to="item.path"
|
||||
class="nav-item nav-subitem"
|
||||
class="nav-item"
|
||||
:title="stowed ? item.label : ''"
|
||||
@click="isMobile && stow()"
|
||||
>
|
||||
|
|
@ -99,94 +41,10 @@
|
|||
<span v-if="!stowed" class="nav-label">{{ item.label }}</span>
|
||||
</RouterLink>
|
||||
</li>
|
||||
|
||||
<!-- ② Eval section -->
|
||||
<li>
|
||||
<div class="section-header" data-section="eval" aria-hidden="true">
|
||||
<template v-if="!stowed">
|
||||
<span class="section-label">② Eval</span>
|
||||
<span
|
||||
v-if="signals.eval_to_train"
|
||||
class="signal-badge"
|
||||
title="Strong eval result — consider finetuning"
|
||||
aria-label="Finetune recommended"
|
||||
/>
|
||||
</template>
|
||||
<template v-else>
|
||||
<span class="section-icon">②</span>
|
||||
<span
|
||||
v-if="signals.eval_to_train"
|
||||
class="signal-badge signal-badge-stowed"
|
||||
title="Finetune recommended"
|
||||
aria-label="Finetune recommended"
|
||||
/>
|
||||
</template>
|
||||
</div>
|
||||
</li>
|
||||
<li v-for="item in evalItems" :key="item.path">
|
||||
<RouterLink
|
||||
:to="item.path"
|
||||
class="nav-item nav-subitem"
|
||||
:title="stowed ? item.label : ''"
|
||||
@click="isMobile && stow()"
|
||||
>
|
||||
<span class="nav-icon" aria-hidden="true">{{ item.icon }}</span>
|
||||
<span v-if="!stowed" class="nav-label">{{ item.label }}</span>
|
||||
</RouterLink>
|
||||
</li>
|
||||
|
||||
<!-- ③ Train section -->
|
||||
<li>
|
||||
<div class="section-header" data-section="train" aria-hidden="true">
|
||||
<template v-if="!stowed">
|
||||
<span class="section-label">③ Train</span>
|
||||
<span
|
||||
v-if="signals.train_to_fleet"
|
||||
class="signal-badge"
|
||||
title="Trained model ready for fleet registration"
|
||||
aria-label="Fleet registration recommended"
|
||||
/>
|
||||
</template>
|
||||
<template v-else>
|
||||
<span class="section-icon">③</span>
|
||||
<span
|
||||
v-if="signals.train_to_fleet"
|
||||
class="signal-badge signal-badge-stowed"
|
||||
title="Fleet registration recommended"
|
||||
aria-label="Fleet registration recommended"
|
||||
/>
|
||||
</template>
|
||||
</div>
|
||||
</li>
|
||||
<li v-for="item in trainItems" :key="item.path">
|
||||
<RouterLink
|
||||
:to="item.path"
|
||||
class="nav-item nav-subitem"
|
||||
:title="stowed ? item.label : ''"
|
||||
@click="isMobile && stow()"
|
||||
>
|
||||
<span class="nav-icon" aria-hidden="true">{{ item.icon }}</span>
|
||||
<span v-if="!stowed" class="nav-label">{{ item.label }}</span>
|
||||
</RouterLink>
|
||||
</li>
|
||||
|
||||
<!-- Divider + Settings -->
|
||||
<li class="nav-divider" aria-hidden="true" />
|
||||
<li>
|
||||
<RouterLink
|
||||
to="/settings"
|
||||
class="nav-item"
|
||||
:title="stowed ? 'Settings' : ''"
|
||||
@click="isMobile && stow()"
|
||||
>
|
||||
<span class="nav-icon" aria-hidden="true">⚙️</span>
|
||||
<span v-if="!stowed" class="nav-label">Settings</span>
|
||||
</RouterLink>
|
||||
</li>
|
||||
</ul>
|
||||
</nav>
|
||||
|
||||
<!-- Mobile hamburger button — visible when sidebar is stowed on mobile -->
|
||||
<!-- Mobile hamburger button rendered outside the sidebar so it's visible when stowed -->
|
||||
<button
|
||||
v-if="isMobile && stowed"
|
||||
class="mobile-hamburger"
|
||||
|
|
@ -203,68 +61,24 @@ import { RouterLink } from 'vue-router'
|
|||
|
||||
const LS_KEY = 'cf-avocet-nav-stowed'
|
||||
|
||||
interface NavItem {
|
||||
path: string
|
||||
icon: string
|
||||
label: string
|
||||
}
|
||||
|
||||
interface DashboardSignals {
|
||||
data_to_eval: boolean
|
||||
eval_to_train: boolean
|
||||
train_to_fleet: boolean
|
||||
}
|
||||
|
||||
const dataItems: NavItem[] = [
|
||||
{ path: '/data/label', icon: '🏷', label: 'Label' },
|
||||
{ path: '/data/fetch', icon: '📬', label: 'Fetch' },
|
||||
{ path: '/data/corrections', icon: '✏️', label: 'Corrections' },
|
||||
{ path: '/data/imitate', icon: '🪞', label: 'Imitate' },
|
||||
{ path: '/data/recipe-scan', icon: '📷', label: 'Recipe Scan' },
|
||||
const navItems = [
|
||||
{ path: '/', icon: '🃏', label: 'Label' },
|
||||
{ path: '/fetch', icon: '📥', label: 'Fetch' },
|
||||
{ path: '/stats', icon: '📊', label: 'Stats' },
|
||||
{ path: '/benchmark', icon: '🏁', label: 'Benchmark' },
|
||||
{ path: '/models', icon: '🤗', label: 'Models' },
|
||||
{ path: '/corrections', icon: '✍️', label: 'Corrections' },
|
||||
{ path: '/settings', icon: '⚙️', label: 'Settings' },
|
||||
]
|
||||
|
||||
const evalItems: NavItem[] = [
|
||||
{ path: '/eval/benchmark', icon: '📊', label: 'Benchmark' },
|
||||
{ path: '/eval/compare', icon: '🔍', label: 'Compare' },
|
||||
{ path: '/eval/embed-compare', icon: '🧮', label: 'Embed Compare' },
|
||||
]
|
||||
|
||||
const trainItems: NavItem[] = [
|
||||
{ path: '/train/jobs', icon: '🧠', label: 'Jobs' },
|
||||
{ path: '/train/results', icon: '📈', label: 'Results' },
|
||||
]
|
||||
|
||||
const stowed = ref(localStorage.getItem(LS_KEY) === 'true')
|
||||
const winWidth = ref(window.innerWidth)
|
||||
const isMobile = computed(() => winWidth.value < 640)
|
||||
|
||||
const signals = ref<DashboardSignals>({
|
||||
data_to_eval: false,
|
||||
eval_to_train: false,
|
||||
train_to_fleet: false,
|
||||
})
|
||||
|
||||
async function loadSignals() {
|
||||
try {
|
||||
const res = await fetch('/api/dashboard')
|
||||
if (res.ok) {
|
||||
const data = await res.json() as { signals?: DashboardSignals }
|
||||
if (data.signals) {
|
||||
signals.value = {
|
||||
data_to_eval: data.signals.data_to_eval ?? false,
|
||||
eval_to_train: data.signals.eval_to_train ?? false,
|
||||
train_to_fleet: data.signals.train_to_fleet ?? false,
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
// Non-fatal: badges simply stay hidden if API is unreachable
|
||||
}
|
||||
}
|
||||
const stowed = ref(localStorage.getItem(LS_KEY) === 'true')
|
||||
const winWidth = ref(window.innerWidth)
|
||||
const isMobile = computed(() => winWidth.value < 640)
|
||||
|
||||
function toggle() {
|
||||
stowed.value = !stowed.value
|
||||
localStorage.setItem(LS_KEY, String(stowed.value))
|
||||
// Update CSS variable on :root so .app-main margin-left syncs
|
||||
document.documentElement.style.setProperty('--sidebar-width', stowed.value ? '56px' : '200px')
|
||||
}
|
||||
|
||||
|
|
@ -278,12 +92,13 @@ function onResize() { winWidth.value = window.innerWidth }
|
|||
|
||||
onMounted(() => {
|
||||
window.addEventListener('resize', onResize)
|
||||
// Apply persisted sidebar width to :root on mount
|
||||
document.documentElement.style.setProperty('--sidebar-width', stowed.value ? '56px' : '200px')
|
||||
// On mobile, default to stowed
|
||||
if (isMobile.value && !localStorage.getItem(LS_KEY)) {
|
||||
stowed.value = true
|
||||
document.documentElement.style.setProperty('--sidebar-width', '56px')
|
||||
}
|
||||
loadSignals()
|
||||
})
|
||||
|
||||
onUnmounted(() => window.removeEventListener('resize', onResize))
|
||||
|
|
@ -305,15 +120,18 @@ onUnmounted(() => window.removeEventListener('resize', onResize))
|
|||
overflow: hidden;
|
||||
}
|
||||
|
||||
.sidebar.stowed { width: 56px; }
|
||||
.sidebar.stowed {
|
||||
width: 56px;
|
||||
}
|
||||
|
||||
/* Mobile: slide in/out from left */
|
||||
.sidebar.mobile {
|
||||
box-shadow: 2px 0 16px rgba(0, 0, 0, 0.15);
|
||||
}
|
||||
|
||||
.sidebar.mobile.stowed {
|
||||
transform: translateX(-100%);
|
||||
width: 200px;
|
||||
width: 200px; /* keep width so slide-in looks right */
|
||||
transition: transform 250ms ease, width 250ms ease;
|
||||
}
|
||||
|
||||
|
|
@ -346,7 +164,10 @@ onUnmounted(() => window.removeEventListener('resize', onResize))
|
|||
white-space: nowrap;
|
||||
}
|
||||
|
||||
.logo-icon { font-size: 1.25rem; flex-shrink: 0; }
|
||||
.logo-icon {
|
||||
font-size: 1.25rem;
|
||||
flex-shrink: 0;
|
||||
}
|
||||
|
||||
.logo-name {
|
||||
font-family: var(--font-display, var(--font-body, sans-serif));
|
||||
|
|
@ -371,76 +192,16 @@ onUnmounted(() => window.removeEventListener('resize', onResize))
|
|||
transition: background 0.15s;
|
||||
}
|
||||
|
||||
.stow-btn:hover { background: var(--color-border, #d0d7e8); }
|
||||
.stow-btn:hover {
|
||||
background: var(--color-border, #d0d7e8);
|
||||
}
|
||||
|
||||
.nav-list {
|
||||
list-style: none;
|
||||
padding: 0.5rem 0;
|
||||
flex: 1;
|
||||
overflow-y: auto;
|
||||
overflow-x: hidden;
|
||||
}
|
||||
|
||||
/* ── Section headers ── */
|
||||
.section-header {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.4rem;
|
||||
padding: 0.55rem 0.75rem 0.25rem;
|
||||
margin-top: 0.5rem;
|
||||
pointer-events: none;
|
||||
user-select: none;
|
||||
}
|
||||
|
||||
.section-label {
|
||||
font-size: 0.7rem;
|
||||
font-weight: 700;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.07em;
|
||||
color: var(--color-text-muted, #4a5c7a);
|
||||
white-space: nowrap;
|
||||
flex: 1;
|
||||
}
|
||||
|
||||
.section-icon {
|
||||
font-size: 0.75rem;
|
||||
color: var(--color-text-muted, #4a5c7a);
|
||||
width: 24px;
|
||||
text-align: center;
|
||||
flex-shrink: 0;
|
||||
}
|
||||
|
||||
/* ── Signal badges ── */
|
||||
.signal-badge {
|
||||
width: 8px;
|
||||
height: 8px;
|
||||
border-radius: 50%;
|
||||
background: var(--color-warning, #d4891a);
|
||||
flex-shrink: 0;
|
||||
display: inline-block;
|
||||
}
|
||||
|
||||
.signal-badge-stowed {
|
||||
position: absolute;
|
||||
top: 4px;
|
||||
right: 4px;
|
||||
}
|
||||
|
||||
/* Make the stowed section header container position:relative for the badge */
|
||||
.sidebar.stowed .section-header {
|
||||
position: relative;
|
||||
justify-content: center;
|
||||
padding: 0.55rem 0 0.25rem;
|
||||
}
|
||||
|
||||
/* ── Nav divider ── */
|
||||
.nav-divider {
|
||||
height: 1px;
|
||||
background: var(--color-border, #d0d7e8);
|
||||
margin: 0.5rem 0.75rem;
|
||||
}
|
||||
|
||||
/* ── Nav items ── */
|
||||
.nav-item {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
|
|
@ -476,9 +237,6 @@ onUnmounted(() => window.removeEventListener('resize', onResize))
|
|||
border-radius: 0 2px 2px 0;
|
||||
}
|
||||
|
||||
/* Sub-items are indented slightly in expanded state */
|
||||
.nav-subitem { padding-left: 1.1rem; font-size: 0.875rem; }
|
||||
|
||||
.nav-icon {
|
||||
font-size: 1.1rem;
|
||||
flex-shrink: 0;
|
||||
|
|
@ -486,9 +244,12 @@ onUnmounted(() => window.removeEventListener('resize', onResize))
|
|||
text-align: center;
|
||||
}
|
||||
|
||||
.nav-label { overflow: hidden; text-overflow: ellipsis; }
|
||||
.nav-label {
|
||||
overflow: hidden;
|
||||
text-overflow: ellipsis;
|
||||
}
|
||||
|
||||
/* Mobile hamburger */
|
||||
/* Mobile hamburger — visible when sidebar is stowed on mobile */
|
||||
.mobile-hamburger {
|
||||
position: fixed;
|
||||
top: 0.75rem;
|
||||
|
|
|
|||
|
|
@ -1,170 +0,0 @@
|
|||
<script setup lang="ts">
|
||||
import { ref, watch } from 'vue'
|
||||
import type { CatalogEntryFull } from '../../types/nodes'
|
||||
|
||||
const props = defineProps<{
|
||||
svcName: string
|
||||
modelName?: string
|
||||
entry?: CatalogEntryFull
|
||||
}>()
|
||||
const emit = defineEmits<{
|
||||
save: [svcName: string, modelName: string, entry: CatalogEntryFull]
|
||||
cancel: []
|
||||
}>()
|
||||
|
||||
const name = ref(props.modelName ?? '')
|
||||
const path = ref(props.entry?.path ?? '')
|
||||
const vramMb = ref(props.entry?.vram_mb ?? 0)
|
||||
const description = ref(props.entry?.description ?? '')
|
||||
const multiGpu = ref(props.entry?.multi_gpu ?? false)
|
||||
const envPairs = ref<{ k: string; v: string }[]>(
|
||||
Object.entries(props.entry?.env ?? {}).map(([k, v]) => ({ k, v }))
|
||||
)
|
||||
const formError = ref('')
|
||||
|
||||
watch(() => props.entry, (e) => {
|
||||
name.value = props.modelName ?? ''
|
||||
path.value = e?.path ?? ''
|
||||
vramMb.value = e?.vram_mb ?? 0
|
||||
description.value = e?.description ?? ''
|
||||
multiGpu.value = e?.multi_gpu ?? false
|
||||
envPairs.value = Object.entries(e?.env ?? {}).map(([k, v]) => ({ k, v }))
|
||||
})
|
||||
|
||||
function addEnvPair() {
|
||||
envPairs.value = [...envPairs.value, { k: '', v: '' }]
|
||||
}
|
||||
function removeEnvPair(i: number) {
|
||||
envPairs.value = envPairs.value.filter((_, idx) => idx !== i)
|
||||
}
|
||||
|
||||
function submit() {
|
||||
formError.value = ''
|
||||
if (!name.value.trim()) { formError.value = 'Model name is required.'; return }
|
||||
if (!path.value.trim()) { formError.value = 'Path is required.'; return }
|
||||
if (!vramMb.value || vramMb.value < 0) { formError.value = 'vram_mb must be a positive number.'; return }
|
||||
|
||||
const envObj: Record<string, string> = {}
|
||||
for (const { k, v } of envPairs.value) {
|
||||
if (k.trim()) envObj[k.trim()] = v
|
||||
}
|
||||
|
||||
const entry: CatalogEntryFull = { path: path.value.trim(), vram_mb: vramMb.value }
|
||||
if (description.value.trim()) entry.description = description.value.trim()
|
||||
if (multiGpu.value) entry.multi_gpu = true
|
||||
if (Object.keys(envObj).length) entry.env = envObj
|
||||
|
||||
emit('save', props.svcName, name.value.trim(), entry)
|
||||
}
|
||||
</script>
|
||||
|
||||
<template>
|
||||
<div class="modal-backdrop" role="dialog" aria-modal="true" :aria-label="`${modelName ? 'Edit' : 'Add'} catalog entry`">
|
||||
<div class="modal-box">
|
||||
<h3 class="modal-title">{{ modelName ? 'Edit' : 'Add' }} Catalog Entry — {{ svcName }}</h3>
|
||||
|
||||
<div class="field-row">
|
||||
<label class="field-label" for="ce-name">Model name</label>
|
||||
<input id="ce-name" v-model="name" class="field-input" :readonly="!!modelName" placeholder="deepseek-r1-7b" />
|
||||
</div>
|
||||
<div class="field-row">
|
||||
<label class="field-label" for="ce-path">Path</label>
|
||||
<input id="ce-path" v-model="path" class="field-input" placeholder="/devl/Assets/LLM/cf-text/models/..." />
|
||||
</div>
|
||||
<div class="field-row">
|
||||
<label class="field-label" for="ce-vram">VRAM (MB)</label>
|
||||
<input id="ce-vram" v-model.number="vramMb" type="number" min="0" class="field-input field-input--sm" />
|
||||
</div>
|
||||
<div class="field-row">
|
||||
<label class="field-label" for="ce-desc">Description</label>
|
||||
<input id="ce-desc" v-model="description" class="field-input" placeholder="Short description" />
|
||||
</div>
|
||||
<div class="field-row field-row--check">
|
||||
<input id="ce-mgpu" v-model="multiGpu" type="checkbox" />
|
||||
<label for="ce-mgpu">Multi-GPU span</label>
|
||||
</div>
|
||||
|
||||
<div class="env-section">
|
||||
<div class="env-header">
|
||||
<span class="field-label">Env vars</span>
|
||||
<button type="button" class="btn-link" @click="addEnvPair">+ Add</button>
|
||||
</div>
|
||||
<div v-for="(pair, i) in envPairs" :key="i" class="env-row">
|
||||
<input v-model="pair.k" class="field-input field-input--sm" placeholder="CF_TEXT_4BIT" />
|
||||
<span>=</span>
|
||||
<input v-model="pair.v" class="field-input field-input--sm" placeholder="1" />
|
||||
<button type="button" class="btn-icon" @click="removeEnvPair(i)" aria-label="Remove">✕</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div v-if="formError" class="form-error" role="alert">{{ formError }}</div>
|
||||
|
||||
<div class="modal-actions">
|
||||
<button class="btn-secondary" @click="emit('cancel')">Cancel</button>
|
||||
<button class="btn-primary" @click="submit">Save</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<style scoped>
|
||||
.modal-backdrop {
|
||||
position: fixed; inset: 0;
|
||||
background: rgba(0,0,0,0.5);
|
||||
display: flex; align-items: center; justify-content: center;
|
||||
z-index: 200;
|
||||
}
|
||||
.modal-box {
|
||||
background: var(--color-surface-raised);
|
||||
border: 1px solid var(--color-border);
|
||||
border-radius: 8px;
|
||||
padding: 1.5rem;
|
||||
width: 100%; max-width: 500px;
|
||||
max-height: 90vh; overflow-y: auto;
|
||||
display: flex; flex-direction: column; gap: 0.75rem;
|
||||
color: var(--color-text);
|
||||
}
|
||||
.modal-title { margin: 0 0 0.25rem; font-size: 1rem; font-weight: 600; color: var(--color-text); }
|
||||
.field-row { display: flex; align-items: center; gap: 0.5rem; }
|
||||
.field-row--check { gap: 0.4rem; color: var(--color-text); }
|
||||
.field-label { min-width: 8rem; font-size: 0.85rem; color: var(--color-text-muted); }
|
||||
.field-input {
|
||||
flex: 1;
|
||||
background: var(--color-surface-alt);
|
||||
border: 1px solid var(--color-border);
|
||||
border-radius: 4px;
|
||||
padding: 0.3rem 0.5rem;
|
||||
color: var(--color-text);
|
||||
font-size: 0.85rem;
|
||||
}
|
||||
.field-input--sm { flex: 0 0 8rem; }
|
||||
.env-section { display: flex; flex-direction: column; gap: 0.35rem; }
|
||||
.env-header { display: flex; align-items: center; justify-content: space-between; }
|
||||
.env-row { display: flex; align-items: center; gap: 0.4rem; }
|
||||
.btn-link { background: none; border: none; color: var(--app-primary); cursor: pointer; font-size: 0.8rem; padding: 0; }
|
||||
.btn-link:hover { color: var(--app-primary-hover); }
|
||||
.btn-icon { background: none; border: none; color: var(--color-text-muted); cursor: pointer; padding: 0 0.2rem; font-size: 0.85rem; }
|
||||
.btn-icon:hover { color: var(--color-error); }
|
||||
.form-error { color: var(--color-error); font-size: 0.8rem; }
|
||||
.modal-actions { display: flex; justify-content: flex-end; gap: 0.5rem; margin-top: 0.25rem; }
|
||||
.btn-primary {
|
||||
background: var(--app-primary);
|
||||
color: var(--color-text-inverse);
|
||||
border: none;
|
||||
border-radius: 4px;
|
||||
padding: 0.4rem 1rem;
|
||||
cursor: pointer;
|
||||
font-size: 0.875rem;
|
||||
}
|
||||
.btn-primary:hover { background: var(--app-primary-hover); }
|
||||
.btn-secondary {
|
||||
background: transparent;
|
||||
border: 1px solid var(--color-border);
|
||||
color: var(--color-text);
|
||||
border-radius: 4px;
|
||||
padding: 0.4rem 0.75rem;
|
||||
cursor: pointer;
|
||||
font-size: 0.875rem;
|
||||
}
|
||||
.btn-secondary:hover { background: var(--color-surface-alt); }
|
||||
</style>
|
||||
|
|
@ -1,129 +0,0 @@
|
|||
<script setup lang="ts">
|
||||
import { ref, computed } from 'vue'
|
||||
import ServiceBadge from './ServiceBadge.vue'
|
||||
import type { GpuEntry, ServiceInfo } from '../../types/nodes'
|
||||
|
||||
const props = defineProps<{
|
||||
gpu: GpuEntry
|
||||
nodeId: string
|
||||
profileLoaded: boolean
|
||||
servicesCatalog: Record<string, ServiceInfo>
|
||||
}>()
|
||||
|
||||
const emit = defineEmits<{ updated: [] }>()
|
||||
|
||||
const saving = ref(false)
|
||||
const saveError = ref('')
|
||||
|
||||
const vramPct = computed(() => {
|
||||
if (!props.gpu.vram_total_mb) return 0
|
||||
return Math.round((props.gpu.vram_used_mb / props.gpu.vram_total_mb) * 100)
|
||||
})
|
||||
|
||||
function serviceState(svcName: string): 'running' | 'stopped' | 'assigned-only' | 'available' | 'incompatible' | 'unknown' {
|
||||
const svc = props.servicesCatalog[svcName]
|
||||
if (!svc) return 'unknown'
|
||||
const cap = props.gpu.compute_cap ?? 0
|
||||
if (cap < svc.min_compute_cap) return 'incompatible'
|
||||
if (props.gpu.services_running.includes(svcName)) return 'running'
|
||||
if (props.gpu.services_assigned.includes(svcName)) return 'assigned-only'
|
||||
return 'available'
|
||||
}
|
||||
|
||||
async function toggleService(svcName: string) {
|
||||
if (!props.profileLoaded || saving.value) return
|
||||
const current = [...props.gpu.services_assigned]
|
||||
const removing = current.includes(svcName)
|
||||
if (removing && !confirm(`Remove ${svcName} from GPU ${props.gpu.gpu_id}?`)) return
|
||||
const next = removing ? current.filter(s => s !== svcName) : [...current, svcName]
|
||||
|
||||
saving.value = true
|
||||
saveError.value = ''
|
||||
try {
|
||||
const r = await fetch(
|
||||
`/api/nodes-mgmt/nodes/${props.nodeId}/gpu/${props.gpu.gpu_id}/services`,
|
||||
{
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ services: next }),
|
||||
},
|
||||
)
|
||||
if (!r.ok) {
|
||||
const data = await r.json().catch(() => ({}))
|
||||
throw new Error((data as { detail?: string }).detail ?? `HTTP ${r.status}`)
|
||||
}
|
||||
const data = await r.json() as { ok: boolean; reloaded: boolean; warnings: string[] }
|
||||
if (data.warnings?.length) saveError.value = `Saved (warning: ${data.warnings.join(', ')})`
|
||||
emit('updated')
|
||||
} catch (e) {
|
||||
saveError.value = e instanceof Error ? e.message : 'Failed to update services'
|
||||
} finally {
|
||||
saving.value = false
|
||||
}
|
||||
}
|
||||
</script>
|
||||
|
||||
<template>
|
||||
<div class="gpu-row">
|
||||
<div class="gpu-info">
|
||||
<span class="gpu-label">GPU {{ gpu.gpu_id }}: {{ gpu.card }}</span>
|
||||
<span v-if="gpu.compute_cap != null" class="gpu-meta">sm{{ gpu.compute_cap }}</span>
|
||||
<span v-if="gpu.temp_c != null" class="gpu-meta">{{ gpu.temp_c }}°C</span>
|
||||
<span v-if="gpu.utilization_pct != null" class="gpu-meta">{{ gpu.utilization_pct }}%</span>
|
||||
</div>
|
||||
|
||||
<div class="vram-wrap">
|
||||
<div
|
||||
class="vram-bar"
|
||||
role="progressbar"
|
||||
:aria-valuenow="gpu.vram_used_mb"
|
||||
aria-valuemin="0"
|
||||
:aria-valuemax="gpu.vram_total_mb"
|
||||
:aria-label="`VRAM: ${gpu.vram_used_mb} of ${gpu.vram_total_mb} MB used`"
|
||||
>
|
||||
<div class="vram-fill" :style="{ width: `${vramPct}%` }" />
|
||||
</div>
|
||||
<span class="vram-text">{{ gpu.vram_used_mb }} / {{ gpu.vram_total_mb }} MB ({{ vramPct }}%)</span>
|
||||
</div>
|
||||
|
||||
<div v-if="profileLoaded" class="services-row" aria-label="Service assignments">
|
||||
<ServiceBadge
|
||||
v-for="(_, svcName) in servicesCatalog"
|
||||
:key="String(svcName)"
|
||||
:service-name="String(svcName)"
|
||||
:state="serviceState(String(svcName))"
|
||||
:assigned="gpu.services_assigned.includes(String(svcName))"
|
||||
:disabled="saving"
|
||||
@toggle="toggleService(String(svcName))"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div v-if="saveError" class="save-msg" role="alert">{{ saveError }}</div>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<style scoped>
|
||||
.gpu-row {
|
||||
padding: 0.5rem 0.75rem;
|
||||
border-radius: 4px;
|
||||
background: var(--color-surface-alt);
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0.4rem;
|
||||
}
|
||||
.gpu-info { display: flex; gap: 0.75rem; align-items: center; flex-wrap: wrap; font-size: 0.875rem; }
|
||||
.gpu-label { font-weight: 500; color: var(--color-text); }
|
||||
.gpu-meta { color: var(--color-text-muted); font-size: 0.8rem; }
|
||||
.vram-wrap { display: flex; align-items: center; gap: 0.5rem; }
|
||||
.vram-bar {
|
||||
flex: 1;
|
||||
height: 8px;
|
||||
background: var(--color-border);
|
||||
border-radius: 4px;
|
||||
overflow: hidden;
|
||||
}
|
||||
.vram-fill { height: 100%; background: var(--app-primary); transition: width 0.3s; }
|
||||
.vram-text { font-size: 0.75rem; color: var(--color-text-muted); white-space: nowrap; }
|
||||
.services-row { display: flex; flex-wrap: wrap; gap: 0.4rem; }
|
||||
.save-msg { color: var(--color-warning); font-size: 0.8rem; }
|
||||
</style>
|
||||
|
|
@ -1,134 +0,0 @@
|
|||
<script setup lang="ts">
|
||||
import { ref, onMounted, onUnmounted } from 'vue'
|
||||
|
||||
interface CatalogEntry {
|
||||
path: string
|
||||
vram_mb: number
|
||||
description: string
|
||||
multi_gpu: boolean
|
||||
}
|
||||
|
||||
interface ServiceProfile {
|
||||
catalog: Record<string, CatalogEntry>
|
||||
min_compute_cap: number
|
||||
max_mb: number
|
||||
}
|
||||
|
||||
interface NodeProfile {
|
||||
services: Record<string, ServiceProfile>
|
||||
}
|
||||
|
||||
const props = defineProps<{
|
||||
nodeId: string
|
||||
}>()
|
||||
|
||||
const profile = ref<NodeProfile | null>(null)
|
||||
const loading = ref(true)
|
||||
const error = ref('')
|
||||
|
||||
let fetchAbort: AbortController | null = null
|
||||
|
||||
async function fetchProfile() {
|
||||
fetchAbort?.abort()
|
||||
fetchAbort = new AbortController()
|
||||
loading.value = true
|
||||
error.value = ''
|
||||
try {
|
||||
const r = await fetch(`/api/nodes-mgmt/nodes/${props.nodeId}/profile`, {
|
||||
signal: fetchAbort.signal,
|
||||
})
|
||||
if (r.status === 404) { profile.value = null; return }
|
||||
if (!r.ok) throw new Error(`HTTP ${r.status}`)
|
||||
profile.value = await r.json() as NodeProfile
|
||||
} catch (e) {
|
||||
if (e instanceof Error && e.name === 'AbortError') return
|
||||
error.value = e instanceof Error ? e.message : 'Failed to load profile'
|
||||
} finally {
|
||||
loading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
onMounted(fetchProfile)
|
||||
onUnmounted(() => { fetchAbort?.abort() })
|
||||
</script>
|
||||
|
||||
<template>
|
||||
<section class="hf-panel">
|
||||
<h3 class="panel-title">Model Catalog</h3>
|
||||
<p class="hf-hint">
|
||||
To download a new HuggingFace model,
|
||||
<a href="#/fleet" class="hf-link">go to Fleet</a>.
|
||||
Models downloaded there are automatically registered in node catalogs.
|
||||
</p>
|
||||
|
||||
<div aria-live="polite" aria-atomic="true" class="sr-announce">
|
||||
<span v-if="loading">Loading catalog...</span>
|
||||
</div>
|
||||
<div v-if="error" class="panel-error" role="alert">{{ error }}</div>
|
||||
<div v-else-if="!loading && !profile" class="panel-empty">No profile loaded for this node.</div>
|
||||
<div v-else-if="!loading && profile" class="catalog-body">
|
||||
<div
|
||||
v-for="(svcInfo, svcName) in profile.services"
|
||||
:key="String(svcName)"
|
||||
class="svc-section"
|
||||
>
|
||||
<h4 class="svc-name">{{ svcName }}</h4>
|
||||
<ul class="catalog-list" role="list">
|
||||
<li
|
||||
v-if="!Object.keys(svcInfo.catalog ?? {}).length"
|
||||
class="catalog-empty"
|
||||
>
|
||||
No models in catalog.
|
||||
</li>
|
||||
<li
|
||||
v-for="(entry, modelName) in (svcInfo.catalog ?? {})"
|
||||
:key="String(modelName)"
|
||||
class="catalog-item"
|
||||
>
|
||||
<span class="catalog-model">{{ modelName }}</span>
|
||||
<span class="catalog-vram">{{ entry.vram_mb }} MB</span>
|
||||
<span v-if="entry.description" class="catalog-desc">{{ entry.description }}</span>
|
||||
</li>
|
||||
</ul>
|
||||
</div>
|
||||
</div>
|
||||
</section>
|
||||
</template>
|
||||
|
||||
<style scoped>
|
||||
.hf-panel {
|
||||
margin-top: 0.75rem;
|
||||
padding: 0.75rem;
|
||||
border: 1px solid var(--color-border);
|
||||
border-radius: 6px;
|
||||
color: var(--color-text);
|
||||
}
|
||||
.panel-title { margin: 0 0 0.5rem; font-size: 0.9rem; color: var(--color-text); }
|
||||
.hf-hint { font-size: 0.8rem; color: var(--color-text-muted); margin: 0 0 0.75rem; }
|
||||
.hf-link { color: var(--app-primary); }
|
||||
.hf-link:hover { color: var(--app-primary-hover); }
|
||||
.svc-section { margin-bottom: 0.75rem; }
|
||||
.svc-name {
|
||||
margin: 0 0 0.25rem;
|
||||
font-size: 0.75rem;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.05em;
|
||||
color: var(--color-text-muted);
|
||||
}
|
||||
.catalog-list { list-style: none; margin: 0; padding: 0; display: flex; flex-direction: column; gap: 0.2rem; }
|
||||
.catalog-item {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.5rem;
|
||||
padding: 0.25rem 0.5rem;
|
||||
background: var(--color-surface-alt);
|
||||
border-radius: 4px;
|
||||
font-size: 0.8rem;
|
||||
}
|
||||
.catalog-model { font-family: var(--font-mono, monospace); flex: 1; }
|
||||
.catalog-vram { color: var(--color-text-muted); white-space: nowrap; }
|
||||
.catalog-desc { color: var(--color-text-muted); font-size: 0.75rem; flex: 2; }
|
||||
.catalog-empty, .panel-empty { color: var(--color-text-muted); font-size: 0.875rem; }
|
||||
.sr-announce { min-height: 1.2em; }
|
||||
.panel-error { color: var(--color-error); font-size: 0.8rem; }
|
||||
</style>
|
||||
|
|
@ -1,148 +0,0 @@
|
|||
<script setup lang="ts">
|
||||
import { ref } from 'vue'
|
||||
import GpuRow from './GpuRow.vue'
|
||||
import OllamaModelPanel from './OllamaModelPanel.vue'
|
||||
import ProfileEditorPanel from './ProfileEditorPanel.vue'
|
||||
import type { NodeSummary, FullProfile } from '../../types/nodes'
|
||||
|
||||
const props = defineProps<{ node: NodeSummary }>()
|
||||
const emit = defineEmits<{ updated: [] }>()
|
||||
|
||||
const showOllama = ref(false)
|
||||
const showEditor = ref(false)
|
||||
const loadedProfile = ref<FullProfile | null>(null)
|
||||
const profileLoading = ref(false)
|
||||
const profileError = ref('')
|
||||
|
||||
async function openEditor() {
|
||||
if (showEditor.value) { showEditor.value = false; return }
|
||||
profileLoading.value = true
|
||||
profileError.value = ''
|
||||
try {
|
||||
const r = await fetch(`/api/nodes-mgmt/nodes/${props.node.node_id}/profile`)
|
||||
if (r.status === 404) {
|
||||
loadedProfile.value = null
|
||||
} else if (!r.ok) {
|
||||
throw new Error(`HTTP ${r.status}`)
|
||||
} else {
|
||||
loadedProfile.value = await r.json() as FullProfile
|
||||
}
|
||||
showEditor.value = true
|
||||
} catch (e) {
|
||||
profileError.value = e instanceof Error ? e.message : 'Failed to load profile'
|
||||
} finally {
|
||||
profileLoading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
function onProfileSaved() {
|
||||
showEditor.value = false
|
||||
emit('updated')
|
||||
}
|
||||
</script>
|
||||
|
||||
<template>
|
||||
<section class="node-card" :class="{ offline: !node.online }">
|
||||
<header class="node-card-header">
|
||||
<div class="node-identity">
|
||||
<span
|
||||
class="status-dot"
|
||||
:class="node.online ? 'online' : 'offline'"
|
||||
:aria-label="node.online ? 'Online' : 'Offline'"
|
||||
role="img"
|
||||
/>
|
||||
<h2 class="node-name">{{ node.node_id }}</h2>
|
||||
<span class="node-agent">{{ node.agent_url }}</span>
|
||||
</div>
|
||||
<div class="node-actions">
|
||||
<button
|
||||
v-if="node.profile_loaded"
|
||||
class="btn-secondary btn-sm"
|
||||
@click="showOllama = !showOllama"
|
||||
>
|
||||
{{ showOllama ? 'Hide Ollama' : 'Ollama' }}
|
||||
</button>
|
||||
<button
|
||||
class="btn-secondary btn-sm"
|
||||
:disabled="profileLoading"
|
||||
@click="openEditor"
|
||||
>
|
||||
{{ profileLoading ? 'Loading…' : node.profile_loaded ? (showEditor ? 'Close Editor' : 'Edit Profile') : 'Create Profile' }}
|
||||
</button>
|
||||
</div>
|
||||
</header>
|
||||
|
||||
<div v-if="!node.profile_loaded" class="no-profile" role="status">
|
||||
No profile configured for this node. GPU stats are visible; service assignment is disabled.
|
||||
</div>
|
||||
|
||||
<div class="gpu-list">
|
||||
<GpuRow
|
||||
v-for="gpu in node.gpus"
|
||||
:key="gpu.gpu_id"
|
||||
:gpu="gpu"
|
||||
:node-id="node.node_id"
|
||||
:profile-loaded="node.profile_loaded"
|
||||
:services-catalog="node.services_catalog"
|
||||
@updated="emit('updated')"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<OllamaModelPanel v-if="showOllama" :node-id="node.node_id" />
|
||||
<div v-if="profileError" class="profile-load-error" role="alert">{{ profileError }}</div>
|
||||
<ProfileEditorPanel
|
||||
v-if="showEditor"
|
||||
:node-id="node.node_id"
|
||||
:initial-profile="loadedProfile"
|
||||
@saved="onProfileSaved"
|
||||
@close="showEditor = false"
|
||||
/>
|
||||
</section>
|
||||
</template>
|
||||
|
||||
<style scoped>
|
||||
.node-card {
|
||||
border: 1px solid var(--color-border);
|
||||
border-radius: 8px;
|
||||
padding: 1rem;
|
||||
background: var(--color-surface-raised);
|
||||
color: var(--color-text);
|
||||
}
|
||||
.node-card.offline { opacity: 0.65; }
|
||||
.node-card-header {
|
||||
display: flex;
|
||||
align-items: flex-start;
|
||||
justify-content: space-between;
|
||||
gap: 0.5rem;
|
||||
margin-bottom: 0.75rem;
|
||||
}
|
||||
.node-identity { display: flex; align-items: center; gap: 0.5rem; flex-wrap: wrap; }
|
||||
.node-name { margin: 0; font-size: 1rem; font-weight: 600; color: var(--color-text); }
|
||||
.node-agent { color: var(--color-text-muted); font-size: 0.8rem; font-family: var(--font-mono, monospace); }
|
||||
.status-dot { width: 10px; height: 10px; border-radius: 50%; flex-shrink: 0; }
|
||||
.status-dot.online { background: var(--color-success); }
|
||||
.status-dot.offline { background: var(--color-warning); }
|
||||
.node-actions { display: flex; gap: 0.5rem; flex-shrink: 0; }
|
||||
.btn-secondary {
|
||||
background: transparent;
|
||||
border: 1px solid var(--color-border);
|
||||
color: var(--color-text);
|
||||
border-radius: 4px;
|
||||
padding: 0.3rem 0.65rem;
|
||||
cursor: pointer;
|
||||
font-size: 0.8rem;
|
||||
}
|
||||
.btn-secondary:hover { background: var(--color-surface-alt); }
|
||||
.btn-secondary:disabled { opacity: 0.5; cursor: not-allowed; }
|
||||
.btn-sm { font-size: 0.8rem; padding: 0.25rem 0.6rem; }
|
||||
.no-profile {
|
||||
padding: 0.6rem 0.75rem;
|
||||
background: var(--color-surface-alt);
|
||||
border-radius: 4px;
|
||||
color: var(--color-text-muted);
|
||||
font-size: 0.875rem;
|
||||
margin-bottom: 0.5rem;
|
||||
}
|
||||
.gpu-list { display: flex; flex-direction: column; gap: 0.5rem; }
|
||||
.profile-load-error { color: var(--color-error); font-size: 0.8rem; margin-top: 0.5rem; }
|
||||
</style>
|
||||
|
|
@ -1,242 +0,0 @@
|
|||
<script setup lang="ts">
|
||||
import { ref, onMounted, onUnmounted } from 'vue'
|
||||
|
||||
const props = defineProps<{ nodeId: string }>()
|
||||
|
||||
interface OllamaModel {
|
||||
name: string
|
||||
size: number
|
||||
modified_at: string
|
||||
}
|
||||
|
||||
const models = ref<OllamaModel[]>([])
|
||||
const loading = ref(true)
|
||||
const loadError = ref('')
|
||||
const pullName = ref('')
|
||||
const pulling = ref(false)
|
||||
const pullStatus = ref('')
|
||||
const pullPct = ref(0)
|
||||
const pullError = ref('')
|
||||
|
||||
// AbortController for the SSE pull stream
|
||||
const abortCtrl = ref<AbortController | null>(null)
|
||||
|
||||
// AbortController for the one-shot fetchModels request
|
||||
let fetchAbort: AbortController | null = null
|
||||
|
||||
async function fetchModels() {
|
||||
fetchAbort?.abort()
|
||||
fetchAbort = new AbortController()
|
||||
loading.value = true
|
||||
loadError.value = ''
|
||||
try {
|
||||
const r = await fetch(`/api/nodes-mgmt/nodes/${props.nodeId}/models/ollama`, {
|
||||
signal: fetchAbort.signal,
|
||||
})
|
||||
const data = await r.json() as { models?: OllamaModel[]; error?: string }
|
||||
if (data.error) { loadError.value = data.error; return }
|
||||
models.value = data.models ?? []
|
||||
} catch (e) {
|
||||
if (e instanceof Error && e.name === 'AbortError') return
|
||||
loadError.value = e instanceof Error ? e.message : 'Failed to load models'
|
||||
} finally {
|
||||
loading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
async function doPull() {
|
||||
const name = pullName.value.trim()
|
||||
if (!name || pulling.value) return
|
||||
pulling.value = true
|
||||
pullStatus.value = 'Starting...'
|
||||
pullError.value = ''
|
||||
pullPct.value = 0
|
||||
|
||||
const ctrl = new AbortController()
|
||||
abortCtrl.value = ctrl
|
||||
|
||||
try {
|
||||
const resp = await fetch(`/api/nodes-mgmt/nodes/${props.nodeId}/models/ollama/pull`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ name }),
|
||||
signal: ctrl.signal,
|
||||
})
|
||||
if (!resp.ok) throw new Error(`HTTP ${resp.status}`)
|
||||
if (!resp.body) throw new Error('No response body')
|
||||
|
||||
const reader = resp.body.getReader()
|
||||
const decoder = new TextDecoder()
|
||||
let buf = ''
|
||||
|
||||
while (true) {
|
||||
const { done, value } = await reader.read()
|
||||
if (done) break
|
||||
buf += decoder.decode(value, { stream: true })
|
||||
const lines = buf.split('\n')
|
||||
buf = lines.pop() ?? ''
|
||||
for (const line of lines) {
|
||||
if (!line.startsWith('data: ')) continue
|
||||
try {
|
||||
const evt = JSON.parse(line.slice(6)) as {
|
||||
status?: string; error?: string; total?: number; completed?: number
|
||||
}
|
||||
if (evt.error) {
|
||||
pullError.value = evt.error
|
||||
break
|
||||
}
|
||||
if (evt.status) pullStatus.value = evt.status
|
||||
if (evt.total && evt.completed) {
|
||||
pullPct.value = Math.round((evt.completed / evt.total) * 100)
|
||||
}
|
||||
if (evt.status === 'success') {
|
||||
pullStatus.value = 'Done!'
|
||||
pullName.value = ''
|
||||
break
|
||||
}
|
||||
} catch { /* skip malformed line */ }
|
||||
}
|
||||
}
|
||||
|
||||
// Refresh model list after the stream closes (success or benign end)
|
||||
await fetchModels()
|
||||
} catch (e) {
|
||||
if (e instanceof Error && e.name === 'AbortError') return
|
||||
pullError.value = e instanceof Error ? e.message : 'Pull failed'
|
||||
} finally {
|
||||
pulling.value = false
|
||||
abortCtrl.value = null
|
||||
}
|
||||
}
|
||||
|
||||
async function deleteModel(name: string) {
|
||||
if (!confirm(`Delete model "${name}" from node ${props.nodeId}?`)) return
|
||||
try {
|
||||
const r = await fetch(
|
||||
`/api/nodes-mgmt/nodes/${props.nodeId}/models/ollama/${encodeURIComponent(name)}`,
|
||||
{ method: 'DELETE' },
|
||||
)
|
||||
if (!r.ok) throw new Error(`HTTP ${r.status}`)
|
||||
await fetchModels()
|
||||
} catch (e) {
|
||||
loadError.value = e instanceof Error ? e.message : 'Delete failed'
|
||||
}
|
||||
}
|
||||
|
||||
function formatSize(bytes: number): string {
|
||||
return (bytes / 1e9).toFixed(1) + ' GB'
|
||||
}
|
||||
|
||||
onMounted(fetchModels)
|
||||
onUnmounted(() => {
|
||||
abortCtrl.value?.abort()
|
||||
fetchAbort?.abort()
|
||||
})
|
||||
</script>
|
||||
|
||||
<template>
|
||||
<section class="ollama-panel">
|
||||
<h3 class="panel-title">Ollama Models</h3>
|
||||
|
||||
<form class="pull-form" @submit.prevent="doPull">
|
||||
<input
|
||||
v-model="pullName"
|
||||
type="text"
|
||||
placeholder="nomic-embed-text, llama3.2:3b, ..."
|
||||
:disabled="pulling"
|
||||
aria-label="Model name to pull from Ollama"
|
||||
class="pull-input"
|
||||
/>
|
||||
<button type="submit" :disabled="pulling || !pullName.trim()" class="btn-primary btn-sm">
|
||||
{{ pulling ? 'Pulling...' : 'Pull' }}
|
||||
</button>
|
||||
</form>
|
||||
|
||||
<div v-if="pulling || pullStatus" class="pull-progress" aria-live="polite">
|
||||
<div
|
||||
class="progress-bar"
|
||||
role="progressbar"
|
||||
:aria-valuenow="pullPct"
|
||||
aria-valuemin="0"
|
||||
aria-valuemax="100"
|
||||
:aria-label="`Pull progress: ${pullStatus}`"
|
||||
>
|
||||
<div class="progress-fill" :style="{ width: `${pullPct}%` }" />
|
||||
</div>
|
||||
<span class="progress-label">{{ pullStatus }}{{ pullPct > 0 ? ` (${pullPct}%)` : '' }}</span>
|
||||
</div>
|
||||
|
||||
<div v-if="pullError" class="pull-error" role="alert">
|
||||
{{ pullError }}
|
||||
<span v-if="pullError.includes('permission denied')">
|
||||
— Remove the partial file on the node and retry.
|
||||
</span>
|
||||
</div>
|
||||
|
||||
<div aria-live="polite" aria-atomic="true" class="sr-announce">
|
||||
<span v-if="loading">Loading...</span>
|
||||
</div>
|
||||
<div v-if="loadError" class="panel-error" role="alert">{{ loadError }}</div>
|
||||
<ul v-if="!loading && !loadError" class="model-list" role="list">
|
||||
<li v-if="!models.length" class="model-empty">No Ollama models installed on this node.</li>
|
||||
<li v-for="m in models" :key="m.name" class="model-item">
|
||||
<span class="model-name">{{ m.name }}</span>
|
||||
<span class="model-size">{{ formatSize(m.size) }}</span>
|
||||
<button
|
||||
class="btn-danger btn-xs"
|
||||
@click="deleteModel(m.name)"
|
||||
:aria-label="`Delete ${m.name}`"
|
||||
>
|
||||
Delete
|
||||
</button>
|
||||
</li>
|
||||
</ul>
|
||||
</section>
|
||||
</template>
|
||||
|
||||
<style scoped>
|
||||
.ollama-panel {
|
||||
margin-top: 0.75rem;
|
||||
padding: 0.75rem;
|
||||
border: 1px solid var(--color-border);
|
||||
border-radius: 6px;
|
||||
color: var(--color-text);
|
||||
}
|
||||
.panel-title { margin: 0 0 0.75rem; font-size: 0.9rem; color: var(--color-text); }
|
||||
.pull-form { display: flex; gap: 0.5rem; margin-bottom: 0.5rem; }
|
||||
.pull-input {
|
||||
flex: 1;
|
||||
padding: 0.3rem 0.5rem;
|
||||
background: var(--color-surface-alt);
|
||||
border: 1px solid var(--color-border);
|
||||
border-radius: 4px;
|
||||
color: var(--color-text);
|
||||
font-size: 0.875rem;
|
||||
}
|
||||
.pull-progress { margin-bottom: 0.5rem; }
|
||||
.progress-bar {
|
||||
height: 8px;
|
||||
background: var(--color-border);
|
||||
border-radius: 4px;
|
||||
overflow: hidden;
|
||||
margin-bottom: 0.25rem;
|
||||
}
|
||||
.progress-fill { height: 100%; background: var(--app-primary); transition: width 0.2s; }
|
||||
.progress-label { font-size: 0.75rem; color: var(--color-text-muted); }
|
||||
.pull-error, .panel-error { color: var(--color-error); font-size: 0.8rem; margin-bottom: 0.5rem; }
|
||||
.sr-announce { min-height: 1.2em; }
|
||||
.panel-loading { color: var(--color-text-muted); font-size: 0.875rem; }
|
||||
.model-list { list-style: none; margin: 0; padding: 0; display: flex; flex-direction: column; gap: 0.3rem; }
|
||||
.model-item {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.5rem;
|
||||
padding: 0.3rem 0.5rem;
|
||||
background: var(--color-surface-alt);
|
||||
border-radius: 4px;
|
||||
font-size: 0.875rem;
|
||||
}
|
||||
.model-name { flex: 1; font-family: var(--font-mono, monospace); }
|
||||
.model-size { color: var(--color-text-muted); font-size: 0.8rem; }
|
||||
.model-empty { color: var(--color-text-muted); font-size: 0.875rem; padding: 0.25rem 0; }
|
||||
</style>
|
||||
|
|
@ -1,597 +0,0 @@
|
|||
<script setup lang="ts">
|
||||
import { ref, onMounted } from 'vue'
|
||||
import type { FullProfile, ServiceDefinition, CatalogEntryFull } from '../../types/nodes'
|
||||
import ServiceFormModal from './ServiceFormModal.vue'
|
||||
import CatalogEntryFormModal from './CatalogEntryFormModal.vue'
|
||||
|
||||
const props = defineProps<{
|
||||
nodeId: string
|
||||
initialProfile: FullProfile | null
|
||||
}>()
|
||||
const emit = defineEmits<{ saved: []; close: [] }>()
|
||||
|
||||
// Deep-clone initial profile so edits don't mutate the parent's data
|
||||
const profile = ref<FullProfile>(
|
||||
props.initialProfile
|
||||
? JSON.parse(JSON.stringify(props.initialProfile))
|
||||
: { services: {}, nodes: {} }
|
||||
)
|
||||
|
||||
const saving = ref(false)
|
||||
const generating = ref(false)
|
||||
const opError = ref('')
|
||||
const expandedSvcs = ref<Set<string>>(new Set())
|
||||
|
||||
// Service modal
|
||||
const showSvcModal = ref(false)
|
||||
const editingSvcName = ref<string | undefined>()
|
||||
const editingSvcDef = ref<ServiceDefinition | undefined>()
|
||||
|
||||
// Catalog modal
|
||||
const showCatalogModal = ref(false)
|
||||
const catalogTargetSvc = ref('')
|
||||
const editingModelName = ref<string | undefined>()
|
||||
const editingEntry = ref<CatalogEntryFull | undefined>()
|
||||
|
||||
// ── Generate nodes section from coordinator ────────────────────────────────────
|
||||
|
||||
async function generate() {
|
||||
generating.value = true
|
||||
opError.value = ''
|
||||
try {
|
||||
const r = await fetch(`/api/nodes-mgmt/nodes/${props.nodeId}/profile/generate`, { method: 'POST' })
|
||||
if (!r.ok) { const d = await r.json().catch(() => ({})); throw new Error((d as {detail?: string}).detail ?? `HTTP ${r.status}`) }
|
||||
const generated = await r.json() as FullProfile
|
||||
// Merge: keep current services edits, replace nodes section
|
||||
profile.value = { ...generated, services: profile.value.services }
|
||||
} catch (e) {
|
||||
opError.value = e instanceof Error ? e.message : 'Generate failed'
|
||||
} finally {
|
||||
generating.value = false
|
||||
}
|
||||
}
|
||||
|
||||
// ── Save full profile ──────────────────────────────────────────────────────────
|
||||
|
||||
async function save() {
|
||||
saving.value = true
|
||||
opError.value = ''
|
||||
try {
|
||||
const r = await fetch(`/api/nodes-mgmt/nodes/${props.nodeId}/profile`, {
|
||||
method: 'PUT',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ profile: profile.value }),
|
||||
})
|
||||
if (!r.ok) { const d = await r.json().catch(() => ({})); throw new Error((d as {detail?: string}).detail ?? `HTTP ${r.status}`) }
|
||||
emit('saved')
|
||||
} catch (e) {
|
||||
opError.value = e instanceof Error ? e.message : 'Save failed'
|
||||
} finally {
|
||||
saving.value = false
|
||||
}
|
||||
}
|
||||
|
||||
// ── Service CRUD ───────────────────────────────────────────────────────────────
|
||||
|
||||
function openAddService() {
|
||||
editingSvcName.value = undefined
|
||||
editingSvcDef.value = undefined
|
||||
showSvcModal.value = true
|
||||
}
|
||||
|
||||
function openEditService(name: string) {
|
||||
editingSvcName.value = name
|
||||
editingSvcDef.value = JSON.parse(JSON.stringify(profile.value.services[name]))
|
||||
showSvcModal.value = true
|
||||
}
|
||||
|
||||
function onServiceSaved(name: string, def: ServiceDefinition) {
|
||||
profile.value = { ...profile.value, services: { ...profile.value.services, [name]: def } }
|
||||
expandedSvcs.value = new Set([...expandedSvcs.value, name])
|
||||
showSvcModal.value = false
|
||||
}
|
||||
|
||||
function deleteService(name: string) {
|
||||
if (!confirm(`Remove service "${name}" from this profile?`)) return
|
||||
const svcs = { ...profile.value.services }
|
||||
delete svcs[name]
|
||||
profile.value = { ...profile.value, services: svcs }
|
||||
expandedSvcs.value = new Set([...expandedSvcs.value].filter(s => s !== name))
|
||||
}
|
||||
|
||||
function toggleSvc(name: string) {
|
||||
const s = new Set(expandedSvcs.value)
|
||||
s.has(name) ? s.delete(name) : s.add(name)
|
||||
expandedSvcs.value = s
|
||||
}
|
||||
|
||||
// ── Catalog CRUD ───────────────────────────────────────────────────────────────
|
||||
|
||||
function openAddCatalogEntry(svcName: string) {
|
||||
catalogTargetSvc.value = svcName
|
||||
editingModelName.value = undefined
|
||||
editingEntry.value = undefined
|
||||
showCatalogModal.value = true
|
||||
}
|
||||
|
||||
function openEditCatalogEntry(svcName: string, modelName: string) {
|
||||
catalogTargetSvc.value = svcName
|
||||
editingModelName.value = modelName
|
||||
editingEntry.value = JSON.parse(JSON.stringify(profile.value.services[svcName].catalog![modelName]))
|
||||
showCatalogModal.value = true
|
||||
}
|
||||
|
||||
function onCatalogEntrySaved(svcName: string, modelName: string, entry: CatalogEntryFull) {
|
||||
const svcs = { ...profile.value.services }
|
||||
const svc = { ...svcs[svcName], catalog: { ...(svcs[svcName].catalog ?? {}), [modelName]: entry } }
|
||||
svcs[svcName] = svc
|
||||
profile.value = { ...profile.value, services: svcs }
|
||||
showCatalogModal.value = false
|
||||
}
|
||||
|
||||
function deleteCatalogEntry(svcName: string, modelName: string) {
|
||||
if (!confirm(`Remove model "${modelName}" from ${svcName} catalog?`)) return
|
||||
const svcs = { ...profile.value.services }
|
||||
const catalog = { ...(svcs[svcName].catalog ?? {}) }
|
||||
delete catalog[modelName]
|
||||
svcs[svcName] = { ...svcs[svcName], catalog }
|
||||
profile.value = { ...profile.value, services: svcs }
|
||||
}
|
||||
|
||||
// ── Helpers ────────────────────────────────────────────────────────────────────
|
||||
|
||||
function gpuList() {
|
||||
return (profile.value.nodes[props.nodeId]?.gpus ?? [])
|
||||
}
|
||||
|
||||
function serviceCount() {
|
||||
return Object.keys(profile.value.services).length
|
||||
}
|
||||
|
||||
// ── Ollama model suggestions ───────────────────────────────────────────────────
|
||||
|
||||
interface OllamaModel { name: string; size: number }
|
||||
const ollamaModels = ref<OllamaModel[]>([])
|
||||
const ollamaLoading = ref(false)
|
||||
|
||||
onMounted(async () => {
|
||||
ollamaLoading.value = true
|
||||
try {
|
||||
const r = await fetch(`/api/nodes-mgmt/nodes/${props.nodeId}/models/ollama`)
|
||||
if (r.ok) {
|
||||
const d = await r.json() as { models?: OllamaModel[] }
|
||||
ollamaModels.value = d.models ?? []
|
||||
}
|
||||
} catch { /* Ollama offline — silently skip */ }
|
||||
finally { ollamaLoading.value = false }
|
||||
})
|
||||
|
||||
function ollamaNotInCatalog(svcName: string): OllamaModel[] {
|
||||
const catalog = profile.value.services[svcName]?.catalog ?? {}
|
||||
return ollamaModels.value.filter(m => !(m.name in catalog))
|
||||
}
|
||||
|
||||
function openAddFromOllama(svcName: string, modelName: string) {
|
||||
catalogTargetSvc.value = svcName
|
||||
editingModelName.value = modelName
|
||||
editingEntry.value = {
|
||||
path: profile.value.services[svcName]?.model_base_path
|
||||
? `${profile.value.services[svcName].model_base_path}/${modelName}`
|
||||
: '',
|
||||
vram_mb: 0,
|
||||
}
|
||||
showCatalogModal.value = true
|
||||
}
|
||||
|
||||
function formatMb(bytes: number): string {
|
||||
return bytes >= 1_000_000_000
|
||||
? `${(bytes / 1_073_741_824).toFixed(1)} GB`
|
||||
: `${Math.round(bytes / 1_048_576)} MB`
|
||||
}
|
||||
|
||||
// ── Pull model onto node ───────────────────────────────────────────────────────
|
||||
|
||||
const pullName = ref('')
|
||||
const pulling = ref(false)
|
||||
const pullStatus = ref('')
|
||||
const pullPct = ref(0)
|
||||
const pullError = ref('')
|
||||
let pullAbort: AbortController | null = null
|
||||
|
||||
async function doPull() {
|
||||
const name = pullName.value.trim()
|
||||
if (!name || pulling.value) return
|
||||
pulling.value = true
|
||||
pullStatus.value = 'Starting…'
|
||||
pullError.value = ''
|
||||
pullPct.value = 0
|
||||
pullAbort?.abort()
|
||||
pullAbort = new AbortController()
|
||||
|
||||
try {
|
||||
const resp = await fetch(`/api/nodes-mgmt/nodes/${props.nodeId}/models/ollama/pull`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ name }),
|
||||
signal: pullAbort.signal,
|
||||
})
|
||||
if (!resp.ok || !resp.body) {
|
||||
pullError.value = `HTTP ${resp.status}`
|
||||
return
|
||||
}
|
||||
const reader = resp.body.getReader()
|
||||
const dec = new TextDecoder()
|
||||
let buf = ''
|
||||
while (true) {
|
||||
const { done, value } = await reader.read()
|
||||
if (done) break
|
||||
buf += dec.decode(value, { stream: true })
|
||||
const lines = buf.split('\n')
|
||||
buf = lines.pop() ?? ''
|
||||
for (const line of lines) {
|
||||
if (!line.startsWith('data:')) continue
|
||||
try {
|
||||
const d = JSON.parse(line.slice(5)) as {
|
||||
status?: string; completed?: number; total?: number; error?: string; done?: boolean
|
||||
}
|
||||
if (d.error) { pullError.value = d.error; return }
|
||||
pullStatus.value = d.status ?? ''
|
||||
if (d.total && d.total > 0) pullPct.value = Math.round((d.completed ?? 0) / d.total * 100)
|
||||
if (d.done) {
|
||||
pullName.value = ''
|
||||
pullPct.value = 100
|
||||
// Refresh Ollama model list so new model appears in suggest chips
|
||||
const r = await fetch(`/api/nodes-mgmt/nodes/${props.nodeId}/models/ollama`)
|
||||
if (r.ok) { const d2 = await r.json() as { models?: OllamaModel[] }; ollamaModels.value = d2.models ?? [] }
|
||||
}
|
||||
} catch { /* skip malformed SSE line */ }
|
||||
}
|
||||
}
|
||||
} catch (e) {
|
||||
if (e instanceof Error && e.name !== 'AbortError') pullError.value = e.message
|
||||
} finally {
|
||||
pulling.value = false
|
||||
if (pullPct.value === 100) setTimeout(() => { pullStatus.value = ''; pullPct.value = 0 }, 2000)
|
||||
}
|
||||
}
|
||||
</script>
|
||||
|
||||
<template>
|
||||
<section class="pep" aria-label="Profile editor">
|
||||
<!-- Header -->
|
||||
<div class="pep-header">
|
||||
<div class="pep-title-row">
|
||||
<h3 class="pep-title">Profile — {{ nodeId }}</h3>
|
||||
<span class="pep-svc-count">{{ serviceCount() }} service{{ serviceCount() === 1 ? '' : 's' }}</span>
|
||||
</div>
|
||||
<div class="pep-actions">
|
||||
<button class="btn-secondary btn-sm" :disabled="generating" @click="generate">
|
||||
{{ generating ? 'Refreshing…' : 'Refresh Hardware' }}
|
||||
</button>
|
||||
<button class="btn-primary btn-sm" :disabled="saving" @click="save">
|
||||
{{ saving ? 'Saving…' : 'Save Profile' }}
|
||||
</button>
|
||||
<button class="btn-icon-lg" aria-label="Close editor" @click="emit('close')">✕</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div v-if="opError" class="pep-error" role="alert">{{ opError }}</div>
|
||||
|
||||
<!-- Meta fields -->
|
||||
<div class="pep-meta">
|
||||
<label class="meta-label" for="pep-vram">vram_total_mb</label>
|
||||
<input id="pep-vram" v-model.number="profile.vram_total_mb" type="number" min="0" class="meta-input" />
|
||||
<label class="meta-label" for="pep-evict">eviction_timeout_s</label>
|
||||
<input id="pep-evict" v-model.number="profile.eviction_timeout_s" type="number" min="0" step="0.5" class="meta-input" />
|
||||
</div>
|
||||
|
||||
<!-- Hardware summary -->
|
||||
<div v-if="gpuList().length" class="hw-section">
|
||||
<span class="hw-label">Hardware</span>
|
||||
<span v-for="g in gpuList()" :key="g.id" class="hw-gpu">
|
||||
GPU {{ g.id }}: {{ g.card || 'unknown' }} · {{ g.vram_mb }} MB · sm{{ g.compute_cap ?? '?' }}
|
||||
</span>
|
||||
<span v-if="!gpuList().length" class="hw-none">No hardware data — click Refresh Hardware.</span>
|
||||
</div>
|
||||
<div v-else class="hw-section">
|
||||
<span class="hw-none">No hardware data — click Refresh Hardware to seed from coordinator.</span>
|
||||
</div>
|
||||
|
||||
<!-- Services -->
|
||||
<div class="svcs-header">
|
||||
<span class="svcs-title">Services</span>
|
||||
<button class="btn-secondary btn-sm" @click="openAddService">+ Add Service</button>
|
||||
</div>
|
||||
|
||||
<div v-if="serviceCount() === 0" class="svcs-empty">
|
||||
No services defined. Add a service to configure what can run on this node.
|
||||
</div>
|
||||
|
||||
<ul class="svcs-list" role="list">
|
||||
<li
|
||||
v-for="(def, svcName) in profile.services"
|
||||
:key="String(svcName)"
|
||||
class="svc-item"
|
||||
>
|
||||
<!-- Service row header -->
|
||||
<div class="svc-row">
|
||||
<button
|
||||
class="svc-toggle"
|
||||
:aria-expanded="expandedSvcs.has(String(svcName))"
|
||||
@click="toggleSvc(String(svcName))"
|
||||
>
|
||||
<span class="svc-arrow">{{ expandedSvcs.has(String(svcName)) ? '▾' : '▸' }}</span>
|
||||
<span class="svc-name">{{ svcName }}</span>
|
||||
</button>
|
||||
<span class="svc-badges">
|
||||
<span class="badge">{{ def.max_mb }} MB</span>
|
||||
<span class="badge">p{{ def.priority }}</span>
|
||||
<span v-if="def.shared" class="badge badge--blue">shared</span>
|
||||
<span v-if="def.managed" class="badge badge--dim">managed</span>
|
||||
<span v-if="def.catalog" class="badge badge--dim">{{ Object.keys(def.catalog).length }} models</span>
|
||||
</span>
|
||||
<div class="svc-btns">
|
||||
<button class="btn-secondary btn-xs" @click="openEditService(String(svcName))">Edit</button>
|
||||
<button class="btn-danger btn-xs" @click="deleteService(String(svcName))">Delete</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Expanded catalog -->
|
||||
<div v-if="expandedSvcs.has(String(svcName))" class="svc-detail">
|
||||
<div class="svc-detail-meta">
|
||||
<span v-if="def.min_compute_cap">min sm{{ def.min_compute_cap }}</span>
|
||||
<span v-if="def.max_concurrent">max_concurrent: {{ def.max_concurrent }}</span>
|
||||
<span v-if="def.idle_stop_after_s">idle_stop: {{ def.idle_stop_after_s }}s</span>
|
||||
<span v-if="def.always_on" class="badge badge--blue">always_on</span>
|
||||
</div>
|
||||
|
||||
<!-- Ollama model suggestions + pull -->
|
||||
<div class="ollama-suggest">
|
||||
<div class="suggest-row">
|
||||
<span class="suggest-label">On node (Ollama):</span>
|
||||
<span v-if="ollamaLoading" class="suggest-loading">loading…</span>
|
||||
<template v-else-if="ollamaNotInCatalog(String(svcName)).length">
|
||||
<button
|
||||
v-for="m in ollamaNotInCatalog(String(svcName))"
|
||||
:key="m.name"
|
||||
class="suggest-chip"
|
||||
:title="`Add ${m.name} (${formatMb(m.size)}) to this service catalog`"
|
||||
@click="openAddFromOllama(String(svcName), m.name)"
|
||||
>
|
||||
+ {{ m.name }} <span class="chip-size">{{ formatMb(m.size) }}</span>
|
||||
</button>
|
||||
</template>
|
||||
<span v-else-if="!ollamaLoading" class="suggest-none">All Ollama models already in catalog.</span>
|
||||
</div>
|
||||
|
||||
<!-- Pull model onto this node -->
|
||||
<div class="pull-row">
|
||||
<input
|
||||
v-model="pullName"
|
||||
class="pull-input"
|
||||
placeholder="Pull model on node (e.g. llama3:8b)"
|
||||
:disabled="pulling"
|
||||
@keyup.enter="doPull"
|
||||
/>
|
||||
<button class="btn-pull" :disabled="pulling || !pullName.trim()" @click="doPull">
|
||||
{{ pulling ? 'Pulling…' : 'Pull' }}
|
||||
</button>
|
||||
</div>
|
||||
<div v-if="pulling || pullPct > 0" class="pull-progress">
|
||||
<div class="pull-bar"><div class="pull-fill" :style="{ width: pullPct + '%' }" /></div>
|
||||
<span class="pull-status">{{ pullStatus }}</span>
|
||||
</div>
|
||||
<div v-if="pullError" class="pull-err" role="alert">{{ pullError }}</div>
|
||||
</div>
|
||||
|
||||
<div class="catalog-header">
|
||||
<span class="catalog-title">Catalog</span>
|
||||
<button class="btn-link" @click="openAddCatalogEntry(String(svcName))">+ Add Model</button>
|
||||
</div>
|
||||
|
||||
<div v-if="!def.catalog || !Object.keys(def.catalog).length" class="catalog-empty">
|
||||
No catalog entries. Only services like cf-text need a catalog.
|
||||
</div>
|
||||
<ul v-else class="catalog-list" role="list">
|
||||
<li
|
||||
v-for="(entry, modelName) in def.catalog"
|
||||
:key="String(modelName)"
|
||||
class="catalog-item"
|
||||
>
|
||||
<span class="catalog-model">{{ modelName }}</span>
|
||||
<span class="catalog-vram">{{ entry.vram_mb }} MB</span>
|
||||
<span v-if="entry.multi_gpu" class="badge badge--dim">multi-gpu</span>
|
||||
<span v-if="entry.description" class="catalog-desc">{{ entry.description }}</span>
|
||||
<div class="catalog-btns">
|
||||
<button class="btn-secondary btn-xs" @click="openEditCatalogEntry(String(svcName), String(modelName))">Edit</button>
|
||||
<button class="btn-danger btn-xs" @click="deleteCatalogEntry(String(svcName), String(modelName))">✕</button>
|
||||
</div>
|
||||
</li>
|
||||
</ul>
|
||||
</div>
|
||||
</li>
|
||||
</ul>
|
||||
</section>
|
||||
|
||||
<!-- Service form modal -->
|
||||
<ServiceFormModal
|
||||
v-if="showSvcModal"
|
||||
:service-name="editingSvcName"
|
||||
:definition="editingSvcDef"
|
||||
@save="onServiceSaved"
|
||||
@cancel="showSvcModal = false"
|
||||
/>
|
||||
|
||||
<!-- Catalog entry form modal -->
|
||||
<CatalogEntryFormModal
|
||||
v-if="showCatalogModal"
|
||||
:svc-name="catalogTargetSvc"
|
||||
:model-name="editingModelName"
|
||||
:entry="editingEntry"
|
||||
@save="onCatalogEntrySaved"
|
||||
@cancel="showCatalogModal = false"
|
||||
/>
|
||||
</template>
|
||||
|
||||
<style scoped>
|
||||
.pep {
|
||||
margin-top: 0.75rem;
|
||||
padding: 1rem;
|
||||
border: 1px solid var(--color-primary);
|
||||
border-radius: 6px;
|
||||
background: var(--color-surface-raised);
|
||||
color: var(--color-text);
|
||||
}
|
||||
.pep-header {
|
||||
display: flex; align-items: center; justify-content: space-between; gap: 0.5rem;
|
||||
margin-bottom: 0.75rem; flex-wrap: wrap;
|
||||
}
|
||||
.pep-title-row { display: flex; align-items: baseline; gap: 0.5rem; }
|
||||
.pep-title { margin: 0; font-size: 0.95rem; font-weight: 600; color: var(--color-text); }
|
||||
.pep-svc-count { font-size: 0.75rem; color: var(--color-text-muted); }
|
||||
.pep-actions { display: flex; align-items: center; gap: 0.4rem; flex-wrap: wrap; }
|
||||
.pep-error { color: var(--color-error); font-size: 0.8rem; margin-bottom: 0.5rem; }
|
||||
.pep-meta {
|
||||
display: flex; align-items: center; gap: 0.5rem; flex-wrap: wrap;
|
||||
padding: 0.5rem; background: var(--color-surface-alt); border-radius: 4px; margin-bottom: 0.75rem;
|
||||
}
|
||||
.meta-label { font-size: 0.8rem; color: var(--color-text-muted); }
|
||||
.meta-input {
|
||||
width: 7rem; background: var(--color-surface); border: 1px solid var(--color-border);
|
||||
border-radius: 4px; padding: 0.2rem 0.4rem; color: var(--color-text); font-size: 0.8rem;
|
||||
}
|
||||
.hw-section {
|
||||
display: flex; flex-wrap: wrap; align-items: center; gap: 0.5rem;
|
||||
font-size: 0.8rem; color: var(--color-text-muted);
|
||||
padding: 0.4rem 0.5rem; border-radius: 4px; background: var(--color-surface-alt);
|
||||
margin-bottom: 0.75rem;
|
||||
}
|
||||
.hw-label { font-weight: 600; color: var(--color-text); }
|
||||
.hw-gpu { font-family: monospace; color: var(--color-text); }
|
||||
.hw-none { font-style: italic; }
|
||||
.svcs-header {
|
||||
display: flex; align-items: center; justify-content: space-between;
|
||||
margin-bottom: 0.5rem;
|
||||
}
|
||||
.svcs-title { font-size: 0.85rem; font-weight: 600; color: var(--color-text); }
|
||||
.svcs-empty { color: var(--color-text-muted); font-size: 0.85rem; padding: 0.5rem 0; }
|
||||
.svcs-list { list-style: none; margin: 0; padding: 0; display: flex; flex-direction: column; gap: 0.4rem; }
|
||||
.svc-item { border: 1px solid var(--color-border); border-radius: 4px; overflow: hidden; }
|
||||
.svc-row {
|
||||
display: flex; align-items: center; gap: 0.5rem; padding: 0.4rem 0.5rem;
|
||||
background: var(--color-surface-alt); flex-wrap: wrap;
|
||||
}
|
||||
.svc-toggle {
|
||||
display: flex; align-items: center; gap: 0.35rem;
|
||||
background: none; border: none; cursor: pointer; color: var(--color-text); padding: 0; flex: 1; min-width: 0;
|
||||
}
|
||||
.svc-arrow { font-size: 0.7rem; color: var(--color-text-muted); }
|
||||
.svc-name { font-size: 0.875rem; font-weight: 500; font-family: monospace; }
|
||||
.svc-badges { display: flex; gap: 0.3rem; flex-wrap: wrap; }
|
||||
.svc-btns { display: flex; gap: 0.3rem; margin-left: auto; }
|
||||
.svc-detail { padding: 0.5rem 0.75rem; display: flex; flex-direction: column; gap: 0.5rem; background: var(--color-surface-raised); }
|
||||
.svc-detail-meta {
|
||||
display: flex; gap: 0.5rem; flex-wrap: wrap;
|
||||
font-size: 0.78rem; color: var(--color-text-muted);
|
||||
}
|
||||
.ollama-suggest {
|
||||
display: flex; flex-direction: column; gap: 0.35rem;
|
||||
padding: 0.4rem 0.5rem;
|
||||
background: var(--color-primary-light);
|
||||
border: 1px solid var(--color-border-light);
|
||||
border-radius: 4px;
|
||||
font-size: 0.78rem;
|
||||
}
|
||||
.suggest-row { display: flex; flex-wrap: wrap; align-items: center; gap: 0.35rem; }
|
||||
.suggest-label { color: var(--color-text-muted); font-weight: 500; white-space: nowrap; }
|
||||
.suggest-loading { color: var(--color-text-muted); font-style: italic; }
|
||||
.suggest-none { color: var(--color-text-muted); font-style: italic; }
|
||||
.suggest-chip {
|
||||
display: inline-flex; align-items: center; gap: 0.25rem;
|
||||
padding: 0.15rem 0.45rem;
|
||||
background: var(--color-surface-raised);
|
||||
border: 1px solid var(--color-border);
|
||||
border-radius: 3px;
|
||||
color: var(--color-text);
|
||||
cursor: pointer;
|
||||
font-size: 0.78rem;
|
||||
transition: border-color 0.15s, background 0.15s;
|
||||
}
|
||||
.suggest-chip:hover { border-color: var(--app-primary); background: var(--color-surface-alt); }
|
||||
.chip-size { color: var(--color-text-muted); font-size: 0.72rem; }
|
||||
.pull-row { display: flex; gap: 0.4rem; align-items: center; }
|
||||
.pull-input {
|
||||
flex: 1;
|
||||
padding: 0.25rem 0.5rem;
|
||||
background: var(--color-surface-raised);
|
||||
border: 1px solid var(--color-border);
|
||||
border-radius: 4px;
|
||||
color: var(--color-text);
|
||||
font-size: 0.78rem;
|
||||
font-family: var(--font-mono, monospace);
|
||||
}
|
||||
.pull-input:disabled { opacity: 0.5; }
|
||||
.btn-pull {
|
||||
padding: 0.25rem 0.6rem;
|
||||
background: var(--app-primary);
|
||||
color: var(--color-text-inverse);
|
||||
border: none;
|
||||
border-radius: 4px;
|
||||
cursor: pointer;
|
||||
font-size: 0.78rem;
|
||||
white-space: nowrap;
|
||||
}
|
||||
.btn-pull:hover:not(:disabled) { background: var(--app-primary-hover); }
|
||||
.btn-pull:disabled { opacity: 0.5; cursor: not-allowed; }
|
||||
.pull-progress { display: flex; align-items: center; gap: 0.4rem; }
|
||||
.pull-bar {
|
||||
flex: 1; height: 6px;
|
||||
background: var(--color-border);
|
||||
border-radius: 3px; overflow: hidden;
|
||||
}
|
||||
.pull-fill { height: 100%; background: var(--app-primary); transition: width 0.2s; }
|
||||
.pull-status { color: var(--color-text-muted); font-size: 0.72rem; white-space: nowrap; max-width: 14rem; overflow: hidden; text-overflow: ellipsis; }
|
||||
.pull-err { color: var(--color-error); font-size: 0.75rem; }
|
||||
.catalog-header { display: flex; align-items: center; justify-content: space-between; }
|
||||
.catalog-title { font-size: 0.8rem; font-weight: 600; color: var(--color-text-muted); text-transform: uppercase; letter-spacing: 0.05em; }
|
||||
.catalog-empty { font-size: 0.8rem; color: var(--color-text-muted); font-style: italic; }
|
||||
.catalog-list { list-style: none; margin: 0; padding: 0; display: flex; flex-direction: column; gap: 0.25rem; }
|
||||
.catalog-item {
|
||||
display: flex; align-items: center; gap: 0.4rem; flex-wrap: wrap;
|
||||
padding: 0.25rem 0.5rem; background: var(--color-surface-alt); border-radius: 3px; font-size: 0.8rem;
|
||||
color: var(--color-text);
|
||||
}
|
||||
.catalog-model { font-family: monospace; flex: 1; min-width: 12rem; }
|
||||
.catalog-vram { color: var(--color-text-muted); white-space: nowrap; }
|
||||
.catalog-desc { color: var(--color-text-muted); flex: 2; font-size: 0.75rem; }
|
||||
.catalog-btns { display: flex; gap: 0.25rem; margin-left: auto; }
|
||||
.badge {
|
||||
padding: 0.1rem 0.4rem; border-radius: 3px; font-size: 0.72rem;
|
||||
background: var(--color-surface); border: 1px solid var(--color-border); color: var(--color-text);
|
||||
}
|
||||
.badge--blue { border-color: var(--color-primary); color: var(--color-primary); background: var(--color-primary-light); }
|
||||
.badge--dim { opacity: 0.75; }
|
||||
.btn-link { background: none; border: none; color: var(--color-accent); cursor: pointer; font-size: 0.8rem; padding: 0; }
|
||||
.btn-link:hover { color: var(--color-accent-hover); }
|
||||
.btn-primary {
|
||||
background: var(--color-primary); color: var(--color-text-inverse); border: none;
|
||||
border-radius: 4px; cursor: pointer; font-size: 0.8rem;
|
||||
}
|
||||
.btn-primary:hover { background: var(--color-primary-hover); }
|
||||
.btn-primary:disabled { opacity: 0.6; cursor: not-allowed; }
|
||||
.btn-secondary {
|
||||
background: transparent; border: 1px solid var(--color-border); color: var(--color-text);
|
||||
border-radius: 4px; cursor: pointer; font-size: 0.8rem;
|
||||
}
|
||||
.btn-secondary:hover { background: var(--color-surface-alt); }
|
||||
.btn-secondary:disabled { opacity: 0.6; cursor: not-allowed; }
|
||||
.btn-danger {
|
||||
background: transparent; border: 1px solid var(--color-error); color: var(--color-error);
|
||||
border-radius: 4px; cursor: pointer; font-size: 0.8rem;
|
||||
}
|
||||
.btn-danger:hover { background: var(--color-surface-alt); }
|
||||
.btn-sm { padding: 0.3rem 0.6rem; }
|
||||
.btn-xs { padding: 0.15rem 0.4rem; }
|
||||
.btn-icon-lg { background: none; border: none; color: var(--color-text-muted); cursor: pointer; font-size: 1rem; padding: 0.2rem 0.3rem; }
|
||||
.btn-icon-lg:hover { color: var(--color-text); }
|
||||
</style>
|
||||
|
|
@ -1,82 +0,0 @@
|
|||
<script setup lang="ts">
|
||||
type ServiceState =
|
||||
| 'running'
|
||||
| 'stopped'
|
||||
| 'assigned-only'
|
||||
| 'available'
|
||||
| 'incompatible'
|
||||
| 'vram-tight'
|
||||
| 'unknown'
|
||||
|
||||
const props = defineProps<{
|
||||
serviceName: string
|
||||
state: ServiceState
|
||||
assigned: boolean
|
||||
disabled?: boolean
|
||||
}>()
|
||||
|
||||
const emit = defineEmits<{ toggle: [] }>()
|
||||
|
||||
const STATE_LABELS: Record<ServiceState, string> = {
|
||||
running: 'Running',
|
||||
stopped: 'Stopped',
|
||||
'assigned-only': 'Assigned',
|
||||
available: 'Available',
|
||||
incompatible: 'Incompatible',
|
||||
'vram-tight': 'VRAM tight',
|
||||
unknown: 'Unknown',
|
||||
}
|
||||
|
||||
const STATE_ICONS: Record<ServiceState, string> = {
|
||||
running: '▶',
|
||||
stopped: '⏹',
|
||||
'assigned-only': '📌',
|
||||
available: '○',
|
||||
incompatible: '✕',
|
||||
'vram-tight': '⚠',
|
||||
unknown: '?',
|
||||
}
|
||||
|
||||
function handleToggle() {
|
||||
if (!props.disabled && props.state !== 'incompatible') emit('toggle')
|
||||
}
|
||||
</script>
|
||||
|
||||
<template>
|
||||
<button
|
||||
class="service-badge"
|
||||
:class="[`state-${state}`, { assigned, 'is-disabled': disabled || state === 'incompatible' }]"
|
||||
:aria-pressed="assigned"
|
||||
:aria-label="`${serviceName}: ${STATE_LABELS[state] ?? state}${assigned ? ' (assigned)' : ''}`"
|
||||
:disabled="disabled || state === 'incompatible'"
|
||||
@click="handleToggle"
|
||||
>
|
||||
<span class="badge-icon" aria-hidden="true">{{ STATE_ICONS[state] ?? '?' }}</span>
|
||||
<span class="badge-name">{{ serviceName }}</span>
|
||||
<span class="badge-state">{{ STATE_LABELS[state] ?? state }}</span>
|
||||
</button>
|
||||
</template>
|
||||
|
||||
<style scoped>
|
||||
.service-badge {
|
||||
display: inline-flex;
|
||||
align-items: center;
|
||||
gap: 0.3rem;
|
||||
padding: 0.2rem 0.5rem;
|
||||
border-radius: 4px;
|
||||
border: 1px solid var(--color-border);
|
||||
background: var(--color-surface);
|
||||
color: var(--color-text);
|
||||
font-size: 0.75rem;
|
||||
cursor: pointer;
|
||||
transition: opacity 0.1s, border-color 0.1s;
|
||||
}
|
||||
.service-badge:hover:not(.is-disabled) { opacity: 0.8; }
|
||||
.service-badge.is-disabled { cursor: not-allowed; opacity: 0.5; }
|
||||
.service-badge.state-running { border-color: var(--color-success); }
|
||||
.service-badge.state-stopped { border-color: var(--color-warning); }
|
||||
.service-badge.state-assigned-only { border-color: var(--color-info); }
|
||||
.service-badge.state-incompatible { border-color: var(--color-error); }
|
||||
.service-badge.state-vram-tight { border-color: var(--color-warning); }
|
||||
.badge-state { color: var(--color-text-muted); }
|
||||
</style>
|
||||
|
|
@ -1,231 +0,0 @@
|
|||
<script setup lang="ts">
|
||||
import { ref, watch, computed } from 'vue'
|
||||
import type { ServiceDefinition } from '../../types/nodes'
|
||||
|
||||
const props = defineProps<{
|
||||
serviceName?: string
|
||||
definition?: ServiceDefinition
|
||||
}>()
|
||||
const emit = defineEmits<{
|
||||
save: [name: string, def: ServiceDefinition]
|
||||
cancel: []
|
||||
}>()
|
||||
|
||||
const name = ref(props.serviceName ?? '')
|
||||
const maxMb = ref(props.definition?.max_mb ?? 0)
|
||||
const priority = ref(props.definition?.priority ?? 1)
|
||||
const minCap = ref(props.definition?.min_compute_cap ?? 0)
|
||||
const prefCap = ref<number | ''>(props.definition?.preferred_compute_cap ?? '')
|
||||
const shared = ref(props.definition?.shared ?? false)
|
||||
const maxConcurrent = ref<number | ''>(props.definition?.max_concurrent ?? '')
|
||||
const idleStop = ref<number | ''>(props.definition?.idle_stop_after_s ?? '')
|
||||
const alwaysOn = ref(props.definition?.always_on ?? false)
|
||||
const modelBasePath = ref(props.definition?.model_base_path ?? '')
|
||||
const hasManaged = ref(!!props.definition?.managed)
|
||||
const managedJson = ref(
|
||||
props.definition?.managed ? JSON.stringify(props.definition.managed, null, 2) : ''
|
||||
)
|
||||
const formError = ref('')
|
||||
|
||||
watch(() => props.definition, (d) => {
|
||||
name.value = props.serviceName ?? ''
|
||||
maxMb.value = d?.max_mb ?? 0
|
||||
priority.value = d?.priority ?? 1
|
||||
minCap.value = d?.min_compute_cap ?? 0
|
||||
prefCap.value = d?.preferred_compute_cap ?? ''
|
||||
shared.value = d?.shared ?? false
|
||||
maxConcurrent.value = d?.max_concurrent ?? ''
|
||||
idleStop.value = d?.idle_stop_after_s ?? ''
|
||||
alwaysOn.value = d?.always_on ?? false
|
||||
modelBasePath.value = d?.model_base_path ?? ''
|
||||
hasManaged.value = !!d?.managed
|
||||
managedJson.value = d?.managed ? JSON.stringify(d.managed, null, 2) : ''
|
||||
})
|
||||
|
||||
const managedJsonError = computed(() => {
|
||||
if (!hasManaged.value || !managedJson.value.trim()) return ''
|
||||
try { JSON.parse(managedJson.value); return '' }
|
||||
catch { return 'Invalid JSON' }
|
||||
})
|
||||
|
||||
function submit() {
|
||||
formError.value = ''
|
||||
if (!name.value.trim()) { formError.value = 'Service name is required.'; return }
|
||||
if (!maxMb.value || maxMb.value <= 0) { formError.value = 'max_mb must be > 0.'; return }
|
||||
if (managedJsonError.value) { formError.value = 'Fix the managed JSON before saving.'; return }
|
||||
|
||||
const def: ServiceDefinition = { max_mb: maxMb.value, priority: priority.value }
|
||||
if (minCap.value) def.min_compute_cap = minCap.value
|
||||
if (prefCap.value !== '') def.preferred_compute_cap = Number(prefCap.value)
|
||||
if (shared.value) def.shared = true
|
||||
if (maxConcurrent.value !== '') def.max_concurrent = Number(maxConcurrent.value)
|
||||
if (idleStop.value !== '') def.idle_stop_after_s = Number(idleStop.value)
|
||||
if (alwaysOn.value) def.always_on = true
|
||||
if (modelBasePath.value.trim()) def.model_base_path = modelBasePath.value.trim()
|
||||
if (hasManaged.value && managedJson.value.trim()) {
|
||||
def.managed = JSON.parse(managedJson.value)
|
||||
}
|
||||
// Preserve existing catalog when editing
|
||||
if (props.definition?.catalog) def.catalog = props.definition.catalog
|
||||
|
||||
emit('save', name.value.trim(), def)
|
||||
}
|
||||
</script>
|
||||
|
||||
<template>
|
||||
<div class="modal-backdrop" role="dialog" aria-modal="true" :aria-label="`${serviceName ? 'Edit' : 'Add'} service`">
|
||||
<div class="modal-box">
|
||||
<h3 class="modal-title">{{ serviceName ? 'Edit' : 'Add' }} Service</h3>
|
||||
|
||||
<div class="field-row">
|
||||
<label class="field-label" for="sf-name">Service name</label>
|
||||
<input id="sf-name" v-model="name" class="field-input" :readonly="!!serviceName" placeholder="cf-text" />
|
||||
</div>
|
||||
|
||||
<div class="field-row">
|
||||
<label class="field-label" for="sf-maxmb">max_mb</label>
|
||||
<input id="sf-maxmb" v-model.number="maxMb" type="number" min="0" class="field-input field-input--sm" />
|
||||
<span class="field-hint">VRAM ceiling</span>
|
||||
</div>
|
||||
|
||||
<div class="field-row">
|
||||
<label class="field-label" for="sf-prio">priority</label>
|
||||
<input id="sf-prio" v-model.number="priority" type="number" min="1" max="10" class="field-input field-input--sm" />
|
||||
<span class="field-hint">1 = highest</span>
|
||||
</div>
|
||||
|
||||
<div class="field-row">
|
||||
<label class="field-label" for="sf-mincap">min_compute_cap</label>
|
||||
<input id="sf-mincap" v-model.number="minCap" type="number" step="0.1" min="0" class="field-input field-input--sm" placeholder="0.0" />
|
||||
</div>
|
||||
|
||||
<div class="field-row">
|
||||
<label class="field-label" for="sf-prefcap">preferred_cap</label>
|
||||
<input id="sf-prefcap" v-model="prefCap" type="number" step="0.1" min="0" class="field-input field-input--sm" placeholder="optional" />
|
||||
</div>
|
||||
|
||||
<div class="field-row field-row--check">
|
||||
<input id="sf-shared" v-model="shared" type="checkbox" />
|
||||
<label for="sf-shared">shared (multiple concurrent users)</label>
|
||||
</div>
|
||||
|
||||
<div class="field-row">
|
||||
<label class="field-label" for="sf-maxcon">max_concurrent</label>
|
||||
<input id="sf-maxcon" v-model="maxConcurrent" type="number" min="1" class="field-input field-input--sm" placeholder="optional" />
|
||||
</div>
|
||||
|
||||
<div class="field-row">
|
||||
<label class="field-label" for="sf-idle">idle_stop_after_s</label>
|
||||
<input id="sf-idle" v-model="idleStop" type="number" min="0" class="field-input field-input--sm" placeholder="optional" />
|
||||
<span class="field-hint">seconds</span>
|
||||
</div>
|
||||
|
||||
<div class="field-row field-row--check">
|
||||
<input id="sf-always" v-model="alwaysOn" type="checkbox" />
|
||||
<label for="sf-always">always_on (never evict)</label>
|
||||
</div>
|
||||
|
||||
<div class="field-row">
|
||||
<label class="field-label" for="sf-base">model_base_path</label>
|
||||
<input id="sf-base" v-model="modelBasePath" class="field-input" placeholder="/devl/Assets/LLM/cf-text/models (optional)" />
|
||||
</div>
|
||||
|
||||
<div class="managed-section">
|
||||
<div class="field-row field-row--check">
|
||||
<input id="sf-has-managed" v-model="hasManaged" type="checkbox" />
|
||||
<label for="sf-has-managed">Has managed process config</label>
|
||||
</div>
|
||||
<div v-if="hasManaged" class="managed-body">
|
||||
<label class="field-label" for="sf-managed">managed (JSON)</label>
|
||||
<textarea
|
||||
id="sf-managed"
|
||||
v-model="managedJson"
|
||||
class="field-textarea"
|
||||
rows="6"
|
||||
spellcheck="false"
|
||||
placeholder='{"type": "process", "exec_path": "...", "args_template": "...", "port": 8008, "host_port": 8008}'
|
||||
/>
|
||||
<span v-if="managedJsonError" class="json-error" role="alert">{{ managedJsonError }}</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div v-if="formError" class="form-error" role="alert">{{ formError }}</div>
|
||||
|
||||
<div class="modal-actions">
|
||||
<button class="btn-secondary" @click="emit('cancel')">Cancel</button>
|
||||
<button class="btn-primary" @click="submit">Save</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<style scoped>
|
||||
.modal-backdrop {
|
||||
position: fixed; inset: 0;
|
||||
background: rgba(0,0,0,0.5);
|
||||
display: flex; align-items: center; justify-content: center;
|
||||
z-index: 200;
|
||||
}
|
||||
.modal-box {
|
||||
background: var(--color-surface-raised);
|
||||
border: 1px solid var(--color-border);
|
||||
border-radius: 8px;
|
||||
padding: 1.5rem;
|
||||
width: 100%; max-width: 540px;
|
||||
max-height: 90vh; overflow-y: auto;
|
||||
display: flex; flex-direction: column; gap: 0.65rem;
|
||||
color: var(--color-text);
|
||||
}
|
||||
.modal-title { margin: 0 0 0.25rem; font-size: 1rem; font-weight: 600; color: var(--color-text); }
|
||||
.field-row { display: flex; align-items: center; gap: 0.5rem; }
|
||||
.field-row--check { gap: 0.4rem; font-size: 0.875rem; color: var(--color-text); }
|
||||
.field-label { min-width: 9rem; font-size: 0.85rem; color: var(--color-text-muted); flex-shrink: 0; }
|
||||
.field-hint { font-size: 0.75rem; color: var(--color-text-muted); }
|
||||
.field-input {
|
||||
flex: 1;
|
||||
background: var(--color-surface-alt);
|
||||
border: 1px solid var(--color-border);
|
||||
border-radius: 4px;
|
||||
padding: 0.3rem 0.5rem;
|
||||
color: var(--color-text);
|
||||
font-size: 0.85rem;
|
||||
}
|
||||
.field-input--sm { flex: 0 0 8rem; }
|
||||
.managed-section { display: flex; flex-direction: column; gap: 0.4rem; border-top: 1px solid var(--color-border); padding-top: 0.5rem; }
|
||||
.managed-body { display: flex; flex-direction: column; gap: 0.3rem; }
|
||||
.field-textarea {
|
||||
width: 100%;
|
||||
background: var(--color-surface-alt);
|
||||
border: 1px solid var(--color-border);
|
||||
border-radius: 4px;
|
||||
padding: 0.4rem 0.5rem;
|
||||
color: var(--color-text);
|
||||
font-size: 0.8rem;
|
||||
font-family: var(--font-mono, monospace);
|
||||
resize: vertical;
|
||||
box-sizing: border-box;
|
||||
}
|
||||
.json-error { color: var(--color-error); font-size: 0.78rem; }
|
||||
.form-error { color: var(--color-error); font-size: 0.8rem; }
|
||||
.modal-actions { display: flex; justify-content: flex-end; gap: 0.5rem; margin-top: 0.25rem; }
|
||||
.btn-primary {
|
||||
background: var(--app-primary);
|
||||
color: var(--color-text-inverse);
|
||||
border: none;
|
||||
border-radius: 4px;
|
||||
padding: 0.4rem 1rem;
|
||||
cursor: pointer;
|
||||
font-size: 0.875rem;
|
||||
}
|
||||
.btn-primary:hover { background: var(--app-primary-hover); }
|
||||
.btn-secondary {
|
||||
background: transparent;
|
||||
border: 1px solid var(--color-border);
|
||||
color: var(--color-text);
|
||||
border-radius: 4px;
|
||||
padding: 0.4rem 0.75rem;
|
||||
cursor: pointer;
|
||||
font-size: 0.875rem;
|
||||
}
|
||||
.btn-secondary:hover { background: var(--color-surface-alt); }
|
||||
</style>
|
||||
|
|
@ -1,53 +1,23 @@
|
|||
import { createRouter, createWebHashHistory } from 'vue-router'
|
||||
import LabelView from '../views/LabelView.vue'
|
||||
|
||||
// Lazy-loaded views
|
||||
const DashboardView = () => import('../views/DashboardView.vue')
|
||||
const LabelView = () => import('../views/LabelView.vue')
|
||||
const FetchView = () => import('../views/FetchView.vue')
|
||||
const CorrectionsView = () => import('../views/CorrectionsView.vue')
|
||||
const ImitateView = () => import('../views/ImitateView.vue')
|
||||
const BenchmarkView = () => import('../views/BenchmarkView.vue')
|
||||
const CompareView = () => import('../views/CompareView.vue')
|
||||
const TrainJobsView = () => import('../views/TrainJobsView.vue')
|
||||
const TrainResultsView = () => import('../views/TrainResultsView.vue')
|
||||
const ModelsView = () => import('../views/ModelsView.vue')
|
||||
const SettingsView = () => import('../views/SettingsView.vue')
|
||||
const NodeManagementView = () => import('../views/NodeManagementView.vue')
|
||||
|
||||
export const routes = [
|
||||
// ── Top-level ────────────────────────────────────────────
|
||||
{ path: '/', component: DashboardView, meta: { title: 'Dashboard' } },
|
||||
{ path: '/fleet', component: ModelsView, meta: { title: 'Fleet' } },
|
||||
{ path: '/nodes', component: NodeManagementView, meta: { title: 'Nodes' } },
|
||||
{ path: '/settings', component: SettingsView, meta: { title: 'Settings' } },
|
||||
|
||||
// ── Data domain ──────────────────────────────────────────
|
||||
{ path: '/data/label', component: LabelView, meta: { title: 'Label' } },
|
||||
{ path: '/data/fetch', component: FetchView, meta: { title: 'Fetch' } },
|
||||
{ path: '/data/corrections', component: CorrectionsView, meta: { title: 'Corrections' } },
|
||||
{ path: '/data/imitate', component: ImitateView, meta: { title: 'Imitate' } },
|
||||
{ path: '/data/recipe-scan', component: () => import('../views/RecipeScanView.vue'), meta: { title: 'Recipe Scan' } },
|
||||
|
||||
// ── Eval domain ──────────────────────────────────────────
|
||||
{ path: '/eval/benchmark', component: BenchmarkView, meta: { title: 'Benchmark' } },
|
||||
{ path: '/eval/compare', component: CompareView, meta: { title: 'Compare' } },
|
||||
{ path: '/eval/embed-compare', component: () => import('../views/EmbedCompareView.vue'), meta: { title: 'Embed Compare' } },
|
||||
|
||||
// ── Train domain ─────────────────────────────────────────
|
||||
{ path: '/train/jobs', component: TrainJobsView, meta: { title: 'Training Jobs' } },
|
||||
{ path: '/train/results', component: TrainResultsView, meta: { title: 'Training Results' } },
|
||||
|
||||
// ── Backward-compat redirects ────────────────────────────
|
||||
{ path: '/benchmark', redirect: '/eval/benchmark' },
|
||||
{ path: '/models', redirect: '/fleet' },
|
||||
{ path: '/stats', redirect: '/' },
|
||||
{ path: '/label', redirect: '/data/label' },
|
||||
{ path: '/fetch', redirect: '/data/fetch' },
|
||||
{ path: '/corrections', redirect: '/data/corrections' },
|
||||
{ path: '/imitate', redirect: '/data/imitate' },
|
||||
]
|
||||
// Views are lazy-loaded to keep initial bundle small
|
||||
const FetchView = () => import('../views/FetchView.vue')
|
||||
const StatsView = () => import('../views/StatsView.vue')
|
||||
const BenchmarkView = () => import('../views/BenchmarkView.vue')
|
||||
const SettingsView = () => import('../views/SettingsView.vue')
|
||||
const CorrectionsView = () => import('../views/CorrectionsView.vue')
|
||||
const ModelsView = () => import('../views/ModelsView.vue')
|
||||
|
||||
export const router = createRouter({
|
||||
history: createWebHashHistory(),
|
||||
routes,
|
||||
routes: [
|
||||
{ path: '/', component: LabelView, meta: { title: 'Label' } },
|
||||
{ path: '/fetch', component: FetchView, meta: { title: 'Fetch' } },
|
||||
{ path: '/stats', component: StatsView, meta: { title: 'Stats' } },
|
||||
{ path: '/benchmark', component: BenchmarkView, meta: { title: 'Benchmark' } },
|
||||
{ path: '/models', component: ModelsView, meta: { title: 'Models' } },
|
||||
{ path: '/corrections', component: CorrectionsView, meta: { title: 'Corrections' } },
|
||||
{ path: '/settings', component: SettingsView, meta: { title: 'Settings' } },
|
||||
],
|
||||
})
|
||||
|
|
|
|||
|
|
@ -1,94 +0,0 @@
|
|||
import { describe, it, expect } from 'vitest'
|
||||
import { createRouter, createWebHashHistory } from 'vue-router'
|
||||
|
||||
// Import the raw routes array so we can test structure without mounting App
|
||||
import { routes } from './index'
|
||||
|
||||
describe('router routes', () => {
|
||||
it('exports a routes array', () => {
|
||||
expect(Array.isArray(routes)).toBe(true)
|
||||
})
|
||||
|
||||
it('has / pointing to DashboardView', () => {
|
||||
const root = routes.find(r => r.path === '/')
|
||||
expect(root).toBeDefined()
|
||||
// Component should be async (lazy) or have a name
|
||||
expect(root?.component).toBeDefined()
|
||||
})
|
||||
|
||||
it('has /fleet route', () => {
|
||||
const r = routes.find(r => r.path === '/fleet')
|
||||
expect(r).toBeDefined()
|
||||
})
|
||||
|
||||
it('has /data/label route', () => {
|
||||
const r = routes.find(r => r.path === '/data/label')
|
||||
expect(r).toBeDefined()
|
||||
})
|
||||
|
||||
it('has /data/fetch route', () => {
|
||||
const r = routes.find(r => r.path === '/data/fetch')
|
||||
expect(r).toBeDefined()
|
||||
})
|
||||
|
||||
it('has /data/corrections route', () => {
|
||||
const r = routes.find(r => r.path === '/data/corrections')
|
||||
expect(r).toBeDefined()
|
||||
})
|
||||
|
||||
it('has /data/imitate route', () => {
|
||||
const r = routes.find(r => r.path === '/data/imitate')
|
||||
expect(r).toBeDefined()
|
||||
})
|
||||
|
||||
it('has /eval/benchmark route', () => {
|
||||
const r = routes.find(r => r.path === '/eval/benchmark')
|
||||
expect(r).toBeDefined()
|
||||
})
|
||||
|
||||
it('has /eval/compare route', () => {
|
||||
const r = routes.find(r => r.path === '/eval/compare')
|
||||
expect(r).toBeDefined()
|
||||
})
|
||||
|
||||
it('has /train/jobs route', () => {
|
||||
const r = routes.find(r => r.path === '/train/jobs')
|
||||
expect(r).toBeDefined()
|
||||
})
|
||||
|
||||
it('has /train/results route', () => {
|
||||
const r = routes.find(r => r.path === '/train/results')
|
||||
expect(r).toBeDefined()
|
||||
})
|
||||
|
||||
it('has /settings route', () => {
|
||||
const r = routes.find(r => r.path === '/settings')
|
||||
expect(r).toBeDefined()
|
||||
})
|
||||
|
||||
it('has backward-compat redirect from /benchmark to /eval/benchmark', () => {
|
||||
const r = routes.find(r => r.path === '/benchmark')
|
||||
expect(r).toBeDefined()
|
||||
expect((r as { redirect?: string }).redirect).toBe('/eval/benchmark')
|
||||
})
|
||||
|
||||
it('has backward-compat redirect from /models to /fleet', () => {
|
||||
const r = routes.find(r => r.path === '/models')
|
||||
expect(r).toBeDefined()
|
||||
expect((r as { redirect?: string }).redirect).toBe('/fleet')
|
||||
})
|
||||
|
||||
it('has backward-compat redirect from /stats to /', () => {
|
||||
const r = routes.find(r => r.path === '/stats')
|
||||
expect(r).toBeDefined()
|
||||
expect((r as { redirect?: string }).redirect).toBe('/')
|
||||
})
|
||||
|
||||
it('can create a functional router instance', () => {
|
||||
const router = createRouter({
|
||||
history: createWebHashHistory(),
|
||||
routes,
|
||||
})
|
||||
expect(router).toBeDefined()
|
||||
})
|
||||
})
|
||||
|
|
@ -1,89 +0,0 @@
|
|||
export interface GpuEntry {
|
||||
gpu_id: number
|
||||
card: string
|
||||
vram_total_mb: number
|
||||
vram_used_mb: number
|
||||
vram_free_mb: number
|
||||
temp_c: number | null
|
||||
utilization_pct: number | null
|
||||
compute_cap: number | null
|
||||
services_assigned: string[]
|
||||
services_running: string[]
|
||||
}
|
||||
|
||||
export interface ServiceInfo {
|
||||
min_compute_cap: number
|
||||
max_mb: number
|
||||
catalog_size: number
|
||||
}
|
||||
|
||||
export interface NodeSummary {
|
||||
node_id: string
|
||||
online: boolean
|
||||
agent_url: string
|
||||
gpus: GpuEntry[]
|
||||
profile_loaded: boolean
|
||||
services_catalog: Record<string, ServiceInfo>
|
||||
}
|
||||
|
||||
// ── Full profile types (for profile editor) ────────────────────────────────────
|
||||
|
||||
export interface ServiceManaged {
|
||||
type: string
|
||||
exec_path?: string
|
||||
args_template?: string
|
||||
port?: number
|
||||
host_port?: number
|
||||
base_port?: number
|
||||
health_path?: string
|
||||
cwd?: string
|
||||
adopt?: boolean
|
||||
[key: string]: unknown
|
||||
}
|
||||
|
||||
export interface CatalogEntryFull {
|
||||
path: string
|
||||
vram_mb: number
|
||||
description?: string
|
||||
multi_gpu?: boolean
|
||||
env?: Record<string, string>
|
||||
}
|
||||
|
||||
export interface ServiceDefinition {
|
||||
max_mb: number
|
||||
priority: number
|
||||
min_compute_cap?: number
|
||||
preferred_compute_cap?: number
|
||||
shared?: boolean
|
||||
max_concurrent?: number
|
||||
idle_stop_after_s?: number
|
||||
always_on?: boolean
|
||||
model_base_path?: string
|
||||
managed?: ServiceManaged
|
||||
catalog?: Record<string, CatalogEntryFull>
|
||||
}
|
||||
|
||||
export interface NodeHardwareGpu {
|
||||
id: number
|
||||
vram_mb: number
|
||||
compute_cap?: number
|
||||
card?: string
|
||||
role?: string
|
||||
services?: string[]
|
||||
}
|
||||
|
||||
export interface NodeHardwareEntry {
|
||||
local_model_root?: string
|
||||
agent_url?: string
|
||||
gpus: NodeHardwareGpu[]
|
||||
}
|
||||
|
||||
export interface FullProfile {
|
||||
schema_version?: number
|
||||
name?: string
|
||||
vram_total_mb?: number
|
||||
eviction_timeout_s?: number
|
||||
services: Record<string, ServiceDefinition>
|
||||
nodes: Record<string, NodeHardwareEntry>
|
||||
model_size_hints?: Record<string, string>
|
||||
}
|
||||
|
|
@ -1,987 +0,0 @@
|
|||
<template>
|
||||
<div class="assignments-tab">
|
||||
|
||||
<!-- ── Toast ───────────────────────────────────────────── -->
|
||||
<div v-if="toast" class="toast" :class="toast.type" role="status" aria-live="polite">
|
||||
{{ toast.message }}
|
||||
</div>
|
||||
|
||||
<!-- ── Assignments section ─────────────────────────────── -->
|
||||
<div class="section-header">
|
||||
<h2 class="section-title">Task Assignments</h2>
|
||||
<button class="btn-primary btn-sm" @click="openNewAssignment">+ New Assignment</button>
|
||||
</div>
|
||||
|
||||
<div class="filter-row">
|
||||
<label for="product-filter" class="filter-label">Product</label>
|
||||
<select id="product-filter" v-model="productFilter" class="filter-select">
|
||||
<option value="">All products</option>
|
||||
<option v-for="p in allProducts" :key="p" :value="p">{{ p }}</option>
|
||||
</select>
|
||||
</div>
|
||||
|
||||
<div v-if="assignmentsLoading" class="empty-state">Loading assignments…</div>
|
||||
<div v-else-if="assignmentsError" class="error-notice" role="alert">{{ assignmentsError }}</div>
|
||||
<div v-else-if="filteredGroups.length === 0" class="empty-state">No assignments yet. Add one above.</div>
|
||||
<div v-else class="product-groups">
|
||||
<div v-for="group in filteredGroups" :key="group.product" class="product-group">
|
||||
<h3 class="product-name">{{ group.product.toUpperCase() }}</h3>
|
||||
<div class="assignment-list">
|
||||
<div v-for="a in group.assignments" :key="`${a.product}/${a.task}`" class="assignment-row">
|
||||
<div class="assignment-main">
|
||||
<span class="task-id">{{ a.task }}</span>
|
||||
<span
|
||||
class="model-name"
|
||||
:title="a.model_id"
|
||||
>{{ displayModelId(a) }}</span>
|
||||
<span v-if="a.vram_mb" class="chip chip-vram">{{ formatVram(a.vram_mb) }}</span>
|
||||
<span v-if="a.service_type" class="chip" :class="serviceChipClass(a.service_type)">{{ a.service_type }}</span>
|
||||
</div>
|
||||
|
||||
<!-- Node deployment status -->
|
||||
<div v-if="deploymentMap[`${a.product}/${a.task}`]" class="node-statuses">
|
||||
<span
|
||||
v-for="ns in deploymentMap[`${a.product}/${a.task}`]"
|
||||
:key="ns.node_id"
|
||||
class="node-badge-wrap"
|
||||
>
|
||||
<span
|
||||
class="node-badge"
|
||||
:class="ns.status"
|
||||
:title="`${ns.node_id}: ${ns.status}`"
|
||||
>
|
||||
<span class="node-icon">{{ nodeIcon(ns.status) }}</span>
|
||||
{{ ns.node_id }}
|
||||
</span>
|
||||
<button
|
||||
v-if="ns.status === 'absent'"
|
||||
class="btn-deploy"
|
||||
:disabled="deploying.has(`${a.product}/${a.task}/${ns.node_id}`)"
|
||||
:title="`Register ${a.model_id} in ${ns.node_id} catalog`"
|
||||
@click="deployModel(a, ns.node_id)"
|
||||
>{{ deploying.has(`${a.product}/${a.task}/${ns.node_id}`) ? '…' : 'Register' }}</button>
|
||||
</span>
|
||||
</div>
|
||||
|
||||
<div class="assignment-actions">
|
||||
<button
|
||||
v-if="editingKey !== `${a.product}/${a.task}`"
|
||||
class="btn-ghost btn-sm"
|
||||
@click="startEdit(a)"
|
||||
>Edit</button>
|
||||
<button
|
||||
class="btn-ghost btn-sm btn-danger"
|
||||
@click="deleteAssignment(a.product, a.task)"
|
||||
>Delete</button>
|
||||
</div>
|
||||
|
||||
<!-- Inline edit form -->
|
||||
<div v-if="editingKey === `${a.product}/${a.task}`" class="inline-edit">
|
||||
<select v-model="editDraft.model_id" class="edit-select" aria-label="Model">
|
||||
<option value="" disabled>Select model…</option>
|
||||
<option v-for="m in registryModels" :key="m.model_id" :value="m.model_id">
|
||||
{{ m.alias || truncate(m.model_id, 40) }}
|
||||
</option>
|
||||
</select>
|
||||
<input
|
||||
v-model="editDraft.description"
|
||||
type="text"
|
||||
class="edit-input"
|
||||
placeholder="Description (optional)"
|
||||
/>
|
||||
<div class="inline-edit-btns">
|
||||
<button class="btn-primary btn-sm" :disabled="!editDraft.model_id" @click="saveEdit(a)">Save</button>
|
||||
<button class="btn-ghost btn-sm" @click="editingKey = null">Cancel</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- ── Model Registry section ───────────────────────────── -->
|
||||
<div class="section-header section-header-mt">
|
||||
<h2 class="section-title">Model Registry</h2>
|
||||
<button class="btn-primary btn-sm" @click="showRegisterModal = true">Register Model</button>
|
||||
</div>
|
||||
|
||||
<div v-if="registryLoading" class="empty-state">Loading model registry…</div>
|
||||
<div v-else-if="registryError" class="error-notice" role="alert">{{ registryError }}</div>
|
||||
<div v-else-if="registryModels.length === 0" class="empty-state">No models registered yet.</div>
|
||||
<div v-else class="registry-table-wrap">
|
||||
<table class="registry-table">
|
||||
<thead>
|
||||
<tr>
|
||||
<th>Alias</th>
|
||||
<th>Model ID</th>
|
||||
<th>VRAM</th>
|
||||
<th>Service</th>
|
||||
<th class="col-hf">HF Repo</th>
|
||||
<th></th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
<tr v-for="m in registryModels" :key="m.model_id">
|
||||
<td>{{ m.alias || '—' }}</td>
|
||||
<td>
|
||||
<span class="truncated" :title="m.model_id">{{ truncate(m.model_id, 36) }}</span>
|
||||
</td>
|
||||
<td>{{ formatVram(m.vram_mb) }}</td>
|
||||
<td><span class="chip" :class="serviceChipClass(m.service_type)">{{ m.service_type }}</span></td>
|
||||
<td class="col-hf">
|
||||
<a
|
||||
v-if="m.hf_repo"
|
||||
:href="`https://huggingface.co/${m.hf_repo}`"
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
class="hf-link"
|
||||
>{{ truncate(m.hf_repo, 30) }}</a>
|
||||
<span v-else class="text-muted">—</span>
|
||||
</td>
|
||||
<td>
|
||||
<button class="btn-ghost btn-sm btn-danger" @click="deleteModel(m.model_id)">Delete</button>
|
||||
</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
|
||||
<!-- ── New Assignment modal ─────────────────────────────── -->
|
||||
<div v-if="showNewAssignmentModal" class="modal-backdrop" @click.self="showNewAssignmentModal = false">
|
||||
<div class="modal" role="dialog" aria-modal="true" aria-labelledby="modal-new-assignment-title">
|
||||
<h3 id="modal-new-assignment-title" class="modal-title">New Assignment</h3>
|
||||
<label class="form-label">Product</label>
|
||||
<input
|
||||
v-model="newAssignment.product"
|
||||
list="product-list"
|
||||
class="form-input"
|
||||
placeholder="e.g. peregrine"
|
||||
autocomplete="off"
|
||||
/>
|
||||
<datalist id="product-list">
|
||||
<option v-for="p in allProducts" :key="p" :value="p" />
|
||||
</datalist>
|
||||
|
||||
<label class="form-label">Task ID</label>
|
||||
<input
|
||||
v-model="newAssignment.task"
|
||||
type="text"
|
||||
class="form-input"
|
||||
placeholder="e.g. cover_letter"
|
||||
/>
|
||||
|
||||
<label class="form-label">Model</label>
|
||||
<select v-model="newAssignment.model_id" class="form-select">
|
||||
<option value="" disabled>Select from registry…</option>
|
||||
<option v-for="m in registryModels" :key="m.model_id" :value="m.model_id">
|
||||
{{ m.alias || truncate(m.model_id, 50) }}
|
||||
</option>
|
||||
</select>
|
||||
|
||||
<label class="form-label">Description <span class="optional">(optional)</span></label>
|
||||
<input
|
||||
v-model="newAssignment.description"
|
||||
type="text"
|
||||
class="form-input"
|
||||
placeholder="Human-readable note for operators"
|
||||
/>
|
||||
|
||||
<div class="modal-actions">
|
||||
<button
|
||||
class="btn-primary"
|
||||
:disabled="!newAssignment.product || !newAssignment.task || !newAssignment.model_id || saving"
|
||||
@click="saveNewAssignment"
|
||||
>{{ saving ? 'Saving…' : 'Save' }}</button>
|
||||
<button class="btn-ghost" @click="showNewAssignmentModal = false">Cancel</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- ── Register Model modal ─────────────────────────────── -->
|
||||
<div v-if="showRegisterModal" class="modal-backdrop" @click.self="showRegisterModal = false">
|
||||
<div class="modal" role="dialog" aria-modal="true" aria-labelledby="modal-register-title">
|
||||
<h3 id="modal-register-title" class="modal-title">Register Model</h3>
|
||||
|
||||
<label class="form-label">Model ID <span class="hint">(HuggingFace slug, e.g. ibm-granite/granite-4.1-8b)</span></label>
|
||||
<input v-model="newModel.model_id" type="text" class="form-input" placeholder="org/model-name" />
|
||||
|
||||
<label class="form-label">Alias <span class="optional">(optional, short name for assignments)</span></label>
|
||||
<input v-model="newModel.alias" type="text" class="form-input" placeholder="e.g. granite-8b" />
|
||||
|
||||
<label class="form-label">Service type</label>
|
||||
<select v-model="newModel.service_type" class="form-select">
|
||||
<option value="" disabled>Select service…</option>
|
||||
<option value="cf-text">cf-text — Language Models</option>
|
||||
<option value="cf-stt">cf-stt — Speech Recognition</option>
|
||||
<option value="cf-tts">cf-tts — Text to Speech</option>
|
||||
<option value="cf-vision">cf-vision — Vision / VLM</option>
|
||||
<option value="cf-image">cf-image — Image Generation</option>
|
||||
<option value="cf-voice">cf-voice — Audio Classification</option>
|
||||
<option value="vllm">vllm — vLLM inference</option>
|
||||
<option value="ollama">ollama — Ollama inference</option>
|
||||
</select>
|
||||
|
||||
<label class="form-label">VRAM required (MB)</label>
|
||||
<input v-model.number="newModel.vram_mb" type="number" min="0" class="form-input" placeholder="e.g. 16384" />
|
||||
|
||||
<label class="form-label">HF Repo <span class="optional">(optional)</span></label>
|
||||
<input v-model="newModel.hf_repo" type="text" class="form-input" placeholder="org/repo-name" />
|
||||
|
||||
<label class="form-label">Description <span class="optional">(optional)</span></label>
|
||||
<input v-model="newModel.description" type="text" class="form-input" placeholder="Human-readable note" />
|
||||
|
||||
<div class="modal-actions">
|
||||
<button
|
||||
class="btn-primary"
|
||||
:disabled="!newModel.model_id || !newModel.service_type || !newModel.vram_mb || saving"
|
||||
@click="saveNewModel"
|
||||
>{{ saving ? 'Saving…' : 'Register' }}</button>
|
||||
<button class="btn-ghost" @click="showRegisterModal = false">Cancel</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, computed, onMounted } from 'vue'
|
||||
|
||||
// ── Types ──────────────────────────────────────────────
|
||||
|
||||
interface AssignmentNode {
|
||||
node_id: string
|
||||
status: 'present' | 'absent' | 'vram_tight'
|
||||
}
|
||||
|
||||
interface DeployingKey {
|
||||
nodeId: string
|
||||
assignmentKey: string
|
||||
}
|
||||
|
||||
interface Assignment {
|
||||
product: string
|
||||
task: string
|
||||
model_id: string
|
||||
description: string
|
||||
alias?: string
|
||||
service_type?: string
|
||||
vram_mb?: number
|
||||
nodes?: AssignmentNode[]
|
||||
}
|
||||
|
||||
interface RegistryModel {
|
||||
model_id: string
|
||||
alias: string
|
||||
service_type: string
|
||||
vram_mb: number
|
||||
hf_repo: string
|
||||
description: string
|
||||
}
|
||||
|
||||
interface ProductGroup {
|
||||
product: string
|
||||
assignments: Assignment[]
|
||||
}
|
||||
|
||||
interface Toast {
|
||||
message: string
|
||||
type: 'success' | 'error'
|
||||
}
|
||||
|
||||
// ── State ──────────────────────────────────────────────
|
||||
|
||||
const assignments = ref<Assignment[]>([])
|
||||
const assignmentsLoading = ref(false)
|
||||
const assignmentsError = ref<string | null>(null)
|
||||
|
||||
const registryModels = ref<RegistryModel[]>([])
|
||||
const registryLoading = ref(false)
|
||||
const registryError = ref<string | null>(null)
|
||||
|
||||
const productFilter = ref('')
|
||||
const editingKey = ref<string | null>(null)
|
||||
const editDraft = ref({ model_id: '', description: '' })
|
||||
|
||||
const showNewAssignmentModal = ref(false)
|
||||
const newAssignment = ref({ product: '', task: '', model_id: '', description: '' })
|
||||
|
||||
const showRegisterModal = ref(false)
|
||||
const newModel = ref({ model_id: '', alias: '', service_type: '', vram_mb: 0, hf_repo: '', description: '' })
|
||||
|
||||
const saving = ref(false)
|
||||
const toast = ref<Toast | null>(null)
|
||||
let toastTimer: ReturnType<typeof setTimeout> | null = null
|
||||
|
||||
const deploying = ref<Set<string>>(new Set())
|
||||
|
||||
// ── Derived ────────────────────────────────────────────
|
||||
|
||||
const allProducts = computed(() => {
|
||||
const seen = new Set<string>()
|
||||
for (const a of assignments.value) seen.add(a.product)
|
||||
return [...seen].sort()
|
||||
})
|
||||
|
||||
const deploymentMap = computed(() => {
|
||||
const map: Record<string, AssignmentNode[]> = {}
|
||||
for (const a of assignments.value) {
|
||||
if (a.nodes) map[`${a.product}/${a.task}`] = a.nodes
|
||||
}
|
||||
return map
|
||||
})
|
||||
|
||||
const filteredGroups = computed((): ProductGroup[] => {
|
||||
const filtered = productFilter.value
|
||||
? assignments.value.filter(a => a.product === productFilter.value)
|
||||
: assignments.value
|
||||
|
||||
const byProduct: Record<string, Assignment[]> = {}
|
||||
for (const a of filtered) {
|
||||
if (!byProduct[a.product]) byProduct[a.product] = []
|
||||
byProduct[a.product].push(a)
|
||||
}
|
||||
return Object.keys(byProduct)
|
||||
.sort()
|
||||
.map(product => ({ product, assignments: byProduct[product] }))
|
||||
})
|
||||
|
||||
// ── Helpers ────────────────────────────────────────────
|
||||
|
||||
function truncate(s: string, max: number): string {
|
||||
return s.length > max ? s.slice(0, max - 1) + '…' : s
|
||||
}
|
||||
|
||||
function displayModelId(a: Assignment): string {
|
||||
if (a.alias) return a.alias
|
||||
const id = a.model_id
|
||||
// Show only the model name part (after /) and truncate long slugs
|
||||
const short = id.includes('/') ? id.split('/').slice(1).join('/') : id
|
||||
return truncate(short, 36)
|
||||
}
|
||||
|
||||
function formatVram(mb: number | undefined): string {
|
||||
if (!mb) return ''
|
||||
if (mb >= 1024) return `${(mb / 1024).toFixed(1)} GB`
|
||||
return `${mb} MB`
|
||||
}
|
||||
|
||||
function serviceChipClass(service: string): string {
|
||||
return `chip-service-${service.replace(/[^a-z0-9]/g, '-')}`
|
||||
}
|
||||
|
||||
function nodeIcon(status: string): string {
|
||||
if (status === 'present') return '✓'
|
||||
if (status === 'vram_tight') return '~'
|
||||
return '✗'
|
||||
}
|
||||
|
||||
function showToast(message: string, type: 'success' | 'error' = 'success') {
|
||||
if (toastTimer) clearTimeout(toastTimer)
|
||||
toast.value = { message, type }
|
||||
toastTimer = setTimeout(() => { toast.value = null }, 3500)
|
||||
}
|
||||
|
||||
function openNewAssignment() {
|
||||
newAssignment.value = { product: '', task: '', model_id: '', description: '' }
|
||||
showNewAssignmentModal.value = true
|
||||
}
|
||||
|
||||
function startEdit(a: Assignment) {
|
||||
editingKey.value = `${a.product}/${a.task}`
|
||||
editDraft.value = { model_id: a.model_id, description: a.description }
|
||||
}
|
||||
|
||||
// ── API ────────────────────────────────────────────────
|
||||
|
||||
async function loadAssignments() {
|
||||
assignmentsLoading.value = true
|
||||
assignmentsError.value = null
|
||||
try {
|
||||
// Fetch both list and deployment status in parallel
|
||||
const [listRes, statusRes] = await Promise.all([
|
||||
fetch('/api/cforch/assignments'),
|
||||
fetch('/api/cforch/assignments/deployment-status'),
|
||||
])
|
||||
if (!listRes.ok) throw new Error(`HTTP ${listRes.status}`)
|
||||
const list: Assignment[] = (await listRes.json()).assignments ?? []
|
||||
|
||||
// Merge deployment status into assignments if available
|
||||
if (statusRes.ok) {
|
||||
const statusList: Assignment[] = (await statusRes.json()).deployment_status ?? []
|
||||
const statusMap: Record<string, AssignmentNode[]> = {}
|
||||
for (const s of statusList) {
|
||||
statusMap[`${s.product}/${s.task}`] = s.nodes ?? []
|
||||
}
|
||||
for (const a of list) {
|
||||
a.nodes = statusMap[`${a.product}/${a.task}`] ?? []
|
||||
// Enrich with service_type/vram_mb from status payload
|
||||
const s = statusList.find(x => x.product === a.product && x.task === a.task)
|
||||
if (s) {
|
||||
a.service_type = s.service_type
|
||||
a.vram_mb = s.vram_mb
|
||||
a.alias = s.alias
|
||||
}
|
||||
}
|
||||
}
|
||||
assignments.value = list
|
||||
} catch (e) {
|
||||
assignmentsError.value = `Could not load assignments: ${e}`
|
||||
} finally {
|
||||
assignmentsLoading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
async function loadRegistry() {
|
||||
registryLoading.value = true
|
||||
registryError.value = null
|
||||
try {
|
||||
const res = await fetch('/api/cforch/model-registry')
|
||||
if (!res.ok) throw new Error(`HTTP ${res.status}`)
|
||||
registryModels.value = (await res.json()).models ?? []
|
||||
} catch (e) {
|
||||
registryError.value = `Could not load model registry: ${e}`
|
||||
} finally {
|
||||
registryLoading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
async function saveNewAssignment() {
|
||||
saving.value = true
|
||||
try {
|
||||
const res = await fetch('/api/cforch/assignments', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify(newAssignment.value),
|
||||
})
|
||||
if (!res.ok) throw new Error(await res.text())
|
||||
showNewAssignmentModal.value = false
|
||||
showToast('Assignment saved')
|
||||
await loadAssignments()
|
||||
} catch (e) {
|
||||
showToast(`Save failed: ${e}`, 'error')
|
||||
} finally {
|
||||
saving.value = false
|
||||
}
|
||||
}
|
||||
|
||||
async function saveEdit(a: Assignment) {
|
||||
saving.value = true
|
||||
try {
|
||||
const res = await fetch('/api/cforch/assignments', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({
|
||||
product: a.product,
|
||||
task: a.task,
|
||||
model_id: editDraft.value.model_id,
|
||||
description: editDraft.value.description,
|
||||
}),
|
||||
})
|
||||
if (!res.ok) throw new Error(await res.text())
|
||||
editingKey.value = null
|
||||
showToast('Assignment updated')
|
||||
await loadAssignments()
|
||||
} catch (e) {
|
||||
showToast(`Update failed: ${e}`, 'error')
|
||||
} finally {
|
||||
saving.value = false
|
||||
}
|
||||
}
|
||||
|
||||
async function deleteAssignment(product: string, task: string) {
|
||||
if (!confirm(`Delete assignment ${product}.${task}?`)) return
|
||||
try {
|
||||
const res = await fetch(
|
||||
`/api/cforch/assignments/${encodeURIComponent(product)}/${encodeURIComponent(task)}`,
|
||||
{ method: 'DELETE' },
|
||||
)
|
||||
if (!res.ok) throw new Error(await res.text())
|
||||
showToast('Assignment deleted')
|
||||
await loadAssignments()
|
||||
} catch (e) {
|
||||
showToast(`Delete failed: ${e}`, 'error')
|
||||
}
|
||||
}
|
||||
|
||||
async function saveNewModel() {
|
||||
saving.value = true
|
||||
try {
|
||||
const res = await fetch('/api/cforch/model-registry', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify(newModel.value),
|
||||
})
|
||||
if (!res.ok) throw new Error(await res.text())
|
||||
showRegisterModal.value = false
|
||||
showToast('Model registered')
|
||||
await loadRegistry()
|
||||
} catch (e) {
|
||||
showToast(`Register failed: ${e}`, 'error')
|
||||
} finally {
|
||||
saving.value = false
|
||||
}
|
||||
}
|
||||
|
||||
async function deleteModel(model_id: string) {
|
||||
if (!confirm(`Remove ${model_id} from the registry?`)) return
|
||||
try {
|
||||
const res = await fetch(
|
||||
`/api/cforch/model-registry/${encodeURIComponent(model_id)}`,
|
||||
{ method: 'DELETE' },
|
||||
)
|
||||
if (!res.ok) throw new Error(await res.text())
|
||||
showToast('Model removed')
|
||||
await loadRegistry()
|
||||
} catch (e) {
|
||||
showToast(`Delete failed: ${e}`, 'error')
|
||||
}
|
||||
}
|
||||
|
||||
async function deployModel(a: Assignment, nodeId: string) {
|
||||
const key = `${a.product}/${a.task}/${nodeId}`
|
||||
if (deploying.value.has(key)) return
|
||||
|
||||
// Look up hf_repo from registry for cleaner path construction
|
||||
const regEntry = registryModels.value.find(m => m.model_id === a.model_id)
|
||||
const hf_repo = regEntry?.hf_repo ?? ''
|
||||
const service_type = a.service_type ?? regEntry?.service_type ?? ''
|
||||
const vram_mb = a.vram_mb ?? regEntry?.vram_mb ?? 0
|
||||
const description = regEntry?.alias ? `${regEntry.alias} (via assignments)` : ''
|
||||
|
||||
if (!service_type) {
|
||||
showToast(`No service type for model ${a.model_id}`, 'error')
|
||||
return
|
||||
}
|
||||
|
||||
deploying.value = new Set([...deploying.value, key])
|
||||
try {
|
||||
const res = await fetch(`/api/nodes-mgmt/nodes/${encodeURIComponent(nodeId)}/models/deploy`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ model_id: a.model_id, service_type, vram_mb, hf_repo, description }),
|
||||
})
|
||||
if (!res.ok) throw new Error(await res.text())
|
||||
const data = await res.json()
|
||||
showToast(`Registered ${a.model_id} on ${nodeId} at ${data.path}`)
|
||||
|
||||
// Optimistic update: flip node to 'present' immediately so the Register button
|
||||
// disappears before the coordinator reload confirms. loadAssignments() reconciles
|
||||
// with real server state on the next round-trip.
|
||||
assignments.value = assignments.value.map(asgn => {
|
||||
if (asgn.product !== a.product || asgn.task !== a.task) return asgn
|
||||
return {
|
||||
...asgn,
|
||||
nodes: (asgn.nodes ?? []).map(ns =>
|
||||
ns.node_id === nodeId ? { ...ns, status: 'present' as const } : ns
|
||||
),
|
||||
}
|
||||
})
|
||||
|
||||
await loadAssignments()
|
||||
} catch (e) {
|
||||
showToast(`Deploy failed: ${e}`, 'error')
|
||||
} finally {
|
||||
deploying.value = new Set([...deploying.value].filter(k => k !== key))
|
||||
}
|
||||
}
|
||||
|
||||
onMounted(() => {
|
||||
loadAssignments()
|
||||
loadRegistry()
|
||||
})
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
.assignments-tab {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 1.25rem;
|
||||
}
|
||||
|
||||
/* ── Toast ── */
|
||||
.toast {
|
||||
position: fixed;
|
||||
bottom: 1.5rem;
|
||||
right: 1.5rem;
|
||||
padding: 0.65rem 1.1rem;
|
||||
border-radius: 0.5rem;
|
||||
font-size: 0.88rem;
|
||||
font-weight: 500;
|
||||
z-index: 200;
|
||||
box-shadow: 0 2px 8px rgba(0,0,0,0.15);
|
||||
}
|
||||
.toast.success {
|
||||
background: var(--color-success, #2a8050);
|
||||
color: #fff;
|
||||
}
|
||||
.toast.error {
|
||||
background: var(--color-danger, #b03030);
|
||||
color: #fff;
|
||||
}
|
||||
|
||||
/* ── Section headers ── */
|
||||
.section-header {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: space-between;
|
||||
gap: 1rem;
|
||||
}
|
||||
.section-header-mt {
|
||||
margin-top: 1.5rem;
|
||||
}
|
||||
.section-title {
|
||||
font-size: 1rem;
|
||||
font-weight: 600;
|
||||
color: var(--app-primary, #2A6080);
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
/* ── Filter row ── */
|
||||
.filter-row {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.6rem;
|
||||
}
|
||||
.filter-label {
|
||||
font-size: 0.85rem;
|
||||
color: var(--color-text-muted, #6b7a99);
|
||||
}
|
||||
.filter-select {
|
||||
padding: 0.3rem 0.6rem;
|
||||
font-size: 0.85rem;
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: 0.4rem;
|
||||
background: var(--color-surface, #fff);
|
||||
color: var(--color-text, #1a2030);
|
||||
}
|
||||
|
||||
/* ── Product groups ── */
|
||||
.product-groups {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 1rem;
|
||||
}
|
||||
.product-group {}
|
||||
.product-name {
|
||||
font-size: 0.75rem;
|
||||
font-weight: 700;
|
||||
letter-spacing: 0.08em;
|
||||
color: var(--color-text-muted, #6b7a99);
|
||||
text-transform: uppercase;
|
||||
margin: 0 0 0.4rem;
|
||||
}
|
||||
.assignment-list {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0.4rem;
|
||||
}
|
||||
|
||||
/* ── Assignment rows ── */
|
||||
.assignment-row {
|
||||
background: var(--color-surface-raised, #f0f4fa);
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: 0.5rem;
|
||||
padding: 0.65rem 0.85rem;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0.4rem;
|
||||
}
|
||||
.assignment-main {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.5rem;
|
||||
flex-wrap: wrap;
|
||||
}
|
||||
.task-id {
|
||||
font-family: var(--font-mono, monospace);
|
||||
font-size: 0.88rem;
|
||||
font-weight: 600;
|
||||
color: var(--color-text, #1a2030);
|
||||
min-width: 0;
|
||||
}
|
||||
.model-name {
|
||||
font-size: 0.85rem;
|
||||
color: var(--color-text-muted, #6b7a99);
|
||||
white-space: nowrap;
|
||||
overflow: hidden;
|
||||
text-overflow: ellipsis;
|
||||
max-width: 280px;
|
||||
cursor: default;
|
||||
}
|
||||
.assignment-actions {
|
||||
display: flex;
|
||||
gap: 0.4rem;
|
||||
flex-wrap: wrap;
|
||||
}
|
||||
|
||||
/* ── Node status badges ── */
|
||||
.node-statuses {
|
||||
display: flex;
|
||||
gap: 0.35rem;
|
||||
flex-wrap: wrap;
|
||||
}
|
||||
.node-badge-wrap {
|
||||
display: inline-flex;
|
||||
align-items: center;
|
||||
gap: 0.2rem;
|
||||
}
|
||||
.node-badge {
|
||||
display: inline-flex;
|
||||
align-items: center;
|
||||
gap: 0.2rem;
|
||||
font-size: 0.78rem;
|
||||
padding: 0.15rem 0.5rem;
|
||||
border-radius: 0.35rem;
|
||||
font-weight: 500;
|
||||
}
|
||||
.node-badge.present {
|
||||
background: color-mix(in srgb, var(--color-success, #2a8050) 15%, transparent);
|
||||
color: var(--color-success, #2a8050);
|
||||
border: 1px solid color-mix(in srgb, var(--color-success, #2a8050) 30%, transparent);
|
||||
}
|
||||
.node-badge.absent {
|
||||
background: color-mix(in srgb, var(--color-danger, #b03030) 12%, transparent);
|
||||
color: var(--color-danger, #b03030);
|
||||
border: 1px solid color-mix(in srgb, var(--color-danger, #b03030) 25%, transparent);
|
||||
}
|
||||
.node-badge.vram_tight {
|
||||
background: color-mix(in srgb, #c08030 15%, transparent);
|
||||
color: #8a5500;
|
||||
border: 1px solid color-mix(in srgb, #c08030 30%, transparent);
|
||||
}
|
||||
.node-icon {
|
||||
font-size: 0.85em;
|
||||
}
|
||||
.btn-deploy {
|
||||
padding: 0.1rem 0.4rem;
|
||||
font-size: 0.72rem;
|
||||
font-weight: 600;
|
||||
background: color-mix(in srgb, var(--app-primary, #2A6080) 12%, transparent);
|
||||
color: var(--app-primary, #2A6080);
|
||||
border: 1px solid color-mix(in srgb, var(--app-primary, #2A6080) 30%, transparent);
|
||||
border-radius: 0.3rem;
|
||||
cursor: pointer;
|
||||
white-space: nowrap;
|
||||
transition: background 0.15s;
|
||||
}
|
||||
.btn-deploy:hover:not(:disabled) {
|
||||
background: color-mix(in srgb, var(--app-primary, #2A6080) 22%, transparent);
|
||||
}
|
||||
.btn-deploy:disabled { opacity: 0.5; cursor: default; }
|
||||
|
||||
/* ── Inline edit ── */
|
||||
.inline-edit {
|
||||
display: flex;
|
||||
flex-wrap: wrap;
|
||||
gap: 0.4rem;
|
||||
padding-top: 0.35rem;
|
||||
border-top: 1px solid var(--color-border, #d0d7e8);
|
||||
}
|
||||
.edit-select,
|
||||
.edit-input {
|
||||
flex: 1;
|
||||
min-width: 160px;
|
||||
padding: 0.35rem 0.55rem;
|
||||
font-size: 0.85rem;
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: 0.4rem;
|
||||
background: var(--color-surface, #fff);
|
||||
color: var(--color-text, #1a2030);
|
||||
}
|
||||
.inline-edit-btns {
|
||||
display: flex;
|
||||
gap: 0.35rem;
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
/* ── Registry table ── */
|
||||
.registry-table-wrap {
|
||||
overflow-x: auto;
|
||||
border-radius: 0.5rem;
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
}
|
||||
.registry-table {
|
||||
width: 100%;
|
||||
border-collapse: collapse;
|
||||
font-size: 0.85rem;
|
||||
}
|
||||
.registry-table th {
|
||||
text-align: left;
|
||||
padding: 0.5rem 0.75rem;
|
||||
font-size: 0.78rem;
|
||||
font-weight: 600;
|
||||
color: var(--color-text-muted, #6b7a99);
|
||||
background: var(--color-surface-raised, #f0f4fa);
|
||||
border-bottom: 1px solid var(--color-border, #d0d7e8);
|
||||
white-space: nowrap;
|
||||
}
|
||||
.registry-table td {
|
||||
padding: 0.5rem 0.75rem;
|
||||
border-bottom: 1px solid var(--color-border, #d0d7e8);
|
||||
vertical-align: middle;
|
||||
}
|
||||
.registry-table tbody tr:last-child td {
|
||||
border-bottom: none;
|
||||
}
|
||||
.truncated {
|
||||
display: inline-block;
|
||||
max-width: 220px;
|
||||
white-space: nowrap;
|
||||
overflow: hidden;
|
||||
text-overflow: ellipsis;
|
||||
vertical-align: bottom;
|
||||
cursor: default;
|
||||
}
|
||||
.hf-link {
|
||||
color: var(--app-primary, #2A6080);
|
||||
text-decoration: none;
|
||||
font-size: 0.82rem;
|
||||
}
|
||||
.hf-link:hover { text-decoration: underline; }
|
||||
.text-muted { color: var(--color-text-muted, #6b7a99); }
|
||||
|
||||
/* ── Chips ── */
|
||||
.chip {
|
||||
display: inline-block;
|
||||
padding: 0.15rem 0.5rem;
|
||||
border-radius: 0.35rem;
|
||||
font-size: 0.75rem;
|
||||
font-weight: 600;
|
||||
white-space: nowrap;
|
||||
}
|
||||
.chip-vram {
|
||||
background: color-mix(in srgb, var(--app-primary, #2A6080) 12%, transparent);
|
||||
color: var(--app-primary, #2A6080);
|
||||
border: 1px solid color-mix(in srgb, var(--app-primary, #2A6080) 25%, transparent);
|
||||
}
|
||||
/* service chips — match ModelsView convention */
|
||||
.chip-service-cf-text { background: #e8f0fe; color: #1a5276; border: 1px solid #a9c4e8; }
|
||||
.chip-service-cf-stt { background: #eaf6ea; color: #1e6b3a; border: 1px solid #a2d9b1; }
|
||||
.chip-service-cf-tts { background: #fdf3e3; color: #7d4e00; border: 1px solid #e8c98a; }
|
||||
.chip-service-cf-vision { background: #f3e8fd; color: #5b2d8e; border: 1px solid #c8a0e8; }
|
||||
.chip-service-cf-image { background: #fce8f0; color: #8e1a4f; border: 1px solid #e8a0c0; }
|
||||
.chip-service-cf-voice { background: #e8f8fc; color: #0a5c6e; border: 1px solid #88d0e0; }
|
||||
.chip-service-vllm { background: #f5ece0; color: #7a3800; border: 1px solid #d4a87a; }
|
||||
.chip-service-ollama { background: #eeeeee; color: #444; border: 1px solid #ccc; }
|
||||
|
||||
/* ── Buttons ── */
|
||||
.btn-primary {
|
||||
padding: 0.45rem 1rem;
|
||||
background: var(--app-primary, #2A6080);
|
||||
color: #fff;
|
||||
border: none;
|
||||
border-radius: 0.4rem;
|
||||
font-size: 0.85rem;
|
||||
font-weight: 600;
|
||||
cursor: pointer;
|
||||
transition: opacity 0.15s;
|
||||
}
|
||||
.btn-primary:disabled { opacity: 0.5; cursor: default; }
|
||||
.btn-primary:not(:disabled):hover { opacity: 0.88; }
|
||||
|
||||
.btn-ghost {
|
||||
padding: 0.35rem 0.75rem;
|
||||
background: transparent;
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: 0.4rem;
|
||||
font-size: 0.82rem;
|
||||
color: var(--color-text-muted, #6b7a99);
|
||||
cursor: pointer;
|
||||
transition: background 0.15s;
|
||||
}
|
||||
.btn-ghost:hover { background: var(--color-surface-raised, #e4ebf5); }
|
||||
.btn-ghost.btn-danger { color: var(--color-danger, #b03030); border-color: color-mix(in srgb, var(--color-danger, #b03030) 30%, transparent); }
|
||||
.btn-ghost.btn-danger:hover { background: color-mix(in srgb, var(--color-danger, #b03030) 10%, transparent); }
|
||||
|
||||
.btn-sm { padding: 0.3rem 0.65rem; font-size: 0.8rem; }
|
||||
|
||||
/* ── Empty / error states ── */
|
||||
.empty-state {
|
||||
padding: 1.5rem;
|
||||
text-align: center;
|
||||
color: var(--color-text-muted, #6b7a99);
|
||||
font-size: 0.9rem;
|
||||
background: var(--color-surface-raised, #f0f4fa);
|
||||
border: 1px dashed var(--color-border, #d0d7e8);
|
||||
border-radius: 0.5rem;
|
||||
}
|
||||
.error-notice {
|
||||
padding: 0.75rem 1rem;
|
||||
background: color-mix(in srgb, var(--color-danger, #b03030) 10%, transparent);
|
||||
color: var(--color-danger, #b03030);
|
||||
border: 1px solid color-mix(in srgb, var(--color-danger, #b03030) 25%, transparent);
|
||||
border-radius: 0.4rem;
|
||||
font-size: 0.87rem;
|
||||
}
|
||||
|
||||
/* ── Modal ── */
|
||||
.modal-backdrop {
|
||||
position: fixed;
|
||||
inset: 0;
|
||||
background: rgba(0,0,0,0.35);
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
z-index: 100;
|
||||
padding: 1rem;
|
||||
}
|
||||
.modal {
|
||||
background: var(--color-surface, #fff);
|
||||
border-radius: 0.65rem;
|
||||
padding: 1.5rem;
|
||||
width: 100%;
|
||||
max-width: 480px;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0.65rem;
|
||||
box-shadow: 0 8px 32px rgba(0,0,0,0.18);
|
||||
max-height: 90vh;
|
||||
overflow-y: auto;
|
||||
}
|
||||
.modal-title {
|
||||
font-size: 1rem;
|
||||
font-weight: 700;
|
||||
color: var(--app-primary, #2A6080);
|
||||
margin: 0 0 0.25rem;
|
||||
}
|
||||
.form-label {
|
||||
font-size: 0.82rem;
|
||||
font-weight: 600;
|
||||
color: var(--color-text-muted, #6b7a99);
|
||||
}
|
||||
.form-input,
|
||||
.form-select {
|
||||
padding: 0.4rem 0.65rem;
|
||||
font-size: 0.88rem;
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: 0.4rem;
|
||||
background: var(--color-surface, #fff);
|
||||
color: var(--color-text, #1a2030);
|
||||
width: 100%;
|
||||
box-sizing: border-box;
|
||||
}
|
||||
.form-input:focus, .form-select:focus {
|
||||
outline: 2px solid var(--app-primary, #2A6080);
|
||||
outline-offset: 1px;
|
||||
}
|
||||
.modal-actions {
|
||||
display: flex;
|
||||
gap: 0.5rem;
|
||||
justify-content: flex-end;
|
||||
margin-top: 0.25rem;
|
||||
}
|
||||
.optional, .hint {
|
||||
font-weight: 400;
|
||||
color: var(--color-text-muted, #6b7a99);
|
||||
font-size: 0.78rem;
|
||||
}
|
||||
|
||||
/* ── Responsive ── */
|
||||
@media (max-width: 600px) {
|
||||
.assignment-main { flex-direction: column; align-items: flex-start; }
|
||||
.col-hf { display: none; }
|
||||
.model-name { max-width: 100%; }
|
||||
.modal { padding: 1rem; }
|
||||
}
|
||||
</style>
|
||||
|
|
@ -1,82 +0,0 @@
|
|||
import { mount, flushPromises } from '@vue/test-utils'
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest'
|
||||
import BenchmarkView from './BenchmarkView.vue'
|
||||
|
||||
beforeEach(() => {
|
||||
vi.stubGlobal('fetch', vi.fn().mockImplementation((url: string) => {
|
||||
// LlmEvalTab calls /api/cforch/models and expects { models: CfOrchModel[] }
|
||||
if (url.includes('/api/cforch/models')) {
|
||||
return Promise.resolve({
|
||||
ok: true,
|
||||
json: async () => ({ models: [] }),
|
||||
text: async () => '',
|
||||
})
|
||||
}
|
||||
// Default: satisfies ClassifierTab (/api/benchmark/results, /api/benchmark/models,
|
||||
// /api/finetune/status), StyleTab (/api/style/models, /api/style/results),
|
||||
// and any other tab that tolerates empty arrays/objects.
|
||||
return Promise.resolve({
|
||||
ok: true,
|
||||
json: async () => ({ models: {}, categories: {}, tasks: [], types: [], results: [] }),
|
||||
text: async () => '',
|
||||
})
|
||||
}))
|
||||
vi.stubGlobal('EventSource', class {
|
||||
onmessage = null
|
||||
onerror = null
|
||||
close() {}
|
||||
})
|
||||
})
|
||||
|
||||
describe('BenchmarkView', () => {
|
||||
it('renders page title "Benchmark"', async () => {
|
||||
const w = mount(BenchmarkView)
|
||||
await flushPromises()
|
||||
expect(w.text()).toContain('Benchmark')
|
||||
})
|
||||
|
||||
it('has mode buttons: Classifier, LLM Eval, Writing Style', async () => {
|
||||
const w = mount(BenchmarkView)
|
||||
await flushPromises()
|
||||
const text = w.text()
|
||||
expect(text).toContain('Classifier')
|
||||
expect(text).toContain('LLM Eval')
|
||||
expect(text).toContain('Writing Style')
|
||||
})
|
||||
|
||||
it('does NOT have a Compare mode button', async () => {
|
||||
const w = mount(BenchmarkView)
|
||||
await flushPromises()
|
||||
const buttons = w.findAll('.mode-btn')
|
||||
const labels = buttons.map(b => b.text())
|
||||
expect(labels.every(l => !l.includes('Compare'))).toBe(true)
|
||||
})
|
||||
|
||||
it('shows Classifier tab by default', async () => {
|
||||
const w = mount(BenchmarkView)
|
||||
await flushPromises()
|
||||
// ClassifierTab has a .classifier-tab root
|
||||
expect(w.find('.classifier-tab').exists()).toBe(true)
|
||||
})
|
||||
|
||||
it('switches to LlmEvalTab when LLM Eval clicked', async () => {
|
||||
const w = mount(BenchmarkView)
|
||||
await flushPromises()
|
||||
const llmBtn = w.findAll('.mode-btn').find(b => b.text().includes('LLM Eval'))!
|
||||
await llmBtn.trigger('click')
|
||||
await flushPromises()
|
||||
expect(w.find('.llm-eval-tab').exists()).toBe(true)
|
||||
expect(w.find('.classifier-tab').exists()).toBe(false)
|
||||
expect(llmBtn.classes()).toContain('active')
|
||||
})
|
||||
|
||||
it('switches to StyleTab when Writing Style clicked', async () => {
|
||||
const w = mount(BenchmarkView)
|
||||
await flushPromises()
|
||||
const styleBtn = w.findAll('.mode-btn').find(b => b.text().includes('Writing Style'))!
|
||||
await styleBtn.trigger('click')
|
||||
await flushPromises()
|
||||
expect(w.find('.style-tab').exists()).toBe(true)
|
||||
expect(w.find('.classifier-tab').exists()).toBe(false)
|
||||
})
|
||||
})
|
||||
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
|
|
@ -1,722 +0,0 @@
|
|||
<template>
|
||||
<div class="compare-tab">
|
||||
|
||||
<!-- Source toggle -->
|
||||
<div class="source-toggle" role="group" aria-label="Prompt source">
|
||||
<button class="source-btn" :class="{ active: promptSource === 'tasks' }" @click="promptSource = 'tasks'">
|
||||
📋 cf-orch Tasks
|
||||
</button>
|
||||
<button class="source-btn" :class="{ active: promptSource === 'style' }" @click="promptSource = 'style'">
|
||||
✍️ Writing Style Prompts
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<!-- Task selector (cf-orch tasks) -->
|
||||
<details v-if="promptSource === 'tasks'" class="model-picker" open>
|
||||
<summary class="picker-summary">
|
||||
<span class="picker-title">📋 Pick a Task</span>
|
||||
<span class="picker-badge">{{ cmpSelectedTask ? cmpSelectedTask.name : 'None selected' }}</span>
|
||||
</summary>
|
||||
<div class="picker-body">
|
||||
<div v-if="llmTasksLoading" class="picker-loading">Loading tasks…</div>
|
||||
<div v-else-if="llmTasks.length === 0" class="picker-empty">No tasks found — check cforch config.</div>
|
||||
<template v-else>
|
||||
<div v-for="(tasks, type) in llmTasksByType" :key="type" class="picker-category">
|
||||
<span class="picker-cat-name picker-cat-section">{{ type }}</span>
|
||||
<div class="picker-model-list">
|
||||
<label v-for="t in tasks" :key="t.id" class="picker-model-row">
|
||||
<input
|
||||
type="radio"
|
||||
name="cmp-task"
|
||||
:checked="cmpSelectedTask?.id === t.id"
|
||||
@change="selectCmpTask(t)"
|
||||
/>
|
||||
<span class="picker-model-name" :title="t.name">{{ t.name }}</span>
|
||||
</label>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
</div>
|
||||
</details>
|
||||
|
||||
<!-- Writing style prompt selector -->
|
||||
<details v-if="promptSource === 'style'" class="model-picker" open>
|
||||
<summary class="picker-summary">
|
||||
<span class="picker-title">✍️ Pick a Writing Style Prompt</span>
|
||||
<span class="picker-badge">{{ selectedVoicePrompt ? selectedVoicePrompt.tag : 'None selected' }}</span>
|
||||
</summary>
|
||||
<div class="picker-body">
|
||||
<div class="picker-model-list style-prompt-list">
|
||||
<label v-for="vp in STYLE_PROMPTS" :key="vp.tag" class="picker-model-row style-prompt-row">
|
||||
<input
|
||||
type="radio"
|
||||
name="cmp-style-prompt"
|
||||
:checked="selectedVoicePrompt?.tag === vp.tag"
|
||||
@change="selectVoicePrompt(vp)"
|
||||
/>
|
||||
<span class="style-prompt-tag">{{ vp.tag }}</span>
|
||||
<span class="style-prompt-title">{{ vp.thread_title }}</span>
|
||||
</label>
|
||||
</div>
|
||||
</div>
|
||||
</details>
|
||||
|
||||
<!-- Prompt editor + model picker (shown once a prompt source is ready) -->
|
||||
<template v-if="promptSource === 'tasks' ? !!cmpSelectedTask : !!selectedVoicePrompt">
|
||||
<label class="prompt-label" for="cmp-prompt">Prompt</label>
|
||||
<textarea
|
||||
id="cmp-prompt"
|
||||
class="cmp-prompt-editor"
|
||||
v-model="cmpPrompt"
|
||||
rows="6"
|
||||
/>
|
||||
|
||||
<!-- LLM model picker (ollama + vllm + cf-text) -->
|
||||
<details class="model-picker" open>
|
||||
<summary class="picker-summary">
|
||||
<span class="picker-title">🤖 LLM Models</span>
|
||||
<span class="picker-badge">{{ cmpSelectedModels.size }} / {{ llmSelectableModels.length }}</span>
|
||||
</summary>
|
||||
<div class="picker-body">
|
||||
<label class="picker-cat-header">
|
||||
<input
|
||||
type="checkbox"
|
||||
:checked="cmpSelectedModels.size === llmSelectableModels.length"
|
||||
:indeterminate="cmpSelectedModels.size > 0 && cmpSelectedModels.size < llmSelectableModels.length"
|
||||
@change="toggleAllCmpModels(($event.target as HTMLInputElement).checked)"
|
||||
/>
|
||||
<span class="picker-cat-name">All LLM models</span>
|
||||
</label>
|
||||
<div v-for="(models, service) in llmModelsByService" :key="service" class="picker-category">
|
||||
<span class="picker-cat-section">{{ service }}</span>
|
||||
<div class="picker-model-list">
|
||||
<label v-for="m in models" :key="m.id" class="picker-model-row">
|
||||
<input
|
||||
type="checkbox"
|
||||
:checked="cmpSelectedModels.has(m.id)"
|
||||
@change="toggleCmpModel(m.id, ($event.target as HTMLInputElement).checked)"
|
||||
/>
|
||||
<span class="picker-model-name">{{ m.name }}</span>
|
||||
<span class="picker-adapter-type">{{ m.tags.slice(0, 2).join(', ') }}</span>
|
||||
</label>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</details>
|
||||
|
||||
<!-- Run controls -->
|
||||
<div class="run-controls">
|
||||
<button
|
||||
class="btn-run"
|
||||
:disabled="cmpRunning || cmpSelectedModels.size === 0"
|
||||
@click="startCompare"
|
||||
>{{ cmpRunning ? '⏳ Running…' : '⚖️ Compare Models' }}</button>
|
||||
<button v-if="cmpRunning" class="btn-cancel" @click="cancelCompare">✕ Cancel</button>
|
||||
</div>
|
||||
|
||||
<!-- Progress log -->
|
||||
<div v-if="cmpLog.length > 0" class="run-log">
|
||||
<div class="log-lines">
|
||||
<div v-for="(line, i) in cmpLog" :key="i" class="log-line">{{ line }}</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Side-by-side results -->
|
||||
<template v-if="cmpResults.length > 0">
|
||||
<h2 class="chart-title">Side-by-Side Responses</h2>
|
||||
<div class="cmp-results-grid">
|
||||
<div
|
||||
v-for="r in cmpResults"
|
||||
:key="r.model"
|
||||
class="cmp-result-card"
|
||||
:class="{ 'cmp-error': !!r.error }"
|
||||
>
|
||||
<div class="cmp-result-header">
|
||||
<span class="cmp-model-name">{{ r.model }}</span>
|
||||
<span class="cmp-meta">
|
||||
<template v-if="r.error"><span class="err-badge">error</span></template>
|
||||
<template v-else>{{ (r.elapsed_ms / 1000).toFixed(1) }}s</template>
|
||||
</span>
|
||||
</div>
|
||||
<pre v-if="r.error" class="cmp-error-text">{{ r.error }}</pre>
|
||||
<pre v-else class="cmp-response">{{ r.response }}</pre>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
</template>
|
||||
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, computed, onMounted } from 'vue'
|
||||
import { useApiFetch } from '../composables/useApi'
|
||||
|
||||
// ── Types ───────────────────────────────────────────────────────────────────
|
||||
interface CfOrchTask {
|
||||
id: string
|
||||
name: string
|
||||
type: string
|
||||
prompt: string
|
||||
system: string
|
||||
}
|
||||
|
||||
interface CfOrchModel {
|
||||
name: string
|
||||
id: string
|
||||
service: string
|
||||
tags: string[]
|
||||
vram_estimate_mb?: number
|
||||
}
|
||||
|
||||
interface CmpResult {
|
||||
model: string
|
||||
response: string
|
||||
elapsed_ms: number
|
||||
error: string | null
|
||||
}
|
||||
|
||||
interface VoicePrompt {
|
||||
tag: string
|
||||
thread_title: string
|
||||
thread_body: string
|
||||
}
|
||||
|
||||
// ── Writing style prompts (mirrors TEST_PROMPTS in benchmark_style.py) ──────
|
||||
const STYLE_SYSTEM = "You are a writing assistant. Your job is to write a Reddit reply that matches the user's voice — casual, direct, community-first. No em dashes. No filler phrases. No semicolons. Short punchy sentences."
|
||||
|
||||
const STYLE_PROMPTS: VoicePrompt[] = [
|
||||
{
|
||||
tag: 'selfhosted_ai_fatigue',
|
||||
thread_title: "Anyone else getting tired of re-explaining their setup every time an AI model forgets?",
|
||||
thread_body: "Every session I start over. My whole hardware setup, what tools I use, what I've already tried. It's exhausting. There has to be a better way.",
|
||||
},
|
||||
{
|
||||
tag: 'privacy_local_llm',
|
||||
thread_title: "What's the point of running local LLMs if the apps still phone home?",
|
||||
thread_body: "I went through all the trouble of setting up ollama and now I find out the frontend I'm using is sending telemetry. Kind of defeats the purpose.",
|
||||
},
|
||||
{
|
||||
tag: 'solarpunk_tech',
|
||||
thread_title: "What does solarpunk computing actually look like in practice?",
|
||||
thread_body: "I keep seeing the aesthetic but not a lot of concrete examples of people living it out with their tech choices. What does it mean day to day?",
|
||||
},
|
||||
{
|
||||
tag: 'nd_tools',
|
||||
thread_title: "Tools that actually help with executive function vs ones that just add friction",
|
||||
thread_body: "I've tried a dozen productivity apps and most of them require more executive function to maintain than they save. What actually sticks for you?",
|
||||
},
|
||||
{
|
||||
tag: 'data_ownership',
|
||||
thread_title: "Who actually owns your data when you use a 'free' AI tool?",
|
||||
thread_body: "Read the ToS on three different AI assistants today. In all three cases your inputs can be used for training, shared with partners, and retained indefinitely. Is this just accepted now?",
|
||||
},
|
||||
{
|
||||
tag: 'digital_culture',
|
||||
thread_title: "The internet used to feel like it belonged to everyone. What happened?",
|
||||
thread_body: "I grew up on forums, IRC, personal homepages. Now everything is a platform owned by someone trying to extract value from the community that built it.",
|
||||
},
|
||||
]
|
||||
|
||||
// ── State ───────────────────────────────────────────────────────────────────
|
||||
const llmTasks = ref<CfOrchTask[]>([])
|
||||
const llmTasksLoading = ref(false)
|
||||
const llmModels = ref<CfOrchModel[]>([])
|
||||
|
||||
const promptSource = ref<'tasks' | 'style'>('tasks')
|
||||
const cmpSelectedTask = ref<CfOrchTask | null>(null)
|
||||
const selectedVoicePrompt = ref<VoicePrompt | null>(null)
|
||||
const cmpSystemPrompt = ref('')
|
||||
const cmpPrompt = ref('')
|
||||
const cmpSelectedModels = ref<Set<string>>(new Set())
|
||||
const cmpRunning = ref(false)
|
||||
const cmpLog = ref<string[]>([])
|
||||
const cmpResults = ref<CmpResult[]>([])
|
||||
const cmpEventSource = ref<EventSource | null>(null)
|
||||
|
||||
// ── Computed ────────────────────────────────────────────────────────────────
|
||||
const LLM_SERVICES = new Set(['ollama', 'vllm', 'cf-text'])
|
||||
|
||||
const llmSelectableModels = computed(() =>
|
||||
llmModels.value.filter(m => LLM_SERVICES.has(m.service))
|
||||
)
|
||||
|
||||
/** Group selectable models by service for the picker UI */
|
||||
const llmModelsByService = computed((): Record<string, CfOrchModel[]> => {
|
||||
const groups: Record<string, CfOrchModel[]> = {}
|
||||
for (const m of llmSelectableModels.value) {
|
||||
if (!groups[m.service]) groups[m.service] = []
|
||||
groups[m.service].push(m)
|
||||
}
|
||||
return groups
|
||||
})
|
||||
|
||||
const llmTasksByType = computed((): Record<string, CfOrchTask[]> => {
|
||||
const groups: Record<string, CfOrchTask[]> = {}
|
||||
for (const t of llmTasks.value) {
|
||||
if (!groups[t.type]) groups[t.type] = []
|
||||
groups[t.type].push(t)
|
||||
}
|
||||
return groups
|
||||
})
|
||||
|
||||
// ── Helpers ─────────────────────────────────────────────────────────────────
|
||||
function selectCmpTask(t: CfOrchTask) {
|
||||
cmpSelectedTask.value = t
|
||||
cmpPrompt.value = t.prompt || ''
|
||||
cmpSystemPrompt.value = t.system || ''
|
||||
cmpResults.value = []
|
||||
cmpLog.value = []
|
||||
}
|
||||
|
||||
function selectVoicePrompt(vp: VoicePrompt) {
|
||||
selectedVoicePrompt.value = vp
|
||||
cmpPrompt.value = `Thread: ${vp.thread_title}\n\n${vp.thread_body}\n\nWrite a reply:`
|
||||
cmpSystemPrompt.value = STYLE_SYSTEM
|
||||
cmpResults.value = []
|
||||
cmpLog.value = []
|
||||
}
|
||||
|
||||
function toggleCmpModel(id: string, checked: boolean) {
|
||||
const next = new Set(cmpSelectedModels.value)
|
||||
checked ? next.add(id) : next.delete(id)
|
||||
cmpSelectedModels.value = next
|
||||
}
|
||||
|
||||
function toggleAllCmpModels(checked: boolean) {
|
||||
cmpSelectedModels.value = checked
|
||||
? new Set(llmSelectableModels.value.map(m => m.id))
|
||||
: new Set()
|
||||
}
|
||||
|
||||
// ── Data loaders ──────────────────────────────────────────────────────────────
|
||||
async function loadLlmTasks() {
|
||||
llmTasksLoading.value = true
|
||||
const { data } = await useApiFetch<{ tasks: CfOrchTask[]; types: string[] }>('/api/cforch/tasks')
|
||||
llmTasksLoading.value = false
|
||||
if (data?.tasks) {
|
||||
llmTasks.value = data.tasks
|
||||
}
|
||||
}
|
||||
|
||||
async function loadLlmModels() {
|
||||
const { data } = await useApiFetch<{ models: CfOrchModel[] }>('/api/cforch/models')
|
||||
if (data?.models) {
|
||||
llmModels.value = data.models
|
||||
cmpSelectedModels.value = new Set(
|
||||
data.models.filter(m => LLM_SERVICES.has(m.service)).map(m => m.id)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// ── Run / cancel ──────────────────────────────────────────────────────────────
|
||||
function startCompare() {
|
||||
if (!cmpPrompt.value.trim() || cmpSelectedModels.value.size === 0) return
|
||||
cmpRunning.value = true
|
||||
cmpResults.value = []
|
||||
cmpLog.value = []
|
||||
|
||||
const params = new URLSearchParams({
|
||||
prompt: cmpPrompt.value,
|
||||
model_ids: [...cmpSelectedModels.value].join(','),
|
||||
system: cmpSystemPrompt.value,
|
||||
})
|
||||
|
||||
const es = new EventSource(`/api/imitate/run?${params}`)
|
||||
cmpEventSource.value = es
|
||||
|
||||
es.onmessage = (event: MessageEvent) => {
|
||||
try {
|
||||
const msg = JSON.parse(event.data)
|
||||
if (msg.type === 'start') {
|
||||
cmpLog.value.push(`Comparing ${msg.total_models} models…`)
|
||||
} else if (msg.type === 'model_start') {
|
||||
cmpLog.value.push(`→ ${msg.model}…`)
|
||||
} else if (msg.type === 'model_done') {
|
||||
const status = msg.error
|
||||
? `✕ ${msg.error}`
|
||||
: `✓ ${(msg.elapsed_ms / 1000).toFixed(1)}s`
|
||||
cmpLog.value.push(` ${msg.model}: ${status}`)
|
||||
cmpResults.value.push({
|
||||
model: msg.model,
|
||||
response: msg.response,
|
||||
elapsed_ms: msg.elapsed_ms,
|
||||
error: msg.error ?? null,
|
||||
})
|
||||
} else if (msg.type === 'complete') {
|
||||
cmpRunning.value = false
|
||||
es.close()
|
||||
}
|
||||
} catch { /* ignore malformed frames */ }
|
||||
}
|
||||
|
||||
es.onerror = () => {
|
||||
cmpLog.value.push('Connection error.')
|
||||
cmpRunning.value = false
|
||||
es.close()
|
||||
cmpEventSource.value = null
|
||||
}
|
||||
}
|
||||
|
||||
function cancelCompare() {
|
||||
cmpEventSource.value?.close()
|
||||
cmpEventSource.value = null
|
||||
cmpRunning.value = false
|
||||
cmpLog.value.push('Cancelled.')
|
||||
}
|
||||
|
||||
onMounted(() => {
|
||||
loadLlmTasks()
|
||||
loadLlmModels()
|
||||
})
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
.compare-tab {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 1.75rem;
|
||||
}
|
||||
|
||||
/* ── Source toggle ──────────────────────────────────────── */
|
||||
.source-toggle {
|
||||
display: inline-flex;
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: 0.5rem;
|
||||
overflow: hidden;
|
||||
align-self: flex-start;
|
||||
}
|
||||
|
||||
.source-btn {
|
||||
padding: 0.4rem 1rem;
|
||||
font-size: 0.83rem;
|
||||
font-family: var(--font-body, sans-serif);
|
||||
font-weight: 500;
|
||||
border: none;
|
||||
background: var(--color-surface, #fff);
|
||||
color: var(--color-text-secondary, #6b7a99);
|
||||
cursor: pointer;
|
||||
transition: background 0.15s, color 0.15s;
|
||||
}
|
||||
.source-btn:not(:last-child) { border-right: 1px solid var(--color-border, #d0d7e8); }
|
||||
.source-btn.active { background: var(--app-primary, #2A6080); color: #fff; }
|
||||
.source-btn:not(.active):hover { background: var(--color-surface-raised, #e4ebf5); }
|
||||
|
||||
/* ── Voice prompt list ──────────────────────────────────── */
|
||||
.style-prompt-list { flex-direction: column !important; flex-wrap: nowrap !important; padding-left: 0 !important; gap: 0.4rem !important; }
|
||||
|
||||
.style-prompt-row {
|
||||
flex-direction: column !important;
|
||||
align-items: flex-start !important;
|
||||
gap: 0.15rem !important;
|
||||
padding: 0.5rem 0.6rem;
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: 0.35rem;
|
||||
background: var(--color-surface, #f4f7fc);
|
||||
cursor: pointer;
|
||||
transition: background 0.1s;
|
||||
}
|
||||
.style-prompt-row:hover { background: var(--color-surface-raised, #e4ebf5); }
|
||||
.style-prompt-row:has(input:checked) {
|
||||
background: color-mix(in srgb, var(--app-primary, #2A6080) 10%, transparent);
|
||||
border-color: var(--app-primary, #2A6080);
|
||||
}
|
||||
.style-prompt-row input { display: none; }
|
||||
|
||||
.style-prompt-tag {
|
||||
font-family: var(--font-mono, monospace);
|
||||
font-size: 0.72rem;
|
||||
color: var(--app-primary, #2A6080);
|
||||
font-weight: 700;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.04em;
|
||||
}
|
||||
|
||||
.style-prompt-title {
|
||||
font-size: 0.83rem;
|
||||
color: var(--color-text, #1a2338);
|
||||
line-height: 1.4;
|
||||
}
|
||||
|
||||
/* ── Buttons ────────────────────────────────────────────── */
|
||||
.btn-run {
|
||||
padding: 0.45rem 1.1rem;
|
||||
border-radius: 0.375rem;
|
||||
border: none;
|
||||
background: var(--app-primary, #2A6080);
|
||||
color: #fff;
|
||||
font-size: 0.88rem;
|
||||
font-family: var(--font-body, sans-serif);
|
||||
cursor: pointer;
|
||||
transition: opacity 0.15s;
|
||||
}
|
||||
.btn-run:disabled { opacity: 0.5; cursor: not-allowed; }
|
||||
.btn-run:not(:disabled):hover { opacity: 0.85; }
|
||||
|
||||
.btn-cancel {
|
||||
padding: 0.45rem 0.9rem;
|
||||
background: transparent;
|
||||
border: 1px solid var(--color-text-secondary, #6b7a99);
|
||||
color: var(--color-text-secondary, #6b7a99);
|
||||
border-radius: 0.4rem;
|
||||
font-size: 0.85rem;
|
||||
font-weight: 500;
|
||||
cursor: pointer;
|
||||
transition: background 0.15s;
|
||||
}
|
||||
.btn-cancel:hover {
|
||||
background: color-mix(in srgb, var(--color-text-secondary, #6b7a99) 12%, transparent);
|
||||
}
|
||||
|
||||
/* ── Run controls row ───────────────────────────────────── */
|
||||
.run-controls {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.75rem;
|
||||
flex-wrap: wrap;
|
||||
}
|
||||
|
||||
/* ── Run log ────────────────────────────────────────────── */
|
||||
.run-log {
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: 0.5rem;
|
||||
overflow: hidden;
|
||||
font-family: var(--font-mono, monospace);
|
||||
font-size: 0.78rem;
|
||||
}
|
||||
|
||||
.log-lines {
|
||||
max-height: 160px;
|
||||
overflow-y: auto;
|
||||
padding: 0.5rem 0.75rem;
|
||||
background: var(--color-surface, #fff);
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0.1rem;
|
||||
}
|
||||
|
||||
.log-line { color: var(--color-text, #1a2338); line-height: 1.5; }
|
||||
|
||||
/* ── Chart title ────────────────────────────────────────── */
|
||||
.chart-title {
|
||||
font-size: 0.95rem;
|
||||
font-weight: 600;
|
||||
color: var(--color-text, #1a2338);
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
/* ── Model Picker ───────────────────────────────────────── */
|
||||
.model-picker {
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: 0.5rem;
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
.picker-summary {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.6rem;
|
||||
padding: 0.65rem 0.9rem;
|
||||
cursor: pointer;
|
||||
user-select: none;
|
||||
list-style: none;
|
||||
background: var(--color-surface-raised, #e4ebf5);
|
||||
}
|
||||
.picker-summary::-webkit-details-marker { display: none; }
|
||||
.picker-summary::before { content: '▶ '; font-size: 0.65rem; color: var(--color-text-secondary, #6b7a99); }
|
||||
details[open] .picker-summary::before { content: '▼ '; }
|
||||
|
||||
.picker-title {
|
||||
font-size: 0.9rem;
|
||||
font-weight: 600;
|
||||
color: var(--color-text, #1a2338);
|
||||
}
|
||||
|
||||
.picker-badge {
|
||||
font-size: 0.75rem;
|
||||
color: var(--color-text-secondary, #6b7a99);
|
||||
background: var(--color-surface, #fff);
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
padding: 0.15rem 0.5rem;
|
||||
border-radius: 1rem;
|
||||
font-family: var(--font-mono, monospace);
|
||||
margin-left: auto;
|
||||
}
|
||||
|
||||
.picker-body {
|
||||
padding: 0.75rem;
|
||||
border-top: 1px solid var(--color-border, #d0d7e8);
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0.75rem;
|
||||
}
|
||||
|
||||
.picker-loading, .picker-empty {
|
||||
font-size: 0.85rem;
|
||||
color: var(--color-text-secondary, #6b7a99);
|
||||
padding: 0.5rem 0;
|
||||
}
|
||||
|
||||
.picker-category {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0.3rem;
|
||||
}
|
||||
|
||||
.picker-cat-header {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.45rem;
|
||||
font-size: 0.82rem;
|
||||
font-weight: 700;
|
||||
color: var(--color-text, #1a2338);
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.04em;
|
||||
cursor: pointer;
|
||||
}
|
||||
|
||||
.picker-cat-name { /* inherits from cat-header or section */ }
|
||||
|
||||
.picker-cat-section {
|
||||
font-weight: 600;
|
||||
font-size: 0.82rem;
|
||||
padding: 0.35rem 0;
|
||||
display: block;
|
||||
color: var(--color-text, #1a2338);
|
||||
}
|
||||
|
||||
.picker-model-list {
|
||||
display: flex;
|
||||
flex-wrap: wrap;
|
||||
gap: 0.35rem 0.75rem;
|
||||
padding-left: 1.4rem;
|
||||
}
|
||||
|
||||
.picker-model-row {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.35rem;
|
||||
font-size: 0.82rem;
|
||||
cursor: pointer;
|
||||
color: var(--color-text, #1a2338);
|
||||
}
|
||||
|
||||
.picker-model-name {
|
||||
font-family: var(--font-mono, monospace);
|
||||
font-size: 0.78rem;
|
||||
white-space: nowrap;
|
||||
max-width: 18ch;
|
||||
overflow: hidden;
|
||||
text-overflow: ellipsis;
|
||||
}
|
||||
|
||||
.picker-adapter-type {
|
||||
font-size: 0.68rem;
|
||||
color: var(--color-text-secondary, #6b7a99);
|
||||
background: var(--color-surface-raised, #e4ebf5);
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: 0.25rem;
|
||||
padding: 0.05rem 0.3rem;
|
||||
font-family: var(--font-mono, monospace);
|
||||
}
|
||||
|
||||
/* ── Prompt editor ──────────────────────────────────────── */
|
||||
.prompt-label {
|
||||
font-size: 0.85rem;
|
||||
font-weight: 600;
|
||||
color: var(--color-text-secondary, #6b7a99);
|
||||
margin-top: 0.5rem;
|
||||
}
|
||||
|
||||
.cmp-prompt-editor {
|
||||
width: 100%;
|
||||
font-family: var(--font-mono, monospace);
|
||||
font-size: 0.85rem;
|
||||
padding: 0.75rem;
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: 0.375rem;
|
||||
background: var(--color-surface, #f0f4fc);
|
||||
color: var(--color-text, #1a2338);
|
||||
resize: vertical;
|
||||
line-height: 1.5;
|
||||
box-sizing: border-box;
|
||||
}
|
||||
.cmp-prompt-editor:focus {
|
||||
outline: 2px solid var(--app-primary, #2A6080);
|
||||
outline-offset: -1px;
|
||||
}
|
||||
|
||||
/* ── Results grid ───────────────────────────────────────── */
|
||||
.cmp-results-grid {
|
||||
display: grid;
|
||||
grid-template-columns: repeat(auto-fill, minmax(280px, 1fr));
|
||||
gap: 1rem;
|
||||
margin-top: 0.5rem;
|
||||
}
|
||||
|
||||
.cmp-result-card {
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: 0.5rem;
|
||||
overflow: hidden;
|
||||
background: var(--color-surface, #f0f4fc);
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
}
|
||||
|
||||
.cmp-result-card.cmp-error {
|
||||
border-color: #fca5a5;
|
||||
}
|
||||
|
||||
.cmp-result-header {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
padding: 0.5rem 0.75rem;
|
||||
background: var(--color-surface-raised, #e4ebf5);
|
||||
border-bottom: 1px solid var(--color-border, #d0d7e8);
|
||||
}
|
||||
|
||||
.cmp-model-name {
|
||||
font-size: 0.82rem;
|
||||
font-weight: 600;
|
||||
color: var(--color-text, #1a2338);
|
||||
overflow: hidden;
|
||||
text-overflow: ellipsis;
|
||||
white-space: nowrap;
|
||||
}
|
||||
|
||||
.cmp-meta {
|
||||
font-size: 0.75rem;
|
||||
color: var(--color-text-secondary, #6b7a99);
|
||||
flex-shrink: 0;
|
||||
margin-left: 0.5rem;
|
||||
}
|
||||
|
||||
.err-badge {
|
||||
background: #fee2e2;
|
||||
color: #991b1b;
|
||||
border-radius: 9999px;
|
||||
padding: 0.1rem 0.45rem;
|
||||
font-size: 0.7rem;
|
||||
font-weight: 600;
|
||||
}
|
||||
|
||||
.cmp-response, .cmp-error-text {
|
||||
padding: 0.75rem;
|
||||
font-size: 0.82rem;
|
||||
white-space: pre-wrap;
|
||||
word-break: break-word;
|
||||
max-height: 300px;
|
||||
overflow-y: auto;
|
||||
margin: 0;
|
||||
flex: 1;
|
||||
color: var(--color-text, #1a2338);
|
||||
}
|
||||
|
||||
.cmp-error-text { color: #b91c1c; }
|
||||
|
||||
@media (max-width: 600px) {
|
||||
.picker-model-list { padding-left: 0; }
|
||||
.picker-model-name { max-width: 14ch; }
|
||||
}
|
||||
</style>
|
||||
|
|
@ -1,31 +0,0 @@
|
|||
import { mount, flushPromises } from '@vue/test-utils'
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest'
|
||||
import CompareView from './CompareView.vue'
|
||||
|
||||
beforeEach(() => {
|
||||
vi.stubGlobal('fetch', vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
json: async () => ({ tasks: [], types: [], models: [] }),
|
||||
text: async () => '',
|
||||
}))
|
||||
vi.stubGlobal('EventSource', class {
|
||||
onmessage = null
|
||||
onerror = null
|
||||
close() {}
|
||||
})
|
||||
})
|
||||
|
||||
describe('CompareView', () => {
|
||||
it('renders page title "Compare"', async () => {
|
||||
const w = mount(CompareView)
|
||||
await flushPromises()
|
||||
expect(w.find('h1.page-title').text()).toContain('Compare')
|
||||
})
|
||||
|
||||
it('wraps CompareTab component', async () => {
|
||||
const w = mount(CompareView)
|
||||
await flushPromises()
|
||||
// CompareTab renders a .compare-tab root div
|
||||
expect(w.find('.compare-tab').exists()).toBe(true)
|
||||
})
|
||||
})
|
||||
|
|
@ -1,36 +0,0 @@
|
|||
<template>
|
||||
<div class="compare-view">
|
||||
<header class="compare-header">
|
||||
<h1 class="page-title">🔍 Compare</h1>
|
||||
</header>
|
||||
<CompareTab />
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import CompareTab from './CompareTab.vue'
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
.compare-view {
|
||||
max-width: 860px;
|
||||
margin: 0 auto;
|
||||
padding: 1.5rem 1rem 4rem;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 1.75rem;
|
||||
}
|
||||
|
||||
.compare-header {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
.page-title {
|
||||
font-family: var(--font-display, var(--font-body, sans-serif));
|
||||
font-size: 1.4rem;
|
||||
font-weight: 700;
|
||||
color: var(--app-primary, #2A6080);
|
||||
margin: 0;
|
||||
}
|
||||
</style>
|
||||
|
|
@ -1,119 +0,0 @@
|
|||
import { mount, flushPromises } from '@vue/test-utils'
|
||||
import { createRouter, createWebHashHistory } from 'vue-router'
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest'
|
||||
import DashboardView from './DashboardView.vue'
|
||||
|
||||
const router = createRouter({
|
||||
history: createWebHashHistory(),
|
||||
routes: [
|
||||
{ path: '/', component: { template: '<div />' } },
|
||||
{ path: '/eval/benchmark', component: { template: '<div />' } },
|
||||
{ path: '/train/jobs', component: { template: '<div />' } },
|
||||
{ path: '/fleet', component: { template: '<div />' } },
|
||||
],
|
||||
})
|
||||
|
||||
const baseDashboard = {
|
||||
labeled_since_last_eval: 0,
|
||||
last_eval_timestamp: null,
|
||||
last_eval_best_score: null,
|
||||
active_jobs: [],
|
||||
corrections_export_ready: 0,
|
||||
signals: { data_to_eval: false, eval_to_train: false, train_to_fleet: false },
|
||||
}
|
||||
|
||||
function mockFetch(overrides: Partial<typeof baseDashboard> = {}) {
|
||||
vi.stubGlobal('fetch', vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
json: async () => ({ ...baseDashboard, ...overrides }),
|
||||
text: async () => '',
|
||||
}))
|
||||
}
|
||||
|
||||
beforeEach(() => mockFetch())
|
||||
|
||||
describe('DashboardView', () => {
|
||||
it('renders page title', async () => {
|
||||
const w = mount(DashboardView, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
expect(w.text()).toContain('Dashboard')
|
||||
})
|
||||
|
||||
it('shows three stage cards: Data, Eval, Train', async () => {
|
||||
const w = mount(DashboardView, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
expect(w.find('.stage-card[data-stage="data"]').exists()).toBe(true)
|
||||
expect(w.find('.stage-card[data-stage="eval"]').exists()).toBe(true)
|
||||
expect(w.find('.stage-card[data-stage="train"]').exists()).toBe(true)
|
||||
})
|
||||
|
||||
it('shows labeled_since_last_eval count in Data card', async () => {
|
||||
mockFetch({ labeled_since_last_eval: 42 })
|
||||
const w = mount(DashboardView, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
expect(w.find('.stage-card[data-stage="data"]').text()).toContain('42')
|
||||
})
|
||||
|
||||
it('does NOT show Run Eval CTA when data_to_eval is false', async () => {
|
||||
mockFetch({ signals: { data_to_eval: false, eval_to_train: false, train_to_fleet: false } })
|
||||
const w = mount(DashboardView, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
const dataCard = w.find('.stage-card[data-stage="data"]')
|
||||
expect(dataCard.find('.cta-btn').exists()).toBe(false)
|
||||
})
|
||||
|
||||
it('shows Run Eval CTA when data_to_eval is true', async () => {
|
||||
mockFetch({ signals: { data_to_eval: true, eval_to_train: false, train_to_fleet: false } })
|
||||
const w = mount(DashboardView, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
const dataCard = w.find('.stage-card[data-stage="data"]')
|
||||
expect(dataCard.find('.cta-btn').exists()).toBe(true)
|
||||
expect(dataCard.find('.cta-btn').text()).toContain('Run Eval')
|
||||
})
|
||||
|
||||
it('shows Queue Finetune CTA when eval_to_train is true', async () => {
|
||||
mockFetch({ signals: { data_to_eval: false, eval_to_train: true, train_to_fleet: false } })
|
||||
const w = mount(DashboardView, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
const evalCard = w.find('.stage-card[data-stage="eval"]')
|
||||
expect(evalCard.find('.cta-btn').text()).toContain('Queue Finetune')
|
||||
})
|
||||
|
||||
it('shows Register in Fleet CTA when train_to_fleet is true', async () => {
|
||||
mockFetch({ signals: { data_to_eval: false, eval_to_train: false, train_to_fleet: true } })
|
||||
const w = mount(DashboardView, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
const trainCard = w.find('.stage-card[data-stage="train"]')
|
||||
expect(trainCard.find('.cta-btn').text()).toContain('Register in Fleet')
|
||||
})
|
||||
|
||||
it('shows active job status pills in Train card', async () => {
|
||||
mockFetch({ active_jobs: [{ id: 'j1', type: 'classifier', model_key: 'deberta-v3', status: 'running' }] })
|
||||
const w = mount(DashboardView, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
const trainCard = w.find('.stage-card[data-stage="train"]')
|
||||
expect(trainCard.find('.status-pill').exists()).toBe(true)
|
||||
expect(trainCard.text()).toContain('deberta-v3')
|
||||
})
|
||||
|
||||
it('shows last eval score in Eval card when present', async () => {
|
||||
mockFetch({ last_eval_best_score: 0.821 })
|
||||
const w = mount(DashboardView, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
const evalCard = w.find('.stage-card[data-stage="eval"]')
|
||||
expect(evalCard.text()).toContain('82.1%')
|
||||
})
|
||||
|
||||
it('shows error state when API call fails', async () => {
|
||||
vi.stubGlobal('fetch', vi.fn().mockResolvedValue({ ok: false, status: 503, text: async () => '' }))
|
||||
const w = mount(DashboardView, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
expect(w.find('.error-notice').exists()).toBe(true)
|
||||
})
|
||||
|
||||
it('shows refresh button', async () => {
|
||||
const w = mount(DashboardView, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
expect(w.find('.refresh-btn').exists()).toBe(true)
|
||||
})
|
||||
})
|
||||
|
|
@ -1,406 +0,0 @@
|
|||
<template>
|
||||
<div class="dashboard-view">
|
||||
<header class="dashboard-header">
|
||||
<h1 class="page-title">📊 Dashboard</h1>
|
||||
<button class="refresh-btn" :disabled="loading" @click="load" aria-label="Refresh dashboard">
|
||||
🔄
|
||||
</button>
|
||||
</header>
|
||||
|
||||
<div v-if="loading && !data" class="loading-state">Loading…</div>
|
||||
|
||||
<div v-if="error" class="error-notice" role="alert">
|
||||
{{ error }}
|
||||
<button class="btn-retry" @click="load">Retry</button>
|
||||
</div>
|
||||
|
||||
<div v-if="data" class="flywheel-grid">
|
||||
|
||||
<!-- ① Data card -->
|
||||
<div class="stage-card" data-stage="data">
|
||||
<div class="card-header">
|
||||
<span class="card-step">①</span>
|
||||
<h2 class="card-title">Data</h2>
|
||||
</div>
|
||||
<div class="card-body">
|
||||
<p class="card-metric">
|
||||
<strong class="metric-value">{{ data.labeled_since_last_eval.toLocaleString() }}</strong>
|
||||
<span class="metric-label"> labeled since last eval</span>
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- ② Eval card -->
|
||||
<div class="stage-card" data-stage="eval">
|
||||
<div class="card-header">
|
||||
<span class="card-step">②</span>
|
||||
<h2 class="card-title">Eval</h2>
|
||||
</div>
|
||||
<div class="card-body">
|
||||
<div class="bench-run-table">
|
||||
<div
|
||||
v-for="(run, type) in data.recent_bench_runs"
|
||||
:key="type"
|
||||
class="bench-run-row"
|
||||
>
|
||||
<span class="bench-type-label">{{ BENCH_LABELS[type as BenchType] ?? type }}</span>
|
||||
<span class="bench-run-time" :class="{ 'metric-muted': !run.timestamp }">
|
||||
{{ run.timestamp ? formatBenchTs(run.timestamp) : '—' }}
|
||||
</span>
|
||||
<span v-if="run.score != null" class="bench-run-score">
|
||||
{{ formatScore(run.score) }}
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div v-if="data.signals.eval_to_train" class="card-cta">
|
||||
<RouterLink to="/train/jobs" class="cta-btn">Queue Finetune</RouterLink>
|
||||
</div>
|
||||
<div v-if="data.signals.data_to_eval" class="card-cta">
|
||||
<RouterLink to="/eval/benchmark" class="cta-btn">Run Eval</RouterLink>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- ③ Train card -->
|
||||
<div class="stage-card" data-stage="train">
|
||||
<div class="card-header">
|
||||
<span class="card-step">③</span>
|
||||
<h2 class="card-title">Train</h2>
|
||||
</div>
|
||||
<div class="card-body">
|
||||
<template v-if="data.active_jobs.length > 0">
|
||||
<div
|
||||
v-for="job in data.active_jobs"
|
||||
:key="job.id"
|
||||
class="job-row"
|
||||
>
|
||||
<span class="job-key">{{ job.model_key }}</span>
|
||||
<span class="status-pill" :class="`status-${job.status}`">{{ job.status }}</span>
|
||||
</div>
|
||||
</template>
|
||||
<p v-else class="card-metric metric-muted">No active jobs</p>
|
||||
|
||||
<p v-if="data.corrections_export_ready > 0" class="card-metric">
|
||||
<strong class="metric-value">{{ data.corrections_export_ready }}</strong>
|
||||
<span class="metric-label"> corrections ready</span>
|
||||
</p>
|
||||
</div>
|
||||
<div v-if="data.signals.train_to_fleet" class="card-cta">
|
||||
<RouterLink to="/fleet" class="cta-btn">Register in Fleet</RouterLink>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, computed, onMounted } from 'vue'
|
||||
import { RouterLink } from 'vue-router'
|
||||
|
||||
interface ActiveJob {
|
||||
id: string
|
||||
type: string
|
||||
model_key: string
|
||||
status: 'queued' | 'running' | 'completed' | 'failed' | 'cancelled'
|
||||
}
|
||||
|
||||
interface DashboardSignals {
|
||||
data_to_eval: boolean
|
||||
eval_to_train: boolean
|
||||
train_to_fleet: boolean
|
||||
}
|
||||
|
||||
interface BenchRun {
|
||||
timestamp: string | null
|
||||
metric: string | null
|
||||
score: number | null
|
||||
}
|
||||
|
||||
type BenchType = 'classifier' | 'llm' | 'style' | 'plans'
|
||||
|
||||
interface DashboardData {
|
||||
labeled_since_last_eval: number
|
||||
last_eval_timestamp: string | null
|
||||
last_eval_best_score: number | null
|
||||
active_jobs: ActiveJob[]
|
||||
corrections_export_ready: number
|
||||
recent_bench_runs: Record<BenchType, BenchRun>
|
||||
signals: DashboardSignals
|
||||
}
|
||||
|
||||
const BENCH_LABELS: Record<BenchType, string> = {
|
||||
classifier: 'Classifier',
|
||||
llm: 'LLM Eval',
|
||||
style: 'Style',
|
||||
plans: 'Planning',
|
||||
}
|
||||
|
||||
const data = ref<DashboardData | null>(null)
|
||||
const loading = ref(false)
|
||||
const error = ref<string | null>(null)
|
||||
|
||||
function formatBenchTs(ts: string): string {
|
||||
const date = new Date(ts)
|
||||
if (!isNaN(date.getTime())) {
|
||||
const diff = Date.now() - date.getTime()
|
||||
const mins = Math.floor(diff / 60000)
|
||||
if (mins < 1) return 'just now'
|
||||
if (mins < 60) return `${mins}m ago`
|
||||
const hrs = Math.floor(mins / 60)
|
||||
if (hrs < 24) return `${hrs}h ago`
|
||||
return `${Math.floor(hrs / 24)}d ago`
|
||||
}
|
||||
// Non-ISO: show as-is (plans bench uses "YYYY-MM-DD HH:MM")
|
||||
return ts.length > 16 ? ts.slice(0, 16) : ts
|
||||
}
|
||||
|
||||
function formatScore(score: number): string {
|
||||
return `${(score * 100).toFixed(1)}%`
|
||||
}
|
||||
|
||||
async function load() {
|
||||
loading.value = true
|
||||
error.value = null
|
||||
try {
|
||||
const res = await fetch('/api/dashboard')
|
||||
if (!res.ok) {
|
||||
error.value = `Could not load dashboard (HTTP ${res.status}).`
|
||||
return
|
||||
}
|
||||
data.value = await res.json() as DashboardData
|
||||
} catch {
|
||||
error.value = 'Network error. Is the Avocet API running?'
|
||||
} finally {
|
||||
loading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
onMounted(() => load())
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
.dashboard-view {
|
||||
max-width: 860px;
|
||||
margin: 0 auto;
|
||||
padding: 1.5rem 1rem 4rem;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 1.75rem;
|
||||
}
|
||||
|
||||
.dashboard-header {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.75rem;
|
||||
}
|
||||
|
||||
.page-title {
|
||||
font-family: var(--font-display, var(--font-body, sans-serif));
|
||||
font-size: 1.4rem;
|
||||
font-weight: 700;
|
||||
color: var(--app-primary, #2A6080);
|
||||
margin: 0;
|
||||
flex: 1;
|
||||
}
|
||||
|
||||
.refresh-btn {
|
||||
background: transparent;
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: 0.375rem;
|
||||
cursor: pointer;
|
||||
font-size: 1rem;
|
||||
padding: 0.3rem 0.5rem;
|
||||
transition: background 0.15s;
|
||||
}
|
||||
|
||||
.refresh-btn:hover:not(:disabled) { background: var(--color-surface-raised, #e4ebf5); }
|
||||
.refresh-btn:disabled { opacity: 0.5; cursor: not-allowed; }
|
||||
|
||||
/* ── Flywheel grid ── */
|
||||
.flywheel-grid {
|
||||
display: grid;
|
||||
grid-template-columns: repeat(3, 1fr);
|
||||
gap: 1rem;
|
||||
}
|
||||
|
||||
@media (max-width: 680px) {
|
||||
.flywheel-grid {
|
||||
grid-template-columns: 1fr;
|
||||
}
|
||||
}
|
||||
|
||||
/* ── Stage cards ── */
|
||||
.stage-card {
|
||||
background: var(--color-surface-raised, #f5f7fc);
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: var(--radius-lg, 1rem);
|
||||
padding: 1rem;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0.75rem;
|
||||
box-shadow: var(--shadow-sm);
|
||||
}
|
||||
|
||||
.card-header {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.5rem;
|
||||
border-bottom: 1px solid var(--color-border, #d0d7e8);
|
||||
padding-bottom: 0.6rem;
|
||||
}
|
||||
|
||||
.card-step {
|
||||
font-size: 1.1rem;
|
||||
font-weight: 700;
|
||||
color: var(--app-primary, #2A6080);
|
||||
flex-shrink: 0;
|
||||
}
|
||||
|
||||
.card-title {
|
||||
font-family: var(--font-display, var(--font-body, sans-serif));
|
||||
font-size: 1rem;
|
||||
font-weight: 600;
|
||||
color: var(--color-text, #1a2338);
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
.card-body {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0.4rem;
|
||||
flex: 1;
|
||||
}
|
||||
|
||||
.card-metric {
|
||||
margin: 0;
|
||||
font-size: 0.875rem;
|
||||
color: var(--color-text, #1a2338);
|
||||
}
|
||||
|
||||
.metric-value {
|
||||
font-size: 1.05rem;
|
||||
font-weight: 700;
|
||||
color: var(--app-primary, #2A6080);
|
||||
}
|
||||
|
||||
.metric-label {
|
||||
color: var(--color-text-muted, #4a5c7a);
|
||||
}
|
||||
|
||||
.metric-muted { color: var(--color-text-muted, #4a5c7a); }
|
||||
|
||||
.card-cta { margin-top: auto; }
|
||||
|
||||
.cta-btn {
|
||||
display: block;
|
||||
width: 100%;
|
||||
text-align: center;
|
||||
padding: 0.5rem;
|
||||
background: var(--app-primary, #2A6080);
|
||||
color: #fff;
|
||||
border-radius: 0.375rem;
|
||||
text-decoration: none;
|
||||
font-size: 0.875rem;
|
||||
font-weight: 600;
|
||||
transition: background 0.15s;
|
||||
}
|
||||
|
||||
.cta-btn:hover { background: color-mix(in srgb, var(--app-primary, #2A6080) 85%, black); }
|
||||
|
||||
/* ── Bench run table ── */
|
||||
.bench-run-table {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0.3rem;
|
||||
}
|
||||
|
||||
.bench-run-row {
|
||||
display: grid;
|
||||
grid-template-columns: 6rem 1fr auto;
|
||||
align-items: center;
|
||||
gap: 0.5rem;
|
||||
font-size: 0.82rem;
|
||||
}
|
||||
|
||||
.bench-type-label {
|
||||
font-weight: 600;
|
||||
color: var(--color-text, #1a2338);
|
||||
font-size: 0.78rem;
|
||||
}
|
||||
|
||||
.bench-run-time {
|
||||
color: var(--color-text-secondary, #6b7a99);
|
||||
font-size: 0.78rem;
|
||||
}
|
||||
|
||||
.bench-run-score {
|
||||
font-family: var(--font-mono, monospace);
|
||||
font-size: 0.75rem;
|
||||
font-weight: 600;
|
||||
color: var(--app-primary, #2A6080);
|
||||
background: color-mix(in srgb, var(--app-primary, #2A6080) 10%, transparent);
|
||||
padding: 0.1rem 0.35rem;
|
||||
border-radius: 0.25rem;
|
||||
}
|
||||
|
||||
/* ── Job pills ── */
|
||||
.job-row {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.5rem;
|
||||
font-size: 0.875rem;
|
||||
}
|
||||
|
||||
.job-key {
|
||||
flex: 1;
|
||||
overflow: hidden;
|
||||
text-overflow: ellipsis;
|
||||
white-space: nowrap;
|
||||
color: var(--color-text, #1a2338);
|
||||
}
|
||||
|
||||
.status-pill {
|
||||
font-size: 0.75rem;
|
||||
padding: 0.15rem 0.45rem;
|
||||
border-radius: 100px;
|
||||
font-weight: 600;
|
||||
flex-shrink: 0;
|
||||
background: var(--color-surface-raised, #e4ebf5);
|
||||
color: var(--color-text-muted, #4a5c7a);
|
||||
}
|
||||
|
||||
.status-pill.status-running { background: #d4f4e0; color: #1a7a3a; }
|
||||
.status-pill.status-queued { background: #fef3cd; color: #856404; }
|
||||
.status-pill.status-failed { background: #fde8e8; color: #842029; }
|
||||
.status-pill.status-completed { background: #e0f0ff; color: #0c5481; }
|
||||
|
||||
/* ── State indicators ── */
|
||||
.loading-state {
|
||||
color: var(--color-text-muted, #4a5c7a);
|
||||
font-size: 0.9rem;
|
||||
}
|
||||
|
||||
.error-notice {
|
||||
background: #fde8e8;
|
||||
color: #842029;
|
||||
border: 1px solid #f5c2c7;
|
||||
border-radius: 0.5rem;
|
||||
padding: 0.75rem 1rem;
|
||||
font-size: 0.875rem;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.75rem;
|
||||
}
|
||||
|
||||
.btn-retry {
|
||||
background: transparent;
|
||||
border: 1px solid currentColor;
|
||||
border-radius: 0.25rem;
|
||||
color: inherit;
|
||||
cursor: pointer;
|
||||
font-size: 0.75rem;
|
||||
padding: 0.2rem 0.5rem;
|
||||
margin-left: auto;
|
||||
}
|
||||
</style>
|
||||
|
|
@ -1,705 +0,0 @@
|
|||
<template>
|
||||
<div class="embed-compare-page">
|
||||
<!-- Step indicator (non-interactive) -->
|
||||
<ol class="step-indicator" aria-label="Setup progress">
|
||||
<li :class="{ complete: corpus.length > 0 }">Corpus</li>
|
||||
<li :class="{ complete: queries.length > 0 }">Queries</li>
|
||||
<li :class="{ complete: selectedModels.length > 0 }">Models</li>
|
||||
<li :class="{ complete: hasResults }">Run & Rate</li>
|
||||
</ol>
|
||||
|
||||
<!-- Persistent aria-live region — always in DOM, never v-if -->
|
||||
<div
|
||||
ref="liveRegion"
|
||||
class="sr-live"
|
||||
aria-live="polite"
|
||||
aria-atomic="true"
|
||||
v-text="liveMessage"
|
||||
/>
|
||||
|
||||
<!-- ① Corpus section -->
|
||||
<section class="card" aria-labelledby="corpus-heading">
|
||||
<h2 id="corpus-heading">① Corpus</h2>
|
||||
<div class="corpus-controls">
|
||||
<div class="field">
|
||||
<label for="corpus-paste">Paste chunks (one per line)</label>
|
||||
<textarea
|
||||
id="corpus-paste"
|
||||
v-model="rawCorpus"
|
||||
rows="6"
|
||||
placeholder="Paste one chunk per line, or use Import below..."
|
||||
@change="onCorpusPaste"
|
||||
/>
|
||||
</div>
|
||||
<div class="import-row">
|
||||
<label for="imitate-product-select">Import from product</label>
|
||||
<select id="imitate-product-select" v-model="selectedProduct">
|
||||
<option value="">-- select product --</option>
|
||||
<option
|
||||
v-for="p in imitateProducts"
|
||||
:key="p.id"
|
||||
:value="p.id"
|
||||
>{{ p.name }}</option>
|
||||
</select>
|
||||
<button
|
||||
class="btn-secondary"
|
||||
:disabled="!selectedProduct || importing"
|
||||
@click="importCorpus"
|
||||
>
|
||||
{{ importing ? 'Importing…' : 'Import' }}
|
||||
</button>
|
||||
<span v-if="importError" class="error-text" role="alert">{{ importError }}</span>
|
||||
</div>
|
||||
<p v-if="corpus.length > 0" class="corpus-count">
|
||||
{{ corpus.length }} chunk{{ corpus.length === 1 ? '' : 's' }} loaded.
|
||||
</p>
|
||||
</div>
|
||||
</section>
|
||||
|
||||
<!-- ② Queries section -->
|
||||
<section class="card" aria-labelledby="queries-heading">
|
||||
<h2 id="queries-heading">② Queries</h2>
|
||||
<div class="field">
|
||||
<label for="query-input">Enter queries (one per line)</label>
|
||||
<textarea
|
||||
id="query-input"
|
||||
v-model="rawQueries"
|
||||
rows="4"
|
||||
placeholder="One query per line..."
|
||||
@change="onQueriesChange"
|
||||
/>
|
||||
</div>
|
||||
<p v-if="queries.length > 0" class="query-count">
|
||||
{{ queries.length }} quer{{ queries.length === 1 ? 'y' : 'ies' }}.
|
||||
</p>
|
||||
</section>
|
||||
|
||||
<!-- ③ Model selection -->
|
||||
<section class="card" aria-labelledby="models-heading">
|
||||
<h2 id="models-heading">③ Models</h2>
|
||||
<p v-if="loadingModels" class="muted">Loading models from Ollama…</p>
|
||||
<p v-else-if="modelsError" class="error-text" role="alert">{{ modelsError }}</p>
|
||||
<ul v-else class="model-list" role="list">
|
||||
<li v-for="m in availableModels" :key="m.name">
|
||||
<label class="model-checkbox">
|
||||
<input
|
||||
type="checkbox"
|
||||
:value="m.name"
|
||||
v-model="selectedModels"
|
||||
/>
|
||||
{{ m.name }}
|
||||
<span class="model-size muted" aria-label="model size">
|
||||
{{ formatBytes(m.size) }}
|
||||
</span>
|
||||
</label>
|
||||
</li>
|
||||
</ul>
|
||||
<p v-if="availableModels.length === 0 && !loadingModels && !modelsError" class="muted">
|
||||
No Ollama models found. Pull an embedding model first.
|
||||
</p>
|
||||
</section>
|
||||
|
||||
<!-- ④ Run controls -->
|
||||
<section class="card run-controls" aria-labelledby="run-heading">
|
||||
<h2 id="run-heading">④ Run</h2>
|
||||
<div class="run-row">
|
||||
<div class="field-inline">
|
||||
<label for="top-k-input">Results per query</label>
|
||||
<input
|
||||
id="top-k-input"
|
||||
type="number"
|
||||
v-model.number="topK"
|
||||
min="1"
|
||||
max="20"
|
||||
style="width: 5rem"
|
||||
/>
|
||||
</div>
|
||||
<button
|
||||
class="btn-primary"
|
||||
:disabled="!canRun || running"
|
||||
@click="startRun"
|
||||
>
|
||||
{{ running ? 'Running…' : 'Run' }}
|
||||
</button>
|
||||
<button
|
||||
v-if="running"
|
||||
class="btn-danger"
|
||||
aria-label="Cancel embedding run"
|
||||
@click="cancelRun"
|
||||
>
|
||||
Cancel
|
||||
</button>
|
||||
</div>
|
||||
<p v-if="!canRun && !running" class="muted">
|
||||
Fill corpus, at least one query, and select at least one model to run.
|
||||
</p>
|
||||
</section>
|
||||
|
||||
<!-- Results -->
|
||||
<section
|
||||
v-if="hasResults"
|
||||
class="card results-section"
|
||||
aria-labelledby="results-heading"
|
||||
>
|
||||
<h2 id="results-heading">Results</h2>
|
||||
|
||||
<!-- Query pagination -->
|
||||
<div class="query-nav" role="navigation" aria-label="Query navigation">
|
||||
<button
|
||||
class="btn-secondary"
|
||||
aria-label="Previous query"
|
||||
:disabled="currentQueryIdx === 0"
|
||||
@click="currentQueryIdx--"
|
||||
>‹</button>
|
||||
<span class="query-counter">
|
||||
Query {{ currentQueryIdx + 1 }} of {{ uniqueQueries.length }}:
|
||||
<em>{{ uniqueQueries[currentQueryIdx] }}</em>
|
||||
</span>
|
||||
<button
|
||||
class="btn-secondary"
|
||||
aria-label="Next query"
|
||||
:disabled="currentQueryIdx >= uniqueQueries.length - 1"
|
||||
@click="currentQueryIdx++"
|
||||
>›</button>
|
||||
</div>
|
||||
|
||||
<!-- Results table: one column per model -->
|
||||
<div class="table-wrap">
|
||||
<table class="results-table">
|
||||
<thead>
|
||||
<tr>
|
||||
<th scope="col" class="rank-col">#</th>
|
||||
<th
|
||||
v-for="model in selectedModels"
|
||||
:key="model"
|
||||
scope="col"
|
||||
>{{ model }}</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
<tr v-for="rank in topK" :key="rank">
|
||||
<td class="rank-col muted">{{ rank }}</td>
|
||||
<td
|
||||
v-for="model in selectedModels"
|
||||
:key="model"
|
||||
class="hit-cell"
|
||||
>
|
||||
<template v-if="getHit(currentQueryIdx, model, rank - 1) as hit">
|
||||
<div class="hit-text">{{ hit.text }}</div>
|
||||
<!-- Visual score bar: decorative only -->
|
||||
<div class="score-row">
|
||||
<div class="score-bar-wrap" aria-hidden="true">
|
||||
<div class="score-bar" :style="{ width: `${hit.score * 100}%` }" />
|
||||
</div>
|
||||
<span class="score-label">{{ hit.score.toFixed(3) }}</span>
|
||||
</div>
|
||||
<!-- Rating buttons -->
|
||||
<div class="rating-row">
|
||||
<button
|
||||
class="rate-btn"
|
||||
:class="{ active: getRating(currentQueryIdx, model, hit.chunk_idx) === 'relevant' }"
|
||||
:aria-pressed="getRating(currentQueryIdx, model, hit.chunk_idx) === 'relevant'"
|
||||
aria-label="Mark as relevant"
|
||||
@click="rate(currentQueryIdx, model, hit, 'relevant')"
|
||||
>
|
||||
👍 Relevant
|
||||
</button>
|
||||
<button
|
||||
class="rate-btn rate-btn-neg"
|
||||
:class="{ active: getRating(currentQueryIdx, model, hit.chunk_idx) === 'not_relevant' }"
|
||||
:aria-pressed="getRating(currentQueryIdx, model, hit.chunk_idx) === 'not_relevant'"
|
||||
aria-label="Mark as not relevant"
|
||||
@click="rate(currentQueryIdx, model, hit, 'not_relevant')"
|
||||
>
|
||||
👎 Not relevant
|
||||
</button>
|
||||
</div>
|
||||
</template>
|
||||
<span v-else class="muted">—</span>
|
||||
</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
</section>
|
||||
|
||||
<!-- Export -->
|
||||
<section
|
||||
v-if="hasResults"
|
||||
class="card export-section"
|
||||
aria-labelledby="export-heading"
|
||||
>
|
||||
<h2 id="export-heading">Export Ratings</h2>
|
||||
<div class="export-row">
|
||||
<fieldset class="export-format-group">
|
||||
<legend>Format</legend>
|
||||
<label><input type="radio" v-model="exportFormat" value="csv" /> CSV</label>
|
||||
<label><input type="radio" v-model="exportFormat" value="json" /> JSON</label>
|
||||
</fieldset>
|
||||
<button class="btn-secondary" @click="exportRatings">Export</button>
|
||||
</div>
|
||||
</section>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, computed, onMounted } from 'vue'
|
||||
|
||||
// ── Types ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
interface OllamaModel { name: string; size: number }
|
||||
interface ImitateProduct { id: string; name: string }
|
||||
interface HitResult { chunk_idx: number; text: string; score: number }
|
||||
interface ResultEvent {
|
||||
type: 'result'
|
||||
query_idx: number
|
||||
query: string
|
||||
model: string
|
||||
hits: HitResult[]
|
||||
}
|
||||
|
||||
// ── State ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
const rawCorpus = ref('')
|
||||
const corpus = ref<string[]>([])
|
||||
const rawQueries = ref('')
|
||||
const queries = ref<string[]>([])
|
||||
const selectedModels = ref<string[]>([])
|
||||
const topK = ref(5)
|
||||
const availableModels = ref<OllamaModel[]>([])
|
||||
const loadingModels = ref(false)
|
||||
const modelsError = ref('')
|
||||
const imitateProducts = ref<ImitateProduct[]>([])
|
||||
const selectedProduct = ref('')
|
||||
const importing = ref(false)
|
||||
const importError = ref('')
|
||||
const running = ref(false)
|
||||
const liveMessage = ref('')
|
||||
const resultEvents = ref<ResultEvent[]>([])
|
||||
const runController = ref<AbortController | null>(null)
|
||||
|
||||
const currentQueryIdx = ref(0)
|
||||
const exportFormat = ref<'csv' | 'json'>('csv')
|
||||
|
||||
type RatingMap = Record<string, Record<string, Record<number, 'relevant' | 'not_relevant'>>>
|
||||
const ratings = ref<RatingMap>({})
|
||||
|
||||
const uniqueQueries = computed(() => {
|
||||
const seen = new Set<string>()
|
||||
const out: string[] = []
|
||||
for (const e of resultEvents.value) {
|
||||
if (!seen.has(e.query)) { seen.add(e.query); out.push(e.query) }
|
||||
}
|
||||
return out
|
||||
})
|
||||
|
||||
const hasResults = computed(() => resultEvents.value.length > 0)
|
||||
const canRun = computed(
|
||||
() => corpus.value.length > 0 && queries.value.length > 0 && selectedModels.value.length > 0
|
||||
)
|
||||
|
||||
// ── Corpus helpers ────────────────────────────────────────────────────────────
|
||||
|
||||
function onCorpusPaste() {
|
||||
const chunks = rawCorpus.value.split('\n').map(l => l.trim()).filter(Boolean)
|
||||
corpus.value = chunks
|
||||
if (chunks.length > 0) {
|
||||
liveMessage.value = `${chunks.length} chunk${chunks.length === 1 ? '' : 's'} loaded.`
|
||||
}
|
||||
}
|
||||
|
||||
function onQueriesChange() {
|
||||
queries.value = rawQueries.value.split('\n').map(l => l.trim()).filter(Boolean)
|
||||
}
|
||||
|
||||
async function importCorpus() {
|
||||
if (!selectedProduct.value) return
|
||||
importing.value = true
|
||||
importError.value = ''
|
||||
try {
|
||||
const r = await fetch(`/api/imitate/products/${selectedProduct.value}/sample-chunks`)
|
||||
if (!r.ok) {
|
||||
const text = await r.text()
|
||||
throw new Error(text || `HTTP ${r.status}`)
|
||||
}
|
||||
const data = await r.json() as { chunks?: string[] }
|
||||
const chunks = data.chunks ?? []
|
||||
corpus.value = chunks
|
||||
rawCorpus.value = chunks.join('\n')
|
||||
liveMessage.value = `${chunks.length} chunk${chunks.length === 1 ? '' : 's'} loaded from import.`
|
||||
} catch (err) {
|
||||
importError.value = String(err)
|
||||
} finally {
|
||||
importing.value = false
|
||||
}
|
||||
}
|
||||
|
||||
// ── Model loading ─────────────────────────────────────────────────────────────
|
||||
|
||||
async function loadModels() {
|
||||
loadingModels.value = true
|
||||
modelsError.value = ''
|
||||
try {
|
||||
const r = await fetch('/api/embed-bench/models')
|
||||
if (!r.ok) throw new Error(`HTTP ${r.status}`)
|
||||
const data = await r.json() as { models: OllamaModel[] }
|
||||
availableModels.value = data.models
|
||||
} catch (err) {
|
||||
modelsError.value = `Failed to load models: ${err}`
|
||||
} finally {
|
||||
loadingModels.value = false
|
||||
}
|
||||
}
|
||||
|
||||
// ── Run ───────────────────────────────────────────────────────────────────────
|
||||
|
||||
async function startRun() {
|
||||
if (!canRun.value) return
|
||||
running.value = true
|
||||
resultEvents.value = []
|
||||
liveMessage.value = 'Starting embedding run…'
|
||||
runController.value = new AbortController()
|
||||
|
||||
try {
|
||||
const resp = await fetch('/api/embed-bench/run', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({
|
||||
corpus: corpus.value,
|
||||
queries: queries.value,
|
||||
models: selectedModels.value,
|
||||
top_k: topK.value,
|
||||
}),
|
||||
signal: runController.value.signal,
|
||||
})
|
||||
|
||||
const reader = resp.body!.getReader()
|
||||
const decoder = new TextDecoder()
|
||||
let buf = ''
|
||||
|
||||
while (true) {
|
||||
const { done, value } = await reader.read()
|
||||
if (done) break
|
||||
buf += decoder.decode(value, { stream: true })
|
||||
const lines = buf.split('\n')
|
||||
buf = lines.pop() ?? ''
|
||||
for (const line of lines) {
|
||||
if (!line.startsWith('data: ')) continue
|
||||
const event = JSON.parse(line.slice(6))
|
||||
if (event.type === 'progress') {
|
||||
liveMessage.value = event.msg
|
||||
} else if (event.type === 'result') {
|
||||
resultEvents.value.push(event as ResultEvent)
|
||||
} else if (event.type === 'done') {
|
||||
liveMessage.value = 'Run complete.'
|
||||
} else if (event.type === 'error') {
|
||||
liveMessage.value = `Error: ${event.msg}`
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (err) {
|
||||
if ((err as Error).name !== 'AbortError') {
|
||||
liveMessage.value = `Run failed: ${err}`
|
||||
}
|
||||
} finally {
|
||||
running.value = false
|
||||
runController.value = null
|
||||
}
|
||||
}
|
||||
|
||||
function cancelRun() {
|
||||
runController.value?.abort()
|
||||
liveMessage.value = 'Run cancelled.'
|
||||
}
|
||||
|
||||
// ── Utilities ─────────────────────────────────────────────────────────────────
|
||||
|
||||
function formatBytes(bytes: number): string {
|
||||
if (bytes < 1_000_000) return `${(bytes / 1000).toFixed(0)} KB`
|
||||
if (bytes < 1_000_000_000) return `${(bytes / 1_000_000).toFixed(0)} MB`
|
||||
return `${(bytes / 1_000_000_000).toFixed(1)} GB`
|
||||
}
|
||||
|
||||
function getHit(queryIdx: number, model: string, rank: number): HitResult | null {
|
||||
const query = uniqueQueries.value[queryIdx]
|
||||
if (!query) return null
|
||||
const ev = resultEvents.value.find(e => e.query === query && e.model === model)
|
||||
return ev?.hits[rank] ?? null
|
||||
}
|
||||
|
||||
function getRating(queryIdx: number, model: string, chunkIdx: number): string | undefined {
|
||||
const query = uniqueQueries.value[queryIdx]
|
||||
return ratings.value[query]?.[model]?.[chunkIdx]
|
||||
}
|
||||
|
||||
async function rate(
|
||||
queryIdx: number,
|
||||
model: string,
|
||||
hit: HitResult,
|
||||
rating: 'relevant' | 'not_relevant',
|
||||
) {
|
||||
const query = uniqueQueries.value[queryIdx]
|
||||
// Optimistic update
|
||||
if (!ratings.value[query]) ratings.value[query] = {}
|
||||
if (!ratings.value[query][model]) ratings.value[query][model] = {}
|
||||
ratings.value[query][model][hit.chunk_idx] = rating
|
||||
|
||||
try {
|
||||
await fetch('/api/embed-bench/rate', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({
|
||||
query,
|
||||
model,
|
||||
chunk_text: hit.text,
|
||||
chunk_idx: hit.chunk_idx,
|
||||
rating,
|
||||
}),
|
||||
})
|
||||
liveMessage.value = `Rated chunk ${hit.chunk_idx + 1} as ${rating}.`
|
||||
} catch (err) {
|
||||
liveMessage.value = `Rating failed: ${err}`
|
||||
}
|
||||
}
|
||||
|
||||
async function exportRatings() {
|
||||
const r = await fetch(`/api/embed-bench/export?format=${exportFormat.value}`)
|
||||
if (!r.ok) {
|
||||
liveMessage.value = `Export failed: HTTP ${r.status}`
|
||||
return
|
||||
}
|
||||
const blob = await r.blob()
|
||||
const disposition = r.headers.get('Content-Disposition') ?? ''
|
||||
const filenameMatch = disposition.match(/filename="([^"]+)"/)
|
||||
const filename = filenameMatch ? filenameMatch[1] : `embed_comparison.${exportFormat.value}`
|
||||
const url = URL.createObjectURL(blob)
|
||||
const a = document.createElement('a')
|
||||
a.href = url
|
||||
a.download = filename
|
||||
a.click()
|
||||
URL.revokeObjectURL(url)
|
||||
liveMessage.value = `Exported ${filename}.`
|
||||
}
|
||||
|
||||
// ── Lifecycle ─────────────────────────────────────────────────────────────────
|
||||
|
||||
onMounted(() => {
|
||||
loadModels()
|
||||
})
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
.embed-compare-page {
|
||||
padding: var(--space-4, 1.5rem);
|
||||
max-width: 1100px;
|
||||
}
|
||||
|
||||
/* Step indicator */
|
||||
.step-indicator {
|
||||
display: flex;
|
||||
gap: 0;
|
||||
list-style: none;
|
||||
margin: 0 0 var(--space-4, 1.5rem);
|
||||
padding: 0;
|
||||
border-bottom: 2px solid var(--color-border, #d0d7e8);
|
||||
}
|
||||
.step-indicator li {
|
||||
padding: 0.4rem 1rem;
|
||||
font-size: 0.8rem;
|
||||
font-weight: 600;
|
||||
color: var(--color-text-muted, #4a5c7a);
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.05em;
|
||||
border-bottom: 2px solid transparent;
|
||||
margin-bottom: -2px;
|
||||
}
|
||||
.step-indicator li.complete {
|
||||
color: var(--app-primary, #2A6080);
|
||||
border-bottom-color: var(--app-primary, #2A6080);
|
||||
}
|
||||
|
||||
/* Accessibility: screen-reader live region — visually hidden but always present */
|
||||
.sr-live {
|
||||
position: absolute;
|
||||
width: 1px; height: 1px;
|
||||
overflow: hidden;
|
||||
clip: rect(0 0 0 0);
|
||||
white-space: nowrap;
|
||||
}
|
||||
|
||||
/* Cards */
|
||||
.card {
|
||||
background: var(--color-surface-raised, #e4ebf5);
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: var(--radius-md, 0.5rem);
|
||||
padding: var(--space-4, 1.5rem);
|
||||
margin-bottom: var(--space-4, 1.5rem);
|
||||
}
|
||||
.card h2 {
|
||||
font-size: 1rem;
|
||||
font-weight: 700;
|
||||
margin: 0 0 var(--space-3, 1rem);
|
||||
color: var(--color-text, #1a2338);
|
||||
}
|
||||
|
||||
.field { display: flex; flex-direction: column; gap: 0.3rem; margin-bottom: 0.75rem; }
|
||||
.field label { font-size: 0.85rem; font-weight: 600; }
|
||||
textarea, input[type="number"] {
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: var(--radius-sm, 0.25rem);
|
||||
padding: 0.5rem;
|
||||
font-size: 0.875rem;
|
||||
background: var(--color-surface, #f0f4fb);
|
||||
color: var(--color-text, #1a2338);
|
||||
resize: vertical;
|
||||
}
|
||||
|
||||
.corpus-controls { display: flex; flex-direction: column; gap: 0.5rem; }
|
||||
.import-row {
|
||||
display: flex; flex-wrap: wrap; gap: 0.5rem; align-items: center;
|
||||
}
|
||||
.import-row label { font-size: 0.85rem; font-weight: 600; }
|
||||
.corpus-count, .query-count { font-size: 0.875rem; color: var(--app-primary, #2A6080); margin: 0; }
|
||||
|
||||
.model-list { list-style: none; padding: 0; margin: 0; display: flex; flex-wrap: wrap; gap: 0.5rem; }
|
||||
.model-checkbox {
|
||||
display: flex; align-items: center; gap: 0.4rem;
|
||||
font-size: 0.875rem; cursor: pointer;
|
||||
padding: 0.3rem 0.6rem;
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: var(--radius-sm, 0.25rem);
|
||||
background: var(--color-surface, #f0f4fb);
|
||||
}
|
||||
.model-size { font-size: 0.75rem; }
|
||||
|
||||
.run-row { display: flex; flex-wrap: wrap; gap: 0.75rem; align-items: flex-end; }
|
||||
.field-inline { display: flex; align-items: center; gap: 0.4rem; }
|
||||
.field-inline label { font-size: 0.85rem; font-weight: 600; white-space: nowrap; }
|
||||
|
||||
.btn-primary, .btn-secondary, .btn-danger {
|
||||
padding: 0.4rem 1rem;
|
||||
border-radius: var(--radius-sm, 0.25rem);
|
||||
border: 1px solid transparent;
|
||||
font-size: 0.875rem;
|
||||
font-weight: 600;
|
||||
cursor: pointer;
|
||||
transition: background 0.15s;
|
||||
}
|
||||
.btn-primary { background: var(--app-primary, #2A6080); color: #fff; }
|
||||
.btn-primary:hover:not(:disabled) { filter: brightness(1.1); }
|
||||
.btn-primary:disabled { opacity: 0.5; cursor: not-allowed; }
|
||||
.btn-secondary { background: var(--color-surface, #f0f4fb); color: var(--color-text, #1a2338); border-color: var(--color-border, #d0d7e8); }
|
||||
.btn-secondary:hover:not(:disabled) { background: var(--color-border, #d0d7e8); }
|
||||
.btn-secondary:disabled { opacity: 0.5; cursor: not-allowed; }
|
||||
.btn-danger { background: var(--color-error, #c0392b); color: #fff; }
|
||||
|
||||
.muted { color: var(--color-text-muted, #4a5c7a); font-size: 0.875rem; }
|
||||
.error-text { color: var(--color-error, #c0392b); font-size: 0.875rem; }
|
||||
|
||||
@media (max-width: 768px) {
|
||||
.import-row { flex-direction: column; align-items: flex-start; }
|
||||
.run-row { flex-direction: column; }
|
||||
.model-list { flex-direction: column; }
|
||||
}
|
||||
|
||||
/* Results table */
|
||||
.table-wrap { overflow-x: auto; }
|
||||
.results-table {
|
||||
width: 100%;
|
||||
border-collapse: collapse;
|
||||
font-size: 0.875rem;
|
||||
}
|
||||
.results-table thead th {
|
||||
position: sticky;
|
||||
top: 0;
|
||||
background: var(--color-surface-raised, #e4ebf5);
|
||||
border-bottom: 2px solid var(--color-border, #d0d7e8);
|
||||
padding: 0.5rem 0.75rem;
|
||||
text-align: left;
|
||||
font-weight: 700;
|
||||
white-space: nowrap;
|
||||
z-index: 1;
|
||||
}
|
||||
.results-table td {
|
||||
padding: 0.5rem 0.75rem;
|
||||
vertical-align: top;
|
||||
border-bottom: 1px solid var(--color-border, #d0d7e8);
|
||||
}
|
||||
.rank-col { width: 2rem; text-align: center; }
|
||||
|
||||
.hit-text { margin-bottom: 0.25rem; line-height: 1.4; }
|
||||
|
||||
.score-row { display: flex; align-items: center; gap: 0.4rem; margin-bottom: 0.25rem; }
|
||||
.score-bar-wrap {
|
||||
flex: 1;
|
||||
height: 6px;
|
||||
background: var(--color-border, #d0d7e8);
|
||||
border-radius: 3px;
|
||||
overflow: hidden;
|
||||
}
|
||||
.score-bar {
|
||||
height: 100%;
|
||||
background: var(--app-primary, #2A6080);
|
||||
border-radius: 3px;
|
||||
transition: width 0.3s ease;
|
||||
}
|
||||
.score-label { font-size: 0.75rem; color: var(--color-text-muted, #4a5c7a); min-width: 3rem; text-align: right; }
|
||||
|
||||
.rating-row { display: flex; gap: 0.4rem; flex-wrap: wrap; }
|
||||
.rate-btn {
|
||||
padding: 0.2rem 0.5rem;
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: var(--radius-sm, 0.25rem);
|
||||
background: var(--color-surface, #f0f4fb);
|
||||
color: var(--color-text, #1a2338);
|
||||
font-size: 0.75rem;
|
||||
cursor: pointer;
|
||||
transition: background 0.15s, border-color 0.15s;
|
||||
}
|
||||
.rate-btn.active {
|
||||
background: color-mix(in srgb, var(--app-primary, #2A6080) 20%, transparent);
|
||||
border-color: var(--app-primary, #2A6080);
|
||||
font-weight: 700;
|
||||
}
|
||||
.rate-btn-neg.active {
|
||||
background: color-mix(in srgb, var(--color-error, #c0392b) 15%, transparent);
|
||||
border-color: var(--color-error, #c0392b);
|
||||
}
|
||||
|
||||
/* Query nav */
|
||||
.query-nav {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.5rem;
|
||||
margin-bottom: 0.75rem;
|
||||
flex-wrap: wrap;
|
||||
}
|
||||
.query-counter { font-size: 0.875rem; flex: 1; }
|
||||
|
||||
/* Export */
|
||||
.export-row { display: flex; gap: 1rem; align-items: center; flex-wrap: wrap; }
|
||||
.export-format-group {
|
||||
border: none;
|
||||
padding: 0;
|
||||
display: flex;
|
||||
gap: 0.75rem;
|
||||
}
|
||||
.export-format-group legend {
|
||||
font-size: 0.85rem;
|
||||
font-weight: 600;
|
||||
margin-bottom: 0.25rem;
|
||||
float: left;
|
||||
margin-right: 0.5rem;
|
||||
}
|
||||
.export-format-group label { font-size: 0.875rem; display: flex; align-items: center; gap: 0.3rem; }
|
||||
|
||||
@media (max-width: 768px) {
|
||||
.results-table thead th,
|
||||
.results-table td { padding: 0.35rem 0.4rem; font-size: 0.8rem; }
|
||||
.query-nav { flex-direction: column; align-items: flex-start; }
|
||||
}
|
||||
|
||||
@media (prefers-reduced-motion: reduce) {
|
||||
.score-bar { transition: none; }
|
||||
}
|
||||
</style>
|
||||
|
|
@ -1,7 +0,0 @@
|
|||
<template>
|
||||
<EmbedCompareTab />
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import EmbedCompareTab from './EmbedCompareTab.vue'
|
||||
</script>
|
||||
File diff suppressed because it is too large
Load diff
|
|
@ -1,957 +0,0 @@
|
|||
<template>
|
||||
<div class="llm-eval-tab">
|
||||
|
||||
<!-- Task Selection -->
|
||||
<details class="model-picker" open>
|
||||
<summary class="picker-summary">
|
||||
<span class="picker-title">📋 Task Selection</span>
|
||||
<span class="picker-badge">{{ llmTaskBadge }}</span>
|
||||
<button class="picker-bulk-btn" @click.stop.prevent="selectAllTasks()">All</button>
|
||||
<button class="picker-bulk-btn" @click.stop.prevent="clearAllTasks()">None</button>
|
||||
</summary>
|
||||
<div class="picker-body">
|
||||
<div v-if="llmTasksLoading" class="picker-loading">Loading tasks…</div>
|
||||
<div v-else-if="Object.keys(llmTasksByType).length === 0" class="picker-empty">
|
||||
No tasks found — check API connection.
|
||||
</div>
|
||||
<template v-else>
|
||||
<div v-for="(tasks, type) in llmTasksByType" :key="type" class="picker-category">
|
||||
<label class="picker-cat-header">
|
||||
<input
|
||||
type="checkbox"
|
||||
:checked="isTaskTypeAllSelected(tasks)"
|
||||
:indeterminate="isTaskTypeIndeterminate(tasks)"
|
||||
@change="toggleTaskType(tasks, ($event.target as HTMLInputElement).checked)"
|
||||
/>
|
||||
<span class="picker-cat-name">{{ type }}</span>
|
||||
<span class="picker-cat-count">({{ tasks.length }})</span>
|
||||
</label>
|
||||
<div class="picker-model-list">
|
||||
<label v-for="t in tasks" :key="t.id" class="picker-model-row">
|
||||
<input
|
||||
type="checkbox"
|
||||
:checked="selectedLlmTasks.has(t.id)"
|
||||
@change="toggleLlmTask(t.id, ($event.target as HTMLInputElement).checked)"
|
||||
/>
|
||||
<span class="picker-model-name" :title="t.name">{{ t.name }}</span>
|
||||
</label>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
</div>
|
||||
</details>
|
||||
|
||||
<!-- Model Selection -->
|
||||
<details class="model-picker" open>
|
||||
<summary class="picker-summary">
|
||||
<span class="picker-title">🎯 Model Selection</span>
|
||||
<span class="picker-badge">{{ llmModelBadge }}</span>
|
||||
<button class="picker-bulk-btn" @click.stop.prevent="selectAllModels()">All</button>
|
||||
<button class="picker-bulk-btn" @click.stop.prevent="clearAllModels()">None</button>
|
||||
</summary>
|
||||
<div class="picker-body">
|
||||
<div v-if="llmModelsLoading" class="picker-loading">Loading models…</div>
|
||||
<div v-else-if="Object.keys(llmModelsByService).length === 0" class="picker-empty">
|
||||
No models found — check cf-orch connection.
|
||||
</div>
|
||||
<template v-else>
|
||||
<div v-for="(models, service) in llmModelsByService" :key="service" class="picker-category">
|
||||
<label class="picker-cat-header">
|
||||
<input
|
||||
type="checkbox"
|
||||
:checked="isServiceAllSelected(models)"
|
||||
:indeterminate="isServiceIndeterminate(models)"
|
||||
@change="toggleService(models, ($event.target as HTMLInputElement).checked)"
|
||||
/>
|
||||
<span class="picker-cat-name">{{ service }}</span>
|
||||
<span class="picker-cat-count">({{ models.length }})</span>
|
||||
</label>
|
||||
<div class="picker-model-list">
|
||||
<label v-for="m in models" :key="m.id" class="picker-model-row">
|
||||
<input
|
||||
type="checkbox"
|
||||
:checked="selectedLlmModels.has(m.id)"
|
||||
@change="toggleLlmModel(m.id, ($event.target as HTMLInputElement).checked)"
|
||||
/>
|
||||
<span class="picker-model-name" :title="m.name">{{ m.name }}</span>
|
||||
<span class="picker-adapter-type" v-if="m.tags.length">{{ m.tags.join(', ') }}</span>
|
||||
</label>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
</div>
|
||||
</details>
|
||||
|
||||
<!-- Node Selection -->
|
||||
<div class="node-picker" v-if="llmNodes.length > 0">
|
||||
<span class="node-picker-label">Nodes:</span>
|
||||
<label
|
||||
v-for="node in llmNodes"
|
||||
:key="node.node_id"
|
||||
class="node-chip"
|
||||
:class="{ 'node-chip--off': !enabledNodes.has(node.node_id), 'node-chip--offline': !node.online }"
|
||||
:title="node.online ? `${node.node_id} — ${node.gpus.length} GPU(s)` : `${node.node_id} — offline`"
|
||||
>
|
||||
<input
|
||||
type="checkbox"
|
||||
class="node-chip-check"
|
||||
:checked="enabledNodes.has(node.node_id)"
|
||||
:disabled="!node.online || llmRunning"
|
||||
@change="toggleNode(node.node_id, ($event.target as HTMLInputElement).checked)"
|
||||
/>
|
||||
{{ node.node_id }}
|
||||
<span class="node-chip-status" v-if="!node.online">offline</span>
|
||||
</label>
|
||||
<span class="node-picker-hint">
|
||||
{{ enabledNodeIds.length === llmNodes.filter(n => n.online).length
|
||||
? 'auto-routing (all nodes)'
|
||||
: `restricted to: ${enabledNodeIds.join(', ')}` }}
|
||||
</span>
|
||||
</div>
|
||||
|
||||
<!-- Run Controls -->
|
||||
<div class="run-controls">
|
||||
<button
|
||||
class="btn-run"
|
||||
:disabled="llmRunning || selectedLlmTasks.size === 0 || selectedLlmModels.size === 0"
|
||||
@click="startLlmBenchmark"
|
||||
>
|
||||
{{ llmRunning ? '⏳ Running…' : '▶ Run LLM Eval' }}
|
||||
</button>
|
||||
<button v-if="llmRunning" class="btn-cancel" @click="cancelLlmBenchmark">✕ Cancel</button>
|
||||
<input
|
||||
v-model="llmJudgeUrl"
|
||||
class="judge-url-input"
|
||||
placeholder="Judge URL — leave empty to skip LLM judge scoring"
|
||||
:disabled="llmRunning"
|
||||
title="Optional: URL of a running cf-text service (e.g. http://10.1.10.158:8008). When set, each LLM response gets a secondary score from the judge model — adds a 'judge' column to results. Empty = primary quality scoring only."
|
||||
/>
|
||||
<label class="workers-label" title="Run this many models concurrently (requires multiple GPUs)">
|
||||
<span class="workers-prefix">workers</span>
|
||||
<input
|
||||
v-model.number="llmWorkers"
|
||||
type="number"
|
||||
min="1"
|
||||
max="8"
|
||||
class="workers-input"
|
||||
:disabled="llmRunning"
|
||||
/>
|
||||
</label>
|
||||
<span v-if="selectedLlmTasks.size === 0 || selectedLlmModels.size === 0" class="run-hint">
|
||||
Select at least one task and one model to run.
|
||||
</span>
|
||||
</div>
|
||||
|
||||
<!-- Progress log -->
|
||||
<div v-if="llmRunning || llmRunLog.length" class="run-log">
|
||||
<div class="run-log-title">
|
||||
<span>{{ llmRunning ? '⏳ Running LLM eval…' : llmError ? '❌ Failed' : '✅ Done' }}</span>
|
||||
<button class="btn-ghost" @click="llmRunLog = []; llmError = ''">Clear</button>
|
||||
</div>
|
||||
<div class="log-lines" ref="llmLogEl">
|
||||
<div
|
||||
v-for="(line, i) in llmRunLog"
|
||||
:key="i"
|
||||
class="log-line"
|
||||
:class="{ 'log-error': line.startsWith('ERROR') || line.startsWith('[error]') }"
|
||||
>{{ line }}</div>
|
||||
</div>
|
||||
<p v-if="llmError" class="run-error">{{ llmError }}</p>
|
||||
</div>
|
||||
|
||||
<!-- Results table -->
|
||||
<template v-if="llmResults.length > 0">
|
||||
<h2 class="chart-title">LLM Eval Results</h2>
|
||||
<div class="heatmap-scroll">
|
||||
<table class="heatmap llm-results-table">
|
||||
<thead>
|
||||
<tr>
|
||||
<th class="hm-label-col">Model</th>
|
||||
<th class="hm-model-col">overall</th>
|
||||
<th v-if="llmHasJudge" class="hm-model-col hm-judge-col">judge</th>
|
||||
<th v-for="col in llmTaskTypeCols" :key="col" class="hm-model-col">{{ col }}</th>
|
||||
<th class="hm-model-col">tok/s</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
<tr v-for="row in llmResults" :key="row.model_id">
|
||||
<td class="hm-label-cell llm-model-name-cell" :title="row.model_id">{{ row.model_name }}</td>
|
||||
<td
|
||||
class="hm-value-cell"
|
||||
:class="{ 'bt-best': llmBestByCol['overall'] === row.model_id }"
|
||||
>{{ pct(row.avg_quality_score) }}</td>
|
||||
<td
|
||||
v-if="llmHasJudge"
|
||||
class="hm-value-cell hm-judge-cell"
|
||||
:class="{ 'bt-best': llmBestByCol['judge'] === row.model_id }"
|
||||
title="LLM-as-judge secondary score"
|
||||
>{{ row.avg_judge_score != null ? pct(row.avg_judge_score) : '—' }}</td>
|
||||
<td
|
||||
v-for="col in llmTaskTypeCols"
|
||||
:key="col"
|
||||
class="hm-value-cell"
|
||||
:class="{ 'bt-best': llmBestByCol[col] === row.model_id }"
|
||||
>{{ row.quality_by_task_type[col] != null ? pct(row.quality_by_task_type[col]) : '—' }}</td>
|
||||
<td class="hm-value-cell llm-tps-cell">{{ row.avg_tokens_per_sec.toFixed(1) }}</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
<p class="heatmap-hint">Run LLM Eval to refresh. Green = best per column.</p>
|
||||
</template>
|
||||
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, computed, onMounted, nextTick } from 'vue'
|
||||
import { useApiFetch } from '../composables/useApi'
|
||||
|
||||
// ── Types ───────────────────────────────────────────────────────────────────
|
||||
interface CfOrchTask {
|
||||
id: string
|
||||
name: string
|
||||
type: string
|
||||
prompt: string
|
||||
system: string
|
||||
}
|
||||
|
||||
interface CfOrchModel {
|
||||
name: string
|
||||
id: string
|
||||
service: string
|
||||
tags: string[]
|
||||
vram_estimate_mb?: number
|
||||
}
|
||||
|
||||
interface CfOrchNode {
|
||||
node_id: string
|
||||
online: boolean
|
||||
gpus: { gpu_id: number; name: string; vram_total_mb: number; vram_free_mb: number }[]
|
||||
}
|
||||
|
||||
interface LlmModelResult {
|
||||
model_name: string
|
||||
model_id: string
|
||||
node_id: string
|
||||
avg_tokens_per_sec: number
|
||||
avg_completion_ms: number
|
||||
avg_quality_score: number
|
||||
avg_judge_score: number | null
|
||||
finetune_candidates: number
|
||||
error_count: number
|
||||
quality_by_task_type: Record<string, number>
|
||||
judge_score_by_task_type?: Record<string, number>
|
||||
}
|
||||
|
||||
// ── State ───────────────────────────────────────────────────────────────────
|
||||
const llmTasks = ref<CfOrchTask[]>([])
|
||||
const llmTasksLoading = ref(false)
|
||||
const llmModels = ref<CfOrchModel[]>([])
|
||||
const llmModelsLoading = ref(false)
|
||||
|
||||
const selectedLlmTasks = ref<Set<string>>(new Set())
|
||||
const selectedLlmModels = ref<Set<string>>(new Set())
|
||||
|
||||
const llmRunning = ref(false)
|
||||
const llmRunLog = ref<string[]>([])
|
||||
const llmError = ref('')
|
||||
const llmResults = ref<LlmModelResult[]>([])
|
||||
const llmEventSource = ref<EventSource | null>(null)
|
||||
const llmLogEl = ref<HTMLElement | null>(null)
|
||||
const llmJudgeUrl = ref('')
|
||||
const llmWorkers = ref(1)
|
||||
const llmNodes = ref<CfOrchNode[]>([])
|
||||
const enabledNodes = ref<Set<string>>(new Set())
|
||||
|
||||
// ── Computed ────────────────────────────────────────────────────────────────
|
||||
const llmTasksByType = computed((): Record<string, CfOrchTask[]> => {
|
||||
const groups: Record<string, CfOrchTask[]> = {}
|
||||
for (const t of llmTasks.value) {
|
||||
if (!groups[t.type]) groups[t.type] = []
|
||||
groups[t.type].push(t)
|
||||
}
|
||||
return groups
|
||||
})
|
||||
|
||||
const llmModelsByService = computed((): Record<string, CfOrchModel[]> => {
|
||||
const groups: Record<string, CfOrchModel[]> = {}
|
||||
for (const m of llmModels.value) {
|
||||
if (!groups[m.service]) groups[m.service] = []
|
||||
groups[m.service].push(m)
|
||||
}
|
||||
return groups
|
||||
})
|
||||
|
||||
const llmTaskBadge = computed(() => {
|
||||
const total = llmTasks.value.length
|
||||
if (total === 0) return 'No tasks available'
|
||||
const sel = selectedLlmTasks.value.size
|
||||
if (sel === total) return `All tasks (${total})`
|
||||
return `${sel} of ${total} tasks selected`
|
||||
})
|
||||
|
||||
const llmModelBadge = computed(() => {
|
||||
const total = llmModels.value.length
|
||||
if (total === 0) return 'No models available'
|
||||
const sel = selectedLlmModels.value.size
|
||||
if (sel === total) return `All models (${total})`
|
||||
return `${sel} of ${total} selected`
|
||||
})
|
||||
|
||||
const llmTaskTypeCols = computed(() => {
|
||||
const types = new Set<string>()
|
||||
for (const r of llmResults.value) {
|
||||
for (const k of Object.keys(r.quality_by_task_type ?? {})) types.add(k)
|
||||
}
|
||||
return [...types].sort()
|
||||
})
|
||||
|
||||
const llmHasJudge = computed(() =>
|
||||
llmResults.value.some(r => r.avg_judge_score != null)
|
||||
)
|
||||
|
||||
const enabledNodeIds = computed(() =>
|
||||
llmNodes.value.filter(n => n.online && enabledNodes.value.has(n.node_id)).map(n => n.node_id)
|
||||
)
|
||||
|
||||
const llmBestByCol = computed((): Record<string, string> => {
|
||||
const best: Record<string, string> = {}
|
||||
if (llmResults.value.length === 0) return best
|
||||
|
||||
let bestId = '', bestVal = -Infinity
|
||||
for (const r of llmResults.value) {
|
||||
if (r.avg_quality_score > bestVal) { bestVal = r.avg_quality_score; bestId = r.model_id }
|
||||
}
|
||||
best['overall'] = bestId
|
||||
|
||||
if (llmHasJudge.value) {
|
||||
bestId = ''; bestVal = -Infinity
|
||||
for (const r of llmResults.value) {
|
||||
if (r.avg_judge_score != null && r.avg_judge_score > bestVal) {
|
||||
bestVal = r.avg_judge_score; bestId = r.model_id
|
||||
}
|
||||
}
|
||||
best['judge'] = bestId
|
||||
}
|
||||
|
||||
for (const col of llmTaskTypeCols.value) {
|
||||
bestId = ''; bestVal = -Infinity
|
||||
for (const r of llmResults.value) {
|
||||
const v = r.quality_by_task_type?.[col]
|
||||
if (v != null && v > bestVal) { bestVal = v; bestId = r.model_id }
|
||||
}
|
||||
best[col] = bestId
|
||||
}
|
||||
return best
|
||||
})
|
||||
|
||||
// ── Helpers ─────────────────────────────────────────────────────────────────
|
||||
function pct(v: number): string {
|
||||
return `${(v * 100).toFixed(1)}%`
|
||||
}
|
||||
|
||||
// Task picker helpers
|
||||
function isTaskTypeAllSelected(tasks: CfOrchTask[]): boolean {
|
||||
return tasks.length > 0 && tasks.every(t => selectedLlmTasks.value.has(t.id))
|
||||
}
|
||||
function isTaskTypeIndeterminate(tasks: CfOrchTask[]): boolean {
|
||||
const some = tasks.some(t => selectedLlmTasks.value.has(t.id))
|
||||
return some && !isTaskTypeAllSelected(tasks)
|
||||
}
|
||||
function toggleLlmTask(id: string, checked: boolean) {
|
||||
const next = new Set(selectedLlmTasks.value)
|
||||
checked ? next.add(id) : next.delete(id)
|
||||
selectedLlmTasks.value = next
|
||||
}
|
||||
function toggleTaskType(tasks: CfOrchTask[], checked: boolean) {
|
||||
const next = new Set(selectedLlmTasks.value)
|
||||
for (const t of tasks) {
|
||||
checked ? next.add(t.id) : next.delete(t.id)
|
||||
}
|
||||
selectedLlmTasks.value = next
|
||||
}
|
||||
|
||||
// Model picker helpers
|
||||
function isServiceAllSelected(models: CfOrchModel[]): boolean {
|
||||
return models.length > 0 && models.every(m => selectedLlmModels.value.has(m.id))
|
||||
}
|
||||
function isServiceIndeterminate(models: CfOrchModel[]): boolean {
|
||||
const some = models.some(m => selectedLlmModels.value.has(m.id))
|
||||
return some && !isServiceAllSelected(models)
|
||||
}
|
||||
function toggleLlmModel(id: string, checked: boolean) {
|
||||
const next = new Set(selectedLlmModels.value)
|
||||
checked ? next.add(id) : next.delete(id)
|
||||
selectedLlmModels.value = next
|
||||
}
|
||||
function toggleService(models: CfOrchModel[], checked: boolean) {
|
||||
const next = new Set(selectedLlmModels.value)
|
||||
for (const m of models) {
|
||||
checked ? next.add(m.id) : next.delete(m.id)
|
||||
}
|
||||
selectedLlmModels.value = next
|
||||
}
|
||||
function selectAllTasks() { selectedLlmTasks.value = new Set(llmTasks.value.map(t => t.id)) }
|
||||
function clearAllTasks() { selectedLlmTasks.value = new Set() }
|
||||
function selectAllModels() { selectedLlmModels.value = new Set(llmModels.value.map(m => m.id)) }
|
||||
function clearAllModels() { selectedLlmModels.value = new Set() }
|
||||
function toggleNode(id: string, checked: boolean) {
|
||||
const next = new Set(enabledNodes.value)
|
||||
checked ? next.add(id) : next.delete(id)
|
||||
enabledNodes.value = next
|
||||
}
|
||||
|
||||
// ── Data loaders ─────────────────────────────────────────────────────────────
|
||||
async function loadLlmTasks() {
|
||||
llmTasksLoading.value = true
|
||||
const { data } = await useApiFetch<{ tasks: CfOrchTask[]; types: string[] }>('/api/cforch/tasks')
|
||||
llmTasksLoading.value = false
|
||||
if (data?.tasks) {
|
||||
llmTasks.value = data.tasks
|
||||
selectedLlmTasks.value = new Set(data.tasks.map(t => t.id))
|
||||
}
|
||||
}
|
||||
|
||||
async function loadLlmModels() {
|
||||
llmModelsLoading.value = true
|
||||
const { data } = await useApiFetch<{ models: CfOrchModel[] }>('/api/cforch/models')
|
||||
llmModelsLoading.value = false
|
||||
if (data?.models) {
|
||||
llmModels.value = data.models
|
||||
selectedLlmModels.value = new Set(data.models.map(m => m.id))
|
||||
}
|
||||
}
|
||||
|
||||
async function loadLlmResults() {
|
||||
const { data } = await useApiFetch<LlmModelResult[]>('/api/cforch/results')
|
||||
if (Array.isArray(data) && data.length > 0) {
|
||||
llmResults.value = data
|
||||
}
|
||||
}
|
||||
|
||||
async function loadLlmConfig() {
|
||||
const { data } = await useApiFetch<{ judge_url?: string }>('/api/cforch/config')
|
||||
if (data?.judge_url && !llmJudgeUrl.value) {
|
||||
llmJudgeUrl.value = data.judge_url
|
||||
}
|
||||
}
|
||||
|
||||
async function loadLlmNodes() {
|
||||
const { data } = await useApiFetch<{ nodes: CfOrchNode[] }>('/api/cforch/nodes')
|
||||
if (data?.nodes) {
|
||||
llmNodes.value = data.nodes
|
||||
enabledNodes.value = new Set(data.nodes.filter(n => n.online).map(n => n.node_id))
|
||||
}
|
||||
}
|
||||
|
||||
// ── Run / cancel ──────────────────────────────────────────────────────────────
|
||||
function startLlmBenchmark() {
|
||||
llmRunning.value = true
|
||||
llmRunLog.value = []
|
||||
llmError.value = ''
|
||||
|
||||
const params = new URLSearchParams()
|
||||
const taskIds = [...selectedLlmTasks.value].join(',')
|
||||
if (taskIds) params.set('task_ids', taskIds)
|
||||
const modelIds = [...selectedLlmModels.value].join(',')
|
||||
if (modelIds) params.set('model_ids', modelIds)
|
||||
if (llmJudgeUrl.value.trim()) params.set('judge_url', llmJudgeUrl.value.trim())
|
||||
if (llmWorkers.value > 1) params.set('workers', String(llmWorkers.value))
|
||||
const onlineNodeIds = llmNodes.value.filter(n => n.online).map(n => n.node_id)
|
||||
const isRestricted = enabledNodeIds.value.length < onlineNodeIds.length
|
||||
if (isRestricted && enabledNodeIds.value.length > 0) {
|
||||
params.set('node_ids', enabledNodeIds.value.join(','))
|
||||
}
|
||||
|
||||
const es = new EventSource(`/api/cforch/run?${params}`)
|
||||
llmEventSource.value = es
|
||||
|
||||
es.onmessage = async (e: MessageEvent) => {
|
||||
const msg = JSON.parse(e.data)
|
||||
if (msg.type === 'progress' && typeof msg.message === 'string') {
|
||||
llmRunLog.value.push(msg.message)
|
||||
await nextTick()
|
||||
llmLogEl.value?.scrollTo({ top: llmLogEl.value.scrollHeight, behavior: 'smooth' })
|
||||
} else if (msg.type === 'result' && Array.isArray(msg.summary)) {
|
||||
llmResults.value = msg.summary
|
||||
} else if (msg.type === 'complete') {
|
||||
llmRunning.value = false
|
||||
es.close()
|
||||
llmEventSource.value = null
|
||||
} else if (msg.type === 'error' && typeof msg.message === 'string') {
|
||||
llmError.value = msg.message
|
||||
llmRunning.value = false
|
||||
es.close()
|
||||
llmEventSource.value = null
|
||||
}
|
||||
}
|
||||
|
||||
es.onerror = () => {
|
||||
if (llmRunning.value) llmError.value = 'Connection lost'
|
||||
llmRunning.value = false
|
||||
es.close()
|
||||
llmEventSource.value = null
|
||||
}
|
||||
}
|
||||
|
||||
async function cancelLlmBenchmark() {
|
||||
llmEventSource.value?.close()
|
||||
llmEventSource.value = null
|
||||
llmRunning.value = false
|
||||
await fetch('/api/cforch/cancel', { method: 'POST' }).catch(() => {})
|
||||
}
|
||||
|
||||
onMounted(() => {
|
||||
loadLlmTasks()
|
||||
loadLlmModels()
|
||||
loadLlmResults()
|
||||
loadLlmConfig()
|
||||
loadLlmNodes()
|
||||
})
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
.llm-eval-tab {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 1.75rem;
|
||||
}
|
||||
|
||||
/* ── Buttons ────────────────────────────────────────────── */
|
||||
.btn-run {
|
||||
padding: 0.45rem 1.1rem;
|
||||
border-radius: 0.375rem;
|
||||
border: none;
|
||||
background: var(--app-primary, #2A6080);
|
||||
color: #fff;
|
||||
font-size: 0.88rem;
|
||||
font-family: var(--font-body, sans-serif);
|
||||
cursor: pointer;
|
||||
transition: opacity 0.15s;
|
||||
}
|
||||
.btn-run:disabled { opacity: 0.5; cursor: not-allowed; }
|
||||
.btn-run:not(:disabled):hover { opacity: 0.85; }
|
||||
|
||||
.btn-cancel {
|
||||
padding: 0.45rem 0.9rem;
|
||||
background: transparent;
|
||||
border: 1px solid var(--color-text-secondary, #6b7a99);
|
||||
color: var(--color-text-secondary, #6b7a99);
|
||||
border-radius: 0.4rem;
|
||||
font-size: 0.85rem;
|
||||
font-weight: 500;
|
||||
cursor: pointer;
|
||||
transition: background 0.15s;
|
||||
}
|
||||
.btn-cancel:hover {
|
||||
background: color-mix(in srgb, var(--color-text-secondary, #6b7a99) 12%, transparent);
|
||||
}
|
||||
|
||||
.btn-ghost {
|
||||
background: none;
|
||||
border: none;
|
||||
color: var(--color-text-secondary, #6b7a99);
|
||||
cursor: pointer;
|
||||
font-size: 0.78rem;
|
||||
padding: 0.1rem 0.3rem;
|
||||
border-radius: 0.2rem;
|
||||
}
|
||||
.btn-ghost:hover { background: var(--color-border, #d0d7e8); }
|
||||
|
||||
/* ── Run controls row ───────────────────────────────────── */
|
||||
.run-controls {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.75rem;
|
||||
flex-wrap: wrap;
|
||||
}
|
||||
|
||||
.run-hint {
|
||||
font-size: 0.8rem;
|
||||
color: var(--color-text-secondary, #6b7a99);
|
||||
}
|
||||
|
||||
.judge-url-input {
|
||||
flex: 1;
|
||||
min-width: 14rem;
|
||||
max-width: 24rem;
|
||||
padding: 0.35rem 0.6rem;
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: 0.375rem;
|
||||
background: var(--color-surface, #fff);
|
||||
color: var(--color-text, #1a2338);
|
||||
font-size: 0.8rem;
|
||||
font-family: var(--font-mono, monospace);
|
||||
}
|
||||
.judge-url-input:disabled { opacity: 0.5; }
|
||||
.judge-url-input::placeholder { color: var(--color-text-secondary, #6b7a99); }
|
||||
|
||||
.workers-label {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.3rem;
|
||||
font-size: 0.8rem;
|
||||
color: var(--color-text-secondary, #6b7a99);
|
||||
white-space: nowrap;
|
||||
}
|
||||
.workers-prefix { font-family: var(--font-mono, monospace); }
|
||||
.workers-input {
|
||||
width: 3.2rem;
|
||||
padding: 0.35rem 0.4rem;
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: 0.375rem;
|
||||
background: var(--color-surface, #fff);
|
||||
color: var(--color-text, #1a2338);
|
||||
font-size: 0.8rem;
|
||||
font-family: var(--font-mono, monospace);
|
||||
text-align: center;
|
||||
}
|
||||
.workers-input:disabled { opacity: 0.5; }
|
||||
|
||||
/* ── Run log ────────────────────────────────────────────── */
|
||||
.run-log {
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: 0.5rem;
|
||||
overflow: hidden;
|
||||
font-family: var(--font-mono, monospace);
|
||||
font-size: 0.78rem;
|
||||
}
|
||||
|
||||
.run-log-title {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
padding: 0.4rem 0.75rem;
|
||||
background: var(--color-surface-raised, #e4ebf5);
|
||||
border-bottom: 1px solid var(--color-border, #d0d7e8);
|
||||
font-size: 0.8rem;
|
||||
color: var(--color-text-secondary, #6b7a99);
|
||||
}
|
||||
|
||||
.log-lines {
|
||||
max-height: 200px;
|
||||
overflow-y: auto;
|
||||
padding: 0.5rem 0.75rem;
|
||||
background: var(--color-surface, #fff);
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0.1rem;
|
||||
}
|
||||
|
||||
.log-line { color: var(--color-text, #1a2338); line-height: 1.5; }
|
||||
.log-line.log-error { color: var(--color-error, #ef4444); }
|
||||
|
||||
.run-error {
|
||||
margin: 0;
|
||||
padding: 0.4rem 0.75rem;
|
||||
background: color-mix(in srgb, var(--color-error, #ef4444) 10%, transparent);
|
||||
color: var(--color-error, #ef4444);
|
||||
font-size: 0.82rem;
|
||||
font-family: var(--font-mono, monospace);
|
||||
}
|
||||
|
||||
/* ── Chart title ────────────────────────────────────────── */
|
||||
.chart-title {
|
||||
font-size: 0.95rem;
|
||||
font-weight: 600;
|
||||
color: var(--color-text, #1a2338);
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
/* ── Heatmap ────────────────────────────────────────────── */
|
||||
.heatmap-scroll {
|
||||
overflow-x: auto;
|
||||
border-radius: 0.5rem;
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
}
|
||||
|
||||
.heatmap {
|
||||
border-collapse: collapse;
|
||||
min-width: 100%;
|
||||
font-size: 0.78rem;
|
||||
}
|
||||
|
||||
.hm-label-col {
|
||||
text-align: left;
|
||||
min-width: 11rem;
|
||||
padding: 0.4rem 0.6rem;
|
||||
background: var(--color-surface-raised, #e4ebf5);
|
||||
font-weight: 600;
|
||||
border-bottom: 1px solid var(--color-border, #d0d7e8);
|
||||
position: sticky;
|
||||
left: 0;
|
||||
}
|
||||
|
||||
.hm-model-col {
|
||||
min-width: 5rem;
|
||||
max-width: 8rem;
|
||||
padding: 0.4rem 0.5rem;
|
||||
background: var(--color-surface-raised, #e4ebf5);
|
||||
border-bottom: 1px solid var(--color-border, #d0d7e8);
|
||||
font-family: var(--font-mono, monospace);
|
||||
font-size: 0.7rem;
|
||||
text-overflow: ellipsis;
|
||||
overflow: hidden;
|
||||
white-space: nowrap;
|
||||
text-align: center;
|
||||
}
|
||||
|
||||
.hm-label-cell {
|
||||
padding: 0.35rem 0.6rem;
|
||||
background: var(--color-surface, #fff);
|
||||
border-top: 1px solid var(--color-border, #d0d7e8);
|
||||
white-space: nowrap;
|
||||
font-family: var(--font-mono, monospace);
|
||||
font-size: 0.74rem;
|
||||
position: sticky;
|
||||
left: 0;
|
||||
}
|
||||
|
||||
.hm-value-cell {
|
||||
padding: 0.35rem 0.5rem;
|
||||
text-align: center;
|
||||
font-family: var(--font-mono, monospace);
|
||||
font-variant-numeric: tabular-nums;
|
||||
border-top: 1px solid var(--color-border, #d0d7e8);
|
||||
cursor: default;
|
||||
}
|
||||
|
||||
.heatmap-hint {
|
||||
font-size: 0.75rem;
|
||||
color: var(--color-text-secondary, #6b7a99);
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
/* LLM-specific table styles */
|
||||
.llm-results-table .bt-best {
|
||||
color: var(--color-success, #3a7a32);
|
||||
font-weight: 700;
|
||||
background: color-mix(in srgb, var(--color-success, #3a7a32) 8%, transparent);
|
||||
}
|
||||
|
||||
.llm-model-name-cell {
|
||||
font-family: var(--font-mono, monospace);
|
||||
font-size: 0.75rem;
|
||||
white-space: nowrap;
|
||||
max-width: 16rem;
|
||||
overflow: hidden;
|
||||
text-overflow: ellipsis;
|
||||
background: var(--color-surface, #fff);
|
||||
border-top: 1px solid var(--color-border, #d0d7e8);
|
||||
padding: 0.35rem 0.6rem;
|
||||
position: sticky;
|
||||
left: 0;
|
||||
}
|
||||
|
||||
.llm-tps-cell {
|
||||
font-family: var(--font-mono, monospace);
|
||||
font-variant-numeric: tabular-nums;
|
||||
white-space: nowrap;
|
||||
}
|
||||
|
||||
.hm-judge-col {
|
||||
background: color-mix(in srgb, var(--color-surface-raised, #e4ebf5) 80%, #c6d5f5);
|
||||
}
|
||||
.hm-judge-cell {
|
||||
background: color-mix(in srgb, var(--color-surface, #fff) 85%, #c6d5f5);
|
||||
font-style: italic;
|
||||
opacity: 0.9;
|
||||
}
|
||||
|
||||
/* ── Model Picker ───────────────────────────────────────── */
|
||||
.model-picker {
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: 0.5rem;
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
.picker-summary {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.6rem;
|
||||
padding: 0.65rem 0.9rem;
|
||||
cursor: pointer;
|
||||
user-select: none;
|
||||
list-style: none;
|
||||
background: var(--color-surface-raised, #e4ebf5);
|
||||
}
|
||||
.picker-summary::-webkit-details-marker { display: none; }
|
||||
.picker-summary::before { content: '▶ '; font-size: 0.65rem; color: var(--color-text-secondary, #6b7a99); }
|
||||
details[open] .picker-summary::before { content: '▼ '; }
|
||||
|
||||
.picker-title {
|
||||
font-size: 0.9rem;
|
||||
font-weight: 600;
|
||||
color: var(--color-text, #1a2338);
|
||||
}
|
||||
|
||||
.picker-badge {
|
||||
font-size: 0.75rem;
|
||||
color: var(--color-text-secondary, #6b7a99);
|
||||
background: var(--color-surface, #fff);
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
padding: 0.15rem 0.5rem;
|
||||
border-radius: 1rem;
|
||||
font-family: var(--font-mono, monospace);
|
||||
margin-left: auto;
|
||||
}
|
||||
|
||||
.picker-bulk-btn {
|
||||
padding: 0.1rem 0.45rem;
|
||||
font-size: 0.7rem;
|
||||
font-family: var(--font-mono, monospace);
|
||||
background: var(--color-surface, #fff);
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: 0.25rem;
|
||||
color: var(--color-text-secondary, #6b7a99);
|
||||
cursor: pointer;
|
||||
transition: background 0.12s, color 0.12s;
|
||||
flex-shrink: 0;
|
||||
}
|
||||
.picker-bulk-btn:hover {
|
||||
background: var(--app-primary, #2A6080);
|
||||
color: #fff;
|
||||
border-color: var(--app-primary, #2A6080);
|
||||
}
|
||||
|
||||
.picker-body {
|
||||
padding: 0.75rem;
|
||||
border-top: 1px solid var(--color-border, #d0d7e8);
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0.75rem;
|
||||
}
|
||||
|
||||
.picker-loading, .picker-empty {
|
||||
font-size: 0.85rem;
|
||||
color: var(--color-text-secondary, #6b7a99);
|
||||
padding: 0.5rem 0;
|
||||
}
|
||||
|
||||
.picker-category {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0.3rem;
|
||||
}
|
||||
|
||||
.picker-cat-header {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.45rem;
|
||||
font-size: 0.82rem;
|
||||
font-weight: 700;
|
||||
color: var(--color-text, #1a2338);
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.04em;
|
||||
cursor: pointer;
|
||||
}
|
||||
|
||||
.picker-cat-name { /* inherits from cat-header */ }
|
||||
|
||||
.picker-cat-count {
|
||||
font-weight: 400;
|
||||
color: var(--color-text-secondary, #6b7a99);
|
||||
font-family: var(--font-mono, monospace);
|
||||
font-size: 0.75rem;
|
||||
text-transform: none;
|
||||
letter-spacing: 0;
|
||||
}
|
||||
|
||||
.picker-model-list {
|
||||
display: flex;
|
||||
flex-wrap: wrap;
|
||||
gap: 0.35rem 0.75rem;
|
||||
padding-left: 1.4rem;
|
||||
}
|
||||
|
||||
.picker-model-row {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.35rem;
|
||||
font-size: 0.82rem;
|
||||
cursor: pointer;
|
||||
color: var(--color-text, #1a2338);
|
||||
}
|
||||
|
||||
.picker-model-name {
|
||||
font-family: var(--font-mono, monospace);
|
||||
font-size: 0.78rem;
|
||||
white-space: nowrap;
|
||||
max-width: 18ch;
|
||||
overflow: hidden;
|
||||
text-overflow: ellipsis;
|
||||
}
|
||||
|
||||
.picker-adapter-type {
|
||||
font-size: 0.68rem;
|
||||
color: var(--color-text-secondary, #6b7a99);
|
||||
background: var(--color-surface-raised, #e4ebf5);
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: 0.25rem;
|
||||
padding: 0.05rem 0.3rem;
|
||||
font-family: var(--font-mono, monospace);
|
||||
}
|
||||
|
||||
@media (max-width: 600px) {
|
||||
.picker-model-list { padding-left: 0; }
|
||||
.picker-model-name { max-width: 14ch; }
|
||||
}
|
||||
|
||||
/* ── Node picker ────────────────────────────────────── */
|
||||
.node-picker {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.5rem;
|
||||
flex-wrap: wrap;
|
||||
padding: 0.5rem 0.75rem;
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: 0.5rem;
|
||||
background: var(--color-surface-raised, #e4ebf5);
|
||||
}
|
||||
|
||||
.node-picker-label {
|
||||
font-size: 0.78rem;
|
||||
font-weight: 600;
|
||||
color: var(--color-text-secondary, #6b7a99);
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.04em;
|
||||
white-space: nowrap;
|
||||
}
|
||||
|
||||
.node-chip {
|
||||
display: inline-flex;
|
||||
align-items: center;
|
||||
gap: 0.3rem;
|
||||
padding: 0.2rem 0.55rem;
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: 1rem;
|
||||
background: var(--color-surface, #fff);
|
||||
font-size: 0.78rem;
|
||||
font-family: var(--font-mono, monospace);
|
||||
color: var(--color-text, #1a2338);
|
||||
cursor: pointer;
|
||||
transition: background 0.12s, opacity 0.12s;
|
||||
}
|
||||
.node-chip--off {
|
||||
opacity: 0.45;
|
||||
background: transparent;
|
||||
}
|
||||
.node-chip--offline {
|
||||
opacity: 0.35;
|
||||
cursor: not-allowed;
|
||||
font-style: italic;
|
||||
}
|
||||
.node-chip-check { accent-color: var(--app-primary, #2A6080); }
|
||||
.node-chip-status {
|
||||
font-size: 0.66rem;
|
||||
color: var(--color-text-secondary, #6b7a99);
|
||||
}
|
||||
|
||||
.node-picker-hint {
|
||||
font-size: 0.72rem;
|
||||
color: var(--color-text-secondary, #6b7a99);
|
||||
font-family: var(--font-mono, monospace);
|
||||
margin-left: auto;
|
||||
}
|
||||
</style>
|
||||
|
|
@ -2,24 +2,6 @@
|
|||
<div class="models-view">
|
||||
<h1 class="page-title">🤗 Models</h1>
|
||||
|
||||
<!-- ── Fleet tab bar ─────────────────────────────── -->
|
||||
<div class="mode-toggle" role="group" aria-label="Fleet view">
|
||||
<button
|
||||
class="mode-btn"
|
||||
:class="{ active: fleetTab === 'models' }"
|
||||
@click="fleetTab = 'models'"
|
||||
>Models</button>
|
||||
<button
|
||||
class="mode-btn"
|
||||
:class="{ active: fleetTab === 'assignments' }"
|
||||
@click="fleetTab = 'assignments'"
|
||||
>Assignments</button>
|
||||
</div>
|
||||
|
||||
<AssignmentsTab v-if="fleetTab === 'assignments'" />
|
||||
|
||||
<template v-if="fleetTab === 'models'">
|
||||
|
||||
<!-- ── 1. HF Lookup ───────────────────────────────── -->
|
||||
<section class="section">
|
||||
<h2 class="section-title">HuggingFace Lookup</h2>
|
||||
|
|
@ -60,40 +42,11 @@
|
|||
<span v-if="lookupResult.pipeline_tag" class="chip chip-pipeline">
|
||||
{{ lookupResult.pipeline_tag }}
|
||||
</span>
|
||||
<span v-if="lookupResult.role" class="chip chip-role">
|
||||
{{ lookupResult.role }}
|
||||
</span>
|
||||
<span v-if="lookupResult.service" class="chip" :class="serviceChipClass(lookupResult.service)">
|
||||
{{ lookupResult.service }}
|
||||
</span>
|
||||
<span v-if="lookupResult.adapter_recommendation" class="chip chip-adapter">
|
||||
{{ lookupResult.adapter_recommendation }}
|
||||
</span>
|
||||
<span v-if="selectedQuantSize > 0" class="preview-size">
|
||||
{{ humanBytes(selectedQuantSize) }}
|
||||
</span>
|
||||
</div>
|
||||
|
||||
<!-- GGUF quantization picker — only shown for GGUF repos -->
|
||||
<div v-if="lookupResult.gguf_files?.length" class="quant-picker">
|
||||
<label class="quant-label" for="quant-select">Quantization</label>
|
||||
<select
|
||||
id="quant-select"
|
||||
v-model="selectedQuant"
|
||||
class="quant-select"
|
||||
aria-label="Select quantization variant"
|
||||
>
|
||||
<option :value="null" disabled>Select quantization…</option>
|
||||
<option
|
||||
v-for="f in lookupResult.gguf_files"
|
||||
:key="f.filename"
|
||||
:value="f.quant_name ?? f.filename"
|
||||
>
|
||||
{{ f.quant_name ?? f.filename }} — {{ humanBytes(f.size) }}
|
||||
</option>
|
||||
</select>
|
||||
<span class="quant-hint">
|
||||
Q5_K_M or Q6_K recommended for 8 GB GPUs. Q8_0 for max quality.
|
||||
<span v-if="lookupResult.size != null" class="preview-size">
|
||||
{{ humanBytes(lookupResult.size) }}
|
||||
</span>
|
||||
</div>
|
||||
|
||||
|
|
@ -101,14 +54,9 @@
|
|||
{{ lookupResult.description }}
|
||||
</p>
|
||||
|
||||
<div v-if="lookupResult.warning" class="compat-warning" role="alert">
|
||||
<span class="compat-warning-icon">⚠️</span>
|
||||
<span>{{ lookupResult.warning }}</span>
|
||||
</div>
|
||||
|
||||
<button
|
||||
class="btn-primary btn-add-queue"
|
||||
:disabled="!canAddToQueue"
|
||||
:disabled="lookupResult.already_installed || lookupResult.already_queued || addingToQueue"
|
||||
@click="addToQueue"
|
||||
>
|
||||
{{ addingToQueue ? 'Adding…' : 'Add to queue' }}
|
||||
|
|
@ -136,43 +84,11 @@
|
|||
</button>
|
||||
</div>
|
||||
<div class="model-meta">
|
||||
<span v-if="model.pipeline_tag" class="chip chip-pipeline">{{ model.pipeline_tag }}</span>
|
||||
<span v-if="model.role" class="chip chip-role">{{ model.role }}</span>
|
||||
<span v-if="model.service" class="chip" :class="serviceChipClass(model.service)">{{ model.service }}</span>
|
||||
<span v-if="model.pipeline_tag" class="chip chip-pipeline">{{ model.pipeline_tag }}</span>
|
||||
<span v-if="model.adapter_recommendation" class="chip chip-adapter">{{ model.adapter_recommendation }}</span>
|
||||
<span v-if="model.quant_pattern" class="chip chip-quant">{{ model.quant_pattern }}</span>
|
||||
</div>
|
||||
<!-- Allow manual service/role assignment for unrecognized pipeline tags -->
|
||||
<div v-if="!model.service" class="classify-row queue-classify">
|
||||
<select
|
||||
class="classify-select"
|
||||
:value="classifyDraft[model.id]?.service ?? ''"
|
||||
@change="onServiceChange(model.id, ($event.target as HTMLSelectElement).value)"
|
||||
aria-label="Assign service"
|
||||
>
|
||||
<option value="" disabled>Service…</option>
|
||||
<option v-for="svc in CLASSIFIABLE_SERVICES" :key="svc.value" :value="svc.value">{{ svc.label }}</option>
|
||||
</select>
|
||||
<select
|
||||
class="classify-select"
|
||||
:value="classifyDraft[model.id]?.role ?? ''"
|
||||
:disabled="!classifyDraft[model.id]?.service"
|
||||
@change="(e) => setClassifyRole(model.id, (e.target as HTMLSelectElement).value)"
|
||||
aria-label="Assign role"
|
||||
>
|
||||
<option value="" disabled>Role…</option>
|
||||
<option
|
||||
v-for="role in rolesForService(classifyDraft[model.id]?.service ?? '')"
|
||||
:key="role"
|
||||
:value="role"
|
||||
>{{ role }}</option>
|
||||
</select>
|
||||
</div>
|
||||
<div class="model-card-actions">
|
||||
<button
|
||||
class="btn-primary btn-sm"
|
||||
@click="approveModel(model.id, classifyDraft[model.id])"
|
||||
>
|
||||
<button class="btn-primary btn-sm" @click="approveModel(model.id)">
|
||||
Approve download
|
||||
</button>
|
||||
</div>
|
||||
|
|
@ -194,8 +110,6 @@
|
|||
</div>
|
||||
<div class="model-meta">
|
||||
<span v-if="model.pipeline_tag" class="chip chip-pipeline">{{ model.pipeline_tag }}</span>
|
||||
<span v-if="model.role" class="chip chip-role">{{ model.role }}</span>
|
||||
<span v-if="model.service" class="chip" :class="serviceChipClass(model.service)">{{ model.service }}</span>
|
||||
</div>
|
||||
|
||||
<div v-if="downloadErrors[model.id]" class="download-error" role="alert">
|
||||
|
|
@ -204,19 +118,14 @@
|
|||
<div v-else class="progress-wrap" :aria-label="`Download progress for ${model.repo_id}`">
|
||||
<div
|
||||
class="progress-bar"
|
||||
:style="{ width: `${downloadProgress[model.repo_id]?.pct ?? 0}%` }"
|
||||
:style="{ width: `${downloadProgress[model.id] ?? 0}%` }"
|
||||
role="progressbar"
|
||||
:aria-valuenow="downloadProgress[model.repo_id]?.pct ?? 0"
|
||||
:aria-valuenow="downloadProgress[model.id] ?? 0"
|
||||
aria-valuemin="0"
|
||||
aria-valuemax="100"
|
||||
/>
|
||||
<span class="progress-label">
|
||||
{{
|
||||
!downloadProgress[model.repo_id] ? 'Preparing…'
|
||||
: downloadProgress[model.repo_id].pct != null ? `${Math.round(downloadProgress[model.repo_id].pct!)}%`
|
||||
: downloadProgress[model.repo_id].bytes > 0 ? `${(downloadProgress[model.repo_id].bytes / 1024 / 1024).toFixed(0)} MB downloaded…`
|
||||
: 'Preparing…'
|
||||
}}
|
||||
{{ downloadProgress[model.id] == null ? 'Preparing…' : `${downloadProgress[model.id]}%` }}
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
|
|
@ -230,121 +139,56 @@
|
|||
No models installed yet.
|
||||
</div>
|
||||
|
||||
<template v-else>
|
||||
<div
|
||||
v-for="group in installedByService"
|
||||
:key="group.service"
|
||||
class="installed-group"
|
||||
>
|
||||
<div class="installed-group-header">
|
||||
<span class="chip" :class="serviceChipClass(group.service)">
|
||||
{{ serviceLabel(group.service) }}
|
||||
</span>
|
||||
<span class="installed-group-count">{{ group.models.length }} model{{ group.models.length !== 1 ? 's' : '' }}</span>
|
||||
</div>
|
||||
|
||||
<div class="installed-table-wrap">
|
||||
<table class="installed-table">
|
||||
<thead>
|
||||
<tr>
|
||||
<th>Name</th>
|
||||
<th>Type</th>
|
||||
<th>Role</th>
|
||||
<th>Size</th>
|
||||
<th></th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
<tr v-for="model in group.models" :key="model.name">
|
||||
<td class="td-name">{{ model.model_id ?? model.name }}</td>
|
||||
<td>
|
||||
<span
|
||||
class="badge"
|
||||
:class="model.type === 'finetuned' ? 'badge-accent' : 'badge-info'"
|
||||
>
|
||||
{{ model.type }}
|
||||
</span>
|
||||
</td>
|
||||
<td>
|
||||
<span v-if="model.role" class="chip chip-role chip-sm">{{ model.role }}</span>
|
||||
<span v-else>—</span>
|
||||
</td>
|
||||
<td>{{ humanBytes(model.size_bytes) }}</td>
|
||||
<td class="td-actions">
|
||||
<div v-if="!model.service" class="classify-row">
|
||||
<select
|
||||
class="classify-select"
|
||||
:value="classifyDraft[model.name]?.service ?? ''"
|
||||
@change="onServiceChange(model.name, ($event.target as HTMLSelectElement).value)"
|
||||
aria-label="Assign service"
|
||||
>
|
||||
<option value="" disabled>Service…</option>
|
||||
<option v-for="svc in CLASSIFIABLE_SERVICES" :key="svc.value" :value="svc.value">{{ svc.label }}</option>
|
||||
</select>
|
||||
<select
|
||||
class="classify-select"
|
||||
:value="classifyDraft[model.name]?.role ?? ''"
|
||||
:disabled="!classifyDraft[model.name]?.service"
|
||||
@change="(e) => setClassifyRole(model.name, (e.target as HTMLSelectElement).value)"
|
||||
aria-label="Assign role"
|
||||
>
|
||||
<option value="" disabled>Role…</option>
|
||||
<option
|
||||
v-for="role in rolesForService(classifyDraft[model.name]?.service ?? '')"
|
||||
:key="role"
|
||||
:value="role"
|
||||
>{{ role }}</option>
|
||||
</select>
|
||||
<button
|
||||
class="btn-primary btn-sm"
|
||||
:disabled="!classifyDraft[model.name]?.service || !classifyDraft[model.name]?.role"
|
||||
@click="saveClassify(model.name)"
|
||||
>Save</button>
|
||||
</div>
|
||||
<button
|
||||
class="btn-danger btn-sm"
|
||||
@click="deleteInstalled(model.name)"
|
||||
>
|
||||
Delete
|
||||
</button>
|
||||
</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
<div v-else class="installed-table-wrap">
|
||||
<table class="installed-table">
|
||||
<thead>
|
||||
<tr>
|
||||
<th>Name</th>
|
||||
<th>Type</th>
|
||||
<th>Adapter</th>
|
||||
<th>Size</th>
|
||||
<th></th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
<tr v-for="model in installedModels" :key="model.name">
|
||||
<td class="td-name">{{ model.name }}</td>
|
||||
<td>
|
||||
<span
|
||||
class="badge"
|
||||
:class="model.type === 'finetuned' ? 'badge-accent' : 'badge-info'"
|
||||
>
|
||||
{{ model.type }}
|
||||
</span>
|
||||
</td>
|
||||
<td>{{ model.adapter ?? '—' }}</td>
|
||||
<td>{{ humanBytes(model.size) }}</td>
|
||||
<td>
|
||||
<button
|
||||
class="btn-danger btn-sm"
|
||||
@click="deleteInstalled(model.name)"
|
||||
>
|
||||
Delete
|
||||
</button>
|
||||
</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
</section>
|
||||
|
||||
</template><!-- end fleetTab === 'models' -->
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, computed, onMounted, onUnmounted } from 'vue'
|
||||
import AssignmentsTab from './AssignmentsTab.vue'
|
||||
|
||||
type FleetTab = 'models' | 'assignments'
|
||||
const fleetTab = ref<FleetTab>('models')
|
||||
|
||||
// ── Type definitions ──────────────────────────────────
|
||||
|
||||
interface GgufFile {
|
||||
filename: string
|
||||
size: number
|
||||
quant_name: string | null
|
||||
}
|
||||
|
||||
interface LookupResult {
|
||||
repo_id: string
|
||||
pipeline_tag: string | null
|
||||
adapter_recommendation: string | null
|
||||
role: string | null
|
||||
service: string | null
|
||||
compatible: boolean
|
||||
warning: string | null
|
||||
model_size_bytes: number
|
||||
gguf_files: GgufFile[] | null
|
||||
size: number | null
|
||||
description: string | null
|
||||
already_installed: boolean
|
||||
already_queued: boolean
|
||||
|
|
@ -356,28 +200,20 @@ interface QueuedModel {
|
|||
status: 'pending' | 'downloading' | 'done' | 'error'
|
||||
pipeline_tag: string | null
|
||||
adapter_recommendation: string | null
|
||||
role: string | null
|
||||
service: string | null
|
||||
quant_pattern: string | null
|
||||
}
|
||||
|
||||
interface InstalledModel {
|
||||
name: string
|
||||
type: 'finetuned' | 'downloaded'
|
||||
adapter: string | null
|
||||
role: string | null
|
||||
service: string | null
|
||||
size_bytes: number
|
||||
model_id: string | null
|
||||
size: number
|
||||
}
|
||||
|
||||
interface SseProgressEvent {
|
||||
type: 'progress' | 'done' | 'error' | 'idle'
|
||||
repo_id?: string
|
||||
pct?: number
|
||||
downloaded_bytes?: number
|
||||
total_bytes?: number
|
||||
error?: string
|
||||
model_id: string
|
||||
pct: number | null
|
||||
status: 'progress' | 'done' | 'error'
|
||||
message?: string
|
||||
}
|
||||
|
||||
// ── State ─────────────────────────────────────────────
|
||||
|
|
@ -387,32 +223,11 @@ const lookupLoading = ref(false)
|
|||
const lookupError = ref<string | null>(null)
|
||||
const lookupResult = ref<LookupResult | null>(null)
|
||||
const addingToQueue = ref(false)
|
||||
const selectedQuant = ref<string | null>(null)
|
||||
|
||||
// Size of the selected GGUF file, or total model size for non-GGUF repos.
|
||||
const selectedQuantSize = computed<number>(() => {
|
||||
const r = lookupResult.value
|
||||
if (!r) return 0
|
||||
if (r.gguf_files?.length && selectedQuant.value) {
|
||||
const f = r.gguf_files.find(f => (f.quant_name ?? f.filename) === selectedQuant.value)
|
||||
return f?.size ?? r.model_size_bytes
|
||||
}
|
||||
return r.model_size_bytes
|
||||
})
|
||||
|
||||
// Disable "Add to queue" when a GGUF repo but no quant chosen yet.
|
||||
const canAddToQueue = computed(() => {
|
||||
const r = lookupResult.value
|
||||
if (!r || r.already_installed || r.already_queued || addingToQueue.value) return false
|
||||
if (r.gguf_files?.length && !selectedQuant.value) return false
|
||||
return true
|
||||
})
|
||||
|
||||
const queuedModels = ref<QueuedModel[]>([])
|
||||
const installedModels = ref<InstalledModel[]>([])
|
||||
|
||||
const downloadProgress = ref<Record<string, { pct: number | null; bytes: number }>>({})
|
||||
const classifyDraft = ref<Record<string, { service: string; role: string }>>({})
|
||||
const downloadProgress = ref<Record<string, number>>({})
|
||||
const downloadErrors = ref<Record<string, string>>({})
|
||||
|
||||
let pollInterval: ReturnType<typeof setInterval> | null = null
|
||||
|
|
@ -428,69 +243,8 @@ const downloadingModels = computed(() =>
|
|||
queuedModels.value.filter(m => m.status === 'downloading')
|
||||
)
|
||||
|
||||
const SERVICE_ORDER = ['avocet', 'cf-text', 'cf-stt', 'cf-tts', 'cf-vision', 'cf-image', 'cf-core', 'cf-voice', 'other']
|
||||
|
||||
const CLASSIFIABLE_SERVICES = [
|
||||
{ value: 'avocet', label: 'Avocet — Email Classifiers' },
|
||||
{ value: 'cf-text', label: 'cf-text — Language Models' },
|
||||
{ value: 'cf-stt', label: 'cf-stt — Speech Recognition' },
|
||||
{ value: 'cf-tts', label: 'cf-tts — Text to Speech' },
|
||||
{ value: 'cf-vision', label: 'cf-vision — Vision / VLM' },
|
||||
{ value: 'cf-image', label: 'cf-image — Image Generation' },
|
||||
{ value: 'cf-core', label: 'cf-core — Embeddings' },
|
||||
{ value: 'cf-voice', label: 'cf-voice — Audio Classification' },
|
||||
]
|
||||
|
||||
const SERVICE_ROLES: Record<string, string[]> = {
|
||||
'avocet': ['classifier', 'reranker'],
|
||||
'cf-text': ['generator'],
|
||||
'cf-stt': ['stt', 'alm'],
|
||||
'cf-tts': ['tts'],
|
||||
'cf-vision': ['vision', 'vlm', 'embedding'],
|
||||
'cf-image': ['image-gen'],
|
||||
'cf-core': ['embedding'],
|
||||
'cf-voice': ['classifier'],
|
||||
}
|
||||
|
||||
function rolesForService(service: string): string[] {
|
||||
return SERVICE_ROLES[service] ?? []
|
||||
}
|
||||
|
||||
const installedByService = computed(() => {
|
||||
const grouped: Record<string, InstalledModel[]> = {}
|
||||
for (const model of installedModels.value) {
|
||||
const key = model.service ?? 'other'
|
||||
if (!grouped[key]) grouped[key] = []
|
||||
grouped[key].push(model)
|
||||
}
|
||||
// Return ordered sections: known services first, then anything else
|
||||
const keys = [...SERVICE_ORDER.filter(s => grouped[s]), ...Object.keys(grouped).filter(k => !SERVICE_ORDER.includes(k))]
|
||||
return keys.map(key => ({ service: key, models: grouped[key] }))
|
||||
})
|
||||
|
||||
// ── Helpers ───────────────────────────────────────────
|
||||
|
||||
const SERVICE_LABELS: Record<string, string> = {
|
||||
'avocet': 'Avocet — Email Classifiers',
|
||||
'cf-text': 'cf-text — Language Models',
|
||||
'cf-stt': 'cf-stt — Speech Recognition',
|
||||
'cf-tts': 'cf-tts — Text to Speech',
|
||||
'cf-vision': 'cf-vision — Vision / VLM',
|
||||
'cf-image': 'cf-image — Image Generation',
|
||||
'cf-core': 'cf-core — Embeddings',
|
||||
'cf-voice': 'cf-voice — Audio Classification',
|
||||
'other': 'Other — Unclassified',
|
||||
}
|
||||
|
||||
function serviceLabel(service: string): string {
|
||||
return SERVICE_LABELS[service] ?? service
|
||||
}
|
||||
|
||||
function serviceChipClass(service: string | null): string {
|
||||
if (!service) return 'chip-service-other'
|
||||
return `chip-service-${service.replace(/[^a-z0-9]/g, '-')}`
|
||||
}
|
||||
|
||||
function humanBytes(bytes: number | null): string {
|
||||
if (bytes == null) return '—'
|
||||
const units = ['B', 'KB', 'MB', 'GB', 'TB']
|
||||
|
|
@ -516,7 +270,6 @@ async function doLookup() {
|
|||
lookupLoading.value = true
|
||||
lookupError.value = null
|
||||
lookupResult.value = null
|
||||
selectedQuant.value = null
|
||||
|
||||
try {
|
||||
const res = await fetch(`/api/models/lookup?repo_id=${encodeURIComponent(repoId)}`)
|
||||
|
|
@ -544,19 +297,10 @@ async function addToQueue() {
|
|||
if (!lookupResult.value) return
|
||||
addingToQueue.value = true
|
||||
try {
|
||||
const { repo_id, pipeline_tag, adapter_recommendation, role, service } = lookupResult.value
|
||||
const res = await fetch('/api/models/queue', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({
|
||||
repo_id,
|
||||
pipeline_tag,
|
||||
adapter_recommendation,
|
||||
role,
|
||||
service,
|
||||
model_size_bytes: selectedQuantSize.value,
|
||||
quant_pattern: selectedQuant.value,
|
||||
}),
|
||||
body: JSON.stringify({ repo_id: lookupResult.value.repo_id }),
|
||||
})
|
||||
if (res.ok) {
|
||||
lookupResult.value = { ...lookupResult.value, already_queued: true }
|
||||
|
|
@ -568,16 +312,8 @@ async function addToQueue() {
|
|||
}
|
||||
}
|
||||
|
||||
async function approveModel(id: string, draft?: { service: string; role: string }) {
|
||||
async function approveModel(id: string) {
|
||||
try {
|
||||
// If the user picked a service/role for an unrecognized model, patch it first.
|
||||
if (draft?.service && draft?.role) {
|
||||
await fetch(`/api/models/queue/${encodeURIComponent(id)}`, {
|
||||
method: 'PATCH',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ service: draft.service, role: draft.role }),
|
||||
})
|
||||
}
|
||||
const res = await fetch(`/api/models/queue/${encodeURIComponent(id)}/approve`, { method: 'POST' })
|
||||
if (res.ok) {
|
||||
await loadQueue()
|
||||
|
|
@ -595,50 +331,12 @@ async function dismissModel(id: string) {
|
|||
} catch { /* ignore */ }
|
||||
}
|
||||
|
||||
function onServiceChange(name: string, service: string) {
|
||||
const roles = SERVICE_ROLES[service] ?? []
|
||||
classifyDraft.value = {
|
||||
...classifyDraft.value,
|
||||
[name]: { service, role: roles.length === 1 ? roles[0] : '' },
|
||||
}
|
||||
}
|
||||
|
||||
function setClassifyRole(name: string, role: string) {
|
||||
classifyDraft.value = {
|
||||
...classifyDraft.value,
|
||||
[name]: { ...classifyDraft.value[name], role },
|
||||
}
|
||||
}
|
||||
|
||||
async function saveClassify(name: string) {
|
||||
const draft = classifyDraft.value[name]
|
||||
if (!draft?.service || !draft?.role) return
|
||||
try {
|
||||
const res = await fetch(`/api/models/installed/${encodeURIComponent(name)}`, {
|
||||
method: 'PATCH',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ service: draft.service, role: draft.role }),
|
||||
})
|
||||
if (res.ok) {
|
||||
// Update in-place so the model moves to the correct service group
|
||||
installedModels.value = installedModels.value.map(m =>
|
||||
m.name === name ? { ...m, service: draft.service, role: draft.role } : m
|
||||
)
|
||||
const updated = { ...classifyDraft.value }
|
||||
delete updated[name]
|
||||
classifyDraft.value = updated
|
||||
await loadQueue()
|
||||
}
|
||||
} catch { /* non-fatal */ }
|
||||
}
|
||||
|
||||
async function deleteInstalled(name: string) {
|
||||
if (!window.confirm(`Delete installed model "${name}"? This cannot be undone.`)) return
|
||||
try {
|
||||
const res = await fetch(`/api/models/installed/${encodeURIComponent(name)}`, { method: 'DELETE' })
|
||||
if (res.ok) {
|
||||
installedModels.value = installedModels.value.filter(m => m.name !== name)
|
||||
await loadQueue()
|
||||
}
|
||||
} catch { /* ignore */ }
|
||||
}
|
||||
|
|
@ -672,28 +370,21 @@ function startSse() {
|
|||
return
|
||||
}
|
||||
|
||||
const { type, repo_id, pct, downloaded_bytes, error } = event
|
||||
if (!repo_id) return
|
||||
const { model_id, pct, status, message } = event
|
||||
|
||||
if (type === 'progress') {
|
||||
const bytes = downloaded_bytes ?? 0
|
||||
// pct stays null when total_bytes is unknown so we can show "X MB" instead
|
||||
const progress = (pct != null && pct > 0) ? pct : (bytes > 0 ? null : undefined)
|
||||
downloadProgress.value = { ...downloadProgress.value, [repo_id]: { pct: progress ?? null, bytes } }
|
||||
} else if (type === 'done') {
|
||||
if (status === 'progress' && pct != null) {
|
||||
downloadProgress.value = { ...downloadProgress.value, [model_id]: pct }
|
||||
} else if (status === 'done') {
|
||||
const updated = { ...downloadProgress.value }
|
||||
delete updated[repo_id]
|
||||
delete updated[model_id]
|
||||
downloadProgress.value = updated
|
||||
|
||||
queuedModels.value = queuedModels.value.filter(m => m.repo_id !== repo_id)
|
||||
queuedModels.value = queuedModels.value.filter(m => m.id !== model_id)
|
||||
loadInstalled()
|
||||
} else if (type === 'error') {
|
||||
const entry = queuedModels.value.find(m => m.repo_id === repo_id)
|
||||
if (entry) {
|
||||
downloadErrors.value = {
|
||||
...downloadErrors.value,
|
||||
[entry.id]: error ?? 'Download failed.',
|
||||
}
|
||||
} else if (status === 'error') {
|
||||
downloadErrors.value = {
|
||||
...downloadErrors.value,
|
||||
[model_id]: message ?? 'Download failed.',
|
||||
}
|
||||
}
|
||||
})
|
||||
|
|
@ -762,39 +453,6 @@ onUnmounted(() => {
|
|||
color: var(--color-primary, #2d5a27);
|
||||
}
|
||||
|
||||
/* ── Fleet tab bar (mode-toggle pattern from BenchmarkView) ── */
|
||||
.mode-toggle {
|
||||
display: inline-flex;
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: 0.5rem;
|
||||
overflow: hidden;
|
||||
align-self: flex-start;
|
||||
}
|
||||
.mode-btn {
|
||||
padding: 0.4rem 1.1rem;
|
||||
font-size: 0.85rem;
|
||||
font-family: var(--font-body, sans-serif);
|
||||
font-weight: 500;
|
||||
border: none;
|
||||
background: var(--color-surface, #fff);
|
||||
color: var(--color-text-secondary, #6b7a99);
|
||||
cursor: pointer;
|
||||
transition: background 0.15s, color 0.15s;
|
||||
}
|
||||
.mode-btn:not(:last-child) {
|
||||
border-right: 1px solid var(--color-border, #d0d7e8);
|
||||
}
|
||||
.mode-btn.active {
|
||||
background: var(--app-primary, #2A6080);
|
||||
color: #fff;
|
||||
}
|
||||
.mode-btn:not(.active):hover {
|
||||
background: var(--color-surface-raised, #e4ebf5);
|
||||
}
|
||||
@media (max-width: 600px) {
|
||||
.mode-btn { padding: 0.4rem 0.65rem; font-size: 0.78rem; }
|
||||
}
|
||||
|
||||
/* ── Sections ── */
|
||||
.section {
|
||||
display: flex;
|
||||
|
|
@ -907,66 +565,10 @@ onUnmounted(() => {
|
|||
overflow: hidden;
|
||||
}
|
||||
|
||||
.compat-warning {
|
||||
display: flex;
|
||||
align-items: flex-start;
|
||||
gap: 0.5rem;
|
||||
padding: 0.6rem 0.75rem;
|
||||
border-radius: var(--radius-sm, 0.25rem);
|
||||
background: color-mix(in srgb, var(--color-warning, #f59e0b) 12%, transparent);
|
||||
border: 1px solid color-mix(in srgb, var(--color-warning, #f59e0b) 40%, transparent);
|
||||
font-size: 0.82rem;
|
||||
color: var(--color-text, #1a2338);
|
||||
line-height: 1.45;
|
||||
}
|
||||
|
||||
.compat-warning-icon {
|
||||
flex-shrink: 0;
|
||||
line-height: 1.45;
|
||||
}
|
||||
|
||||
.btn-add-queue {
|
||||
align-self: flex-start;
|
||||
}
|
||||
|
||||
/* ── Quant picker ── */
|
||||
.quant-picker {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0.35rem;
|
||||
}
|
||||
|
||||
.quant-label {
|
||||
font-size: 0.8rem;
|
||||
font-weight: 600;
|
||||
color: var(--color-text-muted, #4a5c7a);
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.04em;
|
||||
}
|
||||
|
||||
.quant-select {
|
||||
padding: 0.4rem 0.6rem;
|
||||
border: 1px solid var(--color-border, #a8b8d0);
|
||||
border-radius: var(--radius-md, 0.5rem);
|
||||
background: var(--color-surface, #f0f4fb);
|
||||
color: var(--color-text, #1a2338);
|
||||
font-size: 0.9rem;
|
||||
font-family: var(--font-mono, monospace);
|
||||
cursor: pointer;
|
||||
}
|
||||
|
||||
.quant-hint {
|
||||
font-size: 0.78rem;
|
||||
color: var(--color-text-muted, #4a5c7a);
|
||||
}
|
||||
|
||||
.chip-quant {
|
||||
background: color-mix(in srgb, var(--color-primary, #2A6080) 15%, transparent);
|
||||
color: var(--color-primary, #2A6080);
|
||||
font-family: var(--font-mono, monospace);
|
||||
font-size: 0.75rem;
|
||||
}
|
||||
|
||||
/* ── Model cards (queue + downloads) ── */
|
||||
.model-card {
|
||||
border: 1px solid var(--color-border, #a8b8d0);
|
||||
|
|
@ -1081,35 +683,6 @@ onUnmounted(() => {
|
|||
word-break: break-all;
|
||||
}
|
||||
|
||||
.td-actions {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0.4rem;
|
||||
align-items: flex-start;
|
||||
}
|
||||
|
||||
.classify-row {
|
||||
display: flex;
|
||||
gap: 0.35rem;
|
||||
align-items: center;
|
||||
flex-wrap: wrap;
|
||||
}
|
||||
|
||||
.classify-select {
|
||||
font-size: 0.78rem;
|
||||
padding: 0.2rem 0.4rem;
|
||||
border-radius: 4px;
|
||||
border: 1px solid var(--color-border, #444);
|
||||
background: var(--color-surface, #1e1e2e);
|
||||
color: var(--color-text, #cdd6f4);
|
||||
cursor: pointer;
|
||||
}
|
||||
|
||||
.classify-select:disabled {
|
||||
opacity: 0.4;
|
||||
cursor: not-allowed;
|
||||
}
|
||||
|
||||
/* ── Badges ── */
|
||||
.badge-group {
|
||||
display: flex;
|
||||
|
|
@ -1172,76 +745,6 @@ onUnmounted(() => {
|
|||
background: color-mix(in srgb, var(--color-accent, #c4732a) 12%, var(--color-surface-alt, #dde4f0));
|
||||
}
|
||||
|
||||
.chip-role {
|
||||
color: var(--color-info, #1e6091);
|
||||
background: color-mix(in srgb, var(--color-info, #1e6091) 12%, var(--color-surface-alt, #dde4f0));
|
||||
}
|
||||
|
||||
.chip-sm {
|
||||
font-size: 0.68rem;
|
||||
padding: 0.1rem 0.4rem;
|
||||
}
|
||||
|
||||
/* Service chips — one colour per CF service */
|
||||
.chip-service-avocet {
|
||||
color: var(--color-primary, #2d5a27);
|
||||
background: color-mix(in srgb, var(--color-primary, #2d5a27) 15%, var(--color-surface-alt, #dde4f0));
|
||||
}
|
||||
|
||||
.chip-service-cf-text {
|
||||
color: #c2410c;
|
||||
background: color-mix(in srgb, #c2410c 12%, var(--color-surface-alt, #dde4f0));
|
||||
}
|
||||
|
||||
.chip-service-cf-stt {
|
||||
color: #5e35b1;
|
||||
background: color-mix(in srgb, #5e35b1 12%, var(--color-surface-alt, #dde4f0));
|
||||
}
|
||||
|
||||
.chip-service-cf-tts {
|
||||
color: #0277bd;
|
||||
background: color-mix(in srgb, #0277bd 12%, var(--color-surface-alt, #dde4f0));
|
||||
}
|
||||
|
||||
.chip-service-cf-vision {
|
||||
color: #00695c;
|
||||
background: color-mix(in srgb, #00695c 12%, var(--color-surface-alt, #dde4f0));
|
||||
}
|
||||
|
||||
.chip-service-cf-core {
|
||||
color: #6d4c41;
|
||||
background: color-mix(in srgb, #6d4c41 12%, var(--color-surface-alt, #dde4f0));
|
||||
}
|
||||
|
||||
.chip-service-cf-voice {
|
||||
color: #ad1457;
|
||||
background: color-mix(in srgb, #ad1457 12%, var(--color-surface-alt, #dde4f0));
|
||||
}
|
||||
|
||||
.chip-service-other {
|
||||
color: var(--color-text-muted, #4a5c7a);
|
||||
background: var(--color-surface-alt, #dde4f0);
|
||||
}
|
||||
|
||||
/* ── Installed group ── */
|
||||
.installed-group {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0.5rem;
|
||||
}
|
||||
|
||||
.installed-group-header {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.5rem;
|
||||
padding: 0.25rem 0;
|
||||
}
|
||||
|
||||
.installed-group-count {
|
||||
font-size: 0.78rem;
|
||||
color: var(--color-text-muted, #4a5c7a);
|
||||
}
|
||||
|
||||
/* ── Buttons ── */
|
||||
.btn-primary, .btn-danger {
|
||||
padding: 0.4rem 0.9rem;
|
||||
|
|
@ -1317,7 +820,7 @@ onUnmounted(() => {
|
|||
|
||||
.installed-table th:nth-child(3),
|
||||
.installed-table td:nth-child(3) {
|
||||
display: none; /* hide Role column on very narrow screens */
|
||||
display: none; /* hide Adapter column on very narrow screens */
|
||||
}
|
||||
}
|
||||
</style>
|
||||
|
|
|
|||
|
|
@ -1,165 +0,0 @@
|
|||
<script setup lang="ts">
|
||||
import { ref, onMounted } from 'vue'
|
||||
import NodeCard from '../components/nodes/NodeCard.vue'
|
||||
import AssignmentsTab from './AssignmentsTab.vue'
|
||||
import type { NodeSummary } from '../types/nodes'
|
||||
|
||||
type Tab = 'nodes' | 'assignments'
|
||||
|
||||
const activeTab = ref<Tab>('nodes')
|
||||
const nodes = ref<NodeSummary[]>([])
|
||||
const loading = ref(true)
|
||||
const error = ref('')
|
||||
|
||||
async function fetchNodes() {
|
||||
loading.value = true
|
||||
error.value = ''
|
||||
try {
|
||||
const r = await fetch('/api/nodes-mgmt/nodes')
|
||||
if (!r.ok) throw new Error(`HTTP ${r.status}`)
|
||||
nodes.value = (await r.json()) as NodeSummary[]
|
||||
} catch (e) {
|
||||
error.value = e instanceof Error ? e.message : 'Failed to load nodes'
|
||||
} finally {
|
||||
loading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
onMounted(fetchNodes)
|
||||
</script>
|
||||
|
||||
<template>
|
||||
<main class="fleet-page">
|
||||
<header class="fleet-header">
|
||||
<h1 class="fleet-title">Fleet</h1>
|
||||
</header>
|
||||
|
||||
<!-- Tab bar -->
|
||||
<nav class="tab-bar" role="tablist" aria-label="Fleet sections">
|
||||
<button
|
||||
id="tab-nodes"
|
||||
role="tab"
|
||||
:aria-selected="activeTab === 'nodes'"
|
||||
:class="['tab', { active: activeTab === 'nodes' }]"
|
||||
@click="activeTab = 'nodes'"
|
||||
>Nodes</button>
|
||||
<button
|
||||
id="tab-assignments"
|
||||
role="tab"
|
||||
:aria-selected="activeTab === 'assignments'"
|
||||
:class="['tab', { active: activeTab === 'assignments' }]"
|
||||
@click="activeTab = 'assignments'"
|
||||
>Assignments</button>
|
||||
</nav>
|
||||
|
||||
<!-- Nodes tab -->
|
||||
<section
|
||||
v-if="activeTab === 'nodes'"
|
||||
role="tabpanel"
|
||||
aria-labelledby="tab-nodes"
|
||||
class="tab-panel"
|
||||
>
|
||||
<div class="nodes-toolbar">
|
||||
<button class="btn-secondary btn-sm" @click="fetchNodes" :disabled="loading">Refresh</button>
|
||||
</div>
|
||||
<div aria-live="polite" aria-atomic="true" class="sr-announce">
|
||||
<span v-if="loading">Loading nodes...</span>
|
||||
</div>
|
||||
<div v-if="error" class="nodes-status nodes-error" role="alert">{{ error }}</div>
|
||||
<div v-else-if="!loading && nodes.length === 0" class="nodes-status">
|
||||
No nodes found. Check <code>coordinator_url</code> in config.
|
||||
</div>
|
||||
<div v-else-if="!loading" class="nodes-grid">
|
||||
<NodeCard
|
||||
v-for="node in nodes"
|
||||
:key="node.node_id"
|
||||
:node="node"
|
||||
@updated="fetchNodes"
|
||||
/>
|
||||
</div>
|
||||
</section>
|
||||
|
||||
<!-- Assignments tab -->
|
||||
<section
|
||||
v-else-if="activeTab === 'assignments'"
|
||||
role="tabpanel"
|
||||
aria-labelledby="tab-assignments"
|
||||
class="tab-panel"
|
||||
>
|
||||
<AssignmentsTab />
|
||||
</section>
|
||||
</main>
|
||||
</template>
|
||||
|
||||
<style scoped>
|
||||
.fleet-page { padding: 1.5rem; }
|
||||
|
||||
.fleet-header {
|
||||
margin-bottom: 1rem;
|
||||
}
|
||||
.fleet-title {
|
||||
margin: 0;
|
||||
font-size: 1.5rem;
|
||||
color: var(--color-text);
|
||||
}
|
||||
|
||||
/* ── Tab bar ── */
|
||||
.tab-bar {
|
||||
display: flex;
|
||||
gap: 0;
|
||||
border-bottom: 2px solid var(--color-border);
|
||||
margin-bottom: 1.25rem;
|
||||
}
|
||||
.tab {
|
||||
padding: 0.55rem 1.1rem;
|
||||
font-size: 0.88rem;
|
||||
font-weight: 600;
|
||||
background: none;
|
||||
border: none;
|
||||
border-bottom: 2px solid transparent;
|
||||
margin-bottom: -2px;
|
||||
cursor: pointer;
|
||||
color: var(--color-text-muted);
|
||||
transition: color 0.15s, border-color 0.15s;
|
||||
}
|
||||
.tab:hover { color: var(--color-text); }
|
||||
.tab.active {
|
||||
color: var(--app-primary);
|
||||
border-bottom-color: var(--app-primary);
|
||||
}
|
||||
|
||||
/* ── Tab panel ── */
|
||||
.tab-panel { min-height: 200px; }
|
||||
|
||||
/* ── Nodes toolbar ── */
|
||||
.nodes-toolbar {
|
||||
display: flex;
|
||||
justify-content: flex-end;
|
||||
margin-bottom: 1rem;
|
||||
}
|
||||
|
||||
/* ── Nodes grid / status ── */
|
||||
.nodes-grid { display: flex; flex-direction: column; gap: 1.5rem; }
|
||||
.nodes-status {
|
||||
color: var(--color-text-muted);
|
||||
padding: 2rem;
|
||||
text-align: center;
|
||||
}
|
||||
.nodes-error { color: var(--color-error); }
|
||||
.sr-announce { min-height: 1.2em; }
|
||||
|
||||
/* ── Shared button ── */
|
||||
.btn-secondary {
|
||||
padding: 0.4rem 0.9rem;
|
||||
background: var(--color-surface-alt);
|
||||
border: 1px solid var(--color-border);
|
||||
border-radius: 0.4rem;
|
||||
font-size: 0.85rem;
|
||||
color: var(--color-text);
|
||||
cursor: pointer;
|
||||
transition: background 0.15s;
|
||||
}
|
||||
.btn-secondary:hover:not(:disabled) { background: var(--color-surface-raised); }
|
||||
.btn-secondary:disabled { opacity: 0.5; cursor: default; }
|
||||
.btn-sm { padding: 0.3rem 0.65rem; font-size: 0.8rem; }
|
||||
</style>
|
||||
File diff suppressed because it is too large
Load diff
|
|
@ -1,536 +0,0 @@
|
|||
<template>
|
||||
<div class="rsv">
|
||||
<!-- Header -->
|
||||
<header class="rsv-header">
|
||||
<h1 class="rsv-title">Recipe Scan Review</h1>
|
||||
<div class="rsv-stats" v-if="stats">
|
||||
<span class="stat-chip">{{ stats.by_status?.pending ?? 0 }} pending</span>
|
||||
<span class="stat-chip stat-chip--ok">{{ stats.by_status?.approved ?? 0 }} approved</span>
|
||||
<span class="stat-chip stat-chip--edited">{{ stats.by_status?.edited ?? 0 }} edited</span>
|
||||
<span class="stat-chip stat-chip--bad">{{ stats.by_status?.rejected ?? 0 }} rejected</span>
|
||||
<a
|
||||
v-if="(stats.export_ready ?? 0) > 0"
|
||||
:href="`${apiBase}/api/recipe-scan/export`"
|
||||
download
|
||||
class="btn-export"
|
||||
>
|
||||
⬇ Export {{ stats.export_ready }} pairs
|
||||
</a>
|
||||
</div>
|
||||
</header>
|
||||
|
||||
<!-- Loading -->
|
||||
<div v-if="loading" class="rsv-state" aria-label="Loading">
|
||||
<div class="skeleton-block" />
|
||||
</div>
|
||||
|
||||
<!-- Error -->
|
||||
<div v-else-if="apiError" class="rsv-state rsv-error" role="alert">
|
||||
<p>{{ apiError }}</p>
|
||||
<button class="btn-action" @click="fetchNext">Retry</button>
|
||||
</div>
|
||||
|
||||
<!-- Queue empty -->
|
||||
<div v-else-if="!item" class="rsv-state rsv-empty">
|
||||
<p>Queue is empty — all items reviewed.</p>
|
||||
<p class="rsv-hint">Import items from the Kiwi pipeline to continue.</p>
|
||||
</div>
|
||||
|
||||
<!-- Review panel -->
|
||||
<div v-else class="rsv-workspace">
|
||||
<!-- Left: image -->
|
||||
<section class="rsv-image-panel" aria-label="Scan image">
|
||||
<div class="rsv-panel-label">
|
||||
<span class="modality-badge">{{ item.modality }}</span>
|
||||
<span class="source-badge">{{ item.source }}</span>
|
||||
</div>
|
||||
<div class="rsv-image-wrap">
|
||||
<img
|
||||
v-if="imageUrl"
|
||||
:src="imageUrl"
|
||||
:alt="`Recipe scan — ${item.source}`"
|
||||
class="rsv-image"
|
||||
/>
|
||||
<div v-else class="rsv-image-placeholder">
|
||||
<span>Image not available</span>
|
||||
<code class="rsv-path">{{ item.image_path }}</code>
|
||||
</div>
|
||||
</div>
|
||||
</section>
|
||||
|
||||
<!-- Right: JSON comparison -->
|
||||
<section class="rsv-json-panel" aria-label="Extraction review">
|
||||
|
||||
<!-- Ground truth (read-only reference) -->
|
||||
<div class="rsv-json-block">
|
||||
<h2 class="rsv-json-label">Ground truth <span class="label-tag">reference</span></h2>
|
||||
<pre class="rsv-json rsv-json--ground-truth" tabindex="0" aria-label="Ground truth JSON">{{ prettyJson(item.ground_truth) }}</pre>
|
||||
</div>
|
||||
|
||||
<!-- Extracted / editable -->
|
||||
<div class="rsv-json-block">
|
||||
<h2 class="rsv-json-label">
|
||||
Extracted
|
||||
<span class="label-tag label-tag--edit">edit before approving</span>
|
||||
</h2>
|
||||
<textarea
|
||||
v-model="draftJson"
|
||||
class="rsv-json rsv-json--edit"
|
||||
spellcheck="false"
|
||||
aria-label="Extracted JSON — edit to correct"
|
||||
:class="{ 'rsv-json--invalid': jsonError }"
|
||||
/>
|
||||
<p v-if="jsonError" class="rsv-json-error" role="alert">{{ jsonError }}</p>
|
||||
</div>
|
||||
|
||||
<!-- Actions -->
|
||||
<div class="rsv-actions" role="group" aria-label="Review actions">
|
||||
<button
|
||||
class="btn-approve"
|
||||
:disabled="acting"
|
||||
@click="handleApprove"
|
||||
title="Extracted JSON is accurate — approve as-is (A)"
|
||||
>
|
||||
✓ Approve
|
||||
</button>
|
||||
<button
|
||||
class="btn-edit"
|
||||
:disabled="acting || !!jsonError"
|
||||
@click="handleEdit"
|
||||
title="Approve the edited JSON in the text area (E)"
|
||||
>
|
||||
✎ Approve edited
|
||||
</button>
|
||||
<button
|
||||
class="btn-reject"
|
||||
:disabled="acting"
|
||||
@click="handleReject"
|
||||
title="Extraction too broken to use — reject (R)"
|
||||
>
|
||||
✕ Reject
|
||||
</button>
|
||||
</div>
|
||||
|
||||
</section>
|
||||
</div>
|
||||
|
||||
<!-- Feedback toast -->
|
||||
<Transition name="toast">
|
||||
<div v-if="toast" class="rsv-toast" role="status" aria-live="polite">
|
||||
{{ toast }}
|
||||
</div>
|
||||
</Transition>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, computed, watch, onMounted, onUnmounted } from 'vue'
|
||||
|
||||
const apiBase = window.location.origin
|
||||
|
||||
interface RecipeScanItem {
|
||||
id: string
|
||||
image_path: string
|
||||
modality: string
|
||||
source: string
|
||||
extracted: Record<string, unknown>
|
||||
ground_truth: Record<string, unknown>
|
||||
status: string
|
||||
}
|
||||
|
||||
interface Stats {
|
||||
total: number
|
||||
by_status: Record<string, number>
|
||||
by_modality: Record<string, number>
|
||||
export_ready: number
|
||||
}
|
||||
|
||||
const item = ref<RecipeScanItem | null>(null)
|
||||
const stats = ref<Stats | null>(null)
|
||||
const loading = ref(true)
|
||||
const acting = ref(false)
|
||||
const apiError = ref('')
|
||||
const draftJson = ref('')
|
||||
const toast = ref('')
|
||||
let toastTimer: ReturnType<typeof setTimeout> | null = null
|
||||
|
||||
const jsonError = computed(() => {
|
||||
if (!draftJson.value.trim()) return ''
|
||||
try {
|
||||
JSON.parse(draftJson.value)
|
||||
return ''
|
||||
} catch (e) {
|
||||
return 'Invalid JSON — fix before approving'
|
||||
}
|
||||
})
|
||||
|
||||
const imageUrl = computed(() => {
|
||||
if (!item.value) return ''
|
||||
const encoded = encodeURIComponent(item.value.image_path)
|
||||
return `${apiBase}/api/recipe-scan/image?path=${encoded}`
|
||||
})
|
||||
|
||||
function prettyJson(obj: unknown): string {
|
||||
return JSON.stringify(obj, null, 2)
|
||||
}
|
||||
|
||||
function showToast(msg: string) {
|
||||
toast.value = msg
|
||||
if (toastTimer) clearTimeout(toastTimer)
|
||||
toastTimer = setTimeout(() => { toast.value = '' }, 2500)
|
||||
}
|
||||
|
||||
async function fetchNext() {
|
||||
loading.value = true
|
||||
apiError.value = ''
|
||||
try {
|
||||
const r = await fetch(`${apiBase}/api/recipe-scan/next`)
|
||||
if (r.status === 404) {
|
||||
item.value = null
|
||||
} else if (!r.ok) {
|
||||
throw new Error(`API error ${r.status}`)
|
||||
} else {
|
||||
item.value = await r.json()
|
||||
draftJson.value = prettyJson(item.value!.extracted)
|
||||
}
|
||||
} catch (e) {
|
||||
apiError.value = e instanceof Error ? e.message : 'Could not reach API'
|
||||
} finally {
|
||||
loading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
async function fetchStats() {
|
||||
try {
|
||||
const r = await fetch(`${apiBase}/api/recipe-scan/stats`)
|
||||
if (r.ok) stats.value = await r.json()
|
||||
} catch { /* non-critical */ }
|
||||
}
|
||||
|
||||
async function act(endpoint: string, body?: unknown) {
|
||||
if (!item.value || acting.value) return
|
||||
acting.value = true
|
||||
try {
|
||||
const r = await fetch(`${apiBase}/api/recipe-scan/items/${item.value.id}/${endpoint}`, {
|
||||
method: 'POST',
|
||||
headers: body ? { 'Content-Type': 'application/json' } : {},
|
||||
body: body ? JSON.stringify(body) : undefined,
|
||||
})
|
||||
if (!r.ok) throw new Error(`API error ${r.status}`)
|
||||
} catch (e) {
|
||||
showToast(e instanceof Error ? e.message : 'Action failed')
|
||||
acting.value = false
|
||||
return
|
||||
}
|
||||
acting.value = false
|
||||
await Promise.all([fetchNext(), fetchStats()])
|
||||
}
|
||||
|
||||
async function handleApprove() {
|
||||
showToast('Approved')
|
||||
await act('approve')
|
||||
}
|
||||
|
||||
async function handleEdit() {
|
||||
if (jsonError.value) return
|
||||
let corrected: unknown
|
||||
try {
|
||||
corrected = JSON.parse(draftJson.value)
|
||||
} catch {
|
||||
return
|
||||
}
|
||||
showToast('Saved edit')
|
||||
await act('edit', { corrected })
|
||||
}
|
||||
|
||||
async function handleReject() {
|
||||
showToast('Rejected')
|
||||
await act('reject')
|
||||
}
|
||||
|
||||
// Keyboard shortcuts: A = approve, E = edit+approve, R = reject
|
||||
function handleKey(e: KeyboardEvent) {
|
||||
const tag = (e.target as HTMLElement)?.tagName?.toLowerCase()
|
||||
if (tag === 'textarea' || tag === 'input') return
|
||||
if (e.key === 'a' || e.key === 'A') handleApprove()
|
||||
if (e.key === 'e' || e.key === 'E') handleEdit()
|
||||
if (e.key === 'r' || e.key === 'R') handleReject()
|
||||
}
|
||||
|
||||
watch(item, (newItem) => {
|
||||
if (newItem) draftJson.value = prettyJson(newItem.extracted)
|
||||
})
|
||||
|
||||
onMounted(() => {
|
||||
fetchNext()
|
||||
fetchStats()
|
||||
window.addEventListener('keydown', handleKey)
|
||||
})
|
||||
|
||||
onUnmounted(() => {
|
||||
window.removeEventListener('keydown', handleKey)
|
||||
if (toastTimer) clearTimeout(toastTimer)
|
||||
})
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
.rsv {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
height: 100%;
|
||||
padding: var(--space-md, 1rem);
|
||||
gap: var(--space-md, 1rem);
|
||||
box-sizing: border-box;
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
/* Header */
|
||||
.rsv-header {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: var(--space-md, 1rem);
|
||||
flex-wrap: wrap;
|
||||
}
|
||||
.rsv-title {
|
||||
font-size: 1.1rem;
|
||||
font-weight: 600;
|
||||
margin: 0;
|
||||
color: var(--color-text, #fff);
|
||||
}
|
||||
.rsv-stats {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.5rem;
|
||||
flex-wrap: wrap;
|
||||
}
|
||||
.stat-chip {
|
||||
font-size: 0.75rem;
|
||||
padding: 2px 8px;
|
||||
border-radius: 12px;
|
||||
background: var(--color-surface-alt, #2a2a2a);
|
||||
color: var(--color-text-muted, #aaa);
|
||||
}
|
||||
.stat-chip--ok { background: #1a3a1a; color: #6fcf97; }
|
||||
.stat-chip--edited { background: #2a2a00; color: #f2c94c; }
|
||||
.stat-chip--bad { background: #3a1a1a; color: #eb5757; }
|
||||
.btn-export {
|
||||
font-size: 0.8rem;
|
||||
padding: 4px 12px;
|
||||
border-radius: 6px;
|
||||
background: var(--color-accent, #4a9eff);
|
||||
color: #fff;
|
||||
text-decoration: none;
|
||||
}
|
||||
|
||||
/* State panels */
|
||||
.rsv-state {
|
||||
flex: 1;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
gap: 0.5rem;
|
||||
color: var(--color-text-muted, #aaa);
|
||||
}
|
||||
.rsv-error { color: var(--color-danger, #eb5757); }
|
||||
.rsv-empty { font-size: 1rem; }
|
||||
.rsv-hint { font-size: 0.85rem; opacity: 0.7; margin: 0; }
|
||||
.skeleton-block {
|
||||
width: 100%; height: 300px;
|
||||
border-radius: 8px;
|
||||
background: var(--color-surface-alt, #2a2a2a);
|
||||
animation: pulse 1.5s ease-in-out infinite;
|
||||
}
|
||||
@keyframes pulse { 0%, 100% { opacity: 1; } 50% { opacity: 0.5; } }
|
||||
|
||||
/* Workspace: two-column layout */
|
||||
.rsv-workspace {
|
||||
flex: 1;
|
||||
display: grid;
|
||||
grid-template-columns: 1fr 1fr;
|
||||
gap: var(--space-md, 1rem);
|
||||
min-height: 0;
|
||||
overflow: hidden;
|
||||
}
|
||||
@media (max-width: 900px) {
|
||||
.rsv-workspace {
|
||||
grid-template-columns: 1fr;
|
||||
overflow-y: auto;
|
||||
}
|
||||
}
|
||||
|
||||
/* Image panel */
|
||||
.rsv-image-panel {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0.5rem;
|
||||
min-height: 0;
|
||||
}
|
||||
.rsv-panel-label {
|
||||
display: flex;
|
||||
gap: 0.5rem;
|
||||
}
|
||||
.modality-badge, .source-badge {
|
||||
font-size: 0.72rem;
|
||||
padding: 2px 8px;
|
||||
border-radius: 10px;
|
||||
background: var(--color-surface-alt, #2a2a2a);
|
||||
color: var(--color-text-muted, #aaa);
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.04em;
|
||||
}
|
||||
.rsv-image-wrap {
|
||||
flex: 1;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
background: var(--color-surface-alt, #111);
|
||||
border-radius: 8px;
|
||||
overflow: hidden;
|
||||
min-height: 200px;
|
||||
}
|
||||
.rsv-image {
|
||||
max-width: 100%;
|
||||
max-height: 100%;
|
||||
object-fit: contain;
|
||||
}
|
||||
.rsv-image-placeholder {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
align-items: center;
|
||||
gap: 0.5rem;
|
||||
color: var(--color-text-muted, #666);
|
||||
font-size: 0.85rem;
|
||||
padding: 1rem;
|
||||
text-align: center;
|
||||
}
|
||||
.rsv-path {
|
||||
font-size: 0.7rem;
|
||||
word-break: break-all;
|
||||
opacity: 0.6;
|
||||
}
|
||||
|
||||
/* JSON panel */
|
||||
.rsv-json-panel {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0.5rem;
|
||||
min-height: 0;
|
||||
overflow-y: auto;
|
||||
}
|
||||
.rsv-json-block {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0.25rem;
|
||||
flex: 1;
|
||||
min-height: 0;
|
||||
}
|
||||
.rsv-json-label {
|
||||
font-size: 0.8rem;
|
||||
font-weight: 600;
|
||||
color: var(--color-text-muted, #aaa);
|
||||
margin: 0;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.5rem;
|
||||
}
|
||||
.label-tag {
|
||||
font-size: 0.68rem;
|
||||
font-weight: 400;
|
||||
padding: 1px 6px;
|
||||
border-radius: 8px;
|
||||
background: var(--color-surface-alt, #2a2a2a);
|
||||
color: var(--color-text-muted, #888);
|
||||
}
|
||||
.label-tag--edit {
|
||||
background: #2a2a00;
|
||||
color: #f2c94c;
|
||||
}
|
||||
.rsv-json {
|
||||
font-family: var(--font-mono, monospace);
|
||||
font-size: 0.75rem;
|
||||
line-height: 1.5;
|
||||
padding: 0.75rem;
|
||||
border-radius: 6px;
|
||||
min-height: 120px;
|
||||
flex: 1;
|
||||
overflow-y: auto;
|
||||
resize: vertical;
|
||||
white-space: pre;
|
||||
}
|
||||
.rsv-json--ground-truth {
|
||||
background: var(--color-surface-alt, #111);
|
||||
color: var(--color-text, #ccc);
|
||||
border: 1px solid var(--color-border, #333);
|
||||
}
|
||||
.rsv-json--edit {
|
||||
background: var(--color-surface, #1a1a1a);
|
||||
color: var(--color-text, #e0e0e0);
|
||||
border: 1px solid var(--color-border, #444);
|
||||
caret-color: var(--color-accent, #4a9eff);
|
||||
outline: none;
|
||||
transition: border-color 0.15s;
|
||||
}
|
||||
.rsv-json--edit:focus {
|
||||
border-color: var(--color-accent, #4a9eff);
|
||||
}
|
||||
.rsv-json--invalid {
|
||||
border-color: var(--color-danger, #eb5757) !important;
|
||||
}
|
||||
.rsv-json-error {
|
||||
font-size: 0.75rem;
|
||||
color: var(--color-danger, #eb5757);
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
/* Action buttons */
|
||||
.rsv-actions {
|
||||
display: flex;
|
||||
gap: 0.5rem;
|
||||
padding-top: 0.25rem;
|
||||
flex-wrap: wrap;
|
||||
}
|
||||
.btn-approve, .btn-edit, .btn-reject {
|
||||
flex: 1;
|
||||
min-width: 80px;
|
||||
padding: 0.5rem 0.75rem;
|
||||
border: none;
|
||||
border-radius: 6px;
|
||||
font-size: 0.85rem;
|
||||
font-weight: 600;
|
||||
cursor: pointer;
|
||||
transition: opacity 0.15s;
|
||||
}
|
||||
.btn-approve, .btn-edit, .btn-reject {
|
||||
opacity: 1;
|
||||
}
|
||||
.btn-approve:disabled, .btn-edit:disabled, .btn-reject:disabled {
|
||||
opacity: 0.4;
|
||||
cursor: default;
|
||||
}
|
||||
.btn-approve { background: #1e6e1e; color: #6fcf97; }
|
||||
.btn-approve:hover:not(:disabled) { background: #256325; }
|
||||
.btn-edit { background: #4a4a00; color: #f2c94c; }
|
||||
.btn-edit:hover:not(:disabled) { background: #606000; }
|
||||
.btn-reject { background: #6e1e1e; color: #eb8f8f; }
|
||||
.btn-reject:hover:not(:disabled) { background: #7a2222; }
|
||||
|
||||
/* Toast */
|
||||
.rsv-toast {
|
||||
position: fixed;
|
||||
bottom: 1.5rem;
|
||||
left: 50%;
|
||||
transform: translateX(-50%);
|
||||
background: var(--color-surface, #222);
|
||||
color: var(--color-text, #fff);
|
||||
border: 1px solid var(--color-border, #444);
|
||||
border-radius: 8px;
|
||||
padding: 0.5rem 1.25rem;
|
||||
font-size: 0.85rem;
|
||||
box-shadow: 0 4px 20px rgba(0,0,0,0.4);
|
||||
pointer-events: none;
|
||||
z-index: 100;
|
||||
}
|
||||
.toast-enter-active, .toast-leave-active { transition: opacity 0.2s, transform 0.2s; }
|
||||
.toast-enter-from, .toast-leave-to { opacity: 0; transform: translateX(-50%) translateY(8px); }
|
||||
</style>
|
||||
|
|
@ -115,18 +115,8 @@
|
|||
<h2 class="section-title">cf-orch Integration</h2>
|
||||
<p class="section-desc">
|
||||
Import SFT (supervised fine-tuning) candidates from cf-orch benchmark runs.
|
||||
Connection settings fall back to environment variables
|
||||
(<code>CF_ORCH_URL</code>, <code>CF_LICENSE_KEY</code>, <code>OLLAMA_HOST</code>)
|
||||
when not set here.
|
||||
</p>
|
||||
|
||||
<!-- Connection status pill -->
|
||||
<div v-if="orchConfig" class="orch-status-row">
|
||||
<span class="orch-status-pill" :class="orchStatusClass">{{ orchStatusLabel }}</span>
|
||||
<span v-if="orchConfig.source === 'env'" class="orch-source-note">via env vars</span>
|
||||
<span v-else class="orch-source-note">via label_tool.yaml</span>
|
||||
</div>
|
||||
|
||||
<div class="field-row">
|
||||
<label class="field field-grow">
|
||||
<span>bench_results_dir</span>
|
||||
|
|
@ -191,7 +181,7 @@
|
|||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, computed, onMounted } from 'vue'
|
||||
import { ref, onMounted } from 'vue'
|
||||
import { useApiFetch } from '../composables/useApi'
|
||||
|
||||
interface Account {
|
||||
|
|
@ -209,27 +199,12 @@ const saveOk = ref(true)
|
|||
const richMotion = ref(localStorage.getItem('cf-avocet-rich-motion') !== 'false')
|
||||
const keyHints = ref(localStorage.getItem('cf-avocet-key-hints') !== 'false')
|
||||
|
||||
// SFT / cf-orch integration state
|
||||
// SFT integration state
|
||||
const benchResultsDir = ref('')
|
||||
const runs = ref<Array<{ run_id: string; timestamp: string; candidate_count: number; already_imported: boolean }>>([])
|
||||
const importingRunId = ref<string | null>(null)
|
||||
const importResult = ref<{ imported: number; skipped: number } | null>(null)
|
||||
const saveStatus = ref('')
|
||||
const orchConfig = ref<{ coordinator_url: string; ollama_url: string; ollama_model: string; license_key_set: boolean; source: string } | null>(null)
|
||||
|
||||
const orchStatusClass = computed(() => {
|
||||
if (!orchConfig.value) return 'status-unknown'
|
||||
if (orchConfig.value.coordinator_url) return 'status-connected'
|
||||
if (orchConfig.value.ollama_url) return 'status-local'
|
||||
return 'status-unconfigured'
|
||||
})
|
||||
|
||||
const orchStatusLabel = computed(() => {
|
||||
if (!orchConfig.value) return 'Unknown'
|
||||
if (orchConfig.value.coordinator_url) return '● cf-orch coordinator'
|
||||
if (orchConfig.value.ollama_url) return '● Ollama (local)'
|
||||
return '○ Not configured'
|
||||
})
|
||||
|
||||
async function loadSftConfig() {
|
||||
try {
|
||||
|
|
@ -243,15 +218,6 @@ async function loadSftConfig() {
|
|||
}
|
||||
}
|
||||
|
||||
async function loadOrchConfig() {
|
||||
try {
|
||||
const res = await fetch('/api/cforch/config')
|
||||
if (res.ok) orchConfig.value = await res.json()
|
||||
} catch {
|
||||
// non-fatal
|
||||
}
|
||||
}
|
||||
|
||||
async function saveSftConfig() {
|
||||
saveStatus.value = 'Saving…'
|
||||
try {
|
||||
|
|
@ -371,7 +337,6 @@ function onKeyHintsChange() {
|
|||
onMounted(() => {
|
||||
reload()
|
||||
loadSftConfig()
|
||||
loadOrchConfig()
|
||||
})
|
||||
</script>
|
||||
|
||||
|
|
@ -599,31 +564,6 @@ onMounted(() => {
|
|||
width: 100%;
|
||||
}
|
||||
|
||||
.orch-status-row {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: var(--space-2);
|
||||
margin-bottom: var(--space-3);
|
||||
}
|
||||
|
||||
.orch-status-pill {
|
||||
font-size: 0.8rem;
|
||||
font-weight: 600;
|
||||
padding: var(--space-1) var(--space-3);
|
||||
border-radius: var(--radius-full);
|
||||
}
|
||||
|
||||
.status-connected { background: color-mix(in srgb, var(--color-success, #3a7a32) 12%, transparent); color: var(--color-success, #3a7a32); }
|
||||
.status-local { background: color-mix(in srgb, var(--color-primary) 12%, transparent); color: var(--color-primary); }
|
||||
.status-unconfigured { background: var(--color-surface-alt); color: var(--color-text-muted); }
|
||||
.status-unknown { background: var(--color-surface-alt); color: var(--color-text-muted); }
|
||||
|
||||
.orch-source-note {
|
||||
font-size: 0.75rem;
|
||||
color: var(--color-text-muted);
|
||||
font-style: italic;
|
||||
}
|
||||
|
||||
.runs-table {
|
||||
width: 100%;
|
||||
border-collapse: collapse;
|
||||
|
|
|
|||
|
|
@ -68,44 +68,6 @@
|
|||
<p class="bench-hint">Highlighted cells are the best-scoring model per metric.</p>
|
||||
</template>
|
||||
|
||||
<!-- LLM Benchmark Results -->
|
||||
<template v-if="llmResults.length > 0">
|
||||
<h2 class="section-title">🤖 LLM Benchmark</h2>
|
||||
<div class="bench-table-wrap">
|
||||
<table class="bench-table">
|
||||
<thead>
|
||||
<tr>
|
||||
<th class="bt-model-col">Model</th>
|
||||
<th class="bt-metric-col">overall</th>
|
||||
<th
|
||||
v-for="col in llmTaskTypeCols"
|
||||
:key="col"
|
||||
class="bt-metric-col"
|
||||
>{{ col }}</th>
|
||||
<th class="bt-metric-col">tok/s</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
<tr v-for="row in llmResults" :key="row.model_id">
|
||||
<td class="bt-model-cell" :title="row.model_id">{{ row.model_name }}</td>
|
||||
<td
|
||||
class="bt-metric-cell"
|
||||
:class="{ 'bt-best': llmBestByCol['overall'] === row.model_id }"
|
||||
>{{ llmPct(row.avg_quality_score) }}</td>
|
||||
<td
|
||||
v-for="col in llmTaskTypeCols"
|
||||
:key="col"
|
||||
class="bt-metric-cell"
|
||||
:class="{ 'bt-best': llmBestByCol[col] === row.model_id }"
|
||||
>{{ row.quality_by_task_type[col] != null ? llmPct(row.quality_by_task_type[col]) : '—' }}</td>
|
||||
<td class="bt-metric-cell">{{ row.avg_tokens_per_sec.toFixed(1) }}</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
<p class="bench-hint">Run LLM Eval on the Benchmark tab to refresh. Highlighted = best per column.</p>
|
||||
</template>
|
||||
|
||||
<div class="file-info">
|
||||
<span class="file-path">Score file: <code>data/email_score.jsonl</code></span>
|
||||
<span class="file-size">{{ fileSizeLabel }}</span>
|
||||
|
|
@ -132,18 +94,6 @@ interface BenchmarkModelResult {
|
|||
[key: string]: number | undefined
|
||||
}
|
||||
|
||||
interface LlmModelResult {
|
||||
model_name: string
|
||||
model_id: string
|
||||
node_id: string
|
||||
avg_tokens_per_sec: number
|
||||
avg_completion_ms: number
|
||||
avg_quality_score: number
|
||||
finetune_candidates: number
|
||||
error_count: number
|
||||
quality_by_task_type: Record<string, number>
|
||||
}
|
||||
|
||||
interface StatsResponse {
|
||||
total: number
|
||||
counts: Record<string, number>
|
||||
|
|
@ -235,49 +185,6 @@ function formatMetric(v: number | undefined): string {
|
|||
return `${v.toFixed(1)}%`
|
||||
}
|
||||
|
||||
// ── LLM Benchmark results ────────────────────────────────────────────────────
|
||||
const llmResults = ref<LlmModelResult[]>([])
|
||||
|
||||
const llmTaskTypeCols = computed(() => {
|
||||
const types = new Set<string>()
|
||||
for (const r of llmResults.value) {
|
||||
for (const k of Object.keys(r.quality_by_task_type)) types.add(k)
|
||||
}
|
||||
return [...types].sort()
|
||||
})
|
||||
|
||||
const llmBestByCol = computed((): Record<string, string> => {
|
||||
const best: Record<string, string> = {}
|
||||
if (llmResults.value.length === 0) return best
|
||||
|
||||
let bestId = '', bestVal = -Infinity
|
||||
for (const r of llmResults.value) {
|
||||
if (r.avg_quality_score > bestVal) { bestVal = r.avg_quality_score; bestId = r.model_id }
|
||||
}
|
||||
best['overall'] = bestId
|
||||
|
||||
for (const col of llmTaskTypeCols.value) {
|
||||
bestId = ''; bestVal = -Infinity
|
||||
for (const r of llmResults.value) {
|
||||
const v = r.quality_by_task_type[col]
|
||||
if (v != null && v > bestVal) { bestVal = v; bestId = r.model_id }
|
||||
}
|
||||
best[col] = bestId
|
||||
}
|
||||
return best
|
||||
})
|
||||
|
||||
function llmPct(v: number): string {
|
||||
return `${(v * 100).toFixed(1)}%`
|
||||
}
|
||||
|
||||
async function loadLlmResults() {
|
||||
const { data } = await useApiFetch<LlmModelResult[]>('/api/cforch/results')
|
||||
if (Array.isArray(data) && data.length > 0) {
|
||||
llmResults.value = data
|
||||
}
|
||||
}
|
||||
|
||||
async function load() {
|
||||
loading.value = true
|
||||
error.value = ''
|
||||
|
|
@ -290,10 +197,7 @@ async function load() {
|
|||
}
|
||||
}
|
||||
|
||||
onMounted(() => {
|
||||
load()
|
||||
loadLlmResults()
|
||||
})
|
||||
onMounted(load)
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
|
|
|
|||
|
|
@ -1,919 +0,0 @@
|
|||
<template>
|
||||
<div class="style-tab">
|
||||
|
||||
<!-- ── Controls row ──────────────────────────────────────────────────── -->
|
||||
<div class="style-controls">
|
||||
|
||||
<!-- Model picker -->
|
||||
<details class="model-picker" open>
|
||||
<summary class="picker-summary">
|
||||
<span class="picker-title">✍️ Models</span>
|
||||
<span class="picker-badge">{{ selectedCount }} selected</span>
|
||||
<button class="btn-refresh" :disabled="modelsLoading" @click.stop="loadModels" title="Refresh model list">
|
||||
{{ modelsLoading ? '⏳' : '🔄' }}
|
||||
</button>
|
||||
</summary>
|
||||
<div class="picker-body">
|
||||
<div v-if="modelsLoading" class="picker-loading">Loading models…</div>
|
||||
<div v-else-if="loadError" class="picker-error">{{ loadError }}</div>
|
||||
<template v-else>
|
||||
|
||||
<!-- Ollama group -->
|
||||
<div class="picker-group" v-if="ollamaModels.length">
|
||||
<div class="group-header">
|
||||
<label class="group-check">
|
||||
<input
|
||||
type="checkbox"
|
||||
:checked="isGroupAllSelected('ollama')"
|
||||
:indeterminate="isGroupIndeterminate('ollama')"
|
||||
@change="toggleGroup('ollama', ($event.target as HTMLInputElement).checked)"
|
||||
/>
|
||||
<span class="group-label">Ollama</span>
|
||||
<span class="group-count">({{ ollamaModels.length }})</span>
|
||||
</label>
|
||||
<span class="group-note">auto-synced with Models view</span>
|
||||
</div>
|
||||
<div class="model-list">
|
||||
<label v-for="m in ollamaModels" :key="m.id" class="model-item">
|
||||
<input type="checkbox" :value="m.id" v-model="selectedModels" />
|
||||
<span class="model-name">{{ m.name }}</span>
|
||||
<span v-if="m.size_mb" class="model-meta">{{ formatMb(m.size_mb) }}</span>
|
||||
</label>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- cf-text group -->
|
||||
<div class="picker-group" v-if="cftextModels.length">
|
||||
<div class="group-header">
|
||||
<label class="group-check">
|
||||
<input
|
||||
type="checkbox"
|
||||
:checked="isGroupAllSelected('cf-text')"
|
||||
:indeterminate="isGroupIndeterminate('cf-text')"
|
||||
@change="toggleGroup('cf-text', ($event.target as HTMLInputElement).checked)"
|
||||
/>
|
||||
<span class="group-label">cf-text (cf-orch)</span>
|
||||
<span class="group-count">({{ cftextModels.length }})</span>
|
||||
</label>
|
||||
<span class="group-note">GGUFs via coordinator — enable cf-orch below</span>
|
||||
</div>
|
||||
<div class="model-list">
|
||||
<label v-for="m in cftextModels" :key="m.id" class="model-item">
|
||||
<input type="checkbox" :value="m.id" v-model="selectedModels" />
|
||||
<span class="model-name">{{ m.name }}</span>
|
||||
<span v-if="m.vram_mb" class="model-meta">{{ formatMb(m.vram_mb) }} VRAM</span>
|
||||
</label>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div v-if="!ollamaModels.length && !cftextModels.length" class="picker-empty">
|
||||
No models available — check Ollama and cf-orch connections.
|
||||
</div>
|
||||
|
||||
</template>
|
||||
</div>
|
||||
</details>
|
||||
|
||||
<!-- Options panel -->
|
||||
<details class="options-panel">
|
||||
<summary class="picker-summary">
|
||||
<span class="picker-title">⚙️ Options</span>
|
||||
</summary>
|
||||
<div class="options-body">
|
||||
<label class="option-row">
|
||||
<input type="checkbox" v-model="useCforch" :disabled="running" />
|
||||
<span class="option-label">Use cf-orch backend</span>
|
||||
<span class="option-hint">Routes generation through cf-text instead of ollama</span>
|
||||
</label>
|
||||
<label class="option-row" :class="{ dimmed: !useCforch }">
|
||||
<span class="option-label">Max VRAM (MB)</span>
|
||||
<input
|
||||
type="number"
|
||||
v-model.number="maxVram"
|
||||
:disabled="running || !useCforch"
|
||||
min="1024"
|
||||
max="24576"
|
||||
step="512"
|
||||
class="option-number"
|
||||
/>
|
||||
<span class="option-hint">Skip models exceeding this VRAM limit</span>
|
||||
</label>
|
||||
<label class="option-row">
|
||||
<span class="option-label">Parallel workers</span>
|
||||
<input
|
||||
type="number"
|
||||
v-model.number="workers"
|
||||
:disabled="running"
|
||||
min="1"
|
||||
max="16"
|
||||
step="1"
|
||||
class="option-number"
|
||||
/>
|
||||
<span class="option-hint">Models to score simultaneously (1 = sequential)</span>
|
||||
</label>
|
||||
<label class="option-row">
|
||||
<input type="checkbox" v-model="includeLarge" :disabled="running" />
|
||||
<span class="option-label">Include large models (30B+)</span>
|
||||
<span class="option-hint">Off by default — these take much longer</span>
|
||||
</label>
|
||||
</div>
|
||||
</details>
|
||||
|
||||
</div>
|
||||
|
||||
<!-- ── Run controls ──────────────────────────────────────────────────── -->
|
||||
<div class="run-bar">
|
||||
<button class="btn-run" :disabled="running || selectedCount === 0" @click="startBenchmark">
|
||||
{{ running ? '⏳ Running…' : results.length ? '🔄 Re-run' : '▶ Run Benchmark' }}
|
||||
</button>
|
||||
<button v-if="running" class="btn-cancel" @click="cancelBenchmark">✕ Cancel</button>
|
||||
<span v-if="selectedCount === 0 && !running" class="run-hint">Select at least one model above</span>
|
||||
</div>
|
||||
|
||||
<!-- ── Progress log ──────────────────────────────────────────────────── -->
|
||||
<div v-if="runLog.length" class="run-log">
|
||||
<div class="run-log-header">
|
||||
<span class="run-log-title">Run log</span>
|
||||
<button class="btn-clear" @click="runLog = []">Clear</button>
|
||||
</div>
|
||||
<pre class="run-log-body" ref="logEl">{{ runLog.join('\n') }}</pre>
|
||||
</div>
|
||||
|
||||
<!-- ── Past runs picker ─────────────────────────────────────────────── -->
|
||||
<div class="history-bar" v-if="pastRuns.length">
|
||||
<label class="history-label">📂 Past runs:</label>
|
||||
<select class="history-select" v-model="selectedRun" @change="loadRun(selectedRun)">
|
||||
<option value="">— select a past run —</option>
|
||||
<option v-for="r in pastRuns" :key="r.filename" :value="r.filename">
|
||||
{{ r.date }} · {{ r.model_count }} model{{ r.model_count !== 1 ? 's' : '' }} · top {{ r.top_score }}/100
|
||||
</option>
|
||||
</select>
|
||||
</div>
|
||||
|
||||
<!-- ── Results table ─────────────────────────────────────────────────── -->
|
||||
<div v-if="results.length" class="results-section">
|
||||
<div class="results-header">
|
||||
<h2 class="results-title">Rankings</h2>
|
||||
<button
|
||||
class="btn-corrections"
|
||||
:disabled="sendingCorrections"
|
||||
@click="sendToCorrections"
|
||||
title="Push all outputs from this run into the Corrections review queue"
|
||||
>
|
||||
{{ sendingCorrections ? '⏳ Sending…' : correctionsMsg || '✍️ Send to Corrections' }}
|
||||
</button>
|
||||
</div>
|
||||
<div class="results-table-wrap">
|
||||
<table class="results-table">
|
||||
<thead>
|
||||
<tr>
|
||||
<th>Rank</th>
|
||||
<th>Model</th>
|
||||
<th>Score</th>
|
||||
<th>Latency</th>
|
||||
<th title="Em-dash count">—</th>
|
||||
<th title="Filler phrase hits">Fillers</th>
|
||||
<th title="Semicolons">;</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
<tr
|
||||
v-for="(r, i) in results"
|
||||
:key="r.model_id"
|
||||
class="result-row"
|
||||
:class="{ 'top-row': i === 0 }"
|
||||
@click="toggleExpanded(r.model_id)"
|
||||
>
|
||||
<td class="rank-cell">{{ medal(i) }}</td>
|
||||
<td class="model-cell">
|
||||
<span class="model-name-text">{{ r.model_id }}</span>
|
||||
</td>
|
||||
<td class="score-cell">
|
||||
<span class="score-pill" :style="scorePillStyle(r.avg_score)">
|
||||
{{ r.avg_score.toFixed(0) }}
|
||||
</span>
|
||||
</td>
|
||||
<td class="latency-cell">{{ formatLatency(r.avg_latency_ms) }}</td>
|
||||
<td class="violation-cell" :class="{ 'has-violation': r.total_em_dashes > 0 }">
|
||||
{{ r.total_em_dashes }}
|
||||
</td>
|
||||
<td class="violation-cell" :class="{ 'has-violation': r.total_filler_hits > 0 }">
|
||||
{{ r.total_filler_hits }}
|
||||
</td>
|
||||
<td class="violation-cell" :class="{ 'has-violation': r.total_semicolons > 0 }">
|
||||
{{ r.total_semicolons }}
|
||||
</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
|
||||
<!-- Expandable sample outputs -->
|
||||
<div v-for="r in results" :key="'exp-' + r.model_id">
|
||||
<div v-if="expandedModels.has(r.model_id)" class="sample-outputs">
|
||||
<div class="sample-header">
|
||||
<strong>{{ r.model_id }}</strong>
|
||||
<button class="btn-collapse" @click="toggleExpanded(r.model_id)">✕ Close</button>
|
||||
</div>
|
||||
<div v-for="pr in r.prompt_results" :key="pr.tag" class="sample-prompt">
|
||||
<div class="sample-tag">
|
||||
<span class="tag-name">{{ pr.tag }}</span>
|
||||
<span class="tag-score">{{ pr.score.toFixed(0) }}/100</span>
|
||||
<span class="tag-latency">{{ formatLatency(pr.latency_ms) }}</span>
|
||||
</div>
|
||||
<pre class="sample-text">{{ pr.output || '(no output)' }}</pre>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, computed, onMounted, nextTick, watch } from 'vue'
|
||||
|
||||
// ── Types ───────────────────────────────────────────────────────────────────
|
||||
|
||||
interface StyleModel {
|
||||
id: string
|
||||
name: string
|
||||
source: 'ollama' | 'cf-text'
|
||||
size_mb?: number | null
|
||||
vram_mb?: number | null
|
||||
description?: string
|
||||
}
|
||||
|
||||
interface PromptResult {
|
||||
tag: string
|
||||
output: string
|
||||
score: number
|
||||
latency_ms: number
|
||||
signals: Record<string, unknown>
|
||||
}
|
||||
|
||||
interface ModelResult {
|
||||
model_id: string
|
||||
avg_score: number
|
||||
avg_latency_ms: number
|
||||
total_filler_hits: number
|
||||
total_em_dashes: number
|
||||
total_semicolons: number
|
||||
prompt_results: PromptResult[]
|
||||
}
|
||||
|
||||
interface PastRun {
|
||||
filename: string
|
||||
date: string
|
||||
model_count: number
|
||||
top_score: number
|
||||
}
|
||||
|
||||
// ── State ───────────────────────────────────────────────────────────────────
|
||||
|
||||
const ollamaModels = ref<StyleModel[]>([])
|
||||
const cftextModels = ref<StyleModel[]>([])
|
||||
const selectedModels = ref<string[]>([])
|
||||
const modelsLoading = ref(false)
|
||||
const loadError = ref('')
|
||||
|
||||
const useCforch = ref(false)
|
||||
const maxVram = ref(7200)
|
||||
const workers = ref(1)
|
||||
const includeLarge = ref(false)
|
||||
|
||||
const running = ref(false)
|
||||
const runLog = ref<string[]>([])
|
||||
const logEl = ref<HTMLPreElement | null>(null)
|
||||
|
||||
const results = ref<ModelResult[]>([])
|
||||
const pastRuns = ref<PastRun[]>([])
|
||||
const selectedRun = ref('')
|
||||
const expandedModels = ref(new Set<string>())
|
||||
const sendingCorrections = ref(false)
|
||||
const correctionsMsg = ref('')
|
||||
|
||||
// ── Computed ─────────────────────────────────────────────────────────────────
|
||||
|
||||
const selectedCount = computed(() => selectedModels.value.length)
|
||||
|
||||
function isGroupAllSelected(source: string): boolean {
|
||||
const group = source === 'ollama' ? ollamaModels.value : cftextModels.value
|
||||
return group.length > 0 && group.every(m => selectedModels.value.includes(m.id))
|
||||
}
|
||||
|
||||
function isGroupIndeterminate(source: string): boolean {
|
||||
const group = source === 'ollama' ? ollamaModels.value : cftextModels.value
|
||||
const count = group.filter(m => selectedModels.value.includes(m.id)).length
|
||||
return count > 0 && count < group.length
|
||||
}
|
||||
|
||||
// ── Actions ──────────────────────────────────────────────────────────────────
|
||||
|
||||
async function loadModels() {
|
||||
modelsLoading.value = true
|
||||
loadError.value = ''
|
||||
try {
|
||||
const resp = await fetch('/api/style/models')
|
||||
if (!resp.ok) throw new Error(`HTTP ${resp.status}`)
|
||||
const data = await resp.json()
|
||||
ollamaModels.value = data.ollama ?? []
|
||||
cftextModels.value = data.cf_text ?? []
|
||||
} catch (e: unknown) {
|
||||
loadError.value = `Failed to load models: ${e instanceof Error ? e.message : String(e)}`
|
||||
} finally {
|
||||
modelsLoading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
async function loadPastRuns() {
|
||||
try {
|
||||
const resp = await fetch('/api/style/results')
|
||||
if (resp.ok) pastRuns.value = await resp.json()
|
||||
} catch { /* non-fatal */ }
|
||||
}
|
||||
|
||||
async function loadRun(filename: string) {
|
||||
if (!filename) return
|
||||
try {
|
||||
const resp = await fetch(`/api/style/results/${filename}`)
|
||||
if (!resp.ok) throw new Error(`HTTP ${resp.status}`)
|
||||
results.value = await resp.json()
|
||||
expandedModels.value.clear()
|
||||
} catch (e: unknown) {
|
||||
runLog.value.push(`[error] Failed to load ${filename}: ${e instanceof Error ? e.message : String(e)}`)
|
||||
}
|
||||
}
|
||||
|
||||
function toggleGroup(source: string, checked: boolean) {
|
||||
const group = source === 'ollama' ? ollamaModels.value : cftextModels.value
|
||||
const ids = group.map(m => m.id)
|
||||
if (checked) {
|
||||
const newSet = new Set([...selectedModels.value, ...ids])
|
||||
selectedModels.value = [...newSet]
|
||||
} else {
|
||||
selectedModels.value = selectedModels.value.filter(id => !ids.includes(id))
|
||||
}
|
||||
}
|
||||
|
||||
function toggleExpanded(modelId: string) {
|
||||
if (expandedModels.value.has(modelId)) {
|
||||
expandedModels.value.delete(modelId)
|
||||
} else {
|
||||
expandedModels.value.add(modelId)
|
||||
}
|
||||
expandedModels.value = new Set(expandedModels.value)
|
||||
}
|
||||
|
||||
function startBenchmark() {
|
||||
if (running.value || selectedCount.value === 0) return
|
||||
running.value = true
|
||||
runLog.value = []
|
||||
results.value = []
|
||||
expandedModels.value.clear()
|
||||
|
||||
const params = new URLSearchParams({
|
||||
models: selectedModels.value.join(','),
|
||||
use_cforch: String(useCforch.value),
|
||||
max_vram: String(maxVram.value),
|
||||
workers: String(workers.value),
|
||||
include_large: String(includeLarge.value),
|
||||
})
|
||||
|
||||
const es = new EventSource(`/api/style/run?${params}`)
|
||||
|
||||
es.onmessage = async (ev) => {
|
||||
try {
|
||||
const msg = JSON.parse(ev.data)
|
||||
if (msg.type === 'progress') {
|
||||
runLog.value.push(msg.message)
|
||||
await nextTick()
|
||||
if (logEl.value) logEl.value.scrollTop = logEl.value.scrollHeight
|
||||
} else if (msg.type === 'result') {
|
||||
results.value = msg.results ?? []
|
||||
await loadPastRuns()
|
||||
} else if (msg.type === 'complete') {
|
||||
running.value = false
|
||||
es.close()
|
||||
} else if (msg.type === 'error') {
|
||||
runLog.value.push(`[error] ${msg.message}`)
|
||||
running.value = false
|
||||
es.close()
|
||||
}
|
||||
} catch { /* ignore parse errors */ }
|
||||
}
|
||||
|
||||
es.onerror = () => {
|
||||
if (running.value) {
|
||||
runLog.value.push('[error] Connection lost')
|
||||
running.value = false
|
||||
}
|
||||
es.close()
|
||||
}
|
||||
}
|
||||
|
||||
async function cancelBenchmark() {
|
||||
try {
|
||||
await fetch('/api/style/cancel', { method: 'POST' })
|
||||
} finally {
|
||||
running.value = false
|
||||
runLog.value.push('[cancelled]')
|
||||
}
|
||||
}
|
||||
|
||||
async function sendToCorrections() {
|
||||
if (!selectedRun.value || sendingCorrections.value) return
|
||||
sendingCorrections.value = true
|
||||
correctionsMsg.value = ''
|
||||
try {
|
||||
const resp = await fetch('/api/style/send-to-corrections', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ filename: selectedRun.value, model_ids: [] }),
|
||||
})
|
||||
if (!resp.ok) throw new Error(`HTTP ${resp.status}`)
|
||||
const data = await resp.json()
|
||||
correctionsMsg.value = `✓ ${data.imported} added to Corrections`
|
||||
} catch (e: unknown) {
|
||||
correctionsMsg.value = `Error: ${e instanceof Error ? e.message : String(e)}`
|
||||
} finally {
|
||||
sendingCorrections.value = false
|
||||
}
|
||||
}
|
||||
|
||||
// ── Formatting helpers ────────────────────────────────────────────────────────
|
||||
|
||||
function formatMb(mb: number): string {
|
||||
return mb >= 1024 ? `${(mb / 1024).toFixed(1)} GB` : `${mb} MB`
|
||||
}
|
||||
|
||||
function formatLatency(ms: number): string {
|
||||
return ms >= 1000 ? `${(ms / 1000).toFixed(1)}s` : `${Math.round(ms)}ms`
|
||||
}
|
||||
|
||||
function medal(index: number): string {
|
||||
return ['🥇', '🥈', '🥉'][index] ?? `#${index + 1}`
|
||||
}
|
||||
|
||||
function scorePillStyle(score: number): Record<string, string> {
|
||||
const hue = Math.round((score / 100) * 120) // 0=red, 120=green
|
||||
return {
|
||||
background: `hsl(${hue} 60% 88%)`,
|
||||
color: `hsl(${hue} 60% 28%)`,
|
||||
}
|
||||
}
|
||||
|
||||
// ── Lifecycle ─────────────────────────────────────────────────────────────────
|
||||
|
||||
// Auto-enable cf-orch when cf-text models are selected
|
||||
watch(selectedModels, (ids) => {
|
||||
const hasCftext = ids.some(id => cftextModels.value.find(m => m.id === id))
|
||||
if (hasCftext) useCforch.value = true
|
||||
})
|
||||
|
||||
onMounted(async () => {
|
||||
await Promise.all([loadModels(), loadPastRuns()])
|
||||
// Auto-load the latest results if any exist
|
||||
if (pastRuns.value.length) {
|
||||
selectedRun.value = pastRuns.value[0].filename
|
||||
await loadRun(pastRuns.value[0].filename)
|
||||
}
|
||||
})
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
.style-tab {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 1rem;
|
||||
padding: 1rem 0;
|
||||
}
|
||||
|
||||
/* ── Controls ─────────────────────────────────────────────────────────────── */
|
||||
|
||||
.style-controls {
|
||||
display: flex;
|
||||
flex-wrap: wrap;
|
||||
gap: 0.75rem;
|
||||
align-items: flex-start;
|
||||
}
|
||||
|
||||
.model-picker,
|
||||
.options-panel {
|
||||
flex: 1;
|
||||
min-width: 280px;
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: 0.5rem;
|
||||
background: var(--color-surface, #f4f7fc);
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
.picker-summary {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.5rem;
|
||||
padding: 0.65rem 0.85rem;
|
||||
cursor: pointer;
|
||||
user-select: none;
|
||||
font-size: 0.9rem;
|
||||
font-weight: 600;
|
||||
list-style: none;
|
||||
}
|
||||
|
||||
.picker-summary::-webkit-details-marker { display: none; }
|
||||
|
||||
.picker-title { flex: 1; color: var(--color-text, #1a2338); }
|
||||
.picker-badge {
|
||||
background: var(--app-primary, #2A6080);
|
||||
color: #fff;
|
||||
border-radius: 9999px;
|
||||
padding: 0.1rem 0.5rem;
|
||||
font-size: 0.72rem;
|
||||
font-weight: 700;
|
||||
}
|
||||
|
||||
.btn-refresh {
|
||||
border: none;
|
||||
background: transparent;
|
||||
cursor: pointer;
|
||||
font-size: 0.85rem;
|
||||
padding: 0.1rem 0.25rem;
|
||||
border-radius: 0.25rem;
|
||||
color: var(--color-text-secondary, #6b7a99);
|
||||
}
|
||||
.btn-refresh:hover { background: var(--color-border, #d0d7e8); }
|
||||
.btn-refresh:disabled { opacity: 0.5; cursor: not-allowed; }
|
||||
|
||||
.picker-body,
|
||||
.options-body {
|
||||
padding: 0.75rem;
|
||||
border-top: 1px solid var(--color-border, #d0d7e8);
|
||||
}
|
||||
|
||||
.picker-loading, .picker-empty {
|
||||
color: var(--color-text-secondary, #6b7a99);
|
||||
font-size: 0.85rem;
|
||||
padding: 0.25rem 0;
|
||||
}
|
||||
|
||||
.picker-error {
|
||||
color: #b91c1c;
|
||||
font-size: 0.85rem;
|
||||
}
|
||||
|
||||
/* ── Model groups ──────────────────────────────────────────────────────────── */
|
||||
|
||||
.picker-group {
|
||||
margin-bottom: 0.75rem;
|
||||
}
|
||||
.picker-group:last-child { margin-bottom: 0; }
|
||||
|
||||
.group-header {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.5rem;
|
||||
margin-bottom: 0.4rem;
|
||||
}
|
||||
|
||||
.group-check {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.35rem;
|
||||
font-size: 0.85rem;
|
||||
font-weight: 600;
|
||||
cursor: pointer;
|
||||
color: var(--color-text, #1a2338);
|
||||
}
|
||||
|
||||
.group-count {
|
||||
color: var(--color-text-secondary, #6b7a99);
|
||||
font-weight: 400;
|
||||
font-size: 0.8rem;
|
||||
}
|
||||
|
||||
.group-note {
|
||||
margin-left: auto;
|
||||
font-size: 0.72rem;
|
||||
color: var(--color-text-secondary, #6b7a99);
|
||||
font-style: italic;
|
||||
}
|
||||
|
||||
.model-list {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0.2rem;
|
||||
padding-left: 1.25rem;
|
||||
max-height: 220px;
|
||||
overflow-y: auto;
|
||||
}
|
||||
|
||||
.model-item {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.4rem;
|
||||
font-size: 0.82rem;
|
||||
cursor: pointer;
|
||||
padding: 0.15rem 0;
|
||||
}
|
||||
|
||||
.model-name { flex: 1; font-family: var(--font-mono, monospace); }
|
||||
|
||||
.model-meta {
|
||||
font-size: 0.72rem;
|
||||
color: var(--color-text-secondary, #6b7a99);
|
||||
}
|
||||
|
||||
/* ── Options ──────────────────────────────────────────────────────────────── */
|
||||
|
||||
.option-row {
|
||||
display: flex;
|
||||
align-items: flex-start;
|
||||
gap: 0.5rem;
|
||||
padding: 0.35rem 0;
|
||||
cursor: pointer;
|
||||
font-size: 0.85rem;
|
||||
}
|
||||
|
||||
.option-label { font-weight: 500; white-space: nowrap; }
|
||||
|
||||
.option-hint {
|
||||
flex: 1;
|
||||
font-size: 0.75rem;
|
||||
color: var(--color-text-secondary, #6b7a99);
|
||||
margin-left: auto;
|
||||
text-align: right;
|
||||
}
|
||||
|
||||
.option-number {
|
||||
width: 90px;
|
||||
padding: 0.2rem 0.4rem;
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: 0.25rem;
|
||||
font-size: 0.85rem;
|
||||
background: var(--color-bg, #fff);
|
||||
color: var(--color-text, #1a2338);
|
||||
}
|
||||
|
||||
.option-row.dimmed { opacity: 0.45; pointer-events: none; }
|
||||
|
||||
/* ── Run bar ──────────────────────────────────────────────────────────────── */
|
||||
|
||||
.run-bar {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.65rem;
|
||||
}
|
||||
|
||||
.btn-run {
|
||||
padding: 0.5rem 1.25rem;
|
||||
border: none;
|
||||
border-radius: 0.375rem;
|
||||
background: var(--app-primary, #2A6080);
|
||||
color: #fff;
|
||||
font-size: 0.9rem;
|
||||
font-weight: 600;
|
||||
cursor: pointer;
|
||||
transition: background 0.15s;
|
||||
}
|
||||
.btn-run:hover:not(:disabled) { background: color-mix(in srgb, var(--app-primary, #2A6080) 80%, #000); }
|
||||
.btn-run:disabled { opacity: 0.5; cursor: not-allowed; }
|
||||
|
||||
.btn-cancel {
|
||||
padding: 0.5rem 0.9rem;
|
||||
border: 1px solid #f85149;
|
||||
border-radius: 0.375rem;
|
||||
background: transparent;
|
||||
color: #b91c1c;
|
||||
font-size: 0.85rem;
|
||||
cursor: pointer;
|
||||
transition: background 0.15s;
|
||||
}
|
||||
.btn-cancel:hover { background: #fee2e2; }
|
||||
|
||||
.run-hint {
|
||||
font-size: 0.8rem;
|
||||
color: var(--color-text-secondary, #6b7a99);
|
||||
}
|
||||
|
||||
/* ── Run log ──────────────────────────────────────────────────────────────── */
|
||||
|
||||
.run-log {
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: 0.5rem;
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
.run-log-header {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: space-between;
|
||||
padding: 0.4rem 0.75rem;
|
||||
background: var(--color-surface, #f4f7fc);
|
||||
border-bottom: 1px solid var(--color-border, #d0d7e8);
|
||||
font-size: 0.8rem;
|
||||
font-weight: 600;
|
||||
color: var(--color-text-secondary, #6b7a99);
|
||||
}
|
||||
|
||||
.run-log-title { text-transform: uppercase; letter-spacing: 0.05em; }
|
||||
|
||||
.btn-clear {
|
||||
border: none;
|
||||
background: transparent;
|
||||
font-size: 0.75rem;
|
||||
color: var(--color-text-secondary, #6b7a99);
|
||||
cursor: pointer;
|
||||
padding: 0.1rem 0.3rem;
|
||||
border-radius: 0.25rem;
|
||||
}
|
||||
.btn-clear:hover { background: var(--color-border, #d0d7e8); }
|
||||
|
||||
.run-log-body {
|
||||
margin: 0;
|
||||
padding: 0.65rem 0.85rem;
|
||||
font-size: 0.78rem;
|
||||
font-family: var(--font-mono, monospace);
|
||||
white-space: pre-wrap;
|
||||
word-break: break-all;
|
||||
max-height: 260px;
|
||||
overflow-y: auto;
|
||||
background: var(--color-bg, #fff);
|
||||
color: var(--color-text, #1a2338);
|
||||
}
|
||||
|
||||
/* ── History bar ──────────────────────────────────────────────────────────── */
|
||||
|
||||
.history-bar {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.6rem;
|
||||
font-size: 0.85rem;
|
||||
}
|
||||
|
||||
.history-label { font-weight: 500; white-space: nowrap; }
|
||||
|
||||
.history-select {
|
||||
flex: 1;
|
||||
padding: 0.3rem 0.5rem;
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: 0.375rem;
|
||||
background: var(--color-surface, #f4f7fc);
|
||||
color: var(--color-text, #1a2338);
|
||||
font-size: 0.85rem;
|
||||
}
|
||||
|
||||
/* ── Results table ────────────────────────────────────────────────────────── */
|
||||
|
||||
.results-section { display: flex; flex-direction: column; gap: 0.75rem; }
|
||||
|
||||
.results-header {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: space-between;
|
||||
gap: 0.75rem;
|
||||
}
|
||||
|
||||
.results-title {
|
||||
font-size: 1rem;
|
||||
font-weight: 700;
|
||||
color: var(--color-text, #1a2338);
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
.btn-corrections {
|
||||
padding: 0.4rem 0.9rem;
|
||||
border: 1px solid var(--app-primary, #2A6080);
|
||||
border-radius: 0.375rem;
|
||||
background: transparent;
|
||||
color: var(--app-primary, #2A6080);
|
||||
font-size: 0.83rem;
|
||||
font-weight: 600;
|
||||
cursor: pointer;
|
||||
white-space: nowrap;
|
||||
transition: background 0.15s, color 0.15s;
|
||||
}
|
||||
.btn-corrections:hover:not(:disabled) {
|
||||
background: var(--app-primary, #2A6080);
|
||||
color: #fff;
|
||||
}
|
||||
.btn-corrections:disabled { opacity: 0.55; cursor: not-allowed; }
|
||||
|
||||
.results-table-wrap {
|
||||
overflow-x: auto;
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: 0.5rem;
|
||||
}
|
||||
|
||||
.results-table {
|
||||
width: 100%;
|
||||
border-collapse: collapse;
|
||||
font-size: 0.85rem;
|
||||
}
|
||||
|
||||
.results-table th {
|
||||
padding: 0.5rem 0.75rem;
|
||||
text-align: left;
|
||||
background: var(--color-surface, #f4f7fc);
|
||||
border-bottom: 1px solid var(--color-border, #d0d7e8);
|
||||
font-size: 0.78rem;
|
||||
font-weight: 700;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.04em;
|
||||
color: var(--color-text-secondary, #6b7a99);
|
||||
white-space: nowrap;
|
||||
}
|
||||
|
||||
.result-row {
|
||||
cursor: pointer;
|
||||
transition: background 0.1s;
|
||||
}
|
||||
.result-row:hover { background: color-mix(in srgb, var(--app-primary, #2A6080) 6%, transparent); }
|
||||
.result-row.top-row { font-weight: 600; }
|
||||
|
||||
.result-row td {
|
||||
padding: 0.5rem 0.75rem;
|
||||
border-bottom: 1px solid var(--color-border, #d0d7e8);
|
||||
}
|
||||
.result-row:last-child td { border-bottom: none; }
|
||||
|
||||
.rank-cell { width: 2.5rem; text-align: center; font-size: 1.1rem; }
|
||||
.model-cell { font-family: var(--font-mono, monospace); word-break: break-all; }
|
||||
.score-cell { width: 5rem; text-align: center; }
|
||||
.latency-cell { width: 5rem; text-align: right; color: var(--color-text-secondary, #6b7a99); }
|
||||
.violation-cell { width: 4rem; text-align: center; color: var(--color-text-secondary, #6b7a99); }
|
||||
.violation-cell.has-violation { color: #b91c1c; font-weight: 700; }
|
||||
|
||||
.score-pill {
|
||||
display: inline-block;
|
||||
padding: 0.15rem 0.55rem;
|
||||
border-radius: 9999px;
|
||||
font-weight: 700;
|
||||
font-size: 0.82rem;
|
||||
}
|
||||
|
||||
/* ── Sample outputs ───────────────────────────────────────────────────────── */
|
||||
|
||||
.sample-outputs {
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: 0.5rem;
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
.sample-header {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: space-between;
|
||||
padding: 0.5rem 0.85rem;
|
||||
background: var(--color-surface, #f4f7fc);
|
||||
border-bottom: 1px solid var(--color-border, #d0d7e8);
|
||||
font-size: 0.85rem;
|
||||
}
|
||||
|
||||
.btn-collapse {
|
||||
border: none;
|
||||
background: transparent;
|
||||
font-size: 0.78rem;
|
||||
color: var(--color-text-secondary, #6b7a99);
|
||||
cursor: pointer;
|
||||
}
|
||||
|
||||
.sample-prompt {
|
||||
padding: 0.65rem 0.85rem;
|
||||
border-bottom: 1px solid var(--color-border, #d0d7e8);
|
||||
}
|
||||
.sample-prompt:last-child { border-bottom: none; }
|
||||
|
||||
.sample-tag {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.5rem;
|
||||
margin-bottom: 0.35rem;
|
||||
font-size: 0.8rem;
|
||||
}
|
||||
|
||||
.tag-name { font-weight: 600; color: var(--color-text, #1a2338); }
|
||||
.tag-score { color: var(--app-primary, #2A6080); font-weight: 700; }
|
||||
.tag-latency { color: var(--color-text-secondary, #6b7a99); margin-left: auto; }
|
||||
|
||||
.sample-text {
|
||||
margin: 0;
|
||||
font-size: 0.82rem;
|
||||
white-space: pre-wrap;
|
||||
word-break: break-word;
|
||||
max-height: 200px;
|
||||
overflow-y: auto;
|
||||
background: var(--color-bg, #fff);
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: 0.35rem;
|
||||
padding: 0.5rem 0.65rem;
|
||||
color: var(--color-text, #1a2338);
|
||||
font-family: inherit;
|
||||
}
|
||||
|
||||
@media (max-width: 640px) {
|
||||
.style-controls { flex-direction: column; }
|
||||
.model-picker, .options-panel { min-width: 0; }
|
||||
.option-hint { display: none; }
|
||||
.group-note { display: none; }
|
||||
}
|
||||
</style>
|
||||
|
|
@ -1,161 +0,0 @@
|
|||
import { mount, flushPromises } from '@vue/test-utils'
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest'
|
||||
import TrainJobsView from './TrainJobsView.vue'
|
||||
|
||||
const sampleJob = {
|
||||
id: 'job-abc123',
|
||||
type: 'classifier',
|
||||
model_key: 'deberta-v3-small',
|
||||
status: 'queued',
|
||||
created_at: '2026-05-01T10:00:00Z',
|
||||
config: null,
|
||||
}
|
||||
|
||||
function makeFetch(jobs: unknown[] = []) {
|
||||
return vi.fn().mockImplementation((url: string, opts?: RequestInit) => {
|
||||
if ((opts?.method ?? 'GET') === 'POST') {
|
||||
return Promise.resolve({
|
||||
ok: true,
|
||||
json: async () => ({ ...sampleJob, id: 'new-job', status: 'queued' }),
|
||||
text: async () => '',
|
||||
})
|
||||
}
|
||||
if ((opts?.method ?? 'GET') === 'DELETE') {
|
||||
return Promise.resolve({ ok: true, json: async () => ({}), text: async () => '' })
|
||||
}
|
||||
// GET
|
||||
return Promise.resolve({
|
||||
ok: true,
|
||||
json: async () => ({ jobs }),
|
||||
text: async () => '',
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
class MockEventSource {
|
||||
onmessage: ((e: MessageEvent) => void) | null = null
|
||||
onerror: ((e: Event) => void) | null = null
|
||||
private _url: string
|
||||
constructor(url: string) { this._url = url }
|
||||
close() {}
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
vi.stubGlobal('fetch', makeFetch([sampleJob]))
|
||||
vi.stubGlobal('EventSource', MockEventSource)
|
||||
})
|
||||
|
||||
describe('TrainJobsView', () => {
|
||||
it('renders page title "Training Jobs"', async () => {
|
||||
const w = mount(TrainJobsView)
|
||||
await flushPromises()
|
||||
expect(w.find('h1.page-title').text()).toContain('Training Jobs')
|
||||
})
|
||||
|
||||
it('renders the new job form with type selector and model key input', async () => {
|
||||
const w = mount(TrainJobsView)
|
||||
await flushPromises()
|
||||
expect(w.find('select.job-type-select').exists()).toBe(true)
|
||||
expect(w.find('input.model-key-input').exists()).toBe(true)
|
||||
expect(w.find('button.submit-job-btn').exists()).toBe(true)
|
||||
})
|
||||
|
||||
it('type selector has classifier and llm-sft options', async () => {
|
||||
const w = mount(TrainJobsView)
|
||||
await flushPromises()
|
||||
const options = w.findAll('select.job-type-select option')
|
||||
const values = options.map(o => o.attributes('value') ?? o.element.textContent)
|
||||
expect(values).toContain('classifier')
|
||||
expect(values).toContain('llm-sft')
|
||||
})
|
||||
|
||||
it('submit button is disabled when model key is empty', async () => {
|
||||
const w = mount(TrainJobsView)
|
||||
await flushPromises()
|
||||
const btn = w.find('button.submit-job-btn')
|
||||
expect((btn.element as HTMLButtonElement).disabled).toBe(true)
|
||||
})
|
||||
|
||||
it('submit button is enabled when model key is entered', async () => {
|
||||
const w = mount(TrainJobsView)
|
||||
await flushPromises()
|
||||
await w.find('input.model-key-input').setValue('deberta-v3-small')
|
||||
const btn = w.find('button.submit-job-btn')
|
||||
expect((btn.element as HTMLButtonElement).disabled).toBe(false)
|
||||
})
|
||||
|
||||
it('shows job table with existing jobs', async () => {
|
||||
const w = mount(TrainJobsView)
|
||||
await flushPromises()
|
||||
expect(w.find('table.jobs-table').exists()).toBe(true)
|
||||
expect(w.text()).toContain('deberta-v3-small')
|
||||
})
|
||||
|
||||
it('shows status pill for each job', async () => {
|
||||
const w = mount(TrainJobsView)
|
||||
await flushPromises()
|
||||
expect(w.find('.status-pill').exists()).toBe(true)
|
||||
expect(w.find('.status-queued').exists()).toBe(true)
|
||||
})
|
||||
|
||||
it('shows cancel button for queued/running jobs', async () => {
|
||||
const w = mount(TrainJobsView)
|
||||
await flushPromises()
|
||||
expect(w.find('button.cancel-btn').exists()).toBe(true)
|
||||
})
|
||||
|
||||
it('submitting new job calls POST /api/train/jobs and refreshes', async () => {
|
||||
const fetchMock = makeFetch([])
|
||||
vi.stubGlobal('fetch', fetchMock)
|
||||
const w = mount(TrainJobsView)
|
||||
await flushPromises()
|
||||
await w.find('input.model-key-input').setValue('my-model')
|
||||
await w.find('button.submit-job-btn').trigger('click')
|
||||
await flushPromises()
|
||||
const calls = (fetchMock as ReturnType<typeof vi.fn>).mock.calls as [string, RequestInit?][]
|
||||
const postCall = calls.find(([, opts]) => (opts?.method ?? 'GET') === 'POST')
|
||||
expect(postCall).toBeDefined()
|
||||
expect(postCall![0]).toContain('/api/train/jobs')
|
||||
})
|
||||
|
||||
it('shows View Log button for running jobs', async () => {
|
||||
vi.stubGlobal('fetch', makeFetch([{ ...sampleJob, status: 'running' }]))
|
||||
const w = mount(TrainJobsView)
|
||||
await flushPromises()
|
||||
expect(w.find('button.view-log-btn').exists()).toBe(true)
|
||||
})
|
||||
it('shows error when config JSON is invalid', async () => {
|
||||
const w = mount(TrainJobsView)
|
||||
await flushPromises()
|
||||
await w.find('input.model-key-input').setValue('my-model')
|
||||
await w.find('textarea.config-textarea').setValue('{ not valid json }')
|
||||
await w.find('button.submit-job-btn').trigger('click')
|
||||
await flushPromises()
|
||||
expect(w.find('.error-notice').exists()).toBe(true)
|
||||
expect(w.find('.error-notice').text()).toContain('not valid')
|
||||
})
|
||||
|
||||
it('shows error notice when jobs load fails', async () => {
|
||||
vi.stubGlobal('fetch', vi.fn().mockResolvedValue({
|
||||
ok: false,
|
||||
status: 500,
|
||||
json: async () => ({}),
|
||||
text: async () => '',
|
||||
}))
|
||||
const w = mount(TrainJobsView)
|
||||
await flushPromises()
|
||||
expect(w.find('.error-notice').exists()).toBe(true)
|
||||
expect(w.find('table.jobs-table').exists()).toBe(false)
|
||||
})
|
||||
|
||||
it('cancel button optimistically updates job status to cancelled', async () => {
|
||||
const w = mount(TrainJobsView)
|
||||
await flushPromises()
|
||||
await w.find('button.cancel-btn').trigger('click')
|
||||
await flushPromises()
|
||||
// After cancel, job should show status-cancelled pill (not status-queued)
|
||||
expect(w.find('.status-cancelled').exists()).toBe(true)
|
||||
expect(w.find('.status-queued').exists()).toBe(false)
|
||||
})
|
||||
|
||||
})
|
||||
|
|
@ -1,593 +0,0 @@
|
|||
<template>
|
||||
<div class="train-jobs-view">
|
||||
<header class="view-header">
|
||||
<h1 class="page-title">🧠 Training Jobs</h1>
|
||||
</header>
|
||||
|
||||
<!-- New Job form -->
|
||||
<section class="section">
|
||||
<h2 class="section-title">New Job</h2>
|
||||
<form class="new-job-form" @submit.prevent="submitJob">
|
||||
<div class="form-row">
|
||||
<label class="form-label" for="job-type">Type</label>
|
||||
<select
|
||||
id="job-type"
|
||||
v-model="form.type"
|
||||
class="job-type-select form-control"
|
||||
>
|
||||
<option value="classifier">classifier</option>
|
||||
<option value="llm-sft">llm-sft</option>
|
||||
</select>
|
||||
</div>
|
||||
|
||||
<div class="form-row">
|
||||
<label class="form-label" for="model-key">Model key</label>
|
||||
<input
|
||||
id="model-key"
|
||||
v-model.trim="form.model_key"
|
||||
type="text"
|
||||
class="model-key-input form-control"
|
||||
placeholder="e.g. microsoft/deberta-v3-small"
|
||||
autocomplete="off"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div class="form-row">
|
||||
<label class="form-label" for="job-config">Config JSON <span class="form-hint">(optional)</span></label>
|
||||
<textarea
|
||||
id="job-config"
|
||||
v-model="form.config_raw"
|
||||
class="config-textarea form-control"
|
||||
rows="4"
|
||||
placeholder='{"learning_rate": 2e-5}'
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div v-if="submitError" class="error-notice" role="alert">{{ submitError }}</div>
|
||||
|
||||
<button
|
||||
type="submit"
|
||||
class="submit-job-btn btn-primary"
|
||||
:disabled="submitting || !form.model_key"
|
||||
@click.prevent="submitJob"
|
||||
>
|
||||
{{ submitting ? 'Queuing…' : 'Queue Job' }}
|
||||
</button>
|
||||
</form>
|
||||
</section>
|
||||
|
||||
<!-- Job queue table -->
|
||||
<section class="section">
|
||||
<h2 class="section-title">Job Queue</h2>
|
||||
|
||||
<div v-if="loadError" class="error-notice" role="alert">
|
||||
{{ loadError }}
|
||||
<button class="btn-retry" @click="loadJobs">Retry</button>
|
||||
</div>
|
||||
|
||||
<div v-else-if="jobs.length === 0" class="empty-notice">
|
||||
No training jobs yet.
|
||||
</div>
|
||||
|
||||
<div v-else class="jobs-table-wrap">
|
||||
<table class="jobs-table">
|
||||
<thead>
|
||||
<tr>
|
||||
<th>ID</th>
|
||||
<th>Type</th>
|
||||
<th>Model</th>
|
||||
<th>Status</th>
|
||||
<th>Created</th>
|
||||
<th></th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
<tr v-for="job in jobs" :key="job.id">
|
||||
<td class="td-id" :title="job.id">{{ job.id.slice(0, 8) }}</td>
|
||||
<td>
|
||||
<span class="type-chip">{{ job.type }}</span>
|
||||
</td>
|
||||
<td class="td-model">{{ job.model_key }}</td>
|
||||
<td>
|
||||
<span class="status-pill" :class="`status-${job.status}`">{{ job.status }}</span>
|
||||
</td>
|
||||
<td class="td-date">{{ formatDate(job.created_at) }}</td>
|
||||
<td class="td-actions">
|
||||
<button
|
||||
v-if="job.status === 'running'"
|
||||
class="view-log-btn btn-sm"
|
||||
@click="openLog(job.id)"
|
||||
>
|
||||
View Log
|
||||
</button>
|
||||
<button
|
||||
v-if="job.status === 'queued' || job.status === 'running'"
|
||||
class="cancel-btn btn-sm btn-danger-sm"
|
||||
:disabled="cancellingId === job.id"
|
||||
@click="cancelJob(job.id)"
|
||||
>
|
||||
{{ cancellingId === job.id ? '…' : 'Cancel' }}
|
||||
</button>
|
||||
</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
<div v-if="cancelError" class="error-notice" role="alert">{{ cancelError }}</div>
|
||||
</section>
|
||||
|
||||
<!-- Log panel (SSE) -->
|
||||
<section v-if="logJobId" class="section log-section">
|
||||
<div class="log-header">
|
||||
<h2 class="section-title">Log — {{ logJobId.slice(0, 8) }}</h2>
|
||||
<button class="btn-close-log" @click="closeLog">✕ Close</button>
|
||||
</div>
|
||||
<div class="log-panel" ref="logPanelEl">
|
||||
<div
|
||||
v-for="(line, i) in logLines"
|
||||
:key="i"
|
||||
class="log-line"
|
||||
>{{ line }}</div>
|
||||
<div v-if="logLines.length === 0" class="log-line log-muted">Connecting…</div>
|
||||
</div>
|
||||
</section>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, nextTick, onUnmounted } from 'vue'
|
||||
import { useApiSSE } from '../composables/useApi'
|
||||
|
||||
interface TrainJob {
|
||||
id: string
|
||||
type: 'classifier' | 'llm-sft'
|
||||
model_key: string
|
||||
status: 'queued' | 'running' | 'completed' | 'failed' | 'cancelled'
|
||||
created_at: string
|
||||
config: Record<string, unknown> | null
|
||||
}
|
||||
|
||||
const jobs = ref<TrainJob[]>([])
|
||||
const loadError = ref<string | null>(null)
|
||||
const submitError = ref<string | null>(null)
|
||||
const submitting = ref(false)
|
||||
const cancellingId = ref<string | null>(null)
|
||||
const cancelError = ref<string | null>(null)
|
||||
|
||||
const form = ref({
|
||||
type: 'classifier' as 'classifier' | 'llm-sft',
|
||||
model_key: '',
|
||||
config_raw: '',
|
||||
})
|
||||
|
||||
// ── Log panel state ──
|
||||
const logJobId = ref<string | null>(null)
|
||||
const logLines = ref<string[]>([])
|
||||
const logPanelEl = ref<HTMLElement | null>(null)
|
||||
let closeSSE: (() => void) | null = null
|
||||
|
||||
// ── Data loading ──
|
||||
|
||||
async function loadJobs() {
|
||||
loadError.value = null
|
||||
try {
|
||||
const res = await fetch('/api/train/jobs')
|
||||
if (!res.ok) { loadError.value = `Failed to load jobs (HTTP ${res.status}).`; return }
|
||||
const data = await res.json() as { jobs: TrainJob[] }
|
||||
jobs.value = data.jobs ?? []
|
||||
} catch {
|
||||
loadError.value = 'Network error loading jobs.'
|
||||
}
|
||||
}
|
||||
|
||||
// ── Submit ──
|
||||
|
||||
async function submitJob() {
|
||||
if (!form.value.model_key) return
|
||||
submitError.value = null
|
||||
submitting.value = true
|
||||
|
||||
let config: Record<string, unknown> | null = null
|
||||
if (form.value.config_raw.trim()) {
|
||||
try {
|
||||
config = JSON.parse(form.value.config_raw) as Record<string, unknown>
|
||||
} catch {
|
||||
submitError.value = 'Config JSON is not valid. Fix it before submitting.'
|
||||
submitting.value = false
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
const res = await fetch('/api/train/jobs', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({
|
||||
type: form.value.type,
|
||||
model_key: form.value.model_key,
|
||||
config_json: config,
|
||||
}),
|
||||
})
|
||||
if (!res.ok) {
|
||||
const detail = await res.text().catch(() => '')
|
||||
submitError.value = `Failed to queue job (HTTP ${res.status})${detail ? `: ${detail}` : '.'}`
|
||||
return
|
||||
}
|
||||
const newJob = await res.json() as TrainJob
|
||||
jobs.value = [newJob, ...jobs.value]
|
||||
form.value = { type: 'classifier', model_key: '', config_raw: '' }
|
||||
} catch {
|
||||
submitError.value = 'Network error submitting job.'
|
||||
} finally {
|
||||
submitting.value = false
|
||||
}
|
||||
}
|
||||
|
||||
// ── Cancel ──
|
||||
|
||||
async function cancelJob(id: string) {
|
||||
cancellingId.value = id
|
||||
cancelError.value = null
|
||||
try {
|
||||
const res = await fetch(`/api/train/jobs/${encodeURIComponent(id)}/cancel`, { method: 'DELETE' })
|
||||
if (res.ok) {
|
||||
jobs.value = jobs.value.map(j =>
|
||||
j.id === id ? { ...j, status: 'cancelled' as const } : j
|
||||
)
|
||||
} else {
|
||||
cancelError.value = `Failed to cancel job (HTTP ${res.status}).`
|
||||
}
|
||||
} catch {
|
||||
cancelError.value = 'Network error cancelling job.'
|
||||
} finally {
|
||||
cancellingId.value = null
|
||||
}
|
||||
}
|
||||
|
||||
// ── Log SSE ──
|
||||
|
||||
function openLog(id: string) {
|
||||
closeLog()
|
||||
logJobId.value = id
|
||||
logLines.value = []
|
||||
|
||||
closeSSE = useApiSSE(
|
||||
`/api/train/jobs/${encodeURIComponent(id)}/run`,
|
||||
(data) => {
|
||||
if (data.type === 'log' || data.type === 'progress' || data.type === 'error') {
|
||||
logLines.value = [...logLines.value, String(data.message ?? '')]
|
||||
nextTick(() => {
|
||||
if (logPanelEl.value) {
|
||||
logPanelEl.value.scrollTop = logPanelEl.value.scrollHeight
|
||||
}
|
||||
})
|
||||
}
|
||||
if (data.type === 'error') {
|
||||
logLines.value = [...logLines.value, '--- stream ended with error ---']
|
||||
nextTick(() => {
|
||||
if (logPanelEl.value) {
|
||||
logPanelEl.value.scrollTop = logPanelEl.value.scrollHeight
|
||||
}
|
||||
})
|
||||
}
|
||||
},
|
||||
() => {
|
||||
logLines.value = [...logLines.value, '--- stream complete ---']
|
||||
},
|
||||
() => {
|
||||
logLines.value = [...logLines.value, '--- connection lost ---']
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
function closeLog() {
|
||||
closeSSE?.()
|
||||
closeSSE = null
|
||||
logJobId.value = null
|
||||
logLines.value = []
|
||||
}
|
||||
|
||||
// ── Helpers ──
|
||||
|
||||
function formatDate(iso: string): string {
|
||||
const d = new Date(iso)
|
||||
if (isNaN(d.getTime())) return iso
|
||||
return d.toLocaleString(undefined, { dateStyle: 'short', timeStyle: 'short' })
|
||||
}
|
||||
|
||||
// ── Lifecycle ──
|
||||
|
||||
loadJobs()
|
||||
|
||||
onUnmounted(() => {
|
||||
closeSSE?.()
|
||||
})
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
.train-jobs-view {
|
||||
max-width: 860px;
|
||||
margin: 0 auto;
|
||||
padding: 1.5rem 1rem 4rem;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 2rem;
|
||||
}
|
||||
|
||||
.view-header { display: flex; align-items: center; }
|
||||
|
||||
.page-title {
|
||||
font-family: var(--font-display, var(--font-body, sans-serif));
|
||||
font-size: 1.4rem;
|
||||
font-weight: 700;
|
||||
color: var(--app-primary, #2A6080);
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
.section {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0.75rem;
|
||||
}
|
||||
|
||||
.section-title {
|
||||
font-size: 1rem;
|
||||
font-weight: 600;
|
||||
color: var(--color-text, #1a2338);
|
||||
padding-bottom: 0.4rem;
|
||||
border-bottom: 1px solid var(--color-border, #a8b8d0);
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
.new-job-form {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0.75rem;
|
||||
max-width: 480px;
|
||||
}
|
||||
|
||||
.form-row {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0.3rem;
|
||||
}
|
||||
|
||||
.form-label {
|
||||
font-size: 0.85rem;
|
||||
font-weight: 600;
|
||||
color: var(--color-text-muted, #4a5c7a);
|
||||
}
|
||||
|
||||
.form-hint {
|
||||
font-weight: 400;
|
||||
font-size: 0.78rem;
|
||||
}
|
||||
|
||||
.form-control {
|
||||
padding: 0.45rem 0.65rem;
|
||||
border: 1px solid var(--color-border, #a8b8d0);
|
||||
border-radius: var(--radius-md, 0.5rem);
|
||||
background: var(--color-surface-raised, #f5f7fc);
|
||||
color: var(--color-text, #1a2338);
|
||||
font-size: 0.9rem;
|
||||
font-family: var(--font-body, sans-serif);
|
||||
}
|
||||
|
||||
.form-control:focus {
|
||||
outline: 2px solid var(--app-primary, #2A6080);
|
||||
outline-offset: -1px;
|
||||
}
|
||||
|
||||
.config-textarea {
|
||||
resize: vertical;
|
||||
font-family: var(--font-mono, monospace);
|
||||
font-size: 0.82rem;
|
||||
}
|
||||
|
||||
.btn-primary {
|
||||
padding: 0.4rem 0.9rem;
|
||||
border-radius: var(--radius-md, 0.5rem);
|
||||
border: 1px solid var(--app-primary, #2A6080);
|
||||
background: var(--app-primary, #2A6080);
|
||||
color: #fff;
|
||||
font-size: 0.88rem;
|
||||
font-family: var(--font-body, sans-serif);
|
||||
cursor: pointer;
|
||||
align-self: flex-start;
|
||||
transition: opacity 0.15s;
|
||||
}
|
||||
|
||||
.btn-primary:disabled { opacity: 0.5; cursor: not-allowed; }
|
||||
.btn-primary:not(:disabled):hover { opacity: 0.85; }
|
||||
|
||||
.btn-sm {
|
||||
padding: 0.2rem 0.55rem;
|
||||
font-size: 0.78rem;
|
||||
border-radius: 0.3rem;
|
||||
cursor: pointer;
|
||||
font-family: var(--font-body, sans-serif);
|
||||
border: 1px solid;
|
||||
transition: background 0.1s;
|
||||
}
|
||||
|
||||
.view-log-btn {
|
||||
border-color: var(--color-info, #1e6091);
|
||||
background: transparent;
|
||||
color: var(--color-info, #1e6091);
|
||||
}
|
||||
|
||||
.view-log-btn:hover {
|
||||
background: color-mix(in srgb, var(--color-info, #1e6091) 10%, transparent);
|
||||
}
|
||||
|
||||
.btn-danger-sm {
|
||||
border-color: var(--color-error, #c0392b);
|
||||
background: transparent;
|
||||
color: var(--color-error, #c0392b);
|
||||
}
|
||||
|
||||
.btn-danger-sm:hover:not(:disabled) {
|
||||
background: color-mix(in srgb, var(--color-error, #c0392b) 10%, transparent);
|
||||
}
|
||||
|
||||
.btn-danger-sm:disabled { opacity: 0.5; cursor: not-allowed; }
|
||||
|
||||
.btn-retry {
|
||||
margin-left: 0.5rem;
|
||||
padding: 0.2rem 0.55rem;
|
||||
border-radius: 0.25rem;
|
||||
border: 1px solid var(--color-error, #c0392b);
|
||||
background: transparent;
|
||||
color: var(--color-error, #c0392b);
|
||||
cursor: pointer;
|
||||
font-size: 0.82rem;
|
||||
}
|
||||
|
||||
.error-notice {
|
||||
padding: 0.6rem 0.8rem;
|
||||
background: color-mix(in srgb, var(--color-error, #c0392b) 10%, transparent);
|
||||
border: 1px solid color-mix(in srgb, var(--color-error, #c0392b) 30%, transparent);
|
||||
border-radius: var(--radius-md, 0.5rem);
|
||||
color: var(--color-error, #c0392b);
|
||||
font-size: 0.88rem;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.5rem;
|
||||
}
|
||||
|
||||
.empty-notice {
|
||||
color: var(--color-text-muted, #4a5c7a);
|
||||
font-size: 0.9rem;
|
||||
padding: 0.75rem;
|
||||
border: 1px dashed var(--color-border, #a8b8d0);
|
||||
border-radius: var(--radius-md, 0.5rem);
|
||||
}
|
||||
|
||||
.jobs-table-wrap { overflow-x: auto; }
|
||||
|
||||
.jobs-table {
|
||||
width: 100%;
|
||||
border-collapse: collapse;
|
||||
font-size: 0.875rem;
|
||||
}
|
||||
|
||||
.jobs-table th {
|
||||
text-align: left;
|
||||
padding: 0.4rem 0.6rem;
|
||||
background: var(--color-surface-raised, #f5f7fc);
|
||||
color: var(--color-text-muted, #4a5c7a);
|
||||
font-size: 0.78rem;
|
||||
font-weight: 600;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.03em;
|
||||
border-bottom: 1px solid var(--color-border, #a8b8d0);
|
||||
white-space: nowrap;
|
||||
}
|
||||
|
||||
.jobs-table td {
|
||||
padding: 0.5rem 0.6rem;
|
||||
border-bottom: 1px solid var(--color-border-light, #ccd5e6);
|
||||
vertical-align: middle;
|
||||
}
|
||||
|
||||
.td-id {
|
||||
font-family: var(--font-mono, monospace);
|
||||
font-size: 0.78rem;
|
||||
color: var(--color-text-muted, #4a5c7a);
|
||||
}
|
||||
|
||||
.td-model {
|
||||
font-family: var(--font-mono, monospace);
|
||||
font-size: 0.82rem;
|
||||
word-break: break-all;
|
||||
}
|
||||
|
||||
.td-date {
|
||||
font-size: 0.8rem;
|
||||
color: var(--color-text-muted, #4a5c7a);
|
||||
white-space: nowrap;
|
||||
}
|
||||
|
||||
.td-actions {
|
||||
display: flex;
|
||||
gap: 0.35rem;
|
||||
align-items: center;
|
||||
flex-wrap: wrap;
|
||||
}
|
||||
|
||||
.status-pill {
|
||||
font-size: 0.68rem;
|
||||
font-weight: 700;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.04em;
|
||||
padding: 0.15rem 0.45rem;
|
||||
border-radius: var(--radius-full, 9999px);
|
||||
white-space: nowrap;
|
||||
}
|
||||
|
||||
.status-queued { background: var(--color-surface-alt, #dde4f0); color: var(--color-text-muted, #4a5c7a); }
|
||||
.status-running { background: color-mix(in srgb, var(--color-info, #1e6091) 15%, transparent); color: var(--color-info, #1e6091); }
|
||||
.status-completed { background: color-mix(in srgb, var(--color-success, #3a7a32) 15%, transparent); color: var(--color-success, #3a7a32); }
|
||||
.status-failed { background: color-mix(in srgb, var(--color-error, #c0392b) 15%, transparent); color: var(--color-error, #c0392b); }
|
||||
.status-cancelled { background: color-mix(in srgb, var(--color-warning, #d4891a) 15%, transparent); color: var(--color-warning, #d4891a); }
|
||||
|
||||
.type-chip {
|
||||
font-size: 0.72rem;
|
||||
font-family: var(--font-mono, monospace);
|
||||
padding: 0.1rem 0.4rem;
|
||||
border-radius: 0.25rem;
|
||||
background: var(--color-surface-alt, #dde4f0);
|
||||
color: var(--color-text, #1a2338);
|
||||
}
|
||||
|
||||
.log-section { gap: 0.5rem; }
|
||||
|
||||
.log-header {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: space-between;
|
||||
gap: 0.5rem;
|
||||
}
|
||||
|
||||
.btn-close-log {
|
||||
background: transparent;
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: 0.25rem;
|
||||
cursor: pointer;
|
||||
font-size: 0.8rem;
|
||||
padding: 0.2rem 0.5rem;
|
||||
color: var(--color-text-muted, #4a5c7a);
|
||||
transition: background 0.1s;
|
||||
}
|
||||
|
||||
.btn-close-log:hover { background: var(--color-surface-raised, #e4ebf5); }
|
||||
|
||||
.log-panel {
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: 0.5rem;
|
||||
max-height: 320px;
|
||||
overflow-y: auto;
|
||||
padding: 0.5rem 0.75rem;
|
||||
background: var(--color-surface, #f0f4fc);
|
||||
font-family: var(--font-mono, monospace);
|
||||
font-size: 0.78rem;
|
||||
}
|
||||
|
||||
.log-line {
|
||||
color: var(--color-text, #1a2338);
|
||||
line-height: 1.5;
|
||||
white-space: pre-wrap;
|
||||
word-break: break-all;
|
||||
}
|
||||
|
||||
.log-muted { color: var(--color-text-muted, #4a5c7a); }
|
||||
|
||||
@media (max-width: 560px) {
|
||||
.jobs-table th:nth-child(4),
|
||||
.jobs-table td:nth-child(4),
|
||||
.jobs-table th:nth-child(5),
|
||||
.jobs-table td:nth-child(5) {
|
||||
display: none;
|
||||
}
|
||||
}
|
||||
</style>
|
||||
|
|
@ -1,101 +0,0 @@
|
|||
import { mount, flushPromises } from '@vue/test-utils'
|
||||
import { createRouter, createWebHashHistory } from 'vue-router'
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest'
|
||||
import TrainResultsView from './TrainResultsView.vue'
|
||||
|
||||
const router = createRouter({
|
||||
history: createWebHashHistory(),
|
||||
routes: [
|
||||
{ path: '/fleet', component: { template: '<div />' } },
|
||||
],
|
||||
})
|
||||
|
||||
const sampleResult = {
|
||||
id: 'run-xyz',
|
||||
job_id: 'job-abc123',
|
||||
model_type: 'classifier',
|
||||
base_model: 'microsoft/deberta-v3-small',
|
||||
val_macro_f1: 0.847,
|
||||
val_accuracy: 0.891,
|
||||
sample_count: 1240,
|
||||
duration_seconds: 842,
|
||||
created_at: '2026-05-01T11:30:00Z',
|
||||
}
|
||||
|
||||
function makeFetch(results: unknown[] = []) {
|
||||
return vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
json: async () => ({ results }),
|
||||
text: async () => '',
|
||||
})
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
vi.stubGlobal('fetch', makeFetch([sampleResult]))
|
||||
})
|
||||
|
||||
describe('TrainResultsView', () => {
|
||||
it('renders page title "Training Results"', async () => {
|
||||
const w = mount(TrainResultsView, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
expect(w.find('h1.page-title').text()).toContain('Training Results')
|
||||
})
|
||||
|
||||
it('shows empty notice when there are no results', async () => {
|
||||
vi.stubGlobal('fetch', makeFetch([]))
|
||||
const w = mount(TrainResultsView, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
expect(w.find('.empty-notice').exists()).toBe(true)
|
||||
})
|
||||
|
||||
it('renders results table when results exist', async () => {
|
||||
const w = mount(TrainResultsView, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
expect(w.find('table.results-table').exists()).toBe(true)
|
||||
})
|
||||
|
||||
it('shows base_model in table', async () => {
|
||||
const w = mount(TrainResultsView, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
expect(w.text()).toContain('deberta-v3-small')
|
||||
})
|
||||
|
||||
it('shows val_macro_f1 formatted as percentage', async () => {
|
||||
const w = mount(TrainResultsView, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
expect(w.text()).toContain('84.7%')
|
||||
})
|
||||
|
||||
it('shows val_accuracy formatted as percentage', async () => {
|
||||
const w = mount(TrainResultsView, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
expect(w.text()).toContain('89.1%')
|
||||
})
|
||||
|
||||
it('shows duration formatted as minutes and seconds', async () => {
|
||||
const w = mount(TrainResultsView, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
// 842 seconds = 14m 2s
|
||||
expect(w.text()).toContain('14m')
|
||||
})
|
||||
|
||||
it('shows Register in Fleet button for classifier results', async () => {
|
||||
const w = mount(TrainResultsView, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
expect(w.find('a.register-btn').exists()).toBe(true)
|
||||
})
|
||||
|
||||
it('does NOT show Register in Fleet button for llm-sft results', async () => {
|
||||
vi.stubGlobal('fetch', makeFetch([{ ...sampleResult, model_type: 'llm-sft' }]))
|
||||
const w = mount(TrainResultsView, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
expect(w.find('a.register-btn').exists()).toBe(false)
|
||||
})
|
||||
|
||||
it('shows error notice when API fails', async () => {
|
||||
vi.stubGlobal('fetch', vi.fn().mockResolvedValue({ ok: false, status: 500, text: async () => '' }))
|
||||
const w = mount(TrainResultsView, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
expect(w.find('.error-notice').exists()).toBe(true)
|
||||
})
|
||||
})
|
||||
|
|
@ -1,296 +0,0 @@
|
|||
<template>
|
||||
<div class="train-results-view">
|
||||
<header class="view-header">
|
||||
<h1 class="page-title">Training Results</h1>
|
||||
<button class="refresh-btn" :disabled="loading" @click="loadResults" aria-label="Refresh">🔄</button>
|
||||
</header>
|
||||
|
||||
<div v-if="error" class="error-notice" role="alert">
|
||||
{{ error }}
|
||||
<button class="btn-retry" @click="loadResults">Retry</button>
|
||||
</div>
|
||||
|
||||
<div v-if="loading" class="loading-state" aria-live="polite">Loading…</div>
|
||||
|
||||
<div v-if="!error && results.length === 0 && !loading" class="empty-notice">
|
||||
No training results yet. Completed jobs will appear here.
|
||||
</div>
|
||||
|
||||
<div v-if="results.length > 0" class="results-table-wrap">
|
||||
<table class="results-table">
|
||||
<thead>
|
||||
<tr>
|
||||
<th>Run</th>
|
||||
<th>Type</th>
|
||||
<th>Base Model</th>
|
||||
<th class="th-numeric">Macro F1</th>
|
||||
<th class="th-numeric">Accuracy</th>
|
||||
<th class="th-numeric">Samples</th>
|
||||
<th class="th-numeric">Duration</th>
|
||||
<th></th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
<tr v-for="r in results" :key="r.id">
|
||||
<td class="td-id" :title="r.id">{{ r.id.slice(0, 8) }}</td>
|
||||
<td>
|
||||
<span class="type-chip">{{ r.model_type }}</span>
|
||||
</td>
|
||||
<td class="td-model" :title="r.base_model">{{ shortModel(r.base_model) }}</td>
|
||||
<td class="td-numeric">
|
||||
<span class="metric-val" :class="scoreClass(r.val_macro_f1)">
|
||||
{{ formatPct(r.val_macro_f1) }}
|
||||
</span>
|
||||
</td>
|
||||
<td class="td-numeric">{{ formatPct(r.val_accuracy) }}</td>
|
||||
<td class="td-numeric">{{ r.sample_count.toLocaleString() }}</td>
|
||||
<td class="td-numeric">{{ formatDuration(r.duration_seconds) }}</td>
|
||||
<td class="td-actions">
|
||||
<RouterLink
|
||||
v-if="r.model_type === 'classifier'"
|
||||
:to="`/fleet?model=${encodeURIComponent(r.base_model)}`"
|
||||
class="register-btn btn-sm-link"
|
||||
>
|
||||
Register in Fleet
|
||||
</RouterLink>
|
||||
</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, onMounted } from 'vue'
|
||||
import { RouterLink } from 'vue-router'
|
||||
|
||||
interface TrainResult {
|
||||
id: string
|
||||
job_id: string
|
||||
model_type: string
|
||||
base_model: string
|
||||
val_macro_f1: number | null
|
||||
val_accuracy: number | null
|
||||
sample_count: number
|
||||
duration_seconds: number | null
|
||||
created_at: string
|
||||
}
|
||||
|
||||
const results = ref<TrainResult[]>([])
|
||||
const loading = ref(false)
|
||||
const error = ref<string | null>(null)
|
||||
|
||||
async function loadResults() {
|
||||
loading.value = true
|
||||
error.value = null
|
||||
try {
|
||||
const res = await fetch('/api/train/results')
|
||||
if (!res.ok) {
|
||||
error.value = `Failed to load results (HTTP ${res.status}).`
|
||||
return
|
||||
}
|
||||
const raw = await res.json() as { results?: TrainResult[] }
|
||||
results.value = Array.isArray(raw?.results) ? raw.results : []
|
||||
} catch {
|
||||
error.value = 'Network error loading results.'
|
||||
} finally {
|
||||
loading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
function formatPct(v: number | null | undefined): string {
|
||||
if (v == null) return '—'
|
||||
return `${(v * 100).toFixed(1)}%`
|
||||
}
|
||||
|
||||
function formatDuration(seconds: number | null | undefined): string {
|
||||
if (seconds == null) return '—'
|
||||
const mins = Math.floor(seconds / 60)
|
||||
const secs = seconds % 60
|
||||
if (mins === 0) return `${secs}s`
|
||||
return `${mins}m ${secs}s`
|
||||
}
|
||||
|
||||
function shortModel(model: string): string {
|
||||
const parts = model.split('/')
|
||||
return parts[parts.length - 1] ?? model
|
||||
}
|
||||
|
||||
function scoreClass(f1: number | null | undefined): string {
|
||||
if (f1 == null) return ''
|
||||
if (f1 >= 0.85) return 'score-great'
|
||||
if (f1 >= 0.75) return 'score-good'
|
||||
return 'score-fair'
|
||||
}
|
||||
|
||||
onMounted(() => loadResults())
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
.train-results-view {
|
||||
max-width: 860px;
|
||||
margin: 0 auto;
|
||||
padding: 1.5rem 1rem 4rem;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 1.75rem;
|
||||
}
|
||||
|
||||
.view-header {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.75rem;
|
||||
}
|
||||
|
||||
.page-title {
|
||||
font-family: var(--font-display, var(--font-body, sans-serif));
|
||||
font-size: 1.4rem;
|
||||
font-weight: 700;
|
||||
color: var(--app-primary, #2A6080);
|
||||
margin: 0;
|
||||
flex: 1;
|
||||
}
|
||||
|
||||
.refresh-btn {
|
||||
background: transparent;
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: 0.375rem;
|
||||
cursor: pointer;
|
||||
font-size: 1rem;
|
||||
padding: 0.3rem 0.5rem;
|
||||
transition: background 0.15s;
|
||||
}
|
||||
|
||||
.refresh-btn:hover:not(:disabled) { background: var(--color-surface-raised, #e4ebf5); }
|
||||
.refresh-btn:disabled { opacity: 0.5; cursor: not-allowed; }
|
||||
|
||||
.error-notice {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.75rem;
|
||||
padding: 0.75rem 1rem;
|
||||
background: color-mix(in srgb, var(--color-error, #c0392b) 10%, transparent);
|
||||
border: 1px solid color-mix(in srgb, var(--color-error, #c0392b) 30%, transparent);
|
||||
border-radius: var(--radius-md, 0.5rem);
|
||||
color: var(--color-error, #c0392b);
|
||||
font-size: 0.88rem;
|
||||
}
|
||||
|
||||
.btn-retry {
|
||||
padding: 0.2rem 0.55rem;
|
||||
border-radius: 0.25rem;
|
||||
border: 1px solid var(--color-error, #c0392b);
|
||||
background: transparent;
|
||||
color: var(--color-error, #c0392b);
|
||||
cursor: pointer;
|
||||
font-size: 0.82rem;
|
||||
}
|
||||
|
||||
.empty-notice {
|
||||
color: var(--color-text-muted, #4a5c7a);
|
||||
font-size: 0.9rem;
|
||||
padding: 0.75rem;
|
||||
border: 1px dashed var(--color-border, #a8b8d0);
|
||||
border-radius: var(--radius-md, 0.5rem);
|
||||
}
|
||||
|
||||
.loading-state {
|
||||
color: var(--color-text-muted, #4a5c7a);
|
||||
font-size: 0.9rem;
|
||||
padding: 0.75rem;
|
||||
}
|
||||
|
||||
.results-table-wrap { overflow-x: auto; }
|
||||
|
||||
.results-table {
|
||||
width: 100%;
|
||||
border-collapse: collapse;
|
||||
font-size: 0.875rem;
|
||||
}
|
||||
|
||||
.results-table th {
|
||||
text-align: left;
|
||||
padding: 0.4rem 0.6rem;
|
||||
background: var(--color-surface-raised, #f5f7fc);
|
||||
color: var(--color-text-muted, #4a5c7a);
|
||||
font-size: 0.78rem;
|
||||
font-weight: 600;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.03em;
|
||||
border-bottom: 1px solid var(--color-border, #a8b8d0);
|
||||
white-space: nowrap;
|
||||
}
|
||||
|
||||
.th-numeric { text-align: right; }
|
||||
|
||||
.results-table td {
|
||||
padding: 0.5rem 0.6rem;
|
||||
border-bottom: 1px solid var(--color-border-light, #ccd5e6);
|
||||
vertical-align: middle;
|
||||
}
|
||||
|
||||
.td-id {
|
||||
font-family: var(--font-mono, monospace);
|
||||
font-size: 0.78rem;
|
||||
color: var(--color-text-muted, #4a5c7a);
|
||||
}
|
||||
|
||||
.td-model {
|
||||
font-family: var(--font-mono, monospace);
|
||||
font-size: 0.82rem;
|
||||
max-width: 16ch;
|
||||
overflow: hidden;
|
||||
text-overflow: ellipsis;
|
||||
white-space: nowrap;
|
||||
}
|
||||
|
||||
.td-numeric {
|
||||
text-align: right;
|
||||
font-family: var(--font-mono, monospace);
|
||||
font-variant-numeric: tabular-nums;
|
||||
font-size: 0.82rem;
|
||||
}
|
||||
|
||||
.td-actions { text-align: right; }
|
||||
|
||||
.metric-val { font-weight: 600; }
|
||||
.score-great { color: var(--color-success, #3a7a32); }
|
||||
.score-good { color: var(--color-warning, #d4891a); }
|
||||
.score-fair { color: var(--color-text-muted, #4a5c7a); }
|
||||
|
||||
.type-chip {
|
||||
font-size: 0.72rem;
|
||||
font-family: var(--font-mono, monospace);
|
||||
padding: 0.1rem 0.4rem;
|
||||
border-radius: 0.25rem;
|
||||
background: var(--color-surface-alt, #dde4f0);
|
||||
color: var(--color-text, #1a2338);
|
||||
}
|
||||
|
||||
.btn-sm-link {
|
||||
font-size: 0.78rem;
|
||||
padding: 0.2rem 0.55rem;
|
||||
border-radius: 0.3rem;
|
||||
border: 1px solid var(--app-primary, #2A6080);
|
||||
color: var(--app-primary, #2A6080);
|
||||
background: transparent;
|
||||
text-decoration: none;
|
||||
white-space: nowrap;
|
||||
display: inline-block;
|
||||
transition: background 0.1s;
|
||||
}
|
||||
|
||||
.btn-sm-link:hover {
|
||||
background: color-mix(in srgb, var(--app-primary, #2A6080) 10%, transparent);
|
||||
}
|
||||
|
||||
@media (max-width: 600px) {
|
||||
.results-table th:nth-child(6),
|
||||
.results-table td:nth-child(6),
|
||||
.results-table th:nth-child(7),
|
||||
.results-table td:nth-child(7) {
|
||||
display: none;
|
||||
}
|
||||
}
|
||||
</style>
|
||||
|
|
@ -1,919 +0,0 @@
|
|||
<template>
|
||||
<div class="voice-tab">
|
||||
|
||||
<!-- ── Controls row ──────────────────────────────────────────────────── -->
|
||||
<div class="voice-controls">
|
||||
|
||||
<!-- Model picker -->
|
||||
<details class="model-picker" open>
|
||||
<summary class="picker-summary">
|
||||
<span class="picker-title">🎙 Models</span>
|
||||
<span class="picker-badge">{{ selectedCount }} selected</span>
|
||||
<button class="btn-refresh" :disabled="modelsLoading" @click.stop="loadModels" title="Refresh model list">
|
||||
{{ modelsLoading ? '⏳' : '🔄' }}
|
||||
</button>
|
||||
</summary>
|
||||
<div class="picker-body">
|
||||
<div v-if="modelsLoading" class="picker-loading">Loading models…</div>
|
||||
<div v-else-if="loadError" class="picker-error">{{ loadError }}</div>
|
||||
<template v-else>
|
||||
|
||||
<!-- Ollama group -->
|
||||
<div class="picker-group" v-if="ollamaModels.length">
|
||||
<div class="group-header">
|
||||
<label class="group-check">
|
||||
<input
|
||||
type="checkbox"
|
||||
:checked="isGroupAllSelected('ollama')"
|
||||
:indeterminate="isGroupIndeterminate('ollama')"
|
||||
@change="toggleGroup('ollama', ($event.target as HTMLInputElement).checked)"
|
||||
/>
|
||||
<span class="group-label">Ollama</span>
|
||||
<span class="group-count">({{ ollamaModels.length }})</span>
|
||||
</label>
|
||||
<span class="group-note">auto-synced with Models view</span>
|
||||
</div>
|
||||
<div class="model-list">
|
||||
<label v-for="m in ollamaModels" :key="m.id" class="model-item">
|
||||
<input type="checkbox" :value="m.id" v-model="selectedModels" />
|
||||
<span class="model-name">{{ m.name }}</span>
|
||||
<span v-if="m.size_mb" class="model-meta">{{ formatMb(m.size_mb) }}</span>
|
||||
</label>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- cf-text group -->
|
||||
<div class="picker-group" v-if="cftextModels.length">
|
||||
<div class="group-header">
|
||||
<label class="group-check">
|
||||
<input
|
||||
type="checkbox"
|
||||
:checked="isGroupAllSelected('cf-text')"
|
||||
:indeterminate="isGroupIndeterminate('cf-text')"
|
||||
@change="toggleGroup('cf-text', ($event.target as HTMLInputElement).checked)"
|
||||
/>
|
||||
<span class="group-label">cf-text (cf-orch)</span>
|
||||
<span class="group-count">({{ cftextModels.length }})</span>
|
||||
</label>
|
||||
<span class="group-note">GGUFs via coordinator — enable cf-orch below</span>
|
||||
</div>
|
||||
<div class="model-list">
|
||||
<label v-for="m in cftextModels" :key="m.id" class="model-item">
|
||||
<input type="checkbox" :value="m.id" v-model="selectedModels" />
|
||||
<span class="model-name">{{ m.name }}</span>
|
||||
<span v-if="m.vram_mb" class="model-meta">{{ formatMb(m.vram_mb) }} VRAM</span>
|
||||
</label>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div v-if="!ollamaModels.length && !cftextModels.length" class="picker-empty">
|
||||
No models available — check Ollama and cf-orch connections.
|
||||
</div>
|
||||
|
||||
</template>
|
||||
</div>
|
||||
</details>
|
||||
|
||||
<!-- Options panel -->
|
||||
<details class="options-panel">
|
||||
<summary class="picker-summary">
|
||||
<span class="picker-title">⚙️ Options</span>
|
||||
</summary>
|
||||
<div class="options-body">
|
||||
<label class="option-row">
|
||||
<input type="checkbox" v-model="useCforch" :disabled="running" />
|
||||
<span class="option-label">Use cf-orch backend</span>
|
||||
<span class="option-hint">Routes generation through cf-text instead of ollama</span>
|
||||
</label>
|
||||
<label class="option-row" :class="{ dimmed: !useCforch }">
|
||||
<span class="option-label">Max VRAM (MB)</span>
|
||||
<input
|
||||
type="number"
|
||||
v-model.number="maxVram"
|
||||
:disabled="running || !useCforch"
|
||||
min="1024"
|
||||
max="24576"
|
||||
step="512"
|
||||
class="option-number"
|
||||
/>
|
||||
<span class="option-hint">Skip models exceeding this VRAM limit</span>
|
||||
</label>
|
||||
<label class="option-row">
|
||||
<span class="option-label">Parallel workers</span>
|
||||
<input
|
||||
type="number"
|
||||
v-model.number="workers"
|
||||
:disabled="running"
|
||||
min="1"
|
||||
max="16"
|
||||
step="1"
|
||||
class="option-number"
|
||||
/>
|
||||
<span class="option-hint">Models to score simultaneously (1 = sequential)</span>
|
||||
</label>
|
||||
<label class="option-row">
|
||||
<input type="checkbox" v-model="includeLarge" :disabled="running" />
|
||||
<span class="option-label">Include large models (30B+)</span>
|
||||
<span class="option-hint">Off by default — these take much longer</span>
|
||||
</label>
|
||||
</div>
|
||||
</details>
|
||||
|
||||
</div>
|
||||
|
||||
<!-- ── Run controls ──────────────────────────────────────────────────── -->
|
||||
<div class="run-bar">
|
||||
<button class="btn-run" :disabled="running || selectedCount === 0" @click="startBenchmark">
|
||||
{{ running ? '⏳ Running…' : results.length ? '🔄 Re-run' : '▶ Run Benchmark' }}
|
||||
</button>
|
||||
<button v-if="running" class="btn-cancel" @click="cancelBenchmark">✕ Cancel</button>
|
||||
<span v-if="selectedCount === 0 && !running" class="run-hint">Select at least one model above</span>
|
||||
</div>
|
||||
|
||||
<!-- ── Progress log ──────────────────────────────────────────────────── -->
|
||||
<div v-if="runLog.length" class="run-log">
|
||||
<div class="run-log-header">
|
||||
<span class="run-log-title">Run log</span>
|
||||
<button class="btn-clear" @click="runLog = []">Clear</button>
|
||||
</div>
|
||||
<pre class="run-log-body" ref="logEl">{{ runLog.join('\n') }}</pre>
|
||||
</div>
|
||||
|
||||
<!-- ── Past runs picker ─────────────────────────────────────────────── -->
|
||||
<div class="history-bar" v-if="pastRuns.length">
|
||||
<label class="history-label">📂 Past runs:</label>
|
||||
<select class="history-select" v-model="selectedRun" @change="loadRun(selectedRun)">
|
||||
<option value="">— select a past run —</option>
|
||||
<option v-for="r in pastRuns" :key="r.filename" :value="r.filename">
|
||||
{{ r.date }} · {{ r.model_count }} model{{ r.model_count !== 1 ? 's' : '' }} · top {{ r.top_score }}/100
|
||||
</option>
|
||||
</select>
|
||||
</div>
|
||||
|
||||
<!-- ── Results table ─────────────────────────────────────────────────── -->
|
||||
<div v-if="results.length" class="results-section">
|
||||
<div class="results-header">
|
||||
<h2 class="results-title">Rankings</h2>
|
||||
<button
|
||||
class="btn-corrections"
|
||||
:disabled="sendingCorrections"
|
||||
@click="sendToCorrections"
|
||||
title="Push all outputs from this run into the Corrections review queue"
|
||||
>
|
||||
{{ sendingCorrections ? '⏳ Sending…' : correctionsMsg || '✍️ Send to Corrections' }}
|
||||
</button>
|
||||
</div>
|
||||
<div class="results-table-wrap">
|
||||
<table class="results-table">
|
||||
<thead>
|
||||
<tr>
|
||||
<th>Rank</th>
|
||||
<th>Model</th>
|
||||
<th>Score</th>
|
||||
<th>Latency</th>
|
||||
<th title="Em-dash count">—</th>
|
||||
<th title="Filler phrase hits">Fillers</th>
|
||||
<th title="Semicolons">;</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
<tr
|
||||
v-for="(r, i) in results"
|
||||
:key="r.model_id"
|
||||
class="result-row"
|
||||
:class="{ 'top-row': i === 0 }"
|
||||
@click="toggleExpanded(r.model_id)"
|
||||
>
|
||||
<td class="rank-cell">{{ medal(i) }}</td>
|
||||
<td class="model-cell">
|
||||
<span class="model-name-text">{{ r.model_id }}</span>
|
||||
</td>
|
||||
<td class="score-cell">
|
||||
<span class="score-pill" :style="scorePillStyle(r.avg_score)">
|
||||
{{ r.avg_score.toFixed(0) }}
|
||||
</span>
|
||||
</td>
|
||||
<td class="latency-cell">{{ formatLatency(r.avg_latency_ms) }}</td>
|
||||
<td class="violation-cell" :class="{ 'has-violation': r.total_em_dashes > 0 }">
|
||||
{{ r.total_em_dashes }}
|
||||
</td>
|
||||
<td class="violation-cell" :class="{ 'has-violation': r.total_filler_hits > 0 }">
|
||||
{{ r.total_filler_hits }}
|
||||
</td>
|
||||
<td class="violation-cell" :class="{ 'has-violation': r.total_semicolons > 0 }">
|
||||
{{ r.total_semicolons }}
|
||||
</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
|
||||
<!-- Expandable sample outputs -->
|
||||
<div v-for="r in results" :key="'exp-' + r.model_id">
|
||||
<div v-if="expandedModels.has(r.model_id)" class="sample-outputs">
|
||||
<div class="sample-header">
|
||||
<strong>{{ r.model_id }}</strong>
|
||||
<button class="btn-collapse" @click="toggleExpanded(r.model_id)">✕ Close</button>
|
||||
</div>
|
||||
<div v-for="pr in r.prompt_results" :key="pr.tag" class="sample-prompt">
|
||||
<div class="sample-tag">
|
||||
<span class="tag-name">{{ pr.tag }}</span>
|
||||
<span class="tag-score">{{ pr.score.toFixed(0) }}/100</span>
|
||||
<span class="tag-latency">{{ formatLatency(pr.latency_ms) }}</span>
|
||||
</div>
|
||||
<pre class="sample-text">{{ pr.output || '(no output)' }}</pre>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, computed, onMounted, nextTick, watch } from 'vue'
|
||||
|
||||
// ── Types ───────────────────────────────────────────────────────────────────
|
||||
|
||||
interface VoiceModel {
|
||||
id: string
|
||||
name: string
|
||||
source: 'ollama' | 'cf-text'
|
||||
size_mb?: number | null
|
||||
vram_mb?: number | null
|
||||
description?: string
|
||||
}
|
||||
|
||||
interface PromptResult {
|
||||
tag: string
|
||||
output: string
|
||||
score: number
|
||||
latency_ms: number
|
||||
signals: Record<string, unknown>
|
||||
}
|
||||
|
||||
interface ModelResult {
|
||||
model_id: string
|
||||
avg_score: number
|
||||
avg_latency_ms: number
|
||||
total_filler_hits: number
|
||||
total_em_dashes: number
|
||||
total_semicolons: number
|
||||
prompt_results: PromptResult[]
|
||||
}
|
||||
|
||||
interface PastRun {
|
||||
filename: string
|
||||
date: string
|
||||
model_count: number
|
||||
top_score: number
|
||||
}
|
||||
|
||||
// ── State ───────────────────────────────────────────────────────────────────
|
||||
|
||||
const ollamaModels = ref<VoiceModel[]>([])
|
||||
const cftextModels = ref<VoiceModel[]>([])
|
||||
const selectedModels = ref<string[]>([])
|
||||
const modelsLoading = ref(false)
|
||||
const loadError = ref('')
|
||||
|
||||
const useCforch = ref(false)
|
||||
const maxVram = ref(7200)
|
||||
const workers = ref(1)
|
||||
const includeLarge = ref(false)
|
||||
|
||||
const running = ref(false)
|
||||
const runLog = ref<string[]>([])
|
||||
const logEl = ref<HTMLPreElement | null>(null)
|
||||
|
||||
const results = ref<ModelResult[]>([])
|
||||
const pastRuns = ref<PastRun[]>([])
|
||||
const selectedRun = ref('')
|
||||
const expandedModels = ref(new Set<string>())
|
||||
const sendingCorrections = ref(false)
|
||||
const correctionsMsg = ref('')
|
||||
|
||||
// ── Computed ─────────────────────────────────────────────────────────────────
|
||||
|
||||
const selectedCount = computed(() => selectedModels.value.length)
|
||||
|
||||
function isGroupAllSelected(source: string): boolean {
|
||||
const group = source === 'ollama' ? ollamaModels.value : cftextModels.value
|
||||
return group.length > 0 && group.every(m => selectedModels.value.includes(m.id))
|
||||
}
|
||||
|
||||
function isGroupIndeterminate(source: string): boolean {
|
||||
const group = source === 'ollama' ? ollamaModels.value : cftextModels.value
|
||||
const count = group.filter(m => selectedModels.value.includes(m.id)).length
|
||||
return count > 0 && count < group.length
|
||||
}
|
||||
|
||||
// ── Actions ──────────────────────────────────────────────────────────────────
|
||||
|
||||
async function loadModels() {
|
||||
modelsLoading.value = true
|
||||
loadError.value = ''
|
||||
try {
|
||||
const resp = await fetch('/api/voice/models')
|
||||
if (!resp.ok) throw new Error(`HTTP ${resp.status}`)
|
||||
const data = await resp.json()
|
||||
ollamaModels.value = data.ollama ?? []
|
||||
cftextModels.value = data.cf_text ?? []
|
||||
} catch (e: unknown) {
|
||||
loadError.value = `Failed to load models: ${e instanceof Error ? e.message : String(e)}`
|
||||
} finally {
|
||||
modelsLoading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
async function loadPastRuns() {
|
||||
try {
|
||||
const resp = await fetch('/api/voice/results')
|
||||
if (resp.ok) pastRuns.value = await resp.json()
|
||||
} catch { /* non-fatal */ }
|
||||
}
|
||||
|
||||
async function loadRun(filename: string) {
|
||||
if (!filename) return
|
||||
try {
|
||||
const resp = await fetch(`/api/voice/results/${filename}`)
|
||||
if (!resp.ok) throw new Error(`HTTP ${resp.status}`)
|
||||
results.value = await resp.json()
|
||||
expandedModels.value.clear()
|
||||
} catch (e: unknown) {
|
||||
runLog.value.push(`[error] Failed to load ${filename}: ${e instanceof Error ? e.message : String(e)}`)
|
||||
}
|
||||
}
|
||||
|
||||
function toggleGroup(source: string, checked: boolean) {
|
||||
const group = source === 'ollama' ? ollamaModels.value : cftextModels.value
|
||||
const ids = group.map(m => m.id)
|
||||
if (checked) {
|
||||
const newSet = new Set([...selectedModels.value, ...ids])
|
||||
selectedModels.value = [...newSet]
|
||||
} else {
|
||||
selectedModels.value = selectedModels.value.filter(id => !ids.includes(id))
|
||||
}
|
||||
}
|
||||
|
||||
function toggleExpanded(modelId: string) {
|
||||
if (expandedModels.value.has(modelId)) {
|
||||
expandedModels.value.delete(modelId)
|
||||
} else {
|
||||
expandedModels.value.add(modelId)
|
||||
}
|
||||
expandedModels.value = new Set(expandedModels.value)
|
||||
}
|
||||
|
||||
function startBenchmark() {
|
||||
if (running.value || selectedCount.value === 0) return
|
||||
running.value = true
|
||||
runLog.value = []
|
||||
results.value = []
|
||||
expandedModels.value.clear()
|
||||
|
||||
const params = new URLSearchParams({
|
||||
models: selectedModels.value.join(','),
|
||||
use_cforch: String(useCforch.value),
|
||||
max_vram: String(maxVram.value),
|
||||
workers: String(workers.value),
|
||||
include_large: String(includeLarge.value),
|
||||
})
|
||||
|
||||
const es = new EventSource(`/api/voice/run?${params}`)
|
||||
|
||||
es.onmessage = async (ev) => {
|
||||
try {
|
||||
const msg = JSON.parse(ev.data)
|
||||
if (msg.type === 'progress') {
|
||||
runLog.value.push(msg.message)
|
||||
await nextTick()
|
||||
if (logEl.value) logEl.value.scrollTop = logEl.value.scrollHeight
|
||||
} else if (msg.type === 'result') {
|
||||
results.value = msg.results ?? []
|
||||
await loadPastRuns()
|
||||
} else if (msg.type === 'complete') {
|
||||
running.value = false
|
||||
es.close()
|
||||
} else if (msg.type === 'error') {
|
||||
runLog.value.push(`[error] ${msg.message}`)
|
||||
running.value = false
|
||||
es.close()
|
||||
}
|
||||
} catch { /* ignore parse errors */ }
|
||||
}
|
||||
|
||||
es.onerror = () => {
|
||||
if (running.value) {
|
||||
runLog.value.push('[error] Connection lost')
|
||||
running.value = false
|
||||
}
|
||||
es.close()
|
||||
}
|
||||
}
|
||||
|
||||
async function cancelBenchmark() {
|
||||
try {
|
||||
await fetch('/api/voice/cancel', { method: 'POST' })
|
||||
} finally {
|
||||
running.value = false
|
||||
runLog.value.push('[cancelled]')
|
||||
}
|
||||
}
|
||||
|
||||
async function sendToCorrections() {
|
||||
if (!selectedRun.value || sendingCorrections.value) return
|
||||
sendingCorrections.value = true
|
||||
correctionsMsg.value = ''
|
||||
try {
|
||||
const resp = await fetch('/api/voice/send-to-corrections', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ filename: selectedRun.value, model_ids: [] }),
|
||||
})
|
||||
if (!resp.ok) throw new Error(`HTTP ${resp.status}`)
|
||||
const data = await resp.json()
|
||||
correctionsMsg.value = `✓ ${data.imported} added to Corrections`
|
||||
} catch (e: unknown) {
|
||||
correctionsMsg.value = `Error: ${e instanceof Error ? e.message : String(e)}`
|
||||
} finally {
|
||||
sendingCorrections.value = false
|
||||
}
|
||||
}
|
||||
|
||||
// ── Formatting helpers ────────────────────────────────────────────────────────
|
||||
|
||||
function formatMb(mb: number): string {
|
||||
return mb >= 1024 ? `${(mb / 1024).toFixed(1)} GB` : `${mb} MB`
|
||||
}
|
||||
|
||||
function formatLatency(ms: number): string {
|
||||
return ms >= 1000 ? `${(ms / 1000).toFixed(1)}s` : `${Math.round(ms)}ms`
|
||||
}
|
||||
|
||||
function medal(index: number): string {
|
||||
return ['🥇', '🥈', '🥉'][index] ?? `#${index + 1}`
|
||||
}
|
||||
|
||||
function scorePillStyle(score: number): Record<string, string> {
|
||||
const hue = Math.round((score / 100) * 120) // 0=red, 120=green
|
||||
return {
|
||||
background: `hsl(${hue} 60% 88%)`,
|
||||
color: `hsl(${hue} 60% 28%)`,
|
||||
}
|
||||
}
|
||||
|
||||
// ── Lifecycle ─────────────────────────────────────────────────────────────────
|
||||
|
||||
// Auto-enable cf-orch when cf-text models are selected
|
||||
watch(selectedModels, (ids) => {
|
||||
const hasCftext = ids.some(id => cftextModels.value.find(m => m.id === id))
|
||||
if (hasCftext) useCforch.value = true
|
||||
})
|
||||
|
||||
onMounted(async () => {
|
||||
await Promise.all([loadModels(), loadPastRuns()])
|
||||
// Auto-load the latest results if any exist
|
||||
if (pastRuns.value.length) {
|
||||
selectedRun.value = pastRuns.value[0].filename
|
||||
await loadRun(pastRuns.value[0].filename)
|
||||
}
|
||||
})
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
.voice-tab {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 1rem;
|
||||
padding: 1rem 0;
|
||||
}
|
||||
|
||||
/* ── Controls ─────────────────────────────────────────────────────────────── */
|
||||
|
||||
.voice-controls {
|
||||
display: flex;
|
||||
flex-wrap: wrap;
|
||||
gap: 0.75rem;
|
||||
align-items: flex-start;
|
||||
}
|
||||
|
||||
.model-picker,
|
||||
.options-panel {
|
||||
flex: 1;
|
||||
min-width: 280px;
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: 0.5rem;
|
||||
background: var(--color-surface, #f4f7fc);
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
.picker-summary {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.5rem;
|
||||
padding: 0.65rem 0.85rem;
|
||||
cursor: pointer;
|
||||
user-select: none;
|
||||
font-size: 0.9rem;
|
||||
font-weight: 600;
|
||||
list-style: none;
|
||||
}
|
||||
|
||||
.picker-summary::-webkit-details-marker { display: none; }
|
||||
|
||||
.picker-title { flex: 1; color: var(--color-text, #1a2338); }
|
||||
.picker-badge {
|
||||
background: var(--app-primary, #2A6080);
|
||||
color: #fff;
|
||||
border-radius: 9999px;
|
||||
padding: 0.1rem 0.5rem;
|
||||
font-size: 0.72rem;
|
||||
font-weight: 700;
|
||||
}
|
||||
|
||||
.btn-refresh {
|
||||
border: none;
|
||||
background: transparent;
|
||||
cursor: pointer;
|
||||
font-size: 0.85rem;
|
||||
padding: 0.1rem 0.25rem;
|
||||
border-radius: 0.25rem;
|
||||
color: var(--color-text-secondary, #6b7a99);
|
||||
}
|
||||
.btn-refresh:hover { background: var(--color-border, #d0d7e8); }
|
||||
.btn-refresh:disabled { opacity: 0.5; cursor: not-allowed; }
|
||||
|
||||
.picker-body,
|
||||
.options-body {
|
||||
padding: 0.75rem;
|
||||
border-top: 1px solid var(--color-border, #d0d7e8);
|
||||
}
|
||||
|
||||
.picker-loading, .picker-empty {
|
||||
color: var(--color-text-secondary, #6b7a99);
|
||||
font-size: 0.85rem;
|
||||
padding: 0.25rem 0;
|
||||
}
|
||||
|
||||
.picker-error {
|
||||
color: #b91c1c;
|
||||
font-size: 0.85rem;
|
||||
}
|
||||
|
||||
/* ── Model groups ──────────────────────────────────────────────────────────── */
|
||||
|
||||
.picker-group {
|
||||
margin-bottom: 0.75rem;
|
||||
}
|
||||
.picker-group:last-child { margin-bottom: 0; }
|
||||
|
||||
.group-header {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.5rem;
|
||||
margin-bottom: 0.4rem;
|
||||
}
|
||||
|
||||
.group-check {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.35rem;
|
||||
font-size: 0.85rem;
|
||||
font-weight: 600;
|
||||
cursor: pointer;
|
||||
color: var(--color-text, #1a2338);
|
||||
}
|
||||
|
||||
.group-count {
|
||||
color: var(--color-text-secondary, #6b7a99);
|
||||
font-weight: 400;
|
||||
font-size: 0.8rem;
|
||||
}
|
||||
|
||||
.group-note {
|
||||
margin-left: auto;
|
||||
font-size: 0.72rem;
|
||||
color: var(--color-text-secondary, #6b7a99);
|
||||
font-style: italic;
|
||||
}
|
||||
|
||||
.model-list {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0.2rem;
|
||||
padding-left: 1.25rem;
|
||||
max-height: 220px;
|
||||
overflow-y: auto;
|
||||
}
|
||||
|
||||
.model-item {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.4rem;
|
||||
font-size: 0.82rem;
|
||||
cursor: pointer;
|
||||
padding: 0.15rem 0;
|
||||
}
|
||||
|
||||
.model-name { flex: 1; font-family: var(--font-mono, monospace); }
|
||||
|
||||
.model-meta {
|
||||
font-size: 0.72rem;
|
||||
color: var(--color-text-secondary, #6b7a99);
|
||||
}
|
||||
|
||||
/* ── Options ──────────────────────────────────────────────────────────────── */
|
||||
|
||||
.option-row {
|
||||
display: flex;
|
||||
align-items: flex-start;
|
||||
gap: 0.5rem;
|
||||
padding: 0.35rem 0;
|
||||
cursor: pointer;
|
||||
font-size: 0.85rem;
|
||||
}
|
||||
|
||||
.option-label { font-weight: 500; white-space: nowrap; }
|
||||
|
||||
.option-hint {
|
||||
flex: 1;
|
||||
font-size: 0.75rem;
|
||||
color: var(--color-text-secondary, #6b7a99);
|
||||
margin-left: auto;
|
||||
text-align: right;
|
||||
}
|
||||
|
||||
.option-number {
|
||||
width: 90px;
|
||||
padding: 0.2rem 0.4rem;
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: 0.25rem;
|
||||
font-size: 0.85rem;
|
||||
background: var(--color-bg, #fff);
|
||||
color: var(--color-text, #1a2338);
|
||||
}
|
||||
|
||||
.option-row.dimmed { opacity: 0.45; pointer-events: none; }
|
||||
|
||||
/* ── Run bar ──────────────────────────────────────────────────────────────── */
|
||||
|
||||
.run-bar {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.65rem;
|
||||
}
|
||||
|
||||
.btn-run {
|
||||
padding: 0.5rem 1.25rem;
|
||||
border: none;
|
||||
border-radius: 0.375rem;
|
||||
background: var(--app-primary, #2A6080);
|
||||
color: #fff;
|
||||
font-size: 0.9rem;
|
||||
font-weight: 600;
|
||||
cursor: pointer;
|
||||
transition: background 0.15s;
|
||||
}
|
||||
.btn-run:hover:not(:disabled) { background: color-mix(in srgb, var(--app-primary, #2A6080) 80%, #000); }
|
||||
.btn-run:disabled { opacity: 0.5; cursor: not-allowed; }
|
||||
|
||||
.btn-cancel {
|
||||
padding: 0.5rem 0.9rem;
|
||||
border: 1px solid #f85149;
|
||||
border-radius: 0.375rem;
|
||||
background: transparent;
|
||||
color: #b91c1c;
|
||||
font-size: 0.85rem;
|
||||
cursor: pointer;
|
||||
transition: background 0.15s;
|
||||
}
|
||||
.btn-cancel:hover { background: #fee2e2; }
|
||||
|
||||
.run-hint {
|
||||
font-size: 0.8rem;
|
||||
color: var(--color-text-secondary, #6b7a99);
|
||||
}
|
||||
|
||||
/* ── Run log ──────────────────────────────────────────────────────────────── */
|
||||
|
||||
.run-log {
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: 0.5rem;
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
.run-log-header {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: space-between;
|
||||
padding: 0.4rem 0.75rem;
|
||||
background: var(--color-surface, #f4f7fc);
|
||||
border-bottom: 1px solid var(--color-border, #d0d7e8);
|
||||
font-size: 0.8rem;
|
||||
font-weight: 600;
|
||||
color: var(--color-text-secondary, #6b7a99);
|
||||
}
|
||||
|
||||
.run-log-title { text-transform: uppercase; letter-spacing: 0.05em; }
|
||||
|
||||
.btn-clear {
|
||||
border: none;
|
||||
background: transparent;
|
||||
font-size: 0.75rem;
|
||||
color: var(--color-text-secondary, #6b7a99);
|
||||
cursor: pointer;
|
||||
padding: 0.1rem 0.3rem;
|
||||
border-radius: 0.25rem;
|
||||
}
|
||||
.btn-clear:hover { background: var(--color-border, #d0d7e8); }
|
||||
|
||||
.run-log-body {
|
||||
margin: 0;
|
||||
padding: 0.65rem 0.85rem;
|
||||
font-size: 0.78rem;
|
||||
font-family: var(--font-mono, monospace);
|
||||
white-space: pre-wrap;
|
||||
word-break: break-all;
|
||||
max-height: 260px;
|
||||
overflow-y: auto;
|
||||
background: var(--color-bg, #fff);
|
||||
color: var(--color-text, #1a2338);
|
||||
}
|
||||
|
||||
/* ── History bar ──────────────────────────────────────────────────────────── */
|
||||
|
||||
.history-bar {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.6rem;
|
||||
font-size: 0.85rem;
|
||||
}
|
||||
|
||||
.history-label { font-weight: 500; white-space: nowrap; }
|
||||
|
||||
.history-select {
|
||||
flex: 1;
|
||||
padding: 0.3rem 0.5rem;
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: 0.375rem;
|
||||
background: var(--color-surface, #f4f7fc);
|
||||
color: var(--color-text, #1a2338);
|
||||
font-size: 0.85rem;
|
||||
}
|
||||
|
||||
/* ── Results table ────────────────────────────────────────────────────────── */
|
||||
|
||||
.results-section { display: flex; flex-direction: column; gap: 0.75rem; }
|
||||
|
||||
.results-header {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: space-between;
|
||||
gap: 0.75rem;
|
||||
}
|
||||
|
||||
.results-title {
|
||||
font-size: 1rem;
|
||||
font-weight: 700;
|
||||
color: var(--color-text, #1a2338);
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
.btn-corrections {
|
||||
padding: 0.4rem 0.9rem;
|
||||
border: 1px solid var(--app-primary, #2A6080);
|
||||
border-radius: 0.375rem;
|
||||
background: transparent;
|
||||
color: var(--app-primary, #2A6080);
|
||||
font-size: 0.83rem;
|
||||
font-weight: 600;
|
||||
cursor: pointer;
|
||||
white-space: nowrap;
|
||||
transition: background 0.15s, color 0.15s;
|
||||
}
|
||||
.btn-corrections:hover:not(:disabled) {
|
||||
background: var(--app-primary, #2A6080);
|
||||
color: #fff;
|
||||
}
|
||||
.btn-corrections:disabled { opacity: 0.55; cursor: not-allowed; }
|
||||
|
||||
.results-table-wrap {
|
||||
overflow-x: auto;
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: 0.5rem;
|
||||
}
|
||||
|
||||
.results-table {
|
||||
width: 100%;
|
||||
border-collapse: collapse;
|
||||
font-size: 0.85rem;
|
||||
}
|
||||
|
||||
.results-table th {
|
||||
padding: 0.5rem 0.75rem;
|
||||
text-align: left;
|
||||
background: var(--color-surface, #f4f7fc);
|
||||
border-bottom: 1px solid var(--color-border, #d0d7e8);
|
||||
font-size: 0.78rem;
|
||||
font-weight: 700;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.04em;
|
||||
color: var(--color-text-secondary, #6b7a99);
|
||||
white-space: nowrap;
|
||||
}
|
||||
|
||||
.result-row {
|
||||
cursor: pointer;
|
||||
transition: background 0.1s;
|
||||
}
|
||||
.result-row:hover { background: color-mix(in srgb, var(--app-primary, #2A6080) 6%, transparent); }
|
||||
.result-row.top-row { font-weight: 600; }
|
||||
|
||||
.result-row td {
|
||||
padding: 0.5rem 0.75rem;
|
||||
border-bottom: 1px solid var(--color-border, #d0d7e8);
|
||||
}
|
||||
.result-row:last-child td { border-bottom: none; }
|
||||
|
||||
.rank-cell { width: 2.5rem; text-align: center; font-size: 1.1rem; }
|
||||
.model-cell { font-family: var(--font-mono, monospace); word-break: break-all; }
|
||||
.score-cell { width: 5rem; text-align: center; }
|
||||
.latency-cell { width: 5rem; text-align: right; color: var(--color-text-secondary, #6b7a99); }
|
||||
.violation-cell { width: 4rem; text-align: center; color: var(--color-text-secondary, #6b7a99); }
|
||||
.violation-cell.has-violation { color: #b91c1c; font-weight: 700; }
|
||||
|
||||
.score-pill {
|
||||
display: inline-block;
|
||||
padding: 0.15rem 0.55rem;
|
||||
border-radius: 9999px;
|
||||
font-weight: 700;
|
||||
font-size: 0.82rem;
|
||||
}
|
||||
|
||||
/* ── Sample outputs ───────────────────────────────────────────────────────── */
|
||||
|
||||
.sample-outputs {
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: 0.5rem;
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
.sample-header {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: space-between;
|
||||
padding: 0.5rem 0.85rem;
|
||||
background: var(--color-surface, #f4f7fc);
|
||||
border-bottom: 1px solid var(--color-border, #d0d7e8);
|
||||
font-size: 0.85rem;
|
||||
}
|
||||
|
||||
.btn-collapse {
|
||||
border: none;
|
||||
background: transparent;
|
||||
font-size: 0.78rem;
|
||||
color: var(--color-text-secondary, #6b7a99);
|
||||
cursor: pointer;
|
||||
}
|
||||
|
||||
.sample-prompt {
|
||||
padding: 0.65rem 0.85rem;
|
||||
border-bottom: 1px solid var(--color-border, #d0d7e8);
|
||||
}
|
||||
.sample-prompt:last-child { border-bottom: none; }
|
||||
|
||||
.sample-tag {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.5rem;
|
||||
margin-bottom: 0.35rem;
|
||||
font-size: 0.8rem;
|
||||
}
|
||||
|
||||
.tag-name { font-weight: 600; color: var(--color-text, #1a2338); }
|
||||
.tag-score { color: var(--app-primary, #2A6080); font-weight: 700; }
|
||||
.tag-latency { color: var(--color-text-secondary, #6b7a99); margin-left: auto; }
|
||||
|
||||
.sample-text {
|
||||
margin: 0;
|
||||
font-size: 0.82rem;
|
||||
white-space: pre-wrap;
|
||||
word-break: break-word;
|
||||
max-height: 200px;
|
||||
overflow-y: auto;
|
||||
background: var(--color-bg, #fff);
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: 0.35rem;
|
||||
padding: 0.5rem 0.65rem;
|
||||
color: var(--color-text, #1a2338);
|
||||
font-family: inherit;
|
||||
}
|
||||
|
||||
@media (max-width: 640px) {
|
||||
.voice-controls { flex-direction: column; }
|
||||
.model-picker, .options-panel { min-width: 0; }
|
||||
.option-hint { display: none; }
|
||||
.group-note { display: none; }
|
||||
}
|
||||
</style>
|
||||
Loading…
Reference in a new issue