FastAPI microservice wrapping ByteDance/Dolphin-v2 (Qwen2.5-VL-3B base) for structured document extraction. Exposes POST /extract and GET /health. Maps Dolphin's 21 element types to cf-core's 7-type canonical schema. Services: cf-text /extract, /health Env vars: CF_DOCUVISION_MODEL, CF_DOCUVISION_DEVICE, CF_DOCUVISION_PORT GPU: 8GB+ VRAM required for Dolphin-v2; CPU fallback available but very slow.
275 lines
9.3 KiB
Python
275 lines
9.3 KiB
Python
# app/dolphin.py — Dolphin-v2 model wrapper
|
|
#
|
|
# Wraps ByteDance/Dolphin-v2 (Qwen2.5-VL-3B base) for document parsing.
|
|
# This module is the only place in the codebase that touches the Dolphin model
|
|
# directly. The FastAPI service (main.py) calls parse_document() and never
|
|
# imports transformers itself.
|
|
#
|
|
# Dolphin-v2 uses a two-stage pipeline:
|
|
# Stage 1: classify each page region (21 element types)
|
|
# Stage 2: element-wise or holistic parsing depending on region type
|
|
#
|
|
# HuggingFace: https://huggingface.co/ByteDance/Dolphin-v2
|
|
# VRAM: ~8GB minimum, 16GB+ recommended for multi-page documents
|
|
from __future__ import annotations
|
|
|
|
import base64
|
|
import logging
|
|
import os
|
|
from dataclasses import dataclass, field
|
|
from io import BytesIO
|
|
from typing import Any
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_MODEL_ID = os.environ.get("CF_DOCUVISION_MODEL", "ByteDance/Dolphin-v2")
|
|
_DEVICE = os.environ.get("CF_DOCUVISION_DEVICE", "auto")
|
|
|
|
# Dolphin-v2 element type → StructuredDocument element type mapping
|
|
# Dolphin outputs 21 types; we map to cf-core's canonical 7 (+passthrough)
|
|
_TYPE_MAP: dict[str, str] = {
|
|
"title": "heading",
|
|
"section_header": "heading",
|
|
"paragraph": "paragraph",
|
|
"caption": "paragraph",
|
|
"footnote": "paragraph",
|
|
"page_header": "paragraph",
|
|
"page_footer": "paragraph",
|
|
"list_item": "list",
|
|
"table": "table",
|
|
"figure": "figure",
|
|
"figure_caption": "paragraph",
|
|
"formula": "formula",
|
|
"code": "code",
|
|
"annotation": "paragraph",
|
|
"abstract": "paragraph",
|
|
"toc_item": "list",
|
|
"reference": "paragraph",
|
|
"equation": "formula",
|
|
"watermark": "paragraph",
|
|
"stamp": "paragraph",
|
|
"signature": "paragraph",
|
|
}
|
|
|
|
|
|
@dataclass
|
|
class DolphinElement:
|
|
"""Raw output from Dolphin-v2 for one detected region."""
|
|
dolphin_type: str
|
|
text: str
|
|
bbox: list[float] | None = None # [x0, y0, x1, y1] normalised 0-1
|
|
html: str | None = None # for table type only
|
|
|
|
|
|
@dataclass
|
|
class DolphinResult:
|
|
"""Parsed output from one document image."""
|
|
elements: list[DolphinElement] = field(default_factory=list)
|
|
raw_text: str = ""
|
|
model: str = _MODEL_ID
|
|
|
|
|
|
class DolphinParser:
|
|
"""
|
|
Dolphin-v2 document parser.
|
|
|
|
Loaded once at service startup. Thread-safe for concurrent requests
|
|
(model weights are read-only after loading).
|
|
|
|
Usage:
|
|
parser = DolphinParser.from_env()
|
|
result = parser.parse(image_bytes, hint="auto")
|
|
"""
|
|
|
|
def __init__(self, model_id: str = _MODEL_ID, device: str = _DEVICE) -> None:
|
|
try:
|
|
import torch
|
|
from transformers import AutoModelForCausalLM, AutoProcessor
|
|
except ImportError as exc:
|
|
raise ImportError(
|
|
"torch and transformers are required. "
|
|
"Install with: pip install -r requirements.txt"
|
|
) from exc
|
|
|
|
logger.info("Loading Dolphin-v2 model %s", model_id)
|
|
|
|
if device == "auto":
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
if device == "cpu":
|
|
logger.warning(
|
|
"Dolphin-v2 running on CPU — performance will be very slow. "
|
|
"8GB+ VRAM GPU strongly recommended."
|
|
)
|
|
|
|
self._processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
|
|
self._model = AutoModelForCausalLM.from_pretrained(
|
|
model_id,
|
|
trust_remote_code=True,
|
|
device_map=device if device != "cpu" else None,
|
|
torch_dtype="auto",
|
|
)
|
|
if device == "cpu":
|
|
self._model = self._model.to("cpu")
|
|
|
|
self._model_id = model_id
|
|
self._device = device
|
|
logger.info("Dolphin-v2 loaded on %s", device)
|
|
|
|
@classmethod
|
|
def from_env(cls) -> "DolphinParser":
|
|
return cls(model_id=_MODEL_ID, device=_DEVICE)
|
|
|
|
def parse(self, image_bytes: bytes, hint: str = "auto") -> DolphinResult:
|
|
"""
|
|
Parse a document image into structured elements.
|
|
|
|
image_bytes Raw image bytes (JPEG, PNG, TIFF, etc.)
|
|
hint Extraction focus: "auto" | "table" | "text" | "form"
|
|
Passed as context in the Dolphin prompt. "table" prioritises
|
|
HTML table rendering; "form" prioritises key-value pairs.
|
|
"""
|
|
from PIL import Image
|
|
|
|
image = Image.open(BytesIO(image_bytes)).convert("RGB")
|
|
raw_output = self._run_inference(image, hint)
|
|
return self._parse_output(raw_output)
|
|
|
|
def parse_b64(self, image_b64: str, hint: str = "auto") -> DolphinResult:
|
|
"""Convenience wrapper for base64-encoded image bytes."""
|
|
return self.parse(base64.b64decode(image_b64), hint=hint)
|
|
|
|
def _run_inference(self, image: Any, hint: str) -> list[dict]:
|
|
"""Run Dolphin-v2 two-stage inference and return raw element dicts."""
|
|
import torch
|
|
|
|
# Dolphin-v2 uses a structured prompt with an optional extraction hint
|
|
hint_instruction = {
|
|
"table": " Focus on tables and render them as HTML.",
|
|
"form": " Focus on form fields and key-value pairs.",
|
|
"text": " Focus on text content, preserving heading hierarchy.",
|
|
}.get(hint, "")
|
|
|
|
prompt = (
|
|
f"<|im_start|>system\nYou are a document parsing assistant.{hint_instruction}"
|
|
f"<|im_end|>\n<|im_start|>user\n<image>\nParse this document.<|im_end|>\n"
|
|
f"<|im_start|>assistant\n"
|
|
)
|
|
|
|
inputs = self._processor(
|
|
text=prompt,
|
|
images=image,
|
|
return_tensors="pt",
|
|
).to(self._model.device)
|
|
|
|
with torch.no_grad():
|
|
output_ids = self._model.generate(
|
|
**inputs,
|
|
max_new_tokens=4096,
|
|
do_sample=False,
|
|
)
|
|
|
|
# Decode only the newly generated tokens
|
|
input_len = inputs["input_ids"].shape[1]
|
|
generated = output_ids[0][input_len:]
|
|
raw_text = self._processor.decode(generated, skip_special_tokens=True)
|
|
|
|
return self._parse_dolphin_output(raw_text)
|
|
|
|
def _parse_dolphin_output(self, raw: str) -> list[dict]:
|
|
"""
|
|
Parse Dolphin-v2's structured text output into element dicts.
|
|
|
|
Dolphin-v2 outputs a structured format with element markers. This parser
|
|
extracts them into dicts with keys: type, text, bbox, html.
|
|
"""
|
|
import json
|
|
import re
|
|
|
|
elements: list[dict] = []
|
|
|
|
# Dolphin-v2 wraps elements in JSON-like blocks:
|
|
# <element type="paragraph" bbox="[x0,y0,x1,y1]">text</element>
|
|
# or a JSON array for structured output mode
|
|
try:
|
|
parsed = json.loads(raw)
|
|
if isinstance(parsed, list):
|
|
return parsed
|
|
except (json.JSONDecodeError, ValueError):
|
|
pass
|
|
|
|
# Fallback: regex extraction of element tags
|
|
pattern = re.compile(
|
|
r'<element\s+type="([^"]+)"(?:\s+bbox="([^"]+)")?>(.*?)</element>',
|
|
re.DOTALL,
|
|
)
|
|
for match in pattern.finditer(raw):
|
|
el_type, bbox_str, text = match.groups()
|
|
bbox = None
|
|
if bbox_str:
|
|
try:
|
|
bbox = [float(x) for x in bbox_str.strip("[]").split(",")]
|
|
except ValueError:
|
|
pass
|
|
elements.append({
|
|
"type": el_type.strip(),
|
|
"text": text.strip(),
|
|
"bbox": bbox,
|
|
"html": text.strip() if el_type == "table" else None,
|
|
})
|
|
|
|
if not elements and raw.strip():
|
|
# Last resort: treat entire output as a single paragraph
|
|
elements = [{"type": "paragraph", "text": raw.strip(), "bbox": None, "html": None}]
|
|
|
|
return elements
|
|
|
|
def _parse_output(self, raw_elements: list[dict]) -> DolphinResult:
|
|
elements: list[DolphinElement] = []
|
|
texts: list[str] = []
|
|
|
|
for el in raw_elements:
|
|
dolphin_type = el.get("type", "paragraph")
|
|
text = el.get("text", "").strip()
|
|
elements.append(DolphinElement(
|
|
dolphin_type=dolphin_type,
|
|
text=text,
|
|
bbox=el.get("bbox"),
|
|
html=el.get("html"),
|
|
))
|
|
if text:
|
|
texts.append(text)
|
|
|
|
return DolphinResult(
|
|
elements=elements,
|
|
raw_text="\n".join(texts),
|
|
model=self._model_id,
|
|
)
|
|
|
|
|
|
def dolphin_to_cf_elements(result: DolphinResult) -> tuple[list[dict], list[dict]]:
|
|
"""
|
|
Convert DolphinResult into cf-core StructuredDocument wire format.
|
|
|
|
Returns (elements_list, tables_list) ready to JSON-serialise in the
|
|
/extract response. Tables are separated from elements to match the
|
|
DocuvisionClient._parse_response() contract.
|
|
"""
|
|
elements: list[dict] = []
|
|
tables: list[dict] = []
|
|
|
|
for el in result.elements:
|
|
cf_type = _TYPE_MAP.get(el.dolphin_type, "paragraph")
|
|
|
|
if cf_type == "table" and el.html:
|
|
tables.append({
|
|
"html": el.html,
|
|
"bbox": el.bbox,
|
|
})
|
|
else:
|
|
elements.append({
|
|
"type": cf_type,
|
|
"text": el.text,
|
|
"bbox": el.bbox,
|
|
})
|
|
|
|
return elements, tables
|