Compare commits

..

No commits in common. "a92a83db4bcd414b26fd623fb3a5a21e00579e99" and "5a363f3b6cddea1c69bd420e6eeba304514086c4" have entirely different histories.

30 changed files with 62 additions and 2664 deletions

View file

@ -1,11 +1,4 @@
name: Release — PyPI + Forgejo Packages name: Release — PyPI
# circuitforge-core is MIT — published to both public PyPI and the Circuit-Forge
# Forgejo Packages index so cf-orch can resolve it from a single --extra-index-url.
#
# Required secrets:
# PYPI_API_TOKEN — public PyPI upload token
# FORGEJO_PYPI_TOKEN — Forgejo token with package:write scope
on: on:
push: push:
@ -26,36 +19,29 @@ jobs:
- name: Build - name: Build
run: | run: |
pip install build twine pip install build
python -m build python -m build
- name: Publish to public PyPI - name: Publish to PyPI
uses: pypa/gh-action-pypi-publish@release/v1 uses: pypa/gh-action-pypi-publish@release/v1
with: with:
password: ${{ secrets.PYPI_API_TOKEN }} password: ${{ secrets.PYPI_API_TOKEN }}
- name: Publish to Forgejo Packages
env:
TWINE_USERNAME: pypi-token
TWINE_PASSWORD: ${{ secrets.FORGEJO_PYPI_TOKEN }}
TWINE_REPOSITORY_URL: https://git.opensourcesolarpunk.com/api/packages/Circuit-Forge/pypi
run: twine upload dist/*
- name: Create Forgejo release - name: Create Forgejo release
env: env:
FORGEJO_TOKEN: ${{ secrets.FORGEJO_PYPI_TOKEN }} FORGEJO_TOKEN: ${{ secrets.FORGEJO_RELEASE_TOKEN }}
run: | run: |
TAG="${GITHUB_REF_NAME}" TAG="${GITHUB_REF_NAME}"
# Check if release already exists for this tag
EXISTING=$(curl -sf \ EXISTING=$(curl -sf \
-H "Authorization: token ${FORGEJO_TOKEN}" \ -H "Authorization: token ${FORGEJO_TOKEN}" \
"https://git.opensourcesolarpunk.com/api/v1/repos/Circuit-Forge/circuitforge-core/releases/tags/${TAG}" \ "https://git.opensourcesolarpunk.com/api/v1/repos/Circuit-Forge/circuitforge-core/releases/tags/${TAG}" \
2>/dev/null \ 2>/dev/null | jq -r '.id // empty')
| python3 -c "import sys,json; print(json.load(sys.stdin).get('id',''))" 2>/dev/null || true)
if [ -z "${EXISTING}" ]; then if [ -z "${EXISTING}" ]; then
python3 -c " jq -n --arg tag "${TAG}" \
import json '{"tag_name":$tag,"name":$tag,"draft":false,"prerelease":false}' \
print(json.dumps({'tag_name':'${TAG}','name':'${TAG}','draft':False,'prerelease':False})) | curl -sf -X POST \
" | curl -sf -X POST \
-H "Authorization: token ${FORGEJO_TOKEN}" \ -H "Authorization: token ${FORGEJO_TOKEN}" \
-H "Content-Type: application/json" \ -H "Content-Type: application/json" \
"https://git.opensourcesolarpunk.com/api/v1/repos/Circuit-Forge/circuitforge-core/releases" \ "https://git.opensourcesolarpunk.com/api/v1/repos/Circuit-Forge/circuitforge-core/releases" \

View file

@ -1,9 +1,4 @@
from importlib.metadata import PackageNotFoundError, version __version__ = "0.18.0"
try:
__version__ = version("circuitforge-core")
except PackageNotFoundError:
__version__ = "dev" # running from source without an editable install
try: try:
from circuitforge_core.community import CommunityDB, CommunityPost, SharedStore from circuitforge_core.community import CommunityDB, CommunityPost, SharedStore

View file

@ -39,13 +39,6 @@ from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from typing import Any, Callable from typing import Any, Callable
try:
from starlette.requests import Request as _Request
from starlette.responses import Response as _Response
except ImportError: # pragma: no cover — starlette may be absent in non-web envs
_Request = Any # type: ignore[assignment,misc]
_Response = Any # type: ignore[assignment,misc]
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
TIERS: list[str] = ["free", "paid", "premium", "ultra"] TIERS: list[str] = ["free", "paid", "premium", "ultra"]
@ -255,40 +248,22 @@ class CloudSessionFactory:
request.headers.get("x-real-ip", "") request.headers.get("x-real-ip", "")
or (request.client.host if request.client else "") or (request.client.host if request.client else "")
) )
is_bypass = _is_bypass_ip(client_ip, self._bypass_nets, self._bypass_ips) if _is_bypass_ip(client_ip, self._bypass_nets, self._bypass_ips):
log.debug("Bypass IP %s — returning local-dev session for product %s", client_ip, self.product)
return CloudUser(user_id="local-dev", tier="local", product=self.product, has_byok=has_byok)
raw_session = ( raw_session = (
request.headers.get("x-cf-session", "").strip() request.headers.get("x-cf-session", "").strip()
or request.cookies.get("cf_session", "").strip() or request.cookies.get("cf_session", "").strip()
) )
# Bypass IPs skip the JWT *requirement* but not JWT *validation*.
# If a token is present (dev is logged in), honour it so they land on
# their own account DB rather than the shared local-dev DB.
if not raw_session: if not raw_session:
if is_bypass:
log.debug("Bypass IP %s, no token — returning local-dev session for product %s", client_ip, self.product)
return CloudUser(user_id="local-dev", tier="local", product=self.product, has_byok=has_byok)
return self._resolve_guest(request, response) return self._resolve_guest(request, response)
token = _extract_session_token(raw_session) token = _extract_session_token(raw_session)
if not token: if not token:
return self._resolve_guest(request, response) return self._resolve_guest(request, response)
# Soft-fail on invalid/expired JWT: downgrade to guest rather than user_id = self.validate_jwt(token)
# hard-erroring with 401. Public endpoints (e.g. community blocklist)
# should remain accessible even when the browser has a stale cookie.
# Routes that genuinely require an authenticated identity should gate
# themselves with require_tier() — that's where the 401/403 belongs.
try:
user_id = self.validate_jwt(token)
except Exception:
log.warning(
"JWT validation failed for product %s (expired or tampered) — falling back to guest",
self.product,
)
return self._resolve_guest(request, response)
self._ensure_provisioned(user_id) self._ensure_provisioned(user_id)
tier_data = self._resolve_tier(user_id) tier_data = self._resolve_tier(user_id)
tier = tier_data.get("tier", "free") tier = tier_data.get("tier", "free")
@ -308,11 +283,11 @@ class CloudSessionFactory:
meta=meta, meta=meta,
) )
def dependency(self) -> Callable[["_Request", "_Response"], CloudUser]: def dependency(self) -> Callable[[Any, Any], CloudUser]:
"""Return a FastAPI-compatible dependency function (use with Depends()).""" """Return a FastAPI-compatible dependency function (use with Depends())."""
factory = self factory = self
def _get_session(request: _Request, response: _Response) -> CloudUser: def _get_session(request: Any, response: Any) -> CloudUser:
return factory.resolve(request, response) return factory.resolve(request, response)
return _get_session return _get_session

View file

@ -1,54 +0,0 @@
"""circuitforge_core.memory — persistent knowledge graph via mnemo sidecar.
MIT licensed.
Requires the mnemo sidecar to be running (https://github.com/zaydmulani09/mnemo).
If the sidecar is not available, all operations silently no-op so products
can call memory methods unconditionally.
Quick start (in a FastAPI lifespan)::
from circuitforge_core.memory import MemoryClient, MemoryConfig
memory = MemoryClient(MemoryConfig.from_env())
@asynccontextmanager
async def lifespan(app):
await memory.connect()
yield
await memory.close()
# In a route:
await memory.remember("User avoids shellfish", source="dietary-prefs")
context = await memory.recall("What are this user's food restrictions?")
Docker Compose setup::
services:
mnemo:
image: ghcr.io/zaydmulani09/mnemo:latest
ports: ["8080:8080"]
environment:
MNEMO_LLM_PROVIDER: ollama
MNEMO_LLM_BASE_URL: http://ollama:11434/v1
MNEMO_LLM_MODEL: llama3
volumes:
- mnemo-data:/data
Environment variables (for MemoryConfig.from_env())::
MNEMO_HOST default: localhost
MNEMO_PORT default: 8080
MNEMO_TIMEOUT default: 10.0
"""
from circuitforge_core.memory.client import MemoryClient, MemoryUnavailableError
from circuitforge_core.memory.models import MemoryConfig, MemoryEntity, MemoryStats
__all__ = [
"MemoryClient",
"MemoryConfig",
"MemoryEntity",
"MemoryStats",
"MemoryUnavailableError",
]

View file

@ -1,317 +0,0 @@
"""MemoryClient — async wrapper around the mnemo persistent knowledge graph.
mnemo is an optional sidecar (https://github.com/zaydmulani09/mnemo).
When the sidecar is not running, all operations silently no-op so products
can call memory methods unconditionally without try/except.
MIT licensed.
"""
from __future__ import annotations
import logging
import time
from typing import Any
from circuitforge_core.memory.models import MemoryConfig, MemoryEntity, MemoryStats
logger = logging.getLogger(__name__)
# Backoff schedule: 5 * 2^(failure-1), capped at _MAX_BACKOFF seconds.
# failure 1 → 5s, 2 → 10s, 3 → 20s, 4 → 40s, 5+ → 60s
_MAX_FAILURES: int = 3
_MAX_BACKOFF: float = 60.0
class MemoryUnavailableError(RuntimeError):
"""Raised only when strict=True and mnemo is not reachable."""
class MemoryClient:
"""Async interface to the mnemo knowledge graph sidecar.
Resilience model:
- If the sidecar is unreachable at connect(), logs once and enters no-op mode.
- If a live call fails, the failure is counted. Each failure schedules an
exponentially increasing cooldown before the next reconnect attempt.
- After _MAX_FAILURES consecutive failures the client is marked unavailable;
all calls no-op until the cooldown elapses and a reconnect succeeds.
- Any successful call resets the failure counter.
Usage (in a FastAPI lifespan)::
from circuitforge_core.memory import MemoryClient, MemoryConfig
memory = MemoryClient(MemoryConfig.from_env())
@asynccontextmanager
async def lifespan(app):
await memory.connect()
yield
await memory.close()
Then in handlers::
await memory.remember("User prefers dark mode", source="settings")
context = await memory.recall("What are the user's UI preferences?")
"""
def __init__(self, config: MemoryConfig | None = None, *, strict: bool = False) -> None:
"""
Args:
config: connection settings; defaults to MemoryConfig.from_env()
strict: if True, MemoryUnavailableError is raised on connect failure
or after _MAX_FAILURES consecutive call failures
"""
self._config = config or MemoryConfig.from_env()
self._strict = strict
self._available = False
self._client: Any = None # mnemo AsyncMnemoClient, set in connect()
self._failure_count: int = 0
self._retry_at: float | None = None # monotonic timestamp; None = no retry pending
@property
def available(self) -> bool:
"""True if the mnemo sidecar was reachable at last health check."""
return self._available
@property
def failure_count(self) -> int:
"""Consecutive call failures since the last success."""
return self._failure_count
# ── Lifecycle ─────────────────────────────────────────────────────────────
async def connect(self) -> None:
"""Attempt to connect to the mnemo sidecar and run a health check.
Safe to call multiple times (used internally for reconnect). If the
sidecar is not reachable, logs a warning and enters no-op mode.
Does NOT raise unless strict=True.
"""
try:
from mnemo import AsyncMnemoClient
except ImportError:
logger.debug(
"mnemo-sdk not installed — memory module disabled. "
"Install with: pip install circuitforge-core[memory]"
)
self._available = False
return
self._client = AsyncMnemoClient(
base_url=self._config.base_url,
timeout=self._config.timeout,
)
try:
health = await self._client.health()
if health.status == "ok":
self._available = True
self._on_call_success()
logger.info(
"mnemo memory sidecar connected at %s (LLM: %s/%s)",
self._config.base_url,
health.provider_type,
health.provider_model,
)
else:
self._handle_unavailable("connect", reason=f"health status={health.status!r}")
except Exception as exc:
self._handle_unavailable("connect", reason=str(exc))
async def close(self) -> None:
"""Close the underlying HTTP client."""
if self._client is not None:
try:
await self._client.__aexit__(None, None, None)
except Exception:
pass
self._client = None
self._available = False
self._retry_at = None
# ── Core API ──────────────────────────────────────────────────────────────
async def remember(
self,
text: str,
*,
source: str = "cf-core",
session_id: str | None = None,
) -> bool:
"""Store a text fragment in the knowledge graph.
mnemo extracts named entities and relationships from the text and
updates its graph. Large texts should be pre-chunked by the caller
(mnemo stores each call as a single chunk with no sub-splitting).
Args:
text: the text to store (conversation turn, fact, note, etc.)
source: label for the origin (e.g. "chat", "settings", "search")
session_id: optional session grouping for multi-turn retrieval
Returns:
True if stored, False if sidecar unavailable.
"""
if not await self._maybe_reconnect():
return False
try:
await self._client.ingest(content=text, source=source, session_id=session_id)
self._on_call_success()
return True
except Exception as exc:
self._on_call_error("remember", exc)
return False
async def recall(
self,
query: str,
*,
session_id: str | None = None,
) -> str:
"""Retrieve a formatted context block relevant to query.
Returns a prompt-ready string (or empty string if unavailable).
Inject the result directly into a system prompt::
context = await memory.recall("user dietary restrictions")
system = f"You are a helpful assistant.\\n\\n{context}"
Args:
query: natural language question or topic to retrieve context for
session_id: restrict retrieval to a specific session (optional)
Returns:
Formatted context string, or "" if sidecar unavailable.
"""
if not await self._maybe_reconnect():
return ""
try:
result = await self._client.get_context(text=query, session_id=session_id)
self._failure_count = 0
return result
except Exception as exc:
self._on_call_error("recall", exc)
return ""
async def entities(self, *, limit: int = 50) -> list[MemoryEntity]:
"""Return the most recent named entities in the knowledge graph.
Args:
limit: max entities to return (default 50)
Returns:
List of MemoryEntity objects, or [] if unavailable.
"""
if not await self._maybe_reconnect():
return []
try:
raw = await self._client.list_entities(limit=limit)
self._on_call_success()
return [MemoryEntity.from_mnemo(e) for e in raw]
except Exception as exc:
self._on_call_error("entities", exc)
return []
async def stats(self) -> MemoryStats | None:
"""Return knowledge graph statistics, or None if unavailable."""
if not await self._maybe_reconnect():
return None
try:
s = await self._client.stats()
self._on_call_success()
return MemoryStats(
entity_count=s.entity_count,
chunk_count=s.chunk_count,
node_count=s.node_count,
edge_count=s.edge_count,
uptime_seconds=s.uptime_seconds,
available=True,
)
except Exception as exc:
self._on_call_error("stats", exc)
return None
async def wipe(self) -> bool:
"""Delete all stored memory. Irreversible.
Returns True on success, False if unavailable or failed.
"""
if not await self._maybe_reconnect():
return False
try:
await self._client.wipe()
self._on_call_success()
logger.warning("mnemo memory wiped — all entities and chunks deleted")
return True
except Exception as exc:
self._on_call_error("wipe", exc)
return False
# ── Internal ──────────────────────────────────────────────────────────────
async def _maybe_reconnect(self) -> bool:
"""Return True if the client is available (or just became available).
Called at the top of every public method. If the client is unavailable
but the retry cooldown has elapsed, silently attempts reconnect before
answering. No-ops immediately if still within the cooldown window.
"""
if self._available:
return True
if self._retry_at is not None and time.monotonic() >= self._retry_at:
logger.info(
"mnemo: cooldown elapsed after %d failure(s) — attempting reconnect",
self._failure_count,
)
self._retry_at = None
self._client = None
await self.connect()
return self._available
def _on_call_success(self) -> None:
"""Reset failure state after a successful call."""
self._failure_count = 0
self._retry_at = None
def _handle_unavailable(self, operation: str, reason: str = "") -> None:
"""Called when the sidecar is unreachable at connect() time."""
self._available = False
msg = f"mnemo memory sidecar unavailable (operation={operation!r})"
if reason:
msg += f": {reason}"
if self._strict:
raise MemoryUnavailableError(msg)
logger.warning("%s — memory features disabled", msg)
def _on_call_error(self, operation: str, exc: Exception) -> None:
"""Count consecutive failures and schedule exponential backoff retry.
Backoff: 5 * 2^(failure-1) seconds, capped at 60s.
failure 1 5s
failure 2 10s
failure 3 20s _MAX_FAILURES default; client disabled here
failure 4 40s
failure 5+ 60s
After _MAX_FAILURES, _available is set to False and all calls no-op
until _maybe_reconnect() fires after the cooldown elapses.
"""
self._failure_count += 1
backoff = min(5.0 * (2 ** (self._failure_count - 1)), _MAX_BACKOFF)
self._retry_at = time.monotonic() + backoff
if self._failure_count >= _MAX_FAILURES:
self._available = False
logger.warning(
"mnemo %r failed %d consecutive times (%s) — disabled, reconnect in %.0fs",
operation, self._failure_count, exc, backoff,
)
if self._strict:
raise MemoryUnavailableError(
f"mnemo {operation!r} failed {self._failure_count} consecutive times: {exc}"
)
else:
logger.warning(
"mnemo %r failed (%d/%d): %s — retry in %.0fs",
operation, self._failure_count, _MAX_FAILURES, exc, backoff,
)

View file

@ -1,73 +0,0 @@
"""Data models for the cf-core memory module.
MIT licensed.
"""
from __future__ import annotations
import os
from dataclasses import dataclass, field
from datetime import datetime
@dataclass(frozen=True)
class MemoryConfig:
"""Connection config for a mnemo sidecar."""
host: str = "localhost"
port: int = 8080
timeout: float = 10.0
@classmethod
def from_env(cls) -> MemoryConfig:
"""Read config from environment variables.
Variables:
MNEMO_HOST default: localhost
MNEMO_PORT default: 8080
MNEMO_TIMEOUT default: 10.0
"""
return cls(
host=os.environ.get("MNEMO_HOST", "localhost"),
port=int(os.environ.get("MNEMO_PORT", "8080")),
timeout=float(os.environ.get("MNEMO_TIMEOUT", "10.0")),
)
@property
def base_url(self) -> str:
return f"http://{self.host}:{self.port}"
@dataclass(frozen=True)
class MemoryEntity:
"""A named entity extracted and stored by the mnemo knowledge graph."""
entity_id: str
name: str
entity_type: str
aliases: list[str] = field(default_factory=list)
confidence: float = 1.0
source_count: int = 1
@classmethod
def from_mnemo(cls, obj) -> MemoryEntity:
"""Convert a mnemo-sdk Entity object to MemoryEntity."""
return cls(
entity_id=str(obj.id),
name=obj.name,
entity_type=obj.entity_type,
aliases=list(obj.aliases or []),
confidence=float(obj.confidence or 1.0),
source_count=int(obj.source_count or 1),
)
@dataclass(frozen=True)
class MemoryStats:
"""Snapshot of the mnemo knowledge graph state."""
entity_count: int
chunk_count: int
node_count: int
edge_count: int
uptime_seconds: float
available: bool

View file

@ -1,42 +0,0 @@
"""circuitforge_core.mqtt — async MQTT client with topic routing and
Meshtastic adapter support.
MIT licensed.
Quick start::
from circuitforge_core.mqtt import MQTTClient, MQTTConfig
cfg = MQTTConfig(host="localhost")
client = MQTTClient(cfg)
@client.on("sensors/#")
async def handle(msg):
print(msg.topic, msg.text())
await client.run()
For Meshtastic::
from circuitforge_core.mqtt.meshtastic import make_backend
backend = make_backend({
"backend": "mqtt",
"broker_host": "mqtt.example.com",
"topic_prefix": "msh/#",
})
async for pkt in backend.packets():
print(pkt.summary())
"""
from circuitforge_core.mqtt.client import MQTTClient
from circuitforge_core.mqtt.models import MQTTConfig, MQTTMessage
from circuitforge_core.mqtt.router import TopicRouter, matches
__all__ = [
"MQTTClient",
"MQTTConfig",
"MQTTMessage",
"TopicRouter",
"matches",
]

View file

@ -1,152 +0,0 @@
"""Async MQTT client wrapper around aiomqtt.
MIT licensed.
"""
from __future__ import annotations
import asyncio
import logging
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from datetime import datetime, timezone
from typing import Any
from circuitforge_core.mqtt.models import MQTTConfig, MQTTMessage
from circuitforge_core.mqtt.router import TopicRouter
logger = logging.getLogger(__name__)
class MQTTClient:
"""Async MQTT client that subscribes to topics and dispatches messages.
Usage (with a router)::
cfg = MQTTConfig(host="localhost")
client = MQTTClient(cfg)
@client.on("msh/#")
async def handle_mesh(msg: MQTTMessage):
print(msg.topic, msg.text())
await client.run()
Usage (iterate raw messages)::
async with MQTTClient(cfg) as messages:
async for msg in messages:
print(msg.topic)
"""
def __init__(self, config: MQTTConfig, router: TopicRouter | None = None) -> None:
self._config = config
self._router = router or TopicRouter()
def on(self, pattern: str):
"""Shorthand decorator — forwards to the internal router."""
return self._router.on(pattern)
async def run(self) -> None:
"""Subscribe to all registered patterns and dispatch until cancelled.
Reconnects automatically if the connection drops.
"""
try:
import aiomqtt
except ImportError as exc:
raise ImportError(
"aiomqtt is required for MQTTClient. "
"Install with: pip install circuitforge-core[mqtt]"
) from exc
cfg = self._config
while True:
try:
kwargs: dict[str, Any] = {
"hostname": cfg.host,
"port": cfg.port,
"keepalive": cfg.keepalive,
"tls_params": aiomqtt.TLSParameters() if cfg.tls else None,
}
if cfg.client_id:
kwargs["identifier"] = cfg.client_id
if cfg.username is not None:
kwargs["username"] = cfg.username
if cfg.password is not None:
kwargs["password"] = cfg.password
async with aiomqtt.Client(**kwargs) as ac:
patterns = self._router.patterns
if not patterns:
logger.warning("MQTTClient started with no subscriptions")
for p in patterns:
await ac.subscribe(p)
logger.debug("Subscribed to %r on %s:%d", p, cfg.host, cfg.port)
logger.info("MQTT connected to %s:%d", cfg.host, cfg.port)
async for raw in ac.messages:
msg = MQTTMessage(
topic=str(raw.topic),
payload=raw.payload if isinstance(raw.payload, bytes) else str(raw.payload).encode(),
qos=raw.qos,
retain=raw.retain,
received_at=datetime.now(tz=timezone.utc),
)
await self._router.dispatch(msg)
except asyncio.CancelledError:
logger.info("MQTTClient cancelled")
raise
except Exception as exc:
logger.warning(
"MQTT connection to %s:%d failed (%s), retrying in %.0fs",
cfg.host, cfg.port, exc, cfg.reconnect_interval,
)
await asyncio.sleep(cfg.reconnect_interval)
@asynccontextmanager
async def connect(self) -> AsyncIterator[AsyncIterator[MQTTMessage]]:
"""Context manager that yields an async iterator of raw messages.
Useful when the caller wants to do its own routing::
async with client.connect() as messages:
async for msg in messages:
...
"""
try:
import aiomqtt
except ImportError as exc:
raise ImportError(
"aiomqtt is required. Install with: pip install circuitforge-core[mqtt]"
) from exc
cfg = self._config
kwargs: dict[str, Any] = {
"hostname": cfg.host,
"port": cfg.port,
"keepalive": cfg.keepalive,
"tls_params": aiomqtt.TLSParameters() if cfg.tls else None,
}
if cfg.client_id:
kwargs["identifier"] = cfg.client_id
if cfg.username is not None:
kwargs["username"] = cfg.username
if cfg.password is not None:
kwargs["password"] = cfg.password
async with aiomqtt.Client(**kwargs) as ac:
for p in self._router.patterns:
await ac.subscribe(p)
async def _iter() -> AsyncIterator[MQTTMessage]:
async for raw in ac.messages:
yield MQTTMessage(
topic=str(raw.topic),
payload=raw.payload if isinstance(raw.payload, bytes) else str(raw.payload).encode(),
qos=raw.qos,
retain=raw.retain,
received_at=datetime.now(tz=timezone.utc),
)
yield _iter()

View file

@ -1,76 +0,0 @@
"""Meshtastic adapter for circuitforge-core.
Two backends are available:
- ``MQTTMeshtasticBackend`` subscribes to a Meshtastic MQTT bridge
- ``SerialMeshtasticBackend`` direct serial/TCP connection via the
``meshtastic`` Python library
Use ``make_backend()`` for config-driven selection.
MIT licensed.
"""
from __future__ import annotations
from circuitforge_core.mqtt.meshtastic.interface import MeshtasticInterface
from circuitforge_core.mqtt.meshtastic.models import (
MeshtasticPacket,
MeshtasticPosition,
MeshtasticTelemetry,
)
from circuitforge_core.mqtt.meshtastic.mqtt_backend import MQTTMeshtasticBackend
from circuitforge_core.mqtt.meshtastic.serial_backend import SerialMeshtasticBackend
from circuitforge_core.mqtt.models import MQTTConfig
def make_backend(config: dict) -> MeshtasticInterface:
"""Construct a Meshtastic backend from a config dict.
Config keys:
backend (str): ``"mqtt"`` or ``"serial"`` (required)
For ``"mqtt"`` backend:
broker_host (str): MQTT broker hostname
broker_port (int): MQTT broker port (default 1883)
broker_username (str|None): optional
broker_password (str|None): optional
topic_prefix (str): topic to subscribe to (default ``msh/#``)
For ``"serial"`` backend:
dev_path (str|None): serial device, e.g. ``/dev/ttyUSB0``
tcp_host (str|None): TCP hostname for TCP mode
tcp_port (int): TCP port (default 4403)
"""
backend = config.get("backend", "mqtt").lower()
if backend == "mqtt":
mqtt_cfg = MQTTConfig(
host=config["broker_host"],
port=int(config.get("broker_port", 1883)),
username=config.get("broker_username"),
password=config.get("broker_password"),
)
return MQTTMeshtasticBackend(
mqtt_config=mqtt_cfg,
topic_prefix=config.get("topic_prefix", "msh/#"),
)
if backend == "serial":
return SerialMeshtasticBackend(
dev_path=config.get("dev_path"),
tcp_host=config.get("tcp_host"),
tcp_port=int(config.get("tcp_port", 4403)),
)
raise ValueError(f"Unknown Meshtastic backend: {backend!r}. Must be 'mqtt' or 'serial'.")
__all__ = [
"MeshtasticInterface",
"MeshtasticPacket",
"MeshtasticPosition",
"MeshtasticTelemetry",
"MQTTMeshtasticBackend",
"SerialMeshtasticBackend",
"make_backend",
]

View file

@ -1,36 +0,0 @@
"""Abstract interface for Meshtastic backends.
MIT licensed.
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from collections.abc import AsyncIterator
class MeshtasticInterface(ABC):
"""Async interface for receiving and sending Meshtastic packets.
Two concrete backends exist:
- MQTTMeshtasticBackend subscribes to a Meshtastic MQTT bridge
- SerialMeshtasticBackend connects directly via the meshtastic Python library
"""
@abstractmethod
def packets(self) -> AsyncIterator:
"""Async generator of MeshtasticPacket objects.
Yields packets as they arrive. Runs until cancelled.
Concrete types are ``MeshtasticPacket`` from
``circuitforge_core.mqtt.meshtastic.models``.
"""
@abstractmethod
async def send_text(
self,
text: str,
dest_id: int = 0xFFFFFFFF,
channel: int = 0,
) -> None:
"""Send a text message to dest_id (default: broadcast)."""

View file

@ -1,83 +0,0 @@
"""Data models for Meshtastic packets.
MIT licensed.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import Literal
# Meshtastic portnum → our label
PacketType = Literal[
"text",
"position",
"nodeinfo",
"telemetry",
"routing",
"admin",
"unknown",
]
@dataclass(frozen=True)
class MeshtasticPosition:
latitude: float | None = None
longitude: float | None = None
altitude_m: int | None = None
timestamp: datetime | None = None
@dataclass(frozen=True)
class MeshtasticTelemetry:
battery_level: int | None = None # 0-100 %
voltage: float | None = None # volts
channel_util: float | None = None # 0-100 %
air_util_tx: float | None = None # 0-100 %
@dataclass(frozen=True)
class MeshtasticPacket:
"""Normalized Meshtastic packet from any backend."""
packet_type: PacketType
from_id: str # hex node ID, e.g. "!deadbeef"
from_num: int # numeric node ID
to_num: int # 0xffffffff = broadcast
channel: int
received_at: datetime = field(default_factory=lambda: datetime.now(tz=timezone.utc))
# Type-specific payloads (only one is populated per packet type)
text: str | None = None
position: MeshtasticPosition | None = None
telemetry: MeshtasticTelemetry | None = None
node_longname: str | None = None
node_shortname: str | None = None
hardware: int | None = None
# Original raw payload dict for downstream consumers that need all fields
raw: dict = field(default_factory=dict, compare=False, hash=False)
@property
def is_broadcast(self) -> bool:
return self.to_num == 0xFFFFFFFF
def summary(self) -> str:
"""One-line human-readable description."""
src = self.from_id or f"!{self.from_num:08x}"
if self.packet_type == "text":
return f"[{src}] {self.text}"
if self.packet_type == "position" and self.position:
p = self.position
return f"[{src}] position {p.latitude:.5f},{p.longitude:.5f}"
if self.packet_type == "nodeinfo":
return f"[{src}] node info: {self.node_longname!r} ({self.node_shortname})"
if self.packet_type == "telemetry" and self.telemetry:
t = self.telemetry
parts = []
if t.battery_level is not None:
parts.append(f"batt={t.battery_level}%")
if t.voltage is not None:
parts.append(f"v={t.voltage:.2f}V")
return f"[{src}] telemetry {' '.join(parts)}"
return f"[{src}] {self.packet_type} packet"

View file

@ -1,214 +0,0 @@
"""Meshtastic MQTT bridge backend.
Subscribes to the JSON MQTT topics that Meshtastic firmware publishes when
the MQTT uplink is enabled on a node.
Topic schema (Meshtastic firmware >=2.1):
msh/{region}/{gateway}/2/json/{portnum}/{fromId}
The payload is a JSON object. Examples by type:
Text message:
{"channel":0,"from":123456789,"id":987,"payload":{"text":"hello"},
"sender":"!07558d85","timestamp":1716200000,"to":4294967295,"type":"sendtext"}
Position:
{"channel":0,"from":123456789,"payload":{"altitude":50,
"latitude_i":374208130,"longitude_i":-1220848320,"time":1716200000},
"type":"position"}
Node info:
{"channel":0,"from":123456789,"payload":{"hardware":43,
"id":"!07558d85","longname":"Alan Node","shortname":"AN"},
"type":"nodeinfo"}
Telemetry:
{"channel":0,"from":123456789,"payload":{"battery_level":82,
"voltage":4.09,"channel_utilization":0.5,"air_util_tx":0.01,
"time":1716200000},"type":"telemetry"}
MIT licensed.
"""
from __future__ import annotations
import asyncio
import json
import logging
from collections.abc import AsyncIterator
from datetime import datetime, timezone
from circuitforge_core.mqtt.client import MQTTClient
from circuitforge_core.mqtt.meshtastic.interface import MeshtasticInterface
from circuitforge_core.mqtt.meshtastic.models import (
MeshtasticPacket,
MeshtasticPosition,
MeshtasticTelemetry,
)
from circuitforge_core.mqtt.models import MQTTConfig, MQTTMessage
logger = logging.getLogger(__name__)
# latitude_i / longitude_i are stored as integer × 1e7 in Meshtastic protobuf.
_COORD_SCALE = 1e-7
def _parse_packet(raw_json: str | bytes, topic: str) -> MeshtasticPacket | None:
"""Parse a Meshtastic MQTT JSON payload into a MeshtasticPacket.
Returns None if the payload cannot be parsed or is an encrypted packet
(payload is a base64 blob instead of a dict).
"""
try:
obj = json.loads(raw_json)
except json.JSONDecodeError:
logger.debug("Non-JSON Meshtastic payload on topic %r", topic)
return None
payload = obj.get("payload")
if not isinstance(payload, dict):
# Encrypted packet — payload is a base64 string; skip.
return None
from_num: int = obj.get("from", 0)
sender: str = obj.get("sender", f"!{from_num:08x}")
channel: int = obj.get("channel", 0)
to_num: int = obj.get("to", 0xFFFFFFFF)
raw_ts: int | None = payload.get("time") or obj.get("timestamp")
received_at = (
datetime.fromtimestamp(raw_ts, tz=timezone.utc) if raw_ts else datetime.now(tz=timezone.utc)
)
ptype: str = obj.get("type", "unknown").lower()
if ptype in ("sendtext", "text"):
return MeshtasticPacket(
packet_type="text",
from_id=sender,
from_num=from_num,
to_num=to_num,
channel=channel,
received_at=received_at,
text=payload.get("text", ""),
raw=obj,
)
if ptype == "position":
lat_i: int | None = payload.get("latitude_i")
lon_i: int | None = payload.get("longitude_i")
return MeshtasticPacket(
packet_type="position",
from_id=sender,
from_num=from_num,
to_num=to_num,
channel=channel,
received_at=received_at,
position=MeshtasticPosition(
latitude=lat_i * _COORD_SCALE if lat_i is not None else None,
longitude=lon_i * _COORD_SCALE if lon_i is not None else None,
altitude_m=payload.get("altitude"),
timestamp=received_at,
),
raw=obj,
)
if ptype == "nodeinfo":
return MeshtasticPacket(
packet_type="nodeinfo",
from_id=sender,
from_num=from_num,
to_num=to_num,
channel=channel,
received_at=received_at,
node_longname=payload.get("longname"),
node_shortname=payload.get("shortname"),
hardware=payload.get("hardware"),
raw=obj,
)
if ptype == "telemetry":
return MeshtasticPacket(
packet_type="telemetry",
from_id=sender,
from_num=from_num,
to_num=to_num,
channel=channel,
received_at=received_at,
telemetry=MeshtasticTelemetry(
battery_level=payload.get("battery_level"),
voltage=payload.get("voltage"),
channel_util=payload.get("channel_utilization"),
air_util_tx=payload.get("air_util_tx"),
),
raw=obj,
)
# Routing, admin, and other packet types — return minimal packet.
return MeshtasticPacket(
packet_type="unknown",
from_id=sender,
from_num=from_num,
to_num=to_num,
channel=channel,
received_at=received_at,
raw=obj,
)
class MQTTMeshtasticBackend(MeshtasticInterface):
"""Receive Meshtastic packets via a Meshtastic MQTT bridge.
Requires a Meshtastic node with the MQTT uplink enabled, publishing to
the configured broker. Set ``topic_prefix`` to match the region prefix
configured on the node (default ``msh/#`` matches all regions).
Args:
mqtt_config: broker connection settings
topic_prefix: MQTT topic pattern to subscribe to (default ``msh/#``)
"""
def __init__(
self,
mqtt_config: MQTTConfig,
topic_prefix: str = "msh/#",
) -> None:
self._mqtt_config = mqtt_config
self._topic_prefix = topic_prefix
async def packets(self) -> AsyncIterator[MeshtasticPacket]:
client = MQTTClient(self._mqtt_config)
queue: asyncio.Queue[MeshtasticPacket] = asyncio.Queue()
@client.on(self._topic_prefix)
async def _handle(msg: MQTTMessage) -> None:
pkt = _parse_packet(msg.payload, msg.topic)
if pkt is not None:
await queue.put(pkt)
runner = asyncio.create_task(client.run())
try:
while True:
yield await queue.get()
finally:
runner.cancel()
try:
await runner
except asyncio.CancelledError:
pass
async def send_text(
self,
text: str,
dest_id: int = 0xFFFFFFFF,
channel: int = 0,
) -> None:
"""Publishing back to MQTT is not supported by this backend.
Meshtastic nodes consume from MQTT in a different topic namespace;
use the serial backend or a direct Meshtastic MQTT channel config
for two-way messaging.
"""
raise NotImplementedError(
"MQTTMeshtasticBackend is receive-only. "
"Use SerialMeshtasticBackend for send support."
)

View file

@ -1,210 +0,0 @@
"""Meshtastic serial/TCP backend using the meshtastic Python library.
Connects directly to a Meshtastic node over serial port or TCP (e.g.
when a node exposes Meshtastic's native TCP API on port 4403).
The ``meshtastic`` library is synchronous and uses threading + PyPubSub
for callbacks. This backend bridges into asyncio via an asyncio.Queue:
the sync callback puts packets on the queue, and ``packets()`` awaits
items from it.
MIT licensed.
"""
from __future__ import annotations
import asyncio
import logging
from collections.abc import AsyncIterator
from datetime import datetime, timezone
from circuitforge_core.mqtt.meshtastic.interface import MeshtasticInterface
from circuitforge_core.mqtt.meshtastic.models import (
MeshtasticPacket,
MeshtasticPosition,
MeshtasticTelemetry,
)
logger = logging.getLogger(__name__)
_COORD_SCALE = 1e-7
def _packet_from_decoded(decoded: dict, from_id: int) -> MeshtasticPacket:
"""Convert a meshtastic-library decoded packet dict to MeshtasticPacket."""
portnum: str = decoded.get("portnum", "UNKNOWN_APP")
sender = f"!{from_id:08x}"
to_num: int = decoded.get("to", 0xFFFFFFFF)
channel: int = decoded.get("channel", 0)
now = datetime.now(tz=timezone.utc)
if portnum == "TEXT_MESSAGE_APP":
return MeshtasticPacket(
packet_type="text",
from_id=sender,
from_num=from_id,
to_num=to_num,
channel=channel,
received_at=now,
text=decoded.get("decoded", {}).get("text", ""),
raw=decoded,
)
if portnum == "POSITION_APP":
pos = decoded.get("decoded", {}).get("position", {})
lat_i = pos.get("latitudeI")
lon_i = pos.get("longitudeI")
alt = pos.get("altitude")
return MeshtasticPacket(
packet_type="position",
from_id=sender,
from_num=from_id,
to_num=to_num,
channel=channel,
received_at=now,
position=MeshtasticPosition(
latitude=lat_i * _COORD_SCALE if lat_i is not None else None,
longitude=lon_i * _COORD_SCALE if lon_i is not None else None,
altitude_m=alt,
timestamp=now,
),
raw=decoded,
)
if portnum == "NODEINFO_APP":
info = decoded.get("decoded", {}).get("user", {})
return MeshtasticPacket(
packet_type="nodeinfo",
from_id=sender,
from_num=from_id,
to_num=to_num,
channel=channel,
received_at=now,
node_longname=info.get("longName"),
node_shortname=info.get("shortName"),
hardware=info.get("hwModel"),
raw=decoded,
)
if portnum == "TELEMETRY_APP":
telem = decoded.get("decoded", {}).get("telemetry", {})
dev = telem.get("deviceMetrics", {})
return MeshtasticPacket(
packet_type="telemetry",
from_id=sender,
from_num=from_id,
to_num=to_num,
channel=channel,
received_at=now,
telemetry=MeshtasticTelemetry(
battery_level=dev.get("batteryLevel"),
voltage=dev.get("voltage"),
channel_util=dev.get("channelUtilization"),
air_util_tx=dev.get("airUtilTx"),
),
raw=decoded,
)
return MeshtasticPacket(
packet_type="unknown",
from_id=sender,
from_num=from_id,
to_num=to_num,
channel=channel,
received_at=now,
raw=decoded,
)
class SerialMeshtasticBackend(MeshtasticInterface):
"""Receive and send Meshtastic packets via serial port or TCP.
Args:
dev_path: serial device path (e.g. ``/dev/ttyUSB0``) or ``None``
to auto-detect the first connected Meshtastic device.
tcp_host: hostname for TCP connection. If set, ``dev_path`` is ignored
and a TCP connection to port 4403 is used.
tcp_port: TCP port (default 4403).
"""
def __init__(
self,
dev_path: str | None = None,
tcp_host: str | None = None,
tcp_port: int = 4403,
) -> None:
self._dev_path = dev_path
self._tcp_host = tcp_host
self._tcp_port = tcp_port
def _make_interface(self):
try:
import meshtastic.serial_interface
import meshtastic.tcp_interface
except ImportError as exc:
raise ImportError(
"meshtastic is required for SerialMeshtasticBackend. "
"Install with: pip install circuitforge-core[meshtastic-serial]"
) from exc
if self._tcp_host:
return meshtastic.tcp_interface.TCPInterface(
hostname=self._tcp_host,
portNumber=self._tcp_port,
)
return meshtastic.serial_interface.SerialInterface(devPath=self._dev_path)
async def packets(self) -> AsyncIterator[MeshtasticPacket]:
loop = asyncio.get_running_loop()
queue: asyncio.Queue[MeshtasticPacket | None] = asyncio.Queue()
def _on_receive(packet: dict, interface) -> None:
try:
from_id: int = packet.get("from", 0)
pkt = _packet_from_decoded(packet, from_id)
loop.call_soon_threadsafe(queue.put_nowait, pkt)
except Exception:
logger.exception("Error decoding Meshtastic serial packet")
def _on_connection_closed(interface) -> None:
logger.warning("Meshtastic serial connection closed")
loop.call_soon_threadsafe(queue.put_nowait, None)
iface = await loop.run_in_executor(None, self._make_interface)
try:
from pubsub import pub
pub.subscribe(_on_receive, "meshtastic.receive")
pub.subscribe(_on_connection_closed, "meshtastic.connection.lost")
except ImportError:
await loop.run_in_executor(None, iface.close)
raise ImportError(
"pypubsub is required for SerialMeshtasticBackend. "
"Install with: pip install circuitforge-core[meshtastic-serial]"
)
try:
while True:
pkt = await queue.get()
if pkt is None:
break
yield pkt
finally:
pub.unsubscribe(_on_receive, "meshtastic.receive")
pub.unsubscribe(_on_connection_closed, "meshtastic.connection.lost")
await loop.run_in_executor(None, iface.close)
async def send_text(
self,
text: str,
dest_id: int = 0xFFFFFFFF,
channel: int = 0,
) -> None:
loop = asyncio.get_running_loop()
iface = await loop.run_in_executor(None, self._make_interface)
try:
await loop.run_in_executor(
None,
lambda: iface.sendText(text, destinationId=dest_id, channelIndex=channel),
)
finally:
await loop.run_in_executor(None, iface.close)

View file

@ -1,44 +0,0 @@
"""Data models for the MQTT client module.
MIT licensed.
"""
from __future__ import annotations
import json
from dataclasses import dataclass, field
from datetime import datetime, timezone
@dataclass(frozen=True)
class MQTTConfig:
"""Connection config for an MQTT broker."""
host: str
port: int = 1883
username: str | None = None
password: str | None = None
client_id: str = ""
keepalive: int = 60
tls: bool = False
reconnect_interval: float = 5.0
@dataclass(frozen=True)
class MQTTMessage:
"""A single received MQTT message."""
topic: str
payload: bytes
qos: int = 0
retain: bool = False
received_at: datetime = field(default_factory=lambda: datetime.now(tz=timezone.utc))
def text(self, encoding: str = "utf-8") -> str:
return self.payload.decode(encoding, errors="replace")
def json(self) -> dict:
return json.loads(self.payload)
@property
def topic_parts(self) -> list[str]:
return self.topic.split("/")

View file

@ -1,74 +0,0 @@
"""MQTT topic router with wildcard pattern matching.
MIT licensed.
"""
from __future__ import annotations
import asyncio
import inspect
import logging
from collections.abc import Callable, Coroutine
from typing import Any
from circuitforge_core.mqtt.models import MQTTMessage
logger = logging.getLogger(__name__)
Handler = Callable[[MQTTMessage], Coroutine[Any, Any, None]]
def matches(pattern: str, topic: str) -> bool:
"""Return True if topic matches the MQTT wildcard pattern.
MQTT wildcard rules:
- '+' matches exactly one topic level (segment between '/' separators)
- '#' matches zero or more levels and MUST appear at the end of the pattern
- All other characters match literally
Examples:
matches("sensor/+/temp", "sensor/room1/temp") True
matches("sensor/+/temp", "sensor/a/b/temp") False
matches("sensor/#", "sensor/room1/temp") True
matches("sensor/#", "sensor") True (# = zero levels)
matches("#", "any/topic/here") True
matches("a/b/c", "a/b/c") True
"""
# TODO: implement wildcard matching
# Hint: split both pattern and topic on '/' and walk them in parallel.
# Handle '#' early (if it appears, everything past that point in topic matches).
# '+' must cover exactly one (non-empty) level.
raise NotImplementedError("matches() is not yet implemented")
class TopicRouter:
"""Register async handlers for MQTT topic patterns and dispatch messages."""
def __init__(self) -> None:
self._routes: list[tuple[str, Handler]] = []
@property
def patterns(self) -> list[str]:
return [p for p, _ in self._routes]
def register(self, pattern: str, handler: Handler) -> None:
"""Add a handler for the given topic pattern."""
self._routes.append((pattern, handler))
def on(self, pattern: str) -> Callable[[Handler], Handler]:
"""Decorator: @router.on("sensor/#") async def handle(msg): ..."""
def decorator(fn: Handler) -> Handler:
self.register(pattern, fn)
return fn
return decorator
async def dispatch(self, message: MQTTMessage) -> None:
"""Call all handlers whose pattern matches message.topic."""
for pattern, handler in self._routes:
try:
if matches(pattern, message.topic):
if inspect.iscoroutinefunction(handler):
await handler(message)
else:
handler(message)
except Exception:
logger.exception("Handler for %r raised on topic %r", pattern, message.topic)

View file

@ -51,12 +51,9 @@ cf-orch service profile (Phase 3 — remote backend):
""" """
from __future__ import annotations from __future__ import annotations
import logging
import os import os
from typing import Sequence from typing import Sequence
logger = logging.getLogger(__name__)
from circuitforge_core.reranker.base import RerankResult, Reranker, TextReranker from circuitforge_core.reranker.base import RerankResult, Reranker, TextReranker
from circuitforge_core.reranker.adapters.mock import MockTextReranker from circuitforge_core.reranker.adapters.mock import MockTextReranker

View file

@ -169,15 +169,7 @@ class LocalScheduler:
if not q: if not q:
break break
task = q.popleft() task = q.popleft()
try: self._run_task(self._db_path, task.id, task_type, task.job_id, task.params)
self._run_task(self._db_path, task.id, task_type, task.job_id, task.params)
except Exception as exc:
# run_task_fn should handle its own exceptions. If it leaks one,
# log it so the task doesn't silently stay 'queued' with no trace.
logger.exception(
"Unhandled exception in batch worker task %d (%s): %s",
task.id, task_type, exc,
)
finally: finally:
with self._lock: with self._lock:
self._active.pop(task_type, None) self._active.pop(task_type, None)

View file

@ -1,15 +1,14 @@
""" """
cf-text FastAPI service managed by cf-orch. cf-text FastAPI service managed by cf-orch.
Lightweight local text generation and PII filtering. Supports GGUF models via Lightweight local text generation. Supports GGUF models via llama.cpp and
llama.cpp, HuggingFace transformers, and token-classification models (classifier HuggingFace transformers. Sits alongside vllm/ollama for products that need
backend) for PII detection and redaction. fast, frequent inference from small local models (3B7B Q4).
Endpoints: Endpoints:
GET /health {"status": "ok", "model": str, "vram_mb": int, "backend": str} GET /health {"status": "ok", "model": str, "vram_mb": int, "backend": str}
POST /generate GenerateResponse (text-gen backends only) POST /generate GenerateResponse
POST /chat GenerateResponse (text-gen backends only) POST /chat GenerateResponse
POST /filter FilterResponse (classifier backend only)
Usage: Usage:
python -m circuitforge_core.text.app \ python -m circuitforge_core.text.app \
@ -35,46 +34,17 @@ import os
import time import time
import uuid import uuid
from functools import partial from functools import partial
from typing import Annotated, Literal, Union
import uvicorn import uvicorn
from fastapi import FastAPI, HTTPException from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field from pydantic import BaseModel
from circuitforge_core.text.backends.base import ChatMessage as BackendChatMessage from circuitforge_core.text.backends.base import ChatMessage as BackendChatMessage
from circuitforge_core.text.backends.base import make_classifier_backend, make_text_backend from circuitforge_core.text.backends.base import make_text_backend
from circuitforge_core.text.filter import FilterResult, PIIFilter
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_backend = None _backend = None
_pii_filter: PIIFilter | None = None
# ── Content block types (OpenAI multimodal format) ────────────────────────────
class ContentBlockText(BaseModel):
type: Literal["text"]
text: str
class ContentBlockImageURL(BaseModel):
type: Literal["image_url"]
image_url: dict[str, str]
ContentBlock = Annotated[
Union[ContentBlockText, ContentBlockImageURL],
Field(discriminator="type"),
]
def _to_backend_message(role: str, content: "str | list[ContentBlock]") -> "BackendChatMessage":
"""Convert an API message to a BackendChatMessage with raw content dicts."""
if isinstance(content, str):
return BackendChatMessage(role, content)
return BackendChatMessage(role, [b.model_dump() for b in content])
# ── Request / response models ───────────────────────────────────────────────── # ── Request / response models ─────────────────────────────────────────────────
@ -89,7 +59,7 @@ class GenerateRequest(BaseModel):
class ChatMessageModel(BaseModel): class ChatMessageModel(BaseModel):
role: str role: str
content: Union[str, list[ContentBlock]] = "" content: str
class ChatRequest(BaseModel): class ChatRequest(BaseModel):
@ -104,31 +74,12 @@ class GenerateResponse(BaseModel):
model: str = "" model: str = ""
class FilterRequest(BaseModel):
text: str
class PIISpanResponse(BaseModel):
label: str
start: int
end: int
text: str
score: float
class FilterResponse(BaseModel):
redacted_text: str
spans: list[PIISpanResponse]
original_text: str
model: str = ""
# ── OpenAI-compat request / response (for LLMRouter openai_compat path) ────── # ── OpenAI-compat request / response (for LLMRouter openai_compat path) ──────
class OAIMessageModel(BaseModel): class OAIMessageModel(BaseModel):
role: str role: str
content: Union[str, list[ContentBlock]] = "" content: str
class OAIChatRequest(BaseModel): class OAIChatRequest(BaseModel):
@ -169,7 +120,6 @@ def create_app(
gpu_ids: str | None = None, gpu_ids: str | None = None,
backend: str | None = None, backend: str | None = None,
mock: bool = False, mock: bool = False,
mmproj_path: str = "",
) -> FastAPI: ) -> FastAPI:
"""Start the cf-text FastAPI app. """Start the cf-text FastAPI app.
@ -177,12 +127,8 @@ def create_app(
(e.g. "0,1"). When set, overrides ``gpu_id`` and sets (e.g. "0,1"). When set, overrides ``gpu_id`` and sets
``CUDA_VISIBLE_DEVICES`` to the full list so HuggingFace Accelerate's ``CUDA_VISIBLE_DEVICES`` to the full list so HuggingFace Accelerate's
``device_map="auto"`` can shard the model across all listed devices. ``device_map="auto"`` can shard the model across all listed devices.
When ``backend="classifier"``, the service skips the text-gen backends
and loads a token-classification pipeline instead. Only ``POST /filter``
is available in that mode; ``/generate`` and ``/chat`` return 501.
""" """
global _backend, _pii_filter global _backend
if not mock and not model_path: if not mock and not model_path:
raise ValueError( raise ValueError(
@ -193,26 +139,13 @@ def create_app(
visible = gpu_ids if gpu_ids else str(gpu_id) visible = gpu_ids if gpu_ids else str(gpu_id)
os.environ.setdefault("CUDA_VISIBLE_DEVICES", visible) os.environ.setdefault("CUDA_VISIBLE_DEVICES", visible)
resolved_backend = backend or os.environ.get("CF_TEXT_BACKEND", "") _backend = make_text_backend(model_path, backend=backend, mock=mock)
if resolved_backend == "classifier" or (not resolved_backend and False): logger.info("cf-text ready: model=%r vram=%dMB", _backend.model_name, _backend.vram_mb)
classifier_backend = make_classifier_backend(model_path)
_pii_filter = PIIFilter.from_backend(classifier_backend)
logger.info(
"cf-text (classifier) ready: model=%r vram=%dMB",
classifier_backend.model_name,
classifier_backend.vram_mb,
)
else:
_backend = make_text_backend(model_path, backend=backend, mock=mock, mmproj_path=mmproj_path)
logger.info("cf-text ready: model=%r vram=%dMB", _backend.model_name, _backend.vram_mb)
app = FastAPI(title="cf-text", version="0.1.0") app = FastAPI(title="cf-text", version="0.1.0")
@app.get("/health") @app.get("/health")
def health() -> dict: def health() -> dict:
if _pii_filter is not None:
b = _pii_filter._backend
return {"status": "ok", "model": b.model_name, "vram_mb": b.vram_mb, "backend": "classifier"}
if _backend is None: if _backend is None:
raise HTTPException(503, detail="backend not initialised") raise HTTPException(503, detail="backend not initialised")
return { return {
@ -221,35 +154,8 @@ def create_app(
"vram_mb": _backend.vram_mb, "vram_mb": _backend.vram_mb,
} }
@app.post("/filter")
async def filter_text(req: FilterRequest) -> FilterResponse:
if _pii_filter is None:
raise HTTPException(
501,
detail="This cf-text instance is not running a classifier backend. "
"Start with --backend classifier and a token-classification model.",
)
result = await _pii_filter.filter_async(req.text)
return FilterResponse(
redacted_text=result.redacted_text,
spans=[
PIISpanResponse(
label=s.label,
start=s.start,
end=s.end,
text=s.text,
score=s.score,
)
for s in result.spans
],
original_text=result.original_text,
model=_pii_filter._backend.model_name,
)
@app.post("/generate") @app.post("/generate")
async def generate(req: GenerateRequest) -> GenerateResponse: async def generate(req: GenerateRequest) -> GenerateResponse:
if _pii_filter is not None:
raise HTTPException(501, detail="classifier backend loaded — use POST /filter")
if _backend is None: if _backend is None:
raise HTTPException(503, detail="backend not initialised") raise HTTPException(503, detail="backend not initialised")
result = await _backend.generate_async( result = await _backend.generate_async(
@ -266,20 +172,16 @@ def create_app(
@app.post("/chat") @app.post("/chat")
async def chat(req: ChatRequest) -> GenerateResponse: async def chat(req: ChatRequest) -> GenerateResponse:
if _pii_filter is not None:
raise HTTPException(501, detail="classifier backend loaded — use POST /filter")
if _backend is None: if _backend is None:
raise HTTPException(503, detail="backend not initialised") raise HTTPException(503, detail="backend not initialised")
messages = [_to_backend_message(m.role, m.content) for m in req.messages] messages = [BackendChatMessage(m.role, m.content) for m in req.messages]
# chat() is sync-only in the Protocol; run in thread pool to avoid blocking
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
try: result = await loop.run_in_executor(
result = await loop.run_in_executor( None,
None, partial(_backend.chat, messages,
partial(_backend.chat, messages, max_tokens=req.max_tokens, temperature=req.temperature),
max_tokens=req.max_tokens, temperature=req.temperature), )
)
except ValueError as exc:
raise HTTPException(422, detail=str(exc)) from exc
return GenerateResponse( return GenerateResponse(
text=result.text, text=result.text,
tokens_used=result.tokens_used, tokens_used=result.tokens_used,
@ -296,16 +198,13 @@ def create_app(
""" """
if _backend is None: if _backend is None:
raise HTTPException(503, detail="backend not initialised") raise HTTPException(503, detail="backend not initialised")
messages = [_to_backend_message(m.role, m.content) for m in req.messages] messages = [BackendChatMessage(m.role, m.content) for m in req.messages]
max_tok = req.max_tokens or 512 max_tok = req.max_tokens or 512
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
try: result = await loop.run_in_executor(
result = await loop.run_in_executor( None,
None, partial(_backend.chat, messages, max_tokens=max_tok, temperature=req.temperature),
partial(_backend.chat, messages, max_tokens=max_tok, temperature=req.temperature), )
)
except ValueError as exc:
raise HTTPException(422, detail=str(exc)) from exc
return OAIChatResponse( return OAIChatResponse(
id=f"cftext-{uuid.uuid4().hex[:12]}", id=f"cftext-{uuid.uuid4().hex[:12]}",
created=int(time.time()), created=int(time.time()),
@ -331,16 +230,7 @@ def _parse_args() -> argparse.Namespace:
parser.add_argument("--gpu-ids", default=None, parser.add_argument("--gpu-ids", default=None,
help="Comma-separated CUDA device indices for multi-GPU spanning " help="Comma-separated CUDA device indices for multi-GPU spanning "
"(e.g. '0,1'). Overrides --gpu-id when set.") "(e.g. '0,1'). Overrides --gpu-id when set.")
parser.add_argument( parser.add_argument("--backend", choices=["llamacpp", "transformers"], default=None)
"--backend",
choices=["llamacpp", "transformers", "ollama", "vllm", "classifier"],
default=None,
)
parser.add_argument(
"--mmproj", default="",
help="Path to multimodal projector file for VLM GGUF models (LLaVA-style). "
"Qwen2-VL and other self-contained VLMs don't need this.",
)
parser.add_argument("--mock", action="store_true", parser.add_argument("--mock", action="store_true",
help="Run in mock mode (no model or GPU needed)") help="Run in mock mode (no model or GPU needed)")
return parser.parse_args() return parser.parse_args()
@ -357,6 +247,5 @@ if __name__ == "__main__":
gpu_ids=args.gpu_ids, gpu_ids=args.gpu_ids,
backend=args.backend, backend=args.backend,
mock=mock, mock=mock,
mmproj_path=args.mmproj,
) )
uvicorn.run(app, host=args.host, port=args.port, log_level="info") uvicorn.run(app, host=args.host, port=args.port, log_level="info")

View file

@ -24,44 +24,17 @@ class GenerateResult:
class ChatMessage: class ChatMessage:
"""A single message in a chat conversation. """A single message in a chat conversation."""
``content`` is either a plain string or a list of OpenAI-format content def __init__(self, role: str, content: str) -> None:
blocks (dicts with ``type: "text"`` or ``type: "image_url"``). Backends
that do not support images should call ``text_only`` to get the string
form before passing to the model.
"""
def __init__(self, role: str, content: "str | list") -> None:
if role not in ("system", "user", "assistant"): if role not in ("system", "user", "assistant"):
raise ValueError(f"Invalid role {role!r}. Must be system, user, or assistant.") raise ValueError(f"Invalid role {role!r}. Must be system, user, or assistant.")
self.role = role self.role = role
self.content: "str | list" = content self.content = content
def to_dict(self) -> dict: def to_dict(self) -> dict:
return {"role": self.role, "content": self.content} return {"role": self.role, "content": self.content}
@property
def has_images(self) -> bool:
"""True when at least one content block is an image_url block."""
if isinstance(self.content, str):
return False
return any(
isinstance(b, dict) and b.get("type") == "image_url"
for b in self.content
)
@property
def text_only(self) -> str:
"""Flatten multimodal content to text. Returns content as-is if already str."""
if isinstance(self.content, str):
return self.content
return "\n".join(
b["text"]
for b in self.content
if isinstance(b, dict) and b.get("type") == "text"
)
# ── TextBackend Protocol ────────────────────────────────────────────────────── # ── TextBackend Protocol ──────────────────────────────────────────────────────
@ -143,33 +116,6 @@ class TextBackend(Protocol):
... ...
# ── FilterBackend Protocol ────────────────────────────────────────────────────
@runtime_checkable
class FilterBackend(Protocol):
"""
Abstract interface for token-classification / PII-filter backends.
Separate from TextBackend returns entity spans and redacted text,
not generated text.
"""
def classify(self, text: str) -> list[dict]:
"""Synchronous classify — returns list of entity span dicts."""
...
async def classify_async(self, text: str) -> list[dict]:
"""Async classify — runs in thread pool."""
...
@property
def model_name(self) -> str: ...
@property
def vram_mb(self) -> int: ...
# ── Backend selection ───────────────────────────────────────────────────────── # ── Backend selection ─────────────────────────────────────────────────────────
@ -187,7 +133,7 @@ def _select_backend(model_path: str, backend: str | None) -> str:
Raise ValueError for unrecognised override values. Raise ValueError for unrecognised override values.
""" """
_VALID = ("llamacpp", "transformers", "ollama", "vllm", "classifier") _VALID = ("llamacpp", "transformers", "ollama", "vllm")
# 1. Caller-supplied override — highest trust, no inspection needed. # 1. Caller-supplied override — highest trust, no inspection needed.
resolved = backend or os.environ.get("CF_TEXT_BACKEND") resolved = backend or os.environ.get("CF_TEXT_BACKEND")
@ -207,11 +153,6 @@ def _select_backend(model_path: str, backend: str | None) -> str:
# 3. Format detection — GGUF files are unambiguously llama-cpp territory. # 3. Format detection — GGUF files are unambiguously llama-cpp territory.
if model_path.lower().endswith(".gguf"): if model_path.lower().endswith(".gguf"):
return "llamacpp" return "llamacpp"
# 3b. GGUF directory — avocet downloads whole repos; scan for .gguf contents.
if os.path.isdir(model_path):
import glob as _glob
if _glob.glob(os.path.join(model_path, "*.gguf")) or _glob.glob(os.path.join(model_path, "*.GGUF")):
return "llamacpp"
# 4. Safe default — transformers covers HF repo IDs and safetensors dirs. # 4. Safe default — transformers covers HF repo IDs and safetensors dirs.
return "transformers" return "transformers"
@ -224,7 +165,6 @@ def make_text_backend(
model_path: str, model_path: str,
backend: str | None = None, backend: str | None = None,
mock: bool | None = None, mock: bool | None = None,
mmproj_path: str = "",
) -> "TextBackend": ) -> "TextBackend":
""" """
Return a TextBackend for the given model. Return a TextBackend for the given model.
@ -241,7 +181,7 @@ def make_text_backend(
if resolved == "llamacpp": if resolved == "llamacpp":
from circuitforge_core.text.backends.llamacpp import LlamaCppBackend from circuitforge_core.text.backends.llamacpp import LlamaCppBackend
return LlamaCppBackend(model_path=model_path, mmproj_path=mmproj_path) return LlamaCppBackend(model_path=model_path)
if resolved == "transformers": if resolved == "transformers":
from circuitforge_core.text.backends.transformers import TransformersBackend from circuitforge_core.text.backends.transformers import TransformersBackend
@ -255,22 +195,4 @@ def make_text_backend(
from circuitforge_core.text.backends.vllm import VllmBackend from circuitforge_core.text.backends.vllm import VllmBackend
return VllmBackend(model_path=model_path) return VllmBackend(model_path=model_path)
raise ValueError( raise ValueError(f"Unknown backend {resolved!r}. Expected 'llamacpp', 'transformers', 'ollama', or 'vllm'.")
f"Unknown backend {resolved!r}. "
"Expected 'llamacpp', 'transformers', 'ollama', 'vllm', or 'classifier'."
)
def make_classifier_backend(model_path: str) -> "FilterBackend":
"""
Return a FilterBackend for the given token-classification model.
CF_TEXT_MOCK=1 MockClassifierBackend (no GPU, no model file needed)
Otherwise ClassifierBackend via transformers pipeline
"""
if os.environ.get("CF_TEXT_MOCK", "") == "1":
from circuitforge_core.text.backends.mock import MockClassifierBackend
return MockClassifierBackend(model_name=model_path)
from circuitforge_core.text.backends.classifier import ClassifierBackend
return ClassifierBackend(model_path=model_path)

View file

@ -1,88 +0,0 @@
# circuitforge_core/text/backends/classifier.py — HuggingFace token-classification backend
#
# BSL 1.1. Requires torch + transformers.
# Install: pip install circuitforge-core[text-transformers]
#
# Wraps pipeline("token-classification") for PII/entity detection.
# Returns spans with char offsets, entity labels, and confidence scores.
# Use make_classifier_backend() from base.py to instantiate.
from __future__ import annotations
import asyncio
import logging
import os
from typing import Any
logger = logging.getLogger(__name__)
class ClassifierBackend:
"""
HuggingFace token-classification backend for PII detection and entity labeling.
Loads any token-classification model from HuggingFace Hub or a local checkpoint.
Returns aggregated entity spans with char offsets suitable for redaction or audit.
Aggregation strategy "simple" merges consecutive BIO-tagged subwords into word-level
spans and strips the B-/I- prefixes so callers see "NAME" not "B-NAME".
Requires: pip install circuitforge-core[text-transformers]
"""
def __init__(self, model_path: str) -> None:
try:
import torch
from transformers import pipeline as hf_pipeline
except ImportError as exc:
raise ImportError(
"torch and transformers are required for ClassifierBackend. "
"Install with: pip install circuitforge-core[text-transformers]"
) from exc
device = 0 if torch.cuda.is_available() else -1
cuda_devices = os.environ.get("CUDA_VISIBLE_DEVICES", "")
if cuda_devices:
device = 0
logger.info("Loading classifier model %s on device %s", model_path, device)
self._pipeline = hf_pipeline(
"token-classification",
model=model_path,
aggregation_strategy="simple",
device=device,
)
self._model_path = model_path
@property
def model_name(self) -> str:
return self._model_path.split("/")[-1]
@property
def vram_mb(self) -> int:
try:
import torch
if torch.cuda.is_available():
return torch.cuda.memory_allocated() // (1024 * 1024)
except Exception:
pass
return 0
def classify(self, text: str) -> list[dict[str, Any]]:
"""
Run token classification synchronously.
Returns a list of entity dicts with keys:
entity_group: str label without BIO prefix (e.g. "NAME", "EMAIL")
score: float aggregated confidence
word: str matched text span
start: int char offset (start, inclusive)
end: int char offset (end, exclusive)
"""
results: list[dict[str, Any]] = self._pipeline(text)
return results
async def classify_async(self, text: str) -> list[dict[str, Any]]:
"""Async classify — runs pipeline in thread pool to avoid blocking the event loop."""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, self.classify, text)

View file

@ -48,16 +48,7 @@ class LlamaCppBackend:
Requires: pip install circuitforge-core[text-llamacpp] Requires: pip install circuitforge-core[text-llamacpp]
""" """
def __init__(self, model_path: str, mmproj_path: str = "", chat_format: str = "") -> None: def __init__(self, model_path: str) -> None:
"""Load a GGUF model.
``mmproj_path``: path to a separate multimodal projector file (needed
for LLaVA-style VLMs where the visual encoder is a separate .gguf).
Qwen2-VL and similar models with an embedded projector don't need this.
``chat_format``: llama-cpp chat template override (e.g. "llava-1-5",
"moondream"). Required when mmproj_path is set.
"""
try: try:
from llama_cpp import Llama # type: ignore[import] from llama_cpp import Llama # type: ignore[import]
except ImportError as exc: except ImportError as exc:
@ -72,53 +63,20 @@ class LlamaCppBackend:
"Download a GGUF model and set CF_TEXT_MODEL to its path." "Download a GGUF model and set CF_TEXT_MODEL to its path."
) )
# If given a directory, find the .gguf file inside it.
if Path(model_path).is_dir():
candidates = sorted(Path(model_path).glob("*.gguf")) or sorted(Path(model_path).glob("*.GGUF"))
if not candidates:
raise FileNotFoundError(
f"No .gguf file found in directory: {model_path}"
)
model_path = str(candidates[0])
n_threads = int(os.environ.get("CF_TEXT_THREADS", "0")) or None n_threads = int(os.environ.get("CF_TEXT_THREADS", "0")) or None
logger.info(
kwargs: dict = dict( "Loading GGUF model %s (ctx=%d, gpu_layers=%d)",
model_path, _DEFAULT_N_CTX, _DEFAULT_N_GPU_LAYERS,
)
self._llm = Llama(
model_path=model_path, model_path=model_path,
n_ctx=_DEFAULT_N_CTX, n_ctx=_DEFAULT_N_CTX,
n_gpu_layers=_DEFAULT_N_GPU_LAYERS, n_gpu_layers=_DEFAULT_N_GPU_LAYERS,
n_threads=n_threads, n_threads=n_threads,
verbose=False, verbose=False,
) )
if mmproj_path:
kwargs["clip_model_path"] = mmproj_path
kwargs["chat_format"] = chat_format or "llava-1-5"
logger.info(
"Loading VLM %s with mmproj %s (ctx=%d, gpu_layers=%d)",
model_path, mmproj_path, _DEFAULT_N_CTX, _DEFAULT_N_GPU_LAYERS,
)
else:
logger.info(
"Loading GGUF model %s (ctx=%d, gpu_layers=%d)",
model_path, _DEFAULT_N_CTX, _DEFAULT_N_GPU_LAYERS,
)
self._llm = Llama(**kwargs)
self._model_path = model_path self._model_path = model_path
self._vram_mb = _estimate_vram_mb(model_path) self._vram_mb = _estimate_vram_mb(model_path)
# True when the model was initialised with a visual encoder (explicit
# mmproj) or when it is a known self-contained VLM (Qwen2-VL, etc.).
self._is_vlm = bool(mmproj_path) or self._detect_embedded_vlm()
def _detect_embedded_vlm(self) -> bool:
"""Heuristic: check model metadata for a known multimodal architecture."""
try:
meta = self._llm.metadata or {}
arch = str(meta.get("general.architecture", "")).lower()
# Qwen2-VL and similar embed the vision encoder inside the GGUF.
return any(tag in arch for tag in ("qwen2_vl", "llava", "moondream", "minicpm-v"))
except Exception:
return False
@property @property
def model_name(self) -> str: def model_name(self) -> str:
@ -223,14 +181,7 @@ class LlamaCppBackend:
max_tokens: int = 512, max_tokens: int = 512,
temperature: float = 0.7, temperature: float = 0.7,
) -> GenerateResult: ) -> GenerateResult:
# Detect image content before calling the model. # llama-cpp-python has native chat_completion for instruct models
if any(m.has_images for m in messages) and not self._is_vlm:
raise ValueError(
"model does not support image input — "
"load a VLM (with mmproj_path) or route to cf-vision/cf-docuvision"
)
# llama-cpp-python create_chat_completion accepts content as str or
# list-of-blocks (OpenAI multimodal format) natively.
output = self._llm.create_chat_completion( output = self._llm.create_chat_completion(
messages=[m.to_dict() for m in messages], messages=[m.to_dict() for m in messages],
max_tokens=max_tokens, max_tokens=max_tokens,

View file

@ -102,49 +102,3 @@ class MockTextBackend:
# Format messages into a simple prompt for the mock response # Format messages into a simple prompt for the mock response
prompt = "\n".join(f"{m.role}: {m.content}" for m in messages) prompt = "\n".join(f"{m.role}: {m.content}" for m in messages)
return self.generate(prompt, max_tokens=max_tokens, temperature=temperature) return self.generate(prompt, max_tokens=max_tokens, temperature=temperature)
# Synthetic PII spans injected by MockClassifierBackend — predictable in tests.
_MOCK_SPANS = [
{
"entity_group": "NAME",
"score": 0.99,
"word": "Jane Doe",
"start": 0,
"end": 8,
},
{
"entity_group": "EMAIL",
"score": 0.97,
"word": "jane@example.com",
"start": 18,
"end": 34,
},
]
class MockClassifierBackend:
"""
Deterministic mock classifier backend for development and CI.
Always returns the same two synthetic PII spans regardless of input.
Allows filter.py logic (redaction, span conversion) to be tested without
a real model or GPU.
"""
def __init__(self, model_name: str = "mock-classifier") -> None:
self._model_name = model_name
@property
def model_name(self) -> str:
return self._model_name
@property
def vram_mb(self) -> int:
return 0
def classify(self, text: str) -> list[dict]:
return list(_MOCK_SPANS)
async def classify_async(self, text: str) -> list[dict]:
return self.classify(text)

View file

@ -50,12 +50,10 @@ class TransformersBackend:
logger.info("Loading transformers model %s on %s", model_path, self._device) logger.info("Loading transformers model %s on %s", model_path, self._device)
load_kwargs: dict = {"device_map": "auto" if self._device == "cuda" else None} load_kwargs: dict = {"device_map": "auto" if self._device == "cuda" else None}
if _LOAD_IN_4BIT or _LOAD_IN_8BIT: if _LOAD_IN_4BIT:
from transformers import BitsAndBytesConfig load_kwargs["load_in_4bit"] = True
load_kwargs["quantization_config"] = BitsAndBytesConfig( elif _LOAD_IN_8BIT:
load_in_4bit=_LOAD_IN_4BIT, load_kwargs["load_in_8bit"] = True
load_in_8bit=_LOAD_IN_8BIT,
)
self._tokenizer = AutoTokenizer.from_pretrained(model_path) self._tokenizer = AutoTokenizer.from_pretrained(model_path)
self._model = AutoModelForCausalLM.from_pretrained(model_path, **load_kwargs) self._model = AutoModelForCausalLM.from_pretrained(model_path, **load_kwargs)

View file

@ -1,114 +0,0 @@
# circuitforge_core/text/filter.py — PII detection and redaction
#
# BSL 1.1. Products import PIIFilter for pre-send redaction and audit trails.
# Requires a running cf-filter service (or ClassifierBackend for in-process use).
from __future__ import annotations
import re
from dataclasses import dataclass, field
from typing import Any
from circuitforge_core.text.backends.base import FilterBackend, make_classifier_backend
@dataclass(frozen=True)
class PIISpan:
"""A single detected PII entity in the source text."""
label: str # e.g. NAME | EMAIL | PHONE_NUM | ADDRESS | SSN | DOB | IP_ADDRESS
start: int # char offset (inclusive) in original_text
end: int # char offset (exclusive) in original_text
text: str # original span text
score: float # confidence score from the classifier
@dataclass(frozen=True)
class FilterResult:
"""Output of PIIFilter.filter().
``redacted_text``: safe-to-send copy with each span replaced by ``[LABEL]``.
``spans``: all detected entities for audit logs or caller-side decisions.
``original_text``: the input text (stored for round-trip comparisons).
"""
redacted_text: str
spans: list[PIISpan] = field(default_factory=list)
original_text: str = ""
def _redact(text: str, spans: list[PIISpan]) -> str:
"""Replace each span in text with ``[LABEL]``, processing right-to-left so
earlier offsets remain valid after each substitution."""
result = text
for span in sorted(spans, key=lambda s: s.start, reverse=True):
result = result[: span.start] + f"[{span.label}]" + result[span.end :]
return result
def _spans_from_pipeline(raw: list[dict[str, Any]]) -> list[PIISpan]:
"""Convert raw pipeline output dicts into typed PIISpan objects.
Pipeline returns dicts with keys: entity_group, score, word, start, end.
Normalise label to uppercase and strip any residual BIO prefixes.
"""
spans: list[PIISpan] = []
for item in raw:
label = re.sub(r"^[BI]-", "", item.get("entity_group", "")).upper()
spans.append(
PIISpan(
label=label,
start=int(item["start"]),
end=int(item["end"]),
text=item.get("word", ""),
score=float(item.get("score", 0.0)),
)
)
return spans
class PIIFilter:
"""
High-level PII filter backed by a token-classification model.
Usage:
pii_filter = PIIFilter.from_model("openai/privacy-filter")
result = await pii_filter.filter_async(resume_text)
safe_text = result.redacted_text # send to cloud LLM
spans = result.spans # store for audit trail
For in-process use (no cf-orch), pass a model path and it loads directly.
For service-backed use, see PIIFilter.from_backend().
"""
def __init__(self, backend: FilterBackend) -> None:
self._backend = backend
@classmethod
def from_model(cls, model_path: str) -> "PIIFilter":
"""Load a classifier model in-process (no cf-orch required)."""
return cls(make_classifier_backend(model_path))
@classmethod
def from_backend(cls, backend: FilterBackend) -> "PIIFilter":
"""Wrap an already-constructed FilterBackend."""
return cls(backend)
def filter(self, text: str) -> FilterResult:
"""Synchronous filter — blocks until classification is complete."""
raw = self._backend.classify(text)
spans = _spans_from_pipeline(raw)
return FilterResult(
redacted_text=_redact(text, spans),
spans=spans,
original_text=text,
)
async def filter_async(self, text: str) -> FilterResult:
"""Async filter — runs classifier in thread pool."""
raw = await self._backend.classify_async(text)
spans = _spans_from_pipeline(raw)
return FilterResult(
redacted_text=_redact(text, spans),
spans=spans,
original_text=text,
)

View file

@ -5,7 +5,7 @@ circuitforge-core is distributed as an editable install from a local clone. It i
## Prerequisites ## Prerequisites
- Python 3.11+ - Python 3.11+
- A Python environment — conda or venv (see options below) - A conda environment (CircuitForge uses `cf` by convention; older envs may be named `job-seeker`)
- The `circuitforge-core` repo cloned alongside your product repo - The `circuitforge-core` repo cloned alongside your product repo
## Typical layout ## Typical layout
@ -21,10 +21,6 @@ circuitforge-core is distributed as an editable install from a local clone. It i
## Install ## Install
### Option A: conda (dev machines)
The CircuitForge conda environment is named `cf`:
```bash ```bash
# From inside a product repo, assuming circuitforge-core is a sibling # From inside a product repo, assuming circuitforge-core is a sibling
conda run -n cf pip install -e ../circuitforge-core conda run -n cf pip install -e ../circuitforge-core
@ -34,29 +30,13 @@ conda activate cf
pip install -e ../circuitforge-core pip install -e ../circuitforge-core
``` ```
### Option B: venv (server and beta-host deployments)
For hosts that don't use conda (CI runners, beta VMs, Xander's orchard nodes):
```bash
python3 -m venv .venv
source .venv/bin/activate
pip install -e /path/to/circuitforge-core
```
Or if cf-core is a sibling directory of the product:
```bash
pip install -e ../circuitforge-core
```
The editable install means changes to circuitforge-core source are reflected immediately in all products without reinstalling. Only restart the product's process after changes (or Docker container if running in Docker). The editable install means changes to circuitforge-core source are reflected immediately in all products without reinstalling. Only restart the product's process after changes (or Docker container if running in Docker).
## Verify ## Verify
```python ```python
import circuitforge_core import circuitforge_core
print(circuitforge_core.__version__) # e.g. 0.21.0 print(circuitforge_core.__version__) # 0.9.0
``` ```
## Inside Docker ## Inside Docker

View file

@ -1,151 +0,0 @@
# circuitforge_core.memory
Persistent knowledge graph for CF products, backed by the
[mnemo](https://github.com/zaydmulani09/mnemo) sidecar.
## What it does
mnemo runs as a sidecar process alongside a product's FastAPI backend. It:
- Extracts named entities and relationships from text you feed it
- Persists them in a local SQLite database with WAL mode
- Returns a formatted context block for prompt injection in under 5ms
`cf_core.memory` wraps mnemo's Python SDK with CF-standard config,
graceful degradation (no-ops when the sidecar is absent), and
exponential backoff with automatic reconnect after transient failures.
## Install
```bash
pip install circuitforge-core[memory]
```
## Docker Compose setup
Add the `mnemo` service to your product's `compose.yml` alongside `ollama`.
Peregrine is the reference implementation — copy the block from
`peregrine/compose.yml`:
```yaml
services:
mnemo:
image: ghcr.io/zaydmulani09/mnemo:latest
ports:
- "${MNEMO_PORT:-8080}:8080"
volumes:
- mnemo-data:/data
environment:
- MNEMO_DB_PATH=/data/mnemo.db
- MNEMO_LLM_PROVIDER=${MNEMO_LLM_PROVIDER:-ollama}
- MNEMO_LLM_BASE_URL=${MNEMO_LLM_BASE_URL:-http://ollama:11434/v1}
- MNEMO_LLM_API_KEY=${MNEMO_LLM_API_KEY:-ollama}
- MNEMO_LLM_MODEL=${MNEMO_LLM_MODEL:-llama3.2:3b}
depends_on:
- ollama
healthcheck:
test: ["CMD", "wget", "-q", "--spider", "http://localhost:8080/health"]
interval: 15s
timeout: 5s
retries: 3
profiles: [memory]
restart: unless-stopped
volumes:
mnemo-data:
```
Add these to the product's api service environment:
```yaml
environment:
- MNEMO_HOST=${MNEMO_HOST:-mnemo}
- MNEMO_PORT=${MNEMO_PORT:-8080}
```
Launch with:
```bash
docker compose --profile memory --profile cpu up -d
# or alongside a GPU profile:
docker compose --profile memory --profile single-gpu up -d
```
## Environment variables
| Variable | Default | Description |
|---|---|---|
| `MNEMO_HOST` | `localhost` | Sidecar hostname (use `mnemo` in Docker) |
| `MNEMO_PORT` | `8080` | Sidecar port |
| `MNEMO_TIMEOUT` | `10.0` | HTTP timeout in seconds |
The sidecar itself is configured via `MNEMO_LLM_*` env vars (see compose block above).
## FastAPI integration
```python
from contextlib import asynccontextmanager
from fastapi import FastAPI
from circuitforge_core.memory import MemoryClient, MemoryConfig
memory = MemoryClient(MemoryConfig.from_env())
@asynccontextmanager
async def lifespan(app: FastAPI):
await memory.connect() # no-op + warning if sidecar absent
yield
await memory.close()
app = FastAPI(lifespan=lifespan)
```
## API
```python
# Store a text fragment (conversation turn, fact, user preference, etc.)
await memory.remember("User avoids shellfish and prefers dark mode", source="settings")
# Retrieve a prompt-ready context block
context = await memory.recall("What are this user's dietary restrictions?")
system_prompt = f"You are a helpful assistant.\n\n{context}"
# List extracted entities
entities = await memory.entities(limit=20)
# Stats snapshot
stats = await memory.stats() # MemoryStats | None
# Wipe everything (irreversible)
await memory.wipe()
```
All methods return empty values (`False`, `""`, `[]`, `None`) when the
sidecar is not available — no try/except needed in product code.
## Resilience model
| Event | Behaviour |
|---|---|
| Sidecar absent at startup | `connect()` logs once, enters no-op mode |
| First call failure | Warning logged, 5s backoff scheduled |
| Nth consecutive failure | Backoff doubles each time (5→10→20→40→60s cap) |
| After `_MAX_FAILURES` (3) | Client marked unavailable; all calls no-op |
| Cooldown elapses | Next call silently attempts reconnect |
| Successful call | Failure counter and retry timer reset |
| `strict=True` | `MemoryUnavailableError` raised instead of no-op |
## Chunking note
mnemo stores each `remember()` call as a single chunk — it does **not**
automatically split large texts. For best retrieval quality, chunk on the
caller side before ingesting:
```python
# Good: one turn per ingest call
for turn in conversation_turns:
await memory.remember(turn, source="chat", session_id=session_id)
# Avoid: one giant blob
await memory.remember(entire_conversation_as_one_string)
```

View file

@ -14,9 +14,6 @@ dependencies = [
] ]
[project.optional-dependencies] [project.optional-dependencies]
memory = [
"mnemo-sdk>=0.1.0",
]
community = [ community = [
"psycopg2>=2.9", "psycopg2>=2.9",
] ]

View file

@ -1,281 +0,0 @@
"""Tests for circuitforge_core.memory.
These tests mock the mnemo SDK so no live sidecar is required.
"""
from __future__ import annotations
import sys
import time
from types import ModuleType
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from circuitforge_core.memory import MemoryClient, MemoryConfig, MemoryUnavailableError
from circuitforge_core.memory.client import _MAX_FAILURES
# ── Helpers ───────────────────────────────────────────────────────────────────
def _make_mock_mnemo(health_ok: bool = True):
"""Return a (mock_module, mock_inner_client) pair."""
mock_health = MagicMock(
status="ok" if health_ok else "error",
provider_type="ollama",
provider_model="llama3",
)
mock_client = AsyncMock()
mock_client.health = AsyncMock(return_value=mock_health)
mock_client.ingest = AsyncMock(return_value=MagicMock(chunk_id="abc", entities_extracted=2))
mock_client.get_context = AsyncMock(return_value="Relevant context: user prefers dark mode")
mock_client.list_entities = AsyncMock(return_value=[])
mock_client.stats = AsyncMock(return_value=MagicMock(
entity_count=5, chunk_count=10, node_count=5, edge_count=3, uptime_seconds=120.0
))
mock_client.wipe = AsyncMock(return_value=None)
mock_client.__aexit__ = AsyncMock(return_value=None)
mock_module = ModuleType("mnemo")
mock_module.AsyncMnemoClient = MagicMock(return_value=mock_client)
return mock_module, mock_client
async def _connected(health_ok: bool = True):
"""Return a connected MemoryClient with mock inner client attached."""
mock_module, mock_inner = _make_mock_mnemo(health_ok=health_ok)
client = MemoryClient(MemoryConfig())
with patch.dict(sys.modules, {"mnemo": mock_module}):
await client.connect()
client._mock_inner = mock_inner
return client
# ── Config ────────────────────────────────────────────────────────────────────
class TestMemoryConfig:
def test_defaults(self):
cfg = MemoryConfig()
assert cfg.host == "localhost"
assert cfg.port == 8080
assert cfg.base_url == "http://localhost:8080"
def test_from_env(self, monkeypatch):
monkeypatch.setenv("MNEMO_HOST", "mnemo-sidecar")
monkeypatch.setenv("MNEMO_PORT", "9090")
monkeypatch.setenv("MNEMO_TIMEOUT", "30.0")
cfg = MemoryConfig.from_env()
assert cfg.host == "mnemo-sidecar"
assert cfg.port == 9090
assert cfg.timeout == 30.0
def test_base_url(self):
cfg = MemoryConfig(host="10.1.10.5", port=8080)
assert cfg.base_url == "http://10.1.10.5:8080"
# ── connect() ─────────────────────────────────────────────────────────────────
class TestConnect:
@pytest.mark.asyncio
async def test_connect_success(self):
client = await _connected(health_ok=True)
assert client.available is True
assert client.failure_count == 0
@pytest.mark.asyncio
async def test_connect_bad_health_status(self):
client = await _connected(health_ok=False)
assert client.available is False
@pytest.mark.asyncio
async def test_connect_sidecar_unreachable(self):
mock_module, mock_client = _make_mock_mnemo()
mock_client.health.side_effect = ConnectionRefusedError("refused")
client = MemoryClient(MemoryConfig())
with patch.dict(sys.modules, {"mnemo": mock_module}):
await client.connect() # must not raise
assert client.available is False
@pytest.mark.asyncio
async def test_connect_strict_raises(self):
mock_module, mock_client = _make_mock_mnemo()
mock_client.health.side_effect = ConnectionRefusedError("refused")
client = MemoryClient(MemoryConfig(), strict=True)
with patch.dict(sys.modules, {"mnemo": mock_module}):
with pytest.raises(MemoryUnavailableError):
await client.connect()
@pytest.mark.asyncio
async def test_connect_missing_sdk(self):
client = MemoryClient(MemoryConfig())
with patch.dict(sys.modules, {"mnemo": None}):
await client.connect()
assert client.available is False
# ── No-op when unavailable ────────────────────────────────────────────────────
class TestNoopWhenUnavailable:
@pytest.fixture
def unavailable(self):
return MemoryClient(MemoryConfig())
@pytest.mark.asyncio
async def test_remember_noop(self, unavailable):
assert await unavailable.remember("text") is False
@pytest.mark.asyncio
async def test_recall_noop(self, unavailable):
assert await unavailable.recall("query") == ""
@pytest.mark.asyncio
async def test_entities_noop(self, unavailable):
assert await unavailable.entities() == []
@pytest.mark.asyncio
async def test_stats_noop(self, unavailable):
assert await unavailable.stats() is None
@pytest.mark.asyncio
async def test_wipe_noop(self, unavailable):
assert await unavailable.wipe() is False
# ── Live calls when connected ─────────────────────────────────────────────────
class TestLiveCalls:
@pytest.mark.asyncio
async def test_remember_calls_ingest(self):
client = await _connected()
result = await client.remember("hello world", source="test")
assert result is True
client._mock_inner.ingest.assert_awaited_once_with(
content="hello world", source="test", session_id=None
)
@pytest.mark.asyncio
async def test_remember_resets_failure_count(self):
client = await _connected()
client._failure_count = 2 # simulate prior failures
await client.remember("text")
assert client.failure_count == 0
@pytest.mark.asyncio
async def test_recall_returns_context(self):
client = await _connected()
ctx = await client.recall("dark mode preference")
assert "dark mode" in ctx
@pytest.mark.asyncio
async def test_recall_with_session(self):
client = await _connected()
await client.recall("query", session_id="user-123")
client._mock_inner.get_context.assert_awaited_once_with(
text="query", session_id="user-123"
)
@pytest.mark.asyncio
async def test_stats_returns_memory_stats(self):
from circuitforge_core.memory import MemoryStats
client = await _connected()
result = await client.stats()
assert isinstance(result, MemoryStats)
assert result.available is True
assert result.entity_count == 5
# ── Backoff and reconnect ─────────────────────────────────────────────────────
class TestBackoffAndReconnect:
@pytest.mark.asyncio
async def test_failure_count_increments(self):
client = await _connected()
client._mock_inner.ingest.side_effect = ConnectionResetError("reset")
await client.remember("text")
assert client.failure_count == 1
@pytest.mark.asyncio
async def test_client_disabled_after_max_failures(self):
client = await _connected()
client._mock_inner.ingest.side_effect = ConnectionResetError("reset")
# drive failures to the limit
for _ in range(_MAX_FAILURES):
await client.remember("text")
assert client.available is False
@pytest.mark.asyncio
async def test_retry_at_set_after_failure(self):
client = await _connected()
client._mock_inner.ingest.side_effect = ConnectionResetError("reset")
before = time.monotonic()
await client.remember("text")
assert client._retry_at is not None
assert client._retry_at > before
@pytest.mark.asyncio
async def test_backoff_increases_with_failures(self):
client = await _connected()
client._mock_inner.ingest.side_effect = ConnectionResetError("reset")
retry_times = []
t0 = time.monotonic()
for _ in range(3):
await client.remember("text")
retry_times.append(client._retry_at - t0)
# Each cooldown should be longer than the previous
assert retry_times[1] > retry_times[0]
assert retry_times[2] > retry_times[1]
@pytest.mark.asyncio
async def test_reconnect_attempted_after_cooldown(self):
"""Once the retry window elapses, the next call triggers a reconnect."""
client = await _connected()
# Force unavailable with an expired retry window
client._available = False
client._retry_at = time.monotonic() - 1.0 # already elapsed
mock_module, mock_inner = _make_mock_mnemo(health_ok=True)
with patch.dict(sys.modules, {"mnemo": mock_module}):
result = await client.remember("text after reconnect")
# Reconnect should have restored availability
assert client.available is True
assert result is True
@pytest.mark.asyncio
async def test_no_reconnect_during_cooldown(self):
"""Within the cooldown window, calls no-op without attempting reconnect."""
client = await _connected()
client._available = False
client._retry_at = time.monotonic() + 999.0 # far in the future
mock_module, _ = _make_mock_mnemo(health_ok=True)
with patch.dict(sys.modules, {"mnemo": mock_module}):
result = await client.remember("text during cooldown")
assert result is False
assert client.available is False # no reconnect fired
@pytest.mark.asyncio
async def test_success_resets_retry_state(self):
"""A successful call clears failure_count and retry_at."""
client = await _connected()
client._failure_count = 2
client._retry_at = time.monotonic() + 30.0
await client.remember("successful call")
assert client.failure_count == 0
assert client._retry_at is None
@pytest.mark.asyncio
async def test_strict_raises_after_max_failures(self):
"""strict=True raises MemoryUnavailableError once failure threshold is hit."""
client = await _connected()
client._strict = True
client._mock_inner.ingest.side_effect = ConnectionResetError("reset")
with pytest.raises(MemoryUnavailableError):
for _ in range(_MAX_FAILURES):
await client.remember("text")

View file

@ -1,78 +0,0 @@
"""Tests for MQTT topic wildcard matching in circuitforge_core.mqtt.router."""
import pytest
# NOTE: matches() currently raises NotImplementedError — tests will fail
# until you implement it. Run these to verify correctness once implemented.
def _matches(pattern: str, topic: str) -> bool:
from circuitforge_core.mqtt.router import matches
return matches(pattern, topic)
class TestExactMatch:
def test_exact(self):
assert _matches("a/b/c", "a/b/c")
def test_no_match(self):
assert not _matches("a/b/c", "a/b/d")
def test_empty_topic(self):
assert _matches("", "")
class TestSingleLevelWildcard:
def test_plus_middle(self):
assert _matches("sensor/+/temp", "sensor/room1/temp")
def test_plus_no_match_extra_level(self):
assert not _matches("sensor/+/temp", "sensor/a/b/temp")
def test_plus_start(self):
assert _matches("+/b/c", "a/b/c")
def test_plus_end(self):
assert _matches("a/b/+", "a/b/anything")
def test_multiple_plus(self):
assert _matches("+/+/+", "x/y/z")
def test_plus_no_match_empty_segment(self):
# '+' must match exactly one level — a leading slash creates an empty segment
# This edge case depends on the implementation; just check consistent behavior.
result = _matches("+", "a/b")
assert result is False
class TestMultiLevelWildcard:
def test_hash_root(self):
assert _matches("#", "a/b/c")
def test_hash_prefix(self):
assert _matches("sensor/#", "sensor/room1/temp")
def test_hash_zero_levels(self):
# '#' matches zero or more levels — "sensor/#" should match "sensor"
assert _matches("sensor/#", "sensor")
def test_hash_must_be_last(self):
# '#' in the middle is invalid MQTT but we should handle gracefully
# Just verify it doesn't crash; exact behavior is implementation-defined.
try:
_matches("sensor/#/foo", "sensor/bar/foo")
except Exception:
pass # either False or ValueError is acceptable
def test_hash_only(self):
assert _matches("#", "anything")
def test_hash_no_match_different_prefix(self):
assert not _matches("sensor/#", "actuator/fan")
class TestMixedWildcards:
def test_plus_and_hash(self):
assert _matches("msh/+/#", "msh/us-west/node1/json/TEXT_MESSAGE_APP/!deadbeef")
def test_plus_before_hash(self):
assert _matches("+/#", "region/any/nested/topic")

View file

@ -1,151 +0,0 @@
# tests/test_text/test_classifier.py — PII filter backend and endpoint tests
import pytest
from httpx import AsyncClient, ASGITransport
from circuitforge_core.text.backends.mock import MockClassifierBackend
from circuitforge_core.text.filter import PIIFilter, PIISpan, FilterResult, _redact, _spans_from_pipeline
# ── Unit: _spans_from_pipeline ────────────────────────────────────────────────
def test_spans_from_pipeline_normalises_bio_prefix():
raw = [{"entity_group": "B-NAME", "score": 0.9, "word": "Alice", "start": 0, "end": 5}]
spans = _spans_from_pipeline(raw)
assert spans[0].label == "NAME"
def test_spans_from_pipeline_uppercase():
raw = [{"entity_group": "email", "score": 0.8, "word": "a@b.com", "start": 10, "end": 17}]
spans = _spans_from_pipeline(raw)
assert spans[0].label == "EMAIL"
def test_spans_from_pipeline_returns_typed_objects():
raw = [{"entity_group": "PHONE_NUM", "score": 0.95, "word": "555-1234", "start": 5, "end": 13}]
spans = _spans_from_pipeline(raw)
assert isinstance(spans[0], PIISpan)
assert spans[0].score == pytest.approx(0.95)
assert spans[0].start == 5
assert spans[0].end == 13
# ── Unit: _redact ─────────────────────────────────────────────────────────────
def test_redact_replaces_spans():
text = "Call Alice at 555-1234 now"
spans = [
PIISpan(label="NAME", start=5, end=10, text="Alice", score=0.99),
PIISpan(label="PHONE_NUM", start=14, end=22, text="555-1234", score=0.97),
]
assert _redact(text, spans) == "Call [NAME] at [PHONE_NUM] now"
def test_redact_handles_overlapping_order():
# Spans processed right-to-left — earlier offsets must still be valid
text = "Jane Doe jane@example.com"
spans = [
PIISpan(label="NAME", start=0, end=8, text="Jane Doe", score=0.99),
PIISpan(label="EMAIL", start=9, end=25, text="jane@example.com", score=0.97),
]
result = _redact(text, spans)
assert "[NAME]" in result
assert "[EMAIL]" in result
assert "Jane Doe" not in result
assert "jane@example.com" not in result
def test_redact_no_spans_returns_original():
text = "No PII here"
assert _redact(text, []) == text
# ── Unit: PIIFilter with MockClassifierBackend ────────────────────────────────
def test_pii_filter_sync():
backend = MockClassifierBackend()
pii_filter = PIIFilter.from_backend(backend)
# Mock backend returns spans for "Jane Doe" at 0-8 and "jane@example.com" at 18-34
result = pii_filter.filter("Jane Doe emailed jane@example.com today")
assert isinstance(result, FilterResult)
assert "[NAME]" in result.redacted_text
assert "[EMAIL]" in result.redacted_text
assert len(result.spans) == 2
def test_pii_filter_preserves_original_text():
backend = MockClassifierBackend()
pii_filter = PIIFilter.from_backend(backend)
text = "Jane Doe emailed jane@example.com today"
result = pii_filter.filter(text)
assert result.original_text == text
@pytest.mark.asyncio
async def test_pii_filter_async():
backend = MockClassifierBackend()
pii_filter = PIIFilter.from_backend(backend)
result = await pii_filter.filter_async("Jane Doe emailed jane@example.com today")
assert "[NAME]" in result.redacted_text
assert len(result.spans) == 2
def test_pii_filter_result_is_frozen():
backend = MockClassifierBackend()
pii_filter = PIIFilter.from_backend(backend)
result = pii_filter.filter("test")
with pytest.raises((AttributeError, TypeError)):
result.redacted_text = "mutated" # type: ignore[misc]
# ── Integration: /filter HTTP endpoint ───────────────────────────────────────
@pytest.fixture
def classifier_app(monkeypatch):
"""cf-text app in classifier mode using mock backend."""
import os
monkeypatch.setenv("CF_TEXT_MOCK", "1")
monkeypatch.setenv("CF_TEXT_BACKEND", "classifier")
import importlib
import circuitforge_core.text.app as app_mod
importlib.reload(app_mod)
yield app_mod.create_app(model_path="openai/privacy-filter", backend="classifier", mock=False)
monkeypatch.delenv("CF_TEXT_MOCK", raising=False)
monkeypatch.delenv("CF_TEXT_BACKEND", raising=False)
@pytest.mark.asyncio
async def test_filter_endpoint_returns_redacted(classifier_app):
async with AsyncClient(transport=ASGITransport(app=classifier_app), base_url="http://test") as client:
resp = await client.post("/filter", json={"text": "Jane Doe emailed jane@example.com today"})
assert resp.status_code == 200
body = resp.json()
assert "[NAME]" in body["redacted_text"]
assert "[EMAIL]" in body["redacted_text"]
assert len(body["spans"]) == 2
@pytest.mark.asyncio
async def test_filter_endpoint_includes_original(classifier_app):
text = "Jane Doe emailed jane@example.com today"
async with AsyncClient(transport=ASGITransport(app=classifier_app), base_url="http://test") as client:
resp = await client.post("/filter", json={"text": text})
assert resp.json()["original_text"] == text
@pytest.mark.asyncio
async def test_generate_returns_501_in_classifier_mode(classifier_app):
async with AsyncClient(transport=ASGITransport(app=classifier_app), base_url="http://test") as client:
resp = await client.post("/generate", json={"prompt": "hello"})
assert resp.status_code == 501
@pytest.mark.asyncio
async def test_health_reports_classifier_backend(classifier_app):
async with AsyncClient(transport=ASGITransport(app=classifier_app), base_url="http://test") as client:
resp = await client.get("/health")
assert resp.status_code == 200
assert resp.json()["backend"] == "classifier"