avocet/app/imitate.py
pyr0ball cc24cd0d7d feat(imitate): parallel cf-text fanout workers + signal-based cold-start detection
Backend:
- Run all cf-text model allocations concurrently via ThreadPoolExecutor + as_completed
- Announce model_start events upfront so the UI can show loading states immediately
- Replace timer-based startup polling with coordinator state signals: waits for
  state=="running" (success) or state=="stopped" (fail-fast) on the matching
  node/gpu instance; falls back to health poll after 6 consecutive probe misses
- Add /api/cforch/catalog endpoint: fetches live cf-text model list from cf-orch,
  filtering out proxy entries (ollama://, vllm://, http://) so only loadable models
  are returned

Frontend (ImitateView.vue):
- Show per-model loading spinners as results arrive via SSE stream
- Display cold-start badge when coordinator signals the model was freshly loaded
2026-04-24 14:56:09 -07:00

624 lines
23 KiB
Python

"""Avocet — Imitate tab API.
Fetches real samples from sibling CF product APIs, sends them through selected
local LLMs (ollama), and streams responses back to the UI. Results can be
pushed into the SFT corrections queue for human review.
All endpoints registered on `router`. api.py includes this with prefix="/api/imitate".
Module-level globals follow the same testability pattern as cforch.py and sft.py:
override _CONFIG_DIR and _DATA_DIR via set_config_dir() / set_data_dir() in tests.
"""
from __future__ import annotations
import base64
import json
import logging
import time
import uuid
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
from urllib.error import URLError
from urllib.request import Request, urlopen
import httpx
import yaml
from fastapi import APIRouter, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from app.utils import append_jsonl
logger = logging.getLogger(__name__)
_ROOT = Path(__file__).parent.parent
_CONFIG_DIR: Path | None = None
_DATA_DIR: Path = _ROOT / "data"
router = APIRouter()
# ── Testability seams ──────────────────────────────────────────────────────────
def set_config_dir(path: Path | None) -> None:
global _CONFIG_DIR
_CONFIG_DIR = path
def set_data_dir(path: Path) -> None:
global _DATA_DIR
_DATA_DIR = path
# ── Internal helpers ───────────────────────────────────────────────────────────
def _config_file() -> Path:
if _CONFIG_DIR is not None:
return _CONFIG_DIR / "label_tool.yaml"
return _ROOT / "config" / "label_tool.yaml"
def _load_imitate_config() -> dict:
"""Read label_tool.yaml and return the imitate sub-dict (or {} if absent)."""
f = _config_file()
if not f.exists():
return {}
try:
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
except yaml.YAMLError as exc:
logger.warning("Failed to parse imitate config %s: %s", f, exc)
return {}
return raw.get("imitate", {}) or {}
def _load_cforch_config() -> dict:
"""Read cforch section for ollama_url fallback."""
f = _config_file()
if not f.exists():
return {}
try:
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
except yaml.YAMLError as exc:
return {}
return raw.get("cforch", {}) or {}
def _ollama_url(cfg: dict) -> str:
cforch = _load_cforch_config()
return cfg.get("ollama_url") or cforch.get("ollama_url") or "http://localhost:11434"
def _cforch_url() -> str:
cforch = _load_cforch_config()
return cforch.get("coordinator_url") or "http://localhost:7700"
def _cforch_catalog(cforch_base: str) -> list[dict]:
"""Fetch the live cf-text catalog from cf-orch.
Filters out proxy entries (ollama://, vllm://, http://) — those models are
served by their own services and should not be allocated via cf-text.
Returns only models with real file-system paths that cf-text can load directly.
"""
try:
resp = httpx.get(
f"{cforch_base}/api/services/cf-text/catalog",
params={"node_id": "heimdall"},
timeout=5.0,
)
resp.raise_for_status()
raw = resp.json()
result = []
for model_id, entry in raw.items():
if not isinstance(entry, dict):
continue
path = entry.get("path", "")
# Skip proxy entries — they're routed through other services
if "://" in path:
continue
result.append({
"id": model_id,
"vram_mb": entry.get("vram_mb", 0),
"description": entry.get("description", ""),
})
return result
except Exception as exc:
logger.warning("Could not fetch cf-orch catalog: %s", exc)
return []
def _http_get_json(url: str, timeout: int = 5) -> Any:
"""Fetch JSON from url; raise URLError on failure."""
req = Request(url, headers={"Accept": "application/json"})
with urlopen(req, timeout=timeout) as resp:
return json.loads(resp.read().decode("utf-8"))
def _is_online(base_url: str, health_path: str = "/api/health") -> bool:
"""Return True if the product's health endpoint responds OK."""
try:
data = _http_get_json(f"{base_url.rstrip('/')}{health_path}", timeout=2)
return bool(data)
except Exception:
return False
def _extract_sample(
raw: Any,
text_fields: list[str],
sample_index: int = 0,
sample_key: str | None = None,
) -> dict[str, Any]:
"""Pull one item from a list or dict response and extract text_fields.
sample_key: if provided, unwrap raw[sample_key] before looking for a list.
Falls back to a set of conventional envelope keys if sample_key is absent.
"""
item: dict[str, Any]
if isinstance(raw, list):
if not raw:
return {}
item = raw[min(sample_index, len(raw) - 1)]
elif isinstance(raw, dict):
# Use declared sample_key first, then fall back to conventional names.
_ENVELOPE_KEYS = (
"samples", "items", "results", "data", "jobs", "listings",
"pantry", "saved_searches", "entries", "calls", "records",
)
search_keys = ([sample_key] if sample_key else []) + list(_ENVELOPE_KEYS)
for key in search_keys:
if key in raw and isinstance(raw[key], list):
lst = raw[key]
item = lst[min(sample_index, len(lst) - 1)] if lst else {}
break
else:
item = raw
else:
return {}
parts = []
for field in text_fields:
val = item.get(field)
if val and str(val).strip():
parts.append(f"**{field}**: {val}")
return {"item": item, "text": "\n\n".join(parts)}
def _candidates_file() -> Path:
return _DATA_DIR / "sft_candidates.jsonl"
def _sse(data: dict) -> str:
return f"data: {json.dumps(data)}\n\n"
def _fetch_image_b64(image_url: str) -> str:
"""Download an image URL and return it as a base64 string for ollama.
Returns empty string on any failure — a missing image is non-fatal;
the model will still run against the text prompt alone.
"""
try:
req = Request(image_url, headers={"User-Agent": "Avocet/1.0"})
with urlopen(req, timeout=10) as resp:
return base64.b64encode(resp.read()).decode("ascii")
except Exception as exc:
logger.warning("Failed to fetch image %s: %s", image_url, exc)
return ""
def _run_ollama_streaming(
ollama_base: str,
model_id: str,
prompt: str,
temperature: float,
system: str = "",
images: list[str] | None = None,
) -> tuple[str, int]:
"""Call ollama /api/generate with stream=False; return (full_response, elapsed_ms).
Blocks until the model finishes; yields nothing — streaming is handled by
the SSE generator in run_imitate().
system: optional system prompt passed as a separate field to ollama.
images: list of base64-encoded image strings (vision models only).
"""
url = f"{ollama_base.rstrip('/')}/api/generate"
body: dict = {
"model": model_id,
"prompt": prompt,
"stream": False,
"options": {"temperature": temperature},
}
if system:
body["system"] = system
if images:
body["images"] = images
payload = json.dumps(body).encode("utf-8")
req = Request(url, data=payload, method="POST",
headers={"Content-Type": "application/json"})
t0 = time.time()
try:
with urlopen(req, timeout=120) as resp:
body = json.loads(resp.read().decode("utf-8"))
elapsed = int((time.time() - t0) * 1000)
return body.get("response", ""), elapsed
except Exception as exc:
elapsed = int((time.time() - t0) * 1000)
raise RuntimeError(str(exc)) from exc
def _run_cftext(
cforch_base: str,
model_id: str,
prompt: str,
system: str,
temperature: float,
startup_timeout_s: float = 180.0,
) -> tuple[str, int, bool]:
"""Allocate cf-text via cf-orch, generate, release. Returns (response, elapsed_ms, cold_started).
Raises RuntimeError on allocation failure or generation error.
cold_started=True means the service was launched from scratch (caller may log this).
Cold-start detection uses coordinator state signals (running/stopped) rather than
polling the service health endpoint — this fails fast on model load errors instead
of waiting out the full timeout.
"""
# Allocate
alloc_resp = httpx.post(
f"{cforch_base}/api/services/cf-text/allocate",
json={
"model_candidates": [model_id],
"caller": "avocet",
"pipeline": "imitate",
},
timeout=30.0,
)
alloc_resp.raise_for_status()
data = alloc_resp.json()
service_url: str = data["url"]
allocation_id: str = data.get("allocation_id", "")
node_id: str = data.get("node_id", "")
gpu_id: int | None = data.get("gpu_id")
cold_started = data.get("started", False) and not data.get("warm", True)
# Wait for ready using coordinator state signals
if cold_started:
deadline = time.monotonic() + startup_timeout_s
probe_misses = 0
while time.monotonic() < deadline:
try:
status = httpx.get(
f"{cforch_base}/api/services/cf-text/status", timeout=5.0
)
if status.is_success:
instances = status.json().get("instances", [])
match = next(
(i for i in instances
if i.get("node_id") == node_id and i.get("gpu_id") == gpu_id),
None,
)
if match:
probe_misses = 0
state = match.get("state", "")
if state == "running":
break
elif state == "stopped":
if allocation_id:
httpx.delete(
f"{cforch_base}/api/services/cf-text/allocations/{allocation_id}",
timeout=5.0,
)
raise RuntimeError(f"cf-text failed to load {model_id!r} (service stopped)")
else:
probe_misses += 1
if probe_misses >= 6:
# Coordinator hasn't registered instance yet — fall back to health poll
try:
if httpx.get(f"{service_url}/health", timeout=3.0).is_success:
break
except Exception:
pass
except RuntimeError:
raise
except Exception:
pass
time.sleep(2.0)
else:
if allocation_id:
httpx.delete(f"{cforch_base}/api/services/cf-text/allocations/{allocation_id}", timeout=5.0)
raise RuntimeError(f"cf-text cold start timed out after {startup_timeout_s:.0f}s")
# Generate
messages: list[dict] = []
if system:
messages.append({"role": "system", "content": system})
messages.append({"role": "user", "content": prompt})
t0 = time.time()
try:
gen_resp = httpx.post(
f"{service_url}/v1/chat/completions",
json={
"model": model_id,
"messages": messages,
"max_tokens": 300,
"temperature": temperature,
"stream": False,
},
timeout=120.0,
)
gen_resp.raise_for_status()
elapsed_ms = int((time.time() - t0) * 1000)
content = gen_resp.json()["choices"][0]["message"]["content"]
return content.strip(), elapsed_ms, cold_started
except Exception as exc:
elapsed_ms = int((time.time() - t0) * 1000)
raise RuntimeError(str(exc)) from exc
finally:
if allocation_id:
try:
httpx.delete(f"{cforch_base}/api/services/cf-text/allocations/{allocation_id}", timeout=5.0)
except Exception:
pass
# ── GET /products ──────────────────────────────────────────────────────────────
@router.get("/products")
def get_products() -> dict:
"""List configured CF products with live online status."""
cfg = _load_imitate_config()
products_raw = cfg.get("products", []) or []
products = []
for p in products_raw:
if not isinstance(p, dict):
continue
base_url = p.get("base_url", "")
products.append({
"id": p.get("id", ""),
"name": p.get("name", ""),
"icon": p.get("icon", "📦"),
"description": p.get("description", ""),
"base_url": base_url,
"online": _is_online(base_url, p.get("health_path", "/api/health")) if base_url else False,
})
return {"products": products}
# ── GET /products/{product_id}/sample ─────────────────────────────────────────
@router.get("/products/{product_id}/sample")
def get_sample(product_id: str, index: int = 0) -> dict:
"""Fetch a real sample from the given product's API."""
cfg = _load_imitate_config()
products_raw = cfg.get("products", []) or []
product: dict | None = None
for p in products_raw:
if isinstance(p, dict) and p.get("id") == product_id:
product = p
break
if product is None:
raise HTTPException(404, f"Product '{product_id}' not in config")
base_url = product.get("base_url", "").rstrip("/")
endpoint = product.get("sample_endpoint", "")
if not base_url or not endpoint:
raise HTTPException(422, "Product missing base_url or sample_endpoint")
url = f"{base_url}{endpoint}"
try:
raw = _http_get_json(url, timeout=5)
except URLError as exc:
raise HTTPException(503, f"Product API unreachable: {exc}") from exc
except Exception as exc:
raise HTTPException(502, f"Bad response from product API: {exc}") from exc
text_fields = product.get("text_fields", []) or []
sample_key = product.get("sample_key") or None
extracted = _extract_sample(raw, text_fields, index, sample_key=sample_key)
if not extracted:
raise HTTPException(404, "No sample items returned by product API")
prompt_template = product.get("prompt_template", "{text}")
prompt = prompt_template.replace("{text}", extracted["text"])
# Also substitute any {field_name} placeholders from the raw item fields.
item = extracted.get("item", {})
for field, val in item.items():
prompt = prompt.replace(f"{{{field}}}", str(val) if val is not None else "")
# Expose system_prompt and image_url if the product API returns them.
# system_prompt: Peregrine, Snipe (vision analysis instructions)
# image_url: Snipe listing photos — Avocet downloads + base64-encodes at run time
item = extracted.get("item", {})
system_prompt = str(item.get("system_prompt", "")) if isinstance(item, dict) else ""
image_url = str(item.get("image_url", "")) if isinstance(item, dict) else ""
return {
"product_id": product_id,
"sample_index": index,
"text": extracted["text"],
"prompt": prompt,
"system_prompt": system_prompt,
"image_url": image_url,
"raw_item": item,
}
# ── GET /catalog ───────────────────────────────────────────────────────────────
@router.get("/catalog")
def get_catalog() -> dict:
"""Return the live cf-text model catalog from cf-orch coordinator."""
models = _cforch_catalog(_cforch_url())
return {"models": models}
# ── GET /run (SSE) ─────────────────────────────────────────────────────────────
@router.get("/run")
def run_imitate(
prompt: str = "",
model_ids: str = "", # comma-separated ollama model IDs
cf_text_model_ids: str = "", # comma-separated cf-text model IDs (via cf-orch)
temperature: float = 0.7,
product_id: str = "",
system: str = "", # optional system prompt
image_url: str = "", # optional image URL for vision models
) -> StreamingResponse:
"""Run a prompt through selected ollama models and stream results as SSE.
If image_url is provided, the image is downloaded once and passed to every
model as a base64-encoded blob — allowing vision-capable local models to
evaluate listing photos the same way Snipe's background task pipeline does.
"""
if not prompt.strip():
raise HTTPException(422, "prompt is required")
ollama_ids = [m.strip() for m in model_ids.split(",") if m.strip()]
cftext_ids = [m.strip() for m in cf_text_model_ids.split(",") if m.strip()]
if not ollama_ids and not cftext_ids:
raise HTTPException(422, "model_ids or cf_text_model_ids is required")
cfg = _load_imitate_config()
ollama_base = _ollama_url(cfg)
cforch_base = _cforch_url()
system_ctx = system.strip() or ""
total_models = len(ollama_ids) + len(cftext_ids)
# Download image once before streaming — shared across ollama vision models
images: list[str] = []
if image_url.strip():
b64 = _fetch_image_b64(image_url.strip())
if b64:
images = [b64]
def generate():
results: list[dict] = []
yield _sse({"type": "start", "total_models": total_models, "has_image": bool(images)})
# Ollama models
for model_id in ollama_ids:
yield _sse({"type": "model_start", "model": model_id, "service": "ollama"})
try:
response, elapsed_ms = _run_ollama_streaming(
ollama_base, model_id, prompt, temperature,
system=system_ctx, images=images or None,
)
result = {
"model": model_id,
"response": response,
"elapsed_ms": elapsed_ms,
"error": None,
}
except Exception as exc:
result = {
"model": model_id,
"response": "",
"elapsed_ms": 0,
"error": str(exc),
}
results.append(result)
yield _sse({"type": "model_done", **result})
# cf-text models via cf-orch — fan out in parallel when multiple models selected
if cftext_ids:
from concurrent.futures import ThreadPoolExecutor, as_completed
# Announce all models upfront so the UI can show loading states immediately
for model_id in cftext_ids:
yield _sse({"type": "model_start", "model": model_id, "service": "cf-text"})
with ThreadPoolExecutor(max_workers=len(cftext_ids)) as pool:
future_to_model = {
pool.submit(_run_cftext, cforch_base, mid, prompt, system_ctx, temperature): mid
for mid in cftext_ids
}
for future in as_completed(future_to_model):
model_id = future_to_model[future]
try:
response, elapsed_ms, cold_started = future.result()
if cold_started:
yield _sse({"type": "model_coldstart", "model": model_id})
result = {
"model": model_id,
"response": response,
"elapsed_ms": elapsed_ms,
"error": None,
}
except Exception as exc:
result = {
"model": model_id,
"response": "",
"elapsed_ms": 0,
"error": str(exc),
}
results.append(result)
yield _sse({"type": "model_done", **result})
yield _sse({"type": "complete", "results": results})
return StreamingResponse(
generate(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"X-Accel-Buffering": "no",
},
)
# ── POST /push-corrections ─────────────────────────────────────────────────────
class ImitateResult(BaseModel):
model: str
response: str
elapsed_ms: int
error: str | None = None
class PushCorrectionsRequest(BaseModel):
product_id: str
prompt: str
results: list[ImitateResult]
@router.post("/push-corrections")
def push_corrections(req: PushCorrectionsRequest) -> dict:
"""Append imitate results to sft_candidates.jsonl for human review."""
if not req.prompt.strip():
raise HTTPException(422, "prompt is required")
if not req.results:
raise HTTPException(422, "results list is empty")
ts = datetime.now(timezone.utc).isoformat()
records = []
for r in req.results:
if r.error or not r.response.strip():
continue
records.append({
"id": str(uuid.uuid4()),
"source": "imitate",
"product_id": req.product_id,
"prompt_messages": [{"role": "user", "content": req.prompt}],
"model_response": r.response,
"model_id": r.model,
"elapsed_ms": r.elapsed_ms,
"status": "pending",
"created_at": ts,
})
if not records:
raise HTTPException(422, "No non-error results to push")
dest = _candidates_file()
dest.parent.mkdir(parents=True, exist_ok=True)
for record in records:
append_jsonl(dest, record)
return {"pushed": len(records)}