diff --git a/app/services/ocr/docuvision_client.py b/app/services/ocr/docuvision_client.py new file mode 100644 index 0000000..dfa1fed --- /dev/null +++ b/app/services/ocr/docuvision_client.py @@ -0,0 +1,60 @@ +"""Thin HTTP client for the cf-docuvision document vision service.""" +from __future__ import annotations + +import base64 +from dataclasses import dataclass +from pathlib import Path + +import httpx + + +@dataclass +class DocuvisionResult: + text: str + confidence: float | None = None + raw: dict | None = None + + +class DocuvisionClient: + """Thin client for the cf-docuvision service.""" + + def __init__(self, base_url: str) -> None: + self._base_url = base_url.rstrip("/") + + def extract_text(self, image_path: str | Path) -> DocuvisionResult: + """Send an image to docuvision and return extracted text.""" + image_bytes = Path(image_path).read_bytes() + b64 = base64.b64encode(image_bytes).decode() + + with httpx.Client(timeout=30.0) as client: + resp = client.post( + f"{self._base_url}/extract", + json={"image": b64}, + ) + resp.raise_for_status() + data = resp.json() + + return DocuvisionResult( + text=data.get("text", ""), + confidence=data.get("confidence"), + raw=data, + ) + + async def extract_text_async(self, image_path: str | Path) -> DocuvisionResult: + """Async version.""" + image_bytes = Path(image_path).read_bytes() + b64 = base64.b64encode(image_bytes).decode() + + async with httpx.AsyncClient(timeout=30.0) as client: + resp = await client.post( + f"{self._base_url}/extract", + json={"image": b64}, + ) + resp.raise_for_status() + data = resp.json() + + return DocuvisionResult( + text=data.get("text", ""), + confidence=data.get("confidence"), + raw=data, + ) diff --git a/app/services/ocr/vl_model.py b/app/services/ocr/vl_model.py index feea1f2..2f3a621 100644 --- a/app/services/ocr/vl_model.py +++ b/app/services/ocr/vl_model.py @@ -8,6 +8,7 @@ OCR with understanding of receipt structure to extract structured JSON data. import json import logging +import os import re from pathlib import Path from typing import Dict, Any, Optional, List @@ -26,6 +27,31 @@ from app.core.config import settings logger = logging.getLogger(__name__) +def _try_docuvision(image_path: str | Path) -> str | None: + """Try to extract text via cf-docuvision. Returns None if unavailable.""" + cf_orch_url = os.environ.get("CF_ORCH_URL") + if not cf_orch_url: + return None + try: + from circuitforge_core.resources import CFOrchClient + from app.services.ocr.docuvision_client import DocuvisionClient + + client = CFOrchClient(cf_orch_url) + with client.allocate( + service="cf-docuvision", + model_candidates=[], # cf-docuvision has no model selection + ttl_s=60.0, + caller="kiwi-ocr", + ) as alloc: + if alloc is None: + return None + doc_client = DocuvisionClient(alloc.url) + result = doc_client.extract_text(image_path) + return result.text if result.text else None + except Exception: + return None # graceful degradation + + class VisionLanguageOCR: """Vision-Language Model for receipt OCR and structured extraction.""" @@ -40,7 +66,7 @@ class VisionLanguageOCR: self.processor = None self.device = "cuda" if torch.cuda.is_available() and settings.USE_GPU else "cpu" self.use_quantization = use_quantization - self.model_name = "Qwen/Qwen2-VL-2B-Instruct" + self.model_name = "Qwen/Qwen2.5-VL-7B-Instruct" logger.info(f"Initializing VisionLanguageOCR with device: {self.device}") @@ -112,6 +138,19 @@ class VisionLanguageOCR: "warnings": [...] } """ + # Try docuvision fast path first (skips heavy local VLM if available) + docuvision_text = _try_docuvision(image_path) + if docuvision_text is not None: + return { + "raw_text": docuvision_text, + "merchant": {}, + "transaction": {}, + "items": [], + "totals": {}, + "confidence": {"overall": None}, + "warnings": [], + } + self._load_model() try: diff --git a/app/services/recipe/llm_recipe.py b/app/services/recipe/llm_recipe.py index a01ca72..5f8ff33 100644 --- a/app/services/recipe/llm_recipe.py +++ b/app/services/recipe/llm_recipe.py @@ -2,8 +2,12 @@ from __future__ import annotations import logging +import os +from contextlib import nullcontext from typing import TYPE_CHECKING +from openai import OpenAI + if TYPE_CHECKING: from app.db.store import Store @@ -113,76 +117,57 @@ class LLMRecipeGenerator: return "\n".join(lines) - def _acquire_vram_lease(self) -> str | None: - """Request a VRAM lease from the CF-core coordinator. Best-effort — returns None if unavailable.""" - try: - import httpx - from app.core.config import settings - from app.tasks.runner import VRAM_BUDGETS + _MODEL_CANDIDATES: list[str] = ["Ouro-2.6B-Thinking", "Ouro-1.4B"] - budget_mb = int(VRAM_BUDGETS.get("recipe_llm", 4.0) * 1024) - coordinator = settings.COORDINATOR_URL + def _get_llm_context(self): + """Return a sync context manager that yields an Allocation or None. - nodes_resp = httpx.get(f"{coordinator}/api/nodes", timeout=2.0) - if nodes_resp.status_code != 200: - return None - nodes = nodes_resp.json().get("nodes", []) - if not nodes: - return None - - best_node = best_gpu = best_free = None - for node in nodes: - for gpu in node.get("gpus", []): - free = gpu.get("vram_free_mb", 0) - if best_free is None or free > best_free: - best_node = node["node_id"] - best_gpu = gpu["gpu_id"] - best_free = free - if best_node is None: - return None - - lease_resp = httpx.post( - f"{coordinator}/api/leases", - json={ - "node_id": best_node, - "gpu_id": best_gpu, - "mb": budget_mb, - "service": "kiwi", - "priority": 5, - }, - timeout=3.0, - ) - if lease_resp.status_code == 200: - lease_id = lease_resp.json()["lease"]["lease_id"] - logger.debug("Acquired VRAM lease %s for recipe_llm (%d MB)", lease_id, budget_mb) - return lease_id - except Exception as exc: - logger.debug("VRAM lease acquire failed (non-fatal): %s", exc) - return None - - def _release_vram_lease(self, lease_id: str) -> None: - """Release a VRAM lease. Best-effort.""" - try: - import httpx - from app.core.config import settings - httpx.delete(f"{settings.COORDINATOR_URL}/api/leases/{lease_id}", timeout=3.0) - logger.debug("Released VRAM lease %s", lease_id) - except Exception as exc: - logger.debug("VRAM lease release failed (non-fatal): %s", exc) + When CF_ORCH_URL is set, uses CFOrchClient to acquire a vLLM allocation + (which handles service lifecycle and VRAM). Falls back to nullcontext(None) + when the env var is absent or CFOrchClient raises on construction. + """ + cf_orch_url = os.environ.get("CF_ORCH_URL") + if cf_orch_url: + try: + from circuitforge_core.resources import CFOrchClient + client = CFOrchClient(cf_orch_url) + return client.allocate( + service="vllm", + model_candidates=self._MODEL_CANDIDATES, + ttl_s=300.0, + caller="kiwi-recipe", + ) + except Exception as exc: + logger.debug("CFOrchClient init failed, falling back to direct URL: %s", exc) + return nullcontext(None) def _call_llm(self, prompt: str) -> str: - """Call the LLM router with a VRAM lease held for the duration.""" - lease_id = self._acquire_vram_lease() + """Call the LLM, using CFOrchClient allocation when CF_ORCH_URL is set. + + With CF_ORCH_URL set: acquires a vLLM allocation via CFOrchClient and + calls the OpenAI-compatible API directly against the allocated service URL. + Without CF_ORCH_URL: falls back to LLMRouter using its configured backends. + """ try: - from circuitforge_core.llm.router import LLMRouter - router = LLMRouter() - return router.complete(prompt) + with self._get_llm_context() as alloc: + if alloc is not None: + base_url = alloc.url.rstrip("/") + "/v1" + client = OpenAI(base_url=base_url, api_key="any") + model = alloc.model or "__auto__" + if model == "__auto__": + model = client.models.list().data[0].id + resp = client.chat.completions.create( + model=model, + messages=[{"role": "user", "content": prompt}], + ) + return resp.choices[0].message.content or "" + else: + from circuitforge_core.llm.router import LLMRouter + router = LLMRouter() + return router.complete(prompt) except Exception as exc: logger.error("LLM call failed: %s", exc) return "" - finally: - if lease_id: - self._release_vram_lease(lease_id) def _parse_response(self, response: str) -> dict[str, str | list[str]]: """Parse LLM response text into structured recipe fields.""" diff --git a/tests/services/recipe/test_llm_recipe.py b/tests/services/recipe/test_llm_recipe.py index 0588722..06744e5 100644 --- a/tests/services/recipe/test_llm_recipe.py +++ b/tests/services/recipe/test_llm_recipe.py @@ -1,6 +1,11 @@ """Tests for LLMRecipeGenerator — prompt builders and allergy filtering.""" from __future__ import annotations +import os +from contextlib import contextmanager +from dataclasses import dataclass +from unittest.mock import MagicMock, patch + import pytest from app.models.schemas.recipe import RecipeRequest @@ -139,3 +144,82 @@ def test_generate_returns_result_when_llm_responds(monkeypatch): assert len(suggestion.directions) > 0 assert "parmesan" in suggestion.notes.lower() assert result.element_gaps == ["Brightness"] + + +# --------------------------------------------------------------------------- +# CFOrchClient integration tests +# --------------------------------------------------------------------------- + +@dataclass +class _FakeAllocation: + allocation_id: str = "alloc-test-1" + service: str = "vllm" + node_id: str = "node-1" + gpu_id: int = 0 + model: str | None = "Ouro-2.6B-Thinking" + url: str = "http://test:8000" + started: bool = True + warm: bool = True + + +def test_recipe_gen_uses_cf_orch_when_env_set(monkeypatch): + """When CF_ORCH_URL is set, _call_llm uses alloc.url+/v1 as the OpenAI base_url.""" + from app.services.recipe.llm_recipe import LLMRecipeGenerator + + store = _make_store() + gen = LLMRecipeGenerator(store) + + fake_alloc = _FakeAllocation() + + @contextmanager + def _fake_llm_context(): + yield fake_alloc + + captured = {} + + # Fake OpenAI that records the base_url it was constructed with + class _FakeOpenAI: + def __init__(self, *, base_url, api_key): + captured["base_url"] = base_url + msg = MagicMock() + msg.content = "Title: Test\nIngredients: a\nDirections: do it.\nNotes: none." + choice = MagicMock() + choice.message = msg + completion = MagicMock() + completion.choices = [choice] + self.chat = MagicMock() + self.chat.completions = MagicMock() + self.chat.completions.create = MagicMock(return_value=completion) + + # Patch _get_llm_context directly so no real HTTP call is made + monkeypatch.setattr(gen, "_get_llm_context", _fake_llm_context) + + with patch("app.services.recipe.llm_recipe.OpenAI", _FakeOpenAI): + gen._call_llm("make me a recipe") + + assert captured.get("base_url") == "http://test:8000/v1" + + +def test_recipe_gen_falls_back_without_cf_orch(monkeypatch): + """When CF_ORCH_URL is not set, _call_llm falls back to LLMRouter.""" + from app.services.recipe.llm_recipe import LLMRecipeGenerator + + store = _make_store() + gen = LLMRecipeGenerator(store) + + monkeypatch.delenv("CF_ORCH_URL", raising=False) + + router_called = {} + + def _fake_complete(prompt, **_kwargs): + router_called["prompt"] = prompt + return "Title: Direct\nIngredients: x\nDirections: go.\nNotes: ok." + + fake_router = MagicMock() + fake_router.complete.side_effect = _fake_complete + + # Patch where LLMRouter is imported inside _call_llm + with patch("circuitforge_core.llm.router.LLMRouter", return_value=fake_router): + gen._call_llm("direct path prompt") + + assert router_called.get("prompt") == "direct path prompt" diff --git a/tests/test_services/__init__.py b/tests/test_services/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_services/test_docuvision_client.py b/tests/test_services/test_docuvision_client.py new file mode 100644 index 0000000..bba0b70 --- /dev/null +++ b/tests/test_services/test_docuvision_client.py @@ -0,0 +1,87 @@ +"""Tests for DocuvisionClient and the _try_docuvision fast path.""" +from __future__ import annotations + +import base64 +from pathlib import Path +from unittest.mock import MagicMock, patch + +import httpx +import pytest + +from app.services.ocr.docuvision_client import DocuvisionClient, DocuvisionResult + + +# --------------------------------------------------------------------------- +# DocuvisionClient unit tests +# --------------------------------------------------------------------------- + + +def test_extract_text_sends_base64_image(tmp_path: Path) -> None: + """extract_text() POSTs a base64-encoded image and returns parsed text.""" + image_file = tmp_path / "test.jpg" + image_file.write_bytes(b"fake-image-bytes") + + mock_response = MagicMock() + mock_response.json.return_value = {"text": "Cheerios", "confidence": 0.95} + mock_response.raise_for_status.return_value = None + + with patch("httpx.Client") as mock_client_cls: + mock_client = MagicMock() + mock_client_cls.return_value.__enter__.return_value = mock_client + mock_client.post.return_value = mock_response + + client = DocuvisionClient("http://docuvision:8080") + result = client.extract_text(image_file) + + assert result.text == "Cheerios" + assert result.confidence == 0.95 + + mock_client.post.assert_called_once() + call_kwargs = mock_client.post.call_args + assert call_kwargs[0][0] == "http://docuvision:8080/extract" + posted_json = call_kwargs[1]["json"] + expected_b64 = base64.b64encode(b"fake-image-bytes").decode() + assert posted_json["image"] == expected_b64 + + +def test_extract_text_raises_on_http_error(tmp_path: Path) -> None: + """extract_text() propagates HTTP errors from the server.""" + image_file = tmp_path / "test.jpg" + image_file.write_bytes(b"fake-image-bytes") + + mock_response = MagicMock() + mock_response.raise_for_status.side_effect = httpx.HTTPStatusError( + "500 Internal Server Error", + request=MagicMock(), + response=MagicMock(), + ) + + with patch("httpx.Client") as mock_client_cls: + mock_client = MagicMock() + mock_client_cls.return_value.__enter__.return_value = mock_client + mock_client.post.return_value = mock_response + + client = DocuvisionClient("http://docuvision:8080") + with pytest.raises(httpx.HTTPStatusError): + client.extract_text(image_file) + + +# --------------------------------------------------------------------------- +# _try_docuvision fast-path tests +# --------------------------------------------------------------------------- + + +def test_try_docuvision_returns_none_without_cf_orch_url( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + """_try_docuvision() returns None immediately when CF_ORCH_URL is not set.""" + monkeypatch.delenv("CF_ORCH_URL", raising=False) + + # Import after env manipulation so the function sees the unset var + from app.services.ocr.vl_model import _try_docuvision + + with patch("httpx.Client") as mock_client_cls: + result = _try_docuvision(tmp_path / "test.jpg") + + assert result is None + mock_client_cls.assert_not_called()