fix: restore real plans_bench.py (was accidentally stubbed)
This commit is contained in:
parent
bccb385f61
commit
d432026fd7
1 changed files with 306 additions and 13 deletions
|
|
@ -1,30 +1,323 @@
|
||||||
"""Avocet -- Plans benchmark integration API (stub).
|
"""Avocet — CF planning benchmark integration API.
|
||||||
|
|
||||||
Placeholder module so that app/eval/cforch.py can import and include
|
Wraps scripts/benchmark_plans.py and exposes it via the Avocet API.
|
||||||
this router. Full implementation follows in a subsequent task.
|
Connection config (api_base) is read from label_tool.yaml under the
|
||||||
|
`plans_bench:` key (optional; falls back to localhost:8080).
|
||||||
|
|
||||||
All endpoints are registered on `router` (a FastAPI APIRouter).
|
All endpoints are registered on `router` (FastAPI APIRouter).
|
||||||
api.py (via the eval aggregator) includes this router at
|
api.py includes this router with prefix="/api/plans-bench".
|
||||||
prefix="/api/plans-bench".
|
|
||||||
"""
|
"""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import subprocess as _subprocess
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timezone
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from fastapi import APIRouter
|
import httpx
|
||||||
|
import yaml
|
||||||
|
from fastapi import APIRouter, HTTPException, Query
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_ROOT = Path(__file__).parent.parent
|
||||||
|
_CONFIG_DIR: Path | None = None # override in tests via set_config_dir()
|
||||||
|
_BENCH_RUNNING: bool = False
|
||||||
|
_bench_proc: Any = None
|
||||||
|
|
||||||
|
_BENCH_SCRIPT = _ROOT / "scripts" / "benchmark_plans.py"
|
||||||
|
_RESULTS_DIR = _ROOT / "data" / "plans_bench_results"
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
_CONFIG_DIR: Path | None = None # override in tests via set_config_dir()
|
# ── Registered model shortcuts (mirrors benchmark_plans.MODEL_REGISTRY) ────────
|
||||||
|
# Kept here so the UI can list them without importing the script.
|
||||||
|
|
||||||
|
MODEL_REGISTRY: dict[str, str] = {
|
||||||
|
"llama3.2-3b": "Llama 3.2 3B Instruct (local via cf-text)",
|
||||||
|
"llama3.2-1b": "Llama 3.2 1B Instruct (local via cf-text)",
|
||||||
|
"mistral-7b": "Mistral 7B Instruct (local via cf-text)",
|
||||||
|
"phi3-mini": "Phi-3 Mini 3.8B (local via cf-text)",
|
||||||
|
"qwen2.5-3b": "Qwen 2.5 3B Instruct (local via cf-text)",
|
||||||
|
}
|
||||||
|
|
||||||
|
RUBRIC_LABELS: dict[str, str] = {
|
||||||
|
"task_structure": "Task structure (checkboxes + commits)",
|
||||||
|
"tier_awareness": "Tier awareness (Free/Paid/Premium/Ultra)",
|
||||||
|
"privacy_pillar": "Privacy pillar (local-first, no logging)",
|
||||||
|
"safety_pillar": "Safety pillar (human approval, reversibility)",
|
||||||
|
"accessibility": "Accessibility (ND/adaptive users)",
|
||||||
|
"license_split": "License awareness (MIT vs BSL)",
|
||||||
|
"file_paths": "File paths (plausible project paths)",
|
||||||
|
"cf_conventions": "CF conventions (conda, manage.sh, /Library/…)",
|
||||||
|
"length_ok": "Response length (200–2500 words)",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ── Testability seam ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
def set_config_dir(path: Path | None) -> None:
|
def set_config_dir(path: Path | None) -> None:
|
||||||
"""Override config directory -- used by tests."""
|
|
||||||
global _CONFIG_DIR
|
global _CONFIG_DIR
|
||||||
_CONFIG_DIR = path
|
_CONFIG_DIR = path
|
||||||
|
|
||||||
|
|
||||||
@router.get("/status")
|
# ── Internal helpers ───────────────────────────────────────────────────────────
|
||||||
def get_plans_bench_status() -> dict:
|
|
||||||
"""Return placeholder status for the plans benchmark module."""
|
def _config_file() -> Path:
|
||||||
return {"status": "not_implemented"}
|
if _CONFIG_DIR is not None:
|
||||||
|
return _CONFIG_DIR / "label_tool.yaml"
|
||||||
|
return _ROOT / "config" / "label_tool.yaml"
|
||||||
|
|
||||||
|
|
||||||
|
def _load_config() -> dict:
|
||||||
|
f = _config_file()
|
||||||
|
cforch_cfg: dict = {}
|
||||||
|
bench_cfg: dict = {}
|
||||||
|
if f.exists():
|
||||||
|
try:
|
||||||
|
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
||||||
|
cforch_cfg = raw.get("cforch", {}) or {}
|
||||||
|
bench_cfg = raw.get("plans_bench", {}) or {}
|
||||||
|
except yaml.YAMLError as exc:
|
||||||
|
logger.warning("Failed to parse plans_bench config %s: %s", f, exc)
|
||||||
|
return {
|
||||||
|
"coordinator_url": cforch_cfg.get("coordinator_url",
|
||||||
|
bench_cfg.get("coordinator_url", "http://10.1.10.71:7700")),
|
||||||
|
"python_bin": cforch_cfg.get("python_bin",
|
||||||
|
bench_cfg.get("python_bin", "/devl/miniconda3/envs/cf/bin/python")),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _results_file(run_id: str) -> Path:
|
||||||
|
return _RESULTS_DIR / f"{run_id}.json"
|
||||||
|
|
||||||
|
|
||||||
|
# ── GET /models ────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@router.get("/models")
|
||||||
|
def get_models() -> dict:
|
||||||
|
"""Return registered model shortcuts, live cf-orch catalog, and rubric labels."""
|
||||||
|
cfg = _load_config()
|
||||||
|
|
||||||
|
cforch_models: list[dict] = []
|
||||||
|
try:
|
||||||
|
resp = httpx.get(
|
||||||
|
f"{cfg['coordinator_url']}/api/services/cf-text/catalog",
|
||||||
|
timeout=5.0,
|
||||||
|
)
|
||||||
|
resp.raise_for_status()
|
||||||
|
for model_id, entry in resp.json().items():
|
||||||
|
if isinstance(entry, dict):
|
||||||
|
cforch_models.append({
|
||||||
|
"id": model_id,
|
||||||
|
"name": model_id,
|
||||||
|
"vram_mb": entry.get("vram_mb"),
|
||||||
|
"description": entry.get("description", ""),
|
||||||
|
})
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("Failed to fetch cf-orch catalog: %s", exc)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"registry": [
|
||||||
|
{"key": k, "description": v}
|
||||||
|
for k, v in MODEL_REGISTRY.items()
|
||||||
|
],
|
||||||
|
"cforch_models": cforch_models,
|
||||||
|
"coordinator_url": cfg["coordinator_url"],
|
||||||
|
"rubric_labels": RUBRIC_LABELS,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ── GET /run ───────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@router.get("/run")
|
||||||
|
def run_plans_benchmark(
|
||||||
|
models: str = Query(..., description="Comma-separated model IDs (registry keys or cf-orch model names)"),
|
||||||
|
prompt_ids: str = Query("", description="Comma-separated prompt IDs to run (empty = all 10)"),
|
||||||
|
use_cforch: bool = Query(True, description="Route inference through cf-orch coordinator"),
|
||||||
|
api_base: str = Query("", description="Direct API base URL when not using cf-orch"),
|
||||||
|
workers: int = Query(1, ge=1, le=8, description="Number of models to benchmark concurrently"),
|
||||||
|
) -> StreamingResponse:
|
||||||
|
"""Spawn benchmark_plans.py and stream stdout as SSE progress events.
|
||||||
|
|
||||||
|
On successful completion emits a `type: result` event with parsed JSON
|
||||||
|
and saves results to data/plans_bench_results/<run_id>.json.
|
||||||
|
"""
|
||||||
|
global _BENCH_RUNNING, _bench_proc
|
||||||
|
|
||||||
|
if _BENCH_RUNNING:
|
||||||
|
raise HTTPException(409, "A planning benchmark is already running")
|
||||||
|
|
||||||
|
cfg = _load_config()
|
||||||
|
python_bin = cfg["python_bin"]
|
||||||
|
coordinator_url = cfg["coordinator_url"]
|
||||||
|
|
||||||
|
model_keys = [m.strip() for m in models.split(",") if m.strip()]
|
||||||
|
if not model_keys:
|
||||||
|
raise HTTPException(400, "At least one model key is required")
|
||||||
|
|
||||||
|
run_id = datetime.now(tz=timezone.utc).strftime("plans_%Y-%m-%d_%H%M%S")
|
||||||
|
output_path = _results_file(run_id)
|
||||||
|
_RESULTS_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
def generate():
|
||||||
|
global _BENCH_RUNNING, _bench_proc
|
||||||
|
|
||||||
|
if not _BENCH_SCRIPT.exists():
|
||||||
|
yield f"data: {json.dumps({'type': 'error', 'message': f'benchmark_plans.py not found at {_BENCH_SCRIPT}'})}\n\n"
|
||||||
|
return
|
||||||
|
|
||||||
|
cmd = [python_bin, str(_BENCH_SCRIPT)]
|
||||||
|
if len(model_keys) > 1:
|
||||||
|
cmd.extend(["--compare"] + model_keys)
|
||||||
|
else:
|
||||||
|
cmd.extend(["--model", model_keys[0]])
|
||||||
|
|
||||||
|
if use_cforch:
|
||||||
|
cmd.extend(["--cforch", "--cforch-url", coordinator_url])
|
||||||
|
elif api_base.strip():
|
||||||
|
cmd.extend(["--api-base", api_base.strip()])
|
||||||
|
|
||||||
|
cmd.extend(["--verbose", "--output", str(output_path)])
|
||||||
|
if workers > 1:
|
||||||
|
cmd.extend(["--workers", str(workers)])
|
||||||
|
|
||||||
|
if prompt_ids.strip():
|
||||||
|
cmd.extend(["--prompts"] + [p.strip() for p in prompt_ids.split(",") if p.strip()])
|
||||||
|
|
||||||
|
_BENCH_RUNNING = True
|
||||||
|
try:
|
||||||
|
proc = _subprocess.Popen(
|
||||||
|
cmd,
|
||||||
|
stdout=_subprocess.PIPE,
|
||||||
|
stderr=_subprocess.STDOUT,
|
||||||
|
text=True,
|
||||||
|
bufsize=1,
|
||||||
|
cwd=str(_ROOT),
|
||||||
|
)
|
||||||
|
_bench_proc = proc
|
||||||
|
try:
|
||||||
|
for line in proc.stdout:
|
||||||
|
line = line.rstrip()
|
||||||
|
if line:
|
||||||
|
yield f"data: {json.dumps({'type': 'progress', 'message': line})}\n\n"
|
||||||
|
proc.wait()
|
||||||
|
if proc.returncode == 0 and output_path.exists():
|
||||||
|
try:
|
||||||
|
results = json.loads(output_path.read_text(encoding="utf-8"))
|
||||||
|
yield f"data: {json.dumps({'type': 'result', 'run_id': run_id, 'results': results})}\n\n"
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("Failed to read plans benchmark output: %s", exc)
|
||||||
|
yield f"data: {json.dumps({'type': 'complete'})}\n\n"
|
||||||
|
else:
|
||||||
|
yield f"data: {json.dumps({'type': 'error', 'message': f'Process exited with code {proc.returncode}'})}\n\n"
|
||||||
|
finally:
|
||||||
|
_bench_proc = None
|
||||||
|
except Exception as exc:
|
||||||
|
yield f"data: {json.dumps({'type': 'error', 'message': str(exc)})}\n\n"
|
||||||
|
finally:
|
||||||
|
_BENCH_RUNNING = False
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
generate(),
|
||||||
|
media_type="text/event-stream",
|
||||||
|
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── GET /results ───────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@router.get("/results")
|
||||||
|
def list_results() -> list[dict]:
|
||||||
|
"""List past planning benchmark runs, newest first."""
|
||||||
|
if not _RESULTS_DIR.exists():
|
||||||
|
return []
|
||||||
|
|
||||||
|
runs: list[dict] = []
|
||||||
|
for f in sorted(_RESULTS_DIR.glob("plans_*.json"), reverse=True):
|
||||||
|
run_id = f.stem
|
||||||
|
try:
|
||||||
|
data: dict = json.loads(f.read_text(encoding="utf-8"))
|
||||||
|
model_keys = list(data.keys())
|
||||||
|
# Average total_score across all models and prompts
|
||||||
|
all_scores = [
|
||||||
|
r["total_score"]
|
||||||
|
for results in data.values()
|
||||||
|
for r in results
|
||||||
|
if not r.get("error")
|
||||||
|
]
|
||||||
|
avg_score = round(sum(all_scores) / len(all_scores), 3) if all_scores else 0.0
|
||||||
|
except Exception:
|
||||||
|
model_keys = []
|
||||||
|
avg_score = 0.0
|
||||||
|
|
||||||
|
# Parse display date from run_id (plans_2026-04-27_143022)
|
||||||
|
try:
|
||||||
|
date_part = run_id.removeprefix("plans_") # 2026-04-27_143022
|
||||||
|
date, time = date_part.split("_")
|
||||||
|
display_date = f"{date} {time[:2]}:{time[2:4]}"
|
||||||
|
except Exception:
|
||||||
|
display_date = run_id
|
||||||
|
|
||||||
|
runs.append({
|
||||||
|
"run_id": run_id,
|
||||||
|
"filename": f.name,
|
||||||
|
"date": display_date,
|
||||||
|
"models": model_keys,
|
||||||
|
"avg_score": avg_score,
|
||||||
|
})
|
||||||
|
|
||||||
|
return runs
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/results/latest")
|
||||||
|
def get_latest_results() -> dict:
|
||||||
|
"""Return the most recent planning benchmark results dict."""
|
||||||
|
if not _RESULTS_DIR.exists():
|
||||||
|
raise HTTPException(404, "No benchmark results found")
|
||||||
|
files = sorted(_RESULTS_DIR.glob("plans_*.json"))
|
||||||
|
if not files:
|
||||||
|
raise HTTPException(404, "No benchmark results found")
|
||||||
|
try:
|
||||||
|
return json.loads(files[-1].read_text(encoding="utf-8"))
|
||||||
|
except Exception as exc:
|
||||||
|
raise HTTPException(500, f"Failed to read results: {exc}") from exc
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/results/{run_id}")
|
||||||
|
def get_results_by_run_id(run_id: str) -> dict:
|
||||||
|
"""Return planning benchmark results for a specific run."""
|
||||||
|
if not run_id.startswith("plans_"):
|
||||||
|
raise HTTPException(400, "Invalid run_id — expected plans_YYYY-MM-DD_HHMMSS")
|
||||||
|
f = _results_file(run_id)
|
||||||
|
if not f.exists():
|
||||||
|
raise HTTPException(404, f"Results not found: {run_id}")
|
||||||
|
try:
|
||||||
|
return json.loads(f.read_text(encoding="utf-8"))
|
||||||
|
except Exception as exc:
|
||||||
|
raise HTTPException(500, f"Failed to read results: {exc}") from exc
|
||||||
|
|
||||||
|
|
||||||
|
# ── POST /cancel ───────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@router.post("/cancel")
|
||||||
|
def cancel_plans_benchmark() -> dict:
|
||||||
|
"""Kill the running planning benchmark subprocess."""
|
||||||
|
global _BENCH_RUNNING, _bench_proc
|
||||||
|
|
||||||
|
if not _BENCH_RUNNING:
|
||||||
|
raise HTTPException(404, "No planning benchmark is currently running")
|
||||||
|
|
||||||
|
if _bench_proc is not None:
|
||||||
|
try:
|
||||||
|
_bench_proc.terminate()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("Failed to terminate plans benchmark: %s", exc)
|
||||||
|
|
||||||
|
_BENCH_RUNNING = False
|
||||||
|
_bench_proc = None
|
||||||
|
return {"status": "cancelled"}
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue