Compare commits
75 commits
| Author | SHA1 | Date | |
|---|---|---|---|
| c123492a1c | |||
| 391ebb3cd1 | |||
| 9bb88b168f | |||
| 13ca082a43 | |||
| d416ef8aa4 | |||
| 79b9ccbd3d | |||
| e93afec271 | |||
| cac91dd8a2 | |||
| 2b990a603a | |||
| 9fdaeeb3d6 | |||
| 71bf88d09b | |||
| bc4ca1095c | |||
| b6aed3dd1b | |||
| 1ad7ba322a | |||
| 32e3b2a0dd | |||
| 12117ad0c6 | |||
| 5939c67b9f | |||
| 5ea77da97d | |||
| 276bdadb92 | |||
| 6f9aad126e | |||
| 258bbdc0af | |||
| 32872d1ec6 | |||
| 1521198cb1 | |||
| 8dda040480 | |||
| bf675ed1f6 | |||
| 0efd1aedbe | |||
| 4c225b94f5 | |||
| 1cd9c5d455 | |||
| 5702a7190b | |||
| 55b017ba3b | |||
| f952ec8971 | |||
| fd8cb622a1 | |||
| 47cb9f661f | |||
| c2de9e53da | |||
| c039ea4698 | |||
| 95afddb772 | |||
| cbe8c0f03e | |||
| 5df33b0f41 | |||
| 41584de5df | |||
| 1d4c07e4a0 | |||
| e823b5e76d | |||
| 88bc6bed67 | |||
| 4a64a6686d | |||
| f2f150b4fb | |||
| 72449561cf | |||
| c177fb1628 | |||
| 3be5055e31 | |||
| 78b64d007d | |||
| bce932461a | |||
| e11db5ccd9 | |||
| 13d1a394d5 | |||
| b077371107 | |||
| 53b25b27ab | |||
| e014da2dec | |||
| c48db45d91 | |||
| d0ba75b995 | |||
| a134af8b7b | |||
| 6ef6f06023 | |||
| 5bdb095235 | |||
| 0904967320 | |||
| 8fda821e15 | |||
| 0853ed7d56 | |||
| aa742bcfc0 | |||
| 32d3436bbd | |||
| 766fbafa02 | |||
| d432026fd7 | |||
| bccb385f61 | |||
| d74ad3f972 | |||
| 99ea39fe38 | |||
| 2054866ff1 | |||
| cbec776ef1 | |||
| 167d7351e3 | |||
| 6689ff07b1 | |||
| 0745bc3f70 | |||
| 2891606765 |
85 changed files with 17511 additions and 2544 deletions
|
|
@ -17,3 +17,7 @@ CF_LICENSE_KEY=CFG-AVCT-xxxx-xxxx-xxxx
|
|||
# Set one of these to use a cloud LLM instead of a local model.
|
||||
# ANTHROPIC_API_KEY=sk-ant-...
|
||||
# OPENAI_API_KEY=sk-...
|
||||
|
||||
# ── HuggingFace (required for gated/terms-restricted model downloads) ─────────
|
||||
# Generate at https://huggingface.co/settings/tokens and accept model terms first.
|
||||
# HF_TOKEN=hf_xxxxxxxxxxxxxxxxxxxx
|
||||
|
|
|
|||
7
.gitignore
vendored
7
.gitignore
vendored
|
|
@ -8,6 +8,9 @@ __pycache__/
|
|||
config/label_tool.yaml
|
||||
|
||||
# Data files (user-generated, not for version control)
|
||||
data/corpus.db
|
||||
data/corpus.db-wal
|
||||
data/corpus.db-shm
|
||||
data/email_score.jsonl
|
||||
data/email_label_queue.jsonl
|
||||
data/email_compare_sample.jsonl
|
||||
|
|
@ -20,3 +23,7 @@ data/sft_approved.jsonl
|
|||
# Claude context — BSL 1.1, keep out of version control
|
||||
CLAUDE.md
|
||||
docs/superpowers/
|
||||
.superpowers/
|
||||
|
||||
# Git worktrees
|
||||
.worktrees/
|
||||
|
|
|
|||
183
README.md
183
README.md
|
|
@ -1,22 +1,120 @@
|
|||
# Avocet — Email Classifier Training Tool
|
||||
<div align="center">
|
||||
<img src="docs/avocet-logo.svg" alt="Avocet" height="96" />
|
||||
|
||||
> *Part of the CircuitForge LLC internal infrastructure suite.*
|
||||
# Avocet
|
||||
|
||||
**Status:** Internal beta — label tool and benchmark harness complete. Used to build training data for Peregrine's email classifier.
|
||||
**Email classifier training tool — label, benchmark, fine-tune.**
|
||||
|
||||
[]()
|
||||
[](https://git.opensourcesolarpunk.com/Circuit-Forge/avocet/releases)
|
||||
[](LICENSE)
|
||||
[]()
|
||||
[](https://circuitforge.tech)
|
||||
</div>
|
||||
|
||||
---
|
||||
|
||||
## What it does
|
||||
## What is Avocet?
|
||||
|
||||
Avocet is the data pipeline for building and benchmarking email classifiers. It has two layers:
|
||||
Avocet is the internal data pipeline Circuit Forge uses to build, evaluate, and fine-tune email classifiers. It implements a three-stage workflow: human labelers review emails one at a time in a drag-to-bucket UI and produce a ground-truth dataset; the benchmark harness scores any number of HuggingFace zero-shot models against that dataset and produces a ranked comparison; and the fine-tune harness adapts the best-scoring base model to the labeled distribution. The output feeds directly into Peregrine's email classification layer. No LLM API key required for the label tool or benchmark — all inference runs locally via HuggingFace Transformers.
|
||||
|
||||
**No LLM required.** Avocet uses zero-shot HuggingFace classification models — no API key, no cloud inference, no GPU required for the label tool. The benchmark harness can optionally export LLM-labeled emails from a Peregrine staging DB, but human labeling via the card-stack UI is the primary workflow.
|
||||
---
|
||||
|
||||
**Layer 1 — Label tool**
|
||||
Card-stack UI for building ground-truth classifier benchmark data. Fetch emails from one or more IMAP accounts (with targeted date-range and sender/subject filters), review them card-by-card, and label each with a job-search category. Labeled output feeds the benchmark harness.
|
||||
## Quick Start
|
||||
|
||||
**Layer 2 — Benchmark harness**
|
||||
Scores HuggingFace zero-shot classification models against the labeled dataset. Supports slow/large model inclusion, visual side-by-side comparison on live emails, and export of LLM-labeled emails from a Peregrine staging DB.
|
||||
```bash
|
||||
git clone https://git.opensourcesolarpunk.com/Circuit-Forge/avocet.git
|
||||
cd avocet
|
||||
|
||||
# Copy config template and fill in your IMAP credentials
|
||||
cp config/label_tool.yaml.example config/label_tool.yaml
|
||||
|
||||
# Start the label tool (Vue SPA + FastAPI, port 8503)
|
||||
./manage.sh start
|
||||
./manage.sh open
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Features
|
||||
|
||||
- **Drag-to-bucket label UI** — ASMR-style card interface; drag emails into labeled buckets or discard without queuing noise into the training set
|
||||
- **Targeted IMAP fetch** — pull emails by date range, sender, or subject filter across multiple accounts without flooding the queue
|
||||
- **Email classifier benchmark** — score any HuggingFace zero-shot model against your labeled JSONL; side-by-side comparison on live IMAP emails
|
||||
- **Planning benchmark** — evaluate LLMs on structured planning tasks; compare models head-to-head with verbose diff output
|
||||
- **Writing style benchmark** — compare Ollama models on writing style coherence; scan local disk for existing outputs
|
||||
- **Fine-tune harness** — HuggingFace Transformers fine-tuning from labeled ground truth; classifier adapter interface for swapping backends at runtime
|
||||
- **Local inference first** — no API key required; GPU optional; designed to run on developer hardware
|
||||
- **Hot-reload dev mode** — uvicorn `--reload` + Vite HMR (hot module replacement) for fast iteration on both API and UI
|
||||
|
||||
---
|
||||
|
||||
## CLI Reference
|
||||
|
||||
All operations go through `manage.sh`.
|
||||
|
||||
### Label Tool
|
||||
|
||||
```bash
|
||||
./manage.sh start # Build Vue SPA and start FastAPI on port 8503
|
||||
./manage.sh stop # Stop FastAPI server
|
||||
./manage.sh restart # Stop, rebuild, and restart
|
||||
./manage.sh status # Show running state and port
|
||||
./manage.sh logs # Tail the API log
|
||||
./manage.sh open # Open http://localhost:8503 in browser
|
||||
./manage.sh dev # Hot-reload: uvicorn --reload + Vite HMR
|
||||
./manage.sh test # Run pytest suite
|
||||
```
|
||||
|
||||
### Email Classifier Benchmark
|
||||
|
||||
```bash
|
||||
./manage.sh benchmark [args] # Run benchmark_classifier.py
|
||||
./manage.sh list-models # List available zero-shot models
|
||||
./manage.sh score # Score models against labeled JSONL
|
||||
./manage.sh score --include-slow # Include large/slow models
|
||||
./manage.sh compare --limit 30 # Side-by-side comparison on live IMAP emails
|
||||
```
|
||||
|
||||
### Planning Benchmark
|
||||
|
||||
```bash
|
||||
./manage.sh plans-bench [args] # Run benchmark_plans.py
|
||||
./manage.sh plans-list # List available models
|
||||
./manage.sh plans-run <model> [args] # Run a single model (verbose)
|
||||
./manage.sh plans-compare <m1> <m2> [...] # Compare models side-by-side
|
||||
```
|
||||
|
||||
### Writing Style Benchmark
|
||||
|
||||
```bash
|
||||
./manage.sh style-bench [args] # Run benchmark_style.py
|
||||
./manage.sh style-list # List available Ollama models
|
||||
./manage.sh style-run [args] # Run writing style benchmark
|
||||
./manage.sh style-last # Print most recent benchmark report
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Data Flow
|
||||
|
||||
```
|
||||
IMAP accounts
|
||||
→ fetch (targeted or wide)
|
||||
→ email_label_queue.jsonl
|
||||
|
||||
email_label_queue.jsonl
|
||||
→ label tool drag-to-bucket UI
|
||||
→ email_score.jsonl (ground truth)
|
||||
|
||||
email_score.jsonl
|
||||
→ benchmark harness
|
||||
→ model rankings
|
||||
|
||||
best model
|
||||
→ fine-tune harness
|
||||
→ Peregrine classifier adapter
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
|
|
@ -38,69 +136,42 @@ Scores HuggingFace zero-shot classification models against the labeled dataset.
|
|||
|
||||
## Stack
|
||||
|
||||
| Layer | Tech |
|
||||
|-------|------|
|
||||
| Label UI | Streamlit (port 8503, auto-increments on collision) |
|
||||
| Layer | Technology |
|
||||
|-------|-----------|
|
||||
| Label UI | Vue 3 SPA (Vite) |
|
||||
| API | FastAPI + uvicorn (port 8503) |
|
||||
| Benchmark | Python + HuggingFace Transformers |
|
||||
| Email fetch | IMAP (multi-account, targeted date/sender/subject filter) |
|
||||
| Data | JSONL (`data/email_label_queue.jsonl`, `data/email_score.jsonl`) |
|
||||
| Config | `config/label_tool.yaml` (gitignored — see `.example`) |
|
||||
|
||||
Conda environments:
|
||||
- `job-seeker` — label tool UI
|
||||
- `job-seeker-classifiers` — benchmark harness (separate env for heavy deps)
|
||||
| Runtime | SQLite |
|
||||
| Config | `config/label_tool.yaml` (gitignored — `.example` committed) |
|
||||
|
||||
---
|
||||
|
||||
## Running
|
||||
## Logo
|
||||
|
||||
```bash
|
||||
./manage.sh start # start label tool UI (port collision-safe from 8503)
|
||||
./manage.sh stop # stop
|
||||
./manage.sh restart # restart
|
||||
./manage.sh status # show running state and port
|
||||
./manage.sh logs # tail label tool log
|
||||
./manage.sh open # open in browser
|
||||
```
|
||||
|
||||
Benchmark:
|
||||
```bash
|
||||
./manage.sh benchmark --list-models # list available zero-shot models
|
||||
./manage.sh score # score models against labeled JSONL
|
||||
./manage.sh score --include-slow # include large/slow models
|
||||
./manage.sh compare --limit 30 # visual comparison on live IMAP emails
|
||||
```
|
||||
|
||||
Dev:
|
||||
```bash
|
||||
./manage.sh test # run pytest suite
|
||||
```
|
||||
The Avocet logo (`avocet_v1_poly.svg`) lives in the shared graphics repo. Copy it to `docs/avocet-logo.svg` to render correctly in this README.
|
||||
|
||||
---
|
||||
|
||||
## Data flow
|
||||
## About
|
||||
|
||||
```
|
||||
IMAP accounts → fetch (targeted or wide) → email_label_queue.jsonl
|
||||
→ label tool card UI → email_score.jsonl
|
||||
→ benchmark harness → model rankings
|
||||
→ best model → Peregrine classifier adapter
|
||||
```
|
||||
Avocet is internal CircuitForge infrastructure, open source as a reference implementation. It is not a user-facing product. The primary consumer is [Peregrine](https://git.opensourcesolarpunk.com/Circuit-Forge/peregrine), CircuitForge's job-search pipeline tool.
|
||||
|
||||
Targeted fetch: date range + sender/subject filter for pulling historical emails on specific senders or topics without flooding the queue.
|
||||
Docs: [docs.circuitforge.tech/avocet](https://docs.circuitforge.tech/avocet)
|
||||
|
||||
Discard: removes an email from the queue without writing to the score file — for emails that don't belong in the training set.
|
||||
## Forgejo-primary
|
||||
|
||||
---
|
||||
|
||||
## Classifier adapters
|
||||
|
||||
`app/classifier_adapters.py` provides a common interface for swapping classifier backends. Falls back to the label name when no `LABEL_DESCRIPTIONS` entry is configured for a label (RerankerAdapter).
|
||||
Avocet is developed and maintained on Forgejo at [git.opensourcesolarpunk.com/Circuit-Forge/avocet](https://git.opensourcesolarpunk.com/Circuit-Forge/avocet). GitHub and Codeberg are read-only mirrors.
|
||||
|
||||
---
|
||||
|
||||
## License
|
||||
|
||||
BSL 1.1 — internal tool, not user-facing.
|
||||
[Business Source License 1.1](LICENSE) — classifier training is an AI feature under the CircuitForge licensing model.
|
||||
|
||||
© 2026 Circuit Forge LLC
|
||||
Free for personal non-commercial self-hosting. Commercial use or SaaS re-hosting requires a paid license. Converts to MIT after 4 years.
|
||||
|
||||
Humans own design, architecture, code review, testing, and verification. LLMs are part of our development workflow. [Our positions on LLM use →](https://circuitforge.tech/positions)
|
||||
|
||||
© 2026 Circuit Forge LLC — Privacy · Safety · Accessibility
|
||||
|
|
|
|||
642
app/api.py
642
app/api.py
|
|
@ -1,623 +1,95 @@
|
|||
"""Avocet — FastAPI REST layer.
|
||||
"""Avocet -- FastAPI app factory.
|
||||
|
||||
JSONL read/write helpers and FastAPI app instance.
|
||||
Endpoints and static file serving are added in subsequent tasks.
|
||||
Mounts all domain routers and serves the Vue SPA.
|
||||
All business logic lives in the domain modules below.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import subprocess as _subprocess
|
||||
import yaml
|
||||
from pathlib import Path
|
||||
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from fastapi import FastAPI, HTTPException, Query
|
||||
from pydantic import BaseModel
|
||||
|
||||
_ROOT = Path(__file__).parent.parent
|
||||
_DATA_DIR: Path = _ROOT / "data" # overridable in tests via set_data_dir()
|
||||
_MODELS_DIR: Path = _ROOT / "models" # overridable in tests via set_models_dir()
|
||||
_CONFIG_DIR: Path | None = None # None = use real path
|
||||
|
||||
# Process registry for running jobs — used by cancel endpoints.
|
||||
# Keys: "benchmark" | "finetune". Values: the live Popen object.
|
||||
_running_procs: dict = {}
|
||||
_cancelled_jobs: set = set()
|
||||
|
||||
|
||||
def set_data_dir(path: Path) -> None:
|
||||
"""Override data directory — used by tests."""
|
||||
global _DATA_DIR
|
||||
_DATA_DIR = path
|
||||
|
||||
|
||||
def _best_cuda_device() -> str:
|
||||
"""Return the index of the GPU with the most free VRAM as a string.
|
||||
|
||||
Uses nvidia-smi so it works in the job-seeker env (no torch). Returns ""
|
||||
if nvidia-smi is unavailable or no GPUs are found. Restricting the
|
||||
training subprocess to a single GPU via CUDA_VISIBLE_DEVICES prevents
|
||||
PyTorch DataParallel from replicating the model across all GPUs, which
|
||||
would OOM the GPU with less headroom.
|
||||
"""
|
||||
try:
|
||||
out = _subprocess.check_output(
|
||||
["nvidia-smi", "--query-gpu=index,memory.free",
|
||||
"--format=csv,noheader,nounits"],
|
||||
text=True,
|
||||
timeout=5,
|
||||
)
|
||||
best_idx, best_free = "", 0
|
||||
for line in out.strip().splitlines():
|
||||
parts = line.strip().split(", ")
|
||||
if len(parts) == 2:
|
||||
idx, free = parts[0].strip(), int(parts[1].strip())
|
||||
if free > best_free:
|
||||
best_free, best_idx = free, idx
|
||||
return best_idx
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
|
||||
def set_models_dir(path: Path) -> None:
|
||||
"""Override models directory — used by tests."""
|
||||
global _MODELS_DIR
|
||||
_MODELS_DIR = path
|
||||
|
||||
|
||||
def set_config_dir(path: Path | None) -> None:
|
||||
"""Override config directory — used by tests."""
|
||||
global _CONFIG_DIR
|
||||
_CONFIG_DIR = path
|
||||
|
||||
|
||||
def _config_file() -> Path:
|
||||
if _CONFIG_DIR is not None:
|
||||
return _CONFIG_DIR / "label_tool.yaml"
|
||||
return _ROOT / "config" / "label_tool.yaml"
|
||||
|
||||
|
||||
def reset_last_action() -> None:
|
||||
"""Reset undo state — used by tests."""
|
||||
global _last_action
|
||||
_last_action = None
|
||||
|
||||
|
||||
def _queue_file() -> Path:
|
||||
return _DATA_DIR / "email_label_queue.jsonl"
|
||||
|
||||
|
||||
def _score_file() -> Path:
|
||||
return _DATA_DIR / "email_score.jsonl"
|
||||
|
||||
|
||||
def _discarded_file() -> Path:
|
||||
return _DATA_DIR / "discarded.jsonl"
|
||||
|
||||
|
||||
def _read_jsonl(path: Path) -> list[dict]:
|
||||
if not path.exists():
|
||||
return []
|
||||
lines = path.read_text(encoding="utf-8").splitlines()
|
||||
return [json.loads(l) for l in lines if l.strip()]
|
||||
|
||||
|
||||
def _write_jsonl(path: Path, records: list[dict]) -> None:
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
text = "\n".join(json.dumps(r, ensure_ascii=False) for r in records)
|
||||
path.write_text(text + "\n" if records else "", encoding="utf-8")
|
||||
|
||||
|
||||
def _append_jsonl(path: Path, record: dict) -> None:
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with path.open("a", encoding="utf-8") as f:
|
||||
f.write(json.dumps(record, ensure_ascii=False) + "\n")
|
||||
|
||||
|
||||
def _item_id(item: dict) -> str:
|
||||
"""Stable content-hash ID — matches label_tool.py _entry_key dedup logic."""
|
||||
key = (item.get("subject", "") + (item.get("body", "") or "")[:100])
|
||||
return hashlib.md5(key.encode("utf-8", errors="replace")).hexdigest()
|
||||
|
||||
|
||||
def _normalize(item: dict) -> dict:
|
||||
"""Normalize JSONL item to the Vue frontend schema.
|
||||
|
||||
label_tool.py stores: subject, body, from_addr, date, account (no id).
|
||||
The Vue app expects: id, subject, body, from, date, source.
|
||||
Both old (from_addr/account) and new (from/source) field names are handled.
|
||||
"""
|
||||
return {
|
||||
"id": item.get("id") or _item_id(item),
|
||||
"subject": item.get("subject", ""),
|
||||
"body": item.get("body", ""),
|
||||
"from": item.get("from") or item.get("from_addr", ""),
|
||||
"date": item.get("date", ""),
|
||||
"source": item.get("source") or item.get("account", ""),
|
||||
}
|
||||
|
||||
from fastapi import FastAPI
|
||||
|
||||
app = FastAPI(title="Avocet API")
|
||||
|
||||
from app.sft import router as sft_router
|
||||
app.include_router(sft_router, prefix="/api/sft")
|
||||
# -- Domain routers --------------------------------------------------------
|
||||
|
||||
from app.models import router as models_router
|
||||
import app.models as _models_module
|
||||
app.include_router(models_router, prefix="/api/models")
|
||||
from app.data.label import router as label_router
|
||||
app.include_router(label_router, prefix="/api")
|
||||
|
||||
from app.cforch import router as cforch_router
|
||||
app.include_router(cforch_router, prefix="/api/cforch")
|
||||
from app.data.fetch import router as fetch_router
|
||||
app.include_router(fetch_router, prefix="/api")
|
||||
|
||||
from app.imitate import router as imitate_router
|
||||
from app.data.corrections import router as corrections_router
|
||||
app.include_router(corrections_router, prefix="/api/corrections")
|
||||
|
||||
# Backward-compat alias -- remove when Vue SPA is updated to /api/corrections/*
|
||||
app.include_router(corrections_router, prefix="/api/sft")
|
||||
|
||||
from app.data.imitate import router as imitate_router
|
||||
app.include_router(imitate_router, prefix="/api/imitate")
|
||||
|
||||
from app.style import router as style_router
|
||||
app.include_router(style_router, prefix="/api/style")
|
||||
from app.eval.cforch import router as eval_router
|
||||
app.include_router(eval_router, prefix="/api")
|
||||
|
||||
from app.train.train import router as train_router
|
||||
app.include_router(train_router, prefix="/api/train")
|
||||
|
||||
from app.plans_bench import router as plans_bench_router
|
||||
app.include_router(plans_bench_router, prefix="/api/plans-bench")
|
||||
|
||||
# In-memory last-action store (single user, local tool — in-memory is fine)
|
||||
_last_action: dict | None = None
|
||||
|
||||
# -- Backward-compat shims (ClassifierTab still uses old /api/finetune/* paths)
|
||||
# Remove once ClassifierTab fine-tune section is migrated to TrainJobsView.
|
||||
|
||||
@app.get("/api/queue")
|
||||
def get_queue(limit: int = Query(default=10, ge=1, le=50)):
|
||||
items = _read_jsonl(_queue_file())
|
||||
return {"items": [_normalize(x) for x in items[:limit]], "total": len(items)}
|
||||
|
||||
|
||||
class LabelRequest(BaseModel):
|
||||
id: str
|
||||
label: str
|
||||
|
||||
|
||||
@app.post("/api/label")
|
||||
def post_label(req: LabelRequest):
|
||||
global _last_action
|
||||
items = _read_jsonl(_queue_file())
|
||||
match = next((x for x in items if _normalize(x)["id"] == req.id), None)
|
||||
if not match:
|
||||
raise HTTPException(404, f"Item {req.id!r} not found in queue")
|
||||
record = {**match, "label": req.label,
|
||||
"labeled_at": datetime.now(timezone.utc).isoformat()}
|
||||
_append_jsonl(_score_file(), record)
|
||||
_write_jsonl(_queue_file(), [x for x in items if _normalize(x)["id"] != req.id])
|
||||
_last_action = {"type": "label", "item": match, "label": req.label}
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
class SkipRequest(BaseModel):
|
||||
id: str
|
||||
|
||||
|
||||
@app.post("/api/skip")
|
||||
def post_skip(req: SkipRequest):
|
||||
global _last_action
|
||||
items = _read_jsonl(_queue_file())
|
||||
match = next((x for x in items if _normalize(x)["id"] == req.id), None)
|
||||
if not match:
|
||||
raise HTTPException(404, f"Item {req.id!r} not found in queue")
|
||||
reordered = [x for x in items if _normalize(x)["id"] != req.id] + [match]
|
||||
_write_jsonl(_queue_file(), reordered)
|
||||
_last_action = {"type": "skip", "item": match}
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
class DiscardRequest(BaseModel):
|
||||
id: str
|
||||
|
||||
|
||||
@app.post("/api/discard")
|
||||
def post_discard(req: DiscardRequest):
|
||||
global _last_action
|
||||
items = _read_jsonl(_queue_file())
|
||||
match = next((x for x in items if _normalize(x)["id"] == req.id), None)
|
||||
if not match:
|
||||
raise HTTPException(404, f"Item {req.id!r} not found in queue")
|
||||
record = {**match, "label": "__discarded__",
|
||||
"discarded_at": datetime.now(timezone.utc).isoformat()}
|
||||
_append_jsonl(_discarded_file(), record)
|
||||
_write_jsonl(_queue_file(), [x for x in items if _normalize(x)["id"] != req.id])
|
||||
_last_action = {"type": "discard", "item": match}
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
@app.delete("/api/label/undo")
|
||||
def delete_undo():
|
||||
global _last_action
|
||||
if not _last_action:
|
||||
raise HTTPException(404, "No action to undo")
|
||||
action = _last_action
|
||||
item = action["item"] # always the original clean queue item
|
||||
|
||||
# Perform file operations FIRST — only clear _last_action on success
|
||||
if action["type"] == "label":
|
||||
records = _read_jsonl(_score_file())
|
||||
if not records:
|
||||
raise HTTPException(409, "Score file is empty — cannot undo label")
|
||||
_write_jsonl(_score_file(), records[:-1])
|
||||
items = _read_jsonl(_queue_file())
|
||||
_write_jsonl(_queue_file(), [item] + items)
|
||||
elif action["type"] == "discard":
|
||||
records = _read_jsonl(_discarded_file())
|
||||
if not records:
|
||||
raise HTTPException(409, "Discarded file is empty — cannot undo discard")
|
||||
_write_jsonl(_discarded_file(), records[:-1])
|
||||
items = _read_jsonl(_queue_file())
|
||||
_write_jsonl(_queue_file(), [item] + items)
|
||||
elif action["type"] == "skip":
|
||||
items = _read_jsonl(_queue_file())
|
||||
item_id = _normalize(item)["id"]
|
||||
items = [item] + [x for x in items if _normalize(x)["id"] != item_id]
|
||||
_write_jsonl(_queue_file(), items)
|
||||
|
||||
# Clear AFTER all file operations succeed
|
||||
_last_action = None
|
||||
return {"undone": {"type": action["type"], "item": _normalize(item)}}
|
||||
|
||||
|
||||
# Label metadata — 10 labels matching label_tool.py
|
||||
_LABEL_META = [
|
||||
{"name": "interview_scheduled", "emoji": "\U0001f4c5", "color": "#4CAF50", "key": "1"},
|
||||
{"name": "offer_received", "emoji": "\U0001f389", "color": "#2196F3", "key": "2"},
|
||||
{"name": "rejected", "emoji": "\u274c", "color": "#F44336", "key": "3"},
|
||||
{"name": "positive_response", "emoji": "\U0001f44d", "color": "#FF9800", "key": "4"},
|
||||
{"name": "survey_received", "emoji": "\U0001f4cb", "color": "#9C27B0", "key": "5"},
|
||||
{"name": "neutral", "emoji": "\u2b1c", "color": "#607D8B", "key": "6"},
|
||||
{"name": "event_rescheduled", "emoji": "\U0001f504", "color": "#FF5722", "key": "7"},
|
||||
{"name": "digest", "emoji": "\U0001f4f0", "color": "#00BCD4", "key": "8"},
|
||||
{"name": "new_lead", "emoji": "\U0001f91d", "color": "#009688", "key": "9"},
|
||||
{"name": "hired", "emoji": "\U0001f38a", "color": "#FFC107", "key": "h"},
|
||||
]
|
||||
|
||||
|
||||
@app.get("/api/config/labels")
|
||||
def get_labels():
|
||||
return _LABEL_META
|
||||
|
||||
|
||||
@app.get("/api/config")
|
||||
def get_config():
|
||||
f = _config_file()
|
||||
if not f.exists():
|
||||
return {"accounts": [], "max_per_account": 500}
|
||||
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
||||
return {"accounts": raw.get("accounts", []), "max_per_account": raw.get("max_per_account", 500)}
|
||||
|
||||
|
||||
class ConfigPayload(BaseModel):
|
||||
accounts: list[dict]
|
||||
max_per_account: int = 500
|
||||
|
||||
|
||||
@app.post("/api/config")
|
||||
def post_config(payload: ConfigPayload):
|
||||
f = _config_file()
|
||||
f.parent.mkdir(parents=True, exist_ok=True)
|
||||
tmp = f.with_suffix(".tmp")
|
||||
tmp.write_text(yaml.dump(payload.model_dump(), allow_unicode=True, sort_keys=False),
|
||||
encoding="utf-8")
|
||||
tmp.rename(f)
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
@app.get("/api/stats")
|
||||
def get_stats():
|
||||
records = _read_jsonl(_score_file())
|
||||
counts: dict[str, int] = {}
|
||||
for r in records:
|
||||
lbl = r.get("label", "")
|
||||
if lbl:
|
||||
counts[lbl] = counts.get(lbl, 0) + 1
|
||||
benchmark_results: dict = {}
|
||||
benchmark_path = _DATA_DIR / "benchmark_results.json"
|
||||
if benchmark_path.exists():
|
||||
try:
|
||||
benchmark_results = json.loads(benchmark_path.read_text(encoding="utf-8"))
|
||||
except Exception:
|
||||
pass
|
||||
return {
|
||||
"total": len(records),
|
||||
"counts": counts,
|
||||
"score_file_bytes": _score_file().stat().st_size if _score_file().exists() else 0,
|
||||
"benchmark_results": benchmark_results,
|
||||
}
|
||||
|
||||
|
||||
@app.get("/api/stats/download")
|
||||
def download_stats():
|
||||
from fastapi.responses import FileResponse
|
||||
if not _score_file().exists():
|
||||
raise HTTPException(404, "No score file")
|
||||
return FileResponse(
|
||||
str(_score_file()),
|
||||
filename="email_score.jsonl",
|
||||
media_type="application/jsonlines",
|
||||
headers={"Content-Disposition": 'attachment; filename="email_score.jsonl"'},
|
||||
)
|
||||
|
||||
|
||||
class AccountTestRequest(BaseModel):
|
||||
account: dict
|
||||
|
||||
|
||||
@app.post("/api/accounts/test")
|
||||
def test_account(req: AccountTestRequest):
|
||||
from app.imap_fetch import test_connection
|
||||
ok, message, count = test_connection(req.account)
|
||||
return {"ok": ok, "message": message, "count": count}
|
||||
|
||||
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Benchmark endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@app.get("/api/benchmark/models")
|
||||
def get_benchmark_models() -> dict:
|
||||
"""Return installed models grouped by adapter_type category."""
|
||||
models_dir: Path = _models_module._MODELS_DIR
|
||||
categories: dict[str, list[dict]] = {
|
||||
"ZeroShotAdapter": [],
|
||||
"RerankerAdapter": [],
|
||||
"GenerationAdapter": [],
|
||||
"Unknown": [],
|
||||
}
|
||||
if models_dir.exists():
|
||||
for sub in models_dir.iterdir():
|
||||
if not sub.is_dir():
|
||||
continue
|
||||
info_path = sub / "model_info.json"
|
||||
adapter_type = "Unknown"
|
||||
repo_id: str | None = None
|
||||
if info_path.exists():
|
||||
try:
|
||||
info = json.loads(info_path.read_text(encoding="utf-8"))
|
||||
adapter_type = info.get("adapter_type") or info.get("adapter_recommendation") or "Unknown"
|
||||
repo_id = info.get("repo_id")
|
||||
except Exception:
|
||||
pass
|
||||
bucket = adapter_type if adapter_type in categories else "Unknown"
|
||||
entry: dict = {"name": sub.name, "repo_id": repo_id, "adapter_type": adapter_type}
|
||||
categories[bucket].append(entry)
|
||||
return {"categories": categories}
|
||||
|
||||
|
||||
@app.get("/api/benchmark/results")
|
||||
def get_benchmark_results():
|
||||
"""Return the most recently saved benchmark results, or an empty envelope."""
|
||||
path = _DATA_DIR / "benchmark_results.json"
|
||||
if not path.exists():
|
||||
return {"models": {}, "sample_count": 0, "timestamp": None}
|
||||
return json.loads(path.read_text())
|
||||
|
||||
|
||||
@app.get("/api/benchmark/run")
|
||||
def run_benchmark(include_slow: bool = False, model_names: str = ""):
|
||||
"""Spawn the benchmark script and stream stdout as SSE progress events."""
|
||||
python_bin = "/devl/miniconda3/envs/job-seeker-classifiers/bin/python"
|
||||
script = str(_ROOT / "scripts" / "benchmark_classifier.py")
|
||||
cmd = [python_bin, script, "--score", "--save"]
|
||||
if include_slow:
|
||||
cmd.append("--include-slow")
|
||||
if model_names:
|
||||
names = [n.strip() for n in model_names.split(",") if n.strip()]
|
||||
if names:
|
||||
cmd.extend(["--models"] + names)
|
||||
|
||||
def generate():
|
||||
try:
|
||||
proc = _subprocess.Popen(
|
||||
cmd,
|
||||
stdout=_subprocess.PIPE,
|
||||
stderr=_subprocess.STDOUT,
|
||||
text=True,
|
||||
bufsize=1,
|
||||
cwd=str(_ROOT),
|
||||
)
|
||||
_running_procs["benchmark"] = proc
|
||||
_cancelled_jobs.discard("benchmark") # clear any stale flag from a prior run
|
||||
try:
|
||||
for line in proc.stdout:
|
||||
line = line.rstrip()
|
||||
if line:
|
||||
yield f"data: {json.dumps({'type': 'progress', 'message': line})}\n\n"
|
||||
proc.wait()
|
||||
if proc.returncode == 0:
|
||||
yield f"data: {json.dumps({'type': 'complete'})}\n\n"
|
||||
elif "benchmark" in _cancelled_jobs:
|
||||
_cancelled_jobs.discard("benchmark")
|
||||
yield f"data: {json.dumps({'type': 'cancelled'})}\n\n"
|
||||
else:
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': f'Process exited with code {proc.returncode}'})}\n\n"
|
||||
finally:
|
||||
_running_procs.pop("benchmark", None)
|
||||
except Exception as exc:
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': str(exc)})}\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
generate(),
|
||||
media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Finetune endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@app.get("/api/finetune/status")
|
||||
def get_finetune_status():
|
||||
"""Scan models/ for training_info.json files. Returns [] if none exist."""
|
||||
models_dir = _MODELS_DIR
|
||||
if not models_dir.exists():
|
||||
return []
|
||||
results = []
|
||||
for sub in models_dir.iterdir():
|
||||
if not sub.is_dir():
|
||||
continue
|
||||
info_path = sub / "training_info.json"
|
||||
if not info_path.exists():
|
||||
continue
|
||||
try:
|
||||
info = json.loads(info_path.read_text(encoding="utf-8"))
|
||||
results.append(info)
|
||||
except Exception:
|
||||
pass
|
||||
return results
|
||||
|
||||
from fastapi import Query
|
||||
from fastapi.responses import StreamingResponse as _StreamingResponse
|
||||
|
||||
@app.get("/api/finetune/run")
|
||||
def run_finetune_endpoint(
|
||||
model: str = "deberta-small",
|
||||
epochs: int = 5,
|
||||
score: list[str] = Query(default=[]),
|
||||
):
|
||||
"""Spawn finetune_classifier.py and stream stdout as SSE progress events."""
|
||||
python_bin = "/devl/miniconda3/envs/job-seeker-classifiers/bin/python"
|
||||
script = str(_ROOT / "scripts" / "finetune_classifier.py")
|
||||
cmd = [python_bin, script, "--model", model, "--epochs", str(epochs)]
|
||||
data_root = _DATA_DIR.resolve()
|
||||
for score_file in score:
|
||||
resolved = (_DATA_DIR / score_file).resolve()
|
||||
if not str(resolved).startswith(str(data_root)):
|
||||
raise HTTPException(400, f"Invalid score path: {score_file!r}")
|
||||
cmd.extend(["--score", str(resolved)])
|
||||
|
||||
# Pick the GPU with the most free VRAM. Setting CUDA_VISIBLE_DEVICES to a
|
||||
# single device prevents DataParallel from replicating the model across all
|
||||
# GPUs, which would force a full copy onto the more memory-constrained device.
|
||||
proc_env = {**os.environ, "PYTORCH_ALLOC_CONF": "expandable_segments:True"}
|
||||
best_gpu = _best_cuda_device()
|
||||
if best_gpu:
|
||||
proc_env["CUDA_VISIBLE_DEVICES"] = best_gpu
|
||||
|
||||
gpu_note = f"GPU {best_gpu}" if best_gpu else "CPU (no GPU found)"
|
||||
|
||||
def generate():
|
||||
yield f"data: {json.dumps({'type': 'progress', 'message': f'[api] Using {gpu_note} (most free VRAM)'})}\n\n"
|
||||
try:
|
||||
proc = _subprocess.Popen(
|
||||
cmd,
|
||||
stdout=_subprocess.PIPE,
|
||||
stderr=_subprocess.STDOUT,
|
||||
text=True,
|
||||
bufsize=1,
|
||||
cwd=str(_ROOT),
|
||||
env=proc_env,
|
||||
)
|
||||
_running_procs["finetune"] = proc
|
||||
_cancelled_jobs.discard("finetune") # clear any stale flag from a prior run
|
||||
try:
|
||||
for line in proc.stdout:
|
||||
line = line.rstrip()
|
||||
if line:
|
||||
yield f"data: {json.dumps({'type': 'progress', 'message': line})}\n\n"
|
||||
proc.wait()
|
||||
if proc.returncode == 0:
|
||||
yield f"data: {json.dumps({'type': 'complete'})}\n\n"
|
||||
elif "finetune" in _cancelled_jobs:
|
||||
_cancelled_jobs.discard("finetune")
|
||||
yield f"data: {json.dumps({'type': 'cancelled'})}\n\n"
|
||||
else:
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': f'Process exited with code {proc.returncode}'})}\n\n"
|
||||
finally:
|
||||
_running_procs.pop("finetune", None)
|
||||
except Exception as exc:
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': str(exc)})}\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
generate(),
|
||||
media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
|
||||
)
|
||||
|
||||
|
||||
@app.post("/api/benchmark/cancel")
|
||||
def cancel_benchmark():
|
||||
"""Kill the running benchmark subprocess. 404 if none is running."""
|
||||
proc = _running_procs.get("benchmark")
|
||||
if proc is None:
|
||||
raise HTTPException(404, "No benchmark is running")
|
||||
_cancelled_jobs.add("benchmark")
|
||||
proc.terminate()
|
||||
try:
|
||||
proc.wait(timeout=3)
|
||||
except _subprocess.TimeoutExpired:
|
||||
proc.kill()
|
||||
return {"status": "cancelled"}
|
||||
|
||||
def finetune_run_compat(model: str = Query(...), epochs: int = Query(5)) -> _StreamingResponse:
|
||||
"""Shim: create a classifier train job and immediately stream it."""
|
||||
from app.train.train import create_job, run_job, CreateJobRequest
|
||||
job = create_job(CreateJobRequest(type="classifier", model_key=model, config_json={"epochs": epochs}))
|
||||
return run_job(job["id"])
|
||||
|
||||
@app.post("/api/finetune/cancel")
|
||||
def cancel_finetune():
|
||||
"""Kill the running fine-tune subprocess. 404 if none is running."""
|
||||
proc = _running_procs.get("finetune")
|
||||
if proc is None:
|
||||
raise HTTPException(404, "No finetune is running")
|
||||
_cancelled_jobs.add("finetune")
|
||||
proc.terminate()
|
||||
try:
|
||||
proc.wait(timeout=3)
|
||||
except _subprocess.TimeoutExpired:
|
||||
proc.kill()
|
||||
return {"status": "cancelled"}
|
||||
def finetune_cancel_compat() -> dict:
|
||||
"""Shim: cancel the most recent running classifier job."""
|
||||
from app.train.train import _db, _init_db, cancel_job
|
||||
from fastapi import HTTPException
|
||||
_init_db()
|
||||
with _db() as conn:
|
||||
row = conn.execute(
|
||||
"SELECT id FROM jobs WHERE type='classifier' AND status='running' ORDER BY started_at DESC LIMIT 1"
|
||||
).fetchone()
|
||||
if row is None:
|
||||
return {"status": "nothing_running"}
|
||||
return cancel_job(row["id"])
|
||||
|
||||
from app.data.log_corpus import router as log_corpus_router
|
||||
app.include_router(log_corpus_router, prefix="/api/corpus")
|
||||
|
||||
@app.get("/api/fetch/stream")
|
||||
def fetch_stream(
|
||||
accounts: str = Query(default=""),
|
||||
days_back: int = Query(default=90, ge=1, le=365),
|
||||
limit: int = Query(default=150, ge=1, le=1000),
|
||||
mode: str = Query(default="wide"),
|
||||
):
|
||||
from app.imap_fetch import fetch_account_stream
|
||||
from app.data.recipe_scan import router as recipe_scan_router
|
||||
app.include_router(recipe_scan_router, prefix="/api/recipe-scan")
|
||||
|
||||
selected_names = {n.strip() for n in accounts.split(",") if n.strip()}
|
||||
config = get_config() # reuse existing endpoint logic
|
||||
selected = [a for a in config["accounts"] if a.get("name") in selected_names]
|
||||
from app.dashboard import router as dashboard_router
|
||||
app.include_router(dashboard_router, prefix="/api")
|
||||
|
||||
def generate():
|
||||
known_keys = {_item_id(x) for x in _read_jsonl(_queue_file())}
|
||||
total_added = 0
|
||||
from app.models import router as models_router
|
||||
app.include_router(models_router, prefix="/api/models")
|
||||
|
||||
for acc in selected:
|
||||
try:
|
||||
batch_emails: list[dict] = []
|
||||
for event in fetch_account_stream(acc, days_back, limit, known_keys):
|
||||
if event["type"] == "done":
|
||||
batch_emails = event.pop("emails", [])
|
||||
total_added += event["added"]
|
||||
yield f"data: {json.dumps(event)}\n\n"
|
||||
# Write new emails to queue after each account
|
||||
if batch_emails:
|
||||
existing = _read_jsonl(_queue_file())
|
||||
_write_jsonl(_queue_file(), existing + batch_emails)
|
||||
except Exception as exc:
|
||||
error_event = {"type": "error", "account": acc.get("name", "?"),
|
||||
"message": str(exc)}
|
||||
yield f"data: {json.dumps(error_event)}\n\n"
|
||||
from app.nodes import router as nodes_router
|
||||
app.include_router(nodes_router, prefix="/api/nodes-mgmt")
|
||||
|
||||
queue_size = len(_read_jsonl(_queue_file()))
|
||||
complete = {"type": "complete", "total_added": total_added, "queue_size": queue_size}
|
||||
yield f"data: {json.dumps(complete)}\n\n"
|
||||
# -- Static SPA -- MUST be last (catches all unmatched paths) ---------------
|
||||
|
||||
return StreamingResponse(generate(), media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"})
|
||||
|
||||
|
||||
# Static SPA — MUST be last (catches all unmatched paths)
|
||||
_ROOT = Path(__file__).parent.parent
|
||||
_DIST = _ROOT / "web" / "dist"
|
||||
if _DIST.exists():
|
||||
from fastapi.responses import FileResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
|
||||
# Serve index.html with no-cache so browsers always fetch fresh HTML after rebuilds.
|
||||
# Hashed assets (/assets/index-abc123.js) can be cached forever — they change names
|
||||
# when content changes (standard Vite cache-busting strategy).
|
||||
_NO_CACHE = {"Cache-Control": "no-cache, no-store, must-revalidate", "Pragma": "no-cache"}
|
||||
|
||||
@app.get("/")
|
||||
|
|
|
|||
348
app/cforch.py
348
app/cforch.py
|
|
@ -16,13 +16,18 @@ import json
|
|||
import logging
|
||||
import os
|
||||
import re
|
||||
import select as _select
|
||||
import subprocess as _subprocess
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
|
||||
import urllib.parse
|
||||
|
||||
import yaml
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -75,9 +80,31 @@ def _load_cforch_config() -> dict:
|
|||
"license_key": _coalesce(file_cfg.get("license_key", ""), "CF_LICENSE_KEY"),
|
||||
"ollama_url": _coalesce(file_cfg.get("ollama_url", ""), "OLLAMA_HOST"),
|
||||
"ollama_model": _coalesce(file_cfg.get("ollama_model", ""), "OLLAMA_MODEL"),
|
||||
"judge_url": _coalesce(file_cfg.get("judge_url", ""), "CF_JUDGE_URL"),
|
||||
"hf_token": _coalesce(file_cfg.get("hf_token", ""), "HF_TOKEN"),
|
||||
}
|
||||
|
||||
|
||||
def _validate_service_url(url: str, param_name: str) -> str:
|
||||
"""Validate that a URL is a well-formed http/https URL with a hostname.
|
||||
|
||||
Guards against SSRF: only http/https is allowed; the URL must have a
|
||||
non-empty host. Does not enforce an allowlist — call sites are internal
|
||||
tooling, not a public API.
|
||||
"""
|
||||
if not url:
|
||||
return url
|
||||
try:
|
||||
parsed = urllib.parse.urlparse(url)
|
||||
except Exception:
|
||||
raise HTTPException(400, f"{param_name}: not a valid URL")
|
||||
if parsed.scheme not in ("http", "https"):
|
||||
raise HTTPException(400, f"{param_name}: URL must start with http:// or https://")
|
||||
if not parsed.hostname:
|
||||
raise HTTPException(400, f"{param_name}: URL has no hostname")
|
||||
return url
|
||||
|
||||
|
||||
def _strip_ansi(text: str) -> str:
|
||||
"""Remove ANSI escape codes from a string."""
|
||||
return re.sub(r'\x1b\[[0-9;]*m', '', text)
|
||||
|
|
@ -147,54 +174,151 @@ def get_tasks() -> dict:
|
|||
|
||||
# ── GET /models ────────────────────────────────────────────────────────────────
|
||||
|
||||
# Services and roles surfaced in the benchmark model picker.
|
||||
# Covers all cf-orch service types that benchmark.py can route tasks to.
|
||||
_BENCH_SERVICES = frozenset({
|
||||
"cf-text", "vllm", # LLM text generation
|
||||
"cf-stt", # speech-to-text
|
||||
"cf-tts", # text-to-speech
|
||||
"cf-vision", # image classification / embedding
|
||||
"cf-voice", # audio context classification
|
||||
})
|
||||
_BENCH_ROLES = frozenset({
|
||||
"generator", "vlm", # LLM roles
|
||||
"stt", "alm", # speech recognition
|
||||
"tts", # speech synthesis
|
||||
"vision", "embedding", # image understanding
|
||||
"classifier", # audio classification (cf-voice)
|
||||
})
|
||||
|
||||
|
||||
@router.get("/models")
|
||||
def get_models() -> dict:
|
||||
"""Return model list from bench_models.yaml."""
|
||||
"""Return model list from bench_models.yaml merged with locally installed models.
|
||||
|
||||
bench_models.yaml entries are listed first and take precedence; any installed
|
||||
model whose repo_id is already present in the YAML is skipped. Only models
|
||||
whose service is in _BENCH_SERVICES (cf-text, vllm, cf-stt, cf-tts, cf-vision,
|
||||
cf-voice) are surfaced from the installed registry.
|
||||
"""
|
||||
cfg = _load_cforch_config()
|
||||
models_path = cfg.get("bench_models", "")
|
||||
if not models_path:
|
||||
return {"models": []}
|
||||
|
||||
models: list[dict] = []
|
||||
bench_ids: set[str] = set()
|
||||
|
||||
if models_path:
|
||||
p = Path(models_path)
|
||||
if not p.exists():
|
||||
return {"models": []}
|
||||
|
||||
if p.exists():
|
||||
try:
|
||||
raw = yaml.safe_load(p.read_text(encoding="utf-8")) or {}
|
||||
except yaml.YAMLError as exc:
|
||||
logger.warning("Failed to parse bench_models.yaml %s: %s", p, exc)
|
||||
return {"models": []}
|
||||
|
||||
models_raw = raw.get("models", []) or []
|
||||
models: list[dict] = []
|
||||
for m in models_raw:
|
||||
raw = {}
|
||||
for m in (raw.get("models", []) or []):
|
||||
if not isinstance(m, dict):
|
||||
continue
|
||||
model_id = m.get("id", "")
|
||||
models.append({
|
||||
"name": m.get("name", ""),
|
||||
"id": m.get("id", ""),
|
||||
"id": model_id,
|
||||
"service": m.get("service", "ollama"),
|
||||
"tags": m.get("tags", []) or [],
|
||||
"vram_estimate_mb": m.get("vram_estimate_mb", 0),
|
||||
})
|
||||
if model_id:
|
||||
bench_ids.add(model_id)
|
||||
|
||||
# Merge installed generator models not already in bench_models.yaml.
|
||||
try:
|
||||
from app.models import list_installed # local import avoids circular dependency at module load
|
||||
for installed in list_installed():
|
||||
model_id: str = installed.get("model_id") or ""
|
||||
service: str = installed.get("service") or ""
|
||||
role: str = installed.get("role") or ""
|
||||
if not model_id:
|
||||
continue
|
||||
if service not in _BENCH_SERVICES or role not in _BENCH_ROLES:
|
||||
continue
|
||||
if model_id in bench_ids:
|
||||
continue
|
||||
display_name = model_id.split("/", 1)[-1] if "/" in model_id else model_id
|
||||
models.append({
|
||||
"name": display_name,
|
||||
"id": model_id,
|
||||
"service": service,
|
||||
"tags": [role],
|
||||
"vram_estimate_mb": installed.get("vram_mb") or 0,
|
||||
})
|
||||
bench_ids.add(model_id)
|
||||
except Exception as exc:
|
||||
logger.warning("Could not merge installed models into model list: %s", exc)
|
||||
|
||||
return {"models": models}
|
||||
|
||||
|
||||
# ── GET /run ───────────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/nodes")
|
||||
def get_nodes() -> dict:
|
||||
"""Proxy the coordinator's /api/nodes list, returning node_id + online status.
|
||||
|
||||
Online is inferred from last_heartbeat: any node with a recent heartbeat is online.
|
||||
Returns an empty list if the coordinator is unreachable.
|
||||
"""
|
||||
cfg = _load_cforch_config()
|
||||
coordinator_url = cfg.get("coordinator_url", "").rstrip("/")
|
||||
if not coordinator_url:
|
||||
return {"nodes": []}
|
||||
try:
|
||||
import httpx as _httpx
|
||||
resp = _httpx.get(f"{coordinator_url}/api/nodes", timeout=5.0)
|
||||
resp.raise_for_status()
|
||||
raw_nodes = resp.json().get("nodes", [])
|
||||
return {
|
||||
"nodes": [
|
||||
{
|
||||
"node_id": n.get("node_id", ""),
|
||||
"online": n.get("last_heartbeat") is not None,
|
||||
"gpus": [
|
||||
{
|
||||
"gpu_id": g.get("gpu_id"),
|
||||
"name": g.get("name", ""),
|
||||
"vram_total_mb": g.get("vram_total_mb", 0),
|
||||
"vram_free_mb": g.get("vram_free_mb", 0),
|
||||
}
|
||||
for g in n.get("gpus", [])
|
||||
],
|
||||
}
|
||||
for n in raw_nodes
|
||||
]
|
||||
}
|
||||
except Exception as exc:
|
||||
logger.warning("Could not fetch nodes from coordinator: %s", exc)
|
||||
return {"nodes": []}
|
||||
|
||||
|
||||
@router.get("/run")
|
||||
def run_benchmark(
|
||||
task_ids: str = "",
|
||||
model_ids: str = "",
|
||||
model_tags: str = "",
|
||||
coordinator_url: str = "",
|
||||
ollama_url: str = "",
|
||||
judge_url: str = "",
|
||||
judge_backend: str = "chat",
|
||||
workers: int = 1,
|
||||
node_ids: str = "",
|
||||
) -> StreamingResponse:
|
||||
"""Spawn cf-orch benchmark.py and stream stdout as SSE progress events."""
|
||||
global _BENCH_RUNNING, _bench_proc
|
||||
|
||||
# Check if the process is actually still alive; reset stale flag if not.
|
||||
if _BENCH_RUNNING:
|
||||
if _bench_proc is not None and _bench_proc.poll() is None:
|
||||
raise HTTPException(409, "A benchmark is already running")
|
||||
_BENCH_RUNNING = False
|
||||
_bench_proc = None
|
||||
|
||||
cfg = _load_cforch_config()
|
||||
bench_script = cfg.get("bench_script", "")
|
||||
|
|
@ -205,6 +329,13 @@ def run_benchmark(
|
|||
cfg_coordinator = cfg.get("coordinator_url", "")
|
||||
cfg_ollama = cfg.get("ollama_url", "")
|
||||
cfg_license_key = cfg.get("license_key", "")
|
||||
cfg_judge_url = cfg.get("judge_url", "")
|
||||
|
||||
# Validate URL params before spawning the subprocess.
|
||||
# _validate_service_url raises HTTPException on bad input (caught by FastAPI before streaming starts).
|
||||
_validate_service_url(coordinator_url, "coordinator_url")
|
||||
_validate_service_url(ollama_url, "ollama_url")
|
||||
_validate_service_url(judge_url, "judge_url")
|
||||
|
||||
def generate():
|
||||
global _BENCH_RUNNING, _bench_proc
|
||||
|
|
@ -213,16 +344,68 @@ def run_benchmark(
|
|||
yield f"data: {json.dumps({'type': 'error', 'message': 'bench_script not configured or not found'})}\n\n"
|
||||
return
|
||||
|
||||
# Build effective models file: bench_models.yaml + any installed models
|
||||
# whose IDs were selected but are absent from the YAML (e.g. downloaded
|
||||
# via the Models view). Written to a temp file so benchmark.py sees one
|
||||
# unified list; cleaned up in the finally block.
|
||||
effective_models_file = bench_models
|
||||
_tmp_models_path: str | None = None
|
||||
|
||||
if model_ids and bench_models and Path(bench_models).exists():
|
||||
requested_ids = set(model_ids.split(","))
|
||||
try:
|
||||
raw_bench = yaml.safe_load(Path(bench_models).read_text(encoding="utf-8")) or {}
|
||||
bench_entries: list[dict] = raw_bench.get("models", []) or []
|
||||
bench_id_set = {m.get("id", "") for m in bench_entries if isinstance(m, dict)}
|
||||
missing_ids = requested_ids - bench_id_set
|
||||
if missing_ids:
|
||||
from app.models import list_installed
|
||||
installed_map = {
|
||||
m["model_id"]: m
|
||||
for m in list_installed()
|
||||
if m.get("model_id") and m.get("service") in _BENCH_SERVICES
|
||||
}
|
||||
extra: list[dict] = []
|
||||
for mid in missing_ids:
|
||||
if mid in installed_map:
|
||||
inst = installed_map[mid]
|
||||
entry: dict[str, Any] = {
|
||||
"id": mid,
|
||||
"name": mid.split("/", 1)[-1] if "/" in mid else mid,
|
||||
"service": inst.get("service", "cf-text"),
|
||||
"vram_estimate_mb": inst.get("vram_mb") or 0,
|
||||
"tags": [inst.get("role", "generator")],
|
||||
"temperature": 0.0,
|
||||
}
|
||||
local_path = inst.get("path", "") or inst.get("local_path", "")
|
||||
if local_path:
|
||||
entry["model_path"] = local_path
|
||||
extra.append(entry)
|
||||
if extra:
|
||||
merged = {"models": bench_entries + extra}
|
||||
tf = tempfile.NamedTemporaryFile(
|
||||
mode="w", suffix=".yaml", delete=False,
|
||||
prefix="avocet_bench_models_",
|
||||
)
|
||||
yaml.dump(merged, tf)
|
||||
tf.close()
|
||||
_tmp_models_path = tf.name
|
||||
effective_models_file = _tmp_models_path
|
||||
except Exception as exc:
|
||||
logger.warning("Could not merge installed models into temp bench file: %s", exc)
|
||||
|
||||
cmd = [
|
||||
python_bin,
|
||||
bench_script,
|
||||
"--tasks", bench_tasks,
|
||||
"--models", bench_models,
|
||||
"--models", effective_models_file,
|
||||
"--output", results_dir,
|
||||
]
|
||||
|
||||
if task_ids:
|
||||
cmd.extend(["--filter-tasks"] + task_ids.split(","))
|
||||
if model_ids:
|
||||
cmd.extend(["--filter-models"] + model_ids.split(","))
|
||||
if model_tags:
|
||||
cmd.extend(["--filter-tags"] + model_tags.split(","))
|
||||
|
||||
|
|
@ -233,6 +416,15 @@ def run_benchmark(
|
|||
cmd.extend(["--coordinator", effective_coordinator])
|
||||
if effective_ollama:
|
||||
cmd.extend(["--ollama-url", effective_ollama])
|
||||
effective_judge = judge_url if judge_url else cfg_judge_url
|
||||
if effective_judge:
|
||||
cmd.extend(["--judge-url", effective_judge])
|
||||
if judge_backend and judge_backend != "chat":
|
||||
cmd.extend(["--judge-backend", judge_backend])
|
||||
if workers > 1:
|
||||
cmd.extend(["--workers", str(workers)])
|
||||
if node_ids:
|
||||
cmd.extend(["--nodes"] + node_ids.split(","))
|
||||
|
||||
# Pass license key as env var so subprocess can authenticate with cf-orch
|
||||
proc_env = {**os.environ}
|
||||
|
|
@ -250,8 +442,23 @@ def run_benchmark(
|
|||
env=proc_env,
|
||||
)
|
||||
_bench_proc = proc
|
||||
_IDLE_TIMEOUT_S = 120 # kill if no output for 2 minutes (node crash)
|
||||
try:
|
||||
for line in proc.stdout:
|
||||
while True:
|
||||
ready = _select.select([proc.stdout], [], [], _IDLE_TIMEOUT_S)
|
||||
if not ready[0]:
|
||||
# No output for IDLE_TIMEOUT_S — node likely crashed
|
||||
proc.terminate()
|
||||
try:
|
||||
proc.wait(timeout=5)
|
||||
except _subprocess.TimeoutExpired:
|
||||
proc.kill()
|
||||
msg = f"Benchmark timed out — no output for {_IDLE_TIMEOUT_S}s (cluster node may have crashed)"
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': msg})}\n\n"
|
||||
break
|
||||
line = proc.stdout.readline()
|
||||
if not line:
|
||||
break
|
||||
line = _strip_ansi(line.rstrip())
|
||||
if line:
|
||||
yield f"data: {json.dumps({'type': 'progress', 'message': line})}\n\n"
|
||||
|
|
@ -273,6 +480,11 @@ def run_benchmark(
|
|||
yield f"data: {json.dumps({'type': 'error', 'message': str(exc)})}\n\n"
|
||||
finally:
|
||||
_BENCH_RUNNING = False
|
||||
if _tmp_models_path:
|
||||
try:
|
||||
os.unlink(_tmp_models_path)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
return StreamingResponse(
|
||||
generate(),
|
||||
|
|
@ -295,6 +507,7 @@ def get_cforch_config() -> dict:
|
|||
"coordinator_url": cfg.get("coordinator_url", ""),
|
||||
"ollama_url": cfg.get("ollama_url", ""),
|
||||
"ollama_model": cfg.get("ollama_model", ""),
|
||||
"judge_url": cfg.get("judge_url", ""),
|
||||
"license_key_set": bool(cfg.get("license_key", "")),
|
||||
"source": "env" if not _config_file().exists() else "yaml+env",
|
||||
}
|
||||
|
|
@ -335,3 +548,106 @@ def cancel_benchmark() -> dict:
|
|||
_BENCH_RUNNING = False
|
||||
_bench_proc = None
|
||||
return {"status": "cancelled"}
|
||||
|
||||
|
||||
# ── Coordinator proxy helpers ──────────────────────────────────────────────────
|
||||
|
||||
def _coordinator_url() -> str:
|
||||
"""Return coordinator base URL from config, or raise 503 if not configured."""
|
||||
url = _load_cforch_config().get("coordinator_url", "").rstrip("/")
|
||||
if not url:
|
||||
raise HTTPException(503, "cf-orch coordinator_url not configured")
|
||||
return url
|
||||
|
||||
|
||||
def _coordinator_get(path: str) -> Any:
|
||||
"""GET from coordinator, return parsed JSON body. Raises HTTPException on error."""
|
||||
import httpx as _httpx
|
||||
try:
|
||||
resp = _httpx.get(f"{_coordinator_url()}{path}", timeout=10.0)
|
||||
except Exception as exc:
|
||||
raise HTTPException(502, f"Coordinator unreachable: {exc}") from exc
|
||||
if not resp.is_success:
|
||||
raise HTTPException(resp.status_code, resp.text)
|
||||
return resp.json()
|
||||
|
||||
|
||||
async def _coordinator_post(path: str, body: dict) -> Any:
|
||||
import httpx as _httpx
|
||||
try:
|
||||
async with _httpx.AsyncClient(timeout=10.0) as client:
|
||||
resp = await client.post(f"{_coordinator_url()}{path}", json=body)
|
||||
except Exception as exc:
|
||||
raise HTTPException(502, f"Coordinator unreachable: {exc}") from exc
|
||||
if not resp.is_success:
|
||||
raise HTTPException(resp.status_code, resp.text)
|
||||
return resp.json()
|
||||
|
||||
|
||||
async def _coordinator_delete(path: str) -> Any:
|
||||
import httpx as _httpx
|
||||
try:
|
||||
async with _httpx.AsyncClient(timeout=10.0) as client:
|
||||
resp = await client.delete(f"{_coordinator_url()}{path}")
|
||||
except Exception as exc:
|
||||
raise HTTPException(502, f"Coordinator unreachable: {exc}") from exc
|
||||
if not resp.is_success:
|
||||
raise HTTPException(resp.status_code, resp.text)
|
||||
return resp.json()
|
||||
|
||||
|
||||
# ── GET /assignments/deployment-status ───────────────────────────────────────
|
||||
|
||||
@router.get("/assignments/deployment-status")
|
||||
def get_deployment_status() -> Any:
|
||||
return _coordinator_get("/api/assignments/deployment-status")
|
||||
|
||||
|
||||
# ── /assignments ──────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/assignments")
|
||||
def list_assignments() -> Any:
|
||||
return _coordinator_get("/api/assignments")
|
||||
|
||||
|
||||
class AssignmentBody(BaseModel):
|
||||
product: str
|
||||
task: str
|
||||
model_id: str
|
||||
description: str = ""
|
||||
|
||||
|
||||
@router.post("/assignments")
|
||||
async def upsert_assignment(body: AssignmentBody) -> Any:
|
||||
return await _coordinator_post("/api/assignments", body.model_dump())
|
||||
|
||||
|
||||
@router.delete("/assignments/{product}/{task}")
|
||||
async def delete_assignment(product: str, task: str) -> Any:
|
||||
return await _coordinator_delete(f"/api/assignments/{urllib.parse.quote(product, safe='')}/{urllib.parse.quote(task, safe='')}")
|
||||
|
||||
|
||||
# ── /model-registry ────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/model-registry")
|
||||
def list_model_registry() -> Any:
|
||||
return _coordinator_get("/api/model-registry")
|
||||
|
||||
|
||||
class ModelRegistryBody(BaseModel):
|
||||
model_id: str
|
||||
service_type: str
|
||||
vram_mb: int
|
||||
description: str = ""
|
||||
hf_repo: str = ""
|
||||
alias: str = ""
|
||||
|
||||
|
||||
@router.post("/model-registry")
|
||||
async def upsert_model_registry(body: ModelRegistryBody) -> Any:
|
||||
return await _coordinator_post("/api/model-registry", body.model_dump())
|
||||
|
||||
|
||||
@router.delete("/model-registry/{model_id:path}")
|
||||
async def delete_model_registry(model_id: str) -> Any:
|
||||
return await _coordinator_delete(f"/api/model-registry/{urllib.parse.quote(model_id, safe='')}")
|
||||
|
|
|
|||
34
app/cloud_session.py
Normal file
34
app/cloud_session.py
Normal file
|
|
@ -0,0 +1,34 @@
|
|||
"""
|
||||
Avocet cloud session — thin wrapper around cf_core.cloud_session.
|
||||
|
||||
Usage in FastAPI routes:
|
||||
|
||||
from app.cloud_session import get_session, require_tier, CloudUser
|
||||
from fastapi import Depends
|
||||
|
||||
@router.get("/api/imitate")
|
||||
def imitate(session: CloudUser = Depends(get_session)):
|
||||
# session.user_id — Directus UUID (cloud) or "local" (self-hosted)
|
||||
# session.tier — free | paid | premium | ultra | local
|
||||
# session.has_byok — True if user has a configured LLM backend
|
||||
...
|
||||
|
||||
@router.post("/api/custom-models")
|
||||
def list_custom_models(session: CloudUser = Depends(require_tier("premium"))):
|
||||
...
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
from circuitforge_core.cloud_session import CloudSessionFactory, CloudUser, detect_byok
|
||||
|
||||
__all__ = ["CloudUser", "get_session", "require_tier"]
|
||||
|
||||
_factory = CloudSessionFactory(
|
||||
product="avocet",
|
||||
byok_detector=detect_byok,
|
||||
)
|
||||
|
||||
get_session = _factory.dependency()
|
||||
require_tier = _factory.require_tier
|
||||
282
app/dashboard.py
Normal file
282
app/dashboard.py
Normal file
|
|
@ -0,0 +1,282 @@
|
|||
"""Avocet -- dashboard aggregate API.
|
||||
|
||||
GET /api/dashboard returns the current flywheel state:
|
||||
labeled_since_last_eval -- items labeled after the most recent bench run
|
||||
last_eval_timestamp -- ISO timestamp of newest bench_results summary
|
||||
last_eval_best_score -- best macro_f1 from that summary
|
||||
active_jobs -- jobs with status queued or running
|
||||
corrections_pending -- sft_candidates with status=needs_review
|
||||
corrections_export_ready -- approved sft candidates with non-blank correction
|
||||
recent_bench_runs -- most-recent timestamp + score per bench type
|
||||
signals -- computed booleans for UI nudge indicators
|
||||
|
||||
Thresholds in label_tool.yaml pipeline: section:
|
||||
pipeline:
|
||||
data_eval_threshold: 50 # labeled items since last bench to trigger nudge
|
||||
eval_train_threshold: 0.05 # improvement delta needed before retraining (future)
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import yaml
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_ROOT = Path(__file__).parent.parent
|
||||
_DATA_DIR: Path = _ROOT / "data"
|
||||
_CONFIG_DIR: Path | None = None
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
_DEFAULT_DATA_EVAL_THRESHOLD = 50
|
||||
_DEFAULT_EVAL_TRAIN_THRESHOLD = 0.05
|
||||
|
||||
|
||||
def set_data_dir(path: Path) -> None:
|
||||
global _DATA_DIR
|
||||
_DATA_DIR = path
|
||||
|
||||
def set_config_dir(path: Path | None) -> None:
|
||||
global _CONFIG_DIR
|
||||
_CONFIG_DIR = path
|
||||
|
||||
def _config_file() -> Path:
|
||||
if _CONFIG_DIR is not None:
|
||||
return _CONFIG_DIR / "label_tool.yaml"
|
||||
return _ROOT / "config" / "label_tool.yaml"
|
||||
|
||||
def _load_thresholds() -> tuple[int, float]:
|
||||
f = _config_file()
|
||||
if f.exists():
|
||||
try:
|
||||
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
||||
pipeline = raw.get("pipeline", {}) or {}
|
||||
return (
|
||||
int(pipeline.get("data_eval_threshold", _DEFAULT_DATA_EVAL_THRESHOLD)),
|
||||
float(pipeline.get("eval_train_threshold", _DEFAULT_EVAL_TRAIN_THRESHOLD)),
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to read pipeline thresholds: %s", exc)
|
||||
return _DEFAULT_DATA_EVAL_THRESHOLD, _DEFAULT_EVAL_TRAIN_THRESHOLD
|
||||
|
||||
def _load_score_records() -> list[dict]:
|
||||
path = _DATA_DIR / "email_score.jsonl"
|
||||
if not path.exists():
|
||||
return []
|
||||
records = []
|
||||
for line in path.read_text(encoding="utf-8").splitlines():
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
records.append(json.loads(line))
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
return records
|
||||
|
||||
def _find_latest_classifier_bench(results_dir_override: str = "") -> tuple[str | None, float | None]:
|
||||
"""Return (iso_timestamp, best_macro_f1) from the newest bench_results summary.
|
||||
|
||||
Checks results_dir from cforch config if set, then falls back to
|
||||
_ROOT/bench_results/. Returns (None, None) if no results exist.
|
||||
"""
|
||||
candidates = []
|
||||
if results_dir_override:
|
||||
candidates.append(Path(results_dir_override))
|
||||
else:
|
||||
f = _config_file()
|
||||
if f.exists():
|
||||
try:
|
||||
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
||||
rd = (raw.get("cforch", {}) or {}).get("results_dir", "")
|
||||
if rd:
|
||||
candidates.append(Path(rd))
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to read cforch.results_dir from config: %s", exc)
|
||||
candidates.append(_ROOT / "bench_results")
|
||||
|
||||
for rdir in candidates:
|
||||
if not rdir.exists():
|
||||
continue
|
||||
subdirs = sorted([d for d in rdir.iterdir() if d.is_dir()], key=lambda d: d.name)
|
||||
for subdir in reversed(subdirs):
|
||||
summary = subdir / "summary.json"
|
||||
if summary.exists():
|
||||
try:
|
||||
data = json.loads(summary.read_text(encoding="utf-8"))
|
||||
if not isinstance(data, dict):
|
||||
continue # cforch LLM-bench summaries are lists; skip
|
||||
ts = data.get("timestamp") or subdir.name
|
||||
score = data.get("best_macro_f1") or data.get("macro_f1")
|
||||
return ts, (float(score) if isinstance(score, (int, float)) else None)
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to parse summary.json at %s: %s", summary, exc)
|
||||
return None, None
|
||||
|
||||
# Keep old name as alias so existing callers in tests still work.
|
||||
_find_latest_eval = _find_latest_classifier_bench
|
||||
|
||||
|
||||
def _count_corrections() -> tuple[int, int]:
|
||||
"""Return (pending_count, export_ready_count)."""
|
||||
pending = 0
|
||||
export_ready = 0
|
||||
candidates_path = _DATA_DIR / "sft_candidates.jsonl"
|
||||
approved_path = _DATA_DIR / "sft_approved.jsonl"
|
||||
if candidates_path.exists():
|
||||
for line in candidates_path.read_text(encoding="utf-8").splitlines():
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
r = json.loads(line)
|
||||
if r.get("status") == "needs_review":
|
||||
pending += 1
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
if approved_path.exists():
|
||||
for line in approved_path.read_text(encoding="utf-8").splitlines():
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
r = json.loads(line)
|
||||
if (r.get("status") == "approved"
|
||||
and r.get("corrected_response")
|
||||
and str(r["corrected_response"]).strip()):
|
||||
export_ready += 1
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
return pending, export_ready
|
||||
|
||||
def _get_active_jobs() -> list[dict]:
|
||||
"""Query train SQLite DB for queued/running jobs. Returns [] if DB absent."""
|
||||
try:
|
||||
from app.train.train import _DB_PATH, _db, _init_db
|
||||
if not _DB_PATH.exists():
|
||||
return []
|
||||
_init_db()
|
||||
with _db() as conn:
|
||||
rows = conn.execute(
|
||||
"SELECT id, type, model_key, status FROM jobs WHERE status IN ('queued', 'running')"
|
||||
).fetchall()
|
||||
return [{"id": r["id"], "type": r["type"], "model_key": r["model_key"], "status": r["status"]} for r in rows]
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to query train jobs DB: %s", exc)
|
||||
return []
|
||||
|
||||
def _count_labeled_since(since_ts: str | None) -> int:
|
||||
records = _load_score_records()
|
||||
if since_ts is None:
|
||||
return len(records)
|
||||
return sum(1 for r in records if r.get("labeled_at", "") > since_ts)
|
||||
|
||||
|
||||
def _get_recent_bench_runs() -> dict:
|
||||
"""Return most-recent run summary for each bench type.
|
||||
|
||||
Each entry: {"timestamp": str|None, "metric": str|None, "score": float|None}
|
||||
"""
|
||||
runs: dict[str, dict] = {
|
||||
"classifier": {"timestamp": None, "metric": "macro_f1", "score": None},
|
||||
"llm": {"timestamp": None, "metric": None, "score": None},
|
||||
"style": {"timestamp": None, "metric": None, "score": None},
|
||||
"plans": {"timestamp": None, "metric": "avg_score", "score": None},
|
||||
}
|
||||
|
||||
# ── Classifier: bench_results/<run>/summary.json ──────────────────────
|
||||
clf_ts, clf_score = _find_latest_classifier_bench()
|
||||
if clf_ts:
|
||||
runs["classifier"]["timestamp"] = clf_ts
|
||||
runs["classifier"]["score"] = clf_score
|
||||
|
||||
# ── LLM bench + Style: benchmark_results/ ─────────────────────────────
|
||||
f = _config_file()
|
||||
bench_dir: Path | None = None
|
||||
if f.exists():
|
||||
try:
|
||||
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
||||
rd = (raw.get("cforch", {}) or {}).get("results_dir", "")
|
||||
if rd:
|
||||
bench_dir = Path(rd)
|
||||
except Exception:
|
||||
pass
|
||||
if bench_dir is None:
|
||||
bench_dir = _ROOT / "benchmark_results"
|
||||
|
||||
if bench_dir.exists():
|
||||
llm_files = sorted(
|
||||
[p for p in bench_dir.glob("*.json") if not p.name.startswith("style_")],
|
||||
key=lambda p: p.stat().st_mtime, reverse=True,
|
||||
)
|
||||
if llm_files:
|
||||
try:
|
||||
data = json.loads(llm_files[0].read_text(encoding="utf-8"))
|
||||
runs["llm"]["timestamp"] = data.get("timestamp") or llm_files[0].stem
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
style_files = sorted(bench_dir.glob("style_*.json"), reverse=True)
|
||||
if style_files:
|
||||
try:
|
||||
data = json.loads(style_files[0].read_text(encoding="utf-8"))
|
||||
if isinstance(data, list) and data:
|
||||
runs["style"]["timestamp"] = data[0].get("timestamp") or style_files[0].stem
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# ── Plans bench: data/plans_bench_results/plans_*.json ────────────────
|
||||
plans_dir = _DATA_DIR / "plans_bench_results"
|
||||
if plans_dir.exists():
|
||||
plans_files = sorted(plans_dir.glob("plans_*.json"), reverse=True)
|
||||
if plans_files:
|
||||
run_id = plans_files[0].stem
|
||||
try:
|
||||
d: dict = json.loads(plans_files[0].read_text(encoding="utf-8"))
|
||||
all_scores = [
|
||||
r["total_score"]
|
||||
for results in d.values()
|
||||
for r in results
|
||||
if isinstance(r, dict) and not r.get("error")
|
||||
]
|
||||
avg = round(sum(all_scores) / len(all_scores), 3) if all_scores else None
|
||||
try:
|
||||
date_part = run_id.removeprefix("plans_")
|
||||
date, time_part = date_part.split("_")
|
||||
ts_display = f"{date} {time_part[:2]}:{time_part[2:4]}"
|
||||
except Exception:
|
||||
ts_display = run_id
|
||||
runs["plans"]["timestamp"] = ts_display
|
||||
runs["plans"]["score"] = avg
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return runs
|
||||
|
||||
|
||||
@router.get("/dashboard")
|
||||
def get_dashboard() -> dict:
|
||||
data_threshold, _train_threshold = _load_thresholds()
|
||||
last_ts, last_score = _find_latest_classifier_bench()
|
||||
labeled_since = _count_labeled_since(last_ts)
|
||||
corrections_pending, corrections_export_ready = _count_corrections()
|
||||
active_jobs = _get_active_jobs()
|
||||
recent_bench = _get_recent_bench_runs()
|
||||
return {
|
||||
"labeled_since_last_eval": labeled_since,
|
||||
"last_eval_timestamp": last_ts,
|
||||
"last_eval_best_score": last_score,
|
||||
"active_jobs": active_jobs,
|
||||
"corrections_pending": corrections_pending,
|
||||
"corrections_export_ready": corrections_export_ready,
|
||||
"recent_bench_runs": recent_bench,
|
||||
"signals": {
|
||||
"data_to_eval": labeled_since >= data_threshold,
|
||||
"eval_to_train": False, # future: implement delta-F1 comparison
|
||||
"train_to_fleet": False, # future: implement fleet sync signal
|
||||
},
|
||||
}
|
||||
0
app/data/__init__.py
Normal file
0
app/data/__init__.py
Normal file
393
app/data/corrections.py
Normal file
393
app/data/corrections.py
Normal file
|
|
@ -0,0 +1,393 @@
|
|||
"""Avocet -- SFT candidate corrections API (moved from app/sft.py).
|
||||
|
||||
All endpoints are registered on `router` (a FastAPI APIRouter).
|
||||
Primary prefix: /api/corrections (backward-compat alias: /api/sft -- pending Vue SPA migration)
|
||||
|
||||
Module-level globals (_DATA_DIR, _CONFIG_DIR) follow the same
|
||||
testability pattern as api.py -- override them via set_data_dir() and
|
||||
set_config_dir() in test fixtures.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
import yaml
|
||||
from fastapi import APIRouter, Header, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.utils import append_jsonl, read_jsonl, write_jsonl
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_ROOT = Path(__file__).parent.parent.parent
|
||||
_DATA_DIR: Path = _ROOT / "data"
|
||||
_CONFIG_DIR: Path | None = None
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# -- Testability seams ---------------------------------------------------------
|
||||
|
||||
def set_data_dir(path: Path) -> None:
|
||||
global _DATA_DIR
|
||||
_DATA_DIR = path
|
||||
|
||||
|
||||
def set_config_dir(path: Path | None) -> None:
|
||||
global _CONFIG_DIR
|
||||
_CONFIG_DIR = path
|
||||
|
||||
|
||||
# -- Internal helpers ----------------------------------------------------------
|
||||
|
||||
def _config_file() -> Path:
|
||||
if _CONFIG_DIR is not None:
|
||||
return _CONFIG_DIR / "label_tool.yaml"
|
||||
return _ROOT / "config" / "label_tool.yaml"
|
||||
|
||||
|
||||
_DEFAULT_BENCH_RESULTS_DIR = "/Library/Development/CircuitForge/circuitforge-orch/scripts/bench_results"
|
||||
|
||||
|
||||
def set_default_bench_results_dir(path: str) -> None:
|
||||
"""Override the default bench_results_dir -- used by tests to avoid real filesystem."""
|
||||
global _DEFAULT_BENCH_RESULTS_DIR
|
||||
_DEFAULT_BENCH_RESULTS_DIR = path
|
||||
|
||||
|
||||
def _get_bench_results_dir() -> Path:
|
||||
f = _config_file()
|
||||
if f.exists():
|
||||
try:
|
||||
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
||||
d = raw.get("sft", {}).get("bench_results_dir", "")
|
||||
if d:
|
||||
return Path(d)
|
||||
except yaml.YAMLError as exc:
|
||||
logger.warning("Failed to parse SFT config %s: %s", f, exc)
|
||||
return Path(_DEFAULT_BENCH_RESULTS_DIR)
|
||||
|
||||
|
||||
def _candidates_file() -> Path:
|
||||
return _DATA_DIR / "sft_candidates.jsonl"
|
||||
|
||||
|
||||
def _approved_file() -> Path:
|
||||
return _DATA_DIR / "sft_approved.jsonl"
|
||||
|
||||
|
||||
def _read_candidates() -> list[dict]:
|
||||
return read_jsonl(_candidates_file())
|
||||
|
||||
|
||||
def _write_candidates(records: list[dict]) -> None:
|
||||
write_jsonl(_candidates_file(), records)
|
||||
|
||||
|
||||
def _is_exportable(r: dict) -> bool:
|
||||
"""Return True if an approved record is ready to include in SFT export."""
|
||||
return (
|
||||
r.get("status") == "approved"
|
||||
and bool(r.get("corrected_response"))
|
||||
and str(r["corrected_response"]).strip() != ""
|
||||
)
|
||||
|
||||
|
||||
# -- GET /runs -----------------------------------------------------------------
|
||||
|
||||
@router.get("/runs")
|
||||
def get_runs():
|
||||
"""List available benchmark runs in the configured bench_results_dir."""
|
||||
from scripts.sft_import import discover_runs
|
||||
bench_dir = _get_bench_results_dir()
|
||||
existing = _read_candidates()
|
||||
# benchmark_run_id in each record equals the run's directory name by cf-orch convention
|
||||
imported_run_ids = {
|
||||
r["benchmark_run_id"]
|
||||
for r in existing
|
||||
if r.get("benchmark_run_id") is not None
|
||||
}
|
||||
runs = discover_runs(bench_dir)
|
||||
return [
|
||||
{
|
||||
"run_id": r["run_id"],
|
||||
"timestamp": r["timestamp"],
|
||||
"candidate_count": r["candidate_count"],
|
||||
"already_imported": r["run_id"] in imported_run_ids,
|
||||
}
|
||||
for r in runs
|
||||
]
|
||||
|
||||
|
||||
# -- POST /import --------------------------------------------------------------
|
||||
|
||||
class ImportRequest(BaseModel):
|
||||
run_id: str
|
||||
|
||||
|
||||
@router.post("/import")
|
||||
def post_import(req: ImportRequest):
|
||||
"""Import one benchmark run's sft_candidates.jsonl into the local data dir."""
|
||||
from scripts.sft_import import discover_runs, import_run
|
||||
bench_dir = _get_bench_results_dir()
|
||||
runs = discover_runs(bench_dir)
|
||||
run = next((r for r in runs if r["run_id"] == req.run_id), None)
|
||||
if run is None:
|
||||
raise HTTPException(404, f"Run {req.run_id!r} not found in bench_results_dir")
|
||||
return import_run(run["sft_path"], _DATA_DIR)
|
||||
|
||||
|
||||
# -- GET /queue ----------------------------------------------------------------
|
||||
|
||||
@router.get("/queue")
|
||||
def get_queue(page: int = 1, per_page: int = 20):
|
||||
"""Return paginated needs_review candidates."""
|
||||
records = _read_candidates()
|
||||
pending = [r for r in records if r.get("status") == "needs_review"]
|
||||
start = (page - 1) * per_page
|
||||
return {
|
||||
"items": pending[start:start + per_page],
|
||||
"total": len(pending),
|
||||
"page": page,
|
||||
"per_page": per_page,
|
||||
}
|
||||
|
||||
|
||||
# -- POST /submit --------------------------------------------------------------
|
||||
|
||||
FailureCategory = Literal[
|
||||
"scoring_artifact",
|
||||
"style_violation",
|
||||
"partial_answer",
|
||||
"wrong_answer",
|
||||
"format_error",
|
||||
"hallucination",
|
||||
]
|
||||
|
||||
|
||||
class SubmitRequest(BaseModel):
|
||||
id: str
|
||||
action: Literal["correct", "discard", "flag"]
|
||||
corrected_response: str | None = None
|
||||
failure_category: FailureCategory | None = None
|
||||
|
||||
|
||||
@router.post("/submit")
|
||||
def post_submit(req: SubmitRequest):
|
||||
"""Record a reviewer decision for one SFT candidate."""
|
||||
if req.action == "correct":
|
||||
if not req.corrected_response or not req.corrected_response.strip():
|
||||
raise HTTPException(422, "corrected_response must be non-empty when action is 'correct'")
|
||||
|
||||
records = _read_candidates()
|
||||
idx = next((i for i, r in enumerate(records) if r.get("id") == req.id), None)
|
||||
if idx is None:
|
||||
raise HTTPException(404, f"Record {req.id!r} not found")
|
||||
|
||||
record = records[idx]
|
||||
if record.get("status") != "needs_review":
|
||||
raise HTTPException(409, f"Record is not in needs_review state (current: {record.get('status')})")
|
||||
|
||||
if req.action == "correct":
|
||||
records[idx] = {
|
||||
**record,
|
||||
"status": "approved",
|
||||
"corrected_response": req.corrected_response,
|
||||
"failure_category": req.failure_category,
|
||||
}
|
||||
_write_candidates(records)
|
||||
append_jsonl(_approved_file(), records[idx])
|
||||
elif req.action == "discard":
|
||||
records[idx] = {**record, "status": "discarded"}
|
||||
_write_candidates(records)
|
||||
else: # flag
|
||||
records[idx] = {**record, "status": "model_rejected"}
|
||||
_write_candidates(records)
|
||||
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
# -- POST /undo ----------------------------------------------------------------
|
||||
|
||||
class UndoRequest(BaseModel):
|
||||
id: str
|
||||
|
||||
|
||||
@router.post("/undo")
|
||||
def post_undo(req: UndoRequest):
|
||||
"""Restore a previously actioned candidate back to needs_review."""
|
||||
records = _read_candidates()
|
||||
idx = next((i for i, r in enumerate(records) if r.get("id") == req.id), None)
|
||||
if idx is None:
|
||||
raise HTTPException(404, f"Record {req.id!r} not found")
|
||||
|
||||
record = records[idx]
|
||||
old_status = record.get("status")
|
||||
if old_status == "needs_review":
|
||||
raise HTTPException(409, "Record is already in needs_review state")
|
||||
|
||||
records[idx] = {**record, "status": "needs_review", "corrected_response": None}
|
||||
_write_candidates(records)
|
||||
|
||||
# If it was approved, remove from the approved file too
|
||||
if old_status == "approved":
|
||||
approved = read_jsonl(_approved_file())
|
||||
write_jsonl(_approved_file(), [r for r in approved if r.get("id") != req.id])
|
||||
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
# -- GET /export ---------------------------------------------------------------
|
||||
|
||||
@router.get("/export")
|
||||
def get_export() -> StreamingResponse:
|
||||
"""Stream approved records as SFT-ready JSONL for download."""
|
||||
exportable = [r for r in read_jsonl(_approved_file()) if _is_exportable(r)]
|
||||
|
||||
def generate():
|
||||
for r in exportable:
|
||||
record = {
|
||||
"messages": r.get("prompt_messages", []) + [
|
||||
{"role": "assistant", "content": r["corrected_response"]}
|
||||
]
|
||||
}
|
||||
yield json.dumps(record) + "\n"
|
||||
|
||||
timestamp = datetime.now(timezone.utc).strftime("%Y%m%d-%H%M%S")
|
||||
return StreamingResponse(
|
||||
generate(),
|
||||
media_type="application/x-ndjson",
|
||||
headers={
|
||||
"Content-Disposition": f'attachment; filename="sft_export_{timestamp}.jsonl"'
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# -- GET /stats ----------------------------------------------------------------
|
||||
|
||||
@router.get("/stats")
|
||||
def get_stats() -> dict[str, object]:
|
||||
"""Return counts by status, model, and task type."""
|
||||
records = _read_candidates()
|
||||
by_status: dict[str, int] = {}
|
||||
by_model: dict[str, int] = {}
|
||||
by_task_type: dict[str, int] = {}
|
||||
|
||||
for r in records:
|
||||
status = r.get("status", "unknown")
|
||||
by_status[status] = by_status.get(status, 0) + 1
|
||||
model = r.get("model_name", "unknown")
|
||||
by_model[model] = by_model.get(model, 0) + 1
|
||||
task_type = r.get("task_type", "unknown")
|
||||
by_task_type[task_type] = by_task_type.get(task_type, 0) + 1
|
||||
|
||||
approved = read_jsonl(_approved_file())
|
||||
export_ready = sum(1 for r in approved if _is_exportable(r))
|
||||
|
||||
return {
|
||||
"total": len(records),
|
||||
"by_status": by_status,
|
||||
"by_model": by_model,
|
||||
"by_task_type": by_task_type,
|
||||
"export_ready": export_ready,
|
||||
}
|
||||
|
||||
|
||||
# -- GET /config ---------------------------------------------------------------
|
||||
|
||||
@router.get("/config")
|
||||
def get_sft_config() -> dict:
|
||||
"""Return the current SFT configuration (bench_results_dir)."""
|
||||
f = _config_file()
|
||||
if not f.exists():
|
||||
return {"bench_results_dir": ""}
|
||||
try:
|
||||
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
||||
except yaml.YAMLError:
|
||||
return {"bench_results_dir": ""}
|
||||
sft_section = raw.get("sft") or {}
|
||||
return {"bench_results_dir": sft_section.get("bench_results_dir", "")}
|
||||
|
||||
|
||||
class SftConfigPayload(BaseModel):
|
||||
bench_results_dir: str
|
||||
|
||||
|
||||
@router.post("/config")
|
||||
def post_sft_config(payload: SftConfigPayload) -> dict:
|
||||
"""Write the bench_results_dir setting to the config file."""
|
||||
f = _config_file()
|
||||
f.parent.mkdir(parents=True, exist_ok=True)
|
||||
try:
|
||||
raw = yaml.safe_load(f.read_text(encoding="utf-8")) if f.exists() else {}
|
||||
raw = raw or {}
|
||||
except yaml.YAMLError:
|
||||
raw = {}
|
||||
raw["sft"] = {"bench_results_dir": payload.bench_results_dir}
|
||||
tmp = f.with_suffix(".tmp")
|
||||
tmp.write_text(yaml.dump(raw, allow_unicode=True, sort_keys=False), encoding="utf-8")
|
||||
tmp.rename(f)
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
# -- POST /ingest --------------------------------------------------------------
|
||||
|
||||
class IngestRequest(BaseModel):
|
||||
source: str # e.g. "peregrine", "kiwi"
|
||||
task_type: str # e.g. "email_classification", "recipe_suggestion"
|
||||
prompt: str # the prompt that was sent to the LLM
|
||||
response: str # the LLM's original response
|
||||
correction: str # the human-corrected response
|
||||
label: str | None = None # optional label/category
|
||||
|
||||
|
||||
@router.post("/ingest")
|
||||
def post_ingest(
|
||||
req: IngestRequest,
|
||||
authorization: str | None = Header(default=None),
|
||||
) -> dict:
|
||||
"""Ingest a correction from a sibling CF product.
|
||||
|
||||
Authentication: Authorization: Bearer <AVOCET_INGESTION_SECRET>
|
||||
|
||||
Creates a sft_candidates record with status='approved' (pre-approved by
|
||||
the calling product -- human review already happened upstream). Also writes
|
||||
to sft_approved.jsonl so it is immediately included in export counts.
|
||||
|
||||
Returns {"ok": True, "id": "<uuid>"}.
|
||||
"""
|
||||
expected_secret = os.environ.get("AVOCET_INGESTION_SECRET", "")
|
||||
if not expected_secret:
|
||||
raise HTTPException(503, "Ingestion not configured -- AVOCET_INGESTION_SECRET not set")
|
||||
|
||||
if not authorization or not authorization.startswith("Bearer "):
|
||||
raise HTTPException(401, "Missing or malformed Authorization header")
|
||||
|
||||
token = authorization.removeprefix("Bearer ").strip()
|
||||
if token != expected_secret:
|
||||
raise HTTPException(403, "Invalid ingestion secret")
|
||||
|
||||
record_id = str(uuid.uuid4())
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
record = {
|
||||
"id": record_id,
|
||||
"source": req.source,
|
||||
"task_type": req.task_type,
|
||||
"status": "approved",
|
||||
"prompt_messages": [{"role": "user", "content": req.prompt}],
|
||||
"model_response": req.response,
|
||||
"corrected_response": req.correction,
|
||||
"label": req.label,
|
||||
"timestamp": now,
|
||||
"benchmark_run_id": None,
|
||||
}
|
||||
append_jsonl(_candidates_file(), record)
|
||||
append_jsonl(_approved_file(), record)
|
||||
return {"ok": True, "id": record_id}
|
||||
243
app/data/fetch.py
Normal file
243
app/data/fetch.py
Normal file
|
|
@ -0,0 +1,243 @@
|
|||
"""Avocet -- IMAP fetch utilities and fetch API routes.
|
||||
|
||||
All IMAP helper functions (from app/imap_fetch.py) plus the
|
||||
/api/accounts/test and /api/fetch/stream endpoints.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import email as _email_lib
|
||||
import hashlib
|
||||
import imaplib
|
||||
import json
|
||||
import yaml
|
||||
from datetime import datetime, timedelta
|
||||
from email.header import decode_header as _raw_decode
|
||||
from pathlib import Path
|
||||
from typing import Iterator
|
||||
|
||||
from fastapi import APIRouter, Query
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.utils import extract_body, read_jsonl, write_jsonl
|
||||
|
||||
_ROOT = Path(__file__).parent.parent.parent
|
||||
_DATA_DIR: Path = _ROOT / "data"
|
||||
_CONFIG_DIR: Path | None = None
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def set_data_dir(path: Path) -> None:
|
||||
global _DATA_DIR
|
||||
_DATA_DIR = path
|
||||
|
||||
|
||||
def set_config_dir(path: Path | None) -> None:
|
||||
global _CONFIG_DIR
|
||||
_CONFIG_DIR = path
|
||||
|
||||
|
||||
def _config_file() -> Path:
|
||||
if _CONFIG_DIR is not None:
|
||||
return _CONFIG_DIR / "label_tool.yaml"
|
||||
return _ROOT / "config" / "label_tool.yaml"
|
||||
|
||||
|
||||
def _queue_file() -> Path:
|
||||
return _DATA_DIR / "email_label_queue.jsonl"
|
||||
|
||||
|
||||
def _get_config_accounts() -> list[dict]:
|
||||
f = _config_file()
|
||||
if not f.exists():
|
||||
return []
|
||||
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
||||
return raw.get("accounts", [])
|
||||
|
||||
|
||||
# ── IMAP decode helpers ───────────────────────────────────────────────────────
|
||||
|
||||
def _decode_str(value: str | None) -> str:
|
||||
if not value:
|
||||
return ""
|
||||
parts = _raw_decode(value)
|
||||
out = []
|
||||
for part, enc in parts:
|
||||
if isinstance(part, bytes):
|
||||
out.append(part.decode(enc or "utf-8", errors="replace"))
|
||||
else:
|
||||
out.append(str(part))
|
||||
return " ".join(out).strip()
|
||||
|
||||
|
||||
def entry_key(e: dict) -> str:
|
||||
"""Stable MD5 content-hash for dedup — matches label_tool.py _entry_key."""
|
||||
key = (e.get("subject", "") + (e.get("body", "") or "")[:100])
|
||||
return hashlib.md5(key.encode("utf-8", errors="replace")).hexdigest()
|
||||
|
||||
|
||||
# ── Wide search terms ────────────────────────────────────────────────────────
|
||||
|
||||
_WIDE_TERMS = [
|
||||
"interview", "phone screen", "video call", "zoom link", "schedule a call",
|
||||
"offer letter", "job offer", "offer of employment", "pleased to offer",
|
||||
"unfortunately", "not moving forward", "other candidates", "regret to inform",
|
||||
"no longer", "decided not to", "decided to go with",
|
||||
"opportunity", "interested in your background", "reached out", "great fit",
|
||||
"exciting role", "love to connect",
|
||||
"assessment", "questionnaire", "culture fit", "culture-fit", "online assessment",
|
||||
"application received", "thank you for applying", "application confirmation",
|
||||
"you applied", "your application for",
|
||||
"reschedule", "rescheduled", "new time", "moved to", "postponed", "new date",
|
||||
"job digest", "jobs you may like", "recommended jobs", "jobs for you",
|
||||
"new jobs", "job alert",
|
||||
"came across your profile", "reaching out about", "great fit for a role",
|
||||
"exciting opportunity",
|
||||
"welcome to the team", "start date", "onboarding", "first day", "we're excited to have you",
|
||||
"application", "recruiter", "recruiting", "hiring", "candidate",
|
||||
]
|
||||
|
||||
|
||||
# ── Public API ────────────────────────────────────────────────────────────────
|
||||
|
||||
def test_connection(acc: dict) -> tuple[bool, str, int | None]:
|
||||
"""Connect, login, select folder. Returns (ok, human_message, message_count|None)."""
|
||||
host = acc.get("host", "")
|
||||
port = int(acc.get("port", 993))
|
||||
use_ssl = acc.get("use_ssl", True)
|
||||
username = acc.get("username", "")
|
||||
password = acc.get("password", "")
|
||||
folder = acc.get("folder", "INBOX")
|
||||
if not host or not username or not password:
|
||||
return False, "Host, username, and password are all required.", None
|
||||
try:
|
||||
conn = (imaplib.IMAP4_SSL if use_ssl else imaplib.IMAP4)(host, port)
|
||||
conn.login(username, password)
|
||||
_, data = conn.select(folder, readonly=True)
|
||||
count_raw = data[0].decode() if data and data[0] else "0"
|
||||
count = int(count_raw) if count_raw.isdigit() else 0
|
||||
conn.logout()
|
||||
return True, f"Connected — {count:,} message(s) in {folder}.", count
|
||||
except Exception as exc:
|
||||
return False, str(exc), None
|
||||
|
||||
|
||||
def fetch_account_stream(
|
||||
acc: dict,
|
||||
days_back: int,
|
||||
limit: int,
|
||||
known_keys: set[str],
|
||||
) -> Iterator[dict]:
|
||||
"""Generator — yields progress dicts while fetching emails via IMAP.
|
||||
|
||||
Mutates `known_keys` in place for cross-account dedup within one fetch session.
|
||||
|
||||
Yields event dicts with "type" key:
|
||||
{"type": "start", "account": str, "total_uids": int}
|
||||
{"type": "progress", "account": str, "fetched": int, "total_uids": int}
|
||||
{"type": "done", "account": str, "added": int, "skipped": int, "emails": list}
|
||||
"""
|
||||
name = acc.get("name", acc.get("username", "?"))
|
||||
host = acc.get("host", "imap.gmail.com")
|
||||
port = int(acc.get("port", 993))
|
||||
use_ssl = acc.get("use_ssl", True)
|
||||
username = acc["username"]
|
||||
password = acc["password"]
|
||||
folder = acc.get("folder", "INBOX")
|
||||
since = (datetime.now() - timedelta(days=days_back)).strftime("%d-%b-%Y")
|
||||
|
||||
conn = (imaplib.IMAP4_SSL if use_ssl else imaplib.IMAP4)(host, port)
|
||||
conn.login(username, password)
|
||||
conn.select(folder, readonly=True)
|
||||
|
||||
seen_uids: dict[bytes, None] = {}
|
||||
for term in _WIDE_TERMS:
|
||||
try:
|
||||
_, data = conn.search(None, f'(SUBJECT "{term}" SINCE "{since}")')
|
||||
for uid in (data[0] or b"").split():
|
||||
seen_uids[uid] = None
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
uids = list(seen_uids.keys())[: limit * 3]
|
||||
yield {"type": "start", "account": name, "total_uids": len(uids)}
|
||||
|
||||
emails: list[dict] = []
|
||||
skipped = 0
|
||||
for i, uid in enumerate(uids):
|
||||
if len(emails) >= limit:
|
||||
break
|
||||
if i % 5 == 0:
|
||||
yield {"type": "progress", "account": name, "fetched": len(emails), "total_uids": len(uids)}
|
||||
try:
|
||||
_, raw_data = conn.fetch(uid, "(RFC822)")
|
||||
if not raw_data or not raw_data[0]:
|
||||
continue
|
||||
msg = _email_lib.message_from_bytes(raw_data[0][1])
|
||||
subj = _decode_str(msg.get("Subject", ""))
|
||||
from_addr = _decode_str(msg.get("From", ""))
|
||||
date = _decode_str(msg.get("Date", ""))
|
||||
body = extract_body(msg)[:800]
|
||||
entry = {"subject": subj, "body": body, "from_addr": from_addr,
|
||||
"date": date, "account": name}
|
||||
k = entry_key(entry)
|
||||
if k not in known_keys:
|
||||
known_keys.add(k)
|
||||
emails.append(entry)
|
||||
else:
|
||||
skipped += 1
|
||||
except Exception:
|
||||
skipped += 1
|
||||
|
||||
try:
|
||||
conn.logout()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
yield {"type": "done", "account": name, "added": len(emails), "skipped": skipped,
|
||||
"emails": emails}
|
||||
|
||||
|
||||
class AccountTestRequest(BaseModel):
|
||||
account: dict
|
||||
|
||||
|
||||
@router.post("/accounts/test")
|
||||
def test_account_route(req: AccountTestRequest) -> dict:
|
||||
ok, message, count = test_connection(req.account)
|
||||
return {"ok": ok, "message": message, "count": count}
|
||||
|
||||
|
||||
@router.get("/fetch/stream")
|
||||
def fetch_stream(
|
||||
accounts: str = Query(default=""),
|
||||
days_back: int = Query(default=90, ge=1, le=365),
|
||||
limit: int = Query(default=150, ge=1, le=1000),
|
||||
mode: str = Query(default="wide"),
|
||||
) -> StreamingResponse:
|
||||
selected_names = {n.strip() for n in accounts.split(",") if n.strip()}
|
||||
all_accounts = _get_config_accounts()
|
||||
selected = [a for a in all_accounts if a.get("name") in selected_names]
|
||||
|
||||
def generate():
|
||||
known_keys = {entry_key(x) for x in read_jsonl(_queue_file())}
|
||||
total_added = 0
|
||||
for acc in selected:
|
||||
try:
|
||||
batch_emails: list[dict] = []
|
||||
for event in fetch_account_stream(acc, days_back, limit, known_keys):
|
||||
if event["type"] == "done":
|
||||
batch_emails = event.pop("emails", [])
|
||||
total_added += event["added"]
|
||||
yield f"data: {json.dumps(event)}\n\n"
|
||||
if batch_emails:
|
||||
existing = read_jsonl(_queue_file())
|
||||
write_jsonl(_queue_file(), existing + batch_emails)
|
||||
except Exception as exc:
|
||||
yield f"data: {json.dumps({'type': 'error', 'account': acc.get('name', '?'), 'message': str(exc)})}\n\n"
|
||||
queue_size = len(read_jsonl(_queue_file()))
|
||||
yield f"data: {json.dumps({'type': 'complete', 'total_added': total_added, 'queue_size': queue_size})}\n\n"
|
||||
|
||||
return StreamingResponse(generate(), media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"})
|
||||
729
app/data/imitate.py
Normal file
729
app/data/imitate.py
Normal file
|
|
@ -0,0 +1,729 @@
|
|||
"""Avocet — Imitate tab API.
|
||||
|
||||
Fetches real samples from sibling CF product APIs, sends them through selected
|
||||
local LLMs (ollama), and streams responses back to the UI. Results can be
|
||||
pushed into the SFT corrections queue for human review.
|
||||
|
||||
All endpoints registered on `router`. api.py includes this with prefix="/api/imitate".
|
||||
|
||||
Module-level globals follow the same testability pattern as cforch.py and sft.py:
|
||||
override _CONFIG_DIR and _DATA_DIR via set_config_dir() / set_data_dir() in tests.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from urllib.error import URLError
|
||||
from urllib.request import Request, urlopen
|
||||
|
||||
import httpx
|
||||
import yaml
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.utils import append_jsonl
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_ROOT = Path(__file__).parent.parent.parent
|
||||
_CONFIG_DIR: Path | None = None
|
||||
_DATA_DIR: Path = _ROOT / "data"
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# ── Testability seams ──────────────────────────────────────────────────────────
|
||||
|
||||
def set_config_dir(path: Path | None) -> None:
|
||||
global _CONFIG_DIR
|
||||
_CONFIG_DIR = path
|
||||
|
||||
|
||||
def set_data_dir(path: Path) -> None:
|
||||
global _DATA_DIR
|
||||
_DATA_DIR = path
|
||||
|
||||
|
||||
# ── Internal helpers ───────────────────────────────────────────────────────────
|
||||
|
||||
def _config_file() -> Path:
|
||||
if _CONFIG_DIR is not None:
|
||||
return _CONFIG_DIR / "label_tool.yaml"
|
||||
return _ROOT / "config" / "label_tool.yaml"
|
||||
|
||||
|
||||
def _load_imitate_config() -> dict:
|
||||
"""Read label_tool.yaml and return the imitate sub-dict (or {} if absent)."""
|
||||
f = _config_file()
|
||||
if not f.exists():
|
||||
return {}
|
||||
try:
|
||||
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
||||
except yaml.YAMLError as exc:
|
||||
logger.warning("Failed to parse imitate config %s: %s", f, exc)
|
||||
return {}
|
||||
return raw.get("imitate", {}) or {}
|
||||
|
||||
|
||||
def _load_cforch_config() -> dict:
|
||||
"""Read cforch section for ollama_url fallback."""
|
||||
f = _config_file()
|
||||
if not f.exists():
|
||||
return {}
|
||||
try:
|
||||
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
||||
except yaml.YAMLError as exc:
|
||||
return {}
|
||||
return raw.get("cforch", {}) or {}
|
||||
|
||||
|
||||
def _ollama_url(cfg: dict) -> str:
|
||||
cforch = _load_cforch_config()
|
||||
return cfg.get("ollama_url") or cforch.get("ollama_url") or "http://localhost:11434"
|
||||
|
||||
|
||||
def _cforch_url() -> str:
|
||||
cforch = _load_cforch_config()
|
||||
return cforch.get("coordinator_url") or "http://localhost:7700"
|
||||
|
||||
|
||||
def _resolve_task_model(cforch_base: str, product: str, task: str) -> dict | None:
|
||||
"""Return {model_id, service_type} for a product.task assignment, or None if not found.
|
||||
|
||||
Calls GET coordinator/api/assignments and filters by product+task.
|
||||
The model registry entry is fetched separately to get service_type.
|
||||
Returns None (not raises) — callers emit a 'model_done' error event instead.
|
||||
"""
|
||||
try:
|
||||
asgn_resp = httpx.get(f"{cforch_base}/api/assignments", timeout=5.0)
|
||||
asgn_resp.raise_for_status()
|
||||
assignments: list[dict] = asgn_resp.json().get("assignments", []) or []
|
||||
match = next(
|
||||
(a for a in assignments if a.get("product") == product and a.get("task") == task),
|
||||
None,
|
||||
)
|
||||
if match is None:
|
||||
return None
|
||||
model_id: str = match.get("model_id", "")
|
||||
if not model_id:
|
||||
return None
|
||||
|
||||
# Look up service_type from model registry
|
||||
reg_resp = httpx.get(f"{cforch_base}/api/model-registry", timeout=5.0)
|
||||
service_type = "cf-text" # sensible default
|
||||
if reg_resp.is_success:
|
||||
models: list[dict] = reg_resp.json().get("models", []) or []
|
||||
reg_entry = next((m for m in models if m.get("model_id") == model_id), None)
|
||||
if reg_entry:
|
||||
service_type = reg_entry.get("service_type", "cf-text") or "cf-text"
|
||||
|
||||
return {"model_id": model_id, "service_type": service_type}
|
||||
except Exception as exc:
|
||||
logger.warning("Task resolution failed for %s.%s: %s", product, task, exc)
|
||||
return None
|
||||
|
||||
|
||||
def _cforch_catalog(cforch_base: str) -> list[dict]:
|
||||
"""Fetch the live cf-text catalog from cf-orch.
|
||||
|
||||
Filters out proxy entries (ollama://, vllm://, http://) — those models are
|
||||
served by their own services and should not be allocated via cf-text.
|
||||
Returns only models with real file-system paths that cf-text can load directly.
|
||||
"""
|
||||
try:
|
||||
resp = httpx.get(
|
||||
f"{cforch_base}/api/services/cf-text/catalog",
|
||||
params={"node_id": "heimdall"},
|
||||
timeout=5.0,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
raw = resp.json()
|
||||
result = []
|
||||
for model_id, entry in raw.items():
|
||||
if not isinstance(entry, dict):
|
||||
continue
|
||||
path = entry.get("path", "")
|
||||
# Skip proxy entries — they're routed through other services
|
||||
if "://" in path:
|
||||
continue
|
||||
result.append({
|
||||
"id": model_id,
|
||||
"vram_mb": entry.get("vram_mb", 0),
|
||||
"description": entry.get("description", ""),
|
||||
})
|
||||
return result
|
||||
except Exception as exc:
|
||||
logger.warning("Could not fetch cf-orch catalog: %s", exc)
|
||||
return []
|
||||
|
||||
|
||||
def _http_get_json(url: str, timeout: int = 5) -> Any:
|
||||
"""Fetch JSON from url; raise URLError on failure."""
|
||||
req = Request(url, headers={"Accept": "application/json"})
|
||||
with urlopen(req, timeout=timeout) as resp:
|
||||
return json.loads(resp.read().decode("utf-8"))
|
||||
|
||||
|
||||
def _is_online(base_url: str, health_path: str = "/api/health") -> bool:
|
||||
"""Return True if the product's health endpoint responds OK."""
|
||||
try:
|
||||
data = _http_get_json(f"{base_url.rstrip('/')}{health_path}", timeout=2)
|
||||
return bool(data)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _extract_sample(
|
||||
raw: Any,
|
||||
text_fields: list[str],
|
||||
sample_index: int = 0,
|
||||
sample_key: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Pull one item from a list or dict response and extract text_fields.
|
||||
|
||||
sample_key: if provided, unwrap raw[sample_key] before looking for a list.
|
||||
Falls back to a set of conventional envelope keys if sample_key is absent.
|
||||
"""
|
||||
item: dict[str, Any]
|
||||
if isinstance(raw, list):
|
||||
if not raw:
|
||||
return {}
|
||||
item = raw[min(sample_index, len(raw) - 1)]
|
||||
elif isinstance(raw, dict):
|
||||
# Use declared sample_key first, then fall back to conventional names.
|
||||
_ENVELOPE_KEYS = (
|
||||
"samples", "items", "results", "data", "jobs", "listings",
|
||||
"pantry", "saved_searches", "entries", "calls", "records",
|
||||
)
|
||||
search_keys = ([sample_key] if sample_key else []) + list(_ENVELOPE_KEYS)
|
||||
for key in search_keys:
|
||||
if key in raw and isinstance(raw[key], list):
|
||||
lst = raw[key]
|
||||
item = lst[min(sample_index, len(lst) - 1)] if lst else {}
|
||||
break
|
||||
else:
|
||||
item = raw
|
||||
else:
|
||||
return {}
|
||||
|
||||
parts = []
|
||||
for field in text_fields:
|
||||
val = item.get(field)
|
||||
if val and str(val).strip():
|
||||
parts.append(f"**{field}**: {val}")
|
||||
return {"item": item, "text": "\n\n".join(parts)}
|
||||
|
||||
|
||||
def _candidates_file() -> Path:
|
||||
return _DATA_DIR / "sft_candidates.jsonl"
|
||||
|
||||
|
||||
def _sse(data: dict) -> str:
|
||||
return f"data: {json.dumps(data)}\n\n"
|
||||
|
||||
|
||||
def _fetch_image_b64(image_url: str) -> str:
|
||||
"""Download an image URL and return it as a base64 string for ollama.
|
||||
|
||||
Returns empty string on any failure — a missing image is non-fatal;
|
||||
the model will still run against the text prompt alone.
|
||||
"""
|
||||
try:
|
||||
req = Request(image_url, headers={"User-Agent": "Avocet/1.0"})
|
||||
with urlopen(req, timeout=10) as resp:
|
||||
return base64.b64encode(resp.read()).decode("ascii")
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to fetch image %s: %s", image_url, exc)
|
||||
return ""
|
||||
|
||||
|
||||
def _run_ollama_streaming(
|
||||
ollama_base: str,
|
||||
model_id: str,
|
||||
prompt: str,
|
||||
temperature: float,
|
||||
system: str = "",
|
||||
images: list[str] | None = None,
|
||||
) -> tuple[str, int]:
|
||||
"""Call ollama /api/generate with stream=False; return (full_response, elapsed_ms).
|
||||
|
||||
Blocks until the model finishes; yields nothing — streaming is handled by
|
||||
the SSE generator in run_imitate().
|
||||
|
||||
system: optional system prompt passed as a separate field to ollama.
|
||||
images: list of base64-encoded image strings (vision models only).
|
||||
"""
|
||||
url = f"{ollama_base.rstrip('/')}/api/generate"
|
||||
body: dict = {
|
||||
"model": model_id,
|
||||
"prompt": prompt,
|
||||
"stream": False,
|
||||
"options": {"temperature": temperature},
|
||||
}
|
||||
if system:
|
||||
body["system"] = system
|
||||
if images:
|
||||
body["images"] = images
|
||||
payload = json.dumps(body).encode("utf-8")
|
||||
req = Request(url, data=payload, method="POST",
|
||||
headers={"Content-Type": "application/json"})
|
||||
t0 = time.time()
|
||||
try:
|
||||
with urlopen(req, timeout=120) as resp:
|
||||
body = json.loads(resp.read().decode("utf-8"))
|
||||
elapsed = int((time.time() - t0) * 1000)
|
||||
return body.get("response", ""), elapsed
|
||||
except Exception as exc:
|
||||
elapsed = int((time.time() - t0) * 1000)
|
||||
raise RuntimeError(str(exc)) from exc
|
||||
|
||||
|
||||
def _run_cftext(
|
||||
cforch_base: str,
|
||||
model_id: str,
|
||||
prompt: str,
|
||||
system: str,
|
||||
temperature: float,
|
||||
startup_timeout_s: float = 180.0,
|
||||
user_id: str | None = None,
|
||||
) -> tuple[str, int, bool]:
|
||||
"""Allocate cf-text via cf-orch, generate, release. Returns (response, elapsed_ms, cold_started).
|
||||
|
||||
Raises RuntimeError on allocation failure or generation error.
|
||||
cold_started=True means the service was launched from scratch (caller may log this).
|
||||
|
||||
Cold-start detection uses coordinator state signals (running/stopped) rather than
|
||||
polling the service health endpoint — this fails fast on model load errors instead
|
||||
of waiting out the full timeout.
|
||||
"""
|
||||
# Allocate
|
||||
alloc_resp = httpx.post(
|
||||
f"{cforch_base}/api/services/cf-text/allocate",
|
||||
json={
|
||||
"model_candidates": [model_id],
|
||||
"caller": "avocet",
|
||||
"pipeline": "imitate",
|
||||
**({"user_id": user_id} if user_id else {}),
|
||||
},
|
||||
timeout=30.0,
|
||||
)
|
||||
alloc_resp.raise_for_status()
|
||||
data = alloc_resp.json()
|
||||
service_url: str = data["url"]
|
||||
allocation_id: str = data.get("allocation_id", "")
|
||||
node_id: str = data.get("node_id", "")
|
||||
gpu_id: int | None = data.get("gpu_id")
|
||||
cold_started = data.get("started", False) and not data.get("warm", True)
|
||||
|
||||
# Wait for ready using coordinator state signals
|
||||
if cold_started:
|
||||
deadline = time.monotonic() + startup_timeout_s
|
||||
probe_misses = 0
|
||||
while time.monotonic() < deadline:
|
||||
try:
|
||||
status = httpx.get(
|
||||
f"{cforch_base}/api/services/cf-text/status", timeout=5.0
|
||||
)
|
||||
if status.is_success:
|
||||
instances = status.json().get("instances", [])
|
||||
match = next(
|
||||
(i for i in instances
|
||||
if i.get("node_id") == node_id and i.get("gpu_id") == gpu_id),
|
||||
None,
|
||||
)
|
||||
if match:
|
||||
probe_misses = 0
|
||||
state = match.get("state", "")
|
||||
if state == "running":
|
||||
break
|
||||
elif state == "stopped":
|
||||
if allocation_id:
|
||||
httpx.delete(
|
||||
f"{cforch_base}/api/services/cf-text/allocations/{allocation_id}",
|
||||
timeout=5.0,
|
||||
)
|
||||
raise RuntimeError(f"cf-text failed to load {model_id!r} (service stopped)")
|
||||
else:
|
||||
probe_misses += 1
|
||||
if probe_misses >= 6:
|
||||
# Coordinator hasn't registered instance yet — fall back to health poll
|
||||
try:
|
||||
if httpx.get(f"{service_url}/health", timeout=3.0).is_success:
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
except RuntimeError:
|
||||
raise
|
||||
except Exception:
|
||||
pass
|
||||
time.sleep(2.0)
|
||||
else:
|
||||
if allocation_id:
|
||||
httpx.delete(f"{cforch_base}/api/services/cf-text/allocations/{allocation_id}", timeout=5.0)
|
||||
raise RuntimeError(f"cf-text cold start timed out after {startup_timeout_s:.0f}s")
|
||||
|
||||
# Generate
|
||||
messages: list[dict] = []
|
||||
if system:
|
||||
messages.append({"role": "system", "content": system})
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
t0 = time.time()
|
||||
try:
|
||||
gen_resp = httpx.post(
|
||||
f"{service_url}/v1/chat/completions",
|
||||
json={
|
||||
"model": model_id,
|
||||
"messages": messages,
|
||||
"max_tokens": 300,
|
||||
"temperature": temperature,
|
||||
"stream": False,
|
||||
},
|
||||
timeout=120.0,
|
||||
)
|
||||
gen_resp.raise_for_status()
|
||||
elapsed_ms = int((time.time() - t0) * 1000)
|
||||
content = gen_resp.json()["choices"][0]["message"]["content"]
|
||||
return content.strip(), elapsed_ms, cold_started
|
||||
except Exception as exc:
|
||||
elapsed_ms = int((time.time() - t0) * 1000)
|
||||
raise RuntimeError(str(exc)) from exc
|
||||
finally:
|
||||
if allocation_id:
|
||||
try:
|
||||
httpx.delete(f"{cforch_base}/api/services/cf-text/allocations/{allocation_id}", timeout=5.0)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
# ── GET /products ──────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/products")
|
||||
def get_products() -> dict:
|
||||
"""List configured CF products with live online status."""
|
||||
cfg = _load_imitate_config()
|
||||
products_raw = cfg.get("products", []) or []
|
||||
products = []
|
||||
for p in products_raw:
|
||||
if not isinstance(p, dict):
|
||||
continue
|
||||
base_url = p.get("base_url", "")
|
||||
products.append({
|
||||
"id": p.get("id", ""),
|
||||
"name": p.get("name", ""),
|
||||
"icon": p.get("icon", "📦"),
|
||||
"description": p.get("description", ""),
|
||||
"base_url": base_url,
|
||||
"online": _is_online(base_url, p.get("health_path", "/api/health")) if base_url else False,
|
||||
})
|
||||
return {"products": products}
|
||||
|
||||
|
||||
# ── GET /products/{product_id}/sample ─────────────────────────────────────────
|
||||
|
||||
@router.get("/products/{product_id}/sample")
|
||||
def get_sample(product_id: str, index: int = 0) -> dict:
|
||||
"""Fetch a real sample from the given product's API."""
|
||||
cfg = _load_imitate_config()
|
||||
products_raw = cfg.get("products", []) or []
|
||||
|
||||
product: dict | None = None
|
||||
for p in products_raw:
|
||||
if isinstance(p, dict) and p.get("id") == product_id:
|
||||
product = p
|
||||
break
|
||||
|
||||
if product is None:
|
||||
raise HTTPException(404, f"Product '{product_id}' not in config")
|
||||
|
||||
base_url = product.get("base_url", "").rstrip("/")
|
||||
endpoint = product.get("sample_endpoint", "")
|
||||
if not base_url or not endpoint:
|
||||
raise HTTPException(422, "Product missing base_url or sample_endpoint")
|
||||
|
||||
url = f"{base_url}{endpoint}"
|
||||
try:
|
||||
raw = _http_get_json(url, timeout=5)
|
||||
except URLError as exc:
|
||||
raise HTTPException(503, f"Product API unreachable: {exc}") from exc
|
||||
except Exception as exc:
|
||||
raise HTTPException(502, f"Bad response from product API: {exc}") from exc
|
||||
|
||||
text_fields = product.get("text_fields", []) or []
|
||||
sample_key = product.get("sample_key") or None
|
||||
extracted = _extract_sample(raw, text_fields, index, sample_key=sample_key)
|
||||
if not extracted:
|
||||
raise HTTPException(404, "No sample items returned by product API")
|
||||
|
||||
prompt_template = product.get("prompt_template", "{text}")
|
||||
prompt = prompt_template.replace("{text}", extracted["text"])
|
||||
# Also substitute any {field_name} placeholders from the raw item fields.
|
||||
item = extracted.get("item", {})
|
||||
for field, val in item.items():
|
||||
prompt = prompt.replace(f"{{{field}}}", str(val) if val is not None else "")
|
||||
|
||||
# Expose system_prompt and image_url if the product API returns them.
|
||||
# system_prompt: Peregrine, Snipe (vision analysis instructions)
|
||||
# image_url: Snipe listing photos — Avocet downloads + base64-encodes at run time
|
||||
item = extracted.get("item", {})
|
||||
system_prompt = str(item.get("system_prompt", "")) if isinstance(item, dict) else ""
|
||||
image_url = str(item.get("image_url", "")) if isinstance(item, dict) else ""
|
||||
|
||||
return {
|
||||
"product_id": product_id,
|
||||
"sample_index": index,
|
||||
"text": extracted["text"],
|
||||
"prompt": prompt,
|
||||
"system_prompt": system_prompt,
|
||||
"image_url": image_url,
|
||||
"raw_item": item,
|
||||
}
|
||||
|
||||
|
||||
# ── GET /catalog ───────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/catalog")
|
||||
def get_catalog() -> dict:
|
||||
"""Return the live cf-text model catalog from cf-orch coordinator."""
|
||||
models = _cforch_catalog(_cforch_url())
|
||||
return {"models": models}
|
||||
|
||||
|
||||
# ── GET /run (SSE) ─────────────────────────────────────────────────────────────
|
||||
|
||||
def _get_imitate_session(request: Any, response: Any) -> "CloudUser | None":
|
||||
"""Optional session dependency — returns None when cloud_session is unavailable."""
|
||||
try:
|
||||
from app.cloud_session import get_session
|
||||
return get_session(request, response)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
@router.get("/run")
|
||||
def run_imitate(
|
||||
prompt: str = "",
|
||||
model_ids: str = "", # comma-separated ollama model IDs
|
||||
cf_text_model_ids: str = "", # comma-separated cf-text model IDs (via cf-orch)
|
||||
task_ids: str = "", # comma-separated "product/task" strings — resolved via assignments
|
||||
temperature: float = 0.7,
|
||||
product_id: str = "",
|
||||
system: str = "", # optional system prompt
|
||||
image_url: str = "", # optional image URL for vision models
|
||||
session: "Any" = Depends(_get_imitate_session),
|
||||
) -> StreamingResponse:
|
||||
"""Run a prompt through selected models and stream results as SSE.
|
||||
|
||||
Models can be selected three ways (combinable):
|
||||
- model_ids: explicit ollama model IDs
|
||||
- cf_text_model_ids: explicit cf-text model IDs routed via cf-orch
|
||||
- task_ids: "product/task" strings resolved via the coordinator assignments table
|
||||
|
||||
If image_url is provided, the image is downloaded once and passed to every
|
||||
model as a base64-encoded blob — allowing vision-capable local models to
|
||||
evaluate listing photos the same way Snipe's background task pipeline does.
|
||||
"""
|
||||
|
||||
if not prompt.strip():
|
||||
raise HTTPException(422, "prompt is required")
|
||||
|
||||
ollama_ids = [m.strip() for m in model_ids.split(",") if m.strip()]
|
||||
cftext_ids = [m.strip() for m in cf_text_model_ids.split(",") if m.strip()]
|
||||
raw_task_ids = [t.strip() for t in task_ids.split(",") if t.strip()]
|
||||
|
||||
# Resolve task assignments to concrete model IDs, routing to the right service.
|
||||
# Models that fail to resolve emit an error event at run time (non-fatal).
|
||||
if raw_task_ids:
|
||||
cforch_base = _cforch_url()
|
||||
for task_spec in raw_task_ids:
|
||||
parts = task_spec.split("/", 1)
|
||||
if len(parts) != 2:
|
||||
logger.warning("Skipping malformed task_id %r (expected product/task)", task_spec)
|
||||
continue
|
||||
product_name, task_name = parts
|
||||
resolved = _resolve_task_model(cforch_base, product_name, task_name)
|
||||
if resolved is None:
|
||||
logger.warning("No assignment found for task %r", task_spec)
|
||||
# Emit error at stream time via a sentinel in cftext_ids with a special label.
|
||||
# We instead store the failed task_spec to emit a model_done error.
|
||||
cftext_ids.append(f"__task_unresolved__:{task_spec}")
|
||||
continue
|
||||
mid = resolved["model_id"]
|
||||
svc = resolved["service_type"]
|
||||
if svc == "ollama":
|
||||
if mid not in ollama_ids:
|
||||
ollama_ids.append(mid)
|
||||
else:
|
||||
# cf-text, vllm, and any other cf-orch-managed service
|
||||
if mid not in cftext_ids:
|
||||
cftext_ids.append(mid)
|
||||
|
||||
if not ollama_ids and not cftext_ids:
|
||||
raise HTTPException(422, "model_ids, cf_text_model_ids, or task_ids is required")
|
||||
|
||||
cfg = _load_imitate_config()
|
||||
ollama_base = _ollama_url(cfg)
|
||||
cforch_base = _cforch_url()
|
||||
system_ctx = system.strip() or ""
|
||||
total_models = len(ollama_ids) + len(cftext_ids)
|
||||
|
||||
# Download image once before streaming — shared across ollama vision models
|
||||
images: list[str] = []
|
||||
if image_url.strip():
|
||||
b64 = _fetch_image_b64(image_url.strip())
|
||||
if b64:
|
||||
images = [b64]
|
||||
|
||||
def generate():
|
||||
results: list[dict] = []
|
||||
yield _sse({"type": "start", "total_models": total_models, "has_image": bool(images)})
|
||||
|
||||
# Ollama models
|
||||
for model_id in ollama_ids:
|
||||
yield _sse({"type": "model_start", "model": model_id, "service": "ollama"})
|
||||
try:
|
||||
response, elapsed_ms = _run_ollama_streaming(
|
||||
ollama_base, model_id, prompt, temperature,
|
||||
system=system_ctx, images=images or None,
|
||||
)
|
||||
result = {
|
||||
"model": model_id,
|
||||
"response": response,
|
||||
"elapsed_ms": elapsed_ms,
|
||||
"error": None,
|
||||
}
|
||||
except Exception as exc:
|
||||
result = {
|
||||
"model": model_id,
|
||||
"response": "",
|
||||
"elapsed_ms": 0,
|
||||
"error": str(exc),
|
||||
}
|
||||
results.append(result)
|
||||
yield _sse({"type": "model_done", **result})
|
||||
|
||||
# cf-text models via cf-orch — fan out in parallel when multiple models selected
|
||||
# Partition the list: real cf-text IDs vs unresolved-task sentinels.
|
||||
cftext_real = [m for m in cftext_ids if not m.startswith("__task_unresolved__:")]
|
||||
cftext_unresolved = [m for m in cftext_ids if m.startswith("__task_unresolved__:")]
|
||||
for sentinel in cftext_unresolved:
|
||||
task_spec = sentinel.split(":", 1)[1]
|
||||
result = {
|
||||
"model": task_spec,
|
||||
"response": "",
|
||||
"elapsed_ms": 0,
|
||||
"error": f"No assignment configured for task '{task_spec}'",
|
||||
}
|
||||
results.append(result)
|
||||
yield _sse({"type": "model_done", **result})
|
||||
|
||||
if cftext_real:
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
# Announce all models upfront so the UI can show loading states immediately
|
||||
for model_id in cftext_real:
|
||||
yield _sse({"type": "model_start", "model": model_id, "service": "cf-text"})
|
||||
|
||||
_user_id: str | None = getattr(session, "user_id", None)
|
||||
# Only forward real cloud user IDs — skip local/anon sessions
|
||||
if _user_id in (None, "local", "local-dev") or (_user_id or "").startswith("anon-"):
|
||||
_user_id = None
|
||||
|
||||
with ThreadPoolExecutor(max_workers=len(cftext_real)) as pool:
|
||||
future_to_model = {
|
||||
pool.submit(
|
||||
_run_cftext, cforch_base, mid, prompt, system_ctx, temperature,
|
||||
180.0, _user_id,
|
||||
): mid
|
||||
for mid in cftext_real
|
||||
}
|
||||
for future in as_completed(future_to_model):
|
||||
model_id = future_to_model[future]
|
||||
try:
|
||||
response, elapsed_ms, cold_started = future.result()
|
||||
if cold_started:
|
||||
yield _sse({"type": "model_coldstart", "model": model_id})
|
||||
result = {
|
||||
"model": model_id,
|
||||
"response": response,
|
||||
"elapsed_ms": elapsed_ms,
|
||||
"error": None,
|
||||
}
|
||||
except Exception as exc:
|
||||
result = {
|
||||
"model": model_id,
|
||||
"response": "",
|
||||
"elapsed_ms": 0,
|
||||
"error": str(exc),
|
||||
}
|
||||
results.append(result)
|
||||
yield _sse({"type": "model_done", **result})
|
||||
|
||||
yield _sse({"type": "complete", "results": results})
|
||||
|
||||
return StreamingResponse(
|
||||
generate(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# ── POST /push-corrections ─────────────────────────────────────────────────────
|
||||
|
||||
class ImitateResult(BaseModel):
|
||||
model: str
|
||||
response: str
|
||||
elapsed_ms: int
|
||||
error: str | None = None
|
||||
|
||||
|
||||
class PushCorrectionsRequest(BaseModel):
|
||||
product_id: str
|
||||
prompt: str
|
||||
results: list[ImitateResult]
|
||||
|
||||
|
||||
@router.post("/push-corrections")
|
||||
def push_corrections(req: PushCorrectionsRequest) -> dict:
|
||||
"""Append imitate results to sft_candidates.jsonl for human review."""
|
||||
if not req.prompt.strip():
|
||||
raise HTTPException(422, "prompt is required")
|
||||
if not req.results:
|
||||
raise HTTPException(422, "results list is empty")
|
||||
|
||||
ts = datetime.now(timezone.utc).isoformat()
|
||||
records = []
|
||||
for r in req.results:
|
||||
if r.error or not r.response.strip():
|
||||
continue
|
||||
records.append({
|
||||
"id": str(uuid.uuid4()),
|
||||
"source": "imitate",
|
||||
"product_id": req.product_id,
|
||||
"prompt_messages": [{"role": "user", "content": req.prompt}],
|
||||
"model_response": r.response,
|
||||
"model_id": r.model,
|
||||
"elapsed_ms": r.elapsed_ms,
|
||||
"status": "pending",
|
||||
"created_at": ts,
|
||||
})
|
||||
|
||||
if not records:
|
||||
raise HTTPException(422, "No non-error results to push")
|
||||
|
||||
dest = _candidates_file()
|
||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||
for record in records:
|
||||
append_jsonl(dest, record)
|
||||
|
||||
return {"pushed": len(records)}
|
||||
222
app/data/label.py
Normal file
222
app/data/label.py
Normal file
|
|
@ -0,0 +1,222 @@
|
|||
"""Avocet -- label queue API.
|
||||
|
||||
All label/skip/discard/undo/stats/config endpoints.
|
||||
Extracted from app/api.py as part of the v2 domain split.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import yaml
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
from fastapi.responses import FileResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.utils import append_jsonl, read_jsonl, write_jsonl
|
||||
|
||||
_ROOT = Path(__file__).parent.parent.parent
|
||||
_DATA_DIR: Path = _ROOT / "data"
|
||||
_CONFIG_DIR: Path | None = None
|
||||
_last_action: dict | None = None
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def set_data_dir(path: Path) -> None:
|
||||
global _DATA_DIR
|
||||
_DATA_DIR = path
|
||||
|
||||
def set_config_dir(path: Path | None) -> None:
|
||||
global _CONFIG_DIR
|
||||
_CONFIG_DIR = path
|
||||
|
||||
def reset_last_action() -> None:
|
||||
global _last_action
|
||||
_last_action = None
|
||||
|
||||
def _config_file() -> Path:
|
||||
if _CONFIG_DIR is not None:
|
||||
return _CONFIG_DIR / "label_tool.yaml"
|
||||
return _ROOT / "config" / "label_tool.yaml"
|
||||
|
||||
def _queue_file() -> Path:
|
||||
return _DATA_DIR / "email_label_queue.jsonl"
|
||||
|
||||
def _score_file() -> Path:
|
||||
return _DATA_DIR / "email_score.jsonl"
|
||||
|
||||
def _discarded_file() -> Path:
|
||||
return _DATA_DIR / "discarded.jsonl"
|
||||
|
||||
def _item_id(item: dict) -> str:
|
||||
key = (item.get("subject", "") + (item.get("body", "") or "")[:100])
|
||||
return hashlib.md5(key.encode("utf-8", errors="replace")).hexdigest()
|
||||
|
||||
def _normalize(item: dict) -> dict:
|
||||
return {
|
||||
"id": item.get("id") or _item_id(item),
|
||||
"subject": item.get("subject", ""),
|
||||
"body": item.get("body", ""),
|
||||
"from": item.get("from") or item.get("from_addr", ""),
|
||||
"date": item.get("date", ""),
|
||||
"source": item.get("source") or item.get("account", ""),
|
||||
}
|
||||
|
||||
_LABEL_META = [
|
||||
{"name": "interview_scheduled", "emoji": "\U0001f4c5", "color": "#4CAF50", "key": "1"},
|
||||
{"name": "offer_received", "emoji": "\U0001f389", "color": "#2196F3", "key": "2"},
|
||||
{"name": "rejected", "emoji": "❌", "color": "#F44336", "key": "3"},
|
||||
{"name": "positive_response", "emoji": "\U0001f44d", "color": "#FF9800", "key": "4"},
|
||||
{"name": "survey_received", "emoji": "\U0001f4cb", "color": "#9C27B0", "key": "5"},
|
||||
{"name": "neutral", "emoji": "⬜", "color": "#607D8B", "key": "6"},
|
||||
{"name": "event_rescheduled", "emoji": "\U0001f504", "color": "#FF5722", "key": "7"},
|
||||
{"name": "digest", "emoji": "\U0001f4f0", "color": "#00BCD4", "key": "8"},
|
||||
{"name": "new_lead", "emoji": "\U0001f91d", "color": "#009688", "key": "9"},
|
||||
{"name": "hired", "emoji": "\U0001f38a", "color": "#FFC107", "key": "h"},
|
||||
]
|
||||
|
||||
@router.get("/queue")
|
||||
def get_queue(limit: int = Query(default=10, ge=1, le=50)):
|
||||
items = read_jsonl(_queue_file())
|
||||
return {"items": [_normalize(x) for x in items[:limit]], "total": len(items)}
|
||||
|
||||
class LabelRequest(BaseModel):
|
||||
id: str
|
||||
label: str
|
||||
|
||||
@router.post("/label")
|
||||
def post_label(req: LabelRequest):
|
||||
global _last_action
|
||||
items = read_jsonl(_queue_file())
|
||||
match = next((x for x in items if _normalize(x)["id"] == req.id), None)
|
||||
if not match:
|
||||
raise HTTPException(404, f"Item {req.id!r} not found in queue")
|
||||
record = {**match, "label": req.label,
|
||||
"labeled_at": datetime.now(timezone.utc).isoformat()}
|
||||
append_jsonl(_score_file(), record)
|
||||
write_jsonl(_queue_file(), [x for x in items if _normalize(x)["id"] != req.id])
|
||||
_last_action = {"type": "label", "item": match, "label": req.label}
|
||||
return {"ok": True}
|
||||
|
||||
class SkipRequest(BaseModel):
|
||||
id: str
|
||||
|
||||
@router.post("/skip")
|
||||
def post_skip(req: SkipRequest):
|
||||
global _last_action
|
||||
items = read_jsonl(_queue_file())
|
||||
match = next((x for x in items if _normalize(x)["id"] == req.id), None)
|
||||
if not match:
|
||||
raise HTTPException(404, f"Item {req.id!r} not found in queue")
|
||||
reordered = [x for x in items if _normalize(x)["id"] != req.id] + [match]
|
||||
write_jsonl(_queue_file(), reordered)
|
||||
_last_action = {"type": "skip", "item": match}
|
||||
return {"ok": True}
|
||||
|
||||
class DiscardRequest(BaseModel):
|
||||
id: str
|
||||
|
||||
@router.post("/discard")
|
||||
def post_discard(req: DiscardRequest):
|
||||
global _last_action
|
||||
items = read_jsonl(_queue_file())
|
||||
match = next((x for x in items if _normalize(x)["id"] == req.id), None)
|
||||
if not match:
|
||||
raise HTTPException(404, f"Item {req.id!r} not found in queue")
|
||||
record = {**match, "label": "__discarded__",
|
||||
"discarded_at": datetime.now(timezone.utc).isoformat()}
|
||||
append_jsonl(_discarded_file(), record)
|
||||
write_jsonl(_queue_file(), [x for x in items if _normalize(x)["id"] != req.id])
|
||||
_last_action = {"type": "discard", "item": match}
|
||||
return {"ok": True}
|
||||
|
||||
@router.delete("/label/undo")
|
||||
def delete_undo():
|
||||
global _last_action
|
||||
if not _last_action:
|
||||
raise HTTPException(404, "No action to undo")
|
||||
action = _last_action
|
||||
item = action["item"]
|
||||
if action["type"] == "label":
|
||||
records = read_jsonl(_score_file())
|
||||
if not records:
|
||||
raise HTTPException(409, "Score file is empty -- cannot undo label")
|
||||
write_jsonl(_score_file(), records[:-1])
|
||||
items = read_jsonl(_queue_file())
|
||||
write_jsonl(_queue_file(), [item] + items)
|
||||
elif action["type"] == "discard":
|
||||
records = read_jsonl(_discarded_file())
|
||||
if not records:
|
||||
raise HTTPException(409, "Discarded file is empty -- cannot undo discard")
|
||||
write_jsonl(_discarded_file(), records[:-1])
|
||||
items = read_jsonl(_queue_file())
|
||||
write_jsonl(_queue_file(), [item] + items)
|
||||
elif action["type"] == "skip":
|
||||
items = read_jsonl(_queue_file())
|
||||
item_id = _normalize(item)["id"]
|
||||
items = [item] + [x for x in items if _normalize(x)["id"] != item_id]
|
||||
write_jsonl(_queue_file(), items)
|
||||
_last_action = None
|
||||
return {"undone": {"type": action["type"], "item": _normalize(item)}}
|
||||
|
||||
@router.get("/config/labels")
|
||||
def get_labels():
|
||||
return _LABEL_META
|
||||
|
||||
@router.get("/config")
|
||||
def get_config():
|
||||
f = _config_file()
|
||||
if not f.exists():
|
||||
return {"accounts": [], "max_per_account": 500}
|
||||
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
||||
return {"accounts": raw.get("accounts", []), "max_per_account": raw.get("max_per_account", 500)}
|
||||
|
||||
class ConfigPayload(BaseModel):
|
||||
accounts: list[dict]
|
||||
max_per_account: int = 500
|
||||
|
||||
@router.post("/config")
|
||||
def post_config(payload: ConfigPayload):
|
||||
f = _config_file()
|
||||
f.parent.mkdir(parents=True, exist_ok=True)
|
||||
tmp = f.with_suffix(".tmp")
|
||||
tmp.write_text(yaml.dump(payload.model_dump(), allow_unicode=True, sort_keys=False),
|
||||
encoding="utf-8")
|
||||
tmp.rename(f)
|
||||
return {"ok": True}
|
||||
|
||||
@router.get("/stats")
|
||||
def get_stats():
|
||||
records = read_jsonl(_score_file())
|
||||
counts: dict[str, int] = {}
|
||||
for r in records:
|
||||
lbl = r.get("label", "")
|
||||
if lbl:
|
||||
counts[lbl] = counts.get(lbl, 0) + 1
|
||||
benchmark_results: dict = {}
|
||||
benchmark_path = _DATA_DIR / "benchmark_results.json"
|
||||
if benchmark_path.exists():
|
||||
try:
|
||||
benchmark_results = json.loads(benchmark_path.read_text(encoding="utf-8"))
|
||||
except Exception:
|
||||
pass
|
||||
return {
|
||||
"total": len(records),
|
||||
"counts": counts,
|
||||
"score_file_bytes": _score_file().stat().st_size if _score_file().exists() else 0,
|
||||
"benchmark_results": benchmark_results,
|
||||
}
|
||||
|
||||
@router.get("/stats/download")
|
||||
def download_stats():
|
||||
if not _score_file().exists():
|
||||
raise HTTPException(404, "No score file")
|
||||
return FileResponse(
|
||||
str(_score_file()),
|
||||
filename="email_score.jsonl",
|
||||
media_type="application/jsonlines",
|
||||
headers={"Content-Disposition": 'attachment; filename="email_score.jsonl"'},
|
||||
)
|
||||
462
app/data/log_corpus.py
Normal file
462
app/data/log_corpus.py
Normal file
|
|
@ -0,0 +1,462 @@
|
|||
"""Avocet — Log Corpus receiver and labeling API.
|
||||
|
||||
Receives push batches from consented Turnstone nodes, stores entries for labeling,
|
||||
and exports labeled data as JSONL for the logreading fine-tune pipeline.
|
||||
|
||||
DB: data/corpus.db (separate from train_jobs.db — different lifecycle)
|
||||
Auth: Bearer token validated against corpus_sources table (seeded from label_tool.yaml).
|
||||
|
||||
All endpoints registered on `router`. api.py includes this with prefix="/api/corpus".
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import sqlite3
|
||||
import uuid
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Generator
|
||||
|
||||
import yaml
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.requests import Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_ROOT = Path(__file__).parent.parent.parent
|
||||
_CONFIG_DIR: Path | None = None
|
||||
_DATA_DIR: Path = _ROOT / "data"
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
_DB_PATH: Path = _ROOT / "data" / "corpus.db"
|
||||
|
||||
_PIPELINE_SOURCE_HOST = "pipeline_scrape"
|
||||
|
||||
_SCHEMA = """
|
||||
CREATE TABLE IF NOT EXISTS corpus_sources (
|
||||
token TEXT PRIMARY KEY,
|
||||
source_host TEXT NOT NULL,
|
||||
owner TEXT NOT NULL,
|
||||
consent_date TEXT NOT NULL,
|
||||
consent_method TEXT NOT NULL,
|
||||
active INTEGER NOT NULL DEFAULT 1
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS corpus_batches (
|
||||
id TEXT PRIMARY KEY,
|
||||
source_host TEXT NOT NULL,
|
||||
batch_type TEXT NOT NULL,
|
||||
received_at TEXT NOT NULL,
|
||||
entry_count INTEGER NOT NULL,
|
||||
watermark_from TEXT,
|
||||
watermark_to TEXT,
|
||||
raw_json TEXT NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS corpus_entries (
|
||||
id TEXT PRIMARY KEY,
|
||||
batch_id TEXT NOT NULL REFERENCES corpus_batches(id),
|
||||
source_host TEXT NOT NULL,
|
||||
origin_entry_id TEXT,
|
||||
timestamp_iso TEXT,
|
||||
severity TEXT,
|
||||
source_id TEXT,
|
||||
text TEXT NOT NULL,
|
||||
matched_patterns TEXT DEFAULT '[]',
|
||||
label_state TEXT NOT NULL DEFAULT 'unlabeled',
|
||||
failure_type TEXT,
|
||||
plain_explanation TEXT,
|
||||
known_pattern TEXT,
|
||||
labeled_at TEXT,
|
||||
labeled_by TEXT DEFAULT 'alan',
|
||||
pii_flagged INTEGER NOT NULL DEFAULT 0
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_ce_label_state ON corpus_entries(label_state);
|
||||
CREATE INDEX IF NOT EXISTS idx_ce_source ON corpus_entries(source_host);
|
||||
CREATE INDEX IF NOT EXISTS idx_ce_severity ON corpus_entries(severity);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS ingested_pipeline_files (
|
||||
filename TEXT PRIMARY KEY,
|
||||
ingested_at TEXT NOT NULL,
|
||||
entry_count INTEGER NOT NULL
|
||||
);
|
||||
"""
|
||||
|
||||
|
||||
# ── Testability seams ──────────────────────────────────────────────────────────
|
||||
|
||||
def set_config_dir(path: Path | None) -> None:
|
||||
global _CONFIG_DIR
|
||||
_CONFIG_DIR = path
|
||||
|
||||
|
||||
def set_data_dir(path: Path) -> None:
|
||||
global _DATA_DIR, _DB_PATH
|
||||
_DATA_DIR = path
|
||||
_DB_PATH = path / "corpus.db"
|
||||
|
||||
|
||||
# ── Internal helpers ───────────────────────────────────────────────────────────
|
||||
|
||||
def _config_file() -> Path:
|
||||
if _CONFIG_DIR is not None:
|
||||
return _CONFIG_DIR / "label_tool.yaml"
|
||||
return _ROOT / "config" / "label_tool.yaml"
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _db() -> Generator[sqlite3.Connection, None, None]:
|
||||
conn = sqlite3.connect(str(_DB_PATH))
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
try:
|
||||
yield conn
|
||||
conn.commit()
|
||||
except Exception:
|
||||
conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
def _init_db() -> None:
|
||||
with _db() as conn:
|
||||
conn.executescript(_SCHEMA)
|
||||
_seed_sources(conn)
|
||||
|
||||
|
||||
def _pipeline_ingest_dir() -> Path | None:
|
||||
"""Return the configured pipeline log ingest directory, or None if unset."""
|
||||
f = _config_file()
|
||||
if not f.exists():
|
||||
return None
|
||||
try:
|
||||
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
||||
except yaml.YAMLError:
|
||||
return None
|
||||
val = raw.get("corpus", {}).get("pipeline_ingest_dir", "") or ""
|
||||
return Path(val) if val else None
|
||||
|
||||
|
||||
def _load_corpus_config() -> list[dict]:
|
||||
f = _config_file()
|
||||
if not f.exists():
|
||||
return []
|
||||
try:
|
||||
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
||||
except yaml.YAMLError as exc:
|
||||
logger.warning("Failed to parse corpus config: %s", exc)
|
||||
return []
|
||||
return raw.get("corpus", {}).get("sources", []) or []
|
||||
|
||||
|
||||
def _seed_sources(conn: sqlite3.Connection) -> None:
|
||||
for src in _load_corpus_config():
|
||||
conn.execute(
|
||||
"INSERT OR IGNORE INTO corpus_sources (token, source_host, owner, consent_date, consent_method) "
|
||||
"VALUES (?, ?, ?, ?, ?)",
|
||||
(src["token"], src["source_host"], src["owner"],
|
||||
src["consent_date"], src["consent_method"]),
|
||||
)
|
||||
|
||||
|
||||
def _validate_token(token: str, conn: sqlite3.Connection) -> str:
|
||||
"""Return source_host for token, or raise 403."""
|
||||
row = conn.execute(
|
||||
"SELECT source_host FROM corpus_sources WHERE token = ? AND active = 1",
|
||||
(token,),
|
||||
).fetchone()
|
||||
if row is None:
|
||||
raise HTTPException(status_code=403, detail="Unknown or revoked consent token")
|
||||
return row["source_host"]
|
||||
|
||||
|
||||
def _extract_bearer(request: Request) -> str:
|
||||
auth = request.headers.get("Authorization", "")
|
||||
if not auth.startswith("Bearer "):
|
||||
raise HTTPException(status_code=401, detail="Bearer token required")
|
||||
return auth.removeprefix("Bearer ").strip()
|
||||
|
||||
|
||||
def _now_iso() -> str:
|
||||
return datetime.now(timezone.utc).isoformat()
|
||||
|
||||
|
||||
# ── Startup ────────────────────────────────────────────────────────────────────
|
||||
|
||||
_init_db()
|
||||
|
||||
|
||||
# ── POST /api/corpus/log-batch ─────────────────────────────────────────────────
|
||||
|
||||
@router.post("/log-batch")
|
||||
def receive_batch(request: Request, payload: dict) -> dict:
|
||||
"""Accept a push batch from a Turnstone node."""
|
||||
token = _extract_bearer(request)
|
||||
|
||||
batch_type = payload.get("batch_type", "raw_entries")
|
||||
entries_raw = payload.get("entries", [])
|
||||
batch_id = payload.get("batch_id") or str(uuid.uuid4())
|
||||
|
||||
with _db() as conn:
|
||||
source_host = _validate_token(token, conn)
|
||||
|
||||
conn.execute(
|
||||
"INSERT INTO corpus_batches (id, source_host, batch_type, received_at, entry_count, "
|
||||
"watermark_from, watermark_to, raw_json) VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
(batch_id, source_host, batch_type, _now_iso(), len(entries_raw),
|
||||
str(payload.get("watermark_from", "")),
|
||||
str(payload.get("watermark_to", "")),
|
||||
json.dumps(payload)),
|
||||
)
|
||||
|
||||
stored = 0
|
||||
for entry in entries_raw:
|
||||
text = entry.get("text", "").strip()
|
||||
if not text:
|
||||
continue
|
||||
conn.execute(
|
||||
"INSERT OR IGNORE INTO corpus_entries "
|
||||
"(id, batch_id, source_host, origin_entry_id, timestamp_iso, severity, "
|
||||
"source_id, text, matched_patterns) "
|
||||
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
(str(uuid.uuid4()), batch_id, source_host,
|
||||
entry.get("entry_id") or entry.get("id"),
|
||||
entry.get("timestamp_iso"),
|
||||
entry.get("severity"),
|
||||
entry.get("source_id"),
|
||||
text,
|
||||
json.dumps(entry.get("matched_patterns", []))),
|
||||
)
|
||||
stored += 1
|
||||
|
||||
logger.info("Received batch %s from %s: %d/%d entries stored",
|
||||
batch_id, source_host, stored, len(entries_raw))
|
||||
return {"received": True, "batch_id": batch_id, "entries_stored": stored}
|
||||
|
||||
|
||||
# ── GET /api/corpus/entries ────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/entries")
|
||||
def list_entries(
|
||||
state: str = "unlabeled",
|
||||
source_host: str | None = None,
|
||||
limit: int = 25,
|
||||
) -> dict:
|
||||
"""Return entries for labeling. Default: unlabeled entries, oldest first."""
|
||||
with _db() as conn:
|
||||
query = "SELECT * FROM corpus_entries WHERE label_state = ?"
|
||||
params: list = [state]
|
||||
if source_host:
|
||||
query += " AND source_host = ?"
|
||||
params.append(source_host)
|
||||
query += " ORDER BY rowid LIMIT ?"
|
||||
params.append(min(limit, 100))
|
||||
rows = conn.execute(query, params).fetchall()
|
||||
return {"entries": [dict(r) for r in rows], "count": len(rows)}
|
||||
|
||||
|
||||
# ── POST /api/corpus/entries/{id}/label ───────────────────────────────────────
|
||||
|
||||
@router.post("/entries/{entry_id}/label")
|
||||
def label_entry(entry_id: str, body: dict) -> dict:
|
||||
"""Submit a label for a corpus entry."""
|
||||
failure_type = body.get("failure_type")
|
||||
plain_explanation = body.get("plain_explanation", "").strip()
|
||||
known_pattern = body.get("known_pattern")
|
||||
pii_flagged = int(bool(body.get("pii_flagged", False)))
|
||||
|
||||
if not failure_type:
|
||||
raise HTTPException(status_code=422, detail="failure_type is required")
|
||||
valid_types = {"hardware", "software", "network", "security", "application", "none", "other"}
|
||||
if failure_type not in valid_types:
|
||||
raise HTTPException(status_code=422, detail=f"failure_type must be one of {sorted(valid_types)}")
|
||||
|
||||
with _db() as conn:
|
||||
row = conn.execute("SELECT id FROM corpus_entries WHERE id = ?", (entry_id,)).fetchone()
|
||||
if row is None:
|
||||
raise HTTPException(status_code=404, detail="Entry not found")
|
||||
conn.execute(
|
||||
"UPDATE corpus_entries SET label_state='labeled', failure_type=?, plain_explanation=?, "
|
||||
"known_pattern=?, labeled_at=?, pii_flagged=? WHERE id=?",
|
||||
(failure_type, plain_explanation, known_pattern, _now_iso(), pii_flagged, entry_id),
|
||||
)
|
||||
return {"labeled": True, "entry_id": entry_id}
|
||||
|
||||
|
||||
# ── POST /api/corpus/entries/{id}/skip ────────────────────────────────────────
|
||||
|
||||
@router.post("/entries/{entry_id}/skip")
|
||||
def skip_entry(entry_id: str) -> dict:
|
||||
with _db() as conn:
|
||||
row = conn.execute("SELECT id FROM corpus_entries WHERE id = ?", (entry_id,)).fetchone()
|
||||
if row is None:
|
||||
raise HTTPException(status_code=404, detail="Entry not found")
|
||||
conn.execute(
|
||||
"UPDATE corpus_entries SET label_state='skipped' WHERE id=?", (entry_id,)
|
||||
)
|
||||
return {"skipped": True, "entry_id": entry_id}
|
||||
|
||||
|
||||
# ── GET /api/corpus/stats ──────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/stats")
|
||||
def get_stats() -> dict:
|
||||
with _db() as conn:
|
||||
total = conn.execute("SELECT COUNT(*) FROM corpus_entries").fetchone()[0]
|
||||
by_state = {
|
||||
r["label_state"]: r["cnt"]
|
||||
for r in conn.execute(
|
||||
"SELECT label_state, COUNT(*) AS cnt FROM corpus_entries GROUP BY label_state"
|
||||
).fetchall()
|
||||
}
|
||||
by_source = {
|
||||
r["source_host"]: r["cnt"]
|
||||
for r in conn.execute(
|
||||
"SELECT source_host, COUNT(*) AS cnt FROM corpus_entries GROUP BY source_host"
|
||||
).fetchall()
|
||||
}
|
||||
by_severity = {
|
||||
r["severity"]: r["cnt"]
|
||||
for r in conn.execute(
|
||||
"SELECT severity, COUNT(*) AS cnt FROM corpus_entries "
|
||||
"WHERE severity IS NOT NULL GROUP BY severity"
|
||||
).fetchall()
|
||||
}
|
||||
batch_count = conn.execute("SELECT COUNT(*) FROM corpus_batches").fetchone()[0]
|
||||
return {
|
||||
"total_entries": total,
|
||||
"batch_count": batch_count,
|
||||
"by_label_state": by_state,
|
||||
"by_source": by_source,
|
||||
"by_severity": by_severity,
|
||||
}
|
||||
|
||||
|
||||
# ── GET /api/corpus/export ────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/export")
|
||||
def export_labeled() -> StreamingResponse:
|
||||
"""Stream labeled, non-PII entries as JSONL for SFT harness."""
|
||||
with _db() as conn:
|
||||
rows = conn.execute(
|
||||
"SELECT source_host, source_id, severity, text, failure_type, plain_explanation, known_pattern "
|
||||
"FROM corpus_entries "
|
||||
"WHERE label_state = 'labeled' AND pii_flagged = 0 AND plain_explanation != ''"
|
||||
"ORDER BY rowid"
|
||||
).fetchall()
|
||||
|
||||
def _generate():
|
||||
for row in rows:
|
||||
record = {
|
||||
"input": row["text"],
|
||||
"output": row["plain_explanation"],
|
||||
"metadata": {
|
||||
"failure_type": row["failure_type"],
|
||||
"source": row["source_host"],
|
||||
"source_id": row["source_id"],
|
||||
"severity": row["severity"],
|
||||
"known_pattern": row["known_pattern"],
|
||||
},
|
||||
}
|
||||
yield json.dumps(record) + "\n"
|
||||
|
||||
return StreamingResponse(
|
||||
_generate(),
|
||||
media_type="application/x-ndjson",
|
||||
headers={"Content-Disposition": "attachment; filename=log_corpus_labeled.jsonl"},
|
||||
)
|
||||
|
||||
|
||||
# ── POST /api/corpus/pipeline-ingest ─────────────────────────────────────────
|
||||
|
||||
def _ingest_one_file(conn: sqlite3.Connection, path: Path) -> int:
|
||||
"""Parse a pipeline JSONL file and insert entries. Returns count stored."""
|
||||
batch_id = str(uuid.uuid4())
|
||||
lines = path.read_text(encoding="utf-8").splitlines()
|
||||
entries_raw: list[dict] = []
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
entries_raw.append(json.loads(line))
|
||||
except json.JSONDecodeError:
|
||||
logger.debug("Skipping malformed line in %s", path.name)
|
||||
|
||||
conn.execute(
|
||||
"INSERT INTO corpus_batches (id, source_host, batch_type, received_at, entry_count, raw_json) "
|
||||
"VALUES (?, ?, ?, ?, ?, ?)",
|
||||
(batch_id, _PIPELINE_SOURCE_HOST, "pipeline_log", _now_iso(),
|
||||
len(entries_raw), json.dumps({"file": path.name})),
|
||||
)
|
||||
|
||||
stored = 0
|
||||
for entry in entries_raw:
|
||||
text = (entry.get("msg") or "").strip()
|
||||
if not text:
|
||||
continue
|
||||
conn.execute(
|
||||
"INSERT OR IGNORE INTO corpus_entries "
|
||||
"(id, batch_id, source_host, timestamp_iso, severity, source_id, text, matched_patterns) "
|
||||
"VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
(str(uuid.uuid4()), batch_id, _PIPELINE_SOURCE_HOST,
|
||||
entry.get("ts"),
|
||||
entry.get("level"),
|
||||
entry.get("logger"),
|
||||
text,
|
||||
json.dumps([entry["extra"]] if entry.get("extra") else [])),
|
||||
)
|
||||
stored += 1
|
||||
|
||||
conn.execute(
|
||||
"INSERT INTO ingested_pipeline_files (filename, ingested_at, entry_count) VALUES (?, ?, ?)",
|
||||
(path.name, _now_iso(), stored),
|
||||
)
|
||||
return stored
|
||||
|
||||
|
||||
@router.post("/pipeline-ingest")
|
||||
def pipeline_ingest() -> dict:
|
||||
"""Walk the configured pipeline log directory and ingest new JSONL files.
|
||||
|
||||
Skips files already recorded in ingested_pipeline_files. Safe to call
|
||||
repeatedly — idempotent by filename.
|
||||
"""
|
||||
ingest_dir = _pipeline_ingest_dir()
|
||||
if ingest_dir is None:
|
||||
raise HTTPException(404, "pipeline_ingest_dir not configured in label_tool.yaml")
|
||||
|
||||
ingested = 0
|
||||
skipped = 0
|
||||
total_stored = 0
|
||||
files_detail: list[dict] = []
|
||||
|
||||
with _db() as conn:
|
||||
already_done: set[str] = {
|
||||
row[0]
|
||||
for row in conn.execute("SELECT filename FROM ingested_pipeline_files").fetchall()
|
||||
}
|
||||
|
||||
for path in sorted(ingest_dir.glob("*.jsonl")):
|
||||
if path.name in already_done:
|
||||
skipped += 1
|
||||
continue
|
||||
stored = _ingest_one_file(conn, path)
|
||||
ingested += 1
|
||||
total_stored += stored
|
||||
files_detail.append({"file": path.name, "entries_stored": stored})
|
||||
|
||||
logger.info("Pipeline ingest: %d files ingested, %d skipped, %d entries stored",
|
||||
ingested, skipped, total_stored)
|
||||
return {
|
||||
"ingested_files": ingested,
|
||||
"skipped_files": skipped,
|
||||
"entries_stored": total_stored,
|
||||
"files": files_detail,
|
||||
}
|
||||
313
app/data/recipe_scan.py
Normal file
313
app/data/recipe_scan.py
Normal file
|
|
@ -0,0 +1,313 @@
|
|||
"""Avocet — Recipe scan labeling API (avocet#65).
|
||||
|
||||
Receives recipe scan items from the Kiwi pipeline (scanner/phone image +
|
||||
docuvision OCR extraction + ground-truth structured recipe), presents them
|
||||
for human review, and exports approved/edited pairs in the messages chat
|
||||
format for the vision fine-tune harness.
|
||||
|
||||
DB: data/recipe_scan.db (separate from corpus.db — different lifecycle)
|
||||
No auth required — local admin tool, not a push endpoint.
|
||||
|
||||
All endpoints registered on `router`. api.py includes this with
|
||||
prefix="/api/recipe-scan".
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import sqlite3
|
||||
import uuid
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Generator, Literal
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_ROOT = Path(__file__).parent.parent.parent
|
||||
_DB_PATH: Path = _ROOT / "data" / "recipe_scan.db"
|
||||
|
||||
_VALID_MODALITIES = {"scanner", "phone", "handwritten"}
|
||||
_VALID_STATUSES = {"pending", "approved", "edited", "rejected"}
|
||||
|
||||
_SCHEMA = """
|
||||
CREATE TABLE IF NOT EXISTS recipe_scan_items (
|
||||
id TEXT PRIMARY KEY,
|
||||
image_path TEXT NOT NULL,
|
||||
modality TEXT NOT NULL DEFAULT 'scanner',
|
||||
source TEXT NOT NULL DEFAULT 'purple_carrot',
|
||||
extracted TEXT NOT NULL,
|
||||
ground_truth TEXT NOT NULL,
|
||||
status TEXT NOT NULL DEFAULT 'pending',
|
||||
corrected TEXT,
|
||||
labeled_at TEXT,
|
||||
rejected_reason TEXT
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_rsi_status ON recipe_scan_items(status);
|
||||
CREATE INDEX IF NOT EXISTS idx_rsi_modality ON recipe_scan_items(modality);
|
||||
"""
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# ── Testability seam ──────────────────────────────────────────────────────────
|
||||
|
||||
def set_db_path(path: Path) -> None:
|
||||
global _DB_PATH
|
||||
_DB_PATH = path
|
||||
|
||||
|
||||
# ── Internal helpers ──────────────────────────────────────────────────────────
|
||||
|
||||
@contextmanager
|
||||
def _db() -> Generator[sqlite3.Connection, None, None]:
|
||||
conn = sqlite3.connect(str(_DB_PATH))
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
try:
|
||||
yield conn
|
||||
conn.commit()
|
||||
except Exception:
|
||||
conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
def _init_db() -> None:
|
||||
with _db() as conn:
|
||||
conn.executescript(_SCHEMA)
|
||||
|
||||
|
||||
def _now_iso() -> str:
|
||||
return datetime.now(timezone.utc).isoformat()
|
||||
|
||||
|
||||
def _build_training_pair(row: sqlite3.Row) -> dict:
|
||||
"""Build a messages-format training pair from a labeled row.
|
||||
|
||||
user message: correction prompt + the docuvision-extracted JSON draft.
|
||||
Trains the model to review and correct an existing extraction, which is
|
||||
more data-efficient than producing from scratch when OCR is usually close.
|
||||
|
||||
assistant message: the approved ground truth (or human-corrected JSON).
|
||||
"""
|
||||
target_str = row["corrected"] if row["corrected"] else row["ground_truth"]
|
||||
extracted = json.loads(row["extracted"])
|
||||
target = json.loads(target_str)
|
||||
user_content = (
|
||||
"Review and correct this recipe extraction. "
|
||||
"Return valid JSON with fields: title, description, ingredients, steps, "
|
||||
"prep_time, cook_time, servings.\n\n"
|
||||
f"Extraction to review:\n{json.dumps(extracted, ensure_ascii=False, indent=2)}"
|
||||
)
|
||||
return {
|
||||
"id": row["id"],
|
||||
"modality": row["modality"],
|
||||
"source": row["source"],
|
||||
"image_path": row["image_path"],
|
||||
"messages": [
|
||||
{"role": "user", "content": user_content},
|
||||
{"role": "assistant", "content": json.dumps(target, ensure_ascii=False)},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
_init_db()
|
||||
|
||||
|
||||
# ── POST /import ───────────────────────────────────────────────────────────────
|
||||
|
||||
class ImportItem(BaseModel):
|
||||
id: str = ""
|
||||
image_path: str
|
||||
modality: Literal["scanner", "phone", "handwritten"] = "scanner"
|
||||
source: str = "purple_carrot"
|
||||
extracted: dict
|
||||
ground_truth: dict
|
||||
|
||||
@field_validator("id", mode="before")
|
||||
@classmethod
|
||||
def default_id(cls, v: str) -> str:
|
||||
return v or str(uuid.uuid4())
|
||||
|
||||
|
||||
class ImportRequest(BaseModel):
|
||||
items: list[ImportItem]
|
||||
|
||||
|
||||
@router.post("/import")
|
||||
def import_items(body: ImportRequest) -> dict:
|
||||
"""Bulk-import scan items from the Kiwi pipeline. Idempotent by item id."""
|
||||
stored = 0
|
||||
with _db() as conn:
|
||||
for item in body.items:
|
||||
result = conn.execute(
|
||||
"INSERT OR IGNORE INTO recipe_scan_items "
|
||||
"(id, image_path, modality, source, extracted, ground_truth) "
|
||||
"VALUES (?, ?, ?, ?, ?, ?)",
|
||||
(item.id, item.image_path, item.modality, item.source,
|
||||
json.dumps(item.extracted), json.dumps(item.ground_truth)),
|
||||
)
|
||||
stored += result.rowcount
|
||||
return {"imported": stored, "total_submitted": len(body.items)}
|
||||
|
||||
|
||||
# ── GET /next ─────────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/next")
|
||||
def get_next() -> dict:
|
||||
"""Return the next pending item for review, oldest-first."""
|
||||
with _db() as conn:
|
||||
row = conn.execute(
|
||||
"SELECT * FROM recipe_scan_items WHERE status = 'pending' ORDER BY rowid LIMIT 1"
|
||||
).fetchone()
|
||||
if row is None:
|
||||
raise HTTPException(404, "No pending items in queue")
|
||||
return {
|
||||
**dict(row),
|
||||
"extracted": json.loads(row["extracted"]),
|
||||
"ground_truth": json.loads(row["ground_truth"]),
|
||||
}
|
||||
|
||||
|
||||
# ── POST /items/{id}/approve ──────────────────────────────────────────────────
|
||||
|
||||
@router.post("/items/{item_id}/approve")
|
||||
def approve_item(item_id: str) -> dict:
|
||||
"""Mark item as approved — extracted JSON is close enough to ground truth."""
|
||||
with _db() as conn:
|
||||
row = conn.execute("SELECT id FROM recipe_scan_items WHERE id = ?", (item_id,)).fetchone()
|
||||
if row is None:
|
||||
raise HTTPException(404, "Item not found")
|
||||
conn.execute(
|
||||
"UPDATE recipe_scan_items SET status='approved', labeled_at=? WHERE id=?",
|
||||
(_now_iso(), item_id),
|
||||
)
|
||||
return {"status": "approved", "id": item_id}
|
||||
|
||||
|
||||
# ── POST /items/{id}/edit ─────────────────────────────────────────────────────
|
||||
|
||||
class EditBody(BaseModel):
|
||||
corrected: dict
|
||||
|
||||
|
||||
@router.post("/items/{item_id}/edit")
|
||||
def edit_item(item_id: str, body: EditBody) -> dict:
|
||||
"""Approve with a human-corrected JSON. corrected overrides extracted in export."""
|
||||
with _db() as conn:
|
||||
row = conn.execute("SELECT id FROM recipe_scan_items WHERE id = ?", (item_id,)).fetchone()
|
||||
if row is None:
|
||||
raise HTTPException(404, "Item not found")
|
||||
conn.execute(
|
||||
"UPDATE recipe_scan_items SET status='edited', corrected=?, labeled_at=? WHERE id=?",
|
||||
(json.dumps(body.corrected), _now_iso(), item_id),
|
||||
)
|
||||
return {"status": "edited", "id": item_id}
|
||||
|
||||
|
||||
# ── POST /items/{id}/reject ───────────────────────────────────────────────────
|
||||
|
||||
class RejectBody(BaseModel):
|
||||
reason: str = ""
|
||||
|
||||
|
||||
@router.post("/items/{item_id}/reject")
|
||||
def reject_item(item_id: str, body: RejectBody = RejectBody()) -> dict:
|
||||
"""Reject item — extraction too broken to use for training."""
|
||||
with _db() as conn:
|
||||
row = conn.execute("SELECT id FROM recipe_scan_items WHERE id = ?", (item_id,)).fetchone()
|
||||
if row is None:
|
||||
raise HTTPException(404, "Item not found")
|
||||
conn.execute(
|
||||
"UPDATE recipe_scan_items SET status='rejected', rejected_reason=?, labeled_at=? WHERE id=?",
|
||||
(body.reason or None, _now_iso(), item_id),
|
||||
)
|
||||
return {"status": "rejected", "id": item_id}
|
||||
|
||||
|
||||
# ── GET /stats ────────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/stats")
|
||||
def get_stats() -> dict:
|
||||
with _db() as conn:
|
||||
total = conn.execute("SELECT COUNT(*) FROM recipe_scan_items").fetchone()[0]
|
||||
by_status = {
|
||||
r["status"]: r["cnt"]
|
||||
for r in conn.execute(
|
||||
"SELECT status, COUNT(*) AS cnt FROM recipe_scan_items GROUP BY status"
|
||||
).fetchall()
|
||||
}
|
||||
by_modality = {
|
||||
r["modality"]: r["cnt"]
|
||||
for r in conn.execute(
|
||||
"SELECT modality, COUNT(*) AS cnt FROM recipe_scan_items GROUP BY modality"
|
||||
).fetchall()
|
||||
}
|
||||
export_ready = conn.execute(
|
||||
"SELECT COUNT(*) FROM recipe_scan_items WHERE status IN ('approved', 'edited')"
|
||||
).fetchone()[0]
|
||||
return {
|
||||
"total": total,
|
||||
"by_status": by_status,
|
||||
"by_modality": by_modality,
|
||||
"export_ready": export_ready,
|
||||
}
|
||||
|
||||
|
||||
# ── GET /export ───────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/export")
|
||||
def export_pairs() -> StreamingResponse:
|
||||
"""Stream approved/edited items as JSONL training pairs (messages format)."""
|
||||
with _db() as conn:
|
||||
rows = conn.execute(
|
||||
"SELECT * FROM recipe_scan_items WHERE status IN ('approved', 'edited') ORDER BY rowid"
|
||||
).fetchall()
|
||||
|
||||
def _generate():
|
||||
for row in rows:
|
||||
yield json.dumps(_build_training_pair(row), ensure_ascii=False) + "\n"
|
||||
|
||||
return StreamingResponse(
|
||||
_generate(),
|
||||
media_type="application/x-ndjson",
|
||||
headers={"Content-Disposition": "attachment; filename=recipe_scan_pairs.jsonl"},
|
||||
)
|
||||
|
||||
|
||||
# ── GET /image ────────────────────────────────────────────────────────────────
|
||||
|
||||
_IMAGE_ROOT = Path("/Library/Assets/kiwi")
|
||||
|
||||
|
||||
@router.get("/image")
|
||||
def serve_image(path: str) -> StreamingResponse:
|
||||
"""Serve a scan image from /Library/Assets/kiwi/.
|
||||
|
||||
path must resolve within /Library/Assets/kiwi/ — rejects traversal attempts.
|
||||
"""
|
||||
try:
|
||||
resolved = Path(path).resolve()
|
||||
_IMAGE_ROOT.resolve() # ensure root itself is valid
|
||||
resolved.relative_to(_IMAGE_ROOT.resolve())
|
||||
except (ValueError, OSError):
|
||||
raise HTTPException(403, "Path outside allowed image directory")
|
||||
|
||||
if not resolved.exists():
|
||||
raise HTTPException(404, "Image not found")
|
||||
|
||||
suffix = resolved.suffix.lower()
|
||||
media_types = {".jpg": "image/jpeg", ".jpeg": "image/jpeg", ".png": "image/png", ".webp": "image/webp"}
|
||||
media_type = media_types.get(suffix, "application/octet-stream")
|
||||
|
||||
return StreamingResponse(
|
||||
open(resolved, "rb"),
|
||||
media_type=media_type,
|
||||
headers={"Cache-Control": "public, max-age=86400"},
|
||||
)
|
||||
0
app/eval/__init__.py
Normal file
0
app/eval/__init__.py
Normal file
44
app/eval/cforch.py
Normal file
44
app/eval/cforch.py
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
"""Avocet -- eval router aggregator.
|
||||
|
||||
Collects benchmark sub-routers into a single importable `router`
|
||||
for the api.py factory. Each sub-router retains its established prefix
|
||||
so no frontend URL changes are needed.
|
||||
|
||||
Route prefixes when mounted at /api in api.py:
|
||||
/api/cforch/* -- cf-orch benchmark routes
|
||||
/api/style/* -- writing style benchmark routes
|
||||
/api/voice/* -- voice benchmark routes
|
||||
/api/plans-bench/* -- plans benchmark routes
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from app.cforch import router as _cforch_router
|
||||
from app.style import router as _style_router
|
||||
from app.voice import router as _voice_router
|
||||
from app.plans_bench import router as _plans_router
|
||||
from app.eval.embed_bench import router as _embed_router
|
||||
|
||||
router = APIRouter()
|
||||
router.include_router(_cforch_router, prefix="/cforch")
|
||||
router.include_router(_style_router, prefix="/style")
|
||||
router.include_router(_voice_router, prefix="/voice")
|
||||
router.include_router(_plans_router, prefix="/plans-bench")
|
||||
router.include_router(_embed_router, prefix="/embed-bench")
|
||||
|
||||
|
||||
def set_config_dir(path: Path | None) -> None:
|
||||
"""Propagate config dir override to all sub-modules -- used by tests."""
|
||||
import app.cforch as _cforch_mod
|
||||
import app.style as _style_mod
|
||||
import app.voice as _voice_mod
|
||||
import app.plans_bench as _plans_mod
|
||||
import app.eval.embed_bench as _embed_mod
|
||||
_cforch_mod.set_config_dir(path)
|
||||
_style_mod.set_config_dir(path)
|
||||
_voice_mod.set_config_dir(path)
|
||||
_plans_mod.set_config_dir(path)
|
||||
_embed_mod.set_config_dir(path)
|
||||
293
app/eval/embed_bench.py
Normal file
293
app/eval/embed_bench.py
Normal file
|
|
@ -0,0 +1,293 @@
|
|||
"""Avocet — embedding model comparison harness.
|
||||
|
||||
Exposes FastAPI routes under /api/embed-bench (mounted via app/eval/cforch.py).
|
||||
All computation is local: no LLM inference, Ollama only. MIT tier throughout.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import csv
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
import yaml
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_ROOT = Path(__file__).parent.parent.parent
|
||||
_CONFIG_DIR: Path | None = None # override via set_config_dir() in tests
|
||||
_RUN_ACTIVE: bool = False
|
||||
_RATINGS_FILE = _ROOT / "data" / "embed_bench_ratings.jsonl"
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# ── Testability seam ──────────────────────────────────────────────────────────
|
||||
|
||||
def set_config_dir(path: Path | None) -> None:
|
||||
global _CONFIG_DIR
|
||||
_CONFIG_DIR = path
|
||||
|
||||
|
||||
# ── Internal helpers ──────────────────────────────────────────────────────────
|
||||
|
||||
def _config_file() -> Path:
|
||||
if _CONFIG_DIR is not None:
|
||||
return _CONFIG_DIR / "label_tool.yaml"
|
||||
return _ROOT / "config" / "label_tool.yaml"
|
||||
|
||||
|
||||
def _load_config() -> dict[str, Any]:
|
||||
f = _config_file()
|
||||
if not f.exists():
|
||||
return {}
|
||||
try:
|
||||
return yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
||||
except yaml.YAMLError as exc:
|
||||
logger.warning("Failed to parse embed_bench config %s: %s", f, exc)
|
||||
return {}
|
||||
|
||||
|
||||
def _ollama_url() -> str:
|
||||
cfg = _load_config()
|
||||
embed_cfg = cfg.get("embed_bench", {}) or {}
|
||||
cforch_cfg = cfg.get("cforch", {}) or {}
|
||||
return (
|
||||
embed_cfg.get("ollama_url")
|
||||
or cforch_cfg.get("ollama_url", "http://localhost:11434")
|
||||
)
|
||||
|
||||
|
||||
def _ratings_path() -> Path:
|
||||
if _CONFIG_DIR is not None:
|
||||
return _CONFIG_DIR / "embed_bench_ratings.jsonl"
|
||||
return _RATINGS_FILE
|
||||
|
||||
|
||||
def _cosine(a: list[float], b: list[float]) -> float:
|
||||
if len(a) != len(b):
|
||||
raise ValueError(
|
||||
f"Embedding dimension mismatch: {len(a)} vs {len(b)}"
|
||||
)
|
||||
dot = sum(x * y for x, y in zip(a, b))
|
||||
mag_a = math.sqrt(sum(x * x for x in a))
|
||||
mag_b = math.sqrt(sum(x * x for x in b))
|
||||
if mag_a == 0.0 or mag_b == 0.0:
|
||||
return 0.0
|
||||
return dot / (mag_a * mag_b)
|
||||
|
||||
|
||||
# ── GET /models ───────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/models")
|
||||
def get_models() -> dict:
|
||||
"""Return Ollama embedding models available on the configured instance."""
|
||||
ollama = _ollama_url()
|
||||
models: list[dict] = []
|
||||
try:
|
||||
resp = httpx.get(f"{ollama}/api/tags", timeout=5.0)
|
||||
resp.raise_for_status()
|
||||
for entry in resp.json().get("models", []):
|
||||
models.append({
|
||||
"name": entry.get("name", ""),
|
||||
"size": entry.get("size", 0),
|
||||
})
|
||||
except httpx.HTTPStatusError as exc:
|
||||
logger.warning("Ollama /api/tags returned HTTP %s: %s", exc.response.status_code, exc)
|
||||
except httpx.RequestError as exc:
|
||||
logger.warning("Failed to reach Ollama for model list: %s", exc)
|
||||
return {"models": models, "ollama_url": ollama}
|
||||
|
||||
|
||||
# ── POST /run ─────────────────────────────────────────────────────────────────
|
||||
|
||||
class RunRequest(BaseModel):
|
||||
corpus: list[str]
|
||||
queries: list[str]
|
||||
models: list[str]
|
||||
top_k: int = 5
|
||||
ollama_url: str = ""
|
||||
|
||||
@field_validator("corpus")
|
||||
@classmethod
|
||||
def corpus_nonempty(cls, v: list[str]) -> list[str]:
|
||||
if not v:
|
||||
raise ValueError("corpus must not be empty")
|
||||
return v
|
||||
|
||||
@field_validator("queries")
|
||||
@classmethod
|
||||
def queries_nonempty(cls, v: list[str]) -> list[str]:
|
||||
if not v:
|
||||
raise ValueError("queries must not be empty")
|
||||
return v
|
||||
|
||||
@field_validator("models")
|
||||
@classmethod
|
||||
def models_nonempty(cls, v: list[str]) -> list[str]:
|
||||
if not v:
|
||||
raise ValueError("models must contain at least one model name")
|
||||
return v
|
||||
|
||||
|
||||
def _embed_texts(ollama: str, model: str, texts: list[str]) -> list[list[float]]:
|
||||
"""Batch-embed texts via Ollama /v1/embeddings. Returns one vector per text."""
|
||||
resp = httpx.post(
|
||||
f"{ollama}/v1/embeddings",
|
||||
json={"model": model, "input": texts},
|
||||
timeout=120.0,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json().get("data", [])
|
||||
return [item["embedding"] for item in data]
|
||||
|
||||
|
||||
def _sse(event: dict) -> str:
|
||||
return f"data: {json.dumps(event)}\n\n"
|
||||
|
||||
|
||||
@router.post("/run")
|
||||
def run_embed_bench(req: RunRequest) -> StreamingResponse:
|
||||
"""Embed corpus + queries with each model; stream SSE results."""
|
||||
global _RUN_ACTIVE
|
||||
|
||||
if _RUN_ACTIVE:
|
||||
raise HTTPException(409, "An embedding benchmark run is already active")
|
||||
|
||||
ollama = req.ollama_url or _ollama_url()
|
||||
|
||||
def _generate():
|
||||
global _RUN_ACTIVE
|
||||
_RUN_ACTIVE = True
|
||||
try:
|
||||
for model_idx, model in enumerate(req.models, start=1):
|
||||
yield _sse({
|
||||
"type": "progress",
|
||||
"msg": f"Indexing corpus with {model} ({model_idx}/{len(req.models)})...",
|
||||
})
|
||||
try:
|
||||
corpus_vecs = _embed_texts(ollama, model, req.corpus)
|
||||
except Exception as exc:
|
||||
yield _sse({"type": "error", "msg": f"Ollama error for {model}: {exc}"})
|
||||
continue
|
||||
|
||||
yield _sse({
|
||||
"type": "progress",
|
||||
"msg": f"Running queries with {model}...",
|
||||
})
|
||||
|
||||
for q_idx, query in enumerate(req.queries):
|
||||
try:
|
||||
q_vecs = _embed_texts(ollama, model, [query])
|
||||
except Exception as exc:
|
||||
yield _sse({"type": "error", "msg": f"Query embed error ({model}): {exc}"})
|
||||
continue
|
||||
q_vec = q_vecs[0]
|
||||
scored = sorted(
|
||||
[
|
||||
{"chunk_idx": i, "text": chunk, "score": round(_cosine(q_vec, cv), 4)}
|
||||
for i, (chunk, cv) in enumerate(zip(req.corpus, corpus_vecs))
|
||||
],
|
||||
key=lambda h: h["score"],
|
||||
reverse=True,
|
||||
)[: req.top_k]
|
||||
yield _sse({
|
||||
"type": "result",
|
||||
"query_idx": q_idx,
|
||||
"query": query,
|
||||
"model": model,
|
||||
"hits": scored,
|
||||
})
|
||||
|
||||
yield _sse({"type": "done"})
|
||||
finally:
|
||||
_RUN_ACTIVE = False
|
||||
|
||||
return StreamingResponse(_generate(), media_type="text/event-stream")
|
||||
|
||||
|
||||
# ── POST /rate ────────────────────────────────────────────────────────────────
|
||||
|
||||
_VALID_RATINGS = {"relevant", "not_relevant"}
|
||||
|
||||
|
||||
class RatingRequest(BaseModel):
|
||||
query: str
|
||||
model: str
|
||||
chunk_text: str
|
||||
chunk_idx: int
|
||||
rating: str
|
||||
|
||||
@field_validator("rating")
|
||||
@classmethod
|
||||
def rating_valid(cls, v: str) -> str:
|
||||
if v not in _VALID_RATINGS:
|
||||
raise ValueError(f"rating must be one of {_VALID_RATINGS}")
|
||||
return v
|
||||
|
||||
|
||||
@router.post("/rate")
|
||||
def rate_result(req: RatingRequest) -> dict:
|
||||
"""Append one rating to the JSONL ratings file."""
|
||||
entry = {
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"query": req.query,
|
||||
"model": req.model,
|
||||
"chunk_idx": req.chunk_idx,
|
||||
"chunk_text": req.chunk_text,
|
||||
"rating": req.rating,
|
||||
}
|
||||
path = _ratings_path()
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with path.open("a", encoding="utf-8") as fh:
|
||||
fh.write(json.dumps(entry) + "\n")
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
# ── GET /export ───────────────────────────────────────────────────────────────
|
||||
|
||||
_CSV_FIELDS = ["timestamp", "query", "model", "chunk_idx", "chunk_text", "rating"]
|
||||
|
||||
|
||||
@router.get("/export")
|
||||
def export_ratings(format: str = "csv") -> Any:
|
||||
"""Download ratings as CSV or JSON."""
|
||||
path = _ratings_path()
|
||||
rows: list[dict] = []
|
||||
if path.exists():
|
||||
for raw in path.read_text(encoding="utf-8").splitlines():
|
||||
raw = raw.strip()
|
||||
if raw:
|
||||
try:
|
||||
rows.append(json.loads(raw))
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
date_str = datetime.now(timezone.utc).strftime("%Y-%m-%d")
|
||||
|
||||
if format == "json":
|
||||
content = json.dumps(rows, ensure_ascii=False, indent=2)
|
||||
return StreamingResponse(
|
||||
iter([content]),
|
||||
media_type="application/json",
|
||||
headers={"Content-Disposition": f'attachment; filename="embed_comparison_{date_str}.json"'},
|
||||
)
|
||||
|
||||
# Default: CSV
|
||||
buf = io.StringIO()
|
||||
writer = csv.DictWriter(buf, fieldnames=_CSV_FIELDS, extrasaction="ignore")
|
||||
writer.writeheader()
|
||||
writer.writerows(rows)
|
||||
return StreamingResponse(
|
||||
iter([buf.getvalue()]),
|
||||
media_type="text/csv",
|
||||
headers={"Content-Disposition": f'attachment; filename="embed_comparison_{date_str}.csv"'},
|
||||
)
|
||||
|
|
@ -1,158 +1,9 @@
|
|||
"""Avocet — IMAP fetch utilities.
|
||||
|
||||
Shared between app/api.py (FastAPI SSE endpoint) and the label UI.
|
||||
No Streamlit imports here — stdlib + imaplib only.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import email as _email_lib
|
||||
import hashlib
|
||||
import imaplib
|
||||
from datetime import datetime, timedelta
|
||||
from email.header import decode_header as _raw_decode
|
||||
from typing import Any, Iterator
|
||||
|
||||
from app.utils import extract_body, strip_html # noqa: F401 (strip_html re-exported for callers)
|
||||
|
||||
|
||||
# ── IMAP decode helpers ───────────────────────────────────────────────────────
|
||||
|
||||
def _decode_str(value: str | None) -> str:
|
||||
if not value:
|
||||
return ""
|
||||
parts = _raw_decode(value)
|
||||
out = []
|
||||
for part, enc in parts:
|
||||
if isinstance(part, bytes):
|
||||
out.append(part.decode(enc or "utf-8", errors="replace"))
|
||||
else:
|
||||
out.append(str(part))
|
||||
return " ".join(out).strip()
|
||||
|
||||
|
||||
def entry_key(e: dict) -> str:
|
||||
"""Stable MD5 content-hash for dedup — matches label_tool.py _entry_key."""
|
||||
key = (e.get("subject", "") + (e.get("body", "") or "")[:100])
|
||||
return hashlib.md5(key.encode("utf-8", errors="replace")).hexdigest()
|
||||
|
||||
|
||||
# ── Wide search terms ────────────────────────────────────────────────────────
|
||||
|
||||
_WIDE_TERMS = [
|
||||
"interview", "phone screen", "video call", "zoom link", "schedule a call",
|
||||
"offer letter", "job offer", "offer of employment", "pleased to offer",
|
||||
"unfortunately", "not moving forward", "other candidates", "regret to inform",
|
||||
"no longer", "decided not to", "decided to go with",
|
||||
"opportunity", "interested in your background", "reached out", "great fit",
|
||||
"exciting role", "love to connect",
|
||||
"assessment", "questionnaire", "culture fit", "culture-fit", "online assessment",
|
||||
"application received", "thank you for applying", "application confirmation",
|
||||
"you applied", "your application for",
|
||||
"reschedule", "rescheduled", "new time", "moved to", "postponed", "new date",
|
||||
"job digest", "jobs you may like", "recommended jobs", "jobs for you",
|
||||
"new jobs", "job alert",
|
||||
"came across your profile", "reaching out about", "great fit for a role",
|
||||
"exciting opportunity",
|
||||
"welcome to the team", "start date", "onboarding", "first day", "we're excited to have you",
|
||||
"application", "recruiter", "recruiting", "hiring", "candidate",
|
||||
]
|
||||
|
||||
|
||||
# ── Public API ────────────────────────────────────────────────────────────────
|
||||
|
||||
def test_connection(acc: dict) -> tuple[bool, str, int | None]:
|
||||
"""Connect, login, select folder. Returns (ok, human_message, message_count|None)."""
|
||||
host = acc.get("host", "")
|
||||
port = int(acc.get("port", 993))
|
||||
use_ssl = acc.get("use_ssl", True)
|
||||
username = acc.get("username", "")
|
||||
password = acc.get("password", "")
|
||||
folder = acc.get("folder", "INBOX")
|
||||
if not host or not username or not password:
|
||||
return False, "Host, username, and password are all required.", None
|
||||
try:
|
||||
conn = (imaplib.IMAP4_SSL if use_ssl else imaplib.IMAP4)(host, port)
|
||||
conn.login(username, password)
|
||||
_, data = conn.select(folder, readonly=True)
|
||||
count_raw = data[0].decode() if data and data[0] else "0"
|
||||
count = int(count_raw) if count_raw.isdigit() else 0
|
||||
conn.logout()
|
||||
return True, f"Connected — {count:,} message(s) in {folder}.", count
|
||||
except Exception as exc:
|
||||
return False, str(exc), None
|
||||
|
||||
|
||||
def fetch_account_stream(
|
||||
acc: dict,
|
||||
days_back: int,
|
||||
limit: int,
|
||||
known_keys: set[str],
|
||||
) -> Iterator[dict]:
|
||||
"""Generator — yields progress dicts while fetching emails via IMAP.
|
||||
|
||||
Mutates `known_keys` in place for cross-account dedup within one fetch session.
|
||||
|
||||
Yields event dicts with "type" key:
|
||||
{"type": "start", "account": str, "total_uids": int}
|
||||
{"type": "progress", "account": str, "fetched": int, "total_uids": int}
|
||||
{"type": "done", "account": str, "added": int, "skipped": int, "emails": list}
|
||||
"""
|
||||
name = acc.get("name", acc.get("username", "?"))
|
||||
host = acc.get("host", "imap.gmail.com")
|
||||
port = int(acc.get("port", 993))
|
||||
use_ssl = acc.get("use_ssl", True)
|
||||
username = acc["username"]
|
||||
password = acc["password"]
|
||||
folder = acc.get("folder", "INBOX")
|
||||
since = (datetime.now() - timedelta(days=days_back)).strftime("%d-%b-%Y")
|
||||
|
||||
conn = (imaplib.IMAP4_SSL if use_ssl else imaplib.IMAP4)(host, port)
|
||||
conn.login(username, password)
|
||||
conn.select(folder, readonly=True)
|
||||
|
||||
seen_uids: dict[bytes, None] = {}
|
||||
for term in _WIDE_TERMS:
|
||||
try:
|
||||
_, data = conn.search(None, f'(SUBJECT "{term}" SINCE "{since}")')
|
||||
for uid in (data[0] or b"").split():
|
||||
seen_uids[uid] = None
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
uids = list(seen_uids.keys())[: limit * 3]
|
||||
yield {"type": "start", "account": name, "total_uids": len(uids)}
|
||||
|
||||
emails: list[dict] = []
|
||||
skipped = 0
|
||||
for i, uid in enumerate(uids):
|
||||
if len(emails) >= limit:
|
||||
break
|
||||
if i % 5 == 0:
|
||||
yield {"type": "progress", "account": name, "fetched": len(emails), "total_uids": len(uids)}
|
||||
try:
|
||||
_, raw_data = conn.fetch(uid, "(RFC822)")
|
||||
if not raw_data or not raw_data[0]:
|
||||
continue
|
||||
msg = _email_lib.message_from_bytes(raw_data[0][1])
|
||||
subj = _decode_str(msg.get("Subject", ""))
|
||||
from_addr = _decode_str(msg.get("From", ""))
|
||||
date = _decode_str(msg.get("Date", ""))
|
||||
body = extract_body(msg)[:800]
|
||||
entry = {"subject": subj, "body": body, "from_addr": from_addr,
|
||||
"date": date, "account": name}
|
||||
k = entry_key(entry)
|
||||
if k not in known_keys:
|
||||
known_keys.add(k)
|
||||
emails.append(entry)
|
||||
else:
|
||||
skipped += 1
|
||||
except Exception:
|
||||
skipped += 1
|
||||
|
||||
try:
|
||||
conn.logout()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
yield {"type": "done", "account": name, "added": len(emails), "skipped": skipped,
|
||||
"emails": emails}
|
||||
"""Backward-compat shim -- logic moved to app/data/fetch.py."""
|
||||
import imaplib # noqa: F401 -- re-exported so existing patch("app.imap_fetch.imaplib...") calls still work
|
||||
from app.data.fetch import ( # noqa: F401
|
||||
entry_key,
|
||||
fetch_account_stream,
|
||||
test_connection,
|
||||
_decode_str,
|
||||
_WIDE_TERMS,
|
||||
)
|
||||
|
|
|
|||
627
app/imitate.py
627
app/imitate.py
|
|
@ -1,624 +1,3 @@
|
|||
"""Avocet — Imitate tab API.
|
||||
|
||||
Fetches real samples from sibling CF product APIs, sends them through selected
|
||||
local LLMs (ollama), and streams responses back to the UI. Results can be
|
||||
pushed into the SFT corrections queue for human review.
|
||||
|
||||
All endpoints registered on `router`. api.py includes this with prefix="/api/imitate".
|
||||
|
||||
Module-level globals follow the same testability pattern as cforch.py and sft.py:
|
||||
override _CONFIG_DIR and _DATA_DIR via set_config_dir() / set_data_dir() in tests.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from urllib.error import URLError
|
||||
from urllib.request import Request, urlopen
|
||||
|
||||
import httpx
|
||||
import yaml
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.utils import append_jsonl
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_ROOT = Path(__file__).parent.parent
|
||||
_CONFIG_DIR: Path | None = None
|
||||
_DATA_DIR: Path = _ROOT / "data"
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# ── Testability seams ──────────────────────────────────────────────────────────
|
||||
|
||||
def set_config_dir(path: Path | None) -> None:
|
||||
global _CONFIG_DIR
|
||||
_CONFIG_DIR = path
|
||||
|
||||
|
||||
def set_data_dir(path: Path) -> None:
|
||||
global _DATA_DIR
|
||||
_DATA_DIR = path
|
||||
|
||||
|
||||
# ── Internal helpers ───────────────────────────────────────────────────────────
|
||||
|
||||
def _config_file() -> Path:
|
||||
if _CONFIG_DIR is not None:
|
||||
return _CONFIG_DIR / "label_tool.yaml"
|
||||
return _ROOT / "config" / "label_tool.yaml"
|
||||
|
||||
|
||||
def _load_imitate_config() -> dict:
|
||||
"""Read label_tool.yaml and return the imitate sub-dict (or {} if absent)."""
|
||||
f = _config_file()
|
||||
if not f.exists():
|
||||
return {}
|
||||
try:
|
||||
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
||||
except yaml.YAMLError as exc:
|
||||
logger.warning("Failed to parse imitate config %s: %s", f, exc)
|
||||
return {}
|
||||
return raw.get("imitate", {}) or {}
|
||||
|
||||
|
||||
def _load_cforch_config() -> dict:
|
||||
"""Read cforch section for ollama_url fallback."""
|
||||
f = _config_file()
|
||||
if not f.exists():
|
||||
return {}
|
||||
try:
|
||||
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
||||
except yaml.YAMLError as exc:
|
||||
return {}
|
||||
return raw.get("cforch", {}) or {}
|
||||
|
||||
|
||||
def _ollama_url(cfg: dict) -> str:
|
||||
cforch = _load_cforch_config()
|
||||
return cfg.get("ollama_url") or cforch.get("ollama_url") or "http://localhost:11434"
|
||||
|
||||
|
||||
def _cforch_url() -> str:
|
||||
cforch = _load_cforch_config()
|
||||
return cforch.get("coordinator_url") or "http://localhost:7700"
|
||||
|
||||
|
||||
def _cforch_catalog(cforch_base: str) -> list[dict]:
|
||||
"""Fetch the live cf-text catalog from cf-orch.
|
||||
|
||||
Filters out proxy entries (ollama://, vllm://, http://) — those models are
|
||||
served by their own services and should not be allocated via cf-text.
|
||||
Returns only models with real file-system paths that cf-text can load directly.
|
||||
"""
|
||||
try:
|
||||
resp = httpx.get(
|
||||
f"{cforch_base}/api/services/cf-text/catalog",
|
||||
params={"node_id": "heimdall"},
|
||||
timeout=5.0,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
raw = resp.json()
|
||||
result = []
|
||||
for model_id, entry in raw.items():
|
||||
if not isinstance(entry, dict):
|
||||
continue
|
||||
path = entry.get("path", "")
|
||||
# Skip proxy entries — they're routed through other services
|
||||
if "://" in path:
|
||||
continue
|
||||
result.append({
|
||||
"id": model_id,
|
||||
"vram_mb": entry.get("vram_mb", 0),
|
||||
"description": entry.get("description", ""),
|
||||
})
|
||||
return result
|
||||
except Exception as exc:
|
||||
logger.warning("Could not fetch cf-orch catalog: %s", exc)
|
||||
return []
|
||||
|
||||
|
||||
def _http_get_json(url: str, timeout: int = 5) -> Any:
|
||||
"""Fetch JSON from url; raise URLError on failure."""
|
||||
req = Request(url, headers={"Accept": "application/json"})
|
||||
with urlopen(req, timeout=timeout) as resp:
|
||||
return json.loads(resp.read().decode("utf-8"))
|
||||
|
||||
|
||||
def _is_online(base_url: str, health_path: str = "/api/health") -> bool:
|
||||
"""Return True if the product's health endpoint responds OK."""
|
||||
try:
|
||||
data = _http_get_json(f"{base_url.rstrip('/')}{health_path}", timeout=2)
|
||||
return bool(data)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _extract_sample(
|
||||
raw: Any,
|
||||
text_fields: list[str],
|
||||
sample_index: int = 0,
|
||||
sample_key: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Pull one item from a list or dict response and extract text_fields.
|
||||
|
||||
sample_key: if provided, unwrap raw[sample_key] before looking for a list.
|
||||
Falls back to a set of conventional envelope keys if sample_key is absent.
|
||||
"""
|
||||
item: dict[str, Any]
|
||||
if isinstance(raw, list):
|
||||
if not raw:
|
||||
return {}
|
||||
item = raw[min(sample_index, len(raw) - 1)]
|
||||
elif isinstance(raw, dict):
|
||||
# Use declared sample_key first, then fall back to conventional names.
|
||||
_ENVELOPE_KEYS = (
|
||||
"samples", "items", "results", "data", "jobs", "listings",
|
||||
"pantry", "saved_searches", "entries", "calls", "records",
|
||||
)
|
||||
search_keys = ([sample_key] if sample_key else []) + list(_ENVELOPE_KEYS)
|
||||
for key in search_keys:
|
||||
if key in raw and isinstance(raw[key], list):
|
||||
lst = raw[key]
|
||||
item = lst[min(sample_index, len(lst) - 1)] if lst else {}
|
||||
break
|
||||
else:
|
||||
item = raw
|
||||
else:
|
||||
return {}
|
||||
|
||||
parts = []
|
||||
for field in text_fields:
|
||||
val = item.get(field)
|
||||
if val and str(val).strip():
|
||||
parts.append(f"**{field}**: {val}")
|
||||
return {"item": item, "text": "\n\n".join(parts)}
|
||||
|
||||
|
||||
def _candidates_file() -> Path:
|
||||
return _DATA_DIR / "sft_candidates.jsonl"
|
||||
|
||||
|
||||
def _sse(data: dict) -> str:
|
||||
return f"data: {json.dumps(data)}\n\n"
|
||||
|
||||
|
||||
def _fetch_image_b64(image_url: str) -> str:
|
||||
"""Download an image URL and return it as a base64 string for ollama.
|
||||
|
||||
Returns empty string on any failure — a missing image is non-fatal;
|
||||
the model will still run against the text prompt alone.
|
||||
"""
|
||||
try:
|
||||
req = Request(image_url, headers={"User-Agent": "Avocet/1.0"})
|
||||
with urlopen(req, timeout=10) as resp:
|
||||
return base64.b64encode(resp.read()).decode("ascii")
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to fetch image %s: %s", image_url, exc)
|
||||
return ""
|
||||
|
||||
|
||||
def _run_ollama_streaming(
|
||||
ollama_base: str,
|
||||
model_id: str,
|
||||
prompt: str,
|
||||
temperature: float,
|
||||
system: str = "",
|
||||
images: list[str] | None = None,
|
||||
) -> tuple[str, int]:
|
||||
"""Call ollama /api/generate with stream=False; return (full_response, elapsed_ms).
|
||||
|
||||
Blocks until the model finishes; yields nothing — streaming is handled by
|
||||
the SSE generator in run_imitate().
|
||||
|
||||
system: optional system prompt passed as a separate field to ollama.
|
||||
images: list of base64-encoded image strings (vision models only).
|
||||
"""
|
||||
url = f"{ollama_base.rstrip('/')}/api/generate"
|
||||
body: dict = {
|
||||
"model": model_id,
|
||||
"prompt": prompt,
|
||||
"stream": False,
|
||||
"options": {"temperature": temperature},
|
||||
}
|
||||
if system:
|
||||
body["system"] = system
|
||||
if images:
|
||||
body["images"] = images
|
||||
payload = json.dumps(body).encode("utf-8")
|
||||
req = Request(url, data=payload, method="POST",
|
||||
headers={"Content-Type": "application/json"})
|
||||
t0 = time.time()
|
||||
try:
|
||||
with urlopen(req, timeout=120) as resp:
|
||||
body = json.loads(resp.read().decode("utf-8"))
|
||||
elapsed = int((time.time() - t0) * 1000)
|
||||
return body.get("response", ""), elapsed
|
||||
except Exception as exc:
|
||||
elapsed = int((time.time() - t0) * 1000)
|
||||
raise RuntimeError(str(exc)) from exc
|
||||
|
||||
|
||||
def _run_cftext(
|
||||
cforch_base: str,
|
||||
model_id: str,
|
||||
prompt: str,
|
||||
system: str,
|
||||
temperature: float,
|
||||
startup_timeout_s: float = 180.0,
|
||||
) -> tuple[str, int, bool]:
|
||||
"""Allocate cf-text via cf-orch, generate, release. Returns (response, elapsed_ms, cold_started).
|
||||
|
||||
Raises RuntimeError on allocation failure or generation error.
|
||||
cold_started=True means the service was launched from scratch (caller may log this).
|
||||
|
||||
Cold-start detection uses coordinator state signals (running/stopped) rather than
|
||||
polling the service health endpoint — this fails fast on model load errors instead
|
||||
of waiting out the full timeout.
|
||||
"""
|
||||
# Allocate
|
||||
alloc_resp = httpx.post(
|
||||
f"{cforch_base}/api/services/cf-text/allocate",
|
||||
json={
|
||||
"model_candidates": [model_id],
|
||||
"caller": "avocet",
|
||||
"pipeline": "imitate",
|
||||
},
|
||||
timeout=30.0,
|
||||
)
|
||||
alloc_resp.raise_for_status()
|
||||
data = alloc_resp.json()
|
||||
service_url: str = data["url"]
|
||||
allocation_id: str = data.get("allocation_id", "")
|
||||
node_id: str = data.get("node_id", "")
|
||||
gpu_id: int | None = data.get("gpu_id")
|
||||
cold_started = data.get("started", False) and not data.get("warm", True)
|
||||
|
||||
# Wait for ready using coordinator state signals
|
||||
if cold_started:
|
||||
deadline = time.monotonic() + startup_timeout_s
|
||||
probe_misses = 0
|
||||
while time.monotonic() < deadline:
|
||||
try:
|
||||
status = httpx.get(
|
||||
f"{cforch_base}/api/services/cf-text/status", timeout=5.0
|
||||
)
|
||||
if status.is_success:
|
||||
instances = status.json().get("instances", [])
|
||||
match = next(
|
||||
(i for i in instances
|
||||
if i.get("node_id") == node_id and i.get("gpu_id") == gpu_id),
|
||||
None,
|
||||
)
|
||||
if match:
|
||||
probe_misses = 0
|
||||
state = match.get("state", "")
|
||||
if state == "running":
|
||||
break
|
||||
elif state == "stopped":
|
||||
if allocation_id:
|
||||
httpx.delete(
|
||||
f"{cforch_base}/api/services/cf-text/allocations/{allocation_id}",
|
||||
timeout=5.0,
|
||||
)
|
||||
raise RuntimeError(f"cf-text failed to load {model_id!r} (service stopped)")
|
||||
else:
|
||||
probe_misses += 1
|
||||
if probe_misses >= 6:
|
||||
# Coordinator hasn't registered instance yet — fall back to health poll
|
||||
try:
|
||||
if httpx.get(f"{service_url}/health", timeout=3.0).is_success:
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
except RuntimeError:
|
||||
raise
|
||||
except Exception:
|
||||
pass
|
||||
time.sleep(2.0)
|
||||
else:
|
||||
if allocation_id:
|
||||
httpx.delete(f"{cforch_base}/api/services/cf-text/allocations/{allocation_id}", timeout=5.0)
|
||||
raise RuntimeError(f"cf-text cold start timed out after {startup_timeout_s:.0f}s")
|
||||
|
||||
# Generate
|
||||
messages: list[dict] = []
|
||||
if system:
|
||||
messages.append({"role": "system", "content": system})
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
t0 = time.time()
|
||||
try:
|
||||
gen_resp = httpx.post(
|
||||
f"{service_url}/v1/chat/completions",
|
||||
json={
|
||||
"model": model_id,
|
||||
"messages": messages,
|
||||
"max_tokens": 300,
|
||||
"temperature": temperature,
|
||||
"stream": False,
|
||||
},
|
||||
timeout=120.0,
|
||||
)
|
||||
gen_resp.raise_for_status()
|
||||
elapsed_ms = int((time.time() - t0) * 1000)
|
||||
content = gen_resp.json()["choices"][0]["message"]["content"]
|
||||
return content.strip(), elapsed_ms, cold_started
|
||||
except Exception as exc:
|
||||
elapsed_ms = int((time.time() - t0) * 1000)
|
||||
raise RuntimeError(str(exc)) from exc
|
||||
finally:
|
||||
if allocation_id:
|
||||
try:
|
||||
httpx.delete(f"{cforch_base}/api/services/cf-text/allocations/{allocation_id}", timeout=5.0)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
# ── GET /products ──────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/products")
|
||||
def get_products() -> dict:
|
||||
"""List configured CF products with live online status."""
|
||||
cfg = _load_imitate_config()
|
||||
products_raw = cfg.get("products", []) or []
|
||||
products = []
|
||||
for p in products_raw:
|
||||
if not isinstance(p, dict):
|
||||
continue
|
||||
base_url = p.get("base_url", "")
|
||||
products.append({
|
||||
"id": p.get("id", ""),
|
||||
"name": p.get("name", ""),
|
||||
"icon": p.get("icon", "📦"),
|
||||
"description": p.get("description", ""),
|
||||
"base_url": base_url,
|
||||
"online": _is_online(base_url, p.get("health_path", "/api/health")) if base_url else False,
|
||||
})
|
||||
return {"products": products}
|
||||
|
||||
|
||||
# ── GET /products/{product_id}/sample ─────────────────────────────────────────
|
||||
|
||||
@router.get("/products/{product_id}/sample")
|
||||
def get_sample(product_id: str, index: int = 0) -> dict:
|
||||
"""Fetch a real sample from the given product's API."""
|
||||
cfg = _load_imitate_config()
|
||||
products_raw = cfg.get("products", []) or []
|
||||
|
||||
product: dict | None = None
|
||||
for p in products_raw:
|
||||
if isinstance(p, dict) and p.get("id") == product_id:
|
||||
product = p
|
||||
break
|
||||
|
||||
if product is None:
|
||||
raise HTTPException(404, f"Product '{product_id}' not in config")
|
||||
|
||||
base_url = product.get("base_url", "").rstrip("/")
|
||||
endpoint = product.get("sample_endpoint", "")
|
||||
if not base_url or not endpoint:
|
||||
raise HTTPException(422, "Product missing base_url or sample_endpoint")
|
||||
|
||||
url = f"{base_url}{endpoint}"
|
||||
try:
|
||||
raw = _http_get_json(url, timeout=5)
|
||||
except URLError as exc:
|
||||
raise HTTPException(503, f"Product API unreachable: {exc}") from exc
|
||||
except Exception as exc:
|
||||
raise HTTPException(502, f"Bad response from product API: {exc}") from exc
|
||||
|
||||
text_fields = product.get("text_fields", []) or []
|
||||
sample_key = product.get("sample_key") or None
|
||||
extracted = _extract_sample(raw, text_fields, index, sample_key=sample_key)
|
||||
if not extracted:
|
||||
raise HTTPException(404, "No sample items returned by product API")
|
||||
|
||||
prompt_template = product.get("prompt_template", "{text}")
|
||||
prompt = prompt_template.replace("{text}", extracted["text"])
|
||||
# Also substitute any {field_name} placeholders from the raw item fields.
|
||||
item = extracted.get("item", {})
|
||||
for field, val in item.items():
|
||||
prompt = prompt.replace(f"{{{field}}}", str(val) if val is not None else "")
|
||||
|
||||
# Expose system_prompt and image_url if the product API returns them.
|
||||
# system_prompt: Peregrine, Snipe (vision analysis instructions)
|
||||
# image_url: Snipe listing photos — Avocet downloads + base64-encodes at run time
|
||||
item = extracted.get("item", {})
|
||||
system_prompt = str(item.get("system_prompt", "")) if isinstance(item, dict) else ""
|
||||
image_url = str(item.get("image_url", "")) if isinstance(item, dict) else ""
|
||||
|
||||
return {
|
||||
"product_id": product_id,
|
||||
"sample_index": index,
|
||||
"text": extracted["text"],
|
||||
"prompt": prompt,
|
||||
"system_prompt": system_prompt,
|
||||
"image_url": image_url,
|
||||
"raw_item": item,
|
||||
}
|
||||
|
||||
|
||||
# ── GET /catalog ───────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/catalog")
|
||||
def get_catalog() -> dict:
|
||||
"""Return the live cf-text model catalog from cf-orch coordinator."""
|
||||
models = _cforch_catalog(_cforch_url())
|
||||
return {"models": models}
|
||||
|
||||
|
||||
# ── GET /run (SSE) ─────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/run")
|
||||
def run_imitate(
|
||||
prompt: str = "",
|
||||
model_ids: str = "", # comma-separated ollama model IDs
|
||||
cf_text_model_ids: str = "", # comma-separated cf-text model IDs (via cf-orch)
|
||||
temperature: float = 0.7,
|
||||
product_id: str = "",
|
||||
system: str = "", # optional system prompt
|
||||
image_url: str = "", # optional image URL for vision models
|
||||
) -> StreamingResponse:
|
||||
"""Run a prompt through selected ollama models and stream results as SSE.
|
||||
|
||||
If image_url is provided, the image is downloaded once and passed to every
|
||||
model as a base64-encoded blob — allowing vision-capable local models to
|
||||
evaluate listing photos the same way Snipe's background task pipeline does.
|
||||
"""
|
||||
|
||||
if not prompt.strip():
|
||||
raise HTTPException(422, "prompt is required")
|
||||
|
||||
ollama_ids = [m.strip() for m in model_ids.split(",") if m.strip()]
|
||||
cftext_ids = [m.strip() for m in cf_text_model_ids.split(",") if m.strip()]
|
||||
if not ollama_ids and not cftext_ids:
|
||||
raise HTTPException(422, "model_ids or cf_text_model_ids is required")
|
||||
|
||||
cfg = _load_imitate_config()
|
||||
ollama_base = _ollama_url(cfg)
|
||||
cforch_base = _cforch_url()
|
||||
system_ctx = system.strip() or ""
|
||||
total_models = len(ollama_ids) + len(cftext_ids)
|
||||
|
||||
# Download image once before streaming — shared across ollama vision models
|
||||
images: list[str] = []
|
||||
if image_url.strip():
|
||||
b64 = _fetch_image_b64(image_url.strip())
|
||||
if b64:
|
||||
images = [b64]
|
||||
|
||||
def generate():
|
||||
results: list[dict] = []
|
||||
yield _sse({"type": "start", "total_models": total_models, "has_image": bool(images)})
|
||||
|
||||
# Ollama models
|
||||
for model_id in ollama_ids:
|
||||
yield _sse({"type": "model_start", "model": model_id, "service": "ollama"})
|
||||
try:
|
||||
response, elapsed_ms = _run_ollama_streaming(
|
||||
ollama_base, model_id, prompt, temperature,
|
||||
system=system_ctx, images=images or None,
|
||||
)
|
||||
result = {
|
||||
"model": model_id,
|
||||
"response": response,
|
||||
"elapsed_ms": elapsed_ms,
|
||||
"error": None,
|
||||
}
|
||||
except Exception as exc:
|
||||
result = {
|
||||
"model": model_id,
|
||||
"response": "",
|
||||
"elapsed_ms": 0,
|
||||
"error": str(exc),
|
||||
}
|
||||
results.append(result)
|
||||
yield _sse({"type": "model_done", **result})
|
||||
|
||||
# cf-text models via cf-orch — fan out in parallel when multiple models selected
|
||||
if cftext_ids:
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
# Announce all models upfront so the UI can show loading states immediately
|
||||
for model_id in cftext_ids:
|
||||
yield _sse({"type": "model_start", "model": model_id, "service": "cf-text"})
|
||||
|
||||
with ThreadPoolExecutor(max_workers=len(cftext_ids)) as pool:
|
||||
future_to_model = {
|
||||
pool.submit(_run_cftext, cforch_base, mid, prompt, system_ctx, temperature): mid
|
||||
for mid in cftext_ids
|
||||
}
|
||||
for future in as_completed(future_to_model):
|
||||
model_id = future_to_model[future]
|
||||
try:
|
||||
response, elapsed_ms, cold_started = future.result()
|
||||
if cold_started:
|
||||
yield _sse({"type": "model_coldstart", "model": model_id})
|
||||
result = {
|
||||
"model": model_id,
|
||||
"response": response,
|
||||
"elapsed_ms": elapsed_ms,
|
||||
"error": None,
|
||||
}
|
||||
except Exception as exc:
|
||||
result = {
|
||||
"model": model_id,
|
||||
"response": "",
|
||||
"elapsed_ms": 0,
|
||||
"error": str(exc),
|
||||
}
|
||||
results.append(result)
|
||||
yield _sse({"type": "model_done", **result})
|
||||
|
||||
yield _sse({"type": "complete", "results": results})
|
||||
|
||||
return StreamingResponse(
|
||||
generate(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# ── POST /push-corrections ─────────────────────────────────────────────────────
|
||||
|
||||
class ImitateResult(BaseModel):
|
||||
model: str
|
||||
response: str
|
||||
elapsed_ms: int
|
||||
error: str | None = None
|
||||
|
||||
|
||||
class PushCorrectionsRequest(BaseModel):
|
||||
product_id: str
|
||||
prompt: str
|
||||
results: list[ImitateResult]
|
||||
|
||||
|
||||
@router.post("/push-corrections")
|
||||
def push_corrections(req: PushCorrectionsRequest) -> dict:
|
||||
"""Append imitate results to sft_candidates.jsonl for human review."""
|
||||
if not req.prompt.strip():
|
||||
raise HTTPException(422, "prompt is required")
|
||||
if not req.results:
|
||||
raise HTTPException(422, "results list is empty")
|
||||
|
||||
ts = datetime.now(timezone.utc).isoformat()
|
||||
records = []
|
||||
for r in req.results:
|
||||
if r.error or not r.response.strip():
|
||||
continue
|
||||
records.append({
|
||||
"id": str(uuid.uuid4()),
|
||||
"source": "imitate",
|
||||
"product_id": req.product_id,
|
||||
"prompt_messages": [{"role": "user", "content": req.prompt}],
|
||||
"model_response": r.response,
|
||||
"model_id": r.model,
|
||||
"elapsed_ms": r.elapsed_ms,
|
||||
"status": "pending",
|
||||
"created_at": ts,
|
||||
})
|
||||
|
||||
if not records:
|
||||
raise HTTPException(422, "No non-error results to push")
|
||||
|
||||
dest = _candidates_file()
|
||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||
for record in records:
|
||||
append_jsonl(dest, record)
|
||||
|
||||
return {"pushed": len(records)}
|
||||
"""Backward-compat shim -- logic moved to app/data/imitate.py."""
|
||||
from app.data.imitate import router # noqa: F401
|
||||
from app.data.imitate import set_config_dir, set_data_dir # noqa: F401
|
||||
|
|
|
|||
163
app/models.py
163
app/models.py
|
|
@ -15,6 +15,7 @@ from __future__ import annotations
|
|||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import threading
|
||||
from datetime import datetime, timezone
|
||||
|
|
@ -37,13 +38,17 @@ except ImportError: # pragma: no cover
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
_ROOT = Path(__file__).parent.parent
|
||||
_MODELS_DIR: Path = _ROOT / "models"
|
||||
_MODELS_DIR: Path = Path(
|
||||
os.environ.get("AVOCET_MODELS_DIR", str(_ROOT / "models"))
|
||||
)
|
||||
_QUEUE_DIR: Path = _ROOT / "data"
|
||||
|
||||
# Service-specific model destinations.
|
||||
# cf-text models land on the NFS-mounted shared asset store so every cluster
|
||||
# node can reach them without a separate download. Avocet classifiers stay local
|
||||
# because they are fine-tuned in-place and are only consumed by avocet itself.
|
||||
# node can reach them without a separate download. Avocet classifiers default
|
||||
# to a local path but can be redirected via AVOCET_MODELS_DIR — set this to
|
||||
# /Library/Assets/LLM/avocet/models on NFS-connected nodes to keep all model
|
||||
# weights out of the repo directory.
|
||||
# Override via CF_TEXT_MODELS_DIR env var (useful for dev / non-NFS setups).
|
||||
_CF_TEXT_MODELS_DIR: Path = Path(
|
||||
os.environ.get("CF_TEXT_MODELS_DIR", "/Library/Assets/LLM/cf-text/models")
|
||||
|
|
@ -60,6 +65,30 @@ _CF_ORCH_PROFILES_DIR: Path = Path(
|
|||
|
||||
router = APIRouter()
|
||||
|
||||
# ── HuggingFace auth ─────────────────────────────────────────────────────────
|
||||
|
||||
def _get_hf_token() -> str | None:
|
||||
"""Return HF token from label_tool.yaml, then HF_TOKEN / HUGGING_FACE_HUB_TOKEN env vars."""
|
||||
config_file = _ROOT / "config" / "label_tool.yaml"
|
||||
if config_file.exists():
|
||||
try:
|
||||
import yaml as _yaml
|
||||
raw = _yaml.safe_load(config_file.read_text(encoding="utf-8")) or {}
|
||||
token = (raw.get("hf_token") or raw.get("cforch", {}).get("hf_token") or "").strip()
|
||||
if token:
|
||||
return token
|
||||
except Exception:
|
||||
pass
|
||||
return os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN") or None
|
||||
|
||||
|
||||
# ── GGUF quantization detection ───────────────────────────────────────────────
|
||||
# Matches quant identifiers in GGUF filenames: Q4_K_M, Q8_0, F16, IQ3_M, etc.
|
||||
_QUANT_RE = re.compile(
|
||||
r'[._-]((?:IQ\d|Q\d)[A-Z0-9_]*|F16|BF16)\.gguf$',
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
# ── Download progress shared state ────────────────────────────────────────────
|
||||
# Updated by the background download thread; read by GET /download/stream.
|
||||
_download_progress: dict[str, Any] = {}
|
||||
|
|
@ -91,12 +120,16 @@ _TAG_TO_INFO: dict[str, _TagInfo] = {
|
|||
"audio-classification": {"adapter": None, "role": "classifier", "service": "cf-voice"},
|
||||
# TTS — cf-tts text-to-speech service
|
||||
"text-to-speech": {"adapter": None, "role": "tts", "service": "cf-tts"},
|
||||
# Vision — cf-vision image classification / embedding / VLM service
|
||||
# Vision classifiers / embedders — cf-vision (SigLIP/CLIP-style models)
|
||||
"image-classification": {"adapter": None, "role": "vision", "service": "cf-vision"},
|
||||
"zero-shot-image-classification": {"adapter": None, "role": "vision", "service": "cf-vision"},
|
||||
"image-feature-extraction": {"adapter": None, "role": "embedding", "service": "cf-vision"},
|
||||
"image-text-to-text": {"adapter": None, "role": "vlm", "service": "cf-vision"},
|
||||
"visual-question-answering": {"adapter": None, "role": "vlm", "service": "cf-vision"},
|
||||
# Generative VLMs (image+text → text) — GGUF quants run via llama.cpp (cf-text).
|
||||
# cf-vision is a classifier/embedder service; generative VLMs like Qwen2-VL
|
||||
# and LLaVA accept image inputs but are textgen at the backend level.
|
||||
# Full-precision HF-format VLMs would use vllm, but our fleet uses GGUF quants.
|
||||
"image-text-to-text": {"adapter": None, "role": "vlm", "service": "cf-text"},
|
||||
"visual-question-answering": {"adapter": None, "role": "vlm", "service": "cf-text"},
|
||||
# Image generation — cf-image (text → image; distinct from cf-vision image understanding)
|
||||
"text-to-image": {"adapter": None, "role": "image-gen", "service": "cf-image"},
|
||||
# Embedding — cf-core shared embedding layer
|
||||
|
|
@ -111,6 +144,11 @@ def set_models_dir(path: Path) -> None:
|
|||
_MODELS_DIR = path
|
||||
|
||||
|
||||
def set_cf_text_models_dir(path: Path) -> None:
|
||||
global _CF_TEXT_MODELS_DIR
|
||||
_CF_TEXT_MODELS_DIR = path
|
||||
|
||||
|
||||
def set_queue_dir(path: Path) -> None:
|
||||
global _QUEUE_DIR
|
||||
_QUEUE_DIR = path
|
||||
|
|
@ -197,8 +235,15 @@ def _catalog_key(repo_id: str) -> str:
|
|||
|
||||
ibm-granite/granite-4.1-8b → granite-4.1-8b
|
||||
facebook/bart-large-cnn → bart-large-cnn
|
||||
WithinUsAI/Opus4.7-GODs.Ghost.Codex-4B.GGuF → opus4.7-gods.ghost.codex-4b
|
||||
|
||||
The coordinator skips catalog lookup for keys ending in ".gguf" (treats them
|
||||
as direct file paths). Strip the suffix so GGUF repo names produce valid keys.
|
||||
"""
|
||||
return repo_id.split("/", 1)[-1].lower()
|
||||
key = repo_id.split("/", 1)[-1].lower()
|
||||
if key.endswith(".gguf"):
|
||||
key = key[:-5]
|
||||
return key
|
||||
|
||||
|
||||
def _insert_catalog_entry(content: str, entry_lines: str) -> str:
|
||||
|
|
@ -290,6 +335,15 @@ def _register_in_node_catalogs(
|
|||
max_mb: int = cf_text.get("max_mb", 0)
|
||||
catalog: dict = cf_text.get("catalog") or {}
|
||||
|
||||
# If the node has a different local model dir, remap the NFS path.
|
||||
model_base = cf_text.get("model_base_path", "").rstrip("/")
|
||||
if model_base:
|
||||
nfs_base = str(_CF_TEXT_MODELS_DIR).rstrip("/")
|
||||
model_name = local_path.name
|
||||
effective_path_str = f"{model_base}/{model_name}"
|
||||
else:
|
||||
effective_path_str = local_path_str
|
||||
|
||||
# Skip if key already exists
|
||||
if model_key in catalog:
|
||||
logger.debug("Key %r already in %s — skipping", model_key, yaml_file.name)
|
||||
|
|
@ -301,10 +355,10 @@ def _register_in_node_catalogs(
|
|||
for entry in catalog.values()
|
||||
if isinstance(entry, dict)
|
||||
}
|
||||
if local_path_str in registered_paths or any(
|
||||
p.startswith(local_path_str + "/") for p in registered_paths
|
||||
if effective_path_str in registered_paths or any(
|
||||
p.startswith(effective_path_str + "/") for p in registered_paths
|
||||
):
|
||||
logger.debug("Path %s already registered in %s — skipping", local_path_str, yaml_file.name)
|
||||
logger.debug("Path %s already registered in %s — skipping", effective_path_str, yaml_file.name)
|
||||
continue
|
||||
|
||||
# Determine whether model fits at FP16 or needs 4-bit
|
||||
|
|
@ -330,12 +384,18 @@ def _register_in_node_catalogs(
|
|||
if needs_4bit
|
||||
else f" # FP16 file-size estimate"
|
||||
)
|
||||
env_block = (
|
||||
f" env:\n"
|
||||
f" CF_TEXT_4BIT: \"1\"\n"
|
||||
if needs_4bit else ""
|
||||
)
|
||||
entry_block = (
|
||||
f" # auto-registered by avocet on download\n"
|
||||
f" {model_key}:\n"
|
||||
f" path: {local_path_str}\n"
|
||||
f" path: {effective_path_str}\n"
|
||||
f" vram_mb: {vram_for_node}{vram_comment}\n"
|
||||
f" description: \"{desc}\"\n"
|
||||
f"{env_block}"
|
||||
)
|
||||
|
||||
new_content = _insert_catalog_entry(content, entry_block)
|
||||
|
|
@ -388,12 +448,17 @@ def _run_download(
|
|||
role: str | None = None,
|
||||
service: str | None = None,
|
||||
model_size_bytes: int = 0,
|
||||
quant_pattern: str | None = None,
|
||||
) -> None:
|
||||
"""Background thread: download model via huggingface_hub.snapshot_download.
|
||||
|
||||
model_size_bytes is the sum of file sizes reported by the HF API (siblings).
|
||||
It is used to estimate vram_mb and written to model_info.json so cf-orch can
|
||||
budget VRAM when allocating a cf-text instance for this model.
|
||||
|
||||
quant_pattern: when set, restricts snapshot_download to only files matching
|
||||
*{quant_pattern}*.gguf (plus metadata). Avoids downloading every quant variant
|
||||
from GGUF-only repos like bartowski/*.
|
||||
"""
|
||||
global _download_progress
|
||||
local_dir = _model_dir_for(repo_id, service)
|
||||
|
|
@ -422,10 +487,20 @@ def _run_download(
|
|||
|
||||
local_dir.mkdir(parents=True, exist_ok=True)
|
||||
poll_thread.start()
|
||||
snapshot_download(
|
||||
repo_id=repo_id,
|
||||
local_dir=str(local_dir),
|
||||
)
|
||||
|
||||
dl_kwargs: dict[str, Any] = {"repo_id": repo_id, "local_dir": str(local_dir)}
|
||||
hf_token = _get_hf_token()
|
||||
if hf_token:
|
||||
dl_kwargs["token"] = hf_token
|
||||
if quant_pattern:
|
||||
# Include both cases: repos use mixed conventions (Q6_K vs q6_k).
|
||||
dl_kwargs["allow_patterns"] = [
|
||||
f"*{quant_pattern.upper()}*.gguf",
|
||||
f"*{quant_pattern.lower()}*.gguf",
|
||||
"*.json",
|
||||
"README.md",
|
||||
]
|
||||
snapshot_download(**dl_kwargs)
|
||||
|
||||
# Estimate VRAM from reported file size.
|
||||
# HF siblings sizes are pre-quantisation file sizes; add 10% for KV cache
|
||||
|
|
@ -531,9 +606,31 @@ def lookup_model(repo_id: str) -> dict:
|
|||
)
|
||||
logger.warning("Unsupported pipeline_tag %r for %s", pipeline_tag, repo_id)
|
||||
|
||||
# Estimate model size from siblings list
|
||||
# Detect GGUF files and parse quant names from siblings list.
|
||||
# For GGUF-only repos (bartowski, TheBloke, etc.) this lets the UI show
|
||||
# a per-quant size picker instead of downloading every variant.
|
||||
siblings = data.get("siblings") or []
|
||||
model_size_bytes: int = sum(s.get("size", 0) for s in siblings if isinstance(s, dict))
|
||||
gguf_files: list[dict] = []
|
||||
for s in siblings:
|
||||
if not isinstance(s, dict):
|
||||
continue
|
||||
fname: str = s.get("rfilename", "")
|
||||
if not fname.lower().endswith(".gguf"):
|
||||
continue
|
||||
m = _QUANT_RE.search(fname)
|
||||
gguf_files.append({
|
||||
"filename": fname,
|
||||
"size": s.get("size", 0) or 0,
|
||||
"quant_name": m.group(1).upper() if m else None,
|
||||
})
|
||||
gguf_files.sort(key=lambda f: f["size"])
|
||||
|
||||
# model_size_bytes: total of all siblings (for non-GGUF repos) or all GGUFs only.
|
||||
# For GGUF repos the frontend will substitute the selected quant's size on submit.
|
||||
if gguf_files:
|
||||
model_size_bytes: int = sum(f["size"] for f in gguf_files)
|
||||
else:
|
||||
model_size_bytes = sum(s.get("size", 0) for s in siblings if isinstance(s, dict))
|
||||
|
||||
# Description: first 300 chars of card data (modelId field used as fallback)
|
||||
card_data = data.get("cardData") or {}
|
||||
|
|
@ -549,6 +646,7 @@ def lookup_model(repo_id: str) -> dict:
|
|||
"compatible": compatible,
|
||||
"warning": warning,
|
||||
"model_size_bytes": model_size_bytes,
|
||||
"gguf_files": gguf_files if gguf_files else None,
|
||||
"description": description,
|
||||
"tags": data.get("tags") or [],
|
||||
"downloads": data.get("downloads") or 0,
|
||||
|
|
@ -579,6 +677,9 @@ class QueueAddRequest(BaseModel):
|
|||
# Stored in the queue entry so approve can pass it to _run_download
|
||||
# without a second HF API round-trip.
|
||||
model_size_bytes: int = 0
|
||||
# GGUF quantization pattern (e.g. "Q5_K_M"). When set, snapshot_download
|
||||
# restricts to *{quant_pattern}*.gguf instead of fetching all variants.
|
||||
quant_pattern: str | None = None
|
||||
|
||||
|
||||
@router.post("/queue", status_code=201)
|
||||
|
|
@ -597,6 +698,7 @@ def add_to_queue(req: QueueAddRequest) -> dict:
|
|||
"role": req.role,
|
||||
"service": req.service,
|
||||
"model_size_bytes": req.model_size_bytes,
|
||||
"quant_pattern": req.quant_pattern,
|
||||
"status": "pending",
|
||||
"queued_at": datetime.now(timezone.utc).isoformat(),
|
||||
}
|
||||
|
|
@ -629,6 +731,7 @@ def approve_queue_entry(entry_id: str) -> dict:
|
|||
entry.get("role"),
|
||||
entry.get("service"),
|
||||
entry.get("model_size_bytes", 0),
|
||||
entry.get("quant_pattern"),
|
||||
),
|
||||
daemon=True,
|
||||
name=f"model-download-{entry_id}",
|
||||
|
|
@ -638,6 +741,32 @@ def approve_queue_entry(entry_id: str) -> dict:
|
|||
return {"ok": True}
|
||||
|
||||
|
||||
# ── PATCH /queue/{id} ─────────────────────────────────────────────────────────
|
||||
|
||||
class QueuePatchRequest(BaseModel):
|
||||
service: str | None = None
|
||||
role: str | None = None
|
||||
|
||||
|
||||
@router.patch("/queue/{entry_id}")
|
||||
def patch_queue_entry(entry_id: str, body: QueuePatchRequest) -> dict:
|
||||
"""Update mutable fields (service, role) on a pending queue entry."""
|
||||
entry = _get_queue_entry(entry_id)
|
||||
if entry is None:
|
||||
raise HTTPException(404, f"Queue entry {entry_id!r} not found")
|
||||
if entry.get("status") != "pending":
|
||||
raise HTTPException(409, f"Only pending entries can be patched (current: {entry.get('status')!r})")
|
||||
|
||||
updates: dict = {}
|
||||
if body.service is not None:
|
||||
updates["service"] = body.service
|
||||
if body.role is not None:
|
||||
updates["role"] = body.role
|
||||
|
||||
updated = _update_queue_entry(entry_id, updates)
|
||||
return updated or {}
|
||||
|
||||
|
||||
# ── DELETE /queue/{id} ─────────────────────────────────────────────────────────
|
||||
|
||||
@router.delete("/queue/{entry_id}")
|
||||
|
|
|
|||
535
app/nodes.py
Normal file
535
app/nodes.py
Normal file
|
|
@ -0,0 +1,535 @@
|
|||
"""Avocet — Node Management API.
|
||||
|
||||
Proxies cf-orch coordinator and agent APIs to expose per-node GPU state,
|
||||
service affinity management, and Ollama model management.
|
||||
|
||||
Config is read from label_tool.yaml under the `cforch:` key.
|
||||
The `profiles_dir` key (new) points to the cf-orch node profile YAML directory.
|
||||
|
||||
Module-level globals follow the set_config_dir() testability pattern from cforch.py.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import yaml
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_ROOT = Path(__file__).parent.parent
|
||||
_CONFIG_DIR: Path | None = None # override in tests
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# ── Testability seams ──────────────────────────────────────────────────────────
|
||||
|
||||
def set_config_dir(path: Path | None) -> None:
|
||||
global _CONFIG_DIR
|
||||
_CONFIG_DIR = path
|
||||
|
||||
|
||||
# ── Internal helpers ───────────────────────────────────────────────────────────
|
||||
|
||||
def _config_file() -> Path:
|
||||
if _CONFIG_DIR is not None:
|
||||
return _CONFIG_DIR / "label_tool.yaml"
|
||||
return _ROOT / "config" / "label_tool.yaml"
|
||||
|
||||
|
||||
def _load_config() -> dict:
|
||||
"""Read label_tool.yaml cforch section. Returns empty dict on missing or parse error."""
|
||||
f = _config_file()
|
||||
if not f.exists():
|
||||
return {}
|
||||
try:
|
||||
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
||||
return raw.get("cforch", {}) or {}
|
||||
except yaml.YAMLError as exc:
|
||||
logger.warning("Failed to parse config %s: %s", f, exc)
|
||||
return {}
|
||||
|
||||
|
||||
def _profiles_dir() -> Path | None:
|
||||
"""Return the cf-orch node profiles directory, or None if not configured."""
|
||||
cfg = _load_config()
|
||||
pd = cfg.get("profiles_dir", "") or ""
|
||||
if pd:
|
||||
return Path(pd)
|
||||
bench = cfg.get("bench_script", "") or ""
|
||||
if bench:
|
||||
return Path(bench).parent.parent / "profiles" / "nodes"
|
||||
return None
|
||||
|
||||
|
||||
def _profile_path(node_id: str) -> Path | None:
|
||||
"""Return the path to a node's profile YAML, or None if profiles_dir is unknown."""
|
||||
pd = _profiles_dir()
|
||||
if pd is None:
|
||||
return None
|
||||
return pd / f"{node_id}.yaml"
|
||||
|
||||
|
||||
def _load_profile(node_id: str) -> dict | None:
|
||||
"""Load and parse a node profile YAML. Returns None if not found or malformed."""
|
||||
p = _profile_path(node_id)
|
||||
if p is None or not p.exists():
|
||||
return None
|
||||
try:
|
||||
return yaml.safe_load(p.read_text(encoding="utf-8")) or {}
|
||||
except yaml.YAMLError as exc:
|
||||
logger.warning("Malformed profile YAML %s: %s", p, exc)
|
||||
return None
|
||||
|
||||
|
||||
def _get_ollama_url(node_id: str) -> str:
|
||||
"""Derive Ollama URL from the node profile's agent_url (same host, port 11434)."""
|
||||
profile = _load_profile(node_id)
|
||||
if profile:
|
||||
nodes_section = profile.get("nodes", {}) or {}
|
||||
node_entry = nodes_section.get(node_id, {}) or {}
|
||||
agent_url = node_entry.get("agent_url", "") or ""
|
||||
if agent_url:
|
||||
parsed = urlparse(agent_url)
|
||||
return f"{parsed.scheme}://{parsed.hostname}:11434"
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Cannot determine Ollama URL for node {node_id}: no agent_url in profile",
|
||||
)
|
||||
|
||||
|
||||
# ── Endpoints ──────────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/nodes")
|
||||
def list_nodes() -> list:
|
||||
"""Return all nodes with live GPU stats merged with profile YAML."""
|
||||
import httpx
|
||||
|
||||
cfg = _load_config()
|
||||
coordinator_url = cfg.get("coordinator_url", "") or ""
|
||||
if not coordinator_url:
|
||||
return []
|
||||
|
||||
try:
|
||||
r = httpx.get(f"{coordinator_url}/api/nodes", timeout=5.0)
|
||||
r.raise_for_status()
|
||||
coord_nodes: list[dict] = r.json().get("nodes", [])
|
||||
except httpx.HTTPError as exc:
|
||||
logger.warning("Coordinator unreachable: %s", exc)
|
||||
return []
|
||||
|
||||
try:
|
||||
sr = httpx.get(f"{coordinator_url}/api/services", timeout=5.0)
|
||||
sr.raise_for_status()
|
||||
services_data: list[dict] = sr.json().get("services", [])
|
||||
except httpx.HTTPError:
|
||||
logger.warning("Services API unreachable for %s, skipping", coordinator_url)
|
||||
services_data = []
|
||||
|
||||
# Build per-node, per-GPU running services map
|
||||
running: dict[str, dict[int, list[str]]] = {}
|
||||
for svc in services_data:
|
||||
nid = svc.get("node_id", "")
|
||||
gid = svc.get("gpu_id")
|
||||
svc_name = svc.get("service", "")
|
||||
if nid and gid is not None and svc_name:
|
||||
running.setdefault(nid, {}).setdefault(gid, []).append(svc_name)
|
||||
|
||||
result = []
|
||||
for node in coord_nodes:
|
||||
node_id = node.get("node_id", "") or node.get("id", "")
|
||||
profile = _load_profile(node_id) if node_id else None
|
||||
profile_loaded = profile is not None
|
||||
|
||||
gpus = []
|
||||
for gpu in (node.get("gpus", []) or []):
|
||||
gpu_id = gpu.get("gpu_id", gpu.get("id", 0))
|
||||
services_assigned: list[str] = []
|
||||
if profile:
|
||||
node_entry = (profile.get("nodes", {}) or {}).get(node_id, {}) or {}
|
||||
for g in (node_entry.get("gpus", []) or []):
|
||||
if isinstance(g, dict) and g.get("id") == gpu_id:
|
||||
services_assigned = g.get("services", []) or []
|
||||
break
|
||||
gpus.append({
|
||||
"gpu_id": gpu_id,
|
||||
"card": gpu.get("card", ""),
|
||||
"vram_total_mb": gpu.get("vram_total_mb", 0),
|
||||
"vram_used_mb": gpu.get("vram_used_mb", 0),
|
||||
"vram_free_mb": gpu.get("vram_free_mb", 0),
|
||||
"temp_c": gpu.get("temp_c"),
|
||||
"utilization_pct": gpu.get("utilization_pct"),
|
||||
"compute_cap": gpu.get("compute_cap"),
|
||||
"services_assigned": services_assigned,
|
||||
"services_running": running.get(node_id, {}).get(gpu_id, []),
|
||||
})
|
||||
|
||||
services_catalog: dict = {}
|
||||
if profile:
|
||||
for svc_name, svc_info in (profile.get("services", {}) or {}).items():
|
||||
catalog = svc_info.get("catalog", {}) or {}
|
||||
services_catalog[svc_name] = {
|
||||
"min_compute_cap": svc_info.get("min_compute_cap", 0.0),
|
||||
"max_mb": svc_info.get("max_mb", 0),
|
||||
"catalog_size": len(catalog),
|
||||
}
|
||||
|
||||
result.append({
|
||||
"node_id": node_id,
|
||||
"online": node.get("online", True),
|
||||
"agent_url": node.get("agent_url", ""),
|
||||
"gpus": gpus,
|
||||
"profile_loaded": profile_loaded,
|
||||
"services_catalog": services_catalog,
|
||||
})
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@router.get("/nodes/{node_id}/profile")
|
||||
def get_node_profile(node_id: str) -> dict:
|
||||
"""Return the full parsed profile YAML for a node."""
|
||||
p = _profile_path(node_id)
|
||||
if p is None or not p.exists():
|
||||
raise HTTPException(404, f"No profile found for node {node_id}")
|
||||
try:
|
||||
data = yaml.safe_load(p.read_text(encoding="utf-8")) or {}
|
||||
except yaml.YAMLError as exc:
|
||||
raise HTTPException(500, f"Malformed profile YAML: {exc}")
|
||||
return data
|
||||
|
||||
|
||||
class UpdateServicesRequest(BaseModel):
|
||||
services: list[str]
|
||||
|
||||
|
||||
@router.post("/nodes/{node_id}/gpu/{gpu_id}/services")
|
||||
def update_gpu_services(node_id: str, gpu_id: int, body: UpdateServicesRequest) -> dict:
|
||||
"""Set service assignment for a GPU with compatibility validation, then atomic write."""
|
||||
import httpx
|
||||
|
||||
cfg = _load_config()
|
||||
coordinator_url = cfg.get("coordinator_url", "") or ""
|
||||
|
||||
p = _profile_path(node_id)
|
||||
if p is None or not p.exists():
|
||||
raise HTTPException(404, f"No profile found for node {node_id}")
|
||||
|
||||
try:
|
||||
profile = yaml.safe_load(p.read_text(encoding="utf-8")) or {}
|
||||
except yaml.YAMLError as exc:
|
||||
raise HTTPException(500, f"Malformed profile YAML: {exc}")
|
||||
|
||||
nodes_section = profile.get("nodes", {}) or {}
|
||||
node_entry = nodes_section.get(node_id, {}) or {}
|
||||
gpu_list = node_entry.get("gpus", []) or []
|
||||
|
||||
gpu_entry = next(
|
||||
(g for g in gpu_list if isinstance(g, dict) and g.get("id") == gpu_id),
|
||||
None,
|
||||
)
|
||||
if gpu_entry is None:
|
||||
raise HTTPException(404, f"GPU {gpu_id} not found in profile for node {node_id}")
|
||||
|
||||
gpu_compute_cap: float = gpu_entry.get("compute_cap") or 0.0
|
||||
gpu_vram_mb: int = gpu_entry.get("vram_mb") or 0
|
||||
services_def = profile.get("services", {}) or {}
|
||||
|
||||
for svc_name in body.services:
|
||||
if svc_name not in services_def:
|
||||
raise HTTPException(422, f"Service '{svc_name}' not defined in profile services dict")
|
||||
svc = services_def[svc_name]
|
||||
min_cap: float = svc.get("min_compute_cap", 0.0) or 0.0
|
||||
if gpu_compute_cap < min_cap:
|
||||
raise HTTPException(
|
||||
422,
|
||||
f"Service '{svc_name}' requires compute_cap >= {min_cap}; GPU has {gpu_compute_cap}",
|
||||
)
|
||||
catalog = svc.get("catalog", {}) or {}
|
||||
min_catalog_vram = (
|
||||
min((m.get("vram_mb", 0) for m in catalog.values()), default=0)
|
||||
if catalog else svc.get("max_mb", 0)
|
||||
)
|
||||
if gpu_vram_mb < min_catalog_vram:
|
||||
raise HTTPException(
|
||||
422,
|
||||
f"Service '{svc_name}' requires {min_catalog_vram} MB VRAM; GPU has {gpu_vram_mb} MB",
|
||||
)
|
||||
|
||||
# Immutable update of GPU services list
|
||||
new_gpu_list = [
|
||||
({**g, "services": body.services} if isinstance(g, dict) and g.get("id") == gpu_id else g)
|
||||
for g in gpu_list
|
||||
]
|
||||
new_profile = {
|
||||
**profile,
|
||||
"nodes": {
|
||||
**nodes_section,
|
||||
node_id: {**node_entry, "gpus": new_gpu_list},
|
||||
},
|
||||
}
|
||||
|
||||
# Atomic write: write to .tmp then rename
|
||||
tmp_yaml = Path(str(p) + ".tmp")
|
||||
tmp_yaml.write_text(yaml.dump(new_profile, default_flow_style=False), encoding="utf-8")
|
||||
os.replace(tmp_yaml, p)
|
||||
|
||||
# Trigger coordinator profile reload
|
||||
reloaded = False
|
||||
if coordinator_url:
|
||||
try:
|
||||
rr = httpx.post(
|
||||
f"{coordinator_url}/api/nodes/{node_id}/reload-profile", timeout=5.0
|
||||
)
|
||||
reloaded = rr.status_code < 300
|
||||
except Exception as exc:
|
||||
logger.warning("Coordinator reload failed for node %s: %s", node_id, exc)
|
||||
|
||||
return {"ok": True, "reloaded": reloaded, "warnings": []}
|
||||
|
||||
# ── Profile save / generate ────────────────────────────────────────────────────
|
||||
|
||||
class SaveProfileRequest(BaseModel):
|
||||
profile: dict
|
||||
|
||||
|
||||
@router.put("/nodes/{node_id}/profile", status_code=200)
|
||||
def save_profile(node_id: str, body: SaveProfileRequest) -> dict:
|
||||
"""Write a full profile dict to disk as YAML, then trigger coordinator reload."""
|
||||
p = _profile_path(node_id)
|
||||
if p is None:
|
||||
raise HTTPException(500, "profiles_dir not configured in label_tool.yaml")
|
||||
p.parent.mkdir(parents=True, exist_ok=True)
|
||||
tmp = Path(str(p) + ".tmp")
|
||||
tmp.write_text(
|
||||
yaml.dump(body.profile, default_flow_style=False, allow_unicode=True, sort_keys=False),
|
||||
encoding="utf-8",
|
||||
)
|
||||
os.replace(tmp, p)
|
||||
|
||||
cfg = _load_config()
|
||||
coordinator_url = cfg.get("coordinator_url", "") or ""
|
||||
reloaded = False
|
||||
if coordinator_url:
|
||||
try:
|
||||
import httpx
|
||||
rr = httpx.post(f"{coordinator_url}/api/nodes/{node_id}/reload-profile", timeout=5.0)
|
||||
reloaded = rr.status_code < 300
|
||||
except Exception as exc:
|
||||
logger.warning("Coordinator reload failed for %s: %s", node_id, exc)
|
||||
return {"ok": True, "reloaded": reloaded}
|
||||
|
||||
|
||||
@router.post("/nodes/{node_id}/profile/generate")
|
||||
def generate_profile(node_id: str) -> dict:
|
||||
"""Return a profile skeleton seeded from coordinator GPU data.
|
||||
|
||||
If a profile already exists, preserves its services section and only
|
||||
refreshes the nodes hardware section. Never writes to disk — the caller
|
||||
must call PUT /profile to persist.
|
||||
"""
|
||||
import httpx
|
||||
|
||||
cfg = _load_config()
|
||||
coordinator_url = cfg.get("coordinator_url", "") or ""
|
||||
if not coordinator_url:
|
||||
raise HTTPException(503, "coordinator_url not configured")
|
||||
|
||||
try:
|
||||
r = httpx.get(f"{coordinator_url}/api/nodes", timeout=5.0)
|
||||
r.raise_for_status()
|
||||
coord_nodes: list[dict] = r.json().get("nodes", [])
|
||||
except httpx.HTTPError as exc:
|
||||
raise HTTPException(502, f"Coordinator unreachable: {exc}")
|
||||
|
||||
node = next((n for n in coord_nodes if n.get("node_id") == node_id), None)
|
||||
if node is None:
|
||||
raise HTTPException(404, f"Node {node_id!r} not found in coordinator")
|
||||
|
||||
gpus = [
|
||||
{
|
||||
"id": g.get("gpu_id", i),
|
||||
"vram_mb": g.get("vram_total_mb", 0),
|
||||
"compute_cap": g.get("compute_cap", 0.0),
|
||||
"card": g.get("card", g.get("name", "")),
|
||||
"role": "inference",
|
||||
"services": [],
|
||||
}
|
||||
for i, g in enumerate(node.get("gpus", []))
|
||||
]
|
||||
vram_total = max((g["vram_mb"] for g in gpus), default=0)
|
||||
|
||||
existing = _load_profile(node_id) or {}
|
||||
return {
|
||||
"schema_version": existing.get("schema_version", 1),
|
||||
"name": existing.get("name", f"node-{node_id}"),
|
||||
"vram_total_mb": vram_total,
|
||||
"eviction_timeout_s": existing.get("eviction_timeout_s", 10.0),
|
||||
"services": existing.get("services", {}),
|
||||
"nodes": {
|
||||
node_id: {
|
||||
"local_model_root": (
|
||||
(existing.get("nodes", {}) or {})
|
||||
.get(node_id, {})
|
||||
.get("local_model_root", "")
|
||||
),
|
||||
"gpus": gpus,
|
||||
}
|
||||
},
|
||||
"model_size_hints": existing.get("model_size_hints", {}),
|
||||
}
|
||||
|
||||
|
||||
# ── Ollama model management ────────────────────────────────────────────────────
|
||||
|
||||
class PullRequest(BaseModel):
|
||||
name: str
|
||||
|
||||
|
||||
@router.get("/nodes/{node_id}/models/ollama")
|
||||
def list_ollama_models(node_id: str) -> dict:
|
||||
"""Proxy GET {ollama_url}/api/tags for a specific node."""
|
||||
import httpx
|
||||
|
||||
ollama_url = _get_ollama_url(node_id)
|
||||
try:
|
||||
r = httpx.get(f"{ollama_url}/api/tags", timeout=10.0)
|
||||
r.raise_for_status()
|
||||
return r.json()
|
||||
except Exception as exc:
|
||||
return {"error": str(exc)}
|
||||
|
||||
|
||||
@router.post("/nodes/{node_id}/models/ollama/pull")
|
||||
def pull_ollama_model(node_id: str, body: PullRequest) -> StreamingResponse:
|
||||
"""Stream Ollama pull progress as SSE events."""
|
||||
import httpx
|
||||
|
||||
if not body.name:
|
||||
raise HTTPException(400, "name is required")
|
||||
|
||||
ollama_url = _get_ollama_url(node_id)
|
||||
|
||||
def stream():
|
||||
try:
|
||||
with httpx.stream(
|
||||
"POST",
|
||||
f"{ollama_url}/api/pull",
|
||||
json={"name": body.name, "stream": True},
|
||||
timeout=300.0,
|
||||
) as resp:
|
||||
for line in resp.iter_lines():
|
||||
if line:
|
||||
yield f"data: {line}\n\n"
|
||||
except Exception as exc:
|
||||
yield f"data: {json.dumps({'error': str(exc)})}\n\n"
|
||||
|
||||
return StreamingResponse(stream(), media_type="text/event-stream")
|
||||
|
||||
|
||||
@router.delete("/nodes/{node_id}/models/ollama/{name:path}")
|
||||
def delete_ollama_model(node_id: str, name: str) -> dict:
|
||||
"""Proxy DELETE to Ollama for a specific node."""
|
||||
import httpx
|
||||
|
||||
ollama_url = _get_ollama_url(node_id)
|
||||
try:
|
||||
r = httpx.request("DELETE", f"{ollama_url}/api/delete", json={"name": name}, timeout=10.0)
|
||||
if r.status_code == 404:
|
||||
raise HTTPException(404, f"Model '{name}' not found on node {node_id}")
|
||||
r.raise_for_status()
|
||||
return {"ok": True}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
raise HTTPException(502, f"Ollama unreachable: {exc}")
|
||||
|
||||
|
||||
# ── Model deploy (add catalog entry) ──────────────────────────────────────────
|
||||
|
||||
class DeployModelRequest(BaseModel):
|
||||
model_id: str
|
||||
service_type: str
|
||||
vram_mb: int
|
||||
description: str = ""
|
||||
hf_repo: str = ""
|
||||
path: str = "" # explicit path; if empty, constructed from model_base_path + hf_repo slug
|
||||
|
||||
|
||||
@router.post("/nodes/{node_id}/models/deploy", status_code=200)
|
||||
def deploy_model(node_id: str, body: DeployModelRequest) -> dict:
|
||||
"""Register a model in the node's service catalog.
|
||||
|
||||
Adds (or updates) the catalog entry for body.model_id under the given
|
||||
service_type in the node's profile YAML, then triggers a coordinator reload.
|
||||
Does not download the model — that is the user's responsibility.
|
||||
Returns the resolved path so the caller can see where the model should land.
|
||||
"""
|
||||
p = _profile_path(node_id)
|
||||
if p is None or not p.exists():
|
||||
raise HTTPException(404, f"No profile found for node {node_id!r}")
|
||||
|
||||
try:
|
||||
profile = yaml.safe_load(p.read_text(encoding="utf-8")) or {}
|
||||
except yaml.YAMLError as exc:
|
||||
raise HTTPException(500, f"Malformed profile YAML: {exc}")
|
||||
|
||||
services_def = profile.get("services", {}) or {}
|
||||
svc = services_def.get(body.service_type)
|
||||
if svc is None:
|
||||
raise HTTPException(
|
||||
422,
|
||||
f"Service '{body.service_type}' not defined in node '{node_id}' profile; "
|
||||
"add it first via the profile editor",
|
||||
)
|
||||
|
||||
# Resolve path: explicit > model_base_path + hf slug > model_id slug
|
||||
model_path = body.path.strip()
|
||||
if not model_path:
|
||||
base = (svc.get("model_base_path", "") or "").rstrip("/")
|
||||
if not base:
|
||||
raise HTTPException(
|
||||
422,
|
||||
f"Service '{body.service_type}' has no model_base_path; supply an explicit path",
|
||||
)
|
||||
slug_src = body.hf_repo.strip() if body.hf_repo.strip() else body.model_id
|
||||
hf_slug = slug_src.replace("/", "--")
|
||||
model_path = f"{base}/{hf_slug}"
|
||||
|
||||
# Immutable catalog update — spread, never mutate
|
||||
entry: dict = {"path": model_path, "vram_mb": body.vram_mb}
|
||||
if body.description:
|
||||
entry["description"] = body.description
|
||||
new_catalog = {**(svc.get("catalog") or {}), body.model_id: entry}
|
||||
new_svc = {**svc, "catalog": new_catalog}
|
||||
new_services = {**services_def, body.service_type: new_svc}
|
||||
new_profile = {**profile, "services": new_services}
|
||||
|
||||
# Atomic write
|
||||
tmp = Path(str(p) + ".tmp")
|
||||
tmp.write_text(
|
||||
yaml.dump(new_profile, default_flow_style=False, allow_unicode=True, sort_keys=False),
|
||||
encoding="utf-8",
|
||||
)
|
||||
os.replace(tmp, p)
|
||||
|
||||
# Trigger coordinator reload
|
||||
cfg = _load_config()
|
||||
coordinator_url = cfg.get("coordinator_url", "") or ""
|
||||
reloaded = False
|
||||
if coordinator_url:
|
||||
try:
|
||||
import httpx
|
||||
rr = httpx.post(f"{coordinator_url}/api/nodes/{node_id}/reload-profile", timeout=5.0)
|
||||
reloaded = rr.status_code < 300
|
||||
except Exception as exc:
|
||||
logger.warning("Coordinator reload failed for %s: %s", node_id, exc)
|
||||
|
||||
return {"ok": True, "reloaded": reloaded, "path": model_path}
|
||||
327
app/plans_bench.py
Normal file
327
app/plans_bench.py
Normal file
|
|
@ -0,0 +1,327 @@
|
|||
"""Avocet — CF planning benchmark integration API.
|
||||
|
||||
Wraps scripts/benchmark_plans.py and exposes it via the Avocet API.
|
||||
Connection config (api_base) is read from label_tool.yaml under the
|
||||
`plans_bench:` key (optional; falls back to localhost:8080).
|
||||
|
||||
All endpoints are registered on `router` (FastAPI APIRouter).
|
||||
api.py includes this router with prefix="/api/plans-bench".
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import subprocess as _subprocess
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
import yaml
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_ROOT = Path(__file__).parent.parent
|
||||
_CONFIG_DIR: Path | None = None # override in tests via set_config_dir()
|
||||
_BENCH_RUNNING: bool = False
|
||||
_bench_proc: Any = None
|
||||
|
||||
_BENCH_SCRIPT = _ROOT / "scripts" / "benchmark_plans.py"
|
||||
_RESULTS_DIR = _ROOT / "data" / "plans_bench_results"
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# ── Registered model shortcuts (mirrors benchmark_plans.MODEL_REGISTRY) ────────
|
||||
# Kept here so the UI can list them without importing the script.
|
||||
|
||||
MODEL_REGISTRY: dict[str, str] = {
|
||||
"deepseek-r1-1.5b": "DeepSeek R1 1.5B distill (cf-orch catalog key)",
|
||||
"deepseek-r1-7b-4bit": "DeepSeek R1 7B distill, 4-bit (cf-orch catalog key)",
|
||||
"deepseek-r1-0528-qwen3-8b-gguf": "DeepSeek R1 0528 Qwen3 8B GGUF (4 nodes)",
|
||||
"deepseek-coder-6.7b-4bit": "DeepSeek Coder 6.7B instruct, 4-bit (cf-orch catalog key)",
|
||||
"granite-4.1-8b": "IBM Granite 4.1 8B, 4-bit (cf-orch catalog key)",
|
||||
"qwen2.5-3b": "Qwen 2.5 3B Q4 GGUF (cf-orch catalog key)",
|
||||
"qwen2.5-7b": "Qwen 2.5 7B Q4 GGUF (cf-orch catalog key)",
|
||||
"capybarahermes-2.5-mistral-7b-gguf": "CapybaraHermes 2.5 Mistral 7B GGUF (4 nodes)",
|
||||
"darwin-9b-opus-gguf": "Darwin 9B Opus GGUF -- long-form writing (3 nodes)",
|
||||
}
|
||||
|
||||
RUBRIC_LABELS: dict[str, str] = {
|
||||
"task_structure": "Task structure (checkboxes + commits)",
|
||||
"tier_awareness": "Tier awareness (Free/Paid/Premium/Ultra)",
|
||||
"privacy_pillar": "Privacy pillar (local-first, no logging)",
|
||||
"safety_pillar": "Safety pillar (human approval, reversibility)",
|
||||
"accessibility": "Accessibility (ND/adaptive users)",
|
||||
"license_split": "License awareness (MIT vs BSL)",
|
||||
"file_paths": "File paths (plausible project paths)",
|
||||
"cf_conventions": "CF conventions (conda, manage.sh, /Library/…)",
|
||||
"length_ok": "Response length (200–2500 words)",
|
||||
}
|
||||
|
||||
|
||||
# ── Testability seam ───────────────────────────────────────────────────────────
|
||||
|
||||
def set_config_dir(path: Path | None) -> None:
|
||||
global _CONFIG_DIR
|
||||
_CONFIG_DIR = path
|
||||
|
||||
|
||||
# ── Internal helpers ───────────────────────────────────────────────────────────
|
||||
|
||||
def _config_file() -> Path:
|
||||
if _CONFIG_DIR is not None:
|
||||
return _CONFIG_DIR / "label_tool.yaml"
|
||||
return _ROOT / "config" / "label_tool.yaml"
|
||||
|
||||
|
||||
def _load_config() -> dict:
|
||||
f = _config_file()
|
||||
cforch_cfg: dict = {}
|
||||
bench_cfg: dict = {}
|
||||
if f.exists():
|
||||
try:
|
||||
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
||||
cforch_cfg = raw.get("cforch", {}) or {}
|
||||
bench_cfg = raw.get("plans_bench", {}) or {}
|
||||
except yaml.YAMLError as exc:
|
||||
logger.warning("Failed to parse plans_bench config %s: %s", f, exc)
|
||||
return {
|
||||
"coordinator_url": cforch_cfg.get("coordinator_url",
|
||||
bench_cfg.get("coordinator_url", "http://10.1.10.71:7700")),
|
||||
"python_bin": cforch_cfg.get("python_bin",
|
||||
bench_cfg.get("python_bin", "/devl/miniconda3/envs/cf/bin/python")),
|
||||
}
|
||||
|
||||
|
||||
def _results_file(run_id: str) -> Path:
|
||||
return _RESULTS_DIR / f"{run_id}.json"
|
||||
|
||||
|
||||
# ── GET /models ────────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/models")
|
||||
def get_models() -> dict:
|
||||
"""Return registered model shortcuts, live cf-orch catalog, and rubric labels."""
|
||||
cfg = _load_config()
|
||||
|
||||
cforch_models: list[dict] = []
|
||||
try:
|
||||
resp = httpx.get(
|
||||
f"{cfg['coordinator_url']}/api/services/cf-text/catalog",
|
||||
timeout=5.0,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
for model_id, entry in resp.json().items():
|
||||
if isinstance(entry, dict):
|
||||
cforch_models.append({
|
||||
"id": model_id,
|
||||
"name": model_id,
|
||||
"vram_mb": entry.get("vram_mb"),
|
||||
"description": entry.get("description", ""),
|
||||
})
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to fetch cf-orch catalog: %s", exc)
|
||||
|
||||
return {
|
||||
"registry": [
|
||||
{"key": k, "description": v}
|
||||
for k, v in MODEL_REGISTRY.items()
|
||||
],
|
||||
"cforch_models": cforch_models,
|
||||
"coordinator_url": cfg["coordinator_url"],
|
||||
"rubric_labels": RUBRIC_LABELS,
|
||||
}
|
||||
|
||||
|
||||
# ── GET /run ───────────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/run")
|
||||
def run_plans_benchmark(
|
||||
models: str = Query(..., description="Comma-separated model IDs (registry keys or cf-orch model names)"),
|
||||
prompt_ids: str = Query("", description="Comma-separated prompt IDs to run (empty = all 10)"),
|
||||
use_cforch: bool = Query(True, description="Route inference through cf-orch coordinator"),
|
||||
api_base: str = Query("", description="Direct API base URL when not using cf-orch"),
|
||||
workers: int = Query(1, ge=1, le=8, description="Number of models to benchmark concurrently"),
|
||||
) -> StreamingResponse:
|
||||
"""Spawn benchmark_plans.py and stream stdout as SSE progress events.
|
||||
|
||||
On successful completion emits a `type: result` event with parsed JSON
|
||||
and saves results to data/plans_bench_results/<run_id>.json.
|
||||
"""
|
||||
global _BENCH_RUNNING, _bench_proc
|
||||
|
||||
if _BENCH_RUNNING:
|
||||
raise HTTPException(409, "A planning benchmark is already running")
|
||||
|
||||
cfg = _load_config()
|
||||
python_bin = cfg["python_bin"]
|
||||
coordinator_url = cfg["coordinator_url"]
|
||||
|
||||
model_keys = [m.strip() for m in models.split(",") if m.strip()]
|
||||
if not model_keys:
|
||||
raise HTTPException(400, "At least one model key is required")
|
||||
|
||||
run_id = datetime.now(tz=timezone.utc).strftime("plans_%Y-%m-%d_%H%M%S")
|
||||
output_path = _results_file(run_id)
|
||||
_RESULTS_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def generate():
|
||||
global _BENCH_RUNNING, _bench_proc
|
||||
|
||||
if not _BENCH_SCRIPT.exists():
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': f'benchmark_plans.py not found at {_BENCH_SCRIPT}'})}\n\n"
|
||||
return
|
||||
|
||||
cmd = [python_bin, str(_BENCH_SCRIPT)]
|
||||
if len(model_keys) > 1:
|
||||
cmd.extend(["--compare"] + model_keys)
|
||||
else:
|
||||
cmd.extend(["--model", model_keys[0]])
|
||||
|
||||
if use_cforch:
|
||||
cmd.extend(["--cforch", "--cforch-url", coordinator_url])
|
||||
elif api_base.strip():
|
||||
cmd.extend(["--api-base", api_base.strip()])
|
||||
|
||||
cmd.extend(["--verbose", "--output", str(output_path)])
|
||||
if workers > 1:
|
||||
cmd.extend(["--workers", str(workers)])
|
||||
|
||||
if prompt_ids.strip():
|
||||
cmd.extend(["--prompts"] + [p.strip() for p in prompt_ids.split(",") if p.strip()])
|
||||
|
||||
_BENCH_RUNNING = True
|
||||
try:
|
||||
proc = _subprocess.Popen(
|
||||
cmd,
|
||||
stdout=_subprocess.PIPE,
|
||||
stderr=_subprocess.STDOUT,
|
||||
text=True,
|
||||
bufsize=1,
|
||||
cwd=str(_ROOT),
|
||||
)
|
||||
_bench_proc = proc
|
||||
try:
|
||||
for line in proc.stdout:
|
||||
line = line.rstrip()
|
||||
if line:
|
||||
yield f"data: {json.dumps({'type': 'progress', 'message': line})}\n\n"
|
||||
proc.wait()
|
||||
if proc.returncode == 0 and output_path.exists():
|
||||
try:
|
||||
results = json.loads(output_path.read_text(encoding="utf-8"))
|
||||
yield f"data: {json.dumps({'type': 'result', 'run_id': run_id, 'results': results})}\n\n"
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to read plans benchmark output: %s", exc)
|
||||
yield f"data: {json.dumps({'type': 'complete'})}\n\n"
|
||||
else:
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': f'Process exited with code {proc.returncode}'})}\n\n"
|
||||
finally:
|
||||
_bench_proc = None
|
||||
except Exception as exc:
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': str(exc)})}\n\n"
|
||||
finally:
|
||||
_BENCH_RUNNING = False
|
||||
|
||||
return StreamingResponse(
|
||||
generate(),
|
||||
media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
|
||||
)
|
||||
|
||||
|
||||
# ── GET /results ───────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/results")
|
||||
def list_results() -> list[dict]:
|
||||
"""List past planning benchmark runs, newest first."""
|
||||
if not _RESULTS_DIR.exists():
|
||||
return []
|
||||
|
||||
runs: list[dict] = []
|
||||
for f in sorted(_RESULTS_DIR.glob("plans_*.json"), reverse=True):
|
||||
run_id = f.stem
|
||||
try:
|
||||
data: dict = json.loads(f.read_text(encoding="utf-8"))
|
||||
model_keys = list(data.keys())
|
||||
# Average total_score across all models and prompts
|
||||
all_scores = [
|
||||
r["total_score"]
|
||||
for results in data.values()
|
||||
for r in results
|
||||
if not r.get("error")
|
||||
]
|
||||
avg_score = round(sum(all_scores) / len(all_scores), 3) if all_scores else 0.0
|
||||
except Exception:
|
||||
model_keys = []
|
||||
avg_score = 0.0
|
||||
|
||||
# Parse display date from run_id (plans_2026-04-27_143022)
|
||||
try:
|
||||
date_part = run_id.removeprefix("plans_") # 2026-04-27_143022
|
||||
date, time = date_part.split("_")
|
||||
display_date = f"{date} {time[:2]}:{time[2:4]}"
|
||||
except Exception:
|
||||
display_date = run_id
|
||||
|
||||
runs.append({
|
||||
"run_id": run_id,
|
||||
"filename": f.name,
|
||||
"date": display_date,
|
||||
"models": model_keys,
|
||||
"avg_score": avg_score,
|
||||
})
|
||||
|
||||
return runs
|
||||
|
||||
|
||||
@router.get("/results/latest")
|
||||
def get_latest_results() -> dict:
|
||||
"""Return the most recent planning benchmark results dict."""
|
||||
if not _RESULTS_DIR.exists():
|
||||
raise HTTPException(404, "No benchmark results found")
|
||||
files = sorted(_RESULTS_DIR.glob("plans_*.json"))
|
||||
if not files:
|
||||
raise HTTPException(404, "No benchmark results found")
|
||||
try:
|
||||
return json.loads(files[-1].read_text(encoding="utf-8"))
|
||||
except Exception as exc:
|
||||
raise HTTPException(500, f"Failed to read results: {exc}") from exc
|
||||
|
||||
|
||||
@router.get("/results/{run_id}")
|
||||
def get_results_by_run_id(run_id: str) -> dict:
|
||||
"""Return planning benchmark results for a specific run."""
|
||||
if not run_id.startswith("plans_"):
|
||||
raise HTTPException(400, "Invalid run_id — expected plans_YYYY-MM-DD_HHMMSS")
|
||||
f = _results_file(run_id)
|
||||
if not f.exists():
|
||||
raise HTTPException(404, f"Results not found: {run_id}")
|
||||
try:
|
||||
return json.loads(f.read_text(encoding="utf-8"))
|
||||
except Exception as exc:
|
||||
raise HTTPException(500, f"Failed to read results: {exc}") from exc
|
||||
|
||||
|
||||
# ── POST /cancel ───────────────────────────────────────────────────────────────
|
||||
|
||||
@router.post("/cancel")
|
||||
def cancel_plans_benchmark() -> dict:
|
||||
"""Kill the running planning benchmark subprocess."""
|
||||
global _BENCH_RUNNING, _bench_proc
|
||||
|
||||
if not _BENCH_RUNNING:
|
||||
raise HTTPException(404, "No planning benchmark is currently running")
|
||||
|
||||
if _bench_proc is not None:
|
||||
try:
|
||||
_bench_proc.terminate()
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to terminate plans benchmark: %s", exc)
|
||||
|
||||
_BENCH_RUNNING = False
|
||||
_bench_proc = None
|
||||
return {"status": "cancelled"}
|
||||
341
app/sft.py
341
app/sft.py
|
|
@ -1,335 +1,8 @@
|
|||
"""Avocet — SFT candidate import and correction API.
|
||||
|
||||
All endpoints are registered on `router` (a FastAPI APIRouter).
|
||||
api.py includes this router with prefix="/api/sft".
|
||||
|
||||
Module-level globals (_SFT_DATA_DIR, _SFT_CONFIG_DIR) follow the same
|
||||
testability pattern as api.py — override them via set_sft_data_dir() and
|
||||
set_sft_config_dir() in test fixtures.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
import yaml
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.utils import append_jsonl, read_jsonl, write_jsonl
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_ROOT = Path(__file__).parent.parent
|
||||
_SFT_DATA_DIR: Path = _ROOT / "data"
|
||||
_SFT_CONFIG_DIR: Path | None = None
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# ── Testability seams ──────────────────────────────────────────────────────
|
||||
|
||||
def set_sft_data_dir(path: Path) -> None:
|
||||
global _SFT_DATA_DIR
|
||||
_SFT_DATA_DIR = path
|
||||
|
||||
|
||||
def set_sft_config_dir(path: Path | None) -> None:
|
||||
global _SFT_CONFIG_DIR
|
||||
_SFT_CONFIG_DIR = path
|
||||
|
||||
|
||||
# ── Internal helpers ───────────────────────────────────────────────────────
|
||||
|
||||
def _config_file() -> Path:
|
||||
if _SFT_CONFIG_DIR is not None:
|
||||
return _SFT_CONFIG_DIR / "label_tool.yaml"
|
||||
return _ROOT / "config" / "label_tool.yaml"
|
||||
|
||||
|
||||
_DEFAULT_BENCH_RESULTS_DIR = "/Library/Development/CircuitForge/circuitforge-orch/scripts/bench_results"
|
||||
|
||||
|
||||
def set_default_bench_results_dir(path: str) -> None:
|
||||
"""Override the default bench_results_dir — used by tests to avoid real filesystem."""
|
||||
global _DEFAULT_BENCH_RESULTS_DIR
|
||||
_DEFAULT_BENCH_RESULTS_DIR = path
|
||||
|
||||
|
||||
def _get_bench_results_dir() -> Path:
|
||||
f = _config_file()
|
||||
if f.exists():
|
||||
try:
|
||||
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
||||
d = raw.get("sft", {}).get("bench_results_dir", "")
|
||||
if d:
|
||||
return Path(d)
|
||||
except yaml.YAMLError as exc:
|
||||
logger.warning("Failed to parse SFT config %s: %s", f, exc)
|
||||
return Path(_DEFAULT_BENCH_RESULTS_DIR)
|
||||
|
||||
|
||||
def _candidates_file() -> Path:
|
||||
return _SFT_DATA_DIR / "sft_candidates.jsonl"
|
||||
|
||||
|
||||
def _approved_file() -> Path:
|
||||
return _SFT_DATA_DIR / "sft_approved.jsonl"
|
||||
|
||||
|
||||
def _read_candidates() -> list[dict]:
|
||||
return read_jsonl(_candidates_file())
|
||||
|
||||
|
||||
def _write_candidates(records: list[dict]) -> None:
|
||||
write_jsonl(_candidates_file(), records)
|
||||
|
||||
|
||||
def _is_exportable(r: dict) -> bool:
|
||||
"""Return True if an approved record is ready to include in SFT export."""
|
||||
return (
|
||||
r.get("status") == "approved"
|
||||
and bool(r.get("corrected_response"))
|
||||
and str(r["corrected_response"]).strip() != ""
|
||||
"""Backward-compat shim -- logic moved to app/data/corrections.py."""
|
||||
from app.data.corrections import ( # noqa: F401
|
||||
router,
|
||||
set_data_dir as set_sft_data_dir,
|
||||
set_config_dir as set_sft_config_dir,
|
||||
set_default_bench_results_dir,
|
||||
_DEFAULT_BENCH_RESULTS_DIR,
|
||||
)
|
||||
|
||||
|
||||
# ── GET /runs ──────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/runs")
|
||||
def get_runs():
|
||||
"""List available benchmark runs in the configured bench_results_dir."""
|
||||
from scripts.sft_import import discover_runs
|
||||
bench_dir = _get_bench_results_dir()
|
||||
existing = _read_candidates()
|
||||
# benchmark_run_id in each record equals the run's directory name by cf-orch convention
|
||||
imported_run_ids = {
|
||||
r["benchmark_run_id"]
|
||||
for r in existing
|
||||
if r.get("benchmark_run_id") is not None
|
||||
}
|
||||
runs = discover_runs(bench_dir)
|
||||
return [
|
||||
{
|
||||
"run_id": r["run_id"],
|
||||
"timestamp": r["timestamp"],
|
||||
"candidate_count": r["candidate_count"],
|
||||
"already_imported": r["run_id"] in imported_run_ids,
|
||||
}
|
||||
for r in runs
|
||||
]
|
||||
|
||||
|
||||
# ── POST /import ───────────────────────────────────────────────────────────
|
||||
|
||||
class ImportRequest(BaseModel):
|
||||
run_id: str
|
||||
|
||||
|
||||
@router.post("/import")
|
||||
def post_import(req: ImportRequest):
|
||||
"""Import one benchmark run's sft_candidates.jsonl into the local data dir."""
|
||||
from scripts.sft_import import discover_runs, import_run
|
||||
bench_dir = _get_bench_results_dir()
|
||||
runs = discover_runs(bench_dir)
|
||||
run = next((r for r in runs if r["run_id"] == req.run_id), None)
|
||||
if run is None:
|
||||
raise HTTPException(404, f"Run {req.run_id!r} not found in bench_results_dir")
|
||||
return import_run(run["sft_path"], _SFT_DATA_DIR)
|
||||
|
||||
|
||||
# ── GET /queue ─────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/queue")
|
||||
def get_queue(page: int = 1, per_page: int = 20):
|
||||
"""Return paginated needs_review candidates."""
|
||||
records = _read_candidates()
|
||||
pending = [r for r in records if r.get("status") == "needs_review"]
|
||||
start = (page - 1) * per_page
|
||||
return {
|
||||
"items": pending[start:start + per_page],
|
||||
"total": len(pending),
|
||||
"page": page,
|
||||
"per_page": per_page,
|
||||
}
|
||||
|
||||
|
||||
# ── POST /submit ───────────────────────────────────────────────────────────
|
||||
|
||||
FailureCategory = Literal[
|
||||
"scoring_artifact",
|
||||
"style_violation",
|
||||
"partial_answer",
|
||||
"wrong_answer",
|
||||
"format_error",
|
||||
"hallucination",
|
||||
]
|
||||
|
||||
|
||||
class SubmitRequest(BaseModel):
|
||||
id: str
|
||||
action: Literal["correct", "discard", "flag"]
|
||||
corrected_response: str | None = None
|
||||
failure_category: FailureCategory | None = None
|
||||
|
||||
|
||||
@router.post("/submit")
|
||||
def post_submit(req: SubmitRequest):
|
||||
"""Record a reviewer decision for one SFT candidate."""
|
||||
if req.action == "correct":
|
||||
if not req.corrected_response or not req.corrected_response.strip():
|
||||
raise HTTPException(422, "corrected_response must be non-empty when action is 'correct'")
|
||||
|
||||
records = _read_candidates()
|
||||
idx = next((i for i, r in enumerate(records) if r.get("id") == req.id), None)
|
||||
if idx is None:
|
||||
raise HTTPException(404, f"Record {req.id!r} not found")
|
||||
|
||||
record = records[idx]
|
||||
if record.get("status") != "needs_review":
|
||||
raise HTTPException(409, f"Record is not in needs_review state (current: {record.get('status')})")
|
||||
|
||||
if req.action == "correct":
|
||||
records[idx] = {
|
||||
**record,
|
||||
"status": "approved",
|
||||
"corrected_response": req.corrected_response,
|
||||
"failure_category": req.failure_category,
|
||||
}
|
||||
_write_candidates(records)
|
||||
append_jsonl(_approved_file(), records[idx])
|
||||
elif req.action == "discard":
|
||||
records[idx] = {**record, "status": "discarded"}
|
||||
_write_candidates(records)
|
||||
else: # flag
|
||||
records[idx] = {**record, "status": "model_rejected"}
|
||||
_write_candidates(records)
|
||||
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
# ── POST /undo ─────────────────────────────────────────────────────────────
|
||||
|
||||
class UndoRequest(BaseModel):
|
||||
id: str
|
||||
|
||||
|
||||
@router.post("/undo")
|
||||
def post_undo(req: UndoRequest):
|
||||
"""Restore a previously actioned candidate back to needs_review."""
|
||||
records = _read_candidates()
|
||||
idx = next((i for i, r in enumerate(records) if r.get("id") == req.id), None)
|
||||
if idx is None:
|
||||
raise HTTPException(404, f"Record {req.id!r} not found")
|
||||
|
||||
record = records[idx]
|
||||
old_status = record.get("status")
|
||||
if old_status == "needs_review":
|
||||
raise HTTPException(409, "Record is already in needs_review state")
|
||||
|
||||
records[idx] = {**record, "status": "needs_review", "corrected_response": None}
|
||||
_write_candidates(records)
|
||||
|
||||
# If it was approved, remove from the approved file too
|
||||
if old_status == "approved":
|
||||
approved = read_jsonl(_approved_file())
|
||||
write_jsonl(_approved_file(), [r for r in approved if r.get("id") != req.id])
|
||||
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
# ── GET /export ─────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/export")
|
||||
def get_export() -> StreamingResponse:
|
||||
"""Stream approved records as SFT-ready JSONL for download."""
|
||||
exportable = [r for r in read_jsonl(_approved_file()) if _is_exportable(r)]
|
||||
|
||||
def generate():
|
||||
for r in exportable:
|
||||
record = {
|
||||
"messages": r.get("prompt_messages", []) + [
|
||||
{"role": "assistant", "content": r["corrected_response"]}
|
||||
]
|
||||
}
|
||||
yield json.dumps(record) + "\n"
|
||||
|
||||
timestamp = datetime.now(timezone.utc).strftime("%Y%m%d-%H%M%S")
|
||||
return StreamingResponse(
|
||||
generate(),
|
||||
media_type="application/x-ndjson",
|
||||
headers={
|
||||
"Content-Disposition": f'attachment; filename="sft_export_{timestamp}.jsonl"'
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# ── GET /stats ──────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/stats")
|
||||
def get_stats() -> dict[str, object]:
|
||||
"""Return counts by status, model, and task type."""
|
||||
records = _read_candidates()
|
||||
by_status: dict[str, int] = {}
|
||||
by_model: dict[str, int] = {}
|
||||
by_task_type: dict[str, int] = {}
|
||||
|
||||
for r in records:
|
||||
status = r.get("status", "unknown")
|
||||
by_status[status] = by_status.get(status, 0) + 1
|
||||
model = r.get("model_name", "unknown")
|
||||
by_model[model] = by_model.get(model, 0) + 1
|
||||
task_type = r.get("task_type", "unknown")
|
||||
by_task_type[task_type] = by_task_type.get(task_type, 0) + 1
|
||||
|
||||
approved = read_jsonl(_approved_file())
|
||||
export_ready = sum(1 for r in approved if _is_exportable(r))
|
||||
|
||||
return {
|
||||
"total": len(records),
|
||||
"by_status": by_status,
|
||||
"by_model": by_model,
|
||||
"by_task_type": by_task_type,
|
||||
"export_ready": export_ready,
|
||||
}
|
||||
|
||||
|
||||
# ── GET /config ─────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/config")
|
||||
def get_sft_config() -> dict:
|
||||
"""Return the current SFT configuration (bench_results_dir)."""
|
||||
f = _config_file()
|
||||
if not f.exists():
|
||||
return {"bench_results_dir": ""}
|
||||
try:
|
||||
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
||||
except yaml.YAMLError:
|
||||
return {"bench_results_dir": ""}
|
||||
sft_section = raw.get("sft") or {}
|
||||
return {"bench_results_dir": sft_section.get("bench_results_dir", "")}
|
||||
|
||||
|
||||
class SftConfigPayload(BaseModel):
|
||||
bench_results_dir: str
|
||||
|
||||
|
||||
@router.post("/config")
|
||||
def post_sft_config(payload: SftConfigPayload) -> dict:
|
||||
"""Write the bench_results_dir setting to the config file."""
|
||||
f = _config_file()
|
||||
f.parent.mkdir(parents=True, exist_ok=True)
|
||||
try:
|
||||
raw = yaml.safe_load(f.read_text(encoding="utf-8")) if f.exists() else {}
|
||||
raw = raw or {}
|
||||
except yaml.YAMLError:
|
||||
raw = {}
|
||||
raw["sft"] = {"bench_results_dir": payload.bench_results_dir}
|
||||
tmp = f.with_suffix(".tmp")
|
||||
tmp.write_text(yaml.dump(raw, allow_unicode=True, sort_keys=False), encoding="utf-8")
|
||||
tmp.rename(f)
|
||||
return {"ok": True}
|
||||
|
|
|
|||
0
app/train/__init__.py
Normal file
0
app/train/__init__.py
Normal file
339
app/train/train.py
Normal file
339
app/train/train.py
Normal file
|
|
@ -0,0 +1,339 @@
|
|||
"""Avocet -- train job queue API.
|
||||
|
||||
SQLite-backed job queue for finetune jobs. Replaces the ad-hoc
|
||||
_running_procs dict in api.py with a persistent, inspectable queue.
|
||||
|
||||
Routes (all under /api/train when api.py mounts with prefix="/api/train"):
|
||||
GET /jobs -- list all jobs, newest first
|
||||
POST /jobs -- create a new job
|
||||
GET /jobs/{id} -- get one job by id
|
||||
DELETE /jobs/{id}/cancel -- cancel a queued or running job
|
||||
GET /jobs/{id}/run -- SSE: run the job, stream stdout
|
||||
GET /results -- list completed models with training_info.json metrics
|
||||
|
||||
SQLite schema:
|
||||
CREATE TABLE IF NOT EXISTS jobs (
|
||||
id TEXT PRIMARY KEY,
|
||||
type TEXT NOT NULL, -- 'classifier' | 'llm-sft'
|
||||
model_key TEXT NOT NULL,
|
||||
status TEXT NOT NULL DEFAULT 'queued',
|
||||
config_json TEXT NOT NULL DEFAULT '{}',
|
||||
created_at TEXT NOT NULL,
|
||||
started_at TEXT,
|
||||
completed_at TEXT,
|
||||
error TEXT
|
||||
)
|
||||
|
||||
Testability seam: _DB_PATH global, override via set_db_path().
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sqlite3
|
||||
import subprocess as _subprocess
|
||||
import uuid
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any, Generator
|
||||
|
||||
import yaml
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_ROOT = Path(__file__).parent.parent.parent
|
||||
_DB_PATH: Path = _ROOT / "data" / "train_jobs.db"
|
||||
_MODELS_DIR: Path = _ROOT / "models"
|
||||
_CONFIG_DIR: Path | None = None # override in tests via set_config_dir()
|
||||
_running_procs: dict[str, Any] = {}
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# -- Testability seams -------------------------------------------------
|
||||
|
||||
def set_db_path(path: Path) -> None:
|
||||
global _DB_PATH
|
||||
_DB_PATH = path
|
||||
|
||||
def set_models_dir(path: Path) -> None:
|
||||
global _MODELS_DIR
|
||||
_MODELS_DIR = path
|
||||
|
||||
def set_config_dir(path: "Path | None") -> None:
|
||||
global _CONFIG_DIR
|
||||
_CONFIG_DIR = path
|
||||
|
||||
|
||||
# -- Config helpers ----------------------------------------------------
|
||||
|
||||
def _config_file() -> Path:
|
||||
if _CONFIG_DIR is not None:
|
||||
return _CONFIG_DIR / "label_tool.yaml"
|
||||
return _ROOT / "config" / "label_tool.yaml"
|
||||
|
||||
|
||||
def _load_train_config() -> dict:
|
||||
"""Read python_bin from label_tool.yaml.
|
||||
|
||||
Priority (highest to lowest):
|
||||
1. label_tool.yaml train: python_bin
|
||||
2. label_tool.yaml cforch: python_bin
|
||||
3. Hardcoded default (classifiers conda env)
|
||||
"""
|
||||
_DEFAULT_PYTHON_BIN = "/devl/miniconda3/envs/job-seeker-classifiers/bin/python"
|
||||
f = _config_file()
|
||||
train_cfg: dict = {}
|
||||
cforch_cfg: dict = {}
|
||||
if f.exists():
|
||||
try:
|
||||
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
||||
train_cfg = raw.get("train", {}) or {}
|
||||
cforch_cfg = raw.get("cforch", {}) or {}
|
||||
except yaml.YAMLError as exc:
|
||||
logger.warning("Failed to parse train config %s: %s", f, exc)
|
||||
return {
|
||||
"python_bin": train_cfg.get(
|
||||
"python_bin",
|
||||
cforch_cfg.get("python_bin", _DEFAULT_PYTHON_BIN),
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
# -- Database helpers --------------------------------------------------
|
||||
|
||||
@contextmanager
|
||||
def _db() -> Generator[sqlite3.Connection, None, None]:
|
||||
conn = sqlite3.connect(str(_DB_PATH))
|
||||
conn.row_factory = sqlite3.Row
|
||||
try:
|
||||
yield conn
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
def _init_db() -> None:
|
||||
"""Create jobs table if it does not exist. Called lazily per request."""
|
||||
_DB_PATH.parent.mkdir(parents=True, exist_ok=True)
|
||||
with _db() as conn:
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS jobs (
|
||||
id TEXT PRIMARY KEY,
|
||||
type TEXT NOT NULL,
|
||||
model_key TEXT NOT NULL,
|
||||
status TEXT NOT NULL DEFAULT 'queued',
|
||||
config_json TEXT NOT NULL DEFAULT '{}',
|
||||
created_at TEXT NOT NULL,
|
||||
started_at TEXT,
|
||||
completed_at TEXT,
|
||||
error TEXT
|
||||
)
|
||||
""")
|
||||
|
||||
|
||||
def _row_to_dict(row: sqlite3.Row) -> dict:
|
||||
return {k: row[k] for k in row.keys()}
|
||||
|
||||
|
||||
# -- GPU selection (copied from api.py) --------------------------------
|
||||
|
||||
def _best_cuda_device() -> str:
|
||||
"""Return index of GPU with most free VRAM, or empty string."""
|
||||
try:
|
||||
out = _subprocess.check_output(
|
||||
["nvidia-smi", "--query-gpu=index,memory.free",
|
||||
"--format=csv,noheader,nounits"],
|
||||
text=True, timeout=5,
|
||||
)
|
||||
best_idx, best_free = "", 0
|
||||
for line in out.strip().splitlines():
|
||||
parts = line.strip().split(", ")
|
||||
if len(parts) == 2:
|
||||
idx, free = parts[0].strip(), int(parts[1].strip())
|
||||
if free > best_free:
|
||||
best_free, best_idx = free, idx
|
||||
return best_idx
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
|
||||
# -- Pydantic models ---------------------------------------------------
|
||||
|
||||
class CreateJobRequest(BaseModel):
|
||||
type: str # "classifier" | "llm-sft"
|
||||
model_key: str # e.g. "deberta-small"
|
||||
config_json: dict = {}
|
||||
|
||||
|
||||
# -- Routes ------------------------------------------------------------
|
||||
|
||||
@router.get("/jobs")
|
||||
def list_jobs() -> dict:
|
||||
_init_db()
|
||||
with _db() as conn:
|
||||
rows = conn.execute("SELECT * FROM jobs ORDER BY created_at DESC").fetchall()
|
||||
return {"jobs": [_row_to_dict(r) for r in rows]}
|
||||
|
||||
|
||||
@router.post("/jobs")
|
||||
def create_job(req: CreateJobRequest) -> dict:
|
||||
if req.type not in ("classifier", "llm-sft"):
|
||||
raise HTTPException(400, f"Unknown job type: {req.type!r}. Must be 'classifier' or 'llm-sft'.")
|
||||
_init_db()
|
||||
job_id = str(uuid.uuid4())
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
with _db() as conn:
|
||||
conn.execute(
|
||||
"INSERT INTO jobs (id, type, model_key, status, config_json, created_at) "
|
||||
"VALUES (?, ?, ?, 'queued', ?, ?)",
|
||||
(job_id, req.type, req.model_key, json.dumps(req.config_json), now),
|
||||
)
|
||||
return {"id": job_id, "type": req.type, "model_key": req.model_key,
|
||||
"status": "queued", "config_json": req.config_json,
|
||||
"created_at": now, "started_at": None, "completed_at": None, "error": None}
|
||||
|
||||
|
||||
@router.get("/jobs/{job_id}")
|
||||
def get_job(job_id: str) -> dict:
|
||||
_init_db()
|
||||
with _db() as conn:
|
||||
row = conn.execute("SELECT * FROM jobs WHERE id = ?", (job_id,)).fetchone()
|
||||
if row is None:
|
||||
raise HTTPException(404, f"Job {job_id!r} not found")
|
||||
return _row_to_dict(row)
|
||||
|
||||
|
||||
@router.delete("/jobs/{job_id}/cancel")
|
||||
def cancel_job(job_id: str) -> dict:
|
||||
_init_db()
|
||||
with _db() as conn:
|
||||
row = conn.execute("SELECT * FROM jobs WHERE id = ?", (job_id,)).fetchone()
|
||||
if row is None:
|
||||
raise HTTPException(404, f"Job {job_id!r} not found")
|
||||
if row["status"] not in ("queued", "running"):
|
||||
raise HTTPException(409, f"Job is {row['status']} -- cannot cancel")
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
conn.execute("UPDATE jobs SET status='cancelled', completed_at=? WHERE id=?", (now, job_id))
|
||||
proc = _running_procs.pop(job_id, None)
|
||||
if proc is not None:
|
||||
try:
|
||||
proc.terminate()
|
||||
proc.wait(timeout=3)
|
||||
except _subprocess.TimeoutExpired:
|
||||
try:
|
||||
proc.kill()
|
||||
except OSError:
|
||||
pass
|
||||
return {"status": "cancelled"}
|
||||
|
||||
|
||||
@router.get("/jobs/{job_id}/run")
|
||||
def run_job(job_id: str) -> StreamingResponse:
|
||||
_init_db()
|
||||
with _db() as conn:
|
||||
row = conn.execute("SELECT * FROM jobs WHERE id = ?", (job_id,)).fetchone()
|
||||
if row is None:
|
||||
raise HTTPException(404, f"Job {job_id!r} not found")
|
||||
if row["status"] != "queued":
|
||||
raise HTTPException(409, f"Job is {row['status']} -- only queued jobs can be run")
|
||||
job = _row_to_dict(row)
|
||||
|
||||
def generate():
|
||||
cfg = _load_train_config()
|
||||
python_bin = cfg["python_bin"]
|
||||
config = json.loads(job["config_json"] or "{}")
|
||||
model_key = job["model_key"]
|
||||
epochs = config.get("epochs", 5)
|
||||
|
||||
if job["type"] == "classifier":
|
||||
script = str(_ROOT / "scripts" / "finetune_classifier.py")
|
||||
cmd = [python_bin, script, "--model", model_key, "--epochs", str(epochs)]
|
||||
data_dir = _ROOT / "data"
|
||||
for sf in config.get("score_files", []):
|
||||
resolved = (data_dir / sf).resolve()
|
||||
if resolved.is_relative_to(data_dir.resolve()):
|
||||
cmd.extend(["--score", str(resolved)])
|
||||
elif job["type"] == "llm-sft":
|
||||
script = str(_ROOT / "scripts" / "finetune_sft.py")
|
||||
cmd = [python_bin, script, "--model", model_key, "--epochs", str(epochs)]
|
||||
else:
|
||||
job_type = job["type"]
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': f'Unknown job type: {job_type}'})}\n\n"
|
||||
return
|
||||
|
||||
proc_env = {**os.environ, "PYTORCH_ALLOC_CONF": "expandable_segments:True"}
|
||||
best_gpu = _best_cuda_device()
|
||||
if best_gpu:
|
||||
proc_env["CUDA_VISIBLE_DEVICES"] = best_gpu
|
||||
|
||||
gpu_note = f"GPU {best_gpu}" if best_gpu else "CPU (no GPU found)"
|
||||
yield f"data: {json.dumps({'type': 'progress', 'message': f'[train] Using {gpu_note}'})}\n\n"
|
||||
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
with _db() as conn:
|
||||
conn.execute("UPDATE jobs SET status='running', started_at=? WHERE id=?", (now, job_id))
|
||||
|
||||
try:
|
||||
proc = _subprocess.Popen(
|
||||
cmd, stdout=_subprocess.PIPE, stderr=_subprocess.STDOUT,
|
||||
text=True, bufsize=1, cwd=str(_ROOT), env=proc_env,
|
||||
)
|
||||
_running_procs[job_id] = proc
|
||||
try:
|
||||
for line in proc.stdout:
|
||||
line = line.rstrip()
|
||||
if line:
|
||||
yield f"data: {json.dumps({'type': 'progress', 'message': line})}\n\n"
|
||||
proc.wait()
|
||||
finished_at = datetime.now(timezone.utc).isoformat()
|
||||
if proc.returncode == 0:
|
||||
with _db() as conn:
|
||||
conn.execute(
|
||||
"UPDATE jobs SET status='completed', completed_at=? WHERE id=?",
|
||||
(finished_at, job_id))
|
||||
yield f"data: {json.dumps({'type': 'complete'})}\n\n"
|
||||
else:
|
||||
err = f"Process exited with code {proc.returncode}"
|
||||
with _db() as conn:
|
||||
conn.execute(
|
||||
"UPDATE jobs SET status='failed', completed_at=?, error=? WHERE id=?",
|
||||
(finished_at, err, job_id))
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': err})}\n\n"
|
||||
finally:
|
||||
_running_procs.pop(job_id, None)
|
||||
except Exception as exc:
|
||||
err = str(exc)
|
||||
finished_at = datetime.now(timezone.utc).isoformat()
|
||||
with _db() as conn:
|
||||
conn.execute(
|
||||
"UPDATE jobs SET status='failed', completed_at=?, error=? WHERE id=?",
|
||||
(finished_at, err, job_id))
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': err})}\n\n"
|
||||
|
||||
return StreamingResponse(generate(), media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"})
|
||||
|
||||
|
||||
@router.get("/results")
|
||||
def list_results() -> dict:
|
||||
if not _MODELS_DIR.exists():
|
||||
return {"results": []}
|
||||
results = []
|
||||
for sub in _MODELS_DIR.iterdir():
|
||||
if not sub.is_dir():
|
||||
continue
|
||||
info_path = sub / "training_info.json"
|
||||
if not info_path.exists():
|
||||
continue
|
||||
try:
|
||||
info = json.loads(info_path.read_text(encoding="utf-8"))
|
||||
results.append(info)
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to read training_info.json from %s: %s", info_path, exc)
|
||||
return {"results": results}
|
||||
|
|
@ -106,7 +106,7 @@ def read_jsonl(path: Path) -> list[dict]:
|
|||
def write_jsonl(path: Path, records: list[dict]) -> None:
|
||||
"""Write records to a JSONL file, overwriting any existing content."""
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
content = "\n".join(json.dumps(r) for r in records)
|
||||
content = "\n".join(json.dumps(r, ensure_ascii=False) for r in records)
|
||||
path.write_text(content + ("\n" if records else ""), encoding="utf-8")
|
||||
|
||||
|
||||
|
|
@ -114,4 +114,4 @@ def append_jsonl(path: Path, record: dict) -> None:
|
|||
"""Append a single record to a JSONL file."""
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(path, "a", encoding="utf-8") as fh:
|
||||
fh.write(json.dumps(record) + "\n")
|
||||
fh.write(json.dumps(record, ensure_ascii=False) + "\n")
|
||||
|
|
|
|||
|
|
@ -41,11 +41,20 @@ cforch:
|
|||
# Python interpreter with cf-orch installed
|
||||
python_bin: /devl/miniconda3/envs/cf/bin/python
|
||||
|
||||
# Connection config — override env vars CF_ORCH_URL / CF_LICENSE_KEY / OLLAMA_HOST
|
||||
# Connection config — override env vars CF_ORCH_URL / CF_LICENSE_KEY / OLLAMA_HOST / CF_JUDGE_URL / HF_TOKEN
|
||||
# coordinator_url: http://localhost:7700
|
||||
# license_key: CFG-AVCT-xxxx-xxxx-xxxx
|
||||
# ollama_url: http://localhost:11434
|
||||
# ollama_model: llama3.2:3b
|
||||
# embed_model: nomic-embed-text # Ollama embedding model for EmbeddingKNNAdapter
|
||||
# judge_url: http://10.1.10.158:8008 # Sif cf-text — LLM-as-judge secondary scorer
|
||||
# judge_url: http://10.1.10.71:8008 # Heimdall cf-text (alternative)
|
||||
# Or set CF_JUDGE_URL. Populates the Judge URL field in the LLM Eval UI automatically.
|
||||
# hf_token: hf_xxxxxxxxxxxxxxxxxxxx # HuggingFace token — required for gated/terms-restricted models
|
||||
|
||||
# Directory containing per-node profile YAMLs (cf-orch node profiles).
|
||||
# Default: derived from bench_script location (../../profiles/nodes).
|
||||
# profiles_dir: /Library/Development/CircuitForge/circuitforge-orch/circuitforge_orch/profiles/nodes
|
||||
|
||||
# Imitate tab — pull real samples from sibling CF product APIs and run them
|
||||
# through local LLMs to build a corrections dataset.
|
||||
|
|
@ -102,11 +111,34 @@ imitate:
|
|||
text_fields: [title, description, seller_info]
|
||||
prompt_template: "Evaluate the trustworthiness of this listing and flag any red flags:\n\n{text}"
|
||||
|
||||
- id: osprey
|
||||
name: Osprey
|
||||
icon: "📞"
|
||||
description: Gov't hold-line automation
|
||||
base_url: http://localhost:8520
|
||||
sample_endpoint: /api/calls/recent
|
||||
text_fields: [agency, issue, notes]
|
||||
prompt_template: "Draft a concise summary of this government call record:\n\n{text}"
|
||||
- id: pagepiper
|
||||
name: Pagepiper
|
||||
icon: "📄"
|
||||
description: "PDF/rulebook RAG tool: page-level text chunks"
|
||||
base_url: http://localhost:8511
|
||||
health_path: /api/health
|
||||
sample_endpoint: /api/library
|
||||
chunk_endpoint: /api/library/sample-chunks?limit=50 # requires pagepiper#6
|
||||
text_fields: [title]
|
||||
prompt_template: "Summarize the key rules described in this passage:\n\n{text}"
|
||||
|
||||
# ── Log corpus (Turnstone training data) ──────────────────────────────────────
|
||||
corpus:
|
||||
# Directory containing pipeline JSONL log files to ingest (pull-side).
|
||||
# Files named <script>_<ts>.jsonl; one structured record per line.
|
||||
# POST /api/corpus/pipeline-ingest walks this dir and imports new files.
|
||||
# NFS-mounted on both Heimdall and Sif at /Library/Assets/
|
||||
pipeline_ingest_dir: /Library/Assets/logs/pipeline/
|
||||
|
||||
# Turnstone push sources (consent-gated, token-authenticated).
|
||||
# sources:
|
||||
# - token: "your-bearer-token"
|
||||
# source_host: "node.local"
|
||||
# owner: YourName
|
||||
# consent_date: "2026-05-17"
|
||||
# consent_method: signal_chat
|
||||
|
||||
# ── Embedding model comparison harness ────────────────────────────────────────
|
||||
embed_bench:
|
||||
# ollama_url: http://localhost:11434 # optional; falls back to cforch.ollama_url
|
||||
# top_k: 5 # default hits per model per query
|
||||
|
|
|
|||
35
manage.sh
35
manage.sh
|
|
@ -90,6 +90,12 @@ usage() {
|
|||
echo -e " ${GREEN}score [args]${NC} Shortcut: --score [args]"
|
||||
echo -e " ${GREEN}compare [args]${NC} Shortcut: --compare [args]"
|
||||
echo ""
|
||||
echo " Planning Benchmark:"
|
||||
echo -e " ${GREEN}plans-bench [args]${NC} Run benchmark_plans.py (args passed through)"
|
||||
echo -e " ${GREEN}plans-list${NC} Shortcut: --list-models"
|
||||
echo -e " ${GREEN}plans-run <model> [args]${NC} Run a single model (--verbose auto-added)"
|
||||
echo -e " ${GREEN}plans-compare <m1> <m2> [more]${NC} Compare models side-by-side"
|
||||
echo ""
|
||||
echo " Writing Style Benchmark:"
|
||||
echo -e " ${GREEN}style-bench [args]${NC} Run benchmark_style.py (args passed through)"
|
||||
echo -e " ${GREEN}style-list${NC} List available ollama models for style bench"
|
||||
|
|
@ -127,6 +133,8 @@ case "$CMD" in
|
|||
fi
|
||||
mkdir -p "$LOG_DIR"
|
||||
API_LOG="${LOG_DIR}/api.log"
|
||||
# Load .env if present — sets HF_TOKEN and other optional overrides.
|
||||
[[ -f .env ]] && set -a && source .env && set +a
|
||||
info "Building Vue SPA…"
|
||||
(cd web && npm run build) >> "$API_LOG" 2>&1
|
||||
info "Starting FastAPI on port ${API_PORT}…"
|
||||
|
|
@ -179,6 +187,9 @@ case "$CMD" in
|
|||
mkdir -p "$LOG_DIR"
|
||||
DEV_API_LOG="${LOG_DIR}/dev-api.log"
|
||||
|
||||
# Load .env if present — sets HF_TOKEN and other optional overrides.
|
||||
[[ -f .env ]] && set -a && source .env && set +a
|
||||
|
||||
if [[ -f "$DEV_API_PID_FILE" ]] && kill -0 "$(<"$DEV_API_PID_FILE")" 2>/dev/null; then
|
||||
warn "Dev API already running (PID $(<"$DEV_API_PID_FILE"))"
|
||||
else
|
||||
|
|
@ -255,6 +266,30 @@ case "$CMD" in
|
|||
exec "$0" benchmark --compare "$@"
|
||||
;;
|
||||
|
||||
plans-bench)
|
||||
info "Running planning benchmark (${ENV_UI})…"
|
||||
"$PYTHON_UI" scripts/benchmark_plans.py "$@"
|
||||
;;
|
||||
|
||||
plans-list)
|
||||
exec "$0" plans-bench --list-models
|
||||
;;
|
||||
|
||||
plans-run)
|
||||
if [[ $# -lt 1 ]]; then
|
||||
error "Usage: ./manage.sh plans-run <model-key> [extra args]"
|
||||
fi
|
||||
MODEL="$1"; shift
|
||||
exec "$0" plans-bench --model "$MODEL" --verbose "$@"
|
||||
;;
|
||||
|
||||
plans-compare)
|
||||
if [[ $# -lt 2 ]]; then
|
||||
error "Usage: ./manage.sh plans-compare <model1> <model2> [more…]"
|
||||
fi
|
||||
exec "$0" plans-bench --compare "$@" --verbose
|
||||
;;
|
||||
|
||||
style-bench)
|
||||
info "Running writing style benchmark (${ENV_BM})…"
|
||||
if [[ ! -x "$PYTHON_BM" ]]; then
|
||||
|
|
|
|||
|
|
@ -3,3 +3,6 @@ testpaths = tests
|
|||
python_files = test_*.py
|
||||
python_classes = Test*
|
||||
python_functions = test_*
|
||||
markers =
|
||||
gpu: requires an idle GPU; excluded from default runs
|
||||
slow: long-running test; excluded from default CI runs
|
||||
|
|
|
|||
|
|
@ -3,3 +3,4 @@ pydantic>=2.0.0
|
|||
uvicorn[standard]>=0.20.0
|
||||
httpx>=0.24.0
|
||||
pytest>=7.0.0
|
||||
pyyaml>=6.0
|
||||
|
|
|
|||
|
|
@ -39,6 +39,7 @@ from scripts.classifier_adapters import (
|
|||
LABELS,
|
||||
LABEL_DESCRIPTIONS,
|
||||
ClassifierAdapter,
|
||||
EmbeddingKNNAdapter,
|
||||
FineTunedAdapter,
|
||||
GLiClassAdapter,
|
||||
RerankerAdapter,
|
||||
|
|
@ -130,6 +131,13 @@ MODEL_REGISTRY: dict[str, dict[str, Any]] = {
|
|||
"params": "600M",
|
||||
"default": False,
|
||||
},
|
||||
"embed-knn-nomic": {
|
||||
"adapter": EmbeddingKNNAdapter,
|
||||
"model_id": "nomic-embed-text",
|
||||
"params": "local-embed",
|
||||
"default": False, # requires orch or ollama; use --include-slow
|
||||
"kwargs": {"k": 3},
|
||||
},
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -184,6 +192,42 @@ def discover_finetuned_models(models_dir: Path | None = None) -> list[dict]:
|
|||
return found
|
||||
|
||||
|
||||
def build_exemplars_from_jsonl(path: str, k_per_label: int = 10) -> dict[str, list[str]]:
|
||||
"""Sample up to k_per_label formatted email texts per label from a scored JSONL.
|
||||
|
||||
Formats each row as 'Subject: {subject}\n\n{body[:600]}' — the same format
|
||||
EmbeddingKNNAdapter uses at classify() time. Rows missing the 'label' key
|
||||
are skipped silently.
|
||||
|
||||
Returns dict[label, list[str]] ready for EmbeddingKNNAdapter(exemplar_texts=...).
|
||||
"""
|
||||
result: dict[str, list[str]] = {}
|
||||
p = Path(path)
|
||||
with p.open(encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
row = json.loads(line)
|
||||
except json.JSONDecodeError as exc:
|
||||
print(f"[build_exemplars] WARN: skipping malformed line: {exc}", flush=True)
|
||||
continue
|
||||
label = row.get("label")
|
||||
if not label:
|
||||
continue
|
||||
subject = row.get("subject", "")
|
||||
body = row.get("body", "")
|
||||
if not subject and not body:
|
||||
continue
|
||||
texts = result.setdefault(label, [])
|
||||
if len(texts) < k_per_label:
|
||||
texts.append(
|
||||
f"Subject: {subject}\n\n{body[:600]}"
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def _active_models(include_slow: bool = False) -> dict[str, dict[str, Any]]:
|
||||
"""Return the active model registry, merged with any discovered fine-tuned models."""
|
||||
active: dict[str, dict[str, Any]] = {
|
||||
|
|
|
|||
734
scripts/benchmark_plans.py
Normal file
734
scripts/benchmark_plans.py
Normal file
|
|
@ -0,0 +1,734 @@
|
|||
#!/usr/bin/env python
|
||||
"""CF-specific planning benchmark — compare base models before fine-tuning.
|
||||
|
||||
Sends held-out CircuitForge planning prompts to one or more models via the
|
||||
cf-text (local) or cf-orch API, then scores responses against CF-specific
|
||||
rubrics. Use this to select the best base model for SFT.
|
||||
|
||||
Scoring rubrics (each 0-1, summed to total/N):
|
||||
- task_structure : uses checkbox syntax (- [ ]), git commit steps
|
||||
- tier_awareness : mentions Free/Paid/Premium/Ultra tiers
|
||||
- privacy_pillar : mentions privacy/local-inference/no-logging
|
||||
- safety_pillar : mentions safety, human approval, or reversibility
|
||||
- accessibility : mentions ND/accessibility/adaptive needs
|
||||
- license_split : mentions MIT vs BSL or open-core model
|
||||
- file_paths : uses plausible file path references
|
||||
- cf_conventions : uses conda run -n cf, /Library/Development/, or known CF dirs
|
||||
- paired_coherence : (paired only) plan references the design doc's feature name
|
||||
- length_ok : 300–2500 words (under-short = hallucination risk; over-long = padding)
|
||||
|
||||
Usage
|
||||
-----
|
||||
# List available model targets
|
||||
python scripts/benchmark_plans.py --list-models
|
||||
|
||||
# Run all held-out prompts against a single model, print report
|
||||
python scripts/benchmark_plans.py --model granite-4.1-8b
|
||||
|
||||
# Compare two models side-by-side
|
||||
python scripts/benchmark_plans.py --compare granite-4.1-8b deepseek-r1-7b-4bit
|
||||
|
||||
# Run with a custom API base (cf-text default: http://localhost:8080/v1)
|
||||
python scripts/benchmark_plans.py --model granite-4.1-8b --api-base http://localhost:8080/v1
|
||||
|
||||
# Export detailed results JSON
|
||||
python scripts/benchmark_plans.py --model granite-4.1-8b --output data/bench_results.json
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import re
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from dataclasses import dataclass, field, asdict
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
# ── Paths ──────────────────────────────────────────────────────────────────────
|
||||
|
||||
_ROOT = Path(__file__).parent.parent
|
||||
_DATA_DIR = _ROOT / "data"
|
||||
|
||||
CF_TEXT_BASE = "http://localhost:8080/v1"
|
||||
CF_ORCH_BASE = "http://localhost:8090/v1"
|
||||
CF_COORD_URL = "http://10.1.10.71:7700" # cf-orch coordinator (LAN)
|
||||
|
||||
# ── Held-out prompts ───────────────────────────────────────────────────────────
|
||||
# These are NOT in the training export (no matching docs in circuitforge-plans/).
|
||||
# Each prompt exercises a different CF planning domain.
|
||||
|
||||
HELD_OUT_PROMPTS: list[dict[str, Any]] = [
|
||||
{
|
||||
"id": "ho_001",
|
||||
"name": "kiwi_barcode_ocr",
|
||||
"domain": "feature_plan",
|
||||
"prompt": (
|
||||
"You are a senior engineer on Kiwi, a CircuitForge pantry-tracking product. "
|
||||
"Write a detailed implementation plan for adding barcode scanning via device camera "
|
||||
"and receipt OCR to the item-add flow.\n\n"
|
||||
"The plan should include: file structure (create/modify), step-by-step task checklist "
|
||||
"with checkboxes, any DB migrations, and git commit steps."
|
||||
),
|
||||
"expected_signals": ["task_structure", "file_paths", "cf_conventions"],
|
||||
},
|
||||
{
|
||||
"id": "ho_002",
|
||||
"name": "peregrine_ats_scoring",
|
||||
"domain": "feature_design",
|
||||
"prompt": (
|
||||
"Write a design document for Peregrine: ATS keyword scoring for job applications.\n\n"
|
||||
"Context: Peregrine users paste job descriptions and their resume. "
|
||||
"We want to score how well the resume keywords match the JD and suggest rewrites. "
|
||||
"Describe the architecture, data flow, and key design decisions."
|
||||
),
|
||||
"expected_signals": ["privacy_pillar", "tier_awareness", "license_split"],
|
||||
},
|
||||
{
|
||||
"id": "ho_003",
|
||||
"name": "tier_gate_local_llm",
|
||||
"domain": "architecture",
|
||||
"prompt": (
|
||||
"Design the tier-gating architecture for a new CircuitForge product. "
|
||||
"The product should:\n"
|
||||
"- Default to local LLM inference for all tiers\n"
|
||||
"- Unlock cloud LLM for Paid tier and above\n"
|
||||
"- Keep fine-tuned model weights for Premium/Ultra only\n\n"
|
||||
"Describe how the tier check integrates with the LLM router, "
|
||||
"what happens when a Free user tries a Paid-tier feature, "
|
||||
"and how BYOK (bring-your-own-key) fits in."
|
||||
),
|
||||
"expected_signals": ["tier_awareness", "privacy_pillar", "license_split"],
|
||||
},
|
||||
{
|
||||
"id": "ho_004",
|
||||
"name": "heimdall_webhook_plan",
|
||||
"domain": "feature_plan",
|
||||
"prompt": (
|
||||
"Break the following Heimdall feature into a detailed implementation plan with "
|
||||
"file structure and task checkboxes — Stripe webhook handler for subscription lifecycle.\n\n"
|
||||
"Heimdall is the CircuitForge license server (FastAPI + SQLite). "
|
||||
"The webhook needs to handle checkout.session.completed, "
|
||||
"customer.subscription.updated, and customer.subscription.deleted events."
|
||||
),
|
||||
"expected_signals": ["task_structure", "file_paths", "safety_pillar"],
|
||||
},
|
||||
{
|
||||
"id": "ho_005",
|
||||
"name": "nd_accessible_onboarding",
|
||||
"domain": "ux_design",
|
||||
"prompt": (
|
||||
"You are a product designer working on Harrier, a CircuitForge tool for "
|
||||
"helping people navigate government benefits applications.\n\n"
|
||||
"Design the onboarding flow for neurodivergent (ND) users. "
|
||||
"Consider: ADHD time-blindness, executive function challenges, demand avoidance, "
|
||||
"and rejection sensitivity. The flow should reduce cognitive load and "
|
||||
"never use urgency or panic patterns."
|
||||
),
|
||||
"expected_signals": ["accessibility", "safety_pillar", "privacy_pillar"],
|
||||
},
|
||||
{
|
||||
"id": "ho_006",
|
||||
"name": "circuitforge_core_extraction",
|
||||
"domain": "architecture",
|
||||
"prompt": (
|
||||
"Produce a CircuitForge-style design document for the following circuitforge-core "
|
||||
"feature — shared ActivityPub federation module.\n\n"
|
||||
"Background: Multiple CF products (Kiwi, Rook, Snipe) want to publish updates "
|
||||
"to ActivityPub. Build it once in cf-core (MIT licensed) so all products can use it. "
|
||||
"Design the module API, describe what belongs in MIT vs BSL, and note federation "
|
||||
"privacy constraints."
|
||||
),
|
||||
"expected_signals": ["license_split", "privacy_pillar", "cf_conventions"],
|
||||
},
|
||||
{
|
||||
"id": "ho_007",
|
||||
"name": "snipe_trust_score_plan",
|
||||
"domain": "feature_plan",
|
||||
"prompt": (
|
||||
"You are a senior engineer on Snipe, a CircuitForge eBay trust-scoring tool. "
|
||||
"Write a step-by-step engineering plan for: seller trust score calculation.\n\n"
|
||||
"The score should combine: feedback ratio, account age, item-specifics completeness, "
|
||||
"listing photo quality, and shipping time accuracy. "
|
||||
"Include file structure, test plan, and migration steps."
|
||||
),
|
||||
"expected_signals": ["task_structure", "file_paths", "safety_pillar"],
|
||||
},
|
||||
{
|
||||
"id": "ho_008",
|
||||
"name": "avocet_training_pipeline",
|
||||
"domain": "feature_plan",
|
||||
"prompt": (
|
||||
"Break the following Avocet feature into a detailed implementation plan — "
|
||||
"end-to-end fine-tuning pipeline from labeled JSONL to deployed GGUF model.\n\n"
|
||||
"Avocet is the CircuitForge email classifier training tool. "
|
||||
"The pipeline should: validate the dataset, run LoRA SFT via unsloth, "
|
||||
"quantize to Q5_K_M GGUF, run the benchmark harness, and register the model "
|
||||
"in the Avocet model queue if it beats the baseline."
|
||||
),
|
||||
"expected_signals": ["task_structure", "file_paths", "cf_conventions"],
|
||||
},
|
||||
{
|
||||
"id": "ho_009",
|
||||
"name": "privacy_data_flow",
|
||||
"domain": "architecture",
|
||||
"prompt": (
|
||||
"Design the data privacy architecture for a CircuitForge cloud product. "
|
||||
"Describe: what PII is collected, how it's stored, retention policy, "
|
||||
"obfuscation strategy for cloud-side logs, and how consent is obtained "
|
||||
"in plain language. The product handles job applications (resumes, cover letters)."
|
||||
),
|
||||
"expected_signals": ["privacy_pillar", "safety_pillar", "accessibility"],
|
||||
},
|
||||
{
|
||||
"id": "ho_010",
|
||||
"name": "git_workflow_doc",
|
||||
"domain": "process_doc",
|
||||
"prompt": (
|
||||
"Write a developer process document for CircuitForge: conventional commit and "
|
||||
"branch workflow for a BSL 1.1 open-core product.\n\n"
|
||||
"Cover: commit message format (type: description), branch naming, "
|
||||
"when to use feature branches vs direct main commits, "
|
||||
"how the MIT/BSL split affects which commits go in which branch, "
|
||||
"and how CI gates on gitleaks for secret scanning."
|
||||
),
|
||||
"expected_signals": ["license_split", "cf_conventions", "task_structure"],
|
||||
},
|
||||
]
|
||||
|
||||
# ── Rubric scoring ─────────────────────────────────────────────────────────────
|
||||
|
||||
_TASK_STRUCTURE_RE = re.compile(r"- \[ \]", re.MULTILINE)
|
||||
_COMMIT_RE = re.compile(r"git commit|git add", re.IGNORECASE)
|
||||
_TIER_RE = re.compile(r"\b(Free|Paid|Premium|Ultra)\s+tier|\btier\s+(Free|Paid|Premium|Ultra)", re.IGNORECASE)
|
||||
_PRIVACY_RE = re.compile(r"\b(privacy|local.?inference|no.?logging|no.?pii|user.?data|data.?reten|obfuscat)", re.IGNORECASE)
|
||||
_SAFETY_RE = re.compile(r"\b(human.?approv|reversib|safety|safe.?default|fail.?safe|harm)", re.IGNORECASE)
|
||||
_A11Y_RE = re.compile(r"\b(neurodiverg|ND\b|accessib|adaptive|ADHD|autism|executive.?function|demand.?avoid)", re.IGNORECASE)
|
||||
_LICENSE_RE = re.compile(r"\b(MIT|BSL|open.?core|proprietary|commercial.?licens)", re.IGNORECASE)
|
||||
_FILE_PATH_RE = re.compile(r"(app/|tests?/|src/|scripts?/)\w[\w/.-]{3,}", re.IGNORECASE)
|
||||
_CF_CONV_RE = re.compile(r"(conda run -n cf|/Library/Development/CircuitForge|circuitforge-core|manage\.sh)", re.IGNORECASE)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RubricScore:
|
||||
task_structure: float = 0.0
|
||||
tier_awareness: float = 0.0
|
||||
privacy_pillar: float = 0.0
|
||||
safety_pillar: float = 0.0
|
||||
accessibility: float = 0.0
|
||||
license_split: float = 0.0
|
||||
file_paths: float = 0.0
|
||||
cf_conventions: float = 0.0
|
||||
length_ok: float = 0.0
|
||||
|
||||
def total(self) -> float:
|
||||
vals = [self.task_structure, self.tier_awareness, self.privacy_pillar,
|
||||
self.safety_pillar, self.accessibility, self.license_split,
|
||||
self.file_paths, self.cf_conventions, self.length_ok]
|
||||
return sum(vals) / len(vals)
|
||||
|
||||
def as_dict(self) -> dict[str, float]:
|
||||
return asdict(self)
|
||||
|
||||
|
||||
def score_response(response: str, prompt_meta: dict[str, Any]) -> RubricScore:
|
||||
words = len(response.split())
|
||||
s = RubricScore()
|
||||
|
||||
# Task structure: needs checkboxes AND at least one commit step
|
||||
checkbox_hits = len(_TASK_STRUCTURE_RE.findall(response))
|
||||
has_commit = bool(_COMMIT_RE.search(response))
|
||||
s.task_structure = min(1.0, checkbox_hits / 5) * 0.7 + (0.3 if has_commit else 0.0)
|
||||
|
||||
# Tier awareness
|
||||
s.tier_awareness = min(1.0, len(_TIER_RE.findall(response)) / 2)
|
||||
|
||||
# Privacy pillar
|
||||
s.privacy_pillar = min(1.0, len(_PRIVACY_RE.findall(response)) / 3)
|
||||
|
||||
# Safety pillar
|
||||
s.safety_pillar = min(1.0, len(_SAFETY_RE.findall(response)) / 2)
|
||||
|
||||
# Accessibility
|
||||
s.accessibility = min(1.0, len(_A11Y_RE.findall(response)) / 2)
|
||||
|
||||
# License split awareness
|
||||
s.license_split = min(1.0, len(_LICENSE_RE.findall(response)) / 2)
|
||||
|
||||
# File paths: at least 3 plausible path references
|
||||
s.file_paths = min(1.0, len(_FILE_PATH_RE.findall(response)) / 3)
|
||||
|
||||
# CF conventions
|
||||
s.cf_conventions = min(1.0, len(_CF_CONV_RE.findall(response)) / 2)
|
||||
|
||||
# Length: 200–2500 words is healthy; outside = partial credit
|
||||
if 200 <= words <= 2500:
|
||||
s.length_ok = 1.0
|
||||
elif words < 200:
|
||||
s.length_ok = words / 200
|
||||
else:
|
||||
s.length_ok = max(0.0, 1.0 - (words - 2500) / 2500)
|
||||
|
||||
return s
|
||||
|
||||
|
||||
# ── Model client ───────────────────────────────────────────────────────────────
|
||||
|
||||
# Registry of named model targets (shorthand → {api_base, model_name})
|
||||
MODEL_REGISTRY: dict[str, dict[str, str]] = {
|
||||
"deepseek-r1-1.5b": {
|
||||
"api_base": CF_TEXT_BASE,
|
||||
"model": "deepseek-r1-1.5b",
|
||||
"description": "DeepSeek R1 1.5B distill (cf-orch catalog key)",
|
||||
},
|
||||
"deepseek-r1-7b-4bit": {
|
||||
"api_base": CF_TEXT_BASE,
|
||||
"model": "deepseek-r1-7b-4bit",
|
||||
"description": "DeepSeek R1 7B distill, 4-bit (cf-orch catalog key)",
|
||||
},
|
||||
"deepseek-r1-0528-qwen3-8b-gguf": {
|
||||
"api_base": CF_TEXT_BASE,
|
||||
"model": "deepseek-r1-0528-qwen3-8b-gguf",
|
||||
"description": "DeepSeek R1 0528 Qwen3 8B GGUF -- current reasoning model (4 nodes)",
|
||||
},
|
||||
"deepseek-coder-6.7b-4bit": {
|
||||
"api_base": CF_TEXT_BASE,
|
||||
"model": "deepseek-coder-6.7b-4bit",
|
||||
"description": "DeepSeek Coder 6.7B instruct, 4-bit (cf-orch catalog key)",
|
||||
},
|
||||
"granite-4.1-8b": {
|
||||
"api_base": CF_TEXT_BASE,
|
||||
"model": "granite-4.1-8b",
|
||||
"description": "IBM Granite 4.1 8B, 4-bit -- safety-trained (cf-orch catalog key)",
|
||||
},
|
||||
"capybarahermes-2.5-mistral-7b-gguf": {
|
||||
"api_base": CF_TEXT_BASE,
|
||||
"model": "capybarahermes-2.5-mistral-7b-gguf",
|
||||
"description": "CapybaraHermes 2.5 Mistral 7B GGUF -- conversational/creative (4 nodes)",
|
||||
},
|
||||
"darwin-9b-opus-gguf": {
|
||||
"api_base": CF_TEXT_BASE,
|
||||
"model": "darwin-9b-opus-gguf",
|
||||
"description": "Darwin 9B Opus GGUF -- high-quality long-form writing (3 nodes)",
|
||||
},
|
||||
"qwen2.5-3b": {
|
||||
"api_base": CF_TEXT_BASE,
|
||||
"model": "qwen2.5-3b",
|
||||
"description": "Qwen 2.5 3B Q4 GGUF (cf-orch catalog key)",
|
||||
},
|
||||
"qwen2.5-7b": {
|
||||
"api_base": CF_TEXT_BASE,
|
||||
"model": "qwen2.5-7b",
|
||||
"description": "Qwen 2.5 7B Q4 GGUF (cf-orch catalog key)",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# ── cf-orch allocation ─────────────────────────────────────────────────────────
|
||||
|
||||
def _cforch_allocate(
|
||||
model_id: str,
|
||||
cforch_url: str,
|
||||
startup_timeout_s: float = 300.0,
|
||||
) -> tuple[str, str] | None:
|
||||
"""Allocate a cf-text instance for model_id via the cf-orch coordinator.
|
||||
|
||||
Returns (service_url, allocation_id) on success, None on failure.
|
||||
service_url is the direct node URL exposing /v1/chat/completions.
|
||||
"""
|
||||
try:
|
||||
resp = httpx.post(
|
||||
f"{cforch_url}/api/services/cf-text/allocate",
|
||||
json={
|
||||
"model_candidates": [model_id],
|
||||
"caller": "avocet",
|
||||
"pipeline": "plans_benchmark",
|
||||
},
|
||||
timeout=120.0,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
service_url: str = data["url"]
|
||||
allocation_id: str = data.get("allocation_id", "")
|
||||
node_id: str = data.get("node_id", "")
|
||||
gpu_id: int | None = data.get("gpu_id")
|
||||
|
||||
if data.get("started", False) and not data.get("warm", True):
|
||||
# Use \n so the SSE generator sees the line immediately
|
||||
print(f" [cold start] loading {model_id!r} — polling every 3s…", flush=True)
|
||||
t0 = time.monotonic()
|
||||
deadline = t0 + startup_timeout_s
|
||||
probe_misses = 0
|
||||
|
||||
while time.monotonic() < deadline:
|
||||
elapsed = time.monotonic() - t0
|
||||
try:
|
||||
status = httpx.get(f"{cforch_url}/api/services/cf-text/status", timeout=5.0)
|
||||
if status.is_success:
|
||||
instances = status.json().get("instances", [])
|
||||
match = next(
|
||||
(i for i in instances
|
||||
if i.get("node_id") == node_id and i.get("gpu_id") == gpu_id),
|
||||
None,
|
||||
)
|
||||
if match:
|
||||
probe_misses = 0
|
||||
state = match.get("state", "")
|
||||
if state == "running":
|
||||
print(f" [cold start] ready in {elapsed:.0f}s", flush=True)
|
||||
return service_url, allocation_id
|
||||
elif state == "stopped":
|
||||
print(f" [cold start] failed — service stopped after {elapsed:.0f}s", flush=True)
|
||||
return None
|
||||
else:
|
||||
# still starting — emit keepalive so SSE stream stays alive
|
||||
print(f" [cold start] state={state!r} elapsed={elapsed:.0f}s", flush=True)
|
||||
else:
|
||||
probe_misses += 1
|
||||
print(f" [cold start] waiting… elapsed={elapsed:.0f}s", flush=True)
|
||||
if probe_misses >= 6:
|
||||
try:
|
||||
h = httpx.get(f"{service_url}/health", timeout=3.0)
|
||||
if h.is_success:
|
||||
print(f" [cold start] ready via health check in {elapsed:.0f}s", flush=True)
|
||||
return service_url, allocation_id
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
print(f" [cold start] status poll returned {status.status_code}, elapsed={elapsed:.0f}s", flush=True)
|
||||
except Exception as poll_exc:
|
||||
print(f" [cold start] poll error: {poll_exc} elapsed={elapsed:.0f}s", flush=True)
|
||||
time.sleep(3.0)
|
||||
|
||||
print(f" [cold start] timed out after {time.monotonic()-t0:.0f}s", flush=True)
|
||||
return None
|
||||
|
||||
return service_url, allocation_id
|
||||
except Exception as exc:
|
||||
print(f"[warn] cf-orch allocation failed for {model_id!r}: {exc}", file=sys.stderr)
|
||||
return None
|
||||
|
||||
|
||||
def _call_model_direct(service_url: str, model: str, prompt: str, timeout: int = 600) -> tuple[str, float]:
|
||||
"""Call an OpenAI-compatible /v1/chat/completions on a direct service URL."""
|
||||
t0 = time.monotonic()
|
||||
resp = httpx.post(
|
||||
f"{service_url.rstrip('/')}/v1/chat/completions",
|
||||
json={
|
||||
"model": model,
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"max_tokens": 2048,
|
||||
"temperature": 0.2,
|
||||
},
|
||||
timeout=timeout,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
latency = time.monotonic() - t0
|
||||
text = resp.json()["choices"][0]["message"]["content"]
|
||||
return text, latency
|
||||
|
||||
|
||||
def _call_model(api_base: str, model: str, prompt: str, timeout: int = 180) -> tuple[str, float]:
|
||||
"""Call an OpenAI-compatible /chat/completions endpoint. Returns (text, latency_s)."""
|
||||
t0 = time.monotonic()
|
||||
resp = httpx.post(
|
||||
f"{api_base}/chat/completions",
|
||||
json={
|
||||
"model": model,
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"max_tokens": 2048,
|
||||
"temperature": 0.2,
|
||||
},
|
||||
timeout=timeout,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
latency = time.monotonic() - t0
|
||||
text = resp.json()["choices"][0]["message"]["content"]
|
||||
return text, latency
|
||||
|
||||
|
||||
# ── Benchmark runner ───────────────────────────────────────────────────────────
|
||||
|
||||
@dataclass
|
||||
class PromptResult:
|
||||
prompt_id: str
|
||||
prompt_name: str
|
||||
model_key: str
|
||||
response: str
|
||||
latency_s: float
|
||||
word_count: int
|
||||
scores: dict[str, float]
|
||||
total_score: float
|
||||
error: str | None = None
|
||||
|
||||
|
||||
def run_benchmark(
|
||||
model_key: str,
|
||||
model_name: str,
|
||||
prompts: list[dict[str, Any]] | None = None,
|
||||
verbose: bool = False,
|
||||
# cf-orch path
|
||||
use_cforch: bool = False,
|
||||
cforch_url: str = CF_COORD_URL,
|
||||
# direct path (used when not cf-orch)
|
||||
api_base: str = CF_TEXT_BASE,
|
||||
) -> list[PromptResult]:
|
||||
"""Run all prompts through one model. Uses cf-orch allocation when use_cforch=True."""
|
||||
if prompts is None:
|
||||
prompts = HELD_OUT_PROMPTS
|
||||
|
||||
# Allocate once per model when using cf-orch
|
||||
service_url: str | None = None
|
||||
if use_cforch:
|
||||
print(f" Allocating {model_name!r} via cf-orch…", flush=True)
|
||||
alloc = _cforch_allocate(model_name, cforch_url)
|
||||
if alloc is None:
|
||||
# Return all prompts as errors
|
||||
return [
|
||||
PromptResult(
|
||||
prompt_id=p["id"], prompt_name=p["name"], model_key=model_key,
|
||||
response="", latency_s=0.0, word_count=0, scores={}, total_score=0.0,
|
||||
error=f"cf-orch allocation failed for {model_name!r}",
|
||||
)
|
||||
for p in prompts
|
||||
]
|
||||
service_url, _alloc_id = alloc
|
||||
|
||||
results: list[PromptResult] = []
|
||||
for p in prompts:
|
||||
if verbose:
|
||||
print(f" [{p['id']}] {p['name']} … ", end="", flush=True)
|
||||
try:
|
||||
if service_url:
|
||||
response, latency = _call_model_direct(service_url, model_name, p["prompt"])
|
||||
else:
|
||||
response, latency = _call_model(api_base, model_name, p["prompt"])
|
||||
rubric = score_response(response, p)
|
||||
result = PromptResult(
|
||||
prompt_id=p["id"],
|
||||
prompt_name=p["name"],
|
||||
model_key=model_key,
|
||||
response=response,
|
||||
latency_s=round(latency, 2),
|
||||
word_count=len(response.split()),
|
||||
scores=rubric.as_dict(),
|
||||
total_score=round(rubric.total(), 3),
|
||||
)
|
||||
if verbose:
|
||||
print(f"score={result.total_score:.3f} ({result.word_count}w, {latency:.1f}s)")
|
||||
except Exception as exc:
|
||||
result = PromptResult(
|
||||
prompt_id=p["id"],
|
||||
prompt_name=p["name"],
|
||||
model_key=model_key,
|
||||
response="",
|
||||
latency_s=0.0,
|
||||
word_count=0,
|
||||
scores={},
|
||||
total_score=0.0,
|
||||
error=str(exc),
|
||||
)
|
||||
if verbose:
|
||||
print(f"ERROR: {exc}")
|
||||
results.append(result)
|
||||
return results
|
||||
|
||||
|
||||
# ── Reporting ──────────────────────────────────────────────────────────────────
|
||||
|
||||
def _print_single_report(results: list[PromptResult], model_key: str) -> None:
|
||||
ok = [r for r in results if not r.error]
|
||||
err = [r for r in results if r.error]
|
||||
if not ok:
|
||||
print(f"\n[{model_key}] All {len(err)} prompts failed.\n")
|
||||
return
|
||||
|
||||
avg_total = sum(r.total_score for r in ok) / len(ok)
|
||||
avg_latency = sum(r.latency_s for r in ok) / len(ok)
|
||||
|
||||
# Aggregate per-rubric averages
|
||||
rubric_keys = list(ok[0].scores.keys())
|
||||
rubric_avgs = {k: sum(r.scores.get(k, 0) for r in ok) / len(ok) for k in rubric_keys}
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print(f" Model : {model_key}")
|
||||
print(f" Prompts: {len(ok)}/{len(results)} passed ({len(err)} errors)")
|
||||
print(f" Overall score : {avg_total:.3f} (avg latency {avg_latency:.1f}s)")
|
||||
print(f"\n Rubric breakdown:")
|
||||
for k, v in sorted(rubric_avgs.items(), key=lambda x: -x[1]):
|
||||
bar = "█" * int(v * 20)
|
||||
print(f" {k:<22} {v:.3f} {bar}")
|
||||
print(f"\n Per-prompt scores:")
|
||||
for r in sorted(ok, key=lambda x: -x.total_score):
|
||||
flag = "⚠" if r.total_score < 0.3 else " "
|
||||
print(f" {flag} {r.prompt_id} {r.prompt_name:<35} {r.total_score:.3f} ({r.word_count}w)")
|
||||
if err:
|
||||
print(f"\n Errors:")
|
||||
for r in err:
|
||||
print(f" {r.prompt_id} {r.prompt_name}: {r.error}")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
|
||||
def _print_comparison_table(all_results: dict[str, list[PromptResult]]) -> None:
|
||||
model_keys = list(all_results.keys())
|
||||
prompt_ids = [p["id"] for p in HELD_OUT_PROMPTS]
|
||||
|
||||
# Scores by (model, prompt_id)
|
||||
score_map: dict[tuple[str, str], float] = {}
|
||||
for mk, results in all_results.items():
|
||||
for r in results:
|
||||
score_map[(mk, r.prompt_id)] = r.total_score if not r.error else 0.0
|
||||
|
||||
col_w = 10
|
||||
header = f"{'Prompt':<35}" + "".join(f"{mk[:col_w-1]:<{col_w}}" for mk in model_keys)
|
||||
print(f"\n{'='*len(header)}")
|
||||
print(" COMPARISON TABLE")
|
||||
print(f"{'='*len(header)}")
|
||||
print(f" {header}")
|
||||
print(f" {'-'*len(header)}")
|
||||
|
||||
for pid in prompt_ids:
|
||||
pname = next(p["name"] for p in HELD_OUT_PROMPTS if p["id"] == pid)
|
||||
row = f" {pname:<35}"
|
||||
best = max(score_map.get((mk, pid), 0.0) for mk in model_keys)
|
||||
for mk in model_keys:
|
||||
v = score_map.get((mk, pid), 0.0)
|
||||
marker = "*" if v == best and len(model_keys) > 1 else " "
|
||||
row += f"{v:.3f}{marker} "
|
||||
print(row)
|
||||
|
||||
print(f" {'-'*len(header)}")
|
||||
avgs_row = f" {'AVERAGE':<35}"
|
||||
best_avg = -1.0
|
||||
avgs: dict[str, float] = {}
|
||||
for mk in model_keys:
|
||||
vals = [score_map.get((mk, pid), 0.0) for pid in prompt_ids]
|
||||
avgs[mk] = sum(vals) / len(vals)
|
||||
best_avg = max(best_avg, avgs[mk])
|
||||
for mk in model_keys:
|
||||
marker = "*" if avgs[mk] == best_avg and len(model_keys) > 1 else " "
|
||||
avgs_row += f"{avgs[mk]:.3f}{marker} "
|
||||
print(avgs_row)
|
||||
print(f"{'='*len(header)}\n")
|
||||
if len(model_keys) > 1:
|
||||
winner = max(avgs, key=lambda k: avgs[k])
|
||||
print(f" Winner: {winner} (avg {avgs[winner]:.3f})\n")
|
||||
|
||||
|
||||
# ── CLI ────────────────────────────────────────────────────────────────────────
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter)
|
||||
parser.add_argument("--list-models", action="store_true",
|
||||
help="Print registered model shortcuts and exit")
|
||||
parser.add_argument("--model", metavar="KEY",
|
||||
help="Benchmark a single model (registry key or raw model name)")
|
||||
parser.add_argument("--compare", nargs="+", metavar="KEY",
|
||||
help="Compare two or more models side-by-side")
|
||||
parser.add_argument("--cforch", action="store_true",
|
||||
help="Route inference through cf-orch coordinator (allocate per model)")
|
||||
parser.add_argument("--cforch-url", default=CF_COORD_URL, metavar="URL",
|
||||
help=f"cf-orch coordinator URL (default: {CF_COORD_URL})")
|
||||
parser.add_argument("--api-base", default=None,
|
||||
help="Direct API base URL when not using cf-orch")
|
||||
parser.add_argument("--model-name", default=None,
|
||||
help="Override model name sent to API (single-model runs only)")
|
||||
parser.add_argument("--prompts", nargs="+", metavar="ID",
|
||||
help="Run only specific prompt IDs (e.g. ho_001 ho_003)")
|
||||
parser.add_argument("--output", type=Path, default=None,
|
||||
help="Write detailed JSON results to this path")
|
||||
parser.add_argument("--workers", type=int, default=1, metavar="N",
|
||||
help="Run N models concurrently (default 1). Set to number of available nodes.")
|
||||
parser.add_argument("--verbose", "-v", action="store_true",
|
||||
help="Print per-prompt progress")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.list_models:
|
||||
print("\nRegistered model shortcuts:")
|
||||
for key, info in MODEL_REGISTRY.items():
|
||||
print(f" {key:<20} {info['description']}")
|
||||
print(f"\nDefault endpoints:")
|
||||
print(f" direct {CF_TEXT_BASE}")
|
||||
print(f" cf-orch {CF_COORD_URL}")
|
||||
return
|
||||
|
||||
prompts = HELD_OUT_PROMPTS
|
||||
if args.prompts:
|
||||
ids = set(args.prompts)
|
||||
prompts = [p for p in HELD_OUT_PROMPTS if p["id"] in ids]
|
||||
if not prompts:
|
||||
print(f"No prompts matched IDs: {args.prompts}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
model_keys: list[str] = []
|
||||
if args.compare:
|
||||
model_keys = args.compare
|
||||
elif args.model:
|
||||
model_keys = [args.model]
|
||||
else:
|
||||
parser.print_help()
|
||||
sys.exit(0)
|
||||
|
||||
all_results: dict[str, list[PromptResult]] = {}
|
||||
print_lock = threading.Lock()
|
||||
|
||||
def _run_one(mk: str) -> tuple[str, list[PromptResult]]:
|
||||
if mk in MODEL_REGISTRY:
|
||||
reg = MODEL_REGISTRY[mk]
|
||||
model_name = args.model_name or reg["model"]
|
||||
direct_base = args.api_base or reg["api_base"]
|
||||
else:
|
||||
model_name = args.model_name or mk
|
||||
direct_base = args.api_base or CF_TEXT_BASE
|
||||
|
||||
if args.cforch:
|
||||
with print_lock:
|
||||
print(f"\nRunning [{mk}] via cf-orch ({args.cforch_url}) model={model_name}")
|
||||
results = run_benchmark(
|
||||
mk, model_name, prompts=prompts, verbose=args.verbose,
|
||||
use_cforch=True, cforch_url=args.cforch_url,
|
||||
)
|
||||
else:
|
||||
with print_lock:
|
||||
print(f"\nRunning [{mk}] → {direct_base} model={model_name}")
|
||||
results = run_benchmark(
|
||||
mk, model_name, prompts=prompts, verbose=args.verbose,
|
||||
api_base=direct_base,
|
||||
)
|
||||
|
||||
with print_lock:
|
||||
_print_single_report(results, mk)
|
||||
return mk, results
|
||||
|
||||
workers = max(1, args.workers)
|
||||
if workers == 1 or len(model_keys) == 1:
|
||||
for mk in model_keys:
|
||||
mk_out, results = _run_one(mk)
|
||||
all_results[mk_out] = results
|
||||
else:
|
||||
with ThreadPoolExecutor(max_workers=workers) as pool:
|
||||
futures = {pool.submit(_run_one, mk): mk for mk in model_keys}
|
||||
for fut in as_completed(futures):
|
||||
mk_out, results = fut.result()
|
||||
all_results[mk_out] = results
|
||||
|
||||
if len(model_keys) > 1:
|
||||
_print_comparison_table(all_results)
|
||||
|
||||
if args.output:
|
||||
args.output.parent.mkdir(parents=True, exist_ok=True)
|
||||
payload = {
|
||||
mk: [asdict(r) for r in results]
|
||||
for mk, results in all_results.items()
|
||||
}
|
||||
with open(args.output, "w", encoding="utf-8") as f:
|
||||
json.dump(payload, f, indent=2, ensure_ascii=False)
|
||||
print(f"Wrote detailed results to {args.output}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -7,19 +7,26 @@ from __future__ import annotations
|
|||
|
||||
import abc
|
||||
from collections import defaultdict
|
||||
import httpx
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
__all__ = [
|
||||
"LABELS",
|
||||
"LABEL_DESCRIPTIONS",
|
||||
"DEFAULT_EXEMPLARS",
|
||||
"compute_metrics",
|
||||
"ClassifierAdapter",
|
||||
"ZeroShotAdapter",
|
||||
"GLiClassAdapter",
|
||||
"RerankerAdapter",
|
||||
"FineTunedAdapter",
|
||||
"EmbeddingKNNAdapter",
|
||||
]
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
LABELS: list[str] = [
|
||||
"interview_scheduled",
|
||||
"offer_received",
|
||||
|
|
@ -117,6 +124,81 @@ def compute_metrics(
|
|||
return result
|
||||
|
||||
|
||||
|
||||
def _cosine(a: list[float], b: list[float]) -> float:
|
||||
dot = sum(x * y for x, y in zip(a, b))
|
||||
norm_a = sum(x * x for x in a) ** 0.5
|
||||
norm_b = sum(x * x for x in b) ** 0.5
|
||||
return dot / (norm_a * norm_b) if norm_a and norm_b else 0.0
|
||||
|
||||
|
||||
DEFAULT_EXEMPLARS: dict[str, list[str]] = {
|
||||
"interview_scheduled": [
|
||||
"Subject: Interview Invitation\n\nWe would like to invite you for a phone screen next week.",
|
||||
"Subject: Schedule a call\n\nCould you be available for a video interview on Tuesday?",
|
||||
"Subject: Next Steps\n\nWe'd like to move forward with a technical interview. Please select a time.",
|
||||
"Subject: Interview Details\n\nHere are the dial-in instructions for your interview tomorrow.",
|
||||
],
|
||||
"offer_received": [
|
||||
"Subject: Offer Letter Enclosed\n\nWe are pleased to extend you an offer of employment.",
|
||||
"Subject: Job Offer\n\nDear candidate, we are excited to offer you the position of Software Engineer.",
|
||||
"Subject: Employment Offer\n\nPlease find attached your formal offer letter and compensation details.",
|
||||
"Subject: Offer of Employment\n\nCongratulations! We would like to offer you a full-time position.",
|
||||
],
|
||||
"rejected": [
|
||||
"Subject: Your Application\n\nAfter careful consideration, we have decided to move forward with other candidates.",
|
||||
"Subject: Application Status\n\nWe regret to inform you that your application has not been selected.",
|
||||
"Subject: Thank you for applying\n\nWe appreciate your interest but have chosen not to proceed.",
|
||||
"Subject: Update on your candidacy\n\nWe will not be moving forward with your application at this time.",
|
||||
],
|
||||
"positive_response": [
|
||||
"Subject: Your profile\n\nI came across your LinkedIn and think you would be a great fit for our team.",
|
||||
"Subject: Exciting opportunity\n\nWe were impressed by your background and would love to connect.",
|
||||
"Subject: Following up\n\nThank you for your interest — we'd like to learn more about your experience.",
|
||||
"Subject: Great fit\n\nYour skills align well with what we are looking for. Let's set up a call.",
|
||||
],
|
||||
"survey_received": [
|
||||
"Subject: Candidate Experience Survey\n\nPlease complete this brief survey about your application experience.",
|
||||
"Subject: Culture Fit Assessment\n\nAs part of our process, we ask all candidates to complete a short assessment.",
|
||||
"Subject: Skills Assessment\n\nWe'd like you to complete our online coding assessment before proceeding.",
|
||||
"Subject: Personality Assessment\n\nPlease complete the following assessment as the next step in our process.",
|
||||
"Subject: Pre-interview questionnaire\n\nBefore we schedule your interview, please complete this brief skills survey.",
|
||||
],
|
||||
"neutral": [
|
||||
"Subject: Application Received\n\nWe have received your application and will be in touch.",
|
||||
"Subject: Thank you for applying\n\nYour application is under review. We will contact you if needed.",
|
||||
"Subject: Confirmation\n\nThis email confirms receipt of your application to our company.",
|
||||
"Subject: Application Confirmation\n\nThank you for your interest. We will review your materials and follow up.",
|
||||
],
|
||||
"event_rescheduled": [
|
||||
"Subject: Interview Rescheduled\n\nDue to a conflict, we need to move your interview to a new time.",
|
||||
"Subject: Change of interview time\n\nWe apologize — your interview has been rescheduled to Thursday.",
|
||||
"Subject: Updated interview details\n\nYour interview has been moved from Monday to Wednesday at 2pm.",
|
||||
"Subject: Reschedule request\n\nWould you be available to reschedule to a different time slot?",
|
||||
"Subject: New interview time\n\nYour phone screen has been moved from tomorrow to next week.",
|
||||
],
|
||||
"digest": [
|
||||
"Subject: 15 new jobs matching your search\n\nHere are the latest job postings that match your profile.",
|
||||
"Subject: Weekly Job Digest\n\nThis week's top opportunities for Software Engineers in your area.",
|
||||
"Subject: Jobs you might like\n\nBased on your profile, here are some positions we recommend.",
|
||||
"Subject: New jobs for you\n\nSee the latest openings from companies on your watchlist.",
|
||||
],
|
||||
"new_lead": [
|
||||
"Subject: Exciting opportunity at our company\n\nHi, I noticed your background and think you'd be a great fit.",
|
||||
"Subject: Are you open to new opportunities?\n\nI'm a recruiter reaching out about a role matching your experience.",
|
||||
"Subject: Quick question\n\nWould you be interested in hearing about a senior engineering role?",
|
||||
"Subject: Recruiting outreach\n\nI came across your profile and wanted to share an exciting opening.",
|
||||
],
|
||||
"hired": [
|
||||
"Subject: Welcome to the team!\n\nWe are thrilled to have you join us. Here are your onboarding details.",
|
||||
"Subject: Onboarding information\n\nCongratulations on accepting our offer. Your start date is confirmed.",
|
||||
"Subject: First day information\n\nWe look forward to your first day. Please arrive at 9am and ask for HR.",
|
||||
"Subject: Background check initiated\n\nAs part of your onboarding, we have initiated a background check.",
|
||||
"Subject: Equipment setup\n\nYour laptop and equipment will be ready for pickup on your first day.",
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
class ClassifierAdapter(abc.ABC):
|
||||
"""Abstract base for all email classifier adapters."""
|
||||
|
||||
|
|
@ -304,3 +386,148 @@ class FineTunedAdapter(ClassifierAdapter):
|
|||
text = f"{subject} [SEP] {body[:400]}"
|
||||
result = self._pipeline(text)
|
||||
return result[0]["label"]
|
||||
|
||||
|
||||
class EmbeddingKNNAdapter(ClassifierAdapter):
|
||||
"""k-NN email classifier using Ollama /v1/embeddings via cf-orch allocation.
|
||||
|
||||
load():
|
||||
1. Allocates an Ollama instance from cf-orch (POST /api/services/ollama/allocate).
|
||||
Falls back to ollama_url directly if orch allocation fails or is not configured.
|
||||
2. Pre-embeds all exemplar texts and stores per-label vector lists.
|
||||
|
||||
classify(subject, body):
|
||||
Embeds the input email, computes cosine similarity against all stored exemplar
|
||||
vectors, and majority-votes the top-k labels (default k=3). Tie-break: label
|
||||
with the highest total similarity score among tied vote counts wins.
|
||||
|
||||
unload():
|
||||
Releases the cf-orch allocation (DELETE .../allocations/{id}) and clears state.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
model_id: str,
|
||||
*,
|
||||
k: int = 3,
|
||||
orch_url: str = "",
|
||||
ollama_url: str = "",
|
||||
exemplar_texts: dict[str, list[str]] | None = None,
|
||||
) -> None:
|
||||
self._name = name
|
||||
self._model_id = model_id
|
||||
self._k = k
|
||||
self._orch_url = orch_url
|
||||
self._ollama_url = ollama_url
|
||||
self._exemplar_texts: dict[str, list[str]] = (
|
||||
exemplar_texts if exemplar_texts is not None else DEFAULT_EXEMPLARS
|
||||
)
|
||||
self._exemplar_embeddings: dict[str, list[list[float]]] = {}
|
||||
self._node_url: str = ""
|
||||
self._allocation_id: str = ""
|
||||
self._orch_url_used: str = ""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def model_id(self) -> str:
|
||||
return self._model_id
|
||||
|
||||
def _resolve_urls(self) -> tuple[str, str]:
|
||||
if self._orch_url or self._ollama_url:
|
||||
return self._orch_url, self._ollama_url
|
||||
import yaml # noqa: PLC0415
|
||||
cfg_path = Path(__file__).parent.parent / "config" / "label_tool.yaml"
|
||||
cfg: dict = {}
|
||||
if cfg_path.exists():
|
||||
try:
|
||||
cfg = yaml.safe_load(cfg_path.read_text(encoding="utf-8")) or {}
|
||||
except yaml.YAMLError:
|
||||
pass
|
||||
cforch = cfg.get("cforch", {}) or {}
|
||||
return cforch.get("coordinator_url", ""), cforch.get("ollama_url", "")
|
||||
|
||||
def _embed(self, node_url: str, texts: list[str]) -> list[list[float]]:
|
||||
resp = httpx.post(
|
||||
f"{node_url}/v1/embeddings",
|
||||
json={"model": self._model_id, "input": texts},
|
||||
timeout=30.0,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return [item["embedding"] for item in resp.json()["data"]]
|
||||
|
||||
def load(self) -> None:
|
||||
if self._allocation_id or self._exemplar_embeddings:
|
||||
raise RuntimeError(
|
||||
"EmbeddingKNNAdapter.load() called while already loaded — call unload() first"
|
||||
)
|
||||
orch_url, ollama_url = self._resolve_urls()
|
||||
node_url = ""
|
||||
orch_url_used = ""
|
||||
if orch_url:
|
||||
try:
|
||||
resp = httpx.post(
|
||||
f"{orch_url}/api/services/ollama/allocate",
|
||||
json={"model": self._model_id},
|
||||
timeout=15.0,
|
||||
)
|
||||
if resp.status_code == 200:
|
||||
data = resp.json()
|
||||
node_url = data["url"]
|
||||
self._allocation_id = data["allocation_id"]
|
||||
orch_url_used = orch_url
|
||||
except Exception as exc:
|
||||
_logger.warning(
|
||||
"cf-orch allocation failed, falling back to direct ollama_url: %s", exc
|
||||
)
|
||||
if not node_url:
|
||||
node_url = ollama_url
|
||||
self._allocation_id = ""
|
||||
orch_url_used = ""
|
||||
self._node_url = node_url
|
||||
self._orch_url_used = orch_url_used
|
||||
try:
|
||||
embeddings: dict[str, list[list[float]]] = {}
|
||||
for label, texts in self._exemplar_texts.items():
|
||||
embeddings[label] = self._embed(node_url, texts)
|
||||
self._exemplar_embeddings = embeddings
|
||||
except Exception:
|
||||
self.unload()
|
||||
raise
|
||||
|
||||
def unload(self) -> None:
|
||||
if self._allocation_id and self._orch_url_used:
|
||||
try:
|
||||
httpx.request(
|
||||
"DELETE",
|
||||
f"{self._orch_url_used}/api/services/ollama/allocations/{self._allocation_id}",
|
||||
timeout=10.0,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
self._exemplar_embeddings = {}
|
||||
self._node_url = ""
|
||||
self._allocation_id = ""
|
||||
self._orch_url_used = ""
|
||||
|
||||
def classify(self, subject: str, body: str) -> str:
|
||||
if not self._exemplar_embeddings:
|
||||
self.load()
|
||||
text = f"Subject: {subject}\n\n{body[:600]}"
|
||||
[query_vec] = self._embed(self._node_url, [text])
|
||||
scored: list[tuple[float, str]] = [
|
||||
(_cosine(query_vec, vec), label)
|
||||
for label, vecs in self._exemplar_embeddings.items()
|
||||
for vec in vecs
|
||||
]
|
||||
top_k = sorted(scored, reverse=True)[: self._k]
|
||||
votes: dict[str, list[float]] = {}
|
||||
for score, label in top_k:
|
||||
votes.setdefault(label, []).append(score)
|
||||
return max(
|
||||
votes,
|
||||
key=lambda lbl: (len(votes[lbl]), sum(votes[lbl])),
|
||||
)
|
||||
|
|
|
|||
458
scripts/export_plans.py
Normal file
458
scripts/export_plans.py
Normal file
|
|
@ -0,0 +1,458 @@
|
|||
"""Export circuitforge-plans/ documents as instruction-tuning JSONL pairs.
|
||||
|
||||
Each record is a HuggingFace chat-format example:
|
||||
|
||||
{
|
||||
"id": "<sha256>",
|
||||
"messages": [
|
||||
{"role": "user", "content": "<reconstructed planning prompt>"},
|
||||
{"role": "assistant", "content": "<cleaned document content>"}
|
||||
],
|
||||
"meta": {
|
||||
"source": "peregrine/2026-03-03-feedback-button-design.md",
|
||||
"product": "peregrine",
|
||||
"doc_type": "design", # design | plan | spec | implementation | other
|
||||
"date": "2026-03-03",
|
||||
"paired_with": "...", # sibling path, or null
|
||||
"word_count": 1847,
|
||||
"pair_role": "context" # "context" | "target" | "standalone"
|
||||
}
|
||||
}
|
||||
|
||||
Pairing strategy
|
||||
----------------
|
||||
When a design doc and a plan doc share the same date + feature-name prefix,
|
||||
they are treated as a pair:
|
||||
- design → plan: instruction = "Given this design doc, write the implementation plan."
|
||||
context appended = full design doc content.
|
||||
- Solo docs get a synthetic instruction from the title + first overview section.
|
||||
|
||||
Usage
|
||||
-----
|
||||
# Preview stats and 5 sample records
|
||||
python scripts/export_plans.py --preview
|
||||
|
||||
# Write full output
|
||||
python scripts/export_plans.py --output data/plan_pairs.jsonl
|
||||
|
||||
# Restrict to specific products
|
||||
python scripts/export_plans.py --products peregrine,kiwi --output data/plan_pairs.jsonl
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import hashlib
|
||||
import json
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Iterator
|
||||
|
||||
# ── Paths ──────────────────────────────────────────────────────────────────────
|
||||
|
||||
_SCRIPT_DIR = Path(__file__).parent
|
||||
_AVOCET_ROOT = _SCRIPT_DIR.parent
|
||||
_DEFAULT_PLANS_DIR = Path("/Library/Development/CircuitForge/circuitforge-plans")
|
||||
_DEFAULT_OUTPUT = _AVOCET_ROOT / "data" / "plan_pairs.jsonl"
|
||||
|
||||
# ── Doc type detection ─────────────────────────────────────────────────────────
|
||||
|
||||
_TYPE_RE = re.compile(
|
||||
r"-(design|plan|spec|implementation|specs|plans)s?$",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
_SKIP_DIRS = {"__pycache__", ".git", "node_modules"}
|
||||
|
||||
# Boilerplate lines to strip from document content before using as output.
|
||||
_BOILERPLATE_RE = re.compile(
|
||||
r"""
|
||||
^\s*>\s*\*\*For\s+agentic\s+workers.* # superpowers agent hints
|
||||
|^\s*>\s*REQUIRED\s+SUB-SKILL.*
|
||||
|^\s*\*\*Date:\*\*.* # metadata header lines
|
||||
|\*\*Status:\*\*\s*Complete.* # completed-feature noise
|
||||
|\*\*Status:\*\*\s*Done.*
|
||||
|\*\*Product:\*\*.*
|
||||
|\*\*Repo:\*\*.*
|
||||
|\*\*Tech\s+Stack:\*\*.*
|
||||
|\*\*Candidate:\*\*.* # old synthetic personas
|
||||
|^Candidate:.*
|
||||
|^Team:.*
|
||||
""",
|
||||
re.VERBOSE | re.MULTILINE,
|
||||
)
|
||||
|
||||
# Old repo/path names to normalise to current equivalents.
|
||||
_PATH_NORMALIZATIONS: list[tuple[re.Pattern, str]] = [
|
||||
(re.compile(r"/devl/job-seeker", re.IGNORECASE), "/Library/Development/CircuitForge/peregrine"),
|
||||
(re.compile(r"\bjob-seeker\b", re.IGNORECASE), "peregrine"),
|
||||
(re.compile(r"Alex Rivera", re.IGNORECASE), "[user]"),
|
||||
]
|
||||
|
||||
# Instruction paraphrase templates per doc type.
|
||||
# Each entry is (user_prefix, paired_prefix).
|
||||
# {title}, {product}, {type_phrase}, {overview}, {design_context} are substituted.
|
||||
_DESIGN_INSTRUCTIONS = [
|
||||
"Write a design document for {product}: {title}.\n\nContext: {overview}",
|
||||
"You are a software architect working on {product}. Draft a design spec for: {title}.\n\n{overview}",
|
||||
"Produce a CircuitForge-style design document for the following {product} feature — {title}.\n\nBackground: {overview}",
|
||||
]
|
||||
|
||||
_PLAN_INSTRUCTIONS = [
|
||||
"Write an implementation plan for {product}: {title}.\n\nContext: {overview}",
|
||||
"Break the following {product} feature into a detailed implementation plan with file structure and task checkboxes — {title}.\n\n{overview}",
|
||||
"You are a senior engineer on {product}. Produce a step-by-step engineering plan for: {title}.\n\n{overview}",
|
||||
]
|
||||
|
||||
_PAIRED_INSTRUCTIONS = [
|
||||
(
|
||||
"You are a software architect working on {product}, a CircuitForge product. "
|
||||
"Given the following design document, write a detailed implementation plan "
|
||||
"(file structure, task breakdown with checkboxes, migration steps if needed).\n\n"
|
||||
"---\n{design_context}\n---"
|
||||
),
|
||||
(
|
||||
"The following is a design spec for a {product} feature. "
|
||||
"Produce a concrete implementation plan: file list, task checklist, any DB migrations needed.\n\n"
|
||||
"---\n{design_context}\n---"
|
||||
),
|
||||
(
|
||||
"Convert this {product} design document into an actionable implementation plan. "
|
||||
"Include all files to create/modify, step-by-step tasks with checkboxes, and migration steps.\n\n"
|
||||
"---\n{design_context}\n---"
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def _doc_type(stem: str) -> str:
|
||||
m = _TYPE_RE.search(stem)
|
||||
if not m:
|
||||
return "other"
|
||||
raw = m.group(1).lower().rstrip("s")
|
||||
return {"implementation": "plan"}.get(raw, raw)
|
||||
|
||||
|
||||
def _date_feature(stem: str) -> tuple[str, str]:
|
||||
"""Return (date, feature_slug) from '2026-03-03-feedback-button-design'."""
|
||||
m = re.match(r"^(\d{4}-\d{2}-\d{2})-(.+?)(?:-(design|plan|spec|implementation)s?)?$", stem, re.I)
|
||||
if m:
|
||||
return m.group(1), m.group(2)
|
||||
return "", stem
|
||||
|
||||
|
||||
# ── Content extraction ─────────────────────────────────────────────────────────
|
||||
|
||||
def _extract_title(content: str) -> str:
|
||||
m = re.search(r"^#\s+(.+)", content, re.MULTILINE)
|
||||
return m.group(1).strip() if m else ""
|
||||
|
||||
|
||||
def _extract_overview(content: str) -> str:
|
||||
"""Return first substantive paragraph or h2 section body (≤300 chars)."""
|
||||
# Superpowers plans have an explicit **Goal:** line — prefer that.
|
||||
goal_m = re.search(r"\*\*Goal:\*\*\s*(.+)", content)
|
||||
if goal_m:
|
||||
return goal_m.group(1).strip()[:300]
|
||||
|
||||
# Otherwise use the body of the first h2 section.
|
||||
h2_m = re.search(
|
||||
r"^##\s+\d*\.?\s*.+\n([\s\S]+?)(?=^##|\Z)",
|
||||
content,
|
||||
re.MULTILINE,
|
||||
)
|
||||
if h2_m:
|
||||
body = h2_m.group(1).strip()
|
||||
# Strip markdown bullet/code noise for the instruction
|
||||
body = re.sub(r"```[\s\S]*?```", "", body)
|
||||
body = re.sub(r"`[^`]+`", lambda m: m.group().strip("`"), body)
|
||||
body = re.sub(r"\*\*([^*]+)\*\*", r"\1", body)
|
||||
body = re.sub(r"\s+", " ", body).strip()
|
||||
return body[:300]
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
def _clean_content(content: str) -> str:
|
||||
"""Remove boilerplate, normalize old paths/names, collapse whitespace."""
|
||||
cleaned = _BOILERPLATE_RE.sub("", content)
|
||||
for pattern, replacement in _PATH_NORMALIZATIONS:
|
||||
cleaned = pattern.sub(replacement, cleaned)
|
||||
cleaned = re.sub(r"\n{3,}", "\n\n", cleaned)
|
||||
return cleaned.strip()
|
||||
|
||||
|
||||
def _quality_flags(content: str) -> list[str]:
|
||||
"""Return a list of quality issue labels found in cleaned content."""
|
||||
flags = []
|
||||
if "Alex Rivera" in content or "[user]" in content:
|
||||
flags.append("persona-residue")
|
||||
if re.search(r"\bStatus:\s*(Complete|Done|Merged)\b", content):
|
||||
flags.append("completed-status")
|
||||
return flags
|
||||
|
||||
|
||||
def _make_instruction(
|
||||
title: str,
|
||||
product: str,
|
||||
doc_type: str,
|
||||
overview: str,
|
||||
design_context: str | None = None,
|
||||
variant: int = 0,
|
||||
) -> str:
|
||||
"""Synthesise a natural planning prompt for this document.
|
||||
|
||||
variant: 0-2 selects which paraphrase template to use. Caller cycles
|
||||
through all three to produce multiple training examples per document.
|
||||
"""
|
||||
product_label = product.replace("-", " ").title() if product else "CircuitForge"
|
||||
idx = variant % 3
|
||||
|
||||
if design_context:
|
||||
tmpl = _PAIRED_INSTRUCTIONS[idx]
|
||||
return tmpl.format(
|
||||
product=product_label,
|
||||
design_context=design_context[:2500],
|
||||
)
|
||||
|
||||
templates = _PLAN_INSTRUCTIONS if doc_type in ("plan",) else _DESIGN_INSTRUCTIONS
|
||||
tmpl = templates[idx]
|
||||
return tmpl.format(
|
||||
product=product_label,
|
||||
title=title,
|
||||
overview=overview or "",
|
||||
type_phrase="planning document",
|
||||
)
|
||||
|
||||
|
||||
def _record_id(content: str, source: str) -> str:
|
||||
return hashlib.sha256(f"{source}:{content}".encode()).hexdigest()[:16]
|
||||
|
||||
|
||||
# ── Pair discovery ─────────────────────────────────────────────────────────────
|
||||
|
||||
def _find_pairs(plans_dir: Path) -> dict[str, list[tuple[str, Path]]]:
|
||||
"""Return {prefix_key → [(doc_type, path), ...]} for docs sharing date+feature."""
|
||||
by_prefix: dict[str, list[tuple[str, Path]]] = {}
|
||||
for path in plans_dir.rglob("*.md"):
|
||||
if any(part in _SKIP_DIRS for part in path.parts):
|
||||
continue
|
||||
if path.name == "README.md":
|
||||
continue
|
||||
stem = path.stem
|
||||
date, feature = _date_feature(stem)
|
||||
if not date:
|
||||
continue
|
||||
key = str(path.parent / f"{date}-{feature}")
|
||||
by_prefix.setdefault(key, []).append((_doc_type(stem), path))
|
||||
return by_prefix
|
||||
|
||||
|
||||
# ── Record generation ──────────────────────────────────────────────────────────
|
||||
|
||||
def _records_for_group(
|
||||
doc_type_paths: list[tuple[str, Path]],
|
||||
plans_dir: Path,
|
||||
) -> Iterator[dict]:
|
||||
"""Yield one or more training records for a group of related docs."""
|
||||
# Separate design vs plan docs within this group
|
||||
designs = [(t, p) for t, p in doc_type_paths if t in ("design", "spec")]
|
||||
plans_ = [(t, p) for t, p in doc_type_paths if t in ("plan",)]
|
||||
others = [(t, p) for t, p in doc_type_paths if t not in ("design", "spec", "plan")]
|
||||
|
||||
all_paths = doc_type_paths
|
||||
|
||||
if designs and plans_:
|
||||
# Paired: yield a design→plan record (3 instruction variants)
|
||||
design_type, design_path = designs[0]
|
||||
plan_type, plan_path = plans_[0]
|
||||
design_content = design_path.read_text(encoding="utf-8")
|
||||
plan_content = plan_path.read_text(encoding="utf-8")
|
||||
|
||||
product = _product_from_path(plan_path, plans_dir)
|
||||
title = _extract_title(plan_content) or plan_path.stem
|
||||
cleaned = _clean_content(plan_content)
|
||||
design_cleaned = _clean_content(design_content)
|
||||
flags = _quality_flags(cleaned)
|
||||
|
||||
if len(cleaned.split()) >= 80:
|
||||
rel_src = str(plan_path.relative_to(plans_dir))
|
||||
rel_design = str(design_path.relative_to(plans_dir))
|
||||
for variant in range(3):
|
||||
instruction = _make_instruction(
|
||||
title=title,
|
||||
product=product,
|
||||
doc_type="plan",
|
||||
overview=_extract_overview(design_content),
|
||||
design_context=design_cleaned,
|
||||
variant=variant,
|
||||
)
|
||||
yield {
|
||||
"id": _record_id(f"v{variant}:{cleaned}", rel_src),
|
||||
"messages": [
|
||||
{"role": "user", "content": instruction},
|
||||
{"role": "assistant", "content": cleaned},
|
||||
],
|
||||
"meta": {
|
||||
"source": rel_src,
|
||||
"product": product,
|
||||
"doc_type": "plan",
|
||||
"date": _date_feature(plan_path.stem)[0],
|
||||
"paired_with": rel_design,
|
||||
"word_count": len(cleaned.split()),
|
||||
"pair_role": "target",
|
||||
"variant": variant,
|
||||
"quality_flags": flags,
|
||||
},
|
||||
}
|
||||
|
||||
# Also yield the design doc as standalone variants
|
||||
all_paths = [(t, p) for t, p in all_paths if p != plan_path]
|
||||
|
||||
# Remaining docs as standalone records (3 instruction variants each)
|
||||
for doc_type, path in all_paths:
|
||||
content = path.read_text(encoding="utf-8")
|
||||
cleaned = _clean_content(content)
|
||||
if len(cleaned.split()) < 80:
|
||||
continue
|
||||
|
||||
product = _product_from_path(path, plans_dir)
|
||||
title = _extract_title(content) or path.stem
|
||||
overview = _extract_overview(content)
|
||||
flags = _quality_flags(cleaned)
|
||||
rel_src = str(path.relative_to(plans_dir))
|
||||
|
||||
for variant in range(3):
|
||||
instruction = _make_instruction(
|
||||
title=title,
|
||||
product=product,
|
||||
doc_type=doc_type,
|
||||
overview=overview,
|
||||
variant=variant,
|
||||
)
|
||||
yield {
|
||||
"id": _record_id(f"v{variant}:{cleaned}", rel_src),
|
||||
"messages": [
|
||||
{"role": "user", "content": instruction},
|
||||
{"role": "assistant", "content": cleaned},
|
||||
],
|
||||
"meta": {
|
||||
"source": rel_src,
|
||||
"product": product,
|
||||
"doc_type": doc_type,
|
||||
"date": _date_feature(path.stem)[0],
|
||||
"paired_with": None,
|
||||
"word_count": len(cleaned.split()),
|
||||
"pair_role": "standalone",
|
||||
"variant": variant,
|
||||
"quality_flags": flags,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _product_from_path(path: Path, plans_dir: Path) -> str:
|
||||
rel = path.relative_to(plans_dir)
|
||||
return rel.parts[0] if len(rel.parts) > 1 else "shared"
|
||||
|
||||
|
||||
# ── Main export ────────────────────────────────────────────────────────────────
|
||||
|
||||
def export(
|
||||
plans_dir: Path,
|
||||
products: list[str] | None = None,
|
||||
) -> list[dict]:
|
||||
groups = _find_pairs(plans_dir)
|
||||
records: list[dict] = []
|
||||
seen_ids: set[str] = set()
|
||||
|
||||
for group_key, doc_type_paths in groups.items():
|
||||
# Filter by product if requested
|
||||
if products:
|
||||
paths = [p for _, p in doc_type_paths]
|
||||
prods = {_product_from_path(p, plans_dir) for p in paths}
|
||||
if not prods.intersection(products):
|
||||
continue
|
||||
|
||||
for record in _records_for_group(doc_type_paths, plans_dir):
|
||||
if record["id"] not in seen_ids:
|
||||
seen_ids.add(record["id"])
|
||||
records.append(record)
|
||||
|
||||
return records
|
||||
|
||||
|
||||
# ── CLI ────────────────────────────────────────────────────────────────────────
|
||||
|
||||
def _print_stats(records: list[dict]) -> None:
|
||||
from collections import Counter
|
||||
products = Counter(r["meta"]["product"] for r in records)
|
||||
doc_types = Counter(r["meta"]["doc_type"] for r in records)
|
||||
pair_roles = Counter(r["meta"]["pair_role"] for r in records)
|
||||
wc = [r["meta"]["word_count"] for r in records]
|
||||
wc.sort()
|
||||
|
||||
print(f"\n{'='*55}")
|
||||
print(f" Total records: {len(records)}")
|
||||
print(f" Word counts : min={wc[0]}, median={wc[len(wc)//2]}, max={wc[-1]}")
|
||||
print(f"\n By product:")
|
||||
for p, n in products.most_common():
|
||||
print(f" {p:<22} {n}")
|
||||
print(f"\n By doc type:")
|
||||
for t, n in doc_types.most_common():
|
||||
print(f" {t:<22} {n}")
|
||||
print(f"\n Pair roles:")
|
||||
for r, n in pair_roles.most_common():
|
||||
print(f" {r:<22} {n}")
|
||||
print(f"{'='*55}\n")
|
||||
|
||||
|
||||
def _print_sample(records: list[dict], n: int = 3) -> None:
|
||||
import random
|
||||
sample = random.sample(records, min(n, len(records)))
|
||||
for i, rec in enumerate(sample, 1):
|
||||
meta = rec["meta"]
|
||||
user_msg = rec["messages"][0]["content"]
|
||||
asst_msg = rec["messages"][1]["content"]
|
||||
print(f"\n{'─'*55}")
|
||||
print(f"SAMPLE {i}/{n} [{meta['product']} / {meta['doc_type']} / {meta['pair_role']}]")
|
||||
print(f"source: {meta['source']}")
|
||||
print(f"\nUSER ({len(user_msg)} chars):\n{user_msg[:500]}{'...' if len(user_msg)>500 else ''}")
|
||||
print(f"\nASSISTANT ({meta['word_count']} words):\n{asst_msg[:400]}{'...' if len(asst_msg)>400 else ''}")
|
||||
print(f"\n{'─'*55}\n")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter)
|
||||
parser.add_argument("--plans-dir", type=Path, default=_DEFAULT_PLANS_DIR)
|
||||
parser.add_argument("--output", type=Path, default=None,
|
||||
help="Write JSONL to this path (omit for preview-only)")
|
||||
parser.add_argument("--products", default=None,
|
||||
help="Comma-separated product filter, e.g. peregrine,kiwi")
|
||||
parser.add_argument("--preview", action="store_true",
|
||||
help="Print stats + sample records, don't write output")
|
||||
parser.add_argument("--samples", type=int, default=3,
|
||||
help="Number of sample records to show in preview (default 3)")
|
||||
args = parser.parse_args()
|
||||
|
||||
products = [p.strip() for p in args.products.split(",")] if args.products else None
|
||||
|
||||
print(f"Scanning {args.plans_dir} …", file=sys.stderr)
|
||||
records = export(args.plans_dir, products=products)
|
||||
|
||||
_print_stats(records)
|
||||
|
||||
if args.preview or args.output is None:
|
||||
_print_sample(records, n=args.samples)
|
||||
if args.output is None:
|
||||
print("(Pass --output <path> to write JSONL)")
|
||||
return
|
||||
|
||||
args.output.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(args.output, "w", encoding="utf-8") as f:
|
||||
for rec in records:
|
||||
f.write(json.dumps(rec, ensure_ascii=False) + "\n")
|
||||
|
||||
print(f"Wrote {len(records)} records to {args.output}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -1,23 +1,37 @@
|
|||
import json
|
||||
"""Smoke tests for the app factory (app/api.py).
|
||||
|
||||
Detailed route tests live in test_data_label.py, test_data_fetch.py,
|
||||
test_data_corrections.py, test_train.py, and test_dashboard.py.
|
||||
"""
|
||||
import pytest
|
||||
from app import api as api_module # noqa: F401
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_globals(tmp_path):
|
||||
from app import api
|
||||
api.set_data_dir(tmp_path)
|
||||
api.reset_last_action()
|
||||
yield
|
||||
api.reset_last_action()
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
|
||||
def test_import():
|
||||
from app import api # noqa: F401
|
||||
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
def test_app_has_required_routes():
|
||||
from app.api import app
|
||||
paths = {r.path for r in app.routes}
|
||||
# Label routes
|
||||
assert "/api/queue" in paths
|
||||
assert "/api/label" in paths
|
||||
assert "/api/skip" in paths
|
||||
assert "/api/discard" in paths
|
||||
assert "/api/label/undo" in paths
|
||||
assert "/api/config/labels" in paths
|
||||
assert "/api/stats" in paths
|
||||
# Fetch routes
|
||||
assert "/api/accounts/test" in paths
|
||||
assert "/api/fetch/stream" in paths
|
||||
# Train routes
|
||||
assert "/api/train/jobs" in paths
|
||||
assert "/api/train/results" in paths
|
||||
# Dashboard
|
||||
assert "/api/dashboard" in paths
|
||||
# Corrections (new prefix)
|
||||
assert "/api/corrections/ingest" in paths
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -26,536 +40,8 @@ def client():
|
|||
return TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def queue_with_items():
|
||||
"""Write 3 test emails to the queue file."""
|
||||
from app import api as api_module
|
||||
items = [
|
||||
{"id": f"id{i}", "subject": f"Subject {i}", "body": f"Body {i}",
|
||||
"from": "test@example.com", "date": "2026-03-01", "source": "imap:test"}
|
||||
for i in range(3)
|
||||
]
|
||||
queue_path = api_module._DATA_DIR / "email_label_queue.jsonl"
|
||||
queue_path.write_text("\n".join(json.dumps(x) for x in items) + "\n")
|
||||
return items
|
||||
|
||||
|
||||
def test_queue_returns_items(client, queue_with_items):
|
||||
r = client.get("/api/queue?limit=2")
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert len(data["items"]) == 2
|
||||
assert data["total"] == 3
|
||||
|
||||
|
||||
def test_queue_empty_when_no_file(client):
|
||||
def test_queue_endpoint_reachable(client):
|
||||
r = client.get("/api/queue")
|
||||
assert r.status_code == 200
|
||||
assert r.json() == {"items": [], "total": 0}
|
||||
|
||||
|
||||
def test_label_appends_to_score(client, queue_with_items):
|
||||
from app import api as api_module
|
||||
r = client.post("/api/label", json={"id": "id0", "label": "interview_scheduled"})
|
||||
assert r.status_code == 200
|
||||
records = api_module._read_jsonl(api_module._score_file())
|
||||
assert len(records) == 1
|
||||
assert records[0]["id"] == "id0"
|
||||
assert records[0]["label"] == "interview_scheduled"
|
||||
assert "labeled_at" in records[0]
|
||||
|
||||
def test_label_removes_from_queue(client, queue_with_items):
|
||||
from app import api as api_module
|
||||
client.post("/api/label", json={"id": "id0", "label": "rejected"})
|
||||
queue = api_module._read_jsonl(api_module._queue_file())
|
||||
assert not any(x["id"] == "id0" for x in queue)
|
||||
|
||||
def test_label_unknown_id_returns_404(client, queue_with_items):
|
||||
r = client.post("/api/label", json={"id": "unknown", "label": "neutral"})
|
||||
assert r.status_code == 404
|
||||
|
||||
def test_skip_moves_to_back(client, queue_with_items):
|
||||
from app import api as api_module
|
||||
r = client.post("/api/skip", json={"id": "id0"})
|
||||
assert r.status_code == 200
|
||||
queue = api_module._read_jsonl(api_module._queue_file())
|
||||
assert queue[-1]["id"] == "id0"
|
||||
assert queue[0]["id"] == "id1"
|
||||
|
||||
def test_skip_unknown_id_returns_404(client, queue_with_items):
|
||||
r = client.post("/api/skip", json={"id": "nope"})
|
||||
assert r.status_code == 404
|
||||
|
||||
|
||||
# --- Part A: POST /api/discard ---
|
||||
|
||||
def test_discard_writes_to_discarded_file(client, queue_with_items):
|
||||
from app import api as api_module
|
||||
r = client.post("/api/discard", json={"id": "id1"})
|
||||
assert r.status_code == 200
|
||||
discarded = api_module._read_jsonl(api_module._discarded_file())
|
||||
assert len(discarded) == 1
|
||||
assert discarded[0]["id"] == "id1"
|
||||
assert discarded[0]["label"] == "__discarded__"
|
||||
|
||||
def test_discard_removes_from_queue(client, queue_with_items):
|
||||
from app import api as api_module
|
||||
client.post("/api/discard", json={"id": "id1"})
|
||||
queue = api_module._read_jsonl(api_module._queue_file())
|
||||
assert not any(x["id"] == "id1" for x in queue)
|
||||
|
||||
|
||||
# --- Part B: DELETE /api/label/undo ---
|
||||
|
||||
def test_undo_label_removes_from_score(client, queue_with_items):
|
||||
from app import api as api_module
|
||||
client.post("/api/label", json={"id": "id0", "label": "neutral"})
|
||||
r = client.delete("/api/label/undo")
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["undone"]["type"] == "label"
|
||||
score = api_module._read_jsonl(api_module._score_file())
|
||||
assert score == []
|
||||
# Item should be restored to front of queue
|
||||
queue = api_module._read_jsonl(api_module._queue_file())
|
||||
assert queue[0]["id"] == "id0"
|
||||
|
||||
def test_undo_discard_removes_from_discarded(client, queue_with_items):
|
||||
from app import api as api_module
|
||||
client.post("/api/discard", json={"id": "id0"})
|
||||
r = client.delete("/api/label/undo")
|
||||
assert r.status_code == 200
|
||||
discarded = api_module._read_jsonl(api_module._discarded_file())
|
||||
assert discarded == []
|
||||
|
||||
def test_undo_skip_restores_to_front(client, queue_with_items):
|
||||
from app import api as api_module
|
||||
client.post("/api/skip", json={"id": "id0"})
|
||||
r = client.delete("/api/label/undo")
|
||||
assert r.status_code == 200
|
||||
queue = api_module._read_jsonl(api_module._queue_file())
|
||||
assert queue[0]["id"] == "id0"
|
||||
|
||||
def test_undo_with_no_action_returns_404(client):
|
||||
r = client.delete("/api/label/undo")
|
||||
assert r.status_code == 404
|
||||
|
||||
|
||||
# --- Part C: GET /api/config/labels ---
|
||||
|
||||
def test_config_labels_returns_metadata(client):
|
||||
r = client.get("/api/config/labels")
|
||||
assert r.status_code == 200
|
||||
labels = r.json()
|
||||
assert len(labels) == 10
|
||||
assert labels[0]["key"] == "1"
|
||||
assert "emoji" in labels[0]
|
||||
assert "color" in labels[0]
|
||||
assert "name" in labels[0]
|
||||
|
||||
|
||||
# ── /api/config ──────────────────────────────────────────────────────────────
|
||||
|
||||
@pytest.fixture
|
||||
def config_dir(tmp_path):
|
||||
"""Give the API a writable config directory."""
|
||||
from app import api as api_module
|
||||
api_module.set_config_dir(tmp_path)
|
||||
yield tmp_path
|
||||
api_module.set_config_dir(None) # reset to default
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def data_dir():
|
||||
"""Expose the current _DATA_DIR set by the autouse reset_globals fixture."""
|
||||
from app import api as api_module
|
||||
return api_module._DATA_DIR
|
||||
|
||||
|
||||
def test_get_config_returns_empty_when_no_file(client, config_dir):
|
||||
r = client.get("/api/config")
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["accounts"] == []
|
||||
assert data["max_per_account"] == 500
|
||||
|
||||
|
||||
def test_post_config_writes_yaml(client, config_dir):
|
||||
import yaml
|
||||
payload = {
|
||||
"accounts": [{"name": "Test", "host": "imap.test.com", "port": 993,
|
||||
"use_ssl": True, "username": "u@t.com", "password": "pw",
|
||||
"folder": "INBOX", "days_back": 30}],
|
||||
"max_per_account": 200,
|
||||
}
|
||||
r = client.post("/api/config", json=payload)
|
||||
assert r.status_code == 200
|
||||
assert r.json()["ok"] is True
|
||||
cfg_file = config_dir / "label_tool.yaml"
|
||||
assert cfg_file.exists()
|
||||
saved = yaml.safe_load(cfg_file.read_text())
|
||||
assert saved["max_per_account"] == 200
|
||||
assert saved["accounts"][0]["name"] == "Test"
|
||||
|
||||
|
||||
def test_get_config_round_trips(client, config_dir):
|
||||
payload = {"accounts": [{"name": "R", "host": "h", "port": 993, "use_ssl": True,
|
||||
"username": "u", "password": "p", "folder": "INBOX",
|
||||
"days_back": 90}], "max_per_account": 300}
|
||||
client.post("/api/config", json=payload)
|
||||
r = client.get("/api/config")
|
||||
data = r.json()
|
||||
assert data["max_per_account"] == 300
|
||||
assert data["accounts"][0]["name"] == "R"
|
||||
|
||||
|
||||
# ── /api/stats ───────────────────────────────────────────────────────────────
|
||||
|
||||
@pytest.fixture
|
||||
def score_with_labels(tmp_path, data_dir):
|
||||
"""Write a score file with 3 labels for stats tests."""
|
||||
score_path = data_dir / "email_score.jsonl"
|
||||
records = [
|
||||
{"id": "a", "label": "interview_scheduled"},
|
||||
{"id": "b", "label": "interview_scheduled"},
|
||||
{"id": "c", "label": "rejected"},
|
||||
]
|
||||
score_path.write_text("\n".join(json.dumps(r) for r in records) + "\n")
|
||||
return records
|
||||
|
||||
|
||||
def test_stats_returns_counts(client, score_with_labels):
|
||||
r = client.get("/api/stats")
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["total"] == 3
|
||||
assert data["counts"]["interview_scheduled"] == 2
|
||||
assert data["counts"]["rejected"] == 1
|
||||
|
||||
|
||||
def test_stats_empty_when_no_file(client, data_dir):
|
||||
r = client.get("/api/stats")
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["total"] == 0
|
||||
assert data["counts"] == {}
|
||||
assert data["score_file_bytes"] == 0
|
||||
|
||||
|
||||
def test_stats_download_returns_file(client, score_with_labels):
|
||||
r = client.get("/api/stats/download")
|
||||
assert r.status_code == 200
|
||||
assert "jsonlines" in r.headers.get("content-type", "")
|
||||
|
||||
|
||||
def test_stats_download_404_when_no_file(client, data_dir):
|
||||
r = client.get("/api/stats/download")
|
||||
assert r.status_code == 404
|
||||
|
||||
|
||||
# ── /api/accounts/test ───────────────────────────────────────────────────────
|
||||
|
||||
def test_account_test_missing_fields(client):
|
||||
r = client.post("/api/accounts/test", json={"account": {"host": "", "username": "", "password": ""}})
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["ok"] is False
|
||||
assert "required" in data["message"].lower()
|
||||
|
||||
|
||||
def test_account_test_success(client):
|
||||
from unittest.mock import MagicMock, patch
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.select.return_value = ("OK", [b"99"])
|
||||
with patch("app.imap_fetch.imaplib.IMAP4_SSL", return_value=mock_conn):
|
||||
r = client.post("/api/accounts/test", json={"account": {
|
||||
"host": "imap.example.com", "port": 993, "use_ssl": True,
|
||||
"username": "u@example.com", "password": "pw", "folder": "INBOX",
|
||||
}})
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["ok"] is True
|
||||
assert data["count"] == 99
|
||||
|
||||
|
||||
# ── /api/fetch/stream (SSE) ──────────────────────────────────────────────────
|
||||
|
||||
def _parse_sse(content: bytes) -> list[dict]:
|
||||
"""Parse SSE response body into list of event dicts."""
|
||||
events = []
|
||||
for line in content.decode().splitlines():
|
||||
if line.startswith("data: "):
|
||||
events.append(json.loads(line[6:]))
|
||||
return events
|
||||
|
||||
|
||||
def test_fetch_stream_no_accounts_configured(client, config_dir):
|
||||
"""With no config, stream should immediately complete with 0 added."""
|
||||
r = client.get("/api/fetch/stream?accounts=NoSuchAccount&days_back=30&limit=10")
|
||||
assert r.status_code == 200
|
||||
events = _parse_sse(r.content)
|
||||
complete = next((e for e in events if e["type"] == "complete"), None)
|
||||
assert complete is not None
|
||||
assert complete["total_added"] == 0
|
||||
|
||||
|
||||
def test_fetch_stream_with_mock_imap(client, config_dir, data_dir):
|
||||
"""With one configured account, stream should yield start/done/complete events."""
|
||||
import yaml
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
# Write a config with one account
|
||||
cfg = {"accounts": [{"name": "Mock", "host": "h", "port": 993, "use_ssl": True,
|
||||
"username": "u", "password": "p", "folder": "INBOX",
|
||||
"days_back": 30}], "max_per_account": 50}
|
||||
(config_dir / "label_tool.yaml").write_text(yaml.dump(cfg))
|
||||
|
||||
raw_msg = (b"Subject: Interview\r\nFrom: a@b.com\r\n"
|
||||
b"Date: Mon, 1 Mar 2026 12:00:00 +0000\r\n\r\nBody")
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.search.return_value = ("OK", [b"1"])
|
||||
mock_conn.fetch.return_value = ("OK", [(b"1 (RFC822 {N})", raw_msg)])
|
||||
|
||||
with patch("app.imap_fetch.imaplib.IMAP4_SSL", return_value=mock_conn):
|
||||
r = client.get("/api/fetch/stream?accounts=Mock&days_back=30&limit=50")
|
||||
|
||||
assert r.status_code == 200
|
||||
events = _parse_sse(r.content)
|
||||
types = [e["type"] for e in events]
|
||||
assert "start" in types
|
||||
assert "done" in types
|
||||
assert "complete" in types
|
||||
|
||||
|
||||
# ---- /api/finetune/status tests ----
|
||||
|
||||
def test_finetune_status_returns_empty_when_no_models_dir(client):
|
||||
"""GET /api/finetune/status must return [] if models/ does not exist."""
|
||||
r = client.get("/api/finetune/status")
|
||||
assert r.status_code == 200
|
||||
assert r.json() == []
|
||||
|
||||
|
||||
def test_finetune_status_returns_training_info(client, tmp_path):
|
||||
"""GET /api/finetune/status must return one entry per training_info.json found."""
|
||||
import json as _json
|
||||
from app import api as api_module
|
||||
|
||||
models_dir = tmp_path / "models" / "avocet-deberta-small"
|
||||
models_dir.mkdir(parents=True)
|
||||
info = {
|
||||
"name": "avocet-deberta-small",
|
||||
"base_model_id": "cross-encoder/nli-deberta-v3-small",
|
||||
"val_macro_f1": 0.712,
|
||||
"timestamp": "2026-03-15T12:00:00Z",
|
||||
"sample_count": 401,
|
||||
}
|
||||
(models_dir / "training_info.json").write_text(_json.dumps(info))
|
||||
|
||||
api_module.set_models_dir(tmp_path / "models")
|
||||
try:
|
||||
r = client.get("/api/finetune/status")
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert any(d["name"] == "avocet-deberta-small" for d in data)
|
||||
finally:
|
||||
api_module.set_models_dir(api_module._ROOT / "models")
|
||||
|
||||
|
||||
def test_finetune_run_streams_sse_events(client):
|
||||
"""GET /api/finetune/run must return text/event-stream content type."""
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.stdout = iter(["Training epoch 1\n", "Done\n"])
|
||||
mock_proc.returncode = 0
|
||||
mock_proc.wait = MagicMock()
|
||||
|
||||
with patch("app.api._subprocess.Popen",return_value=mock_proc):
|
||||
r = client.get("/api/finetune/run?model=deberta-small&epochs=1")
|
||||
|
||||
assert r.status_code == 200
|
||||
assert "text/event-stream" in r.headers.get("content-type", "")
|
||||
|
||||
|
||||
def test_finetune_run_emits_complete_on_success(client):
|
||||
"""GET /api/finetune/run must emit a complete event on clean exit."""
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.stdout = iter(["progress line\n"])
|
||||
mock_proc.returncode = 0
|
||||
mock_proc.wait = MagicMock()
|
||||
|
||||
with patch("app.api._subprocess.Popen",return_value=mock_proc):
|
||||
r = client.get("/api/finetune/run?model=deberta-small&epochs=1")
|
||||
|
||||
assert '{"type": "complete"}' in r.text
|
||||
|
||||
|
||||
def test_finetune_run_emits_error_on_nonzero_exit(client):
|
||||
"""GET /api/finetune/run must emit an error event on non-zero exit."""
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.stdout = iter([])
|
||||
mock_proc.returncode = 1
|
||||
mock_proc.wait = MagicMock()
|
||||
|
||||
with patch("app.api._subprocess.Popen",return_value=mock_proc):
|
||||
r = client.get("/api/finetune/run?model=deberta-small&epochs=1")
|
||||
|
||||
assert '"type": "error"' in r.text
|
||||
|
||||
|
||||
def test_finetune_run_passes_score_files_to_subprocess(client):
|
||||
"""GET /api/finetune/run?score=file1&score=file2 must pass --score args to subprocess."""
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
captured_cmd = []
|
||||
|
||||
def mock_popen(cmd, **kwargs):
|
||||
captured_cmd.extend(cmd)
|
||||
m = MagicMock()
|
||||
m.stdout = iter([])
|
||||
m.returncode = 0
|
||||
m.wait = MagicMock()
|
||||
return m
|
||||
|
||||
with patch("app.api._subprocess.Popen",side_effect=mock_popen):
|
||||
client.get("/api/finetune/run?model=deberta-small&epochs=1&score=run1.jsonl&score=run2.jsonl")
|
||||
|
||||
assert "--score" in captured_cmd
|
||||
assert captured_cmd.count("--score") == 2
|
||||
# Paths are resolved to absolute — check filenames are present as substrings
|
||||
assert any("run1.jsonl" in arg for arg in captured_cmd)
|
||||
assert any("run2.jsonl" in arg for arg in captured_cmd)
|
||||
|
||||
|
||||
# ---- Cancel endpoint tests ----
|
||||
|
||||
def test_benchmark_cancel_returns_404_when_not_running(client):
|
||||
"""POST /api/benchmark/cancel must return 404 if no benchmark is running."""
|
||||
from app import api as api_module
|
||||
api_module._running_procs.pop("benchmark", None)
|
||||
r = client.post("/api/benchmark/cancel")
|
||||
assert r.status_code == 404
|
||||
|
||||
|
||||
def test_finetune_cancel_returns_404_when_not_running(client):
|
||||
"""POST /api/finetune/cancel must return 404 if no finetune is running."""
|
||||
from app import api as api_module
|
||||
api_module._running_procs.pop("finetune", None)
|
||||
r = client.post("/api/finetune/cancel")
|
||||
assert r.status_code == 404
|
||||
|
||||
|
||||
def test_benchmark_cancel_terminates_running_process(client):
|
||||
"""POST /api/benchmark/cancel must call terminate() on the running process."""
|
||||
from unittest.mock import MagicMock
|
||||
from app import api as api_module
|
||||
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.wait = MagicMock()
|
||||
api_module._running_procs["benchmark"] = mock_proc
|
||||
|
||||
try:
|
||||
r = client.post("/api/benchmark/cancel")
|
||||
assert r.status_code == 200
|
||||
assert r.json()["status"] == "cancelled"
|
||||
mock_proc.terminate.assert_called_once()
|
||||
finally:
|
||||
api_module._running_procs.pop("benchmark", None)
|
||||
api_module._cancelled_jobs.discard("benchmark")
|
||||
|
||||
|
||||
def test_finetune_cancel_terminates_running_process(client):
|
||||
"""POST /api/finetune/cancel must call terminate() on the running process."""
|
||||
from unittest.mock import MagicMock
|
||||
from app import api as api_module
|
||||
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.wait = MagicMock()
|
||||
api_module._running_procs["finetune"] = mock_proc
|
||||
|
||||
try:
|
||||
r = client.post("/api/finetune/cancel")
|
||||
assert r.status_code == 200
|
||||
assert r.json()["status"] == "cancelled"
|
||||
mock_proc.terminate.assert_called_once()
|
||||
finally:
|
||||
api_module._running_procs.pop("finetune", None)
|
||||
api_module._cancelled_jobs.discard("finetune")
|
||||
|
||||
|
||||
def test_benchmark_cancel_kills_process_on_timeout(client):
|
||||
"""POST /api/benchmark/cancel must call kill() if the process does not exit within 3 s."""
|
||||
import subprocess
|
||||
from unittest.mock import MagicMock
|
||||
from app import api as api_module
|
||||
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.wait.side_effect = subprocess.TimeoutExpired(cmd="benchmark", timeout=3)
|
||||
api_module._running_procs["benchmark"] = mock_proc
|
||||
|
||||
try:
|
||||
r = client.post("/api/benchmark/cancel")
|
||||
assert r.status_code == 200
|
||||
mock_proc.kill.assert_called_once()
|
||||
finally:
|
||||
api_module._running_procs.pop("benchmark", None)
|
||||
api_module._cancelled_jobs.discard("benchmark")
|
||||
|
||||
|
||||
def test_finetune_run_emits_cancelled_event(client):
|
||||
"""GET /api/finetune/run must emit cancelled (not error) when job was cancelled."""
|
||||
from unittest.mock import patch, MagicMock
|
||||
from app import api as api_module
|
||||
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.stdout = iter([])
|
||||
mock_proc.returncode = -15 # SIGTERM
|
||||
|
||||
def mock_wait():
|
||||
# Simulate cancel being called while the process is running (after discard clears stale flag)
|
||||
api_module._cancelled_jobs.add("finetune")
|
||||
|
||||
mock_proc.wait = mock_wait
|
||||
|
||||
def mock_popen(cmd, **kwargs):
|
||||
return mock_proc
|
||||
|
||||
try:
|
||||
with patch("app.api._subprocess.Popen",side_effect=mock_popen):
|
||||
r = client.get("/api/finetune/run?model=deberta-small&epochs=1")
|
||||
assert '{"type": "cancelled"}' in r.text
|
||||
assert '"type": "error"' not in r.text
|
||||
finally:
|
||||
api_module._cancelled_jobs.discard("finetune")
|
||||
|
||||
|
||||
def test_benchmark_run_emits_cancelled_event(client):
|
||||
"""GET /api/benchmark/run must emit cancelled (not error) when job was cancelled."""
|
||||
from unittest.mock import patch, MagicMock
|
||||
from app import api as api_module
|
||||
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.stdout = iter([])
|
||||
mock_proc.returncode = -15
|
||||
|
||||
def mock_wait():
|
||||
# Simulate cancel being called while the process is running (after discard clears stale flag)
|
||||
api_module._cancelled_jobs.add("benchmark")
|
||||
|
||||
mock_proc.wait = mock_wait
|
||||
|
||||
def mock_popen(cmd, **kwargs):
|
||||
return mock_proc
|
||||
|
||||
try:
|
||||
with patch("app.api._subprocess.Popen",side_effect=mock_popen):
|
||||
r = client.get("/api/benchmark/run")
|
||||
assert '{"type": "cancelled"}' in r.text
|
||||
assert '"type": "error"' not in r.text
|
||||
finally:
|
||||
api_module._cancelled_jobs.discard("benchmark")
|
||||
assert "items" in r.json()
|
||||
assert "total" in r.json()
|
||||
|
|
|
|||
|
|
@ -2,11 +2,6 @@
|
|||
import pytest
|
||||
|
||||
|
||||
def test_registry_has_thirteen_models():
|
||||
from scripts.benchmark_classifier import MODEL_REGISTRY
|
||||
assert len(MODEL_REGISTRY) == 13
|
||||
|
||||
|
||||
def test_registry_default_count():
|
||||
from scripts.benchmark_classifier import MODEL_REGISTRY
|
||||
defaults = [k for k, v in MODEL_REGISTRY.items() if v["default"]]
|
||||
|
|
@ -166,3 +161,95 @@ def test_active_models_includes_discovered_finetuned(tmp_path):
|
|||
|
||||
assert "avocet-deberta-small" in models
|
||||
assert isinstance(models["avocet-deberta-small"]["adapter_instance"], FineTunedAdapter)
|
||||
|
||||
|
||||
# ---- build_exemplars_from_jsonl() tests ----
|
||||
|
||||
def test_build_exemplars_samples_up_to_k_per_label(tmp_path):
|
||||
from scripts.benchmark_classifier import build_exemplars_from_jsonl
|
||||
import json
|
||||
|
||||
rows = [{"subject": f"S{i}", "body": f"B{i}", "label": "rejected"} for i in range(15)]
|
||||
rows.append({"subject": "Hire", "body": "Welcome", "label": "hired"})
|
||||
f = tmp_path / "score.jsonl"
|
||||
f.write_text("\n".join(json.dumps(r) for r in rows))
|
||||
|
||||
result = build_exemplars_from_jsonl(str(f), k_per_label=10)
|
||||
|
||||
assert len(result["rejected"]) == 10
|
||||
assert len(result["hired"]) == 1
|
||||
assert result["rejected"][0].startswith("Subject: S")
|
||||
|
||||
|
||||
def test_build_exemplars_formats_text_correctly(tmp_path):
|
||||
from scripts.benchmark_classifier import build_exemplars_from_jsonl
|
||||
import json
|
||||
|
||||
row = {"subject": "My Subject", "body": "My Body", "label": "neutral"}
|
||||
f = tmp_path / "score.jsonl"
|
||||
f.write_text(json.dumps(row))
|
||||
|
||||
result = build_exemplars_from_jsonl(str(f))
|
||||
|
||||
assert result["neutral"][0] == "Subject: My Subject\n\nMy Body"
|
||||
|
||||
|
||||
def test_build_exemplars_skips_rows_missing_label(tmp_path):
|
||||
from scripts.benchmark_classifier import build_exemplars_from_jsonl
|
||||
import json
|
||||
|
||||
rows = [
|
||||
{"subject": "A", "body": "B", "label": "neutral"},
|
||||
{"subject": "No label here", "body": "Body"},
|
||||
]
|
||||
f = tmp_path / "score.jsonl"
|
||||
f.write_text("\n".join(json.dumps(r) for r in rows))
|
||||
|
||||
result = build_exemplars_from_jsonl(str(f))
|
||||
assert list(result.keys()) == ["neutral"]
|
||||
|
||||
|
||||
def test_build_exemplars_truncates_body_at_600(tmp_path):
|
||||
from scripts.benchmark_classifier import build_exemplars_from_jsonl
|
||||
import json
|
||||
|
||||
row = {"subject": "S", "body": "x" * 800, "label": "neutral"}
|
||||
f = tmp_path / "score.jsonl"
|
||||
f.write_text(json.dumps(row))
|
||||
|
||||
result = build_exemplars_from_jsonl(str(f))
|
||||
body_part = result["neutral"][0].split("\n\n", 1)[1]
|
||||
assert len(body_part) == 600
|
||||
|
||||
|
||||
def test_build_exemplars_skips_rows_with_no_content(tmp_path):
|
||||
from scripts.benchmark_classifier import build_exemplars_from_jsonl
|
||||
import json
|
||||
|
||||
rows = [
|
||||
{"label": "neutral"}, # no subject, no body -> skip
|
||||
{"subject": "S", "body": "B", "label": "neutral"}, # valid -> keep
|
||||
{"label": "rejected", "subject": "", "body": ""}, # empty strings -> skip
|
||||
]
|
||||
f = tmp_path / "score.jsonl"
|
||||
lines = [json.dumps(r) for r in rows]
|
||||
f.write_text("\n".join(lines))
|
||||
|
||||
result = build_exemplars_from_jsonl(str(f))
|
||||
assert list(result.keys()) == ["neutral"]
|
||||
assert len(result["neutral"]) == 1
|
||||
|
||||
def test_registry_has_fourteen_models():
|
||||
from scripts.benchmark_classifier import MODEL_REGISTRY
|
||||
assert len(MODEL_REGISTRY) == 14
|
||||
|
||||
|
||||
def test_embed_knn_nomic_registry_entry():
|
||||
from scripts.benchmark_classifier import MODEL_REGISTRY
|
||||
from scripts.classifier_adapters import EmbeddingKNNAdapter
|
||||
entry = MODEL_REGISTRY["embed-knn-nomic"]
|
||||
assert entry["adapter"] is EmbeddingKNNAdapter
|
||||
assert entry["model_id"] == "nomic-embed-text"
|
||||
assert entry["params"] == "local-embed"
|
||||
assert entry["default"] is False
|
||||
assert entry.get("kwargs", {}).get("k") == 3
|
||||
|
|
|
|||
|
|
@ -14,7 +14,9 @@ from fastapi.testclient import TestClient
|
|||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_cforch_globals(tmp_path):
|
||||
"""Redirect _CONFIG_DIR to tmp_path and reset running-state globals."""
|
||||
"""Redirect _CONFIG_DIR to tmp_path, reset running-state globals, and stub
|
||||
list_installed to return [] so real disk model directories don't bleed into
|
||||
tests that don't exercise the installed-model merge path."""
|
||||
from app import cforch as cforch_module
|
||||
|
||||
prev_config_dir = cforch_module._CONFIG_DIR
|
||||
|
|
@ -25,6 +27,7 @@ def reset_cforch_globals(tmp_path):
|
|||
cforch_module._BENCH_RUNNING = False
|
||||
cforch_module._bench_proc = None
|
||||
|
||||
with patch("app.models.list_installed", return_value=[]):
|
||||
yield tmp_path
|
||||
|
||||
cforch_module.set_config_dir(prev_config_dir)
|
||||
|
|
@ -141,12 +144,46 @@ def test_models_parses_bench_models_yaml(client, config_dir, tmp_path):
|
|||
assert m["vram_estimate_mb"] == 6000
|
||||
|
||||
|
||||
def test_models_merges_installed_generators(client, config_dir, tmp_path):
|
||||
"""Installed cf-text/vllm generator models appear in the model list,
|
||||
deduplicated against bench_models.yaml entries."""
|
||||
models_file = tmp_path / "bench_models.yaml"
|
||||
_write_models_yaml(models_file, [
|
||||
{"name": "llama3", "id": "llama3:8b", "service": "ollama", "tags": [], "vram_estimate_mb": 6000},
|
||||
{"name": "already-there", "id": "ibm-granite/granite-4.1-8b", "service": "cf-text", "tags": [], "vram_estimate_mb": 8000},
|
||||
])
|
||||
_write_config(config_dir, {"bench_models": str(models_file)})
|
||||
|
||||
fake_installed = [
|
||||
# should be included — cf-text generator not already in YAML
|
||||
{"model_id": "meta-llama/Llama-3.1-8B", "service": "cf-text", "role": "generator", "vram_mb": 16000},
|
||||
# should be deduped — repo_id matches a YAML entry
|
||||
{"model_id": "ibm-granite/granite-4.1-8b", "service": "cf-text", "role": "generator", "vram_mb": 8000},
|
||||
# should be excluded — classifier, not a generator
|
||||
{"model_id": "cross-encoder/ms-marco-MiniLM-L6", "service": "avocet", "role": "reranker", "vram_mb": 500},
|
||||
]
|
||||
with patch("app.models.list_installed", return_value=fake_installed):
|
||||
r = client.get("/api/cforch/models")
|
||||
assert r.status_code == 200
|
||||
ids = [m["id"] for m in r.json()["models"]]
|
||||
assert "llama3:8b" in ids # from YAML
|
||||
assert "ibm-granite/granite-4.1-8b" in ids # from YAML (not duplicated)
|
||||
assert "meta-llama/Llama-3.1-8B" in ids # merged from installed
|
||||
assert "cross-encoder/ms-marco-MiniLM-L6" not in ids # filtered out (reranker)
|
||||
assert ids.count("ibm-granite/granite-4.1-8b") == 1 # no duplicate
|
||||
|
||||
|
||||
# ── GET /run ───────────────────────────────────────────────────────────────────
|
||||
|
||||
def test_run_returns_409_when_already_running(client):
|
||||
"""If _BENCH_RUNNING is True, GET /run returns 409."""
|
||||
"""If a benchmark subprocess is actively running, GET /run returns 409."""
|
||||
from unittest.mock import MagicMock
|
||||
from app import cforch as cforch_module
|
||||
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.poll.return_value = None # process still alive
|
||||
cforch_module._BENCH_RUNNING = True
|
||||
cforch_module._bench_proc = mock_proc
|
||||
|
||||
r = client.get("/api/cforch/run")
|
||||
assert r.status_code == 409
|
||||
|
|
@ -180,16 +217,15 @@ def test_run_streams_progress_events(client, config_dir, tmp_path):
|
|||
"python_bin": "/usr/bin/python3",
|
||||
})
|
||||
|
||||
mock_stdout = MagicMock()
|
||||
mock_stdout.readline.side_effect = ["Running task 1\n", "Running task 2\n", ""]
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.stdout = iter(["Running task 1\n", "Running task 2\n"])
|
||||
mock_proc.stdout = mock_stdout
|
||||
mock_proc.returncode = 1 # non-zero so we don't need summary.json
|
||||
mock_proc.wait = MagicMock()
|
||||
|
||||
def mock_wait():
|
||||
pass
|
||||
|
||||
mock_proc.wait = mock_wait
|
||||
|
||||
with patch("app.cforch._subprocess.Popen", return_value=mock_proc):
|
||||
with patch("app.cforch._subprocess.Popen", return_value=mock_proc), \
|
||||
patch("app.cforch._select.select", return_value=([mock_stdout], [], [])):
|
||||
r = client.get("/api/cforch/run")
|
||||
|
||||
assert r.status_code == 200
|
||||
|
|
@ -222,12 +258,15 @@ def test_run_emits_result_on_success(client, config_dir, tmp_path):
|
|||
"python_bin": "/usr/bin/python3",
|
||||
})
|
||||
|
||||
mock_stdout = MagicMock()
|
||||
mock_stdout.readline.side_effect = [""] # no output lines, immediate EOF
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.stdout = iter([])
|
||||
mock_proc.stdout = mock_stdout
|
||||
mock_proc.returncode = 0
|
||||
mock_proc.wait = MagicMock()
|
||||
|
||||
with patch("app.cforch._subprocess.Popen", return_value=mock_proc):
|
||||
with patch("app.cforch._subprocess.Popen", return_value=mock_proc), \
|
||||
patch("app.cforch._select.select", return_value=([mock_stdout], [], [])):
|
||||
r = client.get("/api/cforch/run")
|
||||
|
||||
assert r.status_code == 200
|
||||
|
|
@ -367,3 +406,13 @@ def test_run_passes_license_key_env_to_subprocess(client, config_dir, tmp_path,
|
|||
client.get("/api/cforch/run")
|
||||
|
||||
assert captured_env.get("CF_LICENSE_KEY") == "CFG-AVCT-ENV-ONLY-KEY"
|
||||
|
||||
|
||||
def test_eval_cforch_router_includes_all_sub_routers():
|
||||
"""eval/cforch.py router must include routes from all four sub-routers."""
|
||||
from app.eval.cforch import router
|
||||
paths = {r.path for r in router.routes}
|
||||
assert any("/cforch/" in p for p in paths), f"no /cforch/ routes found in {paths}"
|
||||
assert any("/style/" in p for p in paths), f"no /style/ routes found in {paths}"
|
||||
assert any("/voice/" in p for p in paths), f"no /voice/ routes found in {paths}"
|
||||
assert any("/plans-bench/" in p for p in paths), f"no /plans-bench/ routes found in {paths}"
|
||||
|
|
|
|||
|
|
@ -268,3 +268,373 @@ def test_finetuned_adapter_unload_clears_pipeline():
|
|||
assert adapter._pipeline is not None
|
||||
adapter.unload()
|
||||
assert adapter._pipeline is None
|
||||
|
||||
# ---- _cosine() tests ----
|
||||
|
||||
def test_cosine_identical_unit_vectors():
|
||||
import math
|
||||
from scripts.classifier_adapters import _cosine
|
||||
assert _cosine([1.0, 0.0], [1.0, 0.0]) == pytest.approx(1.0)
|
||||
|
||||
|
||||
def test_cosine_orthogonal_vectors():
|
||||
from scripts.classifier_adapters import _cosine
|
||||
assert _cosine([1.0, 0.0], [0.0, 1.0]) == pytest.approx(0.0)
|
||||
|
||||
|
||||
def test_cosine_known_value():
|
||||
import math
|
||||
from scripts.classifier_adapters import _cosine
|
||||
# [1,0] vs [1/sqrt(2), 1/sqrt(2)] → dot = 1/sqrt(2), both norms = 1 → 1/sqrt(2)
|
||||
v = [1.0 / math.sqrt(2), 1.0 / math.sqrt(2)]
|
||||
assert _cosine([1.0, 0.0], v) == pytest.approx(1.0 / math.sqrt(2))
|
||||
|
||||
|
||||
def test_cosine_zero_vector_returns_zero():
|
||||
from scripts.classifier_adapters import _cosine
|
||||
assert _cosine([0.0, 0.0], [1.0, 0.0]) == pytest.approx(0.0)
|
||||
|
||||
|
||||
# ---- DEFAULT_EXEMPLARS tests ----
|
||||
|
||||
def test_default_exemplars_covers_all_labels():
|
||||
from scripts.classifier_adapters import DEFAULT_EXEMPLARS, LABELS
|
||||
for label in LABELS:
|
||||
assert label in DEFAULT_EXEMPLARS, f"DEFAULT_EXEMPLARS missing label: {label}"
|
||||
assert len(DEFAULT_EXEMPLARS[label]) >= 4, f"{label} needs >= 4 exemplars for k=3 voting"
|
||||
|
||||
|
||||
def test_default_exemplars_sparse_labels_have_at_least_four():
|
||||
from scripts.classifier_adapters import DEFAULT_EXEMPLARS
|
||||
# These labels have very few real examples; need >= 4 so k=3 vote is meaningful
|
||||
for label in ("hired", "survey_received", "event_rescheduled"):
|
||||
assert len(DEFAULT_EXEMPLARS[label]) >= 4, (
|
||||
f"{label} needs >= 4 exemplars for k=3 voting to work reliably"
|
||||
)
|
||||
|
||||
def test_default_exemplars_strings_are_formatted_correctly():
|
||||
from scripts.classifier_adapters import DEFAULT_EXEMPLARS
|
||||
for label, texts in DEFAULT_EXEMPLARS.items():
|
||||
for text in texts:
|
||||
assert text.startswith("Subject: "), (
|
||||
f"{label!r} exemplar missing 'Subject: ' prefix: {text[:50]!r}"
|
||||
)
|
||||
assert "\n\n" in text, (
|
||||
f"{label!r} exemplar missing double-newline separator: {text[:50]!r}"
|
||||
)
|
||||
|
||||
# ---- EmbeddingKNNAdapter constructor tests ----
|
||||
|
||||
def test_embedding_knn_is_classifier_adapter():
|
||||
from scripts.classifier_adapters import EmbeddingKNNAdapter, ClassifierAdapter
|
||||
adapter = EmbeddingKNNAdapter(
|
||||
"test-knn", "nomic-embed-text",
|
||||
k=3, orch_url="http://orch:7700", ollama_url="http://ollama:11434",
|
||||
)
|
||||
assert isinstance(adapter, ClassifierAdapter)
|
||||
|
||||
|
||||
def test_embedding_knn_name_and_model_id():
|
||||
from scripts.classifier_adapters import EmbeddingKNNAdapter
|
||||
adapter = EmbeddingKNNAdapter(
|
||||
"embed-knn-nomic", "nomic-embed-text",
|
||||
k=3, orch_url="http://orch:7700", ollama_url="http://ollama:11434",
|
||||
)
|
||||
assert adapter.name == "embed-knn-nomic"
|
||||
assert adapter.model_id == "nomic-embed-text"
|
||||
|
||||
|
||||
def test_embedding_knn_uses_default_exemplars_when_none_given():
|
||||
from scripts.classifier_adapters import EmbeddingKNNAdapter, DEFAULT_EXEMPLARS
|
||||
adapter = EmbeddingKNNAdapter(
|
||||
"test", "nomic-embed-text",
|
||||
k=3, orch_url="http://orch:7700", ollama_url="http://ollama:11434",
|
||||
)
|
||||
assert adapter._exemplar_texts is DEFAULT_EXEMPLARS
|
||||
|
||||
|
||||
def test_embedding_knn_accepts_custom_exemplars():
|
||||
from scripts.classifier_adapters import EmbeddingKNNAdapter
|
||||
custom = {"rejected": ["Sorry, we went with others."]}
|
||||
adapter = EmbeddingKNNAdapter(
|
||||
"test", "nomic-embed-text",
|
||||
k=3, orch_url="http://orch:7700", ollama_url="http://ollama:11434",
|
||||
exemplar_texts=custom,
|
||||
)
|
||||
assert adapter._exemplar_texts is custom
|
||||
|
||||
|
||||
# ---- EmbeddingKNNAdapter.load() tests ----
|
||||
|
||||
def _make_post_mock(alloc_url="http://navi:11434", alloc_id="alloc-abc"):
|
||||
"""Return a side_effect function for patching httpx.post.
|
||||
|
||||
Allocate calls get alloc_url/alloc_id; embed calls return one [0.1,0.2,0.3]
|
||||
embedding per input text.
|
||||
"""
|
||||
def _side_effect(url, *, json=None, timeout=None, **kwargs):
|
||||
from unittest.mock import MagicMock
|
||||
resp = MagicMock()
|
||||
resp.raise_for_status.return_value = None
|
||||
if "/allocate" in url:
|
||||
resp.status_code = 200
|
||||
resp.json.return_value = {"allocation_id": alloc_id, "url": alloc_url}
|
||||
else:
|
||||
n = len((json or {}).get("input", []))
|
||||
resp.status_code = 200
|
||||
resp.json.return_value = {"data": [{"embedding": [0.1, 0.2, 0.3]}] * n}
|
||||
return resp
|
||||
return _side_effect
|
||||
|
||||
|
||||
def test_load_calls_allocate_then_embeds_each_label():
|
||||
from unittest.mock import patch
|
||||
from scripts.classifier_adapters import EmbeddingKNNAdapter
|
||||
|
||||
exemplars = {
|
||||
"rejected": ["We went with others"],
|
||||
"hired": ["Welcome aboard!", "First day info"],
|
||||
}
|
||||
adapter = EmbeddingKNNAdapter(
|
||||
"test", "nomic-embed-text", k=3,
|
||||
orch_url="http://orch:7700", ollama_url="http://ollama:11434",
|
||||
exemplar_texts=exemplars,
|
||||
)
|
||||
|
||||
post_urls = []
|
||||
def capturing_mock(url, *, json=None, timeout=None, **kwargs):
|
||||
post_urls.append(url)
|
||||
return _make_post_mock()(url, json=json, timeout=timeout)
|
||||
|
||||
with patch("httpx.post", side_effect=capturing_mock):
|
||||
adapter.load()
|
||||
|
||||
assert any("/allocate" in u for u in post_urls), "expected allocate call"
|
||||
assert any("/v1/embeddings" in u for u in post_urls), "expected embed call"
|
||||
assert adapter._allocation_id == "alloc-abc"
|
||||
assert adapter._node_url == "http://navi:11434"
|
||||
assert adapter._orch_url_used == "http://orch:7700"
|
||||
assert "rejected" in adapter._exemplar_embeddings
|
||||
assert "hired" in adapter._exemplar_embeddings
|
||||
assert len(adapter._exemplar_embeddings["rejected"]) == 1
|
||||
assert len(adapter._exemplar_embeddings["hired"]) == 2
|
||||
assert adapter._exemplar_embeddings["rejected"][0] == [0.1, 0.2, 0.3]
|
||||
assert adapter._exemplar_embeddings["hired"][0] == [0.1, 0.2, 0.3]
|
||||
|
||||
|
||||
def test_load_falls_back_to_ollama_when_allocate_fails():
|
||||
from unittest.mock import patch, MagicMock
|
||||
from scripts.classifier_adapters import EmbeddingKNNAdapter
|
||||
|
||||
exemplars = {"rejected": ["We went with others"]}
|
||||
adapter = EmbeddingKNNAdapter(
|
||||
"test", "nomic-embed-text", k=3,
|
||||
orch_url="http://orch:7700", ollama_url="http://ollama:11434",
|
||||
exemplar_texts=exemplars,
|
||||
)
|
||||
|
||||
def failing_allocate_mock(url, *, json=None, timeout=None, **kwargs):
|
||||
resp = MagicMock()
|
||||
if "/allocate" in url:
|
||||
resp.status_code = 503
|
||||
resp.json.return_value = {}
|
||||
else:
|
||||
resp.raise_for_status.return_value = None
|
||||
resp.json.return_value = {"data": [{"embedding": [0.1, 0.2, 0.3]}]}
|
||||
return resp
|
||||
|
||||
with patch("httpx.post", side_effect=failing_allocate_mock):
|
||||
adapter.load()
|
||||
|
||||
assert adapter._allocation_id == ""
|
||||
assert adapter._orch_url_used == ""
|
||||
assert adapter._node_url == "http://ollama:11434"
|
||||
assert "rejected" in adapter._exemplar_embeddings
|
||||
|
||||
|
||||
def test_load_falls_back_to_ollama_when_allocate_raises():
|
||||
from unittest.mock import patch, MagicMock
|
||||
import httpx as _httpx
|
||||
from scripts.classifier_adapters import EmbeddingKNNAdapter
|
||||
|
||||
exemplars = {"rejected": ["We went with others"]}
|
||||
adapter = EmbeddingKNNAdapter(
|
||||
"test", "nomic-embed-text", k=3,
|
||||
orch_url="http://orch:7700", ollama_url="http://ollama:11434",
|
||||
exemplar_texts=exemplars,
|
||||
)
|
||||
|
||||
def raising_mock(url, *, json=None, timeout=None, **kwargs):
|
||||
if "/allocate" in url:
|
||||
raise _httpx.ConnectError("connection refused")
|
||||
resp = MagicMock()
|
||||
resp.raise_for_status.return_value = None
|
||||
resp.json.return_value = {"data": [{"embedding": [0.1, 0.2, 0.3]}]}
|
||||
return resp
|
||||
|
||||
with patch("httpx.post", side_effect=raising_mock):
|
||||
adapter.load()
|
||||
|
||||
assert adapter._allocation_id == ""
|
||||
assert adapter._orch_url_used == ""
|
||||
assert adapter._node_url == "http://ollama:11434"
|
||||
assert "rejected" in adapter._exemplar_embeddings
|
||||
|
||||
|
||||
# ---- EmbeddingKNNAdapter.unload() tests ----
|
||||
|
||||
def test_unload_releases_orch_allocation_and_clears_state():
|
||||
from unittest.mock import patch, MagicMock
|
||||
from scripts.classifier_adapters import EmbeddingKNNAdapter
|
||||
|
||||
adapter = EmbeddingKNNAdapter(
|
||||
"test", "nomic-embed-text", k=3,
|
||||
orch_url="http://orch:7700", ollama_url="http://ollama:11434",
|
||||
)
|
||||
adapter._exemplar_embeddings = {"rejected": [[1.0, 0.0]]}
|
||||
adapter._node_url = "http://navi:11434"
|
||||
adapter._allocation_id = "alloc-abc"
|
||||
adapter._orch_url_used = "http://orch:7700"
|
||||
|
||||
delete_calls = []
|
||||
def mock_request(method, url, **kwargs):
|
||||
delete_calls.append((method, url))
|
||||
resp = MagicMock()
|
||||
resp.status_code = 200
|
||||
return resp
|
||||
|
||||
with patch("httpx.request", side_effect=mock_request):
|
||||
adapter.unload()
|
||||
|
||||
assert len(delete_calls) == 1
|
||||
method, url = delete_calls[0]
|
||||
assert method == "DELETE"
|
||||
assert "alloc-abc" in url
|
||||
assert adapter._exemplar_embeddings == {}
|
||||
assert adapter._allocation_id == ""
|
||||
assert adapter._node_url == ""
|
||||
assert adapter._orch_url_used == ""
|
||||
|
||||
|
||||
def test_unload_skips_delete_on_ollama_fallback_path():
|
||||
from unittest.mock import patch
|
||||
from scripts.classifier_adapters import EmbeddingKNNAdapter
|
||||
|
||||
adapter = EmbeddingKNNAdapter(
|
||||
"test", "nomic-embed-text", k=3,
|
||||
orch_url="http://orch:7700", ollama_url="http://ollama:11434",
|
||||
)
|
||||
adapter._exemplar_embeddings = {"rejected": [[1.0, 0.0]]}
|
||||
adapter._node_url = "http://ollama:11434"
|
||||
adapter._allocation_id = "" # fallback path: no allocation was made
|
||||
adapter._orch_url_used = ""
|
||||
|
||||
delete_calls = []
|
||||
with patch("httpx.request", side_effect=lambda *a, **k: delete_calls.append(a)):
|
||||
adapter.unload()
|
||||
|
||||
assert len(delete_calls) == 0
|
||||
assert adapter._exemplar_embeddings == {}
|
||||
assert adapter._node_url == ""
|
||||
|
||||
|
||||
# ---- EmbeddingKNNAdapter.classify() tests ----
|
||||
|
||||
def _adapter_with_embeddings(exemplar_embeddings, k=3):
|
||||
"""Return a pre-loaded EmbeddingKNNAdapter (bypass load()) with given per-label vectors."""
|
||||
from scripts.classifier_adapters import EmbeddingKNNAdapter
|
||||
adapter = EmbeddingKNNAdapter(
|
||||
"test", "nomic-embed-text", k=k,
|
||||
orch_url="http://orch:7700", ollama_url="http://ollama:11434",
|
||||
)
|
||||
adapter._exemplar_embeddings = exemplar_embeddings
|
||||
adapter._node_url = "http://navi:11434"
|
||||
return adapter
|
||||
|
||||
|
||||
def _embed_resp(vec):
|
||||
"""Return a mock httpx response for /v1/embeddings returning a single vector."""
|
||||
from unittest.mock import MagicMock
|
||||
resp = MagicMock()
|
||||
resp.raise_for_status.return_value = None
|
||||
resp.json.return_value = {"data": [{"embedding": vec}]}
|
||||
return resp
|
||||
|
||||
|
||||
def test_classify_returns_majority_vote_label():
|
||||
from unittest.mock import patch
|
||||
adapter = _adapter_with_embeddings({
|
||||
"rejected": [[1.0, 0.0, 0.0], [0.9, 0.1, 0.0], [0.85, 0.15, 0.0]],
|
||||
"neutral": [[0.0, 1.0, 0.0]],
|
||||
}, k=3)
|
||||
|
||||
# Query [1,0,0] is closest to all three "rejected" exemplars
|
||||
with patch("httpx.post", return_value=_embed_resp([1.0, 0.0, 0.0])):
|
||||
result = adapter.classify("We went with others", "Thank you for applying.")
|
||||
|
||||
assert result == "rejected"
|
||||
|
||||
|
||||
def test_classify_tiebreak_by_mean_score():
|
||||
from unittest.mock import patch
|
||||
# k=2: each label gets exactly 1 vote → tie-break by mean similarity
|
||||
# [1,0] query: cosine to [1,0] = 1.0 ("rejected"), cosine to [0.6,0.8] ≈ 0.6 ("neutral")
|
||||
adapter = _adapter_with_embeddings({
|
||||
"rejected": [[1.0, 0.0]],
|
||||
"neutral": [[0.6, 0.8]],
|
||||
}, k=2)
|
||||
|
||||
with patch("httpx.post", return_value=_embed_resp([1.0, 0.0])):
|
||||
result = adapter.classify("Rejection", "Sorry")
|
||||
|
||||
assert result == "rejected"
|
||||
|
||||
|
||||
def test_classify_sparse_label_can_win():
|
||||
from unittest.mock import patch
|
||||
# "hired" has only 1 exemplar; with k=1, the single closest match wins
|
||||
adapter = _adapter_with_embeddings({
|
||||
"rejected": [[0.0, 0.0, 1.0], [0.0, 0.1, 0.9]],
|
||||
"hired": [[1.0, 0.0, 0.0]],
|
||||
}, k=1)
|
||||
|
||||
# Query [1,0,0] → hired exemplar scores 1.0; closest single match wins
|
||||
with patch("httpx.post", return_value=_embed_resp([1.0, 0.0, 0.0])):
|
||||
result = adapter.classify("Welcome aboard", "Your first day details")
|
||||
|
||||
assert result == "hired"
|
||||
|
||||
|
||||
def test_classify_lazy_loads_when_not_loaded():
|
||||
from unittest.mock import patch
|
||||
from scripts.classifier_adapters import EmbeddingKNNAdapter
|
||||
|
||||
exemplars = {"rejected": ["We went with others"]}
|
||||
adapter = EmbeddingKNNAdapter(
|
||||
"test", "nomic-embed-text", k=1,
|
||||
orch_url="http://orch:7700", ollama_url="http://ollama:11434",
|
||||
exemplar_texts=exemplars,
|
||||
)
|
||||
assert adapter._exemplar_embeddings == {}
|
||||
|
||||
post_urls = []
|
||||
def mock_post(url, *, json=None, timeout=None, **kwargs):
|
||||
post_urls.append(url)
|
||||
from unittest.mock import MagicMock
|
||||
resp = MagicMock()
|
||||
resp.raise_for_status.return_value = None
|
||||
if "/allocate" in url:
|
||||
resp.status_code = 200
|
||||
resp.json.return_value = {"allocation_id": "a1", "url": "http://navi:11434"}
|
||||
else:
|
||||
n = len((json or {}).get("input", []))
|
||||
resp.json.return_value = {"data": [{"embedding": [1.0, 0.0]}] * n}
|
||||
return resp
|
||||
|
||||
with patch("httpx.post", side_effect=mock_post):
|
||||
result = adapter.classify("Rejection", "Sorry")
|
||||
|
||||
assert result == "rejected"
|
||||
assert any("/allocate" in u for u in post_urls), "lazy load must call allocate"
|
||||
assert adapter._exemplar_embeddings != {}
|
||||
assert adapter._node_url == "http://navi:11434"
|
||||
|
|
|
|||
122
tests/test_dashboard.py
Normal file
122
tests/test_dashboard.py
Normal file
|
|
@ -0,0 +1,122 @@
|
|||
"""Tests for app/dashboard.py -- GET /api/dashboard."""
|
||||
import json
|
||||
import pytest
|
||||
import yaml
|
||||
from fastapi.testclient import TestClient
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_globals(tmp_path):
|
||||
from app import dashboard as dash_module
|
||||
dash_module.set_data_dir(tmp_path)
|
||||
dash_module.set_config_dir(tmp_path)
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
from app.api import app
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
def _write_score(tmp_path: Path, records: list[dict]) -> None:
|
||||
(tmp_path / "email_score.jsonl").write_text(
|
||||
"\n".join(json.dumps(r) for r in records) + "\n"
|
||||
)
|
||||
|
||||
def _write_summary(tmp_path: Path, run_id: str, ts: str, score: float) -> None:
|
||||
run_dir = tmp_path / "bench_results" / run_id
|
||||
run_dir.mkdir(parents=True)
|
||||
(run_dir / "summary.json").write_text(
|
||||
json.dumps({"timestamp": ts, "best_macro_f1": score})
|
||||
)
|
||||
|
||||
|
||||
def test_dashboard_returns_expected_keys(client):
|
||||
r = client.get("/api/dashboard")
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
for key in ("labeled_since_last_eval", "last_eval_timestamp", "last_eval_best_score",
|
||||
"active_jobs", "corrections_pending", "corrections_export_ready", "signals"):
|
||||
assert key in data, f"missing key: {key}"
|
||||
for sig in ("data_to_eval", "eval_to_train", "train_to_fleet"):
|
||||
assert sig in data["signals"], f"missing signal: {sig}"
|
||||
|
||||
|
||||
def test_dashboard_empty_state(client):
|
||||
r = client.get("/api/dashboard")
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["labeled_since_last_eval"] == 0
|
||||
assert data["last_eval_timestamp"] is None
|
||||
assert data["last_eval_best_score"] is None
|
||||
assert data["active_jobs"] == []
|
||||
assert data["corrections_pending"] == 0
|
||||
assert data["corrections_export_ready"] == 0
|
||||
|
||||
|
||||
def test_labeled_since_counts_all_when_no_eval(client, tmp_path):
|
||||
_write_score(tmp_path, [
|
||||
{"id": "a", "label": "neutral", "labeled_at": "2026-05-01T10:00:00+00:00"},
|
||||
{"id": "b", "label": "neutral", "labeled_at": "2026-05-01T11:00:00+00:00"},
|
||||
])
|
||||
r = client.get("/api/dashboard")
|
||||
assert r.json()["labeled_since_last_eval"] == 2
|
||||
|
||||
|
||||
def test_labeled_since_filters_by_eval_timestamp(client, tmp_path):
|
||||
_write_summary(tmp_path, "2026-05-01-100000", "2026-05-01T10:00:00+00:00", 0.80)
|
||||
_write_score(tmp_path, [
|
||||
{"id": "a", "label": "neutral", "labeled_at": "2026-05-01T09:00:00+00:00"},
|
||||
{"id": "b", "label": "neutral", "labeled_at": "2026-05-01T11:00:00+00:00"},
|
||||
])
|
||||
(tmp_path / "label_tool.yaml").write_text(
|
||||
yaml.dump({"cforch": {"results_dir": str(tmp_path / "bench_results")}})
|
||||
)
|
||||
r = client.get("/api/dashboard")
|
||||
data = r.json()
|
||||
assert data["labeled_since_last_eval"] == 1
|
||||
assert abs(data["last_eval_best_score"] - 0.80) < 0.001
|
||||
|
||||
|
||||
def test_data_to_eval_false_below_threshold(client, tmp_path):
|
||||
_write_score(tmp_path, [{"id": str(i), "label": "neutral",
|
||||
"labeled_at": "2026-05-01T10:00:00+00:00"} for i in range(10)])
|
||||
(tmp_path / "label_tool.yaml").write_text(yaml.dump({"pipeline": {"data_eval_threshold": 50}}))
|
||||
r = client.get("/api/dashboard")
|
||||
assert r.json()["signals"]["data_to_eval"] is False
|
||||
|
||||
|
||||
def test_data_to_eval_true_at_threshold(client, tmp_path):
|
||||
_write_score(tmp_path, [{"id": str(i), "label": "neutral",
|
||||
"labeled_at": "2026-05-01T10:00:00+00:00"} for i in range(50)])
|
||||
(tmp_path / "label_tool.yaml").write_text(yaml.dump({"pipeline": {"data_eval_threshold": 50}}))
|
||||
r = client.get("/api/dashboard")
|
||||
assert r.json()["signals"]["data_to_eval"] is True
|
||||
|
||||
|
||||
def test_corrections_pending_count(client, tmp_path):
|
||||
candidates = [
|
||||
{"id": "c1", "status": "needs_review"},
|
||||
{"id": "c2", "status": "needs_review"},
|
||||
{"id": "c3", "status": "discarded"},
|
||||
]
|
||||
(tmp_path / "sft_candidates.jsonl").write_text(
|
||||
"\n".join(json.dumps(c) for c in candidates) + "\n"
|
||||
)
|
||||
r = client.get("/api/dashboard")
|
||||
assert r.json()["corrections_pending"] == 2
|
||||
|
||||
|
||||
def test_corrections_export_ready_count(client, tmp_path):
|
||||
approved = [
|
||||
{"id": "a1", "status": "approved", "corrected_response": "Good answer"},
|
||||
{"id": "a2", "status": "approved", "corrected_response": ""},
|
||||
{"id": "a3", "status": "approved", "corrected_response": "Another answer"},
|
||||
]
|
||||
(tmp_path / "sft_approved.jsonl").write_text(
|
||||
"\n".join(json.dumps(a) for a in approved) + "\n"
|
||||
)
|
||||
r = client.get("/api/dashboard")
|
||||
assert r.json()["corrections_export_ready"] == 2
|
||||
102
tests/test_data_corrections.py
Normal file
102
tests/test_data_corrections.py
Normal file
|
|
@ -0,0 +1,102 @@
|
|||
"""Tests for app/data/corrections.py -- POST /api/sft/ingest.
|
||||
|
||||
The corrections router is mounted at prefix="/api/sft" via the app/sft.py
|
||||
backward-compat shim, so ingest lives at /api/sft/ingest.
|
||||
"""
|
||||
import json
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_globals(tmp_path):
|
||||
from app.data import corrections as corr_module
|
||||
corr_module.set_data_dir(tmp_path)
|
||||
corr_module.set_config_dir(tmp_path)
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
from app.api import app
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
_VALID_PAYLOAD = {
|
||||
"source": "peregrine",
|
||||
"task_type": "email_classification",
|
||||
"prompt": "Classify this email: ...",
|
||||
"response": "skip",
|
||||
"correction": "action_required",
|
||||
"label": "action_required",
|
||||
}
|
||||
|
||||
_SECRET = "test-secret-abc123"
|
||||
|
||||
|
||||
def test_ingest_503_when_secret_not_configured(client, monkeypatch):
|
||||
monkeypatch.delenv("AVOCET_INGESTION_SECRET", raising=False)
|
||||
r = client.post("/api/sft/ingest", json=_VALID_PAYLOAD,
|
||||
headers={"Authorization": f"Bearer {_SECRET}"})
|
||||
assert r.status_code == 503
|
||||
|
||||
|
||||
def test_ingest_401_when_no_auth_header(client, monkeypatch):
|
||||
monkeypatch.setenv("AVOCET_INGESTION_SECRET", _SECRET)
|
||||
r = client.post("/api/sft/ingest", json=_VALID_PAYLOAD)
|
||||
assert r.status_code == 401
|
||||
|
||||
|
||||
def test_ingest_401_when_malformed_header(client, monkeypatch):
|
||||
monkeypatch.setenv("AVOCET_INGESTION_SECRET", _SECRET)
|
||||
r = client.post("/api/sft/ingest", json=_VALID_PAYLOAD,
|
||||
headers={"Authorization": "Token bad-format"})
|
||||
assert r.status_code == 401
|
||||
|
||||
|
||||
def test_ingest_403_when_wrong_secret(client, monkeypatch):
|
||||
monkeypatch.setenv("AVOCET_INGESTION_SECRET", _SECRET)
|
||||
r = client.post("/api/sft/ingest", json=_VALID_PAYLOAD,
|
||||
headers={"Authorization": "Bearer wrong-secret"})
|
||||
assert r.status_code == 403
|
||||
|
||||
|
||||
def test_ingest_creates_approved_record(client, monkeypatch, tmp_path):
|
||||
from app.data import corrections as corr_module
|
||||
monkeypatch.setenv("AVOCET_INGESTION_SECRET", _SECRET)
|
||||
corr_module.set_data_dir(tmp_path)
|
||||
r = client.post("/api/sft/ingest", json=_VALID_PAYLOAD,
|
||||
headers={"Authorization": f"Bearer {_SECRET}"})
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["ok"] is True
|
||||
assert "id" in data
|
||||
candidates = corr_module.read_jsonl(corr_module._candidates_file())
|
||||
assert len(candidates) == 1
|
||||
rec = candidates[0]
|
||||
assert rec["status"] == "approved"
|
||||
assert rec["source"] == "peregrine"
|
||||
assert rec["corrected_response"] == "action_required"
|
||||
assert rec["id"] == data["id"]
|
||||
|
||||
|
||||
def test_ingest_also_writes_to_approved_file(client, monkeypatch, tmp_path):
|
||||
from app.data import corrections as corr_module
|
||||
monkeypatch.setenv("AVOCET_INGESTION_SECRET", _SECRET)
|
||||
corr_module.set_data_dir(tmp_path)
|
||||
r = client.post("/api/sft/ingest", json=_VALID_PAYLOAD,
|
||||
headers={"Authorization": f"Bearer {_SECRET}"})
|
||||
assert r.status_code == 200
|
||||
approved = corr_module.read_jsonl(corr_module._approved_file())
|
||||
assert len(approved) == 1
|
||||
assert approved[0]["id"] == r.json()["id"]
|
||||
|
||||
|
||||
def test_ingest_without_label_is_accepted(client, monkeypatch, tmp_path):
|
||||
from app.data import corrections as corr_module
|
||||
monkeypatch.setenv("AVOCET_INGESTION_SECRET", _SECRET)
|
||||
corr_module.set_data_dir(tmp_path)
|
||||
payload = {**_VALID_PAYLOAD, "label": None}
|
||||
r = client.post("/api/sft/ingest", json=payload,
|
||||
headers={"Authorization": f"Bearer {_SECRET}"})
|
||||
assert r.status_code == 200
|
||||
95
tests/test_data_fetch.py
Normal file
95
tests/test_data_fetch.py
Normal file
|
|
@ -0,0 +1,95 @@
|
|||
"""Tests for app/data/fetch.py"""
|
||||
import json
|
||||
import yaml
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_globals(tmp_path):
|
||||
from app.data import fetch as fetch_module
|
||||
fetch_module.set_data_dir(tmp_path)
|
||||
fetch_module.set_config_dir(tmp_path)
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
from app.api import app
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
def _parse_sse(content: bytes) -> list[dict]:
|
||||
events = []
|
||||
for line in content.decode().splitlines():
|
||||
if line.startswith("data: "):
|
||||
events.append(json.loads(line[6:]))
|
||||
return events
|
||||
|
||||
|
||||
def test_account_test_missing_fields(client):
|
||||
r = client.post("/api/accounts/test",
|
||||
json={"account": {"host": "", "username": "", "password": ""}})
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["ok"] is False
|
||||
assert "required" in data["message"].lower()
|
||||
|
||||
|
||||
def test_account_test_success(client):
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.select.return_value = ("OK", [b"99"])
|
||||
with patch("app.data.fetch.imaplib.IMAP4_SSL", return_value=mock_conn):
|
||||
r = client.post("/api/accounts/test", json={"account": {
|
||||
"host": "imap.example.com", "port": 993, "use_ssl": True,
|
||||
"username": "u@example.com", "password": "pw", "folder": "INBOX",
|
||||
}})
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["ok"] is True
|
||||
assert data["count"] == 99
|
||||
|
||||
|
||||
def test_fetch_stream_no_accounts_configured(client, tmp_path):
|
||||
r = client.get("/api/fetch/stream?accounts=NoSuchAccount&days_back=30&limit=10")
|
||||
assert r.status_code == 200
|
||||
events = _parse_sse(r.content)
|
||||
complete = next((e for e in events if e["type"] == "complete"), None)
|
||||
assert complete is not None
|
||||
assert complete["total_added"] == 0
|
||||
|
||||
|
||||
def test_fetch_stream_with_mock_imap(client, tmp_path):
|
||||
from app.data import fetch as fetch_module
|
||||
fetch_module.set_config_dir(tmp_path)
|
||||
cfg = {"accounts": [{"name": "Mock", "host": "h", "port": 993, "use_ssl": True,
|
||||
"username": "u", "password": "p", "folder": "INBOX",
|
||||
"days_back": 30}], "max_per_account": 50}
|
||||
(tmp_path / "label_tool.yaml").write_text(yaml.dump(cfg))
|
||||
raw_msg = (b"Subject: Interview\r\nFrom: a@b.com\r\n"
|
||||
b"Date: Mon, 1 Mar 2026 12:00:00 +0000\r\n\r\nBody")
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.search.return_value = ("OK", [b"1"])
|
||||
mock_conn.fetch.return_value = ("OK", [(b"1 (RFC822 {N})", raw_msg)])
|
||||
with patch("app.data.fetch.imaplib.IMAP4_SSL", return_value=mock_conn):
|
||||
r = client.get("/api/fetch/stream?accounts=Mock&days_back=30&limit=50")
|
||||
assert r.status_code == 200
|
||||
events = _parse_sse(r.content)
|
||||
types = [e["type"] for e in events]
|
||||
assert "start" in types
|
||||
assert "done" in types
|
||||
assert "complete" in types
|
||||
|
||||
|
||||
def test_entry_key_deterministic():
|
||||
from app.data.fetch import entry_key
|
||||
e = {"subject": "Test", "body": "Hello world"}
|
||||
assert entry_key(e) == entry_key(e)
|
||||
|
||||
|
||||
def test_entry_key_differs_by_subject():
|
||||
from app.data.fetch import entry_key
|
||||
a = {"subject": "A", "body": "same body"}
|
||||
b = {"subject": "B", "body": "same body"}
|
||||
assert entry_key(a) != entry_key(b)
|
||||
219
tests/test_data_label.py
Normal file
219
tests/test_data_label.py
Normal file
|
|
@ -0,0 +1,219 @@
|
|||
"""Tests for app/data/label.py"""
|
||||
import json
|
||||
import pytest
|
||||
import yaml
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_globals(tmp_path):
|
||||
from app.data import label as label_module
|
||||
label_module.set_data_dir(tmp_path)
|
||||
label_module.set_config_dir(tmp_path)
|
||||
label_module.reset_last_action()
|
||||
yield
|
||||
label_module.reset_last_action()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
from app.api import app
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def queue_with_items(tmp_path):
|
||||
from app.data import label as label_module
|
||||
items = [
|
||||
{"id": f"id{i}", "subject": f"Subject {i}", "body": f"Body {i}",
|
||||
"from": "test@example.com", "date": "2026-03-01", "source": "imap:test"}
|
||||
for i in range(3)
|
||||
]
|
||||
(label_module._DATA_DIR / "email_label_queue.jsonl").write_text(
|
||||
"\n".join(json.dumps(x) for x in items) + "\n")
|
||||
return items
|
||||
|
||||
|
||||
def test_queue_returns_items(client, queue_with_items):
|
||||
r = client.get("/api/queue?limit=2")
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert len(data["items"]) == 2
|
||||
assert data["total"] == 3
|
||||
|
||||
|
||||
def test_queue_empty_when_no_file(client):
|
||||
r = client.get("/api/queue")
|
||||
assert r.status_code == 200
|
||||
assert r.json() == {"items": [], "total": 0}
|
||||
|
||||
|
||||
def test_label_appends_to_score(client, queue_with_items):
|
||||
from app.data import label as label_module
|
||||
r = client.post("/api/label", json={"id": "id0", "label": "interview_scheduled"})
|
||||
assert r.status_code == 200
|
||||
records = label_module.read_jsonl(label_module._score_file())
|
||||
assert len(records) == 1
|
||||
assert records[0]["id"] == "id0"
|
||||
assert records[0]["label"] == "interview_scheduled"
|
||||
assert "labeled_at" in records[0]
|
||||
|
||||
|
||||
def test_label_removes_from_queue(client, queue_with_items):
|
||||
from app.data import label as label_module
|
||||
client.post("/api/label", json={"id": "id0", "label": "rejected"})
|
||||
queue = label_module.read_jsonl(label_module._queue_file())
|
||||
assert not any(x["id"] == "id0" for x in queue)
|
||||
|
||||
|
||||
def test_label_unknown_id_returns_404(client, queue_with_items):
|
||||
r = client.post("/api/label", json={"id": "unknown", "label": "neutral"})
|
||||
assert r.status_code == 404
|
||||
|
||||
|
||||
def test_skip_moves_to_back(client, queue_with_items):
|
||||
from app.data import label as label_module
|
||||
r = client.post("/api/skip", json={"id": "id0"})
|
||||
assert r.status_code == 200
|
||||
queue = label_module.read_jsonl(label_module._queue_file())
|
||||
assert queue[-1]["id"] == "id0"
|
||||
assert queue[0]["id"] == "id1"
|
||||
|
||||
|
||||
def test_skip_unknown_id_returns_404(client, queue_with_items):
|
||||
r = client.post("/api/skip", json={"id": "nope"})
|
||||
assert r.status_code == 404
|
||||
|
||||
|
||||
def test_discard_writes_to_discarded_file(client, queue_with_items):
|
||||
from app.data import label as label_module
|
||||
r = client.post("/api/discard", json={"id": "id1"})
|
||||
assert r.status_code == 200
|
||||
discarded = label_module.read_jsonl(label_module._discarded_file())
|
||||
assert len(discarded) == 1
|
||||
assert discarded[0]["id"] == "id1"
|
||||
assert discarded[0]["label"] == "__discarded__"
|
||||
|
||||
|
||||
def test_discard_removes_from_queue(client, queue_with_items):
|
||||
from app.data import label as label_module
|
||||
client.post("/api/discard", json={"id": "id1"})
|
||||
queue = label_module.read_jsonl(label_module._queue_file())
|
||||
assert not any(x["id"] == "id1" for x in queue)
|
||||
|
||||
|
||||
def test_undo_label_removes_from_score(client, queue_with_items):
|
||||
from app.data import label as label_module
|
||||
client.post("/api/label", json={"id": "id0", "label": "neutral"})
|
||||
r = client.delete("/api/label/undo")
|
||||
assert r.status_code == 200
|
||||
assert r.json()["undone"]["type"] == "label"
|
||||
assert label_module.read_jsonl(label_module._score_file()) == []
|
||||
queue = label_module.read_jsonl(label_module._queue_file())
|
||||
assert queue[0]["id"] == "id0"
|
||||
|
||||
|
||||
def test_undo_discard_removes_from_discarded(client, queue_with_items):
|
||||
from app.data import label as label_module
|
||||
client.post("/api/discard", json={"id": "id0"})
|
||||
r = client.delete("/api/label/undo")
|
||||
assert r.status_code == 200
|
||||
assert label_module.read_jsonl(label_module._discarded_file()) == []
|
||||
|
||||
|
||||
def test_undo_skip_restores_to_front(client, queue_with_items):
|
||||
from app.data import label as label_module
|
||||
client.post("/api/skip", json={"id": "id0"})
|
||||
r = client.delete("/api/label/undo")
|
||||
assert r.status_code == 200
|
||||
queue = label_module.read_jsonl(label_module._queue_file())
|
||||
assert queue[0]["id"] == "id0"
|
||||
|
||||
|
||||
def test_undo_with_no_action_returns_404(client):
|
||||
r = client.delete("/api/label/undo")
|
||||
assert r.status_code == 404
|
||||
|
||||
|
||||
def test_config_labels_returns_10_labels(client):
|
||||
r = client.get("/api/config/labels")
|
||||
assert r.status_code == 200
|
||||
labels = r.json()
|
||||
assert len(labels) == 10
|
||||
assert labels[0]["key"] == "1"
|
||||
for lbl in labels:
|
||||
assert "emoji" in lbl and "color" in lbl and "name" in lbl
|
||||
|
||||
|
||||
def test_get_config_returns_empty_when_no_file(client):
|
||||
r = client.get("/api/config")
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["accounts"] == []
|
||||
assert data["max_per_account"] == 500
|
||||
|
||||
|
||||
def test_post_config_writes_yaml(client, tmp_path):
|
||||
from app.data import label as label_module
|
||||
label_module.set_config_dir(tmp_path)
|
||||
payload = {"accounts": [{"name": "Test", "host": "imap.test.com", "port": 993,
|
||||
"use_ssl": True, "username": "u@t.com", "password": "pw",
|
||||
"folder": "INBOX", "days_back": 30}], "max_per_account": 200}
|
||||
r = client.post("/api/config", json=payload)
|
||||
assert r.status_code == 200
|
||||
assert r.json()["ok"] is True
|
||||
saved = yaml.safe_load((tmp_path / "label_tool.yaml").read_text())
|
||||
assert saved["max_per_account"] == 200
|
||||
assert saved["accounts"][0]["name"] == "Test"
|
||||
|
||||
|
||||
def test_get_config_round_trips(client, tmp_path):
|
||||
from app.data import label as label_module
|
||||
label_module.set_config_dir(tmp_path)
|
||||
payload = {"accounts": [{"name": "R", "host": "h", "port": 993, "use_ssl": True,
|
||||
"username": "u", "password": "p", "folder": "INBOX",
|
||||
"days_back": 90}], "max_per_account": 300}
|
||||
client.post("/api/config", json=payload)
|
||||
r = client.get("/api/config")
|
||||
data = r.json()
|
||||
assert data["max_per_account"] == 300
|
||||
assert data["accounts"][0]["name"] == "R"
|
||||
|
||||
|
||||
def test_stats_returns_counts(client, tmp_path):
|
||||
from app.data import label as label_module
|
||||
label_module.set_data_dir(tmp_path)
|
||||
score_path = tmp_path / "email_score.jsonl"
|
||||
records = [{"id": "a", "label": "interview_scheduled"},
|
||||
{"id": "b", "label": "interview_scheduled"},
|
||||
{"id": "c", "label": "rejected"}]
|
||||
score_path.write_text("\n".join(json.dumps(r) for r in records) + "\n")
|
||||
r = client.get("/api/stats")
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["total"] == 3
|
||||
assert data["counts"]["interview_scheduled"] == 2
|
||||
assert data["counts"]["rejected"] == 1
|
||||
|
||||
|
||||
def test_stats_empty_when_no_file(client):
|
||||
r = client.get("/api/stats")
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["total"] == 0
|
||||
assert data["counts"] == {}
|
||||
assert data["score_file_bytes"] == 0
|
||||
|
||||
|
||||
def test_stats_download_returns_file(client, tmp_path):
|
||||
from app.data import label as label_module
|
||||
label_module.set_data_dir(tmp_path)
|
||||
(tmp_path / "email_score.jsonl").write_text(json.dumps({"id": "a", "label": "neutral"}) + "\n")
|
||||
r = client.get("/api/stats/download")
|
||||
assert r.status_code == 200
|
||||
assert "jsonlines" in r.headers.get("content-type", "")
|
||||
|
||||
|
||||
def test_stats_download_404_when_no_file(client):
|
||||
r = client.get("/api/stats/download")
|
||||
assert r.status_code == 404
|
||||
234
tests/test_embed_bench.py
Normal file
234
tests/test_embed_bench.py
Normal file
|
|
@ -0,0 +1,234 @@
|
|||
"""Tests for app/eval/embed_bench.py."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
|
||||
# ── Fixtures ──────────────────────────────────────────────────────────────────
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_embed_bench_globals(tmp_path):
|
||||
"""Redirect config dir to tmp_path and reset running flag."""
|
||||
from app.eval import embed_bench as mod
|
||||
|
||||
prev_config_dir = mod._CONFIG_DIR
|
||||
prev_running = mod._RUN_ACTIVE
|
||||
|
||||
mod.set_config_dir(tmp_path)
|
||||
mod._RUN_ACTIVE = False
|
||||
|
||||
yield tmp_path
|
||||
|
||||
mod.set_config_dir(prev_config_dir)
|
||||
mod._RUN_ACTIVE = prev_running
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
from app.api import app
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
# ── cosine helper ──────────────────────────────────────────────────────────────
|
||||
|
||||
def test_cosine_identical():
|
||||
from app.eval.embed_bench import _cosine
|
||||
assert _cosine([1.0, 0.0], [1.0, 0.0]) == pytest.approx(1.0)
|
||||
|
||||
|
||||
def test_cosine_orthogonal():
|
||||
from app.eval.embed_bench import _cosine
|
||||
assert _cosine([1.0, 0.0], [0.0, 1.0]) == pytest.approx(0.0)
|
||||
|
||||
|
||||
def test_cosine_opposite():
|
||||
from app.eval.embed_bench import _cosine
|
||||
assert _cosine([1.0, 0.0], [-1.0, 0.0]) == pytest.approx(-1.0)
|
||||
|
||||
|
||||
def test_cosine_zero_vector_returns_zero():
|
||||
from app.eval.embed_bench import _cosine
|
||||
assert _cosine([0.0, 0.0], [1.0, 0.0]) == pytest.approx(0.0)
|
||||
|
||||
|
||||
# ── models endpoint ────────────────────────────────────────────────────────────
|
||||
|
||||
def test_models_returns_list_with_mock(client, tmp_path):
|
||||
"""GET /api/embed-bench/models returns list from Ollama tags endpoint."""
|
||||
import yaml
|
||||
cfg = {"cforch": {"ollama_url": "http://localhost:11434"}}
|
||||
(tmp_path / "label_tool.yaml").write_text(yaml.dump(cfg))
|
||||
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.json.return_value = {
|
||||
"models": [
|
||||
{"name": "nomic-embed-text", "size": 274302480},
|
||||
{"name": "mxbai-embed-large", "size": 669000000},
|
||||
]
|
||||
}
|
||||
mock_resp.raise_for_status = MagicMock()
|
||||
|
||||
with patch("app.eval.embed_bench.httpx.get", return_value=mock_resp):
|
||||
r = client.get("/api/embed-bench/models")
|
||||
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert isinstance(data["models"], list)
|
||||
assert any(m["name"] == "nomic-embed-text" for m in data["models"])
|
||||
|
||||
|
||||
def test_models_returns_empty_on_ollama_error(client, tmp_path):
|
||||
"""GET /api/embed-bench/models returns empty list if Ollama unreachable."""
|
||||
import httpx
|
||||
with patch("app.eval.embed_bench.httpx.get", side_effect=httpx.ConnectError("refused")):
|
||||
r = client.get("/api/embed-bench/models")
|
||||
assert r.status_code == 200
|
||||
assert r.json()["models"] == []
|
||||
|
||||
|
||||
# ── run endpoint ───────────────────────────────────────────────────────────────
|
||||
|
||||
def test_run_empty_corpus_returns_422(client):
|
||||
r = client.post("/api/embed-bench/run", json={
|
||||
"corpus": [], "queries": ["test"], "models": ["nomic-embed-text"], "top_k": 3
|
||||
})
|
||||
assert r.status_code == 422
|
||||
|
||||
|
||||
def test_run_empty_queries_returns_422(client):
|
||||
r = client.post("/api/embed-bench/run", json={
|
||||
"corpus": ["chunk 1"], "queries": [], "models": ["nomic-embed-text"], "top_k": 3
|
||||
})
|
||||
assert r.status_code == 422
|
||||
|
||||
|
||||
def test_run_empty_models_returns_422(client):
|
||||
r = client.post("/api/embed-bench/run", json={
|
||||
"corpus": ["chunk 1"], "queries": ["test"], "models": [], "top_k": 3
|
||||
})
|
||||
assert r.status_code == 422
|
||||
|
||||
|
||||
def _fake_embed_response(texts: list[str]) -> MagicMock:
|
||||
"""Build a mock httpx.post response returning unit vectors for each text."""
|
||||
resp = MagicMock()
|
||||
resp.raise_for_status = MagicMock()
|
||||
resp.json.return_value = {
|
||||
"data": [{"embedding": [1.0, 0.0, 0.0] if i % 2 == 0 else [0.0, 1.0, 0.0]}
|
||||
for i, _ in enumerate(texts)]
|
||||
}
|
||||
return resp
|
||||
|
||||
|
||||
def _collect_sse(raw: bytes) -> list[dict]:
|
||||
"""Parse SSE stream bytes into a list of decoded event dicts."""
|
||||
events = []
|
||||
for line in raw.decode().splitlines():
|
||||
if line.startswith("data: "):
|
||||
events.append(json.loads(line[6:]))
|
||||
return events
|
||||
|
||||
|
||||
def test_run_single_model_returns_result_and_done(client, tmp_path):
|
||||
import yaml
|
||||
(tmp_path / "label_tool.yaml").write_text(yaml.dump({"cforch": {"ollama_url": "http://localhost:11434"}}))
|
||||
|
||||
with patch("app.eval.embed_bench.httpx.post", return_value=_fake_embed_response(["chunk 1", "chunk 2"])):
|
||||
r = client.post("/api/embed-bench/run", json={
|
||||
"corpus": ["chunk 1", "chunk 2"],
|
||||
"queries": ["what is chunk one?"],
|
||||
"models": ["nomic-embed-text"],
|
||||
"top_k": 2,
|
||||
})
|
||||
|
||||
assert r.status_code == 200
|
||||
events = _collect_sse(r.content)
|
||||
types = [e["type"] for e in events]
|
||||
assert "result" in types
|
||||
assert types[-1] == "done"
|
||||
result_events = [e for e in events if e["type"] == "result"]
|
||||
assert result_events[0]["model"] == "nomic-embed-text"
|
||||
assert result_events[0]["query_idx"] == 0
|
||||
assert len(result_events[0]["hits"]) <= 2
|
||||
|
||||
|
||||
def test_run_two_models_returns_two_result_events_per_query(client, tmp_path):
|
||||
import yaml
|
||||
(tmp_path / "label_tool.yaml").write_text(yaml.dump({"cforch": {"ollama_url": "http://localhost:11434"}}))
|
||||
|
||||
with patch("app.eval.embed_bench.httpx.post", return_value=_fake_embed_response(["chunk A", "chunk B"])):
|
||||
r = client.post("/api/embed-bench/run", json={
|
||||
"corpus": ["chunk A", "chunk B"],
|
||||
"queries": ["find it"],
|
||||
"models": ["nomic-embed-text", "mxbai-embed-large"],
|
||||
"top_k": 2,
|
||||
})
|
||||
|
||||
events = _collect_sse(r.content)
|
||||
result_events = [e for e in events if e["type"] == "result"]
|
||||
models_seen = {e["model"] for e in result_events}
|
||||
assert "nomic-embed-text" in models_seen
|
||||
assert "mxbai-embed-large" in models_seen
|
||||
|
||||
|
||||
# ── rate + export ──────────────────────────────────────────────────────────────
|
||||
|
||||
def test_rate_appends_jsonl_line(client, tmp_path):
|
||||
r = client.post("/api/embed-bench/rate", json={
|
||||
"query": "test query",
|
||||
"model": "nomic-embed-text",
|
||||
"chunk_text": "some text",
|
||||
"chunk_idx": 2,
|
||||
"rating": "relevant",
|
||||
})
|
||||
assert r.status_code == 200
|
||||
assert r.json() == {"ok": True}
|
||||
ratings_file = tmp_path / "embed_bench_ratings.jsonl"
|
||||
assert ratings_file.exists()
|
||||
line = json.loads(ratings_file.read_text().strip())
|
||||
assert line["query"] == "test query"
|
||||
assert line["rating"] == "relevant"
|
||||
assert line["chunk_idx"] == 2
|
||||
assert "timestamp" in line
|
||||
|
||||
|
||||
def test_export_csv_two_rows(client, tmp_path):
|
||||
for i in range(2):
|
||||
client.post("/api/embed-bench/rate", json={
|
||||
"query": f"q{i}", "model": "nomic-embed-text",
|
||||
"chunk_text": f"chunk {i}", "chunk_idx": i, "rating": "relevant",
|
||||
})
|
||||
r = client.get("/api/embed-bench/export?format=csv")
|
||||
assert r.status_code == 200
|
||||
assert "text/csv" in r.headers["content-type"]
|
||||
lines = r.text.strip().splitlines()
|
||||
assert len(lines) == 3 # header + 2 rows
|
||||
assert "query" in lines[0]
|
||||
|
||||
|
||||
def test_export_json_two_entries(client, tmp_path):
|
||||
for i in range(2):
|
||||
client.post("/api/embed-bench/rate", json={
|
||||
"query": f"q{i}", "model": "nomic-embed-text",
|
||||
"chunk_text": f"chunk {i}", "chunk_idx": i, "rating": "not_relevant",
|
||||
})
|
||||
r = client.get("/api/embed-bench/export?format=json")
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert isinstance(data, list)
|
||||
assert len(data) == 2
|
||||
assert data[0]["rating"] == "not_relevant"
|
||||
|
||||
|
||||
def test_export_empty_returns_csv_header_only(client):
|
||||
r = client.get("/api/embed-bench/export?format=csv")
|
||||
assert r.status_code == 200
|
||||
lines = r.text.strip().splitlines()
|
||||
assert len(lines) == 1 # header only
|
||||
assert "query" in lines[0]
|
||||
|
|
@ -321,6 +321,7 @@ def test_load_and_prepare_data_single_path_still_works(tmp_path):
|
|||
|
||||
# ---- Integration test ----
|
||||
|
||||
@pytest.mark.gpu
|
||||
def test_integration_finetune_on_example_data(tmp_path):
|
||||
"""Fine-tune deberta-small on example data for 1 epoch.
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
"""Tests for app/imitate.py — product registry, sample extraction, corrections push."""
|
||||
"""Tests for app/imitate.py -- product registry, sample extraction, corrections push."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
|
@ -9,10 +9,10 @@ import pytest
|
|||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.api import app
|
||||
from app import imitate as _imitate_module
|
||||
from app.data import imitate as _imitate_module
|
||||
|
||||
|
||||
# ── Fixtures ───────────────────────────────────────────────────────────────────
|
||||
# -- Fixtures ------------------------------------------------------------------
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_module_globals(tmp_path):
|
||||
|
|
@ -70,7 +70,7 @@ def client() -> TestClient:
|
|||
return TestClient(app, raise_server_exceptions=True)
|
||||
|
||||
|
||||
# ── GET /products ──────────────────────────────────────────────────────────────
|
||||
# -- GET /products -------------------------------------------------------------
|
||||
|
||||
def test_products_empty_when_no_config(config_dir, client):
|
||||
"""Returns empty list when label_tool.yaml has no imitate section."""
|
||||
|
|
@ -102,7 +102,7 @@ def test_products_offline_when_unreachable(cfg_with_products, client):
|
|||
assert all(not p["online"] for p in resp.json()["products"])
|
||||
|
||||
|
||||
# ── GET /products/{id}/sample ─────────────────────────────────────────────────
|
||||
# -- GET /products/{id}/sample -------------------------------------------------
|
||||
|
||||
def test_sample_unknown_product(cfg_with_products, client):
|
||||
"""Returns 404 for a product id not in config."""
|
||||
|
|
@ -149,7 +149,7 @@ def test_sample_404_on_empty_list(cfg_with_products, client):
|
|||
assert resp.status_code == 404
|
||||
|
||||
|
||||
# ── POST /push-corrections ─────────────────────────────────────────────────────
|
||||
# -- POST /push-corrections ----------------------------------------------------
|
||||
|
||||
def test_push_corrections_appends_jsonl(cfg_with_products, data_dir, client):
|
||||
"""Successful push writes records to sft_candidates.jsonl."""
|
||||
|
|
@ -214,7 +214,7 @@ def test_push_corrections_all_errors_422(cfg_with_products, data_dir, client):
|
|||
assert resp.status_code == 422
|
||||
|
||||
|
||||
# ── _extract_sample helper ─────────────────────────────────────────────────────
|
||||
# -- _extract_sample helper ----------------------------------------------------
|
||||
|
||||
def test_extract_sample_list():
|
||||
result = _imitate_module._extract_sample(
|
||||
|
|
|
|||
454
tests/test_log_corpus.py
Normal file
454
tests/test_log_corpus.py
Normal file
|
|
@ -0,0 +1,454 @@
|
|||
"""Tests for app/data/log_corpus.py — corpus receiver and labeling endpoints."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.data import log_corpus as lc
|
||||
|
||||
|
||||
VALID_TOKEN = str(uuid.uuid4())
|
||||
VALID_HOST = "testnode.local"
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def isolated_db(tmp_path, monkeypatch):
|
||||
"""Each test gets its own fresh corpus DB and config dir."""
|
||||
monkeypatch.setattr(lc, "_DATA_DIR", tmp_path)
|
||||
monkeypatch.setattr(lc, "_DB_PATH", tmp_path / "corpus.db")
|
||||
# Config dir pointing to a temp yaml with one test source
|
||||
config_dir = tmp_path / "config"
|
||||
config_dir.mkdir()
|
||||
(config_dir / "label_tool.yaml").write_text(
|
||||
f"corpus:\n sources:\n"
|
||||
f" - token: \"{VALID_TOKEN}\"\n"
|
||||
f" source_host: \"{VALID_HOST}\"\n"
|
||||
f" owner: TestOwner\n"
|
||||
f" consent_date: \"2026-05-11\"\n"
|
||||
f" consent_method: signal_chat\n"
|
||||
)
|
||||
monkeypatch.setattr(lc, "_CONFIG_DIR", config_dir)
|
||||
lc._init_db()
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def client():
|
||||
from fastapi import FastAPI
|
||||
app = FastAPI()
|
||||
app.include_router(lc.router, prefix="/api/corpus")
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
def _batch(batch_type="raw_entries", entries=None, source_host=VALID_HOST):
|
||||
return {
|
||||
"batch_version": 1,
|
||||
"batch_id": str(uuid.uuid4()),
|
||||
"pushed_at": "2026-05-11T10:00:00Z",
|
||||
"source_host": source_host,
|
||||
"batch_type": batch_type,
|
||||
"watermark_from": 0,
|
||||
"watermark_to": 5,
|
||||
"entries": entries or [
|
||||
{
|
||||
"entry_id": str(uuid.uuid4()),
|
||||
"source_id": "sonarr",
|
||||
"timestamp_iso": "2026-05-11T09:58:00Z",
|
||||
"severity": "ERROR",
|
||||
"text": "Connection refused to indexer",
|
||||
"matched_patterns": [],
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
# ── Receive endpoint ───────────────────────────────────────────────────────────
|
||||
|
||||
def test_receive_missing_auth(client):
|
||||
resp = client.post("/api/corpus/log-batch", json=_batch())
|
||||
assert resp.status_code == 401
|
||||
|
||||
|
||||
def test_receive_invalid_token(client):
|
||||
resp = client.post(
|
||||
"/api/corpus/log-batch",
|
||||
json=_batch(),
|
||||
headers={"Authorization": "Bearer bad-token"},
|
||||
)
|
||||
assert resp.status_code == 403
|
||||
|
||||
|
||||
def test_receive_valid_batch(client):
|
||||
resp = client.post(
|
||||
"/api/corpus/log-batch",
|
||||
json=_batch(),
|
||||
headers={"Authorization": f"Bearer {VALID_TOKEN}"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["received"] is True
|
||||
assert data["entries_stored"] == 1
|
||||
|
||||
|
||||
def test_receive_stores_source_host_from_token_not_payload(client):
|
||||
"""source_host is always taken from the DB lookup, not the payload."""
|
||||
payload = _batch(source_host="attacker-injected-host")
|
||||
resp = client.post(
|
||||
"/api/corpus/log-batch",
|
||||
json=payload,
|
||||
headers={"Authorization": f"Bearer {VALID_TOKEN}"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
entries_resp = client.get("/api/corpus/entries")
|
||||
entry = entries_resp.json()["entries"][0]
|
||||
assert entry["source_host"] == VALID_HOST
|
||||
|
||||
|
||||
def test_receive_skips_empty_text_entries(client):
|
||||
payload = _batch(entries=[
|
||||
{"entry_id": "e1", "source_id": "svc", "severity": "ERROR", "text": ""},
|
||||
{"entry_id": "e2", "source_id": "svc", "severity": "ERROR", "text": " "},
|
||||
{"entry_id": "e3", "source_id": "svc", "severity": "ERROR", "text": "real error"},
|
||||
])
|
||||
resp = client.post(
|
||||
"/api/corpus/log-batch",
|
||||
json=payload,
|
||||
headers={"Authorization": f"Bearer {VALID_TOKEN}"},
|
||||
)
|
||||
assert resp.json()["entries_stored"] == 1
|
||||
|
||||
|
||||
def test_receive_incident_bundle(client):
|
||||
payload = _batch(batch_type="incident_bundles", entries=[
|
||||
{"id": "inc-1", "label": "plex crash", "issue_type": "plex",
|
||||
"started_at": "2026-05-11T09:00:00", "ended_at": "2026-05-11T09:30:00",
|
||||
"notes": "audio dropped", "created_at": "2026-05-11T09:35:00",
|
||||
"severity": "high", "text": "plex crash"},
|
||||
])
|
||||
resp = client.post(
|
||||
"/api/corpus/log-batch",
|
||||
json=payload,
|
||||
headers={"Authorization": f"Bearer {VALID_TOKEN}"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["entries_stored"] == 1
|
||||
|
||||
|
||||
# ── Labeling endpoints ─────────────────────────────────────────────────────────
|
||||
|
||||
def test_label_entry(client):
|
||||
client.post(
|
||||
"/api/corpus/log-batch",
|
||||
json=_batch(),
|
||||
headers={"Authorization": f"Bearer {VALID_TOKEN}"},
|
||||
)
|
||||
entry_id = client.get("/api/corpus/entries").json()["entries"][0]["id"]
|
||||
|
||||
resp = client.post(f"/api/corpus/entries/{entry_id}/label", json={
|
||||
"failure_type": "software",
|
||||
"plain_explanation": "Sonarr lost connection to its indexer — restart the service.",
|
||||
"known_pattern": "y",
|
||||
})
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["labeled"] is True
|
||||
|
||||
entries = client.get("/api/corpus/entries", params={"state": "labeled"}).json()["entries"]
|
||||
assert len(entries) == 1
|
||||
assert entries[0]["failure_type"] == "software"
|
||||
|
||||
|
||||
def test_label_entry_invalid_failure_type(client):
|
||||
client.post(
|
||||
"/api/corpus/log-batch",
|
||||
json=_batch(),
|
||||
headers={"Authorization": f"Bearer {VALID_TOKEN}"},
|
||||
)
|
||||
entry_id = client.get("/api/corpus/entries").json()["entries"][0]["id"]
|
||||
resp = client.post(f"/api/corpus/entries/{entry_id}/label", json={"failure_type": "aliens"})
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
def test_label_entry_missing_failure_type(client):
|
||||
client.post(
|
||||
"/api/corpus/log-batch",
|
||||
json=_batch(),
|
||||
headers={"Authorization": f"Bearer {VALID_TOKEN}"},
|
||||
)
|
||||
entry_id = client.get("/api/corpus/entries").json()["entries"][0]["id"]
|
||||
resp = client.post(f"/api/corpus/entries/{entry_id}/label", json={})
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
def test_label_entry_not_found(client):
|
||||
resp = client.post("/api/corpus/entries/nonexistent/label", json={"failure_type": "software"})
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
def test_skip_entry(client):
|
||||
client.post(
|
||||
"/api/corpus/log-batch",
|
||||
json=_batch(),
|
||||
headers={"Authorization": f"Bearer {VALID_TOKEN}"},
|
||||
)
|
||||
entry_id = client.get("/api/corpus/entries").json()["entries"][0]["id"]
|
||||
resp = client.post(f"/api/corpus/entries/{entry_id}/skip")
|
||||
assert resp.status_code == 200
|
||||
|
||||
unlabeled = client.get("/api/corpus/entries").json()["entries"]
|
||||
assert len(unlabeled) == 0
|
||||
|
||||
|
||||
# ── Stats ──────────────────────────────────────────────────────────────────────
|
||||
|
||||
def test_stats_empty(client):
|
||||
stats = client.get("/api/corpus/stats").json()
|
||||
assert stats["total_entries"] == 0
|
||||
assert stats["batch_count"] == 0
|
||||
|
||||
|
||||
def test_stats_after_receive(client):
|
||||
client.post(
|
||||
"/api/corpus/log-batch",
|
||||
json=_batch(),
|
||||
headers={"Authorization": f"Bearer {VALID_TOKEN}"},
|
||||
)
|
||||
stats = client.get("/api/corpus/stats").json()
|
||||
assert stats["total_entries"] == 1
|
||||
assert stats["batch_count"] == 1
|
||||
assert stats["by_label_state"].get("unlabeled", 0) == 1
|
||||
|
||||
|
||||
# ── Export ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
def test_export_excludes_unlabeled(client):
|
||||
client.post(
|
||||
"/api/corpus/log-batch",
|
||||
json=_batch(),
|
||||
headers={"Authorization": f"Bearer {VALID_TOKEN}"},
|
||||
)
|
||||
resp = client.get("/api/corpus/export")
|
||||
assert resp.status_code == 200
|
||||
assert resp.text.strip() == ""
|
||||
|
||||
|
||||
def test_export_includes_labeled(client):
|
||||
client.post(
|
||||
"/api/corpus/log-batch",
|
||||
json=_batch(),
|
||||
headers={"Authorization": f"Bearer {VALID_TOKEN}"},
|
||||
)
|
||||
entry_id = client.get("/api/corpus/entries").json()["entries"][0]["id"]
|
||||
client.post(f"/api/corpus/entries/{entry_id}/label", json={
|
||||
"failure_type": "software",
|
||||
"plain_explanation": "Sonarr lost connection to indexer.",
|
||||
})
|
||||
|
||||
resp = client.get("/api/corpus/export")
|
||||
assert resp.status_code == 200
|
||||
lines = [l for l in resp.text.strip().splitlines() if l]
|
||||
assert len(lines) == 1
|
||||
record = json.loads(lines[0])
|
||||
assert record["output"] == "Sonarr lost connection to indexer."
|
||||
assert record["metadata"]["failure_type"] == "software"
|
||||
|
||||
|
||||
def test_export_excludes_pii_flagged(client):
|
||||
client.post(
|
||||
"/api/corpus/log-batch",
|
||||
json=_batch(),
|
||||
headers={"Authorization": f"Bearer {VALID_TOKEN}"},
|
||||
)
|
||||
entry_id = client.get("/api/corpus/entries").json()["entries"][0]["id"]
|
||||
client.post(f"/api/corpus/entries/{entry_id}/label", json={
|
||||
"failure_type": "software",
|
||||
"plain_explanation": "Contains username — should not export.",
|
||||
"pii_flagged": True,
|
||||
})
|
||||
|
||||
resp = client.get("/api/corpus/export")
|
||||
assert resp.text.strip() == ""
|
||||
|
||||
|
||||
# ── Pipeline ingest endpoint ───────────────────────────────────────────────────
|
||||
|
||||
def _make_pipeline_file(directory: Path, name: str, lines: list[dict]) -> Path:
|
||||
"""Write a JSONL pipeline log file to directory."""
|
||||
p = directory / name
|
||||
p.write_text("\n".join(json.dumps(l) for l in lines), encoding="utf-8")
|
||||
return p
|
||||
|
||||
|
||||
_PIPELINE_LINE = {
|
||||
"ts": "2026-05-17T10:00:00Z",
|
||||
"level": "INFO",
|
||||
"logger": "scripts.pipeline.purple_carrot_scraper",
|
||||
"msg": "Fetched recipe page",
|
||||
"extra": {"url": "https://example.com/recipe/1", "status": 200},
|
||||
}
|
||||
|
||||
|
||||
def test_pipeline_ingest_returns_404_when_dir_not_configured(client, tmp_path):
|
||||
"""No pipeline_ingest_dir in config — endpoint returns 404."""
|
||||
resp = client.post("/api/corpus/pipeline-ingest")
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
def test_pipeline_ingest_empty_dir(client, tmp_path, monkeypatch):
|
||||
"""Configured dir exists but is empty — returns zeros, no error."""
|
||||
ingest_dir = tmp_path / "pipeline_logs"
|
||||
ingest_dir.mkdir()
|
||||
config_dir = tmp_path / "config"
|
||||
config_dir.mkdir(exist_ok=True)
|
||||
(config_dir / "label_tool.yaml").write_text(
|
||||
f"corpus:\n pipeline_ingest_dir: \"{ingest_dir}\"\n sources: []\n"
|
||||
)
|
||||
monkeypatch.setattr(lc, "_CONFIG_DIR", config_dir)
|
||||
|
||||
resp = client.post("/api/corpus/pipeline-ingest")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["ingested_files"] == 0
|
||||
assert data["skipped_files"] == 0
|
||||
assert data["entries_stored"] == 0
|
||||
|
||||
|
||||
def test_pipeline_ingest_ingests_valid_file(client, tmp_path, monkeypatch):
|
||||
"""Valid JSONL file is ingested; entries appear in corpus."""
|
||||
ingest_dir = tmp_path / "pipeline_logs"
|
||||
ingest_dir.mkdir()
|
||||
_make_pipeline_file(ingest_dir, "scraper_20260517.jsonl", [
|
||||
_PIPELINE_LINE,
|
||||
{**_PIPELINE_LINE, "msg": "Saved 3 recipes", "level": "INFO"},
|
||||
])
|
||||
|
||||
config_dir = tmp_path / "config"
|
||||
config_dir.mkdir(exist_ok=True)
|
||||
(config_dir / "label_tool.yaml").write_text(
|
||||
f"corpus:\n pipeline_ingest_dir: \"{ingest_dir}\"\n sources: []\n"
|
||||
)
|
||||
monkeypatch.setattr(lc, "_CONFIG_DIR", config_dir)
|
||||
|
||||
resp = client.post("/api/corpus/pipeline-ingest")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["ingested_files"] == 1
|
||||
assert data["entries_stored"] == 2
|
||||
|
||||
entries = client.get("/api/corpus/entries", params={"limit": 10}).json()["entries"]
|
||||
assert len(entries) == 2
|
||||
assert all(e["source_host"] == "pipeline_scrape" for e in entries)
|
||||
|
||||
|
||||
def test_pipeline_ingest_source_id_from_logger(client, tmp_path, monkeypatch):
|
||||
"""source_id is populated from the 'logger' field of each log line."""
|
||||
ingest_dir = tmp_path / "pipeline_logs"
|
||||
ingest_dir.mkdir()
|
||||
_make_pipeline_file(ingest_dir, "run_20260517.jsonl", [_PIPELINE_LINE])
|
||||
|
||||
config_dir = tmp_path / "config"
|
||||
config_dir.mkdir(exist_ok=True)
|
||||
(config_dir / "label_tool.yaml").write_text(
|
||||
f"corpus:\n pipeline_ingest_dir: \"{ingest_dir}\"\n sources: []\n"
|
||||
)
|
||||
monkeypatch.setattr(lc, "_CONFIG_DIR", config_dir)
|
||||
|
||||
client.post("/api/corpus/pipeline-ingest")
|
||||
entries = client.get("/api/corpus/entries", params={"limit": 10}).json()["entries"]
|
||||
assert entries[0]["source_id"] == "scripts.pipeline.purple_carrot_scraper"
|
||||
|
||||
|
||||
def test_pipeline_ingest_idempotent(client, tmp_path, monkeypatch):
|
||||
"""Calling the endpoint twice does not re-ingest already-processed files."""
|
||||
ingest_dir = tmp_path / "pipeline_logs"
|
||||
ingest_dir.mkdir()
|
||||
_make_pipeline_file(ingest_dir, "scraper_20260517.jsonl", [_PIPELINE_LINE])
|
||||
|
||||
config_dir = tmp_path / "config"
|
||||
config_dir.mkdir(exist_ok=True)
|
||||
(config_dir / "label_tool.yaml").write_text(
|
||||
f"corpus:\n pipeline_ingest_dir: \"{ingest_dir}\"\n sources: []\n"
|
||||
)
|
||||
monkeypatch.setattr(lc, "_CONFIG_DIR", config_dir)
|
||||
|
||||
client.post("/api/corpus/pipeline-ingest")
|
||||
resp2 = client.post("/api/corpus/pipeline-ingest")
|
||||
|
||||
data = resp2.json()
|
||||
assert data["ingested_files"] == 0
|
||||
assert data["skipped_files"] == 1
|
||||
assert data["entries_stored"] == 0
|
||||
|
||||
entries = client.get("/api/corpus/entries", params={"limit": 10}).json()["entries"]
|
||||
assert len(entries) == 1 # still just the one from the first ingest
|
||||
|
||||
|
||||
def test_pipeline_ingest_skips_non_jsonl(client, tmp_path, monkeypatch):
|
||||
"""Non-.jsonl files in the dir are silently ignored."""
|
||||
ingest_dir = tmp_path / "pipeline_logs"
|
||||
ingest_dir.mkdir()
|
||||
(ingest_dir / "notes.txt").write_text("this is not a log file")
|
||||
(ingest_dir / "run.csv").write_text("a,b,c\n1,2,3")
|
||||
_make_pipeline_file(ingest_dir, "valid_20260517.jsonl", [_PIPELINE_LINE])
|
||||
|
||||
config_dir = tmp_path / "config"
|
||||
config_dir.mkdir(exist_ok=True)
|
||||
(config_dir / "label_tool.yaml").write_text(
|
||||
f"corpus:\n pipeline_ingest_dir: \"{ingest_dir}\"\n sources: []\n"
|
||||
)
|
||||
monkeypatch.setattr(lc, "_CONFIG_DIR", config_dir)
|
||||
|
||||
resp = client.post("/api/corpus/pipeline-ingest")
|
||||
assert resp.json()["ingested_files"] == 1
|
||||
|
||||
|
||||
def test_pipeline_ingest_skips_malformed_lines(client, tmp_path, monkeypatch):
|
||||
"""Lines that are not valid JSON are skipped; valid lines in the same file still land."""
|
||||
ingest_dir = tmp_path / "pipeline_logs"
|
||||
ingest_dir.mkdir()
|
||||
p = ingest_dir / "mixed_20260517.jsonl"
|
||||
p.write_text(
|
||||
json.dumps(_PIPELINE_LINE) + "\n"
|
||||
"this is not json\n"
|
||||
+ json.dumps({**_PIPELINE_LINE, "msg": "another valid line"}),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
config_dir = tmp_path / "config"
|
||||
config_dir.mkdir(exist_ok=True)
|
||||
(config_dir / "label_tool.yaml").write_text(
|
||||
f"corpus:\n pipeline_ingest_dir: \"{ingest_dir}\"\n sources: []\n"
|
||||
)
|
||||
monkeypatch.setattr(lc, "_CONFIG_DIR", config_dir)
|
||||
|
||||
resp = client.post("/api/corpus/pipeline-ingest")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["entries_stored"] == 2 # 2 valid lines, 1 skipped
|
||||
|
||||
|
||||
def test_pipeline_ingest_new_file_after_first_run(client, tmp_path, monkeypatch):
|
||||
"""A new file added after the first ingest is picked up on the next call."""
|
||||
ingest_dir = tmp_path / "pipeline_logs"
|
||||
ingest_dir.mkdir()
|
||||
_make_pipeline_file(ingest_dir, "run_a.jsonl", [_PIPELINE_LINE])
|
||||
|
||||
config_dir = tmp_path / "config"
|
||||
config_dir.mkdir(exist_ok=True)
|
||||
(config_dir / "label_tool.yaml").write_text(
|
||||
f"corpus:\n pipeline_ingest_dir: \"{ingest_dir}\"\n sources: []\n"
|
||||
)
|
||||
monkeypatch.setattr(lc, "_CONFIG_DIR", config_dir)
|
||||
|
||||
client.post("/api/corpus/pipeline-ingest") # ingest run_a.jsonl
|
||||
|
||||
_make_pipeline_file(ingest_dir, "run_b.jsonl", [
|
||||
{**_PIPELINE_LINE, "msg": "Second run line"},
|
||||
])
|
||||
|
||||
resp2 = client.post("/api/corpus/pipeline-ingest")
|
||||
data = resp2.json()
|
||||
assert data["ingested_files"] == 1
|
||||
assert data["skipped_files"] == 1
|
||||
assert data["entries_stored"] == 1
|
||||
|
|
@ -17,6 +17,7 @@ def reset_models_globals(tmp_path):
|
|||
from app import models as models_module
|
||||
|
||||
prev_models = models_module._MODELS_DIR
|
||||
prev_cf_text = models_module._CF_TEXT_MODELS_DIR
|
||||
prev_queue = models_module._QUEUE_DIR
|
||||
prev_progress = dict(models_module._download_progress)
|
||||
|
||||
|
|
@ -26,12 +27,14 @@ def reset_models_globals(tmp_path):
|
|||
queue_dir.mkdir()
|
||||
|
||||
models_module.set_models_dir(models_dir)
|
||||
models_module.set_cf_text_models_dir(tmp_path / "cf-text-models")
|
||||
models_module.set_queue_dir(queue_dir)
|
||||
models_module._download_progress = {}
|
||||
|
||||
yield
|
||||
|
||||
models_module.set_models_dir(prev_models)
|
||||
models_module.set_cf_text_models_dir(prev_cf_text)
|
||||
models_module.set_queue_dir(prev_queue)
|
||||
models_module._download_progress = prev_progress
|
||||
|
||||
|
|
@ -541,3 +544,84 @@ def test_delete_installed_name_with_slash_blocked(client):
|
|||
except _HTTPException as exc:
|
||||
assert exc.status_code in (400, 404)
|
||||
raise
|
||||
|
||||
|
||||
# ── Catalog registration ───────────────────────────────────────────────────────
|
||||
|
||||
_MINIMAL_YAML = """\
|
||||
services:
|
||||
cf-text:
|
||||
max_mb: {max_mb}
|
||||
catalog:
|
||||
existing-model:
|
||||
path: /some/path
|
||||
vram_mb: 1000
|
||||
description: "placeholder"
|
||||
"""
|
||||
|
||||
|
||||
def _make_node_yaml(tmp_path: Path, max_mb: int = 8192) -> Path:
|
||||
p = tmp_path / "testnode.yaml"
|
||||
p.write_text(_MINIMAL_YAML.format(max_mb=max_mb), encoding="utf-8")
|
||||
return p
|
||||
|
||||
|
||||
def test_catalog_registration_fp16_no_env_block(tmp_path):
|
||||
"""When model fits at FP16, no env block should be written."""
|
||||
from app import models as models_module
|
||||
|
||||
node_yaml = _make_node_yaml(tmp_path, max_mb=8192)
|
||||
with patch.object(models_module, "_CF_ORCH_PROFILES_DIR", tmp_path):
|
||||
updated = models_module._register_in_node_catalogs(
|
||||
repo_id="org/SmallModel",
|
||||
local_path=tmp_path / "org--SmallModel",
|
||||
vram_mb_fp16=4000,
|
||||
role="generator",
|
||||
)
|
||||
|
||||
assert "testnode" in updated
|
||||
content = node_yaml.read_text()
|
||||
# _catalog_key strips org prefix and lowercases: "org/SmallModel" → "smallmodel"
|
||||
assert "smallmodel:" in content
|
||||
assert "CF_TEXT_4BIT" not in content
|
||||
assert "env:" not in content
|
||||
|
||||
|
||||
def test_catalog_registration_needs_4bit_writes_env_block(tmp_path):
|
||||
"""When model only fits at 4-bit, env: CF_TEXT_4BIT: '1' must be written."""
|
||||
from app import models as models_module
|
||||
|
||||
node_yaml = _make_node_yaml(tmp_path, max_mb=8192)
|
||||
with patch.object(models_module, "_CF_ORCH_PROFILES_DIR", tmp_path):
|
||||
updated = models_module._register_in_node_catalogs(
|
||||
repo_id="org/BigModel",
|
||||
local_path=tmp_path / "org--BigModel",
|
||||
vram_mb_fp16=20000, # won't fit at FP16 on 8 GB
|
||||
role="generator",
|
||||
)
|
||||
|
||||
assert "testnode" in updated
|
||||
content = node_yaml.read_text()
|
||||
# _catalog_key: "org/BigModel" → "bigmodel"
|
||||
assert "bigmodel:" in content
|
||||
assert "env:" in content
|
||||
assert 'CF_TEXT_4BIT: "1"' in content
|
||||
assert "CF_TEXT_4BIT=1 required" in content # description note
|
||||
|
||||
|
||||
def test_catalog_registration_too_large_skipped(tmp_path):
|
||||
"""Model too large even at 4-bit should not be registered."""
|
||||
from app import models as models_module
|
||||
|
||||
node_yaml = _make_node_yaml(tmp_path, max_mb=8192)
|
||||
with patch.object(models_module, "_CF_ORCH_PROFILES_DIR", tmp_path):
|
||||
updated = models_module._register_in_node_catalogs(
|
||||
repo_id="org/HugeModel",
|
||||
local_path=tmp_path / "org--HugeModel",
|
||||
vram_mb_fp16=80000, # 4-bit ~22 GB, still won't fit on 8 GB
|
||||
role="generator",
|
||||
)
|
||||
|
||||
assert updated == []
|
||||
content = node_yaml.read_text()
|
||||
assert "hugemodel" not in content
|
||||
|
|
|
|||
575
tests/test_nodes.py
Normal file
575
tests/test_nodes.py
Normal file
|
|
@ -0,0 +1,575 @@
|
|||
"""Tests for app/nodes.py — /api/nodes-mgmt/* endpoints."""
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
from fastapi.testclient import TestClient
|
||||
from unittest.mock import MagicMock, patch
|
||||
import os as _os
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_nodes_globals(tmp_path):
|
||||
"""Redirect _CONFIG_DIR to tmp_path so tests never read the real config."""
|
||||
from app import nodes as nodes_module
|
||||
prev = nodes_module._CONFIG_DIR
|
||||
nodes_module.set_config_dir(tmp_path)
|
||||
yield tmp_path
|
||||
nodes_module.set_config_dir(prev)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
from app.api import app
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
def _write_config(config_dir: Path, cforch_cfg: dict) -> None:
|
||||
cfg = {"cforch": cforch_cfg}
|
||||
(config_dir / "label_tool.yaml").write_text(yaml.dump(cfg), encoding="utf-8")
|
||||
|
||||
|
||||
def _write_profile(profiles_dir: Path, node_id: str, profile: dict) -> None:
|
||||
profiles_dir.mkdir(parents=True, exist_ok=True)
|
||||
(profiles_dir / f"{node_id}.yaml").write_text(yaml.dump(profile), encoding="utf-8")
|
||||
|
||||
|
||||
def test_nodes_module_imports():
|
||||
from app import nodes
|
||||
assert hasattr(nodes, "router")
|
||||
assert hasattr(nodes, "set_config_dir")
|
||||
|
||||
|
||||
def test_list_nodes_returns_empty_when_no_coordinator(client):
|
||||
"""No cforch config — endpoint returns empty list, not 500."""
|
||||
r = client.get("/api/nodes-mgmt/nodes")
|
||||
assert r.status_code == 200
|
||||
assert r.json() == []
|
||||
|
||||
|
||||
|
||||
|
||||
def _fake_nodes_response(nodes_json: list, services_json: list | None = None):
|
||||
"""Build side_effect list for two httpx.get calls: nodes then services."""
|
||||
mock_nodes = MagicMock()
|
||||
mock_nodes.raise_for_status = MagicMock()
|
||||
mock_nodes.json.return_value = {"nodes": nodes_json}
|
||||
|
||||
mock_services = MagicMock()
|
||||
mock_services.raise_for_status = MagicMock()
|
||||
mock_services.json.return_value = {"services": services_json or []}
|
||||
|
||||
return [mock_nodes, mock_services]
|
||||
|
||||
|
||||
def test_list_nodes_coordinator_unreachable_returns_empty(client, tmp_path):
|
||||
"""Coordinator unreachable — returns [] with no 500."""
|
||||
import httpx
|
||||
_write_config(tmp_path, {"coordinator_url": "http://fake-coord:7700"})
|
||||
with patch("httpx.get", side_effect=httpx.ConnectError("refused")):
|
||||
r = client.get("/api/nodes-mgmt/nodes")
|
||||
assert r.status_code == 200
|
||||
assert r.json() == []
|
||||
|
||||
|
||||
def test_list_nodes_merges_profile_data(client, tmp_path):
|
||||
"""Profile YAML services_assigned merged with live GPU stats."""
|
||||
profiles_dir = tmp_path / "profiles"
|
||||
_write_config(tmp_path, {
|
||||
"coordinator_url": "http://fake-coord:7700",
|
||||
"profiles_dir": str(profiles_dir),
|
||||
})
|
||||
_write_profile(profiles_dir, "heimdall", {
|
||||
"services": {
|
||||
"cf-text": {"min_compute_cap": 7.0, "max_mb": 8192, "catalog": {}},
|
||||
},
|
||||
"nodes": {
|
||||
"heimdall": {
|
||||
"gpus": [{"id": 0, "vram_mb": 24576, "compute_cap": 8.6,
|
||||
"services": ["cf-text"], "role": "primary", "card": "RTX 3090",
|
||||
"always_on": True}],
|
||||
"agent_url": "http://10.1.10.71:7701",
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
coord_nodes = [{
|
||||
"node_id": "heimdall", "online": True, "agent_url": "http://10.1.10.71:7701",
|
||||
"gpus": [{"gpu_id": 0, "card": "RTX 3090", "vram_total_mb": 24576,
|
||||
"vram_used_mb": 4096, "vram_free_mb": 20480,
|
||||
"temp_c": 42.0, "utilization_pct": 15.0, "compute_cap": 8.6}],
|
||||
}]
|
||||
|
||||
with patch("httpx.get", side_effect=_fake_nodes_response(coord_nodes)):
|
||||
r = client.get("/api/nodes-mgmt/nodes")
|
||||
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert len(data) == 1
|
||||
node = data[0]
|
||||
assert node["node_id"] == "heimdall"
|
||||
assert node["profile_loaded"] is True
|
||||
assert node["gpus"][0]["services_assigned"] == ["cf-text"]
|
||||
assert node["gpus"][0]["vram_total_mb"] == 24576
|
||||
assert "cf-text" in node["services_catalog"]
|
||||
|
||||
|
||||
def test_list_nodes_no_profile_returns_profile_loaded_false(client, tmp_path):
|
||||
"""Node with no profile YAML — profile_loaded: false, GPU stats still returned."""
|
||||
_write_config(tmp_path, {"coordinator_url": "http://fake-coord:7700"})
|
||||
|
||||
coord_nodes = [{
|
||||
"node_id": "sif", "online": True, "agent_url": "http://10.1.10.158:7701",
|
||||
"gpus": [{"gpu_id": 0, "card": "RTX 5060 Ti", "vram_total_mb": 16384,
|
||||
"vram_used_mb": 0, "vram_free_mb": 16384,
|
||||
"temp_c": None, "utilization_pct": None, "compute_cap": 10.0}],
|
||||
}]
|
||||
|
||||
with patch("httpx.get", side_effect=_fake_nodes_response(coord_nodes)):
|
||||
r = client.get("/api/nodes-mgmt/nodes")
|
||||
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
node = data[0]
|
||||
assert node["profile_loaded"] is False
|
||||
assert node["gpus"][0]["card"] == "RTX 5060 Ti"
|
||||
assert node["services_catalog"] == {}
|
||||
|
||||
|
||||
def test_list_nodes_marks_running_services(client, tmp_path):
|
||||
"""services_running populated from coordinator /api/services response."""
|
||||
profiles_dir = tmp_path / "profiles"
|
||||
_write_config(tmp_path, {
|
||||
"coordinator_url": "http://fake-coord:7700",
|
||||
"profiles_dir": str(profiles_dir),
|
||||
})
|
||||
_write_profile(profiles_dir, "heimdall", {
|
||||
"services": {},
|
||||
"nodes": {"heimdall": {"gpus": [{"id": 0, "vram_mb": 24576, "compute_cap": 8.6,
|
||||
"services": ["cf-text"], "role": "p",
|
||||
"card": "RTX 3090", "always_on": True}],
|
||||
"agent_url": "http://10.1.10.71:7701"}}
|
||||
})
|
||||
|
||||
coord_nodes = [{"node_id": "heimdall", "online": True,
|
||||
"agent_url": "http://10.1.10.71:7701",
|
||||
"gpus": [{"gpu_id": 0, "card": "RTX 3090", "vram_total_mb": 24576,
|
||||
"vram_used_mb": 8192, "vram_free_mb": 16384,
|
||||
"temp_c": 55.0, "utilization_pct": 80.0, "compute_cap": 8.6}]}]
|
||||
coord_services = [{"service": "cf-text", "node_id": "heimdall", "gpu_id": 0}]
|
||||
|
||||
with patch("httpx.get", side_effect=_fake_nodes_response(coord_nodes, coord_services)):
|
||||
r = client.get("/api/nodes-mgmt/nodes")
|
||||
|
||||
data = r.json()
|
||||
assert data[0]["gpus"][0]["services_running"] == ["cf-text"]
|
||||
|
||||
|
||||
# ── GET /api/nodes-mgmt/nodes/{node_id}/profile ────────────────────────────────
|
||||
|
||||
def test_get_profile_returns_parsed_yaml(client, tmp_path):
|
||||
profiles_dir = tmp_path / "profiles"
|
||||
_write_config(tmp_path, {"profiles_dir": str(profiles_dir)})
|
||||
profile = {
|
||||
"services": {"cf-text": {"min_compute_cap": 7.0, "max_mb": 8192, "catalog": {}}},
|
||||
"nodes": {"heimdall": {"gpus": [], "agent_url": "http://10.1.10.71:7701"}},
|
||||
}
|
||||
_write_profile(profiles_dir, "heimdall", profile)
|
||||
|
||||
r = client.get("/api/nodes-mgmt/nodes/heimdall/profile")
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert "services" in data
|
||||
assert "cf-text" in data["services"]
|
||||
|
||||
|
||||
def test_get_profile_404_when_missing(client, tmp_path):
|
||||
_write_config(tmp_path, {"profiles_dir": str(tmp_path / "profiles")})
|
||||
r = client.get("/api/nodes-mgmt/nodes/nonexistent/profile")
|
||||
assert r.status_code == 404
|
||||
|
||||
|
||||
def test_get_profile_500_on_malformed_yaml(client, tmp_path):
|
||||
profiles_dir = tmp_path / "profiles"
|
||||
profiles_dir.mkdir()
|
||||
_write_config(tmp_path, {"profiles_dir": str(profiles_dir)})
|
||||
(profiles_dir / "bad.yaml").write_text("key: [unclosed", encoding="utf-8")
|
||||
|
||||
r = client.get("/api/nodes-mgmt/nodes/bad/profile")
|
||||
assert r.status_code == 500
|
||||
|
||||
|
||||
# ── POST /api/nodes-mgmt/nodes/{node_id}/gpu/{gpu_id}/services ─────────────────
|
||||
|
||||
|
||||
_BASE_PROFILE = {
|
||||
"services": {
|
||||
"cf-text": {"min_compute_cap": 7.0, "max_mb": 8192, "priority": 1,
|
||||
"catalog": {"llama3": {"vram_mb": 6144, "path": "/m/llama3",
|
||||
"description": "", "multi_gpu": False, "env": {}}}},
|
||||
"ollama": {"min_compute_cap": 0.0, "max_mb": 2048, "priority": 2, "catalog": {}},
|
||||
},
|
||||
"nodes": {
|
||||
"heimdall": {
|
||||
"gpus": [{"id": 0, "vram_mb": 24576, "compute_cap": 8.6,
|
||||
"services": [], "role": "primary", "card": "RTX 3090",
|
||||
"always_on": True}],
|
||||
"agent_url": "http://10.1.10.71:7701",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def _setup_profile(tmp_path, profile=None):
|
||||
profiles_dir = tmp_path / "profiles"
|
||||
_write_config(tmp_path, {
|
||||
"coordinator_url": "http://fake-coord:7700",
|
||||
"profiles_dir": str(profiles_dir),
|
||||
})
|
||||
_write_profile(profiles_dir, "heimdall", profile or _BASE_PROFILE)
|
||||
return profiles_dir
|
||||
|
||||
|
||||
def test_update_services_compatible_writes_and_reloads(client, tmp_path):
|
||||
profiles_dir = _setup_profile(tmp_path)
|
||||
|
||||
mock_reload = MagicMock()
|
||||
mock_reload.status_code = 200
|
||||
|
||||
with patch("httpx.post", return_value=mock_reload):
|
||||
r = client.post(
|
||||
"/api/nodes-mgmt/nodes/heimdall/gpu/0/services",
|
||||
json={"services": ["cf-text"]},
|
||||
)
|
||||
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["ok"] is True
|
||||
assert data["reloaded"] is True
|
||||
|
||||
saved = yaml.safe_load((profiles_dir / "heimdall.yaml").read_text())
|
||||
assert saved["nodes"]["heimdall"]["gpus"][0]["services"] == ["cf-text"]
|
||||
|
||||
|
||||
def test_update_services_atomic_write_uses_tmp_file(client, tmp_path):
|
||||
"""YAML must be written to .tmp then renamed — never written directly."""
|
||||
profiles_dir = _setup_profile(tmp_path)
|
||||
renamed_pairs: list[tuple] = []
|
||||
|
||||
original_replace = _os.replace
|
||||
|
||||
def capture(src, dst):
|
||||
renamed_pairs.append((str(src), str(dst)))
|
||||
original_replace(src, dst)
|
||||
|
||||
with patch("os.replace", side_effect=capture), \
|
||||
patch("httpx.post", return_value=MagicMock(status_code=200)):
|
||||
client.post(
|
||||
"/api/nodes-mgmt/nodes/heimdall/gpu/0/services",
|
||||
json={"services": ["ollama"]},
|
||||
)
|
||||
|
||||
assert any(src.endswith(".tmp") for src, dst in renamed_pairs), \
|
||||
"Expected atomic write via .tmp rename"
|
||||
|
||||
|
||||
def test_update_services_incompatible_compute_cap_returns_422(client, tmp_path):
|
||||
low_cap_profile = {
|
||||
**_BASE_PROFILE,
|
||||
"nodes": {
|
||||
"heimdall": {
|
||||
"gpus": [{"id": 0, "vram_mb": 24576, "compute_cap": 6.0,
|
||||
"services": [], "role": "p", "card": "GTX 1080",
|
||||
"always_on": False}],
|
||||
"agent_url": "http://10.1.10.71:7701",
|
||||
}
|
||||
}
|
||||
}
|
||||
_setup_profile(tmp_path, low_cap_profile)
|
||||
|
||||
r = client.post(
|
||||
"/api/nodes-mgmt/nodes/heimdall/gpu/0/services",
|
||||
json={"services": ["cf-text"]},
|
||||
)
|
||||
assert r.status_code == 422
|
||||
assert "compute_cap" in r.json()["detail"]
|
||||
|
||||
|
||||
def test_update_services_insufficient_vram_returns_422(client, tmp_path):
|
||||
tiny_vram_profile = {
|
||||
**_BASE_PROFILE,
|
||||
"nodes": {
|
||||
"heimdall": {
|
||||
"gpus": [{"id": 0, "vram_mb": 512, "compute_cap": 8.6,
|
||||
"services": [], "role": "p", "card": "old",
|
||||
"always_on": False}],
|
||||
"agent_url": "http://10.1.10.71:7701",
|
||||
}
|
||||
}
|
||||
}
|
||||
_setup_profile(tmp_path, tiny_vram_profile)
|
||||
|
||||
r = client.post(
|
||||
"/api/nodes-mgmt/nodes/heimdall/gpu/0/services",
|
||||
json={"services": ["cf-text"]},
|
||||
)
|
||||
assert r.status_code == 422
|
||||
assert "VRAM" in r.json()["detail"]
|
||||
|
||||
|
||||
def test_update_services_unknown_service_returns_422(client, tmp_path):
|
||||
_setup_profile(tmp_path)
|
||||
r = client.post(
|
||||
"/api/nodes-mgmt/nodes/heimdall/gpu/0/services",
|
||||
json={"services": ["not-a-real-service"]},
|
||||
)
|
||||
assert r.status_code == 422
|
||||
|
||||
|
||||
def test_update_services_reload_failure_returns_reloaded_false(client, tmp_path):
|
||||
"""YAML saved but coordinator reload fails — ok: true, reloaded: false."""
|
||||
_setup_profile(tmp_path)
|
||||
|
||||
mock_reload = MagicMock()
|
||||
mock_reload.status_code = 500
|
||||
|
||||
with patch("httpx.post", return_value=mock_reload):
|
||||
r = client.post(
|
||||
"/api/nodes-mgmt/nodes/heimdall/gpu/0/services",
|
||||
json={"services": ["ollama"]},
|
||||
)
|
||||
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["ok"] is True
|
||||
assert data["reloaded"] is False
|
||||
|
||||
# ── Ollama endpoints ───────────────────────────────────────────────────────────
|
||||
|
||||
_OLLAMA_PROFILE = {
|
||||
"services": {},
|
||||
"nodes": {
|
||||
"heimdall": {
|
||||
"gpus": [],
|
||||
"agent_url": "http://10.1.10.71:7701",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def test_list_ollama_models_proxies_tags(client, tmp_path):
|
||||
profiles_dir = tmp_path / "profiles"
|
||||
_write_config(tmp_path, {"profiles_dir": str(profiles_dir)})
|
||||
_write_profile(profiles_dir, "heimdall", _OLLAMA_PROFILE)
|
||||
|
||||
mock_tags = MagicMock()
|
||||
mock_tags.raise_for_status = MagicMock()
|
||||
mock_tags.json.return_value = {
|
||||
"models": [{"name": "nomic-embed-text", "size": 274000000, "modified_at": "2025-01-01"}]
|
||||
}
|
||||
|
||||
with patch("httpx.get", return_value=mock_tags):
|
||||
r = client.get("/api/nodes-mgmt/nodes/heimdall/models/ollama")
|
||||
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert len(data["models"]) == 1
|
||||
assert data["models"][0]["name"] == "nomic-embed-text"
|
||||
|
||||
|
||||
def test_list_ollama_models_unreachable_returns_error(client, tmp_path):
|
||||
import httpx as _httpx
|
||||
profiles_dir = tmp_path / "profiles"
|
||||
_write_config(tmp_path, {"profiles_dir": str(profiles_dir)})
|
||||
_write_profile(profiles_dir, "heimdall", _OLLAMA_PROFILE)
|
||||
|
||||
with patch("httpx.get", side_effect=_httpx.ConnectError("refused")):
|
||||
r = client.get("/api/nodes-mgmt/nodes/heimdall/models/ollama")
|
||||
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert "error" in data
|
||||
|
||||
|
||||
def test_pull_ollama_model_streams_sse(client, tmp_path):
|
||||
profiles_dir = tmp_path / "profiles"
|
||||
_write_config(tmp_path, {"profiles_dir": str(profiles_dir)})
|
||||
_write_profile(profiles_dir, "heimdall", _OLLAMA_PROFILE)
|
||||
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.iter_lines.return_value = iter([
|
||||
'{"status": "pulling manifest"}',
|
||||
'{"status": "pulling", "digest": "sha256-abc", "total": 1000, "completed": 500}',
|
||||
'{"status": "success"}',
|
||||
])
|
||||
|
||||
with patch("httpx.stream") as mock_stream_fn:
|
||||
mock_stream_fn.return_value.__enter__ = MagicMock(return_value=mock_resp)
|
||||
mock_stream_fn.return_value.__exit__ = MagicMock(return_value=False)
|
||||
r = client.post(
|
||||
"/api/nodes-mgmt/nodes/heimdall/models/ollama/pull",
|
||||
json={"name": "nomic-embed-text"},
|
||||
)
|
||||
|
||||
assert r.status_code == 200
|
||||
body = r.text
|
||||
assert 'data: {"status": "pulling manifest"}' in body
|
||||
assert 'data: {"status": "success"}' in body
|
||||
|
||||
|
||||
def test_pull_ollama_model_error_event_in_stream(client, tmp_path):
|
||||
profiles_dir = tmp_path / "profiles"
|
||||
_write_config(tmp_path, {"profiles_dir": str(profiles_dir)})
|
||||
_write_profile(profiles_dir, "heimdall", _OLLAMA_PROFILE)
|
||||
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.iter_lines.return_value = iter([
|
||||
'{"error": "permission denied: /var/lib/ollama/sha256-abc-partial-0"}',
|
||||
])
|
||||
|
||||
with patch("httpx.stream") as mock_stream_fn:
|
||||
mock_stream_fn.return_value.__enter__ = MagicMock(return_value=mock_resp)
|
||||
mock_stream_fn.return_value.__exit__ = MagicMock(return_value=False)
|
||||
r = client.post(
|
||||
"/api/nodes-mgmt/nodes/heimdall/models/ollama/pull",
|
||||
json={"name": "nomic-embed-text"},
|
||||
)
|
||||
|
||||
assert r.status_code == 200
|
||||
assert "permission denied" in r.text
|
||||
|
||||
|
||||
def test_delete_ollama_model_proxies_delete(client, tmp_path):
|
||||
profiles_dir = tmp_path / "profiles"
|
||||
_write_config(tmp_path, {"profiles_dir": str(profiles_dir)})
|
||||
_write_profile(profiles_dir, "heimdall", _OLLAMA_PROFILE)
|
||||
|
||||
mock_del = MagicMock()
|
||||
mock_del.status_code = 200
|
||||
mock_del.raise_for_status = MagicMock()
|
||||
|
||||
with patch("httpx.request", return_value=mock_del):
|
||||
r = client.delete("/api/nodes-mgmt/nodes/heimdall/models/ollama/nomic-embed-text")
|
||||
|
||||
assert r.status_code == 200
|
||||
assert r.json() == {"ok": True}
|
||||
|
||||
|
||||
def test_delete_ollama_model_404_when_not_found(client, tmp_path):
|
||||
profiles_dir = tmp_path / "profiles"
|
||||
_write_config(tmp_path, {"profiles_dir": str(profiles_dir)})
|
||||
_write_profile(profiles_dir, "heimdall", _OLLAMA_PROFILE)
|
||||
|
||||
mock_del = MagicMock()
|
||||
mock_del.status_code = 404
|
||||
|
||||
with patch("httpx.request", return_value=mock_del):
|
||||
r = client.delete("/api/nodes-mgmt/nodes/heimdall/models/ollama/missing-model")
|
||||
|
||||
assert r.status_code == 404
|
||||
|
||||
|
||||
# ── Deploy model endpoint ──────────────────────────────────────────────────────
|
||||
|
||||
_DEPLOY_PROFILE = {
|
||||
"services": {
|
||||
"cf-text": {
|
||||
"max_mb": 20000,
|
||||
"min_compute_cap": 7.0,
|
||||
"model_base_path": "/devl/Assets/LLM/cf-text/models",
|
||||
"catalog": {},
|
||||
},
|
||||
},
|
||||
"nodes": {
|
||||
"heimdall": {
|
||||
"gpus": [],
|
||||
"agent_url": "http://10.1.10.71:7701",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def test_deploy_model_adds_catalog_entry(client, tmp_path):
|
||||
"""Deploy endpoint should add the model to the service catalog."""
|
||||
profiles_dir = tmp_path / "profiles"
|
||||
_write_config(tmp_path, {
|
||||
"coordinator_url": "http://fake-coord:7700",
|
||||
"profiles_dir": str(profiles_dir),
|
||||
})
|
||||
_write_profile(profiles_dir, "heimdall", _DEPLOY_PROFILE)
|
||||
|
||||
mock_reload = MagicMock()
|
||||
mock_reload.status_code = 200
|
||||
|
||||
with patch("httpx.post", return_value=mock_reload):
|
||||
r = client.post(
|
||||
"/api/nodes-mgmt/nodes/heimdall/models/deploy",
|
||||
json={
|
||||
"model_id": "fdtn-ai--Foundation-Sec-8B-Q4",
|
||||
"service_type": "cf-text",
|
||||
"vram_mb": 5180,
|
||||
"hf_repo": "fdtn-ai/Foundation-Sec-8B-Q4_K_M-GGUF",
|
||||
},
|
||||
)
|
||||
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["ok"] is True
|
||||
assert data["reloaded"] is True
|
||||
assert "fdtn-ai--Foundation-Sec-8B-Q4_K_M-GGUF" in data["path"]
|
||||
|
||||
saved = yaml.safe_load((profiles_dir / "heimdall.yaml").read_text())
|
||||
catalog = saved["services"]["cf-text"]["catalog"]
|
||||
assert "fdtn-ai--Foundation-Sec-8B-Q4" in catalog
|
||||
entry = catalog["fdtn-ai--Foundation-Sec-8B-Q4"]
|
||||
assert entry["vram_mb"] == 5180
|
||||
assert entry["path"].endswith("fdtn-ai--Foundation-Sec-8B-Q4_K_M-GGUF")
|
||||
|
||||
|
||||
def test_deploy_model_explicit_path_overrides_base(client, tmp_path):
|
||||
"""An explicit path in the request body takes precedence over model_base_path."""
|
||||
profiles_dir = tmp_path / "profiles"
|
||||
_write_config(tmp_path, {
|
||||
"coordinator_url": "http://fake-coord:7700",
|
||||
"profiles_dir": str(profiles_dir),
|
||||
})
|
||||
_write_profile(profiles_dir, "heimdall", _DEPLOY_PROFILE)
|
||||
|
||||
with patch("httpx.post", return_value=MagicMock(status_code=200)):
|
||||
r = client.post(
|
||||
"/api/nodes-mgmt/nodes/heimdall/models/deploy",
|
||||
json={
|
||||
"model_id": "my-model",
|
||||
"service_type": "cf-text",
|
||||
"vram_mb": 8000,
|
||||
"path": "/custom/path/to/model",
|
||||
},
|
||||
)
|
||||
|
||||
assert r.status_code == 200
|
||||
assert r.json()["path"] == "/custom/path/to/model"
|
||||
|
||||
|
||||
def test_deploy_model_unknown_service_returns_422(client, tmp_path):
|
||||
"""Service type not in profile → 422."""
|
||||
profiles_dir = tmp_path / "profiles"
|
||||
_write_config(tmp_path, {"profiles_dir": str(profiles_dir)})
|
||||
_write_profile(profiles_dir, "heimdall", _DEPLOY_PROFILE)
|
||||
|
||||
r = client.post(
|
||||
"/api/nodes-mgmt/nodes/heimdall/models/deploy",
|
||||
json={"model_id": "x", "service_type": "vllm", "vram_mb": 8000},
|
||||
)
|
||||
assert r.status_code == 422
|
||||
assert "vllm" in r.json()["detail"]
|
||||
|
||||
|
||||
def test_deploy_model_missing_profile_returns_404(client, tmp_path):
|
||||
_write_config(tmp_path, {"profiles_dir": str(tmp_path / "profiles")})
|
||||
r = client.post(
|
||||
"/api/nodes-mgmt/nodes/nonexistent/models/deploy",
|
||||
json={"model_id": "x", "service_type": "cf-text", "vram_mb": 100},
|
||||
)
|
||||
assert r.status_code == 404
|
||||
227
tests/test_recipe_scan.py
Normal file
227
tests/test_recipe_scan.py
Normal file
|
|
@ -0,0 +1,227 @@
|
|||
"""Tests for app/data/recipe_scan.py — recipe scan labeling endpoints."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.data import recipe_scan as rs
|
||||
|
||||
|
||||
EXTRACTED = {"title": "Shepherd's Pie", "ingredients": ["lamb", "potato"], "steps": ["brown meat", "mash potato"]}
|
||||
GROUND_TRUTH = {"title": "Shepherd's Pie", "ingredients": ["ground lamb", "mashed potato", "peas"], "steps": ["brown meat", "add veg", "mash potato", "bake"]}
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def isolated_db(tmp_path, monkeypatch):
|
||||
monkeypatch.setattr(rs, "_DB_PATH", tmp_path / "recipe_scan.db")
|
||||
rs._init_db()
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def client():
|
||||
from fastapi import FastAPI
|
||||
app = FastAPI()
|
||||
app.include_router(rs.router, prefix="/api/recipe-scan")
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
def _item(**kwargs) -> dict:
|
||||
return {
|
||||
"id": str(uuid.uuid4()),
|
||||
"image_path": "/Library/Assets/kiwi/scans/pc_test.jpg",
|
||||
"modality": kwargs.get("modality", "scanner"),
|
||||
"source": kwargs.get("source", "purple_carrot"),
|
||||
"extracted": kwargs.get("extracted", EXTRACTED),
|
||||
"ground_truth": kwargs.get("ground_truth", GROUND_TRUTH),
|
||||
}
|
||||
|
||||
|
||||
def _import(client, items: list[dict]) -> None:
|
||||
resp = client.post("/api/recipe-scan/import", json={"items": items})
|
||||
assert resp.status_code == 200
|
||||
|
||||
|
||||
# ── Import ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
def test_import_stores_items(client):
|
||||
_import(client, [_item()])
|
||||
stats = client.get("/api/recipe-scan/stats").json()
|
||||
assert stats["total"] == 1
|
||||
assert stats["by_status"]["pending"] == 1
|
||||
|
||||
|
||||
def test_import_rejects_unknown_modality(client):
|
||||
bad = _item()
|
||||
bad["modality"] = "telepathy"
|
||||
resp = client.post("/api/recipe-scan/import", json={"items": [bad]})
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
def test_import_is_idempotent(client):
|
||||
item = _item()
|
||||
_import(client, [item])
|
||||
_import(client, [item]) # same id — should not duplicate
|
||||
stats = client.get("/api/recipe-scan/stats").json()
|
||||
assert stats["total"] == 1
|
||||
|
||||
|
||||
def test_import_multiple_items(client):
|
||||
_import(client, [_item(), _item(), _item()])
|
||||
assert client.get("/api/recipe-scan/stats").json()["total"] == 3
|
||||
|
||||
|
||||
# ── Next ───────────────────────────────────────────────────────────────────────
|
||||
|
||||
def test_next_returns_404_when_queue_empty(client):
|
||||
resp = client.get("/api/recipe-scan/next")
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
def test_next_returns_pending_item(client):
|
||||
item = _item()
|
||||
_import(client, [item])
|
||||
resp = client.get("/api/recipe-scan/next")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["id"] == item["id"]
|
||||
assert data["status"] == "pending"
|
||||
assert "extracted" in data
|
||||
assert "ground_truth" in data
|
||||
|
||||
|
||||
def test_next_skips_non_pending(client):
|
||||
item = _item()
|
||||
_import(client, [item])
|
||||
client.post(f"/api/recipe-scan/items/{item['id']}/reject")
|
||||
resp = client.get("/api/recipe-scan/next")
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
# ── Approve ────────────────────────────────────────────────────────────────────
|
||||
|
||||
def test_approve_marks_item_approved(client):
|
||||
item = _item()
|
||||
_import(client, [item])
|
||||
resp = client.post(f"/api/recipe-scan/items/{item['id']}/approve")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["status"] == "approved"
|
||||
stats = client.get("/api/recipe-scan/stats").json()
|
||||
assert stats["by_status"]["approved"] == 1
|
||||
|
||||
|
||||
def test_approve_returns_404_for_unknown_id(client):
|
||||
resp = client.post("/api/recipe-scan/items/no-such-id/approve")
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
# ── Edit ───────────────────────────────────────────────────────────────────────
|
||||
|
||||
def test_edit_stores_corrected_json(client):
|
||||
item = _item()
|
||||
_import(client, [item])
|
||||
corrected = {**GROUND_TRUTH, "servings": 4}
|
||||
resp = client.post(
|
||||
f"/api/recipe-scan/items/{item['id']}/edit",
|
||||
json={"corrected": corrected},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["status"] == "edited"
|
||||
stats = client.get("/api/recipe-scan/stats").json()
|
||||
assert stats["by_status"]["edited"] == 1
|
||||
|
||||
|
||||
def test_edit_requires_corrected_field(client):
|
||||
item = _item()
|
||||
_import(client, [item])
|
||||
resp = client.post(f"/api/recipe-scan/items/{item['id']}/edit", json={})
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
# ── Reject ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
def test_reject_marks_item_rejected(client):
|
||||
item = _item()
|
||||
_import(client, [item])
|
||||
resp = client.post(
|
||||
f"/api/recipe-scan/items/{item['id']}/reject",
|
||||
json={"reason": "OCR completely unreadable"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["status"] == "rejected"
|
||||
|
||||
|
||||
def test_reject_without_reason_is_valid(client):
|
||||
item = _item()
|
||||
_import(client, [item])
|
||||
resp = client.post(f"/api/recipe-scan/items/{item['id']}/reject")
|
||||
assert resp.status_code == 200
|
||||
|
||||
|
||||
# ── Export ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
def test_export_empty_when_nothing_approved(client):
|
||||
item = _item()
|
||||
_import(client, [item])
|
||||
resp = client.get("/api/recipe-scan/export")
|
||||
assert resp.status_code == 200
|
||||
assert resp.text.strip() == ""
|
||||
|
||||
|
||||
def test_export_includes_approved_item(client):
|
||||
item = _item()
|
||||
_import(client, [item])
|
||||
client.post(f"/api/recipe-scan/items/{item['id']}/approve")
|
||||
resp = client.get("/api/recipe-scan/export")
|
||||
lines = [l for l in resp.text.strip().splitlines() if l]
|
||||
assert len(lines) == 1
|
||||
pair = json.loads(lines[0])
|
||||
assert pair["id"] == item["id"]
|
||||
assert pair["modality"] == "scanner"
|
||||
assert "messages" in pair
|
||||
assert len(pair["messages"]) == 2
|
||||
assert pair["messages"][0]["role"] == "user"
|
||||
assert pair["messages"][1]["role"] == "assistant"
|
||||
|
||||
|
||||
def test_export_includes_edited_item_with_correction(client):
|
||||
item = _item()
|
||||
_import(client, [item])
|
||||
corrected = {**GROUND_TRUTH, "servings": 4}
|
||||
client.post(
|
||||
f"/api/recipe-scan/items/{item['id']}/edit",
|
||||
json={"corrected": corrected},
|
||||
)
|
||||
resp = client.get("/api/recipe-scan/export")
|
||||
lines = [l for l in resp.text.strip().splitlines() if l]
|
||||
pair = json.loads(lines[0])
|
||||
assistant_content = json.loads(pair["messages"][1]["content"])
|
||||
assert assistant_content["servings"] == 4
|
||||
|
||||
|
||||
def test_export_excludes_rejected_items(client):
|
||||
item = _item()
|
||||
_import(client, [item])
|
||||
client.post(f"/api/recipe-scan/items/{item['id']}/reject")
|
||||
resp = client.get("/api/recipe-scan/export")
|
||||
assert resp.text.strip() == ""
|
||||
|
||||
|
||||
# ── Stats ──────────────────────────────────────────────────────────────────────
|
||||
|
||||
def test_stats_counts_all_statuses(client):
|
||||
items = [_item(), _item(), _item(), _item()]
|
||||
_import(client, items)
|
||||
client.post(f"/api/recipe-scan/items/{items[0]['id']}/approve")
|
||||
client.post(f"/api/recipe-scan/items/{items[1]['id']}/edit", json={"corrected": GROUND_TRUTH})
|
||||
client.post(f"/api/recipe-scan/items/{items[2]['id']}/reject")
|
||||
stats = client.get("/api/recipe-scan/stats").json()
|
||||
assert stats["total"] == 4
|
||||
assert stats["by_status"]["pending"] == 1
|
||||
assert stats["by_status"]["approved"] == 1
|
||||
assert stats["by_status"]["edited"] == 1
|
||||
assert stats["by_status"]["rejected"] == 1
|
||||
assert stats["export_ready"] == 2 # approved + edited
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
"""API integration tests for app/sft.py — /api/sft/* endpoints."""
|
||||
"""API integration tests for app/sft.py -- /api/sft/* endpoints."""
|
||||
import json
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
|
@ -7,17 +7,17 @@ from pathlib import Path
|
|||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_sft_globals(tmp_path):
|
||||
from app import sft as sft_module
|
||||
_prev_data = sft_module._SFT_DATA_DIR
|
||||
_prev_cfg = sft_module._SFT_CONFIG_DIR
|
||||
_prev_default = sft_module._DEFAULT_BENCH_RESULTS_DIR
|
||||
sft_module.set_sft_data_dir(tmp_path)
|
||||
sft_module.set_sft_config_dir(tmp_path)
|
||||
sft_module.set_default_bench_results_dir(str(tmp_path / "bench_results"))
|
||||
from app.data import corrections as corr_module
|
||||
_prev_data = corr_module._DATA_DIR
|
||||
_prev_cfg = corr_module._CONFIG_DIR
|
||||
_prev_default = corr_module._DEFAULT_BENCH_RESULTS_DIR
|
||||
corr_module.set_data_dir(tmp_path)
|
||||
corr_module.set_config_dir(tmp_path)
|
||||
corr_module.set_default_bench_results_dir(str(tmp_path / "bench_results"))
|
||||
yield
|
||||
sft_module.set_sft_data_dir(_prev_data)
|
||||
sft_module.set_sft_config_dir(_prev_cfg)
|
||||
sft_module.set_default_bench_results_dir(_prev_default)
|
||||
corr_module.set_data_dir(_prev_data)
|
||||
corr_module.set_config_dir(_prev_cfg)
|
||||
corr_module.set_default_bench_results_dir(_prev_default)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -63,7 +63,7 @@ def _write_config(tmp_path, bench_results_dir: Path) -> None:
|
|||
)
|
||||
|
||||
|
||||
# ── /api/sft/runs ──────────────────────────────────────────────────────────
|
||||
# -- /api/sft/runs -------------------------------------------------------------
|
||||
|
||||
def test_runs_returns_empty_when_no_config(client):
|
||||
r = client.get("/api/sft/runs")
|
||||
|
|
@ -86,7 +86,7 @@ def test_runs_returns_available_runs(client, tmp_path):
|
|||
def test_runs_marks_already_imported(client, tmp_path):
|
||||
_write_run(tmp_path, "2026-04-07-143022", [_make_record("a")])
|
||||
_write_config(tmp_path, tmp_path / "bench_results")
|
||||
from app import sft as sft_module
|
||||
from app.data import corrections as sft_module
|
||||
candidates = sft_module._candidates_file()
|
||||
candidates.parent.mkdir(parents=True, exist_ok=True)
|
||||
candidates.write_text(
|
||||
|
|
@ -97,7 +97,7 @@ def test_runs_marks_already_imported(client, tmp_path):
|
|||
assert r.json()[0]["already_imported"] is True
|
||||
|
||||
|
||||
# ── /api/sft/import ─────────────────────────────────────────────────────────
|
||||
# -- /api/sft/import -----------------------------------------------------------
|
||||
|
||||
def test_import_adds_records(client, tmp_path):
|
||||
_write_run(tmp_path, "2026-04-07-143022", [_make_record("a"), _make_record("b")])
|
||||
|
|
@ -121,10 +121,10 @@ def test_import_unknown_run_returns_404(client, tmp_path):
|
|||
assert r.status_code == 404
|
||||
|
||||
|
||||
# ── /api/sft/queue ──────────────────────────────────────────────────────────
|
||||
# -- /api/sft/queue ------------------------------------------------------------
|
||||
|
||||
def _populate_candidates(tmp_path, records: list[dict]) -> None:
|
||||
from app import sft as sft_module
|
||||
from app.data import corrections as sft_module
|
||||
path = sft_module._candidates_file()
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
path.write_text(
|
||||
|
|
@ -164,7 +164,7 @@ def test_queue_empty_when_no_file(client):
|
|||
assert r.json() == {"items": [], "total": 0, "page": 1, "per_page": 20}
|
||||
|
||||
|
||||
# ── /api/sft/submit ─────────────────────────────────────────────────────────
|
||||
# -- /api/sft/submit -----------------------------------------------------------
|
||||
|
||||
def test_submit_correct_sets_approved(client, tmp_path):
|
||||
_populate_candidates(tmp_path, [_make_record("a")])
|
||||
|
|
@ -173,7 +173,7 @@ def test_submit_correct_sets_approved(client, tmp_path):
|
|||
"corrected_response": "def add(a, b): return a + b",
|
||||
})
|
||||
assert r.status_code == 200
|
||||
from app import sft as sft_module
|
||||
from app.data import corrections as sft_module
|
||||
records = sft_module._read_candidates()
|
||||
assert records[0]["status"] == "approved"
|
||||
assert records[0]["corrected_response"] == "def add(a, b): return a + b"
|
||||
|
|
@ -185,7 +185,7 @@ def test_submit_correct_also_appends_to_approved_file(client, tmp_path):
|
|||
"id": "a", "action": "correct",
|
||||
"corrected_response": "def add(a, b): return a + b",
|
||||
})
|
||||
from app import sft as sft_module
|
||||
from app.data import corrections as sft_module
|
||||
from app.utils import read_jsonl
|
||||
approved = read_jsonl(sft_module._approved_file())
|
||||
assert len(approved) == 1
|
||||
|
|
@ -196,7 +196,7 @@ def test_submit_discard_sets_discarded(client, tmp_path):
|
|||
_populate_candidates(tmp_path, [_make_record("a")])
|
||||
r = client.post("/api/sft/submit", json={"id": "a", "action": "discard"})
|
||||
assert r.status_code == 200
|
||||
from app import sft as sft_module
|
||||
from app.data import corrections as sft_module
|
||||
assert sft_module._read_candidates()[0]["status"] == "discarded"
|
||||
|
||||
|
||||
|
|
@ -204,7 +204,7 @@ def test_submit_flag_sets_model_rejected(client, tmp_path):
|
|||
_populate_candidates(tmp_path, [_make_record("a")])
|
||||
r = client.post("/api/sft/submit", json={"id": "a", "action": "flag"})
|
||||
assert r.status_code == 200
|
||||
from app import sft as sft_module
|
||||
from app.data import corrections as sft_module
|
||||
assert sft_module._read_candidates()[0]["status"] == "model_rejected"
|
||||
|
||||
|
||||
|
|
@ -243,7 +243,7 @@ def test_submit_correct_stores_failure_category(client, tmp_path):
|
|||
"failure_category": "style_violation",
|
||||
})
|
||||
assert r.status_code == 200
|
||||
from app import sft as sft_module
|
||||
from app.data import corrections as sft_module
|
||||
records = sft_module._read_candidates()
|
||||
assert records[0]["failure_category"] == "style_violation"
|
||||
|
||||
|
|
@ -255,7 +255,7 @@ def test_submit_correct_null_failure_category(client, tmp_path):
|
|||
"corrected_response": "def add(a, b): return a + b",
|
||||
})
|
||||
assert r.status_code == 200
|
||||
from app import sft as sft_module
|
||||
from app.data import corrections as sft_module
|
||||
records = sft_module._read_candidates()
|
||||
assert records[0]["failure_category"] is None
|
||||
|
||||
|
|
@ -270,14 +270,14 @@ def test_submit_invalid_failure_category_returns_422(client, tmp_path):
|
|||
assert r.status_code == 422
|
||||
|
||||
|
||||
# ── /api/sft/undo ────────────────────────────────────────────────────────────
|
||||
# -- /api/sft/undo -------------------------------------------------------------
|
||||
|
||||
def test_undo_restores_discarded_to_needs_review(client, tmp_path):
|
||||
_populate_candidates(tmp_path, [_make_record("a")])
|
||||
client.post("/api/sft/submit", json={"id": "a", "action": "discard"})
|
||||
r = client.post("/api/sft/undo", json={"id": "a"})
|
||||
assert r.status_code == 200
|
||||
from app import sft as sft_module
|
||||
from app.data import corrections as sft_module
|
||||
assert sft_module._read_candidates()[0]["status"] == "needs_review"
|
||||
|
||||
|
||||
|
|
@ -288,7 +288,7 @@ def test_undo_removes_approved_from_approved_file(client, tmp_path):
|
|||
"corrected_response": "def add(a, b): return a + b",
|
||||
})
|
||||
client.post("/api/sft/undo", json={"id": "a"})
|
||||
from app import sft as sft_module
|
||||
from app.data import corrections as sft_module
|
||||
from app.utils import read_jsonl
|
||||
approved = read_jsonl(sft_module._approved_file())
|
||||
assert not any(r["id"] == "a" for r in approved)
|
||||
|
|
@ -300,10 +300,10 @@ def test_undo_already_needs_review_returns_409(client, tmp_path):
|
|||
assert r.status_code == 409
|
||||
|
||||
|
||||
# ── /api/sft/export ──────────────────────────────────────────────────────────
|
||||
# -- /api/sft/export -----------------------------------------------------------
|
||||
|
||||
def test_export_returns_approved_as_sft_jsonl(client, tmp_path):
|
||||
from app import sft as sft_module
|
||||
from app.data import corrections as sft_module
|
||||
from app.utils import write_jsonl
|
||||
approved = {
|
||||
**_make_record("a"),
|
||||
|
|
@ -331,7 +331,7 @@ def test_export_returns_approved_as_sft_jsonl(client, tmp_path):
|
|||
|
||||
|
||||
def test_export_excludes_non_approved(client, tmp_path):
|
||||
from app import sft as sft_module
|
||||
from app.data import corrections as sft_module
|
||||
from app.utils import write_jsonl
|
||||
records = [
|
||||
{**_make_record("a"), "status": "discarded", "corrected_response": None},
|
||||
|
|
@ -348,10 +348,10 @@ def test_export_empty_when_no_approved_file(client):
|
|||
assert r.text.strip() == ""
|
||||
|
||||
|
||||
# ── /api/sft/stats ───────────────────────────────────────────────────────────
|
||||
# -- /api/sft/stats ------------------------------------------------------------
|
||||
|
||||
def test_stats_counts_by_status(client, tmp_path):
|
||||
from app import sft as sft_module
|
||||
from app.data import corrections as sft_module
|
||||
from app.utils import write_jsonl
|
||||
records = [
|
||||
_make_record("a"),
|
||||
|
|
|
|||
187
tests/test_train.py
Normal file
187
tests/test_train.py
Normal file
|
|
@ -0,0 +1,187 @@
|
|||
"""Tests for app/train/train.py -- /api/train/* endpoints."""
|
||||
import json
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_globals(tmp_path):
|
||||
from app.train import train as train_module
|
||||
train_module.set_db_path(tmp_path / "train_jobs.db")
|
||||
train_module.set_models_dir(tmp_path / "models")
|
||||
train_module._running_procs.clear()
|
||||
yield
|
||||
train_module._running_procs.clear()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
from app.api import app
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
def _parse_sse(content: bytes) -> list[dict]:
|
||||
events = []
|
||||
for line in content.decode().splitlines():
|
||||
if line.startswith("data: "):
|
||||
events.append(json.loads(line[6:]))
|
||||
return events
|
||||
|
||||
|
||||
def test_list_jobs_empty(client):
|
||||
r = client.get("/api/train/jobs")
|
||||
assert r.status_code == 200
|
||||
assert r.json() == {"jobs": []}
|
||||
|
||||
|
||||
def test_create_job_returns_queued_record(client):
|
||||
r = client.post("/api/train/jobs",
|
||||
json={"type": "classifier", "model_key": "deberta-small",
|
||||
"config_json": {"epochs": 3}})
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["status"] == "queued"
|
||||
assert data["type"] == "classifier"
|
||||
assert data["model_key"] == "deberta-small"
|
||||
assert "id" in data
|
||||
|
||||
|
||||
def test_create_job_invalid_type_returns_400(client):
|
||||
r = client.post("/api/train/jobs",
|
||||
json={"type": "unknown-type", "model_key": "deberta-small"})
|
||||
assert r.status_code == 400
|
||||
|
||||
|
||||
def test_create_job_appears_in_list(client):
|
||||
client.post("/api/train/jobs", json={"type": "classifier", "model_key": "deberta-small"})
|
||||
r = client.get("/api/train/jobs")
|
||||
assert r.status_code == 200
|
||||
assert len(r.json()["jobs"]) == 1
|
||||
|
||||
|
||||
def test_get_job_returns_record(client):
|
||||
r = client.post("/api/train/jobs", json={"type": "classifier", "model_key": "deberta-small"})
|
||||
job_id = r.json()["id"]
|
||||
r2 = client.get(f"/api/train/jobs/{job_id}")
|
||||
assert r2.status_code == 200
|
||||
assert r2.json()["id"] == job_id
|
||||
|
||||
|
||||
def test_get_job_404_for_unknown(client):
|
||||
r = client.get("/api/train/jobs/no-such-id")
|
||||
assert r.status_code == 404
|
||||
|
||||
|
||||
def test_cancel_queued_job(client):
|
||||
r = client.post("/api/train/jobs", json={"type": "classifier", "model_key": "deberta-small"})
|
||||
job_id = r.json()["id"]
|
||||
r2 = client.delete(f"/api/train/jobs/{job_id}/cancel")
|
||||
assert r2.status_code == 200
|
||||
assert r2.json()["status"] == "cancelled"
|
||||
r3 = client.get(f"/api/train/jobs/{job_id}")
|
||||
assert r3.json()["status"] == "cancelled"
|
||||
|
||||
|
||||
def test_cancel_completed_job_returns_409(client):
|
||||
from app.train import train as train_module
|
||||
train_module._init_db()
|
||||
with train_module._db() as conn:
|
||||
conn.execute(
|
||||
"INSERT INTO jobs (id, type, model_key, status, config_json, created_at) "
|
||||
"VALUES ('abc', 'classifier', 'deberta-small', 'completed', '{}', '2026-05-01T00:00:00Z')"
|
||||
)
|
||||
r = client.delete("/api/train/jobs/abc/cancel")
|
||||
assert r.status_code == 409
|
||||
|
||||
|
||||
def test_cancel_terminates_running_proc(client):
|
||||
from app.train import train as train_module
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.wait = MagicMock()
|
||||
r = client.post("/api/train/jobs", json={"type": "classifier", "model_key": "deberta-small"})
|
||||
job_id = r.json()["id"]
|
||||
train_module._running_procs[job_id] = mock_proc
|
||||
with train_module._db() as conn:
|
||||
conn.execute("UPDATE jobs SET status='running' WHERE id=?", (job_id,))
|
||||
r2 = client.delete(f"/api/train/jobs/{job_id}/cancel")
|
||||
assert r2.status_code == 200
|
||||
mock_proc.terminate.assert_called_once()
|
||||
|
||||
|
||||
def test_run_job_streams_sse(client):
|
||||
r = client.post("/api/train/jobs", json={"type": "classifier", "model_key": "deberta-small"})
|
||||
job_id = r.json()["id"]
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.stdout = iter(["Epoch 1\n", "Done\n"])
|
||||
mock_proc.returncode = 0
|
||||
mock_proc.wait = MagicMock()
|
||||
with patch("app.train.train._subprocess.Popen", return_value=mock_proc):
|
||||
r2 = client.get(f"/api/train/jobs/{job_id}/run")
|
||||
assert r2.status_code == 200
|
||||
assert "text/event-stream" in r2.headers.get("content-type", "")
|
||||
events = _parse_sse(r2.content)
|
||||
assert any(e["type"] == "complete" for e in events)
|
||||
|
||||
|
||||
def test_run_job_marks_completed_in_db(client):
|
||||
from app.train import train as train_module
|
||||
r = client.post("/api/train/jobs", json={"type": "classifier", "model_key": "deberta-small"})
|
||||
job_id = r.json()["id"]
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.stdout = iter([])
|
||||
mock_proc.returncode = 0
|
||||
mock_proc.wait = MagicMock()
|
||||
with patch("app.train.train._subprocess.Popen", return_value=mock_proc):
|
||||
client.get(f"/api/train/jobs/{job_id}/run")
|
||||
r2 = client.get(f"/api/train/jobs/{job_id}")
|
||||
assert r2.json()["status"] == "completed"
|
||||
|
||||
|
||||
def test_run_job_marks_failed_on_nonzero_exit(client):
|
||||
r = client.post("/api/train/jobs", json={"type": "classifier", "model_key": "deberta-small"})
|
||||
job_id = r.json()["id"]
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.stdout = iter([])
|
||||
mock_proc.returncode = 1
|
||||
mock_proc.wait = MagicMock()
|
||||
with patch("app.train.train._subprocess.Popen", return_value=mock_proc):
|
||||
client.get(f"/api/train/jobs/{job_id}/run")
|
||||
r2 = client.get(f"/api/train/jobs/{job_id}")
|
||||
assert r2.json()["status"] == "failed"
|
||||
|
||||
|
||||
def test_run_nonqueued_job_returns_409(client):
|
||||
from app.train import train as train_module
|
||||
train_module._init_db()
|
||||
with train_module._db() as conn:
|
||||
conn.execute(
|
||||
"INSERT INTO jobs (id, type, model_key, status, config_json, created_at) "
|
||||
"VALUES ('xyz', 'classifier', 'deberta-small', 'running', '{}', '2026-05-01T00:00:00Z')"
|
||||
)
|
||||
r = client.get("/api/train/jobs/xyz/run")
|
||||
assert r.status_code == 409
|
||||
|
||||
|
||||
def test_run_unknown_job_returns_404(client):
|
||||
r = client.get("/api/train/jobs/no-such/run")
|
||||
assert r.status_code == 404
|
||||
|
||||
|
||||
def test_results_empty_when_no_models_dir(client):
|
||||
r = client.get("/api/train/results")
|
||||
assert r.status_code == 200
|
||||
assert r.json() == {"results": []}
|
||||
|
||||
|
||||
def test_results_returns_training_info(client, tmp_path):
|
||||
from app.train import train as train_module
|
||||
models_dir = tmp_path / "models" / "avocet-deberta-small"
|
||||
models_dir.mkdir(parents=True)
|
||||
train_module.set_models_dir(tmp_path / "models")
|
||||
info = {"name": "avocet-deberta-small", "val_macro_f1": 0.712, "sample_count": 401}
|
||||
(models_dir / "training_info.json").write_text(json.dumps(info))
|
||||
r = client.get("/api/train/results")
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert any(d["name"] == "avocet-deberta-small" for d in data["results"])
|
||||
42
web/package-lock.json
generated
42
web/package-lock.json
generated
|
|
@ -2676,9 +2676,9 @@
|
|||
}
|
||||
},
|
||||
"node_modules/brace-expansion": {
|
||||
"version": "2.0.2",
|
||||
"resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.2.tgz",
|
||||
"integrity": "sha512-Jt0vHyM+jmUBqojB7E1NIYadt0vI0Qxjxd2TErW94wDz+E2LAm5vKMXXwg6ZZBTHPuUlDgQHKXvjGBdfcF1ZDQ==",
|
||||
"version": "2.1.0",
|
||||
"resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.1.0.tgz",
|
||||
"integrity": "sha512-TN1kCZAgdgweJhWWpgKYrQaMNHcDULHkWwQIspdtjV4Y5aurRdZpjAqn6yX3FPqTA9ngHCc4hJxMAMgGfve85w==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
|
|
@ -2890,9 +2890,9 @@
|
|||
"license": "MIT"
|
||||
},
|
||||
"node_modules/defu": {
|
||||
"version": "6.1.4",
|
||||
"resolved": "https://registry.npmjs.org/defu/-/defu-6.1.4.tgz",
|
||||
"integrity": "sha512-mEQCMmwJu317oSz8CwdIOdwf3xMif1ttiM8LTufzc3g6kR+9Pe236twL8j3IYT1F7GfRgGcW6MWxzZjLIkuHIg==",
|
||||
"version": "6.1.7",
|
||||
"resolved": "https://registry.npmjs.org/defu/-/defu-6.1.7.tgz",
|
||||
"integrity": "sha512-7z22QmUWiQ/2d0KkdYmANbRUVABpZ9SNYyH5vx6PZ+nE5bcC0l7uFvEfHlyld/HcGBFTL536ClDt3DEcSlEJAQ==",
|
||||
"dev": true,
|
||||
"license": "MIT"
|
||||
},
|
||||
|
|
@ -3725,9 +3725,9 @@
|
|||
"license": "ISC"
|
||||
},
|
||||
"node_modules/picomatch": {
|
||||
"version": "4.0.3",
|
||||
"resolved": "https://registry.npmjs.org/picomatch/-/picomatch-4.0.3.tgz",
|
||||
"integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==",
|
||||
"version": "4.0.4",
|
||||
"resolved": "https://registry.npmjs.org/picomatch/-/picomatch-4.0.4.tgz",
|
||||
"integrity": "sha512-QP88BAKvMam/3NxH6vj2o21R6MjxZUAd6nlwAS/pnGvN9IVLocLHxGYIzFhg6fUQ+5th6P4dv4eW9jX3DSIj7A==",
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">=12"
|
||||
|
|
@ -3769,9 +3769,9 @@
|
|||
}
|
||||
},
|
||||
"node_modules/postcss": {
|
||||
"version": "8.5.8",
|
||||
"resolved": "https://registry.npmjs.org/postcss/-/postcss-8.5.8.tgz",
|
||||
"integrity": "sha512-OW/rX8O/jXnm82Ey1k44pObPtdblfiuWnrd8X7GJ7emImCOstunGbXUpp7HdBrFQX6rJzn3sPT397Wp5aCwCHg==",
|
||||
"version": "8.5.14",
|
||||
"resolved": "https://registry.npmjs.org/postcss/-/postcss-8.5.14.tgz",
|
||||
"integrity": "sha512-SoSL4+OSEtR99LHFZQiJLkT59C5B1amGO1NzTwj7TT1qCUgUO6hxOvzkOYxD+vMrXBM3XJIKzokoERdqQq/Zmg==",
|
||||
"funding": [
|
||||
{
|
||||
"type": "opencollective",
|
||||
|
|
@ -4325,9 +4325,9 @@
|
|||
}
|
||||
},
|
||||
"node_modules/undici": {
|
||||
"version": "7.22.0",
|
||||
"resolved": "https://registry.npmjs.org/undici/-/undici-7.22.0.tgz",
|
||||
"integrity": "sha512-RqslV2Us5BrllB+JeiZnK4peryVTndy9Dnqq62S3yYRRTj0tFQCwEniUy2167skdGOy3vqRzEvl1Dm4sV2ReDg==",
|
||||
"version": "7.25.0",
|
||||
"resolved": "https://registry.npmjs.org/undici/-/undici-7.25.0.tgz",
|
||||
"integrity": "sha512-xXnp4kTyor2Zq+J1FfPI6Eq3ew5h6Vl0F/8d9XU5zZQf1tX9s2Su1/3PiMmUANFULpmksxkClamIZcaUqryHsQ==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
|
|
@ -4422,9 +4422,9 @@
|
|||
}
|
||||
},
|
||||
"node_modules/vite": {
|
||||
"version": "7.3.1",
|
||||
"resolved": "https://registry.npmjs.org/vite/-/vite-7.3.1.tgz",
|
||||
"integrity": "sha512-w+N7Hifpc3gRjZ63vYBXA56dvvRlNWRczTdmCBBa+CotUzAPf5b7YMdMR/8CQoeYE5LX3W4wj6RYTgonm1b9DA==",
|
||||
"version": "7.3.2",
|
||||
"resolved": "https://registry.npmjs.org/vite/-/vite-7.3.2.tgz",
|
||||
"integrity": "sha512-Bby3NOsna2jsjfLVOHKes8sGwgl4TT0E6vvpYgnAYDIF/tie7MRaFthmKuHx1NSXjiTueXH3do80FMQgvEktRg==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
|
|
@ -4921,9 +4921,9 @@
|
|||
"license": "MIT"
|
||||
},
|
||||
"node_modules/yaml": {
|
||||
"version": "2.8.2",
|
||||
"resolved": "https://registry.npmjs.org/yaml/-/yaml-2.8.2.tgz",
|
||||
"integrity": "sha512-mplynKqc1C2hTVYxd0PU2xQAc22TI1vShAYGksCCfxbn/dFwnHTNi1bvYsBTkhdUNtGIf5xNOg938rrSSYvS9A==",
|
||||
"version": "2.8.4",
|
||||
"resolved": "https://registry.npmjs.org/yaml/-/yaml-2.8.4.tgz",
|
||||
"integrity": "sha512-ml/JPOj9fOQK8RNnWojA67GbZ0ApXAUlN2UQclwv2eVgTgn7O9gg9o7paZWKMp4g0H3nTLtS9LVzhkpOFIKzog==",
|
||||
"license": "ISC",
|
||||
"bin": {
|
||||
"yaml": "bin.mjs"
|
||||
|
|
|
|||
124
web/src/components/AppSidebar.test.ts
Normal file
124
web/src/components/AppSidebar.test.ts
Normal file
|
|
@ -0,0 +1,124 @@
|
|||
import { mount, flushPromises } from '@vue/test-utils'
|
||||
import { createRouter, createWebHashHistory } from 'vue-router'
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest'
|
||||
import AppSidebar from './AppSidebar.vue'
|
||||
|
||||
// Minimal router so RouterLink renders without warnings
|
||||
const router = createRouter({
|
||||
history: createWebHashHistory(),
|
||||
routes: [
|
||||
{ path: '/', component: { template: '<div />' } },
|
||||
{ path: '/fleet', component: { template: '<div />' } },
|
||||
{ path: '/data/label', component: { template: '<div />' } },
|
||||
{ path: '/data/fetch', component: { template: '<div />' } },
|
||||
{ path: '/data/corrections', component: { template: '<div />' } },
|
||||
{ path: '/data/imitate', component: { template: '<div />' } },
|
||||
{ path: '/eval/benchmark', component: { template: '<div />' } },
|
||||
{ path: '/eval/compare', component: { template: '<div />' } },
|
||||
{ path: '/train/jobs', component: { template: '<div />' } },
|
||||
{ path: '/train/results', component: { template: '<div />' } },
|
||||
{ path: '/settings', component: { template: '<div />' } },
|
||||
],
|
||||
})
|
||||
|
||||
function makeFetch(signals: Record<string, boolean> = {}) {
|
||||
return vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
json: async () => ({
|
||||
labeled_since_last_eval: 0,
|
||||
last_eval_timestamp: null,
|
||||
last_eval_best_score: null,
|
||||
active_jobs: [],
|
||||
corrections_export_ready: 0,
|
||||
signals,
|
||||
}),
|
||||
text: async () => '',
|
||||
})
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
localStorage.clear()
|
||||
vi.stubGlobal('fetch', makeFetch())
|
||||
})
|
||||
|
||||
describe('AppSidebar structure', () => {
|
||||
it('renders section headers for Data, Eval, Train', async () => {
|
||||
const w = mount(AppSidebar, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
const text = w.text()
|
||||
expect(text).toContain('Data')
|
||||
expect(text).toContain('Eval')
|
||||
expect(text).toContain('Train')
|
||||
})
|
||||
|
||||
it('renders all sub-links', async () => {
|
||||
const w = mount(AppSidebar, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
const anchors = w.findAll('a')
|
||||
const hrefs = anchors.map(a => a.attributes('href') ?? '')
|
||||
expect(hrefs.some(h => h.includes('/data/label'))).toBe(true)
|
||||
expect(hrefs.some(h => h.includes('/data/fetch'))).toBe(true)
|
||||
expect(hrefs.some(h => h.includes('/data/corrections'))).toBe(true)
|
||||
expect(hrefs.some(h => h.includes('/data/imitate'))).toBe(true)
|
||||
expect(hrefs.some(h => h.includes('/eval/benchmark'))).toBe(true)
|
||||
expect(hrefs.some(h => h.includes('/eval/compare'))).toBe(true)
|
||||
expect(hrefs.some(h => h.includes('/train/jobs'))).toBe(true)
|
||||
expect(hrefs.some(h => h.includes('/train/results'))).toBe(true)
|
||||
expect(hrefs.some(h => h.includes('/fleet'))).toBe(true)
|
||||
expect(hrefs.some(h => h.includes('/settings'))).toBe(true)
|
||||
})
|
||||
|
||||
it('does NOT render the old /benchmark or /models links', async () => {
|
||||
const w = mount(AppSidebar, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
const anchors = w.findAll('a')
|
||||
const hrefs = anchors.map(a => a.attributes('href') ?? '')
|
||||
// Old paths must not appear as direct links (they're only redirects)
|
||||
expect(hrefs.every(h => !h.endsWith('/#/benchmark'))).toBe(true)
|
||||
expect(hrefs.every(h => !h.endsWith('/#/models'))).toBe(true)
|
||||
expect(hrefs.every(h => !h.endsWith('/#/stats'))).toBe(true)
|
||||
})
|
||||
|
||||
it('shows no signal badges when all signals are false', async () => {
|
||||
vi.stubGlobal('fetch', makeFetch({ data_to_eval: false, eval_to_train: false, train_to_fleet: false }))
|
||||
const w = mount(AppSidebar, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
expect(w.findAll('.signal-badge').length).toBe(0)
|
||||
})
|
||||
|
||||
it('shows signal badge on Data section when data_to_eval is true', async () => {
|
||||
vi.stubGlobal('fetch', makeFetch({ data_to_eval: true, eval_to_train: false, train_to_fleet: false }))
|
||||
const w = mount(AppSidebar, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
const badges = w.findAll('.signal-badge')
|
||||
expect(badges.length).toBe(1)
|
||||
// It should be inside the Data section header
|
||||
const dataHeader = w.find('[data-section="data"]')
|
||||
expect(dataHeader.find('.signal-badge').exists()).toBe(true)
|
||||
})
|
||||
|
||||
it('shows signal badge on Eval section when eval_to_train is true', async () => {
|
||||
vi.stubGlobal('fetch', makeFetch({ data_to_eval: false, eval_to_train: true, train_to_fleet: false }))
|
||||
const w = mount(AppSidebar, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
const evalHeader = w.find('[data-section="eval"]')
|
||||
expect(evalHeader.find('.signal-badge').exists()).toBe(true)
|
||||
})
|
||||
|
||||
it('shows signal badge on Train section when train_to_fleet is true', async () => {
|
||||
vi.stubGlobal('fetch', makeFetch({ data_to_eval: false, eval_to_train: false, train_to_fleet: true }))
|
||||
const w = mount(AppSidebar, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
const trainHeader = w.find('[data-section="train"]')
|
||||
expect(trainHeader.find('.signal-badge').exists()).toBe(true)
|
||||
})
|
||||
|
||||
it('stow toggle still works', async () => {
|
||||
const w = mount(AppSidebar, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
const nav = w.find('nav')
|
||||
expect(nav.classes()).not.toContain('stowed')
|
||||
await w.find('.stow-btn').trigger('click')
|
||||
expect(nav.classes()).toContain('stowed')
|
||||
})
|
||||
})
|
||||
|
|
@ -28,12 +28,70 @@
|
|||
</button>
|
||||
</div>
|
||||
|
||||
<!-- Nav items -->
|
||||
<!-- Nav -->
|
||||
<ul class="nav-list" role="list">
|
||||
<li v-for="item in navItems" :key="item.path">
|
||||
<!-- Top-level links -->
|
||||
<li>
|
||||
<RouterLink
|
||||
to="/"
|
||||
class="nav-item"
|
||||
:title="stowed ? 'Dashboard' : ''"
|
||||
@click="isMobile && stow()"
|
||||
>
|
||||
<span class="nav-icon" aria-hidden="true">📊</span>
|
||||
<span v-if="!stowed" class="nav-label">Dashboard</span>
|
||||
</RouterLink>
|
||||
</li>
|
||||
<li>
|
||||
<RouterLink
|
||||
to="/fleet"
|
||||
class="nav-item"
|
||||
:title="stowed ? 'Fleet' : ''"
|
||||
@click="isMobile && stow()"
|
||||
>
|
||||
<span class="nav-icon" aria-hidden="true">⚡</span>
|
||||
<span v-if="!stowed" class="nav-label">Fleet</span>
|
||||
</RouterLink>
|
||||
</li>
|
||||
<li>
|
||||
<RouterLink
|
||||
to="/nodes"
|
||||
class="nav-item"
|
||||
:title="stowed ? 'Nodes' : ''"
|
||||
@click="isMobile && stow()"
|
||||
>
|
||||
<span class="nav-icon" aria-hidden="true">🖥️</span>
|
||||
<span v-if="!stowed" class="nav-label">Nodes</span>
|
||||
</RouterLink>
|
||||
</li>
|
||||
|
||||
<!-- ① Data section -->
|
||||
<li>
|
||||
<div class="section-header" data-section="data" aria-hidden="true">
|
||||
<template v-if="!stowed">
|
||||
<span class="section-label">① Data</span>
|
||||
<span
|
||||
v-if="signals.data_to_eval"
|
||||
class="signal-badge"
|
||||
title="Enough new labels to run eval"
|
||||
aria-label="Eval recommended"
|
||||
/>
|
||||
</template>
|
||||
<template v-else>
|
||||
<span class="section-icon">①</span>
|
||||
<span
|
||||
v-if="signals.data_to_eval"
|
||||
class="signal-badge signal-badge-stowed"
|
||||
title="Eval recommended"
|
||||
aria-label="Eval recommended"
|
||||
/>
|
||||
</template>
|
||||
</div>
|
||||
</li>
|
||||
<li v-for="item in dataItems" :key="item.path">
|
||||
<RouterLink
|
||||
:to="item.path"
|
||||
class="nav-item"
|
||||
class="nav-item nav-subitem"
|
||||
:title="stowed ? item.label : ''"
|
||||
@click="isMobile && stow()"
|
||||
>
|
||||
|
|
@ -41,10 +99,94 @@
|
|||
<span v-if="!stowed" class="nav-label">{{ item.label }}</span>
|
||||
</RouterLink>
|
||||
</li>
|
||||
|
||||
<!-- ② Eval section -->
|
||||
<li>
|
||||
<div class="section-header" data-section="eval" aria-hidden="true">
|
||||
<template v-if="!stowed">
|
||||
<span class="section-label">② Eval</span>
|
||||
<span
|
||||
v-if="signals.eval_to_train"
|
||||
class="signal-badge"
|
||||
title="Strong eval result — consider finetuning"
|
||||
aria-label="Finetune recommended"
|
||||
/>
|
||||
</template>
|
||||
<template v-else>
|
||||
<span class="section-icon">②</span>
|
||||
<span
|
||||
v-if="signals.eval_to_train"
|
||||
class="signal-badge signal-badge-stowed"
|
||||
title="Finetune recommended"
|
||||
aria-label="Finetune recommended"
|
||||
/>
|
||||
</template>
|
||||
</div>
|
||||
</li>
|
||||
<li v-for="item in evalItems" :key="item.path">
|
||||
<RouterLink
|
||||
:to="item.path"
|
||||
class="nav-item nav-subitem"
|
||||
:title="stowed ? item.label : ''"
|
||||
@click="isMobile && stow()"
|
||||
>
|
||||
<span class="nav-icon" aria-hidden="true">{{ item.icon }}</span>
|
||||
<span v-if="!stowed" class="nav-label">{{ item.label }}</span>
|
||||
</RouterLink>
|
||||
</li>
|
||||
|
||||
<!-- ③ Train section -->
|
||||
<li>
|
||||
<div class="section-header" data-section="train" aria-hidden="true">
|
||||
<template v-if="!stowed">
|
||||
<span class="section-label">③ Train</span>
|
||||
<span
|
||||
v-if="signals.train_to_fleet"
|
||||
class="signal-badge"
|
||||
title="Trained model ready for fleet registration"
|
||||
aria-label="Fleet registration recommended"
|
||||
/>
|
||||
</template>
|
||||
<template v-else>
|
||||
<span class="section-icon">③</span>
|
||||
<span
|
||||
v-if="signals.train_to_fleet"
|
||||
class="signal-badge signal-badge-stowed"
|
||||
title="Fleet registration recommended"
|
||||
aria-label="Fleet registration recommended"
|
||||
/>
|
||||
</template>
|
||||
</div>
|
||||
</li>
|
||||
<li v-for="item in trainItems" :key="item.path">
|
||||
<RouterLink
|
||||
:to="item.path"
|
||||
class="nav-item nav-subitem"
|
||||
:title="stowed ? item.label : ''"
|
||||
@click="isMobile && stow()"
|
||||
>
|
||||
<span class="nav-icon" aria-hidden="true">{{ item.icon }}</span>
|
||||
<span v-if="!stowed" class="nav-label">{{ item.label }}</span>
|
||||
</RouterLink>
|
||||
</li>
|
||||
|
||||
<!-- Divider + Settings -->
|
||||
<li class="nav-divider" aria-hidden="true" />
|
||||
<li>
|
||||
<RouterLink
|
||||
to="/settings"
|
||||
class="nav-item"
|
||||
:title="stowed ? 'Settings' : ''"
|
||||
@click="isMobile && stow()"
|
||||
>
|
||||
<span class="nav-icon" aria-hidden="true">⚙️</span>
|
||||
<span v-if="!stowed" class="nav-label">Settings</span>
|
||||
</RouterLink>
|
||||
</li>
|
||||
</ul>
|
||||
</nav>
|
||||
|
||||
<!-- Mobile hamburger button rendered outside the sidebar so it's visible when stowed -->
|
||||
<!-- Mobile hamburger button — visible when sidebar is stowed on mobile -->
|
||||
<button
|
||||
v-if="isMobile && stowed"
|
||||
class="mobile-hamburger"
|
||||
|
|
@ -61,25 +203,68 @@ import { RouterLink } from 'vue-router'
|
|||
|
||||
const LS_KEY = 'cf-avocet-nav-stowed'
|
||||
|
||||
const navItems = [
|
||||
{ path: '/', icon: '🃏', label: 'Label' },
|
||||
{ path: '/fetch', icon: '📥', label: 'Fetch' },
|
||||
{ path: '/stats', icon: '📊', label: 'Stats' },
|
||||
{ path: '/benchmark', icon: '🏁', label: 'Benchmark' },
|
||||
{ path: '/models', icon: '🤗', label: 'Models' },
|
||||
{ path: '/imitate', icon: '🪞', label: 'Imitate' },
|
||||
{ path: '/corrections', icon: '✍️', label: 'Corrections' },
|
||||
{ path: '/settings', icon: '⚙️', label: 'Settings' },
|
||||
interface NavItem {
|
||||
path: string
|
||||
icon: string
|
||||
label: string
|
||||
}
|
||||
|
||||
interface DashboardSignals {
|
||||
data_to_eval: boolean
|
||||
eval_to_train: boolean
|
||||
train_to_fleet: boolean
|
||||
}
|
||||
|
||||
const dataItems: NavItem[] = [
|
||||
{ path: '/data/label', icon: '🏷', label: 'Label' },
|
||||
{ path: '/data/fetch', icon: '📬', label: 'Fetch' },
|
||||
{ path: '/data/corrections', icon: '✏️', label: 'Corrections' },
|
||||
{ path: '/data/imitate', icon: '🪞', label: 'Imitate' },
|
||||
{ path: '/data/recipe-scan', icon: '📷', label: 'Recipe Scan' },
|
||||
]
|
||||
|
||||
const evalItems: NavItem[] = [
|
||||
{ path: '/eval/benchmark', icon: '📊', label: 'Benchmark' },
|
||||
{ path: '/eval/compare', icon: '🔍', label: 'Compare' },
|
||||
{ path: '/eval/embed-compare', icon: '🧮', label: 'Embed Compare' },
|
||||
]
|
||||
|
||||
const trainItems: NavItem[] = [
|
||||
{ path: '/train/jobs', icon: '🧠', label: 'Jobs' },
|
||||
{ path: '/train/results', icon: '📈', label: 'Results' },
|
||||
]
|
||||
|
||||
const stowed = ref(localStorage.getItem(LS_KEY) === 'true')
|
||||
const winWidth = ref(window.innerWidth)
|
||||
const isMobile = computed(() => winWidth.value < 640)
|
||||
|
||||
const signals = ref<DashboardSignals>({
|
||||
data_to_eval: false,
|
||||
eval_to_train: false,
|
||||
train_to_fleet: false,
|
||||
})
|
||||
|
||||
async function loadSignals() {
|
||||
try {
|
||||
const res = await fetch('/api/dashboard')
|
||||
if (res.ok) {
|
||||
const data = await res.json() as { signals?: DashboardSignals }
|
||||
if (data.signals) {
|
||||
signals.value = {
|
||||
data_to_eval: data.signals.data_to_eval ?? false,
|
||||
eval_to_train: data.signals.eval_to_train ?? false,
|
||||
train_to_fleet: data.signals.train_to_fleet ?? false,
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
// Non-fatal: badges simply stay hidden if API is unreachable
|
||||
}
|
||||
}
|
||||
|
||||
function toggle() {
|
||||
stowed.value = !stowed.value
|
||||
localStorage.setItem(LS_KEY, String(stowed.value))
|
||||
// Update CSS variable on :root so .app-main margin-left syncs
|
||||
document.documentElement.style.setProperty('--sidebar-width', stowed.value ? '56px' : '200px')
|
||||
}
|
||||
|
||||
|
|
@ -93,13 +278,12 @@ function onResize() { winWidth.value = window.innerWidth }
|
|||
|
||||
onMounted(() => {
|
||||
window.addEventListener('resize', onResize)
|
||||
// Apply persisted sidebar width to :root on mount
|
||||
document.documentElement.style.setProperty('--sidebar-width', stowed.value ? '56px' : '200px')
|
||||
// On mobile, default to stowed
|
||||
if (isMobile.value && !localStorage.getItem(LS_KEY)) {
|
||||
stowed.value = true
|
||||
document.documentElement.style.setProperty('--sidebar-width', '56px')
|
||||
}
|
||||
loadSignals()
|
||||
})
|
||||
|
||||
onUnmounted(() => window.removeEventListener('resize', onResize))
|
||||
|
|
@ -121,18 +305,15 @@ onUnmounted(() => window.removeEventListener('resize', onResize))
|
|||
overflow: hidden;
|
||||
}
|
||||
|
||||
.sidebar.stowed {
|
||||
width: 56px;
|
||||
}
|
||||
.sidebar.stowed { width: 56px; }
|
||||
|
||||
/* Mobile: slide in/out from left */
|
||||
.sidebar.mobile {
|
||||
box-shadow: 2px 0 16px rgba(0, 0, 0, 0.15);
|
||||
}
|
||||
|
||||
.sidebar.mobile.stowed {
|
||||
transform: translateX(-100%);
|
||||
width: 200px; /* keep width so slide-in looks right */
|
||||
width: 200px;
|
||||
transition: transform 250ms ease, width 250ms ease;
|
||||
}
|
||||
|
||||
|
|
@ -165,10 +346,7 @@ onUnmounted(() => window.removeEventListener('resize', onResize))
|
|||
white-space: nowrap;
|
||||
}
|
||||
|
||||
.logo-icon {
|
||||
font-size: 1.25rem;
|
||||
flex-shrink: 0;
|
||||
}
|
||||
.logo-icon { font-size: 1.25rem; flex-shrink: 0; }
|
||||
|
||||
.logo-name {
|
||||
font-family: var(--font-display, var(--font-body, sans-serif));
|
||||
|
|
@ -193,16 +371,76 @@ onUnmounted(() => window.removeEventListener('resize', onResize))
|
|||
transition: background 0.15s;
|
||||
}
|
||||
|
||||
.stow-btn:hover {
|
||||
background: var(--color-border, #d0d7e8);
|
||||
}
|
||||
.stow-btn:hover { background: var(--color-border, #d0d7e8); }
|
||||
|
||||
.nav-list {
|
||||
list-style: none;
|
||||
padding: 0.5rem 0;
|
||||
flex: 1;
|
||||
overflow-y: auto;
|
||||
overflow-x: hidden;
|
||||
}
|
||||
|
||||
/* ── Section headers ── */
|
||||
.section-header {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.4rem;
|
||||
padding: 0.55rem 0.75rem 0.25rem;
|
||||
margin-top: 0.5rem;
|
||||
pointer-events: none;
|
||||
user-select: none;
|
||||
}
|
||||
|
||||
.section-label {
|
||||
font-size: 0.7rem;
|
||||
font-weight: 700;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.07em;
|
||||
color: var(--color-text-muted, #4a5c7a);
|
||||
white-space: nowrap;
|
||||
flex: 1;
|
||||
}
|
||||
|
||||
.section-icon {
|
||||
font-size: 0.75rem;
|
||||
color: var(--color-text-muted, #4a5c7a);
|
||||
width: 24px;
|
||||
text-align: center;
|
||||
flex-shrink: 0;
|
||||
}
|
||||
|
||||
/* ── Signal badges ── */
|
||||
.signal-badge {
|
||||
width: 8px;
|
||||
height: 8px;
|
||||
border-radius: 50%;
|
||||
background: var(--color-warning, #d4891a);
|
||||
flex-shrink: 0;
|
||||
display: inline-block;
|
||||
}
|
||||
|
||||
.signal-badge-stowed {
|
||||
position: absolute;
|
||||
top: 4px;
|
||||
right: 4px;
|
||||
}
|
||||
|
||||
/* Make the stowed section header container position:relative for the badge */
|
||||
.sidebar.stowed .section-header {
|
||||
position: relative;
|
||||
justify-content: center;
|
||||
padding: 0.55rem 0 0.25rem;
|
||||
}
|
||||
|
||||
/* ── Nav divider ── */
|
||||
.nav-divider {
|
||||
height: 1px;
|
||||
background: var(--color-border, #d0d7e8);
|
||||
margin: 0.5rem 0.75rem;
|
||||
}
|
||||
|
||||
/* ── Nav items ── */
|
||||
.nav-item {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
|
|
@ -238,6 +476,9 @@ onUnmounted(() => window.removeEventListener('resize', onResize))
|
|||
border-radius: 0 2px 2px 0;
|
||||
}
|
||||
|
||||
/* Sub-items are indented slightly in expanded state */
|
||||
.nav-subitem { padding-left: 1.1rem; font-size: 0.875rem; }
|
||||
|
||||
.nav-icon {
|
||||
font-size: 1.1rem;
|
||||
flex-shrink: 0;
|
||||
|
|
@ -245,12 +486,9 @@ onUnmounted(() => window.removeEventListener('resize', onResize))
|
|||
text-align: center;
|
||||
}
|
||||
|
||||
.nav-label {
|
||||
overflow: hidden;
|
||||
text-overflow: ellipsis;
|
||||
}
|
||||
.nav-label { overflow: hidden; text-overflow: ellipsis; }
|
||||
|
||||
/* Mobile hamburger — visible when sidebar is stowed on mobile */
|
||||
/* Mobile hamburger */
|
||||
.mobile-hamburger {
|
||||
position: fixed;
|
||||
top: 0.75rem;
|
||||
|
|
|
|||
170
web/src/components/nodes/CatalogEntryFormModal.vue
Normal file
170
web/src/components/nodes/CatalogEntryFormModal.vue
Normal file
|
|
@ -0,0 +1,170 @@
|
|||
<script setup lang="ts">
|
||||
import { ref, watch } from 'vue'
|
||||
import type { CatalogEntryFull } from '../../types/nodes'
|
||||
|
||||
const props = defineProps<{
|
||||
svcName: string
|
||||
modelName?: string
|
||||
entry?: CatalogEntryFull
|
||||
}>()
|
||||
const emit = defineEmits<{
|
||||
save: [svcName: string, modelName: string, entry: CatalogEntryFull]
|
||||
cancel: []
|
||||
}>()
|
||||
|
||||
const name = ref(props.modelName ?? '')
|
||||
const path = ref(props.entry?.path ?? '')
|
||||
const vramMb = ref(props.entry?.vram_mb ?? 0)
|
||||
const description = ref(props.entry?.description ?? '')
|
||||
const multiGpu = ref(props.entry?.multi_gpu ?? false)
|
||||
const envPairs = ref<{ k: string; v: string }[]>(
|
||||
Object.entries(props.entry?.env ?? {}).map(([k, v]) => ({ k, v }))
|
||||
)
|
||||
const formError = ref('')
|
||||
|
||||
watch(() => props.entry, (e) => {
|
||||
name.value = props.modelName ?? ''
|
||||
path.value = e?.path ?? ''
|
||||
vramMb.value = e?.vram_mb ?? 0
|
||||
description.value = e?.description ?? ''
|
||||
multiGpu.value = e?.multi_gpu ?? false
|
||||
envPairs.value = Object.entries(e?.env ?? {}).map(([k, v]) => ({ k, v }))
|
||||
})
|
||||
|
||||
function addEnvPair() {
|
||||
envPairs.value = [...envPairs.value, { k: '', v: '' }]
|
||||
}
|
||||
function removeEnvPair(i: number) {
|
||||
envPairs.value = envPairs.value.filter((_, idx) => idx !== i)
|
||||
}
|
||||
|
||||
function submit() {
|
||||
formError.value = ''
|
||||
if (!name.value.trim()) { formError.value = 'Model name is required.'; return }
|
||||
if (!path.value.trim()) { formError.value = 'Path is required.'; return }
|
||||
if (!vramMb.value || vramMb.value < 0) { formError.value = 'vram_mb must be a positive number.'; return }
|
||||
|
||||
const envObj: Record<string, string> = {}
|
||||
for (const { k, v } of envPairs.value) {
|
||||
if (k.trim()) envObj[k.trim()] = v
|
||||
}
|
||||
|
||||
const entry: CatalogEntryFull = { path: path.value.trim(), vram_mb: vramMb.value }
|
||||
if (description.value.trim()) entry.description = description.value.trim()
|
||||
if (multiGpu.value) entry.multi_gpu = true
|
||||
if (Object.keys(envObj).length) entry.env = envObj
|
||||
|
||||
emit('save', props.svcName, name.value.trim(), entry)
|
||||
}
|
||||
</script>
|
||||
|
||||
<template>
|
||||
<div class="modal-backdrop" role="dialog" aria-modal="true" :aria-label="`${modelName ? 'Edit' : 'Add'} catalog entry`">
|
||||
<div class="modal-box">
|
||||
<h3 class="modal-title">{{ modelName ? 'Edit' : 'Add' }} Catalog Entry — {{ svcName }}</h3>
|
||||
|
||||
<div class="field-row">
|
||||
<label class="field-label" for="ce-name">Model name</label>
|
||||
<input id="ce-name" v-model="name" class="field-input" :readonly="!!modelName" placeholder="deepseek-r1-7b" />
|
||||
</div>
|
||||
<div class="field-row">
|
||||
<label class="field-label" for="ce-path">Path</label>
|
||||
<input id="ce-path" v-model="path" class="field-input" placeholder="/devl/Assets/LLM/cf-text/models/..." />
|
||||
</div>
|
||||
<div class="field-row">
|
||||
<label class="field-label" for="ce-vram">VRAM (MB)</label>
|
||||
<input id="ce-vram" v-model.number="vramMb" type="number" min="0" class="field-input field-input--sm" />
|
||||
</div>
|
||||
<div class="field-row">
|
||||
<label class="field-label" for="ce-desc">Description</label>
|
||||
<input id="ce-desc" v-model="description" class="field-input" placeholder="Short description" />
|
||||
</div>
|
||||
<div class="field-row field-row--check">
|
||||
<input id="ce-mgpu" v-model="multiGpu" type="checkbox" />
|
||||
<label for="ce-mgpu">Multi-GPU span</label>
|
||||
</div>
|
||||
|
||||
<div class="env-section">
|
||||
<div class="env-header">
|
||||
<span class="field-label">Env vars</span>
|
||||
<button type="button" class="btn-link" @click="addEnvPair">+ Add</button>
|
||||
</div>
|
||||
<div v-for="(pair, i) in envPairs" :key="i" class="env-row">
|
||||
<input v-model="pair.k" class="field-input field-input--sm" placeholder="CF_TEXT_4BIT" />
|
||||
<span>=</span>
|
||||
<input v-model="pair.v" class="field-input field-input--sm" placeholder="1" />
|
||||
<button type="button" class="btn-icon" @click="removeEnvPair(i)" aria-label="Remove">✕</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div v-if="formError" class="form-error" role="alert">{{ formError }}</div>
|
||||
|
||||
<div class="modal-actions">
|
||||
<button class="btn-secondary" @click="emit('cancel')">Cancel</button>
|
||||
<button class="btn-primary" @click="submit">Save</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<style scoped>
|
||||
.modal-backdrop {
|
||||
position: fixed; inset: 0;
|
||||
background: rgba(0,0,0,0.5);
|
||||
display: flex; align-items: center; justify-content: center;
|
||||
z-index: 200;
|
||||
}
|
||||
.modal-box {
|
||||
background: var(--color-surface-raised);
|
||||
border: 1px solid var(--color-border);
|
||||
border-radius: 8px;
|
||||
padding: 1.5rem;
|
||||
width: 100%; max-width: 500px;
|
||||
max-height: 90vh; overflow-y: auto;
|
||||
display: flex; flex-direction: column; gap: 0.75rem;
|
||||
color: var(--color-text);
|
||||
}
|
||||
.modal-title { margin: 0 0 0.25rem; font-size: 1rem; font-weight: 600; color: var(--color-text); }
|
||||
.field-row { display: flex; align-items: center; gap: 0.5rem; }
|
||||
.field-row--check { gap: 0.4rem; color: var(--color-text); }
|
||||
.field-label { min-width: 8rem; font-size: 0.85rem; color: var(--color-text-muted); }
|
||||
.field-input {
|
||||
flex: 1;
|
||||
background: var(--color-surface-alt);
|
||||
border: 1px solid var(--color-border);
|
||||
border-radius: 4px;
|
||||
padding: 0.3rem 0.5rem;
|
||||
color: var(--color-text);
|
||||
font-size: 0.85rem;
|
||||
}
|
||||
.field-input--sm { flex: 0 0 8rem; }
|
||||
.env-section { display: flex; flex-direction: column; gap: 0.35rem; }
|
||||
.env-header { display: flex; align-items: center; justify-content: space-between; }
|
||||
.env-row { display: flex; align-items: center; gap: 0.4rem; }
|
||||
.btn-link { background: none; border: none; color: var(--app-primary); cursor: pointer; font-size: 0.8rem; padding: 0; }
|
||||
.btn-link:hover { color: var(--app-primary-hover); }
|
||||
.btn-icon { background: none; border: none; color: var(--color-text-muted); cursor: pointer; padding: 0 0.2rem; font-size: 0.85rem; }
|
||||
.btn-icon:hover { color: var(--color-error); }
|
||||
.form-error { color: var(--color-error); font-size: 0.8rem; }
|
||||
.modal-actions { display: flex; justify-content: flex-end; gap: 0.5rem; margin-top: 0.25rem; }
|
||||
.btn-primary {
|
||||
background: var(--app-primary);
|
||||
color: var(--color-text-inverse);
|
||||
border: none;
|
||||
border-radius: 4px;
|
||||
padding: 0.4rem 1rem;
|
||||
cursor: pointer;
|
||||
font-size: 0.875rem;
|
||||
}
|
||||
.btn-primary:hover { background: var(--app-primary-hover); }
|
||||
.btn-secondary {
|
||||
background: transparent;
|
||||
border: 1px solid var(--color-border);
|
||||
color: var(--color-text);
|
||||
border-radius: 4px;
|
||||
padding: 0.4rem 0.75rem;
|
||||
cursor: pointer;
|
||||
font-size: 0.875rem;
|
||||
}
|
||||
.btn-secondary:hover { background: var(--color-surface-alt); }
|
||||
</style>
|
||||
129
web/src/components/nodes/GpuRow.vue
Normal file
129
web/src/components/nodes/GpuRow.vue
Normal file
|
|
@ -0,0 +1,129 @@
|
|||
<script setup lang="ts">
|
||||
import { ref, computed } from 'vue'
|
||||
import ServiceBadge from './ServiceBadge.vue'
|
||||
import type { GpuEntry, ServiceInfo } from '../../types/nodes'
|
||||
|
||||
const props = defineProps<{
|
||||
gpu: GpuEntry
|
||||
nodeId: string
|
||||
profileLoaded: boolean
|
||||
servicesCatalog: Record<string, ServiceInfo>
|
||||
}>()
|
||||
|
||||
const emit = defineEmits<{ updated: [] }>()
|
||||
|
||||
const saving = ref(false)
|
||||
const saveError = ref('')
|
||||
|
||||
const vramPct = computed(() => {
|
||||
if (!props.gpu.vram_total_mb) return 0
|
||||
return Math.round((props.gpu.vram_used_mb / props.gpu.vram_total_mb) * 100)
|
||||
})
|
||||
|
||||
function serviceState(svcName: string): 'running' | 'stopped' | 'assigned-only' | 'available' | 'incompatible' | 'unknown' {
|
||||
const svc = props.servicesCatalog[svcName]
|
||||
if (!svc) return 'unknown'
|
||||
const cap = props.gpu.compute_cap ?? 0
|
||||
if (cap < svc.min_compute_cap) return 'incompatible'
|
||||
if (props.gpu.services_running.includes(svcName)) return 'running'
|
||||
if (props.gpu.services_assigned.includes(svcName)) return 'assigned-only'
|
||||
return 'available'
|
||||
}
|
||||
|
||||
async function toggleService(svcName: string) {
|
||||
if (!props.profileLoaded || saving.value) return
|
||||
const current = [...props.gpu.services_assigned]
|
||||
const removing = current.includes(svcName)
|
||||
if (removing && !confirm(`Remove ${svcName} from GPU ${props.gpu.gpu_id}?`)) return
|
||||
const next = removing ? current.filter(s => s !== svcName) : [...current, svcName]
|
||||
|
||||
saving.value = true
|
||||
saveError.value = ''
|
||||
try {
|
||||
const r = await fetch(
|
||||
`/api/nodes-mgmt/nodes/${props.nodeId}/gpu/${props.gpu.gpu_id}/services`,
|
||||
{
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ services: next }),
|
||||
},
|
||||
)
|
||||
if (!r.ok) {
|
||||
const data = await r.json().catch(() => ({}))
|
||||
throw new Error((data as { detail?: string }).detail ?? `HTTP ${r.status}`)
|
||||
}
|
||||
const data = await r.json() as { ok: boolean; reloaded: boolean; warnings: string[] }
|
||||
if (data.warnings?.length) saveError.value = `Saved (warning: ${data.warnings.join(', ')})`
|
||||
emit('updated')
|
||||
} catch (e) {
|
||||
saveError.value = e instanceof Error ? e.message : 'Failed to update services'
|
||||
} finally {
|
||||
saving.value = false
|
||||
}
|
||||
}
|
||||
</script>
|
||||
|
||||
<template>
|
||||
<div class="gpu-row">
|
||||
<div class="gpu-info">
|
||||
<span class="gpu-label">GPU {{ gpu.gpu_id }}: {{ gpu.card }}</span>
|
||||
<span v-if="gpu.compute_cap != null" class="gpu-meta">sm{{ gpu.compute_cap }}</span>
|
||||
<span v-if="gpu.temp_c != null" class="gpu-meta">{{ gpu.temp_c }}°C</span>
|
||||
<span v-if="gpu.utilization_pct != null" class="gpu-meta">{{ gpu.utilization_pct }}%</span>
|
||||
</div>
|
||||
|
||||
<div class="vram-wrap">
|
||||
<div
|
||||
class="vram-bar"
|
||||
role="progressbar"
|
||||
:aria-valuenow="gpu.vram_used_mb"
|
||||
aria-valuemin="0"
|
||||
:aria-valuemax="gpu.vram_total_mb"
|
||||
:aria-label="`VRAM: ${gpu.vram_used_mb} of ${gpu.vram_total_mb} MB used`"
|
||||
>
|
||||
<div class="vram-fill" :style="{ width: `${vramPct}%` }" />
|
||||
</div>
|
||||
<span class="vram-text">{{ gpu.vram_used_mb }} / {{ gpu.vram_total_mb }} MB ({{ vramPct }}%)</span>
|
||||
</div>
|
||||
|
||||
<div v-if="profileLoaded" class="services-row" aria-label="Service assignments">
|
||||
<ServiceBadge
|
||||
v-for="(_, svcName) in servicesCatalog"
|
||||
:key="String(svcName)"
|
||||
:service-name="String(svcName)"
|
||||
:state="serviceState(String(svcName))"
|
||||
:assigned="gpu.services_assigned.includes(String(svcName))"
|
||||
:disabled="saving"
|
||||
@toggle="toggleService(String(svcName))"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div v-if="saveError" class="save-msg" role="alert">{{ saveError }}</div>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<style scoped>
|
||||
.gpu-row {
|
||||
padding: 0.5rem 0.75rem;
|
||||
border-radius: 4px;
|
||||
background: var(--color-surface-alt);
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0.4rem;
|
||||
}
|
||||
.gpu-info { display: flex; gap: 0.75rem; align-items: center; flex-wrap: wrap; font-size: 0.875rem; }
|
||||
.gpu-label { font-weight: 500; color: var(--color-text); }
|
||||
.gpu-meta { color: var(--color-text-muted); font-size: 0.8rem; }
|
||||
.vram-wrap { display: flex; align-items: center; gap: 0.5rem; }
|
||||
.vram-bar {
|
||||
flex: 1;
|
||||
height: 8px;
|
||||
background: var(--color-border);
|
||||
border-radius: 4px;
|
||||
overflow: hidden;
|
||||
}
|
||||
.vram-fill { height: 100%; background: var(--app-primary); transition: width 0.3s; }
|
||||
.vram-text { font-size: 0.75rem; color: var(--color-text-muted); white-space: nowrap; }
|
||||
.services-row { display: flex; flex-wrap: wrap; gap: 0.4rem; }
|
||||
.save-msg { color: var(--color-warning); font-size: 0.8rem; }
|
||||
</style>
|
||||
134
web/src/components/nodes/HfNodeModelPanel.vue
Normal file
134
web/src/components/nodes/HfNodeModelPanel.vue
Normal file
|
|
@ -0,0 +1,134 @@
|
|||
<script setup lang="ts">
|
||||
import { ref, onMounted, onUnmounted } from 'vue'
|
||||
|
||||
interface CatalogEntry {
|
||||
path: string
|
||||
vram_mb: number
|
||||
description: string
|
||||
multi_gpu: boolean
|
||||
}
|
||||
|
||||
interface ServiceProfile {
|
||||
catalog: Record<string, CatalogEntry>
|
||||
min_compute_cap: number
|
||||
max_mb: number
|
||||
}
|
||||
|
||||
interface NodeProfile {
|
||||
services: Record<string, ServiceProfile>
|
||||
}
|
||||
|
||||
const props = defineProps<{
|
||||
nodeId: string
|
||||
}>()
|
||||
|
||||
const profile = ref<NodeProfile | null>(null)
|
||||
const loading = ref(true)
|
||||
const error = ref('')
|
||||
|
||||
let fetchAbort: AbortController | null = null
|
||||
|
||||
async function fetchProfile() {
|
||||
fetchAbort?.abort()
|
||||
fetchAbort = new AbortController()
|
||||
loading.value = true
|
||||
error.value = ''
|
||||
try {
|
||||
const r = await fetch(`/api/nodes-mgmt/nodes/${props.nodeId}/profile`, {
|
||||
signal: fetchAbort.signal,
|
||||
})
|
||||
if (r.status === 404) { profile.value = null; return }
|
||||
if (!r.ok) throw new Error(`HTTP ${r.status}`)
|
||||
profile.value = await r.json() as NodeProfile
|
||||
} catch (e) {
|
||||
if (e instanceof Error && e.name === 'AbortError') return
|
||||
error.value = e instanceof Error ? e.message : 'Failed to load profile'
|
||||
} finally {
|
||||
loading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
onMounted(fetchProfile)
|
||||
onUnmounted(() => { fetchAbort?.abort() })
|
||||
</script>
|
||||
|
||||
<template>
|
||||
<section class="hf-panel">
|
||||
<h3 class="panel-title">Model Catalog</h3>
|
||||
<p class="hf-hint">
|
||||
To download a new HuggingFace model,
|
||||
<a href="#/fleet" class="hf-link">go to Fleet</a>.
|
||||
Models downloaded there are automatically registered in node catalogs.
|
||||
</p>
|
||||
|
||||
<div aria-live="polite" aria-atomic="true" class="sr-announce">
|
||||
<span v-if="loading">Loading catalog...</span>
|
||||
</div>
|
||||
<div v-if="error" class="panel-error" role="alert">{{ error }}</div>
|
||||
<div v-else-if="!loading && !profile" class="panel-empty">No profile loaded for this node.</div>
|
||||
<div v-else-if="!loading && profile" class="catalog-body">
|
||||
<div
|
||||
v-for="(svcInfo, svcName) in profile.services"
|
||||
:key="String(svcName)"
|
||||
class="svc-section"
|
||||
>
|
||||
<h4 class="svc-name">{{ svcName }}</h4>
|
||||
<ul class="catalog-list" role="list">
|
||||
<li
|
||||
v-if="!Object.keys(svcInfo.catalog ?? {}).length"
|
||||
class="catalog-empty"
|
||||
>
|
||||
No models in catalog.
|
||||
</li>
|
||||
<li
|
||||
v-for="(entry, modelName) in (svcInfo.catalog ?? {})"
|
||||
:key="String(modelName)"
|
||||
class="catalog-item"
|
||||
>
|
||||
<span class="catalog-model">{{ modelName }}</span>
|
||||
<span class="catalog-vram">{{ entry.vram_mb }} MB</span>
|
||||
<span v-if="entry.description" class="catalog-desc">{{ entry.description }}</span>
|
||||
</li>
|
||||
</ul>
|
||||
</div>
|
||||
</div>
|
||||
</section>
|
||||
</template>
|
||||
|
||||
<style scoped>
|
||||
.hf-panel {
|
||||
margin-top: 0.75rem;
|
||||
padding: 0.75rem;
|
||||
border: 1px solid var(--color-border);
|
||||
border-radius: 6px;
|
||||
color: var(--color-text);
|
||||
}
|
||||
.panel-title { margin: 0 0 0.5rem; font-size: 0.9rem; color: var(--color-text); }
|
||||
.hf-hint { font-size: 0.8rem; color: var(--color-text-muted); margin: 0 0 0.75rem; }
|
||||
.hf-link { color: var(--app-primary); }
|
||||
.hf-link:hover { color: var(--app-primary-hover); }
|
||||
.svc-section { margin-bottom: 0.75rem; }
|
||||
.svc-name {
|
||||
margin: 0 0 0.25rem;
|
||||
font-size: 0.75rem;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.05em;
|
||||
color: var(--color-text-muted);
|
||||
}
|
||||
.catalog-list { list-style: none; margin: 0; padding: 0; display: flex; flex-direction: column; gap: 0.2rem; }
|
||||
.catalog-item {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.5rem;
|
||||
padding: 0.25rem 0.5rem;
|
||||
background: var(--color-surface-alt);
|
||||
border-radius: 4px;
|
||||
font-size: 0.8rem;
|
||||
}
|
||||
.catalog-model { font-family: var(--font-mono, monospace); flex: 1; }
|
||||
.catalog-vram { color: var(--color-text-muted); white-space: nowrap; }
|
||||
.catalog-desc { color: var(--color-text-muted); font-size: 0.75rem; flex: 2; }
|
||||
.catalog-empty, .panel-empty { color: var(--color-text-muted); font-size: 0.875rem; }
|
||||
.sr-announce { min-height: 1.2em; }
|
||||
.panel-error { color: var(--color-error); font-size: 0.8rem; }
|
||||
</style>
|
||||
148
web/src/components/nodes/NodeCard.vue
Normal file
148
web/src/components/nodes/NodeCard.vue
Normal file
|
|
@ -0,0 +1,148 @@
|
|||
<script setup lang="ts">
|
||||
import { ref } from 'vue'
|
||||
import GpuRow from './GpuRow.vue'
|
||||
import OllamaModelPanel from './OllamaModelPanel.vue'
|
||||
import ProfileEditorPanel from './ProfileEditorPanel.vue'
|
||||
import type { NodeSummary, FullProfile } from '../../types/nodes'
|
||||
|
||||
const props = defineProps<{ node: NodeSummary }>()
|
||||
const emit = defineEmits<{ updated: [] }>()
|
||||
|
||||
const showOllama = ref(false)
|
||||
const showEditor = ref(false)
|
||||
const loadedProfile = ref<FullProfile | null>(null)
|
||||
const profileLoading = ref(false)
|
||||
const profileError = ref('')
|
||||
|
||||
async function openEditor() {
|
||||
if (showEditor.value) { showEditor.value = false; return }
|
||||
profileLoading.value = true
|
||||
profileError.value = ''
|
||||
try {
|
||||
const r = await fetch(`/api/nodes-mgmt/nodes/${props.node.node_id}/profile`)
|
||||
if (r.status === 404) {
|
||||
loadedProfile.value = null
|
||||
} else if (!r.ok) {
|
||||
throw new Error(`HTTP ${r.status}`)
|
||||
} else {
|
||||
loadedProfile.value = await r.json() as FullProfile
|
||||
}
|
||||
showEditor.value = true
|
||||
} catch (e) {
|
||||
profileError.value = e instanceof Error ? e.message : 'Failed to load profile'
|
||||
} finally {
|
||||
profileLoading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
function onProfileSaved() {
|
||||
showEditor.value = false
|
||||
emit('updated')
|
||||
}
|
||||
</script>
|
||||
|
||||
<template>
|
||||
<section class="node-card" :class="{ offline: !node.online }">
|
||||
<header class="node-card-header">
|
||||
<div class="node-identity">
|
||||
<span
|
||||
class="status-dot"
|
||||
:class="node.online ? 'online' : 'offline'"
|
||||
:aria-label="node.online ? 'Online' : 'Offline'"
|
||||
role="img"
|
||||
/>
|
||||
<h2 class="node-name">{{ node.node_id }}</h2>
|
||||
<span class="node-agent">{{ node.agent_url }}</span>
|
||||
</div>
|
||||
<div class="node-actions">
|
||||
<button
|
||||
v-if="node.profile_loaded"
|
||||
class="btn-secondary btn-sm"
|
||||
@click="showOllama = !showOllama"
|
||||
>
|
||||
{{ showOllama ? 'Hide Ollama' : 'Ollama' }}
|
||||
</button>
|
||||
<button
|
||||
class="btn-secondary btn-sm"
|
||||
:disabled="profileLoading"
|
||||
@click="openEditor"
|
||||
>
|
||||
{{ profileLoading ? 'Loading…' : node.profile_loaded ? (showEditor ? 'Close Editor' : 'Edit Profile') : 'Create Profile' }}
|
||||
</button>
|
||||
</div>
|
||||
</header>
|
||||
|
||||
<div v-if="!node.profile_loaded" class="no-profile" role="status">
|
||||
No profile configured for this node. GPU stats are visible; service assignment is disabled.
|
||||
</div>
|
||||
|
||||
<div class="gpu-list">
|
||||
<GpuRow
|
||||
v-for="gpu in node.gpus"
|
||||
:key="gpu.gpu_id"
|
||||
:gpu="gpu"
|
||||
:node-id="node.node_id"
|
||||
:profile-loaded="node.profile_loaded"
|
||||
:services-catalog="node.services_catalog"
|
||||
@updated="emit('updated')"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<OllamaModelPanel v-if="showOllama" :node-id="node.node_id" />
|
||||
<div v-if="profileError" class="profile-load-error" role="alert">{{ profileError }}</div>
|
||||
<ProfileEditorPanel
|
||||
v-if="showEditor"
|
||||
:node-id="node.node_id"
|
||||
:initial-profile="loadedProfile"
|
||||
@saved="onProfileSaved"
|
||||
@close="showEditor = false"
|
||||
/>
|
||||
</section>
|
||||
</template>
|
||||
|
||||
<style scoped>
|
||||
.node-card {
|
||||
border: 1px solid var(--color-border);
|
||||
border-radius: 8px;
|
||||
padding: 1rem;
|
||||
background: var(--color-surface-raised);
|
||||
color: var(--color-text);
|
||||
}
|
||||
.node-card.offline { opacity: 0.65; }
|
||||
.node-card-header {
|
||||
display: flex;
|
||||
align-items: flex-start;
|
||||
justify-content: space-between;
|
||||
gap: 0.5rem;
|
||||
margin-bottom: 0.75rem;
|
||||
}
|
||||
.node-identity { display: flex; align-items: center; gap: 0.5rem; flex-wrap: wrap; }
|
||||
.node-name { margin: 0; font-size: 1rem; font-weight: 600; color: var(--color-text); }
|
||||
.node-agent { color: var(--color-text-muted); font-size: 0.8rem; font-family: var(--font-mono, monospace); }
|
||||
.status-dot { width: 10px; height: 10px; border-radius: 50%; flex-shrink: 0; }
|
||||
.status-dot.online { background: var(--color-success); }
|
||||
.status-dot.offline { background: var(--color-warning); }
|
||||
.node-actions { display: flex; gap: 0.5rem; flex-shrink: 0; }
|
||||
.btn-secondary {
|
||||
background: transparent;
|
||||
border: 1px solid var(--color-border);
|
||||
color: var(--color-text);
|
||||
border-radius: 4px;
|
||||
padding: 0.3rem 0.65rem;
|
||||
cursor: pointer;
|
||||
font-size: 0.8rem;
|
||||
}
|
||||
.btn-secondary:hover { background: var(--color-surface-alt); }
|
||||
.btn-secondary:disabled { opacity: 0.5; cursor: not-allowed; }
|
||||
.btn-sm { font-size: 0.8rem; padding: 0.25rem 0.6rem; }
|
||||
.no-profile {
|
||||
padding: 0.6rem 0.75rem;
|
||||
background: var(--color-surface-alt);
|
||||
border-radius: 4px;
|
||||
color: var(--color-text-muted);
|
||||
font-size: 0.875rem;
|
||||
margin-bottom: 0.5rem;
|
||||
}
|
||||
.gpu-list { display: flex; flex-direction: column; gap: 0.5rem; }
|
||||
.profile-load-error { color: var(--color-error); font-size: 0.8rem; margin-top: 0.5rem; }
|
||||
</style>
|
||||
242
web/src/components/nodes/OllamaModelPanel.vue
Normal file
242
web/src/components/nodes/OllamaModelPanel.vue
Normal file
|
|
@ -0,0 +1,242 @@
|
|||
<script setup lang="ts">
|
||||
import { ref, onMounted, onUnmounted } from 'vue'
|
||||
|
||||
const props = defineProps<{ nodeId: string }>()
|
||||
|
||||
interface OllamaModel {
|
||||
name: string
|
||||
size: number
|
||||
modified_at: string
|
||||
}
|
||||
|
||||
const models = ref<OllamaModel[]>([])
|
||||
const loading = ref(true)
|
||||
const loadError = ref('')
|
||||
const pullName = ref('')
|
||||
const pulling = ref(false)
|
||||
const pullStatus = ref('')
|
||||
const pullPct = ref(0)
|
||||
const pullError = ref('')
|
||||
|
||||
// AbortController for the SSE pull stream
|
||||
const abortCtrl = ref<AbortController | null>(null)
|
||||
|
||||
// AbortController for the one-shot fetchModels request
|
||||
let fetchAbort: AbortController | null = null
|
||||
|
||||
async function fetchModels() {
|
||||
fetchAbort?.abort()
|
||||
fetchAbort = new AbortController()
|
||||
loading.value = true
|
||||
loadError.value = ''
|
||||
try {
|
||||
const r = await fetch(`/api/nodes-mgmt/nodes/${props.nodeId}/models/ollama`, {
|
||||
signal: fetchAbort.signal,
|
||||
})
|
||||
const data = await r.json() as { models?: OllamaModel[]; error?: string }
|
||||
if (data.error) { loadError.value = data.error; return }
|
||||
models.value = data.models ?? []
|
||||
} catch (e) {
|
||||
if (e instanceof Error && e.name === 'AbortError') return
|
||||
loadError.value = e instanceof Error ? e.message : 'Failed to load models'
|
||||
} finally {
|
||||
loading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
async function doPull() {
|
||||
const name = pullName.value.trim()
|
||||
if (!name || pulling.value) return
|
||||
pulling.value = true
|
||||
pullStatus.value = 'Starting...'
|
||||
pullError.value = ''
|
||||
pullPct.value = 0
|
||||
|
||||
const ctrl = new AbortController()
|
||||
abortCtrl.value = ctrl
|
||||
|
||||
try {
|
||||
const resp = await fetch(`/api/nodes-mgmt/nodes/${props.nodeId}/models/ollama/pull`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ name }),
|
||||
signal: ctrl.signal,
|
||||
})
|
||||
if (!resp.ok) throw new Error(`HTTP ${resp.status}`)
|
||||
if (!resp.body) throw new Error('No response body')
|
||||
|
||||
const reader = resp.body.getReader()
|
||||
const decoder = new TextDecoder()
|
||||
let buf = ''
|
||||
|
||||
while (true) {
|
||||
const { done, value } = await reader.read()
|
||||
if (done) break
|
||||
buf += decoder.decode(value, { stream: true })
|
||||
const lines = buf.split('\n')
|
||||
buf = lines.pop() ?? ''
|
||||
for (const line of lines) {
|
||||
if (!line.startsWith('data: ')) continue
|
||||
try {
|
||||
const evt = JSON.parse(line.slice(6)) as {
|
||||
status?: string; error?: string; total?: number; completed?: number
|
||||
}
|
||||
if (evt.error) {
|
||||
pullError.value = evt.error
|
||||
break
|
||||
}
|
||||
if (evt.status) pullStatus.value = evt.status
|
||||
if (evt.total && evt.completed) {
|
||||
pullPct.value = Math.round((evt.completed / evt.total) * 100)
|
||||
}
|
||||
if (evt.status === 'success') {
|
||||
pullStatus.value = 'Done!'
|
||||
pullName.value = ''
|
||||
break
|
||||
}
|
||||
} catch { /* skip malformed line */ }
|
||||
}
|
||||
}
|
||||
|
||||
// Refresh model list after the stream closes (success or benign end)
|
||||
await fetchModels()
|
||||
} catch (e) {
|
||||
if (e instanceof Error && e.name === 'AbortError') return
|
||||
pullError.value = e instanceof Error ? e.message : 'Pull failed'
|
||||
} finally {
|
||||
pulling.value = false
|
||||
abortCtrl.value = null
|
||||
}
|
||||
}
|
||||
|
||||
async function deleteModel(name: string) {
|
||||
if (!confirm(`Delete model "${name}" from node ${props.nodeId}?`)) return
|
||||
try {
|
||||
const r = await fetch(
|
||||
`/api/nodes-mgmt/nodes/${props.nodeId}/models/ollama/${encodeURIComponent(name)}`,
|
||||
{ method: 'DELETE' },
|
||||
)
|
||||
if (!r.ok) throw new Error(`HTTP ${r.status}`)
|
||||
await fetchModels()
|
||||
} catch (e) {
|
||||
loadError.value = e instanceof Error ? e.message : 'Delete failed'
|
||||
}
|
||||
}
|
||||
|
||||
function formatSize(bytes: number): string {
|
||||
return (bytes / 1e9).toFixed(1) + ' GB'
|
||||
}
|
||||
|
||||
onMounted(fetchModels)
|
||||
onUnmounted(() => {
|
||||
abortCtrl.value?.abort()
|
||||
fetchAbort?.abort()
|
||||
})
|
||||
</script>
|
||||
|
||||
<template>
|
||||
<section class="ollama-panel">
|
||||
<h3 class="panel-title">Ollama Models</h3>
|
||||
|
||||
<form class="pull-form" @submit.prevent="doPull">
|
||||
<input
|
||||
v-model="pullName"
|
||||
type="text"
|
||||
placeholder="nomic-embed-text, llama3.2:3b, ..."
|
||||
:disabled="pulling"
|
||||
aria-label="Model name to pull from Ollama"
|
||||
class="pull-input"
|
||||
/>
|
||||
<button type="submit" :disabled="pulling || !pullName.trim()" class="btn-primary btn-sm">
|
||||
{{ pulling ? 'Pulling...' : 'Pull' }}
|
||||
</button>
|
||||
</form>
|
||||
|
||||
<div v-if="pulling || pullStatus" class="pull-progress" aria-live="polite">
|
||||
<div
|
||||
class="progress-bar"
|
||||
role="progressbar"
|
||||
:aria-valuenow="pullPct"
|
||||
aria-valuemin="0"
|
||||
aria-valuemax="100"
|
||||
:aria-label="`Pull progress: ${pullStatus}`"
|
||||
>
|
||||
<div class="progress-fill" :style="{ width: `${pullPct}%` }" />
|
||||
</div>
|
||||
<span class="progress-label">{{ pullStatus }}{{ pullPct > 0 ? ` (${pullPct}%)` : '' }}</span>
|
||||
</div>
|
||||
|
||||
<div v-if="pullError" class="pull-error" role="alert">
|
||||
{{ pullError }}
|
||||
<span v-if="pullError.includes('permission denied')">
|
||||
— Remove the partial file on the node and retry.
|
||||
</span>
|
||||
</div>
|
||||
|
||||
<div aria-live="polite" aria-atomic="true" class="sr-announce">
|
||||
<span v-if="loading">Loading...</span>
|
||||
</div>
|
||||
<div v-if="loadError" class="panel-error" role="alert">{{ loadError }}</div>
|
||||
<ul v-if="!loading && !loadError" class="model-list" role="list">
|
||||
<li v-if="!models.length" class="model-empty">No Ollama models installed on this node.</li>
|
||||
<li v-for="m in models" :key="m.name" class="model-item">
|
||||
<span class="model-name">{{ m.name }}</span>
|
||||
<span class="model-size">{{ formatSize(m.size) }}</span>
|
||||
<button
|
||||
class="btn-danger btn-xs"
|
||||
@click="deleteModel(m.name)"
|
||||
:aria-label="`Delete ${m.name}`"
|
||||
>
|
||||
Delete
|
||||
</button>
|
||||
</li>
|
||||
</ul>
|
||||
</section>
|
||||
</template>
|
||||
|
||||
<style scoped>
|
||||
.ollama-panel {
|
||||
margin-top: 0.75rem;
|
||||
padding: 0.75rem;
|
||||
border: 1px solid var(--color-border);
|
||||
border-radius: 6px;
|
||||
color: var(--color-text);
|
||||
}
|
||||
.panel-title { margin: 0 0 0.75rem; font-size: 0.9rem; color: var(--color-text); }
|
||||
.pull-form { display: flex; gap: 0.5rem; margin-bottom: 0.5rem; }
|
||||
.pull-input {
|
||||
flex: 1;
|
||||
padding: 0.3rem 0.5rem;
|
||||
background: var(--color-surface-alt);
|
||||
border: 1px solid var(--color-border);
|
||||
border-radius: 4px;
|
||||
color: var(--color-text);
|
||||
font-size: 0.875rem;
|
||||
}
|
||||
.pull-progress { margin-bottom: 0.5rem; }
|
||||
.progress-bar {
|
||||
height: 8px;
|
||||
background: var(--color-border);
|
||||
border-radius: 4px;
|
||||
overflow: hidden;
|
||||
margin-bottom: 0.25rem;
|
||||
}
|
||||
.progress-fill { height: 100%; background: var(--app-primary); transition: width 0.2s; }
|
||||
.progress-label { font-size: 0.75rem; color: var(--color-text-muted); }
|
||||
.pull-error, .panel-error { color: var(--color-error); font-size: 0.8rem; margin-bottom: 0.5rem; }
|
||||
.sr-announce { min-height: 1.2em; }
|
||||
.panel-loading { color: var(--color-text-muted); font-size: 0.875rem; }
|
||||
.model-list { list-style: none; margin: 0; padding: 0; display: flex; flex-direction: column; gap: 0.3rem; }
|
||||
.model-item {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.5rem;
|
||||
padding: 0.3rem 0.5rem;
|
||||
background: var(--color-surface-alt);
|
||||
border-radius: 4px;
|
||||
font-size: 0.875rem;
|
||||
}
|
||||
.model-name { flex: 1; font-family: var(--font-mono, monospace); }
|
||||
.model-size { color: var(--color-text-muted); font-size: 0.8rem; }
|
||||
.model-empty { color: var(--color-text-muted); font-size: 0.875rem; padding: 0.25rem 0; }
|
||||
</style>
|
||||
597
web/src/components/nodes/ProfileEditorPanel.vue
Normal file
597
web/src/components/nodes/ProfileEditorPanel.vue
Normal file
|
|
@ -0,0 +1,597 @@
|
|||
<script setup lang="ts">
|
||||
import { ref, onMounted } from 'vue'
|
||||
import type { FullProfile, ServiceDefinition, CatalogEntryFull } from '../../types/nodes'
|
||||
import ServiceFormModal from './ServiceFormModal.vue'
|
||||
import CatalogEntryFormModal from './CatalogEntryFormModal.vue'
|
||||
|
||||
const props = defineProps<{
|
||||
nodeId: string
|
||||
initialProfile: FullProfile | null
|
||||
}>()
|
||||
const emit = defineEmits<{ saved: []; close: [] }>()
|
||||
|
||||
// Deep-clone initial profile so edits don't mutate the parent's data
|
||||
const profile = ref<FullProfile>(
|
||||
props.initialProfile
|
||||
? JSON.parse(JSON.stringify(props.initialProfile))
|
||||
: { services: {}, nodes: {} }
|
||||
)
|
||||
|
||||
const saving = ref(false)
|
||||
const generating = ref(false)
|
||||
const opError = ref('')
|
||||
const expandedSvcs = ref<Set<string>>(new Set())
|
||||
|
||||
// Service modal
|
||||
const showSvcModal = ref(false)
|
||||
const editingSvcName = ref<string | undefined>()
|
||||
const editingSvcDef = ref<ServiceDefinition | undefined>()
|
||||
|
||||
// Catalog modal
|
||||
const showCatalogModal = ref(false)
|
||||
const catalogTargetSvc = ref('')
|
||||
const editingModelName = ref<string | undefined>()
|
||||
const editingEntry = ref<CatalogEntryFull | undefined>()
|
||||
|
||||
// ── Generate nodes section from coordinator ────────────────────────────────────
|
||||
|
||||
async function generate() {
|
||||
generating.value = true
|
||||
opError.value = ''
|
||||
try {
|
||||
const r = await fetch(`/api/nodes-mgmt/nodes/${props.nodeId}/profile/generate`, { method: 'POST' })
|
||||
if (!r.ok) { const d = await r.json().catch(() => ({})); throw new Error((d as {detail?: string}).detail ?? `HTTP ${r.status}`) }
|
||||
const generated = await r.json() as FullProfile
|
||||
// Merge: keep current services edits, replace nodes section
|
||||
profile.value = { ...generated, services: profile.value.services }
|
||||
} catch (e) {
|
||||
opError.value = e instanceof Error ? e.message : 'Generate failed'
|
||||
} finally {
|
||||
generating.value = false
|
||||
}
|
||||
}
|
||||
|
||||
// ── Save full profile ──────────────────────────────────────────────────────────
|
||||
|
||||
async function save() {
|
||||
saving.value = true
|
||||
opError.value = ''
|
||||
try {
|
||||
const r = await fetch(`/api/nodes-mgmt/nodes/${props.nodeId}/profile`, {
|
||||
method: 'PUT',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ profile: profile.value }),
|
||||
})
|
||||
if (!r.ok) { const d = await r.json().catch(() => ({})); throw new Error((d as {detail?: string}).detail ?? `HTTP ${r.status}`) }
|
||||
emit('saved')
|
||||
} catch (e) {
|
||||
opError.value = e instanceof Error ? e.message : 'Save failed'
|
||||
} finally {
|
||||
saving.value = false
|
||||
}
|
||||
}
|
||||
|
||||
// ── Service CRUD ───────────────────────────────────────────────────────────────
|
||||
|
||||
function openAddService() {
|
||||
editingSvcName.value = undefined
|
||||
editingSvcDef.value = undefined
|
||||
showSvcModal.value = true
|
||||
}
|
||||
|
||||
function openEditService(name: string) {
|
||||
editingSvcName.value = name
|
||||
editingSvcDef.value = JSON.parse(JSON.stringify(profile.value.services[name]))
|
||||
showSvcModal.value = true
|
||||
}
|
||||
|
||||
function onServiceSaved(name: string, def: ServiceDefinition) {
|
||||
profile.value = { ...profile.value, services: { ...profile.value.services, [name]: def } }
|
||||
expandedSvcs.value = new Set([...expandedSvcs.value, name])
|
||||
showSvcModal.value = false
|
||||
}
|
||||
|
||||
function deleteService(name: string) {
|
||||
if (!confirm(`Remove service "${name}" from this profile?`)) return
|
||||
const svcs = { ...profile.value.services }
|
||||
delete svcs[name]
|
||||
profile.value = { ...profile.value, services: svcs }
|
||||
expandedSvcs.value = new Set([...expandedSvcs.value].filter(s => s !== name))
|
||||
}
|
||||
|
||||
function toggleSvc(name: string) {
|
||||
const s = new Set(expandedSvcs.value)
|
||||
s.has(name) ? s.delete(name) : s.add(name)
|
||||
expandedSvcs.value = s
|
||||
}
|
||||
|
||||
// ── Catalog CRUD ───────────────────────────────────────────────────────────────
|
||||
|
||||
function openAddCatalogEntry(svcName: string) {
|
||||
catalogTargetSvc.value = svcName
|
||||
editingModelName.value = undefined
|
||||
editingEntry.value = undefined
|
||||
showCatalogModal.value = true
|
||||
}
|
||||
|
||||
function openEditCatalogEntry(svcName: string, modelName: string) {
|
||||
catalogTargetSvc.value = svcName
|
||||
editingModelName.value = modelName
|
||||
editingEntry.value = JSON.parse(JSON.stringify(profile.value.services[svcName].catalog![modelName]))
|
||||
showCatalogModal.value = true
|
||||
}
|
||||
|
||||
function onCatalogEntrySaved(svcName: string, modelName: string, entry: CatalogEntryFull) {
|
||||
const svcs = { ...profile.value.services }
|
||||
const svc = { ...svcs[svcName], catalog: { ...(svcs[svcName].catalog ?? {}), [modelName]: entry } }
|
||||
svcs[svcName] = svc
|
||||
profile.value = { ...profile.value, services: svcs }
|
||||
showCatalogModal.value = false
|
||||
}
|
||||
|
||||
function deleteCatalogEntry(svcName: string, modelName: string) {
|
||||
if (!confirm(`Remove model "${modelName}" from ${svcName} catalog?`)) return
|
||||
const svcs = { ...profile.value.services }
|
||||
const catalog = { ...(svcs[svcName].catalog ?? {}) }
|
||||
delete catalog[modelName]
|
||||
svcs[svcName] = { ...svcs[svcName], catalog }
|
||||
profile.value = { ...profile.value, services: svcs }
|
||||
}
|
||||
|
||||
// ── Helpers ────────────────────────────────────────────────────────────────────
|
||||
|
||||
function gpuList() {
|
||||
return (profile.value.nodes[props.nodeId]?.gpus ?? [])
|
||||
}
|
||||
|
||||
function serviceCount() {
|
||||
return Object.keys(profile.value.services).length
|
||||
}
|
||||
|
||||
// ── Ollama model suggestions ───────────────────────────────────────────────────
|
||||
|
||||
interface OllamaModel { name: string; size: number }
|
||||
const ollamaModels = ref<OllamaModel[]>([])
|
||||
const ollamaLoading = ref(false)
|
||||
|
||||
onMounted(async () => {
|
||||
ollamaLoading.value = true
|
||||
try {
|
||||
const r = await fetch(`/api/nodes-mgmt/nodes/${props.nodeId}/models/ollama`)
|
||||
if (r.ok) {
|
||||
const d = await r.json() as { models?: OllamaModel[] }
|
||||
ollamaModels.value = d.models ?? []
|
||||
}
|
||||
} catch { /* Ollama offline — silently skip */ }
|
||||
finally { ollamaLoading.value = false }
|
||||
})
|
||||
|
||||
function ollamaNotInCatalog(svcName: string): OllamaModel[] {
|
||||
const catalog = profile.value.services[svcName]?.catalog ?? {}
|
||||
return ollamaModels.value.filter(m => !(m.name in catalog))
|
||||
}
|
||||
|
||||
function openAddFromOllama(svcName: string, modelName: string) {
|
||||
catalogTargetSvc.value = svcName
|
||||
editingModelName.value = modelName
|
||||
editingEntry.value = {
|
||||
path: profile.value.services[svcName]?.model_base_path
|
||||
? `${profile.value.services[svcName].model_base_path}/${modelName}`
|
||||
: '',
|
||||
vram_mb: 0,
|
||||
}
|
||||
showCatalogModal.value = true
|
||||
}
|
||||
|
||||
function formatMb(bytes: number): string {
|
||||
return bytes >= 1_000_000_000
|
||||
? `${(bytes / 1_073_741_824).toFixed(1)} GB`
|
||||
: `${Math.round(bytes / 1_048_576)} MB`
|
||||
}
|
||||
|
||||
// ── Pull model onto node ───────────────────────────────────────────────────────
|
||||
|
||||
const pullName = ref('')
|
||||
const pulling = ref(false)
|
||||
const pullStatus = ref('')
|
||||
const pullPct = ref(0)
|
||||
const pullError = ref('')
|
||||
let pullAbort: AbortController | null = null
|
||||
|
||||
async function doPull() {
|
||||
const name = pullName.value.trim()
|
||||
if (!name || pulling.value) return
|
||||
pulling.value = true
|
||||
pullStatus.value = 'Starting…'
|
||||
pullError.value = ''
|
||||
pullPct.value = 0
|
||||
pullAbort?.abort()
|
||||
pullAbort = new AbortController()
|
||||
|
||||
try {
|
||||
const resp = await fetch(`/api/nodes-mgmt/nodes/${props.nodeId}/models/ollama/pull`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ name }),
|
||||
signal: pullAbort.signal,
|
||||
})
|
||||
if (!resp.ok || !resp.body) {
|
||||
pullError.value = `HTTP ${resp.status}`
|
||||
return
|
||||
}
|
||||
const reader = resp.body.getReader()
|
||||
const dec = new TextDecoder()
|
||||
let buf = ''
|
||||
while (true) {
|
||||
const { done, value } = await reader.read()
|
||||
if (done) break
|
||||
buf += dec.decode(value, { stream: true })
|
||||
const lines = buf.split('\n')
|
||||
buf = lines.pop() ?? ''
|
||||
for (const line of lines) {
|
||||
if (!line.startsWith('data:')) continue
|
||||
try {
|
||||
const d = JSON.parse(line.slice(5)) as {
|
||||
status?: string; completed?: number; total?: number; error?: string; done?: boolean
|
||||
}
|
||||
if (d.error) { pullError.value = d.error; return }
|
||||
pullStatus.value = d.status ?? ''
|
||||
if (d.total && d.total > 0) pullPct.value = Math.round((d.completed ?? 0) / d.total * 100)
|
||||
if (d.done) {
|
||||
pullName.value = ''
|
||||
pullPct.value = 100
|
||||
// Refresh Ollama model list so new model appears in suggest chips
|
||||
const r = await fetch(`/api/nodes-mgmt/nodes/${props.nodeId}/models/ollama`)
|
||||
if (r.ok) { const d2 = await r.json() as { models?: OllamaModel[] }; ollamaModels.value = d2.models ?? [] }
|
||||
}
|
||||
} catch { /* skip malformed SSE line */ }
|
||||
}
|
||||
}
|
||||
} catch (e) {
|
||||
if (e instanceof Error && e.name !== 'AbortError') pullError.value = e.message
|
||||
} finally {
|
||||
pulling.value = false
|
||||
if (pullPct.value === 100) setTimeout(() => { pullStatus.value = ''; pullPct.value = 0 }, 2000)
|
||||
}
|
||||
}
|
||||
</script>
|
||||
|
||||
<template>
|
||||
<section class="pep" aria-label="Profile editor">
|
||||
<!-- Header -->
|
||||
<div class="pep-header">
|
||||
<div class="pep-title-row">
|
||||
<h3 class="pep-title">Profile — {{ nodeId }}</h3>
|
||||
<span class="pep-svc-count">{{ serviceCount() }} service{{ serviceCount() === 1 ? '' : 's' }}</span>
|
||||
</div>
|
||||
<div class="pep-actions">
|
||||
<button class="btn-secondary btn-sm" :disabled="generating" @click="generate">
|
||||
{{ generating ? 'Refreshing…' : 'Refresh Hardware' }}
|
||||
</button>
|
||||
<button class="btn-primary btn-sm" :disabled="saving" @click="save">
|
||||
{{ saving ? 'Saving…' : 'Save Profile' }}
|
||||
</button>
|
||||
<button class="btn-icon-lg" aria-label="Close editor" @click="emit('close')">✕</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div v-if="opError" class="pep-error" role="alert">{{ opError }}</div>
|
||||
|
||||
<!-- Meta fields -->
|
||||
<div class="pep-meta">
|
||||
<label class="meta-label" for="pep-vram">vram_total_mb</label>
|
||||
<input id="pep-vram" v-model.number="profile.vram_total_mb" type="number" min="0" class="meta-input" />
|
||||
<label class="meta-label" for="pep-evict">eviction_timeout_s</label>
|
||||
<input id="pep-evict" v-model.number="profile.eviction_timeout_s" type="number" min="0" step="0.5" class="meta-input" />
|
||||
</div>
|
||||
|
||||
<!-- Hardware summary -->
|
||||
<div v-if="gpuList().length" class="hw-section">
|
||||
<span class="hw-label">Hardware</span>
|
||||
<span v-for="g in gpuList()" :key="g.id" class="hw-gpu">
|
||||
GPU {{ g.id }}: {{ g.card || 'unknown' }} · {{ g.vram_mb }} MB · sm{{ g.compute_cap ?? '?' }}
|
||||
</span>
|
||||
<span v-if="!gpuList().length" class="hw-none">No hardware data — click Refresh Hardware.</span>
|
||||
</div>
|
||||
<div v-else class="hw-section">
|
||||
<span class="hw-none">No hardware data — click Refresh Hardware to seed from coordinator.</span>
|
||||
</div>
|
||||
|
||||
<!-- Services -->
|
||||
<div class="svcs-header">
|
||||
<span class="svcs-title">Services</span>
|
||||
<button class="btn-secondary btn-sm" @click="openAddService">+ Add Service</button>
|
||||
</div>
|
||||
|
||||
<div v-if="serviceCount() === 0" class="svcs-empty">
|
||||
No services defined. Add a service to configure what can run on this node.
|
||||
</div>
|
||||
|
||||
<ul class="svcs-list" role="list">
|
||||
<li
|
||||
v-for="(def, svcName) in profile.services"
|
||||
:key="String(svcName)"
|
||||
class="svc-item"
|
||||
>
|
||||
<!-- Service row header -->
|
||||
<div class="svc-row">
|
||||
<button
|
||||
class="svc-toggle"
|
||||
:aria-expanded="expandedSvcs.has(String(svcName))"
|
||||
@click="toggleSvc(String(svcName))"
|
||||
>
|
||||
<span class="svc-arrow">{{ expandedSvcs.has(String(svcName)) ? '▾' : '▸' }}</span>
|
||||
<span class="svc-name">{{ svcName }}</span>
|
||||
</button>
|
||||
<span class="svc-badges">
|
||||
<span class="badge">{{ def.max_mb }} MB</span>
|
||||
<span class="badge">p{{ def.priority }}</span>
|
||||
<span v-if="def.shared" class="badge badge--blue">shared</span>
|
||||
<span v-if="def.managed" class="badge badge--dim">managed</span>
|
||||
<span v-if="def.catalog" class="badge badge--dim">{{ Object.keys(def.catalog).length }} models</span>
|
||||
</span>
|
||||
<div class="svc-btns">
|
||||
<button class="btn-secondary btn-xs" @click="openEditService(String(svcName))">Edit</button>
|
||||
<button class="btn-danger btn-xs" @click="deleteService(String(svcName))">Delete</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Expanded catalog -->
|
||||
<div v-if="expandedSvcs.has(String(svcName))" class="svc-detail">
|
||||
<div class="svc-detail-meta">
|
||||
<span v-if="def.min_compute_cap">min sm{{ def.min_compute_cap }}</span>
|
||||
<span v-if="def.max_concurrent">max_concurrent: {{ def.max_concurrent }}</span>
|
||||
<span v-if="def.idle_stop_after_s">idle_stop: {{ def.idle_stop_after_s }}s</span>
|
||||
<span v-if="def.always_on" class="badge badge--blue">always_on</span>
|
||||
</div>
|
||||
|
||||
<!-- Ollama model suggestions + pull -->
|
||||
<div class="ollama-suggest">
|
||||
<div class="suggest-row">
|
||||
<span class="suggest-label">On node (Ollama):</span>
|
||||
<span v-if="ollamaLoading" class="suggest-loading">loading…</span>
|
||||
<template v-else-if="ollamaNotInCatalog(String(svcName)).length">
|
||||
<button
|
||||
v-for="m in ollamaNotInCatalog(String(svcName))"
|
||||
:key="m.name"
|
||||
class="suggest-chip"
|
||||
:title="`Add ${m.name} (${formatMb(m.size)}) to this service catalog`"
|
||||
@click="openAddFromOllama(String(svcName), m.name)"
|
||||
>
|
||||
+ {{ m.name }} <span class="chip-size">{{ formatMb(m.size) }}</span>
|
||||
</button>
|
||||
</template>
|
||||
<span v-else-if="!ollamaLoading" class="suggest-none">All Ollama models already in catalog.</span>
|
||||
</div>
|
||||
|
||||
<!-- Pull model onto this node -->
|
||||
<div class="pull-row">
|
||||
<input
|
||||
v-model="pullName"
|
||||
class="pull-input"
|
||||
placeholder="Pull model on node (e.g. llama3:8b)"
|
||||
:disabled="pulling"
|
||||
@keyup.enter="doPull"
|
||||
/>
|
||||
<button class="btn-pull" :disabled="pulling || !pullName.trim()" @click="doPull">
|
||||
{{ pulling ? 'Pulling…' : 'Pull' }}
|
||||
</button>
|
||||
</div>
|
||||
<div v-if="pulling || pullPct > 0" class="pull-progress">
|
||||
<div class="pull-bar"><div class="pull-fill" :style="{ width: pullPct + '%' }" /></div>
|
||||
<span class="pull-status">{{ pullStatus }}</span>
|
||||
</div>
|
||||
<div v-if="pullError" class="pull-err" role="alert">{{ pullError }}</div>
|
||||
</div>
|
||||
|
||||
<div class="catalog-header">
|
||||
<span class="catalog-title">Catalog</span>
|
||||
<button class="btn-link" @click="openAddCatalogEntry(String(svcName))">+ Add Model</button>
|
||||
</div>
|
||||
|
||||
<div v-if="!def.catalog || !Object.keys(def.catalog).length" class="catalog-empty">
|
||||
No catalog entries. Only services like cf-text need a catalog.
|
||||
</div>
|
||||
<ul v-else class="catalog-list" role="list">
|
||||
<li
|
||||
v-for="(entry, modelName) in def.catalog"
|
||||
:key="String(modelName)"
|
||||
class="catalog-item"
|
||||
>
|
||||
<span class="catalog-model">{{ modelName }}</span>
|
||||
<span class="catalog-vram">{{ entry.vram_mb }} MB</span>
|
||||
<span v-if="entry.multi_gpu" class="badge badge--dim">multi-gpu</span>
|
||||
<span v-if="entry.description" class="catalog-desc">{{ entry.description }}</span>
|
||||
<div class="catalog-btns">
|
||||
<button class="btn-secondary btn-xs" @click="openEditCatalogEntry(String(svcName), String(modelName))">Edit</button>
|
||||
<button class="btn-danger btn-xs" @click="deleteCatalogEntry(String(svcName), String(modelName))">✕</button>
|
||||
</div>
|
||||
</li>
|
||||
</ul>
|
||||
</div>
|
||||
</li>
|
||||
</ul>
|
||||
</section>
|
||||
|
||||
<!-- Service form modal -->
|
||||
<ServiceFormModal
|
||||
v-if="showSvcModal"
|
||||
:service-name="editingSvcName"
|
||||
:definition="editingSvcDef"
|
||||
@save="onServiceSaved"
|
||||
@cancel="showSvcModal = false"
|
||||
/>
|
||||
|
||||
<!-- Catalog entry form modal -->
|
||||
<CatalogEntryFormModal
|
||||
v-if="showCatalogModal"
|
||||
:svc-name="catalogTargetSvc"
|
||||
:model-name="editingModelName"
|
||||
:entry="editingEntry"
|
||||
@save="onCatalogEntrySaved"
|
||||
@cancel="showCatalogModal = false"
|
||||
/>
|
||||
</template>
|
||||
|
||||
<style scoped>
|
||||
.pep {
|
||||
margin-top: 0.75rem;
|
||||
padding: 1rem;
|
||||
border: 1px solid var(--color-primary);
|
||||
border-radius: 6px;
|
||||
background: var(--color-surface-raised);
|
||||
color: var(--color-text);
|
||||
}
|
||||
.pep-header {
|
||||
display: flex; align-items: center; justify-content: space-between; gap: 0.5rem;
|
||||
margin-bottom: 0.75rem; flex-wrap: wrap;
|
||||
}
|
||||
.pep-title-row { display: flex; align-items: baseline; gap: 0.5rem; }
|
||||
.pep-title { margin: 0; font-size: 0.95rem; font-weight: 600; color: var(--color-text); }
|
||||
.pep-svc-count { font-size: 0.75rem; color: var(--color-text-muted); }
|
||||
.pep-actions { display: flex; align-items: center; gap: 0.4rem; flex-wrap: wrap; }
|
||||
.pep-error { color: var(--color-error); font-size: 0.8rem; margin-bottom: 0.5rem; }
|
||||
.pep-meta {
|
||||
display: flex; align-items: center; gap: 0.5rem; flex-wrap: wrap;
|
||||
padding: 0.5rem; background: var(--color-surface-alt); border-radius: 4px; margin-bottom: 0.75rem;
|
||||
}
|
||||
.meta-label { font-size: 0.8rem; color: var(--color-text-muted); }
|
||||
.meta-input {
|
||||
width: 7rem; background: var(--color-surface); border: 1px solid var(--color-border);
|
||||
border-radius: 4px; padding: 0.2rem 0.4rem; color: var(--color-text); font-size: 0.8rem;
|
||||
}
|
||||
.hw-section {
|
||||
display: flex; flex-wrap: wrap; align-items: center; gap: 0.5rem;
|
||||
font-size: 0.8rem; color: var(--color-text-muted);
|
||||
padding: 0.4rem 0.5rem; border-radius: 4px; background: var(--color-surface-alt);
|
||||
margin-bottom: 0.75rem;
|
||||
}
|
||||
.hw-label { font-weight: 600; color: var(--color-text); }
|
||||
.hw-gpu { font-family: monospace; color: var(--color-text); }
|
||||
.hw-none { font-style: italic; }
|
||||
.svcs-header {
|
||||
display: flex; align-items: center; justify-content: space-between;
|
||||
margin-bottom: 0.5rem;
|
||||
}
|
||||
.svcs-title { font-size: 0.85rem; font-weight: 600; color: var(--color-text); }
|
||||
.svcs-empty { color: var(--color-text-muted); font-size: 0.85rem; padding: 0.5rem 0; }
|
||||
.svcs-list { list-style: none; margin: 0; padding: 0; display: flex; flex-direction: column; gap: 0.4rem; }
|
||||
.svc-item { border: 1px solid var(--color-border); border-radius: 4px; overflow: hidden; }
|
||||
.svc-row {
|
||||
display: flex; align-items: center; gap: 0.5rem; padding: 0.4rem 0.5rem;
|
||||
background: var(--color-surface-alt); flex-wrap: wrap;
|
||||
}
|
||||
.svc-toggle {
|
||||
display: flex; align-items: center; gap: 0.35rem;
|
||||
background: none; border: none; cursor: pointer; color: var(--color-text); padding: 0; flex: 1; min-width: 0;
|
||||
}
|
||||
.svc-arrow { font-size: 0.7rem; color: var(--color-text-muted); }
|
||||
.svc-name { font-size: 0.875rem; font-weight: 500; font-family: monospace; }
|
||||
.svc-badges { display: flex; gap: 0.3rem; flex-wrap: wrap; }
|
||||
.svc-btns { display: flex; gap: 0.3rem; margin-left: auto; }
|
||||
.svc-detail { padding: 0.5rem 0.75rem; display: flex; flex-direction: column; gap: 0.5rem; background: var(--color-surface-raised); }
|
||||
.svc-detail-meta {
|
||||
display: flex; gap: 0.5rem; flex-wrap: wrap;
|
||||
font-size: 0.78rem; color: var(--color-text-muted);
|
||||
}
|
||||
.ollama-suggest {
|
||||
display: flex; flex-direction: column; gap: 0.35rem;
|
||||
padding: 0.4rem 0.5rem;
|
||||
background: var(--color-primary-light);
|
||||
border: 1px solid var(--color-border-light);
|
||||
border-radius: 4px;
|
||||
font-size: 0.78rem;
|
||||
}
|
||||
.suggest-row { display: flex; flex-wrap: wrap; align-items: center; gap: 0.35rem; }
|
||||
.suggest-label { color: var(--color-text-muted); font-weight: 500; white-space: nowrap; }
|
||||
.suggest-loading { color: var(--color-text-muted); font-style: italic; }
|
||||
.suggest-none { color: var(--color-text-muted); font-style: italic; }
|
||||
.suggest-chip {
|
||||
display: inline-flex; align-items: center; gap: 0.25rem;
|
||||
padding: 0.15rem 0.45rem;
|
||||
background: var(--color-surface-raised);
|
||||
border: 1px solid var(--color-border);
|
||||
border-radius: 3px;
|
||||
color: var(--color-text);
|
||||
cursor: pointer;
|
||||
font-size: 0.78rem;
|
||||
transition: border-color 0.15s, background 0.15s;
|
||||
}
|
||||
.suggest-chip:hover { border-color: var(--app-primary); background: var(--color-surface-alt); }
|
||||
.chip-size { color: var(--color-text-muted); font-size: 0.72rem; }
|
||||
.pull-row { display: flex; gap: 0.4rem; align-items: center; }
|
||||
.pull-input {
|
||||
flex: 1;
|
||||
padding: 0.25rem 0.5rem;
|
||||
background: var(--color-surface-raised);
|
||||
border: 1px solid var(--color-border);
|
||||
border-radius: 4px;
|
||||
color: var(--color-text);
|
||||
font-size: 0.78rem;
|
||||
font-family: var(--font-mono, monospace);
|
||||
}
|
||||
.pull-input:disabled { opacity: 0.5; }
|
||||
.btn-pull {
|
||||
padding: 0.25rem 0.6rem;
|
||||
background: var(--app-primary);
|
||||
color: var(--color-text-inverse);
|
||||
border: none;
|
||||
border-radius: 4px;
|
||||
cursor: pointer;
|
||||
font-size: 0.78rem;
|
||||
white-space: nowrap;
|
||||
}
|
||||
.btn-pull:hover:not(:disabled) { background: var(--app-primary-hover); }
|
||||
.btn-pull:disabled { opacity: 0.5; cursor: not-allowed; }
|
||||
.pull-progress { display: flex; align-items: center; gap: 0.4rem; }
|
||||
.pull-bar {
|
||||
flex: 1; height: 6px;
|
||||
background: var(--color-border);
|
||||
border-radius: 3px; overflow: hidden;
|
||||
}
|
||||
.pull-fill { height: 100%; background: var(--app-primary); transition: width 0.2s; }
|
||||
.pull-status { color: var(--color-text-muted); font-size: 0.72rem; white-space: nowrap; max-width: 14rem; overflow: hidden; text-overflow: ellipsis; }
|
||||
.pull-err { color: var(--color-error); font-size: 0.75rem; }
|
||||
.catalog-header { display: flex; align-items: center; justify-content: space-between; }
|
||||
.catalog-title { font-size: 0.8rem; font-weight: 600; color: var(--color-text-muted); text-transform: uppercase; letter-spacing: 0.05em; }
|
||||
.catalog-empty { font-size: 0.8rem; color: var(--color-text-muted); font-style: italic; }
|
||||
.catalog-list { list-style: none; margin: 0; padding: 0; display: flex; flex-direction: column; gap: 0.25rem; }
|
||||
.catalog-item {
|
||||
display: flex; align-items: center; gap: 0.4rem; flex-wrap: wrap;
|
||||
padding: 0.25rem 0.5rem; background: var(--color-surface-alt); border-radius: 3px; font-size: 0.8rem;
|
||||
color: var(--color-text);
|
||||
}
|
||||
.catalog-model { font-family: monospace; flex: 1; min-width: 12rem; }
|
||||
.catalog-vram { color: var(--color-text-muted); white-space: nowrap; }
|
||||
.catalog-desc { color: var(--color-text-muted); flex: 2; font-size: 0.75rem; }
|
||||
.catalog-btns { display: flex; gap: 0.25rem; margin-left: auto; }
|
||||
.badge {
|
||||
padding: 0.1rem 0.4rem; border-radius: 3px; font-size: 0.72rem;
|
||||
background: var(--color-surface); border: 1px solid var(--color-border); color: var(--color-text);
|
||||
}
|
||||
.badge--blue { border-color: var(--color-primary); color: var(--color-primary); background: var(--color-primary-light); }
|
||||
.badge--dim { opacity: 0.75; }
|
||||
.btn-link { background: none; border: none; color: var(--color-accent); cursor: pointer; font-size: 0.8rem; padding: 0; }
|
||||
.btn-link:hover { color: var(--color-accent-hover); }
|
||||
.btn-primary {
|
||||
background: var(--color-primary); color: var(--color-text-inverse); border: none;
|
||||
border-radius: 4px; cursor: pointer; font-size: 0.8rem;
|
||||
}
|
||||
.btn-primary:hover { background: var(--color-primary-hover); }
|
||||
.btn-primary:disabled { opacity: 0.6; cursor: not-allowed; }
|
||||
.btn-secondary {
|
||||
background: transparent; border: 1px solid var(--color-border); color: var(--color-text);
|
||||
border-radius: 4px; cursor: pointer; font-size: 0.8rem;
|
||||
}
|
||||
.btn-secondary:hover { background: var(--color-surface-alt); }
|
||||
.btn-secondary:disabled { opacity: 0.6; cursor: not-allowed; }
|
||||
.btn-danger {
|
||||
background: transparent; border: 1px solid var(--color-error); color: var(--color-error);
|
||||
border-radius: 4px; cursor: pointer; font-size: 0.8rem;
|
||||
}
|
||||
.btn-danger:hover { background: var(--color-surface-alt); }
|
||||
.btn-sm { padding: 0.3rem 0.6rem; }
|
||||
.btn-xs { padding: 0.15rem 0.4rem; }
|
||||
.btn-icon-lg { background: none; border: none; color: var(--color-text-muted); cursor: pointer; font-size: 1rem; padding: 0.2rem 0.3rem; }
|
||||
.btn-icon-lg:hover { color: var(--color-text); }
|
||||
</style>
|
||||
82
web/src/components/nodes/ServiceBadge.vue
Normal file
82
web/src/components/nodes/ServiceBadge.vue
Normal file
|
|
@ -0,0 +1,82 @@
|
|||
<script setup lang="ts">
|
||||
type ServiceState =
|
||||
| 'running'
|
||||
| 'stopped'
|
||||
| 'assigned-only'
|
||||
| 'available'
|
||||
| 'incompatible'
|
||||
| 'vram-tight'
|
||||
| 'unknown'
|
||||
|
||||
const props = defineProps<{
|
||||
serviceName: string
|
||||
state: ServiceState
|
||||
assigned: boolean
|
||||
disabled?: boolean
|
||||
}>()
|
||||
|
||||
const emit = defineEmits<{ toggle: [] }>()
|
||||
|
||||
const STATE_LABELS: Record<ServiceState, string> = {
|
||||
running: 'Running',
|
||||
stopped: 'Stopped',
|
||||
'assigned-only': 'Assigned',
|
||||
available: 'Available',
|
||||
incompatible: 'Incompatible',
|
||||
'vram-tight': 'VRAM tight',
|
||||
unknown: 'Unknown',
|
||||
}
|
||||
|
||||
const STATE_ICONS: Record<ServiceState, string> = {
|
||||
running: '▶',
|
||||
stopped: '⏹',
|
||||
'assigned-only': '📌',
|
||||
available: '○',
|
||||
incompatible: '✕',
|
||||
'vram-tight': '⚠',
|
||||
unknown: '?',
|
||||
}
|
||||
|
||||
function handleToggle() {
|
||||
if (!props.disabled && props.state !== 'incompatible') emit('toggle')
|
||||
}
|
||||
</script>
|
||||
|
||||
<template>
|
||||
<button
|
||||
class="service-badge"
|
||||
:class="[`state-${state}`, { assigned, 'is-disabled': disabled || state === 'incompatible' }]"
|
||||
:aria-pressed="assigned"
|
||||
:aria-label="`${serviceName}: ${STATE_LABELS[state] ?? state}${assigned ? ' (assigned)' : ''}`"
|
||||
:disabled="disabled || state === 'incompatible'"
|
||||
@click="handleToggle"
|
||||
>
|
||||
<span class="badge-icon" aria-hidden="true">{{ STATE_ICONS[state] ?? '?' }}</span>
|
||||
<span class="badge-name">{{ serviceName }}</span>
|
||||
<span class="badge-state">{{ STATE_LABELS[state] ?? state }}</span>
|
||||
</button>
|
||||
</template>
|
||||
|
||||
<style scoped>
|
||||
.service-badge {
|
||||
display: inline-flex;
|
||||
align-items: center;
|
||||
gap: 0.3rem;
|
||||
padding: 0.2rem 0.5rem;
|
||||
border-radius: 4px;
|
||||
border: 1px solid var(--color-border);
|
||||
background: var(--color-surface);
|
||||
color: var(--color-text);
|
||||
font-size: 0.75rem;
|
||||
cursor: pointer;
|
||||
transition: opacity 0.1s, border-color 0.1s;
|
||||
}
|
||||
.service-badge:hover:not(.is-disabled) { opacity: 0.8; }
|
||||
.service-badge.is-disabled { cursor: not-allowed; opacity: 0.5; }
|
||||
.service-badge.state-running { border-color: var(--color-success); }
|
||||
.service-badge.state-stopped { border-color: var(--color-warning); }
|
||||
.service-badge.state-assigned-only { border-color: var(--color-info); }
|
||||
.service-badge.state-incompatible { border-color: var(--color-error); }
|
||||
.service-badge.state-vram-tight { border-color: var(--color-warning); }
|
||||
.badge-state { color: var(--color-text-muted); }
|
||||
</style>
|
||||
231
web/src/components/nodes/ServiceFormModal.vue
Normal file
231
web/src/components/nodes/ServiceFormModal.vue
Normal file
|
|
@ -0,0 +1,231 @@
|
|||
<script setup lang="ts">
|
||||
import { ref, watch, computed } from 'vue'
|
||||
import type { ServiceDefinition } from '../../types/nodes'
|
||||
|
||||
const props = defineProps<{
|
||||
serviceName?: string
|
||||
definition?: ServiceDefinition
|
||||
}>()
|
||||
const emit = defineEmits<{
|
||||
save: [name: string, def: ServiceDefinition]
|
||||
cancel: []
|
||||
}>()
|
||||
|
||||
const name = ref(props.serviceName ?? '')
|
||||
const maxMb = ref(props.definition?.max_mb ?? 0)
|
||||
const priority = ref(props.definition?.priority ?? 1)
|
||||
const minCap = ref(props.definition?.min_compute_cap ?? 0)
|
||||
const prefCap = ref<number | ''>(props.definition?.preferred_compute_cap ?? '')
|
||||
const shared = ref(props.definition?.shared ?? false)
|
||||
const maxConcurrent = ref<number | ''>(props.definition?.max_concurrent ?? '')
|
||||
const idleStop = ref<number | ''>(props.definition?.idle_stop_after_s ?? '')
|
||||
const alwaysOn = ref(props.definition?.always_on ?? false)
|
||||
const modelBasePath = ref(props.definition?.model_base_path ?? '')
|
||||
const hasManaged = ref(!!props.definition?.managed)
|
||||
const managedJson = ref(
|
||||
props.definition?.managed ? JSON.stringify(props.definition.managed, null, 2) : ''
|
||||
)
|
||||
const formError = ref('')
|
||||
|
||||
watch(() => props.definition, (d) => {
|
||||
name.value = props.serviceName ?? ''
|
||||
maxMb.value = d?.max_mb ?? 0
|
||||
priority.value = d?.priority ?? 1
|
||||
minCap.value = d?.min_compute_cap ?? 0
|
||||
prefCap.value = d?.preferred_compute_cap ?? ''
|
||||
shared.value = d?.shared ?? false
|
||||
maxConcurrent.value = d?.max_concurrent ?? ''
|
||||
idleStop.value = d?.idle_stop_after_s ?? ''
|
||||
alwaysOn.value = d?.always_on ?? false
|
||||
modelBasePath.value = d?.model_base_path ?? ''
|
||||
hasManaged.value = !!d?.managed
|
||||
managedJson.value = d?.managed ? JSON.stringify(d.managed, null, 2) : ''
|
||||
})
|
||||
|
||||
const managedJsonError = computed(() => {
|
||||
if (!hasManaged.value || !managedJson.value.trim()) return ''
|
||||
try { JSON.parse(managedJson.value); return '' }
|
||||
catch { return 'Invalid JSON' }
|
||||
})
|
||||
|
||||
function submit() {
|
||||
formError.value = ''
|
||||
if (!name.value.trim()) { formError.value = 'Service name is required.'; return }
|
||||
if (!maxMb.value || maxMb.value <= 0) { formError.value = 'max_mb must be > 0.'; return }
|
||||
if (managedJsonError.value) { formError.value = 'Fix the managed JSON before saving.'; return }
|
||||
|
||||
const def: ServiceDefinition = { max_mb: maxMb.value, priority: priority.value }
|
||||
if (minCap.value) def.min_compute_cap = minCap.value
|
||||
if (prefCap.value !== '') def.preferred_compute_cap = Number(prefCap.value)
|
||||
if (shared.value) def.shared = true
|
||||
if (maxConcurrent.value !== '') def.max_concurrent = Number(maxConcurrent.value)
|
||||
if (idleStop.value !== '') def.idle_stop_after_s = Number(idleStop.value)
|
||||
if (alwaysOn.value) def.always_on = true
|
||||
if (modelBasePath.value.trim()) def.model_base_path = modelBasePath.value.trim()
|
||||
if (hasManaged.value && managedJson.value.trim()) {
|
||||
def.managed = JSON.parse(managedJson.value)
|
||||
}
|
||||
// Preserve existing catalog when editing
|
||||
if (props.definition?.catalog) def.catalog = props.definition.catalog
|
||||
|
||||
emit('save', name.value.trim(), def)
|
||||
}
|
||||
</script>
|
||||
|
||||
<template>
|
||||
<div class="modal-backdrop" role="dialog" aria-modal="true" :aria-label="`${serviceName ? 'Edit' : 'Add'} service`">
|
||||
<div class="modal-box">
|
||||
<h3 class="modal-title">{{ serviceName ? 'Edit' : 'Add' }} Service</h3>
|
||||
|
||||
<div class="field-row">
|
||||
<label class="field-label" for="sf-name">Service name</label>
|
||||
<input id="sf-name" v-model="name" class="field-input" :readonly="!!serviceName" placeholder="cf-text" />
|
||||
</div>
|
||||
|
||||
<div class="field-row">
|
||||
<label class="field-label" for="sf-maxmb">max_mb</label>
|
||||
<input id="sf-maxmb" v-model.number="maxMb" type="number" min="0" class="field-input field-input--sm" />
|
||||
<span class="field-hint">VRAM ceiling</span>
|
||||
</div>
|
||||
|
||||
<div class="field-row">
|
||||
<label class="field-label" for="sf-prio">priority</label>
|
||||
<input id="sf-prio" v-model.number="priority" type="number" min="1" max="10" class="field-input field-input--sm" />
|
||||
<span class="field-hint">1 = highest</span>
|
||||
</div>
|
||||
|
||||
<div class="field-row">
|
||||
<label class="field-label" for="sf-mincap">min_compute_cap</label>
|
||||
<input id="sf-mincap" v-model.number="minCap" type="number" step="0.1" min="0" class="field-input field-input--sm" placeholder="0.0" />
|
||||
</div>
|
||||
|
||||
<div class="field-row">
|
||||
<label class="field-label" for="sf-prefcap">preferred_cap</label>
|
||||
<input id="sf-prefcap" v-model="prefCap" type="number" step="0.1" min="0" class="field-input field-input--sm" placeholder="optional" />
|
||||
</div>
|
||||
|
||||
<div class="field-row field-row--check">
|
||||
<input id="sf-shared" v-model="shared" type="checkbox" />
|
||||
<label for="sf-shared">shared (multiple concurrent users)</label>
|
||||
</div>
|
||||
|
||||
<div class="field-row">
|
||||
<label class="field-label" for="sf-maxcon">max_concurrent</label>
|
||||
<input id="sf-maxcon" v-model="maxConcurrent" type="number" min="1" class="field-input field-input--sm" placeholder="optional" />
|
||||
</div>
|
||||
|
||||
<div class="field-row">
|
||||
<label class="field-label" for="sf-idle">idle_stop_after_s</label>
|
||||
<input id="sf-idle" v-model="idleStop" type="number" min="0" class="field-input field-input--sm" placeholder="optional" />
|
||||
<span class="field-hint">seconds</span>
|
||||
</div>
|
||||
|
||||
<div class="field-row field-row--check">
|
||||
<input id="sf-always" v-model="alwaysOn" type="checkbox" />
|
||||
<label for="sf-always">always_on (never evict)</label>
|
||||
</div>
|
||||
|
||||
<div class="field-row">
|
||||
<label class="field-label" for="sf-base">model_base_path</label>
|
||||
<input id="sf-base" v-model="modelBasePath" class="field-input" placeholder="/devl/Assets/LLM/cf-text/models (optional)" />
|
||||
</div>
|
||||
|
||||
<div class="managed-section">
|
||||
<div class="field-row field-row--check">
|
||||
<input id="sf-has-managed" v-model="hasManaged" type="checkbox" />
|
||||
<label for="sf-has-managed">Has managed process config</label>
|
||||
</div>
|
||||
<div v-if="hasManaged" class="managed-body">
|
||||
<label class="field-label" for="sf-managed">managed (JSON)</label>
|
||||
<textarea
|
||||
id="sf-managed"
|
||||
v-model="managedJson"
|
||||
class="field-textarea"
|
||||
rows="6"
|
||||
spellcheck="false"
|
||||
placeholder='{"type": "process", "exec_path": "...", "args_template": "...", "port": 8008, "host_port": 8008}'
|
||||
/>
|
||||
<span v-if="managedJsonError" class="json-error" role="alert">{{ managedJsonError }}</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div v-if="formError" class="form-error" role="alert">{{ formError }}</div>
|
||||
|
||||
<div class="modal-actions">
|
||||
<button class="btn-secondary" @click="emit('cancel')">Cancel</button>
|
||||
<button class="btn-primary" @click="submit">Save</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<style scoped>
|
||||
.modal-backdrop {
|
||||
position: fixed; inset: 0;
|
||||
background: rgba(0,0,0,0.5);
|
||||
display: flex; align-items: center; justify-content: center;
|
||||
z-index: 200;
|
||||
}
|
||||
.modal-box {
|
||||
background: var(--color-surface-raised);
|
||||
border: 1px solid var(--color-border);
|
||||
border-radius: 8px;
|
||||
padding: 1.5rem;
|
||||
width: 100%; max-width: 540px;
|
||||
max-height: 90vh; overflow-y: auto;
|
||||
display: flex; flex-direction: column; gap: 0.65rem;
|
||||
color: var(--color-text);
|
||||
}
|
||||
.modal-title { margin: 0 0 0.25rem; font-size: 1rem; font-weight: 600; color: var(--color-text); }
|
||||
.field-row { display: flex; align-items: center; gap: 0.5rem; }
|
||||
.field-row--check { gap: 0.4rem; font-size: 0.875rem; color: var(--color-text); }
|
||||
.field-label { min-width: 9rem; font-size: 0.85rem; color: var(--color-text-muted); flex-shrink: 0; }
|
||||
.field-hint { font-size: 0.75rem; color: var(--color-text-muted); }
|
||||
.field-input {
|
||||
flex: 1;
|
||||
background: var(--color-surface-alt);
|
||||
border: 1px solid var(--color-border);
|
||||
border-radius: 4px;
|
||||
padding: 0.3rem 0.5rem;
|
||||
color: var(--color-text);
|
||||
font-size: 0.85rem;
|
||||
}
|
||||
.field-input--sm { flex: 0 0 8rem; }
|
||||
.managed-section { display: flex; flex-direction: column; gap: 0.4rem; border-top: 1px solid var(--color-border); padding-top: 0.5rem; }
|
||||
.managed-body { display: flex; flex-direction: column; gap: 0.3rem; }
|
||||
.field-textarea {
|
||||
width: 100%;
|
||||
background: var(--color-surface-alt);
|
||||
border: 1px solid var(--color-border);
|
||||
border-radius: 4px;
|
||||
padding: 0.4rem 0.5rem;
|
||||
color: var(--color-text);
|
||||
font-size: 0.8rem;
|
||||
font-family: var(--font-mono, monospace);
|
||||
resize: vertical;
|
||||
box-sizing: border-box;
|
||||
}
|
||||
.json-error { color: var(--color-error); font-size: 0.78rem; }
|
||||
.form-error { color: var(--color-error); font-size: 0.8rem; }
|
||||
.modal-actions { display: flex; justify-content: flex-end; gap: 0.5rem; margin-top: 0.25rem; }
|
||||
.btn-primary {
|
||||
background: var(--app-primary);
|
||||
color: var(--color-text-inverse);
|
||||
border: none;
|
||||
border-radius: 4px;
|
||||
padding: 0.4rem 1rem;
|
||||
cursor: pointer;
|
||||
font-size: 0.875rem;
|
||||
}
|
||||
.btn-primary:hover { background: var(--app-primary-hover); }
|
||||
.btn-secondary {
|
||||
background: transparent;
|
||||
border: 1px solid var(--color-border);
|
||||
color: var(--color-text);
|
||||
border-radius: 4px;
|
||||
padding: 0.4rem 0.75rem;
|
||||
cursor: pointer;
|
||||
font-size: 0.875rem;
|
||||
}
|
||||
.btn-secondary:hover { background: var(--color-surface-alt); }
|
||||
</style>
|
||||
|
|
@ -1,25 +1,53 @@
|
|||
import { createRouter, createWebHashHistory } from 'vue-router'
|
||||
import LabelView from '../views/LabelView.vue'
|
||||
|
||||
// Views are lazy-loaded to keep initial bundle small
|
||||
// Lazy-loaded views
|
||||
const DashboardView = () => import('../views/DashboardView.vue')
|
||||
const LabelView = () => import('../views/LabelView.vue')
|
||||
const FetchView = () => import('../views/FetchView.vue')
|
||||
const StatsView = () => import('../views/StatsView.vue')
|
||||
const BenchmarkView = () => import('../views/BenchmarkView.vue')
|
||||
const SettingsView = () => import('../views/SettingsView.vue')
|
||||
const CorrectionsView = () => import('../views/CorrectionsView.vue')
|
||||
const ModelsView = () => import('../views/ModelsView.vue')
|
||||
const ImitateView = () => import('../views/ImitateView.vue')
|
||||
const BenchmarkView = () => import('../views/BenchmarkView.vue')
|
||||
const CompareView = () => import('../views/CompareView.vue')
|
||||
const TrainJobsView = () => import('../views/TrainJobsView.vue')
|
||||
const TrainResultsView = () => import('../views/TrainResultsView.vue')
|
||||
const ModelsView = () => import('../views/ModelsView.vue')
|
||||
const SettingsView = () => import('../views/SettingsView.vue')
|
||||
const NodeManagementView = () => import('../views/NodeManagementView.vue')
|
||||
|
||||
export const routes = [
|
||||
// ── Top-level ────────────────────────────────────────────
|
||||
{ path: '/', component: DashboardView, meta: { title: 'Dashboard' } },
|
||||
{ path: '/fleet', component: ModelsView, meta: { title: 'Fleet' } },
|
||||
{ path: '/nodes', component: NodeManagementView, meta: { title: 'Nodes' } },
|
||||
{ path: '/settings', component: SettingsView, meta: { title: 'Settings' } },
|
||||
|
||||
// ── Data domain ──────────────────────────────────────────
|
||||
{ path: '/data/label', component: LabelView, meta: { title: 'Label' } },
|
||||
{ path: '/data/fetch', component: FetchView, meta: { title: 'Fetch' } },
|
||||
{ path: '/data/corrections', component: CorrectionsView, meta: { title: 'Corrections' } },
|
||||
{ path: '/data/imitate', component: ImitateView, meta: { title: 'Imitate' } },
|
||||
{ path: '/data/recipe-scan', component: () => import('../views/RecipeScanView.vue'), meta: { title: 'Recipe Scan' } },
|
||||
|
||||
// ── Eval domain ──────────────────────────────────────────
|
||||
{ path: '/eval/benchmark', component: BenchmarkView, meta: { title: 'Benchmark' } },
|
||||
{ path: '/eval/compare', component: CompareView, meta: { title: 'Compare' } },
|
||||
{ path: '/eval/embed-compare', component: () => import('../views/EmbedCompareView.vue'), meta: { title: 'Embed Compare' } },
|
||||
|
||||
// ── Train domain ─────────────────────────────────────────
|
||||
{ path: '/train/jobs', component: TrainJobsView, meta: { title: 'Training Jobs' } },
|
||||
{ path: '/train/results', component: TrainResultsView, meta: { title: 'Training Results' } },
|
||||
|
||||
// ── Backward-compat redirects ────────────────────────────
|
||||
{ path: '/benchmark', redirect: '/eval/benchmark' },
|
||||
{ path: '/models', redirect: '/fleet' },
|
||||
{ path: '/stats', redirect: '/' },
|
||||
{ path: '/label', redirect: '/data/label' },
|
||||
{ path: '/fetch', redirect: '/data/fetch' },
|
||||
{ path: '/corrections', redirect: '/data/corrections' },
|
||||
{ path: '/imitate', redirect: '/data/imitate' },
|
||||
]
|
||||
|
||||
export const router = createRouter({
|
||||
history: createWebHashHistory(),
|
||||
routes: [
|
||||
{ path: '/', component: LabelView, meta: { title: 'Label' } },
|
||||
{ path: '/fetch', component: FetchView, meta: { title: 'Fetch' } },
|
||||
{ path: '/stats', component: StatsView, meta: { title: 'Stats' } },
|
||||
{ path: '/benchmark', component: BenchmarkView, meta: { title: 'Benchmark' } },
|
||||
{ path: '/models', component: ModelsView, meta: { title: 'Models' } },
|
||||
{ path: '/imitate', component: ImitateView, meta: { title: 'Imitate' } },
|
||||
{ path: '/corrections', component: CorrectionsView, meta: { title: 'Corrections' } },
|
||||
{ path: '/settings', component: SettingsView, meta: { title: 'Settings' } },
|
||||
],
|
||||
routes,
|
||||
})
|
||||
|
|
|
|||
94
web/src/router/router.test.ts
Normal file
94
web/src/router/router.test.ts
Normal file
|
|
@ -0,0 +1,94 @@
|
|||
import { describe, it, expect } from 'vitest'
|
||||
import { createRouter, createWebHashHistory } from 'vue-router'
|
||||
|
||||
// Import the raw routes array so we can test structure without mounting App
|
||||
import { routes } from './index'
|
||||
|
||||
describe('router routes', () => {
|
||||
it('exports a routes array', () => {
|
||||
expect(Array.isArray(routes)).toBe(true)
|
||||
})
|
||||
|
||||
it('has / pointing to DashboardView', () => {
|
||||
const root = routes.find(r => r.path === '/')
|
||||
expect(root).toBeDefined()
|
||||
// Component should be async (lazy) or have a name
|
||||
expect(root?.component).toBeDefined()
|
||||
})
|
||||
|
||||
it('has /fleet route', () => {
|
||||
const r = routes.find(r => r.path === '/fleet')
|
||||
expect(r).toBeDefined()
|
||||
})
|
||||
|
||||
it('has /data/label route', () => {
|
||||
const r = routes.find(r => r.path === '/data/label')
|
||||
expect(r).toBeDefined()
|
||||
})
|
||||
|
||||
it('has /data/fetch route', () => {
|
||||
const r = routes.find(r => r.path === '/data/fetch')
|
||||
expect(r).toBeDefined()
|
||||
})
|
||||
|
||||
it('has /data/corrections route', () => {
|
||||
const r = routes.find(r => r.path === '/data/corrections')
|
||||
expect(r).toBeDefined()
|
||||
})
|
||||
|
||||
it('has /data/imitate route', () => {
|
||||
const r = routes.find(r => r.path === '/data/imitate')
|
||||
expect(r).toBeDefined()
|
||||
})
|
||||
|
||||
it('has /eval/benchmark route', () => {
|
||||
const r = routes.find(r => r.path === '/eval/benchmark')
|
||||
expect(r).toBeDefined()
|
||||
})
|
||||
|
||||
it('has /eval/compare route', () => {
|
||||
const r = routes.find(r => r.path === '/eval/compare')
|
||||
expect(r).toBeDefined()
|
||||
})
|
||||
|
||||
it('has /train/jobs route', () => {
|
||||
const r = routes.find(r => r.path === '/train/jobs')
|
||||
expect(r).toBeDefined()
|
||||
})
|
||||
|
||||
it('has /train/results route', () => {
|
||||
const r = routes.find(r => r.path === '/train/results')
|
||||
expect(r).toBeDefined()
|
||||
})
|
||||
|
||||
it('has /settings route', () => {
|
||||
const r = routes.find(r => r.path === '/settings')
|
||||
expect(r).toBeDefined()
|
||||
})
|
||||
|
||||
it('has backward-compat redirect from /benchmark to /eval/benchmark', () => {
|
||||
const r = routes.find(r => r.path === '/benchmark')
|
||||
expect(r).toBeDefined()
|
||||
expect((r as { redirect?: string }).redirect).toBe('/eval/benchmark')
|
||||
})
|
||||
|
||||
it('has backward-compat redirect from /models to /fleet', () => {
|
||||
const r = routes.find(r => r.path === '/models')
|
||||
expect(r).toBeDefined()
|
||||
expect((r as { redirect?: string }).redirect).toBe('/fleet')
|
||||
})
|
||||
|
||||
it('has backward-compat redirect from /stats to /', () => {
|
||||
const r = routes.find(r => r.path === '/stats')
|
||||
expect(r).toBeDefined()
|
||||
expect((r as { redirect?: string }).redirect).toBe('/')
|
||||
})
|
||||
|
||||
it('can create a functional router instance', () => {
|
||||
const router = createRouter({
|
||||
history: createWebHashHistory(),
|
||||
routes,
|
||||
})
|
||||
expect(router).toBeDefined()
|
||||
})
|
||||
})
|
||||
89
web/src/types/nodes.ts
Normal file
89
web/src/types/nodes.ts
Normal file
|
|
@ -0,0 +1,89 @@
|
|||
export interface GpuEntry {
|
||||
gpu_id: number
|
||||
card: string
|
||||
vram_total_mb: number
|
||||
vram_used_mb: number
|
||||
vram_free_mb: number
|
||||
temp_c: number | null
|
||||
utilization_pct: number | null
|
||||
compute_cap: number | null
|
||||
services_assigned: string[]
|
||||
services_running: string[]
|
||||
}
|
||||
|
||||
export interface ServiceInfo {
|
||||
min_compute_cap: number
|
||||
max_mb: number
|
||||
catalog_size: number
|
||||
}
|
||||
|
||||
export interface NodeSummary {
|
||||
node_id: string
|
||||
online: boolean
|
||||
agent_url: string
|
||||
gpus: GpuEntry[]
|
||||
profile_loaded: boolean
|
||||
services_catalog: Record<string, ServiceInfo>
|
||||
}
|
||||
|
||||
// ── Full profile types (for profile editor) ────────────────────────────────────
|
||||
|
||||
export interface ServiceManaged {
|
||||
type: string
|
||||
exec_path?: string
|
||||
args_template?: string
|
||||
port?: number
|
||||
host_port?: number
|
||||
base_port?: number
|
||||
health_path?: string
|
||||
cwd?: string
|
||||
adopt?: boolean
|
||||
[key: string]: unknown
|
||||
}
|
||||
|
||||
export interface CatalogEntryFull {
|
||||
path: string
|
||||
vram_mb: number
|
||||
description?: string
|
||||
multi_gpu?: boolean
|
||||
env?: Record<string, string>
|
||||
}
|
||||
|
||||
export interface ServiceDefinition {
|
||||
max_mb: number
|
||||
priority: number
|
||||
min_compute_cap?: number
|
||||
preferred_compute_cap?: number
|
||||
shared?: boolean
|
||||
max_concurrent?: number
|
||||
idle_stop_after_s?: number
|
||||
always_on?: boolean
|
||||
model_base_path?: string
|
||||
managed?: ServiceManaged
|
||||
catalog?: Record<string, CatalogEntryFull>
|
||||
}
|
||||
|
||||
export interface NodeHardwareGpu {
|
||||
id: number
|
||||
vram_mb: number
|
||||
compute_cap?: number
|
||||
card?: string
|
||||
role?: string
|
||||
services?: string[]
|
||||
}
|
||||
|
||||
export interface NodeHardwareEntry {
|
||||
local_model_root?: string
|
||||
agent_url?: string
|
||||
gpus: NodeHardwareGpu[]
|
||||
}
|
||||
|
||||
export interface FullProfile {
|
||||
schema_version?: number
|
||||
name?: string
|
||||
vram_total_mb?: number
|
||||
eviction_timeout_s?: number
|
||||
services: Record<string, ServiceDefinition>
|
||||
nodes: Record<string, NodeHardwareEntry>
|
||||
model_size_hints?: Record<string, string>
|
||||
}
|
||||
987
web/src/views/AssignmentsTab.vue
Normal file
987
web/src/views/AssignmentsTab.vue
Normal file
|
|
@ -0,0 +1,987 @@
|
|||
<template>
|
||||
<div class="assignments-tab">
|
||||
|
||||
<!-- ── Toast ───────────────────────────────────────────── -->
|
||||
<div v-if="toast" class="toast" :class="toast.type" role="status" aria-live="polite">
|
||||
{{ toast.message }}
|
||||
</div>
|
||||
|
||||
<!-- ── Assignments section ─────────────────────────────── -->
|
||||
<div class="section-header">
|
||||
<h2 class="section-title">Task Assignments</h2>
|
||||
<button class="btn-primary btn-sm" @click="openNewAssignment">+ New Assignment</button>
|
||||
</div>
|
||||
|
||||
<div class="filter-row">
|
||||
<label for="product-filter" class="filter-label">Product</label>
|
||||
<select id="product-filter" v-model="productFilter" class="filter-select">
|
||||
<option value="">All products</option>
|
||||
<option v-for="p in allProducts" :key="p" :value="p">{{ p }}</option>
|
||||
</select>
|
||||
</div>
|
||||
|
||||
<div v-if="assignmentsLoading" class="empty-state">Loading assignments…</div>
|
||||
<div v-else-if="assignmentsError" class="error-notice" role="alert">{{ assignmentsError }}</div>
|
||||
<div v-else-if="filteredGroups.length === 0" class="empty-state">No assignments yet. Add one above.</div>
|
||||
<div v-else class="product-groups">
|
||||
<div v-for="group in filteredGroups" :key="group.product" class="product-group">
|
||||
<h3 class="product-name">{{ group.product.toUpperCase() }}</h3>
|
||||
<div class="assignment-list">
|
||||
<div v-for="a in group.assignments" :key="`${a.product}/${a.task}`" class="assignment-row">
|
||||
<div class="assignment-main">
|
||||
<span class="task-id">{{ a.task }}</span>
|
||||
<span
|
||||
class="model-name"
|
||||
:title="a.model_id"
|
||||
>{{ displayModelId(a) }}</span>
|
||||
<span v-if="a.vram_mb" class="chip chip-vram">{{ formatVram(a.vram_mb) }}</span>
|
||||
<span v-if="a.service_type" class="chip" :class="serviceChipClass(a.service_type)">{{ a.service_type }}</span>
|
||||
</div>
|
||||
|
||||
<!-- Node deployment status -->
|
||||
<div v-if="deploymentMap[`${a.product}/${a.task}`]" class="node-statuses">
|
||||
<span
|
||||
v-for="ns in deploymentMap[`${a.product}/${a.task}`]"
|
||||
:key="ns.node_id"
|
||||
class="node-badge-wrap"
|
||||
>
|
||||
<span
|
||||
class="node-badge"
|
||||
:class="ns.status"
|
||||
:title="`${ns.node_id}: ${ns.status}`"
|
||||
>
|
||||
<span class="node-icon">{{ nodeIcon(ns.status) }}</span>
|
||||
{{ ns.node_id }}
|
||||
</span>
|
||||
<button
|
||||
v-if="ns.status === 'absent'"
|
||||
class="btn-deploy"
|
||||
:disabled="deploying.has(`${a.product}/${a.task}/${ns.node_id}`)"
|
||||
:title="`Register ${a.model_id} in ${ns.node_id} catalog`"
|
||||
@click="deployModel(a, ns.node_id)"
|
||||
>{{ deploying.has(`${a.product}/${a.task}/${ns.node_id}`) ? '…' : 'Register' }}</button>
|
||||
</span>
|
||||
</div>
|
||||
|
||||
<div class="assignment-actions">
|
||||
<button
|
||||
v-if="editingKey !== `${a.product}/${a.task}`"
|
||||
class="btn-ghost btn-sm"
|
||||
@click="startEdit(a)"
|
||||
>Edit</button>
|
||||
<button
|
||||
class="btn-ghost btn-sm btn-danger"
|
||||
@click="deleteAssignment(a.product, a.task)"
|
||||
>Delete</button>
|
||||
</div>
|
||||
|
||||
<!-- Inline edit form -->
|
||||
<div v-if="editingKey === `${a.product}/${a.task}`" class="inline-edit">
|
||||
<select v-model="editDraft.model_id" class="edit-select" aria-label="Model">
|
||||
<option value="" disabled>Select model…</option>
|
||||
<option v-for="m in registryModels" :key="m.model_id" :value="m.model_id">
|
||||
{{ m.alias || truncate(m.model_id, 40) }}
|
||||
</option>
|
||||
</select>
|
||||
<input
|
||||
v-model="editDraft.description"
|
||||
type="text"
|
||||
class="edit-input"
|
||||
placeholder="Description (optional)"
|
||||
/>
|
||||
<div class="inline-edit-btns">
|
||||
<button class="btn-primary btn-sm" :disabled="!editDraft.model_id" @click="saveEdit(a)">Save</button>
|
||||
<button class="btn-ghost btn-sm" @click="editingKey = null">Cancel</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- ── Model Registry section ───────────────────────────── -->
|
||||
<div class="section-header section-header-mt">
|
||||
<h2 class="section-title">Model Registry</h2>
|
||||
<button class="btn-primary btn-sm" @click="showRegisterModal = true">Register Model</button>
|
||||
</div>
|
||||
|
||||
<div v-if="registryLoading" class="empty-state">Loading model registry…</div>
|
||||
<div v-else-if="registryError" class="error-notice" role="alert">{{ registryError }}</div>
|
||||
<div v-else-if="registryModels.length === 0" class="empty-state">No models registered yet.</div>
|
||||
<div v-else class="registry-table-wrap">
|
||||
<table class="registry-table">
|
||||
<thead>
|
||||
<tr>
|
||||
<th>Alias</th>
|
||||
<th>Model ID</th>
|
||||
<th>VRAM</th>
|
||||
<th>Service</th>
|
||||
<th class="col-hf">HF Repo</th>
|
||||
<th></th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
<tr v-for="m in registryModels" :key="m.model_id">
|
||||
<td>{{ m.alias || '—' }}</td>
|
||||
<td>
|
||||
<span class="truncated" :title="m.model_id">{{ truncate(m.model_id, 36) }}</span>
|
||||
</td>
|
||||
<td>{{ formatVram(m.vram_mb) }}</td>
|
||||
<td><span class="chip" :class="serviceChipClass(m.service_type)">{{ m.service_type }}</span></td>
|
||||
<td class="col-hf">
|
||||
<a
|
||||
v-if="m.hf_repo"
|
||||
:href="`https://huggingface.co/${m.hf_repo}`"
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
class="hf-link"
|
||||
>{{ truncate(m.hf_repo, 30) }}</a>
|
||||
<span v-else class="text-muted">—</span>
|
||||
</td>
|
||||
<td>
|
||||
<button class="btn-ghost btn-sm btn-danger" @click="deleteModel(m.model_id)">Delete</button>
|
||||
</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
|
||||
<!-- ── New Assignment modal ─────────────────────────────── -->
|
||||
<div v-if="showNewAssignmentModal" class="modal-backdrop" @click.self="showNewAssignmentModal = false">
|
||||
<div class="modal" role="dialog" aria-modal="true" aria-labelledby="modal-new-assignment-title">
|
||||
<h3 id="modal-new-assignment-title" class="modal-title">New Assignment</h3>
|
||||
<label class="form-label">Product</label>
|
||||
<input
|
||||
v-model="newAssignment.product"
|
||||
list="product-list"
|
||||
class="form-input"
|
||||
placeholder="e.g. peregrine"
|
||||
autocomplete="off"
|
||||
/>
|
||||
<datalist id="product-list">
|
||||
<option v-for="p in allProducts" :key="p" :value="p" />
|
||||
</datalist>
|
||||
|
||||
<label class="form-label">Task ID</label>
|
||||
<input
|
||||
v-model="newAssignment.task"
|
||||
type="text"
|
||||
class="form-input"
|
||||
placeholder="e.g. cover_letter"
|
||||
/>
|
||||
|
||||
<label class="form-label">Model</label>
|
||||
<select v-model="newAssignment.model_id" class="form-select">
|
||||
<option value="" disabled>Select from registry…</option>
|
||||
<option v-for="m in registryModels" :key="m.model_id" :value="m.model_id">
|
||||
{{ m.alias || truncate(m.model_id, 50) }}
|
||||
</option>
|
||||
</select>
|
||||
|
||||
<label class="form-label">Description <span class="optional">(optional)</span></label>
|
||||
<input
|
||||
v-model="newAssignment.description"
|
||||
type="text"
|
||||
class="form-input"
|
||||
placeholder="Human-readable note for operators"
|
||||
/>
|
||||
|
||||
<div class="modal-actions">
|
||||
<button
|
||||
class="btn-primary"
|
||||
:disabled="!newAssignment.product || !newAssignment.task || !newAssignment.model_id || saving"
|
||||
@click="saveNewAssignment"
|
||||
>{{ saving ? 'Saving…' : 'Save' }}</button>
|
||||
<button class="btn-ghost" @click="showNewAssignmentModal = false">Cancel</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- ── Register Model modal ─────────────────────────────── -->
|
||||
<div v-if="showRegisterModal" class="modal-backdrop" @click.self="showRegisterModal = false">
|
||||
<div class="modal" role="dialog" aria-modal="true" aria-labelledby="modal-register-title">
|
||||
<h3 id="modal-register-title" class="modal-title">Register Model</h3>
|
||||
|
||||
<label class="form-label">Model ID <span class="hint">(HuggingFace slug, e.g. ibm-granite/granite-4.1-8b)</span></label>
|
||||
<input v-model="newModel.model_id" type="text" class="form-input" placeholder="org/model-name" />
|
||||
|
||||
<label class="form-label">Alias <span class="optional">(optional, short name for assignments)</span></label>
|
||||
<input v-model="newModel.alias" type="text" class="form-input" placeholder="e.g. granite-8b" />
|
||||
|
||||
<label class="form-label">Service type</label>
|
||||
<select v-model="newModel.service_type" class="form-select">
|
||||
<option value="" disabled>Select service…</option>
|
||||
<option value="cf-text">cf-text — Language Models</option>
|
||||
<option value="cf-stt">cf-stt — Speech Recognition</option>
|
||||
<option value="cf-tts">cf-tts — Text to Speech</option>
|
||||
<option value="cf-vision">cf-vision — Vision / VLM</option>
|
||||
<option value="cf-image">cf-image — Image Generation</option>
|
||||
<option value="cf-voice">cf-voice — Audio Classification</option>
|
||||
<option value="vllm">vllm — vLLM inference</option>
|
||||
<option value="ollama">ollama — Ollama inference</option>
|
||||
</select>
|
||||
|
||||
<label class="form-label">VRAM required (MB)</label>
|
||||
<input v-model.number="newModel.vram_mb" type="number" min="0" class="form-input" placeholder="e.g. 16384" />
|
||||
|
||||
<label class="form-label">HF Repo <span class="optional">(optional)</span></label>
|
||||
<input v-model="newModel.hf_repo" type="text" class="form-input" placeholder="org/repo-name" />
|
||||
|
||||
<label class="form-label">Description <span class="optional">(optional)</span></label>
|
||||
<input v-model="newModel.description" type="text" class="form-input" placeholder="Human-readable note" />
|
||||
|
||||
<div class="modal-actions">
|
||||
<button
|
||||
class="btn-primary"
|
||||
:disabled="!newModel.model_id || !newModel.service_type || !newModel.vram_mb || saving"
|
||||
@click="saveNewModel"
|
||||
>{{ saving ? 'Saving…' : 'Register' }}</button>
|
||||
<button class="btn-ghost" @click="showRegisterModal = false">Cancel</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, computed, onMounted } from 'vue'
|
||||
|
||||
// ── Types ──────────────────────────────────────────────
|
||||
|
||||
interface AssignmentNode {
|
||||
node_id: string
|
||||
status: 'present' | 'absent' | 'vram_tight'
|
||||
}
|
||||
|
||||
interface DeployingKey {
|
||||
nodeId: string
|
||||
assignmentKey: string
|
||||
}
|
||||
|
||||
interface Assignment {
|
||||
product: string
|
||||
task: string
|
||||
model_id: string
|
||||
description: string
|
||||
alias?: string
|
||||
service_type?: string
|
||||
vram_mb?: number
|
||||
nodes?: AssignmentNode[]
|
||||
}
|
||||
|
||||
interface RegistryModel {
|
||||
model_id: string
|
||||
alias: string
|
||||
service_type: string
|
||||
vram_mb: number
|
||||
hf_repo: string
|
||||
description: string
|
||||
}
|
||||
|
||||
interface ProductGroup {
|
||||
product: string
|
||||
assignments: Assignment[]
|
||||
}
|
||||
|
||||
interface Toast {
|
||||
message: string
|
||||
type: 'success' | 'error'
|
||||
}
|
||||
|
||||
// ── State ──────────────────────────────────────────────
|
||||
|
||||
const assignments = ref<Assignment[]>([])
|
||||
const assignmentsLoading = ref(false)
|
||||
const assignmentsError = ref<string | null>(null)
|
||||
|
||||
const registryModels = ref<RegistryModel[]>([])
|
||||
const registryLoading = ref(false)
|
||||
const registryError = ref<string | null>(null)
|
||||
|
||||
const productFilter = ref('')
|
||||
const editingKey = ref<string | null>(null)
|
||||
const editDraft = ref({ model_id: '', description: '' })
|
||||
|
||||
const showNewAssignmentModal = ref(false)
|
||||
const newAssignment = ref({ product: '', task: '', model_id: '', description: '' })
|
||||
|
||||
const showRegisterModal = ref(false)
|
||||
const newModel = ref({ model_id: '', alias: '', service_type: '', vram_mb: 0, hf_repo: '', description: '' })
|
||||
|
||||
const saving = ref(false)
|
||||
const toast = ref<Toast | null>(null)
|
||||
let toastTimer: ReturnType<typeof setTimeout> | null = null
|
||||
|
||||
const deploying = ref<Set<string>>(new Set())
|
||||
|
||||
// ── Derived ────────────────────────────────────────────
|
||||
|
||||
const allProducts = computed(() => {
|
||||
const seen = new Set<string>()
|
||||
for (const a of assignments.value) seen.add(a.product)
|
||||
return [...seen].sort()
|
||||
})
|
||||
|
||||
const deploymentMap = computed(() => {
|
||||
const map: Record<string, AssignmentNode[]> = {}
|
||||
for (const a of assignments.value) {
|
||||
if (a.nodes) map[`${a.product}/${a.task}`] = a.nodes
|
||||
}
|
||||
return map
|
||||
})
|
||||
|
||||
const filteredGroups = computed((): ProductGroup[] => {
|
||||
const filtered = productFilter.value
|
||||
? assignments.value.filter(a => a.product === productFilter.value)
|
||||
: assignments.value
|
||||
|
||||
const byProduct: Record<string, Assignment[]> = {}
|
||||
for (const a of filtered) {
|
||||
if (!byProduct[a.product]) byProduct[a.product] = []
|
||||
byProduct[a.product].push(a)
|
||||
}
|
||||
return Object.keys(byProduct)
|
||||
.sort()
|
||||
.map(product => ({ product, assignments: byProduct[product] }))
|
||||
})
|
||||
|
||||
// ── Helpers ────────────────────────────────────────────
|
||||
|
||||
function truncate(s: string, max: number): string {
|
||||
return s.length > max ? s.slice(0, max - 1) + '…' : s
|
||||
}
|
||||
|
||||
function displayModelId(a: Assignment): string {
|
||||
if (a.alias) return a.alias
|
||||
const id = a.model_id
|
||||
// Show only the model name part (after /) and truncate long slugs
|
||||
const short = id.includes('/') ? id.split('/').slice(1).join('/') : id
|
||||
return truncate(short, 36)
|
||||
}
|
||||
|
||||
function formatVram(mb: number | undefined): string {
|
||||
if (!mb) return ''
|
||||
if (mb >= 1024) return `${(mb / 1024).toFixed(1)} GB`
|
||||
return `${mb} MB`
|
||||
}
|
||||
|
||||
function serviceChipClass(service: string): string {
|
||||
return `chip-service-${service.replace(/[^a-z0-9]/g, '-')}`
|
||||
}
|
||||
|
||||
function nodeIcon(status: string): string {
|
||||
if (status === 'present') return '✓'
|
||||
if (status === 'vram_tight') return '~'
|
||||
return '✗'
|
||||
}
|
||||
|
||||
function showToast(message: string, type: 'success' | 'error' = 'success') {
|
||||
if (toastTimer) clearTimeout(toastTimer)
|
||||
toast.value = { message, type }
|
||||
toastTimer = setTimeout(() => { toast.value = null }, 3500)
|
||||
}
|
||||
|
||||
function openNewAssignment() {
|
||||
newAssignment.value = { product: '', task: '', model_id: '', description: '' }
|
||||
showNewAssignmentModal.value = true
|
||||
}
|
||||
|
||||
function startEdit(a: Assignment) {
|
||||
editingKey.value = `${a.product}/${a.task}`
|
||||
editDraft.value = { model_id: a.model_id, description: a.description }
|
||||
}
|
||||
|
||||
// ── API ────────────────────────────────────────────────
|
||||
|
||||
async function loadAssignments() {
|
||||
assignmentsLoading.value = true
|
||||
assignmentsError.value = null
|
||||
try {
|
||||
// Fetch both list and deployment status in parallel
|
||||
const [listRes, statusRes] = await Promise.all([
|
||||
fetch('/api/cforch/assignments'),
|
||||
fetch('/api/cforch/assignments/deployment-status'),
|
||||
])
|
||||
if (!listRes.ok) throw new Error(`HTTP ${listRes.status}`)
|
||||
const list: Assignment[] = (await listRes.json()).assignments ?? []
|
||||
|
||||
// Merge deployment status into assignments if available
|
||||
if (statusRes.ok) {
|
||||
const statusList: Assignment[] = (await statusRes.json()).deployment_status ?? []
|
||||
const statusMap: Record<string, AssignmentNode[]> = {}
|
||||
for (const s of statusList) {
|
||||
statusMap[`${s.product}/${s.task}`] = s.nodes ?? []
|
||||
}
|
||||
for (const a of list) {
|
||||
a.nodes = statusMap[`${a.product}/${a.task}`] ?? []
|
||||
// Enrich with service_type/vram_mb from status payload
|
||||
const s = statusList.find(x => x.product === a.product && x.task === a.task)
|
||||
if (s) {
|
||||
a.service_type = s.service_type
|
||||
a.vram_mb = s.vram_mb
|
||||
a.alias = s.alias
|
||||
}
|
||||
}
|
||||
}
|
||||
assignments.value = list
|
||||
} catch (e) {
|
||||
assignmentsError.value = `Could not load assignments: ${e}`
|
||||
} finally {
|
||||
assignmentsLoading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
async function loadRegistry() {
|
||||
registryLoading.value = true
|
||||
registryError.value = null
|
||||
try {
|
||||
const res = await fetch('/api/cforch/model-registry')
|
||||
if (!res.ok) throw new Error(`HTTP ${res.status}`)
|
||||
registryModels.value = (await res.json()).models ?? []
|
||||
} catch (e) {
|
||||
registryError.value = `Could not load model registry: ${e}`
|
||||
} finally {
|
||||
registryLoading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
async function saveNewAssignment() {
|
||||
saving.value = true
|
||||
try {
|
||||
const res = await fetch('/api/cforch/assignments', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify(newAssignment.value),
|
||||
})
|
||||
if (!res.ok) throw new Error(await res.text())
|
||||
showNewAssignmentModal.value = false
|
||||
showToast('Assignment saved')
|
||||
await loadAssignments()
|
||||
} catch (e) {
|
||||
showToast(`Save failed: ${e}`, 'error')
|
||||
} finally {
|
||||
saving.value = false
|
||||
}
|
||||
}
|
||||
|
||||
async function saveEdit(a: Assignment) {
|
||||
saving.value = true
|
||||
try {
|
||||
const res = await fetch('/api/cforch/assignments', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({
|
||||
product: a.product,
|
||||
task: a.task,
|
||||
model_id: editDraft.value.model_id,
|
||||
description: editDraft.value.description,
|
||||
}),
|
||||
})
|
||||
if (!res.ok) throw new Error(await res.text())
|
||||
editingKey.value = null
|
||||
showToast('Assignment updated')
|
||||
await loadAssignments()
|
||||
} catch (e) {
|
||||
showToast(`Update failed: ${e}`, 'error')
|
||||
} finally {
|
||||
saving.value = false
|
||||
}
|
||||
}
|
||||
|
||||
async function deleteAssignment(product: string, task: string) {
|
||||
if (!confirm(`Delete assignment ${product}.${task}?`)) return
|
||||
try {
|
||||
const res = await fetch(
|
||||
`/api/cforch/assignments/${encodeURIComponent(product)}/${encodeURIComponent(task)}`,
|
||||
{ method: 'DELETE' },
|
||||
)
|
||||
if (!res.ok) throw new Error(await res.text())
|
||||
showToast('Assignment deleted')
|
||||
await loadAssignments()
|
||||
} catch (e) {
|
||||
showToast(`Delete failed: ${e}`, 'error')
|
||||
}
|
||||
}
|
||||
|
||||
async function saveNewModel() {
|
||||
saving.value = true
|
||||
try {
|
||||
const res = await fetch('/api/cforch/model-registry', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify(newModel.value),
|
||||
})
|
||||
if (!res.ok) throw new Error(await res.text())
|
||||
showRegisterModal.value = false
|
||||
showToast('Model registered')
|
||||
await loadRegistry()
|
||||
} catch (e) {
|
||||
showToast(`Register failed: ${e}`, 'error')
|
||||
} finally {
|
||||
saving.value = false
|
||||
}
|
||||
}
|
||||
|
||||
async function deleteModel(model_id: string) {
|
||||
if (!confirm(`Remove ${model_id} from the registry?`)) return
|
||||
try {
|
||||
const res = await fetch(
|
||||
`/api/cforch/model-registry/${encodeURIComponent(model_id)}`,
|
||||
{ method: 'DELETE' },
|
||||
)
|
||||
if (!res.ok) throw new Error(await res.text())
|
||||
showToast('Model removed')
|
||||
await loadRegistry()
|
||||
} catch (e) {
|
||||
showToast(`Delete failed: ${e}`, 'error')
|
||||
}
|
||||
}
|
||||
|
||||
async function deployModel(a: Assignment, nodeId: string) {
|
||||
const key = `${a.product}/${a.task}/${nodeId}`
|
||||
if (deploying.value.has(key)) return
|
||||
|
||||
// Look up hf_repo from registry for cleaner path construction
|
||||
const regEntry = registryModels.value.find(m => m.model_id === a.model_id)
|
||||
const hf_repo = regEntry?.hf_repo ?? ''
|
||||
const service_type = a.service_type ?? regEntry?.service_type ?? ''
|
||||
const vram_mb = a.vram_mb ?? regEntry?.vram_mb ?? 0
|
||||
const description = regEntry?.alias ? `${regEntry.alias} (via assignments)` : ''
|
||||
|
||||
if (!service_type) {
|
||||
showToast(`No service type for model ${a.model_id}`, 'error')
|
||||
return
|
||||
}
|
||||
|
||||
deploying.value = new Set([...deploying.value, key])
|
||||
try {
|
||||
const res = await fetch(`/api/nodes-mgmt/nodes/${encodeURIComponent(nodeId)}/models/deploy`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ model_id: a.model_id, service_type, vram_mb, hf_repo, description }),
|
||||
})
|
||||
if (!res.ok) throw new Error(await res.text())
|
||||
const data = await res.json()
|
||||
showToast(`Registered ${a.model_id} on ${nodeId} at ${data.path}`)
|
||||
|
||||
// Optimistic update: flip node to 'present' immediately so the Register button
|
||||
// disappears before the coordinator reload confirms. loadAssignments() reconciles
|
||||
// with real server state on the next round-trip.
|
||||
assignments.value = assignments.value.map(asgn => {
|
||||
if (asgn.product !== a.product || asgn.task !== a.task) return asgn
|
||||
return {
|
||||
...asgn,
|
||||
nodes: (asgn.nodes ?? []).map(ns =>
|
||||
ns.node_id === nodeId ? { ...ns, status: 'present' as const } : ns
|
||||
),
|
||||
}
|
||||
})
|
||||
|
||||
await loadAssignments()
|
||||
} catch (e) {
|
||||
showToast(`Deploy failed: ${e}`, 'error')
|
||||
} finally {
|
||||
deploying.value = new Set([...deploying.value].filter(k => k !== key))
|
||||
}
|
||||
}
|
||||
|
||||
onMounted(() => {
|
||||
loadAssignments()
|
||||
loadRegistry()
|
||||
})
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
.assignments-tab {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 1.25rem;
|
||||
}
|
||||
|
||||
/* ── Toast ── */
|
||||
.toast {
|
||||
position: fixed;
|
||||
bottom: 1.5rem;
|
||||
right: 1.5rem;
|
||||
padding: 0.65rem 1.1rem;
|
||||
border-radius: 0.5rem;
|
||||
font-size: 0.88rem;
|
||||
font-weight: 500;
|
||||
z-index: 200;
|
||||
box-shadow: 0 2px 8px rgba(0,0,0,0.15);
|
||||
}
|
||||
.toast.success {
|
||||
background: var(--color-success, #2a8050);
|
||||
color: #fff;
|
||||
}
|
||||
.toast.error {
|
||||
background: var(--color-danger, #b03030);
|
||||
color: #fff;
|
||||
}
|
||||
|
||||
/* ── Section headers ── */
|
||||
.section-header {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: space-between;
|
||||
gap: 1rem;
|
||||
}
|
||||
.section-header-mt {
|
||||
margin-top: 1.5rem;
|
||||
}
|
||||
.section-title {
|
||||
font-size: 1rem;
|
||||
font-weight: 600;
|
||||
color: var(--app-primary, #2A6080);
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
/* ── Filter row ── */
|
||||
.filter-row {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.6rem;
|
||||
}
|
||||
.filter-label {
|
||||
font-size: 0.85rem;
|
||||
color: var(--color-text-muted, #6b7a99);
|
||||
}
|
||||
.filter-select {
|
||||
padding: 0.3rem 0.6rem;
|
||||
font-size: 0.85rem;
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: 0.4rem;
|
||||
background: var(--color-surface, #fff);
|
||||
color: var(--color-text, #1a2030);
|
||||
}
|
||||
|
||||
/* ── Product groups ── */
|
||||
.product-groups {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 1rem;
|
||||
}
|
||||
.product-group {}
|
||||
.product-name {
|
||||
font-size: 0.75rem;
|
||||
font-weight: 700;
|
||||
letter-spacing: 0.08em;
|
||||
color: var(--color-text-muted, #6b7a99);
|
||||
text-transform: uppercase;
|
||||
margin: 0 0 0.4rem;
|
||||
}
|
||||
.assignment-list {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0.4rem;
|
||||
}
|
||||
|
||||
/* ── Assignment rows ── */
|
||||
.assignment-row {
|
||||
background: var(--color-surface-raised, #f0f4fa);
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: 0.5rem;
|
||||
padding: 0.65rem 0.85rem;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0.4rem;
|
||||
}
|
||||
.assignment-main {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.5rem;
|
||||
flex-wrap: wrap;
|
||||
}
|
||||
.task-id {
|
||||
font-family: var(--font-mono, monospace);
|
||||
font-size: 0.88rem;
|
||||
font-weight: 600;
|
||||
color: var(--color-text, #1a2030);
|
||||
min-width: 0;
|
||||
}
|
||||
.model-name {
|
||||
font-size: 0.85rem;
|
||||
color: var(--color-text-muted, #6b7a99);
|
||||
white-space: nowrap;
|
||||
overflow: hidden;
|
||||
text-overflow: ellipsis;
|
||||
max-width: 280px;
|
||||
cursor: default;
|
||||
}
|
||||
.assignment-actions {
|
||||
display: flex;
|
||||
gap: 0.4rem;
|
||||
flex-wrap: wrap;
|
||||
}
|
||||
|
||||
/* ── Node status badges ── */
|
||||
.node-statuses {
|
||||
display: flex;
|
||||
gap: 0.35rem;
|
||||
flex-wrap: wrap;
|
||||
}
|
||||
.node-badge-wrap {
|
||||
display: inline-flex;
|
||||
align-items: center;
|
||||
gap: 0.2rem;
|
||||
}
|
||||
.node-badge {
|
||||
display: inline-flex;
|
||||
align-items: center;
|
||||
gap: 0.2rem;
|
||||
font-size: 0.78rem;
|
||||
padding: 0.15rem 0.5rem;
|
||||
border-radius: 0.35rem;
|
||||
font-weight: 500;
|
||||
}
|
||||
.node-badge.present {
|
||||
background: color-mix(in srgb, var(--color-success, #2a8050) 15%, transparent);
|
||||
color: var(--color-success, #2a8050);
|
||||
border: 1px solid color-mix(in srgb, var(--color-success, #2a8050) 30%, transparent);
|
||||
}
|
||||
.node-badge.absent {
|
||||
background: color-mix(in srgb, var(--color-danger, #b03030) 12%, transparent);
|
||||
color: var(--color-danger, #b03030);
|
||||
border: 1px solid color-mix(in srgb, var(--color-danger, #b03030) 25%, transparent);
|
||||
}
|
||||
.node-badge.vram_tight {
|
||||
background: color-mix(in srgb, #c08030 15%, transparent);
|
||||
color: #8a5500;
|
||||
border: 1px solid color-mix(in srgb, #c08030 30%, transparent);
|
||||
}
|
||||
.node-icon {
|
||||
font-size: 0.85em;
|
||||
}
|
||||
.btn-deploy {
|
||||
padding: 0.1rem 0.4rem;
|
||||
font-size: 0.72rem;
|
||||
font-weight: 600;
|
||||
background: color-mix(in srgb, var(--app-primary, #2A6080) 12%, transparent);
|
||||
color: var(--app-primary, #2A6080);
|
||||
border: 1px solid color-mix(in srgb, var(--app-primary, #2A6080) 30%, transparent);
|
||||
border-radius: 0.3rem;
|
||||
cursor: pointer;
|
||||
white-space: nowrap;
|
||||
transition: background 0.15s;
|
||||
}
|
||||
.btn-deploy:hover:not(:disabled) {
|
||||
background: color-mix(in srgb, var(--app-primary, #2A6080) 22%, transparent);
|
||||
}
|
||||
.btn-deploy:disabled { opacity: 0.5; cursor: default; }
|
||||
|
||||
/* ── Inline edit ── */
|
||||
.inline-edit {
|
||||
display: flex;
|
||||
flex-wrap: wrap;
|
||||
gap: 0.4rem;
|
||||
padding-top: 0.35rem;
|
||||
border-top: 1px solid var(--color-border, #d0d7e8);
|
||||
}
|
||||
.edit-select,
|
||||
.edit-input {
|
||||
flex: 1;
|
||||
min-width: 160px;
|
||||
padding: 0.35rem 0.55rem;
|
||||
font-size: 0.85rem;
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: 0.4rem;
|
||||
background: var(--color-surface, #fff);
|
||||
color: var(--color-text, #1a2030);
|
||||
}
|
||||
.inline-edit-btns {
|
||||
display: flex;
|
||||
gap: 0.35rem;
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
/* ── Registry table ── */
|
||||
.registry-table-wrap {
|
||||
overflow-x: auto;
|
||||
border-radius: 0.5rem;
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
}
|
||||
.registry-table {
|
||||
width: 100%;
|
||||
border-collapse: collapse;
|
||||
font-size: 0.85rem;
|
||||
}
|
||||
.registry-table th {
|
||||
text-align: left;
|
||||
padding: 0.5rem 0.75rem;
|
||||
font-size: 0.78rem;
|
||||
font-weight: 600;
|
||||
color: var(--color-text-muted, #6b7a99);
|
||||
background: var(--color-surface-raised, #f0f4fa);
|
||||
border-bottom: 1px solid var(--color-border, #d0d7e8);
|
||||
white-space: nowrap;
|
||||
}
|
||||
.registry-table td {
|
||||
padding: 0.5rem 0.75rem;
|
||||
border-bottom: 1px solid var(--color-border, #d0d7e8);
|
||||
vertical-align: middle;
|
||||
}
|
||||
.registry-table tbody tr:last-child td {
|
||||
border-bottom: none;
|
||||
}
|
||||
.truncated {
|
||||
display: inline-block;
|
||||
max-width: 220px;
|
||||
white-space: nowrap;
|
||||
overflow: hidden;
|
||||
text-overflow: ellipsis;
|
||||
vertical-align: bottom;
|
||||
cursor: default;
|
||||
}
|
||||
.hf-link {
|
||||
color: var(--app-primary, #2A6080);
|
||||
text-decoration: none;
|
||||
font-size: 0.82rem;
|
||||
}
|
||||
.hf-link:hover { text-decoration: underline; }
|
||||
.text-muted { color: var(--color-text-muted, #6b7a99); }
|
||||
|
||||
/* ── Chips ── */
|
||||
.chip {
|
||||
display: inline-block;
|
||||
padding: 0.15rem 0.5rem;
|
||||
border-radius: 0.35rem;
|
||||
font-size: 0.75rem;
|
||||
font-weight: 600;
|
||||
white-space: nowrap;
|
||||
}
|
||||
.chip-vram {
|
||||
background: color-mix(in srgb, var(--app-primary, #2A6080) 12%, transparent);
|
||||
color: var(--app-primary, #2A6080);
|
||||
border: 1px solid color-mix(in srgb, var(--app-primary, #2A6080) 25%, transparent);
|
||||
}
|
||||
/* service chips — match ModelsView convention */
|
||||
.chip-service-cf-text { background: #e8f0fe; color: #1a5276; border: 1px solid #a9c4e8; }
|
||||
.chip-service-cf-stt { background: #eaf6ea; color: #1e6b3a; border: 1px solid #a2d9b1; }
|
||||
.chip-service-cf-tts { background: #fdf3e3; color: #7d4e00; border: 1px solid #e8c98a; }
|
||||
.chip-service-cf-vision { background: #f3e8fd; color: #5b2d8e; border: 1px solid #c8a0e8; }
|
||||
.chip-service-cf-image { background: #fce8f0; color: #8e1a4f; border: 1px solid #e8a0c0; }
|
||||
.chip-service-cf-voice { background: #e8f8fc; color: #0a5c6e; border: 1px solid #88d0e0; }
|
||||
.chip-service-vllm { background: #f5ece0; color: #7a3800; border: 1px solid #d4a87a; }
|
||||
.chip-service-ollama { background: #eeeeee; color: #444; border: 1px solid #ccc; }
|
||||
|
||||
/* ── Buttons ── */
|
||||
.btn-primary {
|
||||
padding: 0.45rem 1rem;
|
||||
background: var(--app-primary, #2A6080);
|
||||
color: #fff;
|
||||
border: none;
|
||||
border-radius: 0.4rem;
|
||||
font-size: 0.85rem;
|
||||
font-weight: 600;
|
||||
cursor: pointer;
|
||||
transition: opacity 0.15s;
|
||||
}
|
||||
.btn-primary:disabled { opacity: 0.5; cursor: default; }
|
||||
.btn-primary:not(:disabled):hover { opacity: 0.88; }
|
||||
|
||||
.btn-ghost {
|
||||
padding: 0.35rem 0.75rem;
|
||||
background: transparent;
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: 0.4rem;
|
||||
font-size: 0.82rem;
|
||||
color: var(--color-text-muted, #6b7a99);
|
||||
cursor: pointer;
|
||||
transition: background 0.15s;
|
||||
}
|
||||
.btn-ghost:hover { background: var(--color-surface-raised, #e4ebf5); }
|
||||
.btn-ghost.btn-danger { color: var(--color-danger, #b03030); border-color: color-mix(in srgb, var(--color-danger, #b03030) 30%, transparent); }
|
||||
.btn-ghost.btn-danger:hover { background: color-mix(in srgb, var(--color-danger, #b03030) 10%, transparent); }
|
||||
|
||||
.btn-sm { padding: 0.3rem 0.65rem; font-size: 0.8rem; }
|
||||
|
||||
/* ── Empty / error states ── */
|
||||
.empty-state {
|
||||
padding: 1.5rem;
|
||||
text-align: center;
|
||||
color: var(--color-text-muted, #6b7a99);
|
||||
font-size: 0.9rem;
|
||||
background: var(--color-surface-raised, #f0f4fa);
|
||||
border: 1px dashed var(--color-border, #d0d7e8);
|
||||
border-radius: 0.5rem;
|
||||
}
|
||||
.error-notice {
|
||||
padding: 0.75rem 1rem;
|
||||
background: color-mix(in srgb, var(--color-danger, #b03030) 10%, transparent);
|
||||
color: var(--color-danger, #b03030);
|
||||
border: 1px solid color-mix(in srgb, var(--color-danger, #b03030) 25%, transparent);
|
||||
border-radius: 0.4rem;
|
||||
font-size: 0.87rem;
|
||||
}
|
||||
|
||||
/* ── Modal ── */
|
||||
.modal-backdrop {
|
||||
position: fixed;
|
||||
inset: 0;
|
||||
background: rgba(0,0,0,0.35);
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
z-index: 100;
|
||||
padding: 1rem;
|
||||
}
|
||||
.modal {
|
||||
background: var(--color-surface, #fff);
|
||||
border-radius: 0.65rem;
|
||||
padding: 1.5rem;
|
||||
width: 100%;
|
||||
max-width: 480px;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0.65rem;
|
||||
box-shadow: 0 8px 32px rgba(0,0,0,0.18);
|
||||
max-height: 90vh;
|
||||
overflow-y: auto;
|
||||
}
|
||||
.modal-title {
|
||||
font-size: 1rem;
|
||||
font-weight: 700;
|
||||
color: var(--app-primary, #2A6080);
|
||||
margin: 0 0 0.25rem;
|
||||
}
|
||||
.form-label {
|
||||
font-size: 0.82rem;
|
||||
font-weight: 600;
|
||||
color: var(--color-text-muted, #6b7a99);
|
||||
}
|
||||
.form-input,
|
||||
.form-select {
|
||||
padding: 0.4rem 0.65rem;
|
||||
font-size: 0.88rem;
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: 0.4rem;
|
||||
background: var(--color-surface, #fff);
|
||||
color: var(--color-text, #1a2030);
|
||||
width: 100%;
|
||||
box-sizing: border-box;
|
||||
}
|
||||
.form-input:focus, .form-select:focus {
|
||||
outline: 2px solid var(--app-primary, #2A6080);
|
||||
outline-offset: 1px;
|
||||
}
|
||||
.modal-actions {
|
||||
display: flex;
|
||||
gap: 0.5rem;
|
||||
justify-content: flex-end;
|
||||
margin-top: 0.25rem;
|
||||
}
|
||||
.optional, .hint {
|
||||
font-weight: 400;
|
||||
color: var(--color-text-muted, #6b7a99);
|
||||
font-size: 0.78rem;
|
||||
}
|
||||
|
||||
/* ── Responsive ── */
|
||||
@media (max-width: 600px) {
|
||||
.assignment-main { flex-direction: column; align-items: flex-start; }
|
||||
.col-hf { display: none; }
|
||||
.model-name { max-width: 100%; }
|
||||
.modal { padding: 1rem; }
|
||||
}
|
||||
</style>
|
||||
82
web/src/views/BenchmarkView.test.ts
Normal file
82
web/src/views/BenchmarkView.test.ts
Normal file
|
|
@ -0,0 +1,82 @@
|
|||
import { mount, flushPromises } from '@vue/test-utils'
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest'
|
||||
import BenchmarkView from './BenchmarkView.vue'
|
||||
|
||||
beforeEach(() => {
|
||||
vi.stubGlobal('fetch', vi.fn().mockImplementation((url: string) => {
|
||||
// LlmEvalTab calls /api/cforch/models and expects { models: CfOrchModel[] }
|
||||
if (url.includes('/api/cforch/models')) {
|
||||
return Promise.resolve({
|
||||
ok: true,
|
||||
json: async () => ({ models: [] }),
|
||||
text: async () => '',
|
||||
})
|
||||
}
|
||||
// Default: satisfies ClassifierTab (/api/benchmark/results, /api/benchmark/models,
|
||||
// /api/finetune/status), StyleTab (/api/style/models, /api/style/results),
|
||||
// and any other tab that tolerates empty arrays/objects.
|
||||
return Promise.resolve({
|
||||
ok: true,
|
||||
json: async () => ({ models: {}, categories: {}, tasks: [], types: [], results: [] }),
|
||||
text: async () => '',
|
||||
})
|
||||
}))
|
||||
vi.stubGlobal('EventSource', class {
|
||||
onmessage = null
|
||||
onerror = null
|
||||
close() {}
|
||||
})
|
||||
})
|
||||
|
||||
describe('BenchmarkView', () => {
|
||||
it('renders page title "Benchmark"', async () => {
|
||||
const w = mount(BenchmarkView)
|
||||
await flushPromises()
|
||||
expect(w.text()).toContain('Benchmark')
|
||||
})
|
||||
|
||||
it('has mode buttons: Classifier, LLM Eval, Writing Style', async () => {
|
||||
const w = mount(BenchmarkView)
|
||||
await flushPromises()
|
||||
const text = w.text()
|
||||
expect(text).toContain('Classifier')
|
||||
expect(text).toContain('LLM Eval')
|
||||
expect(text).toContain('Writing Style')
|
||||
})
|
||||
|
||||
it('does NOT have a Compare mode button', async () => {
|
||||
const w = mount(BenchmarkView)
|
||||
await flushPromises()
|
||||
const buttons = w.findAll('.mode-btn')
|
||||
const labels = buttons.map(b => b.text())
|
||||
expect(labels.every(l => !l.includes('Compare'))).toBe(true)
|
||||
})
|
||||
|
||||
it('shows Classifier tab by default', async () => {
|
||||
const w = mount(BenchmarkView)
|
||||
await flushPromises()
|
||||
// ClassifierTab has a .classifier-tab root
|
||||
expect(w.find('.classifier-tab').exists()).toBe(true)
|
||||
})
|
||||
|
||||
it('switches to LlmEvalTab when LLM Eval clicked', async () => {
|
||||
const w = mount(BenchmarkView)
|
||||
await flushPromises()
|
||||
const llmBtn = w.findAll('.mode-btn').find(b => b.text().includes('LLM Eval'))!
|
||||
await llmBtn.trigger('click')
|
||||
await flushPromises()
|
||||
expect(w.find('.llm-eval-tab').exists()).toBe(true)
|
||||
expect(w.find('.classifier-tab').exists()).toBe(false)
|
||||
expect(llmBtn.classes()).toContain('active')
|
||||
})
|
||||
|
||||
it('switches to StyleTab when Writing Style clicked', async () => {
|
||||
const w = mount(BenchmarkView)
|
||||
await flushPromises()
|
||||
const styleBtn = w.findAll('.mode-btn').find(b => b.text().includes('Writing Style'))!
|
||||
await styleBtn.trigger('click')
|
||||
await flushPromises()
|
||||
expect(w.find('.style-tab').exists()).toBe(true)
|
||||
expect(w.find('.classifier-tab').exists()).toBe(false)
|
||||
})
|
||||
})
|
||||
|
|
@ -16,22 +16,22 @@
|
|||
:class="{ active: benchMode === 'llm' }"
|
||||
@click="benchMode = 'llm'"
|
||||
>🤖 LLM Eval</button>
|
||||
<button
|
||||
class="mode-btn"
|
||||
:class="{ active: benchMode === 'compare' }"
|
||||
@click="benchMode = 'compare'"
|
||||
>⚖️ Compare</button>
|
||||
<button
|
||||
class="mode-btn"
|
||||
:class="{ active: benchMode === 'style' }"
|
||||
@click="benchMode = 'style'"
|
||||
>✍️ Writing Style</button>
|
||||
<button
|
||||
class="mode-btn"
|
||||
:class="{ active: benchMode === 'plans' }"
|
||||
@click="benchMode = 'plans'"
|
||||
>📐 Planning</button>
|
||||
</div>
|
||||
|
||||
<ClassifierTab v-if="benchMode === 'classifier'" />
|
||||
<LlmEvalTab v-if="benchMode === 'llm'" />
|
||||
<CompareTab v-if="benchMode === 'compare'" />
|
||||
<StyleTab v-if="benchMode === 'style'" />
|
||||
<PlansBenchTab v-if="benchMode === 'plans'" />
|
||||
</div>
|
||||
</template>
|
||||
|
||||
|
|
@ -39,10 +39,10 @@
|
|||
import { ref } from 'vue'
|
||||
import ClassifierTab from './ClassifierTab.vue'
|
||||
import LlmEvalTab from './LlmEvalTab.vue'
|
||||
import CompareTab from './CompareTab.vue'
|
||||
import StyleTab from './StyleTab.vue'
|
||||
import PlansBenchTab from './PlansBenchTab.vue'
|
||||
|
||||
type BenchMode = 'classifier' | 'llm' | 'compare' | 'style'
|
||||
type BenchMode = 'classifier' | 'llm' | 'style' | 'plans'
|
||||
const benchMode = ref<BenchMode>('classifier')
|
||||
</script>
|
||||
|
||||
|
|
@ -69,7 +69,7 @@ const benchMode = ref<BenchMode>('classifier')
|
|||
margin: 0;
|
||||
}
|
||||
|
||||
/* ── Mode toggle (segmented control) ────────────────────── */
|
||||
/* ── Mode toggle (segmented control) ── */
|
||||
.mode-toggle {
|
||||
display: inline-flex;
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
|
|
|
|||
|
|
@ -325,7 +325,7 @@ function toggleCategory(models: AvailableModel[], checked: boolean) {
|
|||
|
||||
async function loadModelCategories() {
|
||||
modelsLoading.value = true
|
||||
const { data } = await useApiFetch<ModelCategoriesResponse>('/api/benchmark/models')
|
||||
const { data } = await useApiFetch<ModelCategoriesResponse>('/api/cforch/models')
|
||||
modelsLoading.value = false
|
||||
if (data?.categories) {
|
||||
modelCategories.value = data.categories
|
||||
|
|
@ -342,7 +342,7 @@ const modelCount = computed(() => modelNames.value.length)
|
|||
const labelNames = computed(() => {
|
||||
const canonical = Object.keys(LABEL_META)
|
||||
const inResults = new Set(
|
||||
modelNames.value.flatMap(n => Object.keys(results.value!.models[n].per_label))
|
||||
modelNames.value.flatMap(n => Object.keys(results.value?.models[n]?.per_label ?? {}))
|
||||
)
|
||||
return [...canonical.filter(l => inResults.has(l)), ...[...inResults].filter(l => !canonical.includes(l))]
|
||||
})
|
||||
|
|
@ -401,16 +401,16 @@ function formatDate(iso: string | null): string {
|
|||
// ── Data loading ─────────────────────────────────────────────────────────────
|
||||
async function loadResults() {
|
||||
loading.value = true
|
||||
const { data } = await useApiFetch<BenchResults>('/api/benchmark/results')
|
||||
const { data } = await useApiFetch<BenchResults>('/api/cforch/results')
|
||||
loading.value = false
|
||||
if (data && Object.keys(data.models).length > 0) {
|
||||
if (data?.models && Object.keys(data.models).length > 0) {
|
||||
results.value = data
|
||||
}
|
||||
}
|
||||
|
||||
async function loadFineTunedModels() {
|
||||
const { data } = await useApiFetch<FineTunedModel[]>('/api/finetune/status')
|
||||
if (Array.isArray(data)) fineTunedModels.value = data
|
||||
const { data } = await useApiFetch<{ results: FineTunedModel[] }>('/api/train/results')
|
||||
if (Array.isArray(data?.results)) fineTunedModels.value = data.results
|
||||
}
|
||||
|
||||
// ── Benchmark run ────────────────────────────────────────────────────────────
|
||||
|
|
@ -428,7 +428,7 @@ function startBenchmark() {
|
|||
params.set('model_names', [...selectedModels.value].join(','))
|
||||
}
|
||||
const qs = params.toString()
|
||||
const url = `/api/benchmark/run${qs ? `?${qs}` : ''}`
|
||||
const url = `/api/cforch/run${qs ? `?${qs}` : ''}`
|
||||
useApiSSE(
|
||||
url,
|
||||
async (event) => {
|
||||
|
|
@ -457,7 +457,7 @@ function startBenchmark() {
|
|||
}
|
||||
|
||||
async function cancelBenchmark() {
|
||||
await fetch('/api/benchmark/cancel', { method: 'POST' }).catch(() => {})
|
||||
await fetch('/api/cforch/cancel', { method: 'POST' }).catch(() => {})
|
||||
}
|
||||
|
||||
// ── Fine-tune ─────────────────────────────────────────────────────────────────
|
||||
|
|
|
|||
|
|
@ -71,34 +71,37 @@
|
|||
rows="6"
|
||||
/>
|
||||
|
||||
<!-- Ollama model picker -->
|
||||
<!-- LLM model picker (ollama + vllm + cf-text) -->
|
||||
<details class="model-picker" open>
|
||||
<summary class="picker-summary">
|
||||
<span class="picker-title">🤖 Ollama Models</span>
|
||||
<span class="picker-badge">{{ cmpSelectedModels.size }} / {{ ollamaLlmModels.length }}</span>
|
||||
<span class="picker-title">🤖 LLM Models</span>
|
||||
<span class="picker-badge">{{ cmpSelectedModels.size }} / {{ llmSelectableModels.length }}</span>
|
||||
</summary>
|
||||
<div class="picker-body">
|
||||
<label class="picker-cat-header">
|
||||
<input
|
||||
type="checkbox"
|
||||
:checked="cmpSelectedModels.size === ollamaLlmModels.length"
|
||||
:indeterminate="cmpSelectedModels.size > 0 && cmpSelectedModels.size < ollamaLlmModels.length"
|
||||
:checked="cmpSelectedModels.size === llmSelectableModels.length"
|
||||
:indeterminate="cmpSelectedModels.size > 0 && cmpSelectedModels.size < llmSelectableModels.length"
|
||||
@change="toggleAllCmpModels(($event.target as HTMLInputElement).checked)"
|
||||
/>
|
||||
<span class="picker-cat-name">All ollama models</span>
|
||||
<span class="picker-cat-name">All LLM models</span>
|
||||
</label>
|
||||
<div v-for="(models, service) in llmModelsByService" :key="service" class="picker-category">
|
||||
<span class="picker-cat-section">{{ service }}</span>
|
||||
<div class="picker-model-list">
|
||||
<label v-for="m in ollamaLlmModels" :key="m.id" class="picker-model-row">
|
||||
<label v-for="m in models" :key="m.id" class="picker-model-row">
|
||||
<input
|
||||
type="checkbox"
|
||||
:checked="cmpSelectedModels.has(m.id)"
|
||||
@change="toggleCmpModel(m.id, ($event.target as HTMLInputElement).checked)"
|
||||
/>
|
||||
<span class="picker-model-name">{{ m.name }}</span>
|
||||
<span class="picker-adapter-type">{{ m.tags.slice(0, 3).join(', ') }}</span>
|
||||
<span class="picker-adapter-type">{{ m.tags.slice(0, 2).join(', ') }}</span>
|
||||
</label>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</details>
|
||||
|
||||
<!-- Run controls -->
|
||||
|
|
@ -232,10 +235,22 @@ const cmpResults = ref<CmpResult[]>([])
|
|||
const cmpEventSource = ref<EventSource | null>(null)
|
||||
|
||||
// ── Computed ────────────────────────────────────────────────────────────────
|
||||
const ollamaLlmModels = computed(() =>
|
||||
llmModels.value.filter(m => m.service === 'ollama')
|
||||
const LLM_SERVICES = new Set(['ollama', 'vllm', 'cf-text'])
|
||||
|
||||
const llmSelectableModels = computed(() =>
|
||||
llmModels.value.filter(m => LLM_SERVICES.has(m.service))
|
||||
)
|
||||
|
||||
/** Group selectable models by service for the picker UI */
|
||||
const llmModelsByService = computed((): Record<string, CfOrchModel[]> => {
|
||||
const groups: Record<string, CfOrchModel[]> = {}
|
||||
for (const m of llmSelectableModels.value) {
|
||||
if (!groups[m.service]) groups[m.service] = []
|
||||
groups[m.service].push(m)
|
||||
}
|
||||
return groups
|
||||
})
|
||||
|
||||
const llmTasksByType = computed((): Record<string, CfOrchTask[]> => {
|
||||
const groups: Record<string, CfOrchTask[]> = {}
|
||||
for (const t of llmTasks.value) {
|
||||
|
|
@ -270,7 +285,7 @@ function toggleCmpModel(id: string, checked: boolean) {
|
|||
|
||||
function toggleAllCmpModels(checked: boolean) {
|
||||
cmpSelectedModels.value = checked
|
||||
? new Set(ollamaLlmModels.value.map(m => m.id))
|
||||
? new Set(llmSelectableModels.value.map(m => m.id))
|
||||
: new Set()
|
||||
}
|
||||
|
||||
|
|
@ -288,9 +303,8 @@ async function loadLlmModels() {
|
|||
const { data } = await useApiFetch<{ models: CfOrchModel[] }>('/api/cforch/models')
|
||||
if (data?.models) {
|
||||
llmModels.value = data.models
|
||||
// Pre-select all ollama models
|
||||
cmpSelectedModels.value = new Set(
|
||||
data.models.filter(m => m.service === 'ollama').map(m => m.id)
|
||||
data.models.filter(m => LLM_SERVICES.has(m.service)).map(m => m.id)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
31
web/src/views/CompareView.test.ts
Normal file
31
web/src/views/CompareView.test.ts
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
import { mount, flushPromises } from '@vue/test-utils'
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest'
|
||||
import CompareView from './CompareView.vue'
|
||||
|
||||
beforeEach(() => {
|
||||
vi.stubGlobal('fetch', vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
json: async () => ({ tasks: [], types: [], models: [] }),
|
||||
text: async () => '',
|
||||
}))
|
||||
vi.stubGlobal('EventSource', class {
|
||||
onmessage = null
|
||||
onerror = null
|
||||
close() {}
|
||||
})
|
||||
})
|
||||
|
||||
describe('CompareView', () => {
|
||||
it('renders page title "Compare"', async () => {
|
||||
const w = mount(CompareView)
|
||||
await flushPromises()
|
||||
expect(w.find('h1.page-title').text()).toContain('Compare')
|
||||
})
|
||||
|
||||
it('wraps CompareTab component', async () => {
|
||||
const w = mount(CompareView)
|
||||
await flushPromises()
|
||||
// CompareTab renders a .compare-tab root div
|
||||
expect(w.find('.compare-tab').exists()).toBe(true)
|
||||
})
|
||||
})
|
||||
36
web/src/views/CompareView.vue
Normal file
36
web/src/views/CompareView.vue
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
<template>
|
||||
<div class="compare-view">
|
||||
<header class="compare-header">
|
||||
<h1 class="page-title">🔍 Compare</h1>
|
||||
</header>
|
||||
<CompareTab />
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import CompareTab from './CompareTab.vue'
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
.compare-view {
|
||||
max-width: 860px;
|
||||
margin: 0 auto;
|
||||
padding: 1.5rem 1rem 4rem;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 1.75rem;
|
||||
}
|
||||
|
||||
.compare-header {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
.page-title {
|
||||
font-family: var(--font-display, var(--font-body, sans-serif));
|
||||
font-size: 1.4rem;
|
||||
font-weight: 700;
|
||||
color: var(--app-primary, #2A6080);
|
||||
margin: 0;
|
||||
}
|
||||
</style>
|
||||
119
web/src/views/DashboardView.test.ts
Normal file
119
web/src/views/DashboardView.test.ts
Normal file
|
|
@ -0,0 +1,119 @@
|
|||
import { mount, flushPromises } from '@vue/test-utils'
|
||||
import { createRouter, createWebHashHistory } from 'vue-router'
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest'
|
||||
import DashboardView from './DashboardView.vue'
|
||||
|
||||
const router = createRouter({
|
||||
history: createWebHashHistory(),
|
||||
routes: [
|
||||
{ path: '/', component: { template: '<div />' } },
|
||||
{ path: '/eval/benchmark', component: { template: '<div />' } },
|
||||
{ path: '/train/jobs', component: { template: '<div />' } },
|
||||
{ path: '/fleet', component: { template: '<div />' } },
|
||||
],
|
||||
})
|
||||
|
||||
const baseDashboard = {
|
||||
labeled_since_last_eval: 0,
|
||||
last_eval_timestamp: null,
|
||||
last_eval_best_score: null,
|
||||
active_jobs: [],
|
||||
corrections_export_ready: 0,
|
||||
signals: { data_to_eval: false, eval_to_train: false, train_to_fleet: false },
|
||||
}
|
||||
|
||||
function mockFetch(overrides: Partial<typeof baseDashboard> = {}) {
|
||||
vi.stubGlobal('fetch', vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
json: async () => ({ ...baseDashboard, ...overrides }),
|
||||
text: async () => '',
|
||||
}))
|
||||
}
|
||||
|
||||
beforeEach(() => mockFetch())
|
||||
|
||||
describe('DashboardView', () => {
|
||||
it('renders page title', async () => {
|
||||
const w = mount(DashboardView, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
expect(w.text()).toContain('Dashboard')
|
||||
})
|
||||
|
||||
it('shows three stage cards: Data, Eval, Train', async () => {
|
||||
const w = mount(DashboardView, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
expect(w.find('.stage-card[data-stage="data"]').exists()).toBe(true)
|
||||
expect(w.find('.stage-card[data-stage="eval"]').exists()).toBe(true)
|
||||
expect(w.find('.stage-card[data-stage="train"]').exists()).toBe(true)
|
||||
})
|
||||
|
||||
it('shows labeled_since_last_eval count in Data card', async () => {
|
||||
mockFetch({ labeled_since_last_eval: 42 })
|
||||
const w = mount(DashboardView, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
expect(w.find('.stage-card[data-stage="data"]').text()).toContain('42')
|
||||
})
|
||||
|
||||
it('does NOT show Run Eval CTA when data_to_eval is false', async () => {
|
||||
mockFetch({ signals: { data_to_eval: false, eval_to_train: false, train_to_fleet: false } })
|
||||
const w = mount(DashboardView, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
const dataCard = w.find('.stage-card[data-stage="data"]')
|
||||
expect(dataCard.find('.cta-btn').exists()).toBe(false)
|
||||
})
|
||||
|
||||
it('shows Run Eval CTA when data_to_eval is true', async () => {
|
||||
mockFetch({ signals: { data_to_eval: true, eval_to_train: false, train_to_fleet: false } })
|
||||
const w = mount(DashboardView, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
const dataCard = w.find('.stage-card[data-stage="data"]')
|
||||
expect(dataCard.find('.cta-btn').exists()).toBe(true)
|
||||
expect(dataCard.find('.cta-btn').text()).toContain('Run Eval')
|
||||
})
|
||||
|
||||
it('shows Queue Finetune CTA when eval_to_train is true', async () => {
|
||||
mockFetch({ signals: { data_to_eval: false, eval_to_train: true, train_to_fleet: false } })
|
||||
const w = mount(DashboardView, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
const evalCard = w.find('.stage-card[data-stage="eval"]')
|
||||
expect(evalCard.find('.cta-btn').text()).toContain('Queue Finetune')
|
||||
})
|
||||
|
||||
it('shows Register in Fleet CTA when train_to_fleet is true', async () => {
|
||||
mockFetch({ signals: { data_to_eval: false, eval_to_train: false, train_to_fleet: true } })
|
||||
const w = mount(DashboardView, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
const trainCard = w.find('.stage-card[data-stage="train"]')
|
||||
expect(trainCard.find('.cta-btn').text()).toContain('Register in Fleet')
|
||||
})
|
||||
|
||||
it('shows active job status pills in Train card', async () => {
|
||||
mockFetch({ active_jobs: [{ id: 'j1', type: 'classifier', model_key: 'deberta-v3', status: 'running' }] })
|
||||
const w = mount(DashboardView, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
const trainCard = w.find('.stage-card[data-stage="train"]')
|
||||
expect(trainCard.find('.status-pill').exists()).toBe(true)
|
||||
expect(trainCard.text()).toContain('deberta-v3')
|
||||
})
|
||||
|
||||
it('shows last eval score in Eval card when present', async () => {
|
||||
mockFetch({ last_eval_best_score: 0.821 })
|
||||
const w = mount(DashboardView, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
const evalCard = w.find('.stage-card[data-stage="eval"]')
|
||||
expect(evalCard.text()).toContain('82.1%')
|
||||
})
|
||||
|
||||
it('shows error state when API call fails', async () => {
|
||||
vi.stubGlobal('fetch', vi.fn().mockResolvedValue({ ok: false, status: 503, text: async () => '' }))
|
||||
const w = mount(DashboardView, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
expect(w.find('.error-notice').exists()).toBe(true)
|
||||
})
|
||||
|
||||
it('shows refresh button', async () => {
|
||||
const w = mount(DashboardView, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
expect(w.find('.refresh-btn').exists()).toBe(true)
|
||||
})
|
||||
})
|
||||
406
web/src/views/DashboardView.vue
Normal file
406
web/src/views/DashboardView.vue
Normal file
|
|
@ -0,0 +1,406 @@
|
|||
<template>
|
||||
<div class="dashboard-view">
|
||||
<header class="dashboard-header">
|
||||
<h1 class="page-title">📊 Dashboard</h1>
|
||||
<button class="refresh-btn" :disabled="loading" @click="load" aria-label="Refresh dashboard">
|
||||
🔄
|
||||
</button>
|
||||
</header>
|
||||
|
||||
<div v-if="loading && !data" class="loading-state">Loading…</div>
|
||||
|
||||
<div v-if="error" class="error-notice" role="alert">
|
||||
{{ error }}
|
||||
<button class="btn-retry" @click="load">Retry</button>
|
||||
</div>
|
||||
|
||||
<div v-if="data" class="flywheel-grid">
|
||||
|
||||
<!-- ① Data card -->
|
||||
<div class="stage-card" data-stage="data">
|
||||
<div class="card-header">
|
||||
<span class="card-step">①</span>
|
||||
<h2 class="card-title">Data</h2>
|
||||
</div>
|
||||
<div class="card-body">
|
||||
<p class="card-metric">
|
||||
<strong class="metric-value">{{ data.labeled_since_last_eval.toLocaleString() }}</strong>
|
||||
<span class="metric-label"> labeled since last eval</span>
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- ② Eval card -->
|
||||
<div class="stage-card" data-stage="eval">
|
||||
<div class="card-header">
|
||||
<span class="card-step">②</span>
|
||||
<h2 class="card-title">Eval</h2>
|
||||
</div>
|
||||
<div class="card-body">
|
||||
<div class="bench-run-table">
|
||||
<div
|
||||
v-for="(run, type) in data.recent_bench_runs"
|
||||
:key="type"
|
||||
class="bench-run-row"
|
||||
>
|
||||
<span class="bench-type-label">{{ BENCH_LABELS[type as BenchType] ?? type }}</span>
|
||||
<span class="bench-run-time" :class="{ 'metric-muted': !run.timestamp }">
|
||||
{{ run.timestamp ? formatBenchTs(run.timestamp) : '—' }}
|
||||
</span>
|
||||
<span v-if="run.score != null" class="bench-run-score">
|
||||
{{ formatScore(run.score) }}
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div v-if="data.signals.eval_to_train" class="card-cta">
|
||||
<RouterLink to="/train/jobs" class="cta-btn">Queue Finetune</RouterLink>
|
||||
</div>
|
||||
<div v-if="data.signals.data_to_eval" class="card-cta">
|
||||
<RouterLink to="/eval/benchmark" class="cta-btn">Run Eval</RouterLink>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- ③ Train card -->
|
||||
<div class="stage-card" data-stage="train">
|
||||
<div class="card-header">
|
||||
<span class="card-step">③</span>
|
||||
<h2 class="card-title">Train</h2>
|
||||
</div>
|
||||
<div class="card-body">
|
||||
<template v-if="data.active_jobs.length > 0">
|
||||
<div
|
||||
v-for="job in data.active_jobs"
|
||||
:key="job.id"
|
||||
class="job-row"
|
||||
>
|
||||
<span class="job-key">{{ job.model_key }}</span>
|
||||
<span class="status-pill" :class="`status-${job.status}`">{{ job.status }}</span>
|
||||
</div>
|
||||
</template>
|
||||
<p v-else class="card-metric metric-muted">No active jobs</p>
|
||||
|
||||
<p v-if="data.corrections_export_ready > 0" class="card-metric">
|
||||
<strong class="metric-value">{{ data.corrections_export_ready }}</strong>
|
||||
<span class="metric-label"> corrections ready</span>
|
||||
</p>
|
||||
</div>
|
||||
<div v-if="data.signals.train_to_fleet" class="card-cta">
|
||||
<RouterLink to="/fleet" class="cta-btn">Register in Fleet</RouterLink>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, computed, onMounted } from 'vue'
|
||||
import { RouterLink } from 'vue-router'
|
||||
|
||||
interface ActiveJob {
|
||||
id: string
|
||||
type: string
|
||||
model_key: string
|
||||
status: 'queued' | 'running' | 'completed' | 'failed' | 'cancelled'
|
||||
}
|
||||
|
||||
interface DashboardSignals {
|
||||
data_to_eval: boolean
|
||||
eval_to_train: boolean
|
||||
train_to_fleet: boolean
|
||||
}
|
||||
|
||||
interface BenchRun {
|
||||
timestamp: string | null
|
||||
metric: string | null
|
||||
score: number | null
|
||||
}
|
||||
|
||||
type BenchType = 'classifier' | 'llm' | 'style' | 'plans'
|
||||
|
||||
interface DashboardData {
|
||||
labeled_since_last_eval: number
|
||||
last_eval_timestamp: string | null
|
||||
last_eval_best_score: number | null
|
||||
active_jobs: ActiveJob[]
|
||||
corrections_export_ready: number
|
||||
recent_bench_runs: Record<BenchType, BenchRun>
|
||||
signals: DashboardSignals
|
||||
}
|
||||
|
||||
const BENCH_LABELS: Record<BenchType, string> = {
|
||||
classifier: 'Classifier',
|
||||
llm: 'LLM Eval',
|
||||
style: 'Style',
|
||||
plans: 'Planning',
|
||||
}
|
||||
|
||||
const data = ref<DashboardData | null>(null)
|
||||
const loading = ref(false)
|
||||
const error = ref<string | null>(null)
|
||||
|
||||
function formatBenchTs(ts: string): string {
|
||||
const date = new Date(ts)
|
||||
if (!isNaN(date.getTime())) {
|
||||
const diff = Date.now() - date.getTime()
|
||||
const mins = Math.floor(diff / 60000)
|
||||
if (mins < 1) return 'just now'
|
||||
if (mins < 60) return `${mins}m ago`
|
||||
const hrs = Math.floor(mins / 60)
|
||||
if (hrs < 24) return `${hrs}h ago`
|
||||
return `${Math.floor(hrs / 24)}d ago`
|
||||
}
|
||||
// Non-ISO: show as-is (plans bench uses "YYYY-MM-DD HH:MM")
|
||||
return ts.length > 16 ? ts.slice(0, 16) : ts
|
||||
}
|
||||
|
||||
function formatScore(score: number): string {
|
||||
return `${(score * 100).toFixed(1)}%`
|
||||
}
|
||||
|
||||
async function load() {
|
||||
loading.value = true
|
||||
error.value = null
|
||||
try {
|
||||
const res = await fetch('/api/dashboard')
|
||||
if (!res.ok) {
|
||||
error.value = `Could not load dashboard (HTTP ${res.status}).`
|
||||
return
|
||||
}
|
||||
data.value = await res.json() as DashboardData
|
||||
} catch {
|
||||
error.value = 'Network error. Is the Avocet API running?'
|
||||
} finally {
|
||||
loading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
onMounted(() => load())
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
.dashboard-view {
|
||||
max-width: 860px;
|
||||
margin: 0 auto;
|
||||
padding: 1.5rem 1rem 4rem;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 1.75rem;
|
||||
}
|
||||
|
||||
.dashboard-header {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.75rem;
|
||||
}
|
||||
|
||||
.page-title {
|
||||
font-family: var(--font-display, var(--font-body, sans-serif));
|
||||
font-size: 1.4rem;
|
||||
font-weight: 700;
|
||||
color: var(--app-primary, #2A6080);
|
||||
margin: 0;
|
||||
flex: 1;
|
||||
}
|
||||
|
||||
.refresh-btn {
|
||||
background: transparent;
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: 0.375rem;
|
||||
cursor: pointer;
|
||||
font-size: 1rem;
|
||||
padding: 0.3rem 0.5rem;
|
||||
transition: background 0.15s;
|
||||
}
|
||||
|
||||
.refresh-btn:hover:not(:disabled) { background: var(--color-surface-raised, #e4ebf5); }
|
||||
.refresh-btn:disabled { opacity: 0.5; cursor: not-allowed; }
|
||||
|
||||
/* ── Flywheel grid ── */
|
||||
.flywheel-grid {
|
||||
display: grid;
|
||||
grid-template-columns: repeat(3, 1fr);
|
||||
gap: 1rem;
|
||||
}
|
||||
|
||||
@media (max-width: 680px) {
|
||||
.flywheel-grid {
|
||||
grid-template-columns: 1fr;
|
||||
}
|
||||
}
|
||||
|
||||
/* ── Stage cards ── */
|
||||
.stage-card {
|
||||
background: var(--color-surface-raised, #f5f7fc);
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: var(--radius-lg, 1rem);
|
||||
padding: 1rem;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0.75rem;
|
||||
box-shadow: var(--shadow-sm);
|
||||
}
|
||||
|
||||
.card-header {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.5rem;
|
||||
border-bottom: 1px solid var(--color-border, #d0d7e8);
|
||||
padding-bottom: 0.6rem;
|
||||
}
|
||||
|
||||
.card-step {
|
||||
font-size: 1.1rem;
|
||||
font-weight: 700;
|
||||
color: var(--app-primary, #2A6080);
|
||||
flex-shrink: 0;
|
||||
}
|
||||
|
||||
.card-title {
|
||||
font-family: var(--font-display, var(--font-body, sans-serif));
|
||||
font-size: 1rem;
|
||||
font-weight: 600;
|
||||
color: var(--color-text, #1a2338);
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
.card-body {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0.4rem;
|
||||
flex: 1;
|
||||
}
|
||||
|
||||
.card-metric {
|
||||
margin: 0;
|
||||
font-size: 0.875rem;
|
||||
color: var(--color-text, #1a2338);
|
||||
}
|
||||
|
||||
.metric-value {
|
||||
font-size: 1.05rem;
|
||||
font-weight: 700;
|
||||
color: var(--app-primary, #2A6080);
|
||||
}
|
||||
|
||||
.metric-label {
|
||||
color: var(--color-text-muted, #4a5c7a);
|
||||
}
|
||||
|
||||
.metric-muted { color: var(--color-text-muted, #4a5c7a); }
|
||||
|
||||
.card-cta { margin-top: auto; }
|
||||
|
||||
.cta-btn {
|
||||
display: block;
|
||||
width: 100%;
|
||||
text-align: center;
|
||||
padding: 0.5rem;
|
||||
background: var(--app-primary, #2A6080);
|
||||
color: #fff;
|
||||
border-radius: 0.375rem;
|
||||
text-decoration: none;
|
||||
font-size: 0.875rem;
|
||||
font-weight: 600;
|
||||
transition: background 0.15s;
|
||||
}
|
||||
|
||||
.cta-btn:hover { background: color-mix(in srgb, var(--app-primary, #2A6080) 85%, black); }
|
||||
|
||||
/* ── Bench run table ── */
|
||||
.bench-run-table {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0.3rem;
|
||||
}
|
||||
|
||||
.bench-run-row {
|
||||
display: grid;
|
||||
grid-template-columns: 6rem 1fr auto;
|
||||
align-items: center;
|
||||
gap: 0.5rem;
|
||||
font-size: 0.82rem;
|
||||
}
|
||||
|
||||
.bench-type-label {
|
||||
font-weight: 600;
|
||||
color: var(--color-text, #1a2338);
|
||||
font-size: 0.78rem;
|
||||
}
|
||||
|
||||
.bench-run-time {
|
||||
color: var(--color-text-secondary, #6b7a99);
|
||||
font-size: 0.78rem;
|
||||
}
|
||||
|
||||
.bench-run-score {
|
||||
font-family: var(--font-mono, monospace);
|
||||
font-size: 0.75rem;
|
||||
font-weight: 600;
|
||||
color: var(--app-primary, #2A6080);
|
||||
background: color-mix(in srgb, var(--app-primary, #2A6080) 10%, transparent);
|
||||
padding: 0.1rem 0.35rem;
|
||||
border-radius: 0.25rem;
|
||||
}
|
||||
|
||||
/* ── Job pills ── */
|
||||
.job-row {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.5rem;
|
||||
font-size: 0.875rem;
|
||||
}
|
||||
|
||||
.job-key {
|
||||
flex: 1;
|
||||
overflow: hidden;
|
||||
text-overflow: ellipsis;
|
||||
white-space: nowrap;
|
||||
color: var(--color-text, #1a2338);
|
||||
}
|
||||
|
||||
.status-pill {
|
||||
font-size: 0.75rem;
|
||||
padding: 0.15rem 0.45rem;
|
||||
border-radius: 100px;
|
||||
font-weight: 600;
|
||||
flex-shrink: 0;
|
||||
background: var(--color-surface-raised, #e4ebf5);
|
||||
color: var(--color-text-muted, #4a5c7a);
|
||||
}
|
||||
|
||||
.status-pill.status-running { background: #d4f4e0; color: #1a7a3a; }
|
||||
.status-pill.status-queued { background: #fef3cd; color: #856404; }
|
||||
.status-pill.status-failed { background: #fde8e8; color: #842029; }
|
||||
.status-pill.status-completed { background: #e0f0ff; color: #0c5481; }
|
||||
|
||||
/* ── State indicators ── */
|
||||
.loading-state {
|
||||
color: var(--color-text-muted, #4a5c7a);
|
||||
font-size: 0.9rem;
|
||||
}
|
||||
|
||||
.error-notice {
|
||||
background: #fde8e8;
|
||||
color: #842029;
|
||||
border: 1px solid #f5c2c7;
|
||||
border-radius: 0.5rem;
|
||||
padding: 0.75rem 1rem;
|
||||
font-size: 0.875rem;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.75rem;
|
||||
}
|
||||
|
||||
.btn-retry {
|
||||
background: transparent;
|
||||
border: 1px solid currentColor;
|
||||
border-radius: 0.25rem;
|
||||
color: inherit;
|
||||
cursor: pointer;
|
||||
font-size: 0.75rem;
|
||||
padding: 0.2rem 0.5rem;
|
||||
margin-left: auto;
|
||||
}
|
||||
</style>
|
||||
705
web/src/views/EmbedCompareTab.vue
Normal file
705
web/src/views/EmbedCompareTab.vue
Normal file
|
|
@ -0,0 +1,705 @@
|
|||
<template>
|
||||
<div class="embed-compare-page">
|
||||
<!-- Step indicator (non-interactive) -->
|
||||
<ol class="step-indicator" aria-label="Setup progress">
|
||||
<li :class="{ complete: corpus.length > 0 }">Corpus</li>
|
||||
<li :class="{ complete: queries.length > 0 }">Queries</li>
|
||||
<li :class="{ complete: selectedModels.length > 0 }">Models</li>
|
||||
<li :class="{ complete: hasResults }">Run & Rate</li>
|
||||
</ol>
|
||||
|
||||
<!-- Persistent aria-live region — always in DOM, never v-if -->
|
||||
<div
|
||||
ref="liveRegion"
|
||||
class="sr-live"
|
||||
aria-live="polite"
|
||||
aria-atomic="true"
|
||||
v-text="liveMessage"
|
||||
/>
|
||||
|
||||
<!-- ① Corpus section -->
|
||||
<section class="card" aria-labelledby="corpus-heading">
|
||||
<h2 id="corpus-heading">① Corpus</h2>
|
||||
<div class="corpus-controls">
|
||||
<div class="field">
|
||||
<label for="corpus-paste">Paste chunks (one per line)</label>
|
||||
<textarea
|
||||
id="corpus-paste"
|
||||
v-model="rawCorpus"
|
||||
rows="6"
|
||||
placeholder="Paste one chunk per line, or use Import below..."
|
||||
@change="onCorpusPaste"
|
||||
/>
|
||||
</div>
|
||||
<div class="import-row">
|
||||
<label for="imitate-product-select">Import from product</label>
|
||||
<select id="imitate-product-select" v-model="selectedProduct">
|
||||
<option value="">-- select product --</option>
|
||||
<option
|
||||
v-for="p in imitateProducts"
|
||||
:key="p.id"
|
||||
:value="p.id"
|
||||
>{{ p.name }}</option>
|
||||
</select>
|
||||
<button
|
||||
class="btn-secondary"
|
||||
:disabled="!selectedProduct || importing"
|
||||
@click="importCorpus"
|
||||
>
|
||||
{{ importing ? 'Importing…' : 'Import' }}
|
||||
</button>
|
||||
<span v-if="importError" class="error-text" role="alert">{{ importError }}</span>
|
||||
</div>
|
||||
<p v-if="corpus.length > 0" class="corpus-count">
|
||||
{{ corpus.length }} chunk{{ corpus.length === 1 ? '' : 's' }} loaded.
|
||||
</p>
|
||||
</div>
|
||||
</section>
|
||||
|
||||
<!-- ② Queries section -->
|
||||
<section class="card" aria-labelledby="queries-heading">
|
||||
<h2 id="queries-heading">② Queries</h2>
|
||||
<div class="field">
|
||||
<label for="query-input">Enter queries (one per line)</label>
|
||||
<textarea
|
||||
id="query-input"
|
||||
v-model="rawQueries"
|
||||
rows="4"
|
||||
placeholder="One query per line..."
|
||||
@change="onQueriesChange"
|
||||
/>
|
||||
</div>
|
||||
<p v-if="queries.length > 0" class="query-count">
|
||||
{{ queries.length }} quer{{ queries.length === 1 ? 'y' : 'ies' }}.
|
||||
</p>
|
||||
</section>
|
||||
|
||||
<!-- ③ Model selection -->
|
||||
<section class="card" aria-labelledby="models-heading">
|
||||
<h2 id="models-heading">③ Models</h2>
|
||||
<p v-if="loadingModels" class="muted">Loading models from Ollama…</p>
|
||||
<p v-else-if="modelsError" class="error-text" role="alert">{{ modelsError }}</p>
|
||||
<ul v-else class="model-list" role="list">
|
||||
<li v-for="m in availableModels" :key="m.name">
|
||||
<label class="model-checkbox">
|
||||
<input
|
||||
type="checkbox"
|
||||
:value="m.name"
|
||||
v-model="selectedModels"
|
||||
/>
|
||||
{{ m.name }}
|
||||
<span class="model-size muted" aria-label="model size">
|
||||
{{ formatBytes(m.size) }}
|
||||
</span>
|
||||
</label>
|
||||
</li>
|
||||
</ul>
|
||||
<p v-if="availableModels.length === 0 && !loadingModels && !modelsError" class="muted">
|
||||
No Ollama models found. Pull an embedding model first.
|
||||
</p>
|
||||
</section>
|
||||
|
||||
<!-- ④ Run controls -->
|
||||
<section class="card run-controls" aria-labelledby="run-heading">
|
||||
<h2 id="run-heading">④ Run</h2>
|
||||
<div class="run-row">
|
||||
<div class="field-inline">
|
||||
<label for="top-k-input">Results per query</label>
|
||||
<input
|
||||
id="top-k-input"
|
||||
type="number"
|
||||
v-model.number="topK"
|
||||
min="1"
|
||||
max="20"
|
||||
style="width: 5rem"
|
||||
/>
|
||||
</div>
|
||||
<button
|
||||
class="btn-primary"
|
||||
:disabled="!canRun || running"
|
||||
@click="startRun"
|
||||
>
|
||||
{{ running ? 'Running…' : 'Run' }}
|
||||
</button>
|
||||
<button
|
||||
v-if="running"
|
||||
class="btn-danger"
|
||||
aria-label="Cancel embedding run"
|
||||
@click="cancelRun"
|
||||
>
|
||||
Cancel
|
||||
</button>
|
||||
</div>
|
||||
<p v-if="!canRun && !running" class="muted">
|
||||
Fill corpus, at least one query, and select at least one model to run.
|
||||
</p>
|
||||
</section>
|
||||
|
||||
<!-- Results -->
|
||||
<section
|
||||
v-if="hasResults"
|
||||
class="card results-section"
|
||||
aria-labelledby="results-heading"
|
||||
>
|
||||
<h2 id="results-heading">Results</h2>
|
||||
|
||||
<!-- Query pagination -->
|
||||
<div class="query-nav" role="navigation" aria-label="Query navigation">
|
||||
<button
|
||||
class="btn-secondary"
|
||||
aria-label="Previous query"
|
||||
:disabled="currentQueryIdx === 0"
|
||||
@click="currentQueryIdx--"
|
||||
>‹</button>
|
||||
<span class="query-counter">
|
||||
Query {{ currentQueryIdx + 1 }} of {{ uniqueQueries.length }}:
|
||||
<em>{{ uniqueQueries[currentQueryIdx] }}</em>
|
||||
</span>
|
||||
<button
|
||||
class="btn-secondary"
|
||||
aria-label="Next query"
|
||||
:disabled="currentQueryIdx >= uniqueQueries.length - 1"
|
||||
@click="currentQueryIdx++"
|
||||
>›</button>
|
||||
</div>
|
||||
|
||||
<!-- Results table: one column per model -->
|
||||
<div class="table-wrap">
|
||||
<table class="results-table">
|
||||
<thead>
|
||||
<tr>
|
||||
<th scope="col" class="rank-col">#</th>
|
||||
<th
|
||||
v-for="model in selectedModels"
|
||||
:key="model"
|
||||
scope="col"
|
||||
>{{ model }}</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
<tr v-for="rank in topK" :key="rank">
|
||||
<td class="rank-col muted">{{ rank }}</td>
|
||||
<td
|
||||
v-for="model in selectedModels"
|
||||
:key="model"
|
||||
class="hit-cell"
|
||||
>
|
||||
<template v-if="getHit(currentQueryIdx, model, rank - 1) as hit">
|
||||
<div class="hit-text">{{ hit.text }}</div>
|
||||
<!-- Visual score bar: decorative only -->
|
||||
<div class="score-row">
|
||||
<div class="score-bar-wrap" aria-hidden="true">
|
||||
<div class="score-bar" :style="{ width: `${hit.score * 100}%` }" />
|
||||
</div>
|
||||
<span class="score-label">{{ hit.score.toFixed(3) }}</span>
|
||||
</div>
|
||||
<!-- Rating buttons -->
|
||||
<div class="rating-row">
|
||||
<button
|
||||
class="rate-btn"
|
||||
:class="{ active: getRating(currentQueryIdx, model, hit.chunk_idx) === 'relevant' }"
|
||||
:aria-pressed="getRating(currentQueryIdx, model, hit.chunk_idx) === 'relevant'"
|
||||
aria-label="Mark as relevant"
|
||||
@click="rate(currentQueryIdx, model, hit, 'relevant')"
|
||||
>
|
||||
👍 Relevant
|
||||
</button>
|
||||
<button
|
||||
class="rate-btn rate-btn-neg"
|
||||
:class="{ active: getRating(currentQueryIdx, model, hit.chunk_idx) === 'not_relevant' }"
|
||||
:aria-pressed="getRating(currentQueryIdx, model, hit.chunk_idx) === 'not_relevant'"
|
||||
aria-label="Mark as not relevant"
|
||||
@click="rate(currentQueryIdx, model, hit, 'not_relevant')"
|
||||
>
|
||||
👎 Not relevant
|
||||
</button>
|
||||
</div>
|
||||
</template>
|
||||
<span v-else class="muted">—</span>
|
||||
</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
</section>
|
||||
|
||||
<!-- Export -->
|
||||
<section
|
||||
v-if="hasResults"
|
||||
class="card export-section"
|
||||
aria-labelledby="export-heading"
|
||||
>
|
||||
<h2 id="export-heading">Export Ratings</h2>
|
||||
<div class="export-row">
|
||||
<fieldset class="export-format-group">
|
||||
<legend>Format</legend>
|
||||
<label><input type="radio" v-model="exportFormat" value="csv" /> CSV</label>
|
||||
<label><input type="radio" v-model="exportFormat" value="json" /> JSON</label>
|
||||
</fieldset>
|
||||
<button class="btn-secondary" @click="exportRatings">Export</button>
|
||||
</div>
|
||||
</section>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, computed, onMounted } from 'vue'
|
||||
|
||||
// ── Types ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
interface OllamaModel { name: string; size: number }
|
||||
interface ImitateProduct { id: string; name: string }
|
||||
interface HitResult { chunk_idx: number; text: string; score: number }
|
||||
interface ResultEvent {
|
||||
type: 'result'
|
||||
query_idx: number
|
||||
query: string
|
||||
model: string
|
||||
hits: HitResult[]
|
||||
}
|
||||
|
||||
// ── State ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
const rawCorpus = ref('')
|
||||
const corpus = ref<string[]>([])
|
||||
const rawQueries = ref('')
|
||||
const queries = ref<string[]>([])
|
||||
const selectedModels = ref<string[]>([])
|
||||
const topK = ref(5)
|
||||
const availableModels = ref<OllamaModel[]>([])
|
||||
const loadingModels = ref(false)
|
||||
const modelsError = ref('')
|
||||
const imitateProducts = ref<ImitateProduct[]>([])
|
||||
const selectedProduct = ref('')
|
||||
const importing = ref(false)
|
||||
const importError = ref('')
|
||||
const running = ref(false)
|
||||
const liveMessage = ref('')
|
||||
const resultEvents = ref<ResultEvent[]>([])
|
||||
const runController = ref<AbortController | null>(null)
|
||||
|
||||
const currentQueryIdx = ref(0)
|
||||
const exportFormat = ref<'csv' | 'json'>('csv')
|
||||
|
||||
type RatingMap = Record<string, Record<string, Record<number, 'relevant' | 'not_relevant'>>>
|
||||
const ratings = ref<RatingMap>({})
|
||||
|
||||
const uniqueQueries = computed(() => {
|
||||
const seen = new Set<string>()
|
||||
const out: string[] = []
|
||||
for (const e of resultEvents.value) {
|
||||
if (!seen.has(e.query)) { seen.add(e.query); out.push(e.query) }
|
||||
}
|
||||
return out
|
||||
})
|
||||
|
||||
const hasResults = computed(() => resultEvents.value.length > 0)
|
||||
const canRun = computed(
|
||||
() => corpus.value.length > 0 && queries.value.length > 0 && selectedModels.value.length > 0
|
||||
)
|
||||
|
||||
// ── Corpus helpers ────────────────────────────────────────────────────────────
|
||||
|
||||
function onCorpusPaste() {
|
||||
const chunks = rawCorpus.value.split('\n').map(l => l.trim()).filter(Boolean)
|
||||
corpus.value = chunks
|
||||
if (chunks.length > 0) {
|
||||
liveMessage.value = `${chunks.length} chunk${chunks.length === 1 ? '' : 's'} loaded.`
|
||||
}
|
||||
}
|
||||
|
||||
function onQueriesChange() {
|
||||
queries.value = rawQueries.value.split('\n').map(l => l.trim()).filter(Boolean)
|
||||
}
|
||||
|
||||
async function importCorpus() {
|
||||
if (!selectedProduct.value) return
|
||||
importing.value = true
|
||||
importError.value = ''
|
||||
try {
|
||||
const r = await fetch(`/api/imitate/products/${selectedProduct.value}/sample-chunks`)
|
||||
if (!r.ok) {
|
||||
const text = await r.text()
|
||||
throw new Error(text || `HTTP ${r.status}`)
|
||||
}
|
||||
const data = await r.json() as { chunks?: string[] }
|
||||
const chunks = data.chunks ?? []
|
||||
corpus.value = chunks
|
||||
rawCorpus.value = chunks.join('\n')
|
||||
liveMessage.value = `${chunks.length} chunk${chunks.length === 1 ? '' : 's'} loaded from import.`
|
||||
} catch (err) {
|
||||
importError.value = String(err)
|
||||
} finally {
|
||||
importing.value = false
|
||||
}
|
||||
}
|
||||
|
||||
// ── Model loading ─────────────────────────────────────────────────────────────
|
||||
|
||||
async function loadModels() {
|
||||
loadingModels.value = true
|
||||
modelsError.value = ''
|
||||
try {
|
||||
const r = await fetch('/api/embed-bench/models')
|
||||
if (!r.ok) throw new Error(`HTTP ${r.status}`)
|
||||
const data = await r.json() as { models: OllamaModel[] }
|
||||
availableModels.value = data.models
|
||||
} catch (err) {
|
||||
modelsError.value = `Failed to load models: ${err}`
|
||||
} finally {
|
||||
loadingModels.value = false
|
||||
}
|
||||
}
|
||||
|
||||
// ── Run ───────────────────────────────────────────────────────────────────────
|
||||
|
||||
async function startRun() {
|
||||
if (!canRun.value) return
|
||||
running.value = true
|
||||
resultEvents.value = []
|
||||
liveMessage.value = 'Starting embedding run…'
|
||||
runController.value = new AbortController()
|
||||
|
||||
try {
|
||||
const resp = await fetch('/api/embed-bench/run', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({
|
||||
corpus: corpus.value,
|
||||
queries: queries.value,
|
||||
models: selectedModels.value,
|
||||
top_k: topK.value,
|
||||
}),
|
||||
signal: runController.value.signal,
|
||||
})
|
||||
|
||||
const reader = resp.body!.getReader()
|
||||
const decoder = new TextDecoder()
|
||||
let buf = ''
|
||||
|
||||
while (true) {
|
||||
const { done, value } = await reader.read()
|
||||
if (done) break
|
||||
buf += decoder.decode(value, { stream: true })
|
||||
const lines = buf.split('\n')
|
||||
buf = lines.pop() ?? ''
|
||||
for (const line of lines) {
|
||||
if (!line.startsWith('data: ')) continue
|
||||
const event = JSON.parse(line.slice(6))
|
||||
if (event.type === 'progress') {
|
||||
liveMessage.value = event.msg
|
||||
} else if (event.type === 'result') {
|
||||
resultEvents.value.push(event as ResultEvent)
|
||||
} else if (event.type === 'done') {
|
||||
liveMessage.value = 'Run complete.'
|
||||
} else if (event.type === 'error') {
|
||||
liveMessage.value = `Error: ${event.msg}`
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (err) {
|
||||
if ((err as Error).name !== 'AbortError') {
|
||||
liveMessage.value = `Run failed: ${err}`
|
||||
}
|
||||
} finally {
|
||||
running.value = false
|
||||
runController.value = null
|
||||
}
|
||||
}
|
||||
|
||||
function cancelRun() {
|
||||
runController.value?.abort()
|
||||
liveMessage.value = 'Run cancelled.'
|
||||
}
|
||||
|
||||
// ── Utilities ─────────────────────────────────────────────────────────────────
|
||||
|
||||
function formatBytes(bytes: number): string {
|
||||
if (bytes < 1_000_000) return `${(bytes / 1000).toFixed(0)} KB`
|
||||
if (bytes < 1_000_000_000) return `${(bytes / 1_000_000).toFixed(0)} MB`
|
||||
return `${(bytes / 1_000_000_000).toFixed(1)} GB`
|
||||
}
|
||||
|
||||
function getHit(queryIdx: number, model: string, rank: number): HitResult | null {
|
||||
const query = uniqueQueries.value[queryIdx]
|
||||
if (!query) return null
|
||||
const ev = resultEvents.value.find(e => e.query === query && e.model === model)
|
||||
return ev?.hits[rank] ?? null
|
||||
}
|
||||
|
||||
function getRating(queryIdx: number, model: string, chunkIdx: number): string | undefined {
|
||||
const query = uniqueQueries.value[queryIdx]
|
||||
return ratings.value[query]?.[model]?.[chunkIdx]
|
||||
}
|
||||
|
||||
async function rate(
|
||||
queryIdx: number,
|
||||
model: string,
|
||||
hit: HitResult,
|
||||
rating: 'relevant' | 'not_relevant',
|
||||
) {
|
||||
const query = uniqueQueries.value[queryIdx]
|
||||
// Optimistic update
|
||||
if (!ratings.value[query]) ratings.value[query] = {}
|
||||
if (!ratings.value[query][model]) ratings.value[query][model] = {}
|
||||
ratings.value[query][model][hit.chunk_idx] = rating
|
||||
|
||||
try {
|
||||
await fetch('/api/embed-bench/rate', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({
|
||||
query,
|
||||
model,
|
||||
chunk_text: hit.text,
|
||||
chunk_idx: hit.chunk_idx,
|
||||
rating,
|
||||
}),
|
||||
})
|
||||
liveMessage.value = `Rated chunk ${hit.chunk_idx + 1} as ${rating}.`
|
||||
} catch (err) {
|
||||
liveMessage.value = `Rating failed: ${err}`
|
||||
}
|
||||
}
|
||||
|
||||
async function exportRatings() {
|
||||
const r = await fetch(`/api/embed-bench/export?format=${exportFormat.value}`)
|
||||
if (!r.ok) {
|
||||
liveMessage.value = `Export failed: HTTP ${r.status}`
|
||||
return
|
||||
}
|
||||
const blob = await r.blob()
|
||||
const disposition = r.headers.get('Content-Disposition') ?? ''
|
||||
const filenameMatch = disposition.match(/filename="([^"]+)"/)
|
||||
const filename = filenameMatch ? filenameMatch[1] : `embed_comparison.${exportFormat.value}`
|
||||
const url = URL.createObjectURL(blob)
|
||||
const a = document.createElement('a')
|
||||
a.href = url
|
||||
a.download = filename
|
||||
a.click()
|
||||
URL.revokeObjectURL(url)
|
||||
liveMessage.value = `Exported ${filename}.`
|
||||
}
|
||||
|
||||
// ── Lifecycle ─────────────────────────────────────────────────────────────────
|
||||
|
||||
onMounted(() => {
|
||||
loadModels()
|
||||
})
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
.embed-compare-page {
|
||||
padding: var(--space-4, 1.5rem);
|
||||
max-width: 1100px;
|
||||
}
|
||||
|
||||
/* Step indicator */
|
||||
.step-indicator {
|
||||
display: flex;
|
||||
gap: 0;
|
||||
list-style: none;
|
||||
margin: 0 0 var(--space-4, 1.5rem);
|
||||
padding: 0;
|
||||
border-bottom: 2px solid var(--color-border, #d0d7e8);
|
||||
}
|
||||
.step-indicator li {
|
||||
padding: 0.4rem 1rem;
|
||||
font-size: 0.8rem;
|
||||
font-weight: 600;
|
||||
color: var(--color-text-muted, #4a5c7a);
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.05em;
|
||||
border-bottom: 2px solid transparent;
|
||||
margin-bottom: -2px;
|
||||
}
|
||||
.step-indicator li.complete {
|
||||
color: var(--app-primary, #2A6080);
|
||||
border-bottom-color: var(--app-primary, #2A6080);
|
||||
}
|
||||
|
||||
/* Accessibility: screen-reader live region — visually hidden but always present */
|
||||
.sr-live {
|
||||
position: absolute;
|
||||
width: 1px; height: 1px;
|
||||
overflow: hidden;
|
||||
clip: rect(0 0 0 0);
|
||||
white-space: nowrap;
|
||||
}
|
||||
|
||||
/* Cards */
|
||||
.card {
|
||||
background: var(--color-surface-raised, #e4ebf5);
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: var(--radius-md, 0.5rem);
|
||||
padding: var(--space-4, 1.5rem);
|
||||
margin-bottom: var(--space-4, 1.5rem);
|
||||
}
|
||||
.card h2 {
|
||||
font-size: 1rem;
|
||||
font-weight: 700;
|
||||
margin: 0 0 var(--space-3, 1rem);
|
||||
color: var(--color-text, #1a2338);
|
||||
}
|
||||
|
||||
.field { display: flex; flex-direction: column; gap: 0.3rem; margin-bottom: 0.75rem; }
|
||||
.field label { font-size: 0.85rem; font-weight: 600; }
|
||||
textarea, input[type="number"] {
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: var(--radius-sm, 0.25rem);
|
||||
padding: 0.5rem;
|
||||
font-size: 0.875rem;
|
||||
background: var(--color-surface, #f0f4fb);
|
||||
color: var(--color-text, #1a2338);
|
||||
resize: vertical;
|
||||
}
|
||||
|
||||
.corpus-controls { display: flex; flex-direction: column; gap: 0.5rem; }
|
||||
.import-row {
|
||||
display: flex; flex-wrap: wrap; gap: 0.5rem; align-items: center;
|
||||
}
|
||||
.import-row label { font-size: 0.85rem; font-weight: 600; }
|
||||
.corpus-count, .query-count { font-size: 0.875rem; color: var(--app-primary, #2A6080); margin: 0; }
|
||||
|
||||
.model-list { list-style: none; padding: 0; margin: 0; display: flex; flex-wrap: wrap; gap: 0.5rem; }
|
||||
.model-checkbox {
|
||||
display: flex; align-items: center; gap: 0.4rem;
|
||||
font-size: 0.875rem; cursor: pointer;
|
||||
padding: 0.3rem 0.6rem;
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: var(--radius-sm, 0.25rem);
|
||||
background: var(--color-surface, #f0f4fb);
|
||||
}
|
||||
.model-size { font-size: 0.75rem; }
|
||||
|
||||
.run-row { display: flex; flex-wrap: wrap; gap: 0.75rem; align-items: flex-end; }
|
||||
.field-inline { display: flex; align-items: center; gap: 0.4rem; }
|
||||
.field-inline label { font-size: 0.85rem; font-weight: 600; white-space: nowrap; }
|
||||
|
||||
.btn-primary, .btn-secondary, .btn-danger {
|
||||
padding: 0.4rem 1rem;
|
||||
border-radius: var(--radius-sm, 0.25rem);
|
||||
border: 1px solid transparent;
|
||||
font-size: 0.875rem;
|
||||
font-weight: 600;
|
||||
cursor: pointer;
|
||||
transition: background 0.15s;
|
||||
}
|
||||
.btn-primary { background: var(--app-primary, #2A6080); color: #fff; }
|
||||
.btn-primary:hover:not(:disabled) { filter: brightness(1.1); }
|
||||
.btn-primary:disabled { opacity: 0.5; cursor: not-allowed; }
|
||||
.btn-secondary { background: var(--color-surface, #f0f4fb); color: var(--color-text, #1a2338); border-color: var(--color-border, #d0d7e8); }
|
||||
.btn-secondary:hover:not(:disabled) { background: var(--color-border, #d0d7e8); }
|
||||
.btn-secondary:disabled { opacity: 0.5; cursor: not-allowed; }
|
||||
.btn-danger { background: var(--color-error, #c0392b); color: #fff; }
|
||||
|
||||
.muted { color: var(--color-text-muted, #4a5c7a); font-size: 0.875rem; }
|
||||
.error-text { color: var(--color-error, #c0392b); font-size: 0.875rem; }
|
||||
|
||||
@media (max-width: 768px) {
|
||||
.import-row { flex-direction: column; align-items: flex-start; }
|
||||
.run-row { flex-direction: column; }
|
||||
.model-list { flex-direction: column; }
|
||||
}
|
||||
|
||||
/* Results table */
|
||||
.table-wrap { overflow-x: auto; }
|
||||
.results-table {
|
||||
width: 100%;
|
||||
border-collapse: collapse;
|
||||
font-size: 0.875rem;
|
||||
}
|
||||
.results-table thead th {
|
||||
position: sticky;
|
||||
top: 0;
|
||||
background: var(--color-surface-raised, #e4ebf5);
|
||||
border-bottom: 2px solid var(--color-border, #d0d7e8);
|
||||
padding: 0.5rem 0.75rem;
|
||||
text-align: left;
|
||||
font-weight: 700;
|
||||
white-space: nowrap;
|
||||
z-index: 1;
|
||||
}
|
||||
.results-table td {
|
||||
padding: 0.5rem 0.75rem;
|
||||
vertical-align: top;
|
||||
border-bottom: 1px solid var(--color-border, #d0d7e8);
|
||||
}
|
||||
.rank-col { width: 2rem; text-align: center; }
|
||||
|
||||
.hit-text { margin-bottom: 0.25rem; line-height: 1.4; }
|
||||
|
||||
.score-row { display: flex; align-items: center; gap: 0.4rem; margin-bottom: 0.25rem; }
|
||||
.score-bar-wrap {
|
||||
flex: 1;
|
||||
height: 6px;
|
||||
background: var(--color-border, #d0d7e8);
|
||||
border-radius: 3px;
|
||||
overflow: hidden;
|
||||
}
|
||||
.score-bar {
|
||||
height: 100%;
|
||||
background: var(--app-primary, #2A6080);
|
||||
border-radius: 3px;
|
||||
transition: width 0.3s ease;
|
||||
}
|
||||
.score-label { font-size: 0.75rem; color: var(--color-text-muted, #4a5c7a); min-width: 3rem; text-align: right; }
|
||||
|
||||
.rating-row { display: flex; gap: 0.4rem; flex-wrap: wrap; }
|
||||
.rate-btn {
|
||||
padding: 0.2rem 0.5rem;
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: var(--radius-sm, 0.25rem);
|
||||
background: var(--color-surface, #f0f4fb);
|
||||
color: var(--color-text, #1a2338);
|
||||
font-size: 0.75rem;
|
||||
cursor: pointer;
|
||||
transition: background 0.15s, border-color 0.15s;
|
||||
}
|
||||
.rate-btn.active {
|
||||
background: color-mix(in srgb, var(--app-primary, #2A6080) 20%, transparent);
|
||||
border-color: var(--app-primary, #2A6080);
|
||||
font-weight: 700;
|
||||
}
|
||||
.rate-btn-neg.active {
|
||||
background: color-mix(in srgb, var(--color-error, #c0392b) 15%, transparent);
|
||||
border-color: var(--color-error, #c0392b);
|
||||
}
|
||||
|
||||
/* Query nav */
|
||||
.query-nav {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.5rem;
|
||||
margin-bottom: 0.75rem;
|
||||
flex-wrap: wrap;
|
||||
}
|
||||
.query-counter { font-size: 0.875rem; flex: 1; }
|
||||
|
||||
/* Export */
|
||||
.export-row { display: flex; gap: 1rem; align-items: center; flex-wrap: wrap; }
|
||||
.export-format-group {
|
||||
border: none;
|
||||
padding: 0;
|
||||
display: flex;
|
||||
gap: 0.75rem;
|
||||
}
|
||||
.export-format-group legend {
|
||||
font-size: 0.85rem;
|
||||
font-weight: 600;
|
||||
margin-bottom: 0.25rem;
|
||||
float: left;
|
||||
margin-right: 0.5rem;
|
||||
}
|
||||
.export-format-group label { font-size: 0.875rem; display: flex; align-items: center; gap: 0.3rem; }
|
||||
|
||||
@media (max-width: 768px) {
|
||||
.results-table thead th,
|
||||
.results-table td { padding: 0.35rem 0.4rem; font-size: 0.8rem; }
|
||||
.query-nav { flex-direction: column; align-items: flex-start; }
|
||||
}
|
||||
|
||||
@media (prefers-reduced-motion: reduce) {
|
||||
.score-bar { transition: none; }
|
||||
}
|
||||
</style>
|
||||
7
web/src/views/EmbedCompareView.vue
Normal file
7
web/src/views/EmbedCompareView.vue
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
<template>
|
||||
<EmbedCompareTab />
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import EmbedCompareTab from './EmbedCompareTab.vue'
|
||||
</script>
|
||||
|
|
@ -6,6 +6,8 @@
|
|||
<summary class="picker-summary">
|
||||
<span class="picker-title">📋 Task Selection</span>
|
||||
<span class="picker-badge">{{ llmTaskBadge }}</span>
|
||||
<button class="picker-bulk-btn" @click.stop.prevent="selectAllTasks()">All</button>
|
||||
<button class="picker-bulk-btn" @click.stop.prevent="clearAllTasks()">None</button>
|
||||
</summary>
|
||||
<div class="picker-body">
|
||||
<div v-if="llmTasksLoading" class="picker-loading">Loading tasks…</div>
|
||||
|
|
@ -44,6 +46,8 @@
|
|||
<summary class="picker-summary">
|
||||
<span class="picker-title">🎯 Model Selection</span>
|
||||
<span class="picker-badge">{{ llmModelBadge }}</span>
|
||||
<button class="picker-bulk-btn" @click.stop.prevent="selectAllModels()">All</button>
|
||||
<button class="picker-bulk-btn" @click.stop.prevent="clearAllModels()">None</button>
|
||||
</summary>
|
||||
<div class="picker-body">
|
||||
<div v-if="llmModelsLoading" class="picker-loading">Loading models…</div>
|
||||
|
|
@ -78,6 +82,33 @@
|
|||
</div>
|
||||
</details>
|
||||
|
||||
<!-- Node Selection -->
|
||||
<div class="node-picker" v-if="llmNodes.length > 0">
|
||||
<span class="node-picker-label">Nodes:</span>
|
||||
<label
|
||||
v-for="node in llmNodes"
|
||||
:key="node.node_id"
|
||||
class="node-chip"
|
||||
:class="{ 'node-chip--off': !enabledNodes.has(node.node_id), 'node-chip--offline': !node.online }"
|
||||
:title="node.online ? `${node.node_id} — ${node.gpus.length} GPU(s)` : `${node.node_id} — offline`"
|
||||
>
|
||||
<input
|
||||
type="checkbox"
|
||||
class="node-chip-check"
|
||||
:checked="enabledNodes.has(node.node_id)"
|
||||
:disabled="!node.online || llmRunning"
|
||||
@change="toggleNode(node.node_id, ($event.target as HTMLInputElement).checked)"
|
||||
/>
|
||||
{{ node.node_id }}
|
||||
<span class="node-chip-status" v-if="!node.online">offline</span>
|
||||
</label>
|
||||
<span class="node-picker-hint">
|
||||
{{ enabledNodeIds.length === llmNodes.filter(n => n.online).length
|
||||
? 'auto-routing (all nodes)'
|
||||
: `restricted to: ${enabledNodeIds.join(', ')}` }}
|
||||
</span>
|
||||
</div>
|
||||
|
||||
<!-- Run Controls -->
|
||||
<div class="run-controls">
|
||||
<button
|
||||
|
|
@ -88,6 +119,24 @@
|
|||
{{ llmRunning ? '⏳ Running…' : '▶ Run LLM Eval' }}
|
||||
</button>
|
||||
<button v-if="llmRunning" class="btn-cancel" @click="cancelLlmBenchmark">✕ Cancel</button>
|
||||
<input
|
||||
v-model="llmJudgeUrl"
|
||||
class="judge-url-input"
|
||||
placeholder="Judge URL — leave empty to skip LLM judge scoring"
|
||||
:disabled="llmRunning"
|
||||
title="Optional: URL of a running cf-text service (e.g. http://10.1.10.158:8008). When set, each LLM response gets a secondary score from the judge model — adds a 'judge' column to results. Empty = primary quality scoring only."
|
||||
/>
|
||||
<label class="workers-label" title="Run this many models concurrently (requires multiple GPUs)">
|
||||
<span class="workers-prefix">workers</span>
|
||||
<input
|
||||
v-model.number="llmWorkers"
|
||||
type="number"
|
||||
min="1"
|
||||
max="8"
|
||||
class="workers-input"
|
||||
:disabled="llmRunning"
|
||||
/>
|
||||
</label>
|
||||
<span v-if="selectedLlmTasks.size === 0 || selectedLlmModels.size === 0" class="run-hint">
|
||||
Select at least one task and one model to run.
|
||||
</span>
|
||||
|
|
@ -119,6 +168,7 @@
|
|||
<tr>
|
||||
<th class="hm-label-col">Model</th>
|
||||
<th class="hm-model-col">overall</th>
|
||||
<th v-if="llmHasJudge" class="hm-model-col hm-judge-col">judge</th>
|
||||
<th v-for="col in llmTaskTypeCols" :key="col" class="hm-model-col">{{ col }}</th>
|
||||
<th class="hm-model-col">tok/s</th>
|
||||
</tr>
|
||||
|
|
@ -130,6 +180,12 @@
|
|||
class="hm-value-cell"
|
||||
:class="{ 'bt-best': llmBestByCol['overall'] === row.model_id }"
|
||||
>{{ pct(row.avg_quality_score) }}</td>
|
||||
<td
|
||||
v-if="llmHasJudge"
|
||||
class="hm-value-cell hm-judge-cell"
|
||||
:class="{ 'bt-best': llmBestByCol['judge'] === row.model_id }"
|
||||
title="LLM-as-judge secondary score"
|
||||
>{{ row.avg_judge_score != null ? pct(row.avg_judge_score) : '—' }}</td>
|
||||
<td
|
||||
v-for="col in llmTaskTypeCols"
|
||||
:key="col"
|
||||
|
|
@ -168,6 +224,12 @@ interface CfOrchModel {
|
|||
vram_estimate_mb?: number
|
||||
}
|
||||
|
||||
interface CfOrchNode {
|
||||
node_id: string
|
||||
online: boolean
|
||||
gpus: { gpu_id: number; name: string; vram_total_mb: number; vram_free_mb: number }[]
|
||||
}
|
||||
|
||||
interface LlmModelResult {
|
||||
model_name: string
|
||||
model_id: string
|
||||
|
|
@ -175,9 +237,11 @@ interface LlmModelResult {
|
|||
avg_tokens_per_sec: number
|
||||
avg_completion_ms: number
|
||||
avg_quality_score: number
|
||||
avg_judge_score: number | null
|
||||
finetune_candidates: number
|
||||
error_count: number
|
||||
quality_by_task_type: Record<string, number>
|
||||
judge_score_by_task_type?: Record<string, number>
|
||||
}
|
||||
|
||||
// ── State ───────────────────────────────────────────────────────────────────
|
||||
|
|
@ -195,6 +259,10 @@ const llmError = ref('')
|
|||
const llmResults = ref<LlmModelResult[]>([])
|
||||
const llmEventSource = ref<EventSource | null>(null)
|
||||
const llmLogEl = ref<HTMLElement | null>(null)
|
||||
const llmJudgeUrl = ref('')
|
||||
const llmWorkers = ref(1)
|
||||
const llmNodes = ref<CfOrchNode[]>([])
|
||||
const enabledNodes = ref<Set<string>>(new Set())
|
||||
|
||||
// ── Computed ────────────────────────────────────────────────────────────────
|
||||
const llmTasksByType = computed((): Record<string, CfOrchTask[]> => {
|
||||
|
|
@ -234,11 +302,19 @@ const llmModelBadge = computed(() => {
|
|||
const llmTaskTypeCols = computed(() => {
|
||||
const types = new Set<string>()
|
||||
for (const r of llmResults.value) {
|
||||
for (const k of Object.keys(r.quality_by_task_type)) types.add(k)
|
||||
for (const k of Object.keys(r.quality_by_task_type ?? {})) types.add(k)
|
||||
}
|
||||
return [...types].sort()
|
||||
})
|
||||
|
||||
const llmHasJudge = computed(() =>
|
||||
llmResults.value.some(r => r.avg_judge_score != null)
|
||||
)
|
||||
|
||||
const enabledNodeIds = computed(() =>
|
||||
llmNodes.value.filter(n => n.online && enabledNodes.value.has(n.node_id)).map(n => n.node_id)
|
||||
)
|
||||
|
||||
const llmBestByCol = computed((): Record<string, string> => {
|
||||
const best: Record<string, string> = {}
|
||||
if (llmResults.value.length === 0) return best
|
||||
|
|
@ -249,10 +325,20 @@ const llmBestByCol = computed((): Record<string, string> => {
|
|||
}
|
||||
best['overall'] = bestId
|
||||
|
||||
if (llmHasJudge.value) {
|
||||
bestId = ''; bestVal = -Infinity
|
||||
for (const r of llmResults.value) {
|
||||
if (r.avg_judge_score != null && r.avg_judge_score > bestVal) {
|
||||
bestVal = r.avg_judge_score; bestId = r.model_id
|
||||
}
|
||||
}
|
||||
best['judge'] = bestId
|
||||
}
|
||||
|
||||
for (const col of llmTaskTypeCols.value) {
|
||||
bestId = ''; bestVal = -Infinity
|
||||
for (const r of llmResults.value) {
|
||||
const v = r.quality_by_task_type[col]
|
||||
const v = r.quality_by_task_type?.[col]
|
||||
if (v != null && v > bestVal) { bestVal = v; bestId = r.model_id }
|
||||
}
|
||||
best[col] = bestId
|
||||
|
|
@ -306,6 +392,15 @@ function toggleService(models: CfOrchModel[], checked: boolean) {
|
|||
}
|
||||
selectedLlmModels.value = next
|
||||
}
|
||||
function selectAllTasks() { selectedLlmTasks.value = new Set(llmTasks.value.map(t => t.id)) }
|
||||
function clearAllTasks() { selectedLlmTasks.value = new Set() }
|
||||
function selectAllModels() { selectedLlmModels.value = new Set(llmModels.value.map(m => m.id)) }
|
||||
function clearAllModels() { selectedLlmModels.value = new Set() }
|
||||
function toggleNode(id: string, checked: boolean) {
|
||||
const next = new Set(enabledNodes.value)
|
||||
checked ? next.add(id) : next.delete(id)
|
||||
enabledNodes.value = next
|
||||
}
|
||||
|
||||
// ── Data loaders ─────────────────────────────────────────────────────────────
|
||||
async function loadLlmTasks() {
|
||||
|
|
@ -335,6 +430,21 @@ async function loadLlmResults() {
|
|||
}
|
||||
}
|
||||
|
||||
async function loadLlmConfig() {
|
||||
const { data } = await useApiFetch<{ judge_url?: string }>('/api/cforch/config')
|
||||
if (data?.judge_url && !llmJudgeUrl.value) {
|
||||
llmJudgeUrl.value = data.judge_url
|
||||
}
|
||||
}
|
||||
|
||||
async function loadLlmNodes() {
|
||||
const { data } = await useApiFetch<{ nodes: CfOrchNode[] }>('/api/cforch/nodes')
|
||||
if (data?.nodes) {
|
||||
llmNodes.value = data.nodes
|
||||
enabledNodes.value = new Set(data.nodes.filter(n => n.online).map(n => n.node_id))
|
||||
}
|
||||
}
|
||||
|
||||
// ── Run / cancel ──────────────────────────────────────────────────────────────
|
||||
function startLlmBenchmark() {
|
||||
llmRunning.value = true
|
||||
|
|
@ -344,6 +454,15 @@ function startLlmBenchmark() {
|
|||
const params = new URLSearchParams()
|
||||
const taskIds = [...selectedLlmTasks.value].join(',')
|
||||
if (taskIds) params.set('task_ids', taskIds)
|
||||
const modelIds = [...selectedLlmModels.value].join(',')
|
||||
if (modelIds) params.set('model_ids', modelIds)
|
||||
if (llmJudgeUrl.value.trim()) params.set('judge_url', llmJudgeUrl.value.trim())
|
||||
if (llmWorkers.value > 1) params.set('workers', String(llmWorkers.value))
|
||||
const onlineNodeIds = llmNodes.value.filter(n => n.online).map(n => n.node_id)
|
||||
const isRestricted = enabledNodeIds.value.length < onlineNodeIds.length
|
||||
if (isRestricted && enabledNodeIds.value.length > 0) {
|
||||
params.set('node_ids', enabledNodeIds.value.join(','))
|
||||
}
|
||||
|
||||
const es = new EventSource(`/api/cforch/run?${params}`)
|
||||
llmEventSource.value = es
|
||||
|
|
@ -387,6 +506,8 @@ onMounted(() => {
|
|||
loadLlmTasks()
|
||||
loadLlmModels()
|
||||
loadLlmResults()
|
||||
loadLlmConfig()
|
||||
loadLlmNodes()
|
||||
})
|
||||
</script>
|
||||
|
||||
|
|
@ -451,6 +572,43 @@ onMounted(() => {
|
|||
color: var(--color-text-secondary, #6b7a99);
|
||||
}
|
||||
|
||||
.judge-url-input {
|
||||
flex: 1;
|
||||
min-width: 14rem;
|
||||
max-width: 24rem;
|
||||
padding: 0.35rem 0.6rem;
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: 0.375rem;
|
||||
background: var(--color-surface, #fff);
|
||||
color: var(--color-text, #1a2338);
|
||||
font-size: 0.8rem;
|
||||
font-family: var(--font-mono, monospace);
|
||||
}
|
||||
.judge-url-input:disabled { opacity: 0.5; }
|
||||
.judge-url-input::placeholder { color: var(--color-text-secondary, #6b7a99); }
|
||||
|
||||
.workers-label {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.3rem;
|
||||
font-size: 0.8rem;
|
||||
color: var(--color-text-secondary, #6b7a99);
|
||||
white-space: nowrap;
|
||||
}
|
||||
.workers-prefix { font-family: var(--font-mono, monospace); }
|
||||
.workers-input {
|
||||
width: 3.2rem;
|
||||
padding: 0.35rem 0.4rem;
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: 0.375rem;
|
||||
background: var(--color-surface, #fff);
|
||||
color: var(--color-text, #1a2338);
|
||||
font-size: 0.8rem;
|
||||
font-family: var(--font-mono, monospace);
|
||||
text-align: center;
|
||||
}
|
||||
.workers-input:disabled { opacity: 0.5; }
|
||||
|
||||
/* ── Run log ────────────────────────────────────────────── */
|
||||
.run-log {
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
|
|
@ -592,6 +750,15 @@ onMounted(() => {
|
|||
white-space: nowrap;
|
||||
}
|
||||
|
||||
.hm-judge-col {
|
||||
background: color-mix(in srgb, var(--color-surface-raised, #e4ebf5) 80%, #c6d5f5);
|
||||
}
|
||||
.hm-judge-cell {
|
||||
background: color-mix(in srgb, var(--color-surface, #fff) 85%, #c6d5f5);
|
||||
font-style: italic;
|
||||
opacity: 0.9;
|
||||
}
|
||||
|
||||
/* ── Model Picker ───────────────────────────────────────── */
|
||||
.model-picker {
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
|
|
@ -630,6 +797,24 @@ details[open] .picker-summary::before { content: '▼ '; }
|
|||
margin-left: auto;
|
||||
}
|
||||
|
||||
.picker-bulk-btn {
|
||||
padding: 0.1rem 0.45rem;
|
||||
font-size: 0.7rem;
|
||||
font-family: var(--font-mono, monospace);
|
||||
background: var(--color-surface, #fff);
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: 0.25rem;
|
||||
color: var(--color-text-secondary, #6b7a99);
|
||||
cursor: pointer;
|
||||
transition: background 0.12s, color 0.12s;
|
||||
flex-shrink: 0;
|
||||
}
|
||||
.picker-bulk-btn:hover {
|
||||
background: var(--app-primary, #2A6080);
|
||||
color: #fff;
|
||||
border-color: var(--app-primary, #2A6080);
|
||||
}
|
||||
|
||||
.picker-body {
|
||||
padding: 0.75rem;
|
||||
border-top: 1px solid var(--color-border, #d0d7e8);
|
||||
|
|
@ -712,4 +897,61 @@ details[open] .picker-summary::before { content: '▼ '; }
|
|||
.picker-model-list { padding-left: 0; }
|
||||
.picker-model-name { max-width: 14ch; }
|
||||
}
|
||||
|
||||
/* ── Node picker ────────────────────────────────────── */
|
||||
.node-picker {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.5rem;
|
||||
flex-wrap: wrap;
|
||||
padding: 0.5rem 0.75rem;
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: 0.5rem;
|
||||
background: var(--color-surface-raised, #e4ebf5);
|
||||
}
|
||||
|
||||
.node-picker-label {
|
||||
font-size: 0.78rem;
|
||||
font-weight: 600;
|
||||
color: var(--color-text-secondary, #6b7a99);
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.04em;
|
||||
white-space: nowrap;
|
||||
}
|
||||
|
||||
.node-chip {
|
||||
display: inline-flex;
|
||||
align-items: center;
|
||||
gap: 0.3rem;
|
||||
padding: 0.2rem 0.55rem;
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: 1rem;
|
||||
background: var(--color-surface, #fff);
|
||||
font-size: 0.78rem;
|
||||
font-family: var(--font-mono, monospace);
|
||||
color: var(--color-text, #1a2338);
|
||||
cursor: pointer;
|
||||
transition: background 0.12s, opacity 0.12s;
|
||||
}
|
||||
.node-chip--off {
|
||||
opacity: 0.45;
|
||||
background: transparent;
|
||||
}
|
||||
.node-chip--offline {
|
||||
opacity: 0.35;
|
||||
cursor: not-allowed;
|
||||
font-style: italic;
|
||||
}
|
||||
.node-chip-check { accent-color: var(--app-primary, #2A6080); }
|
||||
.node-chip-status {
|
||||
font-size: 0.66rem;
|
||||
color: var(--color-text-secondary, #6b7a99);
|
||||
}
|
||||
|
||||
.node-picker-hint {
|
||||
font-size: 0.72rem;
|
||||
color: var(--color-text-secondary, #6b7a99);
|
||||
font-family: var(--font-mono, monospace);
|
||||
margin-left: auto;
|
||||
}
|
||||
</style>
|
||||
|
|
|
|||
|
|
@ -2,6 +2,24 @@
|
|||
<div class="models-view">
|
||||
<h1 class="page-title">🤗 Models</h1>
|
||||
|
||||
<!-- ── Fleet tab bar ─────────────────────────────── -->
|
||||
<div class="mode-toggle" role="group" aria-label="Fleet view">
|
||||
<button
|
||||
class="mode-btn"
|
||||
:class="{ active: fleetTab === 'models' }"
|
||||
@click="fleetTab = 'models'"
|
||||
>Models</button>
|
||||
<button
|
||||
class="mode-btn"
|
||||
:class="{ active: fleetTab === 'assignments' }"
|
||||
@click="fleetTab = 'assignments'"
|
||||
>Assignments</button>
|
||||
</div>
|
||||
|
||||
<AssignmentsTab v-if="fleetTab === 'assignments'" />
|
||||
|
||||
<template v-if="fleetTab === 'models'">
|
||||
|
||||
<!-- ── 1. HF Lookup ───────────────────────────────── -->
|
||||
<section class="section">
|
||||
<h2 class="section-title">HuggingFace Lookup</h2>
|
||||
|
|
@ -51,8 +69,31 @@
|
|||
<span v-if="lookupResult.adapter_recommendation" class="chip chip-adapter">
|
||||
{{ lookupResult.adapter_recommendation }}
|
||||
</span>
|
||||
<span v-if="lookupResult.size != null" class="preview-size">
|
||||
{{ humanBytes(lookupResult.size) }}
|
||||
<span v-if="selectedQuantSize > 0" class="preview-size">
|
||||
{{ humanBytes(selectedQuantSize) }}
|
||||
</span>
|
||||
</div>
|
||||
|
||||
<!-- GGUF quantization picker — only shown for GGUF repos -->
|
||||
<div v-if="lookupResult.gguf_files?.length" class="quant-picker">
|
||||
<label class="quant-label" for="quant-select">Quantization</label>
|
||||
<select
|
||||
id="quant-select"
|
||||
v-model="selectedQuant"
|
||||
class="quant-select"
|
||||
aria-label="Select quantization variant"
|
||||
>
|
||||
<option :value="null" disabled>Select quantization…</option>
|
||||
<option
|
||||
v-for="f in lookupResult.gguf_files"
|
||||
:key="f.filename"
|
||||
:value="f.quant_name ?? f.filename"
|
||||
>
|
||||
{{ f.quant_name ?? f.filename }} — {{ humanBytes(f.size) }}
|
||||
</option>
|
||||
</select>
|
||||
<span class="quant-hint">
|
||||
Q5_K_M or Q6_K recommended for 8 GB GPUs. Q8_0 for max quality.
|
||||
</span>
|
||||
</div>
|
||||
|
||||
|
|
@ -67,7 +108,7 @@
|
|||
|
||||
<button
|
||||
class="btn-primary btn-add-queue"
|
||||
:disabled="lookupResult.already_installed || lookupResult.already_queued || addingToQueue"
|
||||
:disabled="!canAddToQueue"
|
||||
@click="addToQueue"
|
||||
>
|
||||
{{ addingToQueue ? 'Adding…' : 'Add to queue' }}
|
||||
|
|
@ -99,9 +140,39 @@
|
|||
<span v-if="model.role" class="chip chip-role">{{ model.role }}</span>
|
||||
<span v-if="model.service" class="chip" :class="serviceChipClass(model.service)">{{ model.service }}</span>
|
||||
<span v-if="model.adapter_recommendation" class="chip chip-adapter">{{ model.adapter_recommendation }}</span>
|
||||
<span v-if="model.quant_pattern" class="chip chip-quant">{{ model.quant_pattern }}</span>
|
||||
</div>
|
||||
<!-- Allow manual service/role assignment for unrecognized pipeline tags -->
|
||||
<div v-if="!model.service" class="classify-row queue-classify">
|
||||
<select
|
||||
class="classify-select"
|
||||
:value="classifyDraft[model.id]?.service ?? ''"
|
||||
@change="onServiceChange(model.id, ($event.target as HTMLSelectElement).value)"
|
||||
aria-label="Assign service"
|
||||
>
|
||||
<option value="" disabled>Service…</option>
|
||||
<option v-for="svc in CLASSIFIABLE_SERVICES" :key="svc.value" :value="svc.value">{{ svc.label }}</option>
|
||||
</select>
|
||||
<select
|
||||
class="classify-select"
|
||||
:value="classifyDraft[model.id]?.role ?? ''"
|
||||
:disabled="!classifyDraft[model.id]?.service"
|
||||
@change="(e) => setClassifyRole(model.id, (e.target as HTMLSelectElement).value)"
|
||||
aria-label="Assign role"
|
||||
>
|
||||
<option value="" disabled>Role…</option>
|
||||
<option
|
||||
v-for="role in rolesForService(classifyDraft[model.id]?.service ?? '')"
|
||||
:key="role"
|
||||
:value="role"
|
||||
>{{ role }}</option>
|
||||
</select>
|
||||
</div>
|
||||
<div class="model-card-actions">
|
||||
<button class="btn-primary btn-sm" @click="approveModel(model.id)">
|
||||
<button
|
||||
class="btn-primary btn-sm"
|
||||
@click="approveModel(model.id, classifyDraft[model.id])"
|
||||
>
|
||||
Approve download
|
||||
</button>
|
||||
</div>
|
||||
|
|
@ -244,14 +315,26 @@
|
|||
</div>
|
||||
</template>
|
||||
</section>
|
||||
|
||||
</template><!-- end fleetTab === 'models' -->
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, computed, onMounted, onUnmounted } from 'vue'
|
||||
import AssignmentsTab from './AssignmentsTab.vue'
|
||||
|
||||
type FleetTab = 'models' | 'assignments'
|
||||
const fleetTab = ref<FleetTab>('models')
|
||||
|
||||
// ── Type definitions ──────────────────────────────────
|
||||
|
||||
interface GgufFile {
|
||||
filename: string
|
||||
size: number
|
||||
quant_name: string | null
|
||||
}
|
||||
|
||||
interface LookupResult {
|
||||
repo_id: string
|
||||
pipeline_tag: string | null
|
||||
|
|
@ -260,7 +343,8 @@ interface LookupResult {
|
|||
service: string | null
|
||||
compatible: boolean
|
||||
warning: string | null
|
||||
size: number | null
|
||||
model_size_bytes: number
|
||||
gguf_files: GgufFile[] | null
|
||||
description: string | null
|
||||
already_installed: boolean
|
||||
already_queued: boolean
|
||||
|
|
@ -274,6 +358,7 @@ interface QueuedModel {
|
|||
adapter_recommendation: string | null
|
||||
role: string | null
|
||||
service: string | null
|
||||
quant_pattern: string | null
|
||||
}
|
||||
|
||||
interface InstalledModel {
|
||||
|
|
@ -302,6 +387,26 @@ const lookupLoading = ref(false)
|
|||
const lookupError = ref<string | null>(null)
|
||||
const lookupResult = ref<LookupResult | null>(null)
|
||||
const addingToQueue = ref(false)
|
||||
const selectedQuant = ref<string | null>(null)
|
||||
|
||||
// Size of the selected GGUF file, or total model size for non-GGUF repos.
|
||||
const selectedQuantSize = computed<number>(() => {
|
||||
const r = lookupResult.value
|
||||
if (!r) return 0
|
||||
if (r.gguf_files?.length && selectedQuant.value) {
|
||||
const f = r.gguf_files.find(f => (f.quant_name ?? f.filename) === selectedQuant.value)
|
||||
return f?.size ?? r.model_size_bytes
|
||||
}
|
||||
return r.model_size_bytes
|
||||
})
|
||||
|
||||
// Disable "Add to queue" when a GGUF repo but no quant chosen yet.
|
||||
const canAddToQueue = computed(() => {
|
||||
const r = lookupResult.value
|
||||
if (!r || r.already_installed || r.already_queued || addingToQueue.value) return false
|
||||
if (r.gguf_files?.length && !selectedQuant.value) return false
|
||||
return true
|
||||
})
|
||||
|
||||
const queuedModels = ref<QueuedModel[]>([])
|
||||
const installedModels = ref<InstalledModel[]>([])
|
||||
|
|
@ -411,6 +516,7 @@ async function doLookup() {
|
|||
lookupLoading.value = true
|
||||
lookupError.value = null
|
||||
lookupResult.value = null
|
||||
selectedQuant.value = null
|
||||
|
||||
try {
|
||||
const res = await fetch(`/api/models/lookup?repo_id=${encodeURIComponent(repoId)}`)
|
||||
|
|
@ -442,7 +548,15 @@ async function addToQueue() {
|
|||
const res = await fetch('/api/models/queue', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ repo_id, pipeline_tag, adapter_recommendation, role, service }),
|
||||
body: JSON.stringify({
|
||||
repo_id,
|
||||
pipeline_tag,
|
||||
adapter_recommendation,
|
||||
role,
|
||||
service,
|
||||
model_size_bytes: selectedQuantSize.value,
|
||||
quant_pattern: selectedQuant.value,
|
||||
}),
|
||||
})
|
||||
if (res.ok) {
|
||||
lookupResult.value = { ...lookupResult.value, already_queued: true }
|
||||
|
|
@ -454,8 +568,16 @@ async function addToQueue() {
|
|||
}
|
||||
}
|
||||
|
||||
async function approveModel(id: string) {
|
||||
async function approveModel(id: string, draft?: { service: string; role: string }) {
|
||||
try {
|
||||
// If the user picked a service/role for an unrecognized model, patch it first.
|
||||
if (draft?.service && draft?.role) {
|
||||
await fetch(`/api/models/queue/${encodeURIComponent(id)}`, {
|
||||
method: 'PATCH',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ service: draft.service, role: draft.role }),
|
||||
})
|
||||
}
|
||||
const res = await fetch(`/api/models/queue/${encodeURIComponent(id)}/approve`, { method: 'POST' })
|
||||
if (res.ok) {
|
||||
await loadQueue()
|
||||
|
|
@ -640,6 +762,39 @@ onUnmounted(() => {
|
|||
color: var(--color-primary, #2d5a27);
|
||||
}
|
||||
|
||||
/* ── Fleet tab bar (mode-toggle pattern from BenchmarkView) ── */
|
||||
.mode-toggle {
|
||||
display: inline-flex;
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: 0.5rem;
|
||||
overflow: hidden;
|
||||
align-self: flex-start;
|
||||
}
|
||||
.mode-btn {
|
||||
padding: 0.4rem 1.1rem;
|
||||
font-size: 0.85rem;
|
||||
font-family: var(--font-body, sans-serif);
|
||||
font-weight: 500;
|
||||
border: none;
|
||||
background: var(--color-surface, #fff);
|
||||
color: var(--color-text-secondary, #6b7a99);
|
||||
cursor: pointer;
|
||||
transition: background 0.15s, color 0.15s;
|
||||
}
|
||||
.mode-btn:not(:last-child) {
|
||||
border-right: 1px solid var(--color-border, #d0d7e8);
|
||||
}
|
||||
.mode-btn.active {
|
||||
background: var(--app-primary, #2A6080);
|
||||
color: #fff;
|
||||
}
|
||||
.mode-btn:not(.active):hover {
|
||||
background: var(--color-surface-raised, #e4ebf5);
|
||||
}
|
||||
@media (max-width: 600px) {
|
||||
.mode-btn { padding: 0.4rem 0.65rem; font-size: 0.78rem; }
|
||||
}
|
||||
|
||||
/* ── Sections ── */
|
||||
.section {
|
||||
display: flex;
|
||||
|
|
@ -774,6 +929,44 @@ onUnmounted(() => {
|
|||
align-self: flex-start;
|
||||
}
|
||||
|
||||
/* ── Quant picker ── */
|
||||
.quant-picker {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0.35rem;
|
||||
}
|
||||
|
||||
.quant-label {
|
||||
font-size: 0.8rem;
|
||||
font-weight: 600;
|
||||
color: var(--color-text-muted, #4a5c7a);
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.04em;
|
||||
}
|
||||
|
||||
.quant-select {
|
||||
padding: 0.4rem 0.6rem;
|
||||
border: 1px solid var(--color-border, #a8b8d0);
|
||||
border-radius: var(--radius-md, 0.5rem);
|
||||
background: var(--color-surface, #f0f4fb);
|
||||
color: var(--color-text, #1a2338);
|
||||
font-size: 0.9rem;
|
||||
font-family: var(--font-mono, monospace);
|
||||
cursor: pointer;
|
||||
}
|
||||
|
||||
.quant-hint {
|
||||
font-size: 0.78rem;
|
||||
color: var(--color-text-muted, #4a5c7a);
|
||||
}
|
||||
|
||||
.chip-quant {
|
||||
background: color-mix(in srgb, var(--color-primary, #2A6080) 15%, transparent);
|
||||
color: var(--color-primary, #2A6080);
|
||||
font-family: var(--font-mono, monospace);
|
||||
font-size: 0.75rem;
|
||||
}
|
||||
|
||||
/* ── Model cards (queue + downloads) ── */
|
||||
.model-card {
|
||||
border: 1px solid var(--color-border, #a8b8d0);
|
||||
|
|
|
|||
165
web/src/views/NodeManagementView.vue
Normal file
165
web/src/views/NodeManagementView.vue
Normal file
|
|
@ -0,0 +1,165 @@
|
|||
<script setup lang="ts">
|
||||
import { ref, onMounted } from 'vue'
|
||||
import NodeCard from '../components/nodes/NodeCard.vue'
|
||||
import AssignmentsTab from './AssignmentsTab.vue'
|
||||
import type { NodeSummary } from '../types/nodes'
|
||||
|
||||
type Tab = 'nodes' | 'assignments'
|
||||
|
||||
const activeTab = ref<Tab>('nodes')
|
||||
const nodes = ref<NodeSummary[]>([])
|
||||
const loading = ref(true)
|
||||
const error = ref('')
|
||||
|
||||
async function fetchNodes() {
|
||||
loading.value = true
|
||||
error.value = ''
|
||||
try {
|
||||
const r = await fetch('/api/nodes-mgmt/nodes')
|
||||
if (!r.ok) throw new Error(`HTTP ${r.status}`)
|
||||
nodes.value = (await r.json()) as NodeSummary[]
|
||||
} catch (e) {
|
||||
error.value = e instanceof Error ? e.message : 'Failed to load nodes'
|
||||
} finally {
|
||||
loading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
onMounted(fetchNodes)
|
||||
</script>
|
||||
|
||||
<template>
|
||||
<main class="fleet-page">
|
||||
<header class="fleet-header">
|
||||
<h1 class="fleet-title">Fleet</h1>
|
||||
</header>
|
||||
|
||||
<!-- Tab bar -->
|
||||
<nav class="tab-bar" role="tablist" aria-label="Fleet sections">
|
||||
<button
|
||||
id="tab-nodes"
|
||||
role="tab"
|
||||
:aria-selected="activeTab === 'nodes'"
|
||||
:class="['tab', { active: activeTab === 'nodes' }]"
|
||||
@click="activeTab = 'nodes'"
|
||||
>Nodes</button>
|
||||
<button
|
||||
id="tab-assignments"
|
||||
role="tab"
|
||||
:aria-selected="activeTab === 'assignments'"
|
||||
:class="['tab', { active: activeTab === 'assignments' }]"
|
||||
@click="activeTab = 'assignments'"
|
||||
>Assignments</button>
|
||||
</nav>
|
||||
|
||||
<!-- Nodes tab -->
|
||||
<section
|
||||
v-if="activeTab === 'nodes'"
|
||||
role="tabpanel"
|
||||
aria-labelledby="tab-nodes"
|
||||
class="tab-panel"
|
||||
>
|
||||
<div class="nodes-toolbar">
|
||||
<button class="btn-secondary btn-sm" @click="fetchNodes" :disabled="loading">Refresh</button>
|
||||
</div>
|
||||
<div aria-live="polite" aria-atomic="true" class="sr-announce">
|
||||
<span v-if="loading">Loading nodes...</span>
|
||||
</div>
|
||||
<div v-if="error" class="nodes-status nodes-error" role="alert">{{ error }}</div>
|
||||
<div v-else-if="!loading && nodes.length === 0" class="nodes-status">
|
||||
No nodes found. Check <code>coordinator_url</code> in config.
|
||||
</div>
|
||||
<div v-else-if="!loading" class="nodes-grid">
|
||||
<NodeCard
|
||||
v-for="node in nodes"
|
||||
:key="node.node_id"
|
||||
:node="node"
|
||||
@updated="fetchNodes"
|
||||
/>
|
||||
</div>
|
||||
</section>
|
||||
|
||||
<!-- Assignments tab -->
|
||||
<section
|
||||
v-else-if="activeTab === 'assignments'"
|
||||
role="tabpanel"
|
||||
aria-labelledby="tab-assignments"
|
||||
class="tab-panel"
|
||||
>
|
||||
<AssignmentsTab />
|
||||
</section>
|
||||
</main>
|
||||
</template>
|
||||
|
||||
<style scoped>
|
||||
.fleet-page { padding: 1.5rem; }
|
||||
|
||||
.fleet-header {
|
||||
margin-bottom: 1rem;
|
||||
}
|
||||
.fleet-title {
|
||||
margin: 0;
|
||||
font-size: 1.5rem;
|
||||
color: var(--color-text);
|
||||
}
|
||||
|
||||
/* ── Tab bar ── */
|
||||
.tab-bar {
|
||||
display: flex;
|
||||
gap: 0;
|
||||
border-bottom: 2px solid var(--color-border);
|
||||
margin-bottom: 1.25rem;
|
||||
}
|
||||
.tab {
|
||||
padding: 0.55rem 1.1rem;
|
||||
font-size: 0.88rem;
|
||||
font-weight: 600;
|
||||
background: none;
|
||||
border: none;
|
||||
border-bottom: 2px solid transparent;
|
||||
margin-bottom: -2px;
|
||||
cursor: pointer;
|
||||
color: var(--color-text-muted);
|
||||
transition: color 0.15s, border-color 0.15s;
|
||||
}
|
||||
.tab:hover { color: var(--color-text); }
|
||||
.tab.active {
|
||||
color: var(--app-primary);
|
||||
border-bottom-color: var(--app-primary);
|
||||
}
|
||||
|
||||
/* ── Tab panel ── */
|
||||
.tab-panel { min-height: 200px; }
|
||||
|
||||
/* ── Nodes toolbar ── */
|
||||
.nodes-toolbar {
|
||||
display: flex;
|
||||
justify-content: flex-end;
|
||||
margin-bottom: 1rem;
|
||||
}
|
||||
|
||||
/* ── Nodes grid / status ── */
|
||||
.nodes-grid { display: flex; flex-direction: column; gap: 1.5rem; }
|
||||
.nodes-status {
|
||||
color: var(--color-text-muted);
|
||||
padding: 2rem;
|
||||
text-align: center;
|
||||
}
|
||||
.nodes-error { color: var(--color-error); }
|
||||
.sr-announce { min-height: 1.2em; }
|
||||
|
||||
/* ── Shared button ── */
|
||||
.btn-secondary {
|
||||
padding: 0.4rem 0.9rem;
|
||||
background: var(--color-surface-alt);
|
||||
border: 1px solid var(--color-border);
|
||||
border-radius: 0.4rem;
|
||||
font-size: 0.85rem;
|
||||
color: var(--color-text);
|
||||
cursor: pointer;
|
||||
transition: background 0.15s;
|
||||
}
|
||||
.btn-secondary:hover:not(:disabled) { background: var(--color-surface-raised); }
|
||||
.btn-secondary:disabled { opacity: 0.5; cursor: default; }
|
||||
.btn-sm { padding: 0.3rem 0.65rem; font-size: 0.8rem; }
|
||||
</style>
|
||||
1043
web/src/views/PlansBenchTab.vue
Normal file
1043
web/src/views/PlansBenchTab.vue
Normal file
File diff suppressed because it is too large
Load diff
536
web/src/views/RecipeScanView.vue
Normal file
536
web/src/views/RecipeScanView.vue
Normal file
|
|
@ -0,0 +1,536 @@
|
|||
<template>
|
||||
<div class="rsv">
|
||||
<!-- Header -->
|
||||
<header class="rsv-header">
|
||||
<h1 class="rsv-title">Recipe Scan Review</h1>
|
||||
<div class="rsv-stats" v-if="stats">
|
||||
<span class="stat-chip">{{ stats.by_status?.pending ?? 0 }} pending</span>
|
||||
<span class="stat-chip stat-chip--ok">{{ stats.by_status?.approved ?? 0 }} approved</span>
|
||||
<span class="stat-chip stat-chip--edited">{{ stats.by_status?.edited ?? 0 }} edited</span>
|
||||
<span class="stat-chip stat-chip--bad">{{ stats.by_status?.rejected ?? 0 }} rejected</span>
|
||||
<a
|
||||
v-if="(stats.export_ready ?? 0) > 0"
|
||||
:href="`${apiBase}/api/recipe-scan/export`"
|
||||
download
|
||||
class="btn-export"
|
||||
>
|
||||
⬇ Export {{ stats.export_ready }} pairs
|
||||
</a>
|
||||
</div>
|
||||
</header>
|
||||
|
||||
<!-- Loading -->
|
||||
<div v-if="loading" class="rsv-state" aria-label="Loading">
|
||||
<div class="skeleton-block" />
|
||||
</div>
|
||||
|
||||
<!-- Error -->
|
||||
<div v-else-if="apiError" class="rsv-state rsv-error" role="alert">
|
||||
<p>{{ apiError }}</p>
|
||||
<button class="btn-action" @click="fetchNext">Retry</button>
|
||||
</div>
|
||||
|
||||
<!-- Queue empty -->
|
||||
<div v-else-if="!item" class="rsv-state rsv-empty">
|
||||
<p>Queue is empty — all items reviewed.</p>
|
||||
<p class="rsv-hint">Import items from the Kiwi pipeline to continue.</p>
|
||||
</div>
|
||||
|
||||
<!-- Review panel -->
|
||||
<div v-else class="rsv-workspace">
|
||||
<!-- Left: image -->
|
||||
<section class="rsv-image-panel" aria-label="Scan image">
|
||||
<div class="rsv-panel-label">
|
||||
<span class="modality-badge">{{ item.modality }}</span>
|
||||
<span class="source-badge">{{ item.source }}</span>
|
||||
</div>
|
||||
<div class="rsv-image-wrap">
|
||||
<img
|
||||
v-if="imageUrl"
|
||||
:src="imageUrl"
|
||||
:alt="`Recipe scan — ${item.source}`"
|
||||
class="rsv-image"
|
||||
/>
|
||||
<div v-else class="rsv-image-placeholder">
|
||||
<span>Image not available</span>
|
||||
<code class="rsv-path">{{ item.image_path }}</code>
|
||||
</div>
|
||||
</div>
|
||||
</section>
|
||||
|
||||
<!-- Right: JSON comparison -->
|
||||
<section class="rsv-json-panel" aria-label="Extraction review">
|
||||
|
||||
<!-- Ground truth (read-only reference) -->
|
||||
<div class="rsv-json-block">
|
||||
<h2 class="rsv-json-label">Ground truth <span class="label-tag">reference</span></h2>
|
||||
<pre class="rsv-json rsv-json--ground-truth" tabindex="0" aria-label="Ground truth JSON">{{ prettyJson(item.ground_truth) }}</pre>
|
||||
</div>
|
||||
|
||||
<!-- Extracted / editable -->
|
||||
<div class="rsv-json-block">
|
||||
<h2 class="rsv-json-label">
|
||||
Extracted
|
||||
<span class="label-tag label-tag--edit">edit before approving</span>
|
||||
</h2>
|
||||
<textarea
|
||||
v-model="draftJson"
|
||||
class="rsv-json rsv-json--edit"
|
||||
spellcheck="false"
|
||||
aria-label="Extracted JSON — edit to correct"
|
||||
:class="{ 'rsv-json--invalid': jsonError }"
|
||||
/>
|
||||
<p v-if="jsonError" class="rsv-json-error" role="alert">{{ jsonError }}</p>
|
||||
</div>
|
||||
|
||||
<!-- Actions -->
|
||||
<div class="rsv-actions" role="group" aria-label="Review actions">
|
||||
<button
|
||||
class="btn-approve"
|
||||
:disabled="acting"
|
||||
@click="handleApprove"
|
||||
title="Extracted JSON is accurate — approve as-is (A)"
|
||||
>
|
||||
✓ Approve
|
||||
</button>
|
||||
<button
|
||||
class="btn-edit"
|
||||
:disabled="acting || !!jsonError"
|
||||
@click="handleEdit"
|
||||
title="Approve the edited JSON in the text area (E)"
|
||||
>
|
||||
✎ Approve edited
|
||||
</button>
|
||||
<button
|
||||
class="btn-reject"
|
||||
:disabled="acting"
|
||||
@click="handleReject"
|
||||
title="Extraction too broken to use — reject (R)"
|
||||
>
|
||||
✕ Reject
|
||||
</button>
|
||||
</div>
|
||||
|
||||
</section>
|
||||
</div>
|
||||
|
||||
<!-- Feedback toast -->
|
||||
<Transition name="toast">
|
||||
<div v-if="toast" class="rsv-toast" role="status" aria-live="polite">
|
||||
{{ toast }}
|
||||
</div>
|
||||
</Transition>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, computed, watch, onMounted, onUnmounted } from 'vue'
|
||||
|
||||
const apiBase = window.location.origin
|
||||
|
||||
interface RecipeScanItem {
|
||||
id: string
|
||||
image_path: string
|
||||
modality: string
|
||||
source: string
|
||||
extracted: Record<string, unknown>
|
||||
ground_truth: Record<string, unknown>
|
||||
status: string
|
||||
}
|
||||
|
||||
interface Stats {
|
||||
total: number
|
||||
by_status: Record<string, number>
|
||||
by_modality: Record<string, number>
|
||||
export_ready: number
|
||||
}
|
||||
|
||||
const item = ref<RecipeScanItem | null>(null)
|
||||
const stats = ref<Stats | null>(null)
|
||||
const loading = ref(true)
|
||||
const acting = ref(false)
|
||||
const apiError = ref('')
|
||||
const draftJson = ref('')
|
||||
const toast = ref('')
|
||||
let toastTimer: ReturnType<typeof setTimeout> | null = null
|
||||
|
||||
const jsonError = computed(() => {
|
||||
if (!draftJson.value.trim()) return ''
|
||||
try {
|
||||
JSON.parse(draftJson.value)
|
||||
return ''
|
||||
} catch (e) {
|
||||
return 'Invalid JSON — fix before approving'
|
||||
}
|
||||
})
|
||||
|
||||
const imageUrl = computed(() => {
|
||||
if (!item.value) return ''
|
||||
const encoded = encodeURIComponent(item.value.image_path)
|
||||
return `${apiBase}/api/recipe-scan/image?path=${encoded}`
|
||||
})
|
||||
|
||||
function prettyJson(obj: unknown): string {
|
||||
return JSON.stringify(obj, null, 2)
|
||||
}
|
||||
|
||||
function showToast(msg: string) {
|
||||
toast.value = msg
|
||||
if (toastTimer) clearTimeout(toastTimer)
|
||||
toastTimer = setTimeout(() => { toast.value = '' }, 2500)
|
||||
}
|
||||
|
||||
async function fetchNext() {
|
||||
loading.value = true
|
||||
apiError.value = ''
|
||||
try {
|
||||
const r = await fetch(`${apiBase}/api/recipe-scan/next`)
|
||||
if (r.status === 404) {
|
||||
item.value = null
|
||||
} else if (!r.ok) {
|
||||
throw new Error(`API error ${r.status}`)
|
||||
} else {
|
||||
item.value = await r.json()
|
||||
draftJson.value = prettyJson(item.value!.extracted)
|
||||
}
|
||||
} catch (e) {
|
||||
apiError.value = e instanceof Error ? e.message : 'Could not reach API'
|
||||
} finally {
|
||||
loading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
async function fetchStats() {
|
||||
try {
|
||||
const r = await fetch(`${apiBase}/api/recipe-scan/stats`)
|
||||
if (r.ok) stats.value = await r.json()
|
||||
} catch { /* non-critical */ }
|
||||
}
|
||||
|
||||
async function act(endpoint: string, body?: unknown) {
|
||||
if (!item.value || acting.value) return
|
||||
acting.value = true
|
||||
try {
|
||||
const r = await fetch(`${apiBase}/api/recipe-scan/items/${item.value.id}/${endpoint}`, {
|
||||
method: 'POST',
|
||||
headers: body ? { 'Content-Type': 'application/json' } : {},
|
||||
body: body ? JSON.stringify(body) : undefined,
|
||||
})
|
||||
if (!r.ok) throw new Error(`API error ${r.status}`)
|
||||
} catch (e) {
|
||||
showToast(e instanceof Error ? e.message : 'Action failed')
|
||||
acting.value = false
|
||||
return
|
||||
}
|
||||
acting.value = false
|
||||
await Promise.all([fetchNext(), fetchStats()])
|
||||
}
|
||||
|
||||
async function handleApprove() {
|
||||
showToast('Approved')
|
||||
await act('approve')
|
||||
}
|
||||
|
||||
async function handleEdit() {
|
||||
if (jsonError.value) return
|
||||
let corrected: unknown
|
||||
try {
|
||||
corrected = JSON.parse(draftJson.value)
|
||||
} catch {
|
||||
return
|
||||
}
|
||||
showToast('Saved edit')
|
||||
await act('edit', { corrected })
|
||||
}
|
||||
|
||||
async function handleReject() {
|
||||
showToast('Rejected')
|
||||
await act('reject')
|
||||
}
|
||||
|
||||
// Keyboard shortcuts: A = approve, E = edit+approve, R = reject
|
||||
function handleKey(e: KeyboardEvent) {
|
||||
const tag = (e.target as HTMLElement)?.tagName?.toLowerCase()
|
||||
if (tag === 'textarea' || tag === 'input') return
|
||||
if (e.key === 'a' || e.key === 'A') handleApprove()
|
||||
if (e.key === 'e' || e.key === 'E') handleEdit()
|
||||
if (e.key === 'r' || e.key === 'R') handleReject()
|
||||
}
|
||||
|
||||
watch(item, (newItem) => {
|
||||
if (newItem) draftJson.value = prettyJson(newItem.extracted)
|
||||
})
|
||||
|
||||
onMounted(() => {
|
||||
fetchNext()
|
||||
fetchStats()
|
||||
window.addEventListener('keydown', handleKey)
|
||||
})
|
||||
|
||||
onUnmounted(() => {
|
||||
window.removeEventListener('keydown', handleKey)
|
||||
if (toastTimer) clearTimeout(toastTimer)
|
||||
})
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
.rsv {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
height: 100%;
|
||||
padding: var(--space-md, 1rem);
|
||||
gap: var(--space-md, 1rem);
|
||||
box-sizing: border-box;
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
/* Header */
|
||||
.rsv-header {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: var(--space-md, 1rem);
|
||||
flex-wrap: wrap;
|
||||
}
|
||||
.rsv-title {
|
||||
font-size: 1.1rem;
|
||||
font-weight: 600;
|
||||
margin: 0;
|
||||
color: var(--color-text, #fff);
|
||||
}
|
||||
.rsv-stats {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.5rem;
|
||||
flex-wrap: wrap;
|
||||
}
|
||||
.stat-chip {
|
||||
font-size: 0.75rem;
|
||||
padding: 2px 8px;
|
||||
border-radius: 12px;
|
||||
background: var(--color-surface-alt, #2a2a2a);
|
||||
color: var(--color-text-muted, #aaa);
|
||||
}
|
||||
.stat-chip--ok { background: #1a3a1a; color: #6fcf97; }
|
||||
.stat-chip--edited { background: #2a2a00; color: #f2c94c; }
|
||||
.stat-chip--bad { background: #3a1a1a; color: #eb5757; }
|
||||
.btn-export {
|
||||
font-size: 0.8rem;
|
||||
padding: 4px 12px;
|
||||
border-radius: 6px;
|
||||
background: var(--color-accent, #4a9eff);
|
||||
color: #fff;
|
||||
text-decoration: none;
|
||||
}
|
||||
|
||||
/* State panels */
|
||||
.rsv-state {
|
||||
flex: 1;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
gap: 0.5rem;
|
||||
color: var(--color-text-muted, #aaa);
|
||||
}
|
||||
.rsv-error { color: var(--color-danger, #eb5757); }
|
||||
.rsv-empty { font-size: 1rem; }
|
||||
.rsv-hint { font-size: 0.85rem; opacity: 0.7; margin: 0; }
|
||||
.skeleton-block {
|
||||
width: 100%; height: 300px;
|
||||
border-radius: 8px;
|
||||
background: var(--color-surface-alt, #2a2a2a);
|
||||
animation: pulse 1.5s ease-in-out infinite;
|
||||
}
|
||||
@keyframes pulse { 0%, 100% { opacity: 1; } 50% { opacity: 0.5; } }
|
||||
|
||||
/* Workspace: two-column layout */
|
||||
.rsv-workspace {
|
||||
flex: 1;
|
||||
display: grid;
|
||||
grid-template-columns: 1fr 1fr;
|
||||
gap: var(--space-md, 1rem);
|
||||
min-height: 0;
|
||||
overflow: hidden;
|
||||
}
|
||||
@media (max-width: 900px) {
|
||||
.rsv-workspace {
|
||||
grid-template-columns: 1fr;
|
||||
overflow-y: auto;
|
||||
}
|
||||
}
|
||||
|
||||
/* Image panel */
|
||||
.rsv-image-panel {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0.5rem;
|
||||
min-height: 0;
|
||||
}
|
||||
.rsv-panel-label {
|
||||
display: flex;
|
||||
gap: 0.5rem;
|
||||
}
|
||||
.modality-badge, .source-badge {
|
||||
font-size: 0.72rem;
|
||||
padding: 2px 8px;
|
||||
border-radius: 10px;
|
||||
background: var(--color-surface-alt, #2a2a2a);
|
||||
color: var(--color-text-muted, #aaa);
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.04em;
|
||||
}
|
||||
.rsv-image-wrap {
|
||||
flex: 1;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
background: var(--color-surface-alt, #111);
|
||||
border-radius: 8px;
|
||||
overflow: hidden;
|
||||
min-height: 200px;
|
||||
}
|
||||
.rsv-image {
|
||||
max-width: 100%;
|
||||
max-height: 100%;
|
||||
object-fit: contain;
|
||||
}
|
||||
.rsv-image-placeholder {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
align-items: center;
|
||||
gap: 0.5rem;
|
||||
color: var(--color-text-muted, #666);
|
||||
font-size: 0.85rem;
|
||||
padding: 1rem;
|
||||
text-align: center;
|
||||
}
|
||||
.rsv-path {
|
||||
font-size: 0.7rem;
|
||||
word-break: break-all;
|
||||
opacity: 0.6;
|
||||
}
|
||||
|
||||
/* JSON panel */
|
||||
.rsv-json-panel {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0.5rem;
|
||||
min-height: 0;
|
||||
overflow-y: auto;
|
||||
}
|
||||
.rsv-json-block {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0.25rem;
|
||||
flex: 1;
|
||||
min-height: 0;
|
||||
}
|
||||
.rsv-json-label {
|
||||
font-size: 0.8rem;
|
||||
font-weight: 600;
|
||||
color: var(--color-text-muted, #aaa);
|
||||
margin: 0;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.5rem;
|
||||
}
|
||||
.label-tag {
|
||||
font-size: 0.68rem;
|
||||
font-weight: 400;
|
||||
padding: 1px 6px;
|
||||
border-radius: 8px;
|
||||
background: var(--color-surface-alt, #2a2a2a);
|
||||
color: var(--color-text-muted, #888);
|
||||
}
|
||||
.label-tag--edit {
|
||||
background: #2a2a00;
|
||||
color: #f2c94c;
|
||||
}
|
||||
.rsv-json {
|
||||
font-family: var(--font-mono, monospace);
|
||||
font-size: 0.75rem;
|
||||
line-height: 1.5;
|
||||
padding: 0.75rem;
|
||||
border-radius: 6px;
|
||||
min-height: 120px;
|
||||
flex: 1;
|
||||
overflow-y: auto;
|
||||
resize: vertical;
|
||||
white-space: pre;
|
||||
}
|
||||
.rsv-json--ground-truth {
|
||||
background: var(--color-surface-alt, #111);
|
||||
color: var(--color-text, #ccc);
|
||||
border: 1px solid var(--color-border, #333);
|
||||
}
|
||||
.rsv-json--edit {
|
||||
background: var(--color-surface, #1a1a1a);
|
||||
color: var(--color-text, #e0e0e0);
|
||||
border: 1px solid var(--color-border, #444);
|
||||
caret-color: var(--color-accent, #4a9eff);
|
||||
outline: none;
|
||||
transition: border-color 0.15s;
|
||||
}
|
||||
.rsv-json--edit:focus {
|
||||
border-color: var(--color-accent, #4a9eff);
|
||||
}
|
||||
.rsv-json--invalid {
|
||||
border-color: var(--color-danger, #eb5757) !important;
|
||||
}
|
||||
.rsv-json-error {
|
||||
font-size: 0.75rem;
|
||||
color: var(--color-danger, #eb5757);
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
/* Action buttons */
|
||||
.rsv-actions {
|
||||
display: flex;
|
||||
gap: 0.5rem;
|
||||
padding-top: 0.25rem;
|
||||
flex-wrap: wrap;
|
||||
}
|
||||
.btn-approve, .btn-edit, .btn-reject {
|
||||
flex: 1;
|
||||
min-width: 80px;
|
||||
padding: 0.5rem 0.75rem;
|
||||
border: none;
|
||||
border-radius: 6px;
|
||||
font-size: 0.85rem;
|
||||
font-weight: 600;
|
||||
cursor: pointer;
|
||||
transition: opacity 0.15s;
|
||||
}
|
||||
.btn-approve, .btn-edit, .btn-reject {
|
||||
opacity: 1;
|
||||
}
|
||||
.btn-approve:disabled, .btn-edit:disabled, .btn-reject:disabled {
|
||||
opacity: 0.4;
|
||||
cursor: default;
|
||||
}
|
||||
.btn-approve { background: #1e6e1e; color: #6fcf97; }
|
||||
.btn-approve:hover:not(:disabled) { background: #256325; }
|
||||
.btn-edit { background: #4a4a00; color: #f2c94c; }
|
||||
.btn-edit:hover:not(:disabled) { background: #606000; }
|
||||
.btn-reject { background: #6e1e1e; color: #eb8f8f; }
|
||||
.btn-reject:hover:not(:disabled) { background: #7a2222; }
|
||||
|
||||
/* Toast */
|
||||
.rsv-toast {
|
||||
position: fixed;
|
||||
bottom: 1.5rem;
|
||||
left: 50%;
|
||||
transform: translateX(-50%);
|
||||
background: var(--color-surface, #222);
|
||||
color: var(--color-text, #fff);
|
||||
border: 1px solid var(--color-border, #444);
|
||||
border-radius: 8px;
|
||||
padding: 0.5rem 1.25rem;
|
||||
font-size: 0.85rem;
|
||||
box-shadow: 0 4px 20px rgba(0,0,0,0.4);
|
||||
pointer-events: none;
|
||||
z-index: 100;
|
||||
}
|
||||
.toast-enter-active, .toast-leave-active { transition: opacity 0.2s, transform 0.2s; }
|
||||
.toast-enter-from, .toast-leave-to { opacity: 0; transform: translateX(-50%) translateY(8px); }
|
||||
</style>
|
||||
161
web/src/views/TrainJobsView.test.ts
Normal file
161
web/src/views/TrainJobsView.test.ts
Normal file
|
|
@ -0,0 +1,161 @@
|
|||
import { mount, flushPromises } from '@vue/test-utils'
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest'
|
||||
import TrainJobsView from './TrainJobsView.vue'
|
||||
|
||||
const sampleJob = {
|
||||
id: 'job-abc123',
|
||||
type: 'classifier',
|
||||
model_key: 'deberta-v3-small',
|
||||
status: 'queued',
|
||||
created_at: '2026-05-01T10:00:00Z',
|
||||
config: null,
|
||||
}
|
||||
|
||||
function makeFetch(jobs: unknown[] = []) {
|
||||
return vi.fn().mockImplementation((url: string, opts?: RequestInit) => {
|
||||
if ((opts?.method ?? 'GET') === 'POST') {
|
||||
return Promise.resolve({
|
||||
ok: true,
|
||||
json: async () => ({ ...sampleJob, id: 'new-job', status: 'queued' }),
|
||||
text: async () => '',
|
||||
})
|
||||
}
|
||||
if ((opts?.method ?? 'GET') === 'DELETE') {
|
||||
return Promise.resolve({ ok: true, json: async () => ({}), text: async () => '' })
|
||||
}
|
||||
// GET
|
||||
return Promise.resolve({
|
||||
ok: true,
|
||||
json: async () => ({ jobs }),
|
||||
text: async () => '',
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
class MockEventSource {
|
||||
onmessage: ((e: MessageEvent) => void) | null = null
|
||||
onerror: ((e: Event) => void) | null = null
|
||||
private _url: string
|
||||
constructor(url: string) { this._url = url }
|
||||
close() {}
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
vi.stubGlobal('fetch', makeFetch([sampleJob]))
|
||||
vi.stubGlobal('EventSource', MockEventSource)
|
||||
})
|
||||
|
||||
describe('TrainJobsView', () => {
|
||||
it('renders page title "Training Jobs"', async () => {
|
||||
const w = mount(TrainJobsView)
|
||||
await flushPromises()
|
||||
expect(w.find('h1.page-title').text()).toContain('Training Jobs')
|
||||
})
|
||||
|
||||
it('renders the new job form with type selector and model key input', async () => {
|
||||
const w = mount(TrainJobsView)
|
||||
await flushPromises()
|
||||
expect(w.find('select.job-type-select').exists()).toBe(true)
|
||||
expect(w.find('input.model-key-input').exists()).toBe(true)
|
||||
expect(w.find('button.submit-job-btn').exists()).toBe(true)
|
||||
})
|
||||
|
||||
it('type selector has classifier and llm-sft options', async () => {
|
||||
const w = mount(TrainJobsView)
|
||||
await flushPromises()
|
||||
const options = w.findAll('select.job-type-select option')
|
||||
const values = options.map(o => o.attributes('value') ?? o.element.textContent)
|
||||
expect(values).toContain('classifier')
|
||||
expect(values).toContain('llm-sft')
|
||||
})
|
||||
|
||||
it('submit button is disabled when model key is empty', async () => {
|
||||
const w = mount(TrainJobsView)
|
||||
await flushPromises()
|
||||
const btn = w.find('button.submit-job-btn')
|
||||
expect((btn.element as HTMLButtonElement).disabled).toBe(true)
|
||||
})
|
||||
|
||||
it('submit button is enabled when model key is entered', async () => {
|
||||
const w = mount(TrainJobsView)
|
||||
await flushPromises()
|
||||
await w.find('input.model-key-input').setValue('deberta-v3-small')
|
||||
const btn = w.find('button.submit-job-btn')
|
||||
expect((btn.element as HTMLButtonElement).disabled).toBe(false)
|
||||
})
|
||||
|
||||
it('shows job table with existing jobs', async () => {
|
||||
const w = mount(TrainJobsView)
|
||||
await flushPromises()
|
||||
expect(w.find('table.jobs-table').exists()).toBe(true)
|
||||
expect(w.text()).toContain('deberta-v3-small')
|
||||
})
|
||||
|
||||
it('shows status pill for each job', async () => {
|
||||
const w = mount(TrainJobsView)
|
||||
await flushPromises()
|
||||
expect(w.find('.status-pill').exists()).toBe(true)
|
||||
expect(w.find('.status-queued').exists()).toBe(true)
|
||||
})
|
||||
|
||||
it('shows cancel button for queued/running jobs', async () => {
|
||||
const w = mount(TrainJobsView)
|
||||
await flushPromises()
|
||||
expect(w.find('button.cancel-btn').exists()).toBe(true)
|
||||
})
|
||||
|
||||
it('submitting new job calls POST /api/train/jobs and refreshes', async () => {
|
||||
const fetchMock = makeFetch([])
|
||||
vi.stubGlobal('fetch', fetchMock)
|
||||
const w = mount(TrainJobsView)
|
||||
await flushPromises()
|
||||
await w.find('input.model-key-input').setValue('my-model')
|
||||
await w.find('button.submit-job-btn').trigger('click')
|
||||
await flushPromises()
|
||||
const calls = (fetchMock as ReturnType<typeof vi.fn>).mock.calls as [string, RequestInit?][]
|
||||
const postCall = calls.find(([, opts]) => (opts?.method ?? 'GET') === 'POST')
|
||||
expect(postCall).toBeDefined()
|
||||
expect(postCall![0]).toContain('/api/train/jobs')
|
||||
})
|
||||
|
||||
it('shows View Log button for running jobs', async () => {
|
||||
vi.stubGlobal('fetch', makeFetch([{ ...sampleJob, status: 'running' }]))
|
||||
const w = mount(TrainJobsView)
|
||||
await flushPromises()
|
||||
expect(w.find('button.view-log-btn').exists()).toBe(true)
|
||||
})
|
||||
it('shows error when config JSON is invalid', async () => {
|
||||
const w = mount(TrainJobsView)
|
||||
await flushPromises()
|
||||
await w.find('input.model-key-input').setValue('my-model')
|
||||
await w.find('textarea.config-textarea').setValue('{ not valid json }')
|
||||
await w.find('button.submit-job-btn').trigger('click')
|
||||
await flushPromises()
|
||||
expect(w.find('.error-notice').exists()).toBe(true)
|
||||
expect(w.find('.error-notice').text()).toContain('not valid')
|
||||
})
|
||||
|
||||
it('shows error notice when jobs load fails', async () => {
|
||||
vi.stubGlobal('fetch', vi.fn().mockResolvedValue({
|
||||
ok: false,
|
||||
status: 500,
|
||||
json: async () => ({}),
|
||||
text: async () => '',
|
||||
}))
|
||||
const w = mount(TrainJobsView)
|
||||
await flushPromises()
|
||||
expect(w.find('.error-notice').exists()).toBe(true)
|
||||
expect(w.find('table.jobs-table').exists()).toBe(false)
|
||||
})
|
||||
|
||||
it('cancel button optimistically updates job status to cancelled', async () => {
|
||||
const w = mount(TrainJobsView)
|
||||
await flushPromises()
|
||||
await w.find('button.cancel-btn').trigger('click')
|
||||
await flushPromises()
|
||||
// After cancel, job should show status-cancelled pill (not status-queued)
|
||||
expect(w.find('.status-cancelled').exists()).toBe(true)
|
||||
expect(w.find('.status-queued').exists()).toBe(false)
|
||||
})
|
||||
|
||||
})
|
||||
593
web/src/views/TrainJobsView.vue
Normal file
593
web/src/views/TrainJobsView.vue
Normal file
|
|
@ -0,0 +1,593 @@
|
|||
<template>
|
||||
<div class="train-jobs-view">
|
||||
<header class="view-header">
|
||||
<h1 class="page-title">🧠 Training Jobs</h1>
|
||||
</header>
|
||||
|
||||
<!-- New Job form -->
|
||||
<section class="section">
|
||||
<h2 class="section-title">New Job</h2>
|
||||
<form class="new-job-form" @submit.prevent="submitJob">
|
||||
<div class="form-row">
|
||||
<label class="form-label" for="job-type">Type</label>
|
||||
<select
|
||||
id="job-type"
|
||||
v-model="form.type"
|
||||
class="job-type-select form-control"
|
||||
>
|
||||
<option value="classifier">classifier</option>
|
||||
<option value="llm-sft">llm-sft</option>
|
||||
</select>
|
||||
</div>
|
||||
|
||||
<div class="form-row">
|
||||
<label class="form-label" for="model-key">Model key</label>
|
||||
<input
|
||||
id="model-key"
|
||||
v-model.trim="form.model_key"
|
||||
type="text"
|
||||
class="model-key-input form-control"
|
||||
placeholder="e.g. microsoft/deberta-v3-small"
|
||||
autocomplete="off"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div class="form-row">
|
||||
<label class="form-label" for="job-config">Config JSON <span class="form-hint">(optional)</span></label>
|
||||
<textarea
|
||||
id="job-config"
|
||||
v-model="form.config_raw"
|
||||
class="config-textarea form-control"
|
||||
rows="4"
|
||||
placeholder='{"learning_rate": 2e-5}'
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div v-if="submitError" class="error-notice" role="alert">{{ submitError }}</div>
|
||||
|
||||
<button
|
||||
type="submit"
|
||||
class="submit-job-btn btn-primary"
|
||||
:disabled="submitting || !form.model_key"
|
||||
@click.prevent="submitJob"
|
||||
>
|
||||
{{ submitting ? 'Queuing…' : 'Queue Job' }}
|
||||
</button>
|
||||
</form>
|
||||
</section>
|
||||
|
||||
<!-- Job queue table -->
|
||||
<section class="section">
|
||||
<h2 class="section-title">Job Queue</h2>
|
||||
|
||||
<div v-if="loadError" class="error-notice" role="alert">
|
||||
{{ loadError }}
|
||||
<button class="btn-retry" @click="loadJobs">Retry</button>
|
||||
</div>
|
||||
|
||||
<div v-else-if="jobs.length === 0" class="empty-notice">
|
||||
No training jobs yet.
|
||||
</div>
|
||||
|
||||
<div v-else class="jobs-table-wrap">
|
||||
<table class="jobs-table">
|
||||
<thead>
|
||||
<tr>
|
||||
<th>ID</th>
|
||||
<th>Type</th>
|
||||
<th>Model</th>
|
||||
<th>Status</th>
|
||||
<th>Created</th>
|
||||
<th></th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
<tr v-for="job in jobs" :key="job.id">
|
||||
<td class="td-id" :title="job.id">{{ job.id.slice(0, 8) }}</td>
|
||||
<td>
|
||||
<span class="type-chip">{{ job.type }}</span>
|
||||
</td>
|
||||
<td class="td-model">{{ job.model_key }}</td>
|
||||
<td>
|
||||
<span class="status-pill" :class="`status-${job.status}`">{{ job.status }}</span>
|
||||
</td>
|
||||
<td class="td-date">{{ formatDate(job.created_at) }}</td>
|
||||
<td class="td-actions">
|
||||
<button
|
||||
v-if="job.status === 'running'"
|
||||
class="view-log-btn btn-sm"
|
||||
@click="openLog(job.id)"
|
||||
>
|
||||
View Log
|
||||
</button>
|
||||
<button
|
||||
v-if="job.status === 'queued' || job.status === 'running'"
|
||||
class="cancel-btn btn-sm btn-danger-sm"
|
||||
:disabled="cancellingId === job.id"
|
||||
@click="cancelJob(job.id)"
|
||||
>
|
||||
{{ cancellingId === job.id ? '…' : 'Cancel' }}
|
||||
</button>
|
||||
</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
<div v-if="cancelError" class="error-notice" role="alert">{{ cancelError }}</div>
|
||||
</section>
|
||||
|
||||
<!-- Log panel (SSE) -->
|
||||
<section v-if="logJobId" class="section log-section">
|
||||
<div class="log-header">
|
||||
<h2 class="section-title">Log — {{ logJobId.slice(0, 8) }}</h2>
|
||||
<button class="btn-close-log" @click="closeLog">✕ Close</button>
|
||||
</div>
|
||||
<div class="log-panel" ref="logPanelEl">
|
||||
<div
|
||||
v-for="(line, i) in logLines"
|
||||
:key="i"
|
||||
class="log-line"
|
||||
>{{ line }}</div>
|
||||
<div v-if="logLines.length === 0" class="log-line log-muted">Connecting…</div>
|
||||
</div>
|
||||
</section>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, nextTick, onUnmounted } from 'vue'
|
||||
import { useApiSSE } from '../composables/useApi'
|
||||
|
||||
interface TrainJob {
|
||||
id: string
|
||||
type: 'classifier' | 'llm-sft'
|
||||
model_key: string
|
||||
status: 'queued' | 'running' | 'completed' | 'failed' | 'cancelled'
|
||||
created_at: string
|
||||
config: Record<string, unknown> | null
|
||||
}
|
||||
|
||||
const jobs = ref<TrainJob[]>([])
|
||||
const loadError = ref<string | null>(null)
|
||||
const submitError = ref<string | null>(null)
|
||||
const submitting = ref(false)
|
||||
const cancellingId = ref<string | null>(null)
|
||||
const cancelError = ref<string | null>(null)
|
||||
|
||||
const form = ref({
|
||||
type: 'classifier' as 'classifier' | 'llm-sft',
|
||||
model_key: '',
|
||||
config_raw: '',
|
||||
})
|
||||
|
||||
// ── Log panel state ──
|
||||
const logJobId = ref<string | null>(null)
|
||||
const logLines = ref<string[]>([])
|
||||
const logPanelEl = ref<HTMLElement | null>(null)
|
||||
let closeSSE: (() => void) | null = null
|
||||
|
||||
// ── Data loading ──
|
||||
|
||||
async function loadJobs() {
|
||||
loadError.value = null
|
||||
try {
|
||||
const res = await fetch('/api/train/jobs')
|
||||
if (!res.ok) { loadError.value = `Failed to load jobs (HTTP ${res.status}).`; return }
|
||||
const data = await res.json() as { jobs: TrainJob[] }
|
||||
jobs.value = data.jobs ?? []
|
||||
} catch {
|
||||
loadError.value = 'Network error loading jobs.'
|
||||
}
|
||||
}
|
||||
|
||||
// ── Submit ──
|
||||
|
||||
async function submitJob() {
|
||||
if (!form.value.model_key) return
|
||||
submitError.value = null
|
||||
submitting.value = true
|
||||
|
||||
let config: Record<string, unknown> | null = null
|
||||
if (form.value.config_raw.trim()) {
|
||||
try {
|
||||
config = JSON.parse(form.value.config_raw) as Record<string, unknown>
|
||||
} catch {
|
||||
submitError.value = 'Config JSON is not valid. Fix it before submitting.'
|
||||
submitting.value = false
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
const res = await fetch('/api/train/jobs', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({
|
||||
type: form.value.type,
|
||||
model_key: form.value.model_key,
|
||||
config_json: config,
|
||||
}),
|
||||
})
|
||||
if (!res.ok) {
|
||||
const detail = await res.text().catch(() => '')
|
||||
submitError.value = `Failed to queue job (HTTP ${res.status})${detail ? `: ${detail}` : '.'}`
|
||||
return
|
||||
}
|
||||
const newJob = await res.json() as TrainJob
|
||||
jobs.value = [newJob, ...jobs.value]
|
||||
form.value = { type: 'classifier', model_key: '', config_raw: '' }
|
||||
} catch {
|
||||
submitError.value = 'Network error submitting job.'
|
||||
} finally {
|
||||
submitting.value = false
|
||||
}
|
||||
}
|
||||
|
||||
// ── Cancel ──
|
||||
|
||||
async function cancelJob(id: string) {
|
||||
cancellingId.value = id
|
||||
cancelError.value = null
|
||||
try {
|
||||
const res = await fetch(`/api/train/jobs/${encodeURIComponent(id)}/cancel`, { method: 'DELETE' })
|
||||
if (res.ok) {
|
||||
jobs.value = jobs.value.map(j =>
|
||||
j.id === id ? { ...j, status: 'cancelled' as const } : j
|
||||
)
|
||||
} else {
|
||||
cancelError.value = `Failed to cancel job (HTTP ${res.status}).`
|
||||
}
|
||||
} catch {
|
||||
cancelError.value = 'Network error cancelling job.'
|
||||
} finally {
|
||||
cancellingId.value = null
|
||||
}
|
||||
}
|
||||
|
||||
// ── Log SSE ──
|
||||
|
||||
function openLog(id: string) {
|
||||
closeLog()
|
||||
logJobId.value = id
|
||||
logLines.value = []
|
||||
|
||||
closeSSE = useApiSSE(
|
||||
`/api/train/jobs/${encodeURIComponent(id)}/run`,
|
||||
(data) => {
|
||||
if (data.type === 'log' || data.type === 'progress' || data.type === 'error') {
|
||||
logLines.value = [...logLines.value, String(data.message ?? '')]
|
||||
nextTick(() => {
|
||||
if (logPanelEl.value) {
|
||||
logPanelEl.value.scrollTop = logPanelEl.value.scrollHeight
|
||||
}
|
||||
})
|
||||
}
|
||||
if (data.type === 'error') {
|
||||
logLines.value = [...logLines.value, '--- stream ended with error ---']
|
||||
nextTick(() => {
|
||||
if (logPanelEl.value) {
|
||||
logPanelEl.value.scrollTop = logPanelEl.value.scrollHeight
|
||||
}
|
||||
})
|
||||
}
|
||||
},
|
||||
() => {
|
||||
logLines.value = [...logLines.value, '--- stream complete ---']
|
||||
},
|
||||
() => {
|
||||
logLines.value = [...logLines.value, '--- connection lost ---']
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
function closeLog() {
|
||||
closeSSE?.()
|
||||
closeSSE = null
|
||||
logJobId.value = null
|
||||
logLines.value = []
|
||||
}
|
||||
|
||||
// ── Helpers ──
|
||||
|
||||
function formatDate(iso: string): string {
|
||||
const d = new Date(iso)
|
||||
if (isNaN(d.getTime())) return iso
|
||||
return d.toLocaleString(undefined, { dateStyle: 'short', timeStyle: 'short' })
|
||||
}
|
||||
|
||||
// ── Lifecycle ──
|
||||
|
||||
loadJobs()
|
||||
|
||||
onUnmounted(() => {
|
||||
closeSSE?.()
|
||||
})
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
.train-jobs-view {
|
||||
max-width: 860px;
|
||||
margin: 0 auto;
|
||||
padding: 1.5rem 1rem 4rem;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 2rem;
|
||||
}
|
||||
|
||||
.view-header { display: flex; align-items: center; }
|
||||
|
||||
.page-title {
|
||||
font-family: var(--font-display, var(--font-body, sans-serif));
|
||||
font-size: 1.4rem;
|
||||
font-weight: 700;
|
||||
color: var(--app-primary, #2A6080);
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
.section {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0.75rem;
|
||||
}
|
||||
|
||||
.section-title {
|
||||
font-size: 1rem;
|
||||
font-weight: 600;
|
||||
color: var(--color-text, #1a2338);
|
||||
padding-bottom: 0.4rem;
|
||||
border-bottom: 1px solid var(--color-border, #a8b8d0);
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
.new-job-form {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0.75rem;
|
||||
max-width: 480px;
|
||||
}
|
||||
|
||||
.form-row {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0.3rem;
|
||||
}
|
||||
|
||||
.form-label {
|
||||
font-size: 0.85rem;
|
||||
font-weight: 600;
|
||||
color: var(--color-text-muted, #4a5c7a);
|
||||
}
|
||||
|
||||
.form-hint {
|
||||
font-weight: 400;
|
||||
font-size: 0.78rem;
|
||||
}
|
||||
|
||||
.form-control {
|
||||
padding: 0.45rem 0.65rem;
|
||||
border: 1px solid var(--color-border, #a8b8d0);
|
||||
border-radius: var(--radius-md, 0.5rem);
|
||||
background: var(--color-surface-raised, #f5f7fc);
|
||||
color: var(--color-text, #1a2338);
|
||||
font-size: 0.9rem;
|
||||
font-family: var(--font-body, sans-serif);
|
||||
}
|
||||
|
||||
.form-control:focus {
|
||||
outline: 2px solid var(--app-primary, #2A6080);
|
||||
outline-offset: -1px;
|
||||
}
|
||||
|
||||
.config-textarea {
|
||||
resize: vertical;
|
||||
font-family: var(--font-mono, monospace);
|
||||
font-size: 0.82rem;
|
||||
}
|
||||
|
||||
.btn-primary {
|
||||
padding: 0.4rem 0.9rem;
|
||||
border-radius: var(--radius-md, 0.5rem);
|
||||
border: 1px solid var(--app-primary, #2A6080);
|
||||
background: var(--app-primary, #2A6080);
|
||||
color: #fff;
|
||||
font-size: 0.88rem;
|
||||
font-family: var(--font-body, sans-serif);
|
||||
cursor: pointer;
|
||||
align-self: flex-start;
|
||||
transition: opacity 0.15s;
|
||||
}
|
||||
|
||||
.btn-primary:disabled { opacity: 0.5; cursor: not-allowed; }
|
||||
.btn-primary:not(:disabled):hover { opacity: 0.85; }
|
||||
|
||||
.btn-sm {
|
||||
padding: 0.2rem 0.55rem;
|
||||
font-size: 0.78rem;
|
||||
border-radius: 0.3rem;
|
||||
cursor: pointer;
|
||||
font-family: var(--font-body, sans-serif);
|
||||
border: 1px solid;
|
||||
transition: background 0.1s;
|
||||
}
|
||||
|
||||
.view-log-btn {
|
||||
border-color: var(--color-info, #1e6091);
|
||||
background: transparent;
|
||||
color: var(--color-info, #1e6091);
|
||||
}
|
||||
|
||||
.view-log-btn:hover {
|
||||
background: color-mix(in srgb, var(--color-info, #1e6091) 10%, transparent);
|
||||
}
|
||||
|
||||
.btn-danger-sm {
|
||||
border-color: var(--color-error, #c0392b);
|
||||
background: transparent;
|
||||
color: var(--color-error, #c0392b);
|
||||
}
|
||||
|
||||
.btn-danger-sm:hover:not(:disabled) {
|
||||
background: color-mix(in srgb, var(--color-error, #c0392b) 10%, transparent);
|
||||
}
|
||||
|
||||
.btn-danger-sm:disabled { opacity: 0.5; cursor: not-allowed; }
|
||||
|
||||
.btn-retry {
|
||||
margin-left: 0.5rem;
|
||||
padding: 0.2rem 0.55rem;
|
||||
border-radius: 0.25rem;
|
||||
border: 1px solid var(--color-error, #c0392b);
|
||||
background: transparent;
|
||||
color: var(--color-error, #c0392b);
|
||||
cursor: pointer;
|
||||
font-size: 0.82rem;
|
||||
}
|
||||
|
||||
.error-notice {
|
||||
padding: 0.6rem 0.8rem;
|
||||
background: color-mix(in srgb, var(--color-error, #c0392b) 10%, transparent);
|
||||
border: 1px solid color-mix(in srgb, var(--color-error, #c0392b) 30%, transparent);
|
||||
border-radius: var(--radius-md, 0.5rem);
|
||||
color: var(--color-error, #c0392b);
|
||||
font-size: 0.88rem;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.5rem;
|
||||
}
|
||||
|
||||
.empty-notice {
|
||||
color: var(--color-text-muted, #4a5c7a);
|
||||
font-size: 0.9rem;
|
||||
padding: 0.75rem;
|
||||
border: 1px dashed var(--color-border, #a8b8d0);
|
||||
border-radius: var(--radius-md, 0.5rem);
|
||||
}
|
||||
|
||||
.jobs-table-wrap { overflow-x: auto; }
|
||||
|
||||
.jobs-table {
|
||||
width: 100%;
|
||||
border-collapse: collapse;
|
||||
font-size: 0.875rem;
|
||||
}
|
||||
|
||||
.jobs-table th {
|
||||
text-align: left;
|
||||
padding: 0.4rem 0.6rem;
|
||||
background: var(--color-surface-raised, #f5f7fc);
|
||||
color: var(--color-text-muted, #4a5c7a);
|
||||
font-size: 0.78rem;
|
||||
font-weight: 600;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.03em;
|
||||
border-bottom: 1px solid var(--color-border, #a8b8d0);
|
||||
white-space: nowrap;
|
||||
}
|
||||
|
||||
.jobs-table td {
|
||||
padding: 0.5rem 0.6rem;
|
||||
border-bottom: 1px solid var(--color-border-light, #ccd5e6);
|
||||
vertical-align: middle;
|
||||
}
|
||||
|
||||
.td-id {
|
||||
font-family: var(--font-mono, monospace);
|
||||
font-size: 0.78rem;
|
||||
color: var(--color-text-muted, #4a5c7a);
|
||||
}
|
||||
|
||||
.td-model {
|
||||
font-family: var(--font-mono, monospace);
|
||||
font-size: 0.82rem;
|
||||
word-break: break-all;
|
||||
}
|
||||
|
||||
.td-date {
|
||||
font-size: 0.8rem;
|
||||
color: var(--color-text-muted, #4a5c7a);
|
||||
white-space: nowrap;
|
||||
}
|
||||
|
||||
.td-actions {
|
||||
display: flex;
|
||||
gap: 0.35rem;
|
||||
align-items: center;
|
||||
flex-wrap: wrap;
|
||||
}
|
||||
|
||||
.status-pill {
|
||||
font-size: 0.68rem;
|
||||
font-weight: 700;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.04em;
|
||||
padding: 0.15rem 0.45rem;
|
||||
border-radius: var(--radius-full, 9999px);
|
||||
white-space: nowrap;
|
||||
}
|
||||
|
||||
.status-queued { background: var(--color-surface-alt, #dde4f0); color: var(--color-text-muted, #4a5c7a); }
|
||||
.status-running { background: color-mix(in srgb, var(--color-info, #1e6091) 15%, transparent); color: var(--color-info, #1e6091); }
|
||||
.status-completed { background: color-mix(in srgb, var(--color-success, #3a7a32) 15%, transparent); color: var(--color-success, #3a7a32); }
|
||||
.status-failed { background: color-mix(in srgb, var(--color-error, #c0392b) 15%, transparent); color: var(--color-error, #c0392b); }
|
||||
.status-cancelled { background: color-mix(in srgb, var(--color-warning, #d4891a) 15%, transparent); color: var(--color-warning, #d4891a); }
|
||||
|
||||
.type-chip {
|
||||
font-size: 0.72rem;
|
||||
font-family: var(--font-mono, monospace);
|
||||
padding: 0.1rem 0.4rem;
|
||||
border-radius: 0.25rem;
|
||||
background: var(--color-surface-alt, #dde4f0);
|
||||
color: var(--color-text, #1a2338);
|
||||
}
|
||||
|
||||
.log-section { gap: 0.5rem; }
|
||||
|
||||
.log-header {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: space-between;
|
||||
gap: 0.5rem;
|
||||
}
|
||||
|
||||
.btn-close-log {
|
||||
background: transparent;
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: 0.25rem;
|
||||
cursor: pointer;
|
||||
font-size: 0.8rem;
|
||||
padding: 0.2rem 0.5rem;
|
||||
color: var(--color-text-muted, #4a5c7a);
|
||||
transition: background 0.1s;
|
||||
}
|
||||
|
||||
.btn-close-log:hover { background: var(--color-surface-raised, #e4ebf5); }
|
||||
|
||||
.log-panel {
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: 0.5rem;
|
||||
max-height: 320px;
|
||||
overflow-y: auto;
|
||||
padding: 0.5rem 0.75rem;
|
||||
background: var(--color-surface, #f0f4fc);
|
||||
font-family: var(--font-mono, monospace);
|
||||
font-size: 0.78rem;
|
||||
}
|
||||
|
||||
.log-line {
|
||||
color: var(--color-text, #1a2338);
|
||||
line-height: 1.5;
|
||||
white-space: pre-wrap;
|
||||
word-break: break-all;
|
||||
}
|
||||
|
||||
.log-muted { color: var(--color-text-muted, #4a5c7a); }
|
||||
|
||||
@media (max-width: 560px) {
|
||||
.jobs-table th:nth-child(4),
|
||||
.jobs-table td:nth-child(4),
|
||||
.jobs-table th:nth-child(5),
|
||||
.jobs-table td:nth-child(5) {
|
||||
display: none;
|
||||
}
|
||||
}
|
||||
</style>
|
||||
101
web/src/views/TrainResultsView.test.ts
Normal file
101
web/src/views/TrainResultsView.test.ts
Normal file
|
|
@ -0,0 +1,101 @@
|
|||
import { mount, flushPromises } from '@vue/test-utils'
|
||||
import { createRouter, createWebHashHistory } from 'vue-router'
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest'
|
||||
import TrainResultsView from './TrainResultsView.vue'
|
||||
|
||||
const router = createRouter({
|
||||
history: createWebHashHistory(),
|
||||
routes: [
|
||||
{ path: '/fleet', component: { template: '<div />' } },
|
||||
],
|
||||
})
|
||||
|
||||
const sampleResult = {
|
||||
id: 'run-xyz',
|
||||
job_id: 'job-abc123',
|
||||
model_type: 'classifier',
|
||||
base_model: 'microsoft/deberta-v3-small',
|
||||
val_macro_f1: 0.847,
|
||||
val_accuracy: 0.891,
|
||||
sample_count: 1240,
|
||||
duration_seconds: 842,
|
||||
created_at: '2026-05-01T11:30:00Z',
|
||||
}
|
||||
|
||||
function makeFetch(results: unknown[] = []) {
|
||||
return vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
json: async () => ({ results }),
|
||||
text: async () => '',
|
||||
})
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
vi.stubGlobal('fetch', makeFetch([sampleResult]))
|
||||
})
|
||||
|
||||
describe('TrainResultsView', () => {
|
||||
it('renders page title "Training Results"', async () => {
|
||||
const w = mount(TrainResultsView, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
expect(w.find('h1.page-title').text()).toContain('Training Results')
|
||||
})
|
||||
|
||||
it('shows empty notice when there are no results', async () => {
|
||||
vi.stubGlobal('fetch', makeFetch([]))
|
||||
const w = mount(TrainResultsView, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
expect(w.find('.empty-notice').exists()).toBe(true)
|
||||
})
|
||||
|
||||
it('renders results table when results exist', async () => {
|
||||
const w = mount(TrainResultsView, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
expect(w.find('table.results-table').exists()).toBe(true)
|
||||
})
|
||||
|
||||
it('shows base_model in table', async () => {
|
||||
const w = mount(TrainResultsView, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
expect(w.text()).toContain('deberta-v3-small')
|
||||
})
|
||||
|
||||
it('shows val_macro_f1 formatted as percentage', async () => {
|
||||
const w = mount(TrainResultsView, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
expect(w.text()).toContain('84.7%')
|
||||
})
|
||||
|
||||
it('shows val_accuracy formatted as percentage', async () => {
|
||||
const w = mount(TrainResultsView, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
expect(w.text()).toContain('89.1%')
|
||||
})
|
||||
|
||||
it('shows duration formatted as minutes and seconds', async () => {
|
||||
const w = mount(TrainResultsView, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
// 842 seconds = 14m 2s
|
||||
expect(w.text()).toContain('14m')
|
||||
})
|
||||
|
||||
it('shows Register in Fleet button for classifier results', async () => {
|
||||
const w = mount(TrainResultsView, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
expect(w.find('a.register-btn').exists()).toBe(true)
|
||||
})
|
||||
|
||||
it('does NOT show Register in Fleet button for llm-sft results', async () => {
|
||||
vi.stubGlobal('fetch', makeFetch([{ ...sampleResult, model_type: 'llm-sft' }]))
|
||||
const w = mount(TrainResultsView, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
expect(w.find('a.register-btn').exists()).toBe(false)
|
||||
})
|
||||
|
||||
it('shows error notice when API fails', async () => {
|
||||
vi.stubGlobal('fetch', vi.fn().mockResolvedValue({ ok: false, status: 500, text: async () => '' }))
|
||||
const w = mount(TrainResultsView, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
expect(w.find('.error-notice').exists()).toBe(true)
|
||||
})
|
||||
})
|
||||
296
web/src/views/TrainResultsView.vue
Normal file
296
web/src/views/TrainResultsView.vue
Normal file
|
|
@ -0,0 +1,296 @@
|
|||
<template>
|
||||
<div class="train-results-view">
|
||||
<header class="view-header">
|
||||
<h1 class="page-title">Training Results</h1>
|
||||
<button class="refresh-btn" :disabled="loading" @click="loadResults" aria-label="Refresh">🔄</button>
|
||||
</header>
|
||||
|
||||
<div v-if="error" class="error-notice" role="alert">
|
||||
{{ error }}
|
||||
<button class="btn-retry" @click="loadResults">Retry</button>
|
||||
</div>
|
||||
|
||||
<div v-if="loading" class="loading-state" aria-live="polite">Loading…</div>
|
||||
|
||||
<div v-if="!error && results.length === 0 && !loading" class="empty-notice">
|
||||
No training results yet. Completed jobs will appear here.
|
||||
</div>
|
||||
|
||||
<div v-if="results.length > 0" class="results-table-wrap">
|
||||
<table class="results-table">
|
||||
<thead>
|
||||
<tr>
|
||||
<th>Run</th>
|
||||
<th>Type</th>
|
||||
<th>Base Model</th>
|
||||
<th class="th-numeric">Macro F1</th>
|
||||
<th class="th-numeric">Accuracy</th>
|
||||
<th class="th-numeric">Samples</th>
|
||||
<th class="th-numeric">Duration</th>
|
||||
<th></th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
<tr v-for="r in results" :key="r.id">
|
||||
<td class="td-id" :title="r.id">{{ r.id.slice(0, 8) }}</td>
|
||||
<td>
|
||||
<span class="type-chip">{{ r.model_type }}</span>
|
||||
</td>
|
||||
<td class="td-model" :title="r.base_model">{{ shortModel(r.base_model) }}</td>
|
||||
<td class="td-numeric">
|
||||
<span class="metric-val" :class="scoreClass(r.val_macro_f1)">
|
||||
{{ formatPct(r.val_macro_f1) }}
|
||||
</span>
|
||||
</td>
|
||||
<td class="td-numeric">{{ formatPct(r.val_accuracy) }}</td>
|
||||
<td class="td-numeric">{{ r.sample_count.toLocaleString() }}</td>
|
||||
<td class="td-numeric">{{ formatDuration(r.duration_seconds) }}</td>
|
||||
<td class="td-actions">
|
||||
<RouterLink
|
||||
v-if="r.model_type === 'classifier'"
|
||||
:to="`/fleet?model=${encodeURIComponent(r.base_model)}`"
|
||||
class="register-btn btn-sm-link"
|
||||
>
|
||||
Register in Fleet
|
||||
</RouterLink>
|
||||
</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, onMounted } from 'vue'
|
||||
import { RouterLink } from 'vue-router'
|
||||
|
||||
interface TrainResult {
|
||||
id: string
|
||||
job_id: string
|
||||
model_type: string
|
||||
base_model: string
|
||||
val_macro_f1: number | null
|
||||
val_accuracy: number | null
|
||||
sample_count: number
|
||||
duration_seconds: number | null
|
||||
created_at: string
|
||||
}
|
||||
|
||||
const results = ref<TrainResult[]>([])
|
||||
const loading = ref(false)
|
||||
const error = ref<string | null>(null)
|
||||
|
||||
async function loadResults() {
|
||||
loading.value = true
|
||||
error.value = null
|
||||
try {
|
||||
const res = await fetch('/api/train/results')
|
||||
if (!res.ok) {
|
||||
error.value = `Failed to load results (HTTP ${res.status}).`
|
||||
return
|
||||
}
|
||||
const raw = await res.json() as { results?: TrainResult[] }
|
||||
results.value = Array.isArray(raw?.results) ? raw.results : []
|
||||
} catch {
|
||||
error.value = 'Network error loading results.'
|
||||
} finally {
|
||||
loading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
function formatPct(v: number | null | undefined): string {
|
||||
if (v == null) return '—'
|
||||
return `${(v * 100).toFixed(1)}%`
|
||||
}
|
||||
|
||||
function formatDuration(seconds: number | null | undefined): string {
|
||||
if (seconds == null) return '—'
|
||||
const mins = Math.floor(seconds / 60)
|
||||
const secs = seconds % 60
|
||||
if (mins === 0) return `${secs}s`
|
||||
return `${mins}m ${secs}s`
|
||||
}
|
||||
|
||||
function shortModel(model: string): string {
|
||||
const parts = model.split('/')
|
||||
return parts[parts.length - 1] ?? model
|
||||
}
|
||||
|
||||
function scoreClass(f1: number | null | undefined): string {
|
||||
if (f1 == null) return ''
|
||||
if (f1 >= 0.85) return 'score-great'
|
||||
if (f1 >= 0.75) return 'score-good'
|
||||
return 'score-fair'
|
||||
}
|
||||
|
||||
onMounted(() => loadResults())
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
.train-results-view {
|
||||
max-width: 860px;
|
||||
margin: 0 auto;
|
||||
padding: 1.5rem 1rem 4rem;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 1.75rem;
|
||||
}
|
||||
|
||||
.view-header {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.75rem;
|
||||
}
|
||||
|
||||
.page-title {
|
||||
font-family: var(--font-display, var(--font-body, sans-serif));
|
||||
font-size: 1.4rem;
|
||||
font-weight: 700;
|
||||
color: var(--app-primary, #2A6080);
|
||||
margin: 0;
|
||||
flex: 1;
|
||||
}
|
||||
|
||||
.refresh-btn {
|
||||
background: transparent;
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: 0.375rem;
|
||||
cursor: pointer;
|
||||
font-size: 1rem;
|
||||
padding: 0.3rem 0.5rem;
|
||||
transition: background 0.15s;
|
||||
}
|
||||
|
||||
.refresh-btn:hover:not(:disabled) { background: var(--color-surface-raised, #e4ebf5); }
|
||||
.refresh-btn:disabled { opacity: 0.5; cursor: not-allowed; }
|
||||
|
||||
.error-notice {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.75rem;
|
||||
padding: 0.75rem 1rem;
|
||||
background: color-mix(in srgb, var(--color-error, #c0392b) 10%, transparent);
|
||||
border: 1px solid color-mix(in srgb, var(--color-error, #c0392b) 30%, transparent);
|
||||
border-radius: var(--radius-md, 0.5rem);
|
||||
color: var(--color-error, #c0392b);
|
||||
font-size: 0.88rem;
|
||||
}
|
||||
|
||||
.btn-retry {
|
||||
padding: 0.2rem 0.55rem;
|
||||
border-radius: 0.25rem;
|
||||
border: 1px solid var(--color-error, #c0392b);
|
||||
background: transparent;
|
||||
color: var(--color-error, #c0392b);
|
||||
cursor: pointer;
|
||||
font-size: 0.82rem;
|
||||
}
|
||||
|
||||
.empty-notice {
|
||||
color: var(--color-text-muted, #4a5c7a);
|
||||
font-size: 0.9rem;
|
||||
padding: 0.75rem;
|
||||
border: 1px dashed var(--color-border, #a8b8d0);
|
||||
border-radius: var(--radius-md, 0.5rem);
|
||||
}
|
||||
|
||||
.loading-state {
|
||||
color: var(--color-text-muted, #4a5c7a);
|
||||
font-size: 0.9rem;
|
||||
padding: 0.75rem;
|
||||
}
|
||||
|
||||
.results-table-wrap { overflow-x: auto; }
|
||||
|
||||
.results-table {
|
||||
width: 100%;
|
||||
border-collapse: collapse;
|
||||
font-size: 0.875rem;
|
||||
}
|
||||
|
||||
.results-table th {
|
||||
text-align: left;
|
||||
padding: 0.4rem 0.6rem;
|
||||
background: var(--color-surface-raised, #f5f7fc);
|
||||
color: var(--color-text-muted, #4a5c7a);
|
||||
font-size: 0.78rem;
|
||||
font-weight: 600;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.03em;
|
||||
border-bottom: 1px solid var(--color-border, #a8b8d0);
|
||||
white-space: nowrap;
|
||||
}
|
||||
|
||||
.th-numeric { text-align: right; }
|
||||
|
||||
.results-table td {
|
||||
padding: 0.5rem 0.6rem;
|
||||
border-bottom: 1px solid var(--color-border-light, #ccd5e6);
|
||||
vertical-align: middle;
|
||||
}
|
||||
|
||||
.td-id {
|
||||
font-family: var(--font-mono, monospace);
|
||||
font-size: 0.78rem;
|
||||
color: var(--color-text-muted, #4a5c7a);
|
||||
}
|
||||
|
||||
.td-model {
|
||||
font-family: var(--font-mono, monospace);
|
||||
font-size: 0.82rem;
|
||||
max-width: 16ch;
|
||||
overflow: hidden;
|
||||
text-overflow: ellipsis;
|
||||
white-space: nowrap;
|
||||
}
|
||||
|
||||
.td-numeric {
|
||||
text-align: right;
|
||||
font-family: var(--font-mono, monospace);
|
||||
font-variant-numeric: tabular-nums;
|
||||
font-size: 0.82rem;
|
||||
}
|
||||
|
||||
.td-actions { text-align: right; }
|
||||
|
||||
.metric-val { font-weight: 600; }
|
||||
.score-great { color: var(--color-success, #3a7a32); }
|
||||
.score-good { color: var(--color-warning, #d4891a); }
|
||||
.score-fair { color: var(--color-text-muted, #4a5c7a); }
|
||||
|
||||
.type-chip {
|
||||
font-size: 0.72rem;
|
||||
font-family: var(--font-mono, monospace);
|
||||
padding: 0.1rem 0.4rem;
|
||||
border-radius: 0.25rem;
|
||||
background: var(--color-surface-alt, #dde4f0);
|
||||
color: var(--color-text, #1a2338);
|
||||
}
|
||||
|
||||
.btn-sm-link {
|
||||
font-size: 0.78rem;
|
||||
padding: 0.2rem 0.55rem;
|
||||
border-radius: 0.3rem;
|
||||
border: 1px solid var(--app-primary, #2A6080);
|
||||
color: var(--app-primary, #2A6080);
|
||||
background: transparent;
|
||||
text-decoration: none;
|
||||
white-space: nowrap;
|
||||
display: inline-block;
|
||||
transition: background 0.1s;
|
||||
}
|
||||
|
||||
.btn-sm-link:hover {
|
||||
background: color-mix(in srgb, var(--app-primary, #2A6080) 10%, transparent);
|
||||
}
|
||||
|
||||
@media (max-width: 600px) {
|
||||
.results-table th:nth-child(6),
|
||||
.results-table td:nth-child(6),
|
||||
.results-table th:nth-child(7),
|
||||
.results-table td:nth-child(7) {
|
||||
display: none;
|
||||
}
|
||||
}
|
||||
</style>
|
||||
Loading…
Reference in a new issue