avocet/app/style.py
pyr0ball ddb56efb89 refactor(bench): extract benchmark tabs — classifier, compare, llm-eval, style, voice
- BenchmarkView.vue: convert from monolithic view to tabbed shell; each tab is
  now its own component (ClassifierTab, CompareTab, LlmEvalTab, StyleTab, VoiceTab)
- StyleTab + VoiceTab: new benchmark modes for style and voice model evaluation
- app/style.py: FastAPI router for style imitation benchmarks
- app/voice.py: FastAPI router for voice benchmark endpoints
- scripts/benchmark_style.py + benchmark_voice.py: headless runner scripts
2026-04-24 14:56:17 -07:00

427 lines
17 KiB
Python

"""Avocet — Writing style benchmark integration API.
Wraps scripts/benchmark_style.py and exposes it via the Avocet API.
Connection config (coordinator_url, ollama_url, python_bin) is read
from label_tool.yaml under the `cforch:` key — the same block used
by cforch.py, so no new config section is needed.
All endpoints are registered on `router` (a FastAPI APIRouter).
api.py includes this router with prefix="/api/style".
Module-level globals (_BENCH_RUNNING, _bench_proc) follow the same
testability pattern as cforch.py.
"""
from __future__ import annotations
import json
import logging
import subprocess as _subprocess
import uuid
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
import httpx
import yaml
from fastapi import APIRouter, HTTPException, Query
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
logger = logging.getLogger(__name__)
_ROOT = Path(__file__).parent.parent
_CONFIG_DIR: Path | None = None # override in tests via set_config_dir()
_BENCH_RUNNING: bool = False
_bench_proc: Any = None
_BENCH_SCRIPT = _ROOT / "scripts" / "benchmark_style.py"
_RESULTS_DIR = _ROOT / "benchmark_results"
router = APIRouter()
# ── Testability seams ──────────────────────────────────────────────────────────
def set_config_dir(path: Path | None) -> None:
global _CONFIG_DIR
_CONFIG_DIR = path
# ── Internal helpers ───────────────────────────────────────────────────────────
def _config_file() -> Path:
if _CONFIG_DIR is not None:
return _CONFIG_DIR / "label_tool.yaml"
return _ROOT / "config" / "label_tool.yaml"
def _load_config() -> dict:
"""Read label_tool.yaml cforch section for coordinator/ollama/python config."""
f = _config_file()
file_cfg: dict = {}
if f.exists():
try:
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
file_cfg = raw.get("cforch", {}) or {}
except yaml.YAMLError as exc:
logger.warning("Failed to parse style config %s: %s", f, exc)
return {
"coordinator_url": file_cfg.get("coordinator_url", "http://10.1.10.71:7700"),
"ollama_url": file_cfg.get("ollama_url", "http://localhost:11434"),
"python_bin": file_cfg.get("python_bin", "/devl/miniconda3/envs/cf/bin/python"),
}
# ── GET /models ────────────────────────────────────────────────────────────────
@router.get("/models")
def get_models() -> dict:
"""Return available models grouped by source.
- ollama: fetched live from /api/tags (includes any models downloaded
via the Models view — automatically in sync)
- cf_text: fetched from cf-orch catalog endpoint (requires node profile
entry + coordinator restart when new GGUFs are added)
"""
cfg = _load_config()
# Ollama models — live query so newly downloaded models appear immediately
ollama_models: list[dict] = []
try:
resp = httpx.get(f"{cfg['ollama_url']}/api/tags", timeout=5.0)
resp.raise_for_status()
for m in resp.json().get("models", []):
name = m.get("name", "")
if name:
size_bytes = m.get("size", 0)
ollama_models.append({
"id": name,
"name": name,
"source": "ollama",
"size_mb": round(size_bytes / (1024 * 1024)) if size_bytes else None,
"vram_mb": None,
})
except Exception as exc:
logger.warning("Failed to fetch ollama models: %s", exc)
# cf-text catalog — fetched from cf-orch coordinator
cftext_models: list[dict] = []
try:
resp = httpx.get(
f"{cfg['coordinator_url']}/api/services/cf-text/catalog",
timeout=5.0,
)
resp.raise_for_status()
for model_id, entry in resp.json().items():
if isinstance(entry, dict):
cftext_models.append({
"id": model_id,
"name": model_id,
"source": "cf-text",
"vram_mb": entry.get("vram_mb"),
"description": entry.get("description", ""),
})
except Exception as exc:
logger.warning("Failed to fetch cf-text catalog: %s", exc)
return {"ollama": ollama_models, "cf_text": cftext_models}
# ── GET /run ───────────────────────────────────────────────────────────────────
@router.get("/run")
def run_style_benchmark(
models: str = Query("", description="Comma-separated model IDs (empty = all)"),
use_cforch: bool = Query(False),
max_vram: int = Query(7200, description="Max VRAM MB for cf-orch OOM filter"),
include_large: bool = Query(False, description="Include large (30B+) ollama models"),
workers: int = Query(1, description="Parallel workers — run N models simultaneously"),
) -> StreamingResponse:
"""Spawn benchmark_style.py and stream stdout as SSE progress events.
On successful completion, emits a final `type: result` event containing
the parsed JSON from the newest style_*.json file.
"""
global _BENCH_RUNNING, _bench_proc
if _BENCH_RUNNING:
raise HTTPException(409, "A writing style benchmark is already running")
cfg = _load_config()
python_bin = cfg["python_bin"]
def generate():
global _BENCH_RUNNING, _bench_proc
if not _BENCH_SCRIPT.exists():
yield f"data: {json.dumps({'type': 'error', 'message': f'benchmark_style.py not found at {_BENCH_SCRIPT}'})}\n\n"
return
cmd = [python_bin, str(_BENCH_SCRIPT), "run"]
if models:
cmd.extend(["--models", ",".join(m.strip() for m in models.split(",") if m.strip())])
if use_cforch:
cmd.extend(["--cforch", "--cforch-url", cfg["coordinator_url"],
"--max-vram", str(max_vram)])
if include_large:
cmd.append("--include-large")
if workers > 1:
cmd.extend(["--workers", str(workers)])
_BENCH_RUNNING = True
try:
proc = _subprocess.Popen(
cmd,
stdout=_subprocess.PIPE,
stderr=_subprocess.STDOUT,
text=True,
bufsize=1,
cwd=str(_ROOT),
)
_bench_proc = proc
try:
for line in proc.stdout:
line = line.rstrip()
if line:
yield f"data: {json.dumps({'type': 'progress', 'message': line})}\n\n"
proc.wait()
if proc.returncode == 0:
result_files = sorted(_RESULTS_DIR.glob("style_*.json"))
if result_files:
try:
results = json.loads(result_files[-1].read_text(encoding="utf-8"))
yield f"data: {json.dumps({'type': 'result', 'results': results, 'filename': result_files[-1].name})}\n\n"
except Exception as exc:
logger.warning("Failed to read style results: %s", exc)
yield f"data: {json.dumps({'type': 'complete'})}\n\n"
else:
yield f"data: {json.dumps({'type': 'error', 'message': f'Process exited with code {proc.returncode}'})}\n\n"
finally:
_bench_proc = None
except Exception as exc:
yield f"data: {json.dumps({'type': 'error', 'message': str(exc)})}\n\n"
finally:
_BENCH_RUNNING = False
return StreamingResponse(
generate(),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
)
# ── GET /results ───────────────────────────────────────────────────────────────
@router.get("/results")
def list_results() -> list[dict]:
"""List past writing style benchmark runs, newest first.
Returns lightweight summaries (date, model count, top score).
Use /results/{filename} to fetch full model-level detail.
"""
if not _RESULTS_DIR.exists():
return []
runs: list[dict] = []
for f in sorted(_RESULTS_DIR.glob("style_*.json"), reverse=True):
stem = f.stem # style_2026-04-22_1502
date_str = stem.removeprefix("style_") # 2026-04-22_1502
try:
date_part, time_part = date_str.split("_")
display_date = f"{date_part} {time_part[:2]}:{time_part[2:]}"
except Exception:
display_date = date_str
try:
results = json.loads(f.read_text(encoding="utf-8"))
top_score = max((r.get("avg_score", 0) for r in results), default=0)
model_count = len(results)
except Exception:
top_score = 0
model_count = 0
runs.append({
"filename": f.name,
"date": display_date,
"model_count": model_count,
"top_score": round(top_score, 1),
})
return runs
@router.get("/results/latest")
def get_latest_results() -> list[dict]:
"""Return the latest writing style benchmark result list."""
if not _RESULTS_DIR.exists():
raise HTTPException(404, "No benchmark results found")
files = sorted(_RESULTS_DIR.glob("style_*.json"))
if not files:
raise HTTPException(404, "No benchmark results found")
try:
return json.loads(files[-1].read_text(encoding="utf-8"))
except Exception as exc:
raise HTTPException(500, f"Failed to read results: {exc}") from exc
@router.get("/results/{filename}")
def get_results_by_filename(filename: str) -> list[dict]:
"""Return writing style benchmark results for a specific run file."""
if not filename.startswith("style_") or not filename.endswith(".json"):
raise HTTPException(400, "Invalid filename — expected style_*.json")
f = _RESULTS_DIR / filename
if not f.exists():
raise HTTPException(404, f"Results file not found: {filename}")
try:
return json.loads(f.read_text(encoding="utf-8"))
except Exception as exc:
raise HTTPException(500, f"Failed to read results: {exc}") from exc
# ── POST /send-to-corrections ──────────────────────────────────────────────────
class SendToCorrectionsRequest(BaseModel):
filename: str # style_YYYY-MM-DD_HHMM.json — the source run file
model_ids: list[str] = [] # empty = all models in the run
@router.post("/send-to-corrections")
def send_to_corrections(req: SendToCorrectionsRequest) -> dict:
"""Push writing style benchmark outputs into the SFT corrections queue.
Each prompt_result from the selected models becomes one SFT candidate
with status='needs_review'. Duplicates are skipped via the 'id' field
(hash of model_id + tag).
"""
if not req.filename.startswith("style_") or not req.filename.endswith(".json"):
raise HTTPException(400, "Invalid filename")
src = _RESULTS_DIR / req.filename
if not src.exists():
raise HTTPException(404, f"Results file not found: {req.filename}")
try:
run_results: list[dict] = json.loads(src.read_text(encoding="utf-8"))
except Exception as exc:
raise HTTPException(500, f"Failed to read results: {exc}") from exc
# Resolve sft_candidates.jsonl path (same logic as sft.py)
sft_data_dir = _ROOT / "data"
sft_file = sft_data_dir / "sft_candidates.jsonl"
# Load existing IDs to deduplicate
existing_ids: set[str] = set()
if sft_file.exists():
for line in sft_file.read_text(encoding="utf-8").splitlines():
line = line.strip()
if line:
try:
existing_ids.add(json.loads(line)["id"])
except Exception:
pass
run_id = req.filename.removesuffix(".json") # style_2026-04-22_1502
timestamp = datetime.now(tz=timezone.utc).isoformat()
new_candidates: list[dict] = []
for model_result in run_results:
model_id = model_result.get("model_id", "")
if req.model_ids and model_id not in req.model_ids:
continue
for pr in model_result.get("prompt_results", []):
tag = pr.get("tag", "")
# Stable id: deterministic hash of run + model + prompt tag
candidate_id = str(uuid.uuid5(
uuid.NAMESPACE_URL,
f"style-benchmark/{run_id}/{model_id}/{tag}",
))
if candidate_id in existing_ids:
continue
score_pct = pr.get("score", 0.0) / 100.0
signals = pr.get("signals", {})
# Build the prompt message list matching the benchmark's actual request
prompt_messages = [
{"role": "system", "content": _STYLE_SYSTEM_PROMPT},
{"role": "user", "content": pr.get("user_prompt", tag)},
]
new_candidates.append({
"id": candidate_id,
"source": "style-benchmark",
"benchmark_run_id": run_id,
"timestamp": timestamp,
"status": "needs_review",
"prompt_messages": prompt_messages,
"model_response": pr.get("output", ""),
"corrected_response": None,
"quality_score": round(score_pct, 4),
"failure_reason": _build_failure_reason(pr, signals),
"failure_category": None,
"task_id": f"style/{tag}",
"task_type": "style-match",
"task_name": tag.replace("_", " ").title(),
"model_id": model_id,
"model_name": model_id,
"node_id": "",
"gpu_id": 0,
"tokens_per_sec": 0,
})
existing_ids.add(candidate_id)
if new_candidates:
sft_data_dir.mkdir(parents=True, exist_ok=True)
with open(sft_file, "a", encoding="utf-8") as fh:
for c in new_candidates:
fh.write(json.dumps(c) + "\n")
return {"imported": len(new_candidates), "skipped": 0}
# Excerpt of the system prompt used in benchmark_style.py — reproduced here
# so the SFT candidate captures the full generation context.
_STYLE_SYSTEM_PROMPT = (
"You are a writing assistant. Your job is to write a Reddit reply that matches "
"the voice, tone, and style of the provided samples exactly.\n\n"
"Voice characteristics:\n"
"- Casual engineer tone. Short punchy sentences.\n"
"- No em dashes. No semicolons. No filler phrases.\n"
"- Direct. Opinionated. Community-first."
)
def _build_failure_reason(pr: dict, signals: dict) -> str | None:
"""Return a human-readable failure reason string if there are violations."""
reasons = []
if signals.get("em_dash_count", 0) > 0:
reasons.append(f"{signals['em_dash_count']} em dash(es)")
if signals.get("semicolon_count", 0) > 0:
reasons.append(f"{signals['semicolon_count']} semicolon(s)")
if signals.get("filler_hits"):
reasons.append(f"filler phrases: {', '.join(signals['filler_hits'])}")
if not pr.get("output", "").strip():
reasons.append("empty output")
return "; ".join(reasons) if reasons else None
# ── POST /cancel ───────────────────────────────────────────────────────────────
@router.post("/cancel")
def cancel_style_benchmark() -> dict:
"""Kill the running writing style benchmark subprocess."""
global _BENCH_RUNNING, _bench_proc
if not _BENCH_RUNNING:
raise HTTPException(404, "No writing style benchmark is currently running")
if _bench_proc is not None:
try:
_bench_proc.terminate()
except Exception as exc:
logger.warning("Failed to terminate style benchmark: %s", exc)
_BENCH_RUNNING = False
_bench_proc = None
return {"status": "cancelled"}