feat(llm): add LLMRouter.embed() for batch embedding generation
Adds embed(texts, model_override, fallback_order) to LLMRouter. Only openai_compat backends are tried (Ollama/vLLM expose /v1/embeddings; anthropic and vision_service do not). Uses embedding_model from backend config when present, falls back to the chat model otherwise. Supports cf-orch allocation and raises RuntimeError when all backends are exhausted. 4 tests added (TDD: RED → GREEN), 763 total passing, no regressions.
This commit is contained in:
parent
a6d906bcbb
commit
8e2d15bcd4
2 changed files with 286 additions and 57 deletions
|
|
@ -43,6 +43,7 @@ When llm.yaml is absent, the router builds a minimal config from environment
|
|||
variables: ANTHROPIC_API_KEY, OPENAI_API_KEY / OPENAI_BASE_URL, OLLAMA_HOST.
|
||||
Ollama on localhost:11434 is always included as the lowest-cost local fallback.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import yaml
|
||||
|
|
@ -70,7 +71,8 @@ class LLMRouter:
|
|||
)
|
||||
logger.info(
|
||||
"[LLMRouter] No llm.yaml found — using env-var auto-config "
|
||||
"(backends: %s)", ", ".join(env_config["fallback_order"])
|
||||
"(backends: %s)",
|
||||
", ".join(env_config["fallback_order"]),
|
||||
)
|
||||
self.config = env_config
|
||||
|
||||
|
|
@ -103,7 +105,9 @@ class LLMRouter:
|
|||
backends["openai"] = {
|
||||
"type": "openai_compat",
|
||||
"enabled": True,
|
||||
"base_url": os.environ.get("OPENAI_BASE_URL", "https://api.openai.com/v1"),
|
||||
"base_url": os.environ.get(
|
||||
"OPENAI_BASE_URL", "https://api.openai.com/v1"
|
||||
),
|
||||
"model": os.environ.get("OPENAI_MODEL", "gpt-4o-mini"),
|
||||
"api_key": os.environ.get("OPENAI_API_KEY"),
|
||||
"supports_images": True,
|
||||
|
|
@ -156,6 +160,7 @@ class LLMRouter:
|
|||
Caller MUST call ctx.__exit__(None, None, None) in a finally block.
|
||||
"""
|
||||
import os
|
||||
|
||||
orch_cfg = backend.get("cf_orch")
|
||||
if not orch_cfg:
|
||||
return None
|
||||
|
|
@ -164,6 +169,7 @@ class LLMRouter:
|
|||
return None
|
||||
try:
|
||||
from circuitforge_orch.client import CFOrchClient
|
||||
|
||||
client = CFOrchClient(orch_url)
|
||||
service = orch_cfg.get("service", "vllm")
|
||||
candidates = orch_cfg.get("model_candidates", [])
|
||||
|
|
@ -181,14 +187,21 @@ class LLMRouter:
|
|||
alloc = ctx.__enter__()
|
||||
return (ctx, alloc)
|
||||
except Exception as exc:
|
||||
logger.warning("[LLMRouter] cf_orch allocation failed, using base_url directly: %s", exc)
|
||||
logger.warning(
|
||||
"[LLMRouter] cf_orch allocation failed, using base_url directly: %s",
|
||||
exc,
|
||||
)
|
||||
return None
|
||||
|
||||
def complete(self, prompt: str, system: str | None = None,
|
||||
model_override: str | None = None,
|
||||
fallback_order: list[str] | None = None,
|
||||
images: list[str] | None = None,
|
||||
max_tokens: int | None = None) -> str:
|
||||
def complete(
|
||||
self,
|
||||
prompt: str,
|
||||
system: str | None = None,
|
||||
model_override: str | None = None,
|
||||
fallback_order: list[str] | None = None,
|
||||
images: list[str] | None = None,
|
||||
max_tokens: int | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Generate a completion. Tries each backend in fallback_order.
|
||||
|
||||
|
|
@ -206,7 +219,11 @@ class LLMRouter:
|
|||
"AI inference is disabled in the public demo. "
|
||||
"Run your own instance to use AI features."
|
||||
)
|
||||
order = fallback_order if fallback_order is not None else self.config["fallback_order"]
|
||||
order = (
|
||||
fallback_order
|
||||
if fallback_order is not None
|
||||
else self.config["fallback_order"]
|
||||
)
|
||||
for name in order:
|
||||
backend = self.config["backends"][name]
|
||||
|
||||
|
|
@ -283,10 +300,14 @@ class LLMRouter:
|
|||
if images and supports_images:
|
||||
content = [{"type": "text", "text": prompt}]
|
||||
for img in images:
|
||||
content.append({
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/png;base64,{img}"},
|
||||
})
|
||||
content.append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/png;base64,{img}"
|
||||
},
|
||||
}
|
||||
)
|
||||
messages.append({"role": "user", "content": content})
|
||||
else:
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
|
@ -311,18 +332,27 @@ class LLMRouter:
|
|||
elif backend["type"] == "anthropic":
|
||||
api_key = os.environ.get(backend["api_key_env"], "")
|
||||
if not api_key:
|
||||
print(f"[LLMRouter] {name}: {backend['api_key_env']} not set, skipping")
|
||||
print(
|
||||
f"[LLMRouter] {name}: {backend['api_key_env']} not set, skipping"
|
||||
)
|
||||
continue
|
||||
try:
|
||||
import anthropic as _anthropic
|
||||
|
||||
client = _anthropic.Anthropic(api_key=api_key)
|
||||
if images and supports_images:
|
||||
content = []
|
||||
for img in images:
|
||||
content.append({
|
||||
"type": "image",
|
||||
"source": {"type": "base64", "media_type": "image/png", "data": img},
|
||||
})
|
||||
content.append(
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": "image/png",
|
||||
"data": img,
|
||||
},
|
||||
}
|
||||
)
|
||||
content.append({"type": "text", "text": prompt})
|
||||
else:
|
||||
content = prompt
|
||||
|
|
@ -342,6 +372,76 @@ class LLMRouter:
|
|||
|
||||
raise RuntimeError("All LLM backends exhausted")
|
||||
|
||||
def embed(
|
||||
self,
|
||||
texts: list[str],
|
||||
model_override: str | None = None,
|
||||
fallback_order: list[str] | None = None,
|
||||
) -> list[list[float]]:
|
||||
"""
|
||||
Generate embeddings for a list of texts.
|
||||
|
||||
Only openai_compat backends are tried — Ollama and vLLM expose
|
||||
/v1/embeddings; anthropic and vision_service do not.
|
||||
|
||||
Uses ``embedding_model`` from backend config when present;
|
||||
falls back to ``model`` (the chat model) otherwise.
|
||||
|
||||
Args:
|
||||
texts: Texts to embed (batched in a single API call).
|
||||
model_override: Override the embedding model for this call.
|
||||
fallback_order: Override the backend fallback order for this call.
|
||||
|
||||
Returns:
|
||||
List of float vectors, one per input text, in input order.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If all eligible backends are exhausted.
|
||||
"""
|
||||
order = (
|
||||
fallback_order
|
||||
if fallback_order is not None
|
||||
else self.config["fallback_order"]
|
||||
)
|
||||
for name in order:
|
||||
backend = self.config["backends"][name]
|
||||
if not backend.get("enabled", True):
|
||||
continue
|
||||
if backend["type"] != "openai_compat":
|
||||
continue
|
||||
|
||||
orch_ctx = orch_alloc = None
|
||||
orch_result = self._try_cf_orch_alloc(backend)
|
||||
if orch_result is not None:
|
||||
orch_ctx, orch_alloc = orch_result
|
||||
backend = {**backend, "base_url": orch_alloc.url + "/v1"}
|
||||
elif not self._is_reachable(backend["base_url"]):
|
||||
print(f"[LLMRouter] {name}: unreachable, skipping")
|
||||
continue
|
||||
|
||||
try:
|
||||
client = OpenAI(
|
||||
base_url=backend["base_url"],
|
||||
api_key=backend.get("api_key") or "any",
|
||||
)
|
||||
model = model_override or backend.get(
|
||||
"embedding_model", backend["model"]
|
||||
)
|
||||
resp = client.embeddings.create(model=model, input=texts)
|
||||
print(f"[LLMRouter] embed: used backend {name} ({model})")
|
||||
return [item.embedding for item in resp.data]
|
||||
except Exception as e:
|
||||
print(f"[LLMRouter] {name}: embed error — {e}, trying next")
|
||||
continue
|
||||
finally:
|
||||
if orch_ctx is not None:
|
||||
try:
|
||||
orch_ctx.__exit__(None, None, None)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
raise RuntimeError("All LLM backends exhausted for embed()")
|
||||
|
||||
|
||||
# Module-level singleton for convenience
|
||||
_router: LLMRouter | None = None
|
||||
|
|
|
|||
|
|
@ -11,69 +11,81 @@ def _make_router(config: dict) -> LLMRouter:
|
|||
|
||||
|
||||
def test_complete_uses_first_reachable_backend():
|
||||
router = _make_router({
|
||||
"fallback_order": ["local"],
|
||||
"backends": {
|
||||
"local": {
|
||||
"type": "openai_compat",
|
||||
"base_url": "http://localhost:11434/v1",
|
||||
"model": "llama3",
|
||||
"supports_images": False,
|
||||
}
|
||||
router = _make_router(
|
||||
{
|
||||
"fallback_order": ["local"],
|
||||
"backends": {
|
||||
"local": {
|
||||
"type": "openai_compat",
|
||||
"base_url": "http://localhost:11434/v1",
|
||||
"model": "llama3",
|
||||
"supports_images": False,
|
||||
}
|
||||
},
|
||||
}
|
||||
})
|
||||
)
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create.return_value = MagicMock(
|
||||
choices=[MagicMock(message=MagicMock(content="hello"))]
|
||||
)
|
||||
with patch.object(router, "_is_reachable", return_value=True), \
|
||||
patch("circuitforge_core.llm.router.OpenAI", return_value=mock_client):
|
||||
with (
|
||||
patch.object(router, "_is_reachable", return_value=True),
|
||||
patch("circuitforge_core.llm.router.OpenAI", return_value=mock_client),
|
||||
):
|
||||
result = router.complete("say hello")
|
||||
assert result == "hello"
|
||||
|
||||
|
||||
def test_complete_falls_back_on_unreachable_backend():
|
||||
router = _make_router({
|
||||
"fallback_order": ["unreachable", "working"],
|
||||
"backends": {
|
||||
"unreachable": {
|
||||
"type": "openai_compat",
|
||||
"base_url": "http://nowhere:1/v1",
|
||||
"model": "x",
|
||||
"supports_images": False,
|
||||
router = _make_router(
|
||||
{
|
||||
"fallback_order": ["unreachable", "working"],
|
||||
"backends": {
|
||||
"unreachable": {
|
||||
"type": "openai_compat",
|
||||
"base_url": "http://nowhere:1/v1",
|
||||
"model": "x",
|
||||
"supports_images": False,
|
||||
},
|
||||
"working": {
|
||||
"type": "openai_compat",
|
||||
"base_url": "http://localhost:11434/v1",
|
||||
"model": "llama3",
|
||||
"supports_images": False,
|
||||
},
|
||||
},
|
||||
"working": {
|
||||
"type": "openai_compat",
|
||||
"base_url": "http://localhost:11434/v1",
|
||||
"model": "llama3",
|
||||
"supports_images": False,
|
||||
}
|
||||
}
|
||||
})
|
||||
)
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create.return_value = MagicMock(
|
||||
choices=[MagicMock(message=MagicMock(content="fallback"))]
|
||||
)
|
||||
|
||||
def reachable(url):
|
||||
return "nowhere" not in url
|
||||
with patch.object(router, "_is_reachable", side_effect=reachable), \
|
||||
patch("circuitforge_core.llm.router.OpenAI", return_value=mock_client):
|
||||
|
||||
with (
|
||||
patch.object(router, "_is_reachable", side_effect=reachable),
|
||||
patch("circuitforge_core.llm.router.OpenAI", return_value=mock_client),
|
||||
):
|
||||
result = router.complete("test")
|
||||
assert result == "fallback"
|
||||
|
||||
|
||||
def test_complete_raises_when_all_backends_exhausted():
|
||||
router = _make_router({
|
||||
"fallback_order": ["dead"],
|
||||
"backends": {
|
||||
"dead": {
|
||||
"type": "openai_compat",
|
||||
"base_url": "http://nowhere:1/v1",
|
||||
"model": "x",
|
||||
"supports_images": False,
|
||||
}
|
||||
router = _make_router(
|
||||
{
|
||||
"fallback_order": ["dead"],
|
||||
"backends": {
|
||||
"dead": {
|
||||
"type": "openai_compat",
|
||||
"base_url": "http://nowhere:1/v1",
|
||||
"model": "x",
|
||||
"supports_images": False,
|
||||
}
|
||||
},
|
||||
}
|
||||
})
|
||||
)
|
||||
with patch.object(router, "_is_reachable", return_value=False):
|
||||
with pytest.raises(RuntimeError, match="exhausted"):
|
||||
router.complete("test")
|
||||
|
|
@ -83,6 +95,123 @@ def test_try_cf_orch_alloc_import_path():
|
|||
"""Verify lazy import points to circuitforge_orch, not circuitforge_core.resources."""
|
||||
import inspect
|
||||
from circuitforge_core.llm import router as router_module
|
||||
|
||||
src = inspect.getsource(router_module.LLMRouter._try_cf_orch_alloc)
|
||||
assert "circuitforge_orch.client" in src
|
||||
assert "circuitforge_core.resources.client" not in src
|
||||
|
||||
|
||||
def test_embed_returns_vectors_from_openai_compat_backend():
|
||||
router = _make_router(
|
||||
{
|
||||
"fallback_order": ["ollama"],
|
||||
"backends": {
|
||||
"ollama": {
|
||||
"type": "openai_compat",
|
||||
"base_url": "http://localhost:11434/v1",
|
||||
"model": "mistral:7b",
|
||||
"embedding_model": "nomic-embed-text",
|
||||
"supports_images": False,
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
mock_client = MagicMock()
|
||||
mock_client.embeddings.create.return_value = MagicMock(
|
||||
data=[
|
||||
MagicMock(embedding=[0.1, 0.2, 0.3]),
|
||||
MagicMock(embedding=[0.4, 0.5, 0.6]),
|
||||
]
|
||||
)
|
||||
with (
|
||||
patch.object(router, "_is_reachable", return_value=True),
|
||||
patch("circuitforge_core.llm.router.OpenAI", return_value=mock_client),
|
||||
):
|
||||
result = router.embed(["hello world", "fireball rules"])
|
||||
|
||||
assert result == [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||
mock_client.embeddings.create.assert_called_once_with(
|
||||
model="nomic-embed-text",
|
||||
input=["hello world", "fireball rules"],
|
||||
)
|
||||
|
||||
|
||||
def test_embed_uses_chat_model_when_no_embedding_model_configured():
|
||||
router = _make_router(
|
||||
{
|
||||
"fallback_order": ["ollama"],
|
||||
"backends": {
|
||||
"ollama": {
|
||||
"type": "openai_compat",
|
||||
"base_url": "http://localhost:11434/v1",
|
||||
"model": "llama3",
|
||||
"supports_images": False,
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
mock_client = MagicMock()
|
||||
mock_client.embeddings.create.return_value = MagicMock(
|
||||
data=[MagicMock(embedding=[0.9, 0.8])]
|
||||
)
|
||||
with (
|
||||
patch.object(router, "_is_reachable", return_value=True),
|
||||
patch("circuitforge_core.llm.router.OpenAI", return_value=mock_client),
|
||||
):
|
||||
router.embed(["test"])
|
||||
|
||||
call_kwargs = mock_client.embeddings.create.call_args
|
||||
assert call_kwargs.kwargs["model"] == "llama3"
|
||||
|
||||
|
||||
def test_embed_skips_non_openai_compat_backends():
|
||||
router = _make_router(
|
||||
{
|
||||
"fallback_order": ["anthropic", "ollama"],
|
||||
"backends": {
|
||||
"anthropic": {
|
||||
"type": "anthropic",
|
||||
"enabled": True,
|
||||
"model": "claude-haiku-4-5-20251001",
|
||||
"api_key_env": "ANTHROPIC_API_KEY",
|
||||
"supports_images": True,
|
||||
},
|
||||
"ollama": {
|
||||
"type": "openai_compat",
|
||||
"base_url": "http://localhost:11434/v1",
|
||||
"model": "nomic-embed-text",
|
||||
"supports_images": False,
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
mock_client = MagicMock()
|
||||
mock_client.embeddings.create.return_value = MagicMock(
|
||||
data=[MagicMock(embedding=[0.1])]
|
||||
)
|
||||
with (
|
||||
patch.object(router, "_is_reachable", return_value=True),
|
||||
patch("circuitforge_core.llm.router.OpenAI", return_value=mock_client),
|
||||
):
|
||||
result = router.embed(["hello"])
|
||||
|
||||
assert result == [[0.1]]
|
||||
|
||||
|
||||
def test_embed_raises_when_all_backends_exhausted():
|
||||
router = _make_router(
|
||||
{
|
||||
"fallback_order": ["dead"],
|
||||
"backends": {
|
||||
"dead": {
|
||||
"type": "openai_compat",
|
||||
"base_url": "http://nowhere:1/v1",
|
||||
"model": "x",
|
||||
"supports_images": False,
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
with patch.object(router, "_is_reachable", return_value=False):
|
||||
with pytest.raises(RuntimeError, match="exhausted"):
|
||||
router.embed(["test"])
|
||||
|
|
|
|||
Loading…
Reference in a new issue