fix: assorted reliability fixes

- cloud_session: bypass IPs now honour valid JWT tokens so logged-in devs
  land on their own account DB; invalid/expired JWTs soft-fail to guest
  instead of hard 401 (public endpoints stay accessible with stale cookies)
- tasks/scheduler: log unhandled exceptions that escape run_task_fn to
  prevent silent task stalls in the batch worker
- reranker: add module-level logger for structured log output
- text/transformers: use BitsAndBytesConfig for quantization (deprecated
  load_in_4bit/load_in_8bit kwargs removed in transformers 4.40+)
- __init__: derive __version__ from installed package metadata so editable
  installs always report the correct version string
This commit is contained in:
pyr0ball 2026-06-05 10:19:31 -07:00
parent 24c75925ee
commit cdeb410f45
5 changed files with 55 additions and 12 deletions

View file

@ -1,4 +1,9 @@
__version__ = "0.18.0" from importlib.metadata import PackageNotFoundError, version
try:
__version__ = version("circuitforge-core")
except PackageNotFoundError:
__version__ = "dev" # running from source without an editable install
try: try:
from circuitforge_core.community import CommunityDB, CommunityPost, SharedStore from circuitforge_core.community import CommunityDB, CommunityPost, SharedStore

View file

@ -39,6 +39,13 @@ from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from typing import Any, Callable from typing import Any, Callable
try:
from starlette.requests import Request as _Request
from starlette.responses import Response as _Response
except ImportError: # pragma: no cover — starlette may be absent in non-web envs
_Request = Any # type: ignore[assignment,misc]
_Response = Any # type: ignore[assignment,misc]
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
TIERS: list[str] = ["free", "paid", "premium", "ultra"] TIERS: list[str] = ["free", "paid", "premium", "ultra"]
@ -248,22 +255,40 @@ class CloudSessionFactory:
request.headers.get("x-real-ip", "") request.headers.get("x-real-ip", "")
or (request.client.host if request.client else "") or (request.client.host if request.client else "")
) )
if _is_bypass_ip(client_ip, self._bypass_nets, self._bypass_ips): is_bypass = _is_bypass_ip(client_ip, self._bypass_nets, self._bypass_ips)
log.debug("Bypass IP %s — returning local-dev session for product %s", client_ip, self.product)
return CloudUser(user_id="local-dev", tier="local", product=self.product, has_byok=has_byok)
raw_session = ( raw_session = (
request.headers.get("x-cf-session", "").strip() request.headers.get("x-cf-session", "").strip()
or request.cookies.get("cf_session", "").strip() or request.cookies.get("cf_session", "").strip()
) )
# Bypass IPs skip the JWT *requirement* but not JWT *validation*.
# If a token is present (dev is logged in), honour it so they land on
# their own account DB rather than the shared local-dev DB.
if not raw_session: if not raw_session:
if is_bypass:
log.debug("Bypass IP %s, no token — returning local-dev session for product %s", client_ip, self.product)
return CloudUser(user_id="local-dev", tier="local", product=self.product, has_byok=has_byok)
return self._resolve_guest(request, response) return self._resolve_guest(request, response)
token = _extract_session_token(raw_session) token = _extract_session_token(raw_session)
if not token: if not token:
return self._resolve_guest(request, response) return self._resolve_guest(request, response)
user_id = self.validate_jwt(token) # Soft-fail on invalid/expired JWT: downgrade to guest rather than
# hard-erroring with 401. Public endpoints (e.g. community blocklist)
# should remain accessible even when the browser has a stale cookie.
# Routes that genuinely require an authenticated identity should gate
# themselves with require_tier() — that's where the 401/403 belongs.
try:
user_id = self.validate_jwt(token)
except Exception:
log.warning(
"JWT validation failed for product %s (expired or tampered) — falling back to guest",
self.product,
)
return self._resolve_guest(request, response)
self._ensure_provisioned(user_id) self._ensure_provisioned(user_id)
tier_data = self._resolve_tier(user_id) tier_data = self._resolve_tier(user_id)
tier = tier_data.get("tier", "free") tier = tier_data.get("tier", "free")
@ -283,11 +308,11 @@ class CloudSessionFactory:
meta=meta, meta=meta,
) )
def dependency(self) -> Callable[[Any, Any], CloudUser]: def dependency(self) -> Callable[["_Request", "_Response"], CloudUser]:
"""Return a FastAPI-compatible dependency function (use with Depends()).""" """Return a FastAPI-compatible dependency function (use with Depends())."""
factory = self factory = self
def _get_session(request: Any, response: Any) -> CloudUser: def _get_session(request: _Request, response: _Response) -> CloudUser:
return factory.resolve(request, response) return factory.resolve(request, response)
return _get_session return _get_session

View file

@ -51,9 +51,12 @@ cf-orch service profile (Phase 3 — remote backend):
""" """
from __future__ import annotations from __future__ import annotations
import logging
import os import os
from typing import Sequence from typing import Sequence
logger = logging.getLogger(__name__)
from circuitforge_core.reranker.base import RerankResult, Reranker, TextReranker from circuitforge_core.reranker.base import RerankResult, Reranker, TextReranker
from circuitforge_core.reranker.adapters.mock import MockTextReranker from circuitforge_core.reranker.adapters.mock import MockTextReranker

View file

@ -169,7 +169,15 @@ class LocalScheduler:
if not q: if not q:
break break
task = q.popleft() task = q.popleft()
self._run_task(self._db_path, task.id, task_type, task.job_id, task.params) try:
self._run_task(self._db_path, task.id, task_type, task.job_id, task.params)
except Exception as exc:
# run_task_fn should handle its own exceptions. If it leaks one,
# log it so the task doesn't silently stay 'queued' with no trace.
logger.exception(
"Unhandled exception in batch worker task %d (%s): %s",
task.id, task_type, exc,
)
finally: finally:
with self._lock: with self._lock:
self._active.pop(task_type, None) self._active.pop(task_type, None)

View file

@ -50,10 +50,12 @@ class TransformersBackend:
logger.info("Loading transformers model %s on %s", model_path, self._device) logger.info("Loading transformers model %s on %s", model_path, self._device)
load_kwargs: dict = {"device_map": "auto" if self._device == "cuda" else None} load_kwargs: dict = {"device_map": "auto" if self._device == "cuda" else None}
if _LOAD_IN_4BIT: if _LOAD_IN_4BIT or _LOAD_IN_8BIT:
load_kwargs["load_in_4bit"] = True from transformers import BitsAndBytesConfig
elif _LOAD_IN_8BIT: load_kwargs["quantization_config"] = BitsAndBytesConfig(
load_kwargs["load_in_8bit"] = True load_in_4bit=_LOAD_IN_4BIT,
load_in_8bit=_LOAD_IN_8BIT,
)
self._tokenizer = AutoTokenizer.from_pretrained(model_path) self._tokenizer = AutoTokenizer.from_pretrained(model_path)
self._model = AutoModelForCausalLM.from_pretrained(model_path, **load_kwargs) self._model = AutoModelForCausalLM.from_pretrained(model_path, **load_kwargs)