From d416ef8aa4e84b743002935507b54643f918a760 Mon Sep 17 00:00:00 2001 From: pyr0ball Date: Sun, 17 May 2026 11:23:55 -0700 Subject: [PATCH] feat(imitate): task-model assignment routing via cf-orch Add _resolve_task_model() helper that looks up a product.task assignment from the coordinator and resolves its service_type from the model registry. Add task_ids param to run_imitate() (comma-separated "product/task" strings) so the imitate harness can dispatch to models chosen by the assignment layer rather than requiring explicit model IDs. --- app/data/imitate.py | 97 ++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 91 insertions(+), 6 deletions(-) diff --git a/app/data/imitate.py b/app/data/imitate.py index 354aeab..1d453d2 100644 --- a/app/data/imitate.py +++ b/app/data/imitate.py @@ -94,6 +94,42 @@ def _cforch_url() -> str: return cforch.get("coordinator_url") or "http://localhost:7700" +def _resolve_task_model(cforch_base: str, product: str, task: str) -> dict | None: + """Return {model_id, service_type} for a product.task assignment, or None if not found. + + Calls GET coordinator/api/assignments and filters by product+task. + The model registry entry is fetched separately to get service_type. + Returns None (not raises) — callers emit a 'model_done' error event instead. + """ + try: + asgn_resp = httpx.get(f"{cforch_base}/api/assignments", timeout=5.0) + asgn_resp.raise_for_status() + assignments: list[dict] = asgn_resp.json().get("assignments", []) or [] + match = next( + (a for a in assignments if a.get("product") == product and a.get("task") == task), + None, + ) + if match is None: + return None + model_id: str = match.get("model_id", "") + if not model_id: + return None + + # Look up service_type from model registry + reg_resp = httpx.get(f"{cforch_base}/api/model-registry", timeout=5.0) + service_type = "cf-text" # sensible default + if reg_resp.is_success: + models: list[dict] = reg_resp.json().get("models", []) or [] + reg_entry = next((m for m in models if m.get("model_id") == model_id), None) + if reg_entry: + service_type = reg_entry.get("service_type", "cf-text") or "cf-text" + + return {"model_id": model_id, "service_type": service_type} + except Exception as exc: + logger.warning("Task resolution failed for %s.%s: %s", product, task, exc) + return None + + def _cforch_catalog(cforch_base: str) -> list[dict]: """Fetch the live cf-text catalog from cf-orch. @@ -476,13 +512,19 @@ def run_imitate( prompt: str = "", model_ids: str = "", # comma-separated ollama model IDs cf_text_model_ids: str = "", # comma-separated cf-text model IDs (via cf-orch) + task_ids: str = "", # comma-separated "product/task" strings — resolved via assignments temperature: float = 0.7, product_id: str = "", system: str = "", # optional system prompt image_url: str = "", # optional image URL for vision models session: "Any" = Depends(_get_imitate_session), ) -> StreamingResponse: - """Run a prompt through selected ollama models and stream results as SSE. + """Run a prompt through selected models and stream results as SSE. + + Models can be selected three ways (combinable): + - model_ids: explicit ollama model IDs + - cf_text_model_ids: explicit cf-text model IDs routed via cf-orch + - task_ids: "product/task" strings resolved via the coordinator assignments table If image_url is provided, the image is downloaded once and passed to every model as a base64-encoded blob — allowing vision-capable local models to @@ -494,8 +536,37 @@ def run_imitate( ollama_ids = [m.strip() for m in model_ids.split(",") if m.strip()] cftext_ids = [m.strip() for m in cf_text_model_ids.split(",") if m.strip()] + raw_task_ids = [t.strip() for t in task_ids.split(",") if t.strip()] + + # Resolve task assignments to concrete model IDs, routing to the right service. + # Models that fail to resolve emit an error event at run time (non-fatal). + if raw_task_ids: + cforch_base = _cforch_url() + for task_spec in raw_task_ids: + parts = task_spec.split("/", 1) + if len(parts) != 2: + logger.warning("Skipping malformed task_id %r (expected product/task)", task_spec) + continue + product_name, task_name = parts + resolved = _resolve_task_model(cforch_base, product_name, task_name) + if resolved is None: + logger.warning("No assignment found for task %r", task_spec) + # Emit error at stream time via a sentinel in cftext_ids with a special label. + # We instead store the failed task_spec to emit a model_done error. + cftext_ids.append(f"__task_unresolved__:{task_spec}") + continue + mid = resolved["model_id"] + svc = resolved["service_type"] + if svc == "ollama": + if mid not in ollama_ids: + ollama_ids.append(mid) + else: + # cf-text, vllm, and any other cf-orch-managed service + if mid not in cftext_ids: + cftext_ids.append(mid) + if not ollama_ids and not cftext_ids: - raise HTTPException(422, "model_ids or cf_text_model_ids is required") + raise HTTPException(422, "model_ids, cf_text_model_ids, or task_ids is required") cfg = _load_imitate_config() ollama_base = _ollama_url(cfg) @@ -539,11 +610,25 @@ def run_imitate( yield _sse({"type": "model_done", **result}) # cf-text models via cf-orch — fan out in parallel when multiple models selected - if cftext_ids: + # Partition the list: real cf-text IDs vs unresolved-task sentinels. + cftext_real = [m for m in cftext_ids if not m.startswith("__task_unresolved__:")] + cftext_unresolved = [m for m in cftext_ids if m.startswith("__task_unresolved__:")] + for sentinel in cftext_unresolved: + task_spec = sentinel.split(":", 1)[1] + result = { + "model": task_spec, + "response": "", + "elapsed_ms": 0, + "error": f"No assignment configured for task '{task_spec}'", + } + results.append(result) + yield _sse({"type": "model_done", **result}) + + if cftext_real: from concurrent.futures import ThreadPoolExecutor, as_completed # Announce all models upfront so the UI can show loading states immediately - for model_id in cftext_ids: + for model_id in cftext_real: yield _sse({"type": "model_start", "model": model_id, "service": "cf-text"}) _user_id: str | None = getattr(session, "user_id", None) @@ -551,13 +636,13 @@ def run_imitate( if _user_id in (None, "local", "local-dev") or (_user_id or "").startswith("anon-"): _user_id = None - with ThreadPoolExecutor(max_workers=len(cftext_ids)) as pool: + with ThreadPoolExecutor(max_workers=len(cftext_real)) as pool: future_to_model = { pool.submit( _run_cftext, cforch_base, mid, prompt, system_ctx, temperature, 180.0, _user_id, ): mid - for mid in cftext_ids + for mid in cftext_real } for future in as_completed(future_to_model): model_id = future_to_model[future]