avocet/app/eval/embed_bench.py

214 lines
7.2 KiB
Python

"""Avocet — embedding model comparison harness.
Exposes FastAPI routes under /api/embed-bench (mounted via app/eval/cforch.py).
All computation is local: no LLM inference, Ollama only. MIT tier throughout.
"""
from __future__ import annotations
import csv
import io
import json
import logging
import math
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
import httpx
import yaml
from fastapi import APIRouter, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, field_validator
logger = logging.getLogger(__name__)
_ROOT = Path(__file__).parent.parent.parent
_CONFIG_DIR: Path | None = None # override via set_config_dir() in tests
_RUN_ACTIVE: bool = False
_RATINGS_FILE = _ROOT / "data" / "embed_bench_ratings.jsonl"
router = APIRouter()
# ── Testability seam ──────────────────────────────────────────────────────────
def set_config_dir(path: Path | None) -> None:
global _CONFIG_DIR
_CONFIG_DIR = path
# ── Internal helpers ──────────────────────────────────────────────────────────
def _config_file() -> Path:
if _CONFIG_DIR is not None:
return _CONFIG_DIR / "label_tool.yaml"
return _ROOT / "config" / "label_tool.yaml"
def _load_config() -> dict[str, Any]:
f = _config_file()
if not f.exists():
return {}
try:
return yaml.safe_load(f.read_text(encoding="utf-8")) or {}
except yaml.YAMLError as exc:
logger.warning("Failed to parse embed_bench config %s: %s", f, exc)
return {}
def _ollama_url() -> str:
cfg = _load_config()
embed_cfg = cfg.get("embed_bench", {}) or {}
cforch_cfg = cfg.get("cforch", {}) or {}
return (
embed_cfg.get("ollama_url")
or cforch_cfg.get("ollama_url", "http://localhost:11434")
)
def _ratings_path() -> Path:
if _CONFIG_DIR is not None:
return _CONFIG_DIR / "embed_bench_ratings.jsonl"
return _RATINGS_FILE
def _cosine(a: list[float], b: list[float]) -> float:
if len(a) != len(b):
raise ValueError(
f"Embedding dimension mismatch: {len(a)} vs {len(b)}"
)
dot = sum(x * y for x, y in zip(a, b))
mag_a = math.sqrt(sum(x * x for x in a))
mag_b = math.sqrt(sum(x * x for x in b))
if mag_a == 0.0 or mag_b == 0.0:
return 0.0
return dot / (mag_a * mag_b)
# ── GET /models ───────────────────────────────────────────────────────────────
@router.get("/models")
def get_models() -> dict:
"""Return Ollama embedding models available on the configured instance."""
ollama = _ollama_url()
models: list[dict] = []
try:
resp = httpx.get(f"{ollama}/api/tags", timeout=5.0)
resp.raise_for_status()
for entry in resp.json().get("models", []):
models.append({
"name": entry.get("name", ""),
"size": entry.get("size", 0),
})
except httpx.HTTPStatusError as exc:
logger.warning("Ollama /api/tags returned HTTP %s: %s", exc.response.status_code, exc)
except httpx.RequestError as exc:
logger.warning("Failed to reach Ollama for model list: %s", exc)
return {"models": models, "ollama_url": ollama}
# ── POST /run ─────────────────────────────────────────────────────────────────
class RunRequest(BaseModel):
corpus: list[str]
queries: list[str]
models: list[str]
top_k: int = 5
ollama_url: str = ""
@field_validator("corpus")
@classmethod
def corpus_nonempty(cls, v: list[str]) -> list[str]:
if not v:
raise ValueError("corpus must not be empty")
return v
@field_validator("queries")
@classmethod
def queries_nonempty(cls, v: list[str]) -> list[str]:
if not v:
raise ValueError("queries must not be empty")
return v
@field_validator("models")
@classmethod
def models_nonempty(cls, v: list[str]) -> list[str]:
if not v:
raise ValueError("models must contain at least one model name")
return v
def _embed_texts(ollama: str, model: str, texts: list[str]) -> list[list[float]]:
"""Batch-embed texts via Ollama /v1/embeddings. Returns one vector per text."""
resp = httpx.post(
f"{ollama}/v1/embeddings",
json={"model": model, "input": texts},
timeout=120.0,
)
resp.raise_for_status()
data = resp.json().get("data", [])
return [item["embedding"] for item in data]
def _sse(event: dict) -> str:
return f"data: {json.dumps(event)}\n\n"
@router.post("/run")
def run_embed_bench(req: RunRequest) -> StreamingResponse:
"""Embed corpus + queries with each model; stream SSE results."""
global _RUN_ACTIVE
if _RUN_ACTIVE:
raise HTTPException(409, "An embedding benchmark run is already active")
ollama = req.ollama_url or _ollama_url()
def _generate():
global _RUN_ACTIVE
_RUN_ACTIVE = True
try:
for model_idx, model in enumerate(req.models, start=1):
yield _sse({
"type": "progress",
"msg": f"Indexing corpus with {model} ({model_idx}/{len(req.models)})...",
})
try:
corpus_vecs = _embed_texts(ollama, model, req.corpus)
except Exception as exc:
yield _sse({"type": "error", "msg": f"Ollama error for {model}: {exc}"})
continue
yield _sse({
"type": "progress",
"msg": f"Running queries with {model}...",
})
for q_idx, query in enumerate(req.queries):
try:
q_vecs = _embed_texts(ollama, model, [query])
except Exception as exc:
yield _sse({"type": "error", "msg": f"Query embed error ({model}): {exc}"})
continue
q_vec = q_vecs[0]
scored = sorted(
[
{"chunk_idx": i, "text": chunk, "score": round(_cosine(q_vec, cv), 4)}
for i, (chunk, cv) in enumerate(zip(req.corpus, corpus_vecs))
],
key=lambda h: h["score"],
reverse=True,
)[: req.top_k]
yield _sse({
"type": "result",
"query_idx": q_idx,
"query": query,
"model": model,
"hits": scored,
})
yield _sse({"type": "done"})
finally:
_RUN_ACTIVE = False
return StreamingResponse(_generate(), media_type="text/event-stream")