feat: move imitate API into app/data/imitate.py

This commit is contained in:
pyr0ball 2026-05-01 22:12:19 -07:00
parent 99ea39fe38
commit d74ad3f972
3 changed files with 654 additions and 651 deletions

644
app/data/imitate.py Normal file
View file

@ -0,0 +1,644 @@
"""Avocet — Imitate tab API.
Fetches real samples from sibling CF product APIs, sends them through selected
local LLMs (ollama), and streams responses back to the UI. Results can be
pushed into the SFT corrections queue for human review.
All endpoints registered on `router`. api.py includes this with prefix="/api/imitate".
Module-level globals follow the same testability pattern as cforch.py and sft.py:
override _CONFIG_DIR and _DATA_DIR via set_config_dir() / set_data_dir() in tests.
"""
from __future__ import annotations
import base64
import json
import logging
import time
import uuid
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
from urllib.error import URLError
from urllib.request import Request, urlopen
import httpx
import yaml
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from app.utils import append_jsonl
logger = logging.getLogger(__name__)
_ROOT = Path(__file__).parent.parent.parent
_CONFIG_DIR: Path | None = None
_DATA_DIR: Path = _ROOT / "data"
router = APIRouter()
# ── Testability seams ──────────────────────────────────────────────────────────
def set_config_dir(path: Path | None) -> None:
global _CONFIG_DIR
_CONFIG_DIR = path
def set_data_dir(path: Path) -> None:
global _DATA_DIR
_DATA_DIR = path
# ── Internal helpers ───────────────────────────────────────────────────────────
def _config_file() -> Path:
if _CONFIG_DIR is not None:
return _CONFIG_DIR / "label_tool.yaml"
return _ROOT / "config" / "label_tool.yaml"
def _load_imitate_config() -> dict:
"""Read label_tool.yaml and return the imitate sub-dict (or {} if absent)."""
f = _config_file()
if not f.exists():
return {}
try:
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
except yaml.YAMLError as exc:
logger.warning("Failed to parse imitate config %s: %s", f, exc)
return {}
return raw.get("imitate", {}) or {}
def _load_cforch_config() -> dict:
"""Read cforch section for ollama_url fallback."""
f = _config_file()
if not f.exists():
return {}
try:
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
except yaml.YAMLError as exc:
return {}
return raw.get("cforch", {}) or {}
def _ollama_url(cfg: dict) -> str:
cforch = _load_cforch_config()
return cfg.get("ollama_url") or cforch.get("ollama_url") or "http://localhost:11434"
def _cforch_url() -> str:
cforch = _load_cforch_config()
return cforch.get("coordinator_url") or "http://localhost:7700"
def _cforch_catalog(cforch_base: str) -> list[dict]:
"""Fetch the live cf-text catalog from cf-orch.
Filters out proxy entries (ollama://, vllm://, http://) those models are
served by their own services and should not be allocated via cf-text.
Returns only models with real file-system paths that cf-text can load directly.
"""
try:
resp = httpx.get(
f"{cforch_base}/api/services/cf-text/catalog",
params={"node_id": "heimdall"},
timeout=5.0,
)
resp.raise_for_status()
raw = resp.json()
result = []
for model_id, entry in raw.items():
if not isinstance(entry, dict):
continue
path = entry.get("path", "")
# Skip proxy entries — they're routed through other services
if "://" in path:
continue
result.append({
"id": model_id,
"vram_mb": entry.get("vram_mb", 0),
"description": entry.get("description", ""),
})
return result
except Exception as exc:
logger.warning("Could not fetch cf-orch catalog: %s", exc)
return []
def _http_get_json(url: str, timeout: int = 5) -> Any:
"""Fetch JSON from url; raise URLError on failure."""
req = Request(url, headers={"Accept": "application/json"})
with urlopen(req, timeout=timeout) as resp:
return json.loads(resp.read().decode("utf-8"))
def _is_online(base_url: str, health_path: str = "/api/health") -> bool:
"""Return True if the product's health endpoint responds OK."""
try:
data = _http_get_json(f"{base_url.rstrip('/')}{health_path}", timeout=2)
return bool(data)
except Exception:
return False
def _extract_sample(
raw: Any,
text_fields: list[str],
sample_index: int = 0,
sample_key: str | None = None,
) -> dict[str, Any]:
"""Pull one item from a list or dict response and extract text_fields.
sample_key: if provided, unwrap raw[sample_key] before looking for a list.
Falls back to a set of conventional envelope keys if sample_key is absent.
"""
item: dict[str, Any]
if isinstance(raw, list):
if not raw:
return {}
item = raw[min(sample_index, len(raw) - 1)]
elif isinstance(raw, dict):
# Use declared sample_key first, then fall back to conventional names.
_ENVELOPE_KEYS = (
"samples", "items", "results", "data", "jobs", "listings",
"pantry", "saved_searches", "entries", "calls", "records",
)
search_keys = ([sample_key] if sample_key else []) + list(_ENVELOPE_KEYS)
for key in search_keys:
if key in raw and isinstance(raw[key], list):
lst = raw[key]
item = lst[min(sample_index, len(lst) - 1)] if lst else {}
break
else:
item = raw
else:
return {}
parts = []
for field in text_fields:
val = item.get(field)
if val and str(val).strip():
parts.append(f"**{field}**: {val}")
return {"item": item, "text": "\n\n".join(parts)}
def _candidates_file() -> Path:
return _DATA_DIR / "sft_candidates.jsonl"
def _sse(data: dict) -> str:
return f"data: {json.dumps(data)}\n\n"
def _fetch_image_b64(image_url: str) -> str:
"""Download an image URL and return it as a base64 string for ollama.
Returns empty string on any failure a missing image is non-fatal;
the model will still run against the text prompt alone.
"""
try:
req = Request(image_url, headers={"User-Agent": "Avocet/1.0"})
with urlopen(req, timeout=10) as resp:
return base64.b64encode(resp.read()).decode("ascii")
except Exception as exc:
logger.warning("Failed to fetch image %s: %s", image_url, exc)
return ""
def _run_ollama_streaming(
ollama_base: str,
model_id: str,
prompt: str,
temperature: float,
system: str = "",
images: list[str] | None = None,
) -> tuple[str, int]:
"""Call ollama /api/generate with stream=False; return (full_response, elapsed_ms).
Blocks until the model finishes; yields nothing streaming is handled by
the SSE generator in run_imitate().
system: optional system prompt passed as a separate field to ollama.
images: list of base64-encoded image strings (vision models only).
"""
url = f"{ollama_base.rstrip('/')}/api/generate"
body: dict = {
"model": model_id,
"prompt": prompt,
"stream": False,
"options": {"temperature": temperature},
}
if system:
body["system"] = system
if images:
body["images"] = images
payload = json.dumps(body).encode("utf-8")
req = Request(url, data=payload, method="POST",
headers={"Content-Type": "application/json"})
t0 = time.time()
try:
with urlopen(req, timeout=120) as resp:
body = json.loads(resp.read().decode("utf-8"))
elapsed = int((time.time() - t0) * 1000)
return body.get("response", ""), elapsed
except Exception as exc:
elapsed = int((time.time() - t0) * 1000)
raise RuntimeError(str(exc)) from exc
def _run_cftext(
cforch_base: str,
model_id: str,
prompt: str,
system: str,
temperature: float,
startup_timeout_s: float = 180.0,
user_id: str | None = None,
) -> tuple[str, int, bool]:
"""Allocate cf-text via cf-orch, generate, release. Returns (response, elapsed_ms, cold_started).
Raises RuntimeError on allocation failure or generation error.
cold_started=True means the service was launched from scratch (caller may log this).
Cold-start detection uses coordinator state signals (running/stopped) rather than
polling the service health endpoint this fails fast on model load errors instead
of waiting out the full timeout.
"""
# Allocate
alloc_resp = httpx.post(
f"{cforch_base}/api/services/cf-text/allocate",
json={
"model_candidates": [model_id],
"caller": "avocet",
"pipeline": "imitate",
**({"user_id": user_id} if user_id else {}),
},
timeout=30.0,
)
alloc_resp.raise_for_status()
data = alloc_resp.json()
service_url: str = data["url"]
allocation_id: str = data.get("allocation_id", "")
node_id: str = data.get("node_id", "")
gpu_id: int | None = data.get("gpu_id")
cold_started = data.get("started", False) and not data.get("warm", True)
# Wait for ready using coordinator state signals
if cold_started:
deadline = time.monotonic() + startup_timeout_s
probe_misses = 0
while time.monotonic() < deadline:
try:
status = httpx.get(
f"{cforch_base}/api/services/cf-text/status", timeout=5.0
)
if status.is_success:
instances = status.json().get("instances", [])
match = next(
(i for i in instances
if i.get("node_id") == node_id and i.get("gpu_id") == gpu_id),
None,
)
if match:
probe_misses = 0
state = match.get("state", "")
if state == "running":
break
elif state == "stopped":
if allocation_id:
httpx.delete(
f"{cforch_base}/api/services/cf-text/allocations/{allocation_id}",
timeout=5.0,
)
raise RuntimeError(f"cf-text failed to load {model_id!r} (service stopped)")
else:
probe_misses += 1
if probe_misses >= 6:
# Coordinator hasn't registered instance yet — fall back to health poll
try:
if httpx.get(f"{service_url}/health", timeout=3.0).is_success:
break
except Exception:
pass
except RuntimeError:
raise
except Exception:
pass
time.sleep(2.0)
else:
if allocation_id:
httpx.delete(f"{cforch_base}/api/services/cf-text/allocations/{allocation_id}", timeout=5.0)
raise RuntimeError(f"cf-text cold start timed out after {startup_timeout_s:.0f}s")
# Generate
messages: list[dict] = []
if system:
messages.append({"role": "system", "content": system})
messages.append({"role": "user", "content": prompt})
t0 = time.time()
try:
gen_resp = httpx.post(
f"{service_url}/v1/chat/completions",
json={
"model": model_id,
"messages": messages,
"max_tokens": 300,
"temperature": temperature,
"stream": False,
},
timeout=120.0,
)
gen_resp.raise_for_status()
elapsed_ms = int((time.time() - t0) * 1000)
content = gen_resp.json()["choices"][0]["message"]["content"]
return content.strip(), elapsed_ms, cold_started
except Exception as exc:
elapsed_ms = int((time.time() - t0) * 1000)
raise RuntimeError(str(exc)) from exc
finally:
if allocation_id:
try:
httpx.delete(f"{cforch_base}/api/services/cf-text/allocations/{allocation_id}", timeout=5.0)
except Exception:
pass
# ── GET /products ──────────────────────────────────────────────────────────────
@router.get("/products")
def get_products() -> dict:
"""List configured CF products with live online status."""
cfg = _load_imitate_config()
products_raw = cfg.get("products", []) or []
products = []
for p in products_raw:
if not isinstance(p, dict):
continue
base_url = p.get("base_url", "")
products.append({
"id": p.get("id", ""),
"name": p.get("name", ""),
"icon": p.get("icon", "📦"),
"description": p.get("description", ""),
"base_url": base_url,
"online": _is_online(base_url, p.get("health_path", "/api/health")) if base_url else False,
})
return {"products": products}
# ── GET /products/{product_id}/sample ─────────────────────────────────────────
@router.get("/products/{product_id}/sample")
def get_sample(product_id: str, index: int = 0) -> dict:
"""Fetch a real sample from the given product's API."""
cfg = _load_imitate_config()
products_raw = cfg.get("products", []) or []
product: dict | None = None
for p in products_raw:
if isinstance(p, dict) and p.get("id") == product_id:
product = p
break
if product is None:
raise HTTPException(404, f"Product '{product_id}' not in config")
base_url = product.get("base_url", "").rstrip("/")
endpoint = product.get("sample_endpoint", "")
if not base_url or not endpoint:
raise HTTPException(422, "Product missing base_url or sample_endpoint")
url = f"{base_url}{endpoint}"
try:
raw = _http_get_json(url, timeout=5)
except URLError as exc:
raise HTTPException(503, f"Product API unreachable: {exc}") from exc
except Exception as exc:
raise HTTPException(502, f"Bad response from product API: {exc}") from exc
text_fields = product.get("text_fields", []) or []
sample_key = product.get("sample_key") or None
extracted = _extract_sample(raw, text_fields, index, sample_key=sample_key)
if not extracted:
raise HTTPException(404, "No sample items returned by product API")
prompt_template = product.get("prompt_template", "{text}")
prompt = prompt_template.replace("{text}", extracted["text"])
# Also substitute any {field_name} placeholders from the raw item fields.
item = extracted.get("item", {})
for field, val in item.items():
prompt = prompt.replace(f"{{{field}}}", str(val) if val is not None else "")
# Expose system_prompt and image_url if the product API returns them.
# system_prompt: Peregrine, Snipe (vision analysis instructions)
# image_url: Snipe listing photos — Avocet downloads + base64-encodes at run time
item = extracted.get("item", {})
system_prompt = str(item.get("system_prompt", "")) if isinstance(item, dict) else ""
image_url = str(item.get("image_url", "")) if isinstance(item, dict) else ""
return {
"product_id": product_id,
"sample_index": index,
"text": extracted["text"],
"prompt": prompt,
"system_prompt": system_prompt,
"image_url": image_url,
"raw_item": item,
}
# ── GET /catalog ───────────────────────────────────────────────────────────────
@router.get("/catalog")
def get_catalog() -> dict:
"""Return the live cf-text model catalog from cf-orch coordinator."""
models = _cforch_catalog(_cforch_url())
return {"models": models}
# ── GET /run (SSE) ─────────────────────────────────────────────────────────────
def _get_imitate_session(request: Any, response: Any) -> "CloudUser | None":
"""Optional session dependency — returns None when cloud_session is unavailable."""
try:
from app.cloud_session import get_session
return get_session(request, response)
except Exception:
return None
@router.get("/run")
def run_imitate(
prompt: str = "",
model_ids: str = "", # comma-separated ollama model IDs
cf_text_model_ids: str = "", # comma-separated cf-text model IDs (via cf-orch)
temperature: float = 0.7,
product_id: str = "",
system: str = "", # optional system prompt
image_url: str = "", # optional image URL for vision models
session: "Any" = Depends(_get_imitate_session),
) -> StreamingResponse:
"""Run a prompt through selected ollama models and stream results as SSE.
If image_url is provided, the image is downloaded once and passed to every
model as a base64-encoded blob allowing vision-capable local models to
evaluate listing photos the same way Snipe's background task pipeline does.
"""
if not prompt.strip():
raise HTTPException(422, "prompt is required")
ollama_ids = [m.strip() for m in model_ids.split(",") if m.strip()]
cftext_ids = [m.strip() for m in cf_text_model_ids.split(",") if m.strip()]
if not ollama_ids and not cftext_ids:
raise HTTPException(422, "model_ids or cf_text_model_ids is required")
cfg = _load_imitate_config()
ollama_base = _ollama_url(cfg)
cforch_base = _cforch_url()
system_ctx = system.strip() or ""
total_models = len(ollama_ids) + len(cftext_ids)
# Download image once before streaming — shared across ollama vision models
images: list[str] = []
if image_url.strip():
b64 = _fetch_image_b64(image_url.strip())
if b64:
images = [b64]
def generate():
results: list[dict] = []
yield _sse({"type": "start", "total_models": total_models, "has_image": bool(images)})
# Ollama models
for model_id in ollama_ids:
yield _sse({"type": "model_start", "model": model_id, "service": "ollama"})
try:
response, elapsed_ms = _run_ollama_streaming(
ollama_base, model_id, prompt, temperature,
system=system_ctx, images=images or None,
)
result = {
"model": model_id,
"response": response,
"elapsed_ms": elapsed_ms,
"error": None,
}
except Exception as exc:
result = {
"model": model_id,
"response": "",
"elapsed_ms": 0,
"error": str(exc),
}
results.append(result)
yield _sse({"type": "model_done", **result})
# cf-text models via cf-orch — fan out in parallel when multiple models selected
if cftext_ids:
from concurrent.futures import ThreadPoolExecutor, as_completed
# Announce all models upfront so the UI can show loading states immediately
for model_id in cftext_ids:
yield _sse({"type": "model_start", "model": model_id, "service": "cf-text"})
_user_id: str | None = getattr(session, "user_id", None)
# Only forward real cloud user IDs — skip local/anon sessions
if _user_id in (None, "local", "local-dev") or (_user_id or "").startswith("anon-"):
_user_id = None
with ThreadPoolExecutor(max_workers=len(cftext_ids)) as pool:
future_to_model = {
pool.submit(
_run_cftext, cforch_base, mid, prompt, system_ctx, temperature,
180.0, _user_id,
): mid
for mid in cftext_ids
}
for future in as_completed(future_to_model):
model_id = future_to_model[future]
try:
response, elapsed_ms, cold_started = future.result()
if cold_started:
yield _sse({"type": "model_coldstart", "model": model_id})
result = {
"model": model_id,
"response": response,
"elapsed_ms": elapsed_ms,
"error": None,
}
except Exception as exc:
result = {
"model": model_id,
"response": "",
"elapsed_ms": 0,
"error": str(exc),
}
results.append(result)
yield _sse({"type": "model_done", **result})
yield _sse({"type": "complete", "results": results})
return StreamingResponse(
generate(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"X-Accel-Buffering": "no",
},
)
# ── POST /push-corrections ─────────────────────────────────────────────────────
class ImitateResult(BaseModel):
model: str
response: str
elapsed_ms: int
error: str | None = None
class PushCorrectionsRequest(BaseModel):
product_id: str
prompt: str
results: list[ImitateResult]
@router.post("/push-corrections")
def push_corrections(req: PushCorrectionsRequest) -> dict:
"""Append imitate results to sft_candidates.jsonl for human review."""
if not req.prompt.strip():
raise HTTPException(422, "prompt is required")
if not req.results:
raise HTTPException(422, "results list is empty")
ts = datetime.now(timezone.utc).isoformat()
records = []
for r in req.results:
if r.error or not r.response.strip():
continue
records.append({
"id": str(uuid.uuid4()),
"source": "imitate",
"product_id": req.product_id,
"prompt_messages": [{"role": "user", "content": req.prompt}],
"model_response": r.response,
"model_id": r.model,
"elapsed_ms": r.elapsed_ms,
"status": "pending",
"created_at": ts,
})
if not records:
raise HTTPException(422, "No non-error results to push")
dest = _candidates_file()
dest.parent.mkdir(parents=True, exist_ok=True)
for record in records:
append_jsonl(dest, record)
return {"pushed": len(records)}

View file

@ -1,644 +1,3 @@
"""Avocet — Imitate tab API.
Fetches real samples from sibling CF product APIs, sends them through selected
local LLMs (ollama), and streams responses back to the UI. Results can be
pushed into the SFT corrections queue for human review.
All endpoints registered on `router`. api.py includes this with prefix="/api/imitate".
Module-level globals follow the same testability pattern as cforch.py and sft.py:
override _CONFIG_DIR and _DATA_DIR via set_config_dir() / set_data_dir() in tests.
"""
from __future__ import annotations
import base64
import json
import logging
import time
import uuid
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
from urllib.error import URLError
from urllib.request import Request, urlopen
import httpx
import yaml
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from app.utils import append_jsonl
logger = logging.getLogger(__name__)
_ROOT = Path(__file__).parent.parent
_CONFIG_DIR: Path | None = None
_DATA_DIR: Path = _ROOT / "data"
router = APIRouter()
# ── Testability seams ──────────────────────────────────────────────────────────
def set_config_dir(path: Path | None) -> None:
global _CONFIG_DIR
_CONFIG_DIR = path
def set_data_dir(path: Path) -> None:
global _DATA_DIR
_DATA_DIR = path
# ── Internal helpers ───────────────────────────────────────────────────────────
def _config_file() -> Path:
if _CONFIG_DIR is not None:
return _CONFIG_DIR / "label_tool.yaml"
return _ROOT / "config" / "label_tool.yaml"
def _load_imitate_config() -> dict:
"""Read label_tool.yaml and return the imitate sub-dict (or {} if absent)."""
f = _config_file()
if not f.exists():
return {}
try:
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
except yaml.YAMLError as exc:
logger.warning("Failed to parse imitate config %s: %s", f, exc)
return {}
return raw.get("imitate", {}) or {}
def _load_cforch_config() -> dict:
"""Read cforch section for ollama_url fallback."""
f = _config_file()
if not f.exists():
return {}
try:
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
except yaml.YAMLError as exc:
return {}
return raw.get("cforch", {}) or {}
def _ollama_url(cfg: dict) -> str:
cforch = _load_cforch_config()
return cfg.get("ollama_url") or cforch.get("ollama_url") or "http://localhost:11434"
def _cforch_url() -> str:
cforch = _load_cforch_config()
return cforch.get("coordinator_url") or "http://localhost:7700"
def _cforch_catalog(cforch_base: str) -> list[dict]:
"""Fetch the live cf-text catalog from cf-orch.
Filters out proxy entries (ollama://, vllm://, http://) those models are
served by their own services and should not be allocated via cf-text.
Returns only models with real file-system paths that cf-text can load directly.
"""
try:
resp = httpx.get(
f"{cforch_base}/api/services/cf-text/catalog",
params={"node_id": "heimdall"},
timeout=5.0,
)
resp.raise_for_status()
raw = resp.json()
result = []
for model_id, entry in raw.items():
if not isinstance(entry, dict):
continue
path = entry.get("path", "")
# Skip proxy entries — they're routed through other services
if "://" in path:
continue
result.append({
"id": model_id,
"vram_mb": entry.get("vram_mb", 0),
"description": entry.get("description", ""),
})
return result
except Exception as exc:
logger.warning("Could not fetch cf-orch catalog: %s", exc)
return []
def _http_get_json(url: str, timeout: int = 5) -> Any:
"""Fetch JSON from url; raise URLError on failure."""
req = Request(url, headers={"Accept": "application/json"})
with urlopen(req, timeout=timeout) as resp:
return json.loads(resp.read().decode("utf-8"))
def _is_online(base_url: str, health_path: str = "/api/health") -> bool:
"""Return True if the product's health endpoint responds OK."""
try:
data = _http_get_json(f"{base_url.rstrip('/')}{health_path}", timeout=2)
return bool(data)
except Exception:
return False
def _extract_sample(
raw: Any,
text_fields: list[str],
sample_index: int = 0,
sample_key: str | None = None,
) -> dict[str, Any]:
"""Pull one item from a list or dict response and extract text_fields.
sample_key: if provided, unwrap raw[sample_key] before looking for a list.
Falls back to a set of conventional envelope keys if sample_key is absent.
"""
item: dict[str, Any]
if isinstance(raw, list):
if not raw:
return {}
item = raw[min(sample_index, len(raw) - 1)]
elif isinstance(raw, dict):
# Use declared sample_key first, then fall back to conventional names.
_ENVELOPE_KEYS = (
"samples", "items", "results", "data", "jobs", "listings",
"pantry", "saved_searches", "entries", "calls", "records",
)
search_keys = ([sample_key] if sample_key else []) + list(_ENVELOPE_KEYS)
for key in search_keys:
if key in raw and isinstance(raw[key], list):
lst = raw[key]
item = lst[min(sample_index, len(lst) - 1)] if lst else {}
break
else:
item = raw
else:
return {}
parts = []
for field in text_fields:
val = item.get(field)
if val and str(val).strip():
parts.append(f"**{field}**: {val}")
return {"item": item, "text": "\n\n".join(parts)}
def _candidates_file() -> Path:
return _DATA_DIR / "sft_candidates.jsonl"
def _sse(data: dict) -> str:
return f"data: {json.dumps(data)}\n\n"
def _fetch_image_b64(image_url: str) -> str:
"""Download an image URL and return it as a base64 string for ollama.
Returns empty string on any failure a missing image is non-fatal;
the model will still run against the text prompt alone.
"""
try:
req = Request(image_url, headers={"User-Agent": "Avocet/1.0"})
with urlopen(req, timeout=10) as resp:
return base64.b64encode(resp.read()).decode("ascii")
except Exception as exc:
logger.warning("Failed to fetch image %s: %s", image_url, exc)
return ""
def _run_ollama_streaming(
ollama_base: str,
model_id: str,
prompt: str,
temperature: float,
system: str = "",
images: list[str] | None = None,
) -> tuple[str, int]:
"""Call ollama /api/generate with stream=False; return (full_response, elapsed_ms).
Blocks until the model finishes; yields nothing streaming is handled by
the SSE generator in run_imitate().
system: optional system prompt passed as a separate field to ollama.
images: list of base64-encoded image strings (vision models only).
"""
url = f"{ollama_base.rstrip('/')}/api/generate"
body: dict = {
"model": model_id,
"prompt": prompt,
"stream": False,
"options": {"temperature": temperature},
}
if system:
body["system"] = system
if images:
body["images"] = images
payload = json.dumps(body).encode("utf-8")
req = Request(url, data=payload, method="POST",
headers={"Content-Type": "application/json"})
t0 = time.time()
try:
with urlopen(req, timeout=120) as resp:
body = json.loads(resp.read().decode("utf-8"))
elapsed = int((time.time() - t0) * 1000)
return body.get("response", ""), elapsed
except Exception as exc:
elapsed = int((time.time() - t0) * 1000)
raise RuntimeError(str(exc)) from exc
def _run_cftext(
cforch_base: str,
model_id: str,
prompt: str,
system: str,
temperature: float,
startup_timeout_s: float = 180.0,
user_id: str | None = None,
) -> tuple[str, int, bool]:
"""Allocate cf-text via cf-orch, generate, release. Returns (response, elapsed_ms, cold_started).
Raises RuntimeError on allocation failure or generation error.
cold_started=True means the service was launched from scratch (caller may log this).
Cold-start detection uses coordinator state signals (running/stopped) rather than
polling the service health endpoint this fails fast on model load errors instead
of waiting out the full timeout.
"""
# Allocate
alloc_resp = httpx.post(
f"{cforch_base}/api/services/cf-text/allocate",
json={
"model_candidates": [model_id],
"caller": "avocet",
"pipeline": "imitate",
**({"user_id": user_id} if user_id else {}),
},
timeout=30.0,
)
alloc_resp.raise_for_status()
data = alloc_resp.json()
service_url: str = data["url"]
allocation_id: str = data.get("allocation_id", "")
node_id: str = data.get("node_id", "")
gpu_id: int | None = data.get("gpu_id")
cold_started = data.get("started", False) and not data.get("warm", True)
# Wait for ready using coordinator state signals
if cold_started:
deadline = time.monotonic() + startup_timeout_s
probe_misses = 0
while time.monotonic() < deadline:
try:
status = httpx.get(
f"{cforch_base}/api/services/cf-text/status", timeout=5.0
)
if status.is_success:
instances = status.json().get("instances", [])
match = next(
(i for i in instances
if i.get("node_id") == node_id and i.get("gpu_id") == gpu_id),
None,
)
if match:
probe_misses = 0
state = match.get("state", "")
if state == "running":
break
elif state == "stopped":
if allocation_id:
httpx.delete(
f"{cforch_base}/api/services/cf-text/allocations/{allocation_id}",
timeout=5.0,
)
raise RuntimeError(f"cf-text failed to load {model_id!r} (service stopped)")
else:
probe_misses += 1
if probe_misses >= 6:
# Coordinator hasn't registered instance yet — fall back to health poll
try:
if httpx.get(f"{service_url}/health", timeout=3.0).is_success:
break
except Exception:
pass
except RuntimeError:
raise
except Exception:
pass
time.sleep(2.0)
else:
if allocation_id:
httpx.delete(f"{cforch_base}/api/services/cf-text/allocations/{allocation_id}", timeout=5.0)
raise RuntimeError(f"cf-text cold start timed out after {startup_timeout_s:.0f}s")
# Generate
messages: list[dict] = []
if system:
messages.append({"role": "system", "content": system})
messages.append({"role": "user", "content": prompt})
t0 = time.time()
try:
gen_resp = httpx.post(
f"{service_url}/v1/chat/completions",
json={
"model": model_id,
"messages": messages,
"max_tokens": 300,
"temperature": temperature,
"stream": False,
},
timeout=120.0,
)
gen_resp.raise_for_status()
elapsed_ms = int((time.time() - t0) * 1000)
content = gen_resp.json()["choices"][0]["message"]["content"]
return content.strip(), elapsed_ms, cold_started
except Exception as exc:
elapsed_ms = int((time.time() - t0) * 1000)
raise RuntimeError(str(exc)) from exc
finally:
if allocation_id:
try:
httpx.delete(f"{cforch_base}/api/services/cf-text/allocations/{allocation_id}", timeout=5.0)
except Exception:
pass
# ── GET /products ──────────────────────────────────────────────────────────────
@router.get("/products")
def get_products() -> dict:
"""List configured CF products with live online status."""
cfg = _load_imitate_config()
products_raw = cfg.get("products", []) or []
products = []
for p in products_raw:
if not isinstance(p, dict):
continue
base_url = p.get("base_url", "")
products.append({
"id": p.get("id", ""),
"name": p.get("name", ""),
"icon": p.get("icon", "📦"),
"description": p.get("description", ""),
"base_url": base_url,
"online": _is_online(base_url, p.get("health_path", "/api/health")) if base_url else False,
})
return {"products": products}
# ── GET /products/{product_id}/sample ─────────────────────────────────────────
@router.get("/products/{product_id}/sample")
def get_sample(product_id: str, index: int = 0) -> dict:
"""Fetch a real sample from the given product's API."""
cfg = _load_imitate_config()
products_raw = cfg.get("products", []) or []
product: dict | None = None
for p in products_raw:
if isinstance(p, dict) and p.get("id") == product_id:
product = p
break
if product is None:
raise HTTPException(404, f"Product '{product_id}' not in config")
base_url = product.get("base_url", "").rstrip("/")
endpoint = product.get("sample_endpoint", "")
if not base_url or not endpoint:
raise HTTPException(422, "Product missing base_url or sample_endpoint")
url = f"{base_url}{endpoint}"
try:
raw = _http_get_json(url, timeout=5)
except URLError as exc:
raise HTTPException(503, f"Product API unreachable: {exc}") from exc
except Exception as exc:
raise HTTPException(502, f"Bad response from product API: {exc}") from exc
text_fields = product.get("text_fields", []) or []
sample_key = product.get("sample_key") or None
extracted = _extract_sample(raw, text_fields, index, sample_key=sample_key)
if not extracted:
raise HTTPException(404, "No sample items returned by product API")
prompt_template = product.get("prompt_template", "{text}")
prompt = prompt_template.replace("{text}", extracted["text"])
# Also substitute any {field_name} placeholders from the raw item fields.
item = extracted.get("item", {})
for field, val in item.items():
prompt = prompt.replace(f"{{{field}}}", str(val) if val is not None else "")
# Expose system_prompt and image_url if the product API returns them.
# system_prompt: Peregrine, Snipe (vision analysis instructions)
# image_url: Snipe listing photos — Avocet downloads + base64-encodes at run time
item = extracted.get("item", {})
system_prompt = str(item.get("system_prompt", "")) if isinstance(item, dict) else ""
image_url = str(item.get("image_url", "")) if isinstance(item, dict) else ""
return {
"product_id": product_id,
"sample_index": index,
"text": extracted["text"],
"prompt": prompt,
"system_prompt": system_prompt,
"image_url": image_url,
"raw_item": item,
}
# ── GET /catalog ───────────────────────────────────────────────────────────────
@router.get("/catalog")
def get_catalog() -> dict:
"""Return the live cf-text model catalog from cf-orch coordinator."""
models = _cforch_catalog(_cforch_url())
return {"models": models}
# ── GET /run (SSE) ─────────────────────────────────────────────────────────────
def _get_imitate_session(request: Any, response: Any) -> "CloudUser | None":
"""Optional session dependency — returns None when cloud_session is unavailable."""
try:
from app.cloud_session import get_session
return get_session(request, response)
except Exception:
return None
@router.get("/run")
def run_imitate(
prompt: str = "",
model_ids: str = "", # comma-separated ollama model IDs
cf_text_model_ids: str = "", # comma-separated cf-text model IDs (via cf-orch)
temperature: float = 0.7,
product_id: str = "",
system: str = "", # optional system prompt
image_url: str = "", # optional image URL for vision models
session: "Any" = Depends(_get_imitate_session),
) -> StreamingResponse:
"""Run a prompt through selected ollama models and stream results as SSE.
If image_url is provided, the image is downloaded once and passed to every
model as a base64-encoded blob allowing vision-capable local models to
evaluate listing photos the same way Snipe's background task pipeline does.
"""
if not prompt.strip():
raise HTTPException(422, "prompt is required")
ollama_ids = [m.strip() for m in model_ids.split(",") if m.strip()]
cftext_ids = [m.strip() for m in cf_text_model_ids.split(",") if m.strip()]
if not ollama_ids and not cftext_ids:
raise HTTPException(422, "model_ids or cf_text_model_ids is required")
cfg = _load_imitate_config()
ollama_base = _ollama_url(cfg)
cforch_base = _cforch_url()
system_ctx = system.strip() or ""
total_models = len(ollama_ids) + len(cftext_ids)
# Download image once before streaming — shared across ollama vision models
images: list[str] = []
if image_url.strip():
b64 = _fetch_image_b64(image_url.strip())
if b64:
images = [b64]
def generate():
results: list[dict] = []
yield _sse({"type": "start", "total_models": total_models, "has_image": bool(images)})
# Ollama models
for model_id in ollama_ids:
yield _sse({"type": "model_start", "model": model_id, "service": "ollama"})
try:
response, elapsed_ms = _run_ollama_streaming(
ollama_base, model_id, prompt, temperature,
system=system_ctx, images=images or None,
)
result = {
"model": model_id,
"response": response,
"elapsed_ms": elapsed_ms,
"error": None,
}
except Exception as exc:
result = {
"model": model_id,
"response": "",
"elapsed_ms": 0,
"error": str(exc),
}
results.append(result)
yield _sse({"type": "model_done", **result})
# cf-text models via cf-orch — fan out in parallel when multiple models selected
if cftext_ids:
from concurrent.futures import ThreadPoolExecutor, as_completed
# Announce all models upfront so the UI can show loading states immediately
for model_id in cftext_ids:
yield _sse({"type": "model_start", "model": model_id, "service": "cf-text"})
_user_id: str | None = getattr(session, "user_id", None)
# Only forward real cloud user IDs — skip local/anon sessions
if _user_id in (None, "local", "local-dev") or (_user_id or "").startswith("anon-"):
_user_id = None
with ThreadPoolExecutor(max_workers=len(cftext_ids)) as pool:
future_to_model = {
pool.submit(
_run_cftext, cforch_base, mid, prompt, system_ctx, temperature,
180.0, _user_id,
): mid
for mid in cftext_ids
}
for future in as_completed(future_to_model):
model_id = future_to_model[future]
try:
response, elapsed_ms, cold_started = future.result()
if cold_started:
yield _sse({"type": "model_coldstart", "model": model_id})
result = {
"model": model_id,
"response": response,
"elapsed_ms": elapsed_ms,
"error": None,
}
except Exception as exc:
result = {
"model": model_id,
"response": "",
"elapsed_ms": 0,
"error": str(exc),
}
results.append(result)
yield _sse({"type": "model_done", **result})
yield _sse({"type": "complete", "results": results})
return StreamingResponse(
generate(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"X-Accel-Buffering": "no",
},
)
# ── POST /push-corrections ─────────────────────────────────────────────────────
class ImitateResult(BaseModel):
model: str
response: str
elapsed_ms: int
error: str | None = None
class PushCorrectionsRequest(BaseModel):
product_id: str
prompt: str
results: list[ImitateResult]
@router.post("/push-corrections")
def push_corrections(req: PushCorrectionsRequest) -> dict:
"""Append imitate results to sft_candidates.jsonl for human review."""
if not req.prompt.strip():
raise HTTPException(422, "prompt is required")
if not req.results:
raise HTTPException(422, "results list is empty")
ts = datetime.now(timezone.utc).isoformat()
records = []
for r in req.results:
if r.error or not r.response.strip():
continue
records.append({
"id": str(uuid.uuid4()),
"source": "imitate",
"product_id": req.product_id,
"prompt_messages": [{"role": "user", "content": req.prompt}],
"model_response": r.response,
"model_id": r.model,
"elapsed_ms": r.elapsed_ms,
"status": "pending",
"created_at": ts,
})
if not records:
raise HTTPException(422, "No non-error results to push")
dest = _candidates_file()
dest.parent.mkdir(parents=True, exist_ok=True)
for record in records:
append_jsonl(dest, record)
return {"pushed": len(records)}
"""Backward-compat shim -- logic moved to app/data/imitate.py."""
from app.data.imitate import router # noqa: F401
from app.data.imitate import set_config_dir, set_data_dir # noqa: F401

View file

@ -1,4 +1,4 @@
"""Tests for app/imitate.py product registry, sample extraction, corrections push."""
"""Tests for app/imitate.py -- product registry, sample extraction, corrections push."""
from __future__ import annotations
import json
@ -9,10 +9,10 @@ import pytest
from fastapi.testclient import TestClient
from app.api import app
from app import imitate as _imitate_module
from app.data import imitate as _imitate_module
# ── Fixtures ───────────────────────────────────────────────────────────────────
# -- Fixtures ------------------------------------------------------------------
@pytest.fixture(autouse=True)
def reset_module_globals(tmp_path):
@ -70,7 +70,7 @@ def client() -> TestClient:
return TestClient(app, raise_server_exceptions=True)
# ── GET /products ──────────────────────────────────────────────────────────────
# -- GET /products -------------------------------------------------------------
def test_products_empty_when_no_config(config_dir, client):
"""Returns empty list when label_tool.yaml has no imitate section."""
@ -102,7 +102,7 @@ def test_products_offline_when_unreachable(cfg_with_products, client):
assert all(not p["online"] for p in resp.json()["products"])
# ── GET /products/{id}/sample ─────────────────────────────────────────────────
# -- GET /products/{id}/sample -------------------------------------------------
def test_sample_unknown_product(cfg_with_products, client):
"""Returns 404 for a product id not in config."""
@ -149,7 +149,7 @@ def test_sample_404_on_empty_list(cfg_with_products, client):
assert resp.status_code == 404
# ── POST /push-corrections ─────────────────────────────────────────────────────
# -- POST /push-corrections ----------------------------------------------------
def test_push_corrections_appends_jsonl(cfg_with_products, data_dir, client):
"""Successful push writes records to sft_candidates.jsonl."""
@ -214,7 +214,7 @@ def test_push_corrections_all_errors_422(cfg_with_products, data_dir, client):
assert resp.status_code == 422
# ── _extract_sample helper ─────────────────────────────────────────────────────
# -- _extract_sample helper ----------------------------------------------------
def test_extract_sample_list():
result = _imitate_module._extract_sample(