feat: Vue 3 label tab — complete card-stack UI with ASMR bucket UX #1
115 changed files with 6451 additions and 32181 deletions
23
.env.example
23
.env.example
|
|
@ -1,23 +0,0 @@
|
|||
# Avocet — environment variable configuration
|
||||
# Copy to .env and fill in values. All keys are optional.
|
||||
# label_tool.yaml takes precedence over env vars where both exist.
|
||||
|
||||
# ── Local inference (Ollama) ───────────────────────────────────────────────────
|
||||
# OLLAMA_HOST defaults to http://localhost:11434 if unset.
|
||||
OLLAMA_HOST=http://localhost:11434
|
||||
OLLAMA_MODEL=llama3.2:3b
|
||||
|
||||
# ── cf-orch coordinator (paid/premium tiers) ───────────────────────────────────
|
||||
# Required for multi-GPU LLM benchmarking via the cf-orch benchmark harness.
|
||||
# Free-tier users can leave these unset and use Ollama only.
|
||||
CF_ORCH_URL=http://localhost:7700
|
||||
CF_LICENSE_KEY=CFG-AVCT-xxxx-xxxx-xxxx
|
||||
|
||||
# ── Cloud LLM backends (optional — paid/premium) ──────────────────────────────
|
||||
# Set one of these to use a cloud LLM instead of a local model.
|
||||
# ANTHROPIC_API_KEY=sk-ant-...
|
||||
# OPENAI_API_KEY=sk-...
|
||||
|
||||
# ── HuggingFace (required for gated/terms-restricted model downloads) ─────────
|
||||
# Generate at https://huggingface.co/settings/tokens and accept model terms first.
|
||||
# HF_TOKEN=hf_xxxxxxxxxxxxxxxxxxxx
|
||||
13
.gitignore
vendored
13
.gitignore
vendored
|
|
@ -8,22 +8,9 @@ __pycache__/
|
|||
config/label_tool.yaml
|
||||
|
||||
# Data files (user-generated, not for version control)
|
||||
data/corpus.db
|
||||
data/corpus.db-wal
|
||||
data/corpus.db-shm
|
||||
data/email_score.jsonl
|
||||
data/email_label_queue.jsonl
|
||||
data/email_compare_sample.jsonl
|
||||
data/sft_candidates.jsonl
|
||||
data/sft_approved.jsonl
|
||||
|
||||
# Conda/pip artifacts
|
||||
.env
|
||||
|
||||
# Claude context — BSL 1.1, keep out of version control
|
||||
CLAUDE.md
|
||||
docs/superpowers/
|
||||
.superpowers/
|
||||
|
||||
# Git worktrees
|
||||
.worktrees/
|
||||
|
|
|
|||
173
CLAUDE.md
Normal file
173
CLAUDE.md
Normal file
|
|
@ -0,0 +1,173 @@
|
|||
# Avocet — Email Classifier Training Tool
|
||||
|
||||
## What it is
|
||||
|
||||
Shared infrastructure for building and benchmarking email classifiers across the CircuitForge menagerie.
|
||||
Named for the avocet's sweeping-bill technique — it sweeps through email streams and filters out categories.
|
||||
|
||||
**Pipeline:**
|
||||
```
|
||||
Scrape (IMAP, wide search, multi-account) → data/email_label_queue.jsonl
|
||||
↓
|
||||
Label (card-stack UI) → data/email_score.jsonl
|
||||
↓
|
||||
Benchmark (HuggingFace NLI/reranker) → per-model macro-F1 + latency
|
||||
```
|
||||
|
||||
## Environment
|
||||
|
||||
- Python env: `conda run -n job-seeker <cmd>` for basic use (streamlit, yaml, stdlib only)
|
||||
- Classifier env: `conda run -n job-seeker-classifiers <cmd>` for benchmark (transformers, FlagEmbedding, gliclass)
|
||||
- Run tests: `/devl/miniconda3/envs/job-seeker/bin/pytest tests/ -v`
|
||||
(direct binary — `conda run pytest` can spawn runaway processes)
|
||||
- Create classifier env: `conda env create -f environment.yml`
|
||||
|
||||
## Label Tool (app/label_tool.py)
|
||||
|
||||
Card-stack Streamlit UI for manually labeling recruitment emails.
|
||||
|
||||
```
|
||||
conda run -n job-seeker streamlit run app/label_tool.py --server.port 8503
|
||||
```
|
||||
|
||||
- Config: `config/label_tool.yaml` (gitignored — copy from `.example`, or use ⚙️ Settings tab)
|
||||
- Queue: `data/email_label_queue.jsonl` (gitignored)
|
||||
- Output: `data/email_score.jsonl` (gitignored)
|
||||
- Four tabs: 🃏 Label, 📥 Fetch, 📊 Stats, ⚙️ Settings
|
||||
- Keyboard shortcuts: 1–9 = label, 0 = Other (wildcard, prompts free-text input), S = skip, U = undo
|
||||
- Dedup: MD5 of `(subject + body[:100])` — cross-account safe
|
||||
|
||||
### Settings Tab (⚙️)
|
||||
- Add / edit / remove IMAP accounts via form UI — no manual YAML editing required
|
||||
- Per-account fields: display name, host, port, SSL toggle, username, password (masked), folder, days back
|
||||
- **🔌 Test connection** button per account — connects, logs in, selects folder, reports message count
|
||||
- Global: max emails per account per fetch
|
||||
- **💾 Save** writes `config/label_tool.yaml`; **↩ Reload** discards unsaved changes
|
||||
- `_sync_settings_to_state()` collects widget values before any add/remove to avoid index-key drift
|
||||
|
||||
## Benchmark (scripts/benchmark_classifier.py)
|
||||
|
||||
```
|
||||
# List available models
|
||||
conda run -n job-seeker-classifiers python scripts/benchmark_classifier.py --list-models
|
||||
|
||||
# Score against labeled JSONL
|
||||
conda run -n job-seeker-classifiers python scripts/benchmark_classifier.py --score
|
||||
|
||||
# Visual comparison on live IMAP emails
|
||||
conda run -n job-seeker-classifiers python scripts/benchmark_classifier.py --compare --limit 20
|
||||
|
||||
# Include slow/large models
|
||||
conda run -n job-seeker-classifiers python scripts/benchmark_classifier.py --score --include-slow
|
||||
|
||||
# Export DB-labeled emails (⚠️ LLM-generated labels — review first)
|
||||
conda run -n job-seeker-classifiers python scripts/benchmark_classifier.py --export-db --db /path/to/staging.db
|
||||
```
|
||||
|
||||
## Labels (peregrine defaults — configurable per product)
|
||||
|
||||
| Label | Key | Meaning |
|
||||
|-------|-----|---------|
|
||||
| `interview_scheduled` | 1 | Phone screen, video call, or on-site invitation |
|
||||
| `offer_received` | 2 | Formal job offer or offer letter |
|
||||
| `rejected` | 3 | Application declined or not moving forward |
|
||||
| `positive_response` | 4 | Recruiter interest or request to connect |
|
||||
| `survey_received` | 5 | Culture-fit survey or assessment invitation |
|
||||
| `neutral` | 6 | ATS confirmation (application received, etc.) |
|
||||
| `event_rescheduled` | 7 | Interview or event moved to a new time |
|
||||
| `digest` | 8 | Job digest or multi-listing email (scrapeable) |
|
||||
| `new_lead` | 9 | Unsolicited recruiter outreach or cold contact |
|
||||
| `hired` | h | Offer accepted, onboarding, welcome email, start date |
|
||||
|
||||
## Model Registry (13 models, 7 defaults)
|
||||
|
||||
See `scripts/benchmark_classifier.py:MODEL_REGISTRY`.
|
||||
Default models run without `--include-slow`.
|
||||
Add `--models deberta-small deberta-small-2pass` to test a specific subset.
|
||||
|
||||
## Config Files
|
||||
|
||||
- `config/label_tool.yaml` — gitignored; multi-account IMAP config
|
||||
- `config/label_tool.yaml.example` — committed template
|
||||
|
||||
## Data Files
|
||||
|
||||
- `data/email_score.jsonl` — gitignored; manually-labeled ground truth
|
||||
- `data/email_score.jsonl.example` — committed sample for CI
|
||||
- `data/email_label_queue.jsonl` — gitignored; IMAP fetch queue
|
||||
|
||||
## Key Design Notes
|
||||
|
||||
- `ZeroShotAdapter.load()` instantiates the pipeline object; `classify()` calls the object.
|
||||
Tests patch `scripts.classifier_adapters.pipeline` (the module-level factory) with a
|
||||
two-level mock: `mock_factory.return_value = MagicMock(return_value={...})`.
|
||||
- `two_pass=True` on ZeroShotAdapter: first pass ranks all 6 labels; second pass re-runs
|
||||
with only top-2, forcing a binary choice. 2× cost, better confidence.
|
||||
- `--compare` uses the first account in `label_tool.yaml` for live IMAP emails.
|
||||
- DB export labels are llama3.1:8b-generated — treat as noisy, not gold truth.
|
||||
|
||||
## Vue Label UI (app/api.py + web/)
|
||||
|
||||
FastAPI on port 8503 serves both the REST API and the built Vue SPA (`web/dist/`).
|
||||
|
||||
```
|
||||
./manage.sh start-api # build Vue SPA + start FastAPI (binds 0.0.0.0:8503 — LAN accessible)
|
||||
./manage.sh stop-api
|
||||
./manage.sh open-api # xdg-open http://localhost:8503
|
||||
```
|
||||
|
||||
Logs: `log/api.log`
|
||||
|
||||
## Email Field Schema — IMPORTANT
|
||||
|
||||
Two schemas exist. The normalization layer in `app/api.py` bridges them automatically.
|
||||
|
||||
### JSONL on-disk schema (written by `label_tool.py` and `label_tool.py`'s IMAP fetch)
|
||||
|
||||
| Field | Type | Notes |
|
||||
|-------|------|-------|
|
||||
| `subject` | str | Email subject line |
|
||||
| `body` | str | Plain-text body, truncated at 800 chars; HTML stripped by `_strip_html()` |
|
||||
| `from_addr` | str | Sender address string (`"Name <addr>"`) |
|
||||
| `date` | str | Raw RFC 2822 date string |
|
||||
| `account` | str | Display name of the IMAP account that fetched it |
|
||||
| *(no `id`)* | — | Dedup key is MD5 of `(subject + body[:100])` — never stored on disk |
|
||||
|
||||
### Vue API schema (returned by `GET /api/queue`, required by POST endpoints)
|
||||
|
||||
| Field | Type | Notes |
|
||||
|-------|------|-------|
|
||||
| `id` | str | MD5 content hash, or stored `id` if item has one |
|
||||
| `subject` | str | Unchanged |
|
||||
| `body` | str | Unchanged |
|
||||
| `from` | str | Mapped from `from_addr` (or `from` if already present) |
|
||||
| `date` | str | Unchanged |
|
||||
| `source` | str | Mapped from `account` (or `source` if already present) |
|
||||
|
||||
### Normalization layer (`_normalize()` in `app/api.py`)
|
||||
|
||||
`_normalize(item)` handles the mapping and ID generation. All `GET /api/queue` responses
|
||||
pass through it. Mutating endpoints (`/api/label`, `/api/skip`, `/api/discard`) look up
|
||||
items via `_normalize(x)["id"]`, so both real data (no `id`, uses content hash) and test
|
||||
fixtures (explicit `id` field) work transparently.
|
||||
|
||||
### Peregrine integration
|
||||
|
||||
Peregrine's `staging.db` uses different field names again:
|
||||
|
||||
| staging.db column | Maps to avocet JSONL field |
|
||||
|-------------------|---------------------------|
|
||||
| `subject` | `subject` |
|
||||
| `body` | `body` (may contain HTML — run through `_strip_html()` before queuing) |
|
||||
| `from_address` | `from_addr` |
|
||||
| `received_date` | `date` |
|
||||
| `account` or source context | `account` |
|
||||
|
||||
When exporting from Peregrine's DB for avocet labeling, transform to the JSONL schema above
|
||||
(not the Vue API schema). The `--export-db` flag in `benchmark_classifier.py` does this.
|
||||
Any new export path should also call `_strip_html()` on the body before writing.
|
||||
|
||||
## Relationship to Peregrine
|
||||
|
||||
Avocet started as `peregrine/tools/label_tool.py` + `peregrine/scripts/classifier_adapters.py`.
|
||||
Peregrine retains copies during stabilization; once avocet is proven, peregrine will import from here.
|
||||
177
README.md
177
README.md
|
|
@ -1,177 +0,0 @@
|
|||
<div align="center">
|
||||
<img src="docs/avocet-logo.svg" alt="Avocet" height="96" />
|
||||
|
||||
# Avocet
|
||||
|
||||
**Email classifier training tool — label, benchmark, fine-tune.**
|
||||
|
||||
[]()
|
||||
[](https://git.opensourcesolarpunk.com/Circuit-Forge/avocet/releases)
|
||||
[](LICENSE)
|
||||
[]()
|
||||
[](https://circuitforge.tech)
|
||||
</div>
|
||||
|
||||
---
|
||||
|
||||
## What is Avocet?
|
||||
|
||||
Avocet is the internal data pipeline Circuit Forge uses to build, evaluate, and fine-tune email classifiers. It implements a three-stage workflow: human labelers review emails one at a time in a drag-to-bucket UI and produce a ground-truth dataset; the benchmark harness scores any number of HuggingFace zero-shot models against that dataset and produces a ranked comparison; and the fine-tune harness adapts the best-scoring base model to the labeled distribution. The output feeds directly into Peregrine's email classification layer. No LLM API key required for the label tool or benchmark — all inference runs locally via HuggingFace Transformers.
|
||||
|
||||
---
|
||||
|
||||
## Quick Start
|
||||
|
||||
```bash
|
||||
git clone https://git.opensourcesolarpunk.com/Circuit-Forge/avocet.git
|
||||
cd avocet
|
||||
|
||||
# Copy config template and fill in your IMAP credentials
|
||||
cp config/label_tool.yaml.example config/label_tool.yaml
|
||||
|
||||
# Start the label tool (Vue SPA + FastAPI, port 8503)
|
||||
./manage.sh start
|
||||
./manage.sh open
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Features
|
||||
|
||||
- **Drag-to-bucket label UI** — ASMR-style card interface; drag emails into labeled buckets or discard without queuing noise into the training set
|
||||
- **Targeted IMAP fetch** — pull emails by date range, sender, or subject filter across multiple accounts without flooding the queue
|
||||
- **Email classifier benchmark** — score any HuggingFace zero-shot model against your labeled JSONL; side-by-side comparison on live IMAP emails
|
||||
- **Planning benchmark** — evaluate LLMs on structured planning tasks; compare models head-to-head with verbose diff output
|
||||
- **Writing style benchmark** — compare Ollama models on writing style coherence; scan local disk for existing outputs
|
||||
- **Fine-tune harness** — HuggingFace Transformers fine-tuning from labeled ground truth; classifier adapter interface for swapping backends at runtime
|
||||
- **Local inference first** — no API key required; GPU optional; designed to run on developer hardware
|
||||
- **Hot-reload dev mode** — uvicorn `--reload` + Vite HMR (hot module replacement) for fast iteration on both API and UI
|
||||
|
||||
---
|
||||
|
||||
## CLI Reference
|
||||
|
||||
All operations go through `manage.sh`.
|
||||
|
||||
### Label Tool
|
||||
|
||||
```bash
|
||||
./manage.sh start # Build Vue SPA and start FastAPI on port 8503
|
||||
./manage.sh stop # Stop FastAPI server
|
||||
./manage.sh restart # Stop, rebuild, and restart
|
||||
./manage.sh status # Show running state and port
|
||||
./manage.sh logs # Tail the API log
|
||||
./manage.sh open # Open http://localhost:8503 in browser
|
||||
./manage.sh dev # Hot-reload: uvicorn --reload + Vite HMR
|
||||
./manage.sh test # Run pytest suite
|
||||
```
|
||||
|
||||
### Email Classifier Benchmark
|
||||
|
||||
```bash
|
||||
./manage.sh benchmark [args] # Run benchmark_classifier.py
|
||||
./manage.sh list-models # List available zero-shot models
|
||||
./manage.sh score # Score models against labeled JSONL
|
||||
./manage.sh score --include-slow # Include large/slow models
|
||||
./manage.sh compare --limit 30 # Side-by-side comparison on live IMAP emails
|
||||
```
|
||||
|
||||
### Planning Benchmark
|
||||
|
||||
```bash
|
||||
./manage.sh plans-bench [args] # Run benchmark_plans.py
|
||||
./manage.sh plans-list # List available models
|
||||
./manage.sh plans-run <model> [args] # Run a single model (verbose)
|
||||
./manage.sh plans-compare <m1> <m2> [...] # Compare models side-by-side
|
||||
```
|
||||
|
||||
### Writing Style Benchmark
|
||||
|
||||
```bash
|
||||
./manage.sh style-bench [args] # Run benchmark_style.py
|
||||
./manage.sh style-list # List available Ollama models
|
||||
./manage.sh style-run [args] # Run writing style benchmark
|
||||
./manage.sh style-last # Print most recent benchmark report
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Data Flow
|
||||
|
||||
```
|
||||
IMAP accounts
|
||||
→ fetch (targeted or wide)
|
||||
→ email_label_queue.jsonl
|
||||
|
||||
email_label_queue.jsonl
|
||||
→ label tool drag-to-bucket UI
|
||||
→ email_score.jsonl (ground truth)
|
||||
|
||||
email_score.jsonl
|
||||
→ benchmark harness
|
||||
→ model rankings
|
||||
|
||||
best model
|
||||
→ fine-tune harness
|
||||
→ Peregrine classifier adapter
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Labels
|
||||
|
||||
| Label | Key |
|
||||
|-------|-----|
|
||||
| `interview_scheduled` | 1 |
|
||||
| `offer_received` | 2 |
|
||||
| `rejected` | 3 |
|
||||
| `positive_response` | 4 |
|
||||
| `survey_received` | 5 |
|
||||
| `neutral` | 6 |
|
||||
| `event_rescheduled` | 7 |
|
||||
| `unrelated` | 8 |
|
||||
| `digest` | 9 |
|
||||
|
||||
---
|
||||
|
||||
## Stack
|
||||
|
||||
| Layer | Technology |
|
||||
|-------|-----------|
|
||||
| Label UI | Vue 3 SPA (Vite) |
|
||||
| API | FastAPI + uvicorn (port 8503) |
|
||||
| Benchmark | Python + HuggingFace Transformers |
|
||||
| Email fetch | IMAP (multi-account, targeted date/sender/subject filter) |
|
||||
| Data | JSONL (`data/email_label_queue.jsonl`, `data/email_score.jsonl`) |
|
||||
| Runtime | SQLite |
|
||||
| Config | `config/label_tool.yaml` (gitignored — `.example` committed) |
|
||||
|
||||
---
|
||||
|
||||
## Logo
|
||||
|
||||
The Avocet logo (`avocet_v1_poly.svg`) lives in the shared graphics repo. Copy it to `docs/avocet-logo.svg` to render correctly in this README.
|
||||
|
||||
---
|
||||
|
||||
## About
|
||||
|
||||
Avocet is internal CircuitForge infrastructure, open source as a reference implementation. It is not a user-facing product. The primary consumer is [Peregrine](https://git.opensourcesolarpunk.com/Circuit-Forge/peregrine), CircuitForge's job-search pipeline tool.
|
||||
|
||||
Docs: [docs.circuitforge.tech/avocet](https://docs.circuitforge.tech/avocet)
|
||||
|
||||
## Forgejo-primary
|
||||
|
||||
Avocet is developed and maintained on Forgejo at [git.opensourcesolarpunk.com/Circuit-Forge/avocet](https://git.opensourcesolarpunk.com/Circuit-Forge/avocet). GitHub and Codeberg are read-only mirrors.
|
||||
|
||||
---
|
||||
|
||||
## License
|
||||
|
||||
[Business Source License 1.1](LICENSE) — classifier training is an AI feature under the CircuitForge licensing model.
|
||||
|
||||
Free for personal non-commercial self-hosting. Commercial use or SaaS re-hosting requires a paid license. Converts to MIT after 4 years.
|
||||
|
||||
Humans own design, architecture, code review, testing, and verification. LLMs are part of our development workflow. [Our positions on LLM use →](https://circuitforge.tech/positions)
|
||||
|
||||
© 2026 Circuit Forge LLC — Privacy · Safety · Accessibility
|
||||
596
app/api.py
596
app/api.py
|
|
@ -1,95 +1,565 @@
|
|||
"""Avocet -- FastAPI app factory.
|
||||
"""Avocet — FastAPI REST layer.
|
||||
|
||||
Mounts all domain routers and serves the Vue SPA.
|
||||
All business logic lives in the domain modules below.
|
||||
JSONL read/write helpers and FastAPI app instance.
|
||||
Endpoints and static file serving are added in subsequent tasks.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import subprocess as _subprocess
|
||||
import yaml
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import FastAPI
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from fastapi import FastAPI, HTTPException, Query
|
||||
from pydantic import BaseModel
|
||||
|
||||
_ROOT = Path(__file__).parent.parent
|
||||
_DATA_DIR: Path = _ROOT / "data" # overridable in tests via set_data_dir()
|
||||
_MODELS_DIR: Path = _ROOT / "models" # overridable in tests via set_models_dir()
|
||||
_CONFIG_DIR: Path | None = None # None = use real path
|
||||
|
||||
# Process registry for running jobs — used by cancel endpoints.
|
||||
# Keys: "benchmark" | "finetune". Values: the live Popen object.
|
||||
_running_procs: dict = {}
|
||||
_cancelled_jobs: set = set()
|
||||
|
||||
|
||||
def set_data_dir(path: Path) -> None:
|
||||
"""Override data directory — used by tests."""
|
||||
global _DATA_DIR
|
||||
_DATA_DIR = path
|
||||
|
||||
|
||||
def _best_cuda_device() -> str:
|
||||
"""Return the index of the GPU with the most free VRAM as a string.
|
||||
|
||||
Uses nvidia-smi so it works in the job-seeker env (no torch). Returns ""
|
||||
if nvidia-smi is unavailable or no GPUs are found. Restricting the
|
||||
training subprocess to a single GPU via CUDA_VISIBLE_DEVICES prevents
|
||||
PyTorch DataParallel from replicating the model across all GPUs, which
|
||||
would OOM the GPU with less headroom.
|
||||
"""
|
||||
try:
|
||||
out = _subprocess.check_output(
|
||||
["nvidia-smi", "--query-gpu=index,memory.free",
|
||||
"--format=csv,noheader,nounits"],
|
||||
text=True,
|
||||
timeout=5,
|
||||
)
|
||||
best_idx, best_free = "", 0
|
||||
for line in out.strip().splitlines():
|
||||
parts = line.strip().split(", ")
|
||||
if len(parts) == 2:
|
||||
idx, free = parts[0].strip(), int(parts[1].strip())
|
||||
if free > best_free:
|
||||
best_free, best_idx = free, idx
|
||||
return best_idx
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
|
||||
def set_models_dir(path: Path) -> None:
|
||||
"""Override models directory — used by tests."""
|
||||
global _MODELS_DIR
|
||||
_MODELS_DIR = path
|
||||
|
||||
|
||||
def set_config_dir(path: Path | None) -> None:
|
||||
"""Override config directory — used by tests."""
|
||||
global _CONFIG_DIR
|
||||
_CONFIG_DIR = path
|
||||
|
||||
|
||||
def _config_file() -> Path:
|
||||
if _CONFIG_DIR is not None:
|
||||
return _CONFIG_DIR / "label_tool.yaml"
|
||||
return _ROOT / "config" / "label_tool.yaml"
|
||||
|
||||
|
||||
def reset_last_action() -> None:
|
||||
"""Reset undo state — used by tests."""
|
||||
global _last_action
|
||||
_last_action = None
|
||||
|
||||
|
||||
def _queue_file() -> Path:
|
||||
return _DATA_DIR / "email_label_queue.jsonl"
|
||||
|
||||
|
||||
def _score_file() -> Path:
|
||||
return _DATA_DIR / "email_score.jsonl"
|
||||
|
||||
|
||||
def _discarded_file() -> Path:
|
||||
return _DATA_DIR / "discarded.jsonl"
|
||||
|
||||
|
||||
def _read_jsonl(path: Path) -> list[dict]:
|
||||
if not path.exists():
|
||||
return []
|
||||
lines = path.read_text(encoding="utf-8").splitlines()
|
||||
return [json.loads(l) for l in lines if l.strip()]
|
||||
|
||||
|
||||
def _write_jsonl(path: Path, records: list[dict]) -> None:
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
text = "\n".join(json.dumps(r, ensure_ascii=False) for r in records)
|
||||
path.write_text(text + "\n" if records else "", encoding="utf-8")
|
||||
|
||||
|
||||
def _append_jsonl(path: Path, record: dict) -> None:
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with path.open("a", encoding="utf-8") as f:
|
||||
f.write(json.dumps(record, ensure_ascii=False) + "\n")
|
||||
|
||||
|
||||
def _item_id(item: dict) -> str:
|
||||
"""Stable content-hash ID — matches label_tool.py _entry_key dedup logic."""
|
||||
key = (item.get("subject", "") + (item.get("body", "") or "")[:100])
|
||||
return hashlib.md5(key.encode("utf-8", errors="replace")).hexdigest()
|
||||
|
||||
|
||||
def _normalize(item: dict) -> dict:
|
||||
"""Normalize JSONL item to the Vue frontend schema.
|
||||
|
||||
label_tool.py stores: subject, body, from_addr, date, account (no id).
|
||||
The Vue app expects: id, subject, body, from, date, source.
|
||||
Both old (from_addr/account) and new (from/source) field names are handled.
|
||||
"""
|
||||
return {
|
||||
"id": item.get("id") or _item_id(item),
|
||||
"subject": item.get("subject", ""),
|
||||
"body": item.get("body", ""),
|
||||
"from": item.get("from") or item.get("from_addr", ""),
|
||||
"date": item.get("date", ""),
|
||||
"source": item.get("source") or item.get("account", ""),
|
||||
}
|
||||
|
||||
|
||||
app = FastAPI(title="Avocet API")
|
||||
|
||||
# -- Domain routers --------------------------------------------------------
|
||||
|
||||
from app.data.label import router as label_router
|
||||
app.include_router(label_router, prefix="/api")
|
||||
|
||||
from app.data.fetch import router as fetch_router
|
||||
app.include_router(fetch_router, prefix="/api")
|
||||
|
||||
from app.data.corrections import router as corrections_router
|
||||
app.include_router(corrections_router, prefix="/api/corrections")
|
||||
|
||||
# Backward-compat alias -- remove when Vue SPA is updated to /api/corrections/*
|
||||
app.include_router(corrections_router, prefix="/api/sft")
|
||||
|
||||
from app.data.imitate import router as imitate_router
|
||||
app.include_router(imitate_router, prefix="/api/imitate")
|
||||
|
||||
from app.eval.cforch import router as eval_router
|
||||
app.include_router(eval_router, prefix="/api")
|
||||
|
||||
from app.train.train import router as train_router
|
||||
app.include_router(train_router, prefix="/api/train")
|
||||
|
||||
from app.plans_bench import router as plans_bench_router
|
||||
app.include_router(plans_bench_router, prefix="/api/plans-bench")
|
||||
|
||||
# In-memory last-action store (single user, local tool — in-memory is fine)
|
||||
_last_action: dict | None = None
|
||||
|
||||
# -- Backward-compat shims (ClassifierTab still uses old /api/finetune/* paths)
|
||||
# Remove once ClassifierTab fine-tune section is migrated to TrainJobsView.
|
||||
|
||||
from fastapi import Query
|
||||
from fastapi.responses import StreamingResponse as _StreamingResponse
|
||||
@app.get("/api/queue")
|
||||
def get_queue(limit: int = Query(default=10, ge=1, le=50)):
|
||||
items = _read_jsonl(_queue_file())
|
||||
return {"items": [_normalize(x) for x in items[:limit]], "total": len(items)}
|
||||
|
||||
|
||||
class LabelRequest(BaseModel):
|
||||
id: str
|
||||
label: str
|
||||
|
||||
|
||||
@app.post("/api/label")
|
||||
def post_label(req: LabelRequest):
|
||||
global _last_action
|
||||
items = _read_jsonl(_queue_file())
|
||||
match = next((x for x in items if _normalize(x)["id"] == req.id), None)
|
||||
if not match:
|
||||
raise HTTPException(404, f"Item {req.id!r} not found in queue")
|
||||
record = {**match, "label": req.label,
|
||||
"labeled_at": datetime.now(timezone.utc).isoformat()}
|
||||
_append_jsonl(_score_file(), record)
|
||||
_write_jsonl(_queue_file(), [x for x in items if _normalize(x)["id"] != req.id])
|
||||
_last_action = {"type": "label", "item": match, "label": req.label}
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
class SkipRequest(BaseModel):
|
||||
id: str
|
||||
|
||||
|
||||
@app.post("/api/skip")
|
||||
def post_skip(req: SkipRequest):
|
||||
global _last_action
|
||||
items = _read_jsonl(_queue_file())
|
||||
match = next((x for x in items if _normalize(x)["id"] == req.id), None)
|
||||
if not match:
|
||||
raise HTTPException(404, f"Item {req.id!r} not found in queue")
|
||||
reordered = [x for x in items if _normalize(x)["id"] != req.id] + [match]
|
||||
_write_jsonl(_queue_file(), reordered)
|
||||
_last_action = {"type": "skip", "item": match}
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
class DiscardRequest(BaseModel):
|
||||
id: str
|
||||
|
||||
|
||||
@app.post("/api/discard")
|
||||
def post_discard(req: DiscardRequest):
|
||||
global _last_action
|
||||
items = _read_jsonl(_queue_file())
|
||||
match = next((x for x in items if _normalize(x)["id"] == req.id), None)
|
||||
if not match:
|
||||
raise HTTPException(404, f"Item {req.id!r} not found in queue")
|
||||
record = {**match, "label": "__discarded__",
|
||||
"discarded_at": datetime.now(timezone.utc).isoformat()}
|
||||
_append_jsonl(_discarded_file(), record)
|
||||
_write_jsonl(_queue_file(), [x for x in items if _normalize(x)["id"] != req.id])
|
||||
_last_action = {"type": "discard", "item": match}
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
@app.delete("/api/label/undo")
|
||||
def delete_undo():
|
||||
global _last_action
|
||||
if not _last_action:
|
||||
raise HTTPException(404, "No action to undo")
|
||||
action = _last_action
|
||||
item = action["item"] # always the original clean queue item
|
||||
|
||||
# Perform file operations FIRST — only clear _last_action on success
|
||||
if action["type"] == "label":
|
||||
records = _read_jsonl(_score_file())
|
||||
if not records:
|
||||
raise HTTPException(409, "Score file is empty — cannot undo label")
|
||||
_write_jsonl(_score_file(), records[:-1])
|
||||
items = _read_jsonl(_queue_file())
|
||||
_write_jsonl(_queue_file(), [item] + items)
|
||||
elif action["type"] == "discard":
|
||||
records = _read_jsonl(_discarded_file())
|
||||
if not records:
|
||||
raise HTTPException(409, "Discarded file is empty — cannot undo discard")
|
||||
_write_jsonl(_discarded_file(), records[:-1])
|
||||
items = _read_jsonl(_queue_file())
|
||||
_write_jsonl(_queue_file(), [item] + items)
|
||||
elif action["type"] == "skip":
|
||||
items = _read_jsonl(_queue_file())
|
||||
item_id = _normalize(item)["id"]
|
||||
items = [item] + [x for x in items if _normalize(x)["id"] != item_id]
|
||||
_write_jsonl(_queue_file(), items)
|
||||
|
||||
# Clear AFTER all file operations succeed
|
||||
_last_action = None
|
||||
return {"undone": {"type": action["type"], "item": _normalize(item)}}
|
||||
|
||||
|
||||
# Label metadata — 10 labels matching label_tool.py
|
||||
_LABEL_META = [
|
||||
{"name": "interview_scheduled", "emoji": "\U0001f4c5", "color": "#4CAF50", "key": "1"},
|
||||
{"name": "offer_received", "emoji": "\U0001f389", "color": "#2196F3", "key": "2"},
|
||||
{"name": "rejected", "emoji": "\u274c", "color": "#F44336", "key": "3"},
|
||||
{"name": "positive_response", "emoji": "\U0001f44d", "color": "#FF9800", "key": "4"},
|
||||
{"name": "survey_received", "emoji": "\U0001f4cb", "color": "#9C27B0", "key": "5"},
|
||||
{"name": "neutral", "emoji": "\u2b1c", "color": "#607D8B", "key": "6"},
|
||||
{"name": "event_rescheduled", "emoji": "\U0001f504", "color": "#FF5722", "key": "7"},
|
||||
{"name": "digest", "emoji": "\U0001f4f0", "color": "#00BCD4", "key": "8"},
|
||||
{"name": "new_lead", "emoji": "\U0001f91d", "color": "#009688", "key": "9"},
|
||||
{"name": "hired", "emoji": "\U0001f38a", "color": "#FFC107", "key": "h"},
|
||||
]
|
||||
|
||||
|
||||
@app.get("/api/config/labels")
|
||||
def get_labels():
|
||||
return _LABEL_META
|
||||
|
||||
|
||||
@app.get("/api/config")
|
||||
def get_config():
|
||||
f = _config_file()
|
||||
if not f.exists():
|
||||
return {"accounts": [], "max_per_account": 500}
|
||||
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
||||
return {"accounts": raw.get("accounts", []), "max_per_account": raw.get("max_per_account", 500)}
|
||||
|
||||
|
||||
class ConfigPayload(BaseModel):
|
||||
accounts: list[dict]
|
||||
max_per_account: int = 500
|
||||
|
||||
|
||||
@app.post("/api/config")
|
||||
def post_config(payload: ConfigPayload):
|
||||
f = _config_file()
|
||||
f.parent.mkdir(parents=True, exist_ok=True)
|
||||
tmp = f.with_suffix(".tmp")
|
||||
tmp.write_text(yaml.dump(payload.model_dump(), allow_unicode=True, sort_keys=False),
|
||||
encoding="utf-8")
|
||||
tmp.rename(f)
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
@app.get("/api/stats")
|
||||
def get_stats():
|
||||
records = _read_jsonl(_score_file())
|
||||
counts: dict[str, int] = {}
|
||||
for r in records:
|
||||
lbl = r.get("label", "")
|
||||
if lbl:
|
||||
counts[lbl] = counts.get(lbl, 0) + 1
|
||||
return {
|
||||
"total": len(records),
|
||||
"counts": counts,
|
||||
"score_file_bytes": _score_file().stat().st_size if _score_file().exists() else 0,
|
||||
}
|
||||
|
||||
|
||||
@app.get("/api/stats/download")
|
||||
def download_stats():
|
||||
from fastapi.responses import FileResponse
|
||||
if not _score_file().exists():
|
||||
raise HTTPException(404, "No score file")
|
||||
return FileResponse(
|
||||
str(_score_file()),
|
||||
filename="email_score.jsonl",
|
||||
media_type="application/jsonlines",
|
||||
headers={"Content-Disposition": 'attachment; filename="email_score.jsonl"'},
|
||||
)
|
||||
|
||||
|
||||
class AccountTestRequest(BaseModel):
|
||||
account: dict
|
||||
|
||||
|
||||
@app.post("/api/accounts/test")
|
||||
def test_account(req: AccountTestRequest):
|
||||
from app.imap_fetch import test_connection
|
||||
ok, message, count = test_connection(req.account)
|
||||
return {"ok": ok, "message": message, "count": count}
|
||||
|
||||
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Benchmark endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@app.get("/api/benchmark/results")
|
||||
def get_benchmark_results():
|
||||
"""Return the most recently saved benchmark results, or an empty envelope."""
|
||||
path = _DATA_DIR / "benchmark_results.json"
|
||||
if not path.exists():
|
||||
return {"models": {}, "sample_count": 0, "timestamp": None}
|
||||
return json.loads(path.read_text())
|
||||
|
||||
|
||||
@app.get("/api/benchmark/run")
|
||||
def run_benchmark(include_slow: bool = False):
|
||||
"""Spawn the benchmark script and stream stdout as SSE progress events."""
|
||||
python_bin = "/devl/miniconda3/envs/job-seeker-classifiers/bin/python"
|
||||
script = str(_ROOT / "scripts" / "benchmark_classifier.py")
|
||||
cmd = [python_bin, script, "--score", "--save"]
|
||||
if include_slow:
|
||||
cmd.append("--include-slow")
|
||||
|
||||
def generate():
|
||||
try:
|
||||
proc = _subprocess.Popen(
|
||||
cmd,
|
||||
stdout=_subprocess.PIPE,
|
||||
stderr=_subprocess.STDOUT,
|
||||
text=True,
|
||||
bufsize=1,
|
||||
cwd=str(_ROOT),
|
||||
)
|
||||
_running_procs["benchmark"] = proc
|
||||
_cancelled_jobs.discard("benchmark") # clear any stale flag from a prior run
|
||||
try:
|
||||
for line in proc.stdout:
|
||||
line = line.rstrip()
|
||||
if line:
|
||||
yield f"data: {json.dumps({'type': 'progress', 'message': line})}\n\n"
|
||||
proc.wait()
|
||||
if proc.returncode == 0:
|
||||
yield f"data: {json.dumps({'type': 'complete'})}\n\n"
|
||||
elif "benchmark" in _cancelled_jobs:
|
||||
_cancelled_jobs.discard("benchmark")
|
||||
yield f"data: {json.dumps({'type': 'cancelled'})}\n\n"
|
||||
else:
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': f'Process exited with code {proc.returncode}'})}\n\n"
|
||||
finally:
|
||||
_running_procs.pop("benchmark", None)
|
||||
except Exception as exc:
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': str(exc)})}\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
generate(),
|
||||
media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Finetune endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@app.get("/api/finetune/status")
|
||||
def get_finetune_status():
|
||||
"""Scan models/ for training_info.json files. Returns [] if none exist."""
|
||||
models_dir = _MODELS_DIR
|
||||
if not models_dir.exists():
|
||||
return []
|
||||
results = []
|
||||
for sub in models_dir.iterdir():
|
||||
if not sub.is_dir():
|
||||
continue
|
||||
info_path = sub / "training_info.json"
|
||||
if not info_path.exists():
|
||||
continue
|
||||
try:
|
||||
info = json.loads(info_path.read_text(encoding="utf-8"))
|
||||
results.append(info)
|
||||
except Exception:
|
||||
pass
|
||||
return results
|
||||
|
||||
|
||||
@app.get("/api/finetune/run")
|
||||
def finetune_run_compat(model: str = Query(...), epochs: int = Query(5)) -> _StreamingResponse:
|
||||
"""Shim: create a classifier train job and immediately stream it."""
|
||||
from app.train.train import create_job, run_job, CreateJobRequest
|
||||
job = create_job(CreateJobRequest(type="classifier", model_key=model, config_json={"epochs": epochs}))
|
||||
return run_job(job["id"])
|
||||
def run_finetune_endpoint(
|
||||
model: str = "deberta-small",
|
||||
epochs: int = 5,
|
||||
score: list[str] = Query(default=[]),
|
||||
):
|
||||
"""Spawn finetune_classifier.py and stream stdout as SSE progress events."""
|
||||
python_bin = "/devl/miniconda3/envs/job-seeker-classifiers/bin/python"
|
||||
script = str(_ROOT / "scripts" / "finetune_classifier.py")
|
||||
cmd = [python_bin, script, "--model", model, "--epochs", str(epochs)]
|
||||
data_root = _DATA_DIR.resolve()
|
||||
for score_file in score:
|
||||
resolved = (_DATA_DIR / score_file).resolve()
|
||||
if not str(resolved).startswith(str(data_root)):
|
||||
raise HTTPException(400, f"Invalid score path: {score_file!r}")
|
||||
cmd.extend(["--score", str(resolved)])
|
||||
|
||||
# Pick the GPU with the most free VRAM. Setting CUDA_VISIBLE_DEVICES to a
|
||||
# single device prevents DataParallel from replicating the model across all
|
||||
# GPUs, which would force a full copy onto the more memory-constrained device.
|
||||
proc_env = {**os.environ, "PYTORCH_ALLOC_CONF": "expandable_segments:True"}
|
||||
best_gpu = _best_cuda_device()
|
||||
if best_gpu:
|
||||
proc_env["CUDA_VISIBLE_DEVICES"] = best_gpu
|
||||
|
||||
gpu_note = f"GPU {best_gpu}" if best_gpu else "CPU (no GPU found)"
|
||||
|
||||
def generate():
|
||||
yield f"data: {json.dumps({'type': 'progress', 'message': f'[api] Using {gpu_note} (most free VRAM)'})}\n\n"
|
||||
try:
|
||||
proc = _subprocess.Popen(
|
||||
cmd,
|
||||
stdout=_subprocess.PIPE,
|
||||
stderr=_subprocess.STDOUT,
|
||||
text=True,
|
||||
bufsize=1,
|
||||
cwd=str(_ROOT),
|
||||
env=proc_env,
|
||||
)
|
||||
_running_procs["finetune"] = proc
|
||||
_cancelled_jobs.discard("finetune") # clear any stale flag from a prior run
|
||||
try:
|
||||
for line in proc.stdout:
|
||||
line = line.rstrip()
|
||||
if line:
|
||||
yield f"data: {json.dumps({'type': 'progress', 'message': line})}\n\n"
|
||||
proc.wait()
|
||||
if proc.returncode == 0:
|
||||
yield f"data: {json.dumps({'type': 'complete'})}\n\n"
|
||||
elif "finetune" in _cancelled_jobs:
|
||||
_cancelled_jobs.discard("finetune")
|
||||
yield f"data: {json.dumps({'type': 'cancelled'})}\n\n"
|
||||
else:
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': f'Process exited with code {proc.returncode}'})}\n\n"
|
||||
finally:
|
||||
_running_procs.pop("finetune", None)
|
||||
except Exception as exc:
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': str(exc)})}\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
generate(),
|
||||
media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
|
||||
)
|
||||
|
||||
|
||||
@app.post("/api/benchmark/cancel")
|
||||
def cancel_benchmark():
|
||||
"""Kill the running benchmark subprocess. 404 if none is running."""
|
||||
proc = _running_procs.get("benchmark")
|
||||
if proc is None:
|
||||
raise HTTPException(404, "No benchmark is running")
|
||||
_cancelled_jobs.add("benchmark")
|
||||
proc.terminate()
|
||||
try:
|
||||
proc.wait(timeout=3)
|
||||
except _subprocess.TimeoutExpired:
|
||||
proc.kill()
|
||||
return {"status": "cancelled"}
|
||||
|
||||
|
||||
@app.post("/api/finetune/cancel")
|
||||
def finetune_cancel_compat() -> dict:
|
||||
"""Shim: cancel the most recent running classifier job."""
|
||||
from app.train.train import _db, _init_db, cancel_job
|
||||
from fastapi import HTTPException
|
||||
_init_db()
|
||||
with _db() as conn:
|
||||
row = conn.execute(
|
||||
"SELECT id FROM jobs WHERE type='classifier' AND status='running' ORDER BY started_at DESC LIMIT 1"
|
||||
).fetchone()
|
||||
if row is None:
|
||||
return {"status": "nothing_running"}
|
||||
return cancel_job(row["id"])
|
||||
def cancel_finetune():
|
||||
"""Kill the running fine-tune subprocess. 404 if none is running."""
|
||||
proc = _running_procs.get("finetune")
|
||||
if proc is None:
|
||||
raise HTTPException(404, "No finetune is running")
|
||||
_cancelled_jobs.add("finetune")
|
||||
proc.terminate()
|
||||
try:
|
||||
proc.wait(timeout=3)
|
||||
except _subprocess.TimeoutExpired:
|
||||
proc.kill()
|
||||
return {"status": "cancelled"}
|
||||
|
||||
from app.data.log_corpus import router as log_corpus_router
|
||||
app.include_router(log_corpus_router, prefix="/api/corpus")
|
||||
|
||||
from app.data.recipe_scan import router as recipe_scan_router
|
||||
app.include_router(recipe_scan_router, prefix="/api/recipe-scan")
|
||||
@app.get("/api/fetch/stream")
|
||||
def fetch_stream(
|
||||
accounts: str = Query(default=""),
|
||||
days_back: int = Query(default=90, ge=1, le=365),
|
||||
limit: int = Query(default=150, ge=1, le=1000),
|
||||
mode: str = Query(default="wide"),
|
||||
):
|
||||
from app.imap_fetch import fetch_account_stream
|
||||
|
||||
from app.dashboard import router as dashboard_router
|
||||
app.include_router(dashboard_router, prefix="/api")
|
||||
selected_names = {n.strip() for n in accounts.split(",") if n.strip()}
|
||||
config = get_config() # reuse existing endpoint logic
|
||||
selected = [a for a in config["accounts"] if a.get("name") in selected_names]
|
||||
|
||||
from app.models import router as models_router
|
||||
app.include_router(models_router, prefix="/api/models")
|
||||
def generate():
|
||||
known_keys = {_item_id(x) for x in _read_jsonl(_queue_file())}
|
||||
total_added = 0
|
||||
|
||||
from app.nodes import router as nodes_router
|
||||
app.include_router(nodes_router, prefix="/api/nodes-mgmt")
|
||||
for acc in selected:
|
||||
try:
|
||||
batch_emails: list[dict] = []
|
||||
for event in fetch_account_stream(acc, days_back, limit, known_keys):
|
||||
if event["type"] == "done":
|
||||
batch_emails = event.pop("emails", [])
|
||||
total_added += event["added"]
|
||||
yield f"data: {json.dumps(event)}\n\n"
|
||||
# Write new emails to queue after each account
|
||||
if batch_emails:
|
||||
existing = _read_jsonl(_queue_file())
|
||||
_write_jsonl(_queue_file(), existing + batch_emails)
|
||||
except Exception as exc:
|
||||
error_event = {"type": "error", "account": acc.get("name", "?"),
|
||||
"message": str(exc)}
|
||||
yield f"data: {json.dumps(error_event)}\n\n"
|
||||
|
||||
# -- Static SPA -- MUST be last (catches all unmatched paths) ---------------
|
||||
queue_size = len(_read_jsonl(_queue_file()))
|
||||
complete = {"type": "complete", "total_added": total_added, "queue_size": queue_size}
|
||||
yield f"data: {json.dumps(complete)}\n\n"
|
||||
|
||||
_ROOT = Path(__file__).parent.parent
|
||||
return StreamingResponse(generate(), media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"})
|
||||
|
||||
|
||||
# Static SPA — MUST be last (catches all unmatched paths)
|
||||
_DIST = _ROOT / "web" / "dist"
|
||||
if _DIST.exists():
|
||||
from fastapi.responses import FileResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
|
||||
# Serve index.html with no-cache so browsers always fetch fresh HTML after rebuilds.
|
||||
# Hashed assets (/assets/index-abc123.js) can be cached forever — they change names
|
||||
# when content changes (standard Vite cache-busting strategy).
|
||||
_NO_CACHE = {"Cache-Control": "no-cache, no-store, must-revalidate", "Pragma": "no-cache"}
|
||||
|
||||
@app.get("/")
|
||||
|
|
|
|||
653
app/cforch.py
653
app/cforch.py
|
|
@ -1,653 +0,0 @@
|
|||
"""Avocet — cf-orch benchmark integration API.
|
||||
|
||||
Wraps cf-orch's benchmark.py script and exposes it via the Avocet API.
|
||||
Config is read from label_tool.yaml under the `cforch:` key.
|
||||
|
||||
All endpoints are registered on `router` (a FastAPI APIRouter).
|
||||
api.py includes this router with prefix="/api/cforch".
|
||||
|
||||
Module-level globals (_CONFIG_DIR, _BENCH_RUNNING, _bench_proc) follow the
|
||||
same testability pattern as sft.py — override _CONFIG_DIR via set_config_dir()
|
||||
in test fixtures.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import select as _select
|
||||
import subprocess as _subprocess
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
import urllib.parse
|
||||
|
||||
import yaml
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_ROOT = Path(__file__).parent.parent
|
||||
_CONFIG_DIR: Path | None = None # override in tests
|
||||
_BENCH_RUNNING: bool = False
|
||||
_bench_proc: Any = None # live Popen object while benchmark runs
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# ── Testability seams ──────────────────────────────────────────────────────────
|
||||
|
||||
def set_config_dir(path: Path | None) -> None:
|
||||
global _CONFIG_DIR
|
||||
_CONFIG_DIR = path
|
||||
|
||||
|
||||
# ── Internal helpers ───────────────────────────────────────────────────────────
|
||||
|
||||
def _config_file() -> Path:
|
||||
if _CONFIG_DIR is not None:
|
||||
return _CONFIG_DIR / "label_tool.yaml"
|
||||
return _ROOT / "config" / "label_tool.yaml"
|
||||
|
||||
|
||||
def _load_cforch_config() -> dict:
|
||||
"""Read label_tool.yaml cforch section, falling back to environment variables.
|
||||
|
||||
Priority (highest to lowest):
|
||||
1. label_tool.yaml cforch: key
|
||||
2. Environment variables (CF_ORCH_URL, CF_LICENSE_KEY, OLLAMA_HOST, OLLAMA_MODEL)
|
||||
"""
|
||||
f = _config_file()
|
||||
file_cfg: dict = {}
|
||||
if f.exists():
|
||||
try:
|
||||
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
||||
file_cfg = raw.get("cforch", {}) or {}
|
||||
except yaml.YAMLError as exc:
|
||||
logger.warning("Failed to parse cforch config %s: %s", f, exc)
|
||||
|
||||
# Env var fallbacks — only used when the yaml key is absent or empty
|
||||
def _coalesce(file_val: str, env_key: str) -> str:
|
||||
return file_val if file_val else os.environ.get(env_key, "")
|
||||
|
||||
return {
|
||||
**file_cfg,
|
||||
"coordinator_url": _coalesce(file_cfg.get("coordinator_url", ""), "CF_ORCH_URL"),
|
||||
"license_key": _coalesce(file_cfg.get("license_key", ""), "CF_LICENSE_KEY"),
|
||||
"ollama_url": _coalesce(file_cfg.get("ollama_url", ""), "OLLAMA_HOST"),
|
||||
"ollama_model": _coalesce(file_cfg.get("ollama_model", ""), "OLLAMA_MODEL"),
|
||||
"judge_url": _coalesce(file_cfg.get("judge_url", ""), "CF_JUDGE_URL"),
|
||||
"hf_token": _coalesce(file_cfg.get("hf_token", ""), "HF_TOKEN"),
|
||||
}
|
||||
|
||||
|
||||
def _validate_service_url(url: str, param_name: str) -> str:
|
||||
"""Validate that a URL is a well-formed http/https URL with a hostname.
|
||||
|
||||
Guards against SSRF: only http/https is allowed; the URL must have a
|
||||
non-empty host. Does not enforce an allowlist — call sites are internal
|
||||
tooling, not a public API.
|
||||
"""
|
||||
if not url:
|
||||
return url
|
||||
try:
|
||||
parsed = urllib.parse.urlparse(url)
|
||||
except Exception:
|
||||
raise HTTPException(400, f"{param_name}: not a valid URL")
|
||||
if parsed.scheme not in ("http", "https"):
|
||||
raise HTTPException(400, f"{param_name}: URL must start with http:// or https://")
|
||||
if not parsed.hostname:
|
||||
raise HTTPException(400, f"{param_name}: URL has no hostname")
|
||||
return url
|
||||
|
||||
|
||||
def _strip_ansi(text: str) -> str:
|
||||
"""Remove ANSI escape codes from a string."""
|
||||
return re.sub(r'\x1b\[[0-9;]*m', '', text)
|
||||
|
||||
|
||||
def _find_latest_summary(results_dir: str | None) -> Path | None:
|
||||
"""Find the newest summary.json under results_dir, or None if not found."""
|
||||
if not results_dir:
|
||||
return None
|
||||
rdir = Path(results_dir)
|
||||
if not rdir.exists():
|
||||
return None
|
||||
# Subdirs are named YYYY-MM-DD-HHMMSS; sort lexicographically for chronological order
|
||||
subdirs = sorted(
|
||||
[d for d in rdir.iterdir() if d.is_dir()],
|
||||
key=lambda d: d.name,
|
||||
)
|
||||
for subdir in reversed(subdirs):
|
||||
summary = subdir / "summary.json"
|
||||
if summary.exists():
|
||||
return summary
|
||||
return None
|
||||
|
||||
|
||||
# ── GET /tasks ─────────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/tasks")
|
||||
def get_tasks() -> dict:
|
||||
"""Return task list from bench_tasks.yaml."""
|
||||
cfg = _load_cforch_config()
|
||||
tasks_path = cfg.get("bench_tasks", "")
|
||||
if not tasks_path:
|
||||
return {"tasks": [], "types": []}
|
||||
|
||||
p = Path(tasks_path)
|
||||
if not p.exists():
|
||||
return {"tasks": [], "types": []}
|
||||
|
||||
try:
|
||||
raw = yaml.safe_load(p.read_text(encoding="utf-8")) or {}
|
||||
except yaml.YAMLError as exc:
|
||||
logger.warning("Failed to parse bench_tasks.yaml %s: %s", p, exc)
|
||||
return {"tasks": [], "types": []}
|
||||
|
||||
tasks_raw = raw.get("tasks", []) or []
|
||||
tasks: list[dict] = []
|
||||
seen_types: list[str] = []
|
||||
types_set: set[str] = set()
|
||||
|
||||
for t in tasks_raw:
|
||||
if not isinstance(t, dict):
|
||||
continue
|
||||
tasks.append({
|
||||
"id": t.get("id", ""),
|
||||
"name": t.get("name", ""),
|
||||
"type": t.get("type", ""),
|
||||
"prompt": (t.get("prompt") or "").strip(),
|
||||
"system": (t.get("system") or "").strip(),
|
||||
})
|
||||
task_type = t.get("type", "")
|
||||
if task_type and task_type not in types_set:
|
||||
seen_types.append(task_type)
|
||||
types_set.add(task_type)
|
||||
|
||||
return {"tasks": tasks, "types": seen_types}
|
||||
|
||||
|
||||
# ── GET /models ────────────────────────────────────────────────────────────────
|
||||
|
||||
# Services and roles surfaced in the benchmark model picker.
|
||||
# Covers all cf-orch service types that benchmark.py can route tasks to.
|
||||
_BENCH_SERVICES = frozenset({
|
||||
"cf-text", "vllm", # LLM text generation
|
||||
"cf-stt", # speech-to-text
|
||||
"cf-tts", # text-to-speech
|
||||
"cf-vision", # image classification / embedding
|
||||
"cf-voice", # audio context classification
|
||||
})
|
||||
_BENCH_ROLES = frozenset({
|
||||
"generator", "vlm", # LLM roles
|
||||
"stt", "alm", # speech recognition
|
||||
"tts", # speech synthesis
|
||||
"vision", "embedding", # image understanding
|
||||
"classifier", # audio classification (cf-voice)
|
||||
})
|
||||
|
||||
|
||||
@router.get("/models")
|
||||
def get_models() -> dict:
|
||||
"""Return model list from bench_models.yaml merged with locally installed models.
|
||||
|
||||
bench_models.yaml entries are listed first and take precedence; any installed
|
||||
model whose repo_id is already present in the YAML is skipped. Only models
|
||||
whose service is in _BENCH_SERVICES (cf-text, vllm, cf-stt, cf-tts, cf-vision,
|
||||
cf-voice) are surfaced from the installed registry.
|
||||
"""
|
||||
cfg = _load_cforch_config()
|
||||
models_path = cfg.get("bench_models", "")
|
||||
|
||||
models: list[dict] = []
|
||||
bench_ids: set[str] = set()
|
||||
|
||||
if models_path:
|
||||
p = Path(models_path)
|
||||
if p.exists():
|
||||
try:
|
||||
raw = yaml.safe_load(p.read_text(encoding="utf-8")) or {}
|
||||
except yaml.YAMLError as exc:
|
||||
logger.warning("Failed to parse bench_models.yaml %s: %s", p, exc)
|
||||
raw = {}
|
||||
for m in (raw.get("models", []) or []):
|
||||
if not isinstance(m, dict):
|
||||
continue
|
||||
model_id = m.get("id", "")
|
||||
models.append({
|
||||
"name": m.get("name", ""),
|
||||
"id": model_id,
|
||||
"service": m.get("service", "ollama"),
|
||||
"tags": m.get("tags", []) or [],
|
||||
"vram_estimate_mb": m.get("vram_estimate_mb", 0),
|
||||
})
|
||||
if model_id:
|
||||
bench_ids.add(model_id)
|
||||
|
||||
# Merge installed generator models not already in bench_models.yaml.
|
||||
try:
|
||||
from app.models import list_installed # local import avoids circular dependency at module load
|
||||
for installed in list_installed():
|
||||
model_id: str = installed.get("model_id") or ""
|
||||
service: str = installed.get("service") or ""
|
||||
role: str = installed.get("role") or ""
|
||||
if not model_id:
|
||||
continue
|
||||
if service not in _BENCH_SERVICES or role not in _BENCH_ROLES:
|
||||
continue
|
||||
if model_id in bench_ids:
|
||||
continue
|
||||
display_name = model_id.split("/", 1)[-1] if "/" in model_id else model_id
|
||||
models.append({
|
||||
"name": display_name,
|
||||
"id": model_id,
|
||||
"service": service,
|
||||
"tags": [role],
|
||||
"vram_estimate_mb": installed.get("vram_mb") or 0,
|
||||
})
|
||||
bench_ids.add(model_id)
|
||||
except Exception as exc:
|
||||
logger.warning("Could not merge installed models into model list: %s", exc)
|
||||
|
||||
return {"models": models}
|
||||
|
||||
|
||||
# ── GET /run ───────────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/nodes")
|
||||
def get_nodes() -> dict:
|
||||
"""Proxy the coordinator's /api/nodes list, returning node_id + online status.
|
||||
|
||||
Online is inferred from last_heartbeat: any node with a recent heartbeat is online.
|
||||
Returns an empty list if the coordinator is unreachable.
|
||||
"""
|
||||
cfg = _load_cforch_config()
|
||||
coordinator_url = cfg.get("coordinator_url", "").rstrip("/")
|
||||
if not coordinator_url:
|
||||
return {"nodes": []}
|
||||
try:
|
||||
import httpx as _httpx
|
||||
resp = _httpx.get(f"{coordinator_url}/api/nodes", timeout=5.0)
|
||||
resp.raise_for_status()
|
||||
raw_nodes = resp.json().get("nodes", [])
|
||||
return {
|
||||
"nodes": [
|
||||
{
|
||||
"node_id": n.get("node_id", ""),
|
||||
"online": n.get("last_heartbeat") is not None,
|
||||
"gpus": [
|
||||
{
|
||||
"gpu_id": g.get("gpu_id"),
|
||||
"name": g.get("name", ""),
|
||||
"vram_total_mb": g.get("vram_total_mb", 0),
|
||||
"vram_free_mb": g.get("vram_free_mb", 0),
|
||||
}
|
||||
for g in n.get("gpus", [])
|
||||
],
|
||||
}
|
||||
for n in raw_nodes
|
||||
]
|
||||
}
|
||||
except Exception as exc:
|
||||
logger.warning("Could not fetch nodes from coordinator: %s", exc)
|
||||
return {"nodes": []}
|
||||
|
||||
|
||||
@router.get("/run")
|
||||
def run_benchmark(
|
||||
task_ids: str = "",
|
||||
model_ids: str = "",
|
||||
model_tags: str = "",
|
||||
coordinator_url: str = "",
|
||||
ollama_url: str = "",
|
||||
judge_url: str = "",
|
||||
judge_backend: str = "chat",
|
||||
workers: int = 1,
|
||||
node_ids: str = "",
|
||||
) -> StreamingResponse:
|
||||
"""Spawn cf-orch benchmark.py and stream stdout as SSE progress events."""
|
||||
global _BENCH_RUNNING, _bench_proc
|
||||
|
||||
# Check if the process is actually still alive; reset stale flag if not.
|
||||
if _BENCH_RUNNING:
|
||||
if _bench_proc is not None and _bench_proc.poll() is None:
|
||||
raise HTTPException(409, "A benchmark is already running")
|
||||
_BENCH_RUNNING = False
|
||||
_bench_proc = None
|
||||
|
||||
cfg = _load_cforch_config()
|
||||
bench_script = cfg.get("bench_script", "")
|
||||
bench_tasks = cfg.get("bench_tasks", "")
|
||||
bench_models = cfg.get("bench_models", "")
|
||||
results_dir = cfg.get("results_dir", "")
|
||||
python_bin = cfg.get("python_bin", "/devl/miniconda3/envs/cf/bin/python")
|
||||
cfg_coordinator = cfg.get("coordinator_url", "")
|
||||
cfg_ollama = cfg.get("ollama_url", "")
|
||||
cfg_license_key = cfg.get("license_key", "")
|
||||
cfg_judge_url = cfg.get("judge_url", "")
|
||||
|
||||
# Validate URL params before spawning the subprocess.
|
||||
# _validate_service_url raises HTTPException on bad input (caught by FastAPI before streaming starts).
|
||||
_validate_service_url(coordinator_url, "coordinator_url")
|
||||
_validate_service_url(ollama_url, "ollama_url")
|
||||
_validate_service_url(judge_url, "judge_url")
|
||||
|
||||
def generate():
|
||||
global _BENCH_RUNNING, _bench_proc
|
||||
|
||||
if not bench_script or not Path(bench_script).exists():
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': 'bench_script not configured or not found'})}\n\n"
|
||||
return
|
||||
|
||||
# Build effective models file: bench_models.yaml + any installed models
|
||||
# whose IDs were selected but are absent from the YAML (e.g. downloaded
|
||||
# via the Models view). Written to a temp file so benchmark.py sees one
|
||||
# unified list; cleaned up in the finally block.
|
||||
effective_models_file = bench_models
|
||||
_tmp_models_path: str | None = None
|
||||
|
||||
if model_ids and bench_models and Path(bench_models).exists():
|
||||
requested_ids = set(model_ids.split(","))
|
||||
try:
|
||||
raw_bench = yaml.safe_load(Path(bench_models).read_text(encoding="utf-8")) or {}
|
||||
bench_entries: list[dict] = raw_bench.get("models", []) or []
|
||||
bench_id_set = {m.get("id", "") for m in bench_entries if isinstance(m, dict)}
|
||||
missing_ids = requested_ids - bench_id_set
|
||||
if missing_ids:
|
||||
from app.models import list_installed
|
||||
installed_map = {
|
||||
m["model_id"]: m
|
||||
for m in list_installed()
|
||||
if m.get("model_id") and m.get("service") in _BENCH_SERVICES
|
||||
}
|
||||
extra: list[dict] = []
|
||||
for mid in missing_ids:
|
||||
if mid in installed_map:
|
||||
inst = installed_map[mid]
|
||||
entry: dict[str, Any] = {
|
||||
"id": mid,
|
||||
"name": mid.split("/", 1)[-1] if "/" in mid else mid,
|
||||
"service": inst.get("service", "cf-text"),
|
||||
"vram_estimate_mb": inst.get("vram_mb") or 0,
|
||||
"tags": [inst.get("role", "generator")],
|
||||
"temperature": 0.0,
|
||||
}
|
||||
local_path = inst.get("path", "") or inst.get("local_path", "")
|
||||
if local_path:
|
||||
entry["model_path"] = local_path
|
||||
extra.append(entry)
|
||||
if extra:
|
||||
merged = {"models": bench_entries + extra}
|
||||
tf = tempfile.NamedTemporaryFile(
|
||||
mode="w", suffix=".yaml", delete=False,
|
||||
prefix="avocet_bench_models_",
|
||||
)
|
||||
yaml.dump(merged, tf)
|
||||
tf.close()
|
||||
_tmp_models_path = tf.name
|
||||
effective_models_file = _tmp_models_path
|
||||
except Exception as exc:
|
||||
logger.warning("Could not merge installed models into temp bench file: %s", exc)
|
||||
|
||||
cmd = [
|
||||
python_bin,
|
||||
bench_script,
|
||||
"--tasks", bench_tasks,
|
||||
"--models", effective_models_file,
|
||||
"--output", results_dir,
|
||||
]
|
||||
|
||||
if task_ids:
|
||||
cmd.extend(["--filter-tasks"] + task_ids.split(","))
|
||||
if model_ids:
|
||||
cmd.extend(["--filter-models"] + model_ids.split(","))
|
||||
if model_tags:
|
||||
cmd.extend(["--filter-tags"] + model_tags.split(","))
|
||||
|
||||
# query param overrides config, config overrides env var (already resolved by _load_cforch_config)
|
||||
effective_coordinator = coordinator_url if coordinator_url else cfg_coordinator
|
||||
effective_ollama = ollama_url if ollama_url else cfg_ollama
|
||||
if effective_coordinator:
|
||||
cmd.extend(["--coordinator", effective_coordinator])
|
||||
if effective_ollama:
|
||||
cmd.extend(["--ollama-url", effective_ollama])
|
||||
effective_judge = judge_url if judge_url else cfg_judge_url
|
||||
if effective_judge:
|
||||
cmd.extend(["--judge-url", effective_judge])
|
||||
if judge_backend and judge_backend != "chat":
|
||||
cmd.extend(["--judge-backend", judge_backend])
|
||||
if workers > 1:
|
||||
cmd.extend(["--workers", str(workers)])
|
||||
if node_ids:
|
||||
cmd.extend(["--nodes"] + node_ids.split(","))
|
||||
|
||||
# Pass license key as env var so subprocess can authenticate with cf-orch
|
||||
proc_env = {**os.environ}
|
||||
if cfg_license_key:
|
||||
proc_env["CF_LICENSE_KEY"] = cfg_license_key
|
||||
|
||||
_BENCH_RUNNING = True
|
||||
try:
|
||||
proc = _subprocess.Popen(
|
||||
cmd,
|
||||
stdout=_subprocess.PIPE,
|
||||
stderr=_subprocess.STDOUT,
|
||||
text=True,
|
||||
bufsize=1,
|
||||
env=proc_env,
|
||||
)
|
||||
_bench_proc = proc
|
||||
_IDLE_TIMEOUT_S = 120 # kill if no output for 2 minutes (node crash)
|
||||
try:
|
||||
while True:
|
||||
ready = _select.select([proc.stdout], [], [], _IDLE_TIMEOUT_S)
|
||||
if not ready[0]:
|
||||
# No output for IDLE_TIMEOUT_S — node likely crashed
|
||||
proc.terminate()
|
||||
try:
|
||||
proc.wait(timeout=5)
|
||||
except _subprocess.TimeoutExpired:
|
||||
proc.kill()
|
||||
msg = f"Benchmark timed out — no output for {_IDLE_TIMEOUT_S}s (cluster node may have crashed)"
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': msg})}\n\n"
|
||||
break
|
||||
line = proc.stdout.readline()
|
||||
if not line:
|
||||
break
|
||||
line = _strip_ansi(line.rstrip())
|
||||
if line:
|
||||
yield f"data: {json.dumps({'type': 'progress', 'message': line})}\n\n"
|
||||
proc.wait()
|
||||
if proc.returncode == 0:
|
||||
summary_path = _find_latest_summary(results_dir)
|
||||
if summary_path is not None:
|
||||
try:
|
||||
summary = json.loads(summary_path.read_text(encoding="utf-8"))
|
||||
yield f"data: {json.dumps({'type': 'result', 'summary': summary})}\n\n"
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to read summary.json: %s", exc)
|
||||
yield f"data: {json.dumps({'type': 'complete'})}\n\n"
|
||||
else:
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': f'Process exited with code {proc.returncode}'})}\n\n"
|
||||
finally:
|
||||
_bench_proc = None
|
||||
except Exception as exc:
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': str(exc)})}\n\n"
|
||||
finally:
|
||||
_BENCH_RUNNING = False
|
||||
if _tmp_models_path:
|
||||
try:
|
||||
os.unlink(_tmp_models_path)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
return StreamingResponse(
|
||||
generate(),
|
||||
media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
|
||||
)
|
||||
|
||||
|
||||
# ── GET /config ────────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/config")
|
||||
def get_cforch_config() -> dict:
|
||||
"""Return resolved cf-orch connection config (env vars merged with yaml).
|
||||
|
||||
Redacts license_key — only returns whether it is set, not the value.
|
||||
Used by the Settings UI to show current connection state.
|
||||
"""
|
||||
cfg = _load_cforch_config()
|
||||
return {
|
||||
"coordinator_url": cfg.get("coordinator_url", ""),
|
||||
"ollama_url": cfg.get("ollama_url", ""),
|
||||
"ollama_model": cfg.get("ollama_model", ""),
|
||||
"judge_url": cfg.get("judge_url", ""),
|
||||
"license_key_set": bool(cfg.get("license_key", "")),
|
||||
"source": "env" if not _config_file().exists() else "yaml+env",
|
||||
}
|
||||
|
||||
|
||||
# ── GET /results ───────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/results")
|
||||
def get_results() -> dict:
|
||||
"""Return the latest benchmark summary.json from results_dir."""
|
||||
cfg = _load_cforch_config()
|
||||
results_dir = cfg.get("results_dir", "")
|
||||
summary_path = _find_latest_summary(results_dir)
|
||||
if summary_path is None:
|
||||
raise HTTPException(404, "No benchmark results found")
|
||||
try:
|
||||
return json.loads(summary_path.read_text(encoding="utf-8"))
|
||||
except Exception as exc:
|
||||
raise HTTPException(500, f"Failed to read summary.json: {exc}") from exc
|
||||
|
||||
|
||||
# ── POST /cancel ───────────────────────────────────────────────────────────────
|
||||
|
||||
@router.post("/cancel")
|
||||
def cancel_benchmark() -> dict:
|
||||
"""Kill the running benchmark subprocess."""
|
||||
global _BENCH_RUNNING, _bench_proc
|
||||
|
||||
if not _BENCH_RUNNING:
|
||||
raise HTTPException(404, "No benchmark is currently running")
|
||||
|
||||
if _bench_proc is not None:
|
||||
try:
|
||||
_bench_proc.terminate()
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to terminate benchmark process: %s", exc)
|
||||
|
||||
_BENCH_RUNNING = False
|
||||
_bench_proc = None
|
||||
return {"status": "cancelled"}
|
||||
|
||||
|
||||
# ── Coordinator proxy helpers ──────────────────────────────────────────────────
|
||||
|
||||
def _coordinator_url() -> str:
|
||||
"""Return coordinator base URL from config, or raise 503 if not configured."""
|
||||
url = _load_cforch_config().get("coordinator_url", "").rstrip("/")
|
||||
if not url:
|
||||
raise HTTPException(503, "cf-orch coordinator_url not configured")
|
||||
return url
|
||||
|
||||
|
||||
def _coordinator_get(path: str) -> Any:
|
||||
"""GET from coordinator, return parsed JSON body. Raises HTTPException on error."""
|
||||
import httpx as _httpx
|
||||
try:
|
||||
resp = _httpx.get(f"{_coordinator_url()}{path}", timeout=10.0)
|
||||
except Exception as exc:
|
||||
raise HTTPException(502, f"Coordinator unreachable: {exc}") from exc
|
||||
if not resp.is_success:
|
||||
raise HTTPException(resp.status_code, resp.text)
|
||||
return resp.json()
|
||||
|
||||
|
||||
async def _coordinator_post(path: str, body: dict) -> Any:
|
||||
import httpx as _httpx
|
||||
try:
|
||||
async with _httpx.AsyncClient(timeout=10.0) as client:
|
||||
resp = await client.post(f"{_coordinator_url()}{path}", json=body)
|
||||
except Exception as exc:
|
||||
raise HTTPException(502, f"Coordinator unreachable: {exc}") from exc
|
||||
if not resp.is_success:
|
||||
raise HTTPException(resp.status_code, resp.text)
|
||||
return resp.json()
|
||||
|
||||
|
||||
async def _coordinator_delete(path: str) -> Any:
|
||||
import httpx as _httpx
|
||||
try:
|
||||
async with _httpx.AsyncClient(timeout=10.0) as client:
|
||||
resp = await client.delete(f"{_coordinator_url()}{path}")
|
||||
except Exception as exc:
|
||||
raise HTTPException(502, f"Coordinator unreachable: {exc}") from exc
|
||||
if not resp.is_success:
|
||||
raise HTTPException(resp.status_code, resp.text)
|
||||
return resp.json()
|
||||
|
||||
|
||||
# ── GET /assignments/deployment-status ───────────────────────────────────────
|
||||
|
||||
@router.get("/assignments/deployment-status")
|
||||
def get_deployment_status() -> Any:
|
||||
return _coordinator_get("/api/assignments/deployment-status")
|
||||
|
||||
|
||||
# ── /assignments ──────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/assignments")
|
||||
def list_assignments() -> Any:
|
||||
return _coordinator_get("/api/assignments")
|
||||
|
||||
|
||||
class AssignmentBody(BaseModel):
|
||||
product: str
|
||||
task: str
|
||||
model_id: str
|
||||
description: str = ""
|
||||
|
||||
|
||||
@router.post("/assignments")
|
||||
async def upsert_assignment(body: AssignmentBody) -> Any:
|
||||
return await _coordinator_post("/api/assignments", body.model_dump())
|
||||
|
||||
|
||||
@router.delete("/assignments/{product}/{task}")
|
||||
async def delete_assignment(product: str, task: str) -> Any:
|
||||
return await _coordinator_delete(f"/api/assignments/{urllib.parse.quote(product, safe='')}/{urllib.parse.quote(task, safe='')}")
|
||||
|
||||
|
||||
# ── /model-registry ────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/model-registry")
|
||||
def list_model_registry() -> Any:
|
||||
return _coordinator_get("/api/model-registry")
|
||||
|
||||
|
||||
class ModelRegistryBody(BaseModel):
|
||||
model_id: str
|
||||
service_type: str
|
||||
vram_mb: int
|
||||
description: str = ""
|
||||
hf_repo: str = ""
|
||||
alias: str = ""
|
||||
|
||||
|
||||
@router.post("/model-registry")
|
||||
async def upsert_model_registry(body: ModelRegistryBody) -> Any:
|
||||
return await _coordinator_post("/api/model-registry", body.model_dump())
|
||||
|
||||
|
||||
@router.delete("/model-registry/{model_id:path}")
|
||||
async def delete_model_registry(model_id: str) -> Any:
|
||||
return await _coordinator_delete(f"/api/model-registry/{urllib.parse.quote(model_id, safe='')}")
|
||||
|
|
@ -1,34 +0,0 @@
|
|||
"""
|
||||
Avocet cloud session — thin wrapper around cf_core.cloud_session.
|
||||
|
||||
Usage in FastAPI routes:
|
||||
|
||||
from app.cloud_session import get_session, require_tier, CloudUser
|
||||
from fastapi import Depends
|
||||
|
||||
@router.get("/api/imitate")
|
||||
def imitate(session: CloudUser = Depends(get_session)):
|
||||
# session.user_id — Directus UUID (cloud) or "local" (self-hosted)
|
||||
# session.tier — free | paid | premium | ultra | local
|
||||
# session.has_byok — True if user has a configured LLM backend
|
||||
...
|
||||
|
||||
@router.post("/api/custom-models")
|
||||
def list_custom_models(session: CloudUser = Depends(require_tier("premium"))):
|
||||
...
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
from circuitforge_core.cloud_session import CloudSessionFactory, CloudUser, detect_byok
|
||||
|
||||
__all__ = ["CloudUser", "get_session", "require_tier"]
|
||||
|
||||
_factory = CloudSessionFactory(
|
||||
product="avocet",
|
||||
byok_detector=detect_byok,
|
||||
)
|
||||
|
||||
get_session = _factory.dependency()
|
||||
require_tier = _factory.require_tier
|
||||
282
app/dashboard.py
282
app/dashboard.py
|
|
@ -1,282 +0,0 @@
|
|||
"""Avocet -- dashboard aggregate API.
|
||||
|
||||
GET /api/dashboard returns the current flywheel state:
|
||||
labeled_since_last_eval -- items labeled after the most recent bench run
|
||||
last_eval_timestamp -- ISO timestamp of newest bench_results summary
|
||||
last_eval_best_score -- best macro_f1 from that summary
|
||||
active_jobs -- jobs with status queued or running
|
||||
corrections_pending -- sft_candidates with status=needs_review
|
||||
corrections_export_ready -- approved sft candidates with non-blank correction
|
||||
recent_bench_runs -- most-recent timestamp + score per bench type
|
||||
signals -- computed booleans for UI nudge indicators
|
||||
|
||||
Thresholds in label_tool.yaml pipeline: section:
|
||||
pipeline:
|
||||
data_eval_threshold: 50 # labeled items since last bench to trigger nudge
|
||||
eval_train_threshold: 0.05 # improvement delta needed before retraining (future)
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import yaml
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_ROOT = Path(__file__).parent.parent
|
||||
_DATA_DIR: Path = _ROOT / "data"
|
||||
_CONFIG_DIR: Path | None = None
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
_DEFAULT_DATA_EVAL_THRESHOLD = 50
|
||||
_DEFAULT_EVAL_TRAIN_THRESHOLD = 0.05
|
||||
|
||||
|
||||
def set_data_dir(path: Path) -> None:
|
||||
global _DATA_DIR
|
||||
_DATA_DIR = path
|
||||
|
||||
def set_config_dir(path: Path | None) -> None:
|
||||
global _CONFIG_DIR
|
||||
_CONFIG_DIR = path
|
||||
|
||||
def _config_file() -> Path:
|
||||
if _CONFIG_DIR is not None:
|
||||
return _CONFIG_DIR / "label_tool.yaml"
|
||||
return _ROOT / "config" / "label_tool.yaml"
|
||||
|
||||
def _load_thresholds() -> tuple[int, float]:
|
||||
f = _config_file()
|
||||
if f.exists():
|
||||
try:
|
||||
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
||||
pipeline = raw.get("pipeline", {}) or {}
|
||||
return (
|
||||
int(pipeline.get("data_eval_threshold", _DEFAULT_DATA_EVAL_THRESHOLD)),
|
||||
float(pipeline.get("eval_train_threshold", _DEFAULT_EVAL_TRAIN_THRESHOLD)),
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to read pipeline thresholds: %s", exc)
|
||||
return _DEFAULT_DATA_EVAL_THRESHOLD, _DEFAULT_EVAL_TRAIN_THRESHOLD
|
||||
|
||||
def _load_score_records() -> list[dict]:
|
||||
path = _DATA_DIR / "email_score.jsonl"
|
||||
if not path.exists():
|
||||
return []
|
||||
records = []
|
||||
for line in path.read_text(encoding="utf-8").splitlines():
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
records.append(json.loads(line))
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
return records
|
||||
|
||||
def _find_latest_classifier_bench(results_dir_override: str = "") -> tuple[str | None, float | None]:
|
||||
"""Return (iso_timestamp, best_macro_f1) from the newest bench_results summary.
|
||||
|
||||
Checks results_dir from cforch config if set, then falls back to
|
||||
_ROOT/bench_results/. Returns (None, None) if no results exist.
|
||||
"""
|
||||
candidates = []
|
||||
if results_dir_override:
|
||||
candidates.append(Path(results_dir_override))
|
||||
else:
|
||||
f = _config_file()
|
||||
if f.exists():
|
||||
try:
|
||||
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
||||
rd = (raw.get("cforch", {}) or {}).get("results_dir", "")
|
||||
if rd:
|
||||
candidates.append(Path(rd))
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to read cforch.results_dir from config: %s", exc)
|
||||
candidates.append(_ROOT / "bench_results")
|
||||
|
||||
for rdir in candidates:
|
||||
if not rdir.exists():
|
||||
continue
|
||||
subdirs = sorted([d for d in rdir.iterdir() if d.is_dir()], key=lambda d: d.name)
|
||||
for subdir in reversed(subdirs):
|
||||
summary = subdir / "summary.json"
|
||||
if summary.exists():
|
||||
try:
|
||||
data = json.loads(summary.read_text(encoding="utf-8"))
|
||||
if not isinstance(data, dict):
|
||||
continue # cforch LLM-bench summaries are lists; skip
|
||||
ts = data.get("timestamp") or subdir.name
|
||||
score = data.get("best_macro_f1") or data.get("macro_f1")
|
||||
return ts, (float(score) if isinstance(score, (int, float)) else None)
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to parse summary.json at %s: %s", summary, exc)
|
||||
return None, None
|
||||
|
||||
# Keep old name as alias so existing callers in tests still work.
|
||||
_find_latest_eval = _find_latest_classifier_bench
|
||||
|
||||
|
||||
def _count_corrections() -> tuple[int, int]:
|
||||
"""Return (pending_count, export_ready_count)."""
|
||||
pending = 0
|
||||
export_ready = 0
|
||||
candidates_path = _DATA_DIR / "sft_candidates.jsonl"
|
||||
approved_path = _DATA_DIR / "sft_approved.jsonl"
|
||||
if candidates_path.exists():
|
||||
for line in candidates_path.read_text(encoding="utf-8").splitlines():
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
r = json.loads(line)
|
||||
if r.get("status") == "needs_review":
|
||||
pending += 1
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
if approved_path.exists():
|
||||
for line in approved_path.read_text(encoding="utf-8").splitlines():
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
r = json.loads(line)
|
||||
if (r.get("status") == "approved"
|
||||
and r.get("corrected_response")
|
||||
and str(r["corrected_response"]).strip()):
|
||||
export_ready += 1
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
return pending, export_ready
|
||||
|
||||
def _get_active_jobs() -> list[dict]:
|
||||
"""Query train SQLite DB for queued/running jobs. Returns [] if DB absent."""
|
||||
try:
|
||||
from app.train.train import _DB_PATH, _db, _init_db
|
||||
if not _DB_PATH.exists():
|
||||
return []
|
||||
_init_db()
|
||||
with _db() as conn:
|
||||
rows = conn.execute(
|
||||
"SELECT id, type, model_key, status FROM jobs WHERE status IN ('queued', 'running')"
|
||||
).fetchall()
|
||||
return [{"id": r["id"], "type": r["type"], "model_key": r["model_key"], "status": r["status"]} for r in rows]
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to query train jobs DB: %s", exc)
|
||||
return []
|
||||
|
||||
def _count_labeled_since(since_ts: str | None) -> int:
|
||||
records = _load_score_records()
|
||||
if since_ts is None:
|
||||
return len(records)
|
||||
return sum(1 for r in records if r.get("labeled_at", "") > since_ts)
|
||||
|
||||
|
||||
def _get_recent_bench_runs() -> dict:
|
||||
"""Return most-recent run summary for each bench type.
|
||||
|
||||
Each entry: {"timestamp": str|None, "metric": str|None, "score": float|None}
|
||||
"""
|
||||
runs: dict[str, dict] = {
|
||||
"classifier": {"timestamp": None, "metric": "macro_f1", "score": None},
|
||||
"llm": {"timestamp": None, "metric": None, "score": None},
|
||||
"style": {"timestamp": None, "metric": None, "score": None},
|
||||
"plans": {"timestamp": None, "metric": "avg_score", "score": None},
|
||||
}
|
||||
|
||||
# ── Classifier: bench_results/<run>/summary.json ──────────────────────
|
||||
clf_ts, clf_score = _find_latest_classifier_bench()
|
||||
if clf_ts:
|
||||
runs["classifier"]["timestamp"] = clf_ts
|
||||
runs["classifier"]["score"] = clf_score
|
||||
|
||||
# ── LLM bench + Style: benchmark_results/ ─────────────────────────────
|
||||
f = _config_file()
|
||||
bench_dir: Path | None = None
|
||||
if f.exists():
|
||||
try:
|
||||
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
||||
rd = (raw.get("cforch", {}) or {}).get("results_dir", "")
|
||||
if rd:
|
||||
bench_dir = Path(rd)
|
||||
except Exception:
|
||||
pass
|
||||
if bench_dir is None:
|
||||
bench_dir = _ROOT / "benchmark_results"
|
||||
|
||||
if bench_dir.exists():
|
||||
llm_files = sorted(
|
||||
[p for p in bench_dir.glob("*.json") if not p.name.startswith("style_")],
|
||||
key=lambda p: p.stat().st_mtime, reverse=True,
|
||||
)
|
||||
if llm_files:
|
||||
try:
|
||||
data = json.loads(llm_files[0].read_text(encoding="utf-8"))
|
||||
runs["llm"]["timestamp"] = data.get("timestamp") or llm_files[0].stem
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
style_files = sorted(bench_dir.glob("style_*.json"), reverse=True)
|
||||
if style_files:
|
||||
try:
|
||||
data = json.loads(style_files[0].read_text(encoding="utf-8"))
|
||||
if isinstance(data, list) and data:
|
||||
runs["style"]["timestamp"] = data[0].get("timestamp") or style_files[0].stem
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# ── Plans bench: data/plans_bench_results/plans_*.json ────────────────
|
||||
plans_dir = _DATA_DIR / "plans_bench_results"
|
||||
if plans_dir.exists():
|
||||
plans_files = sorted(plans_dir.glob("plans_*.json"), reverse=True)
|
||||
if plans_files:
|
||||
run_id = plans_files[0].stem
|
||||
try:
|
||||
d: dict = json.loads(plans_files[0].read_text(encoding="utf-8"))
|
||||
all_scores = [
|
||||
r["total_score"]
|
||||
for results in d.values()
|
||||
for r in results
|
||||
if isinstance(r, dict) and not r.get("error")
|
||||
]
|
||||
avg = round(sum(all_scores) / len(all_scores), 3) if all_scores else None
|
||||
try:
|
||||
date_part = run_id.removeprefix("plans_")
|
||||
date, time_part = date_part.split("_")
|
||||
ts_display = f"{date} {time_part[:2]}:{time_part[2:4]}"
|
||||
except Exception:
|
||||
ts_display = run_id
|
||||
runs["plans"]["timestamp"] = ts_display
|
||||
runs["plans"]["score"] = avg
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return runs
|
||||
|
||||
|
||||
@router.get("/dashboard")
|
||||
def get_dashboard() -> dict:
|
||||
data_threshold, _train_threshold = _load_thresholds()
|
||||
last_ts, last_score = _find_latest_classifier_bench()
|
||||
labeled_since = _count_labeled_since(last_ts)
|
||||
corrections_pending, corrections_export_ready = _count_corrections()
|
||||
active_jobs = _get_active_jobs()
|
||||
recent_bench = _get_recent_bench_runs()
|
||||
return {
|
||||
"labeled_since_last_eval": labeled_since,
|
||||
"last_eval_timestamp": last_ts,
|
||||
"last_eval_best_score": last_score,
|
||||
"active_jobs": active_jobs,
|
||||
"corrections_pending": corrections_pending,
|
||||
"corrections_export_ready": corrections_export_ready,
|
||||
"recent_bench_runs": recent_bench,
|
||||
"signals": {
|
||||
"data_to_eval": labeled_since >= data_threshold,
|
||||
"eval_to_train": False, # future: implement delta-F1 comparison
|
||||
"train_to_fleet": False, # future: implement fleet sync signal
|
||||
},
|
||||
}
|
||||
|
|
@ -1,393 +0,0 @@
|
|||
"""Avocet -- SFT candidate corrections API (moved from app/sft.py).
|
||||
|
||||
All endpoints are registered on `router` (a FastAPI APIRouter).
|
||||
Primary prefix: /api/corrections (backward-compat alias: /api/sft -- pending Vue SPA migration)
|
||||
|
||||
Module-level globals (_DATA_DIR, _CONFIG_DIR) follow the same
|
||||
testability pattern as api.py -- override them via set_data_dir() and
|
||||
set_config_dir() in test fixtures.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
import yaml
|
||||
from fastapi import APIRouter, Header, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.utils import append_jsonl, read_jsonl, write_jsonl
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_ROOT = Path(__file__).parent.parent.parent
|
||||
_DATA_DIR: Path = _ROOT / "data"
|
||||
_CONFIG_DIR: Path | None = None
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# -- Testability seams ---------------------------------------------------------
|
||||
|
||||
def set_data_dir(path: Path) -> None:
|
||||
global _DATA_DIR
|
||||
_DATA_DIR = path
|
||||
|
||||
|
||||
def set_config_dir(path: Path | None) -> None:
|
||||
global _CONFIG_DIR
|
||||
_CONFIG_DIR = path
|
||||
|
||||
|
||||
# -- Internal helpers ----------------------------------------------------------
|
||||
|
||||
def _config_file() -> Path:
|
||||
if _CONFIG_DIR is not None:
|
||||
return _CONFIG_DIR / "label_tool.yaml"
|
||||
return _ROOT / "config" / "label_tool.yaml"
|
||||
|
||||
|
||||
_DEFAULT_BENCH_RESULTS_DIR = "/Library/Development/CircuitForge/circuitforge-orch/scripts/bench_results"
|
||||
|
||||
|
||||
def set_default_bench_results_dir(path: str) -> None:
|
||||
"""Override the default bench_results_dir -- used by tests to avoid real filesystem."""
|
||||
global _DEFAULT_BENCH_RESULTS_DIR
|
||||
_DEFAULT_BENCH_RESULTS_DIR = path
|
||||
|
||||
|
||||
def _get_bench_results_dir() -> Path:
|
||||
f = _config_file()
|
||||
if f.exists():
|
||||
try:
|
||||
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
||||
d = raw.get("sft", {}).get("bench_results_dir", "")
|
||||
if d:
|
||||
return Path(d)
|
||||
except yaml.YAMLError as exc:
|
||||
logger.warning("Failed to parse SFT config %s: %s", f, exc)
|
||||
return Path(_DEFAULT_BENCH_RESULTS_DIR)
|
||||
|
||||
|
||||
def _candidates_file() -> Path:
|
||||
return _DATA_DIR / "sft_candidates.jsonl"
|
||||
|
||||
|
||||
def _approved_file() -> Path:
|
||||
return _DATA_DIR / "sft_approved.jsonl"
|
||||
|
||||
|
||||
def _read_candidates() -> list[dict]:
|
||||
return read_jsonl(_candidates_file())
|
||||
|
||||
|
||||
def _write_candidates(records: list[dict]) -> None:
|
||||
write_jsonl(_candidates_file(), records)
|
||||
|
||||
|
||||
def _is_exportable(r: dict) -> bool:
|
||||
"""Return True if an approved record is ready to include in SFT export."""
|
||||
return (
|
||||
r.get("status") == "approved"
|
||||
and bool(r.get("corrected_response"))
|
||||
and str(r["corrected_response"]).strip() != ""
|
||||
)
|
||||
|
||||
|
||||
# -- GET /runs -----------------------------------------------------------------
|
||||
|
||||
@router.get("/runs")
|
||||
def get_runs():
|
||||
"""List available benchmark runs in the configured bench_results_dir."""
|
||||
from scripts.sft_import import discover_runs
|
||||
bench_dir = _get_bench_results_dir()
|
||||
existing = _read_candidates()
|
||||
# benchmark_run_id in each record equals the run's directory name by cf-orch convention
|
||||
imported_run_ids = {
|
||||
r["benchmark_run_id"]
|
||||
for r in existing
|
||||
if r.get("benchmark_run_id") is not None
|
||||
}
|
||||
runs = discover_runs(bench_dir)
|
||||
return [
|
||||
{
|
||||
"run_id": r["run_id"],
|
||||
"timestamp": r["timestamp"],
|
||||
"candidate_count": r["candidate_count"],
|
||||
"already_imported": r["run_id"] in imported_run_ids,
|
||||
}
|
||||
for r in runs
|
||||
]
|
||||
|
||||
|
||||
# -- POST /import --------------------------------------------------------------
|
||||
|
||||
class ImportRequest(BaseModel):
|
||||
run_id: str
|
||||
|
||||
|
||||
@router.post("/import")
|
||||
def post_import(req: ImportRequest):
|
||||
"""Import one benchmark run's sft_candidates.jsonl into the local data dir."""
|
||||
from scripts.sft_import import discover_runs, import_run
|
||||
bench_dir = _get_bench_results_dir()
|
||||
runs = discover_runs(bench_dir)
|
||||
run = next((r for r in runs if r["run_id"] == req.run_id), None)
|
||||
if run is None:
|
||||
raise HTTPException(404, f"Run {req.run_id!r} not found in bench_results_dir")
|
||||
return import_run(run["sft_path"], _DATA_DIR)
|
||||
|
||||
|
||||
# -- GET /queue ----------------------------------------------------------------
|
||||
|
||||
@router.get("/queue")
|
||||
def get_queue(page: int = 1, per_page: int = 20):
|
||||
"""Return paginated needs_review candidates."""
|
||||
records = _read_candidates()
|
||||
pending = [r for r in records if r.get("status") == "needs_review"]
|
||||
start = (page - 1) * per_page
|
||||
return {
|
||||
"items": pending[start:start + per_page],
|
||||
"total": len(pending),
|
||||
"page": page,
|
||||
"per_page": per_page,
|
||||
}
|
||||
|
||||
|
||||
# -- POST /submit --------------------------------------------------------------
|
||||
|
||||
FailureCategory = Literal[
|
||||
"scoring_artifact",
|
||||
"style_violation",
|
||||
"partial_answer",
|
||||
"wrong_answer",
|
||||
"format_error",
|
||||
"hallucination",
|
||||
]
|
||||
|
||||
|
||||
class SubmitRequest(BaseModel):
|
||||
id: str
|
||||
action: Literal["correct", "discard", "flag"]
|
||||
corrected_response: str | None = None
|
||||
failure_category: FailureCategory | None = None
|
||||
|
||||
|
||||
@router.post("/submit")
|
||||
def post_submit(req: SubmitRequest):
|
||||
"""Record a reviewer decision for one SFT candidate."""
|
||||
if req.action == "correct":
|
||||
if not req.corrected_response or not req.corrected_response.strip():
|
||||
raise HTTPException(422, "corrected_response must be non-empty when action is 'correct'")
|
||||
|
||||
records = _read_candidates()
|
||||
idx = next((i for i, r in enumerate(records) if r.get("id") == req.id), None)
|
||||
if idx is None:
|
||||
raise HTTPException(404, f"Record {req.id!r} not found")
|
||||
|
||||
record = records[idx]
|
||||
if record.get("status") != "needs_review":
|
||||
raise HTTPException(409, f"Record is not in needs_review state (current: {record.get('status')})")
|
||||
|
||||
if req.action == "correct":
|
||||
records[idx] = {
|
||||
**record,
|
||||
"status": "approved",
|
||||
"corrected_response": req.corrected_response,
|
||||
"failure_category": req.failure_category,
|
||||
}
|
||||
_write_candidates(records)
|
||||
append_jsonl(_approved_file(), records[idx])
|
||||
elif req.action == "discard":
|
||||
records[idx] = {**record, "status": "discarded"}
|
||||
_write_candidates(records)
|
||||
else: # flag
|
||||
records[idx] = {**record, "status": "model_rejected"}
|
||||
_write_candidates(records)
|
||||
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
# -- POST /undo ----------------------------------------------------------------
|
||||
|
||||
class UndoRequest(BaseModel):
|
||||
id: str
|
||||
|
||||
|
||||
@router.post("/undo")
|
||||
def post_undo(req: UndoRequest):
|
||||
"""Restore a previously actioned candidate back to needs_review."""
|
||||
records = _read_candidates()
|
||||
idx = next((i for i, r in enumerate(records) if r.get("id") == req.id), None)
|
||||
if idx is None:
|
||||
raise HTTPException(404, f"Record {req.id!r} not found")
|
||||
|
||||
record = records[idx]
|
||||
old_status = record.get("status")
|
||||
if old_status == "needs_review":
|
||||
raise HTTPException(409, "Record is already in needs_review state")
|
||||
|
||||
records[idx] = {**record, "status": "needs_review", "corrected_response": None}
|
||||
_write_candidates(records)
|
||||
|
||||
# If it was approved, remove from the approved file too
|
||||
if old_status == "approved":
|
||||
approved = read_jsonl(_approved_file())
|
||||
write_jsonl(_approved_file(), [r for r in approved if r.get("id") != req.id])
|
||||
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
# -- GET /export ---------------------------------------------------------------
|
||||
|
||||
@router.get("/export")
|
||||
def get_export() -> StreamingResponse:
|
||||
"""Stream approved records as SFT-ready JSONL for download."""
|
||||
exportable = [r for r in read_jsonl(_approved_file()) if _is_exportable(r)]
|
||||
|
||||
def generate():
|
||||
for r in exportable:
|
||||
record = {
|
||||
"messages": r.get("prompt_messages", []) + [
|
||||
{"role": "assistant", "content": r["corrected_response"]}
|
||||
]
|
||||
}
|
||||
yield json.dumps(record) + "\n"
|
||||
|
||||
timestamp = datetime.now(timezone.utc).strftime("%Y%m%d-%H%M%S")
|
||||
return StreamingResponse(
|
||||
generate(),
|
||||
media_type="application/x-ndjson",
|
||||
headers={
|
||||
"Content-Disposition": f'attachment; filename="sft_export_{timestamp}.jsonl"'
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# -- GET /stats ----------------------------------------------------------------
|
||||
|
||||
@router.get("/stats")
|
||||
def get_stats() -> dict[str, object]:
|
||||
"""Return counts by status, model, and task type."""
|
||||
records = _read_candidates()
|
||||
by_status: dict[str, int] = {}
|
||||
by_model: dict[str, int] = {}
|
||||
by_task_type: dict[str, int] = {}
|
||||
|
||||
for r in records:
|
||||
status = r.get("status", "unknown")
|
||||
by_status[status] = by_status.get(status, 0) + 1
|
||||
model = r.get("model_name", "unknown")
|
||||
by_model[model] = by_model.get(model, 0) + 1
|
||||
task_type = r.get("task_type", "unknown")
|
||||
by_task_type[task_type] = by_task_type.get(task_type, 0) + 1
|
||||
|
||||
approved = read_jsonl(_approved_file())
|
||||
export_ready = sum(1 for r in approved if _is_exportable(r))
|
||||
|
||||
return {
|
||||
"total": len(records),
|
||||
"by_status": by_status,
|
||||
"by_model": by_model,
|
||||
"by_task_type": by_task_type,
|
||||
"export_ready": export_ready,
|
||||
}
|
||||
|
||||
|
||||
# -- GET /config ---------------------------------------------------------------
|
||||
|
||||
@router.get("/config")
|
||||
def get_sft_config() -> dict:
|
||||
"""Return the current SFT configuration (bench_results_dir)."""
|
||||
f = _config_file()
|
||||
if not f.exists():
|
||||
return {"bench_results_dir": ""}
|
||||
try:
|
||||
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
||||
except yaml.YAMLError:
|
||||
return {"bench_results_dir": ""}
|
||||
sft_section = raw.get("sft") or {}
|
||||
return {"bench_results_dir": sft_section.get("bench_results_dir", "")}
|
||||
|
||||
|
||||
class SftConfigPayload(BaseModel):
|
||||
bench_results_dir: str
|
||||
|
||||
|
||||
@router.post("/config")
|
||||
def post_sft_config(payload: SftConfigPayload) -> dict:
|
||||
"""Write the bench_results_dir setting to the config file."""
|
||||
f = _config_file()
|
||||
f.parent.mkdir(parents=True, exist_ok=True)
|
||||
try:
|
||||
raw = yaml.safe_load(f.read_text(encoding="utf-8")) if f.exists() else {}
|
||||
raw = raw or {}
|
||||
except yaml.YAMLError:
|
||||
raw = {}
|
||||
raw["sft"] = {"bench_results_dir": payload.bench_results_dir}
|
||||
tmp = f.with_suffix(".tmp")
|
||||
tmp.write_text(yaml.dump(raw, allow_unicode=True, sort_keys=False), encoding="utf-8")
|
||||
tmp.rename(f)
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
# -- POST /ingest --------------------------------------------------------------
|
||||
|
||||
class IngestRequest(BaseModel):
|
||||
source: str # e.g. "peregrine", "kiwi"
|
||||
task_type: str # e.g. "email_classification", "recipe_suggestion"
|
||||
prompt: str # the prompt that was sent to the LLM
|
||||
response: str # the LLM's original response
|
||||
correction: str # the human-corrected response
|
||||
label: str | None = None # optional label/category
|
||||
|
||||
|
||||
@router.post("/ingest")
|
||||
def post_ingest(
|
||||
req: IngestRequest,
|
||||
authorization: str | None = Header(default=None),
|
||||
) -> dict:
|
||||
"""Ingest a correction from a sibling CF product.
|
||||
|
||||
Authentication: Authorization: Bearer <AVOCET_INGESTION_SECRET>
|
||||
|
||||
Creates a sft_candidates record with status='approved' (pre-approved by
|
||||
the calling product -- human review already happened upstream). Also writes
|
||||
to sft_approved.jsonl so it is immediately included in export counts.
|
||||
|
||||
Returns {"ok": True, "id": "<uuid>"}.
|
||||
"""
|
||||
expected_secret = os.environ.get("AVOCET_INGESTION_SECRET", "")
|
||||
if not expected_secret:
|
||||
raise HTTPException(503, "Ingestion not configured -- AVOCET_INGESTION_SECRET not set")
|
||||
|
||||
if not authorization or not authorization.startswith("Bearer "):
|
||||
raise HTTPException(401, "Missing or malformed Authorization header")
|
||||
|
||||
token = authorization.removeprefix("Bearer ").strip()
|
||||
if token != expected_secret:
|
||||
raise HTTPException(403, "Invalid ingestion secret")
|
||||
|
||||
record_id = str(uuid.uuid4())
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
record = {
|
||||
"id": record_id,
|
||||
"source": req.source,
|
||||
"task_type": req.task_type,
|
||||
"status": "approved",
|
||||
"prompt_messages": [{"role": "user", "content": req.prompt}],
|
||||
"model_response": req.response,
|
||||
"corrected_response": req.correction,
|
||||
"label": req.label,
|
||||
"timestamp": now,
|
||||
"benchmark_run_id": None,
|
||||
}
|
||||
append_jsonl(_candidates_file(), record)
|
||||
append_jsonl(_approved_file(), record)
|
||||
return {"ok": True, "id": record_id}
|
||||
|
|
@ -1,243 +0,0 @@
|
|||
"""Avocet -- IMAP fetch utilities and fetch API routes.
|
||||
|
||||
All IMAP helper functions (from app/imap_fetch.py) plus the
|
||||
/api/accounts/test and /api/fetch/stream endpoints.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import email as _email_lib
|
||||
import hashlib
|
||||
import imaplib
|
||||
import json
|
||||
import yaml
|
||||
from datetime import datetime, timedelta
|
||||
from email.header import decode_header as _raw_decode
|
||||
from pathlib import Path
|
||||
from typing import Iterator
|
||||
|
||||
from fastapi import APIRouter, Query
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.utils import extract_body, read_jsonl, write_jsonl
|
||||
|
||||
_ROOT = Path(__file__).parent.parent.parent
|
||||
_DATA_DIR: Path = _ROOT / "data"
|
||||
_CONFIG_DIR: Path | None = None
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def set_data_dir(path: Path) -> None:
|
||||
global _DATA_DIR
|
||||
_DATA_DIR = path
|
||||
|
||||
|
||||
def set_config_dir(path: Path | None) -> None:
|
||||
global _CONFIG_DIR
|
||||
_CONFIG_DIR = path
|
||||
|
||||
|
||||
def _config_file() -> Path:
|
||||
if _CONFIG_DIR is not None:
|
||||
return _CONFIG_DIR / "label_tool.yaml"
|
||||
return _ROOT / "config" / "label_tool.yaml"
|
||||
|
||||
|
||||
def _queue_file() -> Path:
|
||||
return _DATA_DIR / "email_label_queue.jsonl"
|
||||
|
||||
|
||||
def _get_config_accounts() -> list[dict]:
|
||||
f = _config_file()
|
||||
if not f.exists():
|
||||
return []
|
||||
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
||||
return raw.get("accounts", [])
|
||||
|
||||
|
||||
# ── IMAP decode helpers ───────────────────────────────────────────────────────
|
||||
|
||||
def _decode_str(value: str | None) -> str:
|
||||
if not value:
|
||||
return ""
|
||||
parts = _raw_decode(value)
|
||||
out = []
|
||||
for part, enc in parts:
|
||||
if isinstance(part, bytes):
|
||||
out.append(part.decode(enc or "utf-8", errors="replace"))
|
||||
else:
|
||||
out.append(str(part))
|
||||
return " ".join(out).strip()
|
||||
|
||||
|
||||
def entry_key(e: dict) -> str:
|
||||
"""Stable MD5 content-hash for dedup — matches label_tool.py _entry_key."""
|
||||
key = (e.get("subject", "") + (e.get("body", "") or "")[:100])
|
||||
return hashlib.md5(key.encode("utf-8", errors="replace")).hexdigest()
|
||||
|
||||
|
||||
# ── Wide search terms ────────────────────────────────────────────────────────
|
||||
|
||||
_WIDE_TERMS = [
|
||||
"interview", "phone screen", "video call", "zoom link", "schedule a call",
|
||||
"offer letter", "job offer", "offer of employment", "pleased to offer",
|
||||
"unfortunately", "not moving forward", "other candidates", "regret to inform",
|
||||
"no longer", "decided not to", "decided to go with",
|
||||
"opportunity", "interested in your background", "reached out", "great fit",
|
||||
"exciting role", "love to connect",
|
||||
"assessment", "questionnaire", "culture fit", "culture-fit", "online assessment",
|
||||
"application received", "thank you for applying", "application confirmation",
|
||||
"you applied", "your application for",
|
||||
"reschedule", "rescheduled", "new time", "moved to", "postponed", "new date",
|
||||
"job digest", "jobs you may like", "recommended jobs", "jobs for you",
|
||||
"new jobs", "job alert",
|
||||
"came across your profile", "reaching out about", "great fit for a role",
|
||||
"exciting opportunity",
|
||||
"welcome to the team", "start date", "onboarding", "first day", "we're excited to have you",
|
||||
"application", "recruiter", "recruiting", "hiring", "candidate",
|
||||
]
|
||||
|
||||
|
||||
# ── Public API ────────────────────────────────────────────────────────────────
|
||||
|
||||
def test_connection(acc: dict) -> tuple[bool, str, int | None]:
|
||||
"""Connect, login, select folder. Returns (ok, human_message, message_count|None)."""
|
||||
host = acc.get("host", "")
|
||||
port = int(acc.get("port", 993))
|
||||
use_ssl = acc.get("use_ssl", True)
|
||||
username = acc.get("username", "")
|
||||
password = acc.get("password", "")
|
||||
folder = acc.get("folder", "INBOX")
|
||||
if not host or not username or not password:
|
||||
return False, "Host, username, and password are all required.", None
|
||||
try:
|
||||
conn = (imaplib.IMAP4_SSL if use_ssl else imaplib.IMAP4)(host, port)
|
||||
conn.login(username, password)
|
||||
_, data = conn.select(folder, readonly=True)
|
||||
count_raw = data[0].decode() if data and data[0] else "0"
|
||||
count = int(count_raw) if count_raw.isdigit() else 0
|
||||
conn.logout()
|
||||
return True, f"Connected — {count:,} message(s) in {folder}.", count
|
||||
except Exception as exc:
|
||||
return False, str(exc), None
|
||||
|
||||
|
||||
def fetch_account_stream(
|
||||
acc: dict,
|
||||
days_back: int,
|
||||
limit: int,
|
||||
known_keys: set[str],
|
||||
) -> Iterator[dict]:
|
||||
"""Generator — yields progress dicts while fetching emails via IMAP.
|
||||
|
||||
Mutates `known_keys` in place for cross-account dedup within one fetch session.
|
||||
|
||||
Yields event dicts with "type" key:
|
||||
{"type": "start", "account": str, "total_uids": int}
|
||||
{"type": "progress", "account": str, "fetched": int, "total_uids": int}
|
||||
{"type": "done", "account": str, "added": int, "skipped": int, "emails": list}
|
||||
"""
|
||||
name = acc.get("name", acc.get("username", "?"))
|
||||
host = acc.get("host", "imap.gmail.com")
|
||||
port = int(acc.get("port", 993))
|
||||
use_ssl = acc.get("use_ssl", True)
|
||||
username = acc["username"]
|
||||
password = acc["password"]
|
||||
folder = acc.get("folder", "INBOX")
|
||||
since = (datetime.now() - timedelta(days=days_back)).strftime("%d-%b-%Y")
|
||||
|
||||
conn = (imaplib.IMAP4_SSL if use_ssl else imaplib.IMAP4)(host, port)
|
||||
conn.login(username, password)
|
||||
conn.select(folder, readonly=True)
|
||||
|
||||
seen_uids: dict[bytes, None] = {}
|
||||
for term in _WIDE_TERMS:
|
||||
try:
|
||||
_, data = conn.search(None, f'(SUBJECT "{term}" SINCE "{since}")')
|
||||
for uid in (data[0] or b"").split():
|
||||
seen_uids[uid] = None
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
uids = list(seen_uids.keys())[: limit * 3]
|
||||
yield {"type": "start", "account": name, "total_uids": len(uids)}
|
||||
|
||||
emails: list[dict] = []
|
||||
skipped = 0
|
||||
for i, uid in enumerate(uids):
|
||||
if len(emails) >= limit:
|
||||
break
|
||||
if i % 5 == 0:
|
||||
yield {"type": "progress", "account": name, "fetched": len(emails), "total_uids": len(uids)}
|
||||
try:
|
||||
_, raw_data = conn.fetch(uid, "(RFC822)")
|
||||
if not raw_data or not raw_data[0]:
|
||||
continue
|
||||
msg = _email_lib.message_from_bytes(raw_data[0][1])
|
||||
subj = _decode_str(msg.get("Subject", ""))
|
||||
from_addr = _decode_str(msg.get("From", ""))
|
||||
date = _decode_str(msg.get("Date", ""))
|
||||
body = extract_body(msg)[:800]
|
||||
entry = {"subject": subj, "body": body, "from_addr": from_addr,
|
||||
"date": date, "account": name}
|
||||
k = entry_key(entry)
|
||||
if k not in known_keys:
|
||||
known_keys.add(k)
|
||||
emails.append(entry)
|
||||
else:
|
||||
skipped += 1
|
||||
except Exception:
|
||||
skipped += 1
|
||||
|
||||
try:
|
||||
conn.logout()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
yield {"type": "done", "account": name, "added": len(emails), "skipped": skipped,
|
||||
"emails": emails}
|
||||
|
||||
|
||||
class AccountTestRequest(BaseModel):
|
||||
account: dict
|
||||
|
||||
|
||||
@router.post("/accounts/test")
|
||||
def test_account_route(req: AccountTestRequest) -> dict:
|
||||
ok, message, count = test_connection(req.account)
|
||||
return {"ok": ok, "message": message, "count": count}
|
||||
|
||||
|
||||
@router.get("/fetch/stream")
|
||||
def fetch_stream(
|
||||
accounts: str = Query(default=""),
|
||||
days_back: int = Query(default=90, ge=1, le=365),
|
||||
limit: int = Query(default=150, ge=1, le=1000),
|
||||
mode: str = Query(default="wide"),
|
||||
) -> StreamingResponse:
|
||||
selected_names = {n.strip() for n in accounts.split(",") if n.strip()}
|
||||
all_accounts = _get_config_accounts()
|
||||
selected = [a for a in all_accounts if a.get("name") in selected_names]
|
||||
|
||||
def generate():
|
||||
known_keys = {entry_key(x) for x in read_jsonl(_queue_file())}
|
||||
total_added = 0
|
||||
for acc in selected:
|
||||
try:
|
||||
batch_emails: list[dict] = []
|
||||
for event in fetch_account_stream(acc, days_back, limit, known_keys):
|
||||
if event["type"] == "done":
|
||||
batch_emails = event.pop("emails", [])
|
||||
total_added += event["added"]
|
||||
yield f"data: {json.dumps(event)}\n\n"
|
||||
if batch_emails:
|
||||
existing = read_jsonl(_queue_file())
|
||||
write_jsonl(_queue_file(), existing + batch_emails)
|
||||
except Exception as exc:
|
||||
yield f"data: {json.dumps({'type': 'error', 'account': acc.get('name', '?'), 'message': str(exc)})}\n\n"
|
||||
queue_size = len(read_jsonl(_queue_file()))
|
||||
yield f"data: {json.dumps({'type': 'complete', 'total_added': total_added, 'queue_size': queue_size})}\n\n"
|
||||
|
||||
return StreamingResponse(generate(), media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"})
|
||||
|
|
@ -1,729 +0,0 @@
|
|||
"""Avocet — Imitate tab API.
|
||||
|
||||
Fetches real samples from sibling CF product APIs, sends them through selected
|
||||
local LLMs (ollama), and streams responses back to the UI. Results can be
|
||||
pushed into the SFT corrections queue for human review.
|
||||
|
||||
All endpoints registered on `router`. api.py includes this with prefix="/api/imitate".
|
||||
|
||||
Module-level globals follow the same testability pattern as cforch.py and sft.py:
|
||||
override _CONFIG_DIR and _DATA_DIR via set_config_dir() / set_data_dir() in tests.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from urllib.error import URLError
|
||||
from urllib.request import Request, urlopen
|
||||
|
||||
import httpx
|
||||
import yaml
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.utils import append_jsonl
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_ROOT = Path(__file__).parent.parent.parent
|
||||
_CONFIG_DIR: Path | None = None
|
||||
_DATA_DIR: Path = _ROOT / "data"
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# ── Testability seams ──────────────────────────────────────────────────────────
|
||||
|
||||
def set_config_dir(path: Path | None) -> None:
|
||||
global _CONFIG_DIR
|
||||
_CONFIG_DIR = path
|
||||
|
||||
|
||||
def set_data_dir(path: Path) -> None:
|
||||
global _DATA_DIR
|
||||
_DATA_DIR = path
|
||||
|
||||
|
||||
# ── Internal helpers ───────────────────────────────────────────────────────────
|
||||
|
||||
def _config_file() -> Path:
|
||||
if _CONFIG_DIR is not None:
|
||||
return _CONFIG_DIR / "label_tool.yaml"
|
||||
return _ROOT / "config" / "label_tool.yaml"
|
||||
|
||||
|
||||
def _load_imitate_config() -> dict:
|
||||
"""Read label_tool.yaml and return the imitate sub-dict (or {} if absent)."""
|
||||
f = _config_file()
|
||||
if not f.exists():
|
||||
return {}
|
||||
try:
|
||||
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
||||
except yaml.YAMLError as exc:
|
||||
logger.warning("Failed to parse imitate config %s: %s", f, exc)
|
||||
return {}
|
||||
return raw.get("imitate", {}) or {}
|
||||
|
||||
|
||||
def _load_cforch_config() -> dict:
|
||||
"""Read cforch section for ollama_url fallback."""
|
||||
f = _config_file()
|
||||
if not f.exists():
|
||||
return {}
|
||||
try:
|
||||
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
||||
except yaml.YAMLError as exc:
|
||||
return {}
|
||||
return raw.get("cforch", {}) or {}
|
||||
|
||||
|
||||
def _ollama_url(cfg: dict) -> str:
|
||||
cforch = _load_cforch_config()
|
||||
return cfg.get("ollama_url") or cforch.get("ollama_url") or "http://localhost:11434"
|
||||
|
||||
|
||||
def _cforch_url() -> str:
|
||||
cforch = _load_cforch_config()
|
||||
return cforch.get("coordinator_url") or "http://localhost:7700"
|
||||
|
||||
|
||||
def _resolve_task_model(cforch_base: str, product: str, task: str) -> dict | None:
|
||||
"""Return {model_id, service_type} for a product.task assignment, or None if not found.
|
||||
|
||||
Calls GET coordinator/api/assignments and filters by product+task.
|
||||
The model registry entry is fetched separately to get service_type.
|
||||
Returns None (not raises) — callers emit a 'model_done' error event instead.
|
||||
"""
|
||||
try:
|
||||
asgn_resp = httpx.get(f"{cforch_base}/api/assignments", timeout=5.0)
|
||||
asgn_resp.raise_for_status()
|
||||
assignments: list[dict] = asgn_resp.json().get("assignments", []) or []
|
||||
match = next(
|
||||
(a for a in assignments if a.get("product") == product and a.get("task") == task),
|
||||
None,
|
||||
)
|
||||
if match is None:
|
||||
return None
|
||||
model_id: str = match.get("model_id", "")
|
||||
if not model_id:
|
||||
return None
|
||||
|
||||
# Look up service_type from model registry
|
||||
reg_resp = httpx.get(f"{cforch_base}/api/model-registry", timeout=5.0)
|
||||
service_type = "cf-text" # sensible default
|
||||
if reg_resp.is_success:
|
||||
models: list[dict] = reg_resp.json().get("models", []) or []
|
||||
reg_entry = next((m for m in models if m.get("model_id") == model_id), None)
|
||||
if reg_entry:
|
||||
service_type = reg_entry.get("service_type", "cf-text") or "cf-text"
|
||||
|
||||
return {"model_id": model_id, "service_type": service_type}
|
||||
except Exception as exc:
|
||||
logger.warning("Task resolution failed for %s.%s: %s", product, task, exc)
|
||||
return None
|
||||
|
||||
|
||||
def _cforch_catalog(cforch_base: str) -> list[dict]:
|
||||
"""Fetch the live cf-text catalog from cf-orch.
|
||||
|
||||
Filters out proxy entries (ollama://, vllm://, http://) — those models are
|
||||
served by their own services and should not be allocated via cf-text.
|
||||
Returns only models with real file-system paths that cf-text can load directly.
|
||||
"""
|
||||
try:
|
||||
resp = httpx.get(
|
||||
f"{cforch_base}/api/services/cf-text/catalog",
|
||||
params={"node_id": "heimdall"},
|
||||
timeout=5.0,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
raw = resp.json()
|
||||
result = []
|
||||
for model_id, entry in raw.items():
|
||||
if not isinstance(entry, dict):
|
||||
continue
|
||||
path = entry.get("path", "")
|
||||
# Skip proxy entries — they're routed through other services
|
||||
if "://" in path:
|
||||
continue
|
||||
result.append({
|
||||
"id": model_id,
|
||||
"vram_mb": entry.get("vram_mb", 0),
|
||||
"description": entry.get("description", ""),
|
||||
})
|
||||
return result
|
||||
except Exception as exc:
|
||||
logger.warning("Could not fetch cf-orch catalog: %s", exc)
|
||||
return []
|
||||
|
||||
|
||||
def _http_get_json(url: str, timeout: int = 5) -> Any:
|
||||
"""Fetch JSON from url; raise URLError on failure."""
|
||||
req = Request(url, headers={"Accept": "application/json"})
|
||||
with urlopen(req, timeout=timeout) as resp:
|
||||
return json.loads(resp.read().decode("utf-8"))
|
||||
|
||||
|
||||
def _is_online(base_url: str, health_path: str = "/api/health") -> bool:
|
||||
"""Return True if the product's health endpoint responds OK."""
|
||||
try:
|
||||
data = _http_get_json(f"{base_url.rstrip('/')}{health_path}", timeout=2)
|
||||
return bool(data)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _extract_sample(
|
||||
raw: Any,
|
||||
text_fields: list[str],
|
||||
sample_index: int = 0,
|
||||
sample_key: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Pull one item from a list or dict response and extract text_fields.
|
||||
|
||||
sample_key: if provided, unwrap raw[sample_key] before looking for a list.
|
||||
Falls back to a set of conventional envelope keys if sample_key is absent.
|
||||
"""
|
||||
item: dict[str, Any]
|
||||
if isinstance(raw, list):
|
||||
if not raw:
|
||||
return {}
|
||||
item = raw[min(sample_index, len(raw) - 1)]
|
||||
elif isinstance(raw, dict):
|
||||
# Use declared sample_key first, then fall back to conventional names.
|
||||
_ENVELOPE_KEYS = (
|
||||
"samples", "items", "results", "data", "jobs", "listings",
|
||||
"pantry", "saved_searches", "entries", "calls", "records",
|
||||
)
|
||||
search_keys = ([sample_key] if sample_key else []) + list(_ENVELOPE_KEYS)
|
||||
for key in search_keys:
|
||||
if key in raw and isinstance(raw[key], list):
|
||||
lst = raw[key]
|
||||
item = lst[min(sample_index, len(lst) - 1)] if lst else {}
|
||||
break
|
||||
else:
|
||||
item = raw
|
||||
else:
|
||||
return {}
|
||||
|
||||
parts = []
|
||||
for field in text_fields:
|
||||
val = item.get(field)
|
||||
if val and str(val).strip():
|
||||
parts.append(f"**{field}**: {val}")
|
||||
return {"item": item, "text": "\n\n".join(parts)}
|
||||
|
||||
|
||||
def _candidates_file() -> Path:
|
||||
return _DATA_DIR / "sft_candidates.jsonl"
|
||||
|
||||
|
||||
def _sse(data: dict) -> str:
|
||||
return f"data: {json.dumps(data)}\n\n"
|
||||
|
||||
|
||||
def _fetch_image_b64(image_url: str) -> str:
|
||||
"""Download an image URL and return it as a base64 string for ollama.
|
||||
|
||||
Returns empty string on any failure — a missing image is non-fatal;
|
||||
the model will still run against the text prompt alone.
|
||||
"""
|
||||
try:
|
||||
req = Request(image_url, headers={"User-Agent": "Avocet/1.0"})
|
||||
with urlopen(req, timeout=10) as resp:
|
||||
return base64.b64encode(resp.read()).decode("ascii")
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to fetch image %s: %s", image_url, exc)
|
||||
return ""
|
||||
|
||||
|
||||
def _run_ollama_streaming(
|
||||
ollama_base: str,
|
||||
model_id: str,
|
||||
prompt: str,
|
||||
temperature: float,
|
||||
system: str = "",
|
||||
images: list[str] | None = None,
|
||||
) -> tuple[str, int]:
|
||||
"""Call ollama /api/generate with stream=False; return (full_response, elapsed_ms).
|
||||
|
||||
Blocks until the model finishes; yields nothing — streaming is handled by
|
||||
the SSE generator in run_imitate().
|
||||
|
||||
system: optional system prompt passed as a separate field to ollama.
|
||||
images: list of base64-encoded image strings (vision models only).
|
||||
"""
|
||||
url = f"{ollama_base.rstrip('/')}/api/generate"
|
||||
body: dict = {
|
||||
"model": model_id,
|
||||
"prompt": prompt,
|
||||
"stream": False,
|
||||
"options": {"temperature": temperature},
|
||||
}
|
||||
if system:
|
||||
body["system"] = system
|
||||
if images:
|
||||
body["images"] = images
|
||||
payload = json.dumps(body).encode("utf-8")
|
||||
req = Request(url, data=payload, method="POST",
|
||||
headers={"Content-Type": "application/json"})
|
||||
t0 = time.time()
|
||||
try:
|
||||
with urlopen(req, timeout=120) as resp:
|
||||
body = json.loads(resp.read().decode("utf-8"))
|
||||
elapsed = int((time.time() - t0) * 1000)
|
||||
return body.get("response", ""), elapsed
|
||||
except Exception as exc:
|
||||
elapsed = int((time.time() - t0) * 1000)
|
||||
raise RuntimeError(str(exc)) from exc
|
||||
|
||||
|
||||
def _run_cftext(
|
||||
cforch_base: str,
|
||||
model_id: str,
|
||||
prompt: str,
|
||||
system: str,
|
||||
temperature: float,
|
||||
startup_timeout_s: float = 180.0,
|
||||
user_id: str | None = None,
|
||||
) -> tuple[str, int, bool]:
|
||||
"""Allocate cf-text via cf-orch, generate, release. Returns (response, elapsed_ms, cold_started).
|
||||
|
||||
Raises RuntimeError on allocation failure or generation error.
|
||||
cold_started=True means the service was launched from scratch (caller may log this).
|
||||
|
||||
Cold-start detection uses coordinator state signals (running/stopped) rather than
|
||||
polling the service health endpoint — this fails fast on model load errors instead
|
||||
of waiting out the full timeout.
|
||||
"""
|
||||
# Allocate
|
||||
alloc_resp = httpx.post(
|
||||
f"{cforch_base}/api/services/cf-text/allocate",
|
||||
json={
|
||||
"model_candidates": [model_id],
|
||||
"caller": "avocet",
|
||||
"pipeline": "imitate",
|
||||
**({"user_id": user_id} if user_id else {}),
|
||||
},
|
||||
timeout=30.0,
|
||||
)
|
||||
alloc_resp.raise_for_status()
|
||||
data = alloc_resp.json()
|
||||
service_url: str = data["url"]
|
||||
allocation_id: str = data.get("allocation_id", "")
|
||||
node_id: str = data.get("node_id", "")
|
||||
gpu_id: int | None = data.get("gpu_id")
|
||||
cold_started = data.get("started", False) and not data.get("warm", True)
|
||||
|
||||
# Wait for ready using coordinator state signals
|
||||
if cold_started:
|
||||
deadline = time.monotonic() + startup_timeout_s
|
||||
probe_misses = 0
|
||||
while time.monotonic() < deadline:
|
||||
try:
|
||||
status = httpx.get(
|
||||
f"{cforch_base}/api/services/cf-text/status", timeout=5.0
|
||||
)
|
||||
if status.is_success:
|
||||
instances = status.json().get("instances", [])
|
||||
match = next(
|
||||
(i for i in instances
|
||||
if i.get("node_id") == node_id and i.get("gpu_id") == gpu_id),
|
||||
None,
|
||||
)
|
||||
if match:
|
||||
probe_misses = 0
|
||||
state = match.get("state", "")
|
||||
if state == "running":
|
||||
break
|
||||
elif state == "stopped":
|
||||
if allocation_id:
|
||||
httpx.delete(
|
||||
f"{cforch_base}/api/services/cf-text/allocations/{allocation_id}",
|
||||
timeout=5.0,
|
||||
)
|
||||
raise RuntimeError(f"cf-text failed to load {model_id!r} (service stopped)")
|
||||
else:
|
||||
probe_misses += 1
|
||||
if probe_misses >= 6:
|
||||
# Coordinator hasn't registered instance yet — fall back to health poll
|
||||
try:
|
||||
if httpx.get(f"{service_url}/health", timeout=3.0).is_success:
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
except RuntimeError:
|
||||
raise
|
||||
except Exception:
|
||||
pass
|
||||
time.sleep(2.0)
|
||||
else:
|
||||
if allocation_id:
|
||||
httpx.delete(f"{cforch_base}/api/services/cf-text/allocations/{allocation_id}", timeout=5.0)
|
||||
raise RuntimeError(f"cf-text cold start timed out after {startup_timeout_s:.0f}s")
|
||||
|
||||
# Generate
|
||||
messages: list[dict] = []
|
||||
if system:
|
||||
messages.append({"role": "system", "content": system})
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
t0 = time.time()
|
||||
try:
|
||||
gen_resp = httpx.post(
|
||||
f"{service_url}/v1/chat/completions",
|
||||
json={
|
||||
"model": model_id,
|
||||
"messages": messages,
|
||||
"max_tokens": 300,
|
||||
"temperature": temperature,
|
||||
"stream": False,
|
||||
},
|
||||
timeout=120.0,
|
||||
)
|
||||
gen_resp.raise_for_status()
|
||||
elapsed_ms = int((time.time() - t0) * 1000)
|
||||
content = gen_resp.json()["choices"][0]["message"]["content"]
|
||||
return content.strip(), elapsed_ms, cold_started
|
||||
except Exception as exc:
|
||||
elapsed_ms = int((time.time() - t0) * 1000)
|
||||
raise RuntimeError(str(exc)) from exc
|
||||
finally:
|
||||
if allocation_id:
|
||||
try:
|
||||
httpx.delete(f"{cforch_base}/api/services/cf-text/allocations/{allocation_id}", timeout=5.0)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
# ── GET /products ──────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/products")
|
||||
def get_products() -> dict:
|
||||
"""List configured CF products with live online status."""
|
||||
cfg = _load_imitate_config()
|
||||
products_raw = cfg.get("products", []) or []
|
||||
products = []
|
||||
for p in products_raw:
|
||||
if not isinstance(p, dict):
|
||||
continue
|
||||
base_url = p.get("base_url", "")
|
||||
products.append({
|
||||
"id": p.get("id", ""),
|
||||
"name": p.get("name", ""),
|
||||
"icon": p.get("icon", "📦"),
|
||||
"description": p.get("description", ""),
|
||||
"base_url": base_url,
|
||||
"online": _is_online(base_url, p.get("health_path", "/api/health")) if base_url else False,
|
||||
})
|
||||
return {"products": products}
|
||||
|
||||
|
||||
# ── GET /products/{product_id}/sample ─────────────────────────────────────────
|
||||
|
||||
@router.get("/products/{product_id}/sample")
|
||||
def get_sample(product_id: str, index: int = 0) -> dict:
|
||||
"""Fetch a real sample from the given product's API."""
|
||||
cfg = _load_imitate_config()
|
||||
products_raw = cfg.get("products", []) or []
|
||||
|
||||
product: dict | None = None
|
||||
for p in products_raw:
|
||||
if isinstance(p, dict) and p.get("id") == product_id:
|
||||
product = p
|
||||
break
|
||||
|
||||
if product is None:
|
||||
raise HTTPException(404, f"Product '{product_id}' not in config")
|
||||
|
||||
base_url = product.get("base_url", "").rstrip("/")
|
||||
endpoint = product.get("sample_endpoint", "")
|
||||
if not base_url or not endpoint:
|
||||
raise HTTPException(422, "Product missing base_url or sample_endpoint")
|
||||
|
||||
url = f"{base_url}{endpoint}"
|
||||
try:
|
||||
raw = _http_get_json(url, timeout=5)
|
||||
except URLError as exc:
|
||||
raise HTTPException(503, f"Product API unreachable: {exc}") from exc
|
||||
except Exception as exc:
|
||||
raise HTTPException(502, f"Bad response from product API: {exc}") from exc
|
||||
|
||||
text_fields = product.get("text_fields", []) or []
|
||||
sample_key = product.get("sample_key") or None
|
||||
extracted = _extract_sample(raw, text_fields, index, sample_key=sample_key)
|
||||
if not extracted:
|
||||
raise HTTPException(404, "No sample items returned by product API")
|
||||
|
||||
prompt_template = product.get("prompt_template", "{text}")
|
||||
prompt = prompt_template.replace("{text}", extracted["text"])
|
||||
# Also substitute any {field_name} placeholders from the raw item fields.
|
||||
item = extracted.get("item", {})
|
||||
for field, val in item.items():
|
||||
prompt = prompt.replace(f"{{{field}}}", str(val) if val is not None else "")
|
||||
|
||||
# Expose system_prompt and image_url if the product API returns them.
|
||||
# system_prompt: Peregrine, Snipe (vision analysis instructions)
|
||||
# image_url: Snipe listing photos — Avocet downloads + base64-encodes at run time
|
||||
item = extracted.get("item", {})
|
||||
system_prompt = str(item.get("system_prompt", "")) if isinstance(item, dict) else ""
|
||||
image_url = str(item.get("image_url", "")) if isinstance(item, dict) else ""
|
||||
|
||||
return {
|
||||
"product_id": product_id,
|
||||
"sample_index": index,
|
||||
"text": extracted["text"],
|
||||
"prompt": prompt,
|
||||
"system_prompt": system_prompt,
|
||||
"image_url": image_url,
|
||||
"raw_item": item,
|
||||
}
|
||||
|
||||
|
||||
# ── GET /catalog ───────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/catalog")
|
||||
def get_catalog() -> dict:
|
||||
"""Return the live cf-text model catalog from cf-orch coordinator."""
|
||||
models = _cforch_catalog(_cforch_url())
|
||||
return {"models": models}
|
||||
|
||||
|
||||
# ── GET /run (SSE) ─────────────────────────────────────────────────────────────
|
||||
|
||||
def _get_imitate_session(request: Any, response: Any) -> "CloudUser | None":
|
||||
"""Optional session dependency — returns None when cloud_session is unavailable."""
|
||||
try:
|
||||
from app.cloud_session import get_session
|
||||
return get_session(request, response)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
@router.get("/run")
|
||||
def run_imitate(
|
||||
prompt: str = "",
|
||||
model_ids: str = "", # comma-separated ollama model IDs
|
||||
cf_text_model_ids: str = "", # comma-separated cf-text model IDs (via cf-orch)
|
||||
task_ids: str = "", # comma-separated "product/task" strings — resolved via assignments
|
||||
temperature: float = 0.7,
|
||||
product_id: str = "",
|
||||
system: str = "", # optional system prompt
|
||||
image_url: str = "", # optional image URL for vision models
|
||||
session: "Any" = Depends(_get_imitate_session),
|
||||
) -> StreamingResponse:
|
||||
"""Run a prompt through selected models and stream results as SSE.
|
||||
|
||||
Models can be selected three ways (combinable):
|
||||
- model_ids: explicit ollama model IDs
|
||||
- cf_text_model_ids: explicit cf-text model IDs routed via cf-orch
|
||||
- task_ids: "product/task" strings resolved via the coordinator assignments table
|
||||
|
||||
If image_url is provided, the image is downloaded once and passed to every
|
||||
model as a base64-encoded blob — allowing vision-capable local models to
|
||||
evaluate listing photos the same way Snipe's background task pipeline does.
|
||||
"""
|
||||
|
||||
if not prompt.strip():
|
||||
raise HTTPException(422, "prompt is required")
|
||||
|
||||
ollama_ids = [m.strip() for m in model_ids.split(",") if m.strip()]
|
||||
cftext_ids = [m.strip() for m in cf_text_model_ids.split(",") if m.strip()]
|
||||
raw_task_ids = [t.strip() for t in task_ids.split(",") if t.strip()]
|
||||
|
||||
# Resolve task assignments to concrete model IDs, routing to the right service.
|
||||
# Models that fail to resolve emit an error event at run time (non-fatal).
|
||||
if raw_task_ids:
|
||||
cforch_base = _cforch_url()
|
||||
for task_spec in raw_task_ids:
|
||||
parts = task_spec.split("/", 1)
|
||||
if len(parts) != 2:
|
||||
logger.warning("Skipping malformed task_id %r (expected product/task)", task_spec)
|
||||
continue
|
||||
product_name, task_name = parts
|
||||
resolved = _resolve_task_model(cforch_base, product_name, task_name)
|
||||
if resolved is None:
|
||||
logger.warning("No assignment found for task %r", task_spec)
|
||||
# Emit error at stream time via a sentinel in cftext_ids with a special label.
|
||||
# We instead store the failed task_spec to emit a model_done error.
|
||||
cftext_ids.append(f"__task_unresolved__:{task_spec}")
|
||||
continue
|
||||
mid = resolved["model_id"]
|
||||
svc = resolved["service_type"]
|
||||
if svc == "ollama":
|
||||
if mid not in ollama_ids:
|
||||
ollama_ids.append(mid)
|
||||
else:
|
||||
# cf-text, vllm, and any other cf-orch-managed service
|
||||
if mid not in cftext_ids:
|
||||
cftext_ids.append(mid)
|
||||
|
||||
if not ollama_ids and not cftext_ids:
|
||||
raise HTTPException(422, "model_ids, cf_text_model_ids, or task_ids is required")
|
||||
|
||||
cfg = _load_imitate_config()
|
||||
ollama_base = _ollama_url(cfg)
|
||||
cforch_base = _cforch_url()
|
||||
system_ctx = system.strip() or ""
|
||||
total_models = len(ollama_ids) + len(cftext_ids)
|
||||
|
||||
# Download image once before streaming — shared across ollama vision models
|
||||
images: list[str] = []
|
||||
if image_url.strip():
|
||||
b64 = _fetch_image_b64(image_url.strip())
|
||||
if b64:
|
||||
images = [b64]
|
||||
|
||||
def generate():
|
||||
results: list[dict] = []
|
||||
yield _sse({"type": "start", "total_models": total_models, "has_image": bool(images)})
|
||||
|
||||
# Ollama models
|
||||
for model_id in ollama_ids:
|
||||
yield _sse({"type": "model_start", "model": model_id, "service": "ollama"})
|
||||
try:
|
||||
response, elapsed_ms = _run_ollama_streaming(
|
||||
ollama_base, model_id, prompt, temperature,
|
||||
system=system_ctx, images=images or None,
|
||||
)
|
||||
result = {
|
||||
"model": model_id,
|
||||
"response": response,
|
||||
"elapsed_ms": elapsed_ms,
|
||||
"error": None,
|
||||
}
|
||||
except Exception as exc:
|
||||
result = {
|
||||
"model": model_id,
|
||||
"response": "",
|
||||
"elapsed_ms": 0,
|
||||
"error": str(exc),
|
||||
}
|
||||
results.append(result)
|
||||
yield _sse({"type": "model_done", **result})
|
||||
|
||||
# cf-text models via cf-orch — fan out in parallel when multiple models selected
|
||||
# Partition the list: real cf-text IDs vs unresolved-task sentinels.
|
||||
cftext_real = [m for m in cftext_ids if not m.startswith("__task_unresolved__:")]
|
||||
cftext_unresolved = [m for m in cftext_ids if m.startswith("__task_unresolved__:")]
|
||||
for sentinel in cftext_unresolved:
|
||||
task_spec = sentinel.split(":", 1)[1]
|
||||
result = {
|
||||
"model": task_spec,
|
||||
"response": "",
|
||||
"elapsed_ms": 0,
|
||||
"error": f"No assignment configured for task '{task_spec}'",
|
||||
}
|
||||
results.append(result)
|
||||
yield _sse({"type": "model_done", **result})
|
||||
|
||||
if cftext_real:
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
# Announce all models upfront so the UI can show loading states immediately
|
||||
for model_id in cftext_real:
|
||||
yield _sse({"type": "model_start", "model": model_id, "service": "cf-text"})
|
||||
|
||||
_user_id: str | None = getattr(session, "user_id", None)
|
||||
# Only forward real cloud user IDs — skip local/anon sessions
|
||||
if _user_id in (None, "local", "local-dev") or (_user_id or "").startswith("anon-"):
|
||||
_user_id = None
|
||||
|
||||
with ThreadPoolExecutor(max_workers=len(cftext_real)) as pool:
|
||||
future_to_model = {
|
||||
pool.submit(
|
||||
_run_cftext, cforch_base, mid, prompt, system_ctx, temperature,
|
||||
180.0, _user_id,
|
||||
): mid
|
||||
for mid in cftext_real
|
||||
}
|
||||
for future in as_completed(future_to_model):
|
||||
model_id = future_to_model[future]
|
||||
try:
|
||||
response, elapsed_ms, cold_started = future.result()
|
||||
if cold_started:
|
||||
yield _sse({"type": "model_coldstart", "model": model_id})
|
||||
result = {
|
||||
"model": model_id,
|
||||
"response": response,
|
||||
"elapsed_ms": elapsed_ms,
|
||||
"error": None,
|
||||
}
|
||||
except Exception as exc:
|
||||
result = {
|
||||
"model": model_id,
|
||||
"response": "",
|
||||
"elapsed_ms": 0,
|
||||
"error": str(exc),
|
||||
}
|
||||
results.append(result)
|
||||
yield _sse({"type": "model_done", **result})
|
||||
|
||||
yield _sse({"type": "complete", "results": results})
|
||||
|
||||
return StreamingResponse(
|
||||
generate(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# ── POST /push-corrections ─────────────────────────────────────────────────────
|
||||
|
||||
class ImitateResult(BaseModel):
|
||||
model: str
|
||||
response: str
|
||||
elapsed_ms: int
|
||||
error: str | None = None
|
||||
|
||||
|
||||
class PushCorrectionsRequest(BaseModel):
|
||||
product_id: str
|
||||
prompt: str
|
||||
results: list[ImitateResult]
|
||||
|
||||
|
||||
@router.post("/push-corrections")
|
||||
def push_corrections(req: PushCorrectionsRequest) -> dict:
|
||||
"""Append imitate results to sft_candidates.jsonl for human review."""
|
||||
if not req.prompt.strip():
|
||||
raise HTTPException(422, "prompt is required")
|
||||
if not req.results:
|
||||
raise HTTPException(422, "results list is empty")
|
||||
|
||||
ts = datetime.now(timezone.utc).isoformat()
|
||||
records = []
|
||||
for r in req.results:
|
||||
if r.error or not r.response.strip():
|
||||
continue
|
||||
records.append({
|
||||
"id": str(uuid.uuid4()),
|
||||
"source": "imitate",
|
||||
"product_id": req.product_id,
|
||||
"prompt_messages": [{"role": "user", "content": req.prompt}],
|
||||
"model_response": r.response,
|
||||
"model_id": r.model,
|
||||
"elapsed_ms": r.elapsed_ms,
|
||||
"status": "pending",
|
||||
"created_at": ts,
|
||||
})
|
||||
|
||||
if not records:
|
||||
raise HTTPException(422, "No non-error results to push")
|
||||
|
||||
dest = _candidates_file()
|
||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||
for record in records:
|
||||
append_jsonl(dest, record)
|
||||
|
||||
return {"pushed": len(records)}
|
||||
|
|
@ -1,222 +0,0 @@
|
|||
"""Avocet -- label queue API.
|
||||
|
||||
All label/skip/discard/undo/stats/config endpoints.
|
||||
Extracted from app/api.py as part of the v2 domain split.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import yaml
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
from fastapi.responses import FileResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.utils import append_jsonl, read_jsonl, write_jsonl
|
||||
|
||||
_ROOT = Path(__file__).parent.parent.parent
|
||||
_DATA_DIR: Path = _ROOT / "data"
|
||||
_CONFIG_DIR: Path | None = None
|
||||
_last_action: dict | None = None
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def set_data_dir(path: Path) -> None:
|
||||
global _DATA_DIR
|
||||
_DATA_DIR = path
|
||||
|
||||
def set_config_dir(path: Path | None) -> None:
|
||||
global _CONFIG_DIR
|
||||
_CONFIG_DIR = path
|
||||
|
||||
def reset_last_action() -> None:
|
||||
global _last_action
|
||||
_last_action = None
|
||||
|
||||
def _config_file() -> Path:
|
||||
if _CONFIG_DIR is not None:
|
||||
return _CONFIG_DIR / "label_tool.yaml"
|
||||
return _ROOT / "config" / "label_tool.yaml"
|
||||
|
||||
def _queue_file() -> Path:
|
||||
return _DATA_DIR / "email_label_queue.jsonl"
|
||||
|
||||
def _score_file() -> Path:
|
||||
return _DATA_DIR / "email_score.jsonl"
|
||||
|
||||
def _discarded_file() -> Path:
|
||||
return _DATA_DIR / "discarded.jsonl"
|
||||
|
||||
def _item_id(item: dict) -> str:
|
||||
key = (item.get("subject", "") + (item.get("body", "") or "")[:100])
|
||||
return hashlib.md5(key.encode("utf-8", errors="replace")).hexdigest()
|
||||
|
||||
def _normalize(item: dict) -> dict:
|
||||
return {
|
||||
"id": item.get("id") or _item_id(item),
|
||||
"subject": item.get("subject", ""),
|
||||
"body": item.get("body", ""),
|
||||
"from": item.get("from") or item.get("from_addr", ""),
|
||||
"date": item.get("date", ""),
|
||||
"source": item.get("source") or item.get("account", ""),
|
||||
}
|
||||
|
||||
_LABEL_META = [
|
||||
{"name": "interview_scheduled", "emoji": "\U0001f4c5", "color": "#4CAF50", "key": "1"},
|
||||
{"name": "offer_received", "emoji": "\U0001f389", "color": "#2196F3", "key": "2"},
|
||||
{"name": "rejected", "emoji": "❌", "color": "#F44336", "key": "3"},
|
||||
{"name": "positive_response", "emoji": "\U0001f44d", "color": "#FF9800", "key": "4"},
|
||||
{"name": "survey_received", "emoji": "\U0001f4cb", "color": "#9C27B0", "key": "5"},
|
||||
{"name": "neutral", "emoji": "⬜", "color": "#607D8B", "key": "6"},
|
||||
{"name": "event_rescheduled", "emoji": "\U0001f504", "color": "#FF5722", "key": "7"},
|
||||
{"name": "digest", "emoji": "\U0001f4f0", "color": "#00BCD4", "key": "8"},
|
||||
{"name": "new_lead", "emoji": "\U0001f91d", "color": "#009688", "key": "9"},
|
||||
{"name": "hired", "emoji": "\U0001f38a", "color": "#FFC107", "key": "h"},
|
||||
]
|
||||
|
||||
@router.get("/queue")
|
||||
def get_queue(limit: int = Query(default=10, ge=1, le=50)):
|
||||
items = read_jsonl(_queue_file())
|
||||
return {"items": [_normalize(x) for x in items[:limit]], "total": len(items)}
|
||||
|
||||
class LabelRequest(BaseModel):
|
||||
id: str
|
||||
label: str
|
||||
|
||||
@router.post("/label")
|
||||
def post_label(req: LabelRequest):
|
||||
global _last_action
|
||||
items = read_jsonl(_queue_file())
|
||||
match = next((x for x in items if _normalize(x)["id"] == req.id), None)
|
||||
if not match:
|
||||
raise HTTPException(404, f"Item {req.id!r} not found in queue")
|
||||
record = {**match, "label": req.label,
|
||||
"labeled_at": datetime.now(timezone.utc).isoformat()}
|
||||
append_jsonl(_score_file(), record)
|
||||
write_jsonl(_queue_file(), [x for x in items if _normalize(x)["id"] != req.id])
|
||||
_last_action = {"type": "label", "item": match, "label": req.label}
|
||||
return {"ok": True}
|
||||
|
||||
class SkipRequest(BaseModel):
|
||||
id: str
|
||||
|
||||
@router.post("/skip")
|
||||
def post_skip(req: SkipRequest):
|
||||
global _last_action
|
||||
items = read_jsonl(_queue_file())
|
||||
match = next((x for x in items if _normalize(x)["id"] == req.id), None)
|
||||
if not match:
|
||||
raise HTTPException(404, f"Item {req.id!r} not found in queue")
|
||||
reordered = [x for x in items if _normalize(x)["id"] != req.id] + [match]
|
||||
write_jsonl(_queue_file(), reordered)
|
||||
_last_action = {"type": "skip", "item": match}
|
||||
return {"ok": True}
|
||||
|
||||
class DiscardRequest(BaseModel):
|
||||
id: str
|
||||
|
||||
@router.post("/discard")
|
||||
def post_discard(req: DiscardRequest):
|
||||
global _last_action
|
||||
items = read_jsonl(_queue_file())
|
||||
match = next((x for x in items if _normalize(x)["id"] == req.id), None)
|
||||
if not match:
|
||||
raise HTTPException(404, f"Item {req.id!r} not found in queue")
|
||||
record = {**match, "label": "__discarded__",
|
||||
"discarded_at": datetime.now(timezone.utc).isoformat()}
|
||||
append_jsonl(_discarded_file(), record)
|
||||
write_jsonl(_queue_file(), [x for x in items if _normalize(x)["id"] != req.id])
|
||||
_last_action = {"type": "discard", "item": match}
|
||||
return {"ok": True}
|
||||
|
||||
@router.delete("/label/undo")
|
||||
def delete_undo():
|
||||
global _last_action
|
||||
if not _last_action:
|
||||
raise HTTPException(404, "No action to undo")
|
||||
action = _last_action
|
||||
item = action["item"]
|
||||
if action["type"] == "label":
|
||||
records = read_jsonl(_score_file())
|
||||
if not records:
|
||||
raise HTTPException(409, "Score file is empty -- cannot undo label")
|
||||
write_jsonl(_score_file(), records[:-1])
|
||||
items = read_jsonl(_queue_file())
|
||||
write_jsonl(_queue_file(), [item] + items)
|
||||
elif action["type"] == "discard":
|
||||
records = read_jsonl(_discarded_file())
|
||||
if not records:
|
||||
raise HTTPException(409, "Discarded file is empty -- cannot undo discard")
|
||||
write_jsonl(_discarded_file(), records[:-1])
|
||||
items = read_jsonl(_queue_file())
|
||||
write_jsonl(_queue_file(), [item] + items)
|
||||
elif action["type"] == "skip":
|
||||
items = read_jsonl(_queue_file())
|
||||
item_id = _normalize(item)["id"]
|
||||
items = [item] + [x for x in items if _normalize(x)["id"] != item_id]
|
||||
write_jsonl(_queue_file(), items)
|
||||
_last_action = None
|
||||
return {"undone": {"type": action["type"], "item": _normalize(item)}}
|
||||
|
||||
@router.get("/config/labels")
|
||||
def get_labels():
|
||||
return _LABEL_META
|
||||
|
||||
@router.get("/config")
|
||||
def get_config():
|
||||
f = _config_file()
|
||||
if not f.exists():
|
||||
return {"accounts": [], "max_per_account": 500}
|
||||
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
||||
return {"accounts": raw.get("accounts", []), "max_per_account": raw.get("max_per_account", 500)}
|
||||
|
||||
class ConfigPayload(BaseModel):
|
||||
accounts: list[dict]
|
||||
max_per_account: int = 500
|
||||
|
||||
@router.post("/config")
|
||||
def post_config(payload: ConfigPayload):
|
||||
f = _config_file()
|
||||
f.parent.mkdir(parents=True, exist_ok=True)
|
||||
tmp = f.with_suffix(".tmp")
|
||||
tmp.write_text(yaml.dump(payload.model_dump(), allow_unicode=True, sort_keys=False),
|
||||
encoding="utf-8")
|
||||
tmp.rename(f)
|
||||
return {"ok": True}
|
||||
|
||||
@router.get("/stats")
|
||||
def get_stats():
|
||||
records = read_jsonl(_score_file())
|
||||
counts: dict[str, int] = {}
|
||||
for r in records:
|
||||
lbl = r.get("label", "")
|
||||
if lbl:
|
||||
counts[lbl] = counts.get(lbl, 0) + 1
|
||||
benchmark_results: dict = {}
|
||||
benchmark_path = _DATA_DIR / "benchmark_results.json"
|
||||
if benchmark_path.exists():
|
||||
try:
|
||||
benchmark_results = json.loads(benchmark_path.read_text(encoding="utf-8"))
|
||||
except Exception:
|
||||
pass
|
||||
return {
|
||||
"total": len(records),
|
||||
"counts": counts,
|
||||
"score_file_bytes": _score_file().stat().st_size if _score_file().exists() else 0,
|
||||
"benchmark_results": benchmark_results,
|
||||
}
|
||||
|
||||
@router.get("/stats/download")
|
||||
def download_stats():
|
||||
if not _score_file().exists():
|
||||
raise HTTPException(404, "No score file")
|
||||
return FileResponse(
|
||||
str(_score_file()),
|
||||
filename="email_score.jsonl",
|
||||
media_type="application/jsonlines",
|
||||
headers={"Content-Disposition": 'attachment; filename="email_score.jsonl"'},
|
||||
)
|
||||
|
|
@ -1,462 +0,0 @@
|
|||
"""Avocet — Log Corpus receiver and labeling API.
|
||||
|
||||
Receives push batches from consented Turnstone nodes, stores entries for labeling,
|
||||
and exports labeled data as JSONL for the logreading fine-tune pipeline.
|
||||
|
||||
DB: data/corpus.db (separate from train_jobs.db — different lifecycle)
|
||||
Auth: Bearer token validated against corpus_sources table (seeded from label_tool.yaml).
|
||||
|
||||
All endpoints registered on `router`. api.py includes this with prefix="/api/corpus".
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import sqlite3
|
||||
import uuid
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Generator
|
||||
|
||||
import yaml
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.requests import Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_ROOT = Path(__file__).parent.parent.parent
|
||||
_CONFIG_DIR: Path | None = None
|
||||
_DATA_DIR: Path = _ROOT / "data"
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
_DB_PATH: Path = _ROOT / "data" / "corpus.db"
|
||||
|
||||
_PIPELINE_SOURCE_HOST = "pipeline_scrape"
|
||||
|
||||
_SCHEMA = """
|
||||
CREATE TABLE IF NOT EXISTS corpus_sources (
|
||||
token TEXT PRIMARY KEY,
|
||||
source_host TEXT NOT NULL,
|
||||
owner TEXT NOT NULL,
|
||||
consent_date TEXT NOT NULL,
|
||||
consent_method TEXT NOT NULL,
|
||||
active INTEGER NOT NULL DEFAULT 1
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS corpus_batches (
|
||||
id TEXT PRIMARY KEY,
|
||||
source_host TEXT NOT NULL,
|
||||
batch_type TEXT NOT NULL,
|
||||
received_at TEXT NOT NULL,
|
||||
entry_count INTEGER NOT NULL,
|
||||
watermark_from TEXT,
|
||||
watermark_to TEXT,
|
||||
raw_json TEXT NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS corpus_entries (
|
||||
id TEXT PRIMARY KEY,
|
||||
batch_id TEXT NOT NULL REFERENCES corpus_batches(id),
|
||||
source_host TEXT NOT NULL,
|
||||
origin_entry_id TEXT,
|
||||
timestamp_iso TEXT,
|
||||
severity TEXT,
|
||||
source_id TEXT,
|
||||
text TEXT NOT NULL,
|
||||
matched_patterns TEXT DEFAULT '[]',
|
||||
label_state TEXT NOT NULL DEFAULT 'unlabeled',
|
||||
failure_type TEXT,
|
||||
plain_explanation TEXT,
|
||||
known_pattern TEXT,
|
||||
labeled_at TEXT,
|
||||
labeled_by TEXT DEFAULT 'alan',
|
||||
pii_flagged INTEGER NOT NULL DEFAULT 0
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_ce_label_state ON corpus_entries(label_state);
|
||||
CREATE INDEX IF NOT EXISTS idx_ce_source ON corpus_entries(source_host);
|
||||
CREATE INDEX IF NOT EXISTS idx_ce_severity ON corpus_entries(severity);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS ingested_pipeline_files (
|
||||
filename TEXT PRIMARY KEY,
|
||||
ingested_at TEXT NOT NULL,
|
||||
entry_count INTEGER NOT NULL
|
||||
);
|
||||
"""
|
||||
|
||||
|
||||
# ── Testability seams ──────────────────────────────────────────────────────────
|
||||
|
||||
def set_config_dir(path: Path | None) -> None:
|
||||
global _CONFIG_DIR
|
||||
_CONFIG_DIR = path
|
||||
|
||||
|
||||
def set_data_dir(path: Path) -> None:
|
||||
global _DATA_DIR, _DB_PATH
|
||||
_DATA_DIR = path
|
||||
_DB_PATH = path / "corpus.db"
|
||||
|
||||
|
||||
# ── Internal helpers ───────────────────────────────────────────────────────────
|
||||
|
||||
def _config_file() -> Path:
|
||||
if _CONFIG_DIR is not None:
|
||||
return _CONFIG_DIR / "label_tool.yaml"
|
||||
return _ROOT / "config" / "label_tool.yaml"
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _db() -> Generator[sqlite3.Connection, None, None]:
|
||||
conn = sqlite3.connect(str(_DB_PATH))
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
try:
|
||||
yield conn
|
||||
conn.commit()
|
||||
except Exception:
|
||||
conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
def _init_db() -> None:
|
||||
with _db() as conn:
|
||||
conn.executescript(_SCHEMA)
|
||||
_seed_sources(conn)
|
||||
|
||||
|
||||
def _pipeline_ingest_dir() -> Path | None:
|
||||
"""Return the configured pipeline log ingest directory, or None if unset."""
|
||||
f = _config_file()
|
||||
if not f.exists():
|
||||
return None
|
||||
try:
|
||||
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
||||
except yaml.YAMLError:
|
||||
return None
|
||||
val = raw.get("corpus", {}).get("pipeline_ingest_dir", "") or ""
|
||||
return Path(val) if val else None
|
||||
|
||||
|
||||
def _load_corpus_config() -> list[dict]:
|
||||
f = _config_file()
|
||||
if not f.exists():
|
||||
return []
|
||||
try:
|
||||
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
||||
except yaml.YAMLError as exc:
|
||||
logger.warning("Failed to parse corpus config: %s", exc)
|
||||
return []
|
||||
return raw.get("corpus", {}).get("sources", []) or []
|
||||
|
||||
|
||||
def _seed_sources(conn: sqlite3.Connection) -> None:
|
||||
for src in _load_corpus_config():
|
||||
conn.execute(
|
||||
"INSERT OR IGNORE INTO corpus_sources (token, source_host, owner, consent_date, consent_method) "
|
||||
"VALUES (?, ?, ?, ?, ?)",
|
||||
(src["token"], src["source_host"], src["owner"],
|
||||
src["consent_date"], src["consent_method"]),
|
||||
)
|
||||
|
||||
|
||||
def _validate_token(token: str, conn: sqlite3.Connection) -> str:
|
||||
"""Return source_host for token, or raise 403."""
|
||||
row = conn.execute(
|
||||
"SELECT source_host FROM corpus_sources WHERE token = ? AND active = 1",
|
||||
(token,),
|
||||
).fetchone()
|
||||
if row is None:
|
||||
raise HTTPException(status_code=403, detail="Unknown or revoked consent token")
|
||||
return row["source_host"]
|
||||
|
||||
|
||||
def _extract_bearer(request: Request) -> str:
|
||||
auth = request.headers.get("Authorization", "")
|
||||
if not auth.startswith("Bearer "):
|
||||
raise HTTPException(status_code=401, detail="Bearer token required")
|
||||
return auth.removeprefix("Bearer ").strip()
|
||||
|
||||
|
||||
def _now_iso() -> str:
|
||||
return datetime.now(timezone.utc).isoformat()
|
||||
|
||||
|
||||
# ── Startup ────────────────────────────────────────────────────────────────────
|
||||
|
||||
_init_db()
|
||||
|
||||
|
||||
# ── POST /api/corpus/log-batch ─────────────────────────────────────────────────
|
||||
|
||||
@router.post("/log-batch")
|
||||
def receive_batch(request: Request, payload: dict) -> dict:
|
||||
"""Accept a push batch from a Turnstone node."""
|
||||
token = _extract_bearer(request)
|
||||
|
||||
batch_type = payload.get("batch_type", "raw_entries")
|
||||
entries_raw = payload.get("entries", [])
|
||||
batch_id = payload.get("batch_id") or str(uuid.uuid4())
|
||||
|
||||
with _db() as conn:
|
||||
source_host = _validate_token(token, conn)
|
||||
|
||||
conn.execute(
|
||||
"INSERT INTO corpus_batches (id, source_host, batch_type, received_at, entry_count, "
|
||||
"watermark_from, watermark_to, raw_json) VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
(batch_id, source_host, batch_type, _now_iso(), len(entries_raw),
|
||||
str(payload.get("watermark_from", "")),
|
||||
str(payload.get("watermark_to", "")),
|
||||
json.dumps(payload)),
|
||||
)
|
||||
|
||||
stored = 0
|
||||
for entry in entries_raw:
|
||||
text = entry.get("text", "").strip()
|
||||
if not text:
|
||||
continue
|
||||
conn.execute(
|
||||
"INSERT OR IGNORE INTO corpus_entries "
|
||||
"(id, batch_id, source_host, origin_entry_id, timestamp_iso, severity, "
|
||||
"source_id, text, matched_patterns) "
|
||||
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
(str(uuid.uuid4()), batch_id, source_host,
|
||||
entry.get("entry_id") or entry.get("id"),
|
||||
entry.get("timestamp_iso"),
|
||||
entry.get("severity"),
|
||||
entry.get("source_id"),
|
||||
text,
|
||||
json.dumps(entry.get("matched_patterns", []))),
|
||||
)
|
||||
stored += 1
|
||||
|
||||
logger.info("Received batch %s from %s: %d/%d entries stored",
|
||||
batch_id, source_host, stored, len(entries_raw))
|
||||
return {"received": True, "batch_id": batch_id, "entries_stored": stored}
|
||||
|
||||
|
||||
# ── GET /api/corpus/entries ────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/entries")
|
||||
def list_entries(
|
||||
state: str = "unlabeled",
|
||||
source_host: str | None = None,
|
||||
limit: int = 25,
|
||||
) -> dict:
|
||||
"""Return entries for labeling. Default: unlabeled entries, oldest first."""
|
||||
with _db() as conn:
|
||||
query = "SELECT * FROM corpus_entries WHERE label_state = ?"
|
||||
params: list = [state]
|
||||
if source_host:
|
||||
query += " AND source_host = ?"
|
||||
params.append(source_host)
|
||||
query += " ORDER BY rowid LIMIT ?"
|
||||
params.append(min(limit, 100))
|
||||
rows = conn.execute(query, params).fetchall()
|
||||
return {"entries": [dict(r) for r in rows], "count": len(rows)}
|
||||
|
||||
|
||||
# ── POST /api/corpus/entries/{id}/label ───────────────────────────────────────
|
||||
|
||||
@router.post("/entries/{entry_id}/label")
|
||||
def label_entry(entry_id: str, body: dict) -> dict:
|
||||
"""Submit a label for a corpus entry."""
|
||||
failure_type = body.get("failure_type")
|
||||
plain_explanation = body.get("plain_explanation", "").strip()
|
||||
known_pattern = body.get("known_pattern")
|
||||
pii_flagged = int(bool(body.get("pii_flagged", False)))
|
||||
|
||||
if not failure_type:
|
||||
raise HTTPException(status_code=422, detail="failure_type is required")
|
||||
valid_types = {"hardware", "software", "network", "security", "application", "none", "other"}
|
||||
if failure_type not in valid_types:
|
||||
raise HTTPException(status_code=422, detail=f"failure_type must be one of {sorted(valid_types)}")
|
||||
|
||||
with _db() as conn:
|
||||
row = conn.execute("SELECT id FROM corpus_entries WHERE id = ?", (entry_id,)).fetchone()
|
||||
if row is None:
|
||||
raise HTTPException(status_code=404, detail="Entry not found")
|
||||
conn.execute(
|
||||
"UPDATE corpus_entries SET label_state='labeled', failure_type=?, plain_explanation=?, "
|
||||
"known_pattern=?, labeled_at=?, pii_flagged=? WHERE id=?",
|
||||
(failure_type, plain_explanation, known_pattern, _now_iso(), pii_flagged, entry_id),
|
||||
)
|
||||
return {"labeled": True, "entry_id": entry_id}
|
||||
|
||||
|
||||
# ── POST /api/corpus/entries/{id}/skip ────────────────────────────────────────
|
||||
|
||||
@router.post("/entries/{entry_id}/skip")
|
||||
def skip_entry(entry_id: str) -> dict:
|
||||
with _db() as conn:
|
||||
row = conn.execute("SELECT id FROM corpus_entries WHERE id = ?", (entry_id,)).fetchone()
|
||||
if row is None:
|
||||
raise HTTPException(status_code=404, detail="Entry not found")
|
||||
conn.execute(
|
||||
"UPDATE corpus_entries SET label_state='skipped' WHERE id=?", (entry_id,)
|
||||
)
|
||||
return {"skipped": True, "entry_id": entry_id}
|
||||
|
||||
|
||||
# ── GET /api/corpus/stats ──────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/stats")
|
||||
def get_stats() -> dict:
|
||||
with _db() as conn:
|
||||
total = conn.execute("SELECT COUNT(*) FROM corpus_entries").fetchone()[0]
|
||||
by_state = {
|
||||
r["label_state"]: r["cnt"]
|
||||
for r in conn.execute(
|
||||
"SELECT label_state, COUNT(*) AS cnt FROM corpus_entries GROUP BY label_state"
|
||||
).fetchall()
|
||||
}
|
||||
by_source = {
|
||||
r["source_host"]: r["cnt"]
|
||||
for r in conn.execute(
|
||||
"SELECT source_host, COUNT(*) AS cnt FROM corpus_entries GROUP BY source_host"
|
||||
).fetchall()
|
||||
}
|
||||
by_severity = {
|
||||
r["severity"]: r["cnt"]
|
||||
for r in conn.execute(
|
||||
"SELECT severity, COUNT(*) AS cnt FROM corpus_entries "
|
||||
"WHERE severity IS NOT NULL GROUP BY severity"
|
||||
).fetchall()
|
||||
}
|
||||
batch_count = conn.execute("SELECT COUNT(*) FROM corpus_batches").fetchone()[0]
|
||||
return {
|
||||
"total_entries": total,
|
||||
"batch_count": batch_count,
|
||||
"by_label_state": by_state,
|
||||
"by_source": by_source,
|
||||
"by_severity": by_severity,
|
||||
}
|
||||
|
||||
|
||||
# ── GET /api/corpus/export ────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/export")
|
||||
def export_labeled() -> StreamingResponse:
|
||||
"""Stream labeled, non-PII entries as JSONL for SFT harness."""
|
||||
with _db() as conn:
|
||||
rows = conn.execute(
|
||||
"SELECT source_host, source_id, severity, text, failure_type, plain_explanation, known_pattern "
|
||||
"FROM corpus_entries "
|
||||
"WHERE label_state = 'labeled' AND pii_flagged = 0 AND plain_explanation != ''"
|
||||
"ORDER BY rowid"
|
||||
).fetchall()
|
||||
|
||||
def _generate():
|
||||
for row in rows:
|
||||
record = {
|
||||
"input": row["text"],
|
||||
"output": row["plain_explanation"],
|
||||
"metadata": {
|
||||
"failure_type": row["failure_type"],
|
||||
"source": row["source_host"],
|
||||
"source_id": row["source_id"],
|
||||
"severity": row["severity"],
|
||||
"known_pattern": row["known_pattern"],
|
||||
},
|
||||
}
|
||||
yield json.dumps(record) + "\n"
|
||||
|
||||
return StreamingResponse(
|
||||
_generate(),
|
||||
media_type="application/x-ndjson",
|
||||
headers={"Content-Disposition": "attachment; filename=log_corpus_labeled.jsonl"},
|
||||
)
|
||||
|
||||
|
||||
# ── POST /api/corpus/pipeline-ingest ─────────────────────────────────────────
|
||||
|
||||
def _ingest_one_file(conn: sqlite3.Connection, path: Path) -> int:
|
||||
"""Parse a pipeline JSONL file and insert entries. Returns count stored."""
|
||||
batch_id = str(uuid.uuid4())
|
||||
lines = path.read_text(encoding="utf-8").splitlines()
|
||||
entries_raw: list[dict] = []
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
entries_raw.append(json.loads(line))
|
||||
except json.JSONDecodeError:
|
||||
logger.debug("Skipping malformed line in %s", path.name)
|
||||
|
||||
conn.execute(
|
||||
"INSERT INTO corpus_batches (id, source_host, batch_type, received_at, entry_count, raw_json) "
|
||||
"VALUES (?, ?, ?, ?, ?, ?)",
|
||||
(batch_id, _PIPELINE_SOURCE_HOST, "pipeline_log", _now_iso(),
|
||||
len(entries_raw), json.dumps({"file": path.name})),
|
||||
)
|
||||
|
||||
stored = 0
|
||||
for entry in entries_raw:
|
||||
text = (entry.get("msg") or "").strip()
|
||||
if not text:
|
||||
continue
|
||||
conn.execute(
|
||||
"INSERT OR IGNORE INTO corpus_entries "
|
||||
"(id, batch_id, source_host, timestamp_iso, severity, source_id, text, matched_patterns) "
|
||||
"VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
(str(uuid.uuid4()), batch_id, _PIPELINE_SOURCE_HOST,
|
||||
entry.get("ts"),
|
||||
entry.get("level"),
|
||||
entry.get("logger"),
|
||||
text,
|
||||
json.dumps([entry["extra"]] if entry.get("extra") else [])),
|
||||
)
|
||||
stored += 1
|
||||
|
||||
conn.execute(
|
||||
"INSERT INTO ingested_pipeline_files (filename, ingested_at, entry_count) VALUES (?, ?, ?)",
|
||||
(path.name, _now_iso(), stored),
|
||||
)
|
||||
return stored
|
||||
|
||||
|
||||
@router.post("/pipeline-ingest")
|
||||
def pipeline_ingest() -> dict:
|
||||
"""Walk the configured pipeline log directory and ingest new JSONL files.
|
||||
|
||||
Skips files already recorded in ingested_pipeline_files. Safe to call
|
||||
repeatedly — idempotent by filename.
|
||||
"""
|
||||
ingest_dir = _pipeline_ingest_dir()
|
||||
if ingest_dir is None:
|
||||
raise HTTPException(404, "pipeline_ingest_dir not configured in label_tool.yaml")
|
||||
|
||||
ingested = 0
|
||||
skipped = 0
|
||||
total_stored = 0
|
||||
files_detail: list[dict] = []
|
||||
|
||||
with _db() as conn:
|
||||
already_done: set[str] = {
|
||||
row[0]
|
||||
for row in conn.execute("SELECT filename FROM ingested_pipeline_files").fetchall()
|
||||
}
|
||||
|
||||
for path in sorted(ingest_dir.glob("*.jsonl")):
|
||||
if path.name in already_done:
|
||||
skipped += 1
|
||||
continue
|
||||
stored = _ingest_one_file(conn, path)
|
||||
ingested += 1
|
||||
total_stored += stored
|
||||
files_detail.append({"file": path.name, "entries_stored": stored})
|
||||
|
||||
logger.info("Pipeline ingest: %d files ingested, %d skipped, %d entries stored",
|
||||
ingested, skipped, total_stored)
|
||||
return {
|
||||
"ingested_files": ingested,
|
||||
"skipped_files": skipped,
|
||||
"entries_stored": total_stored,
|
||||
"files": files_detail,
|
||||
}
|
||||
|
|
@ -1,313 +0,0 @@
|
|||
"""Avocet — Recipe scan labeling API (avocet#65).
|
||||
|
||||
Receives recipe scan items from the Kiwi pipeline (scanner/phone image +
|
||||
docuvision OCR extraction + ground-truth structured recipe), presents them
|
||||
for human review, and exports approved/edited pairs in the messages chat
|
||||
format for the vision fine-tune harness.
|
||||
|
||||
DB: data/recipe_scan.db (separate from corpus.db — different lifecycle)
|
||||
No auth required — local admin tool, not a push endpoint.
|
||||
|
||||
All endpoints registered on `router`. api.py includes this with
|
||||
prefix="/api/recipe-scan".
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import sqlite3
|
||||
import uuid
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Generator, Literal
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_ROOT = Path(__file__).parent.parent.parent
|
||||
_DB_PATH: Path = _ROOT / "data" / "recipe_scan.db"
|
||||
|
||||
_VALID_MODALITIES = {"scanner", "phone", "handwritten"}
|
||||
_VALID_STATUSES = {"pending", "approved", "edited", "rejected"}
|
||||
|
||||
_SCHEMA = """
|
||||
CREATE TABLE IF NOT EXISTS recipe_scan_items (
|
||||
id TEXT PRIMARY KEY,
|
||||
image_path TEXT NOT NULL,
|
||||
modality TEXT NOT NULL DEFAULT 'scanner',
|
||||
source TEXT NOT NULL DEFAULT 'purple_carrot',
|
||||
extracted TEXT NOT NULL,
|
||||
ground_truth TEXT NOT NULL,
|
||||
status TEXT NOT NULL DEFAULT 'pending',
|
||||
corrected TEXT,
|
||||
labeled_at TEXT,
|
||||
rejected_reason TEXT
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_rsi_status ON recipe_scan_items(status);
|
||||
CREATE INDEX IF NOT EXISTS idx_rsi_modality ON recipe_scan_items(modality);
|
||||
"""
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# ── Testability seam ──────────────────────────────────────────────────────────
|
||||
|
||||
def set_db_path(path: Path) -> None:
|
||||
global _DB_PATH
|
||||
_DB_PATH = path
|
||||
|
||||
|
||||
# ── Internal helpers ──────────────────────────────────────────────────────────
|
||||
|
||||
@contextmanager
|
||||
def _db() -> Generator[sqlite3.Connection, None, None]:
|
||||
conn = sqlite3.connect(str(_DB_PATH))
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
try:
|
||||
yield conn
|
||||
conn.commit()
|
||||
except Exception:
|
||||
conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
def _init_db() -> None:
|
||||
with _db() as conn:
|
||||
conn.executescript(_SCHEMA)
|
||||
|
||||
|
||||
def _now_iso() -> str:
|
||||
return datetime.now(timezone.utc).isoformat()
|
||||
|
||||
|
||||
def _build_training_pair(row: sqlite3.Row) -> dict:
|
||||
"""Build a messages-format training pair from a labeled row.
|
||||
|
||||
user message: correction prompt + the docuvision-extracted JSON draft.
|
||||
Trains the model to review and correct an existing extraction, which is
|
||||
more data-efficient than producing from scratch when OCR is usually close.
|
||||
|
||||
assistant message: the approved ground truth (or human-corrected JSON).
|
||||
"""
|
||||
target_str = row["corrected"] if row["corrected"] else row["ground_truth"]
|
||||
extracted = json.loads(row["extracted"])
|
||||
target = json.loads(target_str)
|
||||
user_content = (
|
||||
"Review and correct this recipe extraction. "
|
||||
"Return valid JSON with fields: title, description, ingredients, steps, "
|
||||
"prep_time, cook_time, servings.\n\n"
|
||||
f"Extraction to review:\n{json.dumps(extracted, ensure_ascii=False, indent=2)}"
|
||||
)
|
||||
return {
|
||||
"id": row["id"],
|
||||
"modality": row["modality"],
|
||||
"source": row["source"],
|
||||
"image_path": row["image_path"],
|
||||
"messages": [
|
||||
{"role": "user", "content": user_content},
|
||||
{"role": "assistant", "content": json.dumps(target, ensure_ascii=False)},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
_init_db()
|
||||
|
||||
|
||||
# ── POST /import ───────────────────────────────────────────────────────────────
|
||||
|
||||
class ImportItem(BaseModel):
|
||||
id: str = ""
|
||||
image_path: str
|
||||
modality: Literal["scanner", "phone", "handwritten"] = "scanner"
|
||||
source: str = "purple_carrot"
|
||||
extracted: dict
|
||||
ground_truth: dict
|
||||
|
||||
@field_validator("id", mode="before")
|
||||
@classmethod
|
||||
def default_id(cls, v: str) -> str:
|
||||
return v or str(uuid.uuid4())
|
||||
|
||||
|
||||
class ImportRequest(BaseModel):
|
||||
items: list[ImportItem]
|
||||
|
||||
|
||||
@router.post("/import")
|
||||
def import_items(body: ImportRequest) -> dict:
|
||||
"""Bulk-import scan items from the Kiwi pipeline. Idempotent by item id."""
|
||||
stored = 0
|
||||
with _db() as conn:
|
||||
for item in body.items:
|
||||
result = conn.execute(
|
||||
"INSERT OR IGNORE INTO recipe_scan_items "
|
||||
"(id, image_path, modality, source, extracted, ground_truth) "
|
||||
"VALUES (?, ?, ?, ?, ?, ?)",
|
||||
(item.id, item.image_path, item.modality, item.source,
|
||||
json.dumps(item.extracted), json.dumps(item.ground_truth)),
|
||||
)
|
||||
stored += result.rowcount
|
||||
return {"imported": stored, "total_submitted": len(body.items)}
|
||||
|
||||
|
||||
# ── GET /next ─────────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/next")
|
||||
def get_next() -> dict:
|
||||
"""Return the next pending item for review, oldest-first."""
|
||||
with _db() as conn:
|
||||
row = conn.execute(
|
||||
"SELECT * FROM recipe_scan_items WHERE status = 'pending' ORDER BY rowid LIMIT 1"
|
||||
).fetchone()
|
||||
if row is None:
|
||||
raise HTTPException(404, "No pending items in queue")
|
||||
return {
|
||||
**dict(row),
|
||||
"extracted": json.loads(row["extracted"]),
|
||||
"ground_truth": json.loads(row["ground_truth"]),
|
||||
}
|
||||
|
||||
|
||||
# ── POST /items/{id}/approve ──────────────────────────────────────────────────
|
||||
|
||||
@router.post("/items/{item_id}/approve")
|
||||
def approve_item(item_id: str) -> dict:
|
||||
"""Mark item as approved — extracted JSON is close enough to ground truth."""
|
||||
with _db() as conn:
|
||||
row = conn.execute("SELECT id FROM recipe_scan_items WHERE id = ?", (item_id,)).fetchone()
|
||||
if row is None:
|
||||
raise HTTPException(404, "Item not found")
|
||||
conn.execute(
|
||||
"UPDATE recipe_scan_items SET status='approved', labeled_at=? WHERE id=?",
|
||||
(_now_iso(), item_id),
|
||||
)
|
||||
return {"status": "approved", "id": item_id}
|
||||
|
||||
|
||||
# ── POST /items/{id}/edit ─────────────────────────────────────────────────────
|
||||
|
||||
class EditBody(BaseModel):
|
||||
corrected: dict
|
||||
|
||||
|
||||
@router.post("/items/{item_id}/edit")
|
||||
def edit_item(item_id: str, body: EditBody) -> dict:
|
||||
"""Approve with a human-corrected JSON. corrected overrides extracted in export."""
|
||||
with _db() as conn:
|
||||
row = conn.execute("SELECT id FROM recipe_scan_items WHERE id = ?", (item_id,)).fetchone()
|
||||
if row is None:
|
||||
raise HTTPException(404, "Item not found")
|
||||
conn.execute(
|
||||
"UPDATE recipe_scan_items SET status='edited', corrected=?, labeled_at=? WHERE id=?",
|
||||
(json.dumps(body.corrected), _now_iso(), item_id),
|
||||
)
|
||||
return {"status": "edited", "id": item_id}
|
||||
|
||||
|
||||
# ── POST /items/{id}/reject ───────────────────────────────────────────────────
|
||||
|
||||
class RejectBody(BaseModel):
|
||||
reason: str = ""
|
||||
|
||||
|
||||
@router.post("/items/{item_id}/reject")
|
||||
def reject_item(item_id: str, body: RejectBody = RejectBody()) -> dict:
|
||||
"""Reject item — extraction too broken to use for training."""
|
||||
with _db() as conn:
|
||||
row = conn.execute("SELECT id FROM recipe_scan_items WHERE id = ?", (item_id,)).fetchone()
|
||||
if row is None:
|
||||
raise HTTPException(404, "Item not found")
|
||||
conn.execute(
|
||||
"UPDATE recipe_scan_items SET status='rejected', rejected_reason=?, labeled_at=? WHERE id=?",
|
||||
(body.reason or None, _now_iso(), item_id),
|
||||
)
|
||||
return {"status": "rejected", "id": item_id}
|
||||
|
||||
|
||||
# ── GET /stats ────────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/stats")
|
||||
def get_stats() -> dict:
|
||||
with _db() as conn:
|
||||
total = conn.execute("SELECT COUNT(*) FROM recipe_scan_items").fetchone()[0]
|
||||
by_status = {
|
||||
r["status"]: r["cnt"]
|
||||
for r in conn.execute(
|
||||
"SELECT status, COUNT(*) AS cnt FROM recipe_scan_items GROUP BY status"
|
||||
).fetchall()
|
||||
}
|
||||
by_modality = {
|
||||
r["modality"]: r["cnt"]
|
||||
for r in conn.execute(
|
||||
"SELECT modality, COUNT(*) AS cnt FROM recipe_scan_items GROUP BY modality"
|
||||
).fetchall()
|
||||
}
|
||||
export_ready = conn.execute(
|
||||
"SELECT COUNT(*) FROM recipe_scan_items WHERE status IN ('approved', 'edited')"
|
||||
).fetchone()[0]
|
||||
return {
|
||||
"total": total,
|
||||
"by_status": by_status,
|
||||
"by_modality": by_modality,
|
||||
"export_ready": export_ready,
|
||||
}
|
||||
|
||||
|
||||
# ── GET /export ───────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/export")
|
||||
def export_pairs() -> StreamingResponse:
|
||||
"""Stream approved/edited items as JSONL training pairs (messages format)."""
|
||||
with _db() as conn:
|
||||
rows = conn.execute(
|
||||
"SELECT * FROM recipe_scan_items WHERE status IN ('approved', 'edited') ORDER BY rowid"
|
||||
).fetchall()
|
||||
|
||||
def _generate():
|
||||
for row in rows:
|
||||
yield json.dumps(_build_training_pair(row), ensure_ascii=False) + "\n"
|
||||
|
||||
return StreamingResponse(
|
||||
_generate(),
|
||||
media_type="application/x-ndjson",
|
||||
headers={"Content-Disposition": "attachment; filename=recipe_scan_pairs.jsonl"},
|
||||
)
|
||||
|
||||
|
||||
# ── GET /image ────────────────────────────────────────────────────────────────
|
||||
|
||||
_IMAGE_ROOT = Path("/Library/Assets/kiwi")
|
||||
|
||||
|
||||
@router.get("/image")
|
||||
def serve_image(path: str) -> StreamingResponse:
|
||||
"""Serve a scan image from /Library/Assets/kiwi/.
|
||||
|
||||
path must resolve within /Library/Assets/kiwi/ — rejects traversal attempts.
|
||||
"""
|
||||
try:
|
||||
resolved = Path(path).resolve()
|
||||
_IMAGE_ROOT.resolve() # ensure root itself is valid
|
||||
resolved.relative_to(_IMAGE_ROOT.resolve())
|
||||
except (ValueError, OSError):
|
||||
raise HTTPException(403, "Path outside allowed image directory")
|
||||
|
||||
if not resolved.exists():
|
||||
raise HTTPException(404, "Image not found")
|
||||
|
||||
suffix = resolved.suffix.lower()
|
||||
media_types = {".jpg": "image/jpeg", ".jpeg": "image/jpeg", ".png": "image/png", ".webp": "image/webp"}
|
||||
media_type = media_types.get(suffix, "application/octet-stream")
|
||||
|
||||
return StreamingResponse(
|
||||
open(resolved, "rb"),
|
||||
media_type=media_type,
|
||||
headers={"Cache-Control": "public, max-age=86400"},
|
||||
)
|
||||
|
|
@ -1,44 +0,0 @@
|
|||
"""Avocet -- eval router aggregator.
|
||||
|
||||
Collects benchmark sub-routers into a single importable `router`
|
||||
for the api.py factory. Each sub-router retains its established prefix
|
||||
so no frontend URL changes are needed.
|
||||
|
||||
Route prefixes when mounted at /api in api.py:
|
||||
/api/cforch/* -- cf-orch benchmark routes
|
||||
/api/style/* -- writing style benchmark routes
|
||||
/api/voice/* -- voice benchmark routes
|
||||
/api/plans-bench/* -- plans benchmark routes
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from app.cforch import router as _cforch_router
|
||||
from app.style import router as _style_router
|
||||
from app.voice import router as _voice_router
|
||||
from app.plans_bench import router as _plans_router
|
||||
from app.eval.embed_bench import router as _embed_router
|
||||
|
||||
router = APIRouter()
|
||||
router.include_router(_cforch_router, prefix="/cforch")
|
||||
router.include_router(_style_router, prefix="/style")
|
||||
router.include_router(_voice_router, prefix="/voice")
|
||||
router.include_router(_plans_router, prefix="/plans-bench")
|
||||
router.include_router(_embed_router, prefix="/embed-bench")
|
||||
|
||||
|
||||
def set_config_dir(path: Path | None) -> None:
|
||||
"""Propagate config dir override to all sub-modules -- used by tests."""
|
||||
import app.cforch as _cforch_mod
|
||||
import app.style as _style_mod
|
||||
import app.voice as _voice_mod
|
||||
import app.plans_bench as _plans_mod
|
||||
import app.eval.embed_bench as _embed_mod
|
||||
_cforch_mod.set_config_dir(path)
|
||||
_style_mod.set_config_dir(path)
|
||||
_voice_mod.set_config_dir(path)
|
||||
_plans_mod.set_config_dir(path)
|
||||
_embed_mod.set_config_dir(path)
|
||||
|
|
@ -1,293 +0,0 @@
|
|||
"""Avocet — embedding model comparison harness.
|
||||
|
||||
Exposes FastAPI routes under /api/embed-bench (mounted via app/eval/cforch.py).
|
||||
All computation is local: no LLM inference, Ollama only. MIT tier throughout.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import csv
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
import yaml
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_ROOT = Path(__file__).parent.parent.parent
|
||||
_CONFIG_DIR: Path | None = None # override via set_config_dir() in tests
|
||||
_RUN_ACTIVE: bool = False
|
||||
_RATINGS_FILE = _ROOT / "data" / "embed_bench_ratings.jsonl"
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# ── Testability seam ──────────────────────────────────────────────────────────
|
||||
|
||||
def set_config_dir(path: Path | None) -> None:
|
||||
global _CONFIG_DIR
|
||||
_CONFIG_DIR = path
|
||||
|
||||
|
||||
# ── Internal helpers ──────────────────────────────────────────────────────────
|
||||
|
||||
def _config_file() -> Path:
|
||||
if _CONFIG_DIR is not None:
|
||||
return _CONFIG_DIR / "label_tool.yaml"
|
||||
return _ROOT / "config" / "label_tool.yaml"
|
||||
|
||||
|
||||
def _load_config() -> dict[str, Any]:
|
||||
f = _config_file()
|
||||
if not f.exists():
|
||||
return {}
|
||||
try:
|
||||
return yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
||||
except yaml.YAMLError as exc:
|
||||
logger.warning("Failed to parse embed_bench config %s: %s", f, exc)
|
||||
return {}
|
||||
|
||||
|
||||
def _ollama_url() -> str:
|
||||
cfg = _load_config()
|
||||
embed_cfg = cfg.get("embed_bench", {}) or {}
|
||||
cforch_cfg = cfg.get("cforch", {}) or {}
|
||||
return (
|
||||
embed_cfg.get("ollama_url")
|
||||
or cforch_cfg.get("ollama_url", "http://localhost:11434")
|
||||
)
|
||||
|
||||
|
||||
def _ratings_path() -> Path:
|
||||
if _CONFIG_DIR is not None:
|
||||
return _CONFIG_DIR / "embed_bench_ratings.jsonl"
|
||||
return _RATINGS_FILE
|
||||
|
||||
|
||||
def _cosine(a: list[float], b: list[float]) -> float:
|
||||
if len(a) != len(b):
|
||||
raise ValueError(
|
||||
f"Embedding dimension mismatch: {len(a)} vs {len(b)}"
|
||||
)
|
||||
dot = sum(x * y for x, y in zip(a, b))
|
||||
mag_a = math.sqrt(sum(x * x for x in a))
|
||||
mag_b = math.sqrt(sum(x * x for x in b))
|
||||
if mag_a == 0.0 or mag_b == 0.0:
|
||||
return 0.0
|
||||
return dot / (mag_a * mag_b)
|
||||
|
||||
|
||||
# ── GET /models ───────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/models")
|
||||
def get_models() -> dict:
|
||||
"""Return Ollama embedding models available on the configured instance."""
|
||||
ollama = _ollama_url()
|
||||
models: list[dict] = []
|
||||
try:
|
||||
resp = httpx.get(f"{ollama}/api/tags", timeout=5.0)
|
||||
resp.raise_for_status()
|
||||
for entry in resp.json().get("models", []):
|
||||
models.append({
|
||||
"name": entry.get("name", ""),
|
||||
"size": entry.get("size", 0),
|
||||
})
|
||||
except httpx.HTTPStatusError as exc:
|
||||
logger.warning("Ollama /api/tags returned HTTP %s: %s", exc.response.status_code, exc)
|
||||
except httpx.RequestError as exc:
|
||||
logger.warning("Failed to reach Ollama for model list: %s", exc)
|
||||
return {"models": models, "ollama_url": ollama}
|
||||
|
||||
|
||||
# ── POST /run ─────────────────────────────────────────────────────────────────
|
||||
|
||||
class RunRequest(BaseModel):
|
||||
corpus: list[str]
|
||||
queries: list[str]
|
||||
models: list[str]
|
||||
top_k: int = 5
|
||||
ollama_url: str = ""
|
||||
|
||||
@field_validator("corpus")
|
||||
@classmethod
|
||||
def corpus_nonempty(cls, v: list[str]) -> list[str]:
|
||||
if not v:
|
||||
raise ValueError("corpus must not be empty")
|
||||
return v
|
||||
|
||||
@field_validator("queries")
|
||||
@classmethod
|
||||
def queries_nonempty(cls, v: list[str]) -> list[str]:
|
||||
if not v:
|
||||
raise ValueError("queries must not be empty")
|
||||
return v
|
||||
|
||||
@field_validator("models")
|
||||
@classmethod
|
||||
def models_nonempty(cls, v: list[str]) -> list[str]:
|
||||
if not v:
|
||||
raise ValueError("models must contain at least one model name")
|
||||
return v
|
||||
|
||||
|
||||
def _embed_texts(ollama: str, model: str, texts: list[str]) -> list[list[float]]:
|
||||
"""Batch-embed texts via Ollama /v1/embeddings. Returns one vector per text."""
|
||||
resp = httpx.post(
|
||||
f"{ollama}/v1/embeddings",
|
||||
json={"model": model, "input": texts},
|
||||
timeout=120.0,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json().get("data", [])
|
||||
return [item["embedding"] for item in data]
|
||||
|
||||
|
||||
def _sse(event: dict) -> str:
|
||||
return f"data: {json.dumps(event)}\n\n"
|
||||
|
||||
|
||||
@router.post("/run")
|
||||
def run_embed_bench(req: RunRequest) -> StreamingResponse:
|
||||
"""Embed corpus + queries with each model; stream SSE results."""
|
||||
global _RUN_ACTIVE
|
||||
|
||||
if _RUN_ACTIVE:
|
||||
raise HTTPException(409, "An embedding benchmark run is already active")
|
||||
|
||||
ollama = req.ollama_url or _ollama_url()
|
||||
|
||||
def _generate():
|
||||
global _RUN_ACTIVE
|
||||
_RUN_ACTIVE = True
|
||||
try:
|
||||
for model_idx, model in enumerate(req.models, start=1):
|
||||
yield _sse({
|
||||
"type": "progress",
|
||||
"msg": f"Indexing corpus with {model} ({model_idx}/{len(req.models)})...",
|
||||
})
|
||||
try:
|
||||
corpus_vecs = _embed_texts(ollama, model, req.corpus)
|
||||
except Exception as exc:
|
||||
yield _sse({"type": "error", "msg": f"Ollama error for {model}: {exc}"})
|
||||
continue
|
||||
|
||||
yield _sse({
|
||||
"type": "progress",
|
||||
"msg": f"Running queries with {model}...",
|
||||
})
|
||||
|
||||
for q_idx, query in enumerate(req.queries):
|
||||
try:
|
||||
q_vecs = _embed_texts(ollama, model, [query])
|
||||
except Exception as exc:
|
||||
yield _sse({"type": "error", "msg": f"Query embed error ({model}): {exc}"})
|
||||
continue
|
||||
q_vec = q_vecs[0]
|
||||
scored = sorted(
|
||||
[
|
||||
{"chunk_idx": i, "text": chunk, "score": round(_cosine(q_vec, cv), 4)}
|
||||
for i, (chunk, cv) in enumerate(zip(req.corpus, corpus_vecs))
|
||||
],
|
||||
key=lambda h: h["score"],
|
||||
reverse=True,
|
||||
)[: req.top_k]
|
||||
yield _sse({
|
||||
"type": "result",
|
||||
"query_idx": q_idx,
|
||||
"query": query,
|
||||
"model": model,
|
||||
"hits": scored,
|
||||
})
|
||||
|
||||
yield _sse({"type": "done"})
|
||||
finally:
|
||||
_RUN_ACTIVE = False
|
||||
|
||||
return StreamingResponse(_generate(), media_type="text/event-stream")
|
||||
|
||||
|
||||
# ── POST /rate ────────────────────────────────────────────────────────────────
|
||||
|
||||
_VALID_RATINGS = {"relevant", "not_relevant"}
|
||||
|
||||
|
||||
class RatingRequest(BaseModel):
|
||||
query: str
|
||||
model: str
|
||||
chunk_text: str
|
||||
chunk_idx: int
|
||||
rating: str
|
||||
|
||||
@field_validator("rating")
|
||||
@classmethod
|
||||
def rating_valid(cls, v: str) -> str:
|
||||
if v not in _VALID_RATINGS:
|
||||
raise ValueError(f"rating must be one of {_VALID_RATINGS}")
|
||||
return v
|
||||
|
||||
|
||||
@router.post("/rate")
|
||||
def rate_result(req: RatingRequest) -> dict:
|
||||
"""Append one rating to the JSONL ratings file."""
|
||||
entry = {
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"query": req.query,
|
||||
"model": req.model,
|
||||
"chunk_idx": req.chunk_idx,
|
||||
"chunk_text": req.chunk_text,
|
||||
"rating": req.rating,
|
||||
}
|
||||
path = _ratings_path()
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with path.open("a", encoding="utf-8") as fh:
|
||||
fh.write(json.dumps(entry) + "\n")
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
# ── GET /export ───────────────────────────────────────────────────────────────
|
||||
|
||||
_CSV_FIELDS = ["timestamp", "query", "model", "chunk_idx", "chunk_text", "rating"]
|
||||
|
||||
|
||||
@router.get("/export")
|
||||
def export_ratings(format: str = "csv") -> Any:
|
||||
"""Download ratings as CSV or JSON."""
|
||||
path = _ratings_path()
|
||||
rows: list[dict] = []
|
||||
if path.exists():
|
||||
for raw in path.read_text(encoding="utf-8").splitlines():
|
||||
raw = raw.strip()
|
||||
if raw:
|
||||
try:
|
||||
rows.append(json.loads(raw))
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
date_str = datetime.now(timezone.utc).strftime("%Y-%m-%d")
|
||||
|
||||
if format == "json":
|
||||
content = json.dumps(rows, ensure_ascii=False, indent=2)
|
||||
return StreamingResponse(
|
||||
iter([content]),
|
||||
media_type="application/json",
|
||||
headers={"Content-Disposition": f'attachment; filename="embed_comparison_{date_str}.json"'},
|
||||
)
|
||||
|
||||
# Default: CSV
|
||||
buf = io.StringIO()
|
||||
writer = csv.DictWriter(buf, fieldnames=_CSV_FIELDS, extrasaction="ignore")
|
||||
writer.writeheader()
|
||||
writer.writerows(rows)
|
||||
return StreamingResponse(
|
||||
iter([buf.getvalue()]),
|
||||
media_type="text/csv",
|
||||
headers={"Content-Disposition": f'attachment; filename="embed_comparison_{date_str}.csv"'},
|
||||
)
|
||||
|
|
@ -1,9 +1,214 @@
|
|||
"""Backward-compat shim -- logic moved to app/data/fetch.py."""
|
||||
import imaplib # noqa: F401 -- re-exported so existing patch("app.imap_fetch.imaplib...") calls still work
|
||||
from app.data.fetch import ( # noqa: F401
|
||||
entry_key,
|
||||
fetch_account_stream,
|
||||
test_connection,
|
||||
_decode_str,
|
||||
_WIDE_TERMS,
|
||||
)
|
||||
"""Avocet — IMAP fetch utilities.
|
||||
|
||||
Shared between app/api.py (FastAPI SSE endpoint) and app/label_tool.py (Streamlit).
|
||||
No Streamlit imports here — stdlib + imaplib only.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import email as _email_lib
|
||||
import hashlib
|
||||
import imaplib
|
||||
import re
|
||||
from datetime import datetime, timedelta
|
||||
from email.header import decode_header as _raw_decode
|
||||
from html.parser import HTMLParser
|
||||
from typing import Any, Iterator
|
||||
|
||||
|
||||
# ── HTML → plain text ────────────────────────────────────────────────────────
|
||||
|
||||
class _TextExtractor(HTMLParser):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._parts: list[str] = []
|
||||
|
||||
def handle_data(self, data: str) -> None:
|
||||
stripped = data.strip()
|
||||
if stripped:
|
||||
self._parts.append(stripped)
|
||||
|
||||
def get_text(self) -> str:
|
||||
return " ".join(self._parts)
|
||||
|
||||
|
||||
def strip_html(html_str: str) -> str:
|
||||
try:
|
||||
ex = _TextExtractor()
|
||||
ex.feed(html_str)
|
||||
return ex.get_text()
|
||||
except Exception:
|
||||
return re.sub(r"<[^>]+>", " ", html_str).strip()
|
||||
|
||||
|
||||
# ── IMAP decode helpers ───────────────────────────────────────────────────────
|
||||
|
||||
def _decode_str(value: str | None) -> str:
|
||||
if not value:
|
||||
return ""
|
||||
parts = _raw_decode(value)
|
||||
out = []
|
||||
for part, enc in parts:
|
||||
if isinstance(part, bytes):
|
||||
out.append(part.decode(enc or "utf-8", errors="replace"))
|
||||
else:
|
||||
out.append(str(part))
|
||||
return " ".join(out).strip()
|
||||
|
||||
|
||||
def _extract_body(msg: Any) -> str:
|
||||
if msg.is_multipart():
|
||||
html_fallback: str | None = None
|
||||
for part in msg.walk():
|
||||
ct = part.get_content_type()
|
||||
if ct == "text/plain":
|
||||
try:
|
||||
charset = part.get_content_charset() or "utf-8"
|
||||
return part.get_payload(decode=True).decode(charset, errors="replace")
|
||||
except Exception:
|
||||
pass
|
||||
elif ct == "text/html" and html_fallback is None:
|
||||
try:
|
||||
charset = part.get_content_charset() or "utf-8"
|
||||
raw = part.get_payload(decode=True).decode(charset, errors="replace")
|
||||
html_fallback = strip_html(raw)
|
||||
except Exception:
|
||||
pass
|
||||
return html_fallback or ""
|
||||
else:
|
||||
try:
|
||||
charset = msg.get_content_charset() or "utf-8"
|
||||
raw = msg.get_payload(decode=True).decode(charset, errors="replace")
|
||||
if msg.get_content_type() == "text/html":
|
||||
return strip_html(raw)
|
||||
return raw
|
||||
except Exception:
|
||||
pass
|
||||
return ""
|
||||
|
||||
|
||||
def entry_key(e: dict) -> str:
|
||||
"""Stable MD5 content-hash for dedup — matches label_tool.py _entry_key."""
|
||||
key = (e.get("subject", "") + (e.get("body", "") or "")[:100])
|
||||
return hashlib.md5(key.encode("utf-8", errors="replace")).hexdigest()
|
||||
|
||||
|
||||
# ── Wide search terms ────────────────────────────────────────────────────────
|
||||
|
||||
_WIDE_TERMS = [
|
||||
"interview", "phone screen", "video call", "zoom link", "schedule a call",
|
||||
"offer letter", "job offer", "offer of employment", "pleased to offer",
|
||||
"unfortunately", "not moving forward", "other candidates", "regret to inform",
|
||||
"no longer", "decided not to", "decided to go with",
|
||||
"opportunity", "interested in your background", "reached out", "great fit",
|
||||
"exciting role", "love to connect",
|
||||
"assessment", "questionnaire", "culture fit", "culture-fit", "online assessment",
|
||||
"application received", "thank you for applying", "application confirmation",
|
||||
"you applied", "your application for",
|
||||
"reschedule", "rescheduled", "new time", "moved to", "postponed", "new date",
|
||||
"job digest", "jobs you may like", "recommended jobs", "jobs for you",
|
||||
"new jobs", "job alert",
|
||||
"came across your profile", "reaching out about", "great fit for a role",
|
||||
"exciting opportunity",
|
||||
"welcome to the team", "start date", "onboarding", "first day", "we're excited to have you",
|
||||
"application", "recruiter", "recruiting", "hiring", "candidate",
|
||||
]
|
||||
|
||||
|
||||
# ── Public API ────────────────────────────────────────────────────────────────
|
||||
|
||||
def test_connection(acc: dict) -> tuple[bool, str, int | None]:
|
||||
"""Connect, login, select folder. Returns (ok, human_message, message_count|None)."""
|
||||
host = acc.get("host", "")
|
||||
port = int(acc.get("port", 993))
|
||||
use_ssl = acc.get("use_ssl", True)
|
||||
username = acc.get("username", "")
|
||||
password = acc.get("password", "")
|
||||
folder = acc.get("folder", "INBOX")
|
||||
if not host or not username or not password:
|
||||
return False, "Host, username, and password are all required.", None
|
||||
try:
|
||||
conn = (imaplib.IMAP4_SSL if use_ssl else imaplib.IMAP4)(host, port)
|
||||
conn.login(username, password)
|
||||
_, data = conn.select(folder, readonly=True)
|
||||
count_raw = data[0].decode() if data and data[0] else "0"
|
||||
count = int(count_raw) if count_raw.isdigit() else 0
|
||||
conn.logout()
|
||||
return True, f"Connected — {count:,} message(s) in {folder}.", count
|
||||
except Exception as exc:
|
||||
return False, str(exc), None
|
||||
|
||||
|
||||
def fetch_account_stream(
|
||||
acc: dict,
|
||||
days_back: int,
|
||||
limit: int,
|
||||
known_keys: set[str],
|
||||
) -> Iterator[dict]:
|
||||
"""Generator — yields progress dicts while fetching emails via IMAP.
|
||||
|
||||
Mutates `known_keys` in place for cross-account dedup within one fetch session.
|
||||
|
||||
Yields event dicts with "type" key:
|
||||
{"type": "start", "account": str, "total_uids": int}
|
||||
{"type": "progress", "account": str, "fetched": int, "total_uids": int}
|
||||
{"type": "done", "account": str, "added": int, "skipped": int, "emails": list}
|
||||
"""
|
||||
name = acc.get("name", acc.get("username", "?"))
|
||||
host = acc.get("host", "imap.gmail.com")
|
||||
port = int(acc.get("port", 993))
|
||||
use_ssl = acc.get("use_ssl", True)
|
||||
username = acc["username"]
|
||||
password = acc["password"]
|
||||
folder = acc.get("folder", "INBOX")
|
||||
since = (datetime.now() - timedelta(days=days_back)).strftime("%d-%b-%Y")
|
||||
|
||||
conn = (imaplib.IMAP4_SSL if use_ssl else imaplib.IMAP4)(host, port)
|
||||
conn.login(username, password)
|
||||
conn.select(folder, readonly=True)
|
||||
|
||||
seen_uids: dict[bytes, None] = {}
|
||||
for term in _WIDE_TERMS:
|
||||
try:
|
||||
_, data = conn.search(None, f'(SUBJECT "{term}" SINCE "{since}")')
|
||||
for uid in (data[0] or b"").split():
|
||||
seen_uids[uid] = None
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
uids = list(seen_uids.keys())[: limit * 3]
|
||||
yield {"type": "start", "account": name, "total_uids": len(uids)}
|
||||
|
||||
emails: list[dict] = []
|
||||
skipped = 0
|
||||
for i, uid in enumerate(uids):
|
||||
if len(emails) >= limit:
|
||||
break
|
||||
if i % 5 == 0:
|
||||
yield {"type": "progress", "account": name, "fetched": len(emails), "total_uids": len(uids)}
|
||||
try:
|
||||
_, raw_data = conn.fetch(uid, "(RFC822)")
|
||||
if not raw_data or not raw_data[0]:
|
||||
continue
|
||||
msg = _email_lib.message_from_bytes(raw_data[0][1])
|
||||
subj = _decode_str(msg.get("Subject", ""))
|
||||
from_addr = _decode_str(msg.get("From", ""))
|
||||
date = _decode_str(msg.get("Date", ""))
|
||||
body = _extract_body(msg)[:800]
|
||||
entry = {"subject": subj, "body": body, "from_addr": from_addr,
|
||||
"date": date, "account": name}
|
||||
k = entry_key(entry)
|
||||
if k not in known_keys:
|
||||
known_keys.add(k)
|
||||
emails.append(entry)
|
||||
else:
|
||||
skipped += 1
|
||||
except Exception:
|
||||
skipped += 1
|
||||
|
||||
try:
|
||||
conn.logout()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
yield {"type": "done", "account": name, "added": len(emails), "skipped": skipped,
|
||||
"emails": emails}
|
||||
|
|
|
|||
|
|
@ -1,3 +0,0 @@
|
|||
"""Backward-compat shim -- logic moved to app/data/imitate.py."""
|
||||
from app.data.imitate import router # noqa: F401
|
||||
from app.data.imitate import set_config_dir, set_data_dir # noqa: F401
|
||||
1186
app/label_tool.py
Normal file
1186
app/label_tool.py
Normal file
File diff suppressed because it is too large
Load diff
1111
app/models.py
1111
app/models.py
File diff suppressed because it is too large
Load diff
535
app/nodes.py
535
app/nodes.py
|
|
@ -1,535 +0,0 @@
|
|||
"""Avocet — Node Management API.
|
||||
|
||||
Proxies cf-orch coordinator and agent APIs to expose per-node GPU state,
|
||||
service affinity management, and Ollama model management.
|
||||
|
||||
Config is read from label_tool.yaml under the `cforch:` key.
|
||||
The `profiles_dir` key (new) points to the cf-orch node profile YAML directory.
|
||||
|
||||
Module-level globals follow the set_config_dir() testability pattern from cforch.py.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import yaml
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_ROOT = Path(__file__).parent.parent
|
||||
_CONFIG_DIR: Path | None = None # override in tests
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# ── Testability seams ──────────────────────────────────────────────────────────
|
||||
|
||||
def set_config_dir(path: Path | None) -> None:
|
||||
global _CONFIG_DIR
|
||||
_CONFIG_DIR = path
|
||||
|
||||
|
||||
# ── Internal helpers ───────────────────────────────────────────────────────────
|
||||
|
||||
def _config_file() -> Path:
|
||||
if _CONFIG_DIR is not None:
|
||||
return _CONFIG_DIR / "label_tool.yaml"
|
||||
return _ROOT / "config" / "label_tool.yaml"
|
||||
|
||||
|
||||
def _load_config() -> dict:
|
||||
"""Read label_tool.yaml cforch section. Returns empty dict on missing or parse error."""
|
||||
f = _config_file()
|
||||
if not f.exists():
|
||||
return {}
|
||||
try:
|
||||
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
||||
return raw.get("cforch", {}) or {}
|
||||
except yaml.YAMLError as exc:
|
||||
logger.warning("Failed to parse config %s: %s", f, exc)
|
||||
return {}
|
||||
|
||||
|
||||
def _profiles_dir() -> Path | None:
|
||||
"""Return the cf-orch node profiles directory, or None if not configured."""
|
||||
cfg = _load_config()
|
||||
pd = cfg.get("profiles_dir", "") or ""
|
||||
if pd:
|
||||
return Path(pd)
|
||||
bench = cfg.get("bench_script", "") or ""
|
||||
if bench:
|
||||
return Path(bench).parent.parent / "profiles" / "nodes"
|
||||
return None
|
||||
|
||||
|
||||
def _profile_path(node_id: str) -> Path | None:
|
||||
"""Return the path to a node's profile YAML, or None if profiles_dir is unknown."""
|
||||
pd = _profiles_dir()
|
||||
if pd is None:
|
||||
return None
|
||||
return pd / f"{node_id}.yaml"
|
||||
|
||||
|
||||
def _load_profile(node_id: str) -> dict | None:
|
||||
"""Load and parse a node profile YAML. Returns None if not found or malformed."""
|
||||
p = _profile_path(node_id)
|
||||
if p is None or not p.exists():
|
||||
return None
|
||||
try:
|
||||
return yaml.safe_load(p.read_text(encoding="utf-8")) or {}
|
||||
except yaml.YAMLError as exc:
|
||||
logger.warning("Malformed profile YAML %s: %s", p, exc)
|
||||
return None
|
||||
|
||||
|
||||
def _get_ollama_url(node_id: str) -> str:
|
||||
"""Derive Ollama URL from the node profile's agent_url (same host, port 11434)."""
|
||||
profile = _load_profile(node_id)
|
||||
if profile:
|
||||
nodes_section = profile.get("nodes", {}) or {}
|
||||
node_entry = nodes_section.get(node_id, {}) or {}
|
||||
agent_url = node_entry.get("agent_url", "") or ""
|
||||
if agent_url:
|
||||
parsed = urlparse(agent_url)
|
||||
return f"{parsed.scheme}://{parsed.hostname}:11434"
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Cannot determine Ollama URL for node {node_id}: no agent_url in profile",
|
||||
)
|
||||
|
||||
|
||||
# ── Endpoints ──────────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/nodes")
|
||||
def list_nodes() -> list:
|
||||
"""Return all nodes with live GPU stats merged with profile YAML."""
|
||||
import httpx
|
||||
|
||||
cfg = _load_config()
|
||||
coordinator_url = cfg.get("coordinator_url", "") or ""
|
||||
if not coordinator_url:
|
||||
return []
|
||||
|
||||
try:
|
||||
r = httpx.get(f"{coordinator_url}/api/nodes", timeout=5.0)
|
||||
r.raise_for_status()
|
||||
coord_nodes: list[dict] = r.json().get("nodes", [])
|
||||
except httpx.HTTPError as exc:
|
||||
logger.warning("Coordinator unreachable: %s", exc)
|
||||
return []
|
||||
|
||||
try:
|
||||
sr = httpx.get(f"{coordinator_url}/api/services", timeout=5.0)
|
||||
sr.raise_for_status()
|
||||
services_data: list[dict] = sr.json().get("services", [])
|
||||
except httpx.HTTPError:
|
||||
logger.warning("Services API unreachable for %s, skipping", coordinator_url)
|
||||
services_data = []
|
||||
|
||||
# Build per-node, per-GPU running services map
|
||||
running: dict[str, dict[int, list[str]]] = {}
|
||||
for svc in services_data:
|
||||
nid = svc.get("node_id", "")
|
||||
gid = svc.get("gpu_id")
|
||||
svc_name = svc.get("service", "")
|
||||
if nid and gid is not None and svc_name:
|
||||
running.setdefault(nid, {}).setdefault(gid, []).append(svc_name)
|
||||
|
||||
result = []
|
||||
for node in coord_nodes:
|
||||
node_id = node.get("node_id", "") or node.get("id", "")
|
||||
profile = _load_profile(node_id) if node_id else None
|
||||
profile_loaded = profile is not None
|
||||
|
||||
gpus = []
|
||||
for gpu in (node.get("gpus", []) or []):
|
||||
gpu_id = gpu.get("gpu_id", gpu.get("id", 0))
|
||||
services_assigned: list[str] = []
|
||||
if profile:
|
||||
node_entry = (profile.get("nodes", {}) or {}).get(node_id, {}) or {}
|
||||
for g in (node_entry.get("gpus", []) or []):
|
||||
if isinstance(g, dict) and g.get("id") == gpu_id:
|
||||
services_assigned = g.get("services", []) or []
|
||||
break
|
||||
gpus.append({
|
||||
"gpu_id": gpu_id,
|
||||
"card": gpu.get("card", ""),
|
||||
"vram_total_mb": gpu.get("vram_total_mb", 0),
|
||||
"vram_used_mb": gpu.get("vram_used_mb", 0),
|
||||
"vram_free_mb": gpu.get("vram_free_mb", 0),
|
||||
"temp_c": gpu.get("temp_c"),
|
||||
"utilization_pct": gpu.get("utilization_pct"),
|
||||
"compute_cap": gpu.get("compute_cap"),
|
||||
"services_assigned": services_assigned,
|
||||
"services_running": running.get(node_id, {}).get(gpu_id, []),
|
||||
})
|
||||
|
||||
services_catalog: dict = {}
|
||||
if profile:
|
||||
for svc_name, svc_info in (profile.get("services", {}) or {}).items():
|
||||
catalog = svc_info.get("catalog", {}) or {}
|
||||
services_catalog[svc_name] = {
|
||||
"min_compute_cap": svc_info.get("min_compute_cap", 0.0),
|
||||
"max_mb": svc_info.get("max_mb", 0),
|
||||
"catalog_size": len(catalog),
|
||||
}
|
||||
|
||||
result.append({
|
||||
"node_id": node_id,
|
||||
"online": node.get("online", True),
|
||||
"agent_url": node.get("agent_url", ""),
|
||||
"gpus": gpus,
|
||||
"profile_loaded": profile_loaded,
|
||||
"services_catalog": services_catalog,
|
||||
})
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@router.get("/nodes/{node_id}/profile")
|
||||
def get_node_profile(node_id: str) -> dict:
|
||||
"""Return the full parsed profile YAML for a node."""
|
||||
p = _profile_path(node_id)
|
||||
if p is None or not p.exists():
|
||||
raise HTTPException(404, f"No profile found for node {node_id}")
|
||||
try:
|
||||
data = yaml.safe_load(p.read_text(encoding="utf-8")) or {}
|
||||
except yaml.YAMLError as exc:
|
||||
raise HTTPException(500, f"Malformed profile YAML: {exc}")
|
||||
return data
|
||||
|
||||
|
||||
class UpdateServicesRequest(BaseModel):
|
||||
services: list[str]
|
||||
|
||||
|
||||
@router.post("/nodes/{node_id}/gpu/{gpu_id}/services")
|
||||
def update_gpu_services(node_id: str, gpu_id: int, body: UpdateServicesRequest) -> dict:
|
||||
"""Set service assignment for a GPU with compatibility validation, then atomic write."""
|
||||
import httpx
|
||||
|
||||
cfg = _load_config()
|
||||
coordinator_url = cfg.get("coordinator_url", "") or ""
|
||||
|
||||
p = _profile_path(node_id)
|
||||
if p is None or not p.exists():
|
||||
raise HTTPException(404, f"No profile found for node {node_id}")
|
||||
|
||||
try:
|
||||
profile = yaml.safe_load(p.read_text(encoding="utf-8")) or {}
|
||||
except yaml.YAMLError as exc:
|
||||
raise HTTPException(500, f"Malformed profile YAML: {exc}")
|
||||
|
||||
nodes_section = profile.get("nodes", {}) or {}
|
||||
node_entry = nodes_section.get(node_id, {}) or {}
|
||||
gpu_list = node_entry.get("gpus", []) or []
|
||||
|
||||
gpu_entry = next(
|
||||
(g for g in gpu_list if isinstance(g, dict) and g.get("id") == gpu_id),
|
||||
None,
|
||||
)
|
||||
if gpu_entry is None:
|
||||
raise HTTPException(404, f"GPU {gpu_id} not found in profile for node {node_id}")
|
||||
|
||||
gpu_compute_cap: float = gpu_entry.get("compute_cap") or 0.0
|
||||
gpu_vram_mb: int = gpu_entry.get("vram_mb") or 0
|
||||
services_def = profile.get("services", {}) or {}
|
||||
|
||||
for svc_name in body.services:
|
||||
if svc_name not in services_def:
|
||||
raise HTTPException(422, f"Service '{svc_name}' not defined in profile services dict")
|
||||
svc = services_def[svc_name]
|
||||
min_cap: float = svc.get("min_compute_cap", 0.0) or 0.0
|
||||
if gpu_compute_cap < min_cap:
|
||||
raise HTTPException(
|
||||
422,
|
||||
f"Service '{svc_name}' requires compute_cap >= {min_cap}; GPU has {gpu_compute_cap}",
|
||||
)
|
||||
catalog = svc.get("catalog", {}) or {}
|
||||
min_catalog_vram = (
|
||||
min((m.get("vram_mb", 0) for m in catalog.values()), default=0)
|
||||
if catalog else svc.get("max_mb", 0)
|
||||
)
|
||||
if gpu_vram_mb < min_catalog_vram:
|
||||
raise HTTPException(
|
||||
422,
|
||||
f"Service '{svc_name}' requires {min_catalog_vram} MB VRAM; GPU has {gpu_vram_mb} MB",
|
||||
)
|
||||
|
||||
# Immutable update of GPU services list
|
||||
new_gpu_list = [
|
||||
({**g, "services": body.services} if isinstance(g, dict) and g.get("id") == gpu_id else g)
|
||||
for g in gpu_list
|
||||
]
|
||||
new_profile = {
|
||||
**profile,
|
||||
"nodes": {
|
||||
**nodes_section,
|
||||
node_id: {**node_entry, "gpus": new_gpu_list},
|
||||
},
|
||||
}
|
||||
|
||||
# Atomic write: write to .tmp then rename
|
||||
tmp_yaml = Path(str(p) + ".tmp")
|
||||
tmp_yaml.write_text(yaml.dump(new_profile, default_flow_style=False), encoding="utf-8")
|
||||
os.replace(tmp_yaml, p)
|
||||
|
||||
# Trigger coordinator profile reload
|
||||
reloaded = False
|
||||
if coordinator_url:
|
||||
try:
|
||||
rr = httpx.post(
|
||||
f"{coordinator_url}/api/nodes/{node_id}/reload-profile", timeout=5.0
|
||||
)
|
||||
reloaded = rr.status_code < 300
|
||||
except Exception as exc:
|
||||
logger.warning("Coordinator reload failed for node %s: %s", node_id, exc)
|
||||
|
||||
return {"ok": True, "reloaded": reloaded, "warnings": []}
|
||||
|
||||
# ── Profile save / generate ────────────────────────────────────────────────────
|
||||
|
||||
class SaveProfileRequest(BaseModel):
|
||||
profile: dict
|
||||
|
||||
|
||||
@router.put("/nodes/{node_id}/profile", status_code=200)
|
||||
def save_profile(node_id: str, body: SaveProfileRequest) -> dict:
|
||||
"""Write a full profile dict to disk as YAML, then trigger coordinator reload."""
|
||||
p = _profile_path(node_id)
|
||||
if p is None:
|
||||
raise HTTPException(500, "profiles_dir not configured in label_tool.yaml")
|
||||
p.parent.mkdir(parents=True, exist_ok=True)
|
||||
tmp = Path(str(p) + ".tmp")
|
||||
tmp.write_text(
|
||||
yaml.dump(body.profile, default_flow_style=False, allow_unicode=True, sort_keys=False),
|
||||
encoding="utf-8",
|
||||
)
|
||||
os.replace(tmp, p)
|
||||
|
||||
cfg = _load_config()
|
||||
coordinator_url = cfg.get("coordinator_url", "") or ""
|
||||
reloaded = False
|
||||
if coordinator_url:
|
||||
try:
|
||||
import httpx
|
||||
rr = httpx.post(f"{coordinator_url}/api/nodes/{node_id}/reload-profile", timeout=5.0)
|
||||
reloaded = rr.status_code < 300
|
||||
except Exception as exc:
|
||||
logger.warning("Coordinator reload failed for %s: %s", node_id, exc)
|
||||
return {"ok": True, "reloaded": reloaded}
|
||||
|
||||
|
||||
@router.post("/nodes/{node_id}/profile/generate")
|
||||
def generate_profile(node_id: str) -> dict:
|
||||
"""Return a profile skeleton seeded from coordinator GPU data.
|
||||
|
||||
If a profile already exists, preserves its services section and only
|
||||
refreshes the nodes hardware section. Never writes to disk — the caller
|
||||
must call PUT /profile to persist.
|
||||
"""
|
||||
import httpx
|
||||
|
||||
cfg = _load_config()
|
||||
coordinator_url = cfg.get("coordinator_url", "") or ""
|
||||
if not coordinator_url:
|
||||
raise HTTPException(503, "coordinator_url not configured")
|
||||
|
||||
try:
|
||||
r = httpx.get(f"{coordinator_url}/api/nodes", timeout=5.0)
|
||||
r.raise_for_status()
|
||||
coord_nodes: list[dict] = r.json().get("nodes", [])
|
||||
except httpx.HTTPError as exc:
|
||||
raise HTTPException(502, f"Coordinator unreachable: {exc}")
|
||||
|
||||
node = next((n for n in coord_nodes if n.get("node_id") == node_id), None)
|
||||
if node is None:
|
||||
raise HTTPException(404, f"Node {node_id!r} not found in coordinator")
|
||||
|
||||
gpus = [
|
||||
{
|
||||
"id": g.get("gpu_id", i),
|
||||
"vram_mb": g.get("vram_total_mb", 0),
|
||||
"compute_cap": g.get("compute_cap", 0.0),
|
||||
"card": g.get("card", g.get("name", "")),
|
||||
"role": "inference",
|
||||
"services": [],
|
||||
}
|
||||
for i, g in enumerate(node.get("gpus", []))
|
||||
]
|
||||
vram_total = max((g["vram_mb"] for g in gpus), default=0)
|
||||
|
||||
existing = _load_profile(node_id) or {}
|
||||
return {
|
||||
"schema_version": existing.get("schema_version", 1),
|
||||
"name": existing.get("name", f"node-{node_id}"),
|
||||
"vram_total_mb": vram_total,
|
||||
"eviction_timeout_s": existing.get("eviction_timeout_s", 10.0),
|
||||
"services": existing.get("services", {}),
|
||||
"nodes": {
|
||||
node_id: {
|
||||
"local_model_root": (
|
||||
(existing.get("nodes", {}) or {})
|
||||
.get(node_id, {})
|
||||
.get("local_model_root", "")
|
||||
),
|
||||
"gpus": gpus,
|
||||
}
|
||||
},
|
||||
"model_size_hints": existing.get("model_size_hints", {}),
|
||||
}
|
||||
|
||||
|
||||
# ── Ollama model management ────────────────────────────────────────────────────
|
||||
|
||||
class PullRequest(BaseModel):
|
||||
name: str
|
||||
|
||||
|
||||
@router.get("/nodes/{node_id}/models/ollama")
|
||||
def list_ollama_models(node_id: str) -> dict:
|
||||
"""Proxy GET {ollama_url}/api/tags for a specific node."""
|
||||
import httpx
|
||||
|
||||
ollama_url = _get_ollama_url(node_id)
|
||||
try:
|
||||
r = httpx.get(f"{ollama_url}/api/tags", timeout=10.0)
|
||||
r.raise_for_status()
|
||||
return r.json()
|
||||
except Exception as exc:
|
||||
return {"error": str(exc)}
|
||||
|
||||
|
||||
@router.post("/nodes/{node_id}/models/ollama/pull")
|
||||
def pull_ollama_model(node_id: str, body: PullRequest) -> StreamingResponse:
|
||||
"""Stream Ollama pull progress as SSE events."""
|
||||
import httpx
|
||||
|
||||
if not body.name:
|
||||
raise HTTPException(400, "name is required")
|
||||
|
||||
ollama_url = _get_ollama_url(node_id)
|
||||
|
||||
def stream():
|
||||
try:
|
||||
with httpx.stream(
|
||||
"POST",
|
||||
f"{ollama_url}/api/pull",
|
||||
json={"name": body.name, "stream": True},
|
||||
timeout=300.0,
|
||||
) as resp:
|
||||
for line in resp.iter_lines():
|
||||
if line:
|
||||
yield f"data: {line}\n\n"
|
||||
except Exception as exc:
|
||||
yield f"data: {json.dumps({'error': str(exc)})}\n\n"
|
||||
|
||||
return StreamingResponse(stream(), media_type="text/event-stream")
|
||||
|
||||
|
||||
@router.delete("/nodes/{node_id}/models/ollama/{name:path}")
|
||||
def delete_ollama_model(node_id: str, name: str) -> dict:
|
||||
"""Proxy DELETE to Ollama for a specific node."""
|
||||
import httpx
|
||||
|
||||
ollama_url = _get_ollama_url(node_id)
|
||||
try:
|
||||
r = httpx.request("DELETE", f"{ollama_url}/api/delete", json={"name": name}, timeout=10.0)
|
||||
if r.status_code == 404:
|
||||
raise HTTPException(404, f"Model '{name}' not found on node {node_id}")
|
||||
r.raise_for_status()
|
||||
return {"ok": True}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
raise HTTPException(502, f"Ollama unreachable: {exc}")
|
||||
|
||||
|
||||
# ── Model deploy (add catalog entry) ──────────────────────────────────────────
|
||||
|
||||
class DeployModelRequest(BaseModel):
|
||||
model_id: str
|
||||
service_type: str
|
||||
vram_mb: int
|
||||
description: str = ""
|
||||
hf_repo: str = ""
|
||||
path: str = "" # explicit path; if empty, constructed from model_base_path + hf_repo slug
|
||||
|
||||
|
||||
@router.post("/nodes/{node_id}/models/deploy", status_code=200)
|
||||
def deploy_model(node_id: str, body: DeployModelRequest) -> dict:
|
||||
"""Register a model in the node's service catalog.
|
||||
|
||||
Adds (or updates) the catalog entry for body.model_id under the given
|
||||
service_type in the node's profile YAML, then triggers a coordinator reload.
|
||||
Does not download the model — that is the user's responsibility.
|
||||
Returns the resolved path so the caller can see where the model should land.
|
||||
"""
|
||||
p = _profile_path(node_id)
|
||||
if p is None or not p.exists():
|
||||
raise HTTPException(404, f"No profile found for node {node_id!r}")
|
||||
|
||||
try:
|
||||
profile = yaml.safe_load(p.read_text(encoding="utf-8")) or {}
|
||||
except yaml.YAMLError as exc:
|
||||
raise HTTPException(500, f"Malformed profile YAML: {exc}")
|
||||
|
||||
services_def = profile.get("services", {}) or {}
|
||||
svc = services_def.get(body.service_type)
|
||||
if svc is None:
|
||||
raise HTTPException(
|
||||
422,
|
||||
f"Service '{body.service_type}' not defined in node '{node_id}' profile; "
|
||||
"add it first via the profile editor",
|
||||
)
|
||||
|
||||
# Resolve path: explicit > model_base_path + hf slug > model_id slug
|
||||
model_path = body.path.strip()
|
||||
if not model_path:
|
||||
base = (svc.get("model_base_path", "") or "").rstrip("/")
|
||||
if not base:
|
||||
raise HTTPException(
|
||||
422,
|
||||
f"Service '{body.service_type}' has no model_base_path; supply an explicit path",
|
||||
)
|
||||
slug_src = body.hf_repo.strip() if body.hf_repo.strip() else body.model_id
|
||||
hf_slug = slug_src.replace("/", "--")
|
||||
model_path = f"{base}/{hf_slug}"
|
||||
|
||||
# Immutable catalog update — spread, never mutate
|
||||
entry: dict = {"path": model_path, "vram_mb": body.vram_mb}
|
||||
if body.description:
|
||||
entry["description"] = body.description
|
||||
new_catalog = {**(svc.get("catalog") or {}), body.model_id: entry}
|
||||
new_svc = {**svc, "catalog": new_catalog}
|
||||
new_services = {**services_def, body.service_type: new_svc}
|
||||
new_profile = {**profile, "services": new_services}
|
||||
|
||||
# Atomic write
|
||||
tmp = Path(str(p) + ".tmp")
|
||||
tmp.write_text(
|
||||
yaml.dump(new_profile, default_flow_style=False, allow_unicode=True, sort_keys=False),
|
||||
encoding="utf-8",
|
||||
)
|
||||
os.replace(tmp, p)
|
||||
|
||||
# Trigger coordinator reload
|
||||
cfg = _load_config()
|
||||
coordinator_url = cfg.get("coordinator_url", "") or ""
|
||||
reloaded = False
|
||||
if coordinator_url:
|
||||
try:
|
||||
import httpx
|
||||
rr = httpx.post(f"{coordinator_url}/api/nodes/{node_id}/reload-profile", timeout=5.0)
|
||||
reloaded = rr.status_code < 300
|
||||
except Exception as exc:
|
||||
logger.warning("Coordinator reload failed for %s: %s", node_id, exc)
|
||||
|
||||
return {"ok": True, "reloaded": reloaded, "path": model_path}
|
||||
|
|
@ -1,327 +0,0 @@
|
|||
"""Avocet — CF planning benchmark integration API.
|
||||
|
||||
Wraps scripts/benchmark_plans.py and exposes it via the Avocet API.
|
||||
Connection config (api_base) is read from label_tool.yaml under the
|
||||
`plans_bench:` key (optional; falls back to localhost:8080).
|
||||
|
||||
All endpoints are registered on `router` (FastAPI APIRouter).
|
||||
api.py includes this router with prefix="/api/plans-bench".
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import subprocess as _subprocess
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
import yaml
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_ROOT = Path(__file__).parent.parent
|
||||
_CONFIG_DIR: Path | None = None # override in tests via set_config_dir()
|
||||
_BENCH_RUNNING: bool = False
|
||||
_bench_proc: Any = None
|
||||
|
||||
_BENCH_SCRIPT = _ROOT / "scripts" / "benchmark_plans.py"
|
||||
_RESULTS_DIR = _ROOT / "data" / "plans_bench_results"
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# ── Registered model shortcuts (mirrors benchmark_plans.MODEL_REGISTRY) ────────
|
||||
# Kept here so the UI can list them without importing the script.
|
||||
|
||||
MODEL_REGISTRY: dict[str, str] = {
|
||||
"deepseek-r1-1.5b": "DeepSeek R1 1.5B distill (cf-orch catalog key)",
|
||||
"deepseek-r1-7b-4bit": "DeepSeek R1 7B distill, 4-bit (cf-orch catalog key)",
|
||||
"deepseek-r1-0528-qwen3-8b-gguf": "DeepSeek R1 0528 Qwen3 8B GGUF (4 nodes)",
|
||||
"deepseek-coder-6.7b-4bit": "DeepSeek Coder 6.7B instruct, 4-bit (cf-orch catalog key)",
|
||||
"granite-4.1-8b": "IBM Granite 4.1 8B, 4-bit (cf-orch catalog key)",
|
||||
"qwen2.5-3b": "Qwen 2.5 3B Q4 GGUF (cf-orch catalog key)",
|
||||
"qwen2.5-7b": "Qwen 2.5 7B Q4 GGUF (cf-orch catalog key)",
|
||||
"capybarahermes-2.5-mistral-7b-gguf": "CapybaraHermes 2.5 Mistral 7B GGUF (4 nodes)",
|
||||
"darwin-9b-opus-gguf": "Darwin 9B Opus GGUF -- long-form writing (3 nodes)",
|
||||
}
|
||||
|
||||
RUBRIC_LABELS: dict[str, str] = {
|
||||
"task_structure": "Task structure (checkboxes + commits)",
|
||||
"tier_awareness": "Tier awareness (Free/Paid/Premium/Ultra)",
|
||||
"privacy_pillar": "Privacy pillar (local-first, no logging)",
|
||||
"safety_pillar": "Safety pillar (human approval, reversibility)",
|
||||
"accessibility": "Accessibility (ND/adaptive users)",
|
||||
"license_split": "License awareness (MIT vs BSL)",
|
||||
"file_paths": "File paths (plausible project paths)",
|
||||
"cf_conventions": "CF conventions (conda, manage.sh, /Library/…)",
|
||||
"length_ok": "Response length (200–2500 words)",
|
||||
}
|
||||
|
||||
|
||||
# ── Testability seam ───────────────────────────────────────────────────────────
|
||||
|
||||
def set_config_dir(path: Path | None) -> None:
|
||||
global _CONFIG_DIR
|
||||
_CONFIG_DIR = path
|
||||
|
||||
|
||||
# ── Internal helpers ───────────────────────────────────────────────────────────
|
||||
|
||||
def _config_file() -> Path:
|
||||
if _CONFIG_DIR is not None:
|
||||
return _CONFIG_DIR / "label_tool.yaml"
|
||||
return _ROOT / "config" / "label_tool.yaml"
|
||||
|
||||
|
||||
def _load_config() -> dict:
|
||||
f = _config_file()
|
||||
cforch_cfg: dict = {}
|
||||
bench_cfg: dict = {}
|
||||
if f.exists():
|
||||
try:
|
||||
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
||||
cforch_cfg = raw.get("cforch", {}) or {}
|
||||
bench_cfg = raw.get("plans_bench", {}) or {}
|
||||
except yaml.YAMLError as exc:
|
||||
logger.warning("Failed to parse plans_bench config %s: %s", f, exc)
|
||||
return {
|
||||
"coordinator_url": cforch_cfg.get("coordinator_url",
|
||||
bench_cfg.get("coordinator_url", "http://10.1.10.71:7700")),
|
||||
"python_bin": cforch_cfg.get("python_bin",
|
||||
bench_cfg.get("python_bin", "/devl/miniconda3/envs/cf/bin/python")),
|
||||
}
|
||||
|
||||
|
||||
def _results_file(run_id: str) -> Path:
|
||||
return _RESULTS_DIR / f"{run_id}.json"
|
||||
|
||||
|
||||
# ── GET /models ────────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/models")
|
||||
def get_models() -> dict:
|
||||
"""Return registered model shortcuts, live cf-orch catalog, and rubric labels."""
|
||||
cfg = _load_config()
|
||||
|
||||
cforch_models: list[dict] = []
|
||||
try:
|
||||
resp = httpx.get(
|
||||
f"{cfg['coordinator_url']}/api/services/cf-text/catalog",
|
||||
timeout=5.0,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
for model_id, entry in resp.json().items():
|
||||
if isinstance(entry, dict):
|
||||
cforch_models.append({
|
||||
"id": model_id,
|
||||
"name": model_id,
|
||||
"vram_mb": entry.get("vram_mb"),
|
||||
"description": entry.get("description", ""),
|
||||
})
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to fetch cf-orch catalog: %s", exc)
|
||||
|
||||
return {
|
||||
"registry": [
|
||||
{"key": k, "description": v}
|
||||
for k, v in MODEL_REGISTRY.items()
|
||||
],
|
||||
"cforch_models": cforch_models,
|
||||
"coordinator_url": cfg["coordinator_url"],
|
||||
"rubric_labels": RUBRIC_LABELS,
|
||||
}
|
||||
|
||||
|
||||
# ── GET /run ───────────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/run")
|
||||
def run_plans_benchmark(
|
||||
models: str = Query(..., description="Comma-separated model IDs (registry keys or cf-orch model names)"),
|
||||
prompt_ids: str = Query("", description="Comma-separated prompt IDs to run (empty = all 10)"),
|
||||
use_cforch: bool = Query(True, description="Route inference through cf-orch coordinator"),
|
||||
api_base: str = Query("", description="Direct API base URL when not using cf-orch"),
|
||||
workers: int = Query(1, ge=1, le=8, description="Number of models to benchmark concurrently"),
|
||||
) -> StreamingResponse:
|
||||
"""Spawn benchmark_plans.py and stream stdout as SSE progress events.
|
||||
|
||||
On successful completion emits a `type: result` event with parsed JSON
|
||||
and saves results to data/plans_bench_results/<run_id>.json.
|
||||
"""
|
||||
global _BENCH_RUNNING, _bench_proc
|
||||
|
||||
if _BENCH_RUNNING:
|
||||
raise HTTPException(409, "A planning benchmark is already running")
|
||||
|
||||
cfg = _load_config()
|
||||
python_bin = cfg["python_bin"]
|
||||
coordinator_url = cfg["coordinator_url"]
|
||||
|
||||
model_keys = [m.strip() for m in models.split(",") if m.strip()]
|
||||
if not model_keys:
|
||||
raise HTTPException(400, "At least one model key is required")
|
||||
|
||||
run_id = datetime.now(tz=timezone.utc).strftime("plans_%Y-%m-%d_%H%M%S")
|
||||
output_path = _results_file(run_id)
|
||||
_RESULTS_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def generate():
|
||||
global _BENCH_RUNNING, _bench_proc
|
||||
|
||||
if not _BENCH_SCRIPT.exists():
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': f'benchmark_plans.py not found at {_BENCH_SCRIPT}'})}\n\n"
|
||||
return
|
||||
|
||||
cmd = [python_bin, str(_BENCH_SCRIPT)]
|
||||
if len(model_keys) > 1:
|
||||
cmd.extend(["--compare"] + model_keys)
|
||||
else:
|
||||
cmd.extend(["--model", model_keys[0]])
|
||||
|
||||
if use_cforch:
|
||||
cmd.extend(["--cforch", "--cforch-url", coordinator_url])
|
||||
elif api_base.strip():
|
||||
cmd.extend(["--api-base", api_base.strip()])
|
||||
|
||||
cmd.extend(["--verbose", "--output", str(output_path)])
|
||||
if workers > 1:
|
||||
cmd.extend(["--workers", str(workers)])
|
||||
|
||||
if prompt_ids.strip():
|
||||
cmd.extend(["--prompts"] + [p.strip() for p in prompt_ids.split(",") if p.strip()])
|
||||
|
||||
_BENCH_RUNNING = True
|
||||
try:
|
||||
proc = _subprocess.Popen(
|
||||
cmd,
|
||||
stdout=_subprocess.PIPE,
|
||||
stderr=_subprocess.STDOUT,
|
||||
text=True,
|
||||
bufsize=1,
|
||||
cwd=str(_ROOT),
|
||||
)
|
||||
_bench_proc = proc
|
||||
try:
|
||||
for line in proc.stdout:
|
||||
line = line.rstrip()
|
||||
if line:
|
||||
yield f"data: {json.dumps({'type': 'progress', 'message': line})}\n\n"
|
||||
proc.wait()
|
||||
if proc.returncode == 0 and output_path.exists():
|
||||
try:
|
||||
results = json.loads(output_path.read_text(encoding="utf-8"))
|
||||
yield f"data: {json.dumps({'type': 'result', 'run_id': run_id, 'results': results})}\n\n"
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to read plans benchmark output: %s", exc)
|
||||
yield f"data: {json.dumps({'type': 'complete'})}\n\n"
|
||||
else:
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': f'Process exited with code {proc.returncode}'})}\n\n"
|
||||
finally:
|
||||
_bench_proc = None
|
||||
except Exception as exc:
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': str(exc)})}\n\n"
|
||||
finally:
|
||||
_BENCH_RUNNING = False
|
||||
|
||||
return StreamingResponse(
|
||||
generate(),
|
||||
media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
|
||||
)
|
||||
|
||||
|
||||
# ── GET /results ───────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/results")
|
||||
def list_results() -> list[dict]:
|
||||
"""List past planning benchmark runs, newest first."""
|
||||
if not _RESULTS_DIR.exists():
|
||||
return []
|
||||
|
||||
runs: list[dict] = []
|
||||
for f in sorted(_RESULTS_DIR.glob("plans_*.json"), reverse=True):
|
||||
run_id = f.stem
|
||||
try:
|
||||
data: dict = json.loads(f.read_text(encoding="utf-8"))
|
||||
model_keys = list(data.keys())
|
||||
# Average total_score across all models and prompts
|
||||
all_scores = [
|
||||
r["total_score"]
|
||||
for results in data.values()
|
||||
for r in results
|
||||
if not r.get("error")
|
||||
]
|
||||
avg_score = round(sum(all_scores) / len(all_scores), 3) if all_scores else 0.0
|
||||
except Exception:
|
||||
model_keys = []
|
||||
avg_score = 0.0
|
||||
|
||||
# Parse display date from run_id (plans_2026-04-27_143022)
|
||||
try:
|
||||
date_part = run_id.removeprefix("plans_") # 2026-04-27_143022
|
||||
date, time = date_part.split("_")
|
||||
display_date = f"{date} {time[:2]}:{time[2:4]}"
|
||||
except Exception:
|
||||
display_date = run_id
|
||||
|
||||
runs.append({
|
||||
"run_id": run_id,
|
||||
"filename": f.name,
|
||||
"date": display_date,
|
||||
"models": model_keys,
|
||||
"avg_score": avg_score,
|
||||
})
|
||||
|
||||
return runs
|
||||
|
||||
|
||||
@router.get("/results/latest")
|
||||
def get_latest_results() -> dict:
|
||||
"""Return the most recent planning benchmark results dict."""
|
||||
if not _RESULTS_DIR.exists():
|
||||
raise HTTPException(404, "No benchmark results found")
|
||||
files = sorted(_RESULTS_DIR.glob("plans_*.json"))
|
||||
if not files:
|
||||
raise HTTPException(404, "No benchmark results found")
|
||||
try:
|
||||
return json.loads(files[-1].read_text(encoding="utf-8"))
|
||||
except Exception as exc:
|
||||
raise HTTPException(500, f"Failed to read results: {exc}") from exc
|
||||
|
||||
|
||||
@router.get("/results/{run_id}")
|
||||
def get_results_by_run_id(run_id: str) -> dict:
|
||||
"""Return planning benchmark results for a specific run."""
|
||||
if not run_id.startswith("plans_"):
|
||||
raise HTTPException(400, "Invalid run_id — expected plans_YYYY-MM-DD_HHMMSS")
|
||||
f = _results_file(run_id)
|
||||
if not f.exists():
|
||||
raise HTTPException(404, f"Results not found: {run_id}")
|
||||
try:
|
||||
return json.loads(f.read_text(encoding="utf-8"))
|
||||
except Exception as exc:
|
||||
raise HTTPException(500, f"Failed to read results: {exc}") from exc
|
||||
|
||||
|
||||
# ── POST /cancel ───────────────────────────────────────────────────────────────
|
||||
|
||||
@router.post("/cancel")
|
||||
def cancel_plans_benchmark() -> dict:
|
||||
"""Kill the running planning benchmark subprocess."""
|
||||
global _BENCH_RUNNING, _bench_proc
|
||||
|
||||
if not _BENCH_RUNNING:
|
||||
raise HTTPException(404, "No planning benchmark is currently running")
|
||||
|
||||
if _bench_proc is not None:
|
||||
try:
|
||||
_bench_proc.terminate()
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to terminate plans benchmark: %s", exc)
|
||||
|
||||
_BENCH_RUNNING = False
|
||||
_bench_proc = None
|
||||
return {"status": "cancelled"}
|
||||
|
|
@ -1,8 +0,0 @@
|
|||
"""Backward-compat shim -- logic moved to app/data/corrections.py."""
|
||||
from app.data.corrections import ( # noqa: F401
|
||||
router,
|
||||
set_data_dir as set_sft_data_dir,
|
||||
set_config_dir as set_sft_config_dir,
|
||||
set_default_bench_results_dir,
|
||||
_DEFAULT_BENCH_RESULTS_DIR,
|
||||
)
|
||||
427
app/style.py
427
app/style.py
|
|
@ -1,427 +0,0 @@
|
|||
"""Avocet — Writing style benchmark integration API.
|
||||
|
||||
Wraps scripts/benchmark_style.py and exposes it via the Avocet API.
|
||||
Connection config (coordinator_url, ollama_url, python_bin) is read
|
||||
from label_tool.yaml under the `cforch:` key — the same block used
|
||||
by cforch.py, so no new config section is needed.
|
||||
|
||||
All endpoints are registered on `router` (a FastAPI APIRouter).
|
||||
api.py includes this router with prefix="/api/style".
|
||||
|
||||
Module-level globals (_BENCH_RUNNING, _bench_proc) follow the same
|
||||
testability pattern as cforch.py.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import subprocess as _subprocess
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
import yaml
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_ROOT = Path(__file__).parent.parent
|
||||
_CONFIG_DIR: Path | None = None # override in tests via set_config_dir()
|
||||
_BENCH_RUNNING: bool = False
|
||||
_bench_proc: Any = None
|
||||
|
||||
_BENCH_SCRIPT = _ROOT / "scripts" / "benchmark_style.py"
|
||||
_RESULTS_DIR = _ROOT / "benchmark_results"
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# ── Testability seams ──────────────────────────────────────────────────────────
|
||||
|
||||
def set_config_dir(path: Path | None) -> None:
|
||||
global _CONFIG_DIR
|
||||
_CONFIG_DIR = path
|
||||
|
||||
|
||||
# ── Internal helpers ───────────────────────────────────────────────────────────
|
||||
|
||||
def _config_file() -> Path:
|
||||
if _CONFIG_DIR is not None:
|
||||
return _CONFIG_DIR / "label_tool.yaml"
|
||||
return _ROOT / "config" / "label_tool.yaml"
|
||||
|
||||
|
||||
def _load_config() -> dict:
|
||||
"""Read label_tool.yaml cforch section for coordinator/ollama/python config."""
|
||||
f = _config_file()
|
||||
file_cfg: dict = {}
|
||||
if f.exists():
|
||||
try:
|
||||
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
||||
file_cfg = raw.get("cforch", {}) or {}
|
||||
except yaml.YAMLError as exc:
|
||||
logger.warning("Failed to parse style config %s: %s", f, exc)
|
||||
return {
|
||||
"coordinator_url": file_cfg.get("coordinator_url", "http://10.1.10.71:7700"),
|
||||
"ollama_url": file_cfg.get("ollama_url", "http://localhost:11434"),
|
||||
"python_bin": file_cfg.get("python_bin", "/devl/miniconda3/envs/cf/bin/python"),
|
||||
}
|
||||
|
||||
|
||||
# ── GET /models ────────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/models")
|
||||
def get_models() -> dict:
|
||||
"""Return available models grouped by source.
|
||||
|
||||
- ollama: fetched live from /api/tags (includes any models downloaded
|
||||
via the Models view — automatically in sync)
|
||||
- cf_text: fetched from cf-orch catalog endpoint (requires node profile
|
||||
entry + coordinator restart when new GGUFs are added)
|
||||
"""
|
||||
cfg = _load_config()
|
||||
|
||||
# Ollama models — live query so newly downloaded models appear immediately
|
||||
ollama_models: list[dict] = []
|
||||
try:
|
||||
resp = httpx.get(f"{cfg['ollama_url']}/api/tags", timeout=5.0)
|
||||
resp.raise_for_status()
|
||||
for m in resp.json().get("models", []):
|
||||
name = m.get("name", "")
|
||||
if name:
|
||||
size_bytes = m.get("size", 0)
|
||||
ollama_models.append({
|
||||
"id": name,
|
||||
"name": name,
|
||||
"source": "ollama",
|
||||
"size_mb": round(size_bytes / (1024 * 1024)) if size_bytes else None,
|
||||
"vram_mb": None,
|
||||
})
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to fetch ollama models: %s", exc)
|
||||
|
||||
# cf-text catalog — fetched from cf-orch coordinator
|
||||
cftext_models: list[dict] = []
|
||||
try:
|
||||
resp = httpx.get(
|
||||
f"{cfg['coordinator_url']}/api/services/cf-text/catalog",
|
||||
timeout=5.0,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
for model_id, entry in resp.json().items():
|
||||
if isinstance(entry, dict):
|
||||
cftext_models.append({
|
||||
"id": model_id,
|
||||
"name": model_id,
|
||||
"source": "cf-text",
|
||||
"vram_mb": entry.get("vram_mb"),
|
||||
"description": entry.get("description", ""),
|
||||
})
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to fetch cf-text catalog: %s", exc)
|
||||
|
||||
return {"ollama": ollama_models, "cf_text": cftext_models}
|
||||
|
||||
|
||||
# ── GET /run ───────────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/run")
|
||||
def run_style_benchmark(
|
||||
models: str = Query("", description="Comma-separated model IDs (empty = all)"),
|
||||
use_cforch: bool = Query(False),
|
||||
max_vram: int = Query(7200, description="Max VRAM MB for cf-orch OOM filter"),
|
||||
include_large: bool = Query(False, description="Include large (30B+) ollama models"),
|
||||
workers: int = Query(1, description="Parallel workers — run N models simultaneously"),
|
||||
) -> StreamingResponse:
|
||||
"""Spawn benchmark_style.py and stream stdout as SSE progress events.
|
||||
|
||||
On successful completion, emits a final `type: result` event containing
|
||||
the parsed JSON from the newest style_*.json file.
|
||||
"""
|
||||
global _BENCH_RUNNING, _bench_proc
|
||||
|
||||
if _BENCH_RUNNING:
|
||||
raise HTTPException(409, "A writing style benchmark is already running")
|
||||
|
||||
cfg = _load_config()
|
||||
python_bin = cfg["python_bin"]
|
||||
|
||||
def generate():
|
||||
global _BENCH_RUNNING, _bench_proc
|
||||
|
||||
if not _BENCH_SCRIPT.exists():
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': f'benchmark_style.py not found at {_BENCH_SCRIPT}'})}\n\n"
|
||||
return
|
||||
|
||||
cmd = [python_bin, str(_BENCH_SCRIPT), "run"]
|
||||
|
||||
if models:
|
||||
cmd.extend(["--models", ",".join(m.strip() for m in models.split(",") if m.strip())])
|
||||
if use_cforch:
|
||||
cmd.extend(["--cforch", "--cforch-url", cfg["coordinator_url"],
|
||||
"--max-vram", str(max_vram)])
|
||||
if include_large:
|
||||
cmd.append("--include-large")
|
||||
if workers > 1:
|
||||
cmd.extend(["--workers", str(workers)])
|
||||
|
||||
_BENCH_RUNNING = True
|
||||
try:
|
||||
proc = _subprocess.Popen(
|
||||
cmd,
|
||||
stdout=_subprocess.PIPE,
|
||||
stderr=_subprocess.STDOUT,
|
||||
text=True,
|
||||
bufsize=1,
|
||||
cwd=str(_ROOT),
|
||||
)
|
||||
_bench_proc = proc
|
||||
try:
|
||||
for line in proc.stdout:
|
||||
line = line.rstrip()
|
||||
if line:
|
||||
yield f"data: {json.dumps({'type': 'progress', 'message': line})}\n\n"
|
||||
proc.wait()
|
||||
if proc.returncode == 0:
|
||||
result_files = sorted(_RESULTS_DIR.glob("style_*.json"))
|
||||
if result_files:
|
||||
try:
|
||||
results = json.loads(result_files[-1].read_text(encoding="utf-8"))
|
||||
yield f"data: {json.dumps({'type': 'result', 'results': results, 'filename': result_files[-1].name})}\n\n"
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to read style results: %s", exc)
|
||||
yield f"data: {json.dumps({'type': 'complete'})}\n\n"
|
||||
else:
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': f'Process exited with code {proc.returncode}'})}\n\n"
|
||||
finally:
|
||||
_bench_proc = None
|
||||
except Exception as exc:
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': str(exc)})}\n\n"
|
||||
finally:
|
||||
_BENCH_RUNNING = False
|
||||
|
||||
return StreamingResponse(
|
||||
generate(),
|
||||
media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
|
||||
)
|
||||
|
||||
|
||||
# ── GET /results ───────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/results")
|
||||
def list_results() -> list[dict]:
|
||||
"""List past writing style benchmark runs, newest first.
|
||||
|
||||
Returns lightweight summaries (date, model count, top score).
|
||||
Use /results/{filename} to fetch full model-level detail.
|
||||
"""
|
||||
if not _RESULTS_DIR.exists():
|
||||
return []
|
||||
|
||||
runs: list[dict] = []
|
||||
for f in sorted(_RESULTS_DIR.glob("style_*.json"), reverse=True):
|
||||
stem = f.stem # style_2026-04-22_1502
|
||||
date_str = stem.removeprefix("style_") # 2026-04-22_1502
|
||||
try:
|
||||
date_part, time_part = date_str.split("_")
|
||||
display_date = f"{date_part} {time_part[:2]}:{time_part[2:]}"
|
||||
except Exception:
|
||||
display_date = date_str
|
||||
|
||||
try:
|
||||
results = json.loads(f.read_text(encoding="utf-8"))
|
||||
top_score = max((r.get("avg_score", 0) for r in results), default=0)
|
||||
model_count = len(results)
|
||||
except Exception:
|
||||
top_score = 0
|
||||
model_count = 0
|
||||
|
||||
runs.append({
|
||||
"filename": f.name,
|
||||
"date": display_date,
|
||||
"model_count": model_count,
|
||||
"top_score": round(top_score, 1),
|
||||
})
|
||||
|
||||
return runs
|
||||
|
||||
|
||||
@router.get("/results/latest")
|
||||
def get_latest_results() -> list[dict]:
|
||||
"""Return the latest writing style benchmark result list."""
|
||||
if not _RESULTS_DIR.exists():
|
||||
raise HTTPException(404, "No benchmark results found")
|
||||
files = sorted(_RESULTS_DIR.glob("style_*.json"))
|
||||
if not files:
|
||||
raise HTTPException(404, "No benchmark results found")
|
||||
try:
|
||||
return json.loads(files[-1].read_text(encoding="utf-8"))
|
||||
except Exception as exc:
|
||||
raise HTTPException(500, f"Failed to read results: {exc}") from exc
|
||||
|
||||
|
||||
@router.get("/results/{filename}")
|
||||
def get_results_by_filename(filename: str) -> list[dict]:
|
||||
"""Return writing style benchmark results for a specific run file."""
|
||||
if not filename.startswith("style_") or not filename.endswith(".json"):
|
||||
raise HTTPException(400, "Invalid filename — expected style_*.json")
|
||||
f = _RESULTS_DIR / filename
|
||||
if not f.exists():
|
||||
raise HTTPException(404, f"Results file not found: {filename}")
|
||||
try:
|
||||
return json.loads(f.read_text(encoding="utf-8"))
|
||||
except Exception as exc:
|
||||
raise HTTPException(500, f"Failed to read results: {exc}") from exc
|
||||
|
||||
|
||||
# ── POST /send-to-corrections ──────────────────────────────────────────────────
|
||||
|
||||
class SendToCorrectionsRequest(BaseModel):
|
||||
filename: str # style_YYYY-MM-DD_HHMM.json — the source run file
|
||||
model_ids: list[str] = [] # empty = all models in the run
|
||||
|
||||
|
||||
@router.post("/send-to-corrections")
|
||||
def send_to_corrections(req: SendToCorrectionsRequest) -> dict:
|
||||
"""Push writing style benchmark outputs into the SFT corrections queue.
|
||||
|
||||
Each prompt_result from the selected models becomes one SFT candidate
|
||||
with status='needs_review'. Duplicates are skipped via the 'id' field
|
||||
(hash of model_id + tag).
|
||||
"""
|
||||
if not req.filename.startswith("style_") or not req.filename.endswith(".json"):
|
||||
raise HTTPException(400, "Invalid filename")
|
||||
|
||||
src = _RESULTS_DIR / req.filename
|
||||
if not src.exists():
|
||||
raise HTTPException(404, f"Results file not found: {req.filename}")
|
||||
|
||||
try:
|
||||
run_results: list[dict] = json.loads(src.read_text(encoding="utf-8"))
|
||||
except Exception as exc:
|
||||
raise HTTPException(500, f"Failed to read results: {exc}") from exc
|
||||
|
||||
# Resolve sft_candidates.jsonl path (same logic as sft.py)
|
||||
sft_data_dir = _ROOT / "data"
|
||||
sft_file = sft_data_dir / "sft_candidates.jsonl"
|
||||
|
||||
# Load existing IDs to deduplicate
|
||||
existing_ids: set[str] = set()
|
||||
if sft_file.exists():
|
||||
for line in sft_file.read_text(encoding="utf-8").splitlines():
|
||||
line = line.strip()
|
||||
if line:
|
||||
try:
|
||||
existing_ids.add(json.loads(line)["id"])
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
run_id = req.filename.removesuffix(".json") # style_2026-04-22_1502
|
||||
timestamp = datetime.now(tz=timezone.utc).isoformat()
|
||||
|
||||
new_candidates: list[dict] = []
|
||||
for model_result in run_results:
|
||||
model_id = model_result.get("model_id", "")
|
||||
if req.model_ids and model_id not in req.model_ids:
|
||||
continue
|
||||
for pr in model_result.get("prompt_results", []):
|
||||
tag = pr.get("tag", "")
|
||||
# Stable id: deterministic hash of run + model + prompt tag
|
||||
candidate_id = str(uuid.uuid5(
|
||||
uuid.NAMESPACE_URL,
|
||||
f"style-benchmark/{run_id}/{model_id}/{tag}",
|
||||
))
|
||||
if candidate_id in existing_ids:
|
||||
continue
|
||||
|
||||
score_pct = pr.get("score", 0.0) / 100.0
|
||||
signals = pr.get("signals", {})
|
||||
|
||||
# Build the prompt message list matching the benchmark's actual request
|
||||
prompt_messages = [
|
||||
{"role": "system", "content": _STYLE_SYSTEM_PROMPT},
|
||||
{"role": "user", "content": pr.get("user_prompt", tag)},
|
||||
]
|
||||
|
||||
new_candidates.append({
|
||||
"id": candidate_id,
|
||||
"source": "style-benchmark",
|
||||
"benchmark_run_id": run_id,
|
||||
"timestamp": timestamp,
|
||||
"status": "needs_review",
|
||||
"prompt_messages": prompt_messages,
|
||||
"model_response": pr.get("output", ""),
|
||||
"corrected_response": None,
|
||||
"quality_score": round(score_pct, 4),
|
||||
"failure_reason": _build_failure_reason(pr, signals),
|
||||
"failure_category": None,
|
||||
"task_id": f"style/{tag}",
|
||||
"task_type": "style-match",
|
||||
"task_name": tag.replace("_", " ").title(),
|
||||
"model_id": model_id,
|
||||
"model_name": model_id,
|
||||
"node_id": "",
|
||||
"gpu_id": 0,
|
||||
"tokens_per_sec": 0,
|
||||
})
|
||||
existing_ids.add(candidate_id)
|
||||
|
||||
if new_candidates:
|
||||
sft_data_dir.mkdir(parents=True, exist_ok=True)
|
||||
with open(sft_file, "a", encoding="utf-8") as fh:
|
||||
for c in new_candidates:
|
||||
fh.write(json.dumps(c) + "\n")
|
||||
|
||||
return {"imported": len(new_candidates), "skipped": 0}
|
||||
|
||||
|
||||
# Excerpt of the system prompt used in benchmark_style.py — reproduced here
|
||||
# so the SFT candidate captures the full generation context.
|
||||
_STYLE_SYSTEM_PROMPT = (
|
||||
"You are a writing assistant. Your job is to write a Reddit reply that matches "
|
||||
"the voice, tone, and style of the provided samples exactly.\n\n"
|
||||
"Voice characteristics:\n"
|
||||
"- Casual engineer tone. Short punchy sentences.\n"
|
||||
"- No em dashes. No semicolons. No filler phrases.\n"
|
||||
"- Direct. Opinionated. Community-first."
|
||||
)
|
||||
|
||||
|
||||
def _build_failure_reason(pr: dict, signals: dict) -> str | None:
|
||||
"""Return a human-readable failure reason string if there are violations."""
|
||||
reasons = []
|
||||
if signals.get("em_dash_count", 0) > 0:
|
||||
reasons.append(f"{signals['em_dash_count']} em dash(es)")
|
||||
if signals.get("semicolon_count", 0) > 0:
|
||||
reasons.append(f"{signals['semicolon_count']} semicolon(s)")
|
||||
if signals.get("filler_hits"):
|
||||
reasons.append(f"filler phrases: {', '.join(signals['filler_hits'])}")
|
||||
if not pr.get("output", "").strip():
|
||||
reasons.append("empty output")
|
||||
return "; ".join(reasons) if reasons else None
|
||||
|
||||
|
||||
# ── POST /cancel ───────────────────────────────────────────────────────────────
|
||||
|
||||
@router.post("/cancel")
|
||||
def cancel_style_benchmark() -> dict:
|
||||
"""Kill the running writing style benchmark subprocess."""
|
||||
global _BENCH_RUNNING, _bench_proc
|
||||
|
||||
if not _BENCH_RUNNING:
|
||||
raise HTTPException(404, "No writing style benchmark is currently running")
|
||||
|
||||
if _bench_proc is not None:
|
||||
try:
|
||||
_bench_proc.terminate()
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to terminate style benchmark: %s", exc)
|
||||
|
||||
_BENCH_RUNNING = False
|
||||
_bench_proc = None
|
||||
return {"status": "cancelled"}
|
||||
|
|
@ -1,339 +0,0 @@
|
|||
"""Avocet -- train job queue API.
|
||||
|
||||
SQLite-backed job queue for finetune jobs. Replaces the ad-hoc
|
||||
_running_procs dict in api.py with a persistent, inspectable queue.
|
||||
|
||||
Routes (all under /api/train when api.py mounts with prefix="/api/train"):
|
||||
GET /jobs -- list all jobs, newest first
|
||||
POST /jobs -- create a new job
|
||||
GET /jobs/{id} -- get one job by id
|
||||
DELETE /jobs/{id}/cancel -- cancel a queued or running job
|
||||
GET /jobs/{id}/run -- SSE: run the job, stream stdout
|
||||
GET /results -- list completed models with training_info.json metrics
|
||||
|
||||
SQLite schema:
|
||||
CREATE TABLE IF NOT EXISTS jobs (
|
||||
id TEXT PRIMARY KEY,
|
||||
type TEXT NOT NULL, -- 'classifier' | 'llm-sft'
|
||||
model_key TEXT NOT NULL,
|
||||
status TEXT NOT NULL DEFAULT 'queued',
|
||||
config_json TEXT NOT NULL DEFAULT '{}',
|
||||
created_at TEXT NOT NULL,
|
||||
started_at TEXT,
|
||||
completed_at TEXT,
|
||||
error TEXT
|
||||
)
|
||||
|
||||
Testability seam: _DB_PATH global, override via set_db_path().
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sqlite3
|
||||
import subprocess as _subprocess
|
||||
import uuid
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any, Generator
|
||||
|
||||
import yaml
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_ROOT = Path(__file__).parent.parent.parent
|
||||
_DB_PATH: Path = _ROOT / "data" / "train_jobs.db"
|
||||
_MODELS_DIR: Path = _ROOT / "models"
|
||||
_CONFIG_DIR: Path | None = None # override in tests via set_config_dir()
|
||||
_running_procs: dict[str, Any] = {}
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# -- Testability seams -------------------------------------------------
|
||||
|
||||
def set_db_path(path: Path) -> None:
|
||||
global _DB_PATH
|
||||
_DB_PATH = path
|
||||
|
||||
def set_models_dir(path: Path) -> None:
|
||||
global _MODELS_DIR
|
||||
_MODELS_DIR = path
|
||||
|
||||
def set_config_dir(path: "Path | None") -> None:
|
||||
global _CONFIG_DIR
|
||||
_CONFIG_DIR = path
|
||||
|
||||
|
||||
# -- Config helpers ----------------------------------------------------
|
||||
|
||||
def _config_file() -> Path:
|
||||
if _CONFIG_DIR is not None:
|
||||
return _CONFIG_DIR / "label_tool.yaml"
|
||||
return _ROOT / "config" / "label_tool.yaml"
|
||||
|
||||
|
||||
def _load_train_config() -> dict:
|
||||
"""Read python_bin from label_tool.yaml.
|
||||
|
||||
Priority (highest to lowest):
|
||||
1. label_tool.yaml train: python_bin
|
||||
2. label_tool.yaml cforch: python_bin
|
||||
3. Hardcoded default (classifiers conda env)
|
||||
"""
|
||||
_DEFAULT_PYTHON_BIN = "/devl/miniconda3/envs/job-seeker-classifiers/bin/python"
|
||||
f = _config_file()
|
||||
train_cfg: dict = {}
|
||||
cforch_cfg: dict = {}
|
||||
if f.exists():
|
||||
try:
|
||||
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
||||
train_cfg = raw.get("train", {}) or {}
|
||||
cforch_cfg = raw.get("cforch", {}) or {}
|
||||
except yaml.YAMLError as exc:
|
||||
logger.warning("Failed to parse train config %s: %s", f, exc)
|
||||
return {
|
||||
"python_bin": train_cfg.get(
|
||||
"python_bin",
|
||||
cforch_cfg.get("python_bin", _DEFAULT_PYTHON_BIN),
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
# -- Database helpers --------------------------------------------------
|
||||
|
||||
@contextmanager
|
||||
def _db() -> Generator[sqlite3.Connection, None, None]:
|
||||
conn = sqlite3.connect(str(_DB_PATH))
|
||||
conn.row_factory = sqlite3.Row
|
||||
try:
|
||||
yield conn
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
def _init_db() -> None:
|
||||
"""Create jobs table if it does not exist. Called lazily per request."""
|
||||
_DB_PATH.parent.mkdir(parents=True, exist_ok=True)
|
||||
with _db() as conn:
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS jobs (
|
||||
id TEXT PRIMARY KEY,
|
||||
type TEXT NOT NULL,
|
||||
model_key TEXT NOT NULL,
|
||||
status TEXT NOT NULL DEFAULT 'queued',
|
||||
config_json TEXT NOT NULL DEFAULT '{}',
|
||||
created_at TEXT NOT NULL,
|
||||
started_at TEXT,
|
||||
completed_at TEXT,
|
||||
error TEXT
|
||||
)
|
||||
""")
|
||||
|
||||
|
||||
def _row_to_dict(row: sqlite3.Row) -> dict:
|
||||
return {k: row[k] for k in row.keys()}
|
||||
|
||||
|
||||
# -- GPU selection (copied from api.py) --------------------------------
|
||||
|
||||
def _best_cuda_device() -> str:
|
||||
"""Return index of GPU with most free VRAM, or empty string."""
|
||||
try:
|
||||
out = _subprocess.check_output(
|
||||
["nvidia-smi", "--query-gpu=index,memory.free",
|
||||
"--format=csv,noheader,nounits"],
|
||||
text=True, timeout=5,
|
||||
)
|
||||
best_idx, best_free = "", 0
|
||||
for line in out.strip().splitlines():
|
||||
parts = line.strip().split(", ")
|
||||
if len(parts) == 2:
|
||||
idx, free = parts[0].strip(), int(parts[1].strip())
|
||||
if free > best_free:
|
||||
best_free, best_idx = free, idx
|
||||
return best_idx
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
|
||||
# -- Pydantic models ---------------------------------------------------
|
||||
|
||||
class CreateJobRequest(BaseModel):
|
||||
type: str # "classifier" | "llm-sft"
|
||||
model_key: str # e.g. "deberta-small"
|
||||
config_json: dict = {}
|
||||
|
||||
|
||||
# -- Routes ------------------------------------------------------------
|
||||
|
||||
@router.get("/jobs")
|
||||
def list_jobs() -> dict:
|
||||
_init_db()
|
||||
with _db() as conn:
|
||||
rows = conn.execute("SELECT * FROM jobs ORDER BY created_at DESC").fetchall()
|
||||
return {"jobs": [_row_to_dict(r) for r in rows]}
|
||||
|
||||
|
||||
@router.post("/jobs")
|
||||
def create_job(req: CreateJobRequest) -> dict:
|
||||
if req.type not in ("classifier", "llm-sft"):
|
||||
raise HTTPException(400, f"Unknown job type: {req.type!r}. Must be 'classifier' or 'llm-sft'.")
|
||||
_init_db()
|
||||
job_id = str(uuid.uuid4())
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
with _db() as conn:
|
||||
conn.execute(
|
||||
"INSERT INTO jobs (id, type, model_key, status, config_json, created_at) "
|
||||
"VALUES (?, ?, ?, 'queued', ?, ?)",
|
||||
(job_id, req.type, req.model_key, json.dumps(req.config_json), now),
|
||||
)
|
||||
return {"id": job_id, "type": req.type, "model_key": req.model_key,
|
||||
"status": "queued", "config_json": req.config_json,
|
||||
"created_at": now, "started_at": None, "completed_at": None, "error": None}
|
||||
|
||||
|
||||
@router.get("/jobs/{job_id}")
|
||||
def get_job(job_id: str) -> dict:
|
||||
_init_db()
|
||||
with _db() as conn:
|
||||
row = conn.execute("SELECT * FROM jobs WHERE id = ?", (job_id,)).fetchone()
|
||||
if row is None:
|
||||
raise HTTPException(404, f"Job {job_id!r} not found")
|
||||
return _row_to_dict(row)
|
||||
|
||||
|
||||
@router.delete("/jobs/{job_id}/cancel")
|
||||
def cancel_job(job_id: str) -> dict:
|
||||
_init_db()
|
||||
with _db() as conn:
|
||||
row = conn.execute("SELECT * FROM jobs WHERE id = ?", (job_id,)).fetchone()
|
||||
if row is None:
|
||||
raise HTTPException(404, f"Job {job_id!r} not found")
|
||||
if row["status"] not in ("queued", "running"):
|
||||
raise HTTPException(409, f"Job is {row['status']} -- cannot cancel")
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
conn.execute("UPDATE jobs SET status='cancelled', completed_at=? WHERE id=?", (now, job_id))
|
||||
proc = _running_procs.pop(job_id, None)
|
||||
if proc is not None:
|
||||
try:
|
||||
proc.terminate()
|
||||
proc.wait(timeout=3)
|
||||
except _subprocess.TimeoutExpired:
|
||||
try:
|
||||
proc.kill()
|
||||
except OSError:
|
||||
pass
|
||||
return {"status": "cancelled"}
|
||||
|
||||
|
||||
@router.get("/jobs/{job_id}/run")
|
||||
def run_job(job_id: str) -> StreamingResponse:
|
||||
_init_db()
|
||||
with _db() as conn:
|
||||
row = conn.execute("SELECT * FROM jobs WHERE id = ?", (job_id,)).fetchone()
|
||||
if row is None:
|
||||
raise HTTPException(404, f"Job {job_id!r} not found")
|
||||
if row["status"] != "queued":
|
||||
raise HTTPException(409, f"Job is {row['status']} -- only queued jobs can be run")
|
||||
job = _row_to_dict(row)
|
||||
|
||||
def generate():
|
||||
cfg = _load_train_config()
|
||||
python_bin = cfg["python_bin"]
|
||||
config = json.loads(job["config_json"] or "{}")
|
||||
model_key = job["model_key"]
|
||||
epochs = config.get("epochs", 5)
|
||||
|
||||
if job["type"] == "classifier":
|
||||
script = str(_ROOT / "scripts" / "finetune_classifier.py")
|
||||
cmd = [python_bin, script, "--model", model_key, "--epochs", str(epochs)]
|
||||
data_dir = _ROOT / "data"
|
||||
for sf in config.get("score_files", []):
|
||||
resolved = (data_dir / sf).resolve()
|
||||
if resolved.is_relative_to(data_dir.resolve()):
|
||||
cmd.extend(["--score", str(resolved)])
|
||||
elif job["type"] == "llm-sft":
|
||||
script = str(_ROOT / "scripts" / "finetune_sft.py")
|
||||
cmd = [python_bin, script, "--model", model_key, "--epochs", str(epochs)]
|
||||
else:
|
||||
job_type = job["type"]
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': f'Unknown job type: {job_type}'})}\n\n"
|
||||
return
|
||||
|
||||
proc_env = {**os.environ, "PYTORCH_ALLOC_CONF": "expandable_segments:True"}
|
||||
best_gpu = _best_cuda_device()
|
||||
if best_gpu:
|
||||
proc_env["CUDA_VISIBLE_DEVICES"] = best_gpu
|
||||
|
||||
gpu_note = f"GPU {best_gpu}" if best_gpu else "CPU (no GPU found)"
|
||||
yield f"data: {json.dumps({'type': 'progress', 'message': f'[train] Using {gpu_note}'})}\n\n"
|
||||
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
with _db() as conn:
|
||||
conn.execute("UPDATE jobs SET status='running', started_at=? WHERE id=?", (now, job_id))
|
||||
|
||||
try:
|
||||
proc = _subprocess.Popen(
|
||||
cmd, stdout=_subprocess.PIPE, stderr=_subprocess.STDOUT,
|
||||
text=True, bufsize=1, cwd=str(_ROOT), env=proc_env,
|
||||
)
|
||||
_running_procs[job_id] = proc
|
||||
try:
|
||||
for line in proc.stdout:
|
||||
line = line.rstrip()
|
||||
if line:
|
||||
yield f"data: {json.dumps({'type': 'progress', 'message': line})}\n\n"
|
||||
proc.wait()
|
||||
finished_at = datetime.now(timezone.utc).isoformat()
|
||||
if proc.returncode == 0:
|
||||
with _db() as conn:
|
||||
conn.execute(
|
||||
"UPDATE jobs SET status='completed', completed_at=? WHERE id=?",
|
||||
(finished_at, job_id))
|
||||
yield f"data: {json.dumps({'type': 'complete'})}\n\n"
|
||||
else:
|
||||
err = f"Process exited with code {proc.returncode}"
|
||||
with _db() as conn:
|
||||
conn.execute(
|
||||
"UPDATE jobs SET status='failed', completed_at=?, error=? WHERE id=?",
|
||||
(finished_at, err, job_id))
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': err})}\n\n"
|
||||
finally:
|
||||
_running_procs.pop(job_id, None)
|
||||
except Exception as exc:
|
||||
err = str(exc)
|
||||
finished_at = datetime.now(timezone.utc).isoformat()
|
||||
with _db() as conn:
|
||||
conn.execute(
|
||||
"UPDATE jobs SET status='failed', completed_at=?, error=? WHERE id=?",
|
||||
(finished_at, err, job_id))
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': err})}\n\n"
|
||||
|
||||
return StreamingResponse(generate(), media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"})
|
||||
|
||||
|
||||
@router.get("/results")
|
||||
def list_results() -> dict:
|
||||
if not _MODELS_DIR.exists():
|
||||
return {"results": []}
|
||||
results = []
|
||||
for sub in _MODELS_DIR.iterdir():
|
||||
if not sub.is_dir():
|
||||
continue
|
||||
info_path = sub / "training_info.json"
|
||||
if not info_path.exists():
|
||||
continue
|
||||
try:
|
||||
info = json.loads(info_path.read_text(encoding="utf-8"))
|
||||
results.append(info)
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to read training_info.json from %s: %s", info_path, exc)
|
||||
return {"results": results}
|
||||
117
app/utils.py
117
app/utils.py
|
|
@ -1,117 +0,0 @@
|
|||
"""Shared email utility functions for Avocet.
|
||||
|
||||
Pure-stdlib helpers extracted from the retired label_tool.py Streamlit app.
|
||||
These are reused by the FastAPI backend and the test suite.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
from html.parser import HTMLParser
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
||||
# ── HTML → plain-text extractor ──────────────────────────────────────────────
|
||||
|
||||
class _TextExtractor(HTMLParser):
|
||||
"""Extract visible text from an HTML email body, preserving line breaks."""
|
||||
_BLOCK = {"p", "div", "br", "li", "tr", "h1", "h2", "h3", "h4", "h5", "h6", "blockquote"}
|
||||
_SKIP = {"script", "style", "head", "noscript"}
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(convert_charrefs=True)
|
||||
self._parts: list[str] = []
|
||||
self._depth_skip = 0
|
||||
|
||||
def handle_starttag(self, tag, attrs):
|
||||
tag = tag.lower()
|
||||
if tag in self._SKIP:
|
||||
self._depth_skip += 1
|
||||
elif tag in self._BLOCK:
|
||||
self._parts.append("\n")
|
||||
|
||||
def handle_endtag(self, tag):
|
||||
if tag.lower() in self._SKIP:
|
||||
self._depth_skip = max(0, self._depth_skip - 1)
|
||||
|
||||
def handle_data(self, data):
|
||||
if not self._depth_skip:
|
||||
self._parts.append(data)
|
||||
|
||||
def get_text(self) -> str:
|
||||
text = "".join(self._parts)
|
||||
lines = [ln.strip() for ln in text.splitlines()]
|
||||
return "\n".join(ln for ln in lines if ln)
|
||||
|
||||
|
||||
def strip_html(html_str: str) -> str:
|
||||
"""Convert HTML email body to plain text. Pure stdlib, no dependencies."""
|
||||
try:
|
||||
extractor = _TextExtractor()
|
||||
extractor.feed(html_str)
|
||||
return extractor.get_text()
|
||||
except Exception:
|
||||
return re.sub(r"<[^>]+>", " ", html_str).strip()
|
||||
|
||||
|
||||
def extract_body(msg: Any) -> str:
|
||||
"""Return plain-text body. Strips HTML when no text/plain part exists."""
|
||||
if msg.is_multipart():
|
||||
html_fallback: str | None = None
|
||||
for part in msg.walk():
|
||||
ct = part.get_content_type()
|
||||
if ct == "text/plain":
|
||||
try:
|
||||
charset = part.get_content_charset() or "utf-8"
|
||||
return part.get_payload(decode=True).decode(charset, errors="replace")
|
||||
except Exception:
|
||||
pass
|
||||
elif ct == "text/html" and html_fallback is None:
|
||||
try:
|
||||
charset = part.get_content_charset() or "utf-8"
|
||||
raw = part.get_payload(decode=True).decode(charset, errors="replace")
|
||||
html_fallback = strip_html(raw)
|
||||
except Exception:
|
||||
pass
|
||||
return html_fallback or ""
|
||||
else:
|
||||
try:
|
||||
charset = msg.get_content_charset() or "utf-8"
|
||||
raw = msg.get_payload(decode=True).decode(charset, errors="replace")
|
||||
if msg.get_content_type() == "text/html":
|
||||
return strip_html(raw)
|
||||
return raw
|
||||
except Exception:
|
||||
pass
|
||||
return ""
|
||||
|
||||
|
||||
def read_jsonl(path: Path) -> list[dict]:
|
||||
"""Read a JSONL file, returning valid records. Skips blank lines and malformed JSON."""
|
||||
if not path.exists():
|
||||
return []
|
||||
records: list[dict] = []
|
||||
for line in path.read_text(encoding="utf-8").splitlines():
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
records.append(json.loads(line))
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
return records
|
||||
|
||||
|
||||
def write_jsonl(path: Path, records: list[dict]) -> None:
|
||||
"""Write records to a JSONL file, overwriting any existing content."""
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
content = "\n".join(json.dumps(r, ensure_ascii=False) for r in records)
|
||||
path.write_text(content + ("\n" if records else ""), encoding="utf-8")
|
||||
|
||||
|
||||
def append_jsonl(path: Path, record: dict) -> None:
|
||||
"""Append a single record to a JSONL file."""
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(path, "a", encoding="utf-8") as fh:
|
||||
fh.write(json.dumps(record, ensure_ascii=False) + "\n")
|
||||
427
app/voice.py
427
app/voice.py
|
|
@ -1,427 +0,0 @@
|
|||
"""Avocet — Voice benchmark integration API.
|
||||
|
||||
Wraps scripts/benchmark_voice.py and exposes it via the Avocet API.
|
||||
Connection config (coordinator_url, ollama_url, python_bin) is read
|
||||
from label_tool.yaml under the `cforch:` key — the same block used
|
||||
by cforch.py, so no new config section is needed.
|
||||
|
||||
All endpoints are registered on `router` (a FastAPI APIRouter).
|
||||
api.py includes this router with prefix="/api/voice".
|
||||
|
||||
Module-level globals (_BENCH_RUNNING, _bench_proc) follow the same
|
||||
testability pattern as cforch.py.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import subprocess as _subprocess
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
import yaml
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_ROOT = Path(__file__).parent.parent
|
||||
_CONFIG_DIR: Path | None = None # override in tests via set_config_dir()
|
||||
_BENCH_RUNNING: bool = False
|
||||
_bench_proc: Any = None
|
||||
|
||||
_BENCH_SCRIPT = _ROOT / "scripts" / "benchmark_voice.py"
|
||||
_RESULTS_DIR = _ROOT / "benchmark_results"
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# ── Testability seams ──────────────────────────────────────────────────────────
|
||||
|
||||
def set_config_dir(path: Path | None) -> None:
|
||||
global _CONFIG_DIR
|
||||
_CONFIG_DIR = path
|
||||
|
||||
|
||||
# ── Internal helpers ───────────────────────────────────────────────────────────
|
||||
|
||||
def _config_file() -> Path:
|
||||
if _CONFIG_DIR is not None:
|
||||
return _CONFIG_DIR / "label_tool.yaml"
|
||||
return _ROOT / "config" / "label_tool.yaml"
|
||||
|
||||
|
||||
def _load_config() -> dict:
|
||||
"""Read label_tool.yaml cforch section for coordinator/ollama/python config."""
|
||||
f = _config_file()
|
||||
file_cfg: dict = {}
|
||||
if f.exists():
|
||||
try:
|
||||
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
||||
file_cfg = raw.get("cforch", {}) or {}
|
||||
except yaml.YAMLError as exc:
|
||||
logger.warning("Failed to parse voice config %s: %s", f, exc)
|
||||
return {
|
||||
"coordinator_url": file_cfg.get("coordinator_url", "http://10.1.10.71:7700"),
|
||||
"ollama_url": file_cfg.get("ollama_url", "http://localhost:11434"),
|
||||
"python_bin": file_cfg.get("python_bin", "/devl/miniconda3/envs/cf/bin/python"),
|
||||
}
|
||||
|
||||
|
||||
# ── GET /models ────────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/models")
|
||||
def get_models() -> dict:
|
||||
"""Return available models grouped by source.
|
||||
|
||||
- ollama: fetched live from /api/tags (includes any models downloaded
|
||||
via the Models view — automatically in sync)
|
||||
- cf_text: fetched from cf-orch catalog endpoint (requires node profile
|
||||
entry + coordinator restart when new GGUFs are added)
|
||||
"""
|
||||
cfg = _load_config()
|
||||
|
||||
# Ollama models — live query so newly downloaded models appear immediately
|
||||
ollama_models: list[dict] = []
|
||||
try:
|
||||
resp = httpx.get(f"{cfg['ollama_url']}/api/tags", timeout=5.0)
|
||||
resp.raise_for_status()
|
||||
for m in resp.json().get("models", []):
|
||||
name = m.get("name", "")
|
||||
if name:
|
||||
size_bytes = m.get("size", 0)
|
||||
ollama_models.append({
|
||||
"id": name,
|
||||
"name": name,
|
||||
"source": "ollama",
|
||||
"size_mb": round(size_bytes / (1024 * 1024)) if size_bytes else None,
|
||||
"vram_mb": None,
|
||||
})
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to fetch ollama models: %s", exc)
|
||||
|
||||
# cf-text catalog — fetched from cf-orch coordinator
|
||||
cftext_models: list[dict] = []
|
||||
try:
|
||||
resp = httpx.get(
|
||||
f"{cfg['coordinator_url']}/api/services/cf-text/catalog",
|
||||
timeout=5.0,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
for model_id, entry in resp.json().items():
|
||||
if isinstance(entry, dict):
|
||||
cftext_models.append({
|
||||
"id": model_id,
|
||||
"name": model_id,
|
||||
"source": "cf-text",
|
||||
"vram_mb": entry.get("vram_mb"),
|
||||
"description": entry.get("description", ""),
|
||||
})
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to fetch cf-text catalog: %s", exc)
|
||||
|
||||
return {"ollama": ollama_models, "cf_text": cftext_models}
|
||||
|
||||
|
||||
# ── GET /run ───────────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/run")
|
||||
def run_voice_benchmark(
|
||||
models: str = Query("", description="Comma-separated model IDs (empty = all)"),
|
||||
use_cforch: bool = Query(False),
|
||||
max_vram: int = Query(7200, description="Max VRAM MB for cf-orch OOM filter"),
|
||||
include_large: bool = Query(False, description="Include large (30B+) ollama models"),
|
||||
workers: int = Query(1, description="Parallel workers — run N models simultaneously"),
|
||||
) -> StreamingResponse:
|
||||
"""Spawn benchmark_voice.py and stream stdout as SSE progress events.
|
||||
|
||||
On successful completion, emits a final `type: result` event containing
|
||||
the parsed JSON from the newest voice_*.json file.
|
||||
"""
|
||||
global _BENCH_RUNNING, _bench_proc
|
||||
|
||||
if _BENCH_RUNNING:
|
||||
raise HTTPException(409, "A voice benchmark is already running")
|
||||
|
||||
cfg = _load_config()
|
||||
python_bin = cfg["python_bin"]
|
||||
|
||||
def generate():
|
||||
global _BENCH_RUNNING, _bench_proc
|
||||
|
||||
if not _BENCH_SCRIPT.exists():
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': f'benchmark_voice.py not found at {_BENCH_SCRIPT}'})}\n\n"
|
||||
return
|
||||
|
||||
cmd = [python_bin, str(_BENCH_SCRIPT), "run"]
|
||||
|
||||
if models:
|
||||
cmd.extend(["--models", ",".join(m.strip() for m in models.split(",") if m.strip())])
|
||||
if use_cforch:
|
||||
cmd.extend(["--cforch", "--cforch-url", cfg["coordinator_url"],
|
||||
"--max-vram", str(max_vram)])
|
||||
if include_large:
|
||||
cmd.append("--include-large")
|
||||
if workers > 1:
|
||||
cmd.extend(["--workers", str(workers)])
|
||||
|
||||
_BENCH_RUNNING = True
|
||||
try:
|
||||
proc = _subprocess.Popen(
|
||||
cmd,
|
||||
stdout=_subprocess.PIPE,
|
||||
stderr=_subprocess.STDOUT,
|
||||
text=True,
|
||||
bufsize=1,
|
||||
cwd=str(_ROOT),
|
||||
)
|
||||
_bench_proc = proc
|
||||
try:
|
||||
for line in proc.stdout:
|
||||
line = line.rstrip()
|
||||
if line:
|
||||
yield f"data: {json.dumps({'type': 'progress', 'message': line})}\n\n"
|
||||
proc.wait()
|
||||
if proc.returncode == 0:
|
||||
result_files = sorted(_RESULTS_DIR.glob("voice_*.json"))
|
||||
if result_files:
|
||||
try:
|
||||
results = json.loads(result_files[-1].read_text(encoding="utf-8"))
|
||||
yield f"data: {json.dumps({'type': 'result', 'results': results, 'filename': result_files[-1].name})}\n\n"
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to read voice results: %s", exc)
|
||||
yield f"data: {json.dumps({'type': 'complete'})}\n\n"
|
||||
else:
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': f'Process exited with code {proc.returncode}'})}\n\n"
|
||||
finally:
|
||||
_bench_proc = None
|
||||
except Exception as exc:
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': str(exc)})}\n\n"
|
||||
finally:
|
||||
_BENCH_RUNNING = False
|
||||
|
||||
return StreamingResponse(
|
||||
generate(),
|
||||
media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
|
||||
)
|
||||
|
||||
|
||||
# ── GET /results ───────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/results")
|
||||
def list_results() -> list[dict]:
|
||||
"""List past voice benchmark runs, newest first.
|
||||
|
||||
Returns lightweight summaries (date, model count, top score).
|
||||
Use /results/{filename} to fetch full model-level detail.
|
||||
"""
|
||||
if not _RESULTS_DIR.exists():
|
||||
return []
|
||||
|
||||
runs: list[dict] = []
|
||||
for f in sorted(_RESULTS_DIR.glob("voice_*.json"), reverse=True):
|
||||
stem = f.stem # voice_2026-04-22_1502
|
||||
date_str = stem.removeprefix("voice_") # 2026-04-22_1502
|
||||
try:
|
||||
date_part, time_part = date_str.split("_")
|
||||
display_date = f"{date_part} {time_part[:2]}:{time_part[2:]}"
|
||||
except Exception:
|
||||
display_date = date_str
|
||||
|
||||
try:
|
||||
results = json.loads(f.read_text(encoding="utf-8"))
|
||||
top_score = max((r.get("avg_score", 0) for r in results), default=0)
|
||||
model_count = len(results)
|
||||
except Exception:
|
||||
top_score = 0
|
||||
model_count = 0
|
||||
|
||||
runs.append({
|
||||
"filename": f.name,
|
||||
"date": display_date,
|
||||
"model_count": model_count,
|
||||
"top_score": round(top_score, 1),
|
||||
})
|
||||
|
||||
return runs
|
||||
|
||||
|
||||
@router.get("/results/latest")
|
||||
def get_latest_results() -> list[dict]:
|
||||
"""Return the latest voice benchmark result list."""
|
||||
if not _RESULTS_DIR.exists():
|
||||
raise HTTPException(404, "No benchmark results found")
|
||||
files = sorted(_RESULTS_DIR.glob("voice_*.json"))
|
||||
if not files:
|
||||
raise HTTPException(404, "No benchmark results found")
|
||||
try:
|
||||
return json.loads(files[-1].read_text(encoding="utf-8"))
|
||||
except Exception as exc:
|
||||
raise HTTPException(500, f"Failed to read results: {exc}") from exc
|
||||
|
||||
|
||||
@router.get("/results/{filename}")
|
||||
def get_results_by_filename(filename: str) -> list[dict]:
|
||||
"""Return voice benchmark results for a specific run file."""
|
||||
if not filename.startswith("voice_") or not filename.endswith(".json"):
|
||||
raise HTTPException(400, "Invalid filename — expected voice_*.json")
|
||||
f = _RESULTS_DIR / filename
|
||||
if not f.exists():
|
||||
raise HTTPException(404, f"Results file not found: {filename}")
|
||||
try:
|
||||
return json.loads(f.read_text(encoding="utf-8"))
|
||||
except Exception as exc:
|
||||
raise HTTPException(500, f"Failed to read results: {exc}") from exc
|
||||
|
||||
|
||||
# ── POST /send-to-corrections ──────────────────────────────────────────────────
|
||||
|
||||
class SendToCorrectionsRequest(BaseModel):
|
||||
filename: str # voice_YYYY-MM-DD_HHMM.json — the source run file
|
||||
model_ids: list[str] = [] # empty = all models in the run
|
||||
|
||||
|
||||
@router.post("/send-to-corrections")
|
||||
def send_to_corrections(req: SendToCorrectionsRequest) -> dict:
|
||||
"""Push voice benchmark outputs into the SFT corrections queue.
|
||||
|
||||
Each prompt_result from the selected models becomes one SFT candidate
|
||||
with status='needs_review'. Duplicates are skipped via the 'id' field
|
||||
(hash of model_id + tag).
|
||||
"""
|
||||
if not req.filename.startswith("voice_") or not req.filename.endswith(".json"):
|
||||
raise HTTPException(400, "Invalid filename")
|
||||
|
||||
src = _RESULTS_DIR / req.filename
|
||||
if not src.exists():
|
||||
raise HTTPException(404, f"Results file not found: {req.filename}")
|
||||
|
||||
try:
|
||||
run_results: list[dict] = json.loads(src.read_text(encoding="utf-8"))
|
||||
except Exception as exc:
|
||||
raise HTTPException(500, f"Failed to read results: {exc}") from exc
|
||||
|
||||
# Resolve sft_candidates.jsonl path (same logic as sft.py)
|
||||
sft_data_dir = _ROOT / "data"
|
||||
sft_file = sft_data_dir / "sft_candidates.jsonl"
|
||||
|
||||
# Load existing IDs to deduplicate
|
||||
existing_ids: set[str] = set()
|
||||
if sft_file.exists():
|
||||
for line in sft_file.read_text(encoding="utf-8").splitlines():
|
||||
line = line.strip()
|
||||
if line:
|
||||
try:
|
||||
existing_ids.add(json.loads(line)["id"])
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
run_id = req.filename.removesuffix(".json") # voice_2026-04-22_1502
|
||||
timestamp = datetime.now(tz=timezone.utc).isoformat()
|
||||
|
||||
new_candidates: list[dict] = []
|
||||
for model_result in run_results:
|
||||
model_id = model_result.get("model_id", "")
|
||||
if req.model_ids and model_id not in req.model_ids:
|
||||
continue
|
||||
for pr in model_result.get("prompt_results", []):
|
||||
tag = pr.get("tag", "")
|
||||
# Stable id: deterministic hash of run + model + prompt tag
|
||||
candidate_id = str(uuid.uuid5(
|
||||
uuid.NAMESPACE_URL,
|
||||
f"voice-benchmark/{run_id}/{model_id}/{tag}",
|
||||
))
|
||||
if candidate_id in existing_ids:
|
||||
continue
|
||||
|
||||
score_pct = pr.get("score", 0.0) / 100.0
|
||||
signals = pr.get("signals", {})
|
||||
|
||||
# Build the prompt message list matching the benchmark's actual request
|
||||
prompt_messages = [
|
||||
{"role": "system", "content": _VOICE_SYSTEM_PROMPT},
|
||||
{"role": "user", "content": pr.get("user_prompt", tag)},
|
||||
]
|
||||
|
||||
new_candidates.append({
|
||||
"id": candidate_id,
|
||||
"source": "voice-benchmark",
|
||||
"benchmark_run_id": run_id,
|
||||
"timestamp": timestamp,
|
||||
"status": "needs_review",
|
||||
"prompt_messages": prompt_messages,
|
||||
"model_response": pr.get("output", ""),
|
||||
"corrected_response": None,
|
||||
"quality_score": round(score_pct, 4),
|
||||
"failure_reason": _build_failure_reason(pr, signals),
|
||||
"failure_category": None,
|
||||
"task_id": f"voice/{tag}",
|
||||
"task_type": "voice-match",
|
||||
"task_name": tag.replace("_", " ").title(),
|
||||
"model_id": model_id,
|
||||
"model_name": model_id,
|
||||
"node_id": "",
|
||||
"gpu_id": 0,
|
||||
"tokens_per_sec": 0,
|
||||
})
|
||||
existing_ids.add(candidate_id)
|
||||
|
||||
if new_candidates:
|
||||
sft_data_dir.mkdir(parents=True, exist_ok=True)
|
||||
with open(sft_file, "a", encoding="utf-8") as fh:
|
||||
for c in new_candidates:
|
||||
fh.write(json.dumps(c) + "\n")
|
||||
|
||||
return {"imported": len(new_candidates), "skipped": 0}
|
||||
|
||||
|
||||
# Excerpt of the system prompt used in benchmark_voice.py — reproduced here
|
||||
# so the SFT candidate captures the full generation context.
|
||||
_VOICE_SYSTEM_PROMPT = (
|
||||
"You are a writing assistant. Your job is to write a Reddit reply that matches "
|
||||
"the voice, tone, and style of the provided samples exactly.\n\n"
|
||||
"Voice characteristics:\n"
|
||||
"- Casual engineer tone. Short punchy sentences.\n"
|
||||
"- No em dashes. No semicolons. No filler phrases.\n"
|
||||
"- Direct. Opinionated. Community-first."
|
||||
)
|
||||
|
||||
|
||||
def _build_failure_reason(pr: dict, signals: dict) -> str | None:
|
||||
"""Return a human-readable failure reason string if there are violations."""
|
||||
reasons = []
|
||||
if signals.get("em_dash_count", 0) > 0:
|
||||
reasons.append(f"{signals['em_dash_count']} em dash(es)")
|
||||
if signals.get("semicolon_count", 0) > 0:
|
||||
reasons.append(f"{signals['semicolon_count']} semicolon(s)")
|
||||
if signals.get("filler_hits"):
|
||||
reasons.append(f"filler phrases: {', '.join(signals['filler_hits'])}")
|
||||
if not pr.get("output", "").strip():
|
||||
reasons.append("empty output")
|
||||
return "; ".join(reasons) if reasons else None
|
||||
|
||||
|
||||
# ── POST /cancel ───────────────────────────────────────────────────────────────
|
||||
|
||||
@router.post("/cancel")
|
||||
def cancel_voice_benchmark() -> dict:
|
||||
"""Kill the running voice benchmark subprocess."""
|
||||
global _BENCH_RUNNING, _bench_proc
|
||||
|
||||
if not _BENCH_RUNNING:
|
||||
raise HTTPException(404, "No voice benchmark is currently running")
|
||||
|
||||
if _bench_proc is not None:
|
||||
try:
|
||||
_bench_proc.terminate()
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to terminate voice benchmark: %s", exc)
|
||||
|
||||
_BENCH_RUNNING = False
|
||||
_bench_proc = None
|
||||
return {"status": "cancelled"}
|
||||
|
|
@ -21,124 +21,3 @@ accounts:
|
|||
|
||||
# Optional: limit emails fetched per account per run (0 = unlimited)
|
||||
max_per_account: 500
|
||||
|
||||
# cf-orch SFT candidate import — path to the bench_results/ directory
|
||||
# produced by circuitforge-orch's benchmark harness.
|
||||
sft:
|
||||
bench_results_dir: /path/to/circuitforge-orch/scripts/bench_results
|
||||
|
||||
# cf-orch integration — LLM benchmark harness via cf-orch coordinator.
|
||||
# All keys here override the corresponding environment variables.
|
||||
# Omit any key to fall back to the env var (see .env.example).
|
||||
cforch:
|
||||
# Path to cf-orch's benchmark.py script
|
||||
bench_script: /path/to/circuitforge-orch/scripts/benchmark.py
|
||||
# Task and model definition files (yaml)
|
||||
bench_tasks: /path/to/circuitforge-orch/scripts/bench_tasks.yaml
|
||||
bench_models: /path/to/circuitforge-orch/scripts/bench_models.yaml
|
||||
# Where benchmark results are written (also used for SFT candidate discovery)
|
||||
results_dir: /path/to/circuitforge-orch/scripts/bench_results
|
||||
# Python interpreter with cf-orch installed
|
||||
python_bin: /devl/miniconda3/envs/cf/bin/python
|
||||
|
||||
# Connection config — override env vars CF_ORCH_URL / CF_LICENSE_KEY / OLLAMA_HOST / CF_JUDGE_URL / HF_TOKEN
|
||||
# coordinator_url: http://localhost:7700
|
||||
# license_key: CFG-AVCT-xxxx-xxxx-xxxx
|
||||
# ollama_url: http://localhost:11434
|
||||
# ollama_model: llama3.2:3b
|
||||
# embed_model: nomic-embed-text # Ollama embedding model for EmbeddingKNNAdapter
|
||||
# judge_url: http://10.1.10.158:8008 # Sif cf-text — LLM-as-judge secondary scorer
|
||||
# judge_url: http://10.1.10.71:8008 # Heimdall cf-text (alternative)
|
||||
# Or set CF_JUDGE_URL. Populates the Judge URL field in the LLM Eval UI automatically.
|
||||
# hf_token: hf_xxxxxxxxxxxxxxxxxxxx # HuggingFace token — required for gated/terms-restricted models
|
||||
|
||||
# Directory containing per-node profile YAMLs (cf-orch node profiles).
|
||||
# Default: derived from bench_script location (../../profiles/nodes).
|
||||
# profiles_dir: /Library/Development/CircuitForge/circuitforge-orch/circuitforge_orch/profiles/nodes
|
||||
|
||||
# Imitate tab — pull real samples from sibling CF product APIs and run them
|
||||
# through local LLMs to build a corrections dataset.
|
||||
# ollama_url defaults to cforch.ollama_url if omitted here.
|
||||
imitate:
|
||||
ollama_url: http://localhost:11434 # optional — falls back to cforch.ollama_url
|
||||
|
||||
products:
|
||||
- id: peregrine
|
||||
name: Peregrine
|
||||
icon: "🦅"
|
||||
description: Job search assistant — live job listings
|
||||
base_url: http://localhost:8601
|
||||
health_path: /api/jobs/counts
|
||||
sample_endpoint: /api/jobs?status=pending&limit=5
|
||||
text_fields: [title, company, description]
|
||||
prompt_template: "Analyze this job listing and identify the key requirements, must-have skills, and any culture signals that would help tailor an application:\n\n{text}"
|
||||
|
||||
- id: osprey
|
||||
name: Osprey
|
||||
icon: "📞"
|
||||
description: Gov't hold-line automation — recent call records
|
||||
base_url: http://localhost:8520
|
||||
health_path: /api/health
|
||||
sample_endpoint: /api/calls/recent
|
||||
text_fields: [agency, issue, notes]
|
||||
prompt_template: "Draft a clear, professional follow-up letter for this government hold-line call. Include what was discussed, what action the agency committed to, and a polite deadline for response:\n\n{text}"
|
||||
|
||||
- id: linnet
|
||||
name: Linnet
|
||||
icon: "🐦"
|
||||
description: Real-time tone annotation — Elcor-style subtext for ND users
|
||||
base_url: http://localhost:8522
|
||||
health_path: /health
|
||||
sample_endpoint: /samples
|
||||
text_fields: [text, context]
|
||||
prompt_template: "Annotate the emotional tone and subtext of the following text using explicit Elcor-style markers (e.g. [SINCERELY], [UNCERTAIN], [FRUSTRATED]). Identify implied emotions, potential sarcasm, and any ambiguity that might be misread by neurodivergent readers:\n\n{text}"
|
||||
|
||||
- id: kiwi
|
||||
name: Kiwi
|
||||
icon: "🥝"
|
||||
description: Pantry tracker
|
||||
base_url: http://localhost:8511
|
||||
sample_endpoint: /api/inventory
|
||||
text_fields: [name, category, notes]
|
||||
prompt_template: "Describe this pantry item and estimate how best to use it:\n\n{text}"
|
||||
|
||||
- id: snipe
|
||||
name: Snipe
|
||||
icon: "🎯"
|
||||
description: eBay trust scoring
|
||||
base_url: http://localhost:8509
|
||||
sample_endpoint: /api/listings
|
||||
text_fields: [title, description, seller_info]
|
||||
prompt_template: "Evaluate the trustworthiness of this listing and flag any red flags:\n\n{text}"
|
||||
|
||||
- id: pagepiper
|
||||
name: Pagepiper
|
||||
icon: "📄"
|
||||
description: "PDF/rulebook RAG tool: page-level text chunks"
|
||||
base_url: http://localhost:8511
|
||||
health_path: /api/health
|
||||
sample_endpoint: /api/library
|
||||
chunk_endpoint: /api/library/sample-chunks?limit=50 # requires pagepiper#6
|
||||
text_fields: [title]
|
||||
prompt_template: "Summarize the key rules described in this passage:\n\n{text}"
|
||||
|
||||
# ── Log corpus (Turnstone training data) ──────────────────────────────────────
|
||||
corpus:
|
||||
# Directory containing pipeline JSONL log files to ingest (pull-side).
|
||||
# Files named <script>_<ts>.jsonl; one structured record per line.
|
||||
# POST /api/corpus/pipeline-ingest walks this dir and imports new files.
|
||||
# NFS-mounted on both Heimdall and Sif at /Library/Assets/
|
||||
pipeline_ingest_dir: /Library/Assets/logs/pipeline/
|
||||
|
||||
# Turnstone push sources (consent-gated, token-authenticated).
|
||||
# sources:
|
||||
# - token: "your-bearer-token"
|
||||
# source_host: "node.local"
|
||||
# owner: YourName
|
||||
# consent_date: "2026-05-17"
|
||||
# consent_method: signal_chat
|
||||
|
||||
# ── Embedding model comparison harness ────────────────────────────────────────
|
||||
embed_bench:
|
||||
# ollama_url: http://localhost:11434 # optional; falls back to cforch.ollama_url
|
||||
# top_k: 5 # default hits per model per query
|
||||
|
|
|
|||
95
docs/plans/2026-03-08-anime-animation-design.md
Normal file
95
docs/plans/2026-03-08-anime-animation-design.md
Normal file
|
|
@ -0,0 +1,95 @@
|
|||
# Anime.js Animation Integration — Design
|
||||
|
||||
**Date:** 2026-03-08
|
||||
**Status:** Approved
|
||||
**Branch:** feat/vue-label-tab
|
||||
|
||||
## Problem
|
||||
|
||||
The current animation system mixes CSS keyframes, CSS transitions, and imperative inline-style bindings across three files. The seams between systems produce:
|
||||
|
||||
- Abrupt ball pickup (instant scale/borderRadius jump)
|
||||
- No spring snap-back on release to no target
|
||||
- Rigid CSS dismissals with no timing control
|
||||
- Bucket grid and badge pop on basic `@keyframes`
|
||||
|
||||
## Decision
|
||||
|
||||
Integrate **Anime.js v4** as a single animation layer. Vue reactive state is unchanged; Anime.js owns all DOM motion imperatively.
|
||||
|
||||
## Architecture
|
||||
|
||||
One new composable, minimal changes to two existing files, CSS cleanup in two files.
|
||||
|
||||
```
|
||||
web/src/composables/useCardAnimation.ts ← NEW
|
||||
web/src/components/EmailCardStack.vue ← modify
|
||||
web/src/views/LabelView.vue ← modify
|
||||
```
|
||||
|
||||
**Data flow:**
|
||||
```
|
||||
pointer events → Vue refs (isHeld, deltaX, deltaY, dismissType)
|
||||
↓ watched by
|
||||
useCardAnimation(cardEl, stackEl, isHeld, ...)
|
||||
↓ imperatively drives
|
||||
Anime.js → DOM transforms
|
||||
```
|
||||
|
||||
`useCardAnimation` is a pure side-effect composable — returns nothing to the template. The `cardStyle` computed in `EmailCardStack.vue` is removed; Anime.js owns the element's transform directly.
|
||||
|
||||
## Animation Surfaces
|
||||
|
||||
### Pickup morph
|
||||
```
|
||||
animate(cardEl, { scale: 0.55, borderRadius: '50%', y: -80 }, { duration: 200, ease: spring(1, 80, 10) })
|
||||
```
|
||||
Replaces the instant CSS transform jump on `onPointerDown`.
|
||||
|
||||
### Drag tracking
|
||||
Raw `cardEl.style.translate` update on `onPointerMove` — no animation, just position. Easing only at boundaries (pickup / release), not during active drag.
|
||||
|
||||
### Snap-back
|
||||
```
|
||||
animate(cardEl, { x: 0, y: 0, scale: 1, borderRadius: '1rem' }, { ease: spring(1, 80, 10) })
|
||||
```
|
||||
Fires on `onPointerUp` when no zone/bucket target was hit.
|
||||
|
||||
### Dismissals (replace CSS `@keyframes`)
|
||||
- **fileAway** — `animate(cardEl, { y: '-120%', scale: 0.85, opacity: 0 }, { duration: 280, ease: 'out(3)' })`
|
||||
- **crumple** — 2-step timeline: shrink + redden → `scale(0)` + rotate
|
||||
- **slideUnder** — `animate(cardEl, { x: '110%', rotate: 5, opacity: 0 }, { duration: 260 })`
|
||||
|
||||
### Bucket grid rise
|
||||
`animate(gridEl, { y: -8, opacity: 0.45 })` on `isHeld` → true; reversed on false. Spring easing.
|
||||
|
||||
### Badge pop
|
||||
`animate(badgeEl, { scale: [0.6, 1], opacity: [0, 1] }, { ease: spring(1.5, 80, 8), duration: 300 })` triggered on badge mount via Vue's `onMounted` lifecycle hook in a `BadgePop` wrapper component or `v-enter-active` transition hook.
|
||||
|
||||
## Constraints
|
||||
|
||||
### Reduced motion
|
||||
`useCardAnimation` checks `motion.rich.value` before firing any Anime.js call. If false, all animations are skipped — instant state changes only. Consistent with existing `useMotion` pattern.
|
||||
|
||||
### Bundle size
|
||||
Anime.js v4 core ~17KB gzipped. Only `animate`, `spring`, and `createTimeline` are imported — Vite ESM tree-shaking keeps footprint minimal. The `draggable` module is not used.
|
||||
|
||||
### Tests
|
||||
Existing `EmailCardStack.test.ts` tests emit behavior, not animation — they remain passing. Anime.js mocked at module level in Vitest via `vi.mock('animejs')` where needed.
|
||||
|
||||
### CSS cleanup
|
||||
Remove from `EmailCardStack.vue` and `LabelView.vue`:
|
||||
- `@keyframes fileAway`, `crumple`, `slideUnder`
|
||||
- `@keyframes badge-pop`
|
||||
- `.dismiss-label`, `.dismiss-skip`, `.dismiss-discard` classes (Anime.js fires on element refs directly)
|
||||
- The `dismissClass` computed in `EmailCardStack.vue`
|
||||
|
||||
## Files Changed
|
||||
|
||||
| File | Change |
|
||||
|------|--------|
|
||||
| `web/package.json` | Add `animejs` dependency |
|
||||
| `web/src/composables/useCardAnimation.ts` | New — all Anime.js animation logic |
|
||||
| `web/src/components/EmailCardStack.vue` | Remove `cardStyle` computed + dismiss classes; call `useCardAnimation` |
|
||||
| `web/src/views/LabelView.vue` | Badge pop + bucket grid rise via Anime.js |
|
||||
| `web/src/assets/avocet.css` | Remove any global animation keyframes if present |
|
||||
573
docs/plans/2026-03-08-anime-animation-plan.md
Normal file
573
docs/plans/2026-03-08-anime-animation-plan.md
Normal file
|
|
@ -0,0 +1,573 @@
|
|||
# Anime.js Animation Integration — Implementation Plan
|
||||
|
||||
> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task.
|
||||
|
||||
**Goal:** Replace the current mixed CSS keyframes / inline-style animation system with Anime.js v4 for all card motion — pickup morph, drag tracking, spring snap-back, dismissals, bucket grid rise, and badge pop.
|
||||
|
||||
**Architecture:** A new `useCardAnimation` composable owns all Anime.js calls imperatively against DOM refs. Vue reactive state (`isHeld`, `deltaX`, `deltaY`, `dismissType`) is unchanged. `cardStyle` computed and `dismissClass` computed are deleted; Anime.js writes to the element directly.
|
||||
|
||||
**Tech Stack:** Anime.js v4 (`animejs`), Vue 3 Composition API, `@vue/test-utils` + Vitest for tests.
|
||||
|
||||
---
|
||||
|
||||
## Task 1: Install Anime.js
|
||||
|
||||
**Files:**
|
||||
- Modify: `web/package.json`
|
||||
|
||||
**Step 1: Install the package**
|
||||
|
||||
```bash
|
||||
cd /Library/Development/CircuitForge/avocet/web
|
||||
npm install animejs
|
||||
```
|
||||
|
||||
**Step 2: Verify the import resolves**
|
||||
|
||||
Create a throwaway check — open `web/src/main.ts` briefly and confirm:
|
||||
```ts
|
||||
import { animate, spring } from 'animejs'
|
||||
```
|
||||
resolves without error in the editor (TypeScript types ship with animejs v4).
|
||||
Remove the import immediately after verifying — do not commit it.
|
||||
|
||||
**Step 3: Commit**
|
||||
|
||||
```bash
|
||||
cd /Library/Development/CircuitForge/avocet/web
|
||||
git add package.json package-lock.json
|
||||
git commit -m "feat(avocet): add animejs v4 dependency"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Task 2: Create `useCardAnimation` composable
|
||||
|
||||
**Files:**
|
||||
- Create: `web/src/composables/useCardAnimation.ts`
|
||||
- Create: `web/src/composables/useCardAnimation.test.ts`
|
||||
|
||||
**Background — Anime.js v4 transform model:**
|
||||
Anime.js v4 tracks `x`, `y`, `scale`, `rotate`, etc. as separate transform components internally.
|
||||
Use `utils.set(el, props)` for instant (no-animation) property updates — this keeps the internal cache consistent.
|
||||
Never mix direct `el.style.transform = "..."` with Anime.js on the same element, or the cache desyncs.
|
||||
|
||||
**Step 1: Write the failing tests**
|
||||
|
||||
`web/src/composables/useCardAnimation.test.ts`:
|
||||
```ts
|
||||
import { ref, nextTick } from 'vue'
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest'
|
||||
|
||||
// Mock animejs before importing the composable
|
||||
vi.mock('animejs', () => ({
|
||||
animate: vi.fn(),
|
||||
spring: vi.fn(() => 'mock-spring'),
|
||||
utils: { set: vi.fn() },
|
||||
}))
|
||||
|
||||
import { useCardAnimation } from './useCardAnimation'
|
||||
import { animate, utils } from 'animejs'
|
||||
|
||||
const mockAnimate = animate as ReturnType<typeof vi.fn>
|
||||
const mockSet = utils.set as ReturnType<typeof vi.fn>
|
||||
|
||||
function makeEl() {
|
||||
return document.createElement('div')
|
||||
}
|
||||
|
||||
describe('useCardAnimation', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
it('pickup() calls animate with ball shape', () => {
|
||||
const el = makeEl()
|
||||
const cardEl = ref<HTMLElement | null>(el)
|
||||
const motion = { rich: ref(true) }
|
||||
const { pickup } = useCardAnimation(cardEl, motion)
|
||||
pickup()
|
||||
expect(mockAnimate).toHaveBeenCalledWith(
|
||||
el,
|
||||
expect.objectContaining({ scale: 0.55, borderRadius: '50%' }),
|
||||
expect.anything(),
|
||||
)
|
||||
})
|
||||
|
||||
it('pickup() is a no-op when motion.rich is false', () => {
|
||||
const el = makeEl()
|
||||
const cardEl = ref<HTMLElement | null>(el)
|
||||
const motion = { rich: ref(false) }
|
||||
const { pickup } = useCardAnimation(cardEl, motion)
|
||||
pickup()
|
||||
expect(mockAnimate).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('setDragPosition() calls utils.set with translated coords', () => {
|
||||
const el = makeEl()
|
||||
const cardEl = ref<HTMLElement | null>(el)
|
||||
const motion = { rich: ref(true) }
|
||||
const { setDragPosition } = useCardAnimation(cardEl, motion)
|
||||
setDragPosition(50, 30)
|
||||
expect(mockSet).toHaveBeenCalledWith(el, expect.objectContaining({ x: 50, y: -50 }))
|
||||
// y = deltaY - 80 = 30 - 80 = -50
|
||||
})
|
||||
|
||||
it('snapBack() calls animate returning to card shape', () => {
|
||||
const el = makeEl()
|
||||
const cardEl = ref<HTMLElement | null>(el)
|
||||
const motion = { rich: ref(true) }
|
||||
const { snapBack } = useCardAnimation(cardEl, motion)
|
||||
snapBack()
|
||||
expect(mockAnimate).toHaveBeenCalledWith(
|
||||
el,
|
||||
expect.objectContaining({ x: 0, y: 0, scale: 1 }),
|
||||
expect.anything(),
|
||||
)
|
||||
})
|
||||
|
||||
it('animateDismiss("label") calls animate', () => {
|
||||
const el = makeEl()
|
||||
const cardEl = ref<HTMLElement | null>(el)
|
||||
const motion = { rich: ref(true) }
|
||||
const { animateDismiss } = useCardAnimation(cardEl, motion)
|
||||
animateDismiss('label')
|
||||
expect(mockAnimate).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('animateDismiss("discard") calls animate', () => {
|
||||
const el = makeEl()
|
||||
const cardEl = ref<HTMLElement | null>(el)
|
||||
const motion = { rich: ref(true) }
|
||||
const { animateDismiss } = useCardAnimation(cardEl, motion)
|
||||
animateDismiss('discard')
|
||||
expect(mockAnimate).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('animateDismiss("skip") calls animate', () => {
|
||||
const el = makeEl()
|
||||
const cardEl = ref<HTMLElement | null>(el)
|
||||
const motion = { rich: ref(true) }
|
||||
const { animateDismiss } = useCardAnimation(cardEl, motion)
|
||||
animateDismiss('skip')
|
||||
expect(mockAnimate).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('animateDismiss is a no-op when motion.rich is false', () => {
|
||||
const el = makeEl()
|
||||
const cardEl = ref<HTMLElement | null>(el)
|
||||
const motion = { rich: ref(false) }
|
||||
const { animateDismiss } = useCardAnimation(cardEl, motion)
|
||||
animateDismiss('label')
|
||||
expect(mockAnimate).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
```
|
||||
|
||||
**Step 2: Run tests to confirm they fail**
|
||||
|
||||
```bash
|
||||
cd /Library/Development/CircuitForge/avocet/web
|
||||
npm test -- useCardAnimation
|
||||
```
|
||||
|
||||
Expected: FAIL — "Cannot find module './useCardAnimation'"
|
||||
|
||||
**Step 3: Implement the composable**
|
||||
|
||||
`web/src/composables/useCardAnimation.ts`:
|
||||
```ts
|
||||
import { type Ref } from 'vue'
|
||||
import { animate, spring, utils } from 'animejs'
|
||||
|
||||
const BALL_SCALE = 0.55
|
||||
const BALL_RADIUS = '50%'
|
||||
const CARD_RADIUS = '1rem'
|
||||
const PICKUP_Y_OFFSET = 80 // px above finger
|
||||
const PICKUP_DURATION = 200
|
||||
// NOTE: animejs v4 — spring() takes an object, not positional args
|
||||
const SNAP_SPRING = spring({ mass: 1, stiffness: 80, damping: 10 })
|
||||
|
||||
interface Motion { rich: Ref<boolean> }
|
||||
|
||||
export function useCardAnimation(
|
||||
cardEl: Ref<HTMLElement | null>,
|
||||
motion: Motion,
|
||||
) {
|
||||
function pickup() {
|
||||
if (!motion.rich.value || !cardEl.value) return
|
||||
// NOTE: animejs v4 — animate() is 2-arg; timing options merge into the params object
|
||||
animate(cardEl.value, {
|
||||
scale: BALL_SCALE,
|
||||
borderRadius: BALL_RADIUS,
|
||||
y: -PICKUP_Y_OFFSET,
|
||||
duration: PICKUP_DURATION,
|
||||
ease: SNAP_SPRING,
|
||||
})
|
||||
}
|
||||
|
||||
function setDragPosition(dx: number, dy: number) {
|
||||
if (!cardEl.value) return
|
||||
utils.set(cardEl.value, { x: dx, y: dy - PICKUP_Y_OFFSET })
|
||||
}
|
||||
|
||||
function snapBack() {
|
||||
if (!motion.rich.value || !cardEl.value) return
|
||||
// No duration — spring physics determines settling time
|
||||
animate(cardEl.value, {
|
||||
x: 0,
|
||||
y: 0,
|
||||
scale: 1,
|
||||
borderRadius: CARD_RADIUS,
|
||||
ease: SNAP_SPRING,
|
||||
})
|
||||
}
|
||||
|
||||
function animateDismiss(type: 'label' | 'skip' | 'discard') {
|
||||
if (!motion.rich.value || !cardEl.value) return
|
||||
const el = cardEl.value
|
||||
if (type === 'label') {
|
||||
animate(el, { y: '-120%', scale: 0.85, opacity: 0, duration: 280, ease: 'out(3)' })
|
||||
} else if (type === 'discard') {
|
||||
// Two-step: crumple then shrink (keyframes array in params object)
|
||||
animate(el, { keyframes: [
|
||||
{ scale: 0.95, rotate: 2, filter: 'brightness(0.6) sepia(1) hue-rotate(-20deg)', duration: 140 },
|
||||
{ scale: 0, rotate: 8, opacity: 0, duration: 210 },
|
||||
])
|
||||
} else if (type === 'skip') {
|
||||
animate(el, { x: '110%', rotate: 5, opacity: 0 }, { duration: 260, ease: 'out(2)' })
|
||||
}
|
||||
}
|
||||
|
||||
return { pickup, setDragPosition, snapBack, animateDismiss }
|
||||
}
|
||||
```
|
||||
|
||||
**Step 4: Run tests — expect pass**
|
||||
|
||||
```bash
|
||||
npm test -- useCardAnimation
|
||||
```
|
||||
|
||||
Expected: All 8 tests PASS.
|
||||
|
||||
**Step 5: Commit**
|
||||
|
||||
```bash
|
||||
git add web/src/composables/useCardAnimation.ts web/src/composables/useCardAnimation.test.ts
|
||||
git commit -m "feat(avocet): add useCardAnimation composable with Anime.js"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Task 3: Wire `useCardAnimation` into `EmailCardStack.vue`
|
||||
|
||||
**Files:**
|
||||
- Modify: `web/src/components/EmailCardStack.vue`
|
||||
- Modify: `web/src/components/EmailCardStack.test.ts`
|
||||
|
||||
**What changes:**
|
||||
- Remove `cardStyle` computed and `:style="cardStyle"` binding
|
||||
- Remove `dismissClass` computed and `:class="[dismissClass, ...]"` binding (keep `is-held`)
|
||||
- Remove `deltaX`, `deltaY` reactive refs (position now owned by Anime.js)
|
||||
- Call `pickup()` in `onPointerDown`, `setDragPosition()` in `onPointerMove`, `snapBack()` in `onPointerUp` (no-target path)
|
||||
- Watch `props.dismissType` and call `animateDismiss()`
|
||||
- Remove CSS `@keyframes fileAway`, `crumple`, `slideUnder` and their `.dismiss-*` rule blocks from `<style>`
|
||||
|
||||
**Step 1: Update the tests that check dismiss classes**
|
||||
|
||||
In `EmailCardStack.test.ts`, the 5 tests checking `.dismiss-label`, `.dismiss-discard`, `.dismiss-skip` classes are testing implementation (CSS class name), not behavior. Replace them with a single test that verifies `animateDismiss` is called:
|
||||
|
||||
```ts
|
||||
// Add at the top of the file (after existing imports):
|
||||
vi.mock('../composables/useCardAnimation', () => ({
|
||||
useCardAnimation: vi.fn(() => ({
|
||||
pickup: vi.fn(),
|
||||
setDragPosition: vi.fn(),
|
||||
snapBack: vi.fn(),
|
||||
animateDismiss: vi.fn(),
|
||||
})),
|
||||
}))
|
||||
|
||||
import { useCardAnimation } from '../composables/useCardAnimation'
|
||||
```
|
||||
|
||||
Replace the five `dismissType` class tests (lines 25–46) with:
|
||||
|
||||
```ts
|
||||
it('calls animateDismiss with type when dismissType prop changes', async () => {
|
||||
const w = mount(EmailCardStack, { props: { item, isBucketMode: false, dismissType: null } })
|
||||
const { animateDismiss } = (useCardAnimation as ReturnType<typeof vi.fn>).mock.results[0].value
|
||||
await w.setProps({ dismissType: 'label' })
|
||||
await nextTick()
|
||||
expect(animateDismiss).toHaveBeenCalledWith('label')
|
||||
})
|
||||
```
|
||||
|
||||
Add `nextTick` import to the test file header if not already present:
|
||||
```ts
|
||||
import { nextTick } from 'vue'
|
||||
```
|
||||
|
||||
**Step 2: Run tests to confirm the replaced tests fail**
|
||||
|
||||
```bash
|
||||
npm test -- EmailCardStack
|
||||
```
|
||||
|
||||
Expected: FAIL — `animateDismiss` not called (not yet wired in component)
|
||||
|
||||
**Step 3: Modify `EmailCardStack.vue`**
|
||||
|
||||
Script section changes:
|
||||
|
||||
```ts
|
||||
// Remove:
|
||||
// import { ref, computed } from 'vue' → change to:
|
||||
import { ref, watch } from 'vue'
|
||||
|
||||
// Add import:
|
||||
import { useCardAnimation } from '../composables/useCardAnimation'
|
||||
|
||||
// Remove these refs:
|
||||
// const deltaX = ref(0)
|
||||
// const deltaY = ref(0)
|
||||
|
||||
// Add after const motion = useMotion():
|
||||
const { pickup, setDragPosition, snapBack, animateDismiss } = useCardAnimation(cardEl, motion)
|
||||
|
||||
// Add watcher:
|
||||
watch(() => props.dismissType, (type) => {
|
||||
if (type) animateDismiss(type)
|
||||
})
|
||||
|
||||
// Remove dismissClass computed entirely.
|
||||
|
||||
// In onPointerDown — add after isHeld.value = true:
|
||||
pickup()
|
||||
|
||||
// In onPointerMove — replace deltaX/deltaY assignments with:
|
||||
const dx = e.clientX - pickupX.value
|
||||
const dy = e.clientY - pickupY.value
|
||||
setDragPosition(dx, dy)
|
||||
// (keep the zone/bucket detection that uses e.clientX/e.clientY — those stay the same)
|
||||
|
||||
// In onPointerUp — in the snap-back else branch, replace:
|
||||
// deltaX.value = 0
|
||||
// deltaY.value = 0
|
||||
// with:
|
||||
snapBack()
|
||||
```
|
||||
|
||||
Template changes — on the `.card-wrapper` div:
|
||||
```html
|
||||
<!-- Remove: :class="[dismissClass, { 'is-held': isHeld }]" -->
|
||||
<!-- Replace with: -->
|
||||
:class="{ 'is-held': isHeld }"
|
||||
<!-- Remove: :style="cardStyle" -->
|
||||
```
|
||||
|
||||
CSS changes in `<style scoped>` — delete these entire blocks:
|
||||
```
|
||||
@keyframes fileAway { ... }
|
||||
@keyframes crumple { ... }
|
||||
@keyframes slideUnder { ... }
|
||||
.card-wrapper.dismiss-label { ... }
|
||||
.card-wrapper.dismiss-discard { ... }
|
||||
.card-wrapper.dismiss-skip { ... }
|
||||
```
|
||||
|
||||
Also delete `--card-dismiss` and `--card-skip` CSS var usages if present.
|
||||
|
||||
**Step 4: Run all tests**
|
||||
|
||||
```bash
|
||||
npm test
|
||||
```
|
||||
|
||||
Expected: All pass (both `useCardAnimation.test.ts` and `EmailCardStack.test.ts`).
|
||||
|
||||
**Step 5: Commit**
|
||||
|
||||
```bash
|
||||
git add web/src/components/EmailCardStack.vue web/src/components/EmailCardStack.test.ts
|
||||
git commit -m "feat(avocet): wire Anime.js card animation into EmailCardStack"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Task 4: Bucket grid rise animation
|
||||
|
||||
**Files:**
|
||||
- Modify: `web/src/views/LabelView.vue`
|
||||
|
||||
**What changes:**
|
||||
Replace the CSS class-toggle animation on `.bucket-grid-footer.grid-active` with an Anime.js watch in `LabelView.vue`. The `position: sticky → fixed` switch stays as a CSS class (can't animate position), but `translateY` and `opacity` move to Anime.js.
|
||||
|
||||
**Step 1: Add gridEl ref and import animate**
|
||||
|
||||
In `LabelView.vue` `<script setup>`:
|
||||
```ts
|
||||
// Add to imports:
|
||||
import { ref, onMounted, onUnmounted, watch } from 'vue'
|
||||
import { animate, spring } from 'animejs'
|
||||
|
||||
// Add ref:
|
||||
const gridEl = ref<HTMLElement | null>(null)
|
||||
```
|
||||
|
||||
**Step 2: Add watcher for isHeld**
|
||||
|
||||
```ts
|
||||
watch(isHeld, (held) => {
|
||||
if (!motion.rich.value || !gridEl.value) return
|
||||
// animejs v4: 2-arg animate, spring() takes object
|
||||
animate(gridEl.value,
|
||||
held
|
||||
? { y: -8, opacity: 0.45, ease: spring({ mass: 1, stiffness: 80, damping: 10 }), duration: 250 }
|
||||
: { y: 0, opacity: 1, ease: spring({ mass: 1, stiffness: 80, damping: 10 }), duration: 250 }
|
||||
)
|
||||
})
|
||||
```
|
||||
|
||||
**Step 3: Wire ref in template**
|
||||
|
||||
On the `.bucket-grid-footer` div:
|
||||
```html
|
||||
<div ref="gridEl" class="bucket-grid-footer" :class="{ 'grid-active': isHeld }">
|
||||
```
|
||||
|
||||
**Step 4: Remove CSS transition from `.bucket-grid-footer`**
|
||||
|
||||
In `LabelView.vue <style scoped>`, delete the `transition:` line from `.bucket-grid-footer`:
|
||||
```css
|
||||
/* DELETE this line: */
|
||||
transition: transform 250ms cubic-bezier(0.34, 1.56, 0.64, 1),
|
||||
opacity 200ms ease,
|
||||
background 200ms ease;
|
||||
```
|
||||
Keep the `transform: translateY(-8px)` and `opacity: 0.45` on `.bucket-grid-footer.grid-active` as fallback for reduced-motion users (no-JS fallback too).
|
||||
|
||||
Actually — keep `.grid-active` rules as-is for the no-motion path. The Anime.js `watch` guard (`if (!motion.rich.value)`) means reduced-motion users never hit Anime.js; the CSS class handles them.
|
||||
|
||||
**Step 5: Run tests**
|
||||
|
||||
```bash
|
||||
npm test
|
||||
```
|
||||
|
||||
Expected: All pass (LabelView has no dedicated tests, but full suite should be green).
|
||||
|
||||
**Step 6: Commit**
|
||||
|
||||
```bash
|
||||
git add web/src/views/LabelView.vue
|
||||
git commit -m "feat(avocet): animate bucket grid rise with Anime.js spring"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Task 5: Badge pop animation
|
||||
|
||||
**Files:**
|
||||
- Modify: `web/src/views/LabelView.vue`
|
||||
|
||||
**What changes:**
|
||||
Replace `@keyframes badge-pop` (scale + opacity keyframe) with a Vue `<Transition>` `@enter` hook that calls `animate()`. Badges already appear/disappear via `v-if`, so they have natural mount/unmount lifecycle.
|
||||
|
||||
**Step 1: Wrap each badge in a `<Transition>`**
|
||||
|
||||
In `LabelView.vue` template, each badge `<span v-if="...">` gets wrapped:
|
||||
|
||||
```html
|
||||
<Transition @enter="onBadgeEnter" :css="false">
|
||||
<span v-if="onRoll" class="badge badge-roll">🔥 On a roll!</span>
|
||||
</Transition>
|
||||
<Transition @enter="onBadgeEnter" :css="false">
|
||||
<span v-if="speedRound" class="badge badge-speed">⚡ Speed round!</span>
|
||||
</Transition>
|
||||
<!-- repeat for all 6 badges -->
|
||||
```
|
||||
|
||||
`:css="false"` tells Vue not to apply any CSS transition classes — Anime.js owns the enter animation entirely.
|
||||
|
||||
**Step 2: Add `onBadgeEnter` hook**
|
||||
|
||||
```ts
|
||||
function onBadgeEnter(el: Element, done: () => void) {
|
||||
if (!motion.rich.value) { done(); return }
|
||||
animate(el as HTMLElement,
|
||||
{ scale: [0.6, 1], opacity: [0, 1] },
|
||||
{ ease: spring(1.5, 80, 8), duration: 300, onComplete: done }
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
**Step 3: Remove `@keyframes badge-pop` from CSS**
|
||||
|
||||
In `LabelView.vue <style scoped>`:
|
||||
```css
|
||||
/* DELETE: */
|
||||
@keyframes badge-pop {
|
||||
from { transform: scale(0.6); opacity: 0; }
|
||||
to { transform: scale(1); opacity: 1; }
|
||||
}
|
||||
|
||||
/* DELETE animation line from .badge: */
|
||||
animation: badge-pop 0.3s cubic-bezier(0.34, 1.56, 0.64, 1);
|
||||
```
|
||||
|
||||
**Step 4: Run tests**
|
||||
|
||||
```bash
|
||||
npm test
|
||||
```
|
||||
|
||||
Expected: All pass.
|
||||
|
||||
**Step 5: Commit**
|
||||
|
||||
```bash
|
||||
git add web/src/views/LabelView.vue
|
||||
git commit -m "feat(avocet): badge pop via Anime.js spring transition hook"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Task 6: Build and smoke test
|
||||
|
||||
**Step 1: Build the SPA**
|
||||
|
||||
```bash
|
||||
cd /Library/Development/CircuitForge/avocet
|
||||
./manage.sh start-api
|
||||
```
|
||||
|
||||
(This builds Vue + starts FastAPI on port 8503.)
|
||||
|
||||
**Step 2: Open the app**
|
||||
|
||||
```bash
|
||||
./manage.sh open-api
|
||||
```
|
||||
|
||||
**Step 3: Manual smoke test checklist**
|
||||
|
||||
- [ ] Pick up a card — ball morph is smooth (not instant jump)
|
||||
- [ ] Drag ball around — follows finger with no lag
|
||||
- [ ] Release in center — springs back to card with bounce
|
||||
- [ ] Release in left zone — discard fires (card crumples)
|
||||
- [ ] Release in right zone — skip fires (card slides right)
|
||||
- [ ] Release on a bucket — label fires (card files up)
|
||||
- [ ] Fling left fast — discard fires
|
||||
- [ ] Bucket grid rises smoothly on pickup, falls on release
|
||||
- [ ] Badge (label 10 in a row for 🔥) pops in with spring
|
||||
- [ ] Reduced motion: toggle in system settings → no animations, instant behavior
|
||||
- [ ] Keyboard labels (1–9) still work (pointer events unchanged)
|
||||
|
||||
**Step 4: Final commit if all green**
|
||||
|
||||
```bash
|
||||
git add -A
|
||||
git commit -m "feat(avocet): complete Anime.js animation integration"
|
||||
```
|
||||
1861
docs/superpowers/plans/2026-03-15-finetune-classifier.md
Normal file
1861
docs/superpowers/plans/2026-03-15-finetune-classifier.md
Normal file
File diff suppressed because it is too large
Load diff
254
docs/superpowers/specs/2026-03-15-finetune-classifier-design.md
Normal file
254
docs/superpowers/specs/2026-03-15-finetune-classifier-design.md
Normal file
|
|
@ -0,0 +1,254 @@
|
|||
# Fine-tune Email Classifier — Design Spec
|
||||
|
||||
**Date:** 2026-03-15
|
||||
**Status:** Approved
|
||||
**Scope:** Avocet — `scripts/`, `app/api.py`, `web/src/views/BenchmarkView.vue`, `environment.yml`
|
||||
|
||||
---
|
||||
|
||||
## Problem
|
||||
|
||||
The benchmark baseline shows zero-shot macro-F1 of 0.366 for the best models (`deberta-zeroshot`, `deberta-base-anli`). Zero-shot inference cannot improve with more labeled data. Fine-tuning the fastest models (`deberta-small` at 111ms, `bge-m3` at 123ms) on the growing labeled dataset is the path to meaningful accuracy gains.
|
||||
|
||||
---
|
||||
|
||||
## Constraints
|
||||
|
||||
- 501 labeled samples after dropping 2 non-canonical `profile_alert` rows
|
||||
- Heavy class imbalance: `digest` 29%, `neutral` 26%, `new_lead` 2.6%, `survey_received` 3%
|
||||
- 8.2 GB VRAM (shared with Peregrine vLLM during dev)
|
||||
- Target models: `cross-encoder/nli-deberta-v3-small` (100M params), `MoritzLaurer/bge-m3-zeroshot-v2.0` (600M params)
|
||||
- Output: local `models/avocet-{name}/` directory
|
||||
- UI-triggerable via web interface (SSE streaming log)
|
||||
- Stack: transformers 4.57.3, torch 2.10.0, accelerate 1.12.0, sklearn, CUDA 8.2GB
|
||||
|
||||
---
|
||||
|
||||
## Environment changes
|
||||
|
||||
`environment.yml` must add:
|
||||
- `scikit-learn` — required for `train_test_split(stratify=...)` and `f1_score`
|
||||
- `peft` is NOT used by this spec; it is available in the env but not required here
|
||||
|
||||
---
|
||||
|
||||
## Architecture
|
||||
|
||||
### New file: `scripts/finetune_classifier.py`
|
||||
|
||||
CLI entry point for fine-tuning. All prints use `flush=True` so stdout is SSE-streamable.
|
||||
|
||||
```
|
||||
python scripts/finetune_classifier.py --model deberta-small [--epochs 5]
|
||||
```
|
||||
|
||||
Supported `--model` values: `deberta-small`, `bge-m3`
|
||||
|
||||
**Model registry** (internal to this script):
|
||||
|
||||
| Key | Base model ID | Max tokens | fp16 | Batch size | Grad accum steps | Gradient checkpointing |
|
||||
|-----|--------------|------------|------|------------|-----------------|----------------------|
|
||||
| `deberta-small` | `cross-encoder/nli-deberta-v3-small` | 512 | No | 16 | 1 | No |
|
||||
| `bge-m3` | `MoritzLaurer/bge-m3-zeroshot-v2.0` | 512 | Yes | 4 | 4 | Yes |
|
||||
|
||||
`bge-m3` uses `fp16=True` (halves optimizer state from ~4.8GB to ~2.4GB) with batch size 4 + gradient accumulation 4 = effective batch 16, matching `deberta-small`. These settings are required to fit within 8.2GB VRAM. Still stop Peregrine vLLM before running bge-m3 fine-tuning.
|
||||
|
||||
### Modified: `scripts/classifier_adapters.py`
|
||||
|
||||
Add `FineTunedAdapter(ClassifierAdapter)`:
|
||||
- Takes `model_dir: str` (path to a `models/avocet-*/` checkpoint)
|
||||
- Loads via `pipeline("text-classification", model=model_dir)`
|
||||
- `classify()` input format: **`f"{subject} [SEP] {body[:400]}"`** — must match the training format exactly. Do NOT use the zero-shot adapters' `f"Subject: {subject}\n\n{body[:600]}"` format; distribution shift will degrade accuracy.
|
||||
- Returns the top predicted label directly (single forward pass — no per-label NLI scoring loop)
|
||||
- Expected inference speed: ~10–20ms/email vs 111–338ms for zero-shot
|
||||
|
||||
### Modified: `scripts/benchmark_classifier.py`
|
||||
|
||||
At startup, scan `models/` for subdirectories containing `training_info.json`. Register each as a dynamic entry in the model registry using `FineTunedAdapter`. Silently skips if `models/` does not exist. Existing CLI behaviour unchanged.
|
||||
|
||||
### Modified: `app/api.py`
|
||||
|
||||
Two new GET endpoints (GET required for `EventSource` compatibility):
|
||||
|
||||
**`GET /api/finetune/status`**
|
||||
Scans `models/` for `training_info.json` files. Returns:
|
||||
```json
|
||||
[
|
||||
{
|
||||
"name": "avocet-deberta-small",
|
||||
"base_model": "cross-encoder/nli-deberta-v3-small",
|
||||
"val_macro_f1": 0.712,
|
||||
"timestamp": "2026-03-15T12:00:00Z",
|
||||
"sample_count": 401
|
||||
}
|
||||
]
|
||||
```
|
||||
Returns `[]` if no fine-tuned models exist.
|
||||
|
||||
**`GET /api/finetune/run?model=deberta-small&epochs=5`**
|
||||
Spawns `finetune_classifier.py` via the `job-seeker-classifiers` Python binary. Streams stdout as SSE `{"type":"progress","message":"..."}` events. Emits `{"type":"complete"}` on clean exit, `{"type":"error","message":"..."}` on non-zero exit. Same implementation pattern as `/api/benchmark/run`.
|
||||
|
||||
### Modified: `web/src/views/BenchmarkView.vue`
|
||||
|
||||
**Trained models badge row** (top of view, conditional on fine-tuned models existing):
|
||||
Shows each fine-tuned model name + val macro-F1 chip. Fetches from `/api/finetune/status` on mount.
|
||||
|
||||
**Fine-tune section** (collapsible, below benchmark charts):
|
||||
- Dropdown: `deberta-small` | `bge-m3`
|
||||
- Number input: epochs (default 5, range 1–20)
|
||||
- Run button → streams into existing log component
|
||||
- On `complete`: auto-triggers `/api/benchmark/run` (with `--save`) so charts update immediately
|
||||
|
||||
---
|
||||
|
||||
## Training Pipeline
|
||||
|
||||
### Data preparation
|
||||
|
||||
1. Load `data/email_score.jsonl`
|
||||
2. Drop rows where `label` not in canonical `LABELS` (removes `profile_alert` etc.)
|
||||
3. Check for classes with < 2 **total** samples (before any split). Drop those classes and warn. Additionally warn — but do not skip — classes with < 5 training samples, noting eval F1 for those classes will be unreliable.
|
||||
4. Input text: `f"{subject} [SEP] {body[:400]}"` — fits within 512 tokens for both target models
|
||||
5. Stratified 80/20 train/val split via `sklearn.model_selection.train_test_split(stratify=labels)`
|
||||
|
||||
### Class weighting
|
||||
|
||||
Compute per-class weights: `total_samples / (n_classes × class_count)`. Pass to a `WeightedTrainer` subclass:
|
||||
|
||||
```python
|
||||
class WeightedTrainer(Trainer):
|
||||
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
|
||||
# **kwargs is required — absorbs num_items_in_batch added in Transformers 4.38.
|
||||
# Do not remove it; removing it causes TypeError on the first training step.
|
||||
labels = inputs.pop("labels")
|
||||
outputs = model(**inputs)
|
||||
# Move class_weights to the same device as logits — required for GPU training.
|
||||
# class_weights is created on CPU; logits are on cuda:0 during training.
|
||||
weight = self.class_weights.to(outputs.logits.device)
|
||||
loss = F.cross_entropy(outputs.logits, labels, weight=weight)
|
||||
return (loss, outputs) if return_outputs else loss
|
||||
```
|
||||
|
||||
### Model setup
|
||||
|
||||
```python
|
||||
AutoModelForSequenceClassification.from_pretrained(
|
||||
base_model_id,
|
||||
num_labels=10,
|
||||
ignore_mismatched_sizes=True, # see note below
|
||||
id2label=id2label,
|
||||
label2id=label2id,
|
||||
)
|
||||
```
|
||||
|
||||
**Note on `ignore_mismatched_sizes=True`:** The pretrained NLI head is a 3-class linear projection. It mismatches the 10-class head constructed by `num_labels=10`, so its weights are skipped during loading. PyTorch initializes the new head from scratch using the model's default init scheme. The backbone weights load normally. Do not set this to `False` — it will raise a shape error.
|
||||
|
||||
### Training config and `compute_metrics`
|
||||
|
||||
The Trainer requires a `compute_metrics` callback that takes an `EvalPrediction` (logits + label_ids) and returns a dict with a `macro_f1` key. This is distinct from the existing `compute_metrics` in `classifier_adapters.py` (which operates on string predictions):
|
||||
|
||||
```python
|
||||
def compute_metrics_for_trainer(eval_pred: EvalPrediction) -> dict:
|
||||
logits, labels = eval_pred
|
||||
preds = logits.argmax(axis=-1)
|
||||
return {
|
||||
"macro_f1": f1_score(labels, preds, average="macro", zero_division=0),
|
||||
"accuracy": accuracy_score(labels, preds),
|
||||
}
|
||||
```
|
||||
|
||||
`TrainingArguments` must include:
|
||||
- `load_best_model_at_end=True`
|
||||
- `metric_for_best_model="macro_f1"`
|
||||
- `greater_is_better=True`
|
||||
|
||||
These are required for `EarlyStoppingCallback` to work correctly. Without `load_best_model_at_end=True`, `EarlyStoppingCallback` raises `AssertionError` on init.
|
||||
|
||||
| Hyperparameter | deberta-small | bge-m3 |
|
||||
|---------------|--------------|--------|
|
||||
| Epochs | 5 (default, CLI-overridable) | 5 |
|
||||
| Batch size | 16 | 4 |
|
||||
| Gradient accumulation | 1 | 4 (effective batch = 16) |
|
||||
| Learning rate | 2e-5 | 2e-5 |
|
||||
| LR schedule | Linear with 10% warmup | same |
|
||||
| Optimizer | AdamW | AdamW |
|
||||
| fp16 | No | Yes |
|
||||
| Gradient checkpointing | No | Yes |
|
||||
| Eval strategy | Every epoch | Every epoch |
|
||||
| Best checkpoint | By `macro_f1` | same |
|
||||
| Early stopping patience | 3 epochs | 3 epochs |
|
||||
|
||||
### Output
|
||||
|
||||
Saved to `models/avocet-{name}/`:
|
||||
- Model weights + tokenizer (standard HuggingFace format)
|
||||
- `training_info.json`:
|
||||
```json
|
||||
{
|
||||
"name": "avocet-deberta-small",
|
||||
"base_model_id": "cross-encoder/nli-deberta-v3-small",
|
||||
"timestamp": "2026-03-15T12:00:00Z",
|
||||
"epochs_run": 5,
|
||||
"val_macro_f1": 0.712,
|
||||
"val_accuracy": 0.798,
|
||||
"sample_count": 401,
|
||||
"label_counts": { "digest": 116, "neutral": 104, ... }
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Data Flow
|
||||
|
||||
```
|
||||
email_score.jsonl
|
||||
│
|
||||
▼
|
||||
finetune_classifier.py
|
||||
├── drop non-canonical labels
|
||||
├── check for < 2 total samples per class (drop + warn)
|
||||
├── stratified 80/20 split
|
||||
├── tokenize (subject [SEP] body[:400])
|
||||
├── compute class weights
|
||||
├── WeightedTrainer + EarlyStoppingCallback
|
||||
└── save → models/avocet-{name}/
|
||||
│
|
||||
├── FineTunedAdapter (classifier_adapters.py)
|
||||
│ ├── pipeline("text-classification")
|
||||
│ ├── input: subject [SEP] body[:400] ← must match training format
|
||||
│ └── ~10–20ms/email inference
|
||||
│
|
||||
└── training_info.json
|
||||
└── /api/finetune/status
|
||||
└── BenchmarkView badge row
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Error Handling
|
||||
|
||||
- **Insufficient data (< 2 total samples in a class):** Drop class before split, print warning with class name and count.
|
||||
- **Low data warning (< 5 training samples in a class):** Warn but continue; note eval F1 for that class will be unreliable.
|
||||
- **VRAM OOM on bge-m3:** Surface as clear SSE error message. Suggest stopping Peregrine vLLM first (it holds ~5.7GB).
|
||||
- **Missing score file:** Raise `FileNotFoundError` with actionable message (same pattern as `load_scoring_jsonl`).
|
||||
- **Model dir already exists:** Overwrite with a warning log line. Re-running always produces a fresh checkpoint.
|
||||
|
||||
---
|
||||
|
||||
## Testing
|
||||
|
||||
- Unit test `WeightedTrainer.compute_loss` with a mock model and known label distribution — verify weighted loss differs from unweighted; verify `**kwargs` does not raise `TypeError`
|
||||
- Unit test `compute_metrics_for_trainer` — verify `macro_f1` key in output, correct value on known inputs
|
||||
- Unit test `FineTunedAdapter.classify` with a mock pipeline — verify it returns a string from `LABELS` using `subject [SEP] body[:400]` format
|
||||
- Unit test auto-discovery in `benchmark_classifier.py` — mock `models/` dir with two `training_info.json` files, verify both appear in the active registry
|
||||
- Integration test: fine-tune on `data/email_score.jsonl.example` (8 samples, 5 of 10 labels represented, 1 epoch, `--model deberta-small`). The 5 missing labels trigger the `< 2 total samples` drop path — the test must verify the drop warning is emitted for each missing label rather than treating it as a failure. Verify `models/avocet-deberta-small/training_info.json` is written with correct keys.
|
||||
|
||||
---
|
||||
|
||||
## Out of Scope
|
||||
|
||||
- Pushing fine-tuned weights to HuggingFace Hub (future)
|
||||
- Cross-validation or k-fold evaluation (future — dataset too small to be meaningful now)
|
||||
- Hyperparameter search (future)
|
||||
- LoRA/PEFT adapter fine-tuning (future — relevant if model sizes grow beyond available VRAM)
|
||||
- Fine-tuning models other than `deberta-small` and `bge-m3`
|
||||
|
|
@ -22,8 +22,5 @@ dependencies:
|
|||
# Optional: BGE reranker adapter
|
||||
# - FlagEmbedding
|
||||
|
||||
# CircuitForge shared core (LLM router, tier system, config)
|
||||
- circuitforge-core>=0.9.0
|
||||
|
||||
# Dev
|
||||
- pytest>=8.0
|
||||
|
|
|
|||
271
manage.sh
271
manage.sh
|
|
@ -19,8 +19,9 @@ LOG_FILE="${LOG_DIR}/label_tool.log"
|
|||
DEFAULT_PORT=8503
|
||||
|
||||
CONDA_BASE="${CONDA_BASE:-/devl/miniconda3}"
|
||||
ENV_UI="${AVOCET_ENV:-cf}"
|
||||
ENV_UI="job-seeker"
|
||||
ENV_BM="job-seeker-classifiers"
|
||||
STREAMLIT="${CONDA_BASE}/envs/${ENV_UI}/bin/streamlit"
|
||||
PYTHON_BM="${CONDA_BASE}/envs/${ENV_BM}/bin/python"
|
||||
PYTHON_UI="${CONDA_BASE}/envs/${ENV_UI}/bin/python"
|
||||
|
||||
|
|
@ -78,11 +79,13 @@ usage() {
|
|||
echo ""
|
||||
echo " Usage: ./manage.sh <command> [args]"
|
||||
echo ""
|
||||
echo " Vue UI + FastAPI:"
|
||||
echo -e " ${GREEN}start${NC} Build Vue SPA + start FastAPI on port 8503"
|
||||
echo -e " ${GREEN}stop${NC} Stop FastAPI server"
|
||||
echo -e " ${GREEN}restart${NC} Stop + rebuild + restart FastAPI server"
|
||||
echo -e " ${GREEN}open${NC} Open Vue UI in browser (http://localhost:8503)"
|
||||
echo " Label tool:"
|
||||
echo -e " ${GREEN}start${NC} Start label tool UI (port collision-safe)"
|
||||
echo -e " ${GREEN}stop${NC} Stop label tool UI"
|
||||
echo -e " ${GREEN}restart${NC} Restart label tool UI"
|
||||
echo -e " ${GREEN}status${NC} Show running state and port"
|
||||
echo -e " ${GREEN}logs${NC} Tail label tool log output"
|
||||
echo -e " ${GREEN}open${NC} Open label tool in browser"
|
||||
echo ""
|
||||
echo " Benchmark:"
|
||||
echo -e " ${GREEN}benchmark [args]${NC} Run benchmark_classifier.py (args passed through)"
|
||||
|
|
@ -90,20 +93,13 @@ usage() {
|
|||
echo -e " ${GREEN}score [args]${NC} Shortcut: --score [args]"
|
||||
echo -e " ${GREEN}compare [args]${NC} Shortcut: --compare [args]"
|
||||
echo ""
|
||||
echo " Planning Benchmark:"
|
||||
echo -e " ${GREEN}plans-bench [args]${NC} Run benchmark_plans.py (args passed through)"
|
||||
echo -e " ${GREEN}plans-list${NC} Shortcut: --list-models"
|
||||
echo -e " ${GREEN}plans-run <model> [args]${NC} Run a single model (--verbose auto-added)"
|
||||
echo -e " ${GREEN}plans-compare <m1> <m2> [more]${NC} Compare models side-by-side"
|
||||
echo ""
|
||||
echo " Writing Style Benchmark:"
|
||||
echo -e " ${GREEN}style-bench [args]${NC} Run benchmark_style.py (args passed through)"
|
||||
echo -e " ${GREEN}style-list${NC} List available ollama models for style bench"
|
||||
echo -e " ${GREEN}style-run [args]${NC} Run writing style benchmark (--models, --samples, --include-large, --scan-disk PATH, --cforch)"
|
||||
echo -e " ${GREEN}style-last${NC} Print most recent writing style benchmark report"
|
||||
echo " Vue API:"
|
||||
echo -e " ${GREEN}start-api${NC} Build Vue SPA + start FastAPI on port 8503"
|
||||
echo -e " ${GREEN}stop-api${NC} Stop FastAPI server"
|
||||
echo -e " ${GREEN}restart-api${NC} Stop + rebuild + restart FastAPI server"
|
||||
echo -e " ${GREEN}open-api${NC} Open Vue UI in browser (http://localhost:8503)"
|
||||
echo ""
|
||||
echo " Dev:"
|
||||
echo -e " ${GREEN}dev${NC} Hot-reload: uvicorn --reload (:8503) + Vite HMR (:5173)"
|
||||
echo -e " ${GREEN}test${NC} Run pytest suite"
|
||||
echo ""
|
||||
echo " Port defaults to ${DEFAULT_PORT}; auto-increments if occupied."
|
||||
|
|
@ -125,107 +121,102 @@ shift || true
|
|||
case "$CMD" in
|
||||
|
||||
start)
|
||||
API_PID_FILE=".avocet-api.pid"
|
||||
API_PORT=8503
|
||||
if [[ -f "$API_PID_FILE" ]] && kill -0 "$(<"$API_PID_FILE")" 2>/dev/null; then
|
||||
warn "API already running (PID $(<"$API_PID_FILE")) → http://localhost:${API_PORT}"
|
||||
pid=$(_running_pid)
|
||||
if [[ -n "$pid" ]]; then
|
||||
port=$(_running_port)
|
||||
warn "Already running (PID ${pid}) on port ${port} → http://localhost:${port}"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
if [[ ! -x "$STREAMLIT" ]]; then
|
||||
error "Streamlit not found at ${STREAMLIT}\nActivate env: conda run -n ${ENV_UI} ..."
|
||||
fi
|
||||
|
||||
port=$(_find_free_port "$DEFAULT_PORT")
|
||||
mkdir -p "$LOG_DIR"
|
||||
API_LOG="${LOG_DIR}/api.log"
|
||||
# Load .env if present — sets HF_TOKEN and other optional overrides.
|
||||
[[ -f .env ]] && set -a && source .env && set +a
|
||||
info "Building Vue SPA…"
|
||||
(cd web && npm run build) >> "$API_LOG" 2>&1
|
||||
info "Starting FastAPI on port ${API_PORT}…"
|
||||
nohup "$PYTHON_UI" -m uvicorn app.api:app \
|
||||
--host 0.0.0.0 --port "$API_PORT" \
|
||||
>> "$API_LOG" 2>&1 &
|
||||
echo $! > "$API_PID_FILE"
|
||||
# Poll until port is actually bound (up to 10 s), not just process alive
|
||||
for _i in $(seq 1 20); do
|
||||
sleep 0.5
|
||||
if (echo "" >/dev/tcp/127.0.0.1/"$API_PORT") 2>/dev/null; then
|
||||
success "Avocet started → http://localhost:${API_PORT} (PID $(<"$API_PID_FILE"))"
|
||||
break
|
||||
fi
|
||||
if ! kill -0 "$(<"$API_PID_FILE")" 2>/dev/null; then
|
||||
rm -f "$API_PID_FILE"
|
||||
error "Server died during startup. Check ${API_LOG}"
|
||||
fi
|
||||
done
|
||||
if ! (echo "" >/dev/tcp/127.0.0.1/"$API_PORT") 2>/dev/null; then
|
||||
error "Server did not bind to port ${API_PORT} within 10 s. Check ${API_LOG}"
|
||||
|
||||
info "Starting label tool on port ${port}…"
|
||||
nohup "$STREAMLIT" run app/label_tool.py \
|
||||
--server.port "$port" \
|
||||
--server.headless true \
|
||||
--server.fileWatcherType none \
|
||||
>"$LOG_FILE" 2>&1 &
|
||||
|
||||
pid=$!
|
||||
echo "$pid" > "$PID_FILE"
|
||||
echo "$port" > "$PORT_FILE"
|
||||
|
||||
# Wait briefly and confirm the process survived
|
||||
sleep 1
|
||||
if kill -0 "$pid" 2>/dev/null; then
|
||||
success "Avocet label tool started → http://localhost:${port} (PID ${pid})"
|
||||
success "Logs: ${LOG_FILE}"
|
||||
else
|
||||
rm -f "$PID_FILE" "$PORT_FILE"
|
||||
error "Process died immediately. Check ${LOG_FILE} for details."
|
||||
fi
|
||||
;;
|
||||
|
||||
stop)
|
||||
API_PID_FILE=".avocet-api.pid"
|
||||
if [[ ! -f "$API_PID_FILE" ]]; then
|
||||
pid=$(_running_pid)
|
||||
if [[ -z "$pid" ]]; then
|
||||
warn "Not running."
|
||||
exit 0
|
||||
fi
|
||||
PID="$(<"$API_PID_FILE")"
|
||||
if kill -0 "$PID" 2>/dev/null; then
|
||||
kill "$PID" && rm -f "$API_PID_FILE"
|
||||
success "Stopped (PID ${PID})."
|
||||
else
|
||||
warn "Stale PID file (process ${PID} not running). Cleaning up."
|
||||
rm -f "$API_PID_FILE"
|
||||
info "Stopping label tool (PID ${pid})…"
|
||||
kill "$pid"
|
||||
# Wait up to 5 s for clean exit
|
||||
for _ in $(seq 1 10); do
|
||||
kill -0 "$pid" 2>/dev/null || break
|
||||
sleep 0.5
|
||||
done
|
||||
if kill -0 "$pid" 2>/dev/null; then
|
||||
warn "Process did not exit cleanly; sending SIGKILL…"
|
||||
kill -9 "$pid" 2>/dev/null || true
|
||||
fi
|
||||
rm -f "$PID_FILE" "$PORT_FILE"
|
||||
success "Stopped."
|
||||
;;
|
||||
|
||||
restart)
|
||||
bash "$0" stop
|
||||
pid=$(_running_pid)
|
||||
if [[ -n "$pid" ]]; then
|
||||
info "Stopping existing process (PID ${pid})…"
|
||||
kill "$pid"
|
||||
for _ in $(seq 1 10); do
|
||||
kill -0 "$pid" 2>/dev/null || break
|
||||
sleep 0.5
|
||||
done
|
||||
kill -0 "$pid" 2>/dev/null && kill -9 "$pid" 2>/dev/null || true
|
||||
rm -f "$PID_FILE" "$PORT_FILE"
|
||||
fi
|
||||
exec bash "$0" start
|
||||
;;
|
||||
|
||||
dev)
|
||||
API_PORT=8503
|
||||
VITE_PORT=5173
|
||||
DEV_API_PID_FILE=".avocet-dev-api.pid"
|
||||
mkdir -p "$LOG_DIR"
|
||||
DEV_API_LOG="${LOG_DIR}/dev-api.log"
|
||||
|
||||
# Load .env if present — sets HF_TOKEN and other optional overrides.
|
||||
[[ -f .env ]] && set -a && source .env && set +a
|
||||
|
||||
if [[ -f "$DEV_API_PID_FILE" ]] && kill -0 "$(<"$DEV_API_PID_FILE")" 2>/dev/null; then
|
||||
warn "Dev API already running (PID $(<"$DEV_API_PID_FILE"))"
|
||||
status)
|
||||
pid=$(_running_pid)
|
||||
if [[ -n "$pid" ]]; then
|
||||
port=$(_running_port)
|
||||
success "Running — PID ${pid} port ${port} → http://localhost:${port}"
|
||||
else
|
||||
info "Starting uvicorn with --reload on port ${API_PORT}…"
|
||||
nohup "$PYTHON_UI" -m uvicorn app.api:app \
|
||||
--host 0.0.0.0 --port "$API_PORT" --reload \
|
||||
>> "$DEV_API_LOG" 2>&1 &
|
||||
echo $! > "$DEV_API_PID_FILE"
|
||||
# Wait for API to bind
|
||||
for _i in $(seq 1 20); do
|
||||
sleep 0.5
|
||||
(echo "" >/dev/tcp/127.0.0.1/"$API_PORT") 2>/dev/null && break
|
||||
if ! kill -0 "$(<"$DEV_API_PID_FILE")" 2>/dev/null; then
|
||||
rm -f "$DEV_API_PID_FILE"
|
||||
error "Dev API died during startup. Check ${DEV_API_LOG}"
|
||||
fi
|
||||
done
|
||||
success "API (hot-reload) → http://localhost:${API_PORT}"
|
||||
warn "Not running."
|
||||
fi
|
||||
;;
|
||||
|
||||
# Kill API on exit (Ctrl+C or Vite exits)
|
||||
_cleanup_dev() {
|
||||
local pid
|
||||
pid=$(<"$DEV_API_PID_FILE" 2>/dev/null || true)
|
||||
[[ -n "$pid" ]] && kill "$pid" 2>/dev/null && rm -f "$DEV_API_PID_FILE"
|
||||
info "Dev servers stopped."
|
||||
}
|
||||
trap _cleanup_dev EXIT INT TERM
|
||||
|
||||
info "Starting Vite HMR on port ${VITE_PORT} (proxy /api → :${API_PORT})…"
|
||||
success "Frontend (HMR) → http://localhost:${VITE_PORT}"
|
||||
(cd web && npm run dev -- --host 0.0.0.0 --port "$VITE_PORT")
|
||||
logs)
|
||||
if [[ ! -f "$LOG_FILE" ]]; then
|
||||
warn "No log file found at ${LOG_FILE}. Has the tool been started?"
|
||||
exit 0
|
||||
fi
|
||||
info "Tailing ${LOG_FILE} (Ctrl-C to stop)"
|
||||
tail -f "$LOG_FILE"
|
||||
;;
|
||||
|
||||
open)
|
||||
URL="http://localhost:8503"
|
||||
port=$(_running_port)
|
||||
pid=$(_running_pid)
|
||||
[[ -z "$pid" ]] && warn "Label tool does not appear to be running. Start with: ./manage.sh start"
|
||||
URL="http://localhost:${port}"
|
||||
info "Opening ${URL}"
|
||||
if command -v xdg-open &>/dev/null; then
|
||||
xdg-open "$URL"
|
||||
|
|
@ -266,48 +257,70 @@ case "$CMD" in
|
|||
exec "$0" benchmark --compare "$@"
|
||||
;;
|
||||
|
||||
plans-bench)
|
||||
info "Running planning benchmark (${ENV_UI})…"
|
||||
"$PYTHON_UI" scripts/benchmark_plans.py "$@"
|
||||
;;
|
||||
|
||||
plans-list)
|
||||
exec "$0" plans-bench --list-models
|
||||
;;
|
||||
|
||||
plans-run)
|
||||
if [[ $# -lt 1 ]]; then
|
||||
error "Usage: ./manage.sh plans-run <model-key> [extra args]"
|
||||
start-api)
|
||||
API_PID_FILE=".avocet-api.pid"
|
||||
API_PORT=8503
|
||||
if [[ -f "$API_PID_FILE" ]] && kill -0 "$(<"$API_PID_FILE")" 2>/dev/null; then
|
||||
warn "API already running (PID $(<"$API_PID_FILE")) → http://localhost:${API_PORT}"
|
||||
exit 0
|
||||
fi
|
||||
MODEL="$1"; shift
|
||||
exec "$0" plans-bench --model "$MODEL" --verbose "$@"
|
||||
;;
|
||||
|
||||
plans-compare)
|
||||
if [[ $# -lt 2 ]]; then
|
||||
error "Usage: ./manage.sh plans-compare <model1> <model2> [more…]"
|
||||
mkdir -p "$LOG_DIR"
|
||||
API_LOG="${LOG_DIR}/api.log"
|
||||
info "Building Vue SPA…"
|
||||
(cd web && npm run build) >> "$API_LOG" 2>&1
|
||||
info "Starting FastAPI on port ${API_PORT}…"
|
||||
nohup "$PYTHON_UI" -m uvicorn app.api:app \
|
||||
--host 0.0.0.0 --port "$API_PORT" \
|
||||
>> "$API_LOG" 2>&1 &
|
||||
echo $! > "$API_PID_FILE"
|
||||
# Poll until port is actually bound (up to 10 s), not just process alive
|
||||
for _i in $(seq 1 20); do
|
||||
sleep 0.5
|
||||
if (echo "" >/dev/tcp/127.0.0.1/"$API_PORT") 2>/dev/null; then
|
||||
success "Avocet API started → http://localhost:${API_PORT} (PID $(<"$API_PID_FILE"))"
|
||||
break
|
||||
fi
|
||||
if ! kill -0 "$(<"$API_PID_FILE")" 2>/dev/null; then
|
||||
rm -f "$API_PID_FILE"
|
||||
error "API died during startup. Check ${API_LOG}"
|
||||
fi
|
||||
done
|
||||
if ! (echo "" >/dev/tcp/127.0.0.1/"$API_PORT") 2>/dev/null; then
|
||||
error "API did not bind to port ${API_PORT} within 10 s. Check ${API_LOG}"
|
||||
fi
|
||||
exec "$0" plans-bench --compare "$@" --verbose
|
||||
;;
|
||||
|
||||
style-bench)
|
||||
info "Running writing style benchmark (${ENV_BM})…"
|
||||
if [[ ! -x "$PYTHON_BM" ]]; then
|
||||
error "Python not found in ${ENV_BM} env at ${PYTHON_BM}"
|
||||
stop-api)
|
||||
API_PID_FILE=".avocet-api.pid"
|
||||
if [[ ! -f "$API_PID_FILE" ]]; then
|
||||
warn "API not running."
|
||||
exit 0
|
||||
fi
|
||||
PID="$(<"$API_PID_FILE")"
|
||||
if kill -0 "$PID" 2>/dev/null; then
|
||||
kill "$PID" && rm -f "$API_PID_FILE"
|
||||
success "API stopped (PID ${PID})."
|
||||
else
|
||||
warn "Stale PID file (process ${PID} not running). Cleaning up."
|
||||
rm -f "$API_PID_FILE"
|
||||
fi
|
||||
"$PYTHON_BM" scripts/benchmark_style.py "$@"
|
||||
;;
|
||||
|
||||
style-list)
|
||||
exec "$0" style-bench --list-models
|
||||
restart-api)
|
||||
bash "$0" stop-api
|
||||
exec bash "$0" start-api
|
||||
;;
|
||||
|
||||
style-run)
|
||||
exec "$0" style-bench --run "$@"
|
||||
;;
|
||||
|
||||
style-last)
|
||||
exec "$0" style-bench --show-last
|
||||
open-api)
|
||||
URL="http://localhost:8503"
|
||||
info "Opening ${URL}"
|
||||
if command -v xdg-open &>/dev/null; then
|
||||
xdg-open "$URL"
|
||||
elif command -v open &>/dev/null; then
|
||||
open "$URL"
|
||||
else
|
||||
echo "$URL"
|
||||
fi
|
||||
;;
|
||||
|
||||
help|--help|-h)
|
||||
|
|
|
|||
|
|
@ -3,6 +3,3 @@ testpaths = tests
|
|||
python_files = test_*.py
|
||||
python_classes = Test*
|
||||
python_functions = test_*
|
||||
markers =
|
||||
gpu: requires an idle GPU; excluded from default runs
|
||||
slow: long-running test; excluded from default CI runs
|
||||
|
|
|
|||
|
|
@ -3,4 +3,3 @@ pydantic>=2.0.0
|
|||
uvicorn[standard]>=0.20.0
|
||||
httpx>=0.24.0
|
||||
pytest>=7.0.0
|
||||
pyyaml>=6.0
|
||||
|
|
|
|||
|
|
@ -39,7 +39,6 @@ from scripts.classifier_adapters import (
|
|||
LABELS,
|
||||
LABEL_DESCRIPTIONS,
|
||||
ClassifierAdapter,
|
||||
EmbeddingKNNAdapter,
|
||||
FineTunedAdapter,
|
||||
GLiClassAdapter,
|
||||
RerankerAdapter,
|
||||
|
|
@ -131,13 +130,6 @@ MODEL_REGISTRY: dict[str, dict[str, Any]] = {
|
|||
"params": "600M",
|
||||
"default": False,
|
||||
},
|
||||
"embed-knn-nomic": {
|
||||
"adapter": EmbeddingKNNAdapter,
|
||||
"model_id": "nomic-embed-text",
|
||||
"params": "local-embed",
|
||||
"default": False, # requires orch or ollama; use --include-slow
|
||||
"kwargs": {"k": 3},
|
||||
},
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -192,42 +184,6 @@ def discover_finetuned_models(models_dir: Path | None = None) -> list[dict]:
|
|||
return found
|
||||
|
||||
|
||||
def build_exemplars_from_jsonl(path: str, k_per_label: int = 10) -> dict[str, list[str]]:
|
||||
"""Sample up to k_per_label formatted email texts per label from a scored JSONL.
|
||||
|
||||
Formats each row as 'Subject: {subject}\n\n{body[:600]}' — the same format
|
||||
EmbeddingKNNAdapter uses at classify() time. Rows missing the 'label' key
|
||||
are skipped silently.
|
||||
|
||||
Returns dict[label, list[str]] ready for EmbeddingKNNAdapter(exemplar_texts=...).
|
||||
"""
|
||||
result: dict[str, list[str]] = {}
|
||||
p = Path(path)
|
||||
with p.open(encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
row = json.loads(line)
|
||||
except json.JSONDecodeError as exc:
|
||||
print(f"[build_exemplars] WARN: skipping malformed line: {exc}", flush=True)
|
||||
continue
|
||||
label = row.get("label")
|
||||
if not label:
|
||||
continue
|
||||
subject = row.get("subject", "")
|
||||
body = row.get("body", "")
|
||||
if not subject and not body:
|
||||
continue
|
||||
texts = result.setdefault(label, [])
|
||||
if len(texts) < k_per_label:
|
||||
texts.append(
|
||||
f"Subject: {subject}\n\n{body[:600]}"
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def _active_models(include_slow: bool = False) -> dict[str, dict[str, Any]]:
|
||||
"""Return the active model registry, merged with any discovered fine-tuned models."""
|
||||
active: dict[str, dict[str, Any]] = {
|
||||
|
|
|
|||
|
|
@ -1,734 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
"""CF-specific planning benchmark — compare base models before fine-tuning.
|
||||
|
||||
Sends held-out CircuitForge planning prompts to one or more models via the
|
||||
cf-text (local) or cf-orch API, then scores responses against CF-specific
|
||||
rubrics. Use this to select the best base model for SFT.
|
||||
|
||||
Scoring rubrics (each 0-1, summed to total/N):
|
||||
- task_structure : uses checkbox syntax (- [ ]), git commit steps
|
||||
- tier_awareness : mentions Free/Paid/Premium/Ultra tiers
|
||||
- privacy_pillar : mentions privacy/local-inference/no-logging
|
||||
- safety_pillar : mentions safety, human approval, or reversibility
|
||||
- accessibility : mentions ND/accessibility/adaptive needs
|
||||
- license_split : mentions MIT vs BSL or open-core model
|
||||
- file_paths : uses plausible file path references
|
||||
- cf_conventions : uses conda run -n cf, /Library/Development/, or known CF dirs
|
||||
- paired_coherence : (paired only) plan references the design doc's feature name
|
||||
- length_ok : 300–2500 words (under-short = hallucination risk; over-long = padding)
|
||||
|
||||
Usage
|
||||
-----
|
||||
# List available model targets
|
||||
python scripts/benchmark_plans.py --list-models
|
||||
|
||||
# Run all held-out prompts against a single model, print report
|
||||
python scripts/benchmark_plans.py --model granite-4.1-8b
|
||||
|
||||
# Compare two models side-by-side
|
||||
python scripts/benchmark_plans.py --compare granite-4.1-8b deepseek-r1-7b-4bit
|
||||
|
||||
# Run with a custom API base (cf-text default: http://localhost:8080/v1)
|
||||
python scripts/benchmark_plans.py --model granite-4.1-8b --api-base http://localhost:8080/v1
|
||||
|
||||
# Export detailed results JSON
|
||||
python scripts/benchmark_plans.py --model granite-4.1-8b --output data/bench_results.json
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import re
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from dataclasses import dataclass, field, asdict
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
# ── Paths ──────────────────────────────────────────────────────────────────────
|
||||
|
||||
_ROOT = Path(__file__).parent.parent
|
||||
_DATA_DIR = _ROOT / "data"
|
||||
|
||||
CF_TEXT_BASE = "http://localhost:8080/v1"
|
||||
CF_ORCH_BASE = "http://localhost:8090/v1"
|
||||
CF_COORD_URL = "http://10.1.10.71:7700" # cf-orch coordinator (LAN)
|
||||
|
||||
# ── Held-out prompts ───────────────────────────────────────────────────────────
|
||||
# These are NOT in the training export (no matching docs in circuitforge-plans/).
|
||||
# Each prompt exercises a different CF planning domain.
|
||||
|
||||
HELD_OUT_PROMPTS: list[dict[str, Any]] = [
|
||||
{
|
||||
"id": "ho_001",
|
||||
"name": "kiwi_barcode_ocr",
|
||||
"domain": "feature_plan",
|
||||
"prompt": (
|
||||
"You are a senior engineer on Kiwi, a CircuitForge pantry-tracking product. "
|
||||
"Write a detailed implementation plan for adding barcode scanning via device camera "
|
||||
"and receipt OCR to the item-add flow.\n\n"
|
||||
"The plan should include: file structure (create/modify), step-by-step task checklist "
|
||||
"with checkboxes, any DB migrations, and git commit steps."
|
||||
),
|
||||
"expected_signals": ["task_structure", "file_paths", "cf_conventions"],
|
||||
},
|
||||
{
|
||||
"id": "ho_002",
|
||||
"name": "peregrine_ats_scoring",
|
||||
"domain": "feature_design",
|
||||
"prompt": (
|
||||
"Write a design document for Peregrine: ATS keyword scoring for job applications.\n\n"
|
||||
"Context: Peregrine users paste job descriptions and their resume. "
|
||||
"We want to score how well the resume keywords match the JD and suggest rewrites. "
|
||||
"Describe the architecture, data flow, and key design decisions."
|
||||
),
|
||||
"expected_signals": ["privacy_pillar", "tier_awareness", "license_split"],
|
||||
},
|
||||
{
|
||||
"id": "ho_003",
|
||||
"name": "tier_gate_local_llm",
|
||||
"domain": "architecture",
|
||||
"prompt": (
|
||||
"Design the tier-gating architecture for a new CircuitForge product. "
|
||||
"The product should:\n"
|
||||
"- Default to local LLM inference for all tiers\n"
|
||||
"- Unlock cloud LLM for Paid tier and above\n"
|
||||
"- Keep fine-tuned model weights for Premium/Ultra only\n\n"
|
||||
"Describe how the tier check integrates with the LLM router, "
|
||||
"what happens when a Free user tries a Paid-tier feature, "
|
||||
"and how BYOK (bring-your-own-key) fits in."
|
||||
),
|
||||
"expected_signals": ["tier_awareness", "privacy_pillar", "license_split"],
|
||||
},
|
||||
{
|
||||
"id": "ho_004",
|
||||
"name": "heimdall_webhook_plan",
|
||||
"domain": "feature_plan",
|
||||
"prompt": (
|
||||
"Break the following Heimdall feature into a detailed implementation plan with "
|
||||
"file structure and task checkboxes — Stripe webhook handler for subscription lifecycle.\n\n"
|
||||
"Heimdall is the CircuitForge license server (FastAPI + SQLite). "
|
||||
"The webhook needs to handle checkout.session.completed, "
|
||||
"customer.subscription.updated, and customer.subscription.deleted events."
|
||||
),
|
||||
"expected_signals": ["task_structure", "file_paths", "safety_pillar"],
|
||||
},
|
||||
{
|
||||
"id": "ho_005",
|
||||
"name": "nd_accessible_onboarding",
|
||||
"domain": "ux_design",
|
||||
"prompt": (
|
||||
"You are a product designer working on Harrier, a CircuitForge tool for "
|
||||
"helping people navigate government benefits applications.\n\n"
|
||||
"Design the onboarding flow for neurodivergent (ND) users. "
|
||||
"Consider: ADHD time-blindness, executive function challenges, demand avoidance, "
|
||||
"and rejection sensitivity. The flow should reduce cognitive load and "
|
||||
"never use urgency or panic patterns."
|
||||
),
|
||||
"expected_signals": ["accessibility", "safety_pillar", "privacy_pillar"],
|
||||
},
|
||||
{
|
||||
"id": "ho_006",
|
||||
"name": "circuitforge_core_extraction",
|
||||
"domain": "architecture",
|
||||
"prompt": (
|
||||
"Produce a CircuitForge-style design document for the following circuitforge-core "
|
||||
"feature — shared ActivityPub federation module.\n\n"
|
||||
"Background: Multiple CF products (Kiwi, Rook, Snipe) want to publish updates "
|
||||
"to ActivityPub. Build it once in cf-core (MIT licensed) so all products can use it. "
|
||||
"Design the module API, describe what belongs in MIT vs BSL, and note federation "
|
||||
"privacy constraints."
|
||||
),
|
||||
"expected_signals": ["license_split", "privacy_pillar", "cf_conventions"],
|
||||
},
|
||||
{
|
||||
"id": "ho_007",
|
||||
"name": "snipe_trust_score_plan",
|
||||
"domain": "feature_plan",
|
||||
"prompt": (
|
||||
"You are a senior engineer on Snipe, a CircuitForge eBay trust-scoring tool. "
|
||||
"Write a step-by-step engineering plan for: seller trust score calculation.\n\n"
|
||||
"The score should combine: feedback ratio, account age, item-specifics completeness, "
|
||||
"listing photo quality, and shipping time accuracy. "
|
||||
"Include file structure, test plan, and migration steps."
|
||||
),
|
||||
"expected_signals": ["task_structure", "file_paths", "safety_pillar"],
|
||||
},
|
||||
{
|
||||
"id": "ho_008",
|
||||
"name": "avocet_training_pipeline",
|
||||
"domain": "feature_plan",
|
||||
"prompt": (
|
||||
"Break the following Avocet feature into a detailed implementation plan — "
|
||||
"end-to-end fine-tuning pipeline from labeled JSONL to deployed GGUF model.\n\n"
|
||||
"Avocet is the CircuitForge email classifier training tool. "
|
||||
"The pipeline should: validate the dataset, run LoRA SFT via unsloth, "
|
||||
"quantize to Q5_K_M GGUF, run the benchmark harness, and register the model "
|
||||
"in the Avocet model queue if it beats the baseline."
|
||||
),
|
||||
"expected_signals": ["task_structure", "file_paths", "cf_conventions"],
|
||||
},
|
||||
{
|
||||
"id": "ho_009",
|
||||
"name": "privacy_data_flow",
|
||||
"domain": "architecture",
|
||||
"prompt": (
|
||||
"Design the data privacy architecture for a CircuitForge cloud product. "
|
||||
"Describe: what PII is collected, how it's stored, retention policy, "
|
||||
"obfuscation strategy for cloud-side logs, and how consent is obtained "
|
||||
"in plain language. The product handles job applications (resumes, cover letters)."
|
||||
),
|
||||
"expected_signals": ["privacy_pillar", "safety_pillar", "accessibility"],
|
||||
},
|
||||
{
|
||||
"id": "ho_010",
|
||||
"name": "git_workflow_doc",
|
||||
"domain": "process_doc",
|
||||
"prompt": (
|
||||
"Write a developer process document for CircuitForge: conventional commit and "
|
||||
"branch workflow for a BSL 1.1 open-core product.\n\n"
|
||||
"Cover: commit message format (type: description), branch naming, "
|
||||
"when to use feature branches vs direct main commits, "
|
||||
"how the MIT/BSL split affects which commits go in which branch, "
|
||||
"and how CI gates on gitleaks for secret scanning."
|
||||
),
|
||||
"expected_signals": ["license_split", "cf_conventions", "task_structure"],
|
||||
},
|
||||
]
|
||||
|
||||
# ── Rubric scoring ─────────────────────────────────────────────────────────────
|
||||
|
||||
_TASK_STRUCTURE_RE = re.compile(r"- \[ \]", re.MULTILINE)
|
||||
_COMMIT_RE = re.compile(r"git commit|git add", re.IGNORECASE)
|
||||
_TIER_RE = re.compile(r"\b(Free|Paid|Premium|Ultra)\s+tier|\btier\s+(Free|Paid|Premium|Ultra)", re.IGNORECASE)
|
||||
_PRIVACY_RE = re.compile(r"\b(privacy|local.?inference|no.?logging|no.?pii|user.?data|data.?reten|obfuscat)", re.IGNORECASE)
|
||||
_SAFETY_RE = re.compile(r"\b(human.?approv|reversib|safety|safe.?default|fail.?safe|harm)", re.IGNORECASE)
|
||||
_A11Y_RE = re.compile(r"\b(neurodiverg|ND\b|accessib|adaptive|ADHD|autism|executive.?function|demand.?avoid)", re.IGNORECASE)
|
||||
_LICENSE_RE = re.compile(r"\b(MIT|BSL|open.?core|proprietary|commercial.?licens)", re.IGNORECASE)
|
||||
_FILE_PATH_RE = re.compile(r"(app/|tests?/|src/|scripts?/)\w[\w/.-]{3,}", re.IGNORECASE)
|
||||
_CF_CONV_RE = re.compile(r"(conda run -n cf|/Library/Development/CircuitForge|circuitforge-core|manage\.sh)", re.IGNORECASE)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RubricScore:
|
||||
task_structure: float = 0.0
|
||||
tier_awareness: float = 0.0
|
||||
privacy_pillar: float = 0.0
|
||||
safety_pillar: float = 0.0
|
||||
accessibility: float = 0.0
|
||||
license_split: float = 0.0
|
||||
file_paths: float = 0.0
|
||||
cf_conventions: float = 0.0
|
||||
length_ok: float = 0.0
|
||||
|
||||
def total(self) -> float:
|
||||
vals = [self.task_structure, self.tier_awareness, self.privacy_pillar,
|
||||
self.safety_pillar, self.accessibility, self.license_split,
|
||||
self.file_paths, self.cf_conventions, self.length_ok]
|
||||
return sum(vals) / len(vals)
|
||||
|
||||
def as_dict(self) -> dict[str, float]:
|
||||
return asdict(self)
|
||||
|
||||
|
||||
def score_response(response: str, prompt_meta: dict[str, Any]) -> RubricScore:
|
||||
words = len(response.split())
|
||||
s = RubricScore()
|
||||
|
||||
# Task structure: needs checkboxes AND at least one commit step
|
||||
checkbox_hits = len(_TASK_STRUCTURE_RE.findall(response))
|
||||
has_commit = bool(_COMMIT_RE.search(response))
|
||||
s.task_structure = min(1.0, checkbox_hits / 5) * 0.7 + (0.3 if has_commit else 0.0)
|
||||
|
||||
# Tier awareness
|
||||
s.tier_awareness = min(1.0, len(_TIER_RE.findall(response)) / 2)
|
||||
|
||||
# Privacy pillar
|
||||
s.privacy_pillar = min(1.0, len(_PRIVACY_RE.findall(response)) / 3)
|
||||
|
||||
# Safety pillar
|
||||
s.safety_pillar = min(1.0, len(_SAFETY_RE.findall(response)) / 2)
|
||||
|
||||
# Accessibility
|
||||
s.accessibility = min(1.0, len(_A11Y_RE.findall(response)) / 2)
|
||||
|
||||
# License split awareness
|
||||
s.license_split = min(1.0, len(_LICENSE_RE.findall(response)) / 2)
|
||||
|
||||
# File paths: at least 3 plausible path references
|
||||
s.file_paths = min(1.0, len(_FILE_PATH_RE.findall(response)) / 3)
|
||||
|
||||
# CF conventions
|
||||
s.cf_conventions = min(1.0, len(_CF_CONV_RE.findall(response)) / 2)
|
||||
|
||||
# Length: 200–2500 words is healthy; outside = partial credit
|
||||
if 200 <= words <= 2500:
|
||||
s.length_ok = 1.0
|
||||
elif words < 200:
|
||||
s.length_ok = words / 200
|
||||
else:
|
||||
s.length_ok = max(0.0, 1.0 - (words - 2500) / 2500)
|
||||
|
||||
return s
|
||||
|
||||
|
||||
# ── Model client ───────────────────────────────────────────────────────────────
|
||||
|
||||
# Registry of named model targets (shorthand → {api_base, model_name})
|
||||
MODEL_REGISTRY: dict[str, dict[str, str]] = {
|
||||
"deepseek-r1-1.5b": {
|
||||
"api_base": CF_TEXT_BASE,
|
||||
"model": "deepseek-r1-1.5b",
|
||||
"description": "DeepSeek R1 1.5B distill (cf-orch catalog key)",
|
||||
},
|
||||
"deepseek-r1-7b-4bit": {
|
||||
"api_base": CF_TEXT_BASE,
|
||||
"model": "deepseek-r1-7b-4bit",
|
||||
"description": "DeepSeek R1 7B distill, 4-bit (cf-orch catalog key)",
|
||||
},
|
||||
"deepseek-r1-0528-qwen3-8b-gguf": {
|
||||
"api_base": CF_TEXT_BASE,
|
||||
"model": "deepseek-r1-0528-qwen3-8b-gguf",
|
||||
"description": "DeepSeek R1 0528 Qwen3 8B GGUF -- current reasoning model (4 nodes)",
|
||||
},
|
||||
"deepseek-coder-6.7b-4bit": {
|
||||
"api_base": CF_TEXT_BASE,
|
||||
"model": "deepseek-coder-6.7b-4bit",
|
||||
"description": "DeepSeek Coder 6.7B instruct, 4-bit (cf-orch catalog key)",
|
||||
},
|
||||
"granite-4.1-8b": {
|
||||
"api_base": CF_TEXT_BASE,
|
||||
"model": "granite-4.1-8b",
|
||||
"description": "IBM Granite 4.1 8B, 4-bit -- safety-trained (cf-orch catalog key)",
|
||||
},
|
||||
"capybarahermes-2.5-mistral-7b-gguf": {
|
||||
"api_base": CF_TEXT_BASE,
|
||||
"model": "capybarahermes-2.5-mistral-7b-gguf",
|
||||
"description": "CapybaraHermes 2.5 Mistral 7B GGUF -- conversational/creative (4 nodes)",
|
||||
},
|
||||
"darwin-9b-opus-gguf": {
|
||||
"api_base": CF_TEXT_BASE,
|
||||
"model": "darwin-9b-opus-gguf",
|
||||
"description": "Darwin 9B Opus GGUF -- high-quality long-form writing (3 nodes)",
|
||||
},
|
||||
"qwen2.5-3b": {
|
||||
"api_base": CF_TEXT_BASE,
|
||||
"model": "qwen2.5-3b",
|
||||
"description": "Qwen 2.5 3B Q4 GGUF (cf-orch catalog key)",
|
||||
},
|
||||
"qwen2.5-7b": {
|
||||
"api_base": CF_TEXT_BASE,
|
||||
"model": "qwen2.5-7b",
|
||||
"description": "Qwen 2.5 7B Q4 GGUF (cf-orch catalog key)",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# ── cf-orch allocation ─────────────────────────────────────────────────────────
|
||||
|
||||
def _cforch_allocate(
|
||||
model_id: str,
|
||||
cforch_url: str,
|
||||
startup_timeout_s: float = 300.0,
|
||||
) -> tuple[str, str] | None:
|
||||
"""Allocate a cf-text instance for model_id via the cf-orch coordinator.
|
||||
|
||||
Returns (service_url, allocation_id) on success, None on failure.
|
||||
service_url is the direct node URL exposing /v1/chat/completions.
|
||||
"""
|
||||
try:
|
||||
resp = httpx.post(
|
||||
f"{cforch_url}/api/services/cf-text/allocate",
|
||||
json={
|
||||
"model_candidates": [model_id],
|
||||
"caller": "avocet",
|
||||
"pipeline": "plans_benchmark",
|
||||
},
|
||||
timeout=120.0,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
service_url: str = data["url"]
|
||||
allocation_id: str = data.get("allocation_id", "")
|
||||
node_id: str = data.get("node_id", "")
|
||||
gpu_id: int | None = data.get("gpu_id")
|
||||
|
||||
if data.get("started", False) and not data.get("warm", True):
|
||||
# Use \n so the SSE generator sees the line immediately
|
||||
print(f" [cold start] loading {model_id!r} — polling every 3s…", flush=True)
|
||||
t0 = time.monotonic()
|
||||
deadline = t0 + startup_timeout_s
|
||||
probe_misses = 0
|
||||
|
||||
while time.monotonic() < deadline:
|
||||
elapsed = time.monotonic() - t0
|
||||
try:
|
||||
status = httpx.get(f"{cforch_url}/api/services/cf-text/status", timeout=5.0)
|
||||
if status.is_success:
|
||||
instances = status.json().get("instances", [])
|
||||
match = next(
|
||||
(i for i in instances
|
||||
if i.get("node_id") == node_id and i.get("gpu_id") == gpu_id),
|
||||
None,
|
||||
)
|
||||
if match:
|
||||
probe_misses = 0
|
||||
state = match.get("state", "")
|
||||
if state == "running":
|
||||
print(f" [cold start] ready in {elapsed:.0f}s", flush=True)
|
||||
return service_url, allocation_id
|
||||
elif state == "stopped":
|
||||
print(f" [cold start] failed — service stopped after {elapsed:.0f}s", flush=True)
|
||||
return None
|
||||
else:
|
||||
# still starting — emit keepalive so SSE stream stays alive
|
||||
print(f" [cold start] state={state!r} elapsed={elapsed:.0f}s", flush=True)
|
||||
else:
|
||||
probe_misses += 1
|
||||
print(f" [cold start] waiting… elapsed={elapsed:.0f}s", flush=True)
|
||||
if probe_misses >= 6:
|
||||
try:
|
||||
h = httpx.get(f"{service_url}/health", timeout=3.0)
|
||||
if h.is_success:
|
||||
print(f" [cold start] ready via health check in {elapsed:.0f}s", flush=True)
|
||||
return service_url, allocation_id
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
print(f" [cold start] status poll returned {status.status_code}, elapsed={elapsed:.0f}s", flush=True)
|
||||
except Exception as poll_exc:
|
||||
print(f" [cold start] poll error: {poll_exc} elapsed={elapsed:.0f}s", flush=True)
|
||||
time.sleep(3.0)
|
||||
|
||||
print(f" [cold start] timed out after {time.monotonic()-t0:.0f}s", flush=True)
|
||||
return None
|
||||
|
||||
return service_url, allocation_id
|
||||
except Exception as exc:
|
||||
print(f"[warn] cf-orch allocation failed for {model_id!r}: {exc}", file=sys.stderr)
|
||||
return None
|
||||
|
||||
|
||||
def _call_model_direct(service_url: str, model: str, prompt: str, timeout: int = 600) -> tuple[str, float]:
|
||||
"""Call an OpenAI-compatible /v1/chat/completions on a direct service URL."""
|
||||
t0 = time.monotonic()
|
||||
resp = httpx.post(
|
||||
f"{service_url.rstrip('/')}/v1/chat/completions",
|
||||
json={
|
||||
"model": model,
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"max_tokens": 2048,
|
||||
"temperature": 0.2,
|
||||
},
|
||||
timeout=timeout,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
latency = time.monotonic() - t0
|
||||
text = resp.json()["choices"][0]["message"]["content"]
|
||||
return text, latency
|
||||
|
||||
|
||||
def _call_model(api_base: str, model: str, prompt: str, timeout: int = 180) -> tuple[str, float]:
|
||||
"""Call an OpenAI-compatible /chat/completions endpoint. Returns (text, latency_s)."""
|
||||
t0 = time.monotonic()
|
||||
resp = httpx.post(
|
||||
f"{api_base}/chat/completions",
|
||||
json={
|
||||
"model": model,
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"max_tokens": 2048,
|
||||
"temperature": 0.2,
|
||||
},
|
||||
timeout=timeout,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
latency = time.monotonic() - t0
|
||||
text = resp.json()["choices"][0]["message"]["content"]
|
||||
return text, latency
|
||||
|
||||
|
||||
# ── Benchmark runner ───────────────────────────────────────────────────────────
|
||||
|
||||
@dataclass
|
||||
class PromptResult:
|
||||
prompt_id: str
|
||||
prompt_name: str
|
||||
model_key: str
|
||||
response: str
|
||||
latency_s: float
|
||||
word_count: int
|
||||
scores: dict[str, float]
|
||||
total_score: float
|
||||
error: str | None = None
|
||||
|
||||
|
||||
def run_benchmark(
|
||||
model_key: str,
|
||||
model_name: str,
|
||||
prompts: list[dict[str, Any]] | None = None,
|
||||
verbose: bool = False,
|
||||
# cf-orch path
|
||||
use_cforch: bool = False,
|
||||
cforch_url: str = CF_COORD_URL,
|
||||
# direct path (used when not cf-orch)
|
||||
api_base: str = CF_TEXT_BASE,
|
||||
) -> list[PromptResult]:
|
||||
"""Run all prompts through one model. Uses cf-orch allocation when use_cforch=True."""
|
||||
if prompts is None:
|
||||
prompts = HELD_OUT_PROMPTS
|
||||
|
||||
# Allocate once per model when using cf-orch
|
||||
service_url: str | None = None
|
||||
if use_cforch:
|
||||
print(f" Allocating {model_name!r} via cf-orch…", flush=True)
|
||||
alloc = _cforch_allocate(model_name, cforch_url)
|
||||
if alloc is None:
|
||||
# Return all prompts as errors
|
||||
return [
|
||||
PromptResult(
|
||||
prompt_id=p["id"], prompt_name=p["name"], model_key=model_key,
|
||||
response="", latency_s=0.0, word_count=0, scores={}, total_score=0.0,
|
||||
error=f"cf-orch allocation failed for {model_name!r}",
|
||||
)
|
||||
for p in prompts
|
||||
]
|
||||
service_url, _alloc_id = alloc
|
||||
|
||||
results: list[PromptResult] = []
|
||||
for p in prompts:
|
||||
if verbose:
|
||||
print(f" [{p['id']}] {p['name']} … ", end="", flush=True)
|
||||
try:
|
||||
if service_url:
|
||||
response, latency = _call_model_direct(service_url, model_name, p["prompt"])
|
||||
else:
|
||||
response, latency = _call_model(api_base, model_name, p["prompt"])
|
||||
rubric = score_response(response, p)
|
||||
result = PromptResult(
|
||||
prompt_id=p["id"],
|
||||
prompt_name=p["name"],
|
||||
model_key=model_key,
|
||||
response=response,
|
||||
latency_s=round(latency, 2),
|
||||
word_count=len(response.split()),
|
||||
scores=rubric.as_dict(),
|
||||
total_score=round(rubric.total(), 3),
|
||||
)
|
||||
if verbose:
|
||||
print(f"score={result.total_score:.3f} ({result.word_count}w, {latency:.1f}s)")
|
||||
except Exception as exc:
|
||||
result = PromptResult(
|
||||
prompt_id=p["id"],
|
||||
prompt_name=p["name"],
|
||||
model_key=model_key,
|
||||
response="",
|
||||
latency_s=0.0,
|
||||
word_count=0,
|
||||
scores={},
|
||||
total_score=0.0,
|
||||
error=str(exc),
|
||||
)
|
||||
if verbose:
|
||||
print(f"ERROR: {exc}")
|
||||
results.append(result)
|
||||
return results
|
||||
|
||||
|
||||
# ── Reporting ──────────────────────────────────────────────────────────────────
|
||||
|
||||
def _print_single_report(results: list[PromptResult], model_key: str) -> None:
|
||||
ok = [r for r in results if not r.error]
|
||||
err = [r for r in results if r.error]
|
||||
if not ok:
|
||||
print(f"\n[{model_key}] All {len(err)} prompts failed.\n")
|
||||
return
|
||||
|
||||
avg_total = sum(r.total_score for r in ok) / len(ok)
|
||||
avg_latency = sum(r.latency_s for r in ok) / len(ok)
|
||||
|
||||
# Aggregate per-rubric averages
|
||||
rubric_keys = list(ok[0].scores.keys())
|
||||
rubric_avgs = {k: sum(r.scores.get(k, 0) for r in ok) / len(ok) for k in rubric_keys}
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print(f" Model : {model_key}")
|
||||
print(f" Prompts: {len(ok)}/{len(results)} passed ({len(err)} errors)")
|
||||
print(f" Overall score : {avg_total:.3f} (avg latency {avg_latency:.1f}s)")
|
||||
print(f"\n Rubric breakdown:")
|
||||
for k, v in sorted(rubric_avgs.items(), key=lambda x: -x[1]):
|
||||
bar = "█" * int(v * 20)
|
||||
print(f" {k:<22} {v:.3f} {bar}")
|
||||
print(f"\n Per-prompt scores:")
|
||||
for r in sorted(ok, key=lambda x: -x.total_score):
|
||||
flag = "⚠" if r.total_score < 0.3 else " "
|
||||
print(f" {flag} {r.prompt_id} {r.prompt_name:<35} {r.total_score:.3f} ({r.word_count}w)")
|
||||
if err:
|
||||
print(f"\n Errors:")
|
||||
for r in err:
|
||||
print(f" {r.prompt_id} {r.prompt_name}: {r.error}")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
|
||||
def _print_comparison_table(all_results: dict[str, list[PromptResult]]) -> None:
|
||||
model_keys = list(all_results.keys())
|
||||
prompt_ids = [p["id"] for p in HELD_OUT_PROMPTS]
|
||||
|
||||
# Scores by (model, prompt_id)
|
||||
score_map: dict[tuple[str, str], float] = {}
|
||||
for mk, results in all_results.items():
|
||||
for r in results:
|
||||
score_map[(mk, r.prompt_id)] = r.total_score if not r.error else 0.0
|
||||
|
||||
col_w = 10
|
||||
header = f"{'Prompt':<35}" + "".join(f"{mk[:col_w-1]:<{col_w}}" for mk in model_keys)
|
||||
print(f"\n{'='*len(header)}")
|
||||
print(" COMPARISON TABLE")
|
||||
print(f"{'='*len(header)}")
|
||||
print(f" {header}")
|
||||
print(f" {'-'*len(header)}")
|
||||
|
||||
for pid in prompt_ids:
|
||||
pname = next(p["name"] for p in HELD_OUT_PROMPTS if p["id"] == pid)
|
||||
row = f" {pname:<35}"
|
||||
best = max(score_map.get((mk, pid), 0.0) for mk in model_keys)
|
||||
for mk in model_keys:
|
||||
v = score_map.get((mk, pid), 0.0)
|
||||
marker = "*" if v == best and len(model_keys) > 1 else " "
|
||||
row += f"{v:.3f}{marker} "
|
||||
print(row)
|
||||
|
||||
print(f" {'-'*len(header)}")
|
||||
avgs_row = f" {'AVERAGE':<35}"
|
||||
best_avg = -1.0
|
||||
avgs: dict[str, float] = {}
|
||||
for mk in model_keys:
|
||||
vals = [score_map.get((mk, pid), 0.0) for pid in prompt_ids]
|
||||
avgs[mk] = sum(vals) / len(vals)
|
||||
best_avg = max(best_avg, avgs[mk])
|
||||
for mk in model_keys:
|
||||
marker = "*" if avgs[mk] == best_avg and len(model_keys) > 1 else " "
|
||||
avgs_row += f"{avgs[mk]:.3f}{marker} "
|
||||
print(avgs_row)
|
||||
print(f"{'='*len(header)}\n")
|
||||
if len(model_keys) > 1:
|
||||
winner = max(avgs, key=lambda k: avgs[k])
|
||||
print(f" Winner: {winner} (avg {avgs[winner]:.3f})\n")
|
||||
|
||||
|
||||
# ── CLI ────────────────────────────────────────────────────────────────────────
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter)
|
||||
parser.add_argument("--list-models", action="store_true",
|
||||
help="Print registered model shortcuts and exit")
|
||||
parser.add_argument("--model", metavar="KEY",
|
||||
help="Benchmark a single model (registry key or raw model name)")
|
||||
parser.add_argument("--compare", nargs="+", metavar="KEY",
|
||||
help="Compare two or more models side-by-side")
|
||||
parser.add_argument("--cforch", action="store_true",
|
||||
help="Route inference through cf-orch coordinator (allocate per model)")
|
||||
parser.add_argument("--cforch-url", default=CF_COORD_URL, metavar="URL",
|
||||
help=f"cf-orch coordinator URL (default: {CF_COORD_URL})")
|
||||
parser.add_argument("--api-base", default=None,
|
||||
help="Direct API base URL when not using cf-orch")
|
||||
parser.add_argument("--model-name", default=None,
|
||||
help="Override model name sent to API (single-model runs only)")
|
||||
parser.add_argument("--prompts", nargs="+", metavar="ID",
|
||||
help="Run only specific prompt IDs (e.g. ho_001 ho_003)")
|
||||
parser.add_argument("--output", type=Path, default=None,
|
||||
help="Write detailed JSON results to this path")
|
||||
parser.add_argument("--workers", type=int, default=1, metavar="N",
|
||||
help="Run N models concurrently (default 1). Set to number of available nodes.")
|
||||
parser.add_argument("--verbose", "-v", action="store_true",
|
||||
help="Print per-prompt progress")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.list_models:
|
||||
print("\nRegistered model shortcuts:")
|
||||
for key, info in MODEL_REGISTRY.items():
|
||||
print(f" {key:<20} {info['description']}")
|
||||
print(f"\nDefault endpoints:")
|
||||
print(f" direct {CF_TEXT_BASE}")
|
||||
print(f" cf-orch {CF_COORD_URL}")
|
||||
return
|
||||
|
||||
prompts = HELD_OUT_PROMPTS
|
||||
if args.prompts:
|
||||
ids = set(args.prompts)
|
||||
prompts = [p for p in HELD_OUT_PROMPTS if p["id"] in ids]
|
||||
if not prompts:
|
||||
print(f"No prompts matched IDs: {args.prompts}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
model_keys: list[str] = []
|
||||
if args.compare:
|
||||
model_keys = args.compare
|
||||
elif args.model:
|
||||
model_keys = [args.model]
|
||||
else:
|
||||
parser.print_help()
|
||||
sys.exit(0)
|
||||
|
||||
all_results: dict[str, list[PromptResult]] = {}
|
||||
print_lock = threading.Lock()
|
||||
|
||||
def _run_one(mk: str) -> tuple[str, list[PromptResult]]:
|
||||
if mk in MODEL_REGISTRY:
|
||||
reg = MODEL_REGISTRY[mk]
|
||||
model_name = args.model_name or reg["model"]
|
||||
direct_base = args.api_base or reg["api_base"]
|
||||
else:
|
||||
model_name = args.model_name or mk
|
||||
direct_base = args.api_base or CF_TEXT_BASE
|
||||
|
||||
if args.cforch:
|
||||
with print_lock:
|
||||
print(f"\nRunning [{mk}] via cf-orch ({args.cforch_url}) model={model_name}")
|
||||
results = run_benchmark(
|
||||
mk, model_name, prompts=prompts, verbose=args.verbose,
|
||||
use_cforch=True, cforch_url=args.cforch_url,
|
||||
)
|
||||
else:
|
||||
with print_lock:
|
||||
print(f"\nRunning [{mk}] → {direct_base} model={model_name}")
|
||||
results = run_benchmark(
|
||||
mk, model_name, prompts=prompts, verbose=args.verbose,
|
||||
api_base=direct_base,
|
||||
)
|
||||
|
||||
with print_lock:
|
||||
_print_single_report(results, mk)
|
||||
return mk, results
|
||||
|
||||
workers = max(1, args.workers)
|
||||
if workers == 1 or len(model_keys) == 1:
|
||||
for mk in model_keys:
|
||||
mk_out, results = _run_one(mk)
|
||||
all_results[mk_out] = results
|
||||
else:
|
||||
with ThreadPoolExecutor(max_workers=workers) as pool:
|
||||
futures = {pool.submit(_run_one, mk): mk for mk in model_keys}
|
||||
for fut in as_completed(futures):
|
||||
mk_out, results = fut.result()
|
||||
all_results[mk_out] = results
|
||||
|
||||
if len(model_keys) > 1:
|
||||
_print_comparison_table(all_results)
|
||||
|
||||
if args.output:
|
||||
args.output.parent.mkdir(parents=True, exist_ok=True)
|
||||
payload = {
|
||||
mk: [asdict(r) for r in results]
|
||||
for mk, results in all_results.items()
|
||||
}
|
||||
with open(args.output, "w", encoding="utf-8") as f:
|
||||
json.dump(payload, f, indent=2, ensure_ascii=False)
|
||||
print(f"Wrote detailed results to {args.output}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -1,952 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
"""
|
||||
Writing style benchmark harness -- score local text-gen models for writing style match.
|
||||
|
||||
Runs each model against a set of test prompts, extracts style signals from the
|
||||
outputs, compares them to a style corpus, and produces a ranked markdown table.
|
||||
|
||||
Usage:
|
||||
# List available ollama models
|
||||
conda run -n cf python scripts/benchmark_style.py --list-models
|
||||
|
||||
# Run against all models with default test prompts
|
||||
conda run -n cf python scripts/benchmark_style.py --run
|
||||
|
||||
# Run specific models only
|
||||
conda run -n cf python scripts/benchmark_style.py --run --models mistral:7b,llama3.1:8b
|
||||
|
||||
# Use a custom corpus directory
|
||||
conda run -n cf python scripts/benchmark_style.py --run --samples data/style_corpus/
|
||||
|
||||
# Print last results table
|
||||
conda run -n cf python scripts/benchmark_style.py --show-last
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
_ROOT = Path(__file__).parent.parent
|
||||
_CORPUS_DIR = _ROOT / "data" / "style_corpus"
|
||||
_RESULTS_DIR = _ROOT / "benchmark_results"
|
||||
_OLLAMA_URL = "http://localhost:11434"
|
||||
_CFORCH_URL = "http://localhost:7700"
|
||||
|
||||
# Subdirectories under --scan-disk root that may contain GGUFs
|
||||
_SCAN_SUBDIRS = ["textgen/models", "llama.cpp/models", "cf-text/models", "vllm/models"]
|
||||
|
||||
# ── Filler phrases that should be absent from good style-match output ──────────
|
||||
FILLER_PHRASES: list[str] = [
|
||||
"delve", "certainly", "absolutely", "i apologize", "i'd be happy to",
|
||||
"of course", "great question", "i understand", "let me know if",
|
||||
"feel free to", "it's important to note", "it's worth noting",
|
||||
"in conclusion", "to summarize", "in summary",
|
||||
]
|
||||
|
||||
# ── Test prompts: (thread_title, thread_body, context_tag) ───────────────────
|
||||
# These are representative threads that Magpie might reply to.
|
||||
# Extend this list with real examples as the corpus grows.
|
||||
TEST_PROMPTS: list[dict[str, str]] = [
|
||||
{
|
||||
"tag": "selfhosted_ai_fatigue",
|
||||
"thread_title": "Anyone else getting tired of re-explaining their setup every time an AI model forgets?",
|
||||
"thread_body": (
|
||||
"Every session I start over. My whole hardware setup, what tools I use, "
|
||||
"what I've already tried. It's exhausting. There has to be a better way."
|
||||
),
|
||||
},
|
||||
{
|
||||
"tag": "privacy_local_llm",
|
||||
"thread_title": "What's the point of running local LLMs if the apps still phone home?",
|
||||
"thread_body": (
|
||||
"I went through all the trouble of setting up ollama and now I find out "
|
||||
"the frontend I'm using is sending telemetry. Kind of defeats the purpose."
|
||||
),
|
||||
},
|
||||
{
|
||||
"tag": "solarpunk_tech",
|
||||
"thread_title": "What does solarpunk computing actually look like in practice?",
|
||||
"thread_body": (
|
||||
"I keep seeing the aesthetic but not a lot of concrete examples of "
|
||||
"people living it out with their tech choices. What does it mean day to day?"
|
||||
),
|
||||
},
|
||||
{
|
||||
"tag": "nd_tools",
|
||||
"thread_title": "Tools that actually help with executive function vs ones that just add friction",
|
||||
"thread_body": (
|
||||
"I've tried a dozen productivity apps and most of them require more "
|
||||
"executive function to maintain than they save. What actually sticks for you?"
|
||||
),
|
||||
},
|
||||
{
|
||||
"tag": "data_ownership",
|
||||
"thread_title": "Who actually owns your data when you use a 'free' AI tool?",
|
||||
"thread_body": (
|
||||
"Read the ToS on three different AI assistants today. In all three cases "
|
||||
"your inputs can be used for training, shared with partners, and retained "
|
||||
"indefinitely. At what point does 'free' just mean you're the product?"
|
||||
),
|
||||
},
|
||||
{
|
||||
"tag": "digital_culture",
|
||||
"thread_title": "The internet used to feel like it belonged to everyone. What happened?",
|
||||
"thread_body": (
|
||||
"I grew up on forums, IRC, personal homepages. Now everything is a platform "
|
||||
"owned by someone trying to extract value from the community that built it. "
|
||||
"Is the fediverse / self-hosting movement actually reversing this or just "
|
||||
"a niche hobby?"
|
||||
),
|
||||
},
|
||||
]
|
||||
|
||||
GENERATION_PARAMS: dict[str, Any] = {
|
||||
"temperature": 0.7,
|
||||
"top_p": 0.9,
|
||||
"num_predict": 300,
|
||||
}
|
||||
|
||||
SYSTEM_PROMPT = (
|
||||
"You are a writing assistant. Your job is to write a Reddit reply that matches "
|
||||
"the voice, tone, and style of the provided samples exactly.\n\n"
|
||||
"Voice characteristics:\n"
|
||||
"- Casual engineer tone. Short punchy sentences.\n"
|
||||
"- No hype, no buzzwords, no em dashes, no semicolons.\n"
|
||||
"- Community-first perspective. Solarpunk values.\n"
|
||||
"- Direct and opinionated. No throat-clearing or filler.\n"
|
||||
"- When relevant, mention personal experience with real tools.\n\n"
|
||||
"Write ONLY the reply. No preamble, no 'Here is a reply:', no meta-commentary."
|
||||
)
|
||||
|
||||
|
||||
# ── Style signal extraction ───────────────────────────────────────────────────
|
||||
|
||||
@dataclass
|
||||
class StyleSignals:
|
||||
"""Quantitative style signals extracted from a text sample."""
|
||||
sentence_count: int = 0
|
||||
word_count: int = 0
|
||||
avg_sentence_length: float = 0.0
|
||||
em_dash_count: int = 0
|
||||
semicolon_count: int = 0
|
||||
filler_hits: list[str] = field(default_factory=list)
|
||||
question_ratio: float = 0.0 # fraction of sentences ending in '?'
|
||||
first_person_ratio: float = 0.0 # fraction of sentences starting with 'I'
|
||||
avg_word_length: float = 0.0
|
||||
|
||||
|
||||
def extract_signals(text: str) -> StyleSignals:
|
||||
"""Extract style signals from a text sample."""
|
||||
text = text.strip()
|
||||
if text.startswith("[ERROR:"):
|
||||
return StyleSignals() # zero-score sentinel — caller checks for empty output
|
||||
sentences = [s.strip() for s in re.split(r'(?<=[.!?])\s+', text) if s.strip()]
|
||||
words = text.split()
|
||||
|
||||
if not sentences:
|
||||
return StyleSignals()
|
||||
|
||||
avg_sentence_length = len(words) / len(sentences) if sentences else 0.0
|
||||
avg_word_length = (sum(len(w.strip('.,!?;:"\'')) for w in words) / len(words)) if words else 0.0
|
||||
|
||||
em_dash_count = text.count('\u2014') + text.count(' -- ') + text.count('--')
|
||||
semicolon_count = text.count(';')
|
||||
|
||||
filler_hits = [p for p in FILLER_PHRASES if p.lower() in text.lower()]
|
||||
|
||||
question_ratio = sum(1 for s in sentences if s.endswith('?')) / len(sentences)
|
||||
first_person_ratio = sum(1 for s in sentences if re.match(r"^I\b", s)) / len(sentences)
|
||||
|
||||
return StyleSignals(
|
||||
sentence_count=len(sentences),
|
||||
word_count=len(words),
|
||||
avg_sentence_length=avg_sentence_length,
|
||||
em_dash_count=em_dash_count,
|
||||
semicolon_count=semicolon_count,
|
||||
filler_hits=filler_hits,
|
||||
question_ratio=question_ratio,
|
||||
first_person_ratio=first_person_ratio,
|
||||
avg_word_length=avg_word_length,
|
||||
)
|
||||
|
||||
|
||||
def build_corpus_profile(corpus_dir: Path) -> StyleSignals | None:
|
||||
"""Aggregate style signals across all corpus samples into a target profile."""
|
||||
samples = list(corpus_dir.glob("*.txt"))
|
||||
if not samples:
|
||||
return None
|
||||
|
||||
all_signals = [extract_signals(p.read_text(encoding="utf-8")) for p in samples]
|
||||
n = len(all_signals)
|
||||
|
||||
return StyleSignals(
|
||||
sentence_count=int(sum(s.sentence_count for s in all_signals) / n),
|
||||
word_count=int(sum(s.word_count for s in all_signals) / n),
|
||||
avg_sentence_length=sum(s.avg_sentence_length for s in all_signals) / n,
|
||||
em_dash_count=int(sum(s.em_dash_count for s in all_signals) / n),
|
||||
semicolon_count=int(sum(s.semicolon_count for s in all_signals) / n),
|
||||
question_ratio=sum(s.question_ratio for s in all_signals) / n,
|
||||
first_person_ratio=sum(s.first_person_ratio for s in all_signals) / n,
|
||||
avg_word_length=sum(s.avg_word_length for s in all_signals) / n,
|
||||
)
|
||||
|
||||
|
||||
def score_against_profile(output_signals: StyleSignals, profile: StyleSignals | None) -> float:
|
||||
"""Score a model output against the corpus profile. Returns 0-100.
|
||||
|
||||
Penalties:
|
||||
- Em dashes / semicolons: -5 each occurrence (hard CF style violation)
|
||||
- Filler phrases: -8 each hit (strong signal of non-style output)
|
||||
- Sentence length delta: proportional penalty (target: close to corpus avg)
|
||||
- Word length delta: smaller penalty
|
||||
|
||||
When no corpus profile is available, falls back to absolute signal scores only.
|
||||
"""
|
||||
score = 100.0
|
||||
|
||||
# Hard violations -- always penalised regardless of corpus
|
||||
score -= output_signals.em_dash_count * 5
|
||||
score -= output_signals.semicolon_count * 3
|
||||
score -= len(output_signals.filler_hits) * 8
|
||||
|
||||
if profile is not None:
|
||||
# Sentence length delta: penalise proportionally
|
||||
length_delta = abs(output_signals.avg_sentence_length - profile.avg_sentence_length)
|
||||
score -= min(length_delta * 2, 20)
|
||||
|
||||
# Question ratio delta
|
||||
question_delta = abs(output_signals.question_ratio - profile.question_ratio)
|
||||
score -= min(question_delta * 10, 10)
|
||||
|
||||
return max(0.0, score)
|
||||
|
||||
|
||||
# ── Ollama generation ─────────────────────────────────────────────────────────
|
||||
|
||||
_CFORCH_NODE_ID = "heimdall"
|
||||
|
||||
|
||||
def cforch_list_catalog(
|
||||
cforch_url: str = _CFORCH_URL,
|
||||
node_id: str = _CFORCH_NODE_ID,
|
||||
) -> dict[str, int]:
|
||||
"""Return the cf-text catalog from cf-orch as {model_id: vram_mb}.
|
||||
|
||||
Uses ?node_id= to request the catalog from a specific node's profile,
|
||||
avoiding cross-node catalog shadowing when multiple nodes define catalogs
|
||||
for the same service.
|
||||
"""
|
||||
try:
|
||||
resp = httpx.get(
|
||||
f"{cforch_url}/api/services/cf-text/catalog",
|
||||
params={"node_id": node_id} if node_id else {},
|
||||
timeout=10.0,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
raw = resp.json()
|
||||
return {
|
||||
model_id: (entry.get("vram_mb", 0) if isinstance(entry, dict) else 0)
|
||||
for model_id, entry in raw.items()
|
||||
}
|
||||
except Exception as exc:
|
||||
print(f"[warn] Could not reach cf-orch catalog at {cforch_url}: {exc}", file=sys.stderr)
|
||||
return {}
|
||||
|
||||
|
||||
def _cforch_allocate_service(
|
||||
service: str,
|
||||
model_id: str,
|
||||
cforch_url: str,
|
||||
startup_timeout_s: float,
|
||||
health_path: str,
|
||||
) -> tuple[str, str] | None:
|
||||
"""Generic cf-orch allocate + state-signal wait. Returns (service_url, allocation_id) or None.
|
||||
|
||||
After allocating, waits for the coordinator's service state to reach 'running'.
|
||||
Fails immediately if the state reaches 'stopped' (crashed load) — no waiting out
|
||||
the full timeout for a model that already failed.
|
||||
Falls back to health-polling if the coordinator doesn't expose a matching instance
|
||||
(e.g. older coordinator version or service not yet registered in probe loop).
|
||||
"""
|
||||
try:
|
||||
resp = httpx.post(
|
||||
f"{cforch_url}/api/services/{service}/allocate",
|
||||
json={
|
||||
"model_candidates": [model_id],
|
||||
"caller": "avocet",
|
||||
"pipeline": "style_benchmark",
|
||||
},
|
||||
timeout=120.0,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
service_url: str = data["url"]
|
||||
allocation_id: str = data.get("allocation_id", "")
|
||||
node_id: str = data.get("node_id", "")
|
||||
gpu_id: int | None = data.get("gpu_id")
|
||||
|
||||
if data.get("started", False) and not data.get("warm", True):
|
||||
print(f" [cold start] waiting for {service} to load {model_id!r}...", end=" ", flush=True)
|
||||
t0 = time.monotonic()
|
||||
deadline = t0 + startup_timeout_s
|
||||
probe_misses = 0 # consecutive polls with no matching instance in status
|
||||
|
||||
while time.monotonic() < deadline:
|
||||
try:
|
||||
status = httpx.get(
|
||||
f"{cforch_url}/api/services/{service}/status", timeout=5.0
|
||||
)
|
||||
if status.is_success:
|
||||
instances = status.json().get("instances", [])
|
||||
# Find our specific instance by node+gpu
|
||||
match = next(
|
||||
(i for i in instances
|
||||
if i.get("node_id") == node_id and i.get("gpu_id") == gpu_id),
|
||||
None,
|
||||
)
|
||||
if match:
|
||||
probe_misses = 0
|
||||
state = match.get("state", "")
|
||||
if state == "running":
|
||||
elapsed = time.monotonic() - t0
|
||||
print(f"ready ({elapsed:.0f}s)", flush=True)
|
||||
return service_url, allocation_id
|
||||
elif state == "stopped":
|
||||
print(f"failed (service stopped — model load error)", flush=True)
|
||||
return None
|
||||
# state == "starting" or unknown → keep waiting
|
||||
else:
|
||||
probe_misses += 1
|
||||
# After a grace period with no instance visible, fall back to
|
||||
# direct health-poll (coordinator may not have probed yet)
|
||||
if probe_misses >= 6:
|
||||
try:
|
||||
health = httpx.get(f"{service_url}{health_path}", timeout=3.0)
|
||||
if health.is_success:
|
||||
elapsed = time.monotonic() - t0
|
||||
print(f"ready via health ({elapsed:.0f}s)", flush=True)
|
||||
return service_url, allocation_id
|
||||
except Exception:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
time.sleep(3.0)
|
||||
|
||||
elapsed = time.monotonic() - t0
|
||||
print(f"timed out after {elapsed:.0f}s", flush=True)
|
||||
return None
|
||||
|
||||
return service_url, allocation_id
|
||||
except Exception as exc:
|
||||
print(f"[warn] cf-orch allocation failed for {model_id!r} ({service}): {exc}", file=sys.stderr)
|
||||
return None
|
||||
|
||||
|
||||
def cforch_allocate(
|
||||
model_id: str,
|
||||
cforch_url: str = _CFORCH_URL,
|
||||
startup_timeout_s: float = 180.0,
|
||||
) -> tuple[str, str] | None:
|
||||
"""Allocate a cf-text instance for model_id. Returns (service_url, allocation_id) or None."""
|
||||
return _cforch_allocate_service("cf-text", model_id, cforch_url, startup_timeout_s, "/health")
|
||||
|
||||
|
||||
def cforch_allocate_vllm(
|
||||
model_id: str,
|
||||
cforch_url: str = _CFORCH_URL,
|
||||
startup_timeout_s: float = 300.0,
|
||||
) -> tuple[str, str] | None:
|
||||
"""Allocate a vllm instance for model_id. Returns (service_url, allocation_id) or None.
|
||||
|
||||
vllm exposes an OpenAI-compatible API — generate_cftext() works unchanged
|
||||
against the returned service_url. Startup timeout is longer (300s) because
|
||||
vllm loads large model weights from disk before becoming ready.
|
||||
"""
|
||||
return _cforch_allocate_service("vllm", model_id, cforch_url, startup_timeout_s, "/health")
|
||||
|
||||
|
||||
def cforch_release(allocation_id: str, cforch_url: str = _CFORCH_URL) -> None:
|
||||
"""Release a cf-orch allocation."""
|
||||
if not allocation_id:
|
||||
return
|
||||
try:
|
||||
httpx.delete(f"{cforch_url}/api/services/cf-text/allocations/{allocation_id}", timeout=10.0)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def generate_cftext(
|
||||
service_url: str,
|
||||
model_id: str,
|
||||
prompt: str,
|
||||
system: str = "",
|
||||
) -> tuple[str, float]:
|
||||
"""Call cf-text via OpenAI-compatible /v1/chat/completions. Returns (text, elapsed_ms)."""
|
||||
messages: list[dict[str, str]] = []
|
||||
if system:
|
||||
messages.append({"role": "system", "content": system})
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
payload: dict[str, Any] = {
|
||||
"model": model_id,
|
||||
"messages": messages,
|
||||
"max_tokens": GENERATION_PARAMS.get("num_predict", 300),
|
||||
"temperature": GENERATION_PARAMS.get("temperature", 0.7),
|
||||
"top_p": GENERATION_PARAMS.get("top_p", 0.9),
|
||||
"stream": False,
|
||||
}
|
||||
|
||||
t0 = time.monotonic()
|
||||
try:
|
||||
resp = httpx.post(
|
||||
f"{service_url.rstrip('/')}/v1/chat/completions",
|
||||
json=payload,
|
||||
timeout=180.0,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
elapsed_ms = (time.monotonic() - t0) * 1000
|
||||
content = resp.json()["choices"][0]["message"]["content"]
|
||||
return content.strip(), elapsed_ms
|
||||
except Exception as exc:
|
||||
elapsed_ms = (time.monotonic() - t0) * 1000
|
||||
return f"[ERROR: {exc}]", elapsed_ms
|
||||
|
||||
|
||||
def generate(model_id: str, prompt: str, system: str = "") -> tuple[str, float]:
|
||||
"""Call ollama /api/generate. Returns (text, elapsed_ms)."""
|
||||
payload: dict[str, Any] = {
|
||||
"model": model_id,
|
||||
"prompt": prompt,
|
||||
"stream": False,
|
||||
"options": GENERATION_PARAMS,
|
||||
}
|
||||
if system:
|
||||
payload["system"] = system
|
||||
|
||||
t0 = time.monotonic()
|
||||
try:
|
||||
resp = httpx.post(
|
||||
f"{_OLLAMA_URL}/api/generate",
|
||||
json=payload,
|
||||
timeout=120.0,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
elapsed_ms = (time.monotonic() - t0) * 1000
|
||||
return resp.json().get("response", "").strip(), elapsed_ms
|
||||
except Exception as exc:
|
||||
elapsed_ms = (time.monotonic() - t0) * 1000
|
||||
return f"[ERROR: {exc}]", elapsed_ms
|
||||
|
||||
|
||||
def find_disk_ggufs(llm_root: Path) -> list[Path]:
|
||||
"""Recursively find .gguf files under known subdirs of llm_root.
|
||||
|
||||
Skips vocab-only GGUFs (ggml-vocab-*) which aren't standalone models.
|
||||
"""
|
||||
found: list[Path] = []
|
||||
search_dirs = [llm_root / sub for sub in _SCAN_SUBDIRS] + [llm_root]
|
||||
seen: set[Path] = set()
|
||||
for base in search_dirs:
|
||||
if not base.exists():
|
||||
continue
|
||||
for gguf in base.rglob("*.gguf"):
|
||||
if gguf in seen:
|
||||
continue
|
||||
seen.add(gguf)
|
||||
if gguf.name.startswith("ggml-vocab-"):
|
||||
continue
|
||||
found.append(gguf)
|
||||
return sorted(found)
|
||||
|
||||
|
||||
def gguf_to_ollama_tag(gguf_path: Path) -> str:
|
||||
"""Derive a stable ollama tag from a GGUF path.
|
||||
|
||||
Uses parent dir name + stem to avoid collisions, e.g.:
|
||||
claude-3.7-sonnet-reasoning-gemma3-12B/foo.Q8_0.gguf
|
||||
→ bench-claude-3.7-sonnet-reasoning-gemma3-12b-foo-q8-0
|
||||
"""
|
||||
parent = gguf_path.parent.name.lower()
|
||||
stem = gguf_path.stem.lower()
|
||||
# If stem is contained in parent (common pattern), just use parent
|
||||
slug = parent if stem.replace("-", "").replace("_", "") in parent.replace("-", "").replace("_", "") else f"{parent}-{stem}"
|
||||
slug = re.sub(r"[^a-z0-9]+", "-", slug).strip("-")
|
||||
return f"bench-{slug}:latest"
|
||||
|
||||
|
||||
def register_gguf(gguf_path: Path, tag: str) -> bool:
|
||||
"""Create a temporary ollama model entry from a GGUF file. Returns True on success."""
|
||||
import subprocess
|
||||
import tempfile
|
||||
modelfile = f"FROM {gguf_path.resolve()}\n"
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".Modelfile", delete=False) as f:
|
||||
f.write(modelfile)
|
||||
modelfile_path = f.name
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["ollama", "create", tag, "-f", modelfile_path],
|
||||
capture_output=True, text=True, timeout=60,
|
||||
)
|
||||
return result.returncode == 0
|
||||
except Exception as exc:
|
||||
print(f"[warn] Could not register {gguf_path.name}: {exc}", file=sys.stderr)
|
||||
return False
|
||||
finally:
|
||||
Path(modelfile_path).unlink(missing_ok=True)
|
||||
|
||||
|
||||
def deregister_gguf(tag: str) -> None:
|
||||
"""Remove a temporary ollama model entry."""
|
||||
import subprocess
|
||||
try:
|
||||
subprocess.run(["ollama", "rm", tag], capture_output=True, timeout=30)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def backfill_disk_models(
|
||||
llm_root: Path,
|
||||
existing_tags: set[str],
|
||||
max_vram_mb: int = 0,
|
||||
) -> list[str]:
|
||||
"""Register GGUFs from disk that aren't already in ollama. Returns new tags.
|
||||
|
||||
max_vram_mb: skip files whose size exceeds this threshold (0 = no limit).
|
||||
GGUF file size is a reliable VRAM proxy -- quantized weights load ~1:1.
|
||||
"""
|
||||
ggufs = find_disk_ggufs(llm_root)
|
||||
if not ggufs:
|
||||
print(f"No .gguf files found under {llm_root}", file=sys.stderr)
|
||||
return []
|
||||
|
||||
new_tags: list[str] = []
|
||||
skipped_oom = 0
|
||||
for gguf in ggufs:
|
||||
size_mb = gguf.stat().st_size // (1024 * 1024)
|
||||
if max_vram_mb and size_mb > max_vram_mb:
|
||||
print(f" [skip-oom] {gguf.name} ({size_mb} MB > {max_vram_mb} MB limit)")
|
||||
skipped_oom += 1
|
||||
continue
|
||||
tag = gguf_to_ollama_tag(gguf)
|
||||
if tag in existing_tags:
|
||||
print(f" [skip] {gguf.name} already registered as {tag}")
|
||||
continue
|
||||
print(f" [register] {gguf.name} ({size_mb} MB) → {tag} ...", end=" ", flush=True)
|
||||
if register_gguf(gguf, tag):
|
||||
print("ok")
|
||||
new_tags.append(tag)
|
||||
else:
|
||||
print("failed")
|
||||
|
||||
if skipped_oom:
|
||||
print(f" [info] {skipped_oom} GGUF(s) skipped (exceed {max_vram_mb} MB VRAM limit)")
|
||||
return new_tags
|
||||
|
||||
|
||||
def list_ollama_models() -> list[str]:
|
||||
"""Return model names from ollama /api/tags, filtered to text-gen candidates."""
|
||||
try:
|
||||
resp = httpx.get(f"{_OLLAMA_URL}/api/tags", timeout=10.0)
|
||||
resp.raise_for_status()
|
||||
models = resp.json().get("models", [])
|
||||
# Exclude embedding-only models
|
||||
exclude = {"mxbai-embed-large", "nomic-embed-text", "all-minilm"}
|
||||
return [
|
||||
m["name"] for m in models
|
||||
if not any(x in m["name"].lower() for x in exclude)
|
||||
]
|
||||
except Exception as exc:
|
||||
print(f"[warn] Could not reach ollama: {exc}", file=sys.stderr)
|
||||
return []
|
||||
|
||||
|
||||
# ── Run benchmark ─────────────────────────────────────────────────────────────
|
||||
|
||||
@dataclass
|
||||
class ModelResult:
|
||||
model_id: str
|
||||
prompt_results: list[dict[str, Any]] = field(default_factory=list)
|
||||
avg_score: float = 0.0
|
||||
avg_latency_ms: float = 0.0
|
||||
total_filler_hits: int = 0
|
||||
total_em_dashes: int = 0
|
||||
total_semicolons: int = 0
|
||||
|
||||
|
||||
def _bench_one_model(
|
||||
model_id: str,
|
||||
prompts: list[dict[str, str]],
|
||||
profile: Any,
|
||||
use_cforch: bool,
|
||||
cforch_url: str,
|
||||
use_vllm: bool = False,
|
||||
) -> "ModelResult | None":
|
||||
"""Run all prompts for a single model. Thread-safe — all output is prefixed with model_id.
|
||||
|
||||
Dispatch priority:
|
||||
use_vllm=True → allocate vllm via cf-orch, then generate_cftext() (OpenAI-compatible)
|
||||
use_cforch=True → allocate cf-text via cf-orch, then generate_cftext()
|
||||
else → direct ollama generate()
|
||||
Both vllm and cf-text expose /v1/chat/completions so generate_cftext() works for both.
|
||||
"""
|
||||
prefix = f"[{model_id}]"
|
||||
result = ModelResult(model_id=model_id)
|
||||
|
||||
service_url: str | None = None
|
||||
allocation_id: str = ""
|
||||
if use_vllm:
|
||||
alloc = cforch_allocate_vllm(model_id, cforch_url)
|
||||
if alloc is None:
|
||||
print(f"{prefix} [skip] vllm allocation failed", flush=True)
|
||||
return None
|
||||
service_url, allocation_id = alloc
|
||||
print(f"{prefix} vllm allocated: {service_url}", flush=True)
|
||||
elif use_cforch:
|
||||
alloc = cforch_allocate(model_id, cforch_url)
|
||||
if alloc is None:
|
||||
print(f"{prefix} [skip] cf-orch allocation failed", flush=True)
|
||||
return None
|
||||
service_url, allocation_id = alloc
|
||||
print(f"{prefix} allocated: {service_url}", flush=True)
|
||||
|
||||
try:
|
||||
for prompt_def in prompts:
|
||||
tag = prompt_def["tag"]
|
||||
user_prompt = (
|
||||
f"Thread: {prompt_def['thread_title']}\n\n"
|
||||
f"{prompt_def['thread_body']}\n\n"
|
||||
f"Write a reply:"
|
||||
)
|
||||
print(f"{prefix} [{tag}] generating...", flush=True)
|
||||
|
||||
if (use_cforch or use_vllm) and service_url:
|
||||
# Both cf-text and vllm expose /v1/chat/completions — same call
|
||||
output, elapsed_ms = generate_cftext(service_url, model_id, user_prompt, system=SYSTEM_PROMPT)
|
||||
else:
|
||||
output, elapsed_ms = generate(model_id, user_prompt, system=SYSTEM_PROMPT)
|
||||
|
||||
signals = extract_signals(output)
|
||||
score = score_against_profile(signals, profile)
|
||||
|
||||
print(f"{prefix} [{tag}] {score:.0f}/100 ({elapsed_ms:.0f}ms)", flush=True)
|
||||
if signals.filler_hits:
|
||||
print(f"{prefix} ⚠ filler: {signals.filler_hits}", flush=True)
|
||||
if signals.em_dash_count:
|
||||
print(f"{prefix} ⚠ em-dashes: {signals.em_dash_count}", flush=True)
|
||||
|
||||
result.prompt_results.append({
|
||||
"tag": tag,
|
||||
"user_prompt": user_prompt,
|
||||
"output": output,
|
||||
"signals": {
|
||||
"avg_sentence_length": signals.avg_sentence_length,
|
||||
"em_dash_count": signals.em_dash_count,
|
||||
"semicolon_count": signals.semicolon_count,
|
||||
"filler_hits": signals.filler_hits,
|
||||
"question_ratio": signals.question_ratio,
|
||||
"word_count": signals.word_count,
|
||||
},
|
||||
"score": score,
|
||||
"latency_ms": elapsed_ms,
|
||||
})
|
||||
finally:
|
||||
if (use_cforch or use_vllm) and allocation_id:
|
||||
cforch_release(allocation_id, cforch_url)
|
||||
|
||||
if not result.prompt_results:
|
||||
return None
|
||||
|
||||
scores = [r["score"] for r in result.prompt_results]
|
||||
latencies = [r["latency_ms"] for r in result.prompt_results]
|
||||
result.avg_score = sum(scores) / len(scores)
|
||||
result.avg_latency_ms = sum(latencies) / len(latencies)
|
||||
result.total_filler_hits = sum(len(r["signals"]["filler_hits"]) for r in result.prompt_results)
|
||||
result.total_em_dashes = sum(r["signals"]["em_dash_count"] for r in result.prompt_results)
|
||||
result.total_semicolons = sum(r["signals"]["semicolon_count"] for r in result.prompt_results)
|
||||
|
||||
print(f"{prefix} done — avg score {result.avg_score:.0f}/100", flush=True)
|
||||
return result
|
||||
|
||||
|
||||
def run_benchmark(
|
||||
model_ids: list[str],
|
||||
corpus_dir: Path,
|
||||
prompts: list[dict[str, str]],
|
||||
use_cforch: bool = False,
|
||||
use_vllm: bool = False,
|
||||
cforch_url: str = _CFORCH_URL,
|
||||
workers: int = 1,
|
||||
) -> list[ModelResult]:
|
||||
profile = build_corpus_profile(corpus_dir)
|
||||
if profile:
|
||||
print(f"Corpus profile loaded from {corpus_dir} ({len(list(corpus_dir.glob('*.txt')))} samples)")
|
||||
print(f" Target avg sentence length: {profile.avg_sentence_length:.1f} words")
|
||||
else:
|
||||
print(f"[warn] No corpus samples found in {corpus_dir} -- scoring on hard violations only")
|
||||
|
||||
backend = "vllm via cf-orch" if use_vllm else ("cf-text via cf-orch" if use_cforch else "ollama")
|
||||
print(f" Backend: {backend}")
|
||||
|
||||
effective_workers = min(workers, len(model_ids)) if model_ids else 1
|
||||
print(f" Workers: {effective_workers} (of {len(model_ids)} models)", flush=True)
|
||||
|
||||
results: list[ModelResult] = []
|
||||
|
||||
if effective_workers <= 1:
|
||||
# Sequential path — simpler output, easier to follow for single-model runs
|
||||
for model_id in model_ids:
|
||||
print(f"\n{'='*60}\nModel: {model_id}", flush=True)
|
||||
r = _bench_one_model(model_id, prompts, profile, use_cforch, cforch_url, use_vllm)
|
||||
if r:
|
||||
results.append(r)
|
||||
else:
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
print(f" Fanning out {len(model_ids)} models across {effective_workers} workers...", flush=True)
|
||||
with ThreadPoolExecutor(max_workers=effective_workers) as pool:
|
||||
futures = {
|
||||
pool.submit(_bench_one_model, mid, prompts, profile, use_cforch, cforch_url, use_vllm): mid
|
||||
for mid in model_ids
|
||||
}
|
||||
for future in as_completed(futures):
|
||||
r = future.result()
|
||||
if r:
|
||||
results.append(r)
|
||||
|
||||
return sorted(results, key=lambda r: r.avg_score, reverse=True)
|
||||
|
||||
|
||||
# ── Markdown report ───────────────────────────────────────────────────────────
|
||||
|
||||
def render_report(results: list[ModelResult], corpus_dir: Path) -> str:
|
||||
date_str = datetime.now().strftime("%Y-%m-%d %H:%M")
|
||||
lines: list[str] = [
|
||||
f"# Writing Style Benchmark Results",
|
||||
f"",
|
||||
f"**Date:** {date_str} ",
|
||||
f"**Corpus:** `{corpus_dir}` ",
|
||||
f"**Models tested:** {len(results)} ",
|
||||
f"**Prompts per model:** {len(TEST_PROMPTS)}",
|
||||
f"",
|
||||
f"## Rankings",
|
||||
f"",
|
||||
f"| Rank | Model | Score | Latency | Em-dashes | Fillers | Semicolons |",
|
||||
f"|------|-------|-------|---------|-----------|---------|------------|",
|
||||
]
|
||||
|
||||
for i, r in enumerate(results, 1):
|
||||
medal = {1: "🥇", 2: "🥈", 3: "🥉"}.get(i, f"#{i}")
|
||||
lines.append(
|
||||
f"| {medal} | `{r.model_id}` | {r.avg_score:.0f}/100 "
|
||||
f"| {r.avg_latency_ms:.0f}ms "
|
||||
f"| {r.total_em_dashes} "
|
||||
f"| {r.total_filler_hits} "
|
||||
f"| {r.total_semicolons} |"
|
||||
)
|
||||
|
||||
lines += ["", "## Sample Outputs", ""]
|
||||
|
||||
for r in results[:3]: # top 3 only to keep report readable
|
||||
lines += [f"### `{r.model_id}` (avg score: {r.avg_score:.0f})", ""]
|
||||
for pr in r.prompt_results:
|
||||
lines += [
|
||||
f"**Prompt:** {pr['tag']} ",
|
||||
f"**Score:** {pr['score']:.0f}/100 ",
|
||||
f"",
|
||||
f"```",
|
||||
pr["output"],
|
||||
f"```",
|
||||
f"",
|
||||
]
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def save_report(results: list[ModelResult], corpus_dir: Path) -> Path:
|
||||
_RESULTS_DIR.mkdir(exist_ok=True)
|
||||
date_str = datetime.now().strftime("%Y-%m-%d_%H%M")
|
||||
report_path = _RESULTS_DIR / f"style_{date_str}.md"
|
||||
report_path.write_text(render_report(results, corpus_dir), encoding="utf-8")
|
||||
|
||||
# Also save raw JSON for programmatic use
|
||||
json_path = _RESULTS_DIR / f"style_{date_str}.json"
|
||||
json_path.write_text(
|
||||
json.dumps(
|
||||
[
|
||||
{
|
||||
"model_id": r.model_id,
|
||||
"avg_score": r.avg_score,
|
||||
"avg_latency_ms": r.avg_latency_ms,
|
||||
"total_filler_hits": r.total_filler_hits,
|
||||
"total_em_dashes": r.total_em_dashes,
|
||||
"total_semicolons": r.total_semicolons,
|
||||
"prompt_results": r.prompt_results,
|
||||
}
|
||||
for r in results
|
||||
],
|
||||
indent=2,
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
return report_path
|
||||
|
||||
|
||||
# ── CLI commands ──────────────────────────────────────────────────────────────
|
||||
|
||||
def cmd_list_models(_args: argparse.Namespace) -> None:
|
||||
models = list_ollama_models()
|
||||
if not models:
|
||||
print("No models found (is ollama running?)")
|
||||
return
|
||||
print(f"{len(models)} models available:\n")
|
||||
for m in models:
|
||||
print(f" {m}")
|
||||
|
||||
|
||||
def cmd_run(args: argparse.Namespace) -> None:
|
||||
corpus_dir = Path(args.samples)
|
||||
if not corpus_dir.exists():
|
||||
print(f"[error] Corpus directory not found: {corpus_dir}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
max_vram_mb: int = getattr(args, "max_vram", 7200)
|
||||
use_cforch: bool = getattr(args, "cforch", False)
|
||||
use_vllm: bool = getattr(args, "vllm", False)
|
||||
cforch_url: str = getattr(args, "cforch_url", _CFORCH_URL)
|
||||
registered_tags: list[str] = []
|
||||
|
||||
def _filter_ollama_by_size(ids: list[str], include_large: bool) -> list[str]:
|
||||
"""Apply name-pattern size filter to ollama model list."""
|
||||
if include_large:
|
||||
return ids
|
||||
skip_patterns = ["270b", "70b", "32b", "30b", "21b", "20b", "deepseek-r1"]
|
||||
filtered = [m for m in ids if not any(p in m.lower() for p in skip_patterns)]
|
||||
skipped = len(ids) - len(filtered)
|
||||
if skipped:
|
||||
print(f"[info] Skipped {skipped} large model(s) by name pattern. "
|
||||
"Pass --include-large to include them.")
|
||||
return filtered
|
||||
|
||||
if args.models and args.models != "all":
|
||||
model_ids = [m.strip() for m in args.models.split(",") if m.strip()]
|
||||
elif use_cforch:
|
||||
# cf-orch path: pull model list from catalog, filter by vram_mb
|
||||
catalog = cforch_list_catalog(cforch_url)
|
||||
if not catalog:
|
||||
print("[warn] cf-orch catalog empty or unreachable -- falling back to ollama models")
|
||||
use_cforch = False
|
||||
model_ids = _filter_ollama_by_size(list_ollama_models(), args.include_large)
|
||||
if not model_ids:
|
||||
print("[error] No models found. Pass --models explicitly or check ollama.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
else:
|
||||
before = list(catalog.items())
|
||||
allowed = {mid: mb for mid, mb in before if mb == 0 or mb <= max_vram_mb}
|
||||
skipped_oom = {mid: mb for mid, mb in before if mid not in allowed}
|
||||
model_ids = list(allowed.keys())
|
||||
print(f"[info] cf-orch catalog: {len(before)} model(s), "
|
||||
f"{len(allowed)} within {max_vram_mb} MB VRAM limit")
|
||||
if skipped_oom:
|
||||
print(f"[info] Skipped (OOM risk): "
|
||||
+ ", ".join(f"{mid} ({mb} MB)" for mid, mb in sorted(skipped_oom.items())))
|
||||
else:
|
||||
# Ollama path
|
||||
model_ids = list_ollama_models()
|
||||
if not model_ids:
|
||||
print("[error] No models found. Pass --models explicitly or check ollama.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
# Backfill GGUFs from disk before filtering -- skips files that exceed VRAM limit
|
||||
if getattr(args, "scan_disk", None):
|
||||
llm_root = Path(args.scan_disk)
|
||||
print(f"\nScanning {llm_root} for unregistered GGUFs (limit: {max_vram_mb} MB)...")
|
||||
registered_tags = backfill_disk_models(llm_root, set(model_ids), max_vram_mb=max_vram_mb)
|
||||
model_ids = list_ollama_models() # re-fetch with new registrations
|
||||
|
||||
model_ids = _filter_ollama_by_size(model_ids, args.include_large)
|
||||
|
||||
print(f"\nRunning writing style benchmark on {len(model_ids)} model(s)...")
|
||||
try:
|
||||
results = run_benchmark(model_ids, corpus_dir, TEST_PROMPTS, use_cforch=use_cforch, use_vllm=use_vllm, cforch_url=cforch_url, workers=args.workers)
|
||||
report_path = save_report(results, corpus_dir)
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Results saved to: {report_path}")
|
||||
print(f"\n{render_report(results, corpus_dir)}")
|
||||
finally:
|
||||
if registered_tags:
|
||||
print(f"\nCleaning up {len(registered_tags)} temporary ollama registrations...")
|
||||
for tag in registered_tags:
|
||||
deregister_gguf(tag)
|
||||
|
||||
|
||||
def cmd_show_last(_args: argparse.Namespace) -> None:
|
||||
reports = sorted(_RESULTS_DIR.glob("style_*.md"), reverse=True)
|
||||
if not reports:
|
||||
print("No benchmark results found. Run --run first.")
|
||||
return
|
||||
print(reports[0].read_text(encoding="utf-8"))
|
||||
|
||||
|
||||
# ── Entry point ───────────────────────────────────────────────────────────────
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Writing style benchmark harness for local text-gen models",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
)
|
||||
sub = parser.add_subparsers(dest="cmd")
|
||||
|
||||
sub.add_parser("list-models", help="List available ollama models")
|
||||
|
||||
run_p = sub.add_parser("run", help="Run the benchmark")
|
||||
run_p.add_argument("--models", default="all", help="Comma-separated model IDs, or 'all'")
|
||||
run_p.add_argument("--samples", default=str(_CORPUS_DIR), help="Path to style corpus directory")
|
||||
run_p.add_argument("--include-large", action="store_true", help="Include models >20B params")
|
||||
run_p.add_argument("--scan-disk", metavar="LLM_ROOT", help="Scan directory for GGUFs not yet in ollama (e.g. /Library/Assets/LLM)")
|
||||
run_p.add_argument("--cforch", action="store_true", help="Route generation through cf-orch/cf-text instead of direct ollama")
|
||||
run_p.add_argument("--vllm", action="store_true", help="Route generation through cf-orch/vllm (OpenAI-compatible) instead of ollama")
|
||||
run_p.add_argument("--cforch-url", default=_CFORCH_URL, help=f"cf-orch coordinator URL (default: {_CFORCH_URL})")
|
||||
run_p.add_argument("--max-vram", type=int, default=7200, metavar="MB",
|
||||
help="Skip models whose VRAM footprint exceeds this limit in MB (default: 7200)")
|
||||
run_p.add_argument("--workers", type=int, default=1, metavar="N",
|
||||
help="Parallel workers — run N models simultaneously (default: 1; use 4+ with cf-orch)")
|
||||
|
||||
sub.add_parser("show-last", help="Print the most recent benchmark report")
|
||||
|
||||
# Also support legacy --list-models / --run / --show-last flags for manage.sh compat
|
||||
parser.add_argument("--list-models", action="store_true")
|
||||
parser.add_argument("--run", action="store_true")
|
||||
parser.add_argument("--show-last", action="store_true")
|
||||
parser.add_argument("--models", default="all")
|
||||
parser.add_argument("--samples", default=str(_CORPUS_DIR))
|
||||
parser.add_argument("--include-large", action="store_true")
|
||||
parser.add_argument("--scan-disk", metavar="LLM_ROOT")
|
||||
parser.add_argument("--cforch", action="store_true")
|
||||
parser.add_argument("--vllm", action="store_true")
|
||||
parser.add_argument("--cforch-url", default=_CFORCH_URL)
|
||||
parser.add_argument("--max-vram", type=int, default=7200, metavar="MB")
|
||||
parser.add_argument("--workers", type=int, default=1, metavar="N")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.cmd == "list-models" or args.list_models:
|
||||
cmd_list_models(args)
|
||||
elif args.cmd == "run" or args.run:
|
||||
cmd_run(args)
|
||||
elif args.cmd == "show-last" or args.show_last:
|
||||
cmd_show_last(args)
|
||||
else:
|
||||
parser.print_help()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -1,909 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
"""
|
||||
Voice benchmark harness -- score local text-gen models for writing style match.
|
||||
|
||||
Runs each model against a set of test prompts, extracts style signals from the
|
||||
outputs, compares them to a voice corpus, and produces a ranked markdown table.
|
||||
|
||||
Usage:
|
||||
# List available ollama models
|
||||
conda run -n cf python scripts/benchmark_voice.py --list-models
|
||||
|
||||
# Run against all models with default test prompts
|
||||
conda run -n cf python scripts/benchmark_voice.py --run
|
||||
|
||||
# Run specific models only
|
||||
conda run -n cf python scripts/benchmark_voice.py --run --models mistral:7b,llama3.1:8b
|
||||
|
||||
# Use a custom corpus directory
|
||||
conda run -n cf python scripts/benchmark_voice.py --run --samples data/voice_corpus/
|
||||
|
||||
# Print last results table
|
||||
conda run -n cf python scripts/benchmark_voice.py --show-last
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
_ROOT = Path(__file__).parent.parent
|
||||
_CORPUS_DIR = _ROOT / "data" / "voice_corpus"
|
||||
_RESULTS_DIR = _ROOT / "benchmark_results"
|
||||
_OLLAMA_URL = "http://localhost:11434"
|
||||
_CFORCH_URL = "http://localhost:7700"
|
||||
|
||||
# Subdirectories under --scan-disk root that may contain GGUFs
|
||||
_SCAN_SUBDIRS = ["textgen/models", "llama.cpp/models", "cf-text/models", "vllm/models"]
|
||||
|
||||
# ── Filler phrases that should be absent from good voice-match output ─────────
|
||||
FILLER_PHRASES: list[str] = [
|
||||
"delve", "certainly", "absolutely", "i apologize", "i'd be happy to",
|
||||
"of course", "great question", "i understand", "let me know if",
|
||||
"feel free to", "it's important to note", "it's worth noting",
|
||||
"in conclusion", "to summarize", "in summary",
|
||||
]
|
||||
|
||||
# ── Test prompts: (thread_title, thread_body, context_tag) ───────────────────
|
||||
# These are representative threads that Magpie might reply to.
|
||||
# Extend this list with real examples as the corpus grows.
|
||||
TEST_PROMPTS: list[dict[str, str]] = [
|
||||
{
|
||||
"tag": "selfhosted_ai_fatigue",
|
||||
"thread_title": "Anyone else getting tired of re-explaining their setup every time an AI model forgets?",
|
||||
"thread_body": (
|
||||
"Every session I start over. My whole hardware setup, what tools I use, "
|
||||
"what I've already tried. It's exhausting. There has to be a better way."
|
||||
),
|
||||
},
|
||||
{
|
||||
"tag": "privacy_local_llm",
|
||||
"thread_title": "What's the point of running local LLMs if the apps still phone home?",
|
||||
"thread_body": (
|
||||
"I went through all the trouble of setting up ollama and now I find out "
|
||||
"the frontend I'm using is sending telemetry. Kind of defeats the purpose."
|
||||
),
|
||||
},
|
||||
{
|
||||
"tag": "solarpunk_tech",
|
||||
"thread_title": "What does solarpunk computing actually look like in practice?",
|
||||
"thread_body": (
|
||||
"I keep seeing the aesthetic but not a lot of concrete examples of "
|
||||
"people living it out with their tech choices. What does it mean day to day?"
|
||||
),
|
||||
},
|
||||
{
|
||||
"tag": "nd_tools",
|
||||
"thread_title": "Tools that actually help with executive function vs ones that just add friction",
|
||||
"thread_body": (
|
||||
"I've tried a dozen productivity apps and most of them require more "
|
||||
"executive function to maintain than they save. What actually sticks for you?"
|
||||
),
|
||||
},
|
||||
{
|
||||
"tag": "data_ownership",
|
||||
"thread_title": "Who actually owns your data when you use a 'free' AI tool?",
|
||||
"thread_body": (
|
||||
"Read the ToS on three different AI assistants today. In all three cases "
|
||||
"your inputs can be used for training, shared with partners, and retained "
|
||||
"indefinitely. At what point does 'free' just mean you're the product?"
|
||||
),
|
||||
},
|
||||
{
|
||||
"tag": "digital_culture",
|
||||
"thread_title": "The internet used to feel like it belonged to everyone. What happened?",
|
||||
"thread_body": (
|
||||
"I grew up on forums, IRC, personal homepages. Now everything is a platform "
|
||||
"owned by someone trying to extract value from the community that built it. "
|
||||
"Is the fediverse / self-hosting movement actually reversing this or just "
|
||||
"a niche hobby?"
|
||||
),
|
||||
},
|
||||
]
|
||||
|
||||
GENERATION_PARAMS: dict[str, Any] = {
|
||||
"temperature": 0.7,
|
||||
"top_p": 0.9,
|
||||
"num_predict": 300,
|
||||
}
|
||||
|
||||
SYSTEM_PROMPT = (
|
||||
"You are a writing assistant. Your job is to write a Reddit reply that matches "
|
||||
"the voice, tone, and style of the provided samples exactly.\n\n"
|
||||
"Voice characteristics:\n"
|
||||
"- Casual engineer tone. Short punchy sentences.\n"
|
||||
"- No hype, no buzzwords, no em dashes, no semicolons.\n"
|
||||
"- Community-first perspective. Solarpunk values.\n"
|
||||
"- Direct and opinionated. No throat-clearing or filler.\n"
|
||||
"- When relevant, mention personal experience with real tools.\n\n"
|
||||
"Write ONLY the reply. No preamble, no 'Here is a reply:', no meta-commentary."
|
||||
)
|
||||
|
||||
|
||||
# ── Style signal extraction ───────────────────────────────────────────────────
|
||||
|
||||
@dataclass
|
||||
class StyleSignals:
|
||||
"""Quantitative style signals extracted from a text sample."""
|
||||
sentence_count: int = 0
|
||||
word_count: int = 0
|
||||
avg_sentence_length: float = 0.0
|
||||
em_dash_count: int = 0
|
||||
semicolon_count: int = 0
|
||||
filler_hits: list[str] = field(default_factory=list)
|
||||
question_ratio: float = 0.0 # fraction of sentences ending in '?'
|
||||
first_person_ratio: float = 0.0 # fraction of sentences starting with 'I'
|
||||
avg_word_length: float = 0.0
|
||||
|
||||
|
||||
def extract_signals(text: str) -> StyleSignals:
|
||||
"""Extract style signals from a text sample."""
|
||||
text = text.strip()
|
||||
if text.startswith("[ERROR:"):
|
||||
return StyleSignals() # zero-score sentinel — caller checks for empty output
|
||||
sentences = [s.strip() for s in re.split(r'(?<=[.!?])\s+', text) if s.strip()]
|
||||
words = text.split()
|
||||
|
||||
if not sentences:
|
||||
return StyleSignals()
|
||||
|
||||
avg_sentence_length = len(words) / len(sentences) if sentences else 0.0
|
||||
avg_word_length = (sum(len(w.strip('.,!?;:"\'')) for w in words) / len(words)) if words else 0.0
|
||||
|
||||
em_dash_count = text.count('\u2014') + text.count(' -- ') + text.count('--')
|
||||
semicolon_count = text.count(';')
|
||||
|
||||
filler_hits = [p for p in FILLER_PHRASES if p.lower() in text.lower()]
|
||||
|
||||
question_ratio = sum(1 for s in sentences if s.endswith('?')) / len(sentences)
|
||||
first_person_ratio = sum(1 for s in sentences if re.match(r"^I\b", s)) / len(sentences)
|
||||
|
||||
return StyleSignals(
|
||||
sentence_count=len(sentences),
|
||||
word_count=len(words),
|
||||
avg_sentence_length=avg_sentence_length,
|
||||
em_dash_count=em_dash_count,
|
||||
semicolon_count=semicolon_count,
|
||||
filler_hits=filler_hits,
|
||||
question_ratio=question_ratio,
|
||||
first_person_ratio=first_person_ratio,
|
||||
avg_word_length=avg_word_length,
|
||||
)
|
||||
|
||||
|
||||
def build_corpus_profile(corpus_dir: Path) -> StyleSignals | None:
|
||||
"""Aggregate style signals across all corpus samples into a target profile."""
|
||||
samples = list(corpus_dir.glob("*.txt"))
|
||||
if not samples:
|
||||
return None
|
||||
|
||||
all_signals = [extract_signals(p.read_text(encoding="utf-8")) for p in samples]
|
||||
n = len(all_signals)
|
||||
|
||||
return StyleSignals(
|
||||
sentence_count=int(sum(s.sentence_count for s in all_signals) / n),
|
||||
word_count=int(sum(s.word_count for s in all_signals) / n),
|
||||
avg_sentence_length=sum(s.avg_sentence_length for s in all_signals) / n,
|
||||
em_dash_count=int(sum(s.em_dash_count for s in all_signals) / n),
|
||||
semicolon_count=int(sum(s.semicolon_count for s in all_signals) / n),
|
||||
question_ratio=sum(s.question_ratio for s in all_signals) / n,
|
||||
first_person_ratio=sum(s.first_person_ratio for s in all_signals) / n,
|
||||
avg_word_length=sum(s.avg_word_length for s in all_signals) / n,
|
||||
)
|
||||
|
||||
|
||||
def score_against_profile(output_signals: StyleSignals, profile: StyleSignals | None) -> float:
|
||||
"""Score a model output against the corpus profile. Returns 0-100.
|
||||
|
||||
Penalties:
|
||||
- Em dashes / semicolons: -5 each occurrence (hard CF style violation)
|
||||
- Filler phrases: -8 each hit (strong signal of non-voice output)
|
||||
- Sentence length delta: proportional penalty (target: close to corpus avg)
|
||||
- Word length delta: smaller penalty
|
||||
|
||||
When no corpus profile is available, falls back to absolute signal scores only.
|
||||
"""
|
||||
score = 100.0
|
||||
|
||||
# Hard violations -- always penalised regardless of corpus
|
||||
score -= output_signals.em_dash_count * 5
|
||||
score -= output_signals.semicolon_count * 3
|
||||
score -= len(output_signals.filler_hits) * 8
|
||||
|
||||
if profile is not None:
|
||||
# Sentence length delta: penalise proportionally
|
||||
length_delta = abs(output_signals.avg_sentence_length - profile.avg_sentence_length)
|
||||
score -= min(length_delta * 2, 20)
|
||||
|
||||
# Question ratio delta
|
||||
question_delta = abs(output_signals.question_ratio - profile.question_ratio)
|
||||
score -= min(question_delta * 10, 10)
|
||||
|
||||
return max(0.0, score)
|
||||
|
||||
|
||||
# ── Ollama generation ─────────────────────────────────────────────────────────
|
||||
|
||||
_CFORCH_NODE_ID = "heimdall"
|
||||
|
||||
|
||||
def cforch_list_catalog(
|
||||
cforch_url: str = _CFORCH_URL,
|
||||
node_id: str = _CFORCH_NODE_ID,
|
||||
) -> dict[str, int]:
|
||||
"""Return the cf-text catalog from cf-orch as {model_id: vram_mb}.
|
||||
|
||||
Uses ?node_id= to request the catalog from a specific node's profile,
|
||||
avoiding cross-node catalog shadowing when multiple nodes define catalogs
|
||||
for the same service.
|
||||
"""
|
||||
try:
|
||||
resp = httpx.get(
|
||||
f"{cforch_url}/api/services/cf-text/catalog",
|
||||
params={"node_id": node_id} if node_id else {},
|
||||
timeout=10.0,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
raw = resp.json()
|
||||
return {
|
||||
model_id: (entry.get("vram_mb", 0) if isinstance(entry, dict) else 0)
|
||||
for model_id, entry in raw.items()
|
||||
}
|
||||
except Exception as exc:
|
||||
print(f"[warn] Could not reach cf-orch catalog at {cforch_url}: {exc}", file=sys.stderr)
|
||||
return {}
|
||||
|
||||
|
||||
def _cforch_allocate_service(
|
||||
service: str,
|
||||
model_id: str,
|
||||
cforch_url: str,
|
||||
startup_timeout_s: float,
|
||||
health_path: str,
|
||||
) -> tuple[str, str] | None:
|
||||
"""Generic cf-orch allocate + health-poll. Returns (service_url, allocation_id) or None."""
|
||||
try:
|
||||
resp = httpx.post(
|
||||
f"{cforch_url}/api/services/{service}/allocate",
|
||||
json={
|
||||
"model_candidates": [model_id],
|
||||
"caller": "avocet",
|
||||
"pipeline": "voice_benchmark",
|
||||
},
|
||||
timeout=120.0,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
service_url: str = data["url"]
|
||||
allocation_id: str = data.get("allocation_id", "")
|
||||
|
||||
if data.get("started", False) and not data.get("warm", True):
|
||||
label = service
|
||||
print(f" [cold start] waiting for {label} to load {model_id!r}...", end=" ", flush=True)
|
||||
deadline = time.monotonic() + startup_timeout_s
|
||||
while time.monotonic() < deadline:
|
||||
try:
|
||||
health = httpx.get(f"{service_url}{health_path}", timeout=3.0)
|
||||
if health.is_success:
|
||||
print(f"ready ({time.monotonic() - (deadline - startup_timeout_s):.0f}s)", flush=True)
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
time.sleep(2.0)
|
||||
else:
|
||||
print(f"timed out after {startup_timeout_s:.0f}s", flush=True)
|
||||
return None
|
||||
|
||||
return service_url, allocation_id
|
||||
except Exception as exc:
|
||||
print(f"[warn] cf-orch allocation failed for {model_id!r} ({service}): {exc}", file=sys.stderr)
|
||||
return None
|
||||
|
||||
|
||||
def cforch_allocate(
|
||||
model_id: str,
|
||||
cforch_url: str = _CFORCH_URL,
|
||||
startup_timeout_s: float = 180.0,
|
||||
) -> tuple[str, str] | None:
|
||||
"""Allocate a cf-text instance for model_id. Returns (service_url, allocation_id) or None."""
|
||||
return _cforch_allocate_service("cf-text", model_id, cforch_url, startup_timeout_s, "/health")
|
||||
|
||||
|
||||
def cforch_allocate_vllm(
|
||||
model_id: str,
|
||||
cforch_url: str = _CFORCH_URL,
|
||||
startup_timeout_s: float = 300.0,
|
||||
) -> tuple[str, str] | None:
|
||||
"""Allocate a vllm instance for model_id. Returns (service_url, allocation_id) or None.
|
||||
|
||||
vllm exposes an OpenAI-compatible API — generate_cftext() works unchanged
|
||||
against the returned service_url. Startup timeout is longer (300s) because
|
||||
vllm loads large model weights from disk before becoming ready.
|
||||
"""
|
||||
return _cforch_allocate_service("vllm", model_id, cforch_url, startup_timeout_s, "/health")
|
||||
|
||||
|
||||
def cforch_release(allocation_id: str, cforch_url: str = _CFORCH_URL) -> None:
|
||||
"""Release a cf-orch allocation."""
|
||||
if not allocation_id:
|
||||
return
|
||||
try:
|
||||
httpx.post(f"{cforch_url}/api/leases/{allocation_id}/release", timeout=10.0)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def generate_cftext(
|
||||
service_url: str,
|
||||
model_id: str,
|
||||
prompt: str,
|
||||
system: str = "",
|
||||
) -> tuple[str, float]:
|
||||
"""Call cf-text via OpenAI-compatible /v1/chat/completions. Returns (text, elapsed_ms)."""
|
||||
messages: list[dict[str, str]] = []
|
||||
if system:
|
||||
messages.append({"role": "system", "content": system})
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
payload: dict[str, Any] = {
|
||||
"model": model_id,
|
||||
"messages": messages,
|
||||
"max_tokens": GENERATION_PARAMS.get("num_predict", 300),
|
||||
"temperature": GENERATION_PARAMS.get("temperature", 0.7),
|
||||
"top_p": GENERATION_PARAMS.get("top_p", 0.9),
|
||||
"stream": False,
|
||||
}
|
||||
|
||||
t0 = time.monotonic()
|
||||
try:
|
||||
resp = httpx.post(
|
||||
f"{service_url.rstrip('/')}/v1/chat/completions",
|
||||
json=payload,
|
||||
timeout=180.0,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
elapsed_ms = (time.monotonic() - t0) * 1000
|
||||
content = resp.json()["choices"][0]["message"]["content"]
|
||||
return content.strip(), elapsed_ms
|
||||
except Exception as exc:
|
||||
elapsed_ms = (time.monotonic() - t0) * 1000
|
||||
return f"[ERROR: {exc}]", elapsed_ms
|
||||
|
||||
|
||||
def generate(model_id: str, prompt: str, system: str = "") -> tuple[str, float]:
|
||||
"""Call ollama /api/generate. Returns (text, elapsed_ms)."""
|
||||
payload: dict[str, Any] = {
|
||||
"model": model_id,
|
||||
"prompt": prompt,
|
||||
"stream": False,
|
||||
"options": GENERATION_PARAMS,
|
||||
}
|
||||
if system:
|
||||
payload["system"] = system
|
||||
|
||||
t0 = time.monotonic()
|
||||
try:
|
||||
resp = httpx.post(
|
||||
f"{_OLLAMA_URL}/api/generate",
|
||||
json=payload,
|
||||
timeout=120.0,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
elapsed_ms = (time.monotonic() - t0) * 1000
|
||||
return resp.json().get("response", "").strip(), elapsed_ms
|
||||
except Exception as exc:
|
||||
elapsed_ms = (time.monotonic() - t0) * 1000
|
||||
return f"[ERROR: {exc}]", elapsed_ms
|
||||
|
||||
|
||||
def find_disk_ggufs(llm_root: Path) -> list[Path]:
|
||||
"""Recursively find .gguf files under known subdirs of llm_root.
|
||||
|
||||
Skips vocab-only GGUFs (ggml-vocab-*) which aren't standalone models.
|
||||
"""
|
||||
found: list[Path] = []
|
||||
search_dirs = [llm_root / sub for sub in _SCAN_SUBDIRS] + [llm_root]
|
||||
seen: set[Path] = set()
|
||||
for base in search_dirs:
|
||||
if not base.exists():
|
||||
continue
|
||||
for gguf in base.rglob("*.gguf"):
|
||||
if gguf in seen:
|
||||
continue
|
||||
seen.add(gguf)
|
||||
if gguf.name.startswith("ggml-vocab-"):
|
||||
continue
|
||||
found.append(gguf)
|
||||
return sorted(found)
|
||||
|
||||
|
||||
def gguf_to_ollama_tag(gguf_path: Path) -> str:
|
||||
"""Derive a stable ollama tag from a GGUF path.
|
||||
|
||||
Uses parent dir name + stem to avoid collisions, e.g.:
|
||||
claude-3.7-sonnet-reasoning-gemma3-12B/foo.Q8_0.gguf
|
||||
→ bench-claude-3.7-sonnet-reasoning-gemma3-12b-foo-q8-0
|
||||
"""
|
||||
parent = gguf_path.parent.name.lower()
|
||||
stem = gguf_path.stem.lower()
|
||||
# If stem is contained in parent (common pattern), just use parent
|
||||
slug = parent if stem.replace("-", "").replace("_", "") in parent.replace("-", "").replace("_", "") else f"{parent}-{stem}"
|
||||
slug = re.sub(r"[^a-z0-9]+", "-", slug).strip("-")
|
||||
return f"bench-{slug}:latest"
|
||||
|
||||
|
||||
def register_gguf(gguf_path: Path, tag: str) -> bool:
|
||||
"""Create a temporary ollama model entry from a GGUF file. Returns True on success."""
|
||||
import subprocess
|
||||
import tempfile
|
||||
modelfile = f"FROM {gguf_path.resolve()}\n"
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".Modelfile", delete=False) as f:
|
||||
f.write(modelfile)
|
||||
modelfile_path = f.name
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["ollama", "create", tag, "-f", modelfile_path],
|
||||
capture_output=True, text=True, timeout=60,
|
||||
)
|
||||
return result.returncode == 0
|
||||
except Exception as exc:
|
||||
print(f"[warn] Could not register {gguf_path.name}: {exc}", file=sys.stderr)
|
||||
return False
|
||||
finally:
|
||||
Path(modelfile_path).unlink(missing_ok=True)
|
||||
|
||||
|
||||
def deregister_gguf(tag: str) -> None:
|
||||
"""Remove a temporary ollama model entry."""
|
||||
import subprocess
|
||||
try:
|
||||
subprocess.run(["ollama", "rm", tag], capture_output=True, timeout=30)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def backfill_disk_models(
|
||||
llm_root: Path,
|
||||
existing_tags: set[str],
|
||||
max_vram_mb: int = 0,
|
||||
) -> list[str]:
|
||||
"""Register GGUFs from disk that aren't already in ollama. Returns new tags.
|
||||
|
||||
max_vram_mb: skip files whose size exceeds this threshold (0 = no limit).
|
||||
GGUF file size is a reliable VRAM proxy -- quantized weights load ~1:1.
|
||||
"""
|
||||
ggufs = find_disk_ggufs(llm_root)
|
||||
if not ggufs:
|
||||
print(f"No .gguf files found under {llm_root}", file=sys.stderr)
|
||||
return []
|
||||
|
||||
new_tags: list[str] = []
|
||||
skipped_oom = 0
|
||||
for gguf in ggufs:
|
||||
size_mb = gguf.stat().st_size // (1024 * 1024)
|
||||
if max_vram_mb and size_mb > max_vram_mb:
|
||||
print(f" [skip-oom] {gguf.name} ({size_mb} MB > {max_vram_mb} MB limit)")
|
||||
skipped_oom += 1
|
||||
continue
|
||||
tag = gguf_to_ollama_tag(gguf)
|
||||
if tag in existing_tags:
|
||||
print(f" [skip] {gguf.name} already registered as {tag}")
|
||||
continue
|
||||
print(f" [register] {gguf.name} ({size_mb} MB) → {tag} ...", end=" ", flush=True)
|
||||
if register_gguf(gguf, tag):
|
||||
print("ok")
|
||||
new_tags.append(tag)
|
||||
else:
|
||||
print("failed")
|
||||
|
||||
if skipped_oom:
|
||||
print(f" [info] {skipped_oom} GGUF(s) skipped (exceed {max_vram_mb} MB VRAM limit)")
|
||||
return new_tags
|
||||
|
||||
|
||||
def list_ollama_models() -> list[str]:
|
||||
"""Return model names from ollama /api/tags, filtered to text-gen candidates."""
|
||||
try:
|
||||
resp = httpx.get(f"{_OLLAMA_URL}/api/tags", timeout=10.0)
|
||||
resp.raise_for_status()
|
||||
models = resp.json().get("models", [])
|
||||
# Exclude embedding-only models
|
||||
exclude = {"mxbai-embed-large", "nomic-embed-text", "all-minilm"}
|
||||
return [
|
||||
m["name"] for m in models
|
||||
if not any(x in m["name"].lower() for x in exclude)
|
||||
]
|
||||
except Exception as exc:
|
||||
print(f"[warn] Could not reach ollama: {exc}", file=sys.stderr)
|
||||
return []
|
||||
|
||||
|
||||
# ── Run benchmark ─────────────────────────────────────────────────────────────
|
||||
|
||||
@dataclass
|
||||
class ModelResult:
|
||||
model_id: str
|
||||
prompt_results: list[dict[str, Any]] = field(default_factory=list)
|
||||
avg_score: float = 0.0
|
||||
avg_latency_ms: float = 0.0
|
||||
total_filler_hits: int = 0
|
||||
total_em_dashes: int = 0
|
||||
total_semicolons: int = 0
|
||||
|
||||
|
||||
def _bench_one_model(
|
||||
model_id: str,
|
||||
prompts: list[dict[str, str]],
|
||||
profile: Any,
|
||||
use_cforch: bool,
|
||||
cforch_url: str,
|
||||
use_vllm: bool = False,
|
||||
) -> "ModelResult | None":
|
||||
"""Run all prompts for a single model. Thread-safe — all output is prefixed with model_id.
|
||||
|
||||
Dispatch priority:
|
||||
use_vllm=True → allocate vllm via cf-orch, then generate_cftext() (OpenAI-compatible)
|
||||
use_cforch=True → allocate cf-text via cf-orch, then generate_cftext()
|
||||
else → direct ollama generate()
|
||||
Both vllm and cf-text expose /v1/chat/completions so generate_cftext() works for both.
|
||||
"""
|
||||
prefix = f"[{model_id}]"
|
||||
result = ModelResult(model_id=model_id)
|
||||
|
||||
service_url: str | None = None
|
||||
allocation_id: str = ""
|
||||
if use_vllm:
|
||||
alloc = cforch_allocate_vllm(model_id, cforch_url)
|
||||
if alloc is None:
|
||||
print(f"{prefix} [skip] vllm allocation failed", flush=True)
|
||||
return None
|
||||
service_url, allocation_id = alloc
|
||||
print(f"{prefix} vllm allocated: {service_url}", flush=True)
|
||||
elif use_cforch:
|
||||
alloc = cforch_allocate(model_id, cforch_url)
|
||||
if alloc is None:
|
||||
print(f"{prefix} [skip] cf-orch allocation failed", flush=True)
|
||||
return None
|
||||
service_url, allocation_id = alloc
|
||||
print(f"{prefix} allocated: {service_url}", flush=True)
|
||||
|
||||
try:
|
||||
for prompt_def in prompts:
|
||||
tag = prompt_def["tag"]
|
||||
user_prompt = (
|
||||
f"Thread: {prompt_def['thread_title']}\n\n"
|
||||
f"{prompt_def['thread_body']}\n\n"
|
||||
f"Write a reply:"
|
||||
)
|
||||
print(f"{prefix} [{tag}] generating...", flush=True)
|
||||
|
||||
if (use_cforch or use_vllm) and service_url:
|
||||
# Both cf-text and vllm expose /v1/chat/completions — same call
|
||||
output, elapsed_ms = generate_cftext(service_url, model_id, user_prompt, system=SYSTEM_PROMPT)
|
||||
else:
|
||||
output, elapsed_ms = generate(model_id, user_prompt, system=SYSTEM_PROMPT)
|
||||
|
||||
signals = extract_signals(output)
|
||||
score = score_against_profile(signals, profile)
|
||||
|
||||
print(f"{prefix} [{tag}] {score:.0f}/100 ({elapsed_ms:.0f}ms)", flush=True)
|
||||
if signals.filler_hits:
|
||||
print(f"{prefix} ⚠ filler: {signals.filler_hits}", flush=True)
|
||||
if signals.em_dash_count:
|
||||
print(f"{prefix} ⚠ em-dashes: {signals.em_dash_count}", flush=True)
|
||||
|
||||
result.prompt_results.append({
|
||||
"tag": tag,
|
||||
"user_prompt": user_prompt,
|
||||
"output": output,
|
||||
"signals": {
|
||||
"avg_sentence_length": signals.avg_sentence_length,
|
||||
"em_dash_count": signals.em_dash_count,
|
||||
"semicolon_count": signals.semicolon_count,
|
||||
"filler_hits": signals.filler_hits,
|
||||
"question_ratio": signals.question_ratio,
|
||||
"word_count": signals.word_count,
|
||||
},
|
||||
"score": score,
|
||||
"latency_ms": elapsed_ms,
|
||||
})
|
||||
finally:
|
||||
if use_cforch and allocation_id:
|
||||
cforch_release(allocation_id, cforch_url)
|
||||
|
||||
if not result.prompt_results:
|
||||
return None
|
||||
|
||||
scores = [r["score"] for r in result.prompt_results]
|
||||
latencies = [r["latency_ms"] for r in result.prompt_results]
|
||||
result.avg_score = sum(scores) / len(scores)
|
||||
result.avg_latency_ms = sum(latencies) / len(latencies)
|
||||
result.total_filler_hits = sum(len(r["signals"]["filler_hits"]) for r in result.prompt_results)
|
||||
result.total_em_dashes = sum(r["signals"]["em_dash_count"] for r in result.prompt_results)
|
||||
result.total_semicolons = sum(r["signals"]["semicolon_count"] for r in result.prompt_results)
|
||||
|
||||
print(f"{prefix} done — avg score {result.avg_score:.0f}/100", flush=True)
|
||||
return result
|
||||
|
||||
|
||||
def run_benchmark(
|
||||
model_ids: list[str],
|
||||
corpus_dir: Path,
|
||||
prompts: list[dict[str, str]],
|
||||
use_cforch: bool = False,
|
||||
use_vllm: bool = False,
|
||||
cforch_url: str = _CFORCH_URL,
|
||||
workers: int = 1,
|
||||
) -> list[ModelResult]:
|
||||
profile = build_corpus_profile(corpus_dir)
|
||||
if profile:
|
||||
print(f"Corpus profile loaded from {corpus_dir} ({len(list(corpus_dir.glob('*.txt')))} samples)")
|
||||
print(f" Target avg sentence length: {profile.avg_sentence_length:.1f} words")
|
||||
else:
|
||||
print(f"[warn] No corpus samples found in {corpus_dir} -- scoring on hard violations only")
|
||||
|
||||
backend = "vllm via cf-orch" if use_vllm else ("cf-text via cf-orch" if use_cforch else "ollama")
|
||||
print(f" Backend: {backend}")
|
||||
|
||||
effective_workers = min(workers, len(model_ids)) if model_ids else 1
|
||||
print(f" Workers: {effective_workers} (of {len(model_ids)} models)", flush=True)
|
||||
|
||||
results: list[ModelResult] = []
|
||||
|
||||
if effective_workers <= 1:
|
||||
# Sequential path — simpler output, easier to follow for single-model runs
|
||||
for model_id in model_ids:
|
||||
print(f"\n{'='*60}\nModel: {model_id}", flush=True)
|
||||
r = _bench_one_model(model_id, prompts, profile, use_cforch, cforch_url, use_vllm)
|
||||
if r:
|
||||
results.append(r)
|
||||
else:
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
print(f" Fanning out {len(model_ids)} models across {effective_workers} workers...", flush=True)
|
||||
with ThreadPoolExecutor(max_workers=effective_workers) as pool:
|
||||
futures = {
|
||||
pool.submit(_bench_one_model, mid, prompts, profile, use_cforch, cforch_url, use_vllm): mid
|
||||
for mid in model_ids
|
||||
}
|
||||
for future in as_completed(futures):
|
||||
r = future.result()
|
||||
if r:
|
||||
results.append(r)
|
||||
|
||||
return sorted(results, key=lambda r: r.avg_score, reverse=True)
|
||||
|
||||
|
||||
# ── Markdown report ───────────────────────────────────────────────────────────
|
||||
|
||||
def render_report(results: list[ModelResult], corpus_dir: Path) -> str:
|
||||
date_str = datetime.now().strftime("%Y-%m-%d %H:%M")
|
||||
lines: list[str] = [
|
||||
f"# Voice Benchmark Results",
|
||||
f"",
|
||||
f"**Date:** {date_str} ",
|
||||
f"**Corpus:** `{corpus_dir}` ",
|
||||
f"**Models tested:** {len(results)} ",
|
||||
f"**Prompts per model:** {len(TEST_PROMPTS)}",
|
||||
f"",
|
||||
f"## Rankings",
|
||||
f"",
|
||||
f"| Rank | Model | Score | Latency | Em-dashes | Fillers | Semicolons |",
|
||||
f"|------|-------|-------|---------|-----------|---------|------------|",
|
||||
]
|
||||
|
||||
for i, r in enumerate(results, 1):
|
||||
medal = {1: "🥇", 2: "🥈", 3: "🥉"}.get(i, f"#{i}")
|
||||
lines.append(
|
||||
f"| {medal} | `{r.model_id}` | {r.avg_score:.0f}/100 "
|
||||
f"| {r.avg_latency_ms:.0f}ms "
|
||||
f"| {r.total_em_dashes} "
|
||||
f"| {r.total_filler_hits} "
|
||||
f"| {r.total_semicolons} |"
|
||||
)
|
||||
|
||||
lines += ["", "## Sample Outputs", ""]
|
||||
|
||||
for r in results[:3]: # top 3 only to keep report readable
|
||||
lines += [f"### `{r.model_id}` (avg score: {r.avg_score:.0f})", ""]
|
||||
for pr in r.prompt_results:
|
||||
lines += [
|
||||
f"**Prompt:** {pr['tag']} ",
|
||||
f"**Score:** {pr['score']:.0f}/100 ",
|
||||
f"",
|
||||
f"```",
|
||||
pr["output"],
|
||||
f"```",
|
||||
f"",
|
||||
]
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def save_report(results: list[ModelResult], corpus_dir: Path) -> Path:
|
||||
_RESULTS_DIR.mkdir(exist_ok=True)
|
||||
date_str = datetime.now().strftime("%Y-%m-%d_%H%M")
|
||||
report_path = _RESULTS_DIR / f"voice_{date_str}.md"
|
||||
report_path.write_text(render_report(results, corpus_dir), encoding="utf-8")
|
||||
|
||||
# Also save raw JSON for programmatic use
|
||||
json_path = _RESULTS_DIR / f"voice_{date_str}.json"
|
||||
json_path.write_text(
|
||||
json.dumps(
|
||||
[
|
||||
{
|
||||
"model_id": r.model_id,
|
||||
"avg_score": r.avg_score,
|
||||
"avg_latency_ms": r.avg_latency_ms,
|
||||
"total_filler_hits": r.total_filler_hits,
|
||||
"total_em_dashes": r.total_em_dashes,
|
||||
"total_semicolons": r.total_semicolons,
|
||||
"prompt_results": r.prompt_results,
|
||||
}
|
||||
for r in results
|
||||
],
|
||||
indent=2,
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
return report_path
|
||||
|
||||
|
||||
# ── CLI commands ──────────────────────────────────────────────────────────────
|
||||
|
||||
def cmd_list_models(_args: argparse.Namespace) -> None:
|
||||
models = list_ollama_models()
|
||||
if not models:
|
||||
print("No models found (is ollama running?)")
|
||||
return
|
||||
print(f"{len(models)} models available:\n")
|
||||
for m in models:
|
||||
print(f" {m}")
|
||||
|
||||
|
||||
def cmd_run(args: argparse.Namespace) -> None:
|
||||
corpus_dir = Path(args.samples)
|
||||
if not corpus_dir.exists():
|
||||
print(f"[error] Corpus directory not found: {corpus_dir}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
max_vram_mb: int = getattr(args, "max_vram", 7200)
|
||||
use_cforch: bool = getattr(args, "cforch", False)
|
||||
use_vllm: bool = getattr(args, "vllm", False)
|
||||
cforch_url: str = getattr(args, "cforch_url", _CFORCH_URL)
|
||||
registered_tags: list[str] = []
|
||||
|
||||
def _filter_ollama_by_size(ids: list[str], include_large: bool) -> list[str]:
|
||||
"""Apply name-pattern size filter to ollama model list."""
|
||||
if include_large:
|
||||
return ids
|
||||
skip_patterns = ["270b", "70b", "32b", "30b", "21b", "20b", "deepseek-r1"]
|
||||
filtered = [m for m in ids if not any(p in m.lower() for p in skip_patterns)]
|
||||
skipped = len(ids) - len(filtered)
|
||||
if skipped:
|
||||
print(f"[info] Skipped {skipped} large model(s) by name pattern. "
|
||||
"Pass --include-large to include them.")
|
||||
return filtered
|
||||
|
||||
if args.models and args.models != "all":
|
||||
model_ids = [m.strip() for m in args.models.split(",") if m.strip()]
|
||||
elif use_cforch:
|
||||
# cf-orch path: pull model list from catalog, filter by vram_mb
|
||||
catalog = cforch_list_catalog(cforch_url)
|
||||
if not catalog:
|
||||
print("[warn] cf-orch catalog empty or unreachable -- falling back to ollama models")
|
||||
use_cforch = False
|
||||
model_ids = _filter_ollama_by_size(list_ollama_models(), args.include_large)
|
||||
if not model_ids:
|
||||
print("[error] No models found. Pass --models explicitly or check ollama.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
else:
|
||||
before = list(catalog.items())
|
||||
allowed = {mid: mb for mid, mb in before if mb == 0 or mb <= max_vram_mb}
|
||||
skipped_oom = {mid: mb for mid, mb in before if mid not in allowed}
|
||||
model_ids = list(allowed.keys())
|
||||
print(f"[info] cf-orch catalog: {len(before)} model(s), "
|
||||
f"{len(allowed)} within {max_vram_mb} MB VRAM limit")
|
||||
if skipped_oom:
|
||||
print(f"[info] Skipped (OOM risk): "
|
||||
+ ", ".join(f"{mid} ({mb} MB)" for mid, mb in sorted(skipped_oom.items())))
|
||||
else:
|
||||
# Ollama path
|
||||
model_ids = list_ollama_models()
|
||||
if not model_ids:
|
||||
print("[error] No models found. Pass --models explicitly or check ollama.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
# Backfill GGUFs from disk before filtering -- skips files that exceed VRAM limit
|
||||
if getattr(args, "scan_disk", None):
|
||||
llm_root = Path(args.scan_disk)
|
||||
print(f"\nScanning {llm_root} for unregistered GGUFs (limit: {max_vram_mb} MB)...")
|
||||
registered_tags = backfill_disk_models(llm_root, set(model_ids), max_vram_mb=max_vram_mb)
|
||||
model_ids = list_ollama_models() # re-fetch with new registrations
|
||||
|
||||
model_ids = _filter_ollama_by_size(model_ids, args.include_large)
|
||||
|
||||
print(f"\nRunning voice benchmark on {len(model_ids)} model(s)...")
|
||||
try:
|
||||
results = run_benchmark(model_ids, corpus_dir, TEST_PROMPTS, use_cforch=use_cforch, use_vllm=use_vllm, cforch_url=cforch_url, workers=args.workers)
|
||||
report_path = save_report(results, corpus_dir)
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Results saved to: {report_path}")
|
||||
print(f"\n{render_report(results, corpus_dir)}")
|
||||
finally:
|
||||
if registered_tags:
|
||||
print(f"\nCleaning up {len(registered_tags)} temporary ollama registrations...")
|
||||
for tag in registered_tags:
|
||||
deregister_gguf(tag)
|
||||
|
||||
|
||||
def cmd_show_last(_args: argparse.Namespace) -> None:
|
||||
reports = sorted(_RESULTS_DIR.glob("voice_*.md"), reverse=True)
|
||||
if not reports:
|
||||
print("No benchmark results found. Run --run first.")
|
||||
return
|
||||
print(reports[0].read_text(encoding="utf-8"))
|
||||
|
||||
|
||||
# ── Entry point ───────────────────────────────────────────────────────────────
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Voice benchmark harness for local text-gen models",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
)
|
||||
sub = parser.add_subparsers(dest="cmd")
|
||||
|
||||
sub.add_parser("list-models", help="List available ollama models")
|
||||
|
||||
run_p = sub.add_parser("run", help="Run the benchmark")
|
||||
run_p.add_argument("--models", default="all", help="Comma-separated model IDs, or 'all'")
|
||||
run_p.add_argument("--samples", default=str(_CORPUS_DIR), help="Path to voice corpus directory")
|
||||
run_p.add_argument("--include-large", action="store_true", help="Include models >20B params")
|
||||
run_p.add_argument("--scan-disk", metavar="LLM_ROOT", help="Scan directory for GGUFs not yet in ollama (e.g. /Library/Assets/LLM)")
|
||||
run_p.add_argument("--cforch", action="store_true", help="Route generation through cf-orch/cf-text instead of direct ollama")
|
||||
run_p.add_argument("--vllm", action="store_true", help="Route generation through cf-orch/vllm (OpenAI-compatible) instead of ollama")
|
||||
run_p.add_argument("--cforch-url", default=_CFORCH_URL, help=f"cf-orch coordinator URL (default: {_CFORCH_URL})")
|
||||
run_p.add_argument("--max-vram", type=int, default=7200, metavar="MB",
|
||||
help="Skip models whose VRAM footprint exceeds this limit in MB (default: 7200)")
|
||||
run_p.add_argument("--workers", type=int, default=1, metavar="N",
|
||||
help="Parallel workers — run N models simultaneously (default: 1; use 4+ with cf-orch)")
|
||||
|
||||
sub.add_parser("show-last", help="Print the most recent benchmark report")
|
||||
|
||||
# Also support legacy --list-models / --run / --show-last flags for manage.sh compat
|
||||
parser.add_argument("--list-models", action="store_true")
|
||||
parser.add_argument("--run", action="store_true")
|
||||
parser.add_argument("--show-last", action="store_true")
|
||||
parser.add_argument("--models", default="all")
|
||||
parser.add_argument("--samples", default=str(_CORPUS_DIR))
|
||||
parser.add_argument("--include-large", action="store_true")
|
||||
parser.add_argument("--scan-disk", metavar="LLM_ROOT")
|
||||
parser.add_argument("--cforch", action="store_true")
|
||||
parser.add_argument("--vllm", action="store_true")
|
||||
parser.add_argument("--cforch-url", default=_CFORCH_URL)
|
||||
parser.add_argument("--max-vram", type=int, default=7200, metavar="MB")
|
||||
parser.add_argument("--workers", type=int, default=1, metavar="N")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.cmd == "list-models" or args.list_models:
|
||||
cmd_list_models(args)
|
||||
elif args.cmd == "run" or args.run:
|
||||
cmd_run(args)
|
||||
elif args.cmd == "show-last" or args.show_last:
|
||||
cmd_show_last(args)
|
||||
else:
|
||||
parser.print_help()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -7,26 +7,19 @@ from __future__ import annotations
|
|||
|
||||
import abc
|
||||
from collections import defaultdict
|
||||
import httpx
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
__all__ = [
|
||||
"LABELS",
|
||||
"LABEL_DESCRIPTIONS",
|
||||
"DEFAULT_EXEMPLARS",
|
||||
"compute_metrics",
|
||||
"ClassifierAdapter",
|
||||
"ZeroShotAdapter",
|
||||
"GLiClassAdapter",
|
||||
"RerankerAdapter",
|
||||
"FineTunedAdapter",
|
||||
"EmbeddingKNNAdapter",
|
||||
]
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
LABELS: list[str] = [
|
||||
"interview_scheduled",
|
||||
"offer_received",
|
||||
|
|
@ -124,81 +117,6 @@ def compute_metrics(
|
|||
return result
|
||||
|
||||
|
||||
|
||||
def _cosine(a: list[float], b: list[float]) -> float:
|
||||
dot = sum(x * y for x, y in zip(a, b))
|
||||
norm_a = sum(x * x for x in a) ** 0.5
|
||||
norm_b = sum(x * x for x in b) ** 0.5
|
||||
return dot / (norm_a * norm_b) if norm_a and norm_b else 0.0
|
||||
|
||||
|
||||
DEFAULT_EXEMPLARS: dict[str, list[str]] = {
|
||||
"interview_scheduled": [
|
||||
"Subject: Interview Invitation\n\nWe would like to invite you for a phone screen next week.",
|
||||
"Subject: Schedule a call\n\nCould you be available for a video interview on Tuesday?",
|
||||
"Subject: Next Steps\n\nWe'd like to move forward with a technical interview. Please select a time.",
|
||||
"Subject: Interview Details\n\nHere are the dial-in instructions for your interview tomorrow.",
|
||||
],
|
||||
"offer_received": [
|
||||
"Subject: Offer Letter Enclosed\n\nWe are pleased to extend you an offer of employment.",
|
||||
"Subject: Job Offer\n\nDear candidate, we are excited to offer you the position of Software Engineer.",
|
||||
"Subject: Employment Offer\n\nPlease find attached your formal offer letter and compensation details.",
|
||||
"Subject: Offer of Employment\n\nCongratulations! We would like to offer you a full-time position.",
|
||||
],
|
||||
"rejected": [
|
||||
"Subject: Your Application\n\nAfter careful consideration, we have decided to move forward with other candidates.",
|
||||
"Subject: Application Status\n\nWe regret to inform you that your application has not been selected.",
|
||||
"Subject: Thank you for applying\n\nWe appreciate your interest but have chosen not to proceed.",
|
||||
"Subject: Update on your candidacy\n\nWe will not be moving forward with your application at this time.",
|
||||
],
|
||||
"positive_response": [
|
||||
"Subject: Your profile\n\nI came across your LinkedIn and think you would be a great fit for our team.",
|
||||
"Subject: Exciting opportunity\n\nWe were impressed by your background and would love to connect.",
|
||||
"Subject: Following up\n\nThank you for your interest — we'd like to learn more about your experience.",
|
||||
"Subject: Great fit\n\nYour skills align well with what we are looking for. Let's set up a call.",
|
||||
],
|
||||
"survey_received": [
|
||||
"Subject: Candidate Experience Survey\n\nPlease complete this brief survey about your application experience.",
|
||||
"Subject: Culture Fit Assessment\n\nAs part of our process, we ask all candidates to complete a short assessment.",
|
||||
"Subject: Skills Assessment\n\nWe'd like you to complete our online coding assessment before proceeding.",
|
||||
"Subject: Personality Assessment\n\nPlease complete the following assessment as the next step in our process.",
|
||||
"Subject: Pre-interview questionnaire\n\nBefore we schedule your interview, please complete this brief skills survey.",
|
||||
],
|
||||
"neutral": [
|
||||
"Subject: Application Received\n\nWe have received your application and will be in touch.",
|
||||
"Subject: Thank you for applying\n\nYour application is under review. We will contact you if needed.",
|
||||
"Subject: Confirmation\n\nThis email confirms receipt of your application to our company.",
|
||||
"Subject: Application Confirmation\n\nThank you for your interest. We will review your materials and follow up.",
|
||||
],
|
||||
"event_rescheduled": [
|
||||
"Subject: Interview Rescheduled\n\nDue to a conflict, we need to move your interview to a new time.",
|
||||
"Subject: Change of interview time\n\nWe apologize — your interview has been rescheduled to Thursday.",
|
||||
"Subject: Updated interview details\n\nYour interview has been moved from Monday to Wednesday at 2pm.",
|
||||
"Subject: Reschedule request\n\nWould you be available to reschedule to a different time slot?",
|
||||
"Subject: New interview time\n\nYour phone screen has been moved from tomorrow to next week.",
|
||||
],
|
||||
"digest": [
|
||||
"Subject: 15 new jobs matching your search\n\nHere are the latest job postings that match your profile.",
|
||||
"Subject: Weekly Job Digest\n\nThis week's top opportunities for Software Engineers in your area.",
|
||||
"Subject: Jobs you might like\n\nBased on your profile, here are some positions we recommend.",
|
||||
"Subject: New jobs for you\n\nSee the latest openings from companies on your watchlist.",
|
||||
],
|
||||
"new_lead": [
|
||||
"Subject: Exciting opportunity at our company\n\nHi, I noticed your background and think you'd be a great fit.",
|
||||
"Subject: Are you open to new opportunities?\n\nI'm a recruiter reaching out about a role matching your experience.",
|
||||
"Subject: Quick question\n\nWould you be interested in hearing about a senior engineering role?",
|
||||
"Subject: Recruiting outreach\n\nI came across your profile and wanted to share an exciting opening.",
|
||||
],
|
||||
"hired": [
|
||||
"Subject: Welcome to the team!\n\nWe are thrilled to have you join us. Here are your onboarding details.",
|
||||
"Subject: Onboarding information\n\nCongratulations on accepting our offer. Your start date is confirmed.",
|
||||
"Subject: First day information\n\nWe look forward to your first day. Please arrive at 9am and ask for HR.",
|
||||
"Subject: Background check initiated\n\nAs part of your onboarding, we have initiated a background check.",
|
||||
"Subject: Equipment setup\n\nYour laptop and equipment will be ready for pickup on your first day.",
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
class ClassifierAdapter(abc.ABC):
|
||||
"""Abstract base for all email classifier adapters."""
|
||||
|
||||
|
|
@ -386,148 +304,3 @@ class FineTunedAdapter(ClassifierAdapter):
|
|||
text = f"{subject} [SEP] {body[:400]}"
|
||||
result = self._pipeline(text)
|
||||
return result[0]["label"]
|
||||
|
||||
|
||||
class EmbeddingKNNAdapter(ClassifierAdapter):
|
||||
"""k-NN email classifier using Ollama /v1/embeddings via cf-orch allocation.
|
||||
|
||||
load():
|
||||
1. Allocates an Ollama instance from cf-orch (POST /api/services/ollama/allocate).
|
||||
Falls back to ollama_url directly if orch allocation fails or is not configured.
|
||||
2. Pre-embeds all exemplar texts and stores per-label vector lists.
|
||||
|
||||
classify(subject, body):
|
||||
Embeds the input email, computes cosine similarity against all stored exemplar
|
||||
vectors, and majority-votes the top-k labels (default k=3). Tie-break: label
|
||||
with the highest total similarity score among tied vote counts wins.
|
||||
|
||||
unload():
|
||||
Releases the cf-orch allocation (DELETE .../allocations/{id}) and clears state.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
model_id: str,
|
||||
*,
|
||||
k: int = 3,
|
||||
orch_url: str = "",
|
||||
ollama_url: str = "",
|
||||
exemplar_texts: dict[str, list[str]] | None = None,
|
||||
) -> None:
|
||||
self._name = name
|
||||
self._model_id = model_id
|
||||
self._k = k
|
||||
self._orch_url = orch_url
|
||||
self._ollama_url = ollama_url
|
||||
self._exemplar_texts: dict[str, list[str]] = (
|
||||
exemplar_texts if exemplar_texts is not None else DEFAULT_EXEMPLARS
|
||||
)
|
||||
self._exemplar_embeddings: dict[str, list[list[float]]] = {}
|
||||
self._node_url: str = ""
|
||||
self._allocation_id: str = ""
|
||||
self._orch_url_used: str = ""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def model_id(self) -> str:
|
||||
return self._model_id
|
||||
|
||||
def _resolve_urls(self) -> tuple[str, str]:
|
||||
if self._orch_url or self._ollama_url:
|
||||
return self._orch_url, self._ollama_url
|
||||
import yaml # noqa: PLC0415
|
||||
cfg_path = Path(__file__).parent.parent / "config" / "label_tool.yaml"
|
||||
cfg: dict = {}
|
||||
if cfg_path.exists():
|
||||
try:
|
||||
cfg = yaml.safe_load(cfg_path.read_text(encoding="utf-8")) or {}
|
||||
except yaml.YAMLError:
|
||||
pass
|
||||
cforch = cfg.get("cforch", {}) or {}
|
||||
return cforch.get("coordinator_url", ""), cforch.get("ollama_url", "")
|
||||
|
||||
def _embed(self, node_url: str, texts: list[str]) -> list[list[float]]:
|
||||
resp = httpx.post(
|
||||
f"{node_url}/v1/embeddings",
|
||||
json={"model": self._model_id, "input": texts},
|
||||
timeout=30.0,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return [item["embedding"] for item in resp.json()["data"]]
|
||||
|
||||
def load(self) -> None:
|
||||
if self._allocation_id or self._exemplar_embeddings:
|
||||
raise RuntimeError(
|
||||
"EmbeddingKNNAdapter.load() called while already loaded — call unload() first"
|
||||
)
|
||||
orch_url, ollama_url = self._resolve_urls()
|
||||
node_url = ""
|
||||
orch_url_used = ""
|
||||
if orch_url:
|
||||
try:
|
||||
resp = httpx.post(
|
||||
f"{orch_url}/api/services/ollama/allocate",
|
||||
json={"model": self._model_id},
|
||||
timeout=15.0,
|
||||
)
|
||||
if resp.status_code == 200:
|
||||
data = resp.json()
|
||||
node_url = data["url"]
|
||||
self._allocation_id = data["allocation_id"]
|
||||
orch_url_used = orch_url
|
||||
except Exception as exc:
|
||||
_logger.warning(
|
||||
"cf-orch allocation failed, falling back to direct ollama_url: %s", exc
|
||||
)
|
||||
if not node_url:
|
||||
node_url = ollama_url
|
||||
self._allocation_id = ""
|
||||
orch_url_used = ""
|
||||
self._node_url = node_url
|
||||
self._orch_url_used = orch_url_used
|
||||
try:
|
||||
embeddings: dict[str, list[list[float]]] = {}
|
||||
for label, texts in self._exemplar_texts.items():
|
||||
embeddings[label] = self._embed(node_url, texts)
|
||||
self._exemplar_embeddings = embeddings
|
||||
except Exception:
|
||||
self.unload()
|
||||
raise
|
||||
|
||||
def unload(self) -> None:
|
||||
if self._allocation_id and self._orch_url_used:
|
||||
try:
|
||||
httpx.request(
|
||||
"DELETE",
|
||||
f"{self._orch_url_used}/api/services/ollama/allocations/{self._allocation_id}",
|
||||
timeout=10.0,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
self._exemplar_embeddings = {}
|
||||
self._node_url = ""
|
||||
self._allocation_id = ""
|
||||
self._orch_url_used = ""
|
||||
|
||||
def classify(self, subject: str, body: str) -> str:
|
||||
if not self._exemplar_embeddings:
|
||||
self.load()
|
||||
text = f"Subject: {subject}\n\n{body[:600]}"
|
||||
[query_vec] = self._embed(self._node_url, [text])
|
||||
scored: list[tuple[float, str]] = [
|
||||
(_cosine(query_vec, vec), label)
|
||||
for label, vecs in self._exemplar_embeddings.items()
|
||||
for vec in vecs
|
||||
]
|
||||
top_k = sorted(scored, reverse=True)[: self._k]
|
||||
votes: dict[str, list[float]] = {}
|
||||
for score, label in top_k:
|
||||
votes.setdefault(label, []).append(score)
|
||||
return max(
|
||||
votes,
|
||||
key=lambda lbl: (len(votes[lbl]), sum(votes[lbl])),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,458 +0,0 @@
|
|||
"""Export circuitforge-plans/ documents as instruction-tuning JSONL pairs.
|
||||
|
||||
Each record is a HuggingFace chat-format example:
|
||||
|
||||
{
|
||||
"id": "<sha256>",
|
||||
"messages": [
|
||||
{"role": "user", "content": "<reconstructed planning prompt>"},
|
||||
{"role": "assistant", "content": "<cleaned document content>"}
|
||||
],
|
||||
"meta": {
|
||||
"source": "peregrine/2026-03-03-feedback-button-design.md",
|
||||
"product": "peregrine",
|
||||
"doc_type": "design", # design | plan | spec | implementation | other
|
||||
"date": "2026-03-03",
|
||||
"paired_with": "...", # sibling path, or null
|
||||
"word_count": 1847,
|
||||
"pair_role": "context" # "context" | "target" | "standalone"
|
||||
}
|
||||
}
|
||||
|
||||
Pairing strategy
|
||||
----------------
|
||||
When a design doc and a plan doc share the same date + feature-name prefix,
|
||||
they are treated as a pair:
|
||||
- design → plan: instruction = "Given this design doc, write the implementation plan."
|
||||
context appended = full design doc content.
|
||||
- Solo docs get a synthetic instruction from the title + first overview section.
|
||||
|
||||
Usage
|
||||
-----
|
||||
# Preview stats and 5 sample records
|
||||
python scripts/export_plans.py --preview
|
||||
|
||||
# Write full output
|
||||
python scripts/export_plans.py --output data/plan_pairs.jsonl
|
||||
|
||||
# Restrict to specific products
|
||||
python scripts/export_plans.py --products peregrine,kiwi --output data/plan_pairs.jsonl
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import hashlib
|
||||
import json
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Iterator
|
||||
|
||||
# ── Paths ──────────────────────────────────────────────────────────────────────
|
||||
|
||||
_SCRIPT_DIR = Path(__file__).parent
|
||||
_AVOCET_ROOT = _SCRIPT_DIR.parent
|
||||
_DEFAULT_PLANS_DIR = Path("/Library/Development/CircuitForge/circuitforge-plans")
|
||||
_DEFAULT_OUTPUT = _AVOCET_ROOT / "data" / "plan_pairs.jsonl"
|
||||
|
||||
# ── Doc type detection ─────────────────────────────────────────────────────────
|
||||
|
||||
_TYPE_RE = re.compile(
|
||||
r"-(design|plan|spec|implementation|specs|plans)s?$",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
_SKIP_DIRS = {"__pycache__", ".git", "node_modules"}
|
||||
|
||||
# Boilerplate lines to strip from document content before using as output.
|
||||
_BOILERPLATE_RE = re.compile(
|
||||
r"""
|
||||
^\s*>\s*\*\*For\s+agentic\s+workers.* # superpowers agent hints
|
||||
|^\s*>\s*REQUIRED\s+SUB-SKILL.*
|
||||
|^\s*\*\*Date:\*\*.* # metadata header lines
|
||||
|\*\*Status:\*\*\s*Complete.* # completed-feature noise
|
||||
|\*\*Status:\*\*\s*Done.*
|
||||
|\*\*Product:\*\*.*
|
||||
|\*\*Repo:\*\*.*
|
||||
|\*\*Tech\s+Stack:\*\*.*
|
||||
|\*\*Candidate:\*\*.* # old synthetic personas
|
||||
|^Candidate:.*
|
||||
|^Team:.*
|
||||
""",
|
||||
re.VERBOSE | re.MULTILINE,
|
||||
)
|
||||
|
||||
# Old repo/path names to normalise to current equivalents.
|
||||
_PATH_NORMALIZATIONS: list[tuple[re.Pattern, str]] = [
|
||||
(re.compile(r"/devl/job-seeker", re.IGNORECASE), "/Library/Development/CircuitForge/peregrine"),
|
||||
(re.compile(r"\bjob-seeker\b", re.IGNORECASE), "peregrine"),
|
||||
(re.compile(r"Alex Rivera", re.IGNORECASE), "[user]"),
|
||||
]
|
||||
|
||||
# Instruction paraphrase templates per doc type.
|
||||
# Each entry is (user_prefix, paired_prefix).
|
||||
# {title}, {product}, {type_phrase}, {overview}, {design_context} are substituted.
|
||||
_DESIGN_INSTRUCTIONS = [
|
||||
"Write a design document for {product}: {title}.\n\nContext: {overview}",
|
||||
"You are a software architect working on {product}. Draft a design spec for: {title}.\n\n{overview}",
|
||||
"Produce a CircuitForge-style design document for the following {product} feature — {title}.\n\nBackground: {overview}",
|
||||
]
|
||||
|
||||
_PLAN_INSTRUCTIONS = [
|
||||
"Write an implementation plan for {product}: {title}.\n\nContext: {overview}",
|
||||
"Break the following {product} feature into a detailed implementation plan with file structure and task checkboxes — {title}.\n\n{overview}",
|
||||
"You are a senior engineer on {product}. Produce a step-by-step engineering plan for: {title}.\n\n{overview}",
|
||||
]
|
||||
|
||||
_PAIRED_INSTRUCTIONS = [
|
||||
(
|
||||
"You are a software architect working on {product}, a CircuitForge product. "
|
||||
"Given the following design document, write a detailed implementation plan "
|
||||
"(file structure, task breakdown with checkboxes, migration steps if needed).\n\n"
|
||||
"---\n{design_context}\n---"
|
||||
),
|
||||
(
|
||||
"The following is a design spec for a {product} feature. "
|
||||
"Produce a concrete implementation plan: file list, task checklist, any DB migrations needed.\n\n"
|
||||
"---\n{design_context}\n---"
|
||||
),
|
||||
(
|
||||
"Convert this {product} design document into an actionable implementation plan. "
|
||||
"Include all files to create/modify, step-by-step tasks with checkboxes, and migration steps.\n\n"
|
||||
"---\n{design_context}\n---"
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def _doc_type(stem: str) -> str:
|
||||
m = _TYPE_RE.search(stem)
|
||||
if not m:
|
||||
return "other"
|
||||
raw = m.group(1).lower().rstrip("s")
|
||||
return {"implementation": "plan"}.get(raw, raw)
|
||||
|
||||
|
||||
def _date_feature(stem: str) -> tuple[str, str]:
|
||||
"""Return (date, feature_slug) from '2026-03-03-feedback-button-design'."""
|
||||
m = re.match(r"^(\d{4}-\d{2}-\d{2})-(.+?)(?:-(design|plan|spec|implementation)s?)?$", stem, re.I)
|
||||
if m:
|
||||
return m.group(1), m.group(2)
|
||||
return "", stem
|
||||
|
||||
|
||||
# ── Content extraction ─────────────────────────────────────────────────────────
|
||||
|
||||
def _extract_title(content: str) -> str:
|
||||
m = re.search(r"^#\s+(.+)", content, re.MULTILINE)
|
||||
return m.group(1).strip() if m else ""
|
||||
|
||||
|
||||
def _extract_overview(content: str) -> str:
|
||||
"""Return first substantive paragraph or h2 section body (≤300 chars)."""
|
||||
# Superpowers plans have an explicit **Goal:** line — prefer that.
|
||||
goal_m = re.search(r"\*\*Goal:\*\*\s*(.+)", content)
|
||||
if goal_m:
|
||||
return goal_m.group(1).strip()[:300]
|
||||
|
||||
# Otherwise use the body of the first h2 section.
|
||||
h2_m = re.search(
|
||||
r"^##\s+\d*\.?\s*.+\n([\s\S]+?)(?=^##|\Z)",
|
||||
content,
|
||||
re.MULTILINE,
|
||||
)
|
||||
if h2_m:
|
||||
body = h2_m.group(1).strip()
|
||||
# Strip markdown bullet/code noise for the instruction
|
||||
body = re.sub(r"```[\s\S]*?```", "", body)
|
||||
body = re.sub(r"`[^`]+`", lambda m: m.group().strip("`"), body)
|
||||
body = re.sub(r"\*\*([^*]+)\*\*", r"\1", body)
|
||||
body = re.sub(r"\s+", " ", body).strip()
|
||||
return body[:300]
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
def _clean_content(content: str) -> str:
|
||||
"""Remove boilerplate, normalize old paths/names, collapse whitespace."""
|
||||
cleaned = _BOILERPLATE_RE.sub("", content)
|
||||
for pattern, replacement in _PATH_NORMALIZATIONS:
|
||||
cleaned = pattern.sub(replacement, cleaned)
|
||||
cleaned = re.sub(r"\n{3,}", "\n\n", cleaned)
|
||||
return cleaned.strip()
|
||||
|
||||
|
||||
def _quality_flags(content: str) -> list[str]:
|
||||
"""Return a list of quality issue labels found in cleaned content."""
|
||||
flags = []
|
||||
if "Alex Rivera" in content or "[user]" in content:
|
||||
flags.append("persona-residue")
|
||||
if re.search(r"\bStatus:\s*(Complete|Done|Merged)\b", content):
|
||||
flags.append("completed-status")
|
||||
return flags
|
||||
|
||||
|
||||
def _make_instruction(
|
||||
title: str,
|
||||
product: str,
|
||||
doc_type: str,
|
||||
overview: str,
|
||||
design_context: str | None = None,
|
||||
variant: int = 0,
|
||||
) -> str:
|
||||
"""Synthesise a natural planning prompt for this document.
|
||||
|
||||
variant: 0-2 selects which paraphrase template to use. Caller cycles
|
||||
through all three to produce multiple training examples per document.
|
||||
"""
|
||||
product_label = product.replace("-", " ").title() if product else "CircuitForge"
|
||||
idx = variant % 3
|
||||
|
||||
if design_context:
|
||||
tmpl = _PAIRED_INSTRUCTIONS[idx]
|
||||
return tmpl.format(
|
||||
product=product_label,
|
||||
design_context=design_context[:2500],
|
||||
)
|
||||
|
||||
templates = _PLAN_INSTRUCTIONS if doc_type in ("plan",) else _DESIGN_INSTRUCTIONS
|
||||
tmpl = templates[idx]
|
||||
return tmpl.format(
|
||||
product=product_label,
|
||||
title=title,
|
||||
overview=overview or "",
|
||||
type_phrase="planning document",
|
||||
)
|
||||
|
||||
|
||||
def _record_id(content: str, source: str) -> str:
|
||||
return hashlib.sha256(f"{source}:{content}".encode()).hexdigest()[:16]
|
||||
|
||||
|
||||
# ── Pair discovery ─────────────────────────────────────────────────────────────
|
||||
|
||||
def _find_pairs(plans_dir: Path) -> dict[str, list[tuple[str, Path]]]:
|
||||
"""Return {prefix_key → [(doc_type, path), ...]} for docs sharing date+feature."""
|
||||
by_prefix: dict[str, list[tuple[str, Path]]] = {}
|
||||
for path in plans_dir.rglob("*.md"):
|
||||
if any(part in _SKIP_DIRS for part in path.parts):
|
||||
continue
|
||||
if path.name == "README.md":
|
||||
continue
|
||||
stem = path.stem
|
||||
date, feature = _date_feature(stem)
|
||||
if not date:
|
||||
continue
|
||||
key = str(path.parent / f"{date}-{feature}")
|
||||
by_prefix.setdefault(key, []).append((_doc_type(stem), path))
|
||||
return by_prefix
|
||||
|
||||
|
||||
# ── Record generation ──────────────────────────────────────────────────────────
|
||||
|
||||
def _records_for_group(
|
||||
doc_type_paths: list[tuple[str, Path]],
|
||||
plans_dir: Path,
|
||||
) -> Iterator[dict]:
|
||||
"""Yield one or more training records for a group of related docs."""
|
||||
# Separate design vs plan docs within this group
|
||||
designs = [(t, p) for t, p in doc_type_paths if t in ("design", "spec")]
|
||||
plans_ = [(t, p) for t, p in doc_type_paths if t in ("plan",)]
|
||||
others = [(t, p) for t, p in doc_type_paths if t not in ("design", "spec", "plan")]
|
||||
|
||||
all_paths = doc_type_paths
|
||||
|
||||
if designs and plans_:
|
||||
# Paired: yield a design→plan record (3 instruction variants)
|
||||
design_type, design_path = designs[0]
|
||||
plan_type, plan_path = plans_[0]
|
||||
design_content = design_path.read_text(encoding="utf-8")
|
||||
plan_content = plan_path.read_text(encoding="utf-8")
|
||||
|
||||
product = _product_from_path(plan_path, plans_dir)
|
||||
title = _extract_title(plan_content) or plan_path.stem
|
||||
cleaned = _clean_content(plan_content)
|
||||
design_cleaned = _clean_content(design_content)
|
||||
flags = _quality_flags(cleaned)
|
||||
|
||||
if len(cleaned.split()) >= 80:
|
||||
rel_src = str(plan_path.relative_to(plans_dir))
|
||||
rel_design = str(design_path.relative_to(plans_dir))
|
||||
for variant in range(3):
|
||||
instruction = _make_instruction(
|
||||
title=title,
|
||||
product=product,
|
||||
doc_type="plan",
|
||||
overview=_extract_overview(design_content),
|
||||
design_context=design_cleaned,
|
||||
variant=variant,
|
||||
)
|
||||
yield {
|
||||
"id": _record_id(f"v{variant}:{cleaned}", rel_src),
|
||||
"messages": [
|
||||
{"role": "user", "content": instruction},
|
||||
{"role": "assistant", "content": cleaned},
|
||||
],
|
||||
"meta": {
|
||||
"source": rel_src,
|
||||
"product": product,
|
||||
"doc_type": "plan",
|
||||
"date": _date_feature(plan_path.stem)[0],
|
||||
"paired_with": rel_design,
|
||||
"word_count": len(cleaned.split()),
|
||||
"pair_role": "target",
|
||||
"variant": variant,
|
||||
"quality_flags": flags,
|
||||
},
|
||||
}
|
||||
|
||||
# Also yield the design doc as standalone variants
|
||||
all_paths = [(t, p) for t, p in all_paths if p != plan_path]
|
||||
|
||||
# Remaining docs as standalone records (3 instruction variants each)
|
||||
for doc_type, path in all_paths:
|
||||
content = path.read_text(encoding="utf-8")
|
||||
cleaned = _clean_content(content)
|
||||
if len(cleaned.split()) < 80:
|
||||
continue
|
||||
|
||||
product = _product_from_path(path, plans_dir)
|
||||
title = _extract_title(content) or path.stem
|
||||
overview = _extract_overview(content)
|
||||
flags = _quality_flags(cleaned)
|
||||
rel_src = str(path.relative_to(plans_dir))
|
||||
|
||||
for variant in range(3):
|
||||
instruction = _make_instruction(
|
||||
title=title,
|
||||
product=product,
|
||||
doc_type=doc_type,
|
||||
overview=overview,
|
||||
variant=variant,
|
||||
)
|
||||
yield {
|
||||
"id": _record_id(f"v{variant}:{cleaned}", rel_src),
|
||||
"messages": [
|
||||
{"role": "user", "content": instruction},
|
||||
{"role": "assistant", "content": cleaned},
|
||||
],
|
||||
"meta": {
|
||||
"source": rel_src,
|
||||
"product": product,
|
||||
"doc_type": doc_type,
|
||||
"date": _date_feature(path.stem)[0],
|
||||
"paired_with": None,
|
||||
"word_count": len(cleaned.split()),
|
||||
"pair_role": "standalone",
|
||||
"variant": variant,
|
||||
"quality_flags": flags,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _product_from_path(path: Path, plans_dir: Path) -> str:
|
||||
rel = path.relative_to(plans_dir)
|
||||
return rel.parts[0] if len(rel.parts) > 1 else "shared"
|
||||
|
||||
|
||||
# ── Main export ────────────────────────────────────────────────────────────────
|
||||
|
||||
def export(
|
||||
plans_dir: Path,
|
||||
products: list[str] | None = None,
|
||||
) -> list[dict]:
|
||||
groups = _find_pairs(plans_dir)
|
||||
records: list[dict] = []
|
||||
seen_ids: set[str] = set()
|
||||
|
||||
for group_key, doc_type_paths in groups.items():
|
||||
# Filter by product if requested
|
||||
if products:
|
||||
paths = [p for _, p in doc_type_paths]
|
||||
prods = {_product_from_path(p, plans_dir) for p in paths}
|
||||
if not prods.intersection(products):
|
||||
continue
|
||||
|
||||
for record in _records_for_group(doc_type_paths, plans_dir):
|
||||
if record["id"] not in seen_ids:
|
||||
seen_ids.add(record["id"])
|
||||
records.append(record)
|
||||
|
||||
return records
|
||||
|
||||
|
||||
# ── CLI ────────────────────────────────────────────────────────────────────────
|
||||
|
||||
def _print_stats(records: list[dict]) -> None:
|
||||
from collections import Counter
|
||||
products = Counter(r["meta"]["product"] for r in records)
|
||||
doc_types = Counter(r["meta"]["doc_type"] for r in records)
|
||||
pair_roles = Counter(r["meta"]["pair_role"] for r in records)
|
||||
wc = [r["meta"]["word_count"] for r in records]
|
||||
wc.sort()
|
||||
|
||||
print(f"\n{'='*55}")
|
||||
print(f" Total records: {len(records)}")
|
||||
print(f" Word counts : min={wc[0]}, median={wc[len(wc)//2]}, max={wc[-1]}")
|
||||
print(f"\n By product:")
|
||||
for p, n in products.most_common():
|
||||
print(f" {p:<22} {n}")
|
||||
print(f"\n By doc type:")
|
||||
for t, n in doc_types.most_common():
|
||||
print(f" {t:<22} {n}")
|
||||
print(f"\n Pair roles:")
|
||||
for r, n in pair_roles.most_common():
|
||||
print(f" {r:<22} {n}")
|
||||
print(f"{'='*55}\n")
|
||||
|
||||
|
||||
def _print_sample(records: list[dict], n: int = 3) -> None:
|
||||
import random
|
||||
sample = random.sample(records, min(n, len(records)))
|
||||
for i, rec in enumerate(sample, 1):
|
||||
meta = rec["meta"]
|
||||
user_msg = rec["messages"][0]["content"]
|
||||
asst_msg = rec["messages"][1]["content"]
|
||||
print(f"\n{'─'*55}")
|
||||
print(f"SAMPLE {i}/{n} [{meta['product']} / {meta['doc_type']} / {meta['pair_role']}]")
|
||||
print(f"source: {meta['source']}")
|
||||
print(f"\nUSER ({len(user_msg)} chars):\n{user_msg[:500]}{'...' if len(user_msg)>500 else ''}")
|
||||
print(f"\nASSISTANT ({meta['word_count']} words):\n{asst_msg[:400]}{'...' if len(asst_msg)>400 else ''}")
|
||||
print(f"\n{'─'*55}\n")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter)
|
||||
parser.add_argument("--plans-dir", type=Path, default=_DEFAULT_PLANS_DIR)
|
||||
parser.add_argument("--output", type=Path, default=None,
|
||||
help="Write JSONL to this path (omit for preview-only)")
|
||||
parser.add_argument("--products", default=None,
|
||||
help="Comma-separated product filter, e.g. peregrine,kiwi")
|
||||
parser.add_argument("--preview", action="store_true",
|
||||
help="Print stats + sample records, don't write output")
|
||||
parser.add_argument("--samples", type=int, default=3,
|
||||
help="Number of sample records to show in preview (default 3)")
|
||||
args = parser.parse_args()
|
||||
|
||||
products = [p.strip() for p in args.products.split(",")] if args.products else None
|
||||
|
||||
print(f"Scanning {args.plans_dir} …", file=sys.stderr)
|
||||
records = export(args.plans_dir, products=products)
|
||||
|
||||
_print_stats(records)
|
||||
|
||||
if args.preview or args.output is None:
|
||||
_print_sample(records, n=args.samples)
|
||||
if args.output is None:
|
||||
print("(Pass --output <path> to write JSONL)")
|
||||
return
|
||||
|
||||
args.output.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(args.output, "w", encoding="utf-8") as f:
|
||||
for rec in records:
|
||||
f.write(json.dumps(rec, ensure_ascii=False) + "\n")
|
||||
|
||||
print(f"Wrote {len(records)} records to {args.output}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -1,355 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Corpus gatherer for the voice benchmark fine-tune pipeline.
|
||||
|
||||
Pulls writing samples from multiple sources and drops .txt files into
|
||||
data/voice_corpus/ in the format expected by benchmark_voice.py.
|
||||
|
||||
Sources:
|
||||
- Reddit: u/pyr0ball post history + comment history (public JSON API)
|
||||
- Campaign copy: claude-bridge/reddit-poster/campaigns/*.py (BODY strings)
|
||||
- Documents: brainmap, homeprojects notes, selected personal writing
|
||||
- Discord: requires manual export (see instructions below)
|
||||
|
||||
Usage:
|
||||
# Full gather (Reddit + local sources)
|
||||
conda run -n cf python scripts/gather_corpus.py
|
||||
|
||||
# Reddit only
|
||||
conda run -n cf python scripts/gather_corpus.py --source reddit
|
||||
|
||||
# Local files only (no network)
|
||||
conda run -n cf python scripts/gather_corpus.py --source local
|
||||
|
||||
# Process a Discord data export zip
|
||||
conda run -n cf python scripts/gather_corpus.py --discord /path/to/discord-export.zip
|
||||
|
||||
Discord export instructions:
|
||||
Discord Settings → Privacy & Safety → Request all my data
|
||||
Wait for email, download zip, then run with --discord flag.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import ast
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Paths
|
||||
# ------------------------------------------------------------------ #
|
||||
|
||||
_ROOT = Path(__file__).parent.parent
|
||||
_CORPUS_DIR = _ROOT / "data" / "style_corpus"
|
||||
_CLAUDE_BRIDGE = Path("/Library/Development/CircuitForge/claude-bridge")
|
||||
_DOCUMENTS = Path("/Library/Documents")
|
||||
|
||||
_REDDIT_USER = "pyr0ball"
|
||||
_USER_AGENT = "Avocet/0.1 corpus-gatherer (CircuitForge; personal research)"
|
||||
_REDDIT_BASE = "https://www.reddit.com"
|
||||
|
||||
# Minimum character length to include a sample (filters out one-liners)
|
||||
_MIN_LENGTH = 80
|
||||
|
||||
# Phrases that suggest AI-generated content — skip these
|
||||
_AI_TELLS = [
|
||||
"certainly!", "absolutely!", "great question", "i'd be happy to",
|
||||
"i apologize for", "it's worth noting", "in conclusion,",
|
||||
"feel free to reach out",
|
||||
]
|
||||
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Helpers
|
||||
# ------------------------------------------------------------------ #
|
||||
|
||||
def _is_ai_generated(text: str) -> bool:
|
||||
lower = text.lower()
|
||||
return any(phrase in lower for phrase in _AI_TELLS)
|
||||
|
||||
|
||||
def _clean(text: str) -> str:
|
||||
"""Strip Reddit formatting artifacts and normalize whitespace."""
|
||||
text = re.sub(r"\[deleted\]|\[removed\]", "", text)
|
||||
text = re.sub(r"\s+", " ", text).strip()
|
||||
return text
|
||||
|
||||
|
||||
def _write_corpus_file(filename: str, samples: list[str], source_label: str) -> None:
|
||||
"""Write samples to a corpus .txt file with minimal separators."""
|
||||
path = _CORPUS_DIR / filename
|
||||
kept = [s for s in samples if len(s) >= _MIN_LENGTH and not _is_ai_generated(s)]
|
||||
if not kept:
|
||||
print(f" [skip] {filename} — no samples passed filters")
|
||||
return
|
||||
separator = "\n\n---\n\n"
|
||||
path.write_text(separator.join(kept), encoding="utf-8")
|
||||
print(f" [ok] {filename} — {len(kept)} samples ({path.stat().st_size // 1024}KB)")
|
||||
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Reddit source
|
||||
# ------------------------------------------------------------------ #
|
||||
|
||||
def _reddit_fetch_page(
|
||||
client: httpx.Client,
|
||||
listing_type: str,
|
||||
after: str | None,
|
||||
) -> tuple[list[dict[str, Any]], str | None]:
|
||||
"""Fetch one page of a user's submitted posts or comments."""
|
||||
params: dict[str, Any] = {"limit": 100, "raw_json": 1}
|
||||
if after:
|
||||
params["after"] = after
|
||||
url = f"{_REDDIT_BASE}/user/{_REDDIT_USER}/{listing_type}.json"
|
||||
resp = client.get(url, params=params)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
children = data["data"]["children"]
|
||||
new_after = data["data"].get("after")
|
||||
return [c["data"] for c in children], new_after
|
||||
|
||||
|
||||
def _reddit_fetch_all(listing_type: str, max_items: int = 1000) -> list[dict[str, Any]]:
|
||||
"""Paginate through a user listing until exhausted or max_items reached."""
|
||||
items: list[dict[str, Any]] = []
|
||||
after: str | None = None
|
||||
with httpx.Client(
|
||||
headers={"User-Agent": _USER_AGENT},
|
||||
follow_redirects=True,
|
||||
timeout=20.0,
|
||||
) as client:
|
||||
while len(items) < max_items:
|
||||
try:
|
||||
page, after = _reddit_fetch_page(client, listing_type, after)
|
||||
except httpx.HTTPStatusError as exc:
|
||||
# Reddit blocks unauthenticated pagination after the first page;
|
||||
# save what we have rather than crashing.
|
||||
print(f" stopped at {len(items)} {listing_type} (HTTP {exc.response.status_code})")
|
||||
break
|
||||
if not page:
|
||||
break
|
||||
items.extend(page)
|
||||
print(f" fetched {len(items)} {listing_type}...")
|
||||
if not after:
|
||||
break
|
||||
time.sleep(1.0) # respect rate limit
|
||||
return items
|
||||
|
||||
|
||||
def gather_reddit() -> None:
|
||||
print("Fetching Reddit history for u/pyr0ball...")
|
||||
|
||||
# Posts (submitted)
|
||||
print(" Posts:")
|
||||
posts = _reddit_fetch_all("submitted")
|
||||
post_texts: list[str] = []
|
||||
for p in posts:
|
||||
body = _clean(p.get("selftext", "") or "")
|
||||
title = _clean(p.get("title", ""))
|
||||
if len(body) >= _MIN_LENGTH:
|
||||
post_texts.append(f"{title}\n\n{body}")
|
||||
elif len(title) >= 20:
|
||||
# Title-only posts (link posts) — include title as micro-sample
|
||||
post_texts.append(title)
|
||||
_write_corpus_file("social_post_reddit.txt", post_texts, "reddit/submitted")
|
||||
|
||||
# Comments
|
||||
print(" Comments:")
|
||||
comments = _reddit_fetch_all("comments")
|
||||
comment_texts: list[str] = []
|
||||
for c in comments:
|
||||
body = _clean(c.get("body", "") or "")
|
||||
if body and body not in ("[deleted]", "[removed]"):
|
||||
comment_texts.append(body)
|
||||
_write_corpus_file("social_reply_reddit_comments.txt", comment_texts, "reddit/comments")
|
||||
|
||||
print(f" Done. {len(posts)} posts, {len(comments)} comments fetched.")
|
||||
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Campaign copy source (claude-bridge)
|
||||
# ------------------------------------------------------------------ #
|
||||
|
||||
def _extract_body_from_campaign(py_file: Path) -> str | None:
|
||||
"""
|
||||
Parse a campaign Python file and extract the BODY string literal.
|
||||
Uses AST to handle multi-line strings safely.
|
||||
"""
|
||||
try:
|
||||
tree = ast.parse(py_file.read_text(encoding="utf-8"))
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.Assign):
|
||||
for target in node.targets:
|
||||
if isinstance(target, ast.Name) and target.id == "BODY":
|
||||
if isinstance(node.value, ast.Constant):
|
||||
return str(node.value.value)
|
||||
except (SyntaxError, UnicodeDecodeError):
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def gather_campaigns() -> None:
|
||||
campaigns_dir = _CLAUDE_BRIDGE / "reddit-poster" / "campaigns"
|
||||
if not campaigns_dir.exists():
|
||||
print(f" [skip] campaigns dir not found: {campaigns_dir}")
|
||||
return
|
||||
|
||||
print("Gathering campaign copy from claude-bridge...")
|
||||
samples: list[str] = []
|
||||
for py_file in sorted(campaigns_dir.glob("*.py")):
|
||||
body = _extract_body_from_campaign(py_file)
|
||||
if body:
|
||||
samples.append(body.strip())
|
||||
print(f" {py_file.name} — {len(body)} chars")
|
||||
|
||||
_write_corpus_file("narrative_campaign_copy.txt", samples, "claude-bridge/campaigns")
|
||||
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Documents source
|
||||
# ------------------------------------------------------------------ #
|
||||
|
||||
def gather_documents() -> None:
|
||||
print("Gathering local Documents...")
|
||||
samples: list[str] = []
|
||||
|
||||
# brainmap — personal planning/thinking notes
|
||||
brainmap = _DOCUMENTS / "brainmap_v1.md"
|
||||
if brainmap.exists():
|
||||
text = _clean(brainmap.read_text(encoding="utf-8"))
|
||||
if len(text) >= _MIN_LENGTH:
|
||||
samples.append(text)
|
||||
print(f" brainmap_v1.md — {len(text)} chars")
|
||||
|
||||
# HomeProjects handoff notes — casual technical prose
|
||||
for handoff in sorted((_DOCUMENTS / "HomeProjects").glob("handoff*.md")):
|
||||
text = _clean(handoff.read_text(encoding="utf-8", errors="replace"))
|
||||
if len(text) >= _MIN_LENGTH:
|
||||
samples.append(text)
|
||||
print(f" {handoff.name} — {len(text)} chars")
|
||||
|
||||
# Personal letters (Closet folder) — intimate prose voice
|
||||
closet = _DOCUMENTS / "Closet"
|
||||
if closet.exists():
|
||||
for letter in closet.glob("*.md"):
|
||||
text = _clean(letter.read_text(encoding="utf-8", errors="replace"))
|
||||
if len(text) >= _MIN_LENGTH and not _is_ai_generated(text):
|
||||
samples.append(text)
|
||||
print(f" {letter.name} — {len(text)} chars")
|
||||
|
||||
_write_corpus_file("narrative_personal_docs.txt", samples, "documents")
|
||||
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Discord export source
|
||||
# ------------------------------------------------------------------ #
|
||||
|
||||
def gather_discord(export_zip: Path) -> None:
|
||||
"""
|
||||
Process a Discord data export zip (from Settings → Privacy & Safety → Request all my data).
|
||||
|
||||
Expected zip structure:
|
||||
messages/
|
||||
c{channel_id}/
|
||||
messages.json -- list of {ID, Timestamp, Contents, Attachments}
|
||||
account/
|
||||
user.json -- {username, ...}
|
||||
"""
|
||||
print(f"Processing Discord export: {export_zip}")
|
||||
samples: list[str] = []
|
||||
message_count = 0
|
||||
|
||||
with zipfile.ZipFile(export_zip) as zf:
|
||||
# Find all messages.json files
|
||||
message_files = [n for n in zf.namelist() if n.endswith("/messages.json")]
|
||||
print(f" Found {len(message_files)} channel(s)")
|
||||
|
||||
for mf in message_files:
|
||||
try:
|
||||
data = json.loads(zf.read(mf))
|
||||
except (json.JSONDecodeError, KeyError):
|
||||
continue
|
||||
|
||||
for msg in data:
|
||||
content = _clean(msg.get("Contents", "") or "")
|
||||
# Skip system messages, bot commands, very short messages
|
||||
if (
|
||||
len(content) < _MIN_LENGTH
|
||||
or content.startswith("/")
|
||||
or content.startswith("!")
|
||||
or _is_ai_generated(content)
|
||||
):
|
||||
continue
|
||||
# Skip messages that are just URLs or attachments
|
||||
if re.match(r"^https?://\S+$", content):
|
||||
continue
|
||||
samples.append(content)
|
||||
message_count += 1
|
||||
|
||||
print(f" {message_count} messages → {len(samples)} passed filters")
|
||||
_write_corpus_file("social_reply_discord.txt", samples, "discord")
|
||||
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Entrypoint
|
||||
# ------------------------------------------------------------------ #
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="Gather writing corpus for voice benchmark")
|
||||
parser.add_argument(
|
||||
"--source",
|
||||
choices=["reddit", "local", "all"],
|
||||
default="all",
|
||||
help="Which sources to gather (default: all)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--discord",
|
||||
type=Path,
|
||||
metavar="ZIP",
|
||||
help="Path to Discord data export zip",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
_CORPUS_DIR.mkdir(parents=True, exist_ok=True)
|
||||
print(f"Output: {_CORPUS_DIR}\n")
|
||||
|
||||
if args.source in ("reddit", "all"):
|
||||
gather_reddit()
|
||||
print()
|
||||
|
||||
if args.source in ("local", "all"):
|
||||
gather_campaigns()
|
||||
print()
|
||||
gather_documents()
|
||||
print()
|
||||
|
||||
if args.discord:
|
||||
if not args.discord.exists():
|
||||
print(f"Error: Discord export not found: {args.discord}")
|
||||
else:
|
||||
gather_discord(args.discord)
|
||||
print()
|
||||
|
||||
if not args.discord and args.source in ("local", "all"):
|
||||
print("Discord: manual step required")
|
||||
print(" 1. Discord Settings → Privacy & Safety → Request all my data")
|
||||
print(" 2. Download the zip from the email link")
|
||||
print(" 3. Run: python scripts/gather_corpus.py --discord /path/to/package.zip")
|
||||
print()
|
||||
|
||||
# Summary
|
||||
corpus_files = sorted(_CORPUS_DIR.glob("*.txt"))
|
||||
total_chars = sum(f.stat().st_size for f in corpus_files)
|
||||
print(f"Corpus: {len(corpus_files)} file(s), {total_chars // 1024}KB total")
|
||||
for f in corpus_files:
|
||||
print(f" {f.name}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -1,110 +0,0 @@
|
|||
"""Avocet — SFT candidate run discovery and JSONL import.
|
||||
|
||||
No FastAPI dependency — pure Python file operations.
|
||||
Used by app/sft.py endpoints and can be run standalone.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_CANDIDATES_FILENAME = "sft_candidates.jsonl"
|
||||
|
||||
|
||||
def discover_runs(bench_results_dir: Path) -> list[dict]:
|
||||
"""Return one entry per run subdirectory that contains sft_candidates.jsonl.
|
||||
|
||||
Sorted newest-first by directory name (directories are named YYYY-MM-DD-HHMMSS
|
||||
by the cf-orch benchmark harness, so lexicographic order is chronological).
|
||||
|
||||
Each entry: {run_id, timestamp, candidate_count, sft_path}
|
||||
"""
|
||||
if not bench_results_dir.exists() or not bench_results_dir.is_dir():
|
||||
return []
|
||||
runs = []
|
||||
for subdir in bench_results_dir.iterdir():
|
||||
if not subdir.is_dir():
|
||||
continue
|
||||
sft_path = subdir / _CANDIDATES_FILENAME
|
||||
if not sft_path.exists():
|
||||
continue
|
||||
records = _read_jsonl(sft_path)
|
||||
runs.append({
|
||||
"run_id": subdir.name,
|
||||
"timestamp": subdir.name,
|
||||
"candidate_count": len(records),
|
||||
"sft_path": sft_path,
|
||||
})
|
||||
runs.sort(key=lambda r: r["run_id"], reverse=True)
|
||||
return runs
|
||||
|
||||
|
||||
def import_run(sft_path: Path, data_dir: Path) -> dict[str, int]:
|
||||
"""Append records from sft_path into data_dir/sft_candidates.jsonl.
|
||||
|
||||
Deduplicates on the `id` field — records whose id already exists in the
|
||||
destination file are skipped silently. Records missing an `id` field are
|
||||
also skipped (malformed input from a partial benchmark write).
|
||||
|
||||
Returns {imported: N, skipped: M}.
|
||||
"""
|
||||
dest = data_dir / _CANDIDATES_FILENAME
|
||||
existing_ids = _read_existing_ids(dest)
|
||||
|
||||
new_records: list[dict] = []
|
||||
skipped = 0
|
||||
for record in _read_jsonl(sft_path):
|
||||
if "id" not in record:
|
||||
logger.warning("Skipping record missing 'id' field in %s", sft_path)
|
||||
continue # malformed — skip without crashing
|
||||
if record["id"] in existing_ids:
|
||||
skipped += 1
|
||||
continue
|
||||
new_records.append(record)
|
||||
existing_ids.add(record["id"])
|
||||
|
||||
if new_records:
|
||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(dest, "a", encoding="utf-8") as fh:
|
||||
for r in new_records:
|
||||
fh.write(json.dumps(r) + "\n")
|
||||
|
||||
return {"imported": len(new_records), "skipped": skipped}
|
||||
|
||||
|
||||
def _read_jsonl(path: Path) -> list[dict]:
|
||||
"""Read a JSONL file, returning valid records. Skips blank lines and malformed JSON."""
|
||||
if not path.exists():
|
||||
return []
|
||||
records: list[dict] = []
|
||||
for line in path.read_text(encoding="utf-8").splitlines():
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
records.append(json.loads(line))
|
||||
except json.JSONDecodeError as exc:
|
||||
logger.warning("Skipping malformed JSON line in %s: %s", path, exc)
|
||||
return records
|
||||
|
||||
|
||||
def _read_existing_ids(path: Path) -> set[str]:
|
||||
"""Read only the id field from each line of a JSONL file."""
|
||||
if not path.exists():
|
||||
return set()
|
||||
ids: set[str] = set()
|
||||
with path.open() as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
record = json.loads(line)
|
||||
if "id" in record:
|
||||
ids.add(record["id"])
|
||||
except json.JSONDecodeError:
|
||||
pass # corrupt line, skip silently (ids file is our own output)
|
||||
return ids
|
||||
|
|
@ -1,37 +1,23 @@
|
|||
"""Smoke tests for the app factory (app/api.py).
|
||||
import json
|
||||
|
||||
Detailed route tests live in test_data_label.py, test_data_fetch.py,
|
||||
test_data_corrections.py, test_train.py, and test_dashboard.py.
|
||||
"""
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from app import api as api_module # noqa: F401
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_globals(tmp_path):
|
||||
from app import api
|
||||
api.set_data_dir(tmp_path)
|
||||
api.reset_last_action()
|
||||
yield
|
||||
api.reset_last_action()
|
||||
|
||||
|
||||
def test_import():
|
||||
from app import api # noqa: F401
|
||||
|
||||
|
||||
def test_app_has_required_routes():
|
||||
from app.api import app
|
||||
paths = {r.path for r in app.routes}
|
||||
# Label routes
|
||||
assert "/api/queue" in paths
|
||||
assert "/api/label" in paths
|
||||
assert "/api/skip" in paths
|
||||
assert "/api/discard" in paths
|
||||
assert "/api/label/undo" in paths
|
||||
assert "/api/config/labels" in paths
|
||||
assert "/api/stats" in paths
|
||||
# Fetch routes
|
||||
assert "/api/accounts/test" in paths
|
||||
assert "/api/fetch/stream" in paths
|
||||
# Train routes
|
||||
assert "/api/train/jobs" in paths
|
||||
assert "/api/train/results" in paths
|
||||
# Dashboard
|
||||
assert "/api/dashboard" in paths
|
||||
# Corrections (new prefix)
|
||||
assert "/api/corrections/ingest" in paths
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -40,8 +26,536 @@ def client():
|
|||
return TestClient(app)
|
||||
|
||||
|
||||
def test_queue_endpoint_reachable(client):
|
||||
@pytest.fixture
|
||||
def queue_with_items():
|
||||
"""Write 3 test emails to the queue file."""
|
||||
from app import api as api_module
|
||||
items = [
|
||||
{"id": f"id{i}", "subject": f"Subject {i}", "body": f"Body {i}",
|
||||
"from": "test@example.com", "date": "2026-03-01", "source": "imap:test"}
|
||||
for i in range(3)
|
||||
]
|
||||
queue_path = api_module._DATA_DIR / "email_label_queue.jsonl"
|
||||
queue_path.write_text("\n".join(json.dumps(x) for x in items) + "\n")
|
||||
return items
|
||||
|
||||
|
||||
def test_queue_returns_items(client, queue_with_items):
|
||||
r = client.get("/api/queue?limit=2")
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert len(data["items"]) == 2
|
||||
assert data["total"] == 3
|
||||
|
||||
|
||||
def test_queue_empty_when_no_file(client):
|
||||
r = client.get("/api/queue")
|
||||
assert r.status_code == 200
|
||||
assert "items" in r.json()
|
||||
assert "total" in r.json()
|
||||
assert r.json() == {"items": [], "total": 0}
|
||||
|
||||
|
||||
def test_label_appends_to_score(client, queue_with_items):
|
||||
from app import api as api_module
|
||||
r = client.post("/api/label", json={"id": "id0", "label": "interview_scheduled"})
|
||||
assert r.status_code == 200
|
||||
records = api_module._read_jsonl(api_module._score_file())
|
||||
assert len(records) == 1
|
||||
assert records[0]["id"] == "id0"
|
||||
assert records[0]["label"] == "interview_scheduled"
|
||||
assert "labeled_at" in records[0]
|
||||
|
||||
def test_label_removes_from_queue(client, queue_with_items):
|
||||
from app import api as api_module
|
||||
client.post("/api/label", json={"id": "id0", "label": "rejected"})
|
||||
queue = api_module._read_jsonl(api_module._queue_file())
|
||||
assert not any(x["id"] == "id0" for x in queue)
|
||||
|
||||
def test_label_unknown_id_returns_404(client, queue_with_items):
|
||||
r = client.post("/api/label", json={"id": "unknown", "label": "neutral"})
|
||||
assert r.status_code == 404
|
||||
|
||||
def test_skip_moves_to_back(client, queue_with_items):
|
||||
from app import api as api_module
|
||||
r = client.post("/api/skip", json={"id": "id0"})
|
||||
assert r.status_code == 200
|
||||
queue = api_module._read_jsonl(api_module._queue_file())
|
||||
assert queue[-1]["id"] == "id0"
|
||||
assert queue[0]["id"] == "id1"
|
||||
|
||||
def test_skip_unknown_id_returns_404(client, queue_with_items):
|
||||
r = client.post("/api/skip", json={"id": "nope"})
|
||||
assert r.status_code == 404
|
||||
|
||||
|
||||
# --- Part A: POST /api/discard ---
|
||||
|
||||
def test_discard_writes_to_discarded_file(client, queue_with_items):
|
||||
from app import api as api_module
|
||||
r = client.post("/api/discard", json={"id": "id1"})
|
||||
assert r.status_code == 200
|
||||
discarded = api_module._read_jsonl(api_module._discarded_file())
|
||||
assert len(discarded) == 1
|
||||
assert discarded[0]["id"] == "id1"
|
||||
assert discarded[0]["label"] == "__discarded__"
|
||||
|
||||
def test_discard_removes_from_queue(client, queue_with_items):
|
||||
from app import api as api_module
|
||||
client.post("/api/discard", json={"id": "id1"})
|
||||
queue = api_module._read_jsonl(api_module._queue_file())
|
||||
assert not any(x["id"] == "id1" for x in queue)
|
||||
|
||||
|
||||
# --- Part B: DELETE /api/label/undo ---
|
||||
|
||||
def test_undo_label_removes_from_score(client, queue_with_items):
|
||||
from app import api as api_module
|
||||
client.post("/api/label", json={"id": "id0", "label": "neutral"})
|
||||
r = client.delete("/api/label/undo")
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["undone"]["type"] == "label"
|
||||
score = api_module._read_jsonl(api_module._score_file())
|
||||
assert score == []
|
||||
# Item should be restored to front of queue
|
||||
queue = api_module._read_jsonl(api_module._queue_file())
|
||||
assert queue[0]["id"] == "id0"
|
||||
|
||||
def test_undo_discard_removes_from_discarded(client, queue_with_items):
|
||||
from app import api as api_module
|
||||
client.post("/api/discard", json={"id": "id0"})
|
||||
r = client.delete("/api/label/undo")
|
||||
assert r.status_code == 200
|
||||
discarded = api_module._read_jsonl(api_module._discarded_file())
|
||||
assert discarded == []
|
||||
|
||||
def test_undo_skip_restores_to_front(client, queue_with_items):
|
||||
from app import api as api_module
|
||||
client.post("/api/skip", json={"id": "id0"})
|
||||
r = client.delete("/api/label/undo")
|
||||
assert r.status_code == 200
|
||||
queue = api_module._read_jsonl(api_module._queue_file())
|
||||
assert queue[0]["id"] == "id0"
|
||||
|
||||
def test_undo_with_no_action_returns_404(client):
|
||||
r = client.delete("/api/label/undo")
|
||||
assert r.status_code == 404
|
||||
|
||||
|
||||
# --- Part C: GET /api/config/labels ---
|
||||
|
||||
def test_config_labels_returns_metadata(client):
|
||||
r = client.get("/api/config/labels")
|
||||
assert r.status_code == 200
|
||||
labels = r.json()
|
||||
assert len(labels) == 10
|
||||
assert labels[0]["key"] == "1"
|
||||
assert "emoji" in labels[0]
|
||||
assert "color" in labels[0]
|
||||
assert "name" in labels[0]
|
||||
|
||||
|
||||
# ── /api/config ──────────────────────────────────────────────────────────────
|
||||
|
||||
@pytest.fixture
|
||||
def config_dir(tmp_path):
|
||||
"""Give the API a writable config directory."""
|
||||
from app import api as api_module
|
||||
api_module.set_config_dir(tmp_path)
|
||||
yield tmp_path
|
||||
api_module.set_config_dir(None) # reset to default
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def data_dir():
|
||||
"""Expose the current _DATA_DIR set by the autouse reset_globals fixture."""
|
||||
from app import api as api_module
|
||||
return api_module._DATA_DIR
|
||||
|
||||
|
||||
def test_get_config_returns_empty_when_no_file(client, config_dir):
|
||||
r = client.get("/api/config")
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["accounts"] == []
|
||||
assert data["max_per_account"] == 500
|
||||
|
||||
|
||||
def test_post_config_writes_yaml(client, config_dir):
|
||||
import yaml
|
||||
payload = {
|
||||
"accounts": [{"name": "Test", "host": "imap.test.com", "port": 993,
|
||||
"use_ssl": True, "username": "u@t.com", "password": "pw",
|
||||
"folder": "INBOX", "days_back": 30}],
|
||||
"max_per_account": 200,
|
||||
}
|
||||
r = client.post("/api/config", json=payload)
|
||||
assert r.status_code == 200
|
||||
assert r.json()["ok"] is True
|
||||
cfg_file = config_dir / "label_tool.yaml"
|
||||
assert cfg_file.exists()
|
||||
saved = yaml.safe_load(cfg_file.read_text())
|
||||
assert saved["max_per_account"] == 200
|
||||
assert saved["accounts"][0]["name"] == "Test"
|
||||
|
||||
|
||||
def test_get_config_round_trips(client, config_dir):
|
||||
payload = {"accounts": [{"name": "R", "host": "h", "port": 993, "use_ssl": True,
|
||||
"username": "u", "password": "p", "folder": "INBOX",
|
||||
"days_back": 90}], "max_per_account": 300}
|
||||
client.post("/api/config", json=payload)
|
||||
r = client.get("/api/config")
|
||||
data = r.json()
|
||||
assert data["max_per_account"] == 300
|
||||
assert data["accounts"][0]["name"] == "R"
|
||||
|
||||
|
||||
# ── /api/stats ───────────────────────────────────────────────────────────────
|
||||
|
||||
@pytest.fixture
|
||||
def score_with_labels(tmp_path, data_dir):
|
||||
"""Write a score file with 3 labels for stats tests."""
|
||||
score_path = data_dir / "email_score.jsonl"
|
||||
records = [
|
||||
{"id": "a", "label": "interview_scheduled"},
|
||||
{"id": "b", "label": "interview_scheduled"},
|
||||
{"id": "c", "label": "rejected"},
|
||||
]
|
||||
score_path.write_text("\n".join(json.dumps(r) for r in records) + "\n")
|
||||
return records
|
||||
|
||||
|
||||
def test_stats_returns_counts(client, score_with_labels):
|
||||
r = client.get("/api/stats")
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["total"] == 3
|
||||
assert data["counts"]["interview_scheduled"] == 2
|
||||
assert data["counts"]["rejected"] == 1
|
||||
|
||||
|
||||
def test_stats_empty_when_no_file(client, data_dir):
|
||||
r = client.get("/api/stats")
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["total"] == 0
|
||||
assert data["counts"] == {}
|
||||
assert data["score_file_bytes"] == 0
|
||||
|
||||
|
||||
def test_stats_download_returns_file(client, score_with_labels):
|
||||
r = client.get("/api/stats/download")
|
||||
assert r.status_code == 200
|
||||
assert "jsonlines" in r.headers.get("content-type", "")
|
||||
|
||||
|
||||
def test_stats_download_404_when_no_file(client, data_dir):
|
||||
r = client.get("/api/stats/download")
|
||||
assert r.status_code == 404
|
||||
|
||||
|
||||
# ── /api/accounts/test ───────────────────────────────────────────────────────
|
||||
|
||||
def test_account_test_missing_fields(client):
|
||||
r = client.post("/api/accounts/test", json={"account": {"host": "", "username": "", "password": ""}})
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["ok"] is False
|
||||
assert "required" in data["message"].lower()
|
||||
|
||||
|
||||
def test_account_test_success(client):
|
||||
from unittest.mock import MagicMock, patch
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.select.return_value = ("OK", [b"99"])
|
||||
with patch("app.imap_fetch.imaplib.IMAP4_SSL", return_value=mock_conn):
|
||||
r = client.post("/api/accounts/test", json={"account": {
|
||||
"host": "imap.example.com", "port": 993, "use_ssl": True,
|
||||
"username": "u@example.com", "password": "pw", "folder": "INBOX",
|
||||
}})
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["ok"] is True
|
||||
assert data["count"] == 99
|
||||
|
||||
|
||||
# ── /api/fetch/stream (SSE) ──────────────────────────────────────────────────
|
||||
|
||||
def _parse_sse(content: bytes) -> list[dict]:
|
||||
"""Parse SSE response body into list of event dicts."""
|
||||
events = []
|
||||
for line in content.decode().splitlines():
|
||||
if line.startswith("data: "):
|
||||
events.append(json.loads(line[6:]))
|
||||
return events
|
||||
|
||||
|
||||
def test_fetch_stream_no_accounts_configured(client, config_dir):
|
||||
"""With no config, stream should immediately complete with 0 added."""
|
||||
r = client.get("/api/fetch/stream?accounts=NoSuchAccount&days_back=30&limit=10")
|
||||
assert r.status_code == 200
|
||||
events = _parse_sse(r.content)
|
||||
complete = next((e for e in events if e["type"] == "complete"), None)
|
||||
assert complete is not None
|
||||
assert complete["total_added"] == 0
|
||||
|
||||
|
||||
def test_fetch_stream_with_mock_imap(client, config_dir, data_dir):
|
||||
"""With one configured account, stream should yield start/done/complete events."""
|
||||
import yaml
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
# Write a config with one account
|
||||
cfg = {"accounts": [{"name": "Mock", "host": "h", "port": 993, "use_ssl": True,
|
||||
"username": "u", "password": "p", "folder": "INBOX",
|
||||
"days_back": 30}], "max_per_account": 50}
|
||||
(config_dir / "label_tool.yaml").write_text(yaml.dump(cfg))
|
||||
|
||||
raw_msg = (b"Subject: Interview\r\nFrom: a@b.com\r\n"
|
||||
b"Date: Mon, 1 Mar 2026 12:00:00 +0000\r\n\r\nBody")
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.search.return_value = ("OK", [b"1"])
|
||||
mock_conn.fetch.return_value = ("OK", [(b"1 (RFC822 {N})", raw_msg)])
|
||||
|
||||
with patch("app.imap_fetch.imaplib.IMAP4_SSL", return_value=mock_conn):
|
||||
r = client.get("/api/fetch/stream?accounts=Mock&days_back=30&limit=50")
|
||||
|
||||
assert r.status_code == 200
|
||||
events = _parse_sse(r.content)
|
||||
types = [e["type"] for e in events]
|
||||
assert "start" in types
|
||||
assert "done" in types
|
||||
assert "complete" in types
|
||||
|
||||
|
||||
# ---- /api/finetune/status tests ----
|
||||
|
||||
def test_finetune_status_returns_empty_when_no_models_dir(client):
|
||||
"""GET /api/finetune/status must return [] if models/ does not exist."""
|
||||
r = client.get("/api/finetune/status")
|
||||
assert r.status_code == 200
|
||||
assert r.json() == []
|
||||
|
||||
|
||||
def test_finetune_status_returns_training_info(client, tmp_path):
|
||||
"""GET /api/finetune/status must return one entry per training_info.json found."""
|
||||
import json as _json
|
||||
from app import api as api_module
|
||||
|
||||
models_dir = tmp_path / "models" / "avocet-deberta-small"
|
||||
models_dir.mkdir(parents=True)
|
||||
info = {
|
||||
"name": "avocet-deberta-small",
|
||||
"base_model_id": "cross-encoder/nli-deberta-v3-small",
|
||||
"val_macro_f1": 0.712,
|
||||
"timestamp": "2026-03-15T12:00:00Z",
|
||||
"sample_count": 401,
|
||||
}
|
||||
(models_dir / "training_info.json").write_text(_json.dumps(info))
|
||||
|
||||
api_module.set_models_dir(tmp_path / "models")
|
||||
try:
|
||||
r = client.get("/api/finetune/status")
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert any(d["name"] == "avocet-deberta-small" for d in data)
|
||||
finally:
|
||||
api_module.set_models_dir(api_module._ROOT / "models")
|
||||
|
||||
|
||||
def test_finetune_run_streams_sse_events(client):
|
||||
"""GET /api/finetune/run must return text/event-stream content type."""
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.stdout = iter(["Training epoch 1\n", "Done\n"])
|
||||
mock_proc.returncode = 0
|
||||
mock_proc.wait = MagicMock()
|
||||
|
||||
with patch("app.api._subprocess.Popen",return_value=mock_proc):
|
||||
r = client.get("/api/finetune/run?model=deberta-small&epochs=1")
|
||||
|
||||
assert r.status_code == 200
|
||||
assert "text/event-stream" in r.headers.get("content-type", "")
|
||||
|
||||
|
||||
def test_finetune_run_emits_complete_on_success(client):
|
||||
"""GET /api/finetune/run must emit a complete event on clean exit."""
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.stdout = iter(["progress line\n"])
|
||||
mock_proc.returncode = 0
|
||||
mock_proc.wait = MagicMock()
|
||||
|
||||
with patch("app.api._subprocess.Popen",return_value=mock_proc):
|
||||
r = client.get("/api/finetune/run?model=deberta-small&epochs=1")
|
||||
|
||||
assert '{"type": "complete"}' in r.text
|
||||
|
||||
|
||||
def test_finetune_run_emits_error_on_nonzero_exit(client):
|
||||
"""GET /api/finetune/run must emit an error event on non-zero exit."""
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.stdout = iter([])
|
||||
mock_proc.returncode = 1
|
||||
mock_proc.wait = MagicMock()
|
||||
|
||||
with patch("app.api._subprocess.Popen",return_value=mock_proc):
|
||||
r = client.get("/api/finetune/run?model=deberta-small&epochs=1")
|
||||
|
||||
assert '"type": "error"' in r.text
|
||||
|
||||
|
||||
def test_finetune_run_passes_score_files_to_subprocess(client):
|
||||
"""GET /api/finetune/run?score=file1&score=file2 must pass --score args to subprocess."""
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
captured_cmd = []
|
||||
|
||||
def mock_popen(cmd, **kwargs):
|
||||
captured_cmd.extend(cmd)
|
||||
m = MagicMock()
|
||||
m.stdout = iter([])
|
||||
m.returncode = 0
|
||||
m.wait = MagicMock()
|
||||
return m
|
||||
|
||||
with patch("app.api._subprocess.Popen",side_effect=mock_popen):
|
||||
client.get("/api/finetune/run?model=deberta-small&epochs=1&score=run1.jsonl&score=run2.jsonl")
|
||||
|
||||
assert "--score" in captured_cmd
|
||||
assert captured_cmd.count("--score") == 2
|
||||
# Paths are resolved to absolute — check filenames are present as substrings
|
||||
assert any("run1.jsonl" in arg for arg in captured_cmd)
|
||||
assert any("run2.jsonl" in arg for arg in captured_cmd)
|
||||
|
||||
|
||||
# ---- Cancel endpoint tests ----
|
||||
|
||||
def test_benchmark_cancel_returns_404_when_not_running(client):
|
||||
"""POST /api/benchmark/cancel must return 404 if no benchmark is running."""
|
||||
from app import api as api_module
|
||||
api_module._running_procs.pop("benchmark", None)
|
||||
r = client.post("/api/benchmark/cancel")
|
||||
assert r.status_code == 404
|
||||
|
||||
|
||||
def test_finetune_cancel_returns_404_when_not_running(client):
|
||||
"""POST /api/finetune/cancel must return 404 if no finetune is running."""
|
||||
from app import api as api_module
|
||||
api_module._running_procs.pop("finetune", None)
|
||||
r = client.post("/api/finetune/cancel")
|
||||
assert r.status_code == 404
|
||||
|
||||
|
||||
def test_benchmark_cancel_terminates_running_process(client):
|
||||
"""POST /api/benchmark/cancel must call terminate() on the running process."""
|
||||
from unittest.mock import MagicMock
|
||||
from app import api as api_module
|
||||
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.wait = MagicMock()
|
||||
api_module._running_procs["benchmark"] = mock_proc
|
||||
|
||||
try:
|
||||
r = client.post("/api/benchmark/cancel")
|
||||
assert r.status_code == 200
|
||||
assert r.json()["status"] == "cancelled"
|
||||
mock_proc.terminate.assert_called_once()
|
||||
finally:
|
||||
api_module._running_procs.pop("benchmark", None)
|
||||
api_module._cancelled_jobs.discard("benchmark")
|
||||
|
||||
|
||||
def test_finetune_cancel_terminates_running_process(client):
|
||||
"""POST /api/finetune/cancel must call terminate() on the running process."""
|
||||
from unittest.mock import MagicMock
|
||||
from app import api as api_module
|
||||
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.wait = MagicMock()
|
||||
api_module._running_procs["finetune"] = mock_proc
|
||||
|
||||
try:
|
||||
r = client.post("/api/finetune/cancel")
|
||||
assert r.status_code == 200
|
||||
assert r.json()["status"] == "cancelled"
|
||||
mock_proc.terminate.assert_called_once()
|
||||
finally:
|
||||
api_module._running_procs.pop("finetune", None)
|
||||
api_module._cancelled_jobs.discard("finetune")
|
||||
|
||||
|
||||
def test_benchmark_cancel_kills_process_on_timeout(client):
|
||||
"""POST /api/benchmark/cancel must call kill() if the process does not exit within 3 s."""
|
||||
import subprocess
|
||||
from unittest.mock import MagicMock
|
||||
from app import api as api_module
|
||||
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.wait.side_effect = subprocess.TimeoutExpired(cmd="benchmark", timeout=3)
|
||||
api_module._running_procs["benchmark"] = mock_proc
|
||||
|
||||
try:
|
||||
r = client.post("/api/benchmark/cancel")
|
||||
assert r.status_code == 200
|
||||
mock_proc.kill.assert_called_once()
|
||||
finally:
|
||||
api_module._running_procs.pop("benchmark", None)
|
||||
api_module._cancelled_jobs.discard("benchmark")
|
||||
|
||||
|
||||
def test_finetune_run_emits_cancelled_event(client):
|
||||
"""GET /api/finetune/run must emit cancelled (not error) when job was cancelled."""
|
||||
from unittest.mock import patch, MagicMock
|
||||
from app import api as api_module
|
||||
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.stdout = iter([])
|
||||
mock_proc.returncode = -15 # SIGTERM
|
||||
|
||||
def mock_wait():
|
||||
# Simulate cancel being called while the process is running (after discard clears stale flag)
|
||||
api_module._cancelled_jobs.add("finetune")
|
||||
|
||||
mock_proc.wait = mock_wait
|
||||
|
||||
def mock_popen(cmd, **kwargs):
|
||||
return mock_proc
|
||||
|
||||
try:
|
||||
with patch("app.api._subprocess.Popen",side_effect=mock_popen):
|
||||
r = client.get("/api/finetune/run?model=deberta-small&epochs=1")
|
||||
assert '{"type": "cancelled"}' in r.text
|
||||
assert '"type": "error"' not in r.text
|
||||
finally:
|
||||
api_module._cancelled_jobs.discard("finetune")
|
||||
|
||||
|
||||
def test_benchmark_run_emits_cancelled_event(client):
|
||||
"""GET /api/benchmark/run must emit cancelled (not error) when job was cancelled."""
|
||||
from unittest.mock import patch, MagicMock
|
||||
from app import api as api_module
|
||||
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.stdout = iter([])
|
||||
mock_proc.returncode = -15
|
||||
|
||||
def mock_wait():
|
||||
# Simulate cancel being called while the process is running (after discard clears stale flag)
|
||||
api_module._cancelled_jobs.add("benchmark")
|
||||
|
||||
mock_proc.wait = mock_wait
|
||||
|
||||
def mock_popen(cmd, **kwargs):
|
||||
return mock_proc
|
||||
|
||||
try:
|
||||
with patch("app.api._subprocess.Popen",side_effect=mock_popen):
|
||||
r = client.get("/api/benchmark/run")
|
||||
assert '{"type": "cancelled"}' in r.text
|
||||
assert '"type": "error"' not in r.text
|
||||
finally:
|
||||
api_module._cancelled_jobs.discard("benchmark")
|
||||
|
|
|
|||
|
|
@ -2,6 +2,11 @@
|
|||
import pytest
|
||||
|
||||
|
||||
def test_registry_has_thirteen_models():
|
||||
from scripts.benchmark_classifier import MODEL_REGISTRY
|
||||
assert len(MODEL_REGISTRY) == 13
|
||||
|
||||
|
||||
def test_registry_default_count():
|
||||
from scripts.benchmark_classifier import MODEL_REGISTRY
|
||||
defaults = [k for k, v in MODEL_REGISTRY.items() if v["default"]]
|
||||
|
|
@ -161,95 +166,3 @@ def test_active_models_includes_discovered_finetuned(tmp_path):
|
|||
|
||||
assert "avocet-deberta-small" in models
|
||||
assert isinstance(models["avocet-deberta-small"]["adapter_instance"], FineTunedAdapter)
|
||||
|
||||
|
||||
# ---- build_exemplars_from_jsonl() tests ----
|
||||
|
||||
def test_build_exemplars_samples_up_to_k_per_label(tmp_path):
|
||||
from scripts.benchmark_classifier import build_exemplars_from_jsonl
|
||||
import json
|
||||
|
||||
rows = [{"subject": f"S{i}", "body": f"B{i}", "label": "rejected"} for i in range(15)]
|
||||
rows.append({"subject": "Hire", "body": "Welcome", "label": "hired"})
|
||||
f = tmp_path / "score.jsonl"
|
||||
f.write_text("\n".join(json.dumps(r) for r in rows))
|
||||
|
||||
result = build_exemplars_from_jsonl(str(f), k_per_label=10)
|
||||
|
||||
assert len(result["rejected"]) == 10
|
||||
assert len(result["hired"]) == 1
|
||||
assert result["rejected"][0].startswith("Subject: S")
|
||||
|
||||
|
||||
def test_build_exemplars_formats_text_correctly(tmp_path):
|
||||
from scripts.benchmark_classifier import build_exemplars_from_jsonl
|
||||
import json
|
||||
|
||||
row = {"subject": "My Subject", "body": "My Body", "label": "neutral"}
|
||||
f = tmp_path / "score.jsonl"
|
||||
f.write_text(json.dumps(row))
|
||||
|
||||
result = build_exemplars_from_jsonl(str(f))
|
||||
|
||||
assert result["neutral"][0] == "Subject: My Subject\n\nMy Body"
|
||||
|
||||
|
||||
def test_build_exemplars_skips_rows_missing_label(tmp_path):
|
||||
from scripts.benchmark_classifier import build_exemplars_from_jsonl
|
||||
import json
|
||||
|
||||
rows = [
|
||||
{"subject": "A", "body": "B", "label": "neutral"},
|
||||
{"subject": "No label here", "body": "Body"},
|
||||
]
|
||||
f = tmp_path / "score.jsonl"
|
||||
f.write_text("\n".join(json.dumps(r) for r in rows))
|
||||
|
||||
result = build_exemplars_from_jsonl(str(f))
|
||||
assert list(result.keys()) == ["neutral"]
|
||||
|
||||
|
||||
def test_build_exemplars_truncates_body_at_600(tmp_path):
|
||||
from scripts.benchmark_classifier import build_exemplars_from_jsonl
|
||||
import json
|
||||
|
||||
row = {"subject": "S", "body": "x" * 800, "label": "neutral"}
|
||||
f = tmp_path / "score.jsonl"
|
||||
f.write_text(json.dumps(row))
|
||||
|
||||
result = build_exemplars_from_jsonl(str(f))
|
||||
body_part = result["neutral"][0].split("\n\n", 1)[1]
|
||||
assert len(body_part) == 600
|
||||
|
||||
|
||||
def test_build_exemplars_skips_rows_with_no_content(tmp_path):
|
||||
from scripts.benchmark_classifier import build_exemplars_from_jsonl
|
||||
import json
|
||||
|
||||
rows = [
|
||||
{"label": "neutral"}, # no subject, no body -> skip
|
||||
{"subject": "S", "body": "B", "label": "neutral"}, # valid -> keep
|
||||
{"label": "rejected", "subject": "", "body": ""}, # empty strings -> skip
|
||||
]
|
||||
f = tmp_path / "score.jsonl"
|
||||
lines = [json.dumps(r) for r in rows]
|
||||
f.write_text("\n".join(lines))
|
||||
|
||||
result = build_exemplars_from_jsonl(str(f))
|
||||
assert list(result.keys()) == ["neutral"]
|
||||
assert len(result["neutral"]) == 1
|
||||
|
||||
def test_registry_has_fourteen_models():
|
||||
from scripts.benchmark_classifier import MODEL_REGISTRY
|
||||
assert len(MODEL_REGISTRY) == 14
|
||||
|
||||
|
||||
def test_embed_knn_nomic_registry_entry():
|
||||
from scripts.benchmark_classifier import MODEL_REGISTRY
|
||||
from scripts.classifier_adapters import EmbeddingKNNAdapter
|
||||
entry = MODEL_REGISTRY["embed-knn-nomic"]
|
||||
assert entry["adapter"] is EmbeddingKNNAdapter
|
||||
assert entry["model_id"] == "nomic-embed-text"
|
||||
assert entry["params"] == "local-embed"
|
||||
assert entry["default"] is False
|
||||
assert entry.get("kwargs", {}).get("k") == 3
|
||||
|
|
|
|||
|
|
@ -1,418 +0,0 @@
|
|||
"""Tests for app/cforch.py — /api/cforch/* endpoints."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
|
||||
# ── Fixtures ───────────────────────────────────────────────────────────────────
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_cforch_globals(tmp_path):
|
||||
"""Redirect _CONFIG_DIR to tmp_path, reset running-state globals, and stub
|
||||
list_installed to return [] so real disk model directories don't bleed into
|
||||
tests that don't exercise the installed-model merge path."""
|
||||
from app import cforch as cforch_module
|
||||
|
||||
prev_config_dir = cforch_module._CONFIG_DIR
|
||||
prev_running = cforch_module._BENCH_RUNNING
|
||||
prev_proc = cforch_module._bench_proc
|
||||
|
||||
cforch_module.set_config_dir(tmp_path)
|
||||
cforch_module._BENCH_RUNNING = False
|
||||
cforch_module._bench_proc = None
|
||||
|
||||
with patch("app.models.list_installed", return_value=[]):
|
||||
yield tmp_path
|
||||
|
||||
cforch_module.set_config_dir(prev_config_dir)
|
||||
cforch_module._BENCH_RUNNING = prev_running
|
||||
cforch_module._bench_proc = prev_proc
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
from app.api import app
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def config_dir(reset_cforch_globals):
|
||||
"""Return the tmp config dir (already set as _CONFIG_DIR)."""
|
||||
return reset_cforch_globals
|
||||
|
||||
|
||||
def _write_config(config_dir: Path, cforch_cfg: dict) -> None:
|
||||
"""Write a label_tool.yaml with the given cforch block into config_dir."""
|
||||
cfg = {"cforch": cforch_cfg}
|
||||
(config_dir / "label_tool.yaml").write_text(
|
||||
yaml.dump(cfg), encoding="utf-8"
|
||||
)
|
||||
|
||||
|
||||
def _write_tasks_yaml(path: Path, tasks: list[dict]) -> None:
|
||||
path.write_text(yaml.dump({"tasks": tasks}), encoding="utf-8")
|
||||
|
||||
|
||||
def _write_models_yaml(path: Path, models: list[dict]) -> None:
|
||||
path.write_text(yaml.dump({"models": models}), encoding="utf-8")
|
||||
|
||||
|
||||
# ── GET /tasks ─────────────────────────────────────────────────────────────────
|
||||
|
||||
def test_tasks_returns_empty_when_not_configured(client):
|
||||
"""No config file present — endpoint returns empty lists."""
|
||||
r = client.get("/api/cforch/tasks")
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data == {"tasks": [], "types": []}
|
||||
|
||||
|
||||
def test_tasks_parses_yaml(client, config_dir, tmp_path):
|
||||
tasks_file = tmp_path / "bench_tasks.yaml"
|
||||
_write_tasks_yaml(tasks_file, [
|
||||
{"id": "t1", "name": "Task One", "type": "instruction"},
|
||||
{"id": "t2", "name": "Task Two", "type": "reasoning"},
|
||||
])
|
||||
_write_config(config_dir, {"bench_tasks": str(tasks_file)})
|
||||
|
||||
r = client.get("/api/cforch/tasks")
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert len(data["tasks"]) == 2
|
||||
# TaskEntry now includes optional prompt/system fields (default "")
|
||||
t1 = data["tasks"][0]
|
||||
assert t1["id"] == "t1" and t1["name"] == "Task One" and t1["type"] == "instruction"
|
||||
t2 = data["tasks"][1]
|
||||
assert t2["id"] == "t2" and t2["name"] == "Task Two" and t2["type"] == "reasoning"
|
||||
assert "instruction" in data["types"]
|
||||
assert "reasoning" in data["types"]
|
||||
|
||||
|
||||
def test_tasks_returns_types_deduplicated(client, config_dir, tmp_path):
|
||||
"""Multiple tasks sharing a type — types list must not duplicate."""
|
||||
tasks_file = tmp_path / "bench_tasks.yaml"
|
||||
_write_tasks_yaml(tasks_file, [
|
||||
{"id": "t1", "name": "A", "type": "instruction"},
|
||||
{"id": "t2", "name": "B", "type": "instruction"},
|
||||
{"id": "t3", "name": "C", "type": "reasoning"},
|
||||
])
|
||||
_write_config(config_dir, {"bench_tasks": str(tasks_file)})
|
||||
|
||||
r = client.get("/api/cforch/tasks")
|
||||
data = r.json()
|
||||
assert data["types"].count("instruction") == 1
|
||||
assert len(data["types"]) == 2
|
||||
|
||||
|
||||
# ── GET /models ────────────────────────────────────────────────────────────────
|
||||
|
||||
def test_models_returns_empty_when_not_configured(client):
|
||||
"""No config file present — endpoint returns empty model list."""
|
||||
r = client.get("/api/cforch/models")
|
||||
assert r.status_code == 200
|
||||
assert r.json() == {"models": []}
|
||||
|
||||
|
||||
def test_models_parses_bench_models_yaml(client, config_dir, tmp_path):
|
||||
models_file = tmp_path / "bench_models.yaml"
|
||||
_write_models_yaml(models_file, [
|
||||
{
|
||||
"name": "llama3",
|
||||
"id": "llama3:8b",
|
||||
"service": "ollama",
|
||||
"tags": ["fast", "small"],
|
||||
"vram_estimate_mb": 6000,
|
||||
}
|
||||
])
|
||||
_write_config(config_dir, {"bench_models": str(models_file)})
|
||||
|
||||
r = client.get("/api/cforch/models")
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert len(data["models"]) == 1
|
||||
m = data["models"][0]
|
||||
assert m["name"] == "llama3"
|
||||
assert m["id"] == "llama3:8b"
|
||||
assert m["service"] == "ollama"
|
||||
assert m["tags"] == ["fast", "small"]
|
||||
assert m["vram_estimate_mb"] == 6000
|
||||
|
||||
|
||||
def test_models_merges_installed_generators(client, config_dir, tmp_path):
|
||||
"""Installed cf-text/vllm generator models appear in the model list,
|
||||
deduplicated against bench_models.yaml entries."""
|
||||
models_file = tmp_path / "bench_models.yaml"
|
||||
_write_models_yaml(models_file, [
|
||||
{"name": "llama3", "id": "llama3:8b", "service": "ollama", "tags": [], "vram_estimate_mb": 6000},
|
||||
{"name": "already-there", "id": "ibm-granite/granite-4.1-8b", "service": "cf-text", "tags": [], "vram_estimate_mb": 8000},
|
||||
])
|
||||
_write_config(config_dir, {"bench_models": str(models_file)})
|
||||
|
||||
fake_installed = [
|
||||
# should be included — cf-text generator not already in YAML
|
||||
{"model_id": "meta-llama/Llama-3.1-8B", "service": "cf-text", "role": "generator", "vram_mb": 16000},
|
||||
# should be deduped — repo_id matches a YAML entry
|
||||
{"model_id": "ibm-granite/granite-4.1-8b", "service": "cf-text", "role": "generator", "vram_mb": 8000},
|
||||
# should be excluded — classifier, not a generator
|
||||
{"model_id": "cross-encoder/ms-marco-MiniLM-L6", "service": "avocet", "role": "reranker", "vram_mb": 500},
|
||||
]
|
||||
with patch("app.models.list_installed", return_value=fake_installed):
|
||||
r = client.get("/api/cforch/models")
|
||||
assert r.status_code == 200
|
||||
ids = [m["id"] for m in r.json()["models"]]
|
||||
assert "llama3:8b" in ids # from YAML
|
||||
assert "ibm-granite/granite-4.1-8b" in ids # from YAML (not duplicated)
|
||||
assert "meta-llama/Llama-3.1-8B" in ids # merged from installed
|
||||
assert "cross-encoder/ms-marco-MiniLM-L6" not in ids # filtered out (reranker)
|
||||
assert ids.count("ibm-granite/granite-4.1-8b") == 1 # no duplicate
|
||||
|
||||
|
||||
# ── GET /run ───────────────────────────────────────────────────────────────────
|
||||
|
||||
def test_run_returns_409_when_already_running(client):
|
||||
"""If a benchmark subprocess is actively running, GET /run returns 409."""
|
||||
from unittest.mock import MagicMock
|
||||
from app import cforch as cforch_module
|
||||
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.poll.return_value = None # process still alive
|
||||
cforch_module._BENCH_RUNNING = True
|
||||
cforch_module._bench_proc = mock_proc
|
||||
|
||||
r = client.get("/api/cforch/run")
|
||||
assert r.status_code == 409
|
||||
|
||||
|
||||
def test_run_returns_error_when_bench_script_not_configured(client):
|
||||
"""No config at all — SSE stream contains an error event."""
|
||||
r = client.get("/api/cforch/run")
|
||||
assert r.status_code == 200
|
||||
assert '"type": "error"' in r.text
|
||||
assert "bench_script not configured" in r.text
|
||||
|
||||
|
||||
def test_run_streams_progress_events(client, config_dir, tmp_path):
|
||||
"""Mock subprocess — SSE stream emits progress events from stdout."""
|
||||
bench_script = tmp_path / "fake_benchmark.py"
|
||||
bench_script.write_text("# fake", encoding="utf-8")
|
||||
|
||||
tasks_file = tmp_path / "bench_tasks.yaml"
|
||||
tasks_file.write_text(yaml.dump({"tasks": []}), encoding="utf-8")
|
||||
models_file = tmp_path / "bench_models.yaml"
|
||||
models_file.write_text(yaml.dump({"models": []}), encoding="utf-8")
|
||||
results_dir = tmp_path / "results"
|
||||
results_dir.mkdir()
|
||||
|
||||
_write_config(config_dir, {
|
||||
"bench_script": str(bench_script),
|
||||
"bench_tasks": str(tasks_file),
|
||||
"bench_models": str(models_file),
|
||||
"results_dir": str(results_dir),
|
||||
"python_bin": "/usr/bin/python3",
|
||||
})
|
||||
|
||||
mock_stdout = MagicMock()
|
||||
mock_stdout.readline.side_effect = ["Running task 1\n", "Running task 2\n", ""]
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.stdout = mock_stdout
|
||||
mock_proc.returncode = 1 # non-zero so we don't need summary.json
|
||||
mock_proc.wait = MagicMock()
|
||||
|
||||
with patch("app.cforch._subprocess.Popen", return_value=mock_proc), \
|
||||
patch("app.cforch._select.select", return_value=([mock_stdout], [], [])):
|
||||
r = client.get("/api/cforch/run")
|
||||
|
||||
assert r.status_code == 200
|
||||
assert '"type": "progress"' in r.text
|
||||
assert "Running task 1" in r.text
|
||||
assert "Running task 2" in r.text
|
||||
|
||||
|
||||
def test_run_emits_result_on_success(client, config_dir, tmp_path):
|
||||
"""Mock subprocess exit 0 + write fake summary.json — stream emits result event."""
|
||||
bench_script = tmp_path / "fake_benchmark.py"
|
||||
bench_script.write_text("# fake", encoding="utf-8")
|
||||
|
||||
tasks_file = tmp_path / "bench_tasks.yaml"
|
||||
tasks_file.write_text(yaml.dump({"tasks": []}), encoding="utf-8")
|
||||
models_file = tmp_path / "bench_models.yaml"
|
||||
models_file.write_text(yaml.dump({"models": []}), encoding="utf-8")
|
||||
|
||||
results_dir = tmp_path / "results"
|
||||
run_dir = results_dir / "2026-04-08-120000"
|
||||
run_dir.mkdir(parents=True)
|
||||
summary_data = {"score": 0.92, "models_evaluated": 3}
|
||||
(run_dir / "summary.json").write_text(json.dumps(summary_data), encoding="utf-8")
|
||||
|
||||
_write_config(config_dir, {
|
||||
"bench_script": str(bench_script),
|
||||
"bench_tasks": str(tasks_file),
|
||||
"bench_models": str(models_file),
|
||||
"results_dir": str(results_dir),
|
||||
"python_bin": "/usr/bin/python3",
|
||||
})
|
||||
|
||||
mock_stdout = MagicMock()
|
||||
mock_stdout.readline.side_effect = [""] # no output lines, immediate EOF
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.stdout = mock_stdout
|
||||
mock_proc.returncode = 0
|
||||
mock_proc.wait = MagicMock()
|
||||
|
||||
with patch("app.cforch._subprocess.Popen", return_value=mock_proc), \
|
||||
patch("app.cforch._select.select", return_value=([mock_stdout], [], [])):
|
||||
r = client.get("/api/cforch/run")
|
||||
|
||||
assert r.status_code == 200
|
||||
assert '"type": "result"' in r.text
|
||||
assert '"score": 0.92' in r.text
|
||||
assert '"type": "complete"' in r.text
|
||||
|
||||
|
||||
# ── GET /results ───────────────────────────────────────────────────────────────
|
||||
|
||||
def test_results_returns_404_when_no_results(client):
|
||||
"""No results_dir configured — endpoint returns 404."""
|
||||
r = client.get("/api/cforch/results")
|
||||
assert r.status_code == 404
|
||||
|
||||
|
||||
def test_results_returns_latest_summary(client, config_dir, tmp_path):
|
||||
"""Write fake results dir with one subdir containing summary.json."""
|
||||
results_dir = tmp_path / "results"
|
||||
run_dir = results_dir / "2026-04-08-150000"
|
||||
run_dir.mkdir(parents=True)
|
||||
summary_data = {"score": 0.88, "run": "test"}
|
||||
(run_dir / "summary.json").write_text(json.dumps(summary_data), encoding="utf-8")
|
||||
|
||||
_write_config(config_dir, {"results_dir": str(results_dir)})
|
||||
|
||||
r = client.get("/api/cforch/results")
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["score"] == 0.88
|
||||
assert data["run"] == "test"
|
||||
|
||||
|
||||
# ── POST /cancel ───────────────────────────────────────────────────────────────
|
||||
|
||||
def test_cancel_returns_404_when_not_running(client):
|
||||
"""POST /cancel when no benchmark running — returns 404."""
|
||||
r = client.post("/api/cforch/cancel")
|
||||
assert r.status_code == 404
|
||||
|
||||
|
||||
def test_cancel_terminates_running_benchmark(client):
|
||||
"""POST /cancel when benchmark is running — terminates proc and returns cancelled."""
|
||||
from app import cforch as cforch_module
|
||||
|
||||
mock_proc = MagicMock()
|
||||
cforch_module._BENCH_RUNNING = True
|
||||
cforch_module._bench_proc = mock_proc
|
||||
|
||||
r = client.post("/api/cforch/cancel")
|
||||
assert r.status_code == 200
|
||||
assert r.json() == {"status": "cancelled"}
|
||||
mock_proc.terminate.assert_called_once()
|
||||
assert cforch_module._BENCH_RUNNING is False
|
||||
assert cforch_module._bench_proc is None
|
||||
|
||||
|
||||
# ── GET /config ────────────────────────────────────────────────────────────────
|
||||
|
||||
def test_config_returns_empty_when_no_yaml_no_env(client, monkeypatch):
|
||||
"""No yaml, no env vars — all fields empty, license_key_set False."""
|
||||
for key in ("CF_ORCH_URL", "CF_LICENSE_KEY", "OLLAMA_HOST", "OLLAMA_MODEL"):
|
||||
monkeypatch.delenv(key, raising=False)
|
||||
|
||||
r = client.get("/api/cforch/config")
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["coordinator_url"] == ""
|
||||
assert data["ollama_url"] == ""
|
||||
assert data["license_key_set"] is False
|
||||
|
||||
|
||||
def test_config_reads_env_vars_when_no_yaml(client, monkeypatch):
|
||||
"""Env vars populate fields when label_tool.yaml has no cforch section."""
|
||||
monkeypatch.setenv("CF_ORCH_URL", "http://orch.example.com:7700")
|
||||
monkeypatch.setenv("CF_LICENSE_KEY", "CFG-AVCT-TEST-TEST-TEST")
|
||||
monkeypatch.setenv("OLLAMA_HOST", "http://ollama.local:11434")
|
||||
monkeypatch.setenv("OLLAMA_MODEL", "mistral:7b")
|
||||
|
||||
r = client.get("/api/cforch/config")
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["coordinator_url"] == "http://orch.example.com:7700"
|
||||
assert data["ollama_url"] == "http://ollama.local:11434"
|
||||
assert data["ollama_model"] == "mistral:7b"
|
||||
assert data["license_key_set"] is True # set, but value not exposed
|
||||
|
||||
|
||||
def test_config_yaml_overrides_env(client, config_dir, monkeypatch):
|
||||
"""label_tool.yaml cforch values take priority over env vars."""
|
||||
monkeypatch.setenv("CF_ORCH_URL", "http://env-orch:7700")
|
||||
monkeypatch.setenv("OLLAMA_HOST", "http://env-ollama:11434")
|
||||
|
||||
_write_config(config_dir, {
|
||||
"coordinator_url": "http://yaml-orch:7700",
|
||||
"ollama_url": "http://yaml-ollama:11434",
|
||||
})
|
||||
|
||||
r = client.get("/api/cforch/config")
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["coordinator_url"] == "http://yaml-orch:7700"
|
||||
assert data["ollama_url"] == "http://yaml-ollama:11434"
|
||||
assert data["source"] == "yaml+env"
|
||||
|
||||
|
||||
def test_run_passes_license_key_env_to_subprocess(client, config_dir, tmp_path, monkeypatch):
|
||||
"""CF_LICENSE_KEY must be forwarded to the benchmark subprocess env."""
|
||||
monkeypatch.setenv("CF_LICENSE_KEY", "CFG-AVCT-ENV-ONLY-KEY")
|
||||
|
||||
bench_script = tmp_path / "benchmark.py"
|
||||
bench_script.write_text("# stub", encoding="utf-8")
|
||||
tasks_file = tmp_path / "bench_tasks.yaml"
|
||||
tasks_file.write_text(yaml.dump({"tasks": []}), encoding="utf-8")
|
||||
models_file = tmp_path / "bench_models.yaml"
|
||||
models_file.write_text(yaml.dump({"models": []}), encoding="utf-8")
|
||||
|
||||
_write_config(config_dir, {
|
||||
"bench_script": str(bench_script),
|
||||
"bench_tasks": str(tasks_file),
|
||||
"bench_models": str(models_file),
|
||||
"results_dir": str(tmp_path / "results"),
|
||||
"python_bin": "/usr/bin/python3",
|
||||
})
|
||||
|
||||
captured_env: dict = {}
|
||||
|
||||
def fake_popen(cmd, **kwargs):
|
||||
captured_env.update(kwargs.get("env", {}))
|
||||
mock = MagicMock()
|
||||
mock.stdout = iter([])
|
||||
mock.returncode = 0
|
||||
mock.wait = MagicMock()
|
||||
return mock
|
||||
|
||||
with patch("app.cforch._subprocess.Popen", side_effect=fake_popen):
|
||||
client.get("/api/cforch/run")
|
||||
|
||||
assert captured_env.get("CF_LICENSE_KEY") == "CFG-AVCT-ENV-ONLY-KEY"
|
||||
|
||||
|
||||
def test_eval_cforch_router_includes_all_sub_routers():
|
||||
"""eval/cforch.py router must include routes from all four sub-routers."""
|
||||
from app.eval.cforch import router
|
||||
paths = {r.path for r in router.routes}
|
||||
assert any("/cforch/" in p for p in paths), f"no /cforch/ routes found in {paths}"
|
||||
assert any("/style/" in p for p in paths), f"no /style/ routes found in {paths}"
|
||||
assert any("/voice/" in p for p in paths), f"no /voice/ routes found in {paths}"
|
||||
assert any("/plans-bench/" in p for p in paths), f"no /plans-bench/ routes found in {paths}"
|
||||
|
|
@ -268,373 +268,3 @@ def test_finetuned_adapter_unload_clears_pipeline():
|
|||
assert adapter._pipeline is not None
|
||||
adapter.unload()
|
||||
assert adapter._pipeline is None
|
||||
|
||||
# ---- _cosine() tests ----
|
||||
|
||||
def test_cosine_identical_unit_vectors():
|
||||
import math
|
||||
from scripts.classifier_adapters import _cosine
|
||||
assert _cosine([1.0, 0.0], [1.0, 0.0]) == pytest.approx(1.0)
|
||||
|
||||
|
||||
def test_cosine_orthogonal_vectors():
|
||||
from scripts.classifier_adapters import _cosine
|
||||
assert _cosine([1.0, 0.0], [0.0, 1.0]) == pytest.approx(0.0)
|
||||
|
||||
|
||||
def test_cosine_known_value():
|
||||
import math
|
||||
from scripts.classifier_adapters import _cosine
|
||||
# [1,0] vs [1/sqrt(2), 1/sqrt(2)] → dot = 1/sqrt(2), both norms = 1 → 1/sqrt(2)
|
||||
v = [1.0 / math.sqrt(2), 1.0 / math.sqrt(2)]
|
||||
assert _cosine([1.0, 0.0], v) == pytest.approx(1.0 / math.sqrt(2))
|
||||
|
||||
|
||||
def test_cosine_zero_vector_returns_zero():
|
||||
from scripts.classifier_adapters import _cosine
|
||||
assert _cosine([0.0, 0.0], [1.0, 0.0]) == pytest.approx(0.0)
|
||||
|
||||
|
||||
# ---- DEFAULT_EXEMPLARS tests ----
|
||||
|
||||
def test_default_exemplars_covers_all_labels():
|
||||
from scripts.classifier_adapters import DEFAULT_EXEMPLARS, LABELS
|
||||
for label in LABELS:
|
||||
assert label in DEFAULT_EXEMPLARS, f"DEFAULT_EXEMPLARS missing label: {label}"
|
||||
assert len(DEFAULT_EXEMPLARS[label]) >= 4, f"{label} needs >= 4 exemplars for k=3 voting"
|
||||
|
||||
|
||||
def test_default_exemplars_sparse_labels_have_at_least_four():
|
||||
from scripts.classifier_adapters import DEFAULT_EXEMPLARS
|
||||
# These labels have very few real examples; need >= 4 so k=3 vote is meaningful
|
||||
for label in ("hired", "survey_received", "event_rescheduled"):
|
||||
assert len(DEFAULT_EXEMPLARS[label]) >= 4, (
|
||||
f"{label} needs >= 4 exemplars for k=3 voting to work reliably"
|
||||
)
|
||||
|
||||
def test_default_exemplars_strings_are_formatted_correctly():
|
||||
from scripts.classifier_adapters import DEFAULT_EXEMPLARS
|
||||
for label, texts in DEFAULT_EXEMPLARS.items():
|
||||
for text in texts:
|
||||
assert text.startswith("Subject: "), (
|
||||
f"{label!r} exemplar missing 'Subject: ' prefix: {text[:50]!r}"
|
||||
)
|
||||
assert "\n\n" in text, (
|
||||
f"{label!r} exemplar missing double-newline separator: {text[:50]!r}"
|
||||
)
|
||||
|
||||
# ---- EmbeddingKNNAdapter constructor tests ----
|
||||
|
||||
def test_embedding_knn_is_classifier_adapter():
|
||||
from scripts.classifier_adapters import EmbeddingKNNAdapter, ClassifierAdapter
|
||||
adapter = EmbeddingKNNAdapter(
|
||||
"test-knn", "nomic-embed-text",
|
||||
k=3, orch_url="http://orch:7700", ollama_url="http://ollama:11434",
|
||||
)
|
||||
assert isinstance(adapter, ClassifierAdapter)
|
||||
|
||||
|
||||
def test_embedding_knn_name_and_model_id():
|
||||
from scripts.classifier_adapters import EmbeddingKNNAdapter
|
||||
adapter = EmbeddingKNNAdapter(
|
||||
"embed-knn-nomic", "nomic-embed-text",
|
||||
k=3, orch_url="http://orch:7700", ollama_url="http://ollama:11434",
|
||||
)
|
||||
assert adapter.name == "embed-knn-nomic"
|
||||
assert adapter.model_id == "nomic-embed-text"
|
||||
|
||||
|
||||
def test_embedding_knn_uses_default_exemplars_when_none_given():
|
||||
from scripts.classifier_adapters import EmbeddingKNNAdapter, DEFAULT_EXEMPLARS
|
||||
adapter = EmbeddingKNNAdapter(
|
||||
"test", "nomic-embed-text",
|
||||
k=3, orch_url="http://orch:7700", ollama_url="http://ollama:11434",
|
||||
)
|
||||
assert adapter._exemplar_texts is DEFAULT_EXEMPLARS
|
||||
|
||||
|
||||
def test_embedding_knn_accepts_custom_exemplars():
|
||||
from scripts.classifier_adapters import EmbeddingKNNAdapter
|
||||
custom = {"rejected": ["Sorry, we went with others."]}
|
||||
adapter = EmbeddingKNNAdapter(
|
||||
"test", "nomic-embed-text",
|
||||
k=3, orch_url="http://orch:7700", ollama_url="http://ollama:11434",
|
||||
exemplar_texts=custom,
|
||||
)
|
||||
assert adapter._exemplar_texts is custom
|
||||
|
||||
|
||||
# ---- EmbeddingKNNAdapter.load() tests ----
|
||||
|
||||
def _make_post_mock(alloc_url="http://navi:11434", alloc_id="alloc-abc"):
|
||||
"""Return a side_effect function for patching httpx.post.
|
||||
|
||||
Allocate calls get alloc_url/alloc_id; embed calls return one [0.1,0.2,0.3]
|
||||
embedding per input text.
|
||||
"""
|
||||
def _side_effect(url, *, json=None, timeout=None, **kwargs):
|
||||
from unittest.mock import MagicMock
|
||||
resp = MagicMock()
|
||||
resp.raise_for_status.return_value = None
|
||||
if "/allocate" in url:
|
||||
resp.status_code = 200
|
||||
resp.json.return_value = {"allocation_id": alloc_id, "url": alloc_url}
|
||||
else:
|
||||
n = len((json or {}).get("input", []))
|
||||
resp.status_code = 200
|
||||
resp.json.return_value = {"data": [{"embedding": [0.1, 0.2, 0.3]}] * n}
|
||||
return resp
|
||||
return _side_effect
|
||||
|
||||
|
||||
def test_load_calls_allocate_then_embeds_each_label():
|
||||
from unittest.mock import patch
|
||||
from scripts.classifier_adapters import EmbeddingKNNAdapter
|
||||
|
||||
exemplars = {
|
||||
"rejected": ["We went with others"],
|
||||
"hired": ["Welcome aboard!", "First day info"],
|
||||
}
|
||||
adapter = EmbeddingKNNAdapter(
|
||||
"test", "nomic-embed-text", k=3,
|
||||
orch_url="http://orch:7700", ollama_url="http://ollama:11434",
|
||||
exemplar_texts=exemplars,
|
||||
)
|
||||
|
||||
post_urls = []
|
||||
def capturing_mock(url, *, json=None, timeout=None, **kwargs):
|
||||
post_urls.append(url)
|
||||
return _make_post_mock()(url, json=json, timeout=timeout)
|
||||
|
||||
with patch("httpx.post", side_effect=capturing_mock):
|
||||
adapter.load()
|
||||
|
||||
assert any("/allocate" in u for u in post_urls), "expected allocate call"
|
||||
assert any("/v1/embeddings" in u for u in post_urls), "expected embed call"
|
||||
assert adapter._allocation_id == "alloc-abc"
|
||||
assert adapter._node_url == "http://navi:11434"
|
||||
assert adapter._orch_url_used == "http://orch:7700"
|
||||
assert "rejected" in adapter._exemplar_embeddings
|
||||
assert "hired" in adapter._exemplar_embeddings
|
||||
assert len(adapter._exemplar_embeddings["rejected"]) == 1
|
||||
assert len(adapter._exemplar_embeddings["hired"]) == 2
|
||||
assert adapter._exemplar_embeddings["rejected"][0] == [0.1, 0.2, 0.3]
|
||||
assert adapter._exemplar_embeddings["hired"][0] == [0.1, 0.2, 0.3]
|
||||
|
||||
|
||||
def test_load_falls_back_to_ollama_when_allocate_fails():
|
||||
from unittest.mock import patch, MagicMock
|
||||
from scripts.classifier_adapters import EmbeddingKNNAdapter
|
||||
|
||||
exemplars = {"rejected": ["We went with others"]}
|
||||
adapter = EmbeddingKNNAdapter(
|
||||
"test", "nomic-embed-text", k=3,
|
||||
orch_url="http://orch:7700", ollama_url="http://ollama:11434",
|
||||
exemplar_texts=exemplars,
|
||||
)
|
||||
|
||||
def failing_allocate_mock(url, *, json=None, timeout=None, **kwargs):
|
||||
resp = MagicMock()
|
||||
if "/allocate" in url:
|
||||
resp.status_code = 503
|
||||
resp.json.return_value = {}
|
||||
else:
|
||||
resp.raise_for_status.return_value = None
|
||||
resp.json.return_value = {"data": [{"embedding": [0.1, 0.2, 0.3]}]}
|
||||
return resp
|
||||
|
||||
with patch("httpx.post", side_effect=failing_allocate_mock):
|
||||
adapter.load()
|
||||
|
||||
assert adapter._allocation_id == ""
|
||||
assert adapter._orch_url_used == ""
|
||||
assert adapter._node_url == "http://ollama:11434"
|
||||
assert "rejected" in adapter._exemplar_embeddings
|
||||
|
||||
|
||||
def test_load_falls_back_to_ollama_when_allocate_raises():
|
||||
from unittest.mock import patch, MagicMock
|
||||
import httpx as _httpx
|
||||
from scripts.classifier_adapters import EmbeddingKNNAdapter
|
||||
|
||||
exemplars = {"rejected": ["We went with others"]}
|
||||
adapter = EmbeddingKNNAdapter(
|
||||
"test", "nomic-embed-text", k=3,
|
||||
orch_url="http://orch:7700", ollama_url="http://ollama:11434",
|
||||
exemplar_texts=exemplars,
|
||||
)
|
||||
|
||||
def raising_mock(url, *, json=None, timeout=None, **kwargs):
|
||||
if "/allocate" in url:
|
||||
raise _httpx.ConnectError("connection refused")
|
||||
resp = MagicMock()
|
||||
resp.raise_for_status.return_value = None
|
||||
resp.json.return_value = {"data": [{"embedding": [0.1, 0.2, 0.3]}]}
|
||||
return resp
|
||||
|
||||
with patch("httpx.post", side_effect=raising_mock):
|
||||
adapter.load()
|
||||
|
||||
assert adapter._allocation_id == ""
|
||||
assert adapter._orch_url_used == ""
|
||||
assert adapter._node_url == "http://ollama:11434"
|
||||
assert "rejected" in adapter._exemplar_embeddings
|
||||
|
||||
|
||||
# ---- EmbeddingKNNAdapter.unload() tests ----
|
||||
|
||||
def test_unload_releases_orch_allocation_and_clears_state():
|
||||
from unittest.mock import patch, MagicMock
|
||||
from scripts.classifier_adapters import EmbeddingKNNAdapter
|
||||
|
||||
adapter = EmbeddingKNNAdapter(
|
||||
"test", "nomic-embed-text", k=3,
|
||||
orch_url="http://orch:7700", ollama_url="http://ollama:11434",
|
||||
)
|
||||
adapter._exemplar_embeddings = {"rejected": [[1.0, 0.0]]}
|
||||
adapter._node_url = "http://navi:11434"
|
||||
adapter._allocation_id = "alloc-abc"
|
||||
adapter._orch_url_used = "http://orch:7700"
|
||||
|
||||
delete_calls = []
|
||||
def mock_request(method, url, **kwargs):
|
||||
delete_calls.append((method, url))
|
||||
resp = MagicMock()
|
||||
resp.status_code = 200
|
||||
return resp
|
||||
|
||||
with patch("httpx.request", side_effect=mock_request):
|
||||
adapter.unload()
|
||||
|
||||
assert len(delete_calls) == 1
|
||||
method, url = delete_calls[0]
|
||||
assert method == "DELETE"
|
||||
assert "alloc-abc" in url
|
||||
assert adapter._exemplar_embeddings == {}
|
||||
assert adapter._allocation_id == ""
|
||||
assert adapter._node_url == ""
|
||||
assert adapter._orch_url_used == ""
|
||||
|
||||
|
||||
def test_unload_skips_delete_on_ollama_fallback_path():
|
||||
from unittest.mock import patch
|
||||
from scripts.classifier_adapters import EmbeddingKNNAdapter
|
||||
|
||||
adapter = EmbeddingKNNAdapter(
|
||||
"test", "nomic-embed-text", k=3,
|
||||
orch_url="http://orch:7700", ollama_url="http://ollama:11434",
|
||||
)
|
||||
adapter._exemplar_embeddings = {"rejected": [[1.0, 0.0]]}
|
||||
adapter._node_url = "http://ollama:11434"
|
||||
adapter._allocation_id = "" # fallback path: no allocation was made
|
||||
adapter._orch_url_used = ""
|
||||
|
||||
delete_calls = []
|
||||
with patch("httpx.request", side_effect=lambda *a, **k: delete_calls.append(a)):
|
||||
adapter.unload()
|
||||
|
||||
assert len(delete_calls) == 0
|
||||
assert adapter._exemplar_embeddings == {}
|
||||
assert adapter._node_url == ""
|
||||
|
||||
|
||||
# ---- EmbeddingKNNAdapter.classify() tests ----
|
||||
|
||||
def _adapter_with_embeddings(exemplar_embeddings, k=3):
|
||||
"""Return a pre-loaded EmbeddingKNNAdapter (bypass load()) with given per-label vectors."""
|
||||
from scripts.classifier_adapters import EmbeddingKNNAdapter
|
||||
adapter = EmbeddingKNNAdapter(
|
||||
"test", "nomic-embed-text", k=k,
|
||||
orch_url="http://orch:7700", ollama_url="http://ollama:11434",
|
||||
)
|
||||
adapter._exemplar_embeddings = exemplar_embeddings
|
||||
adapter._node_url = "http://navi:11434"
|
||||
return adapter
|
||||
|
||||
|
||||
def _embed_resp(vec):
|
||||
"""Return a mock httpx response for /v1/embeddings returning a single vector."""
|
||||
from unittest.mock import MagicMock
|
||||
resp = MagicMock()
|
||||
resp.raise_for_status.return_value = None
|
||||
resp.json.return_value = {"data": [{"embedding": vec}]}
|
||||
return resp
|
||||
|
||||
|
||||
def test_classify_returns_majority_vote_label():
|
||||
from unittest.mock import patch
|
||||
adapter = _adapter_with_embeddings({
|
||||
"rejected": [[1.0, 0.0, 0.0], [0.9, 0.1, 0.0], [0.85, 0.15, 0.0]],
|
||||
"neutral": [[0.0, 1.0, 0.0]],
|
||||
}, k=3)
|
||||
|
||||
# Query [1,0,0] is closest to all three "rejected" exemplars
|
||||
with patch("httpx.post", return_value=_embed_resp([1.0, 0.0, 0.0])):
|
||||
result = adapter.classify("We went with others", "Thank you for applying.")
|
||||
|
||||
assert result == "rejected"
|
||||
|
||||
|
||||
def test_classify_tiebreak_by_mean_score():
|
||||
from unittest.mock import patch
|
||||
# k=2: each label gets exactly 1 vote → tie-break by mean similarity
|
||||
# [1,0] query: cosine to [1,0] = 1.0 ("rejected"), cosine to [0.6,0.8] ≈ 0.6 ("neutral")
|
||||
adapter = _adapter_with_embeddings({
|
||||
"rejected": [[1.0, 0.0]],
|
||||
"neutral": [[0.6, 0.8]],
|
||||
}, k=2)
|
||||
|
||||
with patch("httpx.post", return_value=_embed_resp([1.0, 0.0])):
|
||||
result = adapter.classify("Rejection", "Sorry")
|
||||
|
||||
assert result == "rejected"
|
||||
|
||||
|
||||
def test_classify_sparse_label_can_win():
|
||||
from unittest.mock import patch
|
||||
# "hired" has only 1 exemplar; with k=1, the single closest match wins
|
||||
adapter = _adapter_with_embeddings({
|
||||
"rejected": [[0.0, 0.0, 1.0], [0.0, 0.1, 0.9]],
|
||||
"hired": [[1.0, 0.0, 0.0]],
|
||||
}, k=1)
|
||||
|
||||
# Query [1,0,0] → hired exemplar scores 1.0; closest single match wins
|
||||
with patch("httpx.post", return_value=_embed_resp([1.0, 0.0, 0.0])):
|
||||
result = adapter.classify("Welcome aboard", "Your first day details")
|
||||
|
||||
assert result == "hired"
|
||||
|
||||
|
||||
def test_classify_lazy_loads_when_not_loaded():
|
||||
from unittest.mock import patch
|
||||
from scripts.classifier_adapters import EmbeddingKNNAdapter
|
||||
|
||||
exemplars = {"rejected": ["We went with others"]}
|
||||
adapter = EmbeddingKNNAdapter(
|
||||
"test", "nomic-embed-text", k=1,
|
||||
orch_url="http://orch:7700", ollama_url="http://ollama:11434",
|
||||
exemplar_texts=exemplars,
|
||||
)
|
||||
assert adapter._exemplar_embeddings == {}
|
||||
|
||||
post_urls = []
|
||||
def mock_post(url, *, json=None, timeout=None, **kwargs):
|
||||
post_urls.append(url)
|
||||
from unittest.mock import MagicMock
|
||||
resp = MagicMock()
|
||||
resp.raise_for_status.return_value = None
|
||||
if "/allocate" in url:
|
||||
resp.status_code = 200
|
||||
resp.json.return_value = {"allocation_id": "a1", "url": "http://navi:11434"}
|
||||
else:
|
||||
n = len((json or {}).get("input", []))
|
||||
resp.json.return_value = {"data": [{"embedding": [1.0, 0.0]}] * n}
|
||||
return resp
|
||||
|
||||
with patch("httpx.post", side_effect=mock_post):
|
||||
result = adapter.classify("Rejection", "Sorry")
|
||||
|
||||
assert result == "rejected"
|
||||
assert any("/allocate" in u for u in post_urls), "lazy load must call allocate"
|
||||
assert adapter._exemplar_embeddings != {}
|
||||
assert adapter._node_url == "http://navi:11434"
|
||||
|
|
|
|||
|
|
@ -1,122 +0,0 @@
|
|||
"""Tests for app/dashboard.py -- GET /api/dashboard."""
|
||||
import json
|
||||
import pytest
|
||||
import yaml
|
||||
from fastapi.testclient import TestClient
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_globals(tmp_path):
|
||||
from app import dashboard as dash_module
|
||||
dash_module.set_data_dir(tmp_path)
|
||||
dash_module.set_config_dir(tmp_path)
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
from app.api import app
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
def _write_score(tmp_path: Path, records: list[dict]) -> None:
|
||||
(tmp_path / "email_score.jsonl").write_text(
|
||||
"\n".join(json.dumps(r) for r in records) + "\n"
|
||||
)
|
||||
|
||||
def _write_summary(tmp_path: Path, run_id: str, ts: str, score: float) -> None:
|
||||
run_dir = tmp_path / "bench_results" / run_id
|
||||
run_dir.mkdir(parents=True)
|
||||
(run_dir / "summary.json").write_text(
|
||||
json.dumps({"timestamp": ts, "best_macro_f1": score})
|
||||
)
|
||||
|
||||
|
||||
def test_dashboard_returns_expected_keys(client):
|
||||
r = client.get("/api/dashboard")
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
for key in ("labeled_since_last_eval", "last_eval_timestamp", "last_eval_best_score",
|
||||
"active_jobs", "corrections_pending", "corrections_export_ready", "signals"):
|
||||
assert key in data, f"missing key: {key}"
|
||||
for sig in ("data_to_eval", "eval_to_train", "train_to_fleet"):
|
||||
assert sig in data["signals"], f"missing signal: {sig}"
|
||||
|
||||
|
||||
def test_dashboard_empty_state(client):
|
||||
r = client.get("/api/dashboard")
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["labeled_since_last_eval"] == 0
|
||||
assert data["last_eval_timestamp"] is None
|
||||
assert data["last_eval_best_score"] is None
|
||||
assert data["active_jobs"] == []
|
||||
assert data["corrections_pending"] == 0
|
||||
assert data["corrections_export_ready"] == 0
|
||||
|
||||
|
||||
def test_labeled_since_counts_all_when_no_eval(client, tmp_path):
|
||||
_write_score(tmp_path, [
|
||||
{"id": "a", "label": "neutral", "labeled_at": "2026-05-01T10:00:00+00:00"},
|
||||
{"id": "b", "label": "neutral", "labeled_at": "2026-05-01T11:00:00+00:00"},
|
||||
])
|
||||
r = client.get("/api/dashboard")
|
||||
assert r.json()["labeled_since_last_eval"] == 2
|
||||
|
||||
|
||||
def test_labeled_since_filters_by_eval_timestamp(client, tmp_path):
|
||||
_write_summary(tmp_path, "2026-05-01-100000", "2026-05-01T10:00:00+00:00", 0.80)
|
||||
_write_score(tmp_path, [
|
||||
{"id": "a", "label": "neutral", "labeled_at": "2026-05-01T09:00:00+00:00"},
|
||||
{"id": "b", "label": "neutral", "labeled_at": "2026-05-01T11:00:00+00:00"},
|
||||
])
|
||||
(tmp_path / "label_tool.yaml").write_text(
|
||||
yaml.dump({"cforch": {"results_dir": str(tmp_path / "bench_results")}})
|
||||
)
|
||||
r = client.get("/api/dashboard")
|
||||
data = r.json()
|
||||
assert data["labeled_since_last_eval"] == 1
|
||||
assert abs(data["last_eval_best_score"] - 0.80) < 0.001
|
||||
|
||||
|
||||
def test_data_to_eval_false_below_threshold(client, tmp_path):
|
||||
_write_score(tmp_path, [{"id": str(i), "label": "neutral",
|
||||
"labeled_at": "2026-05-01T10:00:00+00:00"} for i in range(10)])
|
||||
(tmp_path / "label_tool.yaml").write_text(yaml.dump({"pipeline": {"data_eval_threshold": 50}}))
|
||||
r = client.get("/api/dashboard")
|
||||
assert r.json()["signals"]["data_to_eval"] is False
|
||||
|
||||
|
||||
def test_data_to_eval_true_at_threshold(client, tmp_path):
|
||||
_write_score(tmp_path, [{"id": str(i), "label": "neutral",
|
||||
"labeled_at": "2026-05-01T10:00:00+00:00"} for i in range(50)])
|
||||
(tmp_path / "label_tool.yaml").write_text(yaml.dump({"pipeline": {"data_eval_threshold": 50}}))
|
||||
r = client.get("/api/dashboard")
|
||||
assert r.json()["signals"]["data_to_eval"] is True
|
||||
|
||||
|
||||
def test_corrections_pending_count(client, tmp_path):
|
||||
candidates = [
|
||||
{"id": "c1", "status": "needs_review"},
|
||||
{"id": "c2", "status": "needs_review"},
|
||||
{"id": "c3", "status": "discarded"},
|
||||
]
|
||||
(tmp_path / "sft_candidates.jsonl").write_text(
|
||||
"\n".join(json.dumps(c) for c in candidates) + "\n"
|
||||
)
|
||||
r = client.get("/api/dashboard")
|
||||
assert r.json()["corrections_pending"] == 2
|
||||
|
||||
|
||||
def test_corrections_export_ready_count(client, tmp_path):
|
||||
approved = [
|
||||
{"id": "a1", "status": "approved", "corrected_response": "Good answer"},
|
||||
{"id": "a2", "status": "approved", "corrected_response": ""},
|
||||
{"id": "a3", "status": "approved", "corrected_response": "Another answer"},
|
||||
]
|
||||
(tmp_path / "sft_approved.jsonl").write_text(
|
||||
"\n".join(json.dumps(a) for a in approved) + "\n"
|
||||
)
|
||||
r = client.get("/api/dashboard")
|
||||
assert r.json()["corrections_export_ready"] == 2
|
||||
|
|
@ -1,102 +0,0 @@
|
|||
"""Tests for app/data/corrections.py -- POST /api/sft/ingest.
|
||||
|
||||
The corrections router is mounted at prefix="/api/sft" via the app/sft.py
|
||||
backward-compat shim, so ingest lives at /api/sft/ingest.
|
||||
"""
|
||||
import json
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_globals(tmp_path):
|
||||
from app.data import corrections as corr_module
|
||||
corr_module.set_data_dir(tmp_path)
|
||||
corr_module.set_config_dir(tmp_path)
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
from app.api import app
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
_VALID_PAYLOAD = {
|
||||
"source": "peregrine",
|
||||
"task_type": "email_classification",
|
||||
"prompt": "Classify this email: ...",
|
||||
"response": "skip",
|
||||
"correction": "action_required",
|
||||
"label": "action_required",
|
||||
}
|
||||
|
||||
_SECRET = "test-secret-abc123"
|
||||
|
||||
|
||||
def test_ingest_503_when_secret_not_configured(client, monkeypatch):
|
||||
monkeypatch.delenv("AVOCET_INGESTION_SECRET", raising=False)
|
||||
r = client.post("/api/sft/ingest", json=_VALID_PAYLOAD,
|
||||
headers={"Authorization": f"Bearer {_SECRET}"})
|
||||
assert r.status_code == 503
|
||||
|
||||
|
||||
def test_ingest_401_when_no_auth_header(client, monkeypatch):
|
||||
monkeypatch.setenv("AVOCET_INGESTION_SECRET", _SECRET)
|
||||
r = client.post("/api/sft/ingest", json=_VALID_PAYLOAD)
|
||||
assert r.status_code == 401
|
||||
|
||||
|
||||
def test_ingest_401_when_malformed_header(client, monkeypatch):
|
||||
monkeypatch.setenv("AVOCET_INGESTION_SECRET", _SECRET)
|
||||
r = client.post("/api/sft/ingest", json=_VALID_PAYLOAD,
|
||||
headers={"Authorization": "Token bad-format"})
|
||||
assert r.status_code == 401
|
||||
|
||||
|
||||
def test_ingest_403_when_wrong_secret(client, monkeypatch):
|
||||
monkeypatch.setenv("AVOCET_INGESTION_SECRET", _SECRET)
|
||||
r = client.post("/api/sft/ingest", json=_VALID_PAYLOAD,
|
||||
headers={"Authorization": "Bearer wrong-secret"})
|
||||
assert r.status_code == 403
|
||||
|
||||
|
||||
def test_ingest_creates_approved_record(client, monkeypatch, tmp_path):
|
||||
from app.data import corrections as corr_module
|
||||
monkeypatch.setenv("AVOCET_INGESTION_SECRET", _SECRET)
|
||||
corr_module.set_data_dir(tmp_path)
|
||||
r = client.post("/api/sft/ingest", json=_VALID_PAYLOAD,
|
||||
headers={"Authorization": f"Bearer {_SECRET}"})
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["ok"] is True
|
||||
assert "id" in data
|
||||
candidates = corr_module.read_jsonl(corr_module._candidates_file())
|
||||
assert len(candidates) == 1
|
||||
rec = candidates[0]
|
||||
assert rec["status"] == "approved"
|
||||
assert rec["source"] == "peregrine"
|
||||
assert rec["corrected_response"] == "action_required"
|
||||
assert rec["id"] == data["id"]
|
||||
|
||||
|
||||
def test_ingest_also_writes_to_approved_file(client, monkeypatch, tmp_path):
|
||||
from app.data import corrections as corr_module
|
||||
monkeypatch.setenv("AVOCET_INGESTION_SECRET", _SECRET)
|
||||
corr_module.set_data_dir(tmp_path)
|
||||
r = client.post("/api/sft/ingest", json=_VALID_PAYLOAD,
|
||||
headers={"Authorization": f"Bearer {_SECRET}"})
|
||||
assert r.status_code == 200
|
||||
approved = corr_module.read_jsonl(corr_module._approved_file())
|
||||
assert len(approved) == 1
|
||||
assert approved[0]["id"] == r.json()["id"]
|
||||
|
||||
|
||||
def test_ingest_without_label_is_accepted(client, monkeypatch, tmp_path):
|
||||
from app.data import corrections as corr_module
|
||||
monkeypatch.setenv("AVOCET_INGESTION_SECRET", _SECRET)
|
||||
corr_module.set_data_dir(tmp_path)
|
||||
payload = {**_VALID_PAYLOAD, "label": None}
|
||||
r = client.post("/api/sft/ingest", json=payload,
|
||||
headers={"Authorization": f"Bearer {_SECRET}"})
|
||||
assert r.status_code == 200
|
||||
|
|
@ -1,95 +0,0 @@
|
|||
"""Tests for app/data/fetch.py"""
|
||||
import json
|
||||
import yaml
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_globals(tmp_path):
|
||||
from app.data import fetch as fetch_module
|
||||
fetch_module.set_data_dir(tmp_path)
|
||||
fetch_module.set_config_dir(tmp_path)
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
from app.api import app
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
def _parse_sse(content: bytes) -> list[dict]:
|
||||
events = []
|
||||
for line in content.decode().splitlines():
|
||||
if line.startswith("data: "):
|
||||
events.append(json.loads(line[6:]))
|
||||
return events
|
||||
|
||||
|
||||
def test_account_test_missing_fields(client):
|
||||
r = client.post("/api/accounts/test",
|
||||
json={"account": {"host": "", "username": "", "password": ""}})
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["ok"] is False
|
||||
assert "required" in data["message"].lower()
|
||||
|
||||
|
||||
def test_account_test_success(client):
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.select.return_value = ("OK", [b"99"])
|
||||
with patch("app.data.fetch.imaplib.IMAP4_SSL", return_value=mock_conn):
|
||||
r = client.post("/api/accounts/test", json={"account": {
|
||||
"host": "imap.example.com", "port": 993, "use_ssl": True,
|
||||
"username": "u@example.com", "password": "pw", "folder": "INBOX",
|
||||
}})
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["ok"] is True
|
||||
assert data["count"] == 99
|
||||
|
||||
|
||||
def test_fetch_stream_no_accounts_configured(client, tmp_path):
|
||||
r = client.get("/api/fetch/stream?accounts=NoSuchAccount&days_back=30&limit=10")
|
||||
assert r.status_code == 200
|
||||
events = _parse_sse(r.content)
|
||||
complete = next((e for e in events if e["type"] == "complete"), None)
|
||||
assert complete is not None
|
||||
assert complete["total_added"] == 0
|
||||
|
||||
|
||||
def test_fetch_stream_with_mock_imap(client, tmp_path):
|
||||
from app.data import fetch as fetch_module
|
||||
fetch_module.set_config_dir(tmp_path)
|
||||
cfg = {"accounts": [{"name": "Mock", "host": "h", "port": 993, "use_ssl": True,
|
||||
"username": "u", "password": "p", "folder": "INBOX",
|
||||
"days_back": 30}], "max_per_account": 50}
|
||||
(tmp_path / "label_tool.yaml").write_text(yaml.dump(cfg))
|
||||
raw_msg = (b"Subject: Interview\r\nFrom: a@b.com\r\n"
|
||||
b"Date: Mon, 1 Mar 2026 12:00:00 +0000\r\n\r\nBody")
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.search.return_value = ("OK", [b"1"])
|
||||
mock_conn.fetch.return_value = ("OK", [(b"1 (RFC822 {N})", raw_msg)])
|
||||
with patch("app.data.fetch.imaplib.IMAP4_SSL", return_value=mock_conn):
|
||||
r = client.get("/api/fetch/stream?accounts=Mock&days_back=30&limit=50")
|
||||
assert r.status_code == 200
|
||||
events = _parse_sse(r.content)
|
||||
types = [e["type"] for e in events]
|
||||
assert "start" in types
|
||||
assert "done" in types
|
||||
assert "complete" in types
|
||||
|
||||
|
||||
def test_entry_key_deterministic():
|
||||
from app.data.fetch import entry_key
|
||||
e = {"subject": "Test", "body": "Hello world"}
|
||||
assert entry_key(e) == entry_key(e)
|
||||
|
||||
|
||||
def test_entry_key_differs_by_subject():
|
||||
from app.data.fetch import entry_key
|
||||
a = {"subject": "A", "body": "same body"}
|
||||
b = {"subject": "B", "body": "same body"}
|
||||
assert entry_key(a) != entry_key(b)
|
||||
|
|
@ -1,219 +0,0 @@
|
|||
"""Tests for app/data/label.py"""
|
||||
import json
|
||||
import pytest
|
||||
import yaml
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_globals(tmp_path):
|
||||
from app.data import label as label_module
|
||||
label_module.set_data_dir(tmp_path)
|
||||
label_module.set_config_dir(tmp_path)
|
||||
label_module.reset_last_action()
|
||||
yield
|
||||
label_module.reset_last_action()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
from app.api import app
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def queue_with_items(tmp_path):
|
||||
from app.data import label as label_module
|
||||
items = [
|
||||
{"id": f"id{i}", "subject": f"Subject {i}", "body": f"Body {i}",
|
||||
"from": "test@example.com", "date": "2026-03-01", "source": "imap:test"}
|
||||
for i in range(3)
|
||||
]
|
||||
(label_module._DATA_DIR / "email_label_queue.jsonl").write_text(
|
||||
"\n".join(json.dumps(x) for x in items) + "\n")
|
||||
return items
|
||||
|
||||
|
||||
def test_queue_returns_items(client, queue_with_items):
|
||||
r = client.get("/api/queue?limit=2")
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert len(data["items"]) == 2
|
||||
assert data["total"] == 3
|
||||
|
||||
|
||||
def test_queue_empty_when_no_file(client):
|
||||
r = client.get("/api/queue")
|
||||
assert r.status_code == 200
|
||||
assert r.json() == {"items": [], "total": 0}
|
||||
|
||||
|
||||
def test_label_appends_to_score(client, queue_with_items):
|
||||
from app.data import label as label_module
|
||||
r = client.post("/api/label", json={"id": "id0", "label": "interview_scheduled"})
|
||||
assert r.status_code == 200
|
||||
records = label_module.read_jsonl(label_module._score_file())
|
||||
assert len(records) == 1
|
||||
assert records[0]["id"] == "id0"
|
||||
assert records[0]["label"] == "interview_scheduled"
|
||||
assert "labeled_at" in records[0]
|
||||
|
||||
|
||||
def test_label_removes_from_queue(client, queue_with_items):
|
||||
from app.data import label as label_module
|
||||
client.post("/api/label", json={"id": "id0", "label": "rejected"})
|
||||
queue = label_module.read_jsonl(label_module._queue_file())
|
||||
assert not any(x["id"] == "id0" for x in queue)
|
||||
|
||||
|
||||
def test_label_unknown_id_returns_404(client, queue_with_items):
|
||||
r = client.post("/api/label", json={"id": "unknown", "label": "neutral"})
|
||||
assert r.status_code == 404
|
||||
|
||||
|
||||
def test_skip_moves_to_back(client, queue_with_items):
|
||||
from app.data import label as label_module
|
||||
r = client.post("/api/skip", json={"id": "id0"})
|
||||
assert r.status_code == 200
|
||||
queue = label_module.read_jsonl(label_module._queue_file())
|
||||
assert queue[-1]["id"] == "id0"
|
||||
assert queue[0]["id"] == "id1"
|
||||
|
||||
|
||||
def test_skip_unknown_id_returns_404(client, queue_with_items):
|
||||
r = client.post("/api/skip", json={"id": "nope"})
|
||||
assert r.status_code == 404
|
||||
|
||||
|
||||
def test_discard_writes_to_discarded_file(client, queue_with_items):
|
||||
from app.data import label as label_module
|
||||
r = client.post("/api/discard", json={"id": "id1"})
|
||||
assert r.status_code == 200
|
||||
discarded = label_module.read_jsonl(label_module._discarded_file())
|
||||
assert len(discarded) == 1
|
||||
assert discarded[0]["id"] == "id1"
|
||||
assert discarded[0]["label"] == "__discarded__"
|
||||
|
||||
|
||||
def test_discard_removes_from_queue(client, queue_with_items):
|
||||
from app.data import label as label_module
|
||||
client.post("/api/discard", json={"id": "id1"})
|
||||
queue = label_module.read_jsonl(label_module._queue_file())
|
||||
assert not any(x["id"] == "id1" for x in queue)
|
||||
|
||||
|
||||
def test_undo_label_removes_from_score(client, queue_with_items):
|
||||
from app.data import label as label_module
|
||||
client.post("/api/label", json={"id": "id0", "label": "neutral"})
|
||||
r = client.delete("/api/label/undo")
|
||||
assert r.status_code == 200
|
||||
assert r.json()["undone"]["type"] == "label"
|
||||
assert label_module.read_jsonl(label_module._score_file()) == []
|
||||
queue = label_module.read_jsonl(label_module._queue_file())
|
||||
assert queue[0]["id"] == "id0"
|
||||
|
||||
|
||||
def test_undo_discard_removes_from_discarded(client, queue_with_items):
|
||||
from app.data import label as label_module
|
||||
client.post("/api/discard", json={"id": "id0"})
|
||||
r = client.delete("/api/label/undo")
|
||||
assert r.status_code == 200
|
||||
assert label_module.read_jsonl(label_module._discarded_file()) == []
|
||||
|
||||
|
||||
def test_undo_skip_restores_to_front(client, queue_with_items):
|
||||
from app.data import label as label_module
|
||||
client.post("/api/skip", json={"id": "id0"})
|
||||
r = client.delete("/api/label/undo")
|
||||
assert r.status_code == 200
|
||||
queue = label_module.read_jsonl(label_module._queue_file())
|
||||
assert queue[0]["id"] == "id0"
|
||||
|
||||
|
||||
def test_undo_with_no_action_returns_404(client):
|
||||
r = client.delete("/api/label/undo")
|
||||
assert r.status_code == 404
|
||||
|
||||
|
||||
def test_config_labels_returns_10_labels(client):
|
||||
r = client.get("/api/config/labels")
|
||||
assert r.status_code == 200
|
||||
labels = r.json()
|
||||
assert len(labels) == 10
|
||||
assert labels[0]["key"] == "1"
|
||||
for lbl in labels:
|
||||
assert "emoji" in lbl and "color" in lbl and "name" in lbl
|
||||
|
||||
|
||||
def test_get_config_returns_empty_when_no_file(client):
|
||||
r = client.get("/api/config")
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["accounts"] == []
|
||||
assert data["max_per_account"] == 500
|
||||
|
||||
|
||||
def test_post_config_writes_yaml(client, tmp_path):
|
||||
from app.data import label as label_module
|
||||
label_module.set_config_dir(tmp_path)
|
||||
payload = {"accounts": [{"name": "Test", "host": "imap.test.com", "port": 993,
|
||||
"use_ssl": True, "username": "u@t.com", "password": "pw",
|
||||
"folder": "INBOX", "days_back": 30}], "max_per_account": 200}
|
||||
r = client.post("/api/config", json=payload)
|
||||
assert r.status_code == 200
|
||||
assert r.json()["ok"] is True
|
||||
saved = yaml.safe_load((tmp_path / "label_tool.yaml").read_text())
|
||||
assert saved["max_per_account"] == 200
|
||||
assert saved["accounts"][0]["name"] == "Test"
|
||||
|
||||
|
||||
def test_get_config_round_trips(client, tmp_path):
|
||||
from app.data import label as label_module
|
||||
label_module.set_config_dir(tmp_path)
|
||||
payload = {"accounts": [{"name": "R", "host": "h", "port": 993, "use_ssl": True,
|
||||
"username": "u", "password": "p", "folder": "INBOX",
|
||||
"days_back": 90}], "max_per_account": 300}
|
||||
client.post("/api/config", json=payload)
|
||||
r = client.get("/api/config")
|
||||
data = r.json()
|
||||
assert data["max_per_account"] == 300
|
||||
assert data["accounts"][0]["name"] == "R"
|
||||
|
||||
|
||||
def test_stats_returns_counts(client, tmp_path):
|
||||
from app.data import label as label_module
|
||||
label_module.set_data_dir(tmp_path)
|
||||
score_path = tmp_path / "email_score.jsonl"
|
||||
records = [{"id": "a", "label": "interview_scheduled"},
|
||||
{"id": "b", "label": "interview_scheduled"},
|
||||
{"id": "c", "label": "rejected"}]
|
||||
score_path.write_text("\n".join(json.dumps(r) for r in records) + "\n")
|
||||
r = client.get("/api/stats")
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["total"] == 3
|
||||
assert data["counts"]["interview_scheduled"] == 2
|
||||
assert data["counts"]["rejected"] == 1
|
||||
|
||||
|
||||
def test_stats_empty_when_no_file(client):
|
||||
r = client.get("/api/stats")
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["total"] == 0
|
||||
assert data["counts"] == {}
|
||||
assert data["score_file_bytes"] == 0
|
||||
|
||||
|
||||
def test_stats_download_returns_file(client, tmp_path):
|
||||
from app.data import label as label_module
|
||||
label_module.set_data_dir(tmp_path)
|
||||
(tmp_path / "email_score.jsonl").write_text(json.dumps({"id": "a", "label": "neutral"}) + "\n")
|
||||
r = client.get("/api/stats/download")
|
||||
assert r.status_code == 200
|
||||
assert "jsonlines" in r.headers.get("content-type", "")
|
||||
|
||||
|
||||
def test_stats_download_404_when_no_file(client):
|
||||
r = client.get("/api/stats/download")
|
||||
assert r.status_code == 404
|
||||
|
|
@ -1,234 +0,0 @@
|
|||
"""Tests for app/eval/embed_bench.py."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
|
||||
# ── Fixtures ──────────────────────────────────────────────────────────────────
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_embed_bench_globals(tmp_path):
|
||||
"""Redirect config dir to tmp_path and reset running flag."""
|
||||
from app.eval import embed_bench as mod
|
||||
|
||||
prev_config_dir = mod._CONFIG_DIR
|
||||
prev_running = mod._RUN_ACTIVE
|
||||
|
||||
mod.set_config_dir(tmp_path)
|
||||
mod._RUN_ACTIVE = False
|
||||
|
||||
yield tmp_path
|
||||
|
||||
mod.set_config_dir(prev_config_dir)
|
||||
mod._RUN_ACTIVE = prev_running
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
from app.api import app
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
# ── cosine helper ──────────────────────────────────────────────────────────────
|
||||
|
||||
def test_cosine_identical():
|
||||
from app.eval.embed_bench import _cosine
|
||||
assert _cosine([1.0, 0.0], [1.0, 0.0]) == pytest.approx(1.0)
|
||||
|
||||
|
||||
def test_cosine_orthogonal():
|
||||
from app.eval.embed_bench import _cosine
|
||||
assert _cosine([1.0, 0.0], [0.0, 1.0]) == pytest.approx(0.0)
|
||||
|
||||
|
||||
def test_cosine_opposite():
|
||||
from app.eval.embed_bench import _cosine
|
||||
assert _cosine([1.0, 0.0], [-1.0, 0.0]) == pytest.approx(-1.0)
|
||||
|
||||
|
||||
def test_cosine_zero_vector_returns_zero():
|
||||
from app.eval.embed_bench import _cosine
|
||||
assert _cosine([0.0, 0.0], [1.0, 0.0]) == pytest.approx(0.0)
|
||||
|
||||
|
||||
# ── models endpoint ────────────────────────────────────────────────────────────
|
||||
|
||||
def test_models_returns_list_with_mock(client, tmp_path):
|
||||
"""GET /api/embed-bench/models returns list from Ollama tags endpoint."""
|
||||
import yaml
|
||||
cfg = {"cforch": {"ollama_url": "http://localhost:11434"}}
|
||||
(tmp_path / "label_tool.yaml").write_text(yaml.dump(cfg))
|
||||
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.json.return_value = {
|
||||
"models": [
|
||||
{"name": "nomic-embed-text", "size": 274302480},
|
||||
{"name": "mxbai-embed-large", "size": 669000000},
|
||||
]
|
||||
}
|
||||
mock_resp.raise_for_status = MagicMock()
|
||||
|
||||
with patch("app.eval.embed_bench.httpx.get", return_value=mock_resp):
|
||||
r = client.get("/api/embed-bench/models")
|
||||
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert isinstance(data["models"], list)
|
||||
assert any(m["name"] == "nomic-embed-text" for m in data["models"])
|
||||
|
||||
|
||||
def test_models_returns_empty_on_ollama_error(client, tmp_path):
|
||||
"""GET /api/embed-bench/models returns empty list if Ollama unreachable."""
|
||||
import httpx
|
||||
with patch("app.eval.embed_bench.httpx.get", side_effect=httpx.ConnectError("refused")):
|
||||
r = client.get("/api/embed-bench/models")
|
||||
assert r.status_code == 200
|
||||
assert r.json()["models"] == []
|
||||
|
||||
|
||||
# ── run endpoint ───────────────────────────────────────────────────────────────
|
||||
|
||||
def test_run_empty_corpus_returns_422(client):
|
||||
r = client.post("/api/embed-bench/run", json={
|
||||
"corpus": [], "queries": ["test"], "models": ["nomic-embed-text"], "top_k": 3
|
||||
})
|
||||
assert r.status_code == 422
|
||||
|
||||
|
||||
def test_run_empty_queries_returns_422(client):
|
||||
r = client.post("/api/embed-bench/run", json={
|
||||
"corpus": ["chunk 1"], "queries": [], "models": ["nomic-embed-text"], "top_k": 3
|
||||
})
|
||||
assert r.status_code == 422
|
||||
|
||||
|
||||
def test_run_empty_models_returns_422(client):
|
||||
r = client.post("/api/embed-bench/run", json={
|
||||
"corpus": ["chunk 1"], "queries": ["test"], "models": [], "top_k": 3
|
||||
})
|
||||
assert r.status_code == 422
|
||||
|
||||
|
||||
def _fake_embed_response(texts: list[str]) -> MagicMock:
|
||||
"""Build a mock httpx.post response returning unit vectors for each text."""
|
||||
resp = MagicMock()
|
||||
resp.raise_for_status = MagicMock()
|
||||
resp.json.return_value = {
|
||||
"data": [{"embedding": [1.0, 0.0, 0.0] if i % 2 == 0 else [0.0, 1.0, 0.0]}
|
||||
for i, _ in enumerate(texts)]
|
||||
}
|
||||
return resp
|
||||
|
||||
|
||||
def _collect_sse(raw: bytes) -> list[dict]:
|
||||
"""Parse SSE stream bytes into a list of decoded event dicts."""
|
||||
events = []
|
||||
for line in raw.decode().splitlines():
|
||||
if line.startswith("data: "):
|
||||
events.append(json.loads(line[6:]))
|
||||
return events
|
||||
|
||||
|
||||
def test_run_single_model_returns_result_and_done(client, tmp_path):
|
||||
import yaml
|
||||
(tmp_path / "label_tool.yaml").write_text(yaml.dump({"cforch": {"ollama_url": "http://localhost:11434"}}))
|
||||
|
||||
with patch("app.eval.embed_bench.httpx.post", return_value=_fake_embed_response(["chunk 1", "chunk 2"])):
|
||||
r = client.post("/api/embed-bench/run", json={
|
||||
"corpus": ["chunk 1", "chunk 2"],
|
||||
"queries": ["what is chunk one?"],
|
||||
"models": ["nomic-embed-text"],
|
||||
"top_k": 2,
|
||||
})
|
||||
|
||||
assert r.status_code == 200
|
||||
events = _collect_sse(r.content)
|
||||
types = [e["type"] for e in events]
|
||||
assert "result" in types
|
||||
assert types[-1] == "done"
|
||||
result_events = [e for e in events if e["type"] == "result"]
|
||||
assert result_events[0]["model"] == "nomic-embed-text"
|
||||
assert result_events[0]["query_idx"] == 0
|
||||
assert len(result_events[0]["hits"]) <= 2
|
||||
|
||||
|
||||
def test_run_two_models_returns_two_result_events_per_query(client, tmp_path):
|
||||
import yaml
|
||||
(tmp_path / "label_tool.yaml").write_text(yaml.dump({"cforch": {"ollama_url": "http://localhost:11434"}}))
|
||||
|
||||
with patch("app.eval.embed_bench.httpx.post", return_value=_fake_embed_response(["chunk A", "chunk B"])):
|
||||
r = client.post("/api/embed-bench/run", json={
|
||||
"corpus": ["chunk A", "chunk B"],
|
||||
"queries": ["find it"],
|
||||
"models": ["nomic-embed-text", "mxbai-embed-large"],
|
||||
"top_k": 2,
|
||||
})
|
||||
|
||||
events = _collect_sse(r.content)
|
||||
result_events = [e for e in events if e["type"] == "result"]
|
||||
models_seen = {e["model"] for e in result_events}
|
||||
assert "nomic-embed-text" in models_seen
|
||||
assert "mxbai-embed-large" in models_seen
|
||||
|
||||
|
||||
# ── rate + export ──────────────────────────────────────────────────────────────
|
||||
|
||||
def test_rate_appends_jsonl_line(client, tmp_path):
|
||||
r = client.post("/api/embed-bench/rate", json={
|
||||
"query": "test query",
|
||||
"model": "nomic-embed-text",
|
||||
"chunk_text": "some text",
|
||||
"chunk_idx": 2,
|
||||
"rating": "relevant",
|
||||
})
|
||||
assert r.status_code == 200
|
||||
assert r.json() == {"ok": True}
|
||||
ratings_file = tmp_path / "embed_bench_ratings.jsonl"
|
||||
assert ratings_file.exists()
|
||||
line = json.loads(ratings_file.read_text().strip())
|
||||
assert line["query"] == "test query"
|
||||
assert line["rating"] == "relevant"
|
||||
assert line["chunk_idx"] == 2
|
||||
assert "timestamp" in line
|
||||
|
||||
|
||||
def test_export_csv_two_rows(client, tmp_path):
|
||||
for i in range(2):
|
||||
client.post("/api/embed-bench/rate", json={
|
||||
"query": f"q{i}", "model": "nomic-embed-text",
|
||||
"chunk_text": f"chunk {i}", "chunk_idx": i, "rating": "relevant",
|
||||
})
|
||||
r = client.get("/api/embed-bench/export?format=csv")
|
||||
assert r.status_code == 200
|
||||
assert "text/csv" in r.headers["content-type"]
|
||||
lines = r.text.strip().splitlines()
|
||||
assert len(lines) == 3 # header + 2 rows
|
||||
assert "query" in lines[0]
|
||||
|
||||
|
||||
def test_export_json_two_entries(client, tmp_path):
|
||||
for i in range(2):
|
||||
client.post("/api/embed-bench/rate", json={
|
||||
"query": f"q{i}", "model": "nomic-embed-text",
|
||||
"chunk_text": f"chunk {i}", "chunk_idx": i, "rating": "not_relevant",
|
||||
})
|
||||
r = client.get("/api/embed-bench/export?format=json")
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert isinstance(data, list)
|
||||
assert len(data) == 2
|
||||
assert data[0]["rating"] == "not_relevant"
|
||||
|
||||
|
||||
def test_export_empty_returns_csv_header_only(client):
|
||||
r = client.get("/api/embed-bench/export?format=csv")
|
||||
assert r.status_code == 200
|
||||
lines = r.text.strip().splitlines()
|
||||
assert len(lines) == 1 # header only
|
||||
assert "query" in lines[0]
|
||||
|
|
@ -321,7 +321,6 @@ def test_load_and_prepare_data_single_path_still_works(tmp_path):
|
|||
|
||||
# ---- Integration test ----
|
||||
|
||||
@pytest.mark.gpu
|
||||
def test_integration_finetune_on_example_data(tmp_path):
|
||||
"""Fine-tune deberta-small on example data for 1 epoch.
|
||||
|
||||
|
|
|
|||
|
|
@ -1,242 +0,0 @@
|
|||
"""Tests for app/imitate.py -- product registry, sample extraction, corrections push."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.api import app
|
||||
from app.data import imitate as _imitate_module
|
||||
|
||||
|
||||
# -- Fixtures ------------------------------------------------------------------
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_module_globals(tmp_path):
|
||||
"""Reset module-level config + data dir globals after each test."""
|
||||
orig_cfg = _imitate_module._CONFIG_DIR
|
||||
orig_data = _imitate_module._DATA_DIR
|
||||
yield
|
||||
_imitate_module._CONFIG_DIR = orig_cfg
|
||||
_imitate_module._DATA_DIR = orig_data
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def config_dir(tmp_path) -> Path:
|
||||
_imitate_module.set_config_dir(tmp_path)
|
||||
return tmp_path
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def data_dir(tmp_path) -> Path:
|
||||
_imitate_module.set_data_dir(tmp_path)
|
||||
return tmp_path
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def cfg_with_products(config_dir: Path) -> Path:
|
||||
"""Write a label_tool.yaml with two products."""
|
||||
(config_dir / "label_tool.yaml").write_text(
|
||||
"""
|
||||
imitate:
|
||||
ollama_url: http://localhost:11434
|
||||
products:
|
||||
- id: peregrine
|
||||
name: Peregrine
|
||||
icon: "🦅"
|
||||
description: Job search assistant
|
||||
base_url: http://peregrine.local
|
||||
sample_endpoint: /api/jobs
|
||||
text_fields: [title, description]
|
||||
prompt_template: "Analyze: {text}"
|
||||
- id: kiwi
|
||||
name: Kiwi
|
||||
icon: "🥝"
|
||||
description: Pantry tracker
|
||||
base_url: http://kiwi.local
|
||||
sample_endpoint: /api/inventory
|
||||
text_fields: [name, notes]
|
||||
prompt_template: "Describe: {text}"
|
||||
"""
|
||||
)
|
||||
return config_dir
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def client() -> TestClient:
|
||||
return TestClient(app, raise_server_exceptions=True)
|
||||
|
||||
|
||||
# -- GET /products -------------------------------------------------------------
|
||||
|
||||
def test_products_empty_when_no_config(config_dir, client):
|
||||
"""Returns empty list when label_tool.yaml has no imitate section."""
|
||||
(config_dir / "label_tool.yaml").write_text("accounts: []\n")
|
||||
resp = client.get("/api/imitate/products")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["products"] == []
|
||||
|
||||
|
||||
def test_products_listed(cfg_with_products, client):
|
||||
"""All configured products are returned with expected fields."""
|
||||
with patch.object(_imitate_module, "_is_online", return_value=True):
|
||||
resp = client.get("/api/imitate/products")
|
||||
assert resp.status_code == 200
|
||||
products = resp.json()["products"]
|
||||
assert len(products) == 2
|
||||
ids = {p["id"] for p in products}
|
||||
assert ids == {"peregrine", "kiwi"}
|
||||
peregrine = next(p for p in products if p["id"] == "peregrine")
|
||||
assert peregrine["name"] == "Peregrine"
|
||||
assert peregrine["icon"] == "🦅"
|
||||
assert peregrine["online"] is True
|
||||
|
||||
|
||||
def test_products_offline_when_unreachable(cfg_with_products, client):
|
||||
"""Products with unreachable base_url are marked offline."""
|
||||
with patch.object(_imitate_module, "_is_online", return_value=False):
|
||||
resp = client.get("/api/imitate/products")
|
||||
assert all(not p["online"] for p in resp.json()["products"])
|
||||
|
||||
|
||||
# -- GET /products/{id}/sample -------------------------------------------------
|
||||
|
||||
def test_sample_unknown_product(cfg_with_products, client):
|
||||
"""Returns 404 for a product id not in config."""
|
||||
resp = client.get("/api/imitate/products/nonexistent/sample")
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
def test_sample_fetched_from_list(cfg_with_products, client):
|
||||
"""Extracts first item from a list API response."""
|
||||
fake_api = [
|
||||
{"title": "Engineer", "description": "Build things"},
|
||||
{"title": "Other", "description": "Ignore me"},
|
||||
]
|
||||
with patch.object(_imitate_module, "_http_get_json", return_value=fake_api):
|
||||
resp = client.get("/api/imitate/products/peregrine/sample")
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert "Engineer" in body["text"]
|
||||
assert "Build things" in body["text"]
|
||||
assert "Analyze:" in body["prompt"]
|
||||
|
||||
|
||||
def test_sample_fetched_from_dict_with_items_key(cfg_with_products, client):
|
||||
"""Extracts from a wrapper dict with a recognised list key."""
|
||||
fake_api = {"items": [{"title": "Wrapped Job", "description": "In a wrapper"}]}
|
||||
with patch.object(_imitate_module, "_http_get_json", return_value=fake_api):
|
||||
resp = client.get("/api/imitate/products/peregrine/sample")
|
||||
assert resp.status_code == 200
|
||||
assert "Wrapped Job" in resp.json()["text"]
|
||||
|
||||
|
||||
def test_sample_503_when_api_unreachable(cfg_with_products, client):
|
||||
"""Returns 503 when the product API is not reachable."""
|
||||
from urllib.error import URLError
|
||||
with patch.object(_imitate_module, "_http_get_json", side_effect=URLError("refused")):
|
||||
resp = client.get("/api/imitate/products/peregrine/sample")
|
||||
assert resp.status_code == 503
|
||||
|
||||
|
||||
def test_sample_404_on_empty_list(cfg_with_products, client):
|
||||
"""Returns 404 when product API returns an empty list."""
|
||||
with patch.object(_imitate_module, "_http_get_json", return_value=[]):
|
||||
resp = client.get("/api/imitate/products/peregrine/sample")
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
# -- POST /push-corrections ----------------------------------------------------
|
||||
|
||||
def test_push_corrections_appends_jsonl(cfg_with_products, data_dir, client):
|
||||
"""Successful push writes records to sft_candidates.jsonl."""
|
||||
payload = {
|
||||
"product_id": "peregrine",
|
||||
"prompt": "Analyze this job:",
|
||||
"results": [
|
||||
{"model": "qwen2.5:0.5b", "response": "It's a good job.", "elapsed_ms": 800, "error": None},
|
||||
{"model": "llama3.1:8b", "response": "Strong candidate.", "elapsed_ms": 1500, "error": None},
|
||||
],
|
||||
}
|
||||
resp = client.post("/api/imitate/push-corrections", json=payload)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["pushed"] == 2
|
||||
|
||||
candidates = (data_dir / "sft_candidates.jsonl").read_text().splitlines()
|
||||
assert len(candidates) == 2
|
||||
for line in candidates:
|
||||
record = json.loads(line)
|
||||
assert record["source"] == "imitate"
|
||||
assert record["product_id"] == "peregrine"
|
||||
assert record["status"] == "pending"
|
||||
assert record["prompt_messages"][0]["role"] == "user"
|
||||
|
||||
|
||||
def test_push_corrections_skips_errors(cfg_with_products, data_dir, client):
|
||||
"""Results with errors are not written to the corrections file."""
|
||||
payload = {
|
||||
"product_id": "peregrine",
|
||||
"prompt": "Analyze:",
|
||||
"results": [
|
||||
{"model": "good-model", "response": "Good answer.", "elapsed_ms": 500, "error": None},
|
||||
{"model": "bad-model", "response": "", "elapsed_ms": 0, "error": "connection refused"},
|
||||
],
|
||||
}
|
||||
resp = client.post("/api/imitate/push-corrections", json=payload)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["pushed"] == 1
|
||||
|
||||
|
||||
def test_push_corrections_empty_prompt_422(cfg_with_products, data_dir, client):
|
||||
"""Empty prompt returns 422."""
|
||||
payload = {
|
||||
"product_id": "peregrine",
|
||||
"prompt": " ",
|
||||
"results": [{"model": "m", "response": "r", "elapsed_ms": 1, "error": None}],
|
||||
}
|
||||
resp = client.post("/api/imitate/push-corrections", json=payload)
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
def test_push_corrections_all_errors_422(cfg_with_products, data_dir, client):
|
||||
"""422 when every result has an error (nothing to push)."""
|
||||
payload = {
|
||||
"product_id": "peregrine",
|
||||
"prompt": "Analyze:",
|
||||
"results": [
|
||||
{"model": "m", "response": "", "elapsed_ms": 0, "error": "timed out"},
|
||||
],
|
||||
}
|
||||
resp = client.post("/api/imitate/push-corrections", json=payload)
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
# -- _extract_sample helper ----------------------------------------------------
|
||||
|
||||
def test_extract_sample_list():
|
||||
result = _imitate_module._extract_sample(
|
||||
[{"title": "A", "description": "B"}],
|
||||
text_fields=["title", "description"],
|
||||
)
|
||||
assert "A" in result["text"]
|
||||
assert "B" in result["text"]
|
||||
|
||||
|
||||
def test_extract_sample_empty_list():
|
||||
result = _imitate_module._extract_sample([], text_fields=["title"])
|
||||
assert result == {}
|
||||
|
||||
|
||||
def test_extract_sample_respects_index():
|
||||
items = [{"title": "First"}, {"title": "Second"}]
|
||||
result = _imitate_module._extract_sample(items, ["title"], sample_index=1)
|
||||
assert "Second" in result["text"]
|
||||
|
||||
|
||||
def test_extract_sample_clamps_index():
|
||||
items = [{"title": "Only"}]
|
||||
result = _imitate_module._extract_sample(items, ["title"], sample_index=99)
|
||||
assert "Only" in result["text"]
|
||||
|
|
@ -5,83 +5,83 @@ These functions are stdlib-only and safe to test without an IMAP connection.
|
|||
from email.mime.multipart import MIMEMultipart
|
||||
from email.mime.text import MIMEText
|
||||
|
||||
from app.utils import extract_body, strip_html
|
||||
from app.label_tool import _extract_body, _strip_html
|
||||
|
||||
|
||||
# ── strip_html ──────────────────────────────────────────────────────────────
|
||||
# ── _strip_html ──────────────────────────────────────────────────────────────
|
||||
|
||||
def test_strip_html_removes_tags():
|
||||
assert strip_html("<p>Hello <b>world</b></p>") == "Hello world"
|
||||
assert _strip_html("<p>Hello <b>world</b></p>") == "Hello world"
|
||||
|
||||
|
||||
def test_strip_html_skips_script_content():
|
||||
result = strip_html("<script>doEvil()</script><p>real</p>")
|
||||
result = _strip_html("<script>doEvil()</script><p>real</p>")
|
||||
assert "doEvil" not in result
|
||||
assert "real" in result
|
||||
|
||||
|
||||
def test_strip_html_skips_style_content():
|
||||
result = strip_html("<style>.foo{color:red}</style><p>visible</p>")
|
||||
result = _strip_html("<style>.foo{color:red}</style><p>visible</p>")
|
||||
assert ".foo" not in result
|
||||
assert "visible" in result
|
||||
|
||||
|
||||
def test_strip_html_handles_br_as_newline():
|
||||
result = strip_html("line1<br>line2")
|
||||
result = _strip_html("line1<br>line2")
|
||||
assert "line1" in result
|
||||
assert "line2" in result
|
||||
|
||||
|
||||
def test_strip_html_decodes_entities():
|
||||
# convert_charrefs=True on HTMLParser handles & etc.
|
||||
result = strip_html("<p>Hello & welcome</p>")
|
||||
result = _strip_html("<p>Hello & welcome</p>")
|
||||
assert "&" not in result
|
||||
assert "Hello" in result
|
||||
assert "welcome" in result
|
||||
|
||||
|
||||
def test_strip_html_empty_string():
|
||||
assert strip_html("") == ""
|
||||
assert _strip_html("") == ""
|
||||
|
||||
|
||||
def test_strip_html_plain_text_passthrough():
|
||||
assert strip_html("no tags here") == "no tags here"
|
||||
assert _strip_html("no tags here") == "no tags here"
|
||||
|
||||
|
||||
# ── extract_body ────────────────────────────────────────────────────────────
|
||||
# ── _extract_body ────────────────────────────────────────────────────────────
|
||||
|
||||
def test_extract_body_prefers_plain_over_html():
|
||||
msg = MIMEMultipart("alternative")
|
||||
msg.attach(MIMEText("plain body", "plain"))
|
||||
msg.attach(MIMEText("<html><body>html body</body></html>", "html"))
|
||||
assert extract_body(msg) == "plain body"
|
||||
assert _extract_body(msg) == "plain body"
|
||||
|
||||
|
||||
def test_extract_body_falls_back_to_html_when_no_plain():
|
||||
msg = MIMEMultipart("alternative")
|
||||
msg.attach(MIMEText("<html><body><p>HTML only email</p></body></html>", "html"))
|
||||
result = extract_body(msg)
|
||||
result = _extract_body(msg)
|
||||
assert "HTML only email" in result
|
||||
assert "<" not in result # no raw HTML tags leaked through
|
||||
|
||||
|
||||
def test_extract_body_non_multipart_html_stripped():
|
||||
msg = MIMEText("<html><body><p>Solo HTML</p></body></html>", "html")
|
||||
result = extract_body(msg)
|
||||
result = _extract_body(msg)
|
||||
assert "Solo HTML" in result
|
||||
assert "<html>" not in result
|
||||
|
||||
|
||||
def test_extract_body_non_multipart_plain_unchanged():
|
||||
msg = MIMEText("just plain text", "plain")
|
||||
assert extract_body(msg) == "just plain text"
|
||||
assert _extract_body(msg) == "just plain text"
|
||||
|
||||
|
||||
def test_extract_body_empty_message():
|
||||
msg = MIMEText("", "plain")
|
||||
assert extract_body(msg) == ""
|
||||
assert _extract_body(msg) == ""
|
||||
|
||||
|
||||
def test_extract_body_multipart_empty_returns_empty():
|
||||
msg = MIMEMultipart("alternative")
|
||||
assert extract_body(msg) == ""
|
||||
assert _extract_body(msg) == ""
|
||||
|
|
|
|||
|
|
@ -1,454 +0,0 @@
|
|||
"""Tests for app/data/log_corpus.py — corpus receiver and labeling endpoints."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.data import log_corpus as lc
|
||||
|
||||
|
||||
VALID_TOKEN = str(uuid.uuid4())
|
||||
VALID_HOST = "testnode.local"
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def isolated_db(tmp_path, monkeypatch):
|
||||
"""Each test gets its own fresh corpus DB and config dir."""
|
||||
monkeypatch.setattr(lc, "_DATA_DIR", tmp_path)
|
||||
monkeypatch.setattr(lc, "_DB_PATH", tmp_path / "corpus.db")
|
||||
# Config dir pointing to a temp yaml with one test source
|
||||
config_dir = tmp_path / "config"
|
||||
config_dir.mkdir()
|
||||
(config_dir / "label_tool.yaml").write_text(
|
||||
f"corpus:\n sources:\n"
|
||||
f" - token: \"{VALID_TOKEN}\"\n"
|
||||
f" source_host: \"{VALID_HOST}\"\n"
|
||||
f" owner: TestOwner\n"
|
||||
f" consent_date: \"2026-05-11\"\n"
|
||||
f" consent_method: signal_chat\n"
|
||||
)
|
||||
monkeypatch.setattr(lc, "_CONFIG_DIR", config_dir)
|
||||
lc._init_db()
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def client():
|
||||
from fastapi import FastAPI
|
||||
app = FastAPI()
|
||||
app.include_router(lc.router, prefix="/api/corpus")
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
def _batch(batch_type="raw_entries", entries=None, source_host=VALID_HOST):
|
||||
return {
|
||||
"batch_version": 1,
|
||||
"batch_id": str(uuid.uuid4()),
|
||||
"pushed_at": "2026-05-11T10:00:00Z",
|
||||
"source_host": source_host,
|
||||
"batch_type": batch_type,
|
||||
"watermark_from": 0,
|
||||
"watermark_to": 5,
|
||||
"entries": entries or [
|
||||
{
|
||||
"entry_id": str(uuid.uuid4()),
|
||||
"source_id": "sonarr",
|
||||
"timestamp_iso": "2026-05-11T09:58:00Z",
|
||||
"severity": "ERROR",
|
||||
"text": "Connection refused to indexer",
|
||||
"matched_patterns": [],
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
# ── Receive endpoint ───────────────────────────────────────────────────────────
|
||||
|
||||
def test_receive_missing_auth(client):
|
||||
resp = client.post("/api/corpus/log-batch", json=_batch())
|
||||
assert resp.status_code == 401
|
||||
|
||||
|
||||
def test_receive_invalid_token(client):
|
||||
resp = client.post(
|
||||
"/api/corpus/log-batch",
|
||||
json=_batch(),
|
||||
headers={"Authorization": "Bearer bad-token"},
|
||||
)
|
||||
assert resp.status_code == 403
|
||||
|
||||
|
||||
def test_receive_valid_batch(client):
|
||||
resp = client.post(
|
||||
"/api/corpus/log-batch",
|
||||
json=_batch(),
|
||||
headers={"Authorization": f"Bearer {VALID_TOKEN}"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["received"] is True
|
||||
assert data["entries_stored"] == 1
|
||||
|
||||
|
||||
def test_receive_stores_source_host_from_token_not_payload(client):
|
||||
"""source_host is always taken from the DB lookup, not the payload."""
|
||||
payload = _batch(source_host="attacker-injected-host")
|
||||
resp = client.post(
|
||||
"/api/corpus/log-batch",
|
||||
json=payload,
|
||||
headers={"Authorization": f"Bearer {VALID_TOKEN}"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
entries_resp = client.get("/api/corpus/entries")
|
||||
entry = entries_resp.json()["entries"][0]
|
||||
assert entry["source_host"] == VALID_HOST
|
||||
|
||||
|
||||
def test_receive_skips_empty_text_entries(client):
|
||||
payload = _batch(entries=[
|
||||
{"entry_id": "e1", "source_id": "svc", "severity": "ERROR", "text": ""},
|
||||
{"entry_id": "e2", "source_id": "svc", "severity": "ERROR", "text": " "},
|
||||
{"entry_id": "e3", "source_id": "svc", "severity": "ERROR", "text": "real error"},
|
||||
])
|
||||
resp = client.post(
|
||||
"/api/corpus/log-batch",
|
||||
json=payload,
|
||||
headers={"Authorization": f"Bearer {VALID_TOKEN}"},
|
||||
)
|
||||
assert resp.json()["entries_stored"] == 1
|
||||
|
||||
|
||||
def test_receive_incident_bundle(client):
|
||||
payload = _batch(batch_type="incident_bundles", entries=[
|
||||
{"id": "inc-1", "label": "plex crash", "issue_type": "plex",
|
||||
"started_at": "2026-05-11T09:00:00", "ended_at": "2026-05-11T09:30:00",
|
||||
"notes": "audio dropped", "created_at": "2026-05-11T09:35:00",
|
||||
"severity": "high", "text": "plex crash"},
|
||||
])
|
||||
resp = client.post(
|
||||
"/api/corpus/log-batch",
|
||||
json=payload,
|
||||
headers={"Authorization": f"Bearer {VALID_TOKEN}"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["entries_stored"] == 1
|
||||
|
||||
|
||||
# ── Labeling endpoints ─────────────────────────────────────────────────────────
|
||||
|
||||
def test_label_entry(client):
|
||||
client.post(
|
||||
"/api/corpus/log-batch",
|
||||
json=_batch(),
|
||||
headers={"Authorization": f"Bearer {VALID_TOKEN}"},
|
||||
)
|
||||
entry_id = client.get("/api/corpus/entries").json()["entries"][0]["id"]
|
||||
|
||||
resp = client.post(f"/api/corpus/entries/{entry_id}/label", json={
|
||||
"failure_type": "software",
|
||||
"plain_explanation": "Sonarr lost connection to its indexer — restart the service.",
|
||||
"known_pattern": "y",
|
||||
})
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["labeled"] is True
|
||||
|
||||
entries = client.get("/api/corpus/entries", params={"state": "labeled"}).json()["entries"]
|
||||
assert len(entries) == 1
|
||||
assert entries[0]["failure_type"] == "software"
|
||||
|
||||
|
||||
def test_label_entry_invalid_failure_type(client):
|
||||
client.post(
|
||||
"/api/corpus/log-batch",
|
||||
json=_batch(),
|
||||
headers={"Authorization": f"Bearer {VALID_TOKEN}"},
|
||||
)
|
||||
entry_id = client.get("/api/corpus/entries").json()["entries"][0]["id"]
|
||||
resp = client.post(f"/api/corpus/entries/{entry_id}/label", json={"failure_type": "aliens"})
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
def test_label_entry_missing_failure_type(client):
|
||||
client.post(
|
||||
"/api/corpus/log-batch",
|
||||
json=_batch(),
|
||||
headers={"Authorization": f"Bearer {VALID_TOKEN}"},
|
||||
)
|
||||
entry_id = client.get("/api/corpus/entries").json()["entries"][0]["id"]
|
||||
resp = client.post(f"/api/corpus/entries/{entry_id}/label", json={})
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
def test_label_entry_not_found(client):
|
||||
resp = client.post("/api/corpus/entries/nonexistent/label", json={"failure_type": "software"})
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
def test_skip_entry(client):
|
||||
client.post(
|
||||
"/api/corpus/log-batch",
|
||||
json=_batch(),
|
||||
headers={"Authorization": f"Bearer {VALID_TOKEN}"},
|
||||
)
|
||||
entry_id = client.get("/api/corpus/entries").json()["entries"][0]["id"]
|
||||
resp = client.post(f"/api/corpus/entries/{entry_id}/skip")
|
||||
assert resp.status_code == 200
|
||||
|
||||
unlabeled = client.get("/api/corpus/entries").json()["entries"]
|
||||
assert len(unlabeled) == 0
|
||||
|
||||
|
||||
# ── Stats ──────────────────────────────────────────────────────────────────────
|
||||
|
||||
def test_stats_empty(client):
|
||||
stats = client.get("/api/corpus/stats").json()
|
||||
assert stats["total_entries"] == 0
|
||||
assert stats["batch_count"] == 0
|
||||
|
||||
|
||||
def test_stats_after_receive(client):
|
||||
client.post(
|
||||
"/api/corpus/log-batch",
|
||||
json=_batch(),
|
||||
headers={"Authorization": f"Bearer {VALID_TOKEN}"},
|
||||
)
|
||||
stats = client.get("/api/corpus/stats").json()
|
||||
assert stats["total_entries"] == 1
|
||||
assert stats["batch_count"] == 1
|
||||
assert stats["by_label_state"].get("unlabeled", 0) == 1
|
||||
|
||||
|
||||
# ── Export ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
def test_export_excludes_unlabeled(client):
|
||||
client.post(
|
||||
"/api/corpus/log-batch",
|
||||
json=_batch(),
|
||||
headers={"Authorization": f"Bearer {VALID_TOKEN}"},
|
||||
)
|
||||
resp = client.get("/api/corpus/export")
|
||||
assert resp.status_code == 200
|
||||
assert resp.text.strip() == ""
|
||||
|
||||
|
||||
def test_export_includes_labeled(client):
|
||||
client.post(
|
||||
"/api/corpus/log-batch",
|
||||
json=_batch(),
|
||||
headers={"Authorization": f"Bearer {VALID_TOKEN}"},
|
||||
)
|
||||
entry_id = client.get("/api/corpus/entries").json()["entries"][0]["id"]
|
||||
client.post(f"/api/corpus/entries/{entry_id}/label", json={
|
||||
"failure_type": "software",
|
||||
"plain_explanation": "Sonarr lost connection to indexer.",
|
||||
})
|
||||
|
||||
resp = client.get("/api/corpus/export")
|
||||
assert resp.status_code == 200
|
||||
lines = [l for l in resp.text.strip().splitlines() if l]
|
||||
assert len(lines) == 1
|
||||
record = json.loads(lines[0])
|
||||
assert record["output"] == "Sonarr lost connection to indexer."
|
||||
assert record["metadata"]["failure_type"] == "software"
|
||||
|
||||
|
||||
def test_export_excludes_pii_flagged(client):
|
||||
client.post(
|
||||
"/api/corpus/log-batch",
|
||||
json=_batch(),
|
||||
headers={"Authorization": f"Bearer {VALID_TOKEN}"},
|
||||
)
|
||||
entry_id = client.get("/api/corpus/entries").json()["entries"][0]["id"]
|
||||
client.post(f"/api/corpus/entries/{entry_id}/label", json={
|
||||
"failure_type": "software",
|
||||
"plain_explanation": "Contains username — should not export.",
|
||||
"pii_flagged": True,
|
||||
})
|
||||
|
||||
resp = client.get("/api/corpus/export")
|
||||
assert resp.text.strip() == ""
|
||||
|
||||
|
||||
# ── Pipeline ingest endpoint ───────────────────────────────────────────────────
|
||||
|
||||
def _make_pipeline_file(directory: Path, name: str, lines: list[dict]) -> Path:
|
||||
"""Write a JSONL pipeline log file to directory."""
|
||||
p = directory / name
|
||||
p.write_text("\n".join(json.dumps(l) for l in lines), encoding="utf-8")
|
||||
return p
|
||||
|
||||
|
||||
_PIPELINE_LINE = {
|
||||
"ts": "2026-05-17T10:00:00Z",
|
||||
"level": "INFO",
|
||||
"logger": "scripts.pipeline.purple_carrot_scraper",
|
||||
"msg": "Fetched recipe page",
|
||||
"extra": {"url": "https://example.com/recipe/1", "status": 200},
|
||||
}
|
||||
|
||||
|
||||
def test_pipeline_ingest_returns_404_when_dir_not_configured(client, tmp_path):
|
||||
"""No pipeline_ingest_dir in config — endpoint returns 404."""
|
||||
resp = client.post("/api/corpus/pipeline-ingest")
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
def test_pipeline_ingest_empty_dir(client, tmp_path, monkeypatch):
|
||||
"""Configured dir exists but is empty — returns zeros, no error."""
|
||||
ingest_dir = tmp_path / "pipeline_logs"
|
||||
ingest_dir.mkdir()
|
||||
config_dir = tmp_path / "config"
|
||||
config_dir.mkdir(exist_ok=True)
|
||||
(config_dir / "label_tool.yaml").write_text(
|
||||
f"corpus:\n pipeline_ingest_dir: \"{ingest_dir}\"\n sources: []\n"
|
||||
)
|
||||
monkeypatch.setattr(lc, "_CONFIG_DIR", config_dir)
|
||||
|
||||
resp = client.post("/api/corpus/pipeline-ingest")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["ingested_files"] == 0
|
||||
assert data["skipped_files"] == 0
|
||||
assert data["entries_stored"] == 0
|
||||
|
||||
|
||||
def test_pipeline_ingest_ingests_valid_file(client, tmp_path, monkeypatch):
|
||||
"""Valid JSONL file is ingested; entries appear in corpus."""
|
||||
ingest_dir = tmp_path / "pipeline_logs"
|
||||
ingest_dir.mkdir()
|
||||
_make_pipeline_file(ingest_dir, "scraper_20260517.jsonl", [
|
||||
_PIPELINE_LINE,
|
||||
{**_PIPELINE_LINE, "msg": "Saved 3 recipes", "level": "INFO"},
|
||||
])
|
||||
|
||||
config_dir = tmp_path / "config"
|
||||
config_dir.mkdir(exist_ok=True)
|
||||
(config_dir / "label_tool.yaml").write_text(
|
||||
f"corpus:\n pipeline_ingest_dir: \"{ingest_dir}\"\n sources: []\n"
|
||||
)
|
||||
monkeypatch.setattr(lc, "_CONFIG_DIR", config_dir)
|
||||
|
||||
resp = client.post("/api/corpus/pipeline-ingest")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["ingested_files"] == 1
|
||||
assert data["entries_stored"] == 2
|
||||
|
||||
entries = client.get("/api/corpus/entries", params={"limit": 10}).json()["entries"]
|
||||
assert len(entries) == 2
|
||||
assert all(e["source_host"] == "pipeline_scrape" for e in entries)
|
||||
|
||||
|
||||
def test_pipeline_ingest_source_id_from_logger(client, tmp_path, monkeypatch):
|
||||
"""source_id is populated from the 'logger' field of each log line."""
|
||||
ingest_dir = tmp_path / "pipeline_logs"
|
||||
ingest_dir.mkdir()
|
||||
_make_pipeline_file(ingest_dir, "run_20260517.jsonl", [_PIPELINE_LINE])
|
||||
|
||||
config_dir = tmp_path / "config"
|
||||
config_dir.mkdir(exist_ok=True)
|
||||
(config_dir / "label_tool.yaml").write_text(
|
||||
f"corpus:\n pipeline_ingest_dir: \"{ingest_dir}\"\n sources: []\n"
|
||||
)
|
||||
monkeypatch.setattr(lc, "_CONFIG_DIR", config_dir)
|
||||
|
||||
client.post("/api/corpus/pipeline-ingest")
|
||||
entries = client.get("/api/corpus/entries", params={"limit": 10}).json()["entries"]
|
||||
assert entries[0]["source_id"] == "scripts.pipeline.purple_carrot_scraper"
|
||||
|
||||
|
||||
def test_pipeline_ingest_idempotent(client, tmp_path, monkeypatch):
|
||||
"""Calling the endpoint twice does not re-ingest already-processed files."""
|
||||
ingest_dir = tmp_path / "pipeline_logs"
|
||||
ingest_dir.mkdir()
|
||||
_make_pipeline_file(ingest_dir, "scraper_20260517.jsonl", [_PIPELINE_LINE])
|
||||
|
||||
config_dir = tmp_path / "config"
|
||||
config_dir.mkdir(exist_ok=True)
|
||||
(config_dir / "label_tool.yaml").write_text(
|
||||
f"corpus:\n pipeline_ingest_dir: \"{ingest_dir}\"\n sources: []\n"
|
||||
)
|
||||
monkeypatch.setattr(lc, "_CONFIG_DIR", config_dir)
|
||||
|
||||
client.post("/api/corpus/pipeline-ingest")
|
||||
resp2 = client.post("/api/corpus/pipeline-ingest")
|
||||
|
||||
data = resp2.json()
|
||||
assert data["ingested_files"] == 0
|
||||
assert data["skipped_files"] == 1
|
||||
assert data["entries_stored"] == 0
|
||||
|
||||
entries = client.get("/api/corpus/entries", params={"limit": 10}).json()["entries"]
|
||||
assert len(entries) == 1 # still just the one from the first ingest
|
||||
|
||||
|
||||
def test_pipeline_ingest_skips_non_jsonl(client, tmp_path, monkeypatch):
|
||||
"""Non-.jsonl files in the dir are silently ignored."""
|
||||
ingest_dir = tmp_path / "pipeline_logs"
|
||||
ingest_dir.mkdir()
|
||||
(ingest_dir / "notes.txt").write_text("this is not a log file")
|
||||
(ingest_dir / "run.csv").write_text("a,b,c\n1,2,3")
|
||||
_make_pipeline_file(ingest_dir, "valid_20260517.jsonl", [_PIPELINE_LINE])
|
||||
|
||||
config_dir = tmp_path / "config"
|
||||
config_dir.mkdir(exist_ok=True)
|
||||
(config_dir / "label_tool.yaml").write_text(
|
||||
f"corpus:\n pipeline_ingest_dir: \"{ingest_dir}\"\n sources: []\n"
|
||||
)
|
||||
monkeypatch.setattr(lc, "_CONFIG_DIR", config_dir)
|
||||
|
||||
resp = client.post("/api/corpus/pipeline-ingest")
|
||||
assert resp.json()["ingested_files"] == 1
|
||||
|
||||
|
||||
def test_pipeline_ingest_skips_malformed_lines(client, tmp_path, monkeypatch):
|
||||
"""Lines that are not valid JSON are skipped; valid lines in the same file still land."""
|
||||
ingest_dir = tmp_path / "pipeline_logs"
|
||||
ingest_dir.mkdir()
|
||||
p = ingest_dir / "mixed_20260517.jsonl"
|
||||
p.write_text(
|
||||
json.dumps(_PIPELINE_LINE) + "\n"
|
||||
"this is not json\n"
|
||||
+ json.dumps({**_PIPELINE_LINE, "msg": "another valid line"}),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
config_dir = tmp_path / "config"
|
||||
config_dir.mkdir(exist_ok=True)
|
||||
(config_dir / "label_tool.yaml").write_text(
|
||||
f"corpus:\n pipeline_ingest_dir: \"{ingest_dir}\"\n sources: []\n"
|
||||
)
|
||||
monkeypatch.setattr(lc, "_CONFIG_DIR", config_dir)
|
||||
|
||||
resp = client.post("/api/corpus/pipeline-ingest")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["entries_stored"] == 2 # 2 valid lines, 1 skipped
|
||||
|
||||
|
||||
def test_pipeline_ingest_new_file_after_first_run(client, tmp_path, monkeypatch):
|
||||
"""A new file added after the first ingest is picked up on the next call."""
|
||||
ingest_dir = tmp_path / "pipeline_logs"
|
||||
ingest_dir.mkdir()
|
||||
_make_pipeline_file(ingest_dir, "run_a.jsonl", [_PIPELINE_LINE])
|
||||
|
||||
config_dir = tmp_path / "config"
|
||||
config_dir.mkdir(exist_ok=True)
|
||||
(config_dir / "label_tool.yaml").write_text(
|
||||
f"corpus:\n pipeline_ingest_dir: \"{ingest_dir}\"\n sources: []\n"
|
||||
)
|
||||
monkeypatch.setattr(lc, "_CONFIG_DIR", config_dir)
|
||||
|
||||
client.post("/api/corpus/pipeline-ingest") # ingest run_a.jsonl
|
||||
|
||||
_make_pipeline_file(ingest_dir, "run_b.jsonl", [
|
||||
{**_PIPELINE_LINE, "msg": "Second run line"},
|
||||
])
|
||||
|
||||
resp2 = client.post("/api/corpus/pipeline-ingest")
|
||||
data = resp2.json()
|
||||
assert data["ingested_files"] == 1
|
||||
assert data["skipped_files"] == 1
|
||||
assert data["entries_stored"] == 1
|
||||
|
|
@ -1,627 +0,0 @@
|
|||
"""Tests for app/models.py — /api/models/* endpoints."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
|
||||
# ── Fixtures ───────────────────────────────────────────────────────────────────
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_models_globals(tmp_path):
|
||||
"""Redirect module-level dirs to tmp_path and reset download progress."""
|
||||
from app import models as models_module
|
||||
|
||||
prev_models = models_module._MODELS_DIR
|
||||
prev_cf_text = models_module._CF_TEXT_MODELS_DIR
|
||||
prev_queue = models_module._QUEUE_DIR
|
||||
prev_progress = dict(models_module._download_progress)
|
||||
|
||||
models_dir = tmp_path / "models"
|
||||
queue_dir = tmp_path / "data"
|
||||
models_dir.mkdir()
|
||||
queue_dir.mkdir()
|
||||
|
||||
models_module.set_models_dir(models_dir)
|
||||
models_module.set_cf_text_models_dir(tmp_path / "cf-text-models")
|
||||
models_module.set_queue_dir(queue_dir)
|
||||
models_module._download_progress = {}
|
||||
|
||||
yield
|
||||
|
||||
models_module.set_models_dir(prev_models)
|
||||
models_module.set_cf_text_models_dir(prev_cf_text)
|
||||
models_module.set_queue_dir(prev_queue)
|
||||
models_module._download_progress = prev_progress
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
from app.api import app
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
def _make_hf_response(repo_id: str = "org/model", pipeline_tag: str = "text-classification") -> dict:
|
||||
"""Minimal HF API response payload."""
|
||||
return {
|
||||
"modelId": repo_id,
|
||||
"pipeline_tag": pipeline_tag,
|
||||
"tags": ["pytorch", pipeline_tag],
|
||||
"downloads": 42000,
|
||||
"siblings": [
|
||||
{"rfilename": "pytorch_model.bin", "size": 500_000_000},
|
||||
],
|
||||
"cardData": {"description": "A test model description."},
|
||||
}
|
||||
|
||||
|
||||
def _queue_one(client, repo_id: str = "org/model") -> dict:
|
||||
"""Helper: POST to /queue and return the created entry."""
|
||||
r = client.post("/api/models/queue", json={
|
||||
"repo_id": repo_id,
|
||||
"pipeline_tag": "text-classification",
|
||||
"adapter_recommendation": "ZeroShotAdapter",
|
||||
})
|
||||
assert r.status_code == 201, r.text
|
||||
return r.json()
|
||||
|
||||
|
||||
# ── GET /lookup ────────────────────────────────────────────────────────────────
|
||||
|
||||
def test_lookup_invalid_repo_id_returns_422_no_slash(client):
|
||||
"""repo_id without a '/' should be rejected with 422."""
|
||||
r = client.get("/api/models/lookup", params={"repo_id": "noslash"})
|
||||
assert r.status_code == 422
|
||||
|
||||
|
||||
def test_lookup_invalid_repo_id_returns_422_whitespace(client):
|
||||
"""repo_id containing whitespace should be rejected with 422."""
|
||||
r = client.get("/api/models/lookup", params={"repo_id": "org/model name"})
|
||||
assert r.status_code == 422
|
||||
|
||||
|
||||
def test_lookup_hf_404_returns_404(client):
|
||||
"""HF API returning 404 should surface as HTTP 404."""
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 404
|
||||
|
||||
with patch("app.models.httpx.get", return_value=mock_resp):
|
||||
r = client.get("/api/models/lookup", params={"repo_id": "org/nonexistent"})
|
||||
|
||||
assert r.status_code == 404
|
||||
|
||||
|
||||
def test_lookup_hf_network_error_returns_502(client):
|
||||
"""Network error reaching HF API should return 502."""
|
||||
import httpx as _httpx
|
||||
|
||||
with patch("app.models.httpx.get", side_effect=_httpx.RequestError("timeout")):
|
||||
r = client.get("/api/models/lookup", params={"repo_id": "org/model"})
|
||||
|
||||
assert r.status_code == 502
|
||||
|
||||
|
||||
def test_lookup_returns_correct_shape(client):
|
||||
"""Successful lookup returns all required fields."""
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.json.return_value = _make_hf_response("org/mymodel", "text-classification")
|
||||
|
||||
with patch("app.models.httpx.get", return_value=mock_resp):
|
||||
r = client.get("/api/models/lookup", params={"repo_id": "org/mymodel"})
|
||||
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["repo_id"] == "org/mymodel"
|
||||
assert data["pipeline_tag"] == "text-classification"
|
||||
assert data["adapter_recommendation"] == "ZeroShotAdapter"
|
||||
assert data["model_size_bytes"] == 500_000_000
|
||||
assert data["downloads"] == 42000
|
||||
assert data["already_installed"] is False
|
||||
assert data["already_queued"] is False
|
||||
|
||||
|
||||
def test_lookup_unknown_pipeline_tag_returns_null_adapter_and_incompatible(client):
|
||||
"""An unrecognised pipeline_tag yields adapter_recommendation=null and compatible=False."""
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.json.return_value = _make_hf_response("org/m", "reinforcement-learning")
|
||||
|
||||
with patch("app.models.httpx.get", return_value=mock_resp):
|
||||
r = client.get("/api/models/lookup", params={"repo_id": "org/m"})
|
||||
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["adapter_recommendation"] is None
|
||||
assert data["compatible"] is False
|
||||
assert data["role"] is None
|
||||
assert data["service"] is None
|
||||
assert "CircuitForge model ecosystem" in data["warning"]
|
||||
|
||||
|
||||
def test_lookup_stt_tag_returns_compatible_with_cf_stt_service(client):
|
||||
"""automatic-speech-recognition tag yields compatible=True, service=cf-stt."""
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.json.return_value = _make_hf_response("openai/whisper-base", "automatic-speech-recognition")
|
||||
|
||||
with patch("app.models.httpx.get", return_value=mock_resp):
|
||||
r = client.get("/api/models/lookup", params={"repo_id": "openai/whisper-base"})
|
||||
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["compatible"] is True
|
||||
assert data["adapter_recommendation"] is None
|
||||
assert data["role"] == "stt"
|
||||
assert data["service"] == "cf-stt"
|
||||
assert data["warning"] is None
|
||||
|
||||
|
||||
def test_lookup_vision_tag_returns_compatible_with_cf_vision_service(client):
|
||||
"""image-classification tag yields compatible=True, service=cf-vision."""
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.json.return_value = _make_hf_response("google/siglip-base", "image-classification")
|
||||
|
||||
with patch("app.models.httpx.get", return_value=mock_resp):
|
||||
r = client.get("/api/models/lookup", params={"repo_id": "google/siglip-base"})
|
||||
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["compatible"] is True
|
||||
assert data["role"] == "vision"
|
||||
assert data["service"] == "cf-vision"
|
||||
|
||||
|
||||
def test_lookup_audio_classification_tag_returns_cf_voice_service(client):
|
||||
"""audio-classification tag yields compatible=True, service=cf-voice."""
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.json.return_value = _make_hf_response("org/audio-model", "audio-classification")
|
||||
|
||||
with patch("app.models.httpx.get", return_value=mock_resp):
|
||||
r = client.get("/api/models/lookup", params={"repo_id": "org/audio-model"})
|
||||
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["compatible"] is True
|
||||
assert data["role"] == "classifier"
|
||||
assert data["service"] == "cf-voice"
|
||||
|
||||
|
||||
def test_lookup_embedding_tag_returns_compatible_with_cf_core_service(client):
|
||||
"""feature-extraction tag yields compatible=True, service=cf-core."""
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.json.return_value = _make_hf_response("BAAI/bge-small-en", "feature-extraction")
|
||||
|
||||
with patch("app.models.httpx.get", return_value=mock_resp):
|
||||
r = client.get("/api/models/lookup", params={"repo_id": "BAAI/bge-small-en"})
|
||||
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["compatible"] is True
|
||||
assert data["role"] == "embedding"
|
||||
assert data["service"] == "cf-core"
|
||||
|
||||
|
||||
def test_lookup_already_queued_flag(client):
|
||||
"""already_queued is True when repo_id is in the pending queue."""
|
||||
_queue_one(client, "org/queued-model")
|
||||
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.json.return_value = _make_hf_response("org/queued-model")
|
||||
|
||||
with patch("app.models.httpx.get", return_value=mock_resp):
|
||||
r = client.get("/api/models/lookup", params={"repo_id": "org/queued-model"})
|
||||
|
||||
assert r.status_code == 200
|
||||
assert r.json()["already_queued"] is True
|
||||
|
||||
|
||||
# ── GET /queue ─────────────────────────────────────────────────────────────────
|
||||
|
||||
def test_queue_empty_initially(client):
|
||||
r = client.get("/api/models/queue")
|
||||
assert r.status_code == 200
|
||||
assert r.json() == []
|
||||
|
||||
|
||||
def test_queue_add_and_list(client):
|
||||
"""POST then GET /queue should return the entry."""
|
||||
entry = _queue_one(client, "org/my-model")
|
||||
|
||||
r = client.get("/api/models/queue")
|
||||
assert r.status_code == 200
|
||||
items = r.json()
|
||||
assert len(items) == 1
|
||||
assert items[0]["repo_id"] == "org/my-model"
|
||||
assert items[0]["status"] == "pending"
|
||||
assert items[0]["id"] == entry["id"]
|
||||
|
||||
|
||||
def test_queue_add_returns_entry_fields(client):
|
||||
"""POST /queue returns an entry with all expected fields."""
|
||||
entry = _queue_one(client)
|
||||
assert "id" in entry
|
||||
assert "queued_at" in entry
|
||||
assert entry["status"] == "pending"
|
||||
assert entry["pipeline_tag"] == "text-classification"
|
||||
assert entry["adapter_recommendation"] == "ZeroShotAdapter"
|
||||
|
||||
|
||||
def test_queue_preserves_role_and_service(client):
|
||||
"""POST /queue with role/service fields round-trips them through GET /queue."""
|
||||
r = client.post("/api/models/queue", json={
|
||||
"repo_id": "openai/whisper-base",
|
||||
"pipeline_tag": "automatic-speech-recognition",
|
||||
"adapter_recommendation": None,
|
||||
"role": "stt",
|
||||
"service": "cf-stt",
|
||||
})
|
||||
assert r.status_code == 201
|
||||
entry = r.json()
|
||||
assert entry["role"] == "stt"
|
||||
assert entry["service"] == "cf-stt"
|
||||
|
||||
r2 = client.get("/api/models/queue")
|
||||
items = r2.json()
|
||||
assert items[0]["role"] == "stt"
|
||||
assert items[0]["service"] == "cf-stt"
|
||||
|
||||
|
||||
# ── POST /queue — 409 duplicate ────────────────────────────────────────────────
|
||||
|
||||
def test_queue_duplicate_returns_409(client):
|
||||
"""Posting the same repo_id twice should return 409."""
|
||||
_queue_one(client, "org/dup-model")
|
||||
|
||||
r = client.post("/api/models/queue", json={
|
||||
"repo_id": "org/dup-model",
|
||||
"pipeline_tag": "text-classification",
|
||||
"adapter_recommendation": "ZeroShotAdapter",
|
||||
})
|
||||
assert r.status_code == 409
|
||||
|
||||
|
||||
def test_queue_multiple_different_models(client):
|
||||
"""Multiple distinct repo_ids should all be accepted."""
|
||||
_queue_one(client, "org/model-a")
|
||||
_queue_one(client, "org/model-b")
|
||||
_queue_one(client, "org/model-c")
|
||||
|
||||
r = client.get("/api/models/queue")
|
||||
assert r.status_code == 200
|
||||
assert len(r.json()) == 3
|
||||
|
||||
|
||||
# ── DELETE /queue/{id} — dismiss ──────────────────────────────────────────────
|
||||
|
||||
def test_queue_dismiss(client):
|
||||
"""DELETE /queue/{id} sets status=dismissed; entry not returned by GET /queue."""
|
||||
entry = _queue_one(client)
|
||||
entry_id = entry["id"]
|
||||
|
||||
r = client.delete(f"/api/models/queue/{entry_id}")
|
||||
assert r.status_code == 200
|
||||
assert r.json() == {"ok": True}
|
||||
|
||||
r2 = client.get("/api/models/queue")
|
||||
assert r2.status_code == 200
|
||||
assert r2.json() == []
|
||||
|
||||
|
||||
def test_queue_dismiss_nonexistent_returns_404(client):
|
||||
"""DELETE /queue/{id} with unknown id returns 404."""
|
||||
r = client.delete("/api/models/queue/does-not-exist")
|
||||
assert r.status_code == 404
|
||||
|
||||
|
||||
def test_queue_dismiss_allows_re_queue(client):
|
||||
"""After dismissal the same repo_id can be queued again."""
|
||||
entry = _queue_one(client, "org/requeue-model")
|
||||
client.delete(f"/api/models/queue/{entry['id']}")
|
||||
|
||||
r = client.post("/api/models/queue", json={
|
||||
"repo_id": "org/requeue-model",
|
||||
"pipeline_tag": None,
|
||||
"adapter_recommendation": None,
|
||||
})
|
||||
assert r.status_code == 201
|
||||
|
||||
|
||||
# ── POST /queue/{id}/approve ───────────────────────────────────────────────────
|
||||
|
||||
def test_approve_nonexistent_returns_404(client):
|
||||
"""Approving an unknown id returns 404."""
|
||||
r = client.post("/api/models/queue/ghost-id/approve")
|
||||
assert r.status_code == 404
|
||||
|
||||
|
||||
def test_approve_non_pending_returns_409(client):
|
||||
"""Approving an entry that is not in 'pending' state returns 409."""
|
||||
from app import models as models_module
|
||||
|
||||
entry = _queue_one(client)
|
||||
# Manually flip status to 'failed'
|
||||
models_module._update_queue_entry(entry["id"], {"status": "failed"})
|
||||
|
||||
r = client.post(f"/api/models/queue/{entry['id']}/approve")
|
||||
assert r.status_code == 409
|
||||
|
||||
|
||||
def test_approve_starts_download_and_returns_ok(client):
|
||||
"""Approving a pending entry returns {ok: true} and starts a background thread."""
|
||||
import time
|
||||
import threading
|
||||
|
||||
entry = _queue_one(client)
|
||||
|
||||
# Patch snapshot_download so the thread doesn't actually hit the network.
|
||||
# Use an Event so we can wait for the thread to finish before asserting.
|
||||
thread_done = threading.Event()
|
||||
original_run = None
|
||||
|
||||
def _fake_snapshot_download(**kwargs):
|
||||
pass
|
||||
|
||||
with patch("app.models.snapshot_download", side_effect=_fake_snapshot_download):
|
||||
r = client.post(f"/api/models/queue/{entry['id']}/approve")
|
||||
assert r.status_code == 200
|
||||
assert r.json() == {"ok": True}
|
||||
# Give the background thread a moment to complete while snapshot_download is patched
|
||||
time.sleep(0.3)
|
||||
|
||||
# Queue entry status should have moved to 'downloading' (or 'ready' if fast)
|
||||
from app import models as models_module
|
||||
updated = models_module._get_queue_entry(entry["id"])
|
||||
assert updated is not None, "Queue entry not found — thread may have run after fixture teardown"
|
||||
assert updated["status"] in ("downloading", "ready", "failed")
|
||||
|
||||
|
||||
# ── GET /download/stream ───────────────────────────────────────────────────────
|
||||
|
||||
def test_download_stream_idle_when_no_download(client):
|
||||
"""GET /download/stream returns a single idle event when nothing is downloading."""
|
||||
r = client.get("/api/models/download/stream")
|
||||
assert r.status_code == 200
|
||||
# SSE body should contain the idle event
|
||||
assert "idle" in r.text
|
||||
|
||||
|
||||
# ── GET /installed ─────────────────────────────────────────────────────────────
|
||||
|
||||
def test_installed_empty(client):
|
||||
"""GET /installed returns [] when models dir is empty."""
|
||||
r = client.get("/api/models/installed")
|
||||
assert r.status_code == 200
|
||||
assert r.json() == []
|
||||
|
||||
|
||||
def test_installed_detects_downloaded_model(client, tmp_path):
|
||||
"""A subdir with config.json is surfaced as type='downloaded'."""
|
||||
from app import models as models_module
|
||||
|
||||
model_dir = models_module._MODELS_DIR / "org--mymodel"
|
||||
model_dir.mkdir()
|
||||
(model_dir / "config.json").write_text(json.dumps({"model_type": "bert"}), encoding="utf-8")
|
||||
(model_dir / "model_info.json").write_text(
|
||||
json.dumps({
|
||||
"repo_id": "org/mymodel",
|
||||
"adapter_recommendation": "ZeroShotAdapter",
|
||||
"role": "classifier",
|
||||
"service": "avocet",
|
||||
}),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
r = client.get("/api/models/installed")
|
||||
assert r.status_code == 200
|
||||
items = r.json()
|
||||
assert len(items) == 1
|
||||
assert items[0]["type"] == "downloaded"
|
||||
assert items[0]["name"] == "org--mymodel"
|
||||
assert items[0]["adapter"] == "ZeroShotAdapter"
|
||||
assert items[0]["model_id"] == "org/mymodel"
|
||||
assert items[0]["role"] == "classifier"
|
||||
assert items[0]["service"] == "avocet"
|
||||
|
||||
|
||||
def test_installed_stt_model_surfaces_role_and_service(client):
|
||||
"""A downloaded STT model's role/service are returned by GET /installed."""
|
||||
from app import models as models_module
|
||||
|
||||
model_dir = models_module._MODELS_DIR / "openai--whisper-base"
|
||||
model_dir.mkdir()
|
||||
(model_dir / "config.json").write_text(json.dumps({"model_type": "whisper"}), encoding="utf-8")
|
||||
(model_dir / "model_info.json").write_text(
|
||||
json.dumps({
|
||||
"repo_id": "openai/whisper-base",
|
||||
"adapter_recommendation": None,
|
||||
"role": "stt",
|
||||
"service": "cf-stt",
|
||||
}),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
r = client.get("/api/models/installed")
|
||||
assert r.status_code == 200
|
||||
items = r.json()
|
||||
assert items[0]["role"] == "stt"
|
||||
assert items[0]["service"] == "cf-stt"
|
||||
assert items[0]["adapter"] is None
|
||||
|
||||
|
||||
def test_installed_finetuned_model_defaults_to_avocet_service(client):
|
||||
"""Fine-tuned models with no role/service in training_info default to avocet/classifier."""
|
||||
from app import models as models_module
|
||||
|
||||
model_dir = models_module._MODELS_DIR / "my-finetuned-v2"
|
||||
model_dir.mkdir()
|
||||
(model_dir / "training_info.json").write_text(
|
||||
json.dumps({"base_model": "microsoft/deberta-v3-base", "epochs": 3}),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
r = client.get("/api/models/installed")
|
||||
assert r.status_code == 200
|
||||
items = r.json()
|
||||
assert items[0]["role"] == "classifier"
|
||||
assert items[0]["service"] == "avocet"
|
||||
|
||||
|
||||
def test_installed_detects_finetuned_model(client):
|
||||
"""A subdir with training_info.json is surfaced as type='finetuned'."""
|
||||
from app import models as models_module
|
||||
|
||||
model_dir = models_module._MODELS_DIR / "my-finetuned"
|
||||
model_dir.mkdir()
|
||||
(model_dir / "training_info.json").write_text(
|
||||
json.dumps({"base_model": "org/base", "epochs": 5}), encoding="utf-8"
|
||||
)
|
||||
|
||||
r = client.get("/api/models/installed")
|
||||
assert r.status_code == 200
|
||||
items = r.json()
|
||||
assert len(items) == 1
|
||||
assert items[0]["type"] == "finetuned"
|
||||
assert items[0]["name"] == "my-finetuned"
|
||||
|
||||
|
||||
# ── DELETE /installed/{name} ───────────────────────────────────────────────────
|
||||
|
||||
def test_delete_installed_removes_directory(client):
|
||||
"""DELETE /installed/{name} removes the directory and returns {ok: true}."""
|
||||
from app import models as models_module
|
||||
|
||||
model_dir = models_module._MODELS_DIR / "org--removeme"
|
||||
model_dir.mkdir()
|
||||
(model_dir / "config.json").write_text("{}", encoding="utf-8")
|
||||
|
||||
r = client.delete("/api/models/installed/org--removeme")
|
||||
assert r.status_code == 200
|
||||
assert r.json() == {"ok": True}
|
||||
assert not model_dir.exists()
|
||||
|
||||
|
||||
def test_delete_installed_not_found_returns_404(client):
|
||||
r = client.delete("/api/models/installed/does-not-exist")
|
||||
assert r.status_code == 404
|
||||
|
||||
|
||||
def test_delete_installed_path_traversal_blocked(client):
|
||||
"""DELETE /installed/../../etc must be blocked.
|
||||
Path traversal normalises to a different URL (/api/etc); if web/dist exists
|
||||
the StaticFiles mount intercepts it and returns 405 (GET/HEAD only).
|
||||
"""
|
||||
r = client.delete("/api/models/installed/../../etc")
|
||||
assert r.status_code in (400, 404, 405, 422)
|
||||
|
||||
|
||||
def test_delete_installed_dotdot_name_blocked(client):
|
||||
"""A name containing '..' in any form must be rejected."""
|
||||
r = client.delete("/api/models/installed/..%2F..%2Fetc")
|
||||
assert r.status_code in (400, 404, 405, 422)
|
||||
|
||||
|
||||
def test_delete_installed_name_with_slash_blocked(client):
|
||||
"""A name containing a literal '/' after URL decoding must be rejected."""
|
||||
from app import models as models_module
|
||||
|
||||
# The router will see the path segment after /installed/ — a second '/' would
|
||||
# be parsed as a new path segment, so we test via the validation helper directly.
|
||||
with pytest.raises(Exception):
|
||||
# Simulate calling delete logic with a slash-containing name directly
|
||||
from fastapi import HTTPException as _HTTPException
|
||||
from app.models import delete_installed
|
||||
try:
|
||||
delete_installed("org/traversal")
|
||||
except _HTTPException as exc:
|
||||
assert exc.status_code in (400, 404)
|
||||
raise
|
||||
|
||||
|
||||
# ── Catalog registration ───────────────────────────────────────────────────────
|
||||
|
||||
_MINIMAL_YAML = """\
|
||||
services:
|
||||
cf-text:
|
||||
max_mb: {max_mb}
|
||||
catalog:
|
||||
existing-model:
|
||||
path: /some/path
|
||||
vram_mb: 1000
|
||||
description: "placeholder"
|
||||
"""
|
||||
|
||||
|
||||
def _make_node_yaml(tmp_path: Path, max_mb: int = 8192) -> Path:
|
||||
p = tmp_path / "testnode.yaml"
|
||||
p.write_text(_MINIMAL_YAML.format(max_mb=max_mb), encoding="utf-8")
|
||||
return p
|
||||
|
||||
|
||||
def test_catalog_registration_fp16_no_env_block(tmp_path):
|
||||
"""When model fits at FP16, no env block should be written."""
|
||||
from app import models as models_module
|
||||
|
||||
node_yaml = _make_node_yaml(tmp_path, max_mb=8192)
|
||||
with patch.object(models_module, "_CF_ORCH_PROFILES_DIR", tmp_path):
|
||||
updated = models_module._register_in_node_catalogs(
|
||||
repo_id="org/SmallModel",
|
||||
local_path=tmp_path / "org--SmallModel",
|
||||
vram_mb_fp16=4000,
|
||||
role="generator",
|
||||
)
|
||||
|
||||
assert "testnode" in updated
|
||||
content = node_yaml.read_text()
|
||||
# _catalog_key strips org prefix and lowercases: "org/SmallModel" → "smallmodel"
|
||||
assert "smallmodel:" in content
|
||||
assert "CF_TEXT_4BIT" not in content
|
||||
assert "env:" not in content
|
||||
|
||||
|
||||
def test_catalog_registration_needs_4bit_writes_env_block(tmp_path):
|
||||
"""When model only fits at 4-bit, env: CF_TEXT_4BIT: '1' must be written."""
|
||||
from app import models as models_module
|
||||
|
||||
node_yaml = _make_node_yaml(tmp_path, max_mb=8192)
|
||||
with patch.object(models_module, "_CF_ORCH_PROFILES_DIR", tmp_path):
|
||||
updated = models_module._register_in_node_catalogs(
|
||||
repo_id="org/BigModel",
|
||||
local_path=tmp_path / "org--BigModel",
|
||||
vram_mb_fp16=20000, # won't fit at FP16 on 8 GB
|
||||
role="generator",
|
||||
)
|
||||
|
||||
assert "testnode" in updated
|
||||
content = node_yaml.read_text()
|
||||
# _catalog_key: "org/BigModel" → "bigmodel"
|
||||
assert "bigmodel:" in content
|
||||
assert "env:" in content
|
||||
assert 'CF_TEXT_4BIT: "1"' in content
|
||||
assert "CF_TEXT_4BIT=1 required" in content # description note
|
||||
|
||||
|
||||
def test_catalog_registration_too_large_skipped(tmp_path):
|
||||
"""Model too large even at 4-bit should not be registered."""
|
||||
from app import models as models_module
|
||||
|
||||
node_yaml = _make_node_yaml(tmp_path, max_mb=8192)
|
||||
with patch.object(models_module, "_CF_ORCH_PROFILES_DIR", tmp_path):
|
||||
updated = models_module._register_in_node_catalogs(
|
||||
repo_id="org/HugeModel",
|
||||
local_path=tmp_path / "org--HugeModel",
|
||||
vram_mb_fp16=80000, # 4-bit ~22 GB, still won't fit on 8 GB
|
||||
role="generator",
|
||||
)
|
||||
|
||||
assert updated == []
|
||||
content = node_yaml.read_text()
|
||||
assert "hugemodel" not in content
|
||||
|
|
@ -1,575 +0,0 @@
|
|||
"""Tests for app/nodes.py — /api/nodes-mgmt/* endpoints."""
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
from fastapi.testclient import TestClient
|
||||
from unittest.mock import MagicMock, patch
|
||||
import os as _os
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_nodes_globals(tmp_path):
|
||||
"""Redirect _CONFIG_DIR to tmp_path so tests never read the real config."""
|
||||
from app import nodes as nodes_module
|
||||
prev = nodes_module._CONFIG_DIR
|
||||
nodes_module.set_config_dir(tmp_path)
|
||||
yield tmp_path
|
||||
nodes_module.set_config_dir(prev)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
from app.api import app
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
def _write_config(config_dir: Path, cforch_cfg: dict) -> None:
|
||||
cfg = {"cforch": cforch_cfg}
|
||||
(config_dir / "label_tool.yaml").write_text(yaml.dump(cfg), encoding="utf-8")
|
||||
|
||||
|
||||
def _write_profile(profiles_dir: Path, node_id: str, profile: dict) -> None:
|
||||
profiles_dir.mkdir(parents=True, exist_ok=True)
|
||||
(profiles_dir / f"{node_id}.yaml").write_text(yaml.dump(profile), encoding="utf-8")
|
||||
|
||||
|
||||
def test_nodes_module_imports():
|
||||
from app import nodes
|
||||
assert hasattr(nodes, "router")
|
||||
assert hasattr(nodes, "set_config_dir")
|
||||
|
||||
|
||||
def test_list_nodes_returns_empty_when_no_coordinator(client):
|
||||
"""No cforch config — endpoint returns empty list, not 500."""
|
||||
r = client.get("/api/nodes-mgmt/nodes")
|
||||
assert r.status_code == 200
|
||||
assert r.json() == []
|
||||
|
||||
|
||||
|
||||
|
||||
def _fake_nodes_response(nodes_json: list, services_json: list | None = None):
|
||||
"""Build side_effect list for two httpx.get calls: nodes then services."""
|
||||
mock_nodes = MagicMock()
|
||||
mock_nodes.raise_for_status = MagicMock()
|
||||
mock_nodes.json.return_value = {"nodes": nodes_json}
|
||||
|
||||
mock_services = MagicMock()
|
||||
mock_services.raise_for_status = MagicMock()
|
||||
mock_services.json.return_value = {"services": services_json or []}
|
||||
|
||||
return [mock_nodes, mock_services]
|
||||
|
||||
|
||||
def test_list_nodes_coordinator_unreachable_returns_empty(client, tmp_path):
|
||||
"""Coordinator unreachable — returns [] with no 500."""
|
||||
import httpx
|
||||
_write_config(tmp_path, {"coordinator_url": "http://fake-coord:7700"})
|
||||
with patch("httpx.get", side_effect=httpx.ConnectError("refused")):
|
||||
r = client.get("/api/nodes-mgmt/nodes")
|
||||
assert r.status_code == 200
|
||||
assert r.json() == []
|
||||
|
||||
|
||||
def test_list_nodes_merges_profile_data(client, tmp_path):
|
||||
"""Profile YAML services_assigned merged with live GPU stats."""
|
||||
profiles_dir = tmp_path / "profiles"
|
||||
_write_config(tmp_path, {
|
||||
"coordinator_url": "http://fake-coord:7700",
|
||||
"profiles_dir": str(profiles_dir),
|
||||
})
|
||||
_write_profile(profiles_dir, "heimdall", {
|
||||
"services": {
|
||||
"cf-text": {"min_compute_cap": 7.0, "max_mb": 8192, "catalog": {}},
|
||||
},
|
||||
"nodes": {
|
||||
"heimdall": {
|
||||
"gpus": [{"id": 0, "vram_mb": 24576, "compute_cap": 8.6,
|
||||
"services": ["cf-text"], "role": "primary", "card": "RTX 3090",
|
||||
"always_on": True}],
|
||||
"agent_url": "http://10.1.10.71:7701",
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
coord_nodes = [{
|
||||
"node_id": "heimdall", "online": True, "agent_url": "http://10.1.10.71:7701",
|
||||
"gpus": [{"gpu_id": 0, "card": "RTX 3090", "vram_total_mb": 24576,
|
||||
"vram_used_mb": 4096, "vram_free_mb": 20480,
|
||||
"temp_c": 42.0, "utilization_pct": 15.0, "compute_cap": 8.6}],
|
||||
}]
|
||||
|
||||
with patch("httpx.get", side_effect=_fake_nodes_response(coord_nodes)):
|
||||
r = client.get("/api/nodes-mgmt/nodes")
|
||||
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert len(data) == 1
|
||||
node = data[0]
|
||||
assert node["node_id"] == "heimdall"
|
||||
assert node["profile_loaded"] is True
|
||||
assert node["gpus"][0]["services_assigned"] == ["cf-text"]
|
||||
assert node["gpus"][0]["vram_total_mb"] == 24576
|
||||
assert "cf-text" in node["services_catalog"]
|
||||
|
||||
|
||||
def test_list_nodes_no_profile_returns_profile_loaded_false(client, tmp_path):
|
||||
"""Node with no profile YAML — profile_loaded: false, GPU stats still returned."""
|
||||
_write_config(tmp_path, {"coordinator_url": "http://fake-coord:7700"})
|
||||
|
||||
coord_nodes = [{
|
||||
"node_id": "sif", "online": True, "agent_url": "http://10.1.10.158:7701",
|
||||
"gpus": [{"gpu_id": 0, "card": "RTX 5060 Ti", "vram_total_mb": 16384,
|
||||
"vram_used_mb": 0, "vram_free_mb": 16384,
|
||||
"temp_c": None, "utilization_pct": None, "compute_cap": 10.0}],
|
||||
}]
|
||||
|
||||
with patch("httpx.get", side_effect=_fake_nodes_response(coord_nodes)):
|
||||
r = client.get("/api/nodes-mgmt/nodes")
|
||||
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
node = data[0]
|
||||
assert node["profile_loaded"] is False
|
||||
assert node["gpus"][0]["card"] == "RTX 5060 Ti"
|
||||
assert node["services_catalog"] == {}
|
||||
|
||||
|
||||
def test_list_nodes_marks_running_services(client, tmp_path):
|
||||
"""services_running populated from coordinator /api/services response."""
|
||||
profiles_dir = tmp_path / "profiles"
|
||||
_write_config(tmp_path, {
|
||||
"coordinator_url": "http://fake-coord:7700",
|
||||
"profiles_dir": str(profiles_dir),
|
||||
})
|
||||
_write_profile(profiles_dir, "heimdall", {
|
||||
"services": {},
|
||||
"nodes": {"heimdall": {"gpus": [{"id": 0, "vram_mb": 24576, "compute_cap": 8.6,
|
||||
"services": ["cf-text"], "role": "p",
|
||||
"card": "RTX 3090", "always_on": True}],
|
||||
"agent_url": "http://10.1.10.71:7701"}}
|
||||
})
|
||||
|
||||
coord_nodes = [{"node_id": "heimdall", "online": True,
|
||||
"agent_url": "http://10.1.10.71:7701",
|
||||
"gpus": [{"gpu_id": 0, "card": "RTX 3090", "vram_total_mb": 24576,
|
||||
"vram_used_mb": 8192, "vram_free_mb": 16384,
|
||||
"temp_c": 55.0, "utilization_pct": 80.0, "compute_cap": 8.6}]}]
|
||||
coord_services = [{"service": "cf-text", "node_id": "heimdall", "gpu_id": 0}]
|
||||
|
||||
with patch("httpx.get", side_effect=_fake_nodes_response(coord_nodes, coord_services)):
|
||||
r = client.get("/api/nodes-mgmt/nodes")
|
||||
|
||||
data = r.json()
|
||||
assert data[0]["gpus"][0]["services_running"] == ["cf-text"]
|
||||
|
||||
|
||||
# ── GET /api/nodes-mgmt/nodes/{node_id}/profile ────────────────────────────────
|
||||
|
||||
def test_get_profile_returns_parsed_yaml(client, tmp_path):
|
||||
profiles_dir = tmp_path / "profiles"
|
||||
_write_config(tmp_path, {"profiles_dir": str(profiles_dir)})
|
||||
profile = {
|
||||
"services": {"cf-text": {"min_compute_cap": 7.0, "max_mb": 8192, "catalog": {}}},
|
||||
"nodes": {"heimdall": {"gpus": [], "agent_url": "http://10.1.10.71:7701"}},
|
||||
}
|
||||
_write_profile(profiles_dir, "heimdall", profile)
|
||||
|
||||
r = client.get("/api/nodes-mgmt/nodes/heimdall/profile")
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert "services" in data
|
||||
assert "cf-text" in data["services"]
|
||||
|
||||
|
||||
def test_get_profile_404_when_missing(client, tmp_path):
|
||||
_write_config(tmp_path, {"profiles_dir": str(tmp_path / "profiles")})
|
||||
r = client.get("/api/nodes-mgmt/nodes/nonexistent/profile")
|
||||
assert r.status_code == 404
|
||||
|
||||
|
||||
def test_get_profile_500_on_malformed_yaml(client, tmp_path):
|
||||
profiles_dir = tmp_path / "profiles"
|
||||
profiles_dir.mkdir()
|
||||
_write_config(tmp_path, {"profiles_dir": str(profiles_dir)})
|
||||
(profiles_dir / "bad.yaml").write_text("key: [unclosed", encoding="utf-8")
|
||||
|
||||
r = client.get("/api/nodes-mgmt/nodes/bad/profile")
|
||||
assert r.status_code == 500
|
||||
|
||||
|
||||
# ── POST /api/nodes-mgmt/nodes/{node_id}/gpu/{gpu_id}/services ─────────────────
|
||||
|
||||
|
||||
_BASE_PROFILE = {
|
||||
"services": {
|
||||
"cf-text": {"min_compute_cap": 7.0, "max_mb": 8192, "priority": 1,
|
||||
"catalog": {"llama3": {"vram_mb": 6144, "path": "/m/llama3",
|
||||
"description": "", "multi_gpu": False, "env": {}}}},
|
||||
"ollama": {"min_compute_cap": 0.0, "max_mb": 2048, "priority": 2, "catalog": {}},
|
||||
},
|
||||
"nodes": {
|
||||
"heimdall": {
|
||||
"gpus": [{"id": 0, "vram_mb": 24576, "compute_cap": 8.6,
|
||||
"services": [], "role": "primary", "card": "RTX 3090",
|
||||
"always_on": True}],
|
||||
"agent_url": "http://10.1.10.71:7701",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def _setup_profile(tmp_path, profile=None):
|
||||
profiles_dir = tmp_path / "profiles"
|
||||
_write_config(tmp_path, {
|
||||
"coordinator_url": "http://fake-coord:7700",
|
||||
"profiles_dir": str(profiles_dir),
|
||||
})
|
||||
_write_profile(profiles_dir, "heimdall", profile or _BASE_PROFILE)
|
||||
return profiles_dir
|
||||
|
||||
|
||||
def test_update_services_compatible_writes_and_reloads(client, tmp_path):
|
||||
profiles_dir = _setup_profile(tmp_path)
|
||||
|
||||
mock_reload = MagicMock()
|
||||
mock_reload.status_code = 200
|
||||
|
||||
with patch("httpx.post", return_value=mock_reload):
|
||||
r = client.post(
|
||||
"/api/nodes-mgmt/nodes/heimdall/gpu/0/services",
|
||||
json={"services": ["cf-text"]},
|
||||
)
|
||||
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["ok"] is True
|
||||
assert data["reloaded"] is True
|
||||
|
||||
saved = yaml.safe_load((profiles_dir / "heimdall.yaml").read_text())
|
||||
assert saved["nodes"]["heimdall"]["gpus"][0]["services"] == ["cf-text"]
|
||||
|
||||
|
||||
def test_update_services_atomic_write_uses_tmp_file(client, tmp_path):
|
||||
"""YAML must be written to .tmp then renamed — never written directly."""
|
||||
profiles_dir = _setup_profile(tmp_path)
|
||||
renamed_pairs: list[tuple] = []
|
||||
|
||||
original_replace = _os.replace
|
||||
|
||||
def capture(src, dst):
|
||||
renamed_pairs.append((str(src), str(dst)))
|
||||
original_replace(src, dst)
|
||||
|
||||
with patch("os.replace", side_effect=capture), \
|
||||
patch("httpx.post", return_value=MagicMock(status_code=200)):
|
||||
client.post(
|
||||
"/api/nodes-mgmt/nodes/heimdall/gpu/0/services",
|
||||
json={"services": ["ollama"]},
|
||||
)
|
||||
|
||||
assert any(src.endswith(".tmp") for src, dst in renamed_pairs), \
|
||||
"Expected atomic write via .tmp rename"
|
||||
|
||||
|
||||
def test_update_services_incompatible_compute_cap_returns_422(client, tmp_path):
|
||||
low_cap_profile = {
|
||||
**_BASE_PROFILE,
|
||||
"nodes": {
|
||||
"heimdall": {
|
||||
"gpus": [{"id": 0, "vram_mb": 24576, "compute_cap": 6.0,
|
||||
"services": [], "role": "p", "card": "GTX 1080",
|
||||
"always_on": False}],
|
||||
"agent_url": "http://10.1.10.71:7701",
|
||||
}
|
||||
}
|
||||
}
|
||||
_setup_profile(tmp_path, low_cap_profile)
|
||||
|
||||
r = client.post(
|
||||
"/api/nodes-mgmt/nodes/heimdall/gpu/0/services",
|
||||
json={"services": ["cf-text"]},
|
||||
)
|
||||
assert r.status_code == 422
|
||||
assert "compute_cap" in r.json()["detail"]
|
||||
|
||||
|
||||
def test_update_services_insufficient_vram_returns_422(client, tmp_path):
|
||||
tiny_vram_profile = {
|
||||
**_BASE_PROFILE,
|
||||
"nodes": {
|
||||
"heimdall": {
|
||||
"gpus": [{"id": 0, "vram_mb": 512, "compute_cap": 8.6,
|
||||
"services": [], "role": "p", "card": "old",
|
||||
"always_on": False}],
|
||||
"agent_url": "http://10.1.10.71:7701",
|
||||
}
|
||||
}
|
||||
}
|
||||
_setup_profile(tmp_path, tiny_vram_profile)
|
||||
|
||||
r = client.post(
|
||||
"/api/nodes-mgmt/nodes/heimdall/gpu/0/services",
|
||||
json={"services": ["cf-text"]},
|
||||
)
|
||||
assert r.status_code == 422
|
||||
assert "VRAM" in r.json()["detail"]
|
||||
|
||||
|
||||
def test_update_services_unknown_service_returns_422(client, tmp_path):
|
||||
_setup_profile(tmp_path)
|
||||
r = client.post(
|
||||
"/api/nodes-mgmt/nodes/heimdall/gpu/0/services",
|
||||
json={"services": ["not-a-real-service"]},
|
||||
)
|
||||
assert r.status_code == 422
|
||||
|
||||
|
||||
def test_update_services_reload_failure_returns_reloaded_false(client, tmp_path):
|
||||
"""YAML saved but coordinator reload fails — ok: true, reloaded: false."""
|
||||
_setup_profile(tmp_path)
|
||||
|
||||
mock_reload = MagicMock()
|
||||
mock_reload.status_code = 500
|
||||
|
||||
with patch("httpx.post", return_value=mock_reload):
|
||||
r = client.post(
|
||||
"/api/nodes-mgmt/nodes/heimdall/gpu/0/services",
|
||||
json={"services": ["ollama"]},
|
||||
)
|
||||
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["ok"] is True
|
||||
assert data["reloaded"] is False
|
||||
|
||||
# ── Ollama endpoints ───────────────────────────────────────────────────────────
|
||||
|
||||
_OLLAMA_PROFILE = {
|
||||
"services": {},
|
||||
"nodes": {
|
||||
"heimdall": {
|
||||
"gpus": [],
|
||||
"agent_url": "http://10.1.10.71:7701",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def test_list_ollama_models_proxies_tags(client, tmp_path):
|
||||
profiles_dir = tmp_path / "profiles"
|
||||
_write_config(tmp_path, {"profiles_dir": str(profiles_dir)})
|
||||
_write_profile(profiles_dir, "heimdall", _OLLAMA_PROFILE)
|
||||
|
||||
mock_tags = MagicMock()
|
||||
mock_tags.raise_for_status = MagicMock()
|
||||
mock_tags.json.return_value = {
|
||||
"models": [{"name": "nomic-embed-text", "size": 274000000, "modified_at": "2025-01-01"}]
|
||||
}
|
||||
|
||||
with patch("httpx.get", return_value=mock_tags):
|
||||
r = client.get("/api/nodes-mgmt/nodes/heimdall/models/ollama")
|
||||
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert len(data["models"]) == 1
|
||||
assert data["models"][0]["name"] == "nomic-embed-text"
|
||||
|
||||
|
||||
def test_list_ollama_models_unreachable_returns_error(client, tmp_path):
|
||||
import httpx as _httpx
|
||||
profiles_dir = tmp_path / "profiles"
|
||||
_write_config(tmp_path, {"profiles_dir": str(profiles_dir)})
|
||||
_write_profile(profiles_dir, "heimdall", _OLLAMA_PROFILE)
|
||||
|
||||
with patch("httpx.get", side_effect=_httpx.ConnectError("refused")):
|
||||
r = client.get("/api/nodes-mgmt/nodes/heimdall/models/ollama")
|
||||
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert "error" in data
|
||||
|
||||
|
||||
def test_pull_ollama_model_streams_sse(client, tmp_path):
|
||||
profiles_dir = tmp_path / "profiles"
|
||||
_write_config(tmp_path, {"profiles_dir": str(profiles_dir)})
|
||||
_write_profile(profiles_dir, "heimdall", _OLLAMA_PROFILE)
|
||||
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.iter_lines.return_value = iter([
|
||||
'{"status": "pulling manifest"}',
|
||||
'{"status": "pulling", "digest": "sha256-abc", "total": 1000, "completed": 500}',
|
||||
'{"status": "success"}',
|
||||
])
|
||||
|
||||
with patch("httpx.stream") as mock_stream_fn:
|
||||
mock_stream_fn.return_value.__enter__ = MagicMock(return_value=mock_resp)
|
||||
mock_stream_fn.return_value.__exit__ = MagicMock(return_value=False)
|
||||
r = client.post(
|
||||
"/api/nodes-mgmt/nodes/heimdall/models/ollama/pull",
|
||||
json={"name": "nomic-embed-text"},
|
||||
)
|
||||
|
||||
assert r.status_code == 200
|
||||
body = r.text
|
||||
assert 'data: {"status": "pulling manifest"}' in body
|
||||
assert 'data: {"status": "success"}' in body
|
||||
|
||||
|
||||
def test_pull_ollama_model_error_event_in_stream(client, tmp_path):
|
||||
profiles_dir = tmp_path / "profiles"
|
||||
_write_config(tmp_path, {"profiles_dir": str(profiles_dir)})
|
||||
_write_profile(profiles_dir, "heimdall", _OLLAMA_PROFILE)
|
||||
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.iter_lines.return_value = iter([
|
||||
'{"error": "permission denied: /var/lib/ollama/sha256-abc-partial-0"}',
|
||||
])
|
||||
|
||||
with patch("httpx.stream") as mock_stream_fn:
|
||||
mock_stream_fn.return_value.__enter__ = MagicMock(return_value=mock_resp)
|
||||
mock_stream_fn.return_value.__exit__ = MagicMock(return_value=False)
|
||||
r = client.post(
|
||||
"/api/nodes-mgmt/nodes/heimdall/models/ollama/pull",
|
||||
json={"name": "nomic-embed-text"},
|
||||
)
|
||||
|
||||
assert r.status_code == 200
|
||||
assert "permission denied" in r.text
|
||||
|
||||
|
||||
def test_delete_ollama_model_proxies_delete(client, tmp_path):
|
||||
profiles_dir = tmp_path / "profiles"
|
||||
_write_config(tmp_path, {"profiles_dir": str(profiles_dir)})
|
||||
_write_profile(profiles_dir, "heimdall", _OLLAMA_PROFILE)
|
||||
|
||||
mock_del = MagicMock()
|
||||
mock_del.status_code = 200
|
||||
mock_del.raise_for_status = MagicMock()
|
||||
|
||||
with patch("httpx.request", return_value=mock_del):
|
||||
r = client.delete("/api/nodes-mgmt/nodes/heimdall/models/ollama/nomic-embed-text")
|
||||
|
||||
assert r.status_code == 200
|
||||
assert r.json() == {"ok": True}
|
||||
|
||||
|
||||
def test_delete_ollama_model_404_when_not_found(client, tmp_path):
|
||||
profiles_dir = tmp_path / "profiles"
|
||||
_write_config(tmp_path, {"profiles_dir": str(profiles_dir)})
|
||||
_write_profile(profiles_dir, "heimdall", _OLLAMA_PROFILE)
|
||||
|
||||
mock_del = MagicMock()
|
||||
mock_del.status_code = 404
|
||||
|
||||
with patch("httpx.request", return_value=mock_del):
|
||||
r = client.delete("/api/nodes-mgmt/nodes/heimdall/models/ollama/missing-model")
|
||||
|
||||
assert r.status_code == 404
|
||||
|
||||
|
||||
# ── Deploy model endpoint ──────────────────────────────────────────────────────
|
||||
|
||||
_DEPLOY_PROFILE = {
|
||||
"services": {
|
||||
"cf-text": {
|
||||
"max_mb": 20000,
|
||||
"min_compute_cap": 7.0,
|
||||
"model_base_path": "/devl/Assets/LLM/cf-text/models",
|
||||
"catalog": {},
|
||||
},
|
||||
},
|
||||
"nodes": {
|
||||
"heimdall": {
|
||||
"gpus": [],
|
||||
"agent_url": "http://10.1.10.71:7701",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def test_deploy_model_adds_catalog_entry(client, tmp_path):
|
||||
"""Deploy endpoint should add the model to the service catalog."""
|
||||
profiles_dir = tmp_path / "profiles"
|
||||
_write_config(tmp_path, {
|
||||
"coordinator_url": "http://fake-coord:7700",
|
||||
"profiles_dir": str(profiles_dir),
|
||||
})
|
||||
_write_profile(profiles_dir, "heimdall", _DEPLOY_PROFILE)
|
||||
|
||||
mock_reload = MagicMock()
|
||||
mock_reload.status_code = 200
|
||||
|
||||
with patch("httpx.post", return_value=mock_reload):
|
||||
r = client.post(
|
||||
"/api/nodes-mgmt/nodes/heimdall/models/deploy",
|
||||
json={
|
||||
"model_id": "fdtn-ai--Foundation-Sec-8B-Q4",
|
||||
"service_type": "cf-text",
|
||||
"vram_mb": 5180,
|
||||
"hf_repo": "fdtn-ai/Foundation-Sec-8B-Q4_K_M-GGUF",
|
||||
},
|
||||
)
|
||||
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["ok"] is True
|
||||
assert data["reloaded"] is True
|
||||
assert "fdtn-ai--Foundation-Sec-8B-Q4_K_M-GGUF" in data["path"]
|
||||
|
||||
saved = yaml.safe_load((profiles_dir / "heimdall.yaml").read_text())
|
||||
catalog = saved["services"]["cf-text"]["catalog"]
|
||||
assert "fdtn-ai--Foundation-Sec-8B-Q4" in catalog
|
||||
entry = catalog["fdtn-ai--Foundation-Sec-8B-Q4"]
|
||||
assert entry["vram_mb"] == 5180
|
||||
assert entry["path"].endswith("fdtn-ai--Foundation-Sec-8B-Q4_K_M-GGUF")
|
||||
|
||||
|
||||
def test_deploy_model_explicit_path_overrides_base(client, tmp_path):
|
||||
"""An explicit path in the request body takes precedence over model_base_path."""
|
||||
profiles_dir = tmp_path / "profiles"
|
||||
_write_config(tmp_path, {
|
||||
"coordinator_url": "http://fake-coord:7700",
|
||||
"profiles_dir": str(profiles_dir),
|
||||
})
|
||||
_write_profile(profiles_dir, "heimdall", _DEPLOY_PROFILE)
|
||||
|
||||
with patch("httpx.post", return_value=MagicMock(status_code=200)):
|
||||
r = client.post(
|
||||
"/api/nodes-mgmt/nodes/heimdall/models/deploy",
|
||||
json={
|
||||
"model_id": "my-model",
|
||||
"service_type": "cf-text",
|
||||
"vram_mb": 8000,
|
||||
"path": "/custom/path/to/model",
|
||||
},
|
||||
)
|
||||
|
||||
assert r.status_code == 200
|
||||
assert r.json()["path"] == "/custom/path/to/model"
|
||||
|
||||
|
||||
def test_deploy_model_unknown_service_returns_422(client, tmp_path):
|
||||
"""Service type not in profile → 422."""
|
||||
profiles_dir = tmp_path / "profiles"
|
||||
_write_config(tmp_path, {"profiles_dir": str(profiles_dir)})
|
||||
_write_profile(profiles_dir, "heimdall", _DEPLOY_PROFILE)
|
||||
|
||||
r = client.post(
|
||||
"/api/nodes-mgmt/nodes/heimdall/models/deploy",
|
||||
json={"model_id": "x", "service_type": "vllm", "vram_mb": 8000},
|
||||
)
|
||||
assert r.status_code == 422
|
||||
assert "vllm" in r.json()["detail"]
|
||||
|
||||
|
||||
def test_deploy_model_missing_profile_returns_404(client, tmp_path):
|
||||
_write_config(tmp_path, {"profiles_dir": str(tmp_path / "profiles")})
|
||||
r = client.post(
|
||||
"/api/nodes-mgmt/nodes/nonexistent/models/deploy",
|
||||
json={"model_id": "x", "service_type": "cf-text", "vram_mb": 100},
|
||||
)
|
||||
assert r.status_code == 404
|
||||
|
|
@ -1,227 +0,0 @@
|
|||
"""Tests for app/data/recipe_scan.py — recipe scan labeling endpoints."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.data import recipe_scan as rs
|
||||
|
||||
|
||||
EXTRACTED = {"title": "Shepherd's Pie", "ingredients": ["lamb", "potato"], "steps": ["brown meat", "mash potato"]}
|
||||
GROUND_TRUTH = {"title": "Shepherd's Pie", "ingredients": ["ground lamb", "mashed potato", "peas"], "steps": ["brown meat", "add veg", "mash potato", "bake"]}
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def isolated_db(tmp_path, monkeypatch):
|
||||
monkeypatch.setattr(rs, "_DB_PATH", tmp_path / "recipe_scan.db")
|
||||
rs._init_db()
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def client():
|
||||
from fastapi import FastAPI
|
||||
app = FastAPI()
|
||||
app.include_router(rs.router, prefix="/api/recipe-scan")
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
def _item(**kwargs) -> dict:
|
||||
return {
|
||||
"id": str(uuid.uuid4()),
|
||||
"image_path": "/Library/Assets/kiwi/scans/pc_test.jpg",
|
||||
"modality": kwargs.get("modality", "scanner"),
|
||||
"source": kwargs.get("source", "purple_carrot"),
|
||||
"extracted": kwargs.get("extracted", EXTRACTED),
|
||||
"ground_truth": kwargs.get("ground_truth", GROUND_TRUTH),
|
||||
}
|
||||
|
||||
|
||||
def _import(client, items: list[dict]) -> None:
|
||||
resp = client.post("/api/recipe-scan/import", json={"items": items})
|
||||
assert resp.status_code == 200
|
||||
|
||||
|
||||
# ── Import ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
def test_import_stores_items(client):
|
||||
_import(client, [_item()])
|
||||
stats = client.get("/api/recipe-scan/stats").json()
|
||||
assert stats["total"] == 1
|
||||
assert stats["by_status"]["pending"] == 1
|
||||
|
||||
|
||||
def test_import_rejects_unknown_modality(client):
|
||||
bad = _item()
|
||||
bad["modality"] = "telepathy"
|
||||
resp = client.post("/api/recipe-scan/import", json={"items": [bad]})
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
def test_import_is_idempotent(client):
|
||||
item = _item()
|
||||
_import(client, [item])
|
||||
_import(client, [item]) # same id — should not duplicate
|
||||
stats = client.get("/api/recipe-scan/stats").json()
|
||||
assert stats["total"] == 1
|
||||
|
||||
|
||||
def test_import_multiple_items(client):
|
||||
_import(client, [_item(), _item(), _item()])
|
||||
assert client.get("/api/recipe-scan/stats").json()["total"] == 3
|
||||
|
||||
|
||||
# ── Next ───────────────────────────────────────────────────────────────────────
|
||||
|
||||
def test_next_returns_404_when_queue_empty(client):
|
||||
resp = client.get("/api/recipe-scan/next")
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
def test_next_returns_pending_item(client):
|
||||
item = _item()
|
||||
_import(client, [item])
|
||||
resp = client.get("/api/recipe-scan/next")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["id"] == item["id"]
|
||||
assert data["status"] == "pending"
|
||||
assert "extracted" in data
|
||||
assert "ground_truth" in data
|
||||
|
||||
|
||||
def test_next_skips_non_pending(client):
|
||||
item = _item()
|
||||
_import(client, [item])
|
||||
client.post(f"/api/recipe-scan/items/{item['id']}/reject")
|
||||
resp = client.get("/api/recipe-scan/next")
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
# ── Approve ────────────────────────────────────────────────────────────────────
|
||||
|
||||
def test_approve_marks_item_approved(client):
|
||||
item = _item()
|
||||
_import(client, [item])
|
||||
resp = client.post(f"/api/recipe-scan/items/{item['id']}/approve")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["status"] == "approved"
|
||||
stats = client.get("/api/recipe-scan/stats").json()
|
||||
assert stats["by_status"]["approved"] == 1
|
||||
|
||||
|
||||
def test_approve_returns_404_for_unknown_id(client):
|
||||
resp = client.post("/api/recipe-scan/items/no-such-id/approve")
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
# ── Edit ───────────────────────────────────────────────────────────────────────
|
||||
|
||||
def test_edit_stores_corrected_json(client):
|
||||
item = _item()
|
||||
_import(client, [item])
|
||||
corrected = {**GROUND_TRUTH, "servings": 4}
|
||||
resp = client.post(
|
||||
f"/api/recipe-scan/items/{item['id']}/edit",
|
||||
json={"corrected": corrected},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["status"] == "edited"
|
||||
stats = client.get("/api/recipe-scan/stats").json()
|
||||
assert stats["by_status"]["edited"] == 1
|
||||
|
||||
|
||||
def test_edit_requires_corrected_field(client):
|
||||
item = _item()
|
||||
_import(client, [item])
|
||||
resp = client.post(f"/api/recipe-scan/items/{item['id']}/edit", json={})
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
# ── Reject ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
def test_reject_marks_item_rejected(client):
|
||||
item = _item()
|
||||
_import(client, [item])
|
||||
resp = client.post(
|
||||
f"/api/recipe-scan/items/{item['id']}/reject",
|
||||
json={"reason": "OCR completely unreadable"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["status"] == "rejected"
|
||||
|
||||
|
||||
def test_reject_without_reason_is_valid(client):
|
||||
item = _item()
|
||||
_import(client, [item])
|
||||
resp = client.post(f"/api/recipe-scan/items/{item['id']}/reject")
|
||||
assert resp.status_code == 200
|
||||
|
||||
|
||||
# ── Export ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
def test_export_empty_when_nothing_approved(client):
|
||||
item = _item()
|
||||
_import(client, [item])
|
||||
resp = client.get("/api/recipe-scan/export")
|
||||
assert resp.status_code == 200
|
||||
assert resp.text.strip() == ""
|
||||
|
||||
|
||||
def test_export_includes_approved_item(client):
|
||||
item = _item()
|
||||
_import(client, [item])
|
||||
client.post(f"/api/recipe-scan/items/{item['id']}/approve")
|
||||
resp = client.get("/api/recipe-scan/export")
|
||||
lines = [l for l in resp.text.strip().splitlines() if l]
|
||||
assert len(lines) == 1
|
||||
pair = json.loads(lines[0])
|
||||
assert pair["id"] == item["id"]
|
||||
assert pair["modality"] == "scanner"
|
||||
assert "messages" in pair
|
||||
assert len(pair["messages"]) == 2
|
||||
assert pair["messages"][0]["role"] == "user"
|
||||
assert pair["messages"][1]["role"] == "assistant"
|
||||
|
||||
|
||||
def test_export_includes_edited_item_with_correction(client):
|
||||
item = _item()
|
||||
_import(client, [item])
|
||||
corrected = {**GROUND_TRUTH, "servings": 4}
|
||||
client.post(
|
||||
f"/api/recipe-scan/items/{item['id']}/edit",
|
||||
json={"corrected": corrected},
|
||||
)
|
||||
resp = client.get("/api/recipe-scan/export")
|
||||
lines = [l for l in resp.text.strip().splitlines() if l]
|
||||
pair = json.loads(lines[0])
|
||||
assistant_content = json.loads(pair["messages"][1]["content"])
|
||||
assert assistant_content["servings"] == 4
|
||||
|
||||
|
||||
def test_export_excludes_rejected_items(client):
|
||||
item = _item()
|
||||
_import(client, [item])
|
||||
client.post(f"/api/recipe-scan/items/{item['id']}/reject")
|
||||
resp = client.get("/api/recipe-scan/export")
|
||||
assert resp.text.strip() == ""
|
||||
|
||||
|
||||
# ── Stats ──────────────────────────────────────────────────────────────────────
|
||||
|
||||
def test_stats_counts_all_statuses(client):
|
||||
items = [_item(), _item(), _item(), _item()]
|
||||
_import(client, items)
|
||||
client.post(f"/api/recipe-scan/items/{items[0]['id']}/approve")
|
||||
client.post(f"/api/recipe-scan/items/{items[1]['id']}/edit", json={"corrected": GROUND_TRUTH})
|
||||
client.post(f"/api/recipe-scan/items/{items[2]['id']}/reject")
|
||||
stats = client.get("/api/recipe-scan/stats").json()
|
||||
assert stats["total"] == 4
|
||||
assert stats["by_status"]["pending"] == 1
|
||||
assert stats["by_status"]["approved"] == 1
|
||||
assert stats["by_status"]["edited"] == 1
|
||||
assert stats["by_status"]["rejected"] == 1
|
||||
assert stats["export_ready"] == 2 # approved + edited
|
||||
|
|
@ -1,380 +0,0 @@
|
|||
"""API integration tests for app/sft.py -- /api/sft/* endpoints."""
|
||||
import json
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_sft_globals(tmp_path):
|
||||
from app.data import corrections as corr_module
|
||||
_prev_data = corr_module._DATA_DIR
|
||||
_prev_cfg = corr_module._CONFIG_DIR
|
||||
_prev_default = corr_module._DEFAULT_BENCH_RESULTS_DIR
|
||||
corr_module.set_data_dir(tmp_path)
|
||||
corr_module.set_config_dir(tmp_path)
|
||||
corr_module.set_default_bench_results_dir(str(tmp_path / "bench_results"))
|
||||
yield
|
||||
corr_module.set_data_dir(_prev_data)
|
||||
corr_module.set_config_dir(_prev_cfg)
|
||||
corr_module.set_default_bench_results_dir(_prev_default)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
from app.api import app
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
def _make_record(id: str, run_id: str = "2026-04-07-143022") -> dict:
|
||||
return {
|
||||
"id": id, "source": "cf-orch-benchmark",
|
||||
"benchmark_run_id": run_id, "timestamp": "2026-04-07T10:00:00Z",
|
||||
"status": "needs_review",
|
||||
"prompt_messages": [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Write a Python function that adds two numbers."},
|
||||
],
|
||||
"model_response": "def add(a, b): return a - b",
|
||||
"corrected_response": None,
|
||||
"quality_score": 0.2, "failure_reason": "pattern_match: 0/2 matched",
|
||||
"task_id": "code-fn", "task_type": "code",
|
||||
"task_name": "Code: Write a Python function",
|
||||
"model_id": "Qwen/Qwen2.5-3B", "model_name": "Qwen2.5-3B",
|
||||
"node_id": "heimdall", "gpu_id": 0, "tokens_per_sec": 38.4,
|
||||
}
|
||||
|
||||
|
||||
def _write_run(tmp_path, run_id: str, records: list[dict]) -> Path:
|
||||
run_dir = tmp_path / "bench_results" / run_id
|
||||
run_dir.mkdir(parents=True)
|
||||
sft_path = run_dir / "sft_candidates.jsonl"
|
||||
sft_path.write_text(
|
||||
"\n".join(json.dumps(r) for r in records) + "\n", encoding="utf-8"
|
||||
)
|
||||
return sft_path
|
||||
|
||||
|
||||
def _write_config(tmp_path, bench_results_dir: Path) -> None:
|
||||
import yaml
|
||||
cfg = {"sft": {"bench_results_dir": str(bench_results_dir)}}
|
||||
(tmp_path / "label_tool.yaml").write_text(
|
||||
yaml.dump(cfg, allow_unicode=True), encoding="utf-8"
|
||||
)
|
||||
|
||||
|
||||
# -- /api/sft/runs -------------------------------------------------------------
|
||||
|
||||
def test_runs_returns_empty_when_no_config(client):
|
||||
r = client.get("/api/sft/runs")
|
||||
assert r.status_code == 200
|
||||
assert r.json() == []
|
||||
|
||||
|
||||
def test_runs_returns_available_runs(client, tmp_path):
|
||||
_write_run(tmp_path, "2026-04-07-143022", [_make_record("a"), _make_record("b")])
|
||||
_write_config(tmp_path, tmp_path / "bench_results")
|
||||
r = client.get("/api/sft/runs")
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert len(data) == 1
|
||||
assert data[0]["run_id"] == "2026-04-07-143022"
|
||||
assert data[0]["candidate_count"] == 2
|
||||
assert data[0]["already_imported"] is False
|
||||
|
||||
|
||||
def test_runs_marks_already_imported(client, tmp_path):
|
||||
_write_run(tmp_path, "2026-04-07-143022", [_make_record("a")])
|
||||
_write_config(tmp_path, tmp_path / "bench_results")
|
||||
from app.data import corrections as sft_module
|
||||
candidates = sft_module._candidates_file()
|
||||
candidates.parent.mkdir(parents=True, exist_ok=True)
|
||||
candidates.write_text(
|
||||
json.dumps(_make_record("a", run_id="2026-04-07-143022")) + "\n",
|
||||
encoding="utf-8"
|
||||
)
|
||||
r = client.get("/api/sft/runs")
|
||||
assert r.json()[0]["already_imported"] is True
|
||||
|
||||
|
||||
# -- /api/sft/import -----------------------------------------------------------
|
||||
|
||||
def test_import_adds_records(client, tmp_path):
|
||||
_write_run(tmp_path, "2026-04-07-143022", [_make_record("a"), _make_record("b")])
|
||||
_write_config(tmp_path, tmp_path / "bench_results")
|
||||
r = client.post("/api/sft/import", json={"run_id": "2026-04-07-143022"})
|
||||
assert r.status_code == 200
|
||||
assert r.json() == {"imported": 2, "skipped": 0}
|
||||
|
||||
|
||||
def test_import_is_idempotent(client, tmp_path):
|
||||
_write_run(tmp_path, "2026-04-07-143022", [_make_record("a")])
|
||||
_write_config(tmp_path, tmp_path / "bench_results")
|
||||
client.post("/api/sft/import", json={"run_id": "2026-04-07-143022"})
|
||||
r = client.post("/api/sft/import", json={"run_id": "2026-04-07-143022"})
|
||||
assert r.json() == {"imported": 0, "skipped": 1}
|
||||
|
||||
|
||||
def test_import_unknown_run_returns_404(client, tmp_path):
|
||||
_write_config(tmp_path, tmp_path / "bench_results")
|
||||
r = client.post("/api/sft/import", json={"run_id": "nonexistent"})
|
||||
assert r.status_code == 404
|
||||
|
||||
|
||||
# -- /api/sft/queue ------------------------------------------------------------
|
||||
|
||||
def _populate_candidates(tmp_path, records: list[dict]) -> None:
|
||||
from app.data import corrections as sft_module
|
||||
path = sft_module._candidates_file()
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
path.write_text(
|
||||
"\n".join(json.dumps(r) for r in records) + "\n", encoding="utf-8"
|
||||
)
|
||||
|
||||
|
||||
def test_queue_returns_needs_review_only(client, tmp_path):
|
||||
records = [
|
||||
_make_record("a"), # needs_review
|
||||
{**_make_record("b"), "status": "approved"}, # should not appear
|
||||
{**_make_record("c"), "status": "discarded"}, # should not appear
|
||||
]
|
||||
_populate_candidates(tmp_path, records)
|
||||
r = client.get("/api/sft/queue")
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["total"] == 1
|
||||
assert len(data["items"]) == 1
|
||||
assert data["items"][0]["id"] == "a"
|
||||
|
||||
|
||||
def test_queue_pagination(client, tmp_path):
|
||||
records = [_make_record(str(i)) for i in range(25)]
|
||||
_populate_candidates(tmp_path, records)
|
||||
r = client.get("/api/sft/queue?page=1&per_page=10")
|
||||
data = r.json()
|
||||
assert data["total"] == 25
|
||||
assert len(data["items"]) == 10
|
||||
r2 = client.get("/api/sft/queue?page=3&per_page=10")
|
||||
assert len(r2.json()["items"]) == 5
|
||||
|
||||
|
||||
def test_queue_empty_when_no_file(client):
|
||||
r = client.get("/api/sft/queue")
|
||||
assert r.status_code == 200
|
||||
assert r.json() == {"items": [], "total": 0, "page": 1, "per_page": 20}
|
||||
|
||||
|
||||
# -- /api/sft/submit -----------------------------------------------------------
|
||||
|
||||
def test_submit_correct_sets_approved(client, tmp_path):
|
||||
_populate_candidates(tmp_path, [_make_record("a")])
|
||||
r = client.post("/api/sft/submit", json={
|
||||
"id": "a", "action": "correct",
|
||||
"corrected_response": "def add(a, b): return a + b",
|
||||
})
|
||||
assert r.status_code == 200
|
||||
from app.data import corrections as sft_module
|
||||
records = sft_module._read_candidates()
|
||||
assert records[0]["status"] == "approved"
|
||||
assert records[0]["corrected_response"] == "def add(a, b): return a + b"
|
||||
|
||||
|
||||
def test_submit_correct_also_appends_to_approved_file(client, tmp_path):
|
||||
_populate_candidates(tmp_path, [_make_record("a")])
|
||||
client.post("/api/sft/submit", json={
|
||||
"id": "a", "action": "correct",
|
||||
"corrected_response": "def add(a, b): return a + b",
|
||||
})
|
||||
from app.data import corrections as sft_module
|
||||
from app.utils import read_jsonl
|
||||
approved = read_jsonl(sft_module._approved_file())
|
||||
assert len(approved) == 1
|
||||
assert approved[0]["id"] == "a"
|
||||
|
||||
|
||||
def test_submit_discard_sets_discarded(client, tmp_path):
|
||||
_populate_candidates(tmp_path, [_make_record("a")])
|
||||
r = client.post("/api/sft/submit", json={"id": "a", "action": "discard"})
|
||||
assert r.status_code == 200
|
||||
from app.data import corrections as sft_module
|
||||
assert sft_module._read_candidates()[0]["status"] == "discarded"
|
||||
|
||||
|
||||
def test_submit_flag_sets_model_rejected(client, tmp_path):
|
||||
_populate_candidates(tmp_path, [_make_record("a")])
|
||||
r = client.post("/api/sft/submit", json={"id": "a", "action": "flag"})
|
||||
assert r.status_code == 200
|
||||
from app.data import corrections as sft_module
|
||||
assert sft_module._read_candidates()[0]["status"] == "model_rejected"
|
||||
|
||||
|
||||
def test_submit_correct_empty_response_returns_422(client, tmp_path):
|
||||
_populate_candidates(tmp_path, [_make_record("a")])
|
||||
r = client.post("/api/sft/submit", json={
|
||||
"id": "a", "action": "correct", "corrected_response": " ",
|
||||
})
|
||||
assert r.status_code == 422
|
||||
|
||||
|
||||
def test_submit_correct_null_response_returns_422(client, tmp_path):
|
||||
_populate_candidates(tmp_path, [_make_record("a")])
|
||||
r = client.post("/api/sft/submit", json={
|
||||
"id": "a", "action": "correct", "corrected_response": None,
|
||||
})
|
||||
assert r.status_code == 422
|
||||
|
||||
|
||||
def test_submit_unknown_id_returns_404(client, tmp_path):
|
||||
r = client.post("/api/sft/submit", json={"id": "nope", "action": "discard"})
|
||||
assert r.status_code == 404
|
||||
|
||||
|
||||
def test_submit_already_approved_returns_409(client, tmp_path):
|
||||
_populate_candidates(tmp_path, [{**_make_record("a"), "status": "approved"}])
|
||||
r = client.post("/api/sft/submit", json={"id": "a", "action": "discard"})
|
||||
assert r.status_code == 409
|
||||
|
||||
|
||||
def test_submit_correct_stores_failure_category(client, tmp_path):
|
||||
_populate_candidates(tmp_path, [_make_record("a")])
|
||||
r = client.post("/api/sft/submit", json={
|
||||
"id": "a", "action": "correct",
|
||||
"corrected_response": "def add(a, b): return a + b",
|
||||
"failure_category": "style_violation",
|
||||
})
|
||||
assert r.status_code == 200
|
||||
from app.data import corrections as sft_module
|
||||
records = sft_module._read_candidates()
|
||||
assert records[0]["failure_category"] == "style_violation"
|
||||
|
||||
|
||||
def test_submit_correct_null_failure_category(client, tmp_path):
|
||||
_populate_candidates(tmp_path, [_make_record("a")])
|
||||
r = client.post("/api/sft/submit", json={
|
||||
"id": "a", "action": "correct",
|
||||
"corrected_response": "def add(a, b): return a + b",
|
||||
})
|
||||
assert r.status_code == 200
|
||||
from app.data import corrections as sft_module
|
||||
records = sft_module._read_candidates()
|
||||
assert records[0]["failure_category"] is None
|
||||
|
||||
|
||||
def test_submit_invalid_failure_category_returns_422(client, tmp_path):
|
||||
_populate_candidates(tmp_path, [_make_record("a")])
|
||||
r = client.post("/api/sft/submit", json={
|
||||
"id": "a", "action": "correct",
|
||||
"corrected_response": "def add(a, b): return a + b",
|
||||
"failure_category": "nonsense",
|
||||
})
|
||||
assert r.status_code == 422
|
||||
|
||||
|
||||
# -- /api/sft/undo -------------------------------------------------------------
|
||||
|
||||
def test_undo_restores_discarded_to_needs_review(client, tmp_path):
|
||||
_populate_candidates(tmp_path, [_make_record("a")])
|
||||
client.post("/api/sft/submit", json={"id": "a", "action": "discard"})
|
||||
r = client.post("/api/sft/undo", json={"id": "a"})
|
||||
assert r.status_code == 200
|
||||
from app.data import corrections as sft_module
|
||||
assert sft_module._read_candidates()[0]["status"] == "needs_review"
|
||||
|
||||
|
||||
def test_undo_removes_approved_from_approved_file(client, tmp_path):
|
||||
_populate_candidates(tmp_path, [_make_record("a")])
|
||||
client.post("/api/sft/submit", json={
|
||||
"id": "a", "action": "correct",
|
||||
"corrected_response": "def add(a, b): return a + b",
|
||||
})
|
||||
client.post("/api/sft/undo", json={"id": "a"})
|
||||
from app.data import corrections as sft_module
|
||||
from app.utils import read_jsonl
|
||||
approved = read_jsonl(sft_module._approved_file())
|
||||
assert not any(r["id"] == "a" for r in approved)
|
||||
|
||||
|
||||
def test_undo_already_needs_review_returns_409(client, tmp_path):
|
||||
_populate_candidates(tmp_path, [_make_record("a")])
|
||||
r = client.post("/api/sft/undo", json={"id": "a"})
|
||||
assert r.status_code == 409
|
||||
|
||||
|
||||
# -- /api/sft/export -----------------------------------------------------------
|
||||
|
||||
def test_export_returns_approved_as_sft_jsonl(client, tmp_path):
|
||||
from app.data import corrections as sft_module
|
||||
from app.utils import write_jsonl
|
||||
approved = {
|
||||
**_make_record("a"),
|
||||
"status": "approved",
|
||||
"corrected_response": "def add(a, b): return a + b",
|
||||
"prompt_messages": [
|
||||
{"role": "system", "content": "You are a coding assistant."},
|
||||
{"role": "user", "content": "Write a Python add function."},
|
||||
],
|
||||
}
|
||||
write_jsonl(sft_module._approved_file(), [approved])
|
||||
_populate_candidates(tmp_path, [approved])
|
||||
|
||||
r = client.get("/api/sft/export")
|
||||
assert r.status_code == 200
|
||||
assert "application/x-ndjson" in r.headers["content-type"]
|
||||
lines = [l for l in r.text.splitlines() if l.strip()]
|
||||
assert len(lines) == 1
|
||||
record = json.loads(lines[0])
|
||||
assert record["messages"][-1] == {
|
||||
"role": "assistant", "content": "def add(a, b): return a + b"
|
||||
}
|
||||
assert record["messages"][0]["role"] == "system"
|
||||
assert record["messages"][1]["role"] == "user"
|
||||
|
||||
|
||||
def test_export_excludes_non_approved(client, tmp_path):
|
||||
from app.data import corrections as sft_module
|
||||
from app.utils import write_jsonl
|
||||
records = [
|
||||
{**_make_record("a"), "status": "discarded", "corrected_response": None},
|
||||
{**_make_record("b"), "status": "needs_review", "corrected_response": None},
|
||||
]
|
||||
write_jsonl(sft_module._approved_file(), records)
|
||||
r = client.get("/api/sft/export")
|
||||
assert r.text.strip() == ""
|
||||
|
||||
|
||||
def test_export_empty_when_no_approved_file(client):
|
||||
r = client.get("/api/sft/export")
|
||||
assert r.status_code == 200
|
||||
assert r.text.strip() == ""
|
||||
|
||||
|
||||
# -- /api/sft/stats ------------------------------------------------------------
|
||||
|
||||
def test_stats_counts_by_status(client, tmp_path):
|
||||
from app.data import corrections as sft_module
|
||||
from app.utils import write_jsonl
|
||||
records = [
|
||||
_make_record("a"),
|
||||
{**_make_record("b"), "status": "approved", "corrected_response": "ok"},
|
||||
{**_make_record("c"), "status": "discarded"},
|
||||
{**_make_record("d"), "status": "model_rejected"},
|
||||
]
|
||||
_populate_candidates(tmp_path, records)
|
||||
write_jsonl(sft_module._approved_file(), [records[1]])
|
||||
r = client.get("/api/sft/stats")
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["total"] == 4
|
||||
assert data["by_status"]["needs_review"] == 1
|
||||
assert data["by_status"]["approved"] == 1
|
||||
assert data["by_status"]["discarded"] == 1
|
||||
assert data["by_status"]["model_rejected"] == 1
|
||||
assert data["export_ready"] == 1
|
||||
|
||||
|
||||
def test_stats_empty_when_no_data(client):
|
||||
r = client.get("/api/sft/stats")
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["total"] == 0
|
||||
assert data["export_ready"] == 0
|
||||
|
|
@ -1,95 +0,0 @@
|
|||
"""Unit tests for scripts/sft_import.py — run discovery and JSONL deduplication."""
|
||||
import json
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def _write_candidates(path: Path, records: list[dict]) -> None:
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
path.write_text("\n".join(json.dumps(r) for r in records) + "\n", encoding="utf-8")
|
||||
|
||||
|
||||
def _make_record(id: str, run_id: str = "run1") -> dict:
|
||||
return {
|
||||
"id": id, "source": "cf-orch-benchmark",
|
||||
"benchmark_run_id": run_id, "timestamp": "2026-04-07T10:00:00Z",
|
||||
"status": "needs_review", "prompt_messages": [],
|
||||
"model_response": "bad", "corrected_response": None,
|
||||
"quality_score": 0.3, "failure_reason": "missing patterns",
|
||||
"task_id": "code-fn", "task_type": "code", "task_name": "Code: fn",
|
||||
"model_id": "Qwen/Qwen2.5-3B", "model_name": "Qwen2.5-3B",
|
||||
"node_id": "heimdall", "gpu_id": 0, "tokens_per_sec": 38.4,
|
||||
}
|
||||
|
||||
|
||||
def test_discover_runs_empty_when_dir_missing(tmp_path):
|
||||
from scripts.sft_import import discover_runs
|
||||
result = discover_runs(tmp_path / "nonexistent")
|
||||
assert result == []
|
||||
|
||||
|
||||
def test_discover_runs_returns_runs(tmp_path):
|
||||
from scripts.sft_import import discover_runs
|
||||
run_dir = tmp_path / "2026-04-07-143022"
|
||||
_write_candidates(run_dir / "sft_candidates.jsonl", [_make_record("a"), _make_record("b")])
|
||||
result = discover_runs(tmp_path)
|
||||
assert len(result) == 1
|
||||
assert result[0]["run_id"] == "2026-04-07-143022"
|
||||
assert result[0]["candidate_count"] == 2
|
||||
assert "sft_path" in result[0]
|
||||
|
||||
|
||||
def test_discover_runs_skips_dirs_without_sft_file(tmp_path):
|
||||
from scripts.sft_import import discover_runs
|
||||
(tmp_path / "2026-04-07-no-sft").mkdir()
|
||||
result = discover_runs(tmp_path)
|
||||
assert result == []
|
||||
|
||||
|
||||
def test_discover_runs_sorted_newest_first(tmp_path):
|
||||
from scripts.sft_import import discover_runs
|
||||
for name in ["2026-04-05-120000", "2026-04-07-143022", "2026-04-06-090000"]:
|
||||
run_dir = tmp_path / name
|
||||
_write_candidates(run_dir / "sft_candidates.jsonl", [_make_record("x")])
|
||||
result = discover_runs(tmp_path)
|
||||
assert [r["run_id"] for r in result] == [
|
||||
"2026-04-07-143022", "2026-04-06-090000", "2026-04-05-120000"
|
||||
]
|
||||
|
||||
|
||||
def test_import_run_imports_new_records(tmp_path):
|
||||
from scripts.sft_import import import_run
|
||||
sft_path = tmp_path / "run1" / "sft_candidates.jsonl"
|
||||
_write_candidates(sft_path, [_make_record("a"), _make_record("b")])
|
||||
result = import_run(sft_path, tmp_path)
|
||||
assert result == {"imported": 2, "skipped": 0}
|
||||
dest = tmp_path / "sft_candidates.jsonl"
|
||||
lines = [json.loads(l) for l in dest.read_text().splitlines() if l.strip()]
|
||||
assert len(lines) == 2
|
||||
|
||||
|
||||
def test_import_run_deduplicates_on_id(tmp_path):
|
||||
from scripts.sft_import import import_run
|
||||
sft_path = tmp_path / "run1" / "sft_candidates.jsonl"
|
||||
_write_candidates(sft_path, [_make_record("a"), _make_record("b")])
|
||||
import_run(sft_path, tmp_path)
|
||||
result = import_run(sft_path, tmp_path) # second import
|
||||
assert result == {"imported": 0, "skipped": 2}
|
||||
dest = tmp_path / "sft_candidates.jsonl"
|
||||
lines = [l for l in dest.read_text().splitlines() if l.strip()]
|
||||
assert len(lines) == 2 # no duplicates
|
||||
|
||||
|
||||
def test_import_run_skips_records_missing_id(tmp_path, caplog):
|
||||
import logging
|
||||
from scripts.sft_import import import_run
|
||||
sft_path = tmp_path / "run1" / "sft_candidates.jsonl"
|
||||
sft_path.parent.mkdir()
|
||||
sft_path.write_text(
|
||||
json.dumps({"model_response": "bad", "status": "needs_review"}) + "\n"
|
||||
+ json.dumps({"id": "abc123", "model_response": "good", "status": "needs_review"}) + "\n"
|
||||
)
|
||||
with caplog.at_level(logging.WARNING, logger="scripts.sft_import"):
|
||||
result = import_run(sft_path, tmp_path)
|
||||
assert result == {"imported": 1, "skipped": 0}
|
||||
assert "missing 'id'" in caplog.text
|
||||
|
|
@ -1,187 +0,0 @@
|
|||
"""Tests for app/train/train.py -- /api/train/* endpoints."""
|
||||
import json
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_globals(tmp_path):
|
||||
from app.train import train as train_module
|
||||
train_module.set_db_path(tmp_path / "train_jobs.db")
|
||||
train_module.set_models_dir(tmp_path / "models")
|
||||
train_module._running_procs.clear()
|
||||
yield
|
||||
train_module._running_procs.clear()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
from app.api import app
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
def _parse_sse(content: bytes) -> list[dict]:
|
||||
events = []
|
||||
for line in content.decode().splitlines():
|
||||
if line.startswith("data: "):
|
||||
events.append(json.loads(line[6:]))
|
||||
return events
|
||||
|
||||
|
||||
def test_list_jobs_empty(client):
|
||||
r = client.get("/api/train/jobs")
|
||||
assert r.status_code == 200
|
||||
assert r.json() == {"jobs": []}
|
||||
|
||||
|
||||
def test_create_job_returns_queued_record(client):
|
||||
r = client.post("/api/train/jobs",
|
||||
json={"type": "classifier", "model_key": "deberta-small",
|
||||
"config_json": {"epochs": 3}})
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["status"] == "queued"
|
||||
assert data["type"] == "classifier"
|
||||
assert data["model_key"] == "deberta-small"
|
||||
assert "id" in data
|
||||
|
||||
|
||||
def test_create_job_invalid_type_returns_400(client):
|
||||
r = client.post("/api/train/jobs",
|
||||
json={"type": "unknown-type", "model_key": "deberta-small"})
|
||||
assert r.status_code == 400
|
||||
|
||||
|
||||
def test_create_job_appears_in_list(client):
|
||||
client.post("/api/train/jobs", json={"type": "classifier", "model_key": "deberta-small"})
|
||||
r = client.get("/api/train/jobs")
|
||||
assert r.status_code == 200
|
||||
assert len(r.json()["jobs"]) == 1
|
||||
|
||||
|
||||
def test_get_job_returns_record(client):
|
||||
r = client.post("/api/train/jobs", json={"type": "classifier", "model_key": "deberta-small"})
|
||||
job_id = r.json()["id"]
|
||||
r2 = client.get(f"/api/train/jobs/{job_id}")
|
||||
assert r2.status_code == 200
|
||||
assert r2.json()["id"] == job_id
|
||||
|
||||
|
||||
def test_get_job_404_for_unknown(client):
|
||||
r = client.get("/api/train/jobs/no-such-id")
|
||||
assert r.status_code == 404
|
||||
|
||||
|
||||
def test_cancel_queued_job(client):
|
||||
r = client.post("/api/train/jobs", json={"type": "classifier", "model_key": "deberta-small"})
|
||||
job_id = r.json()["id"]
|
||||
r2 = client.delete(f"/api/train/jobs/{job_id}/cancel")
|
||||
assert r2.status_code == 200
|
||||
assert r2.json()["status"] == "cancelled"
|
||||
r3 = client.get(f"/api/train/jobs/{job_id}")
|
||||
assert r3.json()["status"] == "cancelled"
|
||||
|
||||
|
||||
def test_cancel_completed_job_returns_409(client):
|
||||
from app.train import train as train_module
|
||||
train_module._init_db()
|
||||
with train_module._db() as conn:
|
||||
conn.execute(
|
||||
"INSERT INTO jobs (id, type, model_key, status, config_json, created_at) "
|
||||
"VALUES ('abc', 'classifier', 'deberta-small', 'completed', '{}', '2026-05-01T00:00:00Z')"
|
||||
)
|
||||
r = client.delete("/api/train/jobs/abc/cancel")
|
||||
assert r.status_code == 409
|
||||
|
||||
|
||||
def test_cancel_terminates_running_proc(client):
|
||||
from app.train import train as train_module
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.wait = MagicMock()
|
||||
r = client.post("/api/train/jobs", json={"type": "classifier", "model_key": "deberta-small"})
|
||||
job_id = r.json()["id"]
|
||||
train_module._running_procs[job_id] = mock_proc
|
||||
with train_module._db() as conn:
|
||||
conn.execute("UPDATE jobs SET status='running' WHERE id=?", (job_id,))
|
||||
r2 = client.delete(f"/api/train/jobs/{job_id}/cancel")
|
||||
assert r2.status_code == 200
|
||||
mock_proc.terminate.assert_called_once()
|
||||
|
||||
|
||||
def test_run_job_streams_sse(client):
|
||||
r = client.post("/api/train/jobs", json={"type": "classifier", "model_key": "deberta-small"})
|
||||
job_id = r.json()["id"]
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.stdout = iter(["Epoch 1\n", "Done\n"])
|
||||
mock_proc.returncode = 0
|
||||
mock_proc.wait = MagicMock()
|
||||
with patch("app.train.train._subprocess.Popen", return_value=mock_proc):
|
||||
r2 = client.get(f"/api/train/jobs/{job_id}/run")
|
||||
assert r2.status_code == 200
|
||||
assert "text/event-stream" in r2.headers.get("content-type", "")
|
||||
events = _parse_sse(r2.content)
|
||||
assert any(e["type"] == "complete" for e in events)
|
||||
|
||||
|
||||
def test_run_job_marks_completed_in_db(client):
|
||||
from app.train import train as train_module
|
||||
r = client.post("/api/train/jobs", json={"type": "classifier", "model_key": "deberta-small"})
|
||||
job_id = r.json()["id"]
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.stdout = iter([])
|
||||
mock_proc.returncode = 0
|
||||
mock_proc.wait = MagicMock()
|
||||
with patch("app.train.train._subprocess.Popen", return_value=mock_proc):
|
||||
client.get(f"/api/train/jobs/{job_id}/run")
|
||||
r2 = client.get(f"/api/train/jobs/{job_id}")
|
||||
assert r2.json()["status"] == "completed"
|
||||
|
||||
|
||||
def test_run_job_marks_failed_on_nonzero_exit(client):
|
||||
r = client.post("/api/train/jobs", json={"type": "classifier", "model_key": "deberta-small"})
|
||||
job_id = r.json()["id"]
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.stdout = iter([])
|
||||
mock_proc.returncode = 1
|
||||
mock_proc.wait = MagicMock()
|
||||
with patch("app.train.train._subprocess.Popen", return_value=mock_proc):
|
||||
client.get(f"/api/train/jobs/{job_id}/run")
|
||||
r2 = client.get(f"/api/train/jobs/{job_id}")
|
||||
assert r2.json()["status"] == "failed"
|
||||
|
||||
|
||||
def test_run_nonqueued_job_returns_409(client):
|
||||
from app.train import train as train_module
|
||||
train_module._init_db()
|
||||
with train_module._db() as conn:
|
||||
conn.execute(
|
||||
"INSERT INTO jobs (id, type, model_key, status, config_json, created_at) "
|
||||
"VALUES ('xyz', 'classifier', 'deberta-small', 'running', '{}', '2026-05-01T00:00:00Z')"
|
||||
)
|
||||
r = client.get("/api/train/jobs/xyz/run")
|
||||
assert r.status_code == 409
|
||||
|
||||
|
||||
def test_run_unknown_job_returns_404(client):
|
||||
r = client.get("/api/train/jobs/no-such/run")
|
||||
assert r.status_code == 404
|
||||
|
||||
|
||||
def test_results_empty_when_no_models_dir(client):
|
||||
r = client.get("/api/train/results")
|
||||
assert r.status_code == 200
|
||||
assert r.json() == {"results": []}
|
||||
|
||||
|
||||
def test_results_returns_training_info(client, tmp_path):
|
||||
from app.train import train as train_module
|
||||
models_dir = tmp_path / "models" / "avocet-deberta-small"
|
||||
models_dir.mkdir(parents=True)
|
||||
train_module.set_models_dir(tmp_path / "models")
|
||||
info = {"name": "avocet-deberta-small", "val_macro_f1": 0.712, "sample_count": 401}
|
||||
(models_dir / "training_info.json").write_text(json.dumps(info))
|
||||
r = client.get("/api/train/results")
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert any(d["name"] == "avocet-deberta-small" for d in data["results"])
|
||||
4
web/.gitignore
vendored
4
web/.gitignore
vendored
|
|
@ -22,7 +22,3 @@ dist-ssr
|
|||
*.njsproj
|
||||
*.sln
|
||||
*.sw?
|
||||
|
||||
# Local environment overrides
|
||||
.env
|
||||
|
||||
|
|
|
|||
42
web/package-lock.json
generated
42
web/package-lock.json
generated
|
|
@ -2676,9 +2676,9 @@
|
|||
}
|
||||
},
|
||||
"node_modules/brace-expansion": {
|
||||
"version": "2.1.0",
|
||||
"resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.1.0.tgz",
|
||||
"integrity": "sha512-TN1kCZAgdgweJhWWpgKYrQaMNHcDULHkWwQIspdtjV4Y5aurRdZpjAqn6yX3FPqTA9ngHCc4hJxMAMgGfve85w==",
|
||||
"version": "2.0.2",
|
||||
"resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.2.tgz",
|
||||
"integrity": "sha512-Jt0vHyM+jmUBqojB7E1NIYadt0vI0Qxjxd2TErW94wDz+E2LAm5vKMXXwg6ZZBTHPuUlDgQHKXvjGBdfcF1ZDQ==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
|
|
@ -2890,9 +2890,9 @@
|
|||
"license": "MIT"
|
||||
},
|
||||
"node_modules/defu": {
|
||||
"version": "6.1.7",
|
||||
"resolved": "https://registry.npmjs.org/defu/-/defu-6.1.7.tgz",
|
||||
"integrity": "sha512-7z22QmUWiQ/2d0KkdYmANbRUVABpZ9SNYyH5vx6PZ+nE5bcC0l7uFvEfHlyld/HcGBFTL536ClDt3DEcSlEJAQ==",
|
||||
"version": "6.1.4",
|
||||
"resolved": "https://registry.npmjs.org/defu/-/defu-6.1.4.tgz",
|
||||
"integrity": "sha512-mEQCMmwJu317oSz8CwdIOdwf3xMif1ttiM8LTufzc3g6kR+9Pe236twL8j3IYT1F7GfRgGcW6MWxzZjLIkuHIg==",
|
||||
"dev": true,
|
||||
"license": "MIT"
|
||||
},
|
||||
|
|
@ -3725,9 +3725,9 @@
|
|||
"license": "ISC"
|
||||
},
|
||||
"node_modules/picomatch": {
|
||||
"version": "4.0.4",
|
||||
"resolved": "https://registry.npmjs.org/picomatch/-/picomatch-4.0.4.tgz",
|
||||
"integrity": "sha512-QP88BAKvMam/3NxH6vj2o21R6MjxZUAd6nlwAS/pnGvN9IVLocLHxGYIzFhg6fUQ+5th6P4dv4eW9jX3DSIj7A==",
|
||||
"version": "4.0.3",
|
||||
"resolved": "https://registry.npmjs.org/picomatch/-/picomatch-4.0.3.tgz",
|
||||
"integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==",
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">=12"
|
||||
|
|
@ -3769,9 +3769,9 @@
|
|||
}
|
||||
},
|
||||
"node_modules/postcss": {
|
||||
"version": "8.5.14",
|
||||
"resolved": "https://registry.npmjs.org/postcss/-/postcss-8.5.14.tgz",
|
||||
"integrity": "sha512-SoSL4+OSEtR99LHFZQiJLkT59C5B1amGO1NzTwj7TT1qCUgUO6hxOvzkOYxD+vMrXBM3XJIKzokoERdqQq/Zmg==",
|
||||
"version": "8.5.8",
|
||||
"resolved": "https://registry.npmjs.org/postcss/-/postcss-8.5.8.tgz",
|
||||
"integrity": "sha512-OW/rX8O/jXnm82Ey1k44pObPtdblfiuWnrd8X7GJ7emImCOstunGbXUpp7HdBrFQX6rJzn3sPT397Wp5aCwCHg==",
|
||||
"funding": [
|
||||
{
|
||||
"type": "opencollective",
|
||||
|
|
@ -4325,9 +4325,9 @@
|
|||
}
|
||||
},
|
||||
"node_modules/undici": {
|
||||
"version": "7.25.0",
|
||||
"resolved": "https://registry.npmjs.org/undici/-/undici-7.25.0.tgz",
|
||||
"integrity": "sha512-xXnp4kTyor2Zq+J1FfPI6Eq3ew5h6Vl0F/8d9XU5zZQf1tX9s2Su1/3PiMmUANFULpmksxkClamIZcaUqryHsQ==",
|
||||
"version": "7.22.0",
|
||||
"resolved": "https://registry.npmjs.org/undici/-/undici-7.22.0.tgz",
|
||||
"integrity": "sha512-RqslV2Us5BrllB+JeiZnK4peryVTndy9Dnqq62S3yYRRTj0tFQCwEniUy2167skdGOy3vqRzEvl1Dm4sV2ReDg==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
|
|
@ -4422,9 +4422,9 @@
|
|||
}
|
||||
},
|
||||
"node_modules/vite": {
|
||||
"version": "7.3.2",
|
||||
"resolved": "https://registry.npmjs.org/vite/-/vite-7.3.2.tgz",
|
||||
"integrity": "sha512-Bby3NOsna2jsjfLVOHKes8sGwgl4TT0E6vvpYgnAYDIF/tie7MRaFthmKuHx1NSXjiTueXH3do80FMQgvEktRg==",
|
||||
"version": "7.3.1",
|
||||
"resolved": "https://registry.npmjs.org/vite/-/vite-7.3.1.tgz",
|
||||
"integrity": "sha512-w+N7Hifpc3gRjZ63vYBXA56dvvRlNWRczTdmCBBa+CotUzAPf5b7YMdMR/8CQoeYE5LX3W4wj6RYTgonm1b9DA==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
|
|
@ -4921,9 +4921,9 @@
|
|||
"license": "MIT"
|
||||
},
|
||||
"node_modules/yaml": {
|
||||
"version": "2.8.4",
|
||||
"resolved": "https://registry.npmjs.org/yaml/-/yaml-2.8.4.tgz",
|
||||
"integrity": "sha512-ml/JPOj9fOQK8RNnWojA67GbZ0ApXAUlN2UQclwv2eVgTgn7O9gg9o7paZWKMp4g0H3nTLtS9LVzhkpOFIKzog==",
|
||||
"version": "2.8.2",
|
||||
"resolved": "https://registry.npmjs.org/yaml/-/yaml-2.8.2.tgz",
|
||||
"integrity": "sha512-mplynKqc1C2hTVYxd0PU2xQAc22TI1vShAYGksCCfxbn/dFwnHTNi1bvYsBTkhdUNtGIf5xNOg938rrSSYvS9A==",
|
||||
"license": "ISC",
|
||||
"bin": {
|
||||
"yaml": "bin.mjs"
|
||||
|
|
|
|||
|
|
@ -1,124 +0,0 @@
|
|||
import { mount, flushPromises } from '@vue/test-utils'
|
||||
import { createRouter, createWebHashHistory } from 'vue-router'
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest'
|
||||
import AppSidebar from './AppSidebar.vue'
|
||||
|
||||
// Minimal router so RouterLink renders without warnings
|
||||
const router = createRouter({
|
||||
history: createWebHashHistory(),
|
||||
routes: [
|
||||
{ path: '/', component: { template: '<div />' } },
|
||||
{ path: '/fleet', component: { template: '<div />' } },
|
||||
{ path: '/data/label', component: { template: '<div />' } },
|
||||
{ path: '/data/fetch', component: { template: '<div />' } },
|
||||
{ path: '/data/corrections', component: { template: '<div />' } },
|
||||
{ path: '/data/imitate', component: { template: '<div />' } },
|
||||
{ path: '/eval/benchmark', component: { template: '<div />' } },
|
||||
{ path: '/eval/compare', component: { template: '<div />' } },
|
||||
{ path: '/train/jobs', component: { template: '<div />' } },
|
||||
{ path: '/train/results', component: { template: '<div />' } },
|
||||
{ path: '/settings', component: { template: '<div />' } },
|
||||
],
|
||||
})
|
||||
|
||||
function makeFetch(signals: Record<string, boolean> = {}) {
|
||||
return vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
json: async () => ({
|
||||
labeled_since_last_eval: 0,
|
||||
last_eval_timestamp: null,
|
||||
last_eval_best_score: null,
|
||||
active_jobs: [],
|
||||
corrections_export_ready: 0,
|
||||
signals,
|
||||
}),
|
||||
text: async () => '',
|
||||
})
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
localStorage.clear()
|
||||
vi.stubGlobal('fetch', makeFetch())
|
||||
})
|
||||
|
||||
describe('AppSidebar structure', () => {
|
||||
it('renders section headers for Data, Eval, Train', async () => {
|
||||
const w = mount(AppSidebar, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
const text = w.text()
|
||||
expect(text).toContain('Data')
|
||||
expect(text).toContain('Eval')
|
||||
expect(text).toContain('Train')
|
||||
})
|
||||
|
||||
it('renders all sub-links', async () => {
|
||||
const w = mount(AppSidebar, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
const anchors = w.findAll('a')
|
||||
const hrefs = anchors.map(a => a.attributes('href') ?? '')
|
||||
expect(hrefs.some(h => h.includes('/data/label'))).toBe(true)
|
||||
expect(hrefs.some(h => h.includes('/data/fetch'))).toBe(true)
|
||||
expect(hrefs.some(h => h.includes('/data/corrections'))).toBe(true)
|
||||
expect(hrefs.some(h => h.includes('/data/imitate'))).toBe(true)
|
||||
expect(hrefs.some(h => h.includes('/eval/benchmark'))).toBe(true)
|
||||
expect(hrefs.some(h => h.includes('/eval/compare'))).toBe(true)
|
||||
expect(hrefs.some(h => h.includes('/train/jobs'))).toBe(true)
|
||||
expect(hrefs.some(h => h.includes('/train/results'))).toBe(true)
|
||||
expect(hrefs.some(h => h.includes('/fleet'))).toBe(true)
|
||||
expect(hrefs.some(h => h.includes('/settings'))).toBe(true)
|
||||
})
|
||||
|
||||
it('does NOT render the old /benchmark or /models links', async () => {
|
||||
const w = mount(AppSidebar, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
const anchors = w.findAll('a')
|
||||
const hrefs = anchors.map(a => a.attributes('href') ?? '')
|
||||
// Old paths must not appear as direct links (they're only redirects)
|
||||
expect(hrefs.every(h => !h.endsWith('/#/benchmark'))).toBe(true)
|
||||
expect(hrefs.every(h => !h.endsWith('/#/models'))).toBe(true)
|
||||
expect(hrefs.every(h => !h.endsWith('/#/stats'))).toBe(true)
|
||||
})
|
||||
|
||||
it('shows no signal badges when all signals are false', async () => {
|
||||
vi.stubGlobal('fetch', makeFetch({ data_to_eval: false, eval_to_train: false, train_to_fleet: false }))
|
||||
const w = mount(AppSidebar, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
expect(w.findAll('.signal-badge').length).toBe(0)
|
||||
})
|
||||
|
||||
it('shows signal badge on Data section when data_to_eval is true', async () => {
|
||||
vi.stubGlobal('fetch', makeFetch({ data_to_eval: true, eval_to_train: false, train_to_fleet: false }))
|
||||
const w = mount(AppSidebar, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
const badges = w.findAll('.signal-badge')
|
||||
expect(badges.length).toBe(1)
|
||||
// It should be inside the Data section header
|
||||
const dataHeader = w.find('[data-section="data"]')
|
||||
expect(dataHeader.find('.signal-badge').exists()).toBe(true)
|
||||
})
|
||||
|
||||
it('shows signal badge on Eval section when eval_to_train is true', async () => {
|
||||
vi.stubGlobal('fetch', makeFetch({ data_to_eval: false, eval_to_train: true, train_to_fleet: false }))
|
||||
const w = mount(AppSidebar, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
const evalHeader = w.find('[data-section="eval"]')
|
||||
expect(evalHeader.find('.signal-badge').exists()).toBe(true)
|
||||
})
|
||||
|
||||
it('shows signal badge on Train section when train_to_fleet is true', async () => {
|
||||
vi.stubGlobal('fetch', makeFetch({ data_to_eval: false, eval_to_train: false, train_to_fleet: true }))
|
||||
const w = mount(AppSidebar, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
const trainHeader = w.find('[data-section="train"]')
|
||||
expect(trainHeader.find('.signal-badge').exists()).toBe(true)
|
||||
})
|
||||
|
||||
it('stow toggle still works', async () => {
|
||||
const w = mount(AppSidebar, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
const nav = w.find('nav')
|
||||
expect(nav.classes()).not.toContain('stowed')
|
||||
await w.find('.stow-btn').trigger('click')
|
||||
expect(nav.classes()).toContain('stowed')
|
||||
})
|
||||
})
|
||||
|
|
@ -28,70 +28,12 @@
|
|||
</button>
|
||||
</div>
|
||||
|
||||
<!-- Nav -->
|
||||
<!-- Nav items -->
|
||||
<ul class="nav-list" role="list">
|
||||
<!-- Top-level links -->
|
||||
<li>
|
||||
<RouterLink
|
||||
to="/"
|
||||
class="nav-item"
|
||||
:title="stowed ? 'Dashboard' : ''"
|
||||
@click="isMobile && stow()"
|
||||
>
|
||||
<span class="nav-icon" aria-hidden="true">📊</span>
|
||||
<span v-if="!stowed" class="nav-label">Dashboard</span>
|
||||
</RouterLink>
|
||||
</li>
|
||||
<li>
|
||||
<RouterLink
|
||||
to="/fleet"
|
||||
class="nav-item"
|
||||
:title="stowed ? 'Fleet' : ''"
|
||||
@click="isMobile && stow()"
|
||||
>
|
||||
<span class="nav-icon" aria-hidden="true">⚡</span>
|
||||
<span v-if="!stowed" class="nav-label">Fleet</span>
|
||||
</RouterLink>
|
||||
</li>
|
||||
<li>
|
||||
<RouterLink
|
||||
to="/nodes"
|
||||
class="nav-item"
|
||||
:title="stowed ? 'Nodes' : ''"
|
||||
@click="isMobile && stow()"
|
||||
>
|
||||
<span class="nav-icon" aria-hidden="true">🖥️</span>
|
||||
<span v-if="!stowed" class="nav-label">Nodes</span>
|
||||
</RouterLink>
|
||||
</li>
|
||||
|
||||
<!-- ① Data section -->
|
||||
<li>
|
||||
<div class="section-header" data-section="data" aria-hidden="true">
|
||||
<template v-if="!stowed">
|
||||
<span class="section-label">① Data</span>
|
||||
<span
|
||||
v-if="signals.data_to_eval"
|
||||
class="signal-badge"
|
||||
title="Enough new labels to run eval"
|
||||
aria-label="Eval recommended"
|
||||
/>
|
||||
</template>
|
||||
<template v-else>
|
||||
<span class="section-icon">①</span>
|
||||
<span
|
||||
v-if="signals.data_to_eval"
|
||||
class="signal-badge signal-badge-stowed"
|
||||
title="Eval recommended"
|
||||
aria-label="Eval recommended"
|
||||
/>
|
||||
</template>
|
||||
</div>
|
||||
</li>
|
||||
<li v-for="item in dataItems" :key="item.path">
|
||||
<li v-for="item in navItems" :key="item.path">
|
||||
<RouterLink
|
||||
:to="item.path"
|
||||
class="nav-item nav-subitem"
|
||||
class="nav-item"
|
||||
:title="stowed ? item.label : ''"
|
||||
@click="isMobile && stow()"
|
||||
>
|
||||
|
|
@ -99,94 +41,10 @@
|
|||
<span v-if="!stowed" class="nav-label">{{ item.label }}</span>
|
||||
</RouterLink>
|
||||
</li>
|
||||
|
||||
<!-- ② Eval section -->
|
||||
<li>
|
||||
<div class="section-header" data-section="eval" aria-hidden="true">
|
||||
<template v-if="!stowed">
|
||||
<span class="section-label">② Eval</span>
|
||||
<span
|
||||
v-if="signals.eval_to_train"
|
||||
class="signal-badge"
|
||||
title="Strong eval result — consider finetuning"
|
||||
aria-label="Finetune recommended"
|
||||
/>
|
||||
</template>
|
||||
<template v-else>
|
||||
<span class="section-icon">②</span>
|
||||
<span
|
||||
v-if="signals.eval_to_train"
|
||||
class="signal-badge signal-badge-stowed"
|
||||
title="Finetune recommended"
|
||||
aria-label="Finetune recommended"
|
||||
/>
|
||||
</template>
|
||||
</div>
|
||||
</li>
|
||||
<li v-for="item in evalItems" :key="item.path">
|
||||
<RouterLink
|
||||
:to="item.path"
|
||||
class="nav-item nav-subitem"
|
||||
:title="stowed ? item.label : ''"
|
||||
@click="isMobile && stow()"
|
||||
>
|
||||
<span class="nav-icon" aria-hidden="true">{{ item.icon }}</span>
|
||||
<span v-if="!stowed" class="nav-label">{{ item.label }}</span>
|
||||
</RouterLink>
|
||||
</li>
|
||||
|
||||
<!-- ③ Train section -->
|
||||
<li>
|
||||
<div class="section-header" data-section="train" aria-hidden="true">
|
||||
<template v-if="!stowed">
|
||||
<span class="section-label">③ Train</span>
|
||||
<span
|
||||
v-if="signals.train_to_fleet"
|
||||
class="signal-badge"
|
||||
title="Trained model ready for fleet registration"
|
||||
aria-label="Fleet registration recommended"
|
||||
/>
|
||||
</template>
|
||||
<template v-else>
|
||||
<span class="section-icon">③</span>
|
||||
<span
|
||||
v-if="signals.train_to_fleet"
|
||||
class="signal-badge signal-badge-stowed"
|
||||
title="Fleet registration recommended"
|
||||
aria-label="Fleet registration recommended"
|
||||
/>
|
||||
</template>
|
||||
</div>
|
||||
</li>
|
||||
<li v-for="item in trainItems" :key="item.path">
|
||||
<RouterLink
|
||||
:to="item.path"
|
||||
class="nav-item nav-subitem"
|
||||
:title="stowed ? item.label : ''"
|
||||
@click="isMobile && stow()"
|
||||
>
|
||||
<span class="nav-icon" aria-hidden="true">{{ item.icon }}</span>
|
||||
<span v-if="!stowed" class="nav-label">{{ item.label }}</span>
|
||||
</RouterLink>
|
||||
</li>
|
||||
|
||||
<!-- Divider + Settings -->
|
||||
<li class="nav-divider" aria-hidden="true" />
|
||||
<li>
|
||||
<RouterLink
|
||||
to="/settings"
|
||||
class="nav-item"
|
||||
:title="stowed ? 'Settings' : ''"
|
||||
@click="isMobile && stow()"
|
||||
>
|
||||
<span class="nav-icon" aria-hidden="true">⚙️</span>
|
||||
<span v-if="!stowed" class="nav-label">Settings</span>
|
||||
</RouterLink>
|
||||
</li>
|
||||
</ul>
|
||||
</nav>
|
||||
|
||||
<!-- Mobile hamburger button — visible when sidebar is stowed on mobile -->
|
||||
<!-- Mobile hamburger button rendered outside the sidebar so it's visible when stowed -->
|
||||
<button
|
||||
v-if="isMobile && stowed"
|
||||
class="mobile-hamburger"
|
||||
|
|
@ -203,68 +61,22 @@ import { RouterLink } from 'vue-router'
|
|||
|
||||
const LS_KEY = 'cf-avocet-nav-stowed'
|
||||
|
||||
interface NavItem {
|
||||
path: string
|
||||
icon: string
|
||||
label: string
|
||||
}
|
||||
|
||||
interface DashboardSignals {
|
||||
data_to_eval: boolean
|
||||
eval_to_train: boolean
|
||||
train_to_fleet: boolean
|
||||
}
|
||||
|
||||
const dataItems: NavItem[] = [
|
||||
{ path: '/data/label', icon: '🏷', label: 'Label' },
|
||||
{ path: '/data/fetch', icon: '📬', label: 'Fetch' },
|
||||
{ path: '/data/corrections', icon: '✏️', label: 'Corrections' },
|
||||
{ path: '/data/imitate', icon: '🪞', label: 'Imitate' },
|
||||
{ path: '/data/recipe-scan', icon: '📷', label: 'Recipe Scan' },
|
||||
const navItems = [
|
||||
{ path: '/', icon: '🃏', label: 'Label' },
|
||||
{ path: '/fetch', icon: '📥', label: 'Fetch' },
|
||||
{ path: '/stats', icon: '📊', label: 'Stats' },
|
||||
{ path: '/benchmark', icon: '🏁', label: 'Benchmark' },
|
||||
{ path: '/settings', icon: '⚙️', label: 'Settings' },
|
||||
]
|
||||
|
||||
const evalItems: NavItem[] = [
|
||||
{ path: '/eval/benchmark', icon: '📊', label: 'Benchmark' },
|
||||
{ path: '/eval/compare', icon: '🔍', label: 'Compare' },
|
||||
{ path: '/eval/embed-compare', icon: '🧮', label: 'Embed Compare' },
|
||||
]
|
||||
|
||||
const trainItems: NavItem[] = [
|
||||
{ path: '/train/jobs', icon: '🧠', label: 'Jobs' },
|
||||
{ path: '/train/results', icon: '📈', label: 'Results' },
|
||||
]
|
||||
|
||||
const stowed = ref(localStorage.getItem(LS_KEY) === 'true')
|
||||
const winWidth = ref(window.innerWidth)
|
||||
const isMobile = computed(() => winWidth.value < 640)
|
||||
|
||||
const signals = ref<DashboardSignals>({
|
||||
data_to_eval: false,
|
||||
eval_to_train: false,
|
||||
train_to_fleet: false,
|
||||
})
|
||||
|
||||
async function loadSignals() {
|
||||
try {
|
||||
const res = await fetch('/api/dashboard')
|
||||
if (res.ok) {
|
||||
const data = await res.json() as { signals?: DashboardSignals }
|
||||
if (data.signals) {
|
||||
signals.value = {
|
||||
data_to_eval: data.signals.data_to_eval ?? false,
|
||||
eval_to_train: data.signals.eval_to_train ?? false,
|
||||
train_to_fleet: data.signals.train_to_fleet ?? false,
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
// Non-fatal: badges simply stay hidden if API is unreachable
|
||||
}
|
||||
}
|
||||
const stowed = ref(localStorage.getItem(LS_KEY) === 'true')
|
||||
const winWidth = ref(window.innerWidth)
|
||||
const isMobile = computed(() => winWidth.value < 640)
|
||||
|
||||
function toggle() {
|
||||
stowed.value = !stowed.value
|
||||
localStorage.setItem(LS_KEY, String(stowed.value))
|
||||
// Update CSS variable on :root so .app-main margin-left syncs
|
||||
document.documentElement.style.setProperty('--sidebar-width', stowed.value ? '56px' : '200px')
|
||||
}
|
||||
|
||||
|
|
@ -278,12 +90,13 @@ function onResize() { winWidth.value = window.innerWidth }
|
|||
|
||||
onMounted(() => {
|
||||
window.addEventListener('resize', onResize)
|
||||
// Apply persisted sidebar width to :root on mount
|
||||
document.documentElement.style.setProperty('--sidebar-width', stowed.value ? '56px' : '200px')
|
||||
// On mobile, default to stowed
|
||||
if (isMobile.value && !localStorage.getItem(LS_KEY)) {
|
||||
stowed.value = true
|
||||
document.documentElement.style.setProperty('--sidebar-width', '56px')
|
||||
}
|
||||
loadSignals()
|
||||
})
|
||||
|
||||
onUnmounted(() => window.removeEventListener('resize', onResize))
|
||||
|
|
@ -305,15 +118,18 @@ onUnmounted(() => window.removeEventListener('resize', onResize))
|
|||
overflow: hidden;
|
||||
}
|
||||
|
||||
.sidebar.stowed { width: 56px; }
|
||||
.sidebar.stowed {
|
||||
width: 56px;
|
||||
}
|
||||
|
||||
/* Mobile: slide in/out from left */
|
||||
.sidebar.mobile {
|
||||
box-shadow: 2px 0 16px rgba(0, 0, 0, 0.15);
|
||||
}
|
||||
|
||||
.sidebar.mobile.stowed {
|
||||
transform: translateX(-100%);
|
||||
width: 200px;
|
||||
width: 200px; /* keep width so slide-in looks right */
|
||||
transition: transform 250ms ease, width 250ms ease;
|
||||
}
|
||||
|
||||
|
|
@ -346,7 +162,10 @@ onUnmounted(() => window.removeEventListener('resize', onResize))
|
|||
white-space: nowrap;
|
||||
}
|
||||
|
||||
.logo-icon { font-size: 1.25rem; flex-shrink: 0; }
|
||||
.logo-icon {
|
||||
font-size: 1.25rem;
|
||||
flex-shrink: 0;
|
||||
}
|
||||
|
||||
.logo-name {
|
||||
font-family: var(--font-display, var(--font-body, sans-serif));
|
||||
|
|
@ -371,76 +190,16 @@ onUnmounted(() => window.removeEventListener('resize', onResize))
|
|||
transition: background 0.15s;
|
||||
}
|
||||
|
||||
.stow-btn:hover { background: var(--color-border, #d0d7e8); }
|
||||
.stow-btn:hover {
|
||||
background: var(--color-border, #d0d7e8);
|
||||
}
|
||||
|
||||
.nav-list {
|
||||
list-style: none;
|
||||
padding: 0.5rem 0;
|
||||
flex: 1;
|
||||
overflow-y: auto;
|
||||
overflow-x: hidden;
|
||||
}
|
||||
|
||||
/* ── Section headers ── */
|
||||
.section-header {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.4rem;
|
||||
padding: 0.55rem 0.75rem 0.25rem;
|
||||
margin-top: 0.5rem;
|
||||
pointer-events: none;
|
||||
user-select: none;
|
||||
}
|
||||
|
||||
.section-label {
|
||||
font-size: 0.7rem;
|
||||
font-weight: 700;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.07em;
|
||||
color: var(--color-text-muted, #4a5c7a);
|
||||
white-space: nowrap;
|
||||
flex: 1;
|
||||
}
|
||||
|
||||
.section-icon {
|
||||
font-size: 0.75rem;
|
||||
color: var(--color-text-muted, #4a5c7a);
|
||||
width: 24px;
|
||||
text-align: center;
|
||||
flex-shrink: 0;
|
||||
}
|
||||
|
||||
/* ── Signal badges ── */
|
||||
.signal-badge {
|
||||
width: 8px;
|
||||
height: 8px;
|
||||
border-radius: 50%;
|
||||
background: var(--color-warning, #d4891a);
|
||||
flex-shrink: 0;
|
||||
display: inline-block;
|
||||
}
|
||||
|
||||
.signal-badge-stowed {
|
||||
position: absolute;
|
||||
top: 4px;
|
||||
right: 4px;
|
||||
}
|
||||
|
||||
/* Make the stowed section header container position:relative for the badge */
|
||||
.sidebar.stowed .section-header {
|
||||
position: relative;
|
||||
justify-content: center;
|
||||
padding: 0.55rem 0 0.25rem;
|
||||
}
|
||||
|
||||
/* ── Nav divider ── */
|
||||
.nav-divider {
|
||||
height: 1px;
|
||||
background: var(--color-border, #d0d7e8);
|
||||
margin: 0.5rem 0.75rem;
|
||||
}
|
||||
|
||||
/* ── Nav items ── */
|
||||
.nav-item {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
|
|
@ -476,9 +235,6 @@ onUnmounted(() => window.removeEventListener('resize', onResize))
|
|||
border-radius: 0 2px 2px 0;
|
||||
}
|
||||
|
||||
/* Sub-items are indented slightly in expanded state */
|
||||
.nav-subitem { padding-left: 1.1rem; font-size: 0.875rem; }
|
||||
|
||||
.nav-icon {
|
||||
font-size: 1.1rem;
|
||||
flex-shrink: 0;
|
||||
|
|
@ -486,9 +242,12 @@ onUnmounted(() => window.removeEventListener('resize', onResize))
|
|||
text-align: center;
|
||||
}
|
||||
|
||||
.nav-label { overflow: hidden; text-overflow: ellipsis; }
|
||||
.nav-label {
|
||||
overflow: hidden;
|
||||
text-overflow: ellipsis;
|
||||
}
|
||||
|
||||
/* Mobile hamburger */
|
||||
/* Mobile hamburger — visible when sidebar is stowed on mobile */
|
||||
.mobile-hamburger {
|
||||
position: fixed;
|
||||
top: 0.75rem;
|
||||
|
|
|
|||
|
|
@ -1,179 +0,0 @@
|
|||
import { mount } from '@vue/test-utils'
|
||||
import SftCard from './SftCard.vue'
|
||||
import type { SftQueueItem } from '../stores/sft'
|
||||
import { describe, it, expect } from 'vitest'
|
||||
|
||||
const LOW_QUALITY_ITEM: SftQueueItem = {
|
||||
id: 'abc', source: 'cf-orch-benchmark', benchmark_run_id: 'run1',
|
||||
timestamp: '2026-04-07T10:00:00Z', status: 'needs_review',
|
||||
prompt_messages: [
|
||||
{ role: 'system', content: 'You are a coding assistant.' },
|
||||
{ role: 'user', content: 'Write a Python add function.' },
|
||||
],
|
||||
model_response: 'def add(a, b): return a - b',
|
||||
corrected_response: null, quality_score: 0.2,
|
||||
failure_reason: 'pattern_match: 0/2 matched',
|
||||
failure_category: null,
|
||||
task_id: 'code-fn', task_type: 'code', task_name: 'Code: Write a function',
|
||||
model_id: 'Qwen/Qwen2.5-3B', model_name: 'Qwen2.5-3B',
|
||||
node_id: 'heimdall', gpu_id: 0, tokens_per_sec: 38.4,
|
||||
}
|
||||
|
||||
const MID_QUALITY_ITEM: SftQueueItem = { ...LOW_QUALITY_ITEM, id: 'mid', quality_score: 0.55 }
|
||||
const HIGH_QUALITY_ITEM: SftQueueItem = { ...LOW_QUALITY_ITEM, id: 'hi', quality_score: 0.72 }
|
||||
|
||||
describe('SftCard', () => {
|
||||
it('renders model name chip', () => {
|
||||
const w = mount(SftCard, { props: { item: LOW_QUALITY_ITEM } })
|
||||
expect(w.text()).toContain('Qwen2.5-3B')
|
||||
})
|
||||
|
||||
it('renders task type chip', () => {
|
||||
const w = mount(SftCard, { props: { item: LOW_QUALITY_ITEM } })
|
||||
expect(w.text()).toContain('code')
|
||||
})
|
||||
|
||||
it('renders failure reason', () => {
|
||||
const w = mount(SftCard, { props: { item: LOW_QUALITY_ITEM } })
|
||||
expect(w.text()).toContain('pattern_match: 0/2 matched')
|
||||
})
|
||||
|
||||
it('renders model response', () => {
|
||||
const w = mount(SftCard, { props: { item: LOW_QUALITY_ITEM } })
|
||||
expect(w.text()).toContain('def add(a, b): return a - b')
|
||||
})
|
||||
|
||||
it('quality chip shows numeric value for low quality', () => {
|
||||
const w = mount(SftCard, { props: { item: LOW_QUALITY_ITEM } })
|
||||
expect(w.text()).toContain('0.20')
|
||||
})
|
||||
|
||||
it('quality chip has low-quality class when score < 0.4', () => {
|
||||
const w = mount(SftCard, { props: { item: LOW_QUALITY_ITEM } })
|
||||
expect(w.find('[data-testid="quality-chip"]').classes()).toContain('quality-low')
|
||||
})
|
||||
|
||||
it('quality chip has mid-quality class when score is 0.4 to <0.7', () => {
|
||||
const w = mount(SftCard, { props: { item: MID_QUALITY_ITEM } })
|
||||
expect(w.find('[data-testid="quality-chip"]').classes()).toContain('quality-mid')
|
||||
})
|
||||
|
||||
it('quality chip has acceptable class when score >= 0.7', () => {
|
||||
const w = mount(SftCard, { props: { item: HIGH_QUALITY_ITEM } })
|
||||
expect(w.find('[data-testid="quality-chip"]').classes()).toContain('quality-ok')
|
||||
})
|
||||
|
||||
it('clicking Correct button emits correct', async () => {
|
||||
const w = mount(SftCard, { props: { item: LOW_QUALITY_ITEM } })
|
||||
await w.find('[data-testid="correct-btn"]').trigger('click')
|
||||
expect(w.emitted('correct')).toBeTruthy()
|
||||
})
|
||||
|
||||
it('clicking Discard button then confirming emits discard', async () => {
|
||||
const w = mount(SftCard, { props: { item: LOW_QUALITY_ITEM } })
|
||||
await w.find('[data-testid="discard-btn"]').trigger('click')
|
||||
await w.find('[data-testid="confirm-pending-btn"]').trigger('click')
|
||||
expect(w.emitted('discard')).toBeTruthy()
|
||||
})
|
||||
|
||||
it('clicking Flag Model button then confirming emits flag', async () => {
|
||||
const w = mount(SftCard, { props: { item: LOW_QUALITY_ITEM } })
|
||||
await w.find('[data-testid="flag-btn"]').trigger('click')
|
||||
await w.find('[data-testid="confirm-pending-btn"]').trigger('click')
|
||||
expect(w.emitted('flag')).toBeTruthy()
|
||||
})
|
||||
|
||||
it('correction area hidden initially', () => {
|
||||
const w = mount(SftCard, { props: { item: LOW_QUALITY_ITEM } })
|
||||
expect(w.find('[data-testid="correction-area"]').exists()).toBe(false)
|
||||
})
|
||||
|
||||
it('correction area shown when correcting prop is true', () => {
|
||||
const w = mount(SftCard, { props: { item: LOW_QUALITY_ITEM, correcting: true } })
|
||||
expect(w.find('[data-testid="correction-area"]').exists()).toBe(true)
|
||||
})
|
||||
|
||||
it('renders nothing for failure reason when null', () => {
|
||||
const item = { ...LOW_QUALITY_ITEM, failure_reason: null }
|
||||
const w = mount(SftCard, { props: { item } })
|
||||
expect(w.find('.failure-reason').exists()).toBe(false)
|
||||
})
|
||||
|
||||
// ── Failure category chip-group ───────────────────────────────────
|
||||
it('failure category section hidden when not correcting and no pending action', () => {
|
||||
const w = mount(SftCard, { props: { item: LOW_QUALITY_ITEM } })
|
||||
expect(w.find('[data-testid="failure-category-section"]').exists()).toBe(false)
|
||||
})
|
||||
|
||||
it('failure category section shown when correcting prop is true', () => {
|
||||
const w = mount(SftCard, { props: { item: LOW_QUALITY_ITEM, correcting: true } })
|
||||
expect(w.find('[data-testid="failure-category-section"]').exists()).toBe(true)
|
||||
})
|
||||
|
||||
it('renders all six category chips when correcting', () => {
|
||||
const w = mount(SftCard, { props: { item: LOW_QUALITY_ITEM, correcting: true } })
|
||||
const chips = w.findAll('.category-chip')
|
||||
expect(chips).toHaveLength(6)
|
||||
})
|
||||
|
||||
it('clicking a category chip selects it (adds active class)', async () => {
|
||||
const w = mount(SftCard, { props: { item: LOW_QUALITY_ITEM, correcting: true } })
|
||||
const chip = w.find('[data-testid="category-chip-wrong_answer"]')
|
||||
await chip.trigger('click')
|
||||
expect(chip.classes()).toContain('category-chip--active')
|
||||
})
|
||||
|
||||
it('clicking the active chip again deselects it', async () => {
|
||||
const w = mount(SftCard, { props: { item: LOW_QUALITY_ITEM, correcting: true } })
|
||||
const chip = w.find('[data-testid="category-chip-hallucination"]')
|
||||
await chip.trigger('click')
|
||||
expect(chip.classes()).toContain('category-chip--active')
|
||||
await chip.trigger('click')
|
||||
expect(chip.classes()).not.toContain('category-chip--active')
|
||||
})
|
||||
|
||||
it('only one chip can be active at a time', async () => {
|
||||
const w = mount(SftCard, { props: { item: LOW_QUALITY_ITEM, correcting: true } })
|
||||
await w.find('[data-testid="category-chip-wrong_answer"]').trigger('click')
|
||||
await w.find('[data-testid="category-chip-hallucination"]').trigger('click')
|
||||
const active = w.findAll('.category-chip--active')
|
||||
expect(active).toHaveLength(1)
|
||||
expect(active[0].attributes('data-testid')).toBe('category-chip-hallucination')
|
||||
})
|
||||
|
||||
it('clicking Discard shows pending action row with category section', async () => {
|
||||
const w = mount(SftCard, { props: { item: LOW_QUALITY_ITEM } })
|
||||
await w.find('[data-testid="discard-btn"]').trigger('click')
|
||||
expect(w.find('[data-testid="failure-category-section"]').exists()).toBe(true)
|
||||
expect(w.find('[data-testid="pending-action-row"]').exists()).toBe(true)
|
||||
})
|
||||
|
||||
it('clicking Flag shows pending action row', async () => {
|
||||
const w = mount(SftCard, { props: { item: LOW_QUALITY_ITEM } })
|
||||
await w.find('[data-testid="flag-btn"]').trigger('click')
|
||||
expect(w.find('[data-testid="pending-action-row"]').exists()).toBe(true)
|
||||
})
|
||||
|
||||
it('confirming discard emits discard with null when no category selected', async () => {
|
||||
const w = mount(SftCard, { props: { item: LOW_QUALITY_ITEM } })
|
||||
await w.find('[data-testid="discard-btn"]').trigger('click')
|
||||
await w.find('[data-testid="confirm-pending-btn"]').trigger('click')
|
||||
expect(w.emitted('discard')).toBeTruthy()
|
||||
expect(w.emitted('discard')![0]).toEqual([null])
|
||||
})
|
||||
|
||||
it('confirming discard emits discard with selected category', async () => {
|
||||
const w = mount(SftCard, { props: { item: LOW_QUALITY_ITEM } })
|
||||
await w.find('[data-testid="discard-btn"]').trigger('click')
|
||||
await w.find('[data-testid="category-chip-scoring_artifact"]').trigger('click')
|
||||
await w.find('[data-testid="confirm-pending-btn"]').trigger('click')
|
||||
expect(w.emitted('discard')![0]).toEqual(['scoring_artifact'])
|
||||
})
|
||||
|
||||
it('cancelling pending action hides the pending row', async () => {
|
||||
const w = mount(SftCard, { props: { item: LOW_QUALITY_ITEM } })
|
||||
await w.find('[data-testid="discard-btn"]').trigger('click')
|
||||
await w.find('[data-testid="cancel-pending-btn"]').trigger('click')
|
||||
expect(w.find('[data-testid="pending-action-row"]').exists()).toBe(false)
|
||||
})
|
||||
})
|
||||
|
|
@ -1,393 +0,0 @@
|
|||
<template>
|
||||
<article class="sft-card">
|
||||
<!-- Chips row -->
|
||||
<div class="chips-row">
|
||||
<span class="chip chip-model">{{ item.model_name }}</span>
|
||||
<span class="chip chip-task">{{ item.task_type }}</span>
|
||||
<span class="chip chip-node">{{ item.node_id }} · GPU {{ item.gpu_id }}</span>
|
||||
<span class="chip chip-speed">{{ item.tokens_per_sec.toFixed(1) }} tok/s</span>
|
||||
<span
|
||||
class="chip quality-chip"
|
||||
:class="qualityClass"
|
||||
data-testid="quality-chip"
|
||||
:title="qualityLabel"
|
||||
>
|
||||
{{ item.quality_score.toFixed(2) }} · {{ qualityLabel }}
|
||||
</span>
|
||||
</div>
|
||||
|
||||
<!-- Failure reason -->
|
||||
<p v-if="item.failure_reason" class="failure-reason">{{ item.failure_reason }}</p>
|
||||
|
||||
<!-- Prompt (collapsible) -->
|
||||
<div class="prompt-section">
|
||||
<button
|
||||
class="prompt-toggle"
|
||||
:aria-expanded="promptExpanded"
|
||||
@click="promptExpanded = !promptExpanded"
|
||||
>
|
||||
{{ promptExpanded ? 'Hide prompt ↑' : 'Show full prompt ↓' }}
|
||||
</button>
|
||||
<div v-if="promptExpanded" class="prompt-messages">
|
||||
<div
|
||||
v-for="(msg, i) in item.prompt_messages"
|
||||
:key="i"
|
||||
class="prompt-message"
|
||||
:class="`role-${msg.role}`"
|
||||
>
|
||||
<span class="role-label">{{ msg.role }}</span>
|
||||
<pre class="message-content">{{ msg.content }}</pre>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Model response -->
|
||||
<div class="model-response-section">
|
||||
<p class="section-label">Model output (incorrect)</p>
|
||||
<pre class="model-response">{{ item.model_response }}</pre>
|
||||
</div>
|
||||
|
||||
<!-- Action bar -->
|
||||
<div class="action-bar">
|
||||
<button
|
||||
data-testid="correct-btn"
|
||||
class="btn-correct"
|
||||
@click="$emit('correct')"
|
||||
>✓ Correct</button>
|
||||
<button
|
||||
data-testid="discard-btn"
|
||||
class="btn-discard"
|
||||
@click="emitWithCategory('discard')"
|
||||
>✕ Discard</button>
|
||||
<button
|
||||
data-testid="flag-btn"
|
||||
class="btn-flag"
|
||||
@click="emitWithCategory('flag')"
|
||||
>⚑ Flag Model</button>
|
||||
</div>
|
||||
|
||||
<!-- Failure category selector (shown when correcting or acting) -->
|
||||
<div
|
||||
v-if="correcting || pendingAction"
|
||||
class="failure-category-section"
|
||||
data-testid="failure-category-section"
|
||||
>
|
||||
<p class="section-label">Failure category <span class="optional-label">(optional)</span></p>
|
||||
<div class="category-chips" role="group" aria-label="Failure category">
|
||||
<button
|
||||
v-for="cat in FAILURE_CATEGORIES"
|
||||
:key="cat.value"
|
||||
type="button"
|
||||
class="category-chip"
|
||||
:class="{ 'category-chip--active': selectedCategory === cat.value }"
|
||||
:aria-pressed="selectedCategory === cat.value || undefined"
|
||||
:data-testid="'category-chip-' + cat.value"
|
||||
@click="toggleCategory(cat.value)"
|
||||
>{{ cat.label }}</button>
|
||||
</div>
|
||||
|
||||
<!-- Pending discard/flag confirm row -->
|
||||
<div v-if="pendingAction" class="pending-action-row" data-testid="pending-action-row">
|
||||
<button class="btn-confirm" @click="confirmPendingAction" data-testid="confirm-pending-btn">
|
||||
Confirm {{ pendingAction }}
|
||||
</button>
|
||||
<button class="btn-cancel-pending" @click="cancelPendingAction" data-testid="cancel-pending-btn">
|
||||
Cancel
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Correction area (shown when correcting = true) -->
|
||||
<div v-if="correcting" data-testid="correction-area">
|
||||
<SftCorrectionArea
|
||||
ref="correctionAreaEl"
|
||||
:described-by="'sft-failure-' + item.id"
|
||||
@submit="handleSubmitCorrection"
|
||||
@cancel="$emit('cancel-correction')"
|
||||
/>
|
||||
</div>
|
||||
</article>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, computed } from 'vue'
|
||||
import type { SftQueueItem, SftFailureCategory } from '../stores/sft'
|
||||
import SftCorrectionArea from './SftCorrectionArea.vue'
|
||||
|
||||
const props = defineProps<{ item: SftQueueItem; correcting?: boolean }>()
|
||||
|
||||
const emit = defineEmits<{
|
||||
correct: []
|
||||
discard: [category: SftFailureCategory | null]
|
||||
flag: [category: SftFailureCategory | null]
|
||||
'submit-correction': [text: string, category: SftFailureCategory | null]
|
||||
'cancel-correction': []
|
||||
}>()
|
||||
|
||||
const FAILURE_CATEGORIES: { value: SftFailureCategory; label: string }[] = [
|
||||
{ value: 'scoring_artifact', label: 'Scoring artifact' },
|
||||
{ value: 'style_violation', label: 'Style violation' },
|
||||
{ value: 'partial_answer', label: 'Partial answer' },
|
||||
{ value: 'wrong_answer', label: 'Wrong answer' },
|
||||
{ value: 'format_error', label: 'Format error' },
|
||||
{ value: 'hallucination', label: 'Hallucination' },
|
||||
]
|
||||
|
||||
const promptExpanded = ref(false)
|
||||
const correctionAreaEl = ref<InstanceType<typeof SftCorrectionArea> | null>(null)
|
||||
const selectedCategory = ref<SftFailureCategory | null>(null)
|
||||
const pendingAction = ref<'discard' | 'flag' | null>(null)
|
||||
|
||||
const qualityClass = computed(() => {
|
||||
const s = props.item.quality_score
|
||||
if (s < 0.4) return 'quality-low'
|
||||
if (s < 0.7) return 'quality-mid'
|
||||
return 'quality-ok'
|
||||
})
|
||||
|
||||
const qualityLabel = computed(() => {
|
||||
const s = props.item.quality_score
|
||||
if (s < 0.4) return 'low quality'
|
||||
if (s < 0.7) return 'fair'
|
||||
return 'acceptable'
|
||||
})
|
||||
|
||||
function toggleCategory(cat: SftFailureCategory) {
|
||||
selectedCategory.value = selectedCategory.value === cat ? null : cat
|
||||
}
|
||||
|
||||
function emitWithCategory(action: 'discard' | 'flag') {
|
||||
pendingAction.value = action
|
||||
}
|
||||
|
||||
function confirmPendingAction() {
|
||||
if (!pendingAction.value) return
|
||||
emit(pendingAction.value, selectedCategory.value)
|
||||
pendingAction.value = null
|
||||
selectedCategory.value = null
|
||||
}
|
||||
|
||||
function cancelPendingAction() {
|
||||
pendingAction.value = null
|
||||
}
|
||||
|
||||
function handleSubmitCorrection(text: string) {
|
||||
emit('submit-correction', text, selectedCategory.value)
|
||||
selectedCategory.value = null
|
||||
}
|
||||
|
||||
function resetCorrection() {
|
||||
correctionAreaEl.value?.reset()
|
||||
selectedCategory.value = null
|
||||
pendingAction.value = null
|
||||
}
|
||||
|
||||
defineExpose({ resetCorrection })
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
.sft-card {
|
||||
background: var(--color-surface-raised);
|
||||
border: 1px solid var(--color-border);
|
||||
border-radius: var(--radius-lg);
|
||||
padding: var(--space-4);
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: var(--space-3);
|
||||
}
|
||||
|
||||
.chips-row {
|
||||
display: flex;
|
||||
flex-wrap: wrap;
|
||||
gap: var(--space-2);
|
||||
}
|
||||
|
||||
.chip {
|
||||
padding: var(--space-1) var(--space-2);
|
||||
border-radius: var(--radius-full);
|
||||
font-size: 0.78rem;
|
||||
font-weight: 600;
|
||||
white-space: nowrap;
|
||||
}
|
||||
|
||||
.chip-model { background: var(--color-primary-light, #e8f2e7); color: var(--color-primary); }
|
||||
.chip-task { background: var(--color-surface-alt); color: var(--color-text-muted); }
|
||||
.chip-node { background: var(--color-surface-alt); color: var(--color-text-muted); }
|
||||
.chip-speed { background: var(--color-surface-alt); color: var(--color-text-muted); }
|
||||
|
||||
.quality-chip { color: #fff; }
|
||||
.quality-low { background: var(--color-error, #c0392b); }
|
||||
.quality-mid { background: var(--color-warning, #d4891a); }
|
||||
.quality-ok { background: var(--color-success, #3a7a32); }
|
||||
|
||||
.failure-reason {
|
||||
font-size: 0.82rem;
|
||||
color: var(--color-text-muted);
|
||||
font-style: italic;
|
||||
}
|
||||
|
||||
.prompt-toggle {
|
||||
background: none;
|
||||
border: none;
|
||||
color: var(--color-accent);
|
||||
font-size: 0.85rem;
|
||||
cursor: pointer;
|
||||
padding: 0;
|
||||
text-decoration: underline;
|
||||
}
|
||||
|
||||
.prompt-messages {
|
||||
margin-top: var(--space-2);
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: var(--space-2);
|
||||
}
|
||||
|
||||
.prompt-message {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: var(--space-1);
|
||||
}
|
||||
|
||||
.role-label {
|
||||
font-size: 0.75rem;
|
||||
font-weight: 700;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.05em;
|
||||
color: var(--color-text-muted);
|
||||
}
|
||||
|
||||
.message-content {
|
||||
font-family: var(--font-mono);
|
||||
font-size: 0.82rem;
|
||||
white-space: pre-wrap;
|
||||
background: var(--color-surface-alt);
|
||||
padding: var(--space-2) var(--space-3);
|
||||
border-radius: var(--radius-md);
|
||||
max-height: 200px;
|
||||
overflow-y: auto;
|
||||
}
|
||||
|
||||
.section-label {
|
||||
font-size: 0.82rem;
|
||||
font-weight: 600;
|
||||
color: var(--color-text-muted);
|
||||
margin-bottom: var(--space-1);
|
||||
}
|
||||
|
||||
.model-response {
|
||||
font-family: var(--font-mono);
|
||||
font-size: 0.88rem;
|
||||
white-space: pre-wrap;
|
||||
background: color-mix(in srgb, var(--color-error, #c0392b) 8%, var(--color-surface-alt));
|
||||
border-left: 3px solid var(--color-error, #c0392b);
|
||||
padding: var(--space-3);
|
||||
border-radius: var(--radius-md);
|
||||
max-height: 300px;
|
||||
overflow-y: auto;
|
||||
}
|
||||
|
||||
.action-bar {
|
||||
display: flex;
|
||||
gap: var(--space-3);
|
||||
flex-wrap: wrap;
|
||||
}
|
||||
|
||||
.action-bar button {
|
||||
padding: var(--space-2) var(--space-4);
|
||||
border-radius: var(--radius-md);
|
||||
border: 1px solid var(--color-border);
|
||||
font-size: 0.9rem;
|
||||
cursor: pointer;
|
||||
background: var(--color-surface-raised);
|
||||
color: var(--color-text);
|
||||
}
|
||||
|
||||
.btn-correct { border-color: var(--color-success); color: var(--color-success); }
|
||||
.btn-correct:hover { background: color-mix(in srgb, var(--color-success) 10%, transparent); }
|
||||
|
||||
.btn-discard { border-color: var(--color-error); color: var(--color-error); }
|
||||
.btn-discard:hover { background: color-mix(in srgb, var(--color-error) 10%, transparent); }
|
||||
|
||||
.btn-flag { border-color: var(--color-warning); color: var(--color-warning); }
|
||||
.btn-flag:hover { background: color-mix(in srgb, var(--color-warning) 10%, transparent); }
|
||||
|
||||
/* ── Failure category selector ─────────────────── */
|
||||
.failure-category-section {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: var(--space-2);
|
||||
}
|
||||
|
||||
.optional-label {
|
||||
font-size: 0.75rem;
|
||||
font-weight: 400;
|
||||
color: var(--color-text-muted);
|
||||
}
|
||||
|
||||
.category-chips {
|
||||
display: flex;
|
||||
flex-wrap: wrap;
|
||||
gap: var(--space-2);
|
||||
}
|
||||
|
||||
.category-chip {
|
||||
padding: var(--space-1) var(--space-3);
|
||||
border-radius: var(--radius-full);
|
||||
border: 1px solid var(--color-border);
|
||||
background: var(--color-surface-alt);
|
||||
color: var(--color-text-muted);
|
||||
font-size: 0.78rem;
|
||||
font-weight: 500;
|
||||
cursor: pointer;
|
||||
transition: background var(--transition), color var(--transition), border-color var(--transition);
|
||||
}
|
||||
|
||||
.category-chip:hover {
|
||||
border-color: var(--color-accent);
|
||||
color: var(--color-accent);
|
||||
background: var(--color-accent-light);
|
||||
}
|
||||
|
||||
.category-chip--active {
|
||||
background: var(--color-accent-light);
|
||||
border-color: var(--color-accent);
|
||||
color: var(--color-accent);
|
||||
font-weight: 700;
|
||||
}
|
||||
|
||||
.pending-action-row {
|
||||
display: flex;
|
||||
gap: var(--space-2);
|
||||
margin-top: var(--space-1);
|
||||
}
|
||||
|
||||
.btn-confirm {
|
||||
padding: var(--space-1) var(--space-3);
|
||||
border-radius: var(--radius-md);
|
||||
border: 1px solid var(--color-accent);
|
||||
background: var(--color-accent-light);
|
||||
color: var(--color-accent);
|
||||
font-size: 0.85rem;
|
||||
font-weight: 600;
|
||||
cursor: pointer;
|
||||
}
|
||||
|
||||
.btn-confirm:hover {
|
||||
background: color-mix(in srgb, var(--color-accent) 15%, transparent);
|
||||
}
|
||||
|
||||
.btn-cancel-pending {
|
||||
padding: var(--space-1) var(--space-3);
|
||||
border-radius: var(--radius-md);
|
||||
border: 1px solid var(--color-border);
|
||||
background: none;
|
||||
color: var(--color-text-muted);
|
||||
font-size: 0.85rem;
|
||||
cursor: pointer;
|
||||
}
|
||||
|
||||
.btn-cancel-pending:hover {
|
||||
background: var(--color-surface-alt);
|
||||
}
|
||||
</style>
|
||||
|
|
@ -1,68 +0,0 @@
|
|||
import { mount } from '@vue/test-utils'
|
||||
import SftCorrectionArea from './SftCorrectionArea.vue'
|
||||
import { describe, it, expect } from 'vitest'
|
||||
|
||||
describe('SftCorrectionArea', () => {
|
||||
it('renders a textarea', () => {
|
||||
const w = mount(SftCorrectionArea)
|
||||
expect(w.find('textarea').exists()).toBe(true)
|
||||
})
|
||||
|
||||
it('submit button is disabled when textarea is empty', () => {
|
||||
const w = mount(SftCorrectionArea)
|
||||
const btn = w.find('[data-testid="submit-btn"]')
|
||||
expect((btn.element as HTMLButtonElement).disabled).toBe(true)
|
||||
})
|
||||
|
||||
it('submit button is disabled when textarea is whitespace only', async () => {
|
||||
const w = mount(SftCorrectionArea)
|
||||
await w.find('textarea').setValue(' ')
|
||||
const btn = w.find('[data-testid="submit-btn"]')
|
||||
expect((btn.element as HTMLButtonElement).disabled).toBe(true)
|
||||
})
|
||||
|
||||
it('submit button is enabled when textarea has content', async () => {
|
||||
const w = mount(SftCorrectionArea)
|
||||
await w.find('textarea').setValue('def add(a, b): return a + b')
|
||||
const btn = w.find('[data-testid="submit-btn"]')
|
||||
expect((btn.element as HTMLButtonElement).disabled).toBe(false)
|
||||
})
|
||||
|
||||
it('clicking submit emits submit with trimmed text', async () => {
|
||||
const w = mount(SftCorrectionArea)
|
||||
await w.find('textarea').setValue(' def add(a, b): return a + b ')
|
||||
await w.find('[data-testid="submit-btn"]').trigger('click')
|
||||
expect(w.emitted('submit')?.[0]).toEqual(['def add(a, b): return a + b'])
|
||||
})
|
||||
|
||||
it('clicking cancel emits cancel', async () => {
|
||||
const w = mount(SftCorrectionArea)
|
||||
await w.find('[data-testid="cancel-btn"]').trigger('click')
|
||||
expect(w.emitted('cancel')).toBeTruthy()
|
||||
})
|
||||
|
||||
it('Escape key emits cancel', async () => {
|
||||
const w = mount(SftCorrectionArea)
|
||||
await w.find('textarea').trigger('keydown', { key: 'Escape' })
|
||||
expect(w.emitted('cancel')).toBeTruthy()
|
||||
})
|
||||
|
||||
it('Ctrl+Enter emits submit when text is non-empty', async () => {
|
||||
const w = mount(SftCorrectionArea)
|
||||
await w.find('textarea').setValue('correct answer')
|
||||
await w.find('textarea').trigger('keydown', { key: 'Enter', ctrlKey: true })
|
||||
expect(w.emitted('submit')?.[0]).toEqual(['correct answer'])
|
||||
})
|
||||
|
||||
it('Ctrl+Enter does not emit submit when text is empty', async () => {
|
||||
const w = mount(SftCorrectionArea)
|
||||
await w.find('textarea').trigger('keydown', { key: 'Enter', ctrlKey: true })
|
||||
expect(w.emitted('submit')).toBeFalsy()
|
||||
})
|
||||
|
||||
it('omits aria-describedby when describedBy prop is not provided', () => {
|
||||
const w = mount(SftCorrectionArea)
|
||||
const textarea = w.find('textarea')
|
||||
expect(textarea.attributes('aria-describedby')).toBeUndefined()
|
||||
})
|
||||
})
|
||||
|
|
@ -1,130 +0,0 @@
|
|||
<template>
|
||||
<div class="correction-area">
|
||||
<label class="correction-label" for="correction-textarea">
|
||||
Write the corrected response:
|
||||
</label>
|
||||
<textarea
|
||||
id="correction-textarea"
|
||||
ref="textareaEl"
|
||||
v-model="text"
|
||||
class="correction-textarea"
|
||||
aria-label="Write corrected response"
|
||||
aria-required="true"
|
||||
:aria-describedby="describedBy || undefined"
|
||||
placeholder="Write the response this model should have given..."
|
||||
rows="4"
|
||||
@keydown.escape="$emit('cancel')"
|
||||
@keydown.enter.ctrl.prevent="submitIfValid"
|
||||
@keydown.enter.meta.prevent="submitIfValid"
|
||||
/>
|
||||
<div class="correction-actions">
|
||||
<button
|
||||
data-testid="submit-btn"
|
||||
class="btn-submit"
|
||||
:disabled="!isValid"
|
||||
@click="submitIfValid"
|
||||
>
|
||||
Submit correction
|
||||
</button>
|
||||
<button data-testid="cancel-btn" class="btn-cancel" @click="$emit('cancel')">
|
||||
Cancel
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, computed, onMounted } from 'vue'
|
||||
|
||||
const props = withDefaults(defineProps<{ describedBy?: string }>(), { describedBy: undefined })
|
||||
|
||||
const emit = defineEmits<{ submit: [text: string]; cancel: [] }>()
|
||||
|
||||
const text = ref('')
|
||||
const textareaEl = ref<HTMLTextAreaElement | null>(null)
|
||||
const isValid = computed(() => text.value.trim().length > 0)
|
||||
|
||||
onMounted(() => textareaEl.value?.focus())
|
||||
|
||||
function submitIfValid() {
|
||||
if (isValid.value) emit('submit', text.value.trim())
|
||||
}
|
||||
|
||||
function reset() {
|
||||
text.value = ''
|
||||
}
|
||||
|
||||
defineExpose({ reset })
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
.correction-area {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: var(--space-3);
|
||||
padding: var(--space-4);
|
||||
border-top: 1px solid var(--color-border);
|
||||
background: var(--color-surface-alt, var(--color-surface));
|
||||
border-radius: 0 0 var(--radius-lg) var(--radius-lg);
|
||||
}
|
||||
|
||||
.correction-label {
|
||||
font-size: 0.85rem;
|
||||
font-weight: 600;
|
||||
color: var(--color-text-muted);
|
||||
}
|
||||
|
||||
.correction-textarea {
|
||||
width: 100%;
|
||||
min-height: 7rem;
|
||||
padding: var(--space-3);
|
||||
border: 1px solid var(--color-border);
|
||||
border-radius: var(--radius-md);
|
||||
background: var(--color-surface-raised);
|
||||
color: var(--color-text);
|
||||
font-family: var(--font-mono);
|
||||
font-size: 0.88rem;
|
||||
line-height: 1.5;
|
||||
resize: vertical;
|
||||
}
|
||||
|
||||
.correction-textarea:focus {
|
||||
outline: 2px solid var(--color-primary);
|
||||
outline-offset: 1px;
|
||||
}
|
||||
|
||||
.correction-actions {
|
||||
display: flex;
|
||||
gap: var(--space-3);
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
.btn-submit {
|
||||
padding: var(--space-2) var(--space-4);
|
||||
background: var(--color-primary);
|
||||
color: var(--color-text-inverse, #fff);
|
||||
border: none;
|
||||
border-radius: var(--radius-md);
|
||||
font-size: 0.9rem;
|
||||
cursor: pointer;
|
||||
}
|
||||
|
||||
.btn-submit:disabled {
|
||||
opacity: 0.45;
|
||||
cursor: not-allowed;
|
||||
}
|
||||
|
||||
.btn-submit:not(:disabled):hover {
|
||||
background: var(--color-primary-hover, var(--color-primary));
|
||||
}
|
||||
|
||||
.btn-cancel {
|
||||
background: none;
|
||||
border: none;
|
||||
color: var(--color-text-muted);
|
||||
font-size: 0.9rem;
|
||||
cursor: pointer;
|
||||
text-decoration: underline;
|
||||
padding: 0;
|
||||
}
|
||||
</style>
|
||||
|
|
@ -1,170 +0,0 @@
|
|||
<script setup lang="ts">
|
||||
import { ref, watch } from 'vue'
|
||||
import type { CatalogEntryFull } from '../../types/nodes'
|
||||
|
||||
const props = defineProps<{
|
||||
svcName: string
|
||||
modelName?: string
|
||||
entry?: CatalogEntryFull
|
||||
}>()
|
||||
const emit = defineEmits<{
|
||||
save: [svcName: string, modelName: string, entry: CatalogEntryFull]
|
||||
cancel: []
|
||||
}>()
|
||||
|
||||
const name = ref(props.modelName ?? '')
|
||||
const path = ref(props.entry?.path ?? '')
|
||||
const vramMb = ref(props.entry?.vram_mb ?? 0)
|
||||
const description = ref(props.entry?.description ?? '')
|
||||
const multiGpu = ref(props.entry?.multi_gpu ?? false)
|
||||
const envPairs = ref<{ k: string; v: string }[]>(
|
||||
Object.entries(props.entry?.env ?? {}).map(([k, v]) => ({ k, v }))
|
||||
)
|
||||
const formError = ref('')
|
||||
|
||||
watch(() => props.entry, (e) => {
|
||||
name.value = props.modelName ?? ''
|
||||
path.value = e?.path ?? ''
|
||||
vramMb.value = e?.vram_mb ?? 0
|
||||
description.value = e?.description ?? ''
|
||||
multiGpu.value = e?.multi_gpu ?? false
|
||||
envPairs.value = Object.entries(e?.env ?? {}).map(([k, v]) => ({ k, v }))
|
||||
})
|
||||
|
||||
function addEnvPair() {
|
||||
envPairs.value = [...envPairs.value, { k: '', v: '' }]
|
||||
}
|
||||
function removeEnvPair(i: number) {
|
||||
envPairs.value = envPairs.value.filter((_, idx) => idx !== i)
|
||||
}
|
||||
|
||||
function submit() {
|
||||
formError.value = ''
|
||||
if (!name.value.trim()) { formError.value = 'Model name is required.'; return }
|
||||
if (!path.value.trim()) { formError.value = 'Path is required.'; return }
|
||||
if (!vramMb.value || vramMb.value < 0) { formError.value = 'vram_mb must be a positive number.'; return }
|
||||
|
||||
const envObj: Record<string, string> = {}
|
||||
for (const { k, v } of envPairs.value) {
|
||||
if (k.trim()) envObj[k.trim()] = v
|
||||
}
|
||||
|
||||
const entry: CatalogEntryFull = { path: path.value.trim(), vram_mb: vramMb.value }
|
||||
if (description.value.trim()) entry.description = description.value.trim()
|
||||
if (multiGpu.value) entry.multi_gpu = true
|
||||
if (Object.keys(envObj).length) entry.env = envObj
|
||||
|
||||
emit('save', props.svcName, name.value.trim(), entry)
|
||||
}
|
||||
</script>
|
||||
|
||||
<template>
|
||||
<div class="modal-backdrop" role="dialog" aria-modal="true" :aria-label="`${modelName ? 'Edit' : 'Add'} catalog entry`">
|
||||
<div class="modal-box">
|
||||
<h3 class="modal-title">{{ modelName ? 'Edit' : 'Add' }} Catalog Entry — {{ svcName }}</h3>
|
||||
|
||||
<div class="field-row">
|
||||
<label class="field-label" for="ce-name">Model name</label>
|
||||
<input id="ce-name" v-model="name" class="field-input" :readonly="!!modelName" placeholder="deepseek-r1-7b" />
|
||||
</div>
|
||||
<div class="field-row">
|
||||
<label class="field-label" for="ce-path">Path</label>
|
||||
<input id="ce-path" v-model="path" class="field-input" placeholder="/devl/Assets/LLM/cf-text/models/..." />
|
||||
</div>
|
||||
<div class="field-row">
|
||||
<label class="field-label" for="ce-vram">VRAM (MB)</label>
|
||||
<input id="ce-vram" v-model.number="vramMb" type="number" min="0" class="field-input field-input--sm" />
|
||||
</div>
|
||||
<div class="field-row">
|
||||
<label class="field-label" for="ce-desc">Description</label>
|
||||
<input id="ce-desc" v-model="description" class="field-input" placeholder="Short description" />
|
||||
</div>
|
||||
<div class="field-row field-row--check">
|
||||
<input id="ce-mgpu" v-model="multiGpu" type="checkbox" />
|
||||
<label for="ce-mgpu">Multi-GPU span</label>
|
||||
</div>
|
||||
|
||||
<div class="env-section">
|
||||
<div class="env-header">
|
||||
<span class="field-label">Env vars</span>
|
||||
<button type="button" class="btn-link" @click="addEnvPair">+ Add</button>
|
||||
</div>
|
||||
<div v-for="(pair, i) in envPairs" :key="i" class="env-row">
|
||||
<input v-model="pair.k" class="field-input field-input--sm" placeholder="CF_TEXT_4BIT" />
|
||||
<span>=</span>
|
||||
<input v-model="pair.v" class="field-input field-input--sm" placeholder="1" />
|
||||
<button type="button" class="btn-icon" @click="removeEnvPair(i)" aria-label="Remove">✕</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div v-if="formError" class="form-error" role="alert">{{ formError }}</div>
|
||||
|
||||
<div class="modal-actions">
|
||||
<button class="btn-secondary" @click="emit('cancel')">Cancel</button>
|
||||
<button class="btn-primary" @click="submit">Save</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<style scoped>
|
||||
.modal-backdrop {
|
||||
position: fixed; inset: 0;
|
||||
background: rgba(0,0,0,0.5);
|
||||
display: flex; align-items: center; justify-content: center;
|
||||
z-index: 200;
|
||||
}
|
||||
.modal-box {
|
||||
background: var(--color-surface-raised);
|
||||
border: 1px solid var(--color-border);
|
||||
border-radius: 8px;
|
||||
padding: 1.5rem;
|
||||
width: 100%; max-width: 500px;
|
||||
max-height: 90vh; overflow-y: auto;
|
||||
display: flex; flex-direction: column; gap: 0.75rem;
|
||||
color: var(--color-text);
|
||||
}
|
||||
.modal-title { margin: 0 0 0.25rem; font-size: 1rem; font-weight: 600; color: var(--color-text); }
|
||||
.field-row { display: flex; align-items: center; gap: 0.5rem; }
|
||||
.field-row--check { gap: 0.4rem; color: var(--color-text); }
|
||||
.field-label { min-width: 8rem; font-size: 0.85rem; color: var(--color-text-muted); }
|
||||
.field-input {
|
||||
flex: 1;
|
||||
background: var(--color-surface-alt);
|
||||
border: 1px solid var(--color-border);
|
||||
border-radius: 4px;
|
||||
padding: 0.3rem 0.5rem;
|
||||
color: var(--color-text);
|
||||
font-size: 0.85rem;
|
||||
}
|
||||
.field-input--sm { flex: 0 0 8rem; }
|
||||
.env-section { display: flex; flex-direction: column; gap: 0.35rem; }
|
||||
.env-header { display: flex; align-items: center; justify-content: space-between; }
|
||||
.env-row { display: flex; align-items: center; gap: 0.4rem; }
|
||||
.btn-link { background: none; border: none; color: var(--app-primary); cursor: pointer; font-size: 0.8rem; padding: 0; }
|
||||
.btn-link:hover { color: var(--app-primary-hover); }
|
||||
.btn-icon { background: none; border: none; color: var(--color-text-muted); cursor: pointer; padding: 0 0.2rem; font-size: 0.85rem; }
|
||||
.btn-icon:hover { color: var(--color-error); }
|
||||
.form-error { color: var(--color-error); font-size: 0.8rem; }
|
||||
.modal-actions { display: flex; justify-content: flex-end; gap: 0.5rem; margin-top: 0.25rem; }
|
||||
.btn-primary {
|
||||
background: var(--app-primary);
|
||||
color: var(--color-text-inverse);
|
||||
border: none;
|
||||
border-radius: 4px;
|
||||
padding: 0.4rem 1rem;
|
||||
cursor: pointer;
|
||||
font-size: 0.875rem;
|
||||
}
|
||||
.btn-primary:hover { background: var(--app-primary-hover); }
|
||||
.btn-secondary {
|
||||
background: transparent;
|
||||
border: 1px solid var(--color-border);
|
||||
color: var(--color-text);
|
||||
border-radius: 4px;
|
||||
padding: 0.4rem 0.75rem;
|
||||
cursor: pointer;
|
||||
font-size: 0.875rem;
|
||||
}
|
||||
.btn-secondary:hover { background: var(--color-surface-alt); }
|
||||
</style>
|
||||
|
|
@ -1,129 +0,0 @@
|
|||
<script setup lang="ts">
|
||||
import { ref, computed } from 'vue'
|
||||
import ServiceBadge from './ServiceBadge.vue'
|
||||
import type { GpuEntry, ServiceInfo } from '../../types/nodes'
|
||||
|
||||
const props = defineProps<{
|
||||
gpu: GpuEntry
|
||||
nodeId: string
|
||||
profileLoaded: boolean
|
||||
servicesCatalog: Record<string, ServiceInfo>
|
||||
}>()
|
||||
|
||||
const emit = defineEmits<{ updated: [] }>()
|
||||
|
||||
const saving = ref(false)
|
||||
const saveError = ref('')
|
||||
|
||||
const vramPct = computed(() => {
|
||||
if (!props.gpu.vram_total_mb) return 0
|
||||
return Math.round((props.gpu.vram_used_mb / props.gpu.vram_total_mb) * 100)
|
||||
})
|
||||
|
||||
function serviceState(svcName: string): 'running' | 'stopped' | 'assigned-only' | 'available' | 'incompatible' | 'unknown' {
|
||||
const svc = props.servicesCatalog[svcName]
|
||||
if (!svc) return 'unknown'
|
||||
const cap = props.gpu.compute_cap ?? 0
|
||||
if (cap < svc.min_compute_cap) return 'incompatible'
|
||||
if (props.gpu.services_running.includes(svcName)) return 'running'
|
||||
if (props.gpu.services_assigned.includes(svcName)) return 'assigned-only'
|
||||
return 'available'
|
||||
}
|
||||
|
||||
async function toggleService(svcName: string) {
|
||||
if (!props.profileLoaded || saving.value) return
|
||||
const current = [...props.gpu.services_assigned]
|
||||
const removing = current.includes(svcName)
|
||||
if (removing && !confirm(`Remove ${svcName} from GPU ${props.gpu.gpu_id}?`)) return
|
||||
const next = removing ? current.filter(s => s !== svcName) : [...current, svcName]
|
||||
|
||||
saving.value = true
|
||||
saveError.value = ''
|
||||
try {
|
||||
const r = await fetch(
|
||||
`/api/nodes-mgmt/nodes/${props.nodeId}/gpu/${props.gpu.gpu_id}/services`,
|
||||
{
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ services: next }),
|
||||
},
|
||||
)
|
||||
if (!r.ok) {
|
||||
const data = await r.json().catch(() => ({}))
|
||||
throw new Error((data as { detail?: string }).detail ?? `HTTP ${r.status}`)
|
||||
}
|
||||
const data = await r.json() as { ok: boolean; reloaded: boolean; warnings: string[] }
|
||||
if (data.warnings?.length) saveError.value = `Saved (warning: ${data.warnings.join(', ')})`
|
||||
emit('updated')
|
||||
} catch (e) {
|
||||
saveError.value = e instanceof Error ? e.message : 'Failed to update services'
|
||||
} finally {
|
||||
saving.value = false
|
||||
}
|
||||
}
|
||||
</script>
|
||||
|
||||
<template>
|
||||
<div class="gpu-row">
|
||||
<div class="gpu-info">
|
||||
<span class="gpu-label">GPU {{ gpu.gpu_id }}: {{ gpu.card }}</span>
|
||||
<span v-if="gpu.compute_cap != null" class="gpu-meta">sm{{ gpu.compute_cap }}</span>
|
||||
<span v-if="gpu.temp_c != null" class="gpu-meta">{{ gpu.temp_c }}°C</span>
|
||||
<span v-if="gpu.utilization_pct != null" class="gpu-meta">{{ gpu.utilization_pct }}%</span>
|
||||
</div>
|
||||
|
||||
<div class="vram-wrap">
|
||||
<div
|
||||
class="vram-bar"
|
||||
role="progressbar"
|
||||
:aria-valuenow="gpu.vram_used_mb"
|
||||
aria-valuemin="0"
|
||||
:aria-valuemax="gpu.vram_total_mb"
|
||||
:aria-label="`VRAM: ${gpu.vram_used_mb} of ${gpu.vram_total_mb} MB used`"
|
||||
>
|
||||
<div class="vram-fill" :style="{ width: `${vramPct}%` }" />
|
||||
</div>
|
||||
<span class="vram-text">{{ gpu.vram_used_mb }} / {{ gpu.vram_total_mb }} MB ({{ vramPct }}%)</span>
|
||||
</div>
|
||||
|
||||
<div v-if="profileLoaded" class="services-row" aria-label="Service assignments">
|
||||
<ServiceBadge
|
||||
v-for="(_, svcName) in servicesCatalog"
|
||||
:key="String(svcName)"
|
||||
:service-name="String(svcName)"
|
||||
:state="serviceState(String(svcName))"
|
||||
:assigned="gpu.services_assigned.includes(String(svcName))"
|
||||
:disabled="saving"
|
||||
@toggle="toggleService(String(svcName))"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div v-if="saveError" class="save-msg" role="alert">{{ saveError }}</div>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<style scoped>
|
||||
.gpu-row {
|
||||
padding: 0.5rem 0.75rem;
|
||||
border-radius: 4px;
|
||||
background: var(--color-surface-alt);
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0.4rem;
|
||||
}
|
||||
.gpu-info { display: flex; gap: 0.75rem; align-items: center; flex-wrap: wrap; font-size: 0.875rem; }
|
||||
.gpu-label { font-weight: 500; color: var(--color-text); }
|
||||
.gpu-meta { color: var(--color-text-muted); font-size: 0.8rem; }
|
||||
.vram-wrap { display: flex; align-items: center; gap: 0.5rem; }
|
||||
.vram-bar {
|
||||
flex: 1;
|
||||
height: 8px;
|
||||
background: var(--color-border);
|
||||
border-radius: 4px;
|
||||
overflow: hidden;
|
||||
}
|
||||
.vram-fill { height: 100%; background: var(--app-primary); transition: width 0.3s; }
|
||||
.vram-text { font-size: 0.75rem; color: var(--color-text-muted); white-space: nowrap; }
|
||||
.services-row { display: flex; flex-wrap: wrap; gap: 0.4rem; }
|
||||
.save-msg { color: var(--color-warning); font-size: 0.8rem; }
|
||||
</style>
|
||||
|
|
@ -1,134 +0,0 @@
|
|||
<script setup lang="ts">
|
||||
import { ref, onMounted, onUnmounted } from 'vue'
|
||||
|
||||
interface CatalogEntry {
|
||||
path: string
|
||||
vram_mb: number
|
||||
description: string
|
||||
multi_gpu: boolean
|
||||
}
|
||||
|
||||
interface ServiceProfile {
|
||||
catalog: Record<string, CatalogEntry>
|
||||
min_compute_cap: number
|
||||
max_mb: number
|
||||
}
|
||||
|
||||
interface NodeProfile {
|
||||
services: Record<string, ServiceProfile>
|
||||
}
|
||||
|
||||
const props = defineProps<{
|
||||
nodeId: string
|
||||
}>()
|
||||
|
||||
const profile = ref<NodeProfile | null>(null)
|
||||
const loading = ref(true)
|
||||
const error = ref('')
|
||||
|
||||
let fetchAbort: AbortController | null = null
|
||||
|
||||
async function fetchProfile() {
|
||||
fetchAbort?.abort()
|
||||
fetchAbort = new AbortController()
|
||||
loading.value = true
|
||||
error.value = ''
|
||||
try {
|
||||
const r = await fetch(`/api/nodes-mgmt/nodes/${props.nodeId}/profile`, {
|
||||
signal: fetchAbort.signal,
|
||||
})
|
||||
if (r.status === 404) { profile.value = null; return }
|
||||
if (!r.ok) throw new Error(`HTTP ${r.status}`)
|
||||
profile.value = await r.json() as NodeProfile
|
||||
} catch (e) {
|
||||
if (e instanceof Error && e.name === 'AbortError') return
|
||||
error.value = e instanceof Error ? e.message : 'Failed to load profile'
|
||||
} finally {
|
||||
loading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
onMounted(fetchProfile)
|
||||
onUnmounted(() => { fetchAbort?.abort() })
|
||||
</script>
|
||||
|
||||
<template>
|
||||
<section class="hf-panel">
|
||||
<h3 class="panel-title">Model Catalog</h3>
|
||||
<p class="hf-hint">
|
||||
To download a new HuggingFace model,
|
||||
<a href="#/fleet" class="hf-link">go to Fleet</a>.
|
||||
Models downloaded there are automatically registered in node catalogs.
|
||||
</p>
|
||||
|
||||
<div aria-live="polite" aria-atomic="true" class="sr-announce">
|
||||
<span v-if="loading">Loading catalog...</span>
|
||||
</div>
|
||||
<div v-if="error" class="panel-error" role="alert">{{ error }}</div>
|
||||
<div v-else-if="!loading && !profile" class="panel-empty">No profile loaded for this node.</div>
|
||||
<div v-else-if="!loading && profile" class="catalog-body">
|
||||
<div
|
||||
v-for="(svcInfo, svcName) in profile.services"
|
||||
:key="String(svcName)"
|
||||
class="svc-section"
|
||||
>
|
||||
<h4 class="svc-name">{{ svcName }}</h4>
|
||||
<ul class="catalog-list" role="list">
|
||||
<li
|
||||
v-if="!Object.keys(svcInfo.catalog ?? {}).length"
|
||||
class="catalog-empty"
|
||||
>
|
||||
No models in catalog.
|
||||
</li>
|
||||
<li
|
||||
v-for="(entry, modelName) in (svcInfo.catalog ?? {})"
|
||||
:key="String(modelName)"
|
||||
class="catalog-item"
|
||||
>
|
||||
<span class="catalog-model">{{ modelName }}</span>
|
||||
<span class="catalog-vram">{{ entry.vram_mb }} MB</span>
|
||||
<span v-if="entry.description" class="catalog-desc">{{ entry.description }}</span>
|
||||
</li>
|
||||
</ul>
|
||||
</div>
|
||||
</div>
|
||||
</section>
|
||||
</template>
|
||||
|
||||
<style scoped>
|
||||
.hf-panel {
|
||||
margin-top: 0.75rem;
|
||||
padding: 0.75rem;
|
||||
border: 1px solid var(--color-border);
|
||||
border-radius: 6px;
|
||||
color: var(--color-text);
|
||||
}
|
||||
.panel-title { margin: 0 0 0.5rem; font-size: 0.9rem; color: var(--color-text); }
|
||||
.hf-hint { font-size: 0.8rem; color: var(--color-text-muted); margin: 0 0 0.75rem; }
|
||||
.hf-link { color: var(--app-primary); }
|
||||
.hf-link:hover { color: var(--app-primary-hover); }
|
||||
.svc-section { margin-bottom: 0.75rem; }
|
||||
.svc-name {
|
||||
margin: 0 0 0.25rem;
|
||||
font-size: 0.75rem;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.05em;
|
||||
color: var(--color-text-muted);
|
||||
}
|
||||
.catalog-list { list-style: none; margin: 0; padding: 0; display: flex; flex-direction: column; gap: 0.2rem; }
|
||||
.catalog-item {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.5rem;
|
||||
padding: 0.25rem 0.5rem;
|
||||
background: var(--color-surface-alt);
|
||||
border-radius: 4px;
|
||||
font-size: 0.8rem;
|
||||
}
|
||||
.catalog-model { font-family: var(--font-mono, monospace); flex: 1; }
|
||||
.catalog-vram { color: var(--color-text-muted); white-space: nowrap; }
|
||||
.catalog-desc { color: var(--color-text-muted); font-size: 0.75rem; flex: 2; }
|
||||
.catalog-empty, .panel-empty { color: var(--color-text-muted); font-size: 0.875rem; }
|
||||
.sr-announce { min-height: 1.2em; }
|
||||
.panel-error { color: var(--color-error); font-size: 0.8rem; }
|
||||
</style>
|
||||
|
|
@ -1,148 +0,0 @@
|
|||
<script setup lang="ts">
|
||||
import { ref } from 'vue'
|
||||
import GpuRow from './GpuRow.vue'
|
||||
import OllamaModelPanel from './OllamaModelPanel.vue'
|
||||
import ProfileEditorPanel from './ProfileEditorPanel.vue'
|
||||
import type { NodeSummary, FullProfile } from '../../types/nodes'
|
||||
|
||||
const props = defineProps<{ node: NodeSummary }>()
|
||||
const emit = defineEmits<{ updated: [] }>()
|
||||
|
||||
const showOllama = ref(false)
|
||||
const showEditor = ref(false)
|
||||
const loadedProfile = ref<FullProfile | null>(null)
|
||||
const profileLoading = ref(false)
|
||||
const profileError = ref('')
|
||||
|
||||
async function openEditor() {
|
||||
if (showEditor.value) { showEditor.value = false; return }
|
||||
profileLoading.value = true
|
||||
profileError.value = ''
|
||||
try {
|
||||
const r = await fetch(`/api/nodes-mgmt/nodes/${props.node.node_id}/profile`)
|
||||
if (r.status === 404) {
|
||||
loadedProfile.value = null
|
||||
} else if (!r.ok) {
|
||||
throw new Error(`HTTP ${r.status}`)
|
||||
} else {
|
||||
loadedProfile.value = await r.json() as FullProfile
|
||||
}
|
||||
showEditor.value = true
|
||||
} catch (e) {
|
||||
profileError.value = e instanceof Error ? e.message : 'Failed to load profile'
|
||||
} finally {
|
||||
profileLoading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
function onProfileSaved() {
|
||||
showEditor.value = false
|
||||
emit('updated')
|
||||
}
|
||||
</script>
|
||||
|
||||
<template>
|
||||
<section class="node-card" :class="{ offline: !node.online }">
|
||||
<header class="node-card-header">
|
||||
<div class="node-identity">
|
||||
<span
|
||||
class="status-dot"
|
||||
:class="node.online ? 'online' : 'offline'"
|
||||
:aria-label="node.online ? 'Online' : 'Offline'"
|
||||
role="img"
|
||||
/>
|
||||
<h2 class="node-name">{{ node.node_id }}</h2>
|
||||
<span class="node-agent">{{ node.agent_url }}</span>
|
||||
</div>
|
||||
<div class="node-actions">
|
||||
<button
|
||||
v-if="node.profile_loaded"
|
||||
class="btn-secondary btn-sm"
|
||||
@click="showOllama = !showOllama"
|
||||
>
|
||||
{{ showOllama ? 'Hide Ollama' : 'Ollama' }}
|
||||
</button>
|
||||
<button
|
||||
class="btn-secondary btn-sm"
|
||||
:disabled="profileLoading"
|
||||
@click="openEditor"
|
||||
>
|
||||
{{ profileLoading ? 'Loading…' : node.profile_loaded ? (showEditor ? 'Close Editor' : 'Edit Profile') : 'Create Profile' }}
|
||||
</button>
|
||||
</div>
|
||||
</header>
|
||||
|
||||
<div v-if="!node.profile_loaded" class="no-profile" role="status">
|
||||
No profile configured for this node. GPU stats are visible; service assignment is disabled.
|
||||
</div>
|
||||
|
||||
<div class="gpu-list">
|
||||
<GpuRow
|
||||
v-for="gpu in node.gpus"
|
||||
:key="gpu.gpu_id"
|
||||
:gpu="gpu"
|
||||
:node-id="node.node_id"
|
||||
:profile-loaded="node.profile_loaded"
|
||||
:services-catalog="node.services_catalog"
|
||||
@updated="emit('updated')"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<OllamaModelPanel v-if="showOllama" :node-id="node.node_id" />
|
||||
<div v-if="profileError" class="profile-load-error" role="alert">{{ profileError }}</div>
|
||||
<ProfileEditorPanel
|
||||
v-if="showEditor"
|
||||
:node-id="node.node_id"
|
||||
:initial-profile="loadedProfile"
|
||||
@saved="onProfileSaved"
|
||||
@close="showEditor = false"
|
||||
/>
|
||||
</section>
|
||||
</template>
|
||||
|
||||
<style scoped>
|
||||
.node-card {
|
||||
border: 1px solid var(--color-border);
|
||||
border-radius: 8px;
|
||||
padding: 1rem;
|
||||
background: var(--color-surface-raised);
|
||||
color: var(--color-text);
|
||||
}
|
||||
.node-card.offline { opacity: 0.65; }
|
||||
.node-card-header {
|
||||
display: flex;
|
||||
align-items: flex-start;
|
||||
justify-content: space-between;
|
||||
gap: 0.5rem;
|
||||
margin-bottom: 0.75rem;
|
||||
}
|
||||
.node-identity { display: flex; align-items: center; gap: 0.5rem; flex-wrap: wrap; }
|
||||
.node-name { margin: 0; font-size: 1rem; font-weight: 600; color: var(--color-text); }
|
||||
.node-agent { color: var(--color-text-muted); font-size: 0.8rem; font-family: var(--font-mono, monospace); }
|
||||
.status-dot { width: 10px; height: 10px; border-radius: 50%; flex-shrink: 0; }
|
||||
.status-dot.online { background: var(--color-success); }
|
||||
.status-dot.offline { background: var(--color-warning); }
|
||||
.node-actions { display: flex; gap: 0.5rem; flex-shrink: 0; }
|
||||
.btn-secondary {
|
||||
background: transparent;
|
||||
border: 1px solid var(--color-border);
|
||||
color: var(--color-text);
|
||||
border-radius: 4px;
|
||||
padding: 0.3rem 0.65rem;
|
||||
cursor: pointer;
|
||||
font-size: 0.8rem;
|
||||
}
|
||||
.btn-secondary:hover { background: var(--color-surface-alt); }
|
||||
.btn-secondary:disabled { opacity: 0.5; cursor: not-allowed; }
|
||||
.btn-sm { font-size: 0.8rem; padding: 0.25rem 0.6rem; }
|
||||
.no-profile {
|
||||
padding: 0.6rem 0.75rem;
|
||||
background: var(--color-surface-alt);
|
||||
border-radius: 4px;
|
||||
color: var(--color-text-muted);
|
||||
font-size: 0.875rem;
|
||||
margin-bottom: 0.5rem;
|
||||
}
|
||||
.gpu-list { display: flex; flex-direction: column; gap: 0.5rem; }
|
||||
.profile-load-error { color: var(--color-error); font-size: 0.8rem; margin-top: 0.5rem; }
|
||||
</style>
|
||||
|
|
@ -1,242 +0,0 @@
|
|||
<script setup lang="ts">
|
||||
import { ref, onMounted, onUnmounted } from 'vue'
|
||||
|
||||
const props = defineProps<{ nodeId: string }>()
|
||||
|
||||
interface OllamaModel {
|
||||
name: string
|
||||
size: number
|
||||
modified_at: string
|
||||
}
|
||||
|
||||
const models = ref<OllamaModel[]>([])
|
||||
const loading = ref(true)
|
||||
const loadError = ref('')
|
||||
const pullName = ref('')
|
||||
const pulling = ref(false)
|
||||
const pullStatus = ref('')
|
||||
const pullPct = ref(0)
|
||||
const pullError = ref('')
|
||||
|
||||
// AbortController for the SSE pull stream
|
||||
const abortCtrl = ref<AbortController | null>(null)
|
||||
|
||||
// AbortController for the one-shot fetchModels request
|
||||
let fetchAbort: AbortController | null = null
|
||||
|
||||
async function fetchModels() {
|
||||
fetchAbort?.abort()
|
||||
fetchAbort = new AbortController()
|
||||
loading.value = true
|
||||
loadError.value = ''
|
||||
try {
|
||||
const r = await fetch(`/api/nodes-mgmt/nodes/${props.nodeId}/models/ollama`, {
|
||||
signal: fetchAbort.signal,
|
||||
})
|
||||
const data = await r.json() as { models?: OllamaModel[]; error?: string }
|
||||
if (data.error) { loadError.value = data.error; return }
|
||||
models.value = data.models ?? []
|
||||
} catch (e) {
|
||||
if (e instanceof Error && e.name === 'AbortError') return
|
||||
loadError.value = e instanceof Error ? e.message : 'Failed to load models'
|
||||
} finally {
|
||||
loading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
async function doPull() {
|
||||
const name = pullName.value.trim()
|
||||
if (!name || pulling.value) return
|
||||
pulling.value = true
|
||||
pullStatus.value = 'Starting...'
|
||||
pullError.value = ''
|
||||
pullPct.value = 0
|
||||
|
||||
const ctrl = new AbortController()
|
||||
abortCtrl.value = ctrl
|
||||
|
||||
try {
|
||||
const resp = await fetch(`/api/nodes-mgmt/nodes/${props.nodeId}/models/ollama/pull`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ name }),
|
||||
signal: ctrl.signal,
|
||||
})
|
||||
if (!resp.ok) throw new Error(`HTTP ${resp.status}`)
|
||||
if (!resp.body) throw new Error('No response body')
|
||||
|
||||
const reader = resp.body.getReader()
|
||||
const decoder = new TextDecoder()
|
||||
let buf = ''
|
||||
|
||||
while (true) {
|
||||
const { done, value } = await reader.read()
|
||||
if (done) break
|
||||
buf += decoder.decode(value, { stream: true })
|
||||
const lines = buf.split('\n')
|
||||
buf = lines.pop() ?? ''
|
||||
for (const line of lines) {
|
||||
if (!line.startsWith('data: ')) continue
|
||||
try {
|
||||
const evt = JSON.parse(line.slice(6)) as {
|
||||
status?: string; error?: string; total?: number; completed?: number
|
||||
}
|
||||
if (evt.error) {
|
||||
pullError.value = evt.error
|
||||
break
|
||||
}
|
||||
if (evt.status) pullStatus.value = evt.status
|
||||
if (evt.total && evt.completed) {
|
||||
pullPct.value = Math.round((evt.completed / evt.total) * 100)
|
||||
}
|
||||
if (evt.status === 'success') {
|
||||
pullStatus.value = 'Done!'
|
||||
pullName.value = ''
|
||||
break
|
||||
}
|
||||
} catch { /* skip malformed line */ }
|
||||
}
|
||||
}
|
||||
|
||||
// Refresh model list after the stream closes (success or benign end)
|
||||
await fetchModels()
|
||||
} catch (e) {
|
||||
if (e instanceof Error && e.name === 'AbortError') return
|
||||
pullError.value = e instanceof Error ? e.message : 'Pull failed'
|
||||
} finally {
|
||||
pulling.value = false
|
||||
abortCtrl.value = null
|
||||
}
|
||||
}
|
||||
|
||||
async function deleteModel(name: string) {
|
||||
if (!confirm(`Delete model "${name}" from node ${props.nodeId}?`)) return
|
||||
try {
|
||||
const r = await fetch(
|
||||
`/api/nodes-mgmt/nodes/${props.nodeId}/models/ollama/${encodeURIComponent(name)}`,
|
||||
{ method: 'DELETE' },
|
||||
)
|
||||
if (!r.ok) throw new Error(`HTTP ${r.status}`)
|
||||
await fetchModels()
|
||||
} catch (e) {
|
||||
loadError.value = e instanceof Error ? e.message : 'Delete failed'
|
||||
}
|
||||
}
|
||||
|
||||
function formatSize(bytes: number): string {
|
||||
return (bytes / 1e9).toFixed(1) + ' GB'
|
||||
}
|
||||
|
||||
onMounted(fetchModels)
|
||||
onUnmounted(() => {
|
||||
abortCtrl.value?.abort()
|
||||
fetchAbort?.abort()
|
||||
})
|
||||
</script>
|
||||
|
||||
<template>
|
||||
<section class="ollama-panel">
|
||||
<h3 class="panel-title">Ollama Models</h3>
|
||||
|
||||
<form class="pull-form" @submit.prevent="doPull">
|
||||
<input
|
||||
v-model="pullName"
|
||||
type="text"
|
||||
placeholder="nomic-embed-text, llama3.2:3b, ..."
|
||||
:disabled="pulling"
|
||||
aria-label="Model name to pull from Ollama"
|
||||
class="pull-input"
|
||||
/>
|
||||
<button type="submit" :disabled="pulling || !pullName.trim()" class="btn-primary btn-sm">
|
||||
{{ pulling ? 'Pulling...' : 'Pull' }}
|
||||
</button>
|
||||
</form>
|
||||
|
||||
<div v-if="pulling || pullStatus" class="pull-progress" aria-live="polite">
|
||||
<div
|
||||
class="progress-bar"
|
||||
role="progressbar"
|
||||
:aria-valuenow="pullPct"
|
||||
aria-valuemin="0"
|
||||
aria-valuemax="100"
|
||||
:aria-label="`Pull progress: ${pullStatus}`"
|
||||
>
|
||||
<div class="progress-fill" :style="{ width: `${pullPct}%` }" />
|
||||
</div>
|
||||
<span class="progress-label">{{ pullStatus }}{{ pullPct > 0 ? ` (${pullPct}%)` : '' }}</span>
|
||||
</div>
|
||||
|
||||
<div v-if="pullError" class="pull-error" role="alert">
|
||||
{{ pullError }}
|
||||
<span v-if="pullError.includes('permission denied')">
|
||||
— Remove the partial file on the node and retry.
|
||||
</span>
|
||||
</div>
|
||||
|
||||
<div aria-live="polite" aria-atomic="true" class="sr-announce">
|
||||
<span v-if="loading">Loading...</span>
|
||||
</div>
|
||||
<div v-if="loadError" class="panel-error" role="alert">{{ loadError }}</div>
|
||||
<ul v-if="!loading && !loadError" class="model-list" role="list">
|
||||
<li v-if="!models.length" class="model-empty">No Ollama models installed on this node.</li>
|
||||
<li v-for="m in models" :key="m.name" class="model-item">
|
||||
<span class="model-name">{{ m.name }}</span>
|
||||
<span class="model-size">{{ formatSize(m.size) }}</span>
|
||||
<button
|
||||
class="btn-danger btn-xs"
|
||||
@click="deleteModel(m.name)"
|
||||
:aria-label="`Delete ${m.name}`"
|
||||
>
|
||||
Delete
|
||||
</button>
|
||||
</li>
|
||||
</ul>
|
||||
</section>
|
||||
</template>
|
||||
|
||||
<style scoped>
|
||||
.ollama-panel {
|
||||
margin-top: 0.75rem;
|
||||
padding: 0.75rem;
|
||||
border: 1px solid var(--color-border);
|
||||
border-radius: 6px;
|
||||
color: var(--color-text);
|
||||
}
|
||||
.panel-title { margin: 0 0 0.75rem; font-size: 0.9rem; color: var(--color-text); }
|
||||
.pull-form { display: flex; gap: 0.5rem; margin-bottom: 0.5rem; }
|
||||
.pull-input {
|
||||
flex: 1;
|
||||
padding: 0.3rem 0.5rem;
|
||||
background: var(--color-surface-alt);
|
||||
border: 1px solid var(--color-border);
|
||||
border-radius: 4px;
|
||||
color: var(--color-text);
|
||||
font-size: 0.875rem;
|
||||
}
|
||||
.pull-progress { margin-bottom: 0.5rem; }
|
||||
.progress-bar {
|
||||
height: 8px;
|
||||
background: var(--color-border);
|
||||
border-radius: 4px;
|
||||
overflow: hidden;
|
||||
margin-bottom: 0.25rem;
|
||||
}
|
||||
.progress-fill { height: 100%; background: var(--app-primary); transition: width 0.2s; }
|
||||
.progress-label { font-size: 0.75rem; color: var(--color-text-muted); }
|
||||
.pull-error, .panel-error { color: var(--color-error); font-size: 0.8rem; margin-bottom: 0.5rem; }
|
||||
.sr-announce { min-height: 1.2em; }
|
||||
.panel-loading { color: var(--color-text-muted); font-size: 0.875rem; }
|
||||
.model-list { list-style: none; margin: 0; padding: 0; display: flex; flex-direction: column; gap: 0.3rem; }
|
||||
.model-item {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.5rem;
|
||||
padding: 0.3rem 0.5rem;
|
||||
background: var(--color-surface-alt);
|
||||
border-radius: 4px;
|
||||
font-size: 0.875rem;
|
||||
}
|
||||
.model-name { flex: 1; font-family: var(--font-mono, monospace); }
|
||||
.model-size { color: var(--color-text-muted); font-size: 0.8rem; }
|
||||
.model-empty { color: var(--color-text-muted); font-size: 0.875rem; padding: 0.25rem 0; }
|
||||
</style>
|
||||
|
|
@ -1,597 +0,0 @@
|
|||
<script setup lang="ts">
|
||||
import { ref, onMounted } from 'vue'
|
||||
import type { FullProfile, ServiceDefinition, CatalogEntryFull } from '../../types/nodes'
|
||||
import ServiceFormModal from './ServiceFormModal.vue'
|
||||
import CatalogEntryFormModal from './CatalogEntryFormModal.vue'
|
||||
|
||||
const props = defineProps<{
|
||||
nodeId: string
|
||||
initialProfile: FullProfile | null
|
||||
}>()
|
||||
const emit = defineEmits<{ saved: []; close: [] }>()
|
||||
|
||||
// Deep-clone initial profile so edits don't mutate the parent's data
|
||||
const profile = ref<FullProfile>(
|
||||
props.initialProfile
|
||||
? JSON.parse(JSON.stringify(props.initialProfile))
|
||||
: { services: {}, nodes: {} }
|
||||
)
|
||||
|
||||
const saving = ref(false)
|
||||
const generating = ref(false)
|
||||
const opError = ref('')
|
||||
const expandedSvcs = ref<Set<string>>(new Set())
|
||||
|
||||
// Service modal
|
||||
const showSvcModal = ref(false)
|
||||
const editingSvcName = ref<string | undefined>()
|
||||
const editingSvcDef = ref<ServiceDefinition | undefined>()
|
||||
|
||||
// Catalog modal
|
||||
const showCatalogModal = ref(false)
|
||||
const catalogTargetSvc = ref('')
|
||||
const editingModelName = ref<string | undefined>()
|
||||
const editingEntry = ref<CatalogEntryFull | undefined>()
|
||||
|
||||
// ── Generate nodes section from coordinator ────────────────────────────────────
|
||||
|
||||
async function generate() {
|
||||
generating.value = true
|
||||
opError.value = ''
|
||||
try {
|
||||
const r = await fetch(`/api/nodes-mgmt/nodes/${props.nodeId}/profile/generate`, { method: 'POST' })
|
||||
if (!r.ok) { const d = await r.json().catch(() => ({})); throw new Error((d as {detail?: string}).detail ?? `HTTP ${r.status}`) }
|
||||
const generated = await r.json() as FullProfile
|
||||
// Merge: keep current services edits, replace nodes section
|
||||
profile.value = { ...generated, services: profile.value.services }
|
||||
} catch (e) {
|
||||
opError.value = e instanceof Error ? e.message : 'Generate failed'
|
||||
} finally {
|
||||
generating.value = false
|
||||
}
|
||||
}
|
||||
|
||||
// ── Save full profile ──────────────────────────────────────────────────────────
|
||||
|
||||
async function save() {
|
||||
saving.value = true
|
||||
opError.value = ''
|
||||
try {
|
||||
const r = await fetch(`/api/nodes-mgmt/nodes/${props.nodeId}/profile`, {
|
||||
method: 'PUT',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ profile: profile.value }),
|
||||
})
|
||||
if (!r.ok) { const d = await r.json().catch(() => ({})); throw new Error((d as {detail?: string}).detail ?? `HTTP ${r.status}`) }
|
||||
emit('saved')
|
||||
} catch (e) {
|
||||
opError.value = e instanceof Error ? e.message : 'Save failed'
|
||||
} finally {
|
||||
saving.value = false
|
||||
}
|
||||
}
|
||||
|
||||
// ── Service CRUD ───────────────────────────────────────────────────────────────
|
||||
|
||||
function openAddService() {
|
||||
editingSvcName.value = undefined
|
||||
editingSvcDef.value = undefined
|
||||
showSvcModal.value = true
|
||||
}
|
||||
|
||||
function openEditService(name: string) {
|
||||
editingSvcName.value = name
|
||||
editingSvcDef.value = JSON.parse(JSON.stringify(profile.value.services[name]))
|
||||
showSvcModal.value = true
|
||||
}
|
||||
|
||||
function onServiceSaved(name: string, def: ServiceDefinition) {
|
||||
profile.value = { ...profile.value, services: { ...profile.value.services, [name]: def } }
|
||||
expandedSvcs.value = new Set([...expandedSvcs.value, name])
|
||||
showSvcModal.value = false
|
||||
}
|
||||
|
||||
function deleteService(name: string) {
|
||||
if (!confirm(`Remove service "${name}" from this profile?`)) return
|
||||
const svcs = { ...profile.value.services }
|
||||
delete svcs[name]
|
||||
profile.value = { ...profile.value, services: svcs }
|
||||
expandedSvcs.value = new Set([...expandedSvcs.value].filter(s => s !== name))
|
||||
}
|
||||
|
||||
function toggleSvc(name: string) {
|
||||
const s = new Set(expandedSvcs.value)
|
||||
s.has(name) ? s.delete(name) : s.add(name)
|
||||
expandedSvcs.value = s
|
||||
}
|
||||
|
||||
// ── Catalog CRUD ───────────────────────────────────────────────────────────────
|
||||
|
||||
function openAddCatalogEntry(svcName: string) {
|
||||
catalogTargetSvc.value = svcName
|
||||
editingModelName.value = undefined
|
||||
editingEntry.value = undefined
|
||||
showCatalogModal.value = true
|
||||
}
|
||||
|
||||
function openEditCatalogEntry(svcName: string, modelName: string) {
|
||||
catalogTargetSvc.value = svcName
|
||||
editingModelName.value = modelName
|
||||
editingEntry.value = JSON.parse(JSON.stringify(profile.value.services[svcName].catalog![modelName]))
|
||||
showCatalogModal.value = true
|
||||
}
|
||||
|
||||
function onCatalogEntrySaved(svcName: string, modelName: string, entry: CatalogEntryFull) {
|
||||
const svcs = { ...profile.value.services }
|
||||
const svc = { ...svcs[svcName], catalog: { ...(svcs[svcName].catalog ?? {}), [modelName]: entry } }
|
||||
svcs[svcName] = svc
|
||||
profile.value = { ...profile.value, services: svcs }
|
||||
showCatalogModal.value = false
|
||||
}
|
||||
|
||||
function deleteCatalogEntry(svcName: string, modelName: string) {
|
||||
if (!confirm(`Remove model "${modelName}" from ${svcName} catalog?`)) return
|
||||
const svcs = { ...profile.value.services }
|
||||
const catalog = { ...(svcs[svcName].catalog ?? {}) }
|
||||
delete catalog[modelName]
|
||||
svcs[svcName] = { ...svcs[svcName], catalog }
|
||||
profile.value = { ...profile.value, services: svcs }
|
||||
}
|
||||
|
||||
// ── Helpers ────────────────────────────────────────────────────────────────────
|
||||
|
||||
function gpuList() {
|
||||
return (profile.value.nodes[props.nodeId]?.gpus ?? [])
|
||||
}
|
||||
|
||||
function serviceCount() {
|
||||
return Object.keys(profile.value.services).length
|
||||
}
|
||||
|
||||
// ── Ollama model suggestions ───────────────────────────────────────────────────
|
||||
|
||||
interface OllamaModel { name: string; size: number }
|
||||
const ollamaModels = ref<OllamaModel[]>([])
|
||||
const ollamaLoading = ref(false)
|
||||
|
||||
onMounted(async () => {
|
||||
ollamaLoading.value = true
|
||||
try {
|
||||
const r = await fetch(`/api/nodes-mgmt/nodes/${props.nodeId}/models/ollama`)
|
||||
if (r.ok) {
|
||||
const d = await r.json() as { models?: OllamaModel[] }
|
||||
ollamaModels.value = d.models ?? []
|
||||
}
|
||||
} catch { /* Ollama offline — silently skip */ }
|
||||
finally { ollamaLoading.value = false }
|
||||
})
|
||||
|
||||
function ollamaNotInCatalog(svcName: string): OllamaModel[] {
|
||||
const catalog = profile.value.services[svcName]?.catalog ?? {}
|
||||
return ollamaModels.value.filter(m => !(m.name in catalog))
|
||||
}
|
||||
|
||||
function openAddFromOllama(svcName: string, modelName: string) {
|
||||
catalogTargetSvc.value = svcName
|
||||
editingModelName.value = modelName
|
||||
editingEntry.value = {
|
||||
path: profile.value.services[svcName]?.model_base_path
|
||||
? `${profile.value.services[svcName].model_base_path}/${modelName}`
|
||||
: '',
|
||||
vram_mb: 0,
|
||||
}
|
||||
showCatalogModal.value = true
|
||||
}
|
||||
|
||||
function formatMb(bytes: number): string {
|
||||
return bytes >= 1_000_000_000
|
||||
? `${(bytes / 1_073_741_824).toFixed(1)} GB`
|
||||
: `${Math.round(bytes / 1_048_576)} MB`
|
||||
}
|
||||
|
||||
// ── Pull model onto node ───────────────────────────────────────────────────────
|
||||
|
||||
const pullName = ref('')
|
||||
const pulling = ref(false)
|
||||
const pullStatus = ref('')
|
||||
const pullPct = ref(0)
|
||||
const pullError = ref('')
|
||||
let pullAbort: AbortController | null = null
|
||||
|
||||
async function doPull() {
|
||||
const name = pullName.value.trim()
|
||||
if (!name || pulling.value) return
|
||||
pulling.value = true
|
||||
pullStatus.value = 'Starting…'
|
||||
pullError.value = ''
|
||||
pullPct.value = 0
|
||||
pullAbort?.abort()
|
||||
pullAbort = new AbortController()
|
||||
|
||||
try {
|
||||
const resp = await fetch(`/api/nodes-mgmt/nodes/${props.nodeId}/models/ollama/pull`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ name }),
|
||||
signal: pullAbort.signal,
|
||||
})
|
||||
if (!resp.ok || !resp.body) {
|
||||
pullError.value = `HTTP ${resp.status}`
|
||||
return
|
||||
}
|
||||
const reader = resp.body.getReader()
|
||||
const dec = new TextDecoder()
|
||||
let buf = ''
|
||||
while (true) {
|
||||
const { done, value } = await reader.read()
|
||||
if (done) break
|
||||
buf += dec.decode(value, { stream: true })
|
||||
const lines = buf.split('\n')
|
||||
buf = lines.pop() ?? ''
|
||||
for (const line of lines) {
|
||||
if (!line.startsWith('data:')) continue
|
||||
try {
|
||||
const d = JSON.parse(line.slice(5)) as {
|
||||
status?: string; completed?: number; total?: number; error?: string; done?: boolean
|
||||
}
|
||||
if (d.error) { pullError.value = d.error; return }
|
||||
pullStatus.value = d.status ?? ''
|
||||
if (d.total && d.total > 0) pullPct.value = Math.round((d.completed ?? 0) / d.total * 100)
|
||||
if (d.done) {
|
||||
pullName.value = ''
|
||||
pullPct.value = 100
|
||||
// Refresh Ollama model list so new model appears in suggest chips
|
||||
const r = await fetch(`/api/nodes-mgmt/nodes/${props.nodeId}/models/ollama`)
|
||||
if (r.ok) { const d2 = await r.json() as { models?: OllamaModel[] }; ollamaModels.value = d2.models ?? [] }
|
||||
}
|
||||
} catch { /* skip malformed SSE line */ }
|
||||
}
|
||||
}
|
||||
} catch (e) {
|
||||
if (e instanceof Error && e.name !== 'AbortError') pullError.value = e.message
|
||||
} finally {
|
||||
pulling.value = false
|
||||
if (pullPct.value === 100) setTimeout(() => { pullStatus.value = ''; pullPct.value = 0 }, 2000)
|
||||
}
|
||||
}
|
||||
</script>
|
||||
|
||||
<template>
|
||||
<section class="pep" aria-label="Profile editor">
|
||||
<!-- Header -->
|
||||
<div class="pep-header">
|
||||
<div class="pep-title-row">
|
||||
<h3 class="pep-title">Profile — {{ nodeId }}</h3>
|
||||
<span class="pep-svc-count">{{ serviceCount() }} service{{ serviceCount() === 1 ? '' : 's' }}</span>
|
||||
</div>
|
||||
<div class="pep-actions">
|
||||
<button class="btn-secondary btn-sm" :disabled="generating" @click="generate">
|
||||
{{ generating ? 'Refreshing…' : 'Refresh Hardware' }}
|
||||
</button>
|
||||
<button class="btn-primary btn-sm" :disabled="saving" @click="save">
|
||||
{{ saving ? 'Saving…' : 'Save Profile' }}
|
||||
</button>
|
||||
<button class="btn-icon-lg" aria-label="Close editor" @click="emit('close')">✕</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div v-if="opError" class="pep-error" role="alert">{{ opError }}</div>
|
||||
|
||||
<!-- Meta fields -->
|
||||
<div class="pep-meta">
|
||||
<label class="meta-label" for="pep-vram">vram_total_mb</label>
|
||||
<input id="pep-vram" v-model.number="profile.vram_total_mb" type="number" min="0" class="meta-input" />
|
||||
<label class="meta-label" for="pep-evict">eviction_timeout_s</label>
|
||||
<input id="pep-evict" v-model.number="profile.eviction_timeout_s" type="number" min="0" step="0.5" class="meta-input" />
|
||||
</div>
|
||||
|
||||
<!-- Hardware summary -->
|
||||
<div v-if="gpuList().length" class="hw-section">
|
||||
<span class="hw-label">Hardware</span>
|
||||
<span v-for="g in gpuList()" :key="g.id" class="hw-gpu">
|
||||
GPU {{ g.id }}: {{ g.card || 'unknown' }} · {{ g.vram_mb }} MB · sm{{ g.compute_cap ?? '?' }}
|
||||
</span>
|
||||
<span v-if="!gpuList().length" class="hw-none">No hardware data — click Refresh Hardware.</span>
|
||||
</div>
|
||||
<div v-else class="hw-section">
|
||||
<span class="hw-none">No hardware data — click Refresh Hardware to seed from coordinator.</span>
|
||||
</div>
|
||||
|
||||
<!-- Services -->
|
||||
<div class="svcs-header">
|
||||
<span class="svcs-title">Services</span>
|
||||
<button class="btn-secondary btn-sm" @click="openAddService">+ Add Service</button>
|
||||
</div>
|
||||
|
||||
<div v-if="serviceCount() === 0" class="svcs-empty">
|
||||
No services defined. Add a service to configure what can run on this node.
|
||||
</div>
|
||||
|
||||
<ul class="svcs-list" role="list">
|
||||
<li
|
||||
v-for="(def, svcName) in profile.services"
|
||||
:key="String(svcName)"
|
||||
class="svc-item"
|
||||
>
|
||||
<!-- Service row header -->
|
||||
<div class="svc-row">
|
||||
<button
|
||||
class="svc-toggle"
|
||||
:aria-expanded="expandedSvcs.has(String(svcName))"
|
||||
@click="toggleSvc(String(svcName))"
|
||||
>
|
||||
<span class="svc-arrow">{{ expandedSvcs.has(String(svcName)) ? '▾' : '▸' }}</span>
|
||||
<span class="svc-name">{{ svcName }}</span>
|
||||
</button>
|
||||
<span class="svc-badges">
|
||||
<span class="badge">{{ def.max_mb }} MB</span>
|
||||
<span class="badge">p{{ def.priority }}</span>
|
||||
<span v-if="def.shared" class="badge badge--blue">shared</span>
|
||||
<span v-if="def.managed" class="badge badge--dim">managed</span>
|
||||
<span v-if="def.catalog" class="badge badge--dim">{{ Object.keys(def.catalog).length }} models</span>
|
||||
</span>
|
||||
<div class="svc-btns">
|
||||
<button class="btn-secondary btn-xs" @click="openEditService(String(svcName))">Edit</button>
|
||||
<button class="btn-danger btn-xs" @click="deleteService(String(svcName))">Delete</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Expanded catalog -->
|
||||
<div v-if="expandedSvcs.has(String(svcName))" class="svc-detail">
|
||||
<div class="svc-detail-meta">
|
||||
<span v-if="def.min_compute_cap">min sm{{ def.min_compute_cap }}</span>
|
||||
<span v-if="def.max_concurrent">max_concurrent: {{ def.max_concurrent }}</span>
|
||||
<span v-if="def.idle_stop_after_s">idle_stop: {{ def.idle_stop_after_s }}s</span>
|
||||
<span v-if="def.always_on" class="badge badge--blue">always_on</span>
|
||||
</div>
|
||||
|
||||
<!-- Ollama model suggestions + pull -->
|
||||
<div class="ollama-suggest">
|
||||
<div class="suggest-row">
|
||||
<span class="suggest-label">On node (Ollama):</span>
|
||||
<span v-if="ollamaLoading" class="suggest-loading">loading…</span>
|
||||
<template v-else-if="ollamaNotInCatalog(String(svcName)).length">
|
||||
<button
|
||||
v-for="m in ollamaNotInCatalog(String(svcName))"
|
||||
:key="m.name"
|
||||
class="suggest-chip"
|
||||
:title="`Add ${m.name} (${formatMb(m.size)}) to this service catalog`"
|
||||
@click="openAddFromOllama(String(svcName), m.name)"
|
||||
>
|
||||
+ {{ m.name }} <span class="chip-size">{{ formatMb(m.size) }}</span>
|
||||
</button>
|
||||
</template>
|
||||
<span v-else-if="!ollamaLoading" class="suggest-none">All Ollama models already in catalog.</span>
|
||||
</div>
|
||||
|
||||
<!-- Pull model onto this node -->
|
||||
<div class="pull-row">
|
||||
<input
|
||||
v-model="pullName"
|
||||
class="pull-input"
|
||||
placeholder="Pull model on node (e.g. llama3:8b)"
|
||||
:disabled="pulling"
|
||||
@keyup.enter="doPull"
|
||||
/>
|
||||
<button class="btn-pull" :disabled="pulling || !pullName.trim()" @click="doPull">
|
||||
{{ pulling ? 'Pulling…' : 'Pull' }}
|
||||
</button>
|
||||
</div>
|
||||
<div v-if="pulling || pullPct > 0" class="pull-progress">
|
||||
<div class="pull-bar"><div class="pull-fill" :style="{ width: pullPct + '%' }" /></div>
|
||||
<span class="pull-status">{{ pullStatus }}</span>
|
||||
</div>
|
||||
<div v-if="pullError" class="pull-err" role="alert">{{ pullError }}</div>
|
||||
</div>
|
||||
|
||||
<div class="catalog-header">
|
||||
<span class="catalog-title">Catalog</span>
|
||||
<button class="btn-link" @click="openAddCatalogEntry(String(svcName))">+ Add Model</button>
|
||||
</div>
|
||||
|
||||
<div v-if="!def.catalog || !Object.keys(def.catalog).length" class="catalog-empty">
|
||||
No catalog entries. Only services like cf-text need a catalog.
|
||||
</div>
|
||||
<ul v-else class="catalog-list" role="list">
|
||||
<li
|
||||
v-for="(entry, modelName) in def.catalog"
|
||||
:key="String(modelName)"
|
||||
class="catalog-item"
|
||||
>
|
||||
<span class="catalog-model">{{ modelName }}</span>
|
||||
<span class="catalog-vram">{{ entry.vram_mb }} MB</span>
|
||||
<span v-if="entry.multi_gpu" class="badge badge--dim">multi-gpu</span>
|
||||
<span v-if="entry.description" class="catalog-desc">{{ entry.description }}</span>
|
||||
<div class="catalog-btns">
|
||||
<button class="btn-secondary btn-xs" @click="openEditCatalogEntry(String(svcName), String(modelName))">Edit</button>
|
||||
<button class="btn-danger btn-xs" @click="deleteCatalogEntry(String(svcName), String(modelName))">✕</button>
|
||||
</div>
|
||||
</li>
|
||||
</ul>
|
||||
</div>
|
||||
</li>
|
||||
</ul>
|
||||
</section>
|
||||
|
||||
<!-- Service form modal -->
|
||||
<ServiceFormModal
|
||||
v-if="showSvcModal"
|
||||
:service-name="editingSvcName"
|
||||
:definition="editingSvcDef"
|
||||
@save="onServiceSaved"
|
||||
@cancel="showSvcModal = false"
|
||||
/>
|
||||
|
||||
<!-- Catalog entry form modal -->
|
||||
<CatalogEntryFormModal
|
||||
v-if="showCatalogModal"
|
||||
:svc-name="catalogTargetSvc"
|
||||
:model-name="editingModelName"
|
||||
:entry="editingEntry"
|
||||
@save="onCatalogEntrySaved"
|
||||
@cancel="showCatalogModal = false"
|
||||
/>
|
||||
</template>
|
||||
|
||||
<style scoped>
|
||||
.pep {
|
||||
margin-top: 0.75rem;
|
||||
padding: 1rem;
|
||||
border: 1px solid var(--color-primary);
|
||||
border-radius: 6px;
|
||||
background: var(--color-surface-raised);
|
||||
color: var(--color-text);
|
||||
}
|
||||
.pep-header {
|
||||
display: flex; align-items: center; justify-content: space-between; gap: 0.5rem;
|
||||
margin-bottom: 0.75rem; flex-wrap: wrap;
|
||||
}
|
||||
.pep-title-row { display: flex; align-items: baseline; gap: 0.5rem; }
|
||||
.pep-title { margin: 0; font-size: 0.95rem; font-weight: 600; color: var(--color-text); }
|
||||
.pep-svc-count { font-size: 0.75rem; color: var(--color-text-muted); }
|
||||
.pep-actions { display: flex; align-items: center; gap: 0.4rem; flex-wrap: wrap; }
|
||||
.pep-error { color: var(--color-error); font-size: 0.8rem; margin-bottom: 0.5rem; }
|
||||
.pep-meta {
|
||||
display: flex; align-items: center; gap: 0.5rem; flex-wrap: wrap;
|
||||
padding: 0.5rem; background: var(--color-surface-alt); border-radius: 4px; margin-bottom: 0.75rem;
|
||||
}
|
||||
.meta-label { font-size: 0.8rem; color: var(--color-text-muted); }
|
||||
.meta-input {
|
||||
width: 7rem; background: var(--color-surface); border: 1px solid var(--color-border);
|
||||
border-radius: 4px; padding: 0.2rem 0.4rem; color: var(--color-text); font-size: 0.8rem;
|
||||
}
|
||||
.hw-section {
|
||||
display: flex; flex-wrap: wrap; align-items: center; gap: 0.5rem;
|
||||
font-size: 0.8rem; color: var(--color-text-muted);
|
||||
padding: 0.4rem 0.5rem; border-radius: 4px; background: var(--color-surface-alt);
|
||||
margin-bottom: 0.75rem;
|
||||
}
|
||||
.hw-label { font-weight: 600; color: var(--color-text); }
|
||||
.hw-gpu { font-family: monospace; color: var(--color-text); }
|
||||
.hw-none { font-style: italic; }
|
||||
.svcs-header {
|
||||
display: flex; align-items: center; justify-content: space-between;
|
||||
margin-bottom: 0.5rem;
|
||||
}
|
||||
.svcs-title { font-size: 0.85rem; font-weight: 600; color: var(--color-text); }
|
||||
.svcs-empty { color: var(--color-text-muted); font-size: 0.85rem; padding: 0.5rem 0; }
|
||||
.svcs-list { list-style: none; margin: 0; padding: 0; display: flex; flex-direction: column; gap: 0.4rem; }
|
||||
.svc-item { border: 1px solid var(--color-border); border-radius: 4px; overflow: hidden; }
|
||||
.svc-row {
|
||||
display: flex; align-items: center; gap: 0.5rem; padding: 0.4rem 0.5rem;
|
||||
background: var(--color-surface-alt); flex-wrap: wrap;
|
||||
}
|
||||
.svc-toggle {
|
||||
display: flex; align-items: center; gap: 0.35rem;
|
||||
background: none; border: none; cursor: pointer; color: var(--color-text); padding: 0; flex: 1; min-width: 0;
|
||||
}
|
||||
.svc-arrow { font-size: 0.7rem; color: var(--color-text-muted); }
|
||||
.svc-name { font-size: 0.875rem; font-weight: 500; font-family: monospace; }
|
||||
.svc-badges { display: flex; gap: 0.3rem; flex-wrap: wrap; }
|
||||
.svc-btns { display: flex; gap: 0.3rem; margin-left: auto; }
|
||||
.svc-detail { padding: 0.5rem 0.75rem; display: flex; flex-direction: column; gap: 0.5rem; background: var(--color-surface-raised); }
|
||||
.svc-detail-meta {
|
||||
display: flex; gap: 0.5rem; flex-wrap: wrap;
|
||||
font-size: 0.78rem; color: var(--color-text-muted);
|
||||
}
|
||||
.ollama-suggest {
|
||||
display: flex; flex-direction: column; gap: 0.35rem;
|
||||
padding: 0.4rem 0.5rem;
|
||||
background: var(--color-primary-light);
|
||||
border: 1px solid var(--color-border-light);
|
||||
border-radius: 4px;
|
||||
font-size: 0.78rem;
|
||||
}
|
||||
.suggest-row { display: flex; flex-wrap: wrap; align-items: center; gap: 0.35rem; }
|
||||
.suggest-label { color: var(--color-text-muted); font-weight: 500; white-space: nowrap; }
|
||||
.suggest-loading { color: var(--color-text-muted); font-style: italic; }
|
||||
.suggest-none { color: var(--color-text-muted); font-style: italic; }
|
||||
.suggest-chip {
|
||||
display: inline-flex; align-items: center; gap: 0.25rem;
|
||||
padding: 0.15rem 0.45rem;
|
||||
background: var(--color-surface-raised);
|
||||
border: 1px solid var(--color-border);
|
||||
border-radius: 3px;
|
||||
color: var(--color-text);
|
||||
cursor: pointer;
|
||||
font-size: 0.78rem;
|
||||
transition: border-color 0.15s, background 0.15s;
|
||||
}
|
||||
.suggest-chip:hover { border-color: var(--app-primary); background: var(--color-surface-alt); }
|
||||
.chip-size { color: var(--color-text-muted); font-size: 0.72rem; }
|
||||
.pull-row { display: flex; gap: 0.4rem; align-items: center; }
|
||||
.pull-input {
|
||||
flex: 1;
|
||||
padding: 0.25rem 0.5rem;
|
||||
background: var(--color-surface-raised);
|
||||
border: 1px solid var(--color-border);
|
||||
border-radius: 4px;
|
||||
color: var(--color-text);
|
||||
font-size: 0.78rem;
|
||||
font-family: var(--font-mono, monospace);
|
||||
}
|
||||
.pull-input:disabled { opacity: 0.5; }
|
||||
.btn-pull {
|
||||
padding: 0.25rem 0.6rem;
|
||||
background: var(--app-primary);
|
||||
color: var(--color-text-inverse);
|
||||
border: none;
|
||||
border-radius: 4px;
|
||||
cursor: pointer;
|
||||
font-size: 0.78rem;
|
||||
white-space: nowrap;
|
||||
}
|
||||
.btn-pull:hover:not(:disabled) { background: var(--app-primary-hover); }
|
||||
.btn-pull:disabled { opacity: 0.5; cursor: not-allowed; }
|
||||
.pull-progress { display: flex; align-items: center; gap: 0.4rem; }
|
||||
.pull-bar {
|
||||
flex: 1; height: 6px;
|
||||
background: var(--color-border);
|
||||
border-radius: 3px; overflow: hidden;
|
||||
}
|
||||
.pull-fill { height: 100%; background: var(--app-primary); transition: width 0.2s; }
|
||||
.pull-status { color: var(--color-text-muted); font-size: 0.72rem; white-space: nowrap; max-width: 14rem; overflow: hidden; text-overflow: ellipsis; }
|
||||
.pull-err { color: var(--color-error); font-size: 0.75rem; }
|
||||
.catalog-header { display: flex; align-items: center; justify-content: space-between; }
|
||||
.catalog-title { font-size: 0.8rem; font-weight: 600; color: var(--color-text-muted); text-transform: uppercase; letter-spacing: 0.05em; }
|
||||
.catalog-empty { font-size: 0.8rem; color: var(--color-text-muted); font-style: italic; }
|
||||
.catalog-list { list-style: none; margin: 0; padding: 0; display: flex; flex-direction: column; gap: 0.25rem; }
|
||||
.catalog-item {
|
||||
display: flex; align-items: center; gap: 0.4rem; flex-wrap: wrap;
|
||||
padding: 0.25rem 0.5rem; background: var(--color-surface-alt); border-radius: 3px; font-size: 0.8rem;
|
||||
color: var(--color-text);
|
||||
}
|
||||
.catalog-model { font-family: monospace; flex: 1; min-width: 12rem; }
|
||||
.catalog-vram { color: var(--color-text-muted); white-space: nowrap; }
|
||||
.catalog-desc { color: var(--color-text-muted); flex: 2; font-size: 0.75rem; }
|
||||
.catalog-btns { display: flex; gap: 0.25rem; margin-left: auto; }
|
||||
.badge {
|
||||
padding: 0.1rem 0.4rem; border-radius: 3px; font-size: 0.72rem;
|
||||
background: var(--color-surface); border: 1px solid var(--color-border); color: var(--color-text);
|
||||
}
|
||||
.badge--blue { border-color: var(--color-primary); color: var(--color-primary); background: var(--color-primary-light); }
|
||||
.badge--dim { opacity: 0.75; }
|
||||
.btn-link { background: none; border: none; color: var(--color-accent); cursor: pointer; font-size: 0.8rem; padding: 0; }
|
||||
.btn-link:hover { color: var(--color-accent-hover); }
|
||||
.btn-primary {
|
||||
background: var(--color-primary); color: var(--color-text-inverse); border: none;
|
||||
border-radius: 4px; cursor: pointer; font-size: 0.8rem;
|
||||
}
|
||||
.btn-primary:hover { background: var(--color-primary-hover); }
|
||||
.btn-primary:disabled { opacity: 0.6; cursor: not-allowed; }
|
||||
.btn-secondary {
|
||||
background: transparent; border: 1px solid var(--color-border); color: var(--color-text);
|
||||
border-radius: 4px; cursor: pointer; font-size: 0.8rem;
|
||||
}
|
||||
.btn-secondary:hover { background: var(--color-surface-alt); }
|
||||
.btn-secondary:disabled { opacity: 0.6; cursor: not-allowed; }
|
||||
.btn-danger {
|
||||
background: transparent; border: 1px solid var(--color-error); color: var(--color-error);
|
||||
border-radius: 4px; cursor: pointer; font-size: 0.8rem;
|
||||
}
|
||||
.btn-danger:hover { background: var(--color-surface-alt); }
|
||||
.btn-sm { padding: 0.3rem 0.6rem; }
|
||||
.btn-xs { padding: 0.15rem 0.4rem; }
|
||||
.btn-icon-lg { background: none; border: none; color: var(--color-text-muted); cursor: pointer; font-size: 1rem; padding: 0.2rem 0.3rem; }
|
||||
.btn-icon-lg:hover { color: var(--color-text); }
|
||||
</style>
|
||||
|
|
@ -1,82 +0,0 @@
|
|||
<script setup lang="ts">
|
||||
type ServiceState =
|
||||
| 'running'
|
||||
| 'stopped'
|
||||
| 'assigned-only'
|
||||
| 'available'
|
||||
| 'incompatible'
|
||||
| 'vram-tight'
|
||||
| 'unknown'
|
||||
|
||||
const props = defineProps<{
|
||||
serviceName: string
|
||||
state: ServiceState
|
||||
assigned: boolean
|
||||
disabled?: boolean
|
||||
}>()
|
||||
|
||||
const emit = defineEmits<{ toggle: [] }>()
|
||||
|
||||
const STATE_LABELS: Record<ServiceState, string> = {
|
||||
running: 'Running',
|
||||
stopped: 'Stopped',
|
||||
'assigned-only': 'Assigned',
|
||||
available: 'Available',
|
||||
incompatible: 'Incompatible',
|
||||
'vram-tight': 'VRAM tight',
|
||||
unknown: 'Unknown',
|
||||
}
|
||||
|
||||
const STATE_ICONS: Record<ServiceState, string> = {
|
||||
running: '▶',
|
||||
stopped: '⏹',
|
||||
'assigned-only': '📌',
|
||||
available: '○',
|
||||
incompatible: '✕',
|
||||
'vram-tight': '⚠',
|
||||
unknown: '?',
|
||||
}
|
||||
|
||||
function handleToggle() {
|
||||
if (!props.disabled && props.state !== 'incompatible') emit('toggle')
|
||||
}
|
||||
</script>
|
||||
|
||||
<template>
|
||||
<button
|
||||
class="service-badge"
|
||||
:class="[`state-${state}`, { assigned, 'is-disabled': disabled || state === 'incompatible' }]"
|
||||
:aria-pressed="assigned"
|
||||
:aria-label="`${serviceName}: ${STATE_LABELS[state] ?? state}${assigned ? ' (assigned)' : ''}`"
|
||||
:disabled="disabled || state === 'incompatible'"
|
||||
@click="handleToggle"
|
||||
>
|
||||
<span class="badge-icon" aria-hidden="true">{{ STATE_ICONS[state] ?? '?' }}</span>
|
||||
<span class="badge-name">{{ serviceName }}</span>
|
||||
<span class="badge-state">{{ STATE_LABELS[state] ?? state }}</span>
|
||||
</button>
|
||||
</template>
|
||||
|
||||
<style scoped>
|
||||
.service-badge {
|
||||
display: inline-flex;
|
||||
align-items: center;
|
||||
gap: 0.3rem;
|
||||
padding: 0.2rem 0.5rem;
|
||||
border-radius: 4px;
|
||||
border: 1px solid var(--color-border);
|
||||
background: var(--color-surface);
|
||||
color: var(--color-text);
|
||||
font-size: 0.75rem;
|
||||
cursor: pointer;
|
||||
transition: opacity 0.1s, border-color 0.1s;
|
||||
}
|
||||
.service-badge:hover:not(.is-disabled) { opacity: 0.8; }
|
||||
.service-badge.is-disabled { cursor: not-allowed; opacity: 0.5; }
|
||||
.service-badge.state-running { border-color: var(--color-success); }
|
||||
.service-badge.state-stopped { border-color: var(--color-warning); }
|
||||
.service-badge.state-assigned-only { border-color: var(--color-info); }
|
||||
.service-badge.state-incompatible { border-color: var(--color-error); }
|
||||
.service-badge.state-vram-tight { border-color: var(--color-warning); }
|
||||
.badge-state { color: var(--color-text-muted); }
|
||||
</style>
|
||||
|
|
@ -1,231 +0,0 @@
|
|||
<script setup lang="ts">
|
||||
import { ref, watch, computed } from 'vue'
|
||||
import type { ServiceDefinition } from '../../types/nodes'
|
||||
|
||||
const props = defineProps<{
|
||||
serviceName?: string
|
||||
definition?: ServiceDefinition
|
||||
}>()
|
||||
const emit = defineEmits<{
|
||||
save: [name: string, def: ServiceDefinition]
|
||||
cancel: []
|
||||
}>()
|
||||
|
||||
const name = ref(props.serviceName ?? '')
|
||||
const maxMb = ref(props.definition?.max_mb ?? 0)
|
||||
const priority = ref(props.definition?.priority ?? 1)
|
||||
const minCap = ref(props.definition?.min_compute_cap ?? 0)
|
||||
const prefCap = ref<number | ''>(props.definition?.preferred_compute_cap ?? '')
|
||||
const shared = ref(props.definition?.shared ?? false)
|
||||
const maxConcurrent = ref<number | ''>(props.definition?.max_concurrent ?? '')
|
||||
const idleStop = ref<number | ''>(props.definition?.idle_stop_after_s ?? '')
|
||||
const alwaysOn = ref(props.definition?.always_on ?? false)
|
||||
const modelBasePath = ref(props.definition?.model_base_path ?? '')
|
||||
const hasManaged = ref(!!props.definition?.managed)
|
||||
const managedJson = ref(
|
||||
props.definition?.managed ? JSON.stringify(props.definition.managed, null, 2) : ''
|
||||
)
|
||||
const formError = ref('')
|
||||
|
||||
watch(() => props.definition, (d) => {
|
||||
name.value = props.serviceName ?? ''
|
||||
maxMb.value = d?.max_mb ?? 0
|
||||
priority.value = d?.priority ?? 1
|
||||
minCap.value = d?.min_compute_cap ?? 0
|
||||
prefCap.value = d?.preferred_compute_cap ?? ''
|
||||
shared.value = d?.shared ?? false
|
||||
maxConcurrent.value = d?.max_concurrent ?? ''
|
||||
idleStop.value = d?.idle_stop_after_s ?? ''
|
||||
alwaysOn.value = d?.always_on ?? false
|
||||
modelBasePath.value = d?.model_base_path ?? ''
|
||||
hasManaged.value = !!d?.managed
|
||||
managedJson.value = d?.managed ? JSON.stringify(d.managed, null, 2) : ''
|
||||
})
|
||||
|
||||
const managedJsonError = computed(() => {
|
||||
if (!hasManaged.value || !managedJson.value.trim()) return ''
|
||||
try { JSON.parse(managedJson.value); return '' }
|
||||
catch { return 'Invalid JSON' }
|
||||
})
|
||||
|
||||
function submit() {
|
||||
formError.value = ''
|
||||
if (!name.value.trim()) { formError.value = 'Service name is required.'; return }
|
||||
if (!maxMb.value || maxMb.value <= 0) { formError.value = 'max_mb must be > 0.'; return }
|
||||
if (managedJsonError.value) { formError.value = 'Fix the managed JSON before saving.'; return }
|
||||
|
||||
const def: ServiceDefinition = { max_mb: maxMb.value, priority: priority.value }
|
||||
if (minCap.value) def.min_compute_cap = minCap.value
|
||||
if (prefCap.value !== '') def.preferred_compute_cap = Number(prefCap.value)
|
||||
if (shared.value) def.shared = true
|
||||
if (maxConcurrent.value !== '') def.max_concurrent = Number(maxConcurrent.value)
|
||||
if (idleStop.value !== '') def.idle_stop_after_s = Number(idleStop.value)
|
||||
if (alwaysOn.value) def.always_on = true
|
||||
if (modelBasePath.value.trim()) def.model_base_path = modelBasePath.value.trim()
|
||||
if (hasManaged.value && managedJson.value.trim()) {
|
||||
def.managed = JSON.parse(managedJson.value)
|
||||
}
|
||||
// Preserve existing catalog when editing
|
||||
if (props.definition?.catalog) def.catalog = props.definition.catalog
|
||||
|
||||
emit('save', name.value.trim(), def)
|
||||
}
|
||||
</script>
|
||||
|
||||
<template>
|
||||
<div class="modal-backdrop" role="dialog" aria-modal="true" :aria-label="`${serviceName ? 'Edit' : 'Add'} service`">
|
||||
<div class="modal-box">
|
||||
<h3 class="modal-title">{{ serviceName ? 'Edit' : 'Add' }} Service</h3>
|
||||
|
||||
<div class="field-row">
|
||||
<label class="field-label" for="sf-name">Service name</label>
|
||||
<input id="sf-name" v-model="name" class="field-input" :readonly="!!serviceName" placeholder="cf-text" />
|
||||
</div>
|
||||
|
||||
<div class="field-row">
|
||||
<label class="field-label" for="sf-maxmb">max_mb</label>
|
||||
<input id="sf-maxmb" v-model.number="maxMb" type="number" min="0" class="field-input field-input--sm" />
|
||||
<span class="field-hint">VRAM ceiling</span>
|
||||
</div>
|
||||
|
||||
<div class="field-row">
|
||||
<label class="field-label" for="sf-prio">priority</label>
|
||||
<input id="sf-prio" v-model.number="priority" type="number" min="1" max="10" class="field-input field-input--sm" />
|
||||
<span class="field-hint">1 = highest</span>
|
||||
</div>
|
||||
|
||||
<div class="field-row">
|
||||
<label class="field-label" for="sf-mincap">min_compute_cap</label>
|
||||
<input id="sf-mincap" v-model.number="minCap" type="number" step="0.1" min="0" class="field-input field-input--sm" placeholder="0.0" />
|
||||
</div>
|
||||
|
||||
<div class="field-row">
|
||||
<label class="field-label" for="sf-prefcap">preferred_cap</label>
|
||||
<input id="sf-prefcap" v-model="prefCap" type="number" step="0.1" min="0" class="field-input field-input--sm" placeholder="optional" />
|
||||
</div>
|
||||
|
||||
<div class="field-row field-row--check">
|
||||
<input id="sf-shared" v-model="shared" type="checkbox" />
|
||||
<label for="sf-shared">shared (multiple concurrent users)</label>
|
||||
</div>
|
||||
|
||||
<div class="field-row">
|
||||
<label class="field-label" for="sf-maxcon">max_concurrent</label>
|
||||
<input id="sf-maxcon" v-model="maxConcurrent" type="number" min="1" class="field-input field-input--sm" placeholder="optional" />
|
||||
</div>
|
||||
|
||||
<div class="field-row">
|
||||
<label class="field-label" for="sf-idle">idle_stop_after_s</label>
|
||||
<input id="sf-idle" v-model="idleStop" type="number" min="0" class="field-input field-input--sm" placeholder="optional" />
|
||||
<span class="field-hint">seconds</span>
|
||||
</div>
|
||||
|
||||
<div class="field-row field-row--check">
|
||||
<input id="sf-always" v-model="alwaysOn" type="checkbox" />
|
||||
<label for="sf-always">always_on (never evict)</label>
|
||||
</div>
|
||||
|
||||
<div class="field-row">
|
||||
<label class="field-label" for="sf-base">model_base_path</label>
|
||||
<input id="sf-base" v-model="modelBasePath" class="field-input" placeholder="/devl/Assets/LLM/cf-text/models (optional)" />
|
||||
</div>
|
||||
|
||||
<div class="managed-section">
|
||||
<div class="field-row field-row--check">
|
||||
<input id="sf-has-managed" v-model="hasManaged" type="checkbox" />
|
||||
<label for="sf-has-managed">Has managed process config</label>
|
||||
</div>
|
||||
<div v-if="hasManaged" class="managed-body">
|
||||
<label class="field-label" for="sf-managed">managed (JSON)</label>
|
||||
<textarea
|
||||
id="sf-managed"
|
||||
v-model="managedJson"
|
||||
class="field-textarea"
|
||||
rows="6"
|
||||
spellcheck="false"
|
||||
placeholder='{"type": "process", "exec_path": "...", "args_template": "...", "port": 8008, "host_port": 8008}'
|
||||
/>
|
||||
<span v-if="managedJsonError" class="json-error" role="alert">{{ managedJsonError }}</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div v-if="formError" class="form-error" role="alert">{{ formError }}</div>
|
||||
|
||||
<div class="modal-actions">
|
||||
<button class="btn-secondary" @click="emit('cancel')">Cancel</button>
|
||||
<button class="btn-primary" @click="submit">Save</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<style scoped>
|
||||
.modal-backdrop {
|
||||
position: fixed; inset: 0;
|
||||
background: rgba(0,0,0,0.5);
|
||||
display: flex; align-items: center; justify-content: center;
|
||||
z-index: 200;
|
||||
}
|
||||
.modal-box {
|
||||
background: var(--color-surface-raised);
|
||||
border: 1px solid var(--color-border);
|
||||
border-radius: 8px;
|
||||
padding: 1.5rem;
|
||||
width: 100%; max-width: 540px;
|
||||
max-height: 90vh; overflow-y: auto;
|
||||
display: flex; flex-direction: column; gap: 0.65rem;
|
||||
color: var(--color-text);
|
||||
}
|
||||
.modal-title { margin: 0 0 0.25rem; font-size: 1rem; font-weight: 600; color: var(--color-text); }
|
||||
.field-row { display: flex; align-items: center; gap: 0.5rem; }
|
||||
.field-row--check { gap: 0.4rem; font-size: 0.875rem; color: var(--color-text); }
|
||||
.field-label { min-width: 9rem; font-size: 0.85rem; color: var(--color-text-muted); flex-shrink: 0; }
|
||||
.field-hint { font-size: 0.75rem; color: var(--color-text-muted); }
|
||||
.field-input {
|
||||
flex: 1;
|
||||
background: var(--color-surface-alt);
|
||||
border: 1px solid var(--color-border);
|
||||
border-radius: 4px;
|
||||
padding: 0.3rem 0.5rem;
|
||||
color: var(--color-text);
|
||||
font-size: 0.85rem;
|
||||
}
|
||||
.field-input--sm { flex: 0 0 8rem; }
|
||||
.managed-section { display: flex; flex-direction: column; gap: 0.4rem; border-top: 1px solid var(--color-border); padding-top: 0.5rem; }
|
||||
.managed-body { display: flex; flex-direction: column; gap: 0.3rem; }
|
||||
.field-textarea {
|
||||
width: 100%;
|
||||
background: var(--color-surface-alt);
|
||||
border: 1px solid var(--color-border);
|
||||
border-radius: 4px;
|
||||
padding: 0.4rem 0.5rem;
|
||||
color: var(--color-text);
|
||||
font-size: 0.8rem;
|
||||
font-family: var(--font-mono, monospace);
|
||||
resize: vertical;
|
||||
box-sizing: border-box;
|
||||
}
|
||||
.json-error { color: var(--color-error); font-size: 0.78rem; }
|
||||
.form-error { color: var(--color-error); font-size: 0.8rem; }
|
||||
.modal-actions { display: flex; justify-content: flex-end; gap: 0.5rem; margin-top: 0.25rem; }
|
||||
.btn-primary {
|
||||
background: var(--app-primary);
|
||||
color: var(--color-text-inverse);
|
||||
border: none;
|
||||
border-radius: 4px;
|
||||
padding: 0.4rem 1rem;
|
||||
cursor: pointer;
|
||||
font-size: 0.875rem;
|
||||
}
|
||||
.btn-primary:hover { background: var(--app-primary-hover); }
|
||||
.btn-secondary {
|
||||
background: transparent;
|
||||
border: 1px solid var(--color-border);
|
||||
color: var(--color-text);
|
||||
border-radius: 4px;
|
||||
padding: 0.4rem 0.75rem;
|
||||
cursor: pointer;
|
||||
font-size: 0.875rem;
|
||||
}
|
||||
.btn-secondary:hover { background: var(--color-surface-alt); }
|
||||
</style>
|
||||
|
|
@ -1,42 +0,0 @@
|
|||
// src/composables/useSftKeyboard.ts
|
||||
import { onUnmounted, getCurrentInstance } from 'vue'
|
||||
|
||||
interface Options {
|
||||
onCorrect: () => void
|
||||
onDiscard: () => void
|
||||
onFlag: () => void
|
||||
onEscape: () => void
|
||||
onSubmit: () => void
|
||||
isEditing: () => boolean // returns true when correction area is open
|
||||
}
|
||||
|
||||
export function useSftKeyboard(opts: Options) {
|
||||
function handler(e: KeyboardEvent) {
|
||||
// Never intercept keys when focus is in an input (correction textarea handles its own keys)
|
||||
if (e.target instanceof HTMLInputElement) return
|
||||
|
||||
// When correction area is open, only handle Escape (textarea handles Ctrl+Enter itself)
|
||||
if (e.target instanceof HTMLTextAreaElement) return
|
||||
|
||||
const k = e.key.toLowerCase()
|
||||
|
||||
if (opts.isEditing()) {
|
||||
if (k === 'escape') opts.onEscape()
|
||||
return
|
||||
}
|
||||
|
||||
if (k === 'c') { opts.onCorrect(); return }
|
||||
if (k === 'd') { opts.onDiscard(); return }
|
||||
if (k === 'f') { opts.onFlag(); return }
|
||||
if (k === 'escape') { opts.onEscape(); return }
|
||||
}
|
||||
|
||||
window.addEventListener('keydown', handler)
|
||||
const cleanup = () => window.removeEventListener('keydown', handler)
|
||||
|
||||
if (getCurrentInstance()) {
|
||||
onUnmounted(cleanup)
|
||||
}
|
||||
|
||||
return { cleanup }
|
||||
}
|
||||
|
|
@ -1,53 +1,19 @@
|
|||
import { createRouter, createWebHashHistory } from 'vue-router'
|
||||
import LabelView from '../views/LabelView.vue'
|
||||
|
||||
// Lazy-loaded views
|
||||
const DashboardView = () => import('../views/DashboardView.vue')
|
||||
const LabelView = () => import('../views/LabelView.vue')
|
||||
const FetchView = () => import('../views/FetchView.vue')
|
||||
const CorrectionsView = () => import('../views/CorrectionsView.vue')
|
||||
const ImitateView = () => import('../views/ImitateView.vue')
|
||||
const BenchmarkView = () => import('../views/BenchmarkView.vue')
|
||||
const CompareView = () => import('../views/CompareView.vue')
|
||||
const TrainJobsView = () => import('../views/TrainJobsView.vue')
|
||||
const TrainResultsView = () => import('../views/TrainResultsView.vue')
|
||||
const ModelsView = () => import('../views/ModelsView.vue')
|
||||
const SettingsView = () => import('../views/SettingsView.vue')
|
||||
const NodeManagementView = () => import('../views/NodeManagementView.vue')
|
||||
|
||||
export const routes = [
|
||||
// ── Top-level ────────────────────────────────────────────
|
||||
{ path: '/', component: DashboardView, meta: { title: 'Dashboard' } },
|
||||
{ path: '/fleet', component: ModelsView, meta: { title: 'Fleet' } },
|
||||
{ path: '/nodes', component: NodeManagementView, meta: { title: 'Nodes' } },
|
||||
{ path: '/settings', component: SettingsView, meta: { title: 'Settings' } },
|
||||
|
||||
// ── Data domain ──────────────────────────────────────────
|
||||
{ path: '/data/label', component: LabelView, meta: { title: 'Label' } },
|
||||
{ path: '/data/fetch', component: FetchView, meta: { title: 'Fetch' } },
|
||||
{ path: '/data/corrections', component: CorrectionsView, meta: { title: 'Corrections' } },
|
||||
{ path: '/data/imitate', component: ImitateView, meta: { title: 'Imitate' } },
|
||||
{ path: '/data/recipe-scan', component: () => import('../views/RecipeScanView.vue'), meta: { title: 'Recipe Scan' } },
|
||||
|
||||
// ── Eval domain ──────────────────────────────────────────
|
||||
{ path: '/eval/benchmark', component: BenchmarkView, meta: { title: 'Benchmark' } },
|
||||
{ path: '/eval/compare', component: CompareView, meta: { title: 'Compare' } },
|
||||
{ path: '/eval/embed-compare', component: () => import('../views/EmbedCompareView.vue'), meta: { title: 'Embed Compare' } },
|
||||
|
||||
// ── Train domain ─────────────────────────────────────────
|
||||
{ path: '/train/jobs', component: TrainJobsView, meta: { title: 'Training Jobs' } },
|
||||
{ path: '/train/results', component: TrainResultsView, meta: { title: 'Training Results' } },
|
||||
|
||||
// ── Backward-compat redirects ────────────────────────────
|
||||
{ path: '/benchmark', redirect: '/eval/benchmark' },
|
||||
{ path: '/models', redirect: '/fleet' },
|
||||
{ path: '/stats', redirect: '/' },
|
||||
{ path: '/label', redirect: '/data/label' },
|
||||
{ path: '/fetch', redirect: '/data/fetch' },
|
||||
{ path: '/corrections', redirect: '/data/corrections' },
|
||||
{ path: '/imitate', redirect: '/data/imitate' },
|
||||
]
|
||||
// Views are lazy-loaded to keep initial bundle small
|
||||
const FetchView = () => import('../views/FetchView.vue')
|
||||
const StatsView = () => import('../views/StatsView.vue')
|
||||
const BenchmarkView = () => import('../views/BenchmarkView.vue')
|
||||
const SettingsView = () => import('../views/SettingsView.vue')
|
||||
|
||||
export const router = createRouter({
|
||||
history: createWebHashHistory(),
|
||||
routes,
|
||||
routes: [
|
||||
{ path: '/', component: LabelView, meta: { title: 'Label' } },
|
||||
{ path: '/fetch', component: FetchView, meta: { title: 'Fetch' } },
|
||||
{ path: '/stats', component: StatsView, meta: { title: 'Stats' } },
|
||||
{ path: '/benchmark', component: BenchmarkView, meta: { title: 'Benchmark' } },
|
||||
{ path: '/settings', component: SettingsView, meta: { title: 'Settings' } },
|
||||
],
|
||||
})
|
||||
|
|
|
|||
|
|
@ -1,94 +0,0 @@
|
|||
import { describe, it, expect } from 'vitest'
|
||||
import { createRouter, createWebHashHistory } from 'vue-router'
|
||||
|
||||
// Import the raw routes array so we can test structure without mounting App
|
||||
import { routes } from './index'
|
||||
|
||||
describe('router routes', () => {
|
||||
it('exports a routes array', () => {
|
||||
expect(Array.isArray(routes)).toBe(true)
|
||||
})
|
||||
|
||||
it('has / pointing to DashboardView', () => {
|
||||
const root = routes.find(r => r.path === '/')
|
||||
expect(root).toBeDefined()
|
||||
// Component should be async (lazy) or have a name
|
||||
expect(root?.component).toBeDefined()
|
||||
})
|
||||
|
||||
it('has /fleet route', () => {
|
||||
const r = routes.find(r => r.path === '/fleet')
|
||||
expect(r).toBeDefined()
|
||||
})
|
||||
|
||||
it('has /data/label route', () => {
|
||||
const r = routes.find(r => r.path === '/data/label')
|
||||
expect(r).toBeDefined()
|
||||
})
|
||||
|
||||
it('has /data/fetch route', () => {
|
||||
const r = routes.find(r => r.path === '/data/fetch')
|
||||
expect(r).toBeDefined()
|
||||
})
|
||||
|
||||
it('has /data/corrections route', () => {
|
||||
const r = routes.find(r => r.path === '/data/corrections')
|
||||
expect(r).toBeDefined()
|
||||
})
|
||||
|
||||
it('has /data/imitate route', () => {
|
||||
const r = routes.find(r => r.path === '/data/imitate')
|
||||
expect(r).toBeDefined()
|
||||
})
|
||||
|
||||
it('has /eval/benchmark route', () => {
|
||||
const r = routes.find(r => r.path === '/eval/benchmark')
|
||||
expect(r).toBeDefined()
|
||||
})
|
||||
|
||||
it('has /eval/compare route', () => {
|
||||
const r = routes.find(r => r.path === '/eval/compare')
|
||||
expect(r).toBeDefined()
|
||||
})
|
||||
|
||||
it('has /train/jobs route', () => {
|
||||
const r = routes.find(r => r.path === '/train/jobs')
|
||||
expect(r).toBeDefined()
|
||||
})
|
||||
|
||||
it('has /train/results route', () => {
|
||||
const r = routes.find(r => r.path === '/train/results')
|
||||
expect(r).toBeDefined()
|
||||
})
|
||||
|
||||
it('has /settings route', () => {
|
||||
const r = routes.find(r => r.path === '/settings')
|
||||
expect(r).toBeDefined()
|
||||
})
|
||||
|
||||
it('has backward-compat redirect from /benchmark to /eval/benchmark', () => {
|
||||
const r = routes.find(r => r.path === '/benchmark')
|
||||
expect(r).toBeDefined()
|
||||
expect((r as { redirect?: string }).redirect).toBe('/eval/benchmark')
|
||||
})
|
||||
|
||||
it('has backward-compat redirect from /models to /fleet', () => {
|
||||
const r = routes.find(r => r.path === '/models')
|
||||
expect(r).toBeDefined()
|
||||
expect((r as { redirect?: string }).redirect).toBe('/fleet')
|
||||
})
|
||||
|
||||
it('has backward-compat redirect from /stats to /', () => {
|
||||
const r = routes.find(r => r.path === '/stats')
|
||||
expect(r).toBeDefined()
|
||||
expect((r as { redirect?: string }).redirect).toBe('/')
|
||||
})
|
||||
|
||||
it('can create a functional router instance', () => {
|
||||
const router = createRouter({
|
||||
history: createWebHashHistory(),
|
||||
routes,
|
||||
})
|
||||
expect(router).toBeDefined()
|
||||
})
|
||||
})
|
||||
|
|
@ -1,78 +0,0 @@
|
|||
import { setActivePinia, createPinia } from 'pinia'
|
||||
import { useSftStore } from './sft'
|
||||
import type { SftQueueItem } from './sft'
|
||||
import { beforeEach, describe, it, expect } from 'vitest'
|
||||
|
||||
function makeMockItem(overrides: Partial<SftQueueItem> = {}): SftQueueItem {
|
||||
return {
|
||||
id: 'abc',
|
||||
source: 'cf-orch-benchmark',
|
||||
benchmark_run_id: 'run1',
|
||||
timestamp: '2026-04-07T10:00:00Z',
|
||||
status: 'needs_review',
|
||||
prompt_messages: [
|
||||
{ role: 'system', content: 'You are a coding assistant.' },
|
||||
{ role: 'user', content: 'Write a Python add function.' },
|
||||
],
|
||||
model_response: 'def add(a, b): return a - b',
|
||||
corrected_response: null,
|
||||
quality_score: 0.2,
|
||||
failure_reason: 'pattern_match: 0/2 matched',
|
||||
task_id: 'code-fn',
|
||||
task_type: 'code',
|
||||
task_name: 'Code: Write a Python function',
|
||||
model_id: 'Qwen/Qwen2.5-3B',
|
||||
model_name: 'Qwen2.5-3B',
|
||||
node_id: 'heimdall',
|
||||
gpu_id: 0,
|
||||
tokens_per_sec: 38.4,
|
||||
...overrides,
|
||||
}
|
||||
}
|
||||
|
||||
describe('useSftStore', () => {
|
||||
beforeEach(() => setActivePinia(createPinia()))
|
||||
|
||||
it('starts with empty queue', () => {
|
||||
const store = useSftStore()
|
||||
expect(store.queue).toEqual([])
|
||||
expect(store.current).toBeNull()
|
||||
})
|
||||
|
||||
it('current returns first item', () => {
|
||||
const store = useSftStore()
|
||||
store.queue = [makeMockItem()]
|
||||
expect(store.current?.id).toBe('abc')
|
||||
})
|
||||
|
||||
it('removeCurrentFromQueue removes first item', () => {
|
||||
const store = useSftStore()
|
||||
const second = makeMockItem({ id: 'def' })
|
||||
store.queue = [makeMockItem(), second]
|
||||
store.removeCurrentFromQueue()
|
||||
expect(store.queue[0].id).toBe('def')
|
||||
})
|
||||
|
||||
it('restoreItem adds to front of queue', () => {
|
||||
const store = useSftStore()
|
||||
const second = makeMockItem({ id: 'def' })
|
||||
store.queue = [second]
|
||||
store.restoreItem(makeMockItem())
|
||||
expect(store.queue[0].id).toBe('abc')
|
||||
expect(store.queue[1].id).toBe('def')
|
||||
})
|
||||
|
||||
it('setLastAction records the action', () => {
|
||||
const store = useSftStore()
|
||||
store.setLastAction('discard', makeMockItem())
|
||||
expect(store.lastAction?.type).toBe('discard')
|
||||
expect(store.lastAction?.item.id).toBe('abc')
|
||||
})
|
||||
|
||||
it('clearLastAction nulls lastAction', () => {
|
||||
const store = useSftStore()
|
||||
store.setLastAction('flag', makeMockItem())
|
||||
store.clearLastAction()
|
||||
expect(store.lastAction).toBeNull()
|
||||
})
|
||||
})
|
||||
|
|
@ -1,72 +0,0 @@
|
|||
// src/stores/sft.ts
|
||||
import { defineStore } from 'pinia'
|
||||
import { computed, ref } from 'vue'
|
||||
|
||||
export type SftFailureCategory =
|
||||
| 'scoring_artifact'
|
||||
| 'style_violation'
|
||||
| 'partial_answer'
|
||||
| 'wrong_answer'
|
||||
| 'format_error'
|
||||
| 'hallucination'
|
||||
|
||||
export interface SftQueueItem {
|
||||
id: string
|
||||
source: 'cf-orch-benchmark'
|
||||
benchmark_run_id: string
|
||||
timestamp: string
|
||||
status: 'needs_review' | 'approved' | 'discarded' | 'model_rejected'
|
||||
prompt_messages: { role: string; content: string }[]
|
||||
model_response: string
|
||||
corrected_response: string | null
|
||||
quality_score: number // 0.0 to 1.0
|
||||
failure_reason: string | null
|
||||
failure_category: SftFailureCategory | null
|
||||
task_id: string
|
||||
task_type: string
|
||||
task_name: string
|
||||
model_id: string
|
||||
model_name: string
|
||||
node_id: string
|
||||
gpu_id: number
|
||||
tokens_per_sec: number
|
||||
}
|
||||
|
||||
export interface SftLastAction {
|
||||
type: 'correct' | 'discard' | 'flag'
|
||||
item: SftQueueItem
|
||||
failure_category?: SftFailureCategory | null
|
||||
}
|
||||
|
||||
export const useSftStore = defineStore('sft', () => {
|
||||
const queue = ref<SftQueueItem[]>([])
|
||||
const totalRemaining = ref(0)
|
||||
const lastAction = ref<SftLastAction | null>(null)
|
||||
|
||||
const current = computed(() => queue.value[0] ?? null)
|
||||
|
||||
function removeCurrentFromQueue() {
|
||||
queue.value.shift()
|
||||
}
|
||||
|
||||
function setLastAction(
|
||||
type: SftLastAction['type'],
|
||||
item: SftQueueItem,
|
||||
failure_category?: SftFailureCategory | null,
|
||||
) {
|
||||
lastAction.value = { type, item, failure_category }
|
||||
}
|
||||
|
||||
function clearLastAction() {
|
||||
lastAction.value = null
|
||||
}
|
||||
|
||||
function restoreItem(item: SftQueueItem) {
|
||||
queue.value.unshift(item)
|
||||
}
|
||||
|
||||
return {
|
||||
queue, totalRemaining, lastAction, current,
|
||||
removeCurrentFromQueue, setLastAction, clearLastAction, restoreItem,
|
||||
}
|
||||
})
|
||||
|
|
@ -1,89 +0,0 @@
|
|||
export interface GpuEntry {
|
||||
gpu_id: number
|
||||
card: string
|
||||
vram_total_mb: number
|
||||
vram_used_mb: number
|
||||
vram_free_mb: number
|
||||
temp_c: number | null
|
||||
utilization_pct: number | null
|
||||
compute_cap: number | null
|
||||
services_assigned: string[]
|
||||
services_running: string[]
|
||||
}
|
||||
|
||||
export interface ServiceInfo {
|
||||
min_compute_cap: number
|
||||
max_mb: number
|
||||
catalog_size: number
|
||||
}
|
||||
|
||||
export interface NodeSummary {
|
||||
node_id: string
|
||||
online: boolean
|
||||
agent_url: string
|
||||
gpus: GpuEntry[]
|
||||
profile_loaded: boolean
|
||||
services_catalog: Record<string, ServiceInfo>
|
||||
}
|
||||
|
||||
// ── Full profile types (for profile editor) ────────────────────────────────────
|
||||
|
||||
export interface ServiceManaged {
|
||||
type: string
|
||||
exec_path?: string
|
||||
args_template?: string
|
||||
port?: number
|
||||
host_port?: number
|
||||
base_port?: number
|
||||
health_path?: string
|
||||
cwd?: string
|
||||
adopt?: boolean
|
||||
[key: string]: unknown
|
||||
}
|
||||
|
||||
export interface CatalogEntryFull {
|
||||
path: string
|
||||
vram_mb: number
|
||||
description?: string
|
||||
multi_gpu?: boolean
|
||||
env?: Record<string, string>
|
||||
}
|
||||
|
||||
export interface ServiceDefinition {
|
||||
max_mb: number
|
||||
priority: number
|
||||
min_compute_cap?: number
|
||||
preferred_compute_cap?: number
|
||||
shared?: boolean
|
||||
max_concurrent?: number
|
||||
idle_stop_after_s?: number
|
||||
always_on?: boolean
|
||||
model_base_path?: string
|
||||
managed?: ServiceManaged
|
||||
catalog?: Record<string, CatalogEntryFull>
|
||||
}
|
||||
|
||||
export interface NodeHardwareGpu {
|
||||
id: number
|
||||
vram_mb: number
|
||||
compute_cap?: number
|
||||
card?: string
|
||||
role?: string
|
||||
services?: string[]
|
||||
}
|
||||
|
||||
export interface NodeHardwareEntry {
|
||||
local_model_root?: string
|
||||
agent_url?: string
|
||||
gpus: NodeHardwareGpu[]
|
||||
}
|
||||
|
||||
export interface FullProfile {
|
||||
schema_version?: number
|
||||
name?: string
|
||||
vram_total_mb?: number
|
||||
eviction_timeout_s?: number
|
||||
services: Record<string, ServiceDefinition>
|
||||
nodes: Record<string, NodeHardwareEntry>
|
||||
model_size_hints?: Record<string, string>
|
||||
}
|
||||
|
|
@ -1,987 +0,0 @@
|
|||
<template>
|
||||
<div class="assignments-tab">
|
||||
|
||||
<!-- ── Toast ───────────────────────────────────────────── -->
|
||||
<div v-if="toast" class="toast" :class="toast.type" role="status" aria-live="polite">
|
||||
{{ toast.message }}
|
||||
</div>
|
||||
|
||||
<!-- ── Assignments section ─────────────────────────────── -->
|
||||
<div class="section-header">
|
||||
<h2 class="section-title">Task Assignments</h2>
|
||||
<button class="btn-primary btn-sm" @click="openNewAssignment">+ New Assignment</button>
|
||||
</div>
|
||||
|
||||
<div class="filter-row">
|
||||
<label for="product-filter" class="filter-label">Product</label>
|
||||
<select id="product-filter" v-model="productFilter" class="filter-select">
|
||||
<option value="">All products</option>
|
||||
<option v-for="p in allProducts" :key="p" :value="p">{{ p }}</option>
|
||||
</select>
|
||||
</div>
|
||||
|
||||
<div v-if="assignmentsLoading" class="empty-state">Loading assignments…</div>
|
||||
<div v-else-if="assignmentsError" class="error-notice" role="alert">{{ assignmentsError }}</div>
|
||||
<div v-else-if="filteredGroups.length === 0" class="empty-state">No assignments yet. Add one above.</div>
|
||||
<div v-else class="product-groups">
|
||||
<div v-for="group in filteredGroups" :key="group.product" class="product-group">
|
||||
<h3 class="product-name">{{ group.product.toUpperCase() }}</h3>
|
||||
<div class="assignment-list">
|
||||
<div v-for="a in group.assignments" :key="`${a.product}/${a.task}`" class="assignment-row">
|
||||
<div class="assignment-main">
|
||||
<span class="task-id">{{ a.task }}</span>
|
||||
<span
|
||||
class="model-name"
|
||||
:title="a.model_id"
|
||||
>{{ displayModelId(a) }}</span>
|
||||
<span v-if="a.vram_mb" class="chip chip-vram">{{ formatVram(a.vram_mb) }}</span>
|
||||
<span v-if="a.service_type" class="chip" :class="serviceChipClass(a.service_type)">{{ a.service_type }}</span>
|
||||
</div>
|
||||
|
||||
<!-- Node deployment status -->
|
||||
<div v-if="deploymentMap[`${a.product}/${a.task}`]" class="node-statuses">
|
||||
<span
|
||||
v-for="ns in deploymentMap[`${a.product}/${a.task}`]"
|
||||
:key="ns.node_id"
|
||||
class="node-badge-wrap"
|
||||
>
|
||||
<span
|
||||
class="node-badge"
|
||||
:class="ns.status"
|
||||
:title="`${ns.node_id}: ${ns.status}`"
|
||||
>
|
||||
<span class="node-icon">{{ nodeIcon(ns.status) }}</span>
|
||||
{{ ns.node_id }}
|
||||
</span>
|
||||
<button
|
||||
v-if="ns.status === 'absent'"
|
||||
class="btn-deploy"
|
||||
:disabled="deploying.has(`${a.product}/${a.task}/${ns.node_id}`)"
|
||||
:title="`Register ${a.model_id} in ${ns.node_id} catalog`"
|
||||
@click="deployModel(a, ns.node_id)"
|
||||
>{{ deploying.has(`${a.product}/${a.task}/${ns.node_id}`) ? '…' : 'Register' }}</button>
|
||||
</span>
|
||||
</div>
|
||||
|
||||
<div class="assignment-actions">
|
||||
<button
|
||||
v-if="editingKey !== `${a.product}/${a.task}`"
|
||||
class="btn-ghost btn-sm"
|
||||
@click="startEdit(a)"
|
||||
>Edit</button>
|
||||
<button
|
||||
class="btn-ghost btn-sm btn-danger"
|
||||
@click="deleteAssignment(a.product, a.task)"
|
||||
>Delete</button>
|
||||
</div>
|
||||
|
||||
<!-- Inline edit form -->
|
||||
<div v-if="editingKey === `${a.product}/${a.task}`" class="inline-edit">
|
||||
<select v-model="editDraft.model_id" class="edit-select" aria-label="Model">
|
||||
<option value="" disabled>Select model…</option>
|
||||
<option v-for="m in registryModels" :key="m.model_id" :value="m.model_id">
|
||||
{{ m.alias || truncate(m.model_id, 40) }}
|
||||
</option>
|
||||
</select>
|
||||
<input
|
||||
v-model="editDraft.description"
|
||||
type="text"
|
||||
class="edit-input"
|
||||
placeholder="Description (optional)"
|
||||
/>
|
||||
<div class="inline-edit-btns">
|
||||
<button class="btn-primary btn-sm" :disabled="!editDraft.model_id" @click="saveEdit(a)">Save</button>
|
||||
<button class="btn-ghost btn-sm" @click="editingKey = null">Cancel</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- ── Model Registry section ───────────────────────────── -->
|
||||
<div class="section-header section-header-mt">
|
||||
<h2 class="section-title">Model Registry</h2>
|
||||
<button class="btn-primary btn-sm" @click="showRegisterModal = true">Register Model</button>
|
||||
</div>
|
||||
|
||||
<div v-if="registryLoading" class="empty-state">Loading model registry…</div>
|
||||
<div v-else-if="registryError" class="error-notice" role="alert">{{ registryError }}</div>
|
||||
<div v-else-if="registryModels.length === 0" class="empty-state">No models registered yet.</div>
|
||||
<div v-else class="registry-table-wrap">
|
||||
<table class="registry-table">
|
||||
<thead>
|
||||
<tr>
|
||||
<th>Alias</th>
|
||||
<th>Model ID</th>
|
||||
<th>VRAM</th>
|
||||
<th>Service</th>
|
||||
<th class="col-hf">HF Repo</th>
|
||||
<th></th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
<tr v-for="m in registryModels" :key="m.model_id">
|
||||
<td>{{ m.alias || '—' }}</td>
|
||||
<td>
|
||||
<span class="truncated" :title="m.model_id">{{ truncate(m.model_id, 36) }}</span>
|
||||
</td>
|
||||
<td>{{ formatVram(m.vram_mb) }}</td>
|
||||
<td><span class="chip" :class="serviceChipClass(m.service_type)">{{ m.service_type }}</span></td>
|
||||
<td class="col-hf">
|
||||
<a
|
||||
v-if="m.hf_repo"
|
||||
:href="`https://huggingface.co/${m.hf_repo}`"
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
class="hf-link"
|
||||
>{{ truncate(m.hf_repo, 30) }}</a>
|
||||
<span v-else class="text-muted">—</span>
|
||||
</td>
|
||||
<td>
|
||||
<button class="btn-ghost btn-sm btn-danger" @click="deleteModel(m.model_id)">Delete</button>
|
||||
</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
|
||||
<!-- ── New Assignment modal ─────────────────────────────── -->
|
||||
<div v-if="showNewAssignmentModal" class="modal-backdrop" @click.self="showNewAssignmentModal = false">
|
||||
<div class="modal" role="dialog" aria-modal="true" aria-labelledby="modal-new-assignment-title">
|
||||
<h3 id="modal-new-assignment-title" class="modal-title">New Assignment</h3>
|
||||
<label class="form-label">Product</label>
|
||||
<input
|
||||
v-model="newAssignment.product"
|
||||
list="product-list"
|
||||
class="form-input"
|
||||
placeholder="e.g. peregrine"
|
||||
autocomplete="off"
|
||||
/>
|
||||
<datalist id="product-list">
|
||||
<option v-for="p in allProducts" :key="p" :value="p" />
|
||||
</datalist>
|
||||
|
||||
<label class="form-label">Task ID</label>
|
||||
<input
|
||||
v-model="newAssignment.task"
|
||||
type="text"
|
||||
class="form-input"
|
||||
placeholder="e.g. cover_letter"
|
||||
/>
|
||||
|
||||
<label class="form-label">Model</label>
|
||||
<select v-model="newAssignment.model_id" class="form-select">
|
||||
<option value="" disabled>Select from registry…</option>
|
||||
<option v-for="m in registryModels" :key="m.model_id" :value="m.model_id">
|
||||
{{ m.alias || truncate(m.model_id, 50) }}
|
||||
</option>
|
||||
</select>
|
||||
|
||||
<label class="form-label">Description <span class="optional">(optional)</span></label>
|
||||
<input
|
||||
v-model="newAssignment.description"
|
||||
type="text"
|
||||
class="form-input"
|
||||
placeholder="Human-readable note for operators"
|
||||
/>
|
||||
|
||||
<div class="modal-actions">
|
||||
<button
|
||||
class="btn-primary"
|
||||
:disabled="!newAssignment.product || !newAssignment.task || !newAssignment.model_id || saving"
|
||||
@click="saveNewAssignment"
|
||||
>{{ saving ? 'Saving…' : 'Save' }}</button>
|
||||
<button class="btn-ghost" @click="showNewAssignmentModal = false">Cancel</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- ── Register Model modal ─────────────────────────────── -->
|
||||
<div v-if="showRegisterModal" class="modal-backdrop" @click.self="showRegisterModal = false">
|
||||
<div class="modal" role="dialog" aria-modal="true" aria-labelledby="modal-register-title">
|
||||
<h3 id="modal-register-title" class="modal-title">Register Model</h3>
|
||||
|
||||
<label class="form-label">Model ID <span class="hint">(HuggingFace slug, e.g. ibm-granite/granite-4.1-8b)</span></label>
|
||||
<input v-model="newModel.model_id" type="text" class="form-input" placeholder="org/model-name" />
|
||||
|
||||
<label class="form-label">Alias <span class="optional">(optional, short name for assignments)</span></label>
|
||||
<input v-model="newModel.alias" type="text" class="form-input" placeholder="e.g. granite-8b" />
|
||||
|
||||
<label class="form-label">Service type</label>
|
||||
<select v-model="newModel.service_type" class="form-select">
|
||||
<option value="" disabled>Select service…</option>
|
||||
<option value="cf-text">cf-text — Language Models</option>
|
||||
<option value="cf-stt">cf-stt — Speech Recognition</option>
|
||||
<option value="cf-tts">cf-tts — Text to Speech</option>
|
||||
<option value="cf-vision">cf-vision — Vision / VLM</option>
|
||||
<option value="cf-image">cf-image — Image Generation</option>
|
||||
<option value="cf-voice">cf-voice — Audio Classification</option>
|
||||
<option value="vllm">vllm — vLLM inference</option>
|
||||
<option value="ollama">ollama — Ollama inference</option>
|
||||
</select>
|
||||
|
||||
<label class="form-label">VRAM required (MB)</label>
|
||||
<input v-model.number="newModel.vram_mb" type="number" min="0" class="form-input" placeholder="e.g. 16384" />
|
||||
|
||||
<label class="form-label">HF Repo <span class="optional">(optional)</span></label>
|
||||
<input v-model="newModel.hf_repo" type="text" class="form-input" placeholder="org/repo-name" />
|
||||
|
||||
<label class="form-label">Description <span class="optional">(optional)</span></label>
|
||||
<input v-model="newModel.description" type="text" class="form-input" placeholder="Human-readable note" />
|
||||
|
||||
<div class="modal-actions">
|
||||
<button
|
||||
class="btn-primary"
|
||||
:disabled="!newModel.model_id || !newModel.service_type || !newModel.vram_mb || saving"
|
||||
@click="saveNewModel"
|
||||
>{{ saving ? 'Saving…' : 'Register' }}</button>
|
||||
<button class="btn-ghost" @click="showRegisterModal = false">Cancel</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, computed, onMounted } from 'vue'
|
||||
|
||||
// ── Types ──────────────────────────────────────────────
|
||||
|
||||
interface AssignmentNode {
|
||||
node_id: string
|
||||
status: 'present' | 'absent' | 'vram_tight'
|
||||
}
|
||||
|
||||
interface DeployingKey {
|
||||
nodeId: string
|
||||
assignmentKey: string
|
||||
}
|
||||
|
||||
interface Assignment {
|
||||
product: string
|
||||
task: string
|
||||
model_id: string
|
||||
description: string
|
||||
alias?: string
|
||||
service_type?: string
|
||||
vram_mb?: number
|
||||
nodes?: AssignmentNode[]
|
||||
}
|
||||
|
||||
interface RegistryModel {
|
||||
model_id: string
|
||||
alias: string
|
||||
service_type: string
|
||||
vram_mb: number
|
||||
hf_repo: string
|
||||
description: string
|
||||
}
|
||||
|
||||
interface ProductGroup {
|
||||
product: string
|
||||
assignments: Assignment[]
|
||||
}
|
||||
|
||||
interface Toast {
|
||||
message: string
|
||||
type: 'success' | 'error'
|
||||
}
|
||||
|
||||
// ── State ──────────────────────────────────────────────
|
||||
|
||||
const assignments = ref<Assignment[]>([])
|
||||
const assignmentsLoading = ref(false)
|
||||
const assignmentsError = ref<string | null>(null)
|
||||
|
||||
const registryModels = ref<RegistryModel[]>([])
|
||||
const registryLoading = ref(false)
|
||||
const registryError = ref<string | null>(null)
|
||||
|
||||
const productFilter = ref('')
|
||||
const editingKey = ref<string | null>(null)
|
||||
const editDraft = ref({ model_id: '', description: '' })
|
||||
|
||||
const showNewAssignmentModal = ref(false)
|
||||
const newAssignment = ref({ product: '', task: '', model_id: '', description: '' })
|
||||
|
||||
const showRegisterModal = ref(false)
|
||||
const newModel = ref({ model_id: '', alias: '', service_type: '', vram_mb: 0, hf_repo: '', description: '' })
|
||||
|
||||
const saving = ref(false)
|
||||
const toast = ref<Toast | null>(null)
|
||||
let toastTimer: ReturnType<typeof setTimeout> | null = null
|
||||
|
||||
const deploying = ref<Set<string>>(new Set())
|
||||
|
||||
// ── Derived ────────────────────────────────────────────
|
||||
|
||||
const allProducts = computed(() => {
|
||||
const seen = new Set<string>()
|
||||
for (const a of assignments.value) seen.add(a.product)
|
||||
return [...seen].sort()
|
||||
})
|
||||
|
||||
const deploymentMap = computed(() => {
|
||||
const map: Record<string, AssignmentNode[]> = {}
|
||||
for (const a of assignments.value) {
|
||||
if (a.nodes) map[`${a.product}/${a.task}`] = a.nodes
|
||||
}
|
||||
return map
|
||||
})
|
||||
|
||||
const filteredGroups = computed((): ProductGroup[] => {
|
||||
const filtered = productFilter.value
|
||||
? assignments.value.filter(a => a.product === productFilter.value)
|
||||
: assignments.value
|
||||
|
||||
const byProduct: Record<string, Assignment[]> = {}
|
||||
for (const a of filtered) {
|
||||
if (!byProduct[a.product]) byProduct[a.product] = []
|
||||
byProduct[a.product].push(a)
|
||||
}
|
||||
return Object.keys(byProduct)
|
||||
.sort()
|
||||
.map(product => ({ product, assignments: byProduct[product] }))
|
||||
})
|
||||
|
||||
// ── Helpers ────────────────────────────────────────────
|
||||
|
||||
function truncate(s: string, max: number): string {
|
||||
return s.length > max ? s.slice(0, max - 1) + '…' : s
|
||||
}
|
||||
|
||||
function displayModelId(a: Assignment): string {
|
||||
if (a.alias) return a.alias
|
||||
const id = a.model_id
|
||||
// Show only the model name part (after /) and truncate long slugs
|
||||
const short = id.includes('/') ? id.split('/').slice(1).join('/') : id
|
||||
return truncate(short, 36)
|
||||
}
|
||||
|
||||
function formatVram(mb: number | undefined): string {
|
||||
if (!mb) return ''
|
||||
if (mb >= 1024) return `${(mb / 1024).toFixed(1)} GB`
|
||||
return `${mb} MB`
|
||||
}
|
||||
|
||||
function serviceChipClass(service: string): string {
|
||||
return `chip-service-${service.replace(/[^a-z0-9]/g, '-')}`
|
||||
}
|
||||
|
||||
function nodeIcon(status: string): string {
|
||||
if (status === 'present') return '✓'
|
||||
if (status === 'vram_tight') return '~'
|
||||
return '✗'
|
||||
}
|
||||
|
||||
function showToast(message: string, type: 'success' | 'error' = 'success') {
|
||||
if (toastTimer) clearTimeout(toastTimer)
|
||||
toast.value = { message, type }
|
||||
toastTimer = setTimeout(() => { toast.value = null }, 3500)
|
||||
}
|
||||
|
||||
function openNewAssignment() {
|
||||
newAssignment.value = { product: '', task: '', model_id: '', description: '' }
|
||||
showNewAssignmentModal.value = true
|
||||
}
|
||||
|
||||
function startEdit(a: Assignment) {
|
||||
editingKey.value = `${a.product}/${a.task}`
|
||||
editDraft.value = { model_id: a.model_id, description: a.description }
|
||||
}
|
||||
|
||||
// ── API ────────────────────────────────────────────────
|
||||
|
||||
async function loadAssignments() {
|
||||
assignmentsLoading.value = true
|
||||
assignmentsError.value = null
|
||||
try {
|
||||
// Fetch both list and deployment status in parallel
|
||||
const [listRes, statusRes] = await Promise.all([
|
||||
fetch('/api/cforch/assignments'),
|
||||
fetch('/api/cforch/assignments/deployment-status'),
|
||||
])
|
||||
if (!listRes.ok) throw new Error(`HTTP ${listRes.status}`)
|
||||
const list: Assignment[] = (await listRes.json()).assignments ?? []
|
||||
|
||||
// Merge deployment status into assignments if available
|
||||
if (statusRes.ok) {
|
||||
const statusList: Assignment[] = (await statusRes.json()).deployment_status ?? []
|
||||
const statusMap: Record<string, AssignmentNode[]> = {}
|
||||
for (const s of statusList) {
|
||||
statusMap[`${s.product}/${s.task}`] = s.nodes ?? []
|
||||
}
|
||||
for (const a of list) {
|
||||
a.nodes = statusMap[`${a.product}/${a.task}`] ?? []
|
||||
// Enrich with service_type/vram_mb from status payload
|
||||
const s = statusList.find(x => x.product === a.product && x.task === a.task)
|
||||
if (s) {
|
||||
a.service_type = s.service_type
|
||||
a.vram_mb = s.vram_mb
|
||||
a.alias = s.alias
|
||||
}
|
||||
}
|
||||
}
|
||||
assignments.value = list
|
||||
} catch (e) {
|
||||
assignmentsError.value = `Could not load assignments: ${e}`
|
||||
} finally {
|
||||
assignmentsLoading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
async function loadRegistry() {
|
||||
registryLoading.value = true
|
||||
registryError.value = null
|
||||
try {
|
||||
const res = await fetch('/api/cforch/model-registry')
|
||||
if (!res.ok) throw new Error(`HTTP ${res.status}`)
|
||||
registryModels.value = (await res.json()).models ?? []
|
||||
} catch (e) {
|
||||
registryError.value = `Could not load model registry: ${e}`
|
||||
} finally {
|
||||
registryLoading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
async function saveNewAssignment() {
|
||||
saving.value = true
|
||||
try {
|
||||
const res = await fetch('/api/cforch/assignments', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify(newAssignment.value),
|
||||
})
|
||||
if (!res.ok) throw new Error(await res.text())
|
||||
showNewAssignmentModal.value = false
|
||||
showToast('Assignment saved')
|
||||
await loadAssignments()
|
||||
} catch (e) {
|
||||
showToast(`Save failed: ${e}`, 'error')
|
||||
} finally {
|
||||
saving.value = false
|
||||
}
|
||||
}
|
||||
|
||||
async function saveEdit(a: Assignment) {
|
||||
saving.value = true
|
||||
try {
|
||||
const res = await fetch('/api/cforch/assignments', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({
|
||||
product: a.product,
|
||||
task: a.task,
|
||||
model_id: editDraft.value.model_id,
|
||||
description: editDraft.value.description,
|
||||
}),
|
||||
})
|
||||
if (!res.ok) throw new Error(await res.text())
|
||||
editingKey.value = null
|
||||
showToast('Assignment updated')
|
||||
await loadAssignments()
|
||||
} catch (e) {
|
||||
showToast(`Update failed: ${e}`, 'error')
|
||||
} finally {
|
||||
saving.value = false
|
||||
}
|
||||
}
|
||||
|
||||
async function deleteAssignment(product: string, task: string) {
|
||||
if (!confirm(`Delete assignment ${product}.${task}?`)) return
|
||||
try {
|
||||
const res = await fetch(
|
||||
`/api/cforch/assignments/${encodeURIComponent(product)}/${encodeURIComponent(task)}`,
|
||||
{ method: 'DELETE' },
|
||||
)
|
||||
if (!res.ok) throw new Error(await res.text())
|
||||
showToast('Assignment deleted')
|
||||
await loadAssignments()
|
||||
} catch (e) {
|
||||
showToast(`Delete failed: ${e}`, 'error')
|
||||
}
|
||||
}
|
||||
|
||||
async function saveNewModel() {
|
||||
saving.value = true
|
||||
try {
|
||||
const res = await fetch('/api/cforch/model-registry', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify(newModel.value),
|
||||
})
|
||||
if (!res.ok) throw new Error(await res.text())
|
||||
showRegisterModal.value = false
|
||||
showToast('Model registered')
|
||||
await loadRegistry()
|
||||
} catch (e) {
|
||||
showToast(`Register failed: ${e}`, 'error')
|
||||
} finally {
|
||||
saving.value = false
|
||||
}
|
||||
}
|
||||
|
||||
async function deleteModel(model_id: string) {
|
||||
if (!confirm(`Remove ${model_id} from the registry?`)) return
|
||||
try {
|
||||
const res = await fetch(
|
||||
`/api/cforch/model-registry/${encodeURIComponent(model_id)}`,
|
||||
{ method: 'DELETE' },
|
||||
)
|
||||
if (!res.ok) throw new Error(await res.text())
|
||||
showToast('Model removed')
|
||||
await loadRegistry()
|
||||
} catch (e) {
|
||||
showToast(`Delete failed: ${e}`, 'error')
|
||||
}
|
||||
}
|
||||
|
||||
async function deployModel(a: Assignment, nodeId: string) {
|
||||
const key = `${a.product}/${a.task}/${nodeId}`
|
||||
if (deploying.value.has(key)) return
|
||||
|
||||
// Look up hf_repo from registry for cleaner path construction
|
||||
const regEntry = registryModels.value.find(m => m.model_id === a.model_id)
|
||||
const hf_repo = regEntry?.hf_repo ?? ''
|
||||
const service_type = a.service_type ?? regEntry?.service_type ?? ''
|
||||
const vram_mb = a.vram_mb ?? regEntry?.vram_mb ?? 0
|
||||
const description = regEntry?.alias ? `${regEntry.alias} (via assignments)` : ''
|
||||
|
||||
if (!service_type) {
|
||||
showToast(`No service type for model ${a.model_id}`, 'error')
|
||||
return
|
||||
}
|
||||
|
||||
deploying.value = new Set([...deploying.value, key])
|
||||
try {
|
||||
const res = await fetch(`/api/nodes-mgmt/nodes/${encodeURIComponent(nodeId)}/models/deploy`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ model_id: a.model_id, service_type, vram_mb, hf_repo, description }),
|
||||
})
|
||||
if (!res.ok) throw new Error(await res.text())
|
||||
const data = await res.json()
|
||||
showToast(`Registered ${a.model_id} on ${nodeId} at ${data.path}`)
|
||||
|
||||
// Optimistic update: flip node to 'present' immediately so the Register button
|
||||
// disappears before the coordinator reload confirms. loadAssignments() reconciles
|
||||
// with real server state on the next round-trip.
|
||||
assignments.value = assignments.value.map(asgn => {
|
||||
if (asgn.product !== a.product || asgn.task !== a.task) return asgn
|
||||
return {
|
||||
...asgn,
|
||||
nodes: (asgn.nodes ?? []).map(ns =>
|
||||
ns.node_id === nodeId ? { ...ns, status: 'present' as const } : ns
|
||||
),
|
||||
}
|
||||
})
|
||||
|
||||
await loadAssignments()
|
||||
} catch (e) {
|
||||
showToast(`Deploy failed: ${e}`, 'error')
|
||||
} finally {
|
||||
deploying.value = new Set([...deploying.value].filter(k => k !== key))
|
||||
}
|
||||
}
|
||||
|
||||
onMounted(() => {
|
||||
loadAssignments()
|
||||
loadRegistry()
|
||||
})
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
.assignments-tab {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 1.25rem;
|
||||
}
|
||||
|
||||
/* ── Toast ── */
|
||||
.toast {
|
||||
position: fixed;
|
||||
bottom: 1.5rem;
|
||||
right: 1.5rem;
|
||||
padding: 0.65rem 1.1rem;
|
||||
border-radius: 0.5rem;
|
||||
font-size: 0.88rem;
|
||||
font-weight: 500;
|
||||
z-index: 200;
|
||||
box-shadow: 0 2px 8px rgba(0,0,0,0.15);
|
||||
}
|
||||
.toast.success {
|
||||
background: var(--color-success, #2a8050);
|
||||
color: #fff;
|
||||
}
|
||||
.toast.error {
|
||||
background: var(--color-danger, #b03030);
|
||||
color: #fff;
|
||||
}
|
||||
|
||||
/* ── Section headers ── */
|
||||
.section-header {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: space-between;
|
||||
gap: 1rem;
|
||||
}
|
||||
.section-header-mt {
|
||||
margin-top: 1.5rem;
|
||||
}
|
||||
.section-title {
|
||||
font-size: 1rem;
|
||||
font-weight: 600;
|
||||
color: var(--app-primary, #2A6080);
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
/* ── Filter row ── */
|
||||
.filter-row {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.6rem;
|
||||
}
|
||||
.filter-label {
|
||||
font-size: 0.85rem;
|
||||
color: var(--color-text-muted, #6b7a99);
|
||||
}
|
||||
.filter-select {
|
||||
padding: 0.3rem 0.6rem;
|
||||
font-size: 0.85rem;
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: 0.4rem;
|
||||
background: var(--color-surface, #fff);
|
||||
color: var(--color-text, #1a2030);
|
||||
}
|
||||
|
||||
/* ── Product groups ── */
|
||||
.product-groups {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 1rem;
|
||||
}
|
||||
.product-group {}
|
||||
.product-name {
|
||||
font-size: 0.75rem;
|
||||
font-weight: 700;
|
||||
letter-spacing: 0.08em;
|
||||
color: var(--color-text-muted, #6b7a99);
|
||||
text-transform: uppercase;
|
||||
margin: 0 0 0.4rem;
|
||||
}
|
||||
.assignment-list {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0.4rem;
|
||||
}
|
||||
|
||||
/* ── Assignment rows ── */
|
||||
.assignment-row {
|
||||
background: var(--color-surface-raised, #f0f4fa);
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: 0.5rem;
|
||||
padding: 0.65rem 0.85rem;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0.4rem;
|
||||
}
|
||||
.assignment-main {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.5rem;
|
||||
flex-wrap: wrap;
|
||||
}
|
||||
.task-id {
|
||||
font-family: var(--font-mono, monospace);
|
||||
font-size: 0.88rem;
|
||||
font-weight: 600;
|
||||
color: var(--color-text, #1a2030);
|
||||
min-width: 0;
|
||||
}
|
||||
.model-name {
|
||||
font-size: 0.85rem;
|
||||
color: var(--color-text-muted, #6b7a99);
|
||||
white-space: nowrap;
|
||||
overflow: hidden;
|
||||
text-overflow: ellipsis;
|
||||
max-width: 280px;
|
||||
cursor: default;
|
||||
}
|
||||
.assignment-actions {
|
||||
display: flex;
|
||||
gap: 0.4rem;
|
||||
flex-wrap: wrap;
|
||||
}
|
||||
|
||||
/* ── Node status badges ── */
|
||||
.node-statuses {
|
||||
display: flex;
|
||||
gap: 0.35rem;
|
||||
flex-wrap: wrap;
|
||||
}
|
||||
.node-badge-wrap {
|
||||
display: inline-flex;
|
||||
align-items: center;
|
||||
gap: 0.2rem;
|
||||
}
|
||||
.node-badge {
|
||||
display: inline-flex;
|
||||
align-items: center;
|
||||
gap: 0.2rem;
|
||||
font-size: 0.78rem;
|
||||
padding: 0.15rem 0.5rem;
|
||||
border-radius: 0.35rem;
|
||||
font-weight: 500;
|
||||
}
|
||||
.node-badge.present {
|
||||
background: color-mix(in srgb, var(--color-success, #2a8050) 15%, transparent);
|
||||
color: var(--color-success, #2a8050);
|
||||
border: 1px solid color-mix(in srgb, var(--color-success, #2a8050) 30%, transparent);
|
||||
}
|
||||
.node-badge.absent {
|
||||
background: color-mix(in srgb, var(--color-danger, #b03030) 12%, transparent);
|
||||
color: var(--color-danger, #b03030);
|
||||
border: 1px solid color-mix(in srgb, var(--color-danger, #b03030) 25%, transparent);
|
||||
}
|
||||
.node-badge.vram_tight {
|
||||
background: color-mix(in srgb, #c08030 15%, transparent);
|
||||
color: #8a5500;
|
||||
border: 1px solid color-mix(in srgb, #c08030 30%, transparent);
|
||||
}
|
||||
.node-icon {
|
||||
font-size: 0.85em;
|
||||
}
|
||||
.btn-deploy {
|
||||
padding: 0.1rem 0.4rem;
|
||||
font-size: 0.72rem;
|
||||
font-weight: 600;
|
||||
background: color-mix(in srgb, var(--app-primary, #2A6080) 12%, transparent);
|
||||
color: var(--app-primary, #2A6080);
|
||||
border: 1px solid color-mix(in srgb, var(--app-primary, #2A6080) 30%, transparent);
|
||||
border-radius: 0.3rem;
|
||||
cursor: pointer;
|
||||
white-space: nowrap;
|
||||
transition: background 0.15s;
|
||||
}
|
||||
.btn-deploy:hover:not(:disabled) {
|
||||
background: color-mix(in srgb, var(--app-primary, #2A6080) 22%, transparent);
|
||||
}
|
||||
.btn-deploy:disabled { opacity: 0.5; cursor: default; }
|
||||
|
||||
/* ── Inline edit ── */
|
||||
.inline-edit {
|
||||
display: flex;
|
||||
flex-wrap: wrap;
|
||||
gap: 0.4rem;
|
||||
padding-top: 0.35rem;
|
||||
border-top: 1px solid var(--color-border, #d0d7e8);
|
||||
}
|
||||
.edit-select,
|
||||
.edit-input {
|
||||
flex: 1;
|
||||
min-width: 160px;
|
||||
padding: 0.35rem 0.55rem;
|
||||
font-size: 0.85rem;
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: 0.4rem;
|
||||
background: var(--color-surface, #fff);
|
||||
color: var(--color-text, #1a2030);
|
||||
}
|
||||
.inline-edit-btns {
|
||||
display: flex;
|
||||
gap: 0.35rem;
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
/* ── Registry table ── */
|
||||
.registry-table-wrap {
|
||||
overflow-x: auto;
|
||||
border-radius: 0.5rem;
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
}
|
||||
.registry-table {
|
||||
width: 100%;
|
||||
border-collapse: collapse;
|
||||
font-size: 0.85rem;
|
||||
}
|
||||
.registry-table th {
|
||||
text-align: left;
|
||||
padding: 0.5rem 0.75rem;
|
||||
font-size: 0.78rem;
|
||||
font-weight: 600;
|
||||
color: var(--color-text-muted, #6b7a99);
|
||||
background: var(--color-surface-raised, #f0f4fa);
|
||||
border-bottom: 1px solid var(--color-border, #d0d7e8);
|
||||
white-space: nowrap;
|
||||
}
|
||||
.registry-table td {
|
||||
padding: 0.5rem 0.75rem;
|
||||
border-bottom: 1px solid var(--color-border, #d0d7e8);
|
||||
vertical-align: middle;
|
||||
}
|
||||
.registry-table tbody tr:last-child td {
|
||||
border-bottom: none;
|
||||
}
|
||||
.truncated {
|
||||
display: inline-block;
|
||||
max-width: 220px;
|
||||
white-space: nowrap;
|
||||
overflow: hidden;
|
||||
text-overflow: ellipsis;
|
||||
vertical-align: bottom;
|
||||
cursor: default;
|
||||
}
|
||||
.hf-link {
|
||||
color: var(--app-primary, #2A6080);
|
||||
text-decoration: none;
|
||||
font-size: 0.82rem;
|
||||
}
|
||||
.hf-link:hover { text-decoration: underline; }
|
||||
.text-muted { color: var(--color-text-muted, #6b7a99); }
|
||||
|
||||
/* ── Chips ── */
|
||||
.chip {
|
||||
display: inline-block;
|
||||
padding: 0.15rem 0.5rem;
|
||||
border-radius: 0.35rem;
|
||||
font-size: 0.75rem;
|
||||
font-weight: 600;
|
||||
white-space: nowrap;
|
||||
}
|
||||
.chip-vram {
|
||||
background: color-mix(in srgb, var(--app-primary, #2A6080) 12%, transparent);
|
||||
color: var(--app-primary, #2A6080);
|
||||
border: 1px solid color-mix(in srgb, var(--app-primary, #2A6080) 25%, transparent);
|
||||
}
|
||||
/* service chips — match ModelsView convention */
|
||||
.chip-service-cf-text { background: #e8f0fe; color: #1a5276; border: 1px solid #a9c4e8; }
|
||||
.chip-service-cf-stt { background: #eaf6ea; color: #1e6b3a; border: 1px solid #a2d9b1; }
|
||||
.chip-service-cf-tts { background: #fdf3e3; color: #7d4e00; border: 1px solid #e8c98a; }
|
||||
.chip-service-cf-vision { background: #f3e8fd; color: #5b2d8e; border: 1px solid #c8a0e8; }
|
||||
.chip-service-cf-image { background: #fce8f0; color: #8e1a4f; border: 1px solid #e8a0c0; }
|
||||
.chip-service-cf-voice { background: #e8f8fc; color: #0a5c6e; border: 1px solid #88d0e0; }
|
||||
.chip-service-vllm { background: #f5ece0; color: #7a3800; border: 1px solid #d4a87a; }
|
||||
.chip-service-ollama { background: #eeeeee; color: #444; border: 1px solid #ccc; }
|
||||
|
||||
/* ── Buttons ── */
|
||||
.btn-primary {
|
||||
padding: 0.45rem 1rem;
|
||||
background: var(--app-primary, #2A6080);
|
||||
color: #fff;
|
||||
border: none;
|
||||
border-radius: 0.4rem;
|
||||
font-size: 0.85rem;
|
||||
font-weight: 600;
|
||||
cursor: pointer;
|
||||
transition: opacity 0.15s;
|
||||
}
|
||||
.btn-primary:disabled { opacity: 0.5; cursor: default; }
|
||||
.btn-primary:not(:disabled):hover { opacity: 0.88; }
|
||||
|
||||
.btn-ghost {
|
||||
padding: 0.35rem 0.75rem;
|
||||
background: transparent;
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: 0.4rem;
|
||||
font-size: 0.82rem;
|
||||
color: var(--color-text-muted, #6b7a99);
|
||||
cursor: pointer;
|
||||
transition: background 0.15s;
|
||||
}
|
||||
.btn-ghost:hover { background: var(--color-surface-raised, #e4ebf5); }
|
||||
.btn-ghost.btn-danger { color: var(--color-danger, #b03030); border-color: color-mix(in srgb, var(--color-danger, #b03030) 30%, transparent); }
|
||||
.btn-ghost.btn-danger:hover { background: color-mix(in srgb, var(--color-danger, #b03030) 10%, transparent); }
|
||||
|
||||
.btn-sm { padding: 0.3rem 0.65rem; font-size: 0.8rem; }
|
||||
|
||||
/* ── Empty / error states ── */
|
||||
.empty-state {
|
||||
padding: 1.5rem;
|
||||
text-align: center;
|
||||
color: var(--color-text-muted, #6b7a99);
|
||||
font-size: 0.9rem;
|
||||
background: var(--color-surface-raised, #f0f4fa);
|
||||
border: 1px dashed var(--color-border, #d0d7e8);
|
||||
border-radius: 0.5rem;
|
||||
}
|
||||
.error-notice {
|
||||
padding: 0.75rem 1rem;
|
||||
background: color-mix(in srgb, var(--color-danger, #b03030) 10%, transparent);
|
||||
color: var(--color-danger, #b03030);
|
||||
border: 1px solid color-mix(in srgb, var(--color-danger, #b03030) 25%, transparent);
|
||||
border-radius: 0.4rem;
|
||||
font-size: 0.87rem;
|
||||
}
|
||||
|
||||
/* ── Modal ── */
|
||||
.modal-backdrop {
|
||||
position: fixed;
|
||||
inset: 0;
|
||||
background: rgba(0,0,0,0.35);
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
z-index: 100;
|
||||
padding: 1rem;
|
||||
}
|
||||
.modal {
|
||||
background: var(--color-surface, #fff);
|
||||
border-radius: 0.65rem;
|
||||
padding: 1.5rem;
|
||||
width: 100%;
|
||||
max-width: 480px;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0.65rem;
|
||||
box-shadow: 0 8px 32px rgba(0,0,0,0.18);
|
||||
max-height: 90vh;
|
||||
overflow-y: auto;
|
||||
}
|
||||
.modal-title {
|
||||
font-size: 1rem;
|
||||
font-weight: 700;
|
||||
color: var(--app-primary, #2A6080);
|
||||
margin: 0 0 0.25rem;
|
||||
}
|
||||
.form-label {
|
||||
font-size: 0.82rem;
|
||||
font-weight: 600;
|
||||
color: var(--color-text-muted, #6b7a99);
|
||||
}
|
||||
.form-input,
|
||||
.form-select {
|
||||
padding: 0.4rem 0.65rem;
|
||||
font-size: 0.88rem;
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: 0.4rem;
|
||||
background: var(--color-surface, #fff);
|
||||
color: var(--color-text, #1a2030);
|
||||
width: 100%;
|
||||
box-sizing: border-box;
|
||||
}
|
||||
.form-input:focus, .form-select:focus {
|
||||
outline: 2px solid var(--app-primary, #2A6080);
|
||||
outline-offset: 1px;
|
||||
}
|
||||
.modal-actions {
|
||||
display: flex;
|
||||
gap: 0.5rem;
|
||||
justify-content: flex-end;
|
||||
margin-top: 0.25rem;
|
||||
}
|
||||
.optional, .hint {
|
||||
font-weight: 400;
|
||||
color: var(--color-text-muted, #6b7a99);
|
||||
font-size: 0.78rem;
|
||||
}
|
||||
|
||||
/* ── Responsive ── */
|
||||
@media (max-width: 600px) {
|
||||
.assignment-main { flex-direction: column; align-items: flex-start; }
|
||||
.col-hf { display: none; }
|
||||
.model-name { max-width: 100%; }
|
||||
.modal { padding: 1rem; }
|
||||
}
|
||||
</style>
|
||||
|
|
@ -1,82 +0,0 @@
|
|||
import { mount, flushPromises } from '@vue/test-utils'
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest'
|
||||
import BenchmarkView from './BenchmarkView.vue'
|
||||
|
||||
beforeEach(() => {
|
||||
vi.stubGlobal('fetch', vi.fn().mockImplementation((url: string) => {
|
||||
// LlmEvalTab calls /api/cforch/models and expects { models: CfOrchModel[] }
|
||||
if (url.includes('/api/cforch/models')) {
|
||||
return Promise.resolve({
|
||||
ok: true,
|
||||
json: async () => ({ models: [] }),
|
||||
text: async () => '',
|
||||
})
|
||||
}
|
||||
// Default: satisfies ClassifierTab (/api/benchmark/results, /api/benchmark/models,
|
||||
// /api/finetune/status), StyleTab (/api/style/models, /api/style/results),
|
||||
// and any other tab that tolerates empty arrays/objects.
|
||||
return Promise.resolve({
|
||||
ok: true,
|
||||
json: async () => ({ models: {}, categories: {}, tasks: [], types: [], results: [] }),
|
||||
text: async () => '',
|
||||
})
|
||||
}))
|
||||
vi.stubGlobal('EventSource', class {
|
||||
onmessage = null
|
||||
onerror = null
|
||||
close() {}
|
||||
})
|
||||
})
|
||||
|
||||
describe('BenchmarkView', () => {
|
||||
it('renders page title "Benchmark"', async () => {
|
||||
const w = mount(BenchmarkView)
|
||||
await flushPromises()
|
||||
expect(w.text()).toContain('Benchmark')
|
||||
})
|
||||
|
||||
it('has mode buttons: Classifier, LLM Eval, Writing Style', async () => {
|
||||
const w = mount(BenchmarkView)
|
||||
await flushPromises()
|
||||
const text = w.text()
|
||||
expect(text).toContain('Classifier')
|
||||
expect(text).toContain('LLM Eval')
|
||||
expect(text).toContain('Writing Style')
|
||||
})
|
||||
|
||||
it('does NOT have a Compare mode button', async () => {
|
||||
const w = mount(BenchmarkView)
|
||||
await flushPromises()
|
||||
const buttons = w.findAll('.mode-btn')
|
||||
const labels = buttons.map(b => b.text())
|
||||
expect(labels.every(l => !l.includes('Compare'))).toBe(true)
|
||||
})
|
||||
|
||||
it('shows Classifier tab by default', async () => {
|
||||
const w = mount(BenchmarkView)
|
||||
await flushPromises()
|
||||
// ClassifierTab has a .classifier-tab root
|
||||
expect(w.find('.classifier-tab').exists()).toBe(true)
|
||||
})
|
||||
|
||||
it('switches to LlmEvalTab when LLM Eval clicked', async () => {
|
||||
const w = mount(BenchmarkView)
|
||||
await flushPromises()
|
||||
const llmBtn = w.findAll('.mode-btn').find(b => b.text().includes('LLM Eval'))!
|
||||
await llmBtn.trigger('click')
|
||||
await flushPromises()
|
||||
expect(w.find('.llm-eval-tab').exists()).toBe(true)
|
||||
expect(w.find('.classifier-tab').exists()).toBe(false)
|
||||
expect(llmBtn.classes()).toContain('active')
|
||||
})
|
||||
|
||||
it('switches to StyleTab when Writing Style clicked', async () => {
|
||||
const w = mount(BenchmarkView)
|
||||
await flushPromises()
|
||||
const styleBtn = w.findAll('.mode-btn').find(b => b.text().includes('Writing Style'))!
|
||||
await styleBtn.trigger('click')
|
||||
await flushPromises()
|
||||
expect(w.find('.style-tab').exists()).toBe(true)
|
||||
expect(w.find('.classifier-tab').exists()).toBe(false)
|
||||
})
|
||||
})
|
||||
|
|
@ -2,48 +2,432 @@
|
|||
<div class="bench-view">
|
||||
<header class="bench-header">
|
||||
<h1 class="page-title">🏁 Benchmark</h1>
|
||||
<div class="header-actions">
|
||||
<label class="slow-toggle" :class="{ disabled: running }">
|
||||
<input type="checkbox" v-model="includeSlow" :disabled="running" />
|
||||
Include slow models
|
||||
</label>
|
||||
<button
|
||||
class="btn-run"
|
||||
:disabled="running"
|
||||
@click="startBenchmark"
|
||||
>
|
||||
{{ running ? '⏳ Running…' : results ? '🔄 Re-run' : '▶ Run Benchmark' }}
|
||||
</button>
|
||||
<button
|
||||
v-if="running"
|
||||
class="btn-cancel"
|
||||
@click="cancelBenchmark"
|
||||
>
|
||||
✕ Cancel
|
||||
</button>
|
||||
</div>
|
||||
</header>
|
||||
|
||||
<!-- Mode toggle -->
|
||||
<div class="mode-toggle" role="group" aria-label="Benchmark mode">
|
||||
<button
|
||||
class="mode-btn"
|
||||
:class="{ active: benchMode === 'classifier' }"
|
||||
@click="benchMode = 'classifier'"
|
||||
>Classifier</button>
|
||||
<button
|
||||
class="mode-btn"
|
||||
:class="{ active: benchMode === 'llm' }"
|
||||
@click="benchMode = 'llm'"
|
||||
>🤖 LLM Eval</button>
|
||||
<button
|
||||
class="mode-btn"
|
||||
:class="{ active: benchMode === 'style' }"
|
||||
@click="benchMode = 'style'"
|
||||
>✍️ Writing Style</button>
|
||||
<button
|
||||
class="mode-btn"
|
||||
:class="{ active: benchMode === 'plans' }"
|
||||
@click="benchMode = 'plans'"
|
||||
>📐 Planning</button>
|
||||
<!-- Trained models badge row -->
|
||||
<div v-if="fineTunedModels.length > 0" class="trained-models-row">
|
||||
<span class="trained-label">Trained:</span>
|
||||
<span
|
||||
v-for="m in fineTunedModels"
|
||||
:key="m.name"
|
||||
class="trained-badge"
|
||||
:title="m.base_model_id ? `Base: ${m.base_model_id} · ${m.sample_count ?? '?'} samples` : m.name"
|
||||
>
|
||||
{{ m.name }}
|
||||
<span v-if="m.val_macro_f1 != null" class="trained-f1">
|
||||
F1 {{ (m.val_macro_f1 * 100).toFixed(1) }}%
|
||||
</span>
|
||||
</span>
|
||||
</div>
|
||||
|
||||
<ClassifierTab v-if="benchMode === 'classifier'" />
|
||||
<LlmEvalTab v-if="benchMode === 'llm'" />
|
||||
<StyleTab v-if="benchMode === 'style'" />
|
||||
<PlansBenchTab v-if="benchMode === 'plans'" />
|
||||
<!-- Progress log -->
|
||||
<div v-if="running || runLog.length" class="run-log">
|
||||
<div class="run-log-title">
|
||||
<span>{{ running ? '⏳ Running benchmark…' : runCancelled ? '⏹ Cancelled' : runError ? '❌ Failed' : '✅ Done' }}</span>
|
||||
<button class="btn-ghost" @click="runLog = []; runError = ''; runCancelled = false">Clear</button>
|
||||
</div>
|
||||
<div class="log-lines" ref="logEl">
|
||||
<div
|
||||
v-for="(line, i) in runLog"
|
||||
:key="i"
|
||||
class="log-line"
|
||||
:class="{ 'log-error': line.startsWith('ERROR') || line.startsWith('[error]') }"
|
||||
>{{ line }}</div>
|
||||
</div>
|
||||
<p v-if="runError" class="run-error">{{ runError }}</p>
|
||||
</div>
|
||||
|
||||
<!-- Loading -->
|
||||
<div v-if="loading" class="status-notice">Loading…</div>
|
||||
|
||||
<!-- No results yet -->
|
||||
<div v-else-if="!results" class="status-notice empty">
|
||||
<p>No benchmark results yet.</p>
|
||||
<p class="hint">Click <strong>Run Benchmark</strong> to score all default models against your labeled data.</p>
|
||||
</div>
|
||||
|
||||
<!-- Results -->
|
||||
<template v-else>
|
||||
<p class="meta-line">
|
||||
<span>{{ results.sample_count.toLocaleString() }} labeled emails</span>
|
||||
<span class="sep">·</span>
|
||||
<span>{{ modelCount }} model{{ modelCount === 1 ? '' : 's' }}</span>
|
||||
<span class="sep">·</span>
|
||||
<span>{{ formatDate(results.timestamp) }}</span>
|
||||
</p>
|
||||
|
||||
<!-- Macro-F1 chart -->
|
||||
<section class="chart-section">
|
||||
<h2 class="chart-title">Macro-F1 (higher = better)</h2>
|
||||
<div class="bar-chart">
|
||||
<div v-for="row in f1Rows" :key="row.name" class="bar-row">
|
||||
<span class="bar-label" :title="row.name">{{ row.name }}</span>
|
||||
<div class="bar-track">
|
||||
<div
|
||||
class="bar-fill"
|
||||
:style="{ width: `${row.pct}%`, background: scoreColor(row.value) }"
|
||||
/>
|
||||
</div>
|
||||
<span class="bar-value" :style="{ color: scoreColor(row.value) }">
|
||||
{{ row.value.toFixed(3) }}
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
</section>
|
||||
|
||||
<!-- Latency chart -->
|
||||
<section class="chart-section">
|
||||
<h2 class="chart-title">Latency (ms / email, lower = better)</h2>
|
||||
<div class="bar-chart">
|
||||
<div v-for="row in latencyRows" :key="row.name" class="bar-row">
|
||||
<span class="bar-label" :title="row.name">{{ row.name }}</span>
|
||||
<div class="bar-track">
|
||||
<div
|
||||
class="bar-fill latency-fill"
|
||||
:style="{ width: `${row.pct}%` }"
|
||||
/>
|
||||
</div>
|
||||
<span class="bar-value">{{ row.value.toFixed(1) }} ms</span>
|
||||
</div>
|
||||
</div>
|
||||
</section>
|
||||
|
||||
<!-- Per-label F1 heatmap -->
|
||||
<section class="chart-section">
|
||||
<h2 class="chart-title">Per-label F1</h2>
|
||||
<div class="heatmap-scroll">
|
||||
<table class="heatmap">
|
||||
<thead>
|
||||
<tr>
|
||||
<th class="hm-label-col">Label</th>
|
||||
<th v-for="name in modelNames" :key="name" class="hm-model-col" :title="name">
|
||||
{{ name }}
|
||||
</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
<tr v-for="label in labelNames" :key="label">
|
||||
<td class="hm-label-cell">
|
||||
<span class="hm-emoji">{{ LABEL_META[label]?.emoji ?? '🏷️' }}</span>
|
||||
{{ label.replace(/_/g, '\u00a0') }}
|
||||
</td>
|
||||
<td
|
||||
v-for="name in modelNames"
|
||||
:key="name"
|
||||
class="hm-value-cell"
|
||||
:style="{ background: heatmapBg(f1For(name, label)), color: heatmapFg(f1For(name, label)) }"
|
||||
:title="`${name} / ${label}: F1 ${f1For(name, label).toFixed(3)}, support ${supportFor(name, label)}`"
|
||||
>
|
||||
{{ f1For(name, label).toFixed(2) }}
|
||||
</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
<p class="heatmap-hint">Hover a cell for precision / recall / support. Color: 🟢 ≥ 0.7 · 🟡 0.4–0.7 · 🔴 < 0.4</p>
|
||||
</section>
|
||||
</template>
|
||||
|
||||
<!-- Fine-tune section -->
|
||||
<details class="ft-section">
|
||||
<summary class="ft-summary">Fine-tune a model</summary>
|
||||
<div class="ft-body">
|
||||
<div class="ft-controls">
|
||||
<label class="ft-field">
|
||||
<span class="ft-field-label">Model</span>
|
||||
<select v-model="ftModel" class="ft-select" :disabled="ftRunning">
|
||||
<option value="deberta-small">deberta-small (100M, fast)</option>
|
||||
<option value="bge-m3">bge-m3 (600M — stop Peregrine vLLM first)</option>
|
||||
</select>
|
||||
</label>
|
||||
<label class="ft-field">
|
||||
<span class="ft-field-label">Epochs</span>
|
||||
<input
|
||||
v-model.number="ftEpochs"
|
||||
type="number" min="1" max="20"
|
||||
class="ft-epochs"
|
||||
:disabled="ftRunning"
|
||||
/>
|
||||
</label>
|
||||
<button
|
||||
class="btn-run ft-run-btn"
|
||||
:disabled="ftRunning"
|
||||
@click="startFinetune"
|
||||
>
|
||||
{{ ftRunning ? '⏳ Training…' : '▶ Run fine-tune' }}
|
||||
</button>
|
||||
<button
|
||||
v-if="ftRunning"
|
||||
class="btn-cancel"
|
||||
@click="cancelFinetune"
|
||||
>
|
||||
✕ Cancel
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<div v-if="ftRunning || ftLog.length || ftError" class="run-log ft-log">
|
||||
<div class="run-log-title">
|
||||
<span>{{ ftRunning ? '⏳ Training…' : ftCancelled ? '⏹ Cancelled' : ftError ? '❌ Failed' : '✅ Done' }}</span>
|
||||
<button class="btn-ghost" @click="ftLog = []; ftError = ''; ftCancelled = false">Clear</button>
|
||||
</div>
|
||||
<div class="log-lines" ref="ftLogEl">
|
||||
<div
|
||||
v-for="(line, i) in ftLog"
|
||||
:key="i"
|
||||
class="log-line"
|
||||
:class="{ 'log-error': line.startsWith('ERROR') || line.startsWith('[error]') }"
|
||||
>{{ line }}</div>
|
||||
</div>
|
||||
<p v-if="ftError" class="run-error">{{ ftError }}</p>
|
||||
</div>
|
||||
</div>
|
||||
</details>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref } from 'vue'
|
||||
import ClassifierTab from './ClassifierTab.vue'
|
||||
import LlmEvalTab from './LlmEvalTab.vue'
|
||||
import StyleTab from './StyleTab.vue'
|
||||
import PlansBenchTab from './PlansBenchTab.vue'
|
||||
import { ref, computed, onMounted, nextTick } from 'vue'
|
||||
import { useApiFetch, useApiSSE } from '../composables/useApi'
|
||||
|
||||
type BenchMode = 'classifier' | 'llm' | 'style' | 'plans'
|
||||
const benchMode = ref<BenchMode>('classifier')
|
||||
// ── Label metadata (same as StatsView) ──────────────────────────────────────
|
||||
const LABEL_META: Record<string, { emoji: string }> = {
|
||||
interview_scheduled: { emoji: '🗓️' },
|
||||
offer_received: { emoji: '🎉' },
|
||||
rejected: { emoji: '❌' },
|
||||
positive_response: { emoji: '👍' },
|
||||
survey_received: { emoji: '📋' },
|
||||
neutral: { emoji: '⬜' },
|
||||
event_rescheduled: { emoji: '🔄' },
|
||||
digest: { emoji: '📰' },
|
||||
new_lead: { emoji: '🤝' },
|
||||
hired: { emoji: '🎊' },
|
||||
}
|
||||
|
||||
// ── Types ────────────────────────────────────────────────────────────────────
|
||||
interface FineTunedModel {
|
||||
name: string
|
||||
base_model_id?: string
|
||||
val_macro_f1?: number
|
||||
timestamp?: string
|
||||
sample_count?: number
|
||||
}
|
||||
|
||||
interface PerLabel { f1: number; precision: number; recall: number; support: number }
|
||||
interface ModelResult {
|
||||
macro_f1: number
|
||||
accuracy: number
|
||||
latency_ms: number
|
||||
per_label: Record<string, PerLabel>
|
||||
}
|
||||
interface BenchResults {
|
||||
timestamp: string | null
|
||||
sample_count: number
|
||||
models: Record<string, ModelResult>
|
||||
}
|
||||
|
||||
// ── State ────────────────────────────────────────────────────────────────────
|
||||
const results = ref<BenchResults | null>(null)
|
||||
const loading = ref(true)
|
||||
const running = ref(false)
|
||||
const runLog = ref<string[]>([])
|
||||
const runError = ref('')
|
||||
const includeSlow = ref(false)
|
||||
const logEl = ref<HTMLElement | null>(null)
|
||||
|
||||
// Fine-tune state
|
||||
const fineTunedModels = ref<FineTunedModel[]>([])
|
||||
const ftModel = ref('deberta-small')
|
||||
const ftEpochs = ref(5)
|
||||
const ftRunning = ref(false)
|
||||
const ftLog = ref<string[]>([])
|
||||
const ftError = ref('')
|
||||
const ftLogEl = ref<HTMLElement | null>(null)
|
||||
|
||||
const runCancelled = ref(false)
|
||||
const ftCancelled = ref(false)
|
||||
|
||||
async function cancelBenchmark() {
|
||||
await fetch('/api/benchmark/cancel', { method: 'POST' }).catch(() => {})
|
||||
}
|
||||
|
||||
async function cancelFinetune() {
|
||||
await fetch('/api/finetune/cancel', { method: 'POST' }).catch(() => {})
|
||||
}
|
||||
|
||||
// ── Derived ──────────────────────────────────────────────────────────────────
|
||||
const modelNames = computed(() => Object.keys(results.value?.models ?? {}))
|
||||
const modelCount = computed(() => modelNames.value.length)
|
||||
|
||||
const labelNames = computed(() => {
|
||||
const canonical = Object.keys(LABEL_META)
|
||||
const inResults = new Set(
|
||||
modelNames.value.flatMap(n => Object.keys(results.value!.models[n].per_label))
|
||||
)
|
||||
return [...canonical.filter(l => inResults.has(l)), ...[...inResults].filter(l => !canonical.includes(l))]
|
||||
})
|
||||
|
||||
const f1Rows = computed(() => {
|
||||
if (!results.value) return []
|
||||
const rows = modelNames.value.map(name => ({
|
||||
name,
|
||||
value: results.value!.models[name].macro_f1,
|
||||
}))
|
||||
rows.sort((a, b) => b.value - a.value)
|
||||
const max = rows[0]?.value || 1
|
||||
return rows.map(r => ({ ...r, pct: Math.round((r.value / max) * 100) }))
|
||||
})
|
||||
|
||||
const latencyRows = computed(() => {
|
||||
if (!results.value) return []
|
||||
const rows = modelNames.value.map(name => ({
|
||||
name,
|
||||
value: results.value!.models[name].latency_ms,
|
||||
}))
|
||||
rows.sort((a, b) => a.value - b.value) // fastest first
|
||||
const max = rows[rows.length - 1]?.value || 1
|
||||
return rows.map(r => ({ ...r, pct: Math.round((r.value / max) * 100) }))
|
||||
})
|
||||
|
||||
// ── Helpers ──────────────────────────────────────────────────────────────────
|
||||
function f1For(model: string, label: string): number {
|
||||
return results.value?.models[model]?.per_label[label]?.f1 ?? 0
|
||||
}
|
||||
function supportFor(model: string, label: string): number {
|
||||
return results.value?.models[model]?.per_label[label]?.support ?? 0
|
||||
}
|
||||
|
||||
function scoreColor(v: number): string {
|
||||
if (v >= 0.7) return 'var(--color-success, #4CAF50)'
|
||||
if (v >= 0.4) return 'var(--app-accent, #B8622A)'
|
||||
return 'var(--color-error, #ef4444)'
|
||||
}
|
||||
|
||||
function heatmapBg(v: number): string {
|
||||
// Blend red→yellow→green using the F1 value
|
||||
if (v >= 0.7) return `color-mix(in srgb, #4CAF50 ${Math.round(v * 100)}%, #1a2338 ${Math.round((1 - v) * 80)}%)`
|
||||
if (v >= 0.4) return `color-mix(in srgb, #FF9800 ${Math.round(v * 120)}%, #1a2338 40%)`
|
||||
return `color-mix(in srgb, #ef4444 ${Math.round(v * 200 + 30)}%, #1a2338 60%)`
|
||||
}
|
||||
function heatmapFg(v: number): string {
|
||||
return v >= 0.5 ? '#fff' : 'rgba(255,255,255,0.75)'
|
||||
}
|
||||
|
||||
function formatDate(iso: string | null): string {
|
||||
if (!iso) return 'unknown date'
|
||||
const d = new Date(iso)
|
||||
return d.toLocaleString(undefined, { dateStyle: 'medium', timeStyle: 'short' })
|
||||
}
|
||||
|
||||
// ── Data loading ─────────────────────────────────────────────────────────────
|
||||
async function loadResults() {
|
||||
loading.value = true
|
||||
const { data } = await useApiFetch<BenchResults>('/api/benchmark/results')
|
||||
loading.value = false
|
||||
if (data && Object.keys(data.models).length > 0) {
|
||||
results.value = data
|
||||
}
|
||||
}
|
||||
|
||||
// ── Benchmark run ─────────────────────────────────────────────────────────────
|
||||
function startBenchmark() {
|
||||
running.value = true
|
||||
runLog.value = []
|
||||
runError.value = ''
|
||||
runCancelled.value = false
|
||||
|
||||
const url = `/api/benchmark/run${includeSlow.value ? '?include_slow=true' : ''}`
|
||||
useApiSSE(
|
||||
url,
|
||||
async (event) => {
|
||||
if (event.type === 'progress' && typeof event.message === 'string') {
|
||||
runLog.value.push(event.message)
|
||||
await nextTick()
|
||||
logEl.value?.scrollTo({ top: logEl.value.scrollHeight, behavior: 'smooth' })
|
||||
}
|
||||
if (event.type === 'error' && typeof event.message === 'string') {
|
||||
runError.value = event.message
|
||||
}
|
||||
if (event.type === 'cancelled') {
|
||||
running.value = false
|
||||
runCancelled.value = true
|
||||
}
|
||||
},
|
||||
async () => {
|
||||
running.value = false
|
||||
await loadResults()
|
||||
},
|
||||
() => {
|
||||
running.value = false
|
||||
if (!runError.value) runError.value = 'Connection lost'
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
async function loadFineTunedModels() {
|
||||
const { data } = await useApiFetch<FineTunedModel[]>('/api/finetune/status')
|
||||
if (Array.isArray(data)) fineTunedModels.value = data
|
||||
}
|
||||
|
||||
function startFinetune() {
|
||||
if (ftRunning.value) return
|
||||
ftRunning.value = true
|
||||
ftLog.value = []
|
||||
ftError.value = ''
|
||||
ftCancelled.value = false
|
||||
|
||||
const params = new URLSearchParams({ model: ftModel.value, epochs: String(ftEpochs.value) })
|
||||
useApiSSE(
|
||||
`/api/finetune/run?${params}`,
|
||||
async (event) => {
|
||||
if (event.type === 'progress' && typeof event.message === 'string') {
|
||||
ftLog.value.push(event.message)
|
||||
await nextTick()
|
||||
ftLogEl.value?.scrollTo({ top: ftLogEl.value.scrollHeight, behavior: 'smooth' })
|
||||
}
|
||||
if (event.type === 'error' && typeof event.message === 'string') {
|
||||
ftError.value = event.message
|
||||
}
|
||||
if (event.type === 'cancelled') {
|
||||
ftRunning.value = false
|
||||
ftCancelled.value = true
|
||||
}
|
||||
},
|
||||
async () => {
|
||||
ftRunning.value = false
|
||||
await loadFineTunedModels()
|
||||
startBenchmark() // auto-trigger benchmark to refresh charts
|
||||
},
|
||||
() => {
|
||||
ftRunning.value = false
|
||||
if (!ftError.value) ftError.value = 'Connection lost'
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
onMounted(() => {
|
||||
loadResults()
|
||||
loadFineTunedModels()
|
||||
})
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
|
|
@ -59,6 +443,9 @@ const benchMode = ref<BenchMode>('classifier')
|
|||
.bench-header {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: space-between;
|
||||
flex-wrap: wrap;
|
||||
gap: 0.75rem;
|
||||
}
|
||||
|
||||
.page-title {
|
||||
|
|
@ -69,41 +456,391 @@ const benchMode = ref<BenchMode>('classifier')
|
|||
margin: 0;
|
||||
}
|
||||
|
||||
/* ── Mode toggle (segmented control) ── */
|
||||
.mode-toggle {
|
||||
display: inline-flex;
|
||||
.header-actions {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.75rem;
|
||||
flex-wrap: wrap;
|
||||
}
|
||||
|
||||
.slow-toggle {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.4rem;
|
||||
font-size: 0.85rem;
|
||||
color: var(--color-text-secondary, #6b7a99);
|
||||
cursor: pointer;
|
||||
user-select: none;
|
||||
}
|
||||
.slow-toggle.disabled { opacity: 0.5; pointer-events: none; }
|
||||
|
||||
.btn-run {
|
||||
padding: 0.45rem 1.1rem;
|
||||
border-radius: 0.375rem;
|
||||
border: none;
|
||||
background: var(--app-primary, #2A6080);
|
||||
color: #fff;
|
||||
font-size: 0.88rem;
|
||||
font-family: var(--font-body, sans-serif);
|
||||
cursor: pointer;
|
||||
transition: opacity 0.15s;
|
||||
}
|
||||
.btn-run:disabled { opacity: 0.5; cursor: not-allowed; }
|
||||
.btn-run:not(:disabled):hover { opacity: 0.85; }
|
||||
|
||||
.btn-cancel {
|
||||
padding: 0.45rem 0.9rem;
|
||||
background: transparent;
|
||||
border: 1px solid var(--color-text-secondary, #6b7a99);
|
||||
color: var(--color-text-secondary, #6b7a99);
|
||||
border-radius: 0.4rem;
|
||||
font-size: 0.85rem;
|
||||
font-weight: 500;
|
||||
cursor: pointer;
|
||||
transition: background 0.15s;
|
||||
}
|
||||
|
||||
.btn-cancel:hover {
|
||||
background: color-mix(in srgb, var(--color-text-secondary, #6b7a99) 12%, transparent);
|
||||
}
|
||||
|
||||
/* ── Run log ────────────────────────────────────────────── */
|
||||
.run-log {
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: 0.5rem;
|
||||
overflow: hidden;
|
||||
align-self: flex-start;
|
||||
font-family: var(--font-mono, monospace);
|
||||
font-size: 0.78rem;
|
||||
}
|
||||
|
||||
.mode-btn {
|
||||
padding: 0.4rem 1.1rem;
|
||||
font-size: 0.85rem;
|
||||
font-family: var(--font-body, sans-serif);
|
||||
font-weight: 500;
|
||||
.run-log-title {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
padding: 0.4rem 0.75rem;
|
||||
background: var(--color-surface-raised, #e4ebf5);
|
||||
border-bottom: 1px solid var(--color-border, #d0d7e8);
|
||||
font-size: 0.8rem;
|
||||
color: var(--color-text-secondary, #6b7a99);
|
||||
}
|
||||
|
||||
.btn-ghost {
|
||||
background: none;
|
||||
border: none;
|
||||
background: var(--color-surface, #fff);
|
||||
color: var(--color-text-secondary, #6b7a99);
|
||||
cursor: pointer;
|
||||
transition: background 0.15s, color 0.15s;
|
||||
font-size: 0.78rem;
|
||||
padding: 0.1rem 0.3rem;
|
||||
border-radius: 0.2rem;
|
||||
}
|
||||
.btn-ghost:hover { background: var(--color-border, #d0d7e8); }
|
||||
|
||||
.log-lines {
|
||||
max-height: 200px;
|
||||
overflow-y: auto;
|
||||
padding: 0.5rem 0.75rem;
|
||||
background: var(--color-surface, #fff);
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0.1rem;
|
||||
}
|
||||
|
||||
.mode-btn:not(:last-child) {
|
||||
border-right: 1px solid var(--color-border, #d0d7e8);
|
||||
.log-line { color: var(--color-text, #1a2338); line-height: 1.5; }
|
||||
.log-line.log-error { color: var(--color-error, #ef4444); }
|
||||
|
||||
.run-error {
|
||||
margin: 0;
|
||||
padding: 0.4rem 0.75rem;
|
||||
background: color-mix(in srgb, var(--color-error, #ef4444) 10%, transparent);
|
||||
color: var(--color-error, #ef4444);
|
||||
font-size: 0.82rem;
|
||||
font-family: var(--font-mono, monospace);
|
||||
}
|
||||
|
||||
.mode-btn.active {
|
||||
/* ── Status notices ─────────────────────────────────────── */
|
||||
.status-notice {
|
||||
color: var(--color-text-secondary, #6b7a99);
|
||||
font-size: 0.9rem;
|
||||
padding: 1rem;
|
||||
}
|
||||
.status-notice.empty {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
align-items: center;
|
||||
gap: 0.5rem;
|
||||
padding: 3rem 1rem;
|
||||
text-align: center;
|
||||
}
|
||||
.hint { font-size: 0.85rem; opacity: 0.75; }
|
||||
|
||||
/* ── Meta line ──────────────────────────────────────────── */
|
||||
.meta-line {
|
||||
display: flex;
|
||||
gap: 0.5rem;
|
||||
align-items: center;
|
||||
font-size: 0.85rem;
|
||||
color: var(--color-text-secondary, #6b7a99);
|
||||
font-family: var(--font-mono, monospace);
|
||||
flex-wrap: wrap;
|
||||
}
|
||||
.sep { opacity: 0.4; }
|
||||
|
||||
/* ── Chart sections ─────────────────────────────────────── */
|
||||
.chart-section {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0.75rem;
|
||||
}
|
||||
|
||||
.chart-title {
|
||||
font-size: 0.95rem;
|
||||
font-weight: 600;
|
||||
color: var(--color-text, #1a2338);
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
/* ── Bar charts ─────────────────────────────────────────── */
|
||||
.bar-chart {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0.4rem;
|
||||
}
|
||||
|
||||
.bar-row {
|
||||
display: grid;
|
||||
grid-template-columns: 14rem 1fr 5rem;
|
||||
align-items: center;
|
||||
gap: 0.5rem;
|
||||
font-size: 0.82rem;
|
||||
}
|
||||
|
||||
.bar-label {
|
||||
font-family: var(--font-mono, monospace);
|
||||
font-size: 0.76rem;
|
||||
white-space: nowrap;
|
||||
overflow: hidden;
|
||||
text-overflow: ellipsis;
|
||||
color: var(--color-text, #1a2338);
|
||||
}
|
||||
|
||||
.bar-track {
|
||||
height: 16px;
|
||||
background: var(--color-surface-raised, #e4ebf5);
|
||||
border-radius: 99px;
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
.bar-fill {
|
||||
height: 100%;
|
||||
border-radius: 99px;
|
||||
transition: width 0.5s cubic-bezier(0.16, 1, 0.3, 1);
|
||||
}
|
||||
|
||||
.latency-fill { background: var(--app-primary, #2A6080); opacity: 0.65; }
|
||||
|
||||
.bar-value {
|
||||
text-align: right;
|
||||
font-family: var(--font-mono, monospace);
|
||||
font-size: 0.8rem;
|
||||
font-variant-numeric: tabular-nums;
|
||||
}
|
||||
|
||||
/* ── Heatmap ────────────────────────────────────────────── */
|
||||
.heatmap-scroll {
|
||||
overflow-x: auto;
|
||||
border-radius: 0.5rem;
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
}
|
||||
|
||||
.heatmap {
|
||||
border-collapse: collapse;
|
||||
min-width: 100%;
|
||||
font-size: 0.78rem;
|
||||
}
|
||||
|
||||
.hm-label-col {
|
||||
text-align: left;
|
||||
min-width: 11rem;
|
||||
padding: 0.4rem 0.6rem;
|
||||
background: var(--color-surface-raised, #e4ebf5);
|
||||
font-weight: 600;
|
||||
border-bottom: 1px solid var(--color-border, #d0d7e8);
|
||||
position: sticky;
|
||||
left: 0;
|
||||
}
|
||||
|
||||
.hm-model-col {
|
||||
min-width: 5rem;
|
||||
max-width: 8rem;
|
||||
padding: 0.4rem 0.5rem;
|
||||
background: var(--color-surface-raised, #e4ebf5);
|
||||
border-bottom: 1px solid var(--color-border, #d0d7e8);
|
||||
font-family: var(--font-mono, monospace);
|
||||
font-size: 0.7rem;
|
||||
text-overflow: ellipsis;
|
||||
overflow: hidden;
|
||||
white-space: nowrap;
|
||||
text-align: center;
|
||||
}
|
||||
|
||||
.hm-label-cell {
|
||||
padding: 0.35rem 0.6rem;
|
||||
background: var(--color-surface, #fff);
|
||||
border-top: 1px solid var(--color-border, #d0d7e8);
|
||||
white-space: nowrap;
|
||||
font-family: var(--font-mono, monospace);
|
||||
font-size: 0.74rem;
|
||||
position: sticky;
|
||||
left: 0;
|
||||
}
|
||||
|
||||
.hm-emoji { margin-right: 0.3rem; }
|
||||
|
||||
.hm-value-cell {
|
||||
padding: 0.35rem 0.5rem;
|
||||
text-align: center;
|
||||
font-family: var(--font-mono, monospace);
|
||||
font-variant-numeric: tabular-nums;
|
||||
border-top: 1px solid rgba(255,255,255,0.08);
|
||||
cursor: default;
|
||||
transition: filter 0.15s;
|
||||
}
|
||||
.hm-value-cell:hover { filter: brightness(1.15); }
|
||||
|
||||
.heatmap-hint {
|
||||
font-size: 0.75rem;
|
||||
color: var(--color-text-secondary, #6b7a99);
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
/* ── Mobile tweaks ──────────────────────────────────────── */
|
||||
@media (max-width: 600px) {
|
||||
.bar-row { grid-template-columns: 9rem 1fr 4rem; }
|
||||
.bar-label { font-size: 0.7rem; }
|
||||
.bench-header { flex-direction: column; align-items: flex-start; }
|
||||
}
|
||||
|
||||
/* ── Trained models badge row ──────────────────────────── */
|
||||
.trained-models-row {
|
||||
display: flex;
|
||||
flex-wrap: wrap;
|
||||
align-items: center;
|
||||
gap: 0.5rem;
|
||||
padding: 0.6rem 0.75rem;
|
||||
background: var(--color-surface-raised, #e4ebf5);
|
||||
border-radius: 0.5rem;
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
}
|
||||
|
||||
.trained-label {
|
||||
font-size: 0.75rem;
|
||||
font-weight: 700;
|
||||
color: var(--color-text-secondary, #6b7a99);
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.05em;
|
||||
flex-shrink: 0;
|
||||
}
|
||||
|
||||
.trained-badge {
|
||||
display: inline-flex;
|
||||
align-items: center;
|
||||
gap: 0.35rem;
|
||||
padding: 0.2rem 0.55rem;
|
||||
background: var(--app-primary, #2A6080);
|
||||
color: #fff;
|
||||
border-radius: 1rem;
|
||||
font-family: var(--font-mono, monospace);
|
||||
font-size: 0.76rem;
|
||||
cursor: default;
|
||||
}
|
||||
|
||||
.mode-btn:not(.active):hover {
|
||||
.trained-f1 {
|
||||
background: rgba(255,255,255,0.2);
|
||||
border-radius: 0.75rem;
|
||||
padding: 0.05rem 0.35rem;
|
||||
font-size: 0.7rem;
|
||||
font-weight: 700;
|
||||
}
|
||||
|
||||
/* ── Fine-tune section ──────────────────────────────────── */
|
||||
.ft-section {
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: 0.5rem;
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
.ft-summary {
|
||||
padding: 0.65rem 0.9rem;
|
||||
cursor: pointer;
|
||||
font-size: 0.9rem;
|
||||
font-weight: 600;
|
||||
color: var(--color-text, #1a2338);
|
||||
user-select: none;
|
||||
list-style: none;
|
||||
background: var(--color-surface-raised, #e4ebf5);
|
||||
}
|
||||
.ft-summary::-webkit-details-marker { display: none; }
|
||||
.ft-summary::before { content: '▶ '; font-size: 0.65rem; color: var(--color-text-secondary, #6b7a99); }
|
||||
details[open] .ft-summary::before { content: '▼ '; }
|
||||
|
||||
.ft-body {
|
||||
padding: 0.75rem;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0.75rem;
|
||||
border-top: 1px solid var(--color-border, #d0d7e8);
|
||||
}
|
||||
|
||||
.ft-controls {
|
||||
display: flex;
|
||||
flex-wrap: wrap;
|
||||
gap: 0.75rem;
|
||||
align-items: flex-end;
|
||||
}
|
||||
|
||||
.ft-field {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0.25rem;
|
||||
}
|
||||
|
||||
.ft-field-label {
|
||||
font-size: 0.75rem;
|
||||
font-weight: 600;
|
||||
color: var(--color-text-secondary, #6b7a99);
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.04em;
|
||||
}
|
||||
|
||||
.ft-select {
|
||||
padding: 0.35rem 0.5rem;
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: 0.375rem;
|
||||
background: var(--color-surface, #fff);
|
||||
font-size: 0.85rem;
|
||||
color: var(--color-text, #1a2338);
|
||||
min-width: 220px;
|
||||
}
|
||||
.ft-select:disabled { opacity: 0.55; }
|
||||
|
||||
.ft-epochs {
|
||||
width: 64px;
|
||||
padding: 0.35rem 0.5rem;
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: 0.375rem;
|
||||
background: var(--color-surface, #fff);
|
||||
font-size: 0.85rem;
|
||||
color: var(--color-text, #1a2338);
|
||||
text-align: center;
|
||||
}
|
||||
.ft-epochs:disabled { opacity: 0.55; }
|
||||
|
||||
.ft-run-btn { align-self: flex-end; }
|
||||
|
||||
.ft-log { margin-top: 0; }
|
||||
|
||||
@media (max-width: 600px) {
|
||||
.mode-btn { padding: 0.4rem 0.65rem; font-size: 0.78rem; }
|
||||
.ft-controls { flex-direction: column; align-items: stretch; }
|
||||
.ft-select { min-width: 0; width: 100%; }
|
||||
}
|
||||
</style>
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -1,722 +0,0 @@
|
|||
<template>
|
||||
<div class="compare-tab">
|
||||
|
||||
<!-- Source toggle -->
|
||||
<div class="source-toggle" role="group" aria-label="Prompt source">
|
||||
<button class="source-btn" :class="{ active: promptSource === 'tasks' }" @click="promptSource = 'tasks'">
|
||||
📋 cf-orch Tasks
|
||||
</button>
|
||||
<button class="source-btn" :class="{ active: promptSource === 'style' }" @click="promptSource = 'style'">
|
||||
✍️ Writing Style Prompts
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<!-- Task selector (cf-orch tasks) -->
|
||||
<details v-if="promptSource === 'tasks'" class="model-picker" open>
|
||||
<summary class="picker-summary">
|
||||
<span class="picker-title">📋 Pick a Task</span>
|
||||
<span class="picker-badge">{{ cmpSelectedTask ? cmpSelectedTask.name : 'None selected' }}</span>
|
||||
</summary>
|
||||
<div class="picker-body">
|
||||
<div v-if="llmTasksLoading" class="picker-loading">Loading tasks…</div>
|
||||
<div v-else-if="llmTasks.length === 0" class="picker-empty">No tasks found — check cforch config.</div>
|
||||
<template v-else>
|
||||
<div v-for="(tasks, type) in llmTasksByType" :key="type" class="picker-category">
|
||||
<span class="picker-cat-name picker-cat-section">{{ type }}</span>
|
||||
<div class="picker-model-list">
|
||||
<label v-for="t in tasks" :key="t.id" class="picker-model-row">
|
||||
<input
|
||||
type="radio"
|
||||
name="cmp-task"
|
||||
:checked="cmpSelectedTask?.id === t.id"
|
||||
@change="selectCmpTask(t)"
|
||||
/>
|
||||
<span class="picker-model-name" :title="t.name">{{ t.name }}</span>
|
||||
</label>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
</div>
|
||||
</details>
|
||||
|
||||
<!-- Writing style prompt selector -->
|
||||
<details v-if="promptSource === 'style'" class="model-picker" open>
|
||||
<summary class="picker-summary">
|
||||
<span class="picker-title">✍️ Pick a Writing Style Prompt</span>
|
||||
<span class="picker-badge">{{ selectedVoicePrompt ? selectedVoicePrompt.tag : 'None selected' }}</span>
|
||||
</summary>
|
||||
<div class="picker-body">
|
||||
<div class="picker-model-list style-prompt-list">
|
||||
<label v-for="vp in STYLE_PROMPTS" :key="vp.tag" class="picker-model-row style-prompt-row">
|
||||
<input
|
||||
type="radio"
|
||||
name="cmp-style-prompt"
|
||||
:checked="selectedVoicePrompt?.tag === vp.tag"
|
||||
@change="selectVoicePrompt(vp)"
|
||||
/>
|
||||
<span class="style-prompt-tag">{{ vp.tag }}</span>
|
||||
<span class="style-prompt-title">{{ vp.thread_title }}</span>
|
||||
</label>
|
||||
</div>
|
||||
</div>
|
||||
</details>
|
||||
|
||||
<!-- Prompt editor + model picker (shown once a prompt source is ready) -->
|
||||
<template v-if="promptSource === 'tasks' ? !!cmpSelectedTask : !!selectedVoicePrompt">
|
||||
<label class="prompt-label" for="cmp-prompt">Prompt</label>
|
||||
<textarea
|
||||
id="cmp-prompt"
|
||||
class="cmp-prompt-editor"
|
||||
v-model="cmpPrompt"
|
||||
rows="6"
|
||||
/>
|
||||
|
||||
<!-- LLM model picker (ollama + vllm + cf-text) -->
|
||||
<details class="model-picker" open>
|
||||
<summary class="picker-summary">
|
||||
<span class="picker-title">🤖 LLM Models</span>
|
||||
<span class="picker-badge">{{ cmpSelectedModels.size }} / {{ llmSelectableModels.length }}</span>
|
||||
</summary>
|
||||
<div class="picker-body">
|
||||
<label class="picker-cat-header">
|
||||
<input
|
||||
type="checkbox"
|
||||
:checked="cmpSelectedModels.size === llmSelectableModels.length"
|
||||
:indeterminate="cmpSelectedModels.size > 0 && cmpSelectedModels.size < llmSelectableModels.length"
|
||||
@change="toggleAllCmpModels(($event.target as HTMLInputElement).checked)"
|
||||
/>
|
||||
<span class="picker-cat-name">All LLM models</span>
|
||||
</label>
|
||||
<div v-for="(models, service) in llmModelsByService" :key="service" class="picker-category">
|
||||
<span class="picker-cat-section">{{ service }}</span>
|
||||
<div class="picker-model-list">
|
||||
<label v-for="m in models" :key="m.id" class="picker-model-row">
|
||||
<input
|
||||
type="checkbox"
|
||||
:checked="cmpSelectedModels.has(m.id)"
|
||||
@change="toggleCmpModel(m.id, ($event.target as HTMLInputElement).checked)"
|
||||
/>
|
||||
<span class="picker-model-name">{{ m.name }}</span>
|
||||
<span class="picker-adapter-type">{{ m.tags.slice(0, 2).join(', ') }}</span>
|
||||
</label>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</details>
|
||||
|
||||
<!-- Run controls -->
|
||||
<div class="run-controls">
|
||||
<button
|
||||
class="btn-run"
|
||||
:disabled="cmpRunning || cmpSelectedModels.size === 0"
|
||||
@click="startCompare"
|
||||
>{{ cmpRunning ? '⏳ Running…' : '⚖️ Compare Models' }}</button>
|
||||
<button v-if="cmpRunning" class="btn-cancel" @click="cancelCompare">✕ Cancel</button>
|
||||
</div>
|
||||
|
||||
<!-- Progress log -->
|
||||
<div v-if="cmpLog.length > 0" class="run-log">
|
||||
<div class="log-lines">
|
||||
<div v-for="(line, i) in cmpLog" :key="i" class="log-line">{{ line }}</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Side-by-side results -->
|
||||
<template v-if="cmpResults.length > 0">
|
||||
<h2 class="chart-title">Side-by-Side Responses</h2>
|
||||
<div class="cmp-results-grid">
|
||||
<div
|
||||
v-for="r in cmpResults"
|
||||
:key="r.model"
|
||||
class="cmp-result-card"
|
||||
:class="{ 'cmp-error': !!r.error }"
|
||||
>
|
||||
<div class="cmp-result-header">
|
||||
<span class="cmp-model-name">{{ r.model }}</span>
|
||||
<span class="cmp-meta">
|
||||
<template v-if="r.error"><span class="err-badge">error</span></template>
|
||||
<template v-else>{{ (r.elapsed_ms / 1000).toFixed(1) }}s</template>
|
||||
</span>
|
||||
</div>
|
||||
<pre v-if="r.error" class="cmp-error-text">{{ r.error }}</pre>
|
||||
<pre v-else class="cmp-response">{{ r.response }}</pre>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
</template>
|
||||
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, computed, onMounted } from 'vue'
|
||||
import { useApiFetch } from '../composables/useApi'
|
||||
|
||||
// ── Types ───────────────────────────────────────────────────────────────────
|
||||
interface CfOrchTask {
|
||||
id: string
|
||||
name: string
|
||||
type: string
|
||||
prompt: string
|
||||
system: string
|
||||
}
|
||||
|
||||
interface CfOrchModel {
|
||||
name: string
|
||||
id: string
|
||||
service: string
|
||||
tags: string[]
|
||||
vram_estimate_mb?: number
|
||||
}
|
||||
|
||||
interface CmpResult {
|
||||
model: string
|
||||
response: string
|
||||
elapsed_ms: number
|
||||
error: string | null
|
||||
}
|
||||
|
||||
interface VoicePrompt {
|
||||
tag: string
|
||||
thread_title: string
|
||||
thread_body: string
|
||||
}
|
||||
|
||||
// ── Writing style prompts (mirrors TEST_PROMPTS in benchmark_style.py) ──────
|
||||
const STYLE_SYSTEM = "You are a writing assistant. Your job is to write a Reddit reply that matches the user's voice — casual, direct, community-first. No em dashes. No filler phrases. No semicolons. Short punchy sentences."
|
||||
|
||||
const STYLE_PROMPTS: VoicePrompt[] = [
|
||||
{
|
||||
tag: 'selfhosted_ai_fatigue',
|
||||
thread_title: "Anyone else getting tired of re-explaining their setup every time an AI model forgets?",
|
||||
thread_body: "Every session I start over. My whole hardware setup, what tools I use, what I've already tried. It's exhausting. There has to be a better way.",
|
||||
},
|
||||
{
|
||||
tag: 'privacy_local_llm',
|
||||
thread_title: "What's the point of running local LLMs if the apps still phone home?",
|
||||
thread_body: "I went through all the trouble of setting up ollama and now I find out the frontend I'm using is sending telemetry. Kind of defeats the purpose.",
|
||||
},
|
||||
{
|
||||
tag: 'solarpunk_tech',
|
||||
thread_title: "What does solarpunk computing actually look like in practice?",
|
||||
thread_body: "I keep seeing the aesthetic but not a lot of concrete examples of people living it out with their tech choices. What does it mean day to day?",
|
||||
},
|
||||
{
|
||||
tag: 'nd_tools',
|
||||
thread_title: "Tools that actually help with executive function vs ones that just add friction",
|
||||
thread_body: "I've tried a dozen productivity apps and most of them require more executive function to maintain than they save. What actually sticks for you?",
|
||||
},
|
||||
{
|
||||
tag: 'data_ownership',
|
||||
thread_title: "Who actually owns your data when you use a 'free' AI tool?",
|
||||
thread_body: "Read the ToS on three different AI assistants today. In all three cases your inputs can be used for training, shared with partners, and retained indefinitely. Is this just accepted now?",
|
||||
},
|
||||
{
|
||||
tag: 'digital_culture',
|
||||
thread_title: "The internet used to feel like it belonged to everyone. What happened?",
|
||||
thread_body: "I grew up on forums, IRC, personal homepages. Now everything is a platform owned by someone trying to extract value from the community that built it.",
|
||||
},
|
||||
]
|
||||
|
||||
// ── State ───────────────────────────────────────────────────────────────────
|
||||
const llmTasks = ref<CfOrchTask[]>([])
|
||||
const llmTasksLoading = ref(false)
|
||||
const llmModels = ref<CfOrchModel[]>([])
|
||||
|
||||
const promptSource = ref<'tasks' | 'style'>('tasks')
|
||||
const cmpSelectedTask = ref<CfOrchTask | null>(null)
|
||||
const selectedVoicePrompt = ref<VoicePrompt | null>(null)
|
||||
const cmpSystemPrompt = ref('')
|
||||
const cmpPrompt = ref('')
|
||||
const cmpSelectedModels = ref<Set<string>>(new Set())
|
||||
const cmpRunning = ref(false)
|
||||
const cmpLog = ref<string[]>([])
|
||||
const cmpResults = ref<CmpResult[]>([])
|
||||
const cmpEventSource = ref<EventSource | null>(null)
|
||||
|
||||
// ── Computed ────────────────────────────────────────────────────────────────
|
||||
const LLM_SERVICES = new Set(['ollama', 'vllm', 'cf-text'])
|
||||
|
||||
const llmSelectableModels = computed(() =>
|
||||
llmModels.value.filter(m => LLM_SERVICES.has(m.service))
|
||||
)
|
||||
|
||||
/** Group selectable models by service for the picker UI */
|
||||
const llmModelsByService = computed((): Record<string, CfOrchModel[]> => {
|
||||
const groups: Record<string, CfOrchModel[]> = {}
|
||||
for (const m of llmSelectableModels.value) {
|
||||
if (!groups[m.service]) groups[m.service] = []
|
||||
groups[m.service].push(m)
|
||||
}
|
||||
return groups
|
||||
})
|
||||
|
||||
const llmTasksByType = computed((): Record<string, CfOrchTask[]> => {
|
||||
const groups: Record<string, CfOrchTask[]> = {}
|
||||
for (const t of llmTasks.value) {
|
||||
if (!groups[t.type]) groups[t.type] = []
|
||||
groups[t.type].push(t)
|
||||
}
|
||||
return groups
|
||||
})
|
||||
|
||||
// ── Helpers ─────────────────────────────────────────────────────────────────
|
||||
function selectCmpTask(t: CfOrchTask) {
|
||||
cmpSelectedTask.value = t
|
||||
cmpPrompt.value = t.prompt || ''
|
||||
cmpSystemPrompt.value = t.system || ''
|
||||
cmpResults.value = []
|
||||
cmpLog.value = []
|
||||
}
|
||||
|
||||
function selectVoicePrompt(vp: VoicePrompt) {
|
||||
selectedVoicePrompt.value = vp
|
||||
cmpPrompt.value = `Thread: ${vp.thread_title}\n\n${vp.thread_body}\n\nWrite a reply:`
|
||||
cmpSystemPrompt.value = STYLE_SYSTEM
|
||||
cmpResults.value = []
|
||||
cmpLog.value = []
|
||||
}
|
||||
|
||||
function toggleCmpModel(id: string, checked: boolean) {
|
||||
const next = new Set(cmpSelectedModels.value)
|
||||
checked ? next.add(id) : next.delete(id)
|
||||
cmpSelectedModels.value = next
|
||||
}
|
||||
|
||||
function toggleAllCmpModels(checked: boolean) {
|
||||
cmpSelectedModels.value = checked
|
||||
? new Set(llmSelectableModels.value.map(m => m.id))
|
||||
: new Set()
|
||||
}
|
||||
|
||||
// ── Data loaders ──────────────────────────────────────────────────────────────
|
||||
async function loadLlmTasks() {
|
||||
llmTasksLoading.value = true
|
||||
const { data } = await useApiFetch<{ tasks: CfOrchTask[]; types: string[] }>('/api/cforch/tasks')
|
||||
llmTasksLoading.value = false
|
||||
if (data?.tasks) {
|
||||
llmTasks.value = data.tasks
|
||||
}
|
||||
}
|
||||
|
||||
async function loadLlmModels() {
|
||||
const { data } = await useApiFetch<{ models: CfOrchModel[] }>('/api/cforch/models')
|
||||
if (data?.models) {
|
||||
llmModels.value = data.models
|
||||
cmpSelectedModels.value = new Set(
|
||||
data.models.filter(m => LLM_SERVICES.has(m.service)).map(m => m.id)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// ── Run / cancel ──────────────────────────────────────────────────────────────
|
||||
function startCompare() {
|
||||
if (!cmpPrompt.value.trim() || cmpSelectedModels.value.size === 0) return
|
||||
cmpRunning.value = true
|
||||
cmpResults.value = []
|
||||
cmpLog.value = []
|
||||
|
||||
const params = new URLSearchParams({
|
||||
prompt: cmpPrompt.value,
|
||||
model_ids: [...cmpSelectedModels.value].join(','),
|
||||
system: cmpSystemPrompt.value,
|
||||
})
|
||||
|
||||
const es = new EventSource(`/api/imitate/run?${params}`)
|
||||
cmpEventSource.value = es
|
||||
|
||||
es.onmessage = (event: MessageEvent) => {
|
||||
try {
|
||||
const msg = JSON.parse(event.data)
|
||||
if (msg.type === 'start') {
|
||||
cmpLog.value.push(`Comparing ${msg.total_models} models…`)
|
||||
} else if (msg.type === 'model_start') {
|
||||
cmpLog.value.push(`→ ${msg.model}…`)
|
||||
} else if (msg.type === 'model_done') {
|
||||
const status = msg.error
|
||||
? `✕ ${msg.error}`
|
||||
: `✓ ${(msg.elapsed_ms / 1000).toFixed(1)}s`
|
||||
cmpLog.value.push(` ${msg.model}: ${status}`)
|
||||
cmpResults.value.push({
|
||||
model: msg.model,
|
||||
response: msg.response,
|
||||
elapsed_ms: msg.elapsed_ms,
|
||||
error: msg.error ?? null,
|
||||
})
|
||||
} else if (msg.type === 'complete') {
|
||||
cmpRunning.value = false
|
||||
es.close()
|
||||
}
|
||||
} catch { /* ignore malformed frames */ }
|
||||
}
|
||||
|
||||
es.onerror = () => {
|
||||
cmpLog.value.push('Connection error.')
|
||||
cmpRunning.value = false
|
||||
es.close()
|
||||
cmpEventSource.value = null
|
||||
}
|
||||
}
|
||||
|
||||
function cancelCompare() {
|
||||
cmpEventSource.value?.close()
|
||||
cmpEventSource.value = null
|
||||
cmpRunning.value = false
|
||||
cmpLog.value.push('Cancelled.')
|
||||
}
|
||||
|
||||
onMounted(() => {
|
||||
loadLlmTasks()
|
||||
loadLlmModels()
|
||||
})
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
.compare-tab {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 1.75rem;
|
||||
}
|
||||
|
||||
/* ── Source toggle ──────────────────────────────────────── */
|
||||
.source-toggle {
|
||||
display: inline-flex;
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: 0.5rem;
|
||||
overflow: hidden;
|
||||
align-self: flex-start;
|
||||
}
|
||||
|
||||
.source-btn {
|
||||
padding: 0.4rem 1rem;
|
||||
font-size: 0.83rem;
|
||||
font-family: var(--font-body, sans-serif);
|
||||
font-weight: 500;
|
||||
border: none;
|
||||
background: var(--color-surface, #fff);
|
||||
color: var(--color-text-secondary, #6b7a99);
|
||||
cursor: pointer;
|
||||
transition: background 0.15s, color 0.15s;
|
||||
}
|
||||
.source-btn:not(:last-child) { border-right: 1px solid var(--color-border, #d0d7e8); }
|
||||
.source-btn.active { background: var(--app-primary, #2A6080); color: #fff; }
|
||||
.source-btn:not(.active):hover { background: var(--color-surface-raised, #e4ebf5); }
|
||||
|
||||
/* ── Voice prompt list ──────────────────────────────────── */
|
||||
.style-prompt-list { flex-direction: column !important; flex-wrap: nowrap !important; padding-left: 0 !important; gap: 0.4rem !important; }
|
||||
|
||||
.style-prompt-row {
|
||||
flex-direction: column !important;
|
||||
align-items: flex-start !important;
|
||||
gap: 0.15rem !important;
|
||||
padding: 0.5rem 0.6rem;
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: 0.35rem;
|
||||
background: var(--color-surface, #f4f7fc);
|
||||
cursor: pointer;
|
||||
transition: background 0.1s;
|
||||
}
|
||||
.style-prompt-row:hover { background: var(--color-surface-raised, #e4ebf5); }
|
||||
.style-prompt-row:has(input:checked) {
|
||||
background: color-mix(in srgb, var(--app-primary, #2A6080) 10%, transparent);
|
||||
border-color: var(--app-primary, #2A6080);
|
||||
}
|
||||
.style-prompt-row input { display: none; }
|
||||
|
||||
.style-prompt-tag {
|
||||
font-family: var(--font-mono, monospace);
|
||||
font-size: 0.72rem;
|
||||
color: var(--app-primary, #2A6080);
|
||||
font-weight: 700;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.04em;
|
||||
}
|
||||
|
||||
.style-prompt-title {
|
||||
font-size: 0.83rem;
|
||||
color: var(--color-text, #1a2338);
|
||||
line-height: 1.4;
|
||||
}
|
||||
|
||||
/* ── Buttons ────────────────────────────────────────────── */
|
||||
.btn-run {
|
||||
padding: 0.45rem 1.1rem;
|
||||
border-radius: 0.375rem;
|
||||
border: none;
|
||||
background: var(--app-primary, #2A6080);
|
||||
color: #fff;
|
||||
font-size: 0.88rem;
|
||||
font-family: var(--font-body, sans-serif);
|
||||
cursor: pointer;
|
||||
transition: opacity 0.15s;
|
||||
}
|
||||
.btn-run:disabled { opacity: 0.5; cursor: not-allowed; }
|
||||
.btn-run:not(:disabled):hover { opacity: 0.85; }
|
||||
|
||||
.btn-cancel {
|
||||
padding: 0.45rem 0.9rem;
|
||||
background: transparent;
|
||||
border: 1px solid var(--color-text-secondary, #6b7a99);
|
||||
color: var(--color-text-secondary, #6b7a99);
|
||||
border-radius: 0.4rem;
|
||||
font-size: 0.85rem;
|
||||
font-weight: 500;
|
||||
cursor: pointer;
|
||||
transition: background 0.15s;
|
||||
}
|
||||
.btn-cancel:hover {
|
||||
background: color-mix(in srgb, var(--color-text-secondary, #6b7a99) 12%, transparent);
|
||||
}
|
||||
|
||||
/* ── Run controls row ───────────────────────────────────── */
|
||||
.run-controls {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.75rem;
|
||||
flex-wrap: wrap;
|
||||
}
|
||||
|
||||
/* ── Run log ────────────────────────────────────────────── */
|
||||
.run-log {
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: 0.5rem;
|
||||
overflow: hidden;
|
||||
font-family: var(--font-mono, monospace);
|
||||
font-size: 0.78rem;
|
||||
}
|
||||
|
||||
.log-lines {
|
||||
max-height: 160px;
|
||||
overflow-y: auto;
|
||||
padding: 0.5rem 0.75rem;
|
||||
background: var(--color-surface, #fff);
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0.1rem;
|
||||
}
|
||||
|
||||
.log-line { color: var(--color-text, #1a2338); line-height: 1.5; }
|
||||
|
||||
/* ── Chart title ────────────────────────────────────────── */
|
||||
.chart-title {
|
||||
font-size: 0.95rem;
|
||||
font-weight: 600;
|
||||
color: var(--color-text, #1a2338);
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
/* ── Model Picker ───────────────────────────────────────── */
|
||||
.model-picker {
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: 0.5rem;
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
.picker-summary {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.6rem;
|
||||
padding: 0.65rem 0.9rem;
|
||||
cursor: pointer;
|
||||
user-select: none;
|
||||
list-style: none;
|
||||
background: var(--color-surface-raised, #e4ebf5);
|
||||
}
|
||||
.picker-summary::-webkit-details-marker { display: none; }
|
||||
.picker-summary::before { content: '▶ '; font-size: 0.65rem; color: var(--color-text-secondary, #6b7a99); }
|
||||
details[open] .picker-summary::before { content: '▼ '; }
|
||||
|
||||
.picker-title {
|
||||
font-size: 0.9rem;
|
||||
font-weight: 600;
|
||||
color: var(--color-text, #1a2338);
|
||||
}
|
||||
|
||||
.picker-badge {
|
||||
font-size: 0.75rem;
|
||||
color: var(--color-text-secondary, #6b7a99);
|
||||
background: var(--color-surface, #fff);
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
padding: 0.15rem 0.5rem;
|
||||
border-radius: 1rem;
|
||||
font-family: var(--font-mono, monospace);
|
||||
margin-left: auto;
|
||||
}
|
||||
|
||||
.picker-body {
|
||||
padding: 0.75rem;
|
||||
border-top: 1px solid var(--color-border, #d0d7e8);
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0.75rem;
|
||||
}
|
||||
|
||||
.picker-loading, .picker-empty {
|
||||
font-size: 0.85rem;
|
||||
color: var(--color-text-secondary, #6b7a99);
|
||||
padding: 0.5rem 0;
|
||||
}
|
||||
|
||||
.picker-category {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0.3rem;
|
||||
}
|
||||
|
||||
.picker-cat-header {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.45rem;
|
||||
font-size: 0.82rem;
|
||||
font-weight: 700;
|
||||
color: var(--color-text, #1a2338);
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.04em;
|
||||
cursor: pointer;
|
||||
}
|
||||
|
||||
.picker-cat-name { /* inherits from cat-header or section */ }
|
||||
|
||||
.picker-cat-section {
|
||||
font-weight: 600;
|
||||
font-size: 0.82rem;
|
||||
padding: 0.35rem 0;
|
||||
display: block;
|
||||
color: var(--color-text, #1a2338);
|
||||
}
|
||||
|
||||
.picker-model-list {
|
||||
display: flex;
|
||||
flex-wrap: wrap;
|
||||
gap: 0.35rem 0.75rem;
|
||||
padding-left: 1.4rem;
|
||||
}
|
||||
|
||||
.picker-model-row {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.35rem;
|
||||
font-size: 0.82rem;
|
||||
cursor: pointer;
|
||||
color: var(--color-text, #1a2338);
|
||||
}
|
||||
|
||||
.picker-model-name {
|
||||
font-family: var(--font-mono, monospace);
|
||||
font-size: 0.78rem;
|
||||
white-space: nowrap;
|
||||
max-width: 18ch;
|
||||
overflow: hidden;
|
||||
text-overflow: ellipsis;
|
||||
}
|
||||
|
||||
.picker-adapter-type {
|
||||
font-size: 0.68rem;
|
||||
color: var(--color-text-secondary, #6b7a99);
|
||||
background: var(--color-surface-raised, #e4ebf5);
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: 0.25rem;
|
||||
padding: 0.05rem 0.3rem;
|
||||
font-family: var(--font-mono, monospace);
|
||||
}
|
||||
|
||||
/* ── Prompt editor ──────────────────────────────────────── */
|
||||
.prompt-label {
|
||||
font-size: 0.85rem;
|
||||
font-weight: 600;
|
||||
color: var(--color-text-secondary, #6b7a99);
|
||||
margin-top: 0.5rem;
|
||||
}
|
||||
|
||||
.cmp-prompt-editor {
|
||||
width: 100%;
|
||||
font-family: var(--font-mono, monospace);
|
||||
font-size: 0.85rem;
|
||||
padding: 0.75rem;
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: 0.375rem;
|
||||
background: var(--color-surface, #f0f4fc);
|
||||
color: var(--color-text, #1a2338);
|
||||
resize: vertical;
|
||||
line-height: 1.5;
|
||||
box-sizing: border-box;
|
||||
}
|
||||
.cmp-prompt-editor:focus {
|
||||
outline: 2px solid var(--app-primary, #2A6080);
|
||||
outline-offset: -1px;
|
||||
}
|
||||
|
||||
/* ── Results grid ───────────────────────────────────────── */
|
||||
.cmp-results-grid {
|
||||
display: grid;
|
||||
grid-template-columns: repeat(auto-fill, minmax(280px, 1fr));
|
||||
gap: 1rem;
|
||||
margin-top: 0.5rem;
|
||||
}
|
||||
|
||||
.cmp-result-card {
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: 0.5rem;
|
||||
overflow: hidden;
|
||||
background: var(--color-surface, #f0f4fc);
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
}
|
||||
|
||||
.cmp-result-card.cmp-error {
|
||||
border-color: #fca5a5;
|
||||
}
|
||||
|
||||
.cmp-result-header {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
padding: 0.5rem 0.75rem;
|
||||
background: var(--color-surface-raised, #e4ebf5);
|
||||
border-bottom: 1px solid var(--color-border, #d0d7e8);
|
||||
}
|
||||
|
||||
.cmp-model-name {
|
||||
font-size: 0.82rem;
|
||||
font-weight: 600;
|
||||
color: var(--color-text, #1a2338);
|
||||
overflow: hidden;
|
||||
text-overflow: ellipsis;
|
||||
white-space: nowrap;
|
||||
}
|
||||
|
||||
.cmp-meta {
|
||||
font-size: 0.75rem;
|
||||
color: var(--color-text-secondary, #6b7a99);
|
||||
flex-shrink: 0;
|
||||
margin-left: 0.5rem;
|
||||
}
|
||||
|
||||
.err-badge {
|
||||
background: #fee2e2;
|
||||
color: #991b1b;
|
||||
border-radius: 9999px;
|
||||
padding: 0.1rem 0.45rem;
|
||||
font-size: 0.7rem;
|
||||
font-weight: 600;
|
||||
}
|
||||
|
||||
.cmp-response, .cmp-error-text {
|
||||
padding: 0.75rem;
|
||||
font-size: 0.82rem;
|
||||
white-space: pre-wrap;
|
||||
word-break: break-word;
|
||||
max-height: 300px;
|
||||
overflow-y: auto;
|
||||
margin: 0;
|
||||
flex: 1;
|
||||
color: var(--color-text, #1a2338);
|
||||
}
|
||||
|
||||
.cmp-error-text { color: #b91c1c; }
|
||||
|
||||
@media (max-width: 600px) {
|
||||
.picker-model-list { padding-left: 0; }
|
||||
.picker-model-name { max-width: 14ch; }
|
||||
}
|
||||
</style>
|
||||
|
|
@ -1,31 +0,0 @@
|
|||
import { mount, flushPromises } from '@vue/test-utils'
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest'
|
||||
import CompareView from './CompareView.vue'
|
||||
|
||||
beforeEach(() => {
|
||||
vi.stubGlobal('fetch', vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
json: async () => ({ tasks: [], types: [], models: [] }),
|
||||
text: async () => '',
|
||||
}))
|
||||
vi.stubGlobal('EventSource', class {
|
||||
onmessage = null
|
||||
onerror = null
|
||||
close() {}
|
||||
})
|
||||
})
|
||||
|
||||
describe('CompareView', () => {
|
||||
it('renders page title "Compare"', async () => {
|
||||
const w = mount(CompareView)
|
||||
await flushPromises()
|
||||
expect(w.find('h1.page-title').text()).toContain('Compare')
|
||||
})
|
||||
|
||||
it('wraps CompareTab component', async () => {
|
||||
const w = mount(CompareView)
|
||||
await flushPromises()
|
||||
// CompareTab renders a .compare-tab root div
|
||||
expect(w.find('.compare-tab').exists()).toBe(true)
|
||||
})
|
||||
})
|
||||
|
|
@ -1,36 +0,0 @@
|
|||
<template>
|
||||
<div class="compare-view">
|
||||
<header class="compare-header">
|
||||
<h1 class="page-title">🔍 Compare</h1>
|
||||
</header>
|
||||
<CompareTab />
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import CompareTab from './CompareTab.vue'
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
.compare-view {
|
||||
max-width: 860px;
|
||||
margin: 0 auto;
|
||||
padding: 1.5rem 1rem 4rem;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 1.75rem;
|
||||
}
|
||||
|
||||
.compare-header {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
.page-title {
|
||||
font-family: var(--font-display, var(--font-body, sans-serif));
|
||||
font-size: 1.4rem;
|
||||
font-weight: 700;
|
||||
color: var(--app-primary, #2A6080);
|
||||
margin: 0;
|
||||
}
|
||||
</style>
|
||||
|
|
@ -1,328 +0,0 @@
|
|||
<template>
|
||||
<div class="corrections-view">
|
||||
<header class="cv-header">
|
||||
<span class="queue-count">
|
||||
<template v-if="loading">Loading…</template>
|
||||
<template v-else-if="store.totalRemaining > 0">
|
||||
{{ store.totalRemaining }} remaining
|
||||
</template>
|
||||
<span v-else class="queue-empty-label">All caught up</span>
|
||||
</span>
|
||||
<div class="header-actions">
|
||||
<button @click="handleUndo" :disabled="!store.lastAction" class="btn-action">↩ Undo</button>
|
||||
</div>
|
||||
</header>
|
||||
|
||||
<!-- States -->
|
||||
<div v-if="loading" class="skeleton-card" aria-label="Loading candidates" />
|
||||
|
||||
<div v-else-if="apiError" class="error-display" role="alert">
|
||||
<p>Couldn't reach Avocet API.</p>
|
||||
<button @click="fetchBatch" class="btn-action">Retry</button>
|
||||
</div>
|
||||
|
||||
<div v-else-if="!store.current" class="empty-state">
|
||||
<p>No candidates need review.</p>
|
||||
<p class="empty-hint">Import a benchmark run from the Settings tab to get started.</p>
|
||||
</div>
|
||||
|
||||
<template v-else>
|
||||
<div class="card-wrapper">
|
||||
<SftCard
|
||||
:item="store.current"
|
||||
:correcting="correcting"
|
||||
@correct="startCorrection"
|
||||
@discard="handleDiscard"
|
||||
@flag="handleFlag"
|
||||
@submit-correction="handleCorrect"
|
||||
@cancel-correction="correcting = false"
|
||||
ref="sftCardEl"
|
||||
/>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<!-- Stats footer -->
|
||||
<footer v-if="stats" class="stats-footer">
|
||||
<span class="stat">✓ {{ stats.by_status?.approved ?? 0 }} approved</span>
|
||||
<span class="stat">✕ {{ stats.by_status?.discarded ?? 0 }} discarded</span>
|
||||
<span class="stat">⚑ {{ stats.by_status?.model_rejected ?? 0 }} flagged</span>
|
||||
<a
|
||||
v-if="(stats.export_ready ?? 0) > 0"
|
||||
:href="exportUrl"
|
||||
download
|
||||
class="btn-export"
|
||||
>
|
||||
⬇ Export {{ stats.export_ready }} corrections
|
||||
</a>
|
||||
</footer>
|
||||
|
||||
<!-- Undo toast (inline — UndoToast.vue uses label store's LastAction shape, not SFT's) -->
|
||||
<div v-if="store.lastAction" class="undo-toast">
|
||||
<span>Last: {{ store.lastAction.type }}</span>
|
||||
<button @click="handleUndo" class="btn-undo">↩ Undo</button>
|
||||
<button @click="store.clearLastAction()" class="btn-dismiss">✕</button>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, onMounted } from 'vue'
|
||||
import { useSftStore } from '../stores/sft'
|
||||
import type { SftFailureCategory } from '../stores/sft'
|
||||
import { useSftKeyboard } from '../composables/useSftKeyboard'
|
||||
import SftCard from '../components/SftCard.vue'
|
||||
|
||||
const store = useSftStore()
|
||||
const loading = ref(false)
|
||||
const apiError = ref(false)
|
||||
const correcting = ref(false)
|
||||
const stats = ref<Record<string, any> | null>(null)
|
||||
const exportUrl = '/api/sft/export'
|
||||
const sftCardEl = ref<InstanceType<typeof SftCard> | null>(null)
|
||||
|
||||
useSftKeyboard({
|
||||
onCorrect: () => { if (store.current && !correcting.value) correcting.value = true },
|
||||
onDiscard: () => { if (store.current && !correcting.value) handleDiscard() },
|
||||
onFlag: () => { if (store.current && !correcting.value) handleFlag() },
|
||||
onEscape: () => { correcting.value = false },
|
||||
onSubmit: () => {},
|
||||
isEditing: () => correcting.value,
|
||||
})
|
||||
|
||||
async function fetchBatch() {
|
||||
loading.value = true
|
||||
apiError.value = false
|
||||
try {
|
||||
const res = await fetch('/api/sft/queue?per_page=20')
|
||||
if (!res.ok) throw new Error('API error')
|
||||
const data = await res.json()
|
||||
store.queue = data.items
|
||||
store.totalRemaining = data.total
|
||||
} catch {
|
||||
apiError.value = true
|
||||
} finally {
|
||||
loading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
async function fetchStats() {
|
||||
try {
|
||||
const res = await fetch('/api/sft/stats')
|
||||
if (res.ok) stats.value = await res.json()
|
||||
} catch { /* ignore */ }
|
||||
}
|
||||
|
||||
function startCorrection() {
|
||||
correcting.value = true
|
||||
}
|
||||
|
||||
async function handleCorrect(text: string, category: SftFailureCategory | null = null) {
|
||||
if (!store.current) return
|
||||
const item = store.current
|
||||
correcting.value = false
|
||||
try {
|
||||
const body: Record<string, unknown> = { id: item.id, action: 'correct', corrected_response: text }
|
||||
if (category != null) body.failure_category = category
|
||||
const res = await fetch('/api/sft/submit', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify(body),
|
||||
})
|
||||
if (!res.ok) throw new Error(`HTTP ${res.status}`)
|
||||
store.removeCurrentFromQueue()
|
||||
store.setLastAction('correct', item, category)
|
||||
store.totalRemaining = Math.max(0, store.totalRemaining - 1)
|
||||
fetchStats()
|
||||
if (store.queue.length < 5) fetchBatch()
|
||||
} catch (err) {
|
||||
console.error('handleCorrect failed:', err)
|
||||
}
|
||||
}
|
||||
|
||||
async function handleDiscard(category: SftFailureCategory | null = null) {
|
||||
if (!store.current) return
|
||||
const item = store.current
|
||||
try {
|
||||
const body: Record<string, unknown> = { id: item.id, action: 'discard' }
|
||||
if (category != null) body.failure_category = category
|
||||
const res = await fetch('/api/sft/submit', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify(body),
|
||||
})
|
||||
if (!res.ok) throw new Error(`HTTP ${res.status}`)
|
||||
store.removeCurrentFromQueue()
|
||||
store.setLastAction('discard', item, category)
|
||||
store.totalRemaining = Math.max(0, store.totalRemaining - 1)
|
||||
fetchStats()
|
||||
if (store.queue.length < 5) fetchBatch()
|
||||
} catch (err) {
|
||||
console.error('handleDiscard failed:', err)
|
||||
}
|
||||
}
|
||||
|
||||
async function handleFlag(category: SftFailureCategory | null = null) {
|
||||
if (!store.current) return
|
||||
const item = store.current
|
||||
try {
|
||||
const body: Record<string, unknown> = { id: item.id, action: 'flag' }
|
||||
if (category != null) body.failure_category = category
|
||||
const res = await fetch('/api/sft/submit', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify(body),
|
||||
})
|
||||
if (!res.ok) throw new Error(`HTTP ${res.status}`)
|
||||
store.removeCurrentFromQueue()
|
||||
store.setLastAction('flag', item, category)
|
||||
store.totalRemaining = Math.max(0, store.totalRemaining - 1)
|
||||
fetchStats()
|
||||
if (store.queue.length < 5) fetchBatch()
|
||||
} catch (err) {
|
||||
console.error('handleFlag failed:', err)
|
||||
}
|
||||
}
|
||||
|
||||
async function handleUndo() {
|
||||
if (!store.lastAction) return
|
||||
const action = store.lastAction
|
||||
const { item } = action
|
||||
try {
|
||||
const res = await fetch('/api/sft/undo', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ id: item.id }),
|
||||
})
|
||||
if (!res.ok) throw new Error(`HTTP ${res.status}`)
|
||||
store.restoreItem(item)
|
||||
store.totalRemaining++
|
||||
store.clearLastAction()
|
||||
fetchStats()
|
||||
} catch (err) {
|
||||
// Backend did not restore — clear the undo UI without restoring queue state
|
||||
console.error('handleUndo failed:', err)
|
||||
store.clearLastAction()
|
||||
}
|
||||
}
|
||||
|
||||
onMounted(() => {
|
||||
fetchBatch()
|
||||
fetchStats()
|
||||
})
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
.corrections-view {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
min-height: 100dvh;
|
||||
padding: var(--space-4);
|
||||
gap: var(--space-4);
|
||||
max-width: 760px;
|
||||
margin: 0 auto;
|
||||
}
|
||||
|
||||
.cv-header {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
.queue-count {
|
||||
font-size: 1rem;
|
||||
font-weight: 600;
|
||||
color: var(--color-text);
|
||||
}
|
||||
|
||||
.queue-empty-label { color: var(--color-text-muted); }
|
||||
|
||||
.btn-action {
|
||||
padding: var(--space-2) var(--space-3);
|
||||
border: 1px solid var(--color-border);
|
||||
border-radius: var(--radius-md);
|
||||
background: var(--color-surface-raised);
|
||||
cursor: pointer;
|
||||
font-size: 0.88rem;
|
||||
}
|
||||
|
||||
.btn-action:disabled { opacity: 0.4; cursor: not-allowed; }
|
||||
|
||||
.skeleton-card {
|
||||
height: 320px;
|
||||
background: var(--color-surface-alt);
|
||||
border-radius: var(--radius-lg);
|
||||
animation: pulse 1.5s ease-in-out infinite;
|
||||
}
|
||||
|
||||
@keyframes pulse {
|
||||
0%, 100% { opacity: 1; }
|
||||
50% { opacity: 0.5; }
|
||||
}
|
||||
|
||||
@media (prefers-reduced-motion: reduce) {
|
||||
.skeleton-card { animation: none; }
|
||||
}
|
||||
|
||||
.error-display, .empty-state {
|
||||
text-align: center;
|
||||
padding: var(--space-12);
|
||||
color: var(--color-text-muted);
|
||||
}
|
||||
|
||||
.empty-hint { font-size: 0.88rem; margin-top: var(--space-2); }
|
||||
|
||||
.stats-footer {
|
||||
display: flex;
|
||||
gap: var(--space-4);
|
||||
align-items: center;
|
||||
flex-wrap: wrap;
|
||||
padding: var(--space-3) 0;
|
||||
border-top: 1px solid var(--color-border-light);
|
||||
font-size: 0.85rem;
|
||||
color: var(--color-text-muted);
|
||||
}
|
||||
|
||||
.btn-export {
|
||||
margin-left: auto;
|
||||
padding: var(--space-2) var(--space-3);
|
||||
background: var(--color-primary);
|
||||
color: var(--color-text-inverse, #fff);
|
||||
border-radius: var(--radius-md);
|
||||
text-decoration: none;
|
||||
font-size: 0.88rem;
|
||||
}
|
||||
|
||||
.undo-toast {
|
||||
position: fixed;
|
||||
bottom: var(--space-6);
|
||||
left: 50%;
|
||||
transform: translateX(-50%);
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: var(--space-3);
|
||||
background: var(--color-surface-raised);
|
||||
border: 1px solid var(--color-border);
|
||||
border-radius: var(--radius-md);
|
||||
padding: var(--space-3) var(--space-4);
|
||||
box-shadow: 0 4px 12px rgba(0,0,0,0.15);
|
||||
font-size: 0.9rem;
|
||||
}
|
||||
|
||||
.btn-undo {
|
||||
background: var(--color-primary);
|
||||
color: var(--color-text-inverse, #fff);
|
||||
border: none;
|
||||
border-radius: var(--radius-sm);
|
||||
padding: var(--space-1) var(--space-3);
|
||||
cursor: pointer;
|
||||
font-size: 0.88rem;
|
||||
}
|
||||
|
||||
.btn-dismiss {
|
||||
background: none;
|
||||
border: none;
|
||||
color: var(--color-text-muted);
|
||||
cursor: pointer;
|
||||
font-size: 1rem;
|
||||
}
|
||||
</style>
|
||||
|
|
@ -1,119 +0,0 @@
|
|||
import { mount, flushPromises } from '@vue/test-utils'
|
||||
import { createRouter, createWebHashHistory } from 'vue-router'
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest'
|
||||
import DashboardView from './DashboardView.vue'
|
||||
|
||||
const router = createRouter({
|
||||
history: createWebHashHistory(),
|
||||
routes: [
|
||||
{ path: '/', component: { template: '<div />' } },
|
||||
{ path: '/eval/benchmark', component: { template: '<div />' } },
|
||||
{ path: '/train/jobs', component: { template: '<div />' } },
|
||||
{ path: '/fleet', component: { template: '<div />' } },
|
||||
],
|
||||
})
|
||||
|
||||
const baseDashboard = {
|
||||
labeled_since_last_eval: 0,
|
||||
last_eval_timestamp: null,
|
||||
last_eval_best_score: null,
|
||||
active_jobs: [],
|
||||
corrections_export_ready: 0,
|
||||
signals: { data_to_eval: false, eval_to_train: false, train_to_fleet: false },
|
||||
}
|
||||
|
||||
function mockFetch(overrides: Partial<typeof baseDashboard> = {}) {
|
||||
vi.stubGlobal('fetch', vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
json: async () => ({ ...baseDashboard, ...overrides }),
|
||||
text: async () => '',
|
||||
}))
|
||||
}
|
||||
|
||||
beforeEach(() => mockFetch())
|
||||
|
||||
describe('DashboardView', () => {
|
||||
it('renders page title', async () => {
|
||||
const w = mount(DashboardView, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
expect(w.text()).toContain('Dashboard')
|
||||
})
|
||||
|
||||
it('shows three stage cards: Data, Eval, Train', async () => {
|
||||
const w = mount(DashboardView, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
expect(w.find('.stage-card[data-stage="data"]').exists()).toBe(true)
|
||||
expect(w.find('.stage-card[data-stage="eval"]').exists()).toBe(true)
|
||||
expect(w.find('.stage-card[data-stage="train"]').exists()).toBe(true)
|
||||
})
|
||||
|
||||
it('shows labeled_since_last_eval count in Data card', async () => {
|
||||
mockFetch({ labeled_since_last_eval: 42 })
|
||||
const w = mount(DashboardView, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
expect(w.find('.stage-card[data-stage="data"]').text()).toContain('42')
|
||||
})
|
||||
|
||||
it('does NOT show Run Eval CTA when data_to_eval is false', async () => {
|
||||
mockFetch({ signals: { data_to_eval: false, eval_to_train: false, train_to_fleet: false } })
|
||||
const w = mount(DashboardView, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
const dataCard = w.find('.stage-card[data-stage="data"]')
|
||||
expect(dataCard.find('.cta-btn').exists()).toBe(false)
|
||||
})
|
||||
|
||||
it('shows Run Eval CTA when data_to_eval is true', async () => {
|
||||
mockFetch({ signals: { data_to_eval: true, eval_to_train: false, train_to_fleet: false } })
|
||||
const w = mount(DashboardView, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
const dataCard = w.find('.stage-card[data-stage="data"]')
|
||||
expect(dataCard.find('.cta-btn').exists()).toBe(true)
|
||||
expect(dataCard.find('.cta-btn').text()).toContain('Run Eval')
|
||||
})
|
||||
|
||||
it('shows Queue Finetune CTA when eval_to_train is true', async () => {
|
||||
mockFetch({ signals: { data_to_eval: false, eval_to_train: true, train_to_fleet: false } })
|
||||
const w = mount(DashboardView, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
const evalCard = w.find('.stage-card[data-stage="eval"]')
|
||||
expect(evalCard.find('.cta-btn').text()).toContain('Queue Finetune')
|
||||
})
|
||||
|
||||
it('shows Register in Fleet CTA when train_to_fleet is true', async () => {
|
||||
mockFetch({ signals: { data_to_eval: false, eval_to_train: false, train_to_fleet: true } })
|
||||
const w = mount(DashboardView, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
const trainCard = w.find('.stage-card[data-stage="train"]')
|
||||
expect(trainCard.find('.cta-btn').text()).toContain('Register in Fleet')
|
||||
})
|
||||
|
||||
it('shows active job status pills in Train card', async () => {
|
||||
mockFetch({ active_jobs: [{ id: 'j1', type: 'classifier', model_key: 'deberta-v3', status: 'running' }] })
|
||||
const w = mount(DashboardView, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
const trainCard = w.find('.stage-card[data-stage="train"]')
|
||||
expect(trainCard.find('.status-pill').exists()).toBe(true)
|
||||
expect(trainCard.text()).toContain('deberta-v3')
|
||||
})
|
||||
|
||||
it('shows last eval score in Eval card when present', async () => {
|
||||
mockFetch({ last_eval_best_score: 0.821 })
|
||||
const w = mount(DashboardView, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
const evalCard = w.find('.stage-card[data-stage="eval"]')
|
||||
expect(evalCard.text()).toContain('82.1%')
|
||||
})
|
||||
|
||||
it('shows error state when API call fails', async () => {
|
||||
vi.stubGlobal('fetch', vi.fn().mockResolvedValue({ ok: false, status: 503, text: async () => '' }))
|
||||
const w = mount(DashboardView, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
expect(w.find('.error-notice').exists()).toBe(true)
|
||||
})
|
||||
|
||||
it('shows refresh button', async () => {
|
||||
const w = mount(DashboardView, { global: { plugins: [router] } })
|
||||
await flushPromises()
|
||||
expect(w.find('.refresh-btn').exists()).toBe(true)
|
||||
})
|
||||
})
|
||||
|
|
@ -1,406 +0,0 @@
|
|||
<template>
|
||||
<div class="dashboard-view">
|
||||
<header class="dashboard-header">
|
||||
<h1 class="page-title">📊 Dashboard</h1>
|
||||
<button class="refresh-btn" :disabled="loading" @click="load" aria-label="Refresh dashboard">
|
||||
🔄
|
||||
</button>
|
||||
</header>
|
||||
|
||||
<div v-if="loading && !data" class="loading-state">Loading…</div>
|
||||
|
||||
<div v-if="error" class="error-notice" role="alert">
|
||||
{{ error }}
|
||||
<button class="btn-retry" @click="load">Retry</button>
|
||||
</div>
|
||||
|
||||
<div v-if="data" class="flywheel-grid">
|
||||
|
||||
<!-- ① Data card -->
|
||||
<div class="stage-card" data-stage="data">
|
||||
<div class="card-header">
|
||||
<span class="card-step">①</span>
|
||||
<h2 class="card-title">Data</h2>
|
||||
</div>
|
||||
<div class="card-body">
|
||||
<p class="card-metric">
|
||||
<strong class="metric-value">{{ data.labeled_since_last_eval.toLocaleString() }}</strong>
|
||||
<span class="metric-label"> labeled since last eval</span>
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- ② Eval card -->
|
||||
<div class="stage-card" data-stage="eval">
|
||||
<div class="card-header">
|
||||
<span class="card-step">②</span>
|
||||
<h2 class="card-title">Eval</h2>
|
||||
</div>
|
||||
<div class="card-body">
|
||||
<div class="bench-run-table">
|
||||
<div
|
||||
v-for="(run, type) in data.recent_bench_runs"
|
||||
:key="type"
|
||||
class="bench-run-row"
|
||||
>
|
||||
<span class="bench-type-label">{{ BENCH_LABELS[type as BenchType] ?? type }}</span>
|
||||
<span class="bench-run-time" :class="{ 'metric-muted': !run.timestamp }">
|
||||
{{ run.timestamp ? formatBenchTs(run.timestamp) : '—' }}
|
||||
</span>
|
||||
<span v-if="run.score != null" class="bench-run-score">
|
||||
{{ formatScore(run.score) }}
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div v-if="data.signals.eval_to_train" class="card-cta">
|
||||
<RouterLink to="/train/jobs" class="cta-btn">Queue Finetune</RouterLink>
|
||||
</div>
|
||||
<div v-if="data.signals.data_to_eval" class="card-cta">
|
||||
<RouterLink to="/eval/benchmark" class="cta-btn">Run Eval</RouterLink>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- ③ Train card -->
|
||||
<div class="stage-card" data-stage="train">
|
||||
<div class="card-header">
|
||||
<span class="card-step">③</span>
|
||||
<h2 class="card-title">Train</h2>
|
||||
</div>
|
||||
<div class="card-body">
|
||||
<template v-if="data.active_jobs.length > 0">
|
||||
<div
|
||||
v-for="job in data.active_jobs"
|
||||
:key="job.id"
|
||||
class="job-row"
|
||||
>
|
||||
<span class="job-key">{{ job.model_key }}</span>
|
||||
<span class="status-pill" :class="`status-${job.status}`">{{ job.status }}</span>
|
||||
</div>
|
||||
</template>
|
||||
<p v-else class="card-metric metric-muted">No active jobs</p>
|
||||
|
||||
<p v-if="data.corrections_export_ready > 0" class="card-metric">
|
||||
<strong class="metric-value">{{ data.corrections_export_ready }}</strong>
|
||||
<span class="metric-label"> corrections ready</span>
|
||||
</p>
|
||||
</div>
|
||||
<div v-if="data.signals.train_to_fleet" class="card-cta">
|
||||
<RouterLink to="/fleet" class="cta-btn">Register in Fleet</RouterLink>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, computed, onMounted } from 'vue'
|
||||
import { RouterLink } from 'vue-router'
|
||||
|
||||
interface ActiveJob {
|
||||
id: string
|
||||
type: string
|
||||
model_key: string
|
||||
status: 'queued' | 'running' | 'completed' | 'failed' | 'cancelled'
|
||||
}
|
||||
|
||||
interface DashboardSignals {
|
||||
data_to_eval: boolean
|
||||
eval_to_train: boolean
|
||||
train_to_fleet: boolean
|
||||
}
|
||||
|
||||
interface BenchRun {
|
||||
timestamp: string | null
|
||||
metric: string | null
|
||||
score: number | null
|
||||
}
|
||||
|
||||
type BenchType = 'classifier' | 'llm' | 'style' | 'plans'
|
||||
|
||||
interface DashboardData {
|
||||
labeled_since_last_eval: number
|
||||
last_eval_timestamp: string | null
|
||||
last_eval_best_score: number | null
|
||||
active_jobs: ActiveJob[]
|
||||
corrections_export_ready: number
|
||||
recent_bench_runs: Record<BenchType, BenchRun>
|
||||
signals: DashboardSignals
|
||||
}
|
||||
|
||||
const BENCH_LABELS: Record<BenchType, string> = {
|
||||
classifier: 'Classifier',
|
||||
llm: 'LLM Eval',
|
||||
style: 'Style',
|
||||
plans: 'Planning',
|
||||
}
|
||||
|
||||
const data = ref<DashboardData | null>(null)
|
||||
const loading = ref(false)
|
||||
const error = ref<string | null>(null)
|
||||
|
||||
function formatBenchTs(ts: string): string {
|
||||
const date = new Date(ts)
|
||||
if (!isNaN(date.getTime())) {
|
||||
const diff = Date.now() - date.getTime()
|
||||
const mins = Math.floor(diff / 60000)
|
||||
if (mins < 1) return 'just now'
|
||||
if (mins < 60) return `${mins}m ago`
|
||||
const hrs = Math.floor(mins / 60)
|
||||
if (hrs < 24) return `${hrs}h ago`
|
||||
return `${Math.floor(hrs / 24)}d ago`
|
||||
}
|
||||
// Non-ISO: show as-is (plans bench uses "YYYY-MM-DD HH:MM")
|
||||
return ts.length > 16 ? ts.slice(0, 16) : ts
|
||||
}
|
||||
|
||||
function formatScore(score: number): string {
|
||||
return `${(score * 100).toFixed(1)}%`
|
||||
}
|
||||
|
||||
async function load() {
|
||||
loading.value = true
|
||||
error.value = null
|
||||
try {
|
||||
const res = await fetch('/api/dashboard')
|
||||
if (!res.ok) {
|
||||
error.value = `Could not load dashboard (HTTP ${res.status}).`
|
||||
return
|
||||
}
|
||||
data.value = await res.json() as DashboardData
|
||||
} catch {
|
||||
error.value = 'Network error. Is the Avocet API running?'
|
||||
} finally {
|
||||
loading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
onMounted(() => load())
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
.dashboard-view {
|
||||
max-width: 860px;
|
||||
margin: 0 auto;
|
||||
padding: 1.5rem 1rem 4rem;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 1.75rem;
|
||||
}
|
||||
|
||||
.dashboard-header {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.75rem;
|
||||
}
|
||||
|
||||
.page-title {
|
||||
font-family: var(--font-display, var(--font-body, sans-serif));
|
||||
font-size: 1.4rem;
|
||||
font-weight: 700;
|
||||
color: var(--app-primary, #2A6080);
|
||||
margin: 0;
|
||||
flex: 1;
|
||||
}
|
||||
|
||||
.refresh-btn {
|
||||
background: transparent;
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: 0.375rem;
|
||||
cursor: pointer;
|
||||
font-size: 1rem;
|
||||
padding: 0.3rem 0.5rem;
|
||||
transition: background 0.15s;
|
||||
}
|
||||
|
||||
.refresh-btn:hover:not(:disabled) { background: var(--color-surface-raised, #e4ebf5); }
|
||||
.refresh-btn:disabled { opacity: 0.5; cursor: not-allowed; }
|
||||
|
||||
/* ── Flywheel grid ── */
|
||||
.flywheel-grid {
|
||||
display: grid;
|
||||
grid-template-columns: repeat(3, 1fr);
|
||||
gap: 1rem;
|
||||
}
|
||||
|
||||
@media (max-width: 680px) {
|
||||
.flywheel-grid {
|
||||
grid-template-columns: 1fr;
|
||||
}
|
||||
}
|
||||
|
||||
/* ── Stage cards ── */
|
||||
.stage-card {
|
||||
background: var(--color-surface-raised, #f5f7fc);
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: var(--radius-lg, 1rem);
|
||||
padding: 1rem;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0.75rem;
|
||||
box-shadow: var(--shadow-sm);
|
||||
}
|
||||
|
||||
.card-header {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.5rem;
|
||||
border-bottom: 1px solid var(--color-border, #d0d7e8);
|
||||
padding-bottom: 0.6rem;
|
||||
}
|
||||
|
||||
.card-step {
|
||||
font-size: 1.1rem;
|
||||
font-weight: 700;
|
||||
color: var(--app-primary, #2A6080);
|
||||
flex-shrink: 0;
|
||||
}
|
||||
|
||||
.card-title {
|
||||
font-family: var(--font-display, var(--font-body, sans-serif));
|
||||
font-size: 1rem;
|
||||
font-weight: 600;
|
||||
color: var(--color-text, #1a2338);
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
.card-body {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0.4rem;
|
||||
flex: 1;
|
||||
}
|
||||
|
||||
.card-metric {
|
||||
margin: 0;
|
||||
font-size: 0.875rem;
|
||||
color: var(--color-text, #1a2338);
|
||||
}
|
||||
|
||||
.metric-value {
|
||||
font-size: 1.05rem;
|
||||
font-weight: 700;
|
||||
color: var(--app-primary, #2A6080);
|
||||
}
|
||||
|
||||
.metric-label {
|
||||
color: var(--color-text-muted, #4a5c7a);
|
||||
}
|
||||
|
||||
.metric-muted { color: var(--color-text-muted, #4a5c7a); }
|
||||
|
||||
.card-cta { margin-top: auto; }
|
||||
|
||||
.cta-btn {
|
||||
display: block;
|
||||
width: 100%;
|
||||
text-align: center;
|
||||
padding: 0.5rem;
|
||||
background: var(--app-primary, #2A6080);
|
||||
color: #fff;
|
||||
border-radius: 0.375rem;
|
||||
text-decoration: none;
|
||||
font-size: 0.875rem;
|
||||
font-weight: 600;
|
||||
transition: background 0.15s;
|
||||
}
|
||||
|
||||
.cta-btn:hover { background: color-mix(in srgb, var(--app-primary, #2A6080) 85%, black); }
|
||||
|
||||
/* ── Bench run table ── */
|
||||
.bench-run-table {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0.3rem;
|
||||
}
|
||||
|
||||
.bench-run-row {
|
||||
display: grid;
|
||||
grid-template-columns: 6rem 1fr auto;
|
||||
align-items: center;
|
||||
gap: 0.5rem;
|
||||
font-size: 0.82rem;
|
||||
}
|
||||
|
||||
.bench-type-label {
|
||||
font-weight: 600;
|
||||
color: var(--color-text, #1a2338);
|
||||
font-size: 0.78rem;
|
||||
}
|
||||
|
||||
.bench-run-time {
|
||||
color: var(--color-text-secondary, #6b7a99);
|
||||
font-size: 0.78rem;
|
||||
}
|
||||
|
||||
.bench-run-score {
|
||||
font-family: var(--font-mono, monospace);
|
||||
font-size: 0.75rem;
|
||||
font-weight: 600;
|
||||
color: var(--app-primary, #2A6080);
|
||||
background: color-mix(in srgb, var(--app-primary, #2A6080) 10%, transparent);
|
||||
padding: 0.1rem 0.35rem;
|
||||
border-radius: 0.25rem;
|
||||
}
|
||||
|
||||
/* ── Job pills ── */
|
||||
.job-row {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.5rem;
|
||||
font-size: 0.875rem;
|
||||
}
|
||||
|
||||
.job-key {
|
||||
flex: 1;
|
||||
overflow: hidden;
|
||||
text-overflow: ellipsis;
|
||||
white-space: nowrap;
|
||||
color: var(--color-text, #1a2338);
|
||||
}
|
||||
|
||||
.status-pill {
|
||||
font-size: 0.75rem;
|
||||
padding: 0.15rem 0.45rem;
|
||||
border-radius: 100px;
|
||||
font-weight: 600;
|
||||
flex-shrink: 0;
|
||||
background: var(--color-surface-raised, #e4ebf5);
|
||||
color: var(--color-text-muted, #4a5c7a);
|
||||
}
|
||||
|
||||
.status-pill.status-running { background: #d4f4e0; color: #1a7a3a; }
|
||||
.status-pill.status-queued { background: #fef3cd; color: #856404; }
|
||||
.status-pill.status-failed { background: #fde8e8; color: #842029; }
|
||||
.status-pill.status-completed { background: #e0f0ff; color: #0c5481; }
|
||||
|
||||
/* ── State indicators ── */
|
||||
.loading-state {
|
||||
color: var(--color-text-muted, #4a5c7a);
|
||||
font-size: 0.9rem;
|
||||
}
|
||||
|
||||
.error-notice {
|
||||
background: #fde8e8;
|
||||
color: #842029;
|
||||
border: 1px solid #f5c2c7;
|
||||
border-radius: 0.5rem;
|
||||
padding: 0.75rem 1rem;
|
||||
font-size: 0.875rem;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.75rem;
|
||||
}
|
||||
|
||||
.btn-retry {
|
||||
background: transparent;
|
||||
border: 1px solid currentColor;
|
||||
border-radius: 0.25rem;
|
||||
color: inherit;
|
||||
cursor: pointer;
|
||||
font-size: 0.75rem;
|
||||
padding: 0.2rem 0.5rem;
|
||||
margin-left: auto;
|
||||
}
|
||||
</style>
|
||||
|
|
@ -1,705 +0,0 @@
|
|||
<template>
|
||||
<div class="embed-compare-page">
|
||||
<!-- Step indicator (non-interactive) -->
|
||||
<ol class="step-indicator" aria-label="Setup progress">
|
||||
<li :class="{ complete: corpus.length > 0 }">Corpus</li>
|
||||
<li :class="{ complete: queries.length > 0 }">Queries</li>
|
||||
<li :class="{ complete: selectedModels.length > 0 }">Models</li>
|
||||
<li :class="{ complete: hasResults }">Run & Rate</li>
|
||||
</ol>
|
||||
|
||||
<!-- Persistent aria-live region — always in DOM, never v-if -->
|
||||
<div
|
||||
ref="liveRegion"
|
||||
class="sr-live"
|
||||
aria-live="polite"
|
||||
aria-atomic="true"
|
||||
v-text="liveMessage"
|
||||
/>
|
||||
|
||||
<!-- ① Corpus section -->
|
||||
<section class="card" aria-labelledby="corpus-heading">
|
||||
<h2 id="corpus-heading">① Corpus</h2>
|
||||
<div class="corpus-controls">
|
||||
<div class="field">
|
||||
<label for="corpus-paste">Paste chunks (one per line)</label>
|
||||
<textarea
|
||||
id="corpus-paste"
|
||||
v-model="rawCorpus"
|
||||
rows="6"
|
||||
placeholder="Paste one chunk per line, or use Import below..."
|
||||
@change="onCorpusPaste"
|
||||
/>
|
||||
</div>
|
||||
<div class="import-row">
|
||||
<label for="imitate-product-select">Import from product</label>
|
||||
<select id="imitate-product-select" v-model="selectedProduct">
|
||||
<option value="">-- select product --</option>
|
||||
<option
|
||||
v-for="p in imitateProducts"
|
||||
:key="p.id"
|
||||
:value="p.id"
|
||||
>{{ p.name }}</option>
|
||||
</select>
|
||||
<button
|
||||
class="btn-secondary"
|
||||
:disabled="!selectedProduct || importing"
|
||||
@click="importCorpus"
|
||||
>
|
||||
{{ importing ? 'Importing…' : 'Import' }}
|
||||
</button>
|
||||
<span v-if="importError" class="error-text" role="alert">{{ importError }}</span>
|
||||
</div>
|
||||
<p v-if="corpus.length > 0" class="corpus-count">
|
||||
{{ corpus.length }} chunk{{ corpus.length === 1 ? '' : 's' }} loaded.
|
||||
</p>
|
||||
</div>
|
||||
</section>
|
||||
|
||||
<!-- ② Queries section -->
|
||||
<section class="card" aria-labelledby="queries-heading">
|
||||
<h2 id="queries-heading">② Queries</h2>
|
||||
<div class="field">
|
||||
<label for="query-input">Enter queries (one per line)</label>
|
||||
<textarea
|
||||
id="query-input"
|
||||
v-model="rawQueries"
|
||||
rows="4"
|
||||
placeholder="One query per line..."
|
||||
@change="onQueriesChange"
|
||||
/>
|
||||
</div>
|
||||
<p v-if="queries.length > 0" class="query-count">
|
||||
{{ queries.length }} quer{{ queries.length === 1 ? 'y' : 'ies' }}.
|
||||
</p>
|
||||
</section>
|
||||
|
||||
<!-- ③ Model selection -->
|
||||
<section class="card" aria-labelledby="models-heading">
|
||||
<h2 id="models-heading">③ Models</h2>
|
||||
<p v-if="loadingModels" class="muted">Loading models from Ollama…</p>
|
||||
<p v-else-if="modelsError" class="error-text" role="alert">{{ modelsError }}</p>
|
||||
<ul v-else class="model-list" role="list">
|
||||
<li v-for="m in availableModels" :key="m.name">
|
||||
<label class="model-checkbox">
|
||||
<input
|
||||
type="checkbox"
|
||||
:value="m.name"
|
||||
v-model="selectedModels"
|
||||
/>
|
||||
{{ m.name }}
|
||||
<span class="model-size muted" aria-label="model size">
|
||||
{{ formatBytes(m.size) }}
|
||||
</span>
|
||||
</label>
|
||||
</li>
|
||||
</ul>
|
||||
<p v-if="availableModels.length === 0 && !loadingModels && !modelsError" class="muted">
|
||||
No Ollama models found. Pull an embedding model first.
|
||||
</p>
|
||||
</section>
|
||||
|
||||
<!-- ④ Run controls -->
|
||||
<section class="card run-controls" aria-labelledby="run-heading">
|
||||
<h2 id="run-heading">④ Run</h2>
|
||||
<div class="run-row">
|
||||
<div class="field-inline">
|
||||
<label for="top-k-input">Results per query</label>
|
||||
<input
|
||||
id="top-k-input"
|
||||
type="number"
|
||||
v-model.number="topK"
|
||||
min="1"
|
||||
max="20"
|
||||
style="width: 5rem"
|
||||
/>
|
||||
</div>
|
||||
<button
|
||||
class="btn-primary"
|
||||
:disabled="!canRun || running"
|
||||
@click="startRun"
|
||||
>
|
||||
{{ running ? 'Running…' : 'Run' }}
|
||||
</button>
|
||||
<button
|
||||
v-if="running"
|
||||
class="btn-danger"
|
||||
aria-label="Cancel embedding run"
|
||||
@click="cancelRun"
|
||||
>
|
||||
Cancel
|
||||
</button>
|
||||
</div>
|
||||
<p v-if="!canRun && !running" class="muted">
|
||||
Fill corpus, at least one query, and select at least one model to run.
|
||||
</p>
|
||||
</section>
|
||||
|
||||
<!-- Results -->
|
||||
<section
|
||||
v-if="hasResults"
|
||||
class="card results-section"
|
||||
aria-labelledby="results-heading"
|
||||
>
|
||||
<h2 id="results-heading">Results</h2>
|
||||
|
||||
<!-- Query pagination -->
|
||||
<div class="query-nav" role="navigation" aria-label="Query navigation">
|
||||
<button
|
||||
class="btn-secondary"
|
||||
aria-label="Previous query"
|
||||
:disabled="currentQueryIdx === 0"
|
||||
@click="currentQueryIdx--"
|
||||
>‹</button>
|
||||
<span class="query-counter">
|
||||
Query {{ currentQueryIdx + 1 }} of {{ uniqueQueries.length }}:
|
||||
<em>{{ uniqueQueries[currentQueryIdx] }}</em>
|
||||
</span>
|
||||
<button
|
||||
class="btn-secondary"
|
||||
aria-label="Next query"
|
||||
:disabled="currentQueryIdx >= uniqueQueries.length - 1"
|
||||
@click="currentQueryIdx++"
|
||||
>›</button>
|
||||
</div>
|
||||
|
||||
<!-- Results table: one column per model -->
|
||||
<div class="table-wrap">
|
||||
<table class="results-table">
|
||||
<thead>
|
||||
<tr>
|
||||
<th scope="col" class="rank-col">#</th>
|
||||
<th
|
||||
v-for="model in selectedModels"
|
||||
:key="model"
|
||||
scope="col"
|
||||
>{{ model }}</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
<tr v-for="rank in topK" :key="rank">
|
||||
<td class="rank-col muted">{{ rank }}</td>
|
||||
<td
|
||||
v-for="model in selectedModels"
|
||||
:key="model"
|
||||
class="hit-cell"
|
||||
>
|
||||
<template v-if="getHit(currentQueryIdx, model, rank - 1) as hit">
|
||||
<div class="hit-text">{{ hit.text }}</div>
|
||||
<!-- Visual score bar: decorative only -->
|
||||
<div class="score-row">
|
||||
<div class="score-bar-wrap" aria-hidden="true">
|
||||
<div class="score-bar" :style="{ width: `${hit.score * 100}%` }" />
|
||||
</div>
|
||||
<span class="score-label">{{ hit.score.toFixed(3) }}</span>
|
||||
</div>
|
||||
<!-- Rating buttons -->
|
||||
<div class="rating-row">
|
||||
<button
|
||||
class="rate-btn"
|
||||
:class="{ active: getRating(currentQueryIdx, model, hit.chunk_idx) === 'relevant' }"
|
||||
:aria-pressed="getRating(currentQueryIdx, model, hit.chunk_idx) === 'relevant'"
|
||||
aria-label="Mark as relevant"
|
||||
@click="rate(currentQueryIdx, model, hit, 'relevant')"
|
||||
>
|
||||
👍 Relevant
|
||||
</button>
|
||||
<button
|
||||
class="rate-btn rate-btn-neg"
|
||||
:class="{ active: getRating(currentQueryIdx, model, hit.chunk_idx) === 'not_relevant' }"
|
||||
:aria-pressed="getRating(currentQueryIdx, model, hit.chunk_idx) === 'not_relevant'"
|
||||
aria-label="Mark as not relevant"
|
||||
@click="rate(currentQueryIdx, model, hit, 'not_relevant')"
|
||||
>
|
||||
👎 Not relevant
|
||||
</button>
|
||||
</div>
|
||||
</template>
|
||||
<span v-else class="muted">—</span>
|
||||
</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
</section>
|
||||
|
||||
<!-- Export -->
|
||||
<section
|
||||
v-if="hasResults"
|
||||
class="card export-section"
|
||||
aria-labelledby="export-heading"
|
||||
>
|
||||
<h2 id="export-heading">Export Ratings</h2>
|
||||
<div class="export-row">
|
||||
<fieldset class="export-format-group">
|
||||
<legend>Format</legend>
|
||||
<label><input type="radio" v-model="exportFormat" value="csv" /> CSV</label>
|
||||
<label><input type="radio" v-model="exportFormat" value="json" /> JSON</label>
|
||||
</fieldset>
|
||||
<button class="btn-secondary" @click="exportRatings">Export</button>
|
||||
</div>
|
||||
</section>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, computed, onMounted } from 'vue'
|
||||
|
||||
// ── Types ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
interface OllamaModel { name: string; size: number }
|
||||
interface ImitateProduct { id: string; name: string }
|
||||
interface HitResult { chunk_idx: number; text: string; score: number }
|
||||
interface ResultEvent {
|
||||
type: 'result'
|
||||
query_idx: number
|
||||
query: string
|
||||
model: string
|
||||
hits: HitResult[]
|
||||
}
|
||||
|
||||
// ── State ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
const rawCorpus = ref('')
|
||||
const corpus = ref<string[]>([])
|
||||
const rawQueries = ref('')
|
||||
const queries = ref<string[]>([])
|
||||
const selectedModels = ref<string[]>([])
|
||||
const topK = ref(5)
|
||||
const availableModels = ref<OllamaModel[]>([])
|
||||
const loadingModels = ref(false)
|
||||
const modelsError = ref('')
|
||||
const imitateProducts = ref<ImitateProduct[]>([])
|
||||
const selectedProduct = ref('')
|
||||
const importing = ref(false)
|
||||
const importError = ref('')
|
||||
const running = ref(false)
|
||||
const liveMessage = ref('')
|
||||
const resultEvents = ref<ResultEvent[]>([])
|
||||
const runController = ref<AbortController | null>(null)
|
||||
|
||||
const currentQueryIdx = ref(0)
|
||||
const exportFormat = ref<'csv' | 'json'>('csv')
|
||||
|
||||
type RatingMap = Record<string, Record<string, Record<number, 'relevant' | 'not_relevant'>>>
|
||||
const ratings = ref<RatingMap>({})
|
||||
|
||||
const uniqueQueries = computed(() => {
|
||||
const seen = new Set<string>()
|
||||
const out: string[] = []
|
||||
for (const e of resultEvents.value) {
|
||||
if (!seen.has(e.query)) { seen.add(e.query); out.push(e.query) }
|
||||
}
|
||||
return out
|
||||
})
|
||||
|
||||
const hasResults = computed(() => resultEvents.value.length > 0)
|
||||
const canRun = computed(
|
||||
() => corpus.value.length > 0 && queries.value.length > 0 && selectedModels.value.length > 0
|
||||
)
|
||||
|
||||
// ── Corpus helpers ────────────────────────────────────────────────────────────
|
||||
|
||||
function onCorpusPaste() {
|
||||
const chunks = rawCorpus.value.split('\n').map(l => l.trim()).filter(Boolean)
|
||||
corpus.value = chunks
|
||||
if (chunks.length > 0) {
|
||||
liveMessage.value = `${chunks.length} chunk${chunks.length === 1 ? '' : 's'} loaded.`
|
||||
}
|
||||
}
|
||||
|
||||
function onQueriesChange() {
|
||||
queries.value = rawQueries.value.split('\n').map(l => l.trim()).filter(Boolean)
|
||||
}
|
||||
|
||||
async function importCorpus() {
|
||||
if (!selectedProduct.value) return
|
||||
importing.value = true
|
||||
importError.value = ''
|
||||
try {
|
||||
const r = await fetch(`/api/imitate/products/${selectedProduct.value}/sample-chunks`)
|
||||
if (!r.ok) {
|
||||
const text = await r.text()
|
||||
throw new Error(text || `HTTP ${r.status}`)
|
||||
}
|
||||
const data = await r.json() as { chunks?: string[] }
|
||||
const chunks = data.chunks ?? []
|
||||
corpus.value = chunks
|
||||
rawCorpus.value = chunks.join('\n')
|
||||
liveMessage.value = `${chunks.length} chunk${chunks.length === 1 ? '' : 's'} loaded from import.`
|
||||
} catch (err) {
|
||||
importError.value = String(err)
|
||||
} finally {
|
||||
importing.value = false
|
||||
}
|
||||
}
|
||||
|
||||
// ── Model loading ─────────────────────────────────────────────────────────────
|
||||
|
||||
async function loadModels() {
|
||||
loadingModels.value = true
|
||||
modelsError.value = ''
|
||||
try {
|
||||
const r = await fetch('/api/embed-bench/models')
|
||||
if (!r.ok) throw new Error(`HTTP ${r.status}`)
|
||||
const data = await r.json() as { models: OllamaModel[] }
|
||||
availableModels.value = data.models
|
||||
} catch (err) {
|
||||
modelsError.value = `Failed to load models: ${err}`
|
||||
} finally {
|
||||
loadingModels.value = false
|
||||
}
|
||||
}
|
||||
|
||||
// ── Run ───────────────────────────────────────────────────────────────────────
|
||||
|
||||
async function startRun() {
|
||||
if (!canRun.value) return
|
||||
running.value = true
|
||||
resultEvents.value = []
|
||||
liveMessage.value = 'Starting embedding run…'
|
||||
runController.value = new AbortController()
|
||||
|
||||
try {
|
||||
const resp = await fetch('/api/embed-bench/run', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({
|
||||
corpus: corpus.value,
|
||||
queries: queries.value,
|
||||
models: selectedModels.value,
|
||||
top_k: topK.value,
|
||||
}),
|
||||
signal: runController.value.signal,
|
||||
})
|
||||
|
||||
const reader = resp.body!.getReader()
|
||||
const decoder = new TextDecoder()
|
||||
let buf = ''
|
||||
|
||||
while (true) {
|
||||
const { done, value } = await reader.read()
|
||||
if (done) break
|
||||
buf += decoder.decode(value, { stream: true })
|
||||
const lines = buf.split('\n')
|
||||
buf = lines.pop() ?? ''
|
||||
for (const line of lines) {
|
||||
if (!line.startsWith('data: ')) continue
|
||||
const event = JSON.parse(line.slice(6))
|
||||
if (event.type === 'progress') {
|
||||
liveMessage.value = event.msg
|
||||
} else if (event.type === 'result') {
|
||||
resultEvents.value.push(event as ResultEvent)
|
||||
} else if (event.type === 'done') {
|
||||
liveMessage.value = 'Run complete.'
|
||||
} else if (event.type === 'error') {
|
||||
liveMessage.value = `Error: ${event.msg}`
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (err) {
|
||||
if ((err as Error).name !== 'AbortError') {
|
||||
liveMessage.value = `Run failed: ${err}`
|
||||
}
|
||||
} finally {
|
||||
running.value = false
|
||||
runController.value = null
|
||||
}
|
||||
}
|
||||
|
||||
function cancelRun() {
|
||||
runController.value?.abort()
|
||||
liveMessage.value = 'Run cancelled.'
|
||||
}
|
||||
|
||||
// ── Utilities ─────────────────────────────────────────────────────────────────
|
||||
|
||||
function formatBytes(bytes: number): string {
|
||||
if (bytes < 1_000_000) return `${(bytes / 1000).toFixed(0)} KB`
|
||||
if (bytes < 1_000_000_000) return `${(bytes / 1_000_000).toFixed(0)} MB`
|
||||
return `${(bytes / 1_000_000_000).toFixed(1)} GB`
|
||||
}
|
||||
|
||||
function getHit(queryIdx: number, model: string, rank: number): HitResult | null {
|
||||
const query = uniqueQueries.value[queryIdx]
|
||||
if (!query) return null
|
||||
const ev = resultEvents.value.find(e => e.query === query && e.model === model)
|
||||
return ev?.hits[rank] ?? null
|
||||
}
|
||||
|
||||
function getRating(queryIdx: number, model: string, chunkIdx: number): string | undefined {
|
||||
const query = uniqueQueries.value[queryIdx]
|
||||
return ratings.value[query]?.[model]?.[chunkIdx]
|
||||
}
|
||||
|
||||
async function rate(
|
||||
queryIdx: number,
|
||||
model: string,
|
||||
hit: HitResult,
|
||||
rating: 'relevant' | 'not_relevant',
|
||||
) {
|
||||
const query = uniqueQueries.value[queryIdx]
|
||||
// Optimistic update
|
||||
if (!ratings.value[query]) ratings.value[query] = {}
|
||||
if (!ratings.value[query][model]) ratings.value[query][model] = {}
|
||||
ratings.value[query][model][hit.chunk_idx] = rating
|
||||
|
||||
try {
|
||||
await fetch('/api/embed-bench/rate', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({
|
||||
query,
|
||||
model,
|
||||
chunk_text: hit.text,
|
||||
chunk_idx: hit.chunk_idx,
|
||||
rating,
|
||||
}),
|
||||
})
|
||||
liveMessage.value = `Rated chunk ${hit.chunk_idx + 1} as ${rating}.`
|
||||
} catch (err) {
|
||||
liveMessage.value = `Rating failed: ${err}`
|
||||
}
|
||||
}
|
||||
|
||||
async function exportRatings() {
|
||||
const r = await fetch(`/api/embed-bench/export?format=${exportFormat.value}`)
|
||||
if (!r.ok) {
|
||||
liveMessage.value = `Export failed: HTTP ${r.status}`
|
||||
return
|
||||
}
|
||||
const blob = await r.blob()
|
||||
const disposition = r.headers.get('Content-Disposition') ?? ''
|
||||
const filenameMatch = disposition.match(/filename="([^"]+)"/)
|
||||
const filename = filenameMatch ? filenameMatch[1] : `embed_comparison.${exportFormat.value}`
|
||||
const url = URL.createObjectURL(blob)
|
||||
const a = document.createElement('a')
|
||||
a.href = url
|
||||
a.download = filename
|
||||
a.click()
|
||||
URL.revokeObjectURL(url)
|
||||
liveMessage.value = `Exported ${filename}.`
|
||||
}
|
||||
|
||||
// ── Lifecycle ─────────────────────────────────────────────────────────────────
|
||||
|
||||
onMounted(() => {
|
||||
loadModels()
|
||||
})
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
.embed-compare-page {
|
||||
padding: var(--space-4, 1.5rem);
|
||||
max-width: 1100px;
|
||||
}
|
||||
|
||||
/* Step indicator */
|
||||
.step-indicator {
|
||||
display: flex;
|
||||
gap: 0;
|
||||
list-style: none;
|
||||
margin: 0 0 var(--space-4, 1.5rem);
|
||||
padding: 0;
|
||||
border-bottom: 2px solid var(--color-border, #d0d7e8);
|
||||
}
|
||||
.step-indicator li {
|
||||
padding: 0.4rem 1rem;
|
||||
font-size: 0.8rem;
|
||||
font-weight: 600;
|
||||
color: var(--color-text-muted, #4a5c7a);
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.05em;
|
||||
border-bottom: 2px solid transparent;
|
||||
margin-bottom: -2px;
|
||||
}
|
||||
.step-indicator li.complete {
|
||||
color: var(--app-primary, #2A6080);
|
||||
border-bottom-color: var(--app-primary, #2A6080);
|
||||
}
|
||||
|
||||
/* Accessibility: screen-reader live region — visually hidden but always present */
|
||||
.sr-live {
|
||||
position: absolute;
|
||||
width: 1px; height: 1px;
|
||||
overflow: hidden;
|
||||
clip: rect(0 0 0 0);
|
||||
white-space: nowrap;
|
||||
}
|
||||
|
||||
/* Cards */
|
||||
.card {
|
||||
background: var(--color-surface-raised, #e4ebf5);
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: var(--radius-md, 0.5rem);
|
||||
padding: var(--space-4, 1.5rem);
|
||||
margin-bottom: var(--space-4, 1.5rem);
|
||||
}
|
||||
.card h2 {
|
||||
font-size: 1rem;
|
||||
font-weight: 700;
|
||||
margin: 0 0 var(--space-3, 1rem);
|
||||
color: var(--color-text, #1a2338);
|
||||
}
|
||||
|
||||
.field { display: flex; flex-direction: column; gap: 0.3rem; margin-bottom: 0.75rem; }
|
||||
.field label { font-size: 0.85rem; font-weight: 600; }
|
||||
textarea, input[type="number"] {
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: var(--radius-sm, 0.25rem);
|
||||
padding: 0.5rem;
|
||||
font-size: 0.875rem;
|
||||
background: var(--color-surface, #f0f4fb);
|
||||
color: var(--color-text, #1a2338);
|
||||
resize: vertical;
|
||||
}
|
||||
|
||||
.corpus-controls { display: flex; flex-direction: column; gap: 0.5rem; }
|
||||
.import-row {
|
||||
display: flex; flex-wrap: wrap; gap: 0.5rem; align-items: center;
|
||||
}
|
||||
.import-row label { font-size: 0.85rem; font-weight: 600; }
|
||||
.corpus-count, .query-count { font-size: 0.875rem; color: var(--app-primary, #2A6080); margin: 0; }
|
||||
|
||||
.model-list { list-style: none; padding: 0; margin: 0; display: flex; flex-wrap: wrap; gap: 0.5rem; }
|
||||
.model-checkbox {
|
||||
display: flex; align-items: center; gap: 0.4rem;
|
||||
font-size: 0.875rem; cursor: pointer;
|
||||
padding: 0.3rem 0.6rem;
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: var(--radius-sm, 0.25rem);
|
||||
background: var(--color-surface, #f0f4fb);
|
||||
}
|
||||
.model-size { font-size: 0.75rem; }
|
||||
|
||||
.run-row { display: flex; flex-wrap: wrap; gap: 0.75rem; align-items: flex-end; }
|
||||
.field-inline { display: flex; align-items: center; gap: 0.4rem; }
|
||||
.field-inline label { font-size: 0.85rem; font-weight: 600; white-space: nowrap; }
|
||||
|
||||
.btn-primary, .btn-secondary, .btn-danger {
|
||||
padding: 0.4rem 1rem;
|
||||
border-radius: var(--radius-sm, 0.25rem);
|
||||
border: 1px solid transparent;
|
||||
font-size: 0.875rem;
|
||||
font-weight: 600;
|
||||
cursor: pointer;
|
||||
transition: background 0.15s;
|
||||
}
|
||||
.btn-primary { background: var(--app-primary, #2A6080); color: #fff; }
|
||||
.btn-primary:hover:not(:disabled) { filter: brightness(1.1); }
|
||||
.btn-primary:disabled { opacity: 0.5; cursor: not-allowed; }
|
||||
.btn-secondary { background: var(--color-surface, #f0f4fb); color: var(--color-text, #1a2338); border-color: var(--color-border, #d0d7e8); }
|
||||
.btn-secondary:hover:not(:disabled) { background: var(--color-border, #d0d7e8); }
|
||||
.btn-secondary:disabled { opacity: 0.5; cursor: not-allowed; }
|
||||
.btn-danger { background: var(--color-error, #c0392b); color: #fff; }
|
||||
|
||||
.muted { color: var(--color-text-muted, #4a5c7a); font-size: 0.875rem; }
|
||||
.error-text { color: var(--color-error, #c0392b); font-size: 0.875rem; }
|
||||
|
||||
@media (max-width: 768px) {
|
||||
.import-row { flex-direction: column; align-items: flex-start; }
|
||||
.run-row { flex-direction: column; }
|
||||
.model-list { flex-direction: column; }
|
||||
}
|
||||
|
||||
/* Results table */
|
||||
.table-wrap { overflow-x: auto; }
|
||||
.results-table {
|
||||
width: 100%;
|
||||
border-collapse: collapse;
|
||||
font-size: 0.875rem;
|
||||
}
|
||||
.results-table thead th {
|
||||
position: sticky;
|
||||
top: 0;
|
||||
background: var(--color-surface-raised, #e4ebf5);
|
||||
border-bottom: 2px solid var(--color-border, #d0d7e8);
|
||||
padding: 0.5rem 0.75rem;
|
||||
text-align: left;
|
||||
font-weight: 700;
|
||||
white-space: nowrap;
|
||||
z-index: 1;
|
||||
}
|
||||
.results-table td {
|
||||
padding: 0.5rem 0.75rem;
|
||||
vertical-align: top;
|
||||
border-bottom: 1px solid var(--color-border, #d0d7e8);
|
||||
}
|
||||
.rank-col { width: 2rem; text-align: center; }
|
||||
|
||||
.hit-text { margin-bottom: 0.25rem; line-height: 1.4; }
|
||||
|
||||
.score-row { display: flex; align-items: center; gap: 0.4rem; margin-bottom: 0.25rem; }
|
||||
.score-bar-wrap {
|
||||
flex: 1;
|
||||
height: 6px;
|
||||
background: var(--color-border, #d0d7e8);
|
||||
border-radius: 3px;
|
||||
overflow: hidden;
|
||||
}
|
||||
.score-bar {
|
||||
height: 100%;
|
||||
background: var(--app-primary, #2A6080);
|
||||
border-radius: 3px;
|
||||
transition: width 0.3s ease;
|
||||
}
|
||||
.score-label { font-size: 0.75rem; color: var(--color-text-muted, #4a5c7a); min-width: 3rem; text-align: right; }
|
||||
|
||||
.rating-row { display: flex; gap: 0.4rem; flex-wrap: wrap; }
|
||||
.rate-btn {
|
||||
padding: 0.2rem 0.5rem;
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: var(--radius-sm, 0.25rem);
|
||||
background: var(--color-surface, #f0f4fb);
|
||||
color: var(--color-text, #1a2338);
|
||||
font-size: 0.75rem;
|
||||
cursor: pointer;
|
||||
transition: background 0.15s, border-color 0.15s;
|
||||
}
|
||||
.rate-btn.active {
|
||||
background: color-mix(in srgb, var(--app-primary, #2A6080) 20%, transparent);
|
||||
border-color: var(--app-primary, #2A6080);
|
||||
font-weight: 700;
|
||||
}
|
||||
.rate-btn-neg.active {
|
||||
background: color-mix(in srgb, var(--color-error, #c0392b) 15%, transparent);
|
||||
border-color: var(--color-error, #c0392b);
|
||||
}
|
||||
|
||||
/* Query nav */
|
||||
.query-nav {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.5rem;
|
||||
margin-bottom: 0.75rem;
|
||||
flex-wrap: wrap;
|
||||
}
|
||||
.query-counter { font-size: 0.875rem; flex: 1; }
|
||||
|
||||
/* Export */
|
||||
.export-row { display: flex; gap: 1rem; align-items: center; flex-wrap: wrap; }
|
||||
.export-format-group {
|
||||
border: none;
|
||||
padding: 0;
|
||||
display: flex;
|
||||
gap: 0.75rem;
|
||||
}
|
||||
.export-format-group legend {
|
||||
font-size: 0.85rem;
|
||||
font-weight: 600;
|
||||
margin-bottom: 0.25rem;
|
||||
float: left;
|
||||
margin-right: 0.5rem;
|
||||
}
|
||||
.export-format-group label { font-size: 0.875rem; display: flex; align-items: center; gap: 0.3rem; }
|
||||
|
||||
@media (max-width: 768px) {
|
||||
.results-table thead th,
|
||||
.results-table td { padding: 0.35rem 0.4rem; font-size: 0.8rem; }
|
||||
.query-nav { flex-direction: column; align-items: flex-start; }
|
||||
}
|
||||
|
||||
@media (prefers-reduced-motion: reduce) {
|
||||
.score-bar { transition: none; }
|
||||
}
|
||||
</style>
|
||||
|
|
@ -1,7 +0,0 @@
|
|||
<template>
|
||||
<EmbedCompareTab />
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import EmbedCompareTab from './EmbedCompareTab.vue'
|
||||
</script>
|
||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Reference in a new issue