Compare commits
28 commits
feature/ap
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| 01ed48808b | |||
| a2c768c635 | |||
| f7bf121aef | |||
| 8fa8216161 | |||
| b9b601aa23 | |||
| 433207d3c5 | |||
| 56fb6be4b1 | |||
| 0598801aaa | |||
| ffb95a5a30 | |||
| f74457d11f | |||
| d78310d4fd | |||
| a189511760 | |||
| 2e9e3fdc4b | |||
| 3082318e0d | |||
| 69a338bd98 | |||
| fc52d32574 | |||
| 7623c3edaf | |||
| 8c1daf3b6c | |||
| 80b0d5fd34 | |||
| 3075e5d3da | |||
| 67493048e2 | |||
| 5766fa82ab | |||
| 48d33a78ef | |||
| c9c4828387 | |||
| 19a26e02a0 | |||
| e5c26f0e67 | |||
| bb2ed3e992 | |||
| f3bc4ac605 |
79 changed files with 7422 additions and 32 deletions
48
CHANGELOG.md
48
CHANGELOG.md
|
|
@ -6,6 +6,54 @@ Versions follow [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
|||
|
||||
---
|
||||
|
||||
## [0.10.0] — 2026-04-12
|
||||
|
||||
### Added
|
||||
|
||||
**`circuitforge_core.community`** — shared community signal module (BSL 1.1, closes #44)
|
||||
|
||||
Provides the PostgreSQL-backed infrastructure for the cross-product community fine-tuning signal pipeline. Products write signals; the training pipeline reads them.
|
||||
|
||||
- `CommunityDB` — psycopg2 connection pool with `run_migrations()`. Picks up all `.sql` files from `circuitforge_core/community/migrations/` in filename order. Safe to call on every startup (idempotent `CREATE TABLE IF NOT EXISTS`).
|
||||
- `CommunityPost` — frozen dataclass capturing a user-authored community post with a snapshot of the originating product item (`element_snapshot` as a tuple of key-value pairs for immutability).
|
||||
- `SharedStore` — base class for product-specific community stores. Provides typed `pg_read()` and `pg_write()` helpers that products subclass without re-implementing connection management.
|
||||
- Migration 001: `community_posts` schema (id, product, item_id, pseudonym, title, body, element_snapshot JSONB, created_at).
|
||||
- Migration 002: `community_reactions` stub (post_id FK, pseudonym, reaction_type, created_at).
|
||||
- `psycopg2-binary` added to `[community]` optional extras in `pyproject.toml`.
|
||||
- All community classes exported from `circuitforge_core.community`.
|
||||
|
||||
---
|
||||
|
||||
## [0.9.0] — 2026-04-10
|
||||
|
||||
### Added
|
||||
|
||||
**`circuitforge_core.text`** — OpenAI-compatible `/v1/chat/completions` endpoint and pipeline crystallization engine.
|
||||
|
||||
**`circuitforge_core.pipeline`** — multimodal pipeline with staged output crystallization. Products queue draft outputs for human review before committing.
|
||||
|
||||
**`circuitforge_core.stt`** — speech-to-text module. `FasterWhisperBackend` for local transcription via `faster-whisper`. Managed FastAPI app mountable in any product.
|
||||
|
||||
**`circuitforge_core.tts`** — text-to-speech module. `ChatterboxTurbo` backend for local synthesis. Managed FastAPI app.
|
||||
|
||||
**Accessibility preferences** — `preferences` module extended with structured accessibility fields (motion reduction, high contrast, font size, focus highlight) under `accessibility.*` key path.
|
||||
|
||||
**LLM output corrections router** — `make_corrections_router()` for collecting LLM output corrections in any product. Stores corrections in product SQLite for future fine-tuning.
|
||||
|
||||
---
|
||||
|
||||
## [0.8.0] — 2026-04-08
|
||||
|
||||
### Added
|
||||
|
||||
**`circuitforge_core.vision`** — cf-vision managed service shim. Routes vision inference requests to a local cf-vision worker (moondream2 / SigLIP). Closes #43.
|
||||
|
||||
**`circuitforge_core.api.feedback`** — `make_feedback_router()` shared Forgejo issue-filing router. Products mount it under `/api/feedback`; requires `FORGEJO_API_TOKEN`. Closes #30.
|
||||
|
||||
**License validation** — `CF_LICENSE_KEY` validation via Heimdall REST API. Products call `validate_license(key, product)` to gate premium features. Closes #26.
|
||||
|
||||
---
|
||||
|
||||
## [0.7.0] — 2026-04-04
|
||||
|
||||
### Added
|
||||
|
|
|
|||
21
LICENSE
Normal file
21
LICENSE
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
MIT License
|
||||
|
||||
Copyright (c) 2026 CircuitForge LLC
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
|
|
@ -1 +1,8 @@
|
|||
__version__ = "0.8.0"
|
||||
__version__ = "0.10.0"
|
||||
|
||||
try:
|
||||
from circuitforge_core.community import CommunityDB, CommunityPost, SharedStore
|
||||
__all__ = ["CommunityDB", "CommunityPost", "SharedStore"]
|
||||
except ImportError:
|
||||
# psycopg2 not installed — install with: pip install circuitforge-core[community]
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
from circuitforge_core.api.feedback import make_feedback_router
|
||||
from circuitforge_core.api.corrections import make_corrections_router, CORRECTIONS_MIGRATION_SQL
|
||||
|
||||
__all__ = ["make_feedback_router"]
|
||||
__all__ = ["make_feedback_router", "make_corrections_router", "CORRECTIONS_MIGRATION_SQL"]
|
||||
|
|
|
|||
199
circuitforge_core/api/corrections.py
Normal file
199
circuitforge_core/api/corrections.py
Normal file
|
|
@ -0,0 +1,199 @@
|
|||
"""
|
||||
Shared corrections router — stores user corrections to LLM output for SFT training.
|
||||
|
||||
Products include this with make_corrections_router(get_db=..., product=...).
|
||||
Corrections are stored locally in each product's SQLite DB and exported as JSONL
|
||||
for the Avocet SFT pipeline. Separate from the bug-feedback→Forgejo-issue path.
|
||||
|
||||
Required DB migration (add to product migrations dir):
|
||||
-- From circuitforge_core.api.corrections import CORRECTIONS_MIGRATION_SQL
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sqlite3
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime, timezone
|
||||
from typing import Iterator, Literal
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# Drop this SQL into a product's migrations directory (e.g. 020_corrections.sql).
|
||||
CORRECTIONS_MIGRATION_SQL = """\
|
||||
CREATE TABLE IF NOT EXISTS corrections (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
item_id TEXT NOT NULL DEFAULT '',
|
||||
product TEXT NOT NULL,
|
||||
correction_type TEXT NOT NULL,
|
||||
input_text TEXT NOT NULL,
|
||||
original_output TEXT NOT NULL,
|
||||
corrected_output TEXT NOT NULL DEFAULT '',
|
||||
rating TEXT NOT NULL DEFAULT 'down',
|
||||
context TEXT NOT NULL DEFAULT '{}',
|
||||
opted_in INTEGER NOT NULL DEFAULT 0,
|
||||
created_at TEXT NOT NULL DEFAULT (datetime('now'))
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_corrections_product
|
||||
ON corrections (product);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_corrections_opted_in
|
||||
ON corrections (opted_in);
|
||||
"""
|
||||
|
||||
|
||||
class CorrectionRequest(BaseModel):
|
||||
item_id: str = ""
|
||||
product: str
|
||||
correction_type: str
|
||||
input_text: str
|
||||
original_output: str
|
||||
corrected_output: str = ""
|
||||
rating: Literal["up", "down"] = "down"
|
||||
context: dict = Field(default_factory=dict)
|
||||
opted_in: bool = False
|
||||
|
||||
|
||||
class CorrectionResponse(BaseModel):
|
||||
id: int
|
||||
saved: bool
|
||||
|
||||
|
||||
class CorrectionRecord(BaseModel):
|
||||
id: int
|
||||
item_id: str
|
||||
product: str
|
||||
correction_type: str
|
||||
input_text: str
|
||||
original_output: str
|
||||
corrected_output: str
|
||||
rating: str
|
||||
context: dict
|
||||
opted_in: bool
|
||||
created_at: str
|
||||
|
||||
|
||||
def make_corrections_router(
|
||||
get_db: Callable[[], Iterator[sqlite3.Connection]],
|
||||
product: str,
|
||||
) -> APIRouter:
|
||||
"""Return a configured corrections APIRouter.
|
||||
|
||||
Args:
|
||||
get_db: FastAPI dependency that yields a sqlite3.Connection.
|
||||
product: Product slug injected into every correction row (e.g. "linnet").
|
||||
"""
|
||||
router = APIRouter()
|
||||
|
||||
@router.post("", response_model=CorrectionResponse)
|
||||
def submit_correction(
|
||||
payload: CorrectionRequest,
|
||||
conn: sqlite3.Connection = Depends(get_db),
|
||||
) -> CorrectionResponse:
|
||||
"""Store a user correction to an LLM output."""
|
||||
# Thumbs-up with no corrected text is a valid positive signal.
|
||||
if payload.rating == "down" and not payload.corrected_output.strip():
|
||||
raise HTTPException(
|
||||
status_code=422,
|
||||
detail="corrected_output is required when rating is 'down'.",
|
||||
)
|
||||
|
||||
row_id = conn.execute(
|
||||
"""
|
||||
INSERT INTO corrections
|
||||
(item_id, product, correction_type, input_text, original_output,
|
||||
corrected_output, rating, context, opted_in)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
payload.item_id,
|
||||
product,
|
||||
payload.correction_type,
|
||||
payload.input_text,
|
||||
payload.original_output,
|
||||
payload.corrected_output,
|
||||
payload.rating,
|
||||
json.dumps(payload.context),
|
||||
int(payload.opted_in),
|
||||
),
|
||||
).lastrowid
|
||||
conn.commit()
|
||||
return CorrectionResponse(id=row_id, saved=True)
|
||||
|
||||
@router.get("", response_model=list[CorrectionRecord])
|
||||
def list_corrections(
|
||||
opted_in_only: bool = False,
|
||||
limit: int = 200,
|
||||
conn: sqlite3.Connection = Depends(get_db),
|
||||
) -> list[CorrectionRecord]:
|
||||
"""List stored corrections, optionally filtered to opted-in rows only."""
|
||||
conn.row_factory = sqlite3.Row
|
||||
query = "SELECT * FROM corrections"
|
||||
params: list = []
|
||||
if opted_in_only:
|
||||
query += " WHERE opted_in = 1"
|
||||
query += " ORDER BY created_at DESC LIMIT ?"
|
||||
params.append(max(1, min(limit, 1000)))
|
||||
rows = conn.execute(query, params).fetchall()
|
||||
return [
|
||||
CorrectionRecord(
|
||||
id=r["id"],
|
||||
item_id=r["item_id"],
|
||||
product=r["product"],
|
||||
correction_type=r["correction_type"],
|
||||
input_text=r["input_text"],
|
||||
original_output=r["original_output"],
|
||||
corrected_output=r["corrected_output"],
|
||||
rating=r["rating"],
|
||||
context=json.loads(r["context"] or "{}"),
|
||||
opted_in=bool(r["opted_in"]),
|
||||
created_at=r["created_at"],
|
||||
)
|
||||
for r in rows
|
||||
]
|
||||
|
||||
@router.get("/export")
|
||||
def export_corrections(
|
||||
opted_in_only: bool = True,
|
||||
conn: sqlite3.Connection = Depends(get_db),
|
||||
) -> StreamingResponse:
|
||||
"""Stream corrections as JSONL for the Avocet SFT pipeline.
|
||||
|
||||
Each line is a JSON object with the fields expected by avocet's
|
||||
SFT candidate importer. opted_in_only=True (default) — only rows
|
||||
where the user consented to share are exported.
|
||||
"""
|
||||
conn.row_factory = sqlite3.Row
|
||||
query = "SELECT * FROM corrections"
|
||||
if opted_in_only:
|
||||
query += " WHERE opted_in = 1"
|
||||
query += " ORDER BY created_at ASC"
|
||||
rows = conn.execute(query).fetchall()
|
||||
|
||||
timestamp = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
|
||||
filename = f"corrections_{product}_{timestamp}.jsonl"
|
||||
|
||||
def generate() -> Iterator[str]:
|
||||
for r in rows:
|
||||
record = {
|
||||
"input": r["input_text"],
|
||||
"output": r["original_output"],
|
||||
"correction": r["corrected_output"],
|
||||
"rating": r["rating"],
|
||||
"correction_type": r["correction_type"],
|
||||
"product": r["product"],
|
||||
"item_id": r["item_id"],
|
||||
"context": json.loads(r["context"] or "{}"),
|
||||
"created_at": r["created_at"],
|
||||
}
|
||||
yield json.dumps(record, ensure_ascii=False) + "\n"
|
||||
|
||||
return StreamingResponse(
|
||||
generate(),
|
||||
media_type="application/x-ndjson",
|
||||
headers={"Content-Disposition": f'attachment; filename="{filename}"'},
|
||||
)
|
||||
|
||||
return router
|
||||
9
circuitforge_core/community/__init__.py
Normal file
9
circuitforge_core/community/__init__.py
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
# circuitforge_core/community/__init__.py
|
||||
# MIT License
|
||||
|
||||
from .models import CommunityPost
|
||||
from .db import CommunityDB
|
||||
from .store import SharedStore
|
||||
from .snipe_store import SellerTrustSignal, SnipeCommunityStore
|
||||
|
||||
__all__ = ["CommunityDB", "CommunityPost", "SharedStore", "SellerTrustSignal", "SnipeCommunityStore"]
|
||||
117
circuitforge_core/community/db.py
Normal file
117
circuitforge_core/community/db.py
Normal file
|
|
@ -0,0 +1,117 @@
|
|||
# circuitforge_core/community/db.py
|
||||
# MIT License
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib.resources
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import psycopg2
|
||||
from psycopg2.pool import ThreadedConnectionPool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_MIN_CONN = 1
|
||||
_MAX_CONN = 10
|
||||
|
||||
|
||||
class CommunityDB:
|
||||
"""Shared PostgreSQL connection pool + migration runner for the community module.
|
||||
|
||||
Products instantiate one CommunityDB at startup and pass it to SharedStore
|
||||
subclasses. The pool is thread-safe (ThreadedConnectionPool).
|
||||
|
||||
Usage:
|
||||
db = CommunityDB.from_env() # reads COMMUNITY_DB_URL
|
||||
db.run_migrations()
|
||||
store = MyProductStore(db)
|
||||
db.close() # at shutdown
|
||||
"""
|
||||
|
||||
def __init__(self, dsn: str | None) -> None:
|
||||
if not dsn:
|
||||
raise ValueError(
|
||||
"CommunityDB requires a DSN. "
|
||||
"Set COMMUNITY_DB_URL or pass dsn= explicitly."
|
||||
)
|
||||
self._pool = ThreadedConnectionPool(_MIN_CONN, _MAX_CONN, dsn=dsn)
|
||||
logger.debug("CommunityDB pool created (min=%d, max=%d)", _MIN_CONN, _MAX_CONN)
|
||||
|
||||
@classmethod
|
||||
def from_env(cls) -> "CommunityDB":
|
||||
"""Construct from the COMMUNITY_DB_URL environment variable."""
|
||||
import os
|
||||
dsn = os.environ.get("COMMUNITY_DB_URL")
|
||||
return cls(dsn=dsn)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Connection management
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def getconn(self):
|
||||
"""Borrow a connection from the pool. Must be returned via putconn()."""
|
||||
return self._pool.getconn()
|
||||
|
||||
def putconn(self, conn) -> None:
|
||||
"""Return a borrowed connection to the pool."""
|
||||
self._pool.putconn(conn)
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close all pool connections. Call at application shutdown."""
|
||||
self._pool.closeall()
|
||||
logger.debug("CommunityDB pool closed")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Migration runner
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _discover_migrations(self) -> list[Path]:
|
||||
"""Return sorted list of .sql migration files from the community migrations dir."""
|
||||
pkg = importlib.resources.files("circuitforge_core.community.migrations")
|
||||
files = sorted(
|
||||
[Path(str(p)) for p in pkg.iterdir() if str(p).endswith(".sql")],
|
||||
key=lambda p: p.name,
|
||||
)
|
||||
return files
|
||||
|
||||
def run_migrations(self) -> None:
|
||||
"""Apply all community migration SQL files in numeric order.
|
||||
|
||||
Uses a simple applied-migrations table to avoid re-running already
|
||||
applied migrations. Idempotent.
|
||||
"""
|
||||
conn = self.getconn()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute("""
|
||||
CREATE TABLE IF NOT EXISTS _community_migrations (
|
||||
filename TEXT PRIMARY KEY,
|
||||
applied_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
)
|
||||
""")
|
||||
conn.commit()
|
||||
|
||||
for migration_file in self._discover_migrations():
|
||||
name = migration_file.name
|
||||
cur.execute(
|
||||
"SELECT 1 FROM _community_migrations WHERE filename = %s",
|
||||
(name,),
|
||||
)
|
||||
if cur.fetchone():
|
||||
logger.debug("Migration %s already applied, skipping", name)
|
||||
continue
|
||||
|
||||
sql = migration_file.read_text()
|
||||
logger.info("Applying community migration: %s", name)
|
||||
cur.execute(sql)
|
||||
cur.execute(
|
||||
"INSERT INTO _community_migrations (filename) VALUES (%s)",
|
||||
(name,),
|
||||
)
|
||||
conn.commit()
|
||||
except Exception:
|
||||
conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self.putconn(conn)
|
||||
|
|
@ -0,0 +1,55 @@
|
|||
-- 001_community_posts.sql
|
||||
-- Community posts table: published meal plans, recipe successes, and bloopers.
|
||||
-- Applies to: cf_community PostgreSQL database (hosted by cf-orch).
|
||||
-- BSL boundary: this schema is MIT (data layer, no inference).
|
||||
|
||||
CREATE TABLE IF NOT EXISTS community_posts (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
slug TEXT NOT NULL UNIQUE,
|
||||
pseudonym TEXT NOT NULL,
|
||||
post_type TEXT NOT NULL CHECK (post_type IN ('plan', 'recipe_success', 'recipe_blooper')),
|
||||
published TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
title TEXT NOT NULL,
|
||||
description TEXT,
|
||||
photo_url TEXT,
|
||||
|
||||
-- Plan slots (JSON array: [{day, meal_type, recipe_id, recipe_name}])
|
||||
slots JSONB NOT NULL DEFAULT '[]',
|
||||
|
||||
-- Recipe result fields
|
||||
recipe_id BIGINT,
|
||||
recipe_name TEXT,
|
||||
level SMALLINT CHECK (level IS NULL OR level BETWEEN 1 AND 4),
|
||||
outcome_notes TEXT,
|
||||
|
||||
-- Element snapshot (denormalized from corpus at publish time)
|
||||
seasoning_score REAL,
|
||||
richness_score REAL,
|
||||
brightness_score REAL,
|
||||
depth_score REAL,
|
||||
aroma_score REAL,
|
||||
structure_score REAL,
|
||||
texture_profile TEXT,
|
||||
|
||||
-- Dietary / allergen / flavor
|
||||
dietary_tags JSONB NOT NULL DEFAULT '[]',
|
||||
allergen_flags JSONB NOT NULL DEFAULT '[]',
|
||||
flavor_molecules JSONB NOT NULL DEFAULT '[]',
|
||||
|
||||
-- USDA FDC macros
|
||||
fat_pct REAL,
|
||||
protein_pct REAL,
|
||||
moisture_pct REAL,
|
||||
|
||||
-- Source product identifier
|
||||
source_product TEXT NOT NULL DEFAULT 'kiwi'
|
||||
);
|
||||
|
||||
-- Indexes for common filter patterns
|
||||
CREATE INDEX IF NOT EXISTS idx_community_posts_published ON community_posts (published DESC);
|
||||
CREATE INDEX IF NOT EXISTS idx_community_posts_post_type ON community_posts (post_type);
|
||||
CREATE INDEX IF NOT EXISTS idx_community_posts_source ON community_posts (source_product);
|
||||
|
||||
-- GIN index for dietary/allergen JSONB array containment queries
|
||||
CREATE INDEX IF NOT EXISTS idx_community_posts_dietary_tags ON community_posts USING GIN (dietary_tags);
|
||||
CREATE INDEX IF NOT EXISTS idx_community_posts_allergen_flags ON community_posts USING GIN (allergen_flags);
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
-- 002_community_post_reactions.sql
|
||||
-- Reserved: community post reactions (thumbs-up, saves count).
|
||||
-- Not yet implemented -- this migration is a stub to reserve the sequence number.
|
||||
-- Applies to: cf_community PostgreSQL database (hosted by cf-orch).
|
||||
|
||||
-- Placeholder: no-op. Will be replaced when reactions feature is designed.
|
||||
SELECT 1;
|
||||
|
|
@ -0,0 +1,26 @@
|
|||
-- Seller trust signals: confirmed scammer / confirmed legitimate outcomes from Snipe.
|
||||
-- Separate table from community_posts (Kiwi-specific) — seller signals are a
|
||||
-- structurally different domain and should not overload the recipe post schema.
|
||||
-- Applies to: cf_community PostgreSQL database (hosted by cf-orch).
|
||||
-- BSL boundary: table schema is MIT; signal ingestion route in cf-orch is BSL 1.1.
|
||||
|
||||
CREATE TABLE IF NOT EXISTS seller_trust_signals (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
platform TEXT NOT NULL DEFAULT 'ebay',
|
||||
platform_seller_id TEXT NOT NULL,
|
||||
confirmed_scam BOOLEAN NOT NULL,
|
||||
signal_source TEXT NOT NULL, -- 'blocklist_add' | 'community_vote' | 'resolved'
|
||||
flags JSONB NOT NULL DEFAULT '[]', -- red flag keys at time of signal
|
||||
source_product TEXT NOT NULL DEFAULT 'snipe',
|
||||
recorded_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
|
||||
-- No PII: platform_seller_id is the public eBay username or platform ID only.
|
||||
CREATE INDEX IF NOT EXISTS idx_seller_trust_platform_id
|
||||
ON seller_trust_signals (platform, platform_seller_id);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_seller_trust_confirmed
|
||||
ON seller_trust_signals (confirmed_scam);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_seller_trust_recorded
|
||||
ON seller_trust_signals (recorded_at DESC);
|
||||
|
|
@ -0,0 +1,19 @@
|
|||
-- 004_community_categories.sql
|
||||
-- MIT License
|
||||
-- Shared eBay category tree published by credentialed Snipe instances.
|
||||
-- Credentialless instances pull from this table during refresh().
|
||||
-- Privacy: only public eBay category metadata (IDs, names, paths) — no user data.
|
||||
|
||||
CREATE TABLE IF NOT EXISTS community_categories (
|
||||
id SERIAL PRIMARY KEY,
|
||||
platform TEXT NOT NULL DEFAULT 'ebay',
|
||||
category_id TEXT NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
full_path TEXT NOT NULL,
|
||||
source_product TEXT NOT NULL DEFAULT 'snipe',
|
||||
published_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
UNIQUE (platform, category_id)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_community_cat_name
|
||||
ON community_categories (platform, name);
|
||||
2
circuitforge_core/community/migrations/__init__.py
Normal file
2
circuitforge_core/community/migrations/__init__.py
Normal file
|
|
@ -0,0 +1,2 @@
|
|||
# Community module migrations
|
||||
# These SQL files are shipped with circuitforge-core so cf-orch can locate them via importlib.resources.
|
||||
87
circuitforge_core/community/models.py
Normal file
87
circuitforge_core/community/models.py
Normal file
|
|
@ -0,0 +1,87 @@
|
|||
# circuitforge_core/community/models.py
|
||||
# MIT License
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Literal
|
||||
|
||||
PostType = Literal["plan", "recipe_success", "recipe_blooper"]
|
||||
CreativityLevel = Literal[1, 2, 3, 4]
|
||||
|
||||
_VALID_POST_TYPES: frozenset[str] = frozenset(["plan", "recipe_success", "recipe_blooper"])
|
||||
|
||||
|
||||
def _validate_score(name: str, value: float) -> float:
|
||||
if not (0.0 <= value <= 1.0):
|
||||
raise ValueError(f"{name} must be between 0.0 and 1.0, got {value!r}")
|
||||
return value
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CommunityPost:
|
||||
"""Immutable snapshot of a published community post.
|
||||
|
||||
Lists (dietary_tags, allergen_flags, flavor_molecules, slots) are stored as
|
||||
tuples to enforce immutability. Pass lists -- they are converted in __post_init__.
|
||||
"""
|
||||
|
||||
# Identity
|
||||
slug: str
|
||||
pseudonym: str
|
||||
post_type: PostType
|
||||
published: datetime
|
||||
title: str
|
||||
|
||||
# Optional content
|
||||
description: str | None
|
||||
photo_url: str | None
|
||||
|
||||
# Plan slots -- list[dict] for post_type="plan"
|
||||
slots: tuple
|
||||
|
||||
# Recipe result fields -- for post_type="recipe_success" | "recipe_blooper"
|
||||
recipe_id: int | None
|
||||
recipe_name: str | None
|
||||
level: CreativityLevel | None
|
||||
outcome_notes: str | None
|
||||
|
||||
# Element snapshot
|
||||
seasoning_score: float
|
||||
richness_score: float
|
||||
brightness_score: float
|
||||
depth_score: float
|
||||
aroma_score: float
|
||||
structure_score: float
|
||||
texture_profile: str
|
||||
|
||||
# Dietary/allergen/flavor
|
||||
dietary_tags: tuple
|
||||
allergen_flags: tuple
|
||||
flavor_molecules: tuple
|
||||
|
||||
# USDA FDC macros (optional -- may not be available for all recipes)
|
||||
fat_pct: float | None
|
||||
protein_pct: float | None
|
||||
moisture_pct: float | None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# Coerce list fields to tuples (frozen dataclass: use object.__setattr__)
|
||||
for key in ("slots", "dietary_tags", "allergen_flags", "flavor_molecules"):
|
||||
val = getattr(self, key)
|
||||
if isinstance(val, list):
|
||||
object.__setattr__(self, key, tuple(val))
|
||||
|
||||
# Validate post_type
|
||||
if self.post_type not in _VALID_POST_TYPES:
|
||||
raise ValueError(
|
||||
f"post_type must be one of {sorted(_VALID_POST_TYPES)}, got {self.post_type!r}"
|
||||
)
|
||||
|
||||
# Validate scores
|
||||
for score_name in (
|
||||
"seasoning_score", "richness_score", "brightness_score",
|
||||
"depth_score", "aroma_score", "structure_score",
|
||||
):
|
||||
_validate_score(score_name, getattr(self, score_name))
|
||||
253
circuitforge_core/community/snipe_store.py
Normal file
253
circuitforge_core/community/snipe_store.py
Normal file
|
|
@ -0,0 +1,253 @@
|
|||
# circuitforge_core/community/snipe_store.py
|
||||
# MIT License
|
||||
"""Snipe community store — publishes seller trust signals to the shared community DB.
|
||||
|
||||
Snipe products subclass SharedStore here to write seller trust signals
|
||||
(confirmed scammer / confirmed legitimate) to the cf_community PostgreSQL.
|
||||
These signals aggregate across all Snipe users to power the cross-user
|
||||
seller trust classifier fine-tuning corpus.
|
||||
|
||||
Privacy: only platform_seller_id (public eBay username/ID) and flag keys
|
||||
are written. No PII is stored.
|
||||
|
||||
Usage:
|
||||
from circuitforge_core.community import CommunityDB
|
||||
from circuitforge_core.community.snipe_store import SnipeCommunityStore
|
||||
|
||||
db = CommunityDB.from_env()
|
||||
store = SnipeCommunityStore(db, source_product="snipe")
|
||||
store.publish_seller_signal(
|
||||
platform_seller_id="ebay-username",
|
||||
confirmed_scam=True,
|
||||
signal_source="blocklist_add",
|
||||
flags=["new_account", "suspicious_price"],
|
||||
)
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from .store import SharedStore
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SellerTrustSignal:
|
||||
"""Immutable snapshot of a recorded seller trust signal."""
|
||||
id: int
|
||||
platform: str
|
||||
platform_seller_id: str
|
||||
confirmed_scam: bool
|
||||
signal_source: str
|
||||
flags: tuple
|
||||
source_product: str
|
||||
recorded_at: datetime
|
||||
|
||||
|
||||
class SnipeCommunityStore(SharedStore):
|
||||
"""Community store for Snipe — seller trust signal publishing and querying."""
|
||||
|
||||
def __init__(self, db, source_product: str = "snipe") -> None:
|
||||
super().__init__(db, source_product=source_product)
|
||||
|
||||
def publish_seller_signal(
|
||||
self,
|
||||
platform_seller_id: str,
|
||||
confirmed_scam: bool,
|
||||
signal_source: str,
|
||||
flags: list[str] | None = None,
|
||||
platform: str = "ebay",
|
||||
) -> SellerTrustSignal:
|
||||
"""Record a seller trust outcome in the shared community DB.
|
||||
|
||||
Args:
|
||||
platform_seller_id: Public eBay username or platform ID (no PII).
|
||||
confirmed_scam: True = confirmed bad actor; False = confirmed legitimate.
|
||||
signal_source: Origin of the signal.
|
||||
'blocklist_add' — user explicitly added to local blocklist
|
||||
'community_vote' — consensus threshold reached from multiple reports
|
||||
'resolved' — seller resolved as legitimate over time
|
||||
flags: List of red-flag keys active at signal time (e.g. ["new_account"]).
|
||||
platform: Source auction platform (default "ebay").
|
||||
|
||||
Returns the inserted SellerTrustSignal.
|
||||
"""
|
||||
flags = flags or []
|
||||
conn = self._db.getconn()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO seller_trust_signals
|
||||
(platform, platform_seller_id, confirmed_scam,
|
||||
signal_source, flags, source_product)
|
||||
VALUES (%s, %s, %s, %s, %s::jsonb, %s)
|
||||
RETURNING id, recorded_at
|
||||
""",
|
||||
(
|
||||
platform,
|
||||
platform_seller_id,
|
||||
confirmed_scam,
|
||||
signal_source,
|
||||
json.dumps(flags),
|
||||
self._source_product,
|
||||
),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
conn.commit()
|
||||
return SellerTrustSignal(
|
||||
id=row[0],
|
||||
platform=platform,
|
||||
platform_seller_id=platform_seller_id,
|
||||
confirmed_scam=confirmed_scam,
|
||||
signal_source=signal_source,
|
||||
flags=tuple(flags),
|
||||
source_product=self._source_product,
|
||||
recorded_at=row[1],
|
||||
)
|
||||
except Exception:
|
||||
conn.rollback()
|
||||
log.warning(
|
||||
"Failed to publish seller signal for %s (%s)",
|
||||
platform_seller_id, signal_source, exc_info=True,
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
self._db.putconn(conn)
|
||||
|
||||
def list_signals_for_seller(
|
||||
self,
|
||||
platform_seller_id: str,
|
||||
platform: str = "ebay",
|
||||
limit: int = 50,
|
||||
) -> list[SellerTrustSignal]:
|
||||
"""Return recent trust signals for a specific seller."""
|
||||
conn = self._db.getconn()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT id, platform, platform_seller_id, confirmed_scam,
|
||||
signal_source, flags, source_product, recorded_at
|
||||
FROM seller_trust_signals
|
||||
WHERE platform = %s AND platform_seller_id = %s
|
||||
ORDER BY recorded_at DESC
|
||||
LIMIT %s
|
||||
""",
|
||||
(platform, platform_seller_id, limit),
|
||||
)
|
||||
rows = cur.fetchall()
|
||||
return [
|
||||
SellerTrustSignal(
|
||||
id=r[0], platform=r[1], platform_seller_id=r[2],
|
||||
confirmed_scam=r[3], signal_source=r[4],
|
||||
flags=tuple(json.loads(r[5]) if isinstance(r[5], str) else r[5] or []),
|
||||
source_product=r[6], recorded_at=r[7],
|
||||
)
|
||||
for r in rows
|
||||
]
|
||||
finally:
|
||||
self._db.putconn(conn)
|
||||
|
||||
def scam_signal_count(self, platform_seller_id: str, platform: str = "ebay") -> int:
|
||||
"""Return the number of confirmed_scam=True signals for a seller.
|
||||
|
||||
Used to determine if a seller has crossed the community consensus threshold
|
||||
for appearing in the shared blocklist.
|
||||
"""
|
||||
conn = self._db.getconn()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT COUNT(*) FROM seller_trust_signals
|
||||
WHERE platform = %s AND platform_seller_id = %s AND confirmed_scam = TRUE
|
||||
""",
|
||||
(platform, platform_seller_id),
|
||||
)
|
||||
return cur.fetchone()[0]
|
||||
finally:
|
||||
self._db.putconn(conn)
|
||||
|
||||
def publish_categories(
|
||||
self,
|
||||
categories: list[tuple[str, str, str]],
|
||||
platform: str = "ebay",
|
||||
) -> int:
|
||||
"""Upsert a batch of eBay leaf categories into the shared community table.
|
||||
|
||||
Args:
|
||||
categories: List of (category_id, name, full_path) tuples.
|
||||
platform: Source auction platform (default "ebay").
|
||||
|
||||
Returns:
|
||||
Number of rows upserted.
|
||||
"""
|
||||
if not categories:
|
||||
return 0
|
||||
conn = self._db.getconn()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.executemany(
|
||||
"""
|
||||
INSERT INTO community_categories
|
||||
(platform, category_id, name, full_path, source_product)
|
||||
VALUES (%s, %s, %s, %s, %s)
|
||||
ON CONFLICT (platform, category_id)
|
||||
DO UPDATE SET
|
||||
name = EXCLUDED.name,
|
||||
full_path = EXCLUDED.full_path,
|
||||
source_product = EXCLUDED.source_product,
|
||||
published_at = NOW()
|
||||
""",
|
||||
[
|
||||
(platform, cid, name, path, self._source_product)
|
||||
for cid, name, path in categories
|
||||
],
|
||||
)
|
||||
conn.commit()
|
||||
return len(categories)
|
||||
except Exception:
|
||||
conn.rollback()
|
||||
log.warning(
|
||||
"Failed to publish %d categories to community store",
|
||||
len(categories), exc_info=True,
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
self._db.putconn(conn)
|
||||
|
||||
def fetch_categories(
|
||||
self,
|
||||
platform: str = "ebay",
|
||||
limit: int = 500,
|
||||
) -> list[tuple[str, str, str]]:
|
||||
"""Fetch community-contributed eBay categories.
|
||||
|
||||
Args:
|
||||
platform: Source auction platform (default "ebay").
|
||||
limit: Maximum rows to return.
|
||||
|
||||
Returns:
|
||||
List of (category_id, name, full_path) tuples ordered by name.
|
||||
"""
|
||||
conn = self._db.getconn()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT category_id, name, full_path
|
||||
FROM community_categories
|
||||
WHERE platform = %s
|
||||
ORDER BY name
|
||||
LIMIT %s
|
||||
""",
|
||||
(platform, limit),
|
||||
)
|
||||
return [(row[0], row[1], row[2]) for row in cur.fetchall()]
|
||||
finally:
|
||||
self._db.putconn(conn)
|
||||
209
circuitforge_core/community/store.py
Normal file
209
circuitforge_core/community/store.py
Normal file
|
|
@ -0,0 +1,209 @@
|
|||
# circuitforge_core/community/store.py
|
||||
# MIT License
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .models import CommunityPost
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .db import CommunityDB
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _row_to_post(row: dict) -> CommunityPost:
|
||||
"""Convert a psycopg2 row dict to a CommunityPost.
|
||||
|
||||
JSONB columns (slots, dietary_tags, allergen_flags, flavor_molecules) come
|
||||
back from psycopg2 as Python lists already — no json.loads() needed.
|
||||
"""
|
||||
return CommunityPost(
|
||||
slug=row["slug"],
|
||||
pseudonym=row["pseudonym"],
|
||||
post_type=row["post_type"],
|
||||
published=row["published"],
|
||||
title=row["title"],
|
||||
description=row.get("description"),
|
||||
photo_url=row.get("photo_url"),
|
||||
slots=row.get("slots") or [],
|
||||
recipe_id=row.get("recipe_id"),
|
||||
recipe_name=row.get("recipe_name"),
|
||||
level=row.get("level"),
|
||||
outcome_notes=row.get("outcome_notes"),
|
||||
seasoning_score=row["seasoning_score"] or 0.0,
|
||||
richness_score=row["richness_score"] or 0.0,
|
||||
brightness_score=row["brightness_score"] or 0.0,
|
||||
depth_score=row["depth_score"] or 0.0,
|
||||
aroma_score=row["aroma_score"] or 0.0,
|
||||
structure_score=row["structure_score"] or 0.0,
|
||||
texture_profile=row.get("texture_profile") or "",
|
||||
dietary_tags=row.get("dietary_tags") or [],
|
||||
allergen_flags=row.get("allergen_flags") or [],
|
||||
flavor_molecules=row.get("flavor_molecules") or [],
|
||||
fat_pct=row.get("fat_pct"),
|
||||
protein_pct=row.get("protein_pct"),
|
||||
moisture_pct=row.get("moisture_pct"),
|
||||
)
|
||||
|
||||
|
||||
def _cursor_to_dict(cur, row) -> dict:
|
||||
"""Convert a psycopg2 row tuple to a dict using cursor.description."""
|
||||
if isinstance(row, dict):
|
||||
return row
|
||||
return {desc[0]: val for desc, val in zip(cur.description, row)}
|
||||
|
||||
|
||||
class SharedStore:
|
||||
"""Base class for product community stores.
|
||||
|
||||
Subclass this in each product:
|
||||
class KiwiCommunityStore(SharedStore):
|
||||
def list_posts_for_week(self, week_start: str) -> list[CommunityPost]: ...
|
||||
|
||||
All methods return new objects (immutable pattern). Never mutate rows in-place.
|
||||
"""
|
||||
|
||||
def __init__(self, db: "CommunityDB", source_product: str = "kiwi") -> None:
|
||||
self._db = db
|
||||
self._source_product = source_product
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Reads
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def get_post_by_slug(self, slug: str) -> CommunityPost | None:
|
||||
conn = self._db.getconn()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"SELECT * FROM community_posts WHERE slug = %s LIMIT 1",
|
||||
(slug,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
return _row_to_post(_cursor_to_dict(cur, row))
|
||||
finally:
|
||||
self._db.putconn(conn)
|
||||
|
||||
def list_posts(
|
||||
self,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
post_type: str | None = None,
|
||||
dietary_tags: list[str] | None = None,
|
||||
allergen_exclude: list[str] | None = None,
|
||||
source_product: str | None = None,
|
||||
) -> list[CommunityPost]:
|
||||
"""Paginated post list with optional filters.
|
||||
|
||||
dietary_tags: JSONB containment — posts must include ALL listed tags.
|
||||
allergen_exclude: JSONB overlap exclusion — posts must NOT include any listed flag.
|
||||
"""
|
||||
conn = self._db.getconn()
|
||||
try:
|
||||
conditions = []
|
||||
params: list = []
|
||||
|
||||
if post_type:
|
||||
conditions.append("post_type = %s")
|
||||
params.append(post_type)
|
||||
if dietary_tags:
|
||||
import json
|
||||
conditions.append("dietary_tags @> %s::jsonb")
|
||||
params.append(json.dumps(dietary_tags))
|
||||
if allergen_exclude:
|
||||
import json
|
||||
conditions.append("NOT (allergen_flags && %s::jsonb)")
|
||||
params.append(json.dumps(allergen_exclude))
|
||||
if source_product:
|
||||
conditions.append("source_product = %s")
|
||||
params.append(source_product)
|
||||
|
||||
where = ("WHERE " + " AND ".join(conditions)) if conditions else ""
|
||||
params.extend([limit, offset])
|
||||
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
f"SELECT * FROM community_posts {where} "
|
||||
"ORDER BY published DESC LIMIT %s OFFSET %s",
|
||||
params,
|
||||
)
|
||||
rows = cur.fetchall()
|
||||
return [_row_to_post(_cursor_to_dict(cur, r)) for r in rows]
|
||||
finally:
|
||||
self._db.putconn(conn)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Writes
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def insert_post(self, post: CommunityPost) -> CommunityPost:
|
||||
"""Insert a new community post. Returns the inserted post (unchanged — slug is the key)."""
|
||||
import json
|
||||
|
||||
conn = self._db.getconn()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO community_posts (
|
||||
slug, pseudonym, post_type, published, title, description, photo_url,
|
||||
slots, recipe_id, recipe_name, level, outcome_notes,
|
||||
seasoning_score, richness_score, brightness_score,
|
||||
depth_score, aroma_score, structure_score, texture_profile,
|
||||
dietary_tags, allergen_flags, flavor_molecules,
|
||||
fat_pct, protein_pct, moisture_pct, source_product
|
||||
) VALUES (
|
||||
%s, %s, %s, %s, %s, %s, %s,
|
||||
%s::jsonb, %s, %s, %s, %s,
|
||||
%s, %s, %s, %s, %s, %s, %s,
|
||||
%s::jsonb, %s::jsonb, %s::jsonb,
|
||||
%s, %s, %s, %s
|
||||
)
|
||||
""",
|
||||
(
|
||||
post.slug, post.pseudonym, post.post_type,
|
||||
post.published, post.title, post.description, post.photo_url,
|
||||
json.dumps(list(post.slots)),
|
||||
post.recipe_id, post.recipe_name, post.level, post.outcome_notes,
|
||||
post.seasoning_score, post.richness_score, post.brightness_score,
|
||||
post.depth_score, post.aroma_score, post.structure_score,
|
||||
post.texture_profile,
|
||||
json.dumps(list(post.dietary_tags)),
|
||||
json.dumps(list(post.allergen_flags)),
|
||||
json.dumps(list(post.flavor_molecules)),
|
||||
post.fat_pct, post.protein_pct, post.moisture_pct,
|
||||
self._source_product,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
return post
|
||||
except Exception:
|
||||
conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self._db.putconn(conn)
|
||||
|
||||
def delete_post(self, slug: str, pseudonym: str) -> bool:
|
||||
"""Hard-delete a post. Only succeeds if pseudonym matches the author.
|
||||
|
||||
Returns True if a row was deleted, False if no matching row found.
|
||||
"""
|
||||
conn = self._db.getconn()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"DELETE FROM community_posts WHERE slug = %s AND pseudonym = %s",
|
||||
(slug, pseudonym),
|
||||
)
|
||||
conn.commit()
|
||||
return cur.rowcount > 0
|
||||
except Exception:
|
||||
conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self._db.putconn(conn)
|
||||
|
|
@ -1,3 +1,4 @@
|
|||
from .settings import require_env, load_env
|
||||
from .license import validate_license, get_license_tier
|
||||
|
||||
__all__ = ["require_env", "load_env"]
|
||||
__all__ = ["require_env", "load_env", "validate_license", "get_license_tier"]
|
||||
|
|
|
|||
104
circuitforge_core/config/license.py
Normal file
104
circuitforge_core/config/license.py
Normal file
|
|
@ -0,0 +1,104 @@
|
|||
"""
|
||||
License validation via Heimdall.
|
||||
|
||||
Products call validate_license() or get_license_tier() at startup to check
|
||||
the CF_LICENSE_KEY environment variable against Heimdall.
|
||||
|
||||
Both functions are safe to call when CF_LICENSE_KEY is absent — they return
|
||||
"free" tier gracefully rather than raising.
|
||||
|
||||
Environment variables:
|
||||
CF_LICENSE_KEY — Raw license key (e.g. CFG-PRNG-XXXX-XXXX-XXXX).
|
||||
If absent, product runs as free tier.
|
||||
CF_LICENSE_URL — Heimdall base URL override.
|
||||
Default: https://license.circuitforge.tech
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
|
||||
import requests
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_DEFAULT_HEIMDALL_URL = "https://license.circuitforge.tech"
|
||||
_CACHE_TTL_SECONDS = 1800 # 30 minutes
|
||||
|
||||
# Cache: (key, product) -> (result_dict, expires_at)
|
||||
_cache: dict[tuple[str, str], tuple[dict[str, bool | str], float]] = {}
|
||||
|
||||
_INVALID: dict[str, bool | str] = {"valid": False, "tier": "free", "user_id": ""}
|
||||
|
||||
|
||||
def _heimdall_url(override: str | None) -> str:
|
||||
return override or os.environ.get("CF_LICENSE_URL", _DEFAULT_HEIMDALL_URL)
|
||||
|
||||
|
||||
def validate_license(
|
||||
product: str,
|
||||
min_tier: str = "free",
|
||||
heimdall_url: str | None = None,
|
||||
) -> dict[str, bool | str]:
|
||||
"""
|
||||
Validate CF_LICENSE_KEY against Heimdall for the given product.
|
||||
|
||||
Returns a dict with keys: valid (bool), tier (str), user_id (str).
|
||||
Returns {"valid": False, "tier": "free", "user_id": ""} when:
|
||||
- CF_LICENSE_KEY is not set
|
||||
- Heimdall is unreachable
|
||||
- The key is invalid/expired/revoked
|
||||
|
||||
Results are cached for 30 minutes per (key, product) pair.
|
||||
"""
|
||||
key = os.environ.get("CF_LICENSE_KEY", "").strip()
|
||||
if not key:
|
||||
return dict(_INVALID)
|
||||
|
||||
cache_key = (key, product)
|
||||
now = time.monotonic()
|
||||
if cache_key in _cache:
|
||||
cached_result, expires_at = _cache[cache_key]
|
||||
if now < expires_at:
|
||||
return dict(cached_result)
|
||||
|
||||
base = _heimdall_url(heimdall_url)
|
||||
try:
|
||||
resp = requests.post(
|
||||
f"{base}/licenses/verify",
|
||||
json={"key": key, "min_tier": min_tier},
|
||||
timeout=5,
|
||||
)
|
||||
if not resp.ok:
|
||||
logger.warning("[license] Heimdall returned %s for key validation", resp.status_code)
|
||||
result = dict(_INVALID)
|
||||
else:
|
||||
data = resp.json()
|
||||
result = {
|
||||
"valid": bool(data.get("valid", False)),
|
||||
"tier": data.get("tier", "free") or "free",
|
||||
"user_id": data.get("user_id", "") or "",
|
||||
}
|
||||
except Exception as exc:
|
||||
logger.warning("[license] License validation failed: %s", exc)
|
||||
result = dict(_INVALID)
|
||||
|
||||
_cache[cache_key] = (result, now + _CACHE_TTL_SECONDS)
|
||||
return result
|
||||
|
||||
|
||||
def get_license_tier(
|
||||
product: str,
|
||||
heimdall_url: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Return the active tier for CF_LICENSE_KEY, or "free" if absent/invalid.
|
||||
|
||||
Convenience wrapper around validate_license() for the common case
|
||||
where only the tier string is needed.
|
||||
"""
|
||||
result = validate_license(product, min_tier="free", heimdall_url=heimdall_url)
|
||||
if not result["valid"]:
|
||||
return "free"
|
||||
return result["tier"]
|
||||
|
|
@ -4,12 +4,22 @@ Applies *.sql files from migrations_dir in filename order.
|
|||
Tracks applied migrations in a _migrations table — safe to call multiple times.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def run_migrations(conn: sqlite3.Connection, migrations_dir: Path) -> None:
|
||||
"""Apply any unapplied *.sql migrations from migrations_dir."""
|
||||
"""Apply any unapplied *.sql migrations from migrations_dir.
|
||||
|
||||
Resilient to partial-failure recovery: if a migration previously crashed
|
||||
mid-run (e.g. a process killed after some ALTER TABLE statements
|
||||
auto-committed via executescript), the next startup re-runs that migration.
|
||||
Any "duplicate column name" errors are silently skipped so the migration
|
||||
can complete and be marked as applied. All other errors still propagate.
|
||||
"""
|
||||
conn.execute(
|
||||
"CREATE TABLE IF NOT EXISTS _migrations "
|
||||
"(name TEXT PRIMARY KEY, applied_at TEXT DEFAULT CURRENT_TIMESTAMP)"
|
||||
|
|
@ -22,8 +32,92 @@ def run_migrations(conn: sqlite3.Connection, migrations_dir: Path) -> None:
|
|||
for sql_file in sql_files:
|
||||
if sql_file.name in applied:
|
||||
continue
|
||||
conn.executescript(sql_file.read_text())
|
||||
|
||||
try:
|
||||
conn.executescript(sql_file.read_text())
|
||||
except sqlite3.OperationalError as exc:
|
||||
if "duplicate column name" not in str(exc).lower():
|
||||
raise
|
||||
# A previous run partially applied this migration (some ALTER TABLE
|
||||
# statements auto-committed before the failure). Re-run with
|
||||
# per-statement recovery to skip already-applied columns.
|
||||
_log.warning(
|
||||
"Migration %s: partial-failure detected (%s) — "
|
||||
"retrying with per-statement recovery",
|
||||
sql_file.name,
|
||||
exc,
|
||||
)
|
||||
_run_script_with_recovery(conn, sql_file)
|
||||
|
||||
# OR IGNORE: safe if two Store() calls race on the same DB — second writer
|
||||
# just skips the insert rather than raising UNIQUE constraint failed.
|
||||
conn.execute("INSERT OR IGNORE INTO _migrations (name) VALUES (?)", (sql_file.name,))
|
||||
conn.commit()
|
||||
|
||||
|
||||
def _run_script_with_recovery(conn: sqlite3.Connection, sql_file: Path) -> None:
|
||||
"""Re-run a migration via executescript, skipping duplicate-column errors.
|
||||
|
||||
Used only when the first executescript() attempt raised a duplicate column
|
||||
error (indicating a previous partial run). Splits the script on the
|
||||
double-dash comment prefix pattern to re-issue each logical statement,
|
||||
catching only the known-safe "duplicate column name" error class.
|
||||
|
||||
Splitting is done via SQLite's own parser — we feed the script to a
|
||||
temporary in-memory connection using executescript (which commits
|
||||
auto-matically per DDL statement) and mirror the results on the real
|
||||
connection statement by statement. That's circular, so instead we use
|
||||
the simpler approach: executescript handles tokenization; we wrap the
|
||||
whole call in a try/except and retry after removing the offending statement.
|
||||
|
||||
Simpler approach: use conn.execute() per statement from the script.
|
||||
This avoids the semicolon-in-comment tokenization problem by not splitting
|
||||
ourselves — instead we let the DB tell us which statement failed and only
|
||||
skip that exact error class.
|
||||
"""
|
||||
# executescript() uses SQLite's real tokenizer, so re-issuing it after a
|
||||
# partial failure will hit "duplicate column name" again. We catch and
|
||||
# ignore that specific error class only, re-running until the script
|
||||
# completes or a different error is raised.
|
||||
#
|
||||
# Implementation: issue the whole script again; catch duplicate-column
|
||||
# errors; keep trying. Since executescript auto-commits per statement,
|
||||
# each successful statement in successive retries is a no-op (CREATE TABLE
|
||||
# IF NOT EXISTS, etc.) or a benign duplicate skip.
|
||||
#
|
||||
# Limit retries to prevent infinite loops on genuinely broken SQL.
|
||||
script = sql_file.read_text()
|
||||
for attempt in range(20):
|
||||
try:
|
||||
conn.executescript(script)
|
||||
return # success
|
||||
except sqlite3.OperationalError as exc:
|
||||
msg = str(exc).lower()
|
||||
if "duplicate column name" in msg:
|
||||
col = str(exc).split(":")[-1].strip() if ":" in str(exc) else "?"
|
||||
_log.warning(
|
||||
"Migration %s (attempt %d): skipping duplicate column '%s'",
|
||||
sql_file.name,
|
||||
attempt + 1,
|
||||
col,
|
||||
)
|
||||
# Remove the offending ALTER TABLE statement from the script
|
||||
# so the next attempt skips it. This is safe because SQLite
|
||||
# already auto-committed that column addition on a prior run.
|
||||
script = _remove_column_add(script, col)
|
||||
else:
|
||||
raise
|
||||
raise RuntimeError(
|
||||
f"Migration {sql_file.name}: could not complete after 20 recovery attempts"
|
||||
)
|
||||
|
||||
|
||||
def _remove_column_add(script: str, column: str) -> str:
|
||||
"""Remove the ALTER TABLE ADD COLUMN statement for *column* from *script*."""
|
||||
import re
|
||||
# Match: ALTER TABLE <tbl> ADD COLUMN <column> <rest-of-line>
|
||||
pattern = re.compile(
|
||||
r"ALTER\s+TABLE\s+\w+\s+ADD\s+COLUMN\s+" + re.escape(column) + r"[^\n]*\n?",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
return pattern.sub("", script)
|
||||
|
|
|
|||
|
|
@ -199,15 +199,18 @@ class LLMRouter:
|
|||
continue
|
||||
|
||||
elif backend["type"] == "openai_compat":
|
||||
if not self._is_reachable(backend["base_url"]):
|
||||
print(f"[LLMRouter] {name}: unreachable, skipping")
|
||||
continue
|
||||
# --- cf_orch: optionally override base_url with coordinator-allocated URL ---
|
||||
# cf_orch: try allocation first — this may start the service on-demand.
|
||||
# Do NOT reachability-check before allocating; the service may be stopped
|
||||
# and the allocation is what starts it.
|
||||
orch_ctx = orch_alloc = None
|
||||
orch_result = self._try_cf_orch_alloc(backend)
|
||||
if orch_result is not None:
|
||||
orch_ctx, orch_alloc = orch_result
|
||||
backend = {**backend, "base_url": orch_alloc.url + "/v1"}
|
||||
elif not self._is_reachable(backend["base_url"]):
|
||||
# Static backend (no cf-orch) — skip if not reachable.
|
||||
print(f"[LLMRouter] {name}: unreachable, skipping")
|
||||
continue
|
||||
try:
|
||||
client = OpenAI(
|
||||
base_url=backend["base_url"],
|
||||
|
|
|
|||
|
|
@ -1,3 +1,43 @@
|
|||
# circuitforge_core/pipeline — FPGA→ASIC crystallization engine
|
||||
#
|
||||
# Public API: call pipeline.run() from product code instead of llm.router directly.
|
||||
# The module transparently checks for crystallized workflows first, falls back
|
||||
# to LLM when none match, and records each run for future crystallization.
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Callable
|
||||
|
||||
from .crystallizer import CrystallizerConfig, crystallize, evaluate_new_run, should_crystallize
|
||||
from .executor import ExecutionResult, Executor, StepResult
|
||||
from .models import CrystallizedWorkflow, PipelineRun, Step, hash_input
|
||||
from .multimodal import MultimodalConfig, MultimodalPipeline, PageResult
|
||||
from .recorder import Recorder
|
||||
from .registry import Registry
|
||||
from .staging import StagingDB
|
||||
|
||||
__all__ = ["StagingDB"]
|
||||
__all__ = [
|
||||
# models
|
||||
"PipelineRun",
|
||||
"CrystallizedWorkflow",
|
||||
"Step",
|
||||
"hash_input",
|
||||
# recorder
|
||||
"Recorder",
|
||||
# crystallizer
|
||||
"CrystallizerConfig",
|
||||
"crystallize",
|
||||
"evaluate_new_run",
|
||||
"should_crystallize",
|
||||
# registry
|
||||
"Registry",
|
||||
# executor
|
||||
"Executor",
|
||||
"ExecutionResult",
|
||||
"StepResult",
|
||||
# multimodal
|
||||
"MultimodalPipeline",
|
||||
"MultimodalConfig",
|
||||
"PageResult",
|
||||
# legacy stub
|
||||
"StagingDB",
|
||||
]
|
||||
|
|
|
|||
177
circuitforge_core/pipeline/crystallizer.py
Normal file
177
circuitforge_core/pipeline/crystallizer.py
Normal file
|
|
@ -0,0 +1,177 @@
|
|||
# circuitforge_core/pipeline/crystallizer.py — promote approved runs → workflows
|
||||
#
|
||||
# MIT — pure logic, no inference backends.
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import warnings
|
||||
from collections import Counter
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from typing import Literal
|
||||
|
||||
from .models import CrystallizedWorkflow, PipelineRun, Step
|
||||
from .recorder import Recorder
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
# Minimum milliseconds of review that counts as "genuine".
|
||||
# Runs shorter than this are accepted but trigger a warning.
|
||||
_RUBBER_STAMP_THRESHOLD_MS = 5_000
|
||||
|
||||
|
||||
@dataclass
|
||||
class CrystallizerConfig:
|
||||
"""Tuning knobs for one product/task-type pair.
|
||||
|
||||
threshold:
|
||||
Minimum number of approved runs required before crystallization.
|
||||
Osprey sets this to 1 (first successful IVR navigation is enough);
|
||||
Peregrine uses 3+ for cover-letter templates.
|
||||
min_review_ms:
|
||||
Approved runs with review_duration_ms below this value generate a
|
||||
warning. Set to 0 to silence the check (tests, automated approvals).
|
||||
strategy:
|
||||
``"most_recent"`` — use the latest approved run's steps verbatim.
|
||||
``"majority"`` — pick each step by majority vote across runs (requires
|
||||
runs to have the same step count; falls back to most_recent otherwise).
|
||||
"""
|
||||
threshold: int = 3
|
||||
min_review_ms: int = _RUBBER_STAMP_THRESHOLD_MS
|
||||
strategy: Literal["most_recent", "majority"] = "most_recent"
|
||||
|
||||
|
||||
# ── Helpers ───────────────────────────────────────────────────────────────────
|
||||
|
||||
def _majority_steps(runs: list[PipelineRun]) -> list[Step] | None:
|
||||
"""Return majority-voted steps, or None if run lengths differ."""
|
||||
lengths = {len(r.steps) for r in runs}
|
||||
if len(lengths) != 1:
|
||||
return None
|
||||
n = lengths.pop()
|
||||
result: list[Step] = []
|
||||
for i in range(n):
|
||||
counter: Counter[str] = Counter()
|
||||
step_by_action: dict[str, Step] = {}
|
||||
for r in runs:
|
||||
s = r.steps[i]
|
||||
counter[s.action] += 1
|
||||
step_by_action[s.action] = s
|
||||
winner = counter.most_common(1)[0][0]
|
||||
result.append(step_by_action[winner])
|
||||
return result
|
||||
|
||||
|
||||
def _check_review_quality(runs: list[PipelineRun],
|
||||
min_review_ms: int) -> None:
|
||||
"""Warn if any run has a suspiciously short review duration."""
|
||||
if min_review_ms <= 0:
|
||||
return
|
||||
flagged = [r for r in runs if r.review_duration_ms < min_review_ms]
|
||||
if flagged:
|
||||
ids = ", ".join(r.run_id for r in flagged)
|
||||
warnings.warn(
|
||||
f"Crystallizing from {len(flagged)} run(s) with review_duration_ms "
|
||||
f"< {min_review_ms} ms — possible rubber-stamp approval: [{ids}]. "
|
||||
"Verify these were genuinely human-reviewed before deployment.",
|
||||
stacklevel=3,
|
||||
)
|
||||
|
||||
|
||||
# ── Public API ────────────────────────────────────────────────────────────────
|
||||
|
||||
def should_crystallize(runs: list[PipelineRun],
|
||||
config: CrystallizerConfig) -> bool:
|
||||
"""Return True if *runs* meet the threshold for crystallization."""
|
||||
approved = [r for r in runs if r.approved]
|
||||
return len(approved) >= config.threshold
|
||||
|
||||
|
||||
def crystallize(runs: list[PipelineRun],
|
||||
config: CrystallizerConfig,
|
||||
existing_version: int = 0) -> CrystallizedWorkflow:
|
||||
"""Promote *runs* into a CrystallizedWorkflow.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If fewer approved runs than ``config.threshold``, or if the runs
|
||||
span more than one (product, task_type, input_hash) triple.
|
||||
"""
|
||||
approved = [r for r in runs if r.approved]
|
||||
if len(approved) < config.threshold:
|
||||
raise ValueError(
|
||||
f"Need {config.threshold} approved runs, got {len(approved)}."
|
||||
)
|
||||
|
||||
# Validate homogeneity
|
||||
products = {r.product for r in approved}
|
||||
task_types = {r.task_type for r in approved}
|
||||
hashes = {r.input_hash for r in approved}
|
||||
if len(products) != 1 or len(task_types) != 1 or len(hashes) != 1:
|
||||
raise ValueError(
|
||||
"All runs must share the same product, task_type, and input_hash. "
|
||||
f"Got products={products}, task_types={task_types}, hashes={hashes}."
|
||||
)
|
||||
|
||||
product = products.pop()
|
||||
task_type = task_types.pop()
|
||||
input_hash = hashes.pop()
|
||||
|
||||
_check_review_quality(approved, config.min_review_ms)
|
||||
|
||||
# Pick canonical steps
|
||||
if config.strategy == "majority":
|
||||
steps = _majority_steps(approved) or approved[-1].steps
|
||||
else:
|
||||
steps = sorted(approved, key=lambda r: r.timestamp)[-1].steps
|
||||
|
||||
avg_ms = sum(r.review_duration_ms for r in approved) // len(approved)
|
||||
all_unmodified = all(not r.output_modified for r in approved)
|
||||
|
||||
workflow_id = f"{product}:{task_type}:{input_hash[:12]}"
|
||||
return CrystallizedWorkflow(
|
||||
workflow_id=workflow_id,
|
||||
product=product,
|
||||
task_type=task_type,
|
||||
input_hash=input_hash,
|
||||
steps=steps,
|
||||
crystallized_at=datetime.now(timezone.utc).isoformat(),
|
||||
run_ids=[r.run_id for r in approved],
|
||||
approval_count=len(approved),
|
||||
avg_review_duration_ms=avg_ms,
|
||||
all_output_unmodified=all_unmodified,
|
||||
version=existing_version + 1,
|
||||
)
|
||||
|
||||
|
||||
def evaluate_new_run(
|
||||
run: PipelineRun,
|
||||
recorder: Recorder,
|
||||
config: CrystallizerConfig,
|
||||
existing_version: int = 0,
|
||||
) -> CrystallizedWorkflow | None:
|
||||
"""Record *run* and return a new workflow if the threshold is now met.
|
||||
|
||||
Products call this after each human-approved execution. Returns a
|
||||
``CrystallizedWorkflow`` if crystallization was triggered, ``None``
|
||||
otherwise.
|
||||
"""
|
||||
recorder.record(run)
|
||||
if not run.approved:
|
||||
return None
|
||||
|
||||
all_runs = recorder.load_approved(run.product, run.task_type, run.input_hash)
|
||||
if not should_crystallize(all_runs, config):
|
||||
log.debug(
|
||||
"pipeline: %d/%d approved runs for %s:%s — not yet crystallizing",
|
||||
len(all_runs), config.threshold, run.product, run.task_type,
|
||||
)
|
||||
return None
|
||||
|
||||
workflow = crystallize(all_runs, config, existing_version=existing_version)
|
||||
log.info(
|
||||
"pipeline: crystallized %s after %d approvals",
|
||||
workflow.workflow_id, workflow.approval_count,
|
||||
)
|
||||
return workflow
|
||||
157
circuitforge_core/pipeline/executor.py
Normal file
157
circuitforge_core/pipeline/executor.py
Normal file
|
|
@ -0,0 +1,157 @@
|
|||
# circuitforge_core/pipeline/executor.py — deterministic execution with LLM fallback
|
||||
#
|
||||
# MIT — orchestration logic only; calls product-supplied callables.
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable
|
||||
|
||||
from .models import CrystallizedWorkflow, Step
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class StepResult:
|
||||
step: Step
|
||||
success: bool
|
||||
output: Any = None
|
||||
error: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExecutionResult:
|
||||
"""Result of running a workflow (deterministic or LLM-assisted).
|
||||
|
||||
Attributes
|
||||
----------
|
||||
success:
|
||||
True if all steps completed without error.
|
||||
used_deterministic:
|
||||
True if a crystallized workflow was used; False if LLM was called.
|
||||
step_results:
|
||||
Per-step outcomes from the deterministic path.
|
||||
llm_output:
|
||||
Raw output from the LLM fallback path, if used.
|
||||
workflow_id:
|
||||
ID of the workflow used, or None for LLM path.
|
||||
error:
|
||||
Error message if the run failed entirely.
|
||||
"""
|
||||
success: bool
|
||||
used_deterministic: bool
|
||||
step_results: list[StepResult] = field(default_factory=list)
|
||||
llm_output: Any = None
|
||||
workflow_id: str | None = None
|
||||
error: str | None = None
|
||||
|
||||
|
||||
# ── Executor ──────────────────────────────────────────────────────────────────
|
||||
|
||||
class Executor:
|
||||
"""Runs crystallized workflows with transparent LLM fallback.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
step_fn:
|
||||
Called for each Step: ``step_fn(step) -> (success, output)``.
|
||||
The product supplies this — it knows how to turn a Step into a real
|
||||
action (DTMF dial, HTTP call, form field write, etc.).
|
||||
llm_fn:
|
||||
Called when no workflow matches or a step fails: ``llm_fn() -> output``.
|
||||
Products wire this to ``cf_core.llm.router`` or equivalent.
|
||||
llm_fallback:
|
||||
If False, raise RuntimeError instead of calling llm_fn on miss.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
step_fn: Callable[[Step], tuple[bool, Any]],
|
||||
llm_fn: Callable[[], Any],
|
||||
llm_fallback: bool = True,
|
||||
) -> None:
|
||||
self._step_fn = step_fn
|
||||
self._llm_fn = llm_fn
|
||||
self._llm_fallback = llm_fallback
|
||||
|
||||
def execute(
|
||||
self,
|
||||
workflow: CrystallizedWorkflow,
|
||||
) -> ExecutionResult:
|
||||
"""Run *workflow* deterministically.
|
||||
|
||||
If a step fails, falls back to LLM (if ``llm_fallback`` is enabled).
|
||||
"""
|
||||
step_results: list[StepResult] = []
|
||||
for step in workflow.steps:
|
||||
try:
|
||||
success, output = self._step_fn(step)
|
||||
except Exception as exc:
|
||||
log.warning("step %s raised: %s", step.action, exc)
|
||||
success, output = False, None
|
||||
error_str = str(exc)
|
||||
else:
|
||||
error_str = None if success else "step_fn returned success=False"
|
||||
|
||||
step_results.append(StepResult(step=step, success=success,
|
||||
output=output, error=error_str))
|
||||
if not success:
|
||||
log.info(
|
||||
"workflow %s: step %s failed — triggering LLM fallback",
|
||||
workflow.workflow_id, step.action,
|
||||
)
|
||||
return self._llm_fallback_result(
|
||||
step_results, workflow.workflow_id
|
||||
)
|
||||
|
||||
log.info("workflow %s: all %d steps succeeded",
|
||||
workflow.workflow_id, len(workflow.steps))
|
||||
return ExecutionResult(
|
||||
success=True,
|
||||
used_deterministic=True,
|
||||
step_results=step_results,
|
||||
workflow_id=workflow.workflow_id,
|
||||
)
|
||||
|
||||
def run_with_fallback(
|
||||
self,
|
||||
workflow: CrystallizedWorkflow | None,
|
||||
) -> ExecutionResult:
|
||||
"""Run *workflow* if provided; otherwise call the LLM directly."""
|
||||
if workflow is None:
|
||||
return self._llm_fallback_result([], workflow_id=None)
|
||||
return self.execute(workflow)
|
||||
|
||||
# ── Internal ──────────────────────────────────────────────────────────────
|
||||
|
||||
def _llm_fallback_result(
|
||||
self,
|
||||
partial_steps: list[StepResult],
|
||||
workflow_id: str | None,
|
||||
) -> ExecutionResult:
|
||||
if not self._llm_fallback:
|
||||
return ExecutionResult(
|
||||
success=False,
|
||||
used_deterministic=True,
|
||||
step_results=partial_steps,
|
||||
workflow_id=workflow_id,
|
||||
error="LLM fallback disabled and deterministic path failed.",
|
||||
)
|
||||
try:
|
||||
llm_output = self._llm_fn()
|
||||
except Exception as exc:
|
||||
return ExecutionResult(
|
||||
success=False,
|
||||
used_deterministic=False,
|
||||
step_results=partial_steps,
|
||||
workflow_id=workflow_id,
|
||||
error=f"LLM fallback raised: {exc}",
|
||||
)
|
||||
return ExecutionResult(
|
||||
success=True,
|
||||
used_deterministic=False,
|
||||
step_results=partial_steps,
|
||||
llm_output=llm_output,
|
||||
workflow_id=workflow_id,
|
||||
)
|
||||
216
circuitforge_core/pipeline/models.py
Normal file
216
circuitforge_core/pipeline/models.py
Normal file
|
|
@ -0,0 +1,216 @@
|
|||
# circuitforge_core/pipeline/models.py — crystallization data models
|
||||
#
|
||||
# MIT — protocol and model types only; no inference backends.
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
|
||||
# ── Utilities ─────────────────────────────────────────────────────────────────
|
||||
|
||||
def hash_input(features: dict[str, Any]) -> str:
|
||||
"""Return a stable SHA-256 hex digest of *features*.
|
||||
|
||||
Sorts keys before serialising so insertion order doesn't affect the hash.
|
||||
Only call this on already-normalised, PII-free feature dicts — the hash is
|
||||
opaque but the source dict should never contain raw user data.
|
||||
"""
|
||||
canonical = json.dumps(features, sort_keys=True, ensure_ascii=True)
|
||||
return hashlib.sha256(canonical.encode()).hexdigest()
|
||||
|
||||
|
||||
# ── Step ──────────────────────────────────────────────────────────────────────
|
||||
|
||||
@dataclass
|
||||
class Step:
|
||||
"""One atomic action in a deterministic workflow.
|
||||
|
||||
The ``action`` string is product-defined (e.g. ``"dtmf"``, ``"field_fill"``,
|
||||
``"api_call"``). ``params`` carries action-specific values; ``description``
|
||||
is a plain-English summary for the approval UI.
|
||||
"""
|
||||
action: str
|
||||
params: dict[str, Any]
|
||||
description: str = ""
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {"action": self.action, "params": self.params,
|
||||
"description": self.description}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d: dict[str, Any]) -> "Step":
|
||||
return cls(action=d["action"], params=d.get("params", {}),
|
||||
description=d.get("description", ""))
|
||||
|
||||
|
||||
# ── PipelineRun ───────────────────────────────────────────────────────────────
|
||||
|
||||
@dataclass
|
||||
class PipelineRun:
|
||||
"""Record of one LLM-assisted execution — the raw material for crystallization.
|
||||
|
||||
Fields
|
||||
------
|
||||
run_id:
|
||||
UUID or unique string identifying this run.
|
||||
product:
|
||||
CF product code (``"osprey"``, ``"falcon"``, ``"peregrine"`` …).
|
||||
task_type:
|
||||
Product-defined task category (``"ivr_navigate"``, ``"form_fill"`` …).
|
||||
input_hash:
|
||||
SHA-256 of normalised, PII-free input features. Never store raw input.
|
||||
steps:
|
||||
Ordered list of Steps the LLM proposed.
|
||||
approved:
|
||||
True if a human approved this run before execution.
|
||||
review_duration_ms:
|
||||
Wall-clock milliseconds between displaying the proposal and the approval
|
||||
click. Values under ~5 000 ms indicate a rubber-stamp — the
|
||||
crystallizer may reject runs with suspiciously short reviews.
|
||||
output_modified:
|
||||
True if the user edited any step before approving. Modifications suggest
|
||||
the LLM proposal was imperfect; too-easy crystallization from unmodified
|
||||
runs may mean the task is already deterministic and the LLM is just
|
||||
echoing a fixed pattern.
|
||||
timestamp:
|
||||
ISO 8601 UTC creation time.
|
||||
llm_model:
|
||||
Model ID that generated the steps, e.g. ``"llama3:8b-instruct"``.
|
||||
metadata:
|
||||
Freeform dict for product-specific extra fields.
|
||||
"""
|
||||
|
||||
run_id: str
|
||||
product: str
|
||||
task_type: str
|
||||
input_hash: str
|
||||
steps: list[Step]
|
||||
approved: bool
|
||||
review_duration_ms: int
|
||||
output_modified: bool
|
||||
timestamp: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
|
||||
llm_model: str | None = None
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"run_id": self.run_id,
|
||||
"product": self.product,
|
||||
"task_type": self.task_type,
|
||||
"input_hash": self.input_hash,
|
||||
"steps": [s.to_dict() for s in self.steps],
|
||||
"approved": self.approved,
|
||||
"review_duration_ms": self.review_duration_ms,
|
||||
"output_modified": self.output_modified,
|
||||
"timestamp": self.timestamp,
|
||||
"llm_model": self.llm_model,
|
||||
"metadata": self.metadata,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d: dict[str, Any]) -> "PipelineRun":
|
||||
return cls(
|
||||
run_id=d["run_id"],
|
||||
product=d["product"],
|
||||
task_type=d["task_type"],
|
||||
input_hash=d["input_hash"],
|
||||
steps=[Step.from_dict(s) for s in d.get("steps", [])],
|
||||
approved=d["approved"],
|
||||
review_duration_ms=d["review_duration_ms"],
|
||||
output_modified=d.get("output_modified", False),
|
||||
timestamp=d.get("timestamp", ""),
|
||||
llm_model=d.get("llm_model"),
|
||||
metadata=d.get("metadata", {}),
|
||||
)
|
||||
|
||||
|
||||
# ── CrystallizedWorkflow ──────────────────────────────────────────────────────
|
||||
|
||||
@dataclass
|
||||
class CrystallizedWorkflow:
|
||||
"""A deterministic workflow promoted from N approved PipelineRuns.
|
||||
|
||||
Once crystallized, the executor runs ``steps`` directly — no LLM required
|
||||
unless an edge case is encountered.
|
||||
|
||||
Fields
|
||||
------
|
||||
workflow_id:
|
||||
Unique identifier (typically ``{product}:{task_type}:{input_hash[:12]}``).
|
||||
product / task_type / input_hash:
|
||||
Same semantics as PipelineRun; the hash is the lookup key.
|
||||
steps:
|
||||
Canonical deterministic step sequence (majority-voted or most-recent,
|
||||
per CrystallizerConfig.strategy).
|
||||
crystallized_at:
|
||||
ISO 8601 UTC timestamp.
|
||||
run_ids:
|
||||
IDs of the source PipelineRuns that contributed to this workflow.
|
||||
approval_count:
|
||||
Number of approved runs that went into crystallization.
|
||||
avg_review_duration_ms:
|
||||
Mean review_duration_ms across all source runs — low values are a
|
||||
warning sign that approvals may not have been genuine.
|
||||
all_output_unmodified:
|
||||
True if every contributing run had output_modified=False. Combined with
|
||||
a very short avg_review_duration_ms this can flag workflows that may
|
||||
have crystallized from rubber-stamp approvals.
|
||||
active:
|
||||
Whether this workflow is in use. Set to False to disable without
|
||||
deleting the record.
|
||||
version:
|
||||
Increments each time the workflow is re-crystallized from new runs.
|
||||
"""
|
||||
|
||||
workflow_id: str
|
||||
product: str
|
||||
task_type: str
|
||||
input_hash: str
|
||||
steps: list[Step]
|
||||
crystallized_at: str
|
||||
run_ids: list[str]
|
||||
approval_count: int
|
||||
avg_review_duration_ms: int
|
||||
all_output_unmodified: bool
|
||||
active: bool = True
|
||||
version: int = 1
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"workflow_id": self.workflow_id,
|
||||
"product": self.product,
|
||||
"task_type": self.task_type,
|
||||
"input_hash": self.input_hash,
|
||||
"steps": [s.to_dict() for s in self.steps],
|
||||
"crystallized_at": self.crystallized_at,
|
||||
"run_ids": self.run_ids,
|
||||
"approval_count": self.approval_count,
|
||||
"avg_review_duration_ms": self.avg_review_duration_ms,
|
||||
"all_output_unmodified": self.all_output_unmodified,
|
||||
"active": self.active,
|
||||
"version": self.version,
|
||||
"metadata": self.metadata,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d: dict[str, Any]) -> "CrystallizedWorkflow":
|
||||
return cls(
|
||||
workflow_id=d["workflow_id"],
|
||||
product=d["product"],
|
||||
task_type=d["task_type"],
|
||||
input_hash=d["input_hash"],
|
||||
steps=[Step.from_dict(s) for s in d.get("steps", [])],
|
||||
crystallized_at=d["crystallized_at"],
|
||||
run_ids=d.get("run_ids", []),
|
||||
approval_count=d["approval_count"],
|
||||
avg_review_duration_ms=d["avg_review_duration_ms"],
|
||||
all_output_unmodified=d.get("all_output_unmodified", True),
|
||||
active=d.get("active", True),
|
||||
version=d.get("version", 1),
|
||||
metadata=d.get("metadata", {}),
|
||||
)
|
||||
234
circuitforge_core/pipeline/multimodal.py
Normal file
234
circuitforge_core/pipeline/multimodal.py
Normal file
|
|
@ -0,0 +1,234 @@
|
|||
# circuitforge_core/pipeline/multimodal.py — cf-docuvision + cf-text pipeline
|
||||
#
|
||||
# MIT — orchestration only; vision and text inference stay in their own modules.
|
||||
#
|
||||
# Usage (minimal):
|
||||
#
|
||||
# from circuitforge_core.pipeline.multimodal import MultimodalPipeline, MultimodalConfig
|
||||
#
|
||||
# pipe = MultimodalPipeline(MultimodalConfig())
|
||||
# for result in pipe.run(page_bytes_list):
|
||||
# print(f"Page {result.page_idx}: {result.generated[:80]}")
|
||||
#
|
||||
# Streaming (token-by-token):
|
||||
#
|
||||
# for page_idx, token in pipe.stream(page_bytes_list):
|
||||
# ui.append(page_idx, token)
|
||||
#
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Callable, Iterable, Iterator
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from circuitforge_core.documents.client import DocuvisionClient
|
||||
from circuitforge_core.documents.models import StructuredDocument
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ── Config ────────────────────────────────────────────────────────────────────
|
||||
|
||||
def _default_prompt(page_idx: int, doc: StructuredDocument) -> str:
|
||||
"""Build a generation prompt from a StructuredDocument."""
|
||||
header = f"[Page {page_idx + 1}]\n" if page_idx > 0 else ""
|
||||
return header + doc.raw_text
|
||||
|
||||
|
||||
@dataclass
|
||||
class MultimodalConfig:
|
||||
"""Configuration for MultimodalPipeline.
|
||||
|
||||
vision_url:
|
||||
Base URL of the cf-docuvision service.
|
||||
hint:
|
||||
Docuvision extraction hint — ``"auto"`` | ``"document"`` | ``"form"``
|
||||
| ``"table"`` | ``"figure"``.
|
||||
max_tokens:
|
||||
Passed to cf-text generate per page.
|
||||
temperature:
|
||||
Sampling temperature for text generation.
|
||||
vram_serialise:
|
||||
When True, ``swap_fn`` is called between the vision and text steps
|
||||
on each page. Use this on 8GB GPUs where Dolphin-v2 and the text
|
||||
model cannot be resident simultaneously.
|
||||
prompt_fn:
|
||||
Callable ``(page_idx, StructuredDocument) -> str`` that builds the
|
||||
generation prompt. Defaults to using ``doc.raw_text`` directly.
|
||||
Products override this to add system context, few-shot examples, etc.
|
||||
vision_timeout:
|
||||
HTTP timeout in seconds for each cf-docuvision request.
|
||||
"""
|
||||
vision_url: str = "http://localhost:8003"
|
||||
hint: str = "auto"
|
||||
max_tokens: int = 512
|
||||
temperature: float = 0.7
|
||||
vram_serialise: bool = False
|
||||
prompt_fn: Callable[[int, StructuredDocument], str] = field(
|
||||
default_factory=lambda: _default_prompt
|
||||
)
|
||||
vision_timeout: int = 60
|
||||
|
||||
|
||||
# ── Results ───────────────────────────────────────────────────────────────────
|
||||
|
||||
@dataclass
|
||||
class PageResult:
|
||||
"""Result of processing one page through the vision + text pipeline.
|
||||
|
||||
page_idx:
|
||||
Zero-based page index.
|
||||
doc:
|
||||
StructuredDocument from cf-docuvision.
|
||||
generated:
|
||||
Full text output from cf-text for this page.
|
||||
error:
|
||||
Non-None if extraction or generation failed for this page.
|
||||
"""
|
||||
page_idx: int
|
||||
doc: StructuredDocument | None
|
||||
generated: str
|
||||
error: str | None = None
|
||||
|
||||
|
||||
# ── Pipeline ──────────────────────────────────────────────────────────────────
|
||||
|
||||
class MultimodalPipeline:
|
||||
"""Chunk a multi-page document through vision extraction + text generation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config:
|
||||
Pipeline configuration.
|
||||
swap_fn:
|
||||
Optional callable with no arguments, called between the vision and text
|
||||
steps on each page when ``config.vram_serialise=True``. Products using
|
||||
cf-orch wire this to the VRAM budget API so Dolphin-v2 can offload
|
||||
before the text model loads. A no-op lambda works for testing.
|
||||
generate_fn:
|
||||
Text generation callable: ``(prompt, max_tokens, temperature) -> str``.
|
||||
Defaults to ``circuitforge_core.text.generate``. Override in tests or
|
||||
when the product manages its own text backend.
|
||||
stream_fn:
|
||||
Streaming text callable: ``(prompt, max_tokens, temperature) -> Iterator[str]``.
|
||||
Defaults to ``circuitforge_core.text.generate`` with ``stream=True``.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: MultimodalConfig | None = None,
|
||||
*,
|
||||
swap_fn: Callable[[], None] | None = None,
|
||||
generate_fn: Callable[..., str] | None = None,
|
||||
stream_fn: Callable[..., Iterator[str]] | None = None,
|
||||
) -> None:
|
||||
self._cfg = config or MultimodalConfig()
|
||||
self._vision = DocuvisionClient(
|
||||
base_url=self._cfg.vision_url,
|
||||
timeout=self._cfg.vision_timeout,
|
||||
)
|
||||
self._swap_fn = swap_fn
|
||||
self._generate_fn = generate_fn
|
||||
self._stream_fn = stream_fn
|
||||
|
||||
# ── Public ────────────────────────────────────────────────────────────────
|
||||
|
||||
def run(self, pages: Iterable[bytes]) -> Iterator[PageResult]:
|
||||
"""Process each page and yield a PageResult as soon as it is ready.
|
||||
|
||||
Callers receive pages one at a time — the UI can begin rendering
|
||||
page 0 while pages 1..N are still being extracted and generated.
|
||||
"""
|
||||
for page_idx, page_bytes in enumerate(pages):
|
||||
yield self._process_page(page_idx, page_bytes)
|
||||
|
||||
def stream(self, pages: Iterable[bytes]) -> Iterator[tuple[int, str]]:
|
||||
"""Yield ``(page_idx, token)`` tuples for token-level progressive rendering.
|
||||
|
||||
Each page is fully extracted before text generation begins, but tokens
|
||||
are yielded as the text model produces them rather than waiting for the
|
||||
full page output.
|
||||
"""
|
||||
for page_idx, page_bytes in enumerate(pages):
|
||||
doc, err = self._extract(page_idx, page_bytes)
|
||||
if err:
|
||||
yield (page_idx, f"[extraction error: {err}]")
|
||||
continue
|
||||
|
||||
self._maybe_swap()
|
||||
|
||||
prompt = self._cfg.prompt_fn(page_idx, doc)
|
||||
try:
|
||||
for token in self._stream_tokens(prompt):
|
||||
yield (page_idx, token)
|
||||
except Exception as exc:
|
||||
log.error("page %d text streaming failed: %s", page_idx, exc)
|
||||
yield (page_idx, f"[generation error: {exc}]")
|
||||
|
||||
# ── Internal ──────────────────────────────────────────────────────────────
|
||||
|
||||
def _process_page(self, page_idx: int, page_bytes: bytes) -> PageResult:
|
||||
doc, err = self._extract(page_idx, page_bytes)
|
||||
if err:
|
||||
return PageResult(page_idx=page_idx, doc=None, generated="", error=err)
|
||||
|
||||
self._maybe_swap()
|
||||
|
||||
prompt = self._cfg.prompt_fn(page_idx, doc)
|
||||
try:
|
||||
text = self._generate(prompt)
|
||||
except Exception as exc:
|
||||
log.error("page %d generation failed: %s", page_idx, exc)
|
||||
return PageResult(page_idx=page_idx, doc=doc, generated="",
|
||||
error=str(exc))
|
||||
|
||||
return PageResult(page_idx=page_idx, doc=doc, generated=text)
|
||||
|
||||
def _extract(
|
||||
self, page_idx: int, page_bytes: bytes
|
||||
) -> tuple[StructuredDocument | None, str | None]:
|
||||
try:
|
||||
doc = self._vision.extract(page_bytes, hint=self._cfg.hint)
|
||||
log.debug("page %d extracted: %d chars", page_idx, len(doc.raw_text))
|
||||
return doc, None
|
||||
except Exception as exc:
|
||||
log.error("page %d vision extraction failed: %s", page_idx, exc)
|
||||
return None, str(exc)
|
||||
|
||||
def _maybe_swap(self) -> None:
|
||||
if self._cfg.vram_serialise and self._swap_fn is not None:
|
||||
log.debug("vram_serialise: calling swap_fn")
|
||||
self._swap_fn()
|
||||
|
||||
def _generate(self, prompt: str) -> str:
|
||||
if self._generate_fn is not None:
|
||||
return self._generate_fn(
|
||||
prompt,
|
||||
max_tokens=self._cfg.max_tokens,
|
||||
temperature=self._cfg.temperature,
|
||||
)
|
||||
from circuitforge_core.text import generate
|
||||
result = generate(
|
||||
prompt,
|
||||
max_tokens=self._cfg.max_tokens,
|
||||
temperature=self._cfg.temperature,
|
||||
)
|
||||
return result.text
|
||||
|
||||
def _stream_tokens(self, prompt: str) -> Iterator[str]:
|
||||
if self._stream_fn is not None:
|
||||
yield from self._stream_fn(
|
||||
prompt,
|
||||
max_tokens=self._cfg.max_tokens,
|
||||
temperature=self._cfg.temperature,
|
||||
)
|
||||
return
|
||||
from circuitforge_core.text import generate
|
||||
tokens = generate(
|
||||
prompt,
|
||||
max_tokens=self._cfg.max_tokens,
|
||||
temperature=self._cfg.temperature,
|
||||
stream=True,
|
||||
)
|
||||
yield from tokens
|
||||
70
circuitforge_core/pipeline/recorder.py
Normal file
70
circuitforge_core/pipeline/recorder.py
Normal file
|
|
@ -0,0 +1,70 @@
|
|||
# circuitforge_core/pipeline/recorder.py — write and load PipelineRun records
|
||||
#
|
||||
# MIT — local file I/O only; no inference.
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Iterable
|
||||
|
||||
from .models import PipelineRun
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
_DEFAULT_ROOT = Path.home() / ".config" / "circuitforge" / "pipeline" / "runs"
|
||||
|
||||
|
||||
class Recorder:
|
||||
"""Writes PipelineRun JSON records to a local directory tree.
|
||||
|
||||
Layout::
|
||||
|
||||
{root}/{product}/{task_type}/{run_id}.json
|
||||
|
||||
The recorder is intentionally append-only — it never deletes or modifies
|
||||
existing records. Old runs accumulate as an audit trail; products that
|
||||
want retention limits should prune the directory themselves.
|
||||
"""
|
||||
|
||||
def __init__(self, root: Path | None = None) -> None:
|
||||
self._root = Path(root) if root else _DEFAULT_ROOT
|
||||
|
||||
# ── Write ─────────────────────────────────────────────────────────────────
|
||||
|
||||
def record(self, run: PipelineRun) -> Path:
|
||||
"""Persist *run* to disk and return the file path written."""
|
||||
dest = self._path_for(run.product, run.task_type, run.run_id)
|
||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||
dest.write_text(json.dumps(run.to_dict(), indent=2), encoding="utf-8")
|
||||
log.debug("recorded pipeline run %s → %s", run.run_id, dest)
|
||||
return dest
|
||||
|
||||
# ── Read ──────────────────────────────────────────────────────────────────
|
||||
|
||||
def load_runs(self, product: str, task_type: str) -> list[PipelineRun]:
|
||||
"""Return all runs for *(product, task_type)*, newest-first."""
|
||||
directory = self._root / product / task_type
|
||||
if not directory.is_dir():
|
||||
return []
|
||||
runs: list[PipelineRun] = []
|
||||
for p in directory.glob("*.json"):
|
||||
try:
|
||||
runs.append(PipelineRun.from_dict(json.loads(p.read_text())))
|
||||
except Exception:
|
||||
log.warning("skipping unreadable run file %s", p)
|
||||
runs.sort(key=lambda r: r.timestamp, reverse=True)
|
||||
return runs
|
||||
|
||||
def load_approved(self, product: str, task_type: str,
|
||||
input_hash: str) -> list[PipelineRun]:
|
||||
"""Return approved runs that match *input_hash*, newest-first."""
|
||||
return [
|
||||
r for r in self.load_runs(product, task_type)
|
||||
if r.approved and r.input_hash == input_hash
|
||||
]
|
||||
|
||||
# ── Internal ──────────────────────────────────────────────────────────────
|
||||
|
||||
def _path_for(self, product: str, task_type: str, run_id: str) -> Path:
|
||||
return self._root / product / task_type / f"{run_id}.json"
|
||||
134
circuitforge_core/pipeline/registry.py
Normal file
134
circuitforge_core/pipeline/registry.py
Normal file
|
|
@ -0,0 +1,134 @@
|
|||
# circuitforge_core/pipeline/registry.py — workflow lookup
|
||||
#
|
||||
# MIT — file I/O and matching logic only.
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
from .models import CrystallizedWorkflow
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
_DEFAULT_ROOT = Path.home() / ".config" / "circuitforge" / "pipeline" / "workflows"
|
||||
|
||||
|
||||
class Registry:
|
||||
"""Loads and matches CrystallizedWorkflows from the local filesystem.
|
||||
|
||||
Layout::
|
||||
|
||||
{root}/{product}/{task_type}/{workflow_id}.json
|
||||
|
||||
Exact matching is always available. Products that need fuzzy/semantic
|
||||
matching can supply a ``similarity_fn`` — a callable that takes two input
|
||||
hashes and returns a float in [0, 1]. The registry returns the first
|
||||
active workflow whose similarity score meets ``fuzzy_threshold``.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root: Path | None = None,
|
||||
similarity_fn: Callable[[str, str], float] | None = None,
|
||||
fuzzy_threshold: float = 0.8,
|
||||
) -> None:
|
||||
self._root = Path(root) if root else _DEFAULT_ROOT
|
||||
self._similarity_fn = similarity_fn
|
||||
self._fuzzy_threshold = fuzzy_threshold
|
||||
|
||||
# ── Write ─────────────────────────────────────────────────────────────────
|
||||
|
||||
def register(self, workflow: CrystallizedWorkflow) -> Path:
|
||||
"""Persist *workflow* and return the path written."""
|
||||
dest = self._path_for(workflow.product, workflow.task_type,
|
||||
workflow.workflow_id)
|
||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||
dest.write_text(json.dumps(workflow.to_dict(), indent=2), encoding="utf-8")
|
||||
log.info("registered workflow %s (v%d)", workflow.workflow_id,
|
||||
workflow.version)
|
||||
return dest
|
||||
|
||||
def deactivate(self, workflow_id: str, product: str,
|
||||
task_type: str) -> bool:
|
||||
"""Set ``active=False`` on a stored workflow. Returns True if found."""
|
||||
path = self._path_for(product, task_type, workflow_id)
|
||||
if not path.exists():
|
||||
return False
|
||||
data = json.loads(path.read_text())
|
||||
data["active"] = False
|
||||
path.write_text(json.dumps(data, indent=2), encoding="utf-8")
|
||||
log.info("deactivated workflow %s", workflow_id)
|
||||
return True
|
||||
|
||||
# ── Read ──────────────────────────────────────────────────────────────────
|
||||
|
||||
def load_all(self, product: str, task_type: str) -> list[CrystallizedWorkflow]:
|
||||
"""Return all (including inactive) workflows for *(product, task_type)*."""
|
||||
directory = self._root / product / task_type
|
||||
if not directory.is_dir():
|
||||
return []
|
||||
workflows: list[CrystallizedWorkflow] = []
|
||||
for p in directory.glob("*.json"):
|
||||
try:
|
||||
workflows.append(
|
||||
CrystallizedWorkflow.from_dict(json.loads(p.read_text()))
|
||||
)
|
||||
except Exception:
|
||||
log.warning("skipping unreadable workflow file %s", p)
|
||||
return workflows
|
||||
|
||||
# ── Match ─────────────────────────────────────────────────────────────────
|
||||
|
||||
def match(self, product: str, task_type: str,
|
||||
input_hash: str) -> CrystallizedWorkflow | None:
|
||||
"""Return the active workflow for an exact input_hash match, or None."""
|
||||
for wf in self.load_all(product, task_type):
|
||||
if wf.active and wf.input_hash == input_hash:
|
||||
log.debug("registry exact match: %s", wf.workflow_id)
|
||||
return wf
|
||||
return None
|
||||
|
||||
def fuzzy_match(self, product: str, task_type: str,
|
||||
input_hash: str) -> CrystallizedWorkflow | None:
|
||||
"""Return a workflow above the similarity threshold, or None.
|
||||
|
||||
Requires a ``similarity_fn`` to have been supplied at construction.
|
||||
If none was provided, raises ``RuntimeError``.
|
||||
"""
|
||||
if self._similarity_fn is None:
|
||||
raise RuntimeError(
|
||||
"fuzzy_match() requires a similarity_fn — none was supplied "
|
||||
"to Registry.__init__()."
|
||||
)
|
||||
best: CrystallizedWorkflow | None = None
|
||||
best_score = 0.0
|
||||
for wf in self.load_all(product, task_type):
|
||||
if not wf.active:
|
||||
continue
|
||||
score = self._similarity_fn(wf.input_hash, input_hash)
|
||||
if score >= self._fuzzy_threshold and score > best_score:
|
||||
best = wf
|
||||
best_score = score
|
||||
if best:
|
||||
log.debug("registry fuzzy match: %s (score=%.2f)", best.workflow_id,
|
||||
best_score)
|
||||
return best
|
||||
|
||||
def find(self, product: str, task_type: str,
|
||||
input_hash: str) -> CrystallizedWorkflow | None:
|
||||
"""Exact match first; fuzzy match second (if similarity_fn is set)."""
|
||||
exact = self.match(product, task_type, input_hash)
|
||||
if exact:
|
||||
return exact
|
||||
if self._similarity_fn is not None:
|
||||
return self.fuzzy_match(product, task_type, input_hash)
|
||||
return None
|
||||
|
||||
# ── Internal ──────────────────────────────────────────────────────────────
|
||||
|
||||
def _path_for(self, product: str, task_type: str,
|
||||
workflow_id: str) -> Path:
|
||||
safe_id = workflow_id.replace(":", "_")
|
||||
return self._root / product / task_type / f"{safe_id}.json"
|
||||
|
|
@ -40,8 +40,11 @@ def set_user_preference(
|
|||
s.set(user_id=user_id, path=path, value=value)
|
||||
|
||||
|
||||
from . import accessibility as accessibility
|
||||
|
||||
__all__ = [
|
||||
"get_path", "set_path",
|
||||
"get_user_preference", "set_user_preference",
|
||||
"LocalFileStore", "PreferenceStore",
|
||||
"accessibility",
|
||||
]
|
||||
|
|
|
|||
73
circuitforge_core/preferences/accessibility.py
Normal file
73
circuitforge_core/preferences/accessibility.py
Normal file
|
|
@ -0,0 +1,73 @@
|
|||
# circuitforge_core/preferences/accessibility.py — a11y preference keys
|
||||
#
|
||||
# First-class accessibility preferences so every product UI reads from
|
||||
# the same store path without each implementing it separately.
|
||||
#
|
||||
# All keys use the "accessibility.*" namespace in the preference store.
|
||||
# Products read these via get_user_preference() or the convenience helpers here.
|
||||
from __future__ import annotations
|
||||
|
||||
from circuitforge_core.preferences import get_user_preference, set_user_preference
|
||||
|
||||
# ── Preference key constants ──────────────────────────────────────────────────
|
||||
|
||||
PREF_REDUCED_MOTION = "accessibility.prefers_reduced_motion"
|
||||
PREF_HIGH_CONTRAST = "accessibility.high_contrast"
|
||||
PREF_FONT_SIZE = "accessibility.font_size" # "default" | "large" | "xlarge"
|
||||
PREF_SCREEN_READER = "accessibility.screen_reader_mode" # reduces decorative content
|
||||
|
||||
_DEFAULTS: dict[str, object] = {
|
||||
PREF_REDUCED_MOTION: False,
|
||||
PREF_HIGH_CONTRAST: False,
|
||||
PREF_FONT_SIZE: "default",
|
||||
PREF_SCREEN_READER: False,
|
||||
}
|
||||
|
||||
|
||||
# ── Convenience helpers ───────────────────────────────────────────────────────
|
||||
|
||||
def is_reduced_motion_preferred(
|
||||
user_id: str | None = None,
|
||||
store=None,
|
||||
) -> bool:
|
||||
"""
|
||||
Return True if the user has requested reduced motion.
|
||||
|
||||
Products must honour this in all animated UI elements: transitions,
|
||||
auto-playing content, parallax, loaders. This maps to the CSS
|
||||
`prefers-reduced-motion: reduce` media query and is the canonical
|
||||
source of truth across all CF product UIs.
|
||||
|
||||
Default: False.
|
||||
"""
|
||||
val = get_user_preference(
|
||||
user_id, PREF_REDUCED_MOTION, default=False, store=store
|
||||
)
|
||||
return bool(val)
|
||||
|
||||
|
||||
def is_high_contrast(user_id: str | None = None, store=None) -> bool:
|
||||
"""Return True if the user has requested high-contrast mode."""
|
||||
return bool(get_user_preference(user_id, PREF_HIGH_CONTRAST, default=False, store=store))
|
||||
|
||||
|
||||
def get_font_size(user_id: str | None = None, store=None) -> str:
|
||||
"""Return the user's preferred font size: 'default' | 'large' | 'xlarge'."""
|
||||
val = get_user_preference(user_id, PREF_FONT_SIZE, default="default", store=store)
|
||||
if val not in ("default", "large", "xlarge"):
|
||||
return "default"
|
||||
return str(val)
|
||||
|
||||
|
||||
def is_screen_reader_mode(user_id: str | None = None, store=None) -> bool:
|
||||
"""Return True if the user has requested screen reader optimised output."""
|
||||
return bool(get_user_preference(user_id, PREF_SCREEN_READER, default=False, store=store))
|
||||
|
||||
|
||||
def set_reduced_motion(
|
||||
value: bool,
|
||||
user_id: str | None = None,
|
||||
store=None,
|
||||
) -> None:
|
||||
"""Persist the user's reduced-motion preference."""
|
||||
set_user_preference(user_id, PREF_REDUCED_MOTION, value, store=store)
|
||||
79
circuitforge_core/stt/__init__.py
Normal file
79
circuitforge_core/stt/__init__.py
Normal file
|
|
@ -0,0 +1,79 @@
|
|||
"""
|
||||
circuitforge_core.stt — Speech-to-text service module.
|
||||
|
||||
Quick start (mock mode — no GPU or model required):
|
||||
|
||||
import os; os.environ["CF_STT_MOCK"] = "1"
|
||||
from circuitforge_core.stt import transcribe
|
||||
|
||||
result = transcribe(open("audio.wav", "rb").read())
|
||||
print(result.text, result.confidence)
|
||||
|
||||
Real inference (faster-whisper):
|
||||
|
||||
export CF_STT_MODEL=/Library/Assets/LLM/whisper/models/Whisper/faster-whisper/models--Systran--faster-whisper-medium/snapshots/<hash>
|
||||
from circuitforge_core.stt import transcribe
|
||||
|
||||
cf-orch service profile:
|
||||
|
||||
service_type: cf-stt
|
||||
max_mb: 1024 (medium); 600 (base/small)
|
||||
max_concurrent: 3
|
||||
shared: true
|
||||
managed:
|
||||
exec: python -m circuitforge_core.stt.app
|
||||
args: --model <path> --port {port} --gpu-id {gpu_id}
|
||||
port: 8004
|
||||
health: /health
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
from circuitforge_core.stt.backends.base import (
|
||||
STTBackend,
|
||||
STTResult,
|
||||
STTSegment,
|
||||
make_stt_backend,
|
||||
)
|
||||
from circuitforge_core.stt.backends.mock import MockSTTBackend
|
||||
|
||||
_backend: STTBackend | None = None
|
||||
|
||||
|
||||
def _get_backend() -> STTBackend:
|
||||
global _backend
|
||||
if _backend is None:
|
||||
model_path = os.environ.get("CF_STT_MODEL", "mock")
|
||||
mock = model_path == "mock" or os.environ.get("CF_STT_MOCK", "") == "1"
|
||||
_backend = make_stt_backend(model_path, mock=mock)
|
||||
return _backend
|
||||
|
||||
|
||||
def transcribe(
|
||||
audio: bytes,
|
||||
*,
|
||||
language: str | None = None,
|
||||
confidence_threshold: float = STTResult.CONFIDENCE_DEFAULT_THRESHOLD,
|
||||
) -> STTResult:
|
||||
"""Transcribe audio bytes using the process-level backend."""
|
||||
return _get_backend().transcribe(
|
||||
audio, language=language, confidence_threshold=confidence_threshold
|
||||
)
|
||||
|
||||
|
||||
def reset_backend() -> None:
|
||||
"""Reset the process-level singleton. Test teardown only."""
|
||||
global _backend
|
||||
_backend = None
|
||||
|
||||
|
||||
__all__ = [
|
||||
"STTBackend",
|
||||
"STTResult",
|
||||
"STTSegment",
|
||||
"MockSTTBackend",
|
||||
"make_stt_backend",
|
||||
"transcribe",
|
||||
"reset_backend",
|
||||
]
|
||||
150
circuitforge_core/stt/app.py
Normal file
150
circuitforge_core/stt/app.py
Normal file
|
|
@ -0,0 +1,150 @@
|
|||
"""
|
||||
circuitforge_core.stt.app — cf-stt FastAPI service.
|
||||
|
||||
Managed by cf-orch as a process-type service. cf-orch starts this via:
|
||||
|
||||
python -m circuitforge_core.stt.app \
|
||||
--model /Library/Assets/LLM/whisper/models/Whisper/faster-whisper/models--Systran--faster-whisper-medium/snapshots/<hash> \
|
||||
--port 8004 \
|
||||
--gpu-id 0
|
||||
|
||||
Endpoints:
|
||||
GET /health → {"status": "ok", "model": "<name>", "vram_mb": <n>}
|
||||
POST /transcribe → STTTranscribeResponse (multipart: audio file)
|
||||
|
||||
Audio format: any format ffmpeg understands (WAV, MP3, OGG, FLAC).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, File, Form, HTTPException, UploadFile
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from circuitforge_core.stt.backends.base import STTResult, make_stt_backend
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ── Response model (mirrors circuitforge_orch.contracts.stt.STTTranscribeResponse) ──
|
||||
|
||||
class TranscribeResponse(BaseModel):
|
||||
text: str
|
||||
confidence: float
|
||||
below_threshold: bool
|
||||
language: str | None = None
|
||||
duration_s: float | None = None
|
||||
segments: list[dict] = []
|
||||
model: str = ""
|
||||
|
||||
|
||||
# ── App factory ───────────────────────────────────────────────────────────────
|
||||
|
||||
def create_app(
|
||||
model_path: str,
|
||||
device: str = "cuda",
|
||||
compute_type: str = "float16",
|
||||
confidence_threshold: float = STTResult.CONFIDENCE_DEFAULT_THRESHOLD,
|
||||
mock: bool = False,
|
||||
) -> FastAPI:
|
||||
app = FastAPI(title="cf-stt", version="0.1.0")
|
||||
backend = make_stt_backend(
|
||||
model_path, device=device, compute_type=compute_type, mock=mock
|
||||
)
|
||||
logger.info("cf-stt ready: model=%r vram=%dMB", backend.model_name, backend.vram_mb)
|
||||
|
||||
@app.get("/health")
|
||||
async def health() -> dict:
|
||||
return {"status": "ok", "model": backend.model_name, "vram_mb": backend.vram_mb}
|
||||
|
||||
@app.post("/transcribe", response_model=TranscribeResponse)
|
||||
async def transcribe(
|
||||
audio: UploadFile = File(..., description="Audio file (WAV, MP3, OGG, FLAC, ...)"),
|
||||
language: str | None = Form(None, description="BCP-47 language code hint, e.g. 'en'"),
|
||||
confidence_threshold_override: float | None = Form(
|
||||
None,
|
||||
description="Override default confidence threshold for this request.",
|
||||
),
|
||||
) -> TranscribeResponse:
|
||||
audio_bytes = await audio.read()
|
||||
if not audio_bytes:
|
||||
raise HTTPException(status_code=400, detail="Empty audio file")
|
||||
|
||||
threshold = confidence_threshold_override or confidence_threshold
|
||||
try:
|
||||
result = backend.transcribe(
|
||||
audio_bytes, language=language, confidence_threshold=threshold
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.exception("Transcription failed")
|
||||
raise HTTPException(status_code=500, detail=str(exc)) from exc
|
||||
|
||||
return TranscribeResponse(
|
||||
text=result.text,
|
||||
confidence=result.confidence,
|
||||
below_threshold=result.below_threshold,
|
||||
language=result.language,
|
||||
duration_s=result.duration_s,
|
||||
segments=[
|
||||
{
|
||||
"start_s": s.start_s,
|
||||
"end_s": s.end_s,
|
||||
"text": s.text,
|
||||
"confidence": s.confidence,
|
||||
}
|
||||
for s in result.segments
|
||||
],
|
||||
model=result.model,
|
||||
)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
# ── CLI entry point ───────────────────────────────────────────────────────────
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="cf-stt — CircuitForge STT service")
|
||||
parser.add_argument("--model", required=True,
|
||||
help="Model path or size name (e.g. 'medium', or full local path)")
|
||||
parser.add_argument("--port", type=int, default=8004)
|
||||
parser.add_argument("--host", default="0.0.0.0")
|
||||
parser.add_argument("--gpu-id", type=int, default=0,
|
||||
help="CUDA device index (sets CUDA_VISIBLE_DEVICES)")
|
||||
parser.add_argument("--device", default="cuda", choices=["cuda", "cpu"])
|
||||
parser.add_argument("--compute-type", default="float16",
|
||||
choices=["float16", "int8", "int8_float16", "float32"],
|
||||
help="Quantisation / compute type passed to faster-whisper")
|
||||
parser.add_argument("--confidence-threshold", type=float,
|
||||
default=STTResult.CONFIDENCE_DEFAULT_THRESHOLD)
|
||||
parser.add_argument("--mock", action="store_true",
|
||||
help="Run with mock backend (no GPU, for testing)")
|
||||
args = parser.parse_args()
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s %(levelname)s %(name)s %(message)s",
|
||||
)
|
||||
|
||||
# Let cf-orch pass --gpu-id; map to CUDA_VISIBLE_DEVICES so the process
|
||||
# only sees its assigned GPU. This prevents accidental multi-GPU usage.
|
||||
if args.device == "cuda" and not args.mock:
|
||||
os.environ.setdefault("CUDA_VISIBLE_DEVICES", str(args.gpu_id))
|
||||
|
||||
mock = args.mock or os.environ.get("CF_STT_MOCK", "") == "1"
|
||||
app = create_app(
|
||||
model_path=args.model,
|
||||
device=args.device,
|
||||
compute_type=args.compute_type,
|
||||
confidence_threshold=args.confidence_threshold,
|
||||
mock=mock,
|
||||
)
|
||||
|
||||
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
4
circuitforge_core/stt/backends/__init__.py
Normal file
4
circuitforge_core/stt/backends/__init__.py
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
from .base import STTBackend, STTResult, STTSegment, make_stt_backend
|
||||
from .mock import MockSTTBackend
|
||||
|
||||
__all__ = ["STTBackend", "STTResult", "STTSegment", "make_stt_backend", "MockSTTBackend"]
|
||||
109
circuitforge_core/stt/backends/base.py
Normal file
109
circuitforge_core/stt/backends/base.py
Normal file
|
|
@ -0,0 +1,109 @@
|
|||
# circuitforge_core/stt/backends/base.py — STTBackend Protocol + factory
|
||||
#
|
||||
# MIT licensed. The Protocol and mock are always importable without GPU deps.
|
||||
# Real backends require optional extras:
|
||||
# pip install -e "circuitforge-core[stt-faster-whisper]"
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Protocol, runtime_checkable
|
||||
|
||||
|
||||
# ── Result types ──────────────────────────────────────────────────────────────
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class STTSegment:
|
||||
"""Word- or phrase-level segment (included when the backend supports it)."""
|
||||
start_s: float
|
||||
end_s: float
|
||||
text: str
|
||||
confidence: float
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class STTResult:
|
||||
"""
|
||||
Standard result from any STTBackend.transcribe() call.
|
||||
|
||||
confidence is normalised to 0.0–1.0 regardless of the backend's native metric.
|
||||
below_threshold is True when confidence < the configured threshold (default 0.75).
|
||||
This flag is safety-critical for products like Osprey: DTMF must NOT be sent
|
||||
when below_threshold is True.
|
||||
"""
|
||||
text: str
|
||||
confidence: float # 0.0–1.0
|
||||
below_threshold: bool
|
||||
language: str | None = None
|
||||
duration_s: float | None = None
|
||||
segments: list[STTSegment] = field(default_factory=list)
|
||||
model: str = ""
|
||||
|
||||
CONFIDENCE_DEFAULT_THRESHOLD: float = 0.75
|
||||
|
||||
|
||||
# ── Protocol ──────────────────────────────────────────────────────────────────
|
||||
|
||||
@runtime_checkable
|
||||
class STTBackend(Protocol):
|
||||
"""
|
||||
Abstract interface for speech-to-text backends.
|
||||
|
||||
All backends load their model once at construction time and are safe to
|
||||
call concurrently (the model weights are read-only after load).
|
||||
"""
|
||||
|
||||
def transcribe(
|
||||
self,
|
||||
audio: bytes,
|
||||
*,
|
||||
language: str | None = None,
|
||||
confidence_threshold: float = STTResult.CONFIDENCE_DEFAULT_THRESHOLD,
|
||||
) -> STTResult:
|
||||
"""Synchronous transcription. audio is raw PCM or any format ffmpeg understands."""
|
||||
...
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
"""Identifier for the loaded model (path stem or size name)."""
|
||||
...
|
||||
|
||||
@property
|
||||
def vram_mb(self) -> int:
|
||||
"""Approximate VRAM footprint in MB. Used by cf-orch service registry."""
|
||||
...
|
||||
|
||||
|
||||
# ── Factory ───────────────────────────────────────────────────────────────────
|
||||
|
||||
def make_stt_backend(
|
||||
model_path: str,
|
||||
backend: str | None = None,
|
||||
mock: bool | None = None,
|
||||
device: str = "cuda",
|
||||
compute_type: str = "float16",
|
||||
) -> STTBackend:
|
||||
"""
|
||||
Return an STTBackend for the given model.
|
||||
|
||||
mock=True or CF_STT_MOCK=1 → MockSTTBackend (no GPU, no model file needed)
|
||||
backend="faster-whisper" → FasterWhisperBackend (default)
|
||||
|
||||
device and compute_type are passed through to the backend and ignored by mock.
|
||||
"""
|
||||
use_mock = mock if mock is not None else os.environ.get("CF_STT_MOCK", "") == "1"
|
||||
if use_mock:
|
||||
from circuitforge_core.stt.backends.mock import MockSTTBackend
|
||||
return MockSTTBackend(model_name=model_path)
|
||||
|
||||
resolved = backend or os.environ.get("CF_STT_BACKEND", "faster-whisper")
|
||||
if resolved == "faster-whisper":
|
||||
from circuitforge_core.stt.backends.faster_whisper import FasterWhisperBackend
|
||||
return FasterWhisperBackend(
|
||||
model_path=model_path, device=device, compute_type=compute_type
|
||||
)
|
||||
|
||||
raise ValueError(
|
||||
f"Unknown STT backend {resolved!r}. "
|
||||
"Expected 'faster-whisper'. Set CF_STT_BACKEND or pass backend= explicitly."
|
||||
)
|
||||
139
circuitforge_core/stt/backends/faster_whisper.py
Normal file
139
circuitforge_core/stt/backends/faster_whisper.py
Normal file
|
|
@ -0,0 +1,139 @@
|
|||
# circuitforge_core/stt/backends/faster_whisper.py — FasterWhisperBackend
|
||||
#
|
||||
# MIT licensed. Requires: pip install -e "circuitforge-core[stt-faster-whisper]"
|
||||
#
|
||||
# Model path can be:
|
||||
# - A size name: "base", "small", "medium", "large-v3"
|
||||
# (faster-whisper downloads and caches it on first use)
|
||||
# - A local path: "/Library/Assets/LLM/whisper/models/Whisper/faster-whisper/..."
|
||||
# (preferred for air-gapped nodes — no download needed)
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
from circuitforge_core.stt.backends.base import STTResult, STTSegment
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# VRAM estimates by model size. Used by cf-orch for VRAM budgeting.
|
||||
_VRAM_MB_BY_SIZE: dict[str, int] = {
|
||||
"tiny": 200,
|
||||
"base": 350,
|
||||
"small": 600,
|
||||
"medium": 1024,
|
||||
"large": 2048,
|
||||
"large-v2": 2048,
|
||||
"large-v3": 2048,
|
||||
"distil-large-v3": 1500,
|
||||
}
|
||||
|
||||
# Aggregate confidence from per-segment no_speech_prob values.
|
||||
# faster-whisper doesn't expose a direct confidence score, so we invert the
|
||||
# mean no_speech_prob as a proxy. This is conservative but directionally correct.
|
||||
def _aggregate_confidence(segments: list) -> float:
|
||||
if not segments:
|
||||
return 0.0
|
||||
probs = [max(0.0, 1.0 - getattr(s, "no_speech_prob", 0.0)) for s in segments]
|
||||
return sum(probs) / len(probs)
|
||||
|
||||
|
||||
class FasterWhisperBackend:
|
||||
"""
|
||||
faster-whisper STT backend.
|
||||
|
||||
Thread-safe after construction: WhisperModel internally manages its own
|
||||
CUDA context and is safe to call from multiple threads.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_path: str,
|
||||
device: str = "cuda",
|
||||
compute_type: str = "float16",
|
||||
) -> None:
|
||||
try:
|
||||
from faster_whisper import WhisperModel
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"faster-whisper is not installed. "
|
||||
"Run: pip install -e 'circuitforge-core[stt-faster-whisper]'"
|
||||
) from exc
|
||||
|
||||
logger.info("Loading faster-whisper model from %r (device=%s)", model_path, device)
|
||||
self._model_path = model_path
|
||||
self._device = device
|
||||
self._compute_type = compute_type
|
||||
self._model = WhisperModel(model_path, device=device, compute_type=compute_type)
|
||||
logger.info("faster-whisper model ready")
|
||||
|
||||
# Determine VRAM footprint from model name/path stem.
|
||||
stem = os.path.basename(model_path.rstrip("/")).lower()
|
||||
self._vram_mb = next(
|
||||
(v for k, v in _VRAM_MB_BY_SIZE.items() if k in stem),
|
||||
1024, # conservative default if size can't be inferred
|
||||
)
|
||||
|
||||
def transcribe(
|
||||
self,
|
||||
audio: bytes,
|
||||
*,
|
||||
language: str | None = None,
|
||||
confidence_threshold: float = STTResult.CONFIDENCE_DEFAULT_THRESHOLD,
|
||||
) -> STTResult:
|
||||
"""
|
||||
Transcribe raw audio bytes.
|
||||
|
||||
audio can be any format ffmpeg understands (WAV, MP3, OGG, FLAC, etc.).
|
||||
faster-whisper writes audio to a temp file internally; we follow the
|
||||
same pattern to avoid holding the bytes in memory longer than needed.
|
||||
"""
|
||||
with tempfile.NamedTemporaryFile(suffix=".audio", delete=False) as tmp:
|
||||
tmp.write(audio)
|
||||
tmp_path = tmp.name
|
||||
|
||||
try:
|
||||
segments_gen, info = self._model.transcribe(
|
||||
tmp_path,
|
||||
language=language,
|
||||
word_timestamps=True,
|
||||
vad_filter=True,
|
||||
)
|
||||
segments = list(segments_gen)
|
||||
finally:
|
||||
os.unlink(tmp_path)
|
||||
|
||||
text = " ".join(s.text.strip() for s in segments).strip()
|
||||
confidence = _aggregate_confidence(segments)
|
||||
duration_s = info.duration if hasattr(info, "duration") else None
|
||||
detected_language = getattr(info, "language", language)
|
||||
|
||||
stt_segments = [
|
||||
STTSegment(
|
||||
start_s=s.start,
|
||||
end_s=s.end,
|
||||
text=s.text.strip(),
|
||||
confidence=max(0.0, 1.0 - getattr(s, "no_speech_prob", 0.0)),
|
||||
)
|
||||
for s in segments
|
||||
]
|
||||
|
||||
return STTResult(
|
||||
text=text,
|
||||
confidence=confidence,
|
||||
below_threshold=confidence < confidence_threshold,
|
||||
language=detected_language,
|
||||
duration_s=duration_s,
|
||||
segments=stt_segments,
|
||||
model=self._model_path,
|
||||
)
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
return self._model_path
|
||||
|
||||
@property
|
||||
def vram_mb(self) -> int:
|
||||
return self._vram_mb
|
||||
54
circuitforge_core/stt/backends/mock.py
Normal file
54
circuitforge_core/stt/backends/mock.py
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
# circuitforge_core/stt/backends/mock.py — MockSTTBackend
|
||||
#
|
||||
# MIT licensed. No GPU, no model file required.
|
||||
# Used in tests and CI, and when CF_STT_MOCK=1.
|
||||
from __future__ import annotations
|
||||
|
||||
from circuitforge_core.stt.backends.base import STTBackend, STTResult
|
||||
|
||||
|
||||
class MockSTTBackend:
|
||||
"""
|
||||
Deterministic mock STT backend for testing.
|
||||
|
||||
Returns a fixed transcript so tests can assert on the response shape
|
||||
without needing a GPU or a model file.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "mock",
|
||||
fixed_text: str = "mock transcription",
|
||||
fixed_confidence: float = 0.95,
|
||||
) -> None:
|
||||
self._model_name = model_name
|
||||
self._fixed_text = fixed_text
|
||||
self._fixed_confidence = fixed_confidence
|
||||
|
||||
def transcribe(
|
||||
self,
|
||||
audio: bytes,
|
||||
*,
|
||||
language: str | None = None,
|
||||
confidence_threshold: float = STTResult.CONFIDENCE_DEFAULT_THRESHOLD,
|
||||
) -> STTResult:
|
||||
return STTResult(
|
||||
text=self._fixed_text,
|
||||
confidence=self._fixed_confidence,
|
||||
below_threshold=self._fixed_confidence < confidence_threshold,
|
||||
language=language or "en",
|
||||
duration_s=float(len(audio)) / 32000, # rough estimate: 16kHz 16-bit mono
|
||||
model=self._model_name,
|
||||
)
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
return self._model_name
|
||||
|
||||
@property
|
||||
def vram_mb(self) -> int:
|
||||
return 0
|
||||
|
||||
|
||||
# Satisfy the Protocol at import time (no GPU needed)
|
||||
assert isinstance(MockSTTBackend(), STTBackend)
|
||||
144
circuitforge_core/text/__init__.py
Normal file
144
circuitforge_core/text/__init__.py
Normal file
|
|
@ -0,0 +1,144 @@
|
|||
"""
|
||||
circuitforge_core.text — direct text generation service module.
|
||||
|
||||
Provides lightweight, low-overhead text generation that bypasses ollama/vllm
|
||||
for products that need fast, frequent inference from small local models.
|
||||
|
||||
Quick start (mock mode — no model required):
|
||||
|
||||
import os; os.environ["CF_TEXT_MOCK"] = "1"
|
||||
from circuitforge_core.text import generate, chat, ChatMessage
|
||||
|
||||
result = generate("Write a short cover letter intro.")
|
||||
print(result.text)
|
||||
|
||||
reply = chat([
|
||||
ChatMessage("system", "You are a helpful recipe assistant."),
|
||||
ChatMessage("user", "What can I make with eggs, spinach, and feta?"),
|
||||
])
|
||||
print(reply.text)
|
||||
|
||||
Real inference (GGUF model):
|
||||
|
||||
export CF_TEXT_MODEL=/Library/Assets/LLM/qwen2.5-3b-instruct-q4_k_m.gguf
|
||||
from circuitforge_core.text import generate
|
||||
result = generate("Summarise this job posting in 2 sentences: ...")
|
||||
|
||||
Backend selection (CF_TEXT_BACKEND env or explicit):
|
||||
|
||||
from circuitforge_core.text import make_backend
|
||||
backend = make_backend("/path/to/model.gguf", backend="llamacpp")
|
||||
|
||||
cf-orch service profile:
|
||||
|
||||
service_type: cf-text
|
||||
max_mb: per-model (3B Q4 ≈ 2048, 7B Q4 ≈ 4096)
|
||||
preferred_compute: 7.5 minimum (INT8 tensor cores)
|
||||
max_concurrent: 2
|
||||
shared: true
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
from circuitforge_core.text.backends.base import (
|
||||
ChatMessage,
|
||||
GenerateResult,
|
||||
TextBackend,
|
||||
make_text_backend,
|
||||
)
|
||||
from circuitforge_core.text.backends.mock import MockTextBackend
|
||||
|
||||
# ── Process-level singleton backend ──────────────────────────────────────────
|
||||
# Lazily initialised on first call to generate() or chat().
|
||||
# Products that need per-user or per-request backends should use make_backend().
|
||||
|
||||
_backend: TextBackend | None = None
|
||||
|
||||
|
||||
def _get_backend() -> TextBackend:
|
||||
global _backend
|
||||
if _backend is None:
|
||||
model_path = os.environ.get("CF_TEXT_MODEL", "mock")
|
||||
mock = model_path == "mock" or os.environ.get("CF_TEXT_MOCK", "") == "1"
|
||||
_backend = make_text_backend(model_path, mock=mock)
|
||||
return _backend
|
||||
|
||||
|
||||
def make_backend(
|
||||
model_path: str,
|
||||
backend: str | None = None,
|
||||
mock: bool | None = None,
|
||||
) -> TextBackend:
|
||||
"""
|
||||
Create a TextBackend for the given model.
|
||||
|
||||
Use this when you need a dedicated backend per request or per user,
|
||||
rather than the process-level singleton used by generate() and chat().
|
||||
"""
|
||||
return make_text_backend(model_path, backend=backend, mock=mock)
|
||||
|
||||
|
||||
# ── Convenience functions (singleton path) ────────────────────────────────────
|
||||
|
||||
|
||||
def generate(
|
||||
prompt: str,
|
||||
*,
|
||||
model: str | None = None,
|
||||
max_tokens: int = 512,
|
||||
temperature: float = 0.7,
|
||||
stream: bool = False,
|
||||
stop: list[str] | None = None,
|
||||
):
|
||||
"""
|
||||
Generate text from a prompt using the process-level backend.
|
||||
|
||||
stream=True returns an Iterator[str] of tokens instead of GenerateResult.
|
||||
model is accepted for API symmetry with LLMRouter but ignored by the
|
||||
singleton path — set CF_TEXT_MODEL to change the loaded model.
|
||||
"""
|
||||
backend = _get_backend()
|
||||
if stream:
|
||||
return backend.generate_stream(prompt, max_tokens=max_tokens, temperature=temperature, stop=stop)
|
||||
return backend.generate(prompt, max_tokens=max_tokens, temperature=temperature, stop=stop)
|
||||
|
||||
|
||||
def chat(
|
||||
messages: list[ChatMessage],
|
||||
*,
|
||||
model: str | None = None,
|
||||
max_tokens: int = 512,
|
||||
temperature: float = 0.7,
|
||||
stream: bool = False,
|
||||
) -> GenerateResult:
|
||||
"""
|
||||
Chat completion using the process-level backend.
|
||||
|
||||
messages should be a list of ChatMessage(role, content) objects.
|
||||
stream=True is not yet supported on the chat path; pass stream=False.
|
||||
"""
|
||||
if stream:
|
||||
raise NotImplementedError(
|
||||
"stream=True is not yet supported for chat(). "
|
||||
"Use generate_stream() directly on a backend instance."
|
||||
)
|
||||
return _get_backend().chat(messages, max_tokens=max_tokens, temperature=temperature)
|
||||
|
||||
|
||||
def reset_backend() -> None:
|
||||
"""Reset the process-level singleton. Test teardown only."""
|
||||
global _backend
|
||||
_backend = None
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ChatMessage",
|
||||
"GenerateResult",
|
||||
"TextBackend",
|
||||
"MockTextBackend",
|
||||
"make_backend",
|
||||
"generate",
|
||||
"chat",
|
||||
"reset_backend",
|
||||
]
|
||||
226
circuitforge_core/text/app.py
Normal file
226
circuitforge_core/text/app.py
Normal file
|
|
@ -0,0 +1,226 @@
|
|||
"""
|
||||
cf-text FastAPI service — managed by cf-orch.
|
||||
|
||||
Lightweight local text generation. Supports GGUF models via llama.cpp and
|
||||
HuggingFace transformers. Sits alongside vllm/ollama for products that need
|
||||
fast, frequent inference from small local models (3B–7B Q4).
|
||||
|
||||
Endpoints:
|
||||
GET /health → {"status": "ok", "model": str, "vram_mb": int, "backend": str}
|
||||
POST /generate → GenerateResponse
|
||||
POST /chat → GenerateResponse
|
||||
|
||||
Usage:
|
||||
python -m circuitforge_core.text.app \
|
||||
--model /Library/Assets/LLM/qwen2.5-3b-instruct-q4_k_m.gguf \
|
||||
--port 8006 \
|
||||
--gpu-id 0
|
||||
|
||||
Mock mode (no model or GPU required):
|
||||
CF_TEXT_MOCK=1 python -m circuitforge_core.text.app --port 8006
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
from functools import partial
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
from circuitforge_core.text.backends.base import ChatMessage as BackendChatMessage
|
||||
from circuitforge_core.text.backends.base import make_text_backend
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_backend = None
|
||||
|
||||
|
||||
# ── Request / response models ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
class GenerateRequest(BaseModel):
|
||||
prompt: str
|
||||
max_tokens: int = 512
|
||||
temperature: float = 0.7
|
||||
stop: list[str] | None = None
|
||||
|
||||
|
||||
class ChatMessageModel(BaseModel):
|
||||
role: str
|
||||
content: str
|
||||
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
messages: list[ChatMessageModel]
|
||||
max_tokens: int = 512
|
||||
temperature: float = 0.7
|
||||
|
||||
|
||||
class GenerateResponse(BaseModel):
|
||||
text: str
|
||||
tokens_used: int = 0
|
||||
model: str = ""
|
||||
|
||||
|
||||
# ── OpenAI-compat request / response (for LLMRouter openai_compat path) ──────
|
||||
|
||||
|
||||
class OAIMessageModel(BaseModel):
|
||||
role: str
|
||||
content: str
|
||||
|
||||
|
||||
class OAIChatRequest(BaseModel):
|
||||
model: str = "cf-text"
|
||||
messages: list[OAIMessageModel]
|
||||
max_tokens: int | None = None
|
||||
temperature: float = 0.7
|
||||
stream: bool = False
|
||||
|
||||
|
||||
class OAIChoice(BaseModel):
|
||||
index: int = 0
|
||||
message: OAIMessageModel
|
||||
finish_reason: str = "stop"
|
||||
|
||||
|
||||
class OAIUsage(BaseModel):
|
||||
prompt_tokens: int = 0
|
||||
completion_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
|
||||
|
||||
class OAIChatResponse(BaseModel):
|
||||
id: str
|
||||
object: str = "chat.completion"
|
||||
created: int
|
||||
model: str
|
||||
choices: list[OAIChoice]
|
||||
usage: OAIUsage
|
||||
|
||||
|
||||
# ── App factory ───────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def create_app(
|
||||
model_path: str,
|
||||
gpu_id: int = 0,
|
||||
backend: str | None = None,
|
||||
mock: bool = False,
|
||||
) -> FastAPI:
|
||||
global _backend
|
||||
|
||||
os.environ.setdefault("CUDA_VISIBLE_DEVICES", str(gpu_id))
|
||||
|
||||
_backend = make_text_backend(model_path, backend=backend, mock=mock)
|
||||
logger.info("cf-text ready: model=%r vram=%dMB", _backend.model_name, _backend.vram_mb)
|
||||
|
||||
app = FastAPI(title="cf-text", version="0.1.0")
|
||||
|
||||
@app.get("/health")
|
||||
def health() -> dict:
|
||||
if _backend is None:
|
||||
raise HTTPException(503, detail="backend not initialised")
|
||||
return {
|
||||
"status": "ok",
|
||||
"model": _backend.model_name,
|
||||
"vram_mb": _backend.vram_mb,
|
||||
}
|
||||
|
||||
@app.post("/generate")
|
||||
async def generate(req: GenerateRequest) -> GenerateResponse:
|
||||
if _backend is None:
|
||||
raise HTTPException(503, detail="backend not initialised")
|
||||
result = await _backend.generate_async(
|
||||
req.prompt,
|
||||
max_tokens=req.max_tokens,
|
||||
temperature=req.temperature,
|
||||
stop=req.stop,
|
||||
)
|
||||
return GenerateResponse(
|
||||
text=result.text,
|
||||
tokens_used=result.tokens_used,
|
||||
model=result.model,
|
||||
)
|
||||
|
||||
@app.post("/chat")
|
||||
async def chat(req: ChatRequest) -> GenerateResponse:
|
||||
if _backend is None:
|
||||
raise HTTPException(503, detail="backend not initialised")
|
||||
messages = [BackendChatMessage(m.role, m.content) for m in req.messages]
|
||||
# chat() is sync-only in the Protocol; run in thread pool to avoid blocking
|
||||
loop = asyncio.get_event_loop()
|
||||
result = await loop.run_in_executor(
|
||||
None,
|
||||
partial(_backend.chat, messages,
|
||||
max_tokens=req.max_tokens, temperature=req.temperature),
|
||||
)
|
||||
return GenerateResponse(
|
||||
text=result.text,
|
||||
tokens_used=result.tokens_used,
|
||||
model=result.model,
|
||||
)
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
async def oai_chat_completions(req: OAIChatRequest) -> OAIChatResponse:
|
||||
"""OpenAI-compatible chat completions endpoint.
|
||||
|
||||
Allows LLMRouter (and any openai_compat client) to use cf-text
|
||||
without a custom backend type — just set base_url to this service's
|
||||
/v1 prefix.
|
||||
"""
|
||||
if _backend is None:
|
||||
raise HTTPException(503, detail="backend not initialised")
|
||||
messages = [BackendChatMessage(m.role, m.content) for m in req.messages]
|
||||
max_tok = req.max_tokens or 512
|
||||
loop = asyncio.get_event_loop()
|
||||
result = await loop.run_in_executor(
|
||||
None,
|
||||
partial(_backend.chat, messages, max_tokens=max_tok, temperature=req.temperature),
|
||||
)
|
||||
return OAIChatResponse(
|
||||
id=f"cftext-{uuid.uuid4().hex[:12]}",
|
||||
created=int(time.time()),
|
||||
model=result.model or req.model,
|
||||
choices=[OAIChoice(message=OAIMessageModel(role="assistant", content=result.text))],
|
||||
usage=OAIUsage(completion_tokens=result.tokens_used, total_tokens=result.tokens_used),
|
||||
)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
# ── CLI entrypoint ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description="cf-text inference server")
|
||||
parser.add_argument("--model", default=os.environ.get("CF_TEXT_MODEL", "mock"),
|
||||
help="Path to GGUF file or HF model ID")
|
||||
parser.add_argument("--port", type=int, default=8006)
|
||||
parser.add_argument("--host", default="0.0.0.0")
|
||||
parser.add_argument("--gpu-id", type=int, default=0,
|
||||
help="CUDA device index to use")
|
||||
parser.add_argument("--backend", choices=["llamacpp", "transformers"], default=None)
|
||||
parser.add_argument("--mock", action="store_true",
|
||||
help="Run in mock mode (no model or GPU needed)")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO,
|
||||
format="%(asctime)s %(levelname)s %(name)s — %(message)s")
|
||||
args = _parse_args()
|
||||
mock = args.mock or os.environ.get("CF_TEXT_MOCK", "") == "1" or args.model == "mock"
|
||||
app = create_app(
|
||||
model_path=args.model,
|
||||
gpu_id=args.gpu_id,
|
||||
backend=args.backend,
|
||||
mock=mock,
|
||||
)
|
||||
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|
||||
10
circuitforge_core/text/backends/__init__.py
Normal file
10
circuitforge_core/text/backends/__init__.py
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
from .base import ChatMessage, GenerateResult, TextBackend, make_text_backend
|
||||
from .mock import MockTextBackend
|
||||
|
||||
__all__ = [
|
||||
"ChatMessage",
|
||||
"GenerateResult",
|
||||
"TextBackend",
|
||||
"MockTextBackend",
|
||||
"make_text_backend",
|
||||
]
|
||||
182
circuitforge_core/text/backends/base.py
Normal file
182
circuitforge_core/text/backends/base.py
Normal file
|
|
@ -0,0 +1,182 @@
|
|||
# circuitforge_core/text/backends/base.py — TextBackend Protocol + factory
|
||||
#
|
||||
# MIT licensed. The Protocol and mock backend are always importable.
|
||||
# Real backends (LlamaCppBackend, TransformersBackend) require optional extras.
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import AsyncIterator, Iterator, Protocol, runtime_checkable
|
||||
|
||||
|
||||
# ── Shared result types ───────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class GenerateResult:
|
||||
"""Result from a single non-streaming generate() call."""
|
||||
|
||||
def __init__(self, text: str, tokens_used: int = 0, model: str = "") -> None:
|
||||
self.text = text
|
||||
self.tokens_used = tokens_used
|
||||
self.model = model
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"GenerateResult(text={self.text!r:.40}, tokens={self.tokens_used})"
|
||||
|
||||
|
||||
class ChatMessage:
|
||||
"""A single message in a chat conversation."""
|
||||
|
||||
def __init__(self, role: str, content: str) -> None:
|
||||
if role not in ("system", "user", "assistant"):
|
||||
raise ValueError(f"Invalid role {role!r}. Must be system, user, or assistant.")
|
||||
self.role = role
|
||||
self.content = content
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {"role": self.role, "content": self.content}
|
||||
|
||||
|
||||
# ── TextBackend Protocol ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class TextBackend(Protocol):
|
||||
"""
|
||||
Abstract interface for direct text generation backends.
|
||||
|
||||
All generate/chat methods have both sync and async variants.
|
||||
Streaming variants yield str tokens rather than a complete result.
|
||||
|
||||
Implementations must be safe to construct once and call concurrently
|
||||
(the model is loaded at construction time and reused across calls).
|
||||
"""
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
*,
|
||||
max_tokens: int = 512,
|
||||
temperature: float = 0.7,
|
||||
stop: list[str] | None = None,
|
||||
) -> GenerateResult:
|
||||
"""Synchronous generate — blocks until the full response is produced."""
|
||||
...
|
||||
|
||||
def generate_stream(
|
||||
self,
|
||||
prompt: str,
|
||||
*,
|
||||
max_tokens: int = 512,
|
||||
temperature: float = 0.7,
|
||||
stop: list[str] | None = None,
|
||||
) -> Iterator[str]:
|
||||
"""Synchronous streaming — yields tokens as they are produced."""
|
||||
...
|
||||
|
||||
async def generate_async(
|
||||
self,
|
||||
prompt: str,
|
||||
*,
|
||||
max_tokens: int = 512,
|
||||
temperature: float = 0.7,
|
||||
stop: list[str] | None = None,
|
||||
) -> GenerateResult:
|
||||
"""Async generate — runs in thread pool, never blocks the event loop."""
|
||||
...
|
||||
|
||||
async def generate_stream_async(
|
||||
self,
|
||||
prompt: str,
|
||||
*,
|
||||
max_tokens: int = 512,
|
||||
temperature: float = 0.7,
|
||||
stop: list[str] | None = None,
|
||||
) -> AsyncIterator[str]:
|
||||
"""Async streaming — yields tokens without blocking the event loop."""
|
||||
...
|
||||
|
||||
def chat(
|
||||
self,
|
||||
messages: list[ChatMessage],
|
||||
*,
|
||||
max_tokens: int = 512,
|
||||
temperature: float = 0.7,
|
||||
) -> GenerateResult:
|
||||
"""Chat completion — formats messages into a prompt and generates."""
|
||||
...
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
"""Identifier for the loaded model (path stem or HF repo ID)."""
|
||||
...
|
||||
|
||||
@property
|
||||
def vram_mb(self) -> int:
|
||||
"""Approximate VRAM footprint in MB. Used by cf-orch service registry."""
|
||||
...
|
||||
|
||||
|
||||
# ── Backend selection ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _select_backend(model_path: str, backend: str | None) -> str:
|
||||
"""
|
||||
Return "llamacpp" or "transformers" for the given model path.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model_path Path to the model file or HuggingFace repo ID (e.g. "Qwen/Qwen2.5-3B").
|
||||
backend Explicit override from the caller ("llamacpp" | "transformers" | None).
|
||||
When provided, trust it without inspection.
|
||||
|
||||
Return "llamacpp" or "transformers". Raise ValueError for unrecognised values.
|
||||
"""
|
||||
_VALID = ("llamacpp", "transformers")
|
||||
|
||||
# 1. Caller-supplied override — highest trust, no inspection needed.
|
||||
resolved = backend or os.environ.get("CF_TEXT_BACKEND")
|
||||
if resolved:
|
||||
if resolved not in _VALID:
|
||||
raise ValueError(
|
||||
f"CF_TEXT_BACKEND={resolved!r} is not valid. Choose: {', '.join(_VALID)}"
|
||||
)
|
||||
return resolved
|
||||
|
||||
# 2. Format detection — GGUF files are unambiguously llama-cpp territory.
|
||||
if model_path.lower().endswith(".gguf"):
|
||||
return "llamacpp"
|
||||
|
||||
# 3. Safe default — transformers covers HF repo IDs and safetensors dirs.
|
||||
return "transformers"
|
||||
|
||||
|
||||
# ── Factory ───────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def make_text_backend(
|
||||
model_path: str,
|
||||
backend: str | None = None,
|
||||
mock: bool | None = None,
|
||||
) -> "TextBackend":
|
||||
"""
|
||||
Return a TextBackend for the given model.
|
||||
|
||||
mock=True or CF_TEXT_MOCK=1 → MockTextBackend (no GPU, no model file needed)
|
||||
Otherwise → backend resolved via _select_backend()
|
||||
"""
|
||||
use_mock = mock if mock is not None else os.environ.get("CF_TEXT_MOCK", "") == "1"
|
||||
if use_mock:
|
||||
from circuitforge_core.text.backends.mock import MockTextBackend
|
||||
return MockTextBackend(model_name=model_path)
|
||||
|
||||
resolved = _select_backend(model_path, backend)
|
||||
|
||||
if resolved == "llamacpp":
|
||||
from circuitforge_core.text.backends.llamacpp import LlamaCppBackend
|
||||
return LlamaCppBackend(model_path=model_path)
|
||||
|
||||
if resolved == "transformers":
|
||||
from circuitforge_core.text.backends.transformers import TransformersBackend
|
||||
return TransformersBackend(model_path=model_path)
|
||||
|
||||
raise ValueError(f"Unknown backend {resolved!r}. Expected 'llamacpp' or 'transformers'.")
|
||||
192
circuitforge_core/text/backends/llamacpp.py
Normal file
192
circuitforge_core/text/backends/llamacpp.py
Normal file
|
|
@ -0,0 +1,192 @@
|
|||
# circuitforge_core/text/backends/llamacpp.py — llama-cpp-python backend
|
||||
#
|
||||
# BSL 1.1: real inference. Requires llama-cpp-python + a GGUF model file.
|
||||
# Install: pip install circuitforge-core[text-llamacpp]
|
||||
#
|
||||
# VRAM estimates (Q4_K_M quant):
|
||||
# 1B → ~700MB 3B → ~2048MB 7B → ~4096MB
|
||||
# 13B → ~7500MB 70B → ~40000MB
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import AsyncIterator, Iterator
|
||||
|
||||
from circuitforge_core.text.backends.base import ChatMessage, GenerateResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Q4_K_M is the recommended default — best accuracy/size tradeoff for local use.
|
||||
_DEFAULT_N_CTX = int(os.environ.get("CF_TEXT_CTX", "4096"))
|
||||
_DEFAULT_N_GPU_LAYERS = int(os.environ.get("CF_TEXT_GPU_LAYERS", "-1")) # -1 = all layers
|
||||
|
||||
|
||||
def _estimate_vram_mb(model_path: str) -> int:
|
||||
"""Rough VRAM estimate from file size. Accurate enough for cf-orch budgeting."""
|
||||
try:
|
||||
size_mb = Path(model_path).stat().st_size // (1024 * 1024)
|
||||
# GGUF models typically need ~1.1× file size in VRAM (KV cache overhead)
|
||||
return int(size_mb * 1.1)
|
||||
except OSError:
|
||||
return 4096 # conservative default
|
||||
|
||||
|
||||
class LlamaCppBackend:
|
||||
"""
|
||||
Direct llama-cpp-python inference backend for GGUF models.
|
||||
|
||||
The model is loaded once at construction. All inference runs in a thread
|
||||
pool executor so async callers never block the event loop.
|
||||
|
||||
Context window, GPU layers, and thread count are configurable via env:
|
||||
CF_TEXT_CTX token context window (default 4096)
|
||||
CF_TEXT_GPU_LAYERS GPU layers to offload, -1 = all (default -1)
|
||||
CF_TEXT_THREADS CPU thread count (default: auto)
|
||||
|
||||
Requires: pip install circuitforge-core[text-llamacpp]
|
||||
"""
|
||||
|
||||
def __init__(self, model_path: str) -> None:
|
||||
try:
|
||||
from llama_cpp import Llama # type: ignore[import]
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"llama-cpp-python is required for LlamaCppBackend. "
|
||||
"Install with: pip install circuitforge-core[text-llamacpp]"
|
||||
) from exc
|
||||
|
||||
if not Path(model_path).exists():
|
||||
raise FileNotFoundError(
|
||||
f"GGUF model not found: {model_path}\n"
|
||||
"Download a GGUF model and set CF_TEXT_MODEL to its path."
|
||||
)
|
||||
|
||||
n_threads = int(os.environ.get("CF_TEXT_THREADS", "0")) or None
|
||||
logger.info(
|
||||
"Loading GGUF model %s (ctx=%d, gpu_layers=%d)",
|
||||
model_path, _DEFAULT_N_CTX, _DEFAULT_N_GPU_LAYERS,
|
||||
)
|
||||
self._llm = Llama(
|
||||
model_path=model_path,
|
||||
n_ctx=_DEFAULT_N_CTX,
|
||||
n_gpu_layers=_DEFAULT_N_GPU_LAYERS,
|
||||
n_threads=n_threads,
|
||||
verbose=False,
|
||||
)
|
||||
self._model_path = model_path
|
||||
self._vram_mb = _estimate_vram_mb(model_path)
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
return Path(self._model_path).stem
|
||||
|
||||
@property
|
||||
def vram_mb(self) -> int:
|
||||
return self._vram_mb
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
*,
|
||||
max_tokens: int = 512,
|
||||
temperature: float = 0.7,
|
||||
stop: list[str] | None = None,
|
||||
) -> GenerateResult:
|
||||
output = self._llm(
|
||||
prompt,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
stop=stop or [],
|
||||
stream=False,
|
||||
)
|
||||
text = output["choices"][0]["text"]
|
||||
tokens_used = output["usage"]["completion_tokens"]
|
||||
return GenerateResult(text=text, tokens_used=tokens_used, model=self.model_name)
|
||||
|
||||
def generate_stream(
|
||||
self,
|
||||
prompt: str,
|
||||
*,
|
||||
max_tokens: int = 512,
|
||||
temperature: float = 0.7,
|
||||
stop: list[str] | None = None,
|
||||
) -> Iterator[str]:
|
||||
for chunk in self._llm(
|
||||
prompt,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
stop=stop or [],
|
||||
stream=True,
|
||||
):
|
||||
yield chunk["choices"][0]["text"]
|
||||
|
||||
async def generate_async(
|
||||
self,
|
||||
prompt: str,
|
||||
*,
|
||||
max_tokens: int = 512,
|
||||
temperature: float = 0.7,
|
||||
stop: list[str] | None = None,
|
||||
) -> GenerateResult:
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(
|
||||
None,
|
||||
lambda: self.generate(prompt, max_tokens=max_tokens, temperature=temperature, stop=stop),
|
||||
)
|
||||
|
||||
async def generate_stream_async(
|
||||
self,
|
||||
prompt: str,
|
||||
*,
|
||||
max_tokens: int = 512,
|
||||
temperature: float = 0.7,
|
||||
stop: list[str] | None = None,
|
||||
) -> AsyncIterator[str]:
|
||||
# llama_cpp streaming is synchronous — run in executor and re-emit tokens
|
||||
import queue
|
||||
import threading
|
||||
|
||||
token_queue: queue.Queue = queue.Queue()
|
||||
_DONE = object()
|
||||
|
||||
def _produce() -> None:
|
||||
try:
|
||||
for chunk in self._llm(
|
||||
prompt,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
stop=stop or [],
|
||||
stream=True,
|
||||
):
|
||||
token_queue.put(chunk["choices"][0]["text"])
|
||||
finally:
|
||||
token_queue.put(_DONE)
|
||||
|
||||
thread = threading.Thread(target=_produce, daemon=True)
|
||||
thread.start()
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
while True:
|
||||
token = await loop.run_in_executor(None, token_queue.get)
|
||||
if token is _DONE:
|
||||
break
|
||||
yield token
|
||||
|
||||
def chat(
|
||||
self,
|
||||
messages: list[ChatMessage],
|
||||
*,
|
||||
max_tokens: int = 512,
|
||||
temperature: float = 0.7,
|
||||
) -> GenerateResult:
|
||||
# llama-cpp-python has native chat_completion for instruct models
|
||||
output = self._llm.create_chat_completion(
|
||||
messages=[m.to_dict() for m in messages],
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
)
|
||||
text = output["choices"][0]["message"]["content"]
|
||||
tokens_used = output["usage"]["completion_tokens"]
|
||||
return GenerateResult(text=text, tokens_used=tokens_used, model=self.model_name)
|
||||
104
circuitforge_core/text/backends/mock.py
Normal file
104
circuitforge_core/text/backends/mock.py
Normal file
|
|
@ -0,0 +1,104 @@
|
|||
# circuitforge_core/text/backends/mock.py — synthetic text backend
|
||||
#
|
||||
# MIT licensed. No model file, no GPU, no extras required.
|
||||
# Used in dev, CI, and free-tier nodes below the minimum VRAM threshold.
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import AsyncIterator, Iterator
|
||||
|
||||
from circuitforge_core.text.backends.base import ChatMessage, GenerateResult
|
||||
|
||||
_MOCK_RESPONSE = (
|
||||
"This is a synthetic response from MockTextBackend. "
|
||||
"Install a real backend (llama-cpp-python or transformers) and provide a model path "
|
||||
"to generate real text."
|
||||
)
|
||||
|
||||
|
||||
class MockTextBackend:
|
||||
"""
|
||||
Deterministic synthetic text backend for development and CI.
|
||||
|
||||
Always returns the same fixed response so tests are reproducible without
|
||||
a GPU or model file. Streaming emits the response word-by-word with a
|
||||
configurable delay so UI streaming paths can be exercised.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "mock",
|
||||
token_delay_s: float = 0.0,
|
||||
) -> None:
|
||||
self._model_name = model_name
|
||||
self._token_delay_s = token_delay_s
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
return self._model_name
|
||||
|
||||
@property
|
||||
def vram_mb(self) -> int:
|
||||
return 0
|
||||
|
||||
def _response_for(self, prompt_or_messages: str) -> str:
|
||||
return _MOCK_RESPONSE
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
*,
|
||||
max_tokens: int = 512,
|
||||
temperature: float = 0.7,
|
||||
stop: list[str] | None = None,
|
||||
) -> GenerateResult:
|
||||
text = self._response_for(prompt)
|
||||
return GenerateResult(text=text, tokens_used=len(text.split()), model=self._model_name)
|
||||
|
||||
def generate_stream(
|
||||
self,
|
||||
prompt: str,
|
||||
*,
|
||||
max_tokens: int = 512,
|
||||
temperature: float = 0.7,
|
||||
stop: list[str] | None = None,
|
||||
) -> Iterator[str]:
|
||||
import time
|
||||
for word in self._response_for(prompt).split():
|
||||
yield word + " "
|
||||
if self._token_delay_s:
|
||||
time.sleep(self._token_delay_s)
|
||||
|
||||
async def generate_async(
|
||||
self,
|
||||
prompt: str,
|
||||
*,
|
||||
max_tokens: int = 512,
|
||||
temperature: float = 0.7,
|
||||
stop: list[str] | None = None,
|
||||
) -> GenerateResult:
|
||||
return self.generate(prompt, max_tokens=max_tokens, temperature=temperature, stop=stop)
|
||||
|
||||
async def generate_stream_async(
|
||||
self,
|
||||
prompt: str,
|
||||
*,
|
||||
max_tokens: int = 512,
|
||||
temperature: float = 0.7,
|
||||
stop: list[str] | None = None,
|
||||
) -> AsyncIterator[str]:
|
||||
for word in self._response_for(prompt).split():
|
||||
yield word + " "
|
||||
if self._token_delay_s:
|
||||
await asyncio.sleep(self._token_delay_s)
|
||||
|
||||
def chat(
|
||||
self,
|
||||
messages: list[ChatMessage],
|
||||
*,
|
||||
max_tokens: int = 512,
|
||||
temperature: float = 0.7,
|
||||
) -> GenerateResult:
|
||||
# Format messages into a simple prompt for the mock response
|
||||
prompt = "\n".join(f"{m.role}: {m.content}" for m in messages)
|
||||
return self.generate(prompt, max_tokens=max_tokens, temperature=temperature)
|
||||
197
circuitforge_core/text/backends/transformers.py
Normal file
197
circuitforge_core/text/backends/transformers.py
Normal file
|
|
@ -0,0 +1,197 @@
|
|||
# circuitforge_core/text/backends/transformers.py — HuggingFace transformers backend
|
||||
#
|
||||
# BSL 1.1: real inference. Requires torch + transformers + a model checkpoint.
|
||||
# Install: pip install circuitforge-core[text-transformers]
|
||||
#
|
||||
# Best for: HF repo IDs, safetensors checkpoints, models without GGUF versions.
|
||||
# For GGUF models prefer LlamaCppBackend — lower overhead, smaller install.
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from typing import AsyncIterator, Iterator
|
||||
|
||||
from circuitforge_core.text.backends.base import ChatMessage, GenerateResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_DEFAULT_MAX_NEW_TOKENS = 512
|
||||
_LOAD_IN_4BIT = os.environ.get("CF_TEXT_4BIT", "0") == "1"
|
||||
_LOAD_IN_8BIT = os.environ.get("CF_TEXT_8BIT", "0") == "1"
|
||||
|
||||
|
||||
class TransformersBackend:
|
||||
"""
|
||||
HuggingFace transformers inference backend.
|
||||
|
||||
Loads any causal LM available on HuggingFace Hub or a local checkpoint dir.
|
||||
Supports 4-bit and 8-bit quantization via bitsandbytes when VRAM is limited:
|
||||
CF_TEXT_4BIT=1 — load_in_4bit (requires bitsandbytes)
|
||||
CF_TEXT_8BIT=1 — load_in_8bit (requires bitsandbytes)
|
||||
|
||||
Chat completion uses the tokenizer's apply_chat_template() when available,
|
||||
falling back to a simple "User: / Assistant:" prompt format.
|
||||
|
||||
Requires: pip install circuitforge-core[text-transformers]
|
||||
"""
|
||||
|
||||
def __init__(self, model_path: str) -> None:
|
||||
try:
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"torch and transformers are required for TransformersBackend. "
|
||||
"Install with: pip install circuitforge-core[text-transformers]"
|
||||
) from exc
|
||||
|
||||
self._device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
logger.info("Loading transformers model %s on %s", model_path, self._device)
|
||||
|
||||
load_kwargs: dict = {"device_map": "auto" if self._device == "cuda" else None}
|
||||
if _LOAD_IN_4BIT:
|
||||
load_kwargs["load_in_4bit"] = True
|
||||
elif _LOAD_IN_8BIT:
|
||||
load_kwargs["load_in_8bit"] = True
|
||||
|
||||
self._tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||
self._model = AutoModelForCausalLM.from_pretrained(model_path, **load_kwargs)
|
||||
if self._device == "cpu":
|
||||
self._model = self._model.to("cpu")
|
||||
|
||||
self._model_path = model_path
|
||||
self._TextIteratorStreamer = TextIteratorStreamer
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
# HF repo IDs contain "/" — use the part after the slash as a short name
|
||||
return self._model_path.split("/")[-1]
|
||||
|
||||
@property
|
||||
def vram_mb(self) -> int:
|
||||
try:
|
||||
import torch
|
||||
if torch.cuda.is_available():
|
||||
return torch.cuda.memory_allocated() // (1024 * 1024)
|
||||
except Exception:
|
||||
pass
|
||||
return 0
|
||||
|
||||
def _build_inputs(self, prompt: str):
|
||||
return self._tokenizer(prompt, return_tensors="pt").to(self._device)
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
*,
|
||||
max_tokens: int = 512,
|
||||
temperature: float = 0.7,
|
||||
stop: list[str] | None = None,
|
||||
) -> GenerateResult:
|
||||
inputs = self._build_inputs(prompt)
|
||||
input_len = inputs["input_ids"].shape[1]
|
||||
outputs = self._model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
do_sample=temperature > 0,
|
||||
pad_token_id=self._tokenizer.eos_token_id,
|
||||
)
|
||||
new_tokens = outputs[0][input_len:]
|
||||
text = self._tokenizer.decode(new_tokens, skip_special_tokens=True)
|
||||
return GenerateResult(text=text, tokens_used=len(new_tokens), model=self.model_name)
|
||||
|
||||
def generate_stream(
|
||||
self,
|
||||
prompt: str,
|
||||
*,
|
||||
max_tokens: int = 512,
|
||||
temperature: float = 0.7,
|
||||
stop: list[str] | None = None,
|
||||
) -> Iterator[str]:
|
||||
import threading
|
||||
|
||||
inputs = self._build_inputs(prompt)
|
||||
streamer = self._TextIteratorStreamer(
|
||||
self._tokenizer, skip_prompt=True, skip_special_tokens=True
|
||||
)
|
||||
gen_kwargs = dict(
|
||||
**inputs,
|
||||
max_new_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
do_sample=temperature > 0,
|
||||
streamer=streamer,
|
||||
pad_token_id=self._tokenizer.eos_token_id,
|
||||
)
|
||||
thread = threading.Thread(target=self._model.generate, kwargs=gen_kwargs, daemon=True)
|
||||
thread.start()
|
||||
yield from streamer
|
||||
|
||||
async def generate_async(
|
||||
self,
|
||||
prompt: str,
|
||||
*,
|
||||
max_tokens: int = 512,
|
||||
temperature: float = 0.7,
|
||||
stop: list[str] | None = None,
|
||||
) -> GenerateResult:
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(
|
||||
None,
|
||||
lambda: self.generate(prompt, max_tokens=max_tokens, temperature=temperature, stop=stop),
|
||||
)
|
||||
|
||||
async def generate_stream_async(
|
||||
self,
|
||||
prompt: str,
|
||||
*,
|
||||
max_tokens: int = 512,
|
||||
temperature: float = 0.7,
|
||||
stop: list[str] | None = None,
|
||||
) -> AsyncIterator[str]:
|
||||
import queue
|
||||
import threading
|
||||
|
||||
token_queue: queue.Queue = queue.Queue()
|
||||
_DONE = object()
|
||||
|
||||
def _produce() -> None:
|
||||
try:
|
||||
for token in self.generate_stream(
|
||||
prompt, max_tokens=max_tokens, temperature=temperature
|
||||
):
|
||||
token_queue.put(token)
|
||||
finally:
|
||||
token_queue.put(_DONE)
|
||||
|
||||
threading.Thread(target=_produce, daemon=True).start()
|
||||
loop = asyncio.get_event_loop()
|
||||
while True:
|
||||
token = await loop.run_in_executor(None, token_queue.get)
|
||||
if token is _DONE:
|
||||
break
|
||||
yield token
|
||||
|
||||
def chat(
|
||||
self,
|
||||
messages: list[ChatMessage],
|
||||
*,
|
||||
max_tokens: int = 512,
|
||||
temperature: float = 0.7,
|
||||
) -> GenerateResult:
|
||||
# Use the tokenizer's chat template when available (instruct models)
|
||||
if hasattr(self._tokenizer, "apply_chat_template") and self._tokenizer.chat_template:
|
||||
prompt = self._tokenizer.apply_chat_template(
|
||||
[m.to_dict() for m in messages],
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
else:
|
||||
prompt = "\n".join(
|
||||
f"{'User' if m.role == 'user' else 'Assistant'}: {m.content}"
|
||||
for m in messages
|
||||
if m.role != "system"
|
||||
) + "\nAssistant:"
|
||||
|
||||
return self.generate(prompt, max_tokens=max_tokens, temperature=temperature)
|
||||
87
circuitforge_core/tts/__init__.py
Normal file
87
circuitforge_core/tts/__init__.py
Normal file
|
|
@ -0,0 +1,87 @@
|
|||
"""
|
||||
circuitforge_core.tts — Text-to-speech service module.
|
||||
|
||||
Quick start (mock mode — no GPU or model required):
|
||||
|
||||
import os; os.environ["CF_TTS_MOCK"] = "1"
|
||||
from circuitforge_core.tts import synthesize
|
||||
|
||||
result = synthesize("Hello world")
|
||||
open("out.ogg", "wb").write(result.audio_bytes)
|
||||
|
||||
Real inference (chatterbox-turbo):
|
||||
|
||||
export CF_TTS_MODEL=/Library/Assets/LLM/chatterbox/hub/models--ResembleAI--chatterbox-turbo/snapshots/<hash>
|
||||
from circuitforge_core.tts import synthesize
|
||||
|
||||
cf-orch service profile:
|
||||
|
||||
service_type: cf-tts
|
||||
max_mb: 768
|
||||
max_concurrent: 1
|
||||
shared: true
|
||||
managed:
|
||||
exec: python -m circuitforge_core.tts.app
|
||||
args: --model <path> --port {port} --gpu-id {gpu_id}
|
||||
port: 8005
|
||||
health: /health
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
from circuitforge_core.tts.backends.base import (
|
||||
AudioFormat,
|
||||
TTSBackend,
|
||||
TTSResult,
|
||||
make_tts_backend,
|
||||
)
|
||||
from circuitforge_core.tts.backends.mock import MockTTSBackend
|
||||
|
||||
_backend: TTSBackend | None = None
|
||||
|
||||
|
||||
def _get_backend() -> TTSBackend:
|
||||
global _backend
|
||||
if _backend is None:
|
||||
model_path = os.environ.get("CF_TTS_MODEL", "mock")
|
||||
mock = model_path == "mock" or os.environ.get("CF_TTS_MOCK", "") == "1"
|
||||
_backend = make_tts_backend(model_path, mock=mock)
|
||||
return _backend
|
||||
|
||||
|
||||
def synthesize(
|
||||
text: str,
|
||||
*,
|
||||
exaggeration: float = 0.5,
|
||||
cfg_weight: float = 0.5,
|
||||
temperature: float = 0.8,
|
||||
audio_prompt: bytes | None = None,
|
||||
format: AudioFormat = "ogg",
|
||||
) -> TTSResult:
|
||||
"""Synthesize speech from text using the process-level backend."""
|
||||
return _get_backend().synthesize(
|
||||
text,
|
||||
exaggeration=exaggeration,
|
||||
cfg_weight=cfg_weight,
|
||||
temperature=temperature,
|
||||
audio_prompt=audio_prompt,
|
||||
format=format,
|
||||
)
|
||||
|
||||
|
||||
def reset_backend() -> None:
|
||||
"""Reset the process-level singleton. Test teardown only."""
|
||||
global _backend
|
||||
_backend = None
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AudioFormat",
|
||||
"TTSBackend",
|
||||
"TTSResult",
|
||||
"MockTTSBackend",
|
||||
"make_tts_backend",
|
||||
"synthesize",
|
||||
"reset_backend",
|
||||
]
|
||||
103
circuitforge_core/tts/app.py
Normal file
103
circuitforge_core/tts/app.py
Normal file
|
|
@ -0,0 +1,103 @@
|
|||
"""
|
||||
cf-tts FastAPI service — managed by cf-orch.
|
||||
|
||||
Endpoints:
|
||||
GET /health → {"status": "ok", "model": str, "vram_mb": int}
|
||||
POST /synthesize → audio bytes (Content-Type: audio/ogg or audio/wav or audio/mpeg)
|
||||
|
||||
Usage:
|
||||
python -m circuitforge_core.tts.app \
|
||||
--model /Library/Assets/LLM/chatterbox/hub/models--ResembleAI--chatterbox-turbo/snapshots/<hash> \
|
||||
--port 8005 \
|
||||
--gpu-id 0
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from fastapi import FastAPI, Form, HTTPException, UploadFile
|
||||
from fastapi.responses import Response
|
||||
|
||||
from circuitforge_core.tts.backends.base import AudioFormat, TTSBackend, make_tts_backend
|
||||
|
||||
_CONTENT_TYPES: dict[str, str] = {
|
||||
"ogg": "audio/ogg",
|
||||
"wav": "audio/wav",
|
||||
"mp3": "audio/mpeg",
|
||||
}
|
||||
|
||||
app = FastAPI(title="cf-tts")
|
||||
_backend: TTSBackend | None = None
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
def health() -> dict:
|
||||
if _backend is None:
|
||||
raise HTTPException(503, detail="backend not initialised")
|
||||
return {"status": "ok", "model": _backend.model_name, "vram_mb": _backend.vram_mb}
|
||||
|
||||
|
||||
@app.post("/synthesize")
|
||||
async def synthesize(
|
||||
text: Annotated[str, Form()],
|
||||
format: Annotated[AudioFormat, Form()] = "ogg",
|
||||
exaggeration: Annotated[float, Form()] = 0.5,
|
||||
cfg_weight: Annotated[float, Form()] = 0.5,
|
||||
temperature: Annotated[float, Form()] = 0.8,
|
||||
audio_prompt: UploadFile | None = None,
|
||||
) -> Response:
|
||||
if _backend is None:
|
||||
raise HTTPException(503, detail="backend not initialised")
|
||||
if not text.strip():
|
||||
raise HTTPException(422, detail="text must not be empty")
|
||||
|
||||
prompt_bytes: bytes | None = None
|
||||
if audio_prompt is not None:
|
||||
prompt_bytes = await audio_prompt.read()
|
||||
|
||||
result = _backend.synthesize(
|
||||
text,
|
||||
exaggeration=exaggeration,
|
||||
cfg_weight=cfg_weight,
|
||||
temperature=temperature,
|
||||
audio_prompt=prompt_bytes,
|
||||
format=format,
|
||||
)
|
||||
return Response(
|
||||
content=result.audio_bytes,
|
||||
media_type=_CONTENT_TYPES.get(result.format, "audio/ogg"),
|
||||
headers={
|
||||
"X-Duration-S": str(round(result.duration_s, 3)),
|
||||
"X-Model": result.model,
|
||||
"X-Sample-Rate": str(result.sample_rate),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _parse_args() -> argparse.Namespace:
|
||||
p = argparse.ArgumentParser(description="cf-tts service")
|
||||
p.add_argument("--model", required=True)
|
||||
p.add_argument("--port", type=int, default=8005)
|
||||
p.add_argument("--host", default="0.0.0.0")
|
||||
p.add_argument("--gpu-id", type=int, default=0)
|
||||
p.add_argument("--mock", action="store_true")
|
||||
return p.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
args = _parse_args()
|
||||
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_id)
|
||||
|
||||
mock = args.mock or args.model == "mock"
|
||||
device = "cpu" if mock else "cuda"
|
||||
|
||||
global _backend
|
||||
_backend = make_tts_backend(args.model, mock=mock, device=device)
|
||||
print(f"cf-tts backend ready: {_backend.model_name} ({_backend.vram_mb} MB)")
|
||||
|
||||
uvicorn.run(app, host=args.host, port=args.port)
|
||||
4
circuitforge_core/tts/backends/__init__.py
Normal file
4
circuitforge_core/tts/backends/__init__.py
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
from .base import AudioFormat, TTSBackend, TTSResult, make_tts_backend
|
||||
from .mock import MockTTSBackend
|
||||
|
||||
__all__ = ["AudioFormat", "TTSBackend", "TTSResult", "make_tts_backend", "MockTTSBackend"]
|
||||
84
circuitforge_core/tts/backends/base.py
Normal file
84
circuitforge_core/tts/backends/base.py
Normal file
|
|
@ -0,0 +1,84 @@
|
|||
"""
|
||||
TTSBackend Protocol — backend-agnostic TTS interface.
|
||||
|
||||
All backends return TTSResult with audio bytes in the requested format.
|
||||
Supported formats: ogg (default, smallest), wav (uncompressed, always works), mp3.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Literal, Protocol, runtime_checkable
|
||||
|
||||
AudioFormat = Literal["ogg", "wav", "mp3"]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TTSResult:
|
||||
audio_bytes: bytes
|
||||
sample_rate: int
|
||||
duration_s: float
|
||||
format: AudioFormat = "ogg"
|
||||
model: str = ""
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class TTSBackend(Protocol):
|
||||
def synthesize(
|
||||
self,
|
||||
text: str,
|
||||
*,
|
||||
exaggeration: float = 0.5,
|
||||
cfg_weight: float = 0.5,
|
||||
temperature: float = 0.8,
|
||||
audio_prompt: bytes | None = None,
|
||||
format: AudioFormat = "ogg",
|
||||
) -> TTSResult: ...
|
||||
|
||||
@property
|
||||
def model_name(self) -> str: ...
|
||||
|
||||
@property
|
||||
def vram_mb(self) -> int: ...
|
||||
|
||||
|
||||
def _encode_audio(
|
||||
wav_tensor, # torch.Tensor shape [1, T] or [T]
|
||||
sample_rate: int,
|
||||
format: AudioFormat,
|
||||
) -> bytes:
|
||||
"""Convert a torch tensor to audio bytes in the requested format."""
|
||||
import torch
|
||||
import torchaudio
|
||||
|
||||
wav = wav_tensor
|
||||
if wav.dim() == 1:
|
||||
wav = wav.unsqueeze(0)
|
||||
wav = wav.to(torch.float32).cpu()
|
||||
|
||||
buf = io.BytesIO()
|
||||
if format == "wav":
|
||||
torchaudio.save(buf, wav, sample_rate, format="wav")
|
||||
elif format == "ogg":
|
||||
torchaudio.save(buf, wav, sample_rate, format="ogg", encoding="vorbis")
|
||||
elif format == "mp3":
|
||||
# torchaudio MP3 encode requires ffmpeg backend; fall back to wav on failure
|
||||
try:
|
||||
torchaudio.save(buf, wav, sample_rate, format="mp3")
|
||||
except Exception:
|
||||
buf = io.BytesIO()
|
||||
torchaudio.save(buf, wav, sample_rate, format="wav")
|
||||
return buf.getvalue()
|
||||
|
||||
|
||||
def make_tts_backend(
|
||||
model_path: str,
|
||||
*,
|
||||
mock: bool = False,
|
||||
device: str = "cuda",
|
||||
) -> TTSBackend:
|
||||
if mock:
|
||||
from circuitforge_core.tts.backends.mock import MockTTSBackend
|
||||
return MockTTSBackend()
|
||||
from circuitforge_core.tts.backends.chatterbox import ChatterboxTurboBackend
|
||||
return ChatterboxTurboBackend(model_path=model_path, device=device)
|
||||
82
circuitforge_core/tts/backends/chatterbox.py
Normal file
82
circuitforge_core/tts/backends/chatterbox.py
Normal file
|
|
@ -0,0 +1,82 @@
|
|||
"""ChatterboxTurboBackend — ResembleAI chatterbox-turbo TTS via chatterbox-tts package."""
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
from circuitforge_core.tts.backends.base import (
|
||||
AudioFormat,
|
||||
TTSBackend,
|
||||
TTSResult,
|
||||
_encode_audio,
|
||||
)
|
||||
|
||||
_VRAM_MB = 768 # conservative estimate for chatterbox-turbo weights
|
||||
|
||||
|
||||
class ChatterboxTurboBackend:
|
||||
def __init__(self, model_path: str, device: str = "cuda") -> None:
|
||||
os.environ.setdefault("CUDA_VISIBLE_DEVICES", "0")
|
||||
from chatterbox.models.s3gen import S3GEN_SR
|
||||
from chatterbox.tts import ChatterboxTTS
|
||||
|
||||
self._sr = S3GEN_SR
|
||||
self._device = device
|
||||
self._model = ChatterboxTTS.from_local(model_path, device=device)
|
||||
self._model_path = model_path
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
return f"chatterbox-turbo@{os.path.basename(self._model_path)}"
|
||||
|
||||
@property
|
||||
def vram_mb(self) -> int:
|
||||
return _VRAM_MB
|
||||
|
||||
def synthesize(
|
||||
self,
|
||||
text: str,
|
||||
*,
|
||||
exaggeration: float = 0.5,
|
||||
cfg_weight: float = 0.5,
|
||||
temperature: float = 0.8,
|
||||
audio_prompt: bytes | None = None,
|
||||
format: AudioFormat = "ogg",
|
||||
) -> TTSResult:
|
||||
audio_prompt_path: str | None = None
|
||||
_tmp = None
|
||||
|
||||
if audio_prompt is not None:
|
||||
_tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
|
||||
_tmp.write(audio_prompt)
|
||||
_tmp.flush()
|
||||
audio_prompt_path = _tmp.name
|
||||
|
||||
try:
|
||||
wav = self._model.generate(
|
||||
text,
|
||||
exaggeration=exaggeration,
|
||||
cfg_weight=cfg_weight,
|
||||
temperature=temperature,
|
||||
audio_prompt_path=audio_prompt_path,
|
||||
)
|
||||
finally:
|
||||
if _tmp is not None:
|
||||
_tmp.close()
|
||||
os.unlink(_tmp.name)
|
||||
|
||||
duration_s = wav.shape[-1] / self._sr
|
||||
audio_bytes = _encode_audio(wav, self._sr, format)
|
||||
return TTSResult(
|
||||
audio_bytes=audio_bytes,
|
||||
sample_rate=self._sr,
|
||||
duration_s=duration_s,
|
||||
format=format,
|
||||
model=self.model_name,
|
||||
)
|
||||
|
||||
|
||||
assert isinstance(
|
||||
ChatterboxTurboBackend.__new__(ChatterboxTurboBackend), TTSBackend
|
||||
), "ChatterboxTurboBackend must satisfy TTSBackend Protocol"
|
||||
56
circuitforge_core/tts/backends/mock.py
Normal file
56
circuitforge_core/tts/backends/mock.py
Normal file
|
|
@ -0,0 +1,56 @@
|
|||
"""MockTTSBackend — no GPU, no model required. Returns a silent WAV clip."""
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import struct
|
||||
import wave
|
||||
|
||||
from circuitforge_core.tts.backends.base import AudioFormat, TTSBackend, TTSResult
|
||||
|
||||
_SAMPLE_RATE = 24000
|
||||
|
||||
|
||||
def _silent_wav(duration_s: float = 0.5, sample_rate: int = _SAMPLE_RATE) -> bytes:
|
||||
num_samples = int(duration_s * sample_rate)
|
||||
buf = io.BytesIO()
|
||||
with wave.open(buf, "wb") as w:
|
||||
w.setnchannels(1)
|
||||
w.setsampwidth(2)
|
||||
w.setframerate(sample_rate)
|
||||
w.writeframes(struct.pack(f"<{num_samples}h", *([0] * num_samples)))
|
||||
return buf.getvalue()
|
||||
|
||||
|
||||
class MockTTSBackend:
|
||||
"""Minimal TTSBackend implementation for tests and CI."""
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
return "mock-tts"
|
||||
|
||||
@property
|
||||
def vram_mb(self) -> int:
|
||||
return 0
|
||||
|
||||
def synthesize(
|
||||
self,
|
||||
text: str,
|
||||
*,
|
||||
exaggeration: float = 0.5,
|
||||
cfg_weight: float = 0.5,
|
||||
temperature: float = 0.8,
|
||||
audio_prompt: bytes | None = None,
|
||||
format: AudioFormat = "ogg",
|
||||
) -> TTSResult:
|
||||
duration_s = max(0.1, len(text.split()) * 0.3)
|
||||
audio = _silent_wav(duration_s)
|
||||
return TTSResult(
|
||||
audio_bytes=audio,
|
||||
sample_rate=_SAMPLE_RATE,
|
||||
duration_s=duration_s,
|
||||
format="wav",
|
||||
model=self.model_name,
|
||||
)
|
||||
|
||||
|
||||
assert isinstance(MockTTSBackend(), TTSBackend), "MockTTSBackend must satisfy TTSBackend Protocol"
|
||||
|
|
@ -1,3 +1,108 @@
|
|||
from .router import VisionRouter
|
||||
"""
|
||||
circuitforge_core.vision — Managed vision service module.
|
||||
|
||||
__all__ = ["VisionRouter"]
|
||||
Quick start (mock mode — no GPU or model required):
|
||||
|
||||
import os; os.environ["CF_VISION_MOCK"] = "1"
|
||||
from circuitforge_core.vision import classify, embed
|
||||
|
||||
result = classify(image_bytes, labels=["cat", "dog", "bird"])
|
||||
print(result.top(1)) # [("cat", 0.82)]
|
||||
|
||||
emb = embed(image_bytes)
|
||||
print(len(emb.embedding)) # 1152 (so400m hidden dim)
|
||||
|
||||
Real inference (SigLIP — default, ~1.4 GB VRAM):
|
||||
|
||||
export CF_VISION_MODEL=google/siglip-so400m-patch14-384
|
||||
from circuitforge_core.vision import classify
|
||||
|
||||
Full VLM inference (caption + VQA):
|
||||
|
||||
export CF_VISION_BACKEND=vlm
|
||||
export CF_VISION_MODEL=vikhyatk/moondream2
|
||||
from circuitforge_core.vision import caption
|
||||
|
||||
Per-request backend (bypasses process singleton):
|
||||
|
||||
from circuitforge_core.vision import make_backend
|
||||
vlm = make_backend("vikhyatk/moondream2", backend="vlm")
|
||||
result = vlm.caption(image_bytes, prompt="What text appears in this image?")
|
||||
|
||||
cf-orch service profile:
|
||||
|
||||
service_type: cf-vision
|
||||
max_mb: 1536 (siglip-so400m); 2200 (moondream2); 14500 (llava-7b)
|
||||
max_concurrent: 4 (siglip); 1 (vlm)
|
||||
shared: true
|
||||
managed:
|
||||
exec: python -m circuitforge_core.vision.app
|
||||
args: --model <path> --backend siglip --port {port} --gpu-id {gpu_id}
|
||||
port: 8006
|
||||
health: /health
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
from circuitforge_core.vision.backends.base import (
|
||||
VisionBackend,
|
||||
VisionResult,
|
||||
make_vision_backend,
|
||||
)
|
||||
from circuitforge_core.vision.backends.mock import MockVisionBackend
|
||||
|
||||
_backend: VisionBackend | None = None
|
||||
|
||||
|
||||
def _get_backend() -> VisionBackend:
|
||||
global _backend
|
||||
if _backend is None:
|
||||
model_path = os.environ.get("CF_VISION_MODEL", "mock")
|
||||
mock = model_path == "mock" or os.environ.get("CF_VISION_MOCK", "") == "1"
|
||||
_backend = make_vision_backend(model_path, mock=mock)
|
||||
return _backend
|
||||
|
||||
|
||||
def classify(image: bytes, labels: list[str]) -> VisionResult:
|
||||
"""Zero-shot image classification using the process-level backend."""
|
||||
return _get_backend().classify(image, labels)
|
||||
|
||||
|
||||
def embed(image: bytes) -> VisionResult:
|
||||
"""Image embedding using the process-level backend (SigLIP only)."""
|
||||
return _get_backend().embed(image)
|
||||
|
||||
|
||||
def caption(image: bytes, prompt: str = "") -> VisionResult:
|
||||
"""Image captioning / VQA using the process-level backend (VLM only)."""
|
||||
return _get_backend().caption(image, prompt)
|
||||
|
||||
|
||||
def make_backend(
|
||||
model_path: str,
|
||||
backend: str | None = None,
|
||||
mock: bool | None = None,
|
||||
device: str = "cuda",
|
||||
dtype: str = "float16",
|
||||
) -> VisionBackend:
|
||||
"""
|
||||
Create a one-off VisionBackend without affecting the process singleton.
|
||||
|
||||
Useful when a product needs both SigLIP (routing) and a VLM (captioning)
|
||||
in the same process, or when testing different models side-by-side.
|
||||
"""
|
||||
return make_vision_backend(
|
||||
model_path, backend=backend, mock=mock, device=device, dtype=dtype
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"VisionBackend",
|
||||
"VisionResult",
|
||||
"MockVisionBackend",
|
||||
"classify",
|
||||
"embed",
|
||||
"caption",
|
||||
"make_backend",
|
||||
]
|
||||
|
|
|
|||
245
circuitforge_core/vision/app.py
Normal file
245
circuitforge_core/vision/app.py
Normal file
|
|
@ -0,0 +1,245 @@
|
|||
"""
|
||||
circuitforge_core.vision.app — cf-vision FastAPI service.
|
||||
|
||||
Managed by cf-orch as a process-type service. cf-orch starts this via:
|
||||
|
||||
python -m circuitforge_core.vision.app \
|
||||
--model google/siglip-so400m-patch14-384 \
|
||||
--backend siglip \
|
||||
--port 8006 \
|
||||
--gpu-id 0
|
||||
|
||||
For VLM inference (caption/VQA):
|
||||
|
||||
python -m circuitforge_core.vision.app \
|
||||
--model vikhyatk/moondream2 \
|
||||
--backend vlm \
|
||||
--port 8006 \
|
||||
--gpu-id 0
|
||||
|
||||
Endpoints:
|
||||
GET /health → {"status": "ok", "model": "...", "vram_mb": n,
|
||||
"supports_embed": bool, "supports_caption": bool}
|
||||
POST /classify → VisionClassifyResponse (multipart: image + labels)
|
||||
POST /embed → VisionEmbedResponse (multipart: image)
|
||||
POST /caption → VisionCaptionResponse (multipart: image + prompt)
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, File, Form, HTTPException, UploadFile
|
||||
from pydantic import BaseModel
|
||||
|
||||
from circuitforge_core.vision.backends.base import make_vision_backend
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ── Response models ───────────────────────────────────────────────────────────
|
||||
|
||||
class VisionClassifyResponse(BaseModel):
|
||||
labels: list[str]
|
||||
scores: list[float]
|
||||
model: str
|
||||
|
||||
|
||||
class VisionEmbedResponse(BaseModel):
|
||||
embedding: list[float]
|
||||
model: str
|
||||
|
||||
|
||||
class VisionCaptionResponse(BaseModel):
|
||||
caption: str
|
||||
model: str
|
||||
|
||||
|
||||
class HealthResponse(BaseModel):
|
||||
status: str
|
||||
model: str
|
||||
vram_mb: int
|
||||
backend: str
|
||||
supports_embed: bool
|
||||
supports_caption: bool
|
||||
|
||||
|
||||
# ── App factory ───────────────────────────────────────────────────────────────
|
||||
|
||||
def create_app(
|
||||
model_path: str,
|
||||
backend: str = "siglip",
|
||||
device: str = "cuda",
|
||||
dtype: str = "float16",
|
||||
mock: bool = False,
|
||||
) -> FastAPI:
|
||||
app = FastAPI(title="cf-vision", version="0.1.0")
|
||||
_backend = make_vision_backend(
|
||||
model_path, backend=backend, device=device, dtype=dtype, mock=mock
|
||||
)
|
||||
logger.info(
|
||||
"cf-vision ready: model=%r backend=%r vram=%dMB",
|
||||
_backend.model_name, backend, _backend.vram_mb,
|
||||
)
|
||||
|
||||
@app.get("/health", response_model=HealthResponse)
|
||||
async def health() -> HealthResponse:
|
||||
return HealthResponse(
|
||||
status="ok",
|
||||
model=_backend.model_name,
|
||||
vram_mb=_backend.vram_mb,
|
||||
backend=backend,
|
||||
supports_embed=_backend.supports_embed,
|
||||
supports_caption=_backend.supports_caption,
|
||||
)
|
||||
|
||||
@app.post("/classify", response_model=VisionClassifyResponse)
|
||||
async def classify(
|
||||
image: UploadFile = File(..., description="Image file (JPEG, PNG, WEBP, ...)"),
|
||||
labels: str = Form(
|
||||
...,
|
||||
description=(
|
||||
"Candidate labels — either a JSON array "
|
||||
'(["cat","dog"]) or comma-separated (cat,dog)'
|
||||
),
|
||||
),
|
||||
) -> VisionClassifyResponse:
|
||||
image_bytes = await image.read()
|
||||
if not image_bytes:
|
||||
raise HTTPException(status_code=400, detail="Empty image file")
|
||||
|
||||
parsed_labels = _parse_labels(labels)
|
||||
if not parsed_labels:
|
||||
raise HTTPException(status_code=400, detail="At least one label is required")
|
||||
|
||||
try:
|
||||
result = _backend.classify(image_bytes, parsed_labels)
|
||||
except Exception as exc:
|
||||
logger.exception("classify failed")
|
||||
raise HTTPException(status_code=500, detail=str(exc)) from exc
|
||||
|
||||
return VisionClassifyResponse(
|
||||
labels=result.labels, scores=result.scores, model=result.model
|
||||
)
|
||||
|
||||
@app.post("/embed", response_model=VisionEmbedResponse)
|
||||
async def embed_image(
|
||||
image: UploadFile = File(..., description="Image file (JPEG, PNG, WEBP, ...)"),
|
||||
) -> VisionEmbedResponse:
|
||||
if not _backend.supports_embed:
|
||||
raise HTTPException(
|
||||
status_code=501,
|
||||
detail=(
|
||||
f"Backend '{backend}' does not support embedding. "
|
||||
"Use backend=siglip for embed()."
|
||||
),
|
||||
)
|
||||
|
||||
image_bytes = await image.read()
|
||||
if not image_bytes:
|
||||
raise HTTPException(status_code=400, detail="Empty image file")
|
||||
|
||||
try:
|
||||
result = _backend.embed(image_bytes)
|
||||
except Exception as exc:
|
||||
logger.exception("embed failed")
|
||||
raise HTTPException(status_code=500, detail=str(exc)) from exc
|
||||
|
||||
return VisionEmbedResponse(embedding=result.embedding or [], model=result.model)
|
||||
|
||||
@app.post("/caption", response_model=VisionCaptionResponse)
|
||||
async def caption_image(
|
||||
image: UploadFile = File(..., description="Image file (JPEG, PNG, WEBP, ...)"),
|
||||
prompt: str = Form(
|
||||
"",
|
||||
description="Optional instruction / question for the VLM",
|
||||
),
|
||||
) -> VisionCaptionResponse:
|
||||
if not _backend.supports_caption:
|
||||
raise HTTPException(
|
||||
status_code=501,
|
||||
detail=(
|
||||
f"Backend '{backend}' does not support caption generation. "
|
||||
"Use backend=vlm for caption()."
|
||||
),
|
||||
)
|
||||
|
||||
image_bytes = await image.read()
|
||||
if not image_bytes:
|
||||
raise HTTPException(status_code=400, detail="Empty image file")
|
||||
|
||||
try:
|
||||
result = _backend.caption(image_bytes, prompt=prompt)
|
||||
except Exception as exc:
|
||||
logger.exception("caption failed")
|
||||
raise HTTPException(status_code=500, detail=str(exc)) from exc
|
||||
|
||||
return VisionCaptionResponse(caption=result.caption or "", model=result.model)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
# ── Label parsing ─────────────────────────────────────────────────────────────
|
||||
|
||||
def _parse_labels(raw: str) -> list[str]:
|
||||
"""Accept JSON array or comma-separated label string."""
|
||||
stripped = raw.strip()
|
||||
if stripped.startswith("["):
|
||||
try:
|
||||
parsed = json.loads(stripped)
|
||||
if isinstance(parsed, list):
|
||||
return [str(x) for x in parsed]
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
return [lbl.strip() for lbl in stripped.split(",") if lbl.strip()]
|
||||
|
||||
|
||||
# ── CLI entry point ───────────────────────────────────────────────────────────
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="cf-vision — CircuitForge vision service")
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
default="google/siglip-so400m-patch14-384",
|
||||
help="HuggingFace model ID or local path",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--backend", default="siglip", choices=["siglip", "vlm"],
|
||||
help="Vision backend: siglip (classify+embed) or vlm (caption+classify)",
|
||||
)
|
||||
parser.add_argument("--port", type=int, default=8006)
|
||||
parser.add_argument("--host", default="0.0.0.0")
|
||||
parser.add_argument("--gpu-id", type=int, default=0)
|
||||
parser.add_argument("--device", default="cuda", choices=["cuda", "cpu"])
|
||||
parser.add_argument("--dtype", default="float16",
|
||||
choices=["float16", "bfloat16", "float32"])
|
||||
parser.add_argument("--mock", action="store_true",
|
||||
help="Run with mock backend (no GPU, for testing)")
|
||||
args = parser.parse_args()
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s %(levelname)s %(name)s %(message)s",
|
||||
)
|
||||
|
||||
if args.device == "cuda" and not args.mock:
|
||||
os.environ.setdefault("CUDA_VISIBLE_DEVICES", str(args.gpu_id))
|
||||
|
||||
mock = args.mock or os.environ.get("CF_VISION_MOCK", "") == "1"
|
||||
app = create_app(
|
||||
model_path=args.model,
|
||||
backend=args.backend,
|
||||
device=args.device,
|
||||
dtype=args.dtype,
|
||||
mock=mock,
|
||||
)
|
||||
|
||||
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
4
circuitforge_core/vision/backends/__init__.py
Normal file
4
circuitforge_core/vision/backends/__init__.py
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
from circuitforge_core.vision.backends.base import VisionBackend, VisionResult, make_vision_backend
|
||||
from circuitforge_core.vision.backends.mock import MockVisionBackend
|
||||
|
||||
__all__ = ["VisionBackend", "VisionResult", "make_vision_backend", "MockVisionBackend"]
|
||||
150
circuitforge_core/vision/backends/base.py
Normal file
150
circuitforge_core/vision/backends/base.py
Normal file
|
|
@ -0,0 +1,150 @@
|
|||
# circuitforge_core/vision/backends/base.py — VisionBackend Protocol + factory
|
||||
#
|
||||
# MIT licensed. The Protocol and mock are always importable without GPU deps.
|
||||
# Real backends require optional extras:
|
||||
# pip install -e "circuitforge-core[vision-siglip]" # SigLIP (default, ~1.4 GB VRAM)
|
||||
# pip install -e "circuitforge-core[vision-vlm]" # Full VLM (e.g. moondream, LLaVA)
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Protocol, runtime_checkable
|
||||
|
||||
|
||||
# ── Result type ───────────────────────────────────────────────────────────────
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class VisionResult:
|
||||
"""
|
||||
Standard result from any VisionBackend call.
|
||||
|
||||
classify() → labels + scores populated; embedding/caption may be None.
|
||||
embed() → embedding populated; labels/scores empty.
|
||||
caption() → caption populated; labels/scores empty; embedding None.
|
||||
"""
|
||||
labels: list[str] = field(default_factory=list)
|
||||
scores: list[float] = field(default_factory=list)
|
||||
embedding: list[float] | None = None
|
||||
caption: str | None = None
|
||||
model: str = ""
|
||||
|
||||
def top(self, n: int = 1) -> list[tuple[str, float]]:
|
||||
"""Return the top-n (label, score) pairs sorted by descending score."""
|
||||
paired = sorted(zip(self.labels, self.scores), key=lambda x: x[1], reverse=True)
|
||||
return paired[:n]
|
||||
|
||||
|
||||
# ── Protocol ──────────────────────────────────────────────────────────────────
|
||||
|
||||
@runtime_checkable
|
||||
class VisionBackend(Protocol):
|
||||
"""
|
||||
Abstract interface for vision backends.
|
||||
|
||||
All backends load their model once at construction time.
|
||||
|
||||
SigLIP backends implement classify() and embed() but raise NotImplementedError
|
||||
for caption(). VLM backends implement caption() and a prompt-based classify()
|
||||
but raise NotImplementedError for embed().
|
||||
"""
|
||||
|
||||
def classify(self, image: bytes, labels: list[str]) -> VisionResult:
|
||||
"""
|
||||
Zero-shot image classification.
|
||||
|
||||
labels: candidate text descriptions; scores are returned in the same order.
|
||||
SigLIP uses sigmoid similarity; VLM prompts for each label.
|
||||
"""
|
||||
...
|
||||
|
||||
def embed(self, image: bytes) -> VisionResult:
|
||||
"""
|
||||
Return an image embedding vector.
|
||||
|
||||
Available on SigLIP backends. Raises NotImplementedError on VLM backends.
|
||||
embedding is a list of floats with length == model hidden dim.
|
||||
"""
|
||||
...
|
||||
|
||||
def caption(self, image: bytes, prompt: str = "") -> VisionResult:
|
||||
"""
|
||||
Generate a text description of the image.
|
||||
|
||||
Available on VLM backends. Raises NotImplementedError on SigLIP backends.
|
||||
prompt is an optional instruction; defaults to a generic description request.
|
||||
"""
|
||||
...
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
"""Identifier for the loaded model (HuggingFace ID or path stem)."""
|
||||
...
|
||||
|
||||
@property
|
||||
def vram_mb(self) -> int:
|
||||
"""Approximate VRAM footprint in MB. Used by cf-orch service registry."""
|
||||
...
|
||||
|
||||
@property
|
||||
def supports_embed(self) -> bool:
|
||||
"""True if embed() is implemented (SigLIP backends)."""
|
||||
...
|
||||
|
||||
@property
|
||||
def supports_caption(self) -> bool:
|
||||
"""True if caption() is implemented (VLM backends)."""
|
||||
...
|
||||
|
||||
|
||||
# ── Factory ───────────────────────────────────────────────────────────────────
|
||||
|
||||
def make_vision_backend(
|
||||
model_path: str,
|
||||
backend: str | None = None,
|
||||
mock: bool | None = None,
|
||||
device: str = "cuda",
|
||||
dtype: str = "float16",
|
||||
) -> VisionBackend:
|
||||
"""
|
||||
Return a VisionBackend for the given model.
|
||||
|
||||
mock=True or CF_VISION_MOCK=1 → MockVisionBackend (no GPU, no model file needed)
|
||||
backend="siglip" → SigLIPBackend (default; classify + embed)
|
||||
backend="vlm" → VLMBackend (caption + prompt-based classify)
|
||||
|
||||
Auto-detection: if model_path contains "siglip" → SigLIPBackend;
|
||||
otherwise defaults to siglip unless backend is explicitly "vlm".
|
||||
|
||||
device and dtype are forwarded to the real backends and ignored by mock.
|
||||
"""
|
||||
use_mock = mock if mock is not None else os.environ.get("CF_VISION_MOCK", "") == "1"
|
||||
if use_mock:
|
||||
from circuitforge_core.vision.backends.mock import MockVisionBackend
|
||||
return MockVisionBackend(model_name=model_path)
|
||||
|
||||
resolved = backend or os.environ.get("CF_VISION_BACKEND", "")
|
||||
if not resolved:
|
||||
# Auto-detect from model path
|
||||
resolved = "vlm" if _looks_like_vlm(model_path) else "siglip"
|
||||
|
||||
if resolved == "siglip":
|
||||
from circuitforge_core.vision.backends.siglip import SigLIPBackend
|
||||
return SigLIPBackend(model_path=model_path, device=device, dtype=dtype)
|
||||
|
||||
if resolved == "vlm":
|
||||
from circuitforge_core.vision.backends.vlm import VLMBackend
|
||||
return VLMBackend(model_path=model_path, device=device, dtype=dtype)
|
||||
|
||||
raise ValueError(
|
||||
f"Unknown vision backend {resolved!r}. "
|
||||
"Expected 'siglip' or 'vlm'. Set CF_VISION_BACKEND or pass backend= explicitly."
|
||||
)
|
||||
|
||||
|
||||
def _looks_like_vlm(model_path: str) -> bool:
|
||||
"""Heuristic: names associated with generative VLMs."""
|
||||
_vlm_hints = ("llava", "moondream", "qwen-vl", "qwenvl", "idefics",
|
||||
"cogvlm", "internvl", "phi-3-vision", "phi3vision",
|
||||
"dolphin", "paligemma")
|
||||
lower = model_path.lower()
|
||||
return any(h in lower for h in _vlm_hints)
|
||||
62
circuitforge_core/vision/backends/mock.py
Normal file
62
circuitforge_core/vision/backends/mock.py
Normal file
|
|
@ -0,0 +1,62 @@
|
|||
# circuitforge_core/vision/backends/mock.py — MockVisionBackend
|
||||
#
|
||||
# Deterministic stub for tests and CI. No GPU, no model files required.
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
|
||||
from circuitforge_core.vision.backends.base import VisionBackend, VisionResult
|
||||
|
||||
|
||||
class MockVisionBackend:
|
||||
"""
|
||||
Mock VisionBackend for testing.
|
||||
|
||||
classify() returns uniform scores normalised to 1/n per label.
|
||||
embed() returns a unit vector of length 512 (all values 1/sqrt(512)).
|
||||
caption() returns a canned string.
|
||||
"""
|
||||
|
||||
def __init__(self, model_name: str = "mock") -> None:
|
||||
self._model_name = model_name
|
||||
|
||||
# ── VisionBackend Protocol ─────────────────────────────────────────────────
|
||||
|
||||
def classify(self, image: bytes, labels: list[str]) -> VisionResult:
|
||||
n = max(len(labels), 1)
|
||||
return VisionResult(
|
||||
labels=list(labels),
|
||||
scores=[1.0 / n] * len(labels),
|
||||
model=self._model_name,
|
||||
)
|
||||
|
||||
def embed(self, image: bytes) -> VisionResult:
|
||||
dim = 512
|
||||
val = 1.0 / math.sqrt(dim)
|
||||
return VisionResult(embedding=[val] * dim, model=self._model_name)
|
||||
|
||||
def caption(self, image: bytes, prompt: str = "") -> VisionResult:
|
||||
return VisionResult(
|
||||
caption="A mock image description for testing purposes.",
|
||||
model=self._model_name,
|
||||
)
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
return self._model_name
|
||||
|
||||
@property
|
||||
def vram_mb(self) -> int:
|
||||
return 0
|
||||
|
||||
@property
|
||||
def supports_embed(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def supports_caption(self) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
# Verify protocol compliance at import time (catches missing methods early).
|
||||
assert isinstance(MockVisionBackend(), VisionBackend)
|
||||
151
circuitforge_core/vision/backends/siglip.py
Normal file
151
circuitforge_core/vision/backends/siglip.py
Normal file
|
|
@ -0,0 +1,151 @@
|
|||
# circuitforge_core/vision/backends/siglip.py — SigLIPBackend
|
||||
#
|
||||
# Requires: pip install -e "circuitforge-core[vision-siglip]"
|
||||
# Default model: google/siglip-so400m-patch14-384 (~1.4 GB VRAM)
|
||||
#
|
||||
# SigLIP uses sigmoid cross-entropy rather than softmax over labels, so each
|
||||
# score is an independent 0–1 probability. This is better than CLIP for
|
||||
# multi-label classification and document routing.
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
|
||||
from circuitforge_core.vision.backends.base import VisionResult
|
||||
|
||||
_DEFAULT_MODEL = "google/siglip-so400m-patch14-384"
|
||||
|
||||
# VRAM footprints by model variant (MB, fp16).
|
||||
_VRAM_TABLE: dict[str, int] = {
|
||||
"siglip-so400m-patch14-384": 1440,
|
||||
"siglip-so400m-patch14-224": 1440,
|
||||
"siglip-base-patch16-224": 340,
|
||||
"siglip-large-patch16-256": 690,
|
||||
}
|
||||
|
||||
|
||||
def _estimate_vram(model_path: str) -> int:
|
||||
lower = model_path.lower()
|
||||
for key, mb in _VRAM_TABLE.items():
|
||||
if key in lower:
|
||||
return mb
|
||||
return 1500 # conservative default for unknown so400m variants
|
||||
|
||||
|
||||
class SigLIPBackend:
|
||||
"""
|
||||
Image classification + embedding via Google SigLIP.
|
||||
|
||||
classify() returns sigmoid similarity scores for each candidate label —
|
||||
independent probabilities, not a softmax distribution.
|
||||
embed() returns the CLS-pool image embedding (normalised).
|
||||
caption() raises NotImplementedError — use VLMBackend for generation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_path: str = _DEFAULT_MODEL,
|
||||
device: str = "cuda",
|
||||
dtype: str = "float16",
|
||||
) -> None:
|
||||
try:
|
||||
import torch
|
||||
from transformers import SiglipProcessor, SiglipModel
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"SigLIPBackend requires torch and transformers. "
|
||||
"Install with: pip install -e 'circuitforge-core[vision-siglip]'"
|
||||
) from exc
|
||||
|
||||
import torch as _torch
|
||||
|
||||
self._device = device
|
||||
self._dtype_str = dtype
|
||||
self._torch_dtype = (
|
||||
_torch.float16 if dtype == "float16"
|
||||
else _torch.bfloat16 if dtype == "bfloat16"
|
||||
else _torch.float32
|
||||
)
|
||||
self._model_path = model_path
|
||||
self._vram_mb = _estimate_vram(model_path)
|
||||
|
||||
# transformers 5.2.0 broke SiglipProcessor.from_pretrained() via the
|
||||
# auto-detection path (TOKENIZER_MAPPING_NAMES.get() returns None for
|
||||
# 'siglip', causing AttributeError on .replace()). Load components
|
||||
# directly and compose manually to bypass that code path.
|
||||
try:
|
||||
from transformers import SiglipTokenizer, SiglipImageProcessor
|
||||
_tokenizer = SiglipTokenizer.from_pretrained(model_path)
|
||||
_image_proc = SiglipImageProcessor.from_pretrained(model_path)
|
||||
self._processor = SiglipProcessor(
|
||||
image_processor=_image_proc, tokenizer=_tokenizer
|
||||
)
|
||||
except Exception:
|
||||
# Fallback: try the standard path (may work on older transformers builds)
|
||||
self._processor = SiglipProcessor.from_pretrained(model_path)
|
||||
self._model = SiglipModel.from_pretrained(
|
||||
model_path,
|
||||
torch_dtype=self._torch_dtype,
|
||||
).to(device)
|
||||
# Set inference mode (train(False) == model.eval() without grad tracking)
|
||||
self._model.train(False)
|
||||
|
||||
# ── VisionBackend Protocol ─────────────────────────────────────────────────
|
||||
|
||||
def classify(self, image: bytes, labels: list[str]) -> VisionResult:
|
||||
"""Zero-shot sigmoid classification — scores are independent per label."""
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
pil_img = Image.open(io.BytesIO(image)).convert("RGB")
|
||||
inputs = self._processor(
|
||||
text=labels,
|
||||
images=pil_img,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
).to(self._device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = self._model(**inputs)
|
||||
# logits_per_image: (1, num_labels) — raw SigLIP logits
|
||||
logits = outputs.logits_per_image[0]
|
||||
scores = torch.sigmoid(logits).cpu().float().tolist()
|
||||
|
||||
return VisionResult(labels=list(labels), scores=scores, model=self.model_name)
|
||||
|
||||
def embed(self, image: bytes) -> VisionResult:
|
||||
"""Return normalised image embedding (CLS pool, L2-normalised)."""
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
pil_img = Image.open(io.BytesIO(image)).convert("RGB")
|
||||
inputs = self._processor(images=pil_img, return_tensors="pt").to(self._device)
|
||||
|
||||
with torch.no_grad():
|
||||
image_features = self._model.get_image_features(**inputs)
|
||||
# L2-normalise so dot-product == cosine similarity
|
||||
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
|
||||
|
||||
embedding = image_features[0].cpu().float().tolist()
|
||||
return VisionResult(embedding=embedding, model=self.model_name)
|
||||
|
||||
def caption(self, image: bytes, prompt: str = "") -> VisionResult:
|
||||
raise NotImplementedError(
|
||||
"SigLIPBackend does not support caption generation. "
|
||||
"Use backend='vlm' (VLMBackend) for image-to-text generation."
|
||||
)
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
return self._model_path.split("/")[-1]
|
||||
|
||||
@property
|
||||
def vram_mb(self) -> int:
|
||||
return self._vram_mb
|
||||
|
||||
@property
|
||||
def supports_embed(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def supports_caption(self) -> bool:
|
||||
return False
|
||||
181
circuitforge_core/vision/backends/vlm.py
Normal file
181
circuitforge_core/vision/backends/vlm.py
Normal file
|
|
@ -0,0 +1,181 @@
|
|||
# circuitforge_core/vision/backends/vlm.py — VLMBackend
|
||||
#
|
||||
# Requires: pip install -e "circuitforge-core[vision-vlm]"
|
||||
#
|
||||
# Supports any HuggingFace AutoModelForVision2Seq-compatible VLM.
|
||||
# Validated models (VRAM fp16):
|
||||
# vikhyatk/moondream2 ~2 GB — fast, lightweight, good for documents
|
||||
# llava-hf/llava-1.5-7b-hf ~14 GB — strong general VQA
|
||||
# Qwen/Qwen2-VL-7B-Instruct ~16 GB — multilingual, structured output friendly
|
||||
#
|
||||
# VLMBackend implements caption() (generative) and a prompt-based classify()
|
||||
# that asks the model to pick from a list. embed() raises NotImplementedError.
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
|
||||
from circuitforge_core.vision.backends.base import VisionResult
|
||||
|
||||
# VRAM estimates (MB, fp16) keyed by lowercase model name fragment.
|
||||
_VRAM_TABLE: dict[str, int] = {
|
||||
"moondream2": 2000,
|
||||
"moondream": 2000,
|
||||
"llava-1.5-7b": 14000,
|
||||
"llava-7b": 14000,
|
||||
"qwen2-vl-7b": 16000,
|
||||
"qwen-vl-7b": 16000,
|
||||
"llava-1.5-13b": 26000,
|
||||
"phi-3-vision": 8000,
|
||||
"phi3-vision": 8000,
|
||||
"paligemma": 6000,
|
||||
"idefics": 12000,
|
||||
"cogvlm": 14000,
|
||||
}
|
||||
|
||||
_CLASSIFY_PROMPT_TMPL = (
|
||||
"Choose the single best label for this image from the following options: "
|
||||
"{labels}. Reply with ONLY the label text, nothing else."
|
||||
)
|
||||
|
||||
|
||||
def _estimate_vram(model_path: str) -> int:
|
||||
lower = model_path.lower()
|
||||
for key, mb in _VRAM_TABLE.items():
|
||||
if key in lower:
|
||||
return mb
|
||||
return 8000 # safe default for unknown 7B-class VLMs
|
||||
|
||||
|
||||
class VLMBackend:
|
||||
"""
|
||||
Generative vision-language model backend.
|
||||
|
||||
caption() generates free-form text from an image + optional prompt.
|
||||
classify() prompts the model to select from candidate labels.
|
||||
embed() raises NotImplementedError — use SigLIPBackend for embeddings.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_path: str,
|
||||
device: str = "cuda",
|
||||
dtype: str = "float16",
|
||||
max_new_tokens: int = 512,
|
||||
) -> None:
|
||||
try:
|
||||
import torch
|
||||
from transformers import AutoProcessor, AutoModelForVision2Seq
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"VLMBackend requires torch and transformers. "
|
||||
"Install with: pip install -e 'circuitforge-core[vision-vlm]'"
|
||||
) from exc
|
||||
|
||||
import torch as _torch
|
||||
|
||||
self._device = device
|
||||
self._max_new_tokens = max_new_tokens
|
||||
self._model_path = model_path
|
||||
self._vram_mb = _estimate_vram(model_path)
|
||||
|
||||
torch_dtype = (
|
||||
_torch.float16 if dtype == "float16"
|
||||
else _torch.bfloat16 if dtype == "bfloat16"
|
||||
else _torch.float32
|
||||
)
|
||||
|
||||
self._processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
|
||||
self._model = AutoModelForVision2Seq.from_pretrained(
|
||||
model_path,
|
||||
torch_dtype=torch_dtype,
|
||||
trust_remote_code=True,
|
||||
).to(device)
|
||||
# Put model in inference mode — disables dropout/batchnorm training behaviour
|
||||
self._model.train(False)
|
||||
|
||||
# ── VisionBackend Protocol ─────────────────────────────────────────────────
|
||||
|
||||
def caption(self, image: bytes, prompt: str = "") -> VisionResult:
|
||||
"""Generate a text description of the image."""
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
pil_img = Image.open(io.BytesIO(image)).convert("RGB")
|
||||
effective_prompt = prompt or "Describe this image in detail."
|
||||
|
||||
inputs = self._processor(
|
||||
text=effective_prompt,
|
||||
images=pil_img,
|
||||
return_tensors="pt",
|
||||
).to(self._device)
|
||||
|
||||
with torch.no_grad():
|
||||
generated_ids = self._model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=self._max_new_tokens,
|
||||
do_sample=False,
|
||||
)
|
||||
|
||||
# Strip the input prompt tokens from the generated output
|
||||
input_len = inputs["input_ids"].shape[1]
|
||||
output_ids = generated_ids[0][input_len:]
|
||||
text = self._processor.decode(output_ids, skip_special_tokens=True).strip()
|
||||
|
||||
return VisionResult(caption=text, model=self.model_name)
|
||||
|
||||
def classify(self, image: bytes, labels: list[str]) -> VisionResult:
|
||||
"""
|
||||
Prompt-based zero-shot classification.
|
||||
|
||||
Asks the VLM to choose a label from the provided list. The returned
|
||||
scores are binary (1.0 for the selected label, 0.0 for others) since
|
||||
VLMs don't expose per-label logits the same way SigLIP does.
|
||||
For soft scores, use SigLIPBackend.
|
||||
"""
|
||||
labels_str = ", ".join(f'"{lbl}"' for lbl in labels)
|
||||
prompt = _CLASSIFY_PROMPT_TMPL.format(labels=labels_str)
|
||||
result = self.caption(image, prompt=prompt)
|
||||
raw = (result.caption or "").strip().strip('"').strip("'")
|
||||
|
||||
matched = _match_label(raw, labels)
|
||||
scores = [1.0 if lbl == matched else 0.0 for lbl in labels]
|
||||
return VisionResult(labels=list(labels), scores=scores, model=self.model_name)
|
||||
|
||||
def embed(self, image: bytes) -> VisionResult:
|
||||
raise NotImplementedError(
|
||||
"VLMBackend does not support image embeddings. "
|
||||
"Use backend='siglip' (SigLIPBackend) for embed()."
|
||||
)
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
return self._model_path.split("/")[-1]
|
||||
|
||||
@property
|
||||
def vram_mb(self) -> int:
|
||||
return self._vram_mb
|
||||
|
||||
@property
|
||||
def supports_embed(self) -> bool:
|
||||
return False
|
||||
|
||||
@property
|
||||
def supports_caption(self) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
# ── Helpers ───────────────────────────────────────────────────────────────────
|
||||
|
||||
def _match_label(raw: str, labels: list[str]) -> str:
|
||||
"""Return the best matching label from the VLM's free-form response."""
|
||||
raw_lower = raw.lower()
|
||||
for lbl in labels:
|
||||
if lbl.lower() == raw_lower:
|
||||
return lbl
|
||||
for lbl in labels:
|
||||
if raw_lower.startswith(lbl.lower()) or lbl.lower().startswith(raw_lower):
|
||||
return lbl
|
||||
for lbl in labels:
|
||||
if lbl.lower() in raw_lower or raw_lower in lbl.lower():
|
||||
return lbl
|
||||
return labels[0] if labels else raw
|
||||
|
|
@ -1,19 +0,0 @@
|
|||
"""
|
||||
Vision model router — stub until v0.2.
|
||||
Supports: moondream2 (local) and Claude vision API (cloud).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
class VisionRouter:
|
||||
"""Routes image analysis requests to local or cloud vision models."""
|
||||
|
||||
def analyze(self, image_bytes: bytes, prompt: str) -> str:
|
||||
"""
|
||||
Analyze image_bytes with the given prompt.
|
||||
Raises NotImplementedError until vision backends are wired up.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"VisionRouter is not yet implemented. "
|
||||
"Photo analysis requires a Paid tier or local vision model (v0.2+)."
|
||||
)
|
||||
1
docs/plausible.js
Normal file
1
docs/plausible.js
Normal file
|
|
@ -0,0 +1 @@
|
|||
(function(){var s=document.createElement("script");s.defer=true;s.dataset.domain="docs.circuitforge.tech,circuitforge.tech";s.dataset.api="https://analytics.circuitforge.tech/api/event";s.src="https://analytics.circuitforge.tech/js/script.js";document.head.appendChild(s);})();
|
||||
82
mkdocs.yml
Normal file
82
mkdocs.yml
Normal file
|
|
@ -0,0 +1,82 @@
|
|||
site_name: circuitforge-core
|
||||
site_description: Shared scaffold for CircuitForge products — modules, conventions, and developer reference.
|
||||
site_author: Circuit Forge LLC
|
||||
site_url: https://docs.circuitforge.tech/cf-core
|
||||
repo_url: https://git.opensourcesolarpunk.com/Circuit-Forge/circuitforge-core
|
||||
repo_name: Circuit-Forge/circuitforge-core
|
||||
|
||||
theme:
|
||||
name: material
|
||||
palette:
|
||||
- scheme: default
|
||||
primary: deep purple
|
||||
accent: purple
|
||||
toggle:
|
||||
icon: material/brightness-7
|
||||
name: Switch to dark mode
|
||||
- scheme: slate
|
||||
primary: deep purple
|
||||
accent: purple
|
||||
toggle:
|
||||
icon: material/brightness-4
|
||||
name: Switch to light mode
|
||||
features:
|
||||
- navigation.top
|
||||
- navigation.sections
|
||||
- search.suggest
|
||||
- search.highlight
|
||||
- content.code.copy
|
||||
- content.code.annotate
|
||||
|
||||
markdown_extensions:
|
||||
- admonition
|
||||
- attr_list
|
||||
- md_in_html
|
||||
- pymdownx.details
|
||||
- pymdownx.superfences:
|
||||
custom_fences:
|
||||
- name: mermaid
|
||||
class: mermaid
|
||||
format: !!python/name:pymdownx.superfences.fence_code_format
|
||||
- pymdownx.emoji:
|
||||
emoji_index: !!python/name:material.extensions.emoji.twemoji
|
||||
emoji_generator: !!python/name:material.extensions.emoji.to_svg
|
||||
- pymdownx.highlight:
|
||||
anchor_linenums: true
|
||||
- pymdownx.inlinehilite
|
||||
- pymdownx.tabbed:
|
||||
alternate_style: true
|
||||
- toc:
|
||||
permalink: true
|
||||
|
||||
nav:
|
||||
- Home: index.md
|
||||
- Getting Started:
|
||||
- Installation: getting-started/installation.md
|
||||
- Using in a Product: getting-started/using-in-product.md
|
||||
- Module Reference:
|
||||
- Overview: modules/index.md
|
||||
- db: modules/db.md
|
||||
- llm: modules/llm.md
|
||||
- tiers: modules/tiers.md
|
||||
- config: modules/config.md
|
||||
- hardware: modules/hardware.md
|
||||
- documents: modules/documents.md
|
||||
- affiliates: modules/affiliates.md
|
||||
- preferences: modules/preferences.md
|
||||
- tasks: modules/tasks.md
|
||||
- manage: modules/manage.md
|
||||
- resources: modules/resources.md
|
||||
- text: modules/text.md
|
||||
- stt: modules/stt.md
|
||||
- tts: modules/tts.md
|
||||
- pipeline: modules/pipeline.md
|
||||
- vision: modules/vision.md
|
||||
- wizard: modules/wizard.md
|
||||
- Developer Guide:
|
||||
- Adding a Module: developer/adding-module.md
|
||||
- Editable Install Pattern: developer/editable-install.md
|
||||
- BSL vs MIT Boundaries: developer/licensing.md
|
||||
|
||||
extra_javascript:
|
||||
- plausible.js
|
||||
|
|
@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|||
|
||||
[project]
|
||||
name = "circuitforge-core"
|
||||
version = "0.8.0"
|
||||
version = "0.10.0"
|
||||
description = "Shared scaffold for CircuitForge products (MIT)"
|
||||
requires-python = ">=3.11"
|
||||
dependencies = [
|
||||
|
|
@ -14,10 +14,61 @@ dependencies = [
|
|||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
community = [
|
||||
"psycopg2>=2.9",
|
||||
]
|
||||
manage = [
|
||||
"platformdirs>=4.0",
|
||||
"typer[all]>=0.12",
|
||||
]
|
||||
text-llamacpp = [
|
||||
"llama-cpp-python>=0.2.0",
|
||||
]
|
||||
text-transformers = [
|
||||
"torch>=2.0",
|
||||
"transformers>=4.40",
|
||||
"accelerate>=0.27",
|
||||
]
|
||||
text-transformers-4bit = [
|
||||
"circuitforge-core[text-transformers]",
|
||||
"bitsandbytes>=0.43",
|
||||
]
|
||||
stt-faster-whisper = [
|
||||
"faster-whisper>=1.0",
|
||||
]
|
||||
stt-service = [
|
||||
"circuitforge-core[stt-faster-whisper]",
|
||||
"fastapi>=0.110",
|
||||
"uvicorn[standard]>=0.29",
|
||||
"python-multipart>=0.0.9",
|
||||
]
|
||||
tts-chatterbox = [
|
||||
"chatterbox-tts>=0.1",
|
||||
"torchaudio>=2.0",
|
||||
]
|
||||
tts-service = [
|
||||
"circuitforge-core[tts-chatterbox]",
|
||||
"fastapi>=0.110",
|
||||
"uvicorn[standard]>=0.29",
|
||||
"python-multipart>=0.0.9",
|
||||
]
|
||||
vision-siglip = [
|
||||
"torch>=2.0",
|
||||
"transformers>=4.40",
|
||||
"Pillow>=10.0",
|
||||
]
|
||||
vision-vlm = [
|
||||
"torch>=2.0",
|
||||
"transformers>=4.40",
|
||||
"Pillow>=10.0",
|
||||
"accelerate>=0.27",
|
||||
]
|
||||
vision-service = [
|
||||
"circuitforge-core[vision-siglip]",
|
||||
"fastapi>=0.110",
|
||||
"uvicorn[standard]>=0.29",
|
||||
"python-multipart>=0.0.9",
|
||||
]
|
||||
dev = [
|
||||
"circuitforge-core[manage]",
|
||||
"pytest>=8.0",
|
||||
|
|
@ -35,6 +86,9 @@ cf-manage = "circuitforge_core.manage.cli:app"
|
|||
where = ["."]
|
||||
include = ["circuitforge_core*"]
|
||||
|
||||
[tool.setuptools.package-data]
|
||||
"circuitforge_core.community.migrations" = ["*.sql"]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = ["tests"]
|
||||
asyncio_mode = "auto"
|
||||
|
|
|
|||
0
tests/community/__init__.py
Normal file
0
tests/community/__init__.py
Normal file
63
tests/community/test_db.py
Normal file
63
tests/community/test_db.py
Normal file
|
|
@ -0,0 +1,63 @@
|
|||
# tests/community/test_db.py
|
||||
import os
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch, call
|
||||
from circuitforge_core.community.db import CommunityDB
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_pool():
|
||||
"""Patch psycopg2.pool.ThreadedConnectionPool to avoid needing a real PG instance."""
|
||||
with patch("circuitforge_core.community.db.ThreadedConnectionPool") as mock_cls:
|
||||
mock_instance = MagicMock()
|
||||
mock_cls.return_value = mock_instance
|
||||
yield mock_cls, mock_instance
|
||||
|
||||
|
||||
def test_community_db_requires_url():
|
||||
with pytest.raises(ValueError, match="COMMUNITY_DB_URL"):
|
||||
CommunityDB(dsn=None)
|
||||
|
||||
|
||||
def test_community_db_init_creates_pool(mock_pool):
|
||||
mock_cls, _ = mock_pool
|
||||
CommunityDB(dsn="postgresql://user:pass@localhost/cf_community")
|
||||
mock_cls.assert_called_once()
|
||||
|
||||
|
||||
def test_community_db_close_puts_pool(mock_pool):
|
||||
_, mock_instance = mock_pool
|
||||
db = CommunityDB(dsn="postgresql://user:pass@localhost/cf_community")
|
||||
db.close()
|
||||
mock_instance.closeall.assert_called_once()
|
||||
|
||||
|
||||
def test_community_db_migration_files_discovered():
|
||||
"""Migration runner must find at least 001 and 002 SQL files."""
|
||||
db = CommunityDB.__new__(CommunityDB)
|
||||
files = db._discover_migrations()
|
||||
names = [f.name for f in files]
|
||||
assert any("001" in n for n in names)
|
||||
assert any("002" in n for n in names)
|
||||
# Must be sorted numerically
|
||||
assert files == sorted(files, key=lambda p: p.name)
|
||||
|
||||
|
||||
def test_community_db_run_migrations_executes_sql(mock_pool):
|
||||
_, mock_instance = mock_pool
|
||||
mock_conn = MagicMock()
|
||||
mock_cur = MagicMock()
|
||||
mock_instance.getconn.return_value = mock_conn
|
||||
mock_conn.cursor.return_value.__enter__.return_value = mock_cur
|
||||
|
||||
db = CommunityDB(dsn="postgresql://user:pass@localhost/cf_community")
|
||||
db.run_migrations()
|
||||
|
||||
# At least one execute call must have happened
|
||||
assert mock_cur.execute.called
|
||||
|
||||
|
||||
def test_community_db_from_env(monkeypatch, mock_pool):
|
||||
monkeypatch.setenv("COMMUNITY_DB_URL", "postgresql://u:p@host/db")
|
||||
db = CommunityDB.from_env()
|
||||
assert db is not None
|
||||
94
tests/community/test_models.py
Normal file
94
tests/community/test_models.py
Normal file
|
|
@ -0,0 +1,94 @@
|
|||
# tests/community/test_models.py
|
||||
import pytest
|
||||
from datetime import datetime, timezone
|
||||
from circuitforge_core.community.models import CommunityPost
|
||||
|
||||
|
||||
def make_post(**kwargs) -> CommunityPost:
|
||||
defaults = dict(
|
||||
slug="kiwi-plan-test-2026-04-12-pasta-week",
|
||||
pseudonym="PastaWitch",
|
||||
post_type="plan",
|
||||
published=datetime(2026, 4, 12, 12, 0, 0, tzinfo=timezone.utc),
|
||||
title="Pasta Week",
|
||||
description="Seven days of carbs",
|
||||
photo_url=None,
|
||||
slots=[{"day": 0, "meal_type": "dinner", "recipe_id": 1, "recipe_name": "Spaghetti"}],
|
||||
recipe_id=None,
|
||||
recipe_name=None,
|
||||
level=None,
|
||||
outcome_notes=None,
|
||||
seasoning_score=0.7,
|
||||
richness_score=0.6,
|
||||
brightness_score=0.3,
|
||||
depth_score=0.5,
|
||||
aroma_score=0.4,
|
||||
structure_score=0.8,
|
||||
texture_profile="chewy",
|
||||
dietary_tags=["vegetarian"],
|
||||
allergen_flags=["gluten"],
|
||||
flavor_molecules=[1234, 5678],
|
||||
fat_pct=12.5,
|
||||
protein_pct=10.0,
|
||||
moisture_pct=55.0,
|
||||
)
|
||||
defaults.update(kwargs)
|
||||
return CommunityPost(**defaults)
|
||||
|
||||
|
||||
def test_community_post_immutable():
|
||||
post = make_post()
|
||||
with pytest.raises((AttributeError, TypeError)):
|
||||
post.title = "changed" # type: ignore
|
||||
|
||||
|
||||
def test_community_post_slug_uri_compatible():
|
||||
post = make_post(slug="kiwi-plan-test-2026-04-12-pasta-week")
|
||||
assert " " not in post.slug
|
||||
assert post.slug == post.slug.lower()
|
||||
|
||||
|
||||
def test_community_post_type_valid():
|
||||
make_post(post_type="plan")
|
||||
make_post(post_type="recipe_success")
|
||||
make_post(post_type="recipe_blooper")
|
||||
|
||||
|
||||
def test_community_post_type_invalid():
|
||||
with pytest.raises(ValueError):
|
||||
make_post(post_type="garbage")
|
||||
|
||||
|
||||
def test_community_post_scores_range():
|
||||
post = make_post(seasoning_score=1.0, richness_score=0.0)
|
||||
assert 0.0 <= post.seasoning_score <= 1.0
|
||||
assert 0.0 <= post.richness_score <= 1.0
|
||||
|
||||
|
||||
def test_community_post_scores_out_of_range():
|
||||
with pytest.raises(ValueError):
|
||||
make_post(seasoning_score=1.5)
|
||||
with pytest.raises(ValueError):
|
||||
make_post(richness_score=-0.1)
|
||||
|
||||
|
||||
def test_community_post_dietary_tags_immutable():
|
||||
post = make_post(dietary_tags=["vegan"])
|
||||
assert isinstance(post.dietary_tags, tuple)
|
||||
|
||||
|
||||
def test_community_post_allergen_flags_immutable():
|
||||
post = make_post(allergen_flags=["nuts", "dairy"])
|
||||
assert isinstance(post.allergen_flags, tuple)
|
||||
|
||||
|
||||
def test_community_post_flavor_molecules_immutable():
|
||||
post = make_post(flavor_molecules=[1, 2, 3])
|
||||
assert isinstance(post.flavor_molecules, tuple)
|
||||
|
||||
|
||||
def test_community_post_optional_fields_none():
|
||||
post = make_post(photo_url=None, recipe_id=None, fat_pct=None)
|
||||
assert post.photo_url is None
|
||||
assert post.recipe_id is None
|
||||
assert post.fat_pct is None
|
||||
115
tests/community/test_store.py
Normal file
115
tests/community/test_store.py
Normal file
|
|
@ -0,0 +1,115 @@
|
|||
# tests/community/test_store.py
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from datetime import datetime, timezone
|
||||
from circuitforge_core.community.store import SharedStore
|
||||
from circuitforge_core.community.models import CommunityPost
|
||||
|
||||
|
||||
def make_post_row() -> dict:
|
||||
return {
|
||||
"id": 1,
|
||||
"slug": "kiwi-plan-test-pasta-week",
|
||||
"pseudonym": "PastaWitch",
|
||||
"post_type": "plan",
|
||||
"published": datetime(2026, 4, 12, 12, 0, 0, tzinfo=timezone.utc),
|
||||
"title": "Pasta Week",
|
||||
"description": None,
|
||||
"photo_url": None,
|
||||
"slots": [{"day": 0, "meal_type": "dinner", "recipe_id": 1, "recipe_name": "Spaghetti"}],
|
||||
"recipe_id": None,
|
||||
"recipe_name": None,
|
||||
"level": None,
|
||||
"outcome_notes": None,
|
||||
"seasoning_score": 0.7,
|
||||
"richness_score": 0.6,
|
||||
"brightness_score": 0.3,
|
||||
"depth_score": 0.5,
|
||||
"aroma_score": 0.4,
|
||||
"structure_score": 0.8,
|
||||
"texture_profile": "chewy",
|
||||
"dietary_tags": ["vegetarian"],
|
||||
"allergen_flags": ["gluten"],
|
||||
"flavor_molecules": [1234],
|
||||
"fat_pct": 12.5,
|
||||
"protein_pct": 10.0,
|
||||
"moisture_pct": 55.0,
|
||||
"source_product": "kiwi",
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db():
|
||||
db = MagicMock()
|
||||
conn = MagicMock()
|
||||
cur = MagicMock()
|
||||
db.getconn.return_value = conn
|
||||
conn.cursor.return_value.__enter__.return_value = cur
|
||||
return db, conn, cur
|
||||
|
||||
|
||||
def test_shared_store_get_post_by_slug(mock_db):
|
||||
db, conn, cur = mock_db
|
||||
cur.fetchone.return_value = make_post_row()
|
||||
cur.description = [(col,) for col in make_post_row().keys()]
|
||||
|
||||
store = SharedStore(db)
|
||||
post = store.get_post_by_slug("kiwi-plan-test-pasta-week")
|
||||
|
||||
assert post is not None
|
||||
assert isinstance(post, CommunityPost)
|
||||
assert post.slug == "kiwi-plan-test-pasta-week"
|
||||
assert post.pseudonym == "PastaWitch"
|
||||
|
||||
|
||||
def test_shared_store_get_post_by_slug_not_found(mock_db):
|
||||
db, conn, cur = mock_db
|
||||
cur.fetchone.return_value = None
|
||||
|
||||
store = SharedStore(db)
|
||||
post = store.get_post_by_slug("does-not-exist")
|
||||
assert post is None
|
||||
|
||||
|
||||
def test_shared_store_list_posts_returns_list(mock_db):
|
||||
db, conn, cur = mock_db
|
||||
row = make_post_row()
|
||||
cur.fetchall.return_value = [row]
|
||||
cur.description = [(col,) for col in row.keys()]
|
||||
|
||||
store = SharedStore(db)
|
||||
posts = store.list_posts(limit=10, offset=0)
|
||||
|
||||
assert isinstance(posts, list)
|
||||
assert len(posts) == 1
|
||||
assert posts[0].slug == "kiwi-plan-test-pasta-week"
|
||||
|
||||
|
||||
def test_shared_store_delete_post(mock_db):
|
||||
db, conn, cur = mock_db
|
||||
cur.rowcount = 1
|
||||
|
||||
store = SharedStore(db)
|
||||
deleted = store.delete_post(slug="kiwi-plan-test-pasta-week", pseudonym="PastaWitch")
|
||||
assert deleted is True
|
||||
|
||||
|
||||
def test_shared_store_delete_post_wrong_owner(mock_db):
|
||||
db, conn, cur = mock_db
|
||||
cur.rowcount = 0
|
||||
|
||||
store = SharedStore(db)
|
||||
deleted = store.delete_post(slug="kiwi-plan-test-pasta-week", pseudonym="WrongUser")
|
||||
assert deleted is False
|
||||
|
||||
|
||||
def test_shared_store_returns_connection_on_error(mock_db):
|
||||
db, conn, cur = mock_db
|
||||
cur.fetchone.side_effect = Exception("DB error")
|
||||
|
||||
store = SharedStore(db)
|
||||
with pytest.raises(Exception, match="DB error"):
|
||||
store.get_post_by_slug("any-slug")
|
||||
|
||||
# Connection must be returned to pool even on error
|
||||
db.putconn.assert_called_once_with(conn)
|
||||
0
tests/test_config/__init__.py
Normal file
0
tests/test_config/__init__.py
Normal file
172
tests/test_config/test_license.py
Normal file
172
tests/test_config/test_license.py
Normal file
|
|
@ -0,0 +1,172 @@
|
|||
"""Tests for circuitforge_core.config.license."""
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
import circuitforge_core.config.license as license_module
|
||||
from circuitforge_core.config.license import get_license_tier, validate_license
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_cache():
|
||||
"""Clear the module-level cache before each test."""
|
||||
license_module._cache.clear()
|
||||
yield
|
||||
license_module._cache.clear()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 1. validate_license returns _INVALID when CF_LICENSE_KEY not set
|
||||
# ---------------------------------------------------------------------------
|
||||
def test_validate_license_no_key_returns_invalid(monkeypatch):
|
||||
monkeypatch.delenv("CF_LICENSE_KEY", raising=False)
|
||||
result = validate_license("kiwi")
|
||||
assert result == {"valid": False, "tier": "free", "user_id": ""}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 2. validate_license calls Heimdall and returns valid result when key set
|
||||
# ---------------------------------------------------------------------------
|
||||
def test_validate_license_valid_response(monkeypatch):
|
||||
monkeypatch.setenv("CF_LICENSE_KEY", "CFG-KIWI-AAAA-BBBB-CCCC")
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.ok = True
|
||||
mock_resp.json.return_value = {"valid": True, "tier": "paid", "user_id": "user-42"}
|
||||
|
||||
with patch("circuitforge_core.config.license.requests.post", return_value=mock_resp) as mock_post:
|
||||
result = validate_license("kiwi")
|
||||
|
||||
mock_post.assert_called_once()
|
||||
assert result == {"valid": True, "tier": "paid", "user_id": "user-42"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 3. validate_license returns invalid when Heimdall returns non-ok status
|
||||
# ---------------------------------------------------------------------------
|
||||
def test_validate_license_non_ok_response(monkeypatch):
|
||||
monkeypatch.setenv("CF_LICENSE_KEY", "CFG-KIWI-AAAA-BBBB-CCCC")
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.ok = False
|
||||
mock_resp.status_code = 403
|
||||
|
||||
with patch("circuitforge_core.config.license.requests.post", return_value=mock_resp):
|
||||
result = validate_license("kiwi")
|
||||
|
||||
assert result == {"valid": False, "tier": "free", "user_id": ""}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 4. validate_license returns invalid when network fails
|
||||
# ---------------------------------------------------------------------------
|
||||
def test_validate_license_network_error(monkeypatch):
|
||||
monkeypatch.setenv("CF_LICENSE_KEY", "CFG-KIWI-AAAA-BBBB-CCCC")
|
||||
|
||||
with patch(
|
||||
"circuitforge_core.config.license.requests.post",
|
||||
side_effect=requests.exceptions.ConnectionError("unreachable"),
|
||||
):
|
||||
result = validate_license("kiwi")
|
||||
|
||||
assert result == {"valid": False, "tier": "free", "user_id": ""}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 5. validate_license caches result — second call does NOT make a second request
|
||||
# ---------------------------------------------------------------------------
|
||||
def test_validate_license_caches_result(monkeypatch):
|
||||
monkeypatch.setenv("CF_LICENSE_KEY", "CFG-KIWI-CACHE-TEST-KEY")
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.ok = True
|
||||
mock_resp.json.return_value = {"valid": True, "tier": "paid", "user_id": "user-1"}
|
||||
|
||||
with patch("circuitforge_core.config.license.requests.post", return_value=mock_resp) as mock_post:
|
||||
result1 = validate_license("kiwi")
|
||||
result2 = validate_license("kiwi")
|
||||
|
||||
assert mock_post.call_count == 1
|
||||
assert result1 == result2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 6. get_license_tier returns "free" when key absent
|
||||
# ---------------------------------------------------------------------------
|
||||
def test_get_license_tier_no_key_returns_free(monkeypatch):
|
||||
monkeypatch.delenv("CF_LICENSE_KEY", raising=False)
|
||||
assert get_license_tier("snipe") == "free"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 7. get_license_tier returns tier string from valid Heimdall response
|
||||
# ---------------------------------------------------------------------------
|
||||
def test_get_license_tier_valid_key_returns_tier(monkeypatch):
|
||||
monkeypatch.setenv("CF_LICENSE_KEY", "CFG-SNPE-AAAA-BBBB-CCCC")
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.ok = True
|
||||
mock_resp.json.return_value = {"valid": True, "tier": "premium", "user_id": "user-7"}
|
||||
|
||||
with patch("circuitforge_core.config.license.requests.post", return_value=mock_resp):
|
||||
tier = get_license_tier("snipe")
|
||||
|
||||
assert tier == "premium"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 8. get_license_tier returns "free" when Heimdall says valid=False
|
||||
# ---------------------------------------------------------------------------
|
||||
def test_get_license_tier_invalid_key_returns_free(monkeypatch):
|
||||
monkeypatch.setenv("CF_LICENSE_KEY", "CFG-SNPE-DEAD-DEAD-DEAD")
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.ok = True
|
||||
mock_resp.json.return_value = {"valid": False, "tier": "free", "user_id": ""}
|
||||
|
||||
with patch("circuitforge_core.config.license.requests.post", return_value=mock_resp):
|
||||
tier = get_license_tier("snipe")
|
||||
|
||||
assert tier == "free"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 9. CF_LICENSE_URL env var overrides the default Heimdall URL
|
||||
# ---------------------------------------------------------------------------
|
||||
def test_cf_license_url_override(monkeypatch):
|
||||
monkeypatch.setenv("CF_LICENSE_KEY", "CFG-PRNG-AAAA-BBBB-CCCC")
|
||||
monkeypatch.setenv("CF_LICENSE_URL", "http://localhost:9000")
|
||||
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.ok = True
|
||||
mock_resp.json.return_value = {"valid": True, "tier": "paid", "user_id": "u1"}
|
||||
|
||||
with patch("circuitforge_core.config.license.requests.post", return_value=mock_resp) as mock_post:
|
||||
validate_license("peregrine")
|
||||
|
||||
call_url = mock_post.call_args[0][0]
|
||||
assert call_url.startswith("http://localhost:9000"), (
|
||||
f"Expected URL to start with http://localhost:9000, got {call_url!r}"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 10. Expired cache entry triggers a fresh Heimdall call
|
||||
# ---------------------------------------------------------------------------
|
||||
def test_validate_license_expired_cache_triggers_fresh_call(monkeypatch):
|
||||
key = "CFG-KIWI-EXPR-EXPR-EXPR"
|
||||
monkeypatch.setenv("CF_LICENSE_KEY", key)
|
||||
|
||||
# Inject an expired cache entry
|
||||
expired_result = {"valid": True, "tier": "paid", "user_id": "old-user"}
|
||||
license_module._cache[(key, "kiwi")] = (expired_result, time.monotonic() - 1)
|
||||
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.ok = True
|
||||
mock_resp.json.return_value = {"valid": True, "tier": "premium", "user_id": "new-user"}
|
||||
|
||||
with patch("circuitforge_core.config.license.requests.post", return_value=mock_resp) as mock_post:
|
||||
result = validate_license("kiwi")
|
||||
|
||||
mock_post.assert_called_once()
|
||||
assert result["tier"] == "premium"
|
||||
assert result["user_id"] == "new-user"
|
||||
0
tests/test_pipeline/__init__.py
Normal file
0
tests/test_pipeline/__init__.py
Normal file
198
tests/test_pipeline/test_crystallizer.py
Normal file
198
tests/test_pipeline/test_crystallizer.py
Normal file
|
|
@ -0,0 +1,198 @@
|
|||
"""Tests for pipeline.crystallizer — the core promotion logic."""
|
||||
import warnings
|
||||
import pytest
|
||||
from circuitforge_core.pipeline.crystallizer import (
|
||||
CrystallizerConfig,
|
||||
crystallize,
|
||||
evaluate_new_run,
|
||||
should_crystallize,
|
||||
)
|
||||
from circuitforge_core.pipeline.models import PipelineRun, Step
|
||||
from circuitforge_core.pipeline.recorder import Recorder
|
||||
|
||||
|
||||
# ── Fixtures / helpers ────────────────────────────────────────────────────────
|
||||
|
||||
def _run(run_id, approved=True, review_ms=8000, modified=False,
|
||||
steps=None, input_hash="fixedhash",
|
||||
ts="2026-04-08T00:00:00+00:00") -> PipelineRun:
|
||||
return PipelineRun(
|
||||
run_id=run_id,
|
||||
product="osprey",
|
||||
task_type="ivr_navigate",
|
||||
input_hash=input_hash,
|
||||
steps=steps or [Step("dtmf", {"digits": "1"})],
|
||||
approved=approved,
|
||||
review_duration_ms=review_ms,
|
||||
output_modified=modified,
|
||||
timestamp=ts,
|
||||
)
|
||||
|
||||
|
||||
_CFG = CrystallizerConfig(threshold=3, min_review_ms=5_000)
|
||||
|
||||
|
||||
# ── should_crystallize ────────────────────────────────────────────────────────
|
||||
|
||||
class TestShouldCrystallize:
|
||||
def test_returns_false_below_threshold(self):
|
||||
runs = [_run(f"r{i}") for i in range(2)]
|
||||
assert should_crystallize(runs, _CFG) is False
|
||||
|
||||
def test_returns_true_at_threshold(self):
|
||||
runs = [_run(f"r{i}") for i in range(3)]
|
||||
assert should_crystallize(runs, _CFG) is True
|
||||
|
||||
def test_returns_true_above_threshold(self):
|
||||
runs = [_run(f"r{i}") for i in range(10)]
|
||||
assert should_crystallize(runs, _CFG) is True
|
||||
|
||||
def test_unapproved_runs_not_counted(self):
|
||||
approved = [_run(f"r{i}") for i in range(2)]
|
||||
unapproved = [_run(f"u{i}", approved=False) for i in range(10)]
|
||||
assert should_crystallize(approved + unapproved, _CFG) is False
|
||||
|
||||
def test_threshold_one(self):
|
||||
cfg = CrystallizerConfig(threshold=1)
|
||||
assert should_crystallize([_run("r1")], cfg) is True
|
||||
|
||||
|
||||
# ── crystallize ───────────────────────────────────────────────────────────────
|
||||
|
||||
class TestCrystallize:
|
||||
def _approved_runs(self, n=3, review_ms=8000):
|
||||
return [_run(f"r{i}", review_ms=review_ms) for i in range(n)]
|
||||
|
||||
def test_produces_workflow(self):
|
||||
wf = crystallize(self._approved_runs(), _CFG)
|
||||
assert wf.product == "osprey"
|
||||
assert wf.task_type == "ivr_navigate"
|
||||
assert wf.approval_count == 3
|
||||
|
||||
def test_workflow_id_format(self):
|
||||
wf = crystallize(self._approved_runs(), _CFG)
|
||||
assert wf.workflow_id.startswith("osprey:ivr_navigate:")
|
||||
|
||||
def test_avg_review_duration_computed(self):
|
||||
runs = [_run("r0", review_ms=6000), _run("r1", review_ms=10000),
|
||||
_run("r2", review_ms=8000)]
|
||||
wf = crystallize(runs, _CFG)
|
||||
assert wf.avg_review_duration_ms == 8000
|
||||
|
||||
def test_all_output_unmodified_true(self):
|
||||
runs = self._approved_runs()
|
||||
wf = crystallize(runs, _CFG)
|
||||
assert wf.all_output_unmodified is True
|
||||
|
||||
def test_all_output_unmodified_false_when_any_modified(self):
|
||||
runs = [_run("r0"), _run("r1"), _run("r2", modified=True)]
|
||||
wf = crystallize(runs, _CFG)
|
||||
assert wf.all_output_unmodified is False
|
||||
|
||||
def test_raises_below_threshold(self):
|
||||
with pytest.raises(ValueError, match="Need 3"):
|
||||
crystallize([_run("r0"), _run("r1")], _CFG)
|
||||
|
||||
def test_raises_on_mixed_products(self):
|
||||
r1 = _run("r1")
|
||||
r2 = PipelineRun(
|
||||
run_id="r2", product="falcon", task_type="ivr_navigate",
|
||||
input_hash="fixedhash", steps=r1.steps, approved=True,
|
||||
review_duration_ms=8000, output_modified=False,
|
||||
)
|
||||
with pytest.raises(ValueError, match="product"):
|
||||
crystallize([r1, r2, r1], _CFG)
|
||||
|
||||
def test_raises_on_mixed_hashes(self):
|
||||
runs = [_run("r0", input_hash="hash_a"),
|
||||
_run("r1", input_hash="hash_b"),
|
||||
_run("r2", input_hash="hash_a")]
|
||||
with pytest.raises(ValueError, match="input_hash"):
|
||||
crystallize(runs, _CFG)
|
||||
|
||||
def test_rubber_stamp_warning(self):
|
||||
runs = [_run(f"r{i}", review_ms=100) for i in range(3)]
|
||||
with warnings.catch_warnings(record=True) as caught:
|
||||
warnings.simplefilter("always")
|
||||
crystallize(runs, _CFG)
|
||||
assert any("rubber-stamp" in str(w.message) for w in caught)
|
||||
|
||||
def test_no_warning_when_min_review_ms_zero(self):
|
||||
cfg = CrystallizerConfig(threshold=3, min_review_ms=0)
|
||||
runs = [_run(f"r{i}", review_ms=1) for i in range(3)]
|
||||
with warnings.catch_warnings(record=True) as caught:
|
||||
warnings.simplefilter("always")
|
||||
crystallize(runs, cfg)
|
||||
assert not any("rubber-stamp" in str(w.message) for w in caught)
|
||||
|
||||
def test_version_increments(self):
|
||||
wf = crystallize(self._approved_runs(), _CFG, existing_version=2)
|
||||
assert wf.version == 3
|
||||
|
||||
def test_strategy_most_recent_uses_latest(self):
|
||||
steps_old = [Step("dtmf", {"digits": "9"})]
|
||||
steps_new = [Step("dtmf", {"digits": "1"})]
|
||||
runs = [
|
||||
_run("r0", steps=steps_old, ts="2026-01-01T00:00:00+00:00"),
|
||||
_run("r1", steps=steps_old, ts="2026-01-02T00:00:00+00:00"),
|
||||
_run("r2", steps=steps_new, ts="2026-04-08T00:00:00+00:00"),
|
||||
]
|
||||
cfg = CrystallizerConfig(threshold=3, strategy="most_recent")
|
||||
wf = crystallize(runs, cfg)
|
||||
assert wf.steps[0].params["digits"] == "1"
|
||||
|
||||
def test_strategy_majority_picks_common_action(self):
|
||||
steps_a = [Step("dtmf", {"digits": "1"})]
|
||||
steps_b = [Step("press_key", {"key": "2"})]
|
||||
runs = [
|
||||
_run("r0", steps=steps_a),
|
||||
_run("r1", steps=steps_a),
|
||||
_run("r2", steps=steps_b),
|
||||
]
|
||||
cfg = CrystallizerConfig(threshold=3, strategy="majority")
|
||||
wf = crystallize(runs, cfg)
|
||||
assert wf.steps[0].action == "dtmf"
|
||||
|
||||
def test_strategy_majority_falls_back_on_length_mismatch(self):
|
||||
runs = [
|
||||
_run("r0", steps=[Step("dtmf", {"digits": "1"})]),
|
||||
_run("r1", steps=[Step("dtmf", {"digits": "1"}),
|
||||
Step("dtmf", {"digits": "2"})]),
|
||||
_run("r2", steps=[Step("dtmf", {"digits": "1"})],
|
||||
ts="2026-04-08T00:00:00+00:00"),
|
||||
]
|
||||
cfg = CrystallizerConfig(threshold=3, strategy="majority")
|
||||
# Should not raise — falls back to most_recent
|
||||
wf = crystallize(runs, cfg)
|
||||
assert wf.steps is not None
|
||||
|
||||
|
||||
# ── evaluate_new_run ──────────────────────────────────────────────────────────
|
||||
|
||||
class TestEvaluateNewRun:
|
||||
def test_returns_none_before_threshold(self, tmp_path):
|
||||
rec = Recorder(root=tmp_path)
|
||||
cfg = CrystallizerConfig(threshold=3, min_review_ms=0)
|
||||
result = evaluate_new_run(_run("r1"), rec, cfg)
|
||||
assert result is None
|
||||
|
||||
def test_returns_workflow_at_threshold(self, tmp_path):
|
||||
rec = Recorder(root=tmp_path)
|
||||
cfg = CrystallizerConfig(threshold=3, min_review_ms=0)
|
||||
for i in range(2):
|
||||
evaluate_new_run(_run(f"r{i}"), rec, cfg)
|
||||
wf = evaluate_new_run(_run("r2"), rec, cfg)
|
||||
assert wf is not None
|
||||
assert wf.approval_count == 3
|
||||
|
||||
def test_unapproved_run_does_not_trigger(self, tmp_path):
|
||||
rec = Recorder(root=tmp_path)
|
||||
cfg = CrystallizerConfig(threshold=1, min_review_ms=0)
|
||||
result = evaluate_new_run(_run("r1", approved=False), rec, cfg)
|
||||
assert result is None
|
||||
|
||||
def test_run_is_recorded_even_if_not_approved(self, tmp_path):
|
||||
rec = Recorder(root=tmp_path)
|
||||
cfg = CrystallizerConfig(threshold=3, min_review_ms=0)
|
||||
evaluate_new_run(_run("r1", approved=False), rec, cfg)
|
||||
assert len(rec.load_runs("osprey", "ivr_navigate")) == 1
|
||||
96
tests/test_pipeline/test_executor.py
Normal file
96
tests/test_pipeline/test_executor.py
Normal file
|
|
@ -0,0 +1,96 @@
|
|||
"""Tests for pipeline.Executor."""
|
||||
import pytest
|
||||
from circuitforge_core.pipeline.executor import Executor, ExecutionResult
|
||||
from circuitforge_core.pipeline.models import CrystallizedWorkflow, Step
|
||||
|
||||
|
||||
def _wf(steps=None) -> CrystallizedWorkflow:
|
||||
return CrystallizedWorkflow(
|
||||
workflow_id="osprey:ivr_navigate:abc",
|
||||
product="osprey",
|
||||
task_type="ivr_navigate",
|
||||
input_hash="abc",
|
||||
steps=[Step("dtmf", {"digits": "1"}), Step("dtmf", {"digits": "2"})]
|
||||
if steps is None else steps,
|
||||
crystallized_at="2026-04-08T00:00:00+00:00",
|
||||
run_ids=["r1"],
|
||||
approval_count=1,
|
||||
avg_review_duration_ms=8000,
|
||||
all_output_unmodified=True,
|
||||
)
|
||||
|
||||
|
||||
def _ok_step(_step):
|
||||
return True, "ok"
|
||||
|
||||
|
||||
def _fail_step(_step):
|
||||
return False, None
|
||||
|
||||
|
||||
def _raise_step(_step):
|
||||
raise RuntimeError("hardware error")
|
||||
|
||||
|
||||
def _llm():
|
||||
return "llm-output"
|
||||
|
||||
|
||||
class TestExecutor:
|
||||
def test_all_steps_succeed(self):
|
||||
ex = Executor(step_fn=_ok_step, llm_fn=_llm)
|
||||
result = ex.execute(_wf())
|
||||
assert result.success is True
|
||||
assert result.used_deterministic is True
|
||||
assert len(result.step_results) == 2
|
||||
|
||||
def test_failed_step_triggers_llm_fallback(self):
|
||||
ex = Executor(step_fn=_fail_step, llm_fn=_llm)
|
||||
result = ex.execute(_wf())
|
||||
assert result.success is True
|
||||
assert result.used_deterministic is False
|
||||
assert result.llm_output == "llm-output"
|
||||
|
||||
def test_raising_step_triggers_llm_fallback(self):
|
||||
ex = Executor(step_fn=_raise_step, llm_fn=_llm)
|
||||
result = ex.execute(_wf())
|
||||
assert result.success is True
|
||||
assert result.used_deterministic is False
|
||||
|
||||
def test_llm_fallback_disabled_returns_failure(self):
|
||||
ex = Executor(step_fn=_fail_step, llm_fn=_llm, llm_fallback=False)
|
||||
result = ex.execute(_wf())
|
||||
assert result.success is False
|
||||
assert "disabled" in (result.error or "")
|
||||
|
||||
def test_run_with_fallback_no_workflow_calls_llm(self):
|
||||
ex = Executor(step_fn=_ok_step, llm_fn=_llm)
|
||||
result = ex.run_with_fallback(workflow=None)
|
||||
assert result.success is True
|
||||
assert result.used_deterministic is False
|
||||
assert result.llm_output == "llm-output"
|
||||
|
||||
def test_run_with_fallback_uses_workflow_when_given(self):
|
||||
ex = Executor(step_fn=_ok_step, llm_fn=_llm)
|
||||
result = ex.run_with_fallback(workflow=_wf())
|
||||
assert result.used_deterministic is True
|
||||
|
||||
def test_llm_fn_raises_returns_failure(self):
|
||||
def _bad_llm():
|
||||
raise ValueError("no model")
|
||||
|
||||
ex = Executor(step_fn=_fail_step, llm_fn=_bad_llm)
|
||||
result = ex.execute(_wf())
|
||||
assert result.success is False
|
||||
assert "no model" in (result.error or "")
|
||||
|
||||
def test_workflow_id_preserved_in_result(self):
|
||||
ex = Executor(step_fn=_ok_step, llm_fn=_llm)
|
||||
result = ex.execute(_wf())
|
||||
assert result.workflow_id == "osprey:ivr_navigate:abc"
|
||||
|
||||
def test_empty_workflow_succeeds_immediately(self):
|
||||
ex = Executor(step_fn=_ok_step, llm_fn=_llm)
|
||||
result = ex.execute(_wf(steps=[]))
|
||||
assert result.success is True
|
||||
assert result.step_results == []
|
||||
102
tests/test_pipeline/test_models.py
Normal file
102
tests/test_pipeline/test_models.py
Normal file
|
|
@ -0,0 +1,102 @@
|
|||
"""Tests for pipeline models and hash_input utility."""
|
||||
import pytest
|
||||
from circuitforge_core.pipeline.models import (
|
||||
CrystallizedWorkflow,
|
||||
PipelineRun,
|
||||
Step,
|
||||
hash_input,
|
||||
)
|
||||
|
||||
|
||||
class TestHashInput:
|
||||
def test_stable_across_calls(self):
|
||||
feat = {"agency": "FTB", "menu_depth": 2}
|
||||
assert hash_input(feat) == hash_input(feat)
|
||||
|
||||
def test_key_order_irrelevant(self):
|
||||
a = hash_input({"b": 2, "a": 1})
|
||||
b = hash_input({"a": 1, "b": 2})
|
||||
assert a == b
|
||||
|
||||
def test_different_values_differ(self):
|
||||
assert hash_input({"a": 1}) != hash_input({"a": 2})
|
||||
|
||||
def test_returns_hex_string(self):
|
||||
h = hash_input({"x": "y"})
|
||||
assert isinstance(h, str)
|
||||
assert len(h) == 64 # SHA-256 hex
|
||||
|
||||
|
||||
class TestStep:
|
||||
def test_roundtrip(self):
|
||||
s = Step(action="dtmf", params={"digits": "1"}, description="Press 1")
|
||||
assert Step.from_dict(s.to_dict()) == s
|
||||
|
||||
def test_description_optional(self):
|
||||
s = Step.from_dict({"action": "dtmf", "params": {}})
|
||||
assert s.description == ""
|
||||
|
||||
|
||||
class TestPipelineRun:
|
||||
def _run(self, **kwargs) -> PipelineRun:
|
||||
defaults = dict(
|
||||
run_id="r1",
|
||||
product="osprey",
|
||||
task_type="ivr_navigate",
|
||||
input_hash="abc123",
|
||||
steps=[Step("dtmf", {"digits": "1"})],
|
||||
approved=True,
|
||||
review_duration_ms=8000,
|
||||
output_modified=False,
|
||||
)
|
||||
defaults.update(kwargs)
|
||||
return PipelineRun(**defaults)
|
||||
|
||||
def test_roundtrip(self):
|
||||
run = self._run()
|
||||
assert PipelineRun.from_dict(run.to_dict()).run_id == "r1"
|
||||
|
||||
def test_output_modified_false_default(self):
|
||||
d = self._run().to_dict()
|
||||
d.pop("output_modified", None)
|
||||
run = PipelineRun.from_dict(d)
|
||||
assert run.output_modified is False
|
||||
|
||||
def test_timestamp_auto_set(self):
|
||||
run = self._run()
|
||||
assert run.timestamp # non-empty
|
||||
|
||||
|
||||
class TestCrystallizedWorkflow:
|
||||
def _wf(self) -> CrystallizedWorkflow:
|
||||
return CrystallizedWorkflow(
|
||||
workflow_id="osprey:ivr_navigate:abc123abc123",
|
||||
product="osprey",
|
||||
task_type="ivr_navigate",
|
||||
input_hash="abc123",
|
||||
steps=[Step("dtmf", {"digits": "1"})],
|
||||
crystallized_at="2026-04-08T00:00:00+00:00",
|
||||
run_ids=["r1", "r2", "r3"],
|
||||
approval_count=3,
|
||||
avg_review_duration_ms=9000,
|
||||
all_output_unmodified=True,
|
||||
)
|
||||
|
||||
def test_roundtrip(self):
|
||||
wf = self._wf()
|
||||
restored = CrystallizedWorkflow.from_dict(wf.to_dict())
|
||||
assert restored.workflow_id == wf.workflow_id
|
||||
assert restored.avg_review_duration_ms == 9000
|
||||
assert restored.all_output_unmodified is True
|
||||
|
||||
def test_active_default_true(self):
|
||||
d = self._wf().to_dict()
|
||||
d.pop("active", None)
|
||||
wf = CrystallizedWorkflow.from_dict(d)
|
||||
assert wf.active is True
|
||||
|
||||
def test_version_default_one(self):
|
||||
d = self._wf().to_dict()
|
||||
d.pop("version", None)
|
||||
wf = CrystallizedWorkflow.from_dict(d)
|
||||
assert wf.version == 1
|
||||
198
tests/test_pipeline/test_multimodal.py
Normal file
198
tests/test_pipeline/test_multimodal.py
Normal file
|
|
@ -0,0 +1,198 @@
|
|||
"""Tests for pipeline.MultimodalPipeline — mock vision and text backends."""
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from circuitforge_core.documents.models import Element, StructuredDocument
|
||||
from circuitforge_core.pipeline.multimodal import (
|
||||
MultimodalConfig,
|
||||
MultimodalPipeline,
|
||||
PageResult,
|
||||
_default_prompt,
|
||||
)
|
||||
|
||||
|
||||
# ── Helpers ───────────────────────────────────────────────────────────────────
|
||||
|
||||
def _doc(text="extracted text", page=0) -> StructuredDocument:
|
||||
return StructuredDocument(
|
||||
elements=[Element(type="paragraph", text=text)],
|
||||
raw_text=text,
|
||||
)
|
||||
|
||||
|
||||
def _vision_ok(text="extracted text"):
|
||||
"""Mock DocuvisionClient.extract that returns a StructuredDocument."""
|
||||
mock = MagicMock()
|
||||
mock.extract.return_value = _doc(text)
|
||||
return mock
|
||||
|
||||
|
||||
def _vision_fail(exc=None):
|
||||
mock = MagicMock()
|
||||
mock.extract.side_effect = exc or ConnectionError("service down")
|
||||
return mock
|
||||
|
||||
|
||||
def _generate_fn(prompt, max_tokens=512, temperature=0.7):
|
||||
return f"generated: {prompt[:20]}"
|
||||
|
||||
|
||||
def _stream_fn(prompt, max_tokens=512, temperature=0.7):
|
||||
yield "tok1"
|
||||
yield "tok2"
|
||||
yield "tok3"
|
||||
|
||||
|
||||
def _pipe(vision_mock=None, generate_fn=None, stream_fn=None,
|
||||
vram_serialise=False, swap_fn=None, prompt_fn=None) -> MultimodalPipeline:
|
||||
cfg = MultimodalConfig(vram_serialise=vram_serialise)
|
||||
if prompt_fn:
|
||||
cfg.prompt_fn = prompt_fn
|
||||
pipe = MultimodalPipeline(cfg, generate_fn=generate_fn or _generate_fn,
|
||||
stream_fn=stream_fn, swap_fn=swap_fn)
|
||||
if vision_mock is not None:
|
||||
pipe._vision = vision_mock
|
||||
return pipe
|
||||
|
||||
|
||||
# ── DefaultPrompt ─────────────────────────────────────────────────────────────
|
||||
|
||||
class TestDefaultPrompt:
|
||||
def test_page_zero_no_header(self):
|
||||
doc = _doc("hello")
|
||||
assert _default_prompt(0, doc) == "hello"
|
||||
|
||||
def test_page_one_has_header(self):
|
||||
doc = _doc("content")
|
||||
prompt = _default_prompt(1, doc)
|
||||
assert "[Page 2]" in prompt
|
||||
assert "content" in prompt
|
||||
|
||||
|
||||
# ── run() ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
class TestMultimodalPipelineRun:
|
||||
def test_single_page_success(self):
|
||||
pipe = _pipe(vision_mock=_vision_ok("resume text"))
|
||||
results = list(pipe.run([b"page0_bytes"]))
|
||||
assert len(results) == 1
|
||||
assert results[0].page_idx == 0
|
||||
assert results[0].error is None
|
||||
assert "generated" in results[0].generated
|
||||
|
||||
def test_multiple_pages_all_yielded(self):
|
||||
pipe = _pipe(vision_mock=_vision_ok())
|
||||
results = list(pipe.run([b"p0", b"p1", b"p2"]))
|
||||
assert len(results) == 3
|
||||
assert [r.page_idx for r in results] == [0, 1, 2]
|
||||
|
||||
def test_vision_failure_yields_error_page(self):
|
||||
pipe = _pipe(vision_mock=_vision_fail())
|
||||
results = list(pipe.run([b"p0"]))
|
||||
assert results[0].error is not None
|
||||
assert results[0].doc is None
|
||||
assert results[0].generated == ""
|
||||
|
||||
def test_partial_failure_does_not_stop_pipeline(self):
|
||||
"""One bad page should not prevent subsequent pages from processing."""
|
||||
mock = MagicMock()
|
||||
mock.extract.side_effect = [
|
||||
ConnectionError("fail"),
|
||||
_doc("good text"),
|
||||
]
|
||||
pipe = _pipe(vision_mock=mock)
|
||||
results = list(pipe.run([b"p0", b"p1"]))
|
||||
assert results[0].error is not None
|
||||
assert results[1].error is None
|
||||
|
||||
def test_generation_failure_yields_error_page(self):
|
||||
def _bad_gen(prompt, **kw):
|
||||
raise RuntimeError("model OOM")
|
||||
|
||||
pipe = _pipe(vision_mock=_vision_ok(), generate_fn=_bad_gen)
|
||||
results = list(pipe.run([b"p0"]))
|
||||
assert results[0].error is not None
|
||||
assert "OOM" in results[0].error
|
||||
|
||||
def test_doc_attached_to_result(self):
|
||||
pipe = _pipe(vision_mock=_vision_ok("some text"))
|
||||
results = list(pipe.run([b"p0"]))
|
||||
assert results[0].doc is not None
|
||||
assert results[0].doc.raw_text == "some text"
|
||||
|
||||
def test_empty_pages_yields_nothing(self):
|
||||
pipe = _pipe(vision_mock=_vision_ok())
|
||||
assert list(pipe.run([])) == []
|
||||
|
||||
def test_custom_prompt_fn_called(self):
|
||||
calls = []
|
||||
|
||||
def _prompt_fn(page_idx, doc):
|
||||
calls.append((page_idx, doc.raw_text))
|
||||
return f"custom:{doc.raw_text}"
|
||||
|
||||
pipe = _pipe(vision_mock=_vision_ok("txt"), prompt_fn=_prompt_fn)
|
||||
list(pipe.run([b"p0"]))
|
||||
assert calls == [(0, "txt")]
|
||||
|
||||
def test_vram_serialise_calls_swap_fn(self):
|
||||
swaps = []
|
||||
pipe = _pipe(vision_mock=_vision_ok(), vram_serialise=True,
|
||||
swap_fn=lambda: swaps.append(1))
|
||||
list(pipe.run([b"p0", b"p1"]))
|
||||
assert len(swaps) == 2 # once per page
|
||||
|
||||
def test_vram_serialise_false_no_swap_called(self):
|
||||
swaps = []
|
||||
pipe = _pipe(vision_mock=_vision_ok(), vram_serialise=False,
|
||||
swap_fn=lambda: swaps.append(1))
|
||||
list(pipe.run([b"p0"]))
|
||||
assert swaps == []
|
||||
|
||||
def test_swap_fn_none_does_not_raise(self):
|
||||
pipe = _pipe(vision_mock=_vision_ok(), vram_serialise=True, swap_fn=None)
|
||||
results = list(pipe.run([b"p0"]))
|
||||
assert results[0].error is None
|
||||
|
||||
|
||||
# ── stream() ──────────────────────────────────────────────────────────────────
|
||||
|
||||
class TestMultimodalPipelineStream:
|
||||
def test_yields_page_idx_token_tuples(self):
|
||||
pipe = _pipe(vision_mock=_vision_ok(), stream_fn=_stream_fn)
|
||||
tokens = list(pipe.stream([b"p0"]))
|
||||
assert all(isinstance(t, tuple) and len(t) == 2 for t in tokens)
|
||||
assert tokens[0][0] == 0 # page_idx
|
||||
assert tokens[0][1] == "tok1"
|
||||
|
||||
def test_multiple_pages_interleaved_by_page(self):
|
||||
pipe = _pipe(vision_mock=_vision_ok(), stream_fn=_stream_fn)
|
||||
tokens = list(pipe.stream([b"p0", b"p1"]))
|
||||
page_indices = [t[0] for t in tokens]
|
||||
# All page-0 tokens come before page-1 tokens (pages are sequential)
|
||||
assert page_indices == sorted(page_indices)
|
||||
|
||||
def test_vision_failure_yields_error_token(self):
|
||||
pipe = _pipe(vision_mock=_vision_fail(), stream_fn=_stream_fn)
|
||||
tokens = list(pipe.stream([b"p0"]))
|
||||
assert len(tokens) == 1
|
||||
assert "extraction error" in tokens[0][1]
|
||||
|
||||
def test_stream_fn_error_yields_error_token(self):
|
||||
def _bad_stream(prompt, **kw):
|
||||
raise RuntimeError("GPU gone")
|
||||
yield # make it a generator
|
||||
|
||||
pipe = _pipe(vision_mock=_vision_ok(), stream_fn=_bad_stream)
|
||||
tokens = list(pipe.stream([b"p0"]))
|
||||
assert any("generation error" in t[1] for t in tokens)
|
||||
|
||||
def test_empty_pages_yields_nothing(self):
|
||||
pipe = _pipe(vision_mock=_vision_ok(), stream_fn=_stream_fn)
|
||||
assert list(pipe.stream([])) == []
|
||||
|
||||
|
||||
# ── Import check ──────────────────────────────────────────────────────────────
|
||||
|
||||
def test_exported_from_pipeline_package():
|
||||
from circuitforge_core.pipeline import MultimodalPipeline, MultimodalConfig, PageResult
|
||||
assert MultimodalPipeline is not None
|
||||
66
tests/test_pipeline/test_recorder.py
Normal file
66
tests/test_pipeline/test_recorder.py
Normal file
|
|
@ -0,0 +1,66 @@
|
|||
"""Tests for pipeline.Recorder."""
|
||||
import pytest
|
||||
from circuitforge_core.pipeline.models import PipelineRun, Step
|
||||
from circuitforge_core.pipeline.recorder import Recorder
|
||||
|
||||
|
||||
def _run(run_id="r1", approved=True, input_hash="abc", review_ms=8000,
|
||||
modified=False, ts="2026-04-08T01:00:00+00:00") -> PipelineRun:
|
||||
return PipelineRun(
|
||||
run_id=run_id,
|
||||
product="osprey",
|
||||
task_type="ivr_navigate",
|
||||
input_hash=input_hash,
|
||||
steps=[Step("dtmf", {"digits": "1"})],
|
||||
approved=approved,
|
||||
review_duration_ms=review_ms,
|
||||
output_modified=modified,
|
||||
timestamp=ts,
|
||||
)
|
||||
|
||||
|
||||
class TestRecorder:
|
||||
def test_record_creates_file(self, tmp_path):
|
||||
rec = Recorder(root=tmp_path)
|
||||
path = rec.record(_run())
|
||||
assert path.exists()
|
||||
|
||||
def test_load_runs_empty_when_no_directory(self, tmp_path):
|
||||
rec = Recorder(root=tmp_path)
|
||||
assert rec.load_runs("osprey", "ivr_navigate") == []
|
||||
|
||||
def test_load_runs_returns_recorded(self, tmp_path):
|
||||
rec = Recorder(root=tmp_path)
|
||||
rec.record(_run("r1"))
|
||||
rec.record(_run("r2"))
|
||||
runs = rec.load_runs("osprey", "ivr_navigate")
|
||||
assert len(runs) == 2
|
||||
|
||||
def test_load_runs_newest_first(self, tmp_path):
|
||||
rec = Recorder(root=tmp_path)
|
||||
rec.record(_run("r_old", ts="2026-01-01T00:00:00+00:00"))
|
||||
rec.record(_run("r_new", ts="2026-04-08T00:00:00+00:00"))
|
||||
runs = rec.load_runs("osprey", "ivr_navigate")
|
||||
assert runs[0].run_id == "r_new"
|
||||
|
||||
def test_load_approved_filters(self, tmp_path):
|
||||
rec = Recorder(root=tmp_path)
|
||||
rec.record(_run("r1", approved=True))
|
||||
rec.record(_run("r2", approved=False))
|
||||
approved = rec.load_approved("osprey", "ivr_navigate", "abc")
|
||||
assert all(r.approved for r in approved)
|
||||
assert len(approved) == 1
|
||||
|
||||
def test_load_approved_filters_by_hash(self, tmp_path):
|
||||
rec = Recorder(root=tmp_path)
|
||||
rec.record(_run("r1", input_hash="hash_a"))
|
||||
rec.record(_run("r2", input_hash="hash_b"))
|
||||
result = rec.load_approved("osprey", "ivr_navigate", "hash_a")
|
||||
assert len(result) == 1
|
||||
assert result[0].run_id == "r1"
|
||||
|
||||
def test_record_is_append_only(self, tmp_path):
|
||||
rec = Recorder(root=tmp_path)
|
||||
for i in range(5):
|
||||
rec.record(_run(f"r{i}"))
|
||||
assert len(rec.load_runs("osprey", "ivr_navigate")) == 5
|
||||
104
tests/test_pipeline/test_registry.py
Normal file
104
tests/test_pipeline/test_registry.py
Normal file
|
|
@ -0,0 +1,104 @@
|
|||
"""Tests for pipeline.Registry — workflow lookup."""
|
||||
import pytest
|
||||
from circuitforge_core.pipeline.models import CrystallizedWorkflow, Step
|
||||
from circuitforge_core.pipeline.registry import Registry
|
||||
|
||||
|
||||
def _wf(input_hash="abc", active=True, wf_id=None) -> CrystallizedWorkflow:
|
||||
wid = wf_id or f"osprey:ivr_navigate:{input_hash[:12]}"
|
||||
return CrystallizedWorkflow(
|
||||
workflow_id=wid,
|
||||
product="osprey",
|
||||
task_type="ivr_navigate",
|
||||
input_hash=input_hash,
|
||||
steps=[Step("dtmf", {"digits": "1"})],
|
||||
crystallized_at="2026-04-08T00:00:00+00:00",
|
||||
run_ids=["r1", "r2", "r3"],
|
||||
approval_count=3,
|
||||
avg_review_duration_ms=9000,
|
||||
all_output_unmodified=True,
|
||||
active=active,
|
||||
)
|
||||
|
||||
|
||||
class TestRegistry:
|
||||
def test_register_creates_file(self, tmp_path):
|
||||
reg = Registry(root=tmp_path)
|
||||
path = reg.register(_wf())
|
||||
assert path.exists()
|
||||
|
||||
def test_load_all_empty_when_no_directory(self, tmp_path):
|
||||
reg = Registry(root=tmp_path)
|
||||
assert reg.load_all("osprey", "ivr_navigate") == []
|
||||
|
||||
def test_load_all_returns_registered(self, tmp_path):
|
||||
reg = Registry(root=tmp_path)
|
||||
reg.register(_wf("hash_a", wf_id="osprey:ivr_navigate:hash_a"))
|
||||
reg.register(_wf("hash_b", wf_id="osprey:ivr_navigate:hash_b"))
|
||||
assert len(reg.load_all("osprey", "ivr_navigate")) == 2
|
||||
|
||||
def test_match_exact_hit(self, tmp_path):
|
||||
reg = Registry(root=tmp_path)
|
||||
reg.register(_wf("abc123"))
|
||||
wf = reg.match("osprey", "ivr_navigate", "abc123")
|
||||
assert wf is not None
|
||||
assert wf.input_hash == "abc123"
|
||||
|
||||
def test_match_returns_none_on_miss(self, tmp_path):
|
||||
reg = Registry(root=tmp_path)
|
||||
reg.register(_wf("abc123"))
|
||||
assert reg.match("osprey", "ivr_navigate", "different") is None
|
||||
|
||||
def test_match_ignores_inactive(self, tmp_path):
|
||||
reg = Registry(root=tmp_path)
|
||||
reg.register(_wf("abc123", active=False))
|
||||
assert reg.match("osprey", "ivr_navigate", "abc123") is None
|
||||
|
||||
def test_deactivate_sets_active_false(self, tmp_path):
|
||||
reg = Registry(root=tmp_path)
|
||||
wf = _wf("abc123")
|
||||
reg.register(wf)
|
||||
reg.deactivate(wf.workflow_id, "osprey", "ivr_navigate")
|
||||
assert reg.match("osprey", "ivr_navigate", "abc123") is None
|
||||
|
||||
def test_deactivate_returns_false_when_not_found(self, tmp_path):
|
||||
reg = Registry(root=tmp_path)
|
||||
assert reg.deactivate("nonexistent", "osprey", "ivr_navigate") is False
|
||||
|
||||
def test_find_falls_through_to_fuzzy(self, tmp_path):
|
||||
reg = Registry(root=tmp_path,
|
||||
similarity_fn=lambda a, b: 1.0 if a == b else 0.5,
|
||||
fuzzy_threshold=0.4)
|
||||
reg.register(_wf("hash_stored"))
|
||||
# No exact match for "hash_query" but similarity returns 0.5 >= 0.4
|
||||
wf = reg.find("osprey", "ivr_navigate", "hash_query")
|
||||
assert wf is not None
|
||||
|
||||
def test_fuzzy_match_raises_without_fn(self, tmp_path):
|
||||
reg = Registry(root=tmp_path)
|
||||
with pytest.raises(RuntimeError, match="similarity_fn"):
|
||||
reg.fuzzy_match("osprey", "ivr_navigate", "any")
|
||||
|
||||
def test_fuzzy_match_below_threshold_returns_none(self, tmp_path):
|
||||
reg = Registry(root=tmp_path,
|
||||
similarity_fn=lambda a, b: 0.1,
|
||||
fuzzy_threshold=0.8)
|
||||
reg.register(_wf("hash_stored"))
|
||||
assert reg.fuzzy_match("osprey", "ivr_navigate", "hash_query") is None
|
||||
|
||||
def test_find_exact_takes_priority(self, tmp_path):
|
||||
reg = Registry(root=tmp_path,
|
||||
similarity_fn=lambda a, b: 0.9,
|
||||
fuzzy_threshold=0.8)
|
||||
reg.register(_wf("exact_hash"))
|
||||
wf = reg.find("osprey", "ivr_navigate", "exact_hash")
|
||||
# Should be the exact-match workflow
|
||||
assert wf.input_hash == "exact_hash"
|
||||
|
||||
def test_workflow_id_colon_safe_in_filename(self, tmp_path):
|
||||
"""Colons in workflow_id must not break file creation on any OS."""
|
||||
reg = Registry(root=tmp_path)
|
||||
wf = _wf("abc", wf_id="osprey:ivr_navigate:abc123abc123")
|
||||
path = reg.register(wf)
|
||||
assert path.exists()
|
||||
assert ":" not in path.name
|
||||
|
|
@ -120,6 +120,93 @@ class TestLocalFileStore:
|
|||
|
||||
|
||||
from circuitforge_core.preferences import get_user_preference, set_user_preference
|
||||
from circuitforge_core.preferences.accessibility import (
|
||||
is_reduced_motion_preferred,
|
||||
is_high_contrast,
|
||||
get_font_size,
|
||||
is_screen_reader_mode,
|
||||
set_reduced_motion,
|
||||
PREF_REDUCED_MOTION,
|
||||
PREF_HIGH_CONTRAST,
|
||||
PREF_FONT_SIZE,
|
||||
PREF_SCREEN_READER,
|
||||
)
|
||||
|
||||
|
||||
class TestAccessibilityPreferences:
|
||||
def _store(self, tmp_path) -> LocalFileStore:
|
||||
return LocalFileStore(prefs_path=tmp_path / "preferences.yaml")
|
||||
|
||||
def test_reduced_motion_default_false(self, tmp_path):
|
||||
store = self._store(tmp_path)
|
||||
assert is_reduced_motion_preferred(store=store) is False
|
||||
|
||||
def test_set_reduced_motion_persists(self, tmp_path):
|
||||
store = self._store(tmp_path)
|
||||
set_reduced_motion(True, store=store)
|
||||
assert is_reduced_motion_preferred(store=store) is True
|
||||
|
||||
def test_reduced_motion_false_roundtrip(self, tmp_path):
|
||||
store = self._store(tmp_path)
|
||||
set_reduced_motion(True, store=store)
|
||||
set_reduced_motion(False, store=store)
|
||||
assert is_reduced_motion_preferred(store=store) is False
|
||||
|
||||
def test_high_contrast_default_false(self, tmp_path):
|
||||
store = self._store(tmp_path)
|
||||
assert is_high_contrast(store=store) is False
|
||||
|
||||
def test_high_contrast_set_and_read(self, tmp_path):
|
||||
store = self._store(tmp_path)
|
||||
store.set(user_id=None, path=PREF_HIGH_CONTRAST, value=True)
|
||||
assert is_high_contrast(store=store) is True
|
||||
|
||||
def test_font_size_default(self, tmp_path):
|
||||
store = self._store(tmp_path)
|
||||
assert get_font_size(store=store) == "default"
|
||||
|
||||
def test_font_size_large(self, tmp_path):
|
||||
store = self._store(tmp_path)
|
||||
store.set(user_id=None, path=PREF_FONT_SIZE, value="large")
|
||||
assert get_font_size(store=store) == "large"
|
||||
|
||||
def test_font_size_xlarge(self, tmp_path):
|
||||
store = self._store(tmp_path)
|
||||
store.set(user_id=None, path=PREF_FONT_SIZE, value="xlarge")
|
||||
assert get_font_size(store=store) == "xlarge"
|
||||
|
||||
def test_font_size_invalid_falls_back_to_default(self, tmp_path):
|
||||
store = self._store(tmp_path)
|
||||
store.set(user_id=None, path=PREF_FONT_SIZE, value="gigantic")
|
||||
assert get_font_size(store=store) == "default"
|
||||
|
||||
def test_screen_reader_mode_default_false(self, tmp_path):
|
||||
store = self._store(tmp_path)
|
||||
assert is_screen_reader_mode(store=store) is False
|
||||
|
||||
def test_screen_reader_mode_set(self, tmp_path):
|
||||
store = self._store(tmp_path)
|
||||
store.set(user_id=None, path=PREF_SCREEN_READER, value=True)
|
||||
assert is_screen_reader_mode(store=store) is True
|
||||
|
||||
def test_preferences_are_independent(self, tmp_path):
|
||||
"""Setting one a11y pref doesn't affect others."""
|
||||
store = self._store(tmp_path)
|
||||
set_reduced_motion(True, store=store)
|
||||
assert is_high_contrast(store=store) is False
|
||||
assert get_font_size(store=store) == "default"
|
||||
assert is_screen_reader_mode(store=store) is False
|
||||
|
||||
def test_user_id_threaded_through(self, tmp_path):
|
||||
"""user_id param is accepted (LocalFileStore ignores it, but must not error)."""
|
||||
store = self._store(tmp_path)
|
||||
set_reduced_motion(True, user_id="u999", store=store)
|
||||
assert is_reduced_motion_preferred(user_id="u999", store=store) is True
|
||||
|
||||
def test_accessibility_exported_from_package(self):
|
||||
from circuitforge_core.preferences import accessibility
|
||||
assert hasattr(accessibility, "is_reduced_motion_preferred")
|
||||
assert hasattr(accessibility, "PREF_REDUCED_MOTION")
|
||||
|
||||
|
||||
class TestPreferenceHelpers:
|
||||
|
|
|
|||
0
tests/test_text/__init__.py
Normal file
0
tests/test_text/__init__.py
Normal file
190
tests/test_text/test_backend.py
Normal file
190
tests/test_text/test_backend.py
Normal file
|
|
@ -0,0 +1,190 @@
|
|||
"""Tests for cf-text backend selection, mock backend, and public API."""
|
||||
import os
|
||||
import pytest
|
||||
|
||||
from circuitforge_core.text.backends.base import (
|
||||
ChatMessage,
|
||||
GenerateResult,
|
||||
TextBackend,
|
||||
_select_backend,
|
||||
make_text_backend,
|
||||
)
|
||||
from circuitforge_core.text.backends.mock import MockTextBackend
|
||||
from circuitforge_core.text import generate, chat, reset_backend, make_backend
|
||||
|
||||
|
||||
# ── _select_backend ───────────────────────────────────────────────────────────
|
||||
|
||||
class TestSelectBackend:
|
||||
def test_explicit_llamacpp(self):
|
||||
assert _select_backend("model.gguf", "llamacpp") == "llamacpp"
|
||||
|
||||
def test_explicit_transformers(self):
|
||||
assert _select_backend("model.gguf", "transformers") == "transformers"
|
||||
|
||||
def test_explicit_invalid_raises(self):
|
||||
with pytest.raises(ValueError, match="not valid"):
|
||||
_select_backend("model.gguf", "ctransformers")
|
||||
|
||||
def test_env_override_llamacpp(self, monkeypatch):
|
||||
monkeypatch.setenv("CF_TEXT_BACKEND", "llamacpp")
|
||||
assert _select_backend("Qwen/Qwen2.5-3B", None) == "llamacpp"
|
||||
|
||||
def test_env_override_transformers(self, monkeypatch):
|
||||
monkeypatch.setenv("CF_TEXT_BACKEND", "transformers")
|
||||
assert _select_backend("model.gguf", None) == "transformers"
|
||||
|
||||
def test_env_override_invalid_raises(self, monkeypatch):
|
||||
monkeypatch.setenv("CF_TEXT_BACKEND", "ctransformers")
|
||||
with pytest.raises(ValueError):
|
||||
_select_backend("model.gguf", None)
|
||||
|
||||
def test_caller_beats_env(self, monkeypatch):
|
||||
monkeypatch.setenv("CF_TEXT_BACKEND", "transformers")
|
||||
assert _select_backend("model.gguf", "llamacpp") == "llamacpp"
|
||||
|
||||
def test_gguf_extension_selects_llamacpp(self, monkeypatch):
|
||||
monkeypatch.delenv("CF_TEXT_BACKEND", raising=False)
|
||||
assert _select_backend("/models/qwen2.5-3b-q4.gguf", None) == "llamacpp"
|
||||
|
||||
def test_gguf_uppercase_extension(self, monkeypatch):
|
||||
monkeypatch.delenv("CF_TEXT_BACKEND", raising=False)
|
||||
assert _select_backend("/models/model.GGUF", None) == "llamacpp"
|
||||
|
||||
def test_hf_repo_id_selects_transformers(self, monkeypatch):
|
||||
monkeypatch.delenv("CF_TEXT_BACKEND", raising=False)
|
||||
assert _select_backend("Qwen/Qwen2.5-3B-Instruct", None) == "transformers"
|
||||
|
||||
def test_safetensors_dir_selects_transformers(self, monkeypatch):
|
||||
monkeypatch.delenv("CF_TEXT_BACKEND", raising=False)
|
||||
assert _select_backend("/models/qwen2.5-3b/", None) == "transformers"
|
||||
|
||||
|
||||
# ── ChatMessage ───────────────────────────────────────────────────────────────
|
||||
|
||||
class TestChatMessage:
|
||||
def test_valid_roles(self):
|
||||
for role in ("system", "user", "assistant"):
|
||||
msg = ChatMessage(role, "hello")
|
||||
assert msg.role == role
|
||||
|
||||
def test_invalid_role_raises(self):
|
||||
with pytest.raises(ValueError, match="Invalid role"):
|
||||
ChatMessage("bot", "hello")
|
||||
|
||||
def test_to_dict(self):
|
||||
msg = ChatMessage("user", "hello")
|
||||
assert msg.to_dict() == {"role": "user", "content": "hello"}
|
||||
|
||||
|
||||
# ── MockTextBackend ───────────────────────────────────────────────────────────
|
||||
|
||||
class TestMockTextBackend:
|
||||
def test_generate_returns_result(self):
|
||||
backend = MockTextBackend()
|
||||
result = backend.generate("write something")
|
||||
assert isinstance(result, GenerateResult)
|
||||
assert len(result.text) > 0
|
||||
|
||||
def test_vram_mb_is_zero(self):
|
||||
assert MockTextBackend().vram_mb == 0
|
||||
|
||||
def test_model_name(self):
|
||||
assert MockTextBackend(model_name="test-model").model_name == "test-model"
|
||||
|
||||
def test_generate_stream_yields_tokens(self):
|
||||
backend = MockTextBackend()
|
||||
tokens = list(backend.generate_stream("hello"))
|
||||
assert len(tokens) > 0
|
||||
assert "".join(tokens).strip() == backend.generate("hello").text.strip()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_async(self):
|
||||
backend = MockTextBackend()
|
||||
result = await backend.generate_async("hello")
|
||||
assert isinstance(result, GenerateResult)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_stream_async(self):
|
||||
backend = MockTextBackend()
|
||||
tokens = []
|
||||
async for token in backend.generate_stream_async("hello"):
|
||||
tokens.append(token)
|
||||
assert len(tokens) > 0
|
||||
|
||||
def test_chat(self):
|
||||
backend = MockTextBackend()
|
||||
messages = [ChatMessage("user", "hello")]
|
||||
result = backend.chat(messages)
|
||||
assert isinstance(result, GenerateResult)
|
||||
|
||||
def test_isinstance_protocol(self):
|
||||
assert isinstance(MockTextBackend(), TextBackend)
|
||||
|
||||
|
||||
# ── make_text_backend ─────────────────────────────────────────────────────────
|
||||
|
||||
class TestMakeTextBackend:
|
||||
def test_mock_flag(self):
|
||||
backend = make_text_backend("any-model", mock=True)
|
||||
assert isinstance(backend, MockTextBackend)
|
||||
|
||||
def test_mock_env(self, monkeypatch):
|
||||
monkeypatch.setenv("CF_TEXT_MOCK", "1")
|
||||
backend = make_text_backend("any-model")
|
||||
assert isinstance(backend, MockTextBackend)
|
||||
|
||||
def test_real_gguf_raises_import_error(self, monkeypatch):
|
||||
monkeypatch.delenv("CF_TEXT_MOCK", raising=False)
|
||||
monkeypatch.delenv("CF_TEXT_BACKEND", raising=False)
|
||||
with pytest.raises((ImportError, FileNotFoundError)):
|
||||
make_text_backend("/nonexistent/model.gguf", mock=False)
|
||||
|
||||
def test_real_transformers_nonexistent_model_raises(self, monkeypatch):
|
||||
monkeypatch.delenv("CF_TEXT_MOCK", raising=False)
|
||||
monkeypatch.setenv("CF_TEXT_BACKEND", "transformers")
|
||||
# Use a clearly nonexistent local path — avoids a network hit and HF download
|
||||
with pytest.raises(Exception):
|
||||
make_text_backend("/nonexistent/local/model-dir", mock=False)
|
||||
|
||||
|
||||
# ── Public API (singleton) ────────────────────────────────────────────────────
|
||||
|
||||
class TestPublicAPI:
|
||||
def setup_method(self):
|
||||
reset_backend()
|
||||
|
||||
def teardown_method(self):
|
||||
reset_backend()
|
||||
|
||||
def test_generate_mock(self, monkeypatch):
|
||||
monkeypatch.setenv("CF_TEXT_MOCK", "1")
|
||||
result = generate("write something")
|
||||
assert isinstance(result, GenerateResult)
|
||||
|
||||
def test_generate_stream_mock(self, monkeypatch):
|
||||
monkeypatch.setenv("CF_TEXT_MOCK", "1")
|
||||
tokens = list(generate("hello", stream=True))
|
||||
assert len(tokens) > 0
|
||||
|
||||
def test_chat_mock(self, monkeypatch):
|
||||
monkeypatch.setenv("CF_TEXT_MOCK", "1")
|
||||
result = chat([ChatMessage("user", "hello")])
|
||||
assert isinstance(result, GenerateResult)
|
||||
|
||||
def test_chat_stream_raises(self, monkeypatch):
|
||||
monkeypatch.setenv("CF_TEXT_MOCK", "1")
|
||||
with pytest.raises(NotImplementedError):
|
||||
chat([ChatMessage("user", "hello")], stream=True)
|
||||
|
||||
def test_make_backend_returns_mock(self):
|
||||
backend = make_backend("any", mock=True)
|
||||
assert isinstance(backend, MockTextBackend)
|
||||
|
||||
def test_singleton_reused(self, monkeypatch):
|
||||
monkeypatch.setenv("CF_TEXT_MOCK", "1")
|
||||
r1 = generate("a")
|
||||
r2 = generate("b")
|
||||
# Both calls should succeed (singleton loaded once)
|
||||
assert isinstance(r1, GenerateResult)
|
||||
assert isinstance(r2, GenerateResult)
|
||||
67
tests/test_text/test_oai_compat.py
Normal file
67
tests/test_text/test_oai_compat.py
Normal file
|
|
@ -0,0 +1,67 @@
|
|||
# tests/test_text/test_oai_compat.py
|
||||
"""Tests for the OpenAI-compatible /v1/chat/completions endpoint on cf-text."""
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from circuitforge_core.text.app import create_app
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def client():
|
||||
app = create_app(model_path="mock", mock=True)
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
def test_oai_chat_completions_returns_200(client: TestClient) -> None:
|
||||
"""POST /v1/chat/completions returns 200 with a valid request."""
|
||||
resp = client.post(
|
||||
"/v1/chat/completions",
|
||||
json={
|
||||
"model": "cf-text",
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
|
||||
|
||||
def test_oai_chat_completions_response_shape(client: TestClient) -> None:
|
||||
"""Response contains the fields LLMRouter expects: choices[0].message.content."""
|
||||
resp = client.post(
|
||||
"/v1/chat/completions",
|
||||
json={
|
||||
"model": "cf-text",
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are helpful."},
|
||||
{"role": "user", "content": "Write a short greeting."},
|
||||
],
|
||||
"max_tokens": 64,
|
||||
},
|
||||
)
|
||||
data = resp.json()
|
||||
assert "choices" in data
|
||||
assert len(data["choices"]) == 1
|
||||
choice = data["choices"][0]
|
||||
assert choice["message"]["role"] == "assistant"
|
||||
assert isinstance(choice["message"]["content"], str)
|
||||
assert len(choice["message"]["content"]) > 0
|
||||
|
||||
|
||||
def test_oai_chat_completions_includes_metadata(client: TestClient) -> None:
|
||||
"""Response includes id, object, created, model, and usage fields."""
|
||||
resp = client.post(
|
||||
"/v1/chat/completions",
|
||||
json={"model": "cf-text", "messages": [{"role": "user", "content": "Hi"}]},
|
||||
)
|
||||
data = resp.json()
|
||||
assert data["object"] == "chat.completion"
|
||||
assert isinstance(data["id"], str)
|
||||
assert data["id"].startswith("cftext-")
|
||||
assert isinstance(data["created"], int)
|
||||
assert "usage" in data
|
||||
|
||||
|
||||
def test_health_endpoint_still_works(client: TestClient) -> None:
|
||||
"""Existing /health endpoint is unaffected by the new OAI route."""
|
||||
resp = client.get("/health")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["status"] == "ok"
|
||||
0
tests/test_vision/__init__.py
Normal file
0
tests/test_vision/__init__.py
Normal file
203
tests/test_vision/test_app.py
Normal file
203
tests/test_vision/test_app.py
Normal file
|
|
@ -0,0 +1,203 @@
|
|||
"""
|
||||
Tests for the cf-vision FastAPI service (mock backend).
|
||||
|
||||
All tests use the mock backend — no GPU or model files required.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import io
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from circuitforge_core.vision.app import create_app, _parse_labels
|
||||
|
||||
|
||||
# ── Fixtures ──────────────────────────────────────────────────────────────────
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def siglip_client() -> TestClient:
|
||||
"""Client backed by mock-siglip (supports classify + embed, not caption)."""
|
||||
app = create_app(model_path="mock-siglip", backend="siglip", mock=True)
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def vlm_client() -> TestClient:
|
||||
"""Client backed by mock-vlm (mock supports all; VLM contract tested separately)."""
|
||||
app = create_app(model_path="mock-vlm", backend="vlm", mock=True)
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
FAKE_IMAGE = b"\x89PNG\r\n\x1a\n" + b"\x00" * 100
|
||||
|
||||
|
||||
def _image_upload(data: bytes = FAKE_IMAGE) -> tuple[str, tuple]:
|
||||
return ("image", ("test.png", io.BytesIO(data), "image/png"))
|
||||
|
||||
|
||||
# ── /health ───────────────────────────────────────────────────────────────────
|
||||
|
||||
def test_health_ok(siglip_client: TestClient) -> None:
|
||||
resp = siglip_client.get("/health")
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert body["status"] == "ok"
|
||||
assert "model" in body
|
||||
assert "vram_mb" in body
|
||||
assert "backend" in body
|
||||
|
||||
|
||||
def test_health_backend_field(siglip_client: TestClient) -> None:
|
||||
resp = siglip_client.get("/health")
|
||||
assert resp.json()["backend"] == "siglip"
|
||||
|
||||
|
||||
def test_health_supports_fields(siglip_client: TestClient) -> None:
|
||||
body = siglip_client.get("/health").json()
|
||||
assert "supports_embed" in body
|
||||
assert "supports_caption" in body
|
||||
|
||||
|
||||
# ── /classify ─────────────────────────────────────────────────────────────────
|
||||
|
||||
def test_classify_json_labels(siglip_client: TestClient) -> None:
|
||||
resp = siglip_client.post(
|
||||
"/classify",
|
||||
files=[_image_upload()],
|
||||
data={"labels": json.dumps(["cat", "dog", "bird"])},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert body["labels"] == ["cat", "dog", "bird"]
|
||||
assert len(body["scores"]) == 3
|
||||
|
||||
|
||||
def test_classify_csv_labels(siglip_client: TestClient) -> None:
|
||||
resp = siglip_client.post(
|
||||
"/classify",
|
||||
files=[_image_upload()],
|
||||
data={"labels": "cat, dog, bird"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["labels"] == ["cat", "dog", "bird"]
|
||||
|
||||
|
||||
def test_classify_single_label(siglip_client: TestClient) -> None:
|
||||
resp = siglip_client.post(
|
||||
"/classify",
|
||||
files=[_image_upload()],
|
||||
data={"labels": "document"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert body["labels"] == ["document"]
|
||||
assert len(body["scores"]) == 1
|
||||
|
||||
|
||||
def test_classify_empty_labels_4xx(siglip_client: TestClient) -> None:
|
||||
# Empty labels should yield a 4xx — either our 400 or FastAPI's 422
|
||||
# depending on how the empty string is handled by the form layer.
|
||||
resp = siglip_client.post(
|
||||
"/classify",
|
||||
files=[_image_upload()],
|
||||
data={"labels": ""},
|
||||
)
|
||||
assert resp.status_code in (400, 422)
|
||||
|
||||
|
||||
def test_classify_empty_image_400(siglip_client: TestClient) -> None:
|
||||
resp = siglip_client.post(
|
||||
"/classify",
|
||||
files=[("image", ("empty.png", io.BytesIO(b""), "image/png"))],
|
||||
data={"labels": "cat"},
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
|
||||
|
||||
def test_classify_model_in_response(siglip_client: TestClient) -> None:
|
||||
resp = siglip_client.post(
|
||||
"/classify",
|
||||
files=[_image_upload()],
|
||||
data={"labels": "cat"},
|
||||
)
|
||||
assert "model" in resp.json()
|
||||
|
||||
|
||||
# ── /embed ────────────────────────────────────────────────────────────────────
|
||||
|
||||
def test_embed_returns_vector(siglip_client: TestClient) -> None:
|
||||
resp = siglip_client.post("/embed", files=[_image_upload()])
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert "embedding" in body
|
||||
assert isinstance(body["embedding"], list)
|
||||
assert len(body["embedding"]) > 0
|
||||
|
||||
|
||||
def test_embed_empty_image_400(siglip_client: TestClient) -> None:
|
||||
resp = siglip_client.post(
|
||||
"/embed",
|
||||
files=[("image", ("empty.png", io.BytesIO(b""), "image/png"))],
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
|
||||
|
||||
def test_embed_model_in_response(siglip_client: TestClient) -> None:
|
||||
resp = siglip_client.post("/embed", files=[_image_upload()])
|
||||
assert "model" in resp.json()
|
||||
|
||||
|
||||
# ── /caption ──────────────────────────────────────────────────────────────────
|
||||
|
||||
def test_caption_returns_text(vlm_client: TestClient) -> None:
|
||||
resp = vlm_client.post(
|
||||
"/caption",
|
||||
files=[_image_upload()],
|
||||
data={"prompt": ""},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert "caption" in body
|
||||
assert isinstance(body["caption"], str)
|
||||
|
||||
|
||||
def test_caption_with_prompt(vlm_client: TestClient) -> None:
|
||||
resp = vlm_client.post(
|
||||
"/caption",
|
||||
files=[_image_upload()],
|
||||
data={"prompt": "What text appears here?"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
|
||||
|
||||
def test_caption_empty_image_400(vlm_client: TestClient) -> None:
|
||||
resp = vlm_client.post(
|
||||
"/caption",
|
||||
files=[("image", ("empty.png", io.BytesIO(b""), "image/png"))],
|
||||
data={"prompt": ""},
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
|
||||
|
||||
# ── Label parser ──────────────────────────────────────────────────────────────
|
||||
|
||||
def test_parse_labels_json_array() -> None:
|
||||
assert _parse_labels('["cat", "dog"]') == ["cat", "dog"]
|
||||
|
||||
|
||||
def test_parse_labels_csv() -> None:
|
||||
assert _parse_labels("cat, dog, bird") == ["cat", "dog", "bird"]
|
||||
|
||||
|
||||
def test_parse_labels_single() -> None:
|
||||
assert _parse_labels("document") == ["document"]
|
||||
|
||||
|
||||
def test_parse_labels_empty() -> None:
|
||||
assert _parse_labels("") == []
|
||||
|
||||
|
||||
def test_parse_labels_whitespace_trimmed() -> None:
|
||||
assert _parse_labels(" cat , dog ") == ["cat", "dog"]
|
||||
247
tests/test_vision/test_backend.py
Normal file
247
tests/test_vision/test_backend.py
Normal file
|
|
@ -0,0 +1,247 @@
|
|||
"""
|
||||
Tests for cf-vision backends (mock) and factory routing.
|
||||
|
||||
Real SigLIP/VLM backends are not tested here — they require GPU + model downloads.
|
||||
The mock backend exercises the full Protocol surface so we can verify the contract
|
||||
without hardware dependencies.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from circuitforge_core.vision.backends.base import (
|
||||
VisionBackend,
|
||||
VisionResult,
|
||||
make_vision_backend,
|
||||
)
|
||||
from circuitforge_core.vision.backends.mock import MockVisionBackend
|
||||
|
||||
|
||||
# ── Fixtures ──────────────────────────────────────────────────────────────────
|
||||
|
||||
FAKE_IMAGE = b"\x89PNG\r\n\x1a\n" + b"\x00" * 100 # Not a real PNG, but enough for mock
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def mock_backend() -> MockVisionBackend:
|
||||
return MockVisionBackend(model_name="test-mock")
|
||||
|
||||
|
||||
# ── Protocol compliance ───────────────────────────────────────────────────────
|
||||
|
||||
def test_mock_is_vision_backend(mock_backend: MockVisionBackend) -> None:
|
||||
assert isinstance(mock_backend, VisionBackend)
|
||||
|
||||
|
||||
def test_mock_model_name(mock_backend: MockVisionBackend) -> None:
|
||||
assert mock_backend.model_name == "test-mock"
|
||||
|
||||
|
||||
def test_mock_vram_mb(mock_backend: MockVisionBackend) -> None:
|
||||
assert mock_backend.vram_mb == 0
|
||||
|
||||
|
||||
def test_mock_supports_embed(mock_backend: MockVisionBackend) -> None:
|
||||
assert mock_backend.supports_embed is True
|
||||
|
||||
|
||||
def test_mock_supports_caption(mock_backend: MockVisionBackend) -> None:
|
||||
assert mock_backend.supports_caption is True
|
||||
|
||||
|
||||
# ── classify() ───────────────────────────────────────────────────────────────
|
||||
|
||||
def test_classify_returns_vision_result(mock_backend: MockVisionBackend) -> None:
|
||||
result = mock_backend.classify(FAKE_IMAGE, ["cat", "dog", "bird"])
|
||||
assert isinstance(result, VisionResult)
|
||||
|
||||
|
||||
def test_classify_labels_preserved(mock_backend: MockVisionBackend) -> None:
|
||||
labels = ["cat", "dog", "bird"]
|
||||
result = mock_backend.classify(FAKE_IMAGE, labels)
|
||||
assert result.labels == labels
|
||||
|
||||
|
||||
def test_classify_scores_length_matches_labels(mock_backend: MockVisionBackend) -> None:
|
||||
labels = ["cat", "dog", "bird"]
|
||||
result = mock_backend.classify(FAKE_IMAGE, labels)
|
||||
assert len(result.scores) == len(labels)
|
||||
|
||||
|
||||
def test_classify_uniform_scores(mock_backend: MockVisionBackend) -> None:
|
||||
labels = ["cat", "dog", "bird"]
|
||||
result = mock_backend.classify(FAKE_IMAGE, labels)
|
||||
expected = 1.0 / 3
|
||||
for score in result.scores:
|
||||
assert abs(score - expected) < 1e-9
|
||||
|
||||
|
||||
def test_classify_single_label(mock_backend: MockVisionBackend) -> None:
|
||||
result = mock_backend.classify(FAKE_IMAGE, ["document"])
|
||||
assert result.labels == ["document"]
|
||||
assert abs(result.scores[0] - 1.0) < 1e-9
|
||||
|
||||
|
||||
def test_classify_model_name_in_result(mock_backend: MockVisionBackend) -> None:
|
||||
result = mock_backend.classify(FAKE_IMAGE, ["x"])
|
||||
assert result.model == "test-mock"
|
||||
|
||||
|
||||
# ── embed() ──────────────────────────────────────────────────────────────────
|
||||
|
||||
def test_embed_returns_vision_result(mock_backend: MockVisionBackend) -> None:
|
||||
result = mock_backend.embed(FAKE_IMAGE)
|
||||
assert isinstance(result, VisionResult)
|
||||
|
||||
|
||||
def test_embed_returns_embedding(mock_backend: MockVisionBackend) -> None:
|
||||
result = mock_backend.embed(FAKE_IMAGE)
|
||||
assert result.embedding is not None
|
||||
assert len(result.embedding) == 512
|
||||
|
||||
|
||||
def test_embed_is_unit_vector(mock_backend: MockVisionBackend) -> None:
|
||||
result = mock_backend.embed(FAKE_IMAGE)
|
||||
magnitude = math.sqrt(sum(v * v for v in result.embedding))
|
||||
assert abs(magnitude - 1.0) < 1e-6
|
||||
|
||||
|
||||
def test_embed_labels_empty(mock_backend: MockVisionBackend) -> None:
|
||||
result = mock_backend.embed(FAKE_IMAGE)
|
||||
assert result.labels == []
|
||||
assert result.scores == []
|
||||
|
||||
|
||||
def test_embed_model_name_in_result(mock_backend: MockVisionBackend) -> None:
|
||||
result = mock_backend.embed(FAKE_IMAGE)
|
||||
assert result.model == "test-mock"
|
||||
|
||||
|
||||
# ── caption() ────────────────────────────────────────────────────────────────
|
||||
|
||||
def test_caption_returns_vision_result(mock_backend: MockVisionBackend) -> None:
|
||||
result = mock_backend.caption(FAKE_IMAGE)
|
||||
assert isinstance(result, VisionResult)
|
||||
|
||||
|
||||
def test_caption_returns_string(mock_backend: MockVisionBackend) -> None:
|
||||
result = mock_backend.caption(FAKE_IMAGE)
|
||||
assert isinstance(result.caption, str)
|
||||
assert len(result.caption) > 0
|
||||
|
||||
|
||||
def test_caption_with_prompt(mock_backend: MockVisionBackend) -> None:
|
||||
result = mock_backend.caption(FAKE_IMAGE, prompt="What is in this image?")
|
||||
assert result.caption is not None
|
||||
|
||||
|
||||
def test_caption_model_name_in_result(mock_backend: MockVisionBackend) -> None:
|
||||
result = mock_backend.caption(FAKE_IMAGE)
|
||||
assert result.model == "test-mock"
|
||||
|
||||
|
||||
# ── VisionResult helpers ──────────────────────────────────────────────────────
|
||||
|
||||
def test_top_returns_sorted_pairs() -> None:
|
||||
result = VisionResult(
|
||||
labels=["cat", "dog", "bird"],
|
||||
scores=[0.3, 0.6, 0.1],
|
||||
)
|
||||
top = result.top(2)
|
||||
assert top[0] == ("dog", 0.6)
|
||||
assert top[1] == ("cat", 0.3)
|
||||
|
||||
|
||||
def test_top_default_n1() -> None:
|
||||
result = VisionResult(labels=["cat", "dog"], scores=[0.4, 0.9])
|
||||
assert result.top() == [("dog", 0.9)]
|
||||
|
||||
|
||||
# ── Factory routing ───────────────────────────────────────────────────────────
|
||||
|
||||
def test_factory_mock_flag() -> None:
|
||||
backend = make_vision_backend("any-model", mock=True)
|
||||
assert isinstance(backend, MockVisionBackend)
|
||||
|
||||
|
||||
def test_factory_mock_env(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setenv("CF_VISION_MOCK", "1")
|
||||
backend = make_vision_backend("any-model")
|
||||
assert isinstance(backend, MockVisionBackend)
|
||||
|
||||
|
||||
def test_factory_mock_model_name() -> None:
|
||||
backend = make_vision_backend("google/siglip-so400m-patch14-384", mock=True)
|
||||
assert backend.model_name == "google/siglip-so400m-patch14-384"
|
||||
|
||||
|
||||
def test_factory_unknown_backend_raises() -> None:
|
||||
with pytest.raises(ValueError, match="Unknown vision backend"):
|
||||
make_vision_backend("any-model", backend="nonexistent", mock=False)
|
||||
|
||||
|
||||
def test_factory_vlm_autodetect_moondream(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Auto-detection should select VLM for moondream model paths."""
|
||||
# We mock at the import level to avoid requiring GPU deps
|
||||
monkeypatch.setenv("CF_VISION_MOCK", "0")
|
||||
# Just verify the ValueError is about vlm backend, not "unknown"
|
||||
# (the ImportError from missing torch is expected in CI)
|
||||
try:
|
||||
make_vision_backend("vikhyatk/moondream2", mock=False)
|
||||
except ImportError:
|
||||
pass # Expected in CI without torch
|
||||
except ValueError as exc:
|
||||
pytest.fail(f"Should not raise ValueError for known backend: {exc}")
|
||||
|
||||
|
||||
def test_factory_siglip_autodetect(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Auto-detection should select siglip for non-VLM model paths (no ValueError)."""
|
||||
monkeypatch.setenv("CF_VISION_MOCK", "0")
|
||||
try:
|
||||
make_vision_backend("google/siglip-so400m-patch14-384", mock=False)
|
||||
except ValueError as exc:
|
||||
pytest.fail(f"Should not raise ValueError for known backend: {exc}")
|
||||
except Exception:
|
||||
pass # ImportError or model-loading errors are expected outside GPU CI
|
||||
|
||||
|
||||
# ── Process singleton ─────────────────────────────────────────────────────────
|
||||
|
||||
def test_module_classify_mock(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setenv("CF_VISION_MOCK", "1")
|
||||
# Reset the module-level singleton
|
||||
import circuitforge_core.vision as vision_mod
|
||||
vision_mod._backend = None
|
||||
|
||||
result = vision_mod.classify(FAKE_IMAGE, ["cat", "dog"])
|
||||
assert result.labels == ["cat", "dog"]
|
||||
assert len(result.scores) == 2
|
||||
|
||||
|
||||
def test_module_embed_mock(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setenv("CF_VISION_MOCK", "1")
|
||||
import circuitforge_core.vision as vision_mod
|
||||
vision_mod._backend = None
|
||||
|
||||
result = vision_mod.embed(FAKE_IMAGE)
|
||||
assert result.embedding is not None
|
||||
|
||||
|
||||
def test_module_caption_mock(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setenv("CF_VISION_MOCK", "1")
|
||||
import circuitforge_core.vision as vision_mod
|
||||
vision_mod._backend = None
|
||||
|
||||
result = vision_mod.caption(FAKE_IMAGE, prompt="Describe")
|
||||
assert result.caption is not None
|
||||
|
||||
|
||||
def test_module_make_backend_returns_fresh_instance(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
import circuitforge_core.vision as vision_mod
|
||||
b1 = vision_mod.make_backend("m1", mock=True)
|
||||
b2 = vision_mod.make_backend("m2", mock=True)
|
||||
assert b1 is not b2
|
||||
assert b1.model_name != b2.model_name
|
||||
Loading…
Reference in a new issue