From 1ad7ba322af3c033cdd96c32c681b05bf6ee9da0 Mon Sep 17 00:00:00 2001 From: pyr0ball Date: Mon, 11 May 2026 08:07:17 -0700 Subject: [PATCH] feat: add embed-bench rate and export endpoints --- app/eval/embed_bench.py | 79 +++++++++++++++++++++++++++++++++++++++ tests/test_embed_bench.py | 57 ++++++++++++++++++++++++++++ 2 files changed, 136 insertions(+) diff --git a/app/eval/embed_bench.py b/app/eval/embed_bench.py index f56abd4..bef333e 100644 --- a/app/eval/embed_bench.py +++ b/app/eval/embed_bench.py @@ -212,3 +212,82 @@ def run_embed_bench(req: RunRequest) -> StreamingResponse: _RUN_ACTIVE = False return StreamingResponse(_generate(), media_type="text/event-stream") + + +# ── POST /rate ──────────────────────────────────────────────────────────────── + +_VALID_RATINGS = {"relevant", "not_relevant"} + + +class RatingRequest(BaseModel): + query: str + model: str + chunk_text: str + chunk_idx: int + rating: str + + @field_validator("rating") + @classmethod + def rating_valid(cls, v: str) -> str: + if v not in _VALID_RATINGS: + raise ValueError(f"rating must be one of {_VALID_RATINGS}") + return v + + +@router.post("/rate") +def rate_result(req: RatingRequest) -> dict: + """Append one rating to the JSONL ratings file.""" + entry = { + "timestamp": datetime.now(timezone.utc).isoformat(), + "query": req.query, + "model": req.model, + "chunk_idx": req.chunk_idx, + "chunk_text": req.chunk_text, + "rating": req.rating, + } + path = _ratings_path() + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("a", encoding="utf-8") as fh: + fh.write(json.dumps(entry) + "\n") + return {"ok": True} + + +# ── GET /export ─────────────────────────────────────────────────────────────── + +_CSV_FIELDS = ["timestamp", "query", "model", "chunk_idx", "chunk_text", "rating"] + + +@router.get("/export") +def export_ratings(format: str = "csv") -> Any: + """Download ratings as CSV or JSON.""" + path = _ratings_path() + rows: list[dict] = [] + if path.exists(): + for raw in path.read_text(encoding="utf-8").splitlines(): + raw = raw.strip() + if raw: + try: + rows.append(json.loads(raw)) + except json.JSONDecodeError: + pass + + date_str = datetime.now(timezone.utc).strftime("%Y-%m-%d") + + if format == "json": + content = json.dumps(rows, ensure_ascii=False, indent=2) + return StreamingResponse( + iter([content]), + media_type="application/json", + headers={"Content-Disposition": f'attachment; filename="embed_comparison_{date_str}.json"'}, + ) + + # Default: CSV + buf = io.StringIO() + writer = csv.DictWriter(buf, fieldnames=_CSV_FIELDS, extrasaction="ignore") + writer.writeheader() + writer.writerows(rows) + return StreamingResponse( + iter([buf.getvalue()]), + media_type="text/csv", + headers={"Content-Disposition": f'attachment; filename="embed_comparison_{date_str}.csv"'}, + ) diff --git a/tests/test_embed_bench.py b/tests/test_embed_bench.py index 800fd61..ea23564 100644 --- a/tests/test_embed_bench.py +++ b/tests/test_embed_bench.py @@ -175,3 +175,60 @@ def test_run_two_models_returns_two_result_events_per_query(client, tmp_path): models_seen = {e["model"] for e in result_events} assert "nomic-embed-text" in models_seen assert "mxbai-embed-large" in models_seen + + +# ── rate + export ────────────────────────────────────────────────────────────── + +def test_rate_appends_jsonl_line(client, tmp_path): + r = client.post("/api/embed-bench/rate", json={ + "query": "test query", + "model": "nomic-embed-text", + "chunk_text": "some text", + "chunk_idx": 2, + "rating": "relevant", + }) + assert r.status_code == 200 + assert r.json() == {"ok": True} + ratings_file = tmp_path / "embed_bench_ratings.jsonl" + assert ratings_file.exists() + line = json.loads(ratings_file.read_text().strip()) + assert line["query"] == "test query" + assert line["rating"] == "relevant" + assert line["chunk_idx"] == 2 + assert "timestamp" in line + + +def test_export_csv_two_rows(client, tmp_path): + for i in range(2): + client.post("/api/embed-bench/rate", json={ + "query": f"q{i}", "model": "nomic-embed-text", + "chunk_text": f"chunk {i}", "chunk_idx": i, "rating": "relevant", + }) + r = client.get("/api/embed-bench/export?format=csv") + assert r.status_code == 200 + assert "text/csv" in r.headers["content-type"] + lines = r.text.strip().splitlines() + assert len(lines) == 3 # header + 2 rows + assert "query" in lines[0] + + +def test_export_json_two_entries(client, tmp_path): + for i in range(2): + client.post("/api/embed-bench/rate", json={ + "query": f"q{i}", "model": "nomic-embed-text", + "chunk_text": f"chunk {i}", "chunk_idx": i, "rating": "not_relevant", + }) + r = client.get("/api/embed-bench/export?format=json") + assert r.status_code == 200 + data = r.json() + assert isinstance(data, list) + assert len(data) == 2 + assert data[0]["rating"] == "not_relevant" + + +def test_export_empty_returns_csv_header_only(client): + r = client.get("/api/embed-bench/export?format=csv") + assert r.status_code == 200 + lines = r.text.strip().splitlines() + assert len(lines) == 1 # header only + assert "query" in lines[0]