- _extract_sample: add saved_searches, entries, calls, records as recognized list-wrapper keys (snipe/osprey response shapes) - _is_online: accept health_path param (default /api/health) so products using /api/v1/health/ (kiwi) report correctly - products endpoint: pass health_path from config into _is_online
352 lines
12 KiB
Python
352 lines
12 KiB
Python
"""Avocet — Imitate tab API.
|
|
|
|
Fetches real samples from sibling CF product APIs, sends them through selected
|
|
local LLMs (ollama), and streams responses back to the UI. Results can be
|
|
pushed into the SFT corrections queue for human review.
|
|
|
|
All endpoints registered on `router`. api.py includes this with prefix="/api/imitate".
|
|
|
|
Module-level globals follow the same testability pattern as cforch.py and sft.py:
|
|
override _CONFIG_DIR and _DATA_DIR via set_config_dir() / set_data_dir() in tests.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
import time
|
|
import uuid
|
|
from datetime import datetime, timezone
|
|
from pathlib import Path
|
|
from typing import Any
|
|
from urllib.error import URLError
|
|
from urllib.request import Request, urlopen
|
|
|
|
import yaml
|
|
from fastapi import APIRouter, HTTPException
|
|
from fastapi.responses import StreamingResponse
|
|
from pydantic import BaseModel
|
|
|
|
from app.utils import append_jsonl
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_ROOT = Path(__file__).parent.parent
|
|
_CONFIG_DIR: Path | None = None
|
|
_DATA_DIR: Path = _ROOT / "data"
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
# ── Testability seams ──────────────────────────────────────────────────────────
|
|
|
|
def set_config_dir(path: Path | None) -> None:
|
|
global _CONFIG_DIR
|
|
_CONFIG_DIR = path
|
|
|
|
|
|
def set_data_dir(path: Path) -> None:
|
|
global _DATA_DIR
|
|
_DATA_DIR = path
|
|
|
|
|
|
# ── Internal helpers ───────────────────────────────────────────────────────────
|
|
|
|
def _config_file() -> Path:
|
|
if _CONFIG_DIR is not None:
|
|
return _CONFIG_DIR / "label_tool.yaml"
|
|
return _ROOT / "config" / "label_tool.yaml"
|
|
|
|
|
|
def _load_imitate_config() -> dict:
|
|
"""Read label_tool.yaml and return the imitate sub-dict (or {} if absent)."""
|
|
f = _config_file()
|
|
if not f.exists():
|
|
return {}
|
|
try:
|
|
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
|
except yaml.YAMLError as exc:
|
|
logger.warning("Failed to parse imitate config %s: %s", f, exc)
|
|
return {}
|
|
return raw.get("imitate", {}) or {}
|
|
|
|
|
|
def _load_cforch_config() -> dict:
|
|
"""Read cforch section for ollama_url fallback."""
|
|
f = _config_file()
|
|
if not f.exists():
|
|
return {}
|
|
try:
|
|
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
|
except yaml.YAMLError as exc:
|
|
return {}
|
|
return raw.get("cforch", {}) or {}
|
|
|
|
|
|
def _ollama_url(cfg: dict) -> str:
|
|
cforch = _load_cforch_config()
|
|
return cfg.get("ollama_url") or cforch.get("ollama_url") or "http://localhost:11434"
|
|
|
|
|
|
def _http_get_json(url: str, timeout: int = 5) -> Any:
|
|
"""Fetch JSON from url; raise URLError on failure."""
|
|
req = Request(url, headers={"Accept": "application/json"})
|
|
with urlopen(req, timeout=timeout) as resp:
|
|
return json.loads(resp.read().decode("utf-8"))
|
|
|
|
|
|
def _is_online(base_url: str, health_path: str = "/api/health") -> bool:
|
|
"""Return True if the product's health endpoint responds OK."""
|
|
try:
|
|
data = _http_get_json(f"{base_url.rstrip('/')}{health_path}", timeout=2)
|
|
return bool(data)
|
|
except Exception:
|
|
return False
|
|
|
|
|
|
def _extract_sample(
|
|
raw: Any, text_fields: list[str], sample_index: int = 0
|
|
) -> dict[str, Any]:
|
|
"""Pull one item from a list or dict response and extract text_fields."""
|
|
item: dict[str, Any]
|
|
if isinstance(raw, list):
|
|
if not raw:
|
|
return {}
|
|
item = raw[min(sample_index, len(raw) - 1)]
|
|
elif isinstance(raw, dict):
|
|
# may be {items: [...]} or the item itself
|
|
for key in ("items", "results", "data", "jobs", "listings", "pantry",
|
|
"saved_searches", "entries", "calls", "records"):
|
|
if key in raw and isinstance(raw[key], list):
|
|
lst = raw[key]
|
|
item = lst[min(sample_index, len(lst) - 1)] if lst else {}
|
|
break
|
|
else:
|
|
item = raw
|
|
else:
|
|
return {}
|
|
|
|
parts = []
|
|
for field in text_fields:
|
|
val = item.get(field)
|
|
if val and str(val).strip():
|
|
parts.append(f"**{field}**: {val}")
|
|
return {"item": item, "text": "\n\n".join(parts)}
|
|
|
|
|
|
def _candidates_file() -> Path:
|
|
return _DATA_DIR / "sft_candidates.jsonl"
|
|
|
|
|
|
def _sse(data: dict) -> str:
|
|
return f"data: {json.dumps(data)}\n\n"
|
|
|
|
|
|
def _run_ollama_streaming(
|
|
ollama_base: str,
|
|
model_id: str,
|
|
prompt: str,
|
|
temperature: float,
|
|
) -> tuple[str, int]:
|
|
"""Call ollama /api/generate with stream=True; return (full_response, elapsed_ms).
|
|
|
|
Blocks until the model finishes; yields nothing — streaming is handled by
|
|
the SSE generator in run_imitate().
|
|
"""
|
|
url = f"{ollama_base.rstrip('/')}/api/generate"
|
|
payload = json.dumps({
|
|
"model": model_id,
|
|
"prompt": prompt,
|
|
"stream": False,
|
|
"options": {"temperature": temperature},
|
|
}).encode("utf-8")
|
|
req = Request(url, data=payload, method="POST",
|
|
headers={"Content-Type": "application/json"})
|
|
t0 = time.time()
|
|
try:
|
|
with urlopen(req, timeout=120) as resp:
|
|
body = json.loads(resp.read().decode("utf-8"))
|
|
elapsed = int((time.time() - t0) * 1000)
|
|
return body.get("response", ""), elapsed
|
|
except Exception as exc:
|
|
elapsed = int((time.time() - t0) * 1000)
|
|
raise RuntimeError(str(exc)) from exc
|
|
|
|
|
|
# ── GET /products ──────────────────────────────────────────────────────────────
|
|
|
|
@router.get("/products")
|
|
def get_products() -> dict:
|
|
"""List configured CF products with live online status."""
|
|
cfg = _load_imitate_config()
|
|
products_raw = cfg.get("products", []) or []
|
|
products = []
|
|
for p in products_raw:
|
|
if not isinstance(p, dict):
|
|
continue
|
|
base_url = p.get("base_url", "")
|
|
products.append({
|
|
"id": p.get("id", ""),
|
|
"name": p.get("name", ""),
|
|
"icon": p.get("icon", "📦"),
|
|
"description": p.get("description", ""),
|
|
"base_url": base_url,
|
|
"online": _is_online(base_url, p.get("health_path", "/api/health")) if base_url else False,
|
|
})
|
|
return {"products": products}
|
|
|
|
|
|
# ── GET /products/{product_id}/sample ─────────────────────────────────────────
|
|
|
|
@router.get("/products/{product_id}/sample")
|
|
def get_sample(product_id: str, index: int = 0) -> dict:
|
|
"""Fetch a real sample from the given product's API."""
|
|
cfg = _load_imitate_config()
|
|
products_raw = cfg.get("products", []) or []
|
|
|
|
product: dict | None = None
|
|
for p in products_raw:
|
|
if isinstance(p, dict) and p.get("id") == product_id:
|
|
product = p
|
|
break
|
|
|
|
if product is None:
|
|
raise HTTPException(404, f"Product '{product_id}' not in config")
|
|
|
|
base_url = product.get("base_url", "").rstrip("/")
|
|
endpoint = product.get("sample_endpoint", "")
|
|
if not base_url or not endpoint:
|
|
raise HTTPException(422, "Product missing base_url or sample_endpoint")
|
|
|
|
url = f"{base_url}{endpoint}"
|
|
try:
|
|
raw = _http_get_json(url, timeout=5)
|
|
except URLError as exc:
|
|
raise HTTPException(503, f"Product API unreachable: {exc}") from exc
|
|
except Exception as exc:
|
|
raise HTTPException(502, f"Bad response from product API: {exc}") from exc
|
|
|
|
text_fields = product.get("text_fields", []) or []
|
|
extracted = _extract_sample(raw, text_fields, index)
|
|
if not extracted:
|
|
raise HTTPException(404, "No sample items returned by product API")
|
|
|
|
prompt_template = product.get("prompt_template", "{text}")
|
|
prompt = prompt_template.replace("{text}", extracted["text"])
|
|
|
|
return {
|
|
"product_id": product_id,
|
|
"sample_index": index,
|
|
"text": extracted["text"],
|
|
"prompt": prompt,
|
|
"raw_item": extracted.get("item", {}),
|
|
}
|
|
|
|
|
|
# ── GET /run (SSE) ─────────────────────────────────────────────────────────────
|
|
|
|
@router.get("/run")
|
|
def run_imitate(
|
|
prompt: str = "",
|
|
model_ids: str = "", # comma-separated ollama model IDs
|
|
temperature: float = 0.7,
|
|
product_id: str = "",
|
|
) -> StreamingResponse:
|
|
"""Run a prompt through selected ollama models and stream results as SSE."""
|
|
|
|
if not prompt.strip():
|
|
raise HTTPException(422, "prompt is required")
|
|
|
|
ids = [m.strip() for m in model_ids.split(",") if m.strip()]
|
|
if not ids:
|
|
raise HTTPException(422, "model_ids is required")
|
|
|
|
cfg = _load_imitate_config()
|
|
ollama_base = _ollama_url(cfg)
|
|
|
|
def generate():
|
|
results: list[dict] = []
|
|
yield _sse({"type": "start", "total_models": len(ids)})
|
|
|
|
for model_id in ids:
|
|
yield _sse({"type": "model_start", "model": model_id})
|
|
try:
|
|
response, elapsed_ms = _run_ollama_streaming(
|
|
ollama_base, model_id, prompt, temperature
|
|
)
|
|
result = {
|
|
"model": model_id,
|
|
"response": response,
|
|
"elapsed_ms": elapsed_ms,
|
|
"error": None,
|
|
}
|
|
except Exception as exc:
|
|
result = {
|
|
"model": model_id,
|
|
"response": "",
|
|
"elapsed_ms": 0,
|
|
"error": str(exc),
|
|
}
|
|
results.append(result)
|
|
yield _sse({"type": "model_done", **result})
|
|
|
|
yield _sse({"type": "complete", "results": results})
|
|
|
|
return StreamingResponse(
|
|
generate(),
|
|
media_type="text/event-stream",
|
|
headers={
|
|
"Cache-Control": "no-cache",
|
|
"X-Accel-Buffering": "no",
|
|
},
|
|
)
|
|
|
|
|
|
# ── POST /push-corrections ─────────────────────────────────────────────────────
|
|
|
|
class ImitateResult(BaseModel):
|
|
model: str
|
|
response: str
|
|
elapsed_ms: int
|
|
error: str | None = None
|
|
|
|
|
|
class PushCorrectionsRequest(BaseModel):
|
|
product_id: str
|
|
prompt: str
|
|
results: list[ImitateResult]
|
|
|
|
|
|
@router.post("/push-corrections")
|
|
def push_corrections(req: PushCorrectionsRequest) -> dict:
|
|
"""Append imitate results to sft_candidates.jsonl for human review."""
|
|
if not req.prompt.strip():
|
|
raise HTTPException(422, "prompt is required")
|
|
if not req.results:
|
|
raise HTTPException(422, "results list is empty")
|
|
|
|
ts = datetime.now(timezone.utc).isoformat()
|
|
records = []
|
|
for r in req.results:
|
|
if r.error or not r.response.strip():
|
|
continue
|
|
records.append({
|
|
"id": str(uuid.uuid4()),
|
|
"source": "imitate",
|
|
"product_id": req.product_id,
|
|
"prompt_messages": [{"role": "user", "content": req.prompt}],
|
|
"model_response": r.response,
|
|
"model_id": r.model,
|
|
"elapsed_ms": r.elapsed_ms,
|
|
"status": "pending",
|
|
"created_at": ts,
|
|
})
|
|
|
|
if not records:
|
|
raise HTTPException(422, "No non-error results to push")
|
|
|
|
dest = _candidates_file()
|
|
dest.parent.mkdir(parents=True, exist_ok=True)
|
|
for record in records:
|
|
append_jsonl(dest, record)
|
|
|
|
return {"pushed": len(records)}
|