From 8eef52a0546f13714970be5cd8d45856c467795c Mon Sep 17 00:00:00 2001 From: pyr0ball Date: Wed, 13 May 2026 16:31:51 -0700 Subject: [PATCH] feat: per-user database isolation for cloud instances (closes #4) Implements Option A from the issue design: each cloud user gets their own data directory (DATA_DIR/users/{user_id}/) with separate pagepiper.db, pagepiper_vecs.db, uploads/, and books/. Local mode is unchanged. Key changes: - app/startup.py: extract apply_migrations, reembed_docs, check_and_rebuild_vec_schema out of main.py (no circular imports) - app/config.py: add LOCAL_USER_ID constant and user_data_dir() helper - app/cloud_session.py: extract resolve_authenticated_user(); require_paid_tier now returns user_id (str) instead of None - app/deps.py: add UserCtx dataclass (db_path, vec_db_path, data_dir, watch_dir, bm25) + get_user_ctx dependency; per-user startup guard runs migrations + vec schema check once per process per user - app/main.py: _bm25 singleton -> _bm25_map dict keyed by user_id; add _get_bm25_for(); lifespan only runs startup checks in local mode - app/api/library.py, search.py, chat.py: thread UserCtx through all endpoints; remove module-level _mark_bm25_dirty injection pattern - tests/conftest.py: override get_user_ctx in addition to get_db so all endpoints get a consistent test UserCtx --- app/api/chat.py | 41 +++++------- app/api/library.py | 81 ++++++++++++++---------- app/api/search.py | 27 ++------ app/cloud_session.py | 131 +++++++++++++++++++++++++++++++++++++++ app/config.py | 33 +++++++--- app/deps.py | 68 +++++++++++++++++++- app/main.py | 111 +++++---------------------------- app/startup.py | 95 ++++++++++++++++++++++++++++ tests/conftest.py | 33 +++++++--- tests/test_search_api.py | 8 +-- tests/test_startup.py | 3 +- 11 files changed, 432 insertions(+), 199 deletions(-) create mode 100644 app/cloud_session.py create mode 100644 app/startup.py diff --git a/app/api/chat.py b/app/api/chat.py index dcfc35a..0fe3815 100644 --- a/app/api/chat.py +++ b/app/api/chat.py @@ -10,9 +10,11 @@ from __future__ import annotations import logging import os -from fastapi import APIRouter, HTTPException +from fastapi import APIRouter, Depends, HTTPException from pydantic import BaseModel +from app.cloud_session import require_paid_tier +from app.deps import UserCtx, get_user_ctx from app.services.retriever import Retriever from app.services.synthesizer import Synthesizer @@ -56,21 +58,6 @@ def _get_llm_router(): return LLMRouter(cfg) -def _get_db_path() -> str: - """Read lazily so test fixtures take effect.""" - import pathlib - - data_dir = pathlib.Path(os.environ.get("PAGEPIPER_DATA_DIR", "data")) - return str(data_dir / "pagepiper.db") - - -def _get_vec_db_path() -> str: - import pathlib - - data_dir = pathlib.Path(os.environ.get("PAGEPIPER_DATA_DIR", "data")) - return str(data_dir / "pagepiper_vecs.db") - - def _require_llm(): """Return LLMRouter or raise 402.""" llm = _get_llm_router() @@ -89,18 +76,20 @@ def _require_llm(): @router.post("") -def chat(req: ChatRequest) -> ChatResponse: +def chat( + req: ChatRequest, + ctx: UserCtx = Depends(get_user_ctx), + _tier: str = Depends(require_paid_tier), +) -> ChatResponse: llm = _require_llm() - from app.main import _bm25 - - retriever = Retriever(_bm25) + retriever = Retriever(ctx.bm25) chunks = retriever.hybrid_search( query=req.message, top_k=req.top_k, doc_ids=req.doc_ids, - db_path=_get_db_path(), - vec_db_path=_get_vec_db_path(), + db_path=ctx.db_path, + vec_db_path=ctx.vec_db_path, llm=llm, ) @@ -141,7 +130,10 @@ def chat_feedback_status() -> dict: @router.post("/feedback") -def submit_chat_feedback(req: ChatFeedbackRequest) -> dict: +def submit_chat_feedback( + req: ChatFeedbackRequest, + ctx: UserCtx = Depends(get_user_ctx), +) -> dict: import json import sqlite3 @@ -149,8 +141,7 @@ def submit_chat_feedback(req: ChatFeedbackRequest) -> dict: from fastapi import HTTPException raise HTTPException(status_code=422, detail="rating must be 1 or -1") - db_path = _get_db_path() - con = sqlite3.connect(db_path) + con = sqlite3.connect(ctx.db_path) try: con.execute( "INSERT INTO chat_feedback (rating, question, answer, doc_ids) VALUES (?, ?, ?, ?)", diff --git a/app/api/library.py b/app/api/library.py index ad96a8e..2ce57d0 100644 --- a/app/api/library.py +++ b/app/api/library.py @@ -14,26 +14,25 @@ from typing import Callable from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, UploadFile -from app.config import WATCH_DIR, DB_PATH, VEC_DB_PATH, DATA_DIR -from app.deps import get_db +from app.config import VEC_DIMENSIONS +from app.deps import UserCtx, get_db, get_user_ctx _MAX_UPLOAD_BYTES = 200 * 1024 * 1024 # 200 MB logger = logging.getLogger(__name__) router = APIRouter(prefix="/api/library", tags=["library"]) -# Injected by main.py after _bm25 is created -_mark_bm25_dirty: Callable[[], None] | None = None - _INGEST_TASKS = { ".pdf": "pagepiper/ingest_pdf", ".epub": "pagepiper/ingest_epub", + ".docx": "pagepiper/ingest_docx", } _INGEST_RUNNERS = { ".pdf": "scripts.ingest_pdf", ".epub": "scripts.ingest_epub", + ".docx": "scripts.ingest_docx", } @@ -41,24 +40,22 @@ def _dispatch_ingest( doc_id: str, file_path: str, background_tasks: BackgroundTasks, + data_dir: Path, + mark_dirty_fn: Callable[[], None], ) -> str: """Dispatch an ingest task. Tries cf-orch; falls back to BackgroundTasks.""" import importlib - import os as _os - from pathlib import Path as _Path - suffix = _Path(file_path).suffix.lower() + suffix = Path(file_path).suffix.lower() task_name = _INGEST_TASKS.get(suffix, "pagepiper/ingest_pdf") runner_module = _INGEST_RUNNERS.get(suffix, "scripts.ingest_pdf") - # Read lazily so test fixtures (monkeypatch.setenv) take effect - _data_dir = _Path(_os.environ.get("PAGEPIPER_DATA_DIR", "data")) task_id = str(uuid.uuid4()) args = { "doc_id": doc_id, "file_path": file_path, - "db_path": str(_data_dir / "pagepiper.db"), - "vec_db_path": str(_data_dir / "pagepiper_vecs.db"), + "db_path": str(data_dir / "pagepiper.db"), + "vec_db_path": str(data_dir / "pagepiper_vecs.db"), } try: @@ -67,7 +64,7 @@ def _dispatch_ingest( logger.info("Dispatched cf-orch ingest task %s for doc %s", task_id, doc_id) except Exception: mod = importlib.import_module(runner_module) - background_tasks.add_task(_run_ingest_background, mod.run, args, task_id) + background_tasks.add_task(_run_ingest_background, mod.run, args, task_id, mark_dirty_fn) logger.info( "cf-orch unavailable — running ingest in background thread (task %s)", task_id ) @@ -75,14 +72,19 @@ def _dispatch_ingest( return task_id -def _run_ingest_background(run_fn: Callable[..., None], args: dict, task_id: str) -> None: +def _run_ingest_background( + run_fn: Callable[..., None], + args: dict, + task_id: str, + mark_dirty_fn: Callable[[], None] | None = None, +) -> None: from app.api.ingest import _task_registry _task_registry[task_id] = {"status": "running", "progress": 0} try: run_fn(**args) _task_registry[task_id] = {"status": "complete", "progress": 100} - if _mark_bm25_dirty: - _mark_bm25_dirty() + if mark_dirty_fn: + mark_dirty_fn() except Exception as exc: logger.exception("Ingest task %s failed", task_id) _task_registry[task_id] = {"status": "error", "error": str(exc)} @@ -101,13 +103,18 @@ def list_library(db: sqlite3.Connection = Depends(get_db)) -> list[dict]: def scan_library( background_tasks: BackgroundTasks, db: sqlite3.Connection = Depends(get_db), + ctx: UserCtx = Depends(get_user_ctx), ) -> dict: """Scan the watched directory and queue ingest for any new PDFs.""" - watch = WATCH_DIR + watch = ctx.watch_dir if not watch.exists(): raise HTTPException(status_code=404, detail=f"Watch directory not found: {watch}") - pdfs = list(watch.glob("**/*.pdf")) + list(watch.glob("**/*.epub")) + pdfs = ( + list(watch.glob("**/*.pdf")) + + list(watch.glob("**/*.epub")) + + list(watch.glob("**/*.docx")) + ) queued = [] for pdf_path in pdfs: @@ -117,7 +124,7 @@ def scan_library( ).fetchone() if existing and existing["status"] == "ready": - continue # already indexed + continue if existing: doc_id = existing["id"] @@ -129,7 +136,9 @@ def scan_library( ).fetchone()[0] db.commit() - task_id = _dispatch_ingest(doc_id, path_str, background_tasks) + task_id = _dispatch_ingest( + doc_id, path_str, background_tasks, ctx.data_dir, ctx.bm25.mark_dirty + ) db.execute( "UPDATE documents SET status='processing', task_id=? WHERE id=?", [task_id, doc_id], @@ -145,12 +154,15 @@ def reingest_document( doc_id: str, background_tasks: BackgroundTasks, db: sqlite3.Connection = Depends(get_db), + ctx: UserCtx = Depends(get_user_ctx), ) -> dict: row = db.execute("SELECT file_path FROM documents WHERE id=?", [doc_id]).fetchone() if not row: raise HTTPException(status_code=404, detail="Document not found") - task_id = _dispatch_ingest(doc_id, row["file_path"], background_tasks) + task_id = _dispatch_ingest( + doc_id, row["file_path"], background_tasks, ctx.data_dir, ctx.bm25.mark_dirty + ) db.execute( "UPDATE documents SET status='processing', task_id=?, error_msg=NULL WHERE id=?", [task_id, doc_id], @@ -163,6 +175,7 @@ def reingest_document( def delete_document( doc_id: str, db: sqlite3.Connection = Depends(get_db), + ctx: UserCtx = Depends(get_user_ctx), ) -> None: row = db.execute("SELECT id FROM documents WHERE id=?", [doc_id]).fetchone() if not row: @@ -171,23 +184,21 @@ def delete_document( db.execute("DELETE FROM documents WHERE id=?", [doc_id]) db.commit() - # Remove embeddings from vector store try: from circuitforge_core.vector.sqlite_vec import LocalSQLiteVecStore # type: ignore[import] - from app.config import VEC_DIMENSIONS - store = LocalSQLiteVecStore(db_path=VEC_DB_PATH, table="page_vecs", dimensions=VEC_DIMENSIONS) + store = LocalSQLiteVecStore( + db_path=ctx.vec_db_path, table="page_vecs", dimensions=VEC_DIMENSIONS + ) store.delete_where({"doc_id": doc_id}) except Exception as exc: logger.warning("Could not remove vectors for doc %s: %s", doc_id, exc) - if _mark_bm25_dirty: - _mark_bm25_dirty() + ctx.bm25.mark_dirty() -def _get_vec_count(doc_id: str) -> int: - """Return how many vectors have been stored for this doc. Returns 0 on any error.""" +def _get_vec_count(doc_id: str, vec_db_path: str) -> int: try: - conn = sqlite3.connect(VEC_DB_PATH) + conn = sqlite3.connect(vec_db_path) count = conn.execute( "SELECT COUNT(*) FROM page_vecs_meta WHERE json_extract(metadata, '$.doc_id') = ?", [doc_id], @@ -202,6 +213,7 @@ def _get_vec_count(doc_id: str) -> int: def document_status( doc_id: str, db: sqlite3.Connection = Depends(get_db), + ctx: UserCtx = Depends(get_user_ctx), ) -> dict: row = db.execute( "SELECT id, status, task_id, page_count, error_msg FROM documents WHERE id=?", @@ -210,7 +222,7 @@ def document_status( if not row: raise HTTPException(status_code=404, detail="Document not found") result = dict(row) - result["vec_count"] = _get_vec_count(doc_id) + result["vec_count"] = _get_vec_count(doc_id, ctx.vec_db_path) return result @@ -219,18 +231,19 @@ def upload_document( file: UploadFile, background_tasks: BackgroundTasks, db: sqlite3.Connection = Depends(get_db), + ctx: UserCtx = Depends(get_user_ctx), ) -> dict: """Accept a PDF/EPUB upload, save to data/uploads/, and queue for indexing.""" name = Path(file.filename or "").name suffix = Path(name).suffix.lower() if suffix not in _INGEST_TASKS: - raise HTTPException(status_code=400, detail="Supported formats: PDF, EPUB") + raise HTTPException(status_code=400, detail="Supported formats: PDF, EPUB, DOCX") content = file.file.read() if len(content) > _MAX_UPLOAD_BYTES: raise HTTPException(status_code=413, detail="File exceeds 200 MB limit") - upload_dir = DATA_DIR / "uploads" + upload_dir = ctx.data_dir / "uploads" upload_dir.mkdir(parents=True, exist_ok=True) dest = upload_dir / name dest.write_bytes(content) @@ -253,7 +266,9 @@ def upload_document( ).fetchone()[0] db.commit() - task_id = _dispatch_ingest(doc_id, path_str, background_tasks) + task_id = _dispatch_ingest( + doc_id, path_str, background_tasks, ctx.data_dir, ctx.bm25.mark_dirty + ) db.execute( "UPDATE documents SET status='processing', task_id=? WHERE id=?", [task_id, doc_id], diff --git a/app/api/search.py b/app/api/search.py index e6d7ce0..5329939 100644 --- a/app/api/search.py +++ b/app/api/search.py @@ -7,13 +7,11 @@ MIT — no tier gate. No Ollama required. from __future__ import annotations import logging -import os -from typing import Annotated from fastapi import APIRouter, Depends from pydantic import BaseModel, Field -from app.services.bm25_index import BM25Index +from app.deps import UserCtx, get_user_ctx logger = logging.getLogger(__name__) router = APIRouter(prefix="/api/search", tags=["search"]) @@ -29,32 +27,17 @@ class SearchResult(BaseModel): chunk_id: str doc_id: str page_number: int - text_snippet: str # first 300 chars of the page text + text_snippet: str bm25_score: float -def _get_bm25() -> BM25Index: - import app.main as _main - bm25 = getattr(_main, "_bm25", None) - if bm25 is None: - raise RuntimeError("BM25 index not initialised — app.main not loaded") - return bm25 - - -def _get_db_path() -> str: - """Read lazily so test fixtures (monkeypatch.setattr) take effect.""" - import pathlib - data_dir = pathlib.Path(os.environ.get("PAGEPIPER_DATA_DIR", "data")) - return str(data_dir / "pagepiper.db") - - @router.post("") def search( req: SearchRequest, - bm25: Annotated[BM25Index, Depends(_get_bm25)], + ctx: UserCtx = Depends(get_user_ctx), ) -> list[SearchResult]: - bm25.ensure_fresh(_get_db_path()) - hits = bm25.query(req.query, top_k=req.top_k, doc_ids=req.doc_ids) + ctx.bm25.ensure_fresh(ctx.db_path) + hits = ctx.bm25.query(req.query, top_k=req.top_k, doc_ids=req.doc_ids) return [ SearchResult( chunk_id=h.chunk_id, diff --git a/app/cloud_session.py b/app/cloud_session.py new file mode 100644 index 0000000..75bcd98 --- /dev/null +++ b/app/cloud_session.py @@ -0,0 +1,131 @@ +# app/cloud_session.py +"""Cloud session auth for Pagepiper — validates cf_session cookie via Directus + Heimdall. + +In local mode (CLOUD_MODE unset or false), require_paid_tier is a no-op. +In cloud mode, the Caddy proxy forwards the browser's Cookie header as +X-CF-Session. This module extracts cf_session, validates it against +Directus /users/me, then checks the user's Pagepiper tier via Heimdall. +Auto-provisions a free tier key for new users. +""" +from __future__ import annotations + +import logging +import os +import re + +import httpx +from fastapi import HTTPException, Request + +log = logging.getLogger(__name__) + +CLOUD_MODE: bool = os.environ.get("CLOUD_MODE", "").lower() in ("1", "true", "yes") +DIRECTUS_URL: str = os.environ.get("DIRECTUS_URL", "http://172.31.0.3:8055").rstrip("/") +HEIMDALL_URL: str = os.environ.get("HEIMDALL_URL", "https://license.circuitforge.tech").rstrip("/") +HEIMDALL_ADMIN_TOKEN: str = os.environ.get("HEIMDALL_ADMIN_TOKEN", "") + +_TIER_ORDER = {"free": 0, "paid": 1, "premium": 2, "ultra": 3} + + +def _extract_session_token(cookie_header: str) -> str: + m = re.search(r"(?:^|;)\s*cf_session=([^;]+)", cookie_header) + return m.group(1).strip() if m else "" + + +def _get_user_id(jwt: str) -> str | None: + try: + resp = httpx.get( + f"{DIRECTUS_URL}/users/me", + headers={"Authorization": f"Bearer {jwt}"}, + timeout=5.0, + ) + if resp.status_code == 200: + return resp.json().get("data", {}).get("id") + except Exception as exc: + log.warning("Directus session check failed: %s", exc) + return None + + +def _ensure_provisioned(user_id: str) -> None: + if not HEIMDALL_ADMIN_TOKEN: + return + try: + httpx.post( + f"{HEIMDALL_URL}/admin/provision", + json={"directus_user_id": user_id, "product": "pagepiper", "tier": "free"}, + headers={"Authorization": f"Bearer {HEIMDALL_ADMIN_TOKEN}"}, + timeout=5.0, + ) + except Exception as exc: + log.warning("Heimdall provision failed for user %s: %s", user_id, exc) + + +def _get_tier(user_id: str) -> str: + if not HEIMDALL_ADMIN_TOKEN: + return "free" + try: + resp = httpx.get( + f"{HEIMDALL_URL}/admin/cloud/resolve", + params={"directus_user_id": user_id, "product": "pagepiper"}, + headers={"Authorization": f"Bearer {HEIMDALL_ADMIN_TOKEN}"}, + timeout=5.0, + ) + if resp.status_code == 200: + return resp.json().get("tier", "free") + except Exception as exc: + log.warning("Heimdall tier check failed for user %s: %s", user_id, exc) + return "free" + + +def resolve_authenticated_user(request: Request) -> str: + """Validate the session cookie and return the Directus user_id. Raises 401 if invalid.""" + cookie_header = request.headers.get("x-cf-session", "") + jwt = _extract_session_token(cookie_header) + + if not jwt: + raise HTTPException( + status_code=401, + detail={ + "error": "auth_required", + "message": "Sign in at circuitforge.tech to use Pagepiper cloud.", + }, + ) + + user_id = _get_user_id(jwt) + if not user_id: + raise HTTPException( + status_code=401, + detail={ + "error": "session_invalid", + "message": "Your session has expired. Sign in again at circuitforge.tech.", + }, + ) + + _ensure_provisioned(user_id) + return user_id + + +def require_paid_tier(request: Request) -> str: + """FastAPI dependency — 401 if no valid session, 402 if tier < paid. Returns user_id. + + In local mode (CLOUD_MODE not set), returns LOCAL_USER_ID without any auth check. + """ + if not CLOUD_MODE: + from app.config import LOCAL_USER_ID + return LOCAL_USER_ID + + user_id = resolve_authenticated_user(request) + tier = _get_tier(user_id) + + if _TIER_ORDER.get(tier, 0) < _TIER_ORDER["paid"]: + raise HTTPException( + status_code=402, + detail={ + "error": "upgrade_required", + "message": ( + "RAG chat requires a Paid tier Pagepiper license. " + "Upgrade at circuitforge.tech/software/pagepiper." + ), + }, + ) + + return user_id diff --git a/app/config.py b/app/config.py index b18568f..88c0a71 100644 --- a/app/config.py +++ b/app/config.py @@ -12,16 +12,36 @@ VEC_DB_PATH = str(DATA_DIR / "pagepiper_vecs.db") WATCH_DIR = Path(os.environ.get("PAGEPIPER_WATCH_DIR", "books")) VEC_DIMENSIONS = int(os.environ.get("PAGEPIPER_EMBED_DIMS", "1024")) +LOCAL_USER_ID = "__local__" + + +def user_data_dir(user_id: str) -> Path: + """Return (and create) the per-user data directory under DATA_DIR/users/.""" + d = DATA_DIR / "users" / user_id + d.mkdir(parents=True, exist_ok=True) + return d + def get_llm_config() -> dict | None: - """Build LLMRouter config from env vars. Returns None if PAGEPIPER_OLLAMA_URL is unset.""" + """Build LLMRouter config from env vars. + + Returns None only when neither PAGEPIPER_OLLAMA_URL nor CF_ORCH_URL is set. + CF_ORCH_URL alone is sufficient — the coordinator resolves the service URL at + allocation time so PAGEPIPER_OLLAMA_URL becomes optional. + """ url = os.environ.get("PAGEPIPER_OLLAMA_URL", "").strip() - if not url: + orch_url = os.environ.get("CF_ORCH_URL", "").strip() + + if not url and not orch_url: return None - _clean = url.rstrip("/") - _base_url = _clean if _clean.endswith("/v1") else _clean + "/v1" + chat_model = os.environ.get("PAGEPIPER_CHAT_MODEL", "mistral:7b") + _base_url = "" + if url: + _clean = url.rstrip("/") + _base_url = _clean if _clean.endswith("/v1") else _clean + "/v1" + backend: dict = { "type": "openai_compat", "base_url": _base_url, @@ -30,12 +50,9 @@ def get_llm_config() -> dict | None: "supports_images": False, } - # Wire cf-orch allocation when coordinator is configured so the model stays warm - # and cold-start latency doesn't cause chat timeouts. - orch_url = os.environ.get("CF_ORCH_URL", "").strip() if orch_url: backend["cf_orch"] = { - "service": "ollama", + "service": os.environ.get("PAGEPIPER_ORCH_SERVICE", "ollama"), "model_candidates": [chat_model], "ttl_s": 3600, } diff --git a/app/deps.py b/app/deps.py index 78cf87c..4826f59 100644 --- a/app/deps.py +++ b/app/deps.py @@ -3,13 +3,75 @@ from __future__ import annotations import sqlite3 +from dataclasses import dataclass +from pathlib import Path from typing import Generator -from app.config import DB_PATH +from fastapi import Depends, Request + +from app.config import DATA_DIR, LOCAL_USER_ID +from app.services.bm25_index import BM25Index -def get_db() -> Generator[sqlite3.Connection, None, None]: - conn = sqlite3.connect(DB_PATH, check_same_thread=False) +@dataclass +class UserCtx: + """Per-request context routing DB paths and BM25 to the right user.""" + + user_id: str + db_path: str + vec_db_path: str + data_dir: Path + watch_dir: Path + bm25: BM25Index + + +_user_startup_done: set[str] = set() + + +def _run_user_startup(user_id: str, user_dir: Path) -> None: + """Run migrations and vec schema check once per process lifetime per user.""" + if user_id in _user_startup_done: + return + _user_startup_done.add(user_id) + from app.config import VEC_DIMENSIONS + from app.startup import apply_migrations, check_and_rebuild_vec_schema + apply_migrations(str(user_dir / "pagepiper.db")) + check_and_rebuild_vec_schema( + str(user_dir / "pagepiper_vecs.db"), VEC_DIMENSIONS, str(user_dir / "pagepiper.db") + ) + + +def get_user_ctx(request: Request) -> UserCtx: + """Resolve the per-user data directory, DB paths, and BM25 instance for this request.""" + import app.main as _main + from app.cloud_session import CLOUD_MODE + + if CLOUD_MODE: + from app.cloud_session import resolve_authenticated_user + from app.config import user_data_dir + user_id = resolve_authenticated_user(request) + user_dir = user_data_dir(user_id) + _run_user_startup(user_id, user_dir) + watch_dir = user_dir / "books" + watch_dir.mkdir(parents=True, exist_ok=True) + else: + from app.config import WATCH_DIR + user_id = LOCAL_USER_ID + user_dir = DATA_DIR + watch_dir = WATCH_DIR + + return UserCtx( + user_id=user_id, + db_path=str(user_dir / "pagepiper.db"), + vec_db_path=str(user_dir / "pagepiper_vecs.db"), + data_dir=user_dir, + watch_dir=watch_dir, + bm25=_main._get_bm25_for(user_id), + ) + + +def get_db(ctx: UserCtx = Depends(get_user_ctx)) -> Generator[sqlite3.Connection, None, None]: + conn = sqlite3.connect(ctx.db_path, check_same_thread=False) conn.execute("PRAGMA foreign_keys = ON") conn.execute("PRAGMA journal_mode = WAL") conn.row_factory = sqlite3.Row diff --git a/app/main.py b/app/main.py index e52f09a..1389472 100644 --- a/app/main.py +++ b/app/main.py @@ -4,9 +4,6 @@ from __future__ import annotations import logging import os -import re -import sqlite3 -import threading from contextlib import asynccontextmanager from fastapi import FastAPI @@ -16,110 +13,36 @@ from app.services.bm25_index import BM25Index logger = logging.getLogger("pagepiper") -# Module-level BM25 singleton — shared across all requests -_bm25 = BM25Index() +# Per-user BM25 registry — keyed by user_id; "__local__" for single-user mode +_bm25_map: dict[str, BM25Index] = {} -def _apply_migrations() -> None: - from scripts.db_migrate import migrate - migrate(DB_PATH) - - -def _reembed_docs(docs: list[tuple[str, str]], db_path: str, vec_db_path: str) -> None: - """Re-run full ingest for a list of (doc_id, file_path) sequentially.""" - for doc_id, file_path in docs: - suffix = os.path.splitext(file_path)[1].lower() - try: - if suffix == ".epub": - from scripts.ingest_epub import run - else: - from scripts.ingest_pdf import run - logger.info("Auto re-embed: starting %s", os.path.basename(file_path)) - run(doc_id=doc_id, file_path=file_path, db_path=db_path, vec_db_path=vec_db_path) - except Exception as exc: - logger.error("Auto re-embed failed for doc %s: %s", doc_id[:8], exc) - - -def _check_vec_schema(vec_db_path: str, expected_dims: int, db_path: str) -> None: - """Drop the vec DB if its stored dimension doesn't match config, then queue re-embed. - - sqlite-vec bakes the embedding dimension into the virtual table DDL, so changing - models requires dropping and recreating the whole file. Catches the mismatch at - startup rather than surfacing it as an obscure OperationalError mid-request. - """ - if not os.path.exists(vec_db_path): - return - try: - conn = sqlite3.connect(vec_db_path) - row = conn.execute( - "SELECT sql FROM sqlite_master WHERE name='page_vecs_vecs'" - ).fetchone() - conn.close() - except Exception as exc: - logger.warning("Vec schema check could not read %s (non-fatal): %s", vec_db_path, exc) - return - - if not row: - return # table not yet created — first embed will build it with the right dims - - m = re.search(r'float\[(\d+)\]', row[0]) - if not m: - return - actual_dims = int(m.group(1)) - if actual_dims == expected_dims: - return - - logger.warning( - "Vec DB dimension mismatch: stored=%d, configured=%d — dropping %s and queuing re-embed", - actual_dims, expected_dims, vec_db_path, - ) - try: - os.remove(vec_db_path) - except OSError as exc: - logger.error( - "Could not delete stale vec DB %s: %s — fix permissions and restart", vec_db_path, exc - ) - return - - # Collect all ready docs so we can rebuild their embeddings in the background. - try: - conn = sqlite3.connect(db_path) - docs = conn.execute( - "SELECT id, file_path FROM documents WHERE status='ready'" - ).fetchall() - conn.close() - except Exception as exc: - logger.warning("Could not query documents for re-embed: %s", exc) - return - - if not docs: - return - - logger.info("Queuing re-embed for %d document(s) in background", len(docs)) - threading.Thread( - target=_reembed_docs, - args=(docs, db_path, vec_db_path), - daemon=True, - name="pagepiper-reembed", - ).start() +def _get_bm25_for(user_id: str) -> BM25Index: + if user_id not in _bm25_map: + _bm25_map[user_id] = BM25Index() + return _bm25_map[user_id] @asynccontextmanager async def lifespan(app: FastAPI): - _apply_migrations() + from app.cloud_session import CLOUD_MODE + from app.config import LOCAL_USER_ID + from app.startup import apply_migrations, check_and_rebuild_vec_schema + embed_model = os.environ.get("PAGEPIPER_EMBED_MODEL", "nomic-embed-text") logger.info("Pagepiper starting — embed model: %s, dims: %d", embed_model, VEC_DIMENSIONS) - _check_vec_schema(VEC_DB_PATH, VEC_DIMENSIONS, DB_PATH) - _bm25.mark_dirty() # will rebuild on first search + + if not CLOUD_MODE: + # In cloud mode, per-user migration and vec schema check run on first request (deps.py). + apply_migrations(DB_PATH) + check_and_rebuild_vec_schema(VEC_DB_PATH, VEC_DIMENSIONS, DB_PATH) + _get_bm25_for(LOCAL_USER_ID).mark_dirty() + yield app = FastAPI(title="Pagepiper", lifespan=lifespan) -# Wire BM25 dirty callback into library router -from app.api import library as _lib_module # noqa: E402 -_lib_module._mark_bm25_dirty = _bm25.mark_dirty - # Register routers from app.api.library import router as library_router # noqa: E402 from app.api.ingest import router as ingest_router # noqa: E402 diff --git a/app/startup.py b/app/startup.py new file mode 100644 index 0000000..1a01f18 --- /dev/null +++ b/app/startup.py @@ -0,0 +1,95 @@ +# app/startup.py +"""DB migration and vec schema check utilities — called at startup and on first user request.""" +from __future__ import annotations + +import logging +import os +import re +import sqlite3 +import threading + +logger = logging.getLogger("pagepiper") + + +def apply_migrations(db_path: str) -> None: + from scripts.db_migrate import migrate + migrate(db_path) + + +def reembed_docs(docs: list[tuple[str, str]], db_path: str, vec_db_path: str) -> None: + for doc_id, file_path in docs: + suffix = os.path.splitext(file_path)[1].lower() + try: + if suffix == ".epub": + from scripts.ingest_epub import run + elif suffix == ".docx": + from scripts.ingest_docx import run + else: + from scripts.ingest_pdf import run + logger.info("Auto re-embed: starting %s", os.path.basename(file_path)) + run(doc_id=doc_id, file_path=file_path, db_path=db_path, vec_db_path=vec_db_path) + except Exception as exc: + logger.error("Auto re-embed failed for doc %s: %s", doc_id[:8], exc) + + +def check_and_rebuild_vec_schema(vec_db_path: str, expected_dims: int, db_path: str) -> None: + """Drop the vec DB if its stored dimension doesn't match config, then queue re-embed. + + sqlite-vec bakes the embedding dimension into the virtual table DDL, so changing + models requires dropping and recreating the whole file. Catches the mismatch at + startup rather than surfacing it as an obscure OperationalError mid-request. + """ + if not os.path.exists(vec_db_path): + return + try: + conn = sqlite3.connect(vec_db_path) + row = conn.execute( + "SELECT sql FROM sqlite_master WHERE name='page_vecs_vecs'" + ).fetchone() + conn.close() + except Exception as exc: + logger.warning("Vec schema check could not read %s (non-fatal): %s", vec_db_path, exc) + return + + if not row: + return + + m = re.search(r'float\[(\d+)\]', row[0]) + if not m: + return + actual_dims = int(m.group(1)) + if actual_dims == expected_dims: + return + + logger.warning( + "Vec DB dimension mismatch: stored=%d, configured=%d — dropping %s and queuing re-embed", + actual_dims, expected_dims, vec_db_path, + ) + try: + os.remove(vec_db_path) + except OSError as exc: + logger.error( + "Could not delete stale vec DB %s: %s — fix permissions and restart", vec_db_path, exc + ) + return + + try: + conn = sqlite3.connect(db_path) + docs = conn.execute( + "SELECT id, file_path FROM documents WHERE status='ready'" + ).fetchall() + conn.close() + except Exception as exc: + logger.warning("Could not query documents for re-embed: %s", exc) + return + + if not docs: + return + + logger.info("Queuing re-embed for %d document(s) in background", len(docs)) + threading.Thread( + target=reembed_docs, + args=(docs, db_path, vec_db_path), + daemon=True, + name="pagepiper-reembed", + ).start() diff --git a/tests/conftest.py b/tests/conftest.py index 91e1f12..65fe8de 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -27,13 +27,32 @@ def client(test_db, tmp_path, monkeypatch): (tmp_path / "books").mkdir(exist_ok=True) import app.main as _main_module - from app.main import app, _bm25 - from app.deps import get_db + from app.config import LOCAL_USER_ID + from app.deps import UserCtx, get_db, get_user_ctx + from app.main import app + from app.services.bm25_index import BM25Index + from app.startup import apply_migrations, check_and_rebuild_vec_schema - # Suppress startup side effects — test_db fixture already applies the schema, - # and vec schema validation is tested separately in test_startup.py - monkeypatch.setattr(_main_module, "_apply_migrations", lambda: None) - monkeypatch.setattr(_main_module, "_check_vec_schema", lambda *a, **kw: None) + monkeypatch.setattr(_main_module, "_apply_migrations", lambda: None, raising=False) + monkeypatch.setattr( + "app.startup.apply_migrations", lambda *a, **kw: None + ) + monkeypatch.setattr( + "app.startup.check_and_rebuild_vec_schema", lambda *a, **kw: None + ) + + test_bm25 = BM25Index() + test_bm25.mark_dirty() + + def override_user_ctx(): + return UserCtx( + user_id=LOCAL_USER_ID, + db_path=test_db, + vec_db_path=str(tmp_path / "test_vecs.db"), + data_dir=Path(tmp_path), + watch_dir=Path(tmp_path) / "books", + bm25=test_bm25, + ) def override_db(): conn = sqlite3.connect(test_db) @@ -44,7 +63,7 @@ def client(test_db, tmp_path, monkeypatch): finally: conn.close() + app.dependency_overrides[get_user_ctx] = override_user_ctx app.dependency_overrides[get_db] = override_db - _bm25.mark_dirty() # clear any state from previous tests yield TestClient(app) app.dependency_overrides.clear() diff --git a/tests/test_search_api.py b/tests/test_search_api.py index 8f4f13a..c693c53 100644 --- a/tests/test_search_api.py +++ b/tests/test_search_api.py @@ -20,9 +20,7 @@ def _add_chunks(db_path: str, doc_id: str, chunks: list[dict]) -> None: conn.close() -def test_search_returns_results(client, test_db, monkeypatch): - import app.api.search as _search_mod - monkeypatch.setattr(_search_mod, "_get_db_path", lambda: test_db) +def test_search_returns_results(client, test_db): # BM25Okapi IDF is 0 when df == N/2 (e.g. 2 docs, 1 match → log(1.0) = 0). # Add a 3rd unrelated chunk so relevant terms score above zero. _add_chunks(test_db, "book-a", [ @@ -46,9 +44,7 @@ def test_search_empty_index_returns_empty(client): assert resp.json() == [] -def test_search_filters_by_doc_ids(client, test_db, monkeypatch): - import app.api.search as _search_mod - monkeypatch.setattr(_search_mod, "_get_db_path", lambda: test_db) +def test_search_filters_by_doc_ids(client, test_db): # Three chunks so BM25Okapi IDF is non-zero for terms appearing in one doc. _add_chunks(test_db, "book-a", [ {"page_number": 1, "text": "Grapple rules for melee attacks."}, diff --git a/tests/test_startup.py b/tests/test_startup.py index 7eca538..e18f111 100644 --- a/tests/test_startup.py +++ b/tests/test_startup.py @@ -9,7 +9,8 @@ from unittest.mock import MagicMock, patch import pytest -from app.main import _check_vec_schema, _reembed_docs +from app.startup import check_and_rebuild_vec_schema as _check_vec_schema +from app.startup import reembed_docs as _reembed_docs def _make_vec_db(path: str, dims: int) -> None: