Compare commits
No commits in common. "a92a83db4bcd414b26fd623fb3a5a21e00579e99" and "5a363f3b6cddea1c69bd420e6eeba304514086c4" have entirely different histories.
a92a83db4b
...
5a363f3b6c
30 changed files with 62 additions and 2664 deletions
|
|
@ -1,11 +1,4 @@
|
||||||
name: Release — PyPI + Forgejo Packages
|
name: Release — PyPI
|
||||||
|
|
||||||
# circuitforge-core is MIT — published to both public PyPI and the Circuit-Forge
|
|
||||||
# Forgejo Packages index so cf-orch can resolve it from a single --extra-index-url.
|
|
||||||
#
|
|
||||||
# Required secrets:
|
|
||||||
# PYPI_API_TOKEN — public PyPI upload token
|
|
||||||
# FORGEJO_PYPI_TOKEN — Forgejo token with package:write scope
|
|
||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
|
|
@ -26,36 +19,29 @@ jobs:
|
||||||
|
|
||||||
- name: Build
|
- name: Build
|
||||||
run: |
|
run: |
|
||||||
pip install build twine
|
pip install build
|
||||||
python -m build
|
python -m build
|
||||||
|
|
||||||
- name: Publish to public PyPI
|
- name: Publish to PyPI
|
||||||
uses: pypa/gh-action-pypi-publish@release/v1
|
uses: pypa/gh-action-pypi-publish@release/v1
|
||||||
with:
|
with:
|
||||||
password: ${{ secrets.PYPI_API_TOKEN }}
|
password: ${{ secrets.PYPI_API_TOKEN }}
|
||||||
|
|
||||||
- name: Publish to Forgejo Packages
|
|
||||||
env:
|
|
||||||
TWINE_USERNAME: pypi-token
|
|
||||||
TWINE_PASSWORD: ${{ secrets.FORGEJO_PYPI_TOKEN }}
|
|
||||||
TWINE_REPOSITORY_URL: https://git.opensourcesolarpunk.com/api/packages/Circuit-Forge/pypi
|
|
||||||
run: twine upload dist/*
|
|
||||||
|
|
||||||
- name: Create Forgejo release
|
- name: Create Forgejo release
|
||||||
env:
|
env:
|
||||||
FORGEJO_TOKEN: ${{ secrets.FORGEJO_PYPI_TOKEN }}
|
FORGEJO_TOKEN: ${{ secrets.FORGEJO_RELEASE_TOKEN }}
|
||||||
run: |
|
run: |
|
||||||
TAG="${GITHUB_REF_NAME}"
|
TAG="${GITHUB_REF_NAME}"
|
||||||
|
# Check if release already exists for this tag
|
||||||
EXISTING=$(curl -sf \
|
EXISTING=$(curl -sf \
|
||||||
-H "Authorization: token ${FORGEJO_TOKEN}" \
|
-H "Authorization: token ${FORGEJO_TOKEN}" \
|
||||||
"https://git.opensourcesolarpunk.com/api/v1/repos/Circuit-Forge/circuitforge-core/releases/tags/${TAG}" \
|
"https://git.opensourcesolarpunk.com/api/v1/repos/Circuit-Forge/circuitforge-core/releases/tags/${TAG}" \
|
||||||
2>/dev/null \
|
2>/dev/null | jq -r '.id // empty')
|
||||||
| python3 -c "import sys,json; print(json.load(sys.stdin).get('id',''))" 2>/dev/null || true)
|
|
||||||
if [ -z "${EXISTING}" ]; then
|
if [ -z "${EXISTING}" ]; then
|
||||||
python3 -c "
|
jq -n --arg tag "${TAG}" \
|
||||||
import json
|
'{"tag_name":$tag,"name":$tag,"draft":false,"prerelease":false}' \
|
||||||
print(json.dumps({'tag_name':'${TAG}','name':'${TAG}','draft':False,'prerelease':False}))
|
| curl -sf -X POST \
|
||||||
" | curl -sf -X POST \
|
|
||||||
-H "Authorization: token ${FORGEJO_TOKEN}" \
|
-H "Authorization: token ${FORGEJO_TOKEN}" \
|
||||||
-H "Content-Type: application/json" \
|
-H "Content-Type: application/json" \
|
||||||
"https://git.opensourcesolarpunk.com/api/v1/repos/Circuit-Forge/circuitforge-core/releases" \
|
"https://git.opensourcesolarpunk.com/api/v1/repos/Circuit-Forge/circuitforge-core/releases" \
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,4 @@
|
||||||
from importlib.metadata import PackageNotFoundError, version
|
__version__ = "0.18.0"
|
||||||
|
|
||||||
try:
|
|
||||||
__version__ = version("circuitforge-core")
|
|
||||||
except PackageNotFoundError:
|
|
||||||
__version__ = "dev" # running from source without an editable install
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from circuitforge_core.community import CommunityDB, CommunityPost, SharedStore
|
from circuitforge_core.community import CommunityDB, CommunityPost, SharedStore
|
||||||
|
|
|
||||||
|
|
@ -39,13 +39,6 @@ from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable
|
from typing import Any, Callable
|
||||||
|
|
||||||
try:
|
|
||||||
from starlette.requests import Request as _Request
|
|
||||||
from starlette.responses import Response as _Response
|
|
||||||
except ImportError: # pragma: no cover — starlette may be absent in non-web envs
|
|
||||||
_Request = Any # type: ignore[assignment,misc]
|
|
||||||
_Response = Any # type: ignore[assignment,misc]
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
TIERS: list[str] = ["free", "paid", "premium", "ultra"]
|
TIERS: list[str] = ["free", "paid", "premium", "ultra"]
|
||||||
|
|
@ -255,40 +248,22 @@ class CloudSessionFactory:
|
||||||
request.headers.get("x-real-ip", "")
|
request.headers.get("x-real-ip", "")
|
||||||
or (request.client.host if request.client else "")
|
or (request.client.host if request.client else "")
|
||||||
)
|
)
|
||||||
is_bypass = _is_bypass_ip(client_ip, self._bypass_nets, self._bypass_ips)
|
if _is_bypass_ip(client_ip, self._bypass_nets, self._bypass_ips):
|
||||||
|
log.debug("Bypass IP %s — returning local-dev session for product %s", client_ip, self.product)
|
||||||
|
return CloudUser(user_id="local-dev", tier="local", product=self.product, has_byok=has_byok)
|
||||||
|
|
||||||
raw_session = (
|
raw_session = (
|
||||||
request.headers.get("x-cf-session", "").strip()
|
request.headers.get("x-cf-session", "").strip()
|
||||||
or request.cookies.get("cf_session", "").strip()
|
or request.cookies.get("cf_session", "").strip()
|
||||||
)
|
)
|
||||||
|
|
||||||
# Bypass IPs skip the JWT *requirement* but not JWT *validation*.
|
|
||||||
# If a token is present (dev is logged in), honour it so they land on
|
|
||||||
# their own account DB rather than the shared local-dev DB.
|
|
||||||
if not raw_session:
|
if not raw_session:
|
||||||
if is_bypass:
|
|
||||||
log.debug("Bypass IP %s, no token — returning local-dev session for product %s", client_ip, self.product)
|
|
||||||
return CloudUser(user_id="local-dev", tier="local", product=self.product, has_byok=has_byok)
|
|
||||||
return self._resolve_guest(request, response)
|
return self._resolve_guest(request, response)
|
||||||
|
|
||||||
token = _extract_session_token(raw_session)
|
token = _extract_session_token(raw_session)
|
||||||
if not token:
|
if not token:
|
||||||
return self._resolve_guest(request, response)
|
return self._resolve_guest(request, response)
|
||||||
|
|
||||||
# Soft-fail on invalid/expired JWT: downgrade to guest rather than
|
user_id = self.validate_jwt(token)
|
||||||
# hard-erroring with 401. Public endpoints (e.g. community blocklist)
|
|
||||||
# should remain accessible even when the browser has a stale cookie.
|
|
||||||
# Routes that genuinely require an authenticated identity should gate
|
|
||||||
# themselves with require_tier() — that's where the 401/403 belongs.
|
|
||||||
try:
|
|
||||||
user_id = self.validate_jwt(token)
|
|
||||||
except Exception:
|
|
||||||
log.warning(
|
|
||||||
"JWT validation failed for product %s (expired or tampered) — falling back to guest",
|
|
||||||
self.product,
|
|
||||||
)
|
|
||||||
return self._resolve_guest(request, response)
|
|
||||||
|
|
||||||
self._ensure_provisioned(user_id)
|
self._ensure_provisioned(user_id)
|
||||||
tier_data = self._resolve_tier(user_id)
|
tier_data = self._resolve_tier(user_id)
|
||||||
tier = tier_data.get("tier", "free")
|
tier = tier_data.get("tier", "free")
|
||||||
|
|
@ -308,11 +283,11 @@ class CloudSessionFactory:
|
||||||
meta=meta,
|
meta=meta,
|
||||||
)
|
)
|
||||||
|
|
||||||
def dependency(self) -> Callable[["_Request", "_Response"], CloudUser]:
|
def dependency(self) -> Callable[[Any, Any], CloudUser]:
|
||||||
"""Return a FastAPI-compatible dependency function (use with Depends())."""
|
"""Return a FastAPI-compatible dependency function (use with Depends())."""
|
||||||
factory = self
|
factory = self
|
||||||
|
|
||||||
def _get_session(request: _Request, response: _Response) -> CloudUser:
|
def _get_session(request: Any, response: Any) -> CloudUser:
|
||||||
return factory.resolve(request, response)
|
return factory.resolve(request, response)
|
||||||
|
|
||||||
return _get_session
|
return _get_session
|
||||||
|
|
|
||||||
|
|
@ -1,54 +0,0 @@
|
||||||
"""circuitforge_core.memory — persistent knowledge graph via mnemo sidecar.
|
|
||||||
|
|
||||||
MIT licensed.
|
|
||||||
|
|
||||||
Requires the mnemo sidecar to be running (https://github.com/zaydmulani09/mnemo).
|
|
||||||
If the sidecar is not available, all operations silently no-op so products
|
|
||||||
can call memory methods unconditionally.
|
|
||||||
|
|
||||||
Quick start (in a FastAPI lifespan)::
|
|
||||||
|
|
||||||
from circuitforge_core.memory import MemoryClient, MemoryConfig
|
|
||||||
|
|
||||||
memory = MemoryClient(MemoryConfig.from_env())
|
|
||||||
|
|
||||||
@asynccontextmanager
|
|
||||||
async def lifespan(app):
|
|
||||||
await memory.connect()
|
|
||||||
yield
|
|
||||||
await memory.close()
|
|
||||||
|
|
||||||
# In a route:
|
|
||||||
await memory.remember("User avoids shellfish", source="dietary-prefs")
|
|
||||||
context = await memory.recall("What are this user's food restrictions?")
|
|
||||||
|
|
||||||
Docker Compose setup::
|
|
||||||
|
|
||||||
services:
|
|
||||||
mnemo:
|
|
||||||
image: ghcr.io/zaydmulani09/mnemo:latest
|
|
||||||
ports: ["8080:8080"]
|
|
||||||
environment:
|
|
||||||
MNEMO_LLM_PROVIDER: ollama
|
|
||||||
MNEMO_LLM_BASE_URL: http://ollama:11434/v1
|
|
||||||
MNEMO_LLM_MODEL: llama3
|
|
||||||
volumes:
|
|
||||||
- mnemo-data:/data
|
|
||||||
|
|
||||||
Environment variables (for MemoryConfig.from_env())::
|
|
||||||
|
|
||||||
MNEMO_HOST — default: localhost
|
|
||||||
MNEMO_PORT — default: 8080
|
|
||||||
MNEMO_TIMEOUT — default: 10.0
|
|
||||||
"""
|
|
||||||
|
|
||||||
from circuitforge_core.memory.client import MemoryClient, MemoryUnavailableError
|
|
||||||
from circuitforge_core.memory.models import MemoryConfig, MemoryEntity, MemoryStats
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"MemoryClient",
|
|
||||||
"MemoryConfig",
|
|
||||||
"MemoryEntity",
|
|
||||||
"MemoryStats",
|
|
||||||
"MemoryUnavailableError",
|
|
||||||
]
|
|
||||||
|
|
@ -1,317 +0,0 @@
|
||||||
"""MemoryClient — async wrapper around the mnemo persistent knowledge graph.
|
|
||||||
|
|
||||||
mnemo is an optional sidecar (https://github.com/zaydmulani09/mnemo).
|
|
||||||
When the sidecar is not running, all operations silently no-op so products
|
|
||||||
can call memory methods unconditionally without try/except.
|
|
||||||
|
|
||||||
MIT licensed.
|
|
||||||
"""
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import time
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from circuitforge_core.memory.models import MemoryConfig, MemoryEntity, MemoryStats
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# Backoff schedule: 5 * 2^(failure-1), capped at _MAX_BACKOFF seconds.
|
|
||||||
# failure 1 → 5s, 2 → 10s, 3 → 20s, 4 → 40s, 5+ → 60s
|
|
||||||
_MAX_FAILURES: int = 3
|
|
||||||
_MAX_BACKOFF: float = 60.0
|
|
||||||
|
|
||||||
|
|
||||||
class MemoryUnavailableError(RuntimeError):
|
|
||||||
"""Raised only when strict=True and mnemo is not reachable."""
|
|
||||||
|
|
||||||
|
|
||||||
class MemoryClient:
|
|
||||||
"""Async interface to the mnemo knowledge graph sidecar.
|
|
||||||
|
|
||||||
Resilience model:
|
|
||||||
- If the sidecar is unreachable at connect(), logs once and enters no-op mode.
|
|
||||||
- If a live call fails, the failure is counted. Each failure schedules an
|
|
||||||
exponentially increasing cooldown before the next reconnect attempt.
|
|
||||||
- After _MAX_FAILURES consecutive failures the client is marked unavailable;
|
|
||||||
all calls no-op until the cooldown elapses and a reconnect succeeds.
|
|
||||||
- Any successful call resets the failure counter.
|
|
||||||
|
|
||||||
Usage (in a FastAPI lifespan)::
|
|
||||||
|
|
||||||
from circuitforge_core.memory import MemoryClient, MemoryConfig
|
|
||||||
|
|
||||||
memory = MemoryClient(MemoryConfig.from_env())
|
|
||||||
|
|
||||||
@asynccontextmanager
|
|
||||||
async def lifespan(app):
|
|
||||||
await memory.connect()
|
|
||||||
yield
|
|
||||||
await memory.close()
|
|
||||||
|
|
||||||
Then in handlers::
|
|
||||||
|
|
||||||
await memory.remember("User prefers dark mode", source="settings")
|
|
||||||
context = await memory.recall("What are the user's UI preferences?")
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, config: MemoryConfig | None = None, *, strict: bool = False) -> None:
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
config: connection settings; defaults to MemoryConfig.from_env()
|
|
||||||
strict: if True, MemoryUnavailableError is raised on connect failure
|
|
||||||
or after _MAX_FAILURES consecutive call failures
|
|
||||||
"""
|
|
||||||
self._config = config or MemoryConfig.from_env()
|
|
||||||
self._strict = strict
|
|
||||||
self._available = False
|
|
||||||
self._client: Any = None # mnemo AsyncMnemoClient, set in connect()
|
|
||||||
self._failure_count: int = 0
|
|
||||||
self._retry_at: float | None = None # monotonic timestamp; None = no retry pending
|
|
||||||
|
|
||||||
@property
|
|
||||||
def available(self) -> bool:
|
|
||||||
"""True if the mnemo sidecar was reachable at last health check."""
|
|
||||||
return self._available
|
|
||||||
|
|
||||||
@property
|
|
||||||
def failure_count(self) -> int:
|
|
||||||
"""Consecutive call failures since the last success."""
|
|
||||||
return self._failure_count
|
|
||||||
|
|
||||||
# ── Lifecycle ─────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
async def connect(self) -> None:
|
|
||||||
"""Attempt to connect to the mnemo sidecar and run a health check.
|
|
||||||
|
|
||||||
Safe to call multiple times (used internally for reconnect). If the
|
|
||||||
sidecar is not reachable, logs a warning and enters no-op mode.
|
|
||||||
Does NOT raise unless strict=True.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
from mnemo import AsyncMnemoClient
|
|
||||||
except ImportError:
|
|
||||||
logger.debug(
|
|
||||||
"mnemo-sdk not installed — memory module disabled. "
|
|
||||||
"Install with: pip install circuitforge-core[memory]"
|
|
||||||
)
|
|
||||||
self._available = False
|
|
||||||
return
|
|
||||||
|
|
||||||
self._client = AsyncMnemoClient(
|
|
||||||
base_url=self._config.base_url,
|
|
||||||
timeout=self._config.timeout,
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
health = await self._client.health()
|
|
||||||
if health.status == "ok":
|
|
||||||
self._available = True
|
|
||||||
self._on_call_success()
|
|
||||||
logger.info(
|
|
||||||
"mnemo memory sidecar connected at %s (LLM: %s/%s)",
|
|
||||||
self._config.base_url,
|
|
||||||
health.provider_type,
|
|
||||||
health.provider_model,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self._handle_unavailable("connect", reason=f"health status={health.status!r}")
|
|
||||||
except Exception as exc:
|
|
||||||
self._handle_unavailable("connect", reason=str(exc))
|
|
||||||
|
|
||||||
async def close(self) -> None:
|
|
||||||
"""Close the underlying HTTP client."""
|
|
||||||
if self._client is not None:
|
|
||||||
try:
|
|
||||||
await self._client.__aexit__(None, None, None)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
self._client = None
|
|
||||||
self._available = False
|
|
||||||
self._retry_at = None
|
|
||||||
|
|
||||||
# ── Core API ──────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
async def remember(
|
|
||||||
self,
|
|
||||||
text: str,
|
|
||||||
*,
|
|
||||||
source: str = "cf-core",
|
|
||||||
session_id: str | None = None,
|
|
||||||
) -> bool:
|
|
||||||
"""Store a text fragment in the knowledge graph.
|
|
||||||
|
|
||||||
mnemo extracts named entities and relationships from the text and
|
|
||||||
updates its graph. Large texts should be pre-chunked by the caller
|
|
||||||
(mnemo stores each call as a single chunk with no sub-splitting).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text: the text to store (conversation turn, fact, note, etc.)
|
|
||||||
source: label for the origin (e.g. "chat", "settings", "search")
|
|
||||||
session_id: optional session grouping for multi-turn retrieval
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if stored, False if sidecar unavailable.
|
|
||||||
"""
|
|
||||||
if not await self._maybe_reconnect():
|
|
||||||
return False
|
|
||||||
try:
|
|
||||||
await self._client.ingest(content=text, source=source, session_id=session_id)
|
|
||||||
self._on_call_success()
|
|
||||||
return True
|
|
||||||
except Exception as exc:
|
|
||||||
self._on_call_error("remember", exc)
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def recall(
|
|
||||||
self,
|
|
||||||
query: str,
|
|
||||||
*,
|
|
||||||
session_id: str | None = None,
|
|
||||||
) -> str:
|
|
||||||
"""Retrieve a formatted context block relevant to query.
|
|
||||||
|
|
||||||
Returns a prompt-ready string (or empty string if unavailable).
|
|
||||||
Inject the result directly into a system prompt::
|
|
||||||
|
|
||||||
context = await memory.recall("user dietary restrictions")
|
|
||||||
system = f"You are a helpful assistant.\\n\\n{context}"
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query: natural language question or topic to retrieve context for
|
|
||||||
session_id: restrict retrieval to a specific session (optional)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Formatted context string, or "" if sidecar unavailable.
|
|
||||||
"""
|
|
||||||
if not await self._maybe_reconnect():
|
|
||||||
return ""
|
|
||||||
try:
|
|
||||||
result = await self._client.get_context(text=query, session_id=session_id)
|
|
||||||
self._failure_count = 0
|
|
||||||
return result
|
|
||||||
except Exception as exc:
|
|
||||||
self._on_call_error("recall", exc)
|
|
||||||
return ""
|
|
||||||
|
|
||||||
async def entities(self, *, limit: int = 50) -> list[MemoryEntity]:
|
|
||||||
"""Return the most recent named entities in the knowledge graph.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
limit: max entities to return (default 50)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of MemoryEntity objects, or [] if unavailable.
|
|
||||||
"""
|
|
||||||
if not await self._maybe_reconnect():
|
|
||||||
return []
|
|
||||||
try:
|
|
||||||
raw = await self._client.list_entities(limit=limit)
|
|
||||||
self._on_call_success()
|
|
||||||
return [MemoryEntity.from_mnemo(e) for e in raw]
|
|
||||||
except Exception as exc:
|
|
||||||
self._on_call_error("entities", exc)
|
|
||||||
return []
|
|
||||||
|
|
||||||
async def stats(self) -> MemoryStats | None:
|
|
||||||
"""Return knowledge graph statistics, or None if unavailable."""
|
|
||||||
if not await self._maybe_reconnect():
|
|
||||||
return None
|
|
||||||
try:
|
|
||||||
s = await self._client.stats()
|
|
||||||
self._on_call_success()
|
|
||||||
return MemoryStats(
|
|
||||||
entity_count=s.entity_count,
|
|
||||||
chunk_count=s.chunk_count,
|
|
||||||
node_count=s.node_count,
|
|
||||||
edge_count=s.edge_count,
|
|
||||||
uptime_seconds=s.uptime_seconds,
|
|
||||||
available=True,
|
|
||||||
)
|
|
||||||
except Exception as exc:
|
|
||||||
self._on_call_error("stats", exc)
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def wipe(self) -> bool:
|
|
||||||
"""Delete all stored memory. Irreversible.
|
|
||||||
|
|
||||||
Returns True on success, False if unavailable or failed.
|
|
||||||
"""
|
|
||||||
if not await self._maybe_reconnect():
|
|
||||||
return False
|
|
||||||
try:
|
|
||||||
await self._client.wipe()
|
|
||||||
self._on_call_success()
|
|
||||||
logger.warning("mnemo memory wiped — all entities and chunks deleted")
|
|
||||||
return True
|
|
||||||
except Exception as exc:
|
|
||||||
self._on_call_error("wipe", exc)
|
|
||||||
return False
|
|
||||||
|
|
||||||
# ── Internal ──────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
async def _maybe_reconnect(self) -> bool:
|
|
||||||
"""Return True if the client is available (or just became available).
|
|
||||||
|
|
||||||
Called at the top of every public method. If the client is unavailable
|
|
||||||
but the retry cooldown has elapsed, silently attempts reconnect before
|
|
||||||
answering. No-ops immediately if still within the cooldown window.
|
|
||||||
"""
|
|
||||||
if self._available:
|
|
||||||
return True
|
|
||||||
if self._retry_at is not None and time.monotonic() >= self._retry_at:
|
|
||||||
logger.info(
|
|
||||||
"mnemo: cooldown elapsed after %d failure(s) — attempting reconnect",
|
|
||||||
self._failure_count,
|
|
||||||
)
|
|
||||||
self._retry_at = None
|
|
||||||
self._client = None
|
|
||||||
await self.connect()
|
|
||||||
return self._available
|
|
||||||
|
|
||||||
def _on_call_success(self) -> None:
|
|
||||||
"""Reset failure state after a successful call."""
|
|
||||||
self._failure_count = 0
|
|
||||||
self._retry_at = None
|
|
||||||
|
|
||||||
def _handle_unavailable(self, operation: str, reason: str = "") -> None:
|
|
||||||
"""Called when the sidecar is unreachable at connect() time."""
|
|
||||||
self._available = False
|
|
||||||
msg = f"mnemo memory sidecar unavailable (operation={operation!r})"
|
|
||||||
if reason:
|
|
||||||
msg += f": {reason}"
|
|
||||||
if self._strict:
|
|
||||||
raise MemoryUnavailableError(msg)
|
|
||||||
logger.warning("%s — memory features disabled", msg)
|
|
||||||
|
|
||||||
def _on_call_error(self, operation: str, exc: Exception) -> None:
|
|
||||||
"""Count consecutive failures and schedule exponential backoff retry.
|
|
||||||
|
|
||||||
Backoff: 5 * 2^(failure-1) seconds, capped at 60s.
|
|
||||||
failure 1 → 5s
|
|
||||||
failure 2 → 10s
|
|
||||||
failure 3 → 20s ← _MAX_FAILURES default; client disabled here
|
|
||||||
failure 4 → 40s
|
|
||||||
failure 5+ → 60s
|
|
||||||
|
|
||||||
After _MAX_FAILURES, _available is set to False and all calls no-op
|
|
||||||
until _maybe_reconnect() fires after the cooldown elapses.
|
|
||||||
"""
|
|
||||||
self._failure_count += 1
|
|
||||||
backoff = min(5.0 * (2 ** (self._failure_count - 1)), _MAX_BACKOFF)
|
|
||||||
self._retry_at = time.monotonic() + backoff
|
|
||||||
|
|
||||||
if self._failure_count >= _MAX_FAILURES:
|
|
||||||
self._available = False
|
|
||||||
logger.warning(
|
|
||||||
"mnemo %r failed %d consecutive times (%s) — disabled, reconnect in %.0fs",
|
|
||||||
operation, self._failure_count, exc, backoff,
|
|
||||||
)
|
|
||||||
if self._strict:
|
|
||||||
raise MemoryUnavailableError(
|
|
||||||
f"mnemo {operation!r} failed {self._failure_count} consecutive times: {exc}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.warning(
|
|
||||||
"mnemo %r failed (%d/%d): %s — retry in %.0fs",
|
|
||||||
operation, self._failure_count, _MAX_FAILURES, exc, backoff,
|
|
||||||
)
|
|
||||||
|
|
@ -1,73 +0,0 @@
|
||||||
"""Data models for the cf-core memory module.
|
|
||||||
|
|
||||||
MIT licensed.
|
|
||||||
"""
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import os
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class MemoryConfig:
|
|
||||||
"""Connection config for a mnemo sidecar."""
|
|
||||||
|
|
||||||
host: str = "localhost"
|
|
||||||
port: int = 8080
|
|
||||||
timeout: float = 10.0
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_env(cls) -> MemoryConfig:
|
|
||||||
"""Read config from environment variables.
|
|
||||||
|
|
||||||
Variables:
|
|
||||||
MNEMO_HOST — default: localhost
|
|
||||||
MNEMO_PORT — default: 8080
|
|
||||||
MNEMO_TIMEOUT — default: 10.0
|
|
||||||
"""
|
|
||||||
return cls(
|
|
||||||
host=os.environ.get("MNEMO_HOST", "localhost"),
|
|
||||||
port=int(os.environ.get("MNEMO_PORT", "8080")),
|
|
||||||
timeout=float(os.environ.get("MNEMO_TIMEOUT", "10.0")),
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def base_url(self) -> str:
|
|
||||||
return f"http://{self.host}:{self.port}"
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class MemoryEntity:
|
|
||||||
"""A named entity extracted and stored by the mnemo knowledge graph."""
|
|
||||||
|
|
||||||
entity_id: str
|
|
||||||
name: str
|
|
||||||
entity_type: str
|
|
||||||
aliases: list[str] = field(default_factory=list)
|
|
||||||
confidence: float = 1.0
|
|
||||||
source_count: int = 1
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_mnemo(cls, obj) -> MemoryEntity:
|
|
||||||
"""Convert a mnemo-sdk Entity object to MemoryEntity."""
|
|
||||||
return cls(
|
|
||||||
entity_id=str(obj.id),
|
|
||||||
name=obj.name,
|
|
||||||
entity_type=obj.entity_type,
|
|
||||||
aliases=list(obj.aliases or []),
|
|
||||||
confidence=float(obj.confidence or 1.0),
|
|
||||||
source_count=int(obj.source_count or 1),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class MemoryStats:
|
|
||||||
"""Snapshot of the mnemo knowledge graph state."""
|
|
||||||
|
|
||||||
entity_count: int
|
|
||||||
chunk_count: int
|
|
||||||
node_count: int
|
|
||||||
edge_count: int
|
|
||||||
uptime_seconds: float
|
|
||||||
available: bool
|
|
||||||
|
|
@ -1,42 +0,0 @@
|
||||||
"""circuitforge_core.mqtt — async MQTT client with topic routing and
|
|
||||||
Meshtastic adapter support.
|
|
||||||
|
|
||||||
MIT licensed.
|
|
||||||
|
|
||||||
Quick start::
|
|
||||||
|
|
||||||
from circuitforge_core.mqtt import MQTTClient, MQTTConfig
|
|
||||||
|
|
||||||
cfg = MQTTConfig(host="localhost")
|
|
||||||
client = MQTTClient(cfg)
|
|
||||||
|
|
||||||
@client.on("sensors/#")
|
|
||||||
async def handle(msg):
|
|
||||||
print(msg.topic, msg.text())
|
|
||||||
|
|
||||||
await client.run()
|
|
||||||
|
|
||||||
For Meshtastic::
|
|
||||||
|
|
||||||
from circuitforge_core.mqtt.meshtastic import make_backend
|
|
||||||
|
|
||||||
backend = make_backend({
|
|
||||||
"backend": "mqtt",
|
|
||||||
"broker_host": "mqtt.example.com",
|
|
||||||
"topic_prefix": "msh/#",
|
|
||||||
})
|
|
||||||
async for pkt in backend.packets():
|
|
||||||
print(pkt.summary())
|
|
||||||
"""
|
|
||||||
|
|
||||||
from circuitforge_core.mqtt.client import MQTTClient
|
|
||||||
from circuitforge_core.mqtt.models import MQTTConfig, MQTTMessage
|
|
||||||
from circuitforge_core.mqtt.router import TopicRouter, matches
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"MQTTClient",
|
|
||||||
"MQTTConfig",
|
|
||||||
"MQTTMessage",
|
|
||||||
"TopicRouter",
|
|
||||||
"matches",
|
|
||||||
]
|
|
||||||
|
|
@ -1,152 +0,0 @@
|
||||||
"""Async MQTT client wrapper around aiomqtt.
|
|
||||||
|
|
||||||
MIT licensed.
|
|
||||||
"""
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import logging
|
|
||||||
from collections.abc import AsyncIterator
|
|
||||||
from contextlib import asynccontextmanager
|
|
||||||
from datetime import datetime, timezone
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from circuitforge_core.mqtt.models import MQTTConfig, MQTTMessage
|
|
||||||
from circuitforge_core.mqtt.router import TopicRouter
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class MQTTClient:
|
|
||||||
"""Async MQTT client that subscribes to topics and dispatches messages.
|
|
||||||
|
|
||||||
Usage (with a router)::
|
|
||||||
|
|
||||||
cfg = MQTTConfig(host="localhost")
|
|
||||||
client = MQTTClient(cfg)
|
|
||||||
|
|
||||||
@client.on("msh/#")
|
|
||||||
async def handle_mesh(msg: MQTTMessage):
|
|
||||||
print(msg.topic, msg.text())
|
|
||||||
|
|
||||||
await client.run()
|
|
||||||
|
|
||||||
Usage (iterate raw messages)::
|
|
||||||
|
|
||||||
async with MQTTClient(cfg) as messages:
|
|
||||||
async for msg in messages:
|
|
||||||
print(msg.topic)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, config: MQTTConfig, router: TopicRouter | None = None) -> None:
|
|
||||||
self._config = config
|
|
||||||
self._router = router or TopicRouter()
|
|
||||||
|
|
||||||
def on(self, pattern: str):
|
|
||||||
"""Shorthand decorator — forwards to the internal router."""
|
|
||||||
return self._router.on(pattern)
|
|
||||||
|
|
||||||
async def run(self) -> None:
|
|
||||||
"""Subscribe to all registered patterns and dispatch until cancelled.
|
|
||||||
|
|
||||||
Reconnects automatically if the connection drops.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
import aiomqtt
|
|
||||||
except ImportError as exc:
|
|
||||||
raise ImportError(
|
|
||||||
"aiomqtt is required for MQTTClient. "
|
|
||||||
"Install with: pip install circuitforge-core[mqtt]"
|
|
||||||
) from exc
|
|
||||||
|
|
||||||
cfg = self._config
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
kwargs: dict[str, Any] = {
|
|
||||||
"hostname": cfg.host,
|
|
||||||
"port": cfg.port,
|
|
||||||
"keepalive": cfg.keepalive,
|
|
||||||
"tls_params": aiomqtt.TLSParameters() if cfg.tls else None,
|
|
||||||
}
|
|
||||||
if cfg.client_id:
|
|
||||||
kwargs["identifier"] = cfg.client_id
|
|
||||||
if cfg.username is not None:
|
|
||||||
kwargs["username"] = cfg.username
|
|
||||||
if cfg.password is not None:
|
|
||||||
kwargs["password"] = cfg.password
|
|
||||||
|
|
||||||
async with aiomqtt.Client(**kwargs) as ac:
|
|
||||||
patterns = self._router.patterns
|
|
||||||
if not patterns:
|
|
||||||
logger.warning("MQTTClient started with no subscriptions")
|
|
||||||
for p in patterns:
|
|
||||||
await ac.subscribe(p)
|
|
||||||
logger.debug("Subscribed to %r on %s:%d", p, cfg.host, cfg.port)
|
|
||||||
logger.info("MQTT connected to %s:%d", cfg.host, cfg.port)
|
|
||||||
|
|
||||||
async for raw in ac.messages:
|
|
||||||
msg = MQTTMessage(
|
|
||||||
topic=str(raw.topic),
|
|
||||||
payload=raw.payload if isinstance(raw.payload, bytes) else str(raw.payload).encode(),
|
|
||||||
qos=raw.qos,
|
|
||||||
retain=raw.retain,
|
|
||||||
received_at=datetime.now(tz=timezone.utc),
|
|
||||||
)
|
|
||||||
await self._router.dispatch(msg)
|
|
||||||
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
logger.info("MQTTClient cancelled")
|
|
||||||
raise
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning(
|
|
||||||
"MQTT connection to %s:%d failed (%s), retrying in %.0fs",
|
|
||||||
cfg.host, cfg.port, exc, cfg.reconnect_interval,
|
|
||||||
)
|
|
||||||
await asyncio.sleep(cfg.reconnect_interval)
|
|
||||||
|
|
||||||
@asynccontextmanager
|
|
||||||
async def connect(self) -> AsyncIterator[AsyncIterator[MQTTMessage]]:
|
|
||||||
"""Context manager that yields an async iterator of raw messages.
|
|
||||||
|
|
||||||
Useful when the caller wants to do its own routing::
|
|
||||||
|
|
||||||
async with client.connect() as messages:
|
|
||||||
async for msg in messages:
|
|
||||||
...
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
import aiomqtt
|
|
||||||
except ImportError as exc:
|
|
||||||
raise ImportError(
|
|
||||||
"aiomqtt is required. Install with: pip install circuitforge-core[mqtt]"
|
|
||||||
) from exc
|
|
||||||
|
|
||||||
cfg = self._config
|
|
||||||
kwargs: dict[str, Any] = {
|
|
||||||
"hostname": cfg.host,
|
|
||||||
"port": cfg.port,
|
|
||||||
"keepalive": cfg.keepalive,
|
|
||||||
"tls_params": aiomqtt.TLSParameters() if cfg.tls else None,
|
|
||||||
}
|
|
||||||
if cfg.client_id:
|
|
||||||
kwargs["identifier"] = cfg.client_id
|
|
||||||
if cfg.username is not None:
|
|
||||||
kwargs["username"] = cfg.username
|
|
||||||
if cfg.password is not None:
|
|
||||||
kwargs["password"] = cfg.password
|
|
||||||
|
|
||||||
async with aiomqtt.Client(**kwargs) as ac:
|
|
||||||
for p in self._router.patterns:
|
|
||||||
await ac.subscribe(p)
|
|
||||||
|
|
||||||
async def _iter() -> AsyncIterator[MQTTMessage]:
|
|
||||||
async for raw in ac.messages:
|
|
||||||
yield MQTTMessage(
|
|
||||||
topic=str(raw.topic),
|
|
||||||
payload=raw.payload if isinstance(raw.payload, bytes) else str(raw.payload).encode(),
|
|
||||||
qos=raw.qos,
|
|
||||||
retain=raw.retain,
|
|
||||||
received_at=datetime.now(tz=timezone.utc),
|
|
||||||
)
|
|
||||||
|
|
||||||
yield _iter()
|
|
||||||
|
|
@ -1,76 +0,0 @@
|
||||||
"""Meshtastic adapter for circuitforge-core.
|
|
||||||
|
|
||||||
Two backends are available:
|
|
||||||
|
|
||||||
- ``MQTTMeshtasticBackend`` — subscribes to a Meshtastic MQTT bridge
|
|
||||||
- ``SerialMeshtasticBackend`` — direct serial/TCP connection via the
|
|
||||||
``meshtastic`` Python library
|
|
||||||
|
|
||||||
Use ``make_backend()`` for config-driven selection.
|
|
||||||
|
|
||||||
MIT licensed.
|
|
||||||
"""
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from circuitforge_core.mqtt.meshtastic.interface import MeshtasticInterface
|
|
||||||
from circuitforge_core.mqtt.meshtastic.models import (
|
|
||||||
MeshtasticPacket,
|
|
||||||
MeshtasticPosition,
|
|
||||||
MeshtasticTelemetry,
|
|
||||||
)
|
|
||||||
from circuitforge_core.mqtt.meshtastic.mqtt_backend import MQTTMeshtasticBackend
|
|
||||||
from circuitforge_core.mqtt.meshtastic.serial_backend import SerialMeshtasticBackend
|
|
||||||
from circuitforge_core.mqtt.models import MQTTConfig
|
|
||||||
|
|
||||||
|
|
||||||
def make_backend(config: dict) -> MeshtasticInterface:
|
|
||||||
"""Construct a Meshtastic backend from a config dict.
|
|
||||||
|
|
||||||
Config keys:
|
|
||||||
backend (str): ``"mqtt"`` or ``"serial"`` (required)
|
|
||||||
|
|
||||||
For ``"mqtt"`` backend:
|
|
||||||
broker_host (str): MQTT broker hostname
|
|
||||||
broker_port (int): MQTT broker port (default 1883)
|
|
||||||
broker_username (str|None): optional
|
|
||||||
broker_password (str|None): optional
|
|
||||||
topic_prefix (str): topic to subscribe to (default ``msh/#``)
|
|
||||||
|
|
||||||
For ``"serial"`` backend:
|
|
||||||
dev_path (str|None): serial device, e.g. ``/dev/ttyUSB0``
|
|
||||||
tcp_host (str|None): TCP hostname for TCP mode
|
|
||||||
tcp_port (int): TCP port (default 4403)
|
|
||||||
"""
|
|
||||||
backend = config.get("backend", "mqtt").lower()
|
|
||||||
|
|
||||||
if backend == "mqtt":
|
|
||||||
mqtt_cfg = MQTTConfig(
|
|
||||||
host=config["broker_host"],
|
|
||||||
port=int(config.get("broker_port", 1883)),
|
|
||||||
username=config.get("broker_username"),
|
|
||||||
password=config.get("broker_password"),
|
|
||||||
)
|
|
||||||
return MQTTMeshtasticBackend(
|
|
||||||
mqtt_config=mqtt_cfg,
|
|
||||||
topic_prefix=config.get("topic_prefix", "msh/#"),
|
|
||||||
)
|
|
||||||
|
|
||||||
if backend == "serial":
|
|
||||||
return SerialMeshtasticBackend(
|
|
||||||
dev_path=config.get("dev_path"),
|
|
||||||
tcp_host=config.get("tcp_host"),
|
|
||||||
tcp_port=int(config.get("tcp_port", 4403)),
|
|
||||||
)
|
|
||||||
|
|
||||||
raise ValueError(f"Unknown Meshtastic backend: {backend!r}. Must be 'mqtt' or 'serial'.")
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"MeshtasticInterface",
|
|
||||||
"MeshtasticPacket",
|
|
||||||
"MeshtasticPosition",
|
|
||||||
"MeshtasticTelemetry",
|
|
||||||
"MQTTMeshtasticBackend",
|
|
||||||
"SerialMeshtasticBackend",
|
|
||||||
"make_backend",
|
|
||||||
]
|
|
||||||
|
|
@ -1,36 +0,0 @@
|
||||||
"""Abstract interface for Meshtastic backends.
|
|
||||||
|
|
||||||
MIT licensed.
|
|
||||||
"""
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from collections.abc import AsyncIterator
|
|
||||||
|
|
||||||
|
|
||||||
class MeshtasticInterface(ABC):
|
|
||||||
"""Async interface for receiving and sending Meshtastic packets.
|
|
||||||
|
|
||||||
Two concrete backends exist:
|
|
||||||
|
|
||||||
- MQTTMeshtasticBackend — subscribes to a Meshtastic MQTT bridge
|
|
||||||
- SerialMeshtasticBackend — connects directly via the meshtastic Python library
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def packets(self) -> AsyncIterator:
|
|
||||||
"""Async generator of MeshtasticPacket objects.
|
|
||||||
|
|
||||||
Yields packets as they arrive. Runs until cancelled.
|
|
||||||
Concrete types are ``MeshtasticPacket`` from
|
|
||||||
``circuitforge_core.mqtt.meshtastic.models``.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def send_text(
|
|
||||||
self,
|
|
||||||
text: str,
|
|
||||||
dest_id: int = 0xFFFFFFFF,
|
|
||||||
channel: int = 0,
|
|
||||||
) -> None:
|
|
||||||
"""Send a text message to dest_id (default: broadcast)."""
|
|
||||||
|
|
@ -1,83 +0,0 @@
|
||||||
"""Data models for Meshtastic packets.
|
|
||||||
|
|
||||||
MIT licensed.
|
|
||||||
"""
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from datetime import datetime, timezone
|
|
||||||
from typing import Literal
|
|
||||||
|
|
||||||
# Meshtastic portnum → our label
|
|
||||||
PacketType = Literal[
|
|
||||||
"text",
|
|
||||||
"position",
|
|
||||||
"nodeinfo",
|
|
||||||
"telemetry",
|
|
||||||
"routing",
|
|
||||||
"admin",
|
|
||||||
"unknown",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class MeshtasticPosition:
|
|
||||||
latitude: float | None = None
|
|
||||||
longitude: float | None = None
|
|
||||||
altitude_m: int | None = None
|
|
||||||
timestamp: datetime | None = None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class MeshtasticTelemetry:
|
|
||||||
battery_level: int | None = None # 0-100 %
|
|
||||||
voltage: float | None = None # volts
|
|
||||||
channel_util: float | None = None # 0-100 %
|
|
||||||
air_util_tx: float | None = None # 0-100 %
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class MeshtasticPacket:
|
|
||||||
"""Normalized Meshtastic packet from any backend."""
|
|
||||||
|
|
||||||
packet_type: PacketType
|
|
||||||
from_id: str # hex node ID, e.g. "!deadbeef"
|
|
||||||
from_num: int # numeric node ID
|
|
||||||
to_num: int # 0xffffffff = broadcast
|
|
||||||
channel: int
|
|
||||||
received_at: datetime = field(default_factory=lambda: datetime.now(tz=timezone.utc))
|
|
||||||
|
|
||||||
# Type-specific payloads (only one is populated per packet type)
|
|
||||||
text: str | None = None
|
|
||||||
position: MeshtasticPosition | None = None
|
|
||||||
telemetry: MeshtasticTelemetry | None = None
|
|
||||||
node_longname: str | None = None
|
|
||||||
node_shortname: str | None = None
|
|
||||||
hardware: int | None = None
|
|
||||||
|
|
||||||
# Original raw payload dict for downstream consumers that need all fields
|
|
||||||
raw: dict = field(default_factory=dict, compare=False, hash=False)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def is_broadcast(self) -> bool:
|
|
||||||
return self.to_num == 0xFFFFFFFF
|
|
||||||
|
|
||||||
def summary(self) -> str:
|
|
||||||
"""One-line human-readable description."""
|
|
||||||
src = self.from_id or f"!{self.from_num:08x}"
|
|
||||||
if self.packet_type == "text":
|
|
||||||
return f"[{src}] {self.text}"
|
|
||||||
if self.packet_type == "position" and self.position:
|
|
||||||
p = self.position
|
|
||||||
return f"[{src}] position {p.latitude:.5f},{p.longitude:.5f}"
|
|
||||||
if self.packet_type == "nodeinfo":
|
|
||||||
return f"[{src}] node info: {self.node_longname!r} ({self.node_shortname})"
|
|
||||||
if self.packet_type == "telemetry" and self.telemetry:
|
|
||||||
t = self.telemetry
|
|
||||||
parts = []
|
|
||||||
if t.battery_level is not None:
|
|
||||||
parts.append(f"batt={t.battery_level}%")
|
|
||||||
if t.voltage is not None:
|
|
||||||
parts.append(f"v={t.voltage:.2f}V")
|
|
||||||
return f"[{src}] telemetry {' '.join(parts)}"
|
|
||||||
return f"[{src}] {self.packet_type} packet"
|
|
||||||
|
|
@ -1,214 +0,0 @@
|
||||||
"""Meshtastic MQTT bridge backend.
|
|
||||||
|
|
||||||
Subscribes to the JSON MQTT topics that Meshtastic firmware publishes when
|
|
||||||
the MQTT uplink is enabled on a node.
|
|
||||||
|
|
||||||
Topic schema (Meshtastic firmware >=2.1):
|
|
||||||
msh/{region}/{gateway}/2/json/{portnum}/{fromId}
|
|
||||||
|
|
||||||
The payload is a JSON object. Examples by type:
|
|
||||||
|
|
||||||
Text message:
|
|
||||||
{"channel":0,"from":123456789,"id":987,"payload":{"text":"hello"},
|
|
||||||
"sender":"!07558d85","timestamp":1716200000,"to":4294967295,"type":"sendtext"}
|
|
||||||
|
|
||||||
Position:
|
|
||||||
{"channel":0,"from":123456789,"payload":{"altitude":50,
|
|
||||||
"latitude_i":374208130,"longitude_i":-1220848320,"time":1716200000},
|
|
||||||
"type":"position"}
|
|
||||||
|
|
||||||
Node info:
|
|
||||||
{"channel":0,"from":123456789,"payload":{"hardware":43,
|
|
||||||
"id":"!07558d85","longname":"Alan Node","shortname":"AN"},
|
|
||||||
"type":"nodeinfo"}
|
|
||||||
|
|
||||||
Telemetry:
|
|
||||||
{"channel":0,"from":123456789,"payload":{"battery_level":82,
|
|
||||||
"voltage":4.09,"channel_utilization":0.5,"air_util_tx":0.01,
|
|
||||||
"time":1716200000},"type":"telemetry"}
|
|
||||||
|
|
||||||
MIT licensed.
|
|
||||||
"""
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
from collections.abc import AsyncIterator
|
|
||||||
from datetime import datetime, timezone
|
|
||||||
|
|
||||||
from circuitforge_core.mqtt.client import MQTTClient
|
|
||||||
from circuitforge_core.mqtt.meshtastic.interface import MeshtasticInterface
|
|
||||||
from circuitforge_core.mqtt.meshtastic.models import (
|
|
||||||
MeshtasticPacket,
|
|
||||||
MeshtasticPosition,
|
|
||||||
MeshtasticTelemetry,
|
|
||||||
)
|
|
||||||
from circuitforge_core.mqtt.models import MQTTConfig, MQTTMessage
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# latitude_i / longitude_i are stored as integer × 1e7 in Meshtastic protobuf.
|
|
||||||
_COORD_SCALE = 1e-7
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_packet(raw_json: str | bytes, topic: str) -> MeshtasticPacket | None:
|
|
||||||
"""Parse a Meshtastic MQTT JSON payload into a MeshtasticPacket.
|
|
||||||
|
|
||||||
Returns None if the payload cannot be parsed or is an encrypted packet
|
|
||||||
(payload is a base64 blob instead of a dict).
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
obj = json.loads(raw_json)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
logger.debug("Non-JSON Meshtastic payload on topic %r", topic)
|
|
||||||
return None
|
|
||||||
|
|
||||||
payload = obj.get("payload")
|
|
||||||
if not isinstance(payload, dict):
|
|
||||||
# Encrypted packet — payload is a base64 string; skip.
|
|
||||||
return None
|
|
||||||
|
|
||||||
from_num: int = obj.get("from", 0)
|
|
||||||
sender: str = obj.get("sender", f"!{from_num:08x}")
|
|
||||||
channel: int = obj.get("channel", 0)
|
|
||||||
to_num: int = obj.get("to", 0xFFFFFFFF)
|
|
||||||
raw_ts: int | None = payload.get("time") or obj.get("timestamp")
|
|
||||||
received_at = (
|
|
||||||
datetime.fromtimestamp(raw_ts, tz=timezone.utc) if raw_ts else datetime.now(tz=timezone.utc)
|
|
||||||
)
|
|
||||||
|
|
||||||
ptype: str = obj.get("type", "unknown").lower()
|
|
||||||
|
|
||||||
if ptype in ("sendtext", "text"):
|
|
||||||
return MeshtasticPacket(
|
|
||||||
packet_type="text",
|
|
||||||
from_id=sender,
|
|
||||||
from_num=from_num,
|
|
||||||
to_num=to_num,
|
|
||||||
channel=channel,
|
|
||||||
received_at=received_at,
|
|
||||||
text=payload.get("text", ""),
|
|
||||||
raw=obj,
|
|
||||||
)
|
|
||||||
|
|
||||||
if ptype == "position":
|
|
||||||
lat_i: int | None = payload.get("latitude_i")
|
|
||||||
lon_i: int | None = payload.get("longitude_i")
|
|
||||||
return MeshtasticPacket(
|
|
||||||
packet_type="position",
|
|
||||||
from_id=sender,
|
|
||||||
from_num=from_num,
|
|
||||||
to_num=to_num,
|
|
||||||
channel=channel,
|
|
||||||
received_at=received_at,
|
|
||||||
position=MeshtasticPosition(
|
|
||||||
latitude=lat_i * _COORD_SCALE if lat_i is not None else None,
|
|
||||||
longitude=lon_i * _COORD_SCALE if lon_i is not None else None,
|
|
||||||
altitude_m=payload.get("altitude"),
|
|
||||||
timestamp=received_at,
|
|
||||||
),
|
|
||||||
raw=obj,
|
|
||||||
)
|
|
||||||
|
|
||||||
if ptype == "nodeinfo":
|
|
||||||
return MeshtasticPacket(
|
|
||||||
packet_type="nodeinfo",
|
|
||||||
from_id=sender,
|
|
||||||
from_num=from_num,
|
|
||||||
to_num=to_num,
|
|
||||||
channel=channel,
|
|
||||||
received_at=received_at,
|
|
||||||
node_longname=payload.get("longname"),
|
|
||||||
node_shortname=payload.get("shortname"),
|
|
||||||
hardware=payload.get("hardware"),
|
|
||||||
raw=obj,
|
|
||||||
)
|
|
||||||
|
|
||||||
if ptype == "telemetry":
|
|
||||||
return MeshtasticPacket(
|
|
||||||
packet_type="telemetry",
|
|
||||||
from_id=sender,
|
|
||||||
from_num=from_num,
|
|
||||||
to_num=to_num,
|
|
||||||
channel=channel,
|
|
||||||
received_at=received_at,
|
|
||||||
telemetry=MeshtasticTelemetry(
|
|
||||||
battery_level=payload.get("battery_level"),
|
|
||||||
voltage=payload.get("voltage"),
|
|
||||||
channel_util=payload.get("channel_utilization"),
|
|
||||||
air_util_tx=payload.get("air_util_tx"),
|
|
||||||
),
|
|
||||||
raw=obj,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Routing, admin, and other packet types — return minimal packet.
|
|
||||||
return MeshtasticPacket(
|
|
||||||
packet_type="unknown",
|
|
||||||
from_id=sender,
|
|
||||||
from_num=from_num,
|
|
||||||
to_num=to_num,
|
|
||||||
channel=channel,
|
|
||||||
received_at=received_at,
|
|
||||||
raw=obj,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class MQTTMeshtasticBackend(MeshtasticInterface):
|
|
||||||
"""Receive Meshtastic packets via a Meshtastic MQTT bridge.
|
|
||||||
|
|
||||||
Requires a Meshtastic node with the MQTT uplink enabled, publishing to
|
|
||||||
the configured broker. Set ``topic_prefix`` to match the region prefix
|
|
||||||
configured on the node (default ``msh/#`` matches all regions).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
mqtt_config: broker connection settings
|
|
||||||
topic_prefix: MQTT topic pattern to subscribe to (default ``msh/#``)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
mqtt_config: MQTTConfig,
|
|
||||||
topic_prefix: str = "msh/#",
|
|
||||||
) -> None:
|
|
||||||
self._mqtt_config = mqtt_config
|
|
||||||
self._topic_prefix = topic_prefix
|
|
||||||
|
|
||||||
async def packets(self) -> AsyncIterator[MeshtasticPacket]:
|
|
||||||
client = MQTTClient(self._mqtt_config)
|
|
||||||
|
|
||||||
queue: asyncio.Queue[MeshtasticPacket] = asyncio.Queue()
|
|
||||||
|
|
||||||
@client.on(self._topic_prefix)
|
|
||||||
async def _handle(msg: MQTTMessage) -> None:
|
|
||||||
pkt = _parse_packet(msg.payload, msg.topic)
|
|
||||||
if pkt is not None:
|
|
||||||
await queue.put(pkt)
|
|
||||||
|
|
||||||
runner = asyncio.create_task(client.run())
|
|
||||||
try:
|
|
||||||
while True:
|
|
||||||
yield await queue.get()
|
|
||||||
finally:
|
|
||||||
runner.cancel()
|
|
||||||
try:
|
|
||||||
await runner
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def send_text(
|
|
||||||
self,
|
|
||||||
text: str,
|
|
||||||
dest_id: int = 0xFFFFFFFF,
|
|
||||||
channel: int = 0,
|
|
||||||
) -> None:
|
|
||||||
"""Publishing back to MQTT is not supported by this backend.
|
|
||||||
|
|
||||||
Meshtastic nodes consume from MQTT in a different topic namespace;
|
|
||||||
use the serial backend or a direct Meshtastic MQTT channel config
|
|
||||||
for two-way messaging.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError(
|
|
||||||
"MQTTMeshtasticBackend is receive-only. "
|
|
||||||
"Use SerialMeshtasticBackend for send support."
|
|
||||||
)
|
|
||||||
|
|
@ -1,210 +0,0 @@
|
||||||
"""Meshtastic serial/TCP backend using the meshtastic Python library.
|
|
||||||
|
|
||||||
Connects directly to a Meshtastic node over serial port or TCP (e.g.
|
|
||||||
when a node exposes Meshtastic's native TCP API on port 4403).
|
|
||||||
|
|
||||||
The ``meshtastic`` library is synchronous and uses threading + PyPubSub
|
|
||||||
for callbacks. This backend bridges into asyncio via an asyncio.Queue:
|
|
||||||
the sync callback puts packets on the queue, and ``packets()`` awaits
|
|
||||||
items from it.
|
|
||||||
|
|
||||||
MIT licensed.
|
|
||||||
"""
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import logging
|
|
||||||
from collections.abc import AsyncIterator
|
|
||||||
from datetime import datetime, timezone
|
|
||||||
|
|
||||||
from circuitforge_core.mqtt.meshtastic.interface import MeshtasticInterface
|
|
||||||
from circuitforge_core.mqtt.meshtastic.models import (
|
|
||||||
MeshtasticPacket,
|
|
||||||
MeshtasticPosition,
|
|
||||||
MeshtasticTelemetry,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
_COORD_SCALE = 1e-7
|
|
||||||
|
|
||||||
|
|
||||||
def _packet_from_decoded(decoded: dict, from_id: int) -> MeshtasticPacket:
|
|
||||||
"""Convert a meshtastic-library decoded packet dict to MeshtasticPacket."""
|
|
||||||
portnum: str = decoded.get("portnum", "UNKNOWN_APP")
|
|
||||||
sender = f"!{from_id:08x}"
|
|
||||||
to_num: int = decoded.get("to", 0xFFFFFFFF)
|
|
||||||
channel: int = decoded.get("channel", 0)
|
|
||||||
now = datetime.now(tz=timezone.utc)
|
|
||||||
|
|
||||||
if portnum == "TEXT_MESSAGE_APP":
|
|
||||||
return MeshtasticPacket(
|
|
||||||
packet_type="text",
|
|
||||||
from_id=sender,
|
|
||||||
from_num=from_id,
|
|
||||||
to_num=to_num,
|
|
||||||
channel=channel,
|
|
||||||
received_at=now,
|
|
||||||
text=decoded.get("decoded", {}).get("text", ""),
|
|
||||||
raw=decoded,
|
|
||||||
)
|
|
||||||
|
|
||||||
if portnum == "POSITION_APP":
|
|
||||||
pos = decoded.get("decoded", {}).get("position", {})
|
|
||||||
lat_i = pos.get("latitudeI")
|
|
||||||
lon_i = pos.get("longitudeI")
|
|
||||||
alt = pos.get("altitude")
|
|
||||||
return MeshtasticPacket(
|
|
||||||
packet_type="position",
|
|
||||||
from_id=sender,
|
|
||||||
from_num=from_id,
|
|
||||||
to_num=to_num,
|
|
||||||
channel=channel,
|
|
||||||
received_at=now,
|
|
||||||
position=MeshtasticPosition(
|
|
||||||
latitude=lat_i * _COORD_SCALE if lat_i is not None else None,
|
|
||||||
longitude=lon_i * _COORD_SCALE if lon_i is not None else None,
|
|
||||||
altitude_m=alt,
|
|
||||||
timestamp=now,
|
|
||||||
),
|
|
||||||
raw=decoded,
|
|
||||||
)
|
|
||||||
|
|
||||||
if portnum == "NODEINFO_APP":
|
|
||||||
info = decoded.get("decoded", {}).get("user", {})
|
|
||||||
return MeshtasticPacket(
|
|
||||||
packet_type="nodeinfo",
|
|
||||||
from_id=sender,
|
|
||||||
from_num=from_id,
|
|
||||||
to_num=to_num,
|
|
||||||
channel=channel,
|
|
||||||
received_at=now,
|
|
||||||
node_longname=info.get("longName"),
|
|
||||||
node_shortname=info.get("shortName"),
|
|
||||||
hardware=info.get("hwModel"),
|
|
||||||
raw=decoded,
|
|
||||||
)
|
|
||||||
|
|
||||||
if portnum == "TELEMETRY_APP":
|
|
||||||
telem = decoded.get("decoded", {}).get("telemetry", {})
|
|
||||||
dev = telem.get("deviceMetrics", {})
|
|
||||||
return MeshtasticPacket(
|
|
||||||
packet_type="telemetry",
|
|
||||||
from_id=sender,
|
|
||||||
from_num=from_id,
|
|
||||||
to_num=to_num,
|
|
||||||
channel=channel,
|
|
||||||
received_at=now,
|
|
||||||
telemetry=MeshtasticTelemetry(
|
|
||||||
battery_level=dev.get("batteryLevel"),
|
|
||||||
voltage=dev.get("voltage"),
|
|
||||||
channel_util=dev.get("channelUtilization"),
|
|
||||||
air_util_tx=dev.get("airUtilTx"),
|
|
||||||
),
|
|
||||||
raw=decoded,
|
|
||||||
)
|
|
||||||
|
|
||||||
return MeshtasticPacket(
|
|
||||||
packet_type="unknown",
|
|
||||||
from_id=sender,
|
|
||||||
from_num=from_id,
|
|
||||||
to_num=to_num,
|
|
||||||
channel=channel,
|
|
||||||
received_at=now,
|
|
||||||
raw=decoded,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class SerialMeshtasticBackend(MeshtasticInterface):
|
|
||||||
"""Receive and send Meshtastic packets via serial port or TCP.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dev_path: serial device path (e.g. ``/dev/ttyUSB0``) or ``None``
|
|
||||||
to auto-detect the first connected Meshtastic device.
|
|
||||||
tcp_host: hostname for TCP connection. If set, ``dev_path`` is ignored
|
|
||||||
and a TCP connection to port 4403 is used.
|
|
||||||
tcp_port: TCP port (default 4403).
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
dev_path: str | None = None,
|
|
||||||
tcp_host: str | None = None,
|
|
||||||
tcp_port: int = 4403,
|
|
||||||
) -> None:
|
|
||||||
self._dev_path = dev_path
|
|
||||||
self._tcp_host = tcp_host
|
|
||||||
self._tcp_port = tcp_port
|
|
||||||
|
|
||||||
def _make_interface(self):
|
|
||||||
try:
|
|
||||||
import meshtastic.serial_interface
|
|
||||||
import meshtastic.tcp_interface
|
|
||||||
except ImportError as exc:
|
|
||||||
raise ImportError(
|
|
||||||
"meshtastic is required for SerialMeshtasticBackend. "
|
|
||||||
"Install with: pip install circuitforge-core[meshtastic-serial]"
|
|
||||||
) from exc
|
|
||||||
|
|
||||||
if self._tcp_host:
|
|
||||||
return meshtastic.tcp_interface.TCPInterface(
|
|
||||||
hostname=self._tcp_host,
|
|
||||||
portNumber=self._tcp_port,
|
|
||||||
)
|
|
||||||
return meshtastic.serial_interface.SerialInterface(devPath=self._dev_path)
|
|
||||||
|
|
||||||
async def packets(self) -> AsyncIterator[MeshtasticPacket]:
|
|
||||||
loop = asyncio.get_running_loop()
|
|
||||||
queue: asyncio.Queue[MeshtasticPacket | None] = asyncio.Queue()
|
|
||||||
|
|
||||||
def _on_receive(packet: dict, interface) -> None:
|
|
||||||
try:
|
|
||||||
from_id: int = packet.get("from", 0)
|
|
||||||
pkt = _packet_from_decoded(packet, from_id)
|
|
||||||
loop.call_soon_threadsafe(queue.put_nowait, pkt)
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Error decoding Meshtastic serial packet")
|
|
||||||
|
|
||||||
def _on_connection_closed(interface) -> None:
|
|
||||||
logger.warning("Meshtastic serial connection closed")
|
|
||||||
loop.call_soon_threadsafe(queue.put_nowait, None)
|
|
||||||
|
|
||||||
iface = await loop.run_in_executor(None, self._make_interface)
|
|
||||||
|
|
||||||
try:
|
|
||||||
from pubsub import pub
|
|
||||||
pub.subscribe(_on_receive, "meshtastic.receive")
|
|
||||||
pub.subscribe(_on_connection_closed, "meshtastic.connection.lost")
|
|
||||||
except ImportError:
|
|
||||||
await loop.run_in_executor(None, iface.close)
|
|
||||||
raise ImportError(
|
|
||||||
"pypubsub is required for SerialMeshtasticBackend. "
|
|
||||||
"Install with: pip install circuitforge-core[meshtastic-serial]"
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
while True:
|
|
||||||
pkt = await queue.get()
|
|
||||||
if pkt is None:
|
|
||||||
break
|
|
||||||
yield pkt
|
|
||||||
finally:
|
|
||||||
pub.unsubscribe(_on_receive, "meshtastic.receive")
|
|
||||||
pub.unsubscribe(_on_connection_closed, "meshtastic.connection.lost")
|
|
||||||
await loop.run_in_executor(None, iface.close)
|
|
||||||
|
|
||||||
async def send_text(
|
|
||||||
self,
|
|
||||||
text: str,
|
|
||||||
dest_id: int = 0xFFFFFFFF,
|
|
||||||
channel: int = 0,
|
|
||||||
) -> None:
|
|
||||||
loop = asyncio.get_running_loop()
|
|
||||||
iface = await loop.run_in_executor(None, self._make_interface)
|
|
||||||
try:
|
|
||||||
await loop.run_in_executor(
|
|
||||||
None,
|
|
||||||
lambda: iface.sendText(text, destinationId=dest_id, channelIndex=channel),
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
await loop.run_in_executor(None, iface.close)
|
|
||||||
|
|
@ -1,44 +0,0 @@
|
||||||
"""Data models for the MQTT client module.
|
|
||||||
|
|
||||||
MIT licensed.
|
|
||||||
"""
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import json
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from datetime import datetime, timezone
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class MQTTConfig:
|
|
||||||
"""Connection config for an MQTT broker."""
|
|
||||||
|
|
||||||
host: str
|
|
||||||
port: int = 1883
|
|
||||||
username: str | None = None
|
|
||||||
password: str | None = None
|
|
||||||
client_id: str = ""
|
|
||||||
keepalive: int = 60
|
|
||||||
tls: bool = False
|
|
||||||
reconnect_interval: float = 5.0
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class MQTTMessage:
|
|
||||||
"""A single received MQTT message."""
|
|
||||||
|
|
||||||
topic: str
|
|
||||||
payload: bytes
|
|
||||||
qos: int = 0
|
|
||||||
retain: bool = False
|
|
||||||
received_at: datetime = field(default_factory=lambda: datetime.now(tz=timezone.utc))
|
|
||||||
|
|
||||||
def text(self, encoding: str = "utf-8") -> str:
|
|
||||||
return self.payload.decode(encoding, errors="replace")
|
|
||||||
|
|
||||||
def json(self) -> dict:
|
|
||||||
return json.loads(self.payload)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def topic_parts(self) -> list[str]:
|
|
||||||
return self.topic.split("/")
|
|
||||||
|
|
@ -1,74 +0,0 @@
|
||||||
"""MQTT topic router with wildcard pattern matching.
|
|
||||||
|
|
||||||
MIT licensed.
|
|
||||||
"""
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import inspect
|
|
||||||
import logging
|
|
||||||
from collections.abc import Callable, Coroutine
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from circuitforge_core.mqtt.models import MQTTMessage
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
Handler = Callable[[MQTTMessage], Coroutine[Any, Any, None]]
|
|
||||||
|
|
||||||
|
|
||||||
def matches(pattern: str, topic: str) -> bool:
|
|
||||||
"""Return True if topic matches the MQTT wildcard pattern.
|
|
||||||
|
|
||||||
MQTT wildcard rules:
|
|
||||||
- '+' matches exactly one topic level (segment between '/' separators)
|
|
||||||
- '#' matches zero or more levels and MUST appear at the end of the pattern
|
|
||||||
- All other characters match literally
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
matches("sensor/+/temp", "sensor/room1/temp") → True
|
|
||||||
matches("sensor/+/temp", "sensor/a/b/temp") → False
|
|
||||||
matches("sensor/#", "sensor/room1/temp") → True
|
|
||||||
matches("sensor/#", "sensor") → True (# = zero levels)
|
|
||||||
matches("#", "any/topic/here") → True
|
|
||||||
matches("a/b/c", "a/b/c") → True
|
|
||||||
"""
|
|
||||||
# TODO: implement wildcard matching
|
|
||||||
# Hint: split both pattern and topic on '/' and walk them in parallel.
|
|
||||||
# Handle '#' early (if it appears, everything past that point in topic matches).
|
|
||||||
# '+' must cover exactly one (non-empty) level.
|
|
||||||
raise NotImplementedError("matches() is not yet implemented")
|
|
||||||
|
|
||||||
|
|
||||||
class TopicRouter:
|
|
||||||
"""Register async handlers for MQTT topic patterns and dispatch messages."""
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self._routes: list[tuple[str, Handler]] = []
|
|
||||||
|
|
||||||
@property
|
|
||||||
def patterns(self) -> list[str]:
|
|
||||||
return [p for p, _ in self._routes]
|
|
||||||
|
|
||||||
def register(self, pattern: str, handler: Handler) -> None:
|
|
||||||
"""Add a handler for the given topic pattern."""
|
|
||||||
self._routes.append((pattern, handler))
|
|
||||||
|
|
||||||
def on(self, pattern: str) -> Callable[[Handler], Handler]:
|
|
||||||
"""Decorator: @router.on("sensor/#") async def handle(msg): ..."""
|
|
||||||
def decorator(fn: Handler) -> Handler:
|
|
||||||
self.register(pattern, fn)
|
|
||||||
return fn
|
|
||||||
return decorator
|
|
||||||
|
|
||||||
async def dispatch(self, message: MQTTMessage) -> None:
|
|
||||||
"""Call all handlers whose pattern matches message.topic."""
|
|
||||||
for pattern, handler in self._routes:
|
|
||||||
try:
|
|
||||||
if matches(pattern, message.topic):
|
|
||||||
if inspect.iscoroutinefunction(handler):
|
|
||||||
await handler(message)
|
|
||||||
else:
|
|
||||||
handler(message)
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Handler for %r raised on topic %r", pattern, message.topic)
|
|
||||||
|
|
@ -51,12 +51,9 @@ cf-orch service profile (Phase 3 — remote backend):
|
||||||
"""
|
"""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
from typing import Sequence
|
from typing import Sequence
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
from circuitforge_core.reranker.base import RerankResult, Reranker, TextReranker
|
from circuitforge_core.reranker.base import RerankResult, Reranker, TextReranker
|
||||||
from circuitforge_core.reranker.adapters.mock import MockTextReranker
|
from circuitforge_core.reranker.adapters.mock import MockTextReranker
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -169,15 +169,7 @@ class LocalScheduler:
|
||||||
if not q:
|
if not q:
|
||||||
break
|
break
|
||||||
task = q.popleft()
|
task = q.popleft()
|
||||||
try:
|
self._run_task(self._db_path, task.id, task_type, task.job_id, task.params)
|
||||||
self._run_task(self._db_path, task.id, task_type, task.job_id, task.params)
|
|
||||||
except Exception as exc:
|
|
||||||
# run_task_fn should handle its own exceptions. If it leaks one,
|
|
||||||
# log it so the task doesn't silently stay 'queued' with no trace.
|
|
||||||
logger.exception(
|
|
||||||
"Unhandled exception in batch worker task %d (%s): %s",
|
|
||||||
task.id, task_type, exc,
|
|
||||||
)
|
|
||||||
finally:
|
finally:
|
||||||
with self._lock:
|
with self._lock:
|
||||||
self._active.pop(task_type, None)
|
self._active.pop(task_type, None)
|
||||||
|
|
|
||||||
|
|
@ -1,15 +1,14 @@
|
||||||
"""
|
"""
|
||||||
cf-text FastAPI service — managed by cf-orch.
|
cf-text FastAPI service — managed by cf-orch.
|
||||||
|
|
||||||
Lightweight local text generation and PII filtering. Supports GGUF models via
|
Lightweight local text generation. Supports GGUF models via llama.cpp and
|
||||||
llama.cpp, HuggingFace transformers, and token-classification models (classifier
|
HuggingFace transformers. Sits alongside vllm/ollama for products that need
|
||||||
backend) for PII detection and redaction.
|
fast, frequent inference from small local models (3B–7B Q4).
|
||||||
|
|
||||||
Endpoints:
|
Endpoints:
|
||||||
GET /health → {"status": "ok", "model": str, "vram_mb": int, "backend": str}
|
GET /health → {"status": "ok", "model": str, "vram_mb": int, "backend": str}
|
||||||
POST /generate → GenerateResponse (text-gen backends only)
|
POST /generate → GenerateResponse
|
||||||
POST /chat → GenerateResponse (text-gen backends only)
|
POST /chat → GenerateResponse
|
||||||
POST /filter → FilterResponse (classifier backend only)
|
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
python -m circuitforge_core.text.app \
|
python -m circuitforge_core.text.app \
|
||||||
|
|
@ -35,46 +34,17 @@ import os
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Annotated, Literal, Union
|
|
||||||
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from fastapi import FastAPI, HTTPException
|
from fastapi import FastAPI, HTTPException
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from circuitforge_core.text.backends.base import ChatMessage as BackendChatMessage
|
from circuitforge_core.text.backends.base import ChatMessage as BackendChatMessage
|
||||||
from circuitforge_core.text.backends.base import make_classifier_backend, make_text_backend
|
from circuitforge_core.text.backends.base import make_text_backend
|
||||||
from circuitforge_core.text.filter import FilterResult, PIIFilter
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
_backend = None
|
_backend = None
|
||||||
_pii_filter: PIIFilter | None = None
|
|
||||||
|
|
||||||
|
|
||||||
# ── Content block types (OpenAI multimodal format) ────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class ContentBlockText(BaseModel):
|
|
||||||
type: Literal["text"]
|
|
||||||
text: str
|
|
||||||
|
|
||||||
|
|
||||||
class ContentBlockImageURL(BaseModel):
|
|
||||||
type: Literal["image_url"]
|
|
||||||
image_url: dict[str, str]
|
|
||||||
|
|
||||||
|
|
||||||
ContentBlock = Annotated[
|
|
||||||
Union[ContentBlockText, ContentBlockImageURL],
|
|
||||||
Field(discriminator="type"),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def _to_backend_message(role: str, content: "str | list[ContentBlock]") -> "BackendChatMessage":
|
|
||||||
"""Convert an API message to a BackendChatMessage with raw content dicts."""
|
|
||||||
if isinstance(content, str):
|
|
||||||
return BackendChatMessage(role, content)
|
|
||||||
return BackendChatMessage(role, [b.model_dump() for b in content])
|
|
||||||
|
|
||||||
|
|
||||||
# ── Request / response models ─────────────────────────────────────────────────
|
# ── Request / response models ─────────────────────────────────────────────────
|
||||||
|
|
@ -89,7 +59,7 @@ class GenerateRequest(BaseModel):
|
||||||
|
|
||||||
class ChatMessageModel(BaseModel):
|
class ChatMessageModel(BaseModel):
|
||||||
role: str
|
role: str
|
||||||
content: Union[str, list[ContentBlock]] = ""
|
content: str
|
||||||
|
|
||||||
|
|
||||||
class ChatRequest(BaseModel):
|
class ChatRequest(BaseModel):
|
||||||
|
|
@ -104,31 +74,12 @@ class GenerateResponse(BaseModel):
|
||||||
model: str = ""
|
model: str = ""
|
||||||
|
|
||||||
|
|
||||||
class FilterRequest(BaseModel):
|
|
||||||
text: str
|
|
||||||
|
|
||||||
|
|
||||||
class PIISpanResponse(BaseModel):
|
|
||||||
label: str
|
|
||||||
start: int
|
|
||||||
end: int
|
|
||||||
text: str
|
|
||||||
score: float
|
|
||||||
|
|
||||||
|
|
||||||
class FilterResponse(BaseModel):
|
|
||||||
redacted_text: str
|
|
||||||
spans: list[PIISpanResponse]
|
|
||||||
original_text: str
|
|
||||||
model: str = ""
|
|
||||||
|
|
||||||
|
|
||||||
# ── OpenAI-compat request / response (for LLMRouter openai_compat path) ──────
|
# ── OpenAI-compat request / response (for LLMRouter openai_compat path) ──────
|
||||||
|
|
||||||
|
|
||||||
class OAIMessageModel(BaseModel):
|
class OAIMessageModel(BaseModel):
|
||||||
role: str
|
role: str
|
||||||
content: Union[str, list[ContentBlock]] = ""
|
content: str
|
||||||
|
|
||||||
|
|
||||||
class OAIChatRequest(BaseModel):
|
class OAIChatRequest(BaseModel):
|
||||||
|
|
@ -169,7 +120,6 @@ def create_app(
|
||||||
gpu_ids: str | None = None,
|
gpu_ids: str | None = None,
|
||||||
backend: str | None = None,
|
backend: str | None = None,
|
||||||
mock: bool = False,
|
mock: bool = False,
|
||||||
mmproj_path: str = "",
|
|
||||||
) -> FastAPI:
|
) -> FastAPI:
|
||||||
"""Start the cf-text FastAPI app.
|
"""Start the cf-text FastAPI app.
|
||||||
|
|
||||||
|
|
@ -177,12 +127,8 @@ def create_app(
|
||||||
(e.g. "0,1"). When set, overrides ``gpu_id`` and sets
|
(e.g. "0,1"). When set, overrides ``gpu_id`` and sets
|
||||||
``CUDA_VISIBLE_DEVICES`` to the full list so HuggingFace Accelerate's
|
``CUDA_VISIBLE_DEVICES`` to the full list so HuggingFace Accelerate's
|
||||||
``device_map="auto"`` can shard the model across all listed devices.
|
``device_map="auto"`` can shard the model across all listed devices.
|
||||||
|
|
||||||
When ``backend="classifier"``, the service skips the text-gen backends
|
|
||||||
and loads a token-classification pipeline instead. Only ``POST /filter``
|
|
||||||
is available in that mode; ``/generate`` and ``/chat`` return 501.
|
|
||||||
"""
|
"""
|
||||||
global _backend, _pii_filter
|
global _backend
|
||||||
|
|
||||||
if not mock and not model_path:
|
if not mock and not model_path:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|
@ -193,26 +139,13 @@ def create_app(
|
||||||
visible = gpu_ids if gpu_ids else str(gpu_id)
|
visible = gpu_ids if gpu_ids else str(gpu_id)
|
||||||
os.environ.setdefault("CUDA_VISIBLE_DEVICES", visible)
|
os.environ.setdefault("CUDA_VISIBLE_DEVICES", visible)
|
||||||
|
|
||||||
resolved_backend = backend or os.environ.get("CF_TEXT_BACKEND", "")
|
_backend = make_text_backend(model_path, backend=backend, mock=mock)
|
||||||
if resolved_backend == "classifier" or (not resolved_backend and False):
|
logger.info("cf-text ready: model=%r vram=%dMB", _backend.model_name, _backend.vram_mb)
|
||||||
classifier_backend = make_classifier_backend(model_path)
|
|
||||||
_pii_filter = PIIFilter.from_backend(classifier_backend)
|
|
||||||
logger.info(
|
|
||||||
"cf-text (classifier) ready: model=%r vram=%dMB",
|
|
||||||
classifier_backend.model_name,
|
|
||||||
classifier_backend.vram_mb,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
_backend = make_text_backend(model_path, backend=backend, mock=mock, mmproj_path=mmproj_path)
|
|
||||||
logger.info("cf-text ready: model=%r vram=%dMB", _backend.model_name, _backend.vram_mb)
|
|
||||||
|
|
||||||
app = FastAPI(title="cf-text", version="0.1.0")
|
app = FastAPI(title="cf-text", version="0.1.0")
|
||||||
|
|
||||||
@app.get("/health")
|
@app.get("/health")
|
||||||
def health() -> dict:
|
def health() -> dict:
|
||||||
if _pii_filter is not None:
|
|
||||||
b = _pii_filter._backend
|
|
||||||
return {"status": "ok", "model": b.model_name, "vram_mb": b.vram_mb, "backend": "classifier"}
|
|
||||||
if _backend is None:
|
if _backend is None:
|
||||||
raise HTTPException(503, detail="backend not initialised")
|
raise HTTPException(503, detail="backend not initialised")
|
||||||
return {
|
return {
|
||||||
|
|
@ -221,35 +154,8 @@ def create_app(
|
||||||
"vram_mb": _backend.vram_mb,
|
"vram_mb": _backend.vram_mb,
|
||||||
}
|
}
|
||||||
|
|
||||||
@app.post("/filter")
|
|
||||||
async def filter_text(req: FilterRequest) -> FilterResponse:
|
|
||||||
if _pii_filter is None:
|
|
||||||
raise HTTPException(
|
|
||||||
501,
|
|
||||||
detail="This cf-text instance is not running a classifier backend. "
|
|
||||||
"Start with --backend classifier and a token-classification model.",
|
|
||||||
)
|
|
||||||
result = await _pii_filter.filter_async(req.text)
|
|
||||||
return FilterResponse(
|
|
||||||
redacted_text=result.redacted_text,
|
|
||||||
spans=[
|
|
||||||
PIISpanResponse(
|
|
||||||
label=s.label,
|
|
||||||
start=s.start,
|
|
||||||
end=s.end,
|
|
||||||
text=s.text,
|
|
||||||
score=s.score,
|
|
||||||
)
|
|
||||||
for s in result.spans
|
|
||||||
],
|
|
||||||
original_text=result.original_text,
|
|
||||||
model=_pii_filter._backend.model_name,
|
|
||||||
)
|
|
||||||
|
|
||||||
@app.post("/generate")
|
@app.post("/generate")
|
||||||
async def generate(req: GenerateRequest) -> GenerateResponse:
|
async def generate(req: GenerateRequest) -> GenerateResponse:
|
||||||
if _pii_filter is not None:
|
|
||||||
raise HTTPException(501, detail="classifier backend loaded — use POST /filter")
|
|
||||||
if _backend is None:
|
if _backend is None:
|
||||||
raise HTTPException(503, detail="backend not initialised")
|
raise HTTPException(503, detail="backend not initialised")
|
||||||
result = await _backend.generate_async(
|
result = await _backend.generate_async(
|
||||||
|
|
@ -266,20 +172,16 @@ def create_app(
|
||||||
|
|
||||||
@app.post("/chat")
|
@app.post("/chat")
|
||||||
async def chat(req: ChatRequest) -> GenerateResponse:
|
async def chat(req: ChatRequest) -> GenerateResponse:
|
||||||
if _pii_filter is not None:
|
|
||||||
raise HTTPException(501, detail="classifier backend loaded — use POST /filter")
|
|
||||||
if _backend is None:
|
if _backend is None:
|
||||||
raise HTTPException(503, detail="backend not initialised")
|
raise HTTPException(503, detail="backend not initialised")
|
||||||
messages = [_to_backend_message(m.role, m.content) for m in req.messages]
|
messages = [BackendChatMessage(m.role, m.content) for m in req.messages]
|
||||||
|
# chat() is sync-only in the Protocol; run in thread pool to avoid blocking
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
try:
|
result = await loop.run_in_executor(
|
||||||
result = await loop.run_in_executor(
|
None,
|
||||||
None,
|
partial(_backend.chat, messages,
|
||||||
partial(_backend.chat, messages,
|
max_tokens=req.max_tokens, temperature=req.temperature),
|
||||||
max_tokens=req.max_tokens, temperature=req.temperature),
|
)
|
||||||
)
|
|
||||||
except ValueError as exc:
|
|
||||||
raise HTTPException(422, detail=str(exc)) from exc
|
|
||||||
return GenerateResponse(
|
return GenerateResponse(
|
||||||
text=result.text,
|
text=result.text,
|
||||||
tokens_used=result.tokens_used,
|
tokens_used=result.tokens_used,
|
||||||
|
|
@ -296,16 +198,13 @@ def create_app(
|
||||||
"""
|
"""
|
||||||
if _backend is None:
|
if _backend is None:
|
||||||
raise HTTPException(503, detail="backend not initialised")
|
raise HTTPException(503, detail="backend not initialised")
|
||||||
messages = [_to_backend_message(m.role, m.content) for m in req.messages]
|
messages = [BackendChatMessage(m.role, m.content) for m in req.messages]
|
||||||
max_tok = req.max_tokens or 512
|
max_tok = req.max_tokens or 512
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
try:
|
result = await loop.run_in_executor(
|
||||||
result = await loop.run_in_executor(
|
None,
|
||||||
None,
|
partial(_backend.chat, messages, max_tokens=max_tok, temperature=req.temperature),
|
||||||
partial(_backend.chat, messages, max_tokens=max_tok, temperature=req.temperature),
|
)
|
||||||
)
|
|
||||||
except ValueError as exc:
|
|
||||||
raise HTTPException(422, detail=str(exc)) from exc
|
|
||||||
return OAIChatResponse(
|
return OAIChatResponse(
|
||||||
id=f"cftext-{uuid.uuid4().hex[:12]}",
|
id=f"cftext-{uuid.uuid4().hex[:12]}",
|
||||||
created=int(time.time()),
|
created=int(time.time()),
|
||||||
|
|
@ -331,16 +230,7 @@ def _parse_args() -> argparse.Namespace:
|
||||||
parser.add_argument("--gpu-ids", default=None,
|
parser.add_argument("--gpu-ids", default=None,
|
||||||
help="Comma-separated CUDA device indices for multi-GPU spanning "
|
help="Comma-separated CUDA device indices for multi-GPU spanning "
|
||||||
"(e.g. '0,1'). Overrides --gpu-id when set.")
|
"(e.g. '0,1'). Overrides --gpu-id when set.")
|
||||||
parser.add_argument(
|
parser.add_argument("--backend", choices=["llamacpp", "transformers"], default=None)
|
||||||
"--backend",
|
|
||||||
choices=["llamacpp", "transformers", "ollama", "vllm", "classifier"],
|
|
||||||
default=None,
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--mmproj", default="",
|
|
||||||
help="Path to multimodal projector file for VLM GGUF models (LLaVA-style). "
|
|
||||||
"Qwen2-VL and other self-contained VLMs don't need this.",
|
|
||||||
)
|
|
||||||
parser.add_argument("--mock", action="store_true",
|
parser.add_argument("--mock", action="store_true",
|
||||||
help="Run in mock mode (no model or GPU needed)")
|
help="Run in mock mode (no model or GPU needed)")
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
@ -357,6 +247,5 @@ if __name__ == "__main__":
|
||||||
gpu_ids=args.gpu_ids,
|
gpu_ids=args.gpu_ids,
|
||||||
backend=args.backend,
|
backend=args.backend,
|
||||||
mock=mock,
|
mock=mock,
|
||||||
mmproj_path=args.mmproj,
|
|
||||||
)
|
)
|
||||||
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|
||||||
|
|
|
||||||
|
|
@ -24,44 +24,17 @@ class GenerateResult:
|
||||||
|
|
||||||
|
|
||||||
class ChatMessage:
|
class ChatMessage:
|
||||||
"""A single message in a chat conversation.
|
"""A single message in a chat conversation."""
|
||||||
|
|
||||||
``content`` is either a plain string or a list of OpenAI-format content
|
def __init__(self, role: str, content: str) -> None:
|
||||||
blocks (dicts with ``type: "text"`` or ``type: "image_url"``). Backends
|
|
||||||
that do not support images should call ``text_only`` to get the string
|
|
||||||
form before passing to the model.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, role: str, content: "str | list") -> None:
|
|
||||||
if role not in ("system", "user", "assistant"):
|
if role not in ("system", "user", "assistant"):
|
||||||
raise ValueError(f"Invalid role {role!r}. Must be system, user, or assistant.")
|
raise ValueError(f"Invalid role {role!r}. Must be system, user, or assistant.")
|
||||||
self.role = role
|
self.role = role
|
||||||
self.content: "str | list" = content
|
self.content = content
|
||||||
|
|
||||||
def to_dict(self) -> dict:
|
def to_dict(self) -> dict:
|
||||||
return {"role": self.role, "content": self.content}
|
return {"role": self.role, "content": self.content}
|
||||||
|
|
||||||
@property
|
|
||||||
def has_images(self) -> bool:
|
|
||||||
"""True when at least one content block is an image_url block."""
|
|
||||||
if isinstance(self.content, str):
|
|
||||||
return False
|
|
||||||
return any(
|
|
||||||
isinstance(b, dict) and b.get("type") == "image_url"
|
|
||||||
for b in self.content
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def text_only(self) -> str:
|
|
||||||
"""Flatten multimodal content to text. Returns content as-is if already str."""
|
|
||||||
if isinstance(self.content, str):
|
|
||||||
return self.content
|
|
||||||
return "\n".join(
|
|
||||||
b["text"]
|
|
||||||
for b in self.content
|
|
||||||
if isinstance(b, dict) and b.get("type") == "text"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ── TextBackend Protocol ──────────────────────────────────────────────────────
|
# ── TextBackend Protocol ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
@ -143,33 +116,6 @@ class TextBackend(Protocol):
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
# ── FilterBackend Protocol ────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
|
||||||
class FilterBackend(Protocol):
|
|
||||||
"""
|
|
||||||
Abstract interface for token-classification / PII-filter backends.
|
|
||||||
|
|
||||||
Separate from TextBackend — returns entity spans and redacted text,
|
|
||||||
not generated text.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def classify(self, text: str) -> list[dict]:
|
|
||||||
"""Synchronous classify — returns list of entity span dicts."""
|
|
||||||
...
|
|
||||||
|
|
||||||
async def classify_async(self, text: str) -> list[dict]:
|
|
||||||
"""Async classify — runs in thread pool."""
|
|
||||||
...
|
|
||||||
|
|
||||||
@property
|
|
||||||
def model_name(self) -> str: ...
|
|
||||||
|
|
||||||
@property
|
|
||||||
def vram_mb(self) -> int: ...
|
|
||||||
|
|
||||||
|
|
||||||
# ── Backend selection ─────────────────────────────────────────────────────────
|
# ── Backend selection ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -187,7 +133,7 @@ def _select_backend(model_path: str, backend: str | None) -> str:
|
||||||
|
|
||||||
Raise ValueError for unrecognised override values.
|
Raise ValueError for unrecognised override values.
|
||||||
"""
|
"""
|
||||||
_VALID = ("llamacpp", "transformers", "ollama", "vllm", "classifier")
|
_VALID = ("llamacpp", "transformers", "ollama", "vllm")
|
||||||
|
|
||||||
# 1. Caller-supplied override — highest trust, no inspection needed.
|
# 1. Caller-supplied override — highest trust, no inspection needed.
|
||||||
resolved = backend or os.environ.get("CF_TEXT_BACKEND")
|
resolved = backend or os.environ.get("CF_TEXT_BACKEND")
|
||||||
|
|
@ -207,11 +153,6 @@ def _select_backend(model_path: str, backend: str | None) -> str:
|
||||||
# 3. Format detection — GGUF files are unambiguously llama-cpp territory.
|
# 3. Format detection — GGUF files are unambiguously llama-cpp territory.
|
||||||
if model_path.lower().endswith(".gguf"):
|
if model_path.lower().endswith(".gguf"):
|
||||||
return "llamacpp"
|
return "llamacpp"
|
||||||
# 3b. GGUF directory — avocet downloads whole repos; scan for .gguf contents.
|
|
||||||
if os.path.isdir(model_path):
|
|
||||||
import glob as _glob
|
|
||||||
if _glob.glob(os.path.join(model_path, "*.gguf")) or _glob.glob(os.path.join(model_path, "*.GGUF")):
|
|
||||||
return "llamacpp"
|
|
||||||
|
|
||||||
# 4. Safe default — transformers covers HF repo IDs and safetensors dirs.
|
# 4. Safe default — transformers covers HF repo IDs and safetensors dirs.
|
||||||
return "transformers"
|
return "transformers"
|
||||||
|
|
@ -224,7 +165,6 @@ def make_text_backend(
|
||||||
model_path: str,
|
model_path: str,
|
||||||
backend: str | None = None,
|
backend: str | None = None,
|
||||||
mock: bool | None = None,
|
mock: bool | None = None,
|
||||||
mmproj_path: str = "",
|
|
||||||
) -> "TextBackend":
|
) -> "TextBackend":
|
||||||
"""
|
"""
|
||||||
Return a TextBackend for the given model.
|
Return a TextBackend for the given model.
|
||||||
|
|
@ -241,7 +181,7 @@ def make_text_backend(
|
||||||
|
|
||||||
if resolved == "llamacpp":
|
if resolved == "llamacpp":
|
||||||
from circuitforge_core.text.backends.llamacpp import LlamaCppBackend
|
from circuitforge_core.text.backends.llamacpp import LlamaCppBackend
|
||||||
return LlamaCppBackend(model_path=model_path, mmproj_path=mmproj_path)
|
return LlamaCppBackend(model_path=model_path)
|
||||||
|
|
||||||
if resolved == "transformers":
|
if resolved == "transformers":
|
||||||
from circuitforge_core.text.backends.transformers import TransformersBackend
|
from circuitforge_core.text.backends.transformers import TransformersBackend
|
||||||
|
|
@ -255,22 +195,4 @@ def make_text_backend(
|
||||||
from circuitforge_core.text.backends.vllm import VllmBackend
|
from circuitforge_core.text.backends.vllm import VllmBackend
|
||||||
return VllmBackend(model_path=model_path)
|
return VllmBackend(model_path=model_path)
|
||||||
|
|
||||||
raise ValueError(
|
raise ValueError(f"Unknown backend {resolved!r}. Expected 'llamacpp', 'transformers', 'ollama', or 'vllm'.")
|
||||||
f"Unknown backend {resolved!r}. "
|
|
||||||
"Expected 'llamacpp', 'transformers', 'ollama', 'vllm', or 'classifier'."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def make_classifier_backend(model_path: str) -> "FilterBackend":
|
|
||||||
"""
|
|
||||||
Return a FilterBackend for the given token-classification model.
|
|
||||||
|
|
||||||
CF_TEXT_MOCK=1 → MockClassifierBackend (no GPU, no model file needed)
|
|
||||||
Otherwise → ClassifierBackend via transformers pipeline
|
|
||||||
"""
|
|
||||||
if os.environ.get("CF_TEXT_MOCK", "") == "1":
|
|
||||||
from circuitforge_core.text.backends.mock import MockClassifierBackend
|
|
||||||
return MockClassifierBackend(model_name=model_path)
|
|
||||||
|
|
||||||
from circuitforge_core.text.backends.classifier import ClassifierBackend
|
|
||||||
return ClassifierBackend(model_path=model_path)
|
|
||||||
|
|
|
||||||
|
|
@ -1,88 +0,0 @@
|
||||||
# circuitforge_core/text/backends/classifier.py — HuggingFace token-classification backend
|
|
||||||
#
|
|
||||||
# BSL 1.1. Requires torch + transformers.
|
|
||||||
# Install: pip install circuitforge-core[text-transformers]
|
|
||||||
#
|
|
||||||
# Wraps pipeline("token-classification") for PII/entity detection.
|
|
||||||
# Returns spans with char offsets, entity labels, and confidence scores.
|
|
||||||
# Use make_classifier_backend() from base.py to instantiate.
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class ClassifierBackend:
|
|
||||||
"""
|
|
||||||
HuggingFace token-classification backend for PII detection and entity labeling.
|
|
||||||
|
|
||||||
Loads any token-classification model from HuggingFace Hub or a local checkpoint.
|
|
||||||
Returns aggregated entity spans with char offsets — suitable for redaction or audit.
|
|
||||||
|
|
||||||
Aggregation strategy "simple" merges consecutive BIO-tagged subwords into word-level
|
|
||||||
spans and strips the B-/I- prefixes so callers see "NAME" not "B-NAME".
|
|
||||||
|
|
||||||
Requires: pip install circuitforge-core[text-transformers]
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, model_path: str) -> None:
|
|
||||||
try:
|
|
||||||
import torch
|
|
||||||
from transformers import pipeline as hf_pipeline
|
|
||||||
except ImportError as exc:
|
|
||||||
raise ImportError(
|
|
||||||
"torch and transformers are required for ClassifierBackend. "
|
|
||||||
"Install with: pip install circuitforge-core[text-transformers]"
|
|
||||||
) from exc
|
|
||||||
|
|
||||||
device = 0 if torch.cuda.is_available() else -1
|
|
||||||
cuda_devices = os.environ.get("CUDA_VISIBLE_DEVICES", "")
|
|
||||||
if cuda_devices:
|
|
||||||
device = 0
|
|
||||||
|
|
||||||
logger.info("Loading classifier model %s on device %s", model_path, device)
|
|
||||||
|
|
||||||
self._pipeline = hf_pipeline(
|
|
||||||
"token-classification",
|
|
||||||
model=model_path,
|
|
||||||
aggregation_strategy="simple",
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
self._model_path = model_path
|
|
||||||
|
|
||||||
@property
|
|
||||||
def model_name(self) -> str:
|
|
||||||
return self._model_path.split("/")[-1]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def vram_mb(self) -> int:
|
|
||||||
try:
|
|
||||||
import torch
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
return torch.cuda.memory_allocated() // (1024 * 1024)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
return 0
|
|
||||||
|
|
||||||
def classify(self, text: str) -> list[dict[str, Any]]:
|
|
||||||
"""
|
|
||||||
Run token classification synchronously.
|
|
||||||
|
|
||||||
Returns a list of entity dicts with keys:
|
|
||||||
entity_group: str — label without BIO prefix (e.g. "NAME", "EMAIL")
|
|
||||||
score: float — aggregated confidence
|
|
||||||
word: str — matched text span
|
|
||||||
start: int — char offset (start, inclusive)
|
|
||||||
end: int — char offset (end, exclusive)
|
|
||||||
"""
|
|
||||||
results: list[dict[str, Any]] = self._pipeline(text)
|
|
||||||
return results
|
|
||||||
|
|
||||||
async def classify_async(self, text: str) -> list[dict[str, Any]]:
|
|
||||||
"""Async classify — runs pipeline in thread pool to avoid blocking the event loop."""
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
return await loop.run_in_executor(None, self.classify, text)
|
|
||||||
|
|
@ -48,16 +48,7 @@ class LlamaCppBackend:
|
||||||
Requires: pip install circuitforge-core[text-llamacpp]
|
Requires: pip install circuitforge-core[text-llamacpp]
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, model_path: str, mmproj_path: str = "", chat_format: str = "") -> None:
|
def __init__(self, model_path: str) -> None:
|
||||||
"""Load a GGUF model.
|
|
||||||
|
|
||||||
``mmproj_path``: path to a separate multimodal projector file (needed
|
|
||||||
for LLaVA-style VLMs where the visual encoder is a separate .gguf).
|
|
||||||
Qwen2-VL and similar models with an embedded projector don't need this.
|
|
||||||
|
|
||||||
``chat_format``: llama-cpp chat template override (e.g. "llava-1-5",
|
|
||||||
"moondream"). Required when mmproj_path is set.
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
from llama_cpp import Llama # type: ignore[import]
|
from llama_cpp import Llama # type: ignore[import]
|
||||||
except ImportError as exc:
|
except ImportError as exc:
|
||||||
|
|
@ -72,53 +63,20 @@ class LlamaCppBackend:
|
||||||
"Download a GGUF model and set CF_TEXT_MODEL to its path."
|
"Download a GGUF model and set CF_TEXT_MODEL to its path."
|
||||||
)
|
)
|
||||||
|
|
||||||
# If given a directory, find the .gguf file inside it.
|
|
||||||
if Path(model_path).is_dir():
|
|
||||||
candidates = sorted(Path(model_path).glob("*.gguf")) or sorted(Path(model_path).glob("*.GGUF"))
|
|
||||||
if not candidates:
|
|
||||||
raise FileNotFoundError(
|
|
||||||
f"No .gguf file found in directory: {model_path}"
|
|
||||||
)
|
|
||||||
model_path = str(candidates[0])
|
|
||||||
|
|
||||||
n_threads = int(os.environ.get("CF_TEXT_THREADS", "0")) or None
|
n_threads = int(os.environ.get("CF_TEXT_THREADS", "0")) or None
|
||||||
|
logger.info(
|
||||||
kwargs: dict = dict(
|
"Loading GGUF model %s (ctx=%d, gpu_layers=%d)",
|
||||||
|
model_path, _DEFAULT_N_CTX, _DEFAULT_N_GPU_LAYERS,
|
||||||
|
)
|
||||||
|
self._llm = Llama(
|
||||||
model_path=model_path,
|
model_path=model_path,
|
||||||
n_ctx=_DEFAULT_N_CTX,
|
n_ctx=_DEFAULT_N_CTX,
|
||||||
n_gpu_layers=_DEFAULT_N_GPU_LAYERS,
|
n_gpu_layers=_DEFAULT_N_GPU_LAYERS,
|
||||||
n_threads=n_threads,
|
n_threads=n_threads,
|
||||||
verbose=False,
|
verbose=False,
|
||||||
)
|
)
|
||||||
if mmproj_path:
|
|
||||||
kwargs["clip_model_path"] = mmproj_path
|
|
||||||
kwargs["chat_format"] = chat_format or "llava-1-5"
|
|
||||||
logger.info(
|
|
||||||
"Loading VLM %s with mmproj %s (ctx=%d, gpu_layers=%d)",
|
|
||||||
model_path, mmproj_path, _DEFAULT_N_CTX, _DEFAULT_N_GPU_LAYERS,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.info(
|
|
||||||
"Loading GGUF model %s (ctx=%d, gpu_layers=%d)",
|
|
||||||
model_path, _DEFAULT_N_CTX, _DEFAULT_N_GPU_LAYERS,
|
|
||||||
)
|
|
||||||
|
|
||||||
self._llm = Llama(**kwargs)
|
|
||||||
self._model_path = model_path
|
self._model_path = model_path
|
||||||
self._vram_mb = _estimate_vram_mb(model_path)
|
self._vram_mb = _estimate_vram_mb(model_path)
|
||||||
# True when the model was initialised with a visual encoder (explicit
|
|
||||||
# mmproj) or when it is a known self-contained VLM (Qwen2-VL, etc.).
|
|
||||||
self._is_vlm = bool(mmproj_path) or self._detect_embedded_vlm()
|
|
||||||
|
|
||||||
def _detect_embedded_vlm(self) -> bool:
|
|
||||||
"""Heuristic: check model metadata for a known multimodal architecture."""
|
|
||||||
try:
|
|
||||||
meta = self._llm.metadata or {}
|
|
||||||
arch = str(meta.get("general.architecture", "")).lower()
|
|
||||||
# Qwen2-VL and similar embed the vision encoder inside the GGUF.
|
|
||||||
return any(tag in arch for tag in ("qwen2_vl", "llava", "moondream", "minicpm-v"))
|
|
||||||
except Exception:
|
|
||||||
return False
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def model_name(self) -> str:
|
def model_name(self) -> str:
|
||||||
|
|
@ -223,14 +181,7 @@ class LlamaCppBackend:
|
||||||
max_tokens: int = 512,
|
max_tokens: int = 512,
|
||||||
temperature: float = 0.7,
|
temperature: float = 0.7,
|
||||||
) -> GenerateResult:
|
) -> GenerateResult:
|
||||||
# Detect image content before calling the model.
|
# llama-cpp-python has native chat_completion for instruct models
|
||||||
if any(m.has_images for m in messages) and not self._is_vlm:
|
|
||||||
raise ValueError(
|
|
||||||
"model does not support image input — "
|
|
||||||
"load a VLM (with mmproj_path) or route to cf-vision/cf-docuvision"
|
|
||||||
)
|
|
||||||
# llama-cpp-python create_chat_completion accepts content as str or
|
|
||||||
# list-of-blocks (OpenAI multimodal format) natively.
|
|
||||||
output = self._llm.create_chat_completion(
|
output = self._llm.create_chat_completion(
|
||||||
messages=[m.to_dict() for m in messages],
|
messages=[m.to_dict() for m in messages],
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
|
|
|
||||||
|
|
@ -102,49 +102,3 @@ class MockTextBackend:
|
||||||
# Format messages into a simple prompt for the mock response
|
# Format messages into a simple prompt for the mock response
|
||||||
prompt = "\n".join(f"{m.role}: {m.content}" for m in messages)
|
prompt = "\n".join(f"{m.role}: {m.content}" for m in messages)
|
||||||
return self.generate(prompt, max_tokens=max_tokens, temperature=temperature)
|
return self.generate(prompt, max_tokens=max_tokens, temperature=temperature)
|
||||||
|
|
||||||
|
|
||||||
# Synthetic PII spans injected by MockClassifierBackend — predictable in tests.
|
|
||||||
_MOCK_SPANS = [
|
|
||||||
{
|
|
||||||
"entity_group": "NAME",
|
|
||||||
"score": 0.99,
|
|
||||||
"word": "Jane Doe",
|
|
||||||
"start": 0,
|
|
||||||
"end": 8,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"entity_group": "EMAIL",
|
|
||||||
"score": 0.97,
|
|
||||||
"word": "jane@example.com",
|
|
||||||
"start": 18,
|
|
||||||
"end": 34,
|
|
||||||
},
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class MockClassifierBackend:
|
|
||||||
"""
|
|
||||||
Deterministic mock classifier backend for development and CI.
|
|
||||||
|
|
||||||
Always returns the same two synthetic PII spans regardless of input.
|
|
||||||
Allows filter.py logic (redaction, span conversion) to be tested without
|
|
||||||
a real model or GPU.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, model_name: str = "mock-classifier") -> None:
|
|
||||||
self._model_name = model_name
|
|
||||||
|
|
||||||
@property
|
|
||||||
def model_name(self) -> str:
|
|
||||||
return self._model_name
|
|
||||||
|
|
||||||
@property
|
|
||||||
def vram_mb(self) -> int:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
def classify(self, text: str) -> list[dict]:
|
|
||||||
return list(_MOCK_SPANS)
|
|
||||||
|
|
||||||
async def classify_async(self, text: str) -> list[dict]:
|
|
||||||
return self.classify(text)
|
|
||||||
|
|
|
||||||
|
|
@ -50,12 +50,10 @@ class TransformersBackend:
|
||||||
logger.info("Loading transformers model %s on %s", model_path, self._device)
|
logger.info("Loading transformers model %s on %s", model_path, self._device)
|
||||||
|
|
||||||
load_kwargs: dict = {"device_map": "auto" if self._device == "cuda" else None}
|
load_kwargs: dict = {"device_map": "auto" if self._device == "cuda" else None}
|
||||||
if _LOAD_IN_4BIT or _LOAD_IN_8BIT:
|
if _LOAD_IN_4BIT:
|
||||||
from transformers import BitsAndBytesConfig
|
load_kwargs["load_in_4bit"] = True
|
||||||
load_kwargs["quantization_config"] = BitsAndBytesConfig(
|
elif _LOAD_IN_8BIT:
|
||||||
load_in_4bit=_LOAD_IN_4BIT,
|
load_kwargs["load_in_8bit"] = True
|
||||||
load_in_8bit=_LOAD_IN_8BIT,
|
|
||||||
)
|
|
||||||
|
|
||||||
self._tokenizer = AutoTokenizer.from_pretrained(model_path)
|
self._tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||||
self._model = AutoModelForCausalLM.from_pretrained(model_path, **load_kwargs)
|
self._model = AutoModelForCausalLM.from_pretrained(model_path, **load_kwargs)
|
||||||
|
|
|
||||||
|
|
@ -1,114 +0,0 @@
|
||||||
# circuitforge_core/text/filter.py — PII detection and redaction
|
|
||||||
#
|
|
||||||
# BSL 1.1. Products import PIIFilter for pre-send redaction and audit trails.
|
|
||||||
# Requires a running cf-filter service (or ClassifierBackend for in-process use).
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import re
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from circuitforge_core.text.backends.base import FilterBackend, make_classifier_backend
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class PIISpan:
|
|
||||||
"""A single detected PII entity in the source text."""
|
|
||||||
|
|
||||||
label: str # e.g. NAME | EMAIL | PHONE_NUM | ADDRESS | SSN | DOB | IP_ADDRESS
|
|
||||||
start: int # char offset (inclusive) in original_text
|
|
||||||
end: int # char offset (exclusive) in original_text
|
|
||||||
text: str # original span text
|
|
||||||
score: float # confidence score from the classifier
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class FilterResult:
|
|
||||||
"""Output of PIIFilter.filter().
|
|
||||||
|
|
||||||
``redacted_text``: safe-to-send copy with each span replaced by ``[LABEL]``.
|
|
||||||
``spans``: all detected entities — for audit logs or caller-side decisions.
|
|
||||||
``original_text``: the input text (stored for round-trip comparisons).
|
|
||||||
"""
|
|
||||||
|
|
||||||
redacted_text: str
|
|
||||||
spans: list[PIISpan] = field(default_factory=list)
|
|
||||||
original_text: str = ""
|
|
||||||
|
|
||||||
|
|
||||||
def _redact(text: str, spans: list[PIISpan]) -> str:
|
|
||||||
"""Replace each span in text with ``[LABEL]``, processing right-to-left so
|
|
||||||
earlier offsets remain valid after each substitution."""
|
|
||||||
result = text
|
|
||||||
for span in sorted(spans, key=lambda s: s.start, reverse=True):
|
|
||||||
result = result[: span.start] + f"[{span.label}]" + result[span.end :]
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def _spans_from_pipeline(raw: list[dict[str, Any]]) -> list[PIISpan]:
|
|
||||||
"""Convert raw pipeline output dicts into typed PIISpan objects.
|
|
||||||
|
|
||||||
Pipeline returns dicts with keys: entity_group, score, word, start, end.
|
|
||||||
Normalise label to uppercase and strip any residual BIO prefixes.
|
|
||||||
"""
|
|
||||||
spans: list[PIISpan] = []
|
|
||||||
for item in raw:
|
|
||||||
label = re.sub(r"^[BI]-", "", item.get("entity_group", "")).upper()
|
|
||||||
spans.append(
|
|
||||||
PIISpan(
|
|
||||||
label=label,
|
|
||||||
start=int(item["start"]),
|
|
||||||
end=int(item["end"]),
|
|
||||||
text=item.get("word", ""),
|
|
||||||
score=float(item.get("score", 0.0)),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return spans
|
|
||||||
|
|
||||||
|
|
||||||
class PIIFilter:
|
|
||||||
"""
|
|
||||||
High-level PII filter backed by a token-classification model.
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
pii_filter = PIIFilter.from_model("openai/privacy-filter")
|
|
||||||
result = await pii_filter.filter_async(resume_text)
|
|
||||||
safe_text = result.redacted_text # send to cloud LLM
|
|
||||||
spans = result.spans # store for audit trail
|
|
||||||
|
|
||||||
For in-process use (no cf-orch), pass a model path and it loads directly.
|
|
||||||
For service-backed use, see PIIFilter.from_backend().
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, backend: FilterBackend) -> None:
|
|
||||||
self._backend = backend
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_model(cls, model_path: str) -> "PIIFilter":
|
|
||||||
"""Load a classifier model in-process (no cf-orch required)."""
|
|
||||||
return cls(make_classifier_backend(model_path))
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_backend(cls, backend: FilterBackend) -> "PIIFilter":
|
|
||||||
"""Wrap an already-constructed FilterBackend."""
|
|
||||||
return cls(backend)
|
|
||||||
|
|
||||||
def filter(self, text: str) -> FilterResult:
|
|
||||||
"""Synchronous filter — blocks until classification is complete."""
|
|
||||||
raw = self._backend.classify(text)
|
|
||||||
spans = _spans_from_pipeline(raw)
|
|
||||||
return FilterResult(
|
|
||||||
redacted_text=_redact(text, spans),
|
|
||||||
spans=spans,
|
|
||||||
original_text=text,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def filter_async(self, text: str) -> FilterResult:
|
|
||||||
"""Async filter — runs classifier in thread pool."""
|
|
||||||
raw = await self._backend.classify_async(text)
|
|
||||||
spans = _spans_from_pipeline(raw)
|
|
||||||
return FilterResult(
|
|
||||||
redacted_text=_redact(text, spans),
|
|
||||||
spans=spans,
|
|
||||||
original_text=text,
|
|
||||||
)
|
|
||||||
|
|
@ -5,7 +5,7 @@ circuitforge-core is distributed as an editable install from a local clone. It i
|
||||||
## Prerequisites
|
## Prerequisites
|
||||||
|
|
||||||
- Python 3.11+
|
- Python 3.11+
|
||||||
- A Python environment — conda or venv (see options below)
|
- A conda environment (CircuitForge uses `cf` by convention; older envs may be named `job-seeker`)
|
||||||
- The `circuitforge-core` repo cloned alongside your product repo
|
- The `circuitforge-core` repo cloned alongside your product repo
|
||||||
|
|
||||||
## Typical layout
|
## Typical layout
|
||||||
|
|
@ -21,10 +21,6 @@ circuitforge-core is distributed as an editable install from a local clone. It i
|
||||||
|
|
||||||
## Install
|
## Install
|
||||||
|
|
||||||
### Option A: conda (dev machines)
|
|
||||||
|
|
||||||
The CircuitForge conda environment is named `cf`:
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# From inside a product repo, assuming circuitforge-core is a sibling
|
# From inside a product repo, assuming circuitforge-core is a sibling
|
||||||
conda run -n cf pip install -e ../circuitforge-core
|
conda run -n cf pip install -e ../circuitforge-core
|
||||||
|
|
@ -34,29 +30,13 @@ conda activate cf
|
||||||
pip install -e ../circuitforge-core
|
pip install -e ../circuitforge-core
|
||||||
```
|
```
|
||||||
|
|
||||||
### Option B: venv (server and beta-host deployments)
|
|
||||||
|
|
||||||
For hosts that don't use conda (CI runners, beta VMs, Xander's orchard nodes):
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python3 -m venv .venv
|
|
||||||
source .venv/bin/activate
|
|
||||||
pip install -e /path/to/circuitforge-core
|
|
||||||
```
|
|
||||||
|
|
||||||
Or if cf-core is a sibling directory of the product:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
pip install -e ../circuitforge-core
|
|
||||||
```
|
|
||||||
|
|
||||||
The editable install means changes to circuitforge-core source are reflected immediately in all products without reinstalling. Only restart the product's process after changes (or Docker container if running in Docker).
|
The editable install means changes to circuitforge-core source are reflected immediately in all products without reinstalling. Only restart the product's process after changes (or Docker container if running in Docker).
|
||||||
|
|
||||||
## Verify
|
## Verify
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import circuitforge_core
|
import circuitforge_core
|
||||||
print(circuitforge_core.__version__) # e.g. 0.21.0
|
print(circuitforge_core.__version__) # 0.9.0
|
||||||
```
|
```
|
||||||
|
|
||||||
## Inside Docker
|
## Inside Docker
|
||||||
|
|
|
||||||
|
|
@ -1,151 +0,0 @@
|
||||||
# circuitforge_core.memory
|
|
||||||
|
|
||||||
Persistent knowledge graph for CF products, backed by the
|
|
||||||
[mnemo](https://github.com/zaydmulani09/mnemo) sidecar.
|
|
||||||
|
|
||||||
## What it does
|
|
||||||
|
|
||||||
mnemo runs as a sidecar process alongside a product's FastAPI backend. It:
|
|
||||||
|
|
||||||
- Extracts named entities and relationships from text you feed it
|
|
||||||
- Persists them in a local SQLite database with WAL mode
|
|
||||||
- Returns a formatted context block for prompt injection in under 5ms
|
|
||||||
|
|
||||||
`cf_core.memory` wraps mnemo's Python SDK with CF-standard config,
|
|
||||||
graceful degradation (no-ops when the sidecar is absent), and
|
|
||||||
exponential backoff with automatic reconnect after transient failures.
|
|
||||||
|
|
||||||
## Install
|
|
||||||
|
|
||||||
```bash
|
|
||||||
pip install circuitforge-core[memory]
|
|
||||||
```
|
|
||||||
|
|
||||||
## Docker Compose setup
|
|
||||||
|
|
||||||
Add the `mnemo` service to your product's `compose.yml` alongside `ollama`.
|
|
||||||
Peregrine is the reference implementation — copy the block from
|
|
||||||
`peregrine/compose.yml`:
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
services:
|
|
||||||
|
|
||||||
mnemo:
|
|
||||||
image: ghcr.io/zaydmulani09/mnemo:latest
|
|
||||||
ports:
|
|
||||||
- "${MNEMO_PORT:-8080}:8080"
|
|
||||||
volumes:
|
|
||||||
- mnemo-data:/data
|
|
||||||
environment:
|
|
||||||
- MNEMO_DB_PATH=/data/mnemo.db
|
|
||||||
- MNEMO_LLM_PROVIDER=${MNEMO_LLM_PROVIDER:-ollama}
|
|
||||||
- MNEMO_LLM_BASE_URL=${MNEMO_LLM_BASE_URL:-http://ollama:11434/v1}
|
|
||||||
- MNEMO_LLM_API_KEY=${MNEMO_LLM_API_KEY:-ollama}
|
|
||||||
- MNEMO_LLM_MODEL=${MNEMO_LLM_MODEL:-llama3.2:3b}
|
|
||||||
depends_on:
|
|
||||||
- ollama
|
|
||||||
healthcheck:
|
|
||||||
test: ["CMD", "wget", "-q", "--spider", "http://localhost:8080/health"]
|
|
||||||
interval: 15s
|
|
||||||
timeout: 5s
|
|
||||||
retries: 3
|
|
||||||
profiles: [memory]
|
|
||||||
restart: unless-stopped
|
|
||||||
|
|
||||||
volumes:
|
|
||||||
mnemo-data:
|
|
||||||
```
|
|
||||||
|
|
||||||
Add these to the product's api service environment:
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
environment:
|
|
||||||
- MNEMO_HOST=${MNEMO_HOST:-mnemo}
|
|
||||||
- MNEMO_PORT=${MNEMO_PORT:-8080}
|
|
||||||
```
|
|
||||||
|
|
||||||
Launch with:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
docker compose --profile memory --profile cpu up -d
|
|
||||||
# or alongside a GPU profile:
|
|
||||||
docker compose --profile memory --profile single-gpu up -d
|
|
||||||
```
|
|
||||||
|
|
||||||
## Environment variables
|
|
||||||
|
|
||||||
| Variable | Default | Description |
|
|
||||||
|---|---|---|
|
|
||||||
| `MNEMO_HOST` | `localhost` | Sidecar hostname (use `mnemo` in Docker) |
|
|
||||||
| `MNEMO_PORT` | `8080` | Sidecar port |
|
|
||||||
| `MNEMO_TIMEOUT` | `10.0` | HTTP timeout in seconds |
|
|
||||||
|
|
||||||
The sidecar itself is configured via `MNEMO_LLM_*` env vars (see compose block above).
|
|
||||||
|
|
||||||
## FastAPI integration
|
|
||||||
|
|
||||||
```python
|
|
||||||
from contextlib import asynccontextmanager
|
|
||||||
from fastapi import FastAPI
|
|
||||||
from circuitforge_core.memory import MemoryClient, MemoryConfig
|
|
||||||
|
|
||||||
memory = MemoryClient(MemoryConfig.from_env())
|
|
||||||
|
|
||||||
@asynccontextmanager
|
|
||||||
async def lifespan(app: FastAPI):
|
|
||||||
await memory.connect() # no-op + warning if sidecar absent
|
|
||||||
yield
|
|
||||||
await memory.close()
|
|
||||||
|
|
||||||
app = FastAPI(lifespan=lifespan)
|
|
||||||
```
|
|
||||||
|
|
||||||
## API
|
|
||||||
|
|
||||||
```python
|
|
||||||
# Store a text fragment (conversation turn, fact, user preference, etc.)
|
|
||||||
await memory.remember("User avoids shellfish and prefers dark mode", source="settings")
|
|
||||||
|
|
||||||
# Retrieve a prompt-ready context block
|
|
||||||
context = await memory.recall("What are this user's dietary restrictions?")
|
|
||||||
system_prompt = f"You are a helpful assistant.\n\n{context}"
|
|
||||||
|
|
||||||
# List extracted entities
|
|
||||||
entities = await memory.entities(limit=20)
|
|
||||||
|
|
||||||
# Stats snapshot
|
|
||||||
stats = await memory.stats() # MemoryStats | None
|
|
||||||
|
|
||||||
# Wipe everything (irreversible)
|
|
||||||
await memory.wipe()
|
|
||||||
```
|
|
||||||
|
|
||||||
All methods return empty values (`False`, `""`, `[]`, `None`) when the
|
|
||||||
sidecar is not available — no try/except needed in product code.
|
|
||||||
|
|
||||||
## Resilience model
|
|
||||||
|
|
||||||
| Event | Behaviour |
|
|
||||||
|---|---|
|
|
||||||
| Sidecar absent at startup | `connect()` logs once, enters no-op mode |
|
|
||||||
| First call failure | Warning logged, 5s backoff scheduled |
|
|
||||||
| Nth consecutive failure | Backoff doubles each time (5→10→20→40→60s cap) |
|
|
||||||
| After `_MAX_FAILURES` (3) | Client marked unavailable; all calls no-op |
|
|
||||||
| Cooldown elapses | Next call silently attempts reconnect |
|
|
||||||
| Successful call | Failure counter and retry timer reset |
|
|
||||||
| `strict=True` | `MemoryUnavailableError` raised instead of no-op |
|
|
||||||
|
|
||||||
## Chunking note
|
|
||||||
|
|
||||||
mnemo stores each `remember()` call as a single chunk — it does **not**
|
|
||||||
automatically split large texts. For best retrieval quality, chunk on the
|
|
||||||
caller side before ingesting:
|
|
||||||
|
|
||||||
```python
|
|
||||||
# Good: one turn per ingest call
|
|
||||||
for turn in conversation_turns:
|
|
||||||
await memory.remember(turn, source="chat", session_id=session_id)
|
|
||||||
|
|
||||||
# Avoid: one giant blob
|
|
||||||
await memory.remember(entire_conversation_as_one_string)
|
|
||||||
```
|
|
||||||
|
|
@ -14,9 +14,6 @@ dependencies = [
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
memory = [
|
|
||||||
"mnemo-sdk>=0.1.0",
|
|
||||||
]
|
|
||||||
community = [
|
community = [
|
||||||
"psycopg2>=2.9",
|
"psycopg2>=2.9",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -1,281 +0,0 @@
|
||||||
"""Tests for circuitforge_core.memory.
|
|
||||||
|
|
||||||
These tests mock the mnemo SDK so no live sidecar is required.
|
|
||||||
"""
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import sys
|
|
||||||
import time
|
|
||||||
from types import ModuleType
|
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from circuitforge_core.memory import MemoryClient, MemoryConfig, MemoryUnavailableError
|
|
||||||
from circuitforge_core.memory.client import _MAX_FAILURES
|
|
||||||
|
|
||||||
|
|
||||||
# ── Helpers ───────────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
def _make_mock_mnemo(health_ok: bool = True):
|
|
||||||
"""Return a (mock_module, mock_inner_client) pair."""
|
|
||||||
mock_health = MagicMock(
|
|
||||||
status="ok" if health_ok else "error",
|
|
||||||
provider_type="ollama",
|
|
||||||
provider_model="llama3",
|
|
||||||
)
|
|
||||||
mock_client = AsyncMock()
|
|
||||||
mock_client.health = AsyncMock(return_value=mock_health)
|
|
||||||
mock_client.ingest = AsyncMock(return_value=MagicMock(chunk_id="abc", entities_extracted=2))
|
|
||||||
mock_client.get_context = AsyncMock(return_value="Relevant context: user prefers dark mode")
|
|
||||||
mock_client.list_entities = AsyncMock(return_value=[])
|
|
||||||
mock_client.stats = AsyncMock(return_value=MagicMock(
|
|
||||||
entity_count=5, chunk_count=10, node_count=5, edge_count=3, uptime_seconds=120.0
|
|
||||||
))
|
|
||||||
mock_client.wipe = AsyncMock(return_value=None)
|
|
||||||
mock_client.__aexit__ = AsyncMock(return_value=None)
|
|
||||||
|
|
||||||
mock_module = ModuleType("mnemo")
|
|
||||||
mock_module.AsyncMnemoClient = MagicMock(return_value=mock_client)
|
|
||||||
return mock_module, mock_client
|
|
||||||
|
|
||||||
|
|
||||||
async def _connected(health_ok: bool = True):
|
|
||||||
"""Return a connected MemoryClient with mock inner client attached."""
|
|
||||||
mock_module, mock_inner = _make_mock_mnemo(health_ok=health_ok)
|
|
||||||
client = MemoryClient(MemoryConfig())
|
|
||||||
with patch.dict(sys.modules, {"mnemo": mock_module}):
|
|
||||||
await client.connect()
|
|
||||||
client._mock_inner = mock_inner
|
|
||||||
return client
|
|
||||||
|
|
||||||
|
|
||||||
# ── Config ────────────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
class TestMemoryConfig:
|
|
||||||
def test_defaults(self):
|
|
||||||
cfg = MemoryConfig()
|
|
||||||
assert cfg.host == "localhost"
|
|
||||||
assert cfg.port == 8080
|
|
||||||
assert cfg.base_url == "http://localhost:8080"
|
|
||||||
|
|
||||||
def test_from_env(self, monkeypatch):
|
|
||||||
monkeypatch.setenv("MNEMO_HOST", "mnemo-sidecar")
|
|
||||||
monkeypatch.setenv("MNEMO_PORT", "9090")
|
|
||||||
monkeypatch.setenv("MNEMO_TIMEOUT", "30.0")
|
|
||||||
cfg = MemoryConfig.from_env()
|
|
||||||
assert cfg.host == "mnemo-sidecar"
|
|
||||||
assert cfg.port == 9090
|
|
||||||
assert cfg.timeout == 30.0
|
|
||||||
|
|
||||||
def test_base_url(self):
|
|
||||||
cfg = MemoryConfig(host="10.1.10.5", port=8080)
|
|
||||||
assert cfg.base_url == "http://10.1.10.5:8080"
|
|
||||||
|
|
||||||
|
|
||||||
# ── connect() ─────────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
class TestConnect:
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_connect_success(self):
|
|
||||||
client = await _connected(health_ok=True)
|
|
||||||
assert client.available is True
|
|
||||||
assert client.failure_count == 0
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_connect_bad_health_status(self):
|
|
||||||
client = await _connected(health_ok=False)
|
|
||||||
assert client.available is False
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_connect_sidecar_unreachable(self):
|
|
||||||
mock_module, mock_client = _make_mock_mnemo()
|
|
||||||
mock_client.health.side_effect = ConnectionRefusedError("refused")
|
|
||||||
client = MemoryClient(MemoryConfig())
|
|
||||||
with patch.dict(sys.modules, {"mnemo": mock_module}):
|
|
||||||
await client.connect() # must not raise
|
|
||||||
assert client.available is False
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_connect_strict_raises(self):
|
|
||||||
mock_module, mock_client = _make_mock_mnemo()
|
|
||||||
mock_client.health.side_effect = ConnectionRefusedError("refused")
|
|
||||||
client = MemoryClient(MemoryConfig(), strict=True)
|
|
||||||
with patch.dict(sys.modules, {"mnemo": mock_module}):
|
|
||||||
with pytest.raises(MemoryUnavailableError):
|
|
||||||
await client.connect()
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_connect_missing_sdk(self):
|
|
||||||
client = MemoryClient(MemoryConfig())
|
|
||||||
with patch.dict(sys.modules, {"mnemo": None}):
|
|
||||||
await client.connect()
|
|
||||||
assert client.available is False
|
|
||||||
|
|
||||||
|
|
||||||
# ── No-op when unavailable ────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
class TestNoopWhenUnavailable:
|
|
||||||
@pytest.fixture
|
|
||||||
def unavailable(self):
|
|
||||||
return MemoryClient(MemoryConfig())
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_remember_noop(self, unavailable):
|
|
||||||
assert await unavailable.remember("text") is False
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_recall_noop(self, unavailable):
|
|
||||||
assert await unavailable.recall("query") == ""
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_entities_noop(self, unavailable):
|
|
||||||
assert await unavailable.entities() == []
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_stats_noop(self, unavailable):
|
|
||||||
assert await unavailable.stats() is None
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_wipe_noop(self, unavailable):
|
|
||||||
assert await unavailable.wipe() is False
|
|
||||||
|
|
||||||
|
|
||||||
# ── Live calls when connected ─────────────────────────────────────────────────
|
|
||||||
|
|
||||||
class TestLiveCalls:
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_remember_calls_ingest(self):
|
|
||||||
client = await _connected()
|
|
||||||
result = await client.remember("hello world", source="test")
|
|
||||||
assert result is True
|
|
||||||
client._mock_inner.ingest.assert_awaited_once_with(
|
|
||||||
content="hello world", source="test", session_id=None
|
|
||||||
)
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_remember_resets_failure_count(self):
|
|
||||||
client = await _connected()
|
|
||||||
client._failure_count = 2 # simulate prior failures
|
|
||||||
await client.remember("text")
|
|
||||||
assert client.failure_count == 0
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_recall_returns_context(self):
|
|
||||||
client = await _connected()
|
|
||||||
ctx = await client.recall("dark mode preference")
|
|
||||||
assert "dark mode" in ctx
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_recall_with_session(self):
|
|
||||||
client = await _connected()
|
|
||||||
await client.recall("query", session_id="user-123")
|
|
||||||
client._mock_inner.get_context.assert_awaited_once_with(
|
|
||||||
text="query", session_id="user-123"
|
|
||||||
)
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_stats_returns_memory_stats(self):
|
|
||||||
from circuitforge_core.memory import MemoryStats
|
|
||||||
client = await _connected()
|
|
||||||
result = await client.stats()
|
|
||||||
assert isinstance(result, MemoryStats)
|
|
||||||
assert result.available is True
|
|
||||||
assert result.entity_count == 5
|
|
||||||
|
|
||||||
|
|
||||||
# ── Backoff and reconnect ─────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
class TestBackoffAndReconnect:
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_failure_count_increments(self):
|
|
||||||
client = await _connected()
|
|
||||||
client._mock_inner.ingest.side_effect = ConnectionResetError("reset")
|
|
||||||
await client.remember("text")
|
|
||||||
assert client.failure_count == 1
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_client_disabled_after_max_failures(self):
|
|
||||||
client = await _connected()
|
|
||||||
client._mock_inner.ingest.side_effect = ConnectionResetError("reset")
|
|
||||||
# drive failures to the limit
|
|
||||||
for _ in range(_MAX_FAILURES):
|
|
||||||
await client.remember("text")
|
|
||||||
assert client.available is False
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_retry_at_set_after_failure(self):
|
|
||||||
client = await _connected()
|
|
||||||
client._mock_inner.ingest.side_effect = ConnectionResetError("reset")
|
|
||||||
before = time.monotonic()
|
|
||||||
await client.remember("text")
|
|
||||||
assert client._retry_at is not None
|
|
||||||
assert client._retry_at > before
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_backoff_increases_with_failures(self):
|
|
||||||
client = await _connected()
|
|
||||||
client._mock_inner.ingest.side_effect = ConnectionResetError("reset")
|
|
||||||
|
|
||||||
retry_times = []
|
|
||||||
t0 = time.monotonic()
|
|
||||||
for _ in range(3):
|
|
||||||
await client.remember("text")
|
|
||||||
retry_times.append(client._retry_at - t0)
|
|
||||||
|
|
||||||
# Each cooldown should be longer than the previous
|
|
||||||
assert retry_times[1] > retry_times[0]
|
|
||||||
assert retry_times[2] > retry_times[1]
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_reconnect_attempted_after_cooldown(self):
|
|
||||||
"""Once the retry window elapses, the next call triggers a reconnect."""
|
|
||||||
client = await _connected()
|
|
||||||
# Force unavailable with an expired retry window
|
|
||||||
client._available = False
|
|
||||||
client._retry_at = time.monotonic() - 1.0 # already elapsed
|
|
||||||
|
|
||||||
mock_module, mock_inner = _make_mock_mnemo(health_ok=True)
|
|
||||||
with patch.dict(sys.modules, {"mnemo": mock_module}):
|
|
||||||
result = await client.remember("text after reconnect")
|
|
||||||
|
|
||||||
# Reconnect should have restored availability
|
|
||||||
assert client.available is True
|
|
||||||
assert result is True
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_no_reconnect_during_cooldown(self):
|
|
||||||
"""Within the cooldown window, calls no-op without attempting reconnect."""
|
|
||||||
client = await _connected()
|
|
||||||
client._available = False
|
|
||||||
client._retry_at = time.monotonic() + 999.0 # far in the future
|
|
||||||
|
|
||||||
mock_module, _ = _make_mock_mnemo(health_ok=True)
|
|
||||||
with patch.dict(sys.modules, {"mnemo": mock_module}):
|
|
||||||
result = await client.remember("text during cooldown")
|
|
||||||
|
|
||||||
assert result is False
|
|
||||||
assert client.available is False # no reconnect fired
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_success_resets_retry_state(self):
|
|
||||||
"""A successful call clears failure_count and retry_at."""
|
|
||||||
client = await _connected()
|
|
||||||
client._failure_count = 2
|
|
||||||
client._retry_at = time.monotonic() + 30.0
|
|
||||||
|
|
||||||
await client.remember("successful call")
|
|
||||||
|
|
||||||
assert client.failure_count == 0
|
|
||||||
assert client._retry_at is None
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_strict_raises_after_max_failures(self):
|
|
||||||
"""strict=True raises MemoryUnavailableError once failure threshold is hit."""
|
|
||||||
client = await _connected()
|
|
||||||
client._strict = True
|
|
||||||
client._mock_inner.ingest.side_effect = ConnectionResetError("reset")
|
|
||||||
|
|
||||||
with pytest.raises(MemoryUnavailableError):
|
|
||||||
for _ in range(_MAX_FAILURES):
|
|
||||||
await client.remember("text")
|
|
||||||
|
|
@ -1,78 +0,0 @@
|
||||||
"""Tests for MQTT topic wildcard matching in circuitforge_core.mqtt.router."""
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
|
|
||||||
# NOTE: matches() currently raises NotImplementedError — tests will fail
|
|
||||||
# until you implement it. Run these to verify correctness once implemented.
|
|
||||||
|
|
||||||
def _matches(pattern: str, topic: str) -> bool:
|
|
||||||
from circuitforge_core.mqtt.router import matches
|
|
||||||
return matches(pattern, topic)
|
|
||||||
|
|
||||||
|
|
||||||
class TestExactMatch:
|
|
||||||
def test_exact(self):
|
|
||||||
assert _matches("a/b/c", "a/b/c")
|
|
||||||
|
|
||||||
def test_no_match(self):
|
|
||||||
assert not _matches("a/b/c", "a/b/d")
|
|
||||||
|
|
||||||
def test_empty_topic(self):
|
|
||||||
assert _matches("", "")
|
|
||||||
|
|
||||||
|
|
||||||
class TestSingleLevelWildcard:
|
|
||||||
def test_plus_middle(self):
|
|
||||||
assert _matches("sensor/+/temp", "sensor/room1/temp")
|
|
||||||
|
|
||||||
def test_plus_no_match_extra_level(self):
|
|
||||||
assert not _matches("sensor/+/temp", "sensor/a/b/temp")
|
|
||||||
|
|
||||||
def test_plus_start(self):
|
|
||||||
assert _matches("+/b/c", "a/b/c")
|
|
||||||
|
|
||||||
def test_plus_end(self):
|
|
||||||
assert _matches("a/b/+", "a/b/anything")
|
|
||||||
|
|
||||||
def test_multiple_plus(self):
|
|
||||||
assert _matches("+/+/+", "x/y/z")
|
|
||||||
|
|
||||||
def test_plus_no_match_empty_segment(self):
|
|
||||||
# '+' must match exactly one level — a leading slash creates an empty segment
|
|
||||||
# This edge case depends on the implementation; just check consistent behavior.
|
|
||||||
result = _matches("+", "a/b")
|
|
||||||
assert result is False
|
|
||||||
|
|
||||||
|
|
||||||
class TestMultiLevelWildcard:
|
|
||||||
def test_hash_root(self):
|
|
||||||
assert _matches("#", "a/b/c")
|
|
||||||
|
|
||||||
def test_hash_prefix(self):
|
|
||||||
assert _matches("sensor/#", "sensor/room1/temp")
|
|
||||||
|
|
||||||
def test_hash_zero_levels(self):
|
|
||||||
# '#' matches zero or more levels — "sensor/#" should match "sensor"
|
|
||||||
assert _matches("sensor/#", "sensor")
|
|
||||||
|
|
||||||
def test_hash_must_be_last(self):
|
|
||||||
# '#' in the middle is invalid MQTT but we should handle gracefully
|
|
||||||
# Just verify it doesn't crash; exact behavior is implementation-defined.
|
|
||||||
try:
|
|
||||||
_matches("sensor/#/foo", "sensor/bar/foo")
|
|
||||||
except Exception:
|
|
||||||
pass # either False or ValueError is acceptable
|
|
||||||
|
|
||||||
def test_hash_only(self):
|
|
||||||
assert _matches("#", "anything")
|
|
||||||
|
|
||||||
def test_hash_no_match_different_prefix(self):
|
|
||||||
assert not _matches("sensor/#", "actuator/fan")
|
|
||||||
|
|
||||||
|
|
||||||
class TestMixedWildcards:
|
|
||||||
def test_plus_and_hash(self):
|
|
||||||
assert _matches("msh/+/#", "msh/us-west/node1/json/TEXT_MESSAGE_APP/!deadbeef")
|
|
||||||
|
|
||||||
def test_plus_before_hash(self):
|
|
||||||
assert _matches("+/#", "region/any/nested/topic")
|
|
||||||
|
|
@ -1,151 +0,0 @@
|
||||||
# tests/test_text/test_classifier.py — PII filter backend and endpoint tests
|
|
||||||
import pytest
|
|
||||||
from httpx import AsyncClient, ASGITransport
|
|
||||||
|
|
||||||
from circuitforge_core.text.backends.mock import MockClassifierBackend
|
|
||||||
from circuitforge_core.text.filter import PIIFilter, PIISpan, FilterResult, _redact, _spans_from_pipeline
|
|
||||||
|
|
||||||
|
|
||||||
# ── Unit: _spans_from_pipeline ────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def test_spans_from_pipeline_normalises_bio_prefix():
|
|
||||||
raw = [{"entity_group": "B-NAME", "score": 0.9, "word": "Alice", "start": 0, "end": 5}]
|
|
||||||
spans = _spans_from_pipeline(raw)
|
|
||||||
assert spans[0].label == "NAME"
|
|
||||||
|
|
||||||
|
|
||||||
def test_spans_from_pipeline_uppercase():
|
|
||||||
raw = [{"entity_group": "email", "score": 0.8, "word": "a@b.com", "start": 10, "end": 17}]
|
|
||||||
spans = _spans_from_pipeline(raw)
|
|
||||||
assert spans[0].label == "EMAIL"
|
|
||||||
|
|
||||||
|
|
||||||
def test_spans_from_pipeline_returns_typed_objects():
|
|
||||||
raw = [{"entity_group": "PHONE_NUM", "score": 0.95, "word": "555-1234", "start": 5, "end": 13}]
|
|
||||||
spans = _spans_from_pipeline(raw)
|
|
||||||
assert isinstance(spans[0], PIISpan)
|
|
||||||
assert spans[0].score == pytest.approx(0.95)
|
|
||||||
assert spans[0].start == 5
|
|
||||||
assert spans[0].end == 13
|
|
||||||
|
|
||||||
|
|
||||||
# ── Unit: _redact ─────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def test_redact_replaces_spans():
|
|
||||||
text = "Call Alice at 555-1234 now"
|
|
||||||
spans = [
|
|
||||||
PIISpan(label="NAME", start=5, end=10, text="Alice", score=0.99),
|
|
||||||
PIISpan(label="PHONE_NUM", start=14, end=22, text="555-1234", score=0.97),
|
|
||||||
]
|
|
||||||
assert _redact(text, spans) == "Call [NAME] at [PHONE_NUM] now"
|
|
||||||
|
|
||||||
|
|
||||||
def test_redact_handles_overlapping_order():
|
|
||||||
# Spans processed right-to-left — earlier offsets must still be valid
|
|
||||||
text = "Jane Doe jane@example.com"
|
|
||||||
spans = [
|
|
||||||
PIISpan(label="NAME", start=0, end=8, text="Jane Doe", score=0.99),
|
|
||||||
PIISpan(label="EMAIL", start=9, end=25, text="jane@example.com", score=0.97),
|
|
||||||
]
|
|
||||||
result = _redact(text, spans)
|
|
||||||
assert "[NAME]" in result
|
|
||||||
assert "[EMAIL]" in result
|
|
||||||
assert "Jane Doe" not in result
|
|
||||||
assert "jane@example.com" not in result
|
|
||||||
|
|
||||||
|
|
||||||
def test_redact_no_spans_returns_original():
|
|
||||||
text = "No PII here"
|
|
||||||
assert _redact(text, []) == text
|
|
||||||
|
|
||||||
|
|
||||||
# ── Unit: PIIFilter with MockClassifierBackend ────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def test_pii_filter_sync():
|
|
||||||
backend = MockClassifierBackend()
|
|
||||||
pii_filter = PIIFilter.from_backend(backend)
|
|
||||||
# Mock backend returns spans for "Jane Doe" at 0-8 and "jane@example.com" at 18-34
|
|
||||||
result = pii_filter.filter("Jane Doe emailed jane@example.com today")
|
|
||||||
assert isinstance(result, FilterResult)
|
|
||||||
assert "[NAME]" in result.redacted_text
|
|
||||||
assert "[EMAIL]" in result.redacted_text
|
|
||||||
assert len(result.spans) == 2
|
|
||||||
|
|
||||||
|
|
||||||
def test_pii_filter_preserves_original_text():
|
|
||||||
backend = MockClassifierBackend()
|
|
||||||
pii_filter = PIIFilter.from_backend(backend)
|
|
||||||
text = "Jane Doe emailed jane@example.com today"
|
|
||||||
result = pii_filter.filter(text)
|
|
||||||
assert result.original_text == text
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_pii_filter_async():
|
|
||||||
backend = MockClassifierBackend()
|
|
||||||
pii_filter = PIIFilter.from_backend(backend)
|
|
||||||
result = await pii_filter.filter_async("Jane Doe emailed jane@example.com today")
|
|
||||||
assert "[NAME]" in result.redacted_text
|
|
||||||
assert len(result.spans) == 2
|
|
||||||
|
|
||||||
|
|
||||||
def test_pii_filter_result_is_frozen():
|
|
||||||
backend = MockClassifierBackend()
|
|
||||||
pii_filter = PIIFilter.from_backend(backend)
|
|
||||||
result = pii_filter.filter("test")
|
|
||||||
with pytest.raises((AttributeError, TypeError)):
|
|
||||||
result.redacted_text = "mutated" # type: ignore[misc]
|
|
||||||
|
|
||||||
|
|
||||||
# ── Integration: /filter HTTP endpoint ───────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def classifier_app(monkeypatch):
|
|
||||||
"""cf-text app in classifier mode using mock backend."""
|
|
||||||
import os
|
|
||||||
monkeypatch.setenv("CF_TEXT_MOCK", "1")
|
|
||||||
monkeypatch.setenv("CF_TEXT_BACKEND", "classifier")
|
|
||||||
import importlib
|
|
||||||
import circuitforge_core.text.app as app_mod
|
|
||||||
importlib.reload(app_mod)
|
|
||||||
yield app_mod.create_app(model_path="openai/privacy-filter", backend="classifier", mock=False)
|
|
||||||
monkeypatch.delenv("CF_TEXT_MOCK", raising=False)
|
|
||||||
monkeypatch.delenv("CF_TEXT_BACKEND", raising=False)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_filter_endpoint_returns_redacted(classifier_app):
|
|
||||||
async with AsyncClient(transport=ASGITransport(app=classifier_app), base_url="http://test") as client:
|
|
||||||
resp = await client.post("/filter", json={"text": "Jane Doe emailed jane@example.com today"})
|
|
||||||
assert resp.status_code == 200
|
|
||||||
body = resp.json()
|
|
||||||
assert "[NAME]" in body["redacted_text"]
|
|
||||||
assert "[EMAIL]" in body["redacted_text"]
|
|
||||||
assert len(body["spans"]) == 2
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_filter_endpoint_includes_original(classifier_app):
|
|
||||||
text = "Jane Doe emailed jane@example.com today"
|
|
||||||
async with AsyncClient(transport=ASGITransport(app=classifier_app), base_url="http://test") as client:
|
|
||||||
resp = await client.post("/filter", json={"text": text})
|
|
||||||
assert resp.json()["original_text"] == text
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_generate_returns_501_in_classifier_mode(classifier_app):
|
|
||||||
async with AsyncClient(transport=ASGITransport(app=classifier_app), base_url="http://test") as client:
|
|
||||||
resp = await client.post("/generate", json={"prompt": "hello"})
|
|
||||||
assert resp.status_code == 501
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_health_reports_classifier_backend(classifier_app):
|
|
||||||
async with AsyncClient(transport=ASGITransport(app=classifier_app), base_url="http://test") as client:
|
|
||||||
resp = await client.get("/health")
|
|
||||||
assert resp.status_code == 200
|
|
||||||
assert resp.json()["backend"] == "classifier"
|
|
||||||
Loading…
Reference in a new issue