feat: initial cf-docuvision service — Dolphin-v2 document parsing
FastAPI microservice wrapping ByteDance/Dolphin-v2 (Qwen2.5-VL-3B base) for structured document extraction. Exposes POST /extract and GET /health. Maps Dolphin's 21 element types to cf-core's 7-type canonical schema. Services: cf-text /extract, /health Env vars: CF_DOCUVISION_MODEL, CF_DOCUVISION_DEVICE, CF_DOCUVISION_PORT GPU: 8GB+ VRAM required for Dolphin-v2; CPU fallback available but very slow.
This commit is contained in:
commit
47d4dfc786
9 changed files with 661 additions and 0 deletions
15
.env.example
Normal file
15
.env.example
Normal file
|
|
@ -0,0 +1,15 @@
|
||||||
|
# cf-docuvision environment — copy to .env and fill in values
|
||||||
|
|
||||||
|
# Model to load. Default: ByteDance/Dolphin-v2 (downloaded from HuggingFace on first run).
|
||||||
|
# Set to a local path to skip the download: /Library/Assets/LLM/dolphin-v2/
|
||||||
|
CF_DOCUVISION_MODEL=ByteDance/Dolphin-v2
|
||||||
|
|
||||||
|
# Compute device. "auto" detects CUDA if available, falls back to CPU.
|
||||||
|
# CPU is very slow for Dolphin-v2 — 8GB+ VRAM GPU strongly recommended.
|
||||||
|
CF_DOCUVISION_DEVICE=auto
|
||||||
|
|
||||||
|
# Service port (default matches CF_DOCUVISION_URL default in cf-core)
|
||||||
|
CF_DOCUVISION_PORT=8003
|
||||||
|
|
||||||
|
# Log level
|
||||||
|
LOG_LEVEL=INFO
|
||||||
11
.gitignore
vendored
Normal file
11
.gitignore
vendored
Normal file
|
|
@ -0,0 +1,11 @@
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*.egg-info/
|
||||||
|
.env
|
||||||
|
.venv/
|
||||||
|
venv/
|
||||||
|
dist/
|
||||||
|
build/
|
||||||
|
.pytest_cache/
|
||||||
|
.mypy_cache/
|
||||||
|
*.egg
|
||||||
23
Dockerfile
Normal file
23
Dockerfile
Normal file
|
|
@ -0,0 +1,23 @@
|
||||||
|
FROM nvidia/cuda:12.1.1-cudnn8-runtime-ubuntu22.04
|
||||||
|
|
||||||
|
ENV DEBIAN_FRONTEND=noninteractive \
|
||||||
|
PYTHONUNBUFFERED=1 \
|
||||||
|
PIP_NO_CACHE_DIR=1
|
||||||
|
|
||||||
|
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||||
|
python3.11 python3.11-dev python3-pip git \
|
||||||
|
libglib2.0-0 libsm6 libxext6 libxrender-dev \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
RUN ln -sf python3.11 /usr/bin/python3 && ln -sf python3 /usr/bin/python
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
COPY requirements.txt .
|
||||||
|
RUN pip install --upgrade pip && pip install -r requirements.txt
|
||||||
|
|
||||||
|
COPY . .
|
||||||
|
|
||||||
|
EXPOSE 8003
|
||||||
|
|
||||||
|
CMD ["python", "-m", "uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8003"]
|
||||||
0
app/__init__.py
Normal file
0
app/__init__.py
Normal file
275
app/dolphin.py
Normal file
275
app/dolphin.py
Normal file
|
|
@ -0,0 +1,275 @@
|
||||||
|
# app/dolphin.py — Dolphin-v2 model wrapper
|
||||||
|
#
|
||||||
|
# Wraps ByteDance/Dolphin-v2 (Qwen2.5-VL-3B base) for document parsing.
|
||||||
|
# This module is the only place in the codebase that touches the Dolphin model
|
||||||
|
# directly. The FastAPI service (main.py) calls parse_document() and never
|
||||||
|
# imports transformers itself.
|
||||||
|
#
|
||||||
|
# Dolphin-v2 uses a two-stage pipeline:
|
||||||
|
# Stage 1: classify each page region (21 element types)
|
||||||
|
# Stage 2: element-wise or holistic parsing depending on region type
|
||||||
|
#
|
||||||
|
# HuggingFace: https://huggingface.co/ByteDance/Dolphin-v2
|
||||||
|
# VRAM: ~8GB minimum, 16GB+ recommended for multi-page documents
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from io import BytesIO
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_MODEL_ID = os.environ.get("CF_DOCUVISION_MODEL", "ByteDance/Dolphin-v2")
|
||||||
|
_DEVICE = os.environ.get("CF_DOCUVISION_DEVICE", "auto")
|
||||||
|
|
||||||
|
# Dolphin-v2 element type → StructuredDocument element type mapping
|
||||||
|
# Dolphin outputs 21 types; we map to cf-core's canonical 7 (+passthrough)
|
||||||
|
_TYPE_MAP: dict[str, str] = {
|
||||||
|
"title": "heading",
|
||||||
|
"section_header": "heading",
|
||||||
|
"paragraph": "paragraph",
|
||||||
|
"caption": "paragraph",
|
||||||
|
"footnote": "paragraph",
|
||||||
|
"page_header": "paragraph",
|
||||||
|
"page_footer": "paragraph",
|
||||||
|
"list_item": "list",
|
||||||
|
"table": "table",
|
||||||
|
"figure": "figure",
|
||||||
|
"figure_caption": "paragraph",
|
||||||
|
"formula": "formula",
|
||||||
|
"code": "code",
|
||||||
|
"annotation": "paragraph",
|
||||||
|
"abstract": "paragraph",
|
||||||
|
"toc_item": "list",
|
||||||
|
"reference": "paragraph",
|
||||||
|
"equation": "formula",
|
||||||
|
"watermark": "paragraph",
|
||||||
|
"stamp": "paragraph",
|
||||||
|
"signature": "paragraph",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DolphinElement:
|
||||||
|
"""Raw output from Dolphin-v2 for one detected region."""
|
||||||
|
dolphin_type: str
|
||||||
|
text: str
|
||||||
|
bbox: list[float] | None = None # [x0, y0, x1, y1] normalised 0-1
|
||||||
|
html: str | None = None # for table type only
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DolphinResult:
|
||||||
|
"""Parsed output from one document image."""
|
||||||
|
elements: list[DolphinElement] = field(default_factory=list)
|
||||||
|
raw_text: str = ""
|
||||||
|
model: str = _MODEL_ID
|
||||||
|
|
||||||
|
|
||||||
|
class DolphinParser:
|
||||||
|
"""
|
||||||
|
Dolphin-v2 document parser.
|
||||||
|
|
||||||
|
Loaded once at service startup. Thread-safe for concurrent requests
|
||||||
|
(model weights are read-only after loading).
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
parser = DolphinParser.from_env()
|
||||||
|
result = parser.parse(image_bytes, hint="auto")
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, model_id: str = _MODEL_ID, device: str = _DEVICE) -> None:
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
from transformers import AutoModelForCausalLM, AutoProcessor
|
||||||
|
except ImportError as exc:
|
||||||
|
raise ImportError(
|
||||||
|
"torch and transformers are required. "
|
||||||
|
"Install with: pip install -r requirements.txt"
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
logger.info("Loading Dolphin-v2 model %s", model_id)
|
||||||
|
|
||||||
|
if device == "auto":
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
if device == "cpu":
|
||||||
|
logger.warning(
|
||||||
|
"Dolphin-v2 running on CPU — performance will be very slow. "
|
||||||
|
"8GB+ VRAM GPU strongly recommended."
|
||||||
|
)
|
||||||
|
|
||||||
|
self._processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
|
||||||
|
self._model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_id,
|
||||||
|
trust_remote_code=True,
|
||||||
|
device_map=device if device != "cpu" else None,
|
||||||
|
torch_dtype="auto",
|
||||||
|
)
|
||||||
|
if device == "cpu":
|
||||||
|
self._model = self._model.to("cpu")
|
||||||
|
|
||||||
|
self._model_id = model_id
|
||||||
|
self._device = device
|
||||||
|
logger.info("Dolphin-v2 loaded on %s", device)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_env(cls) -> "DolphinParser":
|
||||||
|
return cls(model_id=_MODEL_ID, device=_DEVICE)
|
||||||
|
|
||||||
|
def parse(self, image_bytes: bytes, hint: str = "auto") -> DolphinResult:
|
||||||
|
"""
|
||||||
|
Parse a document image into structured elements.
|
||||||
|
|
||||||
|
image_bytes Raw image bytes (JPEG, PNG, TIFF, etc.)
|
||||||
|
hint Extraction focus: "auto" | "table" | "text" | "form"
|
||||||
|
Passed as context in the Dolphin prompt. "table" prioritises
|
||||||
|
HTML table rendering; "form" prioritises key-value pairs.
|
||||||
|
"""
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
image = Image.open(BytesIO(image_bytes)).convert("RGB")
|
||||||
|
raw_output = self._run_inference(image, hint)
|
||||||
|
return self._parse_output(raw_output)
|
||||||
|
|
||||||
|
def parse_b64(self, image_b64: str, hint: str = "auto") -> DolphinResult:
|
||||||
|
"""Convenience wrapper for base64-encoded image bytes."""
|
||||||
|
return self.parse(base64.b64decode(image_b64), hint=hint)
|
||||||
|
|
||||||
|
def _run_inference(self, image: Any, hint: str) -> list[dict]:
|
||||||
|
"""Run Dolphin-v2 two-stage inference and return raw element dicts."""
|
||||||
|
import torch
|
||||||
|
|
||||||
|
# Dolphin-v2 uses a structured prompt with an optional extraction hint
|
||||||
|
hint_instruction = {
|
||||||
|
"table": " Focus on tables and render them as HTML.",
|
||||||
|
"form": " Focus on form fields and key-value pairs.",
|
||||||
|
"text": " Focus on text content, preserving heading hierarchy.",
|
||||||
|
}.get(hint, "")
|
||||||
|
|
||||||
|
prompt = (
|
||||||
|
f"<|im_start|>system\nYou are a document parsing assistant.{hint_instruction}"
|
||||||
|
f"<|im_end|>\n<|im_start|>user\n<image>\nParse this document.<|im_end|>\n"
|
||||||
|
f"<|im_start|>assistant\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
inputs = self._processor(
|
||||||
|
text=prompt,
|
||||||
|
images=image,
|
||||||
|
return_tensors="pt",
|
||||||
|
).to(self._model.device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
output_ids = self._model.generate(
|
||||||
|
**inputs,
|
||||||
|
max_new_tokens=4096,
|
||||||
|
do_sample=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Decode only the newly generated tokens
|
||||||
|
input_len = inputs["input_ids"].shape[1]
|
||||||
|
generated = output_ids[0][input_len:]
|
||||||
|
raw_text = self._processor.decode(generated, skip_special_tokens=True)
|
||||||
|
|
||||||
|
return self._parse_dolphin_output(raw_text)
|
||||||
|
|
||||||
|
def _parse_dolphin_output(self, raw: str) -> list[dict]:
|
||||||
|
"""
|
||||||
|
Parse Dolphin-v2's structured text output into element dicts.
|
||||||
|
|
||||||
|
Dolphin-v2 outputs a structured format with element markers. This parser
|
||||||
|
extracts them into dicts with keys: type, text, bbox, html.
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
|
||||||
|
elements: list[dict] = []
|
||||||
|
|
||||||
|
# Dolphin-v2 wraps elements in JSON-like blocks:
|
||||||
|
# <element type="paragraph" bbox="[x0,y0,x1,y1]">text</element>
|
||||||
|
# or a JSON array for structured output mode
|
||||||
|
try:
|
||||||
|
parsed = json.loads(raw)
|
||||||
|
if isinstance(parsed, list):
|
||||||
|
return parsed
|
||||||
|
except (json.JSONDecodeError, ValueError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Fallback: regex extraction of element tags
|
||||||
|
pattern = re.compile(
|
||||||
|
r'<element\s+type="([^"]+)"(?:\s+bbox="([^"]+)")?>(.*?)</element>',
|
||||||
|
re.DOTALL,
|
||||||
|
)
|
||||||
|
for match in pattern.finditer(raw):
|
||||||
|
el_type, bbox_str, text = match.groups()
|
||||||
|
bbox = None
|
||||||
|
if bbox_str:
|
||||||
|
try:
|
||||||
|
bbox = [float(x) for x in bbox_str.strip("[]").split(",")]
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
elements.append({
|
||||||
|
"type": el_type.strip(),
|
||||||
|
"text": text.strip(),
|
||||||
|
"bbox": bbox,
|
||||||
|
"html": text.strip() if el_type == "table" else None,
|
||||||
|
})
|
||||||
|
|
||||||
|
if not elements and raw.strip():
|
||||||
|
# Last resort: treat entire output as a single paragraph
|
||||||
|
elements = [{"type": "paragraph", "text": raw.strip(), "bbox": None, "html": None}]
|
||||||
|
|
||||||
|
return elements
|
||||||
|
|
||||||
|
def _parse_output(self, raw_elements: list[dict]) -> DolphinResult:
|
||||||
|
elements: list[DolphinElement] = []
|
||||||
|
texts: list[str] = []
|
||||||
|
|
||||||
|
for el in raw_elements:
|
||||||
|
dolphin_type = el.get("type", "paragraph")
|
||||||
|
text = el.get("text", "").strip()
|
||||||
|
elements.append(DolphinElement(
|
||||||
|
dolphin_type=dolphin_type,
|
||||||
|
text=text,
|
||||||
|
bbox=el.get("bbox"),
|
||||||
|
html=el.get("html"),
|
||||||
|
))
|
||||||
|
if text:
|
||||||
|
texts.append(text)
|
||||||
|
|
||||||
|
return DolphinResult(
|
||||||
|
elements=elements,
|
||||||
|
raw_text="\n".join(texts),
|
||||||
|
model=self._model_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def dolphin_to_cf_elements(result: DolphinResult) -> tuple[list[dict], list[dict]]:
|
||||||
|
"""
|
||||||
|
Convert DolphinResult into cf-core StructuredDocument wire format.
|
||||||
|
|
||||||
|
Returns (elements_list, tables_list) ready to JSON-serialise in the
|
||||||
|
/extract response. Tables are separated from elements to match the
|
||||||
|
DocuvisionClient._parse_response() contract.
|
||||||
|
"""
|
||||||
|
elements: list[dict] = []
|
||||||
|
tables: list[dict] = []
|
||||||
|
|
||||||
|
for el in result.elements:
|
||||||
|
cf_type = _TYPE_MAP.get(el.dolphin_type, "paragraph")
|
||||||
|
|
||||||
|
if cf_type == "table" and el.html:
|
||||||
|
tables.append({
|
||||||
|
"html": el.html,
|
||||||
|
"bbox": el.bbox,
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
elements.append({
|
||||||
|
"type": cf_type,
|
||||||
|
"text": el.text,
|
||||||
|
"bbox": el.bbox,
|
||||||
|
})
|
||||||
|
|
||||||
|
return elements, tables
|
||||||
113
app/main.py
Normal file
113
app/main.py
Normal file
|
|
@ -0,0 +1,113 @@
|
||||||
|
# app/main.py — cf-docuvision FastAPI service
|
||||||
|
#
|
||||||
|
# Exposes POST /extract and GET /health.
|
||||||
|
# Response schema matches DocuvisionClient._parse_response() in cf-core.
|
||||||
|
#
|
||||||
|
# Start:
|
||||||
|
# uvicorn app.main:app --host 0.0.0.0 --port 8003
|
||||||
|
# CF_DOCUVISION_DEVICE=cuda uvicorn app.main:app ...
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from fastapi import FastAPI, HTTPException
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from app.dolphin import DolphinParser, dolphin_to_cf_elements
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
logging.basicConfig(level=os.environ.get("LOG_LEVEL", "INFO"))
|
||||||
|
|
||||||
|
# ── Model lifecycle ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
_parser: DolphinParser | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
global _parser
|
||||||
|
logger.info("cf-docuvision: loading Dolphin-v2...")
|
||||||
|
_parser = DolphinParser.from_env()
|
||||||
|
logger.info("cf-docuvision: ready")
|
||||||
|
yield
|
||||||
|
_parser = None
|
||||||
|
|
||||||
|
|
||||||
|
app = FastAPI(
|
||||||
|
title="cf-docuvision",
|
||||||
|
description="Dolphin-v2 document parsing service for CircuitForge products.",
|
||||||
|
version="0.1.0",
|
||||||
|
lifespan=lifespan,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Request / Response schemas ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class ExtractRequest(BaseModel):
|
||||||
|
image_b64: str
|
||||||
|
hint: str = "auto"
|
||||||
|
|
||||||
|
|
||||||
|
class ExtractResponse(BaseModel):
|
||||||
|
elements: list[dict[str, Any]]
|
||||||
|
tables: list[dict[str, Any]]
|
||||||
|
raw_text: str
|
||||||
|
metadata: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
# ── Endpoints ─────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@app.get("/health")
|
||||||
|
def health():
|
||||||
|
"""Health check. Returns 200 when the model is loaded and ready."""
|
||||||
|
if _parser is None:
|
||||||
|
raise HTTPException(status_code=503, detail="Model not loaded")
|
||||||
|
return {"status": "ok", "model": _parser._model_id}
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/extract", response_model=ExtractResponse)
|
||||||
|
def extract(req: ExtractRequest):
|
||||||
|
"""
|
||||||
|
Parse a document image into structured elements.
|
||||||
|
|
||||||
|
Request body:
|
||||||
|
image_b64 Base64-encoded image bytes (JPEG, PNG, TIFF, PDF page, etc.)
|
||||||
|
hint Extraction focus: "auto" | "table" | "text" | "form"
|
||||||
|
|
||||||
|
Response matches the DocuvisionClient._parse_response() contract in cf-core.
|
||||||
|
"""
|
||||||
|
if _parser is None:
|
||||||
|
raise HTTPException(status_code=503, detail="Model not loaded")
|
||||||
|
|
||||||
|
if req.hint not in ("auto", "table", "text", "form"):
|
||||||
|
raise HTTPException(status_code=422, detail=f"Invalid hint {req.hint!r}")
|
||||||
|
|
||||||
|
t0 = time.monotonic()
|
||||||
|
try:
|
||||||
|
result = _parser.parse_b64(req.image_b64, hint=req.hint)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.exception("cf-docuvision: parse failed")
|
||||||
|
raise HTTPException(status_code=500, detail=str(exc)) from exc
|
||||||
|
|
||||||
|
elements, tables = dolphin_to_cf_elements(result)
|
||||||
|
elapsed_ms = round((time.monotonic() - t0) * 1000)
|
||||||
|
logger.info(
|
||||||
|
"cf-docuvision: extracted %d elements, %d tables in %dms",
|
||||||
|
len(elements), len(tables), elapsed_ms,
|
||||||
|
)
|
||||||
|
|
||||||
|
return ExtractResponse(
|
||||||
|
elements=elements,
|
||||||
|
tables=tables,
|
||||||
|
raw_text=result.raw_text,
|
||||||
|
metadata={
|
||||||
|
"source": "cf-docuvision",
|
||||||
|
"model": result.model,
|
||||||
|
"hint": req.hint,
|
||||||
|
"elapsed_ms": elapsed_ms,
|
||||||
|
},
|
||||||
|
)
|
||||||
26
compose.yml
Normal file
26
compose.yml
Normal file
|
|
@ -0,0 +1,26 @@
|
||||||
|
services:
|
||||||
|
cf-docuvision:
|
||||||
|
build: .
|
||||||
|
network_mode: host
|
||||||
|
env_file: .env
|
||||||
|
environment:
|
||||||
|
CF_DOCUVISION_PORT: "8003"
|
||||||
|
volumes:
|
||||||
|
# Cache HuggingFace model weights across rebuilds
|
||||||
|
- ${HOME}/.cache/huggingface:/root/.cache/huggingface
|
||||||
|
# Optional: mount a local model path to skip HF download
|
||||||
|
# - /Library/Assets/LLM/dolphin-v2:/models/dolphin-v2:ro
|
||||||
|
restart: unless-stopped
|
||||||
|
deploy:
|
||||||
|
resources:
|
||||||
|
reservations:
|
||||||
|
devices:
|
||||||
|
- driver: nvidia
|
||||||
|
count: 1
|
||||||
|
capabilities: [gpu]
|
||||||
|
healthcheck:
|
||||||
|
test: ["CMD", "python3", "-c", "import urllib.request; urllib.request.urlopen('http://localhost:8003/health')"]
|
||||||
|
interval: 30s
|
||||||
|
timeout: 10s
|
||||||
|
retries: 3
|
||||||
|
start_period: 120s
|
||||||
8
requirements.txt
Normal file
8
requirements.txt
Normal file
|
|
@ -0,0 +1,8 @@
|
||||||
|
fastapi>=0.110
|
||||||
|
uvicorn[standard]>=0.29
|
||||||
|
pydantic>=2.0
|
||||||
|
torch>=2.0
|
||||||
|
transformers>=4.40
|
||||||
|
accelerate>=0.27
|
||||||
|
Pillow>=10.0
|
||||||
|
python-multipart>=0.0.9
|
||||||
190
tests/test_dolphin.py
Normal file
190
tests/test_dolphin.py
Normal file
|
|
@ -0,0 +1,190 @@
|
||||||
|
"""Tests for cf-docuvision — mock inference path only (no GPU required)."""
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
from app.dolphin import (
|
||||||
|
DolphinElement,
|
||||||
|
DolphinParser,
|
||||||
|
DolphinResult,
|
||||||
|
dolphin_to_cf_elements,
|
||||||
|
_TYPE_MAP,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestTypeMap:
|
||||||
|
def test_all_21_dolphin_types_mapped(self):
|
||||||
|
# Spot-check the 21 Dolphin-v2 element types are all covered
|
||||||
|
expected = {
|
||||||
|
"title", "section_header", "paragraph", "caption", "footnote",
|
||||||
|
"page_header", "page_footer", "list_item", "table", "figure",
|
||||||
|
"figure_caption", "formula", "code", "annotation", "abstract",
|
||||||
|
"toc_item", "reference", "equation", "watermark", "stamp", "signature",
|
||||||
|
}
|
||||||
|
assert expected == set(_TYPE_MAP.keys())
|
||||||
|
|
||||||
|
def test_table_maps_to_table(self):
|
||||||
|
assert _TYPE_MAP["table"] == "table"
|
||||||
|
|
||||||
|
def test_title_maps_to_heading(self):
|
||||||
|
assert _TYPE_MAP["title"] == "heading"
|
||||||
|
|
||||||
|
def test_formula_maps_to_formula(self):
|
||||||
|
assert _TYPE_MAP["formula"] == "formula"
|
||||||
|
|
||||||
|
|
||||||
|
class TestDolphinToCfElements:
|
||||||
|
def _make_result(self, elements: list[DolphinElement]) -> DolphinResult:
|
||||||
|
raw_text = "\n".join(e.text for e in elements if e.text)
|
||||||
|
return DolphinResult(elements=elements, raw_text=raw_text)
|
||||||
|
|
||||||
|
def test_paragraph_goes_to_elements(self):
|
||||||
|
result = self._make_result([DolphinElement("paragraph", "Hello world")])
|
||||||
|
elements, tables = dolphin_to_cf_elements(result)
|
||||||
|
assert len(elements) == 1
|
||||||
|
assert elements[0]["type"] == "paragraph"
|
||||||
|
assert elements[0]["text"] == "Hello world"
|
||||||
|
assert tables == []
|
||||||
|
|
||||||
|
def test_table_with_html_goes_to_tables(self):
|
||||||
|
result = self._make_result([
|
||||||
|
DolphinElement("table", "col1 col2", html="<table><tr><td>A</td></tr></table>")
|
||||||
|
])
|
||||||
|
elements, tables = dolphin_to_cf_elements(result)
|
||||||
|
assert len(tables) == 1
|
||||||
|
assert "<table>" in tables[0]["html"]
|
||||||
|
assert elements == []
|
||||||
|
|
||||||
|
def test_table_without_html_goes_to_elements(self):
|
||||||
|
result = self._make_result([DolphinElement("table", "some table text", html=None)])
|
||||||
|
elements, tables = dolphin_to_cf_elements(result)
|
||||||
|
assert len(elements) == 1
|
||||||
|
assert tables == []
|
||||||
|
|
||||||
|
def test_bbox_preserved(self):
|
||||||
|
result = self._make_result([
|
||||||
|
DolphinElement("paragraph", "text", bbox=[0.1, 0.2, 0.8, 0.3])
|
||||||
|
])
|
||||||
|
elements, _ = dolphin_to_cf_elements(result)
|
||||||
|
assert elements[0]["bbox"] == [0.1, 0.2, 0.8, 0.3]
|
||||||
|
|
||||||
|
def test_mixed_elements_and_tables(self):
|
||||||
|
result = self._make_result([
|
||||||
|
DolphinElement("title", "Document Title"),
|
||||||
|
DolphinElement("table", "data", html="<table/>"),
|
||||||
|
DolphinElement("paragraph", "Body text"),
|
||||||
|
])
|
||||||
|
elements, tables = dolphin_to_cf_elements(result)
|
||||||
|
assert len(elements) == 2
|
||||||
|
assert len(tables) == 1
|
||||||
|
assert elements[0]["type"] == "heading"
|
||||||
|
|
||||||
|
def test_empty_result(self):
|
||||||
|
result = DolphinResult()
|
||||||
|
elements, tables = dolphin_to_cf_elements(result)
|
||||||
|
assert elements == []
|
||||||
|
assert tables == []
|
||||||
|
|
||||||
|
|
||||||
|
class TestParseOutputFallbacks:
|
||||||
|
"""Test _parse_dolphin_output without loading the real model."""
|
||||||
|
|
||||||
|
def _make_parser(self) -> DolphinParser:
|
||||||
|
"""Create a DolphinParser without loading the model."""
|
||||||
|
parser = object.__new__(DolphinParser)
|
||||||
|
parser._model_id = "ByteDance/Dolphin-v2"
|
||||||
|
parser._device = "cpu"
|
||||||
|
return parser
|
||||||
|
|
||||||
|
def test_json_array_output(self):
|
||||||
|
parser = self._make_parser()
|
||||||
|
raw = '[{"type": "paragraph", "text": "Hello", "bbox": null, "html": null}]'
|
||||||
|
elements = parser._parse_dolphin_output(raw)
|
||||||
|
assert len(elements) == 1
|
||||||
|
assert elements[0]["type"] == "paragraph"
|
||||||
|
|
||||||
|
def test_element_tag_output(self):
|
||||||
|
parser = self._make_parser()
|
||||||
|
raw = '<element type="title" bbox="[0.1,0.2,0.8,0.3]">My Title</element>'
|
||||||
|
elements = parser._parse_dolphin_output(raw)
|
||||||
|
assert len(elements) == 1
|
||||||
|
assert elements[0]["type"] == "title"
|
||||||
|
assert elements[0]["text"] == "My Title"
|
||||||
|
assert elements[0]["bbox"] == [0.1, 0.2, 0.8, 0.3]
|
||||||
|
|
||||||
|
def test_element_tag_without_bbox(self):
|
||||||
|
parser = self._make_parser()
|
||||||
|
raw = '<element type="paragraph">Plain text</element>'
|
||||||
|
elements = parser._parse_dolphin_output(raw)
|
||||||
|
assert elements[0]["bbox"] is None
|
||||||
|
|
||||||
|
def test_fallback_to_single_paragraph(self):
|
||||||
|
parser = self._make_parser()
|
||||||
|
raw = "This is some unstructured text output."
|
||||||
|
elements = parser._parse_dolphin_output(raw)
|
||||||
|
assert len(elements) == 1
|
||||||
|
assert elements[0]["type"] == "paragraph"
|
||||||
|
assert "unstructured text" in elements[0]["text"]
|
||||||
|
|
||||||
|
def test_empty_output(self):
|
||||||
|
parser = self._make_parser()
|
||||||
|
elements = parser._parse_dolphin_output("")
|
||||||
|
assert elements == []
|
||||||
|
|
||||||
|
|
||||||
|
class TestFastAPIRoutes:
|
||||||
|
"""Integration tests for the FastAPI endpoints using TestClient."""
|
||||||
|
|
||||||
|
def _make_app_with_mock_parser(self):
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
import app.main as main_module
|
||||||
|
|
||||||
|
mock_parser = MagicMock()
|
||||||
|
mock_parser._model_id = "ByteDance/Dolphin-v2"
|
||||||
|
|
||||||
|
from app.dolphin import DolphinResult, DolphinElement
|
||||||
|
mock_result = DolphinResult(
|
||||||
|
elements=[DolphinElement("paragraph", "Extracted text")],
|
||||||
|
raw_text="Extracted text",
|
||||||
|
)
|
||||||
|
mock_parser.parse_b64.return_value = mock_result
|
||||||
|
|
||||||
|
main_module._parser = mock_parser
|
||||||
|
return TestClient(main_module.app)
|
||||||
|
|
||||||
|
def test_health_with_loaded_model(self):
|
||||||
|
client = self._make_app_with_mock_parser()
|
||||||
|
resp = client.get("/health")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json()["status"] == "ok"
|
||||||
|
|
||||||
|
def test_health_without_model_returns_503(self):
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
import app.main as main_module
|
||||||
|
main_module._parser = None
|
||||||
|
client = TestClient(main_module.app, raise_server_exceptions=False)
|
||||||
|
resp = client.get("/health")
|
||||||
|
assert resp.status_code == 503
|
||||||
|
|
||||||
|
def test_extract_returns_structured_response(self):
|
||||||
|
import base64
|
||||||
|
client = self._make_app_with_mock_parser()
|
||||||
|
payload = {
|
||||||
|
"image_b64": base64.b64encode(b"fake-image-bytes").decode(),
|
||||||
|
"hint": "auto",
|
||||||
|
}
|
||||||
|
resp = client.post("/extract", json=payload)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert "elements" in data
|
||||||
|
assert "tables" in data
|
||||||
|
assert "raw_text" in data
|
||||||
|
assert data["metadata"]["source"] == "cf-docuvision"
|
||||||
|
|
||||||
|
def test_extract_invalid_hint_returns_422(self):
|
||||||
|
import base64
|
||||||
|
client = self._make_app_with_mock_parser()
|
||||||
|
payload = {
|
||||||
|
"image_b64": base64.b64encode(b"fake").decode(),
|
||||||
|
"hint": "invalid",
|
||||||
|
}
|
||||||
|
resp = client.post("/extract", json=payload)
|
||||||
|
assert resp.status_code == 422
|
||||||
Loading…
Reference in a new issue