diff --git a/app/api.py b/app/api.py index f562e20..bb09886 100644 --- a/app/api.py +++ b/app/api.py @@ -155,6 +155,9 @@ app.include_router(cforch_router, prefix="/api/cforch") from app.imitate import router as imitate_router app.include_router(imitate_router, prefix="/api/imitate") +from app.style import router as style_router +app.include_router(style_router, prefix="/api/style") + # In-memory last-action store (single user, local tool — in-memory is fine) _last_action: dict | None = None diff --git a/app/imitate.py b/app/imitate.py index 44b4e79..58777f8 100644 --- a/app/imitate.py +++ b/app/imitate.py @@ -11,6 +11,7 @@ override _CONFIG_DIR and _DATA_DIR via set_config_dir() / set_data_dir() in test """ from __future__ import annotations +import base64 import json import logging import time @@ -21,6 +22,7 @@ from typing import Any from urllib.error import URLError from urllib.request import Request, urlopen +import httpx import yaml from fastapi import APIRouter, HTTPException from fastapi.responses import StreamingResponse @@ -87,6 +89,45 @@ def _ollama_url(cfg: dict) -> str: return cfg.get("ollama_url") or cforch.get("ollama_url") or "http://localhost:11434" +def _cforch_url() -> str: + cforch = _load_cforch_config() + return cforch.get("coordinator_url") or "http://localhost:7700" + + +def _cforch_catalog(cforch_base: str) -> list[dict]: + """Fetch the live cf-text catalog from cf-orch. + + Filters out proxy entries (ollama://, vllm://, http://) — those models are + served by their own services and should not be allocated via cf-text. + Returns only models with real file-system paths that cf-text can load directly. + """ + try: + resp = httpx.get( + f"{cforch_base}/api/services/cf-text/catalog", + params={"node_id": "heimdall"}, + timeout=5.0, + ) + resp.raise_for_status() + raw = resp.json() + result = [] + for model_id, entry in raw.items(): + if not isinstance(entry, dict): + continue + path = entry.get("path", "") + # Skip proxy entries — they're routed through other services + if "://" in path: + continue + result.append({ + "id": model_id, + "vram_mb": entry.get("vram_mb", 0), + "description": entry.get("description", ""), + }) + return result + except Exception as exc: + logger.warning("Could not fetch cf-orch catalog: %s", exc) + return [] + + def _http_get_json(url: str, timeout: int = 5) -> Any: """Fetch JSON from url; raise URLError on failure.""" req = Request(url, headers={"Accept": "application/json"}) @@ -104,18 +145,29 @@ def _is_online(base_url: str, health_path: str = "/api/health") -> bool: def _extract_sample( - raw: Any, text_fields: list[str], sample_index: int = 0 + raw: Any, + text_fields: list[str], + sample_index: int = 0, + sample_key: str | None = None, ) -> dict[str, Any]: - """Pull one item from a list or dict response and extract text_fields.""" + """Pull one item from a list or dict response and extract text_fields. + + sample_key: if provided, unwrap raw[sample_key] before looking for a list. + Falls back to a set of conventional envelope keys if sample_key is absent. + """ item: dict[str, Any] if isinstance(raw, list): if not raw: return {} item = raw[min(sample_index, len(raw) - 1)] elif isinstance(raw, dict): - # may be {items: [...]} or the item itself - for key in ("items", "results", "data", "jobs", "listings", "pantry", - "saved_searches", "entries", "calls", "records"): + # Use declared sample_key first, then fall back to conventional names. + _ENVELOPE_KEYS = ( + "samples", "items", "results", "data", "jobs", "listings", + "pantry", "saved_searches", "entries", "calls", "records", + ) + search_keys = ([sample_key] if sample_key else []) + list(_ENVELOPE_KEYS) + for key in search_keys: if key in raw and isinstance(raw[key], list): lst = raw[key] item = lst[min(sample_index, len(lst) - 1)] if lst else {} @@ -141,24 +193,49 @@ def _sse(data: dict) -> str: return f"data: {json.dumps(data)}\n\n" +def _fetch_image_b64(image_url: str) -> str: + """Download an image URL and return it as a base64 string for ollama. + + Returns empty string on any failure — a missing image is non-fatal; + the model will still run against the text prompt alone. + """ + try: + req = Request(image_url, headers={"User-Agent": "Avocet/1.0"}) + with urlopen(req, timeout=10) as resp: + return base64.b64encode(resp.read()).decode("ascii") + except Exception as exc: + logger.warning("Failed to fetch image %s: %s", image_url, exc) + return "" + + def _run_ollama_streaming( ollama_base: str, model_id: str, prompt: str, temperature: float, + system: str = "", + images: list[str] | None = None, ) -> tuple[str, int]: - """Call ollama /api/generate with stream=True; return (full_response, elapsed_ms). + """Call ollama /api/generate with stream=False; return (full_response, elapsed_ms). Blocks until the model finishes; yields nothing — streaming is handled by the SSE generator in run_imitate(). + + system: optional system prompt passed as a separate field to ollama. + images: list of base64-encoded image strings (vision models only). """ url = f"{ollama_base.rstrip('/')}/api/generate" - payload = json.dumps({ + body: dict = { "model": model_id, "prompt": prompt, "stream": False, "options": {"temperature": temperature}, - }).encode("utf-8") + } + if system: + body["system"] = system + if images: + body["images"] = images + payload = json.dumps(body).encode("utf-8") req = Request(url, data=payload, method="POST", headers={"Content-Type": "application/json"}) t0 = time.time() @@ -172,6 +249,122 @@ def _run_ollama_streaming( raise RuntimeError(str(exc)) from exc +def _run_cftext( + cforch_base: str, + model_id: str, + prompt: str, + system: str, + temperature: float, + startup_timeout_s: float = 180.0, +) -> tuple[str, int, bool]: + """Allocate cf-text via cf-orch, generate, release. Returns (response, elapsed_ms, cold_started). + + Raises RuntimeError on allocation failure or generation error. + cold_started=True means the service was launched from scratch (caller may log this). + + Cold-start detection uses coordinator state signals (running/stopped) rather than + polling the service health endpoint — this fails fast on model load errors instead + of waiting out the full timeout. + """ + # Allocate + alloc_resp = httpx.post( + f"{cforch_base}/api/services/cf-text/allocate", + json={ + "model_candidates": [model_id], + "caller": "avocet", + "pipeline": "imitate", + }, + timeout=30.0, + ) + alloc_resp.raise_for_status() + data = alloc_resp.json() + service_url: str = data["url"] + allocation_id: str = data.get("allocation_id", "") + node_id: str = data.get("node_id", "") + gpu_id: int | None = data.get("gpu_id") + cold_started = data.get("started", False) and not data.get("warm", True) + + # Wait for ready using coordinator state signals + if cold_started: + deadline = time.monotonic() + startup_timeout_s + probe_misses = 0 + while time.monotonic() < deadline: + try: + status = httpx.get( + f"{cforch_base}/api/services/cf-text/status", timeout=5.0 + ) + if status.is_success: + instances = status.json().get("instances", []) + match = next( + (i for i in instances + if i.get("node_id") == node_id and i.get("gpu_id") == gpu_id), + None, + ) + if match: + probe_misses = 0 + state = match.get("state", "") + if state == "running": + break + elif state == "stopped": + if allocation_id: + httpx.delete( + f"{cforch_base}/api/services/cf-text/allocations/{allocation_id}", + timeout=5.0, + ) + raise RuntimeError(f"cf-text failed to load {model_id!r} (service stopped)") + else: + probe_misses += 1 + if probe_misses >= 6: + # Coordinator hasn't registered instance yet — fall back to health poll + try: + if httpx.get(f"{service_url}/health", timeout=3.0).is_success: + break + except Exception: + pass + except RuntimeError: + raise + except Exception: + pass + time.sleep(2.0) + else: + if allocation_id: + httpx.delete(f"{cforch_base}/api/services/cf-text/allocations/{allocation_id}", timeout=5.0) + raise RuntimeError(f"cf-text cold start timed out after {startup_timeout_s:.0f}s") + + # Generate + messages: list[dict] = [] + if system: + messages.append({"role": "system", "content": system}) + messages.append({"role": "user", "content": prompt}) + + t0 = time.time() + try: + gen_resp = httpx.post( + f"{service_url}/v1/chat/completions", + json={ + "model": model_id, + "messages": messages, + "max_tokens": 300, + "temperature": temperature, + "stream": False, + }, + timeout=120.0, + ) + gen_resp.raise_for_status() + elapsed_ms = int((time.time() - t0) * 1000) + content = gen_resp.json()["choices"][0]["message"]["content"] + return content.strip(), elapsed_ms, cold_started + except Exception as exc: + elapsed_ms = int((time.time() - t0) * 1000) + raise RuntimeError(str(exc)) from exc + finally: + if allocation_id: + try: + httpx.delete(f"{cforch_base}/api/services/cf-text/allocations/{allocation_id}", timeout=5.0) + except Exception: + pass + + # ── GET /products ────────────────────────────────────────────────────────────── @router.get("/products") @@ -226,52 +419,96 @@ def get_sample(product_id: str, index: int = 0) -> dict: raise HTTPException(502, f"Bad response from product API: {exc}") from exc text_fields = product.get("text_fields", []) or [] - extracted = _extract_sample(raw, text_fields, index) + sample_key = product.get("sample_key") or None + extracted = _extract_sample(raw, text_fields, index, sample_key=sample_key) if not extracted: raise HTTPException(404, "No sample items returned by product API") prompt_template = product.get("prompt_template", "{text}") prompt = prompt_template.replace("{text}", extracted["text"]) + # Also substitute any {field_name} placeholders from the raw item fields. + item = extracted.get("item", {}) + for field, val in item.items(): + prompt = prompt.replace(f"{{{field}}}", str(val) if val is not None else "") + + # Expose system_prompt and image_url if the product API returns them. + # system_prompt: Peregrine, Snipe (vision analysis instructions) + # image_url: Snipe listing photos — Avocet downloads + base64-encodes at run time + item = extracted.get("item", {}) + system_prompt = str(item.get("system_prompt", "")) if isinstance(item, dict) else "" + image_url = str(item.get("image_url", "")) if isinstance(item, dict) else "" return { "product_id": product_id, "sample_index": index, "text": extracted["text"], "prompt": prompt, - "raw_item": extracted.get("item", {}), + "system_prompt": system_prompt, + "image_url": image_url, + "raw_item": item, } +# ── GET /catalog ─────────────────────────────────────────────────────────────── + +@router.get("/catalog") +def get_catalog() -> dict: + """Return the live cf-text model catalog from cf-orch coordinator.""" + models = _cforch_catalog(_cforch_url()) + return {"models": models} + + # ── GET /run (SSE) ───────────────────────────────────────────────────────────── @router.get("/run") def run_imitate( prompt: str = "", - model_ids: str = "", # comma-separated ollama model IDs + model_ids: str = "", # comma-separated ollama model IDs + cf_text_model_ids: str = "", # comma-separated cf-text model IDs (via cf-orch) temperature: float = 0.7, product_id: str = "", + system: str = "", # optional system prompt + image_url: str = "", # optional image URL for vision models ) -> StreamingResponse: - """Run a prompt through selected ollama models and stream results as SSE.""" + """Run a prompt through selected ollama models and stream results as SSE. + + If image_url is provided, the image is downloaded once and passed to every + model as a base64-encoded blob — allowing vision-capable local models to + evaluate listing photos the same way Snipe's background task pipeline does. + """ if not prompt.strip(): raise HTTPException(422, "prompt is required") - ids = [m.strip() for m in model_ids.split(",") if m.strip()] - if not ids: - raise HTTPException(422, "model_ids is required") + ollama_ids = [m.strip() for m in model_ids.split(",") if m.strip()] + cftext_ids = [m.strip() for m in cf_text_model_ids.split(",") if m.strip()] + if not ollama_ids and not cftext_ids: + raise HTTPException(422, "model_ids or cf_text_model_ids is required") cfg = _load_imitate_config() ollama_base = _ollama_url(cfg) + cforch_base = _cforch_url() + system_ctx = system.strip() or "" + total_models = len(ollama_ids) + len(cftext_ids) + + # Download image once before streaming — shared across ollama vision models + images: list[str] = [] + if image_url.strip(): + b64 = _fetch_image_b64(image_url.strip()) + if b64: + images = [b64] def generate(): results: list[dict] = [] - yield _sse({"type": "start", "total_models": len(ids)}) + yield _sse({"type": "start", "total_models": total_models, "has_image": bool(images)}) - for model_id in ids: - yield _sse({"type": "model_start", "model": model_id}) + # Ollama models + for model_id in ollama_ids: + yield _sse({"type": "model_start", "model": model_id, "service": "ollama"}) try: response, elapsed_ms = _run_ollama_streaming( - ollama_base, model_id, prompt, temperature + ollama_base, model_id, prompt, temperature, + system=system_ctx, images=images or None, ) result = { "model": model_id, @@ -289,6 +526,41 @@ def run_imitate( results.append(result) yield _sse({"type": "model_done", **result}) + # cf-text models via cf-orch — fan out in parallel when multiple models selected + if cftext_ids: + from concurrent.futures import ThreadPoolExecutor, as_completed + + # Announce all models upfront so the UI can show loading states immediately + for model_id in cftext_ids: + yield _sse({"type": "model_start", "model": model_id, "service": "cf-text"}) + + with ThreadPoolExecutor(max_workers=len(cftext_ids)) as pool: + future_to_model = { + pool.submit(_run_cftext, cforch_base, mid, prompt, system_ctx, temperature): mid + for mid in cftext_ids + } + for future in as_completed(future_to_model): + model_id = future_to_model[future] + try: + response, elapsed_ms, cold_started = future.result() + if cold_started: + yield _sse({"type": "model_coldstart", "model": model_id}) + result = { + "model": model_id, + "response": response, + "elapsed_ms": elapsed_ms, + "error": None, + } + except Exception as exc: + result = { + "model": model_id, + "response": "", + "elapsed_ms": 0, + "error": str(exc), + } + results.append(result) + yield _sse({"type": "model_done", **result}) + yield _sse({"type": "complete", "results": results}) return StreamingResponse( diff --git a/web/src/views/ImitateView.vue b/web/src/views/ImitateView.vue index fc2a92b..948859f 100644 --- a/web/src/views/ImitateView.vue +++ b/web/src/views/ImitateView.vue @@ -49,12 +49,30 @@
Fetching sample from API…