- BenchmarkView.vue: convert from monolithic view to tabbed shell; each tab is now its own component (ClassifierTab, CompareTab, LlmEvalTab, StyleTab, VoiceTab) - StyleTab + VoiceTab: new benchmark modes for style and voice model evaluation - app/style.py: FastAPI router for style imitation benchmarks - app/voice.py: FastAPI router for voice benchmark endpoints - scripts/benchmark_style.py + benchmark_voice.py: headless runner scripts
427 lines
17 KiB
Python
427 lines
17 KiB
Python
"""Avocet — Writing style benchmark integration API.
|
|
|
|
Wraps scripts/benchmark_style.py and exposes it via the Avocet API.
|
|
Connection config (coordinator_url, ollama_url, python_bin) is read
|
|
from label_tool.yaml under the `cforch:` key — the same block used
|
|
by cforch.py, so no new config section is needed.
|
|
|
|
All endpoints are registered on `router` (a FastAPI APIRouter).
|
|
api.py includes this router with prefix="/api/style".
|
|
|
|
Module-level globals (_BENCH_RUNNING, _bench_proc) follow the same
|
|
testability pattern as cforch.py.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
import subprocess as _subprocess
|
|
import uuid
|
|
from datetime import datetime, timezone
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
import httpx
|
|
import yaml
|
|
from fastapi import APIRouter, HTTPException, Query
|
|
from fastapi.responses import StreamingResponse
|
|
from pydantic import BaseModel
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_ROOT = Path(__file__).parent.parent
|
|
_CONFIG_DIR: Path | None = None # override in tests via set_config_dir()
|
|
_BENCH_RUNNING: bool = False
|
|
_bench_proc: Any = None
|
|
|
|
_BENCH_SCRIPT = _ROOT / "scripts" / "benchmark_style.py"
|
|
_RESULTS_DIR = _ROOT / "benchmark_results"
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
# ── Testability seams ──────────────────────────────────────────────────────────
|
|
|
|
def set_config_dir(path: Path | None) -> None:
|
|
global _CONFIG_DIR
|
|
_CONFIG_DIR = path
|
|
|
|
|
|
# ── Internal helpers ───────────────────────────────────────────────────────────
|
|
|
|
def _config_file() -> Path:
|
|
if _CONFIG_DIR is not None:
|
|
return _CONFIG_DIR / "label_tool.yaml"
|
|
return _ROOT / "config" / "label_tool.yaml"
|
|
|
|
|
|
def _load_config() -> dict:
|
|
"""Read label_tool.yaml cforch section for coordinator/ollama/python config."""
|
|
f = _config_file()
|
|
file_cfg: dict = {}
|
|
if f.exists():
|
|
try:
|
|
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
|
file_cfg = raw.get("cforch", {}) or {}
|
|
except yaml.YAMLError as exc:
|
|
logger.warning("Failed to parse style config %s: %s", f, exc)
|
|
return {
|
|
"coordinator_url": file_cfg.get("coordinator_url", "http://10.1.10.71:7700"),
|
|
"ollama_url": file_cfg.get("ollama_url", "http://localhost:11434"),
|
|
"python_bin": file_cfg.get("python_bin", "/devl/miniconda3/envs/cf/bin/python"),
|
|
}
|
|
|
|
|
|
# ── GET /models ────────────────────────────────────────────────────────────────
|
|
|
|
@router.get("/models")
|
|
def get_models() -> dict:
|
|
"""Return available models grouped by source.
|
|
|
|
- ollama: fetched live from /api/tags (includes any models downloaded
|
|
via the Models view — automatically in sync)
|
|
- cf_text: fetched from cf-orch catalog endpoint (requires node profile
|
|
entry + coordinator restart when new GGUFs are added)
|
|
"""
|
|
cfg = _load_config()
|
|
|
|
# Ollama models — live query so newly downloaded models appear immediately
|
|
ollama_models: list[dict] = []
|
|
try:
|
|
resp = httpx.get(f"{cfg['ollama_url']}/api/tags", timeout=5.0)
|
|
resp.raise_for_status()
|
|
for m in resp.json().get("models", []):
|
|
name = m.get("name", "")
|
|
if name:
|
|
size_bytes = m.get("size", 0)
|
|
ollama_models.append({
|
|
"id": name,
|
|
"name": name,
|
|
"source": "ollama",
|
|
"size_mb": round(size_bytes / (1024 * 1024)) if size_bytes else None,
|
|
"vram_mb": None,
|
|
})
|
|
except Exception as exc:
|
|
logger.warning("Failed to fetch ollama models: %s", exc)
|
|
|
|
# cf-text catalog — fetched from cf-orch coordinator
|
|
cftext_models: list[dict] = []
|
|
try:
|
|
resp = httpx.get(
|
|
f"{cfg['coordinator_url']}/api/services/cf-text/catalog",
|
|
timeout=5.0,
|
|
)
|
|
resp.raise_for_status()
|
|
for model_id, entry in resp.json().items():
|
|
if isinstance(entry, dict):
|
|
cftext_models.append({
|
|
"id": model_id,
|
|
"name": model_id,
|
|
"source": "cf-text",
|
|
"vram_mb": entry.get("vram_mb"),
|
|
"description": entry.get("description", ""),
|
|
})
|
|
except Exception as exc:
|
|
logger.warning("Failed to fetch cf-text catalog: %s", exc)
|
|
|
|
return {"ollama": ollama_models, "cf_text": cftext_models}
|
|
|
|
|
|
# ── GET /run ───────────────────────────────────────────────────────────────────
|
|
|
|
@router.get("/run")
|
|
def run_style_benchmark(
|
|
models: str = Query("", description="Comma-separated model IDs (empty = all)"),
|
|
use_cforch: bool = Query(False),
|
|
max_vram: int = Query(7200, description="Max VRAM MB for cf-orch OOM filter"),
|
|
include_large: bool = Query(False, description="Include large (30B+) ollama models"),
|
|
workers: int = Query(1, description="Parallel workers — run N models simultaneously"),
|
|
) -> StreamingResponse:
|
|
"""Spawn benchmark_style.py and stream stdout as SSE progress events.
|
|
|
|
On successful completion, emits a final `type: result` event containing
|
|
the parsed JSON from the newest style_*.json file.
|
|
"""
|
|
global _BENCH_RUNNING, _bench_proc
|
|
|
|
if _BENCH_RUNNING:
|
|
raise HTTPException(409, "A writing style benchmark is already running")
|
|
|
|
cfg = _load_config()
|
|
python_bin = cfg["python_bin"]
|
|
|
|
def generate():
|
|
global _BENCH_RUNNING, _bench_proc
|
|
|
|
if not _BENCH_SCRIPT.exists():
|
|
yield f"data: {json.dumps({'type': 'error', 'message': f'benchmark_style.py not found at {_BENCH_SCRIPT}'})}\n\n"
|
|
return
|
|
|
|
cmd = [python_bin, str(_BENCH_SCRIPT), "run"]
|
|
|
|
if models:
|
|
cmd.extend(["--models", ",".join(m.strip() for m in models.split(",") if m.strip())])
|
|
if use_cforch:
|
|
cmd.extend(["--cforch", "--cforch-url", cfg["coordinator_url"],
|
|
"--max-vram", str(max_vram)])
|
|
if include_large:
|
|
cmd.append("--include-large")
|
|
if workers > 1:
|
|
cmd.extend(["--workers", str(workers)])
|
|
|
|
_BENCH_RUNNING = True
|
|
try:
|
|
proc = _subprocess.Popen(
|
|
cmd,
|
|
stdout=_subprocess.PIPE,
|
|
stderr=_subprocess.STDOUT,
|
|
text=True,
|
|
bufsize=1,
|
|
cwd=str(_ROOT),
|
|
)
|
|
_bench_proc = proc
|
|
try:
|
|
for line in proc.stdout:
|
|
line = line.rstrip()
|
|
if line:
|
|
yield f"data: {json.dumps({'type': 'progress', 'message': line})}\n\n"
|
|
proc.wait()
|
|
if proc.returncode == 0:
|
|
result_files = sorted(_RESULTS_DIR.glob("style_*.json"))
|
|
if result_files:
|
|
try:
|
|
results = json.loads(result_files[-1].read_text(encoding="utf-8"))
|
|
yield f"data: {json.dumps({'type': 'result', 'results': results, 'filename': result_files[-1].name})}\n\n"
|
|
except Exception as exc:
|
|
logger.warning("Failed to read style results: %s", exc)
|
|
yield f"data: {json.dumps({'type': 'complete'})}\n\n"
|
|
else:
|
|
yield f"data: {json.dumps({'type': 'error', 'message': f'Process exited with code {proc.returncode}'})}\n\n"
|
|
finally:
|
|
_bench_proc = None
|
|
except Exception as exc:
|
|
yield f"data: {json.dumps({'type': 'error', 'message': str(exc)})}\n\n"
|
|
finally:
|
|
_BENCH_RUNNING = False
|
|
|
|
return StreamingResponse(
|
|
generate(),
|
|
media_type="text/event-stream",
|
|
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
|
|
)
|
|
|
|
|
|
# ── GET /results ───────────────────────────────────────────────────────────────
|
|
|
|
@router.get("/results")
|
|
def list_results() -> list[dict]:
|
|
"""List past writing style benchmark runs, newest first.
|
|
|
|
Returns lightweight summaries (date, model count, top score).
|
|
Use /results/{filename} to fetch full model-level detail.
|
|
"""
|
|
if not _RESULTS_DIR.exists():
|
|
return []
|
|
|
|
runs: list[dict] = []
|
|
for f in sorted(_RESULTS_DIR.glob("style_*.json"), reverse=True):
|
|
stem = f.stem # style_2026-04-22_1502
|
|
date_str = stem.removeprefix("style_") # 2026-04-22_1502
|
|
try:
|
|
date_part, time_part = date_str.split("_")
|
|
display_date = f"{date_part} {time_part[:2]}:{time_part[2:]}"
|
|
except Exception:
|
|
display_date = date_str
|
|
|
|
try:
|
|
results = json.loads(f.read_text(encoding="utf-8"))
|
|
top_score = max((r.get("avg_score", 0) for r in results), default=0)
|
|
model_count = len(results)
|
|
except Exception:
|
|
top_score = 0
|
|
model_count = 0
|
|
|
|
runs.append({
|
|
"filename": f.name,
|
|
"date": display_date,
|
|
"model_count": model_count,
|
|
"top_score": round(top_score, 1),
|
|
})
|
|
|
|
return runs
|
|
|
|
|
|
@router.get("/results/latest")
|
|
def get_latest_results() -> list[dict]:
|
|
"""Return the latest writing style benchmark result list."""
|
|
if not _RESULTS_DIR.exists():
|
|
raise HTTPException(404, "No benchmark results found")
|
|
files = sorted(_RESULTS_DIR.glob("style_*.json"))
|
|
if not files:
|
|
raise HTTPException(404, "No benchmark results found")
|
|
try:
|
|
return json.loads(files[-1].read_text(encoding="utf-8"))
|
|
except Exception as exc:
|
|
raise HTTPException(500, f"Failed to read results: {exc}") from exc
|
|
|
|
|
|
@router.get("/results/{filename}")
|
|
def get_results_by_filename(filename: str) -> list[dict]:
|
|
"""Return writing style benchmark results for a specific run file."""
|
|
if not filename.startswith("style_") or not filename.endswith(".json"):
|
|
raise HTTPException(400, "Invalid filename — expected style_*.json")
|
|
f = _RESULTS_DIR / filename
|
|
if not f.exists():
|
|
raise HTTPException(404, f"Results file not found: {filename}")
|
|
try:
|
|
return json.loads(f.read_text(encoding="utf-8"))
|
|
except Exception as exc:
|
|
raise HTTPException(500, f"Failed to read results: {exc}") from exc
|
|
|
|
|
|
# ── POST /send-to-corrections ──────────────────────────────────────────────────
|
|
|
|
class SendToCorrectionsRequest(BaseModel):
|
|
filename: str # style_YYYY-MM-DD_HHMM.json — the source run file
|
|
model_ids: list[str] = [] # empty = all models in the run
|
|
|
|
|
|
@router.post("/send-to-corrections")
|
|
def send_to_corrections(req: SendToCorrectionsRequest) -> dict:
|
|
"""Push writing style benchmark outputs into the SFT corrections queue.
|
|
|
|
Each prompt_result from the selected models becomes one SFT candidate
|
|
with status='needs_review'. Duplicates are skipped via the 'id' field
|
|
(hash of model_id + tag).
|
|
"""
|
|
if not req.filename.startswith("style_") or not req.filename.endswith(".json"):
|
|
raise HTTPException(400, "Invalid filename")
|
|
|
|
src = _RESULTS_DIR / req.filename
|
|
if not src.exists():
|
|
raise HTTPException(404, f"Results file not found: {req.filename}")
|
|
|
|
try:
|
|
run_results: list[dict] = json.loads(src.read_text(encoding="utf-8"))
|
|
except Exception as exc:
|
|
raise HTTPException(500, f"Failed to read results: {exc}") from exc
|
|
|
|
# Resolve sft_candidates.jsonl path (same logic as sft.py)
|
|
sft_data_dir = _ROOT / "data"
|
|
sft_file = sft_data_dir / "sft_candidates.jsonl"
|
|
|
|
# Load existing IDs to deduplicate
|
|
existing_ids: set[str] = set()
|
|
if sft_file.exists():
|
|
for line in sft_file.read_text(encoding="utf-8").splitlines():
|
|
line = line.strip()
|
|
if line:
|
|
try:
|
|
existing_ids.add(json.loads(line)["id"])
|
|
except Exception:
|
|
pass
|
|
|
|
run_id = req.filename.removesuffix(".json") # style_2026-04-22_1502
|
|
timestamp = datetime.now(tz=timezone.utc).isoformat()
|
|
|
|
new_candidates: list[dict] = []
|
|
for model_result in run_results:
|
|
model_id = model_result.get("model_id", "")
|
|
if req.model_ids and model_id not in req.model_ids:
|
|
continue
|
|
for pr in model_result.get("prompt_results", []):
|
|
tag = pr.get("tag", "")
|
|
# Stable id: deterministic hash of run + model + prompt tag
|
|
candidate_id = str(uuid.uuid5(
|
|
uuid.NAMESPACE_URL,
|
|
f"style-benchmark/{run_id}/{model_id}/{tag}",
|
|
))
|
|
if candidate_id in existing_ids:
|
|
continue
|
|
|
|
score_pct = pr.get("score", 0.0) / 100.0
|
|
signals = pr.get("signals", {})
|
|
|
|
# Build the prompt message list matching the benchmark's actual request
|
|
prompt_messages = [
|
|
{"role": "system", "content": _STYLE_SYSTEM_PROMPT},
|
|
{"role": "user", "content": pr.get("user_prompt", tag)},
|
|
]
|
|
|
|
new_candidates.append({
|
|
"id": candidate_id,
|
|
"source": "style-benchmark",
|
|
"benchmark_run_id": run_id,
|
|
"timestamp": timestamp,
|
|
"status": "needs_review",
|
|
"prompt_messages": prompt_messages,
|
|
"model_response": pr.get("output", ""),
|
|
"corrected_response": None,
|
|
"quality_score": round(score_pct, 4),
|
|
"failure_reason": _build_failure_reason(pr, signals),
|
|
"failure_category": None,
|
|
"task_id": f"style/{tag}",
|
|
"task_type": "style-match",
|
|
"task_name": tag.replace("_", " ").title(),
|
|
"model_id": model_id,
|
|
"model_name": model_id,
|
|
"node_id": "",
|
|
"gpu_id": 0,
|
|
"tokens_per_sec": 0,
|
|
})
|
|
existing_ids.add(candidate_id)
|
|
|
|
if new_candidates:
|
|
sft_data_dir.mkdir(parents=True, exist_ok=True)
|
|
with open(sft_file, "a", encoding="utf-8") as fh:
|
|
for c in new_candidates:
|
|
fh.write(json.dumps(c) + "\n")
|
|
|
|
return {"imported": len(new_candidates), "skipped": 0}
|
|
|
|
|
|
# Excerpt of the system prompt used in benchmark_style.py — reproduced here
|
|
# so the SFT candidate captures the full generation context.
|
|
_STYLE_SYSTEM_PROMPT = (
|
|
"You are a writing assistant. Your job is to write a Reddit reply that matches "
|
|
"the voice, tone, and style of the provided samples exactly.\n\n"
|
|
"Voice characteristics:\n"
|
|
"- Casual engineer tone. Short punchy sentences.\n"
|
|
"- No em dashes. No semicolons. No filler phrases.\n"
|
|
"- Direct. Opinionated. Community-first."
|
|
)
|
|
|
|
|
|
def _build_failure_reason(pr: dict, signals: dict) -> str | None:
|
|
"""Return a human-readable failure reason string if there are violations."""
|
|
reasons = []
|
|
if signals.get("em_dash_count", 0) > 0:
|
|
reasons.append(f"{signals['em_dash_count']} em dash(es)")
|
|
if signals.get("semicolon_count", 0) > 0:
|
|
reasons.append(f"{signals['semicolon_count']} semicolon(s)")
|
|
if signals.get("filler_hits"):
|
|
reasons.append(f"filler phrases: {', '.join(signals['filler_hits'])}")
|
|
if not pr.get("output", "").strip():
|
|
reasons.append("empty output")
|
|
return "; ".join(reasons) if reasons else None
|
|
|
|
|
|
# ── POST /cancel ───────────────────────────────────────────────────────────────
|
|
|
|
@router.post("/cancel")
|
|
def cancel_style_benchmark() -> dict:
|
|
"""Kill the running writing style benchmark subprocess."""
|
|
global _BENCH_RUNNING, _bench_proc
|
|
|
|
if not _BENCH_RUNNING:
|
|
raise HTTPException(404, "No writing style benchmark is currently running")
|
|
|
|
if _bench_proc is not None:
|
|
try:
|
|
_bench_proc.terminate()
|
|
except Exception as exc:
|
|
logger.warning("Failed to terminate style benchmark: %s", exc)
|
|
|
|
_BENCH_RUNNING = False
|
|
_bench_proc = None
|
|
return {"status": "cancelled"}
|