Backend: - Run all cf-text model allocations concurrently via ThreadPoolExecutor + as_completed - Announce model_start events upfront so the UI can show loading states immediately - Replace timer-based startup polling with coordinator state signals: waits for state=="running" (success) or state=="stopped" (fail-fast) on the matching node/gpu instance; falls back to health poll after 6 consecutive probe misses - Add /api/cforch/catalog endpoint: fetches live cf-text model list from cf-orch, filtering out proxy entries (ollama://, vllm://, http://) so only loadable models are returned Frontend (ImitateView.vue): - Show per-model loading spinners as results arrive via SSE stream - Display cold-start badge when coordinator signals the model was freshly loaded
624 lines
23 KiB
Python
624 lines
23 KiB
Python
"""Avocet — Imitate tab API.
|
|
|
|
Fetches real samples from sibling CF product APIs, sends them through selected
|
|
local LLMs (ollama), and streams responses back to the UI. Results can be
|
|
pushed into the SFT corrections queue for human review.
|
|
|
|
All endpoints registered on `router`. api.py includes this with prefix="/api/imitate".
|
|
|
|
Module-level globals follow the same testability pattern as cforch.py and sft.py:
|
|
override _CONFIG_DIR and _DATA_DIR via set_config_dir() / set_data_dir() in tests.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import base64
|
|
import json
|
|
import logging
|
|
import time
|
|
import uuid
|
|
from datetime import datetime, timezone
|
|
from pathlib import Path
|
|
from typing import Any
|
|
from urllib.error import URLError
|
|
from urllib.request import Request, urlopen
|
|
|
|
import httpx
|
|
import yaml
|
|
from fastapi import APIRouter, HTTPException
|
|
from fastapi.responses import StreamingResponse
|
|
from pydantic import BaseModel
|
|
|
|
from app.utils import append_jsonl
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_ROOT = Path(__file__).parent.parent
|
|
_CONFIG_DIR: Path | None = None
|
|
_DATA_DIR: Path = _ROOT / "data"
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
# ── Testability seams ──────────────────────────────────────────────────────────
|
|
|
|
def set_config_dir(path: Path | None) -> None:
|
|
global _CONFIG_DIR
|
|
_CONFIG_DIR = path
|
|
|
|
|
|
def set_data_dir(path: Path) -> None:
|
|
global _DATA_DIR
|
|
_DATA_DIR = path
|
|
|
|
|
|
# ── Internal helpers ───────────────────────────────────────────────────────────
|
|
|
|
def _config_file() -> Path:
|
|
if _CONFIG_DIR is not None:
|
|
return _CONFIG_DIR / "label_tool.yaml"
|
|
return _ROOT / "config" / "label_tool.yaml"
|
|
|
|
|
|
def _load_imitate_config() -> dict:
|
|
"""Read label_tool.yaml and return the imitate sub-dict (or {} if absent)."""
|
|
f = _config_file()
|
|
if not f.exists():
|
|
return {}
|
|
try:
|
|
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
|
except yaml.YAMLError as exc:
|
|
logger.warning("Failed to parse imitate config %s: %s", f, exc)
|
|
return {}
|
|
return raw.get("imitate", {}) or {}
|
|
|
|
|
|
def _load_cforch_config() -> dict:
|
|
"""Read cforch section for ollama_url fallback."""
|
|
f = _config_file()
|
|
if not f.exists():
|
|
return {}
|
|
try:
|
|
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
|
except yaml.YAMLError as exc:
|
|
return {}
|
|
return raw.get("cforch", {}) or {}
|
|
|
|
|
|
def _ollama_url(cfg: dict) -> str:
|
|
cforch = _load_cforch_config()
|
|
return cfg.get("ollama_url") or cforch.get("ollama_url") or "http://localhost:11434"
|
|
|
|
|
|
def _cforch_url() -> str:
|
|
cforch = _load_cforch_config()
|
|
return cforch.get("coordinator_url") or "http://localhost:7700"
|
|
|
|
|
|
def _cforch_catalog(cforch_base: str) -> list[dict]:
|
|
"""Fetch the live cf-text catalog from cf-orch.
|
|
|
|
Filters out proxy entries (ollama://, vllm://, http://) — those models are
|
|
served by their own services and should not be allocated via cf-text.
|
|
Returns only models with real file-system paths that cf-text can load directly.
|
|
"""
|
|
try:
|
|
resp = httpx.get(
|
|
f"{cforch_base}/api/services/cf-text/catalog",
|
|
params={"node_id": "heimdall"},
|
|
timeout=5.0,
|
|
)
|
|
resp.raise_for_status()
|
|
raw = resp.json()
|
|
result = []
|
|
for model_id, entry in raw.items():
|
|
if not isinstance(entry, dict):
|
|
continue
|
|
path = entry.get("path", "")
|
|
# Skip proxy entries — they're routed through other services
|
|
if "://" in path:
|
|
continue
|
|
result.append({
|
|
"id": model_id,
|
|
"vram_mb": entry.get("vram_mb", 0),
|
|
"description": entry.get("description", ""),
|
|
})
|
|
return result
|
|
except Exception as exc:
|
|
logger.warning("Could not fetch cf-orch catalog: %s", exc)
|
|
return []
|
|
|
|
|
|
def _http_get_json(url: str, timeout: int = 5) -> Any:
|
|
"""Fetch JSON from url; raise URLError on failure."""
|
|
req = Request(url, headers={"Accept": "application/json"})
|
|
with urlopen(req, timeout=timeout) as resp:
|
|
return json.loads(resp.read().decode("utf-8"))
|
|
|
|
|
|
def _is_online(base_url: str, health_path: str = "/api/health") -> bool:
|
|
"""Return True if the product's health endpoint responds OK."""
|
|
try:
|
|
data = _http_get_json(f"{base_url.rstrip('/')}{health_path}", timeout=2)
|
|
return bool(data)
|
|
except Exception:
|
|
return False
|
|
|
|
|
|
def _extract_sample(
|
|
raw: Any,
|
|
text_fields: list[str],
|
|
sample_index: int = 0,
|
|
sample_key: str | None = None,
|
|
) -> dict[str, Any]:
|
|
"""Pull one item from a list or dict response and extract text_fields.
|
|
|
|
sample_key: if provided, unwrap raw[sample_key] before looking for a list.
|
|
Falls back to a set of conventional envelope keys if sample_key is absent.
|
|
"""
|
|
item: dict[str, Any]
|
|
if isinstance(raw, list):
|
|
if not raw:
|
|
return {}
|
|
item = raw[min(sample_index, len(raw) - 1)]
|
|
elif isinstance(raw, dict):
|
|
# Use declared sample_key first, then fall back to conventional names.
|
|
_ENVELOPE_KEYS = (
|
|
"samples", "items", "results", "data", "jobs", "listings",
|
|
"pantry", "saved_searches", "entries", "calls", "records",
|
|
)
|
|
search_keys = ([sample_key] if sample_key else []) + list(_ENVELOPE_KEYS)
|
|
for key in search_keys:
|
|
if key in raw and isinstance(raw[key], list):
|
|
lst = raw[key]
|
|
item = lst[min(sample_index, len(lst) - 1)] if lst else {}
|
|
break
|
|
else:
|
|
item = raw
|
|
else:
|
|
return {}
|
|
|
|
parts = []
|
|
for field in text_fields:
|
|
val = item.get(field)
|
|
if val and str(val).strip():
|
|
parts.append(f"**{field}**: {val}")
|
|
return {"item": item, "text": "\n\n".join(parts)}
|
|
|
|
|
|
def _candidates_file() -> Path:
|
|
return _DATA_DIR / "sft_candidates.jsonl"
|
|
|
|
|
|
def _sse(data: dict) -> str:
|
|
return f"data: {json.dumps(data)}\n\n"
|
|
|
|
|
|
def _fetch_image_b64(image_url: str) -> str:
|
|
"""Download an image URL and return it as a base64 string for ollama.
|
|
|
|
Returns empty string on any failure — a missing image is non-fatal;
|
|
the model will still run against the text prompt alone.
|
|
"""
|
|
try:
|
|
req = Request(image_url, headers={"User-Agent": "Avocet/1.0"})
|
|
with urlopen(req, timeout=10) as resp:
|
|
return base64.b64encode(resp.read()).decode("ascii")
|
|
except Exception as exc:
|
|
logger.warning("Failed to fetch image %s: %s", image_url, exc)
|
|
return ""
|
|
|
|
|
|
def _run_ollama_streaming(
|
|
ollama_base: str,
|
|
model_id: str,
|
|
prompt: str,
|
|
temperature: float,
|
|
system: str = "",
|
|
images: list[str] | None = None,
|
|
) -> tuple[str, int]:
|
|
"""Call ollama /api/generate with stream=False; return (full_response, elapsed_ms).
|
|
|
|
Blocks until the model finishes; yields nothing — streaming is handled by
|
|
the SSE generator in run_imitate().
|
|
|
|
system: optional system prompt passed as a separate field to ollama.
|
|
images: list of base64-encoded image strings (vision models only).
|
|
"""
|
|
url = f"{ollama_base.rstrip('/')}/api/generate"
|
|
body: dict = {
|
|
"model": model_id,
|
|
"prompt": prompt,
|
|
"stream": False,
|
|
"options": {"temperature": temperature},
|
|
}
|
|
if system:
|
|
body["system"] = system
|
|
if images:
|
|
body["images"] = images
|
|
payload = json.dumps(body).encode("utf-8")
|
|
req = Request(url, data=payload, method="POST",
|
|
headers={"Content-Type": "application/json"})
|
|
t0 = time.time()
|
|
try:
|
|
with urlopen(req, timeout=120) as resp:
|
|
body = json.loads(resp.read().decode("utf-8"))
|
|
elapsed = int((time.time() - t0) * 1000)
|
|
return body.get("response", ""), elapsed
|
|
except Exception as exc:
|
|
elapsed = int((time.time() - t0) * 1000)
|
|
raise RuntimeError(str(exc)) from exc
|
|
|
|
|
|
def _run_cftext(
|
|
cforch_base: str,
|
|
model_id: str,
|
|
prompt: str,
|
|
system: str,
|
|
temperature: float,
|
|
startup_timeout_s: float = 180.0,
|
|
) -> tuple[str, int, bool]:
|
|
"""Allocate cf-text via cf-orch, generate, release. Returns (response, elapsed_ms, cold_started).
|
|
|
|
Raises RuntimeError on allocation failure or generation error.
|
|
cold_started=True means the service was launched from scratch (caller may log this).
|
|
|
|
Cold-start detection uses coordinator state signals (running/stopped) rather than
|
|
polling the service health endpoint — this fails fast on model load errors instead
|
|
of waiting out the full timeout.
|
|
"""
|
|
# Allocate
|
|
alloc_resp = httpx.post(
|
|
f"{cforch_base}/api/services/cf-text/allocate",
|
|
json={
|
|
"model_candidates": [model_id],
|
|
"caller": "avocet",
|
|
"pipeline": "imitate",
|
|
},
|
|
timeout=30.0,
|
|
)
|
|
alloc_resp.raise_for_status()
|
|
data = alloc_resp.json()
|
|
service_url: str = data["url"]
|
|
allocation_id: str = data.get("allocation_id", "")
|
|
node_id: str = data.get("node_id", "")
|
|
gpu_id: int | None = data.get("gpu_id")
|
|
cold_started = data.get("started", False) and not data.get("warm", True)
|
|
|
|
# Wait for ready using coordinator state signals
|
|
if cold_started:
|
|
deadline = time.monotonic() + startup_timeout_s
|
|
probe_misses = 0
|
|
while time.monotonic() < deadline:
|
|
try:
|
|
status = httpx.get(
|
|
f"{cforch_base}/api/services/cf-text/status", timeout=5.0
|
|
)
|
|
if status.is_success:
|
|
instances = status.json().get("instances", [])
|
|
match = next(
|
|
(i for i in instances
|
|
if i.get("node_id") == node_id and i.get("gpu_id") == gpu_id),
|
|
None,
|
|
)
|
|
if match:
|
|
probe_misses = 0
|
|
state = match.get("state", "")
|
|
if state == "running":
|
|
break
|
|
elif state == "stopped":
|
|
if allocation_id:
|
|
httpx.delete(
|
|
f"{cforch_base}/api/services/cf-text/allocations/{allocation_id}",
|
|
timeout=5.0,
|
|
)
|
|
raise RuntimeError(f"cf-text failed to load {model_id!r} (service stopped)")
|
|
else:
|
|
probe_misses += 1
|
|
if probe_misses >= 6:
|
|
# Coordinator hasn't registered instance yet — fall back to health poll
|
|
try:
|
|
if httpx.get(f"{service_url}/health", timeout=3.0).is_success:
|
|
break
|
|
except Exception:
|
|
pass
|
|
except RuntimeError:
|
|
raise
|
|
except Exception:
|
|
pass
|
|
time.sleep(2.0)
|
|
else:
|
|
if allocation_id:
|
|
httpx.delete(f"{cforch_base}/api/services/cf-text/allocations/{allocation_id}", timeout=5.0)
|
|
raise RuntimeError(f"cf-text cold start timed out after {startup_timeout_s:.0f}s")
|
|
|
|
# Generate
|
|
messages: list[dict] = []
|
|
if system:
|
|
messages.append({"role": "system", "content": system})
|
|
messages.append({"role": "user", "content": prompt})
|
|
|
|
t0 = time.time()
|
|
try:
|
|
gen_resp = httpx.post(
|
|
f"{service_url}/v1/chat/completions",
|
|
json={
|
|
"model": model_id,
|
|
"messages": messages,
|
|
"max_tokens": 300,
|
|
"temperature": temperature,
|
|
"stream": False,
|
|
},
|
|
timeout=120.0,
|
|
)
|
|
gen_resp.raise_for_status()
|
|
elapsed_ms = int((time.time() - t0) * 1000)
|
|
content = gen_resp.json()["choices"][0]["message"]["content"]
|
|
return content.strip(), elapsed_ms, cold_started
|
|
except Exception as exc:
|
|
elapsed_ms = int((time.time() - t0) * 1000)
|
|
raise RuntimeError(str(exc)) from exc
|
|
finally:
|
|
if allocation_id:
|
|
try:
|
|
httpx.delete(f"{cforch_base}/api/services/cf-text/allocations/{allocation_id}", timeout=5.0)
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
# ── GET /products ──────────────────────────────────────────────────────────────
|
|
|
|
@router.get("/products")
|
|
def get_products() -> dict:
|
|
"""List configured CF products with live online status."""
|
|
cfg = _load_imitate_config()
|
|
products_raw = cfg.get("products", []) or []
|
|
products = []
|
|
for p in products_raw:
|
|
if not isinstance(p, dict):
|
|
continue
|
|
base_url = p.get("base_url", "")
|
|
products.append({
|
|
"id": p.get("id", ""),
|
|
"name": p.get("name", ""),
|
|
"icon": p.get("icon", "📦"),
|
|
"description": p.get("description", ""),
|
|
"base_url": base_url,
|
|
"online": _is_online(base_url, p.get("health_path", "/api/health")) if base_url else False,
|
|
})
|
|
return {"products": products}
|
|
|
|
|
|
# ── GET /products/{product_id}/sample ─────────────────────────────────────────
|
|
|
|
@router.get("/products/{product_id}/sample")
|
|
def get_sample(product_id: str, index: int = 0) -> dict:
|
|
"""Fetch a real sample from the given product's API."""
|
|
cfg = _load_imitate_config()
|
|
products_raw = cfg.get("products", []) or []
|
|
|
|
product: dict | None = None
|
|
for p in products_raw:
|
|
if isinstance(p, dict) and p.get("id") == product_id:
|
|
product = p
|
|
break
|
|
|
|
if product is None:
|
|
raise HTTPException(404, f"Product '{product_id}' not in config")
|
|
|
|
base_url = product.get("base_url", "").rstrip("/")
|
|
endpoint = product.get("sample_endpoint", "")
|
|
if not base_url or not endpoint:
|
|
raise HTTPException(422, "Product missing base_url or sample_endpoint")
|
|
|
|
url = f"{base_url}{endpoint}"
|
|
try:
|
|
raw = _http_get_json(url, timeout=5)
|
|
except URLError as exc:
|
|
raise HTTPException(503, f"Product API unreachable: {exc}") from exc
|
|
except Exception as exc:
|
|
raise HTTPException(502, f"Bad response from product API: {exc}") from exc
|
|
|
|
text_fields = product.get("text_fields", []) or []
|
|
sample_key = product.get("sample_key") or None
|
|
extracted = _extract_sample(raw, text_fields, index, sample_key=sample_key)
|
|
if not extracted:
|
|
raise HTTPException(404, "No sample items returned by product API")
|
|
|
|
prompt_template = product.get("prompt_template", "{text}")
|
|
prompt = prompt_template.replace("{text}", extracted["text"])
|
|
# Also substitute any {field_name} placeholders from the raw item fields.
|
|
item = extracted.get("item", {})
|
|
for field, val in item.items():
|
|
prompt = prompt.replace(f"{{{field}}}", str(val) if val is not None else "")
|
|
|
|
# Expose system_prompt and image_url if the product API returns them.
|
|
# system_prompt: Peregrine, Snipe (vision analysis instructions)
|
|
# image_url: Snipe listing photos — Avocet downloads + base64-encodes at run time
|
|
item = extracted.get("item", {})
|
|
system_prompt = str(item.get("system_prompt", "")) if isinstance(item, dict) else ""
|
|
image_url = str(item.get("image_url", "")) if isinstance(item, dict) else ""
|
|
|
|
return {
|
|
"product_id": product_id,
|
|
"sample_index": index,
|
|
"text": extracted["text"],
|
|
"prompt": prompt,
|
|
"system_prompt": system_prompt,
|
|
"image_url": image_url,
|
|
"raw_item": item,
|
|
}
|
|
|
|
|
|
# ── GET /catalog ───────────────────────────────────────────────────────────────
|
|
|
|
@router.get("/catalog")
|
|
def get_catalog() -> dict:
|
|
"""Return the live cf-text model catalog from cf-orch coordinator."""
|
|
models = _cforch_catalog(_cforch_url())
|
|
return {"models": models}
|
|
|
|
|
|
# ── GET /run (SSE) ─────────────────────────────────────────────────────────────
|
|
|
|
@router.get("/run")
|
|
def run_imitate(
|
|
prompt: str = "",
|
|
model_ids: str = "", # comma-separated ollama model IDs
|
|
cf_text_model_ids: str = "", # comma-separated cf-text model IDs (via cf-orch)
|
|
temperature: float = 0.7,
|
|
product_id: str = "",
|
|
system: str = "", # optional system prompt
|
|
image_url: str = "", # optional image URL for vision models
|
|
) -> StreamingResponse:
|
|
"""Run a prompt through selected ollama models and stream results as SSE.
|
|
|
|
If image_url is provided, the image is downloaded once and passed to every
|
|
model as a base64-encoded blob — allowing vision-capable local models to
|
|
evaluate listing photos the same way Snipe's background task pipeline does.
|
|
"""
|
|
|
|
if not prompt.strip():
|
|
raise HTTPException(422, "prompt is required")
|
|
|
|
ollama_ids = [m.strip() for m in model_ids.split(",") if m.strip()]
|
|
cftext_ids = [m.strip() for m in cf_text_model_ids.split(",") if m.strip()]
|
|
if not ollama_ids and not cftext_ids:
|
|
raise HTTPException(422, "model_ids or cf_text_model_ids is required")
|
|
|
|
cfg = _load_imitate_config()
|
|
ollama_base = _ollama_url(cfg)
|
|
cforch_base = _cforch_url()
|
|
system_ctx = system.strip() or ""
|
|
total_models = len(ollama_ids) + len(cftext_ids)
|
|
|
|
# Download image once before streaming — shared across ollama vision models
|
|
images: list[str] = []
|
|
if image_url.strip():
|
|
b64 = _fetch_image_b64(image_url.strip())
|
|
if b64:
|
|
images = [b64]
|
|
|
|
def generate():
|
|
results: list[dict] = []
|
|
yield _sse({"type": "start", "total_models": total_models, "has_image": bool(images)})
|
|
|
|
# Ollama models
|
|
for model_id in ollama_ids:
|
|
yield _sse({"type": "model_start", "model": model_id, "service": "ollama"})
|
|
try:
|
|
response, elapsed_ms = _run_ollama_streaming(
|
|
ollama_base, model_id, prompt, temperature,
|
|
system=system_ctx, images=images or None,
|
|
)
|
|
result = {
|
|
"model": model_id,
|
|
"response": response,
|
|
"elapsed_ms": elapsed_ms,
|
|
"error": None,
|
|
}
|
|
except Exception as exc:
|
|
result = {
|
|
"model": model_id,
|
|
"response": "",
|
|
"elapsed_ms": 0,
|
|
"error": str(exc),
|
|
}
|
|
results.append(result)
|
|
yield _sse({"type": "model_done", **result})
|
|
|
|
# cf-text models via cf-orch — fan out in parallel when multiple models selected
|
|
if cftext_ids:
|
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
|
|
# Announce all models upfront so the UI can show loading states immediately
|
|
for model_id in cftext_ids:
|
|
yield _sse({"type": "model_start", "model": model_id, "service": "cf-text"})
|
|
|
|
with ThreadPoolExecutor(max_workers=len(cftext_ids)) as pool:
|
|
future_to_model = {
|
|
pool.submit(_run_cftext, cforch_base, mid, prompt, system_ctx, temperature): mid
|
|
for mid in cftext_ids
|
|
}
|
|
for future in as_completed(future_to_model):
|
|
model_id = future_to_model[future]
|
|
try:
|
|
response, elapsed_ms, cold_started = future.result()
|
|
if cold_started:
|
|
yield _sse({"type": "model_coldstart", "model": model_id})
|
|
result = {
|
|
"model": model_id,
|
|
"response": response,
|
|
"elapsed_ms": elapsed_ms,
|
|
"error": None,
|
|
}
|
|
except Exception as exc:
|
|
result = {
|
|
"model": model_id,
|
|
"response": "",
|
|
"elapsed_ms": 0,
|
|
"error": str(exc),
|
|
}
|
|
results.append(result)
|
|
yield _sse({"type": "model_done", **result})
|
|
|
|
yield _sse({"type": "complete", "results": results})
|
|
|
|
return StreamingResponse(
|
|
generate(),
|
|
media_type="text/event-stream",
|
|
headers={
|
|
"Cache-Control": "no-cache",
|
|
"X-Accel-Buffering": "no",
|
|
},
|
|
)
|
|
|
|
|
|
# ── POST /push-corrections ─────────────────────────────────────────────────────
|
|
|
|
class ImitateResult(BaseModel):
|
|
model: str
|
|
response: str
|
|
elapsed_ms: int
|
|
error: str | None = None
|
|
|
|
|
|
class PushCorrectionsRequest(BaseModel):
|
|
product_id: str
|
|
prompt: str
|
|
results: list[ImitateResult]
|
|
|
|
|
|
@router.post("/push-corrections")
|
|
def push_corrections(req: PushCorrectionsRequest) -> dict:
|
|
"""Append imitate results to sft_candidates.jsonl for human review."""
|
|
if not req.prompt.strip():
|
|
raise HTTPException(422, "prompt is required")
|
|
if not req.results:
|
|
raise HTTPException(422, "results list is empty")
|
|
|
|
ts = datetime.now(timezone.utc).isoformat()
|
|
records = []
|
|
for r in req.results:
|
|
if r.error or not r.response.strip():
|
|
continue
|
|
records.append({
|
|
"id": str(uuid.uuid4()),
|
|
"source": "imitate",
|
|
"product_id": req.product_id,
|
|
"prompt_messages": [{"role": "user", "content": req.prompt}],
|
|
"model_response": r.response,
|
|
"model_id": r.model,
|
|
"elapsed_ms": r.elapsed_ms,
|
|
"status": "pending",
|
|
"created_at": ts,
|
|
})
|
|
|
|
if not records:
|
|
raise HTTPException(422, "No non-error results to push")
|
|
|
|
dest = _candidates_file()
|
|
dest.parent.mkdir(parents=True, exist_ok=True)
|
|
for record in records:
|
|
append_jsonl(dest, record)
|
|
|
|
return {"pushed": len(records)}
|