feat: add embed-bench rate and export endpoints

This commit is contained in:
pyr0ball 2026-05-11 08:07:17 -07:00
parent 32e3b2a0dd
commit 1ad7ba322a
2 changed files with 136 additions and 0 deletions

View file

@ -212,3 +212,82 @@ def run_embed_bench(req: RunRequest) -> StreamingResponse:
_RUN_ACTIVE = False _RUN_ACTIVE = False
return StreamingResponse(_generate(), media_type="text/event-stream") return StreamingResponse(_generate(), media_type="text/event-stream")
# ── POST /rate ────────────────────────────────────────────────────────────────
_VALID_RATINGS = {"relevant", "not_relevant"}
class RatingRequest(BaseModel):
query: str
model: str
chunk_text: str
chunk_idx: int
rating: str
@field_validator("rating")
@classmethod
def rating_valid(cls, v: str) -> str:
if v not in _VALID_RATINGS:
raise ValueError(f"rating must be one of {_VALID_RATINGS}")
return v
@router.post("/rate")
def rate_result(req: RatingRequest) -> dict:
"""Append one rating to the JSONL ratings file."""
entry = {
"timestamp": datetime.now(timezone.utc).isoformat(),
"query": req.query,
"model": req.model,
"chunk_idx": req.chunk_idx,
"chunk_text": req.chunk_text,
"rating": req.rating,
}
path = _ratings_path()
path.parent.mkdir(parents=True, exist_ok=True)
with path.open("a", encoding="utf-8") as fh:
fh.write(json.dumps(entry) + "\n")
return {"ok": True}
# ── GET /export ───────────────────────────────────────────────────────────────
_CSV_FIELDS = ["timestamp", "query", "model", "chunk_idx", "chunk_text", "rating"]
@router.get("/export")
def export_ratings(format: str = "csv") -> Any:
"""Download ratings as CSV or JSON."""
path = _ratings_path()
rows: list[dict] = []
if path.exists():
for raw in path.read_text(encoding="utf-8").splitlines():
raw = raw.strip()
if raw:
try:
rows.append(json.loads(raw))
except json.JSONDecodeError:
pass
date_str = datetime.now(timezone.utc).strftime("%Y-%m-%d")
if format == "json":
content = json.dumps(rows, ensure_ascii=False, indent=2)
return StreamingResponse(
iter([content]),
media_type="application/json",
headers={"Content-Disposition": f'attachment; filename="embed_comparison_{date_str}.json"'},
)
# Default: CSV
buf = io.StringIO()
writer = csv.DictWriter(buf, fieldnames=_CSV_FIELDS, extrasaction="ignore")
writer.writeheader()
writer.writerows(rows)
return StreamingResponse(
iter([buf.getvalue()]),
media_type="text/csv",
headers={"Content-Disposition": f'attachment; filename="embed_comparison_{date_str}.csv"'},
)

View file

@ -175,3 +175,60 @@ def test_run_two_models_returns_two_result_events_per_query(client, tmp_path):
models_seen = {e["model"] for e in result_events} models_seen = {e["model"] for e in result_events}
assert "nomic-embed-text" in models_seen assert "nomic-embed-text" in models_seen
assert "mxbai-embed-large" in models_seen assert "mxbai-embed-large" in models_seen
# ── rate + export ──────────────────────────────────────────────────────────────
def test_rate_appends_jsonl_line(client, tmp_path):
r = client.post("/api/embed-bench/rate", json={
"query": "test query",
"model": "nomic-embed-text",
"chunk_text": "some text",
"chunk_idx": 2,
"rating": "relevant",
})
assert r.status_code == 200
assert r.json() == {"ok": True}
ratings_file = tmp_path / "embed_bench_ratings.jsonl"
assert ratings_file.exists()
line = json.loads(ratings_file.read_text().strip())
assert line["query"] == "test query"
assert line["rating"] == "relevant"
assert line["chunk_idx"] == 2
assert "timestamp" in line
def test_export_csv_two_rows(client, tmp_path):
for i in range(2):
client.post("/api/embed-bench/rate", json={
"query": f"q{i}", "model": "nomic-embed-text",
"chunk_text": f"chunk {i}", "chunk_idx": i, "rating": "relevant",
})
r = client.get("/api/embed-bench/export?format=csv")
assert r.status_code == 200
assert "text/csv" in r.headers["content-type"]
lines = r.text.strip().splitlines()
assert len(lines) == 3 # header + 2 rows
assert "query" in lines[0]
def test_export_json_two_entries(client, tmp_path):
for i in range(2):
client.post("/api/embed-bench/rate", json={
"query": f"q{i}", "model": "nomic-embed-text",
"chunk_text": f"chunk {i}", "chunk_idx": i, "rating": "not_relevant",
})
r = client.get("/api/embed-bench/export?format=json")
assert r.status_code == 200
data = r.json()
assert isinstance(data, list)
assert len(data) == 2
assert data[0]["rating"] == "not_relevant"
def test_export_empty_returns_csv_header_only(client):
r = client.get("/api/embed-bench/export?format=csv")
assert r.status_code == 200
lines = r.text.strip().splitlines()
assert len(lines) == 1 # header only
assert "query" in lines[0]