feat: add embed-bench run endpoint with SSE streaming

This commit is contained in:
pyr0ball 2026-05-07 09:05:34 -07:00
parent 12117ad0c6
commit 32e3b2a0dd
2 changed files with 193 additions and 0 deletions

View file

@ -105,3 +105,110 @@ def get_models() -> dict:
except httpx.RequestError as exc:
logger.warning("Failed to reach Ollama for model list: %s", exc)
return {"models": models, "ollama_url": ollama}
# ── POST /run ─────────────────────────────────────────────────────────────────
class RunRequest(BaseModel):
corpus: list[str]
queries: list[str]
models: list[str]
top_k: int = 5
ollama_url: str = ""
@field_validator("corpus")
@classmethod
def corpus_nonempty(cls, v: list[str]) -> list[str]:
if not v:
raise ValueError("corpus must not be empty")
return v
@field_validator("queries")
@classmethod
def queries_nonempty(cls, v: list[str]) -> list[str]:
if not v:
raise ValueError("queries must not be empty")
return v
@field_validator("models")
@classmethod
def models_nonempty(cls, v: list[str]) -> list[str]:
if not v:
raise ValueError("models must contain at least one model name")
return v
def _embed_texts(ollama: str, model: str, texts: list[str]) -> list[list[float]]:
"""Batch-embed texts via Ollama /v1/embeddings. Returns one vector per text."""
resp = httpx.post(
f"{ollama}/v1/embeddings",
json={"model": model, "input": texts},
timeout=120.0,
)
resp.raise_for_status()
data = resp.json().get("data", [])
return [item["embedding"] for item in data]
def _sse(event: dict) -> str:
return f"data: {json.dumps(event)}\n\n"
@router.post("/run")
def run_embed_bench(req: RunRequest) -> StreamingResponse:
"""Embed corpus + queries with each model; stream SSE results."""
global _RUN_ACTIVE
if _RUN_ACTIVE:
raise HTTPException(409, "An embedding benchmark run is already active")
ollama = req.ollama_url or _ollama_url()
def _generate():
global _RUN_ACTIVE
_RUN_ACTIVE = True
try:
for model_idx, model in enumerate(req.models, start=1):
yield _sse({
"type": "progress",
"msg": f"Indexing corpus with {model} ({model_idx}/{len(req.models)})...",
})
try:
corpus_vecs = _embed_texts(ollama, model, req.corpus)
except Exception as exc:
yield _sse({"type": "error", "msg": f"Ollama error for {model}: {exc}"})
continue
yield _sse({
"type": "progress",
"msg": f"Running queries with {model}...",
})
for q_idx, query in enumerate(req.queries):
try:
q_vecs = _embed_texts(ollama, model, [query])
except Exception as exc:
yield _sse({"type": "error", "msg": f"Query embed error ({model}): {exc}"})
continue
q_vec = q_vecs[0]
scored = sorted(
[
{"chunk_idx": i, "text": chunk, "score": round(_cosine(q_vec, cv), 4)}
for i, (chunk, cv) in enumerate(zip(req.corpus, corpus_vecs))
],
key=lambda h: h["score"],
reverse=True,
)[: req.top_k]
yield _sse({
"type": "result",
"query_idx": q_idx,
"query": query,
"model": model,
"hits": scored,
})
yield _sse({"type": "done"})
finally:
_RUN_ACTIVE = False
return StreamingResponse(_generate(), media_type="text/event-stream")

View file

@ -1,6 +1,7 @@
"""Tests for app/eval/embed_bench.py."""
from __future__ import annotations
import json
from pathlib import Path
from unittest.mock import MagicMock, patch
@ -89,3 +90,88 @@ def test_models_returns_empty_on_ollama_error(client, tmp_path):
r = client.get("/api/embed-bench/models")
assert r.status_code == 200
assert r.json()["models"] == []
# ── run endpoint ───────────────────────────────────────────────────────────────
def test_run_empty_corpus_returns_422(client):
r = client.post("/api/embed-bench/run", json={
"corpus": [], "queries": ["test"], "models": ["nomic-embed-text"], "top_k": 3
})
assert r.status_code == 422
def test_run_empty_queries_returns_422(client):
r = client.post("/api/embed-bench/run", json={
"corpus": ["chunk 1"], "queries": [], "models": ["nomic-embed-text"], "top_k": 3
})
assert r.status_code == 422
def test_run_empty_models_returns_422(client):
r = client.post("/api/embed-bench/run", json={
"corpus": ["chunk 1"], "queries": ["test"], "models": [], "top_k": 3
})
assert r.status_code == 422
def _fake_embed_response(texts: list[str]) -> MagicMock:
"""Build a mock httpx.post response returning unit vectors for each text."""
resp = MagicMock()
resp.raise_for_status = MagicMock()
resp.json.return_value = {
"data": [{"embedding": [1.0, 0.0, 0.0] if i % 2 == 0 else [0.0, 1.0, 0.0]}
for i, _ in enumerate(texts)]
}
return resp
def _collect_sse(raw: bytes) -> list[dict]:
"""Parse SSE stream bytes into a list of decoded event dicts."""
events = []
for line in raw.decode().splitlines():
if line.startswith("data: "):
events.append(json.loads(line[6:]))
return events
def test_run_single_model_returns_result_and_done(client, tmp_path):
import yaml
(tmp_path / "label_tool.yaml").write_text(yaml.dump({"cforch": {"ollama_url": "http://localhost:11434"}}))
with patch("app.eval.embed_bench.httpx.post", return_value=_fake_embed_response(["chunk 1", "chunk 2"])):
r = client.post("/api/embed-bench/run", json={
"corpus": ["chunk 1", "chunk 2"],
"queries": ["what is chunk one?"],
"models": ["nomic-embed-text"],
"top_k": 2,
})
assert r.status_code == 200
events = _collect_sse(r.content)
types = [e["type"] for e in events]
assert "result" in types
assert types[-1] == "done"
result_events = [e for e in events if e["type"] == "result"]
assert result_events[0]["model"] == "nomic-embed-text"
assert result_events[0]["query_idx"] == 0
assert len(result_events[0]["hits"]) <= 2
def test_run_two_models_returns_two_result_events_per_query(client, tmp_path):
import yaml
(tmp_path / "label_tool.yaml").write_text(yaml.dump({"cforch": {"ollama_url": "http://localhost:11434"}}))
with patch("app.eval.embed_bench.httpx.post", return_value=_fake_embed_response(["chunk A", "chunk B"])):
r = client.post("/api/embed-bench/run", json={
"corpus": ["chunk A", "chunk B"],
"queries": ["find it"],
"models": ["nomic-embed-text", "mxbai-embed-large"],
"top_k": 2,
})
events = _collect_sse(r.content)
result_events = [e for e in events if e["type"] == "result"]
models_seen = {e["model"] for e in result_events}
assert "nomic-embed-text" in models_seen
assert "mxbai-embed-large" in models_seen