feat: cf-vision managed service (#43)
SigLIP so400m-patch14-384 as default backend (classify + embed, ~1.4 GB VRAM). VLM backend (moondream2, LLaVA, Qwen-VL, etc.) as callable alternative for caption generation and VQA. Follows the same factory/Protocol/mock pattern as cf-stt and cf-tts. New module: circuitforge_core.vision - backends/base.py — VisionBackend Protocol, VisionResult, make_vision_backend() - backends/mock.py — MockVisionBackend (no GPU, deterministic) - backends/siglip.py — SigLIPBackend: sigmoid zero-shot classify + L2 embed - backends/vlm.py — VLMBackend: AutoModelForVision2Seq caption + prompt classify - __init__.py — process singleton; classify(), embed(), caption(), make_backend() - app.py — FastAPI service (port 8006): /health /classify /embed /caption Backend selection: CF_VISION_BACKEND=siglip|vlm, auto-detected from model path. VLM backend: supports_embed=False, caption()/classify() only. SigLIP backend: supports_caption=False, classify()/embed() only. 52 new tests, 385 total passing. Closes #43.
This commit is contained in:
parent
80b0d5fd34
commit
8c1daf3b6c
12 changed files with 1354 additions and 28 deletions
|
|
@ -1,3 +1,108 @@
|
|||
from .router import VisionRouter
|
||||
"""
|
||||
circuitforge_core.vision — Managed vision service module.
|
||||
|
||||
__all__ = ["VisionRouter"]
|
||||
Quick start (mock mode — no GPU or model required):
|
||||
|
||||
import os; os.environ["CF_VISION_MOCK"] = "1"
|
||||
from circuitforge_core.vision import classify, embed
|
||||
|
||||
result = classify(image_bytes, labels=["cat", "dog", "bird"])
|
||||
print(result.top(1)) # [("cat", 0.82)]
|
||||
|
||||
emb = embed(image_bytes)
|
||||
print(len(emb.embedding)) # 1152 (so400m hidden dim)
|
||||
|
||||
Real inference (SigLIP — default, ~1.4 GB VRAM):
|
||||
|
||||
export CF_VISION_MODEL=google/siglip-so400m-patch14-384
|
||||
from circuitforge_core.vision import classify
|
||||
|
||||
Full VLM inference (caption + VQA):
|
||||
|
||||
export CF_VISION_BACKEND=vlm
|
||||
export CF_VISION_MODEL=vikhyatk/moondream2
|
||||
from circuitforge_core.vision import caption
|
||||
|
||||
Per-request backend (bypasses process singleton):
|
||||
|
||||
from circuitforge_core.vision import make_backend
|
||||
vlm = make_backend("vikhyatk/moondream2", backend="vlm")
|
||||
result = vlm.caption(image_bytes, prompt="What text appears in this image?")
|
||||
|
||||
cf-orch service profile:
|
||||
|
||||
service_type: cf-vision
|
||||
max_mb: 1536 (siglip-so400m); 2200 (moondream2); 14500 (llava-7b)
|
||||
max_concurrent: 4 (siglip); 1 (vlm)
|
||||
shared: true
|
||||
managed:
|
||||
exec: python -m circuitforge_core.vision.app
|
||||
args: --model <path> --backend siglip --port {port} --gpu-id {gpu_id}
|
||||
port: 8006
|
||||
health: /health
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
from circuitforge_core.vision.backends.base import (
|
||||
VisionBackend,
|
||||
VisionResult,
|
||||
make_vision_backend,
|
||||
)
|
||||
from circuitforge_core.vision.backends.mock import MockVisionBackend
|
||||
|
||||
_backend: VisionBackend | None = None
|
||||
|
||||
|
||||
def _get_backend() -> VisionBackend:
|
||||
global _backend
|
||||
if _backend is None:
|
||||
model_path = os.environ.get("CF_VISION_MODEL", "mock")
|
||||
mock = model_path == "mock" or os.environ.get("CF_VISION_MOCK", "") == "1"
|
||||
_backend = make_vision_backend(model_path, mock=mock)
|
||||
return _backend
|
||||
|
||||
|
||||
def classify(image: bytes, labels: list[str]) -> VisionResult:
|
||||
"""Zero-shot image classification using the process-level backend."""
|
||||
return _get_backend().classify(image, labels)
|
||||
|
||||
|
||||
def embed(image: bytes) -> VisionResult:
|
||||
"""Image embedding using the process-level backend (SigLIP only)."""
|
||||
return _get_backend().embed(image)
|
||||
|
||||
|
||||
def caption(image: bytes, prompt: str = "") -> VisionResult:
|
||||
"""Image captioning / VQA using the process-level backend (VLM only)."""
|
||||
return _get_backend().caption(image, prompt)
|
||||
|
||||
|
||||
def make_backend(
|
||||
model_path: str,
|
||||
backend: str | None = None,
|
||||
mock: bool | None = None,
|
||||
device: str = "cuda",
|
||||
dtype: str = "float16",
|
||||
) -> VisionBackend:
|
||||
"""
|
||||
Create a one-off VisionBackend without affecting the process singleton.
|
||||
|
||||
Useful when a product needs both SigLIP (routing) and a VLM (captioning)
|
||||
in the same process, or when testing different models side-by-side.
|
||||
"""
|
||||
return make_vision_backend(
|
||||
model_path, backend=backend, mock=mock, device=device, dtype=dtype
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"VisionBackend",
|
||||
"VisionResult",
|
||||
"MockVisionBackend",
|
||||
"classify",
|
||||
"embed",
|
||||
"caption",
|
||||
"make_backend",
|
||||
]
|
||||
|
|
|
|||
245
circuitforge_core/vision/app.py
Normal file
245
circuitforge_core/vision/app.py
Normal file
|
|
@ -0,0 +1,245 @@
|
|||
"""
|
||||
circuitforge_core.vision.app — cf-vision FastAPI service.
|
||||
|
||||
Managed by cf-orch as a process-type service. cf-orch starts this via:
|
||||
|
||||
python -m circuitforge_core.vision.app \
|
||||
--model google/siglip-so400m-patch14-384 \
|
||||
--backend siglip \
|
||||
--port 8006 \
|
||||
--gpu-id 0
|
||||
|
||||
For VLM inference (caption/VQA):
|
||||
|
||||
python -m circuitforge_core.vision.app \
|
||||
--model vikhyatk/moondream2 \
|
||||
--backend vlm \
|
||||
--port 8006 \
|
||||
--gpu-id 0
|
||||
|
||||
Endpoints:
|
||||
GET /health → {"status": "ok", "model": "...", "vram_mb": n,
|
||||
"supports_embed": bool, "supports_caption": bool}
|
||||
POST /classify → VisionClassifyResponse (multipart: image + labels)
|
||||
POST /embed → VisionEmbedResponse (multipart: image)
|
||||
POST /caption → VisionCaptionResponse (multipart: image + prompt)
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, File, Form, HTTPException, UploadFile
|
||||
from pydantic import BaseModel
|
||||
|
||||
from circuitforge_core.vision.backends.base import make_vision_backend
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ── Response models ───────────────────────────────────────────────────────────
|
||||
|
||||
class VisionClassifyResponse(BaseModel):
|
||||
labels: list[str]
|
||||
scores: list[float]
|
||||
model: str
|
||||
|
||||
|
||||
class VisionEmbedResponse(BaseModel):
|
||||
embedding: list[float]
|
||||
model: str
|
||||
|
||||
|
||||
class VisionCaptionResponse(BaseModel):
|
||||
caption: str
|
||||
model: str
|
||||
|
||||
|
||||
class HealthResponse(BaseModel):
|
||||
status: str
|
||||
model: str
|
||||
vram_mb: int
|
||||
backend: str
|
||||
supports_embed: bool
|
||||
supports_caption: bool
|
||||
|
||||
|
||||
# ── App factory ───────────────────────────────────────────────────────────────
|
||||
|
||||
def create_app(
|
||||
model_path: str,
|
||||
backend: str = "siglip",
|
||||
device: str = "cuda",
|
||||
dtype: str = "float16",
|
||||
mock: bool = False,
|
||||
) -> FastAPI:
|
||||
app = FastAPI(title="cf-vision", version="0.1.0")
|
||||
_backend = make_vision_backend(
|
||||
model_path, backend=backend, device=device, dtype=dtype, mock=mock
|
||||
)
|
||||
logger.info(
|
||||
"cf-vision ready: model=%r backend=%r vram=%dMB",
|
||||
_backend.model_name, backend, _backend.vram_mb,
|
||||
)
|
||||
|
||||
@app.get("/health", response_model=HealthResponse)
|
||||
async def health() -> HealthResponse:
|
||||
return HealthResponse(
|
||||
status="ok",
|
||||
model=_backend.model_name,
|
||||
vram_mb=_backend.vram_mb,
|
||||
backend=backend,
|
||||
supports_embed=_backend.supports_embed,
|
||||
supports_caption=_backend.supports_caption,
|
||||
)
|
||||
|
||||
@app.post("/classify", response_model=VisionClassifyResponse)
|
||||
async def classify(
|
||||
image: UploadFile = File(..., description="Image file (JPEG, PNG, WEBP, ...)"),
|
||||
labels: str = Form(
|
||||
...,
|
||||
description=(
|
||||
"Candidate labels — either a JSON array "
|
||||
'(["cat","dog"]) or comma-separated (cat,dog)'
|
||||
),
|
||||
),
|
||||
) -> VisionClassifyResponse:
|
||||
image_bytes = await image.read()
|
||||
if not image_bytes:
|
||||
raise HTTPException(status_code=400, detail="Empty image file")
|
||||
|
||||
parsed_labels = _parse_labels(labels)
|
||||
if not parsed_labels:
|
||||
raise HTTPException(status_code=400, detail="At least one label is required")
|
||||
|
||||
try:
|
||||
result = _backend.classify(image_bytes, parsed_labels)
|
||||
except Exception as exc:
|
||||
logger.exception("classify failed")
|
||||
raise HTTPException(status_code=500, detail=str(exc)) from exc
|
||||
|
||||
return VisionClassifyResponse(
|
||||
labels=result.labels, scores=result.scores, model=result.model
|
||||
)
|
||||
|
||||
@app.post("/embed", response_model=VisionEmbedResponse)
|
||||
async def embed_image(
|
||||
image: UploadFile = File(..., description="Image file (JPEG, PNG, WEBP, ...)"),
|
||||
) -> VisionEmbedResponse:
|
||||
if not _backend.supports_embed:
|
||||
raise HTTPException(
|
||||
status_code=501,
|
||||
detail=(
|
||||
f"Backend '{backend}' does not support embedding. "
|
||||
"Use backend=siglip for embed()."
|
||||
),
|
||||
)
|
||||
|
||||
image_bytes = await image.read()
|
||||
if not image_bytes:
|
||||
raise HTTPException(status_code=400, detail="Empty image file")
|
||||
|
||||
try:
|
||||
result = _backend.embed(image_bytes)
|
||||
except Exception as exc:
|
||||
logger.exception("embed failed")
|
||||
raise HTTPException(status_code=500, detail=str(exc)) from exc
|
||||
|
||||
return VisionEmbedResponse(embedding=result.embedding or [], model=result.model)
|
||||
|
||||
@app.post("/caption", response_model=VisionCaptionResponse)
|
||||
async def caption_image(
|
||||
image: UploadFile = File(..., description="Image file (JPEG, PNG, WEBP, ...)"),
|
||||
prompt: str = Form(
|
||||
"",
|
||||
description="Optional instruction / question for the VLM",
|
||||
),
|
||||
) -> VisionCaptionResponse:
|
||||
if not _backend.supports_caption:
|
||||
raise HTTPException(
|
||||
status_code=501,
|
||||
detail=(
|
||||
f"Backend '{backend}' does not support caption generation. "
|
||||
"Use backend=vlm for caption()."
|
||||
),
|
||||
)
|
||||
|
||||
image_bytes = await image.read()
|
||||
if not image_bytes:
|
||||
raise HTTPException(status_code=400, detail="Empty image file")
|
||||
|
||||
try:
|
||||
result = _backend.caption(image_bytes, prompt=prompt)
|
||||
except Exception as exc:
|
||||
logger.exception("caption failed")
|
||||
raise HTTPException(status_code=500, detail=str(exc)) from exc
|
||||
|
||||
return VisionCaptionResponse(caption=result.caption or "", model=result.model)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
# ── Label parsing ─────────────────────────────────────────────────────────────
|
||||
|
||||
def _parse_labels(raw: str) -> list[str]:
|
||||
"""Accept JSON array or comma-separated label string."""
|
||||
stripped = raw.strip()
|
||||
if stripped.startswith("["):
|
||||
try:
|
||||
parsed = json.loads(stripped)
|
||||
if isinstance(parsed, list):
|
||||
return [str(x) for x in parsed]
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
return [lbl.strip() for lbl in stripped.split(",") if lbl.strip()]
|
||||
|
||||
|
||||
# ── CLI entry point ───────────────────────────────────────────────────────────
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="cf-vision — CircuitForge vision service")
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
default="google/siglip-so400m-patch14-384",
|
||||
help="HuggingFace model ID or local path",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--backend", default="siglip", choices=["siglip", "vlm"],
|
||||
help="Vision backend: siglip (classify+embed) or vlm (caption+classify)",
|
||||
)
|
||||
parser.add_argument("--port", type=int, default=8006)
|
||||
parser.add_argument("--host", default="0.0.0.0")
|
||||
parser.add_argument("--gpu-id", type=int, default=0)
|
||||
parser.add_argument("--device", default="cuda", choices=["cuda", "cpu"])
|
||||
parser.add_argument("--dtype", default="float16",
|
||||
choices=["float16", "bfloat16", "float32"])
|
||||
parser.add_argument("--mock", action="store_true",
|
||||
help="Run with mock backend (no GPU, for testing)")
|
||||
args = parser.parse_args()
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s %(levelname)s %(name)s %(message)s",
|
||||
)
|
||||
|
||||
if args.device == "cuda" and not args.mock:
|
||||
os.environ.setdefault("CUDA_VISIBLE_DEVICES", str(args.gpu_id))
|
||||
|
||||
mock = args.mock or os.environ.get("CF_VISION_MOCK", "") == "1"
|
||||
app = create_app(
|
||||
model_path=args.model,
|
||||
backend=args.backend,
|
||||
device=args.device,
|
||||
dtype=args.dtype,
|
||||
mock=mock,
|
||||
)
|
||||
|
||||
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
4
circuitforge_core/vision/backends/__init__.py
Normal file
4
circuitforge_core/vision/backends/__init__.py
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
from circuitforge_core.vision.backends.base import VisionBackend, VisionResult, make_vision_backend
|
||||
from circuitforge_core.vision.backends.mock import MockVisionBackend
|
||||
|
||||
__all__ = ["VisionBackend", "VisionResult", "make_vision_backend", "MockVisionBackend"]
|
||||
150
circuitforge_core/vision/backends/base.py
Normal file
150
circuitforge_core/vision/backends/base.py
Normal file
|
|
@ -0,0 +1,150 @@
|
|||
# circuitforge_core/vision/backends/base.py — VisionBackend Protocol + factory
|
||||
#
|
||||
# MIT licensed. The Protocol and mock are always importable without GPU deps.
|
||||
# Real backends require optional extras:
|
||||
# pip install -e "circuitforge-core[vision-siglip]" # SigLIP (default, ~1.4 GB VRAM)
|
||||
# pip install -e "circuitforge-core[vision-vlm]" # Full VLM (e.g. moondream, LLaVA)
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Protocol, runtime_checkable
|
||||
|
||||
|
||||
# ── Result type ───────────────────────────────────────────────────────────────
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class VisionResult:
|
||||
"""
|
||||
Standard result from any VisionBackend call.
|
||||
|
||||
classify() → labels + scores populated; embedding/caption may be None.
|
||||
embed() → embedding populated; labels/scores empty.
|
||||
caption() → caption populated; labels/scores empty; embedding None.
|
||||
"""
|
||||
labels: list[str] = field(default_factory=list)
|
||||
scores: list[float] = field(default_factory=list)
|
||||
embedding: list[float] | None = None
|
||||
caption: str | None = None
|
||||
model: str = ""
|
||||
|
||||
def top(self, n: int = 1) -> list[tuple[str, float]]:
|
||||
"""Return the top-n (label, score) pairs sorted by descending score."""
|
||||
paired = sorted(zip(self.labels, self.scores), key=lambda x: x[1], reverse=True)
|
||||
return paired[:n]
|
||||
|
||||
|
||||
# ── Protocol ──────────────────────────────────────────────────────────────────
|
||||
|
||||
@runtime_checkable
|
||||
class VisionBackend(Protocol):
|
||||
"""
|
||||
Abstract interface for vision backends.
|
||||
|
||||
All backends load their model once at construction time.
|
||||
|
||||
SigLIP backends implement classify() and embed() but raise NotImplementedError
|
||||
for caption(). VLM backends implement caption() and a prompt-based classify()
|
||||
but raise NotImplementedError for embed().
|
||||
"""
|
||||
|
||||
def classify(self, image: bytes, labels: list[str]) -> VisionResult:
|
||||
"""
|
||||
Zero-shot image classification.
|
||||
|
||||
labels: candidate text descriptions; scores are returned in the same order.
|
||||
SigLIP uses sigmoid similarity; VLM prompts for each label.
|
||||
"""
|
||||
...
|
||||
|
||||
def embed(self, image: bytes) -> VisionResult:
|
||||
"""
|
||||
Return an image embedding vector.
|
||||
|
||||
Available on SigLIP backends. Raises NotImplementedError on VLM backends.
|
||||
embedding is a list of floats with length == model hidden dim.
|
||||
"""
|
||||
...
|
||||
|
||||
def caption(self, image: bytes, prompt: str = "") -> VisionResult:
|
||||
"""
|
||||
Generate a text description of the image.
|
||||
|
||||
Available on VLM backends. Raises NotImplementedError on SigLIP backends.
|
||||
prompt is an optional instruction; defaults to a generic description request.
|
||||
"""
|
||||
...
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
"""Identifier for the loaded model (HuggingFace ID or path stem)."""
|
||||
...
|
||||
|
||||
@property
|
||||
def vram_mb(self) -> int:
|
||||
"""Approximate VRAM footprint in MB. Used by cf-orch service registry."""
|
||||
...
|
||||
|
||||
@property
|
||||
def supports_embed(self) -> bool:
|
||||
"""True if embed() is implemented (SigLIP backends)."""
|
||||
...
|
||||
|
||||
@property
|
||||
def supports_caption(self) -> bool:
|
||||
"""True if caption() is implemented (VLM backends)."""
|
||||
...
|
||||
|
||||
|
||||
# ── Factory ───────────────────────────────────────────────────────────────────
|
||||
|
||||
def make_vision_backend(
|
||||
model_path: str,
|
||||
backend: str | None = None,
|
||||
mock: bool | None = None,
|
||||
device: str = "cuda",
|
||||
dtype: str = "float16",
|
||||
) -> VisionBackend:
|
||||
"""
|
||||
Return a VisionBackend for the given model.
|
||||
|
||||
mock=True or CF_VISION_MOCK=1 → MockVisionBackend (no GPU, no model file needed)
|
||||
backend="siglip" → SigLIPBackend (default; classify + embed)
|
||||
backend="vlm" → VLMBackend (caption + prompt-based classify)
|
||||
|
||||
Auto-detection: if model_path contains "siglip" → SigLIPBackend;
|
||||
otherwise defaults to siglip unless backend is explicitly "vlm".
|
||||
|
||||
device and dtype are forwarded to the real backends and ignored by mock.
|
||||
"""
|
||||
use_mock = mock if mock is not None else os.environ.get("CF_VISION_MOCK", "") == "1"
|
||||
if use_mock:
|
||||
from circuitforge_core.vision.backends.mock import MockVisionBackend
|
||||
return MockVisionBackend(model_name=model_path)
|
||||
|
||||
resolved = backend or os.environ.get("CF_VISION_BACKEND", "")
|
||||
if not resolved:
|
||||
# Auto-detect from model path
|
||||
resolved = "vlm" if _looks_like_vlm(model_path) else "siglip"
|
||||
|
||||
if resolved == "siglip":
|
||||
from circuitforge_core.vision.backends.siglip import SigLIPBackend
|
||||
return SigLIPBackend(model_path=model_path, device=device, dtype=dtype)
|
||||
|
||||
if resolved == "vlm":
|
||||
from circuitforge_core.vision.backends.vlm import VLMBackend
|
||||
return VLMBackend(model_path=model_path, device=device, dtype=dtype)
|
||||
|
||||
raise ValueError(
|
||||
f"Unknown vision backend {resolved!r}. "
|
||||
"Expected 'siglip' or 'vlm'. Set CF_VISION_BACKEND or pass backend= explicitly."
|
||||
)
|
||||
|
||||
|
||||
def _looks_like_vlm(model_path: str) -> bool:
|
||||
"""Heuristic: names associated with generative VLMs."""
|
||||
_vlm_hints = ("llava", "moondream", "qwen-vl", "qwenvl", "idefics",
|
||||
"cogvlm", "internvl", "phi-3-vision", "phi3vision",
|
||||
"dolphin", "paligemma")
|
||||
lower = model_path.lower()
|
||||
return any(h in lower for h in _vlm_hints)
|
||||
62
circuitforge_core/vision/backends/mock.py
Normal file
62
circuitforge_core/vision/backends/mock.py
Normal file
|
|
@ -0,0 +1,62 @@
|
|||
# circuitforge_core/vision/backends/mock.py — MockVisionBackend
|
||||
#
|
||||
# Deterministic stub for tests and CI. No GPU, no model files required.
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
|
||||
from circuitforge_core.vision.backends.base import VisionBackend, VisionResult
|
||||
|
||||
|
||||
class MockVisionBackend:
|
||||
"""
|
||||
Mock VisionBackend for testing.
|
||||
|
||||
classify() returns uniform scores normalised to 1/n per label.
|
||||
embed() returns a unit vector of length 512 (all values 1/sqrt(512)).
|
||||
caption() returns a canned string.
|
||||
"""
|
||||
|
||||
def __init__(self, model_name: str = "mock") -> None:
|
||||
self._model_name = model_name
|
||||
|
||||
# ── VisionBackend Protocol ─────────────────────────────────────────────────
|
||||
|
||||
def classify(self, image: bytes, labels: list[str]) -> VisionResult:
|
||||
n = max(len(labels), 1)
|
||||
return VisionResult(
|
||||
labels=list(labels),
|
||||
scores=[1.0 / n] * len(labels),
|
||||
model=self._model_name,
|
||||
)
|
||||
|
||||
def embed(self, image: bytes) -> VisionResult:
|
||||
dim = 512
|
||||
val = 1.0 / math.sqrt(dim)
|
||||
return VisionResult(embedding=[val] * dim, model=self._model_name)
|
||||
|
||||
def caption(self, image: bytes, prompt: str = "") -> VisionResult:
|
||||
return VisionResult(
|
||||
caption="A mock image description for testing purposes.",
|
||||
model=self._model_name,
|
||||
)
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
return self._model_name
|
||||
|
||||
@property
|
||||
def vram_mb(self) -> int:
|
||||
return 0
|
||||
|
||||
@property
|
||||
def supports_embed(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def supports_caption(self) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
# Verify protocol compliance at import time (catches missing methods early).
|
||||
assert isinstance(MockVisionBackend(), VisionBackend)
|
||||
138
circuitforge_core/vision/backends/siglip.py
Normal file
138
circuitforge_core/vision/backends/siglip.py
Normal file
|
|
@ -0,0 +1,138 @@
|
|||
# circuitforge_core/vision/backends/siglip.py — SigLIPBackend
|
||||
#
|
||||
# Requires: pip install -e "circuitforge-core[vision-siglip]"
|
||||
# Default model: google/siglip-so400m-patch14-384 (~1.4 GB VRAM)
|
||||
#
|
||||
# SigLIP uses sigmoid cross-entropy rather than softmax over labels, so each
|
||||
# score is an independent 0–1 probability. This is better than CLIP for
|
||||
# multi-label classification and document routing.
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
|
||||
from circuitforge_core.vision.backends.base import VisionResult
|
||||
|
||||
_DEFAULT_MODEL = "google/siglip-so400m-patch14-384"
|
||||
|
||||
# VRAM footprints by model variant (MB, fp16).
|
||||
_VRAM_TABLE: dict[str, int] = {
|
||||
"siglip-so400m-patch14-384": 1440,
|
||||
"siglip-so400m-patch14-224": 1440,
|
||||
"siglip-base-patch16-224": 340,
|
||||
"siglip-large-patch16-256": 690,
|
||||
}
|
||||
|
||||
|
||||
def _estimate_vram(model_path: str) -> int:
|
||||
lower = model_path.lower()
|
||||
for key, mb in _VRAM_TABLE.items():
|
||||
if key in lower:
|
||||
return mb
|
||||
return 1500 # conservative default for unknown so400m variants
|
||||
|
||||
|
||||
class SigLIPBackend:
|
||||
"""
|
||||
Image classification + embedding via Google SigLIP.
|
||||
|
||||
classify() returns sigmoid similarity scores for each candidate label —
|
||||
independent probabilities, not a softmax distribution.
|
||||
embed() returns the CLS-pool image embedding (normalised).
|
||||
caption() raises NotImplementedError — use VLMBackend for generation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_path: str = _DEFAULT_MODEL,
|
||||
device: str = "cuda",
|
||||
dtype: str = "float16",
|
||||
) -> None:
|
||||
try:
|
||||
import torch
|
||||
from transformers import AutoProcessor, AutoModel
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"SigLIPBackend requires torch and transformers. "
|
||||
"Install with: pip install -e 'circuitforge-core[vision-siglip]'"
|
||||
) from exc
|
||||
|
||||
import torch as _torch
|
||||
|
||||
self._device = device
|
||||
self._dtype_str = dtype
|
||||
self._torch_dtype = (
|
||||
_torch.float16 if dtype == "float16"
|
||||
else _torch.bfloat16 if dtype == "bfloat16"
|
||||
else _torch.float32
|
||||
)
|
||||
self._model_path = model_path
|
||||
self._vram_mb = _estimate_vram(model_path)
|
||||
|
||||
self._processor = AutoProcessor.from_pretrained(model_path)
|
||||
self._model = AutoModel.from_pretrained(
|
||||
model_path,
|
||||
torch_dtype=self._torch_dtype,
|
||||
).to(device)
|
||||
# Set inference mode (train(False) == model.eval() without grad tracking)
|
||||
self._model.train(False)
|
||||
|
||||
# ── VisionBackend Protocol ─────────────────────────────────────────────────
|
||||
|
||||
def classify(self, image: bytes, labels: list[str]) -> VisionResult:
|
||||
"""Zero-shot sigmoid classification — scores are independent per label."""
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
pil_img = Image.open(io.BytesIO(image)).convert("RGB")
|
||||
inputs = self._processor(
|
||||
text=labels,
|
||||
images=pil_img,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
).to(self._device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = self._model(**inputs)
|
||||
# logits_per_image: (1, num_labels) — raw SigLIP logits
|
||||
logits = outputs.logits_per_image[0]
|
||||
scores = torch.sigmoid(logits).cpu().float().tolist()
|
||||
|
||||
return VisionResult(labels=list(labels), scores=scores, model=self.model_name)
|
||||
|
||||
def embed(self, image: bytes) -> VisionResult:
|
||||
"""Return normalised image embedding (CLS pool, L2-normalised)."""
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
pil_img = Image.open(io.BytesIO(image)).convert("RGB")
|
||||
inputs = self._processor(images=pil_img, return_tensors="pt").to(self._device)
|
||||
|
||||
with torch.no_grad():
|
||||
image_features = self._model.get_image_features(**inputs)
|
||||
# L2-normalise so dot-product == cosine similarity
|
||||
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
|
||||
|
||||
embedding = image_features[0].cpu().float().tolist()
|
||||
return VisionResult(embedding=embedding, model=self.model_name)
|
||||
|
||||
def caption(self, image: bytes, prompt: str = "") -> VisionResult:
|
||||
raise NotImplementedError(
|
||||
"SigLIPBackend does not support caption generation. "
|
||||
"Use backend='vlm' (VLMBackend) for image-to-text generation."
|
||||
)
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
return self._model_path.split("/")[-1]
|
||||
|
||||
@property
|
||||
def vram_mb(self) -> int:
|
||||
return self._vram_mb
|
||||
|
||||
@property
|
||||
def supports_embed(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def supports_caption(self) -> bool:
|
||||
return False
|
||||
181
circuitforge_core/vision/backends/vlm.py
Normal file
181
circuitforge_core/vision/backends/vlm.py
Normal file
|
|
@ -0,0 +1,181 @@
|
|||
# circuitforge_core/vision/backends/vlm.py — VLMBackend
|
||||
#
|
||||
# Requires: pip install -e "circuitforge-core[vision-vlm]"
|
||||
#
|
||||
# Supports any HuggingFace AutoModelForVision2Seq-compatible VLM.
|
||||
# Validated models (VRAM fp16):
|
||||
# vikhyatk/moondream2 ~2 GB — fast, lightweight, good for documents
|
||||
# llava-hf/llava-1.5-7b-hf ~14 GB — strong general VQA
|
||||
# Qwen/Qwen2-VL-7B-Instruct ~16 GB — multilingual, structured output friendly
|
||||
#
|
||||
# VLMBackend implements caption() (generative) and a prompt-based classify()
|
||||
# that asks the model to pick from a list. embed() raises NotImplementedError.
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
|
||||
from circuitforge_core.vision.backends.base import VisionResult
|
||||
|
||||
# VRAM estimates (MB, fp16) keyed by lowercase model name fragment.
|
||||
_VRAM_TABLE: dict[str, int] = {
|
||||
"moondream2": 2000,
|
||||
"moondream": 2000,
|
||||
"llava-1.5-7b": 14000,
|
||||
"llava-7b": 14000,
|
||||
"qwen2-vl-7b": 16000,
|
||||
"qwen-vl-7b": 16000,
|
||||
"llava-1.5-13b": 26000,
|
||||
"phi-3-vision": 8000,
|
||||
"phi3-vision": 8000,
|
||||
"paligemma": 6000,
|
||||
"idefics": 12000,
|
||||
"cogvlm": 14000,
|
||||
}
|
||||
|
||||
_CLASSIFY_PROMPT_TMPL = (
|
||||
"Choose the single best label for this image from the following options: "
|
||||
"{labels}. Reply with ONLY the label text, nothing else."
|
||||
)
|
||||
|
||||
|
||||
def _estimate_vram(model_path: str) -> int:
|
||||
lower = model_path.lower()
|
||||
for key, mb in _VRAM_TABLE.items():
|
||||
if key in lower:
|
||||
return mb
|
||||
return 8000 # safe default for unknown 7B-class VLMs
|
||||
|
||||
|
||||
class VLMBackend:
|
||||
"""
|
||||
Generative vision-language model backend.
|
||||
|
||||
caption() generates free-form text from an image + optional prompt.
|
||||
classify() prompts the model to select from candidate labels.
|
||||
embed() raises NotImplementedError — use SigLIPBackend for embeddings.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_path: str,
|
||||
device: str = "cuda",
|
||||
dtype: str = "float16",
|
||||
max_new_tokens: int = 512,
|
||||
) -> None:
|
||||
try:
|
||||
import torch
|
||||
from transformers import AutoProcessor, AutoModelForVision2Seq
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"VLMBackend requires torch and transformers. "
|
||||
"Install with: pip install -e 'circuitforge-core[vision-vlm]'"
|
||||
) from exc
|
||||
|
||||
import torch as _torch
|
||||
|
||||
self._device = device
|
||||
self._max_new_tokens = max_new_tokens
|
||||
self._model_path = model_path
|
||||
self._vram_mb = _estimate_vram(model_path)
|
||||
|
||||
torch_dtype = (
|
||||
_torch.float16 if dtype == "float16"
|
||||
else _torch.bfloat16 if dtype == "bfloat16"
|
||||
else _torch.float32
|
||||
)
|
||||
|
||||
self._processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
|
||||
self._model = AutoModelForVision2Seq.from_pretrained(
|
||||
model_path,
|
||||
torch_dtype=torch_dtype,
|
||||
trust_remote_code=True,
|
||||
).to(device)
|
||||
# Put model in inference mode — disables dropout/batchnorm training behaviour
|
||||
self._model.train(False)
|
||||
|
||||
# ── VisionBackend Protocol ─────────────────────────────────────────────────
|
||||
|
||||
def caption(self, image: bytes, prompt: str = "") -> VisionResult:
|
||||
"""Generate a text description of the image."""
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
pil_img = Image.open(io.BytesIO(image)).convert("RGB")
|
||||
effective_prompt = prompt or "Describe this image in detail."
|
||||
|
||||
inputs = self._processor(
|
||||
text=effective_prompt,
|
||||
images=pil_img,
|
||||
return_tensors="pt",
|
||||
).to(self._device)
|
||||
|
||||
with torch.no_grad():
|
||||
generated_ids = self._model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=self._max_new_tokens,
|
||||
do_sample=False,
|
||||
)
|
||||
|
||||
# Strip the input prompt tokens from the generated output
|
||||
input_len = inputs["input_ids"].shape[1]
|
||||
output_ids = generated_ids[0][input_len:]
|
||||
text = self._processor.decode(output_ids, skip_special_tokens=True).strip()
|
||||
|
||||
return VisionResult(caption=text, model=self.model_name)
|
||||
|
||||
def classify(self, image: bytes, labels: list[str]) -> VisionResult:
|
||||
"""
|
||||
Prompt-based zero-shot classification.
|
||||
|
||||
Asks the VLM to choose a label from the provided list. The returned
|
||||
scores are binary (1.0 for the selected label, 0.0 for others) since
|
||||
VLMs don't expose per-label logits the same way SigLIP does.
|
||||
For soft scores, use SigLIPBackend.
|
||||
"""
|
||||
labels_str = ", ".join(f'"{lbl}"' for lbl in labels)
|
||||
prompt = _CLASSIFY_PROMPT_TMPL.format(labels=labels_str)
|
||||
result = self.caption(image, prompt=prompt)
|
||||
raw = (result.caption or "").strip().strip('"').strip("'")
|
||||
|
||||
matched = _match_label(raw, labels)
|
||||
scores = [1.0 if lbl == matched else 0.0 for lbl in labels]
|
||||
return VisionResult(labels=list(labels), scores=scores, model=self.model_name)
|
||||
|
||||
def embed(self, image: bytes) -> VisionResult:
|
||||
raise NotImplementedError(
|
||||
"VLMBackend does not support image embeddings. "
|
||||
"Use backend='siglip' (SigLIPBackend) for embed()."
|
||||
)
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
return self._model_path.split("/")[-1]
|
||||
|
||||
@property
|
||||
def vram_mb(self) -> int:
|
||||
return self._vram_mb
|
||||
|
||||
@property
|
||||
def supports_embed(self) -> bool:
|
||||
return False
|
||||
|
||||
@property
|
||||
def supports_caption(self) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
# ── Helpers ───────────────────────────────────────────────────────────────────
|
||||
|
||||
def _match_label(raw: str, labels: list[str]) -> str:
|
||||
"""Return the best matching label from the VLM's free-form response."""
|
||||
raw_lower = raw.lower()
|
||||
for lbl in labels:
|
||||
if lbl.lower() == raw_lower:
|
||||
return lbl
|
||||
for lbl in labels:
|
||||
if raw_lower.startswith(lbl.lower()) or lbl.lower().startswith(raw_lower):
|
||||
return lbl
|
||||
for lbl in labels:
|
||||
if lbl.lower() in raw_lower or raw_lower in lbl.lower():
|
||||
return lbl
|
||||
return labels[0] if labels else raw
|
||||
|
|
@ -1,26 +0,0 @@
|
|||
# circuitforge_core/vision/router.py — shim
|
||||
#
|
||||
# The vision module has been extracted to the standalone cf-vision repo.
|
||||
# This shim re-exports VisionRouter so existing imports continue to work.
|
||||
# New code should import directly from cf_vision:
|
||||
#
|
||||
# from cf_vision.router import VisionRouter
|
||||
# from cf_vision.models import ImageFrame
|
||||
#
|
||||
# Install: pip install -e ../cf-vision
|
||||
from __future__ import annotations
|
||||
|
||||
try:
|
||||
from cf_vision.router import VisionRouter # noqa: F401
|
||||
from cf_vision.models import ImageFrame # noqa: F401
|
||||
except ImportError:
|
||||
# cf-vision not installed — fall back to the stub so products that don't
|
||||
# need vision yet don't hard-fail on import.
|
||||
class VisionRouter: # type: ignore[no-redef]
|
||||
"""Stub — install cf-vision: pip install -e ../cf-vision"""
|
||||
|
||||
def analyze(self, image_bytes: bytes, prompt: str = "", task: str = "document"):
|
||||
raise ImportError(
|
||||
"cf-vision is not installed. "
|
||||
"Run: pip install -e ../cf-vision"
|
||||
)
|
||||
|
|
@ -49,6 +49,23 @@ tts-service = [
|
|||
"uvicorn[standard]>=0.29",
|
||||
"python-multipart>=0.0.9",
|
||||
]
|
||||
vision-siglip = [
|
||||
"torch>=2.0",
|
||||
"transformers>=4.40",
|
||||
"Pillow>=10.0",
|
||||
]
|
||||
vision-vlm = [
|
||||
"torch>=2.0",
|
||||
"transformers>=4.40",
|
||||
"Pillow>=10.0",
|
||||
"accelerate>=0.27",
|
||||
]
|
||||
vision-service = [
|
||||
"circuitforge-core[vision-siglip]",
|
||||
"fastapi>=0.110",
|
||||
"uvicorn[standard]>=0.29",
|
||||
"python-multipart>=0.0.9",
|
||||
]
|
||||
dev = [
|
||||
"circuitforge-core[manage]",
|
||||
"pytest>=8.0",
|
||||
|
|
|
|||
0
tests/test_vision/__init__.py
Normal file
0
tests/test_vision/__init__.py
Normal file
203
tests/test_vision/test_app.py
Normal file
203
tests/test_vision/test_app.py
Normal file
|
|
@ -0,0 +1,203 @@
|
|||
"""
|
||||
Tests for the cf-vision FastAPI service (mock backend).
|
||||
|
||||
All tests use the mock backend — no GPU or model files required.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import io
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from circuitforge_core.vision.app import create_app, _parse_labels
|
||||
|
||||
|
||||
# ── Fixtures ──────────────────────────────────────────────────────────────────
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def siglip_client() -> TestClient:
|
||||
"""Client backed by mock-siglip (supports classify + embed, not caption)."""
|
||||
app = create_app(model_path="mock-siglip", backend="siglip", mock=True)
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def vlm_client() -> TestClient:
|
||||
"""Client backed by mock-vlm (mock supports all; VLM contract tested separately)."""
|
||||
app = create_app(model_path="mock-vlm", backend="vlm", mock=True)
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
FAKE_IMAGE = b"\x89PNG\r\n\x1a\n" + b"\x00" * 100
|
||||
|
||||
|
||||
def _image_upload(data: bytes = FAKE_IMAGE) -> tuple[str, tuple]:
|
||||
return ("image", ("test.png", io.BytesIO(data), "image/png"))
|
||||
|
||||
|
||||
# ── /health ───────────────────────────────────────────────────────────────────
|
||||
|
||||
def test_health_ok(siglip_client: TestClient) -> None:
|
||||
resp = siglip_client.get("/health")
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert body["status"] == "ok"
|
||||
assert "model" in body
|
||||
assert "vram_mb" in body
|
||||
assert "backend" in body
|
||||
|
||||
|
||||
def test_health_backend_field(siglip_client: TestClient) -> None:
|
||||
resp = siglip_client.get("/health")
|
||||
assert resp.json()["backend"] == "siglip"
|
||||
|
||||
|
||||
def test_health_supports_fields(siglip_client: TestClient) -> None:
|
||||
body = siglip_client.get("/health").json()
|
||||
assert "supports_embed" in body
|
||||
assert "supports_caption" in body
|
||||
|
||||
|
||||
# ── /classify ─────────────────────────────────────────────────────────────────
|
||||
|
||||
def test_classify_json_labels(siglip_client: TestClient) -> None:
|
||||
resp = siglip_client.post(
|
||||
"/classify",
|
||||
files=[_image_upload()],
|
||||
data={"labels": json.dumps(["cat", "dog", "bird"])},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert body["labels"] == ["cat", "dog", "bird"]
|
||||
assert len(body["scores"]) == 3
|
||||
|
||||
|
||||
def test_classify_csv_labels(siglip_client: TestClient) -> None:
|
||||
resp = siglip_client.post(
|
||||
"/classify",
|
||||
files=[_image_upload()],
|
||||
data={"labels": "cat, dog, bird"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["labels"] == ["cat", "dog", "bird"]
|
||||
|
||||
|
||||
def test_classify_single_label(siglip_client: TestClient) -> None:
|
||||
resp = siglip_client.post(
|
||||
"/classify",
|
||||
files=[_image_upload()],
|
||||
data={"labels": "document"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert body["labels"] == ["document"]
|
||||
assert len(body["scores"]) == 1
|
||||
|
||||
|
||||
def test_classify_empty_labels_4xx(siglip_client: TestClient) -> None:
|
||||
# Empty labels should yield a 4xx — either our 400 or FastAPI's 422
|
||||
# depending on how the empty string is handled by the form layer.
|
||||
resp = siglip_client.post(
|
||||
"/classify",
|
||||
files=[_image_upload()],
|
||||
data={"labels": ""},
|
||||
)
|
||||
assert resp.status_code in (400, 422)
|
||||
|
||||
|
||||
def test_classify_empty_image_400(siglip_client: TestClient) -> None:
|
||||
resp = siglip_client.post(
|
||||
"/classify",
|
||||
files=[("image", ("empty.png", io.BytesIO(b""), "image/png"))],
|
||||
data={"labels": "cat"},
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
|
||||
|
||||
def test_classify_model_in_response(siglip_client: TestClient) -> None:
|
||||
resp = siglip_client.post(
|
||||
"/classify",
|
||||
files=[_image_upload()],
|
||||
data={"labels": "cat"},
|
||||
)
|
||||
assert "model" in resp.json()
|
||||
|
||||
|
||||
# ── /embed ────────────────────────────────────────────────────────────────────
|
||||
|
||||
def test_embed_returns_vector(siglip_client: TestClient) -> None:
|
||||
resp = siglip_client.post("/embed", files=[_image_upload()])
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert "embedding" in body
|
||||
assert isinstance(body["embedding"], list)
|
||||
assert len(body["embedding"]) > 0
|
||||
|
||||
|
||||
def test_embed_empty_image_400(siglip_client: TestClient) -> None:
|
||||
resp = siglip_client.post(
|
||||
"/embed",
|
||||
files=[("image", ("empty.png", io.BytesIO(b""), "image/png"))],
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
|
||||
|
||||
def test_embed_model_in_response(siglip_client: TestClient) -> None:
|
||||
resp = siglip_client.post("/embed", files=[_image_upload()])
|
||||
assert "model" in resp.json()
|
||||
|
||||
|
||||
# ── /caption ──────────────────────────────────────────────────────────────────
|
||||
|
||||
def test_caption_returns_text(vlm_client: TestClient) -> None:
|
||||
resp = vlm_client.post(
|
||||
"/caption",
|
||||
files=[_image_upload()],
|
||||
data={"prompt": ""},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert "caption" in body
|
||||
assert isinstance(body["caption"], str)
|
||||
|
||||
|
||||
def test_caption_with_prompt(vlm_client: TestClient) -> None:
|
||||
resp = vlm_client.post(
|
||||
"/caption",
|
||||
files=[_image_upload()],
|
||||
data={"prompt": "What text appears here?"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
|
||||
|
||||
def test_caption_empty_image_400(vlm_client: TestClient) -> None:
|
||||
resp = vlm_client.post(
|
||||
"/caption",
|
||||
files=[("image", ("empty.png", io.BytesIO(b""), "image/png"))],
|
||||
data={"prompt": ""},
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
|
||||
|
||||
# ── Label parser ──────────────────────────────────────────────────────────────
|
||||
|
||||
def test_parse_labels_json_array() -> None:
|
||||
assert _parse_labels('["cat", "dog"]') == ["cat", "dog"]
|
||||
|
||||
|
||||
def test_parse_labels_csv() -> None:
|
||||
assert _parse_labels("cat, dog, bird") == ["cat", "dog", "bird"]
|
||||
|
||||
|
||||
def test_parse_labels_single() -> None:
|
||||
assert _parse_labels("document") == ["document"]
|
||||
|
||||
|
||||
def test_parse_labels_empty() -> None:
|
||||
assert _parse_labels("") == []
|
||||
|
||||
|
||||
def test_parse_labels_whitespace_trimmed() -> None:
|
||||
assert _parse_labels(" cat , dog ") == ["cat", "dog"]
|
||||
247
tests/test_vision/test_backend.py
Normal file
247
tests/test_vision/test_backend.py
Normal file
|
|
@ -0,0 +1,247 @@
|
|||
"""
|
||||
Tests for cf-vision backends (mock) and factory routing.
|
||||
|
||||
Real SigLIP/VLM backends are not tested here — they require GPU + model downloads.
|
||||
The mock backend exercises the full Protocol surface so we can verify the contract
|
||||
without hardware dependencies.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from circuitforge_core.vision.backends.base import (
|
||||
VisionBackend,
|
||||
VisionResult,
|
||||
make_vision_backend,
|
||||
)
|
||||
from circuitforge_core.vision.backends.mock import MockVisionBackend
|
||||
|
||||
|
||||
# ── Fixtures ──────────────────────────────────────────────────────────────────
|
||||
|
||||
FAKE_IMAGE = b"\x89PNG\r\n\x1a\n" + b"\x00" * 100 # Not a real PNG, but enough for mock
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def mock_backend() -> MockVisionBackend:
|
||||
return MockVisionBackend(model_name="test-mock")
|
||||
|
||||
|
||||
# ── Protocol compliance ───────────────────────────────────────────────────────
|
||||
|
||||
def test_mock_is_vision_backend(mock_backend: MockVisionBackend) -> None:
|
||||
assert isinstance(mock_backend, VisionBackend)
|
||||
|
||||
|
||||
def test_mock_model_name(mock_backend: MockVisionBackend) -> None:
|
||||
assert mock_backend.model_name == "test-mock"
|
||||
|
||||
|
||||
def test_mock_vram_mb(mock_backend: MockVisionBackend) -> None:
|
||||
assert mock_backend.vram_mb == 0
|
||||
|
||||
|
||||
def test_mock_supports_embed(mock_backend: MockVisionBackend) -> None:
|
||||
assert mock_backend.supports_embed is True
|
||||
|
||||
|
||||
def test_mock_supports_caption(mock_backend: MockVisionBackend) -> None:
|
||||
assert mock_backend.supports_caption is True
|
||||
|
||||
|
||||
# ── classify() ───────────────────────────────────────────────────────────────
|
||||
|
||||
def test_classify_returns_vision_result(mock_backend: MockVisionBackend) -> None:
|
||||
result = mock_backend.classify(FAKE_IMAGE, ["cat", "dog", "bird"])
|
||||
assert isinstance(result, VisionResult)
|
||||
|
||||
|
||||
def test_classify_labels_preserved(mock_backend: MockVisionBackend) -> None:
|
||||
labels = ["cat", "dog", "bird"]
|
||||
result = mock_backend.classify(FAKE_IMAGE, labels)
|
||||
assert result.labels == labels
|
||||
|
||||
|
||||
def test_classify_scores_length_matches_labels(mock_backend: MockVisionBackend) -> None:
|
||||
labels = ["cat", "dog", "bird"]
|
||||
result = mock_backend.classify(FAKE_IMAGE, labels)
|
||||
assert len(result.scores) == len(labels)
|
||||
|
||||
|
||||
def test_classify_uniform_scores(mock_backend: MockVisionBackend) -> None:
|
||||
labels = ["cat", "dog", "bird"]
|
||||
result = mock_backend.classify(FAKE_IMAGE, labels)
|
||||
expected = 1.0 / 3
|
||||
for score in result.scores:
|
||||
assert abs(score - expected) < 1e-9
|
||||
|
||||
|
||||
def test_classify_single_label(mock_backend: MockVisionBackend) -> None:
|
||||
result = mock_backend.classify(FAKE_IMAGE, ["document"])
|
||||
assert result.labels == ["document"]
|
||||
assert abs(result.scores[0] - 1.0) < 1e-9
|
||||
|
||||
|
||||
def test_classify_model_name_in_result(mock_backend: MockVisionBackend) -> None:
|
||||
result = mock_backend.classify(FAKE_IMAGE, ["x"])
|
||||
assert result.model == "test-mock"
|
||||
|
||||
|
||||
# ── embed() ──────────────────────────────────────────────────────────────────
|
||||
|
||||
def test_embed_returns_vision_result(mock_backend: MockVisionBackend) -> None:
|
||||
result = mock_backend.embed(FAKE_IMAGE)
|
||||
assert isinstance(result, VisionResult)
|
||||
|
||||
|
||||
def test_embed_returns_embedding(mock_backend: MockVisionBackend) -> None:
|
||||
result = mock_backend.embed(FAKE_IMAGE)
|
||||
assert result.embedding is not None
|
||||
assert len(result.embedding) == 512
|
||||
|
||||
|
||||
def test_embed_is_unit_vector(mock_backend: MockVisionBackend) -> None:
|
||||
result = mock_backend.embed(FAKE_IMAGE)
|
||||
magnitude = math.sqrt(sum(v * v for v in result.embedding))
|
||||
assert abs(magnitude - 1.0) < 1e-6
|
||||
|
||||
|
||||
def test_embed_labels_empty(mock_backend: MockVisionBackend) -> None:
|
||||
result = mock_backend.embed(FAKE_IMAGE)
|
||||
assert result.labels == []
|
||||
assert result.scores == []
|
||||
|
||||
|
||||
def test_embed_model_name_in_result(mock_backend: MockVisionBackend) -> None:
|
||||
result = mock_backend.embed(FAKE_IMAGE)
|
||||
assert result.model == "test-mock"
|
||||
|
||||
|
||||
# ── caption() ────────────────────────────────────────────────────────────────
|
||||
|
||||
def test_caption_returns_vision_result(mock_backend: MockVisionBackend) -> None:
|
||||
result = mock_backend.caption(FAKE_IMAGE)
|
||||
assert isinstance(result, VisionResult)
|
||||
|
||||
|
||||
def test_caption_returns_string(mock_backend: MockVisionBackend) -> None:
|
||||
result = mock_backend.caption(FAKE_IMAGE)
|
||||
assert isinstance(result.caption, str)
|
||||
assert len(result.caption) > 0
|
||||
|
||||
|
||||
def test_caption_with_prompt(mock_backend: MockVisionBackend) -> None:
|
||||
result = mock_backend.caption(FAKE_IMAGE, prompt="What is in this image?")
|
||||
assert result.caption is not None
|
||||
|
||||
|
||||
def test_caption_model_name_in_result(mock_backend: MockVisionBackend) -> None:
|
||||
result = mock_backend.caption(FAKE_IMAGE)
|
||||
assert result.model == "test-mock"
|
||||
|
||||
|
||||
# ── VisionResult helpers ──────────────────────────────────────────────────────
|
||||
|
||||
def test_top_returns_sorted_pairs() -> None:
|
||||
result = VisionResult(
|
||||
labels=["cat", "dog", "bird"],
|
||||
scores=[0.3, 0.6, 0.1],
|
||||
)
|
||||
top = result.top(2)
|
||||
assert top[0] == ("dog", 0.6)
|
||||
assert top[1] == ("cat", 0.3)
|
||||
|
||||
|
||||
def test_top_default_n1() -> None:
|
||||
result = VisionResult(labels=["cat", "dog"], scores=[0.4, 0.9])
|
||||
assert result.top() == [("dog", 0.9)]
|
||||
|
||||
|
||||
# ── Factory routing ───────────────────────────────────────────────────────────
|
||||
|
||||
def test_factory_mock_flag() -> None:
|
||||
backend = make_vision_backend("any-model", mock=True)
|
||||
assert isinstance(backend, MockVisionBackend)
|
||||
|
||||
|
||||
def test_factory_mock_env(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setenv("CF_VISION_MOCK", "1")
|
||||
backend = make_vision_backend("any-model")
|
||||
assert isinstance(backend, MockVisionBackend)
|
||||
|
||||
|
||||
def test_factory_mock_model_name() -> None:
|
||||
backend = make_vision_backend("google/siglip-so400m-patch14-384", mock=True)
|
||||
assert backend.model_name == "google/siglip-so400m-patch14-384"
|
||||
|
||||
|
||||
def test_factory_unknown_backend_raises() -> None:
|
||||
with pytest.raises(ValueError, match="Unknown vision backend"):
|
||||
make_vision_backend("any-model", backend="nonexistent", mock=False)
|
||||
|
||||
|
||||
def test_factory_vlm_autodetect_moondream(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Auto-detection should select VLM for moondream model paths."""
|
||||
# We mock at the import level to avoid requiring GPU deps
|
||||
monkeypatch.setenv("CF_VISION_MOCK", "0")
|
||||
# Just verify the ValueError is about vlm backend, not "unknown"
|
||||
# (the ImportError from missing torch is expected in CI)
|
||||
try:
|
||||
make_vision_backend("vikhyatk/moondream2", mock=False)
|
||||
except ImportError:
|
||||
pass # Expected in CI without torch
|
||||
except ValueError as exc:
|
||||
pytest.fail(f"Should not raise ValueError for known backend: {exc}")
|
||||
|
||||
|
||||
def test_factory_siglip_autodetect(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Auto-detection should select siglip for non-VLM model paths (no ValueError)."""
|
||||
monkeypatch.setenv("CF_VISION_MOCK", "0")
|
||||
try:
|
||||
make_vision_backend("google/siglip-so400m-patch14-384", mock=False)
|
||||
except ValueError as exc:
|
||||
pytest.fail(f"Should not raise ValueError for known backend: {exc}")
|
||||
except Exception:
|
||||
pass # ImportError or model-loading errors are expected outside GPU CI
|
||||
|
||||
|
||||
# ── Process singleton ─────────────────────────────────────────────────────────
|
||||
|
||||
def test_module_classify_mock(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setenv("CF_VISION_MOCK", "1")
|
||||
# Reset the module-level singleton
|
||||
import circuitforge_core.vision as vision_mod
|
||||
vision_mod._backend = None
|
||||
|
||||
result = vision_mod.classify(FAKE_IMAGE, ["cat", "dog"])
|
||||
assert result.labels == ["cat", "dog"]
|
||||
assert len(result.scores) == 2
|
||||
|
||||
|
||||
def test_module_embed_mock(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setenv("CF_VISION_MOCK", "1")
|
||||
import circuitforge_core.vision as vision_mod
|
||||
vision_mod._backend = None
|
||||
|
||||
result = vision_mod.embed(FAKE_IMAGE)
|
||||
assert result.embedding is not None
|
||||
|
||||
|
||||
def test_module_caption_mock(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setenv("CF_VISION_MOCK", "1")
|
||||
import circuitforge_core.vision as vision_mod
|
||||
vision_mod._backend = None
|
||||
|
||||
result = vision_mod.caption(FAKE_IMAGE, prompt="Describe")
|
||||
assert result.caption is not None
|
||||
|
||||
|
||||
def test_module_make_backend_returns_fresh_instance(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
import circuitforge_core.vision as vision_mod
|
||||
b1 = vision_mod.make_backend("m1", mock=True)
|
||||
b2 = vision_mod.make_backend("m2", mock=True)
|
||||
assert b1 is not b2
|
||||
assert b1.model_name != b2.model_name
|
||||
Loading…
Reference in a new issue