feat(imitate): task-model assignment routing via cf-orch
Add _resolve_task_model() helper that looks up a product.task assignment from the coordinator and resolves its service_type from the model registry. Add task_ids param to run_imitate() (comma-separated "product/task" strings) so the imitate harness can dispatch to models chosen by the assignment layer rather than requiring explicit model IDs.
This commit is contained in:
parent
79b9ccbd3d
commit
d416ef8aa4
1 changed files with 91 additions and 6 deletions
|
|
@ -94,6 +94,42 @@ def _cforch_url() -> str:
|
|||
return cforch.get("coordinator_url") or "http://localhost:7700"
|
||||
|
||||
|
||||
def _resolve_task_model(cforch_base: str, product: str, task: str) -> dict | None:
|
||||
"""Return {model_id, service_type} for a product.task assignment, or None if not found.
|
||||
|
||||
Calls GET coordinator/api/assignments and filters by product+task.
|
||||
The model registry entry is fetched separately to get service_type.
|
||||
Returns None (not raises) — callers emit a 'model_done' error event instead.
|
||||
"""
|
||||
try:
|
||||
asgn_resp = httpx.get(f"{cforch_base}/api/assignments", timeout=5.0)
|
||||
asgn_resp.raise_for_status()
|
||||
assignments: list[dict] = asgn_resp.json().get("assignments", []) or []
|
||||
match = next(
|
||||
(a for a in assignments if a.get("product") == product and a.get("task") == task),
|
||||
None,
|
||||
)
|
||||
if match is None:
|
||||
return None
|
||||
model_id: str = match.get("model_id", "")
|
||||
if not model_id:
|
||||
return None
|
||||
|
||||
# Look up service_type from model registry
|
||||
reg_resp = httpx.get(f"{cforch_base}/api/model-registry", timeout=5.0)
|
||||
service_type = "cf-text" # sensible default
|
||||
if reg_resp.is_success:
|
||||
models: list[dict] = reg_resp.json().get("models", []) or []
|
||||
reg_entry = next((m for m in models if m.get("model_id") == model_id), None)
|
||||
if reg_entry:
|
||||
service_type = reg_entry.get("service_type", "cf-text") or "cf-text"
|
||||
|
||||
return {"model_id": model_id, "service_type": service_type}
|
||||
except Exception as exc:
|
||||
logger.warning("Task resolution failed for %s.%s: %s", product, task, exc)
|
||||
return None
|
||||
|
||||
|
||||
def _cforch_catalog(cforch_base: str) -> list[dict]:
|
||||
"""Fetch the live cf-text catalog from cf-orch.
|
||||
|
||||
|
|
@ -476,13 +512,19 @@ def run_imitate(
|
|||
prompt: str = "",
|
||||
model_ids: str = "", # comma-separated ollama model IDs
|
||||
cf_text_model_ids: str = "", # comma-separated cf-text model IDs (via cf-orch)
|
||||
task_ids: str = "", # comma-separated "product/task" strings — resolved via assignments
|
||||
temperature: float = 0.7,
|
||||
product_id: str = "",
|
||||
system: str = "", # optional system prompt
|
||||
image_url: str = "", # optional image URL for vision models
|
||||
session: "Any" = Depends(_get_imitate_session),
|
||||
) -> StreamingResponse:
|
||||
"""Run a prompt through selected ollama models and stream results as SSE.
|
||||
"""Run a prompt through selected models and stream results as SSE.
|
||||
|
||||
Models can be selected three ways (combinable):
|
||||
- model_ids: explicit ollama model IDs
|
||||
- cf_text_model_ids: explicit cf-text model IDs routed via cf-orch
|
||||
- task_ids: "product/task" strings resolved via the coordinator assignments table
|
||||
|
||||
If image_url is provided, the image is downloaded once and passed to every
|
||||
model as a base64-encoded blob — allowing vision-capable local models to
|
||||
|
|
@ -494,8 +536,37 @@ def run_imitate(
|
|||
|
||||
ollama_ids = [m.strip() for m in model_ids.split(",") if m.strip()]
|
||||
cftext_ids = [m.strip() for m in cf_text_model_ids.split(",") if m.strip()]
|
||||
raw_task_ids = [t.strip() for t in task_ids.split(",") if t.strip()]
|
||||
|
||||
# Resolve task assignments to concrete model IDs, routing to the right service.
|
||||
# Models that fail to resolve emit an error event at run time (non-fatal).
|
||||
if raw_task_ids:
|
||||
cforch_base = _cforch_url()
|
||||
for task_spec in raw_task_ids:
|
||||
parts = task_spec.split("/", 1)
|
||||
if len(parts) != 2:
|
||||
logger.warning("Skipping malformed task_id %r (expected product/task)", task_spec)
|
||||
continue
|
||||
product_name, task_name = parts
|
||||
resolved = _resolve_task_model(cforch_base, product_name, task_name)
|
||||
if resolved is None:
|
||||
logger.warning("No assignment found for task %r", task_spec)
|
||||
# Emit error at stream time via a sentinel in cftext_ids with a special label.
|
||||
# We instead store the failed task_spec to emit a model_done error.
|
||||
cftext_ids.append(f"__task_unresolved__:{task_spec}")
|
||||
continue
|
||||
mid = resolved["model_id"]
|
||||
svc = resolved["service_type"]
|
||||
if svc == "ollama":
|
||||
if mid not in ollama_ids:
|
||||
ollama_ids.append(mid)
|
||||
else:
|
||||
# cf-text, vllm, and any other cf-orch-managed service
|
||||
if mid not in cftext_ids:
|
||||
cftext_ids.append(mid)
|
||||
|
||||
if not ollama_ids and not cftext_ids:
|
||||
raise HTTPException(422, "model_ids or cf_text_model_ids is required")
|
||||
raise HTTPException(422, "model_ids, cf_text_model_ids, or task_ids is required")
|
||||
|
||||
cfg = _load_imitate_config()
|
||||
ollama_base = _ollama_url(cfg)
|
||||
|
|
@ -539,11 +610,25 @@ def run_imitate(
|
|||
yield _sse({"type": "model_done", **result})
|
||||
|
||||
# cf-text models via cf-orch — fan out in parallel when multiple models selected
|
||||
if cftext_ids:
|
||||
# Partition the list: real cf-text IDs vs unresolved-task sentinels.
|
||||
cftext_real = [m for m in cftext_ids if not m.startswith("__task_unresolved__:")]
|
||||
cftext_unresolved = [m for m in cftext_ids if m.startswith("__task_unresolved__:")]
|
||||
for sentinel in cftext_unresolved:
|
||||
task_spec = sentinel.split(":", 1)[1]
|
||||
result = {
|
||||
"model": task_spec,
|
||||
"response": "",
|
||||
"elapsed_ms": 0,
|
||||
"error": f"No assignment configured for task '{task_spec}'",
|
||||
}
|
||||
results.append(result)
|
||||
yield _sse({"type": "model_done", **result})
|
||||
|
||||
if cftext_real:
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
# Announce all models upfront so the UI can show loading states immediately
|
||||
for model_id in cftext_ids:
|
||||
for model_id in cftext_real:
|
||||
yield _sse({"type": "model_start", "model": model_id, "service": "cf-text"})
|
||||
|
||||
_user_id: str | None = getattr(session, "user_id", None)
|
||||
|
|
@ -551,13 +636,13 @@ def run_imitate(
|
|||
if _user_id in (None, "local", "local-dev") or (_user_id or "").startswith("anon-"):
|
||||
_user_id = None
|
||||
|
||||
with ThreadPoolExecutor(max_workers=len(cftext_ids)) as pool:
|
||||
with ThreadPoolExecutor(max_workers=len(cftext_real)) as pool:
|
||||
future_to_model = {
|
||||
pool.submit(
|
||||
_run_cftext, cforch_base, mid, prompt, system_ctx, temperature,
|
||||
180.0, _user_id,
|
||||
): mid
|
||||
for mid in cftext_ids
|
||||
for mid in cftext_real
|
||||
}
|
||||
for future in as_completed(future_to_model):
|
||||
model_id = future_to_model[future]
|
||||
|
|
|
|||
Loading…
Reference in a new issue