kiwi/app/services/ocr/vl_model.py
pyr0ball c166e5216a chore: initial commit — kiwi Phase 2 complete
Pantry tracker app with:
- FastAPI backend + Vue 3 SPA frontend
- SQLite via circuitforge-core (migrations 001-005)
- Inventory CRUD, barcode scan, receipt OCR pipeline
- Expiry prediction (deterministic + LLM fallback)
- CF-core tier system integration
- Cloud session support (menagerie)
2026-03-30 22:20:48 -07:00

371 lines
12 KiB
Python

#!/usr/bin/env python
"""
Vision-Language Model service for receipt OCR and structured data extraction.
Uses Qwen3-VL-2B-Instruct for intelligent receipt processing that combines
OCR with understanding of receipt structure to extract structured JSON data.
"""
import json
import logging
import re
from pathlib import Path
from typing import Dict, Any, Optional, List
from datetime import datetime
from PIL import Image
import torch
from transformers import (
Qwen2VLForConditionalGeneration,
AutoProcessor,
BitsAndBytesConfig
)
from app.core.config import settings
logger = logging.getLogger(__name__)
class VisionLanguageOCR:
"""Vision-Language Model for receipt OCR and structured extraction."""
def __init__(self, use_quantization: bool = False):
"""
Initialize the VLM OCR service.
Args:
use_quantization: Use 8-bit quantization to reduce memory usage
"""
self.model = None
self.processor = None
self.device = "cuda" if torch.cuda.is_available() and settings.USE_GPU else "cpu"
self.use_quantization = use_quantization
self.model_name = "Qwen/Qwen2-VL-2B-Instruct"
logger.info(f"Initializing VisionLanguageOCR with device: {self.device}")
# Lazy loading - model will be loaded on first use
self._model_loaded = False
def _load_model(self):
"""Load the VLM model (lazy loading)."""
if self._model_loaded:
return
logger.info(f"Loading VLM model: {self.model_name}")
try:
if self.use_quantization and self.device == "cuda":
# Use 8-bit quantization for lower memory usage
quantization_config = BitsAndBytesConfig(
load_in_8bit=True,
llm_int8_threshold=6.0
)
self.model = Qwen2VLForConditionalGeneration.from_pretrained(
self.model_name,
quantization_config=quantization_config,
device_map="auto",
low_cpu_mem_usage=True
)
logger.info("Model loaded with 8-bit quantization")
else:
# Standard FP16 loading
self.model = Qwen2VLForConditionalGeneration.from_pretrained(
self.model_name,
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
device_map="auto" if self.device == "cuda" else None,
low_cpu_mem_usage=True
)
if self.device == "cpu":
self.model = self.model.to("cpu")
logger.info(f"Model loaded in {'FP16' if self.device == 'cuda' else 'FP32'} mode")
self.processor = AutoProcessor.from_pretrained(self.model_name)
self.model.eval() # Set to evaluation mode
self._model_loaded = True
logger.info("VLM model loaded successfully")
except Exception as e:
logger.error(f"Failed to load VLM model: {e}")
raise RuntimeError(f"Could not load VLM model: {e}")
def extract_receipt_data(self, image_path: str) -> Dict[str, Any]:
"""
Extract structured data from receipt image.
Args:
image_path: Path to the receipt image
Returns:
Dictionary containing extracted receipt data with structure:
{
"merchant": {...},
"transaction": {...},
"items": [...],
"totals": {...},
"confidence": {...},
"raw_text": "...",
"warnings": [...]
}
"""
self._load_model()
try:
# Load image
image = Image.open(image_path)
# Convert to RGB if needed
if image.mode != 'RGB':
image = image.convert('RGB')
# Build extraction prompt
prompt = self._build_extraction_prompt()
# Process image and text
logger.info(f"Processing receipt image: {image_path}")
inputs = self.processor(
images=image,
text=prompt,
return_tensors="pt"
)
# Move to device
if self.device == "cuda":
inputs = {k: v.to("cuda", torch.float16) if isinstance(v, torch.Tensor) else v
for k, v in inputs.items()}
# Generate
with torch.no_grad():
output_ids = self.model.generate(
**inputs,
max_new_tokens=2048,
do_sample=False, # Deterministic for consistency
temperature=0.0,
pad_token_id=self.processor.tokenizer.pad_token_id,
eos_token_id=self.processor.tokenizer.eos_token_id,
)
# Decode output
output_text = self.processor.decode(
output_ids[0],
skip_special_tokens=True
)
# Remove the prompt from output
output_text = output_text.replace(prompt, "").strip()
logger.info(f"VLM output length: {len(output_text)} characters")
# Parse JSON from output
result = self._parse_json_from_text(output_text)
# Add raw text for reference
result["raw_text"] = output_text
# Validate and enhance result
result = self._validate_result(result)
return result
except Exception as e:
logger.error(f"Error extracting receipt data: {e}", exc_info=True)
return {
"error": str(e),
"merchant": {},
"transaction": {},
"items": [],
"totals": {},
"confidence": {"overall": 0.0},
"warnings": [f"Extraction failed: {str(e)}"]
}
def _build_extraction_prompt(self) -> str:
"""Build the prompt for receipt data extraction."""
return """You are a receipt OCR specialist. Extract all information from this receipt image and return it in the exact JSON format specified below.
Return a JSON object with this exact structure:
{
"merchant": {
"name": "Store Name",
"address": "123 Main St, City, State ZIP",
"phone": "555-1234"
},
"transaction": {
"date": "2025-10-30",
"time": "14:30:00",
"receipt_number": "12345",
"register": "01",
"cashier": "Jane"
},
"items": [
{
"name": "Product name",
"quantity": 2,
"unit_price": 10.99,
"total_price": 21.98,
"category": "grocery",
"tax_code": "F",
"discount": 0.00
}
],
"totals": {
"subtotal": 21.98,
"tax": 1.98,
"discount": 0.00,
"total": 23.96,
"payment_method": "Credit Card",
"amount_paid": 23.96,
"change": 0.00
},
"confidence": {
"overall": 0.95,
"merchant": 0.98,
"items": 0.92,
"totals": 0.97
}
}
Important instructions:
1. Extract ALL items from the receipt, no matter how many there are
2. Use null for fields you cannot find
3. For dates, use YYYY-MM-DD format
4. For times, use HH:MM:SS format
5. For prices, use numeric values (not strings)
6. Estimate confidence scores (0.0-1.0) based on image quality and text clarity
7. Return ONLY the JSON object, no other text or explanation"""
def _parse_json_from_text(self, text: str) -> Dict[str, Any]:
"""
Extract and parse JSON from model output text.
Args:
text: Raw text output from the model
Returns:
Parsed JSON dictionary
"""
# Try to find JSON object in text
# Look for content between first { and last }
json_match = re.search(r'\{.*\}', text, re.DOTALL)
if json_match:
json_str = json_match.group(0)
try:
return json.loads(json_str)
except json.JSONDecodeError as e:
logger.warning(f"Failed to parse JSON: {e}")
# Try to fix common issues
json_str = self._fix_json(json_str)
try:
return json.loads(json_str)
except json.JSONDecodeError:
logger.error("Could not parse JSON even after fixes")
# Return empty structure if parsing fails
logger.warning("No valid JSON found in output, returning empty structure")
return {
"merchant": {},
"transaction": {},
"items": [],
"totals": {},
"confidence": {"overall": 0.1}
}
def _fix_json(self, json_str: str) -> str:
"""Attempt to fix common JSON formatting issues."""
# Remove trailing commas
json_str = re.sub(r',\s*}', '}', json_str)
json_str = re.sub(r',\s*]', ']', json_str)
# Fix single quotes to double quotes
json_str = json_str.replace("'", '"')
return json_str
def _validate_result(self, result: Dict[str, Any]) -> Dict[str, Any]:
"""
Validate and enhance extracted data.
Args:
result: Extracted receipt data
Returns:
Validated and enhanced result with warnings
"""
warnings = []
# Ensure required fields exist
required_fields = ["merchant", "transaction", "items", "totals", "confidence"]
for field in required_fields:
if field not in result:
result[field] = {} if field != "items" else []
warnings.append(f"Missing required field: {field}")
# Validate items
if not result.get("items"):
warnings.append("No items found on receipt")
else:
# Validate item structure
for i, item in enumerate(result["items"]):
if "total_price" not in item and "unit_price" in item and "quantity" in item:
item["total_price"] = item["unit_price"] * item["quantity"]
# Validate totals
if result.get("items") and result.get("totals"):
calculated_subtotal = sum(
item.get("total_price", 0)
for item in result["items"]
)
reported_subtotal = result["totals"].get("subtotal", 0)
# Allow small variance (rounding errors)
if abs(calculated_subtotal - reported_subtotal) > 0.10:
warnings.append(
f"Total mismatch: calculated ${calculated_subtotal:.2f}, "
f"reported ${reported_subtotal:.2f}"
)
result["totals"]["calculated_subtotal"] = calculated_subtotal
# Validate date format
if result.get("transaction", {}).get("date"):
try:
datetime.strptime(result["transaction"]["date"], "%Y-%m-%d")
except ValueError:
warnings.append(f"Invalid date format: {result['transaction']['date']}")
# Add warnings to result
if warnings:
result["warnings"] = warnings
# Ensure confidence exists
if "confidence" not in result or not result["confidence"]:
result["confidence"] = {
"overall": 0.5,
"merchant": 0.5,
"items": 0.5,
"totals": 0.5
}
return result
def get_model_info(self) -> Dict[str, Any]:
"""Get information about the loaded model."""
return {
"model_name": self.model_name,
"device": self.device,
"quantization": self.use_quantization,
"loaded": self._model_loaded,
"gpu_available": torch.cuda.is_available(),
"gpu_memory_allocated": torch.cuda.memory_allocated() if torch.cuda.is_available() else 0,
"gpu_memory_reserved": torch.cuda.memory_reserved() if torch.cuda.is_available() else 0
}
def clear_cache(self):
"""Clear GPU memory cache."""
if torch.cuda.is_available():
torch.cuda.empty_cache()
logger.info("GPU cache cleared")