kiwi/app/services/ocr/vl_model.py
pyr0ball 22e57118df feat: add DocuvisionClient + cf-docuvision fast-path for OCR
Introduces a thin HTTP client for the cf-docuvision service and wires it
as a fast path in VisionLanguageOCR.extract_receipt_data(). When CF_ORCH_URL
is set, the pipeline attempts docuvision allocation via CFOrchClient before
loading the heavy local VLM; falls back gracefully if unavailable.
2026-04-02 12:33:05 -07:00

410 lines
13 KiB
Python

#!/usr/bin/env python
"""
Vision-Language Model service for receipt OCR and structured data extraction.
Uses Qwen3-VL-2B-Instruct for intelligent receipt processing that combines
OCR with understanding of receipt structure to extract structured JSON data.
"""
import json
import logging
import os
import re
from pathlib import Path
from typing import Dict, Any, Optional, List
from datetime import datetime
from PIL import Image
import torch
from transformers import (
Qwen2VLForConditionalGeneration,
AutoProcessor,
BitsAndBytesConfig
)
from app.core.config import settings
logger = logging.getLogger(__name__)
def _try_docuvision(image_path: str | Path) -> str | None:
"""Try to extract text via cf-docuvision. Returns None if unavailable."""
cf_orch_url = os.environ.get("CF_ORCH_URL")
if not cf_orch_url:
return None
try:
from circuitforge_core.resources import CFOrchClient
from app.services.ocr.docuvision_client import DocuvisionClient
client = CFOrchClient(cf_orch_url)
with client.allocate(
service="cf-docuvision",
model_candidates=[], # cf-docuvision has no model selection
ttl_s=60.0,
caller="kiwi-ocr",
) as alloc:
if alloc is None:
return None
doc_client = DocuvisionClient(alloc.url)
result = doc_client.extract_text(image_path)
return result.text if result.text else None
except Exception:
return None # graceful degradation
class VisionLanguageOCR:
"""Vision-Language Model for receipt OCR and structured extraction."""
def __init__(self, use_quantization: bool = False):
"""
Initialize the VLM OCR service.
Args:
use_quantization: Use 8-bit quantization to reduce memory usage
"""
self.model = None
self.processor = None
self.device = "cuda" if torch.cuda.is_available() and settings.USE_GPU else "cpu"
self.use_quantization = use_quantization
self.model_name = "Qwen/Qwen2.5-VL-7B-Instruct"
logger.info(f"Initializing VisionLanguageOCR with device: {self.device}")
# Lazy loading - model will be loaded on first use
self._model_loaded = False
def _load_model(self):
"""Load the VLM model (lazy loading)."""
if self._model_loaded:
return
logger.info(f"Loading VLM model: {self.model_name}")
try:
if self.use_quantization and self.device == "cuda":
# Use 8-bit quantization for lower memory usage
quantization_config = BitsAndBytesConfig(
load_in_8bit=True,
llm_int8_threshold=6.0
)
self.model = Qwen2VLForConditionalGeneration.from_pretrained(
self.model_name,
quantization_config=quantization_config,
device_map="auto",
low_cpu_mem_usage=True
)
logger.info("Model loaded with 8-bit quantization")
else:
# Standard FP16 loading
self.model = Qwen2VLForConditionalGeneration.from_pretrained(
self.model_name,
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
device_map="auto" if self.device == "cuda" else None,
low_cpu_mem_usage=True
)
if self.device == "cpu":
self.model = self.model.to("cpu")
logger.info(f"Model loaded in {'FP16' if self.device == 'cuda' else 'FP32'} mode")
self.processor = AutoProcessor.from_pretrained(self.model_name)
self.model.eval() # Set to evaluation mode
self._model_loaded = True
logger.info("VLM model loaded successfully")
except Exception as e:
logger.error(f"Failed to load VLM model: {e}")
raise RuntimeError(f"Could not load VLM model: {e}")
def extract_receipt_data(self, image_path: str) -> Dict[str, Any]:
"""
Extract structured data from receipt image.
Args:
image_path: Path to the receipt image
Returns:
Dictionary containing extracted receipt data with structure:
{
"merchant": {...},
"transaction": {...},
"items": [...],
"totals": {...},
"confidence": {...},
"raw_text": "...",
"warnings": [...]
}
"""
# Try docuvision fast path first (skips heavy local VLM if available)
docuvision_text = _try_docuvision(image_path)
if docuvision_text is not None:
return {
"raw_text": docuvision_text,
"merchant": {},
"transaction": {},
"items": [],
"totals": {},
"confidence": {"overall": None},
"warnings": [],
}
self._load_model()
try:
# Load image
image = Image.open(image_path)
# Convert to RGB if needed
if image.mode != 'RGB':
image = image.convert('RGB')
# Build extraction prompt
prompt = self._build_extraction_prompt()
# Process image and text
logger.info(f"Processing receipt image: {image_path}")
inputs = self.processor(
images=image,
text=prompt,
return_tensors="pt"
)
# Move to device
if self.device == "cuda":
inputs = {k: v.to("cuda", torch.float16) if isinstance(v, torch.Tensor) else v
for k, v in inputs.items()}
# Generate
with torch.no_grad():
output_ids = self.model.generate(
**inputs,
max_new_tokens=2048,
do_sample=False, # Deterministic for consistency
temperature=0.0,
pad_token_id=self.processor.tokenizer.pad_token_id,
eos_token_id=self.processor.tokenizer.eos_token_id,
)
# Decode output
output_text = self.processor.decode(
output_ids[0],
skip_special_tokens=True
)
# Remove the prompt from output
output_text = output_text.replace(prompt, "").strip()
logger.info(f"VLM output length: {len(output_text)} characters")
# Parse JSON from output
result = self._parse_json_from_text(output_text)
# Add raw text for reference
result["raw_text"] = output_text
# Validate and enhance result
result = self._validate_result(result)
return result
except Exception as e:
logger.error(f"Error extracting receipt data: {e}", exc_info=True)
return {
"error": str(e),
"merchant": {},
"transaction": {},
"items": [],
"totals": {},
"confidence": {"overall": 0.0},
"warnings": [f"Extraction failed: {str(e)}"]
}
def _build_extraction_prompt(self) -> str:
"""Build the prompt for receipt data extraction."""
return """You are a receipt OCR specialist. Extract all information from this receipt image and return it in the exact JSON format specified below.
Return a JSON object with this exact structure:
{
"merchant": {
"name": "Store Name",
"address": "123 Main St, City, State ZIP",
"phone": "555-1234"
},
"transaction": {
"date": "2025-10-30",
"time": "14:30:00",
"receipt_number": "12345",
"register": "01",
"cashier": "Jane"
},
"items": [
{
"name": "Product name",
"quantity": 2,
"unit_price": 10.99,
"total_price": 21.98,
"category": "grocery",
"tax_code": "F",
"discount": 0.00
}
],
"totals": {
"subtotal": 21.98,
"tax": 1.98,
"discount": 0.00,
"total": 23.96,
"payment_method": "Credit Card",
"amount_paid": 23.96,
"change": 0.00
},
"confidence": {
"overall": 0.95,
"merchant": 0.98,
"items": 0.92,
"totals": 0.97
}
}
Important instructions:
1. Extract ALL items from the receipt, no matter how many there are
2. Use null for fields you cannot find
3. For dates, use YYYY-MM-DD format
4. For times, use HH:MM:SS format
5. For prices, use numeric values (not strings)
6. Estimate confidence scores (0.0-1.0) based on image quality and text clarity
7. Return ONLY the JSON object, no other text or explanation"""
def _parse_json_from_text(self, text: str) -> Dict[str, Any]:
"""
Extract and parse JSON from model output text.
Args:
text: Raw text output from the model
Returns:
Parsed JSON dictionary
"""
# Try to find JSON object in text
# Look for content between first { and last }
json_match = re.search(r'\{.*\}', text, re.DOTALL)
if json_match:
json_str = json_match.group(0)
try:
return json.loads(json_str)
except json.JSONDecodeError as e:
logger.warning(f"Failed to parse JSON: {e}")
# Try to fix common issues
json_str = self._fix_json(json_str)
try:
return json.loads(json_str)
except json.JSONDecodeError:
logger.error("Could not parse JSON even after fixes")
# Return empty structure if parsing fails
logger.warning("No valid JSON found in output, returning empty structure")
return {
"merchant": {},
"transaction": {},
"items": [],
"totals": {},
"confidence": {"overall": 0.1}
}
def _fix_json(self, json_str: str) -> str:
"""Attempt to fix common JSON formatting issues."""
# Remove trailing commas
json_str = re.sub(r',\s*}', '}', json_str)
json_str = re.sub(r',\s*]', ']', json_str)
# Fix single quotes to double quotes
json_str = json_str.replace("'", '"')
return json_str
def _validate_result(self, result: Dict[str, Any]) -> Dict[str, Any]:
"""
Validate and enhance extracted data.
Args:
result: Extracted receipt data
Returns:
Validated and enhanced result with warnings
"""
warnings = []
# Ensure required fields exist
required_fields = ["merchant", "transaction", "items", "totals", "confidence"]
for field in required_fields:
if field not in result:
result[field] = {} if field != "items" else []
warnings.append(f"Missing required field: {field}")
# Validate items
if not result.get("items"):
warnings.append("No items found on receipt")
else:
# Validate item structure
for i, item in enumerate(result["items"]):
if "total_price" not in item and "unit_price" in item and "quantity" in item:
item["total_price"] = item["unit_price"] * item["quantity"]
# Validate totals
if result.get("items") and result.get("totals"):
calculated_subtotal = sum(
item.get("total_price", 0)
for item in result["items"]
)
reported_subtotal = result["totals"].get("subtotal", 0)
# Allow small variance (rounding errors)
if abs(calculated_subtotal - reported_subtotal) > 0.10:
warnings.append(
f"Total mismatch: calculated ${calculated_subtotal:.2f}, "
f"reported ${reported_subtotal:.2f}"
)
result["totals"]["calculated_subtotal"] = calculated_subtotal
# Validate date format
if result.get("transaction", {}).get("date"):
try:
datetime.strptime(result["transaction"]["date"], "%Y-%m-%d")
except ValueError:
warnings.append(f"Invalid date format: {result['transaction']['date']}")
# Add warnings to result
if warnings:
result["warnings"] = warnings
# Ensure confidence exists
if "confidence" not in result or not result["confidence"]:
result["confidence"] = {
"overall": 0.5,
"merchant": 0.5,
"items": 0.5,
"totals": 0.5
}
return result
def get_model_info(self) -> Dict[str, Any]:
"""Get information about the loaded model."""
return {
"model_name": self.model_name,
"device": self.device,
"quantization": self.use_quantization,
"loaded": self._model_loaded,
"gpu_available": torch.cuda.is_available(),
"gpu_memory_allocated": torch.cuda.memory_allocated() if torch.cuda.is_available() else 0,
"gpu_memory_reserved": torch.cuda.memory_reserved() if torch.cuda.is_available() else 0
}
def clear_cache(self):
"""Clear GPU memory cache."""
if torch.cuda.is_available():
torch.cuda.empty_cache()
logger.info("GPU cache cleared")