fix: assorted reliability fixes
- cloud_session: bypass IPs now honour valid JWT tokens so logged-in devs land on their own account DB; invalid/expired JWTs soft-fail to guest instead of hard 401 (public endpoints stay accessible with stale cookies) - tasks/scheduler: log unhandled exceptions that escape run_task_fn to prevent silent task stalls in the batch worker - reranker: add module-level logger for structured log output - text/transformers: use BitsAndBytesConfig for quantization (deprecated load_in_4bit/load_in_8bit kwargs removed in transformers 4.40+) - __init__: derive __version__ from installed package metadata so editable installs always report the correct version string
This commit is contained in:
parent
24c75925ee
commit
cdeb410f45
5 changed files with 55 additions and 12 deletions
|
|
@ -1,4 +1,9 @@
|
|||
__version__ = "0.18.0"
|
||||
from importlib.metadata import PackageNotFoundError, version
|
||||
|
||||
try:
|
||||
__version__ = version("circuitforge-core")
|
||||
except PackageNotFoundError:
|
||||
__version__ = "dev" # running from source without an editable install
|
||||
|
||||
try:
|
||||
from circuitforge_core.community import CommunityDB, CommunityPost, SharedStore
|
||||
|
|
|
|||
|
|
@ -39,6 +39,13 @@ from dataclasses import dataclass, field
|
|||
from pathlib import Path
|
||||
from typing import Any, Callable
|
||||
|
||||
try:
|
||||
from starlette.requests import Request as _Request
|
||||
from starlette.responses import Response as _Response
|
||||
except ImportError: # pragma: no cover — starlette may be absent in non-web envs
|
||||
_Request = Any # type: ignore[assignment,misc]
|
||||
_Response = Any # type: ignore[assignment,misc]
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
TIERS: list[str] = ["free", "paid", "premium", "ultra"]
|
||||
|
|
@ -248,22 +255,40 @@ class CloudSessionFactory:
|
|||
request.headers.get("x-real-ip", "")
|
||||
or (request.client.host if request.client else "")
|
||||
)
|
||||
if _is_bypass_ip(client_ip, self._bypass_nets, self._bypass_ips):
|
||||
log.debug("Bypass IP %s — returning local-dev session for product %s", client_ip, self.product)
|
||||
return CloudUser(user_id="local-dev", tier="local", product=self.product, has_byok=has_byok)
|
||||
is_bypass = _is_bypass_ip(client_ip, self._bypass_nets, self._bypass_ips)
|
||||
|
||||
raw_session = (
|
||||
request.headers.get("x-cf-session", "").strip()
|
||||
or request.cookies.get("cf_session", "").strip()
|
||||
)
|
||||
|
||||
# Bypass IPs skip the JWT *requirement* but not JWT *validation*.
|
||||
# If a token is present (dev is logged in), honour it so they land on
|
||||
# their own account DB rather than the shared local-dev DB.
|
||||
if not raw_session:
|
||||
if is_bypass:
|
||||
log.debug("Bypass IP %s, no token — returning local-dev session for product %s", client_ip, self.product)
|
||||
return CloudUser(user_id="local-dev", tier="local", product=self.product, has_byok=has_byok)
|
||||
return self._resolve_guest(request, response)
|
||||
|
||||
token = _extract_session_token(raw_session)
|
||||
if not token:
|
||||
return self._resolve_guest(request, response)
|
||||
|
||||
user_id = self.validate_jwt(token)
|
||||
# Soft-fail on invalid/expired JWT: downgrade to guest rather than
|
||||
# hard-erroring with 401. Public endpoints (e.g. community blocklist)
|
||||
# should remain accessible even when the browser has a stale cookie.
|
||||
# Routes that genuinely require an authenticated identity should gate
|
||||
# themselves with require_tier() — that's where the 401/403 belongs.
|
||||
try:
|
||||
user_id = self.validate_jwt(token)
|
||||
except Exception:
|
||||
log.warning(
|
||||
"JWT validation failed for product %s (expired or tampered) — falling back to guest",
|
||||
self.product,
|
||||
)
|
||||
return self._resolve_guest(request, response)
|
||||
|
||||
self._ensure_provisioned(user_id)
|
||||
tier_data = self._resolve_tier(user_id)
|
||||
tier = tier_data.get("tier", "free")
|
||||
|
|
@ -283,11 +308,11 @@ class CloudSessionFactory:
|
|||
meta=meta,
|
||||
)
|
||||
|
||||
def dependency(self) -> Callable[[Any, Any], CloudUser]:
|
||||
def dependency(self) -> Callable[["_Request", "_Response"], CloudUser]:
|
||||
"""Return a FastAPI-compatible dependency function (use with Depends())."""
|
||||
factory = self
|
||||
|
||||
def _get_session(request: Any, response: Any) -> CloudUser:
|
||||
def _get_session(request: _Request, response: _Response) -> CloudUser:
|
||||
return factory.resolve(request, response)
|
||||
|
||||
return _get_session
|
||||
|
|
|
|||
|
|
@ -51,9 +51,12 @@ cf-orch service profile (Phase 3 — remote backend):
|
|||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Sequence
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from circuitforge_core.reranker.base import RerankResult, Reranker, TextReranker
|
||||
from circuitforge_core.reranker.adapters.mock import MockTextReranker
|
||||
|
||||
|
|
|
|||
|
|
@ -169,7 +169,15 @@ class LocalScheduler:
|
|||
if not q:
|
||||
break
|
||||
task = q.popleft()
|
||||
self._run_task(self._db_path, task.id, task_type, task.job_id, task.params)
|
||||
try:
|
||||
self._run_task(self._db_path, task.id, task_type, task.job_id, task.params)
|
||||
except Exception as exc:
|
||||
# run_task_fn should handle its own exceptions. If it leaks one,
|
||||
# log it so the task doesn't silently stay 'queued' with no trace.
|
||||
logger.exception(
|
||||
"Unhandled exception in batch worker task %d (%s): %s",
|
||||
task.id, task_type, exc,
|
||||
)
|
||||
finally:
|
||||
with self._lock:
|
||||
self._active.pop(task_type, None)
|
||||
|
|
|
|||
|
|
@ -50,10 +50,12 @@ class TransformersBackend:
|
|||
logger.info("Loading transformers model %s on %s", model_path, self._device)
|
||||
|
||||
load_kwargs: dict = {"device_map": "auto" if self._device == "cuda" else None}
|
||||
if _LOAD_IN_4BIT:
|
||||
load_kwargs["load_in_4bit"] = True
|
||||
elif _LOAD_IN_8BIT:
|
||||
load_kwargs["load_in_8bit"] = True
|
||||
if _LOAD_IN_4BIT or _LOAD_IN_8BIT:
|
||||
from transformers import BitsAndBytesConfig
|
||||
load_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||
load_in_4bit=_LOAD_IN_4BIT,
|
||||
load_in_8bit=_LOAD_IN_8BIT,
|
||||
)
|
||||
|
||||
self._tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||
self._model = AutoModelForCausalLM.from_pretrained(model_path, **load_kwargs)
|
||||
|
|
|
|||
Loading…
Reference in a new issue