From cdeb410f45db550fa87920aafac7a8e31cf7826e Mon Sep 17 00:00:00 2001 From: pyr0ball Date: Fri, 5 Jun 2026 10:19:31 -0700 Subject: [PATCH] fix: assorted reliability fixes - cloud_session: bypass IPs now honour valid JWT tokens so logged-in devs land on their own account DB; invalid/expired JWTs soft-fail to guest instead of hard 401 (public endpoints stay accessible with stale cookies) - tasks/scheduler: log unhandled exceptions that escape run_task_fn to prevent silent task stalls in the batch worker - reranker: add module-level logger for structured log output - text/transformers: use BitsAndBytesConfig for quantization (deprecated load_in_4bit/load_in_8bit kwargs removed in transformers 4.40+) - __init__: derive __version__ from installed package metadata so editable installs always report the correct version string --- circuitforge_core/__init__.py | 7 +++- circuitforge_core/cloud_session/__init__.py | 37 ++++++++++++++++--- circuitforge_core/reranker/__init__.py | 3 ++ circuitforge_core/tasks/scheduler.py | 10 ++++- .../text/backends/transformers.py | 10 +++-- 5 files changed, 55 insertions(+), 12 deletions(-) diff --git a/circuitforge_core/__init__.py b/circuitforge_core/__init__.py index 2ed946a..68957a6 100644 --- a/circuitforge_core/__init__.py +++ b/circuitforge_core/__init__.py @@ -1,4 +1,9 @@ -__version__ = "0.18.0" +from importlib.metadata import PackageNotFoundError, version + +try: + __version__ = version("circuitforge-core") +except PackageNotFoundError: + __version__ = "dev" # running from source without an editable install try: from circuitforge_core.community import CommunityDB, CommunityPost, SharedStore diff --git a/circuitforge_core/cloud_session/__init__.py b/circuitforge_core/cloud_session/__init__.py index a90d212..abdda4e 100644 --- a/circuitforge_core/cloud_session/__init__.py +++ b/circuitforge_core/cloud_session/__init__.py @@ -39,6 +39,13 @@ from dataclasses import dataclass, field from pathlib import Path from typing import Any, Callable +try: + from starlette.requests import Request as _Request + from starlette.responses import Response as _Response +except ImportError: # pragma: no cover — starlette may be absent in non-web envs + _Request = Any # type: ignore[assignment,misc] + _Response = Any # type: ignore[assignment,misc] + log = logging.getLogger(__name__) TIERS: list[str] = ["free", "paid", "premium", "ultra"] @@ -248,22 +255,40 @@ class CloudSessionFactory: request.headers.get("x-real-ip", "") or (request.client.host if request.client else "") ) - if _is_bypass_ip(client_ip, self._bypass_nets, self._bypass_ips): - log.debug("Bypass IP %s — returning local-dev session for product %s", client_ip, self.product) - return CloudUser(user_id="local-dev", tier="local", product=self.product, has_byok=has_byok) + is_bypass = _is_bypass_ip(client_ip, self._bypass_nets, self._bypass_ips) raw_session = ( request.headers.get("x-cf-session", "").strip() or request.cookies.get("cf_session", "").strip() ) + + # Bypass IPs skip the JWT *requirement* but not JWT *validation*. + # If a token is present (dev is logged in), honour it so they land on + # their own account DB rather than the shared local-dev DB. if not raw_session: + if is_bypass: + log.debug("Bypass IP %s, no token — returning local-dev session for product %s", client_ip, self.product) + return CloudUser(user_id="local-dev", tier="local", product=self.product, has_byok=has_byok) return self._resolve_guest(request, response) token = _extract_session_token(raw_session) if not token: return self._resolve_guest(request, response) - user_id = self.validate_jwt(token) + # Soft-fail on invalid/expired JWT: downgrade to guest rather than + # hard-erroring with 401. Public endpoints (e.g. community blocklist) + # should remain accessible even when the browser has a stale cookie. + # Routes that genuinely require an authenticated identity should gate + # themselves with require_tier() — that's where the 401/403 belongs. + try: + user_id = self.validate_jwt(token) + except Exception: + log.warning( + "JWT validation failed for product %s (expired or tampered) — falling back to guest", + self.product, + ) + return self._resolve_guest(request, response) + self._ensure_provisioned(user_id) tier_data = self._resolve_tier(user_id) tier = tier_data.get("tier", "free") @@ -283,11 +308,11 @@ class CloudSessionFactory: meta=meta, ) - def dependency(self) -> Callable[[Any, Any], CloudUser]: + def dependency(self) -> Callable[["_Request", "_Response"], CloudUser]: """Return a FastAPI-compatible dependency function (use with Depends()).""" factory = self - def _get_session(request: Any, response: Any) -> CloudUser: + def _get_session(request: _Request, response: _Response) -> CloudUser: return factory.resolve(request, response) return _get_session diff --git a/circuitforge_core/reranker/__init__.py b/circuitforge_core/reranker/__init__.py index 81452d1..37d15a5 100644 --- a/circuitforge_core/reranker/__init__.py +++ b/circuitforge_core/reranker/__init__.py @@ -51,9 +51,12 @@ cf-orch service profile (Phase 3 — remote backend): """ from __future__ import annotations +import logging import os from typing import Sequence +logger = logging.getLogger(__name__) + from circuitforge_core.reranker.base import RerankResult, Reranker, TextReranker from circuitforge_core.reranker.adapters.mock import MockTextReranker diff --git a/circuitforge_core/tasks/scheduler.py b/circuitforge_core/tasks/scheduler.py index 49cf9ba..ff68b45 100644 --- a/circuitforge_core/tasks/scheduler.py +++ b/circuitforge_core/tasks/scheduler.py @@ -169,7 +169,15 @@ class LocalScheduler: if not q: break task = q.popleft() - self._run_task(self._db_path, task.id, task_type, task.job_id, task.params) + try: + self._run_task(self._db_path, task.id, task_type, task.job_id, task.params) + except Exception as exc: + # run_task_fn should handle its own exceptions. If it leaks one, + # log it so the task doesn't silently stay 'queued' with no trace. + logger.exception( + "Unhandled exception in batch worker task %d (%s): %s", + task.id, task_type, exc, + ) finally: with self._lock: self._active.pop(task_type, None) diff --git a/circuitforge_core/text/backends/transformers.py b/circuitforge_core/text/backends/transformers.py index bb17591..1a414cd 100644 --- a/circuitforge_core/text/backends/transformers.py +++ b/circuitforge_core/text/backends/transformers.py @@ -50,10 +50,12 @@ class TransformersBackend: logger.info("Loading transformers model %s on %s", model_path, self._device) load_kwargs: dict = {"device_map": "auto" if self._device == "cuda" else None} - if _LOAD_IN_4BIT: - load_kwargs["load_in_4bit"] = True - elif _LOAD_IN_8BIT: - load_kwargs["load_in_8bit"] = True + if _LOAD_IN_4BIT or _LOAD_IN_8BIT: + from transformers import BitsAndBytesConfig + load_kwargs["quantization_config"] = BitsAndBytesConfig( + load_in_4bit=_LOAD_IN_4BIT, + load_in_8bit=_LOAD_IN_8BIT, + ) self._tokenizer = AutoTokenizer.from_pretrained(model_path) self._model = AutoModelForCausalLM.from_pretrained(model_path, **load_kwargs)