feat: move imitate API into app/data/imitate.py
This commit is contained in:
parent
99ea39fe38
commit
d74ad3f972
3 changed files with 654 additions and 651 deletions
644
app/data/imitate.py
Normal file
644
app/data/imitate.py
Normal file
|
|
@ -0,0 +1,644 @@
|
||||||
|
"""Avocet — Imitate tab API.
|
||||||
|
|
||||||
|
Fetches real samples from sibling CF product APIs, sends them through selected
|
||||||
|
local LLMs (ollama), and streams responses back to the UI. Results can be
|
||||||
|
pushed into the SFT corrections queue for human review.
|
||||||
|
|
||||||
|
All endpoints registered on `router`. api.py includes this with prefix="/api/imitate".
|
||||||
|
|
||||||
|
Module-level globals follow the same testability pattern as cforch.py and sft.py:
|
||||||
|
override _CONFIG_DIR and _DATA_DIR via set_config_dir() / set_data_dir() in tests.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
from urllib.error import URLError
|
||||||
|
from urllib.request import Request, urlopen
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import yaml
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from app.utils import append_jsonl
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_ROOT = Path(__file__).parent.parent.parent
|
||||||
|
_CONFIG_DIR: Path | None = None
|
||||||
|
_DATA_DIR: Path = _ROOT / "data"
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
# ── Testability seams ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def set_config_dir(path: Path | None) -> None:
|
||||||
|
global _CONFIG_DIR
|
||||||
|
_CONFIG_DIR = path
|
||||||
|
|
||||||
|
|
||||||
|
def set_data_dir(path: Path) -> None:
|
||||||
|
global _DATA_DIR
|
||||||
|
_DATA_DIR = path
|
||||||
|
|
||||||
|
|
||||||
|
# ── Internal helpers ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _config_file() -> Path:
|
||||||
|
if _CONFIG_DIR is not None:
|
||||||
|
return _CONFIG_DIR / "label_tool.yaml"
|
||||||
|
return _ROOT / "config" / "label_tool.yaml"
|
||||||
|
|
||||||
|
|
||||||
|
def _load_imitate_config() -> dict:
|
||||||
|
"""Read label_tool.yaml and return the imitate sub-dict (or {} if absent)."""
|
||||||
|
f = _config_file()
|
||||||
|
if not f.exists():
|
||||||
|
return {}
|
||||||
|
try:
|
||||||
|
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
||||||
|
except yaml.YAMLError as exc:
|
||||||
|
logger.warning("Failed to parse imitate config %s: %s", f, exc)
|
||||||
|
return {}
|
||||||
|
return raw.get("imitate", {}) or {}
|
||||||
|
|
||||||
|
|
||||||
|
def _load_cforch_config() -> dict:
|
||||||
|
"""Read cforch section for ollama_url fallback."""
|
||||||
|
f = _config_file()
|
||||||
|
if not f.exists():
|
||||||
|
return {}
|
||||||
|
try:
|
||||||
|
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
||||||
|
except yaml.YAMLError as exc:
|
||||||
|
return {}
|
||||||
|
return raw.get("cforch", {}) or {}
|
||||||
|
|
||||||
|
|
||||||
|
def _ollama_url(cfg: dict) -> str:
|
||||||
|
cforch = _load_cforch_config()
|
||||||
|
return cfg.get("ollama_url") or cforch.get("ollama_url") or "http://localhost:11434"
|
||||||
|
|
||||||
|
|
||||||
|
def _cforch_url() -> str:
|
||||||
|
cforch = _load_cforch_config()
|
||||||
|
return cforch.get("coordinator_url") or "http://localhost:7700"
|
||||||
|
|
||||||
|
|
||||||
|
def _cforch_catalog(cforch_base: str) -> list[dict]:
|
||||||
|
"""Fetch the live cf-text catalog from cf-orch.
|
||||||
|
|
||||||
|
Filters out proxy entries (ollama://, vllm://, http://) — those models are
|
||||||
|
served by their own services and should not be allocated via cf-text.
|
||||||
|
Returns only models with real file-system paths that cf-text can load directly.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
resp = httpx.get(
|
||||||
|
f"{cforch_base}/api/services/cf-text/catalog",
|
||||||
|
params={"node_id": "heimdall"},
|
||||||
|
timeout=5.0,
|
||||||
|
)
|
||||||
|
resp.raise_for_status()
|
||||||
|
raw = resp.json()
|
||||||
|
result = []
|
||||||
|
for model_id, entry in raw.items():
|
||||||
|
if not isinstance(entry, dict):
|
||||||
|
continue
|
||||||
|
path = entry.get("path", "")
|
||||||
|
# Skip proxy entries — they're routed through other services
|
||||||
|
if "://" in path:
|
||||||
|
continue
|
||||||
|
result.append({
|
||||||
|
"id": model_id,
|
||||||
|
"vram_mb": entry.get("vram_mb", 0),
|
||||||
|
"description": entry.get("description", ""),
|
||||||
|
})
|
||||||
|
return result
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("Could not fetch cf-orch catalog: %s", exc)
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
def _http_get_json(url: str, timeout: int = 5) -> Any:
|
||||||
|
"""Fetch JSON from url; raise URLError on failure."""
|
||||||
|
req = Request(url, headers={"Accept": "application/json"})
|
||||||
|
with urlopen(req, timeout=timeout) as resp:
|
||||||
|
return json.loads(resp.read().decode("utf-8"))
|
||||||
|
|
||||||
|
|
||||||
|
def _is_online(base_url: str, health_path: str = "/api/health") -> bool:
|
||||||
|
"""Return True if the product's health endpoint responds OK."""
|
||||||
|
try:
|
||||||
|
data = _http_get_json(f"{base_url.rstrip('/')}{health_path}", timeout=2)
|
||||||
|
return bool(data)
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_sample(
|
||||||
|
raw: Any,
|
||||||
|
text_fields: list[str],
|
||||||
|
sample_index: int = 0,
|
||||||
|
sample_key: str | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Pull one item from a list or dict response and extract text_fields.
|
||||||
|
|
||||||
|
sample_key: if provided, unwrap raw[sample_key] before looking for a list.
|
||||||
|
Falls back to a set of conventional envelope keys if sample_key is absent.
|
||||||
|
"""
|
||||||
|
item: dict[str, Any]
|
||||||
|
if isinstance(raw, list):
|
||||||
|
if not raw:
|
||||||
|
return {}
|
||||||
|
item = raw[min(sample_index, len(raw) - 1)]
|
||||||
|
elif isinstance(raw, dict):
|
||||||
|
# Use declared sample_key first, then fall back to conventional names.
|
||||||
|
_ENVELOPE_KEYS = (
|
||||||
|
"samples", "items", "results", "data", "jobs", "listings",
|
||||||
|
"pantry", "saved_searches", "entries", "calls", "records",
|
||||||
|
)
|
||||||
|
search_keys = ([sample_key] if sample_key else []) + list(_ENVELOPE_KEYS)
|
||||||
|
for key in search_keys:
|
||||||
|
if key in raw and isinstance(raw[key], list):
|
||||||
|
lst = raw[key]
|
||||||
|
item = lst[min(sample_index, len(lst) - 1)] if lst else {}
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
item = raw
|
||||||
|
else:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
parts = []
|
||||||
|
for field in text_fields:
|
||||||
|
val = item.get(field)
|
||||||
|
if val and str(val).strip():
|
||||||
|
parts.append(f"**{field}**: {val}")
|
||||||
|
return {"item": item, "text": "\n\n".join(parts)}
|
||||||
|
|
||||||
|
|
||||||
|
def _candidates_file() -> Path:
|
||||||
|
return _DATA_DIR / "sft_candidates.jsonl"
|
||||||
|
|
||||||
|
|
||||||
|
def _sse(data: dict) -> str:
|
||||||
|
return f"data: {json.dumps(data)}\n\n"
|
||||||
|
|
||||||
|
|
||||||
|
def _fetch_image_b64(image_url: str) -> str:
|
||||||
|
"""Download an image URL and return it as a base64 string for ollama.
|
||||||
|
|
||||||
|
Returns empty string on any failure — a missing image is non-fatal;
|
||||||
|
the model will still run against the text prompt alone.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
req = Request(image_url, headers={"User-Agent": "Avocet/1.0"})
|
||||||
|
with urlopen(req, timeout=10) as resp:
|
||||||
|
return base64.b64encode(resp.read()).decode("ascii")
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("Failed to fetch image %s: %s", image_url, exc)
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
def _run_ollama_streaming(
|
||||||
|
ollama_base: str,
|
||||||
|
model_id: str,
|
||||||
|
prompt: str,
|
||||||
|
temperature: float,
|
||||||
|
system: str = "",
|
||||||
|
images: list[str] | None = None,
|
||||||
|
) -> tuple[str, int]:
|
||||||
|
"""Call ollama /api/generate with stream=False; return (full_response, elapsed_ms).
|
||||||
|
|
||||||
|
Blocks until the model finishes; yields nothing — streaming is handled by
|
||||||
|
the SSE generator in run_imitate().
|
||||||
|
|
||||||
|
system: optional system prompt passed as a separate field to ollama.
|
||||||
|
images: list of base64-encoded image strings (vision models only).
|
||||||
|
"""
|
||||||
|
url = f"{ollama_base.rstrip('/')}/api/generate"
|
||||||
|
body: dict = {
|
||||||
|
"model": model_id,
|
||||||
|
"prompt": prompt,
|
||||||
|
"stream": False,
|
||||||
|
"options": {"temperature": temperature},
|
||||||
|
}
|
||||||
|
if system:
|
||||||
|
body["system"] = system
|
||||||
|
if images:
|
||||||
|
body["images"] = images
|
||||||
|
payload = json.dumps(body).encode("utf-8")
|
||||||
|
req = Request(url, data=payload, method="POST",
|
||||||
|
headers={"Content-Type": "application/json"})
|
||||||
|
t0 = time.time()
|
||||||
|
try:
|
||||||
|
with urlopen(req, timeout=120) as resp:
|
||||||
|
body = json.loads(resp.read().decode("utf-8"))
|
||||||
|
elapsed = int((time.time() - t0) * 1000)
|
||||||
|
return body.get("response", ""), elapsed
|
||||||
|
except Exception as exc:
|
||||||
|
elapsed = int((time.time() - t0) * 1000)
|
||||||
|
raise RuntimeError(str(exc)) from exc
|
||||||
|
|
||||||
|
|
||||||
|
def _run_cftext(
|
||||||
|
cforch_base: str,
|
||||||
|
model_id: str,
|
||||||
|
prompt: str,
|
||||||
|
system: str,
|
||||||
|
temperature: float,
|
||||||
|
startup_timeout_s: float = 180.0,
|
||||||
|
user_id: str | None = None,
|
||||||
|
) -> tuple[str, int, bool]:
|
||||||
|
"""Allocate cf-text via cf-orch, generate, release. Returns (response, elapsed_ms, cold_started).
|
||||||
|
|
||||||
|
Raises RuntimeError on allocation failure or generation error.
|
||||||
|
cold_started=True means the service was launched from scratch (caller may log this).
|
||||||
|
|
||||||
|
Cold-start detection uses coordinator state signals (running/stopped) rather than
|
||||||
|
polling the service health endpoint — this fails fast on model load errors instead
|
||||||
|
of waiting out the full timeout.
|
||||||
|
"""
|
||||||
|
# Allocate
|
||||||
|
alloc_resp = httpx.post(
|
||||||
|
f"{cforch_base}/api/services/cf-text/allocate",
|
||||||
|
json={
|
||||||
|
"model_candidates": [model_id],
|
||||||
|
"caller": "avocet",
|
||||||
|
"pipeline": "imitate",
|
||||||
|
**({"user_id": user_id} if user_id else {}),
|
||||||
|
},
|
||||||
|
timeout=30.0,
|
||||||
|
)
|
||||||
|
alloc_resp.raise_for_status()
|
||||||
|
data = alloc_resp.json()
|
||||||
|
service_url: str = data["url"]
|
||||||
|
allocation_id: str = data.get("allocation_id", "")
|
||||||
|
node_id: str = data.get("node_id", "")
|
||||||
|
gpu_id: int | None = data.get("gpu_id")
|
||||||
|
cold_started = data.get("started", False) and not data.get("warm", True)
|
||||||
|
|
||||||
|
# Wait for ready using coordinator state signals
|
||||||
|
if cold_started:
|
||||||
|
deadline = time.monotonic() + startup_timeout_s
|
||||||
|
probe_misses = 0
|
||||||
|
while time.monotonic() < deadline:
|
||||||
|
try:
|
||||||
|
status = httpx.get(
|
||||||
|
f"{cforch_base}/api/services/cf-text/status", timeout=5.0
|
||||||
|
)
|
||||||
|
if status.is_success:
|
||||||
|
instances = status.json().get("instances", [])
|
||||||
|
match = next(
|
||||||
|
(i for i in instances
|
||||||
|
if i.get("node_id") == node_id and i.get("gpu_id") == gpu_id),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
if match:
|
||||||
|
probe_misses = 0
|
||||||
|
state = match.get("state", "")
|
||||||
|
if state == "running":
|
||||||
|
break
|
||||||
|
elif state == "stopped":
|
||||||
|
if allocation_id:
|
||||||
|
httpx.delete(
|
||||||
|
f"{cforch_base}/api/services/cf-text/allocations/{allocation_id}",
|
||||||
|
timeout=5.0,
|
||||||
|
)
|
||||||
|
raise RuntimeError(f"cf-text failed to load {model_id!r} (service stopped)")
|
||||||
|
else:
|
||||||
|
probe_misses += 1
|
||||||
|
if probe_misses >= 6:
|
||||||
|
# Coordinator hasn't registered instance yet — fall back to health poll
|
||||||
|
try:
|
||||||
|
if httpx.get(f"{service_url}/health", timeout=3.0).is_success:
|
||||||
|
break
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
except RuntimeError:
|
||||||
|
raise
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
time.sleep(2.0)
|
||||||
|
else:
|
||||||
|
if allocation_id:
|
||||||
|
httpx.delete(f"{cforch_base}/api/services/cf-text/allocations/{allocation_id}", timeout=5.0)
|
||||||
|
raise RuntimeError(f"cf-text cold start timed out after {startup_timeout_s:.0f}s")
|
||||||
|
|
||||||
|
# Generate
|
||||||
|
messages: list[dict] = []
|
||||||
|
if system:
|
||||||
|
messages.append({"role": "system", "content": system})
|
||||||
|
messages.append({"role": "user", "content": prompt})
|
||||||
|
|
||||||
|
t0 = time.time()
|
||||||
|
try:
|
||||||
|
gen_resp = httpx.post(
|
||||||
|
f"{service_url}/v1/chat/completions",
|
||||||
|
json={
|
||||||
|
"model": model_id,
|
||||||
|
"messages": messages,
|
||||||
|
"max_tokens": 300,
|
||||||
|
"temperature": temperature,
|
||||||
|
"stream": False,
|
||||||
|
},
|
||||||
|
timeout=120.0,
|
||||||
|
)
|
||||||
|
gen_resp.raise_for_status()
|
||||||
|
elapsed_ms = int((time.time() - t0) * 1000)
|
||||||
|
content = gen_resp.json()["choices"][0]["message"]["content"]
|
||||||
|
return content.strip(), elapsed_ms, cold_started
|
||||||
|
except Exception as exc:
|
||||||
|
elapsed_ms = int((time.time() - t0) * 1000)
|
||||||
|
raise RuntimeError(str(exc)) from exc
|
||||||
|
finally:
|
||||||
|
if allocation_id:
|
||||||
|
try:
|
||||||
|
httpx.delete(f"{cforch_base}/api/services/cf-text/allocations/{allocation_id}", timeout=5.0)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# ── GET /products ──────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@router.get("/products")
|
||||||
|
def get_products() -> dict:
|
||||||
|
"""List configured CF products with live online status."""
|
||||||
|
cfg = _load_imitate_config()
|
||||||
|
products_raw = cfg.get("products", []) or []
|
||||||
|
products = []
|
||||||
|
for p in products_raw:
|
||||||
|
if not isinstance(p, dict):
|
||||||
|
continue
|
||||||
|
base_url = p.get("base_url", "")
|
||||||
|
products.append({
|
||||||
|
"id": p.get("id", ""),
|
||||||
|
"name": p.get("name", ""),
|
||||||
|
"icon": p.get("icon", "📦"),
|
||||||
|
"description": p.get("description", ""),
|
||||||
|
"base_url": base_url,
|
||||||
|
"online": _is_online(base_url, p.get("health_path", "/api/health")) if base_url else False,
|
||||||
|
})
|
||||||
|
return {"products": products}
|
||||||
|
|
||||||
|
|
||||||
|
# ── GET /products/{product_id}/sample ─────────────────────────────────────────
|
||||||
|
|
||||||
|
@router.get("/products/{product_id}/sample")
|
||||||
|
def get_sample(product_id: str, index: int = 0) -> dict:
|
||||||
|
"""Fetch a real sample from the given product's API."""
|
||||||
|
cfg = _load_imitate_config()
|
||||||
|
products_raw = cfg.get("products", []) or []
|
||||||
|
|
||||||
|
product: dict | None = None
|
||||||
|
for p in products_raw:
|
||||||
|
if isinstance(p, dict) and p.get("id") == product_id:
|
||||||
|
product = p
|
||||||
|
break
|
||||||
|
|
||||||
|
if product is None:
|
||||||
|
raise HTTPException(404, f"Product '{product_id}' not in config")
|
||||||
|
|
||||||
|
base_url = product.get("base_url", "").rstrip("/")
|
||||||
|
endpoint = product.get("sample_endpoint", "")
|
||||||
|
if not base_url or not endpoint:
|
||||||
|
raise HTTPException(422, "Product missing base_url or sample_endpoint")
|
||||||
|
|
||||||
|
url = f"{base_url}{endpoint}"
|
||||||
|
try:
|
||||||
|
raw = _http_get_json(url, timeout=5)
|
||||||
|
except URLError as exc:
|
||||||
|
raise HTTPException(503, f"Product API unreachable: {exc}") from exc
|
||||||
|
except Exception as exc:
|
||||||
|
raise HTTPException(502, f"Bad response from product API: {exc}") from exc
|
||||||
|
|
||||||
|
text_fields = product.get("text_fields", []) or []
|
||||||
|
sample_key = product.get("sample_key") or None
|
||||||
|
extracted = _extract_sample(raw, text_fields, index, sample_key=sample_key)
|
||||||
|
if not extracted:
|
||||||
|
raise HTTPException(404, "No sample items returned by product API")
|
||||||
|
|
||||||
|
prompt_template = product.get("prompt_template", "{text}")
|
||||||
|
prompt = prompt_template.replace("{text}", extracted["text"])
|
||||||
|
# Also substitute any {field_name} placeholders from the raw item fields.
|
||||||
|
item = extracted.get("item", {})
|
||||||
|
for field, val in item.items():
|
||||||
|
prompt = prompt.replace(f"{{{field}}}", str(val) if val is not None else "")
|
||||||
|
|
||||||
|
# Expose system_prompt and image_url if the product API returns them.
|
||||||
|
# system_prompt: Peregrine, Snipe (vision analysis instructions)
|
||||||
|
# image_url: Snipe listing photos — Avocet downloads + base64-encodes at run time
|
||||||
|
item = extracted.get("item", {})
|
||||||
|
system_prompt = str(item.get("system_prompt", "")) if isinstance(item, dict) else ""
|
||||||
|
image_url = str(item.get("image_url", "")) if isinstance(item, dict) else ""
|
||||||
|
|
||||||
|
return {
|
||||||
|
"product_id": product_id,
|
||||||
|
"sample_index": index,
|
||||||
|
"text": extracted["text"],
|
||||||
|
"prompt": prompt,
|
||||||
|
"system_prompt": system_prompt,
|
||||||
|
"image_url": image_url,
|
||||||
|
"raw_item": item,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ── GET /catalog ───────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@router.get("/catalog")
|
||||||
|
def get_catalog() -> dict:
|
||||||
|
"""Return the live cf-text model catalog from cf-orch coordinator."""
|
||||||
|
models = _cforch_catalog(_cforch_url())
|
||||||
|
return {"models": models}
|
||||||
|
|
||||||
|
|
||||||
|
# ── GET /run (SSE) ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _get_imitate_session(request: Any, response: Any) -> "CloudUser | None":
|
||||||
|
"""Optional session dependency — returns None when cloud_session is unavailable."""
|
||||||
|
try:
|
||||||
|
from app.cloud_session import get_session
|
||||||
|
return get_session(request, response)
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/run")
|
||||||
|
def run_imitate(
|
||||||
|
prompt: str = "",
|
||||||
|
model_ids: str = "", # comma-separated ollama model IDs
|
||||||
|
cf_text_model_ids: str = "", # comma-separated cf-text model IDs (via cf-orch)
|
||||||
|
temperature: float = 0.7,
|
||||||
|
product_id: str = "",
|
||||||
|
system: str = "", # optional system prompt
|
||||||
|
image_url: str = "", # optional image URL for vision models
|
||||||
|
session: "Any" = Depends(_get_imitate_session),
|
||||||
|
) -> StreamingResponse:
|
||||||
|
"""Run a prompt through selected ollama models and stream results as SSE.
|
||||||
|
|
||||||
|
If image_url is provided, the image is downloaded once and passed to every
|
||||||
|
model as a base64-encoded blob — allowing vision-capable local models to
|
||||||
|
evaluate listing photos the same way Snipe's background task pipeline does.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not prompt.strip():
|
||||||
|
raise HTTPException(422, "prompt is required")
|
||||||
|
|
||||||
|
ollama_ids = [m.strip() for m in model_ids.split(",") if m.strip()]
|
||||||
|
cftext_ids = [m.strip() for m in cf_text_model_ids.split(",") if m.strip()]
|
||||||
|
if not ollama_ids and not cftext_ids:
|
||||||
|
raise HTTPException(422, "model_ids or cf_text_model_ids is required")
|
||||||
|
|
||||||
|
cfg = _load_imitate_config()
|
||||||
|
ollama_base = _ollama_url(cfg)
|
||||||
|
cforch_base = _cforch_url()
|
||||||
|
system_ctx = system.strip() or ""
|
||||||
|
total_models = len(ollama_ids) + len(cftext_ids)
|
||||||
|
|
||||||
|
# Download image once before streaming — shared across ollama vision models
|
||||||
|
images: list[str] = []
|
||||||
|
if image_url.strip():
|
||||||
|
b64 = _fetch_image_b64(image_url.strip())
|
||||||
|
if b64:
|
||||||
|
images = [b64]
|
||||||
|
|
||||||
|
def generate():
|
||||||
|
results: list[dict] = []
|
||||||
|
yield _sse({"type": "start", "total_models": total_models, "has_image": bool(images)})
|
||||||
|
|
||||||
|
# Ollama models
|
||||||
|
for model_id in ollama_ids:
|
||||||
|
yield _sse({"type": "model_start", "model": model_id, "service": "ollama"})
|
||||||
|
try:
|
||||||
|
response, elapsed_ms = _run_ollama_streaming(
|
||||||
|
ollama_base, model_id, prompt, temperature,
|
||||||
|
system=system_ctx, images=images or None,
|
||||||
|
)
|
||||||
|
result = {
|
||||||
|
"model": model_id,
|
||||||
|
"response": response,
|
||||||
|
"elapsed_ms": elapsed_ms,
|
||||||
|
"error": None,
|
||||||
|
}
|
||||||
|
except Exception as exc:
|
||||||
|
result = {
|
||||||
|
"model": model_id,
|
||||||
|
"response": "",
|
||||||
|
"elapsed_ms": 0,
|
||||||
|
"error": str(exc),
|
||||||
|
}
|
||||||
|
results.append(result)
|
||||||
|
yield _sse({"type": "model_done", **result})
|
||||||
|
|
||||||
|
# cf-text models via cf-orch — fan out in parallel when multiple models selected
|
||||||
|
if cftext_ids:
|
||||||
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
|
|
||||||
|
# Announce all models upfront so the UI can show loading states immediately
|
||||||
|
for model_id in cftext_ids:
|
||||||
|
yield _sse({"type": "model_start", "model": model_id, "service": "cf-text"})
|
||||||
|
|
||||||
|
_user_id: str | None = getattr(session, "user_id", None)
|
||||||
|
# Only forward real cloud user IDs — skip local/anon sessions
|
||||||
|
if _user_id in (None, "local", "local-dev") or (_user_id or "").startswith("anon-"):
|
||||||
|
_user_id = None
|
||||||
|
|
||||||
|
with ThreadPoolExecutor(max_workers=len(cftext_ids)) as pool:
|
||||||
|
future_to_model = {
|
||||||
|
pool.submit(
|
||||||
|
_run_cftext, cforch_base, mid, prompt, system_ctx, temperature,
|
||||||
|
180.0, _user_id,
|
||||||
|
): mid
|
||||||
|
for mid in cftext_ids
|
||||||
|
}
|
||||||
|
for future in as_completed(future_to_model):
|
||||||
|
model_id = future_to_model[future]
|
||||||
|
try:
|
||||||
|
response, elapsed_ms, cold_started = future.result()
|
||||||
|
if cold_started:
|
||||||
|
yield _sse({"type": "model_coldstart", "model": model_id})
|
||||||
|
result = {
|
||||||
|
"model": model_id,
|
||||||
|
"response": response,
|
||||||
|
"elapsed_ms": elapsed_ms,
|
||||||
|
"error": None,
|
||||||
|
}
|
||||||
|
except Exception as exc:
|
||||||
|
result = {
|
||||||
|
"model": model_id,
|
||||||
|
"response": "",
|
||||||
|
"elapsed_ms": 0,
|
||||||
|
"error": str(exc),
|
||||||
|
}
|
||||||
|
results.append(result)
|
||||||
|
yield _sse({"type": "model_done", **result})
|
||||||
|
|
||||||
|
yield _sse({"type": "complete", "results": results})
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
generate(),
|
||||||
|
media_type="text/event-stream",
|
||||||
|
headers={
|
||||||
|
"Cache-Control": "no-cache",
|
||||||
|
"X-Accel-Buffering": "no",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── POST /push-corrections ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class ImitateResult(BaseModel):
|
||||||
|
model: str
|
||||||
|
response: str
|
||||||
|
elapsed_ms: int
|
||||||
|
error: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class PushCorrectionsRequest(BaseModel):
|
||||||
|
product_id: str
|
||||||
|
prompt: str
|
||||||
|
results: list[ImitateResult]
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/push-corrections")
|
||||||
|
def push_corrections(req: PushCorrectionsRequest) -> dict:
|
||||||
|
"""Append imitate results to sft_candidates.jsonl for human review."""
|
||||||
|
if not req.prompt.strip():
|
||||||
|
raise HTTPException(422, "prompt is required")
|
||||||
|
if not req.results:
|
||||||
|
raise HTTPException(422, "results list is empty")
|
||||||
|
|
||||||
|
ts = datetime.now(timezone.utc).isoformat()
|
||||||
|
records = []
|
||||||
|
for r in req.results:
|
||||||
|
if r.error or not r.response.strip():
|
||||||
|
continue
|
||||||
|
records.append({
|
||||||
|
"id": str(uuid.uuid4()),
|
||||||
|
"source": "imitate",
|
||||||
|
"product_id": req.product_id,
|
||||||
|
"prompt_messages": [{"role": "user", "content": req.prompt}],
|
||||||
|
"model_response": r.response,
|
||||||
|
"model_id": r.model,
|
||||||
|
"elapsed_ms": r.elapsed_ms,
|
||||||
|
"status": "pending",
|
||||||
|
"created_at": ts,
|
||||||
|
})
|
||||||
|
|
||||||
|
if not records:
|
||||||
|
raise HTTPException(422, "No non-error results to push")
|
||||||
|
|
||||||
|
dest = _candidates_file()
|
||||||
|
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
for record in records:
|
||||||
|
append_jsonl(dest, record)
|
||||||
|
|
||||||
|
return {"pushed": len(records)}
|
||||||
647
app/imitate.py
647
app/imitate.py
|
|
@ -1,644 +1,3 @@
|
||||||
"""Avocet — Imitate tab API.
|
"""Backward-compat shim -- logic moved to app/data/imitate.py."""
|
||||||
|
from app.data.imitate import router # noqa: F401
|
||||||
Fetches real samples from sibling CF product APIs, sends them through selected
|
from app.data.imitate import set_config_dir, set_data_dir # noqa: F401
|
||||||
local LLMs (ollama), and streams responses back to the UI. Results can be
|
|
||||||
pushed into the SFT corrections queue for human review.
|
|
||||||
|
|
||||||
All endpoints registered on `router`. api.py includes this with prefix="/api/imitate".
|
|
||||||
|
|
||||||
Module-level globals follow the same testability pattern as cforch.py and sft.py:
|
|
||||||
override _CONFIG_DIR and _DATA_DIR via set_config_dir() / set_data_dir() in tests.
|
|
||||||
"""
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import base64
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import time
|
|
||||||
import uuid
|
|
||||||
from datetime import datetime, timezone
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any
|
|
||||||
from urllib.error import URLError
|
|
||||||
from urllib.request import Request, urlopen
|
|
||||||
|
|
||||||
import httpx
|
|
||||||
import yaml
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException
|
|
||||||
from fastapi.responses import StreamingResponse
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from app.utils import append_jsonl
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
_ROOT = Path(__file__).parent.parent
|
|
||||||
_CONFIG_DIR: Path | None = None
|
|
||||||
_DATA_DIR: Path = _ROOT / "data"
|
|
||||||
|
|
||||||
router = APIRouter()
|
|
||||||
|
|
||||||
|
|
||||||
# ── Testability seams ──────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
def set_config_dir(path: Path | None) -> None:
|
|
||||||
global _CONFIG_DIR
|
|
||||||
_CONFIG_DIR = path
|
|
||||||
|
|
||||||
|
|
||||||
def set_data_dir(path: Path) -> None:
|
|
||||||
global _DATA_DIR
|
|
||||||
_DATA_DIR = path
|
|
||||||
|
|
||||||
|
|
||||||
# ── Internal helpers ───────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
def _config_file() -> Path:
|
|
||||||
if _CONFIG_DIR is not None:
|
|
||||||
return _CONFIG_DIR / "label_tool.yaml"
|
|
||||||
return _ROOT / "config" / "label_tool.yaml"
|
|
||||||
|
|
||||||
|
|
||||||
def _load_imitate_config() -> dict:
|
|
||||||
"""Read label_tool.yaml and return the imitate sub-dict (or {} if absent)."""
|
|
||||||
f = _config_file()
|
|
||||||
if not f.exists():
|
|
||||||
return {}
|
|
||||||
try:
|
|
||||||
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
|
||||||
except yaml.YAMLError as exc:
|
|
||||||
logger.warning("Failed to parse imitate config %s: %s", f, exc)
|
|
||||||
return {}
|
|
||||||
return raw.get("imitate", {}) or {}
|
|
||||||
|
|
||||||
|
|
||||||
def _load_cforch_config() -> dict:
|
|
||||||
"""Read cforch section for ollama_url fallback."""
|
|
||||||
f = _config_file()
|
|
||||||
if not f.exists():
|
|
||||||
return {}
|
|
||||||
try:
|
|
||||||
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
|
||||||
except yaml.YAMLError as exc:
|
|
||||||
return {}
|
|
||||||
return raw.get("cforch", {}) or {}
|
|
||||||
|
|
||||||
|
|
||||||
def _ollama_url(cfg: dict) -> str:
|
|
||||||
cforch = _load_cforch_config()
|
|
||||||
return cfg.get("ollama_url") or cforch.get("ollama_url") or "http://localhost:11434"
|
|
||||||
|
|
||||||
|
|
||||||
def _cforch_url() -> str:
|
|
||||||
cforch = _load_cforch_config()
|
|
||||||
return cforch.get("coordinator_url") or "http://localhost:7700"
|
|
||||||
|
|
||||||
|
|
||||||
def _cforch_catalog(cforch_base: str) -> list[dict]:
|
|
||||||
"""Fetch the live cf-text catalog from cf-orch.
|
|
||||||
|
|
||||||
Filters out proxy entries (ollama://, vllm://, http://) — those models are
|
|
||||||
served by their own services and should not be allocated via cf-text.
|
|
||||||
Returns only models with real file-system paths that cf-text can load directly.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
resp = httpx.get(
|
|
||||||
f"{cforch_base}/api/services/cf-text/catalog",
|
|
||||||
params={"node_id": "heimdall"},
|
|
||||||
timeout=5.0,
|
|
||||||
)
|
|
||||||
resp.raise_for_status()
|
|
||||||
raw = resp.json()
|
|
||||||
result = []
|
|
||||||
for model_id, entry in raw.items():
|
|
||||||
if not isinstance(entry, dict):
|
|
||||||
continue
|
|
||||||
path = entry.get("path", "")
|
|
||||||
# Skip proxy entries — they're routed through other services
|
|
||||||
if "://" in path:
|
|
||||||
continue
|
|
||||||
result.append({
|
|
||||||
"id": model_id,
|
|
||||||
"vram_mb": entry.get("vram_mb", 0),
|
|
||||||
"description": entry.get("description", ""),
|
|
||||||
})
|
|
||||||
return result
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning("Could not fetch cf-orch catalog: %s", exc)
|
|
||||||
return []
|
|
||||||
|
|
||||||
|
|
||||||
def _http_get_json(url: str, timeout: int = 5) -> Any:
|
|
||||||
"""Fetch JSON from url; raise URLError on failure."""
|
|
||||||
req = Request(url, headers={"Accept": "application/json"})
|
|
||||||
with urlopen(req, timeout=timeout) as resp:
|
|
||||||
return json.loads(resp.read().decode("utf-8"))
|
|
||||||
|
|
||||||
|
|
||||||
def _is_online(base_url: str, health_path: str = "/api/health") -> bool:
|
|
||||||
"""Return True if the product's health endpoint responds OK."""
|
|
||||||
try:
|
|
||||||
data = _http_get_json(f"{base_url.rstrip('/')}{health_path}", timeout=2)
|
|
||||||
return bool(data)
|
|
||||||
except Exception:
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def _extract_sample(
|
|
||||||
raw: Any,
|
|
||||||
text_fields: list[str],
|
|
||||||
sample_index: int = 0,
|
|
||||||
sample_key: str | None = None,
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""Pull one item from a list or dict response and extract text_fields.
|
|
||||||
|
|
||||||
sample_key: if provided, unwrap raw[sample_key] before looking for a list.
|
|
||||||
Falls back to a set of conventional envelope keys if sample_key is absent.
|
|
||||||
"""
|
|
||||||
item: dict[str, Any]
|
|
||||||
if isinstance(raw, list):
|
|
||||||
if not raw:
|
|
||||||
return {}
|
|
||||||
item = raw[min(sample_index, len(raw) - 1)]
|
|
||||||
elif isinstance(raw, dict):
|
|
||||||
# Use declared sample_key first, then fall back to conventional names.
|
|
||||||
_ENVELOPE_KEYS = (
|
|
||||||
"samples", "items", "results", "data", "jobs", "listings",
|
|
||||||
"pantry", "saved_searches", "entries", "calls", "records",
|
|
||||||
)
|
|
||||||
search_keys = ([sample_key] if sample_key else []) + list(_ENVELOPE_KEYS)
|
|
||||||
for key in search_keys:
|
|
||||||
if key in raw and isinstance(raw[key], list):
|
|
||||||
lst = raw[key]
|
|
||||||
item = lst[min(sample_index, len(lst) - 1)] if lst else {}
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
item = raw
|
|
||||||
else:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
parts = []
|
|
||||||
for field in text_fields:
|
|
||||||
val = item.get(field)
|
|
||||||
if val and str(val).strip():
|
|
||||||
parts.append(f"**{field}**: {val}")
|
|
||||||
return {"item": item, "text": "\n\n".join(parts)}
|
|
||||||
|
|
||||||
|
|
||||||
def _candidates_file() -> Path:
|
|
||||||
return _DATA_DIR / "sft_candidates.jsonl"
|
|
||||||
|
|
||||||
|
|
||||||
def _sse(data: dict) -> str:
|
|
||||||
return f"data: {json.dumps(data)}\n\n"
|
|
||||||
|
|
||||||
|
|
||||||
def _fetch_image_b64(image_url: str) -> str:
|
|
||||||
"""Download an image URL and return it as a base64 string for ollama.
|
|
||||||
|
|
||||||
Returns empty string on any failure — a missing image is non-fatal;
|
|
||||||
the model will still run against the text prompt alone.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
req = Request(image_url, headers={"User-Agent": "Avocet/1.0"})
|
|
||||||
with urlopen(req, timeout=10) as resp:
|
|
||||||
return base64.b64encode(resp.read()).decode("ascii")
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning("Failed to fetch image %s: %s", image_url, exc)
|
|
||||||
return ""
|
|
||||||
|
|
||||||
|
|
||||||
def _run_ollama_streaming(
|
|
||||||
ollama_base: str,
|
|
||||||
model_id: str,
|
|
||||||
prompt: str,
|
|
||||||
temperature: float,
|
|
||||||
system: str = "",
|
|
||||||
images: list[str] | None = None,
|
|
||||||
) -> tuple[str, int]:
|
|
||||||
"""Call ollama /api/generate with stream=False; return (full_response, elapsed_ms).
|
|
||||||
|
|
||||||
Blocks until the model finishes; yields nothing — streaming is handled by
|
|
||||||
the SSE generator in run_imitate().
|
|
||||||
|
|
||||||
system: optional system prompt passed as a separate field to ollama.
|
|
||||||
images: list of base64-encoded image strings (vision models only).
|
|
||||||
"""
|
|
||||||
url = f"{ollama_base.rstrip('/')}/api/generate"
|
|
||||||
body: dict = {
|
|
||||||
"model": model_id,
|
|
||||||
"prompt": prompt,
|
|
||||||
"stream": False,
|
|
||||||
"options": {"temperature": temperature},
|
|
||||||
}
|
|
||||||
if system:
|
|
||||||
body["system"] = system
|
|
||||||
if images:
|
|
||||||
body["images"] = images
|
|
||||||
payload = json.dumps(body).encode("utf-8")
|
|
||||||
req = Request(url, data=payload, method="POST",
|
|
||||||
headers={"Content-Type": "application/json"})
|
|
||||||
t0 = time.time()
|
|
||||||
try:
|
|
||||||
with urlopen(req, timeout=120) as resp:
|
|
||||||
body = json.loads(resp.read().decode("utf-8"))
|
|
||||||
elapsed = int((time.time() - t0) * 1000)
|
|
||||||
return body.get("response", ""), elapsed
|
|
||||||
except Exception as exc:
|
|
||||||
elapsed = int((time.time() - t0) * 1000)
|
|
||||||
raise RuntimeError(str(exc)) from exc
|
|
||||||
|
|
||||||
|
|
||||||
def _run_cftext(
|
|
||||||
cforch_base: str,
|
|
||||||
model_id: str,
|
|
||||||
prompt: str,
|
|
||||||
system: str,
|
|
||||||
temperature: float,
|
|
||||||
startup_timeout_s: float = 180.0,
|
|
||||||
user_id: str | None = None,
|
|
||||||
) -> tuple[str, int, bool]:
|
|
||||||
"""Allocate cf-text via cf-orch, generate, release. Returns (response, elapsed_ms, cold_started).
|
|
||||||
|
|
||||||
Raises RuntimeError on allocation failure or generation error.
|
|
||||||
cold_started=True means the service was launched from scratch (caller may log this).
|
|
||||||
|
|
||||||
Cold-start detection uses coordinator state signals (running/stopped) rather than
|
|
||||||
polling the service health endpoint — this fails fast on model load errors instead
|
|
||||||
of waiting out the full timeout.
|
|
||||||
"""
|
|
||||||
# Allocate
|
|
||||||
alloc_resp = httpx.post(
|
|
||||||
f"{cforch_base}/api/services/cf-text/allocate",
|
|
||||||
json={
|
|
||||||
"model_candidates": [model_id],
|
|
||||||
"caller": "avocet",
|
|
||||||
"pipeline": "imitate",
|
|
||||||
**({"user_id": user_id} if user_id else {}),
|
|
||||||
},
|
|
||||||
timeout=30.0,
|
|
||||||
)
|
|
||||||
alloc_resp.raise_for_status()
|
|
||||||
data = alloc_resp.json()
|
|
||||||
service_url: str = data["url"]
|
|
||||||
allocation_id: str = data.get("allocation_id", "")
|
|
||||||
node_id: str = data.get("node_id", "")
|
|
||||||
gpu_id: int | None = data.get("gpu_id")
|
|
||||||
cold_started = data.get("started", False) and not data.get("warm", True)
|
|
||||||
|
|
||||||
# Wait for ready using coordinator state signals
|
|
||||||
if cold_started:
|
|
||||||
deadline = time.monotonic() + startup_timeout_s
|
|
||||||
probe_misses = 0
|
|
||||||
while time.monotonic() < deadline:
|
|
||||||
try:
|
|
||||||
status = httpx.get(
|
|
||||||
f"{cforch_base}/api/services/cf-text/status", timeout=5.0
|
|
||||||
)
|
|
||||||
if status.is_success:
|
|
||||||
instances = status.json().get("instances", [])
|
|
||||||
match = next(
|
|
||||||
(i for i in instances
|
|
||||||
if i.get("node_id") == node_id and i.get("gpu_id") == gpu_id),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
if match:
|
|
||||||
probe_misses = 0
|
|
||||||
state = match.get("state", "")
|
|
||||||
if state == "running":
|
|
||||||
break
|
|
||||||
elif state == "stopped":
|
|
||||||
if allocation_id:
|
|
||||||
httpx.delete(
|
|
||||||
f"{cforch_base}/api/services/cf-text/allocations/{allocation_id}",
|
|
||||||
timeout=5.0,
|
|
||||||
)
|
|
||||||
raise RuntimeError(f"cf-text failed to load {model_id!r} (service stopped)")
|
|
||||||
else:
|
|
||||||
probe_misses += 1
|
|
||||||
if probe_misses >= 6:
|
|
||||||
# Coordinator hasn't registered instance yet — fall back to health poll
|
|
||||||
try:
|
|
||||||
if httpx.get(f"{service_url}/health", timeout=3.0).is_success:
|
|
||||||
break
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
except RuntimeError:
|
|
||||||
raise
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
time.sleep(2.0)
|
|
||||||
else:
|
|
||||||
if allocation_id:
|
|
||||||
httpx.delete(f"{cforch_base}/api/services/cf-text/allocations/{allocation_id}", timeout=5.0)
|
|
||||||
raise RuntimeError(f"cf-text cold start timed out after {startup_timeout_s:.0f}s")
|
|
||||||
|
|
||||||
# Generate
|
|
||||||
messages: list[dict] = []
|
|
||||||
if system:
|
|
||||||
messages.append({"role": "system", "content": system})
|
|
||||||
messages.append({"role": "user", "content": prompt})
|
|
||||||
|
|
||||||
t0 = time.time()
|
|
||||||
try:
|
|
||||||
gen_resp = httpx.post(
|
|
||||||
f"{service_url}/v1/chat/completions",
|
|
||||||
json={
|
|
||||||
"model": model_id,
|
|
||||||
"messages": messages,
|
|
||||||
"max_tokens": 300,
|
|
||||||
"temperature": temperature,
|
|
||||||
"stream": False,
|
|
||||||
},
|
|
||||||
timeout=120.0,
|
|
||||||
)
|
|
||||||
gen_resp.raise_for_status()
|
|
||||||
elapsed_ms = int((time.time() - t0) * 1000)
|
|
||||||
content = gen_resp.json()["choices"][0]["message"]["content"]
|
|
||||||
return content.strip(), elapsed_ms, cold_started
|
|
||||||
except Exception as exc:
|
|
||||||
elapsed_ms = int((time.time() - t0) * 1000)
|
|
||||||
raise RuntimeError(str(exc)) from exc
|
|
||||||
finally:
|
|
||||||
if allocation_id:
|
|
||||||
try:
|
|
||||||
httpx.delete(f"{cforch_base}/api/services/cf-text/allocations/{allocation_id}", timeout=5.0)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
# ── GET /products ──────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
@router.get("/products")
|
|
||||||
def get_products() -> dict:
|
|
||||||
"""List configured CF products with live online status."""
|
|
||||||
cfg = _load_imitate_config()
|
|
||||||
products_raw = cfg.get("products", []) or []
|
|
||||||
products = []
|
|
||||||
for p in products_raw:
|
|
||||||
if not isinstance(p, dict):
|
|
||||||
continue
|
|
||||||
base_url = p.get("base_url", "")
|
|
||||||
products.append({
|
|
||||||
"id": p.get("id", ""),
|
|
||||||
"name": p.get("name", ""),
|
|
||||||
"icon": p.get("icon", "📦"),
|
|
||||||
"description": p.get("description", ""),
|
|
||||||
"base_url": base_url,
|
|
||||||
"online": _is_online(base_url, p.get("health_path", "/api/health")) if base_url else False,
|
|
||||||
})
|
|
||||||
return {"products": products}
|
|
||||||
|
|
||||||
|
|
||||||
# ── GET /products/{product_id}/sample ─────────────────────────────────────────
|
|
||||||
|
|
||||||
@router.get("/products/{product_id}/sample")
|
|
||||||
def get_sample(product_id: str, index: int = 0) -> dict:
|
|
||||||
"""Fetch a real sample from the given product's API."""
|
|
||||||
cfg = _load_imitate_config()
|
|
||||||
products_raw = cfg.get("products", []) or []
|
|
||||||
|
|
||||||
product: dict | None = None
|
|
||||||
for p in products_raw:
|
|
||||||
if isinstance(p, dict) and p.get("id") == product_id:
|
|
||||||
product = p
|
|
||||||
break
|
|
||||||
|
|
||||||
if product is None:
|
|
||||||
raise HTTPException(404, f"Product '{product_id}' not in config")
|
|
||||||
|
|
||||||
base_url = product.get("base_url", "").rstrip("/")
|
|
||||||
endpoint = product.get("sample_endpoint", "")
|
|
||||||
if not base_url or not endpoint:
|
|
||||||
raise HTTPException(422, "Product missing base_url or sample_endpoint")
|
|
||||||
|
|
||||||
url = f"{base_url}{endpoint}"
|
|
||||||
try:
|
|
||||||
raw = _http_get_json(url, timeout=5)
|
|
||||||
except URLError as exc:
|
|
||||||
raise HTTPException(503, f"Product API unreachable: {exc}") from exc
|
|
||||||
except Exception as exc:
|
|
||||||
raise HTTPException(502, f"Bad response from product API: {exc}") from exc
|
|
||||||
|
|
||||||
text_fields = product.get("text_fields", []) or []
|
|
||||||
sample_key = product.get("sample_key") or None
|
|
||||||
extracted = _extract_sample(raw, text_fields, index, sample_key=sample_key)
|
|
||||||
if not extracted:
|
|
||||||
raise HTTPException(404, "No sample items returned by product API")
|
|
||||||
|
|
||||||
prompt_template = product.get("prompt_template", "{text}")
|
|
||||||
prompt = prompt_template.replace("{text}", extracted["text"])
|
|
||||||
# Also substitute any {field_name} placeholders from the raw item fields.
|
|
||||||
item = extracted.get("item", {})
|
|
||||||
for field, val in item.items():
|
|
||||||
prompt = prompt.replace(f"{{{field}}}", str(val) if val is not None else "")
|
|
||||||
|
|
||||||
# Expose system_prompt and image_url if the product API returns them.
|
|
||||||
# system_prompt: Peregrine, Snipe (vision analysis instructions)
|
|
||||||
# image_url: Snipe listing photos — Avocet downloads + base64-encodes at run time
|
|
||||||
item = extracted.get("item", {})
|
|
||||||
system_prompt = str(item.get("system_prompt", "")) if isinstance(item, dict) else ""
|
|
||||||
image_url = str(item.get("image_url", "")) if isinstance(item, dict) else ""
|
|
||||||
|
|
||||||
return {
|
|
||||||
"product_id": product_id,
|
|
||||||
"sample_index": index,
|
|
||||||
"text": extracted["text"],
|
|
||||||
"prompt": prompt,
|
|
||||||
"system_prompt": system_prompt,
|
|
||||||
"image_url": image_url,
|
|
||||||
"raw_item": item,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# ── GET /catalog ───────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
@router.get("/catalog")
|
|
||||||
def get_catalog() -> dict:
|
|
||||||
"""Return the live cf-text model catalog from cf-orch coordinator."""
|
|
||||||
models = _cforch_catalog(_cforch_url())
|
|
||||||
return {"models": models}
|
|
||||||
|
|
||||||
|
|
||||||
# ── GET /run (SSE) ─────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
def _get_imitate_session(request: Any, response: Any) -> "CloudUser | None":
|
|
||||||
"""Optional session dependency — returns None when cloud_session is unavailable."""
|
|
||||||
try:
|
|
||||||
from app.cloud_session import get_session
|
|
||||||
return get_session(request, response)
|
|
||||||
except Exception:
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/run")
|
|
||||||
def run_imitate(
|
|
||||||
prompt: str = "",
|
|
||||||
model_ids: str = "", # comma-separated ollama model IDs
|
|
||||||
cf_text_model_ids: str = "", # comma-separated cf-text model IDs (via cf-orch)
|
|
||||||
temperature: float = 0.7,
|
|
||||||
product_id: str = "",
|
|
||||||
system: str = "", # optional system prompt
|
|
||||||
image_url: str = "", # optional image URL for vision models
|
|
||||||
session: "Any" = Depends(_get_imitate_session),
|
|
||||||
) -> StreamingResponse:
|
|
||||||
"""Run a prompt through selected ollama models and stream results as SSE.
|
|
||||||
|
|
||||||
If image_url is provided, the image is downloaded once and passed to every
|
|
||||||
model as a base64-encoded blob — allowing vision-capable local models to
|
|
||||||
evaluate listing photos the same way Snipe's background task pipeline does.
|
|
||||||
"""
|
|
||||||
|
|
||||||
if not prompt.strip():
|
|
||||||
raise HTTPException(422, "prompt is required")
|
|
||||||
|
|
||||||
ollama_ids = [m.strip() for m in model_ids.split(",") if m.strip()]
|
|
||||||
cftext_ids = [m.strip() for m in cf_text_model_ids.split(",") if m.strip()]
|
|
||||||
if not ollama_ids and not cftext_ids:
|
|
||||||
raise HTTPException(422, "model_ids or cf_text_model_ids is required")
|
|
||||||
|
|
||||||
cfg = _load_imitate_config()
|
|
||||||
ollama_base = _ollama_url(cfg)
|
|
||||||
cforch_base = _cforch_url()
|
|
||||||
system_ctx = system.strip() or ""
|
|
||||||
total_models = len(ollama_ids) + len(cftext_ids)
|
|
||||||
|
|
||||||
# Download image once before streaming — shared across ollama vision models
|
|
||||||
images: list[str] = []
|
|
||||||
if image_url.strip():
|
|
||||||
b64 = _fetch_image_b64(image_url.strip())
|
|
||||||
if b64:
|
|
||||||
images = [b64]
|
|
||||||
|
|
||||||
def generate():
|
|
||||||
results: list[dict] = []
|
|
||||||
yield _sse({"type": "start", "total_models": total_models, "has_image": bool(images)})
|
|
||||||
|
|
||||||
# Ollama models
|
|
||||||
for model_id in ollama_ids:
|
|
||||||
yield _sse({"type": "model_start", "model": model_id, "service": "ollama"})
|
|
||||||
try:
|
|
||||||
response, elapsed_ms = _run_ollama_streaming(
|
|
||||||
ollama_base, model_id, prompt, temperature,
|
|
||||||
system=system_ctx, images=images or None,
|
|
||||||
)
|
|
||||||
result = {
|
|
||||||
"model": model_id,
|
|
||||||
"response": response,
|
|
||||||
"elapsed_ms": elapsed_ms,
|
|
||||||
"error": None,
|
|
||||||
}
|
|
||||||
except Exception as exc:
|
|
||||||
result = {
|
|
||||||
"model": model_id,
|
|
||||||
"response": "",
|
|
||||||
"elapsed_ms": 0,
|
|
||||||
"error": str(exc),
|
|
||||||
}
|
|
||||||
results.append(result)
|
|
||||||
yield _sse({"type": "model_done", **result})
|
|
||||||
|
|
||||||
# cf-text models via cf-orch — fan out in parallel when multiple models selected
|
|
||||||
if cftext_ids:
|
|
||||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
||||||
|
|
||||||
# Announce all models upfront so the UI can show loading states immediately
|
|
||||||
for model_id in cftext_ids:
|
|
||||||
yield _sse({"type": "model_start", "model": model_id, "service": "cf-text"})
|
|
||||||
|
|
||||||
_user_id: str | None = getattr(session, "user_id", None)
|
|
||||||
# Only forward real cloud user IDs — skip local/anon sessions
|
|
||||||
if _user_id in (None, "local", "local-dev") or (_user_id or "").startswith("anon-"):
|
|
||||||
_user_id = None
|
|
||||||
|
|
||||||
with ThreadPoolExecutor(max_workers=len(cftext_ids)) as pool:
|
|
||||||
future_to_model = {
|
|
||||||
pool.submit(
|
|
||||||
_run_cftext, cforch_base, mid, prompt, system_ctx, temperature,
|
|
||||||
180.0, _user_id,
|
|
||||||
): mid
|
|
||||||
for mid in cftext_ids
|
|
||||||
}
|
|
||||||
for future in as_completed(future_to_model):
|
|
||||||
model_id = future_to_model[future]
|
|
||||||
try:
|
|
||||||
response, elapsed_ms, cold_started = future.result()
|
|
||||||
if cold_started:
|
|
||||||
yield _sse({"type": "model_coldstart", "model": model_id})
|
|
||||||
result = {
|
|
||||||
"model": model_id,
|
|
||||||
"response": response,
|
|
||||||
"elapsed_ms": elapsed_ms,
|
|
||||||
"error": None,
|
|
||||||
}
|
|
||||||
except Exception as exc:
|
|
||||||
result = {
|
|
||||||
"model": model_id,
|
|
||||||
"response": "",
|
|
||||||
"elapsed_ms": 0,
|
|
||||||
"error": str(exc),
|
|
||||||
}
|
|
||||||
results.append(result)
|
|
||||||
yield _sse({"type": "model_done", **result})
|
|
||||||
|
|
||||||
yield _sse({"type": "complete", "results": results})
|
|
||||||
|
|
||||||
return StreamingResponse(
|
|
||||||
generate(),
|
|
||||||
media_type="text/event-stream",
|
|
||||||
headers={
|
|
||||||
"Cache-Control": "no-cache",
|
|
||||||
"X-Accel-Buffering": "no",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ── POST /push-corrections ─────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
class ImitateResult(BaseModel):
|
|
||||||
model: str
|
|
||||||
response: str
|
|
||||||
elapsed_ms: int
|
|
||||||
error: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class PushCorrectionsRequest(BaseModel):
|
|
||||||
product_id: str
|
|
||||||
prompt: str
|
|
||||||
results: list[ImitateResult]
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/push-corrections")
|
|
||||||
def push_corrections(req: PushCorrectionsRequest) -> dict:
|
|
||||||
"""Append imitate results to sft_candidates.jsonl for human review."""
|
|
||||||
if not req.prompt.strip():
|
|
||||||
raise HTTPException(422, "prompt is required")
|
|
||||||
if not req.results:
|
|
||||||
raise HTTPException(422, "results list is empty")
|
|
||||||
|
|
||||||
ts = datetime.now(timezone.utc).isoformat()
|
|
||||||
records = []
|
|
||||||
for r in req.results:
|
|
||||||
if r.error or not r.response.strip():
|
|
||||||
continue
|
|
||||||
records.append({
|
|
||||||
"id": str(uuid.uuid4()),
|
|
||||||
"source": "imitate",
|
|
||||||
"product_id": req.product_id,
|
|
||||||
"prompt_messages": [{"role": "user", "content": req.prompt}],
|
|
||||||
"model_response": r.response,
|
|
||||||
"model_id": r.model,
|
|
||||||
"elapsed_ms": r.elapsed_ms,
|
|
||||||
"status": "pending",
|
|
||||||
"created_at": ts,
|
|
||||||
})
|
|
||||||
|
|
||||||
if not records:
|
|
||||||
raise HTTPException(422, "No non-error results to push")
|
|
||||||
|
|
||||||
dest = _candidates_file()
|
|
||||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
for record in records:
|
|
||||||
append_jsonl(dest, record)
|
|
||||||
|
|
||||||
return {"pushed": len(records)}
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
"""Tests for app/imitate.py — product registry, sample extraction, corrections push."""
|
"""Tests for app/imitate.py -- product registry, sample extraction, corrections push."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
|
@ -9,10 +9,10 @@ import pytest
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
from app.api import app
|
from app.api import app
|
||||||
from app import imitate as _imitate_module
|
from app.data import imitate as _imitate_module
|
||||||
|
|
||||||
|
|
||||||
# ── Fixtures ───────────────────────────────────────────────────────────────────
|
# -- Fixtures ------------------------------------------------------------------
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def reset_module_globals(tmp_path):
|
def reset_module_globals(tmp_path):
|
||||||
|
|
@ -70,7 +70,7 @@ def client() -> TestClient:
|
||||||
return TestClient(app, raise_server_exceptions=True)
|
return TestClient(app, raise_server_exceptions=True)
|
||||||
|
|
||||||
|
|
||||||
# ── GET /products ──────────────────────────────────────────────────────────────
|
# -- GET /products -------------------------------------------------------------
|
||||||
|
|
||||||
def test_products_empty_when_no_config(config_dir, client):
|
def test_products_empty_when_no_config(config_dir, client):
|
||||||
"""Returns empty list when label_tool.yaml has no imitate section."""
|
"""Returns empty list when label_tool.yaml has no imitate section."""
|
||||||
|
|
@ -102,7 +102,7 @@ def test_products_offline_when_unreachable(cfg_with_products, client):
|
||||||
assert all(not p["online"] for p in resp.json()["products"])
|
assert all(not p["online"] for p in resp.json()["products"])
|
||||||
|
|
||||||
|
|
||||||
# ── GET /products/{id}/sample ─────────────────────────────────────────────────
|
# -- GET /products/{id}/sample -------------------------------------------------
|
||||||
|
|
||||||
def test_sample_unknown_product(cfg_with_products, client):
|
def test_sample_unknown_product(cfg_with_products, client):
|
||||||
"""Returns 404 for a product id not in config."""
|
"""Returns 404 for a product id not in config."""
|
||||||
|
|
@ -149,7 +149,7 @@ def test_sample_404_on_empty_list(cfg_with_products, client):
|
||||||
assert resp.status_code == 404
|
assert resp.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
# ── POST /push-corrections ─────────────────────────────────────────────────────
|
# -- POST /push-corrections ----------------------------------------------------
|
||||||
|
|
||||||
def test_push_corrections_appends_jsonl(cfg_with_products, data_dir, client):
|
def test_push_corrections_appends_jsonl(cfg_with_products, data_dir, client):
|
||||||
"""Successful push writes records to sft_candidates.jsonl."""
|
"""Successful push writes records to sft_candidates.jsonl."""
|
||||||
|
|
@ -214,7 +214,7 @@ def test_push_corrections_all_errors_422(cfg_with_products, data_dir, client):
|
||||||
assert resp.status_code == 422
|
assert resp.status_code == 422
|
||||||
|
|
||||||
|
|
||||||
# ── _extract_sample helper ─────────────────────────────────────────────────────
|
# -- _extract_sample helper ----------------------------------------------------
|
||||||
|
|
||||||
def test_extract_sample_list():
|
def test_extract_sample_list():
|
||||||
result = _imitate_module._extract_sample(
|
result = _imitate_module._extract_sample(
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue