Compare commits

...

75 commits
v0.4.0 ... main

Author SHA1 Message Date
c123492a1c docs: add LLM development disclosure to README
Humans own design, architecture, code review, testing, and
verification. LLMs are part of our development workflow.
Links to circuitforge.tech/positions for our full position.
2026-05-28 08:20:17 -07:00
391ebb3cd1 feat(recipe-scan): labeling UI for Kiwi vision training pipeline (closes #65)
- POST /api/recipe-scan/import — bulk ingest from Kiwi scanner pipeline, idempotent by item id
- GET /api/recipe-scan/next — oldest-first pending item for review
- POST /api/recipe-scan/items/{id}/approve|edit|reject — label actions
- GET /api/recipe-scan/stats — counts by status and modality
- GET /api/recipe-scan/export — JSONL training pairs (messages chat format, Option B: correction prompt + extracted draft → corrected ground truth)
- GET /api/recipe-scan/image — path-traversal-safe image serving from /Library/Assets/kiwi/
- SQLite at data/recipe_scan.db with WAL mode; separate from corpus.db lifecycle
- set_db_path() testability seam; 18 tests, all passing
- RecipeScanView.vue: two-column review UI (image left, JSON diff right), keyboard shortcuts A/E/R, toast feedback, stats header, export download
- Route /data/recipe-scan and sidebar nav entry added
2026-05-17 12:22:15 -07:00
9bb88b168f feat(corpus): pipeline log ingest from shared dir (closes #67)
Pull-side companion to kiwi#141. Ingests structured JSONL pipeline logs
from /Library/Assets/logs/pipeline/ into the log corpus for Turnstone
logreading model training.

- app/data/log_corpus.py: add ingested_pipeline_files tracking table,
  _pipeline_ingest_dir() config helper, _ingest_one_file() parser, and
  POST /api/corpus/pipeline-ingest endpoint
- source_host = "pipeline_scrape"; source_id from logger field; extra
  dict stored as matched_patterns; batch_type = "pipeline_log"
- Idempotent by filename: skips files already in ingested_pipeline_files
- config/label_tool.yaml.example: add corpus section with pipeline_ingest_dir
  and push sources comment block
- tests/test_log_corpus.py: 8 new tests covering ingest, idempotency,
  non-JSONL filtering, malformed line resilience, incremental runs
2026-05-17 11:28:33 -07:00
13ca082a43 chore(models): refresh model registries with current cluster catalog
Replace stale llama/mistral/phi model refs with models active on the
cluster: deepseek-r1 (1.5b, 7b-4bit, 0528-qwen3-8b-gguf), granite-4.1-8b,
qwen2.5 (3b, 7b), capybarahermes-2.5-mistral-7b, darwin-9b-opus. Update
benchmark_plans.py doc examples to match.
2026-05-17 11:24:03 -07:00
d416ef8aa4 feat(imitate): task-model assignment routing via cf-orch
Add _resolve_task_model() helper that looks up a product.task assignment
from the coordinator and resolves its service_type from the model registry.
Add task_ids param to run_imitate() (comma-separated "product/task" strings)
so the imitate harness can dispatch to models chosen by the assignment layer
rather than requiring explicit model IDs.
2026-05-17 11:23:55 -07:00
79b9ccbd3d feat(fleet): profile editor, assignments tab, node management polish
Backend:
- app/nodes.py: fix coordinator response envelope (.get("nodes"/"services"))
- app/nodes.py: add PUT /nodes/{id}/profile (atomic YAML write + reload)
- app/nodes.py: add POST /nodes/{id}/profile/generate (coordinator-seeded skeleton)
- tests/test_nodes.py: fix mock envelopes; add deploy model + profile tests

Frontend:
- NodeManagementView: tab bar switching nodes / assignments panels
- AssignmentsTab: full product.task → model routing UI (add/edit/delete)
- ProfileEditorPanel: full YAML profile editor with GPU + service sections
- CatalogEntryFormModal: add/edit model catalog entries per service
- ServiceFormModal: add/edit service config blocks
- NodeCard, GpuRow, ServiceBadge, OllamaModelPanel, HfNodeModelPanel: polish pass
- ModelsView: model download additions
- nodes.ts: extend types for full profile editing (ServiceManaged, CatalogEntryFull)
2026-05-17 11:23:47 -07:00
e93afec271 fix(tests): resolve 5 pre-existing test failures on main (closes #56)
- app/models.py: add set_cf_text_models_dir() testability seam
- tests/test_models.py: redirect _CF_TEXT_MODELS_DIR in reset_models_globals
  fixture so list_installed() count tests are not polluted by real NFS models
- app/cforch.py: fix get_results() return type annotation list → dict
- tests/test_cforch.py: give _BENCH_RUNNING=True test a mock proc with
  poll()=None so the stale-flag check correctly returns 409; patch
  _select.select in streaming tests (select requires fileno(), iter() doesn't)
- tests/test_finetune.py: mark GPU integration test @pytest.mark.gpu
- pytest.ini: register gpu and slow markers
2026-05-17 11:21:58 -07:00
cac91dd8a2 docs: bump version badge to match latest Forgejo release 2026-05-17 11:19:13 -07:00
2b990a603a feat: log corpus receiver — accept Turnstone push batches and label for logreading fine-tune
Adds corpus.db (corpus_sources, corpus_batches, corpus_entries), a FastAPI router
at /api/corpus with receive/label/skip/stats/export endpoints, and seeds consent
tokens for xanderland + orchard nodes from label_tool.yaml. PII flag excludes
entries from JSONL export. Closes avocet#61.
2026-05-11 17:07:54 -07:00
9fdaeeb3d6 feat: multi-bench dashboard, API path migration, benchmark reliability fixes
- dashboard: eval card now shows last run + score for all bench types
  (classifier, LLM, style, plans) via new _get_recent_bench_runs()
- dashboard: skip cforch LLM-bench list summaries when scanning for
  classifier best_macro_f1 (fixes _find_latest_classifier_bench)
- cforch: stale _BENCH_RUNNING flag now auto-resets if process exited;
  idle timeout (120s via select) kills hung benchmark if node crashes
- api: add /api/finetune/{run,cancel} backward-compat shims while
  ClassifierTab fine-tune section is migrated to TrainJobsView
- ClassifierTab: migrate all /api/benchmark/* paths to /api/cforch/*;
  fix null-safety on results.models access; load fine-tuned models from
  /api/train/results instead of /api/finetune/status
- CompareTab: extend model picker to include vllm + cf-text alongside
  ollama, grouped by service; pre-select all LLM_SERVICES on load
- LlmEvalTab: null-safety on quality_by_task_type lookups
- models: AVOCET_MODELS_DIR env var overrides default models/ path
2026-05-11 09:05:12 -07:00
71bf88d09b feat: implement results table, rating buttons, export UI, and a11y polish 2026-05-11 08:16:52 -07:00
bc4ca1095c feat: add embed-compare route, sidebar nav entry, and full input UI 2026-05-11 08:14:30 -07:00
b6aed3dd1b chore: add pagepiper imitate entry and embed_bench section to config example 2026-05-11 08:11:30 -07:00
1ad7ba322a feat: add embed-bench rate and export endpoints 2026-05-11 08:07:17 -07:00
32e3b2a0dd feat: add embed-bench run endpoint with SSE streaming 2026-05-07 09:05:34 -07:00
12117ad0c6 fix: narrow exception types in get_models, fix patch targets in tests, add type annotation 2026-05-07 09:03:37 -07:00
5939c67b9f feat: add embed-bench models endpoint and register router in aggregator 2026-05-07 09:01:25 -07:00
5ea77da97d fix: add _cosine dimension guard, fix return type annotation, add zero-vector test 2026-05-07 08:59:24 -07:00
276bdadb92 feat: add embed_bench module scaffold and _cosine() helper 2026-05-07 08:37:18 -07:00
6f9aad126e docs(readme): landing page rewrite — three-stage pipeline explained, full CLI reference, data flow diagram, label table 2026-05-06 08:51:46 -07:00
258bbdc0af chore(deps): fix 10 Dependabot CVEs — vite 7.3.2, defu 6.1.7, yaml 2.8.4, picomatch 4.0.4, undici 7.25.0 2026-05-06 08:41:05 -07:00
32872d1ec6 fix: assigned-only state, remove dead HfNodeModelPanel prop, deduplicate yaml example 2026-05-05 22:11:02 -07:00
1521198cb1 fix: code quality fixes from review (SSE abort, aria-live, shared types, type safety)
- Add AbortController to SSE pull stream in OllamaModelPanel; abort on unmount
- Fix SSE loop: break on success/error events, call fetchModels() after the loop
- Add AbortController to fetchModels() and fetchProfile() one-shot fetches
- Add onUnmounted cleanup to both panel components
- Extract GpuEntry, ServiceInfo, NodeSummary to web/src/types/nodes.ts
- Remove duplicate interface definitions from NodeCard, GpuRow, NodeManagementView
- Fix aria-live regions: persistent container with v-if on inner span (avoids
  screen reader announcement miss on initial mount)
- Tighten STATE_LABELS/STATE_ICONS to Record<ServiceState, string> for exhaustiveness
- Add explicit (await r.json()) as NodeSummary[] cast in fetchNodes()
2026-05-05 21:35:13 -07:00
8dda040480 fix: move /nodes route immediately after /fleet per spec 2026-05-05 21:29:35 -07:00
bf675ed1f6 feat: add OllamaModelPanel and HfNodeModelPanel Vue components 2026-05-05 21:24:38 -07:00
0efd1aedbe feat: add NodeCard, GpuRow, ServiceBadge Vue components 2026-05-05 21:24:32 -07:00
4c225b94f5 feat: add /nodes route, AppSidebar nav item, and NodeManagementView 2026-05-05 21:24:27 -07:00
1cd9c5d455 fix: move json import to module scope in nodes.py 2026-05-05 21:01:32 -07:00
5702a7190b feat: add Ollama list/pull-SSE/delete endpoints 2026-05-05 20:41:29 -07:00
55b017ba3b fix: log coordinator reload failures in update_gpu_services
- Replace bare `except Exception: pass` with `except Exception as exc` and a
  logger.warning call that surfaces node_id and the exception for diagnostics.
- Move `import os as _os` from mid-file (between test functions) to the
  top-level import block to satisfy PEP 8 and linter expectations.
2026-05-05 20:36:08 -07:00
f952ec8971 feat: add profile endpoint and GPU service assignment with compatibility check 2026-05-05 20:33:41 -07:00
fd8cb622a1 feat: add GET /api/nodes-mgmt/nodes/{node_id}/profile endpoint 2026-05-05 20:31:22 -07:00
47cb9f661f fix: narrow exception handling in list_nodes, move mock imports to top
- Remove redundant httpx.ConnectError from nodes except clause (it's a
  subclass of HTTPError so the tuple catch was redundant)
- Narrow services except clause from bare Exception to httpx.HTTPError,
  add logger.warning with coordinator_url for debuggability
- Move `from unittest.mock import MagicMock, patch` from mid-file to
  the top-of-file import block with the other stdlib/third-party imports
2026-05-05 20:18:50 -07:00
c2de9e53da feat: implement GET /api/nodes-mgmt/nodes with coordinator proxy and profile merge 2026-05-05 20:16:06 -07:00
c039ea4698 fix: remove unused imports and em dash in nodes.py scaffold
- Drop unused StreamingResponse import from app/nodes.py (will be
  re-added in Task 2 when the SSE endpoint is implemented)
- Replace em dash with colon in _get_ollama_url HTTPException detail
- Remove unused os and unittest.mock imports from test_nodes.py
  (mock imports will return in Task 2 tests)
2026-05-05 19:59:32 -07:00
95afddb772 feat: add nodes.py scaffold with set_config_dir and router mount
- Create app/nodes.py with _CONFIG_DIR testability seam, _load_config,
  _profiles_dir, _profile_path, _load_profile, _get_ollama_url helpers,
  and stub list_nodes endpoint returning [] when no coordinator_url is set
- Mount nodes router at /api/nodes-mgmt in app/api.py
- Add profiles_dir comment to config/label_tool.yaml.example cforch section
- Create tests/test_nodes.py with autouse fixture and two passing tests
2026-05-05 19:35:28 -07:00
cbe8c0f03e feat(benchmark): wire EmbeddingKNNAdapter into MODEL_REGISTRY; add embed_model config
- Add embed_model: nomic-embed-text to config/label_tool.yaml (local, gitignored)
- Add # embed_model: commented example to config/label_tool.yaml.example
- Add pyyaml>=6.0 to requirements.txt (explicit dep for _resolve_urls yaml.safe_load)
- Add params assertion to test_embed_knn_nomic_registry_entry
2026-05-05 14:05:45 -07:00
5df33b0f41 feat(benchmark): wire EmbeddingKNNAdapter into MODEL_REGISTRY as embed-knn-nomic 2026-05-05 12:43:48 -07:00
41584de5df fix(benchmark): guard empty exemplars, warn on malformed JSON in build_exemplars_from_jsonl 2026-05-05 12:41:46 -07:00
1d4c07e4a0 feat(benchmark): add build_exemplars_from_jsonl() for k-NN seed 2026-05-05 11:43:12 -07:00
e823b5e76d fix(classifier): majority-vote key, partial-load guard, sparse label test 2026-05-05 11:39:24 -07:00
88bc6bed67 feat(classifier): implement EmbeddingKNNAdapter.classify() with k-NN vote 2026-05-05 08:04:54 -07:00
4a64a6686d fix(classifier): atomic embed assignment, logging on orch failure, guard double load 2026-05-05 07:53:15 -07:00
f2f150b4fb feat(classifier): implement EmbeddingKNNAdapter.load() and unload() 2026-05-05 07:12:53 -07:00
72449561cf feat(classifier): add EmbeddingKNNAdapter skeleton and constructor tests 2026-05-05 06:08:21 -07:00
c177fb1628 fix(classifier): quality fixes for DEFAULT_EXEMPLARS — remove forward __all__ entry, tighten tests, fix survey exemplar 2026-05-04 20:03:18 -07:00
3be5055e31 feat(classifier): add DEFAULT_EXEMPLARS for embedding k-NN fallback 2026-05-04 17:44:44 -07:00
78b64d007d feat(classifier): add _cosine() helper for embedding similarity 2026-05-04 17:41:45 -07:00
bce932461a feat: plans benchmark harness — model scoring for CF planning prompts
Adds benchmark_plans.py script, plans_bench API router, PlansBenchTab Vue
component, and registers /api/plans-bench in api.py. Also extends models
registry (cf-text catalog integration), cforch client, LlmEvalTab, and
ModelsView with cf-orch fleet support. Wires Planning mode into BenchmarkView.
2026-05-02 23:36:04 -07:00
e11db5ccd9 fix: align train job/results API envelope, config_json key, progress SSE, dashboard model_key
- GET /api/train/jobs now returns {"jobs":[...]} instead of bare array
- GET /api/train/results now returns {"results":[...]} instead of bare array
- POST /api/train/jobs body key renamed config -> config_json to match Pydantic model
- SSE log handler now handles 'progress' event type (backend never emits 'log')
- Dashboard _get_active_jobs() adds model_key to SELECT and return dict
- corrections.py docstring updated: both /api/corrections and /api/sft prefixes noted
- test_train.py assertions updated to unwrap new envelope shapes
2026-05-02 21:22:18 -07:00
13d1a394d5 fix: add loading state, widen nullable types, add API response guard in TrainResultsView 2026-05-02 20:49:34 -07:00
b077371107 feat: add TrainResultsView with training history table and Fleet registration links 2026-05-02 20:46:03 -07:00
53b25b27ab fix: surface cancel errors, fix SSE sentinel scroll, add missing test coverage in TrainJobsView 2026-05-02 20:33:03 -07:00
e014da2dec feat: add TrainJobsView with job queue, form submission, cancel, and SSE log streaming 2026-05-02 20:28:19 -07:00
c48db45d91 test: fix async flush and add mode-switch coverage in BenchmarkView 2026-05-02 19:35:02 -07:00
d0ba75b995 feat: extract CompareView at /eval/compare; remove Compare tab from BenchmarkView 2026-05-02 18:03:13 -07:00
a134af8b7b feat: add DashboardView with flywheel stage cards and CTA nudges 2026-05-02 16:50:24 -07:00
6ef6f06023 feat: restructure AppSidebar into two-domain nav with section headers and flywheel signal badges 2026-05-02 13:52:45 -07:00
5bdb095235 feat: restructure router into /data/* /eval/* /train/* domains with backward-compat redirects
- Export named `routes` array from router/index.ts for testability
- Move label/fetch/corrections/imitate under /data/* namespace
- Move benchmark/compare under /eval/* namespace
- Add /train/jobs and /train/results under /train/* namespace
- Add / -> DashboardView and /fleet -> ModelsView (replaces old / -> LabelView)
- Add backward-compat redirects for all old flat paths (/benchmark, /models, /stats, /label, /fetch, /corrections, /imitate)
- Add stub views for DashboardView, CompareView, TrainJobsView, TrainResultsView (implemented in later tasks)
- Add router.test.ts: 16 tests covering route structure and redirect targets
2026-05-02 13:00:04 -07:00
0904967320 feat: slim api.py to factory-only; all domain routes in dedicated modules
Replace 149-line api.py (with inline helpers, JSONL utilities, and ad-hoc
router registrations) with a 57-line pure factory. All business logic was
already extracted to domain modules in B1-B7; this removes the dead code
and adds the /api/corrections/* prefix alongside the /api/sft/* backward-
compat alias. Smoke tests updated to cover the new /api/corrections/ingest
and /api/dashboard routes.
2026-05-02 09:55:58 -07:00
8fda821e15 feat: add POST /ingest endpoint to corrections API with Bearer auth
Adds IngestRequest model and POST /api/sft/ingest route to
app/data/corrections.py. Sibling CF products (Peregrine, Kiwi, etc.)
can push pre-approved corrections via Bearer token auth
(AVOCET_INGESTION_SECRET). Records land as status=approved in both
sft_candidates.jsonl and sft_approved.jsonl immediately.

7 tests in tests/test_data_corrections.py cover 503 (secret unset),
401 (missing/malformed header), 403 (wrong secret), happy-path writes
to both files, and optional label field.
2026-05-02 09:07:10 -07:00
0853ed7d56 fix: add logger.warning to silent except blocks in dashboard._find_latest_eval 2026-05-01 23:36:19 -07:00
aa742bcfc0 feat: add GET /api/dashboard flywheel aggregate endpoint 2026-05-01 23:30:04 -07:00
32d3436bbd fix: path traversal guard, python_bin config, completed_at on Popen failure 2026-05-01 23:24:00 -07:00
766fbafa02 feat: build SQLite-backed train job queue in app/train/train.py
Replaces the ad-hoc _running_procs dict in api.py with a persistent,
inspectable SQLite job queue. Removes old /api/finetune/* routes and
_best_cuda_device from api.py. Adds /api/train/* routes (list, create,
get, cancel, run SSE, results). 16 new tests all passing.
2026-05-01 23:05:11 -07:00
d432026fd7 fix: restore real plans_bench.py (was accidentally stubbed) 2026-05-01 22:25:22 -07:00
bccb385f61 feat: build app/eval/cforch.py aggregating eval benchmark routers 2026-05-01 22:23:06 -07:00
d74ad3f972 feat: move imitate API into app/data/imitate.py 2026-05-01 22:12:19 -07:00
99ea39fe38 feat: move SFT corrections API into app/data/corrections.py 2026-05-01 22:02:22 -07:00
2054866ff1 feat: extract fetch routes and IMAP helpers into app/data/fetch.py 2026-05-01 21:57:31 -07:00
cbec776ef1 fix: restore ensure_ascii=False in utils jsonl helpers; remove dead _last_action from api.py 2026-05-01 20:59:44 -07:00
167d7351e3 feat: extract label queue API into app/data/label.py 2026-05-01 18:48:14 -07:00
6689ff07b1 chore: gitignore .worktrees/ directory 2026-05-01 12:25:23 -07:00
0745bc3f70 refactor: import detect_byok from cf-core, remove local copy 2026-04-25 16:45:47 -07:00
2891606765 feat(cloud_session): add session resolution + forward user_id to cf-orch imitate
app/cloud_session.py:
- Thin wrapper around cf_core.cloud_session.CloudSessionFactory
- BYOK detection reads ~/.config/circuitforge/llm.yaml (same path as other products)
- get_session: FastAPI dependency, returns CloudUser (user_id, tier, has_byok)
- require_tier: dependency factory for tier-gated routes

app/imitate.py:
- _run_cftext gains user_id: str | None param; non-None values included in
  the cf-orch ServiceAllocateRequest so premium users get their custom models
- run_imitate injects session via Depends(_get_imitate_session); extracts user_id,
  filters out local/anon sessions (they get the shared catalog), passes real
  cloud user_id to the ThreadPoolExecutor fanout
- _get_imitate_session wraps get_session with a try/except so imitate keeps
  working in envs where cloud_session deps aren't installed
2026-04-24 16:41:45 -07:00
85 changed files with 17511 additions and 2544 deletions

View file

@ -17,3 +17,7 @@ CF_LICENSE_KEY=CFG-AVCT-xxxx-xxxx-xxxx
# Set one of these to use a cloud LLM instead of a local model.
# ANTHROPIC_API_KEY=sk-ant-...
# OPENAI_API_KEY=sk-...
# ── HuggingFace (required for gated/terms-restricted model downloads) ─────────
# Generate at https://huggingface.co/settings/tokens and accept model terms first.
# HF_TOKEN=hf_xxxxxxxxxxxxxxxxxxxx

7
.gitignore vendored
View file

@ -8,6 +8,9 @@ __pycache__/
config/label_tool.yaml
# Data files (user-generated, not for version control)
data/corpus.db
data/corpus.db-wal
data/corpus.db-shm
data/email_score.jsonl
data/email_label_queue.jsonl
data/email_compare_sample.jsonl
@ -20,3 +23,7 @@ data/sft_approved.jsonl
# Claude context — BSL 1.1, keep out of version control
CLAUDE.md
docs/superpowers/
.superpowers/
# Git worktrees
.worktrees/

183
README.md
View file

@ -1,22 +1,120 @@
# Avocet — Email Classifier Training Tool
<div align="center">
<img src="docs/avocet-logo.svg" alt="Avocet" height="96" />
> *Part of the CircuitForge LLC internal infrastructure suite.*
# Avocet
**Status:** Internal beta — label tool and benchmark harness complete. Used to build training data for Peregrine's email classifier.
**Email classifier training tool — label, benchmark, fine-tune.**
[![Status: Internal Beta](https://img.shields.io/badge/status-internal%20beta-blue)]()
[![Version](https://img.shields.io/badge/version-0.5.0-green)](https://git.opensourcesolarpunk.com/Circuit-Forge/avocet/releases)
[![License: BSL 1.1](https://img.shields.io/badge/license-BSL%201.1-orange)](LICENSE)
[![Stack: Vue 3 + FastAPI](https://img.shields.io/badge/stack-Vue%203%20%2B%20FastAPI-brightgreen)]()
[![CircuitForge](https://img.shields.io/badge/by-CircuitForge-black)](https://circuitforge.tech)
</div>
---
## What it does
## What is Avocet?
Avocet is the data pipeline for building and benchmarking email classifiers. It has two layers:
Avocet is the internal data pipeline Circuit Forge uses to build, evaluate, and fine-tune email classifiers. It implements a three-stage workflow: human labelers review emails one at a time in a drag-to-bucket UI and produce a ground-truth dataset; the benchmark harness scores any number of HuggingFace zero-shot models against that dataset and produces a ranked comparison; and the fine-tune harness adapts the best-scoring base model to the labeled distribution. The output feeds directly into Peregrine's email classification layer. No LLM API key required for the label tool or benchmark — all inference runs locally via HuggingFace Transformers.
**No LLM required.** Avocet uses zero-shot HuggingFace classification models — no API key, no cloud inference, no GPU required for the label tool. The benchmark harness can optionally export LLM-labeled emails from a Peregrine staging DB, but human labeling via the card-stack UI is the primary workflow.
---
**Layer 1 — Label tool**
Card-stack UI for building ground-truth classifier benchmark data. Fetch emails from one or more IMAP accounts (with targeted date-range and sender/subject filters), review them card-by-card, and label each with a job-search category. Labeled output feeds the benchmark harness.
## Quick Start
**Layer 2 — Benchmark harness**
Scores HuggingFace zero-shot classification models against the labeled dataset. Supports slow/large model inclusion, visual side-by-side comparison on live emails, and export of LLM-labeled emails from a Peregrine staging DB.
```bash
git clone https://git.opensourcesolarpunk.com/Circuit-Forge/avocet.git
cd avocet
# Copy config template and fill in your IMAP credentials
cp config/label_tool.yaml.example config/label_tool.yaml
# Start the label tool (Vue SPA + FastAPI, port 8503)
./manage.sh start
./manage.sh open
```
---
## Features
- **Drag-to-bucket label UI** — ASMR-style card interface; drag emails into labeled buckets or discard without queuing noise into the training set
- **Targeted IMAP fetch** — pull emails by date range, sender, or subject filter across multiple accounts without flooding the queue
- **Email classifier benchmark** — score any HuggingFace zero-shot model against your labeled JSONL; side-by-side comparison on live IMAP emails
- **Planning benchmark** — evaluate LLMs on structured planning tasks; compare models head-to-head with verbose diff output
- **Writing style benchmark** — compare Ollama models on writing style coherence; scan local disk for existing outputs
- **Fine-tune harness** — HuggingFace Transformers fine-tuning from labeled ground truth; classifier adapter interface for swapping backends at runtime
- **Local inference first** — no API key required; GPU optional; designed to run on developer hardware
- **Hot-reload dev mode** — uvicorn `--reload` + Vite HMR (hot module replacement) for fast iteration on both API and UI
---
## CLI Reference
All operations go through `manage.sh`.
### Label Tool
```bash
./manage.sh start # Build Vue SPA and start FastAPI on port 8503
./manage.sh stop # Stop FastAPI server
./manage.sh restart # Stop, rebuild, and restart
./manage.sh status # Show running state and port
./manage.sh logs # Tail the API log
./manage.sh open # Open http://localhost:8503 in browser
./manage.sh dev # Hot-reload: uvicorn --reload + Vite HMR
./manage.sh test # Run pytest suite
```
### Email Classifier Benchmark
```bash
./manage.sh benchmark [args] # Run benchmark_classifier.py
./manage.sh list-models # List available zero-shot models
./manage.sh score # Score models against labeled JSONL
./manage.sh score --include-slow # Include large/slow models
./manage.sh compare --limit 30 # Side-by-side comparison on live IMAP emails
```
### Planning Benchmark
```bash
./manage.sh plans-bench [args] # Run benchmark_plans.py
./manage.sh plans-list # List available models
./manage.sh plans-run <model> [args] # Run a single model (verbose)
./manage.sh plans-compare <m1> <m2> [...] # Compare models side-by-side
```
### Writing Style Benchmark
```bash
./manage.sh style-bench [args] # Run benchmark_style.py
./manage.sh style-list # List available Ollama models
./manage.sh style-run [args] # Run writing style benchmark
./manage.sh style-last # Print most recent benchmark report
```
---
## Data Flow
```
IMAP accounts
→ fetch (targeted or wide)
→ email_label_queue.jsonl
email_label_queue.jsonl
→ label tool drag-to-bucket UI
→ email_score.jsonl (ground truth)
email_score.jsonl
→ benchmark harness
→ model rankings
best model
→ fine-tune harness
→ Peregrine classifier adapter
```
---
@ -38,69 +136,42 @@ Scores HuggingFace zero-shot classification models against the labeled dataset.
## Stack
| Layer | Tech |
|-------|------|
| Label UI | Streamlit (port 8503, auto-increments on collision) |
| Layer | Technology |
|-------|-----------|
| Label UI | Vue 3 SPA (Vite) |
| API | FastAPI + uvicorn (port 8503) |
| Benchmark | Python + HuggingFace Transformers |
| Email fetch | IMAP (multi-account, targeted date/sender/subject filter) |
| Data | JSONL (`data/email_label_queue.jsonl`, `data/email_score.jsonl`) |
| Config | `config/label_tool.yaml` (gitignored — see `.example`) |
Conda environments:
- `job-seeker` — label tool UI
- `job-seeker-classifiers` — benchmark harness (separate env for heavy deps)
| Runtime | SQLite |
| Config | `config/label_tool.yaml` (gitignored — `.example` committed) |
---
## Running
## Logo
```bash
./manage.sh start # start label tool UI (port collision-safe from 8503)
./manage.sh stop # stop
./manage.sh restart # restart
./manage.sh status # show running state and port
./manage.sh logs # tail label tool log
./manage.sh open # open in browser
```
Benchmark:
```bash
./manage.sh benchmark --list-models # list available zero-shot models
./manage.sh score # score models against labeled JSONL
./manage.sh score --include-slow # include large/slow models
./manage.sh compare --limit 30 # visual comparison on live IMAP emails
```
Dev:
```bash
./manage.sh test # run pytest suite
```
The Avocet logo (`avocet_v1_poly.svg`) lives in the shared graphics repo. Copy it to `docs/avocet-logo.svg` to render correctly in this README.
---
## Data flow
## About
```
IMAP accounts → fetch (targeted or wide) → email_label_queue.jsonl
→ label tool card UI → email_score.jsonl
→ benchmark harness → model rankings
→ best model → Peregrine classifier adapter
```
Avocet is internal CircuitForge infrastructure, open source as a reference implementation. It is not a user-facing product. The primary consumer is [Peregrine](https://git.opensourcesolarpunk.com/Circuit-Forge/peregrine), CircuitForge's job-search pipeline tool.
Targeted fetch: date range + sender/subject filter for pulling historical emails on specific senders or topics without flooding the queue.
Docs: [docs.circuitforge.tech/avocet](https://docs.circuitforge.tech/avocet)
Discard: removes an email from the queue without writing to the score file — for emails that don't belong in the training set.
## Forgejo-primary
---
## Classifier adapters
`app/classifier_adapters.py` provides a common interface for swapping classifier backends. Falls back to the label name when no `LABEL_DESCRIPTIONS` entry is configured for a label (RerankerAdapter).
Avocet is developed and maintained on Forgejo at [git.opensourcesolarpunk.com/Circuit-Forge/avocet](https://git.opensourcesolarpunk.com/Circuit-Forge/avocet). GitHub and Codeberg are read-only mirrors.
---
## License
BSL 1.1 — internal tool, not user-facing.
[Business Source License 1.1](LICENSE) — classifier training is an AI feature under the CircuitForge licensing model.
© 2026 Circuit Forge LLC
Free for personal non-commercial self-hosting. Commercial use or SaaS re-hosting requires a paid license. Converts to MIT after 4 years.
Humans own design, architecture, code review, testing, and verification. LLMs are part of our development workflow. [Our positions on LLM use →](https://circuitforge.tech/positions)
© 2026 Circuit Forge LLC — Privacy · Safety · Accessibility

View file

@ -1,623 +1,95 @@
"""Avocet — FastAPI REST layer.
"""Avocet -- FastAPI app factory.
JSONL read/write helpers and FastAPI app instance.
Endpoints and static file serving are added in subsequent tasks.
Mounts all domain routers and serves the Vue SPA.
All business logic lives in the domain modules below.
"""
from __future__ import annotations
import hashlib
import json
import os
import subprocess as _subprocess
import yaml
from pathlib import Path
from datetime import datetime, timezone
from fastapi import FastAPI, HTTPException, Query
from pydantic import BaseModel
_ROOT = Path(__file__).parent.parent
_DATA_DIR: Path = _ROOT / "data" # overridable in tests via set_data_dir()
_MODELS_DIR: Path = _ROOT / "models" # overridable in tests via set_models_dir()
_CONFIG_DIR: Path | None = None # None = use real path
# Process registry for running jobs — used by cancel endpoints.
# Keys: "benchmark" | "finetune". Values: the live Popen object.
_running_procs: dict = {}
_cancelled_jobs: set = set()
def set_data_dir(path: Path) -> None:
"""Override data directory — used by tests."""
global _DATA_DIR
_DATA_DIR = path
def _best_cuda_device() -> str:
"""Return the index of the GPU with the most free VRAM as a string.
Uses nvidia-smi so it works in the job-seeker env (no torch). Returns ""
if nvidia-smi is unavailable or no GPUs are found. Restricting the
training subprocess to a single GPU via CUDA_VISIBLE_DEVICES prevents
PyTorch DataParallel from replicating the model across all GPUs, which
would OOM the GPU with less headroom.
"""
try:
out = _subprocess.check_output(
["nvidia-smi", "--query-gpu=index,memory.free",
"--format=csv,noheader,nounits"],
text=True,
timeout=5,
)
best_idx, best_free = "", 0
for line in out.strip().splitlines():
parts = line.strip().split(", ")
if len(parts) == 2:
idx, free = parts[0].strip(), int(parts[1].strip())
if free > best_free:
best_free, best_idx = free, idx
return best_idx
except Exception:
return ""
def set_models_dir(path: Path) -> None:
"""Override models directory — used by tests."""
global _MODELS_DIR
_MODELS_DIR = path
def set_config_dir(path: Path | None) -> None:
"""Override config directory — used by tests."""
global _CONFIG_DIR
_CONFIG_DIR = path
def _config_file() -> Path:
if _CONFIG_DIR is not None:
return _CONFIG_DIR / "label_tool.yaml"
return _ROOT / "config" / "label_tool.yaml"
def reset_last_action() -> None:
"""Reset undo state — used by tests."""
global _last_action
_last_action = None
def _queue_file() -> Path:
return _DATA_DIR / "email_label_queue.jsonl"
def _score_file() -> Path:
return _DATA_DIR / "email_score.jsonl"
def _discarded_file() -> Path:
return _DATA_DIR / "discarded.jsonl"
def _read_jsonl(path: Path) -> list[dict]:
if not path.exists():
return []
lines = path.read_text(encoding="utf-8").splitlines()
return [json.loads(l) for l in lines if l.strip()]
def _write_jsonl(path: Path, records: list[dict]) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
text = "\n".join(json.dumps(r, ensure_ascii=False) for r in records)
path.write_text(text + "\n" if records else "", encoding="utf-8")
def _append_jsonl(path: Path, record: dict) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
with path.open("a", encoding="utf-8") as f:
f.write(json.dumps(record, ensure_ascii=False) + "\n")
def _item_id(item: dict) -> str:
"""Stable content-hash ID — matches label_tool.py _entry_key dedup logic."""
key = (item.get("subject", "") + (item.get("body", "") or "")[:100])
return hashlib.md5(key.encode("utf-8", errors="replace")).hexdigest()
def _normalize(item: dict) -> dict:
"""Normalize JSONL item to the Vue frontend schema.
label_tool.py stores: subject, body, from_addr, date, account (no id).
The Vue app expects: id, subject, body, from, date, source.
Both old (from_addr/account) and new (from/source) field names are handled.
"""
return {
"id": item.get("id") or _item_id(item),
"subject": item.get("subject", ""),
"body": item.get("body", ""),
"from": item.get("from") or item.get("from_addr", ""),
"date": item.get("date", ""),
"source": item.get("source") or item.get("account", ""),
}
from fastapi import FastAPI
app = FastAPI(title="Avocet API")
from app.sft import router as sft_router
app.include_router(sft_router, prefix="/api/sft")
# -- Domain routers --------------------------------------------------------
from app.models import router as models_router
import app.models as _models_module
app.include_router(models_router, prefix="/api/models")
from app.data.label import router as label_router
app.include_router(label_router, prefix="/api")
from app.cforch import router as cforch_router
app.include_router(cforch_router, prefix="/api/cforch")
from app.data.fetch import router as fetch_router
app.include_router(fetch_router, prefix="/api")
from app.imitate import router as imitate_router
from app.data.corrections import router as corrections_router
app.include_router(corrections_router, prefix="/api/corrections")
# Backward-compat alias -- remove when Vue SPA is updated to /api/corrections/*
app.include_router(corrections_router, prefix="/api/sft")
from app.data.imitate import router as imitate_router
app.include_router(imitate_router, prefix="/api/imitate")
from app.style import router as style_router
app.include_router(style_router, prefix="/api/style")
from app.eval.cforch import router as eval_router
app.include_router(eval_router, prefix="/api")
from app.train.train import router as train_router
app.include_router(train_router, prefix="/api/train")
from app.plans_bench import router as plans_bench_router
app.include_router(plans_bench_router, prefix="/api/plans-bench")
# In-memory last-action store (single user, local tool — in-memory is fine)
_last_action: dict | None = None
# -- Backward-compat shims (ClassifierTab still uses old /api/finetune/* paths)
# Remove once ClassifierTab fine-tune section is migrated to TrainJobsView.
@app.get("/api/queue")
def get_queue(limit: int = Query(default=10, ge=1, le=50)):
items = _read_jsonl(_queue_file())
return {"items": [_normalize(x) for x in items[:limit]], "total": len(items)}
class LabelRequest(BaseModel):
id: str
label: str
@app.post("/api/label")
def post_label(req: LabelRequest):
global _last_action
items = _read_jsonl(_queue_file())
match = next((x for x in items if _normalize(x)["id"] == req.id), None)
if not match:
raise HTTPException(404, f"Item {req.id!r} not found in queue")
record = {**match, "label": req.label,
"labeled_at": datetime.now(timezone.utc).isoformat()}
_append_jsonl(_score_file(), record)
_write_jsonl(_queue_file(), [x for x in items if _normalize(x)["id"] != req.id])
_last_action = {"type": "label", "item": match, "label": req.label}
return {"ok": True}
class SkipRequest(BaseModel):
id: str
@app.post("/api/skip")
def post_skip(req: SkipRequest):
global _last_action
items = _read_jsonl(_queue_file())
match = next((x for x in items if _normalize(x)["id"] == req.id), None)
if not match:
raise HTTPException(404, f"Item {req.id!r} not found in queue")
reordered = [x for x in items if _normalize(x)["id"] != req.id] + [match]
_write_jsonl(_queue_file(), reordered)
_last_action = {"type": "skip", "item": match}
return {"ok": True}
class DiscardRequest(BaseModel):
id: str
@app.post("/api/discard")
def post_discard(req: DiscardRequest):
global _last_action
items = _read_jsonl(_queue_file())
match = next((x for x in items if _normalize(x)["id"] == req.id), None)
if not match:
raise HTTPException(404, f"Item {req.id!r} not found in queue")
record = {**match, "label": "__discarded__",
"discarded_at": datetime.now(timezone.utc).isoformat()}
_append_jsonl(_discarded_file(), record)
_write_jsonl(_queue_file(), [x for x in items if _normalize(x)["id"] != req.id])
_last_action = {"type": "discard", "item": match}
return {"ok": True}
@app.delete("/api/label/undo")
def delete_undo():
global _last_action
if not _last_action:
raise HTTPException(404, "No action to undo")
action = _last_action
item = action["item"] # always the original clean queue item
# Perform file operations FIRST — only clear _last_action on success
if action["type"] == "label":
records = _read_jsonl(_score_file())
if not records:
raise HTTPException(409, "Score file is empty — cannot undo label")
_write_jsonl(_score_file(), records[:-1])
items = _read_jsonl(_queue_file())
_write_jsonl(_queue_file(), [item] + items)
elif action["type"] == "discard":
records = _read_jsonl(_discarded_file())
if not records:
raise HTTPException(409, "Discarded file is empty — cannot undo discard")
_write_jsonl(_discarded_file(), records[:-1])
items = _read_jsonl(_queue_file())
_write_jsonl(_queue_file(), [item] + items)
elif action["type"] == "skip":
items = _read_jsonl(_queue_file())
item_id = _normalize(item)["id"]
items = [item] + [x for x in items if _normalize(x)["id"] != item_id]
_write_jsonl(_queue_file(), items)
# Clear AFTER all file operations succeed
_last_action = None
return {"undone": {"type": action["type"], "item": _normalize(item)}}
# Label metadata — 10 labels matching label_tool.py
_LABEL_META = [
{"name": "interview_scheduled", "emoji": "\U0001f4c5", "color": "#4CAF50", "key": "1"},
{"name": "offer_received", "emoji": "\U0001f389", "color": "#2196F3", "key": "2"},
{"name": "rejected", "emoji": "\u274c", "color": "#F44336", "key": "3"},
{"name": "positive_response", "emoji": "\U0001f44d", "color": "#FF9800", "key": "4"},
{"name": "survey_received", "emoji": "\U0001f4cb", "color": "#9C27B0", "key": "5"},
{"name": "neutral", "emoji": "\u2b1c", "color": "#607D8B", "key": "6"},
{"name": "event_rescheduled", "emoji": "\U0001f504", "color": "#FF5722", "key": "7"},
{"name": "digest", "emoji": "\U0001f4f0", "color": "#00BCD4", "key": "8"},
{"name": "new_lead", "emoji": "\U0001f91d", "color": "#009688", "key": "9"},
{"name": "hired", "emoji": "\U0001f38a", "color": "#FFC107", "key": "h"},
]
@app.get("/api/config/labels")
def get_labels():
return _LABEL_META
@app.get("/api/config")
def get_config():
f = _config_file()
if not f.exists():
return {"accounts": [], "max_per_account": 500}
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
return {"accounts": raw.get("accounts", []), "max_per_account": raw.get("max_per_account", 500)}
class ConfigPayload(BaseModel):
accounts: list[dict]
max_per_account: int = 500
@app.post("/api/config")
def post_config(payload: ConfigPayload):
f = _config_file()
f.parent.mkdir(parents=True, exist_ok=True)
tmp = f.with_suffix(".tmp")
tmp.write_text(yaml.dump(payload.model_dump(), allow_unicode=True, sort_keys=False),
encoding="utf-8")
tmp.rename(f)
return {"ok": True}
@app.get("/api/stats")
def get_stats():
records = _read_jsonl(_score_file())
counts: dict[str, int] = {}
for r in records:
lbl = r.get("label", "")
if lbl:
counts[lbl] = counts.get(lbl, 0) + 1
benchmark_results: dict = {}
benchmark_path = _DATA_DIR / "benchmark_results.json"
if benchmark_path.exists():
try:
benchmark_results = json.loads(benchmark_path.read_text(encoding="utf-8"))
except Exception:
pass
return {
"total": len(records),
"counts": counts,
"score_file_bytes": _score_file().stat().st_size if _score_file().exists() else 0,
"benchmark_results": benchmark_results,
}
@app.get("/api/stats/download")
def download_stats():
from fastapi.responses import FileResponse
if not _score_file().exists():
raise HTTPException(404, "No score file")
return FileResponse(
str(_score_file()),
filename="email_score.jsonl",
media_type="application/jsonlines",
headers={"Content-Disposition": 'attachment; filename="email_score.jsonl"'},
)
class AccountTestRequest(BaseModel):
account: dict
@app.post("/api/accounts/test")
def test_account(req: AccountTestRequest):
from app.imap_fetch import test_connection
ok, message, count = test_connection(req.account)
return {"ok": ok, "message": message, "count": count}
from fastapi.responses import StreamingResponse
# ---------------------------------------------------------------------------
# Benchmark endpoints
# ---------------------------------------------------------------------------
@app.get("/api/benchmark/models")
def get_benchmark_models() -> dict:
"""Return installed models grouped by adapter_type category."""
models_dir: Path = _models_module._MODELS_DIR
categories: dict[str, list[dict]] = {
"ZeroShotAdapter": [],
"RerankerAdapter": [],
"GenerationAdapter": [],
"Unknown": [],
}
if models_dir.exists():
for sub in models_dir.iterdir():
if not sub.is_dir():
continue
info_path = sub / "model_info.json"
adapter_type = "Unknown"
repo_id: str | None = None
if info_path.exists():
try:
info = json.loads(info_path.read_text(encoding="utf-8"))
adapter_type = info.get("adapter_type") or info.get("adapter_recommendation") or "Unknown"
repo_id = info.get("repo_id")
except Exception:
pass
bucket = adapter_type if adapter_type in categories else "Unknown"
entry: dict = {"name": sub.name, "repo_id": repo_id, "adapter_type": adapter_type}
categories[bucket].append(entry)
return {"categories": categories}
@app.get("/api/benchmark/results")
def get_benchmark_results():
"""Return the most recently saved benchmark results, or an empty envelope."""
path = _DATA_DIR / "benchmark_results.json"
if not path.exists():
return {"models": {}, "sample_count": 0, "timestamp": None}
return json.loads(path.read_text())
@app.get("/api/benchmark/run")
def run_benchmark(include_slow: bool = False, model_names: str = ""):
"""Spawn the benchmark script and stream stdout as SSE progress events."""
python_bin = "/devl/miniconda3/envs/job-seeker-classifiers/bin/python"
script = str(_ROOT / "scripts" / "benchmark_classifier.py")
cmd = [python_bin, script, "--score", "--save"]
if include_slow:
cmd.append("--include-slow")
if model_names:
names = [n.strip() for n in model_names.split(",") if n.strip()]
if names:
cmd.extend(["--models"] + names)
def generate():
try:
proc = _subprocess.Popen(
cmd,
stdout=_subprocess.PIPE,
stderr=_subprocess.STDOUT,
text=True,
bufsize=1,
cwd=str(_ROOT),
)
_running_procs["benchmark"] = proc
_cancelled_jobs.discard("benchmark") # clear any stale flag from a prior run
try:
for line in proc.stdout:
line = line.rstrip()
if line:
yield f"data: {json.dumps({'type': 'progress', 'message': line})}\n\n"
proc.wait()
if proc.returncode == 0:
yield f"data: {json.dumps({'type': 'complete'})}\n\n"
elif "benchmark" in _cancelled_jobs:
_cancelled_jobs.discard("benchmark")
yield f"data: {json.dumps({'type': 'cancelled'})}\n\n"
else:
yield f"data: {json.dumps({'type': 'error', 'message': f'Process exited with code {proc.returncode}'})}\n\n"
finally:
_running_procs.pop("benchmark", None)
except Exception as exc:
yield f"data: {json.dumps({'type': 'error', 'message': str(exc)})}\n\n"
return StreamingResponse(
generate(),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
)
# ---------------------------------------------------------------------------
# Finetune endpoints
# ---------------------------------------------------------------------------
@app.get("/api/finetune/status")
def get_finetune_status():
"""Scan models/ for training_info.json files. Returns [] if none exist."""
models_dir = _MODELS_DIR
if not models_dir.exists():
return []
results = []
for sub in models_dir.iterdir():
if not sub.is_dir():
continue
info_path = sub / "training_info.json"
if not info_path.exists():
continue
try:
info = json.loads(info_path.read_text(encoding="utf-8"))
results.append(info)
except Exception:
pass
return results
from fastapi import Query
from fastapi.responses import StreamingResponse as _StreamingResponse
@app.get("/api/finetune/run")
def run_finetune_endpoint(
model: str = "deberta-small",
epochs: int = 5,
score: list[str] = Query(default=[]),
):
"""Spawn finetune_classifier.py and stream stdout as SSE progress events."""
python_bin = "/devl/miniconda3/envs/job-seeker-classifiers/bin/python"
script = str(_ROOT / "scripts" / "finetune_classifier.py")
cmd = [python_bin, script, "--model", model, "--epochs", str(epochs)]
data_root = _DATA_DIR.resolve()
for score_file in score:
resolved = (_DATA_DIR / score_file).resolve()
if not str(resolved).startswith(str(data_root)):
raise HTTPException(400, f"Invalid score path: {score_file!r}")
cmd.extend(["--score", str(resolved)])
# Pick the GPU with the most free VRAM. Setting CUDA_VISIBLE_DEVICES to a
# single device prevents DataParallel from replicating the model across all
# GPUs, which would force a full copy onto the more memory-constrained device.
proc_env = {**os.environ, "PYTORCH_ALLOC_CONF": "expandable_segments:True"}
best_gpu = _best_cuda_device()
if best_gpu:
proc_env["CUDA_VISIBLE_DEVICES"] = best_gpu
gpu_note = f"GPU {best_gpu}" if best_gpu else "CPU (no GPU found)"
def generate():
yield f"data: {json.dumps({'type': 'progress', 'message': f'[api] Using {gpu_note} (most free VRAM)'})}\n\n"
try:
proc = _subprocess.Popen(
cmd,
stdout=_subprocess.PIPE,
stderr=_subprocess.STDOUT,
text=True,
bufsize=1,
cwd=str(_ROOT),
env=proc_env,
)
_running_procs["finetune"] = proc
_cancelled_jobs.discard("finetune") # clear any stale flag from a prior run
try:
for line in proc.stdout:
line = line.rstrip()
if line:
yield f"data: {json.dumps({'type': 'progress', 'message': line})}\n\n"
proc.wait()
if proc.returncode == 0:
yield f"data: {json.dumps({'type': 'complete'})}\n\n"
elif "finetune" in _cancelled_jobs:
_cancelled_jobs.discard("finetune")
yield f"data: {json.dumps({'type': 'cancelled'})}\n\n"
else:
yield f"data: {json.dumps({'type': 'error', 'message': f'Process exited with code {proc.returncode}'})}\n\n"
finally:
_running_procs.pop("finetune", None)
except Exception as exc:
yield f"data: {json.dumps({'type': 'error', 'message': str(exc)})}\n\n"
return StreamingResponse(
generate(),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
)
@app.post("/api/benchmark/cancel")
def cancel_benchmark():
"""Kill the running benchmark subprocess. 404 if none is running."""
proc = _running_procs.get("benchmark")
if proc is None:
raise HTTPException(404, "No benchmark is running")
_cancelled_jobs.add("benchmark")
proc.terminate()
try:
proc.wait(timeout=3)
except _subprocess.TimeoutExpired:
proc.kill()
return {"status": "cancelled"}
def finetune_run_compat(model: str = Query(...), epochs: int = Query(5)) -> _StreamingResponse:
"""Shim: create a classifier train job and immediately stream it."""
from app.train.train import create_job, run_job, CreateJobRequest
job = create_job(CreateJobRequest(type="classifier", model_key=model, config_json={"epochs": epochs}))
return run_job(job["id"])
@app.post("/api/finetune/cancel")
def cancel_finetune():
"""Kill the running fine-tune subprocess. 404 if none is running."""
proc = _running_procs.get("finetune")
if proc is None:
raise HTTPException(404, "No finetune is running")
_cancelled_jobs.add("finetune")
proc.terminate()
try:
proc.wait(timeout=3)
except _subprocess.TimeoutExpired:
proc.kill()
return {"status": "cancelled"}
def finetune_cancel_compat() -> dict:
"""Shim: cancel the most recent running classifier job."""
from app.train.train import _db, _init_db, cancel_job
from fastapi import HTTPException
_init_db()
with _db() as conn:
row = conn.execute(
"SELECT id FROM jobs WHERE type='classifier' AND status='running' ORDER BY started_at DESC LIMIT 1"
).fetchone()
if row is None:
return {"status": "nothing_running"}
return cancel_job(row["id"])
from app.data.log_corpus import router as log_corpus_router
app.include_router(log_corpus_router, prefix="/api/corpus")
@app.get("/api/fetch/stream")
def fetch_stream(
accounts: str = Query(default=""),
days_back: int = Query(default=90, ge=1, le=365),
limit: int = Query(default=150, ge=1, le=1000),
mode: str = Query(default="wide"),
):
from app.imap_fetch import fetch_account_stream
from app.data.recipe_scan import router as recipe_scan_router
app.include_router(recipe_scan_router, prefix="/api/recipe-scan")
selected_names = {n.strip() for n in accounts.split(",") if n.strip()}
config = get_config() # reuse existing endpoint logic
selected = [a for a in config["accounts"] if a.get("name") in selected_names]
from app.dashboard import router as dashboard_router
app.include_router(dashboard_router, prefix="/api")
def generate():
known_keys = {_item_id(x) for x in _read_jsonl(_queue_file())}
total_added = 0
from app.models import router as models_router
app.include_router(models_router, prefix="/api/models")
for acc in selected:
try:
batch_emails: list[dict] = []
for event in fetch_account_stream(acc, days_back, limit, known_keys):
if event["type"] == "done":
batch_emails = event.pop("emails", [])
total_added += event["added"]
yield f"data: {json.dumps(event)}\n\n"
# Write new emails to queue after each account
if batch_emails:
existing = _read_jsonl(_queue_file())
_write_jsonl(_queue_file(), existing + batch_emails)
except Exception as exc:
error_event = {"type": "error", "account": acc.get("name", "?"),
"message": str(exc)}
yield f"data: {json.dumps(error_event)}\n\n"
from app.nodes import router as nodes_router
app.include_router(nodes_router, prefix="/api/nodes-mgmt")
queue_size = len(_read_jsonl(_queue_file()))
complete = {"type": "complete", "total_added": total_added, "queue_size": queue_size}
yield f"data: {json.dumps(complete)}\n\n"
# -- Static SPA -- MUST be last (catches all unmatched paths) ---------------
return StreamingResponse(generate(), media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"})
# Static SPA — MUST be last (catches all unmatched paths)
_ROOT = Path(__file__).parent.parent
_DIST = _ROOT / "web" / "dist"
if _DIST.exists():
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
# Serve index.html with no-cache so browsers always fetch fresh HTML after rebuilds.
# Hashed assets (/assets/index-abc123.js) can be cached forever — they change names
# when content changes (standard Vite cache-busting strategy).
_NO_CACHE = {"Cache-Control": "no-cache, no-store, must-revalidate", "Pragma": "no-cache"}
@app.get("/")

View file

@ -16,13 +16,18 @@ import json
import logging
import os
import re
import select as _select
import subprocess as _subprocess
import tempfile
from pathlib import Path
from typing import Any
from typing import Any, Optional
import urllib.parse
import yaml
from fastapi import APIRouter, HTTPException
from fastapi import APIRouter, HTTPException, Request
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
logger = logging.getLogger(__name__)
@ -75,9 +80,31 @@ def _load_cforch_config() -> dict:
"license_key": _coalesce(file_cfg.get("license_key", ""), "CF_LICENSE_KEY"),
"ollama_url": _coalesce(file_cfg.get("ollama_url", ""), "OLLAMA_HOST"),
"ollama_model": _coalesce(file_cfg.get("ollama_model", ""), "OLLAMA_MODEL"),
"judge_url": _coalesce(file_cfg.get("judge_url", ""), "CF_JUDGE_URL"),
"hf_token": _coalesce(file_cfg.get("hf_token", ""), "HF_TOKEN"),
}
def _validate_service_url(url: str, param_name: str) -> str:
"""Validate that a URL is a well-formed http/https URL with a hostname.
Guards against SSRF: only http/https is allowed; the URL must have a
non-empty host. Does not enforce an allowlist call sites are internal
tooling, not a public API.
"""
if not url:
return url
try:
parsed = urllib.parse.urlparse(url)
except Exception:
raise HTTPException(400, f"{param_name}: not a valid URL")
if parsed.scheme not in ("http", "https"):
raise HTTPException(400, f"{param_name}: URL must start with http:// or https://")
if not parsed.hostname:
raise HTTPException(400, f"{param_name}: URL has no hostname")
return url
def _strip_ansi(text: str) -> str:
"""Remove ANSI escape codes from a string."""
return re.sub(r'\x1b\[[0-9;]*m', '', text)
@ -147,54 +174,151 @@ def get_tasks() -> dict:
# ── GET /models ────────────────────────────────────────────────────────────────
# Services and roles surfaced in the benchmark model picker.
# Covers all cf-orch service types that benchmark.py can route tasks to.
_BENCH_SERVICES = frozenset({
"cf-text", "vllm", # LLM text generation
"cf-stt", # speech-to-text
"cf-tts", # text-to-speech
"cf-vision", # image classification / embedding
"cf-voice", # audio context classification
})
_BENCH_ROLES = frozenset({
"generator", "vlm", # LLM roles
"stt", "alm", # speech recognition
"tts", # speech synthesis
"vision", "embedding", # image understanding
"classifier", # audio classification (cf-voice)
})
@router.get("/models")
def get_models() -> dict:
"""Return model list from bench_models.yaml."""
"""Return model list from bench_models.yaml merged with locally installed models.
bench_models.yaml entries are listed first and take precedence; any installed
model whose repo_id is already present in the YAML is skipped. Only models
whose service is in _BENCH_SERVICES (cf-text, vllm, cf-stt, cf-tts, cf-vision,
cf-voice) are surfaced from the installed registry.
"""
cfg = _load_cforch_config()
models_path = cfg.get("bench_models", "")
if not models_path:
return {"models": []}
models: list[dict] = []
bench_ids: set[str] = set()
if models_path:
p = Path(models_path)
if not p.exists():
return {"models": []}
if p.exists():
try:
raw = yaml.safe_load(p.read_text(encoding="utf-8")) or {}
except yaml.YAMLError as exc:
logger.warning("Failed to parse bench_models.yaml %s: %s", p, exc)
return {"models": []}
models_raw = raw.get("models", []) or []
models: list[dict] = []
for m in models_raw:
raw = {}
for m in (raw.get("models", []) or []):
if not isinstance(m, dict):
continue
model_id = m.get("id", "")
models.append({
"name": m.get("name", ""),
"id": m.get("id", ""),
"id": model_id,
"service": m.get("service", "ollama"),
"tags": m.get("tags", []) or [],
"vram_estimate_mb": m.get("vram_estimate_mb", 0),
})
if model_id:
bench_ids.add(model_id)
# Merge installed generator models not already in bench_models.yaml.
try:
from app.models import list_installed # local import avoids circular dependency at module load
for installed in list_installed():
model_id: str = installed.get("model_id") or ""
service: str = installed.get("service") or ""
role: str = installed.get("role") or ""
if not model_id:
continue
if service not in _BENCH_SERVICES or role not in _BENCH_ROLES:
continue
if model_id in bench_ids:
continue
display_name = model_id.split("/", 1)[-1] if "/" in model_id else model_id
models.append({
"name": display_name,
"id": model_id,
"service": service,
"tags": [role],
"vram_estimate_mb": installed.get("vram_mb") or 0,
})
bench_ids.add(model_id)
except Exception as exc:
logger.warning("Could not merge installed models into model list: %s", exc)
return {"models": models}
# ── GET /run ───────────────────────────────────────────────────────────────────
@router.get("/nodes")
def get_nodes() -> dict:
"""Proxy the coordinator's /api/nodes list, returning node_id + online status.
Online is inferred from last_heartbeat: any node with a recent heartbeat is online.
Returns an empty list if the coordinator is unreachable.
"""
cfg = _load_cforch_config()
coordinator_url = cfg.get("coordinator_url", "").rstrip("/")
if not coordinator_url:
return {"nodes": []}
try:
import httpx as _httpx
resp = _httpx.get(f"{coordinator_url}/api/nodes", timeout=5.0)
resp.raise_for_status()
raw_nodes = resp.json().get("nodes", [])
return {
"nodes": [
{
"node_id": n.get("node_id", ""),
"online": n.get("last_heartbeat") is not None,
"gpus": [
{
"gpu_id": g.get("gpu_id"),
"name": g.get("name", ""),
"vram_total_mb": g.get("vram_total_mb", 0),
"vram_free_mb": g.get("vram_free_mb", 0),
}
for g in n.get("gpus", [])
],
}
for n in raw_nodes
]
}
except Exception as exc:
logger.warning("Could not fetch nodes from coordinator: %s", exc)
return {"nodes": []}
@router.get("/run")
def run_benchmark(
task_ids: str = "",
model_ids: str = "",
model_tags: str = "",
coordinator_url: str = "",
ollama_url: str = "",
judge_url: str = "",
judge_backend: str = "chat",
workers: int = 1,
node_ids: str = "",
) -> StreamingResponse:
"""Spawn cf-orch benchmark.py and stream stdout as SSE progress events."""
global _BENCH_RUNNING, _bench_proc
# Check if the process is actually still alive; reset stale flag if not.
if _BENCH_RUNNING:
if _bench_proc is not None and _bench_proc.poll() is None:
raise HTTPException(409, "A benchmark is already running")
_BENCH_RUNNING = False
_bench_proc = None
cfg = _load_cforch_config()
bench_script = cfg.get("bench_script", "")
@ -205,6 +329,13 @@ def run_benchmark(
cfg_coordinator = cfg.get("coordinator_url", "")
cfg_ollama = cfg.get("ollama_url", "")
cfg_license_key = cfg.get("license_key", "")
cfg_judge_url = cfg.get("judge_url", "")
# Validate URL params before spawning the subprocess.
# _validate_service_url raises HTTPException on bad input (caught by FastAPI before streaming starts).
_validate_service_url(coordinator_url, "coordinator_url")
_validate_service_url(ollama_url, "ollama_url")
_validate_service_url(judge_url, "judge_url")
def generate():
global _BENCH_RUNNING, _bench_proc
@ -213,16 +344,68 @@ def run_benchmark(
yield f"data: {json.dumps({'type': 'error', 'message': 'bench_script not configured or not found'})}\n\n"
return
# Build effective models file: bench_models.yaml + any installed models
# whose IDs were selected but are absent from the YAML (e.g. downloaded
# via the Models view). Written to a temp file so benchmark.py sees one
# unified list; cleaned up in the finally block.
effective_models_file = bench_models
_tmp_models_path: str | None = None
if model_ids and bench_models and Path(bench_models).exists():
requested_ids = set(model_ids.split(","))
try:
raw_bench = yaml.safe_load(Path(bench_models).read_text(encoding="utf-8")) or {}
bench_entries: list[dict] = raw_bench.get("models", []) or []
bench_id_set = {m.get("id", "") for m in bench_entries if isinstance(m, dict)}
missing_ids = requested_ids - bench_id_set
if missing_ids:
from app.models import list_installed
installed_map = {
m["model_id"]: m
for m in list_installed()
if m.get("model_id") and m.get("service") in _BENCH_SERVICES
}
extra: list[dict] = []
for mid in missing_ids:
if mid in installed_map:
inst = installed_map[mid]
entry: dict[str, Any] = {
"id": mid,
"name": mid.split("/", 1)[-1] if "/" in mid else mid,
"service": inst.get("service", "cf-text"),
"vram_estimate_mb": inst.get("vram_mb") or 0,
"tags": [inst.get("role", "generator")],
"temperature": 0.0,
}
local_path = inst.get("path", "") or inst.get("local_path", "")
if local_path:
entry["model_path"] = local_path
extra.append(entry)
if extra:
merged = {"models": bench_entries + extra}
tf = tempfile.NamedTemporaryFile(
mode="w", suffix=".yaml", delete=False,
prefix="avocet_bench_models_",
)
yaml.dump(merged, tf)
tf.close()
_tmp_models_path = tf.name
effective_models_file = _tmp_models_path
except Exception as exc:
logger.warning("Could not merge installed models into temp bench file: %s", exc)
cmd = [
python_bin,
bench_script,
"--tasks", bench_tasks,
"--models", bench_models,
"--models", effective_models_file,
"--output", results_dir,
]
if task_ids:
cmd.extend(["--filter-tasks"] + task_ids.split(","))
if model_ids:
cmd.extend(["--filter-models"] + model_ids.split(","))
if model_tags:
cmd.extend(["--filter-tags"] + model_tags.split(","))
@ -233,6 +416,15 @@ def run_benchmark(
cmd.extend(["--coordinator", effective_coordinator])
if effective_ollama:
cmd.extend(["--ollama-url", effective_ollama])
effective_judge = judge_url if judge_url else cfg_judge_url
if effective_judge:
cmd.extend(["--judge-url", effective_judge])
if judge_backend and judge_backend != "chat":
cmd.extend(["--judge-backend", judge_backend])
if workers > 1:
cmd.extend(["--workers", str(workers)])
if node_ids:
cmd.extend(["--nodes"] + node_ids.split(","))
# Pass license key as env var so subprocess can authenticate with cf-orch
proc_env = {**os.environ}
@ -250,8 +442,23 @@ def run_benchmark(
env=proc_env,
)
_bench_proc = proc
_IDLE_TIMEOUT_S = 120 # kill if no output for 2 minutes (node crash)
try:
for line in proc.stdout:
while True:
ready = _select.select([proc.stdout], [], [], _IDLE_TIMEOUT_S)
if not ready[0]:
# No output for IDLE_TIMEOUT_S — node likely crashed
proc.terminate()
try:
proc.wait(timeout=5)
except _subprocess.TimeoutExpired:
proc.kill()
msg = f"Benchmark timed out — no output for {_IDLE_TIMEOUT_S}s (cluster node may have crashed)"
yield f"data: {json.dumps({'type': 'error', 'message': msg})}\n\n"
break
line = proc.stdout.readline()
if not line:
break
line = _strip_ansi(line.rstrip())
if line:
yield f"data: {json.dumps({'type': 'progress', 'message': line})}\n\n"
@ -273,6 +480,11 @@ def run_benchmark(
yield f"data: {json.dumps({'type': 'error', 'message': str(exc)})}\n\n"
finally:
_BENCH_RUNNING = False
if _tmp_models_path:
try:
os.unlink(_tmp_models_path)
except OSError:
pass
return StreamingResponse(
generate(),
@ -295,6 +507,7 @@ def get_cforch_config() -> dict:
"coordinator_url": cfg.get("coordinator_url", ""),
"ollama_url": cfg.get("ollama_url", ""),
"ollama_model": cfg.get("ollama_model", ""),
"judge_url": cfg.get("judge_url", ""),
"license_key_set": bool(cfg.get("license_key", "")),
"source": "env" if not _config_file().exists() else "yaml+env",
}
@ -335,3 +548,106 @@ def cancel_benchmark() -> dict:
_BENCH_RUNNING = False
_bench_proc = None
return {"status": "cancelled"}
# ── Coordinator proxy helpers ──────────────────────────────────────────────────
def _coordinator_url() -> str:
"""Return coordinator base URL from config, or raise 503 if not configured."""
url = _load_cforch_config().get("coordinator_url", "").rstrip("/")
if not url:
raise HTTPException(503, "cf-orch coordinator_url not configured")
return url
def _coordinator_get(path: str) -> Any:
"""GET from coordinator, return parsed JSON body. Raises HTTPException on error."""
import httpx as _httpx
try:
resp = _httpx.get(f"{_coordinator_url()}{path}", timeout=10.0)
except Exception as exc:
raise HTTPException(502, f"Coordinator unreachable: {exc}") from exc
if not resp.is_success:
raise HTTPException(resp.status_code, resp.text)
return resp.json()
async def _coordinator_post(path: str, body: dict) -> Any:
import httpx as _httpx
try:
async with _httpx.AsyncClient(timeout=10.0) as client:
resp = await client.post(f"{_coordinator_url()}{path}", json=body)
except Exception as exc:
raise HTTPException(502, f"Coordinator unreachable: {exc}") from exc
if not resp.is_success:
raise HTTPException(resp.status_code, resp.text)
return resp.json()
async def _coordinator_delete(path: str) -> Any:
import httpx as _httpx
try:
async with _httpx.AsyncClient(timeout=10.0) as client:
resp = await client.delete(f"{_coordinator_url()}{path}")
except Exception as exc:
raise HTTPException(502, f"Coordinator unreachable: {exc}") from exc
if not resp.is_success:
raise HTTPException(resp.status_code, resp.text)
return resp.json()
# ── GET /assignments/deployment-status ───────────────────────────────────────
@router.get("/assignments/deployment-status")
def get_deployment_status() -> Any:
return _coordinator_get("/api/assignments/deployment-status")
# ── /assignments ──────────────────────────────────────────────────────────────
@router.get("/assignments")
def list_assignments() -> Any:
return _coordinator_get("/api/assignments")
class AssignmentBody(BaseModel):
product: str
task: str
model_id: str
description: str = ""
@router.post("/assignments")
async def upsert_assignment(body: AssignmentBody) -> Any:
return await _coordinator_post("/api/assignments", body.model_dump())
@router.delete("/assignments/{product}/{task}")
async def delete_assignment(product: str, task: str) -> Any:
return await _coordinator_delete(f"/api/assignments/{urllib.parse.quote(product, safe='')}/{urllib.parse.quote(task, safe='')}")
# ── /model-registry ────────────────────────────────────────────────────────────
@router.get("/model-registry")
def list_model_registry() -> Any:
return _coordinator_get("/api/model-registry")
class ModelRegistryBody(BaseModel):
model_id: str
service_type: str
vram_mb: int
description: str = ""
hf_repo: str = ""
alias: str = ""
@router.post("/model-registry")
async def upsert_model_registry(body: ModelRegistryBody) -> Any:
return await _coordinator_post("/api/model-registry", body.model_dump())
@router.delete("/model-registry/{model_id:path}")
async def delete_model_registry(model_id: str) -> Any:
return await _coordinator_delete(f"/api/model-registry/{urllib.parse.quote(model_id, safe='')}")

34
app/cloud_session.py Normal file
View file

@ -0,0 +1,34 @@
"""
Avocet cloud session thin wrapper around cf_core.cloud_session.
Usage in FastAPI routes:
from app.cloud_session import get_session, require_tier, CloudUser
from fastapi import Depends
@router.get("/api/imitate")
def imitate(session: CloudUser = Depends(get_session)):
# session.user_id — Directus UUID (cloud) or "local" (self-hosted)
# session.tier — free | paid | premium | ultra | local
# session.has_byok — True if user has a configured LLM backend
...
@router.post("/api/custom-models")
def list_custom_models(session: CloudUser = Depends(require_tier("premium"))):
...
"""
from __future__ import annotations
import os
from circuitforge_core.cloud_session import CloudSessionFactory, CloudUser, detect_byok
__all__ = ["CloudUser", "get_session", "require_tier"]
_factory = CloudSessionFactory(
product="avocet",
byok_detector=detect_byok,
)
get_session = _factory.dependency()
require_tier = _factory.require_tier

282
app/dashboard.py Normal file
View file

@ -0,0 +1,282 @@
"""Avocet -- dashboard aggregate API.
GET /api/dashboard returns the current flywheel state:
labeled_since_last_eval -- items labeled after the most recent bench run
last_eval_timestamp -- ISO timestamp of newest bench_results summary
last_eval_best_score -- best macro_f1 from that summary
active_jobs -- jobs with status queued or running
corrections_pending -- sft_candidates with status=needs_review
corrections_export_ready -- approved sft candidates with non-blank correction
recent_bench_runs -- most-recent timestamp + score per bench type
signals -- computed booleans for UI nudge indicators
Thresholds in label_tool.yaml pipeline: section:
pipeline:
data_eval_threshold: 50 # labeled items since last bench to trigger nudge
eval_train_threshold: 0.05 # improvement delta needed before retraining (future)
"""
from __future__ import annotations
import json
import logging
import yaml
from pathlib import Path
from fastapi import APIRouter
logger = logging.getLogger(__name__)
_ROOT = Path(__file__).parent.parent
_DATA_DIR: Path = _ROOT / "data"
_CONFIG_DIR: Path | None = None
router = APIRouter()
_DEFAULT_DATA_EVAL_THRESHOLD = 50
_DEFAULT_EVAL_TRAIN_THRESHOLD = 0.05
def set_data_dir(path: Path) -> None:
global _DATA_DIR
_DATA_DIR = path
def set_config_dir(path: Path | None) -> None:
global _CONFIG_DIR
_CONFIG_DIR = path
def _config_file() -> Path:
if _CONFIG_DIR is not None:
return _CONFIG_DIR / "label_tool.yaml"
return _ROOT / "config" / "label_tool.yaml"
def _load_thresholds() -> tuple[int, float]:
f = _config_file()
if f.exists():
try:
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
pipeline = raw.get("pipeline", {}) or {}
return (
int(pipeline.get("data_eval_threshold", _DEFAULT_DATA_EVAL_THRESHOLD)),
float(pipeline.get("eval_train_threshold", _DEFAULT_EVAL_TRAIN_THRESHOLD)),
)
except Exception as exc:
logger.warning("Failed to read pipeline thresholds: %s", exc)
return _DEFAULT_DATA_EVAL_THRESHOLD, _DEFAULT_EVAL_TRAIN_THRESHOLD
def _load_score_records() -> list[dict]:
path = _DATA_DIR / "email_score.jsonl"
if not path.exists():
return []
records = []
for line in path.read_text(encoding="utf-8").splitlines():
line = line.strip()
if not line:
continue
try:
records.append(json.loads(line))
except json.JSONDecodeError:
pass
return records
def _find_latest_classifier_bench(results_dir_override: str = "") -> tuple[str | None, float | None]:
"""Return (iso_timestamp, best_macro_f1) from the newest bench_results summary.
Checks results_dir from cforch config if set, then falls back to
_ROOT/bench_results/. Returns (None, None) if no results exist.
"""
candidates = []
if results_dir_override:
candidates.append(Path(results_dir_override))
else:
f = _config_file()
if f.exists():
try:
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
rd = (raw.get("cforch", {}) or {}).get("results_dir", "")
if rd:
candidates.append(Path(rd))
except Exception as exc:
logger.warning("Failed to read cforch.results_dir from config: %s", exc)
candidates.append(_ROOT / "bench_results")
for rdir in candidates:
if not rdir.exists():
continue
subdirs = sorted([d for d in rdir.iterdir() if d.is_dir()], key=lambda d: d.name)
for subdir in reversed(subdirs):
summary = subdir / "summary.json"
if summary.exists():
try:
data = json.loads(summary.read_text(encoding="utf-8"))
if not isinstance(data, dict):
continue # cforch LLM-bench summaries are lists; skip
ts = data.get("timestamp") or subdir.name
score = data.get("best_macro_f1") or data.get("macro_f1")
return ts, (float(score) if isinstance(score, (int, float)) else None)
except Exception as exc:
logger.warning("Failed to parse summary.json at %s: %s", summary, exc)
return None, None
# Keep old name as alias so existing callers in tests still work.
_find_latest_eval = _find_latest_classifier_bench
def _count_corrections() -> tuple[int, int]:
"""Return (pending_count, export_ready_count)."""
pending = 0
export_ready = 0
candidates_path = _DATA_DIR / "sft_candidates.jsonl"
approved_path = _DATA_DIR / "sft_approved.jsonl"
if candidates_path.exists():
for line in candidates_path.read_text(encoding="utf-8").splitlines():
line = line.strip()
if not line:
continue
try:
r = json.loads(line)
if r.get("status") == "needs_review":
pending += 1
except json.JSONDecodeError:
pass
if approved_path.exists():
for line in approved_path.read_text(encoding="utf-8").splitlines():
line = line.strip()
if not line:
continue
try:
r = json.loads(line)
if (r.get("status") == "approved"
and r.get("corrected_response")
and str(r["corrected_response"]).strip()):
export_ready += 1
except json.JSONDecodeError:
pass
return pending, export_ready
def _get_active_jobs() -> list[dict]:
"""Query train SQLite DB for queued/running jobs. Returns [] if DB absent."""
try:
from app.train.train import _DB_PATH, _db, _init_db
if not _DB_PATH.exists():
return []
_init_db()
with _db() as conn:
rows = conn.execute(
"SELECT id, type, model_key, status FROM jobs WHERE status IN ('queued', 'running')"
).fetchall()
return [{"id": r["id"], "type": r["type"], "model_key": r["model_key"], "status": r["status"]} for r in rows]
except Exception as exc:
logger.warning("Failed to query train jobs DB: %s", exc)
return []
def _count_labeled_since(since_ts: str | None) -> int:
records = _load_score_records()
if since_ts is None:
return len(records)
return sum(1 for r in records if r.get("labeled_at", "") > since_ts)
def _get_recent_bench_runs() -> dict:
"""Return most-recent run summary for each bench type.
Each entry: {"timestamp": str|None, "metric": str|None, "score": float|None}
"""
runs: dict[str, dict] = {
"classifier": {"timestamp": None, "metric": "macro_f1", "score": None},
"llm": {"timestamp": None, "metric": None, "score": None},
"style": {"timestamp": None, "metric": None, "score": None},
"plans": {"timestamp": None, "metric": "avg_score", "score": None},
}
# ── Classifier: bench_results/<run>/summary.json ──────────────────────
clf_ts, clf_score = _find_latest_classifier_bench()
if clf_ts:
runs["classifier"]["timestamp"] = clf_ts
runs["classifier"]["score"] = clf_score
# ── LLM bench + Style: benchmark_results/ ─────────────────────────────
f = _config_file()
bench_dir: Path | None = None
if f.exists():
try:
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
rd = (raw.get("cforch", {}) or {}).get("results_dir", "")
if rd:
bench_dir = Path(rd)
except Exception:
pass
if bench_dir is None:
bench_dir = _ROOT / "benchmark_results"
if bench_dir.exists():
llm_files = sorted(
[p for p in bench_dir.glob("*.json") if not p.name.startswith("style_")],
key=lambda p: p.stat().st_mtime, reverse=True,
)
if llm_files:
try:
data = json.loads(llm_files[0].read_text(encoding="utf-8"))
runs["llm"]["timestamp"] = data.get("timestamp") or llm_files[0].stem
except Exception:
pass
style_files = sorted(bench_dir.glob("style_*.json"), reverse=True)
if style_files:
try:
data = json.loads(style_files[0].read_text(encoding="utf-8"))
if isinstance(data, list) and data:
runs["style"]["timestamp"] = data[0].get("timestamp") or style_files[0].stem
except Exception:
pass
# ── Plans bench: data/plans_bench_results/plans_*.json ────────────────
plans_dir = _DATA_DIR / "plans_bench_results"
if plans_dir.exists():
plans_files = sorted(plans_dir.glob("plans_*.json"), reverse=True)
if plans_files:
run_id = plans_files[0].stem
try:
d: dict = json.loads(plans_files[0].read_text(encoding="utf-8"))
all_scores = [
r["total_score"]
for results in d.values()
for r in results
if isinstance(r, dict) and not r.get("error")
]
avg = round(sum(all_scores) / len(all_scores), 3) if all_scores else None
try:
date_part = run_id.removeprefix("plans_")
date, time_part = date_part.split("_")
ts_display = f"{date} {time_part[:2]}:{time_part[2:4]}"
except Exception:
ts_display = run_id
runs["plans"]["timestamp"] = ts_display
runs["plans"]["score"] = avg
except Exception:
pass
return runs
@router.get("/dashboard")
def get_dashboard() -> dict:
data_threshold, _train_threshold = _load_thresholds()
last_ts, last_score = _find_latest_classifier_bench()
labeled_since = _count_labeled_since(last_ts)
corrections_pending, corrections_export_ready = _count_corrections()
active_jobs = _get_active_jobs()
recent_bench = _get_recent_bench_runs()
return {
"labeled_since_last_eval": labeled_since,
"last_eval_timestamp": last_ts,
"last_eval_best_score": last_score,
"active_jobs": active_jobs,
"corrections_pending": corrections_pending,
"corrections_export_ready": corrections_export_ready,
"recent_bench_runs": recent_bench,
"signals": {
"data_to_eval": labeled_since >= data_threshold,
"eval_to_train": False, # future: implement delta-F1 comparison
"train_to_fleet": False, # future: implement fleet sync signal
},
}

0
app/data/__init__.py Normal file
View file

393
app/data/corrections.py Normal file
View file

@ -0,0 +1,393 @@
"""Avocet -- SFT candidate corrections API (moved from app/sft.py).
All endpoints are registered on `router` (a FastAPI APIRouter).
Primary prefix: /api/corrections (backward-compat alias: /api/sft -- pending Vue SPA migration)
Module-level globals (_DATA_DIR, _CONFIG_DIR) follow the same
testability pattern as api.py -- override them via set_data_dir() and
set_config_dir() in test fixtures.
"""
from __future__ import annotations
import json
import logging
import os
import uuid
from datetime import datetime, timezone
from pathlib import Path
from typing import Literal
import yaml
from fastapi import APIRouter, Header, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from app.utils import append_jsonl, read_jsonl, write_jsonl
logger = logging.getLogger(__name__)
_ROOT = Path(__file__).parent.parent.parent
_DATA_DIR: Path = _ROOT / "data"
_CONFIG_DIR: Path | None = None
router = APIRouter()
# -- Testability seams ---------------------------------------------------------
def set_data_dir(path: Path) -> None:
global _DATA_DIR
_DATA_DIR = path
def set_config_dir(path: Path | None) -> None:
global _CONFIG_DIR
_CONFIG_DIR = path
# -- Internal helpers ----------------------------------------------------------
def _config_file() -> Path:
if _CONFIG_DIR is not None:
return _CONFIG_DIR / "label_tool.yaml"
return _ROOT / "config" / "label_tool.yaml"
_DEFAULT_BENCH_RESULTS_DIR = "/Library/Development/CircuitForge/circuitforge-orch/scripts/bench_results"
def set_default_bench_results_dir(path: str) -> None:
"""Override the default bench_results_dir -- used by tests to avoid real filesystem."""
global _DEFAULT_BENCH_RESULTS_DIR
_DEFAULT_BENCH_RESULTS_DIR = path
def _get_bench_results_dir() -> Path:
f = _config_file()
if f.exists():
try:
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
d = raw.get("sft", {}).get("bench_results_dir", "")
if d:
return Path(d)
except yaml.YAMLError as exc:
logger.warning("Failed to parse SFT config %s: %s", f, exc)
return Path(_DEFAULT_BENCH_RESULTS_DIR)
def _candidates_file() -> Path:
return _DATA_DIR / "sft_candidates.jsonl"
def _approved_file() -> Path:
return _DATA_DIR / "sft_approved.jsonl"
def _read_candidates() -> list[dict]:
return read_jsonl(_candidates_file())
def _write_candidates(records: list[dict]) -> None:
write_jsonl(_candidates_file(), records)
def _is_exportable(r: dict) -> bool:
"""Return True if an approved record is ready to include in SFT export."""
return (
r.get("status") == "approved"
and bool(r.get("corrected_response"))
and str(r["corrected_response"]).strip() != ""
)
# -- GET /runs -----------------------------------------------------------------
@router.get("/runs")
def get_runs():
"""List available benchmark runs in the configured bench_results_dir."""
from scripts.sft_import import discover_runs
bench_dir = _get_bench_results_dir()
existing = _read_candidates()
# benchmark_run_id in each record equals the run's directory name by cf-orch convention
imported_run_ids = {
r["benchmark_run_id"]
for r in existing
if r.get("benchmark_run_id") is not None
}
runs = discover_runs(bench_dir)
return [
{
"run_id": r["run_id"],
"timestamp": r["timestamp"],
"candidate_count": r["candidate_count"],
"already_imported": r["run_id"] in imported_run_ids,
}
for r in runs
]
# -- POST /import --------------------------------------------------------------
class ImportRequest(BaseModel):
run_id: str
@router.post("/import")
def post_import(req: ImportRequest):
"""Import one benchmark run's sft_candidates.jsonl into the local data dir."""
from scripts.sft_import import discover_runs, import_run
bench_dir = _get_bench_results_dir()
runs = discover_runs(bench_dir)
run = next((r for r in runs if r["run_id"] == req.run_id), None)
if run is None:
raise HTTPException(404, f"Run {req.run_id!r} not found in bench_results_dir")
return import_run(run["sft_path"], _DATA_DIR)
# -- GET /queue ----------------------------------------------------------------
@router.get("/queue")
def get_queue(page: int = 1, per_page: int = 20):
"""Return paginated needs_review candidates."""
records = _read_candidates()
pending = [r for r in records if r.get("status") == "needs_review"]
start = (page - 1) * per_page
return {
"items": pending[start:start + per_page],
"total": len(pending),
"page": page,
"per_page": per_page,
}
# -- POST /submit --------------------------------------------------------------
FailureCategory = Literal[
"scoring_artifact",
"style_violation",
"partial_answer",
"wrong_answer",
"format_error",
"hallucination",
]
class SubmitRequest(BaseModel):
id: str
action: Literal["correct", "discard", "flag"]
corrected_response: str | None = None
failure_category: FailureCategory | None = None
@router.post("/submit")
def post_submit(req: SubmitRequest):
"""Record a reviewer decision for one SFT candidate."""
if req.action == "correct":
if not req.corrected_response or not req.corrected_response.strip():
raise HTTPException(422, "corrected_response must be non-empty when action is 'correct'")
records = _read_candidates()
idx = next((i for i, r in enumerate(records) if r.get("id") == req.id), None)
if idx is None:
raise HTTPException(404, f"Record {req.id!r} not found")
record = records[idx]
if record.get("status") != "needs_review":
raise HTTPException(409, f"Record is not in needs_review state (current: {record.get('status')})")
if req.action == "correct":
records[idx] = {
**record,
"status": "approved",
"corrected_response": req.corrected_response,
"failure_category": req.failure_category,
}
_write_candidates(records)
append_jsonl(_approved_file(), records[idx])
elif req.action == "discard":
records[idx] = {**record, "status": "discarded"}
_write_candidates(records)
else: # flag
records[idx] = {**record, "status": "model_rejected"}
_write_candidates(records)
return {"ok": True}
# -- POST /undo ----------------------------------------------------------------
class UndoRequest(BaseModel):
id: str
@router.post("/undo")
def post_undo(req: UndoRequest):
"""Restore a previously actioned candidate back to needs_review."""
records = _read_candidates()
idx = next((i for i, r in enumerate(records) if r.get("id") == req.id), None)
if idx is None:
raise HTTPException(404, f"Record {req.id!r} not found")
record = records[idx]
old_status = record.get("status")
if old_status == "needs_review":
raise HTTPException(409, "Record is already in needs_review state")
records[idx] = {**record, "status": "needs_review", "corrected_response": None}
_write_candidates(records)
# If it was approved, remove from the approved file too
if old_status == "approved":
approved = read_jsonl(_approved_file())
write_jsonl(_approved_file(), [r for r in approved if r.get("id") != req.id])
return {"ok": True}
# -- GET /export ---------------------------------------------------------------
@router.get("/export")
def get_export() -> StreamingResponse:
"""Stream approved records as SFT-ready JSONL for download."""
exportable = [r for r in read_jsonl(_approved_file()) if _is_exportable(r)]
def generate():
for r in exportable:
record = {
"messages": r.get("prompt_messages", []) + [
{"role": "assistant", "content": r["corrected_response"]}
]
}
yield json.dumps(record) + "\n"
timestamp = datetime.now(timezone.utc).strftime("%Y%m%d-%H%M%S")
return StreamingResponse(
generate(),
media_type="application/x-ndjson",
headers={
"Content-Disposition": f'attachment; filename="sft_export_{timestamp}.jsonl"'
},
)
# -- GET /stats ----------------------------------------------------------------
@router.get("/stats")
def get_stats() -> dict[str, object]:
"""Return counts by status, model, and task type."""
records = _read_candidates()
by_status: dict[str, int] = {}
by_model: dict[str, int] = {}
by_task_type: dict[str, int] = {}
for r in records:
status = r.get("status", "unknown")
by_status[status] = by_status.get(status, 0) + 1
model = r.get("model_name", "unknown")
by_model[model] = by_model.get(model, 0) + 1
task_type = r.get("task_type", "unknown")
by_task_type[task_type] = by_task_type.get(task_type, 0) + 1
approved = read_jsonl(_approved_file())
export_ready = sum(1 for r in approved if _is_exportable(r))
return {
"total": len(records),
"by_status": by_status,
"by_model": by_model,
"by_task_type": by_task_type,
"export_ready": export_ready,
}
# -- GET /config ---------------------------------------------------------------
@router.get("/config")
def get_sft_config() -> dict:
"""Return the current SFT configuration (bench_results_dir)."""
f = _config_file()
if not f.exists():
return {"bench_results_dir": ""}
try:
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
except yaml.YAMLError:
return {"bench_results_dir": ""}
sft_section = raw.get("sft") or {}
return {"bench_results_dir": sft_section.get("bench_results_dir", "")}
class SftConfigPayload(BaseModel):
bench_results_dir: str
@router.post("/config")
def post_sft_config(payload: SftConfigPayload) -> dict:
"""Write the bench_results_dir setting to the config file."""
f = _config_file()
f.parent.mkdir(parents=True, exist_ok=True)
try:
raw = yaml.safe_load(f.read_text(encoding="utf-8")) if f.exists() else {}
raw = raw or {}
except yaml.YAMLError:
raw = {}
raw["sft"] = {"bench_results_dir": payload.bench_results_dir}
tmp = f.with_suffix(".tmp")
tmp.write_text(yaml.dump(raw, allow_unicode=True, sort_keys=False), encoding="utf-8")
tmp.rename(f)
return {"ok": True}
# -- POST /ingest --------------------------------------------------------------
class IngestRequest(BaseModel):
source: str # e.g. "peregrine", "kiwi"
task_type: str # e.g. "email_classification", "recipe_suggestion"
prompt: str # the prompt that was sent to the LLM
response: str # the LLM's original response
correction: str # the human-corrected response
label: str | None = None # optional label/category
@router.post("/ingest")
def post_ingest(
req: IngestRequest,
authorization: str | None = Header(default=None),
) -> dict:
"""Ingest a correction from a sibling CF product.
Authentication: Authorization: Bearer <AVOCET_INGESTION_SECRET>
Creates a sft_candidates record with status='approved' (pre-approved by
the calling product -- human review already happened upstream). Also writes
to sft_approved.jsonl so it is immediately included in export counts.
Returns {"ok": True, "id": "<uuid>"}.
"""
expected_secret = os.environ.get("AVOCET_INGESTION_SECRET", "")
if not expected_secret:
raise HTTPException(503, "Ingestion not configured -- AVOCET_INGESTION_SECRET not set")
if not authorization or not authorization.startswith("Bearer "):
raise HTTPException(401, "Missing or malformed Authorization header")
token = authorization.removeprefix("Bearer ").strip()
if token != expected_secret:
raise HTTPException(403, "Invalid ingestion secret")
record_id = str(uuid.uuid4())
now = datetime.now(timezone.utc).isoformat()
record = {
"id": record_id,
"source": req.source,
"task_type": req.task_type,
"status": "approved",
"prompt_messages": [{"role": "user", "content": req.prompt}],
"model_response": req.response,
"corrected_response": req.correction,
"label": req.label,
"timestamp": now,
"benchmark_run_id": None,
}
append_jsonl(_candidates_file(), record)
append_jsonl(_approved_file(), record)
return {"ok": True, "id": record_id}

243
app/data/fetch.py Normal file
View file

@ -0,0 +1,243 @@
"""Avocet -- IMAP fetch utilities and fetch API routes.
All IMAP helper functions (from app/imap_fetch.py) plus the
/api/accounts/test and /api/fetch/stream endpoints.
"""
from __future__ import annotations
import email as _email_lib
import hashlib
import imaplib
import json
import yaml
from datetime import datetime, timedelta
from email.header import decode_header as _raw_decode
from pathlib import Path
from typing import Iterator
from fastapi import APIRouter, Query
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from app.utils import extract_body, read_jsonl, write_jsonl
_ROOT = Path(__file__).parent.parent.parent
_DATA_DIR: Path = _ROOT / "data"
_CONFIG_DIR: Path | None = None
router = APIRouter()
def set_data_dir(path: Path) -> None:
global _DATA_DIR
_DATA_DIR = path
def set_config_dir(path: Path | None) -> None:
global _CONFIG_DIR
_CONFIG_DIR = path
def _config_file() -> Path:
if _CONFIG_DIR is not None:
return _CONFIG_DIR / "label_tool.yaml"
return _ROOT / "config" / "label_tool.yaml"
def _queue_file() -> Path:
return _DATA_DIR / "email_label_queue.jsonl"
def _get_config_accounts() -> list[dict]:
f = _config_file()
if not f.exists():
return []
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
return raw.get("accounts", [])
# ── IMAP decode helpers ───────────────────────────────────────────────────────
def _decode_str(value: str | None) -> str:
if not value:
return ""
parts = _raw_decode(value)
out = []
for part, enc in parts:
if isinstance(part, bytes):
out.append(part.decode(enc or "utf-8", errors="replace"))
else:
out.append(str(part))
return " ".join(out).strip()
def entry_key(e: dict) -> str:
"""Stable MD5 content-hash for dedup — matches label_tool.py _entry_key."""
key = (e.get("subject", "") + (e.get("body", "") or "")[:100])
return hashlib.md5(key.encode("utf-8", errors="replace")).hexdigest()
# ── Wide search terms ────────────────────────────────────────────────────────
_WIDE_TERMS = [
"interview", "phone screen", "video call", "zoom link", "schedule a call",
"offer letter", "job offer", "offer of employment", "pleased to offer",
"unfortunately", "not moving forward", "other candidates", "regret to inform",
"no longer", "decided not to", "decided to go with",
"opportunity", "interested in your background", "reached out", "great fit",
"exciting role", "love to connect",
"assessment", "questionnaire", "culture fit", "culture-fit", "online assessment",
"application received", "thank you for applying", "application confirmation",
"you applied", "your application for",
"reschedule", "rescheduled", "new time", "moved to", "postponed", "new date",
"job digest", "jobs you may like", "recommended jobs", "jobs for you",
"new jobs", "job alert",
"came across your profile", "reaching out about", "great fit for a role",
"exciting opportunity",
"welcome to the team", "start date", "onboarding", "first day", "we're excited to have you",
"application", "recruiter", "recruiting", "hiring", "candidate",
]
# ── Public API ────────────────────────────────────────────────────────────────
def test_connection(acc: dict) -> tuple[bool, str, int | None]:
"""Connect, login, select folder. Returns (ok, human_message, message_count|None)."""
host = acc.get("host", "")
port = int(acc.get("port", 993))
use_ssl = acc.get("use_ssl", True)
username = acc.get("username", "")
password = acc.get("password", "")
folder = acc.get("folder", "INBOX")
if not host or not username or not password:
return False, "Host, username, and password are all required.", None
try:
conn = (imaplib.IMAP4_SSL if use_ssl else imaplib.IMAP4)(host, port)
conn.login(username, password)
_, data = conn.select(folder, readonly=True)
count_raw = data[0].decode() if data and data[0] else "0"
count = int(count_raw) if count_raw.isdigit() else 0
conn.logout()
return True, f"Connected — {count:,} message(s) in {folder}.", count
except Exception as exc:
return False, str(exc), None
def fetch_account_stream(
acc: dict,
days_back: int,
limit: int,
known_keys: set[str],
) -> Iterator[dict]:
"""Generator — yields progress dicts while fetching emails via IMAP.
Mutates `known_keys` in place for cross-account dedup within one fetch session.
Yields event dicts with "type" key:
{"type": "start", "account": str, "total_uids": int}
{"type": "progress", "account": str, "fetched": int, "total_uids": int}
{"type": "done", "account": str, "added": int, "skipped": int, "emails": list}
"""
name = acc.get("name", acc.get("username", "?"))
host = acc.get("host", "imap.gmail.com")
port = int(acc.get("port", 993))
use_ssl = acc.get("use_ssl", True)
username = acc["username"]
password = acc["password"]
folder = acc.get("folder", "INBOX")
since = (datetime.now() - timedelta(days=days_back)).strftime("%d-%b-%Y")
conn = (imaplib.IMAP4_SSL if use_ssl else imaplib.IMAP4)(host, port)
conn.login(username, password)
conn.select(folder, readonly=True)
seen_uids: dict[bytes, None] = {}
for term in _WIDE_TERMS:
try:
_, data = conn.search(None, f'(SUBJECT "{term}" SINCE "{since}")')
for uid in (data[0] or b"").split():
seen_uids[uid] = None
except Exception:
pass
uids = list(seen_uids.keys())[: limit * 3]
yield {"type": "start", "account": name, "total_uids": len(uids)}
emails: list[dict] = []
skipped = 0
for i, uid in enumerate(uids):
if len(emails) >= limit:
break
if i % 5 == 0:
yield {"type": "progress", "account": name, "fetched": len(emails), "total_uids": len(uids)}
try:
_, raw_data = conn.fetch(uid, "(RFC822)")
if not raw_data or not raw_data[0]:
continue
msg = _email_lib.message_from_bytes(raw_data[0][1])
subj = _decode_str(msg.get("Subject", ""))
from_addr = _decode_str(msg.get("From", ""))
date = _decode_str(msg.get("Date", ""))
body = extract_body(msg)[:800]
entry = {"subject": subj, "body": body, "from_addr": from_addr,
"date": date, "account": name}
k = entry_key(entry)
if k not in known_keys:
known_keys.add(k)
emails.append(entry)
else:
skipped += 1
except Exception:
skipped += 1
try:
conn.logout()
except Exception:
pass
yield {"type": "done", "account": name, "added": len(emails), "skipped": skipped,
"emails": emails}
class AccountTestRequest(BaseModel):
account: dict
@router.post("/accounts/test")
def test_account_route(req: AccountTestRequest) -> dict:
ok, message, count = test_connection(req.account)
return {"ok": ok, "message": message, "count": count}
@router.get("/fetch/stream")
def fetch_stream(
accounts: str = Query(default=""),
days_back: int = Query(default=90, ge=1, le=365),
limit: int = Query(default=150, ge=1, le=1000),
mode: str = Query(default="wide"),
) -> StreamingResponse:
selected_names = {n.strip() for n in accounts.split(",") if n.strip()}
all_accounts = _get_config_accounts()
selected = [a for a in all_accounts if a.get("name") in selected_names]
def generate():
known_keys = {entry_key(x) for x in read_jsonl(_queue_file())}
total_added = 0
for acc in selected:
try:
batch_emails: list[dict] = []
for event in fetch_account_stream(acc, days_back, limit, known_keys):
if event["type"] == "done":
batch_emails = event.pop("emails", [])
total_added += event["added"]
yield f"data: {json.dumps(event)}\n\n"
if batch_emails:
existing = read_jsonl(_queue_file())
write_jsonl(_queue_file(), existing + batch_emails)
except Exception as exc:
yield f"data: {json.dumps({'type': 'error', 'account': acc.get('name', '?'), 'message': str(exc)})}\n\n"
queue_size = len(read_jsonl(_queue_file()))
yield f"data: {json.dumps({'type': 'complete', 'total_added': total_added, 'queue_size': queue_size})}\n\n"
return StreamingResponse(generate(), media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"})

729
app/data/imitate.py Normal file
View file

@ -0,0 +1,729 @@
"""Avocet — Imitate tab API.
Fetches real samples from sibling CF product APIs, sends them through selected
local LLMs (ollama), and streams responses back to the UI. Results can be
pushed into the SFT corrections queue for human review.
All endpoints registered on `router`. api.py includes this with prefix="/api/imitate".
Module-level globals follow the same testability pattern as cforch.py and sft.py:
override _CONFIG_DIR and _DATA_DIR via set_config_dir() / set_data_dir() in tests.
"""
from __future__ import annotations
import base64
import json
import logging
import time
import uuid
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
from urllib.error import URLError
from urllib.request import Request, urlopen
import httpx
import yaml
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from app.utils import append_jsonl
logger = logging.getLogger(__name__)
_ROOT = Path(__file__).parent.parent.parent
_CONFIG_DIR: Path | None = None
_DATA_DIR: Path = _ROOT / "data"
router = APIRouter()
# ── Testability seams ──────────────────────────────────────────────────────────
def set_config_dir(path: Path | None) -> None:
global _CONFIG_DIR
_CONFIG_DIR = path
def set_data_dir(path: Path) -> None:
global _DATA_DIR
_DATA_DIR = path
# ── Internal helpers ───────────────────────────────────────────────────────────
def _config_file() -> Path:
if _CONFIG_DIR is not None:
return _CONFIG_DIR / "label_tool.yaml"
return _ROOT / "config" / "label_tool.yaml"
def _load_imitate_config() -> dict:
"""Read label_tool.yaml and return the imitate sub-dict (or {} if absent)."""
f = _config_file()
if not f.exists():
return {}
try:
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
except yaml.YAMLError as exc:
logger.warning("Failed to parse imitate config %s: %s", f, exc)
return {}
return raw.get("imitate", {}) or {}
def _load_cforch_config() -> dict:
"""Read cforch section for ollama_url fallback."""
f = _config_file()
if not f.exists():
return {}
try:
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
except yaml.YAMLError as exc:
return {}
return raw.get("cforch", {}) or {}
def _ollama_url(cfg: dict) -> str:
cforch = _load_cforch_config()
return cfg.get("ollama_url") or cforch.get("ollama_url") or "http://localhost:11434"
def _cforch_url() -> str:
cforch = _load_cforch_config()
return cforch.get("coordinator_url") or "http://localhost:7700"
def _resolve_task_model(cforch_base: str, product: str, task: str) -> dict | None:
"""Return {model_id, service_type} for a product.task assignment, or None if not found.
Calls GET coordinator/api/assignments and filters by product+task.
The model registry entry is fetched separately to get service_type.
Returns None (not raises) callers emit a 'model_done' error event instead.
"""
try:
asgn_resp = httpx.get(f"{cforch_base}/api/assignments", timeout=5.0)
asgn_resp.raise_for_status()
assignments: list[dict] = asgn_resp.json().get("assignments", []) or []
match = next(
(a for a in assignments if a.get("product") == product and a.get("task") == task),
None,
)
if match is None:
return None
model_id: str = match.get("model_id", "")
if not model_id:
return None
# Look up service_type from model registry
reg_resp = httpx.get(f"{cforch_base}/api/model-registry", timeout=5.0)
service_type = "cf-text" # sensible default
if reg_resp.is_success:
models: list[dict] = reg_resp.json().get("models", []) or []
reg_entry = next((m for m in models if m.get("model_id") == model_id), None)
if reg_entry:
service_type = reg_entry.get("service_type", "cf-text") or "cf-text"
return {"model_id": model_id, "service_type": service_type}
except Exception as exc:
logger.warning("Task resolution failed for %s.%s: %s", product, task, exc)
return None
def _cforch_catalog(cforch_base: str) -> list[dict]:
"""Fetch the live cf-text catalog from cf-orch.
Filters out proxy entries (ollama://, vllm://, http://) those models are
served by their own services and should not be allocated via cf-text.
Returns only models with real file-system paths that cf-text can load directly.
"""
try:
resp = httpx.get(
f"{cforch_base}/api/services/cf-text/catalog",
params={"node_id": "heimdall"},
timeout=5.0,
)
resp.raise_for_status()
raw = resp.json()
result = []
for model_id, entry in raw.items():
if not isinstance(entry, dict):
continue
path = entry.get("path", "")
# Skip proxy entries — they're routed through other services
if "://" in path:
continue
result.append({
"id": model_id,
"vram_mb": entry.get("vram_mb", 0),
"description": entry.get("description", ""),
})
return result
except Exception as exc:
logger.warning("Could not fetch cf-orch catalog: %s", exc)
return []
def _http_get_json(url: str, timeout: int = 5) -> Any:
"""Fetch JSON from url; raise URLError on failure."""
req = Request(url, headers={"Accept": "application/json"})
with urlopen(req, timeout=timeout) as resp:
return json.loads(resp.read().decode("utf-8"))
def _is_online(base_url: str, health_path: str = "/api/health") -> bool:
"""Return True if the product's health endpoint responds OK."""
try:
data = _http_get_json(f"{base_url.rstrip('/')}{health_path}", timeout=2)
return bool(data)
except Exception:
return False
def _extract_sample(
raw: Any,
text_fields: list[str],
sample_index: int = 0,
sample_key: str | None = None,
) -> dict[str, Any]:
"""Pull one item from a list or dict response and extract text_fields.
sample_key: if provided, unwrap raw[sample_key] before looking for a list.
Falls back to a set of conventional envelope keys if sample_key is absent.
"""
item: dict[str, Any]
if isinstance(raw, list):
if not raw:
return {}
item = raw[min(sample_index, len(raw) - 1)]
elif isinstance(raw, dict):
# Use declared sample_key first, then fall back to conventional names.
_ENVELOPE_KEYS = (
"samples", "items", "results", "data", "jobs", "listings",
"pantry", "saved_searches", "entries", "calls", "records",
)
search_keys = ([sample_key] if sample_key else []) + list(_ENVELOPE_KEYS)
for key in search_keys:
if key in raw and isinstance(raw[key], list):
lst = raw[key]
item = lst[min(sample_index, len(lst) - 1)] if lst else {}
break
else:
item = raw
else:
return {}
parts = []
for field in text_fields:
val = item.get(field)
if val and str(val).strip():
parts.append(f"**{field}**: {val}")
return {"item": item, "text": "\n\n".join(parts)}
def _candidates_file() -> Path:
return _DATA_DIR / "sft_candidates.jsonl"
def _sse(data: dict) -> str:
return f"data: {json.dumps(data)}\n\n"
def _fetch_image_b64(image_url: str) -> str:
"""Download an image URL and return it as a base64 string for ollama.
Returns empty string on any failure a missing image is non-fatal;
the model will still run against the text prompt alone.
"""
try:
req = Request(image_url, headers={"User-Agent": "Avocet/1.0"})
with urlopen(req, timeout=10) as resp:
return base64.b64encode(resp.read()).decode("ascii")
except Exception as exc:
logger.warning("Failed to fetch image %s: %s", image_url, exc)
return ""
def _run_ollama_streaming(
ollama_base: str,
model_id: str,
prompt: str,
temperature: float,
system: str = "",
images: list[str] | None = None,
) -> tuple[str, int]:
"""Call ollama /api/generate with stream=False; return (full_response, elapsed_ms).
Blocks until the model finishes; yields nothing streaming is handled by
the SSE generator in run_imitate().
system: optional system prompt passed as a separate field to ollama.
images: list of base64-encoded image strings (vision models only).
"""
url = f"{ollama_base.rstrip('/')}/api/generate"
body: dict = {
"model": model_id,
"prompt": prompt,
"stream": False,
"options": {"temperature": temperature},
}
if system:
body["system"] = system
if images:
body["images"] = images
payload = json.dumps(body).encode("utf-8")
req = Request(url, data=payload, method="POST",
headers={"Content-Type": "application/json"})
t0 = time.time()
try:
with urlopen(req, timeout=120) as resp:
body = json.loads(resp.read().decode("utf-8"))
elapsed = int((time.time() - t0) * 1000)
return body.get("response", ""), elapsed
except Exception as exc:
elapsed = int((time.time() - t0) * 1000)
raise RuntimeError(str(exc)) from exc
def _run_cftext(
cforch_base: str,
model_id: str,
prompt: str,
system: str,
temperature: float,
startup_timeout_s: float = 180.0,
user_id: str | None = None,
) -> tuple[str, int, bool]:
"""Allocate cf-text via cf-orch, generate, release. Returns (response, elapsed_ms, cold_started).
Raises RuntimeError on allocation failure or generation error.
cold_started=True means the service was launched from scratch (caller may log this).
Cold-start detection uses coordinator state signals (running/stopped) rather than
polling the service health endpoint this fails fast on model load errors instead
of waiting out the full timeout.
"""
# Allocate
alloc_resp = httpx.post(
f"{cforch_base}/api/services/cf-text/allocate",
json={
"model_candidates": [model_id],
"caller": "avocet",
"pipeline": "imitate",
**({"user_id": user_id} if user_id else {}),
},
timeout=30.0,
)
alloc_resp.raise_for_status()
data = alloc_resp.json()
service_url: str = data["url"]
allocation_id: str = data.get("allocation_id", "")
node_id: str = data.get("node_id", "")
gpu_id: int | None = data.get("gpu_id")
cold_started = data.get("started", False) and not data.get("warm", True)
# Wait for ready using coordinator state signals
if cold_started:
deadline = time.monotonic() + startup_timeout_s
probe_misses = 0
while time.monotonic() < deadline:
try:
status = httpx.get(
f"{cforch_base}/api/services/cf-text/status", timeout=5.0
)
if status.is_success:
instances = status.json().get("instances", [])
match = next(
(i for i in instances
if i.get("node_id") == node_id and i.get("gpu_id") == gpu_id),
None,
)
if match:
probe_misses = 0
state = match.get("state", "")
if state == "running":
break
elif state == "stopped":
if allocation_id:
httpx.delete(
f"{cforch_base}/api/services/cf-text/allocations/{allocation_id}",
timeout=5.0,
)
raise RuntimeError(f"cf-text failed to load {model_id!r} (service stopped)")
else:
probe_misses += 1
if probe_misses >= 6:
# Coordinator hasn't registered instance yet — fall back to health poll
try:
if httpx.get(f"{service_url}/health", timeout=3.0).is_success:
break
except Exception:
pass
except RuntimeError:
raise
except Exception:
pass
time.sleep(2.0)
else:
if allocation_id:
httpx.delete(f"{cforch_base}/api/services/cf-text/allocations/{allocation_id}", timeout=5.0)
raise RuntimeError(f"cf-text cold start timed out after {startup_timeout_s:.0f}s")
# Generate
messages: list[dict] = []
if system:
messages.append({"role": "system", "content": system})
messages.append({"role": "user", "content": prompt})
t0 = time.time()
try:
gen_resp = httpx.post(
f"{service_url}/v1/chat/completions",
json={
"model": model_id,
"messages": messages,
"max_tokens": 300,
"temperature": temperature,
"stream": False,
},
timeout=120.0,
)
gen_resp.raise_for_status()
elapsed_ms = int((time.time() - t0) * 1000)
content = gen_resp.json()["choices"][0]["message"]["content"]
return content.strip(), elapsed_ms, cold_started
except Exception as exc:
elapsed_ms = int((time.time() - t0) * 1000)
raise RuntimeError(str(exc)) from exc
finally:
if allocation_id:
try:
httpx.delete(f"{cforch_base}/api/services/cf-text/allocations/{allocation_id}", timeout=5.0)
except Exception:
pass
# ── GET /products ──────────────────────────────────────────────────────────────
@router.get("/products")
def get_products() -> dict:
"""List configured CF products with live online status."""
cfg = _load_imitate_config()
products_raw = cfg.get("products", []) or []
products = []
for p in products_raw:
if not isinstance(p, dict):
continue
base_url = p.get("base_url", "")
products.append({
"id": p.get("id", ""),
"name": p.get("name", ""),
"icon": p.get("icon", "📦"),
"description": p.get("description", ""),
"base_url": base_url,
"online": _is_online(base_url, p.get("health_path", "/api/health")) if base_url else False,
})
return {"products": products}
# ── GET /products/{product_id}/sample ─────────────────────────────────────────
@router.get("/products/{product_id}/sample")
def get_sample(product_id: str, index: int = 0) -> dict:
"""Fetch a real sample from the given product's API."""
cfg = _load_imitate_config()
products_raw = cfg.get("products", []) or []
product: dict | None = None
for p in products_raw:
if isinstance(p, dict) and p.get("id") == product_id:
product = p
break
if product is None:
raise HTTPException(404, f"Product '{product_id}' not in config")
base_url = product.get("base_url", "").rstrip("/")
endpoint = product.get("sample_endpoint", "")
if not base_url or not endpoint:
raise HTTPException(422, "Product missing base_url or sample_endpoint")
url = f"{base_url}{endpoint}"
try:
raw = _http_get_json(url, timeout=5)
except URLError as exc:
raise HTTPException(503, f"Product API unreachable: {exc}") from exc
except Exception as exc:
raise HTTPException(502, f"Bad response from product API: {exc}") from exc
text_fields = product.get("text_fields", []) or []
sample_key = product.get("sample_key") or None
extracted = _extract_sample(raw, text_fields, index, sample_key=sample_key)
if not extracted:
raise HTTPException(404, "No sample items returned by product API")
prompt_template = product.get("prompt_template", "{text}")
prompt = prompt_template.replace("{text}", extracted["text"])
# Also substitute any {field_name} placeholders from the raw item fields.
item = extracted.get("item", {})
for field, val in item.items():
prompt = prompt.replace(f"{{{field}}}", str(val) if val is not None else "")
# Expose system_prompt and image_url if the product API returns them.
# system_prompt: Peregrine, Snipe (vision analysis instructions)
# image_url: Snipe listing photos — Avocet downloads + base64-encodes at run time
item = extracted.get("item", {})
system_prompt = str(item.get("system_prompt", "")) if isinstance(item, dict) else ""
image_url = str(item.get("image_url", "")) if isinstance(item, dict) else ""
return {
"product_id": product_id,
"sample_index": index,
"text": extracted["text"],
"prompt": prompt,
"system_prompt": system_prompt,
"image_url": image_url,
"raw_item": item,
}
# ── GET /catalog ───────────────────────────────────────────────────────────────
@router.get("/catalog")
def get_catalog() -> dict:
"""Return the live cf-text model catalog from cf-orch coordinator."""
models = _cforch_catalog(_cforch_url())
return {"models": models}
# ── GET /run (SSE) ─────────────────────────────────────────────────────────────
def _get_imitate_session(request: Any, response: Any) -> "CloudUser | None":
"""Optional session dependency — returns None when cloud_session is unavailable."""
try:
from app.cloud_session import get_session
return get_session(request, response)
except Exception:
return None
@router.get("/run")
def run_imitate(
prompt: str = "",
model_ids: str = "", # comma-separated ollama model IDs
cf_text_model_ids: str = "", # comma-separated cf-text model IDs (via cf-orch)
task_ids: str = "", # comma-separated "product/task" strings — resolved via assignments
temperature: float = 0.7,
product_id: str = "",
system: str = "", # optional system prompt
image_url: str = "", # optional image URL for vision models
session: "Any" = Depends(_get_imitate_session),
) -> StreamingResponse:
"""Run a prompt through selected models and stream results as SSE.
Models can be selected three ways (combinable):
- model_ids: explicit ollama model IDs
- cf_text_model_ids: explicit cf-text model IDs routed via cf-orch
- task_ids: "product/task" strings resolved via the coordinator assignments table
If image_url is provided, the image is downloaded once and passed to every
model as a base64-encoded blob allowing vision-capable local models to
evaluate listing photos the same way Snipe's background task pipeline does.
"""
if not prompt.strip():
raise HTTPException(422, "prompt is required")
ollama_ids = [m.strip() for m in model_ids.split(",") if m.strip()]
cftext_ids = [m.strip() for m in cf_text_model_ids.split(",") if m.strip()]
raw_task_ids = [t.strip() for t in task_ids.split(",") if t.strip()]
# Resolve task assignments to concrete model IDs, routing to the right service.
# Models that fail to resolve emit an error event at run time (non-fatal).
if raw_task_ids:
cforch_base = _cforch_url()
for task_spec in raw_task_ids:
parts = task_spec.split("/", 1)
if len(parts) != 2:
logger.warning("Skipping malformed task_id %r (expected product/task)", task_spec)
continue
product_name, task_name = parts
resolved = _resolve_task_model(cforch_base, product_name, task_name)
if resolved is None:
logger.warning("No assignment found for task %r", task_spec)
# Emit error at stream time via a sentinel in cftext_ids with a special label.
# We instead store the failed task_spec to emit a model_done error.
cftext_ids.append(f"__task_unresolved__:{task_spec}")
continue
mid = resolved["model_id"]
svc = resolved["service_type"]
if svc == "ollama":
if mid not in ollama_ids:
ollama_ids.append(mid)
else:
# cf-text, vllm, and any other cf-orch-managed service
if mid not in cftext_ids:
cftext_ids.append(mid)
if not ollama_ids and not cftext_ids:
raise HTTPException(422, "model_ids, cf_text_model_ids, or task_ids is required")
cfg = _load_imitate_config()
ollama_base = _ollama_url(cfg)
cforch_base = _cforch_url()
system_ctx = system.strip() or ""
total_models = len(ollama_ids) + len(cftext_ids)
# Download image once before streaming — shared across ollama vision models
images: list[str] = []
if image_url.strip():
b64 = _fetch_image_b64(image_url.strip())
if b64:
images = [b64]
def generate():
results: list[dict] = []
yield _sse({"type": "start", "total_models": total_models, "has_image": bool(images)})
# Ollama models
for model_id in ollama_ids:
yield _sse({"type": "model_start", "model": model_id, "service": "ollama"})
try:
response, elapsed_ms = _run_ollama_streaming(
ollama_base, model_id, prompt, temperature,
system=system_ctx, images=images or None,
)
result = {
"model": model_id,
"response": response,
"elapsed_ms": elapsed_ms,
"error": None,
}
except Exception as exc:
result = {
"model": model_id,
"response": "",
"elapsed_ms": 0,
"error": str(exc),
}
results.append(result)
yield _sse({"type": "model_done", **result})
# cf-text models via cf-orch — fan out in parallel when multiple models selected
# Partition the list: real cf-text IDs vs unresolved-task sentinels.
cftext_real = [m for m in cftext_ids if not m.startswith("__task_unresolved__:")]
cftext_unresolved = [m for m in cftext_ids if m.startswith("__task_unresolved__:")]
for sentinel in cftext_unresolved:
task_spec = sentinel.split(":", 1)[1]
result = {
"model": task_spec,
"response": "",
"elapsed_ms": 0,
"error": f"No assignment configured for task '{task_spec}'",
}
results.append(result)
yield _sse({"type": "model_done", **result})
if cftext_real:
from concurrent.futures import ThreadPoolExecutor, as_completed
# Announce all models upfront so the UI can show loading states immediately
for model_id in cftext_real:
yield _sse({"type": "model_start", "model": model_id, "service": "cf-text"})
_user_id: str | None = getattr(session, "user_id", None)
# Only forward real cloud user IDs — skip local/anon sessions
if _user_id in (None, "local", "local-dev") or (_user_id or "").startswith("anon-"):
_user_id = None
with ThreadPoolExecutor(max_workers=len(cftext_real)) as pool:
future_to_model = {
pool.submit(
_run_cftext, cforch_base, mid, prompt, system_ctx, temperature,
180.0, _user_id,
): mid
for mid in cftext_real
}
for future in as_completed(future_to_model):
model_id = future_to_model[future]
try:
response, elapsed_ms, cold_started = future.result()
if cold_started:
yield _sse({"type": "model_coldstart", "model": model_id})
result = {
"model": model_id,
"response": response,
"elapsed_ms": elapsed_ms,
"error": None,
}
except Exception as exc:
result = {
"model": model_id,
"response": "",
"elapsed_ms": 0,
"error": str(exc),
}
results.append(result)
yield _sse({"type": "model_done", **result})
yield _sse({"type": "complete", "results": results})
return StreamingResponse(
generate(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"X-Accel-Buffering": "no",
},
)
# ── POST /push-corrections ─────────────────────────────────────────────────────
class ImitateResult(BaseModel):
model: str
response: str
elapsed_ms: int
error: str | None = None
class PushCorrectionsRequest(BaseModel):
product_id: str
prompt: str
results: list[ImitateResult]
@router.post("/push-corrections")
def push_corrections(req: PushCorrectionsRequest) -> dict:
"""Append imitate results to sft_candidates.jsonl for human review."""
if not req.prompt.strip():
raise HTTPException(422, "prompt is required")
if not req.results:
raise HTTPException(422, "results list is empty")
ts = datetime.now(timezone.utc).isoformat()
records = []
for r in req.results:
if r.error or not r.response.strip():
continue
records.append({
"id": str(uuid.uuid4()),
"source": "imitate",
"product_id": req.product_id,
"prompt_messages": [{"role": "user", "content": req.prompt}],
"model_response": r.response,
"model_id": r.model,
"elapsed_ms": r.elapsed_ms,
"status": "pending",
"created_at": ts,
})
if not records:
raise HTTPException(422, "No non-error results to push")
dest = _candidates_file()
dest.parent.mkdir(parents=True, exist_ok=True)
for record in records:
append_jsonl(dest, record)
return {"pushed": len(records)}

222
app/data/label.py Normal file
View file

@ -0,0 +1,222 @@
"""Avocet -- label queue API.
All label/skip/discard/undo/stats/config endpoints.
Extracted from app/api.py as part of the v2 domain split.
"""
from __future__ import annotations
import hashlib
import json
import yaml
from datetime import datetime, timezone
from pathlib import Path
from fastapi import APIRouter, HTTPException, Query
from fastapi.responses import FileResponse
from pydantic import BaseModel
from app.utils import append_jsonl, read_jsonl, write_jsonl
_ROOT = Path(__file__).parent.parent.parent
_DATA_DIR: Path = _ROOT / "data"
_CONFIG_DIR: Path | None = None
_last_action: dict | None = None
router = APIRouter()
def set_data_dir(path: Path) -> None:
global _DATA_DIR
_DATA_DIR = path
def set_config_dir(path: Path | None) -> None:
global _CONFIG_DIR
_CONFIG_DIR = path
def reset_last_action() -> None:
global _last_action
_last_action = None
def _config_file() -> Path:
if _CONFIG_DIR is not None:
return _CONFIG_DIR / "label_tool.yaml"
return _ROOT / "config" / "label_tool.yaml"
def _queue_file() -> Path:
return _DATA_DIR / "email_label_queue.jsonl"
def _score_file() -> Path:
return _DATA_DIR / "email_score.jsonl"
def _discarded_file() -> Path:
return _DATA_DIR / "discarded.jsonl"
def _item_id(item: dict) -> str:
key = (item.get("subject", "") + (item.get("body", "") or "")[:100])
return hashlib.md5(key.encode("utf-8", errors="replace")).hexdigest()
def _normalize(item: dict) -> dict:
return {
"id": item.get("id") or _item_id(item),
"subject": item.get("subject", ""),
"body": item.get("body", ""),
"from": item.get("from") or item.get("from_addr", ""),
"date": item.get("date", ""),
"source": item.get("source") or item.get("account", ""),
}
_LABEL_META = [
{"name": "interview_scheduled", "emoji": "\U0001f4c5", "color": "#4CAF50", "key": "1"},
{"name": "offer_received", "emoji": "\U0001f389", "color": "#2196F3", "key": "2"},
{"name": "rejected", "emoji": "", "color": "#F44336", "key": "3"},
{"name": "positive_response", "emoji": "\U0001f44d", "color": "#FF9800", "key": "4"},
{"name": "survey_received", "emoji": "\U0001f4cb", "color": "#9C27B0", "key": "5"},
{"name": "neutral", "emoji": "", "color": "#607D8B", "key": "6"},
{"name": "event_rescheduled", "emoji": "\U0001f504", "color": "#FF5722", "key": "7"},
{"name": "digest", "emoji": "\U0001f4f0", "color": "#00BCD4", "key": "8"},
{"name": "new_lead", "emoji": "\U0001f91d", "color": "#009688", "key": "9"},
{"name": "hired", "emoji": "\U0001f38a", "color": "#FFC107", "key": "h"},
]
@router.get("/queue")
def get_queue(limit: int = Query(default=10, ge=1, le=50)):
items = read_jsonl(_queue_file())
return {"items": [_normalize(x) for x in items[:limit]], "total": len(items)}
class LabelRequest(BaseModel):
id: str
label: str
@router.post("/label")
def post_label(req: LabelRequest):
global _last_action
items = read_jsonl(_queue_file())
match = next((x for x in items if _normalize(x)["id"] == req.id), None)
if not match:
raise HTTPException(404, f"Item {req.id!r} not found in queue")
record = {**match, "label": req.label,
"labeled_at": datetime.now(timezone.utc).isoformat()}
append_jsonl(_score_file(), record)
write_jsonl(_queue_file(), [x for x in items if _normalize(x)["id"] != req.id])
_last_action = {"type": "label", "item": match, "label": req.label}
return {"ok": True}
class SkipRequest(BaseModel):
id: str
@router.post("/skip")
def post_skip(req: SkipRequest):
global _last_action
items = read_jsonl(_queue_file())
match = next((x for x in items if _normalize(x)["id"] == req.id), None)
if not match:
raise HTTPException(404, f"Item {req.id!r} not found in queue")
reordered = [x for x in items if _normalize(x)["id"] != req.id] + [match]
write_jsonl(_queue_file(), reordered)
_last_action = {"type": "skip", "item": match}
return {"ok": True}
class DiscardRequest(BaseModel):
id: str
@router.post("/discard")
def post_discard(req: DiscardRequest):
global _last_action
items = read_jsonl(_queue_file())
match = next((x for x in items if _normalize(x)["id"] == req.id), None)
if not match:
raise HTTPException(404, f"Item {req.id!r} not found in queue")
record = {**match, "label": "__discarded__",
"discarded_at": datetime.now(timezone.utc).isoformat()}
append_jsonl(_discarded_file(), record)
write_jsonl(_queue_file(), [x for x in items if _normalize(x)["id"] != req.id])
_last_action = {"type": "discard", "item": match}
return {"ok": True}
@router.delete("/label/undo")
def delete_undo():
global _last_action
if not _last_action:
raise HTTPException(404, "No action to undo")
action = _last_action
item = action["item"]
if action["type"] == "label":
records = read_jsonl(_score_file())
if not records:
raise HTTPException(409, "Score file is empty -- cannot undo label")
write_jsonl(_score_file(), records[:-1])
items = read_jsonl(_queue_file())
write_jsonl(_queue_file(), [item] + items)
elif action["type"] == "discard":
records = read_jsonl(_discarded_file())
if not records:
raise HTTPException(409, "Discarded file is empty -- cannot undo discard")
write_jsonl(_discarded_file(), records[:-1])
items = read_jsonl(_queue_file())
write_jsonl(_queue_file(), [item] + items)
elif action["type"] == "skip":
items = read_jsonl(_queue_file())
item_id = _normalize(item)["id"]
items = [item] + [x for x in items if _normalize(x)["id"] != item_id]
write_jsonl(_queue_file(), items)
_last_action = None
return {"undone": {"type": action["type"], "item": _normalize(item)}}
@router.get("/config/labels")
def get_labels():
return _LABEL_META
@router.get("/config")
def get_config():
f = _config_file()
if not f.exists():
return {"accounts": [], "max_per_account": 500}
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
return {"accounts": raw.get("accounts", []), "max_per_account": raw.get("max_per_account", 500)}
class ConfigPayload(BaseModel):
accounts: list[dict]
max_per_account: int = 500
@router.post("/config")
def post_config(payload: ConfigPayload):
f = _config_file()
f.parent.mkdir(parents=True, exist_ok=True)
tmp = f.with_suffix(".tmp")
tmp.write_text(yaml.dump(payload.model_dump(), allow_unicode=True, sort_keys=False),
encoding="utf-8")
tmp.rename(f)
return {"ok": True}
@router.get("/stats")
def get_stats():
records = read_jsonl(_score_file())
counts: dict[str, int] = {}
for r in records:
lbl = r.get("label", "")
if lbl:
counts[lbl] = counts.get(lbl, 0) + 1
benchmark_results: dict = {}
benchmark_path = _DATA_DIR / "benchmark_results.json"
if benchmark_path.exists():
try:
benchmark_results = json.loads(benchmark_path.read_text(encoding="utf-8"))
except Exception:
pass
return {
"total": len(records),
"counts": counts,
"score_file_bytes": _score_file().stat().st_size if _score_file().exists() else 0,
"benchmark_results": benchmark_results,
}
@router.get("/stats/download")
def download_stats():
if not _score_file().exists():
raise HTTPException(404, "No score file")
return FileResponse(
str(_score_file()),
filename="email_score.jsonl",
media_type="application/jsonlines",
headers={"Content-Disposition": 'attachment; filename="email_score.jsonl"'},
)

462
app/data/log_corpus.py Normal file
View file

@ -0,0 +1,462 @@
"""Avocet — Log Corpus receiver and labeling API.
Receives push batches from consented Turnstone nodes, stores entries for labeling,
and exports labeled data as JSONL for the logreading fine-tune pipeline.
DB: data/corpus.db (separate from train_jobs.db different lifecycle)
Auth: Bearer token validated against corpus_sources table (seeded from label_tool.yaml).
All endpoints registered on `router`. api.py includes this with prefix="/api/corpus".
"""
from __future__ import annotations
import json
import logging
import sqlite3
import uuid
from contextlib import contextmanager
from datetime import datetime, timezone
from pathlib import Path
from typing import Generator
import yaml
from fastapi import APIRouter, Depends, HTTPException
from fastapi.requests import Request
from fastapi.responses import StreamingResponse
logger = logging.getLogger(__name__)
_ROOT = Path(__file__).parent.parent.parent
_CONFIG_DIR: Path | None = None
_DATA_DIR: Path = _ROOT / "data"
router = APIRouter()
_DB_PATH: Path = _ROOT / "data" / "corpus.db"
_PIPELINE_SOURCE_HOST = "pipeline_scrape"
_SCHEMA = """
CREATE TABLE IF NOT EXISTS corpus_sources (
token TEXT PRIMARY KEY,
source_host TEXT NOT NULL,
owner TEXT NOT NULL,
consent_date TEXT NOT NULL,
consent_method TEXT NOT NULL,
active INTEGER NOT NULL DEFAULT 1
);
CREATE TABLE IF NOT EXISTS corpus_batches (
id TEXT PRIMARY KEY,
source_host TEXT NOT NULL,
batch_type TEXT NOT NULL,
received_at TEXT NOT NULL,
entry_count INTEGER NOT NULL,
watermark_from TEXT,
watermark_to TEXT,
raw_json TEXT NOT NULL
);
CREATE TABLE IF NOT EXISTS corpus_entries (
id TEXT PRIMARY KEY,
batch_id TEXT NOT NULL REFERENCES corpus_batches(id),
source_host TEXT NOT NULL,
origin_entry_id TEXT,
timestamp_iso TEXT,
severity TEXT,
source_id TEXT,
text TEXT NOT NULL,
matched_patterns TEXT DEFAULT '[]',
label_state TEXT NOT NULL DEFAULT 'unlabeled',
failure_type TEXT,
plain_explanation TEXT,
known_pattern TEXT,
labeled_at TEXT,
labeled_by TEXT DEFAULT 'alan',
pii_flagged INTEGER NOT NULL DEFAULT 0
);
CREATE INDEX IF NOT EXISTS idx_ce_label_state ON corpus_entries(label_state);
CREATE INDEX IF NOT EXISTS idx_ce_source ON corpus_entries(source_host);
CREATE INDEX IF NOT EXISTS idx_ce_severity ON corpus_entries(severity);
CREATE TABLE IF NOT EXISTS ingested_pipeline_files (
filename TEXT PRIMARY KEY,
ingested_at TEXT NOT NULL,
entry_count INTEGER NOT NULL
);
"""
# ── Testability seams ──────────────────────────────────────────────────────────
def set_config_dir(path: Path | None) -> None:
global _CONFIG_DIR
_CONFIG_DIR = path
def set_data_dir(path: Path) -> None:
global _DATA_DIR, _DB_PATH
_DATA_DIR = path
_DB_PATH = path / "corpus.db"
# ── Internal helpers ───────────────────────────────────────────────────────────
def _config_file() -> Path:
if _CONFIG_DIR is not None:
return _CONFIG_DIR / "label_tool.yaml"
return _ROOT / "config" / "label_tool.yaml"
@contextmanager
def _db() -> Generator[sqlite3.Connection, None, None]:
conn = sqlite3.connect(str(_DB_PATH))
conn.row_factory = sqlite3.Row
conn.execute("PRAGMA journal_mode=WAL")
try:
yield conn
conn.commit()
except Exception:
conn.rollback()
raise
finally:
conn.close()
def _init_db() -> None:
with _db() as conn:
conn.executescript(_SCHEMA)
_seed_sources(conn)
def _pipeline_ingest_dir() -> Path | None:
"""Return the configured pipeline log ingest directory, or None if unset."""
f = _config_file()
if not f.exists():
return None
try:
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
except yaml.YAMLError:
return None
val = raw.get("corpus", {}).get("pipeline_ingest_dir", "") or ""
return Path(val) if val else None
def _load_corpus_config() -> list[dict]:
f = _config_file()
if not f.exists():
return []
try:
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
except yaml.YAMLError as exc:
logger.warning("Failed to parse corpus config: %s", exc)
return []
return raw.get("corpus", {}).get("sources", []) or []
def _seed_sources(conn: sqlite3.Connection) -> None:
for src in _load_corpus_config():
conn.execute(
"INSERT OR IGNORE INTO corpus_sources (token, source_host, owner, consent_date, consent_method) "
"VALUES (?, ?, ?, ?, ?)",
(src["token"], src["source_host"], src["owner"],
src["consent_date"], src["consent_method"]),
)
def _validate_token(token: str, conn: sqlite3.Connection) -> str:
"""Return source_host for token, or raise 403."""
row = conn.execute(
"SELECT source_host FROM corpus_sources WHERE token = ? AND active = 1",
(token,),
).fetchone()
if row is None:
raise HTTPException(status_code=403, detail="Unknown or revoked consent token")
return row["source_host"]
def _extract_bearer(request: Request) -> str:
auth = request.headers.get("Authorization", "")
if not auth.startswith("Bearer "):
raise HTTPException(status_code=401, detail="Bearer token required")
return auth.removeprefix("Bearer ").strip()
def _now_iso() -> str:
return datetime.now(timezone.utc).isoformat()
# ── Startup ────────────────────────────────────────────────────────────────────
_init_db()
# ── POST /api/corpus/log-batch ─────────────────────────────────────────────────
@router.post("/log-batch")
def receive_batch(request: Request, payload: dict) -> dict:
"""Accept a push batch from a Turnstone node."""
token = _extract_bearer(request)
batch_type = payload.get("batch_type", "raw_entries")
entries_raw = payload.get("entries", [])
batch_id = payload.get("batch_id") or str(uuid.uuid4())
with _db() as conn:
source_host = _validate_token(token, conn)
conn.execute(
"INSERT INTO corpus_batches (id, source_host, batch_type, received_at, entry_count, "
"watermark_from, watermark_to, raw_json) VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
(batch_id, source_host, batch_type, _now_iso(), len(entries_raw),
str(payload.get("watermark_from", "")),
str(payload.get("watermark_to", "")),
json.dumps(payload)),
)
stored = 0
for entry in entries_raw:
text = entry.get("text", "").strip()
if not text:
continue
conn.execute(
"INSERT OR IGNORE INTO corpus_entries "
"(id, batch_id, source_host, origin_entry_id, timestamp_iso, severity, "
"source_id, text, matched_patterns) "
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)",
(str(uuid.uuid4()), batch_id, source_host,
entry.get("entry_id") or entry.get("id"),
entry.get("timestamp_iso"),
entry.get("severity"),
entry.get("source_id"),
text,
json.dumps(entry.get("matched_patterns", []))),
)
stored += 1
logger.info("Received batch %s from %s: %d/%d entries stored",
batch_id, source_host, stored, len(entries_raw))
return {"received": True, "batch_id": batch_id, "entries_stored": stored}
# ── GET /api/corpus/entries ────────────────────────────────────────────────────
@router.get("/entries")
def list_entries(
state: str = "unlabeled",
source_host: str | None = None,
limit: int = 25,
) -> dict:
"""Return entries for labeling. Default: unlabeled entries, oldest first."""
with _db() as conn:
query = "SELECT * FROM corpus_entries WHERE label_state = ?"
params: list = [state]
if source_host:
query += " AND source_host = ?"
params.append(source_host)
query += " ORDER BY rowid LIMIT ?"
params.append(min(limit, 100))
rows = conn.execute(query, params).fetchall()
return {"entries": [dict(r) for r in rows], "count": len(rows)}
# ── POST /api/corpus/entries/{id}/label ───────────────────────────────────────
@router.post("/entries/{entry_id}/label")
def label_entry(entry_id: str, body: dict) -> dict:
"""Submit a label for a corpus entry."""
failure_type = body.get("failure_type")
plain_explanation = body.get("plain_explanation", "").strip()
known_pattern = body.get("known_pattern")
pii_flagged = int(bool(body.get("pii_flagged", False)))
if not failure_type:
raise HTTPException(status_code=422, detail="failure_type is required")
valid_types = {"hardware", "software", "network", "security", "application", "none", "other"}
if failure_type not in valid_types:
raise HTTPException(status_code=422, detail=f"failure_type must be one of {sorted(valid_types)}")
with _db() as conn:
row = conn.execute("SELECT id FROM corpus_entries WHERE id = ?", (entry_id,)).fetchone()
if row is None:
raise HTTPException(status_code=404, detail="Entry not found")
conn.execute(
"UPDATE corpus_entries SET label_state='labeled', failure_type=?, plain_explanation=?, "
"known_pattern=?, labeled_at=?, pii_flagged=? WHERE id=?",
(failure_type, plain_explanation, known_pattern, _now_iso(), pii_flagged, entry_id),
)
return {"labeled": True, "entry_id": entry_id}
# ── POST /api/corpus/entries/{id}/skip ────────────────────────────────────────
@router.post("/entries/{entry_id}/skip")
def skip_entry(entry_id: str) -> dict:
with _db() as conn:
row = conn.execute("SELECT id FROM corpus_entries WHERE id = ?", (entry_id,)).fetchone()
if row is None:
raise HTTPException(status_code=404, detail="Entry not found")
conn.execute(
"UPDATE corpus_entries SET label_state='skipped' WHERE id=?", (entry_id,)
)
return {"skipped": True, "entry_id": entry_id}
# ── GET /api/corpus/stats ──────────────────────────────────────────────────────
@router.get("/stats")
def get_stats() -> dict:
with _db() as conn:
total = conn.execute("SELECT COUNT(*) FROM corpus_entries").fetchone()[0]
by_state = {
r["label_state"]: r["cnt"]
for r in conn.execute(
"SELECT label_state, COUNT(*) AS cnt FROM corpus_entries GROUP BY label_state"
).fetchall()
}
by_source = {
r["source_host"]: r["cnt"]
for r in conn.execute(
"SELECT source_host, COUNT(*) AS cnt FROM corpus_entries GROUP BY source_host"
).fetchall()
}
by_severity = {
r["severity"]: r["cnt"]
for r in conn.execute(
"SELECT severity, COUNT(*) AS cnt FROM corpus_entries "
"WHERE severity IS NOT NULL GROUP BY severity"
).fetchall()
}
batch_count = conn.execute("SELECT COUNT(*) FROM corpus_batches").fetchone()[0]
return {
"total_entries": total,
"batch_count": batch_count,
"by_label_state": by_state,
"by_source": by_source,
"by_severity": by_severity,
}
# ── GET /api/corpus/export ────────────────────────────────────────────────────
@router.get("/export")
def export_labeled() -> StreamingResponse:
"""Stream labeled, non-PII entries as JSONL for SFT harness."""
with _db() as conn:
rows = conn.execute(
"SELECT source_host, source_id, severity, text, failure_type, plain_explanation, known_pattern "
"FROM corpus_entries "
"WHERE label_state = 'labeled' AND pii_flagged = 0 AND plain_explanation != ''"
"ORDER BY rowid"
).fetchall()
def _generate():
for row in rows:
record = {
"input": row["text"],
"output": row["plain_explanation"],
"metadata": {
"failure_type": row["failure_type"],
"source": row["source_host"],
"source_id": row["source_id"],
"severity": row["severity"],
"known_pattern": row["known_pattern"],
},
}
yield json.dumps(record) + "\n"
return StreamingResponse(
_generate(),
media_type="application/x-ndjson",
headers={"Content-Disposition": "attachment; filename=log_corpus_labeled.jsonl"},
)
# ── POST /api/corpus/pipeline-ingest ─────────────────────────────────────────
def _ingest_one_file(conn: sqlite3.Connection, path: Path) -> int:
"""Parse a pipeline JSONL file and insert entries. Returns count stored."""
batch_id = str(uuid.uuid4())
lines = path.read_text(encoding="utf-8").splitlines()
entries_raw: list[dict] = []
for line in lines:
line = line.strip()
if not line:
continue
try:
entries_raw.append(json.loads(line))
except json.JSONDecodeError:
logger.debug("Skipping malformed line in %s", path.name)
conn.execute(
"INSERT INTO corpus_batches (id, source_host, batch_type, received_at, entry_count, raw_json) "
"VALUES (?, ?, ?, ?, ?, ?)",
(batch_id, _PIPELINE_SOURCE_HOST, "pipeline_log", _now_iso(),
len(entries_raw), json.dumps({"file": path.name})),
)
stored = 0
for entry in entries_raw:
text = (entry.get("msg") or "").strip()
if not text:
continue
conn.execute(
"INSERT OR IGNORE INTO corpus_entries "
"(id, batch_id, source_host, timestamp_iso, severity, source_id, text, matched_patterns) "
"VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
(str(uuid.uuid4()), batch_id, _PIPELINE_SOURCE_HOST,
entry.get("ts"),
entry.get("level"),
entry.get("logger"),
text,
json.dumps([entry["extra"]] if entry.get("extra") else [])),
)
stored += 1
conn.execute(
"INSERT INTO ingested_pipeline_files (filename, ingested_at, entry_count) VALUES (?, ?, ?)",
(path.name, _now_iso(), stored),
)
return stored
@router.post("/pipeline-ingest")
def pipeline_ingest() -> dict:
"""Walk the configured pipeline log directory and ingest new JSONL files.
Skips files already recorded in ingested_pipeline_files. Safe to call
repeatedly idempotent by filename.
"""
ingest_dir = _pipeline_ingest_dir()
if ingest_dir is None:
raise HTTPException(404, "pipeline_ingest_dir not configured in label_tool.yaml")
ingested = 0
skipped = 0
total_stored = 0
files_detail: list[dict] = []
with _db() as conn:
already_done: set[str] = {
row[0]
for row in conn.execute("SELECT filename FROM ingested_pipeline_files").fetchall()
}
for path in sorted(ingest_dir.glob("*.jsonl")):
if path.name in already_done:
skipped += 1
continue
stored = _ingest_one_file(conn, path)
ingested += 1
total_stored += stored
files_detail.append({"file": path.name, "entries_stored": stored})
logger.info("Pipeline ingest: %d files ingested, %d skipped, %d entries stored",
ingested, skipped, total_stored)
return {
"ingested_files": ingested,
"skipped_files": skipped,
"entries_stored": total_stored,
"files": files_detail,
}

313
app/data/recipe_scan.py Normal file
View file

@ -0,0 +1,313 @@
"""Avocet — Recipe scan labeling API (avocet#65).
Receives recipe scan items from the Kiwi pipeline (scanner/phone image +
docuvision OCR extraction + ground-truth structured recipe), presents them
for human review, and exports approved/edited pairs in the messages chat
format for the vision fine-tune harness.
DB: data/recipe_scan.db (separate from corpus.db different lifecycle)
No auth required local admin tool, not a push endpoint.
All endpoints registered on `router`. api.py includes this with
prefix="/api/recipe-scan".
"""
from __future__ import annotations
import json
import logging
import sqlite3
import uuid
from contextlib import contextmanager
from datetime import datetime, timezone
from pathlib import Path
from typing import Generator, Literal
from fastapi import APIRouter, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, field_validator
logger = logging.getLogger(__name__)
_ROOT = Path(__file__).parent.parent.parent
_DB_PATH: Path = _ROOT / "data" / "recipe_scan.db"
_VALID_MODALITIES = {"scanner", "phone", "handwritten"}
_VALID_STATUSES = {"pending", "approved", "edited", "rejected"}
_SCHEMA = """
CREATE TABLE IF NOT EXISTS recipe_scan_items (
id TEXT PRIMARY KEY,
image_path TEXT NOT NULL,
modality TEXT NOT NULL DEFAULT 'scanner',
source TEXT NOT NULL DEFAULT 'purple_carrot',
extracted TEXT NOT NULL,
ground_truth TEXT NOT NULL,
status TEXT NOT NULL DEFAULT 'pending',
corrected TEXT,
labeled_at TEXT,
rejected_reason TEXT
);
CREATE INDEX IF NOT EXISTS idx_rsi_status ON recipe_scan_items(status);
CREATE INDEX IF NOT EXISTS idx_rsi_modality ON recipe_scan_items(modality);
"""
router = APIRouter()
# ── Testability seam ──────────────────────────────────────────────────────────
def set_db_path(path: Path) -> None:
global _DB_PATH
_DB_PATH = path
# ── Internal helpers ──────────────────────────────────────────────────────────
@contextmanager
def _db() -> Generator[sqlite3.Connection, None, None]:
conn = sqlite3.connect(str(_DB_PATH))
conn.row_factory = sqlite3.Row
conn.execute("PRAGMA journal_mode=WAL")
try:
yield conn
conn.commit()
except Exception:
conn.rollback()
raise
finally:
conn.close()
def _init_db() -> None:
with _db() as conn:
conn.executescript(_SCHEMA)
def _now_iso() -> str:
return datetime.now(timezone.utc).isoformat()
def _build_training_pair(row: sqlite3.Row) -> dict:
"""Build a messages-format training pair from a labeled row.
user message: correction prompt + the docuvision-extracted JSON draft.
Trains the model to review and correct an existing extraction, which is
more data-efficient than producing from scratch when OCR is usually close.
assistant message: the approved ground truth (or human-corrected JSON).
"""
target_str = row["corrected"] if row["corrected"] else row["ground_truth"]
extracted = json.loads(row["extracted"])
target = json.loads(target_str)
user_content = (
"Review and correct this recipe extraction. "
"Return valid JSON with fields: title, description, ingredients, steps, "
"prep_time, cook_time, servings.\n\n"
f"Extraction to review:\n{json.dumps(extracted, ensure_ascii=False, indent=2)}"
)
return {
"id": row["id"],
"modality": row["modality"],
"source": row["source"],
"image_path": row["image_path"],
"messages": [
{"role": "user", "content": user_content},
{"role": "assistant", "content": json.dumps(target, ensure_ascii=False)},
],
}
_init_db()
# ── POST /import ───────────────────────────────────────────────────────────────
class ImportItem(BaseModel):
id: str = ""
image_path: str
modality: Literal["scanner", "phone", "handwritten"] = "scanner"
source: str = "purple_carrot"
extracted: dict
ground_truth: dict
@field_validator("id", mode="before")
@classmethod
def default_id(cls, v: str) -> str:
return v or str(uuid.uuid4())
class ImportRequest(BaseModel):
items: list[ImportItem]
@router.post("/import")
def import_items(body: ImportRequest) -> dict:
"""Bulk-import scan items from the Kiwi pipeline. Idempotent by item id."""
stored = 0
with _db() as conn:
for item in body.items:
result = conn.execute(
"INSERT OR IGNORE INTO recipe_scan_items "
"(id, image_path, modality, source, extracted, ground_truth) "
"VALUES (?, ?, ?, ?, ?, ?)",
(item.id, item.image_path, item.modality, item.source,
json.dumps(item.extracted), json.dumps(item.ground_truth)),
)
stored += result.rowcount
return {"imported": stored, "total_submitted": len(body.items)}
# ── GET /next ─────────────────────────────────────────────────────────────────
@router.get("/next")
def get_next() -> dict:
"""Return the next pending item for review, oldest-first."""
with _db() as conn:
row = conn.execute(
"SELECT * FROM recipe_scan_items WHERE status = 'pending' ORDER BY rowid LIMIT 1"
).fetchone()
if row is None:
raise HTTPException(404, "No pending items in queue")
return {
**dict(row),
"extracted": json.loads(row["extracted"]),
"ground_truth": json.loads(row["ground_truth"]),
}
# ── POST /items/{id}/approve ──────────────────────────────────────────────────
@router.post("/items/{item_id}/approve")
def approve_item(item_id: str) -> dict:
"""Mark item as approved — extracted JSON is close enough to ground truth."""
with _db() as conn:
row = conn.execute("SELECT id FROM recipe_scan_items WHERE id = ?", (item_id,)).fetchone()
if row is None:
raise HTTPException(404, "Item not found")
conn.execute(
"UPDATE recipe_scan_items SET status='approved', labeled_at=? WHERE id=?",
(_now_iso(), item_id),
)
return {"status": "approved", "id": item_id}
# ── POST /items/{id}/edit ─────────────────────────────────────────────────────
class EditBody(BaseModel):
corrected: dict
@router.post("/items/{item_id}/edit")
def edit_item(item_id: str, body: EditBody) -> dict:
"""Approve with a human-corrected JSON. corrected overrides extracted in export."""
with _db() as conn:
row = conn.execute("SELECT id FROM recipe_scan_items WHERE id = ?", (item_id,)).fetchone()
if row is None:
raise HTTPException(404, "Item not found")
conn.execute(
"UPDATE recipe_scan_items SET status='edited', corrected=?, labeled_at=? WHERE id=?",
(json.dumps(body.corrected), _now_iso(), item_id),
)
return {"status": "edited", "id": item_id}
# ── POST /items/{id}/reject ───────────────────────────────────────────────────
class RejectBody(BaseModel):
reason: str = ""
@router.post("/items/{item_id}/reject")
def reject_item(item_id: str, body: RejectBody = RejectBody()) -> dict:
"""Reject item — extraction too broken to use for training."""
with _db() as conn:
row = conn.execute("SELECT id FROM recipe_scan_items WHERE id = ?", (item_id,)).fetchone()
if row is None:
raise HTTPException(404, "Item not found")
conn.execute(
"UPDATE recipe_scan_items SET status='rejected', rejected_reason=?, labeled_at=? WHERE id=?",
(body.reason or None, _now_iso(), item_id),
)
return {"status": "rejected", "id": item_id}
# ── GET /stats ────────────────────────────────────────────────────────────────
@router.get("/stats")
def get_stats() -> dict:
with _db() as conn:
total = conn.execute("SELECT COUNT(*) FROM recipe_scan_items").fetchone()[0]
by_status = {
r["status"]: r["cnt"]
for r in conn.execute(
"SELECT status, COUNT(*) AS cnt FROM recipe_scan_items GROUP BY status"
).fetchall()
}
by_modality = {
r["modality"]: r["cnt"]
for r in conn.execute(
"SELECT modality, COUNT(*) AS cnt FROM recipe_scan_items GROUP BY modality"
).fetchall()
}
export_ready = conn.execute(
"SELECT COUNT(*) FROM recipe_scan_items WHERE status IN ('approved', 'edited')"
).fetchone()[0]
return {
"total": total,
"by_status": by_status,
"by_modality": by_modality,
"export_ready": export_ready,
}
# ── GET /export ───────────────────────────────────────────────────────────────
@router.get("/export")
def export_pairs() -> StreamingResponse:
"""Stream approved/edited items as JSONL training pairs (messages format)."""
with _db() as conn:
rows = conn.execute(
"SELECT * FROM recipe_scan_items WHERE status IN ('approved', 'edited') ORDER BY rowid"
).fetchall()
def _generate():
for row in rows:
yield json.dumps(_build_training_pair(row), ensure_ascii=False) + "\n"
return StreamingResponse(
_generate(),
media_type="application/x-ndjson",
headers={"Content-Disposition": "attachment; filename=recipe_scan_pairs.jsonl"},
)
# ── GET /image ────────────────────────────────────────────────────────────────
_IMAGE_ROOT = Path("/Library/Assets/kiwi")
@router.get("/image")
def serve_image(path: str) -> StreamingResponse:
"""Serve a scan image from /Library/Assets/kiwi/.
path must resolve within /Library/Assets/kiwi/ rejects traversal attempts.
"""
try:
resolved = Path(path).resolve()
_IMAGE_ROOT.resolve() # ensure root itself is valid
resolved.relative_to(_IMAGE_ROOT.resolve())
except (ValueError, OSError):
raise HTTPException(403, "Path outside allowed image directory")
if not resolved.exists():
raise HTTPException(404, "Image not found")
suffix = resolved.suffix.lower()
media_types = {".jpg": "image/jpeg", ".jpeg": "image/jpeg", ".png": "image/png", ".webp": "image/webp"}
media_type = media_types.get(suffix, "application/octet-stream")
return StreamingResponse(
open(resolved, "rb"),
media_type=media_type,
headers={"Cache-Control": "public, max-age=86400"},
)

0
app/eval/__init__.py Normal file
View file

44
app/eval/cforch.py Normal file
View file

@ -0,0 +1,44 @@
"""Avocet -- eval router aggregator.
Collects benchmark sub-routers into a single importable `router`
for the api.py factory. Each sub-router retains its established prefix
so no frontend URL changes are needed.
Route prefixes when mounted at /api in api.py:
/api/cforch/* -- cf-orch benchmark routes
/api/style/* -- writing style benchmark routes
/api/voice/* -- voice benchmark routes
/api/plans-bench/* -- plans benchmark routes
"""
from __future__ import annotations
from pathlib import Path
from fastapi import APIRouter
from app.cforch import router as _cforch_router
from app.style import router as _style_router
from app.voice import router as _voice_router
from app.plans_bench import router as _plans_router
from app.eval.embed_bench import router as _embed_router
router = APIRouter()
router.include_router(_cforch_router, prefix="/cforch")
router.include_router(_style_router, prefix="/style")
router.include_router(_voice_router, prefix="/voice")
router.include_router(_plans_router, prefix="/plans-bench")
router.include_router(_embed_router, prefix="/embed-bench")
def set_config_dir(path: Path | None) -> None:
"""Propagate config dir override to all sub-modules -- used by tests."""
import app.cforch as _cforch_mod
import app.style as _style_mod
import app.voice as _voice_mod
import app.plans_bench as _plans_mod
import app.eval.embed_bench as _embed_mod
_cforch_mod.set_config_dir(path)
_style_mod.set_config_dir(path)
_voice_mod.set_config_dir(path)
_plans_mod.set_config_dir(path)
_embed_mod.set_config_dir(path)

293
app/eval/embed_bench.py Normal file
View file

@ -0,0 +1,293 @@
"""Avocet — embedding model comparison harness.
Exposes FastAPI routes under /api/embed-bench (mounted via app/eval/cforch.py).
All computation is local: no LLM inference, Ollama only. MIT tier throughout.
"""
from __future__ import annotations
import csv
import io
import json
import logging
import math
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
import httpx
import yaml
from fastapi import APIRouter, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, field_validator
logger = logging.getLogger(__name__)
_ROOT = Path(__file__).parent.parent.parent
_CONFIG_DIR: Path | None = None # override via set_config_dir() in tests
_RUN_ACTIVE: bool = False
_RATINGS_FILE = _ROOT / "data" / "embed_bench_ratings.jsonl"
router = APIRouter()
# ── Testability seam ──────────────────────────────────────────────────────────
def set_config_dir(path: Path | None) -> None:
global _CONFIG_DIR
_CONFIG_DIR = path
# ── Internal helpers ──────────────────────────────────────────────────────────
def _config_file() -> Path:
if _CONFIG_DIR is not None:
return _CONFIG_DIR / "label_tool.yaml"
return _ROOT / "config" / "label_tool.yaml"
def _load_config() -> dict[str, Any]:
f = _config_file()
if not f.exists():
return {}
try:
return yaml.safe_load(f.read_text(encoding="utf-8")) or {}
except yaml.YAMLError as exc:
logger.warning("Failed to parse embed_bench config %s: %s", f, exc)
return {}
def _ollama_url() -> str:
cfg = _load_config()
embed_cfg = cfg.get("embed_bench", {}) or {}
cforch_cfg = cfg.get("cforch", {}) or {}
return (
embed_cfg.get("ollama_url")
or cforch_cfg.get("ollama_url", "http://localhost:11434")
)
def _ratings_path() -> Path:
if _CONFIG_DIR is not None:
return _CONFIG_DIR / "embed_bench_ratings.jsonl"
return _RATINGS_FILE
def _cosine(a: list[float], b: list[float]) -> float:
if len(a) != len(b):
raise ValueError(
f"Embedding dimension mismatch: {len(a)} vs {len(b)}"
)
dot = sum(x * y for x, y in zip(a, b))
mag_a = math.sqrt(sum(x * x for x in a))
mag_b = math.sqrt(sum(x * x for x in b))
if mag_a == 0.0 or mag_b == 0.0:
return 0.0
return dot / (mag_a * mag_b)
# ── GET /models ───────────────────────────────────────────────────────────────
@router.get("/models")
def get_models() -> dict:
"""Return Ollama embedding models available on the configured instance."""
ollama = _ollama_url()
models: list[dict] = []
try:
resp = httpx.get(f"{ollama}/api/tags", timeout=5.0)
resp.raise_for_status()
for entry in resp.json().get("models", []):
models.append({
"name": entry.get("name", ""),
"size": entry.get("size", 0),
})
except httpx.HTTPStatusError as exc:
logger.warning("Ollama /api/tags returned HTTP %s: %s", exc.response.status_code, exc)
except httpx.RequestError as exc:
logger.warning("Failed to reach Ollama for model list: %s", exc)
return {"models": models, "ollama_url": ollama}
# ── POST /run ─────────────────────────────────────────────────────────────────
class RunRequest(BaseModel):
corpus: list[str]
queries: list[str]
models: list[str]
top_k: int = 5
ollama_url: str = ""
@field_validator("corpus")
@classmethod
def corpus_nonempty(cls, v: list[str]) -> list[str]:
if not v:
raise ValueError("corpus must not be empty")
return v
@field_validator("queries")
@classmethod
def queries_nonempty(cls, v: list[str]) -> list[str]:
if not v:
raise ValueError("queries must not be empty")
return v
@field_validator("models")
@classmethod
def models_nonempty(cls, v: list[str]) -> list[str]:
if not v:
raise ValueError("models must contain at least one model name")
return v
def _embed_texts(ollama: str, model: str, texts: list[str]) -> list[list[float]]:
"""Batch-embed texts via Ollama /v1/embeddings. Returns one vector per text."""
resp = httpx.post(
f"{ollama}/v1/embeddings",
json={"model": model, "input": texts},
timeout=120.0,
)
resp.raise_for_status()
data = resp.json().get("data", [])
return [item["embedding"] for item in data]
def _sse(event: dict) -> str:
return f"data: {json.dumps(event)}\n\n"
@router.post("/run")
def run_embed_bench(req: RunRequest) -> StreamingResponse:
"""Embed corpus + queries with each model; stream SSE results."""
global _RUN_ACTIVE
if _RUN_ACTIVE:
raise HTTPException(409, "An embedding benchmark run is already active")
ollama = req.ollama_url or _ollama_url()
def _generate():
global _RUN_ACTIVE
_RUN_ACTIVE = True
try:
for model_idx, model in enumerate(req.models, start=1):
yield _sse({
"type": "progress",
"msg": f"Indexing corpus with {model} ({model_idx}/{len(req.models)})...",
})
try:
corpus_vecs = _embed_texts(ollama, model, req.corpus)
except Exception as exc:
yield _sse({"type": "error", "msg": f"Ollama error for {model}: {exc}"})
continue
yield _sse({
"type": "progress",
"msg": f"Running queries with {model}...",
})
for q_idx, query in enumerate(req.queries):
try:
q_vecs = _embed_texts(ollama, model, [query])
except Exception as exc:
yield _sse({"type": "error", "msg": f"Query embed error ({model}): {exc}"})
continue
q_vec = q_vecs[0]
scored = sorted(
[
{"chunk_idx": i, "text": chunk, "score": round(_cosine(q_vec, cv), 4)}
for i, (chunk, cv) in enumerate(zip(req.corpus, corpus_vecs))
],
key=lambda h: h["score"],
reverse=True,
)[: req.top_k]
yield _sse({
"type": "result",
"query_idx": q_idx,
"query": query,
"model": model,
"hits": scored,
})
yield _sse({"type": "done"})
finally:
_RUN_ACTIVE = False
return StreamingResponse(_generate(), media_type="text/event-stream")
# ── POST /rate ────────────────────────────────────────────────────────────────
_VALID_RATINGS = {"relevant", "not_relevant"}
class RatingRequest(BaseModel):
query: str
model: str
chunk_text: str
chunk_idx: int
rating: str
@field_validator("rating")
@classmethod
def rating_valid(cls, v: str) -> str:
if v not in _VALID_RATINGS:
raise ValueError(f"rating must be one of {_VALID_RATINGS}")
return v
@router.post("/rate")
def rate_result(req: RatingRequest) -> dict:
"""Append one rating to the JSONL ratings file."""
entry = {
"timestamp": datetime.now(timezone.utc).isoformat(),
"query": req.query,
"model": req.model,
"chunk_idx": req.chunk_idx,
"chunk_text": req.chunk_text,
"rating": req.rating,
}
path = _ratings_path()
path.parent.mkdir(parents=True, exist_ok=True)
with path.open("a", encoding="utf-8") as fh:
fh.write(json.dumps(entry) + "\n")
return {"ok": True}
# ── GET /export ───────────────────────────────────────────────────────────────
_CSV_FIELDS = ["timestamp", "query", "model", "chunk_idx", "chunk_text", "rating"]
@router.get("/export")
def export_ratings(format: str = "csv") -> Any:
"""Download ratings as CSV or JSON."""
path = _ratings_path()
rows: list[dict] = []
if path.exists():
for raw in path.read_text(encoding="utf-8").splitlines():
raw = raw.strip()
if raw:
try:
rows.append(json.loads(raw))
except json.JSONDecodeError:
pass
date_str = datetime.now(timezone.utc).strftime("%Y-%m-%d")
if format == "json":
content = json.dumps(rows, ensure_ascii=False, indent=2)
return StreamingResponse(
iter([content]),
media_type="application/json",
headers={"Content-Disposition": f'attachment; filename="embed_comparison_{date_str}.json"'},
)
# Default: CSV
buf = io.StringIO()
writer = csv.DictWriter(buf, fieldnames=_CSV_FIELDS, extrasaction="ignore")
writer.writeheader()
writer.writerows(rows)
return StreamingResponse(
iter([buf.getvalue()]),
media_type="text/csv",
headers={"Content-Disposition": f'attachment; filename="embed_comparison_{date_str}.csv"'},
)

View file

@ -1,158 +1,9 @@
"""Avocet — IMAP fetch utilities.
Shared between app/api.py (FastAPI SSE endpoint) and the label UI.
No Streamlit imports here stdlib + imaplib only.
"""
from __future__ import annotations
import email as _email_lib
import hashlib
import imaplib
from datetime import datetime, timedelta
from email.header import decode_header as _raw_decode
from typing import Any, Iterator
from app.utils import extract_body, strip_html # noqa: F401 (strip_html re-exported for callers)
# ── IMAP decode helpers ───────────────────────────────────────────────────────
def _decode_str(value: str | None) -> str:
if not value:
return ""
parts = _raw_decode(value)
out = []
for part, enc in parts:
if isinstance(part, bytes):
out.append(part.decode(enc or "utf-8", errors="replace"))
else:
out.append(str(part))
return " ".join(out).strip()
def entry_key(e: dict) -> str:
"""Stable MD5 content-hash for dedup — matches label_tool.py _entry_key."""
key = (e.get("subject", "") + (e.get("body", "") or "")[:100])
return hashlib.md5(key.encode("utf-8", errors="replace")).hexdigest()
# ── Wide search terms ────────────────────────────────────────────────────────
_WIDE_TERMS = [
"interview", "phone screen", "video call", "zoom link", "schedule a call",
"offer letter", "job offer", "offer of employment", "pleased to offer",
"unfortunately", "not moving forward", "other candidates", "regret to inform",
"no longer", "decided not to", "decided to go with",
"opportunity", "interested in your background", "reached out", "great fit",
"exciting role", "love to connect",
"assessment", "questionnaire", "culture fit", "culture-fit", "online assessment",
"application received", "thank you for applying", "application confirmation",
"you applied", "your application for",
"reschedule", "rescheduled", "new time", "moved to", "postponed", "new date",
"job digest", "jobs you may like", "recommended jobs", "jobs for you",
"new jobs", "job alert",
"came across your profile", "reaching out about", "great fit for a role",
"exciting opportunity",
"welcome to the team", "start date", "onboarding", "first day", "we're excited to have you",
"application", "recruiter", "recruiting", "hiring", "candidate",
]
# ── Public API ────────────────────────────────────────────────────────────────
def test_connection(acc: dict) -> tuple[bool, str, int | None]:
"""Connect, login, select folder. Returns (ok, human_message, message_count|None)."""
host = acc.get("host", "")
port = int(acc.get("port", 993))
use_ssl = acc.get("use_ssl", True)
username = acc.get("username", "")
password = acc.get("password", "")
folder = acc.get("folder", "INBOX")
if not host or not username or not password:
return False, "Host, username, and password are all required.", None
try:
conn = (imaplib.IMAP4_SSL if use_ssl else imaplib.IMAP4)(host, port)
conn.login(username, password)
_, data = conn.select(folder, readonly=True)
count_raw = data[0].decode() if data and data[0] else "0"
count = int(count_raw) if count_raw.isdigit() else 0
conn.logout()
return True, f"Connected — {count:,} message(s) in {folder}.", count
except Exception as exc:
return False, str(exc), None
def fetch_account_stream(
acc: dict,
days_back: int,
limit: int,
known_keys: set[str],
) -> Iterator[dict]:
"""Generator — yields progress dicts while fetching emails via IMAP.
Mutates `known_keys` in place for cross-account dedup within one fetch session.
Yields event dicts with "type" key:
{"type": "start", "account": str, "total_uids": int}
{"type": "progress", "account": str, "fetched": int, "total_uids": int}
{"type": "done", "account": str, "added": int, "skipped": int, "emails": list}
"""
name = acc.get("name", acc.get("username", "?"))
host = acc.get("host", "imap.gmail.com")
port = int(acc.get("port", 993))
use_ssl = acc.get("use_ssl", True)
username = acc["username"]
password = acc["password"]
folder = acc.get("folder", "INBOX")
since = (datetime.now() - timedelta(days=days_back)).strftime("%d-%b-%Y")
conn = (imaplib.IMAP4_SSL if use_ssl else imaplib.IMAP4)(host, port)
conn.login(username, password)
conn.select(folder, readonly=True)
seen_uids: dict[bytes, None] = {}
for term in _WIDE_TERMS:
try:
_, data = conn.search(None, f'(SUBJECT "{term}" SINCE "{since}")')
for uid in (data[0] or b"").split():
seen_uids[uid] = None
except Exception:
pass
uids = list(seen_uids.keys())[: limit * 3]
yield {"type": "start", "account": name, "total_uids": len(uids)}
emails: list[dict] = []
skipped = 0
for i, uid in enumerate(uids):
if len(emails) >= limit:
break
if i % 5 == 0:
yield {"type": "progress", "account": name, "fetched": len(emails), "total_uids": len(uids)}
try:
_, raw_data = conn.fetch(uid, "(RFC822)")
if not raw_data or not raw_data[0]:
continue
msg = _email_lib.message_from_bytes(raw_data[0][1])
subj = _decode_str(msg.get("Subject", ""))
from_addr = _decode_str(msg.get("From", ""))
date = _decode_str(msg.get("Date", ""))
body = extract_body(msg)[:800]
entry = {"subject": subj, "body": body, "from_addr": from_addr,
"date": date, "account": name}
k = entry_key(entry)
if k not in known_keys:
known_keys.add(k)
emails.append(entry)
else:
skipped += 1
except Exception:
skipped += 1
try:
conn.logout()
except Exception:
pass
yield {"type": "done", "account": name, "added": len(emails), "skipped": skipped,
"emails": emails}
"""Backward-compat shim -- logic moved to app/data/fetch.py."""
import imaplib # noqa: F401 -- re-exported so existing patch("app.imap_fetch.imaplib...") calls still work
from app.data.fetch import ( # noqa: F401
entry_key,
fetch_account_stream,
test_connection,
_decode_str,
_WIDE_TERMS,
)

View file

@ -1,624 +1,3 @@
"""Avocet — Imitate tab API.
Fetches real samples from sibling CF product APIs, sends them through selected
local LLMs (ollama), and streams responses back to the UI. Results can be
pushed into the SFT corrections queue for human review.
All endpoints registered on `router`. api.py includes this with prefix="/api/imitate".
Module-level globals follow the same testability pattern as cforch.py and sft.py:
override _CONFIG_DIR and _DATA_DIR via set_config_dir() / set_data_dir() in tests.
"""
from __future__ import annotations
import base64
import json
import logging
import time
import uuid
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
from urllib.error import URLError
from urllib.request import Request, urlopen
import httpx
import yaml
from fastapi import APIRouter, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from app.utils import append_jsonl
logger = logging.getLogger(__name__)
_ROOT = Path(__file__).parent.parent
_CONFIG_DIR: Path | None = None
_DATA_DIR: Path = _ROOT / "data"
router = APIRouter()
# ── Testability seams ──────────────────────────────────────────────────────────
def set_config_dir(path: Path | None) -> None:
global _CONFIG_DIR
_CONFIG_DIR = path
def set_data_dir(path: Path) -> None:
global _DATA_DIR
_DATA_DIR = path
# ── Internal helpers ───────────────────────────────────────────────────────────
def _config_file() -> Path:
if _CONFIG_DIR is not None:
return _CONFIG_DIR / "label_tool.yaml"
return _ROOT / "config" / "label_tool.yaml"
def _load_imitate_config() -> dict:
"""Read label_tool.yaml and return the imitate sub-dict (or {} if absent)."""
f = _config_file()
if not f.exists():
return {}
try:
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
except yaml.YAMLError as exc:
logger.warning("Failed to parse imitate config %s: %s", f, exc)
return {}
return raw.get("imitate", {}) or {}
def _load_cforch_config() -> dict:
"""Read cforch section for ollama_url fallback."""
f = _config_file()
if not f.exists():
return {}
try:
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
except yaml.YAMLError as exc:
return {}
return raw.get("cforch", {}) or {}
def _ollama_url(cfg: dict) -> str:
cforch = _load_cforch_config()
return cfg.get("ollama_url") or cforch.get("ollama_url") or "http://localhost:11434"
def _cforch_url() -> str:
cforch = _load_cforch_config()
return cforch.get("coordinator_url") or "http://localhost:7700"
def _cforch_catalog(cforch_base: str) -> list[dict]:
"""Fetch the live cf-text catalog from cf-orch.
Filters out proxy entries (ollama://, vllm://, http://) those models are
served by their own services and should not be allocated via cf-text.
Returns only models with real file-system paths that cf-text can load directly.
"""
try:
resp = httpx.get(
f"{cforch_base}/api/services/cf-text/catalog",
params={"node_id": "heimdall"},
timeout=5.0,
)
resp.raise_for_status()
raw = resp.json()
result = []
for model_id, entry in raw.items():
if not isinstance(entry, dict):
continue
path = entry.get("path", "")
# Skip proxy entries — they're routed through other services
if "://" in path:
continue
result.append({
"id": model_id,
"vram_mb": entry.get("vram_mb", 0),
"description": entry.get("description", ""),
})
return result
except Exception as exc:
logger.warning("Could not fetch cf-orch catalog: %s", exc)
return []
def _http_get_json(url: str, timeout: int = 5) -> Any:
"""Fetch JSON from url; raise URLError on failure."""
req = Request(url, headers={"Accept": "application/json"})
with urlopen(req, timeout=timeout) as resp:
return json.loads(resp.read().decode("utf-8"))
def _is_online(base_url: str, health_path: str = "/api/health") -> bool:
"""Return True if the product's health endpoint responds OK."""
try:
data = _http_get_json(f"{base_url.rstrip('/')}{health_path}", timeout=2)
return bool(data)
except Exception:
return False
def _extract_sample(
raw: Any,
text_fields: list[str],
sample_index: int = 0,
sample_key: str | None = None,
) -> dict[str, Any]:
"""Pull one item from a list or dict response and extract text_fields.
sample_key: if provided, unwrap raw[sample_key] before looking for a list.
Falls back to a set of conventional envelope keys if sample_key is absent.
"""
item: dict[str, Any]
if isinstance(raw, list):
if not raw:
return {}
item = raw[min(sample_index, len(raw) - 1)]
elif isinstance(raw, dict):
# Use declared sample_key first, then fall back to conventional names.
_ENVELOPE_KEYS = (
"samples", "items", "results", "data", "jobs", "listings",
"pantry", "saved_searches", "entries", "calls", "records",
)
search_keys = ([sample_key] if sample_key else []) + list(_ENVELOPE_KEYS)
for key in search_keys:
if key in raw and isinstance(raw[key], list):
lst = raw[key]
item = lst[min(sample_index, len(lst) - 1)] if lst else {}
break
else:
item = raw
else:
return {}
parts = []
for field in text_fields:
val = item.get(field)
if val and str(val).strip():
parts.append(f"**{field}**: {val}")
return {"item": item, "text": "\n\n".join(parts)}
def _candidates_file() -> Path:
return _DATA_DIR / "sft_candidates.jsonl"
def _sse(data: dict) -> str:
return f"data: {json.dumps(data)}\n\n"
def _fetch_image_b64(image_url: str) -> str:
"""Download an image URL and return it as a base64 string for ollama.
Returns empty string on any failure a missing image is non-fatal;
the model will still run against the text prompt alone.
"""
try:
req = Request(image_url, headers={"User-Agent": "Avocet/1.0"})
with urlopen(req, timeout=10) as resp:
return base64.b64encode(resp.read()).decode("ascii")
except Exception as exc:
logger.warning("Failed to fetch image %s: %s", image_url, exc)
return ""
def _run_ollama_streaming(
ollama_base: str,
model_id: str,
prompt: str,
temperature: float,
system: str = "",
images: list[str] | None = None,
) -> tuple[str, int]:
"""Call ollama /api/generate with stream=False; return (full_response, elapsed_ms).
Blocks until the model finishes; yields nothing streaming is handled by
the SSE generator in run_imitate().
system: optional system prompt passed as a separate field to ollama.
images: list of base64-encoded image strings (vision models only).
"""
url = f"{ollama_base.rstrip('/')}/api/generate"
body: dict = {
"model": model_id,
"prompt": prompt,
"stream": False,
"options": {"temperature": temperature},
}
if system:
body["system"] = system
if images:
body["images"] = images
payload = json.dumps(body).encode("utf-8")
req = Request(url, data=payload, method="POST",
headers={"Content-Type": "application/json"})
t0 = time.time()
try:
with urlopen(req, timeout=120) as resp:
body = json.loads(resp.read().decode("utf-8"))
elapsed = int((time.time() - t0) * 1000)
return body.get("response", ""), elapsed
except Exception as exc:
elapsed = int((time.time() - t0) * 1000)
raise RuntimeError(str(exc)) from exc
def _run_cftext(
cforch_base: str,
model_id: str,
prompt: str,
system: str,
temperature: float,
startup_timeout_s: float = 180.0,
) -> tuple[str, int, bool]:
"""Allocate cf-text via cf-orch, generate, release. Returns (response, elapsed_ms, cold_started).
Raises RuntimeError on allocation failure or generation error.
cold_started=True means the service was launched from scratch (caller may log this).
Cold-start detection uses coordinator state signals (running/stopped) rather than
polling the service health endpoint this fails fast on model load errors instead
of waiting out the full timeout.
"""
# Allocate
alloc_resp = httpx.post(
f"{cforch_base}/api/services/cf-text/allocate",
json={
"model_candidates": [model_id],
"caller": "avocet",
"pipeline": "imitate",
},
timeout=30.0,
)
alloc_resp.raise_for_status()
data = alloc_resp.json()
service_url: str = data["url"]
allocation_id: str = data.get("allocation_id", "")
node_id: str = data.get("node_id", "")
gpu_id: int | None = data.get("gpu_id")
cold_started = data.get("started", False) and not data.get("warm", True)
# Wait for ready using coordinator state signals
if cold_started:
deadline = time.monotonic() + startup_timeout_s
probe_misses = 0
while time.monotonic() < deadline:
try:
status = httpx.get(
f"{cforch_base}/api/services/cf-text/status", timeout=5.0
)
if status.is_success:
instances = status.json().get("instances", [])
match = next(
(i for i in instances
if i.get("node_id") == node_id and i.get("gpu_id") == gpu_id),
None,
)
if match:
probe_misses = 0
state = match.get("state", "")
if state == "running":
break
elif state == "stopped":
if allocation_id:
httpx.delete(
f"{cforch_base}/api/services/cf-text/allocations/{allocation_id}",
timeout=5.0,
)
raise RuntimeError(f"cf-text failed to load {model_id!r} (service stopped)")
else:
probe_misses += 1
if probe_misses >= 6:
# Coordinator hasn't registered instance yet — fall back to health poll
try:
if httpx.get(f"{service_url}/health", timeout=3.0).is_success:
break
except Exception:
pass
except RuntimeError:
raise
except Exception:
pass
time.sleep(2.0)
else:
if allocation_id:
httpx.delete(f"{cforch_base}/api/services/cf-text/allocations/{allocation_id}", timeout=5.0)
raise RuntimeError(f"cf-text cold start timed out after {startup_timeout_s:.0f}s")
# Generate
messages: list[dict] = []
if system:
messages.append({"role": "system", "content": system})
messages.append({"role": "user", "content": prompt})
t0 = time.time()
try:
gen_resp = httpx.post(
f"{service_url}/v1/chat/completions",
json={
"model": model_id,
"messages": messages,
"max_tokens": 300,
"temperature": temperature,
"stream": False,
},
timeout=120.0,
)
gen_resp.raise_for_status()
elapsed_ms = int((time.time() - t0) * 1000)
content = gen_resp.json()["choices"][0]["message"]["content"]
return content.strip(), elapsed_ms, cold_started
except Exception as exc:
elapsed_ms = int((time.time() - t0) * 1000)
raise RuntimeError(str(exc)) from exc
finally:
if allocation_id:
try:
httpx.delete(f"{cforch_base}/api/services/cf-text/allocations/{allocation_id}", timeout=5.0)
except Exception:
pass
# ── GET /products ──────────────────────────────────────────────────────────────
@router.get("/products")
def get_products() -> dict:
"""List configured CF products with live online status."""
cfg = _load_imitate_config()
products_raw = cfg.get("products", []) or []
products = []
for p in products_raw:
if not isinstance(p, dict):
continue
base_url = p.get("base_url", "")
products.append({
"id": p.get("id", ""),
"name": p.get("name", ""),
"icon": p.get("icon", "📦"),
"description": p.get("description", ""),
"base_url": base_url,
"online": _is_online(base_url, p.get("health_path", "/api/health")) if base_url else False,
})
return {"products": products}
# ── GET /products/{product_id}/sample ─────────────────────────────────────────
@router.get("/products/{product_id}/sample")
def get_sample(product_id: str, index: int = 0) -> dict:
"""Fetch a real sample from the given product's API."""
cfg = _load_imitate_config()
products_raw = cfg.get("products", []) or []
product: dict | None = None
for p in products_raw:
if isinstance(p, dict) and p.get("id") == product_id:
product = p
break
if product is None:
raise HTTPException(404, f"Product '{product_id}' not in config")
base_url = product.get("base_url", "").rstrip("/")
endpoint = product.get("sample_endpoint", "")
if not base_url or not endpoint:
raise HTTPException(422, "Product missing base_url or sample_endpoint")
url = f"{base_url}{endpoint}"
try:
raw = _http_get_json(url, timeout=5)
except URLError as exc:
raise HTTPException(503, f"Product API unreachable: {exc}") from exc
except Exception as exc:
raise HTTPException(502, f"Bad response from product API: {exc}") from exc
text_fields = product.get("text_fields", []) or []
sample_key = product.get("sample_key") or None
extracted = _extract_sample(raw, text_fields, index, sample_key=sample_key)
if not extracted:
raise HTTPException(404, "No sample items returned by product API")
prompt_template = product.get("prompt_template", "{text}")
prompt = prompt_template.replace("{text}", extracted["text"])
# Also substitute any {field_name} placeholders from the raw item fields.
item = extracted.get("item", {})
for field, val in item.items():
prompt = prompt.replace(f"{{{field}}}", str(val) if val is not None else "")
# Expose system_prompt and image_url if the product API returns them.
# system_prompt: Peregrine, Snipe (vision analysis instructions)
# image_url: Snipe listing photos — Avocet downloads + base64-encodes at run time
item = extracted.get("item", {})
system_prompt = str(item.get("system_prompt", "")) if isinstance(item, dict) else ""
image_url = str(item.get("image_url", "")) if isinstance(item, dict) else ""
return {
"product_id": product_id,
"sample_index": index,
"text": extracted["text"],
"prompt": prompt,
"system_prompt": system_prompt,
"image_url": image_url,
"raw_item": item,
}
# ── GET /catalog ───────────────────────────────────────────────────────────────
@router.get("/catalog")
def get_catalog() -> dict:
"""Return the live cf-text model catalog from cf-orch coordinator."""
models = _cforch_catalog(_cforch_url())
return {"models": models}
# ── GET /run (SSE) ─────────────────────────────────────────────────────────────
@router.get("/run")
def run_imitate(
prompt: str = "",
model_ids: str = "", # comma-separated ollama model IDs
cf_text_model_ids: str = "", # comma-separated cf-text model IDs (via cf-orch)
temperature: float = 0.7,
product_id: str = "",
system: str = "", # optional system prompt
image_url: str = "", # optional image URL for vision models
) -> StreamingResponse:
"""Run a prompt through selected ollama models and stream results as SSE.
If image_url is provided, the image is downloaded once and passed to every
model as a base64-encoded blob allowing vision-capable local models to
evaluate listing photos the same way Snipe's background task pipeline does.
"""
if not prompt.strip():
raise HTTPException(422, "prompt is required")
ollama_ids = [m.strip() for m in model_ids.split(",") if m.strip()]
cftext_ids = [m.strip() for m in cf_text_model_ids.split(",") if m.strip()]
if not ollama_ids and not cftext_ids:
raise HTTPException(422, "model_ids or cf_text_model_ids is required")
cfg = _load_imitate_config()
ollama_base = _ollama_url(cfg)
cforch_base = _cforch_url()
system_ctx = system.strip() or ""
total_models = len(ollama_ids) + len(cftext_ids)
# Download image once before streaming — shared across ollama vision models
images: list[str] = []
if image_url.strip():
b64 = _fetch_image_b64(image_url.strip())
if b64:
images = [b64]
def generate():
results: list[dict] = []
yield _sse({"type": "start", "total_models": total_models, "has_image": bool(images)})
# Ollama models
for model_id in ollama_ids:
yield _sse({"type": "model_start", "model": model_id, "service": "ollama"})
try:
response, elapsed_ms = _run_ollama_streaming(
ollama_base, model_id, prompt, temperature,
system=system_ctx, images=images or None,
)
result = {
"model": model_id,
"response": response,
"elapsed_ms": elapsed_ms,
"error": None,
}
except Exception as exc:
result = {
"model": model_id,
"response": "",
"elapsed_ms": 0,
"error": str(exc),
}
results.append(result)
yield _sse({"type": "model_done", **result})
# cf-text models via cf-orch — fan out in parallel when multiple models selected
if cftext_ids:
from concurrent.futures import ThreadPoolExecutor, as_completed
# Announce all models upfront so the UI can show loading states immediately
for model_id in cftext_ids:
yield _sse({"type": "model_start", "model": model_id, "service": "cf-text"})
with ThreadPoolExecutor(max_workers=len(cftext_ids)) as pool:
future_to_model = {
pool.submit(_run_cftext, cforch_base, mid, prompt, system_ctx, temperature): mid
for mid in cftext_ids
}
for future in as_completed(future_to_model):
model_id = future_to_model[future]
try:
response, elapsed_ms, cold_started = future.result()
if cold_started:
yield _sse({"type": "model_coldstart", "model": model_id})
result = {
"model": model_id,
"response": response,
"elapsed_ms": elapsed_ms,
"error": None,
}
except Exception as exc:
result = {
"model": model_id,
"response": "",
"elapsed_ms": 0,
"error": str(exc),
}
results.append(result)
yield _sse({"type": "model_done", **result})
yield _sse({"type": "complete", "results": results})
return StreamingResponse(
generate(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"X-Accel-Buffering": "no",
},
)
# ── POST /push-corrections ─────────────────────────────────────────────────────
class ImitateResult(BaseModel):
model: str
response: str
elapsed_ms: int
error: str | None = None
class PushCorrectionsRequest(BaseModel):
product_id: str
prompt: str
results: list[ImitateResult]
@router.post("/push-corrections")
def push_corrections(req: PushCorrectionsRequest) -> dict:
"""Append imitate results to sft_candidates.jsonl for human review."""
if not req.prompt.strip():
raise HTTPException(422, "prompt is required")
if not req.results:
raise HTTPException(422, "results list is empty")
ts = datetime.now(timezone.utc).isoformat()
records = []
for r in req.results:
if r.error or not r.response.strip():
continue
records.append({
"id": str(uuid.uuid4()),
"source": "imitate",
"product_id": req.product_id,
"prompt_messages": [{"role": "user", "content": req.prompt}],
"model_response": r.response,
"model_id": r.model,
"elapsed_ms": r.elapsed_ms,
"status": "pending",
"created_at": ts,
})
if not records:
raise HTTPException(422, "No non-error results to push")
dest = _candidates_file()
dest.parent.mkdir(parents=True, exist_ok=True)
for record in records:
append_jsonl(dest, record)
return {"pushed": len(records)}
"""Backward-compat shim -- logic moved to app/data/imitate.py."""
from app.data.imitate import router # noqa: F401
from app.data.imitate import set_config_dir, set_data_dir # noqa: F401

View file

@ -15,6 +15,7 @@ from __future__ import annotations
import json
import logging
import os
import re
import shutil
import threading
from datetime import datetime, timezone
@ -37,13 +38,17 @@ except ImportError: # pragma: no cover
logger = logging.getLogger(__name__)
_ROOT = Path(__file__).parent.parent
_MODELS_DIR: Path = _ROOT / "models"
_MODELS_DIR: Path = Path(
os.environ.get("AVOCET_MODELS_DIR", str(_ROOT / "models"))
)
_QUEUE_DIR: Path = _ROOT / "data"
# Service-specific model destinations.
# cf-text models land on the NFS-mounted shared asset store so every cluster
# node can reach them without a separate download. Avocet classifiers stay local
# because they are fine-tuned in-place and are only consumed by avocet itself.
# node can reach them without a separate download. Avocet classifiers default
# to a local path but can be redirected via AVOCET_MODELS_DIR — set this to
# /Library/Assets/LLM/avocet/models on NFS-connected nodes to keep all model
# weights out of the repo directory.
# Override via CF_TEXT_MODELS_DIR env var (useful for dev / non-NFS setups).
_CF_TEXT_MODELS_DIR: Path = Path(
os.environ.get("CF_TEXT_MODELS_DIR", "/Library/Assets/LLM/cf-text/models")
@ -60,6 +65,30 @@ _CF_ORCH_PROFILES_DIR: Path = Path(
router = APIRouter()
# ── HuggingFace auth ─────────────────────────────────────────────────────────
def _get_hf_token() -> str | None:
"""Return HF token from label_tool.yaml, then HF_TOKEN / HUGGING_FACE_HUB_TOKEN env vars."""
config_file = _ROOT / "config" / "label_tool.yaml"
if config_file.exists():
try:
import yaml as _yaml
raw = _yaml.safe_load(config_file.read_text(encoding="utf-8")) or {}
token = (raw.get("hf_token") or raw.get("cforch", {}).get("hf_token") or "").strip()
if token:
return token
except Exception:
pass
return os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN") or None
# ── GGUF quantization detection ───────────────────────────────────────────────
# Matches quant identifiers in GGUF filenames: Q4_K_M, Q8_0, F16, IQ3_M, etc.
_QUANT_RE = re.compile(
r'[._-]((?:IQ\d|Q\d)[A-Z0-9_]*|F16|BF16)\.gguf$',
re.IGNORECASE,
)
# ── Download progress shared state ────────────────────────────────────────────
# Updated by the background download thread; read by GET /download/stream.
_download_progress: dict[str, Any] = {}
@ -91,12 +120,16 @@ _TAG_TO_INFO: dict[str, _TagInfo] = {
"audio-classification": {"adapter": None, "role": "classifier", "service": "cf-voice"},
# TTS — cf-tts text-to-speech service
"text-to-speech": {"adapter": None, "role": "tts", "service": "cf-tts"},
# Vision — cf-vision image classification / embedding / VLM service
# Vision classifiers / embedders — cf-vision (SigLIP/CLIP-style models)
"image-classification": {"adapter": None, "role": "vision", "service": "cf-vision"},
"zero-shot-image-classification": {"adapter": None, "role": "vision", "service": "cf-vision"},
"image-feature-extraction": {"adapter": None, "role": "embedding", "service": "cf-vision"},
"image-text-to-text": {"adapter": None, "role": "vlm", "service": "cf-vision"},
"visual-question-answering": {"adapter": None, "role": "vlm", "service": "cf-vision"},
# Generative VLMs (image+text → text) — GGUF quants run via llama.cpp (cf-text).
# cf-vision is a classifier/embedder service; generative VLMs like Qwen2-VL
# and LLaVA accept image inputs but are textgen at the backend level.
# Full-precision HF-format VLMs would use vllm, but our fleet uses GGUF quants.
"image-text-to-text": {"adapter": None, "role": "vlm", "service": "cf-text"},
"visual-question-answering": {"adapter": None, "role": "vlm", "service": "cf-text"},
# Image generation — cf-image (text → image; distinct from cf-vision image understanding)
"text-to-image": {"adapter": None, "role": "image-gen", "service": "cf-image"},
# Embedding — cf-core shared embedding layer
@ -111,6 +144,11 @@ def set_models_dir(path: Path) -> None:
_MODELS_DIR = path
def set_cf_text_models_dir(path: Path) -> None:
global _CF_TEXT_MODELS_DIR
_CF_TEXT_MODELS_DIR = path
def set_queue_dir(path: Path) -> None:
global _QUEUE_DIR
_QUEUE_DIR = path
@ -197,8 +235,15 @@ def _catalog_key(repo_id: str) -> str:
ibm-granite/granite-4.1-8b granite-4.1-8b
facebook/bart-large-cnn bart-large-cnn
WithinUsAI/Opus4.7-GODs.Ghost.Codex-4B.GGuF opus4.7-gods.ghost.codex-4b
The coordinator skips catalog lookup for keys ending in ".gguf" (treats them
as direct file paths). Strip the suffix so GGUF repo names produce valid keys.
"""
return repo_id.split("/", 1)[-1].lower()
key = repo_id.split("/", 1)[-1].lower()
if key.endswith(".gguf"):
key = key[:-5]
return key
def _insert_catalog_entry(content: str, entry_lines: str) -> str:
@ -290,6 +335,15 @@ def _register_in_node_catalogs(
max_mb: int = cf_text.get("max_mb", 0)
catalog: dict = cf_text.get("catalog") or {}
# If the node has a different local model dir, remap the NFS path.
model_base = cf_text.get("model_base_path", "").rstrip("/")
if model_base:
nfs_base = str(_CF_TEXT_MODELS_DIR).rstrip("/")
model_name = local_path.name
effective_path_str = f"{model_base}/{model_name}"
else:
effective_path_str = local_path_str
# Skip if key already exists
if model_key in catalog:
logger.debug("Key %r already in %s — skipping", model_key, yaml_file.name)
@ -301,10 +355,10 @@ def _register_in_node_catalogs(
for entry in catalog.values()
if isinstance(entry, dict)
}
if local_path_str in registered_paths or any(
p.startswith(local_path_str + "/") for p in registered_paths
if effective_path_str in registered_paths or any(
p.startswith(effective_path_str + "/") for p in registered_paths
):
logger.debug("Path %s already registered in %s — skipping", local_path_str, yaml_file.name)
logger.debug("Path %s already registered in %s — skipping", effective_path_str, yaml_file.name)
continue
# Determine whether model fits at FP16 or needs 4-bit
@ -330,12 +384,18 @@ def _register_in_node_catalogs(
if needs_4bit
else f" # FP16 file-size estimate"
)
env_block = (
f" env:\n"
f" CF_TEXT_4BIT: \"1\"\n"
if needs_4bit else ""
)
entry_block = (
f" # auto-registered by avocet on download\n"
f" {model_key}:\n"
f" path: {local_path_str}\n"
f" path: {effective_path_str}\n"
f" vram_mb: {vram_for_node}{vram_comment}\n"
f" description: \"{desc}\"\n"
f"{env_block}"
)
new_content = _insert_catalog_entry(content, entry_block)
@ -388,12 +448,17 @@ def _run_download(
role: str | None = None,
service: str | None = None,
model_size_bytes: int = 0,
quant_pattern: str | None = None,
) -> None:
"""Background thread: download model via huggingface_hub.snapshot_download.
model_size_bytes is the sum of file sizes reported by the HF API (siblings).
It is used to estimate vram_mb and written to model_info.json so cf-orch can
budget VRAM when allocating a cf-text instance for this model.
quant_pattern: when set, restricts snapshot_download to only files matching
*{quant_pattern}*.gguf (plus metadata). Avoids downloading every quant variant
from GGUF-only repos like bartowski/*.
"""
global _download_progress
local_dir = _model_dir_for(repo_id, service)
@ -422,10 +487,20 @@ def _run_download(
local_dir.mkdir(parents=True, exist_ok=True)
poll_thread.start()
snapshot_download(
repo_id=repo_id,
local_dir=str(local_dir),
)
dl_kwargs: dict[str, Any] = {"repo_id": repo_id, "local_dir": str(local_dir)}
hf_token = _get_hf_token()
if hf_token:
dl_kwargs["token"] = hf_token
if quant_pattern:
# Include both cases: repos use mixed conventions (Q6_K vs q6_k).
dl_kwargs["allow_patterns"] = [
f"*{quant_pattern.upper()}*.gguf",
f"*{quant_pattern.lower()}*.gguf",
"*.json",
"README.md",
]
snapshot_download(**dl_kwargs)
# Estimate VRAM from reported file size.
# HF siblings sizes are pre-quantisation file sizes; add 10% for KV cache
@ -531,9 +606,31 @@ def lookup_model(repo_id: str) -> dict:
)
logger.warning("Unsupported pipeline_tag %r for %s", pipeline_tag, repo_id)
# Estimate model size from siblings list
# Detect GGUF files and parse quant names from siblings list.
# For GGUF-only repos (bartowski, TheBloke, etc.) this lets the UI show
# a per-quant size picker instead of downloading every variant.
siblings = data.get("siblings") or []
model_size_bytes: int = sum(s.get("size", 0) for s in siblings if isinstance(s, dict))
gguf_files: list[dict] = []
for s in siblings:
if not isinstance(s, dict):
continue
fname: str = s.get("rfilename", "")
if not fname.lower().endswith(".gguf"):
continue
m = _QUANT_RE.search(fname)
gguf_files.append({
"filename": fname,
"size": s.get("size", 0) or 0,
"quant_name": m.group(1).upper() if m else None,
})
gguf_files.sort(key=lambda f: f["size"])
# model_size_bytes: total of all siblings (for non-GGUF repos) or all GGUFs only.
# For GGUF repos the frontend will substitute the selected quant's size on submit.
if gguf_files:
model_size_bytes: int = sum(f["size"] for f in gguf_files)
else:
model_size_bytes = sum(s.get("size", 0) for s in siblings if isinstance(s, dict))
# Description: first 300 chars of card data (modelId field used as fallback)
card_data = data.get("cardData") or {}
@ -549,6 +646,7 @@ def lookup_model(repo_id: str) -> dict:
"compatible": compatible,
"warning": warning,
"model_size_bytes": model_size_bytes,
"gguf_files": gguf_files if gguf_files else None,
"description": description,
"tags": data.get("tags") or [],
"downloads": data.get("downloads") or 0,
@ -579,6 +677,9 @@ class QueueAddRequest(BaseModel):
# Stored in the queue entry so approve can pass it to _run_download
# without a second HF API round-trip.
model_size_bytes: int = 0
# GGUF quantization pattern (e.g. "Q5_K_M"). When set, snapshot_download
# restricts to *{quant_pattern}*.gguf instead of fetching all variants.
quant_pattern: str | None = None
@router.post("/queue", status_code=201)
@ -597,6 +698,7 @@ def add_to_queue(req: QueueAddRequest) -> dict:
"role": req.role,
"service": req.service,
"model_size_bytes": req.model_size_bytes,
"quant_pattern": req.quant_pattern,
"status": "pending",
"queued_at": datetime.now(timezone.utc).isoformat(),
}
@ -629,6 +731,7 @@ def approve_queue_entry(entry_id: str) -> dict:
entry.get("role"),
entry.get("service"),
entry.get("model_size_bytes", 0),
entry.get("quant_pattern"),
),
daemon=True,
name=f"model-download-{entry_id}",
@ -638,6 +741,32 @@ def approve_queue_entry(entry_id: str) -> dict:
return {"ok": True}
# ── PATCH /queue/{id} ─────────────────────────────────────────────────────────
class QueuePatchRequest(BaseModel):
service: str | None = None
role: str | None = None
@router.patch("/queue/{entry_id}")
def patch_queue_entry(entry_id: str, body: QueuePatchRequest) -> dict:
"""Update mutable fields (service, role) on a pending queue entry."""
entry = _get_queue_entry(entry_id)
if entry is None:
raise HTTPException(404, f"Queue entry {entry_id!r} not found")
if entry.get("status") != "pending":
raise HTTPException(409, f"Only pending entries can be patched (current: {entry.get('status')!r})")
updates: dict = {}
if body.service is not None:
updates["service"] = body.service
if body.role is not None:
updates["role"] = body.role
updated = _update_queue_entry(entry_id, updates)
return updated or {}
# ── DELETE /queue/{id} ─────────────────────────────────────────────────────────
@router.delete("/queue/{entry_id}")

535
app/nodes.py Normal file
View file

@ -0,0 +1,535 @@
"""Avocet — Node Management API.
Proxies cf-orch coordinator and agent APIs to expose per-node GPU state,
service affinity management, and Ollama model management.
Config is read from label_tool.yaml under the `cforch:` key.
The `profiles_dir` key (new) points to the cf-orch node profile YAML directory.
Module-level globals follow the set_config_dir() testability pattern from cforch.py.
"""
from __future__ import annotations
import json
import logging
import os
from pathlib import Path
from urllib.parse import urlparse
import yaml
from fastapi import APIRouter, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
logger = logging.getLogger(__name__)
_ROOT = Path(__file__).parent.parent
_CONFIG_DIR: Path | None = None # override in tests
router = APIRouter()
# ── Testability seams ──────────────────────────────────────────────────────────
def set_config_dir(path: Path | None) -> None:
global _CONFIG_DIR
_CONFIG_DIR = path
# ── Internal helpers ───────────────────────────────────────────────────────────
def _config_file() -> Path:
if _CONFIG_DIR is not None:
return _CONFIG_DIR / "label_tool.yaml"
return _ROOT / "config" / "label_tool.yaml"
def _load_config() -> dict:
"""Read label_tool.yaml cforch section. Returns empty dict on missing or parse error."""
f = _config_file()
if not f.exists():
return {}
try:
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
return raw.get("cforch", {}) or {}
except yaml.YAMLError as exc:
logger.warning("Failed to parse config %s: %s", f, exc)
return {}
def _profiles_dir() -> Path | None:
"""Return the cf-orch node profiles directory, or None if not configured."""
cfg = _load_config()
pd = cfg.get("profiles_dir", "") or ""
if pd:
return Path(pd)
bench = cfg.get("bench_script", "") or ""
if bench:
return Path(bench).parent.parent / "profiles" / "nodes"
return None
def _profile_path(node_id: str) -> Path | None:
"""Return the path to a node's profile YAML, or None if profiles_dir is unknown."""
pd = _profiles_dir()
if pd is None:
return None
return pd / f"{node_id}.yaml"
def _load_profile(node_id: str) -> dict | None:
"""Load and parse a node profile YAML. Returns None if not found or malformed."""
p = _profile_path(node_id)
if p is None or not p.exists():
return None
try:
return yaml.safe_load(p.read_text(encoding="utf-8")) or {}
except yaml.YAMLError as exc:
logger.warning("Malformed profile YAML %s: %s", p, exc)
return None
def _get_ollama_url(node_id: str) -> str:
"""Derive Ollama URL from the node profile's agent_url (same host, port 11434)."""
profile = _load_profile(node_id)
if profile:
nodes_section = profile.get("nodes", {}) or {}
node_entry = nodes_section.get(node_id, {}) or {}
agent_url = node_entry.get("agent_url", "") or ""
if agent_url:
parsed = urlparse(agent_url)
return f"{parsed.scheme}://{parsed.hostname}:11434"
raise HTTPException(
status_code=404,
detail=f"Cannot determine Ollama URL for node {node_id}: no agent_url in profile",
)
# ── Endpoints ──────────────────────────────────────────────────────────────────
@router.get("/nodes")
def list_nodes() -> list:
"""Return all nodes with live GPU stats merged with profile YAML."""
import httpx
cfg = _load_config()
coordinator_url = cfg.get("coordinator_url", "") or ""
if not coordinator_url:
return []
try:
r = httpx.get(f"{coordinator_url}/api/nodes", timeout=5.0)
r.raise_for_status()
coord_nodes: list[dict] = r.json().get("nodes", [])
except httpx.HTTPError as exc:
logger.warning("Coordinator unreachable: %s", exc)
return []
try:
sr = httpx.get(f"{coordinator_url}/api/services", timeout=5.0)
sr.raise_for_status()
services_data: list[dict] = sr.json().get("services", [])
except httpx.HTTPError:
logger.warning("Services API unreachable for %s, skipping", coordinator_url)
services_data = []
# Build per-node, per-GPU running services map
running: dict[str, dict[int, list[str]]] = {}
for svc in services_data:
nid = svc.get("node_id", "")
gid = svc.get("gpu_id")
svc_name = svc.get("service", "")
if nid and gid is not None and svc_name:
running.setdefault(nid, {}).setdefault(gid, []).append(svc_name)
result = []
for node in coord_nodes:
node_id = node.get("node_id", "") or node.get("id", "")
profile = _load_profile(node_id) if node_id else None
profile_loaded = profile is not None
gpus = []
for gpu in (node.get("gpus", []) or []):
gpu_id = gpu.get("gpu_id", gpu.get("id", 0))
services_assigned: list[str] = []
if profile:
node_entry = (profile.get("nodes", {}) or {}).get(node_id, {}) or {}
for g in (node_entry.get("gpus", []) or []):
if isinstance(g, dict) and g.get("id") == gpu_id:
services_assigned = g.get("services", []) or []
break
gpus.append({
"gpu_id": gpu_id,
"card": gpu.get("card", ""),
"vram_total_mb": gpu.get("vram_total_mb", 0),
"vram_used_mb": gpu.get("vram_used_mb", 0),
"vram_free_mb": gpu.get("vram_free_mb", 0),
"temp_c": gpu.get("temp_c"),
"utilization_pct": gpu.get("utilization_pct"),
"compute_cap": gpu.get("compute_cap"),
"services_assigned": services_assigned,
"services_running": running.get(node_id, {}).get(gpu_id, []),
})
services_catalog: dict = {}
if profile:
for svc_name, svc_info in (profile.get("services", {}) or {}).items():
catalog = svc_info.get("catalog", {}) or {}
services_catalog[svc_name] = {
"min_compute_cap": svc_info.get("min_compute_cap", 0.0),
"max_mb": svc_info.get("max_mb", 0),
"catalog_size": len(catalog),
}
result.append({
"node_id": node_id,
"online": node.get("online", True),
"agent_url": node.get("agent_url", ""),
"gpus": gpus,
"profile_loaded": profile_loaded,
"services_catalog": services_catalog,
})
return result
@router.get("/nodes/{node_id}/profile")
def get_node_profile(node_id: str) -> dict:
"""Return the full parsed profile YAML for a node."""
p = _profile_path(node_id)
if p is None or not p.exists():
raise HTTPException(404, f"No profile found for node {node_id}")
try:
data = yaml.safe_load(p.read_text(encoding="utf-8")) or {}
except yaml.YAMLError as exc:
raise HTTPException(500, f"Malformed profile YAML: {exc}")
return data
class UpdateServicesRequest(BaseModel):
services: list[str]
@router.post("/nodes/{node_id}/gpu/{gpu_id}/services")
def update_gpu_services(node_id: str, gpu_id: int, body: UpdateServicesRequest) -> dict:
"""Set service assignment for a GPU with compatibility validation, then atomic write."""
import httpx
cfg = _load_config()
coordinator_url = cfg.get("coordinator_url", "") or ""
p = _profile_path(node_id)
if p is None or not p.exists():
raise HTTPException(404, f"No profile found for node {node_id}")
try:
profile = yaml.safe_load(p.read_text(encoding="utf-8")) or {}
except yaml.YAMLError as exc:
raise HTTPException(500, f"Malformed profile YAML: {exc}")
nodes_section = profile.get("nodes", {}) or {}
node_entry = nodes_section.get(node_id, {}) or {}
gpu_list = node_entry.get("gpus", []) or []
gpu_entry = next(
(g for g in gpu_list if isinstance(g, dict) and g.get("id") == gpu_id),
None,
)
if gpu_entry is None:
raise HTTPException(404, f"GPU {gpu_id} not found in profile for node {node_id}")
gpu_compute_cap: float = gpu_entry.get("compute_cap") or 0.0
gpu_vram_mb: int = gpu_entry.get("vram_mb") or 0
services_def = profile.get("services", {}) or {}
for svc_name in body.services:
if svc_name not in services_def:
raise HTTPException(422, f"Service '{svc_name}' not defined in profile services dict")
svc = services_def[svc_name]
min_cap: float = svc.get("min_compute_cap", 0.0) or 0.0
if gpu_compute_cap < min_cap:
raise HTTPException(
422,
f"Service '{svc_name}' requires compute_cap >= {min_cap}; GPU has {gpu_compute_cap}",
)
catalog = svc.get("catalog", {}) or {}
min_catalog_vram = (
min((m.get("vram_mb", 0) for m in catalog.values()), default=0)
if catalog else svc.get("max_mb", 0)
)
if gpu_vram_mb < min_catalog_vram:
raise HTTPException(
422,
f"Service '{svc_name}' requires {min_catalog_vram} MB VRAM; GPU has {gpu_vram_mb} MB",
)
# Immutable update of GPU services list
new_gpu_list = [
({**g, "services": body.services} if isinstance(g, dict) and g.get("id") == gpu_id else g)
for g in gpu_list
]
new_profile = {
**profile,
"nodes": {
**nodes_section,
node_id: {**node_entry, "gpus": new_gpu_list},
},
}
# Atomic write: write to .tmp then rename
tmp_yaml = Path(str(p) + ".tmp")
tmp_yaml.write_text(yaml.dump(new_profile, default_flow_style=False), encoding="utf-8")
os.replace(tmp_yaml, p)
# Trigger coordinator profile reload
reloaded = False
if coordinator_url:
try:
rr = httpx.post(
f"{coordinator_url}/api/nodes/{node_id}/reload-profile", timeout=5.0
)
reloaded = rr.status_code < 300
except Exception as exc:
logger.warning("Coordinator reload failed for node %s: %s", node_id, exc)
return {"ok": True, "reloaded": reloaded, "warnings": []}
# ── Profile save / generate ────────────────────────────────────────────────────
class SaveProfileRequest(BaseModel):
profile: dict
@router.put("/nodes/{node_id}/profile", status_code=200)
def save_profile(node_id: str, body: SaveProfileRequest) -> dict:
"""Write a full profile dict to disk as YAML, then trigger coordinator reload."""
p = _profile_path(node_id)
if p is None:
raise HTTPException(500, "profiles_dir not configured in label_tool.yaml")
p.parent.mkdir(parents=True, exist_ok=True)
tmp = Path(str(p) + ".tmp")
tmp.write_text(
yaml.dump(body.profile, default_flow_style=False, allow_unicode=True, sort_keys=False),
encoding="utf-8",
)
os.replace(tmp, p)
cfg = _load_config()
coordinator_url = cfg.get("coordinator_url", "") or ""
reloaded = False
if coordinator_url:
try:
import httpx
rr = httpx.post(f"{coordinator_url}/api/nodes/{node_id}/reload-profile", timeout=5.0)
reloaded = rr.status_code < 300
except Exception as exc:
logger.warning("Coordinator reload failed for %s: %s", node_id, exc)
return {"ok": True, "reloaded": reloaded}
@router.post("/nodes/{node_id}/profile/generate")
def generate_profile(node_id: str) -> dict:
"""Return a profile skeleton seeded from coordinator GPU data.
If a profile already exists, preserves its services section and only
refreshes the nodes hardware section. Never writes to disk the caller
must call PUT /profile to persist.
"""
import httpx
cfg = _load_config()
coordinator_url = cfg.get("coordinator_url", "") or ""
if not coordinator_url:
raise HTTPException(503, "coordinator_url not configured")
try:
r = httpx.get(f"{coordinator_url}/api/nodes", timeout=5.0)
r.raise_for_status()
coord_nodes: list[dict] = r.json().get("nodes", [])
except httpx.HTTPError as exc:
raise HTTPException(502, f"Coordinator unreachable: {exc}")
node = next((n for n in coord_nodes if n.get("node_id") == node_id), None)
if node is None:
raise HTTPException(404, f"Node {node_id!r} not found in coordinator")
gpus = [
{
"id": g.get("gpu_id", i),
"vram_mb": g.get("vram_total_mb", 0),
"compute_cap": g.get("compute_cap", 0.0),
"card": g.get("card", g.get("name", "")),
"role": "inference",
"services": [],
}
for i, g in enumerate(node.get("gpus", []))
]
vram_total = max((g["vram_mb"] for g in gpus), default=0)
existing = _load_profile(node_id) or {}
return {
"schema_version": existing.get("schema_version", 1),
"name": existing.get("name", f"node-{node_id}"),
"vram_total_mb": vram_total,
"eviction_timeout_s": existing.get("eviction_timeout_s", 10.0),
"services": existing.get("services", {}),
"nodes": {
node_id: {
"local_model_root": (
(existing.get("nodes", {}) or {})
.get(node_id, {})
.get("local_model_root", "")
),
"gpus": gpus,
}
},
"model_size_hints": existing.get("model_size_hints", {}),
}
# ── Ollama model management ────────────────────────────────────────────────────
class PullRequest(BaseModel):
name: str
@router.get("/nodes/{node_id}/models/ollama")
def list_ollama_models(node_id: str) -> dict:
"""Proxy GET {ollama_url}/api/tags for a specific node."""
import httpx
ollama_url = _get_ollama_url(node_id)
try:
r = httpx.get(f"{ollama_url}/api/tags", timeout=10.0)
r.raise_for_status()
return r.json()
except Exception as exc:
return {"error": str(exc)}
@router.post("/nodes/{node_id}/models/ollama/pull")
def pull_ollama_model(node_id: str, body: PullRequest) -> StreamingResponse:
"""Stream Ollama pull progress as SSE events."""
import httpx
if not body.name:
raise HTTPException(400, "name is required")
ollama_url = _get_ollama_url(node_id)
def stream():
try:
with httpx.stream(
"POST",
f"{ollama_url}/api/pull",
json={"name": body.name, "stream": True},
timeout=300.0,
) as resp:
for line in resp.iter_lines():
if line:
yield f"data: {line}\n\n"
except Exception as exc:
yield f"data: {json.dumps({'error': str(exc)})}\n\n"
return StreamingResponse(stream(), media_type="text/event-stream")
@router.delete("/nodes/{node_id}/models/ollama/{name:path}")
def delete_ollama_model(node_id: str, name: str) -> dict:
"""Proxy DELETE to Ollama for a specific node."""
import httpx
ollama_url = _get_ollama_url(node_id)
try:
r = httpx.request("DELETE", f"{ollama_url}/api/delete", json={"name": name}, timeout=10.0)
if r.status_code == 404:
raise HTTPException(404, f"Model '{name}' not found on node {node_id}")
r.raise_for_status()
return {"ok": True}
except HTTPException:
raise
except Exception as exc:
raise HTTPException(502, f"Ollama unreachable: {exc}")
# ── Model deploy (add catalog entry) ──────────────────────────────────────────
class DeployModelRequest(BaseModel):
model_id: str
service_type: str
vram_mb: int
description: str = ""
hf_repo: str = ""
path: str = "" # explicit path; if empty, constructed from model_base_path + hf_repo slug
@router.post("/nodes/{node_id}/models/deploy", status_code=200)
def deploy_model(node_id: str, body: DeployModelRequest) -> dict:
"""Register a model in the node's service catalog.
Adds (or updates) the catalog entry for body.model_id under the given
service_type in the node's profile YAML, then triggers a coordinator reload.
Does not download the model that is the user's responsibility.
Returns the resolved path so the caller can see where the model should land.
"""
p = _profile_path(node_id)
if p is None or not p.exists():
raise HTTPException(404, f"No profile found for node {node_id!r}")
try:
profile = yaml.safe_load(p.read_text(encoding="utf-8")) or {}
except yaml.YAMLError as exc:
raise HTTPException(500, f"Malformed profile YAML: {exc}")
services_def = profile.get("services", {}) or {}
svc = services_def.get(body.service_type)
if svc is None:
raise HTTPException(
422,
f"Service '{body.service_type}' not defined in node '{node_id}' profile; "
"add it first via the profile editor",
)
# Resolve path: explicit > model_base_path + hf slug > model_id slug
model_path = body.path.strip()
if not model_path:
base = (svc.get("model_base_path", "") or "").rstrip("/")
if not base:
raise HTTPException(
422,
f"Service '{body.service_type}' has no model_base_path; supply an explicit path",
)
slug_src = body.hf_repo.strip() if body.hf_repo.strip() else body.model_id
hf_slug = slug_src.replace("/", "--")
model_path = f"{base}/{hf_slug}"
# Immutable catalog update — spread, never mutate
entry: dict = {"path": model_path, "vram_mb": body.vram_mb}
if body.description:
entry["description"] = body.description
new_catalog = {**(svc.get("catalog") or {}), body.model_id: entry}
new_svc = {**svc, "catalog": new_catalog}
new_services = {**services_def, body.service_type: new_svc}
new_profile = {**profile, "services": new_services}
# Atomic write
tmp = Path(str(p) + ".tmp")
tmp.write_text(
yaml.dump(new_profile, default_flow_style=False, allow_unicode=True, sort_keys=False),
encoding="utf-8",
)
os.replace(tmp, p)
# Trigger coordinator reload
cfg = _load_config()
coordinator_url = cfg.get("coordinator_url", "") or ""
reloaded = False
if coordinator_url:
try:
import httpx
rr = httpx.post(f"{coordinator_url}/api/nodes/{node_id}/reload-profile", timeout=5.0)
reloaded = rr.status_code < 300
except Exception as exc:
logger.warning("Coordinator reload failed for %s: %s", node_id, exc)
return {"ok": True, "reloaded": reloaded, "path": model_path}

327
app/plans_bench.py Normal file
View file

@ -0,0 +1,327 @@
"""Avocet — CF planning benchmark integration API.
Wraps scripts/benchmark_plans.py and exposes it via the Avocet API.
Connection config (api_base) is read from label_tool.yaml under the
`plans_bench:` key (optional; falls back to localhost:8080).
All endpoints are registered on `router` (FastAPI APIRouter).
api.py includes this router with prefix="/api/plans-bench".
"""
from __future__ import annotations
import json
import logging
import subprocess as _subprocess
import uuid
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
import httpx
import yaml
from fastapi import APIRouter, HTTPException, Query
from fastapi.responses import StreamingResponse
logger = logging.getLogger(__name__)
_ROOT = Path(__file__).parent.parent
_CONFIG_DIR: Path | None = None # override in tests via set_config_dir()
_BENCH_RUNNING: bool = False
_bench_proc: Any = None
_BENCH_SCRIPT = _ROOT / "scripts" / "benchmark_plans.py"
_RESULTS_DIR = _ROOT / "data" / "plans_bench_results"
router = APIRouter()
# ── Registered model shortcuts (mirrors benchmark_plans.MODEL_REGISTRY) ────────
# Kept here so the UI can list them without importing the script.
MODEL_REGISTRY: dict[str, str] = {
"deepseek-r1-1.5b": "DeepSeek R1 1.5B distill (cf-orch catalog key)",
"deepseek-r1-7b-4bit": "DeepSeek R1 7B distill, 4-bit (cf-orch catalog key)",
"deepseek-r1-0528-qwen3-8b-gguf": "DeepSeek R1 0528 Qwen3 8B GGUF (4 nodes)",
"deepseek-coder-6.7b-4bit": "DeepSeek Coder 6.7B instruct, 4-bit (cf-orch catalog key)",
"granite-4.1-8b": "IBM Granite 4.1 8B, 4-bit (cf-orch catalog key)",
"qwen2.5-3b": "Qwen 2.5 3B Q4 GGUF (cf-orch catalog key)",
"qwen2.5-7b": "Qwen 2.5 7B Q4 GGUF (cf-orch catalog key)",
"capybarahermes-2.5-mistral-7b-gguf": "CapybaraHermes 2.5 Mistral 7B GGUF (4 nodes)",
"darwin-9b-opus-gguf": "Darwin 9B Opus GGUF -- long-form writing (3 nodes)",
}
RUBRIC_LABELS: dict[str, str] = {
"task_structure": "Task structure (checkboxes + commits)",
"tier_awareness": "Tier awareness (Free/Paid/Premium/Ultra)",
"privacy_pillar": "Privacy pillar (local-first, no logging)",
"safety_pillar": "Safety pillar (human approval, reversibility)",
"accessibility": "Accessibility (ND/adaptive users)",
"license_split": "License awareness (MIT vs BSL)",
"file_paths": "File paths (plausible project paths)",
"cf_conventions": "CF conventions (conda, manage.sh, /Library/…)",
"length_ok": "Response length (2002500 words)",
}
# ── Testability seam ───────────────────────────────────────────────────────────
def set_config_dir(path: Path | None) -> None:
global _CONFIG_DIR
_CONFIG_DIR = path
# ── Internal helpers ───────────────────────────────────────────────────────────
def _config_file() -> Path:
if _CONFIG_DIR is not None:
return _CONFIG_DIR / "label_tool.yaml"
return _ROOT / "config" / "label_tool.yaml"
def _load_config() -> dict:
f = _config_file()
cforch_cfg: dict = {}
bench_cfg: dict = {}
if f.exists():
try:
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
cforch_cfg = raw.get("cforch", {}) or {}
bench_cfg = raw.get("plans_bench", {}) or {}
except yaml.YAMLError as exc:
logger.warning("Failed to parse plans_bench config %s: %s", f, exc)
return {
"coordinator_url": cforch_cfg.get("coordinator_url",
bench_cfg.get("coordinator_url", "http://10.1.10.71:7700")),
"python_bin": cforch_cfg.get("python_bin",
bench_cfg.get("python_bin", "/devl/miniconda3/envs/cf/bin/python")),
}
def _results_file(run_id: str) -> Path:
return _RESULTS_DIR / f"{run_id}.json"
# ── GET /models ────────────────────────────────────────────────────────────────
@router.get("/models")
def get_models() -> dict:
"""Return registered model shortcuts, live cf-orch catalog, and rubric labels."""
cfg = _load_config()
cforch_models: list[dict] = []
try:
resp = httpx.get(
f"{cfg['coordinator_url']}/api/services/cf-text/catalog",
timeout=5.0,
)
resp.raise_for_status()
for model_id, entry in resp.json().items():
if isinstance(entry, dict):
cforch_models.append({
"id": model_id,
"name": model_id,
"vram_mb": entry.get("vram_mb"),
"description": entry.get("description", ""),
})
except Exception as exc:
logger.warning("Failed to fetch cf-orch catalog: %s", exc)
return {
"registry": [
{"key": k, "description": v}
for k, v in MODEL_REGISTRY.items()
],
"cforch_models": cforch_models,
"coordinator_url": cfg["coordinator_url"],
"rubric_labels": RUBRIC_LABELS,
}
# ── GET /run ───────────────────────────────────────────────────────────────────
@router.get("/run")
def run_plans_benchmark(
models: str = Query(..., description="Comma-separated model IDs (registry keys or cf-orch model names)"),
prompt_ids: str = Query("", description="Comma-separated prompt IDs to run (empty = all 10)"),
use_cforch: bool = Query(True, description="Route inference through cf-orch coordinator"),
api_base: str = Query("", description="Direct API base URL when not using cf-orch"),
workers: int = Query(1, ge=1, le=8, description="Number of models to benchmark concurrently"),
) -> StreamingResponse:
"""Spawn benchmark_plans.py and stream stdout as SSE progress events.
On successful completion emits a `type: result` event with parsed JSON
and saves results to data/plans_bench_results/<run_id>.json.
"""
global _BENCH_RUNNING, _bench_proc
if _BENCH_RUNNING:
raise HTTPException(409, "A planning benchmark is already running")
cfg = _load_config()
python_bin = cfg["python_bin"]
coordinator_url = cfg["coordinator_url"]
model_keys = [m.strip() for m in models.split(",") if m.strip()]
if not model_keys:
raise HTTPException(400, "At least one model key is required")
run_id = datetime.now(tz=timezone.utc).strftime("plans_%Y-%m-%d_%H%M%S")
output_path = _results_file(run_id)
_RESULTS_DIR.mkdir(parents=True, exist_ok=True)
def generate():
global _BENCH_RUNNING, _bench_proc
if not _BENCH_SCRIPT.exists():
yield f"data: {json.dumps({'type': 'error', 'message': f'benchmark_plans.py not found at {_BENCH_SCRIPT}'})}\n\n"
return
cmd = [python_bin, str(_BENCH_SCRIPT)]
if len(model_keys) > 1:
cmd.extend(["--compare"] + model_keys)
else:
cmd.extend(["--model", model_keys[0]])
if use_cforch:
cmd.extend(["--cforch", "--cforch-url", coordinator_url])
elif api_base.strip():
cmd.extend(["--api-base", api_base.strip()])
cmd.extend(["--verbose", "--output", str(output_path)])
if workers > 1:
cmd.extend(["--workers", str(workers)])
if prompt_ids.strip():
cmd.extend(["--prompts"] + [p.strip() for p in prompt_ids.split(",") if p.strip()])
_BENCH_RUNNING = True
try:
proc = _subprocess.Popen(
cmd,
stdout=_subprocess.PIPE,
stderr=_subprocess.STDOUT,
text=True,
bufsize=1,
cwd=str(_ROOT),
)
_bench_proc = proc
try:
for line in proc.stdout:
line = line.rstrip()
if line:
yield f"data: {json.dumps({'type': 'progress', 'message': line})}\n\n"
proc.wait()
if proc.returncode == 0 and output_path.exists():
try:
results = json.loads(output_path.read_text(encoding="utf-8"))
yield f"data: {json.dumps({'type': 'result', 'run_id': run_id, 'results': results})}\n\n"
except Exception as exc:
logger.warning("Failed to read plans benchmark output: %s", exc)
yield f"data: {json.dumps({'type': 'complete'})}\n\n"
else:
yield f"data: {json.dumps({'type': 'error', 'message': f'Process exited with code {proc.returncode}'})}\n\n"
finally:
_bench_proc = None
except Exception as exc:
yield f"data: {json.dumps({'type': 'error', 'message': str(exc)})}\n\n"
finally:
_BENCH_RUNNING = False
return StreamingResponse(
generate(),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
)
# ── GET /results ───────────────────────────────────────────────────────────────
@router.get("/results")
def list_results() -> list[dict]:
"""List past planning benchmark runs, newest first."""
if not _RESULTS_DIR.exists():
return []
runs: list[dict] = []
for f in sorted(_RESULTS_DIR.glob("plans_*.json"), reverse=True):
run_id = f.stem
try:
data: dict = json.loads(f.read_text(encoding="utf-8"))
model_keys = list(data.keys())
# Average total_score across all models and prompts
all_scores = [
r["total_score"]
for results in data.values()
for r in results
if not r.get("error")
]
avg_score = round(sum(all_scores) / len(all_scores), 3) if all_scores else 0.0
except Exception:
model_keys = []
avg_score = 0.0
# Parse display date from run_id (plans_2026-04-27_143022)
try:
date_part = run_id.removeprefix("plans_") # 2026-04-27_143022
date, time = date_part.split("_")
display_date = f"{date} {time[:2]}:{time[2:4]}"
except Exception:
display_date = run_id
runs.append({
"run_id": run_id,
"filename": f.name,
"date": display_date,
"models": model_keys,
"avg_score": avg_score,
})
return runs
@router.get("/results/latest")
def get_latest_results() -> dict:
"""Return the most recent planning benchmark results dict."""
if not _RESULTS_DIR.exists():
raise HTTPException(404, "No benchmark results found")
files = sorted(_RESULTS_DIR.glob("plans_*.json"))
if not files:
raise HTTPException(404, "No benchmark results found")
try:
return json.loads(files[-1].read_text(encoding="utf-8"))
except Exception as exc:
raise HTTPException(500, f"Failed to read results: {exc}") from exc
@router.get("/results/{run_id}")
def get_results_by_run_id(run_id: str) -> dict:
"""Return planning benchmark results for a specific run."""
if not run_id.startswith("plans_"):
raise HTTPException(400, "Invalid run_id — expected plans_YYYY-MM-DD_HHMMSS")
f = _results_file(run_id)
if not f.exists():
raise HTTPException(404, f"Results not found: {run_id}")
try:
return json.loads(f.read_text(encoding="utf-8"))
except Exception as exc:
raise HTTPException(500, f"Failed to read results: {exc}") from exc
# ── POST /cancel ───────────────────────────────────────────────────────────────
@router.post("/cancel")
def cancel_plans_benchmark() -> dict:
"""Kill the running planning benchmark subprocess."""
global _BENCH_RUNNING, _bench_proc
if not _BENCH_RUNNING:
raise HTTPException(404, "No planning benchmark is currently running")
if _bench_proc is not None:
try:
_bench_proc.terminate()
except Exception as exc:
logger.warning("Failed to terminate plans benchmark: %s", exc)
_BENCH_RUNNING = False
_bench_proc = None
return {"status": "cancelled"}

View file

@ -1,335 +1,8 @@
"""Avocet — SFT candidate import and correction API.
All endpoints are registered on `router` (a FastAPI APIRouter).
api.py includes this router with prefix="/api/sft".
Module-level globals (_SFT_DATA_DIR, _SFT_CONFIG_DIR) follow the same
testability pattern as api.py override them via set_sft_data_dir() and
set_sft_config_dir() in test fixtures.
"""
from __future__ import annotations
import json
import logging
from datetime import datetime, timezone
from pathlib import Path
from typing import Literal
import yaml
from fastapi import APIRouter, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from app.utils import append_jsonl, read_jsonl, write_jsonl
logger = logging.getLogger(__name__)
_ROOT = Path(__file__).parent.parent
_SFT_DATA_DIR: Path = _ROOT / "data"
_SFT_CONFIG_DIR: Path | None = None
router = APIRouter()
# ── Testability seams ──────────────────────────────────────────────────────
def set_sft_data_dir(path: Path) -> None:
global _SFT_DATA_DIR
_SFT_DATA_DIR = path
def set_sft_config_dir(path: Path | None) -> None:
global _SFT_CONFIG_DIR
_SFT_CONFIG_DIR = path
# ── Internal helpers ───────────────────────────────────────────────────────
def _config_file() -> Path:
if _SFT_CONFIG_DIR is not None:
return _SFT_CONFIG_DIR / "label_tool.yaml"
return _ROOT / "config" / "label_tool.yaml"
_DEFAULT_BENCH_RESULTS_DIR = "/Library/Development/CircuitForge/circuitforge-orch/scripts/bench_results"
def set_default_bench_results_dir(path: str) -> None:
"""Override the default bench_results_dir — used by tests to avoid real filesystem."""
global _DEFAULT_BENCH_RESULTS_DIR
_DEFAULT_BENCH_RESULTS_DIR = path
def _get_bench_results_dir() -> Path:
f = _config_file()
if f.exists():
try:
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
d = raw.get("sft", {}).get("bench_results_dir", "")
if d:
return Path(d)
except yaml.YAMLError as exc:
logger.warning("Failed to parse SFT config %s: %s", f, exc)
return Path(_DEFAULT_BENCH_RESULTS_DIR)
def _candidates_file() -> Path:
return _SFT_DATA_DIR / "sft_candidates.jsonl"
def _approved_file() -> Path:
return _SFT_DATA_DIR / "sft_approved.jsonl"
def _read_candidates() -> list[dict]:
return read_jsonl(_candidates_file())
def _write_candidates(records: list[dict]) -> None:
write_jsonl(_candidates_file(), records)
def _is_exportable(r: dict) -> bool:
"""Return True if an approved record is ready to include in SFT export."""
return (
r.get("status") == "approved"
and bool(r.get("corrected_response"))
and str(r["corrected_response"]).strip() != ""
)
# ── GET /runs ──────────────────────────────────────────────────────────────
@router.get("/runs")
def get_runs():
"""List available benchmark runs in the configured bench_results_dir."""
from scripts.sft_import import discover_runs
bench_dir = _get_bench_results_dir()
existing = _read_candidates()
# benchmark_run_id in each record equals the run's directory name by cf-orch convention
imported_run_ids = {
r["benchmark_run_id"]
for r in existing
if r.get("benchmark_run_id") is not None
}
runs = discover_runs(bench_dir)
return [
{
"run_id": r["run_id"],
"timestamp": r["timestamp"],
"candidate_count": r["candidate_count"],
"already_imported": r["run_id"] in imported_run_ids,
}
for r in runs
]
# ── POST /import ───────────────────────────────────────────────────────────
class ImportRequest(BaseModel):
run_id: str
@router.post("/import")
def post_import(req: ImportRequest):
"""Import one benchmark run's sft_candidates.jsonl into the local data dir."""
from scripts.sft_import import discover_runs, import_run
bench_dir = _get_bench_results_dir()
runs = discover_runs(bench_dir)
run = next((r for r in runs if r["run_id"] == req.run_id), None)
if run is None:
raise HTTPException(404, f"Run {req.run_id!r} not found in bench_results_dir")
return import_run(run["sft_path"], _SFT_DATA_DIR)
# ── GET /queue ─────────────────────────────────────────────────────────────
@router.get("/queue")
def get_queue(page: int = 1, per_page: int = 20):
"""Return paginated needs_review candidates."""
records = _read_candidates()
pending = [r for r in records if r.get("status") == "needs_review"]
start = (page - 1) * per_page
return {
"items": pending[start:start + per_page],
"total": len(pending),
"page": page,
"per_page": per_page,
}
# ── POST /submit ───────────────────────────────────────────────────────────
FailureCategory = Literal[
"scoring_artifact",
"style_violation",
"partial_answer",
"wrong_answer",
"format_error",
"hallucination",
]
class SubmitRequest(BaseModel):
id: str
action: Literal["correct", "discard", "flag"]
corrected_response: str | None = None
failure_category: FailureCategory | None = None
@router.post("/submit")
def post_submit(req: SubmitRequest):
"""Record a reviewer decision for one SFT candidate."""
if req.action == "correct":
if not req.corrected_response or not req.corrected_response.strip():
raise HTTPException(422, "corrected_response must be non-empty when action is 'correct'")
records = _read_candidates()
idx = next((i for i, r in enumerate(records) if r.get("id") == req.id), None)
if idx is None:
raise HTTPException(404, f"Record {req.id!r} not found")
record = records[idx]
if record.get("status") != "needs_review":
raise HTTPException(409, f"Record is not in needs_review state (current: {record.get('status')})")
if req.action == "correct":
records[idx] = {
**record,
"status": "approved",
"corrected_response": req.corrected_response,
"failure_category": req.failure_category,
}
_write_candidates(records)
append_jsonl(_approved_file(), records[idx])
elif req.action == "discard":
records[idx] = {**record, "status": "discarded"}
_write_candidates(records)
else: # flag
records[idx] = {**record, "status": "model_rejected"}
_write_candidates(records)
return {"ok": True}
# ── POST /undo ─────────────────────────────────────────────────────────────
class UndoRequest(BaseModel):
id: str
@router.post("/undo")
def post_undo(req: UndoRequest):
"""Restore a previously actioned candidate back to needs_review."""
records = _read_candidates()
idx = next((i for i, r in enumerate(records) if r.get("id") == req.id), None)
if idx is None:
raise HTTPException(404, f"Record {req.id!r} not found")
record = records[idx]
old_status = record.get("status")
if old_status == "needs_review":
raise HTTPException(409, "Record is already in needs_review state")
records[idx] = {**record, "status": "needs_review", "corrected_response": None}
_write_candidates(records)
# If it was approved, remove from the approved file too
if old_status == "approved":
approved = read_jsonl(_approved_file())
write_jsonl(_approved_file(), [r for r in approved if r.get("id") != req.id])
return {"ok": True}
# ── GET /export ─────────────────────────────────────────────────────────────
@router.get("/export")
def get_export() -> StreamingResponse:
"""Stream approved records as SFT-ready JSONL for download."""
exportable = [r for r in read_jsonl(_approved_file()) if _is_exportable(r)]
def generate():
for r in exportable:
record = {
"messages": r.get("prompt_messages", []) + [
{"role": "assistant", "content": r["corrected_response"]}
]
}
yield json.dumps(record) + "\n"
timestamp = datetime.now(timezone.utc).strftime("%Y%m%d-%H%M%S")
return StreamingResponse(
generate(),
media_type="application/x-ndjson",
headers={
"Content-Disposition": f'attachment; filename="sft_export_{timestamp}.jsonl"'
},
)
# ── GET /stats ──────────────────────────────────────────────────────────────
@router.get("/stats")
def get_stats() -> dict[str, object]:
"""Return counts by status, model, and task type."""
records = _read_candidates()
by_status: dict[str, int] = {}
by_model: dict[str, int] = {}
by_task_type: dict[str, int] = {}
for r in records:
status = r.get("status", "unknown")
by_status[status] = by_status.get(status, 0) + 1
model = r.get("model_name", "unknown")
by_model[model] = by_model.get(model, 0) + 1
task_type = r.get("task_type", "unknown")
by_task_type[task_type] = by_task_type.get(task_type, 0) + 1
approved = read_jsonl(_approved_file())
export_ready = sum(1 for r in approved if _is_exportable(r))
return {
"total": len(records),
"by_status": by_status,
"by_model": by_model,
"by_task_type": by_task_type,
"export_ready": export_ready,
}
# ── GET /config ─────────────────────────────────────────────────────────────
@router.get("/config")
def get_sft_config() -> dict:
"""Return the current SFT configuration (bench_results_dir)."""
f = _config_file()
if not f.exists():
return {"bench_results_dir": ""}
try:
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
except yaml.YAMLError:
return {"bench_results_dir": ""}
sft_section = raw.get("sft") or {}
return {"bench_results_dir": sft_section.get("bench_results_dir", "")}
class SftConfigPayload(BaseModel):
bench_results_dir: str
@router.post("/config")
def post_sft_config(payload: SftConfigPayload) -> dict:
"""Write the bench_results_dir setting to the config file."""
f = _config_file()
f.parent.mkdir(parents=True, exist_ok=True)
try:
raw = yaml.safe_load(f.read_text(encoding="utf-8")) if f.exists() else {}
raw = raw or {}
except yaml.YAMLError:
raw = {}
raw["sft"] = {"bench_results_dir": payload.bench_results_dir}
tmp = f.with_suffix(".tmp")
tmp.write_text(yaml.dump(raw, allow_unicode=True, sort_keys=False), encoding="utf-8")
tmp.rename(f)
return {"ok": True}
"""Backward-compat shim -- logic moved to app/data/corrections.py."""
from app.data.corrections import ( # noqa: F401
router,
set_data_dir as set_sft_data_dir,
set_config_dir as set_sft_config_dir,
set_default_bench_results_dir,
_DEFAULT_BENCH_RESULTS_DIR,
)

0
app/train/__init__.py Normal file
View file

339
app/train/train.py Normal file
View file

@ -0,0 +1,339 @@
"""Avocet -- train job queue API.
SQLite-backed job queue for finetune jobs. Replaces the ad-hoc
_running_procs dict in api.py with a persistent, inspectable queue.
Routes (all under /api/train when api.py mounts with prefix="/api/train"):
GET /jobs -- list all jobs, newest first
POST /jobs -- create a new job
GET /jobs/{id} -- get one job by id
DELETE /jobs/{id}/cancel -- cancel a queued or running job
GET /jobs/{id}/run -- SSE: run the job, stream stdout
GET /results -- list completed models with training_info.json metrics
SQLite schema:
CREATE TABLE IF NOT EXISTS jobs (
id TEXT PRIMARY KEY,
type TEXT NOT NULL, -- 'classifier' | 'llm-sft'
model_key TEXT NOT NULL,
status TEXT NOT NULL DEFAULT 'queued',
config_json TEXT NOT NULL DEFAULT '{}',
created_at TEXT NOT NULL,
started_at TEXT,
completed_at TEXT,
error TEXT
)
Testability seam: _DB_PATH global, override via set_db_path().
"""
from __future__ import annotations
import json
import logging
import os
import sqlite3
import subprocess as _subprocess
import uuid
from contextlib import contextmanager
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Generator
import yaml
from fastapi import APIRouter, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
logger = logging.getLogger(__name__)
_ROOT = Path(__file__).parent.parent.parent
_DB_PATH: Path = _ROOT / "data" / "train_jobs.db"
_MODELS_DIR: Path = _ROOT / "models"
_CONFIG_DIR: Path | None = None # override in tests via set_config_dir()
_running_procs: dict[str, Any] = {}
router = APIRouter()
# -- Testability seams -------------------------------------------------
def set_db_path(path: Path) -> None:
global _DB_PATH
_DB_PATH = path
def set_models_dir(path: Path) -> None:
global _MODELS_DIR
_MODELS_DIR = path
def set_config_dir(path: "Path | None") -> None:
global _CONFIG_DIR
_CONFIG_DIR = path
# -- Config helpers ----------------------------------------------------
def _config_file() -> Path:
if _CONFIG_DIR is not None:
return _CONFIG_DIR / "label_tool.yaml"
return _ROOT / "config" / "label_tool.yaml"
def _load_train_config() -> dict:
"""Read python_bin from label_tool.yaml.
Priority (highest to lowest):
1. label_tool.yaml train: python_bin
2. label_tool.yaml cforch: python_bin
3. Hardcoded default (classifiers conda env)
"""
_DEFAULT_PYTHON_BIN = "/devl/miniconda3/envs/job-seeker-classifiers/bin/python"
f = _config_file()
train_cfg: dict = {}
cforch_cfg: dict = {}
if f.exists():
try:
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
train_cfg = raw.get("train", {}) or {}
cforch_cfg = raw.get("cforch", {}) or {}
except yaml.YAMLError as exc:
logger.warning("Failed to parse train config %s: %s", f, exc)
return {
"python_bin": train_cfg.get(
"python_bin",
cforch_cfg.get("python_bin", _DEFAULT_PYTHON_BIN),
),
}
# -- Database helpers --------------------------------------------------
@contextmanager
def _db() -> Generator[sqlite3.Connection, None, None]:
conn = sqlite3.connect(str(_DB_PATH))
conn.row_factory = sqlite3.Row
try:
yield conn
conn.commit()
finally:
conn.close()
def _init_db() -> None:
"""Create jobs table if it does not exist. Called lazily per request."""
_DB_PATH.parent.mkdir(parents=True, exist_ok=True)
with _db() as conn:
conn.execute("""
CREATE TABLE IF NOT EXISTS jobs (
id TEXT PRIMARY KEY,
type TEXT NOT NULL,
model_key TEXT NOT NULL,
status TEXT NOT NULL DEFAULT 'queued',
config_json TEXT NOT NULL DEFAULT '{}',
created_at TEXT NOT NULL,
started_at TEXT,
completed_at TEXT,
error TEXT
)
""")
def _row_to_dict(row: sqlite3.Row) -> dict:
return {k: row[k] for k in row.keys()}
# -- GPU selection (copied from api.py) --------------------------------
def _best_cuda_device() -> str:
"""Return index of GPU with most free VRAM, or empty string."""
try:
out = _subprocess.check_output(
["nvidia-smi", "--query-gpu=index,memory.free",
"--format=csv,noheader,nounits"],
text=True, timeout=5,
)
best_idx, best_free = "", 0
for line in out.strip().splitlines():
parts = line.strip().split(", ")
if len(parts) == 2:
idx, free = parts[0].strip(), int(parts[1].strip())
if free > best_free:
best_free, best_idx = free, idx
return best_idx
except Exception:
return ""
# -- Pydantic models ---------------------------------------------------
class CreateJobRequest(BaseModel):
type: str # "classifier" | "llm-sft"
model_key: str # e.g. "deberta-small"
config_json: dict = {}
# -- Routes ------------------------------------------------------------
@router.get("/jobs")
def list_jobs() -> dict:
_init_db()
with _db() as conn:
rows = conn.execute("SELECT * FROM jobs ORDER BY created_at DESC").fetchall()
return {"jobs": [_row_to_dict(r) for r in rows]}
@router.post("/jobs")
def create_job(req: CreateJobRequest) -> dict:
if req.type not in ("classifier", "llm-sft"):
raise HTTPException(400, f"Unknown job type: {req.type!r}. Must be 'classifier' or 'llm-sft'.")
_init_db()
job_id = str(uuid.uuid4())
now = datetime.now(timezone.utc).isoformat()
with _db() as conn:
conn.execute(
"INSERT INTO jobs (id, type, model_key, status, config_json, created_at) "
"VALUES (?, ?, ?, 'queued', ?, ?)",
(job_id, req.type, req.model_key, json.dumps(req.config_json), now),
)
return {"id": job_id, "type": req.type, "model_key": req.model_key,
"status": "queued", "config_json": req.config_json,
"created_at": now, "started_at": None, "completed_at": None, "error": None}
@router.get("/jobs/{job_id}")
def get_job(job_id: str) -> dict:
_init_db()
with _db() as conn:
row = conn.execute("SELECT * FROM jobs WHERE id = ?", (job_id,)).fetchone()
if row is None:
raise HTTPException(404, f"Job {job_id!r} not found")
return _row_to_dict(row)
@router.delete("/jobs/{job_id}/cancel")
def cancel_job(job_id: str) -> dict:
_init_db()
with _db() as conn:
row = conn.execute("SELECT * FROM jobs WHERE id = ?", (job_id,)).fetchone()
if row is None:
raise HTTPException(404, f"Job {job_id!r} not found")
if row["status"] not in ("queued", "running"):
raise HTTPException(409, f"Job is {row['status']} -- cannot cancel")
now = datetime.now(timezone.utc).isoformat()
conn.execute("UPDATE jobs SET status='cancelled', completed_at=? WHERE id=?", (now, job_id))
proc = _running_procs.pop(job_id, None)
if proc is not None:
try:
proc.terminate()
proc.wait(timeout=3)
except _subprocess.TimeoutExpired:
try:
proc.kill()
except OSError:
pass
return {"status": "cancelled"}
@router.get("/jobs/{job_id}/run")
def run_job(job_id: str) -> StreamingResponse:
_init_db()
with _db() as conn:
row = conn.execute("SELECT * FROM jobs WHERE id = ?", (job_id,)).fetchone()
if row is None:
raise HTTPException(404, f"Job {job_id!r} not found")
if row["status"] != "queued":
raise HTTPException(409, f"Job is {row['status']} -- only queued jobs can be run")
job = _row_to_dict(row)
def generate():
cfg = _load_train_config()
python_bin = cfg["python_bin"]
config = json.loads(job["config_json"] or "{}")
model_key = job["model_key"]
epochs = config.get("epochs", 5)
if job["type"] == "classifier":
script = str(_ROOT / "scripts" / "finetune_classifier.py")
cmd = [python_bin, script, "--model", model_key, "--epochs", str(epochs)]
data_dir = _ROOT / "data"
for sf in config.get("score_files", []):
resolved = (data_dir / sf).resolve()
if resolved.is_relative_to(data_dir.resolve()):
cmd.extend(["--score", str(resolved)])
elif job["type"] == "llm-sft":
script = str(_ROOT / "scripts" / "finetune_sft.py")
cmd = [python_bin, script, "--model", model_key, "--epochs", str(epochs)]
else:
job_type = job["type"]
yield f"data: {json.dumps({'type': 'error', 'message': f'Unknown job type: {job_type}'})}\n\n"
return
proc_env = {**os.environ, "PYTORCH_ALLOC_CONF": "expandable_segments:True"}
best_gpu = _best_cuda_device()
if best_gpu:
proc_env["CUDA_VISIBLE_DEVICES"] = best_gpu
gpu_note = f"GPU {best_gpu}" if best_gpu else "CPU (no GPU found)"
yield f"data: {json.dumps({'type': 'progress', 'message': f'[train] Using {gpu_note}'})}\n\n"
now = datetime.now(timezone.utc).isoformat()
with _db() as conn:
conn.execute("UPDATE jobs SET status='running', started_at=? WHERE id=?", (now, job_id))
try:
proc = _subprocess.Popen(
cmd, stdout=_subprocess.PIPE, stderr=_subprocess.STDOUT,
text=True, bufsize=1, cwd=str(_ROOT), env=proc_env,
)
_running_procs[job_id] = proc
try:
for line in proc.stdout:
line = line.rstrip()
if line:
yield f"data: {json.dumps({'type': 'progress', 'message': line})}\n\n"
proc.wait()
finished_at = datetime.now(timezone.utc).isoformat()
if proc.returncode == 0:
with _db() as conn:
conn.execute(
"UPDATE jobs SET status='completed', completed_at=? WHERE id=?",
(finished_at, job_id))
yield f"data: {json.dumps({'type': 'complete'})}\n\n"
else:
err = f"Process exited with code {proc.returncode}"
with _db() as conn:
conn.execute(
"UPDATE jobs SET status='failed', completed_at=?, error=? WHERE id=?",
(finished_at, err, job_id))
yield f"data: {json.dumps({'type': 'error', 'message': err})}\n\n"
finally:
_running_procs.pop(job_id, None)
except Exception as exc:
err = str(exc)
finished_at = datetime.now(timezone.utc).isoformat()
with _db() as conn:
conn.execute(
"UPDATE jobs SET status='failed', completed_at=?, error=? WHERE id=?",
(finished_at, err, job_id))
yield f"data: {json.dumps({'type': 'error', 'message': err})}\n\n"
return StreamingResponse(generate(), media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"})
@router.get("/results")
def list_results() -> dict:
if not _MODELS_DIR.exists():
return {"results": []}
results = []
for sub in _MODELS_DIR.iterdir():
if not sub.is_dir():
continue
info_path = sub / "training_info.json"
if not info_path.exists():
continue
try:
info = json.loads(info_path.read_text(encoding="utf-8"))
results.append(info)
except Exception as exc:
logger.warning("Failed to read training_info.json from %s: %s", info_path, exc)
return {"results": results}

View file

@ -106,7 +106,7 @@ def read_jsonl(path: Path) -> list[dict]:
def write_jsonl(path: Path, records: list[dict]) -> None:
"""Write records to a JSONL file, overwriting any existing content."""
path.parent.mkdir(parents=True, exist_ok=True)
content = "\n".join(json.dumps(r) for r in records)
content = "\n".join(json.dumps(r, ensure_ascii=False) for r in records)
path.write_text(content + ("\n" if records else ""), encoding="utf-8")
@ -114,4 +114,4 @@ def append_jsonl(path: Path, record: dict) -> None:
"""Append a single record to a JSONL file."""
path.parent.mkdir(parents=True, exist_ok=True)
with open(path, "a", encoding="utf-8") as fh:
fh.write(json.dumps(record) + "\n")
fh.write(json.dumps(record, ensure_ascii=False) + "\n")

View file

@ -41,11 +41,20 @@ cforch:
# Python interpreter with cf-orch installed
python_bin: /devl/miniconda3/envs/cf/bin/python
# Connection config — override env vars CF_ORCH_URL / CF_LICENSE_KEY / OLLAMA_HOST
# Connection config — override env vars CF_ORCH_URL / CF_LICENSE_KEY / OLLAMA_HOST / CF_JUDGE_URL / HF_TOKEN
# coordinator_url: http://localhost:7700
# license_key: CFG-AVCT-xxxx-xxxx-xxxx
# ollama_url: http://localhost:11434
# ollama_model: llama3.2:3b
# embed_model: nomic-embed-text # Ollama embedding model for EmbeddingKNNAdapter
# judge_url: http://10.1.10.158:8008 # Sif cf-text — LLM-as-judge secondary scorer
# judge_url: http://10.1.10.71:8008 # Heimdall cf-text (alternative)
# Or set CF_JUDGE_URL. Populates the Judge URL field in the LLM Eval UI automatically.
# hf_token: hf_xxxxxxxxxxxxxxxxxxxx # HuggingFace token — required for gated/terms-restricted models
# Directory containing per-node profile YAMLs (cf-orch node profiles).
# Default: derived from bench_script location (../../profiles/nodes).
# profiles_dir: /Library/Development/CircuitForge/circuitforge-orch/circuitforge_orch/profiles/nodes
# Imitate tab — pull real samples from sibling CF product APIs and run them
# through local LLMs to build a corrections dataset.
@ -102,11 +111,34 @@ imitate:
text_fields: [title, description, seller_info]
prompt_template: "Evaluate the trustworthiness of this listing and flag any red flags:\n\n{text}"
- id: osprey
name: Osprey
icon: "📞"
description: Gov't hold-line automation
base_url: http://localhost:8520
sample_endpoint: /api/calls/recent
text_fields: [agency, issue, notes]
prompt_template: "Draft a concise summary of this government call record:\n\n{text}"
- id: pagepiper
name: Pagepiper
icon: "📄"
description: "PDF/rulebook RAG tool: page-level text chunks"
base_url: http://localhost:8511
health_path: /api/health
sample_endpoint: /api/library
chunk_endpoint: /api/library/sample-chunks?limit=50 # requires pagepiper#6
text_fields: [title]
prompt_template: "Summarize the key rules described in this passage:\n\n{text}"
# ── Log corpus (Turnstone training data) ──────────────────────────────────────
corpus:
# Directory containing pipeline JSONL log files to ingest (pull-side).
# Files named <script>_<ts>.jsonl; one structured record per line.
# POST /api/corpus/pipeline-ingest walks this dir and imports new files.
# NFS-mounted on both Heimdall and Sif at /Library/Assets/
pipeline_ingest_dir: /Library/Assets/logs/pipeline/
# Turnstone push sources (consent-gated, token-authenticated).
# sources:
# - token: "your-bearer-token"
# source_host: "node.local"
# owner: YourName
# consent_date: "2026-05-17"
# consent_method: signal_chat
# ── Embedding model comparison harness ────────────────────────────────────────
embed_bench:
# ollama_url: http://localhost:11434 # optional; falls back to cforch.ollama_url
# top_k: 5 # default hits per model per query

View file

@ -90,6 +90,12 @@ usage() {
echo -e " ${GREEN}score [args]${NC} Shortcut: --score [args]"
echo -e " ${GREEN}compare [args]${NC} Shortcut: --compare [args]"
echo ""
echo " Planning Benchmark:"
echo -e " ${GREEN}plans-bench [args]${NC} Run benchmark_plans.py (args passed through)"
echo -e " ${GREEN}plans-list${NC} Shortcut: --list-models"
echo -e " ${GREEN}plans-run <model> [args]${NC} Run a single model (--verbose auto-added)"
echo -e " ${GREEN}plans-compare <m1> <m2> [more]${NC} Compare models side-by-side"
echo ""
echo " Writing Style Benchmark:"
echo -e " ${GREEN}style-bench [args]${NC} Run benchmark_style.py (args passed through)"
echo -e " ${GREEN}style-list${NC} List available ollama models for style bench"
@ -127,6 +133,8 @@ case "$CMD" in
fi
mkdir -p "$LOG_DIR"
API_LOG="${LOG_DIR}/api.log"
# Load .env if present — sets HF_TOKEN and other optional overrides.
[[ -f .env ]] && set -a && source .env && set +a
info "Building Vue SPA…"
(cd web && npm run build) >> "$API_LOG" 2>&1
info "Starting FastAPI on port ${API_PORT}"
@ -179,6 +187,9 @@ case "$CMD" in
mkdir -p "$LOG_DIR"
DEV_API_LOG="${LOG_DIR}/dev-api.log"
# Load .env if present — sets HF_TOKEN and other optional overrides.
[[ -f .env ]] && set -a && source .env && set +a
if [[ -f "$DEV_API_PID_FILE" ]] && kill -0 "$(<"$DEV_API_PID_FILE")" 2>/dev/null; then
warn "Dev API already running (PID $(<"$DEV_API_PID_FILE"))"
else
@ -255,6 +266,30 @@ case "$CMD" in
exec "$0" benchmark --compare "$@"
;;
plans-bench)
info "Running planning benchmark (${ENV_UI})…"
"$PYTHON_UI" scripts/benchmark_plans.py "$@"
;;
plans-list)
exec "$0" plans-bench --list-models
;;
plans-run)
if [[ $# -lt 1 ]]; then
error "Usage: ./manage.sh plans-run <model-key> [extra args]"
fi
MODEL="$1"; shift
exec "$0" plans-bench --model "$MODEL" --verbose "$@"
;;
plans-compare)
if [[ $# -lt 2 ]]; then
error "Usage: ./manage.sh plans-compare <model1> <model2> [more…]"
fi
exec "$0" plans-bench --compare "$@" --verbose
;;
style-bench)
info "Running writing style benchmark (${ENV_BM})…"
if [[ ! -x "$PYTHON_BM" ]]; then

View file

@ -3,3 +3,6 @@ testpaths = tests
python_files = test_*.py
python_classes = Test*
python_functions = test_*
markers =
gpu: requires an idle GPU; excluded from default runs
slow: long-running test; excluded from default CI runs

View file

@ -3,3 +3,4 @@ pydantic>=2.0.0
uvicorn[standard]>=0.20.0
httpx>=0.24.0
pytest>=7.0.0
pyyaml>=6.0

View file

@ -39,6 +39,7 @@ from scripts.classifier_adapters import (
LABELS,
LABEL_DESCRIPTIONS,
ClassifierAdapter,
EmbeddingKNNAdapter,
FineTunedAdapter,
GLiClassAdapter,
RerankerAdapter,
@ -130,6 +131,13 @@ MODEL_REGISTRY: dict[str, dict[str, Any]] = {
"params": "600M",
"default": False,
},
"embed-knn-nomic": {
"adapter": EmbeddingKNNAdapter,
"model_id": "nomic-embed-text",
"params": "local-embed",
"default": False, # requires orch or ollama; use --include-slow
"kwargs": {"k": 3},
},
}
# ---------------------------------------------------------------------------
@ -184,6 +192,42 @@ def discover_finetuned_models(models_dir: Path | None = None) -> list[dict]:
return found
def build_exemplars_from_jsonl(path: str, k_per_label: int = 10) -> dict[str, list[str]]:
"""Sample up to k_per_label formatted email texts per label from a scored JSONL.
Formats each row as 'Subject: {subject}\n\n{body[:600]}' the same format
EmbeddingKNNAdapter uses at classify() time. Rows missing the 'label' key
are skipped silently.
Returns dict[label, list[str]] ready for EmbeddingKNNAdapter(exemplar_texts=...).
"""
result: dict[str, list[str]] = {}
p = Path(path)
with p.open(encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
try:
row = json.loads(line)
except json.JSONDecodeError as exc:
print(f"[build_exemplars] WARN: skipping malformed line: {exc}", flush=True)
continue
label = row.get("label")
if not label:
continue
subject = row.get("subject", "")
body = row.get("body", "")
if not subject and not body:
continue
texts = result.setdefault(label, [])
if len(texts) < k_per_label:
texts.append(
f"Subject: {subject}\n\n{body[:600]}"
)
return result
def _active_models(include_slow: bool = False) -> dict[str, dict[str, Any]]:
"""Return the active model registry, merged with any discovered fine-tuned models."""
active: dict[str, dict[str, Any]] = {

734
scripts/benchmark_plans.py Normal file
View file

@ -0,0 +1,734 @@
#!/usr/bin/env python
"""CF-specific planning benchmark — compare base models before fine-tuning.
Sends held-out CircuitForge planning prompts to one or more models via the
cf-text (local) or cf-orch API, then scores responses against CF-specific
rubrics. Use this to select the best base model for SFT.
Scoring rubrics (each 0-1, summed to total/N):
- task_structure : uses checkbox syntax (- [ ]), git commit steps
- tier_awareness : mentions Free/Paid/Premium/Ultra tiers
- privacy_pillar : mentions privacy/local-inference/no-logging
- safety_pillar : mentions safety, human approval, or reversibility
- accessibility : mentions ND/accessibility/adaptive needs
- license_split : mentions MIT vs BSL or open-core model
- file_paths : uses plausible file path references
- cf_conventions : uses conda run -n cf, /Library/Development/, or known CF dirs
- paired_coherence : (paired only) plan references the design doc's feature name
- length_ok : 3002500 words (under-short = hallucination risk; over-long = padding)
Usage
-----
# List available model targets
python scripts/benchmark_plans.py --list-models
# Run all held-out prompts against a single model, print report
python scripts/benchmark_plans.py --model granite-4.1-8b
# Compare two models side-by-side
python scripts/benchmark_plans.py --compare granite-4.1-8b deepseek-r1-7b-4bit
# Run with a custom API base (cf-text default: http://localhost:8080/v1)
python scripts/benchmark_plans.py --model granite-4.1-8b --api-base http://localhost:8080/v1
# Export detailed results JSON
python scripts/benchmark_plans.py --model granite-4.1-8b --output data/bench_results.json
"""
from __future__ import annotations
import argparse
import json
import re
import sys
import threading
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass, field, asdict
from pathlib import Path
from typing import Any
import httpx
# ── Paths ──────────────────────────────────────────────────────────────────────
_ROOT = Path(__file__).parent.parent
_DATA_DIR = _ROOT / "data"
CF_TEXT_BASE = "http://localhost:8080/v1"
CF_ORCH_BASE = "http://localhost:8090/v1"
CF_COORD_URL = "http://10.1.10.71:7700" # cf-orch coordinator (LAN)
# ── Held-out prompts ───────────────────────────────────────────────────────────
# These are NOT in the training export (no matching docs in circuitforge-plans/).
# Each prompt exercises a different CF planning domain.
HELD_OUT_PROMPTS: list[dict[str, Any]] = [
{
"id": "ho_001",
"name": "kiwi_barcode_ocr",
"domain": "feature_plan",
"prompt": (
"You are a senior engineer on Kiwi, a CircuitForge pantry-tracking product. "
"Write a detailed implementation plan for adding barcode scanning via device camera "
"and receipt OCR to the item-add flow.\n\n"
"The plan should include: file structure (create/modify), step-by-step task checklist "
"with checkboxes, any DB migrations, and git commit steps."
),
"expected_signals": ["task_structure", "file_paths", "cf_conventions"],
},
{
"id": "ho_002",
"name": "peregrine_ats_scoring",
"domain": "feature_design",
"prompt": (
"Write a design document for Peregrine: ATS keyword scoring for job applications.\n\n"
"Context: Peregrine users paste job descriptions and their resume. "
"We want to score how well the resume keywords match the JD and suggest rewrites. "
"Describe the architecture, data flow, and key design decisions."
),
"expected_signals": ["privacy_pillar", "tier_awareness", "license_split"],
},
{
"id": "ho_003",
"name": "tier_gate_local_llm",
"domain": "architecture",
"prompt": (
"Design the tier-gating architecture for a new CircuitForge product. "
"The product should:\n"
"- Default to local LLM inference for all tiers\n"
"- Unlock cloud LLM for Paid tier and above\n"
"- Keep fine-tuned model weights for Premium/Ultra only\n\n"
"Describe how the tier check integrates with the LLM router, "
"what happens when a Free user tries a Paid-tier feature, "
"and how BYOK (bring-your-own-key) fits in."
),
"expected_signals": ["tier_awareness", "privacy_pillar", "license_split"],
},
{
"id": "ho_004",
"name": "heimdall_webhook_plan",
"domain": "feature_plan",
"prompt": (
"Break the following Heimdall feature into a detailed implementation plan with "
"file structure and task checkboxes — Stripe webhook handler for subscription lifecycle.\n\n"
"Heimdall is the CircuitForge license server (FastAPI + SQLite). "
"The webhook needs to handle checkout.session.completed, "
"customer.subscription.updated, and customer.subscription.deleted events."
),
"expected_signals": ["task_structure", "file_paths", "safety_pillar"],
},
{
"id": "ho_005",
"name": "nd_accessible_onboarding",
"domain": "ux_design",
"prompt": (
"You are a product designer working on Harrier, a CircuitForge tool for "
"helping people navigate government benefits applications.\n\n"
"Design the onboarding flow for neurodivergent (ND) users. "
"Consider: ADHD time-blindness, executive function challenges, demand avoidance, "
"and rejection sensitivity. The flow should reduce cognitive load and "
"never use urgency or panic patterns."
),
"expected_signals": ["accessibility", "safety_pillar", "privacy_pillar"],
},
{
"id": "ho_006",
"name": "circuitforge_core_extraction",
"domain": "architecture",
"prompt": (
"Produce a CircuitForge-style design document for the following circuitforge-core "
"feature — shared ActivityPub federation module.\n\n"
"Background: Multiple CF products (Kiwi, Rook, Snipe) want to publish updates "
"to ActivityPub. Build it once in cf-core (MIT licensed) so all products can use it. "
"Design the module API, describe what belongs in MIT vs BSL, and note federation "
"privacy constraints."
),
"expected_signals": ["license_split", "privacy_pillar", "cf_conventions"],
},
{
"id": "ho_007",
"name": "snipe_trust_score_plan",
"domain": "feature_plan",
"prompt": (
"You are a senior engineer on Snipe, a CircuitForge eBay trust-scoring tool. "
"Write a step-by-step engineering plan for: seller trust score calculation.\n\n"
"The score should combine: feedback ratio, account age, item-specifics completeness, "
"listing photo quality, and shipping time accuracy. "
"Include file structure, test plan, and migration steps."
),
"expected_signals": ["task_structure", "file_paths", "safety_pillar"],
},
{
"id": "ho_008",
"name": "avocet_training_pipeline",
"domain": "feature_plan",
"prompt": (
"Break the following Avocet feature into a detailed implementation plan — "
"end-to-end fine-tuning pipeline from labeled JSONL to deployed GGUF model.\n\n"
"Avocet is the CircuitForge email classifier training tool. "
"The pipeline should: validate the dataset, run LoRA SFT via unsloth, "
"quantize to Q5_K_M GGUF, run the benchmark harness, and register the model "
"in the Avocet model queue if it beats the baseline."
),
"expected_signals": ["task_structure", "file_paths", "cf_conventions"],
},
{
"id": "ho_009",
"name": "privacy_data_flow",
"domain": "architecture",
"prompt": (
"Design the data privacy architecture for a CircuitForge cloud product. "
"Describe: what PII is collected, how it's stored, retention policy, "
"obfuscation strategy for cloud-side logs, and how consent is obtained "
"in plain language. The product handles job applications (resumes, cover letters)."
),
"expected_signals": ["privacy_pillar", "safety_pillar", "accessibility"],
},
{
"id": "ho_010",
"name": "git_workflow_doc",
"domain": "process_doc",
"prompt": (
"Write a developer process document for CircuitForge: conventional commit and "
"branch workflow for a BSL 1.1 open-core product.\n\n"
"Cover: commit message format (type: description), branch naming, "
"when to use feature branches vs direct main commits, "
"how the MIT/BSL split affects which commits go in which branch, "
"and how CI gates on gitleaks for secret scanning."
),
"expected_signals": ["license_split", "cf_conventions", "task_structure"],
},
]
# ── Rubric scoring ─────────────────────────────────────────────────────────────
_TASK_STRUCTURE_RE = re.compile(r"- \[ \]", re.MULTILINE)
_COMMIT_RE = re.compile(r"git commit|git add", re.IGNORECASE)
_TIER_RE = re.compile(r"\b(Free|Paid|Premium|Ultra)\s+tier|\btier\s+(Free|Paid|Premium|Ultra)", re.IGNORECASE)
_PRIVACY_RE = re.compile(r"\b(privacy|local.?inference|no.?logging|no.?pii|user.?data|data.?reten|obfuscat)", re.IGNORECASE)
_SAFETY_RE = re.compile(r"\b(human.?approv|reversib|safety|safe.?default|fail.?safe|harm)", re.IGNORECASE)
_A11Y_RE = re.compile(r"\b(neurodiverg|ND\b|accessib|adaptive|ADHD|autism|executive.?function|demand.?avoid)", re.IGNORECASE)
_LICENSE_RE = re.compile(r"\b(MIT|BSL|open.?core|proprietary|commercial.?licens)", re.IGNORECASE)
_FILE_PATH_RE = re.compile(r"(app/|tests?/|src/|scripts?/)\w[\w/.-]{3,}", re.IGNORECASE)
_CF_CONV_RE = re.compile(r"(conda run -n cf|/Library/Development/CircuitForge|circuitforge-core|manage\.sh)", re.IGNORECASE)
@dataclass
class RubricScore:
task_structure: float = 0.0
tier_awareness: float = 0.0
privacy_pillar: float = 0.0
safety_pillar: float = 0.0
accessibility: float = 0.0
license_split: float = 0.0
file_paths: float = 0.0
cf_conventions: float = 0.0
length_ok: float = 0.0
def total(self) -> float:
vals = [self.task_structure, self.tier_awareness, self.privacy_pillar,
self.safety_pillar, self.accessibility, self.license_split,
self.file_paths, self.cf_conventions, self.length_ok]
return sum(vals) / len(vals)
def as_dict(self) -> dict[str, float]:
return asdict(self)
def score_response(response: str, prompt_meta: dict[str, Any]) -> RubricScore:
words = len(response.split())
s = RubricScore()
# Task structure: needs checkboxes AND at least one commit step
checkbox_hits = len(_TASK_STRUCTURE_RE.findall(response))
has_commit = bool(_COMMIT_RE.search(response))
s.task_structure = min(1.0, checkbox_hits / 5) * 0.7 + (0.3 if has_commit else 0.0)
# Tier awareness
s.tier_awareness = min(1.0, len(_TIER_RE.findall(response)) / 2)
# Privacy pillar
s.privacy_pillar = min(1.0, len(_PRIVACY_RE.findall(response)) / 3)
# Safety pillar
s.safety_pillar = min(1.0, len(_SAFETY_RE.findall(response)) / 2)
# Accessibility
s.accessibility = min(1.0, len(_A11Y_RE.findall(response)) / 2)
# License split awareness
s.license_split = min(1.0, len(_LICENSE_RE.findall(response)) / 2)
# File paths: at least 3 plausible path references
s.file_paths = min(1.0, len(_FILE_PATH_RE.findall(response)) / 3)
# CF conventions
s.cf_conventions = min(1.0, len(_CF_CONV_RE.findall(response)) / 2)
# Length: 2002500 words is healthy; outside = partial credit
if 200 <= words <= 2500:
s.length_ok = 1.0
elif words < 200:
s.length_ok = words / 200
else:
s.length_ok = max(0.0, 1.0 - (words - 2500) / 2500)
return s
# ── Model client ───────────────────────────────────────────────────────────────
# Registry of named model targets (shorthand → {api_base, model_name})
MODEL_REGISTRY: dict[str, dict[str, str]] = {
"deepseek-r1-1.5b": {
"api_base": CF_TEXT_BASE,
"model": "deepseek-r1-1.5b",
"description": "DeepSeek R1 1.5B distill (cf-orch catalog key)",
},
"deepseek-r1-7b-4bit": {
"api_base": CF_TEXT_BASE,
"model": "deepseek-r1-7b-4bit",
"description": "DeepSeek R1 7B distill, 4-bit (cf-orch catalog key)",
},
"deepseek-r1-0528-qwen3-8b-gguf": {
"api_base": CF_TEXT_BASE,
"model": "deepseek-r1-0528-qwen3-8b-gguf",
"description": "DeepSeek R1 0528 Qwen3 8B GGUF -- current reasoning model (4 nodes)",
},
"deepseek-coder-6.7b-4bit": {
"api_base": CF_TEXT_BASE,
"model": "deepseek-coder-6.7b-4bit",
"description": "DeepSeek Coder 6.7B instruct, 4-bit (cf-orch catalog key)",
},
"granite-4.1-8b": {
"api_base": CF_TEXT_BASE,
"model": "granite-4.1-8b",
"description": "IBM Granite 4.1 8B, 4-bit -- safety-trained (cf-orch catalog key)",
},
"capybarahermes-2.5-mistral-7b-gguf": {
"api_base": CF_TEXT_BASE,
"model": "capybarahermes-2.5-mistral-7b-gguf",
"description": "CapybaraHermes 2.5 Mistral 7B GGUF -- conversational/creative (4 nodes)",
},
"darwin-9b-opus-gguf": {
"api_base": CF_TEXT_BASE,
"model": "darwin-9b-opus-gguf",
"description": "Darwin 9B Opus GGUF -- high-quality long-form writing (3 nodes)",
},
"qwen2.5-3b": {
"api_base": CF_TEXT_BASE,
"model": "qwen2.5-3b",
"description": "Qwen 2.5 3B Q4 GGUF (cf-orch catalog key)",
},
"qwen2.5-7b": {
"api_base": CF_TEXT_BASE,
"model": "qwen2.5-7b",
"description": "Qwen 2.5 7B Q4 GGUF (cf-orch catalog key)",
},
}
# ── cf-orch allocation ─────────────────────────────────────────────────────────
def _cforch_allocate(
model_id: str,
cforch_url: str,
startup_timeout_s: float = 300.0,
) -> tuple[str, str] | None:
"""Allocate a cf-text instance for model_id via the cf-orch coordinator.
Returns (service_url, allocation_id) on success, None on failure.
service_url is the direct node URL exposing /v1/chat/completions.
"""
try:
resp = httpx.post(
f"{cforch_url}/api/services/cf-text/allocate",
json={
"model_candidates": [model_id],
"caller": "avocet",
"pipeline": "plans_benchmark",
},
timeout=120.0,
)
resp.raise_for_status()
data = resp.json()
service_url: str = data["url"]
allocation_id: str = data.get("allocation_id", "")
node_id: str = data.get("node_id", "")
gpu_id: int | None = data.get("gpu_id")
if data.get("started", False) and not data.get("warm", True):
# Use \n so the SSE generator sees the line immediately
print(f" [cold start] loading {model_id!r} — polling every 3s…", flush=True)
t0 = time.monotonic()
deadline = t0 + startup_timeout_s
probe_misses = 0
while time.monotonic() < deadline:
elapsed = time.monotonic() - t0
try:
status = httpx.get(f"{cforch_url}/api/services/cf-text/status", timeout=5.0)
if status.is_success:
instances = status.json().get("instances", [])
match = next(
(i for i in instances
if i.get("node_id") == node_id and i.get("gpu_id") == gpu_id),
None,
)
if match:
probe_misses = 0
state = match.get("state", "")
if state == "running":
print(f" [cold start] ready in {elapsed:.0f}s", flush=True)
return service_url, allocation_id
elif state == "stopped":
print(f" [cold start] failed — service stopped after {elapsed:.0f}s", flush=True)
return None
else:
# still starting — emit keepalive so SSE stream stays alive
print(f" [cold start] state={state!r} elapsed={elapsed:.0f}s", flush=True)
else:
probe_misses += 1
print(f" [cold start] waiting… elapsed={elapsed:.0f}s", flush=True)
if probe_misses >= 6:
try:
h = httpx.get(f"{service_url}/health", timeout=3.0)
if h.is_success:
print(f" [cold start] ready via health check in {elapsed:.0f}s", flush=True)
return service_url, allocation_id
except Exception:
pass
else:
print(f" [cold start] status poll returned {status.status_code}, elapsed={elapsed:.0f}s", flush=True)
except Exception as poll_exc:
print(f" [cold start] poll error: {poll_exc} elapsed={elapsed:.0f}s", flush=True)
time.sleep(3.0)
print(f" [cold start] timed out after {time.monotonic()-t0:.0f}s", flush=True)
return None
return service_url, allocation_id
except Exception as exc:
print(f"[warn] cf-orch allocation failed for {model_id!r}: {exc}", file=sys.stderr)
return None
def _call_model_direct(service_url: str, model: str, prompt: str, timeout: int = 600) -> tuple[str, float]:
"""Call an OpenAI-compatible /v1/chat/completions on a direct service URL."""
t0 = time.monotonic()
resp = httpx.post(
f"{service_url.rstrip('/')}/v1/chat/completions",
json={
"model": model,
"messages": [{"role": "user", "content": prompt}],
"max_tokens": 2048,
"temperature": 0.2,
},
timeout=timeout,
)
resp.raise_for_status()
latency = time.monotonic() - t0
text = resp.json()["choices"][0]["message"]["content"]
return text, latency
def _call_model(api_base: str, model: str, prompt: str, timeout: int = 180) -> tuple[str, float]:
"""Call an OpenAI-compatible /chat/completions endpoint. Returns (text, latency_s)."""
t0 = time.monotonic()
resp = httpx.post(
f"{api_base}/chat/completions",
json={
"model": model,
"messages": [{"role": "user", "content": prompt}],
"max_tokens": 2048,
"temperature": 0.2,
},
timeout=timeout,
)
resp.raise_for_status()
latency = time.monotonic() - t0
text = resp.json()["choices"][0]["message"]["content"]
return text, latency
# ── Benchmark runner ───────────────────────────────────────────────────────────
@dataclass
class PromptResult:
prompt_id: str
prompt_name: str
model_key: str
response: str
latency_s: float
word_count: int
scores: dict[str, float]
total_score: float
error: str | None = None
def run_benchmark(
model_key: str,
model_name: str,
prompts: list[dict[str, Any]] | None = None,
verbose: bool = False,
# cf-orch path
use_cforch: bool = False,
cforch_url: str = CF_COORD_URL,
# direct path (used when not cf-orch)
api_base: str = CF_TEXT_BASE,
) -> list[PromptResult]:
"""Run all prompts through one model. Uses cf-orch allocation when use_cforch=True."""
if prompts is None:
prompts = HELD_OUT_PROMPTS
# Allocate once per model when using cf-orch
service_url: str | None = None
if use_cforch:
print(f" Allocating {model_name!r} via cf-orch…", flush=True)
alloc = _cforch_allocate(model_name, cforch_url)
if alloc is None:
# Return all prompts as errors
return [
PromptResult(
prompt_id=p["id"], prompt_name=p["name"], model_key=model_key,
response="", latency_s=0.0, word_count=0, scores={}, total_score=0.0,
error=f"cf-orch allocation failed for {model_name!r}",
)
for p in prompts
]
service_url, _alloc_id = alloc
results: list[PromptResult] = []
for p in prompts:
if verbose:
print(f" [{p['id']}] {p['name']}", end="", flush=True)
try:
if service_url:
response, latency = _call_model_direct(service_url, model_name, p["prompt"])
else:
response, latency = _call_model(api_base, model_name, p["prompt"])
rubric = score_response(response, p)
result = PromptResult(
prompt_id=p["id"],
prompt_name=p["name"],
model_key=model_key,
response=response,
latency_s=round(latency, 2),
word_count=len(response.split()),
scores=rubric.as_dict(),
total_score=round(rubric.total(), 3),
)
if verbose:
print(f"score={result.total_score:.3f} ({result.word_count}w, {latency:.1f}s)")
except Exception as exc:
result = PromptResult(
prompt_id=p["id"],
prompt_name=p["name"],
model_key=model_key,
response="",
latency_s=0.0,
word_count=0,
scores={},
total_score=0.0,
error=str(exc),
)
if verbose:
print(f"ERROR: {exc}")
results.append(result)
return results
# ── Reporting ──────────────────────────────────────────────────────────────────
def _print_single_report(results: list[PromptResult], model_key: str) -> None:
ok = [r for r in results if not r.error]
err = [r for r in results if r.error]
if not ok:
print(f"\n[{model_key}] All {len(err)} prompts failed.\n")
return
avg_total = sum(r.total_score for r in ok) / len(ok)
avg_latency = sum(r.latency_s for r in ok) / len(ok)
# Aggregate per-rubric averages
rubric_keys = list(ok[0].scores.keys())
rubric_avgs = {k: sum(r.scores.get(k, 0) for r in ok) / len(ok) for k in rubric_keys}
print(f"\n{'='*60}")
print(f" Model : {model_key}")
print(f" Prompts: {len(ok)}/{len(results)} passed ({len(err)} errors)")
print(f" Overall score : {avg_total:.3f} (avg latency {avg_latency:.1f}s)")
print(f"\n Rubric breakdown:")
for k, v in sorted(rubric_avgs.items(), key=lambda x: -x[1]):
bar = "" * int(v * 20)
print(f" {k:<22} {v:.3f} {bar}")
print(f"\n Per-prompt scores:")
for r in sorted(ok, key=lambda x: -x.total_score):
flag = "" if r.total_score < 0.3 else " "
print(f" {flag} {r.prompt_id} {r.prompt_name:<35} {r.total_score:.3f} ({r.word_count}w)")
if err:
print(f"\n Errors:")
for r in err:
print(f" {r.prompt_id} {r.prompt_name}: {r.error}")
print(f"{'='*60}\n")
def _print_comparison_table(all_results: dict[str, list[PromptResult]]) -> None:
model_keys = list(all_results.keys())
prompt_ids = [p["id"] for p in HELD_OUT_PROMPTS]
# Scores by (model, prompt_id)
score_map: dict[tuple[str, str], float] = {}
for mk, results in all_results.items():
for r in results:
score_map[(mk, r.prompt_id)] = r.total_score if not r.error else 0.0
col_w = 10
header = f"{'Prompt':<35}" + "".join(f"{mk[:col_w-1]:<{col_w}}" for mk in model_keys)
print(f"\n{'='*len(header)}")
print(" COMPARISON TABLE")
print(f"{'='*len(header)}")
print(f" {header}")
print(f" {'-'*len(header)}")
for pid in prompt_ids:
pname = next(p["name"] for p in HELD_OUT_PROMPTS if p["id"] == pid)
row = f" {pname:<35}"
best = max(score_map.get((mk, pid), 0.0) for mk in model_keys)
for mk in model_keys:
v = score_map.get((mk, pid), 0.0)
marker = "*" if v == best and len(model_keys) > 1 else " "
row += f"{v:.3f}{marker} "
print(row)
print(f" {'-'*len(header)}")
avgs_row = f" {'AVERAGE':<35}"
best_avg = -1.0
avgs: dict[str, float] = {}
for mk in model_keys:
vals = [score_map.get((mk, pid), 0.0) for pid in prompt_ids]
avgs[mk] = sum(vals) / len(vals)
best_avg = max(best_avg, avgs[mk])
for mk in model_keys:
marker = "*" if avgs[mk] == best_avg and len(model_keys) > 1 else " "
avgs_row += f"{avgs[mk]:.3f}{marker} "
print(avgs_row)
print(f"{'='*len(header)}\n")
if len(model_keys) > 1:
winner = max(avgs, key=lambda k: avgs[k])
print(f" Winner: {winner} (avg {avgs[winner]:.3f})\n")
# ── CLI ────────────────────────────────────────────────────────────────────────
def main() -> None:
parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter)
parser.add_argument("--list-models", action="store_true",
help="Print registered model shortcuts and exit")
parser.add_argument("--model", metavar="KEY",
help="Benchmark a single model (registry key or raw model name)")
parser.add_argument("--compare", nargs="+", metavar="KEY",
help="Compare two or more models side-by-side")
parser.add_argument("--cforch", action="store_true",
help="Route inference through cf-orch coordinator (allocate per model)")
parser.add_argument("--cforch-url", default=CF_COORD_URL, metavar="URL",
help=f"cf-orch coordinator URL (default: {CF_COORD_URL})")
parser.add_argument("--api-base", default=None,
help="Direct API base URL when not using cf-orch")
parser.add_argument("--model-name", default=None,
help="Override model name sent to API (single-model runs only)")
parser.add_argument("--prompts", nargs="+", metavar="ID",
help="Run only specific prompt IDs (e.g. ho_001 ho_003)")
parser.add_argument("--output", type=Path, default=None,
help="Write detailed JSON results to this path")
parser.add_argument("--workers", type=int, default=1, metavar="N",
help="Run N models concurrently (default 1). Set to number of available nodes.")
parser.add_argument("--verbose", "-v", action="store_true",
help="Print per-prompt progress")
args = parser.parse_args()
if args.list_models:
print("\nRegistered model shortcuts:")
for key, info in MODEL_REGISTRY.items():
print(f" {key:<20} {info['description']}")
print(f"\nDefault endpoints:")
print(f" direct {CF_TEXT_BASE}")
print(f" cf-orch {CF_COORD_URL}")
return
prompts = HELD_OUT_PROMPTS
if args.prompts:
ids = set(args.prompts)
prompts = [p for p in HELD_OUT_PROMPTS if p["id"] in ids]
if not prompts:
print(f"No prompts matched IDs: {args.prompts}", file=sys.stderr)
sys.exit(1)
model_keys: list[str] = []
if args.compare:
model_keys = args.compare
elif args.model:
model_keys = [args.model]
else:
parser.print_help()
sys.exit(0)
all_results: dict[str, list[PromptResult]] = {}
print_lock = threading.Lock()
def _run_one(mk: str) -> tuple[str, list[PromptResult]]:
if mk in MODEL_REGISTRY:
reg = MODEL_REGISTRY[mk]
model_name = args.model_name or reg["model"]
direct_base = args.api_base or reg["api_base"]
else:
model_name = args.model_name or mk
direct_base = args.api_base or CF_TEXT_BASE
if args.cforch:
with print_lock:
print(f"\nRunning [{mk}] via cf-orch ({args.cforch_url}) model={model_name}")
results = run_benchmark(
mk, model_name, prompts=prompts, verbose=args.verbose,
use_cforch=True, cforch_url=args.cforch_url,
)
else:
with print_lock:
print(f"\nRunning [{mk}] → {direct_base} model={model_name}")
results = run_benchmark(
mk, model_name, prompts=prompts, verbose=args.verbose,
api_base=direct_base,
)
with print_lock:
_print_single_report(results, mk)
return mk, results
workers = max(1, args.workers)
if workers == 1 or len(model_keys) == 1:
for mk in model_keys:
mk_out, results = _run_one(mk)
all_results[mk_out] = results
else:
with ThreadPoolExecutor(max_workers=workers) as pool:
futures = {pool.submit(_run_one, mk): mk for mk in model_keys}
for fut in as_completed(futures):
mk_out, results = fut.result()
all_results[mk_out] = results
if len(model_keys) > 1:
_print_comparison_table(all_results)
if args.output:
args.output.parent.mkdir(parents=True, exist_ok=True)
payload = {
mk: [asdict(r) for r in results]
for mk, results in all_results.items()
}
with open(args.output, "w", encoding="utf-8") as f:
json.dump(payload, f, indent=2, ensure_ascii=False)
print(f"Wrote detailed results to {args.output}")
if __name__ == "__main__":
main()

View file

@ -7,19 +7,26 @@ from __future__ import annotations
import abc
from collections import defaultdict
import httpx
import logging
from pathlib import Path
from typing import Any
__all__ = [
"LABELS",
"LABEL_DESCRIPTIONS",
"DEFAULT_EXEMPLARS",
"compute_metrics",
"ClassifierAdapter",
"ZeroShotAdapter",
"GLiClassAdapter",
"RerankerAdapter",
"FineTunedAdapter",
"EmbeddingKNNAdapter",
]
_logger = logging.getLogger(__name__)
LABELS: list[str] = [
"interview_scheduled",
"offer_received",
@ -117,6 +124,81 @@ def compute_metrics(
return result
def _cosine(a: list[float], b: list[float]) -> float:
dot = sum(x * y for x, y in zip(a, b))
norm_a = sum(x * x for x in a) ** 0.5
norm_b = sum(x * x for x in b) ** 0.5
return dot / (norm_a * norm_b) if norm_a and norm_b else 0.0
DEFAULT_EXEMPLARS: dict[str, list[str]] = {
"interview_scheduled": [
"Subject: Interview Invitation\n\nWe would like to invite you for a phone screen next week.",
"Subject: Schedule a call\n\nCould you be available for a video interview on Tuesday?",
"Subject: Next Steps\n\nWe'd like to move forward with a technical interview. Please select a time.",
"Subject: Interview Details\n\nHere are the dial-in instructions for your interview tomorrow.",
],
"offer_received": [
"Subject: Offer Letter Enclosed\n\nWe are pleased to extend you an offer of employment.",
"Subject: Job Offer\n\nDear candidate, we are excited to offer you the position of Software Engineer.",
"Subject: Employment Offer\n\nPlease find attached your formal offer letter and compensation details.",
"Subject: Offer of Employment\n\nCongratulations! We would like to offer you a full-time position.",
],
"rejected": [
"Subject: Your Application\n\nAfter careful consideration, we have decided to move forward with other candidates.",
"Subject: Application Status\n\nWe regret to inform you that your application has not been selected.",
"Subject: Thank you for applying\n\nWe appreciate your interest but have chosen not to proceed.",
"Subject: Update on your candidacy\n\nWe will not be moving forward with your application at this time.",
],
"positive_response": [
"Subject: Your profile\n\nI came across your LinkedIn and think you would be a great fit for our team.",
"Subject: Exciting opportunity\n\nWe were impressed by your background and would love to connect.",
"Subject: Following up\n\nThank you for your interest — we'd like to learn more about your experience.",
"Subject: Great fit\n\nYour skills align well with what we are looking for. Let's set up a call.",
],
"survey_received": [
"Subject: Candidate Experience Survey\n\nPlease complete this brief survey about your application experience.",
"Subject: Culture Fit Assessment\n\nAs part of our process, we ask all candidates to complete a short assessment.",
"Subject: Skills Assessment\n\nWe'd like you to complete our online coding assessment before proceeding.",
"Subject: Personality Assessment\n\nPlease complete the following assessment as the next step in our process.",
"Subject: Pre-interview questionnaire\n\nBefore we schedule your interview, please complete this brief skills survey.",
],
"neutral": [
"Subject: Application Received\n\nWe have received your application and will be in touch.",
"Subject: Thank you for applying\n\nYour application is under review. We will contact you if needed.",
"Subject: Confirmation\n\nThis email confirms receipt of your application to our company.",
"Subject: Application Confirmation\n\nThank you for your interest. We will review your materials and follow up.",
],
"event_rescheduled": [
"Subject: Interview Rescheduled\n\nDue to a conflict, we need to move your interview to a new time.",
"Subject: Change of interview time\n\nWe apologize — your interview has been rescheduled to Thursday.",
"Subject: Updated interview details\n\nYour interview has been moved from Monday to Wednesday at 2pm.",
"Subject: Reschedule request\n\nWould you be available to reschedule to a different time slot?",
"Subject: New interview time\n\nYour phone screen has been moved from tomorrow to next week.",
],
"digest": [
"Subject: 15 new jobs matching your search\n\nHere are the latest job postings that match your profile.",
"Subject: Weekly Job Digest\n\nThis week's top opportunities for Software Engineers in your area.",
"Subject: Jobs you might like\n\nBased on your profile, here are some positions we recommend.",
"Subject: New jobs for you\n\nSee the latest openings from companies on your watchlist.",
],
"new_lead": [
"Subject: Exciting opportunity at our company\n\nHi, I noticed your background and think you'd be a great fit.",
"Subject: Are you open to new opportunities?\n\nI'm a recruiter reaching out about a role matching your experience.",
"Subject: Quick question\n\nWould you be interested in hearing about a senior engineering role?",
"Subject: Recruiting outreach\n\nI came across your profile and wanted to share an exciting opening.",
],
"hired": [
"Subject: Welcome to the team!\n\nWe are thrilled to have you join us. Here are your onboarding details.",
"Subject: Onboarding information\n\nCongratulations on accepting our offer. Your start date is confirmed.",
"Subject: First day information\n\nWe look forward to your first day. Please arrive at 9am and ask for HR.",
"Subject: Background check initiated\n\nAs part of your onboarding, we have initiated a background check.",
"Subject: Equipment setup\n\nYour laptop and equipment will be ready for pickup on your first day.",
],
}
class ClassifierAdapter(abc.ABC):
"""Abstract base for all email classifier adapters."""
@ -304,3 +386,148 @@ class FineTunedAdapter(ClassifierAdapter):
text = f"{subject} [SEP] {body[:400]}"
result = self._pipeline(text)
return result[0]["label"]
class EmbeddingKNNAdapter(ClassifierAdapter):
"""k-NN email classifier using Ollama /v1/embeddings via cf-orch allocation.
load():
1. Allocates an Ollama instance from cf-orch (POST /api/services/ollama/allocate).
Falls back to ollama_url directly if orch allocation fails or is not configured.
2. Pre-embeds all exemplar texts and stores per-label vector lists.
classify(subject, body):
Embeds the input email, computes cosine similarity against all stored exemplar
vectors, and majority-votes the top-k labels (default k=3). Tie-break: label
with the highest total similarity score among tied vote counts wins.
unload():
Releases the cf-orch allocation (DELETE .../allocations/{id}) and clears state.
"""
def __init__(
self,
name: str,
model_id: str,
*,
k: int = 3,
orch_url: str = "",
ollama_url: str = "",
exemplar_texts: dict[str, list[str]] | None = None,
) -> None:
self._name = name
self._model_id = model_id
self._k = k
self._orch_url = orch_url
self._ollama_url = ollama_url
self._exemplar_texts: dict[str, list[str]] = (
exemplar_texts if exemplar_texts is not None else DEFAULT_EXEMPLARS
)
self._exemplar_embeddings: dict[str, list[list[float]]] = {}
self._node_url: str = ""
self._allocation_id: str = ""
self._orch_url_used: str = ""
@property
def name(self) -> str:
return self._name
@property
def model_id(self) -> str:
return self._model_id
def _resolve_urls(self) -> tuple[str, str]:
if self._orch_url or self._ollama_url:
return self._orch_url, self._ollama_url
import yaml # noqa: PLC0415
cfg_path = Path(__file__).parent.parent / "config" / "label_tool.yaml"
cfg: dict = {}
if cfg_path.exists():
try:
cfg = yaml.safe_load(cfg_path.read_text(encoding="utf-8")) or {}
except yaml.YAMLError:
pass
cforch = cfg.get("cforch", {}) or {}
return cforch.get("coordinator_url", ""), cforch.get("ollama_url", "")
def _embed(self, node_url: str, texts: list[str]) -> list[list[float]]:
resp = httpx.post(
f"{node_url}/v1/embeddings",
json={"model": self._model_id, "input": texts},
timeout=30.0,
)
resp.raise_for_status()
return [item["embedding"] for item in resp.json()["data"]]
def load(self) -> None:
if self._allocation_id or self._exemplar_embeddings:
raise RuntimeError(
"EmbeddingKNNAdapter.load() called while already loaded — call unload() first"
)
orch_url, ollama_url = self._resolve_urls()
node_url = ""
orch_url_used = ""
if orch_url:
try:
resp = httpx.post(
f"{orch_url}/api/services/ollama/allocate",
json={"model": self._model_id},
timeout=15.0,
)
if resp.status_code == 200:
data = resp.json()
node_url = data["url"]
self._allocation_id = data["allocation_id"]
orch_url_used = orch_url
except Exception as exc:
_logger.warning(
"cf-orch allocation failed, falling back to direct ollama_url: %s", exc
)
if not node_url:
node_url = ollama_url
self._allocation_id = ""
orch_url_used = ""
self._node_url = node_url
self._orch_url_used = orch_url_used
try:
embeddings: dict[str, list[list[float]]] = {}
for label, texts in self._exemplar_texts.items():
embeddings[label] = self._embed(node_url, texts)
self._exemplar_embeddings = embeddings
except Exception:
self.unload()
raise
def unload(self) -> None:
if self._allocation_id and self._orch_url_used:
try:
httpx.request(
"DELETE",
f"{self._orch_url_used}/api/services/ollama/allocations/{self._allocation_id}",
timeout=10.0,
)
except Exception:
pass
self._exemplar_embeddings = {}
self._node_url = ""
self._allocation_id = ""
self._orch_url_used = ""
def classify(self, subject: str, body: str) -> str:
if not self._exemplar_embeddings:
self.load()
text = f"Subject: {subject}\n\n{body[:600]}"
[query_vec] = self._embed(self._node_url, [text])
scored: list[tuple[float, str]] = [
(_cosine(query_vec, vec), label)
for label, vecs in self._exemplar_embeddings.items()
for vec in vecs
]
top_k = sorted(scored, reverse=True)[: self._k]
votes: dict[str, list[float]] = {}
for score, label in top_k:
votes.setdefault(label, []).append(score)
return max(
votes,
key=lambda lbl: (len(votes[lbl]), sum(votes[lbl])),
)

458
scripts/export_plans.py Normal file
View file

@ -0,0 +1,458 @@
"""Export circuitforge-plans/ documents as instruction-tuning JSONL pairs.
Each record is a HuggingFace chat-format example:
{
"id": "<sha256>",
"messages": [
{"role": "user", "content": "<reconstructed planning prompt>"},
{"role": "assistant", "content": "<cleaned document content>"}
],
"meta": {
"source": "peregrine/2026-03-03-feedback-button-design.md",
"product": "peregrine",
"doc_type": "design", # design | plan | spec | implementation | other
"date": "2026-03-03",
"paired_with": "...", # sibling path, or null
"word_count": 1847,
"pair_role": "context" # "context" | "target" | "standalone"
}
}
Pairing strategy
----------------
When a design doc and a plan doc share the same date + feature-name prefix,
they are treated as a pair:
- design plan: instruction = "Given this design doc, write the implementation plan."
context appended = full design doc content.
- Solo docs get a synthetic instruction from the title + first overview section.
Usage
-----
# Preview stats and 5 sample records
python scripts/export_plans.py --preview
# Write full output
python scripts/export_plans.py --output data/plan_pairs.jsonl
# Restrict to specific products
python scripts/export_plans.py --products peregrine,kiwi --output data/plan_pairs.jsonl
"""
from __future__ import annotations
import argparse
import hashlib
import json
import re
import sys
from pathlib import Path
from typing import Iterator
# ── Paths ──────────────────────────────────────────────────────────────────────
_SCRIPT_DIR = Path(__file__).parent
_AVOCET_ROOT = _SCRIPT_DIR.parent
_DEFAULT_PLANS_DIR = Path("/Library/Development/CircuitForge/circuitforge-plans")
_DEFAULT_OUTPUT = _AVOCET_ROOT / "data" / "plan_pairs.jsonl"
# ── Doc type detection ─────────────────────────────────────────────────────────
_TYPE_RE = re.compile(
r"-(design|plan|spec|implementation|specs|plans)s?$",
re.IGNORECASE,
)
_SKIP_DIRS = {"__pycache__", ".git", "node_modules"}
# Boilerplate lines to strip from document content before using as output.
_BOILERPLATE_RE = re.compile(
r"""
^\s*>\s*\*\*For\s+agentic\s+workers.* # superpowers agent hints
|^\s*>\s*REQUIRED\s+SUB-SKILL.*
|^\s*\*\*Date:\*\*.* # metadata header lines
|\*\*Status:\*\*\s*Complete.* # completed-feature noise
|\*\*Status:\*\*\s*Done.*
|\*\*Product:\*\*.*
|\*\*Repo:\*\*.*
|\*\*Tech\s+Stack:\*\*.*
|\*\*Candidate:\*\*.* # old synthetic personas
|^Candidate:.*
|^Team:.*
""",
re.VERBOSE | re.MULTILINE,
)
# Old repo/path names to normalise to current equivalents.
_PATH_NORMALIZATIONS: list[tuple[re.Pattern, str]] = [
(re.compile(r"/devl/job-seeker", re.IGNORECASE), "/Library/Development/CircuitForge/peregrine"),
(re.compile(r"\bjob-seeker\b", re.IGNORECASE), "peregrine"),
(re.compile(r"Alex Rivera", re.IGNORECASE), "[user]"),
]
# Instruction paraphrase templates per doc type.
# Each entry is (user_prefix, paired_prefix).
# {title}, {product}, {type_phrase}, {overview}, {design_context} are substituted.
_DESIGN_INSTRUCTIONS = [
"Write a design document for {product}: {title}.\n\nContext: {overview}",
"You are a software architect working on {product}. Draft a design spec for: {title}.\n\n{overview}",
"Produce a CircuitForge-style design document for the following {product} feature — {title}.\n\nBackground: {overview}",
]
_PLAN_INSTRUCTIONS = [
"Write an implementation plan for {product}: {title}.\n\nContext: {overview}",
"Break the following {product} feature into a detailed implementation plan with file structure and task checkboxes — {title}.\n\n{overview}",
"You are a senior engineer on {product}. Produce a step-by-step engineering plan for: {title}.\n\n{overview}",
]
_PAIRED_INSTRUCTIONS = [
(
"You are a software architect working on {product}, a CircuitForge product. "
"Given the following design document, write a detailed implementation plan "
"(file structure, task breakdown with checkboxes, migration steps if needed).\n\n"
"---\n{design_context}\n---"
),
(
"The following is a design spec for a {product} feature. "
"Produce a concrete implementation plan: file list, task checklist, any DB migrations needed.\n\n"
"---\n{design_context}\n---"
),
(
"Convert this {product} design document into an actionable implementation plan. "
"Include all files to create/modify, step-by-step tasks with checkboxes, and migration steps.\n\n"
"---\n{design_context}\n---"
),
]
def _doc_type(stem: str) -> str:
m = _TYPE_RE.search(stem)
if not m:
return "other"
raw = m.group(1).lower().rstrip("s")
return {"implementation": "plan"}.get(raw, raw)
def _date_feature(stem: str) -> tuple[str, str]:
"""Return (date, feature_slug) from '2026-03-03-feedback-button-design'."""
m = re.match(r"^(\d{4}-\d{2}-\d{2})-(.+?)(?:-(design|plan|spec|implementation)s?)?$", stem, re.I)
if m:
return m.group(1), m.group(2)
return "", stem
# ── Content extraction ─────────────────────────────────────────────────────────
def _extract_title(content: str) -> str:
m = re.search(r"^#\s+(.+)", content, re.MULTILINE)
return m.group(1).strip() if m else ""
def _extract_overview(content: str) -> str:
"""Return first substantive paragraph or h2 section body (≤300 chars)."""
# Superpowers plans have an explicit **Goal:** line — prefer that.
goal_m = re.search(r"\*\*Goal:\*\*\s*(.+)", content)
if goal_m:
return goal_m.group(1).strip()[:300]
# Otherwise use the body of the first h2 section.
h2_m = re.search(
r"^##\s+\d*\.?\s*.+\n([\s\S]+?)(?=^##|\Z)",
content,
re.MULTILINE,
)
if h2_m:
body = h2_m.group(1).strip()
# Strip markdown bullet/code noise for the instruction
body = re.sub(r"```[\s\S]*?```", "", body)
body = re.sub(r"`[^`]+`", lambda m: m.group().strip("`"), body)
body = re.sub(r"\*\*([^*]+)\*\*", r"\1", body)
body = re.sub(r"\s+", " ", body).strip()
return body[:300]
return ""
def _clean_content(content: str) -> str:
"""Remove boilerplate, normalize old paths/names, collapse whitespace."""
cleaned = _BOILERPLATE_RE.sub("", content)
for pattern, replacement in _PATH_NORMALIZATIONS:
cleaned = pattern.sub(replacement, cleaned)
cleaned = re.sub(r"\n{3,}", "\n\n", cleaned)
return cleaned.strip()
def _quality_flags(content: str) -> list[str]:
"""Return a list of quality issue labels found in cleaned content."""
flags = []
if "Alex Rivera" in content or "[user]" in content:
flags.append("persona-residue")
if re.search(r"\bStatus:\s*(Complete|Done|Merged)\b", content):
flags.append("completed-status")
return flags
def _make_instruction(
title: str,
product: str,
doc_type: str,
overview: str,
design_context: str | None = None,
variant: int = 0,
) -> str:
"""Synthesise a natural planning prompt for this document.
variant: 0-2 selects which paraphrase template to use. Caller cycles
through all three to produce multiple training examples per document.
"""
product_label = product.replace("-", " ").title() if product else "CircuitForge"
idx = variant % 3
if design_context:
tmpl = _PAIRED_INSTRUCTIONS[idx]
return tmpl.format(
product=product_label,
design_context=design_context[:2500],
)
templates = _PLAN_INSTRUCTIONS if doc_type in ("plan",) else _DESIGN_INSTRUCTIONS
tmpl = templates[idx]
return tmpl.format(
product=product_label,
title=title,
overview=overview or "",
type_phrase="planning document",
)
def _record_id(content: str, source: str) -> str:
return hashlib.sha256(f"{source}:{content}".encode()).hexdigest()[:16]
# ── Pair discovery ─────────────────────────────────────────────────────────────
def _find_pairs(plans_dir: Path) -> dict[str, list[tuple[str, Path]]]:
"""Return {prefix_key → [(doc_type, path), ...]} for docs sharing date+feature."""
by_prefix: dict[str, list[tuple[str, Path]]] = {}
for path in plans_dir.rglob("*.md"):
if any(part in _SKIP_DIRS for part in path.parts):
continue
if path.name == "README.md":
continue
stem = path.stem
date, feature = _date_feature(stem)
if not date:
continue
key = str(path.parent / f"{date}-{feature}")
by_prefix.setdefault(key, []).append((_doc_type(stem), path))
return by_prefix
# ── Record generation ──────────────────────────────────────────────────────────
def _records_for_group(
doc_type_paths: list[tuple[str, Path]],
plans_dir: Path,
) -> Iterator[dict]:
"""Yield one or more training records for a group of related docs."""
# Separate design vs plan docs within this group
designs = [(t, p) for t, p in doc_type_paths if t in ("design", "spec")]
plans_ = [(t, p) for t, p in doc_type_paths if t in ("plan",)]
others = [(t, p) for t, p in doc_type_paths if t not in ("design", "spec", "plan")]
all_paths = doc_type_paths
if designs and plans_:
# Paired: yield a design→plan record (3 instruction variants)
design_type, design_path = designs[0]
plan_type, plan_path = plans_[0]
design_content = design_path.read_text(encoding="utf-8")
plan_content = plan_path.read_text(encoding="utf-8")
product = _product_from_path(plan_path, plans_dir)
title = _extract_title(plan_content) or plan_path.stem
cleaned = _clean_content(plan_content)
design_cleaned = _clean_content(design_content)
flags = _quality_flags(cleaned)
if len(cleaned.split()) >= 80:
rel_src = str(plan_path.relative_to(plans_dir))
rel_design = str(design_path.relative_to(plans_dir))
for variant in range(3):
instruction = _make_instruction(
title=title,
product=product,
doc_type="plan",
overview=_extract_overview(design_content),
design_context=design_cleaned,
variant=variant,
)
yield {
"id": _record_id(f"v{variant}:{cleaned}", rel_src),
"messages": [
{"role": "user", "content": instruction},
{"role": "assistant", "content": cleaned},
],
"meta": {
"source": rel_src,
"product": product,
"doc_type": "plan",
"date": _date_feature(plan_path.stem)[0],
"paired_with": rel_design,
"word_count": len(cleaned.split()),
"pair_role": "target",
"variant": variant,
"quality_flags": flags,
},
}
# Also yield the design doc as standalone variants
all_paths = [(t, p) for t, p in all_paths if p != plan_path]
# Remaining docs as standalone records (3 instruction variants each)
for doc_type, path in all_paths:
content = path.read_text(encoding="utf-8")
cleaned = _clean_content(content)
if len(cleaned.split()) < 80:
continue
product = _product_from_path(path, plans_dir)
title = _extract_title(content) or path.stem
overview = _extract_overview(content)
flags = _quality_flags(cleaned)
rel_src = str(path.relative_to(plans_dir))
for variant in range(3):
instruction = _make_instruction(
title=title,
product=product,
doc_type=doc_type,
overview=overview,
variant=variant,
)
yield {
"id": _record_id(f"v{variant}:{cleaned}", rel_src),
"messages": [
{"role": "user", "content": instruction},
{"role": "assistant", "content": cleaned},
],
"meta": {
"source": rel_src,
"product": product,
"doc_type": doc_type,
"date": _date_feature(path.stem)[0],
"paired_with": None,
"word_count": len(cleaned.split()),
"pair_role": "standalone",
"variant": variant,
"quality_flags": flags,
},
}
def _product_from_path(path: Path, plans_dir: Path) -> str:
rel = path.relative_to(plans_dir)
return rel.parts[0] if len(rel.parts) > 1 else "shared"
# ── Main export ────────────────────────────────────────────────────────────────
def export(
plans_dir: Path,
products: list[str] | None = None,
) -> list[dict]:
groups = _find_pairs(plans_dir)
records: list[dict] = []
seen_ids: set[str] = set()
for group_key, doc_type_paths in groups.items():
# Filter by product if requested
if products:
paths = [p for _, p in doc_type_paths]
prods = {_product_from_path(p, plans_dir) for p in paths}
if not prods.intersection(products):
continue
for record in _records_for_group(doc_type_paths, plans_dir):
if record["id"] not in seen_ids:
seen_ids.add(record["id"])
records.append(record)
return records
# ── CLI ────────────────────────────────────────────────────────────────────────
def _print_stats(records: list[dict]) -> None:
from collections import Counter
products = Counter(r["meta"]["product"] for r in records)
doc_types = Counter(r["meta"]["doc_type"] for r in records)
pair_roles = Counter(r["meta"]["pair_role"] for r in records)
wc = [r["meta"]["word_count"] for r in records]
wc.sort()
print(f"\n{'='*55}")
print(f" Total records: {len(records)}")
print(f" Word counts : min={wc[0]}, median={wc[len(wc)//2]}, max={wc[-1]}")
print(f"\n By product:")
for p, n in products.most_common():
print(f" {p:<22} {n}")
print(f"\n By doc type:")
for t, n in doc_types.most_common():
print(f" {t:<22} {n}")
print(f"\n Pair roles:")
for r, n in pair_roles.most_common():
print(f" {r:<22} {n}")
print(f"{'='*55}\n")
def _print_sample(records: list[dict], n: int = 3) -> None:
import random
sample = random.sample(records, min(n, len(records)))
for i, rec in enumerate(sample, 1):
meta = rec["meta"]
user_msg = rec["messages"][0]["content"]
asst_msg = rec["messages"][1]["content"]
print(f"\n{''*55}")
print(f"SAMPLE {i}/{n} [{meta['product']} / {meta['doc_type']} / {meta['pair_role']}]")
print(f"source: {meta['source']}")
print(f"\nUSER ({len(user_msg)} chars):\n{user_msg[:500]}{'...' if len(user_msg)>500 else ''}")
print(f"\nASSISTANT ({meta['word_count']} words):\n{asst_msg[:400]}{'...' if len(asst_msg)>400 else ''}")
print(f"\n{''*55}\n")
def main() -> None:
parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter)
parser.add_argument("--plans-dir", type=Path, default=_DEFAULT_PLANS_DIR)
parser.add_argument("--output", type=Path, default=None,
help="Write JSONL to this path (omit for preview-only)")
parser.add_argument("--products", default=None,
help="Comma-separated product filter, e.g. peregrine,kiwi")
parser.add_argument("--preview", action="store_true",
help="Print stats + sample records, don't write output")
parser.add_argument("--samples", type=int, default=3,
help="Number of sample records to show in preview (default 3)")
args = parser.parse_args()
products = [p.strip() for p in args.products.split(",")] if args.products else None
print(f"Scanning {args.plans_dir}", file=sys.stderr)
records = export(args.plans_dir, products=products)
_print_stats(records)
if args.preview or args.output is None:
_print_sample(records, n=args.samples)
if args.output is None:
print("(Pass --output <path> to write JSONL)")
return
args.output.parent.mkdir(parents=True, exist_ok=True)
with open(args.output, "w", encoding="utf-8") as f:
for rec in records:
f.write(json.dumps(rec, ensure_ascii=False) + "\n")
print(f"Wrote {len(records)} records to {args.output}")
if __name__ == "__main__":
main()

View file

@ -1,23 +1,37 @@
import json
"""Smoke tests for the app factory (app/api.py).
Detailed route tests live in test_data_label.py, test_data_fetch.py,
test_data_corrections.py, test_train.py, and test_dashboard.py.
"""
import pytest
from app import api as api_module # noqa: F401
@pytest.fixture(autouse=True)
def reset_globals(tmp_path):
from app import api
api.set_data_dir(tmp_path)
api.reset_last_action()
yield
api.reset_last_action()
from fastapi.testclient import TestClient
def test_import():
from app import api # noqa: F401
from fastapi.testclient import TestClient
def test_app_has_required_routes():
from app.api import app
paths = {r.path for r in app.routes}
# Label routes
assert "/api/queue" in paths
assert "/api/label" in paths
assert "/api/skip" in paths
assert "/api/discard" in paths
assert "/api/label/undo" in paths
assert "/api/config/labels" in paths
assert "/api/stats" in paths
# Fetch routes
assert "/api/accounts/test" in paths
assert "/api/fetch/stream" in paths
# Train routes
assert "/api/train/jobs" in paths
assert "/api/train/results" in paths
# Dashboard
assert "/api/dashboard" in paths
# Corrections (new prefix)
assert "/api/corrections/ingest" in paths
@pytest.fixture
@ -26,536 +40,8 @@ def client():
return TestClient(app)
@pytest.fixture
def queue_with_items():
"""Write 3 test emails to the queue file."""
from app import api as api_module
items = [
{"id": f"id{i}", "subject": f"Subject {i}", "body": f"Body {i}",
"from": "test@example.com", "date": "2026-03-01", "source": "imap:test"}
for i in range(3)
]
queue_path = api_module._DATA_DIR / "email_label_queue.jsonl"
queue_path.write_text("\n".join(json.dumps(x) for x in items) + "\n")
return items
def test_queue_returns_items(client, queue_with_items):
r = client.get("/api/queue?limit=2")
assert r.status_code == 200
data = r.json()
assert len(data["items"]) == 2
assert data["total"] == 3
def test_queue_empty_when_no_file(client):
def test_queue_endpoint_reachable(client):
r = client.get("/api/queue")
assert r.status_code == 200
assert r.json() == {"items": [], "total": 0}
def test_label_appends_to_score(client, queue_with_items):
from app import api as api_module
r = client.post("/api/label", json={"id": "id0", "label": "interview_scheduled"})
assert r.status_code == 200
records = api_module._read_jsonl(api_module._score_file())
assert len(records) == 1
assert records[0]["id"] == "id0"
assert records[0]["label"] == "interview_scheduled"
assert "labeled_at" in records[0]
def test_label_removes_from_queue(client, queue_with_items):
from app import api as api_module
client.post("/api/label", json={"id": "id0", "label": "rejected"})
queue = api_module._read_jsonl(api_module._queue_file())
assert not any(x["id"] == "id0" for x in queue)
def test_label_unknown_id_returns_404(client, queue_with_items):
r = client.post("/api/label", json={"id": "unknown", "label": "neutral"})
assert r.status_code == 404
def test_skip_moves_to_back(client, queue_with_items):
from app import api as api_module
r = client.post("/api/skip", json={"id": "id0"})
assert r.status_code == 200
queue = api_module._read_jsonl(api_module._queue_file())
assert queue[-1]["id"] == "id0"
assert queue[0]["id"] == "id1"
def test_skip_unknown_id_returns_404(client, queue_with_items):
r = client.post("/api/skip", json={"id": "nope"})
assert r.status_code == 404
# --- Part A: POST /api/discard ---
def test_discard_writes_to_discarded_file(client, queue_with_items):
from app import api as api_module
r = client.post("/api/discard", json={"id": "id1"})
assert r.status_code == 200
discarded = api_module._read_jsonl(api_module._discarded_file())
assert len(discarded) == 1
assert discarded[0]["id"] == "id1"
assert discarded[0]["label"] == "__discarded__"
def test_discard_removes_from_queue(client, queue_with_items):
from app import api as api_module
client.post("/api/discard", json={"id": "id1"})
queue = api_module._read_jsonl(api_module._queue_file())
assert not any(x["id"] == "id1" for x in queue)
# --- Part B: DELETE /api/label/undo ---
def test_undo_label_removes_from_score(client, queue_with_items):
from app import api as api_module
client.post("/api/label", json={"id": "id0", "label": "neutral"})
r = client.delete("/api/label/undo")
assert r.status_code == 200
data = r.json()
assert data["undone"]["type"] == "label"
score = api_module._read_jsonl(api_module._score_file())
assert score == []
# Item should be restored to front of queue
queue = api_module._read_jsonl(api_module._queue_file())
assert queue[0]["id"] == "id0"
def test_undo_discard_removes_from_discarded(client, queue_with_items):
from app import api as api_module
client.post("/api/discard", json={"id": "id0"})
r = client.delete("/api/label/undo")
assert r.status_code == 200
discarded = api_module._read_jsonl(api_module._discarded_file())
assert discarded == []
def test_undo_skip_restores_to_front(client, queue_with_items):
from app import api as api_module
client.post("/api/skip", json={"id": "id0"})
r = client.delete("/api/label/undo")
assert r.status_code == 200
queue = api_module._read_jsonl(api_module._queue_file())
assert queue[0]["id"] == "id0"
def test_undo_with_no_action_returns_404(client):
r = client.delete("/api/label/undo")
assert r.status_code == 404
# --- Part C: GET /api/config/labels ---
def test_config_labels_returns_metadata(client):
r = client.get("/api/config/labels")
assert r.status_code == 200
labels = r.json()
assert len(labels) == 10
assert labels[0]["key"] == "1"
assert "emoji" in labels[0]
assert "color" in labels[0]
assert "name" in labels[0]
# ── /api/config ──────────────────────────────────────────────────────────────
@pytest.fixture
def config_dir(tmp_path):
"""Give the API a writable config directory."""
from app import api as api_module
api_module.set_config_dir(tmp_path)
yield tmp_path
api_module.set_config_dir(None) # reset to default
@pytest.fixture
def data_dir():
"""Expose the current _DATA_DIR set by the autouse reset_globals fixture."""
from app import api as api_module
return api_module._DATA_DIR
def test_get_config_returns_empty_when_no_file(client, config_dir):
r = client.get("/api/config")
assert r.status_code == 200
data = r.json()
assert data["accounts"] == []
assert data["max_per_account"] == 500
def test_post_config_writes_yaml(client, config_dir):
import yaml
payload = {
"accounts": [{"name": "Test", "host": "imap.test.com", "port": 993,
"use_ssl": True, "username": "u@t.com", "password": "pw",
"folder": "INBOX", "days_back": 30}],
"max_per_account": 200,
}
r = client.post("/api/config", json=payload)
assert r.status_code == 200
assert r.json()["ok"] is True
cfg_file = config_dir / "label_tool.yaml"
assert cfg_file.exists()
saved = yaml.safe_load(cfg_file.read_text())
assert saved["max_per_account"] == 200
assert saved["accounts"][0]["name"] == "Test"
def test_get_config_round_trips(client, config_dir):
payload = {"accounts": [{"name": "R", "host": "h", "port": 993, "use_ssl": True,
"username": "u", "password": "p", "folder": "INBOX",
"days_back": 90}], "max_per_account": 300}
client.post("/api/config", json=payload)
r = client.get("/api/config")
data = r.json()
assert data["max_per_account"] == 300
assert data["accounts"][0]["name"] == "R"
# ── /api/stats ───────────────────────────────────────────────────────────────
@pytest.fixture
def score_with_labels(tmp_path, data_dir):
"""Write a score file with 3 labels for stats tests."""
score_path = data_dir / "email_score.jsonl"
records = [
{"id": "a", "label": "interview_scheduled"},
{"id": "b", "label": "interview_scheduled"},
{"id": "c", "label": "rejected"},
]
score_path.write_text("\n".join(json.dumps(r) for r in records) + "\n")
return records
def test_stats_returns_counts(client, score_with_labels):
r = client.get("/api/stats")
assert r.status_code == 200
data = r.json()
assert data["total"] == 3
assert data["counts"]["interview_scheduled"] == 2
assert data["counts"]["rejected"] == 1
def test_stats_empty_when_no_file(client, data_dir):
r = client.get("/api/stats")
assert r.status_code == 200
data = r.json()
assert data["total"] == 0
assert data["counts"] == {}
assert data["score_file_bytes"] == 0
def test_stats_download_returns_file(client, score_with_labels):
r = client.get("/api/stats/download")
assert r.status_code == 200
assert "jsonlines" in r.headers.get("content-type", "")
def test_stats_download_404_when_no_file(client, data_dir):
r = client.get("/api/stats/download")
assert r.status_code == 404
# ── /api/accounts/test ───────────────────────────────────────────────────────
def test_account_test_missing_fields(client):
r = client.post("/api/accounts/test", json={"account": {"host": "", "username": "", "password": ""}})
assert r.status_code == 200
data = r.json()
assert data["ok"] is False
assert "required" in data["message"].lower()
def test_account_test_success(client):
from unittest.mock import MagicMock, patch
mock_conn = MagicMock()
mock_conn.select.return_value = ("OK", [b"99"])
with patch("app.imap_fetch.imaplib.IMAP4_SSL", return_value=mock_conn):
r = client.post("/api/accounts/test", json={"account": {
"host": "imap.example.com", "port": 993, "use_ssl": True,
"username": "u@example.com", "password": "pw", "folder": "INBOX",
}})
assert r.status_code == 200
data = r.json()
assert data["ok"] is True
assert data["count"] == 99
# ── /api/fetch/stream (SSE) ──────────────────────────────────────────────────
def _parse_sse(content: bytes) -> list[dict]:
"""Parse SSE response body into list of event dicts."""
events = []
for line in content.decode().splitlines():
if line.startswith("data: "):
events.append(json.loads(line[6:]))
return events
def test_fetch_stream_no_accounts_configured(client, config_dir):
"""With no config, stream should immediately complete with 0 added."""
r = client.get("/api/fetch/stream?accounts=NoSuchAccount&days_back=30&limit=10")
assert r.status_code == 200
events = _parse_sse(r.content)
complete = next((e for e in events if e["type"] == "complete"), None)
assert complete is not None
assert complete["total_added"] == 0
def test_fetch_stream_with_mock_imap(client, config_dir, data_dir):
"""With one configured account, stream should yield start/done/complete events."""
import yaml
from unittest.mock import MagicMock, patch
# Write a config with one account
cfg = {"accounts": [{"name": "Mock", "host": "h", "port": 993, "use_ssl": True,
"username": "u", "password": "p", "folder": "INBOX",
"days_back": 30}], "max_per_account": 50}
(config_dir / "label_tool.yaml").write_text(yaml.dump(cfg))
raw_msg = (b"Subject: Interview\r\nFrom: a@b.com\r\n"
b"Date: Mon, 1 Mar 2026 12:00:00 +0000\r\n\r\nBody")
mock_conn = MagicMock()
mock_conn.search.return_value = ("OK", [b"1"])
mock_conn.fetch.return_value = ("OK", [(b"1 (RFC822 {N})", raw_msg)])
with patch("app.imap_fetch.imaplib.IMAP4_SSL", return_value=mock_conn):
r = client.get("/api/fetch/stream?accounts=Mock&days_back=30&limit=50")
assert r.status_code == 200
events = _parse_sse(r.content)
types = [e["type"] for e in events]
assert "start" in types
assert "done" in types
assert "complete" in types
# ---- /api/finetune/status tests ----
def test_finetune_status_returns_empty_when_no_models_dir(client):
"""GET /api/finetune/status must return [] if models/ does not exist."""
r = client.get("/api/finetune/status")
assert r.status_code == 200
assert r.json() == []
def test_finetune_status_returns_training_info(client, tmp_path):
"""GET /api/finetune/status must return one entry per training_info.json found."""
import json as _json
from app import api as api_module
models_dir = tmp_path / "models" / "avocet-deberta-small"
models_dir.mkdir(parents=True)
info = {
"name": "avocet-deberta-small",
"base_model_id": "cross-encoder/nli-deberta-v3-small",
"val_macro_f1": 0.712,
"timestamp": "2026-03-15T12:00:00Z",
"sample_count": 401,
}
(models_dir / "training_info.json").write_text(_json.dumps(info))
api_module.set_models_dir(tmp_path / "models")
try:
r = client.get("/api/finetune/status")
assert r.status_code == 200
data = r.json()
assert any(d["name"] == "avocet-deberta-small" for d in data)
finally:
api_module.set_models_dir(api_module._ROOT / "models")
def test_finetune_run_streams_sse_events(client):
"""GET /api/finetune/run must return text/event-stream content type."""
from unittest.mock import patch, MagicMock
mock_proc = MagicMock()
mock_proc.stdout = iter(["Training epoch 1\n", "Done\n"])
mock_proc.returncode = 0
mock_proc.wait = MagicMock()
with patch("app.api._subprocess.Popen",return_value=mock_proc):
r = client.get("/api/finetune/run?model=deberta-small&epochs=1")
assert r.status_code == 200
assert "text/event-stream" in r.headers.get("content-type", "")
def test_finetune_run_emits_complete_on_success(client):
"""GET /api/finetune/run must emit a complete event on clean exit."""
from unittest.mock import patch, MagicMock
mock_proc = MagicMock()
mock_proc.stdout = iter(["progress line\n"])
mock_proc.returncode = 0
mock_proc.wait = MagicMock()
with patch("app.api._subprocess.Popen",return_value=mock_proc):
r = client.get("/api/finetune/run?model=deberta-small&epochs=1")
assert '{"type": "complete"}' in r.text
def test_finetune_run_emits_error_on_nonzero_exit(client):
"""GET /api/finetune/run must emit an error event on non-zero exit."""
from unittest.mock import patch, MagicMock
mock_proc = MagicMock()
mock_proc.stdout = iter([])
mock_proc.returncode = 1
mock_proc.wait = MagicMock()
with patch("app.api._subprocess.Popen",return_value=mock_proc):
r = client.get("/api/finetune/run?model=deberta-small&epochs=1")
assert '"type": "error"' in r.text
def test_finetune_run_passes_score_files_to_subprocess(client):
"""GET /api/finetune/run?score=file1&score=file2 must pass --score args to subprocess."""
from unittest.mock import patch, MagicMock
captured_cmd = []
def mock_popen(cmd, **kwargs):
captured_cmd.extend(cmd)
m = MagicMock()
m.stdout = iter([])
m.returncode = 0
m.wait = MagicMock()
return m
with patch("app.api._subprocess.Popen",side_effect=mock_popen):
client.get("/api/finetune/run?model=deberta-small&epochs=1&score=run1.jsonl&score=run2.jsonl")
assert "--score" in captured_cmd
assert captured_cmd.count("--score") == 2
# Paths are resolved to absolute — check filenames are present as substrings
assert any("run1.jsonl" in arg for arg in captured_cmd)
assert any("run2.jsonl" in arg for arg in captured_cmd)
# ---- Cancel endpoint tests ----
def test_benchmark_cancel_returns_404_when_not_running(client):
"""POST /api/benchmark/cancel must return 404 if no benchmark is running."""
from app import api as api_module
api_module._running_procs.pop("benchmark", None)
r = client.post("/api/benchmark/cancel")
assert r.status_code == 404
def test_finetune_cancel_returns_404_when_not_running(client):
"""POST /api/finetune/cancel must return 404 if no finetune is running."""
from app import api as api_module
api_module._running_procs.pop("finetune", None)
r = client.post("/api/finetune/cancel")
assert r.status_code == 404
def test_benchmark_cancel_terminates_running_process(client):
"""POST /api/benchmark/cancel must call terminate() on the running process."""
from unittest.mock import MagicMock
from app import api as api_module
mock_proc = MagicMock()
mock_proc.wait = MagicMock()
api_module._running_procs["benchmark"] = mock_proc
try:
r = client.post("/api/benchmark/cancel")
assert r.status_code == 200
assert r.json()["status"] == "cancelled"
mock_proc.terminate.assert_called_once()
finally:
api_module._running_procs.pop("benchmark", None)
api_module._cancelled_jobs.discard("benchmark")
def test_finetune_cancel_terminates_running_process(client):
"""POST /api/finetune/cancel must call terminate() on the running process."""
from unittest.mock import MagicMock
from app import api as api_module
mock_proc = MagicMock()
mock_proc.wait = MagicMock()
api_module._running_procs["finetune"] = mock_proc
try:
r = client.post("/api/finetune/cancel")
assert r.status_code == 200
assert r.json()["status"] == "cancelled"
mock_proc.terminate.assert_called_once()
finally:
api_module._running_procs.pop("finetune", None)
api_module._cancelled_jobs.discard("finetune")
def test_benchmark_cancel_kills_process_on_timeout(client):
"""POST /api/benchmark/cancel must call kill() if the process does not exit within 3 s."""
import subprocess
from unittest.mock import MagicMock
from app import api as api_module
mock_proc = MagicMock()
mock_proc.wait.side_effect = subprocess.TimeoutExpired(cmd="benchmark", timeout=3)
api_module._running_procs["benchmark"] = mock_proc
try:
r = client.post("/api/benchmark/cancel")
assert r.status_code == 200
mock_proc.kill.assert_called_once()
finally:
api_module._running_procs.pop("benchmark", None)
api_module._cancelled_jobs.discard("benchmark")
def test_finetune_run_emits_cancelled_event(client):
"""GET /api/finetune/run must emit cancelled (not error) when job was cancelled."""
from unittest.mock import patch, MagicMock
from app import api as api_module
mock_proc = MagicMock()
mock_proc.stdout = iter([])
mock_proc.returncode = -15 # SIGTERM
def mock_wait():
# Simulate cancel being called while the process is running (after discard clears stale flag)
api_module._cancelled_jobs.add("finetune")
mock_proc.wait = mock_wait
def mock_popen(cmd, **kwargs):
return mock_proc
try:
with patch("app.api._subprocess.Popen",side_effect=mock_popen):
r = client.get("/api/finetune/run?model=deberta-small&epochs=1")
assert '{"type": "cancelled"}' in r.text
assert '"type": "error"' not in r.text
finally:
api_module._cancelled_jobs.discard("finetune")
def test_benchmark_run_emits_cancelled_event(client):
"""GET /api/benchmark/run must emit cancelled (not error) when job was cancelled."""
from unittest.mock import patch, MagicMock
from app import api as api_module
mock_proc = MagicMock()
mock_proc.stdout = iter([])
mock_proc.returncode = -15
def mock_wait():
# Simulate cancel being called while the process is running (after discard clears stale flag)
api_module._cancelled_jobs.add("benchmark")
mock_proc.wait = mock_wait
def mock_popen(cmd, **kwargs):
return mock_proc
try:
with patch("app.api._subprocess.Popen",side_effect=mock_popen):
r = client.get("/api/benchmark/run")
assert '{"type": "cancelled"}' in r.text
assert '"type": "error"' not in r.text
finally:
api_module._cancelled_jobs.discard("benchmark")
assert "items" in r.json()
assert "total" in r.json()

View file

@ -2,11 +2,6 @@
import pytest
def test_registry_has_thirteen_models():
from scripts.benchmark_classifier import MODEL_REGISTRY
assert len(MODEL_REGISTRY) == 13
def test_registry_default_count():
from scripts.benchmark_classifier import MODEL_REGISTRY
defaults = [k for k, v in MODEL_REGISTRY.items() if v["default"]]
@ -166,3 +161,95 @@ def test_active_models_includes_discovered_finetuned(tmp_path):
assert "avocet-deberta-small" in models
assert isinstance(models["avocet-deberta-small"]["adapter_instance"], FineTunedAdapter)
# ---- build_exemplars_from_jsonl() tests ----
def test_build_exemplars_samples_up_to_k_per_label(tmp_path):
from scripts.benchmark_classifier import build_exemplars_from_jsonl
import json
rows = [{"subject": f"S{i}", "body": f"B{i}", "label": "rejected"} for i in range(15)]
rows.append({"subject": "Hire", "body": "Welcome", "label": "hired"})
f = tmp_path / "score.jsonl"
f.write_text("\n".join(json.dumps(r) for r in rows))
result = build_exemplars_from_jsonl(str(f), k_per_label=10)
assert len(result["rejected"]) == 10
assert len(result["hired"]) == 1
assert result["rejected"][0].startswith("Subject: S")
def test_build_exemplars_formats_text_correctly(tmp_path):
from scripts.benchmark_classifier import build_exemplars_from_jsonl
import json
row = {"subject": "My Subject", "body": "My Body", "label": "neutral"}
f = tmp_path / "score.jsonl"
f.write_text(json.dumps(row))
result = build_exemplars_from_jsonl(str(f))
assert result["neutral"][0] == "Subject: My Subject\n\nMy Body"
def test_build_exemplars_skips_rows_missing_label(tmp_path):
from scripts.benchmark_classifier import build_exemplars_from_jsonl
import json
rows = [
{"subject": "A", "body": "B", "label": "neutral"},
{"subject": "No label here", "body": "Body"},
]
f = tmp_path / "score.jsonl"
f.write_text("\n".join(json.dumps(r) for r in rows))
result = build_exemplars_from_jsonl(str(f))
assert list(result.keys()) == ["neutral"]
def test_build_exemplars_truncates_body_at_600(tmp_path):
from scripts.benchmark_classifier import build_exemplars_from_jsonl
import json
row = {"subject": "S", "body": "x" * 800, "label": "neutral"}
f = tmp_path / "score.jsonl"
f.write_text(json.dumps(row))
result = build_exemplars_from_jsonl(str(f))
body_part = result["neutral"][0].split("\n\n", 1)[1]
assert len(body_part) == 600
def test_build_exemplars_skips_rows_with_no_content(tmp_path):
from scripts.benchmark_classifier import build_exemplars_from_jsonl
import json
rows = [
{"label": "neutral"}, # no subject, no body -> skip
{"subject": "S", "body": "B", "label": "neutral"}, # valid -> keep
{"label": "rejected", "subject": "", "body": ""}, # empty strings -> skip
]
f = tmp_path / "score.jsonl"
lines = [json.dumps(r) for r in rows]
f.write_text("\n".join(lines))
result = build_exemplars_from_jsonl(str(f))
assert list(result.keys()) == ["neutral"]
assert len(result["neutral"]) == 1
def test_registry_has_fourteen_models():
from scripts.benchmark_classifier import MODEL_REGISTRY
assert len(MODEL_REGISTRY) == 14
def test_embed_knn_nomic_registry_entry():
from scripts.benchmark_classifier import MODEL_REGISTRY
from scripts.classifier_adapters import EmbeddingKNNAdapter
entry = MODEL_REGISTRY["embed-knn-nomic"]
assert entry["adapter"] is EmbeddingKNNAdapter
assert entry["model_id"] == "nomic-embed-text"
assert entry["params"] == "local-embed"
assert entry["default"] is False
assert entry.get("kwargs", {}).get("k") == 3

View file

@ -14,7 +14,9 @@ from fastapi.testclient import TestClient
@pytest.fixture(autouse=True)
def reset_cforch_globals(tmp_path):
"""Redirect _CONFIG_DIR to tmp_path and reset running-state globals."""
"""Redirect _CONFIG_DIR to tmp_path, reset running-state globals, and stub
list_installed to return [] so real disk model directories don't bleed into
tests that don't exercise the installed-model merge path."""
from app import cforch as cforch_module
prev_config_dir = cforch_module._CONFIG_DIR
@ -25,6 +27,7 @@ def reset_cforch_globals(tmp_path):
cforch_module._BENCH_RUNNING = False
cforch_module._bench_proc = None
with patch("app.models.list_installed", return_value=[]):
yield tmp_path
cforch_module.set_config_dir(prev_config_dir)
@ -141,12 +144,46 @@ def test_models_parses_bench_models_yaml(client, config_dir, tmp_path):
assert m["vram_estimate_mb"] == 6000
def test_models_merges_installed_generators(client, config_dir, tmp_path):
"""Installed cf-text/vllm generator models appear in the model list,
deduplicated against bench_models.yaml entries."""
models_file = tmp_path / "bench_models.yaml"
_write_models_yaml(models_file, [
{"name": "llama3", "id": "llama3:8b", "service": "ollama", "tags": [], "vram_estimate_mb": 6000},
{"name": "already-there", "id": "ibm-granite/granite-4.1-8b", "service": "cf-text", "tags": [], "vram_estimate_mb": 8000},
])
_write_config(config_dir, {"bench_models": str(models_file)})
fake_installed = [
# should be included — cf-text generator not already in YAML
{"model_id": "meta-llama/Llama-3.1-8B", "service": "cf-text", "role": "generator", "vram_mb": 16000},
# should be deduped — repo_id matches a YAML entry
{"model_id": "ibm-granite/granite-4.1-8b", "service": "cf-text", "role": "generator", "vram_mb": 8000},
# should be excluded — classifier, not a generator
{"model_id": "cross-encoder/ms-marco-MiniLM-L6", "service": "avocet", "role": "reranker", "vram_mb": 500},
]
with patch("app.models.list_installed", return_value=fake_installed):
r = client.get("/api/cforch/models")
assert r.status_code == 200
ids = [m["id"] for m in r.json()["models"]]
assert "llama3:8b" in ids # from YAML
assert "ibm-granite/granite-4.1-8b" in ids # from YAML (not duplicated)
assert "meta-llama/Llama-3.1-8B" in ids # merged from installed
assert "cross-encoder/ms-marco-MiniLM-L6" not in ids # filtered out (reranker)
assert ids.count("ibm-granite/granite-4.1-8b") == 1 # no duplicate
# ── GET /run ───────────────────────────────────────────────────────────────────
def test_run_returns_409_when_already_running(client):
"""If _BENCH_RUNNING is True, GET /run returns 409."""
"""If a benchmark subprocess is actively running, GET /run returns 409."""
from unittest.mock import MagicMock
from app import cforch as cforch_module
mock_proc = MagicMock()
mock_proc.poll.return_value = None # process still alive
cforch_module._BENCH_RUNNING = True
cforch_module._bench_proc = mock_proc
r = client.get("/api/cforch/run")
assert r.status_code == 409
@ -180,16 +217,15 @@ def test_run_streams_progress_events(client, config_dir, tmp_path):
"python_bin": "/usr/bin/python3",
})
mock_stdout = MagicMock()
mock_stdout.readline.side_effect = ["Running task 1\n", "Running task 2\n", ""]
mock_proc = MagicMock()
mock_proc.stdout = iter(["Running task 1\n", "Running task 2\n"])
mock_proc.stdout = mock_stdout
mock_proc.returncode = 1 # non-zero so we don't need summary.json
mock_proc.wait = MagicMock()
def mock_wait():
pass
mock_proc.wait = mock_wait
with patch("app.cforch._subprocess.Popen", return_value=mock_proc):
with patch("app.cforch._subprocess.Popen", return_value=mock_proc), \
patch("app.cforch._select.select", return_value=([mock_stdout], [], [])):
r = client.get("/api/cforch/run")
assert r.status_code == 200
@ -222,12 +258,15 @@ def test_run_emits_result_on_success(client, config_dir, tmp_path):
"python_bin": "/usr/bin/python3",
})
mock_stdout = MagicMock()
mock_stdout.readline.side_effect = [""] # no output lines, immediate EOF
mock_proc = MagicMock()
mock_proc.stdout = iter([])
mock_proc.stdout = mock_stdout
mock_proc.returncode = 0
mock_proc.wait = MagicMock()
with patch("app.cforch._subprocess.Popen", return_value=mock_proc):
with patch("app.cforch._subprocess.Popen", return_value=mock_proc), \
patch("app.cforch._select.select", return_value=([mock_stdout], [], [])):
r = client.get("/api/cforch/run")
assert r.status_code == 200
@ -367,3 +406,13 @@ def test_run_passes_license_key_env_to_subprocess(client, config_dir, tmp_path,
client.get("/api/cforch/run")
assert captured_env.get("CF_LICENSE_KEY") == "CFG-AVCT-ENV-ONLY-KEY"
def test_eval_cforch_router_includes_all_sub_routers():
"""eval/cforch.py router must include routes from all four sub-routers."""
from app.eval.cforch import router
paths = {r.path for r in router.routes}
assert any("/cforch/" in p for p in paths), f"no /cforch/ routes found in {paths}"
assert any("/style/" in p for p in paths), f"no /style/ routes found in {paths}"
assert any("/voice/" in p for p in paths), f"no /voice/ routes found in {paths}"
assert any("/plans-bench/" in p for p in paths), f"no /plans-bench/ routes found in {paths}"

View file

@ -268,3 +268,373 @@ def test_finetuned_adapter_unload_clears_pipeline():
assert adapter._pipeline is not None
adapter.unload()
assert adapter._pipeline is None
# ---- _cosine() tests ----
def test_cosine_identical_unit_vectors():
import math
from scripts.classifier_adapters import _cosine
assert _cosine([1.0, 0.0], [1.0, 0.0]) == pytest.approx(1.0)
def test_cosine_orthogonal_vectors():
from scripts.classifier_adapters import _cosine
assert _cosine([1.0, 0.0], [0.0, 1.0]) == pytest.approx(0.0)
def test_cosine_known_value():
import math
from scripts.classifier_adapters import _cosine
# [1,0] vs [1/sqrt(2), 1/sqrt(2)] → dot = 1/sqrt(2), both norms = 1 → 1/sqrt(2)
v = [1.0 / math.sqrt(2), 1.0 / math.sqrt(2)]
assert _cosine([1.0, 0.0], v) == pytest.approx(1.0 / math.sqrt(2))
def test_cosine_zero_vector_returns_zero():
from scripts.classifier_adapters import _cosine
assert _cosine([0.0, 0.0], [1.0, 0.0]) == pytest.approx(0.0)
# ---- DEFAULT_EXEMPLARS tests ----
def test_default_exemplars_covers_all_labels():
from scripts.classifier_adapters import DEFAULT_EXEMPLARS, LABELS
for label in LABELS:
assert label in DEFAULT_EXEMPLARS, f"DEFAULT_EXEMPLARS missing label: {label}"
assert len(DEFAULT_EXEMPLARS[label]) >= 4, f"{label} needs >= 4 exemplars for k=3 voting"
def test_default_exemplars_sparse_labels_have_at_least_four():
from scripts.classifier_adapters import DEFAULT_EXEMPLARS
# These labels have very few real examples; need >= 4 so k=3 vote is meaningful
for label in ("hired", "survey_received", "event_rescheduled"):
assert len(DEFAULT_EXEMPLARS[label]) >= 4, (
f"{label} needs >= 4 exemplars for k=3 voting to work reliably"
)
def test_default_exemplars_strings_are_formatted_correctly():
from scripts.classifier_adapters import DEFAULT_EXEMPLARS
for label, texts in DEFAULT_EXEMPLARS.items():
for text in texts:
assert text.startswith("Subject: "), (
f"{label!r} exemplar missing 'Subject: ' prefix: {text[:50]!r}"
)
assert "\n\n" in text, (
f"{label!r} exemplar missing double-newline separator: {text[:50]!r}"
)
# ---- EmbeddingKNNAdapter constructor tests ----
def test_embedding_knn_is_classifier_adapter():
from scripts.classifier_adapters import EmbeddingKNNAdapter, ClassifierAdapter
adapter = EmbeddingKNNAdapter(
"test-knn", "nomic-embed-text",
k=3, orch_url="http://orch:7700", ollama_url="http://ollama:11434",
)
assert isinstance(adapter, ClassifierAdapter)
def test_embedding_knn_name_and_model_id():
from scripts.classifier_adapters import EmbeddingKNNAdapter
adapter = EmbeddingKNNAdapter(
"embed-knn-nomic", "nomic-embed-text",
k=3, orch_url="http://orch:7700", ollama_url="http://ollama:11434",
)
assert adapter.name == "embed-knn-nomic"
assert adapter.model_id == "nomic-embed-text"
def test_embedding_knn_uses_default_exemplars_when_none_given():
from scripts.classifier_adapters import EmbeddingKNNAdapter, DEFAULT_EXEMPLARS
adapter = EmbeddingKNNAdapter(
"test", "nomic-embed-text",
k=3, orch_url="http://orch:7700", ollama_url="http://ollama:11434",
)
assert adapter._exemplar_texts is DEFAULT_EXEMPLARS
def test_embedding_knn_accepts_custom_exemplars():
from scripts.classifier_adapters import EmbeddingKNNAdapter
custom = {"rejected": ["Sorry, we went with others."]}
adapter = EmbeddingKNNAdapter(
"test", "nomic-embed-text",
k=3, orch_url="http://orch:7700", ollama_url="http://ollama:11434",
exemplar_texts=custom,
)
assert adapter._exemplar_texts is custom
# ---- EmbeddingKNNAdapter.load() tests ----
def _make_post_mock(alloc_url="http://navi:11434", alloc_id="alloc-abc"):
"""Return a side_effect function for patching httpx.post.
Allocate calls get alloc_url/alloc_id; embed calls return one [0.1,0.2,0.3]
embedding per input text.
"""
def _side_effect(url, *, json=None, timeout=None, **kwargs):
from unittest.mock import MagicMock
resp = MagicMock()
resp.raise_for_status.return_value = None
if "/allocate" in url:
resp.status_code = 200
resp.json.return_value = {"allocation_id": alloc_id, "url": alloc_url}
else:
n = len((json or {}).get("input", []))
resp.status_code = 200
resp.json.return_value = {"data": [{"embedding": [0.1, 0.2, 0.3]}] * n}
return resp
return _side_effect
def test_load_calls_allocate_then_embeds_each_label():
from unittest.mock import patch
from scripts.classifier_adapters import EmbeddingKNNAdapter
exemplars = {
"rejected": ["We went with others"],
"hired": ["Welcome aboard!", "First day info"],
}
adapter = EmbeddingKNNAdapter(
"test", "nomic-embed-text", k=3,
orch_url="http://orch:7700", ollama_url="http://ollama:11434",
exemplar_texts=exemplars,
)
post_urls = []
def capturing_mock(url, *, json=None, timeout=None, **kwargs):
post_urls.append(url)
return _make_post_mock()(url, json=json, timeout=timeout)
with patch("httpx.post", side_effect=capturing_mock):
adapter.load()
assert any("/allocate" in u for u in post_urls), "expected allocate call"
assert any("/v1/embeddings" in u for u in post_urls), "expected embed call"
assert adapter._allocation_id == "alloc-abc"
assert adapter._node_url == "http://navi:11434"
assert adapter._orch_url_used == "http://orch:7700"
assert "rejected" in adapter._exemplar_embeddings
assert "hired" in adapter._exemplar_embeddings
assert len(adapter._exemplar_embeddings["rejected"]) == 1
assert len(adapter._exemplar_embeddings["hired"]) == 2
assert adapter._exemplar_embeddings["rejected"][0] == [0.1, 0.2, 0.3]
assert adapter._exemplar_embeddings["hired"][0] == [0.1, 0.2, 0.3]
def test_load_falls_back_to_ollama_when_allocate_fails():
from unittest.mock import patch, MagicMock
from scripts.classifier_adapters import EmbeddingKNNAdapter
exemplars = {"rejected": ["We went with others"]}
adapter = EmbeddingKNNAdapter(
"test", "nomic-embed-text", k=3,
orch_url="http://orch:7700", ollama_url="http://ollama:11434",
exemplar_texts=exemplars,
)
def failing_allocate_mock(url, *, json=None, timeout=None, **kwargs):
resp = MagicMock()
if "/allocate" in url:
resp.status_code = 503
resp.json.return_value = {}
else:
resp.raise_for_status.return_value = None
resp.json.return_value = {"data": [{"embedding": [0.1, 0.2, 0.3]}]}
return resp
with patch("httpx.post", side_effect=failing_allocate_mock):
adapter.load()
assert adapter._allocation_id == ""
assert adapter._orch_url_used == ""
assert adapter._node_url == "http://ollama:11434"
assert "rejected" in adapter._exemplar_embeddings
def test_load_falls_back_to_ollama_when_allocate_raises():
from unittest.mock import patch, MagicMock
import httpx as _httpx
from scripts.classifier_adapters import EmbeddingKNNAdapter
exemplars = {"rejected": ["We went with others"]}
adapter = EmbeddingKNNAdapter(
"test", "nomic-embed-text", k=3,
orch_url="http://orch:7700", ollama_url="http://ollama:11434",
exemplar_texts=exemplars,
)
def raising_mock(url, *, json=None, timeout=None, **kwargs):
if "/allocate" in url:
raise _httpx.ConnectError("connection refused")
resp = MagicMock()
resp.raise_for_status.return_value = None
resp.json.return_value = {"data": [{"embedding": [0.1, 0.2, 0.3]}]}
return resp
with patch("httpx.post", side_effect=raising_mock):
adapter.load()
assert adapter._allocation_id == ""
assert adapter._orch_url_used == ""
assert adapter._node_url == "http://ollama:11434"
assert "rejected" in adapter._exemplar_embeddings
# ---- EmbeddingKNNAdapter.unload() tests ----
def test_unload_releases_orch_allocation_and_clears_state():
from unittest.mock import patch, MagicMock
from scripts.classifier_adapters import EmbeddingKNNAdapter
adapter = EmbeddingKNNAdapter(
"test", "nomic-embed-text", k=3,
orch_url="http://orch:7700", ollama_url="http://ollama:11434",
)
adapter._exemplar_embeddings = {"rejected": [[1.0, 0.0]]}
adapter._node_url = "http://navi:11434"
adapter._allocation_id = "alloc-abc"
adapter._orch_url_used = "http://orch:7700"
delete_calls = []
def mock_request(method, url, **kwargs):
delete_calls.append((method, url))
resp = MagicMock()
resp.status_code = 200
return resp
with patch("httpx.request", side_effect=mock_request):
adapter.unload()
assert len(delete_calls) == 1
method, url = delete_calls[0]
assert method == "DELETE"
assert "alloc-abc" in url
assert adapter._exemplar_embeddings == {}
assert adapter._allocation_id == ""
assert adapter._node_url == ""
assert adapter._orch_url_used == ""
def test_unload_skips_delete_on_ollama_fallback_path():
from unittest.mock import patch
from scripts.classifier_adapters import EmbeddingKNNAdapter
adapter = EmbeddingKNNAdapter(
"test", "nomic-embed-text", k=3,
orch_url="http://orch:7700", ollama_url="http://ollama:11434",
)
adapter._exemplar_embeddings = {"rejected": [[1.0, 0.0]]}
adapter._node_url = "http://ollama:11434"
adapter._allocation_id = "" # fallback path: no allocation was made
adapter._orch_url_used = ""
delete_calls = []
with patch("httpx.request", side_effect=lambda *a, **k: delete_calls.append(a)):
adapter.unload()
assert len(delete_calls) == 0
assert adapter._exemplar_embeddings == {}
assert adapter._node_url == ""
# ---- EmbeddingKNNAdapter.classify() tests ----
def _adapter_with_embeddings(exemplar_embeddings, k=3):
"""Return a pre-loaded EmbeddingKNNAdapter (bypass load()) with given per-label vectors."""
from scripts.classifier_adapters import EmbeddingKNNAdapter
adapter = EmbeddingKNNAdapter(
"test", "nomic-embed-text", k=k,
orch_url="http://orch:7700", ollama_url="http://ollama:11434",
)
adapter._exemplar_embeddings = exemplar_embeddings
adapter._node_url = "http://navi:11434"
return adapter
def _embed_resp(vec):
"""Return a mock httpx response for /v1/embeddings returning a single vector."""
from unittest.mock import MagicMock
resp = MagicMock()
resp.raise_for_status.return_value = None
resp.json.return_value = {"data": [{"embedding": vec}]}
return resp
def test_classify_returns_majority_vote_label():
from unittest.mock import patch
adapter = _adapter_with_embeddings({
"rejected": [[1.0, 0.0, 0.0], [0.9, 0.1, 0.0], [0.85, 0.15, 0.0]],
"neutral": [[0.0, 1.0, 0.0]],
}, k=3)
# Query [1,0,0] is closest to all three "rejected" exemplars
with patch("httpx.post", return_value=_embed_resp([1.0, 0.0, 0.0])):
result = adapter.classify("We went with others", "Thank you for applying.")
assert result == "rejected"
def test_classify_tiebreak_by_mean_score():
from unittest.mock import patch
# k=2: each label gets exactly 1 vote → tie-break by mean similarity
# [1,0] query: cosine to [1,0] = 1.0 ("rejected"), cosine to [0.6,0.8] ≈ 0.6 ("neutral")
adapter = _adapter_with_embeddings({
"rejected": [[1.0, 0.0]],
"neutral": [[0.6, 0.8]],
}, k=2)
with patch("httpx.post", return_value=_embed_resp([1.0, 0.0])):
result = adapter.classify("Rejection", "Sorry")
assert result == "rejected"
def test_classify_sparse_label_can_win():
from unittest.mock import patch
# "hired" has only 1 exemplar; with k=1, the single closest match wins
adapter = _adapter_with_embeddings({
"rejected": [[0.0, 0.0, 1.0], [0.0, 0.1, 0.9]],
"hired": [[1.0, 0.0, 0.0]],
}, k=1)
# Query [1,0,0] → hired exemplar scores 1.0; closest single match wins
with patch("httpx.post", return_value=_embed_resp([1.0, 0.0, 0.0])):
result = adapter.classify("Welcome aboard", "Your first day details")
assert result == "hired"
def test_classify_lazy_loads_when_not_loaded():
from unittest.mock import patch
from scripts.classifier_adapters import EmbeddingKNNAdapter
exemplars = {"rejected": ["We went with others"]}
adapter = EmbeddingKNNAdapter(
"test", "nomic-embed-text", k=1,
orch_url="http://orch:7700", ollama_url="http://ollama:11434",
exemplar_texts=exemplars,
)
assert adapter._exemplar_embeddings == {}
post_urls = []
def mock_post(url, *, json=None, timeout=None, **kwargs):
post_urls.append(url)
from unittest.mock import MagicMock
resp = MagicMock()
resp.raise_for_status.return_value = None
if "/allocate" in url:
resp.status_code = 200
resp.json.return_value = {"allocation_id": "a1", "url": "http://navi:11434"}
else:
n = len((json or {}).get("input", []))
resp.json.return_value = {"data": [{"embedding": [1.0, 0.0]}] * n}
return resp
with patch("httpx.post", side_effect=mock_post):
result = adapter.classify("Rejection", "Sorry")
assert result == "rejected"
assert any("/allocate" in u for u in post_urls), "lazy load must call allocate"
assert adapter._exemplar_embeddings != {}
assert adapter._node_url == "http://navi:11434"

122
tests/test_dashboard.py Normal file
View file

@ -0,0 +1,122 @@
"""Tests for app/dashboard.py -- GET /api/dashboard."""
import json
import pytest
import yaml
from fastapi.testclient import TestClient
from pathlib import Path
@pytest.fixture(autouse=True)
def reset_globals(tmp_path):
from app import dashboard as dash_module
dash_module.set_data_dir(tmp_path)
dash_module.set_config_dir(tmp_path)
yield
@pytest.fixture
def client():
from app.api import app
return TestClient(app)
def _write_score(tmp_path: Path, records: list[dict]) -> None:
(tmp_path / "email_score.jsonl").write_text(
"\n".join(json.dumps(r) for r in records) + "\n"
)
def _write_summary(tmp_path: Path, run_id: str, ts: str, score: float) -> None:
run_dir = tmp_path / "bench_results" / run_id
run_dir.mkdir(parents=True)
(run_dir / "summary.json").write_text(
json.dumps({"timestamp": ts, "best_macro_f1": score})
)
def test_dashboard_returns_expected_keys(client):
r = client.get("/api/dashboard")
assert r.status_code == 200
data = r.json()
for key in ("labeled_since_last_eval", "last_eval_timestamp", "last_eval_best_score",
"active_jobs", "corrections_pending", "corrections_export_ready", "signals"):
assert key in data, f"missing key: {key}"
for sig in ("data_to_eval", "eval_to_train", "train_to_fleet"):
assert sig in data["signals"], f"missing signal: {sig}"
def test_dashboard_empty_state(client):
r = client.get("/api/dashboard")
assert r.status_code == 200
data = r.json()
assert data["labeled_since_last_eval"] == 0
assert data["last_eval_timestamp"] is None
assert data["last_eval_best_score"] is None
assert data["active_jobs"] == []
assert data["corrections_pending"] == 0
assert data["corrections_export_ready"] == 0
def test_labeled_since_counts_all_when_no_eval(client, tmp_path):
_write_score(tmp_path, [
{"id": "a", "label": "neutral", "labeled_at": "2026-05-01T10:00:00+00:00"},
{"id": "b", "label": "neutral", "labeled_at": "2026-05-01T11:00:00+00:00"},
])
r = client.get("/api/dashboard")
assert r.json()["labeled_since_last_eval"] == 2
def test_labeled_since_filters_by_eval_timestamp(client, tmp_path):
_write_summary(tmp_path, "2026-05-01-100000", "2026-05-01T10:00:00+00:00", 0.80)
_write_score(tmp_path, [
{"id": "a", "label": "neutral", "labeled_at": "2026-05-01T09:00:00+00:00"},
{"id": "b", "label": "neutral", "labeled_at": "2026-05-01T11:00:00+00:00"},
])
(tmp_path / "label_tool.yaml").write_text(
yaml.dump({"cforch": {"results_dir": str(tmp_path / "bench_results")}})
)
r = client.get("/api/dashboard")
data = r.json()
assert data["labeled_since_last_eval"] == 1
assert abs(data["last_eval_best_score"] - 0.80) < 0.001
def test_data_to_eval_false_below_threshold(client, tmp_path):
_write_score(tmp_path, [{"id": str(i), "label": "neutral",
"labeled_at": "2026-05-01T10:00:00+00:00"} for i in range(10)])
(tmp_path / "label_tool.yaml").write_text(yaml.dump({"pipeline": {"data_eval_threshold": 50}}))
r = client.get("/api/dashboard")
assert r.json()["signals"]["data_to_eval"] is False
def test_data_to_eval_true_at_threshold(client, tmp_path):
_write_score(tmp_path, [{"id": str(i), "label": "neutral",
"labeled_at": "2026-05-01T10:00:00+00:00"} for i in range(50)])
(tmp_path / "label_tool.yaml").write_text(yaml.dump({"pipeline": {"data_eval_threshold": 50}}))
r = client.get("/api/dashboard")
assert r.json()["signals"]["data_to_eval"] is True
def test_corrections_pending_count(client, tmp_path):
candidates = [
{"id": "c1", "status": "needs_review"},
{"id": "c2", "status": "needs_review"},
{"id": "c3", "status": "discarded"},
]
(tmp_path / "sft_candidates.jsonl").write_text(
"\n".join(json.dumps(c) for c in candidates) + "\n"
)
r = client.get("/api/dashboard")
assert r.json()["corrections_pending"] == 2
def test_corrections_export_ready_count(client, tmp_path):
approved = [
{"id": "a1", "status": "approved", "corrected_response": "Good answer"},
{"id": "a2", "status": "approved", "corrected_response": ""},
{"id": "a3", "status": "approved", "corrected_response": "Another answer"},
]
(tmp_path / "sft_approved.jsonl").write_text(
"\n".join(json.dumps(a) for a in approved) + "\n"
)
r = client.get("/api/dashboard")
assert r.json()["corrections_export_ready"] == 2

View file

@ -0,0 +1,102 @@
"""Tests for app/data/corrections.py -- POST /api/sft/ingest.
The corrections router is mounted at prefix="/api/sft" via the app/sft.py
backward-compat shim, so ingest lives at /api/sft/ingest.
"""
import json
import pytest
from fastapi.testclient import TestClient
@pytest.fixture(autouse=True)
def reset_globals(tmp_path):
from app.data import corrections as corr_module
corr_module.set_data_dir(tmp_path)
corr_module.set_config_dir(tmp_path)
yield
@pytest.fixture
def client():
from app.api import app
return TestClient(app)
_VALID_PAYLOAD = {
"source": "peregrine",
"task_type": "email_classification",
"prompt": "Classify this email: ...",
"response": "skip",
"correction": "action_required",
"label": "action_required",
}
_SECRET = "test-secret-abc123"
def test_ingest_503_when_secret_not_configured(client, monkeypatch):
monkeypatch.delenv("AVOCET_INGESTION_SECRET", raising=False)
r = client.post("/api/sft/ingest", json=_VALID_PAYLOAD,
headers={"Authorization": f"Bearer {_SECRET}"})
assert r.status_code == 503
def test_ingest_401_when_no_auth_header(client, monkeypatch):
monkeypatch.setenv("AVOCET_INGESTION_SECRET", _SECRET)
r = client.post("/api/sft/ingest", json=_VALID_PAYLOAD)
assert r.status_code == 401
def test_ingest_401_when_malformed_header(client, monkeypatch):
monkeypatch.setenv("AVOCET_INGESTION_SECRET", _SECRET)
r = client.post("/api/sft/ingest", json=_VALID_PAYLOAD,
headers={"Authorization": "Token bad-format"})
assert r.status_code == 401
def test_ingest_403_when_wrong_secret(client, monkeypatch):
monkeypatch.setenv("AVOCET_INGESTION_SECRET", _SECRET)
r = client.post("/api/sft/ingest", json=_VALID_PAYLOAD,
headers={"Authorization": "Bearer wrong-secret"})
assert r.status_code == 403
def test_ingest_creates_approved_record(client, monkeypatch, tmp_path):
from app.data import corrections as corr_module
monkeypatch.setenv("AVOCET_INGESTION_SECRET", _SECRET)
corr_module.set_data_dir(tmp_path)
r = client.post("/api/sft/ingest", json=_VALID_PAYLOAD,
headers={"Authorization": f"Bearer {_SECRET}"})
assert r.status_code == 200
data = r.json()
assert data["ok"] is True
assert "id" in data
candidates = corr_module.read_jsonl(corr_module._candidates_file())
assert len(candidates) == 1
rec = candidates[0]
assert rec["status"] == "approved"
assert rec["source"] == "peregrine"
assert rec["corrected_response"] == "action_required"
assert rec["id"] == data["id"]
def test_ingest_also_writes_to_approved_file(client, monkeypatch, tmp_path):
from app.data import corrections as corr_module
monkeypatch.setenv("AVOCET_INGESTION_SECRET", _SECRET)
corr_module.set_data_dir(tmp_path)
r = client.post("/api/sft/ingest", json=_VALID_PAYLOAD,
headers={"Authorization": f"Bearer {_SECRET}"})
assert r.status_code == 200
approved = corr_module.read_jsonl(corr_module._approved_file())
assert len(approved) == 1
assert approved[0]["id"] == r.json()["id"]
def test_ingest_without_label_is_accepted(client, monkeypatch, tmp_path):
from app.data import corrections as corr_module
monkeypatch.setenv("AVOCET_INGESTION_SECRET", _SECRET)
corr_module.set_data_dir(tmp_path)
payload = {**_VALID_PAYLOAD, "label": None}
r = client.post("/api/sft/ingest", json=payload,
headers={"Authorization": f"Bearer {_SECRET}"})
assert r.status_code == 200

95
tests/test_data_fetch.py Normal file
View file

@ -0,0 +1,95 @@
"""Tests for app/data/fetch.py"""
import json
import yaml
import pytest
from fastapi.testclient import TestClient
from unittest.mock import MagicMock, patch
@pytest.fixture(autouse=True)
def reset_globals(tmp_path):
from app.data import fetch as fetch_module
fetch_module.set_data_dir(tmp_path)
fetch_module.set_config_dir(tmp_path)
yield
@pytest.fixture
def client():
from app.api import app
return TestClient(app)
def _parse_sse(content: bytes) -> list[dict]:
events = []
for line in content.decode().splitlines():
if line.startswith("data: "):
events.append(json.loads(line[6:]))
return events
def test_account_test_missing_fields(client):
r = client.post("/api/accounts/test",
json={"account": {"host": "", "username": "", "password": ""}})
assert r.status_code == 200
data = r.json()
assert data["ok"] is False
assert "required" in data["message"].lower()
def test_account_test_success(client):
mock_conn = MagicMock()
mock_conn.select.return_value = ("OK", [b"99"])
with patch("app.data.fetch.imaplib.IMAP4_SSL", return_value=mock_conn):
r = client.post("/api/accounts/test", json={"account": {
"host": "imap.example.com", "port": 993, "use_ssl": True,
"username": "u@example.com", "password": "pw", "folder": "INBOX",
}})
assert r.status_code == 200
data = r.json()
assert data["ok"] is True
assert data["count"] == 99
def test_fetch_stream_no_accounts_configured(client, tmp_path):
r = client.get("/api/fetch/stream?accounts=NoSuchAccount&days_back=30&limit=10")
assert r.status_code == 200
events = _parse_sse(r.content)
complete = next((e for e in events if e["type"] == "complete"), None)
assert complete is not None
assert complete["total_added"] == 0
def test_fetch_stream_with_mock_imap(client, tmp_path):
from app.data import fetch as fetch_module
fetch_module.set_config_dir(tmp_path)
cfg = {"accounts": [{"name": "Mock", "host": "h", "port": 993, "use_ssl": True,
"username": "u", "password": "p", "folder": "INBOX",
"days_back": 30}], "max_per_account": 50}
(tmp_path / "label_tool.yaml").write_text(yaml.dump(cfg))
raw_msg = (b"Subject: Interview\r\nFrom: a@b.com\r\n"
b"Date: Mon, 1 Mar 2026 12:00:00 +0000\r\n\r\nBody")
mock_conn = MagicMock()
mock_conn.search.return_value = ("OK", [b"1"])
mock_conn.fetch.return_value = ("OK", [(b"1 (RFC822 {N})", raw_msg)])
with patch("app.data.fetch.imaplib.IMAP4_SSL", return_value=mock_conn):
r = client.get("/api/fetch/stream?accounts=Mock&days_back=30&limit=50")
assert r.status_code == 200
events = _parse_sse(r.content)
types = [e["type"] for e in events]
assert "start" in types
assert "done" in types
assert "complete" in types
def test_entry_key_deterministic():
from app.data.fetch import entry_key
e = {"subject": "Test", "body": "Hello world"}
assert entry_key(e) == entry_key(e)
def test_entry_key_differs_by_subject():
from app.data.fetch import entry_key
a = {"subject": "A", "body": "same body"}
b = {"subject": "B", "body": "same body"}
assert entry_key(a) != entry_key(b)

219
tests/test_data_label.py Normal file
View file

@ -0,0 +1,219 @@
"""Tests for app/data/label.py"""
import json
import pytest
import yaml
from fastapi.testclient import TestClient
@pytest.fixture(autouse=True)
def reset_globals(tmp_path):
from app.data import label as label_module
label_module.set_data_dir(tmp_path)
label_module.set_config_dir(tmp_path)
label_module.reset_last_action()
yield
label_module.reset_last_action()
@pytest.fixture
def client():
from app.api import app
return TestClient(app)
@pytest.fixture
def queue_with_items(tmp_path):
from app.data import label as label_module
items = [
{"id": f"id{i}", "subject": f"Subject {i}", "body": f"Body {i}",
"from": "test@example.com", "date": "2026-03-01", "source": "imap:test"}
for i in range(3)
]
(label_module._DATA_DIR / "email_label_queue.jsonl").write_text(
"\n".join(json.dumps(x) for x in items) + "\n")
return items
def test_queue_returns_items(client, queue_with_items):
r = client.get("/api/queue?limit=2")
assert r.status_code == 200
data = r.json()
assert len(data["items"]) == 2
assert data["total"] == 3
def test_queue_empty_when_no_file(client):
r = client.get("/api/queue")
assert r.status_code == 200
assert r.json() == {"items": [], "total": 0}
def test_label_appends_to_score(client, queue_with_items):
from app.data import label as label_module
r = client.post("/api/label", json={"id": "id0", "label": "interview_scheduled"})
assert r.status_code == 200
records = label_module.read_jsonl(label_module._score_file())
assert len(records) == 1
assert records[0]["id"] == "id0"
assert records[0]["label"] == "interview_scheduled"
assert "labeled_at" in records[0]
def test_label_removes_from_queue(client, queue_with_items):
from app.data import label as label_module
client.post("/api/label", json={"id": "id0", "label": "rejected"})
queue = label_module.read_jsonl(label_module._queue_file())
assert not any(x["id"] == "id0" for x in queue)
def test_label_unknown_id_returns_404(client, queue_with_items):
r = client.post("/api/label", json={"id": "unknown", "label": "neutral"})
assert r.status_code == 404
def test_skip_moves_to_back(client, queue_with_items):
from app.data import label as label_module
r = client.post("/api/skip", json={"id": "id0"})
assert r.status_code == 200
queue = label_module.read_jsonl(label_module._queue_file())
assert queue[-1]["id"] == "id0"
assert queue[0]["id"] == "id1"
def test_skip_unknown_id_returns_404(client, queue_with_items):
r = client.post("/api/skip", json={"id": "nope"})
assert r.status_code == 404
def test_discard_writes_to_discarded_file(client, queue_with_items):
from app.data import label as label_module
r = client.post("/api/discard", json={"id": "id1"})
assert r.status_code == 200
discarded = label_module.read_jsonl(label_module._discarded_file())
assert len(discarded) == 1
assert discarded[0]["id"] == "id1"
assert discarded[0]["label"] == "__discarded__"
def test_discard_removes_from_queue(client, queue_with_items):
from app.data import label as label_module
client.post("/api/discard", json={"id": "id1"})
queue = label_module.read_jsonl(label_module._queue_file())
assert not any(x["id"] == "id1" for x in queue)
def test_undo_label_removes_from_score(client, queue_with_items):
from app.data import label as label_module
client.post("/api/label", json={"id": "id0", "label": "neutral"})
r = client.delete("/api/label/undo")
assert r.status_code == 200
assert r.json()["undone"]["type"] == "label"
assert label_module.read_jsonl(label_module._score_file()) == []
queue = label_module.read_jsonl(label_module._queue_file())
assert queue[0]["id"] == "id0"
def test_undo_discard_removes_from_discarded(client, queue_with_items):
from app.data import label as label_module
client.post("/api/discard", json={"id": "id0"})
r = client.delete("/api/label/undo")
assert r.status_code == 200
assert label_module.read_jsonl(label_module._discarded_file()) == []
def test_undo_skip_restores_to_front(client, queue_with_items):
from app.data import label as label_module
client.post("/api/skip", json={"id": "id0"})
r = client.delete("/api/label/undo")
assert r.status_code == 200
queue = label_module.read_jsonl(label_module._queue_file())
assert queue[0]["id"] == "id0"
def test_undo_with_no_action_returns_404(client):
r = client.delete("/api/label/undo")
assert r.status_code == 404
def test_config_labels_returns_10_labels(client):
r = client.get("/api/config/labels")
assert r.status_code == 200
labels = r.json()
assert len(labels) == 10
assert labels[0]["key"] == "1"
for lbl in labels:
assert "emoji" in lbl and "color" in lbl and "name" in lbl
def test_get_config_returns_empty_when_no_file(client):
r = client.get("/api/config")
assert r.status_code == 200
data = r.json()
assert data["accounts"] == []
assert data["max_per_account"] == 500
def test_post_config_writes_yaml(client, tmp_path):
from app.data import label as label_module
label_module.set_config_dir(tmp_path)
payload = {"accounts": [{"name": "Test", "host": "imap.test.com", "port": 993,
"use_ssl": True, "username": "u@t.com", "password": "pw",
"folder": "INBOX", "days_back": 30}], "max_per_account": 200}
r = client.post("/api/config", json=payload)
assert r.status_code == 200
assert r.json()["ok"] is True
saved = yaml.safe_load((tmp_path / "label_tool.yaml").read_text())
assert saved["max_per_account"] == 200
assert saved["accounts"][0]["name"] == "Test"
def test_get_config_round_trips(client, tmp_path):
from app.data import label as label_module
label_module.set_config_dir(tmp_path)
payload = {"accounts": [{"name": "R", "host": "h", "port": 993, "use_ssl": True,
"username": "u", "password": "p", "folder": "INBOX",
"days_back": 90}], "max_per_account": 300}
client.post("/api/config", json=payload)
r = client.get("/api/config")
data = r.json()
assert data["max_per_account"] == 300
assert data["accounts"][0]["name"] == "R"
def test_stats_returns_counts(client, tmp_path):
from app.data import label as label_module
label_module.set_data_dir(tmp_path)
score_path = tmp_path / "email_score.jsonl"
records = [{"id": "a", "label": "interview_scheduled"},
{"id": "b", "label": "interview_scheduled"},
{"id": "c", "label": "rejected"}]
score_path.write_text("\n".join(json.dumps(r) for r in records) + "\n")
r = client.get("/api/stats")
assert r.status_code == 200
data = r.json()
assert data["total"] == 3
assert data["counts"]["interview_scheduled"] == 2
assert data["counts"]["rejected"] == 1
def test_stats_empty_when_no_file(client):
r = client.get("/api/stats")
assert r.status_code == 200
data = r.json()
assert data["total"] == 0
assert data["counts"] == {}
assert data["score_file_bytes"] == 0
def test_stats_download_returns_file(client, tmp_path):
from app.data import label as label_module
label_module.set_data_dir(tmp_path)
(tmp_path / "email_score.jsonl").write_text(json.dumps({"id": "a", "label": "neutral"}) + "\n")
r = client.get("/api/stats/download")
assert r.status_code == 200
assert "jsonlines" in r.headers.get("content-type", "")
def test_stats_download_404_when_no_file(client):
r = client.get("/api/stats/download")
assert r.status_code == 404

234
tests/test_embed_bench.py Normal file
View file

@ -0,0 +1,234 @@
"""Tests for app/eval/embed_bench.py."""
from __future__ import annotations
import json
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
from fastapi.testclient import TestClient
# ── Fixtures ──────────────────────────────────────────────────────────────────
@pytest.fixture(autouse=True)
def reset_embed_bench_globals(tmp_path):
"""Redirect config dir to tmp_path and reset running flag."""
from app.eval import embed_bench as mod
prev_config_dir = mod._CONFIG_DIR
prev_running = mod._RUN_ACTIVE
mod.set_config_dir(tmp_path)
mod._RUN_ACTIVE = False
yield tmp_path
mod.set_config_dir(prev_config_dir)
mod._RUN_ACTIVE = prev_running
@pytest.fixture
def client():
from app.api import app
return TestClient(app)
# ── cosine helper ──────────────────────────────────────────────────────────────
def test_cosine_identical():
from app.eval.embed_bench import _cosine
assert _cosine([1.0, 0.0], [1.0, 0.0]) == pytest.approx(1.0)
def test_cosine_orthogonal():
from app.eval.embed_bench import _cosine
assert _cosine([1.0, 0.0], [0.0, 1.0]) == pytest.approx(0.0)
def test_cosine_opposite():
from app.eval.embed_bench import _cosine
assert _cosine([1.0, 0.0], [-1.0, 0.0]) == pytest.approx(-1.0)
def test_cosine_zero_vector_returns_zero():
from app.eval.embed_bench import _cosine
assert _cosine([0.0, 0.0], [1.0, 0.0]) == pytest.approx(0.0)
# ── models endpoint ────────────────────────────────────────────────────────────
def test_models_returns_list_with_mock(client, tmp_path):
"""GET /api/embed-bench/models returns list from Ollama tags endpoint."""
import yaml
cfg = {"cforch": {"ollama_url": "http://localhost:11434"}}
(tmp_path / "label_tool.yaml").write_text(yaml.dump(cfg))
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.json.return_value = {
"models": [
{"name": "nomic-embed-text", "size": 274302480},
{"name": "mxbai-embed-large", "size": 669000000},
]
}
mock_resp.raise_for_status = MagicMock()
with patch("app.eval.embed_bench.httpx.get", return_value=mock_resp):
r = client.get("/api/embed-bench/models")
assert r.status_code == 200
data = r.json()
assert isinstance(data["models"], list)
assert any(m["name"] == "nomic-embed-text" for m in data["models"])
def test_models_returns_empty_on_ollama_error(client, tmp_path):
"""GET /api/embed-bench/models returns empty list if Ollama unreachable."""
import httpx
with patch("app.eval.embed_bench.httpx.get", side_effect=httpx.ConnectError("refused")):
r = client.get("/api/embed-bench/models")
assert r.status_code == 200
assert r.json()["models"] == []
# ── run endpoint ───────────────────────────────────────────────────────────────
def test_run_empty_corpus_returns_422(client):
r = client.post("/api/embed-bench/run", json={
"corpus": [], "queries": ["test"], "models": ["nomic-embed-text"], "top_k": 3
})
assert r.status_code == 422
def test_run_empty_queries_returns_422(client):
r = client.post("/api/embed-bench/run", json={
"corpus": ["chunk 1"], "queries": [], "models": ["nomic-embed-text"], "top_k": 3
})
assert r.status_code == 422
def test_run_empty_models_returns_422(client):
r = client.post("/api/embed-bench/run", json={
"corpus": ["chunk 1"], "queries": ["test"], "models": [], "top_k": 3
})
assert r.status_code == 422
def _fake_embed_response(texts: list[str]) -> MagicMock:
"""Build a mock httpx.post response returning unit vectors for each text."""
resp = MagicMock()
resp.raise_for_status = MagicMock()
resp.json.return_value = {
"data": [{"embedding": [1.0, 0.0, 0.0] if i % 2 == 0 else [0.0, 1.0, 0.0]}
for i, _ in enumerate(texts)]
}
return resp
def _collect_sse(raw: bytes) -> list[dict]:
"""Parse SSE stream bytes into a list of decoded event dicts."""
events = []
for line in raw.decode().splitlines():
if line.startswith("data: "):
events.append(json.loads(line[6:]))
return events
def test_run_single_model_returns_result_and_done(client, tmp_path):
import yaml
(tmp_path / "label_tool.yaml").write_text(yaml.dump({"cforch": {"ollama_url": "http://localhost:11434"}}))
with patch("app.eval.embed_bench.httpx.post", return_value=_fake_embed_response(["chunk 1", "chunk 2"])):
r = client.post("/api/embed-bench/run", json={
"corpus": ["chunk 1", "chunk 2"],
"queries": ["what is chunk one?"],
"models": ["nomic-embed-text"],
"top_k": 2,
})
assert r.status_code == 200
events = _collect_sse(r.content)
types = [e["type"] for e in events]
assert "result" in types
assert types[-1] == "done"
result_events = [e for e in events if e["type"] == "result"]
assert result_events[0]["model"] == "nomic-embed-text"
assert result_events[0]["query_idx"] == 0
assert len(result_events[0]["hits"]) <= 2
def test_run_two_models_returns_two_result_events_per_query(client, tmp_path):
import yaml
(tmp_path / "label_tool.yaml").write_text(yaml.dump({"cforch": {"ollama_url": "http://localhost:11434"}}))
with patch("app.eval.embed_bench.httpx.post", return_value=_fake_embed_response(["chunk A", "chunk B"])):
r = client.post("/api/embed-bench/run", json={
"corpus": ["chunk A", "chunk B"],
"queries": ["find it"],
"models": ["nomic-embed-text", "mxbai-embed-large"],
"top_k": 2,
})
events = _collect_sse(r.content)
result_events = [e for e in events if e["type"] == "result"]
models_seen = {e["model"] for e in result_events}
assert "nomic-embed-text" in models_seen
assert "mxbai-embed-large" in models_seen
# ── rate + export ──────────────────────────────────────────────────────────────
def test_rate_appends_jsonl_line(client, tmp_path):
r = client.post("/api/embed-bench/rate", json={
"query": "test query",
"model": "nomic-embed-text",
"chunk_text": "some text",
"chunk_idx": 2,
"rating": "relevant",
})
assert r.status_code == 200
assert r.json() == {"ok": True}
ratings_file = tmp_path / "embed_bench_ratings.jsonl"
assert ratings_file.exists()
line = json.loads(ratings_file.read_text().strip())
assert line["query"] == "test query"
assert line["rating"] == "relevant"
assert line["chunk_idx"] == 2
assert "timestamp" in line
def test_export_csv_two_rows(client, tmp_path):
for i in range(2):
client.post("/api/embed-bench/rate", json={
"query": f"q{i}", "model": "nomic-embed-text",
"chunk_text": f"chunk {i}", "chunk_idx": i, "rating": "relevant",
})
r = client.get("/api/embed-bench/export?format=csv")
assert r.status_code == 200
assert "text/csv" in r.headers["content-type"]
lines = r.text.strip().splitlines()
assert len(lines) == 3 # header + 2 rows
assert "query" in lines[0]
def test_export_json_two_entries(client, tmp_path):
for i in range(2):
client.post("/api/embed-bench/rate", json={
"query": f"q{i}", "model": "nomic-embed-text",
"chunk_text": f"chunk {i}", "chunk_idx": i, "rating": "not_relevant",
})
r = client.get("/api/embed-bench/export?format=json")
assert r.status_code == 200
data = r.json()
assert isinstance(data, list)
assert len(data) == 2
assert data[0]["rating"] == "not_relevant"
def test_export_empty_returns_csv_header_only(client):
r = client.get("/api/embed-bench/export?format=csv")
assert r.status_code == 200
lines = r.text.strip().splitlines()
assert len(lines) == 1 # header only
assert "query" in lines[0]

View file

@ -321,6 +321,7 @@ def test_load_and_prepare_data_single_path_still_works(tmp_path):
# ---- Integration test ----
@pytest.mark.gpu
def test_integration_finetune_on_example_data(tmp_path):
"""Fine-tune deberta-small on example data for 1 epoch.

View file

@ -1,4 +1,4 @@
"""Tests for app/imitate.py product registry, sample extraction, corrections push."""
"""Tests for app/imitate.py -- product registry, sample extraction, corrections push."""
from __future__ import annotations
import json
@ -9,10 +9,10 @@ import pytest
from fastapi.testclient import TestClient
from app.api import app
from app import imitate as _imitate_module
from app.data import imitate as _imitate_module
# ── Fixtures ───────────────────────────────────────────────────────────────────
# -- Fixtures ------------------------------------------------------------------
@pytest.fixture(autouse=True)
def reset_module_globals(tmp_path):
@ -70,7 +70,7 @@ def client() -> TestClient:
return TestClient(app, raise_server_exceptions=True)
# ── GET /products ──────────────────────────────────────────────────────────────
# -- GET /products -------------------------------------------------------------
def test_products_empty_when_no_config(config_dir, client):
"""Returns empty list when label_tool.yaml has no imitate section."""
@ -102,7 +102,7 @@ def test_products_offline_when_unreachable(cfg_with_products, client):
assert all(not p["online"] for p in resp.json()["products"])
# ── GET /products/{id}/sample ─────────────────────────────────────────────────
# -- GET /products/{id}/sample -------------------------------------------------
def test_sample_unknown_product(cfg_with_products, client):
"""Returns 404 for a product id not in config."""
@ -149,7 +149,7 @@ def test_sample_404_on_empty_list(cfg_with_products, client):
assert resp.status_code == 404
# ── POST /push-corrections ─────────────────────────────────────────────────────
# -- POST /push-corrections ----------------------------------------------------
def test_push_corrections_appends_jsonl(cfg_with_products, data_dir, client):
"""Successful push writes records to sft_candidates.jsonl."""
@ -214,7 +214,7 @@ def test_push_corrections_all_errors_422(cfg_with_products, data_dir, client):
assert resp.status_code == 422
# ── _extract_sample helper ─────────────────────────────────────────────────────
# -- _extract_sample helper ----------------------------------------------------
def test_extract_sample_list():
result = _imitate_module._extract_sample(

454
tests/test_log_corpus.py Normal file
View file

@ -0,0 +1,454 @@
"""Tests for app/data/log_corpus.py — corpus receiver and labeling endpoints."""
from __future__ import annotations
import json
import uuid
from pathlib import Path
import pytest
from fastapi.testclient import TestClient
from app.data import log_corpus as lc
VALID_TOKEN = str(uuid.uuid4())
VALID_HOST = "testnode.local"
@pytest.fixture(autouse=True)
def isolated_db(tmp_path, monkeypatch):
"""Each test gets its own fresh corpus DB and config dir."""
monkeypatch.setattr(lc, "_DATA_DIR", tmp_path)
monkeypatch.setattr(lc, "_DB_PATH", tmp_path / "corpus.db")
# Config dir pointing to a temp yaml with one test source
config_dir = tmp_path / "config"
config_dir.mkdir()
(config_dir / "label_tool.yaml").write_text(
f"corpus:\n sources:\n"
f" - token: \"{VALID_TOKEN}\"\n"
f" source_host: \"{VALID_HOST}\"\n"
f" owner: TestOwner\n"
f" consent_date: \"2026-05-11\"\n"
f" consent_method: signal_chat\n"
)
monkeypatch.setattr(lc, "_CONFIG_DIR", config_dir)
lc._init_db()
@pytest.fixture()
def client():
from fastapi import FastAPI
app = FastAPI()
app.include_router(lc.router, prefix="/api/corpus")
return TestClient(app)
def _batch(batch_type="raw_entries", entries=None, source_host=VALID_HOST):
return {
"batch_version": 1,
"batch_id": str(uuid.uuid4()),
"pushed_at": "2026-05-11T10:00:00Z",
"source_host": source_host,
"batch_type": batch_type,
"watermark_from": 0,
"watermark_to": 5,
"entries": entries or [
{
"entry_id": str(uuid.uuid4()),
"source_id": "sonarr",
"timestamp_iso": "2026-05-11T09:58:00Z",
"severity": "ERROR",
"text": "Connection refused to indexer",
"matched_patterns": [],
}
],
}
# ── Receive endpoint ───────────────────────────────────────────────────────────
def test_receive_missing_auth(client):
resp = client.post("/api/corpus/log-batch", json=_batch())
assert resp.status_code == 401
def test_receive_invalid_token(client):
resp = client.post(
"/api/corpus/log-batch",
json=_batch(),
headers={"Authorization": "Bearer bad-token"},
)
assert resp.status_code == 403
def test_receive_valid_batch(client):
resp = client.post(
"/api/corpus/log-batch",
json=_batch(),
headers={"Authorization": f"Bearer {VALID_TOKEN}"},
)
assert resp.status_code == 200
data = resp.json()
assert data["received"] is True
assert data["entries_stored"] == 1
def test_receive_stores_source_host_from_token_not_payload(client):
"""source_host is always taken from the DB lookup, not the payload."""
payload = _batch(source_host="attacker-injected-host")
resp = client.post(
"/api/corpus/log-batch",
json=payload,
headers={"Authorization": f"Bearer {VALID_TOKEN}"},
)
assert resp.status_code == 200
entries_resp = client.get("/api/corpus/entries")
entry = entries_resp.json()["entries"][0]
assert entry["source_host"] == VALID_HOST
def test_receive_skips_empty_text_entries(client):
payload = _batch(entries=[
{"entry_id": "e1", "source_id": "svc", "severity": "ERROR", "text": ""},
{"entry_id": "e2", "source_id": "svc", "severity": "ERROR", "text": " "},
{"entry_id": "e3", "source_id": "svc", "severity": "ERROR", "text": "real error"},
])
resp = client.post(
"/api/corpus/log-batch",
json=payload,
headers={"Authorization": f"Bearer {VALID_TOKEN}"},
)
assert resp.json()["entries_stored"] == 1
def test_receive_incident_bundle(client):
payload = _batch(batch_type="incident_bundles", entries=[
{"id": "inc-1", "label": "plex crash", "issue_type": "plex",
"started_at": "2026-05-11T09:00:00", "ended_at": "2026-05-11T09:30:00",
"notes": "audio dropped", "created_at": "2026-05-11T09:35:00",
"severity": "high", "text": "plex crash"},
])
resp = client.post(
"/api/corpus/log-batch",
json=payload,
headers={"Authorization": f"Bearer {VALID_TOKEN}"},
)
assert resp.status_code == 200
assert resp.json()["entries_stored"] == 1
# ── Labeling endpoints ─────────────────────────────────────────────────────────
def test_label_entry(client):
client.post(
"/api/corpus/log-batch",
json=_batch(),
headers={"Authorization": f"Bearer {VALID_TOKEN}"},
)
entry_id = client.get("/api/corpus/entries").json()["entries"][0]["id"]
resp = client.post(f"/api/corpus/entries/{entry_id}/label", json={
"failure_type": "software",
"plain_explanation": "Sonarr lost connection to its indexer — restart the service.",
"known_pattern": "y",
})
assert resp.status_code == 200
assert resp.json()["labeled"] is True
entries = client.get("/api/corpus/entries", params={"state": "labeled"}).json()["entries"]
assert len(entries) == 1
assert entries[0]["failure_type"] == "software"
def test_label_entry_invalid_failure_type(client):
client.post(
"/api/corpus/log-batch",
json=_batch(),
headers={"Authorization": f"Bearer {VALID_TOKEN}"},
)
entry_id = client.get("/api/corpus/entries").json()["entries"][0]["id"]
resp = client.post(f"/api/corpus/entries/{entry_id}/label", json={"failure_type": "aliens"})
assert resp.status_code == 422
def test_label_entry_missing_failure_type(client):
client.post(
"/api/corpus/log-batch",
json=_batch(),
headers={"Authorization": f"Bearer {VALID_TOKEN}"},
)
entry_id = client.get("/api/corpus/entries").json()["entries"][0]["id"]
resp = client.post(f"/api/corpus/entries/{entry_id}/label", json={})
assert resp.status_code == 422
def test_label_entry_not_found(client):
resp = client.post("/api/corpus/entries/nonexistent/label", json={"failure_type": "software"})
assert resp.status_code == 404
def test_skip_entry(client):
client.post(
"/api/corpus/log-batch",
json=_batch(),
headers={"Authorization": f"Bearer {VALID_TOKEN}"},
)
entry_id = client.get("/api/corpus/entries").json()["entries"][0]["id"]
resp = client.post(f"/api/corpus/entries/{entry_id}/skip")
assert resp.status_code == 200
unlabeled = client.get("/api/corpus/entries").json()["entries"]
assert len(unlabeled) == 0
# ── Stats ──────────────────────────────────────────────────────────────────────
def test_stats_empty(client):
stats = client.get("/api/corpus/stats").json()
assert stats["total_entries"] == 0
assert stats["batch_count"] == 0
def test_stats_after_receive(client):
client.post(
"/api/corpus/log-batch",
json=_batch(),
headers={"Authorization": f"Bearer {VALID_TOKEN}"},
)
stats = client.get("/api/corpus/stats").json()
assert stats["total_entries"] == 1
assert stats["batch_count"] == 1
assert stats["by_label_state"].get("unlabeled", 0) == 1
# ── Export ─────────────────────────────────────────────────────────────────────
def test_export_excludes_unlabeled(client):
client.post(
"/api/corpus/log-batch",
json=_batch(),
headers={"Authorization": f"Bearer {VALID_TOKEN}"},
)
resp = client.get("/api/corpus/export")
assert resp.status_code == 200
assert resp.text.strip() == ""
def test_export_includes_labeled(client):
client.post(
"/api/corpus/log-batch",
json=_batch(),
headers={"Authorization": f"Bearer {VALID_TOKEN}"},
)
entry_id = client.get("/api/corpus/entries").json()["entries"][0]["id"]
client.post(f"/api/corpus/entries/{entry_id}/label", json={
"failure_type": "software",
"plain_explanation": "Sonarr lost connection to indexer.",
})
resp = client.get("/api/corpus/export")
assert resp.status_code == 200
lines = [l for l in resp.text.strip().splitlines() if l]
assert len(lines) == 1
record = json.loads(lines[0])
assert record["output"] == "Sonarr lost connection to indexer."
assert record["metadata"]["failure_type"] == "software"
def test_export_excludes_pii_flagged(client):
client.post(
"/api/corpus/log-batch",
json=_batch(),
headers={"Authorization": f"Bearer {VALID_TOKEN}"},
)
entry_id = client.get("/api/corpus/entries").json()["entries"][0]["id"]
client.post(f"/api/corpus/entries/{entry_id}/label", json={
"failure_type": "software",
"plain_explanation": "Contains username — should not export.",
"pii_flagged": True,
})
resp = client.get("/api/corpus/export")
assert resp.text.strip() == ""
# ── Pipeline ingest endpoint ───────────────────────────────────────────────────
def _make_pipeline_file(directory: Path, name: str, lines: list[dict]) -> Path:
"""Write a JSONL pipeline log file to directory."""
p = directory / name
p.write_text("\n".join(json.dumps(l) for l in lines), encoding="utf-8")
return p
_PIPELINE_LINE = {
"ts": "2026-05-17T10:00:00Z",
"level": "INFO",
"logger": "scripts.pipeline.purple_carrot_scraper",
"msg": "Fetched recipe page",
"extra": {"url": "https://example.com/recipe/1", "status": 200},
}
def test_pipeline_ingest_returns_404_when_dir_not_configured(client, tmp_path):
"""No pipeline_ingest_dir in config — endpoint returns 404."""
resp = client.post("/api/corpus/pipeline-ingest")
assert resp.status_code == 404
def test_pipeline_ingest_empty_dir(client, tmp_path, monkeypatch):
"""Configured dir exists but is empty — returns zeros, no error."""
ingest_dir = tmp_path / "pipeline_logs"
ingest_dir.mkdir()
config_dir = tmp_path / "config"
config_dir.mkdir(exist_ok=True)
(config_dir / "label_tool.yaml").write_text(
f"corpus:\n pipeline_ingest_dir: \"{ingest_dir}\"\n sources: []\n"
)
monkeypatch.setattr(lc, "_CONFIG_DIR", config_dir)
resp = client.post("/api/corpus/pipeline-ingest")
assert resp.status_code == 200
data = resp.json()
assert data["ingested_files"] == 0
assert data["skipped_files"] == 0
assert data["entries_stored"] == 0
def test_pipeline_ingest_ingests_valid_file(client, tmp_path, monkeypatch):
"""Valid JSONL file is ingested; entries appear in corpus."""
ingest_dir = tmp_path / "pipeline_logs"
ingest_dir.mkdir()
_make_pipeline_file(ingest_dir, "scraper_20260517.jsonl", [
_PIPELINE_LINE,
{**_PIPELINE_LINE, "msg": "Saved 3 recipes", "level": "INFO"},
])
config_dir = tmp_path / "config"
config_dir.mkdir(exist_ok=True)
(config_dir / "label_tool.yaml").write_text(
f"corpus:\n pipeline_ingest_dir: \"{ingest_dir}\"\n sources: []\n"
)
monkeypatch.setattr(lc, "_CONFIG_DIR", config_dir)
resp = client.post("/api/corpus/pipeline-ingest")
assert resp.status_code == 200
data = resp.json()
assert data["ingested_files"] == 1
assert data["entries_stored"] == 2
entries = client.get("/api/corpus/entries", params={"limit": 10}).json()["entries"]
assert len(entries) == 2
assert all(e["source_host"] == "pipeline_scrape" for e in entries)
def test_pipeline_ingest_source_id_from_logger(client, tmp_path, monkeypatch):
"""source_id is populated from the 'logger' field of each log line."""
ingest_dir = tmp_path / "pipeline_logs"
ingest_dir.mkdir()
_make_pipeline_file(ingest_dir, "run_20260517.jsonl", [_PIPELINE_LINE])
config_dir = tmp_path / "config"
config_dir.mkdir(exist_ok=True)
(config_dir / "label_tool.yaml").write_text(
f"corpus:\n pipeline_ingest_dir: \"{ingest_dir}\"\n sources: []\n"
)
monkeypatch.setattr(lc, "_CONFIG_DIR", config_dir)
client.post("/api/corpus/pipeline-ingest")
entries = client.get("/api/corpus/entries", params={"limit": 10}).json()["entries"]
assert entries[0]["source_id"] == "scripts.pipeline.purple_carrot_scraper"
def test_pipeline_ingest_idempotent(client, tmp_path, monkeypatch):
"""Calling the endpoint twice does not re-ingest already-processed files."""
ingest_dir = tmp_path / "pipeline_logs"
ingest_dir.mkdir()
_make_pipeline_file(ingest_dir, "scraper_20260517.jsonl", [_PIPELINE_LINE])
config_dir = tmp_path / "config"
config_dir.mkdir(exist_ok=True)
(config_dir / "label_tool.yaml").write_text(
f"corpus:\n pipeline_ingest_dir: \"{ingest_dir}\"\n sources: []\n"
)
monkeypatch.setattr(lc, "_CONFIG_DIR", config_dir)
client.post("/api/corpus/pipeline-ingest")
resp2 = client.post("/api/corpus/pipeline-ingest")
data = resp2.json()
assert data["ingested_files"] == 0
assert data["skipped_files"] == 1
assert data["entries_stored"] == 0
entries = client.get("/api/corpus/entries", params={"limit": 10}).json()["entries"]
assert len(entries) == 1 # still just the one from the first ingest
def test_pipeline_ingest_skips_non_jsonl(client, tmp_path, monkeypatch):
"""Non-.jsonl files in the dir are silently ignored."""
ingest_dir = tmp_path / "pipeline_logs"
ingest_dir.mkdir()
(ingest_dir / "notes.txt").write_text("this is not a log file")
(ingest_dir / "run.csv").write_text("a,b,c\n1,2,3")
_make_pipeline_file(ingest_dir, "valid_20260517.jsonl", [_PIPELINE_LINE])
config_dir = tmp_path / "config"
config_dir.mkdir(exist_ok=True)
(config_dir / "label_tool.yaml").write_text(
f"corpus:\n pipeline_ingest_dir: \"{ingest_dir}\"\n sources: []\n"
)
monkeypatch.setattr(lc, "_CONFIG_DIR", config_dir)
resp = client.post("/api/corpus/pipeline-ingest")
assert resp.json()["ingested_files"] == 1
def test_pipeline_ingest_skips_malformed_lines(client, tmp_path, monkeypatch):
"""Lines that are not valid JSON are skipped; valid lines in the same file still land."""
ingest_dir = tmp_path / "pipeline_logs"
ingest_dir.mkdir()
p = ingest_dir / "mixed_20260517.jsonl"
p.write_text(
json.dumps(_PIPELINE_LINE) + "\n"
"this is not json\n"
+ json.dumps({**_PIPELINE_LINE, "msg": "another valid line"}),
encoding="utf-8",
)
config_dir = tmp_path / "config"
config_dir.mkdir(exist_ok=True)
(config_dir / "label_tool.yaml").write_text(
f"corpus:\n pipeline_ingest_dir: \"{ingest_dir}\"\n sources: []\n"
)
monkeypatch.setattr(lc, "_CONFIG_DIR", config_dir)
resp = client.post("/api/corpus/pipeline-ingest")
assert resp.status_code == 200
assert resp.json()["entries_stored"] == 2 # 2 valid lines, 1 skipped
def test_pipeline_ingest_new_file_after_first_run(client, tmp_path, monkeypatch):
"""A new file added after the first ingest is picked up on the next call."""
ingest_dir = tmp_path / "pipeline_logs"
ingest_dir.mkdir()
_make_pipeline_file(ingest_dir, "run_a.jsonl", [_PIPELINE_LINE])
config_dir = tmp_path / "config"
config_dir.mkdir(exist_ok=True)
(config_dir / "label_tool.yaml").write_text(
f"corpus:\n pipeline_ingest_dir: \"{ingest_dir}\"\n sources: []\n"
)
monkeypatch.setattr(lc, "_CONFIG_DIR", config_dir)
client.post("/api/corpus/pipeline-ingest") # ingest run_a.jsonl
_make_pipeline_file(ingest_dir, "run_b.jsonl", [
{**_PIPELINE_LINE, "msg": "Second run line"},
])
resp2 = client.post("/api/corpus/pipeline-ingest")
data = resp2.json()
assert data["ingested_files"] == 1
assert data["skipped_files"] == 1
assert data["entries_stored"] == 1

View file

@ -17,6 +17,7 @@ def reset_models_globals(tmp_path):
from app import models as models_module
prev_models = models_module._MODELS_DIR
prev_cf_text = models_module._CF_TEXT_MODELS_DIR
prev_queue = models_module._QUEUE_DIR
prev_progress = dict(models_module._download_progress)
@ -26,12 +27,14 @@ def reset_models_globals(tmp_path):
queue_dir.mkdir()
models_module.set_models_dir(models_dir)
models_module.set_cf_text_models_dir(tmp_path / "cf-text-models")
models_module.set_queue_dir(queue_dir)
models_module._download_progress = {}
yield
models_module.set_models_dir(prev_models)
models_module.set_cf_text_models_dir(prev_cf_text)
models_module.set_queue_dir(prev_queue)
models_module._download_progress = prev_progress
@ -541,3 +544,84 @@ def test_delete_installed_name_with_slash_blocked(client):
except _HTTPException as exc:
assert exc.status_code in (400, 404)
raise
# ── Catalog registration ───────────────────────────────────────────────────────
_MINIMAL_YAML = """\
services:
cf-text:
max_mb: {max_mb}
catalog:
existing-model:
path: /some/path
vram_mb: 1000
description: "placeholder"
"""
def _make_node_yaml(tmp_path: Path, max_mb: int = 8192) -> Path:
p = tmp_path / "testnode.yaml"
p.write_text(_MINIMAL_YAML.format(max_mb=max_mb), encoding="utf-8")
return p
def test_catalog_registration_fp16_no_env_block(tmp_path):
"""When model fits at FP16, no env block should be written."""
from app import models as models_module
node_yaml = _make_node_yaml(tmp_path, max_mb=8192)
with patch.object(models_module, "_CF_ORCH_PROFILES_DIR", tmp_path):
updated = models_module._register_in_node_catalogs(
repo_id="org/SmallModel",
local_path=tmp_path / "org--SmallModel",
vram_mb_fp16=4000,
role="generator",
)
assert "testnode" in updated
content = node_yaml.read_text()
# _catalog_key strips org prefix and lowercases: "org/SmallModel" → "smallmodel"
assert "smallmodel:" in content
assert "CF_TEXT_4BIT" not in content
assert "env:" not in content
def test_catalog_registration_needs_4bit_writes_env_block(tmp_path):
"""When model only fits at 4-bit, env: CF_TEXT_4BIT: '1' must be written."""
from app import models as models_module
node_yaml = _make_node_yaml(tmp_path, max_mb=8192)
with patch.object(models_module, "_CF_ORCH_PROFILES_DIR", tmp_path):
updated = models_module._register_in_node_catalogs(
repo_id="org/BigModel",
local_path=tmp_path / "org--BigModel",
vram_mb_fp16=20000, # won't fit at FP16 on 8 GB
role="generator",
)
assert "testnode" in updated
content = node_yaml.read_text()
# _catalog_key: "org/BigModel" → "bigmodel"
assert "bigmodel:" in content
assert "env:" in content
assert 'CF_TEXT_4BIT: "1"' in content
assert "CF_TEXT_4BIT=1 required" in content # description note
def test_catalog_registration_too_large_skipped(tmp_path):
"""Model too large even at 4-bit should not be registered."""
from app import models as models_module
node_yaml = _make_node_yaml(tmp_path, max_mb=8192)
with patch.object(models_module, "_CF_ORCH_PROFILES_DIR", tmp_path):
updated = models_module._register_in_node_catalogs(
repo_id="org/HugeModel",
local_path=tmp_path / "org--HugeModel",
vram_mb_fp16=80000, # 4-bit ~22 GB, still won't fit on 8 GB
role="generator",
)
assert updated == []
content = node_yaml.read_text()
assert "hugemodel" not in content

575
tests/test_nodes.py Normal file
View file

@ -0,0 +1,575 @@
"""Tests for app/nodes.py — /api/nodes-mgmt/* endpoints."""
from __future__ import annotations
from pathlib import Path
import pytest
import yaml
from fastapi.testclient import TestClient
from unittest.mock import MagicMock, patch
import os as _os
@pytest.fixture(autouse=True)
def reset_nodes_globals(tmp_path):
"""Redirect _CONFIG_DIR to tmp_path so tests never read the real config."""
from app import nodes as nodes_module
prev = nodes_module._CONFIG_DIR
nodes_module.set_config_dir(tmp_path)
yield tmp_path
nodes_module.set_config_dir(prev)
@pytest.fixture
def client():
from app.api import app
return TestClient(app)
def _write_config(config_dir: Path, cforch_cfg: dict) -> None:
cfg = {"cforch": cforch_cfg}
(config_dir / "label_tool.yaml").write_text(yaml.dump(cfg), encoding="utf-8")
def _write_profile(profiles_dir: Path, node_id: str, profile: dict) -> None:
profiles_dir.mkdir(parents=True, exist_ok=True)
(profiles_dir / f"{node_id}.yaml").write_text(yaml.dump(profile), encoding="utf-8")
def test_nodes_module_imports():
from app import nodes
assert hasattr(nodes, "router")
assert hasattr(nodes, "set_config_dir")
def test_list_nodes_returns_empty_when_no_coordinator(client):
"""No cforch config — endpoint returns empty list, not 500."""
r = client.get("/api/nodes-mgmt/nodes")
assert r.status_code == 200
assert r.json() == []
def _fake_nodes_response(nodes_json: list, services_json: list | None = None):
"""Build side_effect list for two httpx.get calls: nodes then services."""
mock_nodes = MagicMock()
mock_nodes.raise_for_status = MagicMock()
mock_nodes.json.return_value = {"nodes": nodes_json}
mock_services = MagicMock()
mock_services.raise_for_status = MagicMock()
mock_services.json.return_value = {"services": services_json or []}
return [mock_nodes, mock_services]
def test_list_nodes_coordinator_unreachable_returns_empty(client, tmp_path):
"""Coordinator unreachable — returns [] with no 500."""
import httpx
_write_config(tmp_path, {"coordinator_url": "http://fake-coord:7700"})
with patch("httpx.get", side_effect=httpx.ConnectError("refused")):
r = client.get("/api/nodes-mgmt/nodes")
assert r.status_code == 200
assert r.json() == []
def test_list_nodes_merges_profile_data(client, tmp_path):
"""Profile YAML services_assigned merged with live GPU stats."""
profiles_dir = tmp_path / "profiles"
_write_config(tmp_path, {
"coordinator_url": "http://fake-coord:7700",
"profiles_dir": str(profiles_dir),
})
_write_profile(profiles_dir, "heimdall", {
"services": {
"cf-text": {"min_compute_cap": 7.0, "max_mb": 8192, "catalog": {}},
},
"nodes": {
"heimdall": {
"gpus": [{"id": 0, "vram_mb": 24576, "compute_cap": 8.6,
"services": ["cf-text"], "role": "primary", "card": "RTX 3090",
"always_on": True}],
"agent_url": "http://10.1.10.71:7701",
}
}
})
coord_nodes = [{
"node_id": "heimdall", "online": True, "agent_url": "http://10.1.10.71:7701",
"gpus": [{"gpu_id": 0, "card": "RTX 3090", "vram_total_mb": 24576,
"vram_used_mb": 4096, "vram_free_mb": 20480,
"temp_c": 42.0, "utilization_pct": 15.0, "compute_cap": 8.6}],
}]
with patch("httpx.get", side_effect=_fake_nodes_response(coord_nodes)):
r = client.get("/api/nodes-mgmt/nodes")
assert r.status_code == 200
data = r.json()
assert len(data) == 1
node = data[0]
assert node["node_id"] == "heimdall"
assert node["profile_loaded"] is True
assert node["gpus"][0]["services_assigned"] == ["cf-text"]
assert node["gpus"][0]["vram_total_mb"] == 24576
assert "cf-text" in node["services_catalog"]
def test_list_nodes_no_profile_returns_profile_loaded_false(client, tmp_path):
"""Node with no profile YAML — profile_loaded: false, GPU stats still returned."""
_write_config(tmp_path, {"coordinator_url": "http://fake-coord:7700"})
coord_nodes = [{
"node_id": "sif", "online": True, "agent_url": "http://10.1.10.158:7701",
"gpus": [{"gpu_id": 0, "card": "RTX 5060 Ti", "vram_total_mb": 16384,
"vram_used_mb": 0, "vram_free_mb": 16384,
"temp_c": None, "utilization_pct": None, "compute_cap": 10.0}],
}]
with patch("httpx.get", side_effect=_fake_nodes_response(coord_nodes)):
r = client.get("/api/nodes-mgmt/nodes")
assert r.status_code == 200
data = r.json()
node = data[0]
assert node["profile_loaded"] is False
assert node["gpus"][0]["card"] == "RTX 5060 Ti"
assert node["services_catalog"] == {}
def test_list_nodes_marks_running_services(client, tmp_path):
"""services_running populated from coordinator /api/services response."""
profiles_dir = tmp_path / "profiles"
_write_config(tmp_path, {
"coordinator_url": "http://fake-coord:7700",
"profiles_dir": str(profiles_dir),
})
_write_profile(profiles_dir, "heimdall", {
"services": {},
"nodes": {"heimdall": {"gpus": [{"id": 0, "vram_mb": 24576, "compute_cap": 8.6,
"services": ["cf-text"], "role": "p",
"card": "RTX 3090", "always_on": True}],
"agent_url": "http://10.1.10.71:7701"}}
})
coord_nodes = [{"node_id": "heimdall", "online": True,
"agent_url": "http://10.1.10.71:7701",
"gpus": [{"gpu_id": 0, "card": "RTX 3090", "vram_total_mb": 24576,
"vram_used_mb": 8192, "vram_free_mb": 16384,
"temp_c": 55.0, "utilization_pct": 80.0, "compute_cap": 8.6}]}]
coord_services = [{"service": "cf-text", "node_id": "heimdall", "gpu_id": 0}]
with patch("httpx.get", side_effect=_fake_nodes_response(coord_nodes, coord_services)):
r = client.get("/api/nodes-mgmt/nodes")
data = r.json()
assert data[0]["gpus"][0]["services_running"] == ["cf-text"]
# ── GET /api/nodes-mgmt/nodes/{node_id}/profile ────────────────────────────────
def test_get_profile_returns_parsed_yaml(client, tmp_path):
profiles_dir = tmp_path / "profiles"
_write_config(tmp_path, {"profiles_dir": str(profiles_dir)})
profile = {
"services": {"cf-text": {"min_compute_cap": 7.0, "max_mb": 8192, "catalog": {}}},
"nodes": {"heimdall": {"gpus": [], "agent_url": "http://10.1.10.71:7701"}},
}
_write_profile(profiles_dir, "heimdall", profile)
r = client.get("/api/nodes-mgmt/nodes/heimdall/profile")
assert r.status_code == 200
data = r.json()
assert "services" in data
assert "cf-text" in data["services"]
def test_get_profile_404_when_missing(client, tmp_path):
_write_config(tmp_path, {"profiles_dir": str(tmp_path / "profiles")})
r = client.get("/api/nodes-mgmt/nodes/nonexistent/profile")
assert r.status_code == 404
def test_get_profile_500_on_malformed_yaml(client, tmp_path):
profiles_dir = tmp_path / "profiles"
profiles_dir.mkdir()
_write_config(tmp_path, {"profiles_dir": str(profiles_dir)})
(profiles_dir / "bad.yaml").write_text("key: [unclosed", encoding="utf-8")
r = client.get("/api/nodes-mgmt/nodes/bad/profile")
assert r.status_code == 500
# ── POST /api/nodes-mgmt/nodes/{node_id}/gpu/{gpu_id}/services ─────────────────
_BASE_PROFILE = {
"services": {
"cf-text": {"min_compute_cap": 7.0, "max_mb": 8192, "priority": 1,
"catalog": {"llama3": {"vram_mb": 6144, "path": "/m/llama3",
"description": "", "multi_gpu": False, "env": {}}}},
"ollama": {"min_compute_cap": 0.0, "max_mb": 2048, "priority": 2, "catalog": {}},
},
"nodes": {
"heimdall": {
"gpus": [{"id": 0, "vram_mb": 24576, "compute_cap": 8.6,
"services": [], "role": "primary", "card": "RTX 3090",
"always_on": True}],
"agent_url": "http://10.1.10.71:7701",
}
}
}
def _setup_profile(tmp_path, profile=None):
profiles_dir = tmp_path / "profiles"
_write_config(tmp_path, {
"coordinator_url": "http://fake-coord:7700",
"profiles_dir": str(profiles_dir),
})
_write_profile(profiles_dir, "heimdall", profile or _BASE_PROFILE)
return profiles_dir
def test_update_services_compatible_writes_and_reloads(client, tmp_path):
profiles_dir = _setup_profile(tmp_path)
mock_reload = MagicMock()
mock_reload.status_code = 200
with patch("httpx.post", return_value=mock_reload):
r = client.post(
"/api/nodes-mgmt/nodes/heimdall/gpu/0/services",
json={"services": ["cf-text"]},
)
assert r.status_code == 200
data = r.json()
assert data["ok"] is True
assert data["reloaded"] is True
saved = yaml.safe_load((profiles_dir / "heimdall.yaml").read_text())
assert saved["nodes"]["heimdall"]["gpus"][0]["services"] == ["cf-text"]
def test_update_services_atomic_write_uses_tmp_file(client, tmp_path):
"""YAML must be written to .tmp then renamed — never written directly."""
profiles_dir = _setup_profile(tmp_path)
renamed_pairs: list[tuple] = []
original_replace = _os.replace
def capture(src, dst):
renamed_pairs.append((str(src), str(dst)))
original_replace(src, dst)
with patch("os.replace", side_effect=capture), \
patch("httpx.post", return_value=MagicMock(status_code=200)):
client.post(
"/api/nodes-mgmt/nodes/heimdall/gpu/0/services",
json={"services": ["ollama"]},
)
assert any(src.endswith(".tmp") for src, dst in renamed_pairs), \
"Expected atomic write via .tmp rename"
def test_update_services_incompatible_compute_cap_returns_422(client, tmp_path):
low_cap_profile = {
**_BASE_PROFILE,
"nodes": {
"heimdall": {
"gpus": [{"id": 0, "vram_mb": 24576, "compute_cap": 6.0,
"services": [], "role": "p", "card": "GTX 1080",
"always_on": False}],
"agent_url": "http://10.1.10.71:7701",
}
}
}
_setup_profile(tmp_path, low_cap_profile)
r = client.post(
"/api/nodes-mgmt/nodes/heimdall/gpu/0/services",
json={"services": ["cf-text"]},
)
assert r.status_code == 422
assert "compute_cap" in r.json()["detail"]
def test_update_services_insufficient_vram_returns_422(client, tmp_path):
tiny_vram_profile = {
**_BASE_PROFILE,
"nodes": {
"heimdall": {
"gpus": [{"id": 0, "vram_mb": 512, "compute_cap": 8.6,
"services": [], "role": "p", "card": "old",
"always_on": False}],
"agent_url": "http://10.1.10.71:7701",
}
}
}
_setup_profile(tmp_path, tiny_vram_profile)
r = client.post(
"/api/nodes-mgmt/nodes/heimdall/gpu/0/services",
json={"services": ["cf-text"]},
)
assert r.status_code == 422
assert "VRAM" in r.json()["detail"]
def test_update_services_unknown_service_returns_422(client, tmp_path):
_setup_profile(tmp_path)
r = client.post(
"/api/nodes-mgmt/nodes/heimdall/gpu/0/services",
json={"services": ["not-a-real-service"]},
)
assert r.status_code == 422
def test_update_services_reload_failure_returns_reloaded_false(client, tmp_path):
"""YAML saved but coordinator reload fails — ok: true, reloaded: false."""
_setup_profile(tmp_path)
mock_reload = MagicMock()
mock_reload.status_code = 500
with patch("httpx.post", return_value=mock_reload):
r = client.post(
"/api/nodes-mgmt/nodes/heimdall/gpu/0/services",
json={"services": ["ollama"]},
)
assert r.status_code == 200
data = r.json()
assert data["ok"] is True
assert data["reloaded"] is False
# ── Ollama endpoints ───────────────────────────────────────────────────────────
_OLLAMA_PROFILE = {
"services": {},
"nodes": {
"heimdall": {
"gpus": [],
"agent_url": "http://10.1.10.71:7701",
}
}
}
def test_list_ollama_models_proxies_tags(client, tmp_path):
profiles_dir = tmp_path / "profiles"
_write_config(tmp_path, {"profiles_dir": str(profiles_dir)})
_write_profile(profiles_dir, "heimdall", _OLLAMA_PROFILE)
mock_tags = MagicMock()
mock_tags.raise_for_status = MagicMock()
mock_tags.json.return_value = {
"models": [{"name": "nomic-embed-text", "size": 274000000, "modified_at": "2025-01-01"}]
}
with patch("httpx.get", return_value=mock_tags):
r = client.get("/api/nodes-mgmt/nodes/heimdall/models/ollama")
assert r.status_code == 200
data = r.json()
assert len(data["models"]) == 1
assert data["models"][0]["name"] == "nomic-embed-text"
def test_list_ollama_models_unreachable_returns_error(client, tmp_path):
import httpx as _httpx
profiles_dir = tmp_path / "profiles"
_write_config(tmp_path, {"profiles_dir": str(profiles_dir)})
_write_profile(profiles_dir, "heimdall", _OLLAMA_PROFILE)
with patch("httpx.get", side_effect=_httpx.ConnectError("refused")):
r = client.get("/api/nodes-mgmt/nodes/heimdall/models/ollama")
assert r.status_code == 200
data = r.json()
assert "error" in data
def test_pull_ollama_model_streams_sse(client, tmp_path):
profiles_dir = tmp_path / "profiles"
_write_config(tmp_path, {"profiles_dir": str(profiles_dir)})
_write_profile(profiles_dir, "heimdall", _OLLAMA_PROFILE)
mock_resp = MagicMock()
mock_resp.iter_lines.return_value = iter([
'{"status": "pulling manifest"}',
'{"status": "pulling", "digest": "sha256-abc", "total": 1000, "completed": 500}',
'{"status": "success"}',
])
with patch("httpx.stream") as mock_stream_fn:
mock_stream_fn.return_value.__enter__ = MagicMock(return_value=mock_resp)
mock_stream_fn.return_value.__exit__ = MagicMock(return_value=False)
r = client.post(
"/api/nodes-mgmt/nodes/heimdall/models/ollama/pull",
json={"name": "nomic-embed-text"},
)
assert r.status_code == 200
body = r.text
assert 'data: {"status": "pulling manifest"}' in body
assert 'data: {"status": "success"}' in body
def test_pull_ollama_model_error_event_in_stream(client, tmp_path):
profiles_dir = tmp_path / "profiles"
_write_config(tmp_path, {"profiles_dir": str(profiles_dir)})
_write_profile(profiles_dir, "heimdall", _OLLAMA_PROFILE)
mock_resp = MagicMock()
mock_resp.iter_lines.return_value = iter([
'{"error": "permission denied: /var/lib/ollama/sha256-abc-partial-0"}',
])
with patch("httpx.stream") as mock_stream_fn:
mock_stream_fn.return_value.__enter__ = MagicMock(return_value=mock_resp)
mock_stream_fn.return_value.__exit__ = MagicMock(return_value=False)
r = client.post(
"/api/nodes-mgmt/nodes/heimdall/models/ollama/pull",
json={"name": "nomic-embed-text"},
)
assert r.status_code == 200
assert "permission denied" in r.text
def test_delete_ollama_model_proxies_delete(client, tmp_path):
profiles_dir = tmp_path / "profiles"
_write_config(tmp_path, {"profiles_dir": str(profiles_dir)})
_write_profile(profiles_dir, "heimdall", _OLLAMA_PROFILE)
mock_del = MagicMock()
mock_del.status_code = 200
mock_del.raise_for_status = MagicMock()
with patch("httpx.request", return_value=mock_del):
r = client.delete("/api/nodes-mgmt/nodes/heimdall/models/ollama/nomic-embed-text")
assert r.status_code == 200
assert r.json() == {"ok": True}
def test_delete_ollama_model_404_when_not_found(client, tmp_path):
profiles_dir = tmp_path / "profiles"
_write_config(tmp_path, {"profiles_dir": str(profiles_dir)})
_write_profile(profiles_dir, "heimdall", _OLLAMA_PROFILE)
mock_del = MagicMock()
mock_del.status_code = 404
with patch("httpx.request", return_value=mock_del):
r = client.delete("/api/nodes-mgmt/nodes/heimdall/models/ollama/missing-model")
assert r.status_code == 404
# ── Deploy model endpoint ──────────────────────────────────────────────────────
_DEPLOY_PROFILE = {
"services": {
"cf-text": {
"max_mb": 20000,
"min_compute_cap": 7.0,
"model_base_path": "/devl/Assets/LLM/cf-text/models",
"catalog": {},
},
},
"nodes": {
"heimdall": {
"gpus": [],
"agent_url": "http://10.1.10.71:7701",
}
}
}
def test_deploy_model_adds_catalog_entry(client, tmp_path):
"""Deploy endpoint should add the model to the service catalog."""
profiles_dir = tmp_path / "profiles"
_write_config(tmp_path, {
"coordinator_url": "http://fake-coord:7700",
"profiles_dir": str(profiles_dir),
})
_write_profile(profiles_dir, "heimdall", _DEPLOY_PROFILE)
mock_reload = MagicMock()
mock_reload.status_code = 200
with patch("httpx.post", return_value=mock_reload):
r = client.post(
"/api/nodes-mgmt/nodes/heimdall/models/deploy",
json={
"model_id": "fdtn-ai--Foundation-Sec-8B-Q4",
"service_type": "cf-text",
"vram_mb": 5180,
"hf_repo": "fdtn-ai/Foundation-Sec-8B-Q4_K_M-GGUF",
},
)
assert r.status_code == 200
data = r.json()
assert data["ok"] is True
assert data["reloaded"] is True
assert "fdtn-ai--Foundation-Sec-8B-Q4_K_M-GGUF" in data["path"]
saved = yaml.safe_load((profiles_dir / "heimdall.yaml").read_text())
catalog = saved["services"]["cf-text"]["catalog"]
assert "fdtn-ai--Foundation-Sec-8B-Q4" in catalog
entry = catalog["fdtn-ai--Foundation-Sec-8B-Q4"]
assert entry["vram_mb"] == 5180
assert entry["path"].endswith("fdtn-ai--Foundation-Sec-8B-Q4_K_M-GGUF")
def test_deploy_model_explicit_path_overrides_base(client, tmp_path):
"""An explicit path in the request body takes precedence over model_base_path."""
profiles_dir = tmp_path / "profiles"
_write_config(tmp_path, {
"coordinator_url": "http://fake-coord:7700",
"profiles_dir": str(profiles_dir),
})
_write_profile(profiles_dir, "heimdall", _DEPLOY_PROFILE)
with patch("httpx.post", return_value=MagicMock(status_code=200)):
r = client.post(
"/api/nodes-mgmt/nodes/heimdall/models/deploy",
json={
"model_id": "my-model",
"service_type": "cf-text",
"vram_mb": 8000,
"path": "/custom/path/to/model",
},
)
assert r.status_code == 200
assert r.json()["path"] == "/custom/path/to/model"
def test_deploy_model_unknown_service_returns_422(client, tmp_path):
"""Service type not in profile → 422."""
profiles_dir = tmp_path / "profiles"
_write_config(tmp_path, {"profiles_dir": str(profiles_dir)})
_write_profile(profiles_dir, "heimdall", _DEPLOY_PROFILE)
r = client.post(
"/api/nodes-mgmt/nodes/heimdall/models/deploy",
json={"model_id": "x", "service_type": "vllm", "vram_mb": 8000},
)
assert r.status_code == 422
assert "vllm" in r.json()["detail"]
def test_deploy_model_missing_profile_returns_404(client, tmp_path):
_write_config(tmp_path, {"profiles_dir": str(tmp_path / "profiles")})
r = client.post(
"/api/nodes-mgmt/nodes/nonexistent/models/deploy",
json={"model_id": "x", "service_type": "cf-text", "vram_mb": 100},
)
assert r.status_code == 404

227
tests/test_recipe_scan.py Normal file
View file

@ -0,0 +1,227 @@
"""Tests for app/data/recipe_scan.py — recipe scan labeling endpoints."""
from __future__ import annotations
import json
import uuid
from pathlib import Path
import pytest
from fastapi.testclient import TestClient
from app.data import recipe_scan as rs
EXTRACTED = {"title": "Shepherd's Pie", "ingredients": ["lamb", "potato"], "steps": ["brown meat", "mash potato"]}
GROUND_TRUTH = {"title": "Shepherd's Pie", "ingredients": ["ground lamb", "mashed potato", "peas"], "steps": ["brown meat", "add veg", "mash potato", "bake"]}
@pytest.fixture(autouse=True)
def isolated_db(tmp_path, monkeypatch):
monkeypatch.setattr(rs, "_DB_PATH", tmp_path / "recipe_scan.db")
rs._init_db()
@pytest.fixture()
def client():
from fastapi import FastAPI
app = FastAPI()
app.include_router(rs.router, prefix="/api/recipe-scan")
return TestClient(app)
def _item(**kwargs) -> dict:
return {
"id": str(uuid.uuid4()),
"image_path": "/Library/Assets/kiwi/scans/pc_test.jpg",
"modality": kwargs.get("modality", "scanner"),
"source": kwargs.get("source", "purple_carrot"),
"extracted": kwargs.get("extracted", EXTRACTED),
"ground_truth": kwargs.get("ground_truth", GROUND_TRUTH),
}
def _import(client, items: list[dict]) -> None:
resp = client.post("/api/recipe-scan/import", json={"items": items})
assert resp.status_code == 200
# ── Import ─────────────────────────────────────────────────────────────────────
def test_import_stores_items(client):
_import(client, [_item()])
stats = client.get("/api/recipe-scan/stats").json()
assert stats["total"] == 1
assert stats["by_status"]["pending"] == 1
def test_import_rejects_unknown_modality(client):
bad = _item()
bad["modality"] = "telepathy"
resp = client.post("/api/recipe-scan/import", json={"items": [bad]})
assert resp.status_code == 422
def test_import_is_idempotent(client):
item = _item()
_import(client, [item])
_import(client, [item]) # same id — should not duplicate
stats = client.get("/api/recipe-scan/stats").json()
assert stats["total"] == 1
def test_import_multiple_items(client):
_import(client, [_item(), _item(), _item()])
assert client.get("/api/recipe-scan/stats").json()["total"] == 3
# ── Next ───────────────────────────────────────────────────────────────────────
def test_next_returns_404_when_queue_empty(client):
resp = client.get("/api/recipe-scan/next")
assert resp.status_code == 404
def test_next_returns_pending_item(client):
item = _item()
_import(client, [item])
resp = client.get("/api/recipe-scan/next")
assert resp.status_code == 200
data = resp.json()
assert data["id"] == item["id"]
assert data["status"] == "pending"
assert "extracted" in data
assert "ground_truth" in data
def test_next_skips_non_pending(client):
item = _item()
_import(client, [item])
client.post(f"/api/recipe-scan/items/{item['id']}/reject")
resp = client.get("/api/recipe-scan/next")
assert resp.status_code == 404
# ── Approve ────────────────────────────────────────────────────────────────────
def test_approve_marks_item_approved(client):
item = _item()
_import(client, [item])
resp = client.post(f"/api/recipe-scan/items/{item['id']}/approve")
assert resp.status_code == 200
assert resp.json()["status"] == "approved"
stats = client.get("/api/recipe-scan/stats").json()
assert stats["by_status"]["approved"] == 1
def test_approve_returns_404_for_unknown_id(client):
resp = client.post("/api/recipe-scan/items/no-such-id/approve")
assert resp.status_code == 404
# ── Edit ───────────────────────────────────────────────────────────────────────
def test_edit_stores_corrected_json(client):
item = _item()
_import(client, [item])
corrected = {**GROUND_TRUTH, "servings": 4}
resp = client.post(
f"/api/recipe-scan/items/{item['id']}/edit",
json={"corrected": corrected},
)
assert resp.status_code == 200
assert resp.json()["status"] == "edited"
stats = client.get("/api/recipe-scan/stats").json()
assert stats["by_status"]["edited"] == 1
def test_edit_requires_corrected_field(client):
item = _item()
_import(client, [item])
resp = client.post(f"/api/recipe-scan/items/{item['id']}/edit", json={})
assert resp.status_code == 422
# ── Reject ─────────────────────────────────────────────────────────────────────
def test_reject_marks_item_rejected(client):
item = _item()
_import(client, [item])
resp = client.post(
f"/api/recipe-scan/items/{item['id']}/reject",
json={"reason": "OCR completely unreadable"},
)
assert resp.status_code == 200
assert resp.json()["status"] == "rejected"
def test_reject_without_reason_is_valid(client):
item = _item()
_import(client, [item])
resp = client.post(f"/api/recipe-scan/items/{item['id']}/reject")
assert resp.status_code == 200
# ── Export ─────────────────────────────────────────────────────────────────────
def test_export_empty_when_nothing_approved(client):
item = _item()
_import(client, [item])
resp = client.get("/api/recipe-scan/export")
assert resp.status_code == 200
assert resp.text.strip() == ""
def test_export_includes_approved_item(client):
item = _item()
_import(client, [item])
client.post(f"/api/recipe-scan/items/{item['id']}/approve")
resp = client.get("/api/recipe-scan/export")
lines = [l for l in resp.text.strip().splitlines() if l]
assert len(lines) == 1
pair = json.loads(lines[0])
assert pair["id"] == item["id"]
assert pair["modality"] == "scanner"
assert "messages" in pair
assert len(pair["messages"]) == 2
assert pair["messages"][0]["role"] == "user"
assert pair["messages"][1]["role"] == "assistant"
def test_export_includes_edited_item_with_correction(client):
item = _item()
_import(client, [item])
corrected = {**GROUND_TRUTH, "servings": 4}
client.post(
f"/api/recipe-scan/items/{item['id']}/edit",
json={"corrected": corrected},
)
resp = client.get("/api/recipe-scan/export")
lines = [l for l in resp.text.strip().splitlines() if l]
pair = json.loads(lines[0])
assistant_content = json.loads(pair["messages"][1]["content"])
assert assistant_content["servings"] == 4
def test_export_excludes_rejected_items(client):
item = _item()
_import(client, [item])
client.post(f"/api/recipe-scan/items/{item['id']}/reject")
resp = client.get("/api/recipe-scan/export")
assert resp.text.strip() == ""
# ── Stats ──────────────────────────────────────────────────────────────────────
def test_stats_counts_all_statuses(client):
items = [_item(), _item(), _item(), _item()]
_import(client, items)
client.post(f"/api/recipe-scan/items/{items[0]['id']}/approve")
client.post(f"/api/recipe-scan/items/{items[1]['id']}/edit", json={"corrected": GROUND_TRUTH})
client.post(f"/api/recipe-scan/items/{items[2]['id']}/reject")
stats = client.get("/api/recipe-scan/stats").json()
assert stats["total"] == 4
assert stats["by_status"]["pending"] == 1
assert stats["by_status"]["approved"] == 1
assert stats["by_status"]["edited"] == 1
assert stats["by_status"]["rejected"] == 1
assert stats["export_ready"] == 2 # approved + edited

View file

@ -1,4 +1,4 @@
"""API integration tests for app/sft.py /api/sft/* endpoints."""
"""API integration tests for app/sft.py -- /api/sft/* endpoints."""
import json
import pytest
from fastapi.testclient import TestClient
@ -7,17 +7,17 @@ from pathlib import Path
@pytest.fixture(autouse=True)
def reset_sft_globals(tmp_path):
from app import sft as sft_module
_prev_data = sft_module._SFT_DATA_DIR
_prev_cfg = sft_module._SFT_CONFIG_DIR
_prev_default = sft_module._DEFAULT_BENCH_RESULTS_DIR
sft_module.set_sft_data_dir(tmp_path)
sft_module.set_sft_config_dir(tmp_path)
sft_module.set_default_bench_results_dir(str(tmp_path / "bench_results"))
from app.data import corrections as corr_module
_prev_data = corr_module._DATA_DIR
_prev_cfg = corr_module._CONFIG_DIR
_prev_default = corr_module._DEFAULT_BENCH_RESULTS_DIR
corr_module.set_data_dir(tmp_path)
corr_module.set_config_dir(tmp_path)
corr_module.set_default_bench_results_dir(str(tmp_path / "bench_results"))
yield
sft_module.set_sft_data_dir(_prev_data)
sft_module.set_sft_config_dir(_prev_cfg)
sft_module.set_default_bench_results_dir(_prev_default)
corr_module.set_data_dir(_prev_data)
corr_module.set_config_dir(_prev_cfg)
corr_module.set_default_bench_results_dir(_prev_default)
@pytest.fixture
@ -63,7 +63,7 @@ def _write_config(tmp_path, bench_results_dir: Path) -> None:
)
# ── /api/sft/runs ──────────────────────────────────────────────────────────
# -- /api/sft/runs -------------------------------------------------------------
def test_runs_returns_empty_when_no_config(client):
r = client.get("/api/sft/runs")
@ -86,7 +86,7 @@ def test_runs_returns_available_runs(client, tmp_path):
def test_runs_marks_already_imported(client, tmp_path):
_write_run(tmp_path, "2026-04-07-143022", [_make_record("a")])
_write_config(tmp_path, tmp_path / "bench_results")
from app import sft as sft_module
from app.data import corrections as sft_module
candidates = sft_module._candidates_file()
candidates.parent.mkdir(parents=True, exist_ok=True)
candidates.write_text(
@ -97,7 +97,7 @@ def test_runs_marks_already_imported(client, tmp_path):
assert r.json()[0]["already_imported"] is True
# ── /api/sft/import ─────────────────────────────────────────────────────────
# -- /api/sft/import -----------------------------------------------------------
def test_import_adds_records(client, tmp_path):
_write_run(tmp_path, "2026-04-07-143022", [_make_record("a"), _make_record("b")])
@ -121,10 +121,10 @@ def test_import_unknown_run_returns_404(client, tmp_path):
assert r.status_code == 404
# ── /api/sft/queue ──────────────────────────────────────────────────────────
# -- /api/sft/queue ------------------------------------------------------------
def _populate_candidates(tmp_path, records: list[dict]) -> None:
from app import sft as sft_module
from app.data import corrections as sft_module
path = sft_module._candidates_file()
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(
@ -164,7 +164,7 @@ def test_queue_empty_when_no_file(client):
assert r.json() == {"items": [], "total": 0, "page": 1, "per_page": 20}
# ── /api/sft/submit ─────────────────────────────────────────────────────────
# -- /api/sft/submit -----------------------------------------------------------
def test_submit_correct_sets_approved(client, tmp_path):
_populate_candidates(tmp_path, [_make_record("a")])
@ -173,7 +173,7 @@ def test_submit_correct_sets_approved(client, tmp_path):
"corrected_response": "def add(a, b): return a + b",
})
assert r.status_code == 200
from app import sft as sft_module
from app.data import corrections as sft_module
records = sft_module._read_candidates()
assert records[0]["status"] == "approved"
assert records[0]["corrected_response"] == "def add(a, b): return a + b"
@ -185,7 +185,7 @@ def test_submit_correct_also_appends_to_approved_file(client, tmp_path):
"id": "a", "action": "correct",
"corrected_response": "def add(a, b): return a + b",
})
from app import sft as sft_module
from app.data import corrections as sft_module
from app.utils import read_jsonl
approved = read_jsonl(sft_module._approved_file())
assert len(approved) == 1
@ -196,7 +196,7 @@ def test_submit_discard_sets_discarded(client, tmp_path):
_populate_candidates(tmp_path, [_make_record("a")])
r = client.post("/api/sft/submit", json={"id": "a", "action": "discard"})
assert r.status_code == 200
from app import sft as sft_module
from app.data import corrections as sft_module
assert sft_module._read_candidates()[0]["status"] == "discarded"
@ -204,7 +204,7 @@ def test_submit_flag_sets_model_rejected(client, tmp_path):
_populate_candidates(tmp_path, [_make_record("a")])
r = client.post("/api/sft/submit", json={"id": "a", "action": "flag"})
assert r.status_code == 200
from app import sft as sft_module
from app.data import corrections as sft_module
assert sft_module._read_candidates()[0]["status"] == "model_rejected"
@ -243,7 +243,7 @@ def test_submit_correct_stores_failure_category(client, tmp_path):
"failure_category": "style_violation",
})
assert r.status_code == 200
from app import sft as sft_module
from app.data import corrections as sft_module
records = sft_module._read_candidates()
assert records[0]["failure_category"] == "style_violation"
@ -255,7 +255,7 @@ def test_submit_correct_null_failure_category(client, tmp_path):
"corrected_response": "def add(a, b): return a + b",
})
assert r.status_code == 200
from app import sft as sft_module
from app.data import corrections as sft_module
records = sft_module._read_candidates()
assert records[0]["failure_category"] is None
@ -270,14 +270,14 @@ def test_submit_invalid_failure_category_returns_422(client, tmp_path):
assert r.status_code == 422
# ── /api/sft/undo ────────────────────────────────────────────────────────────
# -- /api/sft/undo -------------------------------------------------------------
def test_undo_restores_discarded_to_needs_review(client, tmp_path):
_populate_candidates(tmp_path, [_make_record("a")])
client.post("/api/sft/submit", json={"id": "a", "action": "discard"})
r = client.post("/api/sft/undo", json={"id": "a"})
assert r.status_code == 200
from app import sft as sft_module
from app.data import corrections as sft_module
assert sft_module._read_candidates()[0]["status"] == "needs_review"
@ -288,7 +288,7 @@ def test_undo_removes_approved_from_approved_file(client, tmp_path):
"corrected_response": "def add(a, b): return a + b",
})
client.post("/api/sft/undo", json={"id": "a"})
from app import sft as sft_module
from app.data import corrections as sft_module
from app.utils import read_jsonl
approved = read_jsonl(sft_module._approved_file())
assert not any(r["id"] == "a" for r in approved)
@ -300,10 +300,10 @@ def test_undo_already_needs_review_returns_409(client, tmp_path):
assert r.status_code == 409
# ── /api/sft/export ──────────────────────────────────────────────────────────
# -- /api/sft/export -----------------------------------------------------------
def test_export_returns_approved_as_sft_jsonl(client, tmp_path):
from app import sft as sft_module
from app.data import corrections as sft_module
from app.utils import write_jsonl
approved = {
**_make_record("a"),
@ -331,7 +331,7 @@ def test_export_returns_approved_as_sft_jsonl(client, tmp_path):
def test_export_excludes_non_approved(client, tmp_path):
from app import sft as sft_module
from app.data import corrections as sft_module
from app.utils import write_jsonl
records = [
{**_make_record("a"), "status": "discarded", "corrected_response": None},
@ -348,10 +348,10 @@ def test_export_empty_when_no_approved_file(client):
assert r.text.strip() == ""
# ── /api/sft/stats ───────────────────────────────────────────────────────────
# -- /api/sft/stats ------------------------------------------------------------
def test_stats_counts_by_status(client, tmp_path):
from app import sft as sft_module
from app.data import corrections as sft_module
from app.utils import write_jsonl
records = [
_make_record("a"),

187
tests/test_train.py Normal file
View file

@ -0,0 +1,187 @@
"""Tests for app/train/train.py -- /api/train/* endpoints."""
import json
import pytest
from fastapi.testclient import TestClient
from unittest.mock import MagicMock, patch
@pytest.fixture(autouse=True)
def reset_globals(tmp_path):
from app.train import train as train_module
train_module.set_db_path(tmp_path / "train_jobs.db")
train_module.set_models_dir(tmp_path / "models")
train_module._running_procs.clear()
yield
train_module._running_procs.clear()
@pytest.fixture
def client():
from app.api import app
return TestClient(app)
def _parse_sse(content: bytes) -> list[dict]:
events = []
for line in content.decode().splitlines():
if line.startswith("data: "):
events.append(json.loads(line[6:]))
return events
def test_list_jobs_empty(client):
r = client.get("/api/train/jobs")
assert r.status_code == 200
assert r.json() == {"jobs": []}
def test_create_job_returns_queued_record(client):
r = client.post("/api/train/jobs",
json={"type": "classifier", "model_key": "deberta-small",
"config_json": {"epochs": 3}})
assert r.status_code == 200
data = r.json()
assert data["status"] == "queued"
assert data["type"] == "classifier"
assert data["model_key"] == "deberta-small"
assert "id" in data
def test_create_job_invalid_type_returns_400(client):
r = client.post("/api/train/jobs",
json={"type": "unknown-type", "model_key": "deberta-small"})
assert r.status_code == 400
def test_create_job_appears_in_list(client):
client.post("/api/train/jobs", json={"type": "classifier", "model_key": "deberta-small"})
r = client.get("/api/train/jobs")
assert r.status_code == 200
assert len(r.json()["jobs"]) == 1
def test_get_job_returns_record(client):
r = client.post("/api/train/jobs", json={"type": "classifier", "model_key": "deberta-small"})
job_id = r.json()["id"]
r2 = client.get(f"/api/train/jobs/{job_id}")
assert r2.status_code == 200
assert r2.json()["id"] == job_id
def test_get_job_404_for_unknown(client):
r = client.get("/api/train/jobs/no-such-id")
assert r.status_code == 404
def test_cancel_queued_job(client):
r = client.post("/api/train/jobs", json={"type": "classifier", "model_key": "deberta-small"})
job_id = r.json()["id"]
r2 = client.delete(f"/api/train/jobs/{job_id}/cancel")
assert r2.status_code == 200
assert r2.json()["status"] == "cancelled"
r3 = client.get(f"/api/train/jobs/{job_id}")
assert r3.json()["status"] == "cancelled"
def test_cancel_completed_job_returns_409(client):
from app.train import train as train_module
train_module._init_db()
with train_module._db() as conn:
conn.execute(
"INSERT INTO jobs (id, type, model_key, status, config_json, created_at) "
"VALUES ('abc', 'classifier', 'deberta-small', 'completed', '{}', '2026-05-01T00:00:00Z')"
)
r = client.delete("/api/train/jobs/abc/cancel")
assert r.status_code == 409
def test_cancel_terminates_running_proc(client):
from app.train import train as train_module
mock_proc = MagicMock()
mock_proc.wait = MagicMock()
r = client.post("/api/train/jobs", json={"type": "classifier", "model_key": "deberta-small"})
job_id = r.json()["id"]
train_module._running_procs[job_id] = mock_proc
with train_module._db() as conn:
conn.execute("UPDATE jobs SET status='running' WHERE id=?", (job_id,))
r2 = client.delete(f"/api/train/jobs/{job_id}/cancel")
assert r2.status_code == 200
mock_proc.terminate.assert_called_once()
def test_run_job_streams_sse(client):
r = client.post("/api/train/jobs", json={"type": "classifier", "model_key": "deberta-small"})
job_id = r.json()["id"]
mock_proc = MagicMock()
mock_proc.stdout = iter(["Epoch 1\n", "Done\n"])
mock_proc.returncode = 0
mock_proc.wait = MagicMock()
with patch("app.train.train._subprocess.Popen", return_value=mock_proc):
r2 = client.get(f"/api/train/jobs/{job_id}/run")
assert r2.status_code == 200
assert "text/event-stream" in r2.headers.get("content-type", "")
events = _parse_sse(r2.content)
assert any(e["type"] == "complete" for e in events)
def test_run_job_marks_completed_in_db(client):
from app.train import train as train_module
r = client.post("/api/train/jobs", json={"type": "classifier", "model_key": "deberta-small"})
job_id = r.json()["id"]
mock_proc = MagicMock()
mock_proc.stdout = iter([])
mock_proc.returncode = 0
mock_proc.wait = MagicMock()
with patch("app.train.train._subprocess.Popen", return_value=mock_proc):
client.get(f"/api/train/jobs/{job_id}/run")
r2 = client.get(f"/api/train/jobs/{job_id}")
assert r2.json()["status"] == "completed"
def test_run_job_marks_failed_on_nonzero_exit(client):
r = client.post("/api/train/jobs", json={"type": "classifier", "model_key": "deberta-small"})
job_id = r.json()["id"]
mock_proc = MagicMock()
mock_proc.stdout = iter([])
mock_proc.returncode = 1
mock_proc.wait = MagicMock()
with patch("app.train.train._subprocess.Popen", return_value=mock_proc):
client.get(f"/api/train/jobs/{job_id}/run")
r2 = client.get(f"/api/train/jobs/{job_id}")
assert r2.json()["status"] == "failed"
def test_run_nonqueued_job_returns_409(client):
from app.train import train as train_module
train_module._init_db()
with train_module._db() as conn:
conn.execute(
"INSERT INTO jobs (id, type, model_key, status, config_json, created_at) "
"VALUES ('xyz', 'classifier', 'deberta-small', 'running', '{}', '2026-05-01T00:00:00Z')"
)
r = client.get("/api/train/jobs/xyz/run")
assert r.status_code == 409
def test_run_unknown_job_returns_404(client):
r = client.get("/api/train/jobs/no-such/run")
assert r.status_code == 404
def test_results_empty_when_no_models_dir(client):
r = client.get("/api/train/results")
assert r.status_code == 200
assert r.json() == {"results": []}
def test_results_returns_training_info(client, tmp_path):
from app.train import train as train_module
models_dir = tmp_path / "models" / "avocet-deberta-small"
models_dir.mkdir(parents=True)
train_module.set_models_dir(tmp_path / "models")
info = {"name": "avocet-deberta-small", "val_macro_f1": 0.712, "sample_count": 401}
(models_dir / "training_info.json").write_text(json.dumps(info))
r = client.get("/api/train/results")
assert r.status_code == 200
data = r.json()
assert any(d["name"] == "avocet-deberta-small" for d in data["results"])

42
web/package-lock.json generated
View file

@ -2676,9 +2676,9 @@
}
},
"node_modules/brace-expansion": {
"version": "2.0.2",
"resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.2.tgz",
"integrity": "sha512-Jt0vHyM+jmUBqojB7E1NIYadt0vI0Qxjxd2TErW94wDz+E2LAm5vKMXXwg6ZZBTHPuUlDgQHKXvjGBdfcF1ZDQ==",
"version": "2.1.0",
"resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.1.0.tgz",
"integrity": "sha512-TN1kCZAgdgweJhWWpgKYrQaMNHcDULHkWwQIspdtjV4Y5aurRdZpjAqn6yX3FPqTA9ngHCc4hJxMAMgGfve85w==",
"dev": true,
"license": "MIT",
"dependencies": {
@ -2890,9 +2890,9 @@
"license": "MIT"
},
"node_modules/defu": {
"version": "6.1.4",
"resolved": "https://registry.npmjs.org/defu/-/defu-6.1.4.tgz",
"integrity": "sha512-mEQCMmwJu317oSz8CwdIOdwf3xMif1ttiM8LTufzc3g6kR+9Pe236twL8j3IYT1F7GfRgGcW6MWxzZjLIkuHIg==",
"version": "6.1.7",
"resolved": "https://registry.npmjs.org/defu/-/defu-6.1.7.tgz",
"integrity": "sha512-7z22QmUWiQ/2d0KkdYmANbRUVABpZ9SNYyH5vx6PZ+nE5bcC0l7uFvEfHlyld/HcGBFTL536ClDt3DEcSlEJAQ==",
"dev": true,
"license": "MIT"
},
@ -3725,9 +3725,9 @@
"license": "ISC"
},
"node_modules/picomatch": {
"version": "4.0.3",
"resolved": "https://registry.npmjs.org/picomatch/-/picomatch-4.0.3.tgz",
"integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==",
"version": "4.0.4",
"resolved": "https://registry.npmjs.org/picomatch/-/picomatch-4.0.4.tgz",
"integrity": "sha512-QP88BAKvMam/3NxH6vj2o21R6MjxZUAd6nlwAS/pnGvN9IVLocLHxGYIzFhg6fUQ+5th6P4dv4eW9jX3DSIj7A==",
"license": "MIT",
"engines": {
"node": ">=12"
@ -3769,9 +3769,9 @@
}
},
"node_modules/postcss": {
"version": "8.5.8",
"resolved": "https://registry.npmjs.org/postcss/-/postcss-8.5.8.tgz",
"integrity": "sha512-OW/rX8O/jXnm82Ey1k44pObPtdblfiuWnrd8X7GJ7emImCOstunGbXUpp7HdBrFQX6rJzn3sPT397Wp5aCwCHg==",
"version": "8.5.14",
"resolved": "https://registry.npmjs.org/postcss/-/postcss-8.5.14.tgz",
"integrity": "sha512-SoSL4+OSEtR99LHFZQiJLkT59C5B1amGO1NzTwj7TT1qCUgUO6hxOvzkOYxD+vMrXBM3XJIKzokoERdqQq/Zmg==",
"funding": [
{
"type": "opencollective",
@ -4325,9 +4325,9 @@
}
},
"node_modules/undici": {
"version": "7.22.0",
"resolved": "https://registry.npmjs.org/undici/-/undici-7.22.0.tgz",
"integrity": "sha512-RqslV2Us5BrllB+JeiZnK4peryVTndy9Dnqq62S3yYRRTj0tFQCwEniUy2167skdGOy3vqRzEvl1Dm4sV2ReDg==",
"version": "7.25.0",
"resolved": "https://registry.npmjs.org/undici/-/undici-7.25.0.tgz",
"integrity": "sha512-xXnp4kTyor2Zq+J1FfPI6Eq3ew5h6Vl0F/8d9XU5zZQf1tX9s2Su1/3PiMmUANFULpmksxkClamIZcaUqryHsQ==",
"dev": true,
"license": "MIT",
"engines": {
@ -4422,9 +4422,9 @@
}
},
"node_modules/vite": {
"version": "7.3.1",
"resolved": "https://registry.npmjs.org/vite/-/vite-7.3.1.tgz",
"integrity": "sha512-w+N7Hifpc3gRjZ63vYBXA56dvvRlNWRczTdmCBBa+CotUzAPf5b7YMdMR/8CQoeYE5LX3W4wj6RYTgonm1b9DA==",
"version": "7.3.2",
"resolved": "https://registry.npmjs.org/vite/-/vite-7.3.2.tgz",
"integrity": "sha512-Bby3NOsna2jsjfLVOHKes8sGwgl4TT0E6vvpYgnAYDIF/tie7MRaFthmKuHx1NSXjiTueXH3do80FMQgvEktRg==",
"dev": true,
"license": "MIT",
"dependencies": {
@ -4921,9 +4921,9 @@
"license": "MIT"
},
"node_modules/yaml": {
"version": "2.8.2",
"resolved": "https://registry.npmjs.org/yaml/-/yaml-2.8.2.tgz",
"integrity": "sha512-mplynKqc1C2hTVYxd0PU2xQAc22TI1vShAYGksCCfxbn/dFwnHTNi1bvYsBTkhdUNtGIf5xNOg938rrSSYvS9A==",
"version": "2.8.4",
"resolved": "https://registry.npmjs.org/yaml/-/yaml-2.8.4.tgz",
"integrity": "sha512-ml/JPOj9fOQK8RNnWojA67GbZ0ApXAUlN2UQclwv2eVgTgn7O9gg9o7paZWKMp4g0H3nTLtS9LVzhkpOFIKzog==",
"license": "ISC",
"bin": {
"yaml": "bin.mjs"

View file

@ -0,0 +1,124 @@
import { mount, flushPromises } from '@vue/test-utils'
import { createRouter, createWebHashHistory } from 'vue-router'
import { describe, it, expect, vi, beforeEach } from 'vitest'
import AppSidebar from './AppSidebar.vue'
// Minimal router so RouterLink renders without warnings
const router = createRouter({
history: createWebHashHistory(),
routes: [
{ path: '/', component: { template: '<div />' } },
{ path: '/fleet', component: { template: '<div />' } },
{ path: '/data/label', component: { template: '<div />' } },
{ path: '/data/fetch', component: { template: '<div />' } },
{ path: '/data/corrections', component: { template: '<div />' } },
{ path: '/data/imitate', component: { template: '<div />' } },
{ path: '/eval/benchmark', component: { template: '<div />' } },
{ path: '/eval/compare', component: { template: '<div />' } },
{ path: '/train/jobs', component: { template: '<div />' } },
{ path: '/train/results', component: { template: '<div />' } },
{ path: '/settings', component: { template: '<div />' } },
],
})
function makeFetch(signals: Record<string, boolean> = {}) {
return vi.fn().mockResolvedValue({
ok: true,
json: async () => ({
labeled_since_last_eval: 0,
last_eval_timestamp: null,
last_eval_best_score: null,
active_jobs: [],
corrections_export_ready: 0,
signals,
}),
text: async () => '',
})
}
beforeEach(() => {
localStorage.clear()
vi.stubGlobal('fetch', makeFetch())
})
describe('AppSidebar structure', () => {
it('renders section headers for Data, Eval, Train', async () => {
const w = mount(AppSidebar, { global: { plugins: [router] } })
await flushPromises()
const text = w.text()
expect(text).toContain('Data')
expect(text).toContain('Eval')
expect(text).toContain('Train')
})
it('renders all sub-links', async () => {
const w = mount(AppSidebar, { global: { plugins: [router] } })
await flushPromises()
const anchors = w.findAll('a')
const hrefs = anchors.map(a => a.attributes('href') ?? '')
expect(hrefs.some(h => h.includes('/data/label'))).toBe(true)
expect(hrefs.some(h => h.includes('/data/fetch'))).toBe(true)
expect(hrefs.some(h => h.includes('/data/corrections'))).toBe(true)
expect(hrefs.some(h => h.includes('/data/imitate'))).toBe(true)
expect(hrefs.some(h => h.includes('/eval/benchmark'))).toBe(true)
expect(hrefs.some(h => h.includes('/eval/compare'))).toBe(true)
expect(hrefs.some(h => h.includes('/train/jobs'))).toBe(true)
expect(hrefs.some(h => h.includes('/train/results'))).toBe(true)
expect(hrefs.some(h => h.includes('/fleet'))).toBe(true)
expect(hrefs.some(h => h.includes('/settings'))).toBe(true)
})
it('does NOT render the old /benchmark or /models links', async () => {
const w = mount(AppSidebar, { global: { plugins: [router] } })
await flushPromises()
const anchors = w.findAll('a')
const hrefs = anchors.map(a => a.attributes('href') ?? '')
// Old paths must not appear as direct links (they're only redirects)
expect(hrefs.every(h => !h.endsWith('/#/benchmark'))).toBe(true)
expect(hrefs.every(h => !h.endsWith('/#/models'))).toBe(true)
expect(hrefs.every(h => !h.endsWith('/#/stats'))).toBe(true)
})
it('shows no signal badges when all signals are false', async () => {
vi.stubGlobal('fetch', makeFetch({ data_to_eval: false, eval_to_train: false, train_to_fleet: false }))
const w = mount(AppSidebar, { global: { plugins: [router] } })
await flushPromises()
expect(w.findAll('.signal-badge').length).toBe(0)
})
it('shows signal badge on Data section when data_to_eval is true', async () => {
vi.stubGlobal('fetch', makeFetch({ data_to_eval: true, eval_to_train: false, train_to_fleet: false }))
const w = mount(AppSidebar, { global: { plugins: [router] } })
await flushPromises()
const badges = w.findAll('.signal-badge')
expect(badges.length).toBe(1)
// It should be inside the Data section header
const dataHeader = w.find('[data-section="data"]')
expect(dataHeader.find('.signal-badge').exists()).toBe(true)
})
it('shows signal badge on Eval section when eval_to_train is true', async () => {
vi.stubGlobal('fetch', makeFetch({ data_to_eval: false, eval_to_train: true, train_to_fleet: false }))
const w = mount(AppSidebar, { global: { plugins: [router] } })
await flushPromises()
const evalHeader = w.find('[data-section="eval"]')
expect(evalHeader.find('.signal-badge').exists()).toBe(true)
})
it('shows signal badge on Train section when train_to_fleet is true', async () => {
vi.stubGlobal('fetch', makeFetch({ data_to_eval: false, eval_to_train: false, train_to_fleet: true }))
const w = mount(AppSidebar, { global: { plugins: [router] } })
await flushPromises()
const trainHeader = w.find('[data-section="train"]')
expect(trainHeader.find('.signal-badge').exists()).toBe(true)
})
it('stow toggle still works', async () => {
const w = mount(AppSidebar, { global: { plugins: [router] } })
await flushPromises()
const nav = w.find('nav')
expect(nav.classes()).not.toContain('stowed')
await w.find('.stow-btn').trigger('click')
expect(nav.classes()).toContain('stowed')
})
})

View file

@ -28,12 +28,70 @@
</button>
</div>
<!-- Nav items -->
<!-- Nav -->
<ul class="nav-list" role="list">
<li v-for="item in navItems" :key="item.path">
<!-- Top-level links -->
<li>
<RouterLink
to="/"
class="nav-item"
:title="stowed ? 'Dashboard' : ''"
@click="isMobile && stow()"
>
<span class="nav-icon" aria-hidden="true">📊</span>
<span v-if="!stowed" class="nav-label">Dashboard</span>
</RouterLink>
</li>
<li>
<RouterLink
to="/fleet"
class="nav-item"
:title="stowed ? 'Fleet' : ''"
@click="isMobile && stow()"
>
<span class="nav-icon" aria-hidden="true"></span>
<span v-if="!stowed" class="nav-label">Fleet</span>
</RouterLink>
</li>
<li>
<RouterLink
to="/nodes"
class="nav-item"
:title="stowed ? 'Nodes' : ''"
@click="isMobile && stow()"
>
<span class="nav-icon" aria-hidden="true">🖥</span>
<span v-if="!stowed" class="nav-label">Nodes</span>
</RouterLink>
</li>
<!-- Data section -->
<li>
<div class="section-header" data-section="data" aria-hidden="true">
<template v-if="!stowed">
<span class="section-label"> Data</span>
<span
v-if="signals.data_to_eval"
class="signal-badge"
title="Enough new labels to run eval"
aria-label="Eval recommended"
/>
</template>
<template v-else>
<span class="section-icon"></span>
<span
v-if="signals.data_to_eval"
class="signal-badge signal-badge-stowed"
title="Eval recommended"
aria-label="Eval recommended"
/>
</template>
</div>
</li>
<li v-for="item in dataItems" :key="item.path">
<RouterLink
:to="item.path"
class="nav-item"
class="nav-item nav-subitem"
:title="stowed ? item.label : ''"
@click="isMobile && stow()"
>
@ -41,10 +99,94 @@
<span v-if="!stowed" class="nav-label">{{ item.label }}</span>
</RouterLink>
</li>
<!-- Eval section -->
<li>
<div class="section-header" data-section="eval" aria-hidden="true">
<template v-if="!stowed">
<span class="section-label"> Eval</span>
<span
v-if="signals.eval_to_train"
class="signal-badge"
title="Strong eval result — consider finetuning"
aria-label="Finetune recommended"
/>
</template>
<template v-else>
<span class="section-icon"></span>
<span
v-if="signals.eval_to_train"
class="signal-badge signal-badge-stowed"
title="Finetune recommended"
aria-label="Finetune recommended"
/>
</template>
</div>
</li>
<li v-for="item in evalItems" :key="item.path">
<RouterLink
:to="item.path"
class="nav-item nav-subitem"
:title="stowed ? item.label : ''"
@click="isMobile && stow()"
>
<span class="nav-icon" aria-hidden="true">{{ item.icon }}</span>
<span v-if="!stowed" class="nav-label">{{ item.label }}</span>
</RouterLink>
</li>
<!-- Train section -->
<li>
<div class="section-header" data-section="train" aria-hidden="true">
<template v-if="!stowed">
<span class="section-label"> Train</span>
<span
v-if="signals.train_to_fleet"
class="signal-badge"
title="Trained model ready for fleet registration"
aria-label="Fleet registration recommended"
/>
</template>
<template v-else>
<span class="section-icon"></span>
<span
v-if="signals.train_to_fleet"
class="signal-badge signal-badge-stowed"
title="Fleet registration recommended"
aria-label="Fleet registration recommended"
/>
</template>
</div>
</li>
<li v-for="item in trainItems" :key="item.path">
<RouterLink
:to="item.path"
class="nav-item nav-subitem"
:title="stowed ? item.label : ''"
@click="isMobile && stow()"
>
<span class="nav-icon" aria-hidden="true">{{ item.icon }}</span>
<span v-if="!stowed" class="nav-label">{{ item.label }}</span>
</RouterLink>
</li>
<!-- Divider + Settings -->
<li class="nav-divider" aria-hidden="true" />
<li>
<RouterLink
to="/settings"
class="nav-item"
:title="stowed ? 'Settings' : ''"
@click="isMobile && stow()"
>
<span class="nav-icon" aria-hidden="true"></span>
<span v-if="!stowed" class="nav-label">Settings</span>
</RouterLink>
</li>
</ul>
</nav>
<!-- Mobile hamburger button rendered outside the sidebar so it's visible when stowed -->
<!-- Mobile hamburger button visible when sidebar is stowed on mobile -->
<button
v-if="isMobile && stowed"
class="mobile-hamburger"
@ -61,25 +203,68 @@ import { RouterLink } from 'vue-router'
const LS_KEY = 'cf-avocet-nav-stowed'
const navItems = [
{ path: '/', icon: '🃏', label: 'Label' },
{ path: '/fetch', icon: '📥', label: 'Fetch' },
{ path: '/stats', icon: '📊', label: 'Stats' },
{ path: '/benchmark', icon: '🏁', label: 'Benchmark' },
{ path: '/models', icon: '🤗', label: 'Models' },
{ path: '/imitate', icon: '🪞', label: 'Imitate' },
{ path: '/corrections', icon: '✍️', label: 'Corrections' },
{ path: '/settings', icon: '⚙️', label: 'Settings' },
interface NavItem {
path: string
icon: string
label: string
}
interface DashboardSignals {
data_to_eval: boolean
eval_to_train: boolean
train_to_fleet: boolean
}
const dataItems: NavItem[] = [
{ path: '/data/label', icon: '🏷', label: 'Label' },
{ path: '/data/fetch', icon: '📬', label: 'Fetch' },
{ path: '/data/corrections', icon: '✏️', label: 'Corrections' },
{ path: '/data/imitate', icon: '🪞', label: 'Imitate' },
{ path: '/data/recipe-scan', icon: '📷', label: 'Recipe Scan' },
]
const evalItems: NavItem[] = [
{ path: '/eval/benchmark', icon: '📊', label: 'Benchmark' },
{ path: '/eval/compare', icon: '🔍', label: 'Compare' },
{ path: '/eval/embed-compare', icon: '🧮', label: 'Embed Compare' },
]
const trainItems: NavItem[] = [
{ path: '/train/jobs', icon: '🧠', label: 'Jobs' },
{ path: '/train/results', icon: '📈', label: 'Results' },
]
const stowed = ref(localStorage.getItem(LS_KEY) === 'true')
const winWidth = ref(window.innerWidth)
const isMobile = computed(() => winWidth.value < 640)
const signals = ref<DashboardSignals>({
data_to_eval: false,
eval_to_train: false,
train_to_fleet: false,
})
async function loadSignals() {
try {
const res = await fetch('/api/dashboard')
if (res.ok) {
const data = await res.json() as { signals?: DashboardSignals }
if (data.signals) {
signals.value = {
data_to_eval: data.signals.data_to_eval ?? false,
eval_to_train: data.signals.eval_to_train ?? false,
train_to_fleet: data.signals.train_to_fleet ?? false,
}
}
}
} catch {
// Non-fatal: badges simply stay hidden if API is unreachable
}
}
function toggle() {
stowed.value = !stowed.value
localStorage.setItem(LS_KEY, String(stowed.value))
// Update CSS variable on :root so .app-main margin-left syncs
document.documentElement.style.setProperty('--sidebar-width', stowed.value ? '56px' : '200px')
}
@ -93,13 +278,12 @@ function onResize() { winWidth.value = window.innerWidth }
onMounted(() => {
window.addEventListener('resize', onResize)
// Apply persisted sidebar width to :root on mount
document.documentElement.style.setProperty('--sidebar-width', stowed.value ? '56px' : '200px')
// On mobile, default to stowed
if (isMobile.value && !localStorage.getItem(LS_KEY)) {
stowed.value = true
document.documentElement.style.setProperty('--sidebar-width', '56px')
}
loadSignals()
})
onUnmounted(() => window.removeEventListener('resize', onResize))
@ -121,18 +305,15 @@ onUnmounted(() => window.removeEventListener('resize', onResize))
overflow: hidden;
}
.sidebar.stowed {
width: 56px;
}
.sidebar.stowed { width: 56px; }
/* Mobile: slide in/out from left */
.sidebar.mobile {
box-shadow: 2px 0 16px rgba(0, 0, 0, 0.15);
}
.sidebar.mobile.stowed {
transform: translateX(-100%);
width: 200px; /* keep width so slide-in looks right */
width: 200px;
transition: transform 250ms ease, width 250ms ease;
}
@ -165,10 +346,7 @@ onUnmounted(() => window.removeEventListener('resize', onResize))
white-space: nowrap;
}
.logo-icon {
font-size: 1.25rem;
flex-shrink: 0;
}
.logo-icon { font-size: 1.25rem; flex-shrink: 0; }
.logo-name {
font-family: var(--font-display, var(--font-body, sans-serif));
@ -193,16 +371,76 @@ onUnmounted(() => window.removeEventListener('resize', onResize))
transition: background 0.15s;
}
.stow-btn:hover {
background: var(--color-border, #d0d7e8);
}
.stow-btn:hover { background: var(--color-border, #d0d7e8); }
.nav-list {
list-style: none;
padding: 0.5rem 0;
flex: 1;
overflow-y: auto;
overflow-x: hidden;
}
/* ── Section headers ── */
.section-header {
display: flex;
align-items: center;
gap: 0.4rem;
padding: 0.55rem 0.75rem 0.25rem;
margin-top: 0.5rem;
pointer-events: none;
user-select: none;
}
.section-label {
font-size: 0.7rem;
font-weight: 700;
text-transform: uppercase;
letter-spacing: 0.07em;
color: var(--color-text-muted, #4a5c7a);
white-space: nowrap;
flex: 1;
}
.section-icon {
font-size: 0.75rem;
color: var(--color-text-muted, #4a5c7a);
width: 24px;
text-align: center;
flex-shrink: 0;
}
/* ── Signal badges ── */
.signal-badge {
width: 8px;
height: 8px;
border-radius: 50%;
background: var(--color-warning, #d4891a);
flex-shrink: 0;
display: inline-block;
}
.signal-badge-stowed {
position: absolute;
top: 4px;
right: 4px;
}
/* Make the stowed section header container position:relative for the badge */
.sidebar.stowed .section-header {
position: relative;
justify-content: center;
padding: 0.55rem 0 0.25rem;
}
/* ── Nav divider ── */
.nav-divider {
height: 1px;
background: var(--color-border, #d0d7e8);
margin: 0.5rem 0.75rem;
}
/* ── Nav items ── */
.nav-item {
display: flex;
align-items: center;
@ -238,6 +476,9 @@ onUnmounted(() => window.removeEventListener('resize', onResize))
border-radius: 0 2px 2px 0;
}
/* Sub-items are indented slightly in expanded state */
.nav-subitem { padding-left: 1.1rem; font-size: 0.875rem; }
.nav-icon {
font-size: 1.1rem;
flex-shrink: 0;
@ -245,12 +486,9 @@ onUnmounted(() => window.removeEventListener('resize', onResize))
text-align: center;
}
.nav-label {
overflow: hidden;
text-overflow: ellipsis;
}
.nav-label { overflow: hidden; text-overflow: ellipsis; }
/* Mobile hamburger — visible when sidebar is stowed on mobile */
/* Mobile hamburger */
.mobile-hamburger {
position: fixed;
top: 0.75rem;

View file

@ -0,0 +1,170 @@
<script setup lang="ts">
import { ref, watch } from 'vue'
import type { CatalogEntryFull } from '../../types/nodes'
const props = defineProps<{
svcName: string
modelName?: string
entry?: CatalogEntryFull
}>()
const emit = defineEmits<{
save: [svcName: string, modelName: string, entry: CatalogEntryFull]
cancel: []
}>()
const name = ref(props.modelName ?? '')
const path = ref(props.entry?.path ?? '')
const vramMb = ref(props.entry?.vram_mb ?? 0)
const description = ref(props.entry?.description ?? '')
const multiGpu = ref(props.entry?.multi_gpu ?? false)
const envPairs = ref<{ k: string; v: string }[]>(
Object.entries(props.entry?.env ?? {}).map(([k, v]) => ({ k, v }))
)
const formError = ref('')
watch(() => props.entry, (e) => {
name.value = props.modelName ?? ''
path.value = e?.path ?? ''
vramMb.value = e?.vram_mb ?? 0
description.value = e?.description ?? ''
multiGpu.value = e?.multi_gpu ?? false
envPairs.value = Object.entries(e?.env ?? {}).map(([k, v]) => ({ k, v }))
})
function addEnvPair() {
envPairs.value = [...envPairs.value, { k: '', v: '' }]
}
function removeEnvPair(i: number) {
envPairs.value = envPairs.value.filter((_, idx) => idx !== i)
}
function submit() {
formError.value = ''
if (!name.value.trim()) { formError.value = 'Model name is required.'; return }
if (!path.value.trim()) { formError.value = 'Path is required.'; return }
if (!vramMb.value || vramMb.value < 0) { formError.value = 'vram_mb must be a positive number.'; return }
const envObj: Record<string, string> = {}
for (const { k, v } of envPairs.value) {
if (k.trim()) envObj[k.trim()] = v
}
const entry: CatalogEntryFull = { path: path.value.trim(), vram_mb: vramMb.value }
if (description.value.trim()) entry.description = description.value.trim()
if (multiGpu.value) entry.multi_gpu = true
if (Object.keys(envObj).length) entry.env = envObj
emit('save', props.svcName, name.value.trim(), entry)
}
</script>
<template>
<div class="modal-backdrop" role="dialog" aria-modal="true" :aria-label="`${modelName ? 'Edit' : 'Add'} catalog entry`">
<div class="modal-box">
<h3 class="modal-title">{{ modelName ? 'Edit' : 'Add' }} Catalog Entry {{ svcName }}</h3>
<div class="field-row">
<label class="field-label" for="ce-name">Model name</label>
<input id="ce-name" v-model="name" class="field-input" :readonly="!!modelName" placeholder="deepseek-r1-7b" />
</div>
<div class="field-row">
<label class="field-label" for="ce-path">Path</label>
<input id="ce-path" v-model="path" class="field-input" placeholder="/devl/Assets/LLM/cf-text/models/..." />
</div>
<div class="field-row">
<label class="field-label" for="ce-vram">VRAM (MB)</label>
<input id="ce-vram" v-model.number="vramMb" type="number" min="0" class="field-input field-input--sm" />
</div>
<div class="field-row">
<label class="field-label" for="ce-desc">Description</label>
<input id="ce-desc" v-model="description" class="field-input" placeholder="Short description" />
</div>
<div class="field-row field-row--check">
<input id="ce-mgpu" v-model="multiGpu" type="checkbox" />
<label for="ce-mgpu">Multi-GPU span</label>
</div>
<div class="env-section">
<div class="env-header">
<span class="field-label">Env vars</span>
<button type="button" class="btn-link" @click="addEnvPair">+ Add</button>
</div>
<div v-for="(pair, i) in envPairs" :key="i" class="env-row">
<input v-model="pair.k" class="field-input field-input--sm" placeholder="CF_TEXT_4BIT" />
<span>=</span>
<input v-model="pair.v" class="field-input field-input--sm" placeholder="1" />
<button type="button" class="btn-icon" @click="removeEnvPair(i)" aria-label="Remove"></button>
</div>
</div>
<div v-if="formError" class="form-error" role="alert">{{ formError }}</div>
<div class="modal-actions">
<button class="btn-secondary" @click="emit('cancel')">Cancel</button>
<button class="btn-primary" @click="submit">Save</button>
</div>
</div>
</div>
</template>
<style scoped>
.modal-backdrop {
position: fixed; inset: 0;
background: rgba(0,0,0,0.5);
display: flex; align-items: center; justify-content: center;
z-index: 200;
}
.modal-box {
background: var(--color-surface-raised);
border: 1px solid var(--color-border);
border-radius: 8px;
padding: 1.5rem;
width: 100%; max-width: 500px;
max-height: 90vh; overflow-y: auto;
display: flex; flex-direction: column; gap: 0.75rem;
color: var(--color-text);
}
.modal-title { margin: 0 0 0.25rem; font-size: 1rem; font-weight: 600; color: var(--color-text); }
.field-row { display: flex; align-items: center; gap: 0.5rem; }
.field-row--check { gap: 0.4rem; color: var(--color-text); }
.field-label { min-width: 8rem; font-size: 0.85rem; color: var(--color-text-muted); }
.field-input {
flex: 1;
background: var(--color-surface-alt);
border: 1px solid var(--color-border);
border-radius: 4px;
padding: 0.3rem 0.5rem;
color: var(--color-text);
font-size: 0.85rem;
}
.field-input--sm { flex: 0 0 8rem; }
.env-section { display: flex; flex-direction: column; gap: 0.35rem; }
.env-header { display: flex; align-items: center; justify-content: space-between; }
.env-row { display: flex; align-items: center; gap: 0.4rem; }
.btn-link { background: none; border: none; color: var(--app-primary); cursor: pointer; font-size: 0.8rem; padding: 0; }
.btn-link:hover { color: var(--app-primary-hover); }
.btn-icon { background: none; border: none; color: var(--color-text-muted); cursor: pointer; padding: 0 0.2rem; font-size: 0.85rem; }
.btn-icon:hover { color: var(--color-error); }
.form-error { color: var(--color-error); font-size: 0.8rem; }
.modal-actions { display: flex; justify-content: flex-end; gap: 0.5rem; margin-top: 0.25rem; }
.btn-primary {
background: var(--app-primary);
color: var(--color-text-inverse);
border: none;
border-radius: 4px;
padding: 0.4rem 1rem;
cursor: pointer;
font-size: 0.875rem;
}
.btn-primary:hover { background: var(--app-primary-hover); }
.btn-secondary {
background: transparent;
border: 1px solid var(--color-border);
color: var(--color-text);
border-radius: 4px;
padding: 0.4rem 0.75rem;
cursor: pointer;
font-size: 0.875rem;
}
.btn-secondary:hover { background: var(--color-surface-alt); }
</style>

View file

@ -0,0 +1,129 @@
<script setup lang="ts">
import { ref, computed } from 'vue'
import ServiceBadge from './ServiceBadge.vue'
import type { GpuEntry, ServiceInfo } from '../../types/nodes'
const props = defineProps<{
gpu: GpuEntry
nodeId: string
profileLoaded: boolean
servicesCatalog: Record<string, ServiceInfo>
}>()
const emit = defineEmits<{ updated: [] }>()
const saving = ref(false)
const saveError = ref('')
const vramPct = computed(() => {
if (!props.gpu.vram_total_mb) return 0
return Math.round((props.gpu.vram_used_mb / props.gpu.vram_total_mb) * 100)
})
function serviceState(svcName: string): 'running' | 'stopped' | 'assigned-only' | 'available' | 'incompatible' | 'unknown' {
const svc = props.servicesCatalog[svcName]
if (!svc) return 'unknown'
const cap = props.gpu.compute_cap ?? 0
if (cap < svc.min_compute_cap) return 'incompatible'
if (props.gpu.services_running.includes(svcName)) return 'running'
if (props.gpu.services_assigned.includes(svcName)) return 'assigned-only'
return 'available'
}
async function toggleService(svcName: string) {
if (!props.profileLoaded || saving.value) return
const current = [...props.gpu.services_assigned]
const removing = current.includes(svcName)
if (removing && !confirm(`Remove ${svcName} from GPU ${props.gpu.gpu_id}?`)) return
const next = removing ? current.filter(s => s !== svcName) : [...current, svcName]
saving.value = true
saveError.value = ''
try {
const r = await fetch(
`/api/nodes-mgmt/nodes/${props.nodeId}/gpu/${props.gpu.gpu_id}/services`,
{
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ services: next }),
},
)
if (!r.ok) {
const data = await r.json().catch(() => ({}))
throw new Error((data as { detail?: string }).detail ?? `HTTP ${r.status}`)
}
const data = await r.json() as { ok: boolean; reloaded: boolean; warnings: string[] }
if (data.warnings?.length) saveError.value = `Saved (warning: ${data.warnings.join(', ')})`
emit('updated')
} catch (e) {
saveError.value = e instanceof Error ? e.message : 'Failed to update services'
} finally {
saving.value = false
}
}
</script>
<template>
<div class="gpu-row">
<div class="gpu-info">
<span class="gpu-label">GPU {{ gpu.gpu_id }}: {{ gpu.card }}</span>
<span v-if="gpu.compute_cap != null" class="gpu-meta">sm{{ gpu.compute_cap }}</span>
<span v-if="gpu.temp_c != null" class="gpu-meta">{{ gpu.temp_c }}°C</span>
<span v-if="gpu.utilization_pct != null" class="gpu-meta">{{ gpu.utilization_pct }}%</span>
</div>
<div class="vram-wrap">
<div
class="vram-bar"
role="progressbar"
:aria-valuenow="gpu.vram_used_mb"
aria-valuemin="0"
:aria-valuemax="gpu.vram_total_mb"
:aria-label="`VRAM: ${gpu.vram_used_mb} of ${gpu.vram_total_mb} MB used`"
>
<div class="vram-fill" :style="{ width: `${vramPct}%` }" />
</div>
<span class="vram-text">{{ gpu.vram_used_mb }} / {{ gpu.vram_total_mb }} MB ({{ vramPct }}%)</span>
</div>
<div v-if="profileLoaded" class="services-row" aria-label="Service assignments">
<ServiceBadge
v-for="(_, svcName) in servicesCatalog"
:key="String(svcName)"
:service-name="String(svcName)"
:state="serviceState(String(svcName))"
:assigned="gpu.services_assigned.includes(String(svcName))"
:disabled="saving"
@toggle="toggleService(String(svcName))"
/>
</div>
<div v-if="saveError" class="save-msg" role="alert">{{ saveError }}</div>
</div>
</template>
<style scoped>
.gpu-row {
padding: 0.5rem 0.75rem;
border-radius: 4px;
background: var(--color-surface-alt);
display: flex;
flex-direction: column;
gap: 0.4rem;
}
.gpu-info { display: flex; gap: 0.75rem; align-items: center; flex-wrap: wrap; font-size: 0.875rem; }
.gpu-label { font-weight: 500; color: var(--color-text); }
.gpu-meta { color: var(--color-text-muted); font-size: 0.8rem; }
.vram-wrap { display: flex; align-items: center; gap: 0.5rem; }
.vram-bar {
flex: 1;
height: 8px;
background: var(--color-border);
border-radius: 4px;
overflow: hidden;
}
.vram-fill { height: 100%; background: var(--app-primary); transition: width 0.3s; }
.vram-text { font-size: 0.75rem; color: var(--color-text-muted); white-space: nowrap; }
.services-row { display: flex; flex-wrap: wrap; gap: 0.4rem; }
.save-msg { color: var(--color-warning); font-size: 0.8rem; }
</style>

View file

@ -0,0 +1,134 @@
<script setup lang="ts">
import { ref, onMounted, onUnmounted } from 'vue'
interface CatalogEntry {
path: string
vram_mb: number
description: string
multi_gpu: boolean
}
interface ServiceProfile {
catalog: Record<string, CatalogEntry>
min_compute_cap: number
max_mb: number
}
interface NodeProfile {
services: Record<string, ServiceProfile>
}
const props = defineProps<{
nodeId: string
}>()
const profile = ref<NodeProfile | null>(null)
const loading = ref(true)
const error = ref('')
let fetchAbort: AbortController | null = null
async function fetchProfile() {
fetchAbort?.abort()
fetchAbort = new AbortController()
loading.value = true
error.value = ''
try {
const r = await fetch(`/api/nodes-mgmt/nodes/${props.nodeId}/profile`, {
signal: fetchAbort.signal,
})
if (r.status === 404) { profile.value = null; return }
if (!r.ok) throw new Error(`HTTP ${r.status}`)
profile.value = await r.json() as NodeProfile
} catch (e) {
if (e instanceof Error && e.name === 'AbortError') return
error.value = e instanceof Error ? e.message : 'Failed to load profile'
} finally {
loading.value = false
}
}
onMounted(fetchProfile)
onUnmounted(() => { fetchAbort?.abort() })
</script>
<template>
<section class="hf-panel">
<h3 class="panel-title">Model Catalog</h3>
<p class="hf-hint">
To download a new HuggingFace model,
<a href="#/fleet" class="hf-link">go to Fleet</a>.
Models downloaded there are automatically registered in node catalogs.
</p>
<div aria-live="polite" aria-atomic="true" class="sr-announce">
<span v-if="loading">Loading catalog...</span>
</div>
<div v-if="error" class="panel-error" role="alert">{{ error }}</div>
<div v-else-if="!loading && !profile" class="panel-empty">No profile loaded for this node.</div>
<div v-else-if="!loading && profile" class="catalog-body">
<div
v-for="(svcInfo, svcName) in profile.services"
:key="String(svcName)"
class="svc-section"
>
<h4 class="svc-name">{{ svcName }}</h4>
<ul class="catalog-list" role="list">
<li
v-if="!Object.keys(svcInfo.catalog ?? {}).length"
class="catalog-empty"
>
No models in catalog.
</li>
<li
v-for="(entry, modelName) in (svcInfo.catalog ?? {})"
:key="String(modelName)"
class="catalog-item"
>
<span class="catalog-model">{{ modelName }}</span>
<span class="catalog-vram">{{ entry.vram_mb }} MB</span>
<span v-if="entry.description" class="catalog-desc">{{ entry.description }}</span>
</li>
</ul>
</div>
</div>
</section>
</template>
<style scoped>
.hf-panel {
margin-top: 0.75rem;
padding: 0.75rem;
border: 1px solid var(--color-border);
border-radius: 6px;
color: var(--color-text);
}
.panel-title { margin: 0 0 0.5rem; font-size: 0.9rem; color: var(--color-text); }
.hf-hint { font-size: 0.8rem; color: var(--color-text-muted); margin: 0 0 0.75rem; }
.hf-link { color: var(--app-primary); }
.hf-link:hover { color: var(--app-primary-hover); }
.svc-section { margin-bottom: 0.75rem; }
.svc-name {
margin: 0 0 0.25rem;
font-size: 0.75rem;
text-transform: uppercase;
letter-spacing: 0.05em;
color: var(--color-text-muted);
}
.catalog-list { list-style: none; margin: 0; padding: 0; display: flex; flex-direction: column; gap: 0.2rem; }
.catalog-item {
display: flex;
align-items: center;
gap: 0.5rem;
padding: 0.25rem 0.5rem;
background: var(--color-surface-alt);
border-radius: 4px;
font-size: 0.8rem;
}
.catalog-model { font-family: var(--font-mono, monospace); flex: 1; }
.catalog-vram { color: var(--color-text-muted); white-space: nowrap; }
.catalog-desc { color: var(--color-text-muted); font-size: 0.75rem; flex: 2; }
.catalog-empty, .panel-empty { color: var(--color-text-muted); font-size: 0.875rem; }
.sr-announce { min-height: 1.2em; }
.panel-error { color: var(--color-error); font-size: 0.8rem; }
</style>

View file

@ -0,0 +1,148 @@
<script setup lang="ts">
import { ref } from 'vue'
import GpuRow from './GpuRow.vue'
import OllamaModelPanel from './OllamaModelPanel.vue'
import ProfileEditorPanel from './ProfileEditorPanel.vue'
import type { NodeSummary, FullProfile } from '../../types/nodes'
const props = defineProps<{ node: NodeSummary }>()
const emit = defineEmits<{ updated: [] }>()
const showOllama = ref(false)
const showEditor = ref(false)
const loadedProfile = ref<FullProfile | null>(null)
const profileLoading = ref(false)
const profileError = ref('')
async function openEditor() {
if (showEditor.value) { showEditor.value = false; return }
profileLoading.value = true
profileError.value = ''
try {
const r = await fetch(`/api/nodes-mgmt/nodes/${props.node.node_id}/profile`)
if (r.status === 404) {
loadedProfile.value = null
} else if (!r.ok) {
throw new Error(`HTTP ${r.status}`)
} else {
loadedProfile.value = await r.json() as FullProfile
}
showEditor.value = true
} catch (e) {
profileError.value = e instanceof Error ? e.message : 'Failed to load profile'
} finally {
profileLoading.value = false
}
}
function onProfileSaved() {
showEditor.value = false
emit('updated')
}
</script>
<template>
<section class="node-card" :class="{ offline: !node.online }">
<header class="node-card-header">
<div class="node-identity">
<span
class="status-dot"
:class="node.online ? 'online' : 'offline'"
:aria-label="node.online ? 'Online' : 'Offline'"
role="img"
/>
<h2 class="node-name">{{ node.node_id }}</h2>
<span class="node-agent">{{ node.agent_url }}</span>
</div>
<div class="node-actions">
<button
v-if="node.profile_loaded"
class="btn-secondary btn-sm"
@click="showOllama = !showOllama"
>
{{ showOllama ? 'Hide Ollama' : 'Ollama' }}
</button>
<button
class="btn-secondary btn-sm"
:disabled="profileLoading"
@click="openEditor"
>
{{ profileLoading ? 'Loading…' : node.profile_loaded ? (showEditor ? 'Close Editor' : 'Edit Profile') : 'Create Profile' }}
</button>
</div>
</header>
<div v-if="!node.profile_loaded" class="no-profile" role="status">
No profile configured for this node. GPU stats are visible; service assignment is disabled.
</div>
<div class="gpu-list">
<GpuRow
v-for="gpu in node.gpus"
:key="gpu.gpu_id"
:gpu="gpu"
:node-id="node.node_id"
:profile-loaded="node.profile_loaded"
:services-catalog="node.services_catalog"
@updated="emit('updated')"
/>
</div>
<OllamaModelPanel v-if="showOllama" :node-id="node.node_id" />
<div v-if="profileError" class="profile-load-error" role="alert">{{ profileError }}</div>
<ProfileEditorPanel
v-if="showEditor"
:node-id="node.node_id"
:initial-profile="loadedProfile"
@saved="onProfileSaved"
@close="showEditor = false"
/>
</section>
</template>
<style scoped>
.node-card {
border: 1px solid var(--color-border);
border-radius: 8px;
padding: 1rem;
background: var(--color-surface-raised);
color: var(--color-text);
}
.node-card.offline { opacity: 0.65; }
.node-card-header {
display: flex;
align-items: flex-start;
justify-content: space-between;
gap: 0.5rem;
margin-bottom: 0.75rem;
}
.node-identity { display: flex; align-items: center; gap: 0.5rem; flex-wrap: wrap; }
.node-name { margin: 0; font-size: 1rem; font-weight: 600; color: var(--color-text); }
.node-agent { color: var(--color-text-muted); font-size: 0.8rem; font-family: var(--font-mono, monospace); }
.status-dot { width: 10px; height: 10px; border-radius: 50%; flex-shrink: 0; }
.status-dot.online { background: var(--color-success); }
.status-dot.offline { background: var(--color-warning); }
.node-actions { display: flex; gap: 0.5rem; flex-shrink: 0; }
.btn-secondary {
background: transparent;
border: 1px solid var(--color-border);
color: var(--color-text);
border-radius: 4px;
padding: 0.3rem 0.65rem;
cursor: pointer;
font-size: 0.8rem;
}
.btn-secondary:hover { background: var(--color-surface-alt); }
.btn-secondary:disabled { opacity: 0.5; cursor: not-allowed; }
.btn-sm { font-size: 0.8rem; padding: 0.25rem 0.6rem; }
.no-profile {
padding: 0.6rem 0.75rem;
background: var(--color-surface-alt);
border-radius: 4px;
color: var(--color-text-muted);
font-size: 0.875rem;
margin-bottom: 0.5rem;
}
.gpu-list { display: flex; flex-direction: column; gap: 0.5rem; }
.profile-load-error { color: var(--color-error); font-size: 0.8rem; margin-top: 0.5rem; }
</style>

View file

@ -0,0 +1,242 @@
<script setup lang="ts">
import { ref, onMounted, onUnmounted } from 'vue'
const props = defineProps<{ nodeId: string }>()
interface OllamaModel {
name: string
size: number
modified_at: string
}
const models = ref<OllamaModel[]>([])
const loading = ref(true)
const loadError = ref('')
const pullName = ref('')
const pulling = ref(false)
const pullStatus = ref('')
const pullPct = ref(0)
const pullError = ref('')
// AbortController for the SSE pull stream
const abortCtrl = ref<AbortController | null>(null)
// AbortController for the one-shot fetchModels request
let fetchAbort: AbortController | null = null
async function fetchModels() {
fetchAbort?.abort()
fetchAbort = new AbortController()
loading.value = true
loadError.value = ''
try {
const r = await fetch(`/api/nodes-mgmt/nodes/${props.nodeId}/models/ollama`, {
signal: fetchAbort.signal,
})
const data = await r.json() as { models?: OllamaModel[]; error?: string }
if (data.error) { loadError.value = data.error; return }
models.value = data.models ?? []
} catch (e) {
if (e instanceof Error && e.name === 'AbortError') return
loadError.value = e instanceof Error ? e.message : 'Failed to load models'
} finally {
loading.value = false
}
}
async function doPull() {
const name = pullName.value.trim()
if (!name || pulling.value) return
pulling.value = true
pullStatus.value = 'Starting...'
pullError.value = ''
pullPct.value = 0
const ctrl = new AbortController()
abortCtrl.value = ctrl
try {
const resp = await fetch(`/api/nodes-mgmt/nodes/${props.nodeId}/models/ollama/pull`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ name }),
signal: ctrl.signal,
})
if (!resp.ok) throw new Error(`HTTP ${resp.status}`)
if (!resp.body) throw new Error('No response body')
const reader = resp.body.getReader()
const decoder = new TextDecoder()
let buf = ''
while (true) {
const { done, value } = await reader.read()
if (done) break
buf += decoder.decode(value, { stream: true })
const lines = buf.split('\n')
buf = lines.pop() ?? ''
for (const line of lines) {
if (!line.startsWith('data: ')) continue
try {
const evt = JSON.parse(line.slice(6)) as {
status?: string; error?: string; total?: number; completed?: number
}
if (evt.error) {
pullError.value = evt.error
break
}
if (evt.status) pullStatus.value = evt.status
if (evt.total && evt.completed) {
pullPct.value = Math.round((evt.completed / evt.total) * 100)
}
if (evt.status === 'success') {
pullStatus.value = 'Done!'
pullName.value = ''
break
}
} catch { /* skip malformed line */ }
}
}
// Refresh model list after the stream closes (success or benign end)
await fetchModels()
} catch (e) {
if (e instanceof Error && e.name === 'AbortError') return
pullError.value = e instanceof Error ? e.message : 'Pull failed'
} finally {
pulling.value = false
abortCtrl.value = null
}
}
async function deleteModel(name: string) {
if (!confirm(`Delete model "${name}" from node ${props.nodeId}?`)) return
try {
const r = await fetch(
`/api/nodes-mgmt/nodes/${props.nodeId}/models/ollama/${encodeURIComponent(name)}`,
{ method: 'DELETE' },
)
if (!r.ok) throw new Error(`HTTP ${r.status}`)
await fetchModels()
} catch (e) {
loadError.value = e instanceof Error ? e.message : 'Delete failed'
}
}
function formatSize(bytes: number): string {
return (bytes / 1e9).toFixed(1) + ' GB'
}
onMounted(fetchModels)
onUnmounted(() => {
abortCtrl.value?.abort()
fetchAbort?.abort()
})
</script>
<template>
<section class="ollama-panel">
<h3 class="panel-title">Ollama Models</h3>
<form class="pull-form" @submit.prevent="doPull">
<input
v-model="pullName"
type="text"
placeholder="nomic-embed-text, llama3.2:3b, ..."
:disabled="pulling"
aria-label="Model name to pull from Ollama"
class="pull-input"
/>
<button type="submit" :disabled="pulling || !pullName.trim()" class="btn-primary btn-sm">
{{ pulling ? 'Pulling...' : 'Pull' }}
</button>
</form>
<div v-if="pulling || pullStatus" class="pull-progress" aria-live="polite">
<div
class="progress-bar"
role="progressbar"
:aria-valuenow="pullPct"
aria-valuemin="0"
aria-valuemax="100"
:aria-label="`Pull progress: ${pullStatus}`"
>
<div class="progress-fill" :style="{ width: `${pullPct}%` }" />
</div>
<span class="progress-label">{{ pullStatus }}{{ pullPct > 0 ? ` (${pullPct}%)` : '' }}</span>
</div>
<div v-if="pullError" class="pull-error" role="alert">
{{ pullError }}
<span v-if="pullError.includes('permission denied')">
Remove the partial file on the node and retry.
</span>
</div>
<div aria-live="polite" aria-atomic="true" class="sr-announce">
<span v-if="loading">Loading...</span>
</div>
<div v-if="loadError" class="panel-error" role="alert">{{ loadError }}</div>
<ul v-if="!loading && !loadError" class="model-list" role="list">
<li v-if="!models.length" class="model-empty">No Ollama models installed on this node.</li>
<li v-for="m in models" :key="m.name" class="model-item">
<span class="model-name">{{ m.name }}</span>
<span class="model-size">{{ formatSize(m.size) }}</span>
<button
class="btn-danger btn-xs"
@click="deleteModel(m.name)"
:aria-label="`Delete ${m.name}`"
>
Delete
</button>
</li>
</ul>
</section>
</template>
<style scoped>
.ollama-panel {
margin-top: 0.75rem;
padding: 0.75rem;
border: 1px solid var(--color-border);
border-radius: 6px;
color: var(--color-text);
}
.panel-title { margin: 0 0 0.75rem; font-size: 0.9rem; color: var(--color-text); }
.pull-form { display: flex; gap: 0.5rem; margin-bottom: 0.5rem; }
.pull-input {
flex: 1;
padding: 0.3rem 0.5rem;
background: var(--color-surface-alt);
border: 1px solid var(--color-border);
border-radius: 4px;
color: var(--color-text);
font-size: 0.875rem;
}
.pull-progress { margin-bottom: 0.5rem; }
.progress-bar {
height: 8px;
background: var(--color-border);
border-radius: 4px;
overflow: hidden;
margin-bottom: 0.25rem;
}
.progress-fill { height: 100%; background: var(--app-primary); transition: width 0.2s; }
.progress-label { font-size: 0.75rem; color: var(--color-text-muted); }
.pull-error, .panel-error { color: var(--color-error); font-size: 0.8rem; margin-bottom: 0.5rem; }
.sr-announce { min-height: 1.2em; }
.panel-loading { color: var(--color-text-muted); font-size: 0.875rem; }
.model-list { list-style: none; margin: 0; padding: 0; display: flex; flex-direction: column; gap: 0.3rem; }
.model-item {
display: flex;
align-items: center;
gap: 0.5rem;
padding: 0.3rem 0.5rem;
background: var(--color-surface-alt);
border-radius: 4px;
font-size: 0.875rem;
}
.model-name { flex: 1; font-family: var(--font-mono, monospace); }
.model-size { color: var(--color-text-muted); font-size: 0.8rem; }
.model-empty { color: var(--color-text-muted); font-size: 0.875rem; padding: 0.25rem 0; }
</style>

View file

@ -0,0 +1,597 @@
<script setup lang="ts">
import { ref, onMounted } from 'vue'
import type { FullProfile, ServiceDefinition, CatalogEntryFull } from '../../types/nodes'
import ServiceFormModal from './ServiceFormModal.vue'
import CatalogEntryFormModal from './CatalogEntryFormModal.vue'
const props = defineProps<{
nodeId: string
initialProfile: FullProfile | null
}>()
const emit = defineEmits<{ saved: []; close: [] }>()
// Deep-clone initial profile so edits don't mutate the parent's data
const profile = ref<FullProfile>(
props.initialProfile
? JSON.parse(JSON.stringify(props.initialProfile))
: { services: {}, nodes: {} }
)
const saving = ref(false)
const generating = ref(false)
const opError = ref('')
const expandedSvcs = ref<Set<string>>(new Set())
// Service modal
const showSvcModal = ref(false)
const editingSvcName = ref<string | undefined>()
const editingSvcDef = ref<ServiceDefinition | undefined>()
// Catalog modal
const showCatalogModal = ref(false)
const catalogTargetSvc = ref('')
const editingModelName = ref<string | undefined>()
const editingEntry = ref<CatalogEntryFull | undefined>()
// Generate nodes section from coordinator
async function generate() {
generating.value = true
opError.value = ''
try {
const r = await fetch(`/api/nodes-mgmt/nodes/${props.nodeId}/profile/generate`, { method: 'POST' })
if (!r.ok) { const d = await r.json().catch(() => ({})); throw new Error((d as {detail?: string}).detail ?? `HTTP ${r.status}`) }
const generated = await r.json() as FullProfile
// Merge: keep current services edits, replace nodes section
profile.value = { ...generated, services: profile.value.services }
} catch (e) {
opError.value = e instanceof Error ? e.message : 'Generate failed'
} finally {
generating.value = false
}
}
// Save full profile
async function save() {
saving.value = true
opError.value = ''
try {
const r = await fetch(`/api/nodes-mgmt/nodes/${props.nodeId}/profile`, {
method: 'PUT',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ profile: profile.value }),
})
if (!r.ok) { const d = await r.json().catch(() => ({})); throw new Error((d as {detail?: string}).detail ?? `HTTP ${r.status}`) }
emit('saved')
} catch (e) {
opError.value = e instanceof Error ? e.message : 'Save failed'
} finally {
saving.value = false
}
}
// Service CRUD
function openAddService() {
editingSvcName.value = undefined
editingSvcDef.value = undefined
showSvcModal.value = true
}
function openEditService(name: string) {
editingSvcName.value = name
editingSvcDef.value = JSON.parse(JSON.stringify(profile.value.services[name]))
showSvcModal.value = true
}
function onServiceSaved(name: string, def: ServiceDefinition) {
profile.value = { ...profile.value, services: { ...profile.value.services, [name]: def } }
expandedSvcs.value = new Set([...expandedSvcs.value, name])
showSvcModal.value = false
}
function deleteService(name: string) {
if (!confirm(`Remove service "${name}" from this profile?`)) return
const svcs = { ...profile.value.services }
delete svcs[name]
profile.value = { ...profile.value, services: svcs }
expandedSvcs.value = new Set([...expandedSvcs.value].filter(s => s !== name))
}
function toggleSvc(name: string) {
const s = new Set(expandedSvcs.value)
s.has(name) ? s.delete(name) : s.add(name)
expandedSvcs.value = s
}
// Catalog CRUD
function openAddCatalogEntry(svcName: string) {
catalogTargetSvc.value = svcName
editingModelName.value = undefined
editingEntry.value = undefined
showCatalogModal.value = true
}
function openEditCatalogEntry(svcName: string, modelName: string) {
catalogTargetSvc.value = svcName
editingModelName.value = modelName
editingEntry.value = JSON.parse(JSON.stringify(profile.value.services[svcName].catalog![modelName]))
showCatalogModal.value = true
}
function onCatalogEntrySaved(svcName: string, modelName: string, entry: CatalogEntryFull) {
const svcs = { ...profile.value.services }
const svc = { ...svcs[svcName], catalog: { ...(svcs[svcName].catalog ?? {}), [modelName]: entry } }
svcs[svcName] = svc
profile.value = { ...profile.value, services: svcs }
showCatalogModal.value = false
}
function deleteCatalogEntry(svcName: string, modelName: string) {
if (!confirm(`Remove model "${modelName}" from ${svcName} catalog?`)) return
const svcs = { ...profile.value.services }
const catalog = { ...(svcs[svcName].catalog ?? {}) }
delete catalog[modelName]
svcs[svcName] = { ...svcs[svcName], catalog }
profile.value = { ...profile.value, services: svcs }
}
// Helpers
function gpuList() {
return (profile.value.nodes[props.nodeId]?.gpus ?? [])
}
function serviceCount() {
return Object.keys(profile.value.services).length
}
// Ollama model suggestions
interface OllamaModel { name: string; size: number }
const ollamaModels = ref<OllamaModel[]>([])
const ollamaLoading = ref(false)
onMounted(async () => {
ollamaLoading.value = true
try {
const r = await fetch(`/api/nodes-mgmt/nodes/${props.nodeId}/models/ollama`)
if (r.ok) {
const d = await r.json() as { models?: OllamaModel[] }
ollamaModels.value = d.models ?? []
}
} catch { /* Ollama offline — silently skip */ }
finally { ollamaLoading.value = false }
})
function ollamaNotInCatalog(svcName: string): OllamaModel[] {
const catalog = profile.value.services[svcName]?.catalog ?? {}
return ollamaModels.value.filter(m => !(m.name in catalog))
}
function openAddFromOllama(svcName: string, modelName: string) {
catalogTargetSvc.value = svcName
editingModelName.value = modelName
editingEntry.value = {
path: profile.value.services[svcName]?.model_base_path
? `${profile.value.services[svcName].model_base_path}/${modelName}`
: '',
vram_mb: 0,
}
showCatalogModal.value = true
}
function formatMb(bytes: number): string {
return bytes >= 1_000_000_000
? `${(bytes / 1_073_741_824).toFixed(1)} GB`
: `${Math.round(bytes / 1_048_576)} MB`
}
// Pull model onto node
const pullName = ref('')
const pulling = ref(false)
const pullStatus = ref('')
const pullPct = ref(0)
const pullError = ref('')
let pullAbort: AbortController | null = null
async function doPull() {
const name = pullName.value.trim()
if (!name || pulling.value) return
pulling.value = true
pullStatus.value = 'Starting…'
pullError.value = ''
pullPct.value = 0
pullAbort?.abort()
pullAbort = new AbortController()
try {
const resp = await fetch(`/api/nodes-mgmt/nodes/${props.nodeId}/models/ollama/pull`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ name }),
signal: pullAbort.signal,
})
if (!resp.ok || !resp.body) {
pullError.value = `HTTP ${resp.status}`
return
}
const reader = resp.body.getReader()
const dec = new TextDecoder()
let buf = ''
while (true) {
const { done, value } = await reader.read()
if (done) break
buf += dec.decode(value, { stream: true })
const lines = buf.split('\n')
buf = lines.pop() ?? ''
for (const line of lines) {
if (!line.startsWith('data:')) continue
try {
const d = JSON.parse(line.slice(5)) as {
status?: string; completed?: number; total?: number; error?: string; done?: boolean
}
if (d.error) { pullError.value = d.error; return }
pullStatus.value = d.status ?? ''
if (d.total && d.total > 0) pullPct.value = Math.round((d.completed ?? 0) / d.total * 100)
if (d.done) {
pullName.value = ''
pullPct.value = 100
// Refresh Ollama model list so new model appears in suggest chips
const r = await fetch(`/api/nodes-mgmt/nodes/${props.nodeId}/models/ollama`)
if (r.ok) { const d2 = await r.json() as { models?: OllamaModel[] }; ollamaModels.value = d2.models ?? [] }
}
} catch { /* skip malformed SSE line */ }
}
}
} catch (e) {
if (e instanceof Error && e.name !== 'AbortError') pullError.value = e.message
} finally {
pulling.value = false
if (pullPct.value === 100) setTimeout(() => { pullStatus.value = ''; pullPct.value = 0 }, 2000)
}
}
</script>
<template>
<section class="pep" aria-label="Profile editor">
<!-- Header -->
<div class="pep-header">
<div class="pep-title-row">
<h3 class="pep-title">Profile {{ nodeId }}</h3>
<span class="pep-svc-count">{{ serviceCount() }} service{{ serviceCount() === 1 ? '' : 's' }}</span>
</div>
<div class="pep-actions">
<button class="btn-secondary btn-sm" :disabled="generating" @click="generate">
{{ generating ? 'Refreshing…' : 'Refresh Hardware' }}
</button>
<button class="btn-primary btn-sm" :disabled="saving" @click="save">
{{ saving ? 'Saving…' : 'Save Profile' }}
</button>
<button class="btn-icon-lg" aria-label="Close editor" @click="emit('close')"></button>
</div>
</div>
<div v-if="opError" class="pep-error" role="alert">{{ opError }}</div>
<!-- Meta fields -->
<div class="pep-meta">
<label class="meta-label" for="pep-vram">vram_total_mb</label>
<input id="pep-vram" v-model.number="profile.vram_total_mb" type="number" min="0" class="meta-input" />
<label class="meta-label" for="pep-evict">eviction_timeout_s</label>
<input id="pep-evict" v-model.number="profile.eviction_timeout_s" type="number" min="0" step="0.5" class="meta-input" />
</div>
<!-- Hardware summary -->
<div v-if="gpuList().length" class="hw-section">
<span class="hw-label">Hardware</span>
<span v-for="g in gpuList()" :key="g.id" class="hw-gpu">
GPU {{ g.id }}: {{ g.card || 'unknown' }} · {{ g.vram_mb }} MB · sm{{ g.compute_cap ?? '?' }}
</span>
<span v-if="!gpuList().length" class="hw-none">No hardware data click Refresh Hardware.</span>
</div>
<div v-else class="hw-section">
<span class="hw-none">No hardware data click Refresh Hardware to seed from coordinator.</span>
</div>
<!-- Services -->
<div class="svcs-header">
<span class="svcs-title">Services</span>
<button class="btn-secondary btn-sm" @click="openAddService">+ Add Service</button>
</div>
<div v-if="serviceCount() === 0" class="svcs-empty">
No services defined. Add a service to configure what can run on this node.
</div>
<ul class="svcs-list" role="list">
<li
v-for="(def, svcName) in profile.services"
:key="String(svcName)"
class="svc-item"
>
<!-- Service row header -->
<div class="svc-row">
<button
class="svc-toggle"
:aria-expanded="expandedSvcs.has(String(svcName))"
@click="toggleSvc(String(svcName))"
>
<span class="svc-arrow">{{ expandedSvcs.has(String(svcName)) ? '▾' : '▸' }}</span>
<span class="svc-name">{{ svcName }}</span>
</button>
<span class="svc-badges">
<span class="badge">{{ def.max_mb }} MB</span>
<span class="badge">p{{ def.priority }}</span>
<span v-if="def.shared" class="badge badge--blue">shared</span>
<span v-if="def.managed" class="badge badge--dim">managed</span>
<span v-if="def.catalog" class="badge badge--dim">{{ Object.keys(def.catalog).length }} models</span>
</span>
<div class="svc-btns">
<button class="btn-secondary btn-xs" @click="openEditService(String(svcName))">Edit</button>
<button class="btn-danger btn-xs" @click="deleteService(String(svcName))">Delete</button>
</div>
</div>
<!-- Expanded catalog -->
<div v-if="expandedSvcs.has(String(svcName))" class="svc-detail">
<div class="svc-detail-meta">
<span v-if="def.min_compute_cap">min sm{{ def.min_compute_cap }}</span>
<span v-if="def.max_concurrent">max_concurrent: {{ def.max_concurrent }}</span>
<span v-if="def.idle_stop_after_s">idle_stop: {{ def.idle_stop_after_s }}s</span>
<span v-if="def.always_on" class="badge badge--blue">always_on</span>
</div>
<!-- Ollama model suggestions + pull -->
<div class="ollama-suggest">
<div class="suggest-row">
<span class="suggest-label">On node (Ollama):</span>
<span v-if="ollamaLoading" class="suggest-loading">loading</span>
<template v-else-if="ollamaNotInCatalog(String(svcName)).length">
<button
v-for="m in ollamaNotInCatalog(String(svcName))"
:key="m.name"
class="suggest-chip"
:title="`Add ${m.name} (${formatMb(m.size)}) to this service catalog`"
@click="openAddFromOllama(String(svcName), m.name)"
>
+ {{ m.name }} <span class="chip-size">{{ formatMb(m.size) }}</span>
</button>
</template>
<span v-else-if="!ollamaLoading" class="suggest-none">All Ollama models already in catalog.</span>
</div>
<!-- Pull model onto this node -->
<div class="pull-row">
<input
v-model="pullName"
class="pull-input"
placeholder="Pull model on node (e.g. llama3:8b)"
:disabled="pulling"
@keyup.enter="doPull"
/>
<button class="btn-pull" :disabled="pulling || !pullName.trim()" @click="doPull">
{{ pulling ? 'Pulling…' : 'Pull' }}
</button>
</div>
<div v-if="pulling || pullPct > 0" class="pull-progress">
<div class="pull-bar"><div class="pull-fill" :style="{ width: pullPct + '%' }" /></div>
<span class="pull-status">{{ pullStatus }}</span>
</div>
<div v-if="pullError" class="pull-err" role="alert">{{ pullError }}</div>
</div>
<div class="catalog-header">
<span class="catalog-title">Catalog</span>
<button class="btn-link" @click="openAddCatalogEntry(String(svcName))">+ Add Model</button>
</div>
<div v-if="!def.catalog || !Object.keys(def.catalog).length" class="catalog-empty">
No catalog entries. Only services like cf-text need a catalog.
</div>
<ul v-else class="catalog-list" role="list">
<li
v-for="(entry, modelName) in def.catalog"
:key="String(modelName)"
class="catalog-item"
>
<span class="catalog-model">{{ modelName }}</span>
<span class="catalog-vram">{{ entry.vram_mb }} MB</span>
<span v-if="entry.multi_gpu" class="badge badge--dim">multi-gpu</span>
<span v-if="entry.description" class="catalog-desc">{{ entry.description }}</span>
<div class="catalog-btns">
<button class="btn-secondary btn-xs" @click="openEditCatalogEntry(String(svcName), String(modelName))">Edit</button>
<button class="btn-danger btn-xs" @click="deleteCatalogEntry(String(svcName), String(modelName))"></button>
</div>
</li>
</ul>
</div>
</li>
</ul>
</section>
<!-- Service form modal -->
<ServiceFormModal
v-if="showSvcModal"
:service-name="editingSvcName"
:definition="editingSvcDef"
@save="onServiceSaved"
@cancel="showSvcModal = false"
/>
<!-- Catalog entry form modal -->
<CatalogEntryFormModal
v-if="showCatalogModal"
:svc-name="catalogTargetSvc"
:model-name="editingModelName"
:entry="editingEntry"
@save="onCatalogEntrySaved"
@cancel="showCatalogModal = false"
/>
</template>
<style scoped>
.pep {
margin-top: 0.75rem;
padding: 1rem;
border: 1px solid var(--color-primary);
border-radius: 6px;
background: var(--color-surface-raised);
color: var(--color-text);
}
.pep-header {
display: flex; align-items: center; justify-content: space-between; gap: 0.5rem;
margin-bottom: 0.75rem; flex-wrap: wrap;
}
.pep-title-row { display: flex; align-items: baseline; gap: 0.5rem; }
.pep-title { margin: 0; font-size: 0.95rem; font-weight: 600; color: var(--color-text); }
.pep-svc-count { font-size: 0.75rem; color: var(--color-text-muted); }
.pep-actions { display: flex; align-items: center; gap: 0.4rem; flex-wrap: wrap; }
.pep-error { color: var(--color-error); font-size: 0.8rem; margin-bottom: 0.5rem; }
.pep-meta {
display: flex; align-items: center; gap: 0.5rem; flex-wrap: wrap;
padding: 0.5rem; background: var(--color-surface-alt); border-radius: 4px; margin-bottom: 0.75rem;
}
.meta-label { font-size: 0.8rem; color: var(--color-text-muted); }
.meta-input {
width: 7rem; background: var(--color-surface); border: 1px solid var(--color-border);
border-radius: 4px; padding: 0.2rem 0.4rem; color: var(--color-text); font-size: 0.8rem;
}
.hw-section {
display: flex; flex-wrap: wrap; align-items: center; gap: 0.5rem;
font-size: 0.8rem; color: var(--color-text-muted);
padding: 0.4rem 0.5rem; border-radius: 4px; background: var(--color-surface-alt);
margin-bottom: 0.75rem;
}
.hw-label { font-weight: 600; color: var(--color-text); }
.hw-gpu { font-family: monospace; color: var(--color-text); }
.hw-none { font-style: italic; }
.svcs-header {
display: flex; align-items: center; justify-content: space-between;
margin-bottom: 0.5rem;
}
.svcs-title { font-size: 0.85rem; font-weight: 600; color: var(--color-text); }
.svcs-empty { color: var(--color-text-muted); font-size: 0.85rem; padding: 0.5rem 0; }
.svcs-list { list-style: none; margin: 0; padding: 0; display: flex; flex-direction: column; gap: 0.4rem; }
.svc-item { border: 1px solid var(--color-border); border-radius: 4px; overflow: hidden; }
.svc-row {
display: flex; align-items: center; gap: 0.5rem; padding: 0.4rem 0.5rem;
background: var(--color-surface-alt); flex-wrap: wrap;
}
.svc-toggle {
display: flex; align-items: center; gap: 0.35rem;
background: none; border: none; cursor: pointer; color: var(--color-text); padding: 0; flex: 1; min-width: 0;
}
.svc-arrow { font-size: 0.7rem; color: var(--color-text-muted); }
.svc-name { font-size: 0.875rem; font-weight: 500; font-family: monospace; }
.svc-badges { display: flex; gap: 0.3rem; flex-wrap: wrap; }
.svc-btns { display: flex; gap: 0.3rem; margin-left: auto; }
.svc-detail { padding: 0.5rem 0.75rem; display: flex; flex-direction: column; gap: 0.5rem; background: var(--color-surface-raised); }
.svc-detail-meta {
display: flex; gap: 0.5rem; flex-wrap: wrap;
font-size: 0.78rem; color: var(--color-text-muted);
}
.ollama-suggest {
display: flex; flex-direction: column; gap: 0.35rem;
padding: 0.4rem 0.5rem;
background: var(--color-primary-light);
border: 1px solid var(--color-border-light);
border-radius: 4px;
font-size: 0.78rem;
}
.suggest-row { display: flex; flex-wrap: wrap; align-items: center; gap: 0.35rem; }
.suggest-label { color: var(--color-text-muted); font-weight: 500; white-space: nowrap; }
.suggest-loading { color: var(--color-text-muted); font-style: italic; }
.suggest-none { color: var(--color-text-muted); font-style: italic; }
.suggest-chip {
display: inline-flex; align-items: center; gap: 0.25rem;
padding: 0.15rem 0.45rem;
background: var(--color-surface-raised);
border: 1px solid var(--color-border);
border-radius: 3px;
color: var(--color-text);
cursor: pointer;
font-size: 0.78rem;
transition: border-color 0.15s, background 0.15s;
}
.suggest-chip:hover { border-color: var(--app-primary); background: var(--color-surface-alt); }
.chip-size { color: var(--color-text-muted); font-size: 0.72rem; }
.pull-row { display: flex; gap: 0.4rem; align-items: center; }
.pull-input {
flex: 1;
padding: 0.25rem 0.5rem;
background: var(--color-surface-raised);
border: 1px solid var(--color-border);
border-radius: 4px;
color: var(--color-text);
font-size: 0.78rem;
font-family: var(--font-mono, monospace);
}
.pull-input:disabled { opacity: 0.5; }
.btn-pull {
padding: 0.25rem 0.6rem;
background: var(--app-primary);
color: var(--color-text-inverse);
border: none;
border-radius: 4px;
cursor: pointer;
font-size: 0.78rem;
white-space: nowrap;
}
.btn-pull:hover:not(:disabled) { background: var(--app-primary-hover); }
.btn-pull:disabled { opacity: 0.5; cursor: not-allowed; }
.pull-progress { display: flex; align-items: center; gap: 0.4rem; }
.pull-bar {
flex: 1; height: 6px;
background: var(--color-border);
border-radius: 3px; overflow: hidden;
}
.pull-fill { height: 100%; background: var(--app-primary); transition: width 0.2s; }
.pull-status { color: var(--color-text-muted); font-size: 0.72rem; white-space: nowrap; max-width: 14rem; overflow: hidden; text-overflow: ellipsis; }
.pull-err { color: var(--color-error); font-size: 0.75rem; }
.catalog-header { display: flex; align-items: center; justify-content: space-between; }
.catalog-title { font-size: 0.8rem; font-weight: 600; color: var(--color-text-muted); text-transform: uppercase; letter-spacing: 0.05em; }
.catalog-empty { font-size: 0.8rem; color: var(--color-text-muted); font-style: italic; }
.catalog-list { list-style: none; margin: 0; padding: 0; display: flex; flex-direction: column; gap: 0.25rem; }
.catalog-item {
display: flex; align-items: center; gap: 0.4rem; flex-wrap: wrap;
padding: 0.25rem 0.5rem; background: var(--color-surface-alt); border-radius: 3px; font-size: 0.8rem;
color: var(--color-text);
}
.catalog-model { font-family: monospace; flex: 1; min-width: 12rem; }
.catalog-vram { color: var(--color-text-muted); white-space: nowrap; }
.catalog-desc { color: var(--color-text-muted); flex: 2; font-size: 0.75rem; }
.catalog-btns { display: flex; gap: 0.25rem; margin-left: auto; }
.badge {
padding: 0.1rem 0.4rem; border-radius: 3px; font-size: 0.72rem;
background: var(--color-surface); border: 1px solid var(--color-border); color: var(--color-text);
}
.badge--blue { border-color: var(--color-primary); color: var(--color-primary); background: var(--color-primary-light); }
.badge--dim { opacity: 0.75; }
.btn-link { background: none; border: none; color: var(--color-accent); cursor: pointer; font-size: 0.8rem; padding: 0; }
.btn-link:hover { color: var(--color-accent-hover); }
.btn-primary {
background: var(--color-primary); color: var(--color-text-inverse); border: none;
border-radius: 4px; cursor: pointer; font-size: 0.8rem;
}
.btn-primary:hover { background: var(--color-primary-hover); }
.btn-primary:disabled { opacity: 0.6; cursor: not-allowed; }
.btn-secondary {
background: transparent; border: 1px solid var(--color-border); color: var(--color-text);
border-radius: 4px; cursor: pointer; font-size: 0.8rem;
}
.btn-secondary:hover { background: var(--color-surface-alt); }
.btn-secondary:disabled { opacity: 0.6; cursor: not-allowed; }
.btn-danger {
background: transparent; border: 1px solid var(--color-error); color: var(--color-error);
border-radius: 4px; cursor: pointer; font-size: 0.8rem;
}
.btn-danger:hover { background: var(--color-surface-alt); }
.btn-sm { padding: 0.3rem 0.6rem; }
.btn-xs { padding: 0.15rem 0.4rem; }
.btn-icon-lg { background: none; border: none; color: var(--color-text-muted); cursor: pointer; font-size: 1rem; padding: 0.2rem 0.3rem; }
.btn-icon-lg:hover { color: var(--color-text); }
</style>

View file

@ -0,0 +1,82 @@
<script setup lang="ts">
type ServiceState =
| 'running'
| 'stopped'
| 'assigned-only'
| 'available'
| 'incompatible'
| 'vram-tight'
| 'unknown'
const props = defineProps<{
serviceName: string
state: ServiceState
assigned: boolean
disabled?: boolean
}>()
const emit = defineEmits<{ toggle: [] }>()
const STATE_LABELS: Record<ServiceState, string> = {
running: 'Running',
stopped: 'Stopped',
'assigned-only': 'Assigned',
available: 'Available',
incompatible: 'Incompatible',
'vram-tight': 'VRAM tight',
unknown: 'Unknown',
}
const STATE_ICONS: Record<ServiceState, string> = {
running: '▶',
stopped: '⏹',
'assigned-only': '📌',
available: '○',
incompatible: '✕',
'vram-tight': '⚠',
unknown: '?',
}
function handleToggle() {
if (!props.disabled && props.state !== 'incompatible') emit('toggle')
}
</script>
<template>
<button
class="service-badge"
:class="[`state-${state}`, { assigned, 'is-disabled': disabled || state === 'incompatible' }]"
:aria-pressed="assigned"
:aria-label="`${serviceName}: ${STATE_LABELS[state] ?? state}${assigned ? ' (assigned)' : ''}`"
:disabled="disabled || state === 'incompatible'"
@click="handleToggle"
>
<span class="badge-icon" aria-hidden="true">{{ STATE_ICONS[state] ?? '?' }}</span>
<span class="badge-name">{{ serviceName }}</span>
<span class="badge-state">{{ STATE_LABELS[state] ?? state }}</span>
</button>
</template>
<style scoped>
.service-badge {
display: inline-flex;
align-items: center;
gap: 0.3rem;
padding: 0.2rem 0.5rem;
border-radius: 4px;
border: 1px solid var(--color-border);
background: var(--color-surface);
color: var(--color-text);
font-size: 0.75rem;
cursor: pointer;
transition: opacity 0.1s, border-color 0.1s;
}
.service-badge:hover:not(.is-disabled) { opacity: 0.8; }
.service-badge.is-disabled { cursor: not-allowed; opacity: 0.5; }
.service-badge.state-running { border-color: var(--color-success); }
.service-badge.state-stopped { border-color: var(--color-warning); }
.service-badge.state-assigned-only { border-color: var(--color-info); }
.service-badge.state-incompatible { border-color: var(--color-error); }
.service-badge.state-vram-tight { border-color: var(--color-warning); }
.badge-state { color: var(--color-text-muted); }
</style>

View file

@ -0,0 +1,231 @@
<script setup lang="ts">
import { ref, watch, computed } from 'vue'
import type { ServiceDefinition } from '../../types/nodes'
const props = defineProps<{
serviceName?: string
definition?: ServiceDefinition
}>()
const emit = defineEmits<{
save: [name: string, def: ServiceDefinition]
cancel: []
}>()
const name = ref(props.serviceName ?? '')
const maxMb = ref(props.definition?.max_mb ?? 0)
const priority = ref(props.definition?.priority ?? 1)
const minCap = ref(props.definition?.min_compute_cap ?? 0)
const prefCap = ref<number | ''>(props.definition?.preferred_compute_cap ?? '')
const shared = ref(props.definition?.shared ?? false)
const maxConcurrent = ref<number | ''>(props.definition?.max_concurrent ?? '')
const idleStop = ref<number | ''>(props.definition?.idle_stop_after_s ?? '')
const alwaysOn = ref(props.definition?.always_on ?? false)
const modelBasePath = ref(props.definition?.model_base_path ?? '')
const hasManaged = ref(!!props.definition?.managed)
const managedJson = ref(
props.definition?.managed ? JSON.stringify(props.definition.managed, null, 2) : ''
)
const formError = ref('')
watch(() => props.definition, (d) => {
name.value = props.serviceName ?? ''
maxMb.value = d?.max_mb ?? 0
priority.value = d?.priority ?? 1
minCap.value = d?.min_compute_cap ?? 0
prefCap.value = d?.preferred_compute_cap ?? ''
shared.value = d?.shared ?? false
maxConcurrent.value = d?.max_concurrent ?? ''
idleStop.value = d?.idle_stop_after_s ?? ''
alwaysOn.value = d?.always_on ?? false
modelBasePath.value = d?.model_base_path ?? ''
hasManaged.value = !!d?.managed
managedJson.value = d?.managed ? JSON.stringify(d.managed, null, 2) : ''
})
const managedJsonError = computed(() => {
if (!hasManaged.value || !managedJson.value.trim()) return ''
try { JSON.parse(managedJson.value); return '' }
catch { return 'Invalid JSON' }
})
function submit() {
formError.value = ''
if (!name.value.trim()) { formError.value = 'Service name is required.'; return }
if (!maxMb.value || maxMb.value <= 0) { formError.value = 'max_mb must be > 0.'; return }
if (managedJsonError.value) { formError.value = 'Fix the managed JSON before saving.'; return }
const def: ServiceDefinition = { max_mb: maxMb.value, priority: priority.value }
if (minCap.value) def.min_compute_cap = minCap.value
if (prefCap.value !== '') def.preferred_compute_cap = Number(prefCap.value)
if (shared.value) def.shared = true
if (maxConcurrent.value !== '') def.max_concurrent = Number(maxConcurrent.value)
if (idleStop.value !== '') def.idle_stop_after_s = Number(idleStop.value)
if (alwaysOn.value) def.always_on = true
if (modelBasePath.value.trim()) def.model_base_path = modelBasePath.value.trim()
if (hasManaged.value && managedJson.value.trim()) {
def.managed = JSON.parse(managedJson.value)
}
// Preserve existing catalog when editing
if (props.definition?.catalog) def.catalog = props.definition.catalog
emit('save', name.value.trim(), def)
}
</script>
<template>
<div class="modal-backdrop" role="dialog" aria-modal="true" :aria-label="`${serviceName ? 'Edit' : 'Add'} service`">
<div class="modal-box">
<h3 class="modal-title">{{ serviceName ? 'Edit' : 'Add' }} Service</h3>
<div class="field-row">
<label class="field-label" for="sf-name">Service name</label>
<input id="sf-name" v-model="name" class="field-input" :readonly="!!serviceName" placeholder="cf-text" />
</div>
<div class="field-row">
<label class="field-label" for="sf-maxmb">max_mb</label>
<input id="sf-maxmb" v-model.number="maxMb" type="number" min="0" class="field-input field-input--sm" />
<span class="field-hint">VRAM ceiling</span>
</div>
<div class="field-row">
<label class="field-label" for="sf-prio">priority</label>
<input id="sf-prio" v-model.number="priority" type="number" min="1" max="10" class="field-input field-input--sm" />
<span class="field-hint">1 = highest</span>
</div>
<div class="field-row">
<label class="field-label" for="sf-mincap">min_compute_cap</label>
<input id="sf-mincap" v-model.number="minCap" type="number" step="0.1" min="0" class="field-input field-input--sm" placeholder="0.0" />
</div>
<div class="field-row">
<label class="field-label" for="sf-prefcap">preferred_cap</label>
<input id="sf-prefcap" v-model="prefCap" type="number" step="0.1" min="0" class="field-input field-input--sm" placeholder="optional" />
</div>
<div class="field-row field-row--check">
<input id="sf-shared" v-model="shared" type="checkbox" />
<label for="sf-shared">shared (multiple concurrent users)</label>
</div>
<div class="field-row">
<label class="field-label" for="sf-maxcon">max_concurrent</label>
<input id="sf-maxcon" v-model="maxConcurrent" type="number" min="1" class="field-input field-input--sm" placeholder="optional" />
</div>
<div class="field-row">
<label class="field-label" for="sf-idle">idle_stop_after_s</label>
<input id="sf-idle" v-model="idleStop" type="number" min="0" class="field-input field-input--sm" placeholder="optional" />
<span class="field-hint">seconds</span>
</div>
<div class="field-row field-row--check">
<input id="sf-always" v-model="alwaysOn" type="checkbox" />
<label for="sf-always">always_on (never evict)</label>
</div>
<div class="field-row">
<label class="field-label" for="sf-base">model_base_path</label>
<input id="sf-base" v-model="modelBasePath" class="field-input" placeholder="/devl/Assets/LLM/cf-text/models (optional)" />
</div>
<div class="managed-section">
<div class="field-row field-row--check">
<input id="sf-has-managed" v-model="hasManaged" type="checkbox" />
<label for="sf-has-managed">Has managed process config</label>
</div>
<div v-if="hasManaged" class="managed-body">
<label class="field-label" for="sf-managed">managed (JSON)</label>
<textarea
id="sf-managed"
v-model="managedJson"
class="field-textarea"
rows="6"
spellcheck="false"
placeholder='{"type": "process", "exec_path": "...", "args_template": "...", "port": 8008, "host_port": 8008}'
/>
<span v-if="managedJsonError" class="json-error" role="alert">{{ managedJsonError }}</span>
</div>
</div>
<div v-if="formError" class="form-error" role="alert">{{ formError }}</div>
<div class="modal-actions">
<button class="btn-secondary" @click="emit('cancel')">Cancel</button>
<button class="btn-primary" @click="submit">Save</button>
</div>
</div>
</div>
</template>
<style scoped>
.modal-backdrop {
position: fixed; inset: 0;
background: rgba(0,0,0,0.5);
display: flex; align-items: center; justify-content: center;
z-index: 200;
}
.modal-box {
background: var(--color-surface-raised);
border: 1px solid var(--color-border);
border-radius: 8px;
padding: 1.5rem;
width: 100%; max-width: 540px;
max-height: 90vh; overflow-y: auto;
display: flex; flex-direction: column; gap: 0.65rem;
color: var(--color-text);
}
.modal-title { margin: 0 0 0.25rem; font-size: 1rem; font-weight: 600; color: var(--color-text); }
.field-row { display: flex; align-items: center; gap: 0.5rem; }
.field-row--check { gap: 0.4rem; font-size: 0.875rem; color: var(--color-text); }
.field-label { min-width: 9rem; font-size: 0.85rem; color: var(--color-text-muted); flex-shrink: 0; }
.field-hint { font-size: 0.75rem; color: var(--color-text-muted); }
.field-input {
flex: 1;
background: var(--color-surface-alt);
border: 1px solid var(--color-border);
border-radius: 4px;
padding: 0.3rem 0.5rem;
color: var(--color-text);
font-size: 0.85rem;
}
.field-input--sm { flex: 0 0 8rem; }
.managed-section { display: flex; flex-direction: column; gap: 0.4rem; border-top: 1px solid var(--color-border); padding-top: 0.5rem; }
.managed-body { display: flex; flex-direction: column; gap: 0.3rem; }
.field-textarea {
width: 100%;
background: var(--color-surface-alt);
border: 1px solid var(--color-border);
border-radius: 4px;
padding: 0.4rem 0.5rem;
color: var(--color-text);
font-size: 0.8rem;
font-family: var(--font-mono, monospace);
resize: vertical;
box-sizing: border-box;
}
.json-error { color: var(--color-error); font-size: 0.78rem; }
.form-error { color: var(--color-error); font-size: 0.8rem; }
.modal-actions { display: flex; justify-content: flex-end; gap: 0.5rem; margin-top: 0.25rem; }
.btn-primary {
background: var(--app-primary);
color: var(--color-text-inverse);
border: none;
border-radius: 4px;
padding: 0.4rem 1rem;
cursor: pointer;
font-size: 0.875rem;
}
.btn-primary:hover { background: var(--app-primary-hover); }
.btn-secondary {
background: transparent;
border: 1px solid var(--color-border);
color: var(--color-text);
border-radius: 4px;
padding: 0.4rem 0.75rem;
cursor: pointer;
font-size: 0.875rem;
}
.btn-secondary:hover { background: var(--color-surface-alt); }
</style>

View file

@ -1,25 +1,53 @@
import { createRouter, createWebHashHistory } from 'vue-router'
import LabelView from '../views/LabelView.vue'
// Views are lazy-loaded to keep initial bundle small
// Lazy-loaded views
const DashboardView = () => import('../views/DashboardView.vue')
const LabelView = () => import('../views/LabelView.vue')
const FetchView = () => import('../views/FetchView.vue')
const StatsView = () => import('../views/StatsView.vue')
const BenchmarkView = () => import('../views/BenchmarkView.vue')
const SettingsView = () => import('../views/SettingsView.vue')
const CorrectionsView = () => import('../views/CorrectionsView.vue')
const ModelsView = () => import('../views/ModelsView.vue')
const ImitateView = () => import('../views/ImitateView.vue')
const BenchmarkView = () => import('../views/BenchmarkView.vue')
const CompareView = () => import('../views/CompareView.vue')
const TrainJobsView = () => import('../views/TrainJobsView.vue')
const TrainResultsView = () => import('../views/TrainResultsView.vue')
const ModelsView = () => import('../views/ModelsView.vue')
const SettingsView = () => import('../views/SettingsView.vue')
const NodeManagementView = () => import('../views/NodeManagementView.vue')
export const routes = [
// ── Top-level ────────────────────────────────────────────
{ path: '/', component: DashboardView, meta: { title: 'Dashboard' } },
{ path: '/fleet', component: ModelsView, meta: { title: 'Fleet' } },
{ path: '/nodes', component: NodeManagementView, meta: { title: 'Nodes' } },
{ path: '/settings', component: SettingsView, meta: { title: 'Settings' } },
// ── Data domain ──────────────────────────────────────────
{ path: '/data/label', component: LabelView, meta: { title: 'Label' } },
{ path: '/data/fetch', component: FetchView, meta: { title: 'Fetch' } },
{ path: '/data/corrections', component: CorrectionsView, meta: { title: 'Corrections' } },
{ path: '/data/imitate', component: ImitateView, meta: { title: 'Imitate' } },
{ path: '/data/recipe-scan', component: () => import('../views/RecipeScanView.vue'), meta: { title: 'Recipe Scan' } },
// ── Eval domain ──────────────────────────────────────────
{ path: '/eval/benchmark', component: BenchmarkView, meta: { title: 'Benchmark' } },
{ path: '/eval/compare', component: CompareView, meta: { title: 'Compare' } },
{ path: '/eval/embed-compare', component: () => import('../views/EmbedCompareView.vue'), meta: { title: 'Embed Compare' } },
// ── Train domain ─────────────────────────────────────────
{ path: '/train/jobs', component: TrainJobsView, meta: { title: 'Training Jobs' } },
{ path: '/train/results', component: TrainResultsView, meta: { title: 'Training Results' } },
// ── Backward-compat redirects ────────────────────────────
{ path: '/benchmark', redirect: '/eval/benchmark' },
{ path: '/models', redirect: '/fleet' },
{ path: '/stats', redirect: '/' },
{ path: '/label', redirect: '/data/label' },
{ path: '/fetch', redirect: '/data/fetch' },
{ path: '/corrections', redirect: '/data/corrections' },
{ path: '/imitate', redirect: '/data/imitate' },
]
export const router = createRouter({
history: createWebHashHistory(),
routes: [
{ path: '/', component: LabelView, meta: { title: 'Label' } },
{ path: '/fetch', component: FetchView, meta: { title: 'Fetch' } },
{ path: '/stats', component: StatsView, meta: { title: 'Stats' } },
{ path: '/benchmark', component: BenchmarkView, meta: { title: 'Benchmark' } },
{ path: '/models', component: ModelsView, meta: { title: 'Models' } },
{ path: '/imitate', component: ImitateView, meta: { title: 'Imitate' } },
{ path: '/corrections', component: CorrectionsView, meta: { title: 'Corrections' } },
{ path: '/settings', component: SettingsView, meta: { title: 'Settings' } },
],
routes,
})

View file

@ -0,0 +1,94 @@
import { describe, it, expect } from 'vitest'
import { createRouter, createWebHashHistory } from 'vue-router'
// Import the raw routes array so we can test structure without mounting App
import { routes } from './index'
describe('router routes', () => {
it('exports a routes array', () => {
expect(Array.isArray(routes)).toBe(true)
})
it('has / pointing to DashboardView', () => {
const root = routes.find(r => r.path === '/')
expect(root).toBeDefined()
// Component should be async (lazy) or have a name
expect(root?.component).toBeDefined()
})
it('has /fleet route', () => {
const r = routes.find(r => r.path === '/fleet')
expect(r).toBeDefined()
})
it('has /data/label route', () => {
const r = routes.find(r => r.path === '/data/label')
expect(r).toBeDefined()
})
it('has /data/fetch route', () => {
const r = routes.find(r => r.path === '/data/fetch')
expect(r).toBeDefined()
})
it('has /data/corrections route', () => {
const r = routes.find(r => r.path === '/data/corrections')
expect(r).toBeDefined()
})
it('has /data/imitate route', () => {
const r = routes.find(r => r.path === '/data/imitate')
expect(r).toBeDefined()
})
it('has /eval/benchmark route', () => {
const r = routes.find(r => r.path === '/eval/benchmark')
expect(r).toBeDefined()
})
it('has /eval/compare route', () => {
const r = routes.find(r => r.path === '/eval/compare')
expect(r).toBeDefined()
})
it('has /train/jobs route', () => {
const r = routes.find(r => r.path === '/train/jobs')
expect(r).toBeDefined()
})
it('has /train/results route', () => {
const r = routes.find(r => r.path === '/train/results')
expect(r).toBeDefined()
})
it('has /settings route', () => {
const r = routes.find(r => r.path === '/settings')
expect(r).toBeDefined()
})
it('has backward-compat redirect from /benchmark to /eval/benchmark', () => {
const r = routes.find(r => r.path === '/benchmark')
expect(r).toBeDefined()
expect((r as { redirect?: string }).redirect).toBe('/eval/benchmark')
})
it('has backward-compat redirect from /models to /fleet', () => {
const r = routes.find(r => r.path === '/models')
expect(r).toBeDefined()
expect((r as { redirect?: string }).redirect).toBe('/fleet')
})
it('has backward-compat redirect from /stats to /', () => {
const r = routes.find(r => r.path === '/stats')
expect(r).toBeDefined()
expect((r as { redirect?: string }).redirect).toBe('/')
})
it('can create a functional router instance', () => {
const router = createRouter({
history: createWebHashHistory(),
routes,
})
expect(router).toBeDefined()
})
})

89
web/src/types/nodes.ts Normal file
View file

@ -0,0 +1,89 @@
export interface GpuEntry {
gpu_id: number
card: string
vram_total_mb: number
vram_used_mb: number
vram_free_mb: number
temp_c: number | null
utilization_pct: number | null
compute_cap: number | null
services_assigned: string[]
services_running: string[]
}
export interface ServiceInfo {
min_compute_cap: number
max_mb: number
catalog_size: number
}
export interface NodeSummary {
node_id: string
online: boolean
agent_url: string
gpus: GpuEntry[]
profile_loaded: boolean
services_catalog: Record<string, ServiceInfo>
}
// ── Full profile types (for profile editor) ────────────────────────────────────
export interface ServiceManaged {
type: string
exec_path?: string
args_template?: string
port?: number
host_port?: number
base_port?: number
health_path?: string
cwd?: string
adopt?: boolean
[key: string]: unknown
}
export interface CatalogEntryFull {
path: string
vram_mb: number
description?: string
multi_gpu?: boolean
env?: Record<string, string>
}
export interface ServiceDefinition {
max_mb: number
priority: number
min_compute_cap?: number
preferred_compute_cap?: number
shared?: boolean
max_concurrent?: number
idle_stop_after_s?: number
always_on?: boolean
model_base_path?: string
managed?: ServiceManaged
catalog?: Record<string, CatalogEntryFull>
}
export interface NodeHardwareGpu {
id: number
vram_mb: number
compute_cap?: number
card?: string
role?: string
services?: string[]
}
export interface NodeHardwareEntry {
local_model_root?: string
agent_url?: string
gpus: NodeHardwareGpu[]
}
export interface FullProfile {
schema_version?: number
name?: string
vram_total_mb?: number
eviction_timeout_s?: number
services: Record<string, ServiceDefinition>
nodes: Record<string, NodeHardwareEntry>
model_size_hints?: Record<string, string>
}

View file

@ -0,0 +1,987 @@
<template>
<div class="assignments-tab">
<!-- Toast -->
<div v-if="toast" class="toast" :class="toast.type" role="status" aria-live="polite">
{{ toast.message }}
</div>
<!-- Assignments section -->
<div class="section-header">
<h2 class="section-title">Task Assignments</h2>
<button class="btn-primary btn-sm" @click="openNewAssignment">+ New Assignment</button>
</div>
<div class="filter-row">
<label for="product-filter" class="filter-label">Product</label>
<select id="product-filter" v-model="productFilter" class="filter-select">
<option value="">All products</option>
<option v-for="p in allProducts" :key="p" :value="p">{{ p }}</option>
</select>
</div>
<div v-if="assignmentsLoading" class="empty-state">Loading assignments</div>
<div v-else-if="assignmentsError" class="error-notice" role="alert">{{ assignmentsError }}</div>
<div v-else-if="filteredGroups.length === 0" class="empty-state">No assignments yet. Add one above.</div>
<div v-else class="product-groups">
<div v-for="group in filteredGroups" :key="group.product" class="product-group">
<h3 class="product-name">{{ group.product.toUpperCase() }}</h3>
<div class="assignment-list">
<div v-for="a in group.assignments" :key="`${a.product}/${a.task}`" class="assignment-row">
<div class="assignment-main">
<span class="task-id">{{ a.task }}</span>
<span
class="model-name"
:title="a.model_id"
>{{ displayModelId(a) }}</span>
<span v-if="a.vram_mb" class="chip chip-vram">{{ formatVram(a.vram_mb) }}</span>
<span v-if="a.service_type" class="chip" :class="serviceChipClass(a.service_type)">{{ a.service_type }}</span>
</div>
<!-- Node deployment status -->
<div v-if="deploymentMap[`${a.product}/${a.task}`]" class="node-statuses">
<span
v-for="ns in deploymentMap[`${a.product}/${a.task}`]"
:key="ns.node_id"
class="node-badge-wrap"
>
<span
class="node-badge"
:class="ns.status"
:title="`${ns.node_id}: ${ns.status}`"
>
<span class="node-icon">{{ nodeIcon(ns.status) }}</span>
{{ ns.node_id }}
</span>
<button
v-if="ns.status === 'absent'"
class="btn-deploy"
:disabled="deploying.has(`${a.product}/${a.task}/${ns.node_id}`)"
:title="`Register ${a.model_id} in ${ns.node_id} catalog`"
@click="deployModel(a, ns.node_id)"
>{{ deploying.has(`${a.product}/${a.task}/${ns.node_id}`) ? '…' : 'Register' }}</button>
</span>
</div>
<div class="assignment-actions">
<button
v-if="editingKey !== `${a.product}/${a.task}`"
class="btn-ghost btn-sm"
@click="startEdit(a)"
>Edit</button>
<button
class="btn-ghost btn-sm btn-danger"
@click="deleteAssignment(a.product, a.task)"
>Delete</button>
</div>
<!-- Inline edit form -->
<div v-if="editingKey === `${a.product}/${a.task}`" class="inline-edit">
<select v-model="editDraft.model_id" class="edit-select" aria-label="Model">
<option value="" disabled>Select model</option>
<option v-for="m in registryModels" :key="m.model_id" :value="m.model_id">
{{ m.alias || truncate(m.model_id, 40) }}
</option>
</select>
<input
v-model="editDraft.description"
type="text"
class="edit-input"
placeholder="Description (optional)"
/>
<div class="inline-edit-btns">
<button class="btn-primary btn-sm" :disabled="!editDraft.model_id" @click="saveEdit(a)">Save</button>
<button class="btn-ghost btn-sm" @click="editingKey = null">Cancel</button>
</div>
</div>
</div>
</div>
</div>
</div>
<!-- Model Registry section -->
<div class="section-header section-header-mt">
<h2 class="section-title">Model Registry</h2>
<button class="btn-primary btn-sm" @click="showRegisterModal = true">Register Model</button>
</div>
<div v-if="registryLoading" class="empty-state">Loading model registry</div>
<div v-else-if="registryError" class="error-notice" role="alert">{{ registryError }}</div>
<div v-else-if="registryModels.length === 0" class="empty-state">No models registered yet.</div>
<div v-else class="registry-table-wrap">
<table class="registry-table">
<thead>
<tr>
<th>Alias</th>
<th>Model ID</th>
<th>VRAM</th>
<th>Service</th>
<th class="col-hf">HF Repo</th>
<th></th>
</tr>
</thead>
<tbody>
<tr v-for="m in registryModels" :key="m.model_id">
<td>{{ m.alias || '—' }}</td>
<td>
<span class="truncated" :title="m.model_id">{{ truncate(m.model_id, 36) }}</span>
</td>
<td>{{ formatVram(m.vram_mb) }}</td>
<td><span class="chip" :class="serviceChipClass(m.service_type)">{{ m.service_type }}</span></td>
<td class="col-hf">
<a
v-if="m.hf_repo"
:href="`https://huggingface.co/${m.hf_repo}`"
target="_blank"
rel="noopener noreferrer"
class="hf-link"
>{{ truncate(m.hf_repo, 30) }}</a>
<span v-else class="text-muted"></span>
</td>
<td>
<button class="btn-ghost btn-sm btn-danger" @click="deleteModel(m.model_id)">Delete</button>
</td>
</tr>
</tbody>
</table>
</div>
<!-- New Assignment modal -->
<div v-if="showNewAssignmentModal" class="modal-backdrop" @click.self="showNewAssignmentModal = false">
<div class="modal" role="dialog" aria-modal="true" aria-labelledby="modal-new-assignment-title">
<h3 id="modal-new-assignment-title" class="modal-title">New Assignment</h3>
<label class="form-label">Product</label>
<input
v-model="newAssignment.product"
list="product-list"
class="form-input"
placeholder="e.g. peregrine"
autocomplete="off"
/>
<datalist id="product-list">
<option v-for="p in allProducts" :key="p" :value="p" />
</datalist>
<label class="form-label">Task ID</label>
<input
v-model="newAssignment.task"
type="text"
class="form-input"
placeholder="e.g. cover_letter"
/>
<label class="form-label">Model</label>
<select v-model="newAssignment.model_id" class="form-select">
<option value="" disabled>Select from registry</option>
<option v-for="m in registryModels" :key="m.model_id" :value="m.model_id">
{{ m.alias || truncate(m.model_id, 50) }}
</option>
</select>
<label class="form-label">Description <span class="optional">(optional)</span></label>
<input
v-model="newAssignment.description"
type="text"
class="form-input"
placeholder="Human-readable note for operators"
/>
<div class="modal-actions">
<button
class="btn-primary"
:disabled="!newAssignment.product || !newAssignment.task || !newAssignment.model_id || saving"
@click="saveNewAssignment"
>{{ saving ? 'Saving…' : 'Save' }}</button>
<button class="btn-ghost" @click="showNewAssignmentModal = false">Cancel</button>
</div>
</div>
</div>
<!-- Register Model modal -->
<div v-if="showRegisterModal" class="modal-backdrop" @click.self="showRegisterModal = false">
<div class="modal" role="dialog" aria-modal="true" aria-labelledby="modal-register-title">
<h3 id="modal-register-title" class="modal-title">Register Model</h3>
<label class="form-label">Model ID <span class="hint">(HuggingFace slug, e.g. ibm-granite/granite-4.1-8b)</span></label>
<input v-model="newModel.model_id" type="text" class="form-input" placeholder="org/model-name" />
<label class="form-label">Alias <span class="optional">(optional, short name for assignments)</span></label>
<input v-model="newModel.alias" type="text" class="form-input" placeholder="e.g. granite-8b" />
<label class="form-label">Service type</label>
<select v-model="newModel.service_type" class="form-select">
<option value="" disabled>Select service</option>
<option value="cf-text">cf-text Language Models</option>
<option value="cf-stt">cf-stt Speech Recognition</option>
<option value="cf-tts">cf-tts Text to Speech</option>
<option value="cf-vision">cf-vision Vision / VLM</option>
<option value="cf-image">cf-image Image Generation</option>
<option value="cf-voice">cf-voice Audio Classification</option>
<option value="vllm">vllm vLLM inference</option>
<option value="ollama">ollama Ollama inference</option>
</select>
<label class="form-label">VRAM required (MB)</label>
<input v-model.number="newModel.vram_mb" type="number" min="0" class="form-input" placeholder="e.g. 16384" />
<label class="form-label">HF Repo <span class="optional">(optional)</span></label>
<input v-model="newModel.hf_repo" type="text" class="form-input" placeholder="org/repo-name" />
<label class="form-label">Description <span class="optional">(optional)</span></label>
<input v-model="newModel.description" type="text" class="form-input" placeholder="Human-readable note" />
<div class="modal-actions">
<button
class="btn-primary"
:disabled="!newModel.model_id || !newModel.service_type || !newModel.vram_mb || saving"
@click="saveNewModel"
>{{ saving ? 'Saving…' : 'Register' }}</button>
<button class="btn-ghost" @click="showRegisterModal = false">Cancel</button>
</div>
</div>
</div>
</div>
</template>
<script setup lang="ts">
import { ref, computed, onMounted } from 'vue'
// Types
interface AssignmentNode {
node_id: string
status: 'present' | 'absent' | 'vram_tight'
}
interface DeployingKey {
nodeId: string
assignmentKey: string
}
interface Assignment {
product: string
task: string
model_id: string
description: string
alias?: string
service_type?: string
vram_mb?: number
nodes?: AssignmentNode[]
}
interface RegistryModel {
model_id: string
alias: string
service_type: string
vram_mb: number
hf_repo: string
description: string
}
interface ProductGroup {
product: string
assignments: Assignment[]
}
interface Toast {
message: string
type: 'success' | 'error'
}
// State
const assignments = ref<Assignment[]>([])
const assignmentsLoading = ref(false)
const assignmentsError = ref<string | null>(null)
const registryModels = ref<RegistryModel[]>([])
const registryLoading = ref(false)
const registryError = ref<string | null>(null)
const productFilter = ref('')
const editingKey = ref<string | null>(null)
const editDraft = ref({ model_id: '', description: '' })
const showNewAssignmentModal = ref(false)
const newAssignment = ref({ product: '', task: '', model_id: '', description: '' })
const showRegisterModal = ref(false)
const newModel = ref({ model_id: '', alias: '', service_type: '', vram_mb: 0, hf_repo: '', description: '' })
const saving = ref(false)
const toast = ref<Toast | null>(null)
let toastTimer: ReturnType<typeof setTimeout> | null = null
const deploying = ref<Set<string>>(new Set())
// Derived
const allProducts = computed(() => {
const seen = new Set<string>()
for (const a of assignments.value) seen.add(a.product)
return [...seen].sort()
})
const deploymentMap = computed(() => {
const map: Record<string, AssignmentNode[]> = {}
for (const a of assignments.value) {
if (a.nodes) map[`${a.product}/${a.task}`] = a.nodes
}
return map
})
const filteredGroups = computed((): ProductGroup[] => {
const filtered = productFilter.value
? assignments.value.filter(a => a.product === productFilter.value)
: assignments.value
const byProduct: Record<string, Assignment[]> = {}
for (const a of filtered) {
if (!byProduct[a.product]) byProduct[a.product] = []
byProduct[a.product].push(a)
}
return Object.keys(byProduct)
.sort()
.map(product => ({ product, assignments: byProduct[product] }))
})
// Helpers
function truncate(s: string, max: number): string {
return s.length > max ? s.slice(0, max - 1) + '…' : s
}
function displayModelId(a: Assignment): string {
if (a.alias) return a.alias
const id = a.model_id
// Show only the model name part (after /) and truncate long slugs
const short = id.includes('/') ? id.split('/').slice(1).join('/') : id
return truncate(short, 36)
}
function formatVram(mb: number | undefined): string {
if (!mb) return ''
if (mb >= 1024) return `${(mb / 1024).toFixed(1)} GB`
return `${mb} MB`
}
function serviceChipClass(service: string): string {
return `chip-service-${service.replace(/[^a-z0-9]/g, '-')}`
}
function nodeIcon(status: string): string {
if (status === 'present') return '✓'
if (status === 'vram_tight') return '~'
return '✗'
}
function showToast(message: string, type: 'success' | 'error' = 'success') {
if (toastTimer) clearTimeout(toastTimer)
toast.value = { message, type }
toastTimer = setTimeout(() => { toast.value = null }, 3500)
}
function openNewAssignment() {
newAssignment.value = { product: '', task: '', model_id: '', description: '' }
showNewAssignmentModal.value = true
}
function startEdit(a: Assignment) {
editingKey.value = `${a.product}/${a.task}`
editDraft.value = { model_id: a.model_id, description: a.description }
}
// API
async function loadAssignments() {
assignmentsLoading.value = true
assignmentsError.value = null
try {
// Fetch both list and deployment status in parallel
const [listRes, statusRes] = await Promise.all([
fetch('/api/cforch/assignments'),
fetch('/api/cforch/assignments/deployment-status'),
])
if (!listRes.ok) throw new Error(`HTTP ${listRes.status}`)
const list: Assignment[] = (await listRes.json()).assignments ?? []
// Merge deployment status into assignments if available
if (statusRes.ok) {
const statusList: Assignment[] = (await statusRes.json()).deployment_status ?? []
const statusMap: Record<string, AssignmentNode[]> = {}
for (const s of statusList) {
statusMap[`${s.product}/${s.task}`] = s.nodes ?? []
}
for (const a of list) {
a.nodes = statusMap[`${a.product}/${a.task}`] ?? []
// Enrich with service_type/vram_mb from status payload
const s = statusList.find(x => x.product === a.product && x.task === a.task)
if (s) {
a.service_type = s.service_type
a.vram_mb = s.vram_mb
a.alias = s.alias
}
}
}
assignments.value = list
} catch (e) {
assignmentsError.value = `Could not load assignments: ${e}`
} finally {
assignmentsLoading.value = false
}
}
async function loadRegistry() {
registryLoading.value = true
registryError.value = null
try {
const res = await fetch('/api/cforch/model-registry')
if (!res.ok) throw new Error(`HTTP ${res.status}`)
registryModels.value = (await res.json()).models ?? []
} catch (e) {
registryError.value = `Could not load model registry: ${e}`
} finally {
registryLoading.value = false
}
}
async function saveNewAssignment() {
saving.value = true
try {
const res = await fetch('/api/cforch/assignments', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify(newAssignment.value),
})
if (!res.ok) throw new Error(await res.text())
showNewAssignmentModal.value = false
showToast('Assignment saved')
await loadAssignments()
} catch (e) {
showToast(`Save failed: ${e}`, 'error')
} finally {
saving.value = false
}
}
async function saveEdit(a: Assignment) {
saving.value = true
try {
const res = await fetch('/api/cforch/assignments', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({
product: a.product,
task: a.task,
model_id: editDraft.value.model_id,
description: editDraft.value.description,
}),
})
if (!res.ok) throw new Error(await res.text())
editingKey.value = null
showToast('Assignment updated')
await loadAssignments()
} catch (e) {
showToast(`Update failed: ${e}`, 'error')
} finally {
saving.value = false
}
}
async function deleteAssignment(product: string, task: string) {
if (!confirm(`Delete assignment ${product}.${task}?`)) return
try {
const res = await fetch(
`/api/cforch/assignments/${encodeURIComponent(product)}/${encodeURIComponent(task)}`,
{ method: 'DELETE' },
)
if (!res.ok) throw new Error(await res.text())
showToast('Assignment deleted')
await loadAssignments()
} catch (e) {
showToast(`Delete failed: ${e}`, 'error')
}
}
async function saveNewModel() {
saving.value = true
try {
const res = await fetch('/api/cforch/model-registry', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify(newModel.value),
})
if (!res.ok) throw new Error(await res.text())
showRegisterModal.value = false
showToast('Model registered')
await loadRegistry()
} catch (e) {
showToast(`Register failed: ${e}`, 'error')
} finally {
saving.value = false
}
}
async function deleteModel(model_id: string) {
if (!confirm(`Remove ${model_id} from the registry?`)) return
try {
const res = await fetch(
`/api/cforch/model-registry/${encodeURIComponent(model_id)}`,
{ method: 'DELETE' },
)
if (!res.ok) throw new Error(await res.text())
showToast('Model removed')
await loadRegistry()
} catch (e) {
showToast(`Delete failed: ${e}`, 'error')
}
}
async function deployModel(a: Assignment, nodeId: string) {
const key = `${a.product}/${a.task}/${nodeId}`
if (deploying.value.has(key)) return
// Look up hf_repo from registry for cleaner path construction
const regEntry = registryModels.value.find(m => m.model_id === a.model_id)
const hf_repo = regEntry?.hf_repo ?? ''
const service_type = a.service_type ?? regEntry?.service_type ?? ''
const vram_mb = a.vram_mb ?? regEntry?.vram_mb ?? 0
const description = regEntry?.alias ? `${regEntry.alias} (via assignments)` : ''
if (!service_type) {
showToast(`No service type for model ${a.model_id}`, 'error')
return
}
deploying.value = new Set([...deploying.value, key])
try {
const res = await fetch(`/api/nodes-mgmt/nodes/${encodeURIComponent(nodeId)}/models/deploy`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ model_id: a.model_id, service_type, vram_mb, hf_repo, description }),
})
if (!res.ok) throw new Error(await res.text())
const data = await res.json()
showToast(`Registered ${a.model_id} on ${nodeId} at ${data.path}`)
// Optimistic update: flip node to 'present' immediately so the Register button
// disappears before the coordinator reload confirms. loadAssignments() reconciles
// with real server state on the next round-trip.
assignments.value = assignments.value.map(asgn => {
if (asgn.product !== a.product || asgn.task !== a.task) return asgn
return {
...asgn,
nodes: (asgn.nodes ?? []).map(ns =>
ns.node_id === nodeId ? { ...ns, status: 'present' as const } : ns
),
}
})
await loadAssignments()
} catch (e) {
showToast(`Deploy failed: ${e}`, 'error')
} finally {
deploying.value = new Set([...deploying.value].filter(k => k !== key))
}
}
onMounted(() => {
loadAssignments()
loadRegistry()
})
</script>
<style scoped>
.assignments-tab {
display: flex;
flex-direction: column;
gap: 1.25rem;
}
/* ── Toast ── */
.toast {
position: fixed;
bottom: 1.5rem;
right: 1.5rem;
padding: 0.65rem 1.1rem;
border-radius: 0.5rem;
font-size: 0.88rem;
font-weight: 500;
z-index: 200;
box-shadow: 0 2px 8px rgba(0,0,0,0.15);
}
.toast.success {
background: var(--color-success, #2a8050);
color: #fff;
}
.toast.error {
background: var(--color-danger, #b03030);
color: #fff;
}
/* ── Section headers ── */
.section-header {
display: flex;
align-items: center;
justify-content: space-between;
gap: 1rem;
}
.section-header-mt {
margin-top: 1.5rem;
}
.section-title {
font-size: 1rem;
font-weight: 600;
color: var(--app-primary, #2A6080);
margin: 0;
}
/* ── Filter row ── */
.filter-row {
display: flex;
align-items: center;
gap: 0.6rem;
}
.filter-label {
font-size: 0.85rem;
color: var(--color-text-muted, #6b7a99);
}
.filter-select {
padding: 0.3rem 0.6rem;
font-size: 0.85rem;
border: 1px solid var(--color-border, #d0d7e8);
border-radius: 0.4rem;
background: var(--color-surface, #fff);
color: var(--color-text, #1a2030);
}
/* ── Product groups ── */
.product-groups {
display: flex;
flex-direction: column;
gap: 1rem;
}
.product-group {}
.product-name {
font-size: 0.75rem;
font-weight: 700;
letter-spacing: 0.08em;
color: var(--color-text-muted, #6b7a99);
text-transform: uppercase;
margin: 0 0 0.4rem;
}
.assignment-list {
display: flex;
flex-direction: column;
gap: 0.4rem;
}
/* ── Assignment rows ── */
.assignment-row {
background: var(--color-surface-raised, #f0f4fa);
border: 1px solid var(--color-border, #d0d7e8);
border-radius: 0.5rem;
padding: 0.65rem 0.85rem;
display: flex;
flex-direction: column;
gap: 0.4rem;
}
.assignment-main {
display: flex;
align-items: center;
gap: 0.5rem;
flex-wrap: wrap;
}
.task-id {
font-family: var(--font-mono, monospace);
font-size: 0.88rem;
font-weight: 600;
color: var(--color-text, #1a2030);
min-width: 0;
}
.model-name {
font-size: 0.85rem;
color: var(--color-text-muted, #6b7a99);
white-space: nowrap;
overflow: hidden;
text-overflow: ellipsis;
max-width: 280px;
cursor: default;
}
.assignment-actions {
display: flex;
gap: 0.4rem;
flex-wrap: wrap;
}
/* ── Node status badges ── */
.node-statuses {
display: flex;
gap: 0.35rem;
flex-wrap: wrap;
}
.node-badge-wrap {
display: inline-flex;
align-items: center;
gap: 0.2rem;
}
.node-badge {
display: inline-flex;
align-items: center;
gap: 0.2rem;
font-size: 0.78rem;
padding: 0.15rem 0.5rem;
border-radius: 0.35rem;
font-weight: 500;
}
.node-badge.present {
background: color-mix(in srgb, var(--color-success, #2a8050) 15%, transparent);
color: var(--color-success, #2a8050);
border: 1px solid color-mix(in srgb, var(--color-success, #2a8050) 30%, transparent);
}
.node-badge.absent {
background: color-mix(in srgb, var(--color-danger, #b03030) 12%, transparent);
color: var(--color-danger, #b03030);
border: 1px solid color-mix(in srgb, var(--color-danger, #b03030) 25%, transparent);
}
.node-badge.vram_tight {
background: color-mix(in srgb, #c08030 15%, transparent);
color: #8a5500;
border: 1px solid color-mix(in srgb, #c08030 30%, transparent);
}
.node-icon {
font-size: 0.85em;
}
.btn-deploy {
padding: 0.1rem 0.4rem;
font-size: 0.72rem;
font-weight: 600;
background: color-mix(in srgb, var(--app-primary, #2A6080) 12%, transparent);
color: var(--app-primary, #2A6080);
border: 1px solid color-mix(in srgb, var(--app-primary, #2A6080) 30%, transparent);
border-radius: 0.3rem;
cursor: pointer;
white-space: nowrap;
transition: background 0.15s;
}
.btn-deploy:hover:not(:disabled) {
background: color-mix(in srgb, var(--app-primary, #2A6080) 22%, transparent);
}
.btn-deploy:disabled { opacity: 0.5; cursor: default; }
/* ── Inline edit ── */
.inline-edit {
display: flex;
flex-wrap: wrap;
gap: 0.4rem;
padding-top: 0.35rem;
border-top: 1px solid var(--color-border, #d0d7e8);
}
.edit-select,
.edit-input {
flex: 1;
min-width: 160px;
padding: 0.35rem 0.55rem;
font-size: 0.85rem;
border: 1px solid var(--color-border, #d0d7e8);
border-radius: 0.4rem;
background: var(--color-surface, #fff);
color: var(--color-text, #1a2030);
}
.inline-edit-btns {
display: flex;
gap: 0.35rem;
align-items: center;
}
/* ── Registry table ── */
.registry-table-wrap {
overflow-x: auto;
border-radius: 0.5rem;
border: 1px solid var(--color-border, #d0d7e8);
}
.registry-table {
width: 100%;
border-collapse: collapse;
font-size: 0.85rem;
}
.registry-table th {
text-align: left;
padding: 0.5rem 0.75rem;
font-size: 0.78rem;
font-weight: 600;
color: var(--color-text-muted, #6b7a99);
background: var(--color-surface-raised, #f0f4fa);
border-bottom: 1px solid var(--color-border, #d0d7e8);
white-space: nowrap;
}
.registry-table td {
padding: 0.5rem 0.75rem;
border-bottom: 1px solid var(--color-border, #d0d7e8);
vertical-align: middle;
}
.registry-table tbody tr:last-child td {
border-bottom: none;
}
.truncated {
display: inline-block;
max-width: 220px;
white-space: nowrap;
overflow: hidden;
text-overflow: ellipsis;
vertical-align: bottom;
cursor: default;
}
.hf-link {
color: var(--app-primary, #2A6080);
text-decoration: none;
font-size: 0.82rem;
}
.hf-link:hover { text-decoration: underline; }
.text-muted { color: var(--color-text-muted, #6b7a99); }
/* ── Chips ── */
.chip {
display: inline-block;
padding: 0.15rem 0.5rem;
border-radius: 0.35rem;
font-size: 0.75rem;
font-weight: 600;
white-space: nowrap;
}
.chip-vram {
background: color-mix(in srgb, var(--app-primary, #2A6080) 12%, transparent);
color: var(--app-primary, #2A6080);
border: 1px solid color-mix(in srgb, var(--app-primary, #2A6080) 25%, transparent);
}
/* service chips — match ModelsView convention */
.chip-service-cf-text { background: #e8f0fe; color: #1a5276; border: 1px solid #a9c4e8; }
.chip-service-cf-stt { background: #eaf6ea; color: #1e6b3a; border: 1px solid #a2d9b1; }
.chip-service-cf-tts { background: #fdf3e3; color: #7d4e00; border: 1px solid #e8c98a; }
.chip-service-cf-vision { background: #f3e8fd; color: #5b2d8e; border: 1px solid #c8a0e8; }
.chip-service-cf-image { background: #fce8f0; color: #8e1a4f; border: 1px solid #e8a0c0; }
.chip-service-cf-voice { background: #e8f8fc; color: #0a5c6e; border: 1px solid #88d0e0; }
.chip-service-vllm { background: #f5ece0; color: #7a3800; border: 1px solid #d4a87a; }
.chip-service-ollama { background: #eeeeee; color: #444; border: 1px solid #ccc; }
/* ── Buttons ── */
.btn-primary {
padding: 0.45rem 1rem;
background: var(--app-primary, #2A6080);
color: #fff;
border: none;
border-radius: 0.4rem;
font-size: 0.85rem;
font-weight: 600;
cursor: pointer;
transition: opacity 0.15s;
}
.btn-primary:disabled { opacity: 0.5; cursor: default; }
.btn-primary:not(:disabled):hover { opacity: 0.88; }
.btn-ghost {
padding: 0.35rem 0.75rem;
background: transparent;
border: 1px solid var(--color-border, #d0d7e8);
border-radius: 0.4rem;
font-size: 0.82rem;
color: var(--color-text-muted, #6b7a99);
cursor: pointer;
transition: background 0.15s;
}
.btn-ghost:hover { background: var(--color-surface-raised, #e4ebf5); }
.btn-ghost.btn-danger { color: var(--color-danger, #b03030); border-color: color-mix(in srgb, var(--color-danger, #b03030) 30%, transparent); }
.btn-ghost.btn-danger:hover { background: color-mix(in srgb, var(--color-danger, #b03030) 10%, transparent); }
.btn-sm { padding: 0.3rem 0.65rem; font-size: 0.8rem; }
/* ── Empty / error states ── */
.empty-state {
padding: 1.5rem;
text-align: center;
color: var(--color-text-muted, #6b7a99);
font-size: 0.9rem;
background: var(--color-surface-raised, #f0f4fa);
border: 1px dashed var(--color-border, #d0d7e8);
border-radius: 0.5rem;
}
.error-notice {
padding: 0.75rem 1rem;
background: color-mix(in srgb, var(--color-danger, #b03030) 10%, transparent);
color: var(--color-danger, #b03030);
border: 1px solid color-mix(in srgb, var(--color-danger, #b03030) 25%, transparent);
border-radius: 0.4rem;
font-size: 0.87rem;
}
/* ── Modal ── */
.modal-backdrop {
position: fixed;
inset: 0;
background: rgba(0,0,0,0.35);
display: flex;
align-items: center;
justify-content: center;
z-index: 100;
padding: 1rem;
}
.modal {
background: var(--color-surface, #fff);
border-radius: 0.65rem;
padding: 1.5rem;
width: 100%;
max-width: 480px;
display: flex;
flex-direction: column;
gap: 0.65rem;
box-shadow: 0 8px 32px rgba(0,0,0,0.18);
max-height: 90vh;
overflow-y: auto;
}
.modal-title {
font-size: 1rem;
font-weight: 700;
color: var(--app-primary, #2A6080);
margin: 0 0 0.25rem;
}
.form-label {
font-size: 0.82rem;
font-weight: 600;
color: var(--color-text-muted, #6b7a99);
}
.form-input,
.form-select {
padding: 0.4rem 0.65rem;
font-size: 0.88rem;
border: 1px solid var(--color-border, #d0d7e8);
border-radius: 0.4rem;
background: var(--color-surface, #fff);
color: var(--color-text, #1a2030);
width: 100%;
box-sizing: border-box;
}
.form-input:focus, .form-select:focus {
outline: 2px solid var(--app-primary, #2A6080);
outline-offset: 1px;
}
.modal-actions {
display: flex;
gap: 0.5rem;
justify-content: flex-end;
margin-top: 0.25rem;
}
.optional, .hint {
font-weight: 400;
color: var(--color-text-muted, #6b7a99);
font-size: 0.78rem;
}
/* ── Responsive ── */
@media (max-width: 600px) {
.assignment-main { flex-direction: column; align-items: flex-start; }
.col-hf { display: none; }
.model-name { max-width: 100%; }
.modal { padding: 1rem; }
}
</style>

View file

@ -0,0 +1,82 @@
import { mount, flushPromises } from '@vue/test-utils'
import { describe, it, expect, vi, beforeEach } from 'vitest'
import BenchmarkView from './BenchmarkView.vue'
beforeEach(() => {
vi.stubGlobal('fetch', vi.fn().mockImplementation((url: string) => {
// LlmEvalTab calls /api/cforch/models and expects { models: CfOrchModel[] }
if (url.includes('/api/cforch/models')) {
return Promise.resolve({
ok: true,
json: async () => ({ models: [] }),
text: async () => '',
})
}
// Default: satisfies ClassifierTab (/api/benchmark/results, /api/benchmark/models,
// /api/finetune/status), StyleTab (/api/style/models, /api/style/results),
// and any other tab that tolerates empty arrays/objects.
return Promise.resolve({
ok: true,
json: async () => ({ models: {}, categories: {}, tasks: [], types: [], results: [] }),
text: async () => '',
})
}))
vi.stubGlobal('EventSource', class {
onmessage = null
onerror = null
close() {}
})
})
describe('BenchmarkView', () => {
it('renders page title "Benchmark"', async () => {
const w = mount(BenchmarkView)
await flushPromises()
expect(w.text()).toContain('Benchmark')
})
it('has mode buttons: Classifier, LLM Eval, Writing Style', async () => {
const w = mount(BenchmarkView)
await flushPromises()
const text = w.text()
expect(text).toContain('Classifier')
expect(text).toContain('LLM Eval')
expect(text).toContain('Writing Style')
})
it('does NOT have a Compare mode button', async () => {
const w = mount(BenchmarkView)
await flushPromises()
const buttons = w.findAll('.mode-btn')
const labels = buttons.map(b => b.text())
expect(labels.every(l => !l.includes('Compare'))).toBe(true)
})
it('shows Classifier tab by default', async () => {
const w = mount(BenchmarkView)
await flushPromises()
// ClassifierTab has a .classifier-tab root
expect(w.find('.classifier-tab').exists()).toBe(true)
})
it('switches to LlmEvalTab when LLM Eval clicked', async () => {
const w = mount(BenchmarkView)
await flushPromises()
const llmBtn = w.findAll('.mode-btn').find(b => b.text().includes('LLM Eval'))!
await llmBtn.trigger('click')
await flushPromises()
expect(w.find('.llm-eval-tab').exists()).toBe(true)
expect(w.find('.classifier-tab').exists()).toBe(false)
expect(llmBtn.classes()).toContain('active')
})
it('switches to StyleTab when Writing Style clicked', async () => {
const w = mount(BenchmarkView)
await flushPromises()
const styleBtn = w.findAll('.mode-btn').find(b => b.text().includes('Writing Style'))!
await styleBtn.trigger('click')
await flushPromises()
expect(w.find('.style-tab').exists()).toBe(true)
expect(w.find('.classifier-tab').exists()).toBe(false)
})
})

View file

@ -16,22 +16,22 @@
:class="{ active: benchMode === 'llm' }"
@click="benchMode = 'llm'"
>🤖 LLM Eval</button>
<button
class="mode-btn"
:class="{ active: benchMode === 'compare' }"
@click="benchMode = 'compare'"
> Compare</button>
<button
class="mode-btn"
:class="{ active: benchMode === 'style' }"
@click="benchMode = 'style'"
> Writing Style</button>
<button
class="mode-btn"
:class="{ active: benchMode === 'plans' }"
@click="benchMode = 'plans'"
>📐 Planning</button>
</div>
<ClassifierTab v-if="benchMode === 'classifier'" />
<LlmEvalTab v-if="benchMode === 'llm'" />
<CompareTab v-if="benchMode === 'compare'" />
<StyleTab v-if="benchMode === 'style'" />
<PlansBenchTab v-if="benchMode === 'plans'" />
</div>
</template>
@ -39,10 +39,10 @@
import { ref } from 'vue'
import ClassifierTab from './ClassifierTab.vue'
import LlmEvalTab from './LlmEvalTab.vue'
import CompareTab from './CompareTab.vue'
import StyleTab from './StyleTab.vue'
import PlansBenchTab from './PlansBenchTab.vue'
type BenchMode = 'classifier' | 'llm' | 'compare' | 'style'
type BenchMode = 'classifier' | 'llm' | 'style' | 'plans'
const benchMode = ref<BenchMode>('classifier')
</script>
@ -69,7 +69,7 @@ const benchMode = ref<BenchMode>('classifier')
margin: 0;
}
/* ── Mode toggle (segmented control) ────────────────────── */
/* ── Mode toggle (segmented control) ── */
.mode-toggle {
display: inline-flex;
border: 1px solid var(--color-border, #d0d7e8);

View file

@ -325,7 +325,7 @@ function toggleCategory(models: AvailableModel[], checked: boolean) {
async function loadModelCategories() {
modelsLoading.value = true
const { data } = await useApiFetch<ModelCategoriesResponse>('/api/benchmark/models')
const { data } = await useApiFetch<ModelCategoriesResponse>('/api/cforch/models')
modelsLoading.value = false
if (data?.categories) {
modelCategories.value = data.categories
@ -342,7 +342,7 @@ const modelCount = computed(() => modelNames.value.length)
const labelNames = computed(() => {
const canonical = Object.keys(LABEL_META)
const inResults = new Set(
modelNames.value.flatMap(n => Object.keys(results.value!.models[n].per_label))
modelNames.value.flatMap(n => Object.keys(results.value?.models[n]?.per_label ?? {}))
)
return [...canonical.filter(l => inResults.has(l)), ...[...inResults].filter(l => !canonical.includes(l))]
})
@ -401,16 +401,16 @@ function formatDate(iso: string | null): string {
// Data loading
async function loadResults() {
loading.value = true
const { data } = await useApiFetch<BenchResults>('/api/benchmark/results')
const { data } = await useApiFetch<BenchResults>('/api/cforch/results')
loading.value = false
if (data && Object.keys(data.models).length > 0) {
if (data?.models && Object.keys(data.models).length > 0) {
results.value = data
}
}
async function loadFineTunedModels() {
const { data } = await useApiFetch<FineTunedModel[]>('/api/finetune/status')
if (Array.isArray(data)) fineTunedModels.value = data
const { data } = await useApiFetch<{ results: FineTunedModel[] }>('/api/train/results')
if (Array.isArray(data?.results)) fineTunedModels.value = data.results
}
// Benchmark run
@ -428,7 +428,7 @@ function startBenchmark() {
params.set('model_names', [...selectedModels.value].join(','))
}
const qs = params.toString()
const url = `/api/benchmark/run${qs ? `?${qs}` : ''}`
const url = `/api/cforch/run${qs ? `?${qs}` : ''}`
useApiSSE(
url,
async (event) => {
@ -457,7 +457,7 @@ function startBenchmark() {
}
async function cancelBenchmark() {
await fetch('/api/benchmark/cancel', { method: 'POST' }).catch(() => {})
await fetch('/api/cforch/cancel', { method: 'POST' }).catch(() => {})
}
// Fine-tune

View file

@ -71,34 +71,37 @@
rows="6"
/>
<!-- Ollama model picker -->
<!-- LLM model picker (ollama + vllm + cf-text) -->
<details class="model-picker" open>
<summary class="picker-summary">
<span class="picker-title">🤖 Ollama Models</span>
<span class="picker-badge">{{ cmpSelectedModels.size }} / {{ ollamaLlmModels.length }}</span>
<span class="picker-title">🤖 LLM Models</span>
<span class="picker-badge">{{ cmpSelectedModels.size }} / {{ llmSelectableModels.length }}</span>
</summary>
<div class="picker-body">
<label class="picker-cat-header">
<input
type="checkbox"
:checked="cmpSelectedModels.size === ollamaLlmModels.length"
:indeterminate="cmpSelectedModels.size > 0 && cmpSelectedModels.size < ollamaLlmModels.length"
:checked="cmpSelectedModels.size === llmSelectableModels.length"
:indeterminate="cmpSelectedModels.size > 0 && cmpSelectedModels.size < llmSelectableModels.length"
@change="toggleAllCmpModels(($event.target as HTMLInputElement).checked)"
/>
<span class="picker-cat-name">All ollama models</span>
<span class="picker-cat-name">All LLM models</span>
</label>
<div v-for="(models, service) in llmModelsByService" :key="service" class="picker-category">
<span class="picker-cat-section">{{ service }}</span>
<div class="picker-model-list">
<label v-for="m in ollamaLlmModels" :key="m.id" class="picker-model-row">
<label v-for="m in models" :key="m.id" class="picker-model-row">
<input
type="checkbox"
:checked="cmpSelectedModels.has(m.id)"
@change="toggleCmpModel(m.id, ($event.target as HTMLInputElement).checked)"
/>
<span class="picker-model-name">{{ m.name }}</span>
<span class="picker-adapter-type">{{ m.tags.slice(0, 3).join(', ') }}</span>
<span class="picker-adapter-type">{{ m.tags.slice(0, 2).join(', ') }}</span>
</label>
</div>
</div>
</div>
</details>
<!-- Run controls -->
@ -232,10 +235,22 @@ const cmpResults = ref<CmpResult[]>([])
const cmpEventSource = ref<EventSource | null>(null)
// Computed
const ollamaLlmModels = computed(() =>
llmModels.value.filter(m => m.service === 'ollama')
const LLM_SERVICES = new Set(['ollama', 'vllm', 'cf-text'])
const llmSelectableModels = computed(() =>
llmModels.value.filter(m => LLM_SERVICES.has(m.service))
)
/** Group selectable models by service for the picker UI */
const llmModelsByService = computed((): Record<string, CfOrchModel[]> => {
const groups: Record<string, CfOrchModel[]> = {}
for (const m of llmSelectableModels.value) {
if (!groups[m.service]) groups[m.service] = []
groups[m.service].push(m)
}
return groups
})
const llmTasksByType = computed((): Record<string, CfOrchTask[]> => {
const groups: Record<string, CfOrchTask[]> = {}
for (const t of llmTasks.value) {
@ -270,7 +285,7 @@ function toggleCmpModel(id: string, checked: boolean) {
function toggleAllCmpModels(checked: boolean) {
cmpSelectedModels.value = checked
? new Set(ollamaLlmModels.value.map(m => m.id))
? new Set(llmSelectableModels.value.map(m => m.id))
: new Set()
}
@ -288,9 +303,8 @@ async function loadLlmModels() {
const { data } = await useApiFetch<{ models: CfOrchModel[] }>('/api/cforch/models')
if (data?.models) {
llmModels.value = data.models
// Pre-select all ollama models
cmpSelectedModels.value = new Set(
data.models.filter(m => m.service === 'ollama').map(m => m.id)
data.models.filter(m => LLM_SERVICES.has(m.service)).map(m => m.id)
)
}
}

View file

@ -0,0 +1,31 @@
import { mount, flushPromises } from '@vue/test-utils'
import { describe, it, expect, vi, beforeEach } from 'vitest'
import CompareView from './CompareView.vue'
beforeEach(() => {
vi.stubGlobal('fetch', vi.fn().mockResolvedValue({
ok: true,
json: async () => ({ tasks: [], types: [], models: [] }),
text: async () => '',
}))
vi.stubGlobal('EventSource', class {
onmessage = null
onerror = null
close() {}
})
})
describe('CompareView', () => {
it('renders page title "Compare"', async () => {
const w = mount(CompareView)
await flushPromises()
expect(w.find('h1.page-title').text()).toContain('Compare')
})
it('wraps CompareTab component', async () => {
const w = mount(CompareView)
await flushPromises()
// CompareTab renders a .compare-tab root div
expect(w.find('.compare-tab').exists()).toBe(true)
})
})

View file

@ -0,0 +1,36 @@
<template>
<div class="compare-view">
<header class="compare-header">
<h1 class="page-title">🔍 Compare</h1>
</header>
<CompareTab />
</div>
</template>
<script setup lang="ts">
import CompareTab from './CompareTab.vue'
</script>
<style scoped>
.compare-view {
max-width: 860px;
margin: 0 auto;
padding: 1.5rem 1rem 4rem;
display: flex;
flex-direction: column;
gap: 1.75rem;
}
.compare-header {
display: flex;
align-items: center;
}
.page-title {
font-family: var(--font-display, var(--font-body, sans-serif));
font-size: 1.4rem;
font-weight: 700;
color: var(--app-primary, #2A6080);
margin: 0;
}
</style>

View file

@ -0,0 +1,119 @@
import { mount, flushPromises } from '@vue/test-utils'
import { createRouter, createWebHashHistory } from 'vue-router'
import { describe, it, expect, vi, beforeEach } from 'vitest'
import DashboardView from './DashboardView.vue'
const router = createRouter({
history: createWebHashHistory(),
routes: [
{ path: '/', component: { template: '<div />' } },
{ path: '/eval/benchmark', component: { template: '<div />' } },
{ path: '/train/jobs', component: { template: '<div />' } },
{ path: '/fleet', component: { template: '<div />' } },
],
})
const baseDashboard = {
labeled_since_last_eval: 0,
last_eval_timestamp: null,
last_eval_best_score: null,
active_jobs: [],
corrections_export_ready: 0,
signals: { data_to_eval: false, eval_to_train: false, train_to_fleet: false },
}
function mockFetch(overrides: Partial<typeof baseDashboard> = {}) {
vi.stubGlobal('fetch', vi.fn().mockResolvedValue({
ok: true,
json: async () => ({ ...baseDashboard, ...overrides }),
text: async () => '',
}))
}
beforeEach(() => mockFetch())
describe('DashboardView', () => {
it('renders page title', async () => {
const w = mount(DashboardView, { global: { plugins: [router] } })
await flushPromises()
expect(w.text()).toContain('Dashboard')
})
it('shows three stage cards: Data, Eval, Train', async () => {
const w = mount(DashboardView, { global: { plugins: [router] } })
await flushPromises()
expect(w.find('.stage-card[data-stage="data"]').exists()).toBe(true)
expect(w.find('.stage-card[data-stage="eval"]').exists()).toBe(true)
expect(w.find('.stage-card[data-stage="train"]').exists()).toBe(true)
})
it('shows labeled_since_last_eval count in Data card', async () => {
mockFetch({ labeled_since_last_eval: 42 })
const w = mount(DashboardView, { global: { plugins: [router] } })
await flushPromises()
expect(w.find('.stage-card[data-stage="data"]').text()).toContain('42')
})
it('does NOT show Run Eval CTA when data_to_eval is false', async () => {
mockFetch({ signals: { data_to_eval: false, eval_to_train: false, train_to_fleet: false } })
const w = mount(DashboardView, { global: { plugins: [router] } })
await flushPromises()
const dataCard = w.find('.stage-card[data-stage="data"]')
expect(dataCard.find('.cta-btn').exists()).toBe(false)
})
it('shows Run Eval CTA when data_to_eval is true', async () => {
mockFetch({ signals: { data_to_eval: true, eval_to_train: false, train_to_fleet: false } })
const w = mount(DashboardView, { global: { plugins: [router] } })
await flushPromises()
const dataCard = w.find('.stage-card[data-stage="data"]')
expect(dataCard.find('.cta-btn').exists()).toBe(true)
expect(dataCard.find('.cta-btn').text()).toContain('Run Eval')
})
it('shows Queue Finetune CTA when eval_to_train is true', async () => {
mockFetch({ signals: { data_to_eval: false, eval_to_train: true, train_to_fleet: false } })
const w = mount(DashboardView, { global: { plugins: [router] } })
await flushPromises()
const evalCard = w.find('.stage-card[data-stage="eval"]')
expect(evalCard.find('.cta-btn').text()).toContain('Queue Finetune')
})
it('shows Register in Fleet CTA when train_to_fleet is true', async () => {
mockFetch({ signals: { data_to_eval: false, eval_to_train: false, train_to_fleet: true } })
const w = mount(DashboardView, { global: { plugins: [router] } })
await flushPromises()
const trainCard = w.find('.stage-card[data-stage="train"]')
expect(trainCard.find('.cta-btn').text()).toContain('Register in Fleet')
})
it('shows active job status pills in Train card', async () => {
mockFetch({ active_jobs: [{ id: 'j1', type: 'classifier', model_key: 'deberta-v3', status: 'running' }] })
const w = mount(DashboardView, { global: { plugins: [router] } })
await flushPromises()
const trainCard = w.find('.stage-card[data-stage="train"]')
expect(trainCard.find('.status-pill').exists()).toBe(true)
expect(trainCard.text()).toContain('deberta-v3')
})
it('shows last eval score in Eval card when present', async () => {
mockFetch({ last_eval_best_score: 0.821 })
const w = mount(DashboardView, { global: { plugins: [router] } })
await flushPromises()
const evalCard = w.find('.stage-card[data-stage="eval"]')
expect(evalCard.text()).toContain('82.1%')
})
it('shows error state when API call fails', async () => {
vi.stubGlobal('fetch', vi.fn().mockResolvedValue({ ok: false, status: 503, text: async () => '' }))
const w = mount(DashboardView, { global: { plugins: [router] } })
await flushPromises()
expect(w.find('.error-notice').exists()).toBe(true)
})
it('shows refresh button', async () => {
const w = mount(DashboardView, { global: { plugins: [router] } })
await flushPromises()
expect(w.find('.refresh-btn').exists()).toBe(true)
})
})

View file

@ -0,0 +1,406 @@
<template>
<div class="dashboard-view">
<header class="dashboard-header">
<h1 class="page-title">📊 Dashboard</h1>
<button class="refresh-btn" :disabled="loading" @click="load" aria-label="Refresh dashboard">
🔄
</button>
</header>
<div v-if="loading && !data" class="loading-state">Loading</div>
<div v-if="error" class="error-notice" role="alert">
{{ error }}
<button class="btn-retry" @click="load">Retry</button>
</div>
<div v-if="data" class="flywheel-grid">
<!-- Data card -->
<div class="stage-card" data-stage="data">
<div class="card-header">
<span class="card-step"></span>
<h2 class="card-title">Data</h2>
</div>
<div class="card-body">
<p class="card-metric">
<strong class="metric-value">{{ data.labeled_since_last_eval.toLocaleString() }}</strong>
<span class="metric-label"> labeled since last eval</span>
</p>
</div>
</div>
<!-- Eval card -->
<div class="stage-card" data-stage="eval">
<div class="card-header">
<span class="card-step"></span>
<h2 class="card-title">Eval</h2>
</div>
<div class="card-body">
<div class="bench-run-table">
<div
v-for="(run, type) in data.recent_bench_runs"
:key="type"
class="bench-run-row"
>
<span class="bench-type-label">{{ BENCH_LABELS[type as BenchType] ?? type }}</span>
<span class="bench-run-time" :class="{ 'metric-muted': !run.timestamp }">
{{ run.timestamp ? formatBenchTs(run.timestamp) : '—' }}
</span>
<span v-if="run.score != null" class="bench-run-score">
{{ formatScore(run.score) }}
</span>
</div>
</div>
</div>
<div v-if="data.signals.eval_to_train" class="card-cta">
<RouterLink to="/train/jobs" class="cta-btn">Queue Finetune</RouterLink>
</div>
<div v-if="data.signals.data_to_eval" class="card-cta">
<RouterLink to="/eval/benchmark" class="cta-btn">Run Eval</RouterLink>
</div>
</div>
<!-- Train card -->
<div class="stage-card" data-stage="train">
<div class="card-header">
<span class="card-step"></span>
<h2 class="card-title">Train</h2>
</div>
<div class="card-body">
<template v-if="data.active_jobs.length > 0">
<div
v-for="job in data.active_jobs"
:key="job.id"
class="job-row"
>
<span class="job-key">{{ job.model_key }}</span>
<span class="status-pill" :class="`status-${job.status}`">{{ job.status }}</span>
</div>
</template>
<p v-else class="card-metric metric-muted">No active jobs</p>
<p v-if="data.corrections_export_ready > 0" class="card-metric">
<strong class="metric-value">{{ data.corrections_export_ready }}</strong>
<span class="metric-label"> corrections ready</span>
</p>
</div>
<div v-if="data.signals.train_to_fleet" class="card-cta">
<RouterLink to="/fleet" class="cta-btn">Register in Fleet</RouterLink>
</div>
</div>
</div>
</div>
</template>
<script setup lang="ts">
import { ref, computed, onMounted } from 'vue'
import { RouterLink } from 'vue-router'
interface ActiveJob {
id: string
type: string
model_key: string
status: 'queued' | 'running' | 'completed' | 'failed' | 'cancelled'
}
interface DashboardSignals {
data_to_eval: boolean
eval_to_train: boolean
train_to_fleet: boolean
}
interface BenchRun {
timestamp: string | null
metric: string | null
score: number | null
}
type BenchType = 'classifier' | 'llm' | 'style' | 'plans'
interface DashboardData {
labeled_since_last_eval: number
last_eval_timestamp: string | null
last_eval_best_score: number | null
active_jobs: ActiveJob[]
corrections_export_ready: number
recent_bench_runs: Record<BenchType, BenchRun>
signals: DashboardSignals
}
const BENCH_LABELS: Record<BenchType, string> = {
classifier: 'Classifier',
llm: 'LLM Eval',
style: 'Style',
plans: 'Planning',
}
const data = ref<DashboardData | null>(null)
const loading = ref(false)
const error = ref<string | null>(null)
function formatBenchTs(ts: string): string {
const date = new Date(ts)
if (!isNaN(date.getTime())) {
const diff = Date.now() - date.getTime()
const mins = Math.floor(diff / 60000)
if (mins < 1) return 'just now'
if (mins < 60) return `${mins}m ago`
const hrs = Math.floor(mins / 60)
if (hrs < 24) return `${hrs}h ago`
return `${Math.floor(hrs / 24)}d ago`
}
// Non-ISO: show as-is (plans bench uses "YYYY-MM-DD HH:MM")
return ts.length > 16 ? ts.slice(0, 16) : ts
}
function formatScore(score: number): string {
return `${(score * 100).toFixed(1)}%`
}
async function load() {
loading.value = true
error.value = null
try {
const res = await fetch('/api/dashboard')
if (!res.ok) {
error.value = `Could not load dashboard (HTTP ${res.status}).`
return
}
data.value = await res.json() as DashboardData
} catch {
error.value = 'Network error. Is the Avocet API running?'
} finally {
loading.value = false
}
}
onMounted(() => load())
</script>
<style scoped>
.dashboard-view {
max-width: 860px;
margin: 0 auto;
padding: 1.5rem 1rem 4rem;
display: flex;
flex-direction: column;
gap: 1.75rem;
}
.dashboard-header {
display: flex;
align-items: center;
gap: 0.75rem;
}
.page-title {
font-family: var(--font-display, var(--font-body, sans-serif));
font-size: 1.4rem;
font-weight: 700;
color: var(--app-primary, #2A6080);
margin: 0;
flex: 1;
}
.refresh-btn {
background: transparent;
border: 1px solid var(--color-border, #d0d7e8);
border-radius: 0.375rem;
cursor: pointer;
font-size: 1rem;
padding: 0.3rem 0.5rem;
transition: background 0.15s;
}
.refresh-btn:hover:not(:disabled) { background: var(--color-surface-raised, #e4ebf5); }
.refresh-btn:disabled { opacity: 0.5; cursor: not-allowed; }
/* ── Flywheel grid ── */
.flywheel-grid {
display: grid;
grid-template-columns: repeat(3, 1fr);
gap: 1rem;
}
@media (max-width: 680px) {
.flywheel-grid {
grid-template-columns: 1fr;
}
}
/* ── Stage cards ── */
.stage-card {
background: var(--color-surface-raised, #f5f7fc);
border: 1px solid var(--color-border, #d0d7e8);
border-radius: var(--radius-lg, 1rem);
padding: 1rem;
display: flex;
flex-direction: column;
gap: 0.75rem;
box-shadow: var(--shadow-sm);
}
.card-header {
display: flex;
align-items: center;
gap: 0.5rem;
border-bottom: 1px solid var(--color-border, #d0d7e8);
padding-bottom: 0.6rem;
}
.card-step {
font-size: 1.1rem;
font-weight: 700;
color: var(--app-primary, #2A6080);
flex-shrink: 0;
}
.card-title {
font-family: var(--font-display, var(--font-body, sans-serif));
font-size: 1rem;
font-weight: 600;
color: var(--color-text, #1a2338);
margin: 0;
}
.card-body {
display: flex;
flex-direction: column;
gap: 0.4rem;
flex: 1;
}
.card-metric {
margin: 0;
font-size: 0.875rem;
color: var(--color-text, #1a2338);
}
.metric-value {
font-size: 1.05rem;
font-weight: 700;
color: var(--app-primary, #2A6080);
}
.metric-label {
color: var(--color-text-muted, #4a5c7a);
}
.metric-muted { color: var(--color-text-muted, #4a5c7a); }
.card-cta { margin-top: auto; }
.cta-btn {
display: block;
width: 100%;
text-align: center;
padding: 0.5rem;
background: var(--app-primary, #2A6080);
color: #fff;
border-radius: 0.375rem;
text-decoration: none;
font-size: 0.875rem;
font-weight: 600;
transition: background 0.15s;
}
.cta-btn:hover { background: color-mix(in srgb, var(--app-primary, #2A6080) 85%, black); }
/* ── Bench run table ── */
.bench-run-table {
display: flex;
flex-direction: column;
gap: 0.3rem;
}
.bench-run-row {
display: grid;
grid-template-columns: 6rem 1fr auto;
align-items: center;
gap: 0.5rem;
font-size: 0.82rem;
}
.bench-type-label {
font-weight: 600;
color: var(--color-text, #1a2338);
font-size: 0.78rem;
}
.bench-run-time {
color: var(--color-text-secondary, #6b7a99);
font-size: 0.78rem;
}
.bench-run-score {
font-family: var(--font-mono, monospace);
font-size: 0.75rem;
font-weight: 600;
color: var(--app-primary, #2A6080);
background: color-mix(in srgb, var(--app-primary, #2A6080) 10%, transparent);
padding: 0.1rem 0.35rem;
border-radius: 0.25rem;
}
/* ── Job pills ── */
.job-row {
display: flex;
align-items: center;
gap: 0.5rem;
font-size: 0.875rem;
}
.job-key {
flex: 1;
overflow: hidden;
text-overflow: ellipsis;
white-space: nowrap;
color: var(--color-text, #1a2338);
}
.status-pill {
font-size: 0.75rem;
padding: 0.15rem 0.45rem;
border-radius: 100px;
font-weight: 600;
flex-shrink: 0;
background: var(--color-surface-raised, #e4ebf5);
color: var(--color-text-muted, #4a5c7a);
}
.status-pill.status-running { background: #d4f4e0; color: #1a7a3a; }
.status-pill.status-queued { background: #fef3cd; color: #856404; }
.status-pill.status-failed { background: #fde8e8; color: #842029; }
.status-pill.status-completed { background: #e0f0ff; color: #0c5481; }
/* ── State indicators ── */
.loading-state {
color: var(--color-text-muted, #4a5c7a);
font-size: 0.9rem;
}
.error-notice {
background: #fde8e8;
color: #842029;
border: 1px solid #f5c2c7;
border-radius: 0.5rem;
padding: 0.75rem 1rem;
font-size: 0.875rem;
display: flex;
align-items: center;
gap: 0.75rem;
}
.btn-retry {
background: transparent;
border: 1px solid currentColor;
border-radius: 0.25rem;
color: inherit;
cursor: pointer;
font-size: 0.75rem;
padding: 0.2rem 0.5rem;
margin-left: auto;
}
</style>

View file

@ -0,0 +1,705 @@
<template>
<div class="embed-compare-page">
<!-- Step indicator (non-interactive) -->
<ol class="step-indicator" aria-label="Setup progress">
<li :class="{ complete: corpus.length > 0 }">Corpus</li>
<li :class="{ complete: queries.length > 0 }">Queries</li>
<li :class="{ complete: selectedModels.length > 0 }">Models</li>
<li :class="{ complete: hasResults }">Run &amp; Rate</li>
</ol>
<!-- Persistent aria-live region always in DOM, never v-if -->
<div
ref="liveRegion"
class="sr-live"
aria-live="polite"
aria-atomic="true"
v-text="liveMessage"
/>
<!-- Corpus section -->
<section class="card" aria-labelledby="corpus-heading">
<h2 id="corpus-heading"> Corpus</h2>
<div class="corpus-controls">
<div class="field">
<label for="corpus-paste">Paste chunks (one per line)</label>
<textarea
id="corpus-paste"
v-model="rawCorpus"
rows="6"
placeholder="Paste one chunk per line, or use Import below..."
@change="onCorpusPaste"
/>
</div>
<div class="import-row">
<label for="imitate-product-select">Import from product</label>
<select id="imitate-product-select" v-model="selectedProduct">
<option value="">-- select product --</option>
<option
v-for="p in imitateProducts"
:key="p.id"
:value="p.id"
>{{ p.name }}</option>
</select>
<button
class="btn-secondary"
:disabled="!selectedProduct || importing"
@click="importCorpus"
>
{{ importing ? 'Importing…' : 'Import' }}
</button>
<span v-if="importError" class="error-text" role="alert">{{ importError }}</span>
</div>
<p v-if="corpus.length > 0" class="corpus-count">
{{ corpus.length }} chunk{{ corpus.length === 1 ? '' : 's' }} loaded.
</p>
</div>
</section>
<!-- Queries section -->
<section class="card" aria-labelledby="queries-heading">
<h2 id="queries-heading"> Queries</h2>
<div class="field">
<label for="query-input">Enter queries (one per line)</label>
<textarea
id="query-input"
v-model="rawQueries"
rows="4"
placeholder="One query per line..."
@change="onQueriesChange"
/>
</div>
<p v-if="queries.length > 0" class="query-count">
{{ queries.length }} quer{{ queries.length === 1 ? 'y' : 'ies' }}.
</p>
</section>
<!-- Model selection -->
<section class="card" aria-labelledby="models-heading">
<h2 id="models-heading"> Models</h2>
<p v-if="loadingModels" class="muted">Loading models from Ollama</p>
<p v-else-if="modelsError" class="error-text" role="alert">{{ modelsError }}</p>
<ul v-else class="model-list" role="list">
<li v-for="m in availableModels" :key="m.name">
<label class="model-checkbox">
<input
type="checkbox"
:value="m.name"
v-model="selectedModels"
/>
{{ m.name }}
<span class="model-size muted" aria-label="model size">
{{ formatBytes(m.size) }}
</span>
</label>
</li>
</ul>
<p v-if="availableModels.length === 0 && !loadingModels && !modelsError" class="muted">
No Ollama models found. Pull an embedding model first.
</p>
</section>
<!-- Run controls -->
<section class="card run-controls" aria-labelledby="run-heading">
<h2 id="run-heading"> Run</h2>
<div class="run-row">
<div class="field-inline">
<label for="top-k-input">Results per query</label>
<input
id="top-k-input"
type="number"
v-model.number="topK"
min="1"
max="20"
style="width: 5rem"
/>
</div>
<button
class="btn-primary"
:disabled="!canRun || running"
@click="startRun"
>
{{ running ? 'Running…' : 'Run' }}
</button>
<button
v-if="running"
class="btn-danger"
aria-label="Cancel embedding run"
@click="cancelRun"
>
Cancel
</button>
</div>
<p v-if="!canRun && !running" class="muted">
Fill corpus, at least one query, and select at least one model to run.
</p>
</section>
<!-- Results -->
<section
v-if="hasResults"
class="card results-section"
aria-labelledby="results-heading"
>
<h2 id="results-heading">Results</h2>
<!-- Query pagination -->
<div class="query-nav" role="navigation" aria-label="Query navigation">
<button
class="btn-secondary"
aria-label="Previous query"
:disabled="currentQueryIdx === 0"
@click="currentQueryIdx--"
></button>
<span class="query-counter">
Query {{ currentQueryIdx + 1 }} of {{ uniqueQueries.length }}:
<em>{{ uniqueQueries[currentQueryIdx] }}</em>
</span>
<button
class="btn-secondary"
aria-label="Next query"
:disabled="currentQueryIdx >= uniqueQueries.length - 1"
@click="currentQueryIdx++"
></button>
</div>
<!-- Results table: one column per model -->
<div class="table-wrap">
<table class="results-table">
<thead>
<tr>
<th scope="col" class="rank-col">#</th>
<th
v-for="model in selectedModels"
:key="model"
scope="col"
>{{ model }}</th>
</tr>
</thead>
<tbody>
<tr v-for="rank in topK" :key="rank">
<td class="rank-col muted">{{ rank }}</td>
<td
v-for="model in selectedModels"
:key="model"
class="hit-cell"
>
<template v-if="getHit(currentQueryIdx, model, rank - 1) as hit">
<div class="hit-text">{{ hit.text }}</div>
<!-- Visual score bar: decorative only -->
<div class="score-row">
<div class="score-bar-wrap" aria-hidden="true">
<div class="score-bar" :style="{ width: `${hit.score * 100}%` }" />
</div>
<span class="score-label">{{ hit.score.toFixed(3) }}</span>
</div>
<!-- Rating buttons -->
<div class="rating-row">
<button
class="rate-btn"
:class="{ active: getRating(currentQueryIdx, model, hit.chunk_idx) === 'relevant' }"
:aria-pressed="getRating(currentQueryIdx, model, hit.chunk_idx) === 'relevant'"
aria-label="Mark as relevant"
@click="rate(currentQueryIdx, model, hit, 'relevant')"
>
👍 Relevant
</button>
<button
class="rate-btn rate-btn-neg"
:class="{ active: getRating(currentQueryIdx, model, hit.chunk_idx) === 'not_relevant' }"
:aria-pressed="getRating(currentQueryIdx, model, hit.chunk_idx) === 'not_relevant'"
aria-label="Mark as not relevant"
@click="rate(currentQueryIdx, model, hit, 'not_relevant')"
>
👎 Not relevant
</button>
</div>
</template>
<span v-else class="muted"></span>
</td>
</tr>
</tbody>
</table>
</div>
</section>
<!-- Export -->
<section
v-if="hasResults"
class="card export-section"
aria-labelledby="export-heading"
>
<h2 id="export-heading">Export Ratings</h2>
<div class="export-row">
<fieldset class="export-format-group">
<legend>Format</legend>
<label><input type="radio" v-model="exportFormat" value="csv" /> CSV</label>
<label><input type="radio" v-model="exportFormat" value="json" /> JSON</label>
</fieldset>
<button class="btn-secondary" @click="exportRatings">Export</button>
</div>
</section>
</div>
</template>
<script setup lang="ts">
import { ref, computed, onMounted } from 'vue'
// Types
interface OllamaModel { name: string; size: number }
interface ImitateProduct { id: string; name: string }
interface HitResult { chunk_idx: number; text: string; score: number }
interface ResultEvent {
type: 'result'
query_idx: number
query: string
model: string
hits: HitResult[]
}
// State
const rawCorpus = ref('')
const corpus = ref<string[]>([])
const rawQueries = ref('')
const queries = ref<string[]>([])
const selectedModels = ref<string[]>([])
const topK = ref(5)
const availableModels = ref<OllamaModel[]>([])
const loadingModels = ref(false)
const modelsError = ref('')
const imitateProducts = ref<ImitateProduct[]>([])
const selectedProduct = ref('')
const importing = ref(false)
const importError = ref('')
const running = ref(false)
const liveMessage = ref('')
const resultEvents = ref<ResultEvent[]>([])
const runController = ref<AbortController | null>(null)
const currentQueryIdx = ref(0)
const exportFormat = ref<'csv' | 'json'>('csv')
type RatingMap = Record<string, Record<string, Record<number, 'relevant' | 'not_relevant'>>>
const ratings = ref<RatingMap>({})
const uniqueQueries = computed(() => {
const seen = new Set<string>()
const out: string[] = []
for (const e of resultEvents.value) {
if (!seen.has(e.query)) { seen.add(e.query); out.push(e.query) }
}
return out
})
const hasResults = computed(() => resultEvents.value.length > 0)
const canRun = computed(
() => corpus.value.length > 0 && queries.value.length > 0 && selectedModels.value.length > 0
)
// Corpus helpers
function onCorpusPaste() {
const chunks = rawCorpus.value.split('\n').map(l => l.trim()).filter(Boolean)
corpus.value = chunks
if (chunks.length > 0) {
liveMessage.value = `${chunks.length} chunk${chunks.length === 1 ? '' : 's'} loaded.`
}
}
function onQueriesChange() {
queries.value = rawQueries.value.split('\n').map(l => l.trim()).filter(Boolean)
}
async function importCorpus() {
if (!selectedProduct.value) return
importing.value = true
importError.value = ''
try {
const r = await fetch(`/api/imitate/products/${selectedProduct.value}/sample-chunks`)
if (!r.ok) {
const text = await r.text()
throw new Error(text || `HTTP ${r.status}`)
}
const data = await r.json() as { chunks?: string[] }
const chunks = data.chunks ?? []
corpus.value = chunks
rawCorpus.value = chunks.join('\n')
liveMessage.value = `${chunks.length} chunk${chunks.length === 1 ? '' : 's'} loaded from import.`
} catch (err) {
importError.value = String(err)
} finally {
importing.value = false
}
}
// Model loading
async function loadModels() {
loadingModels.value = true
modelsError.value = ''
try {
const r = await fetch('/api/embed-bench/models')
if (!r.ok) throw new Error(`HTTP ${r.status}`)
const data = await r.json() as { models: OllamaModel[] }
availableModels.value = data.models
} catch (err) {
modelsError.value = `Failed to load models: ${err}`
} finally {
loadingModels.value = false
}
}
// Run
async function startRun() {
if (!canRun.value) return
running.value = true
resultEvents.value = []
liveMessage.value = 'Starting embedding run…'
runController.value = new AbortController()
try {
const resp = await fetch('/api/embed-bench/run', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({
corpus: corpus.value,
queries: queries.value,
models: selectedModels.value,
top_k: topK.value,
}),
signal: runController.value.signal,
})
const reader = resp.body!.getReader()
const decoder = new TextDecoder()
let buf = ''
while (true) {
const { done, value } = await reader.read()
if (done) break
buf += decoder.decode(value, { stream: true })
const lines = buf.split('\n')
buf = lines.pop() ?? ''
for (const line of lines) {
if (!line.startsWith('data: ')) continue
const event = JSON.parse(line.slice(6))
if (event.type === 'progress') {
liveMessage.value = event.msg
} else if (event.type === 'result') {
resultEvents.value.push(event as ResultEvent)
} else if (event.type === 'done') {
liveMessage.value = 'Run complete.'
} else if (event.type === 'error') {
liveMessage.value = `Error: ${event.msg}`
}
}
}
} catch (err) {
if ((err as Error).name !== 'AbortError') {
liveMessage.value = `Run failed: ${err}`
}
} finally {
running.value = false
runController.value = null
}
}
function cancelRun() {
runController.value?.abort()
liveMessage.value = 'Run cancelled.'
}
// Utilities
function formatBytes(bytes: number): string {
if (bytes < 1_000_000) return `${(bytes / 1000).toFixed(0)} KB`
if (bytes < 1_000_000_000) return `${(bytes / 1_000_000).toFixed(0)} MB`
return `${(bytes / 1_000_000_000).toFixed(1)} GB`
}
function getHit(queryIdx: number, model: string, rank: number): HitResult | null {
const query = uniqueQueries.value[queryIdx]
if (!query) return null
const ev = resultEvents.value.find(e => e.query === query && e.model === model)
return ev?.hits[rank] ?? null
}
function getRating(queryIdx: number, model: string, chunkIdx: number): string | undefined {
const query = uniqueQueries.value[queryIdx]
return ratings.value[query]?.[model]?.[chunkIdx]
}
async function rate(
queryIdx: number,
model: string,
hit: HitResult,
rating: 'relevant' | 'not_relevant',
) {
const query = uniqueQueries.value[queryIdx]
// Optimistic update
if (!ratings.value[query]) ratings.value[query] = {}
if (!ratings.value[query][model]) ratings.value[query][model] = {}
ratings.value[query][model][hit.chunk_idx] = rating
try {
await fetch('/api/embed-bench/rate', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({
query,
model,
chunk_text: hit.text,
chunk_idx: hit.chunk_idx,
rating,
}),
})
liveMessage.value = `Rated chunk ${hit.chunk_idx + 1} as ${rating}.`
} catch (err) {
liveMessage.value = `Rating failed: ${err}`
}
}
async function exportRatings() {
const r = await fetch(`/api/embed-bench/export?format=${exportFormat.value}`)
if (!r.ok) {
liveMessage.value = `Export failed: HTTP ${r.status}`
return
}
const blob = await r.blob()
const disposition = r.headers.get('Content-Disposition') ?? ''
const filenameMatch = disposition.match(/filename="([^"]+)"/)
const filename = filenameMatch ? filenameMatch[1] : `embed_comparison.${exportFormat.value}`
const url = URL.createObjectURL(blob)
const a = document.createElement('a')
a.href = url
a.download = filename
a.click()
URL.revokeObjectURL(url)
liveMessage.value = `Exported ${filename}.`
}
// Lifecycle
onMounted(() => {
loadModels()
})
</script>
<style scoped>
.embed-compare-page {
padding: var(--space-4, 1.5rem);
max-width: 1100px;
}
/* Step indicator */
.step-indicator {
display: flex;
gap: 0;
list-style: none;
margin: 0 0 var(--space-4, 1.5rem);
padding: 0;
border-bottom: 2px solid var(--color-border, #d0d7e8);
}
.step-indicator li {
padding: 0.4rem 1rem;
font-size: 0.8rem;
font-weight: 600;
color: var(--color-text-muted, #4a5c7a);
text-transform: uppercase;
letter-spacing: 0.05em;
border-bottom: 2px solid transparent;
margin-bottom: -2px;
}
.step-indicator li.complete {
color: var(--app-primary, #2A6080);
border-bottom-color: var(--app-primary, #2A6080);
}
/* Accessibility: screen-reader live region — visually hidden but always present */
.sr-live {
position: absolute;
width: 1px; height: 1px;
overflow: hidden;
clip: rect(0 0 0 0);
white-space: nowrap;
}
/* Cards */
.card {
background: var(--color-surface-raised, #e4ebf5);
border: 1px solid var(--color-border, #d0d7e8);
border-radius: var(--radius-md, 0.5rem);
padding: var(--space-4, 1.5rem);
margin-bottom: var(--space-4, 1.5rem);
}
.card h2 {
font-size: 1rem;
font-weight: 700;
margin: 0 0 var(--space-3, 1rem);
color: var(--color-text, #1a2338);
}
.field { display: flex; flex-direction: column; gap: 0.3rem; margin-bottom: 0.75rem; }
.field label { font-size: 0.85rem; font-weight: 600; }
textarea, input[type="number"] {
border: 1px solid var(--color-border, #d0d7e8);
border-radius: var(--radius-sm, 0.25rem);
padding: 0.5rem;
font-size: 0.875rem;
background: var(--color-surface, #f0f4fb);
color: var(--color-text, #1a2338);
resize: vertical;
}
.corpus-controls { display: flex; flex-direction: column; gap: 0.5rem; }
.import-row {
display: flex; flex-wrap: wrap; gap: 0.5rem; align-items: center;
}
.import-row label { font-size: 0.85rem; font-weight: 600; }
.corpus-count, .query-count { font-size: 0.875rem; color: var(--app-primary, #2A6080); margin: 0; }
.model-list { list-style: none; padding: 0; margin: 0; display: flex; flex-wrap: wrap; gap: 0.5rem; }
.model-checkbox {
display: flex; align-items: center; gap: 0.4rem;
font-size: 0.875rem; cursor: pointer;
padding: 0.3rem 0.6rem;
border: 1px solid var(--color-border, #d0d7e8);
border-radius: var(--radius-sm, 0.25rem);
background: var(--color-surface, #f0f4fb);
}
.model-size { font-size: 0.75rem; }
.run-row { display: flex; flex-wrap: wrap; gap: 0.75rem; align-items: flex-end; }
.field-inline { display: flex; align-items: center; gap: 0.4rem; }
.field-inline label { font-size: 0.85rem; font-weight: 600; white-space: nowrap; }
.btn-primary, .btn-secondary, .btn-danger {
padding: 0.4rem 1rem;
border-radius: var(--radius-sm, 0.25rem);
border: 1px solid transparent;
font-size: 0.875rem;
font-weight: 600;
cursor: pointer;
transition: background 0.15s;
}
.btn-primary { background: var(--app-primary, #2A6080); color: #fff; }
.btn-primary:hover:not(:disabled) { filter: brightness(1.1); }
.btn-primary:disabled { opacity: 0.5; cursor: not-allowed; }
.btn-secondary { background: var(--color-surface, #f0f4fb); color: var(--color-text, #1a2338); border-color: var(--color-border, #d0d7e8); }
.btn-secondary:hover:not(:disabled) { background: var(--color-border, #d0d7e8); }
.btn-secondary:disabled { opacity: 0.5; cursor: not-allowed; }
.btn-danger { background: var(--color-error, #c0392b); color: #fff; }
.muted { color: var(--color-text-muted, #4a5c7a); font-size: 0.875rem; }
.error-text { color: var(--color-error, #c0392b); font-size: 0.875rem; }
@media (max-width: 768px) {
.import-row { flex-direction: column; align-items: flex-start; }
.run-row { flex-direction: column; }
.model-list { flex-direction: column; }
}
/* Results table */
.table-wrap { overflow-x: auto; }
.results-table {
width: 100%;
border-collapse: collapse;
font-size: 0.875rem;
}
.results-table thead th {
position: sticky;
top: 0;
background: var(--color-surface-raised, #e4ebf5);
border-bottom: 2px solid var(--color-border, #d0d7e8);
padding: 0.5rem 0.75rem;
text-align: left;
font-weight: 700;
white-space: nowrap;
z-index: 1;
}
.results-table td {
padding: 0.5rem 0.75rem;
vertical-align: top;
border-bottom: 1px solid var(--color-border, #d0d7e8);
}
.rank-col { width: 2rem; text-align: center; }
.hit-text { margin-bottom: 0.25rem; line-height: 1.4; }
.score-row { display: flex; align-items: center; gap: 0.4rem; margin-bottom: 0.25rem; }
.score-bar-wrap {
flex: 1;
height: 6px;
background: var(--color-border, #d0d7e8);
border-radius: 3px;
overflow: hidden;
}
.score-bar {
height: 100%;
background: var(--app-primary, #2A6080);
border-radius: 3px;
transition: width 0.3s ease;
}
.score-label { font-size: 0.75rem; color: var(--color-text-muted, #4a5c7a); min-width: 3rem; text-align: right; }
.rating-row { display: flex; gap: 0.4rem; flex-wrap: wrap; }
.rate-btn {
padding: 0.2rem 0.5rem;
border: 1px solid var(--color-border, #d0d7e8);
border-radius: var(--radius-sm, 0.25rem);
background: var(--color-surface, #f0f4fb);
color: var(--color-text, #1a2338);
font-size: 0.75rem;
cursor: pointer;
transition: background 0.15s, border-color 0.15s;
}
.rate-btn.active {
background: color-mix(in srgb, var(--app-primary, #2A6080) 20%, transparent);
border-color: var(--app-primary, #2A6080);
font-weight: 700;
}
.rate-btn-neg.active {
background: color-mix(in srgb, var(--color-error, #c0392b) 15%, transparent);
border-color: var(--color-error, #c0392b);
}
/* Query nav */
.query-nav {
display: flex;
align-items: center;
gap: 0.5rem;
margin-bottom: 0.75rem;
flex-wrap: wrap;
}
.query-counter { font-size: 0.875rem; flex: 1; }
/* Export */
.export-row { display: flex; gap: 1rem; align-items: center; flex-wrap: wrap; }
.export-format-group {
border: none;
padding: 0;
display: flex;
gap: 0.75rem;
}
.export-format-group legend {
font-size: 0.85rem;
font-weight: 600;
margin-bottom: 0.25rem;
float: left;
margin-right: 0.5rem;
}
.export-format-group label { font-size: 0.875rem; display: flex; align-items: center; gap: 0.3rem; }
@media (max-width: 768px) {
.results-table thead th,
.results-table td { padding: 0.35rem 0.4rem; font-size: 0.8rem; }
.query-nav { flex-direction: column; align-items: flex-start; }
}
@media (prefers-reduced-motion: reduce) {
.score-bar { transition: none; }
}
</style>

View file

@ -0,0 +1,7 @@
<template>
<EmbedCompareTab />
</template>
<script setup lang="ts">
import EmbedCompareTab from './EmbedCompareTab.vue'
</script>

View file

@ -6,6 +6,8 @@
<summary class="picker-summary">
<span class="picker-title">📋 Task Selection</span>
<span class="picker-badge">{{ llmTaskBadge }}</span>
<button class="picker-bulk-btn" @click.stop.prevent="selectAllTasks()">All</button>
<button class="picker-bulk-btn" @click.stop.prevent="clearAllTasks()">None</button>
</summary>
<div class="picker-body">
<div v-if="llmTasksLoading" class="picker-loading">Loading tasks</div>
@ -44,6 +46,8 @@
<summary class="picker-summary">
<span class="picker-title">🎯 Model Selection</span>
<span class="picker-badge">{{ llmModelBadge }}</span>
<button class="picker-bulk-btn" @click.stop.prevent="selectAllModels()">All</button>
<button class="picker-bulk-btn" @click.stop.prevent="clearAllModels()">None</button>
</summary>
<div class="picker-body">
<div v-if="llmModelsLoading" class="picker-loading">Loading models</div>
@ -78,6 +82,33 @@
</div>
</details>
<!-- Node Selection -->
<div class="node-picker" v-if="llmNodes.length > 0">
<span class="node-picker-label">Nodes:</span>
<label
v-for="node in llmNodes"
:key="node.node_id"
class="node-chip"
:class="{ 'node-chip--off': !enabledNodes.has(node.node_id), 'node-chip--offline': !node.online }"
:title="node.online ? `${node.node_id} — ${node.gpus.length} GPU(s)` : `${node.node_id} — offline`"
>
<input
type="checkbox"
class="node-chip-check"
:checked="enabledNodes.has(node.node_id)"
:disabled="!node.online || llmRunning"
@change="toggleNode(node.node_id, ($event.target as HTMLInputElement).checked)"
/>
{{ node.node_id }}
<span class="node-chip-status" v-if="!node.online">offline</span>
</label>
<span class="node-picker-hint">
{{ enabledNodeIds.length === llmNodes.filter(n => n.online).length
? 'auto-routing (all nodes)'
: `restricted to: ${enabledNodeIds.join(', ')}` }}
</span>
</div>
<!-- Run Controls -->
<div class="run-controls">
<button
@ -88,6 +119,24 @@
{{ llmRunning ? '⏳ Running…' : '▶ Run LLM Eval' }}
</button>
<button v-if="llmRunning" class="btn-cancel" @click="cancelLlmBenchmark"> Cancel</button>
<input
v-model="llmJudgeUrl"
class="judge-url-input"
placeholder="Judge URL — leave empty to skip LLM judge scoring"
:disabled="llmRunning"
title="Optional: URL of a running cf-text service (e.g. http://10.1.10.158:8008). When set, each LLM response gets a secondary score from the judge model — adds a 'judge' column to results. Empty = primary quality scoring only."
/>
<label class="workers-label" title="Run this many models concurrently (requires multiple GPUs)">
<span class="workers-prefix">workers</span>
<input
v-model.number="llmWorkers"
type="number"
min="1"
max="8"
class="workers-input"
:disabled="llmRunning"
/>
</label>
<span v-if="selectedLlmTasks.size === 0 || selectedLlmModels.size === 0" class="run-hint">
Select at least one task and one model to run.
</span>
@ -119,6 +168,7 @@
<tr>
<th class="hm-label-col">Model</th>
<th class="hm-model-col">overall</th>
<th v-if="llmHasJudge" class="hm-model-col hm-judge-col">judge</th>
<th v-for="col in llmTaskTypeCols" :key="col" class="hm-model-col">{{ col }}</th>
<th class="hm-model-col">tok/s</th>
</tr>
@ -130,6 +180,12 @@
class="hm-value-cell"
:class="{ 'bt-best': llmBestByCol['overall'] === row.model_id }"
>{{ pct(row.avg_quality_score) }}</td>
<td
v-if="llmHasJudge"
class="hm-value-cell hm-judge-cell"
:class="{ 'bt-best': llmBestByCol['judge'] === row.model_id }"
title="LLM-as-judge secondary score"
>{{ row.avg_judge_score != null ? pct(row.avg_judge_score) : '—' }}</td>
<td
v-for="col in llmTaskTypeCols"
:key="col"
@ -168,6 +224,12 @@ interface CfOrchModel {
vram_estimate_mb?: number
}
interface CfOrchNode {
node_id: string
online: boolean
gpus: { gpu_id: number; name: string; vram_total_mb: number; vram_free_mb: number }[]
}
interface LlmModelResult {
model_name: string
model_id: string
@ -175,9 +237,11 @@ interface LlmModelResult {
avg_tokens_per_sec: number
avg_completion_ms: number
avg_quality_score: number
avg_judge_score: number | null
finetune_candidates: number
error_count: number
quality_by_task_type: Record<string, number>
judge_score_by_task_type?: Record<string, number>
}
// State
@ -195,6 +259,10 @@ const llmError = ref('')
const llmResults = ref<LlmModelResult[]>([])
const llmEventSource = ref<EventSource | null>(null)
const llmLogEl = ref<HTMLElement | null>(null)
const llmJudgeUrl = ref('')
const llmWorkers = ref(1)
const llmNodes = ref<CfOrchNode[]>([])
const enabledNodes = ref<Set<string>>(new Set())
// Computed
const llmTasksByType = computed((): Record<string, CfOrchTask[]> => {
@ -234,11 +302,19 @@ const llmModelBadge = computed(() => {
const llmTaskTypeCols = computed(() => {
const types = new Set<string>()
for (const r of llmResults.value) {
for (const k of Object.keys(r.quality_by_task_type)) types.add(k)
for (const k of Object.keys(r.quality_by_task_type ?? {})) types.add(k)
}
return [...types].sort()
})
const llmHasJudge = computed(() =>
llmResults.value.some(r => r.avg_judge_score != null)
)
const enabledNodeIds = computed(() =>
llmNodes.value.filter(n => n.online && enabledNodes.value.has(n.node_id)).map(n => n.node_id)
)
const llmBestByCol = computed((): Record<string, string> => {
const best: Record<string, string> = {}
if (llmResults.value.length === 0) return best
@ -249,10 +325,20 @@ const llmBestByCol = computed((): Record<string, string> => {
}
best['overall'] = bestId
if (llmHasJudge.value) {
bestId = ''; bestVal = -Infinity
for (const r of llmResults.value) {
if (r.avg_judge_score != null && r.avg_judge_score > bestVal) {
bestVal = r.avg_judge_score; bestId = r.model_id
}
}
best['judge'] = bestId
}
for (const col of llmTaskTypeCols.value) {
bestId = ''; bestVal = -Infinity
for (const r of llmResults.value) {
const v = r.quality_by_task_type[col]
const v = r.quality_by_task_type?.[col]
if (v != null && v > bestVal) { bestVal = v; bestId = r.model_id }
}
best[col] = bestId
@ -306,6 +392,15 @@ function toggleService(models: CfOrchModel[], checked: boolean) {
}
selectedLlmModels.value = next
}
function selectAllTasks() { selectedLlmTasks.value = new Set(llmTasks.value.map(t => t.id)) }
function clearAllTasks() { selectedLlmTasks.value = new Set() }
function selectAllModels() { selectedLlmModels.value = new Set(llmModels.value.map(m => m.id)) }
function clearAllModels() { selectedLlmModels.value = new Set() }
function toggleNode(id: string, checked: boolean) {
const next = new Set(enabledNodes.value)
checked ? next.add(id) : next.delete(id)
enabledNodes.value = next
}
// Data loaders
async function loadLlmTasks() {
@ -335,6 +430,21 @@ async function loadLlmResults() {
}
}
async function loadLlmConfig() {
const { data } = await useApiFetch<{ judge_url?: string }>('/api/cforch/config')
if (data?.judge_url && !llmJudgeUrl.value) {
llmJudgeUrl.value = data.judge_url
}
}
async function loadLlmNodes() {
const { data } = await useApiFetch<{ nodes: CfOrchNode[] }>('/api/cforch/nodes')
if (data?.nodes) {
llmNodes.value = data.nodes
enabledNodes.value = new Set(data.nodes.filter(n => n.online).map(n => n.node_id))
}
}
// Run / cancel
function startLlmBenchmark() {
llmRunning.value = true
@ -344,6 +454,15 @@ function startLlmBenchmark() {
const params = new URLSearchParams()
const taskIds = [...selectedLlmTasks.value].join(',')
if (taskIds) params.set('task_ids', taskIds)
const modelIds = [...selectedLlmModels.value].join(',')
if (modelIds) params.set('model_ids', modelIds)
if (llmJudgeUrl.value.trim()) params.set('judge_url', llmJudgeUrl.value.trim())
if (llmWorkers.value > 1) params.set('workers', String(llmWorkers.value))
const onlineNodeIds = llmNodes.value.filter(n => n.online).map(n => n.node_id)
const isRestricted = enabledNodeIds.value.length < onlineNodeIds.length
if (isRestricted && enabledNodeIds.value.length > 0) {
params.set('node_ids', enabledNodeIds.value.join(','))
}
const es = new EventSource(`/api/cforch/run?${params}`)
llmEventSource.value = es
@ -387,6 +506,8 @@ onMounted(() => {
loadLlmTasks()
loadLlmModels()
loadLlmResults()
loadLlmConfig()
loadLlmNodes()
})
</script>
@ -451,6 +572,43 @@ onMounted(() => {
color: var(--color-text-secondary, #6b7a99);
}
.judge-url-input {
flex: 1;
min-width: 14rem;
max-width: 24rem;
padding: 0.35rem 0.6rem;
border: 1px solid var(--color-border, #d0d7e8);
border-radius: 0.375rem;
background: var(--color-surface, #fff);
color: var(--color-text, #1a2338);
font-size: 0.8rem;
font-family: var(--font-mono, monospace);
}
.judge-url-input:disabled { opacity: 0.5; }
.judge-url-input::placeholder { color: var(--color-text-secondary, #6b7a99); }
.workers-label {
display: flex;
align-items: center;
gap: 0.3rem;
font-size: 0.8rem;
color: var(--color-text-secondary, #6b7a99);
white-space: nowrap;
}
.workers-prefix { font-family: var(--font-mono, monospace); }
.workers-input {
width: 3.2rem;
padding: 0.35rem 0.4rem;
border: 1px solid var(--color-border, #d0d7e8);
border-radius: 0.375rem;
background: var(--color-surface, #fff);
color: var(--color-text, #1a2338);
font-size: 0.8rem;
font-family: var(--font-mono, monospace);
text-align: center;
}
.workers-input:disabled { opacity: 0.5; }
/* ── Run log ────────────────────────────────────────────── */
.run-log {
border: 1px solid var(--color-border, #d0d7e8);
@ -592,6 +750,15 @@ onMounted(() => {
white-space: nowrap;
}
.hm-judge-col {
background: color-mix(in srgb, var(--color-surface-raised, #e4ebf5) 80%, #c6d5f5);
}
.hm-judge-cell {
background: color-mix(in srgb, var(--color-surface, #fff) 85%, #c6d5f5);
font-style: italic;
opacity: 0.9;
}
/* ── Model Picker ───────────────────────────────────────── */
.model-picker {
border: 1px solid var(--color-border, #d0d7e8);
@ -630,6 +797,24 @@ details[open] .picker-summary::before { content: '▼ '; }
margin-left: auto;
}
.picker-bulk-btn {
padding: 0.1rem 0.45rem;
font-size: 0.7rem;
font-family: var(--font-mono, monospace);
background: var(--color-surface, #fff);
border: 1px solid var(--color-border, #d0d7e8);
border-radius: 0.25rem;
color: var(--color-text-secondary, #6b7a99);
cursor: pointer;
transition: background 0.12s, color 0.12s;
flex-shrink: 0;
}
.picker-bulk-btn:hover {
background: var(--app-primary, #2A6080);
color: #fff;
border-color: var(--app-primary, #2A6080);
}
.picker-body {
padding: 0.75rem;
border-top: 1px solid var(--color-border, #d0d7e8);
@ -712,4 +897,61 @@ details[open] .picker-summary::before { content: '▼ '; }
.picker-model-list { padding-left: 0; }
.picker-model-name { max-width: 14ch; }
}
/* ── Node picker ────────────────────────────────────── */
.node-picker {
display: flex;
align-items: center;
gap: 0.5rem;
flex-wrap: wrap;
padding: 0.5rem 0.75rem;
border: 1px solid var(--color-border, #d0d7e8);
border-radius: 0.5rem;
background: var(--color-surface-raised, #e4ebf5);
}
.node-picker-label {
font-size: 0.78rem;
font-weight: 600;
color: var(--color-text-secondary, #6b7a99);
text-transform: uppercase;
letter-spacing: 0.04em;
white-space: nowrap;
}
.node-chip {
display: inline-flex;
align-items: center;
gap: 0.3rem;
padding: 0.2rem 0.55rem;
border: 1px solid var(--color-border, #d0d7e8);
border-radius: 1rem;
background: var(--color-surface, #fff);
font-size: 0.78rem;
font-family: var(--font-mono, monospace);
color: var(--color-text, #1a2338);
cursor: pointer;
transition: background 0.12s, opacity 0.12s;
}
.node-chip--off {
opacity: 0.45;
background: transparent;
}
.node-chip--offline {
opacity: 0.35;
cursor: not-allowed;
font-style: italic;
}
.node-chip-check { accent-color: var(--app-primary, #2A6080); }
.node-chip-status {
font-size: 0.66rem;
color: var(--color-text-secondary, #6b7a99);
}
.node-picker-hint {
font-size: 0.72rem;
color: var(--color-text-secondary, #6b7a99);
font-family: var(--font-mono, monospace);
margin-left: auto;
}
</style>

View file

@ -2,6 +2,24 @@
<div class="models-view">
<h1 class="page-title">🤗 Models</h1>
<!-- Fleet tab bar -->
<div class="mode-toggle" role="group" aria-label="Fleet view">
<button
class="mode-btn"
:class="{ active: fleetTab === 'models' }"
@click="fleetTab = 'models'"
>Models</button>
<button
class="mode-btn"
:class="{ active: fleetTab === 'assignments' }"
@click="fleetTab = 'assignments'"
>Assignments</button>
</div>
<AssignmentsTab v-if="fleetTab === 'assignments'" />
<template v-if="fleetTab === 'models'">
<!-- 1. HF Lookup -->
<section class="section">
<h2 class="section-title">HuggingFace Lookup</h2>
@ -51,8 +69,31 @@
<span v-if="lookupResult.adapter_recommendation" class="chip chip-adapter">
{{ lookupResult.adapter_recommendation }}
</span>
<span v-if="lookupResult.size != null" class="preview-size">
{{ humanBytes(lookupResult.size) }}
<span v-if="selectedQuantSize > 0" class="preview-size">
{{ humanBytes(selectedQuantSize) }}
</span>
</div>
<!-- GGUF quantization picker only shown for GGUF repos -->
<div v-if="lookupResult.gguf_files?.length" class="quant-picker">
<label class="quant-label" for="quant-select">Quantization</label>
<select
id="quant-select"
v-model="selectedQuant"
class="quant-select"
aria-label="Select quantization variant"
>
<option :value="null" disabled>Select quantization</option>
<option
v-for="f in lookupResult.gguf_files"
:key="f.filename"
:value="f.quant_name ?? f.filename"
>
{{ f.quant_name ?? f.filename }} {{ humanBytes(f.size) }}
</option>
</select>
<span class="quant-hint">
Q5_K_M or Q6_K recommended for 8 GB GPUs. Q8_0 for max quality.
</span>
</div>
@ -67,7 +108,7 @@
<button
class="btn-primary btn-add-queue"
:disabled="lookupResult.already_installed || lookupResult.already_queued || addingToQueue"
:disabled="!canAddToQueue"
@click="addToQueue"
>
{{ addingToQueue ? 'Adding…' : 'Add to queue' }}
@ -99,9 +140,39 @@
<span v-if="model.role" class="chip chip-role">{{ model.role }}</span>
<span v-if="model.service" class="chip" :class="serviceChipClass(model.service)">{{ model.service }}</span>
<span v-if="model.adapter_recommendation" class="chip chip-adapter">{{ model.adapter_recommendation }}</span>
<span v-if="model.quant_pattern" class="chip chip-quant">{{ model.quant_pattern }}</span>
</div>
<!-- Allow manual service/role assignment for unrecognized pipeline tags -->
<div v-if="!model.service" class="classify-row queue-classify">
<select
class="classify-select"
:value="classifyDraft[model.id]?.service ?? ''"
@change="onServiceChange(model.id, ($event.target as HTMLSelectElement).value)"
aria-label="Assign service"
>
<option value="" disabled>Service</option>
<option v-for="svc in CLASSIFIABLE_SERVICES" :key="svc.value" :value="svc.value">{{ svc.label }}</option>
</select>
<select
class="classify-select"
:value="classifyDraft[model.id]?.role ?? ''"
:disabled="!classifyDraft[model.id]?.service"
@change="(e) => setClassifyRole(model.id, (e.target as HTMLSelectElement).value)"
aria-label="Assign role"
>
<option value="" disabled>Role</option>
<option
v-for="role in rolesForService(classifyDraft[model.id]?.service ?? '')"
:key="role"
:value="role"
>{{ role }}</option>
</select>
</div>
<div class="model-card-actions">
<button class="btn-primary btn-sm" @click="approveModel(model.id)">
<button
class="btn-primary btn-sm"
@click="approveModel(model.id, classifyDraft[model.id])"
>
Approve download
</button>
</div>
@ -244,14 +315,26 @@
</div>
</template>
</section>
</template><!-- end fleetTab === 'models' -->
</div>
</template>
<script setup lang="ts">
import { ref, computed, onMounted, onUnmounted } from 'vue'
import AssignmentsTab from './AssignmentsTab.vue'
type FleetTab = 'models' | 'assignments'
const fleetTab = ref<FleetTab>('models')
// Type definitions
interface GgufFile {
filename: string
size: number
quant_name: string | null
}
interface LookupResult {
repo_id: string
pipeline_tag: string | null
@ -260,7 +343,8 @@ interface LookupResult {
service: string | null
compatible: boolean
warning: string | null
size: number | null
model_size_bytes: number
gguf_files: GgufFile[] | null
description: string | null
already_installed: boolean
already_queued: boolean
@ -274,6 +358,7 @@ interface QueuedModel {
adapter_recommendation: string | null
role: string | null
service: string | null
quant_pattern: string | null
}
interface InstalledModel {
@ -302,6 +387,26 @@ const lookupLoading = ref(false)
const lookupError = ref<string | null>(null)
const lookupResult = ref<LookupResult | null>(null)
const addingToQueue = ref(false)
const selectedQuant = ref<string | null>(null)
// Size of the selected GGUF file, or total model size for non-GGUF repos.
const selectedQuantSize = computed<number>(() => {
const r = lookupResult.value
if (!r) return 0
if (r.gguf_files?.length && selectedQuant.value) {
const f = r.gguf_files.find(f => (f.quant_name ?? f.filename) === selectedQuant.value)
return f?.size ?? r.model_size_bytes
}
return r.model_size_bytes
})
// Disable "Add to queue" when a GGUF repo but no quant chosen yet.
const canAddToQueue = computed(() => {
const r = lookupResult.value
if (!r || r.already_installed || r.already_queued || addingToQueue.value) return false
if (r.gguf_files?.length && !selectedQuant.value) return false
return true
})
const queuedModels = ref<QueuedModel[]>([])
const installedModels = ref<InstalledModel[]>([])
@ -411,6 +516,7 @@ async function doLookup() {
lookupLoading.value = true
lookupError.value = null
lookupResult.value = null
selectedQuant.value = null
try {
const res = await fetch(`/api/models/lookup?repo_id=${encodeURIComponent(repoId)}`)
@ -442,7 +548,15 @@ async function addToQueue() {
const res = await fetch('/api/models/queue', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ repo_id, pipeline_tag, adapter_recommendation, role, service }),
body: JSON.stringify({
repo_id,
pipeline_tag,
adapter_recommendation,
role,
service,
model_size_bytes: selectedQuantSize.value,
quant_pattern: selectedQuant.value,
}),
})
if (res.ok) {
lookupResult.value = { ...lookupResult.value, already_queued: true }
@ -454,8 +568,16 @@ async function addToQueue() {
}
}
async function approveModel(id: string) {
async function approveModel(id: string, draft?: { service: string; role: string }) {
try {
// If the user picked a service/role for an unrecognized model, patch it first.
if (draft?.service && draft?.role) {
await fetch(`/api/models/queue/${encodeURIComponent(id)}`, {
method: 'PATCH',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ service: draft.service, role: draft.role }),
})
}
const res = await fetch(`/api/models/queue/${encodeURIComponent(id)}/approve`, { method: 'POST' })
if (res.ok) {
await loadQueue()
@ -640,6 +762,39 @@ onUnmounted(() => {
color: var(--color-primary, #2d5a27);
}
/* ── Fleet tab bar (mode-toggle pattern from BenchmarkView) ── */
.mode-toggle {
display: inline-flex;
border: 1px solid var(--color-border, #d0d7e8);
border-radius: 0.5rem;
overflow: hidden;
align-self: flex-start;
}
.mode-btn {
padding: 0.4rem 1.1rem;
font-size: 0.85rem;
font-family: var(--font-body, sans-serif);
font-weight: 500;
border: none;
background: var(--color-surface, #fff);
color: var(--color-text-secondary, #6b7a99);
cursor: pointer;
transition: background 0.15s, color 0.15s;
}
.mode-btn:not(:last-child) {
border-right: 1px solid var(--color-border, #d0d7e8);
}
.mode-btn.active {
background: var(--app-primary, #2A6080);
color: #fff;
}
.mode-btn:not(.active):hover {
background: var(--color-surface-raised, #e4ebf5);
}
@media (max-width: 600px) {
.mode-btn { padding: 0.4rem 0.65rem; font-size: 0.78rem; }
}
/* ── Sections ── */
.section {
display: flex;
@ -774,6 +929,44 @@ onUnmounted(() => {
align-self: flex-start;
}
/* ── Quant picker ── */
.quant-picker {
display: flex;
flex-direction: column;
gap: 0.35rem;
}
.quant-label {
font-size: 0.8rem;
font-weight: 600;
color: var(--color-text-muted, #4a5c7a);
text-transform: uppercase;
letter-spacing: 0.04em;
}
.quant-select {
padding: 0.4rem 0.6rem;
border: 1px solid var(--color-border, #a8b8d0);
border-radius: var(--radius-md, 0.5rem);
background: var(--color-surface, #f0f4fb);
color: var(--color-text, #1a2338);
font-size: 0.9rem;
font-family: var(--font-mono, monospace);
cursor: pointer;
}
.quant-hint {
font-size: 0.78rem;
color: var(--color-text-muted, #4a5c7a);
}
.chip-quant {
background: color-mix(in srgb, var(--color-primary, #2A6080) 15%, transparent);
color: var(--color-primary, #2A6080);
font-family: var(--font-mono, monospace);
font-size: 0.75rem;
}
/* ── Model cards (queue + downloads) ── */
.model-card {
border: 1px solid var(--color-border, #a8b8d0);

View file

@ -0,0 +1,165 @@
<script setup lang="ts">
import { ref, onMounted } from 'vue'
import NodeCard from '../components/nodes/NodeCard.vue'
import AssignmentsTab from './AssignmentsTab.vue'
import type { NodeSummary } from '../types/nodes'
type Tab = 'nodes' | 'assignments'
const activeTab = ref<Tab>('nodes')
const nodes = ref<NodeSummary[]>([])
const loading = ref(true)
const error = ref('')
async function fetchNodes() {
loading.value = true
error.value = ''
try {
const r = await fetch('/api/nodes-mgmt/nodes')
if (!r.ok) throw new Error(`HTTP ${r.status}`)
nodes.value = (await r.json()) as NodeSummary[]
} catch (e) {
error.value = e instanceof Error ? e.message : 'Failed to load nodes'
} finally {
loading.value = false
}
}
onMounted(fetchNodes)
</script>
<template>
<main class="fleet-page">
<header class="fleet-header">
<h1 class="fleet-title">Fleet</h1>
</header>
<!-- Tab bar -->
<nav class="tab-bar" role="tablist" aria-label="Fleet sections">
<button
id="tab-nodes"
role="tab"
:aria-selected="activeTab === 'nodes'"
:class="['tab', { active: activeTab === 'nodes' }]"
@click="activeTab = 'nodes'"
>Nodes</button>
<button
id="tab-assignments"
role="tab"
:aria-selected="activeTab === 'assignments'"
:class="['tab', { active: activeTab === 'assignments' }]"
@click="activeTab = 'assignments'"
>Assignments</button>
</nav>
<!-- Nodes tab -->
<section
v-if="activeTab === 'nodes'"
role="tabpanel"
aria-labelledby="tab-nodes"
class="tab-panel"
>
<div class="nodes-toolbar">
<button class="btn-secondary btn-sm" @click="fetchNodes" :disabled="loading">Refresh</button>
</div>
<div aria-live="polite" aria-atomic="true" class="sr-announce">
<span v-if="loading">Loading nodes...</span>
</div>
<div v-if="error" class="nodes-status nodes-error" role="alert">{{ error }}</div>
<div v-else-if="!loading && nodes.length === 0" class="nodes-status">
No nodes found. Check <code>coordinator_url</code> in config.
</div>
<div v-else-if="!loading" class="nodes-grid">
<NodeCard
v-for="node in nodes"
:key="node.node_id"
:node="node"
@updated="fetchNodes"
/>
</div>
</section>
<!-- Assignments tab -->
<section
v-else-if="activeTab === 'assignments'"
role="tabpanel"
aria-labelledby="tab-assignments"
class="tab-panel"
>
<AssignmentsTab />
</section>
</main>
</template>
<style scoped>
.fleet-page { padding: 1.5rem; }
.fleet-header {
margin-bottom: 1rem;
}
.fleet-title {
margin: 0;
font-size: 1.5rem;
color: var(--color-text);
}
/* ── Tab bar ── */
.tab-bar {
display: flex;
gap: 0;
border-bottom: 2px solid var(--color-border);
margin-bottom: 1.25rem;
}
.tab {
padding: 0.55rem 1.1rem;
font-size: 0.88rem;
font-weight: 600;
background: none;
border: none;
border-bottom: 2px solid transparent;
margin-bottom: -2px;
cursor: pointer;
color: var(--color-text-muted);
transition: color 0.15s, border-color 0.15s;
}
.tab:hover { color: var(--color-text); }
.tab.active {
color: var(--app-primary);
border-bottom-color: var(--app-primary);
}
/* ── Tab panel ── */
.tab-panel { min-height: 200px; }
/* ── Nodes toolbar ── */
.nodes-toolbar {
display: flex;
justify-content: flex-end;
margin-bottom: 1rem;
}
/* ── Nodes grid / status ── */
.nodes-grid { display: flex; flex-direction: column; gap: 1.5rem; }
.nodes-status {
color: var(--color-text-muted);
padding: 2rem;
text-align: center;
}
.nodes-error { color: var(--color-error); }
.sr-announce { min-height: 1.2em; }
/* ── Shared button ── */
.btn-secondary {
padding: 0.4rem 0.9rem;
background: var(--color-surface-alt);
border: 1px solid var(--color-border);
border-radius: 0.4rem;
font-size: 0.85rem;
color: var(--color-text);
cursor: pointer;
transition: background 0.15s;
}
.btn-secondary:hover:not(:disabled) { background: var(--color-surface-raised); }
.btn-secondary:disabled { opacity: 0.5; cursor: default; }
.btn-sm { padding: 0.3rem 0.65rem; font-size: 0.8rem; }
</style>

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,536 @@
<template>
<div class="rsv">
<!-- Header -->
<header class="rsv-header">
<h1 class="rsv-title">Recipe Scan Review</h1>
<div class="rsv-stats" v-if="stats">
<span class="stat-chip">{{ stats.by_status?.pending ?? 0 }} pending</span>
<span class="stat-chip stat-chip--ok">{{ stats.by_status?.approved ?? 0 }} approved</span>
<span class="stat-chip stat-chip--edited">{{ stats.by_status?.edited ?? 0 }} edited</span>
<span class="stat-chip stat-chip--bad">{{ stats.by_status?.rejected ?? 0 }} rejected</span>
<a
v-if="(stats.export_ready ?? 0) > 0"
:href="`${apiBase}/api/recipe-scan/export`"
download
class="btn-export"
>
Export {{ stats.export_ready }} pairs
</a>
</div>
</header>
<!-- Loading -->
<div v-if="loading" class="rsv-state" aria-label="Loading">
<div class="skeleton-block" />
</div>
<!-- Error -->
<div v-else-if="apiError" class="rsv-state rsv-error" role="alert">
<p>{{ apiError }}</p>
<button class="btn-action" @click="fetchNext">Retry</button>
</div>
<!-- Queue empty -->
<div v-else-if="!item" class="rsv-state rsv-empty">
<p>Queue is empty all items reviewed.</p>
<p class="rsv-hint">Import items from the Kiwi pipeline to continue.</p>
</div>
<!-- Review panel -->
<div v-else class="rsv-workspace">
<!-- Left: image -->
<section class="rsv-image-panel" aria-label="Scan image">
<div class="rsv-panel-label">
<span class="modality-badge">{{ item.modality }}</span>
<span class="source-badge">{{ item.source }}</span>
</div>
<div class="rsv-image-wrap">
<img
v-if="imageUrl"
:src="imageUrl"
:alt="`Recipe scan — ${item.source}`"
class="rsv-image"
/>
<div v-else class="rsv-image-placeholder">
<span>Image not available</span>
<code class="rsv-path">{{ item.image_path }}</code>
</div>
</div>
</section>
<!-- Right: JSON comparison -->
<section class="rsv-json-panel" aria-label="Extraction review">
<!-- Ground truth (read-only reference) -->
<div class="rsv-json-block">
<h2 class="rsv-json-label">Ground truth <span class="label-tag">reference</span></h2>
<pre class="rsv-json rsv-json--ground-truth" tabindex="0" aria-label="Ground truth JSON">{{ prettyJson(item.ground_truth) }}</pre>
</div>
<!-- Extracted / editable -->
<div class="rsv-json-block">
<h2 class="rsv-json-label">
Extracted
<span class="label-tag label-tag--edit">edit before approving</span>
</h2>
<textarea
v-model="draftJson"
class="rsv-json rsv-json--edit"
spellcheck="false"
aria-label="Extracted JSON — edit to correct"
:class="{ 'rsv-json--invalid': jsonError }"
/>
<p v-if="jsonError" class="rsv-json-error" role="alert">{{ jsonError }}</p>
</div>
<!-- Actions -->
<div class="rsv-actions" role="group" aria-label="Review actions">
<button
class="btn-approve"
:disabled="acting"
@click="handleApprove"
title="Extracted JSON is accurate — approve as-is (A)"
>
Approve
</button>
<button
class="btn-edit"
:disabled="acting || !!jsonError"
@click="handleEdit"
title="Approve the edited JSON in the text area (E)"
>
Approve edited
</button>
<button
class="btn-reject"
:disabled="acting"
@click="handleReject"
title="Extraction too broken to use — reject (R)"
>
Reject
</button>
</div>
</section>
</div>
<!-- Feedback toast -->
<Transition name="toast">
<div v-if="toast" class="rsv-toast" role="status" aria-live="polite">
{{ toast }}
</div>
</Transition>
</div>
</template>
<script setup lang="ts">
import { ref, computed, watch, onMounted, onUnmounted } from 'vue'
const apiBase = window.location.origin
interface RecipeScanItem {
id: string
image_path: string
modality: string
source: string
extracted: Record<string, unknown>
ground_truth: Record<string, unknown>
status: string
}
interface Stats {
total: number
by_status: Record<string, number>
by_modality: Record<string, number>
export_ready: number
}
const item = ref<RecipeScanItem | null>(null)
const stats = ref<Stats | null>(null)
const loading = ref(true)
const acting = ref(false)
const apiError = ref('')
const draftJson = ref('')
const toast = ref('')
let toastTimer: ReturnType<typeof setTimeout> | null = null
const jsonError = computed(() => {
if (!draftJson.value.trim()) return ''
try {
JSON.parse(draftJson.value)
return ''
} catch (e) {
return 'Invalid JSON — fix before approving'
}
})
const imageUrl = computed(() => {
if (!item.value) return ''
const encoded = encodeURIComponent(item.value.image_path)
return `${apiBase}/api/recipe-scan/image?path=${encoded}`
})
function prettyJson(obj: unknown): string {
return JSON.stringify(obj, null, 2)
}
function showToast(msg: string) {
toast.value = msg
if (toastTimer) clearTimeout(toastTimer)
toastTimer = setTimeout(() => { toast.value = '' }, 2500)
}
async function fetchNext() {
loading.value = true
apiError.value = ''
try {
const r = await fetch(`${apiBase}/api/recipe-scan/next`)
if (r.status === 404) {
item.value = null
} else if (!r.ok) {
throw new Error(`API error ${r.status}`)
} else {
item.value = await r.json()
draftJson.value = prettyJson(item.value!.extracted)
}
} catch (e) {
apiError.value = e instanceof Error ? e.message : 'Could not reach API'
} finally {
loading.value = false
}
}
async function fetchStats() {
try {
const r = await fetch(`${apiBase}/api/recipe-scan/stats`)
if (r.ok) stats.value = await r.json()
} catch { /* non-critical */ }
}
async function act(endpoint: string, body?: unknown) {
if (!item.value || acting.value) return
acting.value = true
try {
const r = await fetch(`${apiBase}/api/recipe-scan/items/${item.value.id}/${endpoint}`, {
method: 'POST',
headers: body ? { 'Content-Type': 'application/json' } : {},
body: body ? JSON.stringify(body) : undefined,
})
if (!r.ok) throw new Error(`API error ${r.status}`)
} catch (e) {
showToast(e instanceof Error ? e.message : 'Action failed')
acting.value = false
return
}
acting.value = false
await Promise.all([fetchNext(), fetchStats()])
}
async function handleApprove() {
showToast('Approved')
await act('approve')
}
async function handleEdit() {
if (jsonError.value) return
let corrected: unknown
try {
corrected = JSON.parse(draftJson.value)
} catch {
return
}
showToast('Saved edit')
await act('edit', { corrected })
}
async function handleReject() {
showToast('Rejected')
await act('reject')
}
// Keyboard shortcuts: A = approve, E = edit+approve, R = reject
function handleKey(e: KeyboardEvent) {
const tag = (e.target as HTMLElement)?.tagName?.toLowerCase()
if (tag === 'textarea' || tag === 'input') return
if (e.key === 'a' || e.key === 'A') handleApprove()
if (e.key === 'e' || e.key === 'E') handleEdit()
if (e.key === 'r' || e.key === 'R') handleReject()
}
watch(item, (newItem) => {
if (newItem) draftJson.value = prettyJson(newItem.extracted)
})
onMounted(() => {
fetchNext()
fetchStats()
window.addEventListener('keydown', handleKey)
})
onUnmounted(() => {
window.removeEventListener('keydown', handleKey)
if (toastTimer) clearTimeout(toastTimer)
})
</script>
<style scoped>
.rsv {
display: flex;
flex-direction: column;
height: 100%;
padding: var(--space-md, 1rem);
gap: var(--space-md, 1rem);
box-sizing: border-box;
overflow: hidden;
}
/* Header */
.rsv-header {
display: flex;
align-items: center;
gap: var(--space-md, 1rem);
flex-wrap: wrap;
}
.rsv-title {
font-size: 1.1rem;
font-weight: 600;
margin: 0;
color: var(--color-text, #fff);
}
.rsv-stats {
display: flex;
align-items: center;
gap: 0.5rem;
flex-wrap: wrap;
}
.stat-chip {
font-size: 0.75rem;
padding: 2px 8px;
border-radius: 12px;
background: var(--color-surface-alt, #2a2a2a);
color: var(--color-text-muted, #aaa);
}
.stat-chip--ok { background: #1a3a1a; color: #6fcf97; }
.stat-chip--edited { background: #2a2a00; color: #f2c94c; }
.stat-chip--bad { background: #3a1a1a; color: #eb5757; }
.btn-export {
font-size: 0.8rem;
padding: 4px 12px;
border-radius: 6px;
background: var(--color-accent, #4a9eff);
color: #fff;
text-decoration: none;
}
/* State panels */
.rsv-state {
flex: 1;
display: flex;
flex-direction: column;
align-items: center;
justify-content: center;
gap: 0.5rem;
color: var(--color-text-muted, #aaa);
}
.rsv-error { color: var(--color-danger, #eb5757); }
.rsv-empty { font-size: 1rem; }
.rsv-hint { font-size: 0.85rem; opacity: 0.7; margin: 0; }
.skeleton-block {
width: 100%; height: 300px;
border-radius: 8px;
background: var(--color-surface-alt, #2a2a2a);
animation: pulse 1.5s ease-in-out infinite;
}
@keyframes pulse { 0%, 100% { opacity: 1; } 50% { opacity: 0.5; } }
/* Workspace: two-column layout */
.rsv-workspace {
flex: 1;
display: grid;
grid-template-columns: 1fr 1fr;
gap: var(--space-md, 1rem);
min-height: 0;
overflow: hidden;
}
@media (max-width: 900px) {
.rsv-workspace {
grid-template-columns: 1fr;
overflow-y: auto;
}
}
/* Image panel */
.rsv-image-panel {
display: flex;
flex-direction: column;
gap: 0.5rem;
min-height: 0;
}
.rsv-panel-label {
display: flex;
gap: 0.5rem;
}
.modality-badge, .source-badge {
font-size: 0.72rem;
padding: 2px 8px;
border-radius: 10px;
background: var(--color-surface-alt, #2a2a2a);
color: var(--color-text-muted, #aaa);
text-transform: uppercase;
letter-spacing: 0.04em;
}
.rsv-image-wrap {
flex: 1;
display: flex;
align-items: center;
justify-content: center;
background: var(--color-surface-alt, #111);
border-radius: 8px;
overflow: hidden;
min-height: 200px;
}
.rsv-image {
max-width: 100%;
max-height: 100%;
object-fit: contain;
}
.rsv-image-placeholder {
display: flex;
flex-direction: column;
align-items: center;
gap: 0.5rem;
color: var(--color-text-muted, #666);
font-size: 0.85rem;
padding: 1rem;
text-align: center;
}
.rsv-path {
font-size: 0.7rem;
word-break: break-all;
opacity: 0.6;
}
/* JSON panel */
.rsv-json-panel {
display: flex;
flex-direction: column;
gap: 0.5rem;
min-height: 0;
overflow-y: auto;
}
.rsv-json-block {
display: flex;
flex-direction: column;
gap: 0.25rem;
flex: 1;
min-height: 0;
}
.rsv-json-label {
font-size: 0.8rem;
font-weight: 600;
color: var(--color-text-muted, #aaa);
margin: 0;
display: flex;
align-items: center;
gap: 0.5rem;
}
.label-tag {
font-size: 0.68rem;
font-weight: 400;
padding: 1px 6px;
border-radius: 8px;
background: var(--color-surface-alt, #2a2a2a);
color: var(--color-text-muted, #888);
}
.label-tag--edit {
background: #2a2a00;
color: #f2c94c;
}
.rsv-json {
font-family: var(--font-mono, monospace);
font-size: 0.75rem;
line-height: 1.5;
padding: 0.75rem;
border-radius: 6px;
min-height: 120px;
flex: 1;
overflow-y: auto;
resize: vertical;
white-space: pre;
}
.rsv-json--ground-truth {
background: var(--color-surface-alt, #111);
color: var(--color-text, #ccc);
border: 1px solid var(--color-border, #333);
}
.rsv-json--edit {
background: var(--color-surface, #1a1a1a);
color: var(--color-text, #e0e0e0);
border: 1px solid var(--color-border, #444);
caret-color: var(--color-accent, #4a9eff);
outline: none;
transition: border-color 0.15s;
}
.rsv-json--edit:focus {
border-color: var(--color-accent, #4a9eff);
}
.rsv-json--invalid {
border-color: var(--color-danger, #eb5757) !important;
}
.rsv-json-error {
font-size: 0.75rem;
color: var(--color-danger, #eb5757);
margin: 0;
}
/* Action buttons */
.rsv-actions {
display: flex;
gap: 0.5rem;
padding-top: 0.25rem;
flex-wrap: wrap;
}
.btn-approve, .btn-edit, .btn-reject {
flex: 1;
min-width: 80px;
padding: 0.5rem 0.75rem;
border: none;
border-radius: 6px;
font-size: 0.85rem;
font-weight: 600;
cursor: pointer;
transition: opacity 0.15s;
}
.btn-approve, .btn-edit, .btn-reject {
opacity: 1;
}
.btn-approve:disabled, .btn-edit:disabled, .btn-reject:disabled {
opacity: 0.4;
cursor: default;
}
.btn-approve { background: #1e6e1e; color: #6fcf97; }
.btn-approve:hover:not(:disabled) { background: #256325; }
.btn-edit { background: #4a4a00; color: #f2c94c; }
.btn-edit:hover:not(:disabled) { background: #606000; }
.btn-reject { background: #6e1e1e; color: #eb8f8f; }
.btn-reject:hover:not(:disabled) { background: #7a2222; }
/* Toast */
.rsv-toast {
position: fixed;
bottom: 1.5rem;
left: 50%;
transform: translateX(-50%);
background: var(--color-surface, #222);
color: var(--color-text, #fff);
border: 1px solid var(--color-border, #444);
border-radius: 8px;
padding: 0.5rem 1.25rem;
font-size: 0.85rem;
box-shadow: 0 4px 20px rgba(0,0,0,0.4);
pointer-events: none;
z-index: 100;
}
.toast-enter-active, .toast-leave-active { transition: opacity 0.2s, transform 0.2s; }
.toast-enter-from, .toast-leave-to { opacity: 0; transform: translateX(-50%) translateY(8px); }
</style>

View file

@ -0,0 +1,161 @@
import { mount, flushPromises } from '@vue/test-utils'
import { describe, it, expect, vi, beforeEach } from 'vitest'
import TrainJobsView from './TrainJobsView.vue'
const sampleJob = {
id: 'job-abc123',
type: 'classifier',
model_key: 'deberta-v3-small',
status: 'queued',
created_at: '2026-05-01T10:00:00Z',
config: null,
}
function makeFetch(jobs: unknown[] = []) {
return vi.fn().mockImplementation((url: string, opts?: RequestInit) => {
if ((opts?.method ?? 'GET') === 'POST') {
return Promise.resolve({
ok: true,
json: async () => ({ ...sampleJob, id: 'new-job', status: 'queued' }),
text: async () => '',
})
}
if ((opts?.method ?? 'GET') === 'DELETE') {
return Promise.resolve({ ok: true, json: async () => ({}), text: async () => '' })
}
// GET
return Promise.resolve({
ok: true,
json: async () => ({ jobs }),
text: async () => '',
})
})
}
class MockEventSource {
onmessage: ((e: MessageEvent) => void) | null = null
onerror: ((e: Event) => void) | null = null
private _url: string
constructor(url: string) { this._url = url }
close() {}
}
beforeEach(() => {
vi.stubGlobal('fetch', makeFetch([sampleJob]))
vi.stubGlobal('EventSource', MockEventSource)
})
describe('TrainJobsView', () => {
it('renders page title "Training Jobs"', async () => {
const w = mount(TrainJobsView)
await flushPromises()
expect(w.find('h1.page-title').text()).toContain('Training Jobs')
})
it('renders the new job form with type selector and model key input', async () => {
const w = mount(TrainJobsView)
await flushPromises()
expect(w.find('select.job-type-select').exists()).toBe(true)
expect(w.find('input.model-key-input').exists()).toBe(true)
expect(w.find('button.submit-job-btn').exists()).toBe(true)
})
it('type selector has classifier and llm-sft options', async () => {
const w = mount(TrainJobsView)
await flushPromises()
const options = w.findAll('select.job-type-select option')
const values = options.map(o => o.attributes('value') ?? o.element.textContent)
expect(values).toContain('classifier')
expect(values).toContain('llm-sft')
})
it('submit button is disabled when model key is empty', async () => {
const w = mount(TrainJobsView)
await flushPromises()
const btn = w.find('button.submit-job-btn')
expect((btn.element as HTMLButtonElement).disabled).toBe(true)
})
it('submit button is enabled when model key is entered', async () => {
const w = mount(TrainJobsView)
await flushPromises()
await w.find('input.model-key-input').setValue('deberta-v3-small')
const btn = w.find('button.submit-job-btn')
expect((btn.element as HTMLButtonElement).disabled).toBe(false)
})
it('shows job table with existing jobs', async () => {
const w = mount(TrainJobsView)
await flushPromises()
expect(w.find('table.jobs-table').exists()).toBe(true)
expect(w.text()).toContain('deberta-v3-small')
})
it('shows status pill for each job', async () => {
const w = mount(TrainJobsView)
await flushPromises()
expect(w.find('.status-pill').exists()).toBe(true)
expect(w.find('.status-queued').exists()).toBe(true)
})
it('shows cancel button for queued/running jobs', async () => {
const w = mount(TrainJobsView)
await flushPromises()
expect(w.find('button.cancel-btn').exists()).toBe(true)
})
it('submitting new job calls POST /api/train/jobs and refreshes', async () => {
const fetchMock = makeFetch([])
vi.stubGlobal('fetch', fetchMock)
const w = mount(TrainJobsView)
await flushPromises()
await w.find('input.model-key-input').setValue('my-model')
await w.find('button.submit-job-btn').trigger('click')
await flushPromises()
const calls = (fetchMock as ReturnType<typeof vi.fn>).mock.calls as [string, RequestInit?][]
const postCall = calls.find(([, opts]) => (opts?.method ?? 'GET') === 'POST')
expect(postCall).toBeDefined()
expect(postCall![0]).toContain('/api/train/jobs')
})
it('shows View Log button for running jobs', async () => {
vi.stubGlobal('fetch', makeFetch([{ ...sampleJob, status: 'running' }]))
const w = mount(TrainJobsView)
await flushPromises()
expect(w.find('button.view-log-btn').exists()).toBe(true)
})
it('shows error when config JSON is invalid', async () => {
const w = mount(TrainJobsView)
await flushPromises()
await w.find('input.model-key-input').setValue('my-model')
await w.find('textarea.config-textarea').setValue('{ not valid json }')
await w.find('button.submit-job-btn').trigger('click')
await flushPromises()
expect(w.find('.error-notice').exists()).toBe(true)
expect(w.find('.error-notice').text()).toContain('not valid')
})
it('shows error notice when jobs load fails', async () => {
vi.stubGlobal('fetch', vi.fn().mockResolvedValue({
ok: false,
status: 500,
json: async () => ({}),
text: async () => '',
}))
const w = mount(TrainJobsView)
await flushPromises()
expect(w.find('.error-notice').exists()).toBe(true)
expect(w.find('table.jobs-table').exists()).toBe(false)
})
it('cancel button optimistically updates job status to cancelled', async () => {
const w = mount(TrainJobsView)
await flushPromises()
await w.find('button.cancel-btn').trigger('click')
await flushPromises()
// After cancel, job should show status-cancelled pill (not status-queued)
expect(w.find('.status-cancelled').exists()).toBe(true)
expect(w.find('.status-queued').exists()).toBe(false)
})
})

View file

@ -0,0 +1,593 @@
<template>
<div class="train-jobs-view">
<header class="view-header">
<h1 class="page-title">🧠 Training Jobs</h1>
</header>
<!-- New Job form -->
<section class="section">
<h2 class="section-title">New Job</h2>
<form class="new-job-form" @submit.prevent="submitJob">
<div class="form-row">
<label class="form-label" for="job-type">Type</label>
<select
id="job-type"
v-model="form.type"
class="job-type-select form-control"
>
<option value="classifier">classifier</option>
<option value="llm-sft">llm-sft</option>
</select>
</div>
<div class="form-row">
<label class="form-label" for="model-key">Model key</label>
<input
id="model-key"
v-model.trim="form.model_key"
type="text"
class="model-key-input form-control"
placeholder="e.g. microsoft/deberta-v3-small"
autocomplete="off"
/>
</div>
<div class="form-row">
<label class="form-label" for="job-config">Config JSON <span class="form-hint">(optional)</span></label>
<textarea
id="job-config"
v-model="form.config_raw"
class="config-textarea form-control"
rows="4"
placeholder='{"learning_rate": 2e-5}'
/>
</div>
<div v-if="submitError" class="error-notice" role="alert">{{ submitError }}</div>
<button
type="submit"
class="submit-job-btn btn-primary"
:disabled="submitting || !form.model_key"
@click.prevent="submitJob"
>
{{ submitting ? 'Queuing…' : 'Queue Job' }}
</button>
</form>
</section>
<!-- Job queue table -->
<section class="section">
<h2 class="section-title">Job Queue</h2>
<div v-if="loadError" class="error-notice" role="alert">
{{ loadError }}
<button class="btn-retry" @click="loadJobs">Retry</button>
</div>
<div v-else-if="jobs.length === 0" class="empty-notice">
No training jobs yet.
</div>
<div v-else class="jobs-table-wrap">
<table class="jobs-table">
<thead>
<tr>
<th>ID</th>
<th>Type</th>
<th>Model</th>
<th>Status</th>
<th>Created</th>
<th></th>
</tr>
</thead>
<tbody>
<tr v-for="job in jobs" :key="job.id">
<td class="td-id" :title="job.id">{{ job.id.slice(0, 8) }}</td>
<td>
<span class="type-chip">{{ job.type }}</span>
</td>
<td class="td-model">{{ job.model_key }}</td>
<td>
<span class="status-pill" :class="`status-${job.status}`">{{ job.status }}</span>
</td>
<td class="td-date">{{ formatDate(job.created_at) }}</td>
<td class="td-actions">
<button
v-if="job.status === 'running'"
class="view-log-btn btn-sm"
@click="openLog(job.id)"
>
View Log
</button>
<button
v-if="job.status === 'queued' || job.status === 'running'"
class="cancel-btn btn-sm btn-danger-sm"
:disabled="cancellingId === job.id"
@click="cancelJob(job.id)"
>
{{ cancellingId === job.id ? '…' : 'Cancel' }}
</button>
</td>
</tr>
</tbody>
</table>
</div>
<div v-if="cancelError" class="error-notice" role="alert">{{ cancelError }}</div>
</section>
<!-- Log panel (SSE) -->
<section v-if="logJobId" class="section log-section">
<div class="log-header">
<h2 class="section-title">Log {{ logJobId.slice(0, 8) }}</h2>
<button class="btn-close-log" @click="closeLog"> Close</button>
</div>
<div class="log-panel" ref="logPanelEl">
<div
v-for="(line, i) in logLines"
:key="i"
class="log-line"
>{{ line }}</div>
<div v-if="logLines.length === 0" class="log-line log-muted">Connecting</div>
</div>
</section>
</div>
</template>
<script setup lang="ts">
import { ref, nextTick, onUnmounted } from 'vue'
import { useApiSSE } from '../composables/useApi'
interface TrainJob {
id: string
type: 'classifier' | 'llm-sft'
model_key: string
status: 'queued' | 'running' | 'completed' | 'failed' | 'cancelled'
created_at: string
config: Record<string, unknown> | null
}
const jobs = ref<TrainJob[]>([])
const loadError = ref<string | null>(null)
const submitError = ref<string | null>(null)
const submitting = ref(false)
const cancellingId = ref<string | null>(null)
const cancelError = ref<string | null>(null)
const form = ref({
type: 'classifier' as 'classifier' | 'llm-sft',
model_key: '',
config_raw: '',
})
// Log panel state
const logJobId = ref<string | null>(null)
const logLines = ref<string[]>([])
const logPanelEl = ref<HTMLElement | null>(null)
let closeSSE: (() => void) | null = null
// Data loading
async function loadJobs() {
loadError.value = null
try {
const res = await fetch('/api/train/jobs')
if (!res.ok) { loadError.value = `Failed to load jobs (HTTP ${res.status}).`; return }
const data = await res.json() as { jobs: TrainJob[] }
jobs.value = data.jobs ?? []
} catch {
loadError.value = 'Network error loading jobs.'
}
}
// Submit
async function submitJob() {
if (!form.value.model_key) return
submitError.value = null
submitting.value = true
let config: Record<string, unknown> | null = null
if (form.value.config_raw.trim()) {
try {
config = JSON.parse(form.value.config_raw) as Record<string, unknown>
} catch {
submitError.value = 'Config JSON is not valid. Fix it before submitting.'
submitting.value = false
return
}
}
try {
const res = await fetch('/api/train/jobs', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({
type: form.value.type,
model_key: form.value.model_key,
config_json: config,
}),
})
if (!res.ok) {
const detail = await res.text().catch(() => '')
submitError.value = `Failed to queue job (HTTP ${res.status})${detail ? `: ${detail}` : '.'}`
return
}
const newJob = await res.json() as TrainJob
jobs.value = [newJob, ...jobs.value]
form.value = { type: 'classifier', model_key: '', config_raw: '' }
} catch {
submitError.value = 'Network error submitting job.'
} finally {
submitting.value = false
}
}
// Cancel
async function cancelJob(id: string) {
cancellingId.value = id
cancelError.value = null
try {
const res = await fetch(`/api/train/jobs/${encodeURIComponent(id)}/cancel`, { method: 'DELETE' })
if (res.ok) {
jobs.value = jobs.value.map(j =>
j.id === id ? { ...j, status: 'cancelled' as const } : j
)
} else {
cancelError.value = `Failed to cancel job (HTTP ${res.status}).`
}
} catch {
cancelError.value = 'Network error cancelling job.'
} finally {
cancellingId.value = null
}
}
// Log SSE
function openLog(id: string) {
closeLog()
logJobId.value = id
logLines.value = []
closeSSE = useApiSSE(
`/api/train/jobs/${encodeURIComponent(id)}/run`,
(data) => {
if (data.type === 'log' || data.type === 'progress' || data.type === 'error') {
logLines.value = [...logLines.value, String(data.message ?? '')]
nextTick(() => {
if (logPanelEl.value) {
logPanelEl.value.scrollTop = logPanelEl.value.scrollHeight
}
})
}
if (data.type === 'error') {
logLines.value = [...logLines.value, '--- stream ended with error ---']
nextTick(() => {
if (logPanelEl.value) {
logPanelEl.value.scrollTop = logPanelEl.value.scrollHeight
}
})
}
},
() => {
logLines.value = [...logLines.value, '--- stream complete ---']
},
() => {
logLines.value = [...logLines.value, '--- connection lost ---']
},
)
}
function closeLog() {
closeSSE?.()
closeSSE = null
logJobId.value = null
logLines.value = []
}
// Helpers
function formatDate(iso: string): string {
const d = new Date(iso)
if (isNaN(d.getTime())) return iso
return d.toLocaleString(undefined, { dateStyle: 'short', timeStyle: 'short' })
}
// Lifecycle
loadJobs()
onUnmounted(() => {
closeSSE?.()
})
</script>
<style scoped>
.train-jobs-view {
max-width: 860px;
margin: 0 auto;
padding: 1.5rem 1rem 4rem;
display: flex;
flex-direction: column;
gap: 2rem;
}
.view-header { display: flex; align-items: center; }
.page-title {
font-family: var(--font-display, var(--font-body, sans-serif));
font-size: 1.4rem;
font-weight: 700;
color: var(--app-primary, #2A6080);
margin: 0;
}
.section {
display: flex;
flex-direction: column;
gap: 0.75rem;
}
.section-title {
font-size: 1rem;
font-weight: 600;
color: var(--color-text, #1a2338);
padding-bottom: 0.4rem;
border-bottom: 1px solid var(--color-border, #a8b8d0);
margin: 0;
}
.new-job-form {
display: flex;
flex-direction: column;
gap: 0.75rem;
max-width: 480px;
}
.form-row {
display: flex;
flex-direction: column;
gap: 0.3rem;
}
.form-label {
font-size: 0.85rem;
font-weight: 600;
color: var(--color-text-muted, #4a5c7a);
}
.form-hint {
font-weight: 400;
font-size: 0.78rem;
}
.form-control {
padding: 0.45rem 0.65rem;
border: 1px solid var(--color-border, #a8b8d0);
border-radius: var(--radius-md, 0.5rem);
background: var(--color-surface-raised, #f5f7fc);
color: var(--color-text, #1a2338);
font-size: 0.9rem;
font-family: var(--font-body, sans-serif);
}
.form-control:focus {
outline: 2px solid var(--app-primary, #2A6080);
outline-offset: -1px;
}
.config-textarea {
resize: vertical;
font-family: var(--font-mono, monospace);
font-size: 0.82rem;
}
.btn-primary {
padding: 0.4rem 0.9rem;
border-radius: var(--radius-md, 0.5rem);
border: 1px solid var(--app-primary, #2A6080);
background: var(--app-primary, #2A6080);
color: #fff;
font-size: 0.88rem;
font-family: var(--font-body, sans-serif);
cursor: pointer;
align-self: flex-start;
transition: opacity 0.15s;
}
.btn-primary:disabled { opacity: 0.5; cursor: not-allowed; }
.btn-primary:not(:disabled):hover { opacity: 0.85; }
.btn-sm {
padding: 0.2rem 0.55rem;
font-size: 0.78rem;
border-radius: 0.3rem;
cursor: pointer;
font-family: var(--font-body, sans-serif);
border: 1px solid;
transition: background 0.1s;
}
.view-log-btn {
border-color: var(--color-info, #1e6091);
background: transparent;
color: var(--color-info, #1e6091);
}
.view-log-btn:hover {
background: color-mix(in srgb, var(--color-info, #1e6091) 10%, transparent);
}
.btn-danger-sm {
border-color: var(--color-error, #c0392b);
background: transparent;
color: var(--color-error, #c0392b);
}
.btn-danger-sm:hover:not(:disabled) {
background: color-mix(in srgb, var(--color-error, #c0392b) 10%, transparent);
}
.btn-danger-sm:disabled { opacity: 0.5; cursor: not-allowed; }
.btn-retry {
margin-left: 0.5rem;
padding: 0.2rem 0.55rem;
border-radius: 0.25rem;
border: 1px solid var(--color-error, #c0392b);
background: transparent;
color: var(--color-error, #c0392b);
cursor: pointer;
font-size: 0.82rem;
}
.error-notice {
padding: 0.6rem 0.8rem;
background: color-mix(in srgb, var(--color-error, #c0392b) 10%, transparent);
border: 1px solid color-mix(in srgb, var(--color-error, #c0392b) 30%, transparent);
border-radius: var(--radius-md, 0.5rem);
color: var(--color-error, #c0392b);
font-size: 0.88rem;
display: flex;
align-items: center;
gap: 0.5rem;
}
.empty-notice {
color: var(--color-text-muted, #4a5c7a);
font-size: 0.9rem;
padding: 0.75rem;
border: 1px dashed var(--color-border, #a8b8d0);
border-radius: var(--radius-md, 0.5rem);
}
.jobs-table-wrap { overflow-x: auto; }
.jobs-table {
width: 100%;
border-collapse: collapse;
font-size: 0.875rem;
}
.jobs-table th {
text-align: left;
padding: 0.4rem 0.6rem;
background: var(--color-surface-raised, #f5f7fc);
color: var(--color-text-muted, #4a5c7a);
font-size: 0.78rem;
font-weight: 600;
text-transform: uppercase;
letter-spacing: 0.03em;
border-bottom: 1px solid var(--color-border, #a8b8d0);
white-space: nowrap;
}
.jobs-table td {
padding: 0.5rem 0.6rem;
border-bottom: 1px solid var(--color-border-light, #ccd5e6);
vertical-align: middle;
}
.td-id {
font-family: var(--font-mono, monospace);
font-size: 0.78rem;
color: var(--color-text-muted, #4a5c7a);
}
.td-model {
font-family: var(--font-mono, monospace);
font-size: 0.82rem;
word-break: break-all;
}
.td-date {
font-size: 0.8rem;
color: var(--color-text-muted, #4a5c7a);
white-space: nowrap;
}
.td-actions {
display: flex;
gap: 0.35rem;
align-items: center;
flex-wrap: wrap;
}
.status-pill {
font-size: 0.68rem;
font-weight: 700;
text-transform: uppercase;
letter-spacing: 0.04em;
padding: 0.15rem 0.45rem;
border-radius: var(--radius-full, 9999px);
white-space: nowrap;
}
.status-queued { background: var(--color-surface-alt, #dde4f0); color: var(--color-text-muted, #4a5c7a); }
.status-running { background: color-mix(in srgb, var(--color-info, #1e6091) 15%, transparent); color: var(--color-info, #1e6091); }
.status-completed { background: color-mix(in srgb, var(--color-success, #3a7a32) 15%, transparent); color: var(--color-success, #3a7a32); }
.status-failed { background: color-mix(in srgb, var(--color-error, #c0392b) 15%, transparent); color: var(--color-error, #c0392b); }
.status-cancelled { background: color-mix(in srgb, var(--color-warning, #d4891a) 15%, transparent); color: var(--color-warning, #d4891a); }
.type-chip {
font-size: 0.72rem;
font-family: var(--font-mono, monospace);
padding: 0.1rem 0.4rem;
border-radius: 0.25rem;
background: var(--color-surface-alt, #dde4f0);
color: var(--color-text, #1a2338);
}
.log-section { gap: 0.5rem; }
.log-header {
display: flex;
align-items: center;
justify-content: space-between;
gap: 0.5rem;
}
.btn-close-log {
background: transparent;
border: 1px solid var(--color-border, #d0d7e8);
border-radius: 0.25rem;
cursor: pointer;
font-size: 0.8rem;
padding: 0.2rem 0.5rem;
color: var(--color-text-muted, #4a5c7a);
transition: background 0.1s;
}
.btn-close-log:hover { background: var(--color-surface-raised, #e4ebf5); }
.log-panel {
border: 1px solid var(--color-border, #d0d7e8);
border-radius: 0.5rem;
max-height: 320px;
overflow-y: auto;
padding: 0.5rem 0.75rem;
background: var(--color-surface, #f0f4fc);
font-family: var(--font-mono, monospace);
font-size: 0.78rem;
}
.log-line {
color: var(--color-text, #1a2338);
line-height: 1.5;
white-space: pre-wrap;
word-break: break-all;
}
.log-muted { color: var(--color-text-muted, #4a5c7a); }
@media (max-width: 560px) {
.jobs-table th:nth-child(4),
.jobs-table td:nth-child(4),
.jobs-table th:nth-child(5),
.jobs-table td:nth-child(5) {
display: none;
}
}
</style>

View file

@ -0,0 +1,101 @@
import { mount, flushPromises } from '@vue/test-utils'
import { createRouter, createWebHashHistory } from 'vue-router'
import { describe, it, expect, vi, beforeEach } from 'vitest'
import TrainResultsView from './TrainResultsView.vue'
const router = createRouter({
history: createWebHashHistory(),
routes: [
{ path: '/fleet', component: { template: '<div />' } },
],
})
const sampleResult = {
id: 'run-xyz',
job_id: 'job-abc123',
model_type: 'classifier',
base_model: 'microsoft/deberta-v3-small',
val_macro_f1: 0.847,
val_accuracy: 0.891,
sample_count: 1240,
duration_seconds: 842,
created_at: '2026-05-01T11:30:00Z',
}
function makeFetch(results: unknown[] = []) {
return vi.fn().mockResolvedValue({
ok: true,
json: async () => ({ results }),
text: async () => '',
})
}
beforeEach(() => {
vi.stubGlobal('fetch', makeFetch([sampleResult]))
})
describe('TrainResultsView', () => {
it('renders page title "Training Results"', async () => {
const w = mount(TrainResultsView, { global: { plugins: [router] } })
await flushPromises()
expect(w.find('h1.page-title').text()).toContain('Training Results')
})
it('shows empty notice when there are no results', async () => {
vi.stubGlobal('fetch', makeFetch([]))
const w = mount(TrainResultsView, { global: { plugins: [router] } })
await flushPromises()
expect(w.find('.empty-notice').exists()).toBe(true)
})
it('renders results table when results exist', async () => {
const w = mount(TrainResultsView, { global: { plugins: [router] } })
await flushPromises()
expect(w.find('table.results-table').exists()).toBe(true)
})
it('shows base_model in table', async () => {
const w = mount(TrainResultsView, { global: { plugins: [router] } })
await flushPromises()
expect(w.text()).toContain('deberta-v3-small')
})
it('shows val_macro_f1 formatted as percentage', async () => {
const w = mount(TrainResultsView, { global: { plugins: [router] } })
await flushPromises()
expect(w.text()).toContain('84.7%')
})
it('shows val_accuracy formatted as percentage', async () => {
const w = mount(TrainResultsView, { global: { plugins: [router] } })
await flushPromises()
expect(w.text()).toContain('89.1%')
})
it('shows duration formatted as minutes and seconds', async () => {
const w = mount(TrainResultsView, { global: { plugins: [router] } })
await flushPromises()
// 842 seconds = 14m 2s
expect(w.text()).toContain('14m')
})
it('shows Register in Fleet button for classifier results', async () => {
const w = mount(TrainResultsView, { global: { plugins: [router] } })
await flushPromises()
expect(w.find('a.register-btn').exists()).toBe(true)
})
it('does NOT show Register in Fleet button for llm-sft results', async () => {
vi.stubGlobal('fetch', makeFetch([{ ...sampleResult, model_type: 'llm-sft' }]))
const w = mount(TrainResultsView, { global: { plugins: [router] } })
await flushPromises()
expect(w.find('a.register-btn').exists()).toBe(false)
})
it('shows error notice when API fails', async () => {
vi.stubGlobal('fetch', vi.fn().mockResolvedValue({ ok: false, status: 500, text: async () => '' }))
const w = mount(TrainResultsView, { global: { plugins: [router] } })
await flushPromises()
expect(w.find('.error-notice').exists()).toBe(true)
})
})

View file

@ -0,0 +1,296 @@
<template>
<div class="train-results-view">
<header class="view-header">
<h1 class="page-title">Training Results</h1>
<button class="refresh-btn" :disabled="loading" @click="loadResults" aria-label="Refresh">&#x1F504;</button>
</header>
<div v-if="error" class="error-notice" role="alert">
{{ error }}
<button class="btn-retry" @click="loadResults">Retry</button>
</div>
<div v-if="loading" class="loading-state" aria-live="polite">Loading</div>
<div v-if="!error && results.length === 0 && !loading" class="empty-notice">
No training results yet. Completed jobs will appear here.
</div>
<div v-if="results.length > 0" class="results-table-wrap">
<table class="results-table">
<thead>
<tr>
<th>Run</th>
<th>Type</th>
<th>Base Model</th>
<th class="th-numeric">Macro F1</th>
<th class="th-numeric">Accuracy</th>
<th class="th-numeric">Samples</th>
<th class="th-numeric">Duration</th>
<th></th>
</tr>
</thead>
<tbody>
<tr v-for="r in results" :key="r.id">
<td class="td-id" :title="r.id">{{ r.id.slice(0, 8) }}</td>
<td>
<span class="type-chip">{{ r.model_type }}</span>
</td>
<td class="td-model" :title="r.base_model">{{ shortModel(r.base_model) }}</td>
<td class="td-numeric">
<span class="metric-val" :class="scoreClass(r.val_macro_f1)">
{{ formatPct(r.val_macro_f1) }}
</span>
</td>
<td class="td-numeric">{{ formatPct(r.val_accuracy) }}</td>
<td class="td-numeric">{{ r.sample_count.toLocaleString() }}</td>
<td class="td-numeric">{{ formatDuration(r.duration_seconds) }}</td>
<td class="td-actions">
<RouterLink
v-if="r.model_type === 'classifier'"
:to="`/fleet?model=${encodeURIComponent(r.base_model)}`"
class="register-btn btn-sm-link"
>
Register in Fleet
</RouterLink>
</td>
</tr>
</tbody>
</table>
</div>
</div>
</template>
<script setup lang="ts">
import { ref, onMounted } from 'vue'
import { RouterLink } from 'vue-router'
interface TrainResult {
id: string
job_id: string
model_type: string
base_model: string
val_macro_f1: number | null
val_accuracy: number | null
sample_count: number
duration_seconds: number | null
created_at: string
}
const results = ref<TrainResult[]>([])
const loading = ref(false)
const error = ref<string | null>(null)
async function loadResults() {
loading.value = true
error.value = null
try {
const res = await fetch('/api/train/results')
if (!res.ok) {
error.value = `Failed to load results (HTTP ${res.status}).`
return
}
const raw = await res.json() as { results?: TrainResult[] }
results.value = Array.isArray(raw?.results) ? raw.results : []
} catch {
error.value = 'Network error loading results.'
} finally {
loading.value = false
}
}
function formatPct(v: number | null | undefined): string {
if (v == null) return '—'
return `${(v * 100).toFixed(1)}%`
}
function formatDuration(seconds: number | null | undefined): string {
if (seconds == null) return '—'
const mins = Math.floor(seconds / 60)
const secs = seconds % 60
if (mins === 0) return `${secs}s`
return `${mins}m ${secs}s`
}
function shortModel(model: string): string {
const parts = model.split('/')
return parts[parts.length - 1] ?? model
}
function scoreClass(f1: number | null | undefined): string {
if (f1 == null) return ''
if (f1 >= 0.85) return 'score-great'
if (f1 >= 0.75) return 'score-good'
return 'score-fair'
}
onMounted(() => loadResults())
</script>
<style scoped>
.train-results-view {
max-width: 860px;
margin: 0 auto;
padding: 1.5rem 1rem 4rem;
display: flex;
flex-direction: column;
gap: 1.75rem;
}
.view-header {
display: flex;
align-items: center;
gap: 0.75rem;
}
.page-title {
font-family: var(--font-display, var(--font-body, sans-serif));
font-size: 1.4rem;
font-weight: 700;
color: var(--app-primary, #2A6080);
margin: 0;
flex: 1;
}
.refresh-btn {
background: transparent;
border: 1px solid var(--color-border, #d0d7e8);
border-radius: 0.375rem;
cursor: pointer;
font-size: 1rem;
padding: 0.3rem 0.5rem;
transition: background 0.15s;
}
.refresh-btn:hover:not(:disabled) { background: var(--color-surface-raised, #e4ebf5); }
.refresh-btn:disabled { opacity: 0.5; cursor: not-allowed; }
.error-notice {
display: flex;
align-items: center;
gap: 0.75rem;
padding: 0.75rem 1rem;
background: color-mix(in srgb, var(--color-error, #c0392b) 10%, transparent);
border: 1px solid color-mix(in srgb, var(--color-error, #c0392b) 30%, transparent);
border-radius: var(--radius-md, 0.5rem);
color: var(--color-error, #c0392b);
font-size: 0.88rem;
}
.btn-retry {
padding: 0.2rem 0.55rem;
border-radius: 0.25rem;
border: 1px solid var(--color-error, #c0392b);
background: transparent;
color: var(--color-error, #c0392b);
cursor: pointer;
font-size: 0.82rem;
}
.empty-notice {
color: var(--color-text-muted, #4a5c7a);
font-size: 0.9rem;
padding: 0.75rem;
border: 1px dashed var(--color-border, #a8b8d0);
border-radius: var(--radius-md, 0.5rem);
}
.loading-state {
color: var(--color-text-muted, #4a5c7a);
font-size: 0.9rem;
padding: 0.75rem;
}
.results-table-wrap { overflow-x: auto; }
.results-table {
width: 100%;
border-collapse: collapse;
font-size: 0.875rem;
}
.results-table th {
text-align: left;
padding: 0.4rem 0.6rem;
background: var(--color-surface-raised, #f5f7fc);
color: var(--color-text-muted, #4a5c7a);
font-size: 0.78rem;
font-weight: 600;
text-transform: uppercase;
letter-spacing: 0.03em;
border-bottom: 1px solid var(--color-border, #a8b8d0);
white-space: nowrap;
}
.th-numeric { text-align: right; }
.results-table td {
padding: 0.5rem 0.6rem;
border-bottom: 1px solid var(--color-border-light, #ccd5e6);
vertical-align: middle;
}
.td-id {
font-family: var(--font-mono, monospace);
font-size: 0.78rem;
color: var(--color-text-muted, #4a5c7a);
}
.td-model {
font-family: var(--font-mono, monospace);
font-size: 0.82rem;
max-width: 16ch;
overflow: hidden;
text-overflow: ellipsis;
white-space: nowrap;
}
.td-numeric {
text-align: right;
font-family: var(--font-mono, monospace);
font-variant-numeric: tabular-nums;
font-size: 0.82rem;
}
.td-actions { text-align: right; }
.metric-val { font-weight: 600; }
.score-great { color: var(--color-success, #3a7a32); }
.score-good { color: var(--color-warning, #d4891a); }
.score-fair { color: var(--color-text-muted, #4a5c7a); }
.type-chip {
font-size: 0.72rem;
font-family: var(--font-mono, monospace);
padding: 0.1rem 0.4rem;
border-radius: 0.25rem;
background: var(--color-surface-alt, #dde4f0);
color: var(--color-text, #1a2338);
}
.btn-sm-link {
font-size: 0.78rem;
padding: 0.2rem 0.55rem;
border-radius: 0.3rem;
border: 1px solid var(--app-primary, #2A6080);
color: var(--app-primary, #2A6080);
background: transparent;
text-decoration: none;
white-space: nowrap;
display: inline-block;
transition: background 0.1s;
}
.btn-sm-link:hover {
background: color-mix(in srgb, var(--app-primary, #2A6080) 10%, transparent);
}
@media (max-width: 600px) {
.results-table th:nth-child(6),
.results-table td:nth-child(6),
.results-table th:nth-child(7),
.results-table td:nth-child(7) {
display: none;
}
}
</style>