feat: per-user database isolation for cloud instances (closes #4)

Implements Option A from the issue design: each cloud user gets their own
data directory (DATA_DIR/users/{user_id}/) with separate pagepiper.db,
pagepiper_vecs.db, uploads/, and books/. Local mode is unchanged.

Key changes:
- app/startup.py: extract apply_migrations, reembed_docs,
  check_and_rebuild_vec_schema out of main.py (no circular imports)
- app/config.py: add LOCAL_USER_ID constant and user_data_dir() helper
- app/cloud_session.py: extract resolve_authenticated_user(); require_paid_tier
  now returns user_id (str) instead of None
- app/deps.py: add UserCtx dataclass (db_path, vec_db_path, data_dir,
  watch_dir, bm25) + get_user_ctx dependency; per-user startup guard runs
  migrations + vec schema check once per process per user
- app/main.py: _bm25 singleton -> _bm25_map dict keyed by user_id;
  add _get_bm25_for(); lifespan only runs startup checks in local mode
- app/api/library.py, search.py, chat.py: thread UserCtx through all
  endpoints; remove module-level _mark_bm25_dirty injection pattern
- tests/conftest.py: override get_user_ctx in addition to get_db so all
  endpoints get a consistent test UserCtx
This commit is contained in:
pyr0ball 2026-05-13 16:31:51 -07:00
parent df9e91ad89
commit 8eef52a054
11 changed files with 432 additions and 199 deletions

View file

@ -10,9 +10,11 @@ from __future__ import annotations
import logging
import os
from fastapi import APIRouter, HTTPException
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel
from app.cloud_session import require_paid_tier
from app.deps import UserCtx, get_user_ctx
from app.services.retriever import Retriever
from app.services.synthesizer import Synthesizer
@ -56,21 +58,6 @@ def _get_llm_router():
return LLMRouter(cfg)
def _get_db_path() -> str:
"""Read lazily so test fixtures take effect."""
import pathlib
data_dir = pathlib.Path(os.environ.get("PAGEPIPER_DATA_DIR", "data"))
return str(data_dir / "pagepiper.db")
def _get_vec_db_path() -> str:
import pathlib
data_dir = pathlib.Path(os.environ.get("PAGEPIPER_DATA_DIR", "data"))
return str(data_dir / "pagepiper_vecs.db")
def _require_llm():
"""Return LLMRouter or raise 402."""
llm = _get_llm_router()
@ -89,18 +76,20 @@ def _require_llm():
@router.post("")
def chat(req: ChatRequest) -> ChatResponse:
def chat(
req: ChatRequest,
ctx: UserCtx = Depends(get_user_ctx),
_tier: str = Depends(require_paid_tier),
) -> ChatResponse:
llm = _require_llm()
from app.main import _bm25
retriever = Retriever(_bm25)
retriever = Retriever(ctx.bm25)
chunks = retriever.hybrid_search(
query=req.message,
top_k=req.top_k,
doc_ids=req.doc_ids,
db_path=_get_db_path(),
vec_db_path=_get_vec_db_path(),
db_path=ctx.db_path,
vec_db_path=ctx.vec_db_path,
llm=llm,
)
@ -141,7 +130,10 @@ def chat_feedback_status() -> dict:
@router.post("/feedback")
def submit_chat_feedback(req: ChatFeedbackRequest) -> dict:
def submit_chat_feedback(
req: ChatFeedbackRequest,
ctx: UserCtx = Depends(get_user_ctx),
) -> dict:
import json
import sqlite3
@ -149,8 +141,7 @@ def submit_chat_feedback(req: ChatFeedbackRequest) -> dict:
from fastapi import HTTPException
raise HTTPException(status_code=422, detail="rating must be 1 or -1")
db_path = _get_db_path()
con = sqlite3.connect(db_path)
con = sqlite3.connect(ctx.db_path)
try:
con.execute(
"INSERT INTO chat_feedback (rating, question, answer, doc_ids) VALUES (?, ?, ?, ?)",

View file

@ -14,26 +14,25 @@ from typing import Callable
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, UploadFile
from app.config import WATCH_DIR, DB_PATH, VEC_DB_PATH, DATA_DIR
from app.deps import get_db
from app.config import VEC_DIMENSIONS
from app.deps import UserCtx, get_db, get_user_ctx
_MAX_UPLOAD_BYTES = 200 * 1024 * 1024 # 200 MB
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/library", tags=["library"])
# Injected by main.py after _bm25 is created
_mark_bm25_dirty: Callable[[], None] | None = None
_INGEST_TASKS = {
".pdf": "pagepiper/ingest_pdf",
".epub": "pagepiper/ingest_epub",
".docx": "pagepiper/ingest_docx",
}
_INGEST_RUNNERS = {
".pdf": "scripts.ingest_pdf",
".epub": "scripts.ingest_epub",
".docx": "scripts.ingest_docx",
}
@ -41,24 +40,22 @@ def _dispatch_ingest(
doc_id: str,
file_path: str,
background_tasks: BackgroundTasks,
data_dir: Path,
mark_dirty_fn: Callable[[], None],
) -> str:
"""Dispatch an ingest task. Tries cf-orch; falls back to BackgroundTasks."""
import importlib
import os as _os
from pathlib import Path as _Path
suffix = _Path(file_path).suffix.lower()
suffix = Path(file_path).suffix.lower()
task_name = _INGEST_TASKS.get(suffix, "pagepiper/ingest_pdf")
runner_module = _INGEST_RUNNERS.get(suffix, "scripts.ingest_pdf")
# Read lazily so test fixtures (monkeypatch.setenv) take effect
_data_dir = _Path(_os.environ.get("PAGEPIPER_DATA_DIR", "data"))
task_id = str(uuid.uuid4())
args = {
"doc_id": doc_id,
"file_path": file_path,
"db_path": str(_data_dir / "pagepiper.db"),
"vec_db_path": str(_data_dir / "pagepiper_vecs.db"),
"db_path": str(data_dir / "pagepiper.db"),
"vec_db_path": str(data_dir / "pagepiper_vecs.db"),
}
try:
@ -67,7 +64,7 @@ def _dispatch_ingest(
logger.info("Dispatched cf-orch ingest task %s for doc %s", task_id, doc_id)
except Exception:
mod = importlib.import_module(runner_module)
background_tasks.add_task(_run_ingest_background, mod.run, args, task_id)
background_tasks.add_task(_run_ingest_background, mod.run, args, task_id, mark_dirty_fn)
logger.info(
"cf-orch unavailable — running ingest in background thread (task %s)", task_id
)
@ -75,14 +72,19 @@ def _dispatch_ingest(
return task_id
def _run_ingest_background(run_fn: Callable[..., None], args: dict, task_id: str) -> None:
def _run_ingest_background(
run_fn: Callable[..., None],
args: dict,
task_id: str,
mark_dirty_fn: Callable[[], None] | None = None,
) -> None:
from app.api.ingest import _task_registry
_task_registry[task_id] = {"status": "running", "progress": 0}
try:
run_fn(**args)
_task_registry[task_id] = {"status": "complete", "progress": 100}
if _mark_bm25_dirty:
_mark_bm25_dirty()
if mark_dirty_fn:
mark_dirty_fn()
except Exception as exc:
logger.exception("Ingest task %s failed", task_id)
_task_registry[task_id] = {"status": "error", "error": str(exc)}
@ -101,13 +103,18 @@ def list_library(db: sqlite3.Connection = Depends(get_db)) -> list[dict]:
def scan_library(
background_tasks: BackgroundTasks,
db: sqlite3.Connection = Depends(get_db),
ctx: UserCtx = Depends(get_user_ctx),
) -> dict:
"""Scan the watched directory and queue ingest for any new PDFs."""
watch = WATCH_DIR
watch = ctx.watch_dir
if not watch.exists():
raise HTTPException(status_code=404, detail=f"Watch directory not found: {watch}")
pdfs = list(watch.glob("**/*.pdf")) + list(watch.glob("**/*.epub"))
pdfs = (
list(watch.glob("**/*.pdf"))
+ list(watch.glob("**/*.epub"))
+ list(watch.glob("**/*.docx"))
)
queued = []
for pdf_path in pdfs:
@ -117,7 +124,7 @@ def scan_library(
).fetchone()
if existing and existing["status"] == "ready":
continue # already indexed
continue
if existing:
doc_id = existing["id"]
@ -129,7 +136,9 @@ def scan_library(
).fetchone()[0]
db.commit()
task_id = _dispatch_ingest(doc_id, path_str, background_tasks)
task_id = _dispatch_ingest(
doc_id, path_str, background_tasks, ctx.data_dir, ctx.bm25.mark_dirty
)
db.execute(
"UPDATE documents SET status='processing', task_id=? WHERE id=?",
[task_id, doc_id],
@ -145,12 +154,15 @@ def reingest_document(
doc_id: str,
background_tasks: BackgroundTasks,
db: sqlite3.Connection = Depends(get_db),
ctx: UserCtx = Depends(get_user_ctx),
) -> dict:
row = db.execute("SELECT file_path FROM documents WHERE id=?", [doc_id]).fetchone()
if not row:
raise HTTPException(status_code=404, detail="Document not found")
task_id = _dispatch_ingest(doc_id, row["file_path"], background_tasks)
task_id = _dispatch_ingest(
doc_id, row["file_path"], background_tasks, ctx.data_dir, ctx.bm25.mark_dirty
)
db.execute(
"UPDATE documents SET status='processing', task_id=?, error_msg=NULL WHERE id=?",
[task_id, doc_id],
@ -163,6 +175,7 @@ def reingest_document(
def delete_document(
doc_id: str,
db: sqlite3.Connection = Depends(get_db),
ctx: UserCtx = Depends(get_user_ctx),
) -> None:
row = db.execute("SELECT id FROM documents WHERE id=?", [doc_id]).fetchone()
if not row:
@ -171,23 +184,21 @@ def delete_document(
db.execute("DELETE FROM documents WHERE id=?", [doc_id])
db.commit()
# Remove embeddings from vector store
try:
from circuitforge_core.vector.sqlite_vec import LocalSQLiteVecStore # type: ignore[import]
from app.config import VEC_DIMENSIONS
store = LocalSQLiteVecStore(db_path=VEC_DB_PATH, table="page_vecs", dimensions=VEC_DIMENSIONS)
store = LocalSQLiteVecStore(
db_path=ctx.vec_db_path, table="page_vecs", dimensions=VEC_DIMENSIONS
)
store.delete_where({"doc_id": doc_id})
except Exception as exc:
logger.warning("Could not remove vectors for doc %s: %s", doc_id, exc)
if _mark_bm25_dirty:
_mark_bm25_dirty()
ctx.bm25.mark_dirty()
def _get_vec_count(doc_id: str) -> int:
"""Return how many vectors have been stored for this doc. Returns 0 on any error."""
def _get_vec_count(doc_id: str, vec_db_path: str) -> int:
try:
conn = sqlite3.connect(VEC_DB_PATH)
conn = sqlite3.connect(vec_db_path)
count = conn.execute(
"SELECT COUNT(*) FROM page_vecs_meta WHERE json_extract(metadata, '$.doc_id') = ?",
[doc_id],
@ -202,6 +213,7 @@ def _get_vec_count(doc_id: str) -> int:
def document_status(
doc_id: str,
db: sqlite3.Connection = Depends(get_db),
ctx: UserCtx = Depends(get_user_ctx),
) -> dict:
row = db.execute(
"SELECT id, status, task_id, page_count, error_msg FROM documents WHERE id=?",
@ -210,7 +222,7 @@ def document_status(
if not row:
raise HTTPException(status_code=404, detail="Document not found")
result = dict(row)
result["vec_count"] = _get_vec_count(doc_id)
result["vec_count"] = _get_vec_count(doc_id, ctx.vec_db_path)
return result
@ -219,18 +231,19 @@ def upload_document(
file: UploadFile,
background_tasks: BackgroundTasks,
db: sqlite3.Connection = Depends(get_db),
ctx: UserCtx = Depends(get_user_ctx),
) -> dict:
"""Accept a PDF/EPUB upload, save to data/uploads/, and queue for indexing."""
name = Path(file.filename or "").name
suffix = Path(name).suffix.lower()
if suffix not in _INGEST_TASKS:
raise HTTPException(status_code=400, detail="Supported formats: PDF, EPUB")
raise HTTPException(status_code=400, detail="Supported formats: PDF, EPUB, DOCX")
content = file.file.read()
if len(content) > _MAX_UPLOAD_BYTES:
raise HTTPException(status_code=413, detail="File exceeds 200 MB limit")
upload_dir = DATA_DIR / "uploads"
upload_dir = ctx.data_dir / "uploads"
upload_dir.mkdir(parents=True, exist_ok=True)
dest = upload_dir / name
dest.write_bytes(content)
@ -253,7 +266,9 @@ def upload_document(
).fetchone()[0]
db.commit()
task_id = _dispatch_ingest(doc_id, path_str, background_tasks)
task_id = _dispatch_ingest(
doc_id, path_str, background_tasks, ctx.data_dir, ctx.bm25.mark_dirty
)
db.execute(
"UPDATE documents SET status='processing', task_id=? WHERE id=?",
[task_id, doc_id],

View file

@ -7,13 +7,11 @@ MIT — no tier gate. No Ollama required.
from __future__ import annotations
import logging
import os
from typing import Annotated
from fastapi import APIRouter, Depends
from pydantic import BaseModel, Field
from app.services.bm25_index import BM25Index
from app.deps import UserCtx, get_user_ctx
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/search", tags=["search"])
@ -29,32 +27,17 @@ class SearchResult(BaseModel):
chunk_id: str
doc_id: str
page_number: int
text_snippet: str # first 300 chars of the page text
text_snippet: str
bm25_score: float
def _get_bm25() -> BM25Index:
import app.main as _main
bm25 = getattr(_main, "_bm25", None)
if bm25 is None:
raise RuntimeError("BM25 index not initialised — app.main not loaded")
return bm25
def _get_db_path() -> str:
"""Read lazily so test fixtures (monkeypatch.setattr) take effect."""
import pathlib
data_dir = pathlib.Path(os.environ.get("PAGEPIPER_DATA_DIR", "data"))
return str(data_dir / "pagepiper.db")
@router.post("")
def search(
req: SearchRequest,
bm25: Annotated[BM25Index, Depends(_get_bm25)],
ctx: UserCtx = Depends(get_user_ctx),
) -> list[SearchResult]:
bm25.ensure_fresh(_get_db_path())
hits = bm25.query(req.query, top_k=req.top_k, doc_ids=req.doc_ids)
ctx.bm25.ensure_fresh(ctx.db_path)
hits = ctx.bm25.query(req.query, top_k=req.top_k, doc_ids=req.doc_ids)
return [
SearchResult(
chunk_id=h.chunk_id,

131
app/cloud_session.py Normal file
View file

@ -0,0 +1,131 @@
# app/cloud_session.py
"""Cloud session auth for Pagepiper — validates cf_session cookie via Directus + Heimdall.
In local mode (CLOUD_MODE unset or false), require_paid_tier is a no-op.
In cloud mode, the Caddy proxy forwards the browser's Cookie header as
X-CF-Session. This module extracts cf_session, validates it against
Directus /users/me, then checks the user's Pagepiper tier via Heimdall.
Auto-provisions a free tier key for new users.
"""
from __future__ import annotations
import logging
import os
import re
import httpx
from fastapi import HTTPException, Request
log = logging.getLogger(__name__)
CLOUD_MODE: bool = os.environ.get("CLOUD_MODE", "").lower() in ("1", "true", "yes")
DIRECTUS_URL: str = os.environ.get("DIRECTUS_URL", "http://172.31.0.3:8055").rstrip("/")
HEIMDALL_URL: str = os.environ.get("HEIMDALL_URL", "https://license.circuitforge.tech").rstrip("/")
HEIMDALL_ADMIN_TOKEN: str = os.environ.get("HEIMDALL_ADMIN_TOKEN", "")
_TIER_ORDER = {"free": 0, "paid": 1, "premium": 2, "ultra": 3}
def _extract_session_token(cookie_header: str) -> str:
m = re.search(r"(?:^|;)\s*cf_session=([^;]+)", cookie_header)
return m.group(1).strip() if m else ""
def _get_user_id(jwt: str) -> str | None:
try:
resp = httpx.get(
f"{DIRECTUS_URL}/users/me",
headers={"Authorization": f"Bearer {jwt}"},
timeout=5.0,
)
if resp.status_code == 200:
return resp.json().get("data", {}).get("id")
except Exception as exc:
log.warning("Directus session check failed: %s", exc)
return None
def _ensure_provisioned(user_id: str) -> None:
if not HEIMDALL_ADMIN_TOKEN:
return
try:
httpx.post(
f"{HEIMDALL_URL}/admin/provision",
json={"directus_user_id": user_id, "product": "pagepiper", "tier": "free"},
headers={"Authorization": f"Bearer {HEIMDALL_ADMIN_TOKEN}"},
timeout=5.0,
)
except Exception as exc:
log.warning("Heimdall provision failed for user %s: %s", user_id, exc)
def _get_tier(user_id: str) -> str:
if not HEIMDALL_ADMIN_TOKEN:
return "free"
try:
resp = httpx.get(
f"{HEIMDALL_URL}/admin/cloud/resolve",
params={"directus_user_id": user_id, "product": "pagepiper"},
headers={"Authorization": f"Bearer {HEIMDALL_ADMIN_TOKEN}"},
timeout=5.0,
)
if resp.status_code == 200:
return resp.json().get("tier", "free")
except Exception as exc:
log.warning("Heimdall tier check failed for user %s: %s", user_id, exc)
return "free"
def resolve_authenticated_user(request: Request) -> str:
"""Validate the session cookie and return the Directus user_id. Raises 401 if invalid."""
cookie_header = request.headers.get("x-cf-session", "")
jwt = _extract_session_token(cookie_header)
if not jwt:
raise HTTPException(
status_code=401,
detail={
"error": "auth_required",
"message": "Sign in at circuitforge.tech to use Pagepiper cloud.",
},
)
user_id = _get_user_id(jwt)
if not user_id:
raise HTTPException(
status_code=401,
detail={
"error": "session_invalid",
"message": "Your session has expired. Sign in again at circuitforge.tech.",
},
)
_ensure_provisioned(user_id)
return user_id
def require_paid_tier(request: Request) -> str:
"""FastAPI dependency — 401 if no valid session, 402 if tier < paid. Returns user_id.
In local mode (CLOUD_MODE not set), returns LOCAL_USER_ID without any auth check.
"""
if not CLOUD_MODE:
from app.config import LOCAL_USER_ID
return LOCAL_USER_ID
user_id = resolve_authenticated_user(request)
tier = _get_tier(user_id)
if _TIER_ORDER.get(tier, 0) < _TIER_ORDER["paid"]:
raise HTTPException(
status_code=402,
detail={
"error": "upgrade_required",
"message": (
"RAG chat requires a Paid tier Pagepiper license. "
"Upgrade at circuitforge.tech/software/pagepiper."
),
},
)
return user_id

View file

@ -12,16 +12,36 @@ VEC_DB_PATH = str(DATA_DIR / "pagepiper_vecs.db")
WATCH_DIR = Path(os.environ.get("PAGEPIPER_WATCH_DIR", "books"))
VEC_DIMENSIONS = int(os.environ.get("PAGEPIPER_EMBED_DIMS", "1024"))
LOCAL_USER_ID = "__local__"
def user_data_dir(user_id: str) -> Path:
"""Return (and create) the per-user data directory under DATA_DIR/users/."""
d = DATA_DIR / "users" / user_id
d.mkdir(parents=True, exist_ok=True)
return d
def get_llm_config() -> dict | None:
"""Build LLMRouter config from env vars. Returns None if PAGEPIPER_OLLAMA_URL is unset."""
"""Build LLMRouter config from env vars.
Returns None only when neither PAGEPIPER_OLLAMA_URL nor CF_ORCH_URL is set.
CF_ORCH_URL alone is sufficient the coordinator resolves the service URL at
allocation time so PAGEPIPER_OLLAMA_URL becomes optional.
"""
url = os.environ.get("PAGEPIPER_OLLAMA_URL", "").strip()
if not url:
orch_url = os.environ.get("CF_ORCH_URL", "").strip()
if not url and not orch_url:
return None
_clean = url.rstrip("/")
_base_url = _clean if _clean.endswith("/v1") else _clean + "/v1"
chat_model = os.environ.get("PAGEPIPER_CHAT_MODEL", "mistral:7b")
_base_url = ""
if url:
_clean = url.rstrip("/")
_base_url = _clean if _clean.endswith("/v1") else _clean + "/v1"
backend: dict = {
"type": "openai_compat",
"base_url": _base_url,
@ -30,12 +50,9 @@ def get_llm_config() -> dict | None:
"supports_images": False,
}
# Wire cf-orch allocation when coordinator is configured so the model stays warm
# and cold-start latency doesn't cause chat timeouts.
orch_url = os.environ.get("CF_ORCH_URL", "").strip()
if orch_url:
backend["cf_orch"] = {
"service": "ollama",
"service": os.environ.get("PAGEPIPER_ORCH_SERVICE", "ollama"),
"model_candidates": [chat_model],
"ttl_s": 3600,
}

View file

@ -3,13 +3,75 @@
from __future__ import annotations
import sqlite3
from dataclasses import dataclass
from pathlib import Path
from typing import Generator
from app.config import DB_PATH
from fastapi import Depends, Request
from app.config import DATA_DIR, LOCAL_USER_ID
from app.services.bm25_index import BM25Index
def get_db() -> Generator[sqlite3.Connection, None, None]:
conn = sqlite3.connect(DB_PATH, check_same_thread=False)
@dataclass
class UserCtx:
"""Per-request context routing DB paths and BM25 to the right user."""
user_id: str
db_path: str
vec_db_path: str
data_dir: Path
watch_dir: Path
bm25: BM25Index
_user_startup_done: set[str] = set()
def _run_user_startup(user_id: str, user_dir: Path) -> None:
"""Run migrations and vec schema check once per process lifetime per user."""
if user_id in _user_startup_done:
return
_user_startup_done.add(user_id)
from app.config import VEC_DIMENSIONS
from app.startup import apply_migrations, check_and_rebuild_vec_schema
apply_migrations(str(user_dir / "pagepiper.db"))
check_and_rebuild_vec_schema(
str(user_dir / "pagepiper_vecs.db"), VEC_DIMENSIONS, str(user_dir / "pagepiper.db")
)
def get_user_ctx(request: Request) -> UserCtx:
"""Resolve the per-user data directory, DB paths, and BM25 instance for this request."""
import app.main as _main
from app.cloud_session import CLOUD_MODE
if CLOUD_MODE:
from app.cloud_session import resolve_authenticated_user
from app.config import user_data_dir
user_id = resolve_authenticated_user(request)
user_dir = user_data_dir(user_id)
_run_user_startup(user_id, user_dir)
watch_dir = user_dir / "books"
watch_dir.mkdir(parents=True, exist_ok=True)
else:
from app.config import WATCH_DIR
user_id = LOCAL_USER_ID
user_dir = DATA_DIR
watch_dir = WATCH_DIR
return UserCtx(
user_id=user_id,
db_path=str(user_dir / "pagepiper.db"),
vec_db_path=str(user_dir / "pagepiper_vecs.db"),
data_dir=user_dir,
watch_dir=watch_dir,
bm25=_main._get_bm25_for(user_id),
)
def get_db(ctx: UserCtx = Depends(get_user_ctx)) -> Generator[sqlite3.Connection, None, None]:
conn = sqlite3.connect(ctx.db_path, check_same_thread=False)
conn.execute("PRAGMA foreign_keys = ON")
conn.execute("PRAGMA journal_mode = WAL")
conn.row_factory = sqlite3.Row

View file

@ -4,9 +4,6 @@ from __future__ import annotations
import logging
import os
import re
import sqlite3
import threading
from contextlib import asynccontextmanager
from fastapi import FastAPI
@ -16,110 +13,36 @@ from app.services.bm25_index import BM25Index
logger = logging.getLogger("pagepiper")
# Module-level BM25 singleton — shared across all requests
_bm25 = BM25Index()
# Per-user BM25 registry — keyed by user_id; "__local__" for single-user mode
_bm25_map: dict[str, BM25Index] = {}
def _apply_migrations() -> None:
from scripts.db_migrate import migrate
migrate(DB_PATH)
def _reembed_docs(docs: list[tuple[str, str]], db_path: str, vec_db_path: str) -> None:
"""Re-run full ingest for a list of (doc_id, file_path) sequentially."""
for doc_id, file_path in docs:
suffix = os.path.splitext(file_path)[1].lower()
try:
if suffix == ".epub":
from scripts.ingest_epub import run
else:
from scripts.ingest_pdf import run
logger.info("Auto re-embed: starting %s", os.path.basename(file_path))
run(doc_id=doc_id, file_path=file_path, db_path=db_path, vec_db_path=vec_db_path)
except Exception as exc:
logger.error("Auto re-embed failed for doc %s: %s", doc_id[:8], exc)
def _check_vec_schema(vec_db_path: str, expected_dims: int, db_path: str) -> None:
"""Drop the vec DB if its stored dimension doesn't match config, then queue re-embed.
sqlite-vec bakes the embedding dimension into the virtual table DDL, so changing
models requires dropping and recreating the whole file. Catches the mismatch at
startup rather than surfacing it as an obscure OperationalError mid-request.
"""
if not os.path.exists(vec_db_path):
return
try:
conn = sqlite3.connect(vec_db_path)
row = conn.execute(
"SELECT sql FROM sqlite_master WHERE name='page_vecs_vecs'"
).fetchone()
conn.close()
except Exception as exc:
logger.warning("Vec schema check could not read %s (non-fatal): %s", vec_db_path, exc)
return
if not row:
return # table not yet created — first embed will build it with the right dims
m = re.search(r'float\[(\d+)\]', row[0])
if not m:
return
actual_dims = int(m.group(1))
if actual_dims == expected_dims:
return
logger.warning(
"Vec DB dimension mismatch: stored=%d, configured=%d — dropping %s and queuing re-embed",
actual_dims, expected_dims, vec_db_path,
)
try:
os.remove(vec_db_path)
except OSError as exc:
logger.error(
"Could not delete stale vec DB %s: %s — fix permissions and restart", vec_db_path, exc
)
return
# Collect all ready docs so we can rebuild their embeddings in the background.
try:
conn = sqlite3.connect(db_path)
docs = conn.execute(
"SELECT id, file_path FROM documents WHERE status='ready'"
).fetchall()
conn.close()
except Exception as exc:
logger.warning("Could not query documents for re-embed: %s", exc)
return
if not docs:
return
logger.info("Queuing re-embed for %d document(s) in background", len(docs))
threading.Thread(
target=_reembed_docs,
args=(docs, db_path, vec_db_path),
daemon=True,
name="pagepiper-reembed",
).start()
def _get_bm25_for(user_id: str) -> BM25Index:
if user_id not in _bm25_map:
_bm25_map[user_id] = BM25Index()
return _bm25_map[user_id]
@asynccontextmanager
async def lifespan(app: FastAPI):
_apply_migrations()
from app.cloud_session import CLOUD_MODE
from app.config import LOCAL_USER_ID
from app.startup import apply_migrations, check_and_rebuild_vec_schema
embed_model = os.environ.get("PAGEPIPER_EMBED_MODEL", "nomic-embed-text")
logger.info("Pagepiper starting — embed model: %s, dims: %d", embed_model, VEC_DIMENSIONS)
_check_vec_schema(VEC_DB_PATH, VEC_DIMENSIONS, DB_PATH)
_bm25.mark_dirty() # will rebuild on first search
if not CLOUD_MODE:
# In cloud mode, per-user migration and vec schema check run on first request (deps.py).
apply_migrations(DB_PATH)
check_and_rebuild_vec_schema(VEC_DB_PATH, VEC_DIMENSIONS, DB_PATH)
_get_bm25_for(LOCAL_USER_ID).mark_dirty()
yield
app = FastAPI(title="Pagepiper", lifespan=lifespan)
# Wire BM25 dirty callback into library router
from app.api import library as _lib_module # noqa: E402
_lib_module._mark_bm25_dirty = _bm25.mark_dirty
# Register routers
from app.api.library import router as library_router # noqa: E402
from app.api.ingest import router as ingest_router # noqa: E402

95
app/startup.py Normal file
View file

@ -0,0 +1,95 @@
# app/startup.py
"""DB migration and vec schema check utilities — called at startup and on first user request."""
from __future__ import annotations
import logging
import os
import re
import sqlite3
import threading
logger = logging.getLogger("pagepiper")
def apply_migrations(db_path: str) -> None:
from scripts.db_migrate import migrate
migrate(db_path)
def reembed_docs(docs: list[tuple[str, str]], db_path: str, vec_db_path: str) -> None:
for doc_id, file_path in docs:
suffix = os.path.splitext(file_path)[1].lower()
try:
if suffix == ".epub":
from scripts.ingest_epub import run
elif suffix == ".docx":
from scripts.ingest_docx import run
else:
from scripts.ingest_pdf import run
logger.info("Auto re-embed: starting %s", os.path.basename(file_path))
run(doc_id=doc_id, file_path=file_path, db_path=db_path, vec_db_path=vec_db_path)
except Exception as exc:
logger.error("Auto re-embed failed for doc %s: %s", doc_id[:8], exc)
def check_and_rebuild_vec_schema(vec_db_path: str, expected_dims: int, db_path: str) -> None:
"""Drop the vec DB if its stored dimension doesn't match config, then queue re-embed.
sqlite-vec bakes the embedding dimension into the virtual table DDL, so changing
models requires dropping and recreating the whole file. Catches the mismatch at
startup rather than surfacing it as an obscure OperationalError mid-request.
"""
if not os.path.exists(vec_db_path):
return
try:
conn = sqlite3.connect(vec_db_path)
row = conn.execute(
"SELECT sql FROM sqlite_master WHERE name='page_vecs_vecs'"
).fetchone()
conn.close()
except Exception as exc:
logger.warning("Vec schema check could not read %s (non-fatal): %s", vec_db_path, exc)
return
if not row:
return
m = re.search(r'float\[(\d+)\]', row[0])
if not m:
return
actual_dims = int(m.group(1))
if actual_dims == expected_dims:
return
logger.warning(
"Vec DB dimension mismatch: stored=%d, configured=%d — dropping %s and queuing re-embed",
actual_dims, expected_dims, vec_db_path,
)
try:
os.remove(vec_db_path)
except OSError as exc:
logger.error(
"Could not delete stale vec DB %s: %s — fix permissions and restart", vec_db_path, exc
)
return
try:
conn = sqlite3.connect(db_path)
docs = conn.execute(
"SELECT id, file_path FROM documents WHERE status='ready'"
).fetchall()
conn.close()
except Exception as exc:
logger.warning("Could not query documents for re-embed: %s", exc)
return
if not docs:
return
logger.info("Queuing re-embed for %d document(s) in background", len(docs))
threading.Thread(
target=reembed_docs,
args=(docs, db_path, vec_db_path),
daemon=True,
name="pagepiper-reembed",
).start()

View file

@ -27,13 +27,32 @@ def client(test_db, tmp_path, monkeypatch):
(tmp_path / "books").mkdir(exist_ok=True)
import app.main as _main_module
from app.main import app, _bm25
from app.deps import get_db
from app.config import LOCAL_USER_ID
from app.deps import UserCtx, get_db, get_user_ctx
from app.main import app
from app.services.bm25_index import BM25Index
from app.startup import apply_migrations, check_and_rebuild_vec_schema
# Suppress startup side effects — test_db fixture already applies the schema,
# and vec schema validation is tested separately in test_startup.py
monkeypatch.setattr(_main_module, "_apply_migrations", lambda: None)
monkeypatch.setattr(_main_module, "_check_vec_schema", lambda *a, **kw: None)
monkeypatch.setattr(_main_module, "_apply_migrations", lambda: None, raising=False)
monkeypatch.setattr(
"app.startup.apply_migrations", lambda *a, **kw: None
)
monkeypatch.setattr(
"app.startup.check_and_rebuild_vec_schema", lambda *a, **kw: None
)
test_bm25 = BM25Index()
test_bm25.mark_dirty()
def override_user_ctx():
return UserCtx(
user_id=LOCAL_USER_ID,
db_path=test_db,
vec_db_path=str(tmp_path / "test_vecs.db"),
data_dir=Path(tmp_path),
watch_dir=Path(tmp_path) / "books",
bm25=test_bm25,
)
def override_db():
conn = sqlite3.connect(test_db)
@ -44,7 +63,7 @@ def client(test_db, tmp_path, monkeypatch):
finally:
conn.close()
app.dependency_overrides[get_user_ctx] = override_user_ctx
app.dependency_overrides[get_db] = override_db
_bm25.mark_dirty() # clear any state from previous tests
yield TestClient(app)
app.dependency_overrides.clear()

View file

@ -20,9 +20,7 @@ def _add_chunks(db_path: str, doc_id: str, chunks: list[dict]) -> None:
conn.close()
def test_search_returns_results(client, test_db, monkeypatch):
import app.api.search as _search_mod
monkeypatch.setattr(_search_mod, "_get_db_path", lambda: test_db)
def test_search_returns_results(client, test_db):
# BM25Okapi IDF is 0 when df == N/2 (e.g. 2 docs, 1 match → log(1.0) = 0).
# Add a 3rd unrelated chunk so relevant terms score above zero.
_add_chunks(test_db, "book-a", [
@ -46,9 +44,7 @@ def test_search_empty_index_returns_empty(client):
assert resp.json() == []
def test_search_filters_by_doc_ids(client, test_db, monkeypatch):
import app.api.search as _search_mod
monkeypatch.setattr(_search_mod, "_get_db_path", lambda: test_db)
def test_search_filters_by_doc_ids(client, test_db):
# Three chunks so BM25Okapi IDF is non-zero for terms appearing in one doc.
_add_chunks(test_db, "book-a", [
{"page_number": 1, "text": "Grapple rules for melee attacks."},

View file

@ -9,7 +9,8 @@ from unittest.mock import MagicMock, patch
import pytest
from app.main import _check_vec_schema, _reembed_docs
from app.startup import check_and_rebuild_vec_schema as _check_vec_schema
from app.startup import reembed_docs as _reembed_docs
def _make_vec_db(path: str, dims: int) -> None: