feat: add embed-bench rate and export endpoints
This commit is contained in:
parent
32e3b2a0dd
commit
1ad7ba322a
2 changed files with 136 additions and 0 deletions
|
|
@ -212,3 +212,82 @@ def run_embed_bench(req: RunRequest) -> StreamingResponse:
|
||||||
_RUN_ACTIVE = False
|
_RUN_ACTIVE = False
|
||||||
|
|
||||||
return StreamingResponse(_generate(), media_type="text/event-stream")
|
return StreamingResponse(_generate(), media_type="text/event-stream")
|
||||||
|
|
||||||
|
|
||||||
|
# ── POST /rate ────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
_VALID_RATINGS = {"relevant", "not_relevant"}
|
||||||
|
|
||||||
|
|
||||||
|
class RatingRequest(BaseModel):
|
||||||
|
query: str
|
||||||
|
model: str
|
||||||
|
chunk_text: str
|
||||||
|
chunk_idx: int
|
||||||
|
rating: str
|
||||||
|
|
||||||
|
@field_validator("rating")
|
||||||
|
@classmethod
|
||||||
|
def rating_valid(cls, v: str) -> str:
|
||||||
|
if v not in _VALID_RATINGS:
|
||||||
|
raise ValueError(f"rating must be one of {_VALID_RATINGS}")
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/rate")
|
||||||
|
def rate_result(req: RatingRequest) -> dict:
|
||||||
|
"""Append one rating to the JSONL ratings file."""
|
||||||
|
entry = {
|
||||||
|
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||||
|
"query": req.query,
|
||||||
|
"model": req.model,
|
||||||
|
"chunk_idx": req.chunk_idx,
|
||||||
|
"chunk_text": req.chunk_text,
|
||||||
|
"rating": req.rating,
|
||||||
|
}
|
||||||
|
path = _ratings_path()
|
||||||
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
with path.open("a", encoding="utf-8") as fh:
|
||||||
|
fh.write(json.dumps(entry) + "\n")
|
||||||
|
return {"ok": True}
|
||||||
|
|
||||||
|
|
||||||
|
# ── GET /export ───────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
_CSV_FIELDS = ["timestamp", "query", "model", "chunk_idx", "chunk_text", "rating"]
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/export")
|
||||||
|
def export_ratings(format: str = "csv") -> Any:
|
||||||
|
"""Download ratings as CSV or JSON."""
|
||||||
|
path = _ratings_path()
|
||||||
|
rows: list[dict] = []
|
||||||
|
if path.exists():
|
||||||
|
for raw in path.read_text(encoding="utf-8").splitlines():
|
||||||
|
raw = raw.strip()
|
||||||
|
if raw:
|
||||||
|
try:
|
||||||
|
rows.append(json.loads(raw))
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
date_str = datetime.now(timezone.utc).strftime("%Y-%m-%d")
|
||||||
|
|
||||||
|
if format == "json":
|
||||||
|
content = json.dumps(rows, ensure_ascii=False, indent=2)
|
||||||
|
return StreamingResponse(
|
||||||
|
iter([content]),
|
||||||
|
media_type="application/json",
|
||||||
|
headers={"Content-Disposition": f'attachment; filename="embed_comparison_{date_str}.json"'},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Default: CSV
|
||||||
|
buf = io.StringIO()
|
||||||
|
writer = csv.DictWriter(buf, fieldnames=_CSV_FIELDS, extrasaction="ignore")
|
||||||
|
writer.writeheader()
|
||||||
|
writer.writerows(rows)
|
||||||
|
return StreamingResponse(
|
||||||
|
iter([buf.getvalue()]),
|
||||||
|
media_type="text/csv",
|
||||||
|
headers={"Content-Disposition": f'attachment; filename="embed_comparison_{date_str}.csv"'},
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -175,3 +175,60 @@ def test_run_two_models_returns_two_result_events_per_query(client, tmp_path):
|
||||||
models_seen = {e["model"] for e in result_events}
|
models_seen = {e["model"] for e in result_events}
|
||||||
assert "nomic-embed-text" in models_seen
|
assert "nomic-embed-text" in models_seen
|
||||||
assert "mxbai-embed-large" in models_seen
|
assert "mxbai-embed-large" in models_seen
|
||||||
|
|
||||||
|
|
||||||
|
# ── rate + export ──────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_rate_appends_jsonl_line(client, tmp_path):
|
||||||
|
r = client.post("/api/embed-bench/rate", json={
|
||||||
|
"query": "test query",
|
||||||
|
"model": "nomic-embed-text",
|
||||||
|
"chunk_text": "some text",
|
||||||
|
"chunk_idx": 2,
|
||||||
|
"rating": "relevant",
|
||||||
|
})
|
||||||
|
assert r.status_code == 200
|
||||||
|
assert r.json() == {"ok": True}
|
||||||
|
ratings_file = tmp_path / "embed_bench_ratings.jsonl"
|
||||||
|
assert ratings_file.exists()
|
||||||
|
line = json.loads(ratings_file.read_text().strip())
|
||||||
|
assert line["query"] == "test query"
|
||||||
|
assert line["rating"] == "relevant"
|
||||||
|
assert line["chunk_idx"] == 2
|
||||||
|
assert "timestamp" in line
|
||||||
|
|
||||||
|
|
||||||
|
def test_export_csv_two_rows(client, tmp_path):
|
||||||
|
for i in range(2):
|
||||||
|
client.post("/api/embed-bench/rate", json={
|
||||||
|
"query": f"q{i}", "model": "nomic-embed-text",
|
||||||
|
"chunk_text": f"chunk {i}", "chunk_idx": i, "rating": "relevant",
|
||||||
|
})
|
||||||
|
r = client.get("/api/embed-bench/export?format=csv")
|
||||||
|
assert r.status_code == 200
|
||||||
|
assert "text/csv" in r.headers["content-type"]
|
||||||
|
lines = r.text.strip().splitlines()
|
||||||
|
assert len(lines) == 3 # header + 2 rows
|
||||||
|
assert "query" in lines[0]
|
||||||
|
|
||||||
|
|
||||||
|
def test_export_json_two_entries(client, tmp_path):
|
||||||
|
for i in range(2):
|
||||||
|
client.post("/api/embed-bench/rate", json={
|
||||||
|
"query": f"q{i}", "model": "nomic-embed-text",
|
||||||
|
"chunk_text": f"chunk {i}", "chunk_idx": i, "rating": "not_relevant",
|
||||||
|
})
|
||||||
|
r = client.get("/api/embed-bench/export?format=json")
|
||||||
|
assert r.status_code == 200
|
||||||
|
data = r.json()
|
||||||
|
assert isinstance(data, list)
|
||||||
|
assert len(data) == 2
|
||||||
|
assert data[0]["rating"] == "not_relevant"
|
||||||
|
|
||||||
|
|
||||||
|
def test_export_empty_returns_csv_header_only(client):
|
||||||
|
r = client.get("/api/embed-bench/export?format=csv")
|
||||||
|
assert r.status_code == 200
|
||||||
|
lines = r.text.strip().splitlines()
|
||||||
|
assert len(lines) == 1 # header only
|
||||||
|
assert "query" in lines[0]
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue