feat(imitate): task-model assignment routing via cf-orch

Add _resolve_task_model() helper that looks up a product.task assignment
from the coordinator and resolves its service_type from the model registry.
Add task_ids param to run_imitate() (comma-separated "product/task" strings)
so the imitate harness can dispatch to models chosen by the assignment layer
rather than requiring explicit model IDs.
This commit is contained in:
pyr0ball 2026-05-17 11:23:55 -07:00
parent 79b9ccbd3d
commit d416ef8aa4

View file

@ -94,6 +94,42 @@ def _cforch_url() -> str:
return cforch.get("coordinator_url") or "http://localhost:7700" return cforch.get("coordinator_url") or "http://localhost:7700"
def _resolve_task_model(cforch_base: str, product: str, task: str) -> dict | None:
"""Return {model_id, service_type} for a product.task assignment, or None if not found.
Calls GET coordinator/api/assignments and filters by product+task.
The model registry entry is fetched separately to get service_type.
Returns None (not raises) callers emit a 'model_done' error event instead.
"""
try:
asgn_resp = httpx.get(f"{cforch_base}/api/assignments", timeout=5.0)
asgn_resp.raise_for_status()
assignments: list[dict] = asgn_resp.json().get("assignments", []) or []
match = next(
(a for a in assignments if a.get("product") == product and a.get("task") == task),
None,
)
if match is None:
return None
model_id: str = match.get("model_id", "")
if not model_id:
return None
# Look up service_type from model registry
reg_resp = httpx.get(f"{cforch_base}/api/model-registry", timeout=5.0)
service_type = "cf-text" # sensible default
if reg_resp.is_success:
models: list[dict] = reg_resp.json().get("models", []) or []
reg_entry = next((m for m in models if m.get("model_id") == model_id), None)
if reg_entry:
service_type = reg_entry.get("service_type", "cf-text") or "cf-text"
return {"model_id": model_id, "service_type": service_type}
except Exception as exc:
logger.warning("Task resolution failed for %s.%s: %s", product, task, exc)
return None
def _cforch_catalog(cforch_base: str) -> list[dict]: def _cforch_catalog(cforch_base: str) -> list[dict]:
"""Fetch the live cf-text catalog from cf-orch. """Fetch the live cf-text catalog from cf-orch.
@ -476,13 +512,19 @@ def run_imitate(
prompt: str = "", prompt: str = "",
model_ids: str = "", # comma-separated ollama model IDs model_ids: str = "", # comma-separated ollama model IDs
cf_text_model_ids: str = "", # comma-separated cf-text model IDs (via cf-orch) cf_text_model_ids: str = "", # comma-separated cf-text model IDs (via cf-orch)
task_ids: str = "", # comma-separated "product/task" strings — resolved via assignments
temperature: float = 0.7, temperature: float = 0.7,
product_id: str = "", product_id: str = "",
system: str = "", # optional system prompt system: str = "", # optional system prompt
image_url: str = "", # optional image URL for vision models image_url: str = "", # optional image URL for vision models
session: "Any" = Depends(_get_imitate_session), session: "Any" = Depends(_get_imitate_session),
) -> StreamingResponse: ) -> StreamingResponse:
"""Run a prompt through selected ollama models and stream results as SSE. """Run a prompt through selected models and stream results as SSE.
Models can be selected three ways (combinable):
- model_ids: explicit ollama model IDs
- cf_text_model_ids: explicit cf-text model IDs routed via cf-orch
- task_ids: "product/task" strings resolved via the coordinator assignments table
If image_url is provided, the image is downloaded once and passed to every If image_url is provided, the image is downloaded once and passed to every
model as a base64-encoded blob allowing vision-capable local models to model as a base64-encoded blob allowing vision-capable local models to
@ -494,8 +536,37 @@ def run_imitate(
ollama_ids = [m.strip() for m in model_ids.split(",") if m.strip()] ollama_ids = [m.strip() for m in model_ids.split(",") if m.strip()]
cftext_ids = [m.strip() for m in cf_text_model_ids.split(",") if m.strip()] cftext_ids = [m.strip() for m in cf_text_model_ids.split(",") if m.strip()]
raw_task_ids = [t.strip() for t in task_ids.split(",") if t.strip()]
# Resolve task assignments to concrete model IDs, routing to the right service.
# Models that fail to resolve emit an error event at run time (non-fatal).
if raw_task_ids:
cforch_base = _cforch_url()
for task_spec in raw_task_ids:
parts = task_spec.split("/", 1)
if len(parts) != 2:
logger.warning("Skipping malformed task_id %r (expected product/task)", task_spec)
continue
product_name, task_name = parts
resolved = _resolve_task_model(cforch_base, product_name, task_name)
if resolved is None:
logger.warning("No assignment found for task %r", task_spec)
# Emit error at stream time via a sentinel in cftext_ids with a special label.
# We instead store the failed task_spec to emit a model_done error.
cftext_ids.append(f"__task_unresolved__:{task_spec}")
continue
mid = resolved["model_id"]
svc = resolved["service_type"]
if svc == "ollama":
if mid not in ollama_ids:
ollama_ids.append(mid)
else:
# cf-text, vllm, and any other cf-orch-managed service
if mid not in cftext_ids:
cftext_ids.append(mid)
if not ollama_ids and not cftext_ids: if not ollama_ids and not cftext_ids:
raise HTTPException(422, "model_ids or cf_text_model_ids is required") raise HTTPException(422, "model_ids, cf_text_model_ids, or task_ids is required")
cfg = _load_imitate_config() cfg = _load_imitate_config()
ollama_base = _ollama_url(cfg) ollama_base = _ollama_url(cfg)
@ -539,11 +610,25 @@ def run_imitate(
yield _sse({"type": "model_done", **result}) yield _sse({"type": "model_done", **result})
# cf-text models via cf-orch — fan out in parallel when multiple models selected # cf-text models via cf-orch — fan out in parallel when multiple models selected
if cftext_ids: # Partition the list: real cf-text IDs vs unresolved-task sentinels.
cftext_real = [m for m in cftext_ids if not m.startswith("__task_unresolved__:")]
cftext_unresolved = [m for m in cftext_ids if m.startswith("__task_unresolved__:")]
for sentinel in cftext_unresolved:
task_spec = sentinel.split(":", 1)[1]
result = {
"model": task_spec,
"response": "",
"elapsed_ms": 0,
"error": f"No assignment configured for task '{task_spec}'",
}
results.append(result)
yield _sse({"type": "model_done", **result})
if cftext_real:
from concurrent.futures import ThreadPoolExecutor, as_completed from concurrent.futures import ThreadPoolExecutor, as_completed
# Announce all models upfront so the UI can show loading states immediately # Announce all models upfront so the UI can show loading states immediately
for model_id in cftext_ids: for model_id in cftext_real:
yield _sse({"type": "model_start", "model": model_id, "service": "cf-text"}) yield _sse({"type": "model_start", "model": model_id, "service": "cf-text"})
_user_id: str | None = getattr(session, "user_id", None) _user_id: str | None = getattr(session, "user_id", None)
@ -551,13 +636,13 @@ def run_imitate(
if _user_id in (None, "local", "local-dev") or (_user_id or "").startswith("anon-"): if _user_id in (None, "local", "local-dev") or (_user_id or "").startswith("anon-"):
_user_id = None _user_id = None
with ThreadPoolExecutor(max_workers=len(cftext_ids)) as pool: with ThreadPoolExecutor(max_workers=len(cftext_real)) as pool:
future_to_model = { future_to_model = {
pool.submit( pool.submit(
_run_cftext, cforch_base, mid, prompt, system_ctx, temperature, _run_cftext, cforch_base, mid, prompt, system_ctx, temperature,
180.0, _user_id, 180.0, _user_id,
): mid ): mid
for mid in cftext_ids for mid in cftext_real
} }
for future in as_completed(future_to_model): for future in as_completed(future_to_model):
model_id = future_to_model[future] model_id = future_to_model[future]