From 32e3b2a0ddade643f85bbdcbbc0484fb8c78b0ca Mon Sep 17 00:00:00 2001 From: pyr0ball Date: Thu, 7 May 2026 09:05:34 -0700 Subject: [PATCH] feat: add embed-bench run endpoint with SSE streaming --- app/eval/embed_bench.py | 107 ++++++++++++++++++++++++++++++++++++++ tests/test_embed_bench.py | 86 ++++++++++++++++++++++++++++++ 2 files changed, 193 insertions(+) diff --git a/app/eval/embed_bench.py b/app/eval/embed_bench.py index a46e221..f56abd4 100644 --- a/app/eval/embed_bench.py +++ b/app/eval/embed_bench.py @@ -105,3 +105,110 @@ def get_models() -> dict: except httpx.RequestError as exc: logger.warning("Failed to reach Ollama for model list: %s", exc) return {"models": models, "ollama_url": ollama} + + +# ── POST /run ───────────────────────────────────────────────────────────────── + +class RunRequest(BaseModel): + corpus: list[str] + queries: list[str] + models: list[str] + top_k: int = 5 + ollama_url: str = "" + + @field_validator("corpus") + @classmethod + def corpus_nonempty(cls, v: list[str]) -> list[str]: + if not v: + raise ValueError("corpus must not be empty") + return v + + @field_validator("queries") + @classmethod + def queries_nonempty(cls, v: list[str]) -> list[str]: + if not v: + raise ValueError("queries must not be empty") + return v + + @field_validator("models") + @classmethod + def models_nonempty(cls, v: list[str]) -> list[str]: + if not v: + raise ValueError("models must contain at least one model name") + return v + + +def _embed_texts(ollama: str, model: str, texts: list[str]) -> list[list[float]]: + """Batch-embed texts via Ollama /v1/embeddings. Returns one vector per text.""" + resp = httpx.post( + f"{ollama}/v1/embeddings", + json={"model": model, "input": texts}, + timeout=120.0, + ) + resp.raise_for_status() + data = resp.json().get("data", []) + return [item["embedding"] for item in data] + + +def _sse(event: dict) -> str: + return f"data: {json.dumps(event)}\n\n" + + +@router.post("/run") +def run_embed_bench(req: RunRequest) -> StreamingResponse: + """Embed corpus + queries with each model; stream SSE results.""" + global _RUN_ACTIVE + + if _RUN_ACTIVE: + raise HTTPException(409, "An embedding benchmark run is already active") + + ollama = req.ollama_url or _ollama_url() + + def _generate(): + global _RUN_ACTIVE + _RUN_ACTIVE = True + try: + for model_idx, model in enumerate(req.models, start=1): + yield _sse({ + "type": "progress", + "msg": f"Indexing corpus with {model} ({model_idx}/{len(req.models)})...", + }) + try: + corpus_vecs = _embed_texts(ollama, model, req.corpus) + except Exception as exc: + yield _sse({"type": "error", "msg": f"Ollama error for {model}: {exc}"}) + continue + + yield _sse({ + "type": "progress", + "msg": f"Running queries with {model}...", + }) + + for q_idx, query in enumerate(req.queries): + try: + q_vecs = _embed_texts(ollama, model, [query]) + except Exception as exc: + yield _sse({"type": "error", "msg": f"Query embed error ({model}): {exc}"}) + continue + q_vec = q_vecs[0] + scored = sorted( + [ + {"chunk_idx": i, "text": chunk, "score": round(_cosine(q_vec, cv), 4)} + for i, (chunk, cv) in enumerate(zip(req.corpus, corpus_vecs)) + ], + key=lambda h: h["score"], + reverse=True, + )[: req.top_k] + yield _sse({ + "type": "result", + "query_idx": q_idx, + "query": query, + "model": model, + "hits": scored, + }) + + yield _sse({"type": "done"}) + finally: + _RUN_ACTIVE = False + + return StreamingResponse(_generate(), media_type="text/event-stream") diff --git a/tests/test_embed_bench.py b/tests/test_embed_bench.py index 4488e06..800fd61 100644 --- a/tests/test_embed_bench.py +++ b/tests/test_embed_bench.py @@ -1,6 +1,7 @@ """Tests for app/eval/embed_bench.py.""" from __future__ import annotations +import json from pathlib import Path from unittest.mock import MagicMock, patch @@ -89,3 +90,88 @@ def test_models_returns_empty_on_ollama_error(client, tmp_path): r = client.get("/api/embed-bench/models") assert r.status_code == 200 assert r.json()["models"] == [] + + +# ── run endpoint ─────────────────────────────────────────────────────────────── + +def test_run_empty_corpus_returns_422(client): + r = client.post("/api/embed-bench/run", json={ + "corpus": [], "queries": ["test"], "models": ["nomic-embed-text"], "top_k": 3 + }) + assert r.status_code == 422 + + +def test_run_empty_queries_returns_422(client): + r = client.post("/api/embed-bench/run", json={ + "corpus": ["chunk 1"], "queries": [], "models": ["nomic-embed-text"], "top_k": 3 + }) + assert r.status_code == 422 + + +def test_run_empty_models_returns_422(client): + r = client.post("/api/embed-bench/run", json={ + "corpus": ["chunk 1"], "queries": ["test"], "models": [], "top_k": 3 + }) + assert r.status_code == 422 + + +def _fake_embed_response(texts: list[str]) -> MagicMock: + """Build a mock httpx.post response returning unit vectors for each text.""" + resp = MagicMock() + resp.raise_for_status = MagicMock() + resp.json.return_value = { + "data": [{"embedding": [1.0, 0.0, 0.0] if i % 2 == 0 else [0.0, 1.0, 0.0]} + for i, _ in enumerate(texts)] + } + return resp + + +def _collect_sse(raw: bytes) -> list[dict]: + """Parse SSE stream bytes into a list of decoded event dicts.""" + events = [] + for line in raw.decode().splitlines(): + if line.startswith("data: "): + events.append(json.loads(line[6:])) + return events + + +def test_run_single_model_returns_result_and_done(client, tmp_path): + import yaml + (tmp_path / "label_tool.yaml").write_text(yaml.dump({"cforch": {"ollama_url": "http://localhost:11434"}})) + + with patch("app.eval.embed_bench.httpx.post", return_value=_fake_embed_response(["chunk 1", "chunk 2"])): + r = client.post("/api/embed-bench/run", json={ + "corpus": ["chunk 1", "chunk 2"], + "queries": ["what is chunk one?"], + "models": ["nomic-embed-text"], + "top_k": 2, + }) + + assert r.status_code == 200 + events = _collect_sse(r.content) + types = [e["type"] for e in events] + assert "result" in types + assert types[-1] == "done" + result_events = [e for e in events if e["type"] == "result"] + assert result_events[0]["model"] == "nomic-embed-text" + assert result_events[0]["query_idx"] == 0 + assert len(result_events[0]["hits"]) <= 2 + + +def test_run_two_models_returns_two_result_events_per_query(client, tmp_path): + import yaml + (tmp_path / "label_tool.yaml").write_text(yaml.dump({"cforch": {"ollama_url": "http://localhost:11434"}})) + + with patch("app.eval.embed_bench.httpx.post", return_value=_fake_embed_response(["chunk A", "chunk B"])): + r = client.post("/api/embed-bench/run", json={ + "corpus": ["chunk A", "chunk B"], + "queries": ["find it"], + "models": ["nomic-embed-text", "mxbai-embed-large"], + "top_k": 2, + }) + + events = _collect_sse(r.content) + result_events = [e for e in events if e["type"] == "result"] + models_seen = {e["model"] for e in result_events} + assert "nomic-embed-text" in models_seen + assert "mxbai-embed-large" in models_seen