Compare commits
No commits in common. "main" and "feature/license-validation" have entirely different histories.
main
...
feature/li
199 changed files with 121 additions and 19017 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
|
@ -11,4 +11,3 @@ build/
|
|||
|
||||
# cf-orch private profiles (commit on personal/heimdall branch only)
|
||||
circuitforge_core/resources/profiles/private/
|
||||
.worktrees/
|
||||
|
|
|
|||
170
CHANGELOG.md
170
CHANGELOG.md
|
|
@ -6,176 +6,6 @@ Versions follow [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
|||
|
||||
---
|
||||
|
||||
## [0.20.0] — 2026-05-05
|
||||
|
||||
### Fixed / Enhanced
|
||||
|
||||
**`circuitforge_core.llm.LLMRouter`** — Pagepiper-driven improvements (closes #59, #60)
|
||||
|
||||
- **#59 — dict init** (`LLMRouter(config_path: Path | dict)`): `__init__` now accepts an inline config dict in addition to a `Path`. Ingest scripts that construct Ollama URLs from product-specific env vars (e.g. `PAGEPIPER_OLLAMA_URL`) can pass the dict directly without writing a temp file. Passing a dict previously raised `AttributeError: 'dict' object has no attribute 'exists'`. Tests: `test_init_accepts_inline_dict`, `test_init_dict_is_used_directly`.
|
||||
|
||||
- **#60 — Ollama preflight** (`_check_ollama_model_pulled()`): Before the first `embed()` call on an Ollama backend, `GET /api/tags` is checked to verify the configured embedding model is pulled. If it is not, a `RuntimeError` with an actionable `ollama pull <model>` hint is raised immediately — replacing the opaque `All LLM backends exhausted for embed()` error. Results are cached per base URL for the router's lifetime (one HTTP call, not one per `embed()` invocation). Non-Ollama backends (vLLM, etc.) don't expose `/api/tags` — a non-200 response causes the check to be silently skipped. Tests: `test_embed_raises_actionable_error_when_model_not_pulled`, `test_embed_proceeds_when_model_is_pulled`, `test_embed_skips_preflight_when_tags_endpoint_unavailable`, `test_ollama_tags_cache_is_hit_only_once`.
|
||||
|
||||
---
|
||||
|
||||
## [0.17.0] — 2026-04-27
|
||||
|
||||
### Added
|
||||
|
||||
**`circuitforge_core.reranker`** — shared reranker module for RAG pipelines across the orchard (MIT, closes #54)
|
||||
|
||||
Five adapters covering local and cloud paths:
|
||||
|
||||
- `adapters/bge.py` — `BGETextReranker`: FlagEmbedding cross-encoder (`BAAI/bge-reranker-*`). Batches all pairs in a single `compute_score()` call via `rerank_batch()`. Thread-safe with internal lock. Free tier.
|
||||
- `adapters/qwen3.py` — `Qwen3TextReranker`: generative reranker using `AutoModelForCausalLM`. Scores by reading yes/no token logits at the last input position after pre-filling the assistant `<think>\n\n</think>` block — one forward pass per batch, no generation loop. Left-pads for consistent last-token position across batch. Free / Paid tier.
|
||||
- `adapters/cross_encoder.py` — `CrossEncoderTextReranker`: sentence-transformers `CrossEncoder`. Broader model coverage: `mxbai-rerank-*`, `ms-marco-MiniLM-*`, `jina-reranker-*`. Free tier.
|
||||
- `adapters/cohere.py` — `CohereTextReranker`: Cohere Rerank API (BYOK cloud path). Reads `COHERE_API_KEY` from env or explicit `api_key=` arg. Restores original candidate order from Cohere's score-sorted response. Paid / BYOK.
|
||||
- `adapters/remote.py` — `RemoteTextReranker`: HTTP delegate to a cf-reranker service endpoint. `from_cf_orch()` classmethod allocates via cf-orch on demand. `release()` method returns the lease.
|
||||
- `adapters/mock.py` — `MockTextReranker`: Jaccard-similarity scorer, no model required. Used in tests and `CF_RERANKER_MOCK=1` mode.
|
||||
|
||||
`app.py` — `cf-reranker` FastAPI service (port 8011). Managed by cf-orch as a process-type service. Exposes `GET /health` and `POST /rerank`. Defaults to `Qwen3-Reranker-0.6B`.
|
||||
|
||||
**Auto cf-orch routing:** `make_reranker()` checks `CF_ORCH_URL` at construction time. When set (cloud deployments), it automatically allocates a `cf-reranker` service via cf-orch and returns a `RemoteTextReranker` — no code changes needed in Kiwi, Peregrine, or Snipe. Local dev (no `CF_ORCH_URL`) falls back to local BGE inference.
|
||||
|
||||
**Public API:**
|
||||
- `rerank(query, candidates, top_n)` — process-level singleton, mock-safe
|
||||
- `make_reranker(model_id, backend, mock)` — explicit instance
|
||||
- `reset_reranker()` — test teardown only
|
||||
- `RerankResult(candidate, score, rank)` — frozen dataclass result type
|
||||
|
||||
**`pyproject.toml` extras:** `reranker-bge`, `reranker-qwen3`, `reranker-cross-encoder`, `reranker-cohere`, `reranker-service`
|
||||
|
||||
54 tests across all adapters.
|
||||
|
||||
---
|
||||
|
||||
## [0.14.0] — 2026-04-20
|
||||
|
||||
### Added
|
||||
|
||||
**`circuitforge_core.activitypub`** — ActivityPub actor management, object construction, HTTP Signature signing, delivery, and Lemmy integration (MIT, closes #51)
|
||||
|
||||
- `actor.py` — `CFActor` frozen dataclass; `generate_rsa_keypair(bits)`; `make_actor()`; `load_actor_from_key_file()`. `to_ap_dict()` produces an ActivityPub Application/Person object and never includes the private key.
|
||||
- `objects.py` — `make_note()`, `make_offer()`, `make_request()` (CF namespace extension), `make_create()`. All return plain dicts; IDs minted with UUID4. `make_request` uses `https://circuitforge.tech/ns/activitystreams` context extension for the non-AS2 Request type.
|
||||
- `signing.py` — `sign_headers()` (draft-cavage-http-signatures-08, rsa-sha256; signs `(request-target)`, `host`, `date`, `digest`, `content-type`). `verify_signature()` re-computes Digest from actual body after signature verification to catch body-swap attacks.
|
||||
- `delivery.py` — `deliver_activity(activity, inbox_url, actor)` — synchronous `requests.post` with signed headers and `Content-Type: application/activity+json`.
|
||||
- `lemmy.py` — `LemmyConfig` frozen dataclass; `LemmyClient` with `login()`, `resolve_community()` (bare name or `!community@instance` address), `post_to_community()`. Uses Lemmy v0.19+ REST API (JWT auth). `LemmyAuthError` / `LemmyCommunityNotFound` exceptions.
|
||||
- `inbox.py` — `make_inbox_router(handlers, verify_key_fetcher, path)` — FastAPI APIRouter stub; dispatches by activity type; optional HTTP Signature verification via async `verify_key_fetcher` callback. FastAPI imported at module level with `_FASTAPI_AVAILABLE` guard (avoids annotation-resolution bug with lazy string annotations).
|
||||
- 105 tests across all six files.
|
||||
|
||||
**Key design notes:**
|
||||
- `inbox` not re-exported from `__init__` — requires fastapi, imported explicitly by products that need it
|
||||
- Signing Digest + re-verifying digest against body on verify — prevents body-swap attacks even when signature is valid
|
||||
- `from __future__ import annotations` intentionally omitted in `inbox.py` — FastAPI resolves `Request` annotation against module globals at route registration time
|
||||
|
||||
---
|
||||
|
||||
## [0.13.0] — 2026-04-20
|
||||
|
||||
### Added
|
||||
|
||||
**`circuitforge_core.preferences.currency`** — per-user currency code preference + formatting utility (MIT, closes #52)
|
||||
|
||||
- `PREF_CURRENCY_CODE = "currency.code"` — shared store key; all products read from the same path
|
||||
- `get_currency_code(user_id, store)` — priority fallback: store → `CURRENCY_DEFAULT` env var → `"USD"`
|
||||
- `set_currency_code(currency_code, user_id, store)` — persists ISO 4217 code, uppercased
|
||||
- `format_currency(amount, currency_code, locale="en_US")` — uses `babel.numbers.format_currency` when available; falls back to a built-in 30-currency symbol table (no hard babel dependency)
|
||||
- Symbol table covers: USD, CAD, AUD, NZD, GBP, EUR, CHF, SEK/NOK/DKK, JPY, CNY, KRW, INR, BRL, MXN, ZAR, SGD, HKD, THB, PLN, CZK, HUF, RUB, TRY, ILS, AED, SAR, CLP, COP, ARS, VND, IDR, MYR, PHP
|
||||
- JPY/KRW/HUF/CLP/COP/VND/IDR format with 0 decimal places per ISO 4217 minor-unit convention
|
||||
- Exported from `circuitforge_core.preferences` as `currency` submodule
|
||||
- 30 tests (preference store, env var fallback, format dispatch, symbol table, edge cases)
|
||||
|
||||
---
|
||||
|
||||
## [0.12.0] — 2026-04-20
|
||||
|
||||
### Added
|
||||
|
||||
**`circuitforge_core.job_quality`** — deterministic trust scorer for job listings (MIT, closes #48)
|
||||
|
||||
Pure signal processing module. No LLM calls, no network calls, no file I/O. Fully auditable and independently unit-testable per signal.
|
||||
|
||||
- `models.py` — `JobListing`, `JobEnrichment`, `SignalResult`, `JobQualityScore` (Pydantic)
|
||||
- `signals.py` — 12 signal functions with weights: `listing_age` (0.25), `repost_detected` (0.25), `no_salary_transparency` (0.20), `always_open_pattern` (0.20), `staffing_agency` (0.15), `requirement_overload` (0.12), `layoff_news` (0.12), `jd_vagueness` (0.10), `ats_blackhole` (0.10), `high_applicant_count` (0.08), `poor_response_history` (0.08), `weekend_posted` (0.04)
|
||||
- `scorer.py` — `score_job(listing, enrichment=None) -> JobQualityScore`; trust_score = 1 − clamp(sum(triggered weights), 0, 1); confidence = fraction of signals with available evidence
|
||||
- Salary transparency enforcement for CO, CA, NY, WA, IL, MA; ATS blackhole detection (Lever, Greenhouse, Workday, iCIMS, Taleo)
|
||||
- `ALL_SIGNALS` registry for iteration and extension
|
||||
- 83 tests across models, signals (all 12 individually), and scorer — 100% pass
|
||||
|
||||
---
|
||||
|
||||
## [0.11.0] — 2026-04-20
|
||||
|
||||
### Added
|
||||
|
||||
**`circuitforge_core.audio`** — shared PCM and audio signal utilities (MIT, numpy-only, closes #50)
|
||||
|
||||
Pure signal processing module. No model weights, no HuggingFace, no torch dependency.
|
||||
|
||||
- `convert.py` — `pcm_to_float32`, `float32_to_pcm`, `bytes_to_float32` (int16 ↔ float32 with correct int16 asymmetry handling)
|
||||
- `gate.py` — `is_silent`, `rms` (RMS energy gate; default 0.005 threshold extracted from cf-voice)
|
||||
- `resample.py` — `resample` (scipy `resample_poly` when available; numpy linear interpolation fallback)
|
||||
- `buffer.py` — `ChunkAccumulator` (window-based chunk collector with `flush`, `reset`, bounds enforcement)
|
||||
- Replaces hand-rolled equivalents in cf-voice `stt.py` + `context.py`. Also consumed by Sparrow and Linnet.
|
||||
|
||||
**`circuitforge_core.musicgen` tests** — 21 tests covering mock backend, factory, and FastAPI app endpoints (closes #49). Module was already implemented; tests were the missing deliverable.
|
||||
|
||||
### Fixed
|
||||
|
||||
**SQLCipher PRAGMA injection** (closes #45) — `db/base.py` now uses `PRAGMA key=?` parameterized form instead of f-string interpolation. Regression tests added (skipped gracefully when `pysqlcipher3` is not installed).
|
||||
|
||||
**`circuitforge_core.text.app`** — early validation on empty `--model` argument: raises `ValueError` with a clear message before reaching the HuggingFace loader. Prevents the cryptic `HFValidationError` surfaced by cf-orch #46 when no model candidates were provided.
|
||||
|
||||
---
|
||||
|
||||
## [0.10.0] — 2026-04-12
|
||||
|
||||
### Added
|
||||
|
||||
**`circuitforge_core.community`** — shared community signal module (BSL 1.1, closes #44)
|
||||
|
||||
Provides the PostgreSQL-backed infrastructure for the cross-product community fine-tuning signal pipeline. Products write signals; the training pipeline reads them.
|
||||
|
||||
- `CommunityDB` — psycopg2 connection pool with `run_migrations()`. Picks up all `.sql` files from `circuitforge_core/community/migrations/` in filename order. Safe to call on every startup (idempotent `CREATE TABLE IF NOT EXISTS`).
|
||||
- `CommunityPost` — frozen dataclass capturing a user-authored community post with a snapshot of the originating product item (`element_snapshot` as a tuple of key-value pairs for immutability).
|
||||
- `SharedStore` — base class for product-specific community stores. Provides typed `pg_read()` and `pg_write()` helpers that products subclass without re-implementing connection management.
|
||||
- Migration 001: `community_posts` schema (id, product, item_id, pseudonym, title, body, element_snapshot JSONB, created_at).
|
||||
- Migration 002: `community_reactions` stub (post_id FK, pseudonym, reaction_type, created_at).
|
||||
- `psycopg2-binary` added to `[community]` optional extras in `pyproject.toml`.
|
||||
- All community classes exported from `circuitforge_core.community`.
|
||||
|
||||
---
|
||||
|
||||
## [0.9.0] — 2026-04-10
|
||||
|
||||
### Added
|
||||
|
||||
**`circuitforge_core.text`** — OpenAI-compatible `/v1/chat/completions` endpoint and pipeline crystallization engine.
|
||||
|
||||
**`circuitforge_core.pipeline`** — multimodal pipeline with staged output crystallization. Products queue draft outputs for human review before committing.
|
||||
|
||||
**`circuitforge_core.stt`** — speech-to-text module. `FasterWhisperBackend` for local transcription via `faster-whisper`. Managed FastAPI app mountable in any product.
|
||||
|
||||
**`circuitforge_core.tts`** — text-to-speech module. `ChatterboxTurbo` backend for local synthesis. Managed FastAPI app.
|
||||
|
||||
**Accessibility preferences** — `preferences` module extended with structured accessibility fields (motion reduction, high contrast, font size, focus highlight) under `accessibility.*` key path.
|
||||
|
||||
**LLM output corrections router** — `make_corrections_router()` for collecting LLM output corrections in any product. Stores corrections in product SQLite for future fine-tuning.
|
||||
|
||||
---
|
||||
|
||||
## [0.8.0] — 2026-04-08
|
||||
|
||||
### Added
|
||||
|
||||
**`circuitforge_core.vision`** — cf-vision managed service shim. Routes vision inference requests to a local cf-vision worker (moondream2 / SigLIP). Closes #43.
|
||||
|
||||
**`circuitforge_core.api.feedback`** — `make_feedback_router()` shared Forgejo issue-filing router. Products mount it under `/api/feedback`; requires `FORGEJO_API_TOKEN`. Closes #30.
|
||||
|
||||
**License validation** — `CF_LICENSE_KEY` validation via Heimdall REST API. Products call `validate_license(key, product)` to gate premium features. Closes #26.
|
||||
|
||||
---
|
||||
|
||||
## [0.7.0] — 2026-04-04
|
||||
|
||||
### Added
|
||||
|
|
|
|||
21
LICENSE
21
LICENSE
|
|
@ -1,21 +0,0 @@
|
|||
MIT License
|
||||
|
||||
Copyright (c) 2026 CircuitForge LLC
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
193
README.md
193
README.md
|
|
@ -1,186 +1,37 @@
|
|||
<p align="center">
|
||||
<img src="docs/cf-logo.png" alt="CircuitForge logo" width="120" />
|
||||
</p>
|
||||
# circuitforge-core
|
||||
|
||||
<h1 align="center">circuitforge-core</h1>
|
||||
Shared scaffold for CircuitForge products.
|
||||
|
||||
<p align="center">Shared Python scaffold for privacy-first, self-hosted AI tools</p>
|
||||
**Current version: 0.7.0**
|
||||
|
||||
<p align="center">
|
||||
<a href="LICENSE"><img src="https://img.shields.io/badge/license-MIT-green.svg" alt="MIT License" /></a>
|
||||
<img src="https://img.shields.io/badge/version-0.20.0-blue.svg" alt="v0.20.0" />
|
||||
<img src="https://img.shields.io/badge/python-3.11%2B-blue.svg" alt="Python 3.11+" />
|
||||
<a href="https://git.opensourcesolarpunk.com/Circuit-Forge/circuitforge-core"><img src="https://img.shields.io/badge/repo-Forgejo-orange.svg" alt="Forgejo" /></a>
|
||||
</p>
|
||||
## Modules
|
||||
|
||||
---
|
||||
### Implemented
|
||||
|
||||
## Why circuitforge-core?
|
||||
- `circuitforge_core.db` — SQLite connection factory and migration runner
|
||||
- `circuitforge_core.llm` — LLM router with fallback chain (Ollama, vLLM, Anthropic, OpenAI-compatible)
|
||||
- `circuitforge_core.tiers` — Tier system with BYOK and local vision unlocks
|
||||
- `circuitforge_core.config` — Env validation and .env loader
|
||||
- `circuitforge_core.hardware` — Hardware detection and LLM backend profile generation (VRAM tiers, GPU/CPU auto-select)
|
||||
- `circuitforge_core.documents` — Document ingestion pipeline: PDF, DOCX, and image OCR → `StructuredDocument`
|
||||
- `circuitforge_core.affiliates` — Affiliate URL wrapping with opt-out, BYOK user IDs, and CF env-var fallback (`wrap_url`)
|
||||
- `circuitforge_core.preferences` — User preference store (local YAML file, pluggable backend); dot-path get/set API
|
||||
- `circuitforge_core.tasks` — VRAM-aware LLM task scheduler; shared slot manager across services (`TaskScheduler`)
|
||||
- `circuitforge_core.manage` — Cross-platform product process manager (Docker and native modes)
|
||||
- `circuitforge_core.resources` — Resource coordinator and agent: VRAM allocation, eviction engine, GPU profile registry
|
||||
|
||||
- **Local inference first.** The LLM router defaults to Ollama on localhost. Cloud APIs are a configurable fallback, not the default path. No telemetry, no round-trips you didn't ask for.
|
||||
- **VRAM-aware scheduling.** The task scheduler and resource coordinator track GPU memory across concurrent services, allocate slots before loading models, and evict backends gracefully when VRAM is scarce.
|
||||
- **Consistent tier system across products.** One `tiers` module handles Free / Paid / Premium / Ultra tiers, BYOK (bring your own key) unlocks, and local-vision capability gates — the same way in every product.
|
||||
- **Uniform developer experience.** DB migrations, config validation, document ingestion, process management, and preference storage all share a single, tested implementation. Products extend, not reimplement.
|
||||
### Stubs (in-tree, not yet implemented)
|
||||
|
||||
---
|
||||
- `circuitforge_core.vision` — Vision router base class (planned: moondream2 / Claude vision dispatch)
|
||||
- `circuitforge_core.wizard` — First-run wizard base class (products subclass `BaseWizard`)
|
||||
- `circuitforge_core.pipeline` — Staging queue base (`StagingDB`; products provide concrete schema)
|
||||
|
||||
## Install
|
||||
|
||||
```bash
|
||||
# From PyPI
|
||||
pip install circuitforge-core
|
||||
|
||||
# Editable install from source (recommended for product development)
|
||||
pip install -e /path/to/circuitforge-core
|
||||
|
||||
# With optional extras
|
||||
pip install circuitforge-core[pdf] # PDF/DOCX/OCR document ingestion
|
||||
pip install circuitforge-core[vector] # SQLite-vec vector store
|
||||
pip install circuitforge-core[text-transformers] # Local transformer inference (cf-text)
|
||||
pip install circuitforge-core[stt-faster-whisper] # Speech-to-text via Faster Whisper
|
||||
pip install circuitforge-core[tts-chatterbox] # Text-to-speech via Chatterbox
|
||||
pip install circuitforge-core[reranker-qwen3] # Reranking via Qwen3
|
||||
pip install circuitforge-core[community] # PostgreSQL-backed community store
|
||||
pip install circuitforge-core[manage] # cf-manage CLI (Typer)
|
||||
pip install circuitforge-core[dev] # All dev dependencies
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Modules
|
||||
|
||||
| Module | Status | Description |
|
||||
|---|---|---|
|
||||
| `db` | Implemented | SQLite connection factory and migration runner |
|
||||
| `llm` | Implemented | LLM router with priority fallback chain (Ollama, vLLM, Anthropic, OpenAI-compatible) |
|
||||
| `tiers` | Implemented | Tier system with BYOK and local-vision unlocks (Free / Paid / Premium / Ultra) |
|
||||
| `config` | Implemented | Env validation and `.env` loader with startup fail-fast |
|
||||
| `hardware` | Implemented | GPU/CPU detection, VRAM profiling, backend profile generation |
|
||||
| `documents` | Implemented | PDF, DOCX, and image OCR ingestion into `StructuredDocument` |
|
||||
| `affiliates` | Implemented | Affiliate URL wrapping with per-user opt-out and env-var fallback |
|
||||
| `preferences` | Implemented | User preference store — local YAML with pluggable backend; dot-path get/set |
|
||||
| `tasks` | Implemented | VRAM-aware LLM task scheduler; shared slot manager across services |
|
||||
| `manage` | Implemented | Cross-platform product process manager (Docker and native modes) |
|
||||
| `resources` | Implemented | VRAM allocation, eviction engine, GPU profile registry |
|
||||
| `text` | Implemented | Text processing utilities and local transformer inference service |
|
||||
| `activitypub` | Implemented | ActivityPub actor, inbox, delivery, and Lemmy federation primitives |
|
||||
| `audio` | Implemented | Audio buffer, format conversion, resampling, and VAD (voice activity detection) gate |
|
||||
| `stt` | Implemented | Speech-to-text service (Faster Whisper backend) |
|
||||
| `tts` | Implemented | Text-to-speech service (Chatterbox backend) |
|
||||
| `musicgen` | Implemented | Music generation service (AudioCraft/MusicGen backend) |
|
||||
| `reranker` | Implemented | Result reranking — BGE, Qwen3, cross-encoder, and Cohere adapters |
|
||||
| `vector` | Implemented | SQLite-vec vector store with pluggable embedding backend |
|
||||
| `api` | Implemented | Shared API helpers — corrections and feedback endpoints |
|
||||
| `community` | Implemented | Community feed and social store (PostgreSQL-backed) |
|
||||
| `platforms` | Implemented | Platform-specific integrations (eBay) |
|
||||
| `cloud_session` | Implemented | Cloud session management primitives |
|
||||
| `input` | Implemented | Input handling — MediaPipe gesture recognition |
|
||||
| `job_quality` | Implemented | Job listing quality scoring and signal extraction |
|
||||
| `vision` | Stub | Vision router (moondream2 / SigLIP dispatch — planned) |
|
||||
| `wizard` | Stub | First-run wizard base class — products subclass `BaseWizard` |
|
||||
| `pipeline` | Stub | Staging queue base — products provide concrete schema |
|
||||
|
||||
---
|
||||
|
||||
## Usage: LLM Router
|
||||
|
||||
The LLM router reads a config file at `~/.config/circuitforge/llm.yaml`, tries each backend in fallback order, and skips unreachable or disabled entries transparently.
|
||||
|
||||
```python
|
||||
from circuitforge_core.llm import LLMRouter
|
||||
|
||||
# Auto-detects from env vars when llm.yaml is absent:
|
||||
# ANTHROPIC_API_KEY, OPENAI_API_KEY / OPENAI_BASE_URL, OLLAMA_HOST
|
||||
router = LLMRouter()
|
||||
|
||||
response = router.complete(
|
||||
messages=[{"role": "user", "content": "Summarize this in one sentence."}],
|
||||
system="You are a concise assistant.",
|
||||
)
|
||||
print(response)
|
||||
```
|
||||
|
||||
**Example `llm.yaml`** (Ollama local, Anthropic cloud fallback):
|
||||
|
||||
```yaml
|
||||
fallback_order:
|
||||
- ollama
|
||||
- anthropic
|
||||
|
||||
backends:
|
||||
ollama:
|
||||
type: openai_compat
|
||||
enabled: true
|
||||
base_url: http://localhost:11434/v1
|
||||
model: llama3.2:3b
|
||||
|
||||
anthropic:
|
||||
type: anthropic
|
||||
enabled: true
|
||||
model: claude-haiku-4-5-20251001
|
||||
api_key_env: ANTHROPIC_API_KEY
|
||||
supports_images: true
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Usage: Database + Migrations
|
||||
|
||||
```python
|
||||
from circuitforge_core.db import get_connection, run_migrations
|
||||
from pathlib import Path
|
||||
|
||||
# Run product migrations on startup
|
||||
run_migrations(db_path=Path("data/app.db"), migrations_dir=Path("db/migrations"))
|
||||
|
||||
# Get a connection anywhere in your app
|
||||
with get_connection(Path("data/app.db")) as conn:
|
||||
conn.execute("INSERT INTO items (name) VALUES (?)", ("example",))
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Used by
|
||||
|
||||
| Product | Description |
|
||||
|---|---|
|
||||
| [peregrine](https://git.opensourcesolarpunk.com/Circuit-Forge/peregrine) | Job search — discovery, cover letters, interview prep |
|
||||
| [snipe](https://git.opensourcesolarpunk.com/Circuit-Forge/snipe) | Auction sniping — eBay trust scoring, bid timing |
|
||||
| [kiwi](https://git.opensourcesolarpunk.com/Circuit-Forge/kiwi) | Pantry tracker with barcode/receipt OCR and recipe suggestions |
|
||||
| [avocet](https://git.opensourcesolarpunk.com/Circuit-Forge/avocet) | Email classifier training and benchmark harness |
|
||||
| [osprey](https://git.opensourcesolarpunk.com/Circuit-Forge/osprey) | Government hold-line automation |
|
||||
| [linnet](https://git.opensourcesolarpunk.com/Circuit-Forge/linnet) | Real-time tone annotation and voice transcription |
|
||||
| pagepiper | PDF/rulebook RAG (retrieval-augmented generation) search |
|
||||
|
||||
---
|
||||
|
||||
## Contributing
|
||||
|
||||
circuitforge-core is MIT licensed. Contributions are welcome.
|
||||
|
||||
```bash
|
||||
git clone https://git.opensourcesolarpunk.com/Circuit-Forge/circuitforge-core
|
||||
cd circuitforge-core
|
||||
pip install -e ".[dev]"
|
||||
pytest
|
||||
```
|
||||
|
||||
- New modules belong in `circuitforge_core/<module>/` as a package, not a flat file
|
||||
- Keep modules focused — extract when a module exceeds 400 lines
|
||||
- All public functions need type annotations
|
||||
- Tests live in `tests/` — aim for 80% coverage on new code
|
||||
- Use `ruff` for linting before submitting a PR
|
||||
|
||||
Open issues and PRs at: [git.opensourcesolarpunk.com/Circuit-Forge/circuitforge-core](https://git.opensourcesolarpunk.com/Circuit-Forge/circuitforge-core)
|
||||
|
||||
---
|
||||
|
||||
## License
|
||||
|
||||
MIT — see [LICENSE](LICENSE).
|
||||
|
||||
This is the fully open layer of the CircuitForge stack. Products built on top of circuitforge-core may carry different licenses (BSL 1.1 for AI features, proprietary for fine-tuned weights). The scaffold itself is and will remain MIT.
|
||||
|
||||
---
|
||||
|
||||
Humans own design, architecture, code review, testing, and verification. LLMs are part of our development workflow. [Our positions on LLM use →](https://circuitforge.tech/positions)
|
||||
BSL 1.1 — see LICENSE
|
||||
|
|
|
|||
|
|
@ -1,8 +1 @@
|
|||
__version__ = "0.18.0"
|
||||
|
||||
try:
|
||||
from circuitforge_core.community import CommunityDB, CommunityPost, SharedStore
|
||||
__all__ = ["CommunityDB", "CommunityPost", "SharedStore"]
|
||||
except ImportError:
|
||||
# psycopg2 not installed — install with: pip install circuitforge-core[community]
|
||||
pass
|
||||
__version__ = "0.8.0"
|
||||
|
|
|
|||
|
|
@ -1,55 +0,0 @@
|
|||
"""
|
||||
circuitforge_core.activitypub — ActivityPub actor management, object construction,
|
||||
HTTP Signature signing, delivery, and Lemmy integration.
|
||||
|
||||
MIT licensed.
|
||||
"""
|
||||
|
||||
from circuitforge_core.activitypub.actor import (
|
||||
CFActor,
|
||||
generate_rsa_keypair,
|
||||
load_actor_from_key_file,
|
||||
make_actor,
|
||||
)
|
||||
from circuitforge_core.activitypub.delivery import deliver_activity
|
||||
from circuitforge_core.activitypub.lemmy import (
|
||||
LemmyAuthError,
|
||||
LemmyClient,
|
||||
LemmyCommunityNotFound,
|
||||
LemmyConfig,
|
||||
)
|
||||
from circuitforge_core.activitypub.objects import (
|
||||
PUBLIC,
|
||||
make_create,
|
||||
make_note,
|
||||
make_offer,
|
||||
make_request,
|
||||
)
|
||||
from circuitforge_core.activitypub.signing import sign_headers, verify_signature
|
||||
|
||||
__all__ = [
|
||||
# Actor
|
||||
"CFActor",
|
||||
"generate_rsa_keypair",
|
||||
"load_actor_from_key_file",
|
||||
"make_actor",
|
||||
# Objects
|
||||
"PUBLIC",
|
||||
"make_note",
|
||||
"make_offer",
|
||||
"make_request",
|
||||
"make_create",
|
||||
# Signing
|
||||
"sign_headers",
|
||||
"verify_signature",
|
||||
# Delivery
|
||||
"deliver_activity",
|
||||
# Lemmy
|
||||
"LemmyConfig",
|
||||
"LemmyClient",
|
||||
"LemmyAuthError",
|
||||
"LemmyCommunityNotFound",
|
||||
]
|
||||
|
||||
# inbox is optional (requires fastapi) — import it when needed:
|
||||
# from circuitforge_core.activitypub.inbox import make_inbox_router
|
||||
|
|
@ -1,146 +0,0 @@
|
|||
"""
|
||||
CFActor — ActivityPub actor identity for CircuitForge products.
|
||||
|
||||
An actor holds RSA key material and its ActivityPub identity URLs.
|
||||
The private key is in-memory only; to_ap_dict() never includes it.
|
||||
|
||||
MIT licensed.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CFActor:
|
||||
"""ActivityPub actor for a CircuitForge product instance."""
|
||||
|
||||
actor_id: str # e.g. "https://kiwi.circuitforge.tech/actors/kiwi"
|
||||
username: str
|
||||
display_name: str
|
||||
inbox_url: str
|
||||
outbox_url: str
|
||||
public_key_pem: str
|
||||
private_key_pem: str # Never included in to_ap_dict()
|
||||
icon_url: str | None = None
|
||||
summary: str | None = None
|
||||
|
||||
def to_ap_dict(self) -> dict:
|
||||
"""Return an ActivityPub Person/Application object (public only)."""
|
||||
obj: dict = {
|
||||
"@context": [
|
||||
"https://www.w3.org/ns/activitystreams",
|
||||
"https://w3id.org/security/v1",
|
||||
],
|
||||
"id": self.actor_id,
|
||||
"type": "Application",
|
||||
"preferredUsername": self.username,
|
||||
"name": self.display_name,
|
||||
"inbox": self.inbox_url,
|
||||
"outbox": self.outbox_url,
|
||||
"publicKey": {
|
||||
"id": f"{self.actor_id}#main-key",
|
||||
"owner": self.actor_id,
|
||||
"publicKeyPem": self.public_key_pem,
|
||||
},
|
||||
}
|
||||
if self.summary:
|
||||
obj["summary"] = self.summary
|
||||
if self.icon_url:
|
||||
obj["icon"] = {
|
||||
"type": "Image",
|
||||
"mediaType": "image/png",
|
||||
"url": self.icon_url,
|
||||
}
|
||||
return obj
|
||||
|
||||
|
||||
def generate_rsa_keypair(bits: int = 2048) -> tuple[str, str]:
|
||||
"""
|
||||
Generate a new RSA keypair.
|
||||
|
||||
Returns:
|
||||
(private_key_pem, public_key_pem) as PEM-encoded strings.
|
||||
"""
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import rsa
|
||||
|
||||
private_key = rsa.generate_private_key(public_exponent=65537, key_size=bits)
|
||||
private_pem = private_key.private_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PrivateFormat.PKCS8,
|
||||
encryption_algorithm=serialization.NoEncryption(),
|
||||
).decode()
|
||||
public_pem = private_key.public_key().public_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||
).decode()
|
||||
return private_pem, public_pem
|
||||
|
||||
|
||||
def make_actor(
|
||||
actor_id: str,
|
||||
username: str,
|
||||
display_name: str,
|
||||
private_key_pem: str,
|
||||
public_key_pem: str,
|
||||
icon_url: str | None = None,
|
||||
summary: str | None = None,
|
||||
) -> CFActor:
|
||||
"""
|
||||
Construct a CFActor from an existing keypair.
|
||||
|
||||
Inbox and outbox URLs are derived from actor_id by convention:
|
||||
{actor_id}/inbox and {actor_id}/outbox
|
||||
"""
|
||||
return CFActor(
|
||||
actor_id=actor_id,
|
||||
username=username,
|
||||
display_name=display_name,
|
||||
inbox_url=f"{actor_id}/inbox",
|
||||
outbox_url=f"{actor_id}/outbox",
|
||||
public_key_pem=public_key_pem,
|
||||
private_key_pem=private_key_pem,
|
||||
icon_url=icon_url,
|
||||
summary=summary,
|
||||
)
|
||||
|
||||
|
||||
def load_actor_from_key_file(
|
||||
actor_id: str,
|
||||
username: str,
|
||||
display_name: str,
|
||||
private_key_path: str,
|
||||
icon_url: str | None = None,
|
||||
summary: str | None = None,
|
||||
) -> CFActor:
|
||||
"""
|
||||
Load a CFActor from a PEM private key file on disk.
|
||||
|
||||
The public key is derived from the private key — no separate public key
|
||||
file is required.
|
||||
"""
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.primitives.serialization import load_pem_private_key
|
||||
|
||||
pem_bytes = Path(private_key_path).read_bytes()
|
||||
private_key = load_pem_private_key(pem_bytes, password=None)
|
||||
private_pem = private_key.private_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PrivateFormat.PKCS8,
|
||||
encryption_algorithm=serialization.NoEncryption(),
|
||||
).decode()
|
||||
public_pem = private_key.public_key().public_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||
).decode()
|
||||
return make_actor(
|
||||
actor_id=actor_id,
|
||||
username=username,
|
||||
display_name=display_name,
|
||||
private_key_pem=private_pem,
|
||||
public_key_pem=public_pem,
|
||||
icon_url=icon_url,
|
||||
summary=summary,
|
||||
)
|
||||
|
|
@ -1,56 +0,0 @@
|
|||
"""
|
||||
ActivityPub HTTP delivery — POST a signed activity to a remote inbox.
|
||||
|
||||
Synchronous (uses requests). Async callers can wrap in asyncio.to_thread.
|
||||
|
||||
MIT licensed.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import requests
|
||||
|
||||
from circuitforge_core.activitypub.signing import sign_headers
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from circuitforge_core.activitypub.actor import CFActor
|
||||
|
||||
ACTIVITY_CONTENT_TYPE = "application/activity+json"
|
||||
|
||||
|
||||
def deliver_activity(
|
||||
activity: dict,
|
||||
inbox_url: str,
|
||||
actor: "CFActor",
|
||||
timeout: float = 10.0,
|
||||
) -> requests.Response:
|
||||
"""
|
||||
POST a signed ActivityPub activity to a remote inbox.
|
||||
|
||||
The activity dict is serialized to JSON, signed with the actor's private
|
||||
key (HTTP Signatures, rsa-sha256), and delivered via HTTP POST.
|
||||
|
||||
Args:
|
||||
activity: ActivityPub activity dict (e.g. from make_create()).
|
||||
inbox_url: Target inbox URL (e.g. "https://lemmy.ml/inbox").
|
||||
actor: CFActor whose key signs the request.
|
||||
timeout: Request timeout in seconds.
|
||||
|
||||
Returns:
|
||||
The raw requests.Response. Caller decides retry / error policy.
|
||||
|
||||
Raises:
|
||||
requests.RequestException: On network-level failure.
|
||||
"""
|
||||
body = json.dumps(activity).encode()
|
||||
base_headers = {"Content-Type": ACTIVITY_CONTENT_TYPE}
|
||||
signed = sign_headers(
|
||||
method="POST",
|
||||
url=inbox_url,
|
||||
headers=base_headers,
|
||||
body=body,
|
||||
actor=actor,
|
||||
)
|
||||
return requests.post(inbox_url, data=body, headers=signed, timeout=timeout)
|
||||
|
|
@ -1,128 +0,0 @@
|
|||
"""
|
||||
ActivityPub inbox router — FastAPI stub for receiving federated activities.
|
||||
|
||||
Products mount this router to handle incoming Create, Follow, Like, Announce,
|
||||
and other ActivityPub activities from the Fediverse.
|
||||
|
||||
Requires fastapi (optional dep). ImportError is raised with a clear message
|
||||
when fastapi is not installed.
|
||||
|
||||
NOTE: from __future__ import annotations is intentionally omitted here.
|
||||
FastAPI resolves route parameter annotations against module globals at
|
||||
definition time; lazy string annotations break the Request injection.
|
||||
|
||||
MIT licensed.
|
||||
"""
|
||||
|
||||
import json as _json
|
||||
import re
|
||||
from typing import Awaitable, Callable
|
||||
|
||||
# Handler type: receives (activity_dict, request_headers) and returns None
|
||||
InboxHandler = Callable[[dict, dict], Awaitable[None]]
|
||||
|
||||
# FastAPI imports at module level so annotations resolve correctly.
|
||||
# Products that don't use the inbox router are not affected by this import
|
||||
# since circuitforge_core.activitypub.__init__ does NOT import inbox.
|
||||
try:
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
_FASTAPI_AVAILABLE = True
|
||||
except ImportError:
|
||||
_FASTAPI_AVAILABLE = False
|
||||
# Provide stubs so the module can be imported without fastapi
|
||||
APIRouter = None # type: ignore[assignment,misc]
|
||||
HTTPException = None # type: ignore[assignment]
|
||||
Request = None # type: ignore[assignment]
|
||||
JSONResponse = None # type: ignore[assignment]
|
||||
|
||||
|
||||
def make_inbox_router(
|
||||
handlers: dict[str, InboxHandler] | None = None,
|
||||
verify_key_fetcher: Callable[[str], Awaitable[str | None]] | None = None,
|
||||
path: str = "/inbox",
|
||||
) -> "APIRouter": # type: ignore[name-defined]
|
||||
"""
|
||||
Build a FastAPI router that handles ActivityPub inbox POSTs.
|
||||
|
||||
The router:
|
||||
1. Parses the JSON body into an activity dict
|
||||
2. Optionally verifies the HTTP Signature (when verify_key_fetcher is provided)
|
||||
3. Dispatches activity["type"] to the matching handler from *handlers*
|
||||
4. Returns 202 Accepted on success, 400 on bad JSON, 401 on bad signature
|
||||
|
||||
Args:
|
||||
handlers: Dict mapping activity type strings (e.g. "Create",
|
||||
"Follow") to async handler callables.
|
||||
verify_key_fetcher: Async callable that takes a keyId URL and returns the
|
||||
actor's public key PEM, or None if not found.
|
||||
When None, signature verification is skipped (dev mode).
|
||||
path: Inbox endpoint path (default "/inbox").
|
||||
|
||||
Returns:
|
||||
FastAPI APIRouter.
|
||||
|
||||
Example::
|
||||
|
||||
async def on_create(activity: dict, headers: dict) -> None:
|
||||
print("Received Create:", activity)
|
||||
|
||||
router = make_inbox_router(handlers={"Create": on_create})
|
||||
app.include_router(router, prefix="/actors/kiwi")
|
||||
"""
|
||||
if not _FASTAPI_AVAILABLE:
|
||||
raise ImportError(
|
||||
"circuitforge_core.activitypub.inbox requires fastapi. "
|
||||
"Install with: pip install fastapi"
|
||||
)
|
||||
|
||||
from circuitforge_core.activitypub.signing import verify_signature
|
||||
|
||||
router = APIRouter()
|
||||
_handlers: dict[str, InboxHandler] = handlers or {}
|
||||
|
||||
@router.post(path, status_code=202)
|
||||
async def inbox_endpoint(request: Request) -> JSONResponse:
|
||||
# Parse body — read bytes first (needed for signature verification),
|
||||
# then decode JSON manually to avoid double-read issues.
|
||||
try:
|
||||
body = await request.body()
|
||||
activity = _json.loads(body)
|
||||
except Exception:
|
||||
raise HTTPException(status_code=400, detail="Invalid JSON body.")
|
||||
|
||||
# Optional signature verification
|
||||
if verify_key_fetcher is not None:
|
||||
sig_header = request.headers.get("Signature", "")
|
||||
key_id = _parse_key_id(sig_header)
|
||||
if not key_id:
|
||||
raise HTTPException(status_code=401, detail="Missing or malformed Signature header.")
|
||||
public_key_pem = await verify_key_fetcher(key_id)
|
||||
if public_key_pem is None:
|
||||
raise HTTPException(status_code=401, detail=f"Unknown keyId: {key_id}")
|
||||
ok = verify_signature(
|
||||
headers=dict(request.headers),
|
||||
method="POST",
|
||||
path=request.url.path,
|
||||
body=body,
|
||||
public_key_pem=public_key_pem,
|
||||
)
|
||||
if not ok:
|
||||
raise HTTPException(status_code=401, detail="Signature verification failed.")
|
||||
|
||||
activity_type = activity.get("type", "")
|
||||
handler = _handlers.get(activity_type)
|
||||
if handler is None:
|
||||
# Unknown types are silently accepted per AP spec — return 202
|
||||
return JSONResponse(status_code=202, content={"status": "accepted", "type": activity_type})
|
||||
|
||||
await handler(activity, dict(request.headers))
|
||||
return JSONResponse(status_code=202, content={"status": "accepted"})
|
||||
|
||||
return router
|
||||
|
||||
|
||||
def _parse_key_id(sig_header: str) -> str | None:
|
||||
"""Extract keyId value from a Signature header string."""
|
||||
match = re.search(r'keyId="([^"]+)"', sig_header)
|
||||
return match.group(1) if match else None
|
||||
|
|
@ -1,173 +0,0 @@
|
|||
"""
|
||||
Lemmy REST API client for posting to Lemmy communities.
|
||||
|
||||
Uses JWT authentication (Lemmy v0.19+ API). Does not require ActivityPub
|
||||
federation setup — the Lemmy REST API is simpler and more reliable for
|
||||
the initial integration.
|
||||
|
||||
MIT licensed.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
class LemmyAuthError(Exception):
|
||||
"""Raised when Lemmy login fails."""
|
||||
|
||||
|
||||
class LemmyCommunityNotFound(Exception):
|
||||
"""Raised when a community cannot be resolved by name."""
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LemmyConfig:
|
||||
"""Connection config for a Lemmy instance."""
|
||||
|
||||
instance_url: str # e.g. "https://lemmy.ml" (no trailing slash)
|
||||
username: str
|
||||
password: str # Load from env/config; never hardcode
|
||||
|
||||
|
||||
class LemmyClient:
|
||||
"""
|
||||
Lemmy REST API client.
|
||||
|
||||
Usage::
|
||||
|
||||
config = LemmyConfig(instance_url="https://lemmy.ml", username="bot", password="...")
|
||||
client = LemmyClient(config)
|
||||
client.login()
|
||||
community_id = client.resolve_community("!cooking@lemmy.world")
|
||||
client.post_to_community(community_id, title="Fresh pesto recipe", body="...")
|
||||
"""
|
||||
|
||||
def __init__(self, config: LemmyConfig) -> None:
|
||||
self._config = config
|
||||
self._jwt: str | None = None
|
||||
self._session = requests.Session()
|
||||
self._session.headers.update({"Content-Type": "application/json"})
|
||||
|
||||
@property
|
||||
def _api(self) -> str:
|
||||
return f"{self._config.instance_url.rstrip('/')}/api/v3"
|
||||
|
||||
def _auth_headers(self) -> dict[str, str]:
|
||||
if not self._jwt:
|
||||
raise LemmyAuthError("Not logged in — call login() first.")
|
||||
return {"Authorization": f"Bearer {self._jwt}"}
|
||||
|
||||
def login(self) -> None:
|
||||
"""
|
||||
Authenticate with the Lemmy instance and store the JWT.
|
||||
|
||||
Raises:
|
||||
LemmyAuthError: If credentials are rejected or the request fails.
|
||||
"""
|
||||
resp = self._session.post(
|
||||
f"{self._api}/user/login",
|
||||
json={"username_or_email": self._config.username, "password": self._config.password},
|
||||
timeout=10,
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
raise LemmyAuthError(
|
||||
f"Lemmy login failed ({resp.status_code}): {resp.text[:200]}"
|
||||
)
|
||||
data = resp.json()
|
||||
token = data.get("jwt")
|
||||
if not token:
|
||||
raise LemmyAuthError("Lemmy login response missing 'jwt' field.")
|
||||
self._jwt = token
|
||||
|
||||
def resolve_community(self, name: str) -> int:
|
||||
"""
|
||||
Resolve a community name or address to its numeric Lemmy ID.
|
||||
|
||||
Accepts:
|
||||
- Bare name: "cooking"
|
||||
- Fediverse address: "!cooking@lemmy.world"
|
||||
- Display name search (best-effort)
|
||||
|
||||
Args:
|
||||
name: Community identifier.
|
||||
|
||||
Returns:
|
||||
Numeric community ID.
|
||||
|
||||
Raises:
|
||||
LemmyCommunityNotFound: If not found or multiple matches are ambiguous.
|
||||
LemmyAuthError: If not logged in.
|
||||
"""
|
||||
# Strip leading ! for address lookups
|
||||
lookup = name.lstrip("!")
|
||||
resp = self._session.get(
|
||||
f"{self._api}/search",
|
||||
params={"q": lookup, "type_": "Communities", "limit": 5},
|
||||
headers=self._auth_headers(),
|
||||
timeout=10,
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
raise LemmyCommunityNotFound(
|
||||
f"Community search failed ({resp.status_code}): {resp.text[:200]}"
|
||||
)
|
||||
communities = resp.json().get("communities", [])
|
||||
if not communities:
|
||||
raise LemmyCommunityNotFound(f"No communities found for '{name}'.")
|
||||
# Prefer exact actor_id match (e.g. !cooking@lemmy.world)
|
||||
for item in communities:
|
||||
view = item.get("community", {})
|
||||
if "@" in lookup:
|
||||
actor_id: str = view.get("actor_id", "")
|
||||
if lookup.lower() in actor_id.lower():
|
||||
return int(view["id"])
|
||||
else:
|
||||
if view.get("name", "").lower() == lookup.lower():
|
||||
return int(view["id"])
|
||||
# Fall back to first result
|
||||
return int(communities[0]["community"]["id"])
|
||||
|
||||
def post_to_community(
|
||||
self,
|
||||
community_id: int,
|
||||
title: str,
|
||||
body: str,
|
||||
url: str | None = None,
|
||||
nsfw: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Create a post in a Lemmy community.
|
||||
|
||||
Args:
|
||||
community_id: Numeric community ID (from resolve_community()).
|
||||
title: Post title.
|
||||
body: Markdown post body.
|
||||
url: Optional external URL to attach.
|
||||
nsfw: Mark NSFW (default False).
|
||||
|
||||
Returns:
|
||||
Lemmy API response dict (contains 'post_view', etc.).
|
||||
|
||||
Raises:
|
||||
LemmyAuthError: If not logged in.
|
||||
requests.RequestException: On network failure.
|
||||
"""
|
||||
payload: dict[str, Any] = {
|
||||
"community_id": community_id,
|
||||
"name": title,
|
||||
"body": body,
|
||||
"nsfw": nsfw,
|
||||
}
|
||||
if url:
|
||||
payload["url"] = url
|
||||
|
||||
resp = self._session.post(
|
||||
f"{self._api}/post",
|
||||
json=payload,
|
||||
headers=self._auth_headers(),
|
||||
timeout=15,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
|
@ -1,168 +0,0 @@
|
|||
"""
|
||||
ActivityStreams 2.0 object constructors.
|
||||
|
||||
All functions return plain dicts (no classes) — they are serialized to JSON
|
||||
for delivery. IDs are minted with UUID4 so callers don't need to track them.
|
||||
|
||||
Custom types:
|
||||
- "Offer" — AS2 Offer (Rook exchange offers)
|
||||
- "Request" — custom CF extension (Rook exchange requests); not in core AS2
|
||||
|
||||
MIT licensed.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from circuitforge_core.activitypub.actor import CFActor
|
||||
|
||||
# AS2 public address (all followers)
|
||||
PUBLIC = "https://www.w3.org/ns/activitystreams#Public"
|
||||
|
||||
# Custom context extension for CF-specific types
|
||||
_CF_CONTEXT = "https://circuitforge.tech/ns/activitystreams"
|
||||
|
||||
|
||||
def _now_iso() -> str:
|
||||
return datetime.now(tz=timezone.utc).isoformat().replace("+00:00", "Z")
|
||||
|
||||
|
||||
def _mint_id(actor_id: str, type_slug: str) -> str:
|
||||
"""Generate a unique ID scoped to the actor's namespace."""
|
||||
return f"{actor_id}/{type_slug}/{uuid.uuid4().hex}"
|
||||
|
||||
|
||||
def make_note(
|
||||
actor_id: str,
|
||||
content: str,
|
||||
to: list[str] | None = None,
|
||||
cc: list[str] | None = None,
|
||||
in_reply_to: str | None = None,
|
||||
tag: list[dict] | None = None,
|
||||
published: datetime | None = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Construct an AS2 Note object.
|
||||
|
||||
Args:
|
||||
actor_id: The actor's ID URL (attributedTo).
|
||||
content: HTML or plain-text body.
|
||||
to: Direct recipients (defaults to [PUBLIC]).
|
||||
cc: CC recipients.
|
||||
in_reply_to: URL of the parent note when replying.
|
||||
tag: Mention/hashtag tag dicts.
|
||||
published: Post timestamp (defaults to now UTC).
|
||||
"""
|
||||
note: dict = {
|
||||
"@context": "https://www.w3.org/ns/activitystreams",
|
||||
"id": _mint_id(actor_id, "notes"),
|
||||
"type": "Note",
|
||||
"attributedTo": actor_id,
|
||||
"content": content,
|
||||
"to": to if to is not None else [PUBLIC],
|
||||
"published": published.isoformat().replace("+00:00", "Z") if published else _now_iso(),
|
||||
}
|
||||
if cc:
|
||||
note["cc"] = cc
|
||||
if in_reply_to:
|
||||
note["inReplyTo"] = in_reply_to
|
||||
if tag:
|
||||
note["tag"] = tag
|
||||
return note
|
||||
|
||||
|
||||
def make_offer(
|
||||
actor_id: str,
|
||||
summary: str,
|
||||
content: str,
|
||||
to: list[str] | None = None,
|
||||
cc: list[str] | None = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Construct an AS2 Offer object (Rook exchange offers).
|
||||
|
||||
The Offer type is part of core ActivityStreams 2.0.
|
||||
|
||||
Args:
|
||||
actor_id: The actor's ID URL (actor field).
|
||||
summary: Short one-line description (used as title in Lemmy).
|
||||
content: Full HTML/plain-text description.
|
||||
to: Recipients (defaults to [PUBLIC]).
|
||||
cc: CC recipients.
|
||||
"""
|
||||
return {
|
||||
"@context": "https://www.w3.org/ns/activitystreams",
|
||||
"id": _mint_id(actor_id, "offers"),
|
||||
"type": "Offer",
|
||||
"actor": actor_id,
|
||||
"summary": summary,
|
||||
"content": content,
|
||||
"to": to if to is not None else [PUBLIC],
|
||||
"cc": cc or [],
|
||||
"published": _now_iso(),
|
||||
}
|
||||
|
||||
|
||||
def make_request(
|
||||
actor_id: str,
|
||||
summary: str,
|
||||
content: str,
|
||||
to: list[str] | None = None,
|
||||
cc: list[str] | None = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Construct a CF-extension Request object (Rook exchange requests).
|
||||
|
||||
"Request" is not in core AS2 vocabulary — the CF namespace context
|
||||
extension is included so federating servers don't reject it.
|
||||
|
||||
Args:
|
||||
actor_id: The actor's ID URL.
|
||||
summary: Short one-line description.
|
||||
content: Full HTML/plain-text description.
|
||||
to: Recipients (defaults to [PUBLIC]).
|
||||
cc: CC recipients.
|
||||
"""
|
||||
return {
|
||||
"@context": [
|
||||
"https://www.w3.org/ns/activitystreams",
|
||||
_CF_CONTEXT,
|
||||
],
|
||||
"id": _mint_id(actor_id, "requests"),
|
||||
"type": "Request",
|
||||
"actor": actor_id,
|
||||
"summary": summary,
|
||||
"content": content,
|
||||
"to": to if to is not None else [PUBLIC],
|
||||
"cc": cc or [],
|
||||
"published": _now_iso(),
|
||||
}
|
||||
|
||||
|
||||
def make_create(actor: "CFActor", obj: dict) -> dict:
|
||||
"""
|
||||
Wrap any object dict in an AS2 Create activity.
|
||||
|
||||
The Create activity's id, actor, to, cc, and published fields are
|
||||
derived from the wrapped object where available.
|
||||
|
||||
Args:
|
||||
actor: The CFActor originating the Create.
|
||||
obj: An object dict (Note, Offer, Request, etc.).
|
||||
"""
|
||||
# Propagate context from inner object if it's a list (custom types)
|
||||
ctx = obj.get("@context", "https://www.w3.org/ns/activitystreams")
|
||||
|
||||
return {
|
||||
"@context": ctx,
|
||||
"id": _mint_id(actor.actor_id, "activities"),
|
||||
"type": "Create",
|
||||
"actor": actor.actor_id,
|
||||
"to": obj.get("to", [PUBLIC]),
|
||||
"cc": obj.get("cc", []),
|
||||
"published": obj.get("published", _now_iso()),
|
||||
"object": obj,
|
||||
}
|
||||
|
|
@ -1,197 +0,0 @@
|
|||
"""
|
||||
HTTP Signatures for ActivityPub (draft-cavage-http-signatures-08).
|
||||
|
||||
This is the signing convention used by Mastodon, Lemmy, and the broader
|
||||
ActivityPub ecosystem. It is distinct from the newer RFC 9421.
|
||||
|
||||
Signing algorithm: rsa-sha256
|
||||
Signed headers: (request-target) host date [digest] content-type
|
||||
Digest header: SHA-256 of request body (when body is present)
|
||||
keyId: {actor.actor_id}#main-key
|
||||
|
||||
MIT licensed.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import re
|
||||
from email.utils import formatdate
|
||||
from typing import TYPE_CHECKING
|
||||
from urllib.parse import urlparse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from circuitforge_core.activitypub.actor import CFActor
|
||||
|
||||
|
||||
def _rfc1123_now() -> str:
|
||||
"""Return current UTC time in RFC 1123 format as required by HTTP Date header."""
|
||||
return formatdate(usegmt=True)
|
||||
|
||||
|
||||
def _sha256_digest(body: bytes) -> str:
|
||||
"""Return 'SHA-256=<base64>' digest string for body."""
|
||||
digest = hashlib.sha256(body).digest()
|
||||
return f"SHA-256={base64.b64encode(digest).decode()}"
|
||||
|
||||
|
||||
def sign_headers(
|
||||
method: str,
|
||||
url: str,
|
||||
headers: dict,
|
||||
body: bytes | None,
|
||||
actor: "CFActor", # type: ignore[name-defined]
|
||||
) -> dict:
|
||||
"""
|
||||
Return a new headers dict with Date, Digest (if body), and Signature added.
|
||||
|
||||
The input *headers* dict is not mutated.
|
||||
|
||||
Args:
|
||||
method: HTTP method string (e.g. "POST"), case-insensitive.
|
||||
url: Full request URL.
|
||||
headers: Existing headers dict (Content-Type, etc.).
|
||||
body: Request body bytes, or None for bodyless requests.
|
||||
actor: CFActor whose private key signs the request.
|
||||
|
||||
Returns:
|
||||
New dict with all original headers plus Date, Digest (if body), Signature.
|
||||
"""
|
||||
from cryptography.hazmat.primitives import hashes, serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import padding
|
||||
from cryptography.hazmat.primitives.serialization import load_pem_private_key
|
||||
|
||||
parsed = urlparse(url)
|
||||
host = parsed.netloc
|
||||
path = parsed.path or "/"
|
||||
if parsed.query:
|
||||
path = f"{path}?{parsed.query}"
|
||||
|
||||
method_lower = method.lower()
|
||||
date = _rfc1123_now()
|
||||
|
||||
out = dict(headers)
|
||||
out["Date"] = date
|
||||
out["Host"] = host
|
||||
|
||||
signed_header_names = ["(request-target)", "host", "date"]
|
||||
|
||||
if body is not None:
|
||||
digest = _sha256_digest(body)
|
||||
out["Digest"] = digest
|
||||
signed_header_names.append("digest")
|
||||
|
||||
if "Content-Type" in out:
|
||||
signed_header_names.append("content-type")
|
||||
|
||||
# Build the signature string — header names in the spec are lowercase,
|
||||
# but the dict uses Title-Case HTTP convention, so look up case-insensitively.
|
||||
def _ci_get(d: dict, key: str) -> str:
|
||||
for k, v in d.items():
|
||||
if k.lower() == key.lower():
|
||||
return v
|
||||
raise KeyError(key)
|
||||
|
||||
lines = []
|
||||
for name in signed_header_names:
|
||||
if name == "(request-target)":
|
||||
lines.append(f"(request-target): {method_lower} {path}")
|
||||
else:
|
||||
lines.append(f"{name}: {_ci_get(out, name)}")
|
||||
|
||||
signature_string = "\n".join(lines).encode()
|
||||
|
||||
private_key = load_pem_private_key(actor.private_key_pem.encode(), password=None)
|
||||
raw_sig = private_key.sign(signature_string, padding.PKCS1v15(), hashes.SHA256())
|
||||
b64_sig = base64.b64encode(raw_sig).decode()
|
||||
|
||||
key_id = f"{actor.actor_id}#main-key"
|
||||
headers_param = " ".join(signed_header_names)
|
||||
|
||||
out["Signature"] = (
|
||||
f'keyId="{key_id}",'
|
||||
f'algorithm="rsa-sha256",'
|
||||
f'headers="{headers_param}",'
|
||||
f'signature="{b64_sig}"'
|
||||
)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def verify_signature(
|
||||
headers: dict,
|
||||
method: str,
|
||||
path: str,
|
||||
body: bytes | None,
|
||||
public_key_pem: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Verify an incoming ActivityPub HTTP Signature.
|
||||
|
||||
Returns False on any parse or verification failure — never raises.
|
||||
|
||||
Args:
|
||||
headers: Request headers dict (case-insensitive lookup attempted).
|
||||
method: HTTP method (e.g. "POST").
|
||||
path: Request path (e.g. "/actors/kiwi/inbox").
|
||||
body: Raw request body bytes, or None.
|
||||
public_key_pem: PEM-encoded RSA public key of the signing actor.
|
||||
"""
|
||||
from cryptography.exceptions import InvalidSignature
|
||||
from cryptography.hazmat.primitives import hashes, serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import padding
|
||||
from cryptography.hazmat.primitives.serialization import load_pem_public_key
|
||||
|
||||
try:
|
||||
# Case-insensitive header lookup helper
|
||||
def _get(name: str) -> str | None:
|
||||
name_lower = name.lower()
|
||||
for k, v in headers.items():
|
||||
if k.lower() == name_lower:
|
||||
return v
|
||||
return None
|
||||
|
||||
sig_header = _get("Signature")
|
||||
if not sig_header:
|
||||
return False
|
||||
|
||||
# Parse Signature header key=value pairs
|
||||
params: dict[str, str] = {}
|
||||
for match in re.finditer(r'(\w+)="([^"]*)"', sig_header):
|
||||
params[match.group(1)] = match.group(2)
|
||||
|
||||
if "signature" not in params or "headers" not in params:
|
||||
return False
|
||||
|
||||
signed_header_names = params["headers"].split()
|
||||
method_lower = method.lower()
|
||||
|
||||
lines = []
|
||||
for name in signed_header_names:
|
||||
if name == "(request-target)":
|
||||
lines.append(f"(request-target): {method_lower} {path}")
|
||||
else:
|
||||
val = _get(name)
|
||||
if val is None:
|
||||
return False
|
||||
lines.append(f"{name}: {val}")
|
||||
|
||||
signature_string = "\n".join(lines).encode()
|
||||
raw_sig = base64.b64decode(params["signature"])
|
||||
|
||||
public_key = load_pem_public_key(public_key_pem.encode())
|
||||
public_key.verify(raw_sig, signature_string, padding.PKCS1v15(), hashes.SHA256())
|
||||
|
||||
# Also verify the Digest header matches the actual body, if both are present.
|
||||
# Signing the Digest header proves it wasn't swapped; re-computing it proves
|
||||
# the body wasn't replaced after signing.
|
||||
digest_val = _get("Digest")
|
||||
if digest_val and body is not None:
|
||||
expected = _sha256_digest(body)
|
||||
if digest_val != expected:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
except (InvalidSignature, Exception):
|
||||
return False
|
||||
|
|
@ -1,4 +0,0 @@
|
|||
from circuitforge_core.api.feedback import make_feedback_router
|
||||
from circuitforge_core.api.corrections import make_corrections_router, CORRECTIONS_MIGRATION_SQL
|
||||
|
||||
__all__ = ["make_feedback_router", "make_corrections_router", "CORRECTIONS_MIGRATION_SQL"]
|
||||
|
|
@ -1,199 +0,0 @@
|
|||
"""
|
||||
Shared corrections router — stores user corrections to LLM output for SFT training.
|
||||
|
||||
Products include this with make_corrections_router(get_db=..., product=...).
|
||||
Corrections are stored locally in each product's SQLite DB and exported as JSONL
|
||||
for the Avocet SFT pipeline. Separate from the bug-feedback→Forgejo-issue path.
|
||||
|
||||
Required DB migration (add to product migrations dir):
|
||||
-- From circuitforge_core.api.corrections import CORRECTIONS_MIGRATION_SQL
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sqlite3
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime, timezone
|
||||
from typing import Iterator, Literal
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# Drop this SQL into a product's migrations directory (e.g. 020_corrections.sql).
|
||||
CORRECTIONS_MIGRATION_SQL = """\
|
||||
CREATE TABLE IF NOT EXISTS corrections (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
item_id TEXT NOT NULL DEFAULT '',
|
||||
product TEXT NOT NULL,
|
||||
correction_type TEXT NOT NULL,
|
||||
input_text TEXT NOT NULL,
|
||||
original_output TEXT NOT NULL,
|
||||
corrected_output TEXT NOT NULL DEFAULT '',
|
||||
rating TEXT NOT NULL DEFAULT 'down',
|
||||
context TEXT NOT NULL DEFAULT '{}',
|
||||
opted_in INTEGER NOT NULL DEFAULT 0,
|
||||
created_at TEXT NOT NULL DEFAULT (datetime('now'))
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_corrections_product
|
||||
ON corrections (product);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_corrections_opted_in
|
||||
ON corrections (opted_in);
|
||||
"""
|
||||
|
||||
|
||||
class CorrectionRequest(BaseModel):
|
||||
item_id: str = ""
|
||||
product: str
|
||||
correction_type: str
|
||||
input_text: str
|
||||
original_output: str
|
||||
corrected_output: str = ""
|
||||
rating: Literal["up", "down"] = "down"
|
||||
context: dict = Field(default_factory=dict)
|
||||
opted_in: bool = False
|
||||
|
||||
|
||||
class CorrectionResponse(BaseModel):
|
||||
id: int
|
||||
saved: bool
|
||||
|
||||
|
||||
class CorrectionRecord(BaseModel):
|
||||
id: int
|
||||
item_id: str
|
||||
product: str
|
||||
correction_type: str
|
||||
input_text: str
|
||||
original_output: str
|
||||
corrected_output: str
|
||||
rating: str
|
||||
context: dict
|
||||
opted_in: bool
|
||||
created_at: str
|
||||
|
||||
|
||||
def make_corrections_router(
|
||||
get_db: Callable[[], Iterator[sqlite3.Connection]],
|
||||
product: str,
|
||||
) -> APIRouter:
|
||||
"""Return a configured corrections APIRouter.
|
||||
|
||||
Args:
|
||||
get_db: FastAPI dependency that yields a sqlite3.Connection.
|
||||
product: Product slug injected into every correction row (e.g. "linnet").
|
||||
"""
|
||||
router = APIRouter()
|
||||
|
||||
@router.post("", response_model=CorrectionResponse)
|
||||
def submit_correction(
|
||||
payload: CorrectionRequest,
|
||||
conn: sqlite3.Connection = Depends(get_db),
|
||||
) -> CorrectionResponse:
|
||||
"""Store a user correction to an LLM output."""
|
||||
# Thumbs-up with no corrected text is a valid positive signal.
|
||||
if payload.rating == "down" and not payload.corrected_output.strip():
|
||||
raise HTTPException(
|
||||
status_code=422,
|
||||
detail="corrected_output is required when rating is 'down'.",
|
||||
)
|
||||
|
||||
row_id = conn.execute(
|
||||
"""
|
||||
INSERT INTO corrections
|
||||
(item_id, product, correction_type, input_text, original_output,
|
||||
corrected_output, rating, context, opted_in)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
payload.item_id,
|
||||
product,
|
||||
payload.correction_type,
|
||||
payload.input_text,
|
||||
payload.original_output,
|
||||
payload.corrected_output,
|
||||
payload.rating,
|
||||
json.dumps(payload.context),
|
||||
int(payload.opted_in),
|
||||
),
|
||||
).lastrowid
|
||||
conn.commit()
|
||||
return CorrectionResponse(id=row_id, saved=True)
|
||||
|
||||
@router.get("", response_model=list[CorrectionRecord])
|
||||
def list_corrections(
|
||||
opted_in_only: bool = False,
|
||||
limit: int = 200,
|
||||
conn: sqlite3.Connection = Depends(get_db),
|
||||
) -> list[CorrectionRecord]:
|
||||
"""List stored corrections, optionally filtered to opted-in rows only."""
|
||||
conn.row_factory = sqlite3.Row
|
||||
query = "SELECT * FROM corrections"
|
||||
params: list = []
|
||||
if opted_in_only:
|
||||
query += " WHERE opted_in = 1"
|
||||
query += " ORDER BY created_at DESC LIMIT ?"
|
||||
params.append(max(1, min(limit, 1000)))
|
||||
rows = conn.execute(query, params).fetchall()
|
||||
return [
|
||||
CorrectionRecord(
|
||||
id=r["id"],
|
||||
item_id=r["item_id"],
|
||||
product=r["product"],
|
||||
correction_type=r["correction_type"],
|
||||
input_text=r["input_text"],
|
||||
original_output=r["original_output"],
|
||||
corrected_output=r["corrected_output"],
|
||||
rating=r["rating"],
|
||||
context=json.loads(r["context"] or "{}"),
|
||||
opted_in=bool(r["opted_in"]),
|
||||
created_at=r["created_at"],
|
||||
)
|
||||
for r in rows
|
||||
]
|
||||
|
||||
@router.get("/export")
|
||||
def export_corrections(
|
||||
opted_in_only: bool = True,
|
||||
conn: sqlite3.Connection = Depends(get_db),
|
||||
) -> StreamingResponse:
|
||||
"""Stream corrections as JSONL for the Avocet SFT pipeline.
|
||||
|
||||
Each line is a JSON object with the fields expected by avocet's
|
||||
SFT candidate importer. opted_in_only=True (default) — only rows
|
||||
where the user consented to share are exported.
|
||||
"""
|
||||
conn.row_factory = sqlite3.Row
|
||||
query = "SELECT * FROM corrections"
|
||||
if opted_in_only:
|
||||
query += " WHERE opted_in = 1"
|
||||
query += " ORDER BY created_at ASC"
|
||||
rows = conn.execute(query).fetchall()
|
||||
|
||||
timestamp = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
|
||||
filename = f"corrections_{product}_{timestamp}.jsonl"
|
||||
|
||||
def generate() -> Iterator[str]:
|
||||
for r in rows:
|
||||
record = {
|
||||
"input": r["input_text"],
|
||||
"output": r["original_output"],
|
||||
"correction": r["corrected_output"],
|
||||
"rating": r["rating"],
|
||||
"correction_type": r["correction_type"],
|
||||
"product": r["product"],
|
||||
"item_id": r["item_id"],
|
||||
"context": json.loads(r["context"] or "{}"),
|
||||
"created_at": r["created_at"],
|
||||
}
|
||||
yield json.dumps(record, ensure_ascii=False) + "\n"
|
||||
|
||||
return StreamingResponse(
|
||||
generate(),
|
||||
media_type="application/x-ndjson",
|
||||
headers={"Content-Disposition": f'attachment; filename="{filename}"'},
|
||||
)
|
||||
|
||||
return router
|
||||
|
|
@ -1,29 +0,0 @@
|
|||
"""
|
||||
circuitforge_core.audio — shared PCM and audio signal utilities.
|
||||
|
||||
MIT licensed. No model weights. No HuggingFace. Dependency: numpy only
|
||||
(scipy optional for high-quality resampling).
|
||||
|
||||
Consumers:
|
||||
cf-voice — replaces hand-rolled PCM conversion in stt.py / context.py
|
||||
Sparrow — torchaudio stitching, export, acoustic analysis
|
||||
Avocet — audio preprocessing for classifier training corpus
|
||||
Linnet — chunk accumulation for real-time tone annotation
|
||||
"""
|
||||
from circuitforge_core.audio.convert import (
|
||||
bytes_to_float32,
|
||||
float32_to_pcm,
|
||||
pcm_to_float32,
|
||||
)
|
||||
from circuitforge_core.audio.gate import is_silent
|
||||
from circuitforge_core.audio.resample import resample
|
||||
from circuitforge_core.audio.buffer import ChunkAccumulator
|
||||
|
||||
__all__ = [
|
||||
"bytes_to_float32",
|
||||
"float32_to_pcm",
|
||||
"pcm_to_float32",
|
||||
"is_silent",
|
||||
"resample",
|
||||
"ChunkAccumulator",
|
||||
]
|
||||
|
|
@ -1,67 +0,0 @@
|
|||
"""
|
||||
ChunkAccumulator — collect fixed-size audio chunks into a classify window.
|
||||
|
||||
Used by cf-voice and Linnet to gather N × 100ms frames before firing
|
||||
a classification pass. The window size trades latency against context:
|
||||
a 2-second window (20 × 100ms) gives the classifier enough signal to
|
||||
detect tone/affect reliably without lagging the conversation.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import deque
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class ChunkAccumulator:
|
||||
"""Accumulate audio chunks and flush when the window is full.
|
||||
|
||||
Args:
|
||||
window_chunks: Number of chunks to collect before is_ready() is True.
|
||||
dtype: numpy dtype of the accumulated array. Default float32.
|
||||
"""
|
||||
|
||||
def __init__(self, window_chunks: int, *, dtype: np.dtype = np.float32) -> None:
|
||||
if window_chunks < 1:
|
||||
raise ValueError(f"window_chunks must be >= 1, got {window_chunks}")
|
||||
self._window = window_chunks
|
||||
self._dtype = dtype
|
||||
self._buf: deque[np.ndarray] = deque()
|
||||
|
||||
def accumulate(self, chunk: np.ndarray) -> None:
|
||||
"""Add a chunk to the buffer. Oldest chunks are dropped once the
|
||||
buffer exceeds window_chunks to bound memory."""
|
||||
self._buf.append(chunk.astype(self._dtype))
|
||||
while len(self._buf) > self._window:
|
||||
self._buf.popleft()
|
||||
|
||||
def is_ready(self) -> bool:
|
||||
"""True when window_chunks have been accumulated."""
|
||||
return len(self._buf) >= self._window
|
||||
|
||||
def flush(self) -> np.ndarray:
|
||||
"""Concatenate accumulated chunks and reset the buffer.
|
||||
|
||||
Returns:
|
||||
float32 ndarray of concatenated audio.
|
||||
|
||||
Raises:
|
||||
RuntimeError: if fewer than window_chunks have been accumulated.
|
||||
"""
|
||||
if not self.is_ready():
|
||||
raise RuntimeError(
|
||||
f"Not enough chunks accumulated: have {len(self._buf)}, "
|
||||
f"need {self._window}. Check is_ready() before calling flush()."
|
||||
)
|
||||
result = np.concatenate(list(self._buf), axis=-1).astype(self._dtype)
|
||||
self._buf.clear()
|
||||
return result
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Discard all buffered audio without returning it."""
|
||||
self._buf.clear()
|
||||
|
||||
@property
|
||||
def chunk_count(self) -> int:
|
||||
"""Current number of buffered chunks."""
|
||||
return len(self._buf)
|
||||
|
|
@ -1,50 +0,0 @@
|
|||
"""
|
||||
PCM / float32 conversion utilities.
|
||||
|
||||
All functions operate on raw audio bytes or numpy arrays. No torch dependency.
|
||||
|
||||
Standard pipeline:
|
||||
bytes (int16 PCM) -> float32 ndarray -> signal processing -> bytes (int16 PCM)
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def pcm_to_float32(pcm_bytes: bytes, *, dtype: np.dtype = np.int16) -> np.ndarray:
|
||||
"""Convert raw PCM bytes to a float32 numpy array in [-1.0, 1.0].
|
||||
|
||||
Args:
|
||||
pcm_bytes: Raw PCM audio bytes.
|
||||
dtype: Sample dtype of the input. Default: int16 (standard mic input).
|
||||
|
||||
Returns:
|
||||
float32 ndarray, values in [-1.0, 1.0].
|
||||
"""
|
||||
scale = np.iinfo(dtype).max
|
||||
return np.frombuffer(pcm_bytes, dtype=dtype).astype(np.float32) / scale
|
||||
|
||||
|
||||
def bytes_to_float32(pcm_bytes: bytes) -> np.ndarray:
|
||||
"""Alias for pcm_to_float32 with default int16 dtype.
|
||||
|
||||
Matches the naming used in cf-voice context.py for easier migration.
|
||||
"""
|
||||
return pcm_to_float32(pcm_bytes)
|
||||
|
||||
|
||||
def float32_to_pcm(audio: np.ndarray, *, dtype: np.dtype = np.int16) -> bytes:
|
||||
"""Convert a float32 ndarray in [-1.0, 1.0] to raw PCM bytes.
|
||||
|
||||
Clips to [-1.0, 1.0] before scaling to prevent wraparound distortion.
|
||||
|
||||
Args:
|
||||
audio: float32 ndarray, values nominally in [-1.0, 1.0].
|
||||
dtype: Target PCM sample dtype. Default: int16.
|
||||
|
||||
Returns:
|
||||
Raw PCM bytes.
|
||||
"""
|
||||
scale = np.iinfo(dtype).max
|
||||
clipped = np.clip(audio, -1.0, 1.0)
|
||||
return (clipped * scale).astype(dtype).tobytes()
|
||||
|
|
@ -1,44 +0,0 @@
|
|||
"""
|
||||
Energy gate — silence detection via RMS amplitude.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
|
||||
# Default threshold extracted from cf-voice stt.py.
|
||||
# Signals below this RMS level are considered silent.
|
||||
_DEFAULT_RMS_THRESHOLD = 0.005
|
||||
|
||||
|
||||
def is_silent(
|
||||
audio: np.ndarray,
|
||||
*,
|
||||
rms_threshold: float = _DEFAULT_RMS_THRESHOLD,
|
||||
) -> bool:
|
||||
"""Return True when the audio clip is effectively silent.
|
||||
|
||||
Uses root-mean-square amplitude as the energy estimate. This is a fast
|
||||
frame-level gate — not a VAD model. Use it to skip inference on empty
|
||||
audio frames before they hit a more expensive transcription or
|
||||
classification pipeline.
|
||||
|
||||
Args:
|
||||
audio: float32 ndarray, values in [-1.0, 1.0].
|
||||
rms_threshold: Clips with RMS below this value are silent.
|
||||
Default 0.005 is conservative — genuine speech at
|
||||
normal mic levels sits well above this.
|
||||
|
||||
Returns:
|
||||
True if silent, False if the clip contains meaningful signal.
|
||||
"""
|
||||
if audio.size == 0:
|
||||
return True
|
||||
rms = float(np.sqrt(np.mean(audio.astype(np.float32) ** 2)))
|
||||
return rms < rms_threshold
|
||||
|
||||
|
||||
def rms(audio: np.ndarray) -> float:
|
||||
"""Return the RMS amplitude of an audio array."""
|
||||
if audio.size == 0:
|
||||
return 0.0
|
||||
return float(np.sqrt(np.mean(audio.astype(np.float32) ** 2)))
|
||||
|
|
@ -1,39 +0,0 @@
|
|||
"""
|
||||
Audio resampling — change sample rate of a float32 audio array.
|
||||
|
||||
Uses scipy.signal.resample_poly when available (high-quality, anti-aliased).
|
||||
Falls back to linear interpolation via numpy when scipy is absent — acceptable
|
||||
for 16kHz speech but not for music.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def resample(audio: np.ndarray, from_hz: int, to_hz: int) -> np.ndarray:
|
||||
"""Resample audio from one sample rate to another.
|
||||
|
||||
Args:
|
||||
audio: float32 ndarray, shape (samples,) or (channels, samples).
|
||||
from_hz: Source sample rate in Hz.
|
||||
to_hz: Target sample rate in Hz.
|
||||
|
||||
Returns:
|
||||
Resampled float32 ndarray at to_hz.
|
||||
"""
|
||||
if from_hz == to_hz:
|
||||
return audio.astype(np.float32)
|
||||
|
||||
try:
|
||||
from scipy.signal import resample_poly # type: ignore[import]
|
||||
from math import gcd
|
||||
g = gcd(from_hz, to_hz)
|
||||
up, down = to_hz // g, from_hz // g
|
||||
return resample_poly(audio.astype(np.float32), up, down, axis=-1)
|
||||
except ImportError:
|
||||
# Numpy linear interpolation fallback — lower quality but no extra deps.
|
||||
# Adequate for 16kHz ↔ 8kHz conversion on speech; avoid for music.
|
||||
n_out = int(len(audio) * to_hz / from_hz)
|
||||
x_old = np.linspace(0, 1, len(audio), endpoint=False)
|
||||
x_new = np.linspace(0, 1, n_out, endpoint=False)
|
||||
return np.interp(x_new, x_old, audio.astype(np.float32)).astype(np.float32)
|
||||
|
|
@ -1,340 +0,0 @@
|
|||
"""
|
||||
circuitforge_core.cloud_session — shared cloud session resolution for all CF products.
|
||||
|
||||
Usage (FastAPI product):
|
||||
|
||||
from circuitforge_core.cloud_session import CloudSessionFactory
|
||||
from pathlib import Path
|
||||
|
||||
_sessions = CloudSessionFactory(
|
||||
product="avocet",
|
||||
local_db=Path("data/avocet.db"),
|
||||
)
|
||||
get_session = _sessions.dependency()
|
||||
require_tier = _sessions.require_tier
|
||||
|
||||
@router.get("/api/imitate")
|
||||
def imitate(session: CloudUser = Depends(get_session)):
|
||||
# session.user_id is the Directus UUID for cloud users, "local" for self-hosted
|
||||
...
|
||||
|
||||
Environment variables (set per-product via .env / compose):
|
||||
CLOUD_MODE 1/true/yes to enable cloud auth (default: off)
|
||||
CLOUD_DATA_ROOT Root directory for per-user data (default: /devl/<product>-cloud-data)
|
||||
DIRECTUS_JWT_SECRET HS256 secret used to sign cf_session JWTs (required in cloud mode)
|
||||
HEIMDALL_URL License server base URL (default: https://license.circuitforge.tech)
|
||||
HEIMDALL_ADMIN_TOKEN Heimdall admin bearer token (required for tier resolution)
|
||||
CF_SERVER_SECRET Server-side secret for deriving per-user encryption keys
|
||||
CLOUD_AUTH_BYPASS_IPS Comma-separated IPs/CIDRs to skip JWT auth (dev LAN only)
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import ipaddress
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
TIERS: list[str] = ["free", "paid", "premium", "ultra"]
|
||||
|
||||
# ── CloudUser ─────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CloudUser:
|
||||
"""Resolved user identity for one HTTP request.
|
||||
|
||||
user_id: Directus UUID for authenticated cloud users.
|
||||
"local" for self-hosted / CLOUD_MODE=false.
|
||||
"local-dev" for dev-bypass-IP sessions.
|
||||
"anon-<uuid>" for unauthenticated guest visitors.
|
||||
tier: free | paid | premium | ultra | local
|
||||
product: Which CF product this session belongs to (e.g. "avocet").
|
||||
meta: Product-specific extras (e.g. household_id for Kiwi).
|
||||
Access via session.meta.get("household_id").
|
||||
"""
|
||||
user_id: str
|
||||
tier: str
|
||||
product: str
|
||||
has_byok: bool = False
|
||||
meta: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
# ── Helpers ───────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _parse_bypass_nets(raw: str) -> tuple[list[ipaddress.IPv4Network | ipaddress.IPv6Network], frozenset[str]]:
|
||||
nets: list[ipaddress.IPv4Network | ipaddress.IPv6Network] = []
|
||||
ips: set[str] = set()
|
||||
for entry in (e.strip() for e in raw.split(",") if e.strip()):
|
||||
try:
|
||||
nets.append(ipaddress.ip_network(entry, strict=False))
|
||||
except ValueError:
|
||||
ips.add(entry)
|
||||
return nets, frozenset(ips)
|
||||
|
||||
|
||||
def _is_bypass_ip(
|
||||
ip: str,
|
||||
nets: list[ipaddress.IPv4Network | ipaddress.IPv6Network],
|
||||
ips: frozenset[str],
|
||||
) -> bool:
|
||||
if not ip or (not nets and not ips):
|
||||
return False
|
||||
if ip in ips:
|
||||
return True
|
||||
try:
|
||||
addr = ipaddress.ip_address(ip)
|
||||
return any(addr in net for net in nets)
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
def _extract_session_token(header_value: str) -> str:
|
||||
"""Pull cf_session value out of a raw Cookie header or return the value as-is."""
|
||||
m = re.search(r'(?:^|;)\s*cf_session=([^;]+)', header_value)
|
||||
return m.group(1).strip() if m else header_value.strip()
|
||||
|
||||
|
||||
# ── CloudSessionFactory ───────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class CloudSessionFactory:
|
||||
"""Per-product session factory. Instantiate once at module level.
|
||||
|
||||
Args:
|
||||
product: Product code string (e.g. "avocet", "kiwi").
|
||||
extra_meta: Optional async-or-sync callable that receives
|
||||
(user_id: str, tier: str) and returns a dict merged
|
||||
into CloudUser.meta. Use for product-specific fields
|
||||
like household_id.
|
||||
byok_detector: Callable() → bool. Override to detect BYOK for this
|
||||
product's config path. Default: always False.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
product: str,
|
||||
extra_meta: Callable[[str, str], dict[str, Any]] | None = None,
|
||||
byok_detector: Callable[[], bool] | None = None,
|
||||
) -> None:
|
||||
self.product = product
|
||||
self._extra_meta = extra_meta
|
||||
self._byok_detector = byok_detector or (lambda: False)
|
||||
|
||||
# Config — read from environment at construction time so tests can patch env
|
||||
self._cloud_mode: bool = os.environ.get("CLOUD_MODE", "").lower() in ("1", "true", "yes")
|
||||
self._directus_secret: str = os.environ.get("DIRECTUS_JWT_SECRET", "")
|
||||
self._heimdall_url: str = os.environ.get("HEIMDALL_URL", "https://license.circuitforge.tech")
|
||||
self._heimdall_token: str = os.environ.get("HEIMDALL_ADMIN_TOKEN", "")
|
||||
self._cloud_data_root: Path = Path(
|
||||
os.environ.get("CLOUD_DATA_ROOT", f"/devl/{product}-cloud-data")
|
||||
)
|
||||
|
||||
_bypass_raw = os.environ.get("CLOUD_AUTH_BYPASS_IPS", "")
|
||||
self._bypass_nets, self._bypass_ips = _parse_bypass_nets(_bypass_raw)
|
||||
|
||||
# Tier resolution cache: {user_id: (result_dict, timestamp)}
|
||||
self._tier_cache: dict[str, tuple[dict, float]] = {}
|
||||
self._tier_cache_ttl: float = 300.0 # 5 minutes
|
||||
|
||||
# ── JWT ───────────────────────────────────────────────────────────────────
|
||||
|
||||
def validate_jwt(self, token: str) -> str:
|
||||
"""Validate a cf_session JWT and return the Directus user_id. Raises HTTPException on failure."""
|
||||
try:
|
||||
import jwt as pyjwt # lazy — not needed in local mode
|
||||
from fastapi import HTTPException
|
||||
payload = pyjwt.decode(
|
||||
token,
|
||||
self._directus_secret,
|
||||
algorithms=["HS256"],
|
||||
options={"require": ["id", "exp"]},
|
||||
)
|
||||
return payload["id"]
|
||||
except Exception as exc:
|
||||
log.debug("JWT validation failed: %s", exc)
|
||||
from fastapi import HTTPException
|
||||
raise HTTPException(status_code=401, detail="Session invalid or expired")
|
||||
|
||||
# ── Heimdall ──────────────────────────────────────────────────────────────
|
||||
|
||||
def _ensure_provisioned(self, user_id: str) -> None:
|
||||
if not self._heimdall_token:
|
||||
return
|
||||
try:
|
||||
import requests
|
||||
requests.post(
|
||||
f"{self._heimdall_url}/admin/provision",
|
||||
json={"directus_user_id": user_id, "product": self.product, "tier": "free"},
|
||||
headers={"Authorization": f"Bearer {self._heimdall_token}"},
|
||||
timeout=5,
|
||||
)
|
||||
except Exception as exc:
|
||||
log.warning("Heimdall provision failed for user %s: %s", user_id, exc)
|
||||
|
||||
def _resolve_tier(self, user_id: str) -> dict[str, Any]:
|
||||
"""Returns dict with keys: tier, license_key (and any product extras)."""
|
||||
now = time.monotonic()
|
||||
cached = self._tier_cache.get(user_id)
|
||||
if cached and (now - cached[1]) < self._tier_cache_ttl:
|
||||
return cached[0]
|
||||
|
||||
result: dict[str, Any] = {"tier": "free", "license_key": None}
|
||||
if self._heimdall_token:
|
||||
try:
|
||||
import requests
|
||||
resp = requests.post(
|
||||
f"{self._heimdall_url}/admin/cloud/resolve",
|
||||
json={"directus_user_id": user_id, "product": self.product},
|
||||
headers={"Authorization": f"Bearer {self._heimdall_token}"},
|
||||
timeout=5,
|
||||
)
|
||||
if resp.ok:
|
||||
data = resp.json()
|
||||
result["tier"] = data.get("tier", "free")
|
||||
result["license_key"] = data.get("key_display")
|
||||
# Forward any extra fields Heimdall returns (household_id etc.)
|
||||
result.update({k: v for k, v in data.items() if k not in result})
|
||||
except Exception as exc:
|
||||
log.warning("Heimdall tier resolve failed for %s: %s", user_id, exc)
|
||||
else:
|
||||
log.debug("HEIMDALL_ADMIN_TOKEN not set — defaulting tier to free")
|
||||
|
||||
self._tier_cache[user_id] = (result, now)
|
||||
return result
|
||||
|
||||
# ── Guest sessions ────────────────────────────────────────────────────────
|
||||
|
||||
_GUEST_COOKIE = "cf_guest_id"
|
||||
_GUEST_COOKIE_MAX_AGE = 60 * 60 * 24 * 90 # 90 days
|
||||
|
||||
def _resolve_guest(self, request: Any, response: Any) -> CloudUser:
|
||||
guest_id = (request.cookies.get(self._GUEST_COOKIE) or "").strip()
|
||||
if not guest_id:
|
||||
guest_id = str(uuid.uuid4())
|
||||
is_https = request.headers.get("x-forwarded-proto", "http").lower() == "https"
|
||||
response.set_cookie(
|
||||
key=self._GUEST_COOKIE,
|
||||
value=guest_id,
|
||||
max_age=self._GUEST_COOKIE_MAX_AGE,
|
||||
httponly=True,
|
||||
samesite="lax",
|
||||
secure=is_https,
|
||||
)
|
||||
return CloudUser(
|
||||
user_id=f"anon-{guest_id}",
|
||||
tier="free",
|
||||
product=self.product,
|
||||
has_byok=self._byok_detector(),
|
||||
)
|
||||
|
||||
# ── Core resolver ─────────────────────────────────────────────────────────
|
||||
|
||||
def resolve(self, request: Any, response: Any) -> CloudUser:
|
||||
"""Resolve the CloudUser for a FastAPI request. Suitable as a Depends() target."""
|
||||
has_byok = self._byok_detector()
|
||||
|
||||
if not self._cloud_mode:
|
||||
return CloudUser(user_id="local", tier="local", product=self.product, has_byok=has_byok)
|
||||
|
||||
client_ip = (
|
||||
request.headers.get("x-real-ip", "")
|
||||
or (request.client.host if request.client else "")
|
||||
)
|
||||
if _is_bypass_ip(client_ip, self._bypass_nets, self._bypass_ips):
|
||||
log.debug("Bypass IP %s — returning local-dev session for product %s", client_ip, self.product)
|
||||
return CloudUser(user_id="local-dev", tier="local", product=self.product, has_byok=has_byok)
|
||||
|
||||
raw_session = (
|
||||
request.headers.get("x-cf-session", "").strip()
|
||||
or request.cookies.get("cf_session", "").strip()
|
||||
)
|
||||
if not raw_session:
|
||||
return self._resolve_guest(request, response)
|
||||
|
||||
token = _extract_session_token(raw_session)
|
||||
if not token:
|
||||
return self._resolve_guest(request, response)
|
||||
|
||||
user_id = self.validate_jwt(token)
|
||||
self._ensure_provisioned(user_id)
|
||||
tier_data = self._resolve_tier(user_id)
|
||||
tier = tier_data.get("tier", "free")
|
||||
|
||||
meta: dict[str, Any] = {}
|
||||
if self._extra_meta:
|
||||
meta = self._extra_meta(user_id, tier) or {}
|
||||
# Merge any extra fields from Heimdall response (e.g. household_id)
|
||||
meta.update({k: v for k, v in tier_data.items() if k not in ("tier", "license_key")})
|
||||
meta["license_key"] = tier_data.get("license_key")
|
||||
|
||||
return CloudUser(
|
||||
user_id=user_id,
|
||||
tier=tier,
|
||||
product=self.product,
|
||||
has_byok=has_byok,
|
||||
meta=meta,
|
||||
)
|
||||
|
||||
def dependency(self) -> Callable[[Any, Any], CloudUser]:
|
||||
"""Return a FastAPI-compatible dependency function (use with Depends())."""
|
||||
factory = self
|
||||
|
||||
def _get_session(request: Any, response: Any) -> CloudUser:
|
||||
return factory.resolve(request, response)
|
||||
|
||||
return _get_session
|
||||
|
||||
def require_tier(self, min_tier: str) -> Callable:
|
||||
"""Dependency factory — raises 403 if the session tier is below min_tier."""
|
||||
from fastapi import Depends, HTTPException
|
||||
min_idx = TIERS.index(min_tier)
|
||||
get_session = self.dependency()
|
||||
|
||||
def _check(session: CloudUser = Depends(get_session)) -> CloudUser:
|
||||
if session.tier in ("local", "local-dev"):
|
||||
return session
|
||||
try:
|
||||
if TIERS.index(session.tier) < min_idx:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail=f"This feature requires {min_tier} tier or above.",
|
||||
)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=403, detail="Unknown tier.")
|
||||
return session
|
||||
|
||||
return _check
|
||||
|
||||
|
||||
# ── BYOK detection ────────────────────────────────────────────────────────────
|
||||
|
||||
def detect_byok(config_path: Path | None = None) -> bool:
|
||||
"""Return True if at least one enabled non-vision LLM backend is configured.
|
||||
|
||||
Reads the shared llm.yaml that LLMRouter uses. Local (Ollama, vLLM) and
|
||||
API-key backends both count — the policy is "user is supplying compute",
|
||||
regardless of where that compute lives.
|
||||
|
||||
Args:
|
||||
config_path: Override the default config location. Useful in tests.
|
||||
"""
|
||||
import yaml
|
||||
if config_path is None:
|
||||
config_path = Path.home() / ".config" / "circuitforge" / "llm.yaml"
|
||||
try:
|
||||
with open(config_path) as f:
|
||||
cfg = yaml.safe_load(f) or {}
|
||||
return any(
|
||||
b.get("enabled", True) and b.get("type") != "vision_service"
|
||||
for b in cfg.get("backends", {}).values()
|
||||
)
|
||||
except Exception:
|
||||
return False
|
||||
|
|
@ -1,9 +0,0 @@
|
|||
# circuitforge_core/community/__init__.py
|
||||
# MIT License
|
||||
|
||||
from .models import CommunityPost
|
||||
from .db import CommunityDB
|
||||
from .store import SharedStore
|
||||
from .snipe_store import SellerTrustSignal, SnipeCommunityStore
|
||||
|
||||
__all__ = ["CommunityDB", "CommunityPost", "SharedStore", "SellerTrustSignal", "SnipeCommunityStore"]
|
||||
|
|
@ -1,117 +0,0 @@
|
|||
# circuitforge_core/community/db.py
|
||||
# MIT License
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib.resources
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import psycopg2
|
||||
from psycopg2.pool import ThreadedConnectionPool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_MIN_CONN = 1
|
||||
_MAX_CONN = 10
|
||||
|
||||
|
||||
class CommunityDB:
|
||||
"""Shared PostgreSQL connection pool + migration runner for the community module.
|
||||
|
||||
Products instantiate one CommunityDB at startup and pass it to SharedStore
|
||||
subclasses. The pool is thread-safe (ThreadedConnectionPool).
|
||||
|
||||
Usage:
|
||||
db = CommunityDB.from_env() # reads COMMUNITY_DB_URL
|
||||
db.run_migrations()
|
||||
store = MyProductStore(db)
|
||||
db.close() # at shutdown
|
||||
"""
|
||||
|
||||
def __init__(self, dsn: str | None) -> None:
|
||||
if not dsn:
|
||||
raise ValueError(
|
||||
"CommunityDB requires a DSN. "
|
||||
"Set COMMUNITY_DB_URL or pass dsn= explicitly."
|
||||
)
|
||||
self._pool = ThreadedConnectionPool(_MIN_CONN, _MAX_CONN, dsn=dsn)
|
||||
logger.debug("CommunityDB pool created (min=%d, max=%d)", _MIN_CONN, _MAX_CONN)
|
||||
|
||||
@classmethod
|
||||
def from_env(cls) -> "CommunityDB":
|
||||
"""Construct from the COMMUNITY_DB_URL environment variable."""
|
||||
import os
|
||||
dsn = os.environ.get("COMMUNITY_DB_URL")
|
||||
return cls(dsn=dsn)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Connection management
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def getconn(self):
|
||||
"""Borrow a connection from the pool. Must be returned via putconn()."""
|
||||
return self._pool.getconn()
|
||||
|
||||
def putconn(self, conn) -> None:
|
||||
"""Return a borrowed connection to the pool."""
|
||||
self._pool.putconn(conn)
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close all pool connections. Call at application shutdown."""
|
||||
self._pool.closeall()
|
||||
logger.debug("CommunityDB pool closed")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Migration runner
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _discover_migrations(self) -> list[Path]:
|
||||
"""Return sorted list of .sql migration files from the community migrations dir."""
|
||||
pkg = importlib.resources.files("circuitforge_core.community.migrations")
|
||||
files = sorted(
|
||||
[Path(str(p)) for p in pkg.iterdir() if str(p).endswith(".sql")],
|
||||
key=lambda p: p.name,
|
||||
)
|
||||
return files
|
||||
|
||||
def run_migrations(self) -> None:
|
||||
"""Apply all community migration SQL files in numeric order.
|
||||
|
||||
Uses a simple applied-migrations table to avoid re-running already
|
||||
applied migrations. Idempotent.
|
||||
"""
|
||||
conn = self.getconn()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute("""
|
||||
CREATE TABLE IF NOT EXISTS _community_migrations (
|
||||
filename TEXT PRIMARY KEY,
|
||||
applied_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
)
|
||||
""")
|
||||
conn.commit()
|
||||
|
||||
for migration_file in self._discover_migrations():
|
||||
name = migration_file.name
|
||||
cur.execute(
|
||||
"SELECT 1 FROM _community_migrations WHERE filename = %s",
|
||||
(name,),
|
||||
)
|
||||
if cur.fetchone():
|
||||
logger.debug("Migration %s already applied, skipping", name)
|
||||
continue
|
||||
|
||||
sql = migration_file.read_text()
|
||||
logger.info("Applying community migration: %s", name)
|
||||
cur.execute(sql)
|
||||
cur.execute(
|
||||
"INSERT INTO _community_migrations (filename) VALUES (%s)",
|
||||
(name,),
|
||||
)
|
||||
conn.commit()
|
||||
except Exception:
|
||||
conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self.putconn(conn)
|
||||
|
|
@ -1,55 +0,0 @@
|
|||
-- 001_community_posts.sql
|
||||
-- Community posts table: published meal plans, recipe successes, and bloopers.
|
||||
-- Applies to: cf_community PostgreSQL database (hosted by cf-orch).
|
||||
-- BSL boundary: this schema is MIT (data layer, no inference).
|
||||
|
||||
CREATE TABLE IF NOT EXISTS community_posts (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
slug TEXT NOT NULL UNIQUE,
|
||||
pseudonym TEXT NOT NULL,
|
||||
post_type TEXT NOT NULL CHECK (post_type IN ('plan', 'recipe_success', 'recipe_blooper')),
|
||||
published TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
title TEXT NOT NULL,
|
||||
description TEXT,
|
||||
photo_url TEXT,
|
||||
|
||||
-- Plan slots (JSON array: [{day, meal_type, recipe_id, recipe_name}])
|
||||
slots JSONB NOT NULL DEFAULT '[]',
|
||||
|
||||
-- Recipe result fields
|
||||
recipe_id BIGINT,
|
||||
recipe_name TEXT,
|
||||
level SMALLINT CHECK (level IS NULL OR level BETWEEN 1 AND 4),
|
||||
outcome_notes TEXT,
|
||||
|
||||
-- Element snapshot (denormalized from corpus at publish time)
|
||||
seasoning_score REAL,
|
||||
richness_score REAL,
|
||||
brightness_score REAL,
|
||||
depth_score REAL,
|
||||
aroma_score REAL,
|
||||
structure_score REAL,
|
||||
texture_profile TEXT,
|
||||
|
||||
-- Dietary / allergen / flavor
|
||||
dietary_tags JSONB NOT NULL DEFAULT '[]',
|
||||
allergen_flags JSONB NOT NULL DEFAULT '[]',
|
||||
flavor_molecules JSONB NOT NULL DEFAULT '[]',
|
||||
|
||||
-- USDA FDC macros
|
||||
fat_pct REAL,
|
||||
protein_pct REAL,
|
||||
moisture_pct REAL,
|
||||
|
||||
-- Source product identifier
|
||||
source_product TEXT NOT NULL DEFAULT 'kiwi'
|
||||
);
|
||||
|
||||
-- Indexes for common filter patterns
|
||||
CREATE INDEX IF NOT EXISTS idx_community_posts_published ON community_posts (published DESC);
|
||||
CREATE INDEX IF NOT EXISTS idx_community_posts_post_type ON community_posts (post_type);
|
||||
CREATE INDEX IF NOT EXISTS idx_community_posts_source ON community_posts (source_product);
|
||||
|
||||
-- GIN index for dietary/allergen JSONB array containment queries
|
||||
CREATE INDEX IF NOT EXISTS idx_community_posts_dietary_tags ON community_posts USING GIN (dietary_tags);
|
||||
CREATE INDEX IF NOT EXISTS idx_community_posts_allergen_flags ON community_posts USING GIN (allergen_flags);
|
||||
|
|
@ -1,7 +0,0 @@
|
|||
-- 002_community_post_reactions.sql
|
||||
-- Reserved: community post reactions (thumbs-up, saves count).
|
||||
-- Not yet implemented -- this migration is a stub to reserve the sequence number.
|
||||
-- Applies to: cf_community PostgreSQL database (hosted by cf-orch).
|
||||
|
||||
-- Placeholder: no-op. Will be replaced when reactions feature is designed.
|
||||
SELECT 1;
|
||||
|
|
@ -1,26 +0,0 @@
|
|||
-- Seller trust signals: confirmed scammer / confirmed legitimate outcomes from Snipe.
|
||||
-- Separate table from community_posts (Kiwi-specific) — seller signals are a
|
||||
-- structurally different domain and should not overload the recipe post schema.
|
||||
-- Applies to: cf_community PostgreSQL database (hosted by cf-orch).
|
||||
-- BSL boundary: table schema is MIT; signal ingestion route in cf-orch is BSL 1.1.
|
||||
|
||||
CREATE TABLE IF NOT EXISTS seller_trust_signals (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
platform TEXT NOT NULL DEFAULT 'ebay',
|
||||
platform_seller_id TEXT NOT NULL,
|
||||
confirmed_scam BOOLEAN NOT NULL,
|
||||
signal_source TEXT NOT NULL, -- 'blocklist_add' | 'community_vote' | 'resolved'
|
||||
flags JSONB NOT NULL DEFAULT '[]', -- red flag keys at time of signal
|
||||
source_product TEXT NOT NULL DEFAULT 'snipe',
|
||||
recorded_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
|
||||
-- No PII: platform_seller_id is the public eBay username or platform ID only.
|
||||
CREATE INDEX IF NOT EXISTS idx_seller_trust_platform_id
|
||||
ON seller_trust_signals (platform, platform_seller_id);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_seller_trust_confirmed
|
||||
ON seller_trust_signals (confirmed_scam);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_seller_trust_recorded
|
||||
ON seller_trust_signals (recorded_at DESC);
|
||||
|
|
@ -1,19 +0,0 @@
|
|||
-- 004_community_categories.sql
|
||||
-- MIT License
|
||||
-- Shared eBay category tree published by credentialed Snipe instances.
|
||||
-- Credentialless instances pull from this table during refresh().
|
||||
-- Privacy: only public eBay category metadata (IDs, names, paths) — no user data.
|
||||
|
||||
CREATE TABLE IF NOT EXISTS community_categories (
|
||||
id SERIAL PRIMARY KEY,
|
||||
platform TEXT NOT NULL DEFAULT 'ebay',
|
||||
category_id TEXT NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
full_path TEXT NOT NULL,
|
||||
source_product TEXT NOT NULL DEFAULT 'snipe',
|
||||
published_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
UNIQUE (platform, category_id)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_community_cat_name
|
||||
ON community_categories (platform, name);
|
||||
|
|
@ -1,42 +0,0 @@
|
|||
-- 005_recipe_tags.sql
|
||||
-- Community-contributed recipe subcategory tags.
|
||||
--
|
||||
-- Users can tag corpus recipes (from a product's local recipe dataset) with a
|
||||
-- domain/category/subcategory from that product's browse taxonomy. Tags are
|
||||
-- keyed by (recipe_source, recipe_ref) so a single table serves all CF products
|
||||
-- that have a recipe corpus (currently: kiwi).
|
||||
--
|
||||
-- Acceptance threshold: upvotes >= 2 (submitter's implicit vote counts as 1,
|
||||
-- so one additional voter is enough to publish). Browse counts caches merge
|
||||
-- accepted tags into subcategory totals on each nightly refresh.
|
||||
|
||||
CREATE TABLE IF NOT EXISTS recipe_tags (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
recipe_source TEXT NOT NULL CHECK (recipe_source IN ('corpus')),
|
||||
recipe_ref TEXT NOT NULL, -- corpus integer recipe ID stored as text
|
||||
domain TEXT NOT NULL,
|
||||
category TEXT NOT NULL,
|
||||
subcategory TEXT, -- NULL = category-level tag (no subcategory)
|
||||
pseudonym TEXT NOT NULL,
|
||||
upvotes INTEGER NOT NULL DEFAULT 1, -- starts at 1 (submitter's own vote)
|
||||
source_product TEXT NOT NULL DEFAULT 'kiwi',
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
-- one tag per (recipe, location, user) — prevents submitting the same tag twice
|
||||
UNIQUE (recipe_source, recipe_ref, domain, category, subcategory, pseudonym)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_recipe_tags_lookup
|
||||
ON recipe_tags (source_product, domain, category, subcategory)
|
||||
WHERE upvotes >= 2;
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_recipe_tags_recipe
|
||||
ON recipe_tags (recipe_source, recipe_ref);
|
||||
|
||||
-- Tracks who voted on which tag to prevent double-voting.
|
||||
-- The submitter's self-vote is inserted here at submission time.
|
||||
CREATE TABLE IF NOT EXISTS recipe_tag_votes (
|
||||
tag_id BIGINT NOT NULL REFERENCES recipe_tags(id) ON DELETE CASCADE,
|
||||
pseudonym TEXT NOT NULL,
|
||||
voted_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
PRIMARY KEY (tag_id, pseudonym)
|
||||
);
|
||||
|
|
@ -1,22 +0,0 @@
|
|||
-- 006_community_dedup.sql
|
||||
-- Adds variation-linking and title search support for community recipe dedup.
|
||||
-- Applies to: cf_community PostgreSQL database.
|
||||
-- BSL boundary: MIT (data layer, no inference).
|
||||
|
||||
-- Nullable self-referential FK: user-declared "this is a variation of X"
|
||||
ALTER TABLE community_posts
|
||||
ADD COLUMN IF NOT EXISTS similar_to_ref TEXT REFERENCES community_posts(slug) ON DELETE SET NULL;
|
||||
|
||||
-- Index for variation lookup (find all variations of a parent post)
|
||||
CREATE INDEX IF NOT EXISTS idx_community_posts_similar_ref
|
||||
ON community_posts (similar_to_ref)
|
||||
WHERE similar_to_ref IS NOT NULL;
|
||||
|
||||
-- Index to speed up title ILIKE prefix and substring searches
|
||||
CREATE INDEX IF NOT EXISTS idx_community_posts_title_lower
|
||||
ON community_posts (lower(title));
|
||||
|
||||
-- Index on recipe_id for exact-recipe duplicate detection
|
||||
CREATE INDEX IF NOT EXISTS idx_community_posts_recipe_id
|
||||
ON community_posts (recipe_id)
|
||||
WHERE recipe_id IS NOT NULL;
|
||||
|
|
@ -1,2 +0,0 @@
|
|||
# Community module migrations
|
||||
# These SQL files are shipped with circuitforge-core so cf-orch can locate them via importlib.resources.
|
||||
|
|
@ -1,90 +0,0 @@
|
|||
# circuitforge_core/community/models.py
|
||||
# MIT License
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Literal
|
||||
|
||||
PostType = Literal["plan", "recipe_success", "recipe_blooper"]
|
||||
CreativityLevel = Literal[1, 2, 3, 4]
|
||||
|
||||
_VALID_POST_TYPES: frozenset[str] = frozenset(["plan", "recipe_success", "recipe_blooper"])
|
||||
|
||||
|
||||
def _validate_score(name: str, value: float) -> float:
|
||||
if not (0.0 <= value <= 1.0):
|
||||
raise ValueError(f"{name} must be between 0.0 and 1.0, got {value!r}")
|
||||
return value
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CommunityPost:
|
||||
"""Immutable snapshot of a published community post.
|
||||
|
||||
Lists (dietary_tags, allergen_flags, flavor_molecules, slots) are stored as
|
||||
tuples to enforce immutability. Pass lists -- they are converted in __post_init__.
|
||||
"""
|
||||
|
||||
# Identity
|
||||
slug: str
|
||||
pseudonym: str
|
||||
post_type: PostType
|
||||
published: datetime
|
||||
title: str
|
||||
|
||||
# Optional content
|
||||
description: str | None
|
||||
photo_url: str | None
|
||||
|
||||
# Plan slots -- list[dict] for post_type="plan"
|
||||
slots: tuple
|
||||
|
||||
# Recipe result fields -- for post_type="recipe_success" | "recipe_blooper"
|
||||
recipe_id: int | None
|
||||
recipe_name: str | None
|
||||
level: CreativityLevel | None
|
||||
outcome_notes: str | None
|
||||
|
||||
# Element snapshot
|
||||
seasoning_score: float
|
||||
richness_score: float
|
||||
brightness_score: float
|
||||
depth_score: float
|
||||
aroma_score: float
|
||||
structure_score: float
|
||||
texture_profile: str
|
||||
|
||||
# Dietary/allergen/flavor
|
||||
dietary_tags: tuple
|
||||
allergen_flags: tuple
|
||||
flavor_molecules: tuple
|
||||
|
||||
# USDA FDC macros (optional -- may not be available for all recipes)
|
||||
fat_pct: float | None
|
||||
protein_pct: float | None
|
||||
moisture_pct: float | None
|
||||
|
||||
# Variation link: slug of the parent post this is explicitly a variation of
|
||||
similar_to_ref: str | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# Coerce list fields to tuples (frozen dataclass: use object.__setattr__)
|
||||
for key in ("slots", "dietary_tags", "allergen_flags", "flavor_molecules"):
|
||||
val = getattr(self, key)
|
||||
if isinstance(val, list):
|
||||
object.__setattr__(self, key, tuple(val))
|
||||
|
||||
# Validate post_type
|
||||
if self.post_type not in _VALID_POST_TYPES:
|
||||
raise ValueError(
|
||||
f"post_type must be one of {sorted(_VALID_POST_TYPES)}, got {self.post_type!r}"
|
||||
)
|
||||
|
||||
# Validate scores
|
||||
for score_name in (
|
||||
"seasoning_score", "richness_score", "brightness_score",
|
||||
"depth_score", "aroma_score", "structure_score",
|
||||
):
|
||||
_validate_score(score_name, getattr(self, score_name))
|
||||
|
|
@ -1,253 +0,0 @@
|
|||
# circuitforge_core/community/snipe_store.py
|
||||
# MIT License
|
||||
"""Snipe community store — publishes seller trust signals to the shared community DB.
|
||||
|
||||
Snipe products subclass SharedStore here to write seller trust signals
|
||||
(confirmed scammer / confirmed legitimate) to the cf_community PostgreSQL.
|
||||
These signals aggregate across all Snipe users to power the cross-user
|
||||
seller trust classifier fine-tuning corpus.
|
||||
|
||||
Privacy: only platform_seller_id (public eBay username/ID) and flag keys
|
||||
are written. No PII is stored.
|
||||
|
||||
Usage:
|
||||
from circuitforge_core.community import CommunityDB
|
||||
from circuitforge_core.community.snipe_store import SnipeCommunityStore
|
||||
|
||||
db = CommunityDB.from_env()
|
||||
store = SnipeCommunityStore(db, source_product="snipe")
|
||||
store.publish_seller_signal(
|
||||
platform_seller_id="ebay-username",
|
||||
confirmed_scam=True,
|
||||
signal_source="blocklist_add",
|
||||
flags=["new_account", "suspicious_price"],
|
||||
)
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from .store import SharedStore
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SellerTrustSignal:
|
||||
"""Immutable snapshot of a recorded seller trust signal."""
|
||||
id: int
|
||||
platform: str
|
||||
platform_seller_id: str
|
||||
confirmed_scam: bool
|
||||
signal_source: str
|
||||
flags: tuple
|
||||
source_product: str
|
||||
recorded_at: datetime
|
||||
|
||||
|
||||
class SnipeCommunityStore(SharedStore):
|
||||
"""Community store for Snipe — seller trust signal publishing and querying."""
|
||||
|
||||
def __init__(self, db, source_product: str = "snipe") -> None:
|
||||
super().__init__(db, source_product=source_product)
|
||||
|
||||
def publish_seller_signal(
|
||||
self,
|
||||
platform_seller_id: str,
|
||||
confirmed_scam: bool,
|
||||
signal_source: str,
|
||||
flags: list[str] | None = None,
|
||||
platform: str = "ebay",
|
||||
) -> SellerTrustSignal:
|
||||
"""Record a seller trust outcome in the shared community DB.
|
||||
|
||||
Args:
|
||||
platform_seller_id: Public eBay username or platform ID (no PII).
|
||||
confirmed_scam: True = confirmed bad actor; False = confirmed legitimate.
|
||||
signal_source: Origin of the signal.
|
||||
'blocklist_add' — user explicitly added to local blocklist
|
||||
'community_vote' — consensus threshold reached from multiple reports
|
||||
'resolved' — seller resolved as legitimate over time
|
||||
flags: List of red-flag keys active at signal time (e.g. ["new_account"]).
|
||||
platform: Source auction platform (default "ebay").
|
||||
|
||||
Returns the inserted SellerTrustSignal.
|
||||
"""
|
||||
flags = flags or []
|
||||
conn = self._db.getconn()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO seller_trust_signals
|
||||
(platform, platform_seller_id, confirmed_scam,
|
||||
signal_source, flags, source_product)
|
||||
VALUES (%s, %s, %s, %s, %s::jsonb, %s)
|
||||
RETURNING id, recorded_at
|
||||
""",
|
||||
(
|
||||
platform,
|
||||
platform_seller_id,
|
||||
confirmed_scam,
|
||||
signal_source,
|
||||
json.dumps(flags),
|
||||
self._source_product,
|
||||
),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
conn.commit()
|
||||
return SellerTrustSignal(
|
||||
id=row[0],
|
||||
platform=platform,
|
||||
platform_seller_id=platform_seller_id,
|
||||
confirmed_scam=confirmed_scam,
|
||||
signal_source=signal_source,
|
||||
flags=tuple(flags),
|
||||
source_product=self._source_product,
|
||||
recorded_at=row[1],
|
||||
)
|
||||
except Exception:
|
||||
conn.rollback()
|
||||
log.warning(
|
||||
"Failed to publish seller signal for %s (%s)",
|
||||
platform_seller_id, signal_source, exc_info=True,
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
self._db.putconn(conn)
|
||||
|
||||
def list_signals_for_seller(
|
||||
self,
|
||||
platform_seller_id: str,
|
||||
platform: str = "ebay",
|
||||
limit: int = 50,
|
||||
) -> list[SellerTrustSignal]:
|
||||
"""Return recent trust signals for a specific seller."""
|
||||
conn = self._db.getconn()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT id, platform, platform_seller_id, confirmed_scam,
|
||||
signal_source, flags, source_product, recorded_at
|
||||
FROM seller_trust_signals
|
||||
WHERE platform = %s AND platform_seller_id = %s
|
||||
ORDER BY recorded_at DESC
|
||||
LIMIT %s
|
||||
""",
|
||||
(platform, platform_seller_id, limit),
|
||||
)
|
||||
rows = cur.fetchall()
|
||||
return [
|
||||
SellerTrustSignal(
|
||||
id=r[0], platform=r[1], platform_seller_id=r[2],
|
||||
confirmed_scam=r[3], signal_source=r[4],
|
||||
flags=tuple(json.loads(r[5]) if isinstance(r[5], str) else r[5] or []),
|
||||
source_product=r[6], recorded_at=r[7],
|
||||
)
|
||||
for r in rows
|
||||
]
|
||||
finally:
|
||||
self._db.putconn(conn)
|
||||
|
||||
def scam_signal_count(self, platform_seller_id: str, platform: str = "ebay") -> int:
|
||||
"""Return the number of confirmed_scam=True signals for a seller.
|
||||
|
||||
Used to determine if a seller has crossed the community consensus threshold
|
||||
for appearing in the shared blocklist.
|
||||
"""
|
||||
conn = self._db.getconn()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT COUNT(*) FROM seller_trust_signals
|
||||
WHERE platform = %s AND platform_seller_id = %s AND confirmed_scam = TRUE
|
||||
""",
|
||||
(platform, platform_seller_id),
|
||||
)
|
||||
return cur.fetchone()[0]
|
||||
finally:
|
||||
self._db.putconn(conn)
|
||||
|
||||
def publish_categories(
|
||||
self,
|
||||
categories: list[tuple[str, str, str]],
|
||||
platform: str = "ebay",
|
||||
) -> int:
|
||||
"""Upsert a batch of eBay leaf categories into the shared community table.
|
||||
|
||||
Args:
|
||||
categories: List of (category_id, name, full_path) tuples.
|
||||
platform: Source auction platform (default "ebay").
|
||||
|
||||
Returns:
|
||||
Number of rows upserted.
|
||||
"""
|
||||
if not categories:
|
||||
return 0
|
||||
conn = self._db.getconn()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.executemany(
|
||||
"""
|
||||
INSERT INTO community_categories
|
||||
(platform, category_id, name, full_path, source_product)
|
||||
VALUES (%s, %s, %s, %s, %s)
|
||||
ON CONFLICT (platform, category_id)
|
||||
DO UPDATE SET
|
||||
name = EXCLUDED.name,
|
||||
full_path = EXCLUDED.full_path,
|
||||
source_product = EXCLUDED.source_product,
|
||||
published_at = NOW()
|
||||
""",
|
||||
[
|
||||
(platform, cid, name, path, self._source_product)
|
||||
for cid, name, path in categories
|
||||
],
|
||||
)
|
||||
conn.commit()
|
||||
return len(categories)
|
||||
except Exception:
|
||||
conn.rollback()
|
||||
log.warning(
|
||||
"Failed to publish %d categories to community store",
|
||||
len(categories), exc_info=True,
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
self._db.putconn(conn)
|
||||
|
||||
def fetch_categories(
|
||||
self,
|
||||
platform: str = "ebay",
|
||||
limit: int = 500,
|
||||
) -> list[tuple[str, str, str]]:
|
||||
"""Fetch community-contributed eBay categories.
|
||||
|
||||
Args:
|
||||
platform: Source auction platform (default "ebay").
|
||||
limit: Maximum rows to return.
|
||||
|
||||
Returns:
|
||||
List of (category_id, name, full_path) tuples ordered by name.
|
||||
"""
|
||||
conn = self._db.getconn()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT category_id, name, full_path
|
||||
FROM community_categories
|
||||
WHERE platform = %s
|
||||
ORDER BY name
|
||||
LIMIT %s
|
||||
""",
|
||||
(platform, limit),
|
||||
)
|
||||
return [(row[0], row[1], row[2]) for row in cur.fetchall()]
|
||||
finally:
|
||||
self._db.putconn(conn)
|
||||
|
|
@ -1,434 +0,0 @@
|
|||
# circuitforge_core/community/store.py
|
||||
# MIT License
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .models import CommunityPost
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .db import CommunityDB
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _row_to_post(row: dict) -> CommunityPost:
|
||||
"""Convert a psycopg2 row dict to a CommunityPost.
|
||||
|
||||
JSONB columns (slots, dietary_tags, allergen_flags, flavor_molecules) come
|
||||
back from psycopg2 as Python lists already — no json.loads() needed.
|
||||
"""
|
||||
return CommunityPost(
|
||||
slug=row["slug"],
|
||||
pseudonym=row["pseudonym"],
|
||||
post_type=row["post_type"],
|
||||
published=row["published"],
|
||||
title=row["title"],
|
||||
description=row.get("description"),
|
||||
photo_url=row.get("photo_url"),
|
||||
slots=row.get("slots") or [],
|
||||
recipe_id=row.get("recipe_id"),
|
||||
recipe_name=row.get("recipe_name"),
|
||||
level=row.get("level"),
|
||||
outcome_notes=row.get("outcome_notes"),
|
||||
seasoning_score=row["seasoning_score"] or 0.0,
|
||||
richness_score=row["richness_score"] or 0.0,
|
||||
brightness_score=row["brightness_score"] or 0.0,
|
||||
depth_score=row["depth_score"] or 0.0,
|
||||
aroma_score=row["aroma_score"] or 0.0,
|
||||
structure_score=row["structure_score"] or 0.0,
|
||||
texture_profile=row.get("texture_profile") or "",
|
||||
dietary_tags=row.get("dietary_tags") or [],
|
||||
allergen_flags=row.get("allergen_flags") or [],
|
||||
flavor_molecules=row.get("flavor_molecules") or [],
|
||||
fat_pct=row.get("fat_pct"),
|
||||
protein_pct=row.get("protein_pct"),
|
||||
moisture_pct=row.get("moisture_pct"),
|
||||
similar_to_ref=row.get("similar_to_ref"),
|
||||
)
|
||||
|
||||
|
||||
def _cursor_to_dict(cur, row) -> dict:
|
||||
"""Convert a psycopg2 row tuple to a dict using cursor.description."""
|
||||
if isinstance(row, dict):
|
||||
return row
|
||||
return {desc[0]: val for desc, val in zip(cur.description, row)}
|
||||
|
||||
|
||||
class SharedStore:
|
||||
"""Base class for product community stores.
|
||||
|
||||
Subclass this in each product:
|
||||
class KiwiCommunityStore(SharedStore):
|
||||
def list_posts_for_week(self, week_start: str) -> list[CommunityPost]: ...
|
||||
|
||||
All methods return new objects (immutable pattern). Never mutate rows in-place.
|
||||
"""
|
||||
|
||||
def __init__(self, db: "CommunityDB", source_product: str = "kiwi") -> None:
|
||||
self._db = db
|
||||
self._source_product = source_product
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Reads
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def get_post_by_slug(self, slug: str) -> CommunityPost | None:
|
||||
conn = self._db.getconn()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"SELECT * FROM community_posts WHERE slug = %s LIMIT 1",
|
||||
(slug,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
return _row_to_post(_cursor_to_dict(cur, row))
|
||||
finally:
|
||||
self._db.putconn(conn)
|
||||
|
||||
def list_posts(
|
||||
self,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
post_type: str | None = None,
|
||||
dietary_tags: list[str] | None = None,
|
||||
allergen_exclude: list[str] | None = None,
|
||||
source_product: str | None = None,
|
||||
) -> list[CommunityPost]:
|
||||
"""Paginated post list with optional filters.
|
||||
|
||||
dietary_tags: JSONB containment — posts must include ALL listed tags.
|
||||
allergen_exclude: JSONB overlap exclusion — posts must NOT include any listed flag.
|
||||
"""
|
||||
conn = self._db.getconn()
|
||||
try:
|
||||
conditions = []
|
||||
params: list = []
|
||||
|
||||
if post_type:
|
||||
conditions.append("post_type = %s")
|
||||
params.append(post_type)
|
||||
if dietary_tags:
|
||||
import json
|
||||
conditions.append("dietary_tags @> %s::jsonb")
|
||||
params.append(json.dumps(dietary_tags))
|
||||
if allergen_exclude:
|
||||
import json
|
||||
conditions.append("NOT (allergen_flags && %s::jsonb)")
|
||||
params.append(json.dumps(allergen_exclude))
|
||||
if source_product:
|
||||
conditions.append("source_product = %s")
|
||||
params.append(source_product)
|
||||
|
||||
where = ("WHERE " + " AND ".join(conditions)) if conditions else ""
|
||||
params.extend([limit, offset])
|
||||
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
f"SELECT * FROM community_posts {where} "
|
||||
"ORDER BY published DESC LIMIT %s OFFSET %s",
|
||||
params,
|
||||
)
|
||||
rows = cur.fetchall()
|
||||
return [_row_to_post(_cursor_to_dict(cur, r)) for r in rows]
|
||||
finally:
|
||||
self._db.putconn(conn)
|
||||
|
||||
def search_similar_posts(
|
||||
self,
|
||||
title: str,
|
||||
recipe_id: int | None = None,
|
||||
post_type: str | None = None,
|
||||
limit: int = 8,
|
||||
) -> list[CommunityPost]:
|
||||
"""Return posts similar to the given title or with the same recipe_id.
|
||||
|
||||
Used by the dedup check before a new post is submitted. Matches on:
|
||||
- exact recipe_id (strongest signal)
|
||||
- case-insensitive title substring match
|
||||
|
||||
Results are ordered: recipe_id matches first, then by published desc.
|
||||
"""
|
||||
conn = self._db.getconn()
|
||||
try:
|
||||
conditions: list[str] = []
|
||||
params: list = []
|
||||
|
||||
title_condition = "lower(title) LIKE lower(%s)"
|
||||
title_param = f"%{title.lower()[:80]}%"
|
||||
|
||||
if recipe_id is not None:
|
||||
conditions.append(f"(recipe_id = %s OR {title_condition})")
|
||||
params.extend([recipe_id, title_param])
|
||||
else:
|
||||
conditions.append(title_condition)
|
||||
params.append(title_param)
|
||||
|
||||
if post_type:
|
||||
conditions.append("post_type = %s")
|
||||
params.append(post_type)
|
||||
|
||||
where = "WHERE " + " AND ".join(conditions)
|
||||
params.append(limit)
|
||||
|
||||
order_clause = (
|
||||
"ORDER BY (recipe_id = %s) DESC, published DESC"
|
||||
if recipe_id is not None
|
||||
else "ORDER BY published DESC"
|
||||
)
|
||||
if recipe_id is not None:
|
||||
params.insert(-1, recipe_id)
|
||||
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
f"SELECT * FROM community_posts {where} {order_clause} LIMIT %s",
|
||||
params,
|
||||
)
|
||||
rows = cur.fetchall()
|
||||
return [_row_to_post(_cursor_to_dict(cur, r)) for r in rows]
|
||||
finally:
|
||||
self._db.putconn(conn)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Writes
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def insert_post(self, post: CommunityPost) -> CommunityPost:
|
||||
"""Insert a new community post. Returns the inserted post (unchanged — slug is the key)."""
|
||||
import json
|
||||
|
||||
conn = self._db.getconn()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO community_posts (
|
||||
slug, pseudonym, post_type, published, title, description, photo_url,
|
||||
slots, recipe_id, recipe_name, level, outcome_notes,
|
||||
seasoning_score, richness_score, brightness_score,
|
||||
depth_score, aroma_score, structure_score, texture_profile,
|
||||
dietary_tags, allergen_flags, flavor_molecules,
|
||||
fat_pct, protein_pct, moisture_pct, source_product,
|
||||
similar_to_ref
|
||||
) VALUES (
|
||||
%s, %s, %s, %s, %s, %s, %s,
|
||||
%s::jsonb, %s, %s, %s, %s,
|
||||
%s, %s, %s, %s, %s, %s, %s,
|
||||
%s::jsonb, %s::jsonb, %s::jsonb,
|
||||
%s, %s, %s, %s, %s
|
||||
)
|
||||
""",
|
||||
(
|
||||
post.slug, post.pseudonym, post.post_type,
|
||||
post.published, post.title, post.description, post.photo_url,
|
||||
json.dumps(list(post.slots)),
|
||||
post.recipe_id, post.recipe_name, post.level, post.outcome_notes,
|
||||
post.seasoning_score, post.richness_score, post.brightness_score,
|
||||
post.depth_score, post.aroma_score, post.structure_score,
|
||||
post.texture_profile,
|
||||
json.dumps(list(post.dietary_tags)),
|
||||
json.dumps(list(post.allergen_flags)),
|
||||
json.dumps(list(post.flavor_molecules)),
|
||||
post.fat_pct, post.protein_pct, post.moisture_pct,
|
||||
self._source_product,
|
||||
post.similar_to_ref,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
return post
|
||||
except Exception:
|
||||
conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self._db.putconn(conn)
|
||||
|
||||
def delete_post(self, slug: str, pseudonym: str) -> bool:
|
||||
"""Hard-delete a post. Only succeeds if pseudonym matches the author.
|
||||
|
||||
Returns True if a row was deleted, False if no matching row found.
|
||||
"""
|
||||
conn = self._db.getconn()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"DELETE FROM community_posts WHERE slug = %s AND pseudonym = %s",
|
||||
(slug, pseudonym),
|
||||
)
|
||||
conn.commit()
|
||||
return cur.rowcount > 0
|
||||
except Exception:
|
||||
conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self._db.putconn(conn)
|
||||
|
||||
# ── Recipe tags ───────────────────────────────────────────────────────────
|
||||
|
||||
def submit_recipe_tag(
|
||||
self,
|
||||
recipe_id: int,
|
||||
domain: str,
|
||||
category: str,
|
||||
subcategory: str | None,
|
||||
pseudonym: str,
|
||||
source_product: str = "kiwi",
|
||||
) -> dict:
|
||||
"""Submit a new subcategory tag for a corpus recipe.
|
||||
|
||||
Inserts the tag with upvotes=1 and records the submitter's self-vote in
|
||||
recipe_tag_votes. Returns the created tag row as a dict.
|
||||
|
||||
Raises psycopg2.errors.UniqueViolation if the same user has already
|
||||
tagged this recipe to this location — let the caller handle it.
|
||||
"""
|
||||
conn = self._db.getconn()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO recipe_tags
|
||||
(recipe_source, recipe_ref, domain, category, subcategory,
|
||||
pseudonym, upvotes, source_product)
|
||||
VALUES ('corpus', %s, %s, %s, %s, %s, 1, %s)
|
||||
RETURNING id, recipe_ref, domain, category, subcategory,
|
||||
pseudonym, upvotes, created_at
|
||||
""",
|
||||
(str(recipe_id), domain, category, subcategory,
|
||||
pseudonym, source_product),
|
||||
)
|
||||
row = dict(zip([d[0] for d in cur.description], cur.fetchone()))
|
||||
# Record submitter's self-vote
|
||||
cur.execute(
|
||||
"INSERT INTO recipe_tag_votes (tag_id, pseudonym) VALUES (%s, %s)",
|
||||
(row["id"], pseudonym),
|
||||
)
|
||||
conn.commit()
|
||||
return row
|
||||
except Exception:
|
||||
conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self._db.putconn(conn)
|
||||
|
||||
def upvote_recipe_tag(self, tag_id: int, pseudonym: str) -> int:
|
||||
"""Add an upvote to a tag from pseudonym. Returns new upvote count.
|
||||
|
||||
Raises psycopg2.errors.UniqueViolation if this pseudonym already voted.
|
||||
Raises ValueError if the tag does not exist.
|
||||
"""
|
||||
conn = self._db.getconn()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"INSERT INTO recipe_tag_votes (tag_id, pseudonym) VALUES (%s, %s)",
|
||||
(tag_id, pseudonym),
|
||||
)
|
||||
cur.execute(
|
||||
"UPDATE recipe_tags SET upvotes = upvotes + 1 WHERE id = %s"
|
||||
" RETURNING upvotes",
|
||||
(tag_id,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if row is None:
|
||||
raise ValueError(f"recipe_tag {tag_id} not found")
|
||||
conn.commit()
|
||||
return row[0]
|
||||
except Exception:
|
||||
conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self._db.putconn(conn)
|
||||
|
||||
def get_recipe_tag_by_id(self, tag_id: int) -> dict | None:
|
||||
"""Return a single recipe_tag row by ID, or None if not found."""
|
||||
conn = self._db.getconn()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT id, recipe_ref, domain, category, subcategory,
|
||||
pseudonym, upvotes, created_at
|
||||
FROM recipe_tags WHERE id = %s
|
||||
""",
|
||||
(tag_id,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
return dict(zip([d[0] for d in cur.description], row))
|
||||
finally:
|
||||
self._db.putconn(conn)
|
||||
|
||||
def list_tags_for_recipe(
|
||||
self,
|
||||
recipe_id: int,
|
||||
source_product: str = "kiwi",
|
||||
) -> list[dict]:
|
||||
"""Return all tags for a corpus recipe, accepted or not, newest first."""
|
||||
conn = self._db.getconn()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT id, domain, category, subcategory, pseudonym,
|
||||
upvotes, created_at
|
||||
FROM recipe_tags
|
||||
WHERE recipe_source = 'corpus'
|
||||
AND recipe_ref = %s
|
||||
AND source_product = %s
|
||||
ORDER BY upvotes DESC, created_at DESC
|
||||
""",
|
||||
(str(recipe_id), source_product),
|
||||
)
|
||||
cols = [d[0] for d in cur.description]
|
||||
return [dict(zip(cols, r)) for r in cur.fetchall()]
|
||||
finally:
|
||||
self._db.putconn(conn)
|
||||
|
||||
def get_accepted_recipe_ids_for_subcategory(
|
||||
self,
|
||||
domain: str,
|
||||
category: str,
|
||||
subcategory: str | None,
|
||||
source_product: str = "kiwi",
|
||||
threshold: int = 2,
|
||||
) -> list[int]:
|
||||
"""Return corpus recipe IDs with accepted community tags for a subcategory.
|
||||
|
||||
Used by browse_counts_cache refresh and browse_recipes() FTS fallback.
|
||||
Only includes tags that have reached the acceptance threshold.
|
||||
"""
|
||||
conn = self._db.getconn()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
if subcategory is None:
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT DISTINCT recipe_ref::INTEGER
|
||||
FROM recipe_tags
|
||||
WHERE source_product = %s
|
||||
AND domain = %s AND category = %s
|
||||
AND subcategory IS NULL
|
||||
AND upvotes >= %s
|
||||
""",
|
||||
(source_product, domain, category, threshold),
|
||||
)
|
||||
else:
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT DISTINCT recipe_ref::INTEGER
|
||||
FROM recipe_tags
|
||||
WHERE source_product = %s
|
||||
AND domain = %s AND category = %s
|
||||
AND subcategory = %s
|
||||
AND upvotes >= %s
|
||||
""",
|
||||
(source_product, domain, category, subcategory, threshold),
|
||||
)
|
||||
return [r[0] for r in cur.fetchall()]
|
||||
finally:
|
||||
self._db.putconn(conn)
|
||||
|
|
@ -23,7 +23,7 @@ def get_connection(db_path: Path, key: str = "") -> sqlite3.Connection:
|
|||
if cloud_mode and key:
|
||||
from pysqlcipher3 import dbapi2 as _sqlcipher # type: ignore
|
||||
conn = _sqlcipher.connect(str(db_path), timeout=30)
|
||||
conn.execute("PRAGMA key=?", (key,))
|
||||
conn.execute(f"PRAGMA key='{key}'")
|
||||
return conn
|
||||
# timeout=30: retry for up to 30s when another writer holds the lock (WAL mode
|
||||
# allows concurrent readers but only one writer at a time).
|
||||
|
|
|
|||
|
|
@ -4,22 +4,12 @@ Applies *.sql files from migrations_dir in filename order.
|
|||
Tracks applied migrations in a _migrations table — safe to call multiple times.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def run_migrations(conn: sqlite3.Connection, migrations_dir: Path) -> None:
|
||||
"""Apply any unapplied *.sql migrations from migrations_dir.
|
||||
|
||||
Resilient to partial-failure recovery: if a migration previously crashed
|
||||
mid-run (e.g. a process killed after some ALTER TABLE statements
|
||||
auto-committed via executescript), the next startup re-runs that migration.
|
||||
Any "duplicate column name" errors are silently skipped so the migration
|
||||
can complete and be marked as applied. All other errors still propagate.
|
||||
"""
|
||||
"""Apply any unapplied *.sql migrations from migrations_dir."""
|
||||
conn.execute(
|
||||
"CREATE TABLE IF NOT EXISTS _migrations "
|
||||
"(name TEXT PRIMARY KEY, applied_at TEXT DEFAULT CURRENT_TIMESTAMP)"
|
||||
|
|
@ -32,92 +22,8 @@ def run_migrations(conn: sqlite3.Connection, migrations_dir: Path) -> None:
|
|||
for sql_file in sql_files:
|
||||
if sql_file.name in applied:
|
||||
continue
|
||||
|
||||
try:
|
||||
conn.executescript(sql_file.read_text())
|
||||
except sqlite3.OperationalError as exc:
|
||||
if "duplicate column name" not in str(exc).lower():
|
||||
raise
|
||||
# A previous run partially applied this migration (some ALTER TABLE
|
||||
# statements auto-committed before the failure). Re-run with
|
||||
# per-statement recovery to skip already-applied columns.
|
||||
_log.warning(
|
||||
"Migration %s: partial-failure detected (%s) — "
|
||||
"retrying with per-statement recovery",
|
||||
sql_file.name,
|
||||
exc,
|
||||
)
|
||||
_run_script_with_recovery(conn, sql_file)
|
||||
|
||||
conn.executescript(sql_file.read_text())
|
||||
# OR IGNORE: safe if two Store() calls race on the same DB — second writer
|
||||
# just skips the insert rather than raising UNIQUE constraint failed.
|
||||
conn.execute("INSERT OR IGNORE INTO _migrations (name) VALUES (?)", (sql_file.name,))
|
||||
conn.commit()
|
||||
|
||||
|
||||
def _run_script_with_recovery(conn: sqlite3.Connection, sql_file: Path) -> None:
|
||||
"""Re-run a migration via executescript, skipping duplicate-column errors.
|
||||
|
||||
Used only when the first executescript() attempt raised a duplicate column
|
||||
error (indicating a previous partial run). Splits the script on the
|
||||
double-dash comment prefix pattern to re-issue each logical statement,
|
||||
catching only the known-safe "duplicate column name" error class.
|
||||
|
||||
Splitting is done via SQLite's own parser — we feed the script to a
|
||||
temporary in-memory connection using executescript (which commits
|
||||
auto-matically per DDL statement) and mirror the results on the real
|
||||
connection statement by statement. That's circular, so instead we use
|
||||
the simpler approach: executescript handles tokenization; we wrap the
|
||||
whole call in a try/except and retry after removing the offending statement.
|
||||
|
||||
Simpler approach: use conn.execute() per statement from the script.
|
||||
This avoids the semicolon-in-comment tokenization problem by not splitting
|
||||
ourselves — instead we let the DB tell us which statement failed and only
|
||||
skip that exact error class.
|
||||
"""
|
||||
# executescript() uses SQLite's real tokenizer, so re-issuing it after a
|
||||
# partial failure will hit "duplicate column name" again. We catch and
|
||||
# ignore that specific error class only, re-running until the script
|
||||
# completes or a different error is raised.
|
||||
#
|
||||
# Implementation: issue the whole script again; catch duplicate-column
|
||||
# errors; keep trying. Since executescript auto-commits per statement,
|
||||
# each successful statement in successive retries is a no-op (CREATE TABLE
|
||||
# IF NOT EXISTS, etc.) or a benign duplicate skip.
|
||||
#
|
||||
# Limit retries to prevent infinite loops on genuinely broken SQL.
|
||||
script = sql_file.read_text()
|
||||
for attempt in range(20):
|
||||
try:
|
||||
conn.executescript(script)
|
||||
return # success
|
||||
except sqlite3.OperationalError as exc:
|
||||
msg = str(exc).lower()
|
||||
if "duplicate column name" in msg:
|
||||
col = str(exc).split(":")[-1].strip() if ":" in str(exc) else "?"
|
||||
_log.warning(
|
||||
"Migration %s (attempt %d): skipping duplicate column '%s'",
|
||||
sql_file.name,
|
||||
attempt + 1,
|
||||
col,
|
||||
)
|
||||
# Remove the offending ALTER TABLE statement from the script
|
||||
# so the next attempt skips it. This is safe because SQLite
|
||||
# already auto-committed that column addition on a prior run.
|
||||
script = _remove_column_add(script, col)
|
||||
else:
|
||||
raise
|
||||
raise RuntimeError(
|
||||
f"Migration {sql_file.name}: could not complete after 20 recovery attempts"
|
||||
)
|
||||
|
||||
|
||||
def _remove_column_add(script: str, column: str) -> str:
|
||||
"""Remove the ALTER TABLE ADD COLUMN statement for *column* from *script*."""
|
||||
import re
|
||||
# Match: ALTER TABLE <tbl> ADD COLUMN <column> <rest-of-line>
|
||||
pattern = re.compile(
|
||||
r"ALTER\s+TABLE\s+\w+\s+ADD\s+COLUMN\s+" + re.escape(column) + r"[^\n]*\n?",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
return pattern.sub("", script)
|
||||
|
|
|
|||
|
|
@ -1,133 +0,0 @@
|
|||
# circuitforge_core/documents/pdf.py
|
||||
"""
|
||||
circuitforge_core.documents.pdf — PDF text extraction and page-level chunking.
|
||||
|
||||
Primary path: pdfplumber (selectable text layers).
|
||||
Fallback: pytesseract OCR (scanned / image-only pages).
|
||||
|
||||
Usage::
|
||||
|
||||
from circuitforge_core.documents.pdf import PDFExtractor
|
||||
|
||||
chunks = PDFExtractor().chunk_pages("/path/to/book.pdf")
|
||||
for chunk in chunks:
|
||||
print(f"[p.{chunk.page_number}] ({chunk.source}) {chunk.text[:80]}")
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
import pdfplumber
|
||||
except ImportError: # pragma: no cover
|
||||
pdfplumber = None # type: ignore[assignment]
|
||||
|
||||
try:
|
||||
import pytesseract
|
||||
except ImportError: # pragma: no cover
|
||||
pytesseract = None # type: ignore[assignment]
|
||||
|
||||
try:
|
||||
from PIL import Image
|
||||
except ImportError: # pragma: no cover
|
||||
Image = None # type: ignore[assignment]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PageChunk:
|
||||
"""Text content extracted from a single PDF page."""
|
||||
|
||||
page_number: int # 1-indexed
|
||||
text: str
|
||||
source: str # "text_layer" | "ocr"
|
||||
word_count: int
|
||||
|
||||
|
||||
class PDFExtractor:
|
||||
"""
|
||||
Extract page-level text chunks from PDF files.
|
||||
|
||||
Args:
|
||||
ocr_min_words: Pages with fewer words from the text layer trigger OCR.
|
||||
"""
|
||||
|
||||
def __init__(self, ocr_min_words: int = 10) -> None:
|
||||
self.ocr_min_words = ocr_min_words
|
||||
|
||||
def chunk_pages(self, pdf_path: str | Path) -> list[PageChunk]:
|
||||
"""
|
||||
Primary entry point. Returns one PageChunk per page.
|
||||
|
||||
Uses text-layer extraction per page; falls back to OCR when text is sparse.
|
||||
Empty PDFs return an empty list.
|
||||
"""
|
||||
if pdfplumber is None:
|
||||
raise ImportError(
|
||||
"pdfplumber is required for PDF extraction. "
|
||||
"Install it with: pip install pdfplumber"
|
||||
)
|
||||
|
||||
path = Path(pdf_path)
|
||||
chunks: list[PageChunk] = []
|
||||
|
||||
with pdfplumber.open(path) as pdf:
|
||||
for i, page in enumerate(pdf.pages, start=1):
|
||||
text = page.extract_text() or ""
|
||||
words = text.split()
|
||||
|
||||
if len(words) >= self.ocr_min_words:
|
||||
chunks.append(
|
||||
PageChunk(
|
||||
page_number=i,
|
||||
text=text.strip(),
|
||||
source="text_layer",
|
||||
word_count=len(words),
|
||||
)
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
"pdf: page %d sparse (%d words), falling back to OCR",
|
||||
i,
|
||||
len(words),
|
||||
)
|
||||
chunks.append(self._ocr_page(page, i))
|
||||
|
||||
return chunks
|
||||
|
||||
def _ocr_page(self, page: object, page_number: int) -> PageChunk:
|
||||
"""Render page to image and extract text via tesseract."""
|
||||
try:
|
||||
rendered = page.to_image(resolution=200).original # type: ignore[attr-defined]
|
||||
rendered = _ensure_pil_image(rendered)
|
||||
text = pytesseract.image_to_string(rendered) # type: ignore[union-attr]
|
||||
words = text.split()
|
||||
return PageChunk(
|
||||
page_number=page_number,
|
||||
text=text.strip(),
|
||||
source="ocr",
|
||||
word_count=len(words),
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("pdf: OCR failed for page %d: %s", page_number, exc)
|
||||
return PageChunk(
|
||||
page_number=page_number, text="", source="ocr", word_count=0
|
||||
)
|
||||
|
||||
|
||||
def _ensure_pil_image(rendered: object) -> object:
|
||||
"""Return *rendered* as a PIL Image, converting from bytes if needed."""
|
||||
if Image is None:
|
||||
return rendered
|
||||
try:
|
||||
if not isinstance(rendered, Image.Image):
|
||||
rendered = Image.open(io.BytesIO(rendered)) # type: ignore[arg-type]
|
||||
except TypeError:
|
||||
# Image may be patched (e.g. in tests); skip the conversion.
|
||||
pass
|
||||
return rendered
|
||||
|
|
@ -69,7 +69,7 @@ VRAM_TIERS: list[VramTier] = [
|
|||
profile_name="single-gpu-8gb",
|
||||
ollama_model="qwen2.5:7b-instruct",
|
||||
vllm_candidates=["Qwen2.5-3B-Instruct", "Phi-4-mini-instruct"],
|
||||
services=["ollama", "vllm", "cf-vision", "cf-docuvision", "cf-stt", "cf-tts", "cf-musicgen"],
|
||||
services=["ollama", "vllm", "cf-vision", "cf-docuvision", "cf-stt", "cf-tts"],
|
||||
llm_max_params="8b",
|
||||
),
|
||||
VramTier(
|
||||
|
|
@ -79,7 +79,7 @@ VRAM_TIERS: list[VramTier] = [
|
|||
ollama_model="qwen2.5:14b-instruct-q4_k_m",
|
||||
vllm_candidates=["Qwen2.5-14B-Instruct", "Qwen2.5-3B-Instruct", "Phi-4-mini-instruct"],
|
||||
services=["ollama", "vllm", "cf-vision", "cf-docuvision", "cf-stt", "cf-tts",
|
||||
"cf-musicgen", "cf-embed", "cf-classify"],
|
||||
"cf-embed", "cf-classify"],
|
||||
llm_max_params="14b",
|
||||
),
|
||||
VramTier(
|
||||
|
|
@ -89,7 +89,7 @@ VRAM_TIERS: list[VramTier] = [
|
|||
ollama_model="qwen2.5:32b-instruct-q4_k_m",
|
||||
vllm_candidates=["Qwen2.5-14B-Instruct", "Qwen2.5-3B-Instruct", "Phi-4-mini-instruct"],
|
||||
services=["ollama", "vllm", "cf-vision", "cf-docuvision", "cf-stt", "cf-tts",
|
||||
"cf-musicgen", "cf-embed", "cf-classify", "comfyui"],
|
||||
"cf-embed", "cf-classify", "comfyui"],
|
||||
llm_max_params="32b-q4",
|
||||
),
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,15 +0,0 @@
|
|||
"""
|
||||
cf_input.gestures — camera capture, hand detection, landmark normalization.
|
||||
|
||||
Public API:
|
||||
CameraCapture — OpenCV frame source
|
||||
HandsDetector — MediaPipe Hands wrapper
|
||||
HandLandmarks — immutable detected hand dataclass
|
||||
normalize_hand() — scale/translation-invariant feature vector
|
||||
"""
|
||||
|
||||
from circuitforge_core.input.gestures.camera import CameraCapture
|
||||
from circuitforge_core.input.gestures.hands import HandLandmarks, HandsDetector
|
||||
from circuitforge_core.input.gestures.normalizer import normalize_hand
|
||||
|
||||
__all__ = ["CameraCapture", "HandLandmarks", "HandsDetector", "normalize_hand"]
|
||||
|
|
@ -1,57 +0,0 @@
|
|||
"""
|
||||
OpenCV camera capture — context manager wrapping VideoCapture.
|
||||
|
||||
Yields BGR frames. Callers convert to RGB before passing to HandsDetector:
|
||||
frame_rgb = frame_bgr[:, :, ::-1]
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Iterator
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
|
||||
class CameraCapture:
|
||||
"""
|
||||
Thin wrapper around cv2.VideoCapture.
|
||||
|
||||
Usage:
|
||||
with CameraCapture(device_index=0) as cam:
|
||||
for frame_bgr in cam.frames():
|
||||
process(frame_bgr)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device_index: int = 0,
|
||||
width: int = 640,
|
||||
height: int = 480,
|
||||
fps: int = 30,
|
||||
) -> None:
|
||||
self._cap = cv2.VideoCapture(device_index)
|
||||
self._cap.set(cv2.CAP_PROP_FRAME_WIDTH, width)
|
||||
self._cap.set(cv2.CAP_PROP_FRAME_HEIGHT, height)
|
||||
self._cap.set(cv2.CAP_PROP_FPS, fps)
|
||||
|
||||
@property
|
||||
def is_open(self) -> bool:
|
||||
return self._cap.isOpened()
|
||||
|
||||
def frames(self) -> Iterator[np.ndarray]:
|
||||
"""Yield BGR uint8 frames until camera fails or caller breaks."""
|
||||
while self._cap.isOpened():
|
||||
ok, frame = self._cap.read()
|
||||
if not ok:
|
||||
break
|
||||
yield frame
|
||||
|
||||
def release(self) -> None:
|
||||
self._cap.release()
|
||||
|
||||
def __enter__(self) -> CameraCapture:
|
||||
return self
|
||||
|
||||
def __exit__(self, *_: object) -> None:
|
||||
self.release()
|
||||
|
|
@ -1,91 +0,0 @@
|
|||
"""
|
||||
MediaPipe Hands wrapper.
|
||||
|
||||
Produces immutable HandLandmarks dataclasses from RGB video frames.
|
||||
The caller is responsible for BGR→RGB conversion before passing frames.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import mediapipe as mp
|
||||
import numpy as np
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class HandLandmarks:
|
||||
"""Immutable snapshot of one detected hand."""
|
||||
|
||||
points: np.ndarray # shape (21, 3) — x, y, z in [0,1] normalized image space
|
||||
handedness: str # 'Left' | 'Right' (mirror of physical hand)
|
||||
confidence: float # [0.0, 1.0]
|
||||
|
||||
|
||||
class HandsDetector:
|
||||
"""
|
||||
Thin wrapper around mediapipe.solutions.hands.Hands.
|
||||
|
||||
Usage:
|
||||
detector = HandsDetector()
|
||||
for frame_bgr in camera.frames():
|
||||
frame_rgb = frame_bgr[:, :, ::-1]
|
||||
hands = detector.detect(frame_rgb)
|
||||
for hand in hands:
|
||||
vec = normalize_hand(hand.points)
|
||||
...
|
||||
detector.close()
|
||||
|
||||
Or use as a context manager:
|
||||
with HandsDetector() as detector:
|
||||
...
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_hands: int = 2,
|
||||
min_detection_confidence: float = 0.7,
|
||||
min_tracking_confidence: float = 0.5,
|
||||
) -> None:
|
||||
self._hands = mp.solutions.hands.Hands(
|
||||
static_image_mode=False,
|
||||
max_num_hands=max_hands,
|
||||
min_detection_confidence=min_detection_confidence,
|
||||
min_tracking_confidence=min_tracking_confidence,
|
||||
)
|
||||
|
||||
def detect(self, rgb_frame: np.ndarray) -> list[HandLandmarks]:
|
||||
"""
|
||||
Run hand detection on one RGB frame.
|
||||
|
||||
Args:
|
||||
rgb_frame: (H, W, 3) uint8 RGB image.
|
||||
|
||||
Returns:
|
||||
List of HandLandmarks, one per detected hand (up to max_hands).
|
||||
Empty list if no hands detected.
|
||||
"""
|
||||
results = self._hands.process(rgb_frame)
|
||||
if not results.multi_hand_landmarks:
|
||||
return []
|
||||
out: list[HandLandmarks] = []
|
||||
for lm, hand in zip(results.multi_hand_landmarks, results.multi_handedness):
|
||||
points = np.array([[p.x, p.y, p.z] for p in lm.landmark], dtype=np.float32)
|
||||
points.flags.writeable = False # enforce immutability of stored array
|
||||
out.append(
|
||||
HandLandmarks(
|
||||
points=points,
|
||||
handedness=hand.classification[0].label,
|
||||
confidence=float(hand.classification[0].score),
|
||||
)
|
||||
)
|
||||
return out
|
||||
|
||||
def close(self) -> None:
|
||||
self._hands.close()
|
||||
|
||||
def __enter__(self) -> HandsDetector:
|
||||
return self
|
||||
|
||||
def __exit__(self, *_: object) -> None:
|
||||
self.close()
|
||||
|
|
@ -1,33 +0,0 @@
|
|||
"""
|
||||
Landmark normalization for MediaPipe hand landmarks.
|
||||
|
||||
Converts raw (21, 3) landmark array into a 63-element translation- and
|
||||
scale-invariant feature vector suitable for gesture classifiers.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def normalize_hand(points: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Normalize 21 MediaPipe hand landmarks into a scale/translation-invariant
|
||||
63-element feature vector.
|
||||
|
||||
Steps:
|
||||
1. Translate so wrist (landmark 0) is at origin.
|
||||
2. Scale so distance from wrist to middle-finger MCP (landmark 9) = 1.0.
|
||||
If that distance is near-zero (degenerate hand), return zeros.
|
||||
3. Flatten to shape (63,).
|
||||
|
||||
Args:
|
||||
points: (21, 3) float32 array — raw MediaPipe landmark coords.
|
||||
|
||||
Returns:
|
||||
(63,) float32 feature vector.
|
||||
"""
|
||||
pts = points.astype(np.float32).copy()
|
||||
pts -= pts[0] # translate: wrist → origin
|
||||
scale = float(np.linalg.norm(pts[9])) # wrist-to-middle-MCP distance
|
||||
if scale > 1e-6:
|
||||
pts /= scale
|
||||
return pts.flatten()
|
||||
|
|
@ -1,23 +0,0 @@
|
|||
"""
|
||||
circuitforge_core.job_quality — deterministic trust scorer for job listings.
|
||||
|
||||
MIT licensed.
|
||||
"""
|
||||
|
||||
from circuitforge_core.job_quality.models import (
|
||||
JobEnrichment,
|
||||
JobListing,
|
||||
JobQualityScore,
|
||||
SignalResult,
|
||||
)
|
||||
from circuitforge_core.job_quality.scorer import score_job
|
||||
from circuitforge_core.job_quality.signals import ALL_SIGNALS
|
||||
|
||||
__all__ = [
|
||||
"JobEnrichment",
|
||||
"JobListing",
|
||||
"JobQualityScore",
|
||||
"SignalResult",
|
||||
"score_job",
|
||||
"ALL_SIGNALS",
|
||||
]
|
||||
|
|
@ -1,70 +0,0 @@
|
|||
"""
|
||||
Pydantic models for the job_quality trust scorer.
|
||||
|
||||
MIT licensed — no LLM calls, no network calls, no file I/O.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class JobListing(BaseModel):
|
||||
"""Input data sourced directly from a job board scraper or ATS export."""
|
||||
|
||||
# Core identity
|
||||
title: str = ""
|
||||
company: str = ""
|
||||
location: str = ""
|
||||
state_code: str = "" # Two-letter US state code, e.g. "CA"
|
||||
|
||||
# Salary / compensation
|
||||
salary_min: float | None = None
|
||||
salary_max: float | None = None
|
||||
salary_text: str = "" # Raw salary string from the listing
|
||||
|
||||
# Posting metadata
|
||||
posted_at: datetime | None = None
|
||||
repost_count: int = 0 # Times the same listing has been reposted
|
||||
applicant_count: int | None = None
|
||||
is_staffing_agency: bool = False
|
||||
is_always_open: bool = False # Evergreen/always-accepting flag
|
||||
|
||||
# Content
|
||||
description: str = ""
|
||||
requirements: list[str] = Field(default_factory=list)
|
||||
ats_url: str = "" # ATS apply URL (Greenhouse, Lever, Workday, etc.)
|
||||
|
||||
# Signals from scraper enrichment
|
||||
weekend_posted: bool = False # Posted on Saturday or Sunday
|
||||
|
||||
|
||||
class JobEnrichment(BaseModel):
|
||||
"""Optional enrichment data gathered outside the listing (news, history, etc.)."""
|
||||
|
||||
has_layoff_news: bool = False # Recent layoff news for this company
|
||||
avg_response_days: float | None = None # Average recruiter response time (days)
|
||||
no_response_rate: float | None = None # Fraction of applicants with no response (0–1)
|
||||
|
||||
|
||||
class SignalResult(BaseModel):
|
||||
"""Output of a single signal function."""
|
||||
|
||||
name: str
|
||||
triggered: bool
|
||||
weight: float
|
||||
penalty: float # weight * triggered (0.0 when not triggered)
|
||||
detail: str = "" # Human-readable explanation
|
||||
|
||||
|
||||
class JobQualityScore(BaseModel):
|
||||
"""Aggregated trust score for a job listing."""
|
||||
|
||||
trust_score: float # 0.0 (low trust) – 1.0 (high trust)
|
||||
confidence: float # 0.0 – 1.0: fraction of signals with available evidence
|
||||
signals: list[SignalResult]
|
||||
raw_penalty: float # Sum of triggered weights before clamping
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
|
|
@ -1,60 +0,0 @@
|
|||
"""
|
||||
score_job: aggregate all signals into a JobQualityScore.
|
||||
|
||||
MIT licensed — pure function, no I/O.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from circuitforge_core.job_quality.models import JobEnrichment, JobListing, JobQualityScore, SignalResult
|
||||
from circuitforge_core.job_quality.signals import ALL_SIGNALS
|
||||
|
||||
|
||||
def score_job(
|
||||
listing: JobListing,
|
||||
enrichment: JobEnrichment | None = None,
|
||||
) -> JobQualityScore:
|
||||
"""
|
||||
Score a job listing for trust/quality.
|
||||
|
||||
Each signal produces a penalty in [0, weight]. The raw penalty is the sum of
|
||||
all triggered signal weights. trust_score = 1 - clamp(raw_penalty, 0, 1).
|
||||
|
||||
confidence reflects what fraction of signals had enough data to evaluate.
|
||||
Signals that return triggered=False with a "not available" detail are counted
|
||||
as unevaluable — they reduce confidence without adding penalty.
|
||||
"""
|
||||
results: list[SignalResult] = []
|
||||
evaluable_count = 0
|
||||
|
||||
for fn in ALL_SIGNALS:
|
||||
result = fn(listing, enrichment)
|
||||
results.append(result)
|
||||
# A signal is evaluable when it either triggered or had data to decide it didn't.
|
||||
# Signals that skip due to missing data always set triggered=False AND include
|
||||
# "not available" or "No" in their detail.
|
||||
if result.triggered or _has_data(result):
|
||||
evaluable_count += 1
|
||||
|
||||
raw_penalty = sum(r.penalty for r in results)
|
||||
trust_score = max(0.0, min(1.0, 1.0 - raw_penalty))
|
||||
confidence = evaluable_count / len(ALL_SIGNALS) if ALL_SIGNALS else 0.0
|
||||
|
||||
return JobQualityScore(
|
||||
trust_score=round(trust_score, 4),
|
||||
confidence=round(confidence, 4),
|
||||
signals=results,
|
||||
raw_penalty=round(raw_penalty, 4),
|
||||
)
|
||||
|
||||
|
||||
def _has_data(result: SignalResult) -> bool:
|
||||
"""Return True when the signal's detail indicates it actually evaluated data."""
|
||||
skip_phrases = (
|
||||
"not available",
|
||||
"No enrichment",
|
||||
"No posted_at",
|
||||
"No response rate",
|
||||
"No salary information",
|
||||
)
|
||||
return not any(phrase.lower() in result.detail.lower() for phrase in skip_phrases)
|
||||
|
|
@ -1,275 +0,0 @@
|
|||
"""
|
||||
Individual signal functions for the job_quality trust scorer.
|
||||
|
||||
Each function takes a JobListing and optional JobEnrichment and returns a SignalResult.
|
||||
All signals are pure functions: no I/O, no LLM calls, no side effects.
|
||||
|
||||
MIT licensed.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from circuitforge_core.job_quality.models import JobEnrichment, JobListing, SignalResult
|
||||
|
||||
# US states with salary transparency laws (as of 2026)
|
||||
_SALARY_TRANSPARENCY_STATES = {"CO", "CA", "NY", "WA", "IL", "MA"}
|
||||
|
||||
# ATS providers whose apply URLs are commonly associated with high ghosting rates
|
||||
_GHOSTING_ATS_PATTERNS = ("lever.co", "greenhouse.io", "workday.com", "icims.com", "taleo.net")
|
||||
|
||||
# Threshold for "always open" detection: repost every N days for M months
|
||||
_ALWAYS_OPEN_REPOST_THRESHOLD = 3
|
||||
|
||||
# Requirement count above which a listing is considered overloaded
|
||||
_REQUIREMENT_OVERLOAD_COUNT = 12
|
||||
|
||||
# Vagueness: description length below this suggests bare-minimum content
|
||||
_VAGUE_DESCRIPTION_CHARS = 400
|
||||
|
||||
# Applicant count above which competition is considered very high
|
||||
_HIGH_APPLICANT_THRESHOLD = 200
|
||||
|
||||
# Listing age above which staleness is likely
|
||||
_STALE_DAYS = 30
|
||||
|
||||
# Response rate above which the role is considered a high-ghosting source
|
||||
_NO_RESPONSE_RATE_THRESHOLD = 0.60
|
||||
|
||||
|
||||
def _now() -> datetime:
|
||||
return datetime.now(tz=timezone.utc)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# High-weight signals (0.15 – 0.25)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def listing_age(listing: JobListing, _: JobEnrichment | None = None) -> SignalResult:
|
||||
"""Listing older than 30 days is likely stale or already filled."""
|
||||
weight = 0.25
|
||||
if listing.posted_at is None:
|
||||
return SignalResult(name="listing_age", triggered=False, weight=weight, penalty=0.0,
|
||||
detail="No posted_at date available.")
|
||||
age_days = (_now() - listing.posted_at.astimezone(timezone.utc)).days
|
||||
triggered = age_days > _STALE_DAYS
|
||||
return SignalResult(
|
||||
name="listing_age",
|
||||
triggered=triggered,
|
||||
weight=weight,
|
||||
penalty=weight if triggered else 0.0,
|
||||
detail=f"Listing is {age_days} days old (threshold: {_STALE_DAYS}).",
|
||||
)
|
||||
|
||||
|
||||
def repost_detected(listing: JobListing, _: JobEnrichment | None = None) -> SignalResult:
|
||||
"""Listing has been reposted multiple times — a strong ghost-job indicator."""
|
||||
weight = 0.25
|
||||
triggered = listing.repost_count >= _ALWAYS_OPEN_REPOST_THRESHOLD
|
||||
return SignalResult(
|
||||
name="repost_detected",
|
||||
triggered=triggered,
|
||||
weight=weight,
|
||||
penalty=weight if triggered else 0.0,
|
||||
detail=f"Repost count: {listing.repost_count} (threshold: {_ALWAYS_OPEN_REPOST_THRESHOLD}).",
|
||||
)
|
||||
|
||||
|
||||
def no_salary_transparency(listing: JobListing, _: JobEnrichment | None = None) -> SignalResult:
|
||||
"""No salary info despite being in a transparency-law state, or generally absent."""
|
||||
weight = 0.20
|
||||
has_range = listing.salary_min is not None or listing.salary_max is not None
|
||||
has_text = bool(listing.salary_text.strip())
|
||||
has_salary = has_range or has_text
|
||||
in_transparency_state = listing.state_code.upper() in _SALARY_TRANSPARENCY_STATES
|
||||
|
||||
if not has_salary:
|
||||
if in_transparency_state:
|
||||
detail = (f"No salary disclosed despite {listing.state_code} transparency law. "
|
||||
"Possible compliance violation.")
|
||||
else:
|
||||
detail = "No salary information provided."
|
||||
triggered = True
|
||||
else:
|
||||
triggered = False
|
||||
detail = "Salary information present."
|
||||
|
||||
return SignalResult(
|
||||
name="no_salary_transparency",
|
||||
triggered=triggered,
|
||||
weight=weight,
|
||||
penalty=weight if triggered else 0.0,
|
||||
detail=detail,
|
||||
)
|
||||
|
||||
|
||||
def always_open_pattern(listing: JobListing, _: JobEnrichment | None = None) -> SignalResult:
|
||||
"""Listing is flagged as always-accepting or evergreen — pipeline filler."""
|
||||
weight = 0.20
|
||||
triggered = listing.is_always_open
|
||||
return SignalResult(
|
||||
name="always_open_pattern",
|
||||
triggered=triggered,
|
||||
weight=weight,
|
||||
penalty=weight if triggered else 0.0,
|
||||
detail="Listing marked as always-open/evergreen." if triggered else "Not always-open.",
|
||||
)
|
||||
|
||||
|
||||
def staffing_agency(listing: JobListing, _: JobEnrichment | None = None) -> SignalResult:
|
||||
"""Posted by a staffing or recruiting agency rather than the hiring company directly."""
|
||||
weight = 0.15
|
||||
triggered = listing.is_staffing_agency
|
||||
return SignalResult(
|
||||
name="staffing_agency",
|
||||
triggered=triggered,
|
||||
weight=weight,
|
||||
penalty=weight if triggered else 0.0,
|
||||
detail="Listed by a staffing/recruiting agency." if triggered else "Direct employer listing.",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Medium-weight signals (0.08 – 0.12)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def requirement_overload(listing: JobListing, _: JobEnrichment | None = None) -> SignalResult:
|
||||
"""Excessive requirements list suggests a wish-list role or perpetual search."""
|
||||
weight = 0.12
|
||||
count = len(listing.requirements)
|
||||
triggered = count > _REQUIREMENT_OVERLOAD_COUNT
|
||||
return SignalResult(
|
||||
name="requirement_overload",
|
||||
triggered=triggered,
|
||||
weight=weight,
|
||||
penalty=weight if triggered else 0.0,
|
||||
detail=f"{count} requirements listed (threshold: {_REQUIREMENT_OVERLOAD_COUNT}).",
|
||||
)
|
||||
|
||||
|
||||
def layoff_news(listing: JobListing, enrichment: JobEnrichment | None = None) -> SignalResult:
|
||||
"""Company has recent layoff news — new hires may be at high risk."""
|
||||
weight = 0.12
|
||||
if enrichment is None:
|
||||
return SignalResult(name="layoff_news", triggered=False, weight=weight, penalty=0.0,
|
||||
detail="No enrichment data available.")
|
||||
triggered = enrichment.has_layoff_news
|
||||
return SignalResult(
|
||||
name="layoff_news",
|
||||
triggered=triggered,
|
||||
weight=weight,
|
||||
penalty=weight if triggered else 0.0,
|
||||
detail="Recent layoff news detected for this company." if triggered else "No layoff news found.",
|
||||
)
|
||||
|
||||
|
||||
def jd_vagueness(listing: JobListing, _: JobEnrichment | None = None) -> SignalResult:
|
||||
"""Job description is suspiciously short — may not represent a real open role."""
|
||||
weight = 0.10
|
||||
char_count = len(listing.description.strip())
|
||||
triggered = char_count < _VAGUE_DESCRIPTION_CHARS
|
||||
return SignalResult(
|
||||
name="jd_vagueness",
|
||||
triggered=triggered,
|
||||
weight=weight,
|
||||
penalty=weight if triggered else 0.0,
|
||||
detail=f"Description is {char_count} characters (threshold: {_VAGUE_DESCRIPTION_CHARS}).",
|
||||
)
|
||||
|
||||
|
||||
def ats_blackhole(listing: JobListing, _: JobEnrichment | None = None) -> SignalResult:
|
||||
"""Apply URL routes through a high-volume ATS known for candidate ghosting."""
|
||||
weight = 0.10
|
||||
url_lower = listing.ats_url.lower()
|
||||
matched = next((p for p in _GHOSTING_ATS_PATTERNS if p in url_lower), None)
|
||||
triggered = matched is not None
|
||||
return SignalResult(
|
||||
name="ats_blackhole",
|
||||
triggered=triggered,
|
||||
weight=weight,
|
||||
penalty=weight if triggered else 0.0,
|
||||
detail=f"ATS matches high-ghosting pattern '{matched}'." if triggered else "No high-ghosting ATS detected.",
|
||||
)
|
||||
|
||||
|
||||
def high_applicant_count(listing: JobListing, _: JobEnrichment | None = None) -> SignalResult:
|
||||
"""Very high applicant count means low odds and possible ghost-collection."""
|
||||
weight = 0.08
|
||||
if listing.applicant_count is None:
|
||||
return SignalResult(name="high_applicant_count", triggered=False, weight=weight, penalty=0.0,
|
||||
detail="Applicant count not available.")
|
||||
triggered = listing.applicant_count > _HIGH_APPLICANT_THRESHOLD
|
||||
return SignalResult(
|
||||
name="high_applicant_count",
|
||||
triggered=triggered,
|
||||
weight=weight,
|
||||
penalty=weight if triggered else 0.0,
|
||||
detail=f"{listing.applicant_count} applicants (threshold: {_HIGH_APPLICANT_THRESHOLD}).",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Low-weight signals (0.04 – 0.08)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def weekend_posted(listing: JobListing, _: JobEnrichment | None = None) -> SignalResult:
|
||||
"""Posted on a weekend — may indicate bulk/automated ghost-job pipeline posting."""
|
||||
weight = 0.04
|
||||
if listing.posted_at is None and not listing.weekend_posted:
|
||||
return SignalResult(name="weekend_posted", triggered=False, weight=weight, penalty=0.0,
|
||||
detail="No posted_at date available.")
|
||||
if listing.weekend_posted:
|
||||
triggered = True
|
||||
else:
|
||||
triggered = listing.posted_at.weekday() >= 5 # type: ignore[union-attr]
|
||||
return SignalResult(
|
||||
name="weekend_posted",
|
||||
triggered=triggered,
|
||||
weight=weight,
|
||||
penalty=weight if triggered else 0.0,
|
||||
detail="Posted on a weekend." if triggered else "Posted on a weekday.",
|
||||
)
|
||||
|
||||
|
||||
def poor_response_history(listing: JobListing, enrichment: JobEnrichment | None = None) -> SignalResult:
|
||||
"""Company/ATS historically does not respond to applicants."""
|
||||
weight = 0.08
|
||||
if enrichment is None:
|
||||
return SignalResult(name="poor_response_history", triggered=False, weight=weight, penalty=0.0,
|
||||
detail="No enrichment data available.")
|
||||
rate = enrichment.no_response_rate
|
||||
if rate is None:
|
||||
return SignalResult(name="poor_response_history", triggered=False, weight=weight, penalty=0.0,
|
||||
detail="No response rate data available.")
|
||||
triggered = rate > _NO_RESPONSE_RATE_THRESHOLD
|
||||
return SignalResult(
|
||||
name="poor_response_history",
|
||||
triggered=triggered,
|
||||
weight=weight,
|
||||
penalty=weight if triggered else 0.0,
|
||||
detail=f"No-response rate: {rate:.0%} (threshold: {_NO_RESPONSE_RATE_THRESHOLD:.0%}).",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Signal registry — ordered by weight descending for scorer iteration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
ALL_SIGNALS = [
|
||||
listing_age,
|
||||
repost_detected,
|
||||
no_salary_transparency,
|
||||
always_open_pattern,
|
||||
staffing_agency,
|
||||
requirement_overload,
|
||||
layoff_news,
|
||||
jd_vagueness,
|
||||
ats_blackhole,
|
||||
high_applicant_count,
|
||||
weekend_posted,
|
||||
poor_response_history,
|
||||
]
|
||||
|
|
@ -1,49 +1,8 @@
|
|||
"""
|
||||
LLM abstraction layer with priority fallback chain.
|
||||
|
||||
Reads config from ~/.config/circuitforge/llm.yaml (or the path passed to
|
||||
LLMRouter.__init__). Tries backends in fallback_order; skips unreachable or
|
||||
disabled entries and falls back to the next until one succeeds.
|
||||
|
||||
## Backend types
|
||||
|
||||
**openai_compat** — OpenAI-compatible /v1/chat/completions endpoint.
|
||||
Used for: Ollama, vLLM, GitHub Copilot wrapper, Claude Code wrapper,
|
||||
and the cf-orch trunk services (cf-text, cf-voice).
|
||||
|
||||
With a cf_orch block the router first allocates via cf-orch, which
|
||||
starts the service on-demand and returns its URL. Without cf_orch the
|
||||
router does a static reachability check against base_url.
|
||||
|
||||
**anthropic** — Direct Anthropic API via the anthropic SDK.
|
||||
|
||||
**vision_service** — cf-vision managed service (moondream2 / SigLIP).
|
||||
Posts to /analyze; only used when images= is provided to complete().
|
||||
Supports cf_orch allocation to start cf-vision on-demand.
|
||||
|
||||
## Trunk services (The Orchard architecture)
|
||||
|
||||
These services live in cf-orch as branches; cf-core wires them as backends.
|
||||
Products declare them in llm.yaml using the openai_compat type plus a
|
||||
cf_orch block — the router handles allocation and URL injection transparently.
|
||||
|
||||
cf-text — Local transformer inference (/v1/chat/completions, port 8008).
|
||||
Default model set by default_model in the node's service
|
||||
profile; override via model_candidates in the cf_orch block.
|
||||
|
||||
cf-voice — STT/TTS pipeline endpoint (/v1/chat/completions, port 8009).
|
||||
Same allocation pattern as cf-text.
|
||||
|
||||
cf-vision — Vision inference (moondream2 / SigLIP), vision_service type.
|
||||
Used via the vision_fallback_order when images are present.
|
||||
|
||||
## Config auto-detection (no llm.yaml)
|
||||
|
||||
When llm.yaml is absent, the router builds a minimal config from environment
|
||||
variables: ANTHROPIC_API_KEY, OPENAI_API_KEY / OPENAI_BASE_URL, OLLAMA_HOST.
|
||||
Ollama on localhost:11434 is always included as the lowest-cost local fallback.
|
||||
Reads config from ~/.config/circuitforge/llm.yaml.
|
||||
Tries backends in order; falls back on any error.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import yaml
|
||||
|
|
@ -57,11 +16,8 @@ CONFIG_PATH = Path.home() / ".config" / "circuitforge" / "llm.yaml"
|
|||
|
||||
|
||||
class LLMRouter:
|
||||
def __init__(self, config_path: Path | dict = CONFIG_PATH):
|
||||
self._ollama_tags_cache: dict[str, set[str]] = {}
|
||||
if isinstance(config_path, dict):
|
||||
self.config = config_path
|
||||
elif config_path.exists():
|
||||
def __init__(self, config_path: Path = CONFIG_PATH):
|
||||
if config_path.exists():
|
||||
with open(config_path) as f:
|
||||
self.config = yaml.safe_load(f)
|
||||
else:
|
||||
|
|
@ -74,8 +30,7 @@ class LLMRouter:
|
|||
)
|
||||
logger.info(
|
||||
"[LLMRouter] No llm.yaml found — using env-var auto-config "
|
||||
"(backends: %s)",
|
||||
", ".join(env_config["fallback_order"]),
|
||||
"(backends: %s)", ", ".join(env_config["fallback_order"])
|
||||
)
|
||||
self.config = env_config
|
||||
|
||||
|
|
@ -108,9 +63,7 @@ class LLMRouter:
|
|||
backends["openai"] = {
|
||||
"type": "openai_compat",
|
||||
"enabled": True,
|
||||
"base_url": os.environ.get(
|
||||
"OPENAI_BASE_URL", "https://api.openai.com/v1"
|
||||
),
|
||||
"base_url": os.environ.get("OPENAI_BASE_URL", "https://api.openai.com/v1"),
|
||||
"model": os.environ.get("OPENAI_MODEL", "gpt-4o-mini"),
|
||||
"api_key": os.environ.get("OPENAI_API_KEY"),
|
||||
"supports_images": True,
|
||||
|
|
@ -148,37 +101,6 @@ class LLMRouter:
|
|||
except Exception:
|
||||
return False
|
||||
|
||||
def _check_ollama_model_pulled(self, base_url: str, model: str) -> None:
|
||||
"""Raise RuntimeError with actionable message if model is not pulled in Ollama.
|
||||
|
||||
Silently skips the check if the /api/tags endpoint is unavailable (e.g. vLLM).
|
||||
Results are cached per base_url for the lifetime of this router instance.
|
||||
"""
|
||||
tags_url = base_url.rstrip("/").removesuffix("/v1") + "/api/tags"
|
||||
if not hasattr(self, "_ollama_tags_cache"):
|
||||
self._ollama_tags_cache = {}
|
||||
if base_url not in self._ollama_tags_cache:
|
||||
try:
|
||||
resp = requests.get(tags_url, timeout=3)
|
||||
if resp.status_code != 200:
|
||||
return
|
||||
pulled = {
|
||||
m["name"].split(":")[0]
|
||||
for m in resp.json().get("models", [])
|
||||
}
|
||||
self._ollama_tags_cache[base_url] = pulled
|
||||
except Exception:
|
||||
return # can't verify — let the actual embed call fail naturally
|
||||
pulled_models = self._ollama_tags_cache.get(base_url)
|
||||
if pulled_models is None:
|
||||
return
|
||||
model_base = model.split(":")[0]
|
||||
if model_base not in pulled_models:
|
||||
raise RuntimeError(
|
||||
f'Ollama embedding model "{model}" is not pulled.\n'
|
||||
f"Fix: ollama pull {model}"
|
||||
)
|
||||
|
||||
def _resolve_model(self, client: OpenAI, model: str) -> str:
|
||||
"""Resolve __auto__ to the first model served by vLLM."""
|
||||
if model != "__auto__":
|
||||
|
|
@ -190,19 +112,10 @@ class LLMRouter:
|
|||
"""
|
||||
If backend config has a cf_orch block and CF_ORCH_URL is set (env takes
|
||||
precedence over yaml url), allocate via cf-orch and return (ctx, alloc).
|
||||
|
||||
Two allocation modes:
|
||||
- task-based (preferred): cf_orch block has `product` + `task` keys.
|
||||
Calls POST /api/inference/task; coordinator resolves model/node from
|
||||
assignments.yaml. No hardcoded model IDs in product config.
|
||||
- service-based (legacy): cf_orch block has `service` key.
|
||||
Calls allocate(service=...) directly.
|
||||
|
||||
Returns None if not configured or allocation fails.
|
||||
Caller MUST call ctx.__exit__(None, None, None) in a finally block.
|
||||
"""
|
||||
import os
|
||||
|
||||
orch_cfg = backend.get("cf_orch")
|
||||
if not orch_cfg:
|
||||
return None
|
||||
|
|
@ -211,46 +124,22 @@ class LLMRouter:
|
|||
return None
|
||||
try:
|
||||
from circuitforge_orch.client import CFOrchClient
|
||||
|
||||
client = CFOrchClient(orch_url)
|
||||
service = orch_cfg.get("service", "vllm")
|
||||
candidates = orch_cfg.get("model_candidates", [])
|
||||
ttl_s = float(orch_cfg.get("ttl_s", 3600.0))
|
||||
|
||||
# Task-based allocation: product+task → coordinator resolves model/node.
|
||||
task = orch_cfg.get("task")
|
||||
product = orch_cfg.get("product") or os.environ.get("CF_APP_NAME") or None
|
||||
if task and product:
|
||||
ctx = client.task_allocate(product, task, ttl_s=ttl_s)
|
||||
alloc = ctx.__enter__()
|
||||
return (ctx, alloc)
|
||||
|
||||
# Service-based allocation (legacy path).
|
||||
cf_app = os.environ.get("CF_APP_NAME") or None
|
||||
caller = f"{cf_app}.llm-router" if cf_app else "llm-router"
|
||||
ctx = client.allocate(
|
||||
orch_cfg.get("service", "vllm"),
|
||||
model_candidates=orch_cfg.get("model_candidates", []),
|
||||
ttl_s=ttl_s,
|
||||
caller=caller,
|
||||
pipeline=cf_app,
|
||||
)
|
||||
ctx = client.allocate(service, model_candidates=candidates, ttl_s=ttl_s, caller="llm-router")
|
||||
alloc = ctx.__enter__()
|
||||
return (ctx, alloc)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"[LLMRouter] cf_orch allocation failed, using base_url directly: %s",
|
||||
exc,
|
||||
)
|
||||
logger.warning("[LLMRouter] cf_orch allocation failed, using base_url directly: %s", exc)
|
||||
return None
|
||||
|
||||
def complete(
|
||||
self,
|
||||
prompt: str,
|
||||
system: str | None = None,
|
||||
model_override: str | None = None,
|
||||
fallback_order: list[str] | None = None,
|
||||
images: list[str] | None = None,
|
||||
max_tokens: int | None = None,
|
||||
) -> str:
|
||||
def complete(self, prompt: str, system: str | None = None,
|
||||
model_override: str | None = None,
|
||||
fallback_order: list[str] | None = None,
|
||||
images: list[str] | None = None,
|
||||
max_tokens: int | None = None) -> str:
|
||||
"""
|
||||
Generate a completion. Tries each backend in fallback_order.
|
||||
|
||||
|
|
@ -268,11 +157,7 @@ class LLMRouter:
|
|||
"AI inference is disabled in the public demo. "
|
||||
"Run your own instance to use AI features."
|
||||
)
|
||||
order = (
|
||||
fallback_order
|
||||
if fallback_order is not None
|
||||
else self.config["fallback_order"]
|
||||
)
|
||||
order = fallback_order if fallback_order is not None else self.config["fallback_order"]
|
||||
for name in order:
|
||||
backend = self.config["backends"][name]
|
||||
|
||||
|
|
@ -294,14 +179,7 @@ class LLMRouter:
|
|||
continue
|
||||
|
||||
if is_vision_service:
|
||||
# cf_orch: try allocation first (same pattern as openai_compat).
|
||||
# Allocation can start the vision service on-demand on the cluster.
|
||||
orch_ctx = orch_alloc = None
|
||||
orch_result = self._try_cf_orch_alloc(backend)
|
||||
if orch_result is not None:
|
||||
orch_ctx, orch_alloc = orch_result
|
||||
backend = {**backend, "base_url": orch_alloc.url}
|
||||
elif not self._is_reachable(backend["base_url"]):
|
||||
if not self._is_reachable(backend["base_url"]):
|
||||
print(f"[LLMRouter] {name}: unreachable, skipping")
|
||||
continue
|
||||
try:
|
||||
|
|
@ -319,23 +197,17 @@ class LLMRouter:
|
|||
except Exception as e:
|
||||
print(f"[LLMRouter] {name}: error — {e}, trying next")
|
||||
continue
|
||||
finally:
|
||||
if orch_ctx is not None:
|
||||
orch_ctx.__exit__(None, None, None)
|
||||
|
||||
elif backend["type"] == "openai_compat":
|
||||
# cf_orch: try allocation first — this may start the service on-demand.
|
||||
# Do NOT reachability-check before allocating; the service may be stopped
|
||||
# and the allocation is what starts it.
|
||||
if not self._is_reachable(backend["base_url"]):
|
||||
print(f"[LLMRouter] {name}: unreachable, skipping")
|
||||
continue
|
||||
# --- cf_orch: optionally override base_url with coordinator-allocated URL ---
|
||||
orch_ctx = orch_alloc = None
|
||||
orch_result = self._try_cf_orch_alloc(backend)
|
||||
if orch_result is not None:
|
||||
orch_ctx, orch_alloc = orch_result
|
||||
backend = {**backend, "base_url": orch_alloc.url + "/v1"}
|
||||
elif not self._is_reachable(backend["base_url"]):
|
||||
# Static backend (no cf-orch) — skip if not reachable.
|
||||
print(f"[LLMRouter] {name}: unreachable, skipping")
|
||||
continue
|
||||
try:
|
||||
client = OpenAI(
|
||||
base_url=backend["base_url"],
|
||||
|
|
@ -349,14 +221,10 @@ class LLMRouter:
|
|||
if images and supports_images:
|
||||
content = [{"type": "text", "text": prompt}]
|
||||
for img in images:
|
||||
content.append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/png;base64,{img}"
|
||||
},
|
||||
}
|
||||
)
|
||||
content.append({
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/png;base64,{img}"},
|
||||
})
|
||||
messages.append({"role": "user", "content": content})
|
||||
else:
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
|
@ -381,27 +249,18 @@ class LLMRouter:
|
|||
elif backend["type"] == "anthropic":
|
||||
api_key = os.environ.get(backend["api_key_env"], "")
|
||||
if not api_key:
|
||||
print(
|
||||
f"[LLMRouter] {name}: {backend['api_key_env']} not set, skipping"
|
||||
)
|
||||
print(f"[LLMRouter] {name}: {backend['api_key_env']} not set, skipping")
|
||||
continue
|
||||
try:
|
||||
import anthropic as _anthropic
|
||||
|
||||
client = _anthropic.Anthropic(api_key=api_key)
|
||||
if images and supports_images:
|
||||
content = []
|
||||
for img in images:
|
||||
content.append(
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": "image/png",
|
||||
"data": img,
|
||||
},
|
||||
}
|
||||
)
|
||||
content.append({
|
||||
"type": "image",
|
||||
"source": {"type": "base64", "media_type": "image/png", "data": img},
|
||||
})
|
||||
content.append({"type": "text", "text": prompt})
|
||||
else:
|
||||
content = prompt
|
||||
|
|
@ -421,84 +280,6 @@ class LLMRouter:
|
|||
|
||||
raise RuntimeError("All LLM backends exhausted")
|
||||
|
||||
def embed(
|
||||
self,
|
||||
texts: list[str],
|
||||
model_override: str | None = None,
|
||||
fallback_order: list[str] | None = None,
|
||||
) -> list[list[float]]:
|
||||
"""
|
||||
Generate embeddings for a list of texts.
|
||||
|
||||
Only openai_compat backends are tried — Ollama and vLLM expose
|
||||
/v1/embeddings; anthropic and vision_service do not.
|
||||
|
||||
Uses ``embedding_model`` from backend config when present;
|
||||
falls back to ``model`` (the chat model) otherwise.
|
||||
|
||||
Args:
|
||||
texts: Texts to embed (batched in a single API call).
|
||||
model_override: Override the embedding model for this call.
|
||||
fallback_order: Override the backend fallback order for this call.
|
||||
|
||||
Returns:
|
||||
List of float vectors, one per input text, in input order.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If all eligible backends are exhausted.
|
||||
"""
|
||||
if os.environ.get("DEMO_MODE", "").lower() in ("1", "true", "yes"):
|
||||
raise RuntimeError(
|
||||
"AI inference is disabled in the public demo. "
|
||||
"Run your own instance to use AI features."
|
||||
)
|
||||
order = (
|
||||
fallback_order
|
||||
if fallback_order is not None
|
||||
else self.config["fallback_order"]
|
||||
)
|
||||
for name in order:
|
||||
backend = self.config["backends"][name]
|
||||
if not backend.get("enabled", True):
|
||||
continue
|
||||
if backend["type"] != "openai_compat":
|
||||
continue
|
||||
|
||||
orch_ctx = orch_alloc = None
|
||||
orch_result = self._try_cf_orch_alloc(backend)
|
||||
if orch_result is not None:
|
||||
orch_ctx, orch_alloc = orch_result
|
||||
backend = {**backend, "base_url": orch_alloc.url + "/v1"}
|
||||
elif not self._is_reachable(backend["base_url"]):
|
||||
print(f"[LLMRouter] {name}: unreachable, skipping")
|
||||
continue
|
||||
|
||||
embed_model = model_override or backend.get(
|
||||
"embedding_model", backend["model"]
|
||||
)
|
||||
self._check_ollama_model_pulled(backend["base_url"], embed_model)
|
||||
|
||||
try:
|
||||
client = OpenAI(
|
||||
base_url=backend["base_url"],
|
||||
api_key=backend.get("api_key") or "any",
|
||||
)
|
||||
model = embed_model
|
||||
resp = client.embeddings.create(model=model, input=texts)
|
||||
print(f"[LLMRouter] embed: used backend {name} ({model})")
|
||||
return [item.embedding for item in resp.data]
|
||||
except Exception as e:
|
||||
print(f"[LLMRouter] {name}: embed error — {e}, trying next")
|
||||
continue
|
||||
finally:
|
||||
if orch_ctx is not None:
|
||||
try:
|
||||
orch_ctx.__exit__(None, None, None)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
raise RuntimeError("All LLM backends exhausted for embed()")
|
||||
|
||||
|
||||
# Module-level singleton for convenience
|
||||
_router: LLMRouter | None = None
|
||||
|
|
|
|||
|
|
@ -1 +0,0 @@
|
|||
"""circuitforge_core.musicgen — music continuation service (BSL 1.1)."""
|
||||
|
|
@ -1,138 +0,0 @@
|
|||
"""
|
||||
cf-musicgen FastAPI service — managed by cf-orch.
|
||||
|
||||
Endpoints:
|
||||
GET /health -> {"status": "ok", "model": str, "vram_mb": int}
|
||||
POST /continue -> audio bytes (Content-Type: audio/wav or audio/mpeg)
|
||||
|
||||
Usage:
|
||||
python -m circuitforge_core.musicgen.app \
|
||||
--model facebook/musicgen-melody \
|
||||
--port 8006 \
|
||||
--gpu-id 0
|
||||
|
||||
The service streams back raw audio bytes. Headers include:
|
||||
X-Duration-S generated duration in seconds
|
||||
X-Prompt-Duration-S how many seconds of the input were used as prompt
|
||||
X-Model model name
|
||||
X-Sample-Rate output sample rate (32000 for all MusicGen variants)
|
||||
|
||||
Model weights are cached at /Library/Assets/LLM/musicgen/.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import FastAPI, File, Form, HTTPException, UploadFile
|
||||
from fastapi.responses import Response
|
||||
|
||||
from circuitforge_core.musicgen.backends.base import (
|
||||
MODEL_MELODY,
|
||||
MODEL_SMALL,
|
||||
AudioFormat,
|
||||
MusicGenBackend,
|
||||
make_musicgen_backend,
|
||||
)
|
||||
|
||||
_CONTENT_TYPES: dict[str, str] = {
|
||||
"wav": "audio/wav",
|
||||
"mp3": "audio/mpeg",
|
||||
}
|
||||
|
||||
app = FastAPI(title="cf-musicgen", version="0.1.0")
|
||||
_backend: MusicGenBackend | None = None
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
def health() -> dict:
|
||||
if _backend is None:
|
||||
raise HTTPException(503, detail="backend not initialised")
|
||||
return {
|
||||
"status": "ok",
|
||||
"model": _backend.model_name,
|
||||
"vram_mb": _backend.vram_mb,
|
||||
}
|
||||
|
||||
|
||||
@app.post("/continue")
|
||||
async def continue_audio(
|
||||
audio: UploadFile = File(..., description="Audio file (WAV, MP3, FLAC, OGG, ...)"),
|
||||
description: Annotated[str | None, Form()] = None,
|
||||
duration_s: Annotated[float, Form()] = 15.0,
|
||||
prompt_duration_s: Annotated[float, Form()] = 10.0,
|
||||
format: Annotated[AudioFormat, Form()] = "wav",
|
||||
) -> Response:
|
||||
if _backend is None:
|
||||
raise HTTPException(503, detail="backend not initialised")
|
||||
if duration_s <= 0 or duration_s > 60:
|
||||
raise HTTPException(422, detail="duration_s must be between 0 and 60")
|
||||
if prompt_duration_s <= 0 or prompt_duration_s > 30:
|
||||
raise HTTPException(422, detail="prompt_duration_s must be between 0 and 30")
|
||||
|
||||
audio_bytes = await audio.read()
|
||||
if not audio_bytes:
|
||||
raise HTTPException(400, detail="Empty audio file")
|
||||
|
||||
try:
|
||||
result = _backend.continue_audio(
|
||||
audio_bytes,
|
||||
description=description or None,
|
||||
duration_s=duration_s,
|
||||
prompt_duration_s=prompt_duration_s,
|
||||
format=format,
|
||||
)
|
||||
except Exception as exc:
|
||||
logging.exception("Music continuation failed")
|
||||
raise HTTPException(500, detail=str(exc)) from exc
|
||||
|
||||
return Response(
|
||||
content=result.audio_bytes,
|
||||
media_type=_CONTENT_TYPES.get(result.format, "audio/wav"),
|
||||
headers={
|
||||
"X-Duration-S": str(round(result.duration_s, 3)),
|
||||
"X-Prompt-Duration-S": str(round(result.prompt_duration_s, 3)),
|
||||
"X-Model": result.model,
|
||||
"X-Sample-Rate": str(result.sample_rate),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _parse_args() -> argparse.Namespace:
|
||||
p = argparse.ArgumentParser(description="cf-musicgen service")
|
||||
p.add_argument(
|
||||
"--model",
|
||||
default=MODEL_MELODY,
|
||||
choices=[MODEL_MELODY, MODEL_SMALL, "facebook/musicgen-medium", "facebook/musicgen-large"],
|
||||
help="MusicGen model variant",
|
||||
)
|
||||
p.add_argument("--port", type=int, default=8006)
|
||||
p.add_argument("--host", default="0.0.0.0")
|
||||
p.add_argument("--gpu-id", type=int, default=0,
|
||||
help="CUDA device index (sets CUDA_VISIBLE_DEVICES)")
|
||||
p.add_argument("--device", default="cuda", choices=["cuda", "cpu"])
|
||||
p.add_argument("--mock", action="store_true",
|
||||
help="Run with mock backend (no GPU, for testing)")
|
||||
return p.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s %(levelname)s %(name)s %(message)s",
|
||||
)
|
||||
args = _parse_args()
|
||||
|
||||
if args.device == "cuda" and not args.mock:
|
||||
os.environ.setdefault("CUDA_VISIBLE_DEVICES", str(args.gpu_id))
|
||||
|
||||
mock = args.mock or args.model == "mock"
|
||||
device = "cpu" if mock else args.device
|
||||
|
||||
_backend = make_musicgen_backend(model_name=args.model, mock=mock, device=device)
|
||||
|
||||
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|
||||
|
|
@ -1 +0,0 @@
|
|||
"""MusicGen backend implementations."""
|
||||
|
|
@ -1,128 +0,0 @@
|
|||
"""
|
||||
AudioCraft MusicGen backend — music continuation via Meta's MusicGen.
|
||||
|
||||
Models are downloaded to /Library/Assets/LLM/musicgen/ (HF hub cache).
|
||||
The melody model (~8 GB VRAM) is the default; small (~1.5 GB) is available
|
||||
for lower-VRAM nodes.
|
||||
|
||||
Continuation workflow:
|
||||
1. Decode input audio with torchaudio (any format ffmpeg understands)
|
||||
2. Trim to the last `prompt_duration_s` seconds — this anchors the generation
|
||||
3. Call model.generate_continuation(prompt_waveform, prompt_sample_rate, ...)
|
||||
4. Output tensor is the NEW audio only (not prompt + continuation)
|
||||
5. Encode to the requested format and return
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
from circuitforge_core.musicgen.backends.base import (
|
||||
AudioFormat,
|
||||
MusicContinueResult,
|
||||
decode_audio,
|
||||
encode_audio,
|
||||
)
|
||||
|
||||
# All MusicGen/AudioCraft weights land here — consistent with other CF model dirs.
|
||||
_MUSICGEN_CACHE = "/Library/Assets/LLM/musicgen"
|
||||
|
||||
# VRAM estimates (MB) per model variant
|
||||
_VRAM_MB: dict[str, int] = {
|
||||
"facebook/musicgen-small": 1500,
|
||||
"facebook/musicgen-medium": 4500,
|
||||
"facebook/musicgen-melody": 8000,
|
||||
"facebook/musicgen-large": 8500,
|
||||
}
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AudioCraftBackend:
|
||||
"""MusicGen backend using Meta's AudioCraft library."""
|
||||
|
||||
def __init__(self, model_name: str = "facebook/musicgen-melody", device: str = "cuda") -> None:
|
||||
# Redirect HF hub cache before the first import so weights go to /Library/Assets
|
||||
os.environ.setdefault("HF_HOME", _MUSICGEN_CACHE)
|
||||
os.makedirs(_MUSICGEN_CACHE, exist_ok=True)
|
||||
|
||||
from audiocraft.models import MusicGen # noqa: PLC0415
|
||||
|
||||
logger.info("Loading MusicGen model: %s on %s", model_name, device)
|
||||
self._model = MusicGen.get_pretrained(model_name, device=device)
|
||||
self._model_name = model_name
|
||||
self._device = device
|
||||
logger.info("MusicGen ready: %s", model_name)
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
return self._model_name
|
||||
|
||||
@property
|
||||
def vram_mb(self) -> int:
|
||||
return _VRAM_MB.get(self._model_name, 8000)
|
||||
|
||||
def continue_audio(
|
||||
self,
|
||||
audio_bytes: bytes,
|
||||
*,
|
||||
description: str | None = None,
|
||||
duration_s: float = 15.0,
|
||||
prompt_duration_s: float = 10.0,
|
||||
format: AudioFormat = "wav",
|
||||
) -> MusicContinueResult:
|
||||
import torch
|
||||
|
||||
# Decode input audio -> [C, T] tensor
|
||||
wav, sr = decode_audio(audio_bytes)
|
||||
|
||||
# Trim to the last `prompt_duration_s` seconds to form the conditioning prompt.
|
||||
# Using the end of the track (not the beginning) gives the model the musical
|
||||
# context closest to where we want to continue.
|
||||
max_prompt_samples = int(prompt_duration_s * sr)
|
||||
if wav.shape[-1] > max_prompt_samples:
|
||||
wav = wav[..., -max_prompt_samples:]
|
||||
|
||||
# MusicGen expects [batch, channels, time]
|
||||
prompt_tensor = wav.unsqueeze(0).to(self._device)
|
||||
|
||||
# Build descriptions list — one entry per batch item (batch=1 here)
|
||||
descriptions = [description] if description else [None]
|
||||
|
||||
self._model.set_generation_params(
|
||||
duration=duration_s,
|
||||
top_k=250,
|
||||
temperature=1.0,
|
||||
cfg_coef=3.0,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Generating %.1fs continuation (prompt=%.1fs) model=%s",
|
||||
duration_s,
|
||||
prompt_duration_s,
|
||||
self._model_name,
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
output = self._model.generate_continuation(
|
||||
prompt=prompt_tensor,
|
||||
prompt_sample_rate=sr,
|
||||
descriptions=descriptions,
|
||||
progress=True,
|
||||
)
|
||||
|
||||
# output: [batch, channels, time] at model sample rate (32 kHz)
|
||||
output_wav = output[0] # [C, T]
|
||||
model_sr = self._model.sample_rate
|
||||
|
||||
actual_duration_s = output_wav.shape[-1] / model_sr
|
||||
audio_bytes_out = encode_audio(output_wav, model_sr, format)
|
||||
|
||||
return MusicContinueResult(
|
||||
audio_bytes=audio_bytes_out,
|
||||
sample_rate=model_sr,
|
||||
duration_s=actual_duration_s,
|
||||
format=format,
|
||||
model=self._model_name,
|
||||
prompt_duration_s=prompt_duration_s,
|
||||
)
|
||||
|
|
@ -1,97 +0,0 @@
|
|||
"""
|
||||
MusicGenBackend Protocol — backend-agnostic music continuation interface.
|
||||
|
||||
All backends accept an audio prompt (raw bytes, any ffmpeg-readable format) and
|
||||
return MusicContinueResult with the generated continuation as audio bytes.
|
||||
|
||||
The continuation is the *new* audio only (not prompt + continuation). Callers
|
||||
that want a seamless joined file can concatenate the original + result themselves.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal, Protocol, runtime_checkable
|
||||
|
||||
AudioFormat = Literal["wav", "mp3"]
|
||||
|
||||
MODEL_SMALL = "facebook/musicgen-small"
|
||||
MODEL_MELODY = "facebook/musicgen-melody"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MusicContinueResult:
|
||||
audio_bytes: bytes
|
||||
sample_rate: int
|
||||
duration_s: float
|
||||
format: AudioFormat
|
||||
model: str
|
||||
prompt_duration_s: float
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class MusicGenBackend(Protocol):
|
||||
def continue_audio(
|
||||
self,
|
||||
audio_bytes: bytes,
|
||||
*,
|
||||
description: str | None = None,
|
||||
duration_s: float = 15.0,
|
||||
prompt_duration_s: float = 10.0,
|
||||
format: AudioFormat = "wav",
|
||||
) -> MusicContinueResult: ...
|
||||
|
||||
@property
|
||||
def model_name(self) -> str: ...
|
||||
|
||||
@property
|
||||
def vram_mb(self) -> int: ...
|
||||
|
||||
|
||||
def encode_audio(wav_tensor, sample_rate: int, format: AudioFormat) -> bytes:
|
||||
"""Encode a [C, T] or [1, C, T] torch tensor to audio bytes."""
|
||||
import io
|
||||
import torch
|
||||
import torchaudio
|
||||
|
||||
wav = wav_tensor
|
||||
if wav.dim() == 3:
|
||||
wav = wav.squeeze(0) # [1, C, T] -> [C, T]
|
||||
if wav.dim() == 1:
|
||||
wav = wav.unsqueeze(0) # [T] -> [1, T]
|
||||
wav = wav.to(torch.float32).cpu()
|
||||
|
||||
buf = io.BytesIO()
|
||||
if format == "wav":
|
||||
torchaudio.save(buf, wav, sample_rate, format="wav")
|
||||
elif format == "mp3":
|
||||
try:
|
||||
torchaudio.save(buf, wav, sample_rate, format="mp3")
|
||||
except Exception:
|
||||
# ffmpeg backend not available; fall back to wav
|
||||
buf = io.BytesIO()
|
||||
torchaudio.save(buf, wav, sample_rate, format="wav")
|
||||
return buf.getvalue()
|
||||
|
||||
|
||||
def decode_audio(audio_bytes: bytes) -> tuple:
|
||||
"""Decode arbitrary audio bytes to (waveform [C, T], sample_rate)."""
|
||||
import io
|
||||
import torchaudio
|
||||
|
||||
buf = io.BytesIO(audio_bytes)
|
||||
wav, sr = torchaudio.load(buf)
|
||||
return wav, sr
|
||||
|
||||
|
||||
def make_musicgen_backend(
|
||||
model_name: str = MODEL_MELODY,
|
||||
*,
|
||||
mock: bool = False,
|
||||
device: str = "cuda",
|
||||
) -> MusicGenBackend:
|
||||
if mock:
|
||||
from circuitforge_core.musicgen.backends.mock import MockMusicGenBackend
|
||||
return MockMusicGenBackend()
|
||||
from circuitforge_core.musicgen.backends.audiocraft import AudioCraftBackend
|
||||
return AudioCraftBackend(model_name=model_name, device=device)
|
||||
|
|
@ -1,53 +0,0 @@
|
|||
"""
|
||||
Mock MusicGenBackend — returns silent WAV audio; no GPU required.
|
||||
|
||||
Used in unit tests and CI where GPU is unavailable.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import struct
|
||||
import wave
|
||||
|
||||
from circuitforge_core.musicgen.backends.base import AudioFormat, MusicContinueResult
|
||||
|
||||
|
||||
class MockMusicGenBackend:
|
||||
"""Returns a silent WAV file of the requested duration."""
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
return "mock"
|
||||
|
||||
@property
|
||||
def vram_mb(self) -> int:
|
||||
return 0
|
||||
|
||||
def continue_audio(
|
||||
self,
|
||||
audio_bytes: bytes,
|
||||
*,
|
||||
description: str | None = None,
|
||||
duration_s: float = 15.0,
|
||||
prompt_duration_s: float = 10.0,
|
||||
format: AudioFormat = "wav",
|
||||
) -> MusicContinueResult:
|
||||
sample_rate = 32000
|
||||
n_samples = int(duration_s * sample_rate)
|
||||
silent_samples = b"\x00\x00" * n_samples # 16-bit PCM silence
|
||||
|
||||
buf = io.BytesIO()
|
||||
with wave.open(buf, "wb") as wf:
|
||||
wf.setnchannels(1)
|
||||
wf.setsampwidth(2)
|
||||
wf.setframerate(sample_rate)
|
||||
wf.writeframes(silent_samples)
|
||||
|
||||
return MusicContinueResult(
|
||||
audio_bytes=buf.getvalue(),
|
||||
sample_rate=sample_rate,
|
||||
duration_s=duration_s,
|
||||
format="wav",
|
||||
model="mock",
|
||||
prompt_duration_s=prompt_duration_s,
|
||||
)
|
||||
|
|
@ -1,43 +1,3 @@
|
|||
# circuitforge_core/pipeline — FPGA→ASIC crystallization engine
|
||||
#
|
||||
# Public API: call pipeline.run() from product code instead of llm.router directly.
|
||||
# The module transparently checks for crystallized workflows first, falls back
|
||||
# to LLM when none match, and records each run for future crystallization.
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Callable
|
||||
|
||||
from .crystallizer import CrystallizerConfig, crystallize, evaluate_new_run, should_crystallize
|
||||
from .executor import ExecutionResult, Executor, StepResult
|
||||
from .models import CrystallizedWorkflow, PipelineRun, Step, hash_input
|
||||
from .multimodal import MultimodalConfig, MultimodalPipeline, PageResult
|
||||
from .recorder import Recorder
|
||||
from .registry import Registry
|
||||
from .staging import StagingDB
|
||||
|
||||
__all__ = [
|
||||
# models
|
||||
"PipelineRun",
|
||||
"CrystallizedWorkflow",
|
||||
"Step",
|
||||
"hash_input",
|
||||
# recorder
|
||||
"Recorder",
|
||||
# crystallizer
|
||||
"CrystallizerConfig",
|
||||
"crystallize",
|
||||
"evaluate_new_run",
|
||||
"should_crystallize",
|
||||
# registry
|
||||
"Registry",
|
||||
# executor
|
||||
"Executor",
|
||||
"ExecutionResult",
|
||||
"StepResult",
|
||||
# multimodal
|
||||
"MultimodalPipeline",
|
||||
"MultimodalConfig",
|
||||
"PageResult",
|
||||
# legacy stub
|
||||
"StagingDB",
|
||||
]
|
||||
__all__ = ["StagingDB"]
|
||||
|
|
|
|||
|
|
@ -1,177 +0,0 @@
|
|||
# circuitforge_core/pipeline/crystallizer.py — promote approved runs → workflows
|
||||
#
|
||||
# MIT — pure logic, no inference backends.
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import warnings
|
||||
from collections import Counter
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from typing import Literal
|
||||
|
||||
from .models import CrystallizedWorkflow, PipelineRun, Step
|
||||
from .recorder import Recorder
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
# Minimum milliseconds of review that counts as "genuine".
|
||||
# Runs shorter than this are accepted but trigger a warning.
|
||||
_RUBBER_STAMP_THRESHOLD_MS = 5_000
|
||||
|
||||
|
||||
@dataclass
|
||||
class CrystallizerConfig:
|
||||
"""Tuning knobs for one product/task-type pair.
|
||||
|
||||
threshold:
|
||||
Minimum number of approved runs required before crystallization.
|
||||
Osprey sets this to 1 (first successful IVR navigation is enough);
|
||||
Peregrine uses 3+ for cover-letter templates.
|
||||
min_review_ms:
|
||||
Approved runs with review_duration_ms below this value generate a
|
||||
warning. Set to 0 to silence the check (tests, automated approvals).
|
||||
strategy:
|
||||
``"most_recent"`` — use the latest approved run's steps verbatim.
|
||||
``"majority"`` — pick each step by majority vote across runs (requires
|
||||
runs to have the same step count; falls back to most_recent otherwise).
|
||||
"""
|
||||
threshold: int = 3
|
||||
min_review_ms: int = _RUBBER_STAMP_THRESHOLD_MS
|
||||
strategy: Literal["most_recent", "majority"] = "most_recent"
|
||||
|
||||
|
||||
# ── Helpers ───────────────────────────────────────────────────────────────────
|
||||
|
||||
def _majority_steps(runs: list[PipelineRun]) -> list[Step] | None:
|
||||
"""Return majority-voted steps, or None if run lengths differ."""
|
||||
lengths = {len(r.steps) for r in runs}
|
||||
if len(lengths) != 1:
|
||||
return None
|
||||
n = lengths.pop()
|
||||
result: list[Step] = []
|
||||
for i in range(n):
|
||||
counter: Counter[str] = Counter()
|
||||
step_by_action: dict[str, Step] = {}
|
||||
for r in runs:
|
||||
s = r.steps[i]
|
||||
counter[s.action] += 1
|
||||
step_by_action[s.action] = s
|
||||
winner = counter.most_common(1)[0][0]
|
||||
result.append(step_by_action[winner])
|
||||
return result
|
||||
|
||||
|
||||
def _check_review_quality(runs: list[PipelineRun],
|
||||
min_review_ms: int) -> None:
|
||||
"""Warn if any run has a suspiciously short review duration."""
|
||||
if min_review_ms <= 0:
|
||||
return
|
||||
flagged = [r for r in runs if r.review_duration_ms < min_review_ms]
|
||||
if flagged:
|
||||
ids = ", ".join(r.run_id for r in flagged)
|
||||
warnings.warn(
|
||||
f"Crystallizing from {len(flagged)} run(s) with review_duration_ms "
|
||||
f"< {min_review_ms} ms — possible rubber-stamp approval: [{ids}]. "
|
||||
"Verify these were genuinely human-reviewed before deployment.",
|
||||
stacklevel=3,
|
||||
)
|
||||
|
||||
|
||||
# ── Public API ────────────────────────────────────────────────────────────────
|
||||
|
||||
def should_crystallize(runs: list[PipelineRun],
|
||||
config: CrystallizerConfig) -> bool:
|
||||
"""Return True if *runs* meet the threshold for crystallization."""
|
||||
approved = [r for r in runs if r.approved]
|
||||
return len(approved) >= config.threshold
|
||||
|
||||
|
||||
def crystallize(runs: list[PipelineRun],
|
||||
config: CrystallizerConfig,
|
||||
existing_version: int = 0) -> CrystallizedWorkflow:
|
||||
"""Promote *runs* into a CrystallizedWorkflow.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If fewer approved runs than ``config.threshold``, or if the runs
|
||||
span more than one (product, task_type, input_hash) triple.
|
||||
"""
|
||||
approved = [r for r in runs if r.approved]
|
||||
if len(approved) < config.threshold:
|
||||
raise ValueError(
|
||||
f"Need {config.threshold} approved runs, got {len(approved)}."
|
||||
)
|
||||
|
||||
# Validate homogeneity
|
||||
products = {r.product for r in approved}
|
||||
task_types = {r.task_type for r in approved}
|
||||
hashes = {r.input_hash for r in approved}
|
||||
if len(products) != 1 or len(task_types) != 1 or len(hashes) != 1:
|
||||
raise ValueError(
|
||||
"All runs must share the same product, task_type, and input_hash. "
|
||||
f"Got products={products}, task_types={task_types}, hashes={hashes}."
|
||||
)
|
||||
|
||||
product = products.pop()
|
||||
task_type = task_types.pop()
|
||||
input_hash = hashes.pop()
|
||||
|
||||
_check_review_quality(approved, config.min_review_ms)
|
||||
|
||||
# Pick canonical steps
|
||||
if config.strategy == "majority":
|
||||
steps = _majority_steps(approved) or approved[-1].steps
|
||||
else:
|
||||
steps = sorted(approved, key=lambda r: r.timestamp)[-1].steps
|
||||
|
||||
avg_ms = sum(r.review_duration_ms for r in approved) // len(approved)
|
||||
all_unmodified = all(not r.output_modified for r in approved)
|
||||
|
||||
workflow_id = f"{product}:{task_type}:{input_hash[:12]}"
|
||||
return CrystallizedWorkflow(
|
||||
workflow_id=workflow_id,
|
||||
product=product,
|
||||
task_type=task_type,
|
||||
input_hash=input_hash,
|
||||
steps=steps,
|
||||
crystallized_at=datetime.now(timezone.utc).isoformat(),
|
||||
run_ids=[r.run_id for r in approved],
|
||||
approval_count=len(approved),
|
||||
avg_review_duration_ms=avg_ms,
|
||||
all_output_unmodified=all_unmodified,
|
||||
version=existing_version + 1,
|
||||
)
|
||||
|
||||
|
||||
def evaluate_new_run(
|
||||
run: PipelineRun,
|
||||
recorder: Recorder,
|
||||
config: CrystallizerConfig,
|
||||
existing_version: int = 0,
|
||||
) -> CrystallizedWorkflow | None:
|
||||
"""Record *run* and return a new workflow if the threshold is now met.
|
||||
|
||||
Products call this after each human-approved execution. Returns a
|
||||
``CrystallizedWorkflow`` if crystallization was triggered, ``None``
|
||||
otherwise.
|
||||
"""
|
||||
recorder.record(run)
|
||||
if not run.approved:
|
||||
return None
|
||||
|
||||
all_runs = recorder.load_approved(run.product, run.task_type, run.input_hash)
|
||||
if not should_crystallize(all_runs, config):
|
||||
log.debug(
|
||||
"pipeline: %d/%d approved runs for %s:%s — not yet crystallizing",
|
||||
len(all_runs), config.threshold, run.product, run.task_type,
|
||||
)
|
||||
return None
|
||||
|
||||
workflow = crystallize(all_runs, config, existing_version=existing_version)
|
||||
log.info(
|
||||
"pipeline: crystallized %s after %d approvals",
|
||||
workflow.workflow_id, workflow.approval_count,
|
||||
)
|
||||
return workflow
|
||||
|
|
@ -1,157 +0,0 @@
|
|||
# circuitforge_core/pipeline/executor.py — deterministic execution with LLM fallback
|
||||
#
|
||||
# MIT — orchestration logic only; calls product-supplied callables.
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable
|
||||
|
||||
from .models import CrystallizedWorkflow, Step
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class StepResult:
|
||||
step: Step
|
||||
success: bool
|
||||
output: Any = None
|
||||
error: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExecutionResult:
|
||||
"""Result of running a workflow (deterministic or LLM-assisted).
|
||||
|
||||
Attributes
|
||||
----------
|
||||
success:
|
||||
True if all steps completed without error.
|
||||
used_deterministic:
|
||||
True if a crystallized workflow was used; False if LLM was called.
|
||||
step_results:
|
||||
Per-step outcomes from the deterministic path.
|
||||
llm_output:
|
||||
Raw output from the LLM fallback path, if used.
|
||||
workflow_id:
|
||||
ID of the workflow used, or None for LLM path.
|
||||
error:
|
||||
Error message if the run failed entirely.
|
||||
"""
|
||||
success: bool
|
||||
used_deterministic: bool
|
||||
step_results: list[StepResult] = field(default_factory=list)
|
||||
llm_output: Any = None
|
||||
workflow_id: str | None = None
|
||||
error: str | None = None
|
||||
|
||||
|
||||
# ── Executor ──────────────────────────────────────────────────────────────────
|
||||
|
||||
class Executor:
|
||||
"""Runs crystallized workflows with transparent LLM fallback.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
step_fn:
|
||||
Called for each Step: ``step_fn(step) -> (success, output)``.
|
||||
The product supplies this — it knows how to turn a Step into a real
|
||||
action (DTMF dial, HTTP call, form field write, etc.).
|
||||
llm_fn:
|
||||
Called when no workflow matches or a step fails: ``llm_fn() -> output``.
|
||||
Products wire this to ``cf_core.llm.router`` or equivalent.
|
||||
llm_fallback:
|
||||
If False, raise RuntimeError instead of calling llm_fn on miss.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
step_fn: Callable[[Step], tuple[bool, Any]],
|
||||
llm_fn: Callable[[], Any],
|
||||
llm_fallback: bool = True,
|
||||
) -> None:
|
||||
self._step_fn = step_fn
|
||||
self._llm_fn = llm_fn
|
||||
self._llm_fallback = llm_fallback
|
||||
|
||||
def execute(
|
||||
self,
|
||||
workflow: CrystallizedWorkflow,
|
||||
) -> ExecutionResult:
|
||||
"""Run *workflow* deterministically.
|
||||
|
||||
If a step fails, falls back to LLM (if ``llm_fallback`` is enabled).
|
||||
"""
|
||||
step_results: list[StepResult] = []
|
||||
for step in workflow.steps:
|
||||
try:
|
||||
success, output = self._step_fn(step)
|
||||
except Exception as exc:
|
||||
log.warning("step %s raised: %s", step.action, exc)
|
||||
success, output = False, None
|
||||
error_str = str(exc)
|
||||
else:
|
||||
error_str = None if success else "step_fn returned success=False"
|
||||
|
||||
step_results.append(StepResult(step=step, success=success,
|
||||
output=output, error=error_str))
|
||||
if not success:
|
||||
log.info(
|
||||
"workflow %s: step %s failed — triggering LLM fallback",
|
||||
workflow.workflow_id, step.action,
|
||||
)
|
||||
return self._llm_fallback_result(
|
||||
step_results, workflow.workflow_id
|
||||
)
|
||||
|
||||
log.info("workflow %s: all %d steps succeeded",
|
||||
workflow.workflow_id, len(workflow.steps))
|
||||
return ExecutionResult(
|
||||
success=True,
|
||||
used_deterministic=True,
|
||||
step_results=step_results,
|
||||
workflow_id=workflow.workflow_id,
|
||||
)
|
||||
|
||||
def run_with_fallback(
|
||||
self,
|
||||
workflow: CrystallizedWorkflow | None,
|
||||
) -> ExecutionResult:
|
||||
"""Run *workflow* if provided; otherwise call the LLM directly."""
|
||||
if workflow is None:
|
||||
return self._llm_fallback_result([], workflow_id=None)
|
||||
return self.execute(workflow)
|
||||
|
||||
# ── Internal ──────────────────────────────────────────────────────────────
|
||||
|
||||
def _llm_fallback_result(
|
||||
self,
|
||||
partial_steps: list[StepResult],
|
||||
workflow_id: str | None,
|
||||
) -> ExecutionResult:
|
||||
if not self._llm_fallback:
|
||||
return ExecutionResult(
|
||||
success=False,
|
||||
used_deterministic=True,
|
||||
step_results=partial_steps,
|
||||
workflow_id=workflow_id,
|
||||
error="LLM fallback disabled and deterministic path failed.",
|
||||
)
|
||||
try:
|
||||
llm_output = self._llm_fn()
|
||||
except Exception as exc:
|
||||
return ExecutionResult(
|
||||
success=False,
|
||||
used_deterministic=False,
|
||||
step_results=partial_steps,
|
||||
workflow_id=workflow_id,
|
||||
error=f"LLM fallback raised: {exc}",
|
||||
)
|
||||
return ExecutionResult(
|
||||
success=True,
|
||||
used_deterministic=False,
|
||||
step_results=partial_steps,
|
||||
llm_output=llm_output,
|
||||
workflow_id=workflow_id,
|
||||
)
|
||||
|
|
@ -1,216 +0,0 @@
|
|||
# circuitforge_core/pipeline/models.py — crystallization data models
|
||||
#
|
||||
# MIT — protocol and model types only; no inference backends.
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
|
||||
# ── Utilities ─────────────────────────────────────────────────────────────────
|
||||
|
||||
def hash_input(features: dict[str, Any]) -> str:
|
||||
"""Return a stable SHA-256 hex digest of *features*.
|
||||
|
||||
Sorts keys before serialising so insertion order doesn't affect the hash.
|
||||
Only call this on already-normalised, PII-free feature dicts — the hash is
|
||||
opaque but the source dict should never contain raw user data.
|
||||
"""
|
||||
canonical = json.dumps(features, sort_keys=True, ensure_ascii=True)
|
||||
return hashlib.sha256(canonical.encode()).hexdigest()
|
||||
|
||||
|
||||
# ── Step ──────────────────────────────────────────────────────────────────────
|
||||
|
||||
@dataclass
|
||||
class Step:
|
||||
"""One atomic action in a deterministic workflow.
|
||||
|
||||
The ``action`` string is product-defined (e.g. ``"dtmf"``, ``"field_fill"``,
|
||||
``"api_call"``). ``params`` carries action-specific values; ``description``
|
||||
is a plain-English summary for the approval UI.
|
||||
"""
|
||||
action: str
|
||||
params: dict[str, Any]
|
||||
description: str = ""
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {"action": self.action, "params": self.params,
|
||||
"description": self.description}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d: dict[str, Any]) -> "Step":
|
||||
return cls(action=d["action"], params=d.get("params", {}),
|
||||
description=d.get("description", ""))
|
||||
|
||||
|
||||
# ── PipelineRun ───────────────────────────────────────────────────────────────
|
||||
|
||||
@dataclass
|
||||
class PipelineRun:
|
||||
"""Record of one LLM-assisted execution — the raw material for crystallization.
|
||||
|
||||
Fields
|
||||
------
|
||||
run_id:
|
||||
UUID or unique string identifying this run.
|
||||
product:
|
||||
CF product code (``"osprey"``, ``"falcon"``, ``"peregrine"`` …).
|
||||
task_type:
|
||||
Product-defined task category (``"ivr_navigate"``, ``"form_fill"`` …).
|
||||
input_hash:
|
||||
SHA-256 of normalised, PII-free input features. Never store raw input.
|
||||
steps:
|
||||
Ordered list of Steps the LLM proposed.
|
||||
approved:
|
||||
True if a human approved this run before execution.
|
||||
review_duration_ms:
|
||||
Wall-clock milliseconds between displaying the proposal and the approval
|
||||
click. Values under ~5 000 ms indicate a rubber-stamp — the
|
||||
crystallizer may reject runs with suspiciously short reviews.
|
||||
output_modified:
|
||||
True if the user edited any step before approving. Modifications suggest
|
||||
the LLM proposal was imperfect; too-easy crystallization from unmodified
|
||||
runs may mean the task is already deterministic and the LLM is just
|
||||
echoing a fixed pattern.
|
||||
timestamp:
|
||||
ISO 8601 UTC creation time.
|
||||
llm_model:
|
||||
Model ID that generated the steps, e.g. ``"llama3:8b-instruct"``.
|
||||
metadata:
|
||||
Freeform dict for product-specific extra fields.
|
||||
"""
|
||||
|
||||
run_id: str
|
||||
product: str
|
||||
task_type: str
|
||||
input_hash: str
|
||||
steps: list[Step]
|
||||
approved: bool
|
||||
review_duration_ms: int
|
||||
output_modified: bool
|
||||
timestamp: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
|
||||
llm_model: str | None = None
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"run_id": self.run_id,
|
||||
"product": self.product,
|
||||
"task_type": self.task_type,
|
||||
"input_hash": self.input_hash,
|
||||
"steps": [s.to_dict() for s in self.steps],
|
||||
"approved": self.approved,
|
||||
"review_duration_ms": self.review_duration_ms,
|
||||
"output_modified": self.output_modified,
|
||||
"timestamp": self.timestamp,
|
||||
"llm_model": self.llm_model,
|
||||
"metadata": self.metadata,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d: dict[str, Any]) -> "PipelineRun":
|
||||
return cls(
|
||||
run_id=d["run_id"],
|
||||
product=d["product"],
|
||||
task_type=d["task_type"],
|
||||
input_hash=d["input_hash"],
|
||||
steps=[Step.from_dict(s) for s in d.get("steps", [])],
|
||||
approved=d["approved"],
|
||||
review_duration_ms=d["review_duration_ms"],
|
||||
output_modified=d.get("output_modified", False),
|
||||
timestamp=d.get("timestamp", ""),
|
||||
llm_model=d.get("llm_model"),
|
||||
metadata=d.get("metadata", {}),
|
||||
)
|
||||
|
||||
|
||||
# ── CrystallizedWorkflow ──────────────────────────────────────────────────────
|
||||
|
||||
@dataclass
|
||||
class CrystallizedWorkflow:
|
||||
"""A deterministic workflow promoted from N approved PipelineRuns.
|
||||
|
||||
Once crystallized, the executor runs ``steps`` directly — no LLM required
|
||||
unless an edge case is encountered.
|
||||
|
||||
Fields
|
||||
------
|
||||
workflow_id:
|
||||
Unique identifier (typically ``{product}:{task_type}:{input_hash[:12]}``).
|
||||
product / task_type / input_hash:
|
||||
Same semantics as PipelineRun; the hash is the lookup key.
|
||||
steps:
|
||||
Canonical deterministic step sequence (majority-voted or most-recent,
|
||||
per CrystallizerConfig.strategy).
|
||||
crystallized_at:
|
||||
ISO 8601 UTC timestamp.
|
||||
run_ids:
|
||||
IDs of the source PipelineRuns that contributed to this workflow.
|
||||
approval_count:
|
||||
Number of approved runs that went into crystallization.
|
||||
avg_review_duration_ms:
|
||||
Mean review_duration_ms across all source runs — low values are a
|
||||
warning sign that approvals may not have been genuine.
|
||||
all_output_unmodified:
|
||||
True if every contributing run had output_modified=False. Combined with
|
||||
a very short avg_review_duration_ms this can flag workflows that may
|
||||
have crystallized from rubber-stamp approvals.
|
||||
active:
|
||||
Whether this workflow is in use. Set to False to disable without
|
||||
deleting the record.
|
||||
version:
|
||||
Increments each time the workflow is re-crystallized from new runs.
|
||||
"""
|
||||
|
||||
workflow_id: str
|
||||
product: str
|
||||
task_type: str
|
||||
input_hash: str
|
||||
steps: list[Step]
|
||||
crystallized_at: str
|
||||
run_ids: list[str]
|
||||
approval_count: int
|
||||
avg_review_duration_ms: int
|
||||
all_output_unmodified: bool
|
||||
active: bool = True
|
||||
version: int = 1
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"workflow_id": self.workflow_id,
|
||||
"product": self.product,
|
||||
"task_type": self.task_type,
|
||||
"input_hash": self.input_hash,
|
||||
"steps": [s.to_dict() for s in self.steps],
|
||||
"crystallized_at": self.crystallized_at,
|
||||
"run_ids": self.run_ids,
|
||||
"approval_count": self.approval_count,
|
||||
"avg_review_duration_ms": self.avg_review_duration_ms,
|
||||
"all_output_unmodified": self.all_output_unmodified,
|
||||
"active": self.active,
|
||||
"version": self.version,
|
||||
"metadata": self.metadata,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d: dict[str, Any]) -> "CrystallizedWorkflow":
|
||||
return cls(
|
||||
workflow_id=d["workflow_id"],
|
||||
product=d["product"],
|
||||
task_type=d["task_type"],
|
||||
input_hash=d["input_hash"],
|
||||
steps=[Step.from_dict(s) for s in d.get("steps", [])],
|
||||
crystallized_at=d["crystallized_at"],
|
||||
run_ids=d.get("run_ids", []),
|
||||
approval_count=d["approval_count"],
|
||||
avg_review_duration_ms=d["avg_review_duration_ms"],
|
||||
all_output_unmodified=d.get("all_output_unmodified", True),
|
||||
active=d.get("active", True),
|
||||
version=d.get("version", 1),
|
||||
metadata=d.get("metadata", {}),
|
||||
)
|
||||
|
|
@ -1,234 +0,0 @@
|
|||
# circuitforge_core/pipeline/multimodal.py — cf-docuvision + cf-text pipeline
|
||||
#
|
||||
# MIT — orchestration only; vision and text inference stay in their own modules.
|
||||
#
|
||||
# Usage (minimal):
|
||||
#
|
||||
# from circuitforge_core.pipeline.multimodal import MultimodalPipeline, MultimodalConfig
|
||||
#
|
||||
# pipe = MultimodalPipeline(MultimodalConfig())
|
||||
# for result in pipe.run(page_bytes_list):
|
||||
# print(f"Page {result.page_idx}: {result.generated[:80]}")
|
||||
#
|
||||
# Streaming (token-by-token):
|
||||
#
|
||||
# for page_idx, token in pipe.stream(page_bytes_list):
|
||||
# ui.append(page_idx, token)
|
||||
#
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Callable, Iterable, Iterator
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from circuitforge_core.documents.client import DocuvisionClient
|
||||
from circuitforge_core.documents.models import StructuredDocument
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ── Config ────────────────────────────────────────────────────────────────────
|
||||
|
||||
def _default_prompt(page_idx: int, doc: StructuredDocument) -> str:
|
||||
"""Build a generation prompt from a StructuredDocument."""
|
||||
header = f"[Page {page_idx + 1}]\n" if page_idx > 0 else ""
|
||||
return header + doc.raw_text
|
||||
|
||||
|
||||
@dataclass
|
||||
class MultimodalConfig:
|
||||
"""Configuration for MultimodalPipeline.
|
||||
|
||||
vision_url:
|
||||
Base URL of the cf-docuvision service.
|
||||
hint:
|
||||
Docuvision extraction hint — ``"auto"`` | ``"document"`` | ``"form"``
|
||||
| ``"table"`` | ``"figure"``.
|
||||
max_tokens:
|
||||
Passed to cf-text generate per page.
|
||||
temperature:
|
||||
Sampling temperature for text generation.
|
||||
vram_serialise:
|
||||
When True, ``swap_fn`` is called between the vision and text steps
|
||||
on each page. Use this on 8GB GPUs where Dolphin-v2 and the text
|
||||
model cannot be resident simultaneously.
|
||||
prompt_fn:
|
||||
Callable ``(page_idx, StructuredDocument) -> str`` that builds the
|
||||
generation prompt. Defaults to using ``doc.raw_text`` directly.
|
||||
Products override this to add system context, few-shot examples, etc.
|
||||
vision_timeout:
|
||||
HTTP timeout in seconds for each cf-docuvision request.
|
||||
"""
|
||||
vision_url: str = "http://localhost:8003"
|
||||
hint: str = "auto"
|
||||
max_tokens: int = 512
|
||||
temperature: float = 0.7
|
||||
vram_serialise: bool = False
|
||||
prompt_fn: Callable[[int, StructuredDocument], str] = field(
|
||||
default_factory=lambda: _default_prompt
|
||||
)
|
||||
vision_timeout: int = 60
|
||||
|
||||
|
||||
# ── Results ───────────────────────────────────────────────────────────────────
|
||||
|
||||
@dataclass
|
||||
class PageResult:
|
||||
"""Result of processing one page through the vision + text pipeline.
|
||||
|
||||
page_idx:
|
||||
Zero-based page index.
|
||||
doc:
|
||||
StructuredDocument from cf-docuvision.
|
||||
generated:
|
||||
Full text output from cf-text for this page.
|
||||
error:
|
||||
Non-None if extraction or generation failed for this page.
|
||||
"""
|
||||
page_idx: int
|
||||
doc: StructuredDocument | None
|
||||
generated: str
|
||||
error: str | None = None
|
||||
|
||||
|
||||
# ── Pipeline ──────────────────────────────────────────────────────────────────
|
||||
|
||||
class MultimodalPipeline:
|
||||
"""Chunk a multi-page document through vision extraction + text generation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config:
|
||||
Pipeline configuration.
|
||||
swap_fn:
|
||||
Optional callable with no arguments, called between the vision and text
|
||||
steps on each page when ``config.vram_serialise=True``. Products using
|
||||
cf-orch wire this to the VRAM budget API so Dolphin-v2 can offload
|
||||
before the text model loads. A no-op lambda works for testing.
|
||||
generate_fn:
|
||||
Text generation callable: ``(prompt, max_tokens, temperature) -> str``.
|
||||
Defaults to ``circuitforge_core.text.generate``. Override in tests or
|
||||
when the product manages its own text backend.
|
||||
stream_fn:
|
||||
Streaming text callable: ``(prompt, max_tokens, temperature) -> Iterator[str]``.
|
||||
Defaults to ``circuitforge_core.text.generate`` with ``stream=True``.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: MultimodalConfig | None = None,
|
||||
*,
|
||||
swap_fn: Callable[[], None] | None = None,
|
||||
generate_fn: Callable[..., str] | None = None,
|
||||
stream_fn: Callable[..., Iterator[str]] | None = None,
|
||||
) -> None:
|
||||
self._cfg = config or MultimodalConfig()
|
||||
self._vision = DocuvisionClient(
|
||||
base_url=self._cfg.vision_url,
|
||||
timeout=self._cfg.vision_timeout,
|
||||
)
|
||||
self._swap_fn = swap_fn
|
||||
self._generate_fn = generate_fn
|
||||
self._stream_fn = stream_fn
|
||||
|
||||
# ── Public ────────────────────────────────────────────────────────────────
|
||||
|
||||
def run(self, pages: Iterable[bytes]) -> Iterator[PageResult]:
|
||||
"""Process each page and yield a PageResult as soon as it is ready.
|
||||
|
||||
Callers receive pages one at a time — the UI can begin rendering
|
||||
page 0 while pages 1..N are still being extracted and generated.
|
||||
"""
|
||||
for page_idx, page_bytes in enumerate(pages):
|
||||
yield self._process_page(page_idx, page_bytes)
|
||||
|
||||
def stream(self, pages: Iterable[bytes]) -> Iterator[tuple[int, str]]:
|
||||
"""Yield ``(page_idx, token)`` tuples for token-level progressive rendering.
|
||||
|
||||
Each page is fully extracted before text generation begins, but tokens
|
||||
are yielded as the text model produces them rather than waiting for the
|
||||
full page output.
|
||||
"""
|
||||
for page_idx, page_bytes in enumerate(pages):
|
||||
doc, err = self._extract(page_idx, page_bytes)
|
||||
if err:
|
||||
yield (page_idx, f"[extraction error: {err}]")
|
||||
continue
|
||||
|
||||
self._maybe_swap()
|
||||
|
||||
prompt = self._cfg.prompt_fn(page_idx, doc)
|
||||
try:
|
||||
for token in self._stream_tokens(prompt):
|
||||
yield (page_idx, token)
|
||||
except Exception as exc:
|
||||
log.error("page %d text streaming failed: %s", page_idx, exc)
|
||||
yield (page_idx, f"[generation error: {exc}]")
|
||||
|
||||
# ── Internal ──────────────────────────────────────────────────────────────
|
||||
|
||||
def _process_page(self, page_idx: int, page_bytes: bytes) -> PageResult:
|
||||
doc, err = self._extract(page_idx, page_bytes)
|
||||
if err:
|
||||
return PageResult(page_idx=page_idx, doc=None, generated="", error=err)
|
||||
|
||||
self._maybe_swap()
|
||||
|
||||
prompt = self._cfg.prompt_fn(page_idx, doc)
|
||||
try:
|
||||
text = self._generate(prompt)
|
||||
except Exception as exc:
|
||||
log.error("page %d generation failed: %s", page_idx, exc)
|
||||
return PageResult(page_idx=page_idx, doc=doc, generated="",
|
||||
error=str(exc))
|
||||
|
||||
return PageResult(page_idx=page_idx, doc=doc, generated=text)
|
||||
|
||||
def _extract(
|
||||
self, page_idx: int, page_bytes: bytes
|
||||
) -> tuple[StructuredDocument | None, str | None]:
|
||||
try:
|
||||
doc = self._vision.extract(page_bytes, hint=self._cfg.hint)
|
||||
log.debug("page %d extracted: %d chars", page_idx, len(doc.raw_text))
|
||||
return doc, None
|
||||
except Exception as exc:
|
||||
log.error("page %d vision extraction failed: %s", page_idx, exc)
|
||||
return None, str(exc)
|
||||
|
||||
def _maybe_swap(self) -> None:
|
||||
if self._cfg.vram_serialise and self._swap_fn is not None:
|
||||
log.debug("vram_serialise: calling swap_fn")
|
||||
self._swap_fn()
|
||||
|
||||
def _generate(self, prompt: str) -> str:
|
||||
if self._generate_fn is not None:
|
||||
return self._generate_fn(
|
||||
prompt,
|
||||
max_tokens=self._cfg.max_tokens,
|
||||
temperature=self._cfg.temperature,
|
||||
)
|
||||
from circuitforge_core.text import generate
|
||||
result = generate(
|
||||
prompt,
|
||||
max_tokens=self._cfg.max_tokens,
|
||||
temperature=self._cfg.temperature,
|
||||
)
|
||||
return result.text
|
||||
|
||||
def _stream_tokens(self, prompt: str) -> Iterator[str]:
|
||||
if self._stream_fn is not None:
|
||||
yield from self._stream_fn(
|
||||
prompt,
|
||||
max_tokens=self._cfg.max_tokens,
|
||||
temperature=self._cfg.temperature,
|
||||
)
|
||||
return
|
||||
from circuitforge_core.text import generate
|
||||
tokens = generate(
|
||||
prompt,
|
||||
max_tokens=self._cfg.max_tokens,
|
||||
temperature=self._cfg.temperature,
|
||||
stream=True,
|
||||
)
|
||||
yield from tokens
|
||||
|
|
@ -1,70 +0,0 @@
|
|||
# circuitforge_core/pipeline/recorder.py — write and load PipelineRun records
|
||||
#
|
||||
# MIT — local file I/O only; no inference.
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Iterable
|
||||
|
||||
from .models import PipelineRun
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
_DEFAULT_ROOT = Path.home() / ".config" / "circuitforge" / "pipeline" / "runs"
|
||||
|
||||
|
||||
class Recorder:
|
||||
"""Writes PipelineRun JSON records to a local directory tree.
|
||||
|
||||
Layout::
|
||||
|
||||
{root}/{product}/{task_type}/{run_id}.json
|
||||
|
||||
The recorder is intentionally append-only — it never deletes or modifies
|
||||
existing records. Old runs accumulate as an audit trail; products that
|
||||
want retention limits should prune the directory themselves.
|
||||
"""
|
||||
|
||||
def __init__(self, root: Path | None = None) -> None:
|
||||
self._root = Path(root) if root else _DEFAULT_ROOT
|
||||
|
||||
# ── Write ─────────────────────────────────────────────────────────────────
|
||||
|
||||
def record(self, run: PipelineRun) -> Path:
|
||||
"""Persist *run* to disk and return the file path written."""
|
||||
dest = self._path_for(run.product, run.task_type, run.run_id)
|
||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||
dest.write_text(json.dumps(run.to_dict(), indent=2), encoding="utf-8")
|
||||
log.debug("recorded pipeline run %s → %s", run.run_id, dest)
|
||||
return dest
|
||||
|
||||
# ── Read ──────────────────────────────────────────────────────────────────
|
||||
|
||||
def load_runs(self, product: str, task_type: str) -> list[PipelineRun]:
|
||||
"""Return all runs for *(product, task_type)*, newest-first."""
|
||||
directory = self._root / product / task_type
|
||||
if not directory.is_dir():
|
||||
return []
|
||||
runs: list[PipelineRun] = []
|
||||
for p in directory.glob("*.json"):
|
||||
try:
|
||||
runs.append(PipelineRun.from_dict(json.loads(p.read_text())))
|
||||
except Exception:
|
||||
log.warning("skipping unreadable run file %s", p)
|
||||
runs.sort(key=lambda r: r.timestamp, reverse=True)
|
||||
return runs
|
||||
|
||||
def load_approved(self, product: str, task_type: str,
|
||||
input_hash: str) -> list[PipelineRun]:
|
||||
"""Return approved runs that match *input_hash*, newest-first."""
|
||||
return [
|
||||
r for r in self.load_runs(product, task_type)
|
||||
if r.approved and r.input_hash == input_hash
|
||||
]
|
||||
|
||||
# ── Internal ──────────────────────────────────────────────────────────────
|
||||
|
||||
def _path_for(self, product: str, task_type: str, run_id: str) -> Path:
|
||||
return self._root / product / task_type / f"{run_id}.json"
|
||||
|
|
@ -1,134 +0,0 @@
|
|||
# circuitforge_core/pipeline/registry.py — workflow lookup
|
||||
#
|
||||
# MIT — file I/O and matching logic only.
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
from .models import CrystallizedWorkflow
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
_DEFAULT_ROOT = Path.home() / ".config" / "circuitforge" / "pipeline" / "workflows"
|
||||
|
||||
|
||||
class Registry:
|
||||
"""Loads and matches CrystallizedWorkflows from the local filesystem.
|
||||
|
||||
Layout::
|
||||
|
||||
{root}/{product}/{task_type}/{workflow_id}.json
|
||||
|
||||
Exact matching is always available. Products that need fuzzy/semantic
|
||||
matching can supply a ``similarity_fn`` — a callable that takes two input
|
||||
hashes and returns a float in [0, 1]. The registry returns the first
|
||||
active workflow whose similarity score meets ``fuzzy_threshold``.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root: Path | None = None,
|
||||
similarity_fn: Callable[[str, str], float] | None = None,
|
||||
fuzzy_threshold: float = 0.8,
|
||||
) -> None:
|
||||
self._root = Path(root) if root else _DEFAULT_ROOT
|
||||
self._similarity_fn = similarity_fn
|
||||
self._fuzzy_threshold = fuzzy_threshold
|
||||
|
||||
# ── Write ─────────────────────────────────────────────────────────────────
|
||||
|
||||
def register(self, workflow: CrystallizedWorkflow) -> Path:
|
||||
"""Persist *workflow* and return the path written."""
|
||||
dest = self._path_for(workflow.product, workflow.task_type,
|
||||
workflow.workflow_id)
|
||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||
dest.write_text(json.dumps(workflow.to_dict(), indent=2), encoding="utf-8")
|
||||
log.info("registered workflow %s (v%d)", workflow.workflow_id,
|
||||
workflow.version)
|
||||
return dest
|
||||
|
||||
def deactivate(self, workflow_id: str, product: str,
|
||||
task_type: str) -> bool:
|
||||
"""Set ``active=False`` on a stored workflow. Returns True if found."""
|
||||
path = self._path_for(product, task_type, workflow_id)
|
||||
if not path.exists():
|
||||
return False
|
||||
data = json.loads(path.read_text())
|
||||
data["active"] = False
|
||||
path.write_text(json.dumps(data, indent=2), encoding="utf-8")
|
||||
log.info("deactivated workflow %s", workflow_id)
|
||||
return True
|
||||
|
||||
# ── Read ──────────────────────────────────────────────────────────────────
|
||||
|
||||
def load_all(self, product: str, task_type: str) -> list[CrystallizedWorkflow]:
|
||||
"""Return all (including inactive) workflows for *(product, task_type)*."""
|
||||
directory = self._root / product / task_type
|
||||
if not directory.is_dir():
|
||||
return []
|
||||
workflows: list[CrystallizedWorkflow] = []
|
||||
for p in directory.glob("*.json"):
|
||||
try:
|
||||
workflows.append(
|
||||
CrystallizedWorkflow.from_dict(json.loads(p.read_text()))
|
||||
)
|
||||
except Exception:
|
||||
log.warning("skipping unreadable workflow file %s", p)
|
||||
return workflows
|
||||
|
||||
# ── Match ─────────────────────────────────────────────────────────────────
|
||||
|
||||
def match(self, product: str, task_type: str,
|
||||
input_hash: str) -> CrystallizedWorkflow | None:
|
||||
"""Return the active workflow for an exact input_hash match, or None."""
|
||||
for wf in self.load_all(product, task_type):
|
||||
if wf.active and wf.input_hash == input_hash:
|
||||
log.debug("registry exact match: %s", wf.workflow_id)
|
||||
return wf
|
||||
return None
|
||||
|
||||
def fuzzy_match(self, product: str, task_type: str,
|
||||
input_hash: str) -> CrystallizedWorkflow | None:
|
||||
"""Return a workflow above the similarity threshold, or None.
|
||||
|
||||
Requires a ``similarity_fn`` to have been supplied at construction.
|
||||
If none was provided, raises ``RuntimeError``.
|
||||
"""
|
||||
if self._similarity_fn is None:
|
||||
raise RuntimeError(
|
||||
"fuzzy_match() requires a similarity_fn — none was supplied "
|
||||
"to Registry.__init__()."
|
||||
)
|
||||
best: CrystallizedWorkflow | None = None
|
||||
best_score = 0.0
|
||||
for wf in self.load_all(product, task_type):
|
||||
if not wf.active:
|
||||
continue
|
||||
score = self._similarity_fn(wf.input_hash, input_hash)
|
||||
if score >= self._fuzzy_threshold and score > best_score:
|
||||
best = wf
|
||||
best_score = score
|
||||
if best:
|
||||
log.debug("registry fuzzy match: %s (score=%.2f)", best.workflow_id,
|
||||
best_score)
|
||||
return best
|
||||
|
||||
def find(self, product: str, task_type: str,
|
||||
input_hash: str) -> CrystallizedWorkflow | None:
|
||||
"""Exact match first; fuzzy match second (if similarity_fn is set)."""
|
||||
exact = self.match(product, task_type, input_hash)
|
||||
if exact:
|
||||
return exact
|
||||
if self._similarity_fn is not None:
|
||||
return self.fuzzy_match(product, task_type, input_hash)
|
||||
return None
|
||||
|
||||
# ── Internal ──────────────────────────────────────────────────────────────
|
||||
|
||||
def _path_for(self, product: str, task_type: str,
|
||||
workflow_id: str) -> Path:
|
||||
safe_id = workflow_id.replace(":", "_")
|
||||
return self._root / product / task_type / f"{safe_id}.json"
|
||||
|
|
@ -1,183 +0,0 @@
|
|||
"""eBay OAuth Authorization Code flow — user-level token manager.
|
||||
|
||||
Implements the Authorization Code Grant for eBay's Trading API.
|
||||
App-level client credentials (Browse API) are handled separately in
|
||||
the product-level EbayTokenManager (snipe/app/platforms/ebay/auth.py).
|
||||
|
||||
Usage (Snipe):
|
||||
manager = EbayUserTokenManager(
|
||||
client_id=app_id,
|
||||
client_secret=cert_id,
|
||||
runame=runame,
|
||||
redirect_uri=redirect_uri,
|
||||
env="production",
|
||||
)
|
||||
|
||||
# 1. Send user to eBay
|
||||
url = manager.get_authorization_url(state="csrf-token-here")
|
||||
redirect(url)
|
||||
|
||||
# 2. Handle callback
|
||||
tokens = manager.exchange_code(code) # returns EbayUserTokens
|
||||
# store tokens.access_token, tokens.refresh_token, tokens.expires_at
|
||||
|
||||
# 3. Get a fresh access token for API calls
|
||||
access_token = manager.refresh(stored_refresh_token)
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import time
|
||||
import urllib.parse
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
|
||||
EBAY_AUTH_URLS = {
|
||||
"production": "https://auth.ebay.com/oauth2/authorize",
|
||||
"sandbox": "https://auth.sandbox.ebay.com/oauth2/authorize",
|
||||
}
|
||||
|
||||
EBAY_TOKEN_URLS = {
|
||||
"production": "https://api.ebay.com/identity/v1/oauth2/token",
|
||||
"sandbox": "https://api.sandbox.ebay.com/identity/v1/oauth2/token",
|
||||
}
|
||||
|
||||
# Scopes needed for Trading API GetUser (account age + category feedback).
|
||||
# https://developer.ebay.com/api-docs/static/oauth-scopes.html
|
||||
DEFAULT_SCOPES = [
|
||||
"https://api.ebay.com/oauth/api_scope",
|
||||
"https://api.ebay.com/oauth/api_scope/sell.account.readonly",
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class EbayUserTokens:
|
||||
access_token: str
|
||||
refresh_token: str
|
||||
expires_at: float # epoch seconds
|
||||
scopes: list[str]
|
||||
|
||||
|
||||
class EbayUserTokenManager:
|
||||
"""Manages eBay Authorization Code OAuth tokens for a single user.
|
||||
|
||||
One instance per user session. Does NOT persist tokens — callers are
|
||||
responsible for storing/loading tokens via the DB migration
|
||||
013_ebay_user_tokens.sql.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client_id: str,
|
||||
client_secret: str,
|
||||
runame: str,
|
||||
redirect_uri: str,
|
||||
env: str = "production",
|
||||
scopes: Optional[list[str]] = None,
|
||||
):
|
||||
self._client_id = client_id
|
||||
self._client_secret = client_secret
|
||||
self._runame = runame
|
||||
self._redirect_uri = redirect_uri
|
||||
self._auth_url = EBAY_AUTH_URLS[env]
|
||||
self._token_url = EBAY_TOKEN_URLS[env]
|
||||
self._scopes = scopes or DEFAULT_SCOPES
|
||||
|
||||
# ── Authorization URL ──────────────────────────────────────────────────────
|
||||
|
||||
def get_authorization_url(self, state: str = "") -> str:
|
||||
"""Build the eBay OAuth authorization URL to redirect the user to.
|
||||
|
||||
Args:
|
||||
state: CSRF token or opaque value passed through unchanged.
|
||||
|
||||
Returns:
|
||||
Full URL string to redirect the user's browser to.
|
||||
"""
|
||||
params = {
|
||||
"client_id": self._client_id,
|
||||
"response_type": "code",
|
||||
"redirect_uri": self._runame, # eBay uses RuName, not the raw URI
|
||||
"scope": " ".join(self._scopes),
|
||||
}
|
||||
if state:
|
||||
params["state"] = state
|
||||
return f"{self._auth_url}?{urllib.parse.urlencode(params)}"
|
||||
|
||||
# ── Code exchange ──────────────────────────────────────────────────────────
|
||||
|
||||
def exchange_code(self, code: str) -> EbayUserTokens:
|
||||
"""Exchange an authorization code for access + refresh tokens.
|
||||
|
||||
Called from the OAuth callback endpoint after eBay redirects back.
|
||||
|
||||
Raises:
|
||||
requests.HTTPError on non-2xx eBay response.
|
||||
KeyError if eBay response is missing expected fields.
|
||||
"""
|
||||
resp = requests.post(
|
||||
self._token_url,
|
||||
headers={
|
||||
"Authorization": f"Basic {self._credentials_b64()}",
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
},
|
||||
data={
|
||||
"grant_type": "authorization_code",
|
||||
"code": code,
|
||||
"redirect_uri": self._runame,
|
||||
},
|
||||
timeout=15,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return self._parse_token_response(resp.json())
|
||||
|
||||
# ── Token refresh ──────────────────────────────────────────────────────────
|
||||
|
||||
def refresh(self, refresh_token: str) -> EbayUserTokens:
|
||||
"""Exchange a refresh token for a new access token.
|
||||
|
||||
eBay refresh tokens are valid for 18 months. Access tokens last 2h.
|
||||
Call this before making Trading API requests when the stored token
|
||||
is within 60 seconds of expiry.
|
||||
|
||||
Raises:
|
||||
requests.HTTPError if the refresh token is expired or revoked.
|
||||
"""
|
||||
resp = requests.post(
|
||||
self._token_url,
|
||||
headers={
|
||||
"Authorization": f"Basic {self._credentials_b64()}",
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
},
|
||||
data={
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": refresh_token,
|
||||
"scope": " ".join(self._scopes),
|
||||
},
|
||||
timeout=15,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
# Refresh responses do NOT include a new refresh_token — the original stays valid
|
||||
data = resp.json()
|
||||
return EbayUserTokens(
|
||||
access_token=data["access_token"],
|
||||
refresh_token=refresh_token, # unchanged
|
||||
expires_at=time.time() + data["expires_in"],
|
||||
scopes=data.get("scope", "").split(),
|
||||
)
|
||||
|
||||
# ── Helpers ────────────────────────────────────────────────────────────────
|
||||
|
||||
def _credentials_b64(self) -> str:
|
||||
raw = f"{self._client_id}:{self._client_secret}"
|
||||
return base64.b64encode(raw.encode()).decode()
|
||||
|
||||
def _parse_token_response(self, data: dict) -> EbayUserTokens:
|
||||
return EbayUserTokens(
|
||||
access_token=data["access_token"],
|
||||
refresh_token=data["refresh_token"],
|
||||
expires_at=time.time() + data["expires_in"],
|
||||
scopes=data.get("scope", "").split(),
|
||||
)
|
||||
|
|
@ -40,13 +40,8 @@ def set_user_preference(
|
|||
s.set(user_id=user_id, path=path, value=value)
|
||||
|
||||
|
||||
from . import accessibility as accessibility
|
||||
from . import currency as currency
|
||||
|
||||
__all__ = [
|
||||
"get_path", "set_path",
|
||||
"get_user_preference", "set_user_preference",
|
||||
"LocalFileStore", "PreferenceStore",
|
||||
"accessibility",
|
||||
"currency",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,73 +0,0 @@
|
|||
# circuitforge_core/preferences/accessibility.py — a11y preference keys
|
||||
#
|
||||
# First-class accessibility preferences so every product UI reads from
|
||||
# the same store path without each implementing it separately.
|
||||
#
|
||||
# All keys use the "accessibility.*" namespace in the preference store.
|
||||
# Products read these via get_user_preference() or the convenience helpers here.
|
||||
from __future__ import annotations
|
||||
|
||||
from circuitforge_core.preferences import get_user_preference, set_user_preference
|
||||
|
||||
# ── Preference key constants ──────────────────────────────────────────────────
|
||||
|
||||
PREF_REDUCED_MOTION = "accessibility.prefers_reduced_motion"
|
||||
PREF_HIGH_CONTRAST = "accessibility.high_contrast"
|
||||
PREF_FONT_SIZE = "accessibility.font_size" # "default" | "large" | "xlarge"
|
||||
PREF_SCREEN_READER = "accessibility.screen_reader_mode" # reduces decorative content
|
||||
|
||||
_DEFAULTS: dict[str, object] = {
|
||||
PREF_REDUCED_MOTION: False,
|
||||
PREF_HIGH_CONTRAST: False,
|
||||
PREF_FONT_SIZE: "default",
|
||||
PREF_SCREEN_READER: False,
|
||||
}
|
||||
|
||||
|
||||
# ── Convenience helpers ───────────────────────────────────────────────────────
|
||||
|
||||
def is_reduced_motion_preferred(
|
||||
user_id: str | None = None,
|
||||
store=None,
|
||||
) -> bool:
|
||||
"""
|
||||
Return True if the user has requested reduced motion.
|
||||
|
||||
Products must honour this in all animated UI elements: transitions,
|
||||
auto-playing content, parallax, loaders. This maps to the CSS
|
||||
`prefers-reduced-motion: reduce` media query and is the canonical
|
||||
source of truth across all CF product UIs.
|
||||
|
||||
Default: False.
|
||||
"""
|
||||
val = get_user_preference(
|
||||
user_id, PREF_REDUCED_MOTION, default=False, store=store
|
||||
)
|
||||
return bool(val)
|
||||
|
||||
|
||||
def is_high_contrast(user_id: str | None = None, store=None) -> bool:
|
||||
"""Return True if the user has requested high-contrast mode."""
|
||||
return bool(get_user_preference(user_id, PREF_HIGH_CONTRAST, default=False, store=store))
|
||||
|
||||
|
||||
def get_font_size(user_id: str | None = None, store=None) -> str:
|
||||
"""Return the user's preferred font size: 'default' | 'large' | 'xlarge'."""
|
||||
val = get_user_preference(user_id, PREF_FONT_SIZE, default="default", store=store)
|
||||
if val not in ("default", "large", "xlarge"):
|
||||
return "default"
|
||||
return str(val)
|
||||
|
||||
|
||||
def is_screen_reader_mode(user_id: str | None = None, store=None) -> bool:
|
||||
"""Return True if the user has requested screen reader optimised output."""
|
||||
return bool(get_user_preference(user_id, PREF_SCREEN_READER, default=False, store=store))
|
||||
|
||||
|
||||
def set_reduced_motion(
|
||||
value: bool,
|
||||
user_id: str | None = None,
|
||||
store=None,
|
||||
) -> None:
|
||||
"""Persist the user's reduced-motion preference."""
|
||||
set_user_preference(user_id, PREF_REDUCED_MOTION, value, store=store)
|
||||
|
|
@ -1,148 +0,0 @@
|
|||
# circuitforge_core/preferences/currency.py — currency preference + display formatting
|
||||
#
|
||||
# Stores a per-user ISO 4217 currency code and provides format_currency() so every
|
||||
# product formats prices consistently without rolling its own formatter.
|
||||
#
|
||||
# Priority fallback chain for get_currency_code():
|
||||
# 1. User preference store ("currency.code")
|
||||
# 2. CURRENCY_DEFAULT env var
|
||||
# 3. Hard default: "USD"
|
||||
#
|
||||
# format_currency() tries babel for full locale support; falls back to a built-in
|
||||
# symbol table when babel is not installed (no hard dependency on cf-core).
|
||||
#
|
||||
# MIT licensed.
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
from circuitforge_core.preferences import get_user_preference, set_user_preference
|
||||
|
||||
# ── Preference key constants ──────────────────────────────────────────────────
|
||||
|
||||
PREF_CURRENCY_CODE = "currency.code"
|
||||
DEFAULT_CURRENCY_CODE = "USD"
|
||||
|
||||
# ── Built-in symbol table (babel fallback) ────────────────────────────────────
|
||||
# Covers the currencies most likely to appear across CF product consumers.
|
||||
# Symbol is prepended; decimal places follow ISO 4217 minor-unit convention.
|
||||
|
||||
_CURRENCY_META: dict[str, tuple[str, int]] = {
|
||||
# (symbol, decimal_places)
|
||||
"USD": ("$", 2),
|
||||
"CAD": ("CA$", 2),
|
||||
"AUD": ("A$", 2),
|
||||
"NZD": ("NZ$", 2),
|
||||
"GBP": ("£", 2),
|
||||
"EUR": ("€", 2),
|
||||
"CHF": ("CHF ", 2),
|
||||
"SEK": ("kr", 2),
|
||||
"NOK": ("kr", 2),
|
||||
"DKK": ("kr", 2),
|
||||
"JPY": ("¥", 0),
|
||||
"CNY": ("¥", 2),
|
||||
"KRW": ("₩", 0),
|
||||
"INR": ("₹", 2),
|
||||
"BRL": ("R$", 2),
|
||||
"MXN": ("$", 2),
|
||||
"ZAR": ("R", 2),
|
||||
"SGD": ("S$", 2),
|
||||
"HKD": ("HK$", 2),
|
||||
"THB": ("฿", 2),
|
||||
"PLN": ("zł", 2),
|
||||
"CZK": ("Kč", 2),
|
||||
"HUF": ("Ft", 0),
|
||||
"RUB": ("₽", 2),
|
||||
"TRY": ("₺", 2),
|
||||
"ILS": ("₪", 2),
|
||||
"AED": ("د.إ", 2),
|
||||
"SAR": ("﷼", 2),
|
||||
"CLP": ("$", 0),
|
||||
"COP": ("$", 0),
|
||||
"ARS": ("$", 2),
|
||||
"VND": ("₫", 0),
|
||||
"IDR": ("Rp", 0),
|
||||
"MYR": ("RM", 2),
|
||||
"PHP": ("₱", 2),
|
||||
}
|
||||
|
||||
# ── Preference helpers ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def get_currency_code(
|
||||
user_id: str | None = None,
|
||||
store=None,
|
||||
) -> str:
|
||||
"""
|
||||
Return the user's preferred ISO 4217 currency code.
|
||||
|
||||
Fallback chain:
|
||||
1. Value in preference store at "currency.code"
|
||||
2. CURRENCY_DEFAULT environment variable
|
||||
3. "USD"
|
||||
"""
|
||||
stored = get_user_preference(user_id, PREF_CURRENCY_CODE, default=None, store=store)
|
||||
if stored is not None:
|
||||
return str(stored).upper()
|
||||
env_default = os.environ.get("CURRENCY_DEFAULT", "").strip().upper()
|
||||
if env_default:
|
||||
return env_default
|
||||
return DEFAULT_CURRENCY_CODE
|
||||
|
||||
|
||||
def set_currency_code(
|
||||
currency_code: str,
|
||||
user_id: str | None = None,
|
||||
store=None,
|
||||
) -> None:
|
||||
"""Persist *currency_code* (ISO 4217, e.g. 'GBP') to the preference store."""
|
||||
set_user_preference(user_id, PREF_CURRENCY_CODE, currency_code.upper(), store=store)
|
||||
|
||||
|
||||
# ── Formatting ────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def format_currency(
|
||||
amount: float,
|
||||
currency_code: str,
|
||||
locale: str = "en_US",
|
||||
) -> str:
|
||||
"""
|
||||
Format *amount* as a locale-aware currency string.
|
||||
|
||||
Examples::
|
||||
|
||||
format_currency(12.5, "GBP") # "£12.50"
|
||||
format_currency(1234.99, "USD") # "$1,234.99"
|
||||
format_currency(1500, "JPY") # "¥1,500"
|
||||
|
||||
Uses ``babel.numbers.format_currency`` when babel is installed, which gives
|
||||
full locale-aware grouping, decimal separators, and symbol placement.
|
||||
Falls back to a built-in symbol table for the common currencies.
|
||||
|
||||
Args:
|
||||
amount: Numeric amount to format.
|
||||
currency_code: ISO 4217 code (e.g. "USD", "GBP", "EUR").
|
||||
locale: BCP 47 locale string (e.g. "en_US", "de_DE"). Only used
|
||||
when babel is available.
|
||||
|
||||
Returns:
|
||||
Formatted string, e.g. "£12.50".
|
||||
"""
|
||||
code = currency_code.upper()
|
||||
try:
|
||||
from babel.numbers import format_currency as babel_format # type: ignore[import]
|
||||
return babel_format(amount, code, locale=locale)
|
||||
except ImportError:
|
||||
return _fallback_format(amount, code)
|
||||
|
||||
|
||||
def _fallback_format(amount: float, code: str) -> str:
|
||||
"""Format without babel using the built-in symbol table."""
|
||||
symbol, decimals = _CURRENCY_META.get(code, (f"{code} ", 2))
|
||||
# Group thousands with commas
|
||||
if decimals == 0:
|
||||
value_str = f"{int(round(amount)):,}"
|
||||
else:
|
||||
value_str = f"{amount:,.{decimals}f}"
|
||||
return f"{symbol}{value_str}"
|
||||
|
|
@ -1,174 +0,0 @@
|
|||
"""
|
||||
circuitforge_core.reranker — shared reranker module for RAG pipelines.
|
||||
|
||||
Provides a modality-aware scoring interface for ranking candidates against a
|
||||
query. Built to handle text today and audio/image/video in future branches.
|
||||
|
||||
Architecture:
|
||||
|
||||
Reranker (Protocol / trunk)
|
||||
└── TextReranker (branch)
|
||||
├── MockTextReranker — no deps, deterministic, for tests
|
||||
├── BGETextReranker — FlagEmbedding cross-encoder, MIT, Free tier
|
||||
└── Qwen3TextReranker — generative reranker, MIT/BSL, Paid tier (Phase 2)
|
||||
|
||||
Quick start (mock mode — no model required):
|
||||
|
||||
import os; os.environ["CF_RERANKER_MOCK"] = "1"
|
||||
from circuitforge_core.reranker import rerank
|
||||
|
||||
results = rerank("chicken soup", ["hearty chicken noodle", "chocolate cake", "tomato basil soup"])
|
||||
for r in results:
|
||||
print(r.rank, r.score, r.candidate[:40])
|
||||
|
||||
Real inference (BGE cross-encoder):
|
||||
|
||||
export CF_RERANKER_MODEL=BAAI/bge-reranker-base
|
||||
from circuitforge_core.reranker import rerank
|
||||
results = rerank(query, candidates, top_n=20)
|
||||
|
||||
Explicit backend (per-request or per-user):
|
||||
|
||||
from circuitforge_core.reranker import make_reranker
|
||||
reranker = make_reranker("BAAI/bge-reranker-v2-m3", backend="bge")
|
||||
results = reranker.rerank(query, candidates, top_n=10)
|
||||
|
||||
Batch scoring (efficient for large corpora):
|
||||
|
||||
from circuitforge_core.reranker import make_reranker
|
||||
reranker = make_reranker("BAAI/bge-reranker-base")
|
||||
batch = reranker.rerank_batch(queries, candidate_lists, top_n=10)
|
||||
|
||||
Environment variables:
|
||||
CF_RERANKER_MODEL model ID or path (default: "BAAI/bge-reranker-base")
|
||||
CF_RERANKER_BACKEND backend override: "bge" | "mock" (default: auto-detect)
|
||||
CF_RERANKER_MOCK set to "1" to force mock backend (no model required)
|
||||
|
||||
cf-orch service profile (Phase 3 — remote backend):
|
||||
service_type: cf-reranker
|
||||
max_mb: per-model (base ≈ 600, large ≈ 1400, 8B ≈ 8192)
|
||||
shared: true
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Sequence
|
||||
|
||||
from circuitforge_core.reranker.base import RerankResult, Reranker, TextReranker
|
||||
from circuitforge_core.reranker.adapters.mock import MockTextReranker
|
||||
|
||||
# ── Process-level singleton ───────────────────────────────────────────────────
|
||||
|
||||
_reranker: TextReranker | None = None
|
||||
|
||||
_DEFAULT_MODEL = "BAAI/bge-reranker-base"
|
||||
|
||||
|
||||
def _get_reranker() -> TextReranker:
|
||||
global _reranker
|
||||
if _reranker is None:
|
||||
_reranker = make_reranker()
|
||||
return _reranker
|
||||
|
||||
|
||||
def make_reranker(
|
||||
model_id: str | None = None,
|
||||
backend: str | None = None,
|
||||
mock: bool | None = None,
|
||||
) -> TextReranker:
|
||||
"""
|
||||
Create a TextReranker for the given model.
|
||||
|
||||
Use this when you need an explicit reranker instance (e.g. per-service
|
||||
with a specific model) rather than the process-level singleton.
|
||||
|
||||
model_id — HuggingFace model ID or local path. Defaults to
|
||||
CF_RERANKER_MODEL env var, then BAAI/bge-reranker-base.
|
||||
backend — "bge" | "mock". Auto-detected from model_id if omitted.
|
||||
mock — Force mock backend. Defaults to CF_RERANKER_MOCK env var.
|
||||
"""
|
||||
_mock = mock if mock is not None else os.environ.get("CF_RERANKER_MOCK", "") == "1"
|
||||
if _mock:
|
||||
return MockTextReranker()
|
||||
|
||||
_model_id = model_id or os.environ.get("CF_RERANKER_MODEL", _DEFAULT_MODEL)
|
||||
_backend = backend or os.environ.get("CF_RERANKER_BACKEND", "")
|
||||
|
||||
# Auto-route to cf-orch when CF_ORCH_URL is set and no explicit backend override.
|
||||
# Cloud deployments set CF_ORCH_URL; local dev leaves it unset → local inference.
|
||||
if not _backend:
|
||||
orch_url = os.environ.get("CF_ORCH_URL", "")
|
||||
if orch_url:
|
||||
from circuitforge_core.reranker.adapters.remote import RemoteTextReranker
|
||||
logger.info("[reranker] CF_ORCH_URL set — using remote cf-reranker via cf-orch")
|
||||
return RemoteTextReranker.from_cf_orch(
|
||||
orch_url=orch_url,
|
||||
service="cf-reranker",
|
||||
ttl_s=float(os.environ.get("CF_RERANKER_TTL", "3600")),
|
||||
)
|
||||
_backend = "bge" # local default
|
||||
|
||||
if _backend == "mock":
|
||||
return MockTextReranker()
|
||||
|
||||
if _backend == "bge":
|
||||
from circuitforge_core.reranker.adapters.bge import BGETextReranker
|
||||
return BGETextReranker(_model_id)
|
||||
|
||||
if _backend == "qwen3":
|
||||
from circuitforge_core.reranker.adapters.qwen3 import Qwen3TextReranker
|
||||
return Qwen3TextReranker(_model_id)
|
||||
|
||||
if _backend == "cross-encoder":
|
||||
from circuitforge_core.reranker.adapters.cross_encoder import CrossEncoderTextReranker
|
||||
return CrossEncoderTextReranker(_model_id)
|
||||
|
||||
if _backend == "cohere":
|
||||
from circuitforge_core.reranker.adapters.cohere import CohereTextReranker
|
||||
return CohereTextReranker(model=_model_id)
|
||||
|
||||
if _backend == "remote":
|
||||
from circuitforge_core.reranker.adapters.remote import RemoteTextReranker
|
||||
return RemoteTextReranker(_model_id)
|
||||
|
||||
raise ValueError(
|
||||
f"Unknown reranker backend {_backend!r}. "
|
||||
"Valid options: 'bge', 'qwen3', 'cross-encoder', 'cohere', 'remote', 'mock'."
|
||||
)
|
||||
|
||||
|
||||
# ── Convenience functions (singleton path) ────────────────────────────────────
|
||||
|
||||
|
||||
def rerank(
|
||||
query: str,
|
||||
candidates: Sequence[str],
|
||||
top_n: int = 0,
|
||||
) -> list[RerankResult]:
|
||||
"""
|
||||
Score and sort candidates against query using the process-level reranker.
|
||||
|
||||
Returns a list of RerankResult sorted by score descending (rank 0 first).
|
||||
top_n=0 returns all candidates.
|
||||
|
||||
For large corpora, prefer rerank_batch() on an explicit reranker instance
|
||||
to amortise model load time and batch the forward pass.
|
||||
"""
|
||||
return _get_reranker().rerank(query, candidates, top_n=top_n)
|
||||
|
||||
|
||||
def reset_reranker() -> None:
|
||||
"""Reset the process-level singleton. Test teardown only."""
|
||||
global _reranker
|
||||
_reranker = None
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Reranker",
|
||||
"TextReranker",
|
||||
"RerankResult",
|
||||
"MockTextReranker",
|
||||
"make_reranker",
|
||||
"rerank",
|
||||
"reset_reranker",
|
||||
]
|
||||
|
|
@ -1,137 +0,0 @@
|
|||
# circuitforge_core/reranker/adapters/bge.py — BGE cross-encoder reranker
|
||||
#
|
||||
# Requires: pip install circuitforge-core[reranker-bge]
|
||||
# Tested with FlagEmbedding>=1.2 (BAAI/bge-reranker-* family).
|
||||
#
|
||||
# MIT licensed — local inference only, no tier gate.
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from typing import Sequence
|
||||
|
||||
from circuitforge_core.reranker.base import TextReranker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Lazy import sentinel — FlagEmbedding is an optional dep.
|
||||
try:
|
||||
from FlagEmbedding import FlagReranker as _FlagReranker # type: ignore[import]
|
||||
except ImportError:
|
||||
_FlagReranker = None # type: ignore[assignment]
|
||||
|
||||
|
||||
def _cuda_available() -> bool:
|
||||
try:
|
||||
import torch
|
||||
return torch.cuda.is_available()
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
class BGETextReranker(TextReranker):
|
||||
"""
|
||||
Cross-encoder reranker using the BAAI BGE reranker family.
|
||||
|
||||
Scores (query, candidate) pairs via FlagEmbedding.FlagReranker.
|
||||
Thread-safe: a lock serialises concurrent _score_pairs calls since
|
||||
FlagReranker is not guaranteed to be reentrant.
|
||||
|
||||
Recommended free-tier models:
|
||||
BAAI/bge-reranker-base ~570MB VRAM, fast
|
||||
BAAI/bge-reranker-v2-m3 ~570MB VRAM, multilingual
|
||||
BAAI/bge-reranker-large ~1.3GB VRAM, higher quality
|
||||
|
||||
Usage:
|
||||
reranker = BGETextReranker("BAAI/bge-reranker-base")
|
||||
results = reranker.rerank("chicken soup recipe", ["recipe 1...", "recipe 2..."])
|
||||
"""
|
||||
|
||||
def __init__(self, model_id: str = "BAAI/bge-reranker-base") -> None:
|
||||
self._model_id = model_id
|
||||
self._reranker: object | None = None
|
||||
self._lock = threading.Lock()
|
||||
|
||||
@property
|
||||
def model_id(self) -> str:
|
||||
return self._model_id
|
||||
|
||||
def load(self) -> None:
|
||||
"""Explicitly load model weights. Called automatically on first rerank()."""
|
||||
if _FlagReranker is None:
|
||||
raise ImportError(
|
||||
"FlagEmbedding is not installed. "
|
||||
"Run: pip install circuitforge-core[reranker-bge]"
|
||||
)
|
||||
with self._lock:
|
||||
if self._reranker is None:
|
||||
logger.info("Loading BGE reranker: %s (fp16=%s)", self._model_id, _cuda_available())
|
||||
self._reranker = _FlagReranker(self._model_id, use_fp16=_cuda_available())
|
||||
|
||||
def unload(self) -> None:
|
||||
"""Release model weights. Useful for VRAM management between tasks."""
|
||||
with self._lock:
|
||||
self._reranker = None
|
||||
|
||||
def _score_pairs(self, query: str, candidates: list[str]) -> list[float]:
|
||||
if self._reranker is None:
|
||||
self.load()
|
||||
pairs = [[query, c] for c in candidates]
|
||||
with self._lock:
|
||||
scores: list[float] = self._reranker.compute_score( # type: ignore[union-attr]
|
||||
pairs, normalize=True
|
||||
)
|
||||
# compute_score may return a single float when given one pair.
|
||||
if isinstance(scores, float):
|
||||
scores = [scores]
|
||||
return scores
|
||||
|
||||
def rerank_batch(
|
||||
self,
|
||||
queries: Sequence[str],
|
||||
candidates: Sequence[Sequence[str]],
|
||||
top_n: int = 0,
|
||||
) -> list[list[object]]:
|
||||
"""Batch all pairs into a single compute_score call for efficiency."""
|
||||
from circuitforge_core.reranker.base import RerankResult
|
||||
|
||||
if self._reranker is None:
|
||||
self.load()
|
||||
|
||||
# Flatten all pairs, recording group boundaries for reconstruction.
|
||||
all_pairs: list[list[str]] = []
|
||||
group_sizes: list[int] = []
|
||||
for q, cs in zip(queries, candidates):
|
||||
cands = list(cs)
|
||||
group_sizes.append(len(cands))
|
||||
all_pairs.extend([q, c] for c in cands)
|
||||
|
||||
if not all_pairs:
|
||||
return [[] for _ in queries]
|
||||
|
||||
with self._lock:
|
||||
all_scores: list[float] = self._reranker.compute_score( # type: ignore[union-attr]
|
||||
all_pairs, normalize=True
|
||||
)
|
||||
if isinstance(all_scores, float):
|
||||
all_scores = [all_scores]
|
||||
|
||||
# Reconstruct per-query result lists.
|
||||
results: list[list[RerankResult]] = []
|
||||
offset = 0
|
||||
for (q, cs), size in zip(zip(queries, candidates), group_sizes):
|
||||
cands = list(cs)
|
||||
scores = all_scores[offset : offset + size]
|
||||
offset += size
|
||||
sorted_results = sorted(
|
||||
(RerankResult(candidate=c, score=s, rank=0) for c, s in zip(cands, scores)),
|
||||
key=lambda r: r.score,
|
||||
reverse=True,
|
||||
)
|
||||
if top_n > 0:
|
||||
sorted_results = sorted_results[:top_n]
|
||||
results.append([
|
||||
RerankResult(candidate=r.candidate, score=r.score, rank=i)
|
||||
for i, r in enumerate(sorted_results)
|
||||
])
|
||||
return results
|
||||
|
|
@ -1,94 +0,0 @@
|
|||
# circuitforge_core/reranker/adapters/cohere.py — Cohere Rerank API (BYOK cloud)
|
||||
#
|
||||
# Requires: pip install circuitforge-core[reranker-cohere]
|
||||
# API key: set COHERE_API_KEY env var, or pass api_key= explicitly.
|
||||
#
|
||||
# Models (as of 2026):
|
||||
# rerank-english-v3.0 English-only, highest quality
|
||||
# rerank-multilingual-v3.0 Multilingual
|
||||
# rerank-english-v2.0 Legacy, lower cost
|
||||
#
|
||||
# BYOK unlock path: free-tier users who supply their own Cohere key get cloud
|
||||
# reranking without needing a cf-orch node. Same pattern as the Anthropic
|
||||
# backend in LLMRouter.
|
||||
#
|
||||
# MIT licensed.
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Sequence
|
||||
|
||||
from circuitforge_core.reranker.base import TextReranker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
import cohere as _cohere # type: ignore[import]
|
||||
except ImportError:
|
||||
_cohere = None # type: ignore[assignment]
|
||||
|
||||
_DEFAULT_MODEL = "rerank-english-v3.0"
|
||||
|
||||
|
||||
class CohereTextReranker(TextReranker):
|
||||
"""
|
||||
Cloud reranker backed by the Cohere Rerank API.
|
||||
|
||||
BYOK (bring your own key): pass api_key= or set COHERE_API_KEY in the
|
||||
environment. No model weights loaded locally.
|
||||
|
||||
Usage:
|
||||
reranker = CohereTextReranker() # reads COHERE_API_KEY from env
|
||||
results = reranker.rerank("chicken soup recipe", ["recipe 1...", "recipe 2..."])
|
||||
|
||||
With an explicit key and model:
|
||||
reranker = CohereTextReranker(
|
||||
api_key="co-...",
|
||||
model="rerank-multilingual-v3.0",
|
||||
)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
model: str = _DEFAULT_MODEL,
|
||||
max_chunks_per_doc: int = 1,
|
||||
) -> None:
|
||||
self._api_key_arg = api_key
|
||||
self._model = model
|
||||
self._max_chunks_per_doc = max_chunks_per_doc
|
||||
|
||||
@property
|
||||
def model_id(self) -> str:
|
||||
return f"cohere:{self._model}"
|
||||
|
||||
def _get_client(self) -> object:
|
||||
if _cohere is None:
|
||||
raise ImportError(
|
||||
"cohere is not installed. "
|
||||
"Run: pip install circuitforge-core[reranker-cohere]"
|
||||
)
|
||||
api_key = self._api_key_arg or os.environ.get("COHERE_API_KEY", "")
|
||||
if not api_key:
|
||||
raise RuntimeError(
|
||||
"Cohere API key is not set. "
|
||||
"Pass api_key= to CohereTextReranker or set COHERE_API_KEY."
|
||||
)
|
||||
return _cohere.Client(api_key=api_key)
|
||||
|
||||
def _score_pairs(self, query: str, candidates: list[str]) -> list[float]:
|
||||
client = self._get_client()
|
||||
response = client.rerank( # type: ignore[union-attr]
|
||||
query=query,
|
||||
documents=candidates,
|
||||
model=self._model,
|
||||
top_n=len(candidates),
|
||||
max_chunks_per_doc=self._max_chunks_per_doc,
|
||||
)
|
||||
# response.results is sorted by relevance_score desc; rebuild
|
||||
# in original candidate order so TextReranker.rerank() re-sorts correctly.
|
||||
score_map: dict[int, float] = {
|
||||
r.index: r.relevance_score for r in response.results
|
||||
}
|
||||
return [score_map.get(i, 0.0) for i in range(len(candidates))]
|
||||
|
|
@ -1,96 +0,0 @@
|
|||
# circuitforge_core/reranker/adapters/cross_encoder.py — sentence-transformers CrossEncoder
|
||||
#
|
||||
# Requires: pip install circuitforge-core[reranker-cross-encoder]
|
||||
#
|
||||
# Covers models not in the FlagEmbedding ecosystem:
|
||||
# mixedbread-ai/mxbai-rerank-base-v1 ~570MB VRAM, strong general-purpose
|
||||
# mixedbread-ai/mxbai-rerank-large-v1 ~1.3GB VRAM, higher quality
|
||||
# cross-encoder/ms-marco-MiniLM-L-6-v2 ~90MB, fast, English-only
|
||||
# cross-encoder/ms-marco-MiniLM-L-12-v2 ~130MB, balanced
|
||||
# jinaai/jina-reranker-v2-base-multilingual ~280MB, multilingual
|
||||
#
|
||||
# MIT licensed.
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from typing import Sequence
|
||||
|
||||
from circuitforge_core.reranker.base import TextReranker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
from sentence_transformers import CrossEncoder as _CrossEncoder # type: ignore[import]
|
||||
except ImportError:
|
||||
_CrossEncoder = None # type: ignore[assignment]
|
||||
|
||||
|
||||
def _cuda_available() -> bool:
|
||||
try:
|
||||
import torch
|
||||
return torch.cuda.is_available()
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
class CrossEncoderTextReranker(TextReranker):
|
||||
"""
|
||||
Cross-encoder reranker using the sentence-transformers CrossEncoder class.
|
||||
|
||||
Broader model compatibility than BGETextReranker — any HuggingFace model
|
||||
with a sequence-classification head works here. Particularly well-suited
|
||||
for the mxbai-rerank and ms-marco families.
|
||||
|
||||
Usage:
|
||||
reranker = CrossEncoderTextReranker("mixedbread-ai/mxbai-rerank-base-v1")
|
||||
results = reranker.rerank("chicken soup recipe", ["recipe 1...", "recipe 2..."])
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str = "mixedbread-ai/mxbai-rerank-base-v1",
|
||||
max_length: int = 512,
|
||||
) -> None:
|
||||
self._model_id = model_id
|
||||
self._max_length = max_length
|
||||
self._model: object | None = None
|
||||
self._lock = threading.Lock()
|
||||
|
||||
@property
|
||||
def model_id(self) -> str:
|
||||
return self._model_id
|
||||
|
||||
def load(self) -> None:
|
||||
"""Explicitly load model weights. Called automatically on first rerank()."""
|
||||
if _CrossEncoder is None:
|
||||
raise ImportError(
|
||||
"sentence-transformers is not installed. "
|
||||
"Run: pip install circuitforge-core[reranker-cross-encoder]"
|
||||
)
|
||||
with self._lock:
|
||||
if self._model is not None:
|
||||
return
|
||||
device = "cuda" if _cuda_available() else "cpu"
|
||||
logger.info(
|
||||
"Loading CrossEncoder reranker: %s (device=%s)", self._model_id, device
|
||||
)
|
||||
self._model = _CrossEncoder(
|
||||
self._model_id,
|
||||
max_length=self._max_length,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def unload(self) -> None:
|
||||
"""Release model weights."""
|
||||
with self._lock:
|
||||
self._model = None
|
||||
|
||||
def _score_pairs(self, query: str, candidates: list[str]) -> list[float]:
|
||||
if self._model is None:
|
||||
self.load()
|
||||
pairs = [(query, c) for c in candidates]
|
||||
with self._lock:
|
||||
raw = self._model.predict(pairs) # type: ignore[union-attr]
|
||||
# predict() returns a numpy array or list; normalise to plain floats.
|
||||
return [float(s) for s in raw]
|
||||
|
|
@ -1,37 +0,0 @@
|
|||
# circuitforge_core/reranker/adapters/mock.py — deterministic mock reranker
|
||||
#
|
||||
# Always importable, no optional deps. Used in tests and CF_RERANKER_MOCK=1 mode.
|
||||
# Scores by descending overlap of query tokens with candidate tokens so results
|
||||
# are deterministic and meaningful enough to exercise product code paths.
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Sequence
|
||||
|
||||
from circuitforge_core.reranker.base import RerankResult, TextReranker
|
||||
|
||||
|
||||
class MockTextReranker(TextReranker):
|
||||
"""Deterministic reranker for tests. No model weights required.
|
||||
|
||||
Scoring: Jaccard similarity between query token set and candidate token set.
|
||||
Ties broken by candidate length (shorter wins) then lexicographic order,
|
||||
so test assertions can be written against a stable ordering.
|
||||
"""
|
||||
|
||||
_MODEL_ID = "mock"
|
||||
|
||||
@property
|
||||
def model_id(self) -> str:
|
||||
return self._MODEL_ID
|
||||
|
||||
def _score_pairs(self, query: str, candidates: list[str]) -> list[float]:
|
||||
q_tokens = set(query.lower().split())
|
||||
scores: list[float] = []
|
||||
for candidate in candidates:
|
||||
c_tokens = set(candidate.lower().split())
|
||||
union = q_tokens | c_tokens
|
||||
if not union:
|
||||
scores.append(0.0)
|
||||
else:
|
||||
scores.append(len(q_tokens & c_tokens) / len(union))
|
||||
return scores
|
||||
|
|
@ -1,239 +0,0 @@
|
|||
# circuitforge_core/reranker/adapters/qwen3.py — Qwen3-Reranker adapter
|
||||
#
|
||||
# Requires: pip install circuitforge-core[reranker-qwen3]
|
||||
# Tested with: Qwen/Qwen3-Reranker-0.6B, -1.5B, -8B
|
||||
#
|
||||
# Scoring mechanism (generative reranker):
|
||||
# Rather than generating a full response, we pre-fill the assistant turn with
|
||||
# the <think>\n\n</think>\n block and read the logits at the last input token
|
||||
# position. The softmax probability of "yes" vs "no" at that position is the
|
||||
# relevance score — one forward pass per batch, no generation loop.
|
||||
#
|
||||
# Prompt format (Qwen3 chat template):
|
||||
# system: "Judge whether the Document meets the requirements based on the
|
||||
# Query and the Instruct. Note that the answer can only be 'yes'
|
||||
# or 'no'."
|
||||
# user: "<Instruct>: {task}\n<Query>: {query}\n<Document>: {doc}"
|
||||
# assistant (pre-filled): "<think>\n\n</think>\n\n"
|
||||
#
|
||||
# MIT licensed.
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from typing import Sequence
|
||||
|
||||
from circuitforge_core.reranker.base import TextReranker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Optional heavy deps — lazy-imported at load() time.
|
||||
try:
|
||||
import torch as _torch # type: ignore[import]
|
||||
except ImportError:
|
||||
_torch = None # type: ignore[assignment]
|
||||
|
||||
try:
|
||||
from transformers import AutoModelForCausalLM as _AutoModel # type: ignore[import]
|
||||
from transformers import AutoTokenizer as _AutoTokenizer # type: ignore[import]
|
||||
except ImportError:
|
||||
_AutoModel = None # type: ignore[assignment]
|
||||
_AutoTokenizer = None # type: ignore[assignment]
|
||||
|
||||
# System prompt used for all reranking tasks.
|
||||
_SYSTEM_PROMPT = (
|
||||
"Judge whether the Document meets the requirements based on the Query and "
|
||||
'the Instruct. Note that the answer can only be "yes" or "no".'
|
||||
)
|
||||
|
||||
# Default task instruction — products can override via make_reranker(task=...).
|
||||
_DEFAULT_TASK = "Given a query, retrieve the most relevant document that answers the query."
|
||||
|
||||
# The pre-filled assistant turn that puts the model past its thinking block
|
||||
# so the very next token position scores "yes" vs "no".
|
||||
_ASSISTANT_PREFILL = "<think>\n\n</think>\n\n"
|
||||
|
||||
|
||||
def _requires_deps() -> None:
|
||||
if _torch is None:
|
||||
raise ImportError(
|
||||
"torch is not installed. Run: pip install circuitforge-core[reranker-qwen3]"
|
||||
)
|
||||
if _AutoModel is None:
|
||||
raise ImportError(
|
||||
"transformers is not installed. Run: pip install circuitforge-core[reranker-qwen3]"
|
||||
)
|
||||
|
||||
|
||||
class Qwen3TextReranker(TextReranker):
|
||||
"""
|
||||
Generative reranker using the Qwen3-Reranker model family.
|
||||
|
||||
Scores candidates by reading yes/no token logits at the last input position
|
||||
after pre-filling the assistant thinking block. One forward pass covers an
|
||||
entire batch — efficient for ranking large candidate lists.
|
||||
|
||||
Model options (by tier):
|
||||
Free: Qwen/Qwen3-Reranker-0.6B (~1.2 GB VRAM fp16)
|
||||
Qwen/Qwen3-Reranker-1.5B (~3.0 GB VRAM fp16)
|
||||
Paid: Qwen/Qwen3-Reranker-8B (~16 GB VRAM fp16, or ~9 GB int8)
|
||||
|
||||
Usage:
|
||||
reranker = Qwen3TextReranker("Qwen/Qwen3-Reranker-0.6B")
|
||||
results = reranker.rerank("chicken soup recipe", ["recipe 1...", "recipe 2..."])
|
||||
|
||||
With a custom task instruction:
|
||||
reranker = Qwen3TextReranker(
|
||||
"Qwen/Qwen3-Reranker-1.5B",
|
||||
task="Given a job description, retrieve the most relevant resume.",
|
||||
)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str = "Qwen/Qwen3-Reranker-0.6B",
|
||||
task: str = _DEFAULT_TASK,
|
||||
device: str | None = None,
|
||||
dtype: str = "float16",
|
||||
batch_size: int = 32,
|
||||
) -> None:
|
||||
self._model_id = model_id
|
||||
self._task = task
|
||||
self._device = device # None = auto-detect at load time
|
||||
self._dtype_str = dtype
|
||||
self._batch_size = batch_size
|
||||
self._model: object | None = None
|
||||
self._tokenizer: object | None = None
|
||||
self._yes_id: int | None = None
|
||||
self._no_id: int | None = None
|
||||
self._lock = threading.Lock()
|
||||
|
||||
@property
|
||||
def model_id(self) -> str:
|
||||
return self._model_id
|
||||
|
||||
def load(self) -> None:
|
||||
"""Explicitly load model weights. Called automatically on first rerank()."""
|
||||
_requires_deps()
|
||||
with self._lock:
|
||||
if self._model is not None:
|
||||
return
|
||||
device = self._device or ("cuda" if _torch.cuda.is_available() else "cpu")
|
||||
dtype_map: dict[str, object] = {
|
||||
"float16": _torch.float16,
|
||||
"bfloat16": _torch.bfloat16,
|
||||
"float32": _torch.float32,
|
||||
}
|
||||
torch_dtype = dtype_map.get(self._dtype_str, _torch.float16)
|
||||
|
||||
logger.info(
|
||||
"Loading Qwen3 reranker: %s (device=%s dtype=%s)",
|
||||
self._model_id, device, self._dtype_str,
|
||||
)
|
||||
tokenizer = _AutoTokenizer.from_pretrained(
|
||||
self._model_id, trust_remote_code=True
|
||||
)
|
||||
model = _AutoModel.from_pretrained(
|
||||
self._model_id,
|
||||
torch_dtype=torch_dtype,
|
||||
device_map=device,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
model.eval()
|
||||
|
||||
# Resolve the token IDs for "yes" and "no" once at load time.
|
||||
# Qwen tokenizers encode single-word tokens without a leading space.
|
||||
yes_ids: list[int] = tokenizer.encode("yes", add_special_tokens=False)
|
||||
no_ids: list[int] = tokenizer.encode("no", add_special_tokens=False)
|
||||
if not yes_ids or not no_ids:
|
||||
raise RuntimeError(
|
||||
f"Could not resolve 'yes'/'no' token IDs from tokenizer {self._model_id!r}. "
|
||||
"This model may not be a Qwen3-Reranker variant."
|
||||
)
|
||||
|
||||
self._tokenizer = tokenizer
|
||||
self._model = model
|
||||
self._yes_id = yes_ids[0]
|
||||
self._no_id = no_ids[0]
|
||||
|
||||
def unload(self) -> None:
|
||||
"""Release model weights. Useful for VRAM management between tasks."""
|
||||
with self._lock:
|
||||
self._model = None
|
||||
self._tokenizer = None
|
||||
self._yes_id = None
|
||||
self._no_id = None
|
||||
|
||||
def _build_prompt(self, query: str, document: str) -> str:
|
||||
"""Format a single (query, document) pair as a chat-template prompt."""
|
||||
messages = [
|
||||
{"role": "system", "content": _SYSTEM_PROMPT},
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
f"<Instruct>: {self._task}\n"
|
||||
f"<Query>: {query}\n"
|
||||
f"<Document>: {document}"
|
||||
),
|
||||
},
|
||||
]
|
||||
# apply_chat_template without tokenization so we can append the prefill.
|
||||
text: str = self._tokenizer.apply_chat_template( # type: ignore[union-attr]
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
return text + _ASSISTANT_PREFILL
|
||||
|
||||
def _score_pairs(self, query: str, candidates: list[str]) -> list[float]:
|
||||
if self._model is None:
|
||||
self.load()
|
||||
return self._score_in_batches(query, candidates)
|
||||
|
||||
def _score_in_batches(self, query: str, candidates: list[str]) -> list[float]:
|
||||
"""Score all (query, candidate) pairs, splitting into sub-batches."""
|
||||
all_scores: list[float] = []
|
||||
for start in range(0, len(candidates), self._batch_size):
|
||||
batch = candidates[start : start + self._batch_size]
|
||||
all_scores.extend(self._score_batch(query, batch))
|
||||
return all_scores
|
||||
|
||||
def _score_batch(self, query: str, candidates: list[str]) -> list[float]:
|
||||
"""Single forward pass for one sub-batch. Returns a score per candidate."""
|
||||
prompts = [self._build_prompt(query, c) for c in candidates]
|
||||
|
||||
# Left-pad so the last token position is consistent across all sequences.
|
||||
tokenizer = self._tokenizer # type: ignore[union-attr]
|
||||
original_side = getattr(tokenizer, "padding_side", "right")
|
||||
tokenizer.padding_side = "left"
|
||||
try:
|
||||
encoded = tokenizer(
|
||||
prompts,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=4096,
|
||||
)
|
||||
finally:
|
||||
tokenizer.padding_side = original_side
|
||||
|
||||
model = self._model # type: ignore[union-attr]
|
||||
device = next(model.parameters()).device # type: ignore[union-attr]
|
||||
input_ids = encoded["input_ids"].to(device)
|
||||
attention_mask = encoded["attention_mask"].to(device)
|
||||
|
||||
with self._lock:
|
||||
with _torch.no_grad():
|
||||
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
|
||||
|
||||
# logits shape: (batch, seq_len, vocab_size)
|
||||
# Last position [-1] is the token the model would output next.
|
||||
last_logits = outputs.logits[:, -1, :] # (batch, vocab)
|
||||
yes_logits = last_logits[:, self._yes_id] # (batch,)
|
||||
no_logits = last_logits[:, self._no_id] # (batch,)
|
||||
|
||||
# Softmax over yes/no only — score = P(yes | query, doc).
|
||||
stacked = _torch.stack([yes_logits, no_logits], dim=-1) # (batch, 2)
|
||||
probs = _torch.softmax(stacked, dim=-1)
|
||||
scores: list[float] = probs[:, 0].float().cpu().tolist()
|
||||
return scores
|
||||
|
|
@ -1,131 +0,0 @@
|
|||
# circuitforge_core/reranker/adapters/remote.py — HTTP remote reranker adapter
|
||||
#
|
||||
# Calls a cf-reranker service endpoint (cf-orch allocated or static URL).
|
||||
# No model weights loaded locally — all inference runs on the remote node.
|
||||
#
|
||||
# MIT licensed.
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Sequence
|
||||
|
||||
import requests
|
||||
|
||||
from circuitforge_core.reranker.base import TextReranker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Default timeout for a single /rerank call (seconds).
|
||||
# Large candidate lists may take longer — callers can pass timeout= explicitly.
|
||||
_DEFAULT_TIMEOUT = 30
|
||||
|
||||
|
||||
class RemoteTextReranker(TextReranker):
|
||||
"""
|
||||
Reranker that delegates scoring to a remote cf-reranker HTTP service.
|
||||
|
||||
The remote service must implement POST /rerank with the request body::
|
||||
|
||||
{"query": str, "candidates": [str, ...], "top_n": int}
|
||||
|
||||
and return::
|
||||
|
||||
{"results": [{"candidate": str, "score": float, "rank": int}, ...]}
|
||||
|
||||
cf-orch allocation (recommended — starts service on-demand):
|
||||
reranker = RemoteTextReranker.from_cf_orch(
|
||||
orch_url="http://10.1.10.71:7700",
|
||||
service="cf-reranker",
|
||||
model_candidates=["qwen3-0.6b"],
|
||||
)
|
||||
|
||||
Static URL (e.g. dedicated node already running cf-reranker):
|
||||
reranker = RemoteTextReranker("http://10.1.10.10:8011")
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str,
|
||||
timeout: int = _DEFAULT_TIMEOUT,
|
||||
_model_id: str = "remote",
|
||||
) -> None:
|
||||
self._base_url = base_url.rstrip("/")
|
||||
self._timeout = timeout
|
||||
self._model_id_str = _model_id
|
||||
|
||||
@property
|
||||
def model_id(self) -> str:
|
||||
return self._model_id_str
|
||||
|
||||
@classmethod
|
||||
def from_cf_orch(
|
||||
cls,
|
||||
orch_url: str,
|
||||
service: str = "cf-reranker",
|
||||
model_candidates: list[str] | None = None,
|
||||
ttl_s: float = 3600.0,
|
||||
timeout: int = _DEFAULT_TIMEOUT,
|
||||
) -> "RemoteTextReranker":
|
||||
"""
|
||||
Allocate a cf-reranker service via cf-orch and return a configured adapter.
|
||||
|
||||
Blocks until allocation succeeds or raises on failure. The returned
|
||||
adapter is valid for the duration of the TTL; create a new one if the
|
||||
lease expires.
|
||||
|
||||
This is a one-shot allocation — the caller owns the lifetime. For
|
||||
long-running services, prefer the static URL constructor and let
|
||||
cf-orch manage the process independently.
|
||||
"""
|
||||
try:
|
||||
from circuitforge_orch.client import CFOrchClient # type: ignore[import]
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"circuitforge_orch is not installed — cannot allocate via cf-orch."
|
||||
) from exc
|
||||
|
||||
client = CFOrchClient(orch_url)
|
||||
ctx = client.allocate(
|
||||
service,
|
||||
model_candidates=model_candidates or [],
|
||||
ttl_s=ttl_s,
|
||||
caller="reranker-remote",
|
||||
)
|
||||
alloc = ctx.__enter__()
|
||||
# Note: caller is responsible for ctx.__exit__() when done.
|
||||
# We stash it on the instance so callers can call release().
|
||||
instance = cls(
|
||||
base_url=alloc.url,
|
||||
timeout=timeout,
|
||||
_model_id=f"remote:{service}",
|
||||
)
|
||||
instance._orch_ctx = ctx # type: ignore[attr-defined]
|
||||
return instance
|
||||
|
||||
def release(self) -> None:
|
||||
"""Release the cf-orch allocation if this adapter was created via from_cf_orch()."""
|
||||
ctx = getattr(self, "_orch_ctx", None)
|
||||
if ctx is not None:
|
||||
try:
|
||||
ctx.__exit__(None, None, None)
|
||||
except Exception:
|
||||
pass
|
||||
self._orch_ctx = None # type: ignore[attr-defined]
|
||||
|
||||
def _score_pairs(self, query: str, candidates: list[str]) -> list[float]:
|
||||
url = f"{self._base_url}/rerank"
|
||||
payload = {"query": query, "candidates": candidates, "top_n": 0}
|
||||
try:
|
||||
resp = requests.post(url, json=payload, timeout=self._timeout)
|
||||
resp.raise_for_status()
|
||||
except requests.RequestException as exc:
|
||||
raise RuntimeError(
|
||||
f"Remote reranker at {url!r} failed: {exc}"
|
||||
) from exc
|
||||
|
||||
data = resp.json()
|
||||
# Build a score-per-candidate list in the original order.
|
||||
score_map: dict[str, float] = {
|
||||
r["candidate"]: r["score"] for r in data["results"]
|
||||
}
|
||||
return [score_map.get(c, 0.0) for c in candidates]
|
||||
|
|
@ -1,169 +0,0 @@
|
|||
"""
|
||||
circuitforge_core.reranker.app — cf-reranker FastAPI service.
|
||||
|
||||
Managed by cf-orch as a process-type service. cf-orch starts this via:
|
||||
|
||||
python -m circuitforge_core.reranker.app \
|
||||
--model BAAI/bge-reranker-base \
|
||||
--backend bge \
|
||||
--port 8011 \
|
||||
--gpu-id 0
|
||||
|
||||
Or with Qwen3:
|
||||
|
||||
python -m circuitforge_core.reranker.app \
|
||||
--model Qwen/Qwen3-Reranker-0.6B \
|
||||
--backend qwen3 \
|
||||
--port 8011 \
|
||||
--gpu-id 0 \
|
||||
--dtype float16
|
||||
|
||||
Endpoints:
|
||||
GET /health → {"status": "ok", "model": "...", "backend": "...", "vram_mb": n}
|
||||
POST /rerank → RerankResponse
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ── Request / response models ─────────────────────────────────────────────────
|
||||
|
||||
class RerankRequest(BaseModel):
|
||||
query: str
|
||||
candidates: list[str]
|
||||
top_n: int = 0
|
||||
|
||||
|
||||
class RerankResultItem(BaseModel):
|
||||
candidate: str
|
||||
score: float
|
||||
rank: int
|
||||
|
||||
|
||||
class RerankResponse(BaseModel):
|
||||
results: list[RerankResultItem]
|
||||
model: str
|
||||
|
||||
|
||||
class HealthResponse(BaseModel):
|
||||
status: str
|
||||
model: str
|
||||
backend: str
|
||||
vram_mb: int
|
||||
|
||||
|
||||
# ── VRAM estimates by backend/model family ────────────────────────────────────
|
||||
|
||||
_VRAM_TABLE: dict[str, int] = {
|
||||
"bge-reranker-base": 570,
|
||||
"bge-reranker-large": 1300,
|
||||
"bge-reranker-v2-m3": 570,
|
||||
"mxbai-rerank-base-v1": 570,
|
||||
"mxbai-rerank-large-v1": 1300,
|
||||
"ms-marco-MiniLM-L-6-v2": 90,
|
||||
"ms-marco-MiniLM-L-12-v2": 130,
|
||||
"Qwen3-Reranker-0.6B": 1200,
|
||||
"Qwen3-Reranker-1.5B": 3000,
|
||||
"Qwen3-Reranker-8B": 16000,
|
||||
}
|
||||
|
||||
def _estimate_vram(model_id: str) -> int:
|
||||
for key, mb in _VRAM_TABLE.items():
|
||||
if key in model_id:
|
||||
return mb
|
||||
return 1024 # safe default
|
||||
|
||||
|
||||
# ── App factory ───────────────────────────────────────────────────────────────
|
||||
|
||||
def create_app(model_id: str, backend: str, dtype: str, mock: bool) -> FastAPI:
|
||||
from circuitforge_core.reranker import make_reranker
|
||||
|
||||
app = FastAPI(title="cf-reranker", version="0.1.0")
|
||||
_reranker = make_reranker(model_id=model_id, backend=backend, mock=mock)
|
||||
_vram_mb = _estimate_vram(model_id)
|
||||
|
||||
logger.info("cf-reranker ready: model=%r backend=%r vram=%dMB", model_id, backend, _vram_mb)
|
||||
|
||||
@app.get("/health", response_model=HealthResponse)
|
||||
async def health() -> HealthResponse:
|
||||
return HealthResponse(
|
||||
status="ok",
|
||||
model=_reranker.model_id,
|
||||
backend=backend,
|
||||
vram_mb=_vram_mb,
|
||||
)
|
||||
|
||||
@app.post("/rerank", response_model=RerankResponse)
|
||||
async def rerank(req: RerankRequest) -> RerankResponse:
|
||||
if not req.candidates:
|
||||
raise HTTPException(status_code=400, detail="candidates must not be empty")
|
||||
try:
|
||||
results = _reranker.rerank(req.query, req.candidates, top_n=req.top_n)
|
||||
except Exception as exc:
|
||||
logger.exception("rerank failed")
|
||||
raise HTTPException(status_code=500, detail=str(exc)) from exc
|
||||
return RerankResponse(
|
||||
results=[
|
||||
RerankResultItem(candidate=r.candidate, score=r.score, rank=r.rank)
|
||||
for r in results
|
||||
],
|
||||
model=_reranker.model_id,
|
||||
)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
# ── CLI entry point ───────────────────────────────────────────────────────────
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="cf-reranker — CircuitForge reranker service")
|
||||
parser.add_argument(
|
||||
"--model", default="BAAI/bge-reranker-base",
|
||||
help="HuggingFace model ID or local path",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--backend", default="bge",
|
||||
choices=["bge", "qwen3", "cross-encoder", "mock"],
|
||||
help="Reranker backend",
|
||||
)
|
||||
parser.add_argument("--port", type=int, default=8011)
|
||||
parser.add_argument("--host", default="0.0.0.0")
|
||||
parser.add_argument("--gpu-id", type=int, default=0)
|
||||
parser.add_argument(
|
||||
"--dtype", default="float16",
|
||||
choices=["float16", "bfloat16", "float32"],
|
||||
)
|
||||
parser.add_argument("--mock", action="store_true",
|
||||
help="Run with mock backend (no GPU, for testing)")
|
||||
args = parser.parse_args()
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s %(levelname)s %(name)s %(message)s",
|
||||
)
|
||||
|
||||
if args.backend != "mock" and not args.mock:
|
||||
os.environ.setdefault("CUDA_VISIBLE_DEVICES", str(args.gpu_id))
|
||||
|
||||
mock = args.mock or os.environ.get("CF_RERANKER_MOCK", "") == "1"
|
||||
app = create_app(
|
||||
model_id=args.model,
|
||||
backend=args.backend,
|
||||
dtype=args.dtype,
|
||||
mock=mock,
|
||||
)
|
||||
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -1,135 +0,0 @@
|
|||
# circuitforge_core/reranker/base.py — Reranker Protocol + modality branches
|
||||
#
|
||||
# MIT licensed. The Protocol and RerankResult are always importable.
|
||||
# Adapter implementations (BGE, Qwen3, cf-orch remote) require optional extras.
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Protocol, Sequence, runtime_checkable
|
||||
|
||||
|
||||
# ── Result type ───────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RerankResult:
|
||||
"""A single scored candidate returned by a reranker.
|
||||
|
||||
rank is 0-based (0 = highest score).
|
||||
candidate preserves the original object — text, Path, or any other type
|
||||
passed in by the caller, so products don't need to re-index the input list.
|
||||
"""
|
||||
candidate: Any
|
||||
score: float
|
||||
rank: int
|
||||
|
||||
|
||||
# ── Trunk: generic Reranker Protocol ─────────────────────────────────────────
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Reranker(Protocol):
|
||||
"""
|
||||
Abstract interface for all reranker adapters.
|
||||
|
||||
Implementations must be safe to construct once and call concurrently;
|
||||
internal state (loaded model weights) should be guarded by a lock if
|
||||
the backend is not thread-safe.
|
||||
|
||||
query — the reference item to rank against (typically a text query)
|
||||
candidates — ordered collection of items to score; ordering is preserved
|
||||
in the returned list, which is sorted by score descending
|
||||
top_n — return at most this many results; 0 means return all
|
||||
|
||||
Returns a list of RerankResult sorted by score descending (rank 0 first).
|
||||
"""
|
||||
|
||||
def rerank(
|
||||
self,
|
||||
query: str,
|
||||
candidates: Sequence[Any],
|
||||
top_n: int = 0,
|
||||
) -> list[RerankResult]:
|
||||
...
|
||||
|
||||
def rerank_batch(
|
||||
self,
|
||||
queries: Sequence[str],
|
||||
candidates: Sequence[Sequence[Any]],
|
||||
top_n: int = 0,
|
||||
) -> list[list[RerankResult]]:
|
||||
"""Score multiple (query, candidates) pairs in one call.
|
||||
|
||||
Default implementation loops over rerank(); adapters may override
|
||||
with a true batched forward pass for efficiency.
|
||||
"""
|
||||
...
|
||||
|
||||
@property
|
||||
def model_id(self) -> str:
|
||||
"""Identifier for the loaded model (name, path, or URL)."""
|
||||
...
|
||||
|
||||
|
||||
# ── Branch: text-specific reranker ───────────────────────────────────────────
|
||||
|
||||
|
||||
class TextReranker:
|
||||
"""
|
||||
Base class for text-to-text rerankers.
|
||||
|
||||
Subclasses implement _score_pairs(query, candidates) and get rerank()
|
||||
and rerank_batch() for free. The default rerank_batch() loops over
|
||||
_score_pairs; override it in adapters that support native batching.
|
||||
|
||||
candidates must be strings. query is always a string.
|
||||
"""
|
||||
|
||||
@property
|
||||
def model_id(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def _score_pairs(
|
||||
self,
|
||||
query: str,
|
||||
candidates: list[str],
|
||||
) -> list[float]:
|
||||
"""Return a score per candidate (higher = more relevant).
|
||||
|
||||
Called by rerank() and rerank_batch(). Must return a list of the
|
||||
same length as candidates, in the same order.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def rerank(
|
||||
self,
|
||||
query: str,
|
||||
candidates: Sequence[str],
|
||||
top_n: int = 0,
|
||||
) -> list[RerankResult]:
|
||||
cands = list(candidates)
|
||||
if not cands:
|
||||
return []
|
||||
scores = self._score_pairs(query, cands)
|
||||
results = sorted(
|
||||
(RerankResult(candidate=c, score=s, rank=0) for c, s in zip(cands, scores)),
|
||||
key=lambda r: r.score,
|
||||
reverse=True,
|
||||
)
|
||||
if top_n > 0:
|
||||
results = results[:top_n]
|
||||
return [
|
||||
RerankResult(candidate=r.candidate, score=r.score, rank=i)
|
||||
for i, r in enumerate(results)
|
||||
]
|
||||
|
||||
def rerank_batch(
|
||||
self,
|
||||
queries: Sequence[str],
|
||||
candidates: Sequence[Sequence[str]],
|
||||
top_n: int = 0,
|
||||
) -> list[list[RerankResult]]:
|
||||
return [
|
||||
self.rerank(q, cs, top_n)
|
||||
for q, cs in zip(queries, candidates)
|
||||
]
|
||||
|
|
@ -1,79 +0,0 @@
|
|||
"""
|
||||
circuitforge_core.stt — Speech-to-text service module.
|
||||
|
||||
Quick start (mock mode — no GPU or model required):
|
||||
|
||||
import os; os.environ["CF_STT_MOCK"] = "1"
|
||||
from circuitforge_core.stt import transcribe
|
||||
|
||||
result = transcribe(open("audio.wav", "rb").read())
|
||||
print(result.text, result.confidence)
|
||||
|
||||
Real inference (faster-whisper):
|
||||
|
||||
export CF_STT_MODEL=/Library/Assets/LLM/whisper/models/Whisper/faster-whisper/models--Systran--faster-whisper-medium/snapshots/<hash>
|
||||
from circuitforge_core.stt import transcribe
|
||||
|
||||
cf-orch service profile:
|
||||
|
||||
service_type: cf-stt
|
||||
max_mb: 1024 (medium); 600 (base/small)
|
||||
max_concurrent: 3
|
||||
shared: true
|
||||
managed:
|
||||
exec: python -m circuitforge_core.stt.app
|
||||
args: --model <path> --port {port} --gpu-id {gpu_id}
|
||||
port: 8004
|
||||
health: /health
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
from circuitforge_core.stt.backends.base import (
|
||||
STTBackend,
|
||||
STTResult,
|
||||
STTSegment,
|
||||
make_stt_backend,
|
||||
)
|
||||
from circuitforge_core.stt.backends.mock import MockSTTBackend
|
||||
|
||||
_backend: STTBackend | None = None
|
||||
|
||||
|
||||
def _get_backend() -> STTBackend:
|
||||
global _backend
|
||||
if _backend is None:
|
||||
model_path = os.environ.get("CF_STT_MODEL", "mock")
|
||||
mock = model_path == "mock" or os.environ.get("CF_STT_MOCK", "") == "1"
|
||||
_backend = make_stt_backend(model_path, mock=mock)
|
||||
return _backend
|
||||
|
||||
|
||||
def transcribe(
|
||||
audio: bytes,
|
||||
*,
|
||||
language: str | None = None,
|
||||
confidence_threshold: float = STTResult.CONFIDENCE_DEFAULT_THRESHOLD,
|
||||
) -> STTResult:
|
||||
"""Transcribe audio bytes using the process-level backend."""
|
||||
return _get_backend().transcribe(
|
||||
audio, language=language, confidence_threshold=confidence_threshold
|
||||
)
|
||||
|
||||
|
||||
def reset_backend() -> None:
|
||||
"""Reset the process-level singleton. Test teardown only."""
|
||||
global _backend
|
||||
_backend = None
|
||||
|
||||
|
||||
__all__ = [
|
||||
"STTBackend",
|
||||
"STTResult",
|
||||
"STTSegment",
|
||||
"MockSTTBackend",
|
||||
"make_stt_backend",
|
||||
"transcribe",
|
||||
"reset_backend",
|
||||
]
|
||||
|
|
@ -1,150 +0,0 @@
|
|||
"""
|
||||
circuitforge_core.stt.app — cf-stt FastAPI service.
|
||||
|
||||
Managed by cf-orch as a process-type service. cf-orch starts this via:
|
||||
|
||||
python -m circuitforge_core.stt.app \
|
||||
--model /Library/Assets/LLM/whisper/models/Whisper/faster-whisper/models--Systran--faster-whisper-medium/snapshots/<hash> \
|
||||
--port 8004 \
|
||||
--gpu-id 0
|
||||
|
||||
Endpoints:
|
||||
GET /health → {"status": "ok", "model": "<name>", "vram_mb": <n>}
|
||||
POST /transcribe → STTTranscribeResponse (multipart: audio file)
|
||||
|
||||
Audio format: any format ffmpeg understands (WAV, MP3, OGG, FLAC).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, File, Form, HTTPException, UploadFile
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from circuitforge_core.stt.backends.base import STTResult, make_stt_backend
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ── Response model (mirrors circuitforge_orch.contracts.stt.STTTranscribeResponse) ──
|
||||
|
||||
class TranscribeResponse(BaseModel):
|
||||
text: str
|
||||
confidence: float
|
||||
below_threshold: bool
|
||||
language: str | None = None
|
||||
duration_s: float | None = None
|
||||
segments: list[dict] = []
|
||||
model: str = ""
|
||||
|
||||
|
||||
# ── App factory ───────────────────────────────────────────────────────────────
|
||||
|
||||
def create_app(
|
||||
model_path: str,
|
||||
device: str = "cuda",
|
||||
compute_type: str = "float16",
|
||||
confidence_threshold: float = STTResult.CONFIDENCE_DEFAULT_THRESHOLD,
|
||||
mock: bool = False,
|
||||
) -> FastAPI:
|
||||
app = FastAPI(title="cf-stt", version="0.1.0")
|
||||
backend = make_stt_backend(
|
||||
model_path, device=device, compute_type=compute_type, mock=mock
|
||||
)
|
||||
logger.info("cf-stt ready: model=%r vram=%dMB", backend.model_name, backend.vram_mb)
|
||||
|
||||
@app.get("/health")
|
||||
async def health() -> dict:
|
||||
return {"status": "ok", "model": backend.model_name, "vram_mb": backend.vram_mb}
|
||||
|
||||
@app.post("/transcribe", response_model=TranscribeResponse)
|
||||
async def transcribe(
|
||||
audio: UploadFile = File(..., description="Audio file (WAV, MP3, OGG, FLAC, ...)"),
|
||||
language: str | None = Form(None, description="BCP-47 language code hint, e.g. 'en'"),
|
||||
confidence_threshold_override: float | None = Form(
|
||||
None,
|
||||
description="Override default confidence threshold for this request.",
|
||||
),
|
||||
) -> TranscribeResponse:
|
||||
audio_bytes = await audio.read()
|
||||
if not audio_bytes:
|
||||
raise HTTPException(status_code=400, detail="Empty audio file")
|
||||
|
||||
threshold = confidence_threshold_override or confidence_threshold
|
||||
try:
|
||||
result = backend.transcribe(
|
||||
audio_bytes, language=language, confidence_threshold=threshold
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.exception("Transcription failed")
|
||||
raise HTTPException(status_code=500, detail=str(exc)) from exc
|
||||
|
||||
return TranscribeResponse(
|
||||
text=result.text,
|
||||
confidence=result.confidence,
|
||||
below_threshold=result.below_threshold,
|
||||
language=result.language,
|
||||
duration_s=result.duration_s,
|
||||
segments=[
|
||||
{
|
||||
"start_s": s.start_s,
|
||||
"end_s": s.end_s,
|
||||
"text": s.text,
|
||||
"confidence": s.confidence,
|
||||
}
|
||||
for s in result.segments
|
||||
],
|
||||
model=result.model,
|
||||
)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
# ── CLI entry point ───────────────────────────────────────────────────────────
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="cf-stt — CircuitForge STT service")
|
||||
parser.add_argument("--model", required=True,
|
||||
help="Model path or size name (e.g. 'medium', or full local path)")
|
||||
parser.add_argument("--port", type=int, default=8004)
|
||||
parser.add_argument("--host", default="0.0.0.0")
|
||||
parser.add_argument("--gpu-id", type=int, default=0,
|
||||
help="CUDA device index (sets CUDA_VISIBLE_DEVICES)")
|
||||
parser.add_argument("--device", default="cuda", choices=["cuda", "cpu"])
|
||||
parser.add_argument("--compute-type", default="float16",
|
||||
choices=["float16", "int8", "int8_float16", "float32"],
|
||||
help="Quantisation / compute type passed to faster-whisper")
|
||||
parser.add_argument("--confidence-threshold", type=float,
|
||||
default=STTResult.CONFIDENCE_DEFAULT_THRESHOLD)
|
||||
parser.add_argument("--mock", action="store_true",
|
||||
help="Run with mock backend (no GPU, for testing)")
|
||||
args = parser.parse_args()
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s %(levelname)s %(name)s %(message)s",
|
||||
)
|
||||
|
||||
# Let cf-orch pass --gpu-id; map to CUDA_VISIBLE_DEVICES so the process
|
||||
# only sees its assigned GPU. This prevents accidental multi-GPU usage.
|
||||
if args.device == "cuda" and not args.mock:
|
||||
os.environ.setdefault("CUDA_VISIBLE_DEVICES", str(args.gpu_id))
|
||||
|
||||
mock = args.mock or os.environ.get("CF_STT_MOCK", "") == "1"
|
||||
app = create_app(
|
||||
model_path=args.model,
|
||||
device=args.device,
|
||||
compute_type=args.compute_type,
|
||||
confidence_threshold=args.confidence_threshold,
|
||||
mock=mock,
|
||||
)
|
||||
|
||||
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -1,4 +0,0 @@
|
|||
from .base import STTBackend, STTResult, STTSegment, make_stt_backend
|
||||
from .mock import MockSTTBackend
|
||||
|
||||
__all__ = ["STTBackend", "STTResult", "STTSegment", "make_stt_backend", "MockSTTBackend"]
|
||||
|
|
@ -1,109 +0,0 @@
|
|||
# circuitforge_core/stt/backends/base.py — STTBackend Protocol + factory
|
||||
#
|
||||
# MIT licensed. The Protocol and mock are always importable without GPU deps.
|
||||
# Real backends require optional extras:
|
||||
# pip install -e "circuitforge-core[stt-faster-whisper]"
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Protocol, runtime_checkable
|
||||
|
||||
|
||||
# ── Result types ──────────────────────────────────────────────────────────────
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class STTSegment:
|
||||
"""Word- or phrase-level segment (included when the backend supports it)."""
|
||||
start_s: float
|
||||
end_s: float
|
||||
text: str
|
||||
confidence: float
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class STTResult:
|
||||
"""
|
||||
Standard result from any STTBackend.transcribe() call.
|
||||
|
||||
confidence is normalised to 0.0–1.0 regardless of the backend's native metric.
|
||||
below_threshold is True when confidence < the configured threshold (default 0.75).
|
||||
This flag is safety-critical for products like Osprey: DTMF must NOT be sent
|
||||
when below_threshold is True.
|
||||
"""
|
||||
text: str
|
||||
confidence: float # 0.0–1.0
|
||||
below_threshold: bool
|
||||
language: str | None = None
|
||||
duration_s: float | None = None
|
||||
segments: list[STTSegment] = field(default_factory=list)
|
||||
model: str = ""
|
||||
|
||||
CONFIDENCE_DEFAULT_THRESHOLD: float = 0.75
|
||||
|
||||
|
||||
# ── Protocol ──────────────────────────────────────────────────────────────────
|
||||
|
||||
@runtime_checkable
|
||||
class STTBackend(Protocol):
|
||||
"""
|
||||
Abstract interface for speech-to-text backends.
|
||||
|
||||
All backends load their model once at construction time and are safe to
|
||||
call concurrently (the model weights are read-only after load).
|
||||
"""
|
||||
|
||||
def transcribe(
|
||||
self,
|
||||
audio: bytes,
|
||||
*,
|
||||
language: str | None = None,
|
||||
confidence_threshold: float = STTResult.CONFIDENCE_DEFAULT_THRESHOLD,
|
||||
) -> STTResult:
|
||||
"""Synchronous transcription. audio is raw PCM or any format ffmpeg understands."""
|
||||
...
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
"""Identifier for the loaded model (path stem or size name)."""
|
||||
...
|
||||
|
||||
@property
|
||||
def vram_mb(self) -> int:
|
||||
"""Approximate VRAM footprint in MB. Used by cf-orch service registry."""
|
||||
...
|
||||
|
||||
|
||||
# ── Factory ───────────────────────────────────────────────────────────────────
|
||||
|
||||
def make_stt_backend(
|
||||
model_path: str,
|
||||
backend: str | None = None,
|
||||
mock: bool | None = None,
|
||||
device: str = "cuda",
|
||||
compute_type: str = "float16",
|
||||
) -> STTBackend:
|
||||
"""
|
||||
Return an STTBackend for the given model.
|
||||
|
||||
mock=True or CF_STT_MOCK=1 → MockSTTBackend (no GPU, no model file needed)
|
||||
backend="faster-whisper" → FasterWhisperBackend (default)
|
||||
|
||||
device and compute_type are passed through to the backend and ignored by mock.
|
||||
"""
|
||||
use_mock = mock if mock is not None else os.environ.get("CF_STT_MOCK", "") == "1"
|
||||
if use_mock:
|
||||
from circuitforge_core.stt.backends.mock import MockSTTBackend
|
||||
return MockSTTBackend(model_name=model_path)
|
||||
|
||||
resolved = backend or os.environ.get("CF_STT_BACKEND", "faster-whisper")
|
||||
if resolved == "faster-whisper":
|
||||
from circuitforge_core.stt.backends.faster_whisper import FasterWhisperBackend
|
||||
return FasterWhisperBackend(
|
||||
model_path=model_path, device=device, compute_type=compute_type
|
||||
)
|
||||
|
||||
raise ValueError(
|
||||
f"Unknown STT backend {resolved!r}. "
|
||||
"Expected 'faster-whisper'. Set CF_STT_BACKEND or pass backend= explicitly."
|
||||
)
|
||||
|
|
@ -1,139 +0,0 @@
|
|||
# circuitforge_core/stt/backends/faster_whisper.py — FasterWhisperBackend
|
||||
#
|
||||
# MIT licensed. Requires: pip install -e "circuitforge-core[stt-faster-whisper]"
|
||||
#
|
||||
# Model path can be:
|
||||
# - A size name: "base", "small", "medium", "large-v3"
|
||||
# (faster-whisper downloads and caches it on first use)
|
||||
# - A local path: "/Library/Assets/LLM/whisper/models/Whisper/faster-whisper/..."
|
||||
# (preferred for air-gapped nodes — no download needed)
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
from circuitforge_core.stt.backends.base import STTResult, STTSegment
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# VRAM estimates by model size. Used by cf-orch for VRAM budgeting.
|
||||
_VRAM_MB_BY_SIZE: dict[str, int] = {
|
||||
"tiny": 200,
|
||||
"base": 350,
|
||||
"small": 600,
|
||||
"medium": 1024,
|
||||
"large": 2048,
|
||||
"large-v2": 2048,
|
||||
"large-v3": 2048,
|
||||
"distil-large-v3": 1500,
|
||||
}
|
||||
|
||||
# Aggregate confidence from per-segment no_speech_prob values.
|
||||
# faster-whisper doesn't expose a direct confidence score, so we invert the
|
||||
# mean no_speech_prob as a proxy. This is conservative but directionally correct.
|
||||
def _aggregate_confidence(segments: list) -> float:
|
||||
if not segments:
|
||||
return 0.0
|
||||
probs = [max(0.0, 1.0 - getattr(s, "no_speech_prob", 0.0)) for s in segments]
|
||||
return sum(probs) / len(probs)
|
||||
|
||||
|
||||
class FasterWhisperBackend:
|
||||
"""
|
||||
faster-whisper STT backend.
|
||||
|
||||
Thread-safe after construction: WhisperModel internally manages its own
|
||||
CUDA context and is safe to call from multiple threads.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_path: str,
|
||||
device: str = "cuda",
|
||||
compute_type: str = "float16",
|
||||
) -> None:
|
||||
try:
|
||||
from faster_whisper import WhisperModel
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"faster-whisper is not installed. "
|
||||
"Run: pip install -e 'circuitforge-core[stt-faster-whisper]'"
|
||||
) from exc
|
||||
|
||||
logger.info("Loading faster-whisper model from %r (device=%s)", model_path, device)
|
||||
self._model_path = model_path
|
||||
self._device = device
|
||||
self._compute_type = compute_type
|
||||
self._model = WhisperModel(model_path, device=device, compute_type=compute_type)
|
||||
logger.info("faster-whisper model ready")
|
||||
|
||||
# Determine VRAM footprint from model name/path stem.
|
||||
stem = os.path.basename(model_path.rstrip("/")).lower()
|
||||
self._vram_mb = next(
|
||||
(v for k, v in _VRAM_MB_BY_SIZE.items() if k in stem),
|
||||
1024, # conservative default if size can't be inferred
|
||||
)
|
||||
|
||||
def transcribe(
|
||||
self,
|
||||
audio: bytes,
|
||||
*,
|
||||
language: str | None = None,
|
||||
confidence_threshold: float = STTResult.CONFIDENCE_DEFAULT_THRESHOLD,
|
||||
) -> STTResult:
|
||||
"""
|
||||
Transcribe raw audio bytes.
|
||||
|
||||
audio can be any format ffmpeg understands (WAV, MP3, OGG, FLAC, etc.).
|
||||
faster-whisper writes audio to a temp file internally; we follow the
|
||||
same pattern to avoid holding the bytes in memory longer than needed.
|
||||
"""
|
||||
with tempfile.NamedTemporaryFile(suffix=".audio", delete=False) as tmp:
|
||||
tmp.write(audio)
|
||||
tmp_path = tmp.name
|
||||
|
||||
try:
|
||||
segments_gen, info = self._model.transcribe(
|
||||
tmp_path,
|
||||
language=language,
|
||||
word_timestamps=True,
|
||||
vad_filter=True,
|
||||
)
|
||||
segments = list(segments_gen)
|
||||
finally:
|
||||
os.unlink(tmp_path)
|
||||
|
||||
text = " ".join(s.text.strip() for s in segments).strip()
|
||||
confidence = _aggregate_confidence(segments)
|
||||
duration_s = info.duration if hasattr(info, "duration") else None
|
||||
detected_language = getattr(info, "language", language)
|
||||
|
||||
stt_segments = [
|
||||
STTSegment(
|
||||
start_s=s.start,
|
||||
end_s=s.end,
|
||||
text=s.text.strip(),
|
||||
confidence=max(0.0, 1.0 - getattr(s, "no_speech_prob", 0.0)),
|
||||
)
|
||||
for s in segments
|
||||
]
|
||||
|
||||
return STTResult(
|
||||
text=text,
|
||||
confidence=confidence,
|
||||
below_threshold=confidence < confidence_threshold,
|
||||
language=detected_language,
|
||||
duration_s=duration_s,
|
||||
segments=stt_segments,
|
||||
model=self._model_path,
|
||||
)
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
return self._model_path
|
||||
|
||||
@property
|
||||
def vram_mb(self) -> int:
|
||||
return self._vram_mb
|
||||
|
|
@ -1,54 +0,0 @@
|
|||
# circuitforge_core/stt/backends/mock.py — MockSTTBackend
|
||||
#
|
||||
# MIT licensed. No GPU, no model file required.
|
||||
# Used in tests and CI, and when CF_STT_MOCK=1.
|
||||
from __future__ import annotations
|
||||
|
||||
from circuitforge_core.stt.backends.base import STTBackend, STTResult
|
||||
|
||||
|
||||
class MockSTTBackend:
|
||||
"""
|
||||
Deterministic mock STT backend for testing.
|
||||
|
||||
Returns a fixed transcript so tests can assert on the response shape
|
||||
without needing a GPU or a model file.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "mock",
|
||||
fixed_text: str = "mock transcription",
|
||||
fixed_confidence: float = 0.95,
|
||||
) -> None:
|
||||
self._model_name = model_name
|
||||
self._fixed_text = fixed_text
|
||||
self._fixed_confidence = fixed_confidence
|
||||
|
||||
def transcribe(
|
||||
self,
|
||||
audio: bytes,
|
||||
*,
|
||||
language: str | None = None,
|
||||
confidence_threshold: float = STTResult.CONFIDENCE_DEFAULT_THRESHOLD,
|
||||
) -> STTResult:
|
||||
return STTResult(
|
||||
text=self._fixed_text,
|
||||
confidence=self._fixed_confidence,
|
||||
below_threshold=self._fixed_confidence < confidence_threshold,
|
||||
language=language or "en",
|
||||
duration_s=float(len(audio)) / 32000, # rough estimate: 16kHz 16-bit mono
|
||||
model=self._model_name,
|
||||
)
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
return self._model_name
|
||||
|
||||
@property
|
||||
def vram_mb(self) -> int:
|
||||
return 0
|
||||
|
||||
|
||||
# Satisfy the Protocol at import time (no GPU needed)
|
||||
assert isinstance(MockSTTBackend(), STTBackend)
|
||||
|
|
@ -1,144 +0,0 @@
|
|||
"""
|
||||
circuitforge_core.text — direct text generation service module.
|
||||
|
||||
Provides lightweight, low-overhead text generation that bypasses ollama/vllm
|
||||
for products that need fast, frequent inference from small local models.
|
||||
|
||||
Quick start (mock mode — no model required):
|
||||
|
||||
import os; os.environ["CF_TEXT_MOCK"] = "1"
|
||||
from circuitforge_core.text import generate, chat, ChatMessage
|
||||
|
||||
result = generate("Write a short cover letter intro.")
|
||||
print(result.text)
|
||||
|
||||
reply = chat([
|
||||
ChatMessage("system", "You are a helpful recipe assistant."),
|
||||
ChatMessage("user", "What can I make with eggs, spinach, and feta?"),
|
||||
])
|
||||
print(reply.text)
|
||||
|
||||
Real inference (GGUF model):
|
||||
|
||||
export CF_TEXT_MODEL=/Library/Assets/LLM/qwen2.5-3b-instruct-q4_k_m.gguf
|
||||
from circuitforge_core.text import generate
|
||||
result = generate("Summarise this job posting in 2 sentences: ...")
|
||||
|
||||
Backend selection (CF_TEXT_BACKEND env or explicit):
|
||||
|
||||
from circuitforge_core.text import make_backend
|
||||
backend = make_backend("/path/to/model.gguf", backend="llamacpp")
|
||||
|
||||
cf-orch service profile:
|
||||
|
||||
service_type: cf-text
|
||||
max_mb: per-model (3B Q4 ≈ 2048, 7B Q4 ≈ 4096)
|
||||
preferred_compute: 7.5 minimum (INT8 tensor cores)
|
||||
max_concurrent: 2
|
||||
shared: true
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
from circuitforge_core.text.backends.base import (
|
||||
ChatMessage,
|
||||
GenerateResult,
|
||||
TextBackend,
|
||||
make_text_backend,
|
||||
)
|
||||
from circuitforge_core.text.backends.mock import MockTextBackend
|
||||
|
||||
# ── Process-level singleton backend ──────────────────────────────────────────
|
||||
# Lazily initialised on first call to generate() or chat().
|
||||
# Products that need per-user or per-request backends should use make_backend().
|
||||
|
||||
_backend: TextBackend | None = None
|
||||
|
||||
|
||||
def _get_backend() -> TextBackend:
|
||||
global _backend
|
||||
if _backend is None:
|
||||
model_path = os.environ.get("CF_TEXT_MODEL", "mock")
|
||||
mock = model_path == "mock" or os.environ.get("CF_TEXT_MOCK", "") == "1"
|
||||
_backend = make_text_backend(model_path, mock=mock)
|
||||
return _backend
|
||||
|
||||
|
||||
def make_backend(
|
||||
model_path: str,
|
||||
backend: str | None = None,
|
||||
mock: bool | None = None,
|
||||
) -> TextBackend:
|
||||
"""
|
||||
Create a TextBackend for the given model.
|
||||
|
||||
Use this when you need a dedicated backend per request or per user,
|
||||
rather than the process-level singleton used by generate() and chat().
|
||||
"""
|
||||
return make_text_backend(model_path, backend=backend, mock=mock)
|
||||
|
||||
|
||||
# ── Convenience functions (singleton path) ────────────────────────────────────
|
||||
|
||||
|
||||
def generate(
|
||||
prompt: str,
|
||||
*,
|
||||
model: str | None = None,
|
||||
max_tokens: int = 512,
|
||||
temperature: float = 0.7,
|
||||
stream: bool = False,
|
||||
stop: list[str] | None = None,
|
||||
):
|
||||
"""
|
||||
Generate text from a prompt using the process-level backend.
|
||||
|
||||
stream=True returns an Iterator[str] of tokens instead of GenerateResult.
|
||||
model is accepted for API symmetry with LLMRouter but ignored by the
|
||||
singleton path — set CF_TEXT_MODEL to change the loaded model.
|
||||
"""
|
||||
backend = _get_backend()
|
||||
if stream:
|
||||
return backend.generate_stream(prompt, max_tokens=max_tokens, temperature=temperature, stop=stop)
|
||||
return backend.generate(prompt, max_tokens=max_tokens, temperature=temperature, stop=stop)
|
||||
|
||||
|
||||
def chat(
|
||||
messages: list[ChatMessage],
|
||||
*,
|
||||
model: str | None = None,
|
||||
max_tokens: int = 512,
|
||||
temperature: float = 0.7,
|
||||
stream: bool = False,
|
||||
) -> GenerateResult:
|
||||
"""
|
||||
Chat completion using the process-level backend.
|
||||
|
||||
messages should be a list of ChatMessage(role, content) objects.
|
||||
stream=True is not yet supported on the chat path; pass stream=False.
|
||||
"""
|
||||
if stream:
|
||||
raise NotImplementedError(
|
||||
"stream=True is not yet supported for chat(). "
|
||||
"Use generate_stream() directly on a backend instance."
|
||||
)
|
||||
return _get_backend().chat(messages, max_tokens=max_tokens, temperature=temperature)
|
||||
|
||||
|
||||
def reset_backend() -> None:
|
||||
"""Reset the process-level singleton. Test teardown only."""
|
||||
global _backend
|
||||
_backend = None
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ChatMessage",
|
||||
"GenerateResult",
|
||||
"TextBackend",
|
||||
"MockTextBackend",
|
||||
"make_backend",
|
||||
"generate",
|
||||
"chat",
|
||||
"reset_backend",
|
||||
]
|
||||
|
|
@ -1,251 +0,0 @@
|
|||
"""
|
||||
cf-text FastAPI service — managed by cf-orch.
|
||||
|
||||
Lightweight local text generation. Supports GGUF models via llama.cpp and
|
||||
HuggingFace transformers. Sits alongside vllm/ollama for products that need
|
||||
fast, frequent inference from small local models (3B–7B Q4).
|
||||
|
||||
Endpoints:
|
||||
GET /health → {"status": "ok", "model": str, "vram_mb": int, "backend": str}
|
||||
POST /generate → GenerateResponse
|
||||
POST /chat → GenerateResponse
|
||||
|
||||
Usage:
|
||||
python -m circuitforge_core.text.app \
|
||||
--model /Library/Assets/LLM/qwen2.5-3b-instruct-q4_k_m.gguf \
|
||||
--port 8006 \
|
||||
--gpu-id 0
|
||||
|
||||
Multi-GPU (spans two GPUs via CUDA_VISIBLE_DEVICES, device_map=auto):
|
||||
python -m circuitforge_core.text.app \
|
||||
--model /Library/Assets/LLM/deepseek-14b \
|
||||
--port 8006 \
|
||||
--gpu-ids 0,1
|
||||
|
||||
Mock mode (no model or GPU required):
|
||||
CF_TEXT_MOCK=1 python -m circuitforge_core.text.app --port 8006
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
from functools import partial
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
from circuitforge_core.text.backends.base import ChatMessage as BackendChatMessage
|
||||
from circuitforge_core.text.backends.base import make_text_backend
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_backend = None
|
||||
|
||||
|
||||
# ── Request / response models ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
class GenerateRequest(BaseModel):
|
||||
prompt: str
|
||||
max_tokens: int = 512
|
||||
temperature: float = 0.7
|
||||
stop: list[str] | None = None
|
||||
|
||||
|
||||
class ChatMessageModel(BaseModel):
|
||||
role: str
|
||||
content: str
|
||||
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
messages: list[ChatMessageModel]
|
||||
max_tokens: int = 512
|
||||
temperature: float = 0.7
|
||||
|
||||
|
||||
class GenerateResponse(BaseModel):
|
||||
text: str
|
||||
tokens_used: int = 0
|
||||
model: str = ""
|
||||
|
||||
|
||||
# ── OpenAI-compat request / response (for LLMRouter openai_compat path) ──────
|
||||
|
||||
|
||||
class OAIMessageModel(BaseModel):
|
||||
role: str
|
||||
content: str
|
||||
|
||||
|
||||
class OAIChatRequest(BaseModel):
|
||||
model: str = "cf-text"
|
||||
messages: list[OAIMessageModel]
|
||||
max_tokens: int | None = None
|
||||
temperature: float = 0.7
|
||||
stream: bool = False
|
||||
|
||||
|
||||
class OAIChoice(BaseModel):
|
||||
index: int = 0
|
||||
message: OAIMessageModel
|
||||
finish_reason: str = "stop"
|
||||
|
||||
|
||||
class OAIUsage(BaseModel):
|
||||
prompt_tokens: int = 0
|
||||
completion_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
|
||||
|
||||
class OAIChatResponse(BaseModel):
|
||||
id: str
|
||||
object: str = "chat.completion"
|
||||
created: int
|
||||
model: str
|
||||
choices: list[OAIChoice]
|
||||
usage: OAIUsage
|
||||
|
||||
|
||||
# ── App factory ───────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def create_app(
|
||||
model_path: str,
|
||||
gpu_id: int = 0,
|
||||
gpu_ids: str | None = None,
|
||||
backend: str | None = None,
|
||||
mock: bool = False,
|
||||
) -> FastAPI:
|
||||
"""Start the cf-text FastAPI app.
|
||||
|
||||
``gpu_ids``: comma-separated CUDA device indices for multi-GPU spanning
|
||||
(e.g. "0,1"). When set, overrides ``gpu_id`` and sets
|
||||
``CUDA_VISIBLE_DEVICES`` to the full list so HuggingFace Accelerate's
|
||||
``device_map="auto"`` can shard the model across all listed devices.
|
||||
"""
|
||||
global _backend
|
||||
|
||||
if not mock and not model_path:
|
||||
raise ValueError(
|
||||
"cf-text: --model is required (got empty string). "
|
||||
"Pass a GGUF path, a HuggingFace model ID, or set CF_TEXT_MOCK=1 for mock mode."
|
||||
)
|
||||
|
||||
visible = gpu_ids if gpu_ids else str(gpu_id)
|
||||
os.environ.setdefault("CUDA_VISIBLE_DEVICES", visible)
|
||||
|
||||
_backend = make_text_backend(model_path, backend=backend, mock=mock)
|
||||
logger.info("cf-text ready: model=%r vram=%dMB", _backend.model_name, _backend.vram_mb)
|
||||
|
||||
app = FastAPI(title="cf-text", version="0.1.0")
|
||||
|
||||
@app.get("/health")
|
||||
def health() -> dict:
|
||||
if _backend is None:
|
||||
raise HTTPException(503, detail="backend not initialised")
|
||||
return {
|
||||
"status": "ok",
|
||||
"model": _backend.model_name,
|
||||
"vram_mb": _backend.vram_mb,
|
||||
}
|
||||
|
||||
@app.post("/generate")
|
||||
async def generate(req: GenerateRequest) -> GenerateResponse:
|
||||
if _backend is None:
|
||||
raise HTTPException(503, detail="backend not initialised")
|
||||
result = await _backend.generate_async(
|
||||
req.prompt,
|
||||
max_tokens=req.max_tokens,
|
||||
temperature=req.temperature,
|
||||
stop=req.stop,
|
||||
)
|
||||
return GenerateResponse(
|
||||
text=result.text,
|
||||
tokens_used=result.tokens_used,
|
||||
model=result.model,
|
||||
)
|
||||
|
||||
@app.post("/chat")
|
||||
async def chat(req: ChatRequest) -> GenerateResponse:
|
||||
if _backend is None:
|
||||
raise HTTPException(503, detail="backend not initialised")
|
||||
messages = [BackendChatMessage(m.role, m.content) for m in req.messages]
|
||||
# chat() is sync-only in the Protocol; run in thread pool to avoid blocking
|
||||
loop = asyncio.get_event_loop()
|
||||
result = await loop.run_in_executor(
|
||||
None,
|
||||
partial(_backend.chat, messages,
|
||||
max_tokens=req.max_tokens, temperature=req.temperature),
|
||||
)
|
||||
return GenerateResponse(
|
||||
text=result.text,
|
||||
tokens_used=result.tokens_used,
|
||||
model=result.model,
|
||||
)
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
async def oai_chat_completions(req: OAIChatRequest) -> OAIChatResponse:
|
||||
"""OpenAI-compatible chat completions endpoint.
|
||||
|
||||
Allows LLMRouter (and any openai_compat client) to use cf-text
|
||||
without a custom backend type — just set base_url to this service's
|
||||
/v1 prefix.
|
||||
"""
|
||||
if _backend is None:
|
||||
raise HTTPException(503, detail="backend not initialised")
|
||||
messages = [BackendChatMessage(m.role, m.content) for m in req.messages]
|
||||
max_tok = req.max_tokens or 512
|
||||
loop = asyncio.get_event_loop()
|
||||
result = await loop.run_in_executor(
|
||||
None,
|
||||
partial(_backend.chat, messages, max_tokens=max_tok, temperature=req.temperature),
|
||||
)
|
||||
return OAIChatResponse(
|
||||
id=f"cftext-{uuid.uuid4().hex[:12]}",
|
||||
created=int(time.time()),
|
||||
model=result.model or req.model,
|
||||
choices=[OAIChoice(message=OAIMessageModel(role="assistant", content=result.text))],
|
||||
usage=OAIUsage(completion_tokens=result.tokens_used, total_tokens=result.tokens_used),
|
||||
)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
# ── CLI entrypoint ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description="cf-text inference server")
|
||||
parser.add_argument("--model", default=os.environ.get("CF_TEXT_MODEL", "mock"),
|
||||
help="Path to GGUF file or HF model ID")
|
||||
parser.add_argument("--port", type=int, default=8006)
|
||||
parser.add_argument("--host", default="0.0.0.0")
|
||||
parser.add_argument("--gpu-id", type=int, default=0,
|
||||
help="CUDA device index to use (single GPU)")
|
||||
parser.add_argument("--gpu-ids", default=None,
|
||||
help="Comma-separated CUDA device indices for multi-GPU spanning "
|
||||
"(e.g. '0,1'). Overrides --gpu-id when set.")
|
||||
parser.add_argument("--backend", choices=["llamacpp", "transformers"], default=None)
|
||||
parser.add_argument("--mock", action="store_true",
|
||||
help="Run in mock mode (no model or GPU needed)")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO,
|
||||
format="%(asctime)s %(levelname)s %(name)s — %(message)s")
|
||||
args = _parse_args()
|
||||
mock = args.mock or os.environ.get("CF_TEXT_MOCK", "") == "1" or args.model == "mock"
|
||||
app = create_app(
|
||||
model_path=args.model,
|
||||
gpu_id=args.gpu_id,
|
||||
gpu_ids=args.gpu_ids,
|
||||
backend=args.backend,
|
||||
mock=mock,
|
||||
)
|
||||
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|
||||
|
|
@ -1,10 +0,0 @@
|
|||
from .base import ChatMessage, GenerateResult, TextBackend, make_text_backend
|
||||
from .mock import MockTextBackend
|
||||
|
||||
__all__ = [
|
||||
"ChatMessage",
|
||||
"GenerateResult",
|
||||
"TextBackend",
|
||||
"MockTextBackend",
|
||||
"make_text_backend",
|
||||
]
|
||||
|
|
@ -1,198 +0,0 @@
|
|||
# circuitforge_core/text/backends/base.py — TextBackend Protocol + factory
|
||||
#
|
||||
# MIT licensed. The Protocol and mock backend are always importable.
|
||||
# Real backends (LlamaCppBackend, TransformersBackend) require optional extras.
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import AsyncIterator, Iterator, Protocol, runtime_checkable
|
||||
|
||||
|
||||
# ── Shared result types ───────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class GenerateResult:
|
||||
"""Result from a single non-streaming generate() call."""
|
||||
|
||||
def __init__(self, text: str, tokens_used: int = 0, model: str = "") -> None:
|
||||
self.text = text
|
||||
self.tokens_used = tokens_used
|
||||
self.model = model
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"GenerateResult(text={self.text!r:.40}, tokens={self.tokens_used})"
|
||||
|
||||
|
||||
class ChatMessage:
|
||||
"""A single message in a chat conversation."""
|
||||
|
||||
def __init__(self, role: str, content: str) -> None:
|
||||
if role not in ("system", "user", "assistant"):
|
||||
raise ValueError(f"Invalid role {role!r}. Must be system, user, or assistant.")
|
||||
self.role = role
|
||||
self.content = content
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {"role": self.role, "content": self.content}
|
||||
|
||||
|
||||
# ── TextBackend Protocol ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class TextBackend(Protocol):
|
||||
"""
|
||||
Abstract interface for direct text generation backends.
|
||||
|
||||
All generate/chat methods have both sync and async variants.
|
||||
Streaming variants yield str tokens rather than a complete result.
|
||||
|
||||
Implementations must be safe to construct once and call concurrently
|
||||
(the model is loaded at construction time and reused across calls).
|
||||
"""
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
*,
|
||||
max_tokens: int = 512,
|
||||
temperature: float = 0.7,
|
||||
stop: list[str] | None = None,
|
||||
) -> GenerateResult:
|
||||
"""Synchronous generate — blocks until the full response is produced."""
|
||||
...
|
||||
|
||||
def generate_stream(
|
||||
self,
|
||||
prompt: str,
|
||||
*,
|
||||
max_tokens: int = 512,
|
||||
temperature: float = 0.7,
|
||||
stop: list[str] | None = None,
|
||||
) -> Iterator[str]:
|
||||
"""Synchronous streaming — yields tokens as they are produced."""
|
||||
...
|
||||
|
||||
async def generate_async(
|
||||
self,
|
||||
prompt: str,
|
||||
*,
|
||||
max_tokens: int = 512,
|
||||
temperature: float = 0.7,
|
||||
stop: list[str] | None = None,
|
||||
) -> GenerateResult:
|
||||
"""Async generate — runs in thread pool, never blocks the event loop."""
|
||||
...
|
||||
|
||||
async def generate_stream_async(
|
||||
self,
|
||||
prompt: str,
|
||||
*,
|
||||
max_tokens: int = 512,
|
||||
temperature: float = 0.7,
|
||||
stop: list[str] | None = None,
|
||||
) -> AsyncIterator[str]:
|
||||
"""Async streaming — yields tokens without blocking the event loop."""
|
||||
...
|
||||
|
||||
def chat(
|
||||
self,
|
||||
messages: list[ChatMessage],
|
||||
*,
|
||||
max_tokens: int = 512,
|
||||
temperature: float = 0.7,
|
||||
) -> GenerateResult:
|
||||
"""Chat completion — formats messages into a prompt and generates."""
|
||||
...
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
"""Identifier for the loaded model (path stem or HF repo ID)."""
|
||||
...
|
||||
|
||||
@property
|
||||
def vram_mb(self) -> int:
|
||||
"""Approximate VRAM footprint in MB. Used by cf-orch service registry."""
|
||||
...
|
||||
|
||||
|
||||
# ── Backend selection ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _select_backend(model_path: str, backend: str | None) -> str:
|
||||
"""
|
||||
Return "llamacpp", "transformers", "ollama", or "vllm" for the given model path.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model_path Path to the model file, HuggingFace repo ID, "ollama://<name>",
|
||||
or "vllm://<model-id>".
|
||||
backend Explicit override from the caller
|
||||
("llamacpp" | "transformers" | "ollama" | "vllm" | None).
|
||||
When provided, trust it without inspection.
|
||||
|
||||
Raise ValueError for unrecognised override values.
|
||||
"""
|
||||
_VALID = ("llamacpp", "transformers", "ollama", "vllm")
|
||||
|
||||
# 1. Caller-supplied override — highest trust, no inspection needed.
|
||||
resolved = backend or os.environ.get("CF_TEXT_BACKEND")
|
||||
if resolved:
|
||||
if resolved not in _VALID:
|
||||
raise ValueError(
|
||||
f"CF_TEXT_BACKEND={resolved!r} is not valid. Choose: {', '.join(_VALID)}"
|
||||
)
|
||||
return resolved
|
||||
|
||||
# 2. Proxy prefixes — unambiguous routing regardless of model name format.
|
||||
if model_path.startswith("ollama://"):
|
||||
return "ollama"
|
||||
if model_path.startswith("vllm://"):
|
||||
return "vllm"
|
||||
|
||||
# 3. Format detection — GGUF files are unambiguously llama-cpp territory.
|
||||
if model_path.lower().endswith(".gguf"):
|
||||
return "llamacpp"
|
||||
|
||||
# 4. Safe default — transformers covers HF repo IDs and safetensors dirs.
|
||||
return "transformers"
|
||||
|
||||
|
||||
# ── Factory ───────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def make_text_backend(
|
||||
model_path: str,
|
||||
backend: str | None = None,
|
||||
mock: bool | None = None,
|
||||
) -> "TextBackend":
|
||||
"""
|
||||
Return a TextBackend for the given model.
|
||||
|
||||
mock=True or CF_TEXT_MOCK=1 → MockTextBackend (no GPU, no model file needed)
|
||||
Otherwise → backend resolved via _select_backend()
|
||||
"""
|
||||
use_mock = mock if mock is not None else os.environ.get("CF_TEXT_MOCK", "") == "1"
|
||||
if use_mock:
|
||||
from circuitforge_core.text.backends.mock import MockTextBackend
|
||||
return MockTextBackend(model_name=model_path)
|
||||
|
||||
resolved = _select_backend(model_path, backend)
|
||||
|
||||
if resolved == "llamacpp":
|
||||
from circuitforge_core.text.backends.llamacpp import LlamaCppBackend
|
||||
return LlamaCppBackend(model_path=model_path)
|
||||
|
||||
if resolved == "transformers":
|
||||
from circuitforge_core.text.backends.transformers import TransformersBackend
|
||||
return TransformersBackend(model_path=model_path)
|
||||
|
||||
if resolved == "ollama":
|
||||
from circuitforge_core.text.backends.ollama import OllamaBackend
|
||||
return OllamaBackend(model_path=model_path)
|
||||
|
||||
if resolved == "vllm":
|
||||
from circuitforge_core.text.backends.vllm import VllmBackend
|
||||
return VllmBackend(model_path=model_path)
|
||||
|
||||
raise ValueError(f"Unknown backend {resolved!r}. Expected 'llamacpp', 'transformers', 'ollama', or 'vllm'.")
|
||||
|
|
@ -1,192 +0,0 @@
|
|||
# circuitforge_core/text/backends/llamacpp.py — llama-cpp-python backend
|
||||
#
|
||||
# BSL 1.1: real inference. Requires llama-cpp-python + a GGUF model file.
|
||||
# Install: pip install circuitforge-core[text-llamacpp]
|
||||
#
|
||||
# VRAM estimates (Q4_K_M quant):
|
||||
# 1B → ~700MB 3B → ~2048MB 7B → ~4096MB
|
||||
# 13B → ~7500MB 70B → ~40000MB
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import AsyncIterator, Iterator
|
||||
|
||||
from circuitforge_core.text.backends.base import ChatMessage, GenerateResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Q4_K_M is the recommended default — best accuracy/size tradeoff for local use.
|
||||
_DEFAULT_N_CTX = int(os.environ.get("CF_TEXT_CTX", "4096"))
|
||||
_DEFAULT_N_GPU_LAYERS = int(os.environ.get("CF_TEXT_GPU_LAYERS", "-1")) # -1 = all layers
|
||||
|
||||
|
||||
def _estimate_vram_mb(model_path: str) -> int:
|
||||
"""Rough VRAM estimate from file size. Accurate enough for cf-orch budgeting."""
|
||||
try:
|
||||
size_mb = Path(model_path).stat().st_size // (1024 * 1024)
|
||||
# GGUF models typically need ~1.1× file size in VRAM (KV cache overhead)
|
||||
return int(size_mb * 1.1)
|
||||
except OSError:
|
||||
return 4096 # conservative default
|
||||
|
||||
|
||||
class LlamaCppBackend:
|
||||
"""
|
||||
Direct llama-cpp-python inference backend for GGUF models.
|
||||
|
||||
The model is loaded once at construction. All inference runs in a thread
|
||||
pool executor so async callers never block the event loop.
|
||||
|
||||
Context window, GPU layers, and thread count are configurable via env:
|
||||
CF_TEXT_CTX token context window (default 4096)
|
||||
CF_TEXT_GPU_LAYERS GPU layers to offload, -1 = all (default -1)
|
||||
CF_TEXT_THREADS CPU thread count (default: auto)
|
||||
|
||||
Requires: pip install circuitforge-core[text-llamacpp]
|
||||
"""
|
||||
|
||||
def __init__(self, model_path: str) -> None:
|
||||
try:
|
||||
from llama_cpp import Llama # type: ignore[import]
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"llama-cpp-python is required for LlamaCppBackend. "
|
||||
"Install with: pip install circuitforge-core[text-llamacpp]"
|
||||
) from exc
|
||||
|
||||
if not Path(model_path).exists():
|
||||
raise FileNotFoundError(
|
||||
f"GGUF model not found: {model_path}\n"
|
||||
"Download a GGUF model and set CF_TEXT_MODEL to its path."
|
||||
)
|
||||
|
||||
n_threads = int(os.environ.get("CF_TEXT_THREADS", "0")) or None
|
||||
logger.info(
|
||||
"Loading GGUF model %s (ctx=%d, gpu_layers=%d)",
|
||||
model_path, _DEFAULT_N_CTX, _DEFAULT_N_GPU_LAYERS,
|
||||
)
|
||||
self._llm = Llama(
|
||||
model_path=model_path,
|
||||
n_ctx=_DEFAULT_N_CTX,
|
||||
n_gpu_layers=_DEFAULT_N_GPU_LAYERS,
|
||||
n_threads=n_threads,
|
||||
verbose=False,
|
||||
)
|
||||
self._model_path = model_path
|
||||
self._vram_mb = _estimate_vram_mb(model_path)
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
return Path(self._model_path).stem
|
||||
|
||||
@property
|
||||
def vram_mb(self) -> int:
|
||||
return self._vram_mb
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
*,
|
||||
max_tokens: int = 512,
|
||||
temperature: float = 0.7,
|
||||
stop: list[str] | None = None,
|
||||
) -> GenerateResult:
|
||||
output = self._llm(
|
||||
prompt,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
stop=stop or [],
|
||||
stream=False,
|
||||
)
|
||||
text = output["choices"][0]["text"]
|
||||
tokens_used = output["usage"]["completion_tokens"]
|
||||
return GenerateResult(text=text, tokens_used=tokens_used, model=self.model_name)
|
||||
|
||||
def generate_stream(
|
||||
self,
|
||||
prompt: str,
|
||||
*,
|
||||
max_tokens: int = 512,
|
||||
temperature: float = 0.7,
|
||||
stop: list[str] | None = None,
|
||||
) -> Iterator[str]:
|
||||
for chunk in self._llm(
|
||||
prompt,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
stop=stop or [],
|
||||
stream=True,
|
||||
):
|
||||
yield chunk["choices"][0]["text"]
|
||||
|
||||
async def generate_async(
|
||||
self,
|
||||
prompt: str,
|
||||
*,
|
||||
max_tokens: int = 512,
|
||||
temperature: float = 0.7,
|
||||
stop: list[str] | None = None,
|
||||
) -> GenerateResult:
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(
|
||||
None,
|
||||
lambda: self.generate(prompt, max_tokens=max_tokens, temperature=temperature, stop=stop),
|
||||
)
|
||||
|
||||
async def generate_stream_async(
|
||||
self,
|
||||
prompt: str,
|
||||
*,
|
||||
max_tokens: int = 512,
|
||||
temperature: float = 0.7,
|
||||
stop: list[str] | None = None,
|
||||
) -> AsyncIterator[str]:
|
||||
# llama_cpp streaming is synchronous — run in executor and re-emit tokens
|
||||
import queue
|
||||
import threading
|
||||
|
||||
token_queue: queue.Queue = queue.Queue()
|
||||
_DONE = object()
|
||||
|
||||
def _produce() -> None:
|
||||
try:
|
||||
for chunk in self._llm(
|
||||
prompt,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
stop=stop or [],
|
||||
stream=True,
|
||||
):
|
||||
token_queue.put(chunk["choices"][0]["text"])
|
||||
finally:
|
||||
token_queue.put(_DONE)
|
||||
|
||||
thread = threading.Thread(target=_produce, daemon=True)
|
||||
thread.start()
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
while True:
|
||||
token = await loop.run_in_executor(None, token_queue.get)
|
||||
if token is _DONE:
|
||||
break
|
||||
yield token
|
||||
|
||||
def chat(
|
||||
self,
|
||||
messages: list[ChatMessage],
|
||||
*,
|
||||
max_tokens: int = 512,
|
||||
temperature: float = 0.7,
|
||||
) -> GenerateResult:
|
||||
# llama-cpp-python has native chat_completion for instruct models
|
||||
output = self._llm.create_chat_completion(
|
||||
messages=[m.to_dict() for m in messages],
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
)
|
||||
text = output["choices"][0]["message"]["content"]
|
||||
tokens_used = output["usage"]["completion_tokens"]
|
||||
return GenerateResult(text=text, tokens_used=tokens_used, model=self.model_name)
|
||||
|
|
@ -1,104 +0,0 @@
|
|||
# circuitforge_core/text/backends/mock.py — synthetic text backend
|
||||
#
|
||||
# MIT licensed. No model file, no GPU, no extras required.
|
||||
# Used in dev, CI, and free-tier nodes below the minimum VRAM threshold.
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import AsyncIterator, Iterator
|
||||
|
||||
from circuitforge_core.text.backends.base import ChatMessage, GenerateResult
|
||||
|
||||
_MOCK_RESPONSE = (
|
||||
"This is a synthetic response from MockTextBackend. "
|
||||
"Install a real backend (llama-cpp-python or transformers) and provide a model path "
|
||||
"to generate real text."
|
||||
)
|
||||
|
||||
|
||||
class MockTextBackend:
|
||||
"""
|
||||
Deterministic synthetic text backend for development and CI.
|
||||
|
||||
Always returns the same fixed response so tests are reproducible without
|
||||
a GPU or model file. Streaming emits the response word-by-word with a
|
||||
configurable delay so UI streaming paths can be exercised.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "mock",
|
||||
token_delay_s: float = 0.0,
|
||||
) -> None:
|
||||
self._model_name = model_name
|
||||
self._token_delay_s = token_delay_s
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
return self._model_name
|
||||
|
||||
@property
|
||||
def vram_mb(self) -> int:
|
||||
return 0
|
||||
|
||||
def _response_for(self, prompt_or_messages: str) -> str:
|
||||
return _MOCK_RESPONSE
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
*,
|
||||
max_tokens: int = 512,
|
||||
temperature: float = 0.7,
|
||||
stop: list[str] | None = None,
|
||||
) -> GenerateResult:
|
||||
text = self._response_for(prompt)
|
||||
return GenerateResult(text=text, tokens_used=len(text.split()), model=self._model_name)
|
||||
|
||||
def generate_stream(
|
||||
self,
|
||||
prompt: str,
|
||||
*,
|
||||
max_tokens: int = 512,
|
||||
temperature: float = 0.7,
|
||||
stop: list[str] | None = None,
|
||||
) -> Iterator[str]:
|
||||
import time
|
||||
for word in self._response_for(prompt).split():
|
||||
yield word + " "
|
||||
if self._token_delay_s:
|
||||
time.sleep(self._token_delay_s)
|
||||
|
||||
async def generate_async(
|
||||
self,
|
||||
prompt: str,
|
||||
*,
|
||||
max_tokens: int = 512,
|
||||
temperature: float = 0.7,
|
||||
stop: list[str] | None = None,
|
||||
) -> GenerateResult:
|
||||
return self.generate(prompt, max_tokens=max_tokens, temperature=temperature, stop=stop)
|
||||
|
||||
async def generate_stream_async(
|
||||
self,
|
||||
prompt: str,
|
||||
*,
|
||||
max_tokens: int = 512,
|
||||
temperature: float = 0.7,
|
||||
stop: list[str] | None = None,
|
||||
) -> AsyncIterator[str]:
|
||||
for word in self._response_for(prompt).split():
|
||||
yield word + " "
|
||||
if self._token_delay_s:
|
||||
await asyncio.sleep(self._token_delay_s)
|
||||
|
||||
def chat(
|
||||
self,
|
||||
messages: list[ChatMessage],
|
||||
*,
|
||||
max_tokens: int = 512,
|
||||
temperature: float = 0.7,
|
||||
) -> GenerateResult:
|
||||
# Format messages into a simple prompt for the mock response
|
||||
prompt = "\n".join(f"{m.role}: {m.content}" for m in messages)
|
||||
return self.generate(prompt, max_tokens=max_tokens, temperature=temperature)
|
||||
|
|
@ -1,201 +0,0 @@
|
|||
# circuitforge_core/text/backends/ollama.py — Ollama proxy backend for cf-text
|
||||
#
|
||||
# Routes inference requests to a running Ollama instance via its HTTP API.
|
||||
# cf-text itself holds no GPU memory; Ollama manages the model and VRAM.
|
||||
#
|
||||
# Model path format: "ollama://<model-name>" e.g. "ollama://llama3.1:8b"
|
||||
# The "ollama://" prefix is stripped before forwarding to the API.
|
||||
#
|
||||
# Environment:
|
||||
# CF_TEXT_OLLAMA_URL Base URL of the Ollama server (default: http://localhost:11434)
|
||||
#
|
||||
# MIT licensed.
|
||||
from __future__ import annotations
|
||||
|
||||
import json as _json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from typing import AsyncIterator, Iterator
|
||||
|
||||
import httpx
|
||||
|
||||
from circuitforge_core.text.backends.base import GenerateResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_DEFAULT_OLLAMA_URL = "http://localhost:11434"
|
||||
|
||||
|
||||
class OllamaBackend:
|
||||
"""
|
||||
cf-text backend that proxies inference to a local Ollama instance.
|
||||
|
||||
This backend holds no GPU memory itself — Ollama owns the model and VRAM.
|
||||
vram_mb is therefore reported as 0 so cf-orch does not double-count VRAM
|
||||
against the separate ollama service budget.
|
||||
|
||||
Supports /generate, /chat, and /v1/chat/completions (via generate/chat).
|
||||
Streaming is implemented for all variants.
|
||||
"""
|
||||
|
||||
def __init__(self, model_path: str, *, vram_mb: int = 0) -> None:
|
||||
# Strip the "ollama://" prefix from catalog paths
|
||||
self._model = model_path.removeprefix("ollama://")
|
||||
self._url = os.environ.get("CF_TEXT_OLLAMA_URL", _DEFAULT_OLLAMA_URL).rstrip("/")
|
||||
self._vram_mb = vram_mb
|
||||
logger.info("OllamaBackend: model=%r url=%r", self._model, self._url)
|
||||
|
||||
# ── Protocol properties ───────────────────────────────────────────────────
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
return self._model
|
||||
|
||||
@property
|
||||
def vram_mb(self) -> int:
|
||||
# Ollama manages its own VRAM; cf-text holds nothing.
|
||||
return self._vram_mb
|
||||
|
||||
# ── Synchronous interface ─────────────────────────────────────────────────
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
*,
|
||||
max_tokens: int = 512,
|
||||
temperature: float = 0.7,
|
||||
stop: list[str] | None = None,
|
||||
) -> GenerateResult:
|
||||
t0 = time.monotonic()
|
||||
payload: dict = {
|
||||
"model": self._model,
|
||||
"prompt": prompt,
|
||||
"stream": False,
|
||||
"options": {"temperature": temperature, "num_predict": max_tokens},
|
||||
}
|
||||
if stop:
|
||||
payload["options"]["stop"] = stop
|
||||
with httpx.Client(timeout=180.0) as client:
|
||||
resp = client.post(f"{self._url}/api/generate", json=payload)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
elapsed_ms = round((time.monotonic() - t0) * 1000)
|
||||
return GenerateResult(
|
||||
text=data.get("response", ""),
|
||||
tokens_used=data.get("eval_count", 0),
|
||||
model=self._model,
|
||||
)
|
||||
|
||||
def generate_stream(
|
||||
self,
|
||||
prompt: str,
|
||||
*,
|
||||
max_tokens: int = 512,
|
||||
temperature: float = 0.7,
|
||||
stop: list[str] | None = None,
|
||||
) -> Iterator[str]:
|
||||
payload: dict = {
|
||||
"model": self._model,
|
||||
"prompt": prompt,
|
||||
"stream": True,
|
||||
"options": {"temperature": temperature, "num_predict": max_tokens},
|
||||
}
|
||||
if stop:
|
||||
payload["options"]["stop"] = stop
|
||||
with httpx.Client(timeout=180.0) as client:
|
||||
with client.stream("POST", f"{self._url}/api/generate", json=payload) as resp:
|
||||
resp.raise_for_status()
|
||||
for line in resp.iter_lines():
|
||||
if not line:
|
||||
continue
|
||||
chunk = _json.loads(line)
|
||||
token = chunk.get("response", "")
|
||||
if token:
|
||||
yield token
|
||||
if chunk.get("done"):
|
||||
break
|
||||
|
||||
def chat(
|
||||
self,
|
||||
messages: list[dict],
|
||||
*,
|
||||
max_tokens: int = 512,
|
||||
temperature: float = 0.7,
|
||||
) -> GenerateResult:
|
||||
t0 = time.monotonic()
|
||||
payload: dict = {
|
||||
"model": self._model,
|
||||
"messages": messages,
|
||||
"stream": False,
|
||||
"options": {"temperature": temperature, "num_predict": max_tokens},
|
||||
}
|
||||
with httpx.Client(timeout=180.0) as client:
|
||||
resp = client.post(f"{self._url}/api/chat", json=payload)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
elapsed_ms = round((time.monotonic() - t0) * 1000)
|
||||
return GenerateResult(
|
||||
text=data.get("message", {}).get("content", ""),
|
||||
tokens_used=data.get("eval_count", 0),
|
||||
model=self._model,
|
||||
)
|
||||
|
||||
# ── Async interface ───────────────────────────────────────────────────────
|
||||
|
||||
async def generate_async(
|
||||
self,
|
||||
prompt: str,
|
||||
*,
|
||||
max_tokens: int = 512,
|
||||
temperature: float = 0.7,
|
||||
stop: list[str] | None = None,
|
||||
) -> GenerateResult:
|
||||
t0 = time.monotonic()
|
||||
payload: dict = {
|
||||
"model": self._model,
|
||||
"prompt": prompt,
|
||||
"stream": False,
|
||||
"options": {"temperature": temperature, "num_predict": max_tokens},
|
||||
}
|
||||
if stop:
|
||||
payload["options"]["stop"] = stop
|
||||
async with httpx.AsyncClient(timeout=180.0) as client:
|
||||
resp = await client.post(f"{self._url}/api/generate", json=payload)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
elapsed_ms = round((time.monotonic() - t0) * 1000)
|
||||
return GenerateResult(
|
||||
text=data.get("response", ""),
|
||||
tokens_used=data.get("eval_count", 0),
|
||||
model=self._model,
|
||||
)
|
||||
|
||||
async def generate_stream_async(
|
||||
self,
|
||||
prompt: str,
|
||||
*,
|
||||
max_tokens: int = 512,
|
||||
temperature: float = 0.7,
|
||||
stop: list[str] | None = None,
|
||||
) -> AsyncIterator[str]:
|
||||
payload: dict = {
|
||||
"model": self._model,
|
||||
"prompt": prompt,
|
||||
"stream": True,
|
||||
"options": {"temperature": temperature, "num_predict": max_tokens},
|
||||
}
|
||||
if stop:
|
||||
payload["options"]["stop"] = stop
|
||||
async with httpx.AsyncClient(timeout=180.0) as client:
|
||||
async with client.stream("POST", f"{self._url}/api/generate", json=payload) as resp:
|
||||
resp.raise_for_status()
|
||||
async for line in resp.aiter_lines():
|
||||
if not line:
|
||||
continue
|
||||
chunk = _json.loads(line)
|
||||
token = chunk.get("response", "")
|
||||
if token:
|
||||
yield token
|
||||
if chunk.get("done"):
|
||||
break
|
||||
|
|
@ -1,197 +0,0 @@
|
|||
# circuitforge_core/text/backends/transformers.py — HuggingFace transformers backend
|
||||
#
|
||||
# BSL 1.1: real inference. Requires torch + transformers + a model checkpoint.
|
||||
# Install: pip install circuitforge-core[text-transformers]
|
||||
#
|
||||
# Best for: HF repo IDs, safetensors checkpoints, models without GGUF versions.
|
||||
# For GGUF models prefer LlamaCppBackend — lower overhead, smaller install.
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from typing import AsyncIterator, Iterator
|
||||
|
||||
from circuitforge_core.text.backends.base import ChatMessage, GenerateResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_DEFAULT_MAX_NEW_TOKENS = 512
|
||||
_LOAD_IN_4BIT = os.environ.get("CF_TEXT_4BIT", "0") == "1"
|
||||
_LOAD_IN_8BIT = os.environ.get("CF_TEXT_8BIT", "0") == "1"
|
||||
|
||||
|
||||
class TransformersBackend:
|
||||
"""
|
||||
HuggingFace transformers inference backend.
|
||||
|
||||
Loads any causal LM available on HuggingFace Hub or a local checkpoint dir.
|
||||
Supports 4-bit and 8-bit quantization via bitsandbytes when VRAM is limited:
|
||||
CF_TEXT_4BIT=1 — load_in_4bit (requires bitsandbytes)
|
||||
CF_TEXT_8BIT=1 — load_in_8bit (requires bitsandbytes)
|
||||
|
||||
Chat completion uses the tokenizer's apply_chat_template() when available,
|
||||
falling back to a simple "User: / Assistant:" prompt format.
|
||||
|
||||
Requires: pip install circuitforge-core[text-transformers]
|
||||
"""
|
||||
|
||||
def __init__(self, model_path: str) -> None:
|
||||
try:
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"torch and transformers are required for TransformersBackend. "
|
||||
"Install with: pip install circuitforge-core[text-transformers]"
|
||||
) from exc
|
||||
|
||||
self._device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
logger.info("Loading transformers model %s on %s", model_path, self._device)
|
||||
|
||||
load_kwargs: dict = {"device_map": "auto" if self._device == "cuda" else None}
|
||||
if _LOAD_IN_4BIT:
|
||||
load_kwargs["load_in_4bit"] = True
|
||||
elif _LOAD_IN_8BIT:
|
||||
load_kwargs["load_in_8bit"] = True
|
||||
|
||||
self._tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||
self._model = AutoModelForCausalLM.from_pretrained(model_path, **load_kwargs)
|
||||
if self._device == "cpu":
|
||||
self._model = self._model.to("cpu")
|
||||
|
||||
self._model_path = model_path
|
||||
self._TextIteratorStreamer = TextIteratorStreamer
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
# HF repo IDs contain "/" — use the part after the slash as a short name
|
||||
return self._model_path.split("/")[-1]
|
||||
|
||||
@property
|
||||
def vram_mb(self) -> int:
|
||||
try:
|
||||
import torch
|
||||
if torch.cuda.is_available():
|
||||
return torch.cuda.memory_allocated() // (1024 * 1024)
|
||||
except Exception:
|
||||
pass
|
||||
return 0
|
||||
|
||||
def _build_inputs(self, prompt: str):
|
||||
return self._tokenizer(prompt, return_tensors="pt").to(self._device)
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
*,
|
||||
max_tokens: int = 512,
|
||||
temperature: float = 0.7,
|
||||
stop: list[str] | None = None,
|
||||
) -> GenerateResult:
|
||||
inputs = self._build_inputs(prompt)
|
||||
input_len = inputs["input_ids"].shape[1]
|
||||
outputs = self._model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
do_sample=temperature > 0,
|
||||
pad_token_id=self._tokenizer.eos_token_id,
|
||||
)
|
||||
new_tokens = outputs[0][input_len:]
|
||||
text = self._tokenizer.decode(new_tokens, skip_special_tokens=True)
|
||||
return GenerateResult(text=text, tokens_used=len(new_tokens), model=self.model_name)
|
||||
|
||||
def generate_stream(
|
||||
self,
|
||||
prompt: str,
|
||||
*,
|
||||
max_tokens: int = 512,
|
||||
temperature: float = 0.7,
|
||||
stop: list[str] | None = None,
|
||||
) -> Iterator[str]:
|
||||
import threading
|
||||
|
||||
inputs = self._build_inputs(prompt)
|
||||
streamer = self._TextIteratorStreamer(
|
||||
self._tokenizer, skip_prompt=True, skip_special_tokens=True
|
||||
)
|
||||
gen_kwargs = dict(
|
||||
**inputs,
|
||||
max_new_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
do_sample=temperature > 0,
|
||||
streamer=streamer,
|
||||
pad_token_id=self._tokenizer.eos_token_id,
|
||||
)
|
||||
thread = threading.Thread(target=self._model.generate, kwargs=gen_kwargs, daemon=True)
|
||||
thread.start()
|
||||
yield from streamer
|
||||
|
||||
async def generate_async(
|
||||
self,
|
||||
prompt: str,
|
||||
*,
|
||||
max_tokens: int = 512,
|
||||
temperature: float = 0.7,
|
||||
stop: list[str] | None = None,
|
||||
) -> GenerateResult:
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(
|
||||
None,
|
||||
lambda: self.generate(prompt, max_tokens=max_tokens, temperature=temperature, stop=stop),
|
||||
)
|
||||
|
||||
async def generate_stream_async(
|
||||
self,
|
||||
prompt: str,
|
||||
*,
|
||||
max_tokens: int = 512,
|
||||
temperature: float = 0.7,
|
||||
stop: list[str] | None = None,
|
||||
) -> AsyncIterator[str]:
|
||||
import queue
|
||||
import threading
|
||||
|
||||
token_queue: queue.Queue = queue.Queue()
|
||||
_DONE = object()
|
||||
|
||||
def _produce() -> None:
|
||||
try:
|
||||
for token in self.generate_stream(
|
||||
prompt, max_tokens=max_tokens, temperature=temperature
|
||||
):
|
||||
token_queue.put(token)
|
||||
finally:
|
||||
token_queue.put(_DONE)
|
||||
|
||||
threading.Thread(target=_produce, daemon=True).start()
|
||||
loop = asyncio.get_event_loop()
|
||||
while True:
|
||||
token = await loop.run_in_executor(None, token_queue.get)
|
||||
if token is _DONE:
|
||||
break
|
||||
yield token
|
||||
|
||||
def chat(
|
||||
self,
|
||||
messages: list[ChatMessage],
|
||||
*,
|
||||
max_tokens: int = 512,
|
||||
temperature: float = 0.7,
|
||||
) -> GenerateResult:
|
||||
# Use the tokenizer's chat template when available (instruct models)
|
||||
if hasattr(self._tokenizer, "apply_chat_template") and self._tokenizer.chat_template:
|
||||
prompt = self._tokenizer.apply_chat_template(
|
||||
[m.to_dict() for m in messages],
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
else:
|
||||
prompt = "\n".join(
|
||||
f"{'User' if m.role == 'user' else 'Assistant'}: {m.content}"
|
||||
for m in messages
|
||||
if m.role != "system"
|
||||
) + "\nAssistant:"
|
||||
|
||||
return self.generate(prompt, max_tokens=max_tokens, temperature=temperature)
|
||||
|
|
@ -1,213 +0,0 @@
|
|||
# circuitforge_core/text/backends/vllm.py — vllm proxy backend for cf-text
|
||||
#
|
||||
# Routes inference requests to a running vllm instance via its OpenAI-compatible
|
||||
# HTTP API (/v1/chat/completions, /v1/completions).
|
||||
# cf-text itself holds no GPU memory; vllm manages the model and VRAM.
|
||||
#
|
||||
# Model path format: "vllm://<model-id>" e.g. "vllm://Qwen/Qwen2.5-7B-Instruct"
|
||||
# The "vllm://" prefix is stripped; the remainder is the model_id sent to vllm.
|
||||
#
|
||||
# Environment:
|
||||
# CF_TEXT_VLLM_URL Base URL of the vllm server (default: http://localhost:8000)
|
||||
#
|
||||
# MIT licensed.
|
||||
from __future__ import annotations
|
||||
|
||||
import json as _json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from typing import AsyncIterator, Iterator
|
||||
|
||||
import httpx
|
||||
|
||||
from circuitforge_core.text.backends.base import ChatMessage, GenerateResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_DEFAULT_VLLM_URL = "http://localhost:8000"
|
||||
|
||||
|
||||
class VllmBackend:
|
||||
"""
|
||||
cf-text backend that proxies inference to a local vllm instance.
|
||||
|
||||
vllm exposes an OpenAI-compatible API (/v1/chat/completions).
|
||||
This backend holds no GPU memory — vllm owns the model and VRAM.
|
||||
vram_mb is reported as 0 so cf-orch does not double-count VRAM
|
||||
against the separate vllm service budget.
|
||||
"""
|
||||
|
||||
def __init__(self, model_path: str, *, vram_mb: int = 0) -> None:
|
||||
# Strip the "vllm://" prefix from catalog paths
|
||||
self._model = model_path.removeprefix("vllm://")
|
||||
self._url = os.environ.get("CF_TEXT_VLLM_URL", _DEFAULT_VLLM_URL).rstrip("/")
|
||||
self._vram_mb = vram_mb
|
||||
logger.info("VllmBackend: model=%r url=%r", self._model, self._url)
|
||||
|
||||
# ── Protocol properties ───────────────────────────────────────────────────
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
return self._model
|
||||
|
||||
@property
|
||||
def vram_mb(self) -> int:
|
||||
# vllm manages its own VRAM; cf-text holds nothing.
|
||||
return self._vram_mb
|
||||
|
||||
# ── Internal helpers ──────────────────────────────────────────────────────
|
||||
|
||||
def _chat_payload(
|
||||
self,
|
||||
messages: list[dict],
|
||||
*,
|
||||
max_tokens: int,
|
||||
temperature: float,
|
||||
stop: list[str] | None,
|
||||
stream: bool,
|
||||
) -> dict:
|
||||
payload: dict = {
|
||||
"model": self._model,
|
||||
"messages": messages,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
"stream": stream,
|
||||
}
|
||||
if stop:
|
||||
payload["stop"] = stop
|
||||
return payload
|
||||
|
||||
def _prompt_as_messages(self, prompt: str) -> list[dict]:
|
||||
return [{"role": "user", "content": prompt}]
|
||||
|
||||
# ── Synchronous interface ─────────────────────────────────────────────────
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
*,
|
||||
max_tokens: int = 512,
|
||||
temperature: float = 0.7,
|
||||
stop: list[str] | None = None,
|
||||
) -> GenerateResult:
|
||||
t0 = time.monotonic()
|
||||
payload = self._chat_payload(
|
||||
self._prompt_as_messages(prompt),
|
||||
max_tokens=max_tokens, temperature=temperature, stop=stop, stream=False,
|
||||
)
|
||||
with httpx.Client(timeout=180.0) as client:
|
||||
resp = client.post(f"{self._url}/v1/chat/completions", json=payload)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
return GenerateResult(
|
||||
text=data["choices"][0]["message"]["content"],
|
||||
tokens_used=data.get("usage", {}).get("completion_tokens", 0),
|
||||
model=self._model,
|
||||
)
|
||||
|
||||
def generate_stream(
|
||||
self,
|
||||
prompt: str,
|
||||
*,
|
||||
max_tokens: int = 512,
|
||||
temperature: float = 0.7,
|
||||
stop: list[str] | None = None,
|
||||
) -> Iterator[str]:
|
||||
payload = self._chat_payload(
|
||||
self._prompt_as_messages(prompt),
|
||||
max_tokens=max_tokens, temperature=temperature, stop=stop, stream=True,
|
||||
)
|
||||
with httpx.Client(timeout=180.0) as client:
|
||||
with client.stream("POST", f"{self._url}/v1/chat/completions", json=payload) as resp:
|
||||
resp.raise_for_status()
|
||||
for line in resp.iter_lines():
|
||||
token = _parse_sse_token(line)
|
||||
if token:
|
||||
yield token
|
||||
|
||||
def chat(
|
||||
self,
|
||||
messages: list[ChatMessage],
|
||||
*,
|
||||
max_tokens: int = 512,
|
||||
temperature: float = 0.7,
|
||||
) -> GenerateResult:
|
||||
dicts = [m.to_dict() if hasattr(m, "to_dict") else m for m in messages]
|
||||
payload = self._chat_payload(
|
||||
dicts, max_tokens=max_tokens, temperature=temperature, stop=None, stream=False,
|
||||
)
|
||||
with httpx.Client(timeout=180.0) as client:
|
||||
resp = client.post(f"{self._url}/v1/chat/completions", json=payload)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
return GenerateResult(
|
||||
text=data["choices"][0]["message"]["content"],
|
||||
tokens_used=data.get("usage", {}).get("completion_tokens", 0),
|
||||
model=self._model,
|
||||
)
|
||||
|
||||
# ── Async interface ───────────────────────────────────────────────────────
|
||||
|
||||
async def generate_async(
|
||||
self,
|
||||
prompt: str,
|
||||
*,
|
||||
max_tokens: int = 512,
|
||||
temperature: float = 0.7,
|
||||
stop: list[str] | None = None,
|
||||
) -> GenerateResult:
|
||||
payload = self._chat_payload(
|
||||
self._prompt_as_messages(prompt),
|
||||
max_tokens=max_tokens, temperature=temperature, stop=stop, stream=False,
|
||||
)
|
||||
async with httpx.AsyncClient(timeout=180.0) as client:
|
||||
resp = await client.post(f"{self._url}/v1/chat/completions", json=payload)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
return GenerateResult(
|
||||
text=data["choices"][0]["message"]["content"],
|
||||
tokens_used=data.get("usage", {}).get("completion_tokens", 0),
|
||||
model=self._model,
|
||||
)
|
||||
|
||||
async def generate_stream_async(
|
||||
self,
|
||||
prompt: str,
|
||||
*,
|
||||
max_tokens: int = 512,
|
||||
temperature: float = 0.7,
|
||||
stop: list[str] | None = None,
|
||||
) -> AsyncIterator[str]:
|
||||
payload = self._chat_payload(
|
||||
self._prompt_as_messages(prompt),
|
||||
max_tokens=max_tokens, temperature=temperature, stop=stop, stream=True,
|
||||
)
|
||||
async with httpx.AsyncClient(timeout=180.0) as client:
|
||||
async with client.stream("POST", f"{self._url}/v1/chat/completions", json=payload) as resp:
|
||||
resp.raise_for_status()
|
||||
async for line in resp.aiter_lines():
|
||||
token = _parse_sse_token(line)
|
||||
if token:
|
||||
yield token
|
||||
|
||||
|
||||
# ── SSE parser (OpenAI/vllm format) ──────────────────────────────────────────
|
||||
|
||||
def _parse_sse_token(line: str) -> str:
|
||||
"""Extract content token from an OpenAI-format SSE line.
|
||||
|
||||
Lines look like: data: {"choices": [{"delta": {"content": "word"}}]}
|
||||
Terminal line: data: [DONE]
|
||||
Returns the token string, or "" for empty/done/non-data lines.
|
||||
"""
|
||||
if not line.startswith("data:"):
|
||||
return ""
|
||||
payload = line[5:].strip()
|
||||
if payload == "[DONE]":
|
||||
return ""
|
||||
try:
|
||||
chunk = _json.loads(payload)
|
||||
return chunk["choices"][0]["delta"].get("content", "") or ""
|
||||
except (KeyError, IndexError, _json.JSONDecodeError):
|
||||
return ""
|
||||
|
|
@ -1,87 +0,0 @@
|
|||
"""
|
||||
circuitforge_core.tts — Text-to-speech service module.
|
||||
|
||||
Quick start (mock mode — no GPU or model required):
|
||||
|
||||
import os; os.environ["CF_TTS_MOCK"] = "1"
|
||||
from circuitforge_core.tts import synthesize
|
||||
|
||||
result = synthesize("Hello world")
|
||||
open("out.ogg", "wb").write(result.audio_bytes)
|
||||
|
||||
Real inference (chatterbox-turbo):
|
||||
|
||||
export CF_TTS_MODEL=/Library/Assets/LLM/chatterbox/hub/models--ResembleAI--chatterbox-turbo/snapshots/<hash>
|
||||
from circuitforge_core.tts import synthesize
|
||||
|
||||
cf-orch service profile:
|
||||
|
||||
service_type: cf-tts
|
||||
max_mb: 768
|
||||
max_concurrent: 1
|
||||
shared: true
|
||||
managed:
|
||||
exec: python -m circuitforge_core.tts.app
|
||||
args: --model <path> --port {port} --gpu-id {gpu_id}
|
||||
port: 8005
|
||||
health: /health
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
from circuitforge_core.tts.backends.base import (
|
||||
AudioFormat,
|
||||
TTSBackend,
|
||||
TTSResult,
|
||||
make_tts_backend,
|
||||
)
|
||||
from circuitforge_core.tts.backends.mock import MockTTSBackend
|
||||
|
||||
_backend: TTSBackend | None = None
|
||||
|
||||
|
||||
def _get_backend() -> TTSBackend:
|
||||
global _backend
|
||||
if _backend is None:
|
||||
model_path = os.environ.get("CF_TTS_MODEL", "mock")
|
||||
mock = model_path == "mock" or os.environ.get("CF_TTS_MOCK", "") == "1"
|
||||
_backend = make_tts_backend(model_path, mock=mock)
|
||||
return _backend
|
||||
|
||||
|
||||
def synthesize(
|
||||
text: str,
|
||||
*,
|
||||
exaggeration: float = 0.5,
|
||||
cfg_weight: float = 0.5,
|
||||
temperature: float = 0.8,
|
||||
audio_prompt: bytes | None = None,
|
||||
format: AudioFormat = "ogg",
|
||||
) -> TTSResult:
|
||||
"""Synthesize speech from text using the process-level backend."""
|
||||
return _get_backend().synthesize(
|
||||
text,
|
||||
exaggeration=exaggeration,
|
||||
cfg_weight=cfg_weight,
|
||||
temperature=temperature,
|
||||
audio_prompt=audio_prompt,
|
||||
format=format,
|
||||
)
|
||||
|
||||
|
||||
def reset_backend() -> None:
|
||||
"""Reset the process-level singleton. Test teardown only."""
|
||||
global _backend
|
||||
_backend = None
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AudioFormat",
|
||||
"TTSBackend",
|
||||
"TTSResult",
|
||||
"MockTTSBackend",
|
||||
"make_tts_backend",
|
||||
"synthesize",
|
||||
"reset_backend",
|
||||
]
|
||||
|
|
@ -1,102 +0,0 @@
|
|||
"""
|
||||
cf-tts FastAPI service — managed by cf-orch.
|
||||
|
||||
Endpoints:
|
||||
GET /health → {"status": "ok", "model": str, "vram_mb": int}
|
||||
POST /synthesize → audio bytes (Content-Type: audio/ogg or audio/wav or audio/mpeg)
|
||||
|
||||
Usage:
|
||||
python -m circuitforge_core.tts.app \
|
||||
--model /Library/Assets/LLM/chatterbox/hub/models--ResembleAI--chatterbox-turbo/snapshots/<hash> \
|
||||
--port 8005 \
|
||||
--gpu-id 0
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from fastapi import FastAPI, Form, HTTPException, UploadFile
|
||||
from fastapi.responses import Response
|
||||
|
||||
from circuitforge_core.tts.backends.base import AudioFormat, TTSBackend, make_tts_backend
|
||||
|
||||
_CONTENT_TYPES: dict[str, str] = {
|
||||
"ogg": "audio/ogg",
|
||||
"wav": "audio/wav",
|
||||
"mp3": "audio/mpeg",
|
||||
}
|
||||
|
||||
app = FastAPI(title="cf-tts")
|
||||
_backend = None # type: TTSBackend | None
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
def health() -> dict:
|
||||
if _backend is None:
|
||||
raise HTTPException(503, detail="backend not initialised")
|
||||
return {"status": "ok", "model": _backend.model_name, "vram_mb": _backend.vram_mb}
|
||||
|
||||
|
||||
@app.post("/synthesize")
|
||||
async def synthesize(
|
||||
text: Annotated[str, Form()],
|
||||
format: Annotated[AudioFormat, Form()] = "ogg",
|
||||
exaggeration: Annotated[float, Form()] = 0.5,
|
||||
cfg_weight: Annotated[float, Form()] = 0.5,
|
||||
temperature: Annotated[float, Form()] = 0.8,
|
||||
audio_prompt: UploadFile | None = None,
|
||||
) -> Response:
|
||||
if _backend is None:
|
||||
raise HTTPException(503, detail="backend not initialised")
|
||||
if not text.strip():
|
||||
raise HTTPException(422, detail="text must not be empty")
|
||||
|
||||
prompt_bytes: bytes | None = None
|
||||
if audio_prompt is not None:
|
||||
prompt_bytes = await audio_prompt.read()
|
||||
|
||||
result = _backend.synthesize(
|
||||
text,
|
||||
exaggeration=exaggeration,
|
||||
cfg_weight=cfg_weight,
|
||||
temperature=temperature,
|
||||
audio_prompt=prompt_bytes,
|
||||
format=format,
|
||||
)
|
||||
return Response(
|
||||
content=result.audio_bytes,
|
||||
media_type=_CONTENT_TYPES.get(result.format, "audio/ogg"),
|
||||
headers={
|
||||
"X-Duration-S": str(round(result.duration_s, 3)),
|
||||
"X-Model": result.model,
|
||||
"X-Sample-Rate": str(result.sample_rate),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _parse_args() -> argparse.Namespace:
|
||||
p = argparse.ArgumentParser(description="cf-tts service")
|
||||
p.add_argument("--model", required=True)
|
||||
p.add_argument("--port", type=int, default=8005)
|
||||
p.add_argument("--host", default="0.0.0.0")
|
||||
p.add_argument("--gpu-id", type=int, default=0)
|
||||
p.add_argument("--mock", action="store_true")
|
||||
return p.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
args = _parse_args()
|
||||
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_id)
|
||||
|
||||
mock = args.mock or args.model == "mock"
|
||||
device = "cpu" if mock else "cuda"
|
||||
|
||||
_backend = make_tts_backend(args.model, mock=mock, device=device)
|
||||
print(f"cf-tts backend ready: {_backend.model_name} ({_backend.vram_mb} MB)")
|
||||
|
||||
uvicorn.run(app, host=args.host, port=args.port)
|
||||
|
|
@ -1,4 +0,0 @@
|
|||
from .base import AudioFormat, TTSBackend, TTSResult, make_tts_backend
|
||||
from .mock import MockTTSBackend
|
||||
|
||||
__all__ = ["AudioFormat", "TTSBackend", "TTSResult", "make_tts_backend", "MockTTSBackend"]
|
||||
|
|
@ -1,89 +0,0 @@
|
|||
"""
|
||||
TTSBackend Protocol — backend-agnostic TTS interface.
|
||||
|
||||
All backends return TTSResult with audio bytes in the requested format.
|
||||
Supported formats: ogg (default, smallest), wav (uncompressed, always works), mp3.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Literal, Protocol, runtime_checkable
|
||||
|
||||
AudioFormat = Literal["ogg", "wav", "mp3"]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TTSResult:
|
||||
audio_bytes: bytes
|
||||
sample_rate: int
|
||||
duration_s: float
|
||||
format: AudioFormat = "ogg"
|
||||
model: str = ""
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class TTSBackend(Protocol):
|
||||
def synthesize(
|
||||
self,
|
||||
text: str,
|
||||
*,
|
||||
exaggeration: float = 0.5,
|
||||
cfg_weight: float = 0.5,
|
||||
temperature: float = 0.8,
|
||||
audio_prompt: bytes | None = None,
|
||||
format: AudioFormat = "ogg",
|
||||
) -> TTSResult: ...
|
||||
|
||||
@property
|
||||
def model_name(self) -> str: ...
|
||||
|
||||
@property
|
||||
def vram_mb(self) -> int: ...
|
||||
|
||||
|
||||
def _encode_audio(
|
||||
wav_tensor, # torch.Tensor shape [1, T] or [T]
|
||||
sample_rate: int,
|
||||
format: AudioFormat,
|
||||
) -> bytes:
|
||||
"""Convert a torch tensor to audio bytes in the requested format."""
|
||||
import torch
|
||||
import torchaudio
|
||||
|
||||
wav = wav_tensor
|
||||
if wav.dim() == 1:
|
||||
wav = wav.unsqueeze(0)
|
||||
wav = wav.to(torch.float32).cpu()
|
||||
|
||||
buf = io.BytesIO()
|
||||
if format == "wav":
|
||||
torchaudio.save(buf, wav, sample_rate, format="wav")
|
||||
elif format == "ogg":
|
||||
# libvorbis may not be available on all torchaudio builds; fall back to wav
|
||||
try:
|
||||
torchaudio.save(buf, wav, sample_rate, format="ogg", encoding="vorbis")
|
||||
except Exception:
|
||||
buf = io.BytesIO()
|
||||
torchaudio.save(buf, wav, sample_rate, format="wav")
|
||||
elif format == "mp3":
|
||||
# torchaudio MP3 encode requires ffmpeg backend; fall back to wav on failure
|
||||
try:
|
||||
torchaudio.save(buf, wav, sample_rate, format="mp3")
|
||||
except Exception:
|
||||
buf = io.BytesIO()
|
||||
torchaudio.save(buf, wav, sample_rate, format="wav")
|
||||
return buf.getvalue()
|
||||
|
||||
|
||||
def make_tts_backend(
|
||||
model_path: str,
|
||||
*,
|
||||
mock: bool = False,
|
||||
device: str = "cuda",
|
||||
) -> TTSBackend:
|
||||
if mock:
|
||||
from circuitforge_core.tts.backends.mock import MockTTSBackend
|
||||
return MockTTSBackend()
|
||||
from circuitforge_core.tts.backends.chatterbox import ChatterboxTurboBackend
|
||||
return ChatterboxTurboBackend(model_path=model_path, device=device)
|
||||
|
|
@ -1,82 +0,0 @@
|
|||
"""ChatterboxTurboBackend — ResembleAI chatterbox-turbo TTS via chatterbox-tts package."""
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
from circuitforge_core.tts.backends.base import (
|
||||
AudioFormat,
|
||||
TTSBackend,
|
||||
TTSResult,
|
||||
_encode_audio,
|
||||
)
|
||||
|
||||
_VRAM_MB = 768 # conservative estimate for chatterbox-turbo weights
|
||||
|
||||
|
||||
class ChatterboxTurboBackend:
|
||||
def __init__(self, model_path: str, device: str = "cuda") -> None:
|
||||
os.environ.setdefault("CUDA_VISIBLE_DEVICES", "0")
|
||||
from chatterbox.models.s3gen import S3GEN_SR
|
||||
from chatterbox.tts import ChatterboxTTS
|
||||
|
||||
self._sr = S3GEN_SR
|
||||
self._device = device
|
||||
self._model = ChatterboxTTS.from_local(model_path, device=device)
|
||||
self._model_path = model_path
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
return f"chatterbox-turbo@{os.path.basename(self._model_path)}"
|
||||
|
||||
@property
|
||||
def vram_mb(self) -> int:
|
||||
return _VRAM_MB
|
||||
|
||||
def synthesize(
|
||||
self,
|
||||
text: str,
|
||||
*,
|
||||
exaggeration: float = 0.5,
|
||||
cfg_weight: float = 0.5,
|
||||
temperature: float = 0.8,
|
||||
audio_prompt: bytes | None = None,
|
||||
format: AudioFormat = "ogg",
|
||||
) -> TTSResult:
|
||||
audio_prompt_path: str | None = None
|
||||
_tmp = None
|
||||
|
||||
if audio_prompt is not None:
|
||||
_tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
|
||||
_tmp.write(audio_prompt)
|
||||
_tmp.flush()
|
||||
audio_prompt_path = _tmp.name
|
||||
|
||||
try:
|
||||
wav = self._model.generate(
|
||||
text,
|
||||
exaggeration=exaggeration,
|
||||
cfg_weight=cfg_weight,
|
||||
temperature=temperature,
|
||||
audio_prompt_path=audio_prompt_path,
|
||||
)
|
||||
finally:
|
||||
if _tmp is not None:
|
||||
_tmp.close()
|
||||
os.unlink(_tmp.name)
|
||||
|
||||
duration_s = wav.shape[-1] / self._sr
|
||||
audio_bytes = _encode_audio(wav, self._sr, format)
|
||||
return TTSResult(
|
||||
audio_bytes=audio_bytes,
|
||||
sample_rate=self._sr,
|
||||
duration_s=duration_s,
|
||||
format=format,
|
||||
model=self.model_name,
|
||||
)
|
||||
|
||||
|
||||
assert isinstance(
|
||||
ChatterboxTurboBackend.__new__(ChatterboxTurboBackend), TTSBackend
|
||||
), "ChatterboxTurboBackend must satisfy TTSBackend Protocol"
|
||||
|
|
@ -1,56 +0,0 @@
|
|||
"""MockTTSBackend — no GPU, no model required. Returns a silent WAV clip."""
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import struct
|
||||
import wave
|
||||
|
||||
from circuitforge_core.tts.backends.base import AudioFormat, TTSBackend, TTSResult
|
||||
|
||||
_SAMPLE_RATE = 24000
|
||||
|
||||
|
||||
def _silent_wav(duration_s: float = 0.5, sample_rate: int = _SAMPLE_RATE) -> bytes:
|
||||
num_samples = int(duration_s * sample_rate)
|
||||
buf = io.BytesIO()
|
||||
with wave.open(buf, "wb") as w:
|
||||
w.setnchannels(1)
|
||||
w.setsampwidth(2)
|
||||
w.setframerate(sample_rate)
|
||||
w.writeframes(struct.pack(f"<{num_samples}h", *([0] * num_samples)))
|
||||
return buf.getvalue()
|
||||
|
||||
|
||||
class MockTTSBackend:
|
||||
"""Minimal TTSBackend implementation for tests and CI."""
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
return "mock-tts"
|
||||
|
||||
@property
|
||||
def vram_mb(self) -> int:
|
||||
return 0
|
||||
|
||||
def synthesize(
|
||||
self,
|
||||
text: str,
|
||||
*,
|
||||
exaggeration: float = 0.5,
|
||||
cfg_weight: float = 0.5,
|
||||
temperature: float = 0.8,
|
||||
audio_prompt: bytes | None = None,
|
||||
format: AudioFormat = "ogg",
|
||||
) -> TTSResult:
|
||||
duration_s = max(0.1, len(text.split()) * 0.3)
|
||||
audio = _silent_wav(duration_s)
|
||||
return TTSResult(
|
||||
audio_bytes=audio,
|
||||
sample_rate=_SAMPLE_RATE,
|
||||
duration_s=duration_s,
|
||||
format="wav",
|
||||
model=self.model_name,
|
||||
)
|
||||
|
||||
|
||||
assert isinstance(MockTTSBackend(), TTSBackend), "MockTTSBackend must satisfy TTSBackend Protocol"
|
||||
|
|
@ -1,4 +0,0 @@
|
|||
from .base import VectorMatch, VectorStore
|
||||
from .sqlite_vec import LocalSQLiteVecStore
|
||||
|
||||
__all__ = ["VectorMatch", "VectorStore", "LocalSQLiteVecStore"]
|
||||
|
|
@ -1,50 +0,0 @@
|
|||
"""
|
||||
circuitforge_core.vector.base — VectorStore ABC and shared types.
|
||||
|
||||
Concrete implementations: LocalSQLiteVecStore (local), QdrantStore (cloud Paid tier).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class VectorMatch:
|
||||
"""A single result from a vector similarity search."""
|
||||
|
||||
entry_id: str
|
||||
score: float # lower is better (L2 / cosine distance)
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
class VectorStore(ABC):
|
||||
"""Abstract interface for vector storage backends."""
|
||||
|
||||
@abstractmethod
|
||||
def upsert(
|
||||
self, entry_id: str, vector: list[float], metadata: dict[str, Any]
|
||||
) -> None:
|
||||
"""Insert or replace a vector and its metadata."""
|
||||
|
||||
@abstractmethod
|
||||
def query(
|
||||
self,
|
||||
vector: list[float],
|
||||
top_k: int = 10,
|
||||
filter_metadata: dict[str, Any] | None = None,
|
||||
) -> list[VectorMatch]:
|
||||
"""Return the top_k nearest vectors. Optional metadata filter applied post-search."""
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, entry_id: str) -> None:
|
||||
"""Remove a single vector by string ID. No-op if not found."""
|
||||
|
||||
@abstractmethod
|
||||
def delete_where(self, filter_metadata: dict[str, Any]) -> int:
|
||||
"""Remove all vectors whose metadata matches all key-value pairs. Returns count removed.
|
||||
|
||||
Raises ValueError if filter_metadata is empty (would delete entire store).
|
||||
"""
|
||||
|
|
@ -1,185 +0,0 @@
|
|||
# circuitforge_core/vector/sqlite_vec.py
|
||||
"""
|
||||
circuitforge_core.vector.sqlite_vec -- sqlite-vec backed VectorStore.
|
||||
|
||||
Suitable for single-user local deployments. Cloud Paid tier replaces
|
||||
this with QdrantStore via the same VectorStore ABC.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import sqlite3
|
||||
import struct
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Any, Generator
|
||||
|
||||
import sqlite_vec
|
||||
|
||||
from .base import VectorMatch, VectorStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_SAFE_IDENTIFIER = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]*$")
|
||||
|
||||
|
||||
def _serialize(vector: list[float]) -> bytes:
|
||||
return struct.pack(f"<{len(vector)}f", *vector)
|
||||
|
||||
|
||||
class LocalSQLiteVecStore(VectorStore):
|
||||
"""
|
||||
VectorStore backed by sqlite-vec virtual tables.
|
||||
|
||||
Uses two tables per logical store:
|
||||
- ``<table>_vecs``: vec0 virtual table (rowid-indexed float vectors)
|
||||
- ``<table>_meta``: companion table mapping rowid to string ID + JSON metadata
|
||||
|
||||
Args:
|
||||
db_path: Path to SQLite database file.
|
||||
table: Logical name prefix (default ``"vecs"``).
|
||||
dimensions: Vector length; must match the embedding model (default 768).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db_path: str | Path,
|
||||
table: str = "vecs",
|
||||
dimensions: int = 768,
|
||||
) -> None:
|
||||
if not _SAFE_IDENTIFIER.match(table):
|
||||
raise ValueError(
|
||||
f"table must be a valid SQL identifier (letters, digits, underscores): {table!r}"
|
||||
)
|
||||
self.db_path = str(db_path)
|
||||
self.table = table
|
||||
self.dimensions = dimensions
|
||||
self._init_tables()
|
||||
|
||||
@contextmanager
|
||||
def _conn(self) -> Generator[sqlite3.Connection, None, None]:
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
conn.enable_load_extension(True)
|
||||
sqlite_vec.load(conn)
|
||||
conn.enable_load_extension(False)
|
||||
conn.row_factory = sqlite3.Row
|
||||
try:
|
||||
yield conn
|
||||
conn.commit()
|
||||
except Exception:
|
||||
conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def _init_tables(self) -> None:
|
||||
with self._conn() as conn:
|
||||
conn.execute(f"""
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS {self.table}_vecs
|
||||
USING vec0(embedding float[{self.dimensions}])
|
||||
""")
|
||||
conn.execute(f"""
|
||||
CREATE TABLE IF NOT EXISTS {self.table}_meta (
|
||||
rowid INTEGER PRIMARY KEY,
|
||||
entry_id TEXT NOT NULL UNIQUE,
|
||||
metadata TEXT NOT NULL DEFAULT '{{}}'
|
||||
)
|
||||
""")
|
||||
|
||||
def upsert(
|
||||
self, entry_id: str, vector: list[float], metadata: dict[str, Any]
|
||||
) -> None:
|
||||
with self._conn() as conn:
|
||||
row = conn.execute(
|
||||
f"SELECT rowid FROM {self.table}_meta WHERE entry_id = ?", [entry_id]
|
||||
).fetchone()
|
||||
|
||||
if row:
|
||||
rowid = row["rowid"]
|
||||
conn.execute(
|
||||
f"UPDATE {self.table}_vecs SET embedding = ? WHERE rowid = ?",
|
||||
[_serialize(vector), rowid],
|
||||
)
|
||||
conn.execute(
|
||||
f"UPDATE {self.table}_meta SET metadata = ? WHERE rowid = ?",
|
||||
[json.dumps(metadata), rowid],
|
||||
)
|
||||
else:
|
||||
cursor = conn.execute(
|
||||
f"INSERT INTO {self.table}_meta(entry_id, metadata) VALUES (?, ?)",
|
||||
[entry_id, json.dumps(metadata)],
|
||||
)
|
||||
rowid = cursor.lastrowid
|
||||
conn.execute(
|
||||
f"INSERT INTO {self.table}_vecs(rowid, embedding) VALUES (?, ?)",
|
||||
[rowid, _serialize(vector)],
|
||||
)
|
||||
|
||||
def query(
|
||||
self,
|
||||
vector: list[float],
|
||||
top_k: int = 10,
|
||||
filter_metadata: dict[str, Any] | None = None,
|
||||
) -> list[VectorMatch]:
|
||||
with self._conn() as conn:
|
||||
rows = conn.execute(
|
||||
f"""
|
||||
SELECT m.entry_id, v.distance, m.metadata
|
||||
FROM {self.table}_vecs v
|
||||
JOIN {self.table}_meta m ON m.rowid = v.rowid
|
||||
WHERE v.embedding MATCH ? AND k = ?
|
||||
ORDER BY v.distance
|
||||
""",
|
||||
[_serialize(vector), top_k],
|
||||
).fetchall()
|
||||
results = [
|
||||
VectorMatch(
|
||||
entry_id=r["entry_id"],
|
||||
score=r["distance"],
|
||||
metadata=json.loads(r["metadata"]),
|
||||
)
|
||||
for r in rows
|
||||
]
|
||||
|
||||
if filter_metadata:
|
||||
results = [
|
||||
r
|
||||
for r in results
|
||||
if all(r.metadata.get(k) == v for k, v in filter_metadata.items())
|
||||
]
|
||||
return results
|
||||
|
||||
def delete(self, entry_id: str) -> None:
|
||||
with self._conn() as conn:
|
||||
row = conn.execute(
|
||||
f"SELECT rowid FROM {self.table}_meta WHERE entry_id = ?", [entry_id]
|
||||
).fetchone()
|
||||
if row:
|
||||
rowid = row["rowid"]
|
||||
conn.execute(f"DELETE FROM {self.table}_vecs WHERE rowid = ?", [rowid])
|
||||
conn.execute(f"DELETE FROM {self.table}_meta WHERE rowid = ?", [rowid])
|
||||
|
||||
def delete_where(self, filter_metadata: dict[str, Any]) -> int:
|
||||
if not filter_metadata:
|
||||
raise ValueError(
|
||||
"delete_where requires a non-empty filter; refusing to delete entire store"
|
||||
)
|
||||
with self._conn() as conn:
|
||||
rows = conn.execute(
|
||||
f"SELECT rowid, metadata FROM {self.table}_meta"
|
||||
).fetchall()
|
||||
to_delete = [
|
||||
r["rowid"]
|
||||
for r in rows
|
||||
if all(
|
||||
json.loads(r["metadata"]).get(k) == v
|
||||
for k, v in filter_metadata.items()
|
||||
)
|
||||
]
|
||||
for rowid in to_delete:
|
||||
conn.execute(f"DELETE FROM {self.table}_vecs WHERE rowid = ?", [rowid])
|
||||
conn.execute(f"DELETE FROM {self.table}_meta WHERE rowid = ?", [rowid])
|
||||
return len(to_delete)
|
||||
|
|
@ -1,11 +0,0 @@
|
|||
"""
|
||||
circuitforge_core.video — cf-video service: video VLM inference via Marlin-2B.
|
||||
|
||||
Exposes a FastAPI process (managed by cf-orch) with endpoints:
|
||||
GET /health → {"status": "ok", "model": str, "vram_mb": int}
|
||||
POST /caption → CaptionResult (scene description + timestamped events)
|
||||
POST /find → FindResult (temporal grounding span for a natural-language event)
|
||||
|
||||
Run as:
|
||||
python -m circuitforge_core.video.app --model /path/to/NemoStation--Marlin-2B --port 8016 --gpu-id 0
|
||||
"""
|
||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Reference in a new issue