Compare commits
60 commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 01ed48808b | |||
| a2c768c635 | |||
| f7bf121aef | |||
| 8fa8216161 | |||
| b9b601aa23 | |||
| 433207d3c5 | |||
| 56fb6be4b1 | |||
| 0598801aaa | |||
| ffb95a5a30 | |||
| f74457d11f | |||
| d78310d4fd | |||
| a189511760 | |||
| 2e9e3fdc4b | |||
| 3082318e0d | |||
| 69a338bd98 | |||
| fc52d32574 | |||
| 7623c3edaf | |||
| 8c1daf3b6c | |||
| 80b0d5fd34 | |||
| 3075e5d3da | |||
| 67493048e2 | |||
| 5766fa82ab | |||
| 48d33a78ef | |||
| c9c4828387 | |||
| 19a26e02a0 | |||
| e5c26f0e67 | |||
| 3c9c765668 | |||
| bb2ed3e992 | |||
| f3bc4ac605 | |||
| d98d27be3d | |||
| 4d858af4d1 | |||
| 874354f235 | |||
| 3050179b2f | |||
| 378d125ba6 | |||
| 1cbea29817 | |||
| f0a9ec5c37 | |||
| 0a15ad9522 | |||
| c244260d1c | |||
| 2259382d0b | |||
| 090a86ce1b | |||
| c1e825c06a | |||
| d16bc569cf | |||
| ccd2a35deb | |||
| fe19de3d9a | |||
| 7837fbcad2 | |||
| 73cec07bd2 | |||
| 4c3f3a95a5 | |||
| d719ea2309 | |||
| 0d9d030320 | |||
| 9ee31a09c1 | |||
| e6cd3a2e96 | |||
| cb51ba72bc | |||
| 3deae056de | |||
| 9544f695e6 | |||
| 7397e227e2 | |||
| 8d87ed4c9f | |||
| 6e3474b97b | |||
| d45d4e1de6 | |||
| 7bb6b76bd5 | |||
| a54a530493 |
165 changed files with 10598 additions and 6050 deletions
33
.cliff.toml
Normal file
33
.cliff.toml
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
[changelog]
|
||||
header = "# Changelog\n"
|
||||
body = """
|
||||
{% if version %}\
|
||||
## [{{ version | trim_start_matches(pat="v") }}] - {{ timestamp | date(format="%Y-%m-%d") }}
|
||||
{% else %}\
|
||||
## [Unreleased]
|
||||
{% endif %}\
|
||||
{% for group, commits in commits | group_by(attribute="group") %}
|
||||
### {{ group | upper_first }}
|
||||
{% for commit in commits %}
|
||||
- {{ commit.message | upper_first }}\
|
||||
{% endfor %}
|
||||
{% endfor %}\n
|
||||
"""
|
||||
footer = ""
|
||||
trim = true
|
||||
|
||||
[git]
|
||||
conventional_commits = true
|
||||
filter_unconventional = true
|
||||
commit_parsers = [
|
||||
{ message = "^feat", group = "Features" },
|
||||
{ message = "^fix", group = "Bug Fixes" },
|
||||
{ message = "^refactor", group = "Refactor" },
|
||||
{ message = "^perf", group = "Performance" },
|
||||
{ message = "^docs", group = "Documentation" },
|
||||
{ message = "^test", group = "Testing" },
|
||||
{ message = "^ci", group = "CI/CD" },
|
||||
{ message = "^chore", group = "Miscellaneous" },
|
||||
]
|
||||
filter_commits = false
|
||||
tag_pattern = "v[0-9].*"
|
||||
30
.forgejo/workflows/ci.yml
Normal file
30
.forgejo/workflows/ci.yml
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
name: CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [main]
|
||||
pull_request:
|
||||
branches: [main]
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
cache: pip
|
||||
|
||||
- name: Install dependencies
|
||||
run: pip install -e ".[dev]"
|
||||
|
||||
- name: Lint (ruff)
|
||||
run: ruff check circuitforge_core/
|
||||
|
||||
- name: Type check (mypy)
|
||||
run: mypy circuitforge_core/ --ignore-missing-imports
|
||||
|
||||
- name: Test
|
||||
run: pytest tests/ -v --tb=short
|
||||
31
.forgejo/workflows/mirror.yml
Normal file
31
.forgejo/workflows/mirror.yml
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
name: Mirror
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [main]
|
||||
tags:
|
||||
- "v*"
|
||||
|
||||
jobs:
|
||||
mirror:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Mirror to GitHub
|
||||
continue-on-error: true
|
||||
env:
|
||||
GITHUB_MIRROR_TOKEN: ${{ secrets.GITHUB_MIRROR_TOKEN }}
|
||||
run: |
|
||||
git remote add github "https://x-access-token:${GITHUB_MIRROR_TOKEN}@github.com/CircuitForgeLLC/circuitforge-core.git"
|
||||
git push github --mirror
|
||||
|
||||
- name: Mirror to Codeberg
|
||||
continue-on-error: true
|
||||
env:
|
||||
CODEBERG_MIRROR_TOKEN: ${{ secrets.CODEBERG_MIRROR_TOKEN }}
|
||||
run: |
|
||||
git remote add codeberg "https://CircuitForge:${CODEBERG_MIRROR_TOKEN}@codeberg.org/CircuitForge/circuitforge-core.git"
|
||||
git push codeberg --mirror
|
||||
52
.forgejo/workflows/release-pypi.yml
Normal file
52
.forgejo/workflows/release-pypi.yml
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
name: Release — PyPI
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- "v*"
|
||||
|
||||
jobs:
|
||||
release:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
|
||||
- name: Build
|
||||
run: |
|
||||
pip install build
|
||||
python -m build
|
||||
|
||||
- name: Publish to PyPI
|
||||
uses: pypa/gh-action-pypi-publish@release/v1
|
||||
with:
|
||||
password: ${{ secrets.PYPI_API_TOKEN }}
|
||||
|
||||
- name: Create Forgejo release
|
||||
env:
|
||||
FORGEJO_TOKEN: ${{ secrets.FORGEJO_RELEASE_TOKEN }}
|
||||
run: |
|
||||
TAG="${GITHUB_REF_NAME}"
|
||||
# Check if release already exists for this tag
|
||||
EXISTING=$(curl -sf \
|
||||
-H "Authorization: token ${FORGEJO_TOKEN}" \
|
||||
"https://git.opensourcesolarpunk.com/api/v1/repos/Circuit-Forge/circuitforge-core/releases/tags/${TAG}" \
|
||||
2>/dev/null | jq -r '.id // empty')
|
||||
|
||||
if [ -z "${EXISTING}" ]; then
|
||||
jq -n --arg tag "${TAG}" \
|
||||
'{"tag_name":$tag,"name":$tag,"draft":false,"prerelease":false}' \
|
||||
| curl -sf -X POST \
|
||||
-H "Authorization: token ${FORGEJO_TOKEN}" \
|
||||
-H "Content-Type: application/json" \
|
||||
"https://git.opensourcesolarpunk.com/api/v1/repos/Circuit-Forge/circuitforge-core/releases" \
|
||||
-d @-
|
||||
echo "Release created for ${TAG}"
|
||||
else
|
||||
echo "Release for ${TAG} already exists (id=${EXISTING}), skipping."
|
||||
fi
|
||||
107
CHANGELOG.md
107
CHANGELOG.md
|
|
@ -6,6 +6,113 @@ Versions follow [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
|||
|
||||
---
|
||||
|
||||
## [0.10.0] — 2026-04-12
|
||||
|
||||
### Added
|
||||
|
||||
**`circuitforge_core.community`** — shared community signal module (BSL 1.1, closes #44)
|
||||
|
||||
Provides the PostgreSQL-backed infrastructure for the cross-product community fine-tuning signal pipeline. Products write signals; the training pipeline reads them.
|
||||
|
||||
- `CommunityDB` — psycopg2 connection pool with `run_migrations()`. Picks up all `.sql` files from `circuitforge_core/community/migrations/` in filename order. Safe to call on every startup (idempotent `CREATE TABLE IF NOT EXISTS`).
|
||||
- `CommunityPost` — frozen dataclass capturing a user-authored community post with a snapshot of the originating product item (`element_snapshot` as a tuple of key-value pairs for immutability).
|
||||
- `SharedStore` — base class for product-specific community stores. Provides typed `pg_read()` and `pg_write()` helpers that products subclass without re-implementing connection management.
|
||||
- Migration 001: `community_posts` schema (id, product, item_id, pseudonym, title, body, element_snapshot JSONB, created_at).
|
||||
- Migration 002: `community_reactions` stub (post_id FK, pseudonym, reaction_type, created_at).
|
||||
- `psycopg2-binary` added to `[community]` optional extras in `pyproject.toml`.
|
||||
- All community classes exported from `circuitforge_core.community`.
|
||||
|
||||
---
|
||||
|
||||
## [0.9.0] — 2026-04-10
|
||||
|
||||
### Added
|
||||
|
||||
**`circuitforge_core.text`** — OpenAI-compatible `/v1/chat/completions` endpoint and pipeline crystallization engine.
|
||||
|
||||
**`circuitforge_core.pipeline`** — multimodal pipeline with staged output crystallization. Products queue draft outputs for human review before committing.
|
||||
|
||||
**`circuitforge_core.stt`** — speech-to-text module. `FasterWhisperBackend` for local transcription via `faster-whisper`. Managed FastAPI app mountable in any product.
|
||||
|
||||
**`circuitforge_core.tts`** — text-to-speech module. `ChatterboxTurbo` backend for local synthesis. Managed FastAPI app.
|
||||
|
||||
**Accessibility preferences** — `preferences` module extended with structured accessibility fields (motion reduction, high contrast, font size, focus highlight) under `accessibility.*` key path.
|
||||
|
||||
**LLM output corrections router** — `make_corrections_router()` for collecting LLM output corrections in any product. Stores corrections in product SQLite for future fine-tuning.
|
||||
|
||||
---
|
||||
|
||||
## [0.8.0] — 2026-04-08
|
||||
|
||||
### Added
|
||||
|
||||
**`circuitforge_core.vision`** — cf-vision managed service shim. Routes vision inference requests to a local cf-vision worker (moondream2 / SigLIP). Closes #43.
|
||||
|
||||
**`circuitforge_core.api.feedback`** — `make_feedback_router()` shared Forgejo issue-filing router. Products mount it under `/api/feedback`; requires `FORGEJO_API_TOKEN`. Closes #30.
|
||||
|
||||
**License validation** — `CF_LICENSE_KEY` validation via Heimdall REST API. Products call `validate_license(key, product)` to gate premium features. Closes #26.
|
||||
|
||||
---
|
||||
|
||||
## [0.7.0] — 2026-04-04
|
||||
|
||||
### Added
|
||||
|
||||
**`circuitforge_core.affiliates`** — affiliate link wrapping module (closes #21)
|
||||
- `wrap_url(url, retailer, user_id, get_preference)` — resolution order: opt-out → BYOK → CF env var → plain URL
|
||||
- `AffiliateProgram` frozen dataclass + `register_program()` / `get_program()` registry
|
||||
- Built-in programs: eBay Partner Network (`EBAY_AFFILIATE_CAMPAIGN_ID`), Amazon Associates (`AMAZON_ASSOCIATES_TAG`)
|
||||
- `get_disclosure_text(retailer)` — per-retailer tooltip copy + `BANNER_COPY` first-encounter constants
|
||||
- `get_preference` callable injection for opt-out + BYOK without hard-wiring a storage backend
|
||||
|
||||
**`circuitforge_core.preferences`** — preference persistence helpers (closes #22 self-hosted path)
|
||||
- `LocalFileStore` — YAML-backed single-user preference store (`~/.config/circuitforge/preferences.yaml`)
|
||||
- `get_user_preference(user_id, path, default, store)` + `set_user_preference(user_id, path, value, store)`
|
||||
- `PreferenceStore` protocol — Heimdall cloud backend to follow once Heimdall#5 lands
|
||||
- Dot-path utilities `get_path` / `set_path` (immutable nested dict read/write)
|
||||
|
||||
---
|
||||
|
||||
## [0.5.0] — 2026-04-02
|
||||
|
||||
### Added
|
||||
|
||||
**`circuitforge_core.manage` — cross-platform product manager** (closes #6)
|
||||
|
||||
Replaces bash-only `manage.sh` across all products. Works on Linux, macOS, and Windows natively — no WSL2 or Docker required.
|
||||
|
||||
- **`ManageConfig`**: reads `manage.toml` from the product root (TOML via stdlib `tomllib`). Falls back to directory name when no config file is present — Docker-only products need zero configuration.
|
||||
- **Docker mode** (`DockerManager`): wraps `docker compose` (v2 plugin) or `docker-compose` (v1). Auto-detected when Docker is available and a compose file exists. Commands: `start`, `stop`, `restart`, `status`, `logs`, `build`.
|
||||
- **Native mode** (`NativeManager`): PID-file process management with `platformdirs`-based paths (`AppData` on Windows, `~/.local/share` on Linux/macOS). Cross-platform kill (SIGTERM→SIGKILL on Unix, `taskkill /F` on Windows). Log tailing via polling — no `tail -f`, works everywhere.
|
||||
- **CLI** (`typer`): `start`, `stop`, `restart`, `status`, `logs`, `build`, `open`, `install-shims`. `--mode auto|docker|native` override.
|
||||
- **`install-shims`**: writes `manage.sh` (bash, +x) and `manage.ps1` (PowerShell) into the product directory, plus `manage.toml.example`.
|
||||
- **Entry points**: `python -m circuitforge_core.manage` and `cf-manage` console script.
|
||||
- **`pyproject.toml`**: `[manage]` optional extras group (`platformdirs`, `typer`).
|
||||
|
||||
---
|
||||
|
||||
## [0.4.0] — 2026-04-02
|
||||
|
||||
### Added
|
||||
|
||||
**Agent watchdog — coordinator-restart reconnect** (closes #15)
|
||||
- `NodeStore`: SQLite persistence for known agent nodes (`~/.local/share/circuitforge/cf-orch-nodes.db`); `upsert` on every registration, `prune_stale` removes nodes unseen for 30+ days
|
||||
- `AgentSupervisor.restore_from_store()`: reloads all previously-known nodes on coordinator startup; nodes start `offline=False` and come online within one heartbeat cycle (~10 s) without touching the agent processes
|
||||
- `AgentSupervisor.register()` now persists to `NodeStore` on every call
|
||||
- Agent CLI: one-shot registration replaced with a persistent 30 s reconnect loop (daemon thread); coordinator restart → remote nodes (Navi, Strahl, etc.) reappear automatically with no manual intervention
|
||||
|
||||
**Ollama adopt-if-running + configurable health path** (closes #16)
|
||||
- `ProcessSpec.adopt` (`bool`, default `False`): when `True`, `ServiceManager.start()` probes the health endpoint first and claims the already-running process rather than spawning a new one — designed for system daemons like Ollama
|
||||
- `ProcessSpec.health_path` (`str`, default `"/health"`): configurable health probe path; Ollama uses `/api/tags`
|
||||
- `ServiceManager._probe_health()`: shared urllib health check used by both `start()` and `is_running()` for adopt services
|
||||
- Agent `/services/{service}/start` response includes `adopted: true` when the service was claimed rather than started; coordinator sets instance state to `running` immediately (skips probe loop wait)
|
||||
- `ServiceInstance.health_path` field; `upsert_instance(health_path=)` kwarg
|
||||
- Coordinator probe loop uses `inst.health_path` instead of hardcoded `/health`
|
||||
- `_get_health_path()` helper looks up the ProcessSpec health path from the profile registry
|
||||
- All GPU profiles (2/4/6/8/16/24 GB + cpu-16/32 GB): `ollama` service now has a `managed:` block with `adopt: true`, `health_path: /api/tags`, port 11434
|
||||
|
||||
---
|
||||
|
||||
## [0.3.0] — 2026-04-02
|
||||
|
||||
### Added
|
||||
|
|
|
|||
53
Dockerfile.orch
Normal file
53
Dockerfile.orch
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
# cf-orch coordinator image
|
||||
# Includes the coordinator + agent; designed for paid+ multi-node deployments.
|
||||
#
|
||||
# Usage (coordinator node):
|
||||
# docker run -d \
|
||||
# -p 7700:7700 \
|
||||
# -e HEIMDALL_URL=https://license.circuitforge.tech \
|
||||
# -e HEIMDALL_MIN_TIER=paid \
|
||||
# -e CF_ORCH_AUTH_SECRET=<secret> \
|
||||
# ghcr.io/circuit-forge/cf-orch:latest coordinator
|
||||
#
|
||||
# Usage (GPU agent node — connects back to coordinator):
|
||||
# docker run -d \
|
||||
# --gpus all \
|
||||
# -e CF_COORDINATOR_URL=http://<coordinator-ip>:7700 \
|
||||
# ghcr.io/circuit-forge/cf-orch:latest agent
|
||||
#
|
||||
# Environment variables
|
||||
# ─────────────────────
|
||||
# CF_ORCH_PORT Coordinator listen port (default: 7700)
|
||||
# HEIMDALL_URL Enable license auth (omit for LAN-only / self-hosted)
|
||||
# HEIMDALL_MIN_TIER Minimum tier required (default: paid)
|
||||
# CF_ORCH_AUTH_SECRET Shared secret with Heimdall /licenses/verify
|
||||
# CF_COORDINATOR_URL Agent mode: coordinator URL to register with
|
||||
# CF_AGENT_GPU_IDS Comma-separated GPU indices for agent (default: 0)
|
||||
|
||||
FROM python:3.12-slim
|
||||
|
||||
LABEL org.opencontainers.image.source="https://git.opensourcesolarpunk.com/Circuit-Forge/circuitforge-core"
|
||||
LABEL org.opencontainers.image.description="cf-orch coordinator and agent for CircuitForge multi-node GPU orchestration"
|
||||
LABEL org.opencontainers.image.licenses="BSL-1.1"
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# System deps — httpx needs curl for connection reuse; avoid full dev toolchain
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
curl \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install cf-core with the resources extra (coordinator + agent deps)
|
||||
COPY pyproject.toml README.md ./
|
||||
COPY circuitforge_core/ ./circuitforge_core/
|
||||
|
||||
RUN pip install --no-cache-dir ".[resources,manage]"
|
||||
|
||||
ENV CF_ORCH_PORT=7700
|
||||
EXPOSE 7700
|
||||
|
||||
COPY docker/orch-entrypoint.sh /entrypoint.sh
|
||||
RUN chmod +x /entrypoint.sh
|
||||
|
||||
ENTRYPOINT ["/entrypoint.sh"]
|
||||
CMD ["coordinator"]
|
||||
21
LICENSE
Normal file
21
LICENSE
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
MIT License
|
||||
|
||||
Copyright (c) 2026 CircuitForge LLC
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
22
README.md
22
README.md
|
|
@ -2,15 +2,29 @@
|
|||
|
||||
Shared scaffold for CircuitForge products.
|
||||
|
||||
**Current version: 0.7.0**
|
||||
|
||||
## Modules
|
||||
|
||||
### Implemented
|
||||
|
||||
- `circuitforge_core.db` — SQLite connection factory and migration runner
|
||||
- `circuitforge_core.llm` — LLM router with fallback chain
|
||||
- `circuitforge_core.llm` — LLM router with fallback chain (Ollama, vLLM, Anthropic, OpenAI-compatible)
|
||||
- `circuitforge_core.tiers` — Tier system with BYOK and local vision unlocks
|
||||
- `circuitforge_core.config` — Env validation and .env loader
|
||||
- `circuitforge_core.vision` — Vision router stub (v0.2+)
|
||||
- `circuitforge_core.wizard` — First-run wizard base class stub
|
||||
- `circuitforge_core.pipeline` — Staging queue stub (v0.2+)
|
||||
- `circuitforge_core.hardware` — Hardware detection and LLM backend profile generation (VRAM tiers, GPU/CPU auto-select)
|
||||
- `circuitforge_core.documents` — Document ingestion pipeline: PDF, DOCX, and image OCR → `StructuredDocument`
|
||||
- `circuitforge_core.affiliates` — Affiliate URL wrapping with opt-out, BYOK user IDs, and CF env-var fallback (`wrap_url`)
|
||||
- `circuitforge_core.preferences` — User preference store (local YAML file, pluggable backend); dot-path get/set API
|
||||
- `circuitforge_core.tasks` — VRAM-aware LLM task scheduler; shared slot manager across services (`TaskScheduler`)
|
||||
- `circuitforge_core.manage` — Cross-platform product process manager (Docker and native modes)
|
||||
- `circuitforge_core.resources` — Resource coordinator and agent: VRAM allocation, eviction engine, GPU profile registry
|
||||
|
||||
### Stubs (in-tree, not yet implemented)
|
||||
|
||||
- `circuitforge_core.vision` — Vision router base class (planned: moondream2 / Claude vision dispatch)
|
||||
- `circuitforge_core.wizard` — First-run wizard base class (products subclass `BaseWizard`)
|
||||
- `circuitforge_core.pipeline` — Staging queue base (`StagingDB`; products provide concrete schema)
|
||||
|
||||
## Install
|
||||
|
||||
|
|
|
|||
|
|
@ -1 +1,8 @@
|
|||
__version__ = "0.1.0"
|
||||
__version__ = "0.10.0"
|
||||
|
||||
try:
|
||||
from circuitforge_core.community import CommunityDB, CommunityPost, SharedStore
|
||||
__all__ = ["CommunityDB", "CommunityPost", "SharedStore"]
|
||||
except ImportError:
|
||||
# psycopg2 not installed — install with: pip install circuitforge-core[community]
|
||||
pass
|
||||
|
|
|
|||
41
circuitforge_core/affiliates/__init__.py
Normal file
41
circuitforge_core/affiliates/__init__.py
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
"""Public API for circuitforge_core.affiliates.
|
||||
|
||||
Usage::
|
||||
|
||||
from circuitforge_core.affiliates import wrap_url, get_disclosure_text
|
||||
|
||||
# Wrap a URL — env-var mode (no preferences, no opt-out)
|
||||
url = wrap_url("https://www.ebay.com/itm/123", retailer="ebay")
|
||||
|
||||
# Wrap a URL — with preference injection (opt-out + BYOK)
|
||||
url = wrap_url(
|
||||
"https://www.ebay.com/itm/123",
|
||||
retailer="ebay",
|
||||
user_id="u123",
|
||||
get_preference=my_prefs_client.get,
|
||||
)
|
||||
|
||||
# Frontend disclosure tooltip
|
||||
text = get_disclosure_text("ebay")
|
||||
|
||||
# Register a product-specific program at startup
|
||||
register_program(AffiliateProgram(
|
||||
name="My Shop",
|
||||
retailer_key="myshop",
|
||||
env_var="MYSHOP_AFFILIATE_ID",
|
||||
build_url=lambda url, id_: f"{url}?ref={id_}",
|
||||
))
|
||||
"""
|
||||
from .disclosure import BANNER_COPY, get_disclosure_text
|
||||
from .programs import AffiliateProgram, get_program, register_program, registered_keys
|
||||
from .router import wrap_url
|
||||
|
||||
__all__ = [
|
||||
"wrap_url",
|
||||
"get_disclosure_text",
|
||||
"BANNER_COPY",
|
||||
"AffiliateProgram",
|
||||
"register_program",
|
||||
"get_program",
|
||||
"registered_keys",
|
||||
]
|
||||
49
circuitforge_core/affiliates/disclosure.py
Normal file
49
circuitforge_core/affiliates/disclosure.py
Normal file
|
|
@ -0,0 +1,49 @@
|
|||
"""Affiliate disclosure copy constants.
|
||||
|
||||
Follows the plain-language disclosure design from the affiliate links design
|
||||
doc. All copy is centralized here so products don't drift out of sync and
|
||||
legal/copy review has a single file to audit.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
# Per-retailer tooltip copy (shown on hover/tap of affiliate link indicator)
|
||||
_TOOLTIP: dict[str, str] = {
|
||||
"ebay": (
|
||||
"Affiliate link — CircuitForge earns a small commission if you purchase "
|
||||
"on eBay. No purchase data is shared with us. [Opt out in Settings]"
|
||||
),
|
||||
"amazon": (
|
||||
"Affiliate link — CircuitForge earns a small commission if you purchase "
|
||||
"on Amazon. No purchase data is shared with us. [Opt out in Settings]"
|
||||
),
|
||||
}
|
||||
|
||||
_GENERIC_TOOLTIP = (
|
||||
"Affiliate link — CircuitForge may earn a small commission if you purchase. "
|
||||
"No purchase data is shared with us. [Opt out in Settings]"
|
||||
)
|
||||
|
||||
# First-encounter banner copy (shown once, then preference saved)
|
||||
BANNER_COPY: dict[str, str] = {
|
||||
"title": "A note on purchase links",
|
||||
"body": (
|
||||
"Some links in this product go to retailers using our affiliate code. "
|
||||
"When you click one, the retailer knows you came from CircuitForge. "
|
||||
"We don't see or store what you buy. The retailer may track your "
|
||||
"purchase — that's between you and them.\n\n"
|
||||
"If you'd rather use plain links with no tracking code, you can opt "
|
||||
"out in Settings."
|
||||
),
|
||||
"dismiss_label": "Got it",
|
||||
"opt_out_label": "Opt out now",
|
||||
"learn_more_label": "Learn more",
|
||||
}
|
||||
|
||||
|
||||
def get_disclosure_text(retailer: str) -> str:
|
||||
"""Return the tooltip disclosure string for *retailer*.
|
||||
|
||||
Falls back to a generic string for unregistered retailers so callers
|
||||
never receive an empty string.
|
||||
"""
|
||||
return _TOOLTIP.get(retailer, _GENERIC_TOOLTIP)
|
||||
116
circuitforge_core/affiliates/programs.py
Normal file
116
circuitforge_core/affiliates/programs.py
Normal file
|
|
@ -0,0 +1,116 @@
|
|||
"""Affiliate program definitions and URL builders.
|
||||
|
||||
Each ``AffiliateProgram`` knows how to append its affiliate parameters to a
|
||||
plain product URL. Built-in programs (eBay EPN, Amazon Associates) are
|
||||
registered at module import time. Products can register additional programs
|
||||
with ``register_program()``.
|
||||
|
||||
Affiliate IDs are read from environment variables at call time so they pick
|
||||
up values set after process startup (useful in tests).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable
|
||||
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AffiliateProgram:
|
||||
"""One affiliate program and its URL building logic.
|
||||
|
||||
Attributes:
|
||||
name: Human-readable program name.
|
||||
retailer_key: Matches the ``retailer=`` argument in ``wrap_url()``.
|
||||
env_var: Environment variable holding CF's affiliate ID.
|
||||
build_url: ``(plain_url, affiliate_id) -> affiliate_url`` callable.
|
||||
"""
|
||||
|
||||
name: str
|
||||
retailer_key: str
|
||||
env_var: str
|
||||
build_url: Callable[[str, str], str]
|
||||
|
||||
def cf_affiliate_id(self) -> str | None:
|
||||
"""Return CF's configured affiliate ID, or None if the env var is unset/blank."""
|
||||
val = os.environ.get(self.env_var, "").strip()
|
||||
return val or None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# URL builders
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _build_ebay_url(url: str, affiliate_id: str) -> str:
|
||||
"""Append eBay Partner Network parameters to a listing URL."""
|
||||
sep = "&" if "?" in url else "?"
|
||||
params = urlencode({
|
||||
"mkcid": "1",
|
||||
"mkrid": "711-53200-19255-0",
|
||||
"siteid": "0",
|
||||
"campid": affiliate_id,
|
||||
"toolid": "10001",
|
||||
"mkevt": "1",
|
||||
})
|
||||
return f"{url}{sep}{params}"
|
||||
|
||||
|
||||
def _build_instacart_url(url: str, affiliate_id: str) -> str:
|
||||
"""Append Instacart affiliate parameter to a search URL."""
|
||||
sep = "&" if "?" in url else "?"
|
||||
return f"{url}{sep}aff={affiliate_id}"
|
||||
|
||||
|
||||
def _build_amazon_url(url: str, affiliate_id: str) -> str:
|
||||
"""Merge an Amazon Associates tag into a product URL's query string."""
|
||||
parsed = urlparse(url)
|
||||
qs = parse_qs(parsed.query, keep_blank_values=True)
|
||||
qs["tag"] = [affiliate_id]
|
||||
new_query = urlencode({k: v[0] for k, v in qs.items()})
|
||||
return urlunparse(parsed._replace(query=new_query))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Registry
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_REGISTRY: dict[str, AffiliateProgram] = {}
|
||||
|
||||
|
||||
def register_program(program: AffiliateProgram) -> None:
|
||||
"""Register an affiliate program (overwrites any existing entry for the same key)."""
|
||||
_REGISTRY[program.retailer_key] = program
|
||||
|
||||
|
||||
def get_program(retailer_key: str) -> AffiliateProgram | None:
|
||||
"""Return the registered program for *retailer_key*, or None."""
|
||||
return _REGISTRY.get(retailer_key)
|
||||
|
||||
|
||||
def registered_keys() -> list[str]:
|
||||
"""Return all currently registered retailer keys."""
|
||||
return list(_REGISTRY.keys())
|
||||
|
||||
|
||||
# Register built-ins
|
||||
register_program(AffiliateProgram(
|
||||
name="eBay Partner Network",
|
||||
retailer_key="ebay",
|
||||
env_var="EBAY_AFFILIATE_CAMPAIGN_ID",
|
||||
build_url=_build_ebay_url,
|
||||
))
|
||||
|
||||
register_program(AffiliateProgram(
|
||||
name="Amazon Associates",
|
||||
retailer_key="amazon",
|
||||
env_var="AMAZON_ASSOCIATES_TAG",
|
||||
build_url=_build_amazon_url,
|
||||
))
|
||||
|
||||
register_program(AffiliateProgram(
|
||||
name="Instacart",
|
||||
retailer_key="instacart",
|
||||
env_var="INSTACART_AFFILIATE_ID",
|
||||
build_url=_build_instacart_url,
|
||||
))
|
||||
83
circuitforge_core/affiliates/router.py
Normal file
83
circuitforge_core/affiliates/router.py
Normal file
|
|
@ -0,0 +1,83 @@
|
|||
"""Affiliate URL wrapping — resolution logic.
|
||||
|
||||
Resolution order (from affiliate links design doc):
|
||||
|
||||
1. User opted out? → return plain URL
|
||||
2. User has BYOK ID for this retailer? → wrap with user's ID
|
||||
3. CF has a program with env var set? → wrap with CF's ID
|
||||
4. No program / no ID configured → return plain URL
|
||||
|
||||
The ``get_preference`` callable is optional. When None (default), steps 1
|
||||
and 2 are skipped — the module operates in env-var-only mode. Products
|
||||
inject their preferences client to enable opt-out and BYOK.
|
||||
|
||||
Signature of ``get_preference``::
|
||||
|
||||
def get_preference(user_id: str | None, path: str, default=None) -> Any: ...
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Callable
|
||||
|
||||
from .programs import get_program
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
GetPreferenceFn = Callable[[str | None, str, Any], Any]
|
||||
|
||||
|
||||
def wrap_url(
|
||||
url: str,
|
||||
retailer: str,
|
||||
user_id: str | None = None,
|
||||
get_preference: GetPreferenceFn | None = None,
|
||||
) -> str:
|
||||
"""Return an affiliate URL for *url*, or the plain URL if no affiliate
|
||||
link can be or should be generated.
|
||||
|
||||
Args:
|
||||
url: Plain product URL to wrap.
|
||||
retailer: Retailer key (e.g. ``"ebay"``, ``"amazon"``).
|
||||
user_id: User identifier for preference lookups. None = anonymous.
|
||||
get_preference: Optional callable ``(user_id, path, default) -> value``.
|
||||
Injected by products to enable opt-out and BYOK resolution.
|
||||
When None, opt-out and BYOK checks are skipped.
|
||||
|
||||
Returns:
|
||||
Affiliate URL, or *url* unchanged if:
|
||||
- The user has opted out
|
||||
- No program is registered for *retailer*
|
||||
- No affiliate ID is configured (env var unset and no BYOK)
|
||||
"""
|
||||
program = get_program(retailer)
|
||||
if program is None:
|
||||
logger.debug("affiliates: no program registered for retailer=%r", retailer)
|
||||
return url
|
||||
|
||||
# Step 1: opt-out check
|
||||
if get_preference is not None:
|
||||
opted_out = get_preference(user_id, "affiliate.opt_out", False)
|
||||
if opted_out:
|
||||
logger.debug("affiliates: user %r opted out — returning plain URL", user_id)
|
||||
return url
|
||||
|
||||
# Step 2: BYOK — user's own affiliate ID (Premium)
|
||||
if get_preference is not None and user_id is not None:
|
||||
byok_id = get_preference(user_id, f"affiliate.byok_ids.{retailer}", None)
|
||||
if byok_id:
|
||||
logger.debug(
|
||||
"affiliates: using BYOK id for user=%r retailer=%r", user_id, retailer
|
||||
)
|
||||
return program.build_url(url, byok_id)
|
||||
|
||||
# Step 3: CF's affiliate ID from env var
|
||||
cf_id = program.cf_affiliate_id()
|
||||
if cf_id:
|
||||
return program.build_url(url, cf_id)
|
||||
|
||||
logger.debug(
|
||||
"affiliates: no affiliate ID configured for retailer=%r (env var %r unset)",
|
||||
retailer, program.env_var,
|
||||
)
|
||||
return url
|
||||
4
circuitforge_core/api/__init__.py
Normal file
4
circuitforge_core/api/__init__.py
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
from circuitforge_core.api.feedback import make_feedback_router
|
||||
from circuitforge_core.api.corrections import make_corrections_router, CORRECTIONS_MIGRATION_SQL
|
||||
|
||||
__all__ = ["make_feedback_router", "make_corrections_router", "CORRECTIONS_MIGRATION_SQL"]
|
||||
199
circuitforge_core/api/corrections.py
Normal file
199
circuitforge_core/api/corrections.py
Normal file
|
|
@ -0,0 +1,199 @@
|
|||
"""
|
||||
Shared corrections router — stores user corrections to LLM output for SFT training.
|
||||
|
||||
Products include this with make_corrections_router(get_db=..., product=...).
|
||||
Corrections are stored locally in each product's SQLite DB and exported as JSONL
|
||||
for the Avocet SFT pipeline. Separate from the bug-feedback→Forgejo-issue path.
|
||||
|
||||
Required DB migration (add to product migrations dir):
|
||||
-- From circuitforge_core.api.corrections import CORRECTIONS_MIGRATION_SQL
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sqlite3
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime, timezone
|
||||
from typing import Iterator, Literal
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# Drop this SQL into a product's migrations directory (e.g. 020_corrections.sql).
|
||||
CORRECTIONS_MIGRATION_SQL = """\
|
||||
CREATE TABLE IF NOT EXISTS corrections (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
item_id TEXT NOT NULL DEFAULT '',
|
||||
product TEXT NOT NULL,
|
||||
correction_type TEXT NOT NULL,
|
||||
input_text TEXT NOT NULL,
|
||||
original_output TEXT NOT NULL,
|
||||
corrected_output TEXT NOT NULL DEFAULT '',
|
||||
rating TEXT NOT NULL DEFAULT 'down',
|
||||
context TEXT NOT NULL DEFAULT '{}',
|
||||
opted_in INTEGER NOT NULL DEFAULT 0,
|
||||
created_at TEXT NOT NULL DEFAULT (datetime('now'))
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_corrections_product
|
||||
ON corrections (product);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_corrections_opted_in
|
||||
ON corrections (opted_in);
|
||||
"""
|
||||
|
||||
|
||||
class CorrectionRequest(BaseModel):
|
||||
item_id: str = ""
|
||||
product: str
|
||||
correction_type: str
|
||||
input_text: str
|
||||
original_output: str
|
||||
corrected_output: str = ""
|
||||
rating: Literal["up", "down"] = "down"
|
||||
context: dict = Field(default_factory=dict)
|
||||
opted_in: bool = False
|
||||
|
||||
|
||||
class CorrectionResponse(BaseModel):
|
||||
id: int
|
||||
saved: bool
|
||||
|
||||
|
||||
class CorrectionRecord(BaseModel):
|
||||
id: int
|
||||
item_id: str
|
||||
product: str
|
||||
correction_type: str
|
||||
input_text: str
|
||||
original_output: str
|
||||
corrected_output: str
|
||||
rating: str
|
||||
context: dict
|
||||
opted_in: bool
|
||||
created_at: str
|
||||
|
||||
|
||||
def make_corrections_router(
|
||||
get_db: Callable[[], Iterator[sqlite3.Connection]],
|
||||
product: str,
|
||||
) -> APIRouter:
|
||||
"""Return a configured corrections APIRouter.
|
||||
|
||||
Args:
|
||||
get_db: FastAPI dependency that yields a sqlite3.Connection.
|
||||
product: Product slug injected into every correction row (e.g. "linnet").
|
||||
"""
|
||||
router = APIRouter()
|
||||
|
||||
@router.post("", response_model=CorrectionResponse)
|
||||
def submit_correction(
|
||||
payload: CorrectionRequest,
|
||||
conn: sqlite3.Connection = Depends(get_db),
|
||||
) -> CorrectionResponse:
|
||||
"""Store a user correction to an LLM output."""
|
||||
# Thumbs-up with no corrected text is a valid positive signal.
|
||||
if payload.rating == "down" and not payload.corrected_output.strip():
|
||||
raise HTTPException(
|
||||
status_code=422,
|
||||
detail="corrected_output is required when rating is 'down'.",
|
||||
)
|
||||
|
||||
row_id = conn.execute(
|
||||
"""
|
||||
INSERT INTO corrections
|
||||
(item_id, product, correction_type, input_text, original_output,
|
||||
corrected_output, rating, context, opted_in)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
payload.item_id,
|
||||
product,
|
||||
payload.correction_type,
|
||||
payload.input_text,
|
||||
payload.original_output,
|
||||
payload.corrected_output,
|
||||
payload.rating,
|
||||
json.dumps(payload.context),
|
||||
int(payload.opted_in),
|
||||
),
|
||||
).lastrowid
|
||||
conn.commit()
|
||||
return CorrectionResponse(id=row_id, saved=True)
|
||||
|
||||
@router.get("", response_model=list[CorrectionRecord])
|
||||
def list_corrections(
|
||||
opted_in_only: bool = False,
|
||||
limit: int = 200,
|
||||
conn: sqlite3.Connection = Depends(get_db),
|
||||
) -> list[CorrectionRecord]:
|
||||
"""List stored corrections, optionally filtered to opted-in rows only."""
|
||||
conn.row_factory = sqlite3.Row
|
||||
query = "SELECT * FROM corrections"
|
||||
params: list = []
|
||||
if opted_in_only:
|
||||
query += " WHERE opted_in = 1"
|
||||
query += " ORDER BY created_at DESC LIMIT ?"
|
||||
params.append(max(1, min(limit, 1000)))
|
||||
rows = conn.execute(query, params).fetchall()
|
||||
return [
|
||||
CorrectionRecord(
|
||||
id=r["id"],
|
||||
item_id=r["item_id"],
|
||||
product=r["product"],
|
||||
correction_type=r["correction_type"],
|
||||
input_text=r["input_text"],
|
||||
original_output=r["original_output"],
|
||||
corrected_output=r["corrected_output"],
|
||||
rating=r["rating"],
|
||||
context=json.loads(r["context"] or "{}"),
|
||||
opted_in=bool(r["opted_in"]),
|
||||
created_at=r["created_at"],
|
||||
)
|
||||
for r in rows
|
||||
]
|
||||
|
||||
@router.get("/export")
|
||||
def export_corrections(
|
||||
opted_in_only: bool = True,
|
||||
conn: sqlite3.Connection = Depends(get_db),
|
||||
) -> StreamingResponse:
|
||||
"""Stream corrections as JSONL for the Avocet SFT pipeline.
|
||||
|
||||
Each line is a JSON object with the fields expected by avocet's
|
||||
SFT candidate importer. opted_in_only=True (default) — only rows
|
||||
where the user consented to share are exported.
|
||||
"""
|
||||
conn.row_factory = sqlite3.Row
|
||||
query = "SELECT * FROM corrections"
|
||||
if opted_in_only:
|
||||
query += " WHERE opted_in = 1"
|
||||
query += " ORDER BY created_at ASC"
|
||||
rows = conn.execute(query).fetchall()
|
||||
|
||||
timestamp = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
|
||||
filename = f"corrections_{product}_{timestamp}.jsonl"
|
||||
|
||||
def generate() -> Iterator[str]:
|
||||
for r in rows:
|
||||
record = {
|
||||
"input": r["input_text"],
|
||||
"output": r["original_output"],
|
||||
"correction": r["corrected_output"],
|
||||
"rating": r["rating"],
|
||||
"correction_type": r["correction_type"],
|
||||
"product": r["product"],
|
||||
"item_id": r["item_id"],
|
||||
"context": json.loads(r["context"] or "{}"),
|
||||
"created_at": r["created_at"],
|
||||
}
|
||||
yield json.dumps(record, ensure_ascii=False) + "\n"
|
||||
|
||||
return StreamingResponse(
|
||||
generate(),
|
||||
media_type="application/x-ndjson",
|
||||
headers={"Content-Disposition": f'attachment; filename="{filename}"'},
|
||||
)
|
||||
|
||||
return router
|
||||
179
circuitforge_core/api/feedback.py
Normal file
179
circuitforge_core/api/feedback.py
Normal file
|
|
@ -0,0 +1,179 @@
|
|||
"""
|
||||
Shared feedback router — creates Forgejo issues from in-app beta feedback.
|
||||
Products include this with make_feedback_router(repo=..., product=...).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import platform
|
||||
import subprocess
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
import requests
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
_LABEL_COLORS: dict[str, str] = {
|
||||
"beta-feedback": "#0075ca",
|
||||
"needs-triage": "#e4e669",
|
||||
"bug": "#d73a4a",
|
||||
"feature-request": "#a2eeef",
|
||||
"question": "#d876e3",
|
||||
}
|
||||
|
||||
_TYPE_LABEL_MAP: dict[str, str] = {"bug": "bug", "feature": "feature-request"}
|
||||
_TYPE_DISPLAY: dict[str, str] = {
|
||||
"bug": "🐛 Bug",
|
||||
"feature": "✨ Feature Request",
|
||||
"other": "💬 Other",
|
||||
}
|
||||
|
||||
|
||||
class FeedbackRequest(BaseModel):
|
||||
title: str
|
||||
description: str
|
||||
type: Literal["bug", "feature", "other"] = "other"
|
||||
repro: str = ""
|
||||
tab: str = "unknown"
|
||||
submitter: str = ""
|
||||
|
||||
|
||||
class FeedbackResponse(BaseModel):
|
||||
issue_number: int
|
||||
issue_url: str
|
||||
|
||||
|
||||
def _forgejo_headers() -> dict[str, str]:
|
||||
token = os.environ.get("FORGEJO_API_TOKEN", "")
|
||||
return {"Authorization": f"token {token}", "Content-Type": "application/json"}
|
||||
|
||||
|
||||
def _ensure_labels(label_names: list[str], base: str, repo: str) -> list[int]:
|
||||
headers = _forgejo_headers()
|
||||
resp = requests.get(f"{base}/repos/{repo}/labels", headers=headers, timeout=10)
|
||||
existing = {lb["name"]: lb["id"] for lb in resp.json()} if resp.ok else {}
|
||||
ids: list[int] = []
|
||||
for name in label_names:
|
||||
if name in existing:
|
||||
ids.append(existing[name])
|
||||
else:
|
||||
r = requests.post(
|
||||
f"{base}/repos/{repo}/labels",
|
||||
headers=headers,
|
||||
json={"name": name, "color": _LABEL_COLORS.get(name, "#ededed")},
|
||||
timeout=10,
|
||||
)
|
||||
if r.ok:
|
||||
ids.append(r.json()["id"])
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail=f"Failed to create label '{name}': {r.text[:200]}",
|
||||
)
|
||||
return ids
|
||||
|
||||
|
||||
def _collect_context(tab: str, product: str) -> dict[str, str]:
|
||||
try:
|
||||
version = subprocess.check_output(
|
||||
["git", "describe", "--tags", "--always"],
|
||||
cwd=Path.cwd(),
|
||||
text=True,
|
||||
timeout=5,
|
||||
).strip()
|
||||
except (subprocess.SubprocessError, OSError):
|
||||
version = "dev"
|
||||
return {
|
||||
"product": product,
|
||||
"tab": tab,
|
||||
"version": version,
|
||||
"platform": platform.platform(),
|
||||
"timestamp": datetime.now(timezone.utc).isoformat().replace("+00:00", "Z"),
|
||||
}
|
||||
|
||||
|
||||
def _build_issue_body(payload: FeedbackRequest, context: dict[str, str]) -> str:
|
||||
lines: list[str] = [
|
||||
f"## {_TYPE_DISPLAY.get(payload.type, '💬 Other')}",
|
||||
"",
|
||||
payload.description,
|
||||
"",
|
||||
]
|
||||
if payload.type == "bug" and payload.repro:
|
||||
lines += ["### Reproduction Steps", "", payload.repro, ""]
|
||||
lines += ["### Context", ""]
|
||||
for k, v in context.items():
|
||||
lines.append(f"- **{k}:** {v}")
|
||||
lines.append("")
|
||||
if payload.submitter:
|
||||
lines += ["---", f"*Submitted by: {payload.submitter}*"]
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def make_feedback_router(
|
||||
repo: str,
|
||||
product: str,
|
||||
demo_mode_fn: Callable[[], bool] | None = None,
|
||||
) -> APIRouter:
|
||||
"""Return a configured feedback APIRouter for the given Forgejo repo and product.
|
||||
|
||||
Args:
|
||||
repo: Forgejo repo slug, e.g. "Circuit-Forge/kiwi".
|
||||
product: Product name injected into issue context, e.g. "kiwi".
|
||||
demo_mode_fn: Optional callable returning True when in demo mode.
|
||||
If None, reads the DEMO_MODE environment variable.
|
||||
"""
|
||||
|
||||
def _is_demo() -> bool:
|
||||
if demo_mode_fn is not None:
|
||||
return demo_mode_fn()
|
||||
return os.environ.get("DEMO_MODE", "").lower() in ("1", "true", "yes")
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/status")
|
||||
def feedback_status() -> dict:
|
||||
"""Return whether feedback submission is configured on this instance."""
|
||||
return {"enabled": bool(os.environ.get("FORGEJO_API_TOKEN")) and not _is_demo()}
|
||||
|
||||
@router.post("", response_model=FeedbackResponse)
|
||||
def submit_feedback(payload: FeedbackRequest) -> FeedbackResponse:
|
||||
"""File a Forgejo issue from in-app feedback."""
|
||||
token = os.environ.get("FORGEJO_API_TOKEN", "")
|
||||
if not token:
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Feedback disabled: FORGEJO_API_TOKEN not configured.",
|
||||
)
|
||||
if _is_demo():
|
||||
raise HTTPException(status_code=403, detail="Feedback disabled in demo mode.")
|
||||
|
||||
base = os.environ.get(
|
||||
"FORGEJO_API_URL", "https://git.opensourcesolarpunk.com/api/v1"
|
||||
)
|
||||
context = _collect_context(payload.tab, product)
|
||||
body = _build_issue_body(payload, context)
|
||||
labels = [
|
||||
"beta-feedback",
|
||||
"needs-triage",
|
||||
_TYPE_LABEL_MAP.get(payload.type, "question"),
|
||||
]
|
||||
label_ids = _ensure_labels(labels, base, repo)
|
||||
|
||||
resp = requests.post(
|
||||
f"{base}/repos/{repo}/issues",
|
||||
headers=_forgejo_headers(),
|
||||
json={"title": payload.title, "body": body, "labels": label_ids},
|
||||
timeout=15,
|
||||
)
|
||||
if not resp.ok:
|
||||
raise HTTPException(
|
||||
status_code=502, detail=f"Forgejo error: {resp.text[:200]}"
|
||||
)
|
||||
data = resp.json()
|
||||
return FeedbackResponse(issue_number=data["number"], issue_url=data["html_url"])
|
||||
|
||||
return router
|
||||
9
circuitforge_core/community/__init__.py
Normal file
9
circuitforge_core/community/__init__.py
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
# circuitforge_core/community/__init__.py
|
||||
# MIT License
|
||||
|
||||
from .models import CommunityPost
|
||||
from .db import CommunityDB
|
||||
from .store import SharedStore
|
||||
from .snipe_store import SellerTrustSignal, SnipeCommunityStore
|
||||
|
||||
__all__ = ["CommunityDB", "CommunityPost", "SharedStore", "SellerTrustSignal", "SnipeCommunityStore"]
|
||||
117
circuitforge_core/community/db.py
Normal file
117
circuitforge_core/community/db.py
Normal file
|
|
@ -0,0 +1,117 @@
|
|||
# circuitforge_core/community/db.py
|
||||
# MIT License
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib.resources
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import psycopg2
|
||||
from psycopg2.pool import ThreadedConnectionPool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_MIN_CONN = 1
|
||||
_MAX_CONN = 10
|
||||
|
||||
|
||||
class CommunityDB:
|
||||
"""Shared PostgreSQL connection pool + migration runner for the community module.
|
||||
|
||||
Products instantiate one CommunityDB at startup and pass it to SharedStore
|
||||
subclasses. The pool is thread-safe (ThreadedConnectionPool).
|
||||
|
||||
Usage:
|
||||
db = CommunityDB.from_env() # reads COMMUNITY_DB_URL
|
||||
db.run_migrations()
|
||||
store = MyProductStore(db)
|
||||
db.close() # at shutdown
|
||||
"""
|
||||
|
||||
def __init__(self, dsn: str | None) -> None:
|
||||
if not dsn:
|
||||
raise ValueError(
|
||||
"CommunityDB requires a DSN. "
|
||||
"Set COMMUNITY_DB_URL or pass dsn= explicitly."
|
||||
)
|
||||
self._pool = ThreadedConnectionPool(_MIN_CONN, _MAX_CONN, dsn=dsn)
|
||||
logger.debug("CommunityDB pool created (min=%d, max=%d)", _MIN_CONN, _MAX_CONN)
|
||||
|
||||
@classmethod
|
||||
def from_env(cls) -> "CommunityDB":
|
||||
"""Construct from the COMMUNITY_DB_URL environment variable."""
|
||||
import os
|
||||
dsn = os.environ.get("COMMUNITY_DB_URL")
|
||||
return cls(dsn=dsn)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Connection management
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def getconn(self):
|
||||
"""Borrow a connection from the pool. Must be returned via putconn()."""
|
||||
return self._pool.getconn()
|
||||
|
||||
def putconn(self, conn) -> None:
|
||||
"""Return a borrowed connection to the pool."""
|
||||
self._pool.putconn(conn)
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close all pool connections. Call at application shutdown."""
|
||||
self._pool.closeall()
|
||||
logger.debug("CommunityDB pool closed")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Migration runner
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _discover_migrations(self) -> list[Path]:
|
||||
"""Return sorted list of .sql migration files from the community migrations dir."""
|
||||
pkg = importlib.resources.files("circuitforge_core.community.migrations")
|
||||
files = sorted(
|
||||
[Path(str(p)) for p in pkg.iterdir() if str(p).endswith(".sql")],
|
||||
key=lambda p: p.name,
|
||||
)
|
||||
return files
|
||||
|
||||
def run_migrations(self) -> None:
|
||||
"""Apply all community migration SQL files in numeric order.
|
||||
|
||||
Uses a simple applied-migrations table to avoid re-running already
|
||||
applied migrations. Idempotent.
|
||||
"""
|
||||
conn = self.getconn()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute("""
|
||||
CREATE TABLE IF NOT EXISTS _community_migrations (
|
||||
filename TEXT PRIMARY KEY,
|
||||
applied_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
)
|
||||
""")
|
||||
conn.commit()
|
||||
|
||||
for migration_file in self._discover_migrations():
|
||||
name = migration_file.name
|
||||
cur.execute(
|
||||
"SELECT 1 FROM _community_migrations WHERE filename = %s",
|
||||
(name,),
|
||||
)
|
||||
if cur.fetchone():
|
||||
logger.debug("Migration %s already applied, skipping", name)
|
||||
continue
|
||||
|
||||
sql = migration_file.read_text()
|
||||
logger.info("Applying community migration: %s", name)
|
||||
cur.execute(sql)
|
||||
cur.execute(
|
||||
"INSERT INTO _community_migrations (filename) VALUES (%s)",
|
||||
(name,),
|
||||
)
|
||||
conn.commit()
|
||||
except Exception:
|
||||
conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self.putconn(conn)
|
||||
|
|
@ -0,0 +1,55 @@
|
|||
-- 001_community_posts.sql
|
||||
-- Community posts table: published meal plans, recipe successes, and bloopers.
|
||||
-- Applies to: cf_community PostgreSQL database (hosted by cf-orch).
|
||||
-- BSL boundary: this schema is MIT (data layer, no inference).
|
||||
|
||||
CREATE TABLE IF NOT EXISTS community_posts (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
slug TEXT NOT NULL UNIQUE,
|
||||
pseudonym TEXT NOT NULL,
|
||||
post_type TEXT NOT NULL CHECK (post_type IN ('plan', 'recipe_success', 'recipe_blooper')),
|
||||
published TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
title TEXT NOT NULL,
|
||||
description TEXT,
|
||||
photo_url TEXT,
|
||||
|
||||
-- Plan slots (JSON array: [{day, meal_type, recipe_id, recipe_name}])
|
||||
slots JSONB NOT NULL DEFAULT '[]',
|
||||
|
||||
-- Recipe result fields
|
||||
recipe_id BIGINT,
|
||||
recipe_name TEXT,
|
||||
level SMALLINT CHECK (level IS NULL OR level BETWEEN 1 AND 4),
|
||||
outcome_notes TEXT,
|
||||
|
||||
-- Element snapshot (denormalized from corpus at publish time)
|
||||
seasoning_score REAL,
|
||||
richness_score REAL,
|
||||
brightness_score REAL,
|
||||
depth_score REAL,
|
||||
aroma_score REAL,
|
||||
structure_score REAL,
|
||||
texture_profile TEXT,
|
||||
|
||||
-- Dietary / allergen / flavor
|
||||
dietary_tags JSONB NOT NULL DEFAULT '[]',
|
||||
allergen_flags JSONB NOT NULL DEFAULT '[]',
|
||||
flavor_molecules JSONB NOT NULL DEFAULT '[]',
|
||||
|
||||
-- USDA FDC macros
|
||||
fat_pct REAL,
|
||||
protein_pct REAL,
|
||||
moisture_pct REAL,
|
||||
|
||||
-- Source product identifier
|
||||
source_product TEXT NOT NULL DEFAULT 'kiwi'
|
||||
);
|
||||
|
||||
-- Indexes for common filter patterns
|
||||
CREATE INDEX IF NOT EXISTS idx_community_posts_published ON community_posts (published DESC);
|
||||
CREATE INDEX IF NOT EXISTS idx_community_posts_post_type ON community_posts (post_type);
|
||||
CREATE INDEX IF NOT EXISTS idx_community_posts_source ON community_posts (source_product);
|
||||
|
||||
-- GIN index for dietary/allergen JSONB array containment queries
|
||||
CREATE INDEX IF NOT EXISTS idx_community_posts_dietary_tags ON community_posts USING GIN (dietary_tags);
|
||||
CREATE INDEX IF NOT EXISTS idx_community_posts_allergen_flags ON community_posts USING GIN (allergen_flags);
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
-- 002_community_post_reactions.sql
|
||||
-- Reserved: community post reactions (thumbs-up, saves count).
|
||||
-- Not yet implemented -- this migration is a stub to reserve the sequence number.
|
||||
-- Applies to: cf_community PostgreSQL database (hosted by cf-orch).
|
||||
|
||||
-- Placeholder: no-op. Will be replaced when reactions feature is designed.
|
||||
SELECT 1;
|
||||
|
|
@ -0,0 +1,26 @@
|
|||
-- Seller trust signals: confirmed scammer / confirmed legitimate outcomes from Snipe.
|
||||
-- Separate table from community_posts (Kiwi-specific) — seller signals are a
|
||||
-- structurally different domain and should not overload the recipe post schema.
|
||||
-- Applies to: cf_community PostgreSQL database (hosted by cf-orch).
|
||||
-- BSL boundary: table schema is MIT; signal ingestion route in cf-orch is BSL 1.1.
|
||||
|
||||
CREATE TABLE IF NOT EXISTS seller_trust_signals (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
platform TEXT NOT NULL DEFAULT 'ebay',
|
||||
platform_seller_id TEXT NOT NULL,
|
||||
confirmed_scam BOOLEAN NOT NULL,
|
||||
signal_source TEXT NOT NULL, -- 'blocklist_add' | 'community_vote' | 'resolved'
|
||||
flags JSONB NOT NULL DEFAULT '[]', -- red flag keys at time of signal
|
||||
source_product TEXT NOT NULL DEFAULT 'snipe',
|
||||
recorded_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
|
||||
-- No PII: platform_seller_id is the public eBay username or platform ID only.
|
||||
CREATE INDEX IF NOT EXISTS idx_seller_trust_platform_id
|
||||
ON seller_trust_signals (platform, platform_seller_id);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_seller_trust_confirmed
|
||||
ON seller_trust_signals (confirmed_scam);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_seller_trust_recorded
|
||||
ON seller_trust_signals (recorded_at DESC);
|
||||
|
|
@ -0,0 +1,19 @@
|
|||
-- 004_community_categories.sql
|
||||
-- MIT License
|
||||
-- Shared eBay category tree published by credentialed Snipe instances.
|
||||
-- Credentialless instances pull from this table during refresh().
|
||||
-- Privacy: only public eBay category metadata (IDs, names, paths) — no user data.
|
||||
|
||||
CREATE TABLE IF NOT EXISTS community_categories (
|
||||
id SERIAL PRIMARY KEY,
|
||||
platform TEXT NOT NULL DEFAULT 'ebay',
|
||||
category_id TEXT NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
full_path TEXT NOT NULL,
|
||||
source_product TEXT NOT NULL DEFAULT 'snipe',
|
||||
published_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
UNIQUE (platform, category_id)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_community_cat_name
|
||||
ON community_categories (platform, name);
|
||||
2
circuitforge_core/community/migrations/__init__.py
Normal file
2
circuitforge_core/community/migrations/__init__.py
Normal file
|
|
@ -0,0 +1,2 @@
|
|||
# Community module migrations
|
||||
# These SQL files are shipped with circuitforge-core so cf-orch can locate them via importlib.resources.
|
||||
87
circuitforge_core/community/models.py
Normal file
87
circuitforge_core/community/models.py
Normal file
|
|
@ -0,0 +1,87 @@
|
|||
# circuitforge_core/community/models.py
|
||||
# MIT License
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Literal
|
||||
|
||||
PostType = Literal["plan", "recipe_success", "recipe_blooper"]
|
||||
CreativityLevel = Literal[1, 2, 3, 4]
|
||||
|
||||
_VALID_POST_TYPES: frozenset[str] = frozenset(["plan", "recipe_success", "recipe_blooper"])
|
||||
|
||||
|
||||
def _validate_score(name: str, value: float) -> float:
|
||||
if not (0.0 <= value <= 1.0):
|
||||
raise ValueError(f"{name} must be between 0.0 and 1.0, got {value!r}")
|
||||
return value
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CommunityPost:
|
||||
"""Immutable snapshot of a published community post.
|
||||
|
||||
Lists (dietary_tags, allergen_flags, flavor_molecules, slots) are stored as
|
||||
tuples to enforce immutability. Pass lists -- they are converted in __post_init__.
|
||||
"""
|
||||
|
||||
# Identity
|
||||
slug: str
|
||||
pseudonym: str
|
||||
post_type: PostType
|
||||
published: datetime
|
||||
title: str
|
||||
|
||||
# Optional content
|
||||
description: str | None
|
||||
photo_url: str | None
|
||||
|
||||
# Plan slots -- list[dict] for post_type="plan"
|
||||
slots: tuple
|
||||
|
||||
# Recipe result fields -- for post_type="recipe_success" | "recipe_blooper"
|
||||
recipe_id: int | None
|
||||
recipe_name: str | None
|
||||
level: CreativityLevel | None
|
||||
outcome_notes: str | None
|
||||
|
||||
# Element snapshot
|
||||
seasoning_score: float
|
||||
richness_score: float
|
||||
brightness_score: float
|
||||
depth_score: float
|
||||
aroma_score: float
|
||||
structure_score: float
|
||||
texture_profile: str
|
||||
|
||||
# Dietary/allergen/flavor
|
||||
dietary_tags: tuple
|
||||
allergen_flags: tuple
|
||||
flavor_molecules: tuple
|
||||
|
||||
# USDA FDC macros (optional -- may not be available for all recipes)
|
||||
fat_pct: float | None
|
||||
protein_pct: float | None
|
||||
moisture_pct: float | None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# Coerce list fields to tuples (frozen dataclass: use object.__setattr__)
|
||||
for key in ("slots", "dietary_tags", "allergen_flags", "flavor_molecules"):
|
||||
val = getattr(self, key)
|
||||
if isinstance(val, list):
|
||||
object.__setattr__(self, key, tuple(val))
|
||||
|
||||
# Validate post_type
|
||||
if self.post_type not in _VALID_POST_TYPES:
|
||||
raise ValueError(
|
||||
f"post_type must be one of {sorted(_VALID_POST_TYPES)}, got {self.post_type!r}"
|
||||
)
|
||||
|
||||
# Validate scores
|
||||
for score_name in (
|
||||
"seasoning_score", "richness_score", "brightness_score",
|
||||
"depth_score", "aroma_score", "structure_score",
|
||||
):
|
||||
_validate_score(score_name, getattr(self, score_name))
|
||||
253
circuitforge_core/community/snipe_store.py
Normal file
253
circuitforge_core/community/snipe_store.py
Normal file
|
|
@ -0,0 +1,253 @@
|
|||
# circuitforge_core/community/snipe_store.py
|
||||
# MIT License
|
||||
"""Snipe community store — publishes seller trust signals to the shared community DB.
|
||||
|
||||
Snipe products subclass SharedStore here to write seller trust signals
|
||||
(confirmed scammer / confirmed legitimate) to the cf_community PostgreSQL.
|
||||
These signals aggregate across all Snipe users to power the cross-user
|
||||
seller trust classifier fine-tuning corpus.
|
||||
|
||||
Privacy: only platform_seller_id (public eBay username/ID) and flag keys
|
||||
are written. No PII is stored.
|
||||
|
||||
Usage:
|
||||
from circuitforge_core.community import CommunityDB
|
||||
from circuitforge_core.community.snipe_store import SnipeCommunityStore
|
||||
|
||||
db = CommunityDB.from_env()
|
||||
store = SnipeCommunityStore(db, source_product="snipe")
|
||||
store.publish_seller_signal(
|
||||
platform_seller_id="ebay-username",
|
||||
confirmed_scam=True,
|
||||
signal_source="blocklist_add",
|
||||
flags=["new_account", "suspicious_price"],
|
||||
)
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from .store import SharedStore
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SellerTrustSignal:
|
||||
"""Immutable snapshot of a recorded seller trust signal."""
|
||||
id: int
|
||||
platform: str
|
||||
platform_seller_id: str
|
||||
confirmed_scam: bool
|
||||
signal_source: str
|
||||
flags: tuple
|
||||
source_product: str
|
||||
recorded_at: datetime
|
||||
|
||||
|
||||
class SnipeCommunityStore(SharedStore):
|
||||
"""Community store for Snipe — seller trust signal publishing and querying."""
|
||||
|
||||
def __init__(self, db, source_product: str = "snipe") -> None:
|
||||
super().__init__(db, source_product=source_product)
|
||||
|
||||
def publish_seller_signal(
|
||||
self,
|
||||
platform_seller_id: str,
|
||||
confirmed_scam: bool,
|
||||
signal_source: str,
|
||||
flags: list[str] | None = None,
|
||||
platform: str = "ebay",
|
||||
) -> SellerTrustSignal:
|
||||
"""Record a seller trust outcome in the shared community DB.
|
||||
|
||||
Args:
|
||||
platform_seller_id: Public eBay username or platform ID (no PII).
|
||||
confirmed_scam: True = confirmed bad actor; False = confirmed legitimate.
|
||||
signal_source: Origin of the signal.
|
||||
'blocklist_add' — user explicitly added to local blocklist
|
||||
'community_vote' — consensus threshold reached from multiple reports
|
||||
'resolved' — seller resolved as legitimate over time
|
||||
flags: List of red-flag keys active at signal time (e.g. ["new_account"]).
|
||||
platform: Source auction platform (default "ebay").
|
||||
|
||||
Returns the inserted SellerTrustSignal.
|
||||
"""
|
||||
flags = flags or []
|
||||
conn = self._db.getconn()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO seller_trust_signals
|
||||
(platform, platform_seller_id, confirmed_scam,
|
||||
signal_source, flags, source_product)
|
||||
VALUES (%s, %s, %s, %s, %s::jsonb, %s)
|
||||
RETURNING id, recorded_at
|
||||
""",
|
||||
(
|
||||
platform,
|
||||
platform_seller_id,
|
||||
confirmed_scam,
|
||||
signal_source,
|
||||
json.dumps(flags),
|
||||
self._source_product,
|
||||
),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
conn.commit()
|
||||
return SellerTrustSignal(
|
||||
id=row[0],
|
||||
platform=platform,
|
||||
platform_seller_id=platform_seller_id,
|
||||
confirmed_scam=confirmed_scam,
|
||||
signal_source=signal_source,
|
||||
flags=tuple(flags),
|
||||
source_product=self._source_product,
|
||||
recorded_at=row[1],
|
||||
)
|
||||
except Exception:
|
||||
conn.rollback()
|
||||
log.warning(
|
||||
"Failed to publish seller signal for %s (%s)",
|
||||
platform_seller_id, signal_source, exc_info=True,
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
self._db.putconn(conn)
|
||||
|
||||
def list_signals_for_seller(
|
||||
self,
|
||||
platform_seller_id: str,
|
||||
platform: str = "ebay",
|
||||
limit: int = 50,
|
||||
) -> list[SellerTrustSignal]:
|
||||
"""Return recent trust signals for a specific seller."""
|
||||
conn = self._db.getconn()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT id, platform, platform_seller_id, confirmed_scam,
|
||||
signal_source, flags, source_product, recorded_at
|
||||
FROM seller_trust_signals
|
||||
WHERE platform = %s AND platform_seller_id = %s
|
||||
ORDER BY recorded_at DESC
|
||||
LIMIT %s
|
||||
""",
|
||||
(platform, platform_seller_id, limit),
|
||||
)
|
||||
rows = cur.fetchall()
|
||||
return [
|
||||
SellerTrustSignal(
|
||||
id=r[0], platform=r[1], platform_seller_id=r[2],
|
||||
confirmed_scam=r[3], signal_source=r[4],
|
||||
flags=tuple(json.loads(r[5]) if isinstance(r[5], str) else r[5] or []),
|
||||
source_product=r[6], recorded_at=r[7],
|
||||
)
|
||||
for r in rows
|
||||
]
|
||||
finally:
|
||||
self._db.putconn(conn)
|
||||
|
||||
def scam_signal_count(self, platform_seller_id: str, platform: str = "ebay") -> int:
|
||||
"""Return the number of confirmed_scam=True signals for a seller.
|
||||
|
||||
Used to determine if a seller has crossed the community consensus threshold
|
||||
for appearing in the shared blocklist.
|
||||
"""
|
||||
conn = self._db.getconn()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT COUNT(*) FROM seller_trust_signals
|
||||
WHERE platform = %s AND platform_seller_id = %s AND confirmed_scam = TRUE
|
||||
""",
|
||||
(platform, platform_seller_id),
|
||||
)
|
||||
return cur.fetchone()[0]
|
||||
finally:
|
||||
self._db.putconn(conn)
|
||||
|
||||
def publish_categories(
|
||||
self,
|
||||
categories: list[tuple[str, str, str]],
|
||||
platform: str = "ebay",
|
||||
) -> int:
|
||||
"""Upsert a batch of eBay leaf categories into the shared community table.
|
||||
|
||||
Args:
|
||||
categories: List of (category_id, name, full_path) tuples.
|
||||
platform: Source auction platform (default "ebay").
|
||||
|
||||
Returns:
|
||||
Number of rows upserted.
|
||||
"""
|
||||
if not categories:
|
||||
return 0
|
||||
conn = self._db.getconn()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.executemany(
|
||||
"""
|
||||
INSERT INTO community_categories
|
||||
(platform, category_id, name, full_path, source_product)
|
||||
VALUES (%s, %s, %s, %s, %s)
|
||||
ON CONFLICT (platform, category_id)
|
||||
DO UPDATE SET
|
||||
name = EXCLUDED.name,
|
||||
full_path = EXCLUDED.full_path,
|
||||
source_product = EXCLUDED.source_product,
|
||||
published_at = NOW()
|
||||
""",
|
||||
[
|
||||
(platform, cid, name, path, self._source_product)
|
||||
for cid, name, path in categories
|
||||
],
|
||||
)
|
||||
conn.commit()
|
||||
return len(categories)
|
||||
except Exception:
|
||||
conn.rollback()
|
||||
log.warning(
|
||||
"Failed to publish %d categories to community store",
|
||||
len(categories), exc_info=True,
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
self._db.putconn(conn)
|
||||
|
||||
def fetch_categories(
|
||||
self,
|
||||
platform: str = "ebay",
|
||||
limit: int = 500,
|
||||
) -> list[tuple[str, str, str]]:
|
||||
"""Fetch community-contributed eBay categories.
|
||||
|
||||
Args:
|
||||
platform: Source auction platform (default "ebay").
|
||||
limit: Maximum rows to return.
|
||||
|
||||
Returns:
|
||||
List of (category_id, name, full_path) tuples ordered by name.
|
||||
"""
|
||||
conn = self._db.getconn()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT category_id, name, full_path
|
||||
FROM community_categories
|
||||
WHERE platform = %s
|
||||
ORDER BY name
|
||||
LIMIT %s
|
||||
""",
|
||||
(platform, limit),
|
||||
)
|
||||
return [(row[0], row[1], row[2]) for row in cur.fetchall()]
|
||||
finally:
|
||||
self._db.putconn(conn)
|
||||
209
circuitforge_core/community/store.py
Normal file
209
circuitforge_core/community/store.py
Normal file
|
|
@ -0,0 +1,209 @@
|
|||
# circuitforge_core/community/store.py
|
||||
# MIT License
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .models import CommunityPost
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .db import CommunityDB
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _row_to_post(row: dict) -> CommunityPost:
|
||||
"""Convert a psycopg2 row dict to a CommunityPost.
|
||||
|
||||
JSONB columns (slots, dietary_tags, allergen_flags, flavor_molecules) come
|
||||
back from psycopg2 as Python lists already — no json.loads() needed.
|
||||
"""
|
||||
return CommunityPost(
|
||||
slug=row["slug"],
|
||||
pseudonym=row["pseudonym"],
|
||||
post_type=row["post_type"],
|
||||
published=row["published"],
|
||||
title=row["title"],
|
||||
description=row.get("description"),
|
||||
photo_url=row.get("photo_url"),
|
||||
slots=row.get("slots") or [],
|
||||
recipe_id=row.get("recipe_id"),
|
||||
recipe_name=row.get("recipe_name"),
|
||||
level=row.get("level"),
|
||||
outcome_notes=row.get("outcome_notes"),
|
||||
seasoning_score=row["seasoning_score"] or 0.0,
|
||||
richness_score=row["richness_score"] or 0.0,
|
||||
brightness_score=row["brightness_score"] or 0.0,
|
||||
depth_score=row["depth_score"] or 0.0,
|
||||
aroma_score=row["aroma_score"] or 0.0,
|
||||
structure_score=row["structure_score"] or 0.0,
|
||||
texture_profile=row.get("texture_profile") or "",
|
||||
dietary_tags=row.get("dietary_tags") or [],
|
||||
allergen_flags=row.get("allergen_flags") or [],
|
||||
flavor_molecules=row.get("flavor_molecules") or [],
|
||||
fat_pct=row.get("fat_pct"),
|
||||
protein_pct=row.get("protein_pct"),
|
||||
moisture_pct=row.get("moisture_pct"),
|
||||
)
|
||||
|
||||
|
||||
def _cursor_to_dict(cur, row) -> dict:
|
||||
"""Convert a psycopg2 row tuple to a dict using cursor.description."""
|
||||
if isinstance(row, dict):
|
||||
return row
|
||||
return {desc[0]: val for desc, val in zip(cur.description, row)}
|
||||
|
||||
|
||||
class SharedStore:
|
||||
"""Base class for product community stores.
|
||||
|
||||
Subclass this in each product:
|
||||
class KiwiCommunityStore(SharedStore):
|
||||
def list_posts_for_week(self, week_start: str) -> list[CommunityPost]: ...
|
||||
|
||||
All methods return new objects (immutable pattern). Never mutate rows in-place.
|
||||
"""
|
||||
|
||||
def __init__(self, db: "CommunityDB", source_product: str = "kiwi") -> None:
|
||||
self._db = db
|
||||
self._source_product = source_product
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Reads
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def get_post_by_slug(self, slug: str) -> CommunityPost | None:
|
||||
conn = self._db.getconn()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"SELECT * FROM community_posts WHERE slug = %s LIMIT 1",
|
||||
(slug,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
return _row_to_post(_cursor_to_dict(cur, row))
|
||||
finally:
|
||||
self._db.putconn(conn)
|
||||
|
||||
def list_posts(
|
||||
self,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
post_type: str | None = None,
|
||||
dietary_tags: list[str] | None = None,
|
||||
allergen_exclude: list[str] | None = None,
|
||||
source_product: str | None = None,
|
||||
) -> list[CommunityPost]:
|
||||
"""Paginated post list with optional filters.
|
||||
|
||||
dietary_tags: JSONB containment — posts must include ALL listed tags.
|
||||
allergen_exclude: JSONB overlap exclusion — posts must NOT include any listed flag.
|
||||
"""
|
||||
conn = self._db.getconn()
|
||||
try:
|
||||
conditions = []
|
||||
params: list = []
|
||||
|
||||
if post_type:
|
||||
conditions.append("post_type = %s")
|
||||
params.append(post_type)
|
||||
if dietary_tags:
|
||||
import json
|
||||
conditions.append("dietary_tags @> %s::jsonb")
|
||||
params.append(json.dumps(dietary_tags))
|
||||
if allergen_exclude:
|
||||
import json
|
||||
conditions.append("NOT (allergen_flags && %s::jsonb)")
|
||||
params.append(json.dumps(allergen_exclude))
|
||||
if source_product:
|
||||
conditions.append("source_product = %s")
|
||||
params.append(source_product)
|
||||
|
||||
where = ("WHERE " + " AND ".join(conditions)) if conditions else ""
|
||||
params.extend([limit, offset])
|
||||
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
f"SELECT * FROM community_posts {where} "
|
||||
"ORDER BY published DESC LIMIT %s OFFSET %s",
|
||||
params,
|
||||
)
|
||||
rows = cur.fetchall()
|
||||
return [_row_to_post(_cursor_to_dict(cur, r)) for r in rows]
|
||||
finally:
|
||||
self._db.putconn(conn)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Writes
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def insert_post(self, post: CommunityPost) -> CommunityPost:
|
||||
"""Insert a new community post. Returns the inserted post (unchanged — slug is the key)."""
|
||||
import json
|
||||
|
||||
conn = self._db.getconn()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO community_posts (
|
||||
slug, pseudonym, post_type, published, title, description, photo_url,
|
||||
slots, recipe_id, recipe_name, level, outcome_notes,
|
||||
seasoning_score, richness_score, brightness_score,
|
||||
depth_score, aroma_score, structure_score, texture_profile,
|
||||
dietary_tags, allergen_flags, flavor_molecules,
|
||||
fat_pct, protein_pct, moisture_pct, source_product
|
||||
) VALUES (
|
||||
%s, %s, %s, %s, %s, %s, %s,
|
||||
%s::jsonb, %s, %s, %s, %s,
|
||||
%s, %s, %s, %s, %s, %s, %s,
|
||||
%s::jsonb, %s::jsonb, %s::jsonb,
|
||||
%s, %s, %s, %s
|
||||
)
|
||||
""",
|
||||
(
|
||||
post.slug, post.pseudonym, post.post_type,
|
||||
post.published, post.title, post.description, post.photo_url,
|
||||
json.dumps(list(post.slots)),
|
||||
post.recipe_id, post.recipe_name, post.level, post.outcome_notes,
|
||||
post.seasoning_score, post.richness_score, post.brightness_score,
|
||||
post.depth_score, post.aroma_score, post.structure_score,
|
||||
post.texture_profile,
|
||||
json.dumps(list(post.dietary_tags)),
|
||||
json.dumps(list(post.allergen_flags)),
|
||||
json.dumps(list(post.flavor_molecules)),
|
||||
post.fat_pct, post.protein_pct, post.moisture_pct,
|
||||
self._source_product,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
return post
|
||||
except Exception:
|
||||
conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self._db.putconn(conn)
|
||||
|
||||
def delete_post(self, slug: str, pseudonym: str) -> bool:
|
||||
"""Hard-delete a post. Only succeeds if pseudonym matches the author.
|
||||
|
||||
Returns True if a row was deleted, False if no matching row found.
|
||||
"""
|
||||
conn = self._db.getconn()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"DELETE FROM community_posts WHERE slug = %s AND pseudonym = %s",
|
||||
(slug, pseudonym),
|
||||
)
|
||||
conn.commit()
|
||||
return cur.rowcount > 0
|
||||
except Exception:
|
||||
conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self._db.putconn(conn)
|
||||
|
|
@ -1,3 +1,4 @@
|
|||
from .settings import require_env, load_env
|
||||
from .license import validate_license, get_license_tier
|
||||
|
||||
__all__ = ["require_env", "load_env"]
|
||||
__all__ = ["require_env", "load_env", "validate_license", "get_license_tier"]
|
||||
|
|
|
|||
104
circuitforge_core/config/license.py
Normal file
104
circuitforge_core/config/license.py
Normal file
|
|
@ -0,0 +1,104 @@
|
|||
"""
|
||||
License validation via Heimdall.
|
||||
|
||||
Products call validate_license() or get_license_tier() at startup to check
|
||||
the CF_LICENSE_KEY environment variable against Heimdall.
|
||||
|
||||
Both functions are safe to call when CF_LICENSE_KEY is absent — they return
|
||||
"free" tier gracefully rather than raising.
|
||||
|
||||
Environment variables:
|
||||
CF_LICENSE_KEY — Raw license key (e.g. CFG-PRNG-XXXX-XXXX-XXXX).
|
||||
If absent, product runs as free tier.
|
||||
CF_LICENSE_URL — Heimdall base URL override.
|
||||
Default: https://license.circuitforge.tech
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
|
||||
import requests
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_DEFAULT_HEIMDALL_URL = "https://license.circuitforge.tech"
|
||||
_CACHE_TTL_SECONDS = 1800 # 30 minutes
|
||||
|
||||
# Cache: (key, product) -> (result_dict, expires_at)
|
||||
_cache: dict[tuple[str, str], tuple[dict[str, bool | str], float]] = {}
|
||||
|
||||
_INVALID: dict[str, bool | str] = {"valid": False, "tier": "free", "user_id": ""}
|
||||
|
||||
|
||||
def _heimdall_url(override: str | None) -> str:
|
||||
return override or os.environ.get("CF_LICENSE_URL", _DEFAULT_HEIMDALL_URL)
|
||||
|
||||
|
||||
def validate_license(
|
||||
product: str,
|
||||
min_tier: str = "free",
|
||||
heimdall_url: str | None = None,
|
||||
) -> dict[str, bool | str]:
|
||||
"""
|
||||
Validate CF_LICENSE_KEY against Heimdall for the given product.
|
||||
|
||||
Returns a dict with keys: valid (bool), tier (str), user_id (str).
|
||||
Returns {"valid": False, "tier": "free", "user_id": ""} when:
|
||||
- CF_LICENSE_KEY is not set
|
||||
- Heimdall is unreachable
|
||||
- The key is invalid/expired/revoked
|
||||
|
||||
Results are cached for 30 minutes per (key, product) pair.
|
||||
"""
|
||||
key = os.environ.get("CF_LICENSE_KEY", "").strip()
|
||||
if not key:
|
||||
return dict(_INVALID)
|
||||
|
||||
cache_key = (key, product)
|
||||
now = time.monotonic()
|
||||
if cache_key in _cache:
|
||||
cached_result, expires_at = _cache[cache_key]
|
||||
if now < expires_at:
|
||||
return dict(cached_result)
|
||||
|
||||
base = _heimdall_url(heimdall_url)
|
||||
try:
|
||||
resp = requests.post(
|
||||
f"{base}/licenses/verify",
|
||||
json={"key": key, "min_tier": min_tier},
|
||||
timeout=5,
|
||||
)
|
||||
if not resp.ok:
|
||||
logger.warning("[license] Heimdall returned %s for key validation", resp.status_code)
|
||||
result = dict(_INVALID)
|
||||
else:
|
||||
data = resp.json()
|
||||
result = {
|
||||
"valid": bool(data.get("valid", False)),
|
||||
"tier": data.get("tier", "free") or "free",
|
||||
"user_id": data.get("user_id", "") or "",
|
||||
}
|
||||
except Exception as exc:
|
||||
logger.warning("[license] License validation failed: %s", exc)
|
||||
result = dict(_INVALID)
|
||||
|
||||
_cache[cache_key] = (result, now + _CACHE_TTL_SECONDS)
|
||||
return result
|
||||
|
||||
|
||||
def get_license_tier(
|
||||
product: str,
|
||||
heimdall_url: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Return the active tier for CF_LICENSE_KEY, or "free" if absent/invalid.
|
||||
|
||||
Convenience wrapper around validate_license() for the common case
|
||||
where only the tier string is needed.
|
||||
"""
|
||||
result = validate_license(product, min_tier="free", heimdall_url=heimdall_url)
|
||||
if not result["valid"]:
|
||||
return "free"
|
||||
return result["tier"]
|
||||
|
|
@ -4,12 +4,22 @@ Applies *.sql files from migrations_dir in filename order.
|
|||
Tracks applied migrations in a _migrations table — safe to call multiple times.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def run_migrations(conn: sqlite3.Connection, migrations_dir: Path) -> None:
|
||||
"""Apply any unapplied *.sql migrations from migrations_dir."""
|
||||
"""Apply any unapplied *.sql migrations from migrations_dir.
|
||||
|
||||
Resilient to partial-failure recovery: if a migration previously crashed
|
||||
mid-run (e.g. a process killed after some ALTER TABLE statements
|
||||
auto-committed via executescript), the next startup re-runs that migration.
|
||||
Any "duplicate column name" errors are silently skipped so the migration
|
||||
can complete and be marked as applied. All other errors still propagate.
|
||||
"""
|
||||
conn.execute(
|
||||
"CREATE TABLE IF NOT EXISTS _migrations "
|
||||
"(name TEXT PRIMARY KEY, applied_at TEXT DEFAULT CURRENT_TIMESTAMP)"
|
||||
|
|
@ -22,8 +32,92 @@ def run_migrations(conn: sqlite3.Connection, migrations_dir: Path) -> None:
|
|||
for sql_file in sql_files:
|
||||
if sql_file.name in applied:
|
||||
continue
|
||||
conn.executescript(sql_file.read_text())
|
||||
|
||||
try:
|
||||
conn.executescript(sql_file.read_text())
|
||||
except sqlite3.OperationalError as exc:
|
||||
if "duplicate column name" not in str(exc).lower():
|
||||
raise
|
||||
# A previous run partially applied this migration (some ALTER TABLE
|
||||
# statements auto-committed before the failure). Re-run with
|
||||
# per-statement recovery to skip already-applied columns.
|
||||
_log.warning(
|
||||
"Migration %s: partial-failure detected (%s) — "
|
||||
"retrying with per-statement recovery",
|
||||
sql_file.name,
|
||||
exc,
|
||||
)
|
||||
_run_script_with_recovery(conn, sql_file)
|
||||
|
||||
# OR IGNORE: safe if two Store() calls race on the same DB — second writer
|
||||
# just skips the insert rather than raising UNIQUE constraint failed.
|
||||
conn.execute("INSERT OR IGNORE INTO _migrations (name) VALUES (?)", (sql_file.name,))
|
||||
conn.commit()
|
||||
|
||||
|
||||
def _run_script_with_recovery(conn: sqlite3.Connection, sql_file: Path) -> None:
|
||||
"""Re-run a migration via executescript, skipping duplicate-column errors.
|
||||
|
||||
Used only when the first executescript() attempt raised a duplicate column
|
||||
error (indicating a previous partial run). Splits the script on the
|
||||
double-dash comment prefix pattern to re-issue each logical statement,
|
||||
catching only the known-safe "duplicate column name" error class.
|
||||
|
||||
Splitting is done via SQLite's own parser — we feed the script to a
|
||||
temporary in-memory connection using executescript (which commits
|
||||
auto-matically per DDL statement) and mirror the results on the real
|
||||
connection statement by statement. That's circular, so instead we use
|
||||
the simpler approach: executescript handles tokenization; we wrap the
|
||||
whole call in a try/except and retry after removing the offending statement.
|
||||
|
||||
Simpler approach: use conn.execute() per statement from the script.
|
||||
This avoids the semicolon-in-comment tokenization problem by not splitting
|
||||
ourselves — instead we let the DB tell us which statement failed and only
|
||||
skip that exact error class.
|
||||
"""
|
||||
# executescript() uses SQLite's real tokenizer, so re-issuing it after a
|
||||
# partial failure will hit "duplicate column name" again. We catch and
|
||||
# ignore that specific error class only, re-running until the script
|
||||
# completes or a different error is raised.
|
||||
#
|
||||
# Implementation: issue the whole script again; catch duplicate-column
|
||||
# errors; keep trying. Since executescript auto-commits per statement,
|
||||
# each successful statement in successive retries is a no-op (CREATE TABLE
|
||||
# IF NOT EXISTS, etc.) or a benign duplicate skip.
|
||||
#
|
||||
# Limit retries to prevent infinite loops on genuinely broken SQL.
|
||||
script = sql_file.read_text()
|
||||
for attempt in range(20):
|
||||
try:
|
||||
conn.executescript(script)
|
||||
return # success
|
||||
except sqlite3.OperationalError as exc:
|
||||
msg = str(exc).lower()
|
||||
if "duplicate column name" in msg:
|
||||
col = str(exc).split(":")[-1].strip() if ":" in str(exc) else "?"
|
||||
_log.warning(
|
||||
"Migration %s (attempt %d): skipping duplicate column '%s'",
|
||||
sql_file.name,
|
||||
attempt + 1,
|
||||
col,
|
||||
)
|
||||
# Remove the offending ALTER TABLE statement from the script
|
||||
# so the next attempt skips it. This is safe because SQLite
|
||||
# already auto-committed that column addition on a prior run.
|
||||
script = _remove_column_add(script, col)
|
||||
else:
|
||||
raise
|
||||
raise RuntimeError(
|
||||
f"Migration {sql_file.name}: could not complete after 20 recovery attempts"
|
||||
)
|
||||
|
||||
|
||||
def _remove_column_add(script: str, column: str) -> str:
|
||||
"""Remove the ALTER TABLE ADD COLUMN statement for *column* from *script*."""
|
||||
import re
|
||||
# Match: ALTER TABLE <tbl> ADD COLUMN <column> <rest-of-line>
|
||||
pattern = re.compile(
|
||||
r"ALTER\s+TABLE\s+\w+\s+ADD\s+COLUMN\s+" + re.escape(column) + r"[^\n]*\n?",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
return pattern.sub("", script)
|
||||
|
|
|
|||
|
|
@ -17,13 +17,80 @@ CONFIG_PATH = Path.home() / ".config" / "circuitforge" / "llm.yaml"
|
|||
|
||||
class LLMRouter:
|
||||
def __init__(self, config_path: Path = CONFIG_PATH):
|
||||
if not config_path.exists():
|
||||
raise FileNotFoundError(
|
||||
f"{config_path} not found. "
|
||||
"Copy the llm.yaml.example to ~/.config/circuitforge/llm.yaml and configure your LLM backends."
|
||||
if config_path.exists():
|
||||
with open(config_path) as f:
|
||||
self.config = yaml.safe_load(f)
|
||||
else:
|
||||
env_config = self._auto_config_from_env()
|
||||
if env_config is None:
|
||||
raise FileNotFoundError(
|
||||
f"{config_path} not found and no LLM env vars detected. "
|
||||
"Either copy llm.yaml.example to ~/.config/circuitforge/llm.yaml, "
|
||||
"or set ANTHROPIC_API_KEY, OPENAI_API_KEY, or OLLAMA_HOST."
|
||||
)
|
||||
logger.info(
|
||||
"[LLMRouter] No llm.yaml found — using env-var auto-config "
|
||||
"(backends: %s)", ", ".join(env_config["fallback_order"])
|
||||
)
|
||||
with open(config_path) as f:
|
||||
self.config = yaml.safe_load(f)
|
||||
self.config = env_config
|
||||
|
||||
@staticmethod
|
||||
def _auto_config_from_env() -> dict | None:
|
||||
"""Build a minimal LLM config from well-known environment variables.
|
||||
|
||||
Priority order (highest to lowest):
|
||||
1. ANTHROPIC_API_KEY → anthropic backend
|
||||
2. OPENAI_API_KEY → openai-compat → api.openai.com (or OPENAI_BASE_URL)
|
||||
3. OLLAMA_HOST → openai-compat → local Ollama (always included as last resort)
|
||||
|
||||
Returns None only when none of these are set and Ollama is not configured,
|
||||
so the caller can decide whether to raise or surface a user-facing message.
|
||||
"""
|
||||
backends: dict = {}
|
||||
fallback_order: list[str] = []
|
||||
|
||||
if os.environ.get("ANTHROPIC_API_KEY"):
|
||||
backends["anthropic"] = {
|
||||
"type": "anthropic",
|
||||
"enabled": True,
|
||||
"model": os.environ.get("ANTHROPIC_MODEL", "claude-haiku-4-5-20251001"),
|
||||
"api_key_env": "ANTHROPIC_API_KEY",
|
||||
"supports_images": True,
|
||||
}
|
||||
fallback_order.append("anthropic")
|
||||
|
||||
if os.environ.get("OPENAI_API_KEY"):
|
||||
backends["openai"] = {
|
||||
"type": "openai_compat",
|
||||
"enabled": True,
|
||||
"base_url": os.environ.get("OPENAI_BASE_URL", "https://api.openai.com/v1"),
|
||||
"model": os.environ.get("OPENAI_MODEL", "gpt-4o-mini"),
|
||||
"api_key": os.environ.get("OPENAI_API_KEY"),
|
||||
"supports_images": True,
|
||||
}
|
||||
fallback_order.append("openai")
|
||||
|
||||
# Ollama — always added when any config exists, as the lowest-cost local fallback.
|
||||
# Unreachable Ollama is harmless — _is_reachable() skips it gracefully.
|
||||
ollama_host = os.environ.get("OLLAMA_HOST", "http://localhost:11434")
|
||||
if not ollama_host.startswith("http"):
|
||||
ollama_host = f"http://{ollama_host}"
|
||||
backends["ollama"] = {
|
||||
"type": "openai_compat",
|
||||
"enabled": True,
|
||||
"base_url": ollama_host.rstrip("/") + "/v1",
|
||||
"model": os.environ.get("OLLAMA_MODEL", "llama3.2:3b"),
|
||||
"api_key": "any",
|
||||
"supports_images": False,
|
||||
}
|
||||
fallback_order.append("ollama")
|
||||
|
||||
# Return None if only ollama is in the list AND no explicit host was set —
|
||||
# that means the user set nothing at all, not even OLLAMA_HOST.
|
||||
if fallback_order == ["ollama"] and "OLLAMA_HOST" not in os.environ:
|
||||
return None
|
||||
|
||||
return {"backends": backends, "fallback_order": fallback_order}
|
||||
|
||||
def _is_reachable(self, base_url: str) -> bool:
|
||||
"""Quick health-check ping. Returns True if backend is up."""
|
||||
|
|
@ -56,7 +123,7 @@ class LLMRouter:
|
|||
if not orch_url:
|
||||
return None
|
||||
try:
|
||||
from circuitforge_core.resources.client import CFOrchClient
|
||||
from circuitforge_orch.client import CFOrchClient
|
||||
client = CFOrchClient(orch_url)
|
||||
service = orch_cfg.get("service", "vllm")
|
||||
candidates = orch_cfg.get("model_candidates", [])
|
||||
|
|
@ -132,15 +199,18 @@ class LLMRouter:
|
|||
continue
|
||||
|
||||
elif backend["type"] == "openai_compat":
|
||||
if not self._is_reachable(backend["base_url"]):
|
||||
print(f"[LLMRouter] {name}: unreachable, skipping")
|
||||
continue
|
||||
# --- cf_orch: optionally override base_url with coordinator-allocated URL ---
|
||||
# cf_orch: try allocation first — this may start the service on-demand.
|
||||
# Do NOT reachability-check before allocating; the service may be stopped
|
||||
# and the allocation is what starts it.
|
||||
orch_ctx = orch_alloc = None
|
||||
orch_result = self._try_cf_orch_alloc(backend)
|
||||
if orch_result is not None:
|
||||
orch_ctx, orch_alloc = orch_result
|
||||
backend = {**backend, "base_url": orch_alloc.url + "/v1"}
|
||||
elif not self._is_reachable(backend["base_url"]):
|
||||
# Static backend (no cf-orch) — skip if not reachable.
|
||||
print(f"[LLMRouter] {name}: unreachable, skipping")
|
||||
continue
|
||||
try:
|
||||
client = OpenAI(
|
||||
base_url=backend["base_url"],
|
||||
|
|
|
|||
12
circuitforge_core/manage/__init__.py
Normal file
12
circuitforge_core/manage/__init__.py
Normal file
|
|
@ -0,0 +1,12 @@
|
|||
"""circuitforge_core.manage — cross-platform product process manager."""
|
||||
from .config import ManageConfig, NativeService
|
||||
from .docker_mode import DockerManager, docker_available
|
||||
from .native_mode import NativeManager
|
||||
|
||||
__all__ = [
|
||||
"ManageConfig",
|
||||
"NativeService",
|
||||
"DockerManager",
|
||||
"docker_available",
|
||||
"NativeManager",
|
||||
]
|
||||
4
circuitforge_core/manage/__main__.py
Normal file
4
circuitforge_core/manage/__main__.py
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
"""Entry point for `python -m circuitforge_core.manage`."""
|
||||
from .cli import app
|
||||
|
||||
app()
|
||||
237
circuitforge_core/manage/cli.py
Normal file
237
circuitforge_core/manage/cli.py
Normal file
|
|
@ -0,0 +1,237 @@
|
|||
"""
|
||||
circuitforge_core.manage.cli — cross-platform product manager CLI.
|
||||
|
||||
Usage (from any product directory):
|
||||
python -m circuitforge_core.manage start
|
||||
python -m circuitforge_core.manage stop
|
||||
python -m circuitforge_core.manage restart
|
||||
python -m circuitforge_core.manage status
|
||||
python -m circuitforge_core.manage logs [SERVICE]
|
||||
python -m circuitforge_core.manage open
|
||||
python -m circuitforge_core.manage build
|
||||
python -m circuitforge_core.manage install-shims
|
||||
|
||||
Products shim into this via a thin manage.sh / manage.ps1 that finds Python
|
||||
and delegates: exec python -m circuitforge_core.manage "$@"
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
import webbrowser
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Annotated, Optional
|
||||
|
||||
import typer
|
||||
|
||||
from .config import ManageConfig
|
||||
from .docker_mode import DockerManager, docker_available
|
||||
from .native_mode import NativeManager
|
||||
|
||||
app = typer.Typer(
|
||||
name="manage",
|
||||
help="CircuitForge cross-platform product manager",
|
||||
no_args_is_help=True,
|
||||
)
|
||||
|
||||
|
||||
class Mode(str, Enum):
|
||||
auto = "auto"
|
||||
docker = "docker"
|
||||
native = "native"
|
||||
|
||||
|
||||
def _resolve(
|
||||
mode: Mode,
|
||||
root: Path,
|
||||
config: ManageConfig,
|
||||
) -> tuple[str, DockerManager | NativeManager]:
|
||||
"""Return (mode_name, manager) based on mode flag and environment."""
|
||||
if mode == Mode.docker or (
|
||||
mode == Mode.auto
|
||||
and docker_available()
|
||||
and (root / config.docker.compose_file).exists()
|
||||
):
|
||||
return "docker", DockerManager(config, root)
|
||||
return "native", NativeManager(config, root)
|
||||
|
||||
|
||||
def _load(root: Path) -> ManageConfig:
|
||||
return ManageConfig.from_cwd(root)
|
||||
|
||||
|
||||
# ── commands ──────────────────────────────────────────────────────────────────
|
||||
|
||||
@app.command()
|
||||
def start(
|
||||
service: Annotated[Optional[str], typer.Argument(help="Service name (omit for all)")] = None,
|
||||
mode: Mode = Mode.auto,
|
||||
root: Path = Path("."),
|
||||
) -> None:
|
||||
"""Start services."""
|
||||
config = _load(root.resolve())
|
||||
mode_name, mgr = _resolve(mode, root.resolve(), config)
|
||||
typer.echo(f"[{config.app_name}] Starting ({mode_name} mode)…")
|
||||
if isinstance(mgr, DockerManager):
|
||||
mgr.start(service or "")
|
||||
else:
|
||||
started = mgr.start(service)
|
||||
if started:
|
||||
typer.echo(f"[{config.app_name}] Started: {', '.join(started)}")
|
||||
else:
|
||||
typer.echo(f"[{config.app_name}] All services already running")
|
||||
|
||||
|
||||
@app.command()
|
||||
def stop(
|
||||
service: Annotated[Optional[str], typer.Argument(help="Service name (omit for all)")] = None,
|
||||
mode: Mode = Mode.auto,
|
||||
root: Path = Path("."),
|
||||
) -> None:
|
||||
"""Stop services."""
|
||||
config = _load(root.resolve())
|
||||
mode_name, mgr = _resolve(mode, root.resolve(), config)
|
||||
typer.echo(f"[{config.app_name}] Stopping ({mode_name} mode)…")
|
||||
if isinstance(mgr, DockerManager):
|
||||
mgr.stop(service or "")
|
||||
else:
|
||||
stopped = mgr.stop(service)
|
||||
if stopped:
|
||||
typer.echo(f"[{config.app_name}] Stopped: {', '.join(stopped)}")
|
||||
else:
|
||||
typer.echo(f"[{config.app_name}] No running services to stop")
|
||||
|
||||
|
||||
@app.command()
|
||||
def restart(
|
||||
service: Annotated[Optional[str], typer.Argument(help="Service name (omit for all)")] = None,
|
||||
mode: Mode = Mode.auto,
|
||||
root: Path = Path("."),
|
||||
) -> None:
|
||||
"""Restart services."""
|
||||
config = _load(root.resolve())
|
||||
mode_name, mgr = _resolve(mode, root.resolve(), config)
|
||||
typer.echo(f"[{config.app_name}] Restarting ({mode_name} mode)…")
|
||||
if isinstance(mgr, DockerManager):
|
||||
mgr.restart(service or "")
|
||||
else:
|
||||
mgr.stop(service)
|
||||
mgr.start(service)
|
||||
|
||||
|
||||
@app.command()
|
||||
def status(
|
||||
mode: Mode = Mode.auto,
|
||||
root: Path = Path("."),
|
||||
) -> None:
|
||||
"""Show service status."""
|
||||
config = _load(root.resolve())
|
||||
mode_name, mgr = _resolve(mode, root.resolve(), config)
|
||||
if isinstance(mgr, DockerManager):
|
||||
mgr.status()
|
||||
else:
|
||||
rows = mgr.status()
|
||||
if not rows:
|
||||
typer.echo(f"[{config.app_name}] No native services defined in manage.toml")
|
||||
return
|
||||
typer.echo(f"\n {config.app_name} — native services\n")
|
||||
for svc in rows:
|
||||
indicator = typer.style("●", fg=typer.colors.GREEN) if svc.running \
|
||||
else typer.style("○", fg=typer.colors.RED)
|
||||
pid_str = f" pid={svc.pid}" if svc.pid else ""
|
||||
port_str = f" port={svc.port}" if svc.port else ""
|
||||
typer.echo(f" {indicator} {svc.name:<20}{pid_str}{port_str}")
|
||||
typer.echo("")
|
||||
|
||||
|
||||
@app.command()
|
||||
def logs(
|
||||
service: Annotated[Optional[str], typer.Argument(help="Service name")] = None,
|
||||
follow: bool = typer.Option(True, "--follow/--no-follow", "-f/-F"),
|
||||
mode: Mode = Mode.auto,
|
||||
root: Path = Path("."),
|
||||
) -> None:
|
||||
"""Tail service logs."""
|
||||
config = _load(root.resolve())
|
||||
mode_name, mgr = _resolve(mode, root.resolve(), config)
|
||||
if isinstance(mgr, DockerManager):
|
||||
mgr.logs(service or "", follow=follow)
|
||||
else:
|
||||
if not service:
|
||||
# Default to first service when none specified
|
||||
if not config.services:
|
||||
typer.echo("No native services defined", err=True)
|
||||
raise typer.Exit(1)
|
||||
service = config.services[0].name
|
||||
mgr.logs(service, follow=follow)
|
||||
|
||||
|
||||
@app.command()
|
||||
def build(
|
||||
no_cache: bool = False,
|
||||
mode: Mode = Mode.auto,
|
||||
root: Path = Path("."),
|
||||
) -> None:
|
||||
"""Build/rebuild service images (Docker mode only)."""
|
||||
config = _load(root.resolve())
|
||||
mode_name, mgr = _resolve(mode, root.resolve(), config)
|
||||
if isinstance(mgr, NativeManager):
|
||||
typer.echo("build is only available in Docker mode", err=True)
|
||||
raise typer.Exit(1)
|
||||
typer.echo(f"[{config.app_name}] Building images…")
|
||||
mgr.build(no_cache=no_cache)
|
||||
|
||||
|
||||
@app.command("open")
|
||||
def open_browser(
|
||||
url: Annotated[Optional[str], typer.Option(help="Override URL")] = None,
|
||||
root: Path = Path("."),
|
||||
) -> None:
|
||||
"""Open the product web UI in the default browser."""
|
||||
config = _load(root.resolve())
|
||||
target = url or config.default_url
|
||||
if not target:
|
||||
typer.echo("No URL configured. Set default_url in manage.toml or pass --url.", err=True)
|
||||
raise typer.Exit(1)
|
||||
typer.echo(f"Opening {target}")
|
||||
webbrowser.open(target)
|
||||
|
||||
|
||||
@app.command("install-shims")
|
||||
def install_shims(
|
||||
root: Path = Path("."),
|
||||
force: bool = typer.Option(False, "--force", help="Overwrite existing shims"),
|
||||
) -> None:
|
||||
"""
|
||||
Write manage.sh and manage.ps1 shims into the product directory.
|
||||
|
||||
The shims auto-detect the Python environment (conda, venv, or system Python)
|
||||
and delegate all arguments to `python -m circuitforge_core.manage`.
|
||||
"""
|
||||
from importlib.resources import files as _res_files
|
||||
|
||||
target = root.resolve()
|
||||
templates_pkg = "circuitforge_core.manage.templates"
|
||||
|
||||
for filename in ("manage.sh", "manage.ps1"):
|
||||
dest = target / filename
|
||||
if dest.exists() and not force:
|
||||
typer.echo(f" skipped {filename} (already exists — use --force to overwrite)")
|
||||
continue
|
||||
content = (_res_files(templates_pkg) / filename).read_text()
|
||||
dest.write_text(content)
|
||||
if filename.endswith(".sh"):
|
||||
dest.chmod(dest.stat().st_mode | 0o111) # make executable
|
||||
typer.echo(f" wrote {dest}")
|
||||
|
||||
toml_example = target / "manage.toml.example"
|
||||
if not toml_example.exists() or force:
|
||||
content = (_res_files(templates_pkg) / "manage.toml.example").read_text()
|
||||
toml_example.write_text(content)
|
||||
typer.echo(f" wrote {toml_example}")
|
||||
|
||||
typer.echo("\nDone. Rename manage.toml.example → manage.toml and edit for your services.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app()
|
||||
119
circuitforge_core/manage/config.py
Normal file
119
circuitforge_core/manage/config.py
Normal file
|
|
@ -0,0 +1,119 @@
|
|||
"""
|
||||
circuitforge_core.manage.config — ManageConfig parsed from manage.toml.
|
||||
|
||||
Products drop a manage.toml in their root directory. manage.py reads it to
|
||||
discover the app name, compose file, and native service definitions.
|
||||
|
||||
Minimal manage.toml (Docker-only):
|
||||
----------------------------------------------------------------------
|
||||
[app]
|
||||
name = "kiwi"
|
||||
default_url = "http://localhost:8511"
|
||||
----------------------------------------------------------------------
|
||||
|
||||
Full manage.toml (Docker + native services):
|
||||
----------------------------------------------------------------------
|
||||
[app]
|
||||
name = "kiwi"
|
||||
default_url = "http://localhost:8511"
|
||||
|
||||
[docker]
|
||||
compose_file = "compose.yml" # default
|
||||
project = "kiwi" # defaults to app.name
|
||||
|
||||
[[native.services]]
|
||||
name = "api"
|
||||
command = "uvicorn app.main:app --host 0.0.0.0 --port 8512"
|
||||
port = 8512
|
||||
|
||||
[[native.services]]
|
||||
name = "frontend"
|
||||
command = "npm run preview -- --host --port 8511"
|
||||
port = 8511
|
||||
cwd = "frontend"
|
||||
----------------------------------------------------------------------
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
import tomllib
|
||||
else:
|
||||
try:
|
||||
import tomllib # type: ignore[no-redef]
|
||||
except ImportError:
|
||||
import tomli as tomllib # type: ignore[no-redef]
|
||||
|
||||
_DEFAULT_COMPOSE_FILE = "compose.yml"
|
||||
|
||||
|
||||
@dataclass
|
||||
class NativeService:
|
||||
"""One process to manage in native mode."""
|
||||
name: str
|
||||
command: str # shell command string
|
||||
port: int = 0 # for status / open URL
|
||||
cwd: str = "" # relative to project root; "" = root
|
||||
env: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DockerConfig:
|
||||
compose_file: str = _DEFAULT_COMPOSE_FILE
|
||||
project: str = "" # docker compose -p; defaults to app name
|
||||
|
||||
|
||||
@dataclass
|
||||
class ManageConfig:
|
||||
app_name: str
|
||||
default_url: str = ""
|
||||
docker: DockerConfig = field(default_factory=DockerConfig)
|
||||
services: list[NativeService] = field(default_factory=list)
|
||||
|
||||
@classmethod
|
||||
def load(cls, path: Path) -> "ManageConfig":
|
||||
"""Load from a manage.toml file."""
|
||||
raw = tomllib.loads(path.read_text())
|
||||
app = raw.get("app", {})
|
||||
name = app.get("name") or path.parent.name # fallback to directory name
|
||||
default_url = app.get("default_url", "")
|
||||
|
||||
docker_raw = raw.get("docker", {})
|
||||
docker = DockerConfig(
|
||||
compose_file=docker_raw.get("compose_file", _DEFAULT_COMPOSE_FILE),
|
||||
project=docker_raw.get("project", name),
|
||||
)
|
||||
|
||||
services: list[NativeService] = []
|
||||
for svc in raw.get("native", {}).get("services", []):
|
||||
services.append(NativeService(
|
||||
name=svc["name"],
|
||||
command=svc["command"],
|
||||
port=svc.get("port", 0),
|
||||
cwd=svc.get("cwd", ""),
|
||||
env=svc.get("env", {}),
|
||||
))
|
||||
|
||||
return cls(
|
||||
app_name=name,
|
||||
default_url=default_url,
|
||||
docker=docker,
|
||||
services=services,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_cwd(cls, cwd: Path | None = None) -> "ManageConfig":
|
||||
"""
|
||||
Load from manage.toml in cwd, or return a minimal config derived from
|
||||
the directory name if no manage.toml exists (Docker-only products work
|
||||
without one).
|
||||
"""
|
||||
root = cwd or Path.cwd()
|
||||
toml_path = root / "manage.toml"
|
||||
if toml_path.exists():
|
||||
return cls.load(toml_path)
|
||||
# Fallback: infer from directory name, look for compose.yml
|
||||
return cls(app_name=root.name)
|
||||
115
circuitforge_core/manage/docker_mode.py
Normal file
115
circuitforge_core/manage/docker_mode.py
Normal file
|
|
@ -0,0 +1,115 @@
|
|||
"""
|
||||
circuitforge_core.manage.docker_mode — Docker Compose wrapper.
|
||||
|
||||
All commands delegate to `docker compose` (v2 plugin syntax).
|
||||
Falls back to `docker-compose` (v1 standalone) if the plugin is unavailable.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from .config import ManageConfig
|
||||
|
||||
|
||||
def _compose_bin() -> list[str]:
|
||||
"""Return the docker compose command as a list (handles v1/v2 difference)."""
|
||||
# Docker Compose v2: `docker compose` (space, built-in plugin)
|
||||
# Docker Compose v1: `docker-compose` (hyphen, standalone binary)
|
||||
if shutil.which("docker"):
|
||||
return ["docker", "compose"]
|
||||
if shutil.which("docker-compose"):
|
||||
return ["docker-compose"]
|
||||
raise RuntimeError("Neither 'docker' nor 'docker-compose' found on PATH")
|
||||
|
||||
|
||||
def docker_available() -> bool:
|
||||
"""Return True if Docker is reachable (docker info succeeds)."""
|
||||
try:
|
||||
subprocess.run(
|
||||
["docker", "info"],
|
||||
stdout=subprocess.DEVNULL,
|
||||
stderr=subprocess.DEVNULL,
|
||||
timeout=5,
|
||||
)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
class DockerManager:
|
||||
"""
|
||||
Wraps `docker compose` for a single product directory.
|
||||
|
||||
Args:
|
||||
config: ManageConfig for the current product.
|
||||
root: Product root directory (where compose file lives).
|
||||
"""
|
||||
|
||||
def __init__(self, config: ManageConfig, root: Path) -> None:
|
||||
self.config = config
|
||||
self.root = root
|
||||
self._compose_file = root / config.docker.compose_file
|
||||
|
||||
def _run(self, *args: str, check: bool = True) -> subprocess.CompletedProcess: # type: ignore[type-arg]
|
||||
cmd = [
|
||||
*_compose_bin(),
|
||||
"-f", str(self._compose_file),
|
||||
"-p", self.config.docker.project or self.config.app_name,
|
||||
*args,
|
||||
]
|
||||
return subprocess.run(cmd, cwd=self.root, check=check)
|
||||
|
||||
def _stream(self, *args: str) -> None:
|
||||
"""Run a compose command, streaming output directly to the terminal."""
|
||||
cmd = [
|
||||
*_compose_bin(),
|
||||
"-f", str(self._compose_file),
|
||||
"-p", self.config.docker.project or self.config.app_name,
|
||||
*args,
|
||||
]
|
||||
with subprocess.Popen(cmd, cwd=self.root) as proc:
|
||||
try:
|
||||
proc.wait()
|
||||
except KeyboardInterrupt:
|
||||
proc.terminate()
|
||||
|
||||
def compose_file_exists(self) -> bool:
|
||||
return self._compose_file.exists()
|
||||
|
||||
def start(self, service: str = "") -> None:
|
||||
args = ["up", "-d", "--build"]
|
||||
if service:
|
||||
args.append(service)
|
||||
self._run(*args)
|
||||
|
||||
def stop(self, service: str = "") -> None:
|
||||
if service:
|
||||
self._run("stop", service)
|
||||
else:
|
||||
self._run("down")
|
||||
|
||||
def restart(self, service: str = "") -> None:
|
||||
args = ["restart"]
|
||||
if service:
|
||||
args.append(service)
|
||||
self._run(*args)
|
||||
|
||||
def status(self) -> None:
|
||||
self._run("ps", check=False)
|
||||
|
||||
def logs(self, service: str = "", follow: bool = True) -> None:
|
||||
args = ["logs"]
|
||||
if follow:
|
||||
args.append("-f")
|
||||
if service:
|
||||
args.append(service)
|
||||
self._stream(*args)
|
||||
|
||||
def build(self, no_cache: bool = False) -> None:
|
||||
args = ["build"]
|
||||
if no_cache:
|
||||
args.append("--no-cache")
|
||||
self._run(*args)
|
||||
217
circuitforge_core/manage/native_mode.py
Normal file
217
circuitforge_core/manage/native_mode.py
Normal file
|
|
@ -0,0 +1,217 @@
|
|||
"""
|
||||
circuitforge_core.manage.native_mode — PID-file process manager.
|
||||
|
||||
Manages processes directly without Docker. Designed for Windows (no WSL2,
|
||||
no Docker), but works identically on Linux/macOS.
|
||||
|
||||
Platform conventions (via platformdirs):
|
||||
PID files : user_runtime_dir(app_name) / <service>.pid
|
||||
Log files : user_log_dir(app_name) / <service>.log
|
||||
|
||||
PID file format (one line each):
|
||||
<pid>
|
||||
<command_fingerprint> (first 80 chars of command — used to sanity-check
|
||||
that the PID belongs to our process, not a recycled one)
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import platform
|
||||
import shlex
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
from platformdirs import user_log_dir, user_runtime_dir
|
||||
|
||||
from .config import ManageConfig, NativeService
|
||||
|
||||
_IS_WINDOWS = platform.system() == "Windows"
|
||||
_LOG_TAIL_LINES = 50
|
||||
_FOLLOW_POLL_S = 0.25
|
||||
|
||||
|
||||
@dataclass
|
||||
class ServiceStatus:
|
||||
name: str
|
||||
running: bool
|
||||
pid: int | None
|
||||
port: int
|
||||
log_path: Path
|
||||
|
||||
|
||||
class NativeManager:
|
||||
"""
|
||||
Start, stop, and monitor native processes for a product.
|
||||
|
||||
Args:
|
||||
config: ManageConfig for the current product.
|
||||
root: Product root directory.
|
||||
"""
|
||||
|
||||
def __init__(self, config: ManageConfig, root: Path) -> None:
|
||||
self.config = config
|
||||
self.root = root
|
||||
self._pid_dir = Path(user_runtime_dir(config.app_name, ensure_exists=True))
|
||||
self._log_dir = Path(user_log_dir(config.app_name, ensure_exists=True))
|
||||
|
||||
# ── helpers ───────────────────────────────────────────────────────────────
|
||||
|
||||
def _pid_path(self, name: str) -> Path:
|
||||
return self._pid_dir / f"{name}.pid"
|
||||
|
||||
def _log_path(self, name: str) -> Path:
|
||||
return self._log_dir / f"{name}.log"
|
||||
|
||||
def _write_pid(self, name: str, pid: int, command: str) -> None:
|
||||
self._pid_path(name).write_text(f"{pid}\n{command[:80]}\n")
|
||||
|
||||
def _read_pid(self, name: str) -> int | None:
|
||||
p = self._pid_path(name)
|
||||
if not p.exists():
|
||||
return None
|
||||
try:
|
||||
return int(p.read_text().splitlines()[0].strip())
|
||||
except (ValueError, IndexError):
|
||||
return None
|
||||
|
||||
def _pid_alive(self, pid: int) -> bool:
|
||||
"""Return True if a process with this PID is currently running."""
|
||||
if _IS_WINDOWS:
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["tasklist", "/FI", f"PID eq {pid}", "/NH"],
|
||||
capture_output=True, text=True,
|
||||
)
|
||||
return str(pid) in result.stdout
|
||||
except Exception:
|
||||
return False
|
||||
else:
|
||||
try:
|
||||
os.kill(pid, 0) # signal 0 = existence check only
|
||||
return True
|
||||
except (OSError, ProcessLookupError):
|
||||
return False
|
||||
|
||||
def _kill(self, pid: int) -> None:
|
||||
"""Terminate a process gracefully, then force-kill if needed."""
|
||||
if _IS_WINDOWS:
|
||||
subprocess.run(["taskkill", "/F", "/PID", str(pid)],
|
||||
capture_output=True)
|
||||
else:
|
||||
import signal
|
||||
try:
|
||||
os.kill(pid, signal.SIGTERM)
|
||||
for _ in range(30): # wait up to 3 s
|
||||
time.sleep(0.1)
|
||||
if not self._pid_alive(pid):
|
||||
return
|
||||
os.kill(pid, signal.SIGKILL)
|
||||
except (OSError, ProcessLookupError):
|
||||
pass
|
||||
|
||||
def _svc(self, name: str) -> NativeService | None:
|
||||
return next((s for s in self.config.services if s.name == name), None)
|
||||
|
||||
# ── public API ────────────────────────────────────────────────────────────
|
||||
|
||||
def is_running(self, name: str) -> bool:
|
||||
pid = self._read_pid(name)
|
||||
return pid is not None and self._pid_alive(pid)
|
||||
|
||||
def status(self) -> list[ServiceStatus]:
|
||||
result = []
|
||||
for svc in self.config.services:
|
||||
pid = self._read_pid(svc.name)
|
||||
running = pid is not None and self._pid_alive(pid)
|
||||
result.append(ServiceStatus(
|
||||
name=svc.name,
|
||||
running=running,
|
||||
pid=pid if running else None,
|
||||
port=svc.port,
|
||||
log_path=self._log_path(svc.name),
|
||||
))
|
||||
return result
|
||||
|
||||
def start(self, name: str | None = None) -> list[str]:
|
||||
"""Start one or all services. Returns list of started service names."""
|
||||
targets = [self._svc(name)] if name else self.config.services
|
||||
started: list[str] = []
|
||||
for svc in targets:
|
||||
if svc is None:
|
||||
raise ValueError(f"Unknown service: {name!r}")
|
||||
if self.is_running(svc.name):
|
||||
continue
|
||||
cwd = (self.root / svc.cwd) if svc.cwd else self.root
|
||||
log_file = open(self._log_path(svc.name), "a") # noqa: WPS515
|
||||
env = {**os.environ, **svc.env}
|
||||
if _IS_WINDOWS:
|
||||
cmd = svc.command # Windows: pass as string to shell
|
||||
shell = True
|
||||
else:
|
||||
cmd = shlex.split(svc.command)
|
||||
shell = False
|
||||
proc = subprocess.Popen(
|
||||
cmd,
|
||||
cwd=cwd,
|
||||
env=env,
|
||||
shell=shell,
|
||||
stdout=log_file,
|
||||
stderr=log_file,
|
||||
start_new_session=True, # detach from terminal (Unix)
|
||||
)
|
||||
self._write_pid(svc.name, proc.pid, svc.command)
|
||||
started.append(svc.name)
|
||||
return started
|
||||
|
||||
def stop(self, name: str | None = None) -> list[str]:
|
||||
"""Stop one or all services. Returns list of stopped service names."""
|
||||
names = [name] if name else [s.name for s in self.config.services]
|
||||
stopped: list[str] = []
|
||||
for n in names:
|
||||
pid = self._read_pid(n)
|
||||
if pid and self._pid_alive(pid):
|
||||
self._kill(pid)
|
||||
stopped.append(n)
|
||||
pid_path = self._pid_path(n)
|
||||
if pid_path.exists():
|
||||
pid_path.unlink()
|
||||
return stopped
|
||||
|
||||
def logs(self, name: str, follow: bool = True, lines: int = _LOG_TAIL_LINES) -> None:
|
||||
"""
|
||||
Print the last N lines of a service log, then optionally follow.
|
||||
|
||||
Uses polling rather than `tail -f` so it works on Windows.
|
||||
"""
|
||||
log_path = self._log_path(name)
|
||||
if not log_path.exists():
|
||||
print(f"[{name}] No log file found at {log_path}", file=sys.stderr)
|
||||
return
|
||||
|
||||
# Print last N lines
|
||||
content = log_path.read_bytes()
|
||||
lines_data = content.splitlines()[-lines:]
|
||||
for line in lines_data:
|
||||
print(line.decode("utf-8", errors="replace"))
|
||||
|
||||
if not follow:
|
||||
return
|
||||
|
||||
# Poll for new content
|
||||
offset = len(content)
|
||||
try:
|
||||
while True:
|
||||
time.sleep(_FOLLOW_POLL_S)
|
||||
new_size = log_path.stat().st_size
|
||||
if new_size > offset:
|
||||
with open(log_path, "rb") as f:
|
||||
f.seek(offset)
|
||||
chunk = f.read()
|
||||
offset = new_size
|
||||
for line in chunk.splitlines():
|
||||
print(line.decode("utf-8", errors="replace"))
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
30
circuitforge_core/manage/templates/manage.ps1
Normal file
30
circuitforge_core/manage/templates/manage.ps1
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
# manage.ps1 — CircuitForge cross-platform product manager shim (Windows)
|
||||
#
|
||||
# Auto-detects the Python environment and delegates to
|
||||
# `python -m circuitforge_core.manage`.
|
||||
#
|
||||
# Generated by: python -m circuitforge_core.manage install-shims
|
||||
# Do not edit the logic here; edit manage.toml for product configuration.
|
||||
|
||||
# ── Python detection ──────────────────────────────────────────────────────────
|
||||
$Python = $null
|
||||
|
||||
if (Test-Path ".venv\Scripts\python.exe") {
|
||||
$Python = ".venv\Scripts\python.exe"
|
||||
} elseif (Test-Path "venv\Scripts\python.exe") {
|
||||
$Python = "venv\Scripts\python.exe"
|
||||
} elseif ($env:CONDA_DEFAULT_ENV -and (Get-Command conda -ErrorAction SilentlyContinue)) {
|
||||
# Conda: run via `conda run` so the env is activated correctly
|
||||
& conda run -n $env:CONDA_DEFAULT_ENV python -m circuitforge_core.manage @args
|
||||
exit $LASTEXITCODE
|
||||
} elseif (Get-Command python -ErrorAction SilentlyContinue) {
|
||||
$Python = "python"
|
||||
} elseif (Get-Command python3 -ErrorAction SilentlyContinue) {
|
||||
$Python = "python3"
|
||||
} else {
|
||||
Write-Error "No Python interpreter found. Install Python 3.11+, activate a venv, or activate a conda environment."
|
||||
exit 1
|
||||
}
|
||||
|
||||
& $Python -m circuitforge_core.manage @args
|
||||
exit $LASTEXITCODE
|
||||
28
circuitforge_core/manage/templates/manage.sh
Normal file
28
circuitforge_core/manage/templates/manage.sh
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
#!/usr/bin/env bash
|
||||
# manage.sh — CircuitForge cross-platform product manager shim
|
||||
#
|
||||
# Auto-detects the Python environment and delegates to
|
||||
# `python -m circuitforge_core.manage`.
|
||||
#
|
||||
# Generated by: python -m circuitforge_core.manage install-shims
|
||||
# Do not edit the logic here; edit manage.toml for product configuration.
|
||||
set -euo pipefail
|
||||
|
||||
# ── Python detection ──────────────────────────────────────────────────────────
|
||||
if [ -f ".venv/bin/python" ]; then
|
||||
PYTHON=".venv/bin/python"
|
||||
elif [ -f "venv/bin/python" ]; then
|
||||
PYTHON="venv/bin/python"
|
||||
elif command -v conda &>/dev/null && [ -n "${CONDA_DEFAULT_ENV:-}" ]; then
|
||||
PYTHON="conda run -n ${CONDA_DEFAULT_ENV} python"
|
||||
elif command -v python3 &>/dev/null; then
|
||||
PYTHON="python3"
|
||||
elif command -v python &>/dev/null; then
|
||||
PYTHON="python"
|
||||
else
|
||||
echo "ERROR: No Python interpreter found." >&2
|
||||
echo "Install Python 3.11+, activate a venv, or activate a conda environment." >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
exec $PYTHON -m circuitforge_core.manage "$@"
|
||||
38
circuitforge_core/manage/templates/manage.toml.example
Normal file
38
circuitforge_core/manage/templates/manage.toml.example
Normal file
|
|
@ -0,0 +1,38 @@
|
|||
# manage.toml — CircuitForge product manager configuration
|
||||
#
|
||||
# Drop this file (renamed from manage.toml.example) in your product root.
|
||||
# Docker-only products only need [app]; native services require [[native.services]].
|
||||
|
||||
[app]
|
||||
# Product name — used for log/PID directory names and display.
|
||||
name = "myproduct"
|
||||
|
||||
# URL opened by `manage.py open`. Typically the frontend port.
|
||||
default_url = "http://localhost:8511"
|
||||
|
||||
[docker]
|
||||
# Path to the Docker Compose file, relative to this directory.
|
||||
compose_file = "compose.yml"
|
||||
|
||||
# Docker Compose project name (defaults to app.name).
|
||||
# project = "myproduct"
|
||||
|
||||
# ── Native mode services ───────────────────────────────────────────────────────
|
||||
# Define one [[native.services]] block per process.
|
||||
# Used when Docker is unavailable (Windows without Docker, or --mode native).
|
||||
|
||||
[[native.services]]
|
||||
name = "api"
|
||||
# Full shell command to launch the backend.
|
||||
command = "uvicorn app.main:app --host 0.0.0.0 --port 8512 --reload"
|
||||
port = 8512
|
||||
# cwd is relative to the project root. Leave empty for the project root itself.
|
||||
cwd = ""
|
||||
# Optional extra environment variables.
|
||||
# env = { PYTHONPATH = ".", DEBUG = "1" }
|
||||
|
||||
[[native.services]]
|
||||
name = "frontend"
|
||||
command = "npm run preview -- --host 0.0.0.0 --port 8511"
|
||||
port = 8511
|
||||
cwd = "frontend"
|
||||
|
|
@ -1,3 +1,43 @@
|
|||
# circuitforge_core/pipeline — FPGA→ASIC crystallization engine
|
||||
#
|
||||
# Public API: call pipeline.run() from product code instead of llm.router directly.
|
||||
# The module transparently checks for crystallized workflows first, falls back
|
||||
# to LLM when none match, and records each run for future crystallization.
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Callable
|
||||
|
||||
from .crystallizer import CrystallizerConfig, crystallize, evaluate_new_run, should_crystallize
|
||||
from .executor import ExecutionResult, Executor, StepResult
|
||||
from .models import CrystallizedWorkflow, PipelineRun, Step, hash_input
|
||||
from .multimodal import MultimodalConfig, MultimodalPipeline, PageResult
|
||||
from .recorder import Recorder
|
||||
from .registry import Registry
|
||||
from .staging import StagingDB
|
||||
|
||||
__all__ = ["StagingDB"]
|
||||
__all__ = [
|
||||
# models
|
||||
"PipelineRun",
|
||||
"CrystallizedWorkflow",
|
||||
"Step",
|
||||
"hash_input",
|
||||
# recorder
|
||||
"Recorder",
|
||||
# crystallizer
|
||||
"CrystallizerConfig",
|
||||
"crystallize",
|
||||
"evaluate_new_run",
|
||||
"should_crystallize",
|
||||
# registry
|
||||
"Registry",
|
||||
# executor
|
||||
"Executor",
|
||||
"ExecutionResult",
|
||||
"StepResult",
|
||||
# multimodal
|
||||
"MultimodalPipeline",
|
||||
"MultimodalConfig",
|
||||
"PageResult",
|
||||
# legacy stub
|
||||
"StagingDB",
|
||||
]
|
||||
|
|
|
|||
177
circuitforge_core/pipeline/crystallizer.py
Normal file
177
circuitforge_core/pipeline/crystallizer.py
Normal file
|
|
@ -0,0 +1,177 @@
|
|||
# circuitforge_core/pipeline/crystallizer.py — promote approved runs → workflows
|
||||
#
|
||||
# MIT — pure logic, no inference backends.
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import warnings
|
||||
from collections import Counter
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from typing import Literal
|
||||
|
||||
from .models import CrystallizedWorkflow, PipelineRun, Step
|
||||
from .recorder import Recorder
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
# Minimum milliseconds of review that counts as "genuine".
|
||||
# Runs shorter than this are accepted but trigger a warning.
|
||||
_RUBBER_STAMP_THRESHOLD_MS = 5_000
|
||||
|
||||
|
||||
@dataclass
|
||||
class CrystallizerConfig:
|
||||
"""Tuning knobs for one product/task-type pair.
|
||||
|
||||
threshold:
|
||||
Minimum number of approved runs required before crystallization.
|
||||
Osprey sets this to 1 (first successful IVR navigation is enough);
|
||||
Peregrine uses 3+ for cover-letter templates.
|
||||
min_review_ms:
|
||||
Approved runs with review_duration_ms below this value generate a
|
||||
warning. Set to 0 to silence the check (tests, automated approvals).
|
||||
strategy:
|
||||
``"most_recent"`` — use the latest approved run's steps verbatim.
|
||||
``"majority"`` — pick each step by majority vote across runs (requires
|
||||
runs to have the same step count; falls back to most_recent otherwise).
|
||||
"""
|
||||
threshold: int = 3
|
||||
min_review_ms: int = _RUBBER_STAMP_THRESHOLD_MS
|
||||
strategy: Literal["most_recent", "majority"] = "most_recent"
|
||||
|
||||
|
||||
# ── Helpers ───────────────────────────────────────────────────────────────────
|
||||
|
||||
def _majority_steps(runs: list[PipelineRun]) -> list[Step] | None:
|
||||
"""Return majority-voted steps, or None if run lengths differ."""
|
||||
lengths = {len(r.steps) for r in runs}
|
||||
if len(lengths) != 1:
|
||||
return None
|
||||
n = lengths.pop()
|
||||
result: list[Step] = []
|
||||
for i in range(n):
|
||||
counter: Counter[str] = Counter()
|
||||
step_by_action: dict[str, Step] = {}
|
||||
for r in runs:
|
||||
s = r.steps[i]
|
||||
counter[s.action] += 1
|
||||
step_by_action[s.action] = s
|
||||
winner = counter.most_common(1)[0][0]
|
||||
result.append(step_by_action[winner])
|
||||
return result
|
||||
|
||||
|
||||
def _check_review_quality(runs: list[PipelineRun],
|
||||
min_review_ms: int) -> None:
|
||||
"""Warn if any run has a suspiciously short review duration."""
|
||||
if min_review_ms <= 0:
|
||||
return
|
||||
flagged = [r for r in runs if r.review_duration_ms < min_review_ms]
|
||||
if flagged:
|
||||
ids = ", ".join(r.run_id for r in flagged)
|
||||
warnings.warn(
|
||||
f"Crystallizing from {len(flagged)} run(s) with review_duration_ms "
|
||||
f"< {min_review_ms} ms — possible rubber-stamp approval: [{ids}]. "
|
||||
"Verify these were genuinely human-reviewed before deployment.",
|
||||
stacklevel=3,
|
||||
)
|
||||
|
||||
|
||||
# ── Public API ────────────────────────────────────────────────────────────────
|
||||
|
||||
def should_crystallize(runs: list[PipelineRun],
|
||||
config: CrystallizerConfig) -> bool:
|
||||
"""Return True if *runs* meet the threshold for crystallization."""
|
||||
approved = [r for r in runs if r.approved]
|
||||
return len(approved) >= config.threshold
|
||||
|
||||
|
||||
def crystallize(runs: list[PipelineRun],
|
||||
config: CrystallizerConfig,
|
||||
existing_version: int = 0) -> CrystallizedWorkflow:
|
||||
"""Promote *runs* into a CrystallizedWorkflow.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If fewer approved runs than ``config.threshold``, or if the runs
|
||||
span more than one (product, task_type, input_hash) triple.
|
||||
"""
|
||||
approved = [r for r in runs if r.approved]
|
||||
if len(approved) < config.threshold:
|
||||
raise ValueError(
|
||||
f"Need {config.threshold} approved runs, got {len(approved)}."
|
||||
)
|
||||
|
||||
# Validate homogeneity
|
||||
products = {r.product for r in approved}
|
||||
task_types = {r.task_type for r in approved}
|
||||
hashes = {r.input_hash for r in approved}
|
||||
if len(products) != 1 or len(task_types) != 1 or len(hashes) != 1:
|
||||
raise ValueError(
|
||||
"All runs must share the same product, task_type, and input_hash. "
|
||||
f"Got products={products}, task_types={task_types}, hashes={hashes}."
|
||||
)
|
||||
|
||||
product = products.pop()
|
||||
task_type = task_types.pop()
|
||||
input_hash = hashes.pop()
|
||||
|
||||
_check_review_quality(approved, config.min_review_ms)
|
||||
|
||||
# Pick canonical steps
|
||||
if config.strategy == "majority":
|
||||
steps = _majority_steps(approved) or approved[-1].steps
|
||||
else:
|
||||
steps = sorted(approved, key=lambda r: r.timestamp)[-1].steps
|
||||
|
||||
avg_ms = sum(r.review_duration_ms for r in approved) // len(approved)
|
||||
all_unmodified = all(not r.output_modified for r in approved)
|
||||
|
||||
workflow_id = f"{product}:{task_type}:{input_hash[:12]}"
|
||||
return CrystallizedWorkflow(
|
||||
workflow_id=workflow_id,
|
||||
product=product,
|
||||
task_type=task_type,
|
||||
input_hash=input_hash,
|
||||
steps=steps,
|
||||
crystallized_at=datetime.now(timezone.utc).isoformat(),
|
||||
run_ids=[r.run_id for r in approved],
|
||||
approval_count=len(approved),
|
||||
avg_review_duration_ms=avg_ms,
|
||||
all_output_unmodified=all_unmodified,
|
||||
version=existing_version + 1,
|
||||
)
|
||||
|
||||
|
||||
def evaluate_new_run(
|
||||
run: PipelineRun,
|
||||
recorder: Recorder,
|
||||
config: CrystallizerConfig,
|
||||
existing_version: int = 0,
|
||||
) -> CrystallizedWorkflow | None:
|
||||
"""Record *run* and return a new workflow if the threshold is now met.
|
||||
|
||||
Products call this after each human-approved execution. Returns a
|
||||
``CrystallizedWorkflow`` if crystallization was triggered, ``None``
|
||||
otherwise.
|
||||
"""
|
||||
recorder.record(run)
|
||||
if not run.approved:
|
||||
return None
|
||||
|
||||
all_runs = recorder.load_approved(run.product, run.task_type, run.input_hash)
|
||||
if not should_crystallize(all_runs, config):
|
||||
log.debug(
|
||||
"pipeline: %d/%d approved runs for %s:%s — not yet crystallizing",
|
||||
len(all_runs), config.threshold, run.product, run.task_type,
|
||||
)
|
||||
return None
|
||||
|
||||
workflow = crystallize(all_runs, config, existing_version=existing_version)
|
||||
log.info(
|
||||
"pipeline: crystallized %s after %d approvals",
|
||||
workflow.workflow_id, workflow.approval_count,
|
||||
)
|
||||
return workflow
|
||||
157
circuitforge_core/pipeline/executor.py
Normal file
157
circuitforge_core/pipeline/executor.py
Normal file
|
|
@ -0,0 +1,157 @@
|
|||
# circuitforge_core/pipeline/executor.py — deterministic execution with LLM fallback
|
||||
#
|
||||
# MIT — orchestration logic only; calls product-supplied callables.
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable
|
||||
|
||||
from .models import CrystallizedWorkflow, Step
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class StepResult:
|
||||
step: Step
|
||||
success: bool
|
||||
output: Any = None
|
||||
error: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExecutionResult:
|
||||
"""Result of running a workflow (deterministic or LLM-assisted).
|
||||
|
||||
Attributes
|
||||
----------
|
||||
success:
|
||||
True if all steps completed without error.
|
||||
used_deterministic:
|
||||
True if a crystallized workflow was used; False if LLM was called.
|
||||
step_results:
|
||||
Per-step outcomes from the deterministic path.
|
||||
llm_output:
|
||||
Raw output from the LLM fallback path, if used.
|
||||
workflow_id:
|
||||
ID of the workflow used, or None for LLM path.
|
||||
error:
|
||||
Error message if the run failed entirely.
|
||||
"""
|
||||
success: bool
|
||||
used_deterministic: bool
|
||||
step_results: list[StepResult] = field(default_factory=list)
|
||||
llm_output: Any = None
|
||||
workflow_id: str | None = None
|
||||
error: str | None = None
|
||||
|
||||
|
||||
# ── Executor ──────────────────────────────────────────────────────────────────
|
||||
|
||||
class Executor:
|
||||
"""Runs crystallized workflows with transparent LLM fallback.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
step_fn:
|
||||
Called for each Step: ``step_fn(step) -> (success, output)``.
|
||||
The product supplies this — it knows how to turn a Step into a real
|
||||
action (DTMF dial, HTTP call, form field write, etc.).
|
||||
llm_fn:
|
||||
Called when no workflow matches or a step fails: ``llm_fn() -> output``.
|
||||
Products wire this to ``cf_core.llm.router`` or equivalent.
|
||||
llm_fallback:
|
||||
If False, raise RuntimeError instead of calling llm_fn on miss.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
step_fn: Callable[[Step], tuple[bool, Any]],
|
||||
llm_fn: Callable[[], Any],
|
||||
llm_fallback: bool = True,
|
||||
) -> None:
|
||||
self._step_fn = step_fn
|
||||
self._llm_fn = llm_fn
|
||||
self._llm_fallback = llm_fallback
|
||||
|
||||
def execute(
|
||||
self,
|
||||
workflow: CrystallizedWorkflow,
|
||||
) -> ExecutionResult:
|
||||
"""Run *workflow* deterministically.
|
||||
|
||||
If a step fails, falls back to LLM (if ``llm_fallback`` is enabled).
|
||||
"""
|
||||
step_results: list[StepResult] = []
|
||||
for step in workflow.steps:
|
||||
try:
|
||||
success, output = self._step_fn(step)
|
||||
except Exception as exc:
|
||||
log.warning("step %s raised: %s", step.action, exc)
|
||||
success, output = False, None
|
||||
error_str = str(exc)
|
||||
else:
|
||||
error_str = None if success else "step_fn returned success=False"
|
||||
|
||||
step_results.append(StepResult(step=step, success=success,
|
||||
output=output, error=error_str))
|
||||
if not success:
|
||||
log.info(
|
||||
"workflow %s: step %s failed — triggering LLM fallback",
|
||||
workflow.workflow_id, step.action,
|
||||
)
|
||||
return self._llm_fallback_result(
|
||||
step_results, workflow.workflow_id
|
||||
)
|
||||
|
||||
log.info("workflow %s: all %d steps succeeded",
|
||||
workflow.workflow_id, len(workflow.steps))
|
||||
return ExecutionResult(
|
||||
success=True,
|
||||
used_deterministic=True,
|
||||
step_results=step_results,
|
||||
workflow_id=workflow.workflow_id,
|
||||
)
|
||||
|
||||
def run_with_fallback(
|
||||
self,
|
||||
workflow: CrystallizedWorkflow | None,
|
||||
) -> ExecutionResult:
|
||||
"""Run *workflow* if provided; otherwise call the LLM directly."""
|
||||
if workflow is None:
|
||||
return self._llm_fallback_result([], workflow_id=None)
|
||||
return self.execute(workflow)
|
||||
|
||||
# ── Internal ──────────────────────────────────────────────────────────────
|
||||
|
||||
def _llm_fallback_result(
|
||||
self,
|
||||
partial_steps: list[StepResult],
|
||||
workflow_id: str | None,
|
||||
) -> ExecutionResult:
|
||||
if not self._llm_fallback:
|
||||
return ExecutionResult(
|
||||
success=False,
|
||||
used_deterministic=True,
|
||||
step_results=partial_steps,
|
||||
workflow_id=workflow_id,
|
||||
error="LLM fallback disabled and deterministic path failed.",
|
||||
)
|
||||
try:
|
||||
llm_output = self._llm_fn()
|
||||
except Exception as exc:
|
||||
return ExecutionResult(
|
||||
success=False,
|
||||
used_deterministic=False,
|
||||
step_results=partial_steps,
|
||||
workflow_id=workflow_id,
|
||||
error=f"LLM fallback raised: {exc}",
|
||||
)
|
||||
return ExecutionResult(
|
||||
success=True,
|
||||
used_deterministic=False,
|
||||
step_results=partial_steps,
|
||||
llm_output=llm_output,
|
||||
workflow_id=workflow_id,
|
||||
)
|
||||
216
circuitforge_core/pipeline/models.py
Normal file
216
circuitforge_core/pipeline/models.py
Normal file
|
|
@ -0,0 +1,216 @@
|
|||
# circuitforge_core/pipeline/models.py — crystallization data models
|
||||
#
|
||||
# MIT — protocol and model types only; no inference backends.
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
|
||||
# ── Utilities ─────────────────────────────────────────────────────────────────
|
||||
|
||||
def hash_input(features: dict[str, Any]) -> str:
|
||||
"""Return a stable SHA-256 hex digest of *features*.
|
||||
|
||||
Sorts keys before serialising so insertion order doesn't affect the hash.
|
||||
Only call this on already-normalised, PII-free feature dicts — the hash is
|
||||
opaque but the source dict should never contain raw user data.
|
||||
"""
|
||||
canonical = json.dumps(features, sort_keys=True, ensure_ascii=True)
|
||||
return hashlib.sha256(canonical.encode()).hexdigest()
|
||||
|
||||
|
||||
# ── Step ──────────────────────────────────────────────────────────────────────
|
||||
|
||||
@dataclass
|
||||
class Step:
|
||||
"""One atomic action in a deterministic workflow.
|
||||
|
||||
The ``action`` string is product-defined (e.g. ``"dtmf"``, ``"field_fill"``,
|
||||
``"api_call"``). ``params`` carries action-specific values; ``description``
|
||||
is a plain-English summary for the approval UI.
|
||||
"""
|
||||
action: str
|
||||
params: dict[str, Any]
|
||||
description: str = ""
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {"action": self.action, "params": self.params,
|
||||
"description": self.description}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d: dict[str, Any]) -> "Step":
|
||||
return cls(action=d["action"], params=d.get("params", {}),
|
||||
description=d.get("description", ""))
|
||||
|
||||
|
||||
# ── PipelineRun ───────────────────────────────────────────────────────────────
|
||||
|
||||
@dataclass
|
||||
class PipelineRun:
|
||||
"""Record of one LLM-assisted execution — the raw material for crystallization.
|
||||
|
||||
Fields
|
||||
------
|
||||
run_id:
|
||||
UUID or unique string identifying this run.
|
||||
product:
|
||||
CF product code (``"osprey"``, ``"falcon"``, ``"peregrine"`` …).
|
||||
task_type:
|
||||
Product-defined task category (``"ivr_navigate"``, ``"form_fill"`` …).
|
||||
input_hash:
|
||||
SHA-256 of normalised, PII-free input features. Never store raw input.
|
||||
steps:
|
||||
Ordered list of Steps the LLM proposed.
|
||||
approved:
|
||||
True if a human approved this run before execution.
|
||||
review_duration_ms:
|
||||
Wall-clock milliseconds between displaying the proposal and the approval
|
||||
click. Values under ~5 000 ms indicate a rubber-stamp — the
|
||||
crystallizer may reject runs with suspiciously short reviews.
|
||||
output_modified:
|
||||
True if the user edited any step before approving. Modifications suggest
|
||||
the LLM proposal was imperfect; too-easy crystallization from unmodified
|
||||
runs may mean the task is already deterministic and the LLM is just
|
||||
echoing a fixed pattern.
|
||||
timestamp:
|
||||
ISO 8601 UTC creation time.
|
||||
llm_model:
|
||||
Model ID that generated the steps, e.g. ``"llama3:8b-instruct"``.
|
||||
metadata:
|
||||
Freeform dict for product-specific extra fields.
|
||||
"""
|
||||
|
||||
run_id: str
|
||||
product: str
|
||||
task_type: str
|
||||
input_hash: str
|
||||
steps: list[Step]
|
||||
approved: bool
|
||||
review_duration_ms: int
|
||||
output_modified: bool
|
||||
timestamp: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
|
||||
llm_model: str | None = None
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"run_id": self.run_id,
|
||||
"product": self.product,
|
||||
"task_type": self.task_type,
|
||||
"input_hash": self.input_hash,
|
||||
"steps": [s.to_dict() for s in self.steps],
|
||||
"approved": self.approved,
|
||||
"review_duration_ms": self.review_duration_ms,
|
||||
"output_modified": self.output_modified,
|
||||
"timestamp": self.timestamp,
|
||||
"llm_model": self.llm_model,
|
||||
"metadata": self.metadata,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d: dict[str, Any]) -> "PipelineRun":
|
||||
return cls(
|
||||
run_id=d["run_id"],
|
||||
product=d["product"],
|
||||
task_type=d["task_type"],
|
||||
input_hash=d["input_hash"],
|
||||
steps=[Step.from_dict(s) for s in d.get("steps", [])],
|
||||
approved=d["approved"],
|
||||
review_duration_ms=d["review_duration_ms"],
|
||||
output_modified=d.get("output_modified", False),
|
||||
timestamp=d.get("timestamp", ""),
|
||||
llm_model=d.get("llm_model"),
|
||||
metadata=d.get("metadata", {}),
|
||||
)
|
||||
|
||||
|
||||
# ── CrystallizedWorkflow ──────────────────────────────────────────────────────
|
||||
|
||||
@dataclass
|
||||
class CrystallizedWorkflow:
|
||||
"""A deterministic workflow promoted from N approved PipelineRuns.
|
||||
|
||||
Once crystallized, the executor runs ``steps`` directly — no LLM required
|
||||
unless an edge case is encountered.
|
||||
|
||||
Fields
|
||||
------
|
||||
workflow_id:
|
||||
Unique identifier (typically ``{product}:{task_type}:{input_hash[:12]}``).
|
||||
product / task_type / input_hash:
|
||||
Same semantics as PipelineRun; the hash is the lookup key.
|
||||
steps:
|
||||
Canonical deterministic step sequence (majority-voted or most-recent,
|
||||
per CrystallizerConfig.strategy).
|
||||
crystallized_at:
|
||||
ISO 8601 UTC timestamp.
|
||||
run_ids:
|
||||
IDs of the source PipelineRuns that contributed to this workflow.
|
||||
approval_count:
|
||||
Number of approved runs that went into crystallization.
|
||||
avg_review_duration_ms:
|
||||
Mean review_duration_ms across all source runs — low values are a
|
||||
warning sign that approvals may not have been genuine.
|
||||
all_output_unmodified:
|
||||
True if every contributing run had output_modified=False. Combined with
|
||||
a very short avg_review_duration_ms this can flag workflows that may
|
||||
have crystallized from rubber-stamp approvals.
|
||||
active:
|
||||
Whether this workflow is in use. Set to False to disable without
|
||||
deleting the record.
|
||||
version:
|
||||
Increments each time the workflow is re-crystallized from new runs.
|
||||
"""
|
||||
|
||||
workflow_id: str
|
||||
product: str
|
||||
task_type: str
|
||||
input_hash: str
|
||||
steps: list[Step]
|
||||
crystallized_at: str
|
||||
run_ids: list[str]
|
||||
approval_count: int
|
||||
avg_review_duration_ms: int
|
||||
all_output_unmodified: bool
|
||||
active: bool = True
|
||||
version: int = 1
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"workflow_id": self.workflow_id,
|
||||
"product": self.product,
|
||||
"task_type": self.task_type,
|
||||
"input_hash": self.input_hash,
|
||||
"steps": [s.to_dict() for s in self.steps],
|
||||
"crystallized_at": self.crystallized_at,
|
||||
"run_ids": self.run_ids,
|
||||
"approval_count": self.approval_count,
|
||||
"avg_review_duration_ms": self.avg_review_duration_ms,
|
||||
"all_output_unmodified": self.all_output_unmodified,
|
||||
"active": self.active,
|
||||
"version": self.version,
|
||||
"metadata": self.metadata,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d: dict[str, Any]) -> "CrystallizedWorkflow":
|
||||
return cls(
|
||||
workflow_id=d["workflow_id"],
|
||||
product=d["product"],
|
||||
task_type=d["task_type"],
|
||||
input_hash=d["input_hash"],
|
||||
steps=[Step.from_dict(s) for s in d.get("steps", [])],
|
||||
crystallized_at=d["crystallized_at"],
|
||||
run_ids=d.get("run_ids", []),
|
||||
approval_count=d["approval_count"],
|
||||
avg_review_duration_ms=d["avg_review_duration_ms"],
|
||||
all_output_unmodified=d.get("all_output_unmodified", True),
|
||||
active=d.get("active", True),
|
||||
version=d.get("version", 1),
|
||||
metadata=d.get("metadata", {}),
|
||||
)
|
||||
234
circuitforge_core/pipeline/multimodal.py
Normal file
234
circuitforge_core/pipeline/multimodal.py
Normal file
|
|
@ -0,0 +1,234 @@
|
|||
# circuitforge_core/pipeline/multimodal.py — cf-docuvision + cf-text pipeline
|
||||
#
|
||||
# MIT — orchestration only; vision and text inference stay in their own modules.
|
||||
#
|
||||
# Usage (minimal):
|
||||
#
|
||||
# from circuitforge_core.pipeline.multimodal import MultimodalPipeline, MultimodalConfig
|
||||
#
|
||||
# pipe = MultimodalPipeline(MultimodalConfig())
|
||||
# for result in pipe.run(page_bytes_list):
|
||||
# print(f"Page {result.page_idx}: {result.generated[:80]}")
|
||||
#
|
||||
# Streaming (token-by-token):
|
||||
#
|
||||
# for page_idx, token in pipe.stream(page_bytes_list):
|
||||
# ui.append(page_idx, token)
|
||||
#
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Callable, Iterable, Iterator
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from circuitforge_core.documents.client import DocuvisionClient
|
||||
from circuitforge_core.documents.models import StructuredDocument
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ── Config ────────────────────────────────────────────────────────────────────
|
||||
|
||||
def _default_prompt(page_idx: int, doc: StructuredDocument) -> str:
|
||||
"""Build a generation prompt from a StructuredDocument."""
|
||||
header = f"[Page {page_idx + 1}]\n" if page_idx > 0 else ""
|
||||
return header + doc.raw_text
|
||||
|
||||
|
||||
@dataclass
|
||||
class MultimodalConfig:
|
||||
"""Configuration for MultimodalPipeline.
|
||||
|
||||
vision_url:
|
||||
Base URL of the cf-docuvision service.
|
||||
hint:
|
||||
Docuvision extraction hint — ``"auto"`` | ``"document"`` | ``"form"``
|
||||
| ``"table"`` | ``"figure"``.
|
||||
max_tokens:
|
||||
Passed to cf-text generate per page.
|
||||
temperature:
|
||||
Sampling temperature for text generation.
|
||||
vram_serialise:
|
||||
When True, ``swap_fn`` is called between the vision and text steps
|
||||
on each page. Use this on 8GB GPUs where Dolphin-v2 and the text
|
||||
model cannot be resident simultaneously.
|
||||
prompt_fn:
|
||||
Callable ``(page_idx, StructuredDocument) -> str`` that builds the
|
||||
generation prompt. Defaults to using ``doc.raw_text`` directly.
|
||||
Products override this to add system context, few-shot examples, etc.
|
||||
vision_timeout:
|
||||
HTTP timeout in seconds for each cf-docuvision request.
|
||||
"""
|
||||
vision_url: str = "http://localhost:8003"
|
||||
hint: str = "auto"
|
||||
max_tokens: int = 512
|
||||
temperature: float = 0.7
|
||||
vram_serialise: bool = False
|
||||
prompt_fn: Callable[[int, StructuredDocument], str] = field(
|
||||
default_factory=lambda: _default_prompt
|
||||
)
|
||||
vision_timeout: int = 60
|
||||
|
||||
|
||||
# ── Results ───────────────────────────────────────────────────────────────────
|
||||
|
||||
@dataclass
|
||||
class PageResult:
|
||||
"""Result of processing one page through the vision + text pipeline.
|
||||
|
||||
page_idx:
|
||||
Zero-based page index.
|
||||
doc:
|
||||
StructuredDocument from cf-docuvision.
|
||||
generated:
|
||||
Full text output from cf-text for this page.
|
||||
error:
|
||||
Non-None if extraction or generation failed for this page.
|
||||
"""
|
||||
page_idx: int
|
||||
doc: StructuredDocument | None
|
||||
generated: str
|
||||
error: str | None = None
|
||||
|
||||
|
||||
# ── Pipeline ──────────────────────────────────────────────────────────────────
|
||||
|
||||
class MultimodalPipeline:
|
||||
"""Chunk a multi-page document through vision extraction + text generation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config:
|
||||
Pipeline configuration.
|
||||
swap_fn:
|
||||
Optional callable with no arguments, called between the vision and text
|
||||
steps on each page when ``config.vram_serialise=True``. Products using
|
||||
cf-orch wire this to the VRAM budget API so Dolphin-v2 can offload
|
||||
before the text model loads. A no-op lambda works for testing.
|
||||
generate_fn:
|
||||
Text generation callable: ``(prompt, max_tokens, temperature) -> str``.
|
||||
Defaults to ``circuitforge_core.text.generate``. Override in tests or
|
||||
when the product manages its own text backend.
|
||||
stream_fn:
|
||||
Streaming text callable: ``(prompt, max_tokens, temperature) -> Iterator[str]``.
|
||||
Defaults to ``circuitforge_core.text.generate`` with ``stream=True``.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: MultimodalConfig | None = None,
|
||||
*,
|
||||
swap_fn: Callable[[], None] | None = None,
|
||||
generate_fn: Callable[..., str] | None = None,
|
||||
stream_fn: Callable[..., Iterator[str]] | None = None,
|
||||
) -> None:
|
||||
self._cfg = config or MultimodalConfig()
|
||||
self._vision = DocuvisionClient(
|
||||
base_url=self._cfg.vision_url,
|
||||
timeout=self._cfg.vision_timeout,
|
||||
)
|
||||
self._swap_fn = swap_fn
|
||||
self._generate_fn = generate_fn
|
||||
self._stream_fn = stream_fn
|
||||
|
||||
# ── Public ────────────────────────────────────────────────────────────────
|
||||
|
||||
def run(self, pages: Iterable[bytes]) -> Iterator[PageResult]:
|
||||
"""Process each page and yield a PageResult as soon as it is ready.
|
||||
|
||||
Callers receive pages one at a time — the UI can begin rendering
|
||||
page 0 while pages 1..N are still being extracted and generated.
|
||||
"""
|
||||
for page_idx, page_bytes in enumerate(pages):
|
||||
yield self._process_page(page_idx, page_bytes)
|
||||
|
||||
def stream(self, pages: Iterable[bytes]) -> Iterator[tuple[int, str]]:
|
||||
"""Yield ``(page_idx, token)`` tuples for token-level progressive rendering.
|
||||
|
||||
Each page is fully extracted before text generation begins, but tokens
|
||||
are yielded as the text model produces them rather than waiting for the
|
||||
full page output.
|
||||
"""
|
||||
for page_idx, page_bytes in enumerate(pages):
|
||||
doc, err = self._extract(page_idx, page_bytes)
|
||||
if err:
|
||||
yield (page_idx, f"[extraction error: {err}]")
|
||||
continue
|
||||
|
||||
self._maybe_swap()
|
||||
|
||||
prompt = self._cfg.prompt_fn(page_idx, doc)
|
||||
try:
|
||||
for token in self._stream_tokens(prompt):
|
||||
yield (page_idx, token)
|
||||
except Exception as exc:
|
||||
log.error("page %d text streaming failed: %s", page_idx, exc)
|
||||
yield (page_idx, f"[generation error: {exc}]")
|
||||
|
||||
# ── Internal ──────────────────────────────────────────────────────────────
|
||||
|
||||
def _process_page(self, page_idx: int, page_bytes: bytes) -> PageResult:
|
||||
doc, err = self._extract(page_idx, page_bytes)
|
||||
if err:
|
||||
return PageResult(page_idx=page_idx, doc=None, generated="", error=err)
|
||||
|
||||
self._maybe_swap()
|
||||
|
||||
prompt = self._cfg.prompt_fn(page_idx, doc)
|
||||
try:
|
||||
text = self._generate(prompt)
|
||||
except Exception as exc:
|
||||
log.error("page %d generation failed: %s", page_idx, exc)
|
||||
return PageResult(page_idx=page_idx, doc=doc, generated="",
|
||||
error=str(exc))
|
||||
|
||||
return PageResult(page_idx=page_idx, doc=doc, generated=text)
|
||||
|
||||
def _extract(
|
||||
self, page_idx: int, page_bytes: bytes
|
||||
) -> tuple[StructuredDocument | None, str | None]:
|
||||
try:
|
||||
doc = self._vision.extract(page_bytes, hint=self._cfg.hint)
|
||||
log.debug("page %d extracted: %d chars", page_idx, len(doc.raw_text))
|
||||
return doc, None
|
||||
except Exception as exc:
|
||||
log.error("page %d vision extraction failed: %s", page_idx, exc)
|
||||
return None, str(exc)
|
||||
|
||||
def _maybe_swap(self) -> None:
|
||||
if self._cfg.vram_serialise and self._swap_fn is not None:
|
||||
log.debug("vram_serialise: calling swap_fn")
|
||||
self._swap_fn()
|
||||
|
||||
def _generate(self, prompt: str) -> str:
|
||||
if self._generate_fn is not None:
|
||||
return self._generate_fn(
|
||||
prompt,
|
||||
max_tokens=self._cfg.max_tokens,
|
||||
temperature=self._cfg.temperature,
|
||||
)
|
||||
from circuitforge_core.text import generate
|
||||
result = generate(
|
||||
prompt,
|
||||
max_tokens=self._cfg.max_tokens,
|
||||
temperature=self._cfg.temperature,
|
||||
)
|
||||
return result.text
|
||||
|
||||
def _stream_tokens(self, prompt: str) -> Iterator[str]:
|
||||
if self._stream_fn is not None:
|
||||
yield from self._stream_fn(
|
||||
prompt,
|
||||
max_tokens=self._cfg.max_tokens,
|
||||
temperature=self._cfg.temperature,
|
||||
)
|
||||
return
|
||||
from circuitforge_core.text import generate
|
||||
tokens = generate(
|
||||
prompt,
|
||||
max_tokens=self._cfg.max_tokens,
|
||||
temperature=self._cfg.temperature,
|
||||
stream=True,
|
||||
)
|
||||
yield from tokens
|
||||
70
circuitforge_core/pipeline/recorder.py
Normal file
70
circuitforge_core/pipeline/recorder.py
Normal file
|
|
@ -0,0 +1,70 @@
|
|||
# circuitforge_core/pipeline/recorder.py — write and load PipelineRun records
|
||||
#
|
||||
# MIT — local file I/O only; no inference.
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Iterable
|
||||
|
||||
from .models import PipelineRun
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
_DEFAULT_ROOT = Path.home() / ".config" / "circuitforge" / "pipeline" / "runs"
|
||||
|
||||
|
||||
class Recorder:
|
||||
"""Writes PipelineRun JSON records to a local directory tree.
|
||||
|
||||
Layout::
|
||||
|
||||
{root}/{product}/{task_type}/{run_id}.json
|
||||
|
||||
The recorder is intentionally append-only — it never deletes or modifies
|
||||
existing records. Old runs accumulate as an audit trail; products that
|
||||
want retention limits should prune the directory themselves.
|
||||
"""
|
||||
|
||||
def __init__(self, root: Path | None = None) -> None:
|
||||
self._root = Path(root) if root else _DEFAULT_ROOT
|
||||
|
||||
# ── Write ─────────────────────────────────────────────────────────────────
|
||||
|
||||
def record(self, run: PipelineRun) -> Path:
|
||||
"""Persist *run* to disk and return the file path written."""
|
||||
dest = self._path_for(run.product, run.task_type, run.run_id)
|
||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||
dest.write_text(json.dumps(run.to_dict(), indent=2), encoding="utf-8")
|
||||
log.debug("recorded pipeline run %s → %s", run.run_id, dest)
|
||||
return dest
|
||||
|
||||
# ── Read ──────────────────────────────────────────────────────────────────
|
||||
|
||||
def load_runs(self, product: str, task_type: str) -> list[PipelineRun]:
|
||||
"""Return all runs for *(product, task_type)*, newest-first."""
|
||||
directory = self._root / product / task_type
|
||||
if not directory.is_dir():
|
||||
return []
|
||||
runs: list[PipelineRun] = []
|
||||
for p in directory.glob("*.json"):
|
||||
try:
|
||||
runs.append(PipelineRun.from_dict(json.loads(p.read_text())))
|
||||
except Exception:
|
||||
log.warning("skipping unreadable run file %s", p)
|
||||
runs.sort(key=lambda r: r.timestamp, reverse=True)
|
||||
return runs
|
||||
|
||||
def load_approved(self, product: str, task_type: str,
|
||||
input_hash: str) -> list[PipelineRun]:
|
||||
"""Return approved runs that match *input_hash*, newest-first."""
|
||||
return [
|
||||
r for r in self.load_runs(product, task_type)
|
||||
if r.approved and r.input_hash == input_hash
|
||||
]
|
||||
|
||||
# ── Internal ──────────────────────────────────────────────────────────────
|
||||
|
||||
def _path_for(self, product: str, task_type: str, run_id: str) -> Path:
|
||||
return self._root / product / task_type / f"{run_id}.json"
|
||||
134
circuitforge_core/pipeline/registry.py
Normal file
134
circuitforge_core/pipeline/registry.py
Normal file
|
|
@ -0,0 +1,134 @@
|
|||
# circuitforge_core/pipeline/registry.py — workflow lookup
|
||||
#
|
||||
# MIT — file I/O and matching logic only.
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
from .models import CrystallizedWorkflow
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
_DEFAULT_ROOT = Path.home() / ".config" / "circuitforge" / "pipeline" / "workflows"
|
||||
|
||||
|
||||
class Registry:
|
||||
"""Loads and matches CrystallizedWorkflows from the local filesystem.
|
||||
|
||||
Layout::
|
||||
|
||||
{root}/{product}/{task_type}/{workflow_id}.json
|
||||
|
||||
Exact matching is always available. Products that need fuzzy/semantic
|
||||
matching can supply a ``similarity_fn`` — a callable that takes two input
|
||||
hashes and returns a float in [0, 1]. The registry returns the first
|
||||
active workflow whose similarity score meets ``fuzzy_threshold``.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root: Path | None = None,
|
||||
similarity_fn: Callable[[str, str], float] | None = None,
|
||||
fuzzy_threshold: float = 0.8,
|
||||
) -> None:
|
||||
self._root = Path(root) if root else _DEFAULT_ROOT
|
||||
self._similarity_fn = similarity_fn
|
||||
self._fuzzy_threshold = fuzzy_threshold
|
||||
|
||||
# ── Write ─────────────────────────────────────────────────────────────────
|
||||
|
||||
def register(self, workflow: CrystallizedWorkflow) -> Path:
|
||||
"""Persist *workflow* and return the path written."""
|
||||
dest = self._path_for(workflow.product, workflow.task_type,
|
||||
workflow.workflow_id)
|
||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||
dest.write_text(json.dumps(workflow.to_dict(), indent=2), encoding="utf-8")
|
||||
log.info("registered workflow %s (v%d)", workflow.workflow_id,
|
||||
workflow.version)
|
||||
return dest
|
||||
|
||||
def deactivate(self, workflow_id: str, product: str,
|
||||
task_type: str) -> bool:
|
||||
"""Set ``active=False`` on a stored workflow. Returns True if found."""
|
||||
path = self._path_for(product, task_type, workflow_id)
|
||||
if not path.exists():
|
||||
return False
|
||||
data = json.loads(path.read_text())
|
||||
data["active"] = False
|
||||
path.write_text(json.dumps(data, indent=2), encoding="utf-8")
|
||||
log.info("deactivated workflow %s", workflow_id)
|
||||
return True
|
||||
|
||||
# ── Read ──────────────────────────────────────────────────────────────────
|
||||
|
||||
def load_all(self, product: str, task_type: str) -> list[CrystallizedWorkflow]:
|
||||
"""Return all (including inactive) workflows for *(product, task_type)*."""
|
||||
directory = self._root / product / task_type
|
||||
if not directory.is_dir():
|
||||
return []
|
||||
workflows: list[CrystallizedWorkflow] = []
|
||||
for p in directory.glob("*.json"):
|
||||
try:
|
||||
workflows.append(
|
||||
CrystallizedWorkflow.from_dict(json.loads(p.read_text()))
|
||||
)
|
||||
except Exception:
|
||||
log.warning("skipping unreadable workflow file %s", p)
|
||||
return workflows
|
||||
|
||||
# ── Match ─────────────────────────────────────────────────────────────────
|
||||
|
||||
def match(self, product: str, task_type: str,
|
||||
input_hash: str) -> CrystallizedWorkflow | None:
|
||||
"""Return the active workflow for an exact input_hash match, or None."""
|
||||
for wf in self.load_all(product, task_type):
|
||||
if wf.active and wf.input_hash == input_hash:
|
||||
log.debug("registry exact match: %s", wf.workflow_id)
|
||||
return wf
|
||||
return None
|
||||
|
||||
def fuzzy_match(self, product: str, task_type: str,
|
||||
input_hash: str) -> CrystallizedWorkflow | None:
|
||||
"""Return a workflow above the similarity threshold, or None.
|
||||
|
||||
Requires a ``similarity_fn`` to have been supplied at construction.
|
||||
If none was provided, raises ``RuntimeError``.
|
||||
"""
|
||||
if self._similarity_fn is None:
|
||||
raise RuntimeError(
|
||||
"fuzzy_match() requires a similarity_fn — none was supplied "
|
||||
"to Registry.__init__()."
|
||||
)
|
||||
best: CrystallizedWorkflow | None = None
|
||||
best_score = 0.0
|
||||
for wf in self.load_all(product, task_type):
|
||||
if not wf.active:
|
||||
continue
|
||||
score = self._similarity_fn(wf.input_hash, input_hash)
|
||||
if score >= self._fuzzy_threshold and score > best_score:
|
||||
best = wf
|
||||
best_score = score
|
||||
if best:
|
||||
log.debug("registry fuzzy match: %s (score=%.2f)", best.workflow_id,
|
||||
best_score)
|
||||
return best
|
||||
|
||||
def find(self, product: str, task_type: str,
|
||||
input_hash: str) -> CrystallizedWorkflow | None:
|
||||
"""Exact match first; fuzzy match second (if similarity_fn is set)."""
|
||||
exact = self.match(product, task_type, input_hash)
|
||||
if exact:
|
||||
return exact
|
||||
if self._similarity_fn is not None:
|
||||
return self.fuzzy_match(product, task_type, input_hash)
|
||||
return None
|
||||
|
||||
# ── Internal ──────────────────────────────────────────────────────────────
|
||||
|
||||
def _path_for(self, product: str, task_type: str,
|
||||
workflow_id: str) -> Path:
|
||||
safe_id = workflow_id.replace(":", "_")
|
||||
return self._root / product / task_type / f"{safe_id}.json"
|
||||
50
circuitforge_core/preferences/__init__.py
Normal file
50
circuitforge_core/preferences/__init__.py
Normal file
|
|
@ -0,0 +1,50 @@
|
|||
from . import store as store_module
|
||||
from .paths import get_path, set_path
|
||||
from .store import LocalFileStore, PreferenceStore
|
||||
|
||||
|
||||
def get_user_preference(
|
||||
user_id: str | None,
|
||||
path: str,
|
||||
default=None,
|
||||
store: PreferenceStore | None = None,
|
||||
):
|
||||
"""Read a preference value at dot-separated *path*.
|
||||
|
||||
Args:
|
||||
user_id: User identifier (passed to store; local store ignores it).
|
||||
path: Dot-separated preference path, e.g. ``"affiliate.opt_out"``.
|
||||
default: Returned when the path is not set.
|
||||
store: Optional store override; defaults to ``LocalFileStore`` at
|
||||
``~/.config/circuitforge/preferences.yaml``.
|
||||
"""
|
||||
s = store or store_module._DEFAULT_STORE
|
||||
return s.get(user_id=user_id, path=path, default=default)
|
||||
|
||||
|
||||
def set_user_preference(
|
||||
user_id: str | None,
|
||||
path: str,
|
||||
value,
|
||||
store: PreferenceStore | None = None,
|
||||
) -> None:
|
||||
"""Write *value* at dot-separated *path*.
|
||||
|
||||
Args:
|
||||
user_id: User identifier (passed to store; local store ignores it).
|
||||
path: Dot-separated preference path, e.g. ``"affiliate.byok_ids.ebay"``.
|
||||
value: Value to persist.
|
||||
store: Optional store override; defaults to ``LocalFileStore``.
|
||||
"""
|
||||
s = store or store_module._DEFAULT_STORE
|
||||
s.set(user_id=user_id, path=path, value=value)
|
||||
|
||||
|
||||
from . import accessibility as accessibility
|
||||
|
||||
__all__ = [
|
||||
"get_path", "set_path",
|
||||
"get_user_preference", "set_user_preference",
|
||||
"LocalFileStore", "PreferenceStore",
|
||||
"accessibility",
|
||||
]
|
||||
73
circuitforge_core/preferences/accessibility.py
Normal file
73
circuitforge_core/preferences/accessibility.py
Normal file
|
|
@ -0,0 +1,73 @@
|
|||
# circuitforge_core/preferences/accessibility.py — a11y preference keys
|
||||
#
|
||||
# First-class accessibility preferences so every product UI reads from
|
||||
# the same store path without each implementing it separately.
|
||||
#
|
||||
# All keys use the "accessibility.*" namespace in the preference store.
|
||||
# Products read these via get_user_preference() or the convenience helpers here.
|
||||
from __future__ import annotations
|
||||
|
||||
from circuitforge_core.preferences import get_user_preference, set_user_preference
|
||||
|
||||
# ── Preference key constants ──────────────────────────────────────────────────
|
||||
|
||||
PREF_REDUCED_MOTION = "accessibility.prefers_reduced_motion"
|
||||
PREF_HIGH_CONTRAST = "accessibility.high_contrast"
|
||||
PREF_FONT_SIZE = "accessibility.font_size" # "default" | "large" | "xlarge"
|
||||
PREF_SCREEN_READER = "accessibility.screen_reader_mode" # reduces decorative content
|
||||
|
||||
_DEFAULTS: dict[str, object] = {
|
||||
PREF_REDUCED_MOTION: False,
|
||||
PREF_HIGH_CONTRAST: False,
|
||||
PREF_FONT_SIZE: "default",
|
||||
PREF_SCREEN_READER: False,
|
||||
}
|
||||
|
||||
|
||||
# ── Convenience helpers ───────────────────────────────────────────────────────
|
||||
|
||||
def is_reduced_motion_preferred(
|
||||
user_id: str | None = None,
|
||||
store=None,
|
||||
) -> bool:
|
||||
"""
|
||||
Return True if the user has requested reduced motion.
|
||||
|
||||
Products must honour this in all animated UI elements: transitions,
|
||||
auto-playing content, parallax, loaders. This maps to the CSS
|
||||
`prefers-reduced-motion: reduce` media query and is the canonical
|
||||
source of truth across all CF product UIs.
|
||||
|
||||
Default: False.
|
||||
"""
|
||||
val = get_user_preference(
|
||||
user_id, PREF_REDUCED_MOTION, default=False, store=store
|
||||
)
|
||||
return bool(val)
|
||||
|
||||
|
||||
def is_high_contrast(user_id: str | None = None, store=None) -> bool:
|
||||
"""Return True if the user has requested high-contrast mode."""
|
||||
return bool(get_user_preference(user_id, PREF_HIGH_CONTRAST, default=False, store=store))
|
||||
|
||||
|
||||
def get_font_size(user_id: str | None = None, store=None) -> str:
|
||||
"""Return the user's preferred font size: 'default' | 'large' | 'xlarge'."""
|
||||
val = get_user_preference(user_id, PREF_FONT_SIZE, default="default", store=store)
|
||||
if val not in ("default", "large", "xlarge"):
|
||||
return "default"
|
||||
return str(val)
|
||||
|
||||
|
||||
def is_screen_reader_mode(user_id: str | None = None, store=None) -> bool:
|
||||
"""Return True if the user has requested screen reader optimised output."""
|
||||
return bool(get_user_preference(user_id, PREF_SCREEN_READER, default=False, store=store))
|
||||
|
||||
|
||||
def set_reduced_motion(
|
||||
value: bool,
|
||||
user_id: str | None = None,
|
||||
store=None,
|
||||
) -> None:
|
||||
"""Persist the user's reduced-motion preference."""
|
||||
set_user_preference(user_id, PREF_REDUCED_MOTION, value, store=store)
|
||||
64
circuitforge_core/preferences/paths.py
Normal file
64
circuitforge_core/preferences/paths.py
Normal file
|
|
@ -0,0 +1,64 @@
|
|||
"""Dot-path utilities for reading and writing nested preference dicts.
|
||||
|
||||
All operations are immutable: set_path returns a new dict rather than
|
||||
mutating the input.
|
||||
|
||||
Path format: dot-separated keys, e.g. "affiliate.byok_ids.ebay"
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
def get_path(data: dict, path: str, default: Any = None) -> Any:
|
||||
"""Return the value at *path* inside *data*, or *default* if missing.
|
||||
|
||||
Example::
|
||||
|
||||
prefs = {"affiliate": {"opt_out": False, "byok_ids": {"ebay": "my-id"}}}
|
||||
get_path(prefs, "affiliate.byok_ids.ebay") # "my-id"
|
||||
get_path(prefs, "affiliate.missing", default="x") # "x"
|
||||
"""
|
||||
keys = path.split(".")
|
||||
node: Any = data
|
||||
for key in keys:
|
||||
if not isinstance(node, dict):
|
||||
return default
|
||||
node = node.get(key, _SENTINEL)
|
||||
if node is _SENTINEL:
|
||||
return default
|
||||
return node
|
||||
|
||||
|
||||
def set_path(data: dict, path: str, value: Any) -> dict:
|
||||
"""Return a new dict with *value* written at *path*.
|
||||
|
||||
Intermediate dicts are created as needed; existing values at other paths
|
||||
are preserved. The original *data* dict is never mutated.
|
||||
|
||||
Example::
|
||||
|
||||
prefs = {}
|
||||
updated = set_path(prefs, "affiliate.opt_out", True)
|
||||
# {"affiliate": {"opt_out": True}}
|
||||
"""
|
||||
keys = path.split(".")
|
||||
return _set_recursive(data, keys, value)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_SENTINEL = object()
|
||||
|
||||
|
||||
def _set_recursive(node: Any, keys: list[str], value: Any) -> dict:
|
||||
if not isinstance(node, dict):
|
||||
node = {}
|
||||
key, rest = keys[0], keys[1:]
|
||||
if rest:
|
||||
child = _set_recursive(node.get(key, {}), rest, value)
|
||||
else:
|
||||
child = value
|
||||
return {**node, key: child}
|
||||
78
circuitforge_core/preferences/store.py
Normal file
78
circuitforge_core/preferences/store.py
Normal file
|
|
@ -0,0 +1,78 @@
|
|||
"""Preference store backends.
|
||||
|
||||
``LocalFileStore`` reads and writes a single YAML file at a configurable
|
||||
path (default: ``~/.config/circuitforge/preferences.yaml``).
|
||||
|
||||
The ``PreferenceStore`` protocol describes the interface any backend must
|
||||
satisfy. The Heimdall cloud backend will implement the same protocol once
|
||||
Heimdall#5 (user_preferences column) lands — products swap backends by
|
||||
passing a different store instance.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any, Protocol, runtime_checkable
|
||||
|
||||
from .paths import get_path, set_path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_DEFAULT_PREFS_PATH = Path.home() / ".config" / "circuitforge" / "preferences.yaml"
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class PreferenceStore(Protocol):
|
||||
"""Read/write interface for user preferences.
|
||||
|
||||
``user_id`` is passed through for cloud backends that store per-user
|
||||
data. Local single-user backends accept it but ignore it.
|
||||
"""
|
||||
|
||||
def get(self, user_id: str | None, path: str, default: Any = None) -> Any:
|
||||
"""Return the value at *path*, or *default* if missing."""
|
||||
...
|
||||
|
||||
def set(self, user_id: str | None, path: str, value: Any) -> None:
|
||||
"""Persist *value* at *path*."""
|
||||
...
|
||||
|
||||
|
||||
class LocalFileStore:
|
||||
"""Single-user preference store backed by a YAML file.
|
||||
|
||||
Thread-safe for typical single-process use (reads the file on every
|
||||
``get`` call, writes atomically via a temp-file rename on ``set``).
|
||||
Not suitable for concurrent multi-process writes.
|
||||
"""
|
||||
|
||||
def __init__(self, prefs_path: Path = _DEFAULT_PREFS_PATH) -> None:
|
||||
self._path = Path(prefs_path)
|
||||
|
||||
def _load(self) -> dict:
|
||||
if not self._path.exists():
|
||||
return {}
|
||||
try:
|
||||
import yaml # type: ignore[import]
|
||||
text = self._path.read_text(encoding="utf-8")
|
||||
data = yaml.safe_load(text)
|
||||
return data if isinstance(data, dict) else {}
|
||||
except Exception as exc:
|
||||
logger.warning("preferences: could not read %s: %s", self._path, exc)
|
||||
return {}
|
||||
|
||||
def _save(self, data: dict) -> None:
|
||||
import yaml # type: ignore[import]
|
||||
self._path.parent.mkdir(parents=True, exist_ok=True)
|
||||
tmp = self._path.with_suffix(".yaml.tmp")
|
||||
tmp.write_text(yaml.safe_dump(data, default_flow_style=False), encoding="utf-8")
|
||||
tmp.replace(self._path)
|
||||
|
||||
def get(self, user_id: str | None, path: str, default: Any = None) -> Any: # noqa: ARG002
|
||||
return get_path(self._load(), path, default=default)
|
||||
|
||||
def set(self, user_id: str | None, path: str, value: Any) -> None: # noqa: ARG002
|
||||
self._save(set_path(self._load(), path, value))
|
||||
|
||||
|
||||
_DEFAULT_STORE: PreferenceStore = LocalFileStore()
|
||||
|
|
@ -1 +0,0 @@
|
|||
from circuitforge_core.resources.client import CFOrchClient, Allocation # noqa: F401
|
||||
|
|
@ -1,101 +0,0 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
from circuitforge_core.resources.agent.eviction_executor import EvictionExecutor
|
||||
from circuitforge_core.resources.agent.gpu_monitor import GpuMonitor
|
||||
from circuitforge_core.resources.agent.service_manager import ServiceManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EvictRequest(BaseModel):
|
||||
pid: int
|
||||
grace_period_s: float = 5.0
|
||||
|
||||
|
||||
class ServiceStartRequest(BaseModel):
|
||||
gpu_id: int = 0
|
||||
params: dict[str, str] = {}
|
||||
|
||||
|
||||
def create_agent_app(
|
||||
node_id: str,
|
||||
monitor: GpuMonitor | None = None,
|
||||
executor: EvictionExecutor | None = None,
|
||||
service_manager: ServiceManager | None = None,
|
||||
) -> FastAPI:
|
||||
_monitor = monitor or GpuMonitor()
|
||||
_executor = executor or EvictionExecutor()
|
||||
|
||||
app = FastAPI(title=f"cf-orch-agent [{node_id}]")
|
||||
|
||||
@app.get("/health")
|
||||
def health() -> dict[str, Any]:
|
||||
return {"status": "ok", "node_id": node_id}
|
||||
|
||||
@app.get("/gpu-info")
|
||||
def gpu_info() -> dict[str, Any]:
|
||||
gpus = _monitor.poll()
|
||||
return {
|
||||
"node_id": node_id,
|
||||
"gpus": [
|
||||
{
|
||||
"gpu_id": g.gpu_id,
|
||||
"name": g.name,
|
||||
"vram_total_mb": g.vram_total_mb,
|
||||
"vram_used_mb": g.vram_used_mb,
|
||||
"vram_free_mb": g.vram_free_mb,
|
||||
}
|
||||
for g in gpus
|
||||
],
|
||||
}
|
||||
|
||||
@app.post("/evict")
|
||||
def evict(req: EvictRequest) -> dict[str, Any]:
|
||||
result = _executor.evict_pid(pid=req.pid, grace_period_s=req.grace_period_s)
|
||||
return {
|
||||
"success": result.success,
|
||||
"method": result.method,
|
||||
"message": result.message,
|
||||
}
|
||||
|
||||
@app.get("/resident-info")
|
||||
def resident_info() -> dict[str, Any]:
|
||||
"""Return which models are currently loaded in each running managed service."""
|
||||
if service_manager is None:
|
||||
return {"residents": []}
|
||||
from circuitforge_core.resources.agent.service_probe import probe_all
|
||||
return {"residents": probe_all(service_manager)}
|
||||
|
||||
if service_manager is not None:
|
||||
@app.get("/services")
|
||||
def list_services() -> dict:
|
||||
return {"running": service_manager.list_running()}
|
||||
|
||||
@app.get("/services/{service}")
|
||||
def service_status(service: str) -> dict:
|
||||
running = service_manager.is_running(service)
|
||||
url = service_manager.get_url(service) if running else None
|
||||
return {"service": service, "running": running, "url": url}
|
||||
|
||||
@app.post("/services/{service}/start")
|
||||
def start_service(service: str, req: ServiceStartRequest) -> dict:
|
||||
try:
|
||||
url = service_manager.start(service, req.gpu_id, req.params)
|
||||
return {"service": service, "url": url, "running": True}
|
||||
except (ValueError, NotImplementedError) as exc:
|
||||
raise HTTPException(status_code=422, detail=str(exc))
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to start {service}: {exc}")
|
||||
|
||||
@app.post("/services/{service}/stop")
|
||||
def stop_service(service: str) -> dict:
|
||||
stopped = service_manager.stop(service)
|
||||
return {"service": service, "stopped": stopped}
|
||||
|
||||
return app
|
||||
|
|
@ -1,85 +0,0 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
|
||||
import psutil
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_DEFAULT_GRACE_S = 5.0
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class EvictionResult:
|
||||
success: bool
|
||||
method: str # "sigterm", "sigkill", "already_gone", "not_found", "error"
|
||||
message: str
|
||||
|
||||
|
||||
class EvictionExecutor:
|
||||
def __init__(self, grace_period_s: float = _DEFAULT_GRACE_S) -> None:
|
||||
self._default_grace = grace_period_s
|
||||
|
||||
def evict_pid(
|
||||
self,
|
||||
pid: int,
|
||||
grace_period_s: float | None = None,
|
||||
) -> EvictionResult:
|
||||
grace = grace_period_s if grace_period_s is not None else self._default_grace
|
||||
|
||||
if pid <= 0:
|
||||
return EvictionResult(
|
||||
success=False, method="error",
|
||||
message=f"Refusing to signal invalid PID {pid}"
|
||||
)
|
||||
|
||||
if not psutil.pid_exists(pid):
|
||||
return EvictionResult(
|
||||
success=False, method="not_found",
|
||||
message=f"PID {pid} not found"
|
||||
)
|
||||
|
||||
try:
|
||||
os.kill(pid, signal.SIGTERM)
|
||||
except ProcessLookupError:
|
||||
return EvictionResult(
|
||||
success=True, method="already_gone",
|
||||
message=f"PID {pid} vanished before SIGTERM"
|
||||
)
|
||||
except PermissionError as exc:
|
||||
return EvictionResult(
|
||||
success=False, method="error",
|
||||
message=f"Permission denied terminating PID {pid}: {exc}"
|
||||
)
|
||||
|
||||
# Wait for grace period
|
||||
deadline = time.monotonic() + grace
|
||||
while time.monotonic() < deadline:
|
||||
if not psutil.pid_exists(pid):
|
||||
logger.info("PID %d exited cleanly after SIGTERM", pid)
|
||||
return EvictionResult(
|
||||
success=True, method="sigterm",
|
||||
message=f"PID {pid} exited after SIGTERM"
|
||||
)
|
||||
time.sleep(0.05)
|
||||
|
||||
# Escalate to SIGKILL
|
||||
if psutil.pid_exists(pid):
|
||||
try:
|
||||
os.kill(pid, signal.SIGKILL)
|
||||
logger.warning("PID %d required SIGKILL", pid)
|
||||
return EvictionResult(
|
||||
success=True, method="sigkill",
|
||||
message=f"PID {pid} killed with SIGKILL"
|
||||
)
|
||||
except ProcessLookupError:
|
||||
pass
|
||||
|
||||
return EvictionResult(
|
||||
success=True, method="sigkill",
|
||||
message=f"PID {pid} is gone"
|
||||
)
|
||||
|
|
@ -1,52 +0,0 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import subprocess
|
||||
|
||||
from circuitforge_core.resources.models import GpuInfo
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_NVIDIA_SMI_CMD = [
|
||||
"nvidia-smi",
|
||||
"--query-gpu=index,name,memory.total,memory.used,memory.free",
|
||||
"--format=csv,noheader,nounits",
|
||||
]
|
||||
|
||||
|
||||
class GpuMonitor:
|
||||
def poll(self) -> list[GpuInfo]:
|
||||
try:
|
||||
result = subprocess.run(
|
||||
_NVIDIA_SMI_CMD,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5,
|
||||
)
|
||||
except (FileNotFoundError, subprocess.TimeoutExpired) as exc:
|
||||
logger.warning("nvidia-smi unavailable: %s", exc)
|
||||
return []
|
||||
|
||||
if result.returncode != 0:
|
||||
logger.warning("nvidia-smi exited %d", result.returncode)
|
||||
return []
|
||||
|
||||
return self._parse(result.stdout)
|
||||
|
||||
def _parse(self, output: str) -> list[GpuInfo]:
|
||||
gpus: list[GpuInfo] = []
|
||||
for line in output.strip().splitlines():
|
||||
parts = [p.strip() for p in line.split(",")]
|
||||
if len(parts) != 5:
|
||||
continue
|
||||
try:
|
||||
gpus.append(GpuInfo(
|
||||
gpu_id=int(parts[0]),
|
||||
name=parts[1],
|
||||
vram_total_mb=int(parts[2]),
|
||||
vram_used_mb=int(parts[3]),
|
||||
vram_free_mb=int(parts[4]),
|
||||
))
|
||||
except ValueError:
|
||||
logger.debug("Skipping malformed nvidia-smi line: %r", line)
|
||||
return gpus
|
||||
|
|
@ -1,169 +0,0 @@
|
|||
"""
|
||||
ServiceManager — start/stop Docker containers and processes for cf-orch managed services.
|
||||
|
||||
Container naming convention: cf-orch-{service}-{node_id}
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
from collections import defaultdict
|
||||
from typing import Any
|
||||
|
||||
from circuitforge_core.resources.profiles.schema import DockerSpec, GpuProfile, ProcessSpec
|
||||
|
||||
|
||||
def _expand_volume(v: str) -> str:
|
||||
"""Expand bash-style volume strings including ${VAR:-default} and $VAR."""
|
||||
def _sub(m: re.Match) -> str: # type: ignore[type-arg]
|
||||
var, default = m.group(1), m.group(2) or ""
|
||||
return os.environ.get(var) or default
|
||||
v = re.sub(r"\$\{(\w+)(?::-(.*?))?\}", _sub, v)
|
||||
v = re.sub(r"\$(\w+)", lambda m: os.environ.get(m.group(1), m.group(0)), v)
|
||||
return v
|
||||
|
||||
|
||||
class ServiceManager:
|
||||
def __init__(
|
||||
self,
|
||||
node_id: str,
|
||||
profile: GpuProfile,
|
||||
advertise_host: str = "127.0.0.1",
|
||||
) -> None:
|
||||
self.node_id = node_id
|
||||
self.profile = profile
|
||||
self.advertise_host = advertise_host
|
||||
self._procs: dict[str, Any] = {}
|
||||
|
||||
def container_name(self, service: str) -> str:
|
||||
return f"cf-orch-{service}-{self.node_id}"
|
||||
|
||||
def _get_spec(self, service: str) -> DockerSpec | ProcessSpec | None:
|
||||
svc = self.profile.services.get(service)
|
||||
if svc is None:
|
||||
return None
|
||||
return svc.managed
|
||||
|
||||
def is_running(self, service: str) -> bool:
|
||||
spec = self._get_spec(service)
|
||||
if spec is None:
|
||||
return False
|
||||
if isinstance(spec, DockerSpec):
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[
|
||||
"docker",
|
||||
"inspect",
|
||||
"--format",
|
||||
"{{.State.Running}}",
|
||||
self.container_name(service),
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=True,
|
||||
)
|
||||
return result.stdout.strip() == "true"
|
||||
except subprocess.CalledProcessError:
|
||||
return False
|
||||
if isinstance(spec, ProcessSpec):
|
||||
proc = self._procs.get(service)
|
||||
if proc is None or proc.poll() is not None:
|
||||
return False
|
||||
import socket
|
||||
try:
|
||||
with socket.create_connection(("127.0.0.1", spec.host_port), timeout=1):
|
||||
return True
|
||||
except OSError:
|
||||
return False
|
||||
return False
|
||||
|
||||
def start(self, service: str, gpu_id: int, params: dict[str, str]) -> str:
|
||||
spec = self._get_spec(service)
|
||||
if spec is None:
|
||||
raise ValueError(f"Service {service!r} not in profile or has no managed spec")
|
||||
|
||||
if self.is_running(service):
|
||||
return f"http://{self.advertise_host}:{spec.host_port}"
|
||||
|
||||
if isinstance(spec, DockerSpec):
|
||||
expanded_volumes = [_expand_volume(v) for v in spec.volumes]
|
||||
|
||||
filler: dict[str, str] = defaultdict(str, params)
|
||||
expanded_command = spec.command_template.format_map(filler).split()
|
||||
|
||||
cmd = [
|
||||
"docker", "run", "-d", "--rm",
|
||||
"--name", self.container_name(service),
|
||||
"--runtime", spec.runtime,
|
||||
"--gpus", f"device={gpu_id}",
|
||||
"--ipc", spec.ipc,
|
||||
"-p", f"{spec.host_port}:{spec.port}",
|
||||
]
|
||||
for vol in expanded_volumes:
|
||||
cmd += ["-v", vol]
|
||||
for key, val in spec.env.items():
|
||||
cmd += ["-e", f"{key}={val}"]
|
||||
cmd.append(spec.image)
|
||||
cmd.extend(expanded_command)
|
||||
|
||||
subprocess.run(cmd, check=True, capture_output=True, text=True)
|
||||
return f"http://{self.advertise_host}:{spec.host_port}"
|
||||
|
||||
if isinstance(spec, ProcessSpec):
|
||||
import shlex
|
||||
import subprocess as _sp
|
||||
|
||||
filler = defaultdict(str, params)
|
||||
filler.setdefault("port", str(spec.port))
|
||||
filler.setdefault("gpu_id", str(gpu_id))
|
||||
args_expanded = spec.args_template.format_map(filler).split()
|
||||
|
||||
cmd = [spec.exec_path] + args_expanded
|
||||
env = {**__import__("os").environ}
|
||||
proc = _sp.Popen(
|
||||
cmd,
|
||||
cwd=spec.cwd or None,
|
||||
env=env,
|
||||
stdout=_sp.DEVNULL,
|
||||
stderr=_sp.DEVNULL,
|
||||
)
|
||||
self._procs[service] = proc
|
||||
return f"http://{self.advertise_host}:{spec.host_port}"
|
||||
|
||||
raise NotImplementedError(f"Unknown spec type: {type(spec)}")
|
||||
|
||||
def stop(self, service: str) -> bool:
|
||||
spec = self._get_spec(service)
|
||||
if spec is None:
|
||||
return False
|
||||
if isinstance(spec, DockerSpec):
|
||||
try:
|
||||
subprocess.run(
|
||||
["docker", "stop", self.container_name(service)],
|
||||
check=True,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
return True
|
||||
except subprocess.CalledProcessError:
|
||||
return False
|
||||
if isinstance(spec, ProcessSpec):
|
||||
proc = self._procs.pop(service, None)
|
||||
if proc is not None:
|
||||
proc.terminate()
|
||||
try:
|
||||
proc.wait(timeout=10)
|
||||
except Exception:
|
||||
proc.kill()
|
||||
return True
|
||||
return False
|
||||
|
||||
def list_running(self) -> list[str]:
|
||||
return [svc for svc in self.profile.services if self.is_running(svc)]
|
||||
|
||||
def get_url(self, service: str) -> str | None:
|
||||
spec = self._get_spec(service)
|
||||
if spec is None or not self.is_running(service):
|
||||
return None
|
||||
return f"http://{self.advertise_host}:{spec.host_port}"
|
||||
|
|
@ -1,123 +0,0 @@
|
|||
"""
|
||||
Probe running services to detect which models are currently loaded in VRAM.
|
||||
|
||||
Two probe strategies run together:
|
||||
|
||||
1. Well-known ports — always checked, regardless of who started the service.
|
||||
Catches ollama, vLLM, etc. running outside cf-orch management.
|
||||
|
||||
2. Managed services — services cf-orch started via ServiceManager.
|
||||
Checked on their configured host_port, deduplicates with well-known results.
|
||||
|
||||
Each service exposes a different introspection API:
|
||||
- vllm: GET /v1/models → {"data": [{"id": "<model-name>"}]}
|
||||
- ollama: GET /api/ps → {"models": [{"name": "<model>", "size_vram": <bytes>}]}
|
||||
|
||||
ollama can have multiple models loaded simultaneously; each is reported as a
|
||||
separate entry so the dashboard shows per-model residency.
|
||||
|
||||
The probe is best-effort: a timeout or connection refusal means model_name=None
|
||||
but the service is still reported as resident.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import urllib.request
|
||||
from typing import Any
|
||||
|
||||
from circuitforge_core.resources.profiles.schema import DockerSpec
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_PROBE_TIMEOUT_S = 2.0
|
||||
|
||||
# Well-known service ports probed on every heartbeat.
|
||||
# key → (service_name, prober_key)
|
||||
_WELL_KNOWN_PORTS: dict[int, str] = {
|
||||
11434: "ollama",
|
||||
8000: "vllm",
|
||||
8080: "vllm", # common alt vLLM port
|
||||
}
|
||||
|
||||
|
||||
def _fetch_json(url: str) -> dict[str, Any] | None:
|
||||
"""GET a URL and parse JSON; returns None on any error."""
|
||||
try:
|
||||
with urllib.request.urlopen(url, timeout=_PROBE_TIMEOUT_S) as resp:
|
||||
return json.loads(resp.read())
|
||||
except Exception as exc:
|
||||
logger.debug("Probe %s: %s", url, exc)
|
||||
return None
|
||||
|
||||
|
||||
def _probe_vllm(port: int) -> list[str]:
|
||||
data = _fetch_json(f"http://127.0.0.1:{port}/v1/models")
|
||||
if data and data.get("data"):
|
||||
return [m["id"] for m in data["data"] if m.get("id")]
|
||||
return []
|
||||
|
||||
|
||||
def _probe_ollama(port: int) -> list[str]:
|
||||
# /api/ps lists models currently *loaded in memory*, not just downloaded.
|
||||
data = _fetch_json(f"http://127.0.0.1:{port}/api/ps")
|
||||
if data and data.get("models"):
|
||||
return [m["name"] for m in data["models"] if m.get("name")]
|
||||
return []
|
||||
|
||||
|
||||
_PROBERS: dict[str, Any] = {
|
||||
"vllm": _probe_vllm,
|
||||
"ollama": _probe_ollama,
|
||||
}
|
||||
|
||||
|
||||
def probe_all(service_manager: Any) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Probe all services — both well-known ports and cf-orch managed services.
|
||||
|
||||
Returns a list of dicts: [{"service": str, "model_name": str | None}].
|
||||
Multiple loaded models in one service (e.g. two ollama models) each get
|
||||
their own entry, disambiguated as "ollama/0", "ollama/1", etc.
|
||||
"""
|
||||
results: list[dict[str, Any]] = []
|
||||
seen_ports: set[int] = set()
|
||||
|
||||
# ── 1. Well-known ports ──────────────────────────────────────────
|
||||
for port, service in _WELL_KNOWN_PORTS.items():
|
||||
prober = _PROBERS.get(service)
|
||||
if prober is None:
|
||||
continue
|
||||
models = prober(port)
|
||||
if not models:
|
||||
continue # nothing on this port right now
|
||||
seen_ports.add(port)
|
||||
if len(models) == 1:
|
||||
results.append({"service": service, "model_name": models[0]})
|
||||
else:
|
||||
for i, model in enumerate(models):
|
||||
results.append({"service": f"{service}/{i}", "model_name": model})
|
||||
|
||||
# ── 2. Managed services (cf-orch started) ───────────────────────
|
||||
if service_manager is not None:
|
||||
for service in service_manager.list_running():
|
||||
spec = service_manager._get_spec(service)
|
||||
if not isinstance(spec, DockerSpec):
|
||||
continue
|
||||
if spec.host_port in seen_ports:
|
||||
continue # already captured by well-known probe
|
||||
prober = _PROBERS.get(service)
|
||||
if prober is None:
|
||||
results.append({"service": service, "model_name": None})
|
||||
continue
|
||||
models = prober(spec.host_port)
|
||||
seen_ports.add(spec.host_port)
|
||||
if not models:
|
||||
results.append({"service": service, "model_name": None})
|
||||
elif len(models) == 1:
|
||||
results.append({"service": service, "model_name": models[0]})
|
||||
else:
|
||||
for i, model in enumerate(models):
|
||||
results.append({"service": f"{service}/{i}", "model_name": model})
|
||||
|
||||
return results
|
||||
|
|
@ -1,208 +0,0 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Annotated, Optional
|
||||
|
||||
import typer
|
||||
import uvicorn
|
||||
|
||||
app = typer.Typer(name="cf-orch", help="CircuitForge GPU resource orchestrator")
|
||||
|
||||
_SYSTEMD_UNIT_PATH = Path("/etc/systemd/system/cf-orch.service")
|
||||
|
||||
_SYSTEMD_UNIT_TEMPLATE = """\
|
||||
[Unit]
|
||||
Description=CircuitForge GPU Resource Orchestrator
|
||||
After=network.target
|
||||
|
||||
[Service]
|
||||
Type=simple
|
||||
ExecStart={python} -m circuitforge_core.resources.cli start
|
||||
Restart=on-failure
|
||||
RestartSec=5
|
||||
|
||||
[Install]
|
||||
WantedBy=multi-user.target
|
||||
"""
|
||||
|
||||
|
||||
@app.command()
|
||||
def start(
|
||||
profile: Annotated[Optional[Path], typer.Option(help="Profile YAML path")] = None,
|
||||
host: str = "0.0.0.0",
|
||||
port: int = 7700,
|
||||
node_id: str = "local",
|
||||
agent_port: int = 7701,
|
||||
) -> None:
|
||||
"""Start the cf-orch coordinator (auto-detects GPU profile if not specified).
|
||||
|
||||
Automatically pre-registers the local agent so its GPUs appear on the
|
||||
dashboard immediately. Remote nodes self-register via POST /api/nodes.
|
||||
"""
|
||||
from circuitforge_core.resources.coordinator.lease_manager import LeaseManager
|
||||
from circuitforge_core.resources.coordinator.profile_registry import ProfileRegistry
|
||||
from circuitforge_core.resources.coordinator.agent_supervisor import AgentSupervisor
|
||||
from circuitforge_core.resources.coordinator.app import create_coordinator_app
|
||||
from circuitforge_core.resources.coordinator.service_registry import ServiceRegistry
|
||||
from circuitforge_core.resources.agent.gpu_monitor import GpuMonitor
|
||||
|
||||
lease_manager = LeaseManager()
|
||||
profile_registry = ProfileRegistry()
|
||||
service_registry = ServiceRegistry()
|
||||
supervisor = AgentSupervisor(
|
||||
lease_manager=lease_manager,
|
||||
service_registry=service_registry,
|
||||
profile_registry=profile_registry,
|
||||
)
|
||||
|
||||
monitor = GpuMonitor()
|
||||
gpus = monitor.poll()
|
||||
if not gpus:
|
||||
typer.echo(
|
||||
"Warning: no GPUs detected via nvidia-smi — coordinator running with 0 VRAM"
|
||||
)
|
||||
else:
|
||||
typer.echo(f"Detected {len(gpus)} GPU(s)")
|
||||
|
||||
if profile:
|
||||
active_profile = profile_registry.load(profile)
|
||||
typer.echo(f"Using profile: {active_profile.name} (from {profile})")
|
||||
else:
|
||||
active_profile = (
|
||||
profile_registry.auto_detect(gpus)
|
||||
if gpus
|
||||
else profile_registry.list_public()[-1]
|
||||
)
|
||||
typer.echo(f"Auto-selected profile: {active_profile.name}")
|
||||
|
||||
# Pre-register the local agent — the heartbeat loop will poll it for live GPU data.
|
||||
local_agent_url = f"http://127.0.0.1:{agent_port}"
|
||||
supervisor.register(node_id, local_agent_url)
|
||||
typer.echo(f"Registered local node '{node_id}' → {local_agent_url}")
|
||||
|
||||
coordinator_app = create_coordinator_app(
|
||||
lease_manager=lease_manager,
|
||||
profile_registry=profile_registry,
|
||||
agent_supervisor=supervisor,
|
||||
service_registry=service_registry,
|
||||
)
|
||||
|
||||
typer.echo(f"Starting cf-orch coordinator on {host}:{port}")
|
||||
uvicorn.run(coordinator_app, host=host, port=port)
|
||||
|
||||
|
||||
@app.command()
|
||||
def agent(
|
||||
coordinator: str = "http://localhost:7700",
|
||||
node_id: str = "local",
|
||||
host: str = "0.0.0.0",
|
||||
port: int = 7701,
|
||||
advertise_host: Optional[str] = None,
|
||||
profile: Annotated[Optional[Path], typer.Option(help="Profile YAML path")] = None,
|
||||
) -> None:
|
||||
"""Start a cf-orch node agent and self-register with the coordinator.
|
||||
|
||||
The agent starts its HTTP server, then POSTs its URL to the coordinator
|
||||
so it appears on the dashboard without manual configuration.
|
||||
|
||||
Use --advertise-host to override the IP the coordinator should use to
|
||||
reach this agent (e.g. on a multi-homed or NATted host).
|
||||
"""
|
||||
import threading
|
||||
import httpx
|
||||
from circuitforge_core.resources.agent.app import create_agent_app
|
||||
from circuitforge_core.resources.agent.service_manager import ServiceManager
|
||||
from circuitforge_core.resources.coordinator.profile_registry import ProfileRegistry
|
||||
|
||||
# The URL the coordinator should use to reach this agent.
|
||||
reach_host = advertise_host or ("127.0.0.1" if host in ("0.0.0.0", "::") else host)
|
||||
agent_url = f"http://{reach_host}:{port}"
|
||||
|
||||
def _register_in_background() -> None:
|
||||
"""POST registration to coordinator after a short delay (uvicorn needs ~1s to bind)."""
|
||||
import time
|
||||
time.sleep(2.0)
|
||||
try:
|
||||
resp = httpx.post(
|
||||
f"{coordinator}/api/nodes",
|
||||
json={"node_id": node_id, "agent_url": agent_url},
|
||||
timeout=5.0,
|
||||
)
|
||||
if resp.is_success:
|
||||
typer.echo(f"Registered with coordinator at {coordinator} as '{node_id}'")
|
||||
else:
|
||||
typer.echo(
|
||||
f"Warning: coordinator registration returned {resp.status_code}", err=True
|
||||
)
|
||||
except Exception as exc:
|
||||
typer.echo(f"Warning: could not reach coordinator at {coordinator}: {exc}", err=True)
|
||||
|
||||
# Fire registration in a daemon thread so uvicorn.run() can start blocking immediately.
|
||||
threading.Thread(target=_register_in_background, daemon=True).start()
|
||||
|
||||
service_manager = None
|
||||
try:
|
||||
from circuitforge_core.resources.agent.gpu_monitor import GpuMonitor
|
||||
pr = ProfileRegistry()
|
||||
gpus = GpuMonitor().poll()
|
||||
p = pr.load(Path(profile)) if profile else pr.auto_detect(gpus)
|
||||
service_manager = ServiceManager(node_id=node_id, profile=p, advertise_host=reach_host)
|
||||
typer.echo(f"ServiceManager ready with profile: {p.name}")
|
||||
except Exception as exc:
|
||||
typer.echo(f"Warning: ServiceManager unavailable ({exc})", err=True)
|
||||
|
||||
agent_app = create_agent_app(node_id=node_id, service_manager=service_manager)
|
||||
typer.echo(f"Starting cf-orch agent [{node_id}] on {host}:{port}")
|
||||
uvicorn.run(agent_app, host=host, port=port)
|
||||
|
||||
|
||||
@app.command()
|
||||
def status(coordinator: str = "http://localhost:7700") -> None:
|
||||
"""Show GPU and lease status from the coordinator."""
|
||||
import httpx
|
||||
|
||||
try:
|
||||
resp = httpx.get(f"{coordinator}/api/nodes", timeout=5.0)
|
||||
resp.raise_for_status()
|
||||
nodes = resp.json().get("nodes", [])
|
||||
for node in nodes:
|
||||
typer.echo(f"\nNode: {node['node_id']}")
|
||||
for gpu in node.get("gpus", []):
|
||||
typer.echo(
|
||||
f" GPU {gpu['gpu_id']}: {gpu['name']} — "
|
||||
f"{gpu['vram_used_mb']}/{gpu['vram_total_mb']} MB used"
|
||||
)
|
||||
except Exception as exc:
|
||||
typer.echo(f"Coordinator unreachable at {coordinator}: {exc}", err=True)
|
||||
raise typer.Exit(1)
|
||||
|
||||
|
||||
@app.command("install-service")
|
||||
def install_service(
|
||||
dry_run: bool = typer.Option(
|
||||
False, "--dry-run", help="Print unit file without writing"
|
||||
),
|
||||
) -> None:
|
||||
"""Write a systemd unit file for cf-orch (requires root)."""
|
||||
python = sys.executable
|
||||
unit_content = _SYSTEMD_UNIT_TEMPLATE.format(python=python)
|
||||
if dry_run:
|
||||
typer.echo(f"Would write to {_SYSTEMD_UNIT_PATH}:\n")
|
||||
typer.echo(unit_content)
|
||||
return
|
||||
try:
|
||||
_SYSTEMD_UNIT_PATH.write_text(unit_content)
|
||||
typer.echo(f"Written: {_SYSTEMD_UNIT_PATH}")
|
||||
typer.echo(
|
||||
"Run: sudo systemctl daemon-reload && sudo systemctl enable --now cf-orch"
|
||||
)
|
||||
except PermissionError:
|
||||
typer.echo(
|
||||
f"Permission denied writing to {_SYSTEMD_UNIT_PATH}. Run as root.", err=True
|
||||
)
|
||||
raise typer.Exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app()
|
||||
|
|
@ -1,126 +0,0 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from contextlib import contextmanager, asynccontextmanager
|
||||
from dataclasses import dataclass
|
||||
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Allocation:
|
||||
allocation_id: str
|
||||
service: str
|
||||
node_id: str
|
||||
gpu_id: int
|
||||
model: str | None
|
||||
url: str
|
||||
started: bool
|
||||
warm: bool
|
||||
|
||||
|
||||
class CFOrchClient:
|
||||
"""
|
||||
Client for cf-orch coordinator allocation.
|
||||
|
||||
Sync usage (in LLMRouter or other sync code):
|
||||
client = CFOrchClient(os.environ["CF_ORCH_URL"])
|
||||
with client.allocate("vllm", model_candidates=["Ouro-1.4B"]) as alloc:
|
||||
# alloc.url is the inference endpoint
|
||||
|
||||
Async usage (in FastAPI apps):
|
||||
async with client.allocate_async("vllm", model_candidates=["Ouro-1.4B"]) as alloc:
|
||||
...
|
||||
|
||||
Raises ValueError immediately if coordinator_url is empty.
|
||||
"""
|
||||
|
||||
def __init__(self, coordinator_url: str) -> None:
|
||||
if not coordinator_url:
|
||||
raise ValueError("coordinator_url is empty — cf-orch not configured")
|
||||
self._url = coordinator_url.rstrip("/")
|
||||
|
||||
def _build_body(self, model_candidates: list[str] | None, ttl_s: float, caller: str) -> dict:
|
||||
return {
|
||||
"model_candidates": model_candidates or [],
|
||||
"ttl_s": ttl_s,
|
||||
"caller": caller,
|
||||
}
|
||||
|
||||
def _parse_allocation(self, data: dict, service: str) -> Allocation:
|
||||
return Allocation(
|
||||
allocation_id=data["allocation_id"],
|
||||
service=service,
|
||||
node_id=data["node_id"],
|
||||
gpu_id=data["gpu_id"],
|
||||
model=data.get("model"),
|
||||
url=data["url"],
|
||||
started=data.get("started", False),
|
||||
warm=data.get("warm", False),
|
||||
)
|
||||
|
||||
@contextmanager
|
||||
def allocate(
|
||||
self,
|
||||
service: str,
|
||||
*,
|
||||
model_candidates: list[str] | None = None,
|
||||
ttl_s: float = 3600.0,
|
||||
caller: str = "",
|
||||
):
|
||||
"""Sync context manager. Allocates on enter, releases on exit."""
|
||||
resp = httpx.post(
|
||||
f"{self._url}/api/services/{service}/allocate",
|
||||
json=self._build_body(model_candidates, ttl_s, caller),
|
||||
timeout=120.0,
|
||||
)
|
||||
if not resp.is_success:
|
||||
raise RuntimeError(
|
||||
f"cf-orch allocation failed for {service!r}: "
|
||||
f"HTTP {resp.status_code} — {resp.text[:200]}"
|
||||
)
|
||||
alloc = self._parse_allocation(resp.json(), service)
|
||||
try:
|
||||
yield alloc
|
||||
finally:
|
||||
try:
|
||||
httpx.delete(
|
||||
f"{self._url}/api/services/{service}/allocations/{alloc.allocation_id}",
|
||||
timeout=10.0,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.debug("cf-orch release failed (non-fatal): %s", exc)
|
||||
|
||||
@asynccontextmanager
|
||||
async def allocate_async(
|
||||
self,
|
||||
service: str,
|
||||
*,
|
||||
model_candidates: list[str] | None = None,
|
||||
ttl_s: float = 3600.0,
|
||||
caller: str = "",
|
||||
):
|
||||
"""Async context manager. Allocates on enter, releases on exit."""
|
||||
async with httpx.AsyncClient(timeout=120.0) as client:
|
||||
resp = await client.post(
|
||||
f"{self._url}/api/services/{service}/allocate",
|
||||
json=self._build_body(model_candidates, ttl_s, caller),
|
||||
)
|
||||
if not resp.is_success:
|
||||
raise RuntimeError(
|
||||
f"cf-orch allocation failed for {service!r}: "
|
||||
f"HTTP {resp.status_code} — {resp.text[:200]}"
|
||||
)
|
||||
alloc = self._parse_allocation(resp.json(), service)
|
||||
try:
|
||||
yield alloc
|
||||
finally:
|
||||
try:
|
||||
await client.delete(
|
||||
f"{self._url}/api/services/{service}/allocations/{alloc.allocation_id}",
|
||||
timeout=10.0,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.debug("cf-orch async release failed (non-fatal): %s", exc)
|
||||
|
|
@ -1,44 +0,0 @@
|
|||
# circuitforge_core/resources/compose.yml
|
||||
# One-command cf-orch deployment for Docker self-hosters:
|
||||
# docker compose -f path/to/compose.yml up cf-orch-coordinator
|
||||
|
||||
services:
|
||||
cf-orch-coordinator:
|
||||
image: python:3.12-slim
|
||||
command: >
|
||||
sh -c "pip install 'circuitforge-core[orch]' &&
|
||||
cf-orch start --host 0.0.0.0 --port 7700"
|
||||
ports:
|
||||
- "7700:7700"
|
||||
volumes:
|
||||
- /run/docker.sock:/var/run/docker.sock:ro
|
||||
- cf-orch-data:/data
|
||||
environment:
|
||||
- CFORCH_PROFILE=${CFORCH_PROFILE:-}
|
||||
restart: unless-stopped
|
||||
devices:
|
||||
- /dev/nvidia0:/dev/nvidia0
|
||||
- /dev/nvidiactl:/dev/nvidiactl
|
||||
runtime: nvidia
|
||||
|
||||
cf-orch-agent:
|
||||
image: python:3.12-slim
|
||||
command: >
|
||||
sh -c "pip install 'circuitforge-core[orch]' &&
|
||||
cf-orch agent --coordinator http://cf-orch-coordinator:7700
|
||||
--node-id ${CFORCH_NODE_ID:-local}
|
||||
--host 0.0.0.0 --port 7701"
|
||||
ports:
|
||||
- "7701:7701"
|
||||
depends_on:
|
||||
- cf-orch-coordinator
|
||||
environment:
|
||||
- CFORCH_NODE_ID=${CFORCH_NODE_ID:-local}
|
||||
restart: unless-stopped
|
||||
devices:
|
||||
- /dev/nvidia0:/dev/nvidia0
|
||||
- /dev/nvidiactl:/dev/nvidiactl
|
||||
runtime: nvidia
|
||||
|
||||
volumes:
|
||||
cf-orch-data:
|
||||
|
|
@ -1,182 +0,0 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import httpx
|
||||
|
||||
from circuitforge_core.resources.coordinator.lease_manager import LeaseManager
|
||||
from circuitforge_core.resources.coordinator.profile_registry import ProfileRegistry
|
||||
from circuitforge_core.resources.coordinator.service_registry import ServiceRegistry
|
||||
from circuitforge_core.resources.models import GpuInfo, NodeInfo, ResidentAllocation
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_HEARTBEAT_INTERVAL_S = 10.0
|
||||
_AGENT_TIMEOUT_S = 5.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentRecord:
|
||||
node_id: str
|
||||
agent_url: str
|
||||
last_seen: float = field(default_factory=time.time)
|
||||
gpus: list[GpuInfo] = field(default_factory=list)
|
||||
online: bool = False
|
||||
|
||||
|
||||
class AgentSupervisor:
|
||||
def __init__(
|
||||
self,
|
||||
lease_manager: LeaseManager,
|
||||
service_registry: ServiceRegistry | None = None,
|
||||
profile_registry: ProfileRegistry | None = None,
|
||||
) -> None:
|
||||
self._agents: dict[str, AgentRecord] = {}
|
||||
self._lease_manager = lease_manager
|
||||
self._running = False
|
||||
self._service_registry = service_registry
|
||||
self._profile_registry = profile_registry
|
||||
self._heartbeat_tick = 0
|
||||
|
||||
def register(self, node_id: str, agent_url: str) -> None:
|
||||
if node_id not in self._agents:
|
||||
self._agents[node_id] = AgentRecord(node_id=node_id, agent_url=agent_url)
|
||||
logger.info("Registered agent node: %s @ %s", node_id, agent_url)
|
||||
else:
|
||||
if self._agents[node_id].agent_url != agent_url:
|
||||
self._agents[node_id].agent_url = agent_url
|
||||
logger.info("Updated agent URL for %s → %s", node_id, agent_url)
|
||||
|
||||
def get_node_info(self, node_id: str) -> NodeInfo | None:
|
||||
record = self._agents.get(node_id)
|
||||
if record is None:
|
||||
return None
|
||||
return NodeInfo(
|
||||
node_id=record.node_id,
|
||||
agent_url=record.agent_url,
|
||||
gpus=record.gpus,
|
||||
last_heartbeat=record.last_seen,
|
||||
)
|
||||
|
||||
def all_nodes(self) -> list[NodeInfo]:
|
||||
return [
|
||||
NodeInfo(
|
||||
node_id=r.node_id,
|
||||
agent_url=r.agent_url,
|
||||
gpus=r.gpus,
|
||||
last_heartbeat=r.last_seen,
|
||||
)
|
||||
for r in self._agents.values()
|
||||
]
|
||||
|
||||
def online_agents(self) -> "dict[str, AgentRecord]":
|
||||
"""Return only currently-online agents, keyed by node_id."""
|
||||
return {nid: rec for nid, rec in self._agents.items() if rec.online}
|
||||
|
||||
async def poll_agent(self, node_id: str) -> bool:
|
||||
record = self._agents.get(node_id)
|
||||
if record is None:
|
||||
return False
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=_AGENT_TIMEOUT_S) as client:
|
||||
gpu_resp = await client.get(f"{record.agent_url}/gpu-info")
|
||||
gpu_resp.raise_for_status()
|
||||
|
||||
# Resident-info is best-effort — older agents may not have the endpoint.
|
||||
try:
|
||||
res_resp = await client.get(f"{record.agent_url}/resident-info")
|
||||
resident_data = res_resp.json() if res_resp.is_success else {}
|
||||
except Exception:
|
||||
resident_data = {}
|
||||
|
||||
data = gpu_resp.json()
|
||||
gpus = [
|
||||
GpuInfo(
|
||||
gpu_id=g["gpu_id"],
|
||||
name=g["name"],
|
||||
vram_total_mb=g["vram_total_mb"],
|
||||
vram_used_mb=g["vram_used_mb"],
|
||||
vram_free_mb=g["vram_free_mb"],
|
||||
)
|
||||
for g in data.get("gpus", [])
|
||||
]
|
||||
record.gpus = gpus
|
||||
record.last_seen = time.time()
|
||||
record.online = True
|
||||
for gpu in gpus:
|
||||
self._lease_manager.register_gpu(node_id, gpu.gpu_id, gpu.vram_total_mb)
|
||||
|
||||
residents = [
|
||||
(r["service"], r.get("model_name"))
|
||||
for r in resident_data.get("residents", [])
|
||||
]
|
||||
self._lease_manager.set_residents_for_node(node_id, residents)
|
||||
|
||||
return True
|
||||
except Exception as exc:
|
||||
logger.warning("Agent %s unreachable: %s", node_id, exc)
|
||||
record.online = False
|
||||
return False
|
||||
|
||||
async def poll_all(self) -> None:
|
||||
await asyncio.gather(*[self.poll_agent(nid) for nid in self._agents])
|
||||
|
||||
def _build_idle_stop_config(self) -> dict[str, int]:
|
||||
if self._profile_registry is None:
|
||||
return {}
|
||||
config: dict[str, int] = {}
|
||||
for profile in self._profile_registry.list_public():
|
||||
for svc_name, svc in profile.services.items():
|
||||
if svc.idle_stop_after_s > 0:
|
||||
existing = config.get(svc_name, 0)
|
||||
config[svc_name] = min(existing, svc.idle_stop_after_s) if existing > 0 else svc.idle_stop_after_s
|
||||
return config
|
||||
|
||||
async def _http_post(self, url: str) -> bool:
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
resp = await client.post(url)
|
||||
return resp.is_success
|
||||
except Exception as exc:
|
||||
logger.warning("HTTP POST %s failed: %s", url, exc)
|
||||
return False
|
||||
|
||||
async def _run_idle_sweep(self) -> None:
|
||||
if self._service_registry is None:
|
||||
return
|
||||
expired = self._service_registry.sweep_expired_allocations()
|
||||
if expired:
|
||||
logger.info("TTL sweep: expired %d allocation(s): %s", len(expired), expired)
|
||||
idle_stop_config = self._build_idle_stop_config()
|
||||
if not idle_stop_config:
|
||||
return
|
||||
timed_out = self._service_registry.idle_past_timeout(idle_stop_config)
|
||||
for instance in timed_out:
|
||||
node_info = self.get_node_info(instance.node_id)
|
||||
if node_info is None:
|
||||
continue
|
||||
stop_url = f"{node_info.agent_url}/services/{instance.service}/stop"
|
||||
logger.info(
|
||||
"Idle sweep: stopping %s on %s gpu%s (idle timeout)",
|
||||
instance.service, instance.node_id, instance.gpu_id,
|
||||
)
|
||||
success = await self._http_post(stop_url)
|
||||
if success:
|
||||
self._service_registry.mark_stopped(
|
||||
instance.service, instance.node_id, instance.gpu_id
|
||||
)
|
||||
|
||||
async def run_heartbeat_loop(self) -> None:
|
||||
self._running = True
|
||||
while self._running:
|
||||
await self.poll_all()
|
||||
self._heartbeat_tick += 1
|
||||
if self._heartbeat_tick % 3 == 0:
|
||||
await self._run_idle_sweep()
|
||||
await asyncio.sleep(_HEARTBEAT_INTERVAL_S)
|
||||
|
||||
def stop(self) -> None:
|
||||
self._running = False
|
||||
|
|
@ -1,487 +0,0 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
import urllib.request
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.responses import HTMLResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from circuitforge_core.resources.coordinator.agent_supervisor import AgentSupervisor
|
||||
from circuitforge_core.resources.coordinator.eviction_engine import EvictionEngine
|
||||
from circuitforge_core.resources.coordinator.lease_manager import LeaseManager
|
||||
from circuitforge_core.resources.coordinator.node_selector import select_node
|
||||
from circuitforge_core.resources.coordinator.profile_registry import ProfileRegistry
|
||||
from circuitforge_core.resources.coordinator.service_registry import ServiceRegistry
|
||||
|
||||
_DASHBOARD_HTML = (Path(__file__).parent / "dashboard.html").read_text()
|
||||
|
||||
_PROBE_INTERVAL_S = 5.0 # how often to poll starting instances
|
||||
_PROBE_TIMEOUT_S = 300.0 # give up and mark stopped after this many seconds
|
||||
|
||||
|
||||
async def _run_instance_probe_loop(service_registry: ServiceRegistry) -> None:
|
||||
"""
|
||||
Background loop: transition 'starting' instances to 'running' once their
|
||||
/health endpoint responds, or to 'stopped' after PROBE_TIMEOUT_S.
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
start_times: dict[str, float] = {} # instance key → time first seen as starting
|
||||
|
||||
while True:
|
||||
await asyncio.sleep(_PROBE_INTERVAL_S)
|
||||
now = time.time()
|
||||
for inst in service_registry.all_instances():
|
||||
if inst.state != "starting":
|
||||
start_times.pop(f"{inst.service}:{inst.node_id}:{inst.gpu_id}", None)
|
||||
continue
|
||||
key = f"{inst.service}:{inst.node_id}:{inst.gpu_id}"
|
||||
start_times.setdefault(key, now)
|
||||
|
||||
healthy = False
|
||||
if inst.url:
|
||||
try:
|
||||
with urllib.request.urlopen(
|
||||
inst.url.rstrip("/") + "/health", timeout=2.0
|
||||
) as resp:
|
||||
healthy = resp.status == 200
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if healthy:
|
||||
service_registry.upsert_instance(
|
||||
service=inst.service, node_id=inst.node_id, gpu_id=inst.gpu_id,
|
||||
state="running", model=inst.model, url=inst.url,
|
||||
)
|
||||
start_times.pop(key, None)
|
||||
logger.info("Instance %s/%s gpu=%s transitioned to running", inst.service, inst.node_id, inst.gpu_id)
|
||||
elif now - start_times[key] > _PROBE_TIMEOUT_S:
|
||||
service_registry.upsert_instance(
|
||||
service=inst.service, node_id=inst.node_id, gpu_id=inst.gpu_id,
|
||||
state="stopped", model=inst.model, url=inst.url,
|
||||
)
|
||||
start_times.pop(key, None)
|
||||
logger.warning("Instance %s/%s gpu=%s timed out in starting state — marked stopped", inst.service, inst.node_id, inst.gpu_id)
|
||||
|
||||
|
||||
class LeaseRequest(BaseModel):
|
||||
node_id: str
|
||||
gpu_id: int
|
||||
mb: int
|
||||
service: str
|
||||
priority: int = 2
|
||||
ttl_s: float = 0.0
|
||||
|
||||
|
||||
class NodeRegisterRequest(BaseModel):
|
||||
node_id: str
|
||||
agent_url: str # e.g. "http://10.1.10.71:7701"
|
||||
|
||||
|
||||
class ServiceEnsureRequest(BaseModel):
|
||||
node_id: str
|
||||
gpu_id: int = 0
|
||||
params: dict[str, str] = {}
|
||||
ttl_s: float = 3600.0
|
||||
# Ordered list of model names to try; falls back down the list if VRAM is tight.
|
||||
# The "model" key in params is used if this list is empty.
|
||||
model_candidates: list[str] = []
|
||||
|
||||
|
||||
class ServiceAllocateRequest(BaseModel):
|
||||
model_candidates: list[str] = []
|
||||
gpu_id: int | None = None
|
||||
params: dict[str, str] = {}
|
||||
ttl_s: float = 3600.0
|
||||
caller: str = ""
|
||||
|
||||
|
||||
def create_coordinator_app(
|
||||
lease_manager: LeaseManager,
|
||||
profile_registry: ProfileRegistry,
|
||||
agent_supervisor: AgentSupervisor,
|
||||
service_registry: ServiceRegistry,
|
||||
) -> FastAPI:
|
||||
eviction_engine = EvictionEngine(lease_manager=lease_manager)
|
||||
|
||||
@asynccontextmanager
|
||||
async def _lifespan(app: FastAPI): # type: ignore[type-arg]
|
||||
import asyncio
|
||||
heartbeat_task = asyncio.create_task(agent_supervisor.run_heartbeat_loop())
|
||||
probe_task = asyncio.create_task(_run_instance_probe_loop(service_registry))
|
||||
yield
|
||||
agent_supervisor.stop()
|
||||
heartbeat_task.cancel()
|
||||
probe_task.cancel()
|
||||
|
||||
app = FastAPI(title="cf-orch-coordinator", lifespan=_lifespan)
|
||||
|
||||
@app.get("/", response_class=HTMLResponse, include_in_schema=False)
|
||||
def dashboard() -> HTMLResponse:
|
||||
return HTMLResponse(content=_DASHBOARD_HTML)
|
||||
|
||||
@app.get("/api/health")
|
||||
def health() -> dict[str, Any]:
|
||||
return {"status": "ok"}
|
||||
|
||||
@app.get("/api/nodes")
|
||||
def get_nodes() -> dict[str, Any]:
|
||||
nodes = agent_supervisor.all_nodes()
|
||||
return {
|
||||
"nodes": [
|
||||
{
|
||||
"node_id": n.node_id,
|
||||
"agent_url": n.agent_url,
|
||||
"last_heartbeat": n.last_heartbeat,
|
||||
"gpus": [
|
||||
{
|
||||
"gpu_id": g.gpu_id,
|
||||
"name": g.name,
|
||||
"vram_total_mb": g.vram_total_mb,
|
||||
"vram_used_mb": g.vram_used_mb,
|
||||
"vram_free_mb": g.vram_free_mb,
|
||||
}
|
||||
for g in n.gpus
|
||||
],
|
||||
}
|
||||
for n in nodes
|
||||
]
|
||||
}
|
||||
|
||||
@app.post("/api/nodes")
|
||||
async def register_node(req: NodeRegisterRequest) -> dict[str, Any]:
|
||||
"""Agents call this to self-register. Coordinator immediately polls for GPU info."""
|
||||
agent_supervisor.register(req.node_id, req.agent_url)
|
||||
await agent_supervisor.poll_agent(req.node_id)
|
||||
return {"registered": True, "node_id": req.node_id}
|
||||
|
||||
@app.get("/api/profiles")
|
||||
def get_profiles() -> dict[str, Any]:
|
||||
return {
|
||||
"profiles": [
|
||||
{"name": p.name, "vram_total_mb": p.vram_total_mb}
|
||||
for p in profile_registry.list_public()
|
||||
]
|
||||
}
|
||||
|
||||
@app.get("/api/resident")
|
||||
def get_residents() -> dict[str, Any]:
|
||||
return {
|
||||
"residents": [
|
||||
{
|
||||
"service": r.service,
|
||||
"node_id": r.node_id,
|
||||
"model_name": r.model_name,
|
||||
"first_seen": r.first_seen,
|
||||
}
|
||||
for r in lease_manager.all_residents()
|
||||
]
|
||||
}
|
||||
|
||||
@app.get("/api/leases")
|
||||
def get_leases() -> dict[str, Any]:
|
||||
return {
|
||||
"leases": [
|
||||
{
|
||||
"lease_id": lease.lease_id,
|
||||
"node_id": lease.node_id,
|
||||
"gpu_id": lease.gpu_id,
|
||||
"mb_granted": lease.mb_granted,
|
||||
"holder_service": lease.holder_service,
|
||||
"priority": lease.priority,
|
||||
"expires_at": lease.expires_at,
|
||||
}
|
||||
for lease in lease_manager.all_leases()
|
||||
]
|
||||
}
|
||||
|
||||
@app.post("/api/leases")
|
||||
async def request_lease(req: LeaseRequest) -> dict[str, Any]:
|
||||
node_info = agent_supervisor.get_node_info(req.node_id)
|
||||
if node_info is None:
|
||||
raise HTTPException(
|
||||
status_code=422,
|
||||
detail=f"Unknown node_id {req.node_id!r} — node not registered",
|
||||
)
|
||||
agent_url = node_info.agent_url
|
||||
|
||||
lease = await eviction_engine.request_lease(
|
||||
node_id=req.node_id,
|
||||
gpu_id=req.gpu_id,
|
||||
mb=req.mb,
|
||||
service=req.service,
|
||||
priority=req.priority,
|
||||
agent_url=agent_url,
|
||||
ttl_s=req.ttl_s,
|
||||
)
|
||||
if lease is None:
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Insufficient VRAM — no eviction candidates available",
|
||||
)
|
||||
return {
|
||||
"lease": {
|
||||
"lease_id": lease.lease_id,
|
||||
"node_id": lease.node_id,
|
||||
"gpu_id": lease.gpu_id,
|
||||
"mb_granted": lease.mb_granted,
|
||||
"holder_service": lease.holder_service,
|
||||
"priority": lease.priority,
|
||||
"expires_at": lease.expires_at,
|
||||
}
|
||||
}
|
||||
|
||||
@app.delete("/api/leases/{lease_id}")
|
||||
async def release_lease(lease_id: str) -> dict[str, Any]:
|
||||
released = await lease_manager.release(lease_id)
|
||||
if not released:
|
||||
raise HTTPException(status_code=404, detail=f"Lease {lease_id!r} not found")
|
||||
return {"released": True, "lease_id": lease_id}
|
||||
|
||||
@app.post("/api/services/{service}/ensure")
|
||||
async def ensure_service(service: str, req: ServiceEnsureRequest) -> dict[str, Any]:
|
||||
"""
|
||||
Ensure a managed service is running on the given node.
|
||||
|
||||
If model_candidates is provided, tries each model in order, skipping any
|
||||
that exceed the live free VRAM on the target GPU. Falls back down the list
|
||||
until one succeeds. The selected model is returned in the response.
|
||||
"""
|
||||
import httpx
|
||||
|
||||
node_info = agent_supervisor.get_node_info(req.node_id)
|
||||
if node_info is None:
|
||||
raise HTTPException(422, detail=f"Unknown node_id {req.node_id!r}")
|
||||
|
||||
# Resolve candidate list — fall back to params["model"] if not specified.
|
||||
candidates: list[str] = req.model_candidates or (
|
||||
[req.params["model"]] if "model" in req.params else []
|
||||
)
|
||||
if not candidates:
|
||||
raise HTTPException(422, detail="No model specified: set params.model or model_candidates")
|
||||
|
||||
# Live free VRAM on the target GPU (used for pre-flight filtering).
|
||||
gpu = next((g for g in node_info.gpus if g.gpu_id == req.gpu_id), None)
|
||||
free_mb = gpu.vram_free_mb if gpu else 0
|
||||
|
||||
# Profile max_mb for the service gives us the VRAM ceiling for this slot.
|
||||
# Models larger than free_mb are skipped before we even try to start them.
|
||||
# We use model file size as a rough proxy — skip if free_mb < half of max_mb,
|
||||
# since a fully-loaded model typically needs ~50-80% of its param size in VRAM.
|
||||
service_max_mb = 0
|
||||
for p in profile_registry.list_public():
|
||||
svc = p.services.get(service)
|
||||
if svc:
|
||||
service_max_mb = svc.max_mb
|
||||
break
|
||||
|
||||
# Filter candidates by VRAM headroom — require free VRAM >= service ceiling
|
||||
# so the model can actually load without competing for VRAM with other processes.
|
||||
if service_max_mb > 0 and free_mb < service_max_mb:
|
||||
raise HTTPException(
|
||||
503,
|
||||
detail=f"Insufficient VRAM on gpu {req.gpu_id}: {free_mb}MB free, need {service_max_mb}MB",
|
||||
)
|
||||
|
||||
last_error: str = ""
|
||||
async with httpx.AsyncClient(timeout=120.0) as client:
|
||||
for model in candidates:
|
||||
params_with_model = {**req.params, "model": model}
|
||||
try:
|
||||
start_resp = await client.post(
|
||||
f"{node_info.agent_url}/services/{service}/start",
|
||||
json={"gpu_id": req.gpu_id, "params": params_with_model},
|
||||
)
|
||||
if start_resp.is_success:
|
||||
data = start_resp.json()
|
||||
return {
|
||||
"service": service,
|
||||
"node_id": req.node_id,
|
||||
"gpu_id": req.gpu_id,
|
||||
"model": model,
|
||||
"url": data.get("url"),
|
||||
"running": data.get("running", False),
|
||||
}
|
||||
last_error = start_resp.text
|
||||
except httpx.HTTPError as exc:
|
||||
raise HTTPException(502, detail=f"Agent unreachable: {exc}")
|
||||
|
||||
raise HTTPException(
|
||||
503,
|
||||
detail=f"All model candidates exhausted for {service!r}. Last error: {last_error}",
|
||||
)
|
||||
|
||||
@app.post("/api/services/{service}/allocate")
|
||||
async def allocate_service(service: str, req: ServiceAllocateRequest) -> dict[str, Any]:
|
||||
"""
|
||||
Allocate a managed service — coordinator picks the best node automatically.
|
||||
Returns a URL + allocation_id. (Allocation not tracked server-side until Phase 2.)
|
||||
"""
|
||||
import httpx
|
||||
|
||||
if not req.model_candidates:
|
||||
raise HTTPException(422, detail="model_candidates must be non-empty")
|
||||
|
||||
# Validate service is known in at least one profile, regardless of gpu_id
|
||||
if not any(service in p.services for p in profile_registry.list_public()):
|
||||
raise HTTPException(422, detail=f"Unknown service {service!r} — not in any profile")
|
||||
|
||||
residents = lease_manager.resident_keys()
|
||||
|
||||
if req.gpu_id is None:
|
||||
online = agent_supervisor.online_agents()
|
||||
placement = select_node(online, service, profile_registry, residents)
|
||||
if placement is None:
|
||||
raise HTTPException(
|
||||
503,
|
||||
detail=f"No online node has capacity for service {service!r}",
|
||||
)
|
||||
node_id, gpu_id = placement
|
||||
else:
|
||||
online = agent_supervisor.online_agents()
|
||||
node_id = next(
|
||||
(nid for nid, rec in online.items()
|
||||
if any(g.gpu_id == req.gpu_id for g in rec.gpus)),
|
||||
None,
|
||||
)
|
||||
if node_id is None:
|
||||
raise HTTPException(422, detail=f"No online node has gpu_id={req.gpu_id}")
|
||||
gpu_id = req.gpu_id
|
||||
|
||||
node_info = agent_supervisor.get_node_info(node_id)
|
||||
if node_info is None:
|
||||
raise HTTPException(422, detail=f"Node {node_id!r} not found")
|
||||
|
||||
warm = f"{node_id}:{service}" in residents
|
||||
|
||||
async with httpx.AsyncClient(timeout=120.0) as client:
|
||||
last_error = ""
|
||||
for model in req.model_candidates:
|
||||
try:
|
||||
resp = await client.post(
|
||||
f"{node_info.agent_url}/services/{service}/start",
|
||||
json={"gpu_id": gpu_id, "params": {**req.params, "model": model}},
|
||||
)
|
||||
if resp.is_success:
|
||||
data = resp.json()
|
||||
svc_url = data.get("url", "")
|
||||
alloc = service_registry.allocate(
|
||||
service=service,
|
||||
node_id=node_id,
|
||||
gpu_id=gpu_id,
|
||||
model=model,
|
||||
caller=req.caller,
|
||||
url=svc_url,
|
||||
ttl_s=req.ttl_s,
|
||||
)
|
||||
# Seed the instance state for first-time starts
|
||||
instance_state = "running" if warm else "starting"
|
||||
service_registry.upsert_instance(
|
||||
service=service,
|
||||
node_id=node_id,
|
||||
gpu_id=gpu_id,
|
||||
state=instance_state,
|
||||
model=model,
|
||||
url=svc_url,
|
||||
)
|
||||
return {
|
||||
"allocation_id": alloc.allocation_id,
|
||||
"service": service,
|
||||
"node_id": node_id,
|
||||
"gpu_id": gpu_id,
|
||||
"model": model,
|
||||
"url": data.get("url"),
|
||||
"started": not warm,
|
||||
"warm": warm,
|
||||
}
|
||||
last_error = resp.text
|
||||
except httpx.HTTPError as exc:
|
||||
raise HTTPException(502, detail=f"Agent unreachable: {exc}")
|
||||
|
||||
raise HTTPException(
|
||||
503,
|
||||
detail=f"All model candidates exhausted for {service!r}. Last error: {last_error}",
|
||||
)
|
||||
|
||||
@app.delete("/api/services/{service}/allocations/{allocation_id}")
|
||||
async def release_allocation(service: str, allocation_id: str) -> dict[str, Any]:
|
||||
existing = service_registry.get_allocation(allocation_id)
|
||||
if existing is None or existing.service != service:
|
||||
raise HTTPException(404, detail=f"Allocation {allocation_id!r} not found for service {service!r}")
|
||||
released = service_registry.release(allocation_id)
|
||||
if not released:
|
||||
raise HTTPException(404, detail=f"Allocation {allocation_id!r} not found")
|
||||
return {"released": True, "allocation_id": allocation_id}
|
||||
|
||||
@app.get("/api/services/{service}/status")
|
||||
def get_service_status(service: str) -> dict[str, Any]:
|
||||
instances = [i for i in service_registry.all_instances() if i.service == service]
|
||||
allocations = [a for a in service_registry.all_allocations() if a.service == service]
|
||||
return {
|
||||
"service": service,
|
||||
"instances": [
|
||||
{
|
||||
"node_id": i.node_id,
|
||||
"gpu_id": i.gpu_id,
|
||||
"state": i.state,
|
||||
"model": i.model,
|
||||
"url": i.url,
|
||||
"idle_since": i.idle_since,
|
||||
}
|
||||
for i in instances
|
||||
],
|
||||
"allocations": [
|
||||
{
|
||||
"allocation_id": a.allocation_id,
|
||||
"node_id": a.node_id,
|
||||
"gpu_id": a.gpu_id,
|
||||
"model": a.model,
|
||||
"caller": a.caller,
|
||||
"url": a.url,
|
||||
"expires_at": a.expires_at,
|
||||
}
|
||||
for a in allocations
|
||||
],
|
||||
}
|
||||
|
||||
@app.get("/api/services")
|
||||
def list_services() -> dict[str, Any]:
|
||||
instances = service_registry.all_instances()
|
||||
return {
|
||||
"services": [
|
||||
{
|
||||
"service": i.service,
|
||||
"node_id": i.node_id,
|
||||
"gpu_id": i.gpu_id,
|
||||
"state": i.state,
|
||||
"model": i.model,
|
||||
"url": i.url,
|
||||
}
|
||||
for i in instances
|
||||
]
|
||||
}
|
||||
|
||||
@app.delete("/api/services/{service}")
|
||||
async def stop_service(service: str, node_id: str) -> dict[str, Any]:
|
||||
"""Stop a managed service on the given node."""
|
||||
node_info = agent_supervisor.get_node_info(node_id)
|
||||
if node_info is None:
|
||||
raise HTTPException(422, detail=f"Unknown node_id {node_id!r}")
|
||||
|
||||
import httpx
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
try:
|
||||
resp = await client.post(f"{node_info.agent_url}/services/{service}/stop")
|
||||
resp.raise_for_status()
|
||||
return {"service": service, "node_id": node_id, "stopped": resp.json().get("stopped", False)}
|
||||
except httpx.HTTPError as exc:
|
||||
raise HTTPException(502, detail=f"Agent unreachable: {exc}")
|
||||
|
||||
return app
|
||||
|
|
@ -1,473 +0,0 @@
|
|||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>cf-orch · dashboard</title>
|
||||
<style>
|
||||
*, *::before, *::after { box-sizing: border-box; margin: 0; padding: 0; }
|
||||
|
||||
:root {
|
||||
--bg: #0d1117;
|
||||
--bg2: #161b22;
|
||||
--bg3: #1c2129;
|
||||
--border: #30363d;
|
||||
--border-dim: #21262d;
|
||||
--text: #e6edf3;
|
||||
--muted: #8b949e;
|
||||
--dim: #4d5763;
|
||||
--indigo: #818cf8;
|
||||
--cyan: #22d3ee;
|
||||
--green: #4ade80;
|
||||
--amber: #fbbf24;
|
||||
--red: #f85149;
|
||||
--orange: #fb923c;
|
||||
--radius: 6px;
|
||||
--radius-sm: 3px;
|
||||
--font: 'JetBrains Mono', 'Fira Code', ui-monospace, monospace;
|
||||
}
|
||||
|
||||
body { background: var(--bg); color: var(--text); font-family: var(--font); font-size: 13px; line-height: 1.5; padding: 1rem; }
|
||||
|
||||
/* header */
|
||||
header { display: flex; align-items: center; gap: 1rem; margin-bottom: 1rem; padding-bottom: 0.75rem; border-bottom: 1px solid var(--border); }
|
||||
.logo { color: var(--indigo); font-size: 1.1em; font-weight: 700; }
|
||||
#refresh-badge { margin-left: auto; font-size: 0.75em; color: var(--dim); }
|
||||
#refresh-badge span { color: var(--green); }
|
||||
|
||||
/* section labels */
|
||||
.section-label { font-size: 0.72em; font-weight: 600; text-transform: uppercase; letter-spacing: 0.07em; color: var(--dim); margin-bottom: 0.5rem; }
|
||||
|
||||
/* health strip */
|
||||
#health-strip { display: flex; flex-wrap: wrap; gap: 0.4rem; margin-bottom: 1rem; padding: 0.6rem 0.75rem; background: var(--bg2); border: 1px solid var(--border); border-radius: var(--radius); min-height: 36px; }
|
||||
.pill { display: inline-flex; align-items: center; gap: 0.3rem; padding: 2px 10px; border-radius: 99px; font-size: 0.8em; font-weight: 600; }
|
||||
.pill.ok { background: rgba(74,222,128,.12); color: var(--green); }
|
||||
.pill.err { background: rgba(248,81,73,.12); color: var(--red); }
|
||||
.pill.off { background: rgba(139,148,158,.1); color: var(--dim); }
|
||||
|
||||
/* GPU grid */
|
||||
#gpu-grid { display: grid; grid-template-columns: repeat(auto-fill, minmax(180px, 1fr)); gap: 0.6rem; margin-bottom: 1rem; }
|
||||
.gpu-card { background: var(--bg3); border: 1px solid var(--border); border-radius: var(--radius); padding: 0.7rem 0.8rem; }
|
||||
.gpu-card.offline { border-color: #7c2d12; opacity: 0.7; }
|
||||
.gpu-node { font-size: 0.75em; font-weight: 700; color: var(--indigo); margin-bottom: 1px; }
|
||||
.gpu-offline .gpu-node { color: var(--orange); }
|
||||
.gpu-name { font-size: 0.78em; color: var(--text); margin-bottom: 0.4rem; }
|
||||
.vram-track { position: relative; background: var(--bg); border-radius: var(--radius-sm); height: 6px; margin-bottom: 0.3rem; overflow: hidden; }
|
||||
.vram-leased { position: absolute; left: 0; top: 0; height: 100%; background: var(--cyan); transition: width 0.4s; }
|
||||
.vram-resident { position: absolute; top: 0; height: 100%; background: var(--amber); transition: left 0.4s, width 0.4s; }
|
||||
.vram-label { font-size: 0.72em; color: var(--muted); margin-bottom: 0.25rem; }
|
||||
.gpu-status { font-size: 0.72em; }
|
||||
.gpu-status.idle { color: var(--green); }
|
||||
.gpu-status.busy { color: var(--amber); }
|
||||
.gpu-status.full { color: var(--red); }
|
||||
.gpu-status.offline { color: var(--orange); }
|
||||
.spark-track { height: 24px; background: var(--bg); border-radius: var(--radius-sm); margin-top: 0.4rem; overflow: hidden; }
|
||||
|
||||
/* shared table base */
|
||||
.cf-table { width: 100%; border-collapse: collapse; background: var(--bg2); border: 1px solid var(--border); border-radius: var(--radius); overflow: hidden; margin-bottom: 1rem; }
|
||||
.cf-table th { background: var(--bg3); color: var(--dim); font-size: 0.72em; font-weight: 600; text-transform: uppercase; letter-spacing: 0.05em; padding: 0.4rem 0.6rem; text-align: left; border-bottom: 1px solid var(--border); }
|
||||
.cf-table td { padding: 0.35rem 0.6rem; border-bottom: 1px solid var(--border-dim); font-size: 0.8em; vertical-align: middle; }
|
||||
.cf-table tr:last-child td { border-bottom: none; }
|
||||
.td-service { color: var(--indigo); font-weight: 600; }
|
||||
.td-node { color: var(--muted); }
|
||||
.td-mb { color: var(--text); }
|
||||
.td-priority { color: var(--amber); }
|
||||
.td-model { color: var(--cyan); font-size: 0.75em; }
|
||||
.td-warm { color: var(--amber); }
|
||||
.td-none { color: var(--dim); font-style: italic; }
|
||||
.ttl-wrap { display: flex; align-items: center; gap: 0.5rem; }
|
||||
.ttl-label { color: var(--cyan); font-variant-numeric: tabular-nums; white-space: nowrap; }
|
||||
.ttl-track { flex: 1; background: var(--bg); border-radius: var(--radius-sm); height: 4px; }
|
||||
.ttl-fill { height: 100%; border-radius: var(--radius-sm); background: var(--cyan); transition: width 0.4s; }
|
||||
|
||||
/* service state classes */
|
||||
.state-running { color: #2ecc40; }
|
||||
.state-idle { color: #ff851b; }
|
||||
.state-stopped { color: #aaa; }
|
||||
.state-starting { color: #0074d9; }
|
||||
.state-unknown { color: #ff4136; }
|
||||
|
||||
/* error */
|
||||
#error-banner { display: none; background: rgba(248,81,73,.1); border: 1px solid var(--red); border-radius: var(--radius); color: var(--red); padding: 0.5rem 0.75rem; font-size: 0.82em; margin-bottom: 1rem; }
|
||||
|
||||
/* footer */
|
||||
footer { border-top: 1px solid var(--border); padding-top: 0.5rem; color: var(--dim); font-size: 0.72em; display: flex; gap: 1.5rem; }
|
||||
footer a { color: var(--indigo); text-decoration: none; }
|
||||
footer a:hover { text-decoration: underline; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
|
||||
<header>
|
||||
<span class="logo">cf-orch</span>
|
||||
<span id="cluster-label" style="color:var(--muted)">coordinator</span>
|
||||
<div id="refresh-badge">auto-refresh <span id="countdown">5</span>s</div>
|
||||
</header>
|
||||
|
||||
<div id="error-banner"></div>
|
||||
|
||||
<div class="section-label">Services</div>
|
||||
<div id="health-strip"></div>
|
||||
|
||||
<div class="section-label">GPU Nodes</div>
|
||||
<div id="gpu-grid"></div>
|
||||
|
||||
<div id="services-section">
|
||||
<div class="section-label">Service Instances</div>
|
||||
<table class="cf-table" id="services-table">
|
||||
<thead>
|
||||
<tr>
|
||||
<th>Service</th><th>Node</th><th>GPU</th><th>State</th><th>Model</th><th>URL</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody id="services-body"></tbody>
|
||||
</table>
|
||||
</div>
|
||||
|
||||
<div class="section-label">Active Leases</div>
|
||||
<table class="cf-table" id="leases-table">
|
||||
<thead>
|
||||
<tr>
|
||||
<th>Service</th><th>Node / GPU</th><th>VRAM</th><th>Priority</th><th>TTL / Expires</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody id="leases-body"></tbody>
|
||||
</table>
|
||||
|
||||
<div class="section-label">Warm Models</div>
|
||||
<table class="cf-table" id="resident-table">
|
||||
<thead>
|
||||
<tr>
|
||||
<th>Service</th><th>Node</th><th>Model</th><th>Warm Since</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody id="resident-body"></tbody>
|
||||
</table>
|
||||
|
||||
<footer>
|
||||
<span>cf-orch · circuitforge-core</span>
|
||||
<a href="/api/nodes" target="_blank">/api/nodes</a>
|
||||
<a href="/api/leases" target="_blank">/api/leases</a>
|
||||
<a href="/api/resident" target="_blank">/api/resident</a>
|
||||
<a href="/api/services" target="_blank">/api/services</a>
|
||||
<a href="/api/health" target="_blank">/api/health</a>
|
||||
</footer>
|
||||
|
||||
<script>
|
||||
"use strict";
|
||||
|
||||
// ── helpers ──────────────────────────────────────────────────────
|
||||
|
||||
/** Create an element with optional className and textContent. */
|
||||
function el(tag, opts) {
|
||||
const e = document.createElement(tag);
|
||||
if (opts && opts.cls) { opts.cls.split(' ').forEach(c => c && e.classList.add(c)); }
|
||||
if (opts && opts.text != null) e.textContent = opts.text;
|
||||
if (opts && opts.style) Object.assign(e.style, opts.style);
|
||||
if (opts && opts.attr) Object.entries(opts.attr).forEach(([k,v]) => e.setAttribute(k, v));
|
||||
return e;
|
||||
}
|
||||
|
||||
/** Append children to a parent element. Returns parent. */
|
||||
function append(parent, ...children) {
|
||||
children.forEach(c => c && parent.appendChild(c));
|
||||
return parent;
|
||||
}
|
||||
|
||||
/** Replace all children of a DOM node. */
|
||||
function setChildren(parent, ...children) {
|
||||
while (parent.firstChild) parent.removeChild(parent.firstChild);
|
||||
append(parent, ...children);
|
||||
}
|
||||
|
||||
/** Build a sparkline SVG element (no innerHTML). */
|
||||
function buildSparkline(history, totalMb) {
|
||||
const ns = 'http://www.w3.org/2000/svg';
|
||||
const svg = document.createElementNS(ns, 'svg');
|
||||
svg.setAttribute('width', '100%');
|
||||
svg.setAttribute('height', '16');
|
||||
svg.setAttribute('viewBox', '0 0 100 16');
|
||||
|
||||
if (!history || history.length < 2) {
|
||||
const line = document.createElementNS(ns, 'line');
|
||||
line.setAttribute('x1', '0'); line.setAttribute('y1', '14');
|
||||
line.setAttribute('x2', '100'); line.setAttribute('y2', '14');
|
||||
line.setAttribute('stroke', '#30363d'); line.setAttribute('stroke-width', '1');
|
||||
svg.appendChild(line);
|
||||
return svg;
|
||||
}
|
||||
|
||||
const max = Math.max(totalMb, 1);
|
||||
const pts = history.map((v, i) => {
|
||||
const x = (i / (history.length - 1)) * 100;
|
||||
const y = 14 - ((v / max) * 12);
|
||||
return x.toFixed(1) + ',' + y.toFixed(1);
|
||||
}).join(' ');
|
||||
|
||||
const poly = document.createElementNS(ns, 'polyline');
|
||||
poly.setAttribute('points', pts);
|
||||
poly.setAttribute('fill', 'none');
|
||||
poly.setAttribute('stroke', '#818cf8');
|
||||
poly.setAttribute('stroke-width', '1.5');
|
||||
poly.setAttribute('stroke-linejoin', 'round');
|
||||
svg.appendChild(poly);
|
||||
return svg;
|
||||
}
|
||||
|
||||
/** VRAM fill colour based on utilisation fraction. */
|
||||
function vramColor(pct) {
|
||||
if (pct >= 0.9) return '#f85149';
|
||||
if (pct >= 0.7) return '#fbbf24';
|
||||
return '#22d3ee';
|
||||
}
|
||||
|
||||
// ── sparkline history ────────────────────────────────────────────
|
||||
// keyed "nodeId:gpuId" → array of vram_used_mb, max 20 samples
|
||||
const sparkHistory = {};
|
||||
|
||||
// ── countdown ────────────────────────────────────────────────────
|
||||
let countdown = 5;
|
||||
setInterval(() => {
|
||||
countdown = countdown <= 1 ? 5 : countdown - 1;
|
||||
document.getElementById('countdown').textContent = countdown;
|
||||
}, 1000);
|
||||
|
||||
// ── state class helper ───────────────────────────────────────────
|
||||
function stateClass(state) {
|
||||
const map = { running: 'state-running', idle: 'state-idle', stopped: 'state-stopped', starting: 'state-starting' };
|
||||
return map[state] || 'state-unknown';
|
||||
}
|
||||
|
||||
// ── render: services table ───────────────────────────────────────
|
||||
function renderServices(services) {
|
||||
const tbody = document.getElementById('services-body');
|
||||
if (!services || services.length === 0) {
|
||||
const tr = document.createElement('tr');
|
||||
const td = el('td', { cls: 'td-none', text: 'No service instances registered.' });
|
||||
td.setAttribute('colspan', '6');
|
||||
tr.appendChild(td);
|
||||
setChildren(tbody, tr);
|
||||
return;
|
||||
}
|
||||
|
||||
const rows = services.map(svc => {
|
||||
const tr = document.createElement('tr');
|
||||
const fields = [
|
||||
{ text: svc.service, cls: 'td-service' },
|
||||
{ text: svc.node_id, cls: 'td-node' },
|
||||
{ text: String(svc.gpu_id), cls: 'td-mb' },
|
||||
{ text: svc.state, cls: stateClass(svc.state) },
|
||||
{ text: svc.model || '\u2014', cls: 'td-model' },
|
||||
{ text: svc.url || '\u2014', cls: 'td-node' },
|
||||
];
|
||||
fields.forEach(f => tr.appendChild(el('td', { cls: f.cls, text: f.text })));
|
||||
return tr;
|
||||
});
|
||||
|
||||
setChildren(tbody, ...rows);
|
||||
}
|
||||
|
||||
// ── render: health strip ─────────────────────────────────────────
|
||||
function renderHealth(ok) {
|
||||
const strip = document.getElementById('health-strip');
|
||||
const pill = el('span', { cls: 'pill ' + (ok ? 'ok' : 'err'), text: (ok ? '● ' : '✕ ') + 'coordinator' });
|
||||
setChildren(strip, pill);
|
||||
}
|
||||
|
||||
// ── render: GPU grid ─────────────────────────────────────────────
|
||||
// leasedByGpu: "nodeId:gpuId" → total MB currently leased (from active leases)
|
||||
function renderNodes(nodes, leasedByGpu) {
|
||||
const grid = document.getElementById('gpu-grid');
|
||||
if (!nodes || nodes.length === 0) {
|
||||
setChildren(grid, el('div', { text: 'No nodes registered.', style: { color: 'var(--dim)', fontSize: '0.8em', padding: '0.5rem' } }));
|
||||
return;
|
||||
}
|
||||
|
||||
const cards = [];
|
||||
for (const node of nodes) {
|
||||
for (const gpu of node.gpus) {
|
||||
const key = node.node_id + ':' + gpu.gpu_id;
|
||||
const total = gpu.vram_total_mb || 1;
|
||||
const used = gpu.vram_used_mb;
|
||||
const leased = leasedByGpu[key] || 0;
|
||||
// Resident = nvidia-smi used minus actively leased; clamped to [0, used].
|
||||
const resident = Math.max(0, Math.min(used - leased, used));
|
||||
const pct = used / total;
|
||||
|
||||
if (!sparkHistory[key]) sparkHistory[key] = [];
|
||||
sparkHistory[key].push(used);
|
||||
if (sparkHistory[key].length > 20) sparkHistory[key].shift();
|
||||
|
||||
const statusCls = pct >= 0.9 ? 'full' : pct >= 0.1 ? 'busy' : 'idle';
|
||||
const statusText = pct >= 0.9 ? 'saturated' : pct >= 0.1 ? Math.round(pct * 100) + '% used' : 'idle';
|
||||
|
||||
const card = el('div', { cls: 'gpu-card' });
|
||||
const nodeLabel = el('div', { cls: 'gpu-node', text: node.node_id.toUpperCase() + ' · GPU ' + gpu.gpu_id });
|
||||
const nameLine = el('div', { cls: 'gpu-name', text: gpu.name || 'Unknown GPU' });
|
||||
|
||||
// Stacked bar: cyan (leased) → amber (resident) → dark bg (free).
|
||||
const leasedPct = (leased / total * 100).toFixed(1);
|
||||
const residentPct = (resident / total * 100).toFixed(1);
|
||||
const track = el('div', { cls: 'vram-track' });
|
||||
const fillLeased = el('div', { cls: 'vram-leased', style: { width: leasedPct + '%' } });
|
||||
const fillResident = el('div', { cls: 'vram-resident', style: { left: leasedPct + '%', width: residentPct + '%' } });
|
||||
append(track, fillLeased, fillResident);
|
||||
|
||||
// Breakdown label when something is allocated.
|
||||
let labelText = (used / 1024).toFixed(1) + ' / ' + (total / 1024).toFixed(1) + ' GB';
|
||||
if (leased > 0 || resident > 0) {
|
||||
const parts = [];
|
||||
if (leased > 0) parts.push((leased / 1024).toFixed(1) + 'G leased');
|
||||
if (resident > 0) parts.push((resident / 1024).toFixed(1) + 'G resident');
|
||||
labelText += ' (' + parts.join(' · ') + ')';
|
||||
}
|
||||
|
||||
const vramLbl = el('div', { cls: 'vram-label', text: labelText });
|
||||
const statusEl = el('div', { cls: 'gpu-status ' + statusCls, text: statusText });
|
||||
const sparkTrack = el('div', { cls: 'spark-track' });
|
||||
sparkTrack.appendChild(buildSparkline(sparkHistory[key], total));
|
||||
|
||||
append(card, nodeLabel, nameLine, track, vramLbl, statusEl, sparkTrack);
|
||||
cards.push(card);
|
||||
}
|
||||
}
|
||||
|
||||
setChildren(grid, ...cards);
|
||||
}
|
||||
|
||||
// ── render: warm models table ────────────────────────────────────
|
||||
function renderResidents(residents) {
|
||||
const tbody = document.getElementById('resident-body');
|
||||
if (!residents || residents.length === 0) {
|
||||
const tr = document.createElement('tr');
|
||||
const td = el('td', { cls: 'td-none', text: 'No warm models detected.' });
|
||||
td.setAttribute('colspan', '4');
|
||||
tr.appendChild(td);
|
||||
setChildren(tbody, tr);
|
||||
return;
|
||||
}
|
||||
|
||||
const now = Date.now() / 1000;
|
||||
const rows = residents.map(r => {
|
||||
const warmSecs = now - (r.first_seen || now);
|
||||
const warmText = warmSecs < 60
|
||||
? Math.floor(warmSecs) + 's'
|
||||
: warmSecs < 3600
|
||||
? Math.floor(warmSecs / 60) + 'm ' + String(Math.floor(warmSecs % 60)).padStart(2, '0') + 's'
|
||||
: Math.floor(warmSecs / 3600) + 'h ' + String(Math.floor((warmSecs % 3600) / 60)).padStart(2, '0') + 'm';
|
||||
|
||||
const tr = document.createElement('tr');
|
||||
append(tr,
|
||||
el('td', { cls: 'td-service', text: r.service }),
|
||||
el('td', { cls: 'td-node', text: r.node_id }),
|
||||
el('td', { cls: 'td-model', text: r.model_name || '—' }),
|
||||
el('td', { cls: 'td-warm', text: warmText }),
|
||||
);
|
||||
return tr;
|
||||
});
|
||||
|
||||
setChildren(tbody, ...rows);
|
||||
}
|
||||
|
||||
// ── render: leases table ─────────────────────────────────────────
|
||||
function renderLeases(leases) {
|
||||
const tbody = document.getElementById('leases-body');
|
||||
if (!leases || leases.length === 0) {
|
||||
const tr = document.createElement('tr');
|
||||
const td = el('td', { cls: 'td-none', text: 'No active leases.' });
|
||||
td.setAttribute('colspan', '5');
|
||||
tr.appendChild(td);
|
||||
setChildren(tbody, tr);
|
||||
return;
|
||||
}
|
||||
|
||||
const now = Date.now() / 1000;
|
||||
const rows = leases.map(lease => {
|
||||
const mbGb = lease.mb_granted >= 1024
|
||||
? (lease.mb_granted / 1024).toFixed(1) + ' GB'
|
||||
: lease.mb_granted + ' MB';
|
||||
|
||||
const tr = document.createElement('tr');
|
||||
|
||||
const tdService = el('td', { cls: 'td-service', text: lease.holder_service });
|
||||
const tdNode = el('td', { cls: 'td-node', text: lease.node_id + ' / GPU ' + lease.gpu_id });
|
||||
const tdMb = el('td', { cls: 'td-mb', text: mbGb });
|
||||
const tdPriority = el('td', { cls: 'td-priority', text: 'p' + lease.priority });
|
||||
|
||||
const tdTtl = document.createElement('td');
|
||||
if (!lease.expires_at) {
|
||||
tdTtl.appendChild(el('span', { cls: 'ttl-label', text: '∞' }));
|
||||
} else {
|
||||
const remaining = Math.max(0, lease.expires_at - now);
|
||||
const pct = Math.min(100, (remaining / 300) * 100);
|
||||
const mins = Math.floor(remaining / 60);
|
||||
const secs = Math.floor(remaining % 60);
|
||||
const label = remaining > 60
|
||||
? mins + 'm ' + String(secs).padStart(2, '0') + 's'
|
||||
: Math.floor(remaining) + 's';
|
||||
|
||||
const wrap = el('div', { cls: 'ttl-wrap' });
|
||||
const lbl = el('span', { cls: 'ttl-label', text: label });
|
||||
const track = el('div', { cls: 'ttl-track' });
|
||||
const fill = el('div', { cls: 'ttl-fill', style: { width: pct.toFixed(1) + '%' } });
|
||||
track.appendChild(fill);
|
||||
append(wrap, lbl, track);
|
||||
tdTtl.appendChild(wrap);
|
||||
}
|
||||
|
||||
append(tr, tdService, tdNode, tdMb, tdPriority, tdTtl);
|
||||
return tr;
|
||||
});
|
||||
|
||||
setChildren(tbody, ...rows);
|
||||
}
|
||||
|
||||
// ── error banner ─────────────────────────────────────────────────
|
||||
function showError(msg) {
|
||||
const el = document.getElementById('error-banner');
|
||||
el.textContent = msg; // textContent — safe
|
||||
el.style.display = 'block';
|
||||
}
|
||||
function clearError() { document.getElementById('error-banner').style.display = 'none'; }
|
||||
|
||||
// ── poll ─────────────────────────────────────────────────────────
|
||||
async function poll() {
|
||||
try {
|
||||
const [nodesRes, leasesRes, residentRes, healthRes, servicesRes] = await Promise.all([
|
||||
fetch('/api/nodes'),
|
||||
fetch('/api/leases'),
|
||||
fetch('/api/resident'),
|
||||
fetch('/api/health'),
|
||||
fetch('/api/services'),
|
||||
]);
|
||||
if (!nodesRes.ok || !leasesRes.ok) throw new Error('API error: ' + nodesRes.status);
|
||||
const [nodesData, leasesData, residentData, servicesData] = await Promise.all([
|
||||
nodesRes.json(), leasesRes.json(),
|
||||
residentRes.ok ? residentRes.json() : Promise.resolve({ residents: [] }),
|
||||
servicesRes.ok ? servicesRes.json() : Promise.resolve({ services: [] }),
|
||||
]);
|
||||
|
||||
// Build per-GPU leased-MB index for the stacked bar.
|
||||
const leasedByGpu = {};
|
||||
for (const lease of (leasesData.leases || [])) {
|
||||
const key = lease.node_id + ':' + lease.gpu_id;
|
||||
leasedByGpu[key] = (leasedByGpu[key] || 0) + lease.mb_granted;
|
||||
}
|
||||
|
||||
clearError();
|
||||
renderHealth(healthRes.ok);
|
||||
renderNodes(nodesData.nodes || [], leasedByGpu);
|
||||
renderServices(servicesData.services || []);
|
||||
renderLeases(leasesData.leases || []);
|
||||
renderResidents(residentData.residents || []);
|
||||
} catch (err) {
|
||||
showError('Failed to reach coordinator: ' + err.message);
|
||||
renderHealth(false);
|
||||
}
|
||||
}
|
||||
|
||||
poll();
|
||||
setInterval(poll, 5000);
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
|
|
@ -1,81 +0,0 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from circuitforge_core.resources.coordinator.lease_manager import LeaseManager
|
||||
from circuitforge_core.resources.models import VRAMLease
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_DEFAULT_EVICTION_TIMEOUT_S = 10.0
|
||||
|
||||
|
||||
class EvictionEngine:
|
||||
def __init__(
|
||||
self,
|
||||
lease_manager: LeaseManager,
|
||||
eviction_timeout_s: float = _DEFAULT_EVICTION_TIMEOUT_S,
|
||||
) -> None:
|
||||
self.lease_manager = lease_manager
|
||||
self._timeout = eviction_timeout_s
|
||||
|
||||
async def request_lease(
|
||||
self,
|
||||
node_id: str,
|
||||
gpu_id: int,
|
||||
mb: int,
|
||||
service: str,
|
||||
priority: int,
|
||||
agent_url: str,
|
||||
ttl_s: float = 0.0,
|
||||
) -> VRAMLease | None:
|
||||
# Fast path: enough free VRAM
|
||||
lease = await self.lease_manager.try_grant(
|
||||
node_id, gpu_id, mb, service, priority, ttl_s
|
||||
)
|
||||
if lease is not None:
|
||||
return lease
|
||||
|
||||
# Find eviction candidates
|
||||
candidates = self.lease_manager.get_eviction_candidates(
|
||||
node_id=node_id, gpu_id=gpu_id,
|
||||
needed_mb=mb, requester_priority=priority,
|
||||
)
|
||||
if not candidates:
|
||||
logger.info(
|
||||
"No eviction candidates for %s on %s:GPU%d (%dMB needed)",
|
||||
service, node_id, gpu_id, mb,
|
||||
)
|
||||
return None
|
||||
|
||||
# Evict candidates
|
||||
freed_mb = sum(c.mb_granted for c in candidates)
|
||||
logger.info(
|
||||
"Evicting %d lease(s) to free %dMB for %s",
|
||||
len(candidates), freed_mb, service,
|
||||
)
|
||||
for candidate in candidates:
|
||||
await self._evict_lease(candidate, agent_url)
|
||||
|
||||
# Wait for evictions to free up VRAM (poll with timeout)
|
||||
loop = asyncio.get_running_loop()
|
||||
deadline = loop.time() + self._timeout
|
||||
while loop.time() < deadline:
|
||||
lease = await self.lease_manager.try_grant(
|
||||
node_id, gpu_id, mb, service, priority, ttl_s
|
||||
)
|
||||
if lease is not None:
|
||||
return lease
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
logger.warning("Eviction timed out for %s after %.1fs", service, self._timeout)
|
||||
return None
|
||||
|
||||
async def _evict_lease(self, lease: VRAMLease, agent_url: str) -> None:
|
||||
"""Release lease accounting. Process-level eviction deferred to Plan B."""
|
||||
await self.lease_manager.release(lease.lease_id)
|
||||
|
||||
async def _call_agent_evict(self, agent_url: str, lease: VRAMLease) -> bool:
|
||||
"""POST /evict to the agent. Stub for v1 — real process lookup in Plan B."""
|
||||
return True
|
||||
|
|
@ -1,130 +0,0 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections import defaultdict
|
||||
|
||||
from circuitforge_core.resources.models import ResidentAllocation, VRAMLease
|
||||
|
||||
|
||||
class LeaseManager:
|
||||
def __init__(self) -> None:
|
||||
self._leases: dict[str, VRAMLease] = {}
|
||||
self._gpu_total: dict[tuple[str, int], int] = {}
|
||||
self._gpu_used: dict[tuple[str, int], int] = defaultdict(int)
|
||||
self._lock = asyncio.Lock()
|
||||
# Resident allocations — keyed "node_id:service", updated by heartbeat.
|
||||
# No lock needed: only the single heartbeat task writes this dict.
|
||||
self._residents: dict[str, ResidentAllocation] = {}
|
||||
|
||||
def register_gpu(self, node_id: str, gpu_id: int, total_mb: int) -> None:
|
||||
self._gpu_total[(node_id, gpu_id)] = total_mb
|
||||
|
||||
def gpu_total_mb(self, node_id: str, gpu_id: int) -> int:
|
||||
return self._gpu_total.get((node_id, gpu_id), 0)
|
||||
|
||||
def used_mb(self, node_id: str, gpu_id: int) -> int:
|
||||
return self._gpu_used[(node_id, gpu_id)]
|
||||
|
||||
async def try_grant(
|
||||
self,
|
||||
node_id: str,
|
||||
gpu_id: int,
|
||||
mb: int,
|
||||
service: str,
|
||||
priority: int,
|
||||
ttl_s: float = 0.0,
|
||||
) -> VRAMLease | None:
|
||||
async with self._lock:
|
||||
total = self._gpu_total.get((node_id, gpu_id), 0)
|
||||
used = self._gpu_used[(node_id, gpu_id)]
|
||||
if total - used < mb:
|
||||
return None
|
||||
lease = VRAMLease.create(
|
||||
gpu_id=gpu_id, node_id=node_id, mb=mb,
|
||||
service=service, priority=priority, ttl_s=ttl_s,
|
||||
)
|
||||
self._leases[lease.lease_id] = lease
|
||||
self._gpu_used[(node_id, gpu_id)] += mb
|
||||
return lease
|
||||
|
||||
async def release(self, lease_id: str) -> bool:
|
||||
async with self._lock:
|
||||
lease = self._leases.pop(lease_id, None)
|
||||
if lease is None:
|
||||
return False
|
||||
self._gpu_used[(lease.node_id, lease.gpu_id)] -= lease.mb_granted
|
||||
return True
|
||||
|
||||
def get_eviction_candidates(
|
||||
self,
|
||||
node_id: str,
|
||||
gpu_id: int,
|
||||
needed_mb: int,
|
||||
requester_priority: int,
|
||||
) -> list[VRAMLease]:
|
||||
candidates = [
|
||||
lease for lease in self._leases.values()
|
||||
if lease.node_id == node_id
|
||||
and lease.gpu_id == gpu_id
|
||||
and lease.priority > requester_priority
|
||||
]
|
||||
candidates.sort(key=lambda lease: lease.priority, reverse=True)
|
||||
selected: list[VRAMLease] = []
|
||||
freed = 0
|
||||
for candidate in candidates:
|
||||
selected.append(candidate)
|
||||
freed += candidate.mb_granted
|
||||
if freed >= needed_mb:
|
||||
break
|
||||
return selected
|
||||
|
||||
def list_leases(
|
||||
self, node_id: str | None = None, gpu_id: int | None = None
|
||||
) -> list[VRAMLease]:
|
||||
return [
|
||||
lease for lease in self._leases.values()
|
||||
if (node_id is None or lease.node_id == node_id)
|
||||
and (gpu_id is None or lease.gpu_id == gpu_id)
|
||||
]
|
||||
|
||||
def all_leases(self) -> list[VRAMLease]:
|
||||
return list(self._leases.values())
|
||||
|
||||
# ── resident tracking ────────────────────────────────────────────
|
||||
|
||||
def set_residents_for_node(
|
||||
self,
|
||||
node_id: str,
|
||||
residents: list[tuple[str, str | None]], # (service, model_name)
|
||||
) -> None:
|
||||
"""
|
||||
Replace the resident snapshot for a node.
|
||||
|
||||
Preserves first_seen for entries whose service+model_name are unchanged,
|
||||
so the dashboard can show how long a model has been warm.
|
||||
"""
|
||||
new_keys = {f"{node_id}:{service}" for service, _ in residents}
|
||||
|
||||
# Remove stale entries (service no longer running on this node).
|
||||
for key in list(self._residents):
|
||||
if key.startswith(f"{node_id}:") and key not in new_keys:
|
||||
del self._residents[key]
|
||||
|
||||
# Upsert: preserve first_seen when model is unchanged, reset otherwise.
|
||||
for service, model_name in residents:
|
||||
key = f"{node_id}:{service}"
|
||||
existing = self._residents.get(key)
|
||||
if existing is not None and existing.model_name == model_name:
|
||||
continue # same model still loaded — keep original first_seen
|
||||
self._residents[key] = ResidentAllocation(
|
||||
service=service,
|
||||
node_id=node_id,
|
||||
model_name=model_name,
|
||||
)
|
||||
|
||||
def all_residents(self) -> list[ResidentAllocation]:
|
||||
return list(self._residents.values())
|
||||
|
||||
def resident_keys(self) -> set[str]:
|
||||
"""Return set of 'node_id:service' strings for currently-warm services."""
|
||||
return set(self._residents.keys())
|
||||
|
|
@ -1,74 +0,0 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from circuitforge_core.resources.coordinator.agent_supervisor import AgentRecord
|
||||
from circuitforge_core.resources.coordinator.profile_registry import ProfileRegistry
|
||||
|
||||
_WARM_BONUS_MB = 1000
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _Scored:
|
||||
node_id: str
|
||||
gpu_id: int
|
||||
vram_free_mb: int
|
||||
effective_free_mb: int
|
||||
can_fit: bool
|
||||
warm: bool
|
||||
|
||||
|
||||
def select_node(
|
||||
agents: "dict[str, AgentRecord]",
|
||||
service: str,
|
||||
profile_registry: "ProfileRegistry",
|
||||
resident_keys: set[str],
|
||||
) -> tuple[str, int] | None:
|
||||
"""
|
||||
Pick the best (node_id, gpu_id) for the requested service.
|
||||
Warm nodes (service already running) get priority, then sorted by free VRAM.
|
||||
Returns None if no suitable node exists.
|
||||
"""
|
||||
service_max_mb = _find_service_max_mb(service, profile_registry)
|
||||
if service_max_mb is None:
|
||||
return None # service not in any profile
|
||||
|
||||
candidates: list[_Scored] = []
|
||||
for node_id, record in agents.items():
|
||||
if not record.online:
|
||||
continue
|
||||
for gpu in record.gpus:
|
||||
warm = f"{node_id}:{service}" in resident_keys
|
||||
effective = gpu.vram_free_mb + (_WARM_BONUS_MB if warm else 0)
|
||||
can_fit = gpu.vram_free_mb >= service_max_mb
|
||||
candidates.append(_Scored(
|
||||
node_id=node_id,
|
||||
gpu_id=gpu.gpu_id,
|
||||
vram_free_mb=gpu.vram_free_mb,
|
||||
effective_free_mb=effective,
|
||||
can_fit=can_fit,
|
||||
warm=warm,
|
||||
))
|
||||
if not candidates:
|
||||
return None
|
||||
# Prefer: (1) warm nodes (model already resident — no cold start)
|
||||
# (2) cold nodes that can fit the service (free >= half of max_mb)
|
||||
# Fallback: best-effort node when nothing fits and nothing is warm
|
||||
# (coordinator will attempt to start the service anyway; it may evict or fail)
|
||||
# Note: resident_keys are per-node, not per-GPU. On multi-GPU nodes, the warm
|
||||
# bonus applies to all GPUs on the node. This is a known coarseness —
|
||||
# per-GPU resident tracking requires a resident_key format change.
|
||||
preferred = [c for c in candidates if c.warm or c.can_fit]
|
||||
pool = preferred if preferred else candidates
|
||||
best = max(pool, key=lambda c: (c.warm, c.effective_free_mb))
|
||||
return best.node_id, best.gpu_id
|
||||
|
||||
|
||||
def _find_service_max_mb(service: str, profile_registry: "ProfileRegistry") -> int | None:
|
||||
for profile in profile_registry.list_public():
|
||||
svc = profile.services.get(service)
|
||||
if svc is not None:
|
||||
return svc.max_mb
|
||||
return None
|
||||
|
|
@ -1,65 +0,0 @@
|
|||
# circuitforge_core/resources/coordinator/profile_registry.py
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from circuitforge_core.resources.models import GpuInfo
|
||||
from circuitforge_core.resources.profiles.schema import GpuProfile, load_profile
|
||||
|
||||
_PUBLIC_DIR = Path(__file__).parent.parent / "profiles" / "public"
|
||||
|
||||
# VRAM thresholds for public profile selection (MB)
|
||||
_PROFILE_THRESHOLDS = [
|
||||
(22000, "single-gpu-24gb"),
|
||||
(14000, "single-gpu-16gb"),
|
||||
(8000, "single-gpu-8gb"),
|
||||
(5500, "single-gpu-6gb"),
|
||||
(3500, "single-gpu-4gb"),
|
||||
(0, "single-gpu-2gb"),
|
||||
]
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ProfileRegistry:
|
||||
def __init__(self, extra_dirs: list[Path] | None = None) -> None:
|
||||
self._profiles: dict[str, GpuProfile] = {}
|
||||
self._load_dir(_PUBLIC_DIR)
|
||||
for d in (extra_dirs or []):
|
||||
if d.exists():
|
||||
self._load_dir(d)
|
||||
|
||||
def _load_dir(self, directory: Path) -> None:
|
||||
for yaml_file in directory.glob("*.yaml"):
|
||||
try:
|
||||
profile = load_profile(yaml_file)
|
||||
self._profiles[profile.name] = profile
|
||||
except Exception as exc:
|
||||
_log.warning("Skipping %s: %s", yaml_file, exc)
|
||||
|
||||
def load(self, path: Path) -> GpuProfile:
|
||||
profile = load_profile(path)
|
||||
self._profiles[profile.name] = profile
|
||||
return profile
|
||||
|
||||
def list_public(self) -> list[GpuProfile]:
|
||||
# CPU profiles (cpu-*) are intentionally excluded — this endpoint
|
||||
# is used to match GPU hardware. CPU inference nodes self-select
|
||||
# their profile via the CLI and are not listed for lease matching.
|
||||
return [
|
||||
p for p in self._profiles.values()
|
||||
if p.name.startswith("single-gpu-")
|
||||
]
|
||||
|
||||
def get(self, name: str) -> GpuProfile | None:
|
||||
return self._profiles.get(name)
|
||||
|
||||
def auto_detect(self, gpus: list[GpuInfo]) -> GpuProfile:
|
||||
primary_vram = gpus[0].vram_total_mb if gpus else 0
|
||||
for threshold_mb, profile_name in _PROFILE_THRESHOLDS:
|
||||
if primary_vram >= threshold_mb:
|
||||
profile = self._profiles.get(profile_name)
|
||||
if profile:
|
||||
return profile
|
||||
return self._profiles["single-gpu-2gb"]
|
||||
|
|
@ -1,170 +0,0 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal
|
||||
|
||||
|
||||
@dataclass
|
||||
class ServiceAllocation:
|
||||
allocation_id: str
|
||||
service: str
|
||||
node_id: str
|
||||
gpu_id: int
|
||||
model: str | None
|
||||
caller: str
|
||||
url: str
|
||||
created_at: float
|
||||
expires_at: float # 0 = no expiry
|
||||
|
||||
|
||||
@dataclass
|
||||
class ServiceInstance:
|
||||
service: str
|
||||
node_id: str
|
||||
gpu_id: int
|
||||
state: Literal["starting", "running", "idle", "stopped"]
|
||||
model: str | None
|
||||
url: str | None
|
||||
idle_since: float | None = None
|
||||
|
||||
|
||||
class ServiceRegistry:
|
||||
"""
|
||||
In-memory registry of service allocations and instance state.
|
||||
|
||||
Allocations: per-caller request — many per service instance.
|
||||
Instances: per (service, node_id, gpu_id) — one per running container.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._allocations: dict[str, ServiceAllocation] = {}
|
||||
self._instances: dict[str, ServiceInstance] = {} # key: "service:node_id:gpu_id"
|
||||
|
||||
# ── allocation API ────────────────────────────────────────────────
|
||||
|
||||
def allocate(
|
||||
self,
|
||||
service: str,
|
||||
node_id: str,
|
||||
gpu_id: int,
|
||||
model: str | None,
|
||||
url: str,
|
||||
caller: str,
|
||||
ttl_s: float,
|
||||
) -> ServiceAllocation:
|
||||
alloc = ServiceAllocation(
|
||||
allocation_id=str(uuid.uuid4()),
|
||||
service=service,
|
||||
node_id=node_id,
|
||||
gpu_id=gpu_id,
|
||||
model=model,
|
||||
caller=caller,
|
||||
url=url,
|
||||
created_at=time.time(),
|
||||
expires_at=time.time() + ttl_s if ttl_s > 0 else 0.0,
|
||||
)
|
||||
self._allocations[alloc.allocation_id] = alloc
|
||||
|
||||
# If an instance exists in idle/stopped state, mark it running again
|
||||
key = f"{service}:{node_id}:{gpu_id}"
|
||||
if key in self._instances:
|
||||
inst = self._instances[key]
|
||||
if inst.state in ("idle", "stopped"):
|
||||
self._instances[key] = dataclasses.replace(
|
||||
inst, state="running", idle_since=None
|
||||
)
|
||||
return alloc
|
||||
|
||||
def release(self, allocation_id: str) -> bool:
|
||||
alloc = self._allocations.pop(allocation_id, None)
|
||||
if alloc is None:
|
||||
return False
|
||||
# If no active allocations remain for this instance, mark it idle
|
||||
key = f"{alloc.service}:{alloc.node_id}:{alloc.gpu_id}"
|
||||
if self.active_allocations(alloc.service, alloc.node_id, alloc.gpu_id) == 0:
|
||||
if key in self._instances:
|
||||
self._instances[key] = dataclasses.replace(
|
||||
self._instances[key], state="idle", idle_since=time.time()
|
||||
)
|
||||
return True
|
||||
|
||||
def active_allocations(self, service: str, node_id: str, gpu_id: int) -> int:
|
||||
return sum(
|
||||
1 for a in self._allocations.values()
|
||||
if a.service == service and a.node_id == node_id and a.gpu_id == gpu_id
|
||||
)
|
||||
|
||||
# ── instance API ─────────────────────────────────────────────────
|
||||
|
||||
def upsert_instance(
|
||||
self,
|
||||
service: str,
|
||||
node_id: str,
|
||||
gpu_id: int,
|
||||
state: Literal["starting", "running", "idle", "stopped"],
|
||||
model: str | None,
|
||||
url: str | None,
|
||||
) -> ServiceInstance:
|
||||
key = f"{service}:{node_id}:{gpu_id}"
|
||||
existing = self._instances.get(key)
|
||||
idle_since: float | None = None
|
||||
if state == "idle":
|
||||
# Preserve idle_since if already idle; set now if transitioning into idle
|
||||
idle_since = existing.idle_since if (existing and existing.state == "idle") else time.time()
|
||||
inst = ServiceInstance(
|
||||
service=service, node_id=node_id, gpu_id=gpu_id,
|
||||
state=state, model=model, url=url, idle_since=idle_since,
|
||||
)
|
||||
self._instances[key] = inst
|
||||
return inst
|
||||
|
||||
def get_allocation(self, allocation_id: str) -> ServiceAllocation | None:
|
||||
return self._allocations.get(allocation_id)
|
||||
|
||||
def sweep_expired_allocations(self) -> list[str]:
|
||||
"""
|
||||
Remove all allocations whose TTL has elapsed and transition the
|
||||
corresponding instance to 'idle' if no active allocations remain.
|
||||
Returns the list of expired allocation_ids.
|
||||
"""
|
||||
now = time.time()
|
||||
expired = [
|
||||
alloc_id
|
||||
for alloc_id, alloc in self._allocations.items()
|
||||
if alloc.expires_at > 0 and now > alloc.expires_at
|
||||
]
|
||||
for alloc_id in expired:
|
||||
self.release(alloc_id)
|
||||
return expired
|
||||
|
||||
def all_allocations(self) -> list[ServiceAllocation]:
|
||||
return list(self._allocations.values())
|
||||
|
||||
def all_instances(self) -> list[ServiceInstance]:
|
||||
return list(self._instances.values())
|
||||
|
||||
def mark_stopped(self, service: str, node_id: str, gpu_id: int) -> None:
|
||||
"""Transition an instance to 'stopped' state and clear idle_since."""
|
||||
key = f"{service}:{node_id}:{gpu_id}"
|
||||
if key in self._instances:
|
||||
self._instances[key] = dataclasses.replace(
|
||||
self._instances[key], state="stopped", idle_since=None
|
||||
)
|
||||
|
||||
def idle_past_timeout(self, idle_stop_config: dict[str, int]) -> list[ServiceInstance]:
|
||||
"""
|
||||
Return instances in 'idle' state whose idle time exceeds their configured timeout.
|
||||
idle_stop_config: {service_name: seconds} — 0 means never stop automatically.
|
||||
"""
|
||||
now = time.time()
|
||||
result = []
|
||||
for inst in self._instances.values():
|
||||
if inst.state != "idle" or inst.idle_since is None:
|
||||
continue
|
||||
timeout = idle_stop_config.get(inst.service, 0)
|
||||
if timeout > 0 and (now - inst.idle_since) >= timeout:
|
||||
result.append(inst)
|
||||
return result
|
||||
|
|
@ -1,250 +0,0 @@
|
|||
"""
|
||||
cf-docuvision — managed document understanding service.
|
||||
|
||||
Wraps ByteDance/Dolphin-v2 (Qwen2.5-VL backbone) behind a simple HTTP API.
|
||||
Managed by cf-orch; started/stopped as a ProcessSpec service.
|
||||
|
||||
API
|
||||
---
|
||||
GET /health → {"status": "ok", "model": "<path>"}
|
||||
POST /extract → ExtractResponse
|
||||
|
||||
Usage (standalone)::
|
||||
|
||||
python -m circuitforge_core.resources.docuvision.app \\
|
||||
--model /Library/Assets/LLM/docuvision/models/dolphin-v2 \\
|
||||
--port 8003 --gpu-id 0
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Module-level state — populated by _load_model() on first /extract call
|
||||
_model: Any = None
|
||||
_processor: Any = None
|
||||
_model_path: str = ""
|
||||
_device: str = "cpu"
|
||||
|
||||
|
||||
# ── lazy loader ───────────────────────────────────────────────────────────────
|
||||
|
||||
def _load_model() -> None:
|
||||
"""Lazy-load Dolphin-v2. Called once on first /extract request."""
|
||||
global _model, _processor, _device
|
||||
|
||||
if _model is not None:
|
||||
return
|
||||
|
||||
import torch
|
||||
from transformers import AutoProcessor, AutoModelForCausalLM
|
||||
|
||||
logger.info("Loading Dolphin-v2 from %s ...", _model_path)
|
||||
_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
_processor = AutoProcessor.from_pretrained(
|
||||
_model_path,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
_model = AutoModelForCausalLM.from_pretrained(
|
||||
_model_path,
|
||||
trust_remote_code=True,
|
||||
torch_dtype=torch.float16 if _device == "cuda" else torch.float32,
|
||||
device_map=_device,
|
||||
)
|
||||
_model.eval()
|
||||
logger.info("Dolphin-v2 loaded on %s", _device)
|
||||
|
||||
|
||||
# ── FastAPI app ───────────────────────────────────────────────────────────────
|
||||
|
||||
@asynccontextmanager
|
||||
async def _lifespan(app: FastAPI):
|
||||
yield
|
||||
|
||||
|
||||
app = FastAPI(title="cf-docuvision", lifespan=_lifespan)
|
||||
|
||||
|
||||
# ── request / response models ─────────────────────────────────────────────────
|
||||
|
||||
class ExtractRequest(BaseModel):
|
||||
"""
|
||||
Either image_b64 (base64-encoded bytes) or image_path (absolute path) must
|
||||
be provided. hint guides the extraction mode:
|
||||
- "auto" - Dolphin-v2 detects layout and element types automatically
|
||||
- "table" - optimise for tabular data (receipts, invoices, forms)
|
||||
- "text" - optimise for dense prose (contracts, letters)
|
||||
- "form" - optimise for form field extraction
|
||||
"""
|
||||
image_b64: str | None = None
|
||||
image_path: str | None = None
|
||||
hint: str = "auto"
|
||||
|
||||
|
||||
class ElementOut(BaseModel):
|
||||
type: str # heading | paragraph | list | table | figure | formula | code
|
||||
text: str
|
||||
bbox: list[float] | None = None # [x0, y0, x1, y1] normalised 0-1 if available
|
||||
|
||||
|
||||
class TableOut(BaseModel):
|
||||
html: str
|
||||
bbox: list[float] | None = None
|
||||
|
||||
|
||||
class ExtractResponse(BaseModel):
|
||||
elements: list[ElementOut]
|
||||
raw_text: str
|
||||
tables: list[TableOut]
|
||||
metadata: dict[str, Any]
|
||||
|
||||
|
||||
# ── helpers ───────────────────────────────────────────────────────────────────
|
||||
|
||||
_HINT_PROMPTS: dict[str, str] = {
|
||||
"auto": "Parse this document. Extract all elements with their types and text content.",
|
||||
"table": "Extract all tables from this document as structured HTML. Also extract any line-item text.",
|
||||
"text": "Extract all text from this document preserving paragraph and heading structure.",
|
||||
"form": "Extract all form fields from this document. Return field labels and their values.",
|
||||
}
|
||||
|
||||
|
||||
def _image_from_request(req: ExtractRequest):
|
||||
"""Return a PIL Image from either image_b64 or image_path."""
|
||||
from PIL import Image
|
||||
|
||||
if req.image_b64:
|
||||
img_bytes = base64.b64decode(req.image_b64)
|
||||
return Image.open(io.BytesIO(img_bytes)).convert("RGB")
|
||||
|
||||
if req.image_path:
|
||||
from pathlib import Path
|
||||
p = Path(req.image_path)
|
||||
if not p.exists():
|
||||
raise HTTPException(status_code=404, detail=f"image_path not found: {req.image_path}")
|
||||
return Image.open(p).convert("RGB")
|
||||
|
||||
raise HTTPException(status_code=422, detail="Either image_b64 or image_path must be provided")
|
||||
|
||||
|
||||
def _parse_dolphin_output(raw: str) -> tuple[list[ElementOut], list[TableOut], str]:
|
||||
"""
|
||||
Parse Dolphin-v2's structured output into elements and tables.
|
||||
|
||||
Dolphin-v2 returns a JSON array of element dicts with keys:
|
||||
type, text, [html], [bbox]
|
||||
|
||||
Falls back gracefully if the model returns plain text instead.
|
||||
"""
|
||||
elements: list[ElementOut] = []
|
||||
tables: list[TableOut] = []
|
||||
|
||||
# Try JSON parse first
|
||||
try:
|
||||
parsed = json.loads(raw)
|
||||
if isinstance(parsed, list):
|
||||
for item in parsed:
|
||||
etype = item.get("type", "paragraph")
|
||||
text = item.get("text", "")
|
||||
bbox = item.get("bbox")
|
||||
if etype == "table":
|
||||
tables.append(TableOut(html=item.get("html", text), bbox=bbox))
|
||||
elements.append(ElementOut(type=etype, text=text, bbox=bbox))
|
||||
raw_text = "\n".join(e.text for e in elements)
|
||||
return elements, tables, raw_text
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
|
||||
# Plain-text fallback: treat entire output as a single paragraph
|
||||
elements = [ElementOut(type="paragraph", text=raw.strip())]
|
||||
return elements, tables, raw.strip()
|
||||
|
||||
|
||||
# ── routes ────────────────────────────────────────────────────────────────────
|
||||
|
||||
@app.get("/health")
|
||||
async def health() -> dict[str, str]:
|
||||
return {"status": "ok", "model": _model_path}
|
||||
|
||||
|
||||
@app.post("/extract", response_model=ExtractResponse)
|
||||
async def extract(req: ExtractRequest) -> ExtractResponse:
|
||||
_load_model()
|
||||
|
||||
image = _image_from_request(req)
|
||||
prompt = _HINT_PROMPTS.get(req.hint, _HINT_PROMPTS["auto"])
|
||||
|
||||
import torch
|
||||
|
||||
inputs = _processor(
|
||||
text=prompt,
|
||||
images=image,
|
||||
return_tensors="pt",
|
||||
).to(_device)
|
||||
|
||||
with torch.no_grad():
|
||||
output_ids = _model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=2048,
|
||||
do_sample=False,
|
||||
)
|
||||
|
||||
# Decode only the newly generated tokens
|
||||
input_len = inputs["input_ids"].shape[1]
|
||||
raw_output = _processor.decode(
|
||||
output_ids[0][input_len:],
|
||||
skip_special_tokens=True,
|
||||
)
|
||||
|
||||
elements, tables, raw_text = _parse_dolphin_output(raw_output)
|
||||
|
||||
w, h = image.size
|
||||
|
||||
return ExtractResponse(
|
||||
elements=elements,
|
||||
raw_text=raw_text,
|
||||
tables=tables,
|
||||
metadata={
|
||||
"hint": req.hint,
|
||||
"width": w,
|
||||
"height": h,
|
||||
"model": _model_path,
|
||||
"device": _device,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# ── CLI entry point ───────────────────────────────────────────────────────────
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="cf-docuvision service")
|
||||
parser.add_argument("--model", required=True, help="Path to Dolphin-v2 model directory")
|
||||
parser.add_argument("--port", type=int, default=8003)
|
||||
parser.add_argument("--host", default="0.0.0.0")
|
||||
parser.add_argument("--gpu-id", type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
|
||||
global _model_path
|
||||
_model_path = args.model
|
||||
|
||||
import os
|
||||
os.environ.setdefault("CUDA_VISIBLE_DEVICES", str(args.gpu_id))
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(levelname)s %(name)s: %(message)s")
|
||||
uvicorn.run(app, host=args.host, port=args.port)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -1,137 +0,0 @@
|
|||
"""Generic OpenAI-compatible inference server for HuggingFace causal LMs."""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import time
|
||||
import uuid
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
_model: Any = None
|
||||
_tokenizer: Any = None
|
||||
_model_id: str = ""
|
||||
_device: str = "cpu"
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
yield
|
||||
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
role: str
|
||||
content: str
|
||||
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
model: str | None = None
|
||||
messages: list[Message]
|
||||
max_tokens: int | None = 512
|
||||
temperature: float | None = 0.7
|
||||
stream: bool | None = False
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
def health() -> dict[str, str]:
|
||||
return {"status": "ok", "model": _model_id}
|
||||
|
||||
|
||||
@app.get("/v1/models")
|
||||
def list_models() -> dict[str, Any]:
|
||||
return {
|
||||
"object": "list",
|
||||
"data": [{"id": _model_id, "object": "model", "owned_by": "cf-orch"}],
|
||||
}
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
def chat_completions(req: ChatRequest) -> dict[str, Any]:
|
||||
if _model is None:
|
||||
raise HTTPException(503, detail="Model not loaded")
|
||||
if req.stream:
|
||||
raise HTTPException(501, detail="Streaming not supported")
|
||||
|
||||
conversation = [{"role": m.role, "content": m.content} for m in req.messages]
|
||||
try:
|
||||
encoded = _tokenizer.apply_chat_template(
|
||||
conversation,
|
||||
return_tensors="pt",
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
# transformers 5.x returns BatchEncoding; 4.x returned a bare tensor
|
||||
input_ids = (encoded.input_ids if hasattr(encoded, "input_ids") else encoded).to(_device)
|
||||
except Exception as exc:
|
||||
raise HTTPException(500, detail=f"Tokenisation failed: {exc}")
|
||||
|
||||
max_new = req.max_tokens or 512
|
||||
temp = req.temperature if req.temperature is not None else 0.7
|
||||
gen_kwargs: dict[str, Any] = {
|
||||
"max_new_tokens": max_new,
|
||||
"do_sample": temp > 0,
|
||||
"pad_token_id": _tokenizer.eos_token_id,
|
||||
}
|
||||
if temp > 0:
|
||||
gen_kwargs["temperature"] = temp
|
||||
|
||||
with torch.inference_mode():
|
||||
output_ids = _model.generate(input_ids, **gen_kwargs)
|
||||
|
||||
new_tokens = output_ids[0][input_ids.shape[-1]:]
|
||||
reply = _tokenizer.decode(new_tokens, skip_special_tokens=True)
|
||||
|
||||
return {
|
||||
"id": f"chatcmpl-{uuid.uuid4().hex[:8]}",
|
||||
"object": "chat.completion",
|
||||
"created": int(time.time()),
|
||||
"model": _model_id,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {"role": "assistant", "content": reply},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": input_ids.shape[-1],
|
||||
"completion_tokens": len(new_tokens),
|
||||
"total_tokens": input_ids.shape[-1] + len(new_tokens),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _load_model(model_path: str, gpu_id: int) -> None:
|
||||
global _model, _tokenizer, _model_id, _device
|
||||
_device = f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu"
|
||||
_model_id = model_path
|
||||
_tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||
_model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path,
|
||||
dtype=torch.float16 if "cuda" in _device else torch.float32,
|
||||
device_map={"": _device},
|
||||
trust_remote_code=True,
|
||||
)
|
||||
_model.eval()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="cf-orch generic LLM inference server")
|
||||
parser.add_argument("--model", required=True)
|
||||
parser.add_argument("--port", type=int, default=8000)
|
||||
parser.add_argument("--host", default="0.0.0.0")
|
||||
parser.add_argument("--gpu-id", type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
_load_model(args.model, args.gpu_id)
|
||||
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -1,66 +0,0 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class VRAMLease:
|
||||
lease_id: str
|
||||
gpu_id: int
|
||||
node_id: str
|
||||
mb_granted: int
|
||||
holder_service: str
|
||||
priority: int
|
||||
expires_at: float # unix timestamp; 0.0 = no expiry
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
gpu_id: int,
|
||||
node_id: str,
|
||||
mb: int,
|
||||
service: str,
|
||||
priority: int,
|
||||
ttl_s: float = 0.0,
|
||||
) -> VRAMLease:
|
||||
return cls(
|
||||
lease_id=str(uuid.uuid4()),
|
||||
gpu_id=gpu_id,
|
||||
node_id=node_id,
|
||||
mb_granted=mb,
|
||||
holder_service=service,
|
||||
priority=priority,
|
||||
expires_at=time.time() + ttl_s if ttl_s > 0.0 else 0.0,
|
||||
)
|
||||
|
||||
def is_expired(self) -> bool:
|
||||
return self.expires_at > 0.0 and time.time() > self.expires_at
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GpuInfo:
|
||||
gpu_id: int
|
||||
name: str
|
||||
vram_total_mb: int
|
||||
vram_used_mb: int
|
||||
vram_free_mb: int
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ResidentAllocation:
|
||||
"""A model that is loaded and warm in VRAM but not actively serving a request."""
|
||||
service: str
|
||||
node_id: str
|
||||
model_name: Optional[str] # None if service is running but model probe failed
|
||||
first_seen: float = field(default_factory=time.time)
|
||||
|
||||
|
||||
@dataclass
|
||||
class NodeInfo:
|
||||
node_id: str
|
||||
agent_url: str
|
||||
gpus: list[GpuInfo]
|
||||
last_heartbeat: float = field(default_factory=time.time)
|
||||
|
|
@ -1,33 +0,0 @@
|
|||
schema_version: 1
|
||||
name: cpu-16gb
|
||||
eviction_timeout_s: 30.0
|
||||
services:
|
||||
ollama:
|
||||
max_mb: 0
|
||||
priority: 1
|
||||
cf-stt:
|
||||
max_mb: 0
|
||||
priority: 2
|
||||
shared: true
|
||||
max_concurrent: 1
|
||||
backend: moonshine
|
||||
cf-tts:
|
||||
max_mb: 0
|
||||
priority: 2
|
||||
shared: true
|
||||
max_concurrent: 1
|
||||
cf-embed:
|
||||
max_mb: 0
|
||||
priority: 2
|
||||
shared: true
|
||||
max_concurrent: 2
|
||||
always_on: true
|
||||
cf-classify:
|
||||
max_mb: 0
|
||||
priority: 2
|
||||
shared: true
|
||||
max_concurrent: 2
|
||||
always_on: true
|
||||
model_size_hints:
|
||||
llm_max_params: 3b-q4
|
||||
image_gen_max: none
|
||||
|
|
@ -1,33 +0,0 @@
|
|||
schema_version: 1
|
||||
name: cpu-32gb
|
||||
eviction_timeout_s: 30.0
|
||||
services:
|
||||
ollama:
|
||||
max_mb: 0
|
||||
priority: 1
|
||||
cf-stt:
|
||||
max_mb: 0
|
||||
priority: 2
|
||||
shared: true
|
||||
max_concurrent: 2
|
||||
backend: faster-whisper
|
||||
cf-tts:
|
||||
max_mb: 0
|
||||
priority: 2
|
||||
shared: true
|
||||
max_concurrent: 2
|
||||
cf-embed:
|
||||
max_mb: 0
|
||||
priority: 2
|
||||
shared: true
|
||||
max_concurrent: 4
|
||||
always_on: true
|
||||
cf-classify:
|
||||
max_mb: 0
|
||||
priority: 2
|
||||
shared: true
|
||||
max_concurrent: 4
|
||||
always_on: true
|
||||
model_size_hints:
|
||||
llm_max_params: 7b-q4
|
||||
image_gen_max: none
|
||||
|
|
@ -1,65 +0,0 @@
|
|||
schema_version: 1
|
||||
name: single-gpu-16gb
|
||||
vram_total_mb: 16384
|
||||
eviction_timeout_s: 10.0
|
||||
services:
|
||||
vllm:
|
||||
max_mb: 9000
|
||||
priority: 1
|
||||
idle_stop_after_s: 600
|
||||
managed:
|
||||
type: process
|
||||
exec_path: "/devl/miniconda3/envs/cf/bin/python"
|
||||
args_template: "-m circuitforge_core.resources.inference.llm_server --model /Library/Assets/LLM/vllm/models/{model} --port {port} --gpu-id {gpu_id}"
|
||||
port: 8000
|
||||
host_port: 8000
|
||||
cwd: "/Library/Development/CircuitForge/circuitforge-core"
|
||||
ollama:
|
||||
max_mb: 12288
|
||||
priority: 1
|
||||
cf-vision:
|
||||
max_mb: 3072
|
||||
priority: 2
|
||||
shared: true
|
||||
max_concurrent: 4
|
||||
cf-docuvision:
|
||||
max_mb: 6144
|
||||
priority: 2
|
||||
shared: true
|
||||
max_concurrent: 3
|
||||
managed:
|
||||
type: process
|
||||
exec_path: "/devl/miniconda3/envs/cf/bin/python"
|
||||
args_template: "-m circuitforge_core.resources.docuvision.app --model /Library/Assets/LLM/docuvision/models/dolphin-v2 --port {port} --gpu-id {gpu_id}"
|
||||
port: 8003
|
||||
host_port: 8003
|
||||
cwd: "/Library/Development/CircuitForge/circuitforge-core"
|
||||
cf-stt:
|
||||
max_mb: 1200
|
||||
priority: 2
|
||||
shared: true
|
||||
max_concurrent: 3
|
||||
backend: parakeet-tdt
|
||||
cf-tts:
|
||||
max_mb: 1024
|
||||
priority: 2
|
||||
shared: true
|
||||
max_concurrent: 3
|
||||
cf-embed:
|
||||
max_mb: 512
|
||||
priority: 2
|
||||
shared: true
|
||||
max_concurrent: 6
|
||||
always_on: true
|
||||
cf-classify:
|
||||
max_mb: 512
|
||||
priority: 2
|
||||
shared: true
|
||||
max_concurrent: 6
|
||||
always_on: true
|
||||
comfyui:
|
||||
max_mb: 14336
|
||||
priority: 4
|
||||
model_size_hints:
|
||||
llm_max_params: 34b
|
||||
image_gen_max: flux-dev-fp8
|
||||
|
|
@ -1,65 +0,0 @@
|
|||
schema_version: 1
|
||||
name: single-gpu-24gb
|
||||
vram_total_mb: 24576
|
||||
eviction_timeout_s: 10.0
|
||||
services:
|
||||
vllm:
|
||||
max_mb: 9000
|
||||
priority: 1
|
||||
idle_stop_after_s: 600
|
||||
managed:
|
||||
type: process
|
||||
exec_path: "/devl/miniconda3/envs/cf/bin/python"
|
||||
args_template: "-m circuitforge_core.resources.inference.llm_server --model /Library/Assets/LLM/vllm/models/{model} --port {port} --gpu-id {gpu_id}"
|
||||
port: 8000
|
||||
host_port: 8000
|
||||
cwd: "/Library/Development/CircuitForge/circuitforge-core"
|
||||
ollama:
|
||||
max_mb: 18432
|
||||
priority: 1
|
||||
cf-vision:
|
||||
max_mb: 4096
|
||||
priority: 2
|
||||
shared: true
|
||||
max_concurrent: 6
|
||||
cf-docuvision:
|
||||
max_mb: 8192
|
||||
priority: 2
|
||||
shared: true
|
||||
max_concurrent: 4
|
||||
managed:
|
||||
type: process
|
||||
exec_path: "/devl/miniconda3/envs/cf/bin/python"
|
||||
args_template: "-m circuitforge_core.resources.docuvision.app --model /Library/Assets/LLM/docuvision/models/dolphin-v2 --port {port} --gpu-id {gpu_id}"
|
||||
port: 8003
|
||||
host_port: 8003
|
||||
cwd: "/Library/Development/CircuitForge/circuitforge-core"
|
||||
cf-stt:
|
||||
max_mb: 1200
|
||||
priority: 2
|
||||
shared: true
|
||||
max_concurrent: 4
|
||||
backend: parakeet-tdt
|
||||
cf-tts:
|
||||
max_mb: 1024
|
||||
priority: 2
|
||||
shared: true
|
||||
max_concurrent: 4
|
||||
cf-embed:
|
||||
max_mb: 512
|
||||
priority: 2
|
||||
shared: true
|
||||
max_concurrent: 8
|
||||
always_on: true
|
||||
cf-classify:
|
||||
max_mb: 512
|
||||
priority: 2
|
||||
shared: true
|
||||
max_concurrent: 8
|
||||
always_on: true
|
||||
comfyui:
|
||||
max_mb: 20480
|
||||
priority: 4
|
||||
model_size_hints:
|
||||
llm_max_params: 70b
|
||||
image_gen_max: flux-dev-fp16
|
||||
|
|
@ -1,22 +0,0 @@
|
|||
schema_version: 1
|
||||
name: single-gpu-2gb
|
||||
vram_total_mb: 2048
|
||||
eviction_timeout_s: 15.0
|
||||
services:
|
||||
ollama:
|
||||
max_mb: 1536
|
||||
priority: 1
|
||||
cf-vision:
|
||||
max_mb: 512
|
||||
priority: 2
|
||||
shared: true
|
||||
max_concurrent: 1
|
||||
cf-stt:
|
||||
max_mb: 200
|
||||
priority: 2
|
||||
shared: true
|
||||
max_concurrent: 1
|
||||
backend: moonshine
|
||||
model_size_hints:
|
||||
llm_max_params: 3b
|
||||
image_gen_max: none
|
||||
|
|
@ -1,30 +0,0 @@
|
|||
schema_version: 1
|
||||
name: single-gpu-4gb
|
||||
vram_total_mb: 4096
|
||||
eviction_timeout_s: 15.0
|
||||
services:
|
||||
ollama:
|
||||
max_mb: 3072
|
||||
priority: 1
|
||||
cf-vision:
|
||||
max_mb: 1024
|
||||
priority: 2
|
||||
shared: true
|
||||
max_concurrent: 1
|
||||
cf-stt:
|
||||
max_mb: 600
|
||||
priority: 2
|
||||
shared: true
|
||||
max_concurrent: 1
|
||||
backend: faster-whisper
|
||||
cf-tts:
|
||||
max_mb: 512
|
||||
priority: 2
|
||||
shared: true
|
||||
max_concurrent: 1
|
||||
comfyui:
|
||||
max_mb: 3584
|
||||
priority: 4
|
||||
model_size_hints:
|
||||
llm_max_params: 3b
|
||||
image_gen_max: sd15-fp8
|
||||
|
|
@ -1,53 +0,0 @@
|
|||
schema_version: 1
|
||||
name: single-gpu-6gb
|
||||
vram_total_mb: 6144
|
||||
eviction_timeout_s: 10.0
|
||||
services:
|
||||
vllm:
|
||||
max_mb: 5500
|
||||
priority: 1
|
||||
idle_stop_after_s: 600
|
||||
managed:
|
||||
type: process
|
||||
exec_path: "/devl/miniconda3/envs/cf/bin/python"
|
||||
args_template: "-m circuitforge_core.resources.inference.llm_server --model /Library/Assets/LLM/vllm/models/{model} --port {port} --gpu-id {gpu_id}"
|
||||
port: 8000
|
||||
host_port: 8000
|
||||
cwd: "/Library/Development/CircuitForge/circuitforge-core"
|
||||
ollama:
|
||||
max_mb: 3584
|
||||
priority: 1
|
||||
cf-vision:
|
||||
max_mb: 1536
|
||||
priority: 2
|
||||
shared: true
|
||||
max_concurrent: 2
|
||||
cf-docuvision:
|
||||
max_mb: 3072
|
||||
priority: 2
|
||||
shared: true
|
||||
max_concurrent: 1
|
||||
managed:
|
||||
type: process
|
||||
exec_path: "/devl/miniconda3/envs/cf/bin/python"
|
||||
args_template: "-m circuitforge_core.resources.docuvision.app --model /Library/Assets/LLM/docuvision/models/dolphin-v2 --port {port} --gpu-id {gpu_id}"
|
||||
port: 8003
|
||||
host_port: 8003
|
||||
cwd: "/Library/Development/CircuitForge/circuitforge-core"
|
||||
cf-stt:
|
||||
max_mb: 600
|
||||
priority: 2
|
||||
shared: true
|
||||
max_concurrent: 2
|
||||
backend: faster-whisper
|
||||
cf-tts:
|
||||
max_mb: 768
|
||||
priority: 2
|
||||
shared: true
|
||||
max_concurrent: 1
|
||||
comfyui:
|
||||
max_mb: 5120
|
||||
priority: 4
|
||||
model_size_hints:
|
||||
llm_max_params: 7b
|
||||
image_gen_max: sd15
|
||||
|
|
@ -1,60 +0,0 @@
|
|||
schema_version: 1
|
||||
name: single-gpu-8gb
|
||||
vram_total_mb: 8192
|
||||
eviction_timeout_s: 10.0
|
||||
services:
|
||||
vllm:
|
||||
max_mb: 6500
|
||||
priority: 1
|
||||
idle_stop_after_s: 600
|
||||
managed:
|
||||
type: process
|
||||
exec_path: "/devl/miniconda3/envs/cf/bin/python"
|
||||
args_template: "-m circuitforge_core.resources.inference.llm_server --model /Library/Assets/LLM/vllm/models/{model} --port {port} --gpu-id {gpu_id}"
|
||||
port: 8000
|
||||
host_port: 8000
|
||||
cwd: "/Library/Development/CircuitForge/circuitforge-core"
|
||||
ollama:
|
||||
max_mb: 4096
|
||||
priority: 1
|
||||
cf-vision:
|
||||
max_mb: 2048
|
||||
priority: 2
|
||||
shared: true
|
||||
max_concurrent: 3
|
||||
cf-docuvision:
|
||||
max_mb: 4096
|
||||
priority: 2
|
||||
shared: true
|
||||
max_concurrent: 2
|
||||
managed:
|
||||
type: process
|
||||
exec_path: "/devl/miniconda3/envs/cf/bin/python"
|
||||
args_template: "-m circuitforge_core.resources.docuvision.app --model /Library/Assets/LLM/docuvision/models/dolphin-v2 --port {port} --gpu-id {gpu_id}"
|
||||
port: 8003
|
||||
host_port: 8003
|
||||
cwd: "/Library/Development/CircuitForge/circuitforge-core"
|
||||
cf-stt:
|
||||
max_mb: 1200
|
||||
priority: 2
|
||||
shared: true
|
||||
max_concurrent: 2
|
||||
backend: parakeet-tdt
|
||||
cf-tts:
|
||||
max_mb: 1024
|
||||
priority: 2
|
||||
shared: true
|
||||
max_concurrent: 2
|
||||
comfyui:
|
||||
max_mb: 6144
|
||||
priority: 4
|
||||
managed:
|
||||
type: process
|
||||
exec_path: "/opt/miniconda3/envs/comfyui/bin/python"
|
||||
args_template: "/opt/ComfyUI/main.py --listen 0.0.0.0 --port {port} --cuda-device {gpu_id}"
|
||||
cwd: "/opt/ComfyUI"
|
||||
port: 8188
|
||||
host_port: 8188
|
||||
model_size_hints:
|
||||
llm_max_params: 8b
|
||||
image_gen_max: sdxl-fp8
|
||||
|
|
@ -1,116 +0,0 @@
|
|||
# circuitforge_core/resources/profiles/schema.py
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
SUPPORTED_SCHEMA_VERSION = 1
|
||||
|
||||
|
||||
class DockerSpec(BaseModel):
|
||||
"""Spec for a Docker-managed service."""
|
||||
|
||||
image: str
|
||||
port: int
|
||||
host_port: int
|
||||
command_template: str = ""
|
||||
volumes: list[str] = Field(default_factory=list)
|
||||
env: dict[str, str] = Field(default_factory=dict)
|
||||
runtime: str = "nvidia"
|
||||
ipc: str = "host"
|
||||
|
||||
model_config = {"frozen": True}
|
||||
|
||||
|
||||
class ProcessSpec(BaseModel):
|
||||
"""Spec for a process-managed service (non-Docker, e.g. conda env)."""
|
||||
|
||||
exec_path: str
|
||||
args_template: str = ""
|
||||
cwd: str = ""
|
||||
env: dict[str, str] = Field(default_factory=dict)
|
||||
port: int = 0
|
||||
host_port: int = 0
|
||||
|
||||
model_config = {"frozen": True}
|
||||
|
||||
|
||||
class ServiceProfile(BaseModel):
|
||||
max_mb: int
|
||||
priority: int
|
||||
shared: bool = False
|
||||
max_concurrent: int = 1
|
||||
always_on: bool = False
|
||||
idle_stop_after_s: int = 0
|
||||
backend: str | None = None
|
||||
consumers: list[str] = Field(default_factory=list)
|
||||
managed: DockerSpec | ProcessSpec | None = None
|
||||
|
||||
model_config = {"frozen": True}
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def _parse_managed(cls, values: Any) -> Any:
|
||||
if not isinstance(values, dict):
|
||||
return values
|
||||
raw = values.get("managed")
|
||||
if raw is None:
|
||||
return values
|
||||
if not isinstance(raw, dict):
|
||||
return values
|
||||
spec_type = raw.get("type")
|
||||
managed_fields = {k: v for k, v in raw.items() if k != "type"}
|
||||
if spec_type == "docker":
|
||||
values["managed"] = DockerSpec(**managed_fields)
|
||||
elif spec_type == "process":
|
||||
values["managed"] = ProcessSpec(**managed_fields)
|
||||
else:
|
||||
raise ValueError(f"Unknown managed service type: {spec_type!r}")
|
||||
return values
|
||||
|
||||
|
||||
class GpuNodeEntry(BaseModel):
|
||||
id: int
|
||||
vram_mb: int
|
||||
role: str
|
||||
card: str = "unknown"
|
||||
always_on: bool = False
|
||||
services: list[str] = Field(default_factory=list)
|
||||
|
||||
model_config = {"frozen": True}
|
||||
|
||||
|
||||
class NodeProfile(BaseModel):
|
||||
gpus: list[GpuNodeEntry]
|
||||
agent_url: str | None = None
|
||||
nas_mount: str | None = None
|
||||
|
||||
model_config = {"frozen": True}
|
||||
|
||||
|
||||
class GpuProfile(BaseModel):
|
||||
schema_version: int
|
||||
name: str
|
||||
vram_total_mb: int | None = None
|
||||
eviction_timeout_s: float = 10.0
|
||||
services: dict[str, ServiceProfile] = Field(default_factory=dict)
|
||||
model_size_hints: dict[str, str] = Field(default_factory=dict)
|
||||
nodes: dict[str, NodeProfile] = Field(default_factory=dict)
|
||||
|
||||
model_config = {"frozen": True}
|
||||
|
||||
|
||||
def load_profile(path: Path) -> GpuProfile:
|
||||
raw: dict[str, Any] = yaml.safe_load(path.read_text())
|
||||
if not isinstance(raw, dict):
|
||||
raise ValueError(f"Profile file {path} must be a YAML mapping, got {type(raw).__name__}")
|
||||
version = raw.get("schema_version")
|
||||
if version != SUPPORTED_SCHEMA_VERSION:
|
||||
raise ValueError(
|
||||
f"Unsupported schema_version {version!r} in {path}. "
|
||||
f"Expected {SUPPORTED_SCHEMA_VERSION}."
|
||||
)
|
||||
return GpuProfile.model_validate(raw)
|
||||
79
circuitforge_core/stt/__init__.py
Normal file
79
circuitforge_core/stt/__init__.py
Normal file
|
|
@ -0,0 +1,79 @@
|
|||
"""
|
||||
circuitforge_core.stt — Speech-to-text service module.
|
||||
|
||||
Quick start (mock mode — no GPU or model required):
|
||||
|
||||
import os; os.environ["CF_STT_MOCK"] = "1"
|
||||
from circuitforge_core.stt import transcribe
|
||||
|
||||
result = transcribe(open("audio.wav", "rb").read())
|
||||
print(result.text, result.confidence)
|
||||
|
||||
Real inference (faster-whisper):
|
||||
|
||||
export CF_STT_MODEL=/Library/Assets/LLM/whisper/models/Whisper/faster-whisper/models--Systran--faster-whisper-medium/snapshots/<hash>
|
||||
from circuitforge_core.stt import transcribe
|
||||
|
||||
cf-orch service profile:
|
||||
|
||||
service_type: cf-stt
|
||||
max_mb: 1024 (medium); 600 (base/small)
|
||||
max_concurrent: 3
|
||||
shared: true
|
||||
managed:
|
||||
exec: python -m circuitforge_core.stt.app
|
||||
args: --model <path> --port {port} --gpu-id {gpu_id}
|
||||
port: 8004
|
||||
health: /health
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
from circuitforge_core.stt.backends.base import (
|
||||
STTBackend,
|
||||
STTResult,
|
||||
STTSegment,
|
||||
make_stt_backend,
|
||||
)
|
||||
from circuitforge_core.stt.backends.mock import MockSTTBackend
|
||||
|
||||
_backend: STTBackend | None = None
|
||||
|
||||
|
||||
def _get_backend() -> STTBackend:
|
||||
global _backend
|
||||
if _backend is None:
|
||||
model_path = os.environ.get("CF_STT_MODEL", "mock")
|
||||
mock = model_path == "mock" or os.environ.get("CF_STT_MOCK", "") == "1"
|
||||
_backend = make_stt_backend(model_path, mock=mock)
|
||||
return _backend
|
||||
|
||||
|
||||
def transcribe(
|
||||
audio: bytes,
|
||||
*,
|
||||
language: str | None = None,
|
||||
confidence_threshold: float = STTResult.CONFIDENCE_DEFAULT_THRESHOLD,
|
||||
) -> STTResult:
|
||||
"""Transcribe audio bytes using the process-level backend."""
|
||||
return _get_backend().transcribe(
|
||||
audio, language=language, confidence_threshold=confidence_threshold
|
||||
)
|
||||
|
||||
|
||||
def reset_backend() -> None:
|
||||
"""Reset the process-level singleton. Test teardown only."""
|
||||
global _backend
|
||||
_backend = None
|
||||
|
||||
|
||||
__all__ = [
|
||||
"STTBackend",
|
||||
"STTResult",
|
||||
"STTSegment",
|
||||
"MockSTTBackend",
|
||||
"make_stt_backend",
|
||||
"transcribe",
|
||||
"reset_backend",
|
||||
]
|
||||
150
circuitforge_core/stt/app.py
Normal file
150
circuitforge_core/stt/app.py
Normal file
|
|
@ -0,0 +1,150 @@
|
|||
"""
|
||||
circuitforge_core.stt.app — cf-stt FastAPI service.
|
||||
|
||||
Managed by cf-orch as a process-type service. cf-orch starts this via:
|
||||
|
||||
python -m circuitforge_core.stt.app \
|
||||
--model /Library/Assets/LLM/whisper/models/Whisper/faster-whisper/models--Systran--faster-whisper-medium/snapshots/<hash> \
|
||||
--port 8004 \
|
||||
--gpu-id 0
|
||||
|
||||
Endpoints:
|
||||
GET /health → {"status": "ok", "model": "<name>", "vram_mb": <n>}
|
||||
POST /transcribe → STTTranscribeResponse (multipart: audio file)
|
||||
|
||||
Audio format: any format ffmpeg understands (WAV, MP3, OGG, FLAC).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, File, Form, HTTPException, UploadFile
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from circuitforge_core.stt.backends.base import STTResult, make_stt_backend
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ── Response model (mirrors circuitforge_orch.contracts.stt.STTTranscribeResponse) ──
|
||||
|
||||
class TranscribeResponse(BaseModel):
|
||||
text: str
|
||||
confidence: float
|
||||
below_threshold: bool
|
||||
language: str | None = None
|
||||
duration_s: float | None = None
|
||||
segments: list[dict] = []
|
||||
model: str = ""
|
||||
|
||||
|
||||
# ── App factory ───────────────────────────────────────────────────────────────
|
||||
|
||||
def create_app(
|
||||
model_path: str,
|
||||
device: str = "cuda",
|
||||
compute_type: str = "float16",
|
||||
confidence_threshold: float = STTResult.CONFIDENCE_DEFAULT_THRESHOLD,
|
||||
mock: bool = False,
|
||||
) -> FastAPI:
|
||||
app = FastAPI(title="cf-stt", version="0.1.0")
|
||||
backend = make_stt_backend(
|
||||
model_path, device=device, compute_type=compute_type, mock=mock
|
||||
)
|
||||
logger.info("cf-stt ready: model=%r vram=%dMB", backend.model_name, backend.vram_mb)
|
||||
|
||||
@app.get("/health")
|
||||
async def health() -> dict:
|
||||
return {"status": "ok", "model": backend.model_name, "vram_mb": backend.vram_mb}
|
||||
|
||||
@app.post("/transcribe", response_model=TranscribeResponse)
|
||||
async def transcribe(
|
||||
audio: UploadFile = File(..., description="Audio file (WAV, MP3, OGG, FLAC, ...)"),
|
||||
language: str | None = Form(None, description="BCP-47 language code hint, e.g. 'en'"),
|
||||
confidence_threshold_override: float | None = Form(
|
||||
None,
|
||||
description="Override default confidence threshold for this request.",
|
||||
),
|
||||
) -> TranscribeResponse:
|
||||
audio_bytes = await audio.read()
|
||||
if not audio_bytes:
|
||||
raise HTTPException(status_code=400, detail="Empty audio file")
|
||||
|
||||
threshold = confidence_threshold_override or confidence_threshold
|
||||
try:
|
||||
result = backend.transcribe(
|
||||
audio_bytes, language=language, confidence_threshold=threshold
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.exception("Transcription failed")
|
||||
raise HTTPException(status_code=500, detail=str(exc)) from exc
|
||||
|
||||
return TranscribeResponse(
|
||||
text=result.text,
|
||||
confidence=result.confidence,
|
||||
below_threshold=result.below_threshold,
|
||||
language=result.language,
|
||||
duration_s=result.duration_s,
|
||||
segments=[
|
||||
{
|
||||
"start_s": s.start_s,
|
||||
"end_s": s.end_s,
|
||||
"text": s.text,
|
||||
"confidence": s.confidence,
|
||||
}
|
||||
for s in result.segments
|
||||
],
|
||||
model=result.model,
|
||||
)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
# ── CLI entry point ───────────────────────────────────────────────────────────
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="cf-stt — CircuitForge STT service")
|
||||
parser.add_argument("--model", required=True,
|
||||
help="Model path or size name (e.g. 'medium', or full local path)")
|
||||
parser.add_argument("--port", type=int, default=8004)
|
||||
parser.add_argument("--host", default="0.0.0.0")
|
||||
parser.add_argument("--gpu-id", type=int, default=0,
|
||||
help="CUDA device index (sets CUDA_VISIBLE_DEVICES)")
|
||||
parser.add_argument("--device", default="cuda", choices=["cuda", "cpu"])
|
||||
parser.add_argument("--compute-type", default="float16",
|
||||
choices=["float16", "int8", "int8_float16", "float32"],
|
||||
help="Quantisation / compute type passed to faster-whisper")
|
||||
parser.add_argument("--confidence-threshold", type=float,
|
||||
default=STTResult.CONFIDENCE_DEFAULT_THRESHOLD)
|
||||
parser.add_argument("--mock", action="store_true",
|
||||
help="Run with mock backend (no GPU, for testing)")
|
||||
args = parser.parse_args()
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s %(levelname)s %(name)s %(message)s",
|
||||
)
|
||||
|
||||
# Let cf-orch pass --gpu-id; map to CUDA_VISIBLE_DEVICES so the process
|
||||
# only sees its assigned GPU. This prevents accidental multi-GPU usage.
|
||||
if args.device == "cuda" and not args.mock:
|
||||
os.environ.setdefault("CUDA_VISIBLE_DEVICES", str(args.gpu_id))
|
||||
|
||||
mock = args.mock or os.environ.get("CF_STT_MOCK", "") == "1"
|
||||
app = create_app(
|
||||
model_path=args.model,
|
||||
device=args.device,
|
||||
compute_type=args.compute_type,
|
||||
confidence_threshold=args.confidence_threshold,
|
||||
mock=mock,
|
||||
)
|
||||
|
||||
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
4
circuitforge_core/stt/backends/__init__.py
Normal file
4
circuitforge_core/stt/backends/__init__.py
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
from .base import STTBackend, STTResult, STTSegment, make_stt_backend
|
||||
from .mock import MockSTTBackend
|
||||
|
||||
__all__ = ["STTBackend", "STTResult", "STTSegment", "make_stt_backend", "MockSTTBackend"]
|
||||
109
circuitforge_core/stt/backends/base.py
Normal file
109
circuitforge_core/stt/backends/base.py
Normal file
|
|
@ -0,0 +1,109 @@
|
|||
# circuitforge_core/stt/backends/base.py — STTBackend Protocol + factory
|
||||
#
|
||||
# MIT licensed. The Protocol and mock are always importable without GPU deps.
|
||||
# Real backends require optional extras:
|
||||
# pip install -e "circuitforge-core[stt-faster-whisper]"
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Protocol, runtime_checkable
|
||||
|
||||
|
||||
# ── Result types ──────────────────────────────────────────────────────────────
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class STTSegment:
|
||||
"""Word- or phrase-level segment (included when the backend supports it)."""
|
||||
start_s: float
|
||||
end_s: float
|
||||
text: str
|
||||
confidence: float
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class STTResult:
|
||||
"""
|
||||
Standard result from any STTBackend.transcribe() call.
|
||||
|
||||
confidence is normalised to 0.0–1.0 regardless of the backend's native metric.
|
||||
below_threshold is True when confidence < the configured threshold (default 0.75).
|
||||
This flag is safety-critical for products like Osprey: DTMF must NOT be sent
|
||||
when below_threshold is True.
|
||||
"""
|
||||
text: str
|
||||
confidence: float # 0.0–1.0
|
||||
below_threshold: bool
|
||||
language: str | None = None
|
||||
duration_s: float | None = None
|
||||
segments: list[STTSegment] = field(default_factory=list)
|
||||
model: str = ""
|
||||
|
||||
CONFIDENCE_DEFAULT_THRESHOLD: float = 0.75
|
||||
|
||||
|
||||
# ── Protocol ──────────────────────────────────────────────────────────────────
|
||||
|
||||
@runtime_checkable
|
||||
class STTBackend(Protocol):
|
||||
"""
|
||||
Abstract interface for speech-to-text backends.
|
||||
|
||||
All backends load their model once at construction time and are safe to
|
||||
call concurrently (the model weights are read-only after load).
|
||||
"""
|
||||
|
||||
def transcribe(
|
||||
self,
|
||||
audio: bytes,
|
||||
*,
|
||||
language: str | None = None,
|
||||
confidence_threshold: float = STTResult.CONFIDENCE_DEFAULT_THRESHOLD,
|
||||
) -> STTResult:
|
||||
"""Synchronous transcription. audio is raw PCM or any format ffmpeg understands."""
|
||||
...
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
"""Identifier for the loaded model (path stem or size name)."""
|
||||
...
|
||||
|
||||
@property
|
||||
def vram_mb(self) -> int:
|
||||
"""Approximate VRAM footprint in MB. Used by cf-orch service registry."""
|
||||
...
|
||||
|
||||
|
||||
# ── Factory ───────────────────────────────────────────────────────────────────
|
||||
|
||||
def make_stt_backend(
|
||||
model_path: str,
|
||||
backend: str | None = None,
|
||||
mock: bool | None = None,
|
||||
device: str = "cuda",
|
||||
compute_type: str = "float16",
|
||||
) -> STTBackend:
|
||||
"""
|
||||
Return an STTBackend for the given model.
|
||||
|
||||
mock=True or CF_STT_MOCK=1 → MockSTTBackend (no GPU, no model file needed)
|
||||
backend="faster-whisper" → FasterWhisperBackend (default)
|
||||
|
||||
device and compute_type are passed through to the backend and ignored by mock.
|
||||
"""
|
||||
use_mock = mock if mock is not None else os.environ.get("CF_STT_MOCK", "") == "1"
|
||||
if use_mock:
|
||||
from circuitforge_core.stt.backends.mock import MockSTTBackend
|
||||
return MockSTTBackend(model_name=model_path)
|
||||
|
||||
resolved = backend or os.environ.get("CF_STT_BACKEND", "faster-whisper")
|
||||
if resolved == "faster-whisper":
|
||||
from circuitforge_core.stt.backends.faster_whisper import FasterWhisperBackend
|
||||
return FasterWhisperBackend(
|
||||
model_path=model_path, device=device, compute_type=compute_type
|
||||
)
|
||||
|
||||
raise ValueError(
|
||||
f"Unknown STT backend {resolved!r}. "
|
||||
"Expected 'faster-whisper'. Set CF_STT_BACKEND or pass backend= explicitly."
|
||||
)
|
||||
139
circuitforge_core/stt/backends/faster_whisper.py
Normal file
139
circuitforge_core/stt/backends/faster_whisper.py
Normal file
|
|
@ -0,0 +1,139 @@
|
|||
# circuitforge_core/stt/backends/faster_whisper.py — FasterWhisperBackend
|
||||
#
|
||||
# MIT licensed. Requires: pip install -e "circuitforge-core[stt-faster-whisper]"
|
||||
#
|
||||
# Model path can be:
|
||||
# - A size name: "base", "small", "medium", "large-v3"
|
||||
# (faster-whisper downloads and caches it on first use)
|
||||
# - A local path: "/Library/Assets/LLM/whisper/models/Whisper/faster-whisper/..."
|
||||
# (preferred for air-gapped nodes — no download needed)
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
from circuitforge_core.stt.backends.base import STTResult, STTSegment
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# VRAM estimates by model size. Used by cf-orch for VRAM budgeting.
|
||||
_VRAM_MB_BY_SIZE: dict[str, int] = {
|
||||
"tiny": 200,
|
||||
"base": 350,
|
||||
"small": 600,
|
||||
"medium": 1024,
|
||||
"large": 2048,
|
||||
"large-v2": 2048,
|
||||
"large-v3": 2048,
|
||||
"distil-large-v3": 1500,
|
||||
}
|
||||
|
||||
# Aggregate confidence from per-segment no_speech_prob values.
|
||||
# faster-whisper doesn't expose a direct confidence score, so we invert the
|
||||
# mean no_speech_prob as a proxy. This is conservative but directionally correct.
|
||||
def _aggregate_confidence(segments: list) -> float:
|
||||
if not segments:
|
||||
return 0.0
|
||||
probs = [max(0.0, 1.0 - getattr(s, "no_speech_prob", 0.0)) for s in segments]
|
||||
return sum(probs) / len(probs)
|
||||
|
||||
|
||||
class FasterWhisperBackend:
|
||||
"""
|
||||
faster-whisper STT backend.
|
||||
|
||||
Thread-safe after construction: WhisperModel internally manages its own
|
||||
CUDA context and is safe to call from multiple threads.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_path: str,
|
||||
device: str = "cuda",
|
||||
compute_type: str = "float16",
|
||||
) -> None:
|
||||
try:
|
||||
from faster_whisper import WhisperModel
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"faster-whisper is not installed. "
|
||||
"Run: pip install -e 'circuitforge-core[stt-faster-whisper]'"
|
||||
) from exc
|
||||
|
||||
logger.info("Loading faster-whisper model from %r (device=%s)", model_path, device)
|
||||
self._model_path = model_path
|
||||
self._device = device
|
||||
self._compute_type = compute_type
|
||||
self._model = WhisperModel(model_path, device=device, compute_type=compute_type)
|
||||
logger.info("faster-whisper model ready")
|
||||
|
||||
# Determine VRAM footprint from model name/path stem.
|
||||
stem = os.path.basename(model_path.rstrip("/")).lower()
|
||||
self._vram_mb = next(
|
||||
(v for k, v in _VRAM_MB_BY_SIZE.items() if k in stem),
|
||||
1024, # conservative default if size can't be inferred
|
||||
)
|
||||
|
||||
def transcribe(
|
||||
self,
|
||||
audio: bytes,
|
||||
*,
|
||||
language: str | None = None,
|
||||
confidence_threshold: float = STTResult.CONFIDENCE_DEFAULT_THRESHOLD,
|
||||
) -> STTResult:
|
||||
"""
|
||||
Transcribe raw audio bytes.
|
||||
|
||||
audio can be any format ffmpeg understands (WAV, MP3, OGG, FLAC, etc.).
|
||||
faster-whisper writes audio to a temp file internally; we follow the
|
||||
same pattern to avoid holding the bytes in memory longer than needed.
|
||||
"""
|
||||
with tempfile.NamedTemporaryFile(suffix=".audio", delete=False) as tmp:
|
||||
tmp.write(audio)
|
||||
tmp_path = tmp.name
|
||||
|
||||
try:
|
||||
segments_gen, info = self._model.transcribe(
|
||||
tmp_path,
|
||||
language=language,
|
||||
word_timestamps=True,
|
||||
vad_filter=True,
|
||||
)
|
||||
segments = list(segments_gen)
|
||||
finally:
|
||||
os.unlink(tmp_path)
|
||||
|
||||
text = " ".join(s.text.strip() for s in segments).strip()
|
||||
confidence = _aggregate_confidence(segments)
|
||||
duration_s = info.duration if hasattr(info, "duration") else None
|
||||
detected_language = getattr(info, "language", language)
|
||||
|
||||
stt_segments = [
|
||||
STTSegment(
|
||||
start_s=s.start,
|
||||
end_s=s.end,
|
||||
text=s.text.strip(),
|
||||
confidence=max(0.0, 1.0 - getattr(s, "no_speech_prob", 0.0)),
|
||||
)
|
||||
for s in segments
|
||||
]
|
||||
|
||||
return STTResult(
|
||||
text=text,
|
||||
confidence=confidence,
|
||||
below_threshold=confidence < confidence_threshold,
|
||||
language=detected_language,
|
||||
duration_s=duration_s,
|
||||
segments=stt_segments,
|
||||
model=self._model_path,
|
||||
)
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
return self._model_path
|
||||
|
||||
@property
|
||||
def vram_mb(self) -> int:
|
||||
return self._vram_mb
|
||||
54
circuitforge_core/stt/backends/mock.py
Normal file
54
circuitforge_core/stt/backends/mock.py
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
# circuitforge_core/stt/backends/mock.py — MockSTTBackend
|
||||
#
|
||||
# MIT licensed. No GPU, no model file required.
|
||||
# Used in tests and CI, and when CF_STT_MOCK=1.
|
||||
from __future__ import annotations
|
||||
|
||||
from circuitforge_core.stt.backends.base import STTBackend, STTResult
|
||||
|
||||
|
||||
class MockSTTBackend:
|
||||
"""
|
||||
Deterministic mock STT backend for testing.
|
||||
|
||||
Returns a fixed transcript so tests can assert on the response shape
|
||||
without needing a GPU or a model file.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "mock",
|
||||
fixed_text: str = "mock transcription",
|
||||
fixed_confidence: float = 0.95,
|
||||
) -> None:
|
||||
self._model_name = model_name
|
||||
self._fixed_text = fixed_text
|
||||
self._fixed_confidence = fixed_confidence
|
||||
|
||||
def transcribe(
|
||||
self,
|
||||
audio: bytes,
|
||||
*,
|
||||
language: str | None = None,
|
||||
confidence_threshold: float = STTResult.CONFIDENCE_DEFAULT_THRESHOLD,
|
||||
) -> STTResult:
|
||||
return STTResult(
|
||||
text=self._fixed_text,
|
||||
confidence=self._fixed_confidence,
|
||||
below_threshold=self._fixed_confidence < confidence_threshold,
|
||||
language=language or "en",
|
||||
duration_s=float(len(audio)) / 32000, # rough estimate: 16kHz 16-bit mono
|
||||
model=self._model_name,
|
||||
)
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
return self._model_name
|
||||
|
||||
@property
|
||||
def vram_mb(self) -> int:
|
||||
return 0
|
||||
|
||||
|
||||
# Satisfy the Protocol at import time (no GPU needed)
|
||||
assert isinstance(MockSTTBackend(), STTBackend)
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
# circuitforge_core/tasks/__init__.py
|
||||
from circuitforge_core.tasks.scheduler import (
|
||||
TaskScheduler,
|
||||
LocalScheduler,
|
||||
detect_available_vram_gb,
|
||||
get_scheduler,
|
||||
reset_scheduler,
|
||||
|
|
@ -8,6 +8,7 @@ from circuitforge_core.tasks.scheduler import (
|
|||
|
||||
__all__ = [
|
||||
"TaskScheduler",
|
||||
"LocalScheduler",
|
||||
"detect_available_vram_gb",
|
||||
"get_scheduler",
|
||||
"reset_scheduler",
|
||||
|
|
|
|||
|
|
@ -1,21 +1,17 @@
|
|||
# circuitforge_core/tasks/scheduler.py
|
||||
"""Resource-aware batch scheduler for LLM background tasks.
|
||||
"""Task scheduler for CircuitForge products — MIT layer.
|
||||
|
||||
Generic scheduler that any CircuitForge product can use. Products supply:
|
||||
- task_types: frozenset[str] — task type strings routed through this scheduler
|
||||
- vram_budgets: dict[str, float] — VRAM GB estimate per task type
|
||||
- run_task_fn — product's task execution function
|
||||
Provides a simple FIFO task queue with no coordinator dependency.
|
||||
|
||||
VRAM detection priority:
|
||||
1. cf-orch coordinator /api/nodes — free VRAM (lease-aware, cooperative)
|
||||
2. scripts.preflight.get_gpus() — total GPU VRAM (Peregrine-era fallback)
|
||||
3. 999.0 — unlimited (CPU-only or no detection available)
|
||||
For coordinator-aware VRAM-budgeted scheduling on paid/premium tiers, install
|
||||
circuitforge-orch and use OrchestratedScheduler instead.
|
||||
|
||||
Public API:
|
||||
TaskScheduler — the scheduler class
|
||||
detect_available_vram_gb() — standalone VRAM query helper
|
||||
get_scheduler() — lazy process-level singleton
|
||||
reset_scheduler() — test teardown only
|
||||
TaskScheduler — Protocol defining the scheduler interface
|
||||
LocalScheduler — Simple FIFO queue implementation (MIT, no coordinator)
|
||||
detect_available_vram_gb() — Returns 999.0 (unlimited; no coordinator on free tier)
|
||||
get_scheduler() — Lazy process-level singleton returning a LocalScheduler
|
||||
reset_scheduler() — Test teardown only
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
|
|
@ -24,12 +20,7 @@ import sqlite3
|
|||
import threading
|
||||
from collections import deque
|
||||
from pathlib import Path
|
||||
from typing import Callable, NamedTuple, Optional
|
||||
|
||||
try:
|
||||
import httpx as httpx
|
||||
except ImportError:
|
||||
httpx = None # type: ignore[assignment]
|
||||
from typing import Callable, NamedTuple, Optional, Protocol, runtime_checkable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -41,68 +32,45 @@ class TaskSpec(NamedTuple):
|
|||
job_id: int
|
||||
params: Optional[str]
|
||||
|
||||
|
||||
_DEFAULT_MAX_QUEUE_DEPTH = 500
|
||||
|
||||
|
||||
def detect_available_vram_gb(
|
||||
coordinator_url: str = "http://localhost:7700",
|
||||
) -> float:
|
||||
"""Detect available VRAM GB for task scheduling.
|
||||
def detect_available_vram_gb() -> float:
|
||||
"""Return available VRAM for task scheduling.
|
||||
|
||||
Returns free VRAM via cf-orch (sum across all nodes/GPUs) so the scheduler
|
||||
cooperates with other cf-orch consumers. Falls back to preflight total VRAM,
|
||||
then 999.0 (unlimited) if nothing is reachable.
|
||||
Free tier (no coordinator): always returns 999.0 — no VRAM gating.
|
||||
For coordinator-aware VRAM detection use circuitforge_orch.scheduler.
|
||||
"""
|
||||
# 1. Try cf-orch: use free VRAM so the scheduler cooperates with other
|
||||
# cf-orch consumers (vision service, inference services, etc.)
|
||||
if httpx is not None:
|
||||
try:
|
||||
resp = httpx.get(f"{coordinator_url}/api/nodes", timeout=2.0)
|
||||
if resp.status_code == 200:
|
||||
nodes = resp.json().get("nodes", [])
|
||||
total_free_mb = sum(
|
||||
gpu.get("vram_free_mb", 0)
|
||||
for node in nodes
|
||||
for gpu in node.get("gpus", [])
|
||||
)
|
||||
if total_free_mb > 0:
|
||||
free_gb = total_free_mb / 1024.0
|
||||
logger.debug(
|
||||
"Scheduler VRAM from cf-orch: %.1f GB free", free_gb
|
||||
)
|
||||
return free_gb
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 2. Try preflight (systems with nvidia-smi; Peregrine-era fallback)
|
||||
try:
|
||||
from scripts.preflight import get_gpus # type: ignore[import]
|
||||
|
||||
gpus = get_gpus()
|
||||
if gpus:
|
||||
total_gb = sum(g.get("vram_total_gb", 0.0) for g in gpus)
|
||||
logger.debug(
|
||||
"Scheduler VRAM from preflight: %.1f GB total", total_gb
|
||||
)
|
||||
return total_gb
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
logger.debug(
|
||||
"Scheduler VRAM detection unavailable — using unlimited (999 GB)"
|
||||
)
|
||||
return 999.0
|
||||
|
||||
|
||||
class TaskScheduler:
|
||||
"""Resource-aware LLM task batch scheduler.
|
||||
@runtime_checkable
|
||||
class TaskScheduler(Protocol):
|
||||
"""Protocol for task schedulers across free and paid tiers.
|
||||
|
||||
Runs one batch-worker thread per task type while total reserved VRAM
|
||||
stays within the detected available budget. Always allows at least one
|
||||
batch to start even if its budget exceeds available VRAM (prevents
|
||||
permanent starvation on low-VRAM systems).
|
||||
Both LocalScheduler (MIT) and OrchestratedScheduler (BSL, circuitforge-orch)
|
||||
implement this interface so products can inject either without API changes.
|
||||
"""
|
||||
|
||||
Thread-safety: all queue/active state protected by self._lock.
|
||||
def enqueue(self, task_id: int, task_type: str, job_id: int, params: Optional[str]) -> bool:
|
||||
"""Add a task to the queue. Returns True if enqueued, False if queue full."""
|
||||
...
|
||||
|
||||
def start(self) -> None:
|
||||
"""Start the background worker thread."""
|
||||
...
|
||||
|
||||
def shutdown(self, timeout: float = 5.0) -> None:
|
||||
"""Stop the scheduler and wait for it to exit."""
|
||||
...
|
||||
|
||||
|
||||
class LocalScheduler:
|
||||
"""Simple FIFO task scheduler with no coordinator dependency.
|
||||
|
||||
Processes tasks serially per task type. No VRAM gating — all tasks run.
|
||||
Suitable for free tier (single-host, up to 2 GPUs, static config).
|
||||
|
||||
Usage::
|
||||
|
||||
|
|
@ -112,11 +80,7 @@ class TaskScheduler:
|
|||
task_types=frozenset({"cover_letter", "research"}),
|
||||
vram_budgets={"cover_letter": 2.5, "research": 5.0},
|
||||
)
|
||||
task_id, is_new = insert_task(db_path, "cover_letter", job_id)
|
||||
if is_new:
|
||||
enqueued = sched.enqueue(task_id, "cover_letter", job_id, params_json)
|
||||
if not enqueued:
|
||||
mark_task_failed(db_path, task_id, "Queue full")
|
||||
enqueued = sched.enqueue(task_id, "cover_letter", job_id, params_json)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
|
@ -125,11 +89,7 @@ class TaskScheduler:
|
|||
run_task_fn: RunTaskFn,
|
||||
task_types: frozenset[str],
|
||||
vram_budgets: dict[str, float],
|
||||
available_vram_gb: Optional[float] = None,
|
||||
max_queue_depth: int = _DEFAULT_MAX_QUEUE_DEPTH,
|
||||
coordinator_url: str = "http://localhost:7700",
|
||||
service_name: str = "peregrine",
|
||||
lease_priority: int = 2,
|
||||
) -> None:
|
||||
self._db_path = db_path
|
||||
self._run_task = run_task_fn
|
||||
|
|
@ -137,54 +97,22 @@ class TaskScheduler:
|
|||
self._budgets: dict[str, float] = dict(vram_budgets)
|
||||
self._max_queue_depth = max_queue_depth
|
||||
|
||||
self._coordinator_url = coordinator_url.rstrip("/")
|
||||
self._service_name = service_name
|
||||
self._lease_priority = lease_priority
|
||||
|
||||
self._lock = threading.Lock()
|
||||
self._wake = threading.Event()
|
||||
self._stop = threading.Event()
|
||||
self._queues: dict[str, deque[TaskSpec]] = {}
|
||||
self._active: dict[str, threading.Thread] = {}
|
||||
self._reserved_vram: float = 0.0
|
||||
self._thread: Optional[threading.Thread] = None
|
||||
|
||||
self._available_vram: float = (
|
||||
available_vram_gb
|
||||
if available_vram_gb is not None
|
||||
else detect_available_vram_gb()
|
||||
)
|
||||
|
||||
for t in self._task_types:
|
||||
if t not in self._budgets:
|
||||
logger.warning(
|
||||
"No VRAM budget defined for task type %r — "
|
||||
"defaulting to 0.0 GB (no VRAM gating for this type)",
|
||||
t,
|
||||
)
|
||||
|
||||
self._load_queued_tasks()
|
||||
|
||||
def enqueue(
|
||||
self,
|
||||
task_id: int,
|
||||
task_type: str,
|
||||
job_id: int,
|
||||
params: Optional[str],
|
||||
) -> bool:
|
||||
"""Add a task to the scheduler queue.
|
||||
|
||||
Returns True if enqueued successfully.
|
||||
Returns False if the queue is full — caller should mark the task failed.
|
||||
"""
|
||||
def enqueue(self, task_id: int, task_type: str, job_id: int, params: Optional[str]) -> bool:
|
||||
with self._lock:
|
||||
q = self._queues.setdefault(task_type, deque())
|
||||
if len(q) >= self._max_queue_depth:
|
||||
logger.warning(
|
||||
"Queue depth limit for %s (max=%d) — task %d dropped",
|
||||
task_type,
|
||||
self._max_queue_depth,
|
||||
task_id,
|
||||
task_type, self._max_queue_depth, task_id,
|
||||
)
|
||||
return False
|
||||
q.append(TaskSpec(task_id, job_id, params))
|
||||
|
|
@ -192,28 +120,19 @@ class TaskScheduler:
|
|||
return True
|
||||
|
||||
def start(self) -> None:
|
||||
"""Start the background scheduler loop thread. Call once after construction."""
|
||||
self._thread = threading.Thread(
|
||||
target=self._scheduler_loop, name="task-scheduler", daemon=True
|
||||
)
|
||||
self._thread.start()
|
||||
# Wake the loop immediately so tasks loaded from DB at startup are dispatched
|
||||
with self._lock:
|
||||
if any(self._queues.values()):
|
||||
self._wake.set()
|
||||
|
||||
def shutdown(self, timeout: float = 5.0) -> None:
|
||||
"""Signal the scheduler to stop and wait for it to exit.
|
||||
|
||||
Joins both the scheduler loop thread and any active batch worker
|
||||
threads so callers can rely on clean state (e.g. _reserved_vram == 0)
|
||||
immediately after this returns.
|
||||
"""
|
||||
self._stop.set()
|
||||
self._wake.set()
|
||||
if self._thread and self._thread.is_alive():
|
||||
self._thread.join(timeout=timeout)
|
||||
# Join active batch workers so _reserved_vram is settled on return
|
||||
with self._lock:
|
||||
workers = list(self._active.values())
|
||||
for worker in workers:
|
||||
|
|
@ -224,103 +143,25 @@ class TaskScheduler:
|
|||
self._wake.wait(timeout=30)
|
||||
self._wake.clear()
|
||||
with self._lock:
|
||||
# Reap batch threads that finished without waking us.
|
||||
# VRAM accounting is handled solely by _batch_worker's finally block;
|
||||
# the reaper only removes dead entries from _active.
|
||||
for t, thread in list(self._active.items()):
|
||||
if not thread.is_alive():
|
||||
del self._active[t]
|
||||
# Start new type batches while VRAM budget allows
|
||||
candidates = sorted(
|
||||
[
|
||||
t
|
||||
for t in self._queues
|
||||
if self._queues[t] and t not in self._active
|
||||
],
|
||||
[t for t in self._queues if self._queues[t] and t not in self._active],
|
||||
key=lambda t: len(self._queues[t]),
|
||||
reverse=True,
|
||||
)
|
||||
for task_type in candidates:
|
||||
budget = self._budgets.get(task_type, 0.0)
|
||||
# Always allow at least one batch to run
|
||||
if (
|
||||
self._reserved_vram == 0.0
|
||||
or self._reserved_vram + budget <= self._available_vram
|
||||
):
|
||||
thread = threading.Thread(
|
||||
target=self._batch_worker,
|
||||
args=(task_type,),
|
||||
name=f"batch-{task_type}",
|
||||
daemon=True,
|
||||
)
|
||||
self._active[task_type] = thread
|
||||
self._reserved_vram += budget
|
||||
thread.start()
|
||||
|
||||
def _acquire_lease(self, task_type: str) -> Optional[str]:
|
||||
"""Request a VRAM lease from the coordinator. Returns lease_id or None."""
|
||||
if httpx is None:
|
||||
return None
|
||||
budget_gb = self._budgets.get(task_type, 0.0)
|
||||
if budget_gb <= 0:
|
||||
return None
|
||||
mb = int(budget_gb * 1024)
|
||||
try:
|
||||
# Pick the GPU with the most free VRAM on the first registered node
|
||||
resp = httpx.get(f"{self._coordinator_url}/api/nodes", timeout=2.0)
|
||||
if resp.status_code != 200:
|
||||
return None
|
||||
nodes = resp.json().get("nodes", [])
|
||||
if not nodes:
|
||||
return None
|
||||
best_node = best_gpu = best_free = None
|
||||
for node in nodes:
|
||||
for gpu in node.get("gpus", []):
|
||||
free = gpu.get("vram_free_mb", 0)
|
||||
if best_free is None or free > best_free:
|
||||
best_node = node["node_id"]
|
||||
best_gpu = gpu["gpu_id"]
|
||||
best_free = free
|
||||
if best_node is None:
|
||||
return None
|
||||
lease_resp = httpx.post(
|
||||
f"{self._coordinator_url}/api/leases",
|
||||
json={
|
||||
"node_id": best_node,
|
||||
"gpu_id": best_gpu,
|
||||
"mb": mb,
|
||||
"service": self._service_name,
|
||||
"priority": self._lease_priority,
|
||||
},
|
||||
timeout=3.0,
|
||||
)
|
||||
if lease_resp.status_code == 200:
|
||||
lease_id = lease_resp.json()["lease"]["lease_id"]
|
||||
logger.debug(
|
||||
"Acquired VRAM lease %s for task_type=%s (%d MB)",
|
||||
lease_id, task_type, mb,
|
||||
)
|
||||
return lease_id
|
||||
except Exception as exc:
|
||||
logger.debug("Lease acquire failed (non-fatal): %s", exc)
|
||||
return None
|
||||
|
||||
def _release_lease(self, lease_id: str) -> None:
|
||||
"""Release a coordinator VRAM lease. Best-effort; failures are logged only."""
|
||||
if httpx is None or not lease_id:
|
||||
return
|
||||
try:
|
||||
httpx.delete(
|
||||
f"{self._coordinator_url}/api/leases/{lease_id}",
|
||||
timeout=3.0,
|
||||
)
|
||||
logger.debug("Released VRAM lease %s", lease_id)
|
||||
except Exception as exc:
|
||||
logger.debug("Lease release failed (non-fatal): %s", exc)
|
||||
thread = threading.Thread(
|
||||
target=self._batch_worker,
|
||||
args=(task_type,),
|
||||
name=f"batch-{task_type}",
|
||||
daemon=True,
|
||||
)
|
||||
self._active[task_type] = thread
|
||||
thread.start()
|
||||
|
||||
def _batch_worker(self, task_type: str) -> None:
|
||||
"""Serial consumer for one task type. Runs until the type's deque is empty."""
|
||||
lease_id: Optional[str] = self._acquire_lease(task_type)
|
||||
try:
|
||||
while True:
|
||||
with self._lock:
|
||||
|
|
@ -328,19 +169,13 @@ class TaskScheduler:
|
|||
if not q:
|
||||
break
|
||||
task = q.popleft()
|
||||
self._run_task(
|
||||
self._db_path, task.id, task_type, task.job_id, task.params
|
||||
)
|
||||
self._run_task(self._db_path, task.id, task_type, task.job_id, task.params)
|
||||
finally:
|
||||
if lease_id:
|
||||
self._release_lease(lease_id)
|
||||
with self._lock:
|
||||
self._active.pop(task_type, None)
|
||||
self._reserved_vram -= self._budgets.get(task_type, 0.0)
|
||||
self._wake.set()
|
||||
|
||||
def _load_queued_tasks(self) -> None:
|
||||
"""Reload surviving 'queued' tasks from SQLite into deques at startup."""
|
||||
if not self._task_types:
|
||||
return
|
||||
task_types_list = sorted(self._task_types)
|
||||
|
|
@ -354,68 +189,58 @@ class TaskScheduler:
|
|||
task_types_list,
|
||||
).fetchall()
|
||||
except sqlite3.OperationalError:
|
||||
# Table not yet created (first run before migrations)
|
||||
rows = []
|
||||
|
||||
for row_id, task_type, job_id, params in rows:
|
||||
q = self._queues.setdefault(task_type, deque())
|
||||
q.append(TaskSpec(row_id, job_id, params))
|
||||
|
||||
if rows:
|
||||
logger.info(
|
||||
"Scheduler: resumed %d queued task(s) from prior run", len(rows)
|
||||
)
|
||||
logger.info("Scheduler: resumed %d queued task(s) from prior run", len(rows))
|
||||
|
||||
|
||||
# ── Process-level singleton ────────────────────────────────────────────────────
|
||||
|
||||
_scheduler: Optional[TaskScheduler] = None
|
||||
_scheduler: Optional[LocalScheduler] = None
|
||||
_scheduler_lock = threading.Lock()
|
||||
|
||||
|
||||
def get_scheduler(
|
||||
db_path: Path,
|
||||
db_path: Optional[Path] = None,
|
||||
run_task_fn: Optional[RunTaskFn] = None,
|
||||
task_types: Optional[frozenset[str]] = None,
|
||||
vram_budgets: Optional[dict[str, float]] = None,
|
||||
max_queue_depth: int = _DEFAULT_MAX_QUEUE_DEPTH,
|
||||
coordinator_url: str = "http://localhost:7700",
|
||||
service_name: str = "peregrine",
|
||||
) -> TaskScheduler:
|
||||
"""Return the process-level TaskScheduler singleton.
|
||||
) -> LocalScheduler:
|
||||
"""Return the process-level LocalScheduler singleton.
|
||||
|
||||
``run_task_fn``, ``task_types``, and ``vram_budgets`` are required on the
|
||||
first call; ignored on subsequent calls (singleton already constructed).
|
||||
``run_task_fn``, ``task_types``, ``vram_budgets``, and ``db_path`` are
|
||||
required on the first call; ignored on subsequent calls.
|
||||
|
||||
VRAM detection (which may involve a network call) is performed outside the
|
||||
lock so the lock is never held across blocking I/O.
|
||||
``coordinator_url`` and ``service_name`` are accepted but ignored —
|
||||
LocalScheduler has no coordinator. They exist for API compatibility with
|
||||
OrchestratedScheduler call sites.
|
||||
"""
|
||||
global _scheduler
|
||||
if _scheduler is not None:
|
||||
return _scheduler
|
||||
# Build outside the lock — TaskScheduler.__init__ may call detect_available_vram_gb()
|
||||
# which makes an httpx network call (up to 2 s). Holding the lock during that
|
||||
# would block any concurrent caller for the full duration.
|
||||
if run_task_fn is None or task_types is None or vram_budgets is None:
|
||||
if run_task_fn is None or task_types is None or vram_budgets is None or db_path is None:
|
||||
raise ValueError(
|
||||
"run_task_fn, task_types, and vram_budgets are required "
|
||||
"db_path, run_task_fn, task_types, and vram_budgets are required "
|
||||
"on the first call to get_scheduler()"
|
||||
)
|
||||
candidate = TaskScheduler(
|
||||
candidate = LocalScheduler(
|
||||
db_path=db_path,
|
||||
run_task_fn=run_task_fn,
|
||||
task_types=task_types,
|
||||
vram_budgets=vram_budgets,
|
||||
max_queue_depth=max_queue_depth,
|
||||
coordinator_url=coordinator_url,
|
||||
service_name=service_name,
|
||||
)
|
||||
candidate.start()
|
||||
with _scheduler_lock:
|
||||
if _scheduler is None:
|
||||
_scheduler = candidate
|
||||
else:
|
||||
# Another thread beat us — shut down our candidate and use the winner.
|
||||
candidate.shutdown()
|
||||
return _scheduler
|
||||
|
||||
|
|
|
|||
144
circuitforge_core/text/__init__.py
Normal file
144
circuitforge_core/text/__init__.py
Normal file
|
|
@ -0,0 +1,144 @@
|
|||
"""
|
||||
circuitforge_core.text — direct text generation service module.
|
||||
|
||||
Provides lightweight, low-overhead text generation that bypasses ollama/vllm
|
||||
for products that need fast, frequent inference from small local models.
|
||||
|
||||
Quick start (mock mode — no model required):
|
||||
|
||||
import os; os.environ["CF_TEXT_MOCK"] = "1"
|
||||
from circuitforge_core.text import generate, chat, ChatMessage
|
||||
|
||||
result = generate("Write a short cover letter intro.")
|
||||
print(result.text)
|
||||
|
||||
reply = chat([
|
||||
ChatMessage("system", "You are a helpful recipe assistant."),
|
||||
ChatMessage("user", "What can I make with eggs, spinach, and feta?"),
|
||||
])
|
||||
print(reply.text)
|
||||
|
||||
Real inference (GGUF model):
|
||||
|
||||
export CF_TEXT_MODEL=/Library/Assets/LLM/qwen2.5-3b-instruct-q4_k_m.gguf
|
||||
from circuitforge_core.text import generate
|
||||
result = generate("Summarise this job posting in 2 sentences: ...")
|
||||
|
||||
Backend selection (CF_TEXT_BACKEND env or explicit):
|
||||
|
||||
from circuitforge_core.text import make_backend
|
||||
backend = make_backend("/path/to/model.gguf", backend="llamacpp")
|
||||
|
||||
cf-orch service profile:
|
||||
|
||||
service_type: cf-text
|
||||
max_mb: per-model (3B Q4 ≈ 2048, 7B Q4 ≈ 4096)
|
||||
preferred_compute: 7.5 minimum (INT8 tensor cores)
|
||||
max_concurrent: 2
|
||||
shared: true
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
from circuitforge_core.text.backends.base import (
|
||||
ChatMessage,
|
||||
GenerateResult,
|
||||
TextBackend,
|
||||
make_text_backend,
|
||||
)
|
||||
from circuitforge_core.text.backends.mock import MockTextBackend
|
||||
|
||||
# ── Process-level singleton backend ──────────────────────────────────────────
|
||||
# Lazily initialised on first call to generate() or chat().
|
||||
# Products that need per-user or per-request backends should use make_backend().
|
||||
|
||||
_backend: TextBackend | None = None
|
||||
|
||||
|
||||
def _get_backend() -> TextBackend:
|
||||
global _backend
|
||||
if _backend is None:
|
||||
model_path = os.environ.get("CF_TEXT_MODEL", "mock")
|
||||
mock = model_path == "mock" or os.environ.get("CF_TEXT_MOCK", "") == "1"
|
||||
_backend = make_text_backend(model_path, mock=mock)
|
||||
return _backend
|
||||
|
||||
|
||||
def make_backend(
|
||||
model_path: str,
|
||||
backend: str | None = None,
|
||||
mock: bool | None = None,
|
||||
) -> TextBackend:
|
||||
"""
|
||||
Create a TextBackend for the given model.
|
||||
|
||||
Use this when you need a dedicated backend per request or per user,
|
||||
rather than the process-level singleton used by generate() and chat().
|
||||
"""
|
||||
return make_text_backend(model_path, backend=backend, mock=mock)
|
||||
|
||||
|
||||
# ── Convenience functions (singleton path) ────────────────────────────────────
|
||||
|
||||
|
||||
def generate(
|
||||
prompt: str,
|
||||
*,
|
||||
model: str | None = None,
|
||||
max_tokens: int = 512,
|
||||
temperature: float = 0.7,
|
||||
stream: bool = False,
|
||||
stop: list[str] | None = None,
|
||||
):
|
||||
"""
|
||||
Generate text from a prompt using the process-level backend.
|
||||
|
||||
stream=True returns an Iterator[str] of tokens instead of GenerateResult.
|
||||
model is accepted for API symmetry with LLMRouter but ignored by the
|
||||
singleton path — set CF_TEXT_MODEL to change the loaded model.
|
||||
"""
|
||||
backend = _get_backend()
|
||||
if stream:
|
||||
return backend.generate_stream(prompt, max_tokens=max_tokens, temperature=temperature, stop=stop)
|
||||
return backend.generate(prompt, max_tokens=max_tokens, temperature=temperature, stop=stop)
|
||||
|
||||
|
||||
def chat(
|
||||
messages: list[ChatMessage],
|
||||
*,
|
||||
model: str | None = None,
|
||||
max_tokens: int = 512,
|
||||
temperature: float = 0.7,
|
||||
stream: bool = False,
|
||||
) -> GenerateResult:
|
||||
"""
|
||||
Chat completion using the process-level backend.
|
||||
|
||||
messages should be a list of ChatMessage(role, content) objects.
|
||||
stream=True is not yet supported on the chat path; pass stream=False.
|
||||
"""
|
||||
if stream:
|
||||
raise NotImplementedError(
|
||||
"stream=True is not yet supported for chat(). "
|
||||
"Use generate_stream() directly on a backend instance."
|
||||
)
|
||||
return _get_backend().chat(messages, max_tokens=max_tokens, temperature=temperature)
|
||||
|
||||
|
||||
def reset_backend() -> None:
|
||||
"""Reset the process-level singleton. Test teardown only."""
|
||||
global _backend
|
||||
_backend = None
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ChatMessage",
|
||||
"GenerateResult",
|
||||
"TextBackend",
|
||||
"MockTextBackend",
|
||||
"make_backend",
|
||||
"generate",
|
||||
"chat",
|
||||
"reset_backend",
|
||||
]
|
||||
226
circuitforge_core/text/app.py
Normal file
226
circuitforge_core/text/app.py
Normal file
|
|
@ -0,0 +1,226 @@
|
|||
"""
|
||||
cf-text FastAPI service — managed by cf-orch.
|
||||
|
||||
Lightweight local text generation. Supports GGUF models via llama.cpp and
|
||||
HuggingFace transformers. Sits alongside vllm/ollama for products that need
|
||||
fast, frequent inference from small local models (3B–7B Q4).
|
||||
|
||||
Endpoints:
|
||||
GET /health → {"status": "ok", "model": str, "vram_mb": int, "backend": str}
|
||||
POST /generate → GenerateResponse
|
||||
POST /chat → GenerateResponse
|
||||
|
||||
Usage:
|
||||
python -m circuitforge_core.text.app \
|
||||
--model /Library/Assets/LLM/qwen2.5-3b-instruct-q4_k_m.gguf \
|
||||
--port 8006 \
|
||||
--gpu-id 0
|
||||
|
||||
Mock mode (no model or GPU required):
|
||||
CF_TEXT_MOCK=1 python -m circuitforge_core.text.app --port 8006
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
from functools import partial
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
from circuitforge_core.text.backends.base import ChatMessage as BackendChatMessage
|
||||
from circuitforge_core.text.backends.base import make_text_backend
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_backend = None
|
||||
|
||||
|
||||
# ── Request / response models ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
class GenerateRequest(BaseModel):
|
||||
prompt: str
|
||||
max_tokens: int = 512
|
||||
temperature: float = 0.7
|
||||
stop: list[str] | None = None
|
||||
|
||||
|
||||
class ChatMessageModel(BaseModel):
|
||||
role: str
|
||||
content: str
|
||||
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
messages: list[ChatMessageModel]
|
||||
max_tokens: int = 512
|
||||
temperature: float = 0.7
|
||||
|
||||
|
||||
class GenerateResponse(BaseModel):
|
||||
text: str
|
||||
tokens_used: int = 0
|
||||
model: str = ""
|
||||
|
||||
|
||||
# ── OpenAI-compat request / response (for LLMRouter openai_compat path) ──────
|
||||
|
||||
|
||||
class OAIMessageModel(BaseModel):
|
||||
role: str
|
||||
content: str
|
||||
|
||||
|
||||
class OAIChatRequest(BaseModel):
|
||||
model: str = "cf-text"
|
||||
messages: list[OAIMessageModel]
|
||||
max_tokens: int | None = None
|
||||
temperature: float = 0.7
|
||||
stream: bool = False
|
||||
|
||||
|
||||
class OAIChoice(BaseModel):
|
||||
index: int = 0
|
||||
message: OAIMessageModel
|
||||
finish_reason: str = "stop"
|
||||
|
||||
|
||||
class OAIUsage(BaseModel):
|
||||
prompt_tokens: int = 0
|
||||
completion_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
|
||||
|
||||
class OAIChatResponse(BaseModel):
|
||||
id: str
|
||||
object: str = "chat.completion"
|
||||
created: int
|
||||
model: str
|
||||
choices: list[OAIChoice]
|
||||
usage: OAIUsage
|
||||
|
||||
|
||||
# ── App factory ───────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def create_app(
|
||||
model_path: str,
|
||||
gpu_id: int = 0,
|
||||
backend: str | None = None,
|
||||
mock: bool = False,
|
||||
) -> FastAPI:
|
||||
global _backend
|
||||
|
||||
os.environ.setdefault("CUDA_VISIBLE_DEVICES", str(gpu_id))
|
||||
|
||||
_backend = make_text_backend(model_path, backend=backend, mock=mock)
|
||||
logger.info("cf-text ready: model=%r vram=%dMB", _backend.model_name, _backend.vram_mb)
|
||||
|
||||
app = FastAPI(title="cf-text", version="0.1.0")
|
||||
|
||||
@app.get("/health")
|
||||
def health() -> dict:
|
||||
if _backend is None:
|
||||
raise HTTPException(503, detail="backend not initialised")
|
||||
return {
|
||||
"status": "ok",
|
||||
"model": _backend.model_name,
|
||||
"vram_mb": _backend.vram_mb,
|
||||
}
|
||||
|
||||
@app.post("/generate")
|
||||
async def generate(req: GenerateRequest) -> GenerateResponse:
|
||||
if _backend is None:
|
||||
raise HTTPException(503, detail="backend not initialised")
|
||||
result = await _backend.generate_async(
|
||||
req.prompt,
|
||||
max_tokens=req.max_tokens,
|
||||
temperature=req.temperature,
|
||||
stop=req.stop,
|
||||
)
|
||||
return GenerateResponse(
|
||||
text=result.text,
|
||||
tokens_used=result.tokens_used,
|
||||
model=result.model,
|
||||
)
|
||||
|
||||
@app.post("/chat")
|
||||
async def chat(req: ChatRequest) -> GenerateResponse:
|
||||
if _backend is None:
|
||||
raise HTTPException(503, detail="backend not initialised")
|
||||
messages = [BackendChatMessage(m.role, m.content) for m in req.messages]
|
||||
# chat() is sync-only in the Protocol; run in thread pool to avoid blocking
|
||||
loop = asyncio.get_event_loop()
|
||||
result = await loop.run_in_executor(
|
||||
None,
|
||||
partial(_backend.chat, messages,
|
||||
max_tokens=req.max_tokens, temperature=req.temperature),
|
||||
)
|
||||
return GenerateResponse(
|
||||
text=result.text,
|
||||
tokens_used=result.tokens_used,
|
||||
model=result.model,
|
||||
)
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
async def oai_chat_completions(req: OAIChatRequest) -> OAIChatResponse:
|
||||
"""OpenAI-compatible chat completions endpoint.
|
||||
|
||||
Allows LLMRouter (and any openai_compat client) to use cf-text
|
||||
without a custom backend type — just set base_url to this service's
|
||||
/v1 prefix.
|
||||
"""
|
||||
if _backend is None:
|
||||
raise HTTPException(503, detail="backend not initialised")
|
||||
messages = [BackendChatMessage(m.role, m.content) for m in req.messages]
|
||||
max_tok = req.max_tokens or 512
|
||||
loop = asyncio.get_event_loop()
|
||||
result = await loop.run_in_executor(
|
||||
None,
|
||||
partial(_backend.chat, messages, max_tokens=max_tok, temperature=req.temperature),
|
||||
)
|
||||
return OAIChatResponse(
|
||||
id=f"cftext-{uuid.uuid4().hex[:12]}",
|
||||
created=int(time.time()),
|
||||
model=result.model or req.model,
|
||||
choices=[OAIChoice(message=OAIMessageModel(role="assistant", content=result.text))],
|
||||
usage=OAIUsage(completion_tokens=result.tokens_used, total_tokens=result.tokens_used),
|
||||
)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
# ── CLI entrypoint ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description="cf-text inference server")
|
||||
parser.add_argument("--model", default=os.environ.get("CF_TEXT_MODEL", "mock"),
|
||||
help="Path to GGUF file or HF model ID")
|
||||
parser.add_argument("--port", type=int, default=8006)
|
||||
parser.add_argument("--host", default="0.0.0.0")
|
||||
parser.add_argument("--gpu-id", type=int, default=0,
|
||||
help="CUDA device index to use")
|
||||
parser.add_argument("--backend", choices=["llamacpp", "transformers"], default=None)
|
||||
parser.add_argument("--mock", action="store_true",
|
||||
help="Run in mock mode (no model or GPU needed)")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO,
|
||||
format="%(asctime)s %(levelname)s %(name)s — %(message)s")
|
||||
args = _parse_args()
|
||||
mock = args.mock or os.environ.get("CF_TEXT_MOCK", "") == "1" or args.model == "mock"
|
||||
app = create_app(
|
||||
model_path=args.model,
|
||||
gpu_id=args.gpu_id,
|
||||
backend=args.backend,
|
||||
mock=mock,
|
||||
)
|
||||
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|
||||
10
circuitforge_core/text/backends/__init__.py
Normal file
10
circuitforge_core/text/backends/__init__.py
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
from .base import ChatMessage, GenerateResult, TextBackend, make_text_backend
|
||||
from .mock import MockTextBackend
|
||||
|
||||
__all__ = [
|
||||
"ChatMessage",
|
||||
"GenerateResult",
|
||||
"TextBackend",
|
||||
"MockTextBackend",
|
||||
"make_text_backend",
|
||||
]
|
||||
182
circuitforge_core/text/backends/base.py
Normal file
182
circuitforge_core/text/backends/base.py
Normal file
|
|
@ -0,0 +1,182 @@
|
|||
# circuitforge_core/text/backends/base.py — TextBackend Protocol + factory
|
||||
#
|
||||
# MIT licensed. The Protocol and mock backend are always importable.
|
||||
# Real backends (LlamaCppBackend, TransformersBackend) require optional extras.
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import AsyncIterator, Iterator, Protocol, runtime_checkable
|
||||
|
||||
|
||||
# ── Shared result types ───────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class GenerateResult:
|
||||
"""Result from a single non-streaming generate() call."""
|
||||
|
||||
def __init__(self, text: str, tokens_used: int = 0, model: str = "") -> None:
|
||||
self.text = text
|
||||
self.tokens_used = tokens_used
|
||||
self.model = model
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"GenerateResult(text={self.text!r:.40}, tokens={self.tokens_used})"
|
||||
|
||||
|
||||
class ChatMessage:
|
||||
"""A single message in a chat conversation."""
|
||||
|
||||
def __init__(self, role: str, content: str) -> None:
|
||||
if role not in ("system", "user", "assistant"):
|
||||
raise ValueError(f"Invalid role {role!r}. Must be system, user, or assistant.")
|
||||
self.role = role
|
||||
self.content = content
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {"role": self.role, "content": self.content}
|
||||
|
||||
|
||||
# ── TextBackend Protocol ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class TextBackend(Protocol):
|
||||
"""
|
||||
Abstract interface for direct text generation backends.
|
||||
|
||||
All generate/chat methods have both sync and async variants.
|
||||
Streaming variants yield str tokens rather than a complete result.
|
||||
|
||||
Implementations must be safe to construct once and call concurrently
|
||||
(the model is loaded at construction time and reused across calls).
|
||||
"""
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
*,
|
||||
max_tokens: int = 512,
|
||||
temperature: float = 0.7,
|
||||
stop: list[str] | None = None,
|
||||
) -> GenerateResult:
|
||||
"""Synchronous generate — blocks until the full response is produced."""
|
||||
...
|
||||
|
||||
def generate_stream(
|
||||
self,
|
||||
prompt: str,
|
||||
*,
|
||||
max_tokens: int = 512,
|
||||
temperature: float = 0.7,
|
||||
stop: list[str] | None = None,
|
||||
) -> Iterator[str]:
|
||||
"""Synchronous streaming — yields tokens as they are produced."""
|
||||
...
|
||||
|
||||
async def generate_async(
|
||||
self,
|
||||
prompt: str,
|
||||
*,
|
||||
max_tokens: int = 512,
|
||||
temperature: float = 0.7,
|
||||
stop: list[str] | None = None,
|
||||
) -> GenerateResult:
|
||||
"""Async generate — runs in thread pool, never blocks the event loop."""
|
||||
...
|
||||
|
||||
async def generate_stream_async(
|
||||
self,
|
||||
prompt: str,
|
||||
*,
|
||||
max_tokens: int = 512,
|
||||
temperature: float = 0.7,
|
||||
stop: list[str] | None = None,
|
||||
) -> AsyncIterator[str]:
|
||||
"""Async streaming — yields tokens without blocking the event loop."""
|
||||
...
|
||||
|
||||
def chat(
|
||||
self,
|
||||
messages: list[ChatMessage],
|
||||
*,
|
||||
max_tokens: int = 512,
|
||||
temperature: float = 0.7,
|
||||
) -> GenerateResult:
|
||||
"""Chat completion — formats messages into a prompt and generates."""
|
||||
...
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
"""Identifier for the loaded model (path stem or HF repo ID)."""
|
||||
...
|
||||
|
||||
@property
|
||||
def vram_mb(self) -> int:
|
||||
"""Approximate VRAM footprint in MB. Used by cf-orch service registry."""
|
||||
...
|
||||
|
||||
|
||||
# ── Backend selection ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _select_backend(model_path: str, backend: str | None) -> str:
|
||||
"""
|
||||
Return "llamacpp" or "transformers" for the given model path.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model_path Path to the model file or HuggingFace repo ID (e.g. "Qwen/Qwen2.5-3B").
|
||||
backend Explicit override from the caller ("llamacpp" | "transformers" | None).
|
||||
When provided, trust it without inspection.
|
||||
|
||||
Return "llamacpp" or "transformers". Raise ValueError for unrecognised values.
|
||||
"""
|
||||
_VALID = ("llamacpp", "transformers")
|
||||
|
||||
# 1. Caller-supplied override — highest trust, no inspection needed.
|
||||
resolved = backend or os.environ.get("CF_TEXT_BACKEND")
|
||||
if resolved:
|
||||
if resolved not in _VALID:
|
||||
raise ValueError(
|
||||
f"CF_TEXT_BACKEND={resolved!r} is not valid. Choose: {', '.join(_VALID)}"
|
||||
)
|
||||
return resolved
|
||||
|
||||
# 2. Format detection — GGUF files are unambiguously llama-cpp territory.
|
||||
if model_path.lower().endswith(".gguf"):
|
||||
return "llamacpp"
|
||||
|
||||
# 3. Safe default — transformers covers HF repo IDs and safetensors dirs.
|
||||
return "transformers"
|
||||
|
||||
|
||||
# ── Factory ───────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def make_text_backend(
|
||||
model_path: str,
|
||||
backend: str | None = None,
|
||||
mock: bool | None = None,
|
||||
) -> "TextBackend":
|
||||
"""
|
||||
Return a TextBackend for the given model.
|
||||
|
||||
mock=True or CF_TEXT_MOCK=1 → MockTextBackend (no GPU, no model file needed)
|
||||
Otherwise → backend resolved via _select_backend()
|
||||
"""
|
||||
use_mock = mock if mock is not None else os.environ.get("CF_TEXT_MOCK", "") == "1"
|
||||
if use_mock:
|
||||
from circuitforge_core.text.backends.mock import MockTextBackend
|
||||
return MockTextBackend(model_name=model_path)
|
||||
|
||||
resolved = _select_backend(model_path, backend)
|
||||
|
||||
if resolved == "llamacpp":
|
||||
from circuitforge_core.text.backends.llamacpp import LlamaCppBackend
|
||||
return LlamaCppBackend(model_path=model_path)
|
||||
|
||||
if resolved == "transformers":
|
||||
from circuitforge_core.text.backends.transformers import TransformersBackend
|
||||
return TransformersBackend(model_path=model_path)
|
||||
|
||||
raise ValueError(f"Unknown backend {resolved!r}. Expected 'llamacpp' or 'transformers'.")
|
||||
192
circuitforge_core/text/backends/llamacpp.py
Normal file
192
circuitforge_core/text/backends/llamacpp.py
Normal file
|
|
@ -0,0 +1,192 @@
|
|||
# circuitforge_core/text/backends/llamacpp.py — llama-cpp-python backend
|
||||
#
|
||||
# BSL 1.1: real inference. Requires llama-cpp-python + a GGUF model file.
|
||||
# Install: pip install circuitforge-core[text-llamacpp]
|
||||
#
|
||||
# VRAM estimates (Q4_K_M quant):
|
||||
# 1B → ~700MB 3B → ~2048MB 7B → ~4096MB
|
||||
# 13B → ~7500MB 70B → ~40000MB
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import AsyncIterator, Iterator
|
||||
|
||||
from circuitforge_core.text.backends.base import ChatMessage, GenerateResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Q4_K_M is the recommended default — best accuracy/size tradeoff for local use.
|
||||
_DEFAULT_N_CTX = int(os.environ.get("CF_TEXT_CTX", "4096"))
|
||||
_DEFAULT_N_GPU_LAYERS = int(os.environ.get("CF_TEXT_GPU_LAYERS", "-1")) # -1 = all layers
|
||||
|
||||
|
||||
def _estimate_vram_mb(model_path: str) -> int:
|
||||
"""Rough VRAM estimate from file size. Accurate enough for cf-orch budgeting."""
|
||||
try:
|
||||
size_mb = Path(model_path).stat().st_size // (1024 * 1024)
|
||||
# GGUF models typically need ~1.1× file size in VRAM (KV cache overhead)
|
||||
return int(size_mb * 1.1)
|
||||
except OSError:
|
||||
return 4096 # conservative default
|
||||
|
||||
|
||||
class LlamaCppBackend:
|
||||
"""
|
||||
Direct llama-cpp-python inference backend for GGUF models.
|
||||
|
||||
The model is loaded once at construction. All inference runs in a thread
|
||||
pool executor so async callers never block the event loop.
|
||||
|
||||
Context window, GPU layers, and thread count are configurable via env:
|
||||
CF_TEXT_CTX token context window (default 4096)
|
||||
CF_TEXT_GPU_LAYERS GPU layers to offload, -1 = all (default -1)
|
||||
CF_TEXT_THREADS CPU thread count (default: auto)
|
||||
|
||||
Requires: pip install circuitforge-core[text-llamacpp]
|
||||
"""
|
||||
|
||||
def __init__(self, model_path: str) -> None:
|
||||
try:
|
||||
from llama_cpp import Llama # type: ignore[import]
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"llama-cpp-python is required for LlamaCppBackend. "
|
||||
"Install with: pip install circuitforge-core[text-llamacpp]"
|
||||
) from exc
|
||||
|
||||
if not Path(model_path).exists():
|
||||
raise FileNotFoundError(
|
||||
f"GGUF model not found: {model_path}\n"
|
||||
"Download a GGUF model and set CF_TEXT_MODEL to its path."
|
||||
)
|
||||
|
||||
n_threads = int(os.environ.get("CF_TEXT_THREADS", "0")) or None
|
||||
logger.info(
|
||||
"Loading GGUF model %s (ctx=%d, gpu_layers=%d)",
|
||||
model_path, _DEFAULT_N_CTX, _DEFAULT_N_GPU_LAYERS,
|
||||
)
|
||||
self._llm = Llama(
|
||||
model_path=model_path,
|
||||
n_ctx=_DEFAULT_N_CTX,
|
||||
n_gpu_layers=_DEFAULT_N_GPU_LAYERS,
|
||||
n_threads=n_threads,
|
||||
verbose=False,
|
||||
)
|
||||
self._model_path = model_path
|
||||
self._vram_mb = _estimate_vram_mb(model_path)
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
return Path(self._model_path).stem
|
||||
|
||||
@property
|
||||
def vram_mb(self) -> int:
|
||||
return self._vram_mb
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
*,
|
||||
max_tokens: int = 512,
|
||||
temperature: float = 0.7,
|
||||
stop: list[str] | None = None,
|
||||
) -> GenerateResult:
|
||||
output = self._llm(
|
||||
prompt,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
stop=stop or [],
|
||||
stream=False,
|
||||
)
|
||||
text = output["choices"][0]["text"]
|
||||
tokens_used = output["usage"]["completion_tokens"]
|
||||
return GenerateResult(text=text, tokens_used=tokens_used, model=self.model_name)
|
||||
|
||||
def generate_stream(
|
||||
self,
|
||||
prompt: str,
|
||||
*,
|
||||
max_tokens: int = 512,
|
||||
temperature: float = 0.7,
|
||||
stop: list[str] | None = None,
|
||||
) -> Iterator[str]:
|
||||
for chunk in self._llm(
|
||||
prompt,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
stop=stop or [],
|
||||
stream=True,
|
||||
):
|
||||
yield chunk["choices"][0]["text"]
|
||||
|
||||
async def generate_async(
|
||||
self,
|
||||
prompt: str,
|
||||
*,
|
||||
max_tokens: int = 512,
|
||||
temperature: float = 0.7,
|
||||
stop: list[str] | None = None,
|
||||
) -> GenerateResult:
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(
|
||||
None,
|
||||
lambda: self.generate(prompt, max_tokens=max_tokens, temperature=temperature, stop=stop),
|
||||
)
|
||||
|
||||
async def generate_stream_async(
|
||||
self,
|
||||
prompt: str,
|
||||
*,
|
||||
max_tokens: int = 512,
|
||||
temperature: float = 0.7,
|
||||
stop: list[str] | None = None,
|
||||
) -> AsyncIterator[str]:
|
||||
# llama_cpp streaming is synchronous — run in executor and re-emit tokens
|
||||
import queue
|
||||
import threading
|
||||
|
||||
token_queue: queue.Queue = queue.Queue()
|
||||
_DONE = object()
|
||||
|
||||
def _produce() -> None:
|
||||
try:
|
||||
for chunk in self._llm(
|
||||
prompt,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
stop=stop or [],
|
||||
stream=True,
|
||||
):
|
||||
token_queue.put(chunk["choices"][0]["text"])
|
||||
finally:
|
||||
token_queue.put(_DONE)
|
||||
|
||||
thread = threading.Thread(target=_produce, daemon=True)
|
||||
thread.start()
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
while True:
|
||||
token = await loop.run_in_executor(None, token_queue.get)
|
||||
if token is _DONE:
|
||||
break
|
||||
yield token
|
||||
|
||||
def chat(
|
||||
self,
|
||||
messages: list[ChatMessage],
|
||||
*,
|
||||
max_tokens: int = 512,
|
||||
temperature: float = 0.7,
|
||||
) -> GenerateResult:
|
||||
# llama-cpp-python has native chat_completion for instruct models
|
||||
output = self._llm.create_chat_completion(
|
||||
messages=[m.to_dict() for m in messages],
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
)
|
||||
text = output["choices"][0]["message"]["content"]
|
||||
tokens_used = output["usage"]["completion_tokens"]
|
||||
return GenerateResult(text=text, tokens_used=tokens_used, model=self.model_name)
|
||||
104
circuitforge_core/text/backends/mock.py
Normal file
104
circuitforge_core/text/backends/mock.py
Normal file
|
|
@ -0,0 +1,104 @@
|
|||
# circuitforge_core/text/backends/mock.py — synthetic text backend
|
||||
#
|
||||
# MIT licensed. No model file, no GPU, no extras required.
|
||||
# Used in dev, CI, and free-tier nodes below the minimum VRAM threshold.
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import AsyncIterator, Iterator
|
||||
|
||||
from circuitforge_core.text.backends.base import ChatMessage, GenerateResult
|
||||
|
||||
_MOCK_RESPONSE = (
|
||||
"This is a synthetic response from MockTextBackend. "
|
||||
"Install a real backend (llama-cpp-python or transformers) and provide a model path "
|
||||
"to generate real text."
|
||||
)
|
||||
|
||||
|
||||
class MockTextBackend:
|
||||
"""
|
||||
Deterministic synthetic text backend for development and CI.
|
||||
|
||||
Always returns the same fixed response so tests are reproducible without
|
||||
a GPU or model file. Streaming emits the response word-by-word with a
|
||||
configurable delay so UI streaming paths can be exercised.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "mock",
|
||||
token_delay_s: float = 0.0,
|
||||
) -> None:
|
||||
self._model_name = model_name
|
||||
self._token_delay_s = token_delay_s
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
return self._model_name
|
||||
|
||||
@property
|
||||
def vram_mb(self) -> int:
|
||||
return 0
|
||||
|
||||
def _response_for(self, prompt_or_messages: str) -> str:
|
||||
return _MOCK_RESPONSE
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
*,
|
||||
max_tokens: int = 512,
|
||||
temperature: float = 0.7,
|
||||
stop: list[str] | None = None,
|
||||
) -> GenerateResult:
|
||||
text = self._response_for(prompt)
|
||||
return GenerateResult(text=text, tokens_used=len(text.split()), model=self._model_name)
|
||||
|
||||
def generate_stream(
|
||||
self,
|
||||
prompt: str,
|
||||
*,
|
||||
max_tokens: int = 512,
|
||||
temperature: float = 0.7,
|
||||
stop: list[str] | None = None,
|
||||
) -> Iterator[str]:
|
||||
import time
|
||||
for word in self._response_for(prompt).split():
|
||||
yield word + " "
|
||||
if self._token_delay_s:
|
||||
time.sleep(self._token_delay_s)
|
||||
|
||||
async def generate_async(
|
||||
self,
|
||||
prompt: str,
|
||||
*,
|
||||
max_tokens: int = 512,
|
||||
temperature: float = 0.7,
|
||||
stop: list[str] | None = None,
|
||||
) -> GenerateResult:
|
||||
return self.generate(prompt, max_tokens=max_tokens, temperature=temperature, stop=stop)
|
||||
|
||||
async def generate_stream_async(
|
||||
self,
|
||||
prompt: str,
|
||||
*,
|
||||
max_tokens: int = 512,
|
||||
temperature: float = 0.7,
|
||||
stop: list[str] | None = None,
|
||||
) -> AsyncIterator[str]:
|
||||
for word in self._response_for(prompt).split():
|
||||
yield word + " "
|
||||
if self._token_delay_s:
|
||||
await asyncio.sleep(self._token_delay_s)
|
||||
|
||||
def chat(
|
||||
self,
|
||||
messages: list[ChatMessage],
|
||||
*,
|
||||
max_tokens: int = 512,
|
||||
temperature: float = 0.7,
|
||||
) -> GenerateResult:
|
||||
# Format messages into a simple prompt for the mock response
|
||||
prompt = "\n".join(f"{m.role}: {m.content}" for m in messages)
|
||||
return self.generate(prompt, max_tokens=max_tokens, temperature=temperature)
|
||||
197
circuitforge_core/text/backends/transformers.py
Normal file
197
circuitforge_core/text/backends/transformers.py
Normal file
|
|
@ -0,0 +1,197 @@
|
|||
# circuitforge_core/text/backends/transformers.py — HuggingFace transformers backend
|
||||
#
|
||||
# BSL 1.1: real inference. Requires torch + transformers + a model checkpoint.
|
||||
# Install: pip install circuitforge-core[text-transformers]
|
||||
#
|
||||
# Best for: HF repo IDs, safetensors checkpoints, models without GGUF versions.
|
||||
# For GGUF models prefer LlamaCppBackend — lower overhead, smaller install.
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from typing import AsyncIterator, Iterator
|
||||
|
||||
from circuitforge_core.text.backends.base import ChatMessage, GenerateResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_DEFAULT_MAX_NEW_TOKENS = 512
|
||||
_LOAD_IN_4BIT = os.environ.get("CF_TEXT_4BIT", "0") == "1"
|
||||
_LOAD_IN_8BIT = os.environ.get("CF_TEXT_8BIT", "0") == "1"
|
||||
|
||||
|
||||
class TransformersBackend:
|
||||
"""
|
||||
HuggingFace transformers inference backend.
|
||||
|
||||
Loads any causal LM available on HuggingFace Hub or a local checkpoint dir.
|
||||
Supports 4-bit and 8-bit quantization via bitsandbytes when VRAM is limited:
|
||||
CF_TEXT_4BIT=1 — load_in_4bit (requires bitsandbytes)
|
||||
CF_TEXT_8BIT=1 — load_in_8bit (requires bitsandbytes)
|
||||
|
||||
Chat completion uses the tokenizer's apply_chat_template() when available,
|
||||
falling back to a simple "User: / Assistant:" prompt format.
|
||||
|
||||
Requires: pip install circuitforge-core[text-transformers]
|
||||
"""
|
||||
|
||||
def __init__(self, model_path: str) -> None:
|
||||
try:
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"torch and transformers are required for TransformersBackend. "
|
||||
"Install with: pip install circuitforge-core[text-transformers]"
|
||||
) from exc
|
||||
|
||||
self._device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
logger.info("Loading transformers model %s on %s", model_path, self._device)
|
||||
|
||||
load_kwargs: dict = {"device_map": "auto" if self._device == "cuda" else None}
|
||||
if _LOAD_IN_4BIT:
|
||||
load_kwargs["load_in_4bit"] = True
|
||||
elif _LOAD_IN_8BIT:
|
||||
load_kwargs["load_in_8bit"] = True
|
||||
|
||||
self._tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||
self._model = AutoModelForCausalLM.from_pretrained(model_path, **load_kwargs)
|
||||
if self._device == "cpu":
|
||||
self._model = self._model.to("cpu")
|
||||
|
||||
self._model_path = model_path
|
||||
self._TextIteratorStreamer = TextIteratorStreamer
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
# HF repo IDs contain "/" — use the part after the slash as a short name
|
||||
return self._model_path.split("/")[-1]
|
||||
|
||||
@property
|
||||
def vram_mb(self) -> int:
|
||||
try:
|
||||
import torch
|
||||
if torch.cuda.is_available():
|
||||
return torch.cuda.memory_allocated() // (1024 * 1024)
|
||||
except Exception:
|
||||
pass
|
||||
return 0
|
||||
|
||||
def _build_inputs(self, prompt: str):
|
||||
return self._tokenizer(prompt, return_tensors="pt").to(self._device)
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
*,
|
||||
max_tokens: int = 512,
|
||||
temperature: float = 0.7,
|
||||
stop: list[str] | None = None,
|
||||
) -> GenerateResult:
|
||||
inputs = self._build_inputs(prompt)
|
||||
input_len = inputs["input_ids"].shape[1]
|
||||
outputs = self._model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
do_sample=temperature > 0,
|
||||
pad_token_id=self._tokenizer.eos_token_id,
|
||||
)
|
||||
new_tokens = outputs[0][input_len:]
|
||||
text = self._tokenizer.decode(new_tokens, skip_special_tokens=True)
|
||||
return GenerateResult(text=text, tokens_used=len(new_tokens), model=self.model_name)
|
||||
|
||||
def generate_stream(
|
||||
self,
|
||||
prompt: str,
|
||||
*,
|
||||
max_tokens: int = 512,
|
||||
temperature: float = 0.7,
|
||||
stop: list[str] | None = None,
|
||||
) -> Iterator[str]:
|
||||
import threading
|
||||
|
||||
inputs = self._build_inputs(prompt)
|
||||
streamer = self._TextIteratorStreamer(
|
||||
self._tokenizer, skip_prompt=True, skip_special_tokens=True
|
||||
)
|
||||
gen_kwargs = dict(
|
||||
**inputs,
|
||||
max_new_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
do_sample=temperature > 0,
|
||||
streamer=streamer,
|
||||
pad_token_id=self._tokenizer.eos_token_id,
|
||||
)
|
||||
thread = threading.Thread(target=self._model.generate, kwargs=gen_kwargs, daemon=True)
|
||||
thread.start()
|
||||
yield from streamer
|
||||
|
||||
async def generate_async(
|
||||
self,
|
||||
prompt: str,
|
||||
*,
|
||||
max_tokens: int = 512,
|
||||
temperature: float = 0.7,
|
||||
stop: list[str] | None = None,
|
||||
) -> GenerateResult:
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(
|
||||
None,
|
||||
lambda: self.generate(prompt, max_tokens=max_tokens, temperature=temperature, stop=stop),
|
||||
)
|
||||
|
||||
async def generate_stream_async(
|
||||
self,
|
||||
prompt: str,
|
||||
*,
|
||||
max_tokens: int = 512,
|
||||
temperature: float = 0.7,
|
||||
stop: list[str] | None = None,
|
||||
) -> AsyncIterator[str]:
|
||||
import queue
|
||||
import threading
|
||||
|
||||
token_queue: queue.Queue = queue.Queue()
|
||||
_DONE = object()
|
||||
|
||||
def _produce() -> None:
|
||||
try:
|
||||
for token in self.generate_stream(
|
||||
prompt, max_tokens=max_tokens, temperature=temperature
|
||||
):
|
||||
token_queue.put(token)
|
||||
finally:
|
||||
token_queue.put(_DONE)
|
||||
|
||||
threading.Thread(target=_produce, daemon=True).start()
|
||||
loop = asyncio.get_event_loop()
|
||||
while True:
|
||||
token = await loop.run_in_executor(None, token_queue.get)
|
||||
if token is _DONE:
|
||||
break
|
||||
yield token
|
||||
|
||||
def chat(
|
||||
self,
|
||||
messages: list[ChatMessage],
|
||||
*,
|
||||
max_tokens: int = 512,
|
||||
temperature: float = 0.7,
|
||||
) -> GenerateResult:
|
||||
# Use the tokenizer's chat template when available (instruct models)
|
||||
if hasattr(self._tokenizer, "apply_chat_template") and self._tokenizer.chat_template:
|
||||
prompt = self._tokenizer.apply_chat_template(
|
||||
[m.to_dict() for m in messages],
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
else:
|
||||
prompt = "\n".join(
|
||||
f"{'User' if m.role == 'user' else 'Assistant'}: {m.content}"
|
||||
for m in messages
|
||||
if m.role != "system"
|
||||
) + "\nAssistant:"
|
||||
|
||||
return self.generate(prompt, max_tokens=max_tokens, temperature=temperature)
|
||||
87
circuitforge_core/tts/__init__.py
Normal file
87
circuitforge_core/tts/__init__.py
Normal file
|
|
@ -0,0 +1,87 @@
|
|||
"""
|
||||
circuitforge_core.tts — Text-to-speech service module.
|
||||
|
||||
Quick start (mock mode — no GPU or model required):
|
||||
|
||||
import os; os.environ["CF_TTS_MOCK"] = "1"
|
||||
from circuitforge_core.tts import synthesize
|
||||
|
||||
result = synthesize("Hello world")
|
||||
open("out.ogg", "wb").write(result.audio_bytes)
|
||||
|
||||
Real inference (chatterbox-turbo):
|
||||
|
||||
export CF_TTS_MODEL=/Library/Assets/LLM/chatterbox/hub/models--ResembleAI--chatterbox-turbo/snapshots/<hash>
|
||||
from circuitforge_core.tts import synthesize
|
||||
|
||||
cf-orch service profile:
|
||||
|
||||
service_type: cf-tts
|
||||
max_mb: 768
|
||||
max_concurrent: 1
|
||||
shared: true
|
||||
managed:
|
||||
exec: python -m circuitforge_core.tts.app
|
||||
args: --model <path> --port {port} --gpu-id {gpu_id}
|
||||
port: 8005
|
||||
health: /health
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
from circuitforge_core.tts.backends.base import (
|
||||
AudioFormat,
|
||||
TTSBackend,
|
||||
TTSResult,
|
||||
make_tts_backend,
|
||||
)
|
||||
from circuitforge_core.tts.backends.mock import MockTTSBackend
|
||||
|
||||
_backend: TTSBackend | None = None
|
||||
|
||||
|
||||
def _get_backend() -> TTSBackend:
|
||||
global _backend
|
||||
if _backend is None:
|
||||
model_path = os.environ.get("CF_TTS_MODEL", "mock")
|
||||
mock = model_path == "mock" or os.environ.get("CF_TTS_MOCK", "") == "1"
|
||||
_backend = make_tts_backend(model_path, mock=mock)
|
||||
return _backend
|
||||
|
||||
|
||||
def synthesize(
|
||||
text: str,
|
||||
*,
|
||||
exaggeration: float = 0.5,
|
||||
cfg_weight: float = 0.5,
|
||||
temperature: float = 0.8,
|
||||
audio_prompt: bytes | None = None,
|
||||
format: AudioFormat = "ogg",
|
||||
) -> TTSResult:
|
||||
"""Synthesize speech from text using the process-level backend."""
|
||||
return _get_backend().synthesize(
|
||||
text,
|
||||
exaggeration=exaggeration,
|
||||
cfg_weight=cfg_weight,
|
||||
temperature=temperature,
|
||||
audio_prompt=audio_prompt,
|
||||
format=format,
|
||||
)
|
||||
|
||||
|
||||
def reset_backend() -> None:
|
||||
"""Reset the process-level singleton. Test teardown only."""
|
||||
global _backend
|
||||
_backend = None
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AudioFormat",
|
||||
"TTSBackend",
|
||||
"TTSResult",
|
||||
"MockTTSBackend",
|
||||
"make_tts_backend",
|
||||
"synthesize",
|
||||
"reset_backend",
|
||||
]
|
||||
103
circuitforge_core/tts/app.py
Normal file
103
circuitforge_core/tts/app.py
Normal file
|
|
@ -0,0 +1,103 @@
|
|||
"""
|
||||
cf-tts FastAPI service — managed by cf-orch.
|
||||
|
||||
Endpoints:
|
||||
GET /health → {"status": "ok", "model": str, "vram_mb": int}
|
||||
POST /synthesize → audio bytes (Content-Type: audio/ogg or audio/wav or audio/mpeg)
|
||||
|
||||
Usage:
|
||||
python -m circuitforge_core.tts.app \
|
||||
--model /Library/Assets/LLM/chatterbox/hub/models--ResembleAI--chatterbox-turbo/snapshots/<hash> \
|
||||
--port 8005 \
|
||||
--gpu-id 0
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from fastapi import FastAPI, Form, HTTPException, UploadFile
|
||||
from fastapi.responses import Response
|
||||
|
||||
from circuitforge_core.tts.backends.base import AudioFormat, TTSBackend, make_tts_backend
|
||||
|
||||
_CONTENT_TYPES: dict[str, str] = {
|
||||
"ogg": "audio/ogg",
|
||||
"wav": "audio/wav",
|
||||
"mp3": "audio/mpeg",
|
||||
}
|
||||
|
||||
app = FastAPI(title="cf-tts")
|
||||
_backend: TTSBackend | None = None
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
def health() -> dict:
|
||||
if _backend is None:
|
||||
raise HTTPException(503, detail="backend not initialised")
|
||||
return {"status": "ok", "model": _backend.model_name, "vram_mb": _backend.vram_mb}
|
||||
|
||||
|
||||
@app.post("/synthesize")
|
||||
async def synthesize(
|
||||
text: Annotated[str, Form()],
|
||||
format: Annotated[AudioFormat, Form()] = "ogg",
|
||||
exaggeration: Annotated[float, Form()] = 0.5,
|
||||
cfg_weight: Annotated[float, Form()] = 0.5,
|
||||
temperature: Annotated[float, Form()] = 0.8,
|
||||
audio_prompt: UploadFile | None = None,
|
||||
) -> Response:
|
||||
if _backend is None:
|
||||
raise HTTPException(503, detail="backend not initialised")
|
||||
if not text.strip():
|
||||
raise HTTPException(422, detail="text must not be empty")
|
||||
|
||||
prompt_bytes: bytes | None = None
|
||||
if audio_prompt is not None:
|
||||
prompt_bytes = await audio_prompt.read()
|
||||
|
||||
result = _backend.synthesize(
|
||||
text,
|
||||
exaggeration=exaggeration,
|
||||
cfg_weight=cfg_weight,
|
||||
temperature=temperature,
|
||||
audio_prompt=prompt_bytes,
|
||||
format=format,
|
||||
)
|
||||
return Response(
|
||||
content=result.audio_bytes,
|
||||
media_type=_CONTENT_TYPES.get(result.format, "audio/ogg"),
|
||||
headers={
|
||||
"X-Duration-S": str(round(result.duration_s, 3)),
|
||||
"X-Model": result.model,
|
||||
"X-Sample-Rate": str(result.sample_rate),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _parse_args() -> argparse.Namespace:
|
||||
p = argparse.ArgumentParser(description="cf-tts service")
|
||||
p.add_argument("--model", required=True)
|
||||
p.add_argument("--port", type=int, default=8005)
|
||||
p.add_argument("--host", default="0.0.0.0")
|
||||
p.add_argument("--gpu-id", type=int, default=0)
|
||||
p.add_argument("--mock", action="store_true")
|
||||
return p.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
args = _parse_args()
|
||||
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_id)
|
||||
|
||||
mock = args.mock or args.model == "mock"
|
||||
device = "cpu" if mock else "cuda"
|
||||
|
||||
global _backend
|
||||
_backend = make_tts_backend(args.model, mock=mock, device=device)
|
||||
print(f"cf-tts backend ready: {_backend.model_name} ({_backend.vram_mb} MB)")
|
||||
|
||||
uvicorn.run(app, host=args.host, port=args.port)
|
||||
4
circuitforge_core/tts/backends/__init__.py
Normal file
4
circuitforge_core/tts/backends/__init__.py
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
from .base import AudioFormat, TTSBackend, TTSResult, make_tts_backend
|
||||
from .mock import MockTTSBackend
|
||||
|
||||
__all__ = ["AudioFormat", "TTSBackend", "TTSResult", "make_tts_backend", "MockTTSBackend"]
|
||||
84
circuitforge_core/tts/backends/base.py
Normal file
84
circuitforge_core/tts/backends/base.py
Normal file
|
|
@ -0,0 +1,84 @@
|
|||
"""
|
||||
TTSBackend Protocol — backend-agnostic TTS interface.
|
||||
|
||||
All backends return TTSResult with audio bytes in the requested format.
|
||||
Supported formats: ogg (default, smallest), wav (uncompressed, always works), mp3.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Literal, Protocol, runtime_checkable
|
||||
|
||||
AudioFormat = Literal["ogg", "wav", "mp3"]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TTSResult:
|
||||
audio_bytes: bytes
|
||||
sample_rate: int
|
||||
duration_s: float
|
||||
format: AudioFormat = "ogg"
|
||||
model: str = ""
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class TTSBackend(Protocol):
|
||||
def synthesize(
|
||||
self,
|
||||
text: str,
|
||||
*,
|
||||
exaggeration: float = 0.5,
|
||||
cfg_weight: float = 0.5,
|
||||
temperature: float = 0.8,
|
||||
audio_prompt: bytes | None = None,
|
||||
format: AudioFormat = "ogg",
|
||||
) -> TTSResult: ...
|
||||
|
||||
@property
|
||||
def model_name(self) -> str: ...
|
||||
|
||||
@property
|
||||
def vram_mb(self) -> int: ...
|
||||
|
||||
|
||||
def _encode_audio(
|
||||
wav_tensor, # torch.Tensor shape [1, T] or [T]
|
||||
sample_rate: int,
|
||||
format: AudioFormat,
|
||||
) -> bytes:
|
||||
"""Convert a torch tensor to audio bytes in the requested format."""
|
||||
import torch
|
||||
import torchaudio
|
||||
|
||||
wav = wav_tensor
|
||||
if wav.dim() == 1:
|
||||
wav = wav.unsqueeze(0)
|
||||
wav = wav.to(torch.float32).cpu()
|
||||
|
||||
buf = io.BytesIO()
|
||||
if format == "wav":
|
||||
torchaudio.save(buf, wav, sample_rate, format="wav")
|
||||
elif format == "ogg":
|
||||
torchaudio.save(buf, wav, sample_rate, format="ogg", encoding="vorbis")
|
||||
elif format == "mp3":
|
||||
# torchaudio MP3 encode requires ffmpeg backend; fall back to wav on failure
|
||||
try:
|
||||
torchaudio.save(buf, wav, sample_rate, format="mp3")
|
||||
except Exception:
|
||||
buf = io.BytesIO()
|
||||
torchaudio.save(buf, wav, sample_rate, format="wav")
|
||||
return buf.getvalue()
|
||||
|
||||
|
||||
def make_tts_backend(
|
||||
model_path: str,
|
||||
*,
|
||||
mock: bool = False,
|
||||
device: str = "cuda",
|
||||
) -> TTSBackend:
|
||||
if mock:
|
||||
from circuitforge_core.tts.backends.mock import MockTTSBackend
|
||||
return MockTTSBackend()
|
||||
from circuitforge_core.tts.backends.chatterbox import ChatterboxTurboBackend
|
||||
return ChatterboxTurboBackend(model_path=model_path, device=device)
|
||||
82
circuitforge_core/tts/backends/chatterbox.py
Normal file
82
circuitforge_core/tts/backends/chatterbox.py
Normal file
|
|
@ -0,0 +1,82 @@
|
|||
"""ChatterboxTurboBackend — ResembleAI chatterbox-turbo TTS via chatterbox-tts package."""
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
from circuitforge_core.tts.backends.base import (
|
||||
AudioFormat,
|
||||
TTSBackend,
|
||||
TTSResult,
|
||||
_encode_audio,
|
||||
)
|
||||
|
||||
_VRAM_MB = 768 # conservative estimate for chatterbox-turbo weights
|
||||
|
||||
|
||||
class ChatterboxTurboBackend:
|
||||
def __init__(self, model_path: str, device: str = "cuda") -> None:
|
||||
os.environ.setdefault("CUDA_VISIBLE_DEVICES", "0")
|
||||
from chatterbox.models.s3gen import S3GEN_SR
|
||||
from chatterbox.tts import ChatterboxTTS
|
||||
|
||||
self._sr = S3GEN_SR
|
||||
self._device = device
|
||||
self._model = ChatterboxTTS.from_local(model_path, device=device)
|
||||
self._model_path = model_path
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
return f"chatterbox-turbo@{os.path.basename(self._model_path)}"
|
||||
|
||||
@property
|
||||
def vram_mb(self) -> int:
|
||||
return _VRAM_MB
|
||||
|
||||
def synthesize(
|
||||
self,
|
||||
text: str,
|
||||
*,
|
||||
exaggeration: float = 0.5,
|
||||
cfg_weight: float = 0.5,
|
||||
temperature: float = 0.8,
|
||||
audio_prompt: bytes | None = None,
|
||||
format: AudioFormat = "ogg",
|
||||
) -> TTSResult:
|
||||
audio_prompt_path: str | None = None
|
||||
_tmp = None
|
||||
|
||||
if audio_prompt is not None:
|
||||
_tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
|
||||
_tmp.write(audio_prompt)
|
||||
_tmp.flush()
|
||||
audio_prompt_path = _tmp.name
|
||||
|
||||
try:
|
||||
wav = self._model.generate(
|
||||
text,
|
||||
exaggeration=exaggeration,
|
||||
cfg_weight=cfg_weight,
|
||||
temperature=temperature,
|
||||
audio_prompt_path=audio_prompt_path,
|
||||
)
|
||||
finally:
|
||||
if _tmp is not None:
|
||||
_tmp.close()
|
||||
os.unlink(_tmp.name)
|
||||
|
||||
duration_s = wav.shape[-1] / self._sr
|
||||
audio_bytes = _encode_audio(wav, self._sr, format)
|
||||
return TTSResult(
|
||||
audio_bytes=audio_bytes,
|
||||
sample_rate=self._sr,
|
||||
duration_s=duration_s,
|
||||
format=format,
|
||||
model=self.model_name,
|
||||
)
|
||||
|
||||
|
||||
assert isinstance(
|
||||
ChatterboxTurboBackend.__new__(ChatterboxTurboBackend), TTSBackend
|
||||
), "ChatterboxTurboBackend must satisfy TTSBackend Protocol"
|
||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Reference in a new issue