feat: add embed-bench models endpoint and register router in aggregator

This commit is contained in:
pyr0ball 2026-05-07 09:01:25 -07:00
parent 5ea77da97d
commit 5939c67b9f
3 changed files with 60 additions and 0 deletions

View file

@ -18,12 +18,14 @@ from app.cforch import router as _cforch_router
from app.style import router as _style_router from app.style import router as _style_router
from app.voice import router as _voice_router from app.voice import router as _voice_router
from app.plans_bench import router as _plans_router from app.plans_bench import router as _plans_router
from app.eval.embed_bench import router as _embed_router
router = APIRouter() router = APIRouter()
router.include_router(_cforch_router, prefix="/cforch") router.include_router(_cforch_router, prefix="/cforch")
router.include_router(_style_router, prefix="/style") router.include_router(_style_router, prefix="/style")
router.include_router(_voice_router, prefix="/voice") router.include_router(_voice_router, prefix="/voice")
router.include_router(_plans_router, prefix="/plans-bench") router.include_router(_plans_router, prefix="/plans-bench")
router.include_router(_embed_router, prefix="/embed-bench")
def set_config_dir(path) -> None: def set_config_dir(path) -> None:
@ -32,7 +34,9 @@ def set_config_dir(path) -> None:
import app.style as _style_mod import app.style as _style_mod
import app.voice as _voice_mod import app.voice as _voice_mod
import app.plans_bench as _plans_mod import app.plans_bench as _plans_mod
import app.eval.embed_bench as _embed_mod
_cforch_mod.set_config_dir(path) _cforch_mod.set_config_dir(path)
_style_mod.set_config_dir(path) _style_mod.set_config_dir(path)
_voice_mod.set_config_dir(path) _voice_mod.set_config_dir(path)
_plans_mod.set_config_dir(path) _plans_mod.set_config_dir(path)
_embed_mod.set_config_dir(path)

View file

@ -83,3 +83,23 @@ def _cosine(a: list[float], b: list[float]) -> float:
if mag_a == 0.0 or mag_b == 0.0: if mag_a == 0.0 or mag_b == 0.0:
return 0.0 return 0.0
return dot / (mag_a * mag_b) return dot / (mag_a * mag_b)
# ── GET /models ───────────────────────────────────────────────────────────────
@router.get("/models")
def get_models() -> dict:
"""Return Ollama embedding models available on the configured instance."""
ollama = _ollama_url()
models: list[dict] = []
try:
resp = httpx.get(f"{ollama}/api/tags", timeout=5.0)
resp.raise_for_status()
for entry in resp.json().get("models", []):
models.append({
"name": entry.get("name", ""),
"size": entry.get("size", 0),
})
except Exception as exc:
logger.warning("Failed to list Ollama models: %s", exc)
return {"models": models, "ollama_url": ollama}

View file

@ -53,3 +53,39 @@ def test_cosine_opposite():
def test_cosine_zero_vector_returns_zero(): def test_cosine_zero_vector_returns_zero():
from app.eval.embed_bench import _cosine from app.eval.embed_bench import _cosine
assert _cosine([0.0, 0.0], [1.0, 0.0]) == pytest.approx(0.0) assert _cosine([0.0, 0.0], [1.0, 0.0]) == pytest.approx(0.0)
# ── models endpoint ────────────────────────────────────────────────────────────
def test_models_returns_list_with_mock(client, tmp_path):
"""GET /api/embed-bench/models returns list from Ollama tags endpoint."""
import yaml
cfg = {"cforch": {"ollama_url": "http://localhost:11434"}}
(tmp_path / "label_tool.yaml").write_text(yaml.dump(cfg))
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.json.return_value = {
"models": [
{"name": "nomic-embed-text", "size": 274302480},
{"name": "mxbai-embed-large", "size": 669000000},
]
}
mock_resp.raise_for_status = MagicMock()
with patch("httpx.get", return_value=mock_resp):
r = client.get("/api/embed-bench/models")
assert r.status_code == 200
data = r.json()
assert isinstance(data["models"], list)
assert any(m["name"] == "nomic-embed-text" for m in data["models"])
def test_models_returns_empty_on_ollama_error(client, tmp_path):
"""GET /api/embed-bench/models returns empty list if Ollama unreachable."""
import httpx
with patch("httpx.get", side_effect=httpx.ConnectError("refused")):
r = client.get("/api/embed-bench/models")
assert r.status_code == 200
assert r.json()["models"] == []