diff --git a/app/data/imitate.py b/app/data/imitate.py new file mode 100644 index 0000000..354aeab --- /dev/null +++ b/app/data/imitate.py @@ -0,0 +1,644 @@ +"""Avocet — Imitate tab API. + +Fetches real samples from sibling CF product APIs, sends them through selected +local LLMs (ollama), and streams responses back to the UI. Results can be +pushed into the SFT corrections queue for human review. + +All endpoints registered on `router`. api.py includes this with prefix="/api/imitate". + +Module-level globals follow the same testability pattern as cforch.py and sft.py: +override _CONFIG_DIR and _DATA_DIR via set_config_dir() / set_data_dir() in tests. +""" +from __future__ import annotations + +import base64 +import json +import logging +import time +import uuid +from datetime import datetime, timezone +from pathlib import Path +from typing import Any +from urllib.error import URLError +from urllib.request import Request, urlopen + +import httpx +import yaml +from fastapi import APIRouter, Depends, HTTPException +from fastapi.responses import StreamingResponse +from pydantic import BaseModel + +from app.utils import append_jsonl + +logger = logging.getLogger(__name__) + +_ROOT = Path(__file__).parent.parent.parent +_CONFIG_DIR: Path | None = None +_DATA_DIR: Path = _ROOT / "data" + +router = APIRouter() + + +# ── Testability seams ────────────────────────────────────────────────────────── + +def set_config_dir(path: Path | None) -> None: + global _CONFIG_DIR + _CONFIG_DIR = path + + +def set_data_dir(path: Path) -> None: + global _DATA_DIR + _DATA_DIR = path + + +# ── Internal helpers ─────────────────────────────────────────────────────────── + +def _config_file() -> Path: + if _CONFIG_DIR is not None: + return _CONFIG_DIR / "label_tool.yaml" + return _ROOT / "config" / "label_tool.yaml" + + +def _load_imitate_config() -> dict: + """Read label_tool.yaml and return the imitate sub-dict (or {} if absent).""" + f = _config_file() + if not f.exists(): + return {} + try: + raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {} + except yaml.YAMLError as exc: + logger.warning("Failed to parse imitate config %s: %s", f, exc) + return {} + return raw.get("imitate", {}) or {} + + +def _load_cforch_config() -> dict: + """Read cforch section for ollama_url fallback.""" + f = _config_file() + if not f.exists(): + return {} + try: + raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {} + except yaml.YAMLError as exc: + return {} + return raw.get("cforch", {}) or {} + + +def _ollama_url(cfg: dict) -> str: + cforch = _load_cforch_config() + return cfg.get("ollama_url") or cforch.get("ollama_url") or "http://localhost:11434" + + +def _cforch_url() -> str: + cforch = _load_cforch_config() + return cforch.get("coordinator_url") or "http://localhost:7700" + + +def _cforch_catalog(cforch_base: str) -> list[dict]: + """Fetch the live cf-text catalog from cf-orch. + + Filters out proxy entries (ollama://, vllm://, http://) — those models are + served by their own services and should not be allocated via cf-text. + Returns only models with real file-system paths that cf-text can load directly. + """ + try: + resp = httpx.get( + f"{cforch_base}/api/services/cf-text/catalog", + params={"node_id": "heimdall"}, + timeout=5.0, + ) + resp.raise_for_status() + raw = resp.json() + result = [] + for model_id, entry in raw.items(): + if not isinstance(entry, dict): + continue + path = entry.get("path", "") + # Skip proxy entries — they're routed through other services + if "://" in path: + continue + result.append({ + "id": model_id, + "vram_mb": entry.get("vram_mb", 0), + "description": entry.get("description", ""), + }) + return result + except Exception as exc: + logger.warning("Could not fetch cf-orch catalog: %s", exc) + return [] + + +def _http_get_json(url: str, timeout: int = 5) -> Any: + """Fetch JSON from url; raise URLError on failure.""" + req = Request(url, headers={"Accept": "application/json"}) + with urlopen(req, timeout=timeout) as resp: + return json.loads(resp.read().decode("utf-8")) + + +def _is_online(base_url: str, health_path: str = "/api/health") -> bool: + """Return True if the product's health endpoint responds OK.""" + try: + data = _http_get_json(f"{base_url.rstrip('/')}{health_path}", timeout=2) + return bool(data) + except Exception: + return False + + +def _extract_sample( + raw: Any, + text_fields: list[str], + sample_index: int = 0, + sample_key: str | None = None, +) -> dict[str, Any]: + """Pull one item from a list or dict response and extract text_fields. + + sample_key: if provided, unwrap raw[sample_key] before looking for a list. + Falls back to a set of conventional envelope keys if sample_key is absent. + """ + item: dict[str, Any] + if isinstance(raw, list): + if not raw: + return {} + item = raw[min(sample_index, len(raw) - 1)] + elif isinstance(raw, dict): + # Use declared sample_key first, then fall back to conventional names. + _ENVELOPE_KEYS = ( + "samples", "items", "results", "data", "jobs", "listings", + "pantry", "saved_searches", "entries", "calls", "records", + ) + search_keys = ([sample_key] if sample_key else []) + list(_ENVELOPE_KEYS) + for key in search_keys: + if key in raw and isinstance(raw[key], list): + lst = raw[key] + item = lst[min(sample_index, len(lst) - 1)] if lst else {} + break + else: + item = raw + else: + return {} + + parts = [] + for field in text_fields: + val = item.get(field) + if val and str(val).strip(): + parts.append(f"**{field}**: {val}") + return {"item": item, "text": "\n\n".join(parts)} + + +def _candidates_file() -> Path: + return _DATA_DIR / "sft_candidates.jsonl" + + +def _sse(data: dict) -> str: + return f"data: {json.dumps(data)}\n\n" + + +def _fetch_image_b64(image_url: str) -> str: + """Download an image URL and return it as a base64 string for ollama. + + Returns empty string on any failure — a missing image is non-fatal; + the model will still run against the text prompt alone. + """ + try: + req = Request(image_url, headers={"User-Agent": "Avocet/1.0"}) + with urlopen(req, timeout=10) as resp: + return base64.b64encode(resp.read()).decode("ascii") + except Exception as exc: + logger.warning("Failed to fetch image %s: %s", image_url, exc) + return "" + + +def _run_ollama_streaming( + ollama_base: str, + model_id: str, + prompt: str, + temperature: float, + system: str = "", + images: list[str] | None = None, +) -> tuple[str, int]: + """Call ollama /api/generate with stream=False; return (full_response, elapsed_ms). + + Blocks until the model finishes; yields nothing — streaming is handled by + the SSE generator in run_imitate(). + + system: optional system prompt passed as a separate field to ollama. + images: list of base64-encoded image strings (vision models only). + """ + url = f"{ollama_base.rstrip('/')}/api/generate" + body: dict = { + "model": model_id, + "prompt": prompt, + "stream": False, + "options": {"temperature": temperature}, + } + if system: + body["system"] = system + if images: + body["images"] = images + payload = json.dumps(body).encode("utf-8") + req = Request(url, data=payload, method="POST", + headers={"Content-Type": "application/json"}) + t0 = time.time() + try: + with urlopen(req, timeout=120) as resp: + body = json.loads(resp.read().decode("utf-8")) + elapsed = int((time.time() - t0) * 1000) + return body.get("response", ""), elapsed + except Exception as exc: + elapsed = int((time.time() - t0) * 1000) + raise RuntimeError(str(exc)) from exc + + +def _run_cftext( + cforch_base: str, + model_id: str, + prompt: str, + system: str, + temperature: float, + startup_timeout_s: float = 180.0, + user_id: str | None = None, +) -> tuple[str, int, bool]: + """Allocate cf-text via cf-orch, generate, release. Returns (response, elapsed_ms, cold_started). + + Raises RuntimeError on allocation failure or generation error. + cold_started=True means the service was launched from scratch (caller may log this). + + Cold-start detection uses coordinator state signals (running/stopped) rather than + polling the service health endpoint — this fails fast on model load errors instead + of waiting out the full timeout. + """ + # Allocate + alloc_resp = httpx.post( + f"{cforch_base}/api/services/cf-text/allocate", + json={ + "model_candidates": [model_id], + "caller": "avocet", + "pipeline": "imitate", + **({"user_id": user_id} if user_id else {}), + }, + timeout=30.0, + ) + alloc_resp.raise_for_status() + data = alloc_resp.json() + service_url: str = data["url"] + allocation_id: str = data.get("allocation_id", "") + node_id: str = data.get("node_id", "") + gpu_id: int | None = data.get("gpu_id") + cold_started = data.get("started", False) and not data.get("warm", True) + + # Wait for ready using coordinator state signals + if cold_started: + deadline = time.monotonic() + startup_timeout_s + probe_misses = 0 + while time.monotonic() < deadline: + try: + status = httpx.get( + f"{cforch_base}/api/services/cf-text/status", timeout=5.0 + ) + if status.is_success: + instances = status.json().get("instances", []) + match = next( + (i for i in instances + if i.get("node_id") == node_id and i.get("gpu_id") == gpu_id), + None, + ) + if match: + probe_misses = 0 + state = match.get("state", "") + if state == "running": + break + elif state == "stopped": + if allocation_id: + httpx.delete( + f"{cforch_base}/api/services/cf-text/allocations/{allocation_id}", + timeout=5.0, + ) + raise RuntimeError(f"cf-text failed to load {model_id!r} (service stopped)") + else: + probe_misses += 1 + if probe_misses >= 6: + # Coordinator hasn't registered instance yet — fall back to health poll + try: + if httpx.get(f"{service_url}/health", timeout=3.0).is_success: + break + except Exception: + pass + except RuntimeError: + raise + except Exception: + pass + time.sleep(2.0) + else: + if allocation_id: + httpx.delete(f"{cforch_base}/api/services/cf-text/allocations/{allocation_id}", timeout=5.0) + raise RuntimeError(f"cf-text cold start timed out after {startup_timeout_s:.0f}s") + + # Generate + messages: list[dict] = [] + if system: + messages.append({"role": "system", "content": system}) + messages.append({"role": "user", "content": prompt}) + + t0 = time.time() + try: + gen_resp = httpx.post( + f"{service_url}/v1/chat/completions", + json={ + "model": model_id, + "messages": messages, + "max_tokens": 300, + "temperature": temperature, + "stream": False, + }, + timeout=120.0, + ) + gen_resp.raise_for_status() + elapsed_ms = int((time.time() - t0) * 1000) + content = gen_resp.json()["choices"][0]["message"]["content"] + return content.strip(), elapsed_ms, cold_started + except Exception as exc: + elapsed_ms = int((time.time() - t0) * 1000) + raise RuntimeError(str(exc)) from exc + finally: + if allocation_id: + try: + httpx.delete(f"{cforch_base}/api/services/cf-text/allocations/{allocation_id}", timeout=5.0) + except Exception: + pass + + +# ── GET /products ────────────────────────────────────────────────────────────── + +@router.get("/products") +def get_products() -> dict: + """List configured CF products with live online status.""" + cfg = _load_imitate_config() + products_raw = cfg.get("products", []) or [] + products = [] + for p in products_raw: + if not isinstance(p, dict): + continue + base_url = p.get("base_url", "") + products.append({ + "id": p.get("id", ""), + "name": p.get("name", ""), + "icon": p.get("icon", "📦"), + "description": p.get("description", ""), + "base_url": base_url, + "online": _is_online(base_url, p.get("health_path", "/api/health")) if base_url else False, + }) + return {"products": products} + + +# ── GET /products/{product_id}/sample ───────────────────────────────────────── + +@router.get("/products/{product_id}/sample") +def get_sample(product_id: str, index: int = 0) -> dict: + """Fetch a real sample from the given product's API.""" + cfg = _load_imitate_config() + products_raw = cfg.get("products", []) or [] + + product: dict | None = None + for p in products_raw: + if isinstance(p, dict) and p.get("id") == product_id: + product = p + break + + if product is None: + raise HTTPException(404, f"Product '{product_id}' not in config") + + base_url = product.get("base_url", "").rstrip("/") + endpoint = product.get("sample_endpoint", "") + if not base_url or not endpoint: + raise HTTPException(422, "Product missing base_url or sample_endpoint") + + url = f"{base_url}{endpoint}" + try: + raw = _http_get_json(url, timeout=5) + except URLError as exc: + raise HTTPException(503, f"Product API unreachable: {exc}") from exc + except Exception as exc: + raise HTTPException(502, f"Bad response from product API: {exc}") from exc + + text_fields = product.get("text_fields", []) or [] + sample_key = product.get("sample_key") or None + extracted = _extract_sample(raw, text_fields, index, sample_key=sample_key) + if not extracted: + raise HTTPException(404, "No sample items returned by product API") + + prompt_template = product.get("prompt_template", "{text}") + prompt = prompt_template.replace("{text}", extracted["text"]) + # Also substitute any {field_name} placeholders from the raw item fields. + item = extracted.get("item", {}) + for field, val in item.items(): + prompt = prompt.replace(f"{{{field}}}", str(val) if val is not None else "") + + # Expose system_prompt and image_url if the product API returns them. + # system_prompt: Peregrine, Snipe (vision analysis instructions) + # image_url: Snipe listing photos — Avocet downloads + base64-encodes at run time + item = extracted.get("item", {}) + system_prompt = str(item.get("system_prompt", "")) if isinstance(item, dict) else "" + image_url = str(item.get("image_url", "")) if isinstance(item, dict) else "" + + return { + "product_id": product_id, + "sample_index": index, + "text": extracted["text"], + "prompt": prompt, + "system_prompt": system_prompt, + "image_url": image_url, + "raw_item": item, + } + + +# ── GET /catalog ─────────────────────────────────────────────────────────────── + +@router.get("/catalog") +def get_catalog() -> dict: + """Return the live cf-text model catalog from cf-orch coordinator.""" + models = _cforch_catalog(_cforch_url()) + return {"models": models} + + +# ── GET /run (SSE) ───────────────────────────────────────────────────────────── + +def _get_imitate_session(request: Any, response: Any) -> "CloudUser | None": + """Optional session dependency — returns None when cloud_session is unavailable.""" + try: + from app.cloud_session import get_session + return get_session(request, response) + except Exception: + return None + + +@router.get("/run") +def run_imitate( + prompt: str = "", + model_ids: str = "", # comma-separated ollama model IDs + cf_text_model_ids: str = "", # comma-separated cf-text model IDs (via cf-orch) + temperature: float = 0.7, + product_id: str = "", + system: str = "", # optional system prompt + image_url: str = "", # optional image URL for vision models + session: "Any" = Depends(_get_imitate_session), +) -> StreamingResponse: + """Run a prompt through selected ollama models and stream results as SSE. + + If image_url is provided, the image is downloaded once and passed to every + model as a base64-encoded blob — allowing vision-capable local models to + evaluate listing photos the same way Snipe's background task pipeline does. + """ + + if not prompt.strip(): + raise HTTPException(422, "prompt is required") + + ollama_ids = [m.strip() for m in model_ids.split(",") if m.strip()] + cftext_ids = [m.strip() for m in cf_text_model_ids.split(",") if m.strip()] + if not ollama_ids and not cftext_ids: + raise HTTPException(422, "model_ids or cf_text_model_ids is required") + + cfg = _load_imitate_config() + ollama_base = _ollama_url(cfg) + cforch_base = _cforch_url() + system_ctx = system.strip() or "" + total_models = len(ollama_ids) + len(cftext_ids) + + # Download image once before streaming — shared across ollama vision models + images: list[str] = [] + if image_url.strip(): + b64 = _fetch_image_b64(image_url.strip()) + if b64: + images = [b64] + + def generate(): + results: list[dict] = [] + yield _sse({"type": "start", "total_models": total_models, "has_image": bool(images)}) + + # Ollama models + for model_id in ollama_ids: + yield _sse({"type": "model_start", "model": model_id, "service": "ollama"}) + try: + response, elapsed_ms = _run_ollama_streaming( + ollama_base, model_id, prompt, temperature, + system=system_ctx, images=images or None, + ) + result = { + "model": model_id, + "response": response, + "elapsed_ms": elapsed_ms, + "error": None, + } + except Exception as exc: + result = { + "model": model_id, + "response": "", + "elapsed_ms": 0, + "error": str(exc), + } + results.append(result) + yield _sse({"type": "model_done", **result}) + + # cf-text models via cf-orch — fan out in parallel when multiple models selected + if cftext_ids: + from concurrent.futures import ThreadPoolExecutor, as_completed + + # Announce all models upfront so the UI can show loading states immediately + for model_id in cftext_ids: + yield _sse({"type": "model_start", "model": model_id, "service": "cf-text"}) + + _user_id: str | None = getattr(session, "user_id", None) + # Only forward real cloud user IDs — skip local/anon sessions + if _user_id in (None, "local", "local-dev") or (_user_id or "").startswith("anon-"): + _user_id = None + + with ThreadPoolExecutor(max_workers=len(cftext_ids)) as pool: + future_to_model = { + pool.submit( + _run_cftext, cforch_base, mid, prompt, system_ctx, temperature, + 180.0, _user_id, + ): mid + for mid in cftext_ids + } + for future in as_completed(future_to_model): + model_id = future_to_model[future] + try: + response, elapsed_ms, cold_started = future.result() + if cold_started: + yield _sse({"type": "model_coldstart", "model": model_id}) + result = { + "model": model_id, + "response": response, + "elapsed_ms": elapsed_ms, + "error": None, + } + except Exception as exc: + result = { + "model": model_id, + "response": "", + "elapsed_ms": 0, + "error": str(exc), + } + results.append(result) + yield _sse({"type": "model_done", **result}) + + yield _sse({"type": "complete", "results": results}) + + return StreamingResponse( + generate(), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "X-Accel-Buffering": "no", + }, + ) + + +# ── POST /push-corrections ───────────────────────────────────────────────────── + +class ImitateResult(BaseModel): + model: str + response: str + elapsed_ms: int + error: str | None = None + + +class PushCorrectionsRequest(BaseModel): + product_id: str + prompt: str + results: list[ImitateResult] + + +@router.post("/push-corrections") +def push_corrections(req: PushCorrectionsRequest) -> dict: + """Append imitate results to sft_candidates.jsonl for human review.""" + if not req.prompt.strip(): + raise HTTPException(422, "prompt is required") + if not req.results: + raise HTTPException(422, "results list is empty") + + ts = datetime.now(timezone.utc).isoformat() + records = [] + for r in req.results: + if r.error or not r.response.strip(): + continue + records.append({ + "id": str(uuid.uuid4()), + "source": "imitate", + "product_id": req.product_id, + "prompt_messages": [{"role": "user", "content": req.prompt}], + "model_response": r.response, + "model_id": r.model, + "elapsed_ms": r.elapsed_ms, + "status": "pending", + "created_at": ts, + }) + + if not records: + raise HTTPException(422, "No non-error results to push") + + dest = _candidates_file() + dest.parent.mkdir(parents=True, exist_ok=True) + for record in records: + append_jsonl(dest, record) + + return {"pushed": len(records)} diff --git a/app/imitate.py b/app/imitate.py index cec5685..7f3ea93 100644 --- a/app/imitate.py +++ b/app/imitate.py @@ -1,644 +1,3 @@ -"""Avocet — Imitate tab API. - -Fetches real samples from sibling CF product APIs, sends them through selected -local LLMs (ollama), and streams responses back to the UI. Results can be -pushed into the SFT corrections queue for human review. - -All endpoints registered on `router`. api.py includes this with prefix="/api/imitate". - -Module-level globals follow the same testability pattern as cforch.py and sft.py: -override _CONFIG_DIR and _DATA_DIR via set_config_dir() / set_data_dir() in tests. -""" -from __future__ import annotations - -import base64 -import json -import logging -import time -import uuid -from datetime import datetime, timezone -from pathlib import Path -from typing import Any -from urllib.error import URLError -from urllib.request import Request, urlopen - -import httpx -import yaml -from fastapi import APIRouter, Depends, HTTPException -from fastapi.responses import StreamingResponse -from pydantic import BaseModel - -from app.utils import append_jsonl - -logger = logging.getLogger(__name__) - -_ROOT = Path(__file__).parent.parent -_CONFIG_DIR: Path | None = None -_DATA_DIR: Path = _ROOT / "data" - -router = APIRouter() - - -# ── Testability seams ────────────────────────────────────────────────────────── - -def set_config_dir(path: Path | None) -> None: - global _CONFIG_DIR - _CONFIG_DIR = path - - -def set_data_dir(path: Path) -> None: - global _DATA_DIR - _DATA_DIR = path - - -# ── Internal helpers ─────────────────────────────────────────────────────────── - -def _config_file() -> Path: - if _CONFIG_DIR is not None: - return _CONFIG_DIR / "label_tool.yaml" - return _ROOT / "config" / "label_tool.yaml" - - -def _load_imitate_config() -> dict: - """Read label_tool.yaml and return the imitate sub-dict (or {} if absent).""" - f = _config_file() - if not f.exists(): - return {} - try: - raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {} - except yaml.YAMLError as exc: - logger.warning("Failed to parse imitate config %s: %s", f, exc) - return {} - return raw.get("imitate", {}) or {} - - -def _load_cforch_config() -> dict: - """Read cforch section for ollama_url fallback.""" - f = _config_file() - if not f.exists(): - return {} - try: - raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {} - except yaml.YAMLError as exc: - return {} - return raw.get("cforch", {}) or {} - - -def _ollama_url(cfg: dict) -> str: - cforch = _load_cforch_config() - return cfg.get("ollama_url") or cforch.get("ollama_url") or "http://localhost:11434" - - -def _cforch_url() -> str: - cforch = _load_cforch_config() - return cforch.get("coordinator_url") or "http://localhost:7700" - - -def _cforch_catalog(cforch_base: str) -> list[dict]: - """Fetch the live cf-text catalog from cf-orch. - - Filters out proxy entries (ollama://, vllm://, http://) — those models are - served by their own services and should not be allocated via cf-text. - Returns only models with real file-system paths that cf-text can load directly. - """ - try: - resp = httpx.get( - f"{cforch_base}/api/services/cf-text/catalog", - params={"node_id": "heimdall"}, - timeout=5.0, - ) - resp.raise_for_status() - raw = resp.json() - result = [] - for model_id, entry in raw.items(): - if not isinstance(entry, dict): - continue - path = entry.get("path", "") - # Skip proxy entries — they're routed through other services - if "://" in path: - continue - result.append({ - "id": model_id, - "vram_mb": entry.get("vram_mb", 0), - "description": entry.get("description", ""), - }) - return result - except Exception as exc: - logger.warning("Could not fetch cf-orch catalog: %s", exc) - return [] - - -def _http_get_json(url: str, timeout: int = 5) -> Any: - """Fetch JSON from url; raise URLError on failure.""" - req = Request(url, headers={"Accept": "application/json"}) - with urlopen(req, timeout=timeout) as resp: - return json.loads(resp.read().decode("utf-8")) - - -def _is_online(base_url: str, health_path: str = "/api/health") -> bool: - """Return True if the product's health endpoint responds OK.""" - try: - data = _http_get_json(f"{base_url.rstrip('/')}{health_path}", timeout=2) - return bool(data) - except Exception: - return False - - -def _extract_sample( - raw: Any, - text_fields: list[str], - sample_index: int = 0, - sample_key: str | None = None, -) -> dict[str, Any]: - """Pull one item from a list or dict response and extract text_fields. - - sample_key: if provided, unwrap raw[sample_key] before looking for a list. - Falls back to a set of conventional envelope keys if sample_key is absent. - """ - item: dict[str, Any] - if isinstance(raw, list): - if not raw: - return {} - item = raw[min(sample_index, len(raw) - 1)] - elif isinstance(raw, dict): - # Use declared sample_key first, then fall back to conventional names. - _ENVELOPE_KEYS = ( - "samples", "items", "results", "data", "jobs", "listings", - "pantry", "saved_searches", "entries", "calls", "records", - ) - search_keys = ([sample_key] if sample_key else []) + list(_ENVELOPE_KEYS) - for key in search_keys: - if key in raw and isinstance(raw[key], list): - lst = raw[key] - item = lst[min(sample_index, len(lst) - 1)] if lst else {} - break - else: - item = raw - else: - return {} - - parts = [] - for field in text_fields: - val = item.get(field) - if val and str(val).strip(): - parts.append(f"**{field}**: {val}") - return {"item": item, "text": "\n\n".join(parts)} - - -def _candidates_file() -> Path: - return _DATA_DIR / "sft_candidates.jsonl" - - -def _sse(data: dict) -> str: - return f"data: {json.dumps(data)}\n\n" - - -def _fetch_image_b64(image_url: str) -> str: - """Download an image URL and return it as a base64 string for ollama. - - Returns empty string on any failure — a missing image is non-fatal; - the model will still run against the text prompt alone. - """ - try: - req = Request(image_url, headers={"User-Agent": "Avocet/1.0"}) - with urlopen(req, timeout=10) as resp: - return base64.b64encode(resp.read()).decode("ascii") - except Exception as exc: - logger.warning("Failed to fetch image %s: %s", image_url, exc) - return "" - - -def _run_ollama_streaming( - ollama_base: str, - model_id: str, - prompt: str, - temperature: float, - system: str = "", - images: list[str] | None = None, -) -> tuple[str, int]: - """Call ollama /api/generate with stream=False; return (full_response, elapsed_ms). - - Blocks until the model finishes; yields nothing — streaming is handled by - the SSE generator in run_imitate(). - - system: optional system prompt passed as a separate field to ollama. - images: list of base64-encoded image strings (vision models only). - """ - url = f"{ollama_base.rstrip('/')}/api/generate" - body: dict = { - "model": model_id, - "prompt": prompt, - "stream": False, - "options": {"temperature": temperature}, - } - if system: - body["system"] = system - if images: - body["images"] = images - payload = json.dumps(body).encode("utf-8") - req = Request(url, data=payload, method="POST", - headers={"Content-Type": "application/json"}) - t0 = time.time() - try: - with urlopen(req, timeout=120) as resp: - body = json.loads(resp.read().decode("utf-8")) - elapsed = int((time.time() - t0) * 1000) - return body.get("response", ""), elapsed - except Exception as exc: - elapsed = int((time.time() - t0) * 1000) - raise RuntimeError(str(exc)) from exc - - -def _run_cftext( - cforch_base: str, - model_id: str, - prompt: str, - system: str, - temperature: float, - startup_timeout_s: float = 180.0, - user_id: str | None = None, -) -> tuple[str, int, bool]: - """Allocate cf-text via cf-orch, generate, release. Returns (response, elapsed_ms, cold_started). - - Raises RuntimeError on allocation failure or generation error. - cold_started=True means the service was launched from scratch (caller may log this). - - Cold-start detection uses coordinator state signals (running/stopped) rather than - polling the service health endpoint — this fails fast on model load errors instead - of waiting out the full timeout. - """ - # Allocate - alloc_resp = httpx.post( - f"{cforch_base}/api/services/cf-text/allocate", - json={ - "model_candidates": [model_id], - "caller": "avocet", - "pipeline": "imitate", - **({"user_id": user_id} if user_id else {}), - }, - timeout=30.0, - ) - alloc_resp.raise_for_status() - data = alloc_resp.json() - service_url: str = data["url"] - allocation_id: str = data.get("allocation_id", "") - node_id: str = data.get("node_id", "") - gpu_id: int | None = data.get("gpu_id") - cold_started = data.get("started", False) and not data.get("warm", True) - - # Wait for ready using coordinator state signals - if cold_started: - deadline = time.monotonic() + startup_timeout_s - probe_misses = 0 - while time.monotonic() < deadline: - try: - status = httpx.get( - f"{cforch_base}/api/services/cf-text/status", timeout=5.0 - ) - if status.is_success: - instances = status.json().get("instances", []) - match = next( - (i for i in instances - if i.get("node_id") == node_id and i.get("gpu_id") == gpu_id), - None, - ) - if match: - probe_misses = 0 - state = match.get("state", "") - if state == "running": - break - elif state == "stopped": - if allocation_id: - httpx.delete( - f"{cforch_base}/api/services/cf-text/allocations/{allocation_id}", - timeout=5.0, - ) - raise RuntimeError(f"cf-text failed to load {model_id!r} (service stopped)") - else: - probe_misses += 1 - if probe_misses >= 6: - # Coordinator hasn't registered instance yet — fall back to health poll - try: - if httpx.get(f"{service_url}/health", timeout=3.0).is_success: - break - except Exception: - pass - except RuntimeError: - raise - except Exception: - pass - time.sleep(2.0) - else: - if allocation_id: - httpx.delete(f"{cforch_base}/api/services/cf-text/allocations/{allocation_id}", timeout=5.0) - raise RuntimeError(f"cf-text cold start timed out after {startup_timeout_s:.0f}s") - - # Generate - messages: list[dict] = [] - if system: - messages.append({"role": "system", "content": system}) - messages.append({"role": "user", "content": prompt}) - - t0 = time.time() - try: - gen_resp = httpx.post( - f"{service_url}/v1/chat/completions", - json={ - "model": model_id, - "messages": messages, - "max_tokens": 300, - "temperature": temperature, - "stream": False, - }, - timeout=120.0, - ) - gen_resp.raise_for_status() - elapsed_ms = int((time.time() - t0) * 1000) - content = gen_resp.json()["choices"][0]["message"]["content"] - return content.strip(), elapsed_ms, cold_started - except Exception as exc: - elapsed_ms = int((time.time() - t0) * 1000) - raise RuntimeError(str(exc)) from exc - finally: - if allocation_id: - try: - httpx.delete(f"{cforch_base}/api/services/cf-text/allocations/{allocation_id}", timeout=5.0) - except Exception: - pass - - -# ── GET /products ────────────────────────────────────────────────────────────── - -@router.get("/products") -def get_products() -> dict: - """List configured CF products with live online status.""" - cfg = _load_imitate_config() - products_raw = cfg.get("products", []) or [] - products = [] - for p in products_raw: - if not isinstance(p, dict): - continue - base_url = p.get("base_url", "") - products.append({ - "id": p.get("id", ""), - "name": p.get("name", ""), - "icon": p.get("icon", "📦"), - "description": p.get("description", ""), - "base_url": base_url, - "online": _is_online(base_url, p.get("health_path", "/api/health")) if base_url else False, - }) - return {"products": products} - - -# ── GET /products/{product_id}/sample ───────────────────────────────────────── - -@router.get("/products/{product_id}/sample") -def get_sample(product_id: str, index: int = 0) -> dict: - """Fetch a real sample from the given product's API.""" - cfg = _load_imitate_config() - products_raw = cfg.get("products", []) or [] - - product: dict | None = None - for p in products_raw: - if isinstance(p, dict) and p.get("id") == product_id: - product = p - break - - if product is None: - raise HTTPException(404, f"Product '{product_id}' not in config") - - base_url = product.get("base_url", "").rstrip("/") - endpoint = product.get("sample_endpoint", "") - if not base_url or not endpoint: - raise HTTPException(422, "Product missing base_url or sample_endpoint") - - url = f"{base_url}{endpoint}" - try: - raw = _http_get_json(url, timeout=5) - except URLError as exc: - raise HTTPException(503, f"Product API unreachable: {exc}") from exc - except Exception as exc: - raise HTTPException(502, f"Bad response from product API: {exc}") from exc - - text_fields = product.get("text_fields", []) or [] - sample_key = product.get("sample_key") or None - extracted = _extract_sample(raw, text_fields, index, sample_key=sample_key) - if not extracted: - raise HTTPException(404, "No sample items returned by product API") - - prompt_template = product.get("prompt_template", "{text}") - prompt = prompt_template.replace("{text}", extracted["text"]) - # Also substitute any {field_name} placeholders from the raw item fields. - item = extracted.get("item", {}) - for field, val in item.items(): - prompt = prompt.replace(f"{{{field}}}", str(val) if val is not None else "") - - # Expose system_prompt and image_url if the product API returns them. - # system_prompt: Peregrine, Snipe (vision analysis instructions) - # image_url: Snipe listing photos — Avocet downloads + base64-encodes at run time - item = extracted.get("item", {}) - system_prompt = str(item.get("system_prompt", "")) if isinstance(item, dict) else "" - image_url = str(item.get("image_url", "")) if isinstance(item, dict) else "" - - return { - "product_id": product_id, - "sample_index": index, - "text": extracted["text"], - "prompt": prompt, - "system_prompt": system_prompt, - "image_url": image_url, - "raw_item": item, - } - - -# ── GET /catalog ─────────────────────────────────────────────────────────────── - -@router.get("/catalog") -def get_catalog() -> dict: - """Return the live cf-text model catalog from cf-orch coordinator.""" - models = _cforch_catalog(_cforch_url()) - return {"models": models} - - -# ── GET /run (SSE) ───────────────────────────────────────────────────────────── - -def _get_imitate_session(request: Any, response: Any) -> "CloudUser | None": - """Optional session dependency — returns None when cloud_session is unavailable.""" - try: - from app.cloud_session import get_session - return get_session(request, response) - except Exception: - return None - - -@router.get("/run") -def run_imitate( - prompt: str = "", - model_ids: str = "", # comma-separated ollama model IDs - cf_text_model_ids: str = "", # comma-separated cf-text model IDs (via cf-orch) - temperature: float = 0.7, - product_id: str = "", - system: str = "", # optional system prompt - image_url: str = "", # optional image URL for vision models - session: "Any" = Depends(_get_imitate_session), -) -> StreamingResponse: - """Run a prompt through selected ollama models and stream results as SSE. - - If image_url is provided, the image is downloaded once and passed to every - model as a base64-encoded blob — allowing vision-capable local models to - evaluate listing photos the same way Snipe's background task pipeline does. - """ - - if not prompt.strip(): - raise HTTPException(422, "prompt is required") - - ollama_ids = [m.strip() for m in model_ids.split(",") if m.strip()] - cftext_ids = [m.strip() for m in cf_text_model_ids.split(",") if m.strip()] - if not ollama_ids and not cftext_ids: - raise HTTPException(422, "model_ids or cf_text_model_ids is required") - - cfg = _load_imitate_config() - ollama_base = _ollama_url(cfg) - cforch_base = _cforch_url() - system_ctx = system.strip() or "" - total_models = len(ollama_ids) + len(cftext_ids) - - # Download image once before streaming — shared across ollama vision models - images: list[str] = [] - if image_url.strip(): - b64 = _fetch_image_b64(image_url.strip()) - if b64: - images = [b64] - - def generate(): - results: list[dict] = [] - yield _sse({"type": "start", "total_models": total_models, "has_image": bool(images)}) - - # Ollama models - for model_id in ollama_ids: - yield _sse({"type": "model_start", "model": model_id, "service": "ollama"}) - try: - response, elapsed_ms = _run_ollama_streaming( - ollama_base, model_id, prompt, temperature, - system=system_ctx, images=images or None, - ) - result = { - "model": model_id, - "response": response, - "elapsed_ms": elapsed_ms, - "error": None, - } - except Exception as exc: - result = { - "model": model_id, - "response": "", - "elapsed_ms": 0, - "error": str(exc), - } - results.append(result) - yield _sse({"type": "model_done", **result}) - - # cf-text models via cf-orch — fan out in parallel when multiple models selected - if cftext_ids: - from concurrent.futures import ThreadPoolExecutor, as_completed - - # Announce all models upfront so the UI can show loading states immediately - for model_id in cftext_ids: - yield _sse({"type": "model_start", "model": model_id, "service": "cf-text"}) - - _user_id: str | None = getattr(session, "user_id", None) - # Only forward real cloud user IDs — skip local/anon sessions - if _user_id in (None, "local", "local-dev") or (_user_id or "").startswith("anon-"): - _user_id = None - - with ThreadPoolExecutor(max_workers=len(cftext_ids)) as pool: - future_to_model = { - pool.submit( - _run_cftext, cforch_base, mid, prompt, system_ctx, temperature, - 180.0, _user_id, - ): mid - for mid in cftext_ids - } - for future in as_completed(future_to_model): - model_id = future_to_model[future] - try: - response, elapsed_ms, cold_started = future.result() - if cold_started: - yield _sse({"type": "model_coldstart", "model": model_id}) - result = { - "model": model_id, - "response": response, - "elapsed_ms": elapsed_ms, - "error": None, - } - except Exception as exc: - result = { - "model": model_id, - "response": "", - "elapsed_ms": 0, - "error": str(exc), - } - results.append(result) - yield _sse({"type": "model_done", **result}) - - yield _sse({"type": "complete", "results": results}) - - return StreamingResponse( - generate(), - media_type="text/event-stream", - headers={ - "Cache-Control": "no-cache", - "X-Accel-Buffering": "no", - }, - ) - - -# ── POST /push-corrections ───────────────────────────────────────────────────── - -class ImitateResult(BaseModel): - model: str - response: str - elapsed_ms: int - error: str | None = None - - -class PushCorrectionsRequest(BaseModel): - product_id: str - prompt: str - results: list[ImitateResult] - - -@router.post("/push-corrections") -def push_corrections(req: PushCorrectionsRequest) -> dict: - """Append imitate results to sft_candidates.jsonl for human review.""" - if not req.prompt.strip(): - raise HTTPException(422, "prompt is required") - if not req.results: - raise HTTPException(422, "results list is empty") - - ts = datetime.now(timezone.utc).isoformat() - records = [] - for r in req.results: - if r.error or not r.response.strip(): - continue - records.append({ - "id": str(uuid.uuid4()), - "source": "imitate", - "product_id": req.product_id, - "prompt_messages": [{"role": "user", "content": req.prompt}], - "model_response": r.response, - "model_id": r.model, - "elapsed_ms": r.elapsed_ms, - "status": "pending", - "created_at": ts, - }) - - if not records: - raise HTTPException(422, "No non-error results to push") - - dest = _candidates_file() - dest.parent.mkdir(parents=True, exist_ok=True) - for record in records: - append_jsonl(dest, record) - - return {"pushed": len(records)} +"""Backward-compat shim -- logic moved to app/data/imitate.py.""" +from app.data.imitate import router # noqa: F401 +from app.data.imitate import set_config_dir, set_data_dir # noqa: F401 diff --git a/tests/test_imitate.py b/tests/test_imitate.py index c795b19..486fe09 100644 --- a/tests/test_imitate.py +++ b/tests/test_imitate.py @@ -1,4 +1,4 @@ -"""Tests for app/imitate.py — product registry, sample extraction, corrections push.""" +"""Tests for app/imitate.py -- product registry, sample extraction, corrections push.""" from __future__ import annotations import json @@ -9,10 +9,10 @@ import pytest from fastapi.testclient import TestClient from app.api import app -from app import imitate as _imitate_module +from app.data import imitate as _imitate_module -# ── Fixtures ─────────────────────────────────────────────────────────────────── +# -- Fixtures ------------------------------------------------------------------ @pytest.fixture(autouse=True) def reset_module_globals(tmp_path): @@ -70,7 +70,7 @@ def client() -> TestClient: return TestClient(app, raise_server_exceptions=True) -# ── GET /products ────────────────────────────────────────────────────────────── +# -- GET /products ------------------------------------------------------------- def test_products_empty_when_no_config(config_dir, client): """Returns empty list when label_tool.yaml has no imitate section.""" @@ -102,7 +102,7 @@ def test_products_offline_when_unreachable(cfg_with_products, client): assert all(not p["online"] for p in resp.json()["products"]) -# ── GET /products/{id}/sample ───────────────────────────────────────────────── +# -- GET /products/{id}/sample ------------------------------------------------- def test_sample_unknown_product(cfg_with_products, client): """Returns 404 for a product id not in config.""" @@ -149,7 +149,7 @@ def test_sample_404_on_empty_list(cfg_with_products, client): assert resp.status_code == 404 -# ── POST /push-corrections ───────────────────────────────────────────────────── +# -- POST /push-corrections ---------------------------------------------------- def test_push_corrections_appends_jsonl(cfg_with_products, data_dir, client): """Successful push writes records to sft_candidates.jsonl.""" @@ -214,7 +214,7 @@ def test_push_corrections_all_errors_422(cfg_with_products, data_dir, client): assert resp.status_code == 422 -# ── _extract_sample helper ───────────────────────────────────────────────────── +# -- _extract_sample helper ---------------------------------------------------- def test_extract_sample_list(): result = _imitate_module._extract_sample(