feat: per-user database isolation for cloud instances (closes #4)

Implements Option A from the issue design: each cloud user gets their own
data directory (DATA_DIR/users/{user_id}/) with separate pagepiper.db,
pagepiper_vecs.db, uploads/, and books/. Local mode is unchanged.

Key changes:
- app/startup.py: extract apply_migrations, reembed_docs,
  check_and_rebuild_vec_schema out of main.py (no circular imports)
- app/config.py: add LOCAL_USER_ID constant and user_data_dir() helper
- app/cloud_session.py: extract resolve_authenticated_user(); require_paid_tier
  now returns user_id (str) instead of None
- app/deps.py: add UserCtx dataclass (db_path, vec_db_path, data_dir,
  watch_dir, bm25) + get_user_ctx dependency; per-user startup guard runs
  migrations + vec schema check once per process per user
- app/main.py: _bm25 singleton -> _bm25_map dict keyed by user_id;
  add _get_bm25_for(); lifespan only runs startup checks in local mode
- app/api/library.py, search.py, chat.py: thread UserCtx through all
  endpoints; remove module-level _mark_bm25_dirty injection pattern
- tests/conftest.py: override get_user_ctx in addition to get_db so all
  endpoints get a consistent test UserCtx
This commit is contained in:
pyr0ball 2026-05-13 16:31:51 -07:00
parent df9e91ad89
commit 8eef52a054
11 changed files with 432 additions and 199 deletions

View file

@ -10,9 +10,11 @@ from __future__ import annotations
import logging import logging
import os import os
from fastapi import APIRouter, HTTPException from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel from pydantic import BaseModel
from app.cloud_session import require_paid_tier
from app.deps import UserCtx, get_user_ctx
from app.services.retriever import Retriever from app.services.retriever import Retriever
from app.services.synthesizer import Synthesizer from app.services.synthesizer import Synthesizer
@ -56,21 +58,6 @@ def _get_llm_router():
return LLMRouter(cfg) return LLMRouter(cfg)
def _get_db_path() -> str:
"""Read lazily so test fixtures take effect."""
import pathlib
data_dir = pathlib.Path(os.environ.get("PAGEPIPER_DATA_DIR", "data"))
return str(data_dir / "pagepiper.db")
def _get_vec_db_path() -> str:
import pathlib
data_dir = pathlib.Path(os.environ.get("PAGEPIPER_DATA_DIR", "data"))
return str(data_dir / "pagepiper_vecs.db")
def _require_llm(): def _require_llm():
"""Return LLMRouter or raise 402.""" """Return LLMRouter or raise 402."""
llm = _get_llm_router() llm = _get_llm_router()
@ -89,18 +76,20 @@ def _require_llm():
@router.post("") @router.post("")
def chat(req: ChatRequest) -> ChatResponse: def chat(
req: ChatRequest,
ctx: UserCtx = Depends(get_user_ctx),
_tier: str = Depends(require_paid_tier),
) -> ChatResponse:
llm = _require_llm() llm = _require_llm()
from app.main import _bm25 retriever = Retriever(ctx.bm25)
retriever = Retriever(_bm25)
chunks = retriever.hybrid_search( chunks = retriever.hybrid_search(
query=req.message, query=req.message,
top_k=req.top_k, top_k=req.top_k,
doc_ids=req.doc_ids, doc_ids=req.doc_ids,
db_path=_get_db_path(), db_path=ctx.db_path,
vec_db_path=_get_vec_db_path(), vec_db_path=ctx.vec_db_path,
llm=llm, llm=llm,
) )
@ -141,7 +130,10 @@ def chat_feedback_status() -> dict:
@router.post("/feedback") @router.post("/feedback")
def submit_chat_feedback(req: ChatFeedbackRequest) -> dict: def submit_chat_feedback(
req: ChatFeedbackRequest,
ctx: UserCtx = Depends(get_user_ctx),
) -> dict:
import json import json
import sqlite3 import sqlite3
@ -149,8 +141,7 @@ def submit_chat_feedback(req: ChatFeedbackRequest) -> dict:
from fastapi import HTTPException from fastapi import HTTPException
raise HTTPException(status_code=422, detail="rating must be 1 or -1") raise HTTPException(status_code=422, detail="rating must be 1 or -1")
db_path = _get_db_path() con = sqlite3.connect(ctx.db_path)
con = sqlite3.connect(db_path)
try: try:
con.execute( con.execute(
"INSERT INTO chat_feedback (rating, question, answer, doc_ids) VALUES (?, ?, ?, ?)", "INSERT INTO chat_feedback (rating, question, answer, doc_ids) VALUES (?, ?, ?, ?)",

View file

@ -14,26 +14,25 @@ from typing import Callable
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, UploadFile from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, UploadFile
from app.config import WATCH_DIR, DB_PATH, VEC_DB_PATH, DATA_DIR from app.config import VEC_DIMENSIONS
from app.deps import get_db from app.deps import UserCtx, get_db, get_user_ctx
_MAX_UPLOAD_BYTES = 200 * 1024 * 1024 # 200 MB _MAX_UPLOAD_BYTES = 200 * 1024 * 1024 # 200 MB
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/library", tags=["library"]) router = APIRouter(prefix="/api/library", tags=["library"])
# Injected by main.py after _bm25 is created
_mark_bm25_dirty: Callable[[], None] | None = None
_INGEST_TASKS = { _INGEST_TASKS = {
".pdf": "pagepiper/ingest_pdf", ".pdf": "pagepiper/ingest_pdf",
".epub": "pagepiper/ingest_epub", ".epub": "pagepiper/ingest_epub",
".docx": "pagepiper/ingest_docx",
} }
_INGEST_RUNNERS = { _INGEST_RUNNERS = {
".pdf": "scripts.ingest_pdf", ".pdf": "scripts.ingest_pdf",
".epub": "scripts.ingest_epub", ".epub": "scripts.ingest_epub",
".docx": "scripts.ingest_docx",
} }
@ -41,24 +40,22 @@ def _dispatch_ingest(
doc_id: str, doc_id: str,
file_path: str, file_path: str,
background_tasks: BackgroundTasks, background_tasks: BackgroundTasks,
data_dir: Path,
mark_dirty_fn: Callable[[], None],
) -> str: ) -> str:
"""Dispatch an ingest task. Tries cf-orch; falls back to BackgroundTasks.""" """Dispatch an ingest task. Tries cf-orch; falls back to BackgroundTasks."""
import importlib import importlib
import os as _os
from pathlib import Path as _Path
suffix = _Path(file_path).suffix.lower() suffix = Path(file_path).suffix.lower()
task_name = _INGEST_TASKS.get(suffix, "pagepiper/ingest_pdf") task_name = _INGEST_TASKS.get(suffix, "pagepiper/ingest_pdf")
runner_module = _INGEST_RUNNERS.get(suffix, "scripts.ingest_pdf") runner_module = _INGEST_RUNNERS.get(suffix, "scripts.ingest_pdf")
# Read lazily so test fixtures (monkeypatch.setenv) take effect
_data_dir = _Path(_os.environ.get("PAGEPIPER_DATA_DIR", "data"))
task_id = str(uuid.uuid4()) task_id = str(uuid.uuid4())
args = { args = {
"doc_id": doc_id, "doc_id": doc_id,
"file_path": file_path, "file_path": file_path,
"db_path": str(_data_dir / "pagepiper.db"), "db_path": str(data_dir / "pagepiper.db"),
"vec_db_path": str(_data_dir / "pagepiper_vecs.db"), "vec_db_path": str(data_dir / "pagepiper_vecs.db"),
} }
try: try:
@ -67,7 +64,7 @@ def _dispatch_ingest(
logger.info("Dispatched cf-orch ingest task %s for doc %s", task_id, doc_id) logger.info("Dispatched cf-orch ingest task %s for doc %s", task_id, doc_id)
except Exception: except Exception:
mod = importlib.import_module(runner_module) mod = importlib.import_module(runner_module)
background_tasks.add_task(_run_ingest_background, mod.run, args, task_id) background_tasks.add_task(_run_ingest_background, mod.run, args, task_id, mark_dirty_fn)
logger.info( logger.info(
"cf-orch unavailable — running ingest in background thread (task %s)", task_id "cf-orch unavailable — running ingest in background thread (task %s)", task_id
) )
@ -75,14 +72,19 @@ def _dispatch_ingest(
return task_id return task_id
def _run_ingest_background(run_fn: Callable[..., None], args: dict, task_id: str) -> None: def _run_ingest_background(
run_fn: Callable[..., None],
args: dict,
task_id: str,
mark_dirty_fn: Callable[[], None] | None = None,
) -> None:
from app.api.ingest import _task_registry from app.api.ingest import _task_registry
_task_registry[task_id] = {"status": "running", "progress": 0} _task_registry[task_id] = {"status": "running", "progress": 0}
try: try:
run_fn(**args) run_fn(**args)
_task_registry[task_id] = {"status": "complete", "progress": 100} _task_registry[task_id] = {"status": "complete", "progress": 100}
if _mark_bm25_dirty: if mark_dirty_fn:
_mark_bm25_dirty() mark_dirty_fn()
except Exception as exc: except Exception as exc:
logger.exception("Ingest task %s failed", task_id) logger.exception("Ingest task %s failed", task_id)
_task_registry[task_id] = {"status": "error", "error": str(exc)} _task_registry[task_id] = {"status": "error", "error": str(exc)}
@ -101,13 +103,18 @@ def list_library(db: sqlite3.Connection = Depends(get_db)) -> list[dict]:
def scan_library( def scan_library(
background_tasks: BackgroundTasks, background_tasks: BackgroundTasks,
db: sqlite3.Connection = Depends(get_db), db: sqlite3.Connection = Depends(get_db),
ctx: UserCtx = Depends(get_user_ctx),
) -> dict: ) -> dict:
"""Scan the watched directory and queue ingest for any new PDFs.""" """Scan the watched directory and queue ingest for any new PDFs."""
watch = WATCH_DIR watch = ctx.watch_dir
if not watch.exists(): if not watch.exists():
raise HTTPException(status_code=404, detail=f"Watch directory not found: {watch}") raise HTTPException(status_code=404, detail=f"Watch directory not found: {watch}")
pdfs = list(watch.glob("**/*.pdf")) + list(watch.glob("**/*.epub")) pdfs = (
list(watch.glob("**/*.pdf"))
+ list(watch.glob("**/*.epub"))
+ list(watch.glob("**/*.docx"))
)
queued = [] queued = []
for pdf_path in pdfs: for pdf_path in pdfs:
@ -117,7 +124,7 @@ def scan_library(
).fetchone() ).fetchone()
if existing and existing["status"] == "ready": if existing and existing["status"] == "ready":
continue # already indexed continue
if existing: if existing:
doc_id = existing["id"] doc_id = existing["id"]
@ -129,7 +136,9 @@ def scan_library(
).fetchone()[0] ).fetchone()[0]
db.commit() db.commit()
task_id = _dispatch_ingest(doc_id, path_str, background_tasks) task_id = _dispatch_ingest(
doc_id, path_str, background_tasks, ctx.data_dir, ctx.bm25.mark_dirty
)
db.execute( db.execute(
"UPDATE documents SET status='processing', task_id=? WHERE id=?", "UPDATE documents SET status='processing', task_id=? WHERE id=?",
[task_id, doc_id], [task_id, doc_id],
@ -145,12 +154,15 @@ def reingest_document(
doc_id: str, doc_id: str,
background_tasks: BackgroundTasks, background_tasks: BackgroundTasks,
db: sqlite3.Connection = Depends(get_db), db: sqlite3.Connection = Depends(get_db),
ctx: UserCtx = Depends(get_user_ctx),
) -> dict: ) -> dict:
row = db.execute("SELECT file_path FROM documents WHERE id=?", [doc_id]).fetchone() row = db.execute("SELECT file_path FROM documents WHERE id=?", [doc_id]).fetchone()
if not row: if not row:
raise HTTPException(status_code=404, detail="Document not found") raise HTTPException(status_code=404, detail="Document not found")
task_id = _dispatch_ingest(doc_id, row["file_path"], background_tasks) task_id = _dispatch_ingest(
doc_id, row["file_path"], background_tasks, ctx.data_dir, ctx.bm25.mark_dirty
)
db.execute( db.execute(
"UPDATE documents SET status='processing', task_id=?, error_msg=NULL WHERE id=?", "UPDATE documents SET status='processing', task_id=?, error_msg=NULL WHERE id=?",
[task_id, doc_id], [task_id, doc_id],
@ -163,6 +175,7 @@ def reingest_document(
def delete_document( def delete_document(
doc_id: str, doc_id: str,
db: sqlite3.Connection = Depends(get_db), db: sqlite3.Connection = Depends(get_db),
ctx: UserCtx = Depends(get_user_ctx),
) -> None: ) -> None:
row = db.execute("SELECT id FROM documents WHERE id=?", [doc_id]).fetchone() row = db.execute("SELECT id FROM documents WHERE id=?", [doc_id]).fetchone()
if not row: if not row:
@ -171,23 +184,21 @@ def delete_document(
db.execute("DELETE FROM documents WHERE id=?", [doc_id]) db.execute("DELETE FROM documents WHERE id=?", [doc_id])
db.commit() db.commit()
# Remove embeddings from vector store
try: try:
from circuitforge_core.vector.sqlite_vec import LocalSQLiteVecStore # type: ignore[import] from circuitforge_core.vector.sqlite_vec import LocalSQLiteVecStore # type: ignore[import]
from app.config import VEC_DIMENSIONS store = LocalSQLiteVecStore(
store = LocalSQLiteVecStore(db_path=VEC_DB_PATH, table="page_vecs", dimensions=VEC_DIMENSIONS) db_path=ctx.vec_db_path, table="page_vecs", dimensions=VEC_DIMENSIONS
)
store.delete_where({"doc_id": doc_id}) store.delete_where({"doc_id": doc_id})
except Exception as exc: except Exception as exc:
logger.warning("Could not remove vectors for doc %s: %s", doc_id, exc) logger.warning("Could not remove vectors for doc %s: %s", doc_id, exc)
if _mark_bm25_dirty: ctx.bm25.mark_dirty()
_mark_bm25_dirty()
def _get_vec_count(doc_id: str) -> int: def _get_vec_count(doc_id: str, vec_db_path: str) -> int:
"""Return how many vectors have been stored for this doc. Returns 0 on any error."""
try: try:
conn = sqlite3.connect(VEC_DB_PATH) conn = sqlite3.connect(vec_db_path)
count = conn.execute( count = conn.execute(
"SELECT COUNT(*) FROM page_vecs_meta WHERE json_extract(metadata, '$.doc_id') = ?", "SELECT COUNT(*) FROM page_vecs_meta WHERE json_extract(metadata, '$.doc_id') = ?",
[doc_id], [doc_id],
@ -202,6 +213,7 @@ def _get_vec_count(doc_id: str) -> int:
def document_status( def document_status(
doc_id: str, doc_id: str,
db: sqlite3.Connection = Depends(get_db), db: sqlite3.Connection = Depends(get_db),
ctx: UserCtx = Depends(get_user_ctx),
) -> dict: ) -> dict:
row = db.execute( row = db.execute(
"SELECT id, status, task_id, page_count, error_msg FROM documents WHERE id=?", "SELECT id, status, task_id, page_count, error_msg FROM documents WHERE id=?",
@ -210,7 +222,7 @@ def document_status(
if not row: if not row:
raise HTTPException(status_code=404, detail="Document not found") raise HTTPException(status_code=404, detail="Document not found")
result = dict(row) result = dict(row)
result["vec_count"] = _get_vec_count(doc_id) result["vec_count"] = _get_vec_count(doc_id, ctx.vec_db_path)
return result return result
@ -219,18 +231,19 @@ def upload_document(
file: UploadFile, file: UploadFile,
background_tasks: BackgroundTasks, background_tasks: BackgroundTasks,
db: sqlite3.Connection = Depends(get_db), db: sqlite3.Connection = Depends(get_db),
ctx: UserCtx = Depends(get_user_ctx),
) -> dict: ) -> dict:
"""Accept a PDF/EPUB upload, save to data/uploads/, and queue for indexing.""" """Accept a PDF/EPUB upload, save to data/uploads/, and queue for indexing."""
name = Path(file.filename or "").name name = Path(file.filename or "").name
suffix = Path(name).suffix.lower() suffix = Path(name).suffix.lower()
if suffix not in _INGEST_TASKS: if suffix not in _INGEST_TASKS:
raise HTTPException(status_code=400, detail="Supported formats: PDF, EPUB") raise HTTPException(status_code=400, detail="Supported formats: PDF, EPUB, DOCX")
content = file.file.read() content = file.file.read()
if len(content) > _MAX_UPLOAD_BYTES: if len(content) > _MAX_UPLOAD_BYTES:
raise HTTPException(status_code=413, detail="File exceeds 200 MB limit") raise HTTPException(status_code=413, detail="File exceeds 200 MB limit")
upload_dir = DATA_DIR / "uploads" upload_dir = ctx.data_dir / "uploads"
upload_dir.mkdir(parents=True, exist_ok=True) upload_dir.mkdir(parents=True, exist_ok=True)
dest = upload_dir / name dest = upload_dir / name
dest.write_bytes(content) dest.write_bytes(content)
@ -253,7 +266,9 @@ def upload_document(
).fetchone()[0] ).fetchone()[0]
db.commit() db.commit()
task_id = _dispatch_ingest(doc_id, path_str, background_tasks) task_id = _dispatch_ingest(
doc_id, path_str, background_tasks, ctx.data_dir, ctx.bm25.mark_dirty
)
db.execute( db.execute(
"UPDATE documents SET status='processing', task_id=? WHERE id=?", "UPDATE documents SET status='processing', task_id=? WHERE id=?",
[task_id, doc_id], [task_id, doc_id],

View file

@ -7,13 +7,11 @@ MIT — no tier gate. No Ollama required.
from __future__ import annotations from __future__ import annotations
import logging import logging
import os
from typing import Annotated
from fastapi import APIRouter, Depends from fastapi import APIRouter, Depends
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from app.services.bm25_index import BM25Index from app.deps import UserCtx, get_user_ctx
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/search", tags=["search"]) router = APIRouter(prefix="/api/search", tags=["search"])
@ -29,32 +27,17 @@ class SearchResult(BaseModel):
chunk_id: str chunk_id: str
doc_id: str doc_id: str
page_number: int page_number: int
text_snippet: str # first 300 chars of the page text text_snippet: str
bm25_score: float bm25_score: float
def _get_bm25() -> BM25Index:
import app.main as _main
bm25 = getattr(_main, "_bm25", None)
if bm25 is None:
raise RuntimeError("BM25 index not initialised — app.main not loaded")
return bm25
def _get_db_path() -> str:
"""Read lazily so test fixtures (monkeypatch.setattr) take effect."""
import pathlib
data_dir = pathlib.Path(os.environ.get("PAGEPIPER_DATA_DIR", "data"))
return str(data_dir / "pagepiper.db")
@router.post("") @router.post("")
def search( def search(
req: SearchRequest, req: SearchRequest,
bm25: Annotated[BM25Index, Depends(_get_bm25)], ctx: UserCtx = Depends(get_user_ctx),
) -> list[SearchResult]: ) -> list[SearchResult]:
bm25.ensure_fresh(_get_db_path()) ctx.bm25.ensure_fresh(ctx.db_path)
hits = bm25.query(req.query, top_k=req.top_k, doc_ids=req.doc_ids) hits = ctx.bm25.query(req.query, top_k=req.top_k, doc_ids=req.doc_ids)
return [ return [
SearchResult( SearchResult(
chunk_id=h.chunk_id, chunk_id=h.chunk_id,

131
app/cloud_session.py Normal file
View file

@ -0,0 +1,131 @@
# app/cloud_session.py
"""Cloud session auth for Pagepiper — validates cf_session cookie via Directus + Heimdall.
In local mode (CLOUD_MODE unset or false), require_paid_tier is a no-op.
In cloud mode, the Caddy proxy forwards the browser's Cookie header as
X-CF-Session. This module extracts cf_session, validates it against
Directus /users/me, then checks the user's Pagepiper tier via Heimdall.
Auto-provisions a free tier key for new users.
"""
from __future__ import annotations
import logging
import os
import re
import httpx
from fastapi import HTTPException, Request
log = logging.getLogger(__name__)
CLOUD_MODE: bool = os.environ.get("CLOUD_MODE", "").lower() in ("1", "true", "yes")
DIRECTUS_URL: str = os.environ.get("DIRECTUS_URL", "http://172.31.0.3:8055").rstrip("/")
HEIMDALL_URL: str = os.environ.get("HEIMDALL_URL", "https://license.circuitforge.tech").rstrip("/")
HEIMDALL_ADMIN_TOKEN: str = os.environ.get("HEIMDALL_ADMIN_TOKEN", "")
_TIER_ORDER = {"free": 0, "paid": 1, "premium": 2, "ultra": 3}
def _extract_session_token(cookie_header: str) -> str:
m = re.search(r"(?:^|;)\s*cf_session=([^;]+)", cookie_header)
return m.group(1).strip() if m else ""
def _get_user_id(jwt: str) -> str | None:
try:
resp = httpx.get(
f"{DIRECTUS_URL}/users/me",
headers={"Authorization": f"Bearer {jwt}"},
timeout=5.0,
)
if resp.status_code == 200:
return resp.json().get("data", {}).get("id")
except Exception as exc:
log.warning("Directus session check failed: %s", exc)
return None
def _ensure_provisioned(user_id: str) -> None:
if not HEIMDALL_ADMIN_TOKEN:
return
try:
httpx.post(
f"{HEIMDALL_URL}/admin/provision",
json={"directus_user_id": user_id, "product": "pagepiper", "tier": "free"},
headers={"Authorization": f"Bearer {HEIMDALL_ADMIN_TOKEN}"},
timeout=5.0,
)
except Exception as exc:
log.warning("Heimdall provision failed for user %s: %s", user_id, exc)
def _get_tier(user_id: str) -> str:
if not HEIMDALL_ADMIN_TOKEN:
return "free"
try:
resp = httpx.get(
f"{HEIMDALL_URL}/admin/cloud/resolve",
params={"directus_user_id": user_id, "product": "pagepiper"},
headers={"Authorization": f"Bearer {HEIMDALL_ADMIN_TOKEN}"},
timeout=5.0,
)
if resp.status_code == 200:
return resp.json().get("tier", "free")
except Exception as exc:
log.warning("Heimdall tier check failed for user %s: %s", user_id, exc)
return "free"
def resolve_authenticated_user(request: Request) -> str:
"""Validate the session cookie and return the Directus user_id. Raises 401 if invalid."""
cookie_header = request.headers.get("x-cf-session", "")
jwt = _extract_session_token(cookie_header)
if not jwt:
raise HTTPException(
status_code=401,
detail={
"error": "auth_required",
"message": "Sign in at circuitforge.tech to use Pagepiper cloud.",
},
)
user_id = _get_user_id(jwt)
if not user_id:
raise HTTPException(
status_code=401,
detail={
"error": "session_invalid",
"message": "Your session has expired. Sign in again at circuitforge.tech.",
},
)
_ensure_provisioned(user_id)
return user_id
def require_paid_tier(request: Request) -> str:
"""FastAPI dependency — 401 if no valid session, 402 if tier < paid. Returns user_id.
In local mode (CLOUD_MODE not set), returns LOCAL_USER_ID without any auth check.
"""
if not CLOUD_MODE:
from app.config import LOCAL_USER_ID
return LOCAL_USER_ID
user_id = resolve_authenticated_user(request)
tier = _get_tier(user_id)
if _TIER_ORDER.get(tier, 0) < _TIER_ORDER["paid"]:
raise HTTPException(
status_code=402,
detail={
"error": "upgrade_required",
"message": (
"RAG chat requires a Paid tier Pagepiper license. "
"Upgrade at circuitforge.tech/software/pagepiper."
),
},
)
return user_id

View file

@ -12,16 +12,36 @@ VEC_DB_PATH = str(DATA_DIR / "pagepiper_vecs.db")
WATCH_DIR = Path(os.environ.get("PAGEPIPER_WATCH_DIR", "books")) WATCH_DIR = Path(os.environ.get("PAGEPIPER_WATCH_DIR", "books"))
VEC_DIMENSIONS = int(os.environ.get("PAGEPIPER_EMBED_DIMS", "1024")) VEC_DIMENSIONS = int(os.environ.get("PAGEPIPER_EMBED_DIMS", "1024"))
LOCAL_USER_ID = "__local__"
def user_data_dir(user_id: str) -> Path:
"""Return (and create) the per-user data directory under DATA_DIR/users/."""
d = DATA_DIR / "users" / user_id
d.mkdir(parents=True, exist_ok=True)
return d
def get_llm_config() -> dict | None: def get_llm_config() -> dict | None:
"""Build LLMRouter config from env vars. Returns None if PAGEPIPER_OLLAMA_URL is unset.""" """Build LLMRouter config from env vars.
Returns None only when neither PAGEPIPER_OLLAMA_URL nor CF_ORCH_URL is set.
CF_ORCH_URL alone is sufficient the coordinator resolves the service URL at
allocation time so PAGEPIPER_OLLAMA_URL becomes optional.
"""
url = os.environ.get("PAGEPIPER_OLLAMA_URL", "").strip() url = os.environ.get("PAGEPIPER_OLLAMA_URL", "").strip()
if not url: orch_url = os.environ.get("CF_ORCH_URL", "").strip()
if not url and not orch_url:
return None return None
_clean = url.rstrip("/")
_base_url = _clean if _clean.endswith("/v1") else _clean + "/v1"
chat_model = os.environ.get("PAGEPIPER_CHAT_MODEL", "mistral:7b") chat_model = os.environ.get("PAGEPIPER_CHAT_MODEL", "mistral:7b")
_base_url = ""
if url:
_clean = url.rstrip("/")
_base_url = _clean if _clean.endswith("/v1") else _clean + "/v1"
backend: dict = { backend: dict = {
"type": "openai_compat", "type": "openai_compat",
"base_url": _base_url, "base_url": _base_url,
@ -30,12 +50,9 @@ def get_llm_config() -> dict | None:
"supports_images": False, "supports_images": False,
} }
# Wire cf-orch allocation when coordinator is configured so the model stays warm
# and cold-start latency doesn't cause chat timeouts.
orch_url = os.environ.get("CF_ORCH_URL", "").strip()
if orch_url: if orch_url:
backend["cf_orch"] = { backend["cf_orch"] = {
"service": "ollama", "service": os.environ.get("PAGEPIPER_ORCH_SERVICE", "ollama"),
"model_candidates": [chat_model], "model_candidates": [chat_model],
"ttl_s": 3600, "ttl_s": 3600,
} }

View file

@ -3,13 +3,75 @@
from __future__ import annotations from __future__ import annotations
import sqlite3 import sqlite3
from dataclasses import dataclass
from pathlib import Path
from typing import Generator from typing import Generator
from app.config import DB_PATH from fastapi import Depends, Request
from app.config import DATA_DIR, LOCAL_USER_ID
from app.services.bm25_index import BM25Index
def get_db() -> Generator[sqlite3.Connection, None, None]: @dataclass
conn = sqlite3.connect(DB_PATH, check_same_thread=False) class UserCtx:
"""Per-request context routing DB paths and BM25 to the right user."""
user_id: str
db_path: str
vec_db_path: str
data_dir: Path
watch_dir: Path
bm25: BM25Index
_user_startup_done: set[str] = set()
def _run_user_startup(user_id: str, user_dir: Path) -> None:
"""Run migrations and vec schema check once per process lifetime per user."""
if user_id in _user_startup_done:
return
_user_startup_done.add(user_id)
from app.config import VEC_DIMENSIONS
from app.startup import apply_migrations, check_and_rebuild_vec_schema
apply_migrations(str(user_dir / "pagepiper.db"))
check_and_rebuild_vec_schema(
str(user_dir / "pagepiper_vecs.db"), VEC_DIMENSIONS, str(user_dir / "pagepiper.db")
)
def get_user_ctx(request: Request) -> UserCtx:
"""Resolve the per-user data directory, DB paths, and BM25 instance for this request."""
import app.main as _main
from app.cloud_session import CLOUD_MODE
if CLOUD_MODE:
from app.cloud_session import resolve_authenticated_user
from app.config import user_data_dir
user_id = resolve_authenticated_user(request)
user_dir = user_data_dir(user_id)
_run_user_startup(user_id, user_dir)
watch_dir = user_dir / "books"
watch_dir.mkdir(parents=True, exist_ok=True)
else:
from app.config import WATCH_DIR
user_id = LOCAL_USER_ID
user_dir = DATA_DIR
watch_dir = WATCH_DIR
return UserCtx(
user_id=user_id,
db_path=str(user_dir / "pagepiper.db"),
vec_db_path=str(user_dir / "pagepiper_vecs.db"),
data_dir=user_dir,
watch_dir=watch_dir,
bm25=_main._get_bm25_for(user_id),
)
def get_db(ctx: UserCtx = Depends(get_user_ctx)) -> Generator[sqlite3.Connection, None, None]:
conn = sqlite3.connect(ctx.db_path, check_same_thread=False)
conn.execute("PRAGMA foreign_keys = ON") conn.execute("PRAGMA foreign_keys = ON")
conn.execute("PRAGMA journal_mode = WAL") conn.execute("PRAGMA journal_mode = WAL")
conn.row_factory = sqlite3.Row conn.row_factory = sqlite3.Row

View file

@ -4,9 +4,6 @@ from __future__ import annotations
import logging import logging
import os import os
import re
import sqlite3
import threading
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from fastapi import FastAPI from fastapi import FastAPI
@ -16,110 +13,36 @@ from app.services.bm25_index import BM25Index
logger = logging.getLogger("pagepiper") logger = logging.getLogger("pagepiper")
# Module-level BM25 singleton — shared across all requests # Per-user BM25 registry — keyed by user_id; "__local__" for single-user mode
_bm25 = BM25Index() _bm25_map: dict[str, BM25Index] = {}
def _apply_migrations() -> None: def _get_bm25_for(user_id: str) -> BM25Index:
from scripts.db_migrate import migrate if user_id not in _bm25_map:
migrate(DB_PATH) _bm25_map[user_id] = BM25Index()
return _bm25_map[user_id]
def _reembed_docs(docs: list[tuple[str, str]], db_path: str, vec_db_path: str) -> None:
"""Re-run full ingest for a list of (doc_id, file_path) sequentially."""
for doc_id, file_path in docs:
suffix = os.path.splitext(file_path)[1].lower()
try:
if suffix == ".epub":
from scripts.ingest_epub import run
else:
from scripts.ingest_pdf import run
logger.info("Auto re-embed: starting %s", os.path.basename(file_path))
run(doc_id=doc_id, file_path=file_path, db_path=db_path, vec_db_path=vec_db_path)
except Exception as exc:
logger.error("Auto re-embed failed for doc %s: %s", doc_id[:8], exc)
def _check_vec_schema(vec_db_path: str, expected_dims: int, db_path: str) -> None:
"""Drop the vec DB if its stored dimension doesn't match config, then queue re-embed.
sqlite-vec bakes the embedding dimension into the virtual table DDL, so changing
models requires dropping and recreating the whole file. Catches the mismatch at
startup rather than surfacing it as an obscure OperationalError mid-request.
"""
if not os.path.exists(vec_db_path):
return
try:
conn = sqlite3.connect(vec_db_path)
row = conn.execute(
"SELECT sql FROM sqlite_master WHERE name='page_vecs_vecs'"
).fetchone()
conn.close()
except Exception as exc:
logger.warning("Vec schema check could not read %s (non-fatal): %s", vec_db_path, exc)
return
if not row:
return # table not yet created — first embed will build it with the right dims
m = re.search(r'float\[(\d+)\]', row[0])
if not m:
return
actual_dims = int(m.group(1))
if actual_dims == expected_dims:
return
logger.warning(
"Vec DB dimension mismatch: stored=%d, configured=%d — dropping %s and queuing re-embed",
actual_dims, expected_dims, vec_db_path,
)
try:
os.remove(vec_db_path)
except OSError as exc:
logger.error(
"Could not delete stale vec DB %s: %s — fix permissions and restart", vec_db_path, exc
)
return
# Collect all ready docs so we can rebuild their embeddings in the background.
try:
conn = sqlite3.connect(db_path)
docs = conn.execute(
"SELECT id, file_path FROM documents WHERE status='ready'"
).fetchall()
conn.close()
except Exception as exc:
logger.warning("Could not query documents for re-embed: %s", exc)
return
if not docs:
return
logger.info("Queuing re-embed for %d document(s) in background", len(docs))
threading.Thread(
target=_reembed_docs,
args=(docs, db_path, vec_db_path),
daemon=True,
name="pagepiper-reembed",
).start()
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
_apply_migrations() from app.cloud_session import CLOUD_MODE
from app.config import LOCAL_USER_ID
from app.startup import apply_migrations, check_and_rebuild_vec_schema
embed_model = os.environ.get("PAGEPIPER_EMBED_MODEL", "nomic-embed-text") embed_model = os.environ.get("PAGEPIPER_EMBED_MODEL", "nomic-embed-text")
logger.info("Pagepiper starting — embed model: %s, dims: %d", embed_model, VEC_DIMENSIONS) logger.info("Pagepiper starting — embed model: %s, dims: %d", embed_model, VEC_DIMENSIONS)
_check_vec_schema(VEC_DB_PATH, VEC_DIMENSIONS, DB_PATH)
_bm25.mark_dirty() # will rebuild on first search if not CLOUD_MODE:
# In cloud mode, per-user migration and vec schema check run on first request (deps.py).
apply_migrations(DB_PATH)
check_and_rebuild_vec_schema(VEC_DB_PATH, VEC_DIMENSIONS, DB_PATH)
_get_bm25_for(LOCAL_USER_ID).mark_dirty()
yield yield
app = FastAPI(title="Pagepiper", lifespan=lifespan) app = FastAPI(title="Pagepiper", lifespan=lifespan)
# Wire BM25 dirty callback into library router
from app.api import library as _lib_module # noqa: E402
_lib_module._mark_bm25_dirty = _bm25.mark_dirty
# Register routers # Register routers
from app.api.library import router as library_router # noqa: E402 from app.api.library import router as library_router # noqa: E402
from app.api.ingest import router as ingest_router # noqa: E402 from app.api.ingest import router as ingest_router # noqa: E402

95
app/startup.py Normal file
View file

@ -0,0 +1,95 @@
# app/startup.py
"""DB migration and vec schema check utilities — called at startup and on first user request."""
from __future__ import annotations
import logging
import os
import re
import sqlite3
import threading
logger = logging.getLogger("pagepiper")
def apply_migrations(db_path: str) -> None:
from scripts.db_migrate import migrate
migrate(db_path)
def reembed_docs(docs: list[tuple[str, str]], db_path: str, vec_db_path: str) -> None:
for doc_id, file_path in docs:
suffix = os.path.splitext(file_path)[1].lower()
try:
if suffix == ".epub":
from scripts.ingest_epub import run
elif suffix == ".docx":
from scripts.ingest_docx import run
else:
from scripts.ingest_pdf import run
logger.info("Auto re-embed: starting %s", os.path.basename(file_path))
run(doc_id=doc_id, file_path=file_path, db_path=db_path, vec_db_path=vec_db_path)
except Exception as exc:
logger.error("Auto re-embed failed for doc %s: %s", doc_id[:8], exc)
def check_and_rebuild_vec_schema(vec_db_path: str, expected_dims: int, db_path: str) -> None:
"""Drop the vec DB if its stored dimension doesn't match config, then queue re-embed.
sqlite-vec bakes the embedding dimension into the virtual table DDL, so changing
models requires dropping and recreating the whole file. Catches the mismatch at
startup rather than surfacing it as an obscure OperationalError mid-request.
"""
if not os.path.exists(vec_db_path):
return
try:
conn = sqlite3.connect(vec_db_path)
row = conn.execute(
"SELECT sql FROM sqlite_master WHERE name='page_vecs_vecs'"
).fetchone()
conn.close()
except Exception as exc:
logger.warning("Vec schema check could not read %s (non-fatal): %s", vec_db_path, exc)
return
if not row:
return
m = re.search(r'float\[(\d+)\]', row[0])
if not m:
return
actual_dims = int(m.group(1))
if actual_dims == expected_dims:
return
logger.warning(
"Vec DB dimension mismatch: stored=%d, configured=%d — dropping %s and queuing re-embed",
actual_dims, expected_dims, vec_db_path,
)
try:
os.remove(vec_db_path)
except OSError as exc:
logger.error(
"Could not delete stale vec DB %s: %s — fix permissions and restart", vec_db_path, exc
)
return
try:
conn = sqlite3.connect(db_path)
docs = conn.execute(
"SELECT id, file_path FROM documents WHERE status='ready'"
).fetchall()
conn.close()
except Exception as exc:
logger.warning("Could not query documents for re-embed: %s", exc)
return
if not docs:
return
logger.info("Queuing re-embed for %d document(s) in background", len(docs))
threading.Thread(
target=reembed_docs,
args=(docs, db_path, vec_db_path),
daemon=True,
name="pagepiper-reembed",
).start()

View file

@ -27,13 +27,32 @@ def client(test_db, tmp_path, monkeypatch):
(tmp_path / "books").mkdir(exist_ok=True) (tmp_path / "books").mkdir(exist_ok=True)
import app.main as _main_module import app.main as _main_module
from app.main import app, _bm25 from app.config import LOCAL_USER_ID
from app.deps import get_db from app.deps import UserCtx, get_db, get_user_ctx
from app.main import app
from app.services.bm25_index import BM25Index
from app.startup import apply_migrations, check_and_rebuild_vec_schema
# Suppress startup side effects — test_db fixture already applies the schema, monkeypatch.setattr(_main_module, "_apply_migrations", lambda: None, raising=False)
# and vec schema validation is tested separately in test_startup.py monkeypatch.setattr(
monkeypatch.setattr(_main_module, "_apply_migrations", lambda: None) "app.startup.apply_migrations", lambda *a, **kw: None
monkeypatch.setattr(_main_module, "_check_vec_schema", lambda *a, **kw: None) )
monkeypatch.setattr(
"app.startup.check_and_rebuild_vec_schema", lambda *a, **kw: None
)
test_bm25 = BM25Index()
test_bm25.mark_dirty()
def override_user_ctx():
return UserCtx(
user_id=LOCAL_USER_ID,
db_path=test_db,
vec_db_path=str(tmp_path / "test_vecs.db"),
data_dir=Path(tmp_path),
watch_dir=Path(tmp_path) / "books",
bm25=test_bm25,
)
def override_db(): def override_db():
conn = sqlite3.connect(test_db) conn = sqlite3.connect(test_db)
@ -44,7 +63,7 @@ def client(test_db, tmp_path, monkeypatch):
finally: finally:
conn.close() conn.close()
app.dependency_overrides[get_user_ctx] = override_user_ctx
app.dependency_overrides[get_db] = override_db app.dependency_overrides[get_db] = override_db
_bm25.mark_dirty() # clear any state from previous tests
yield TestClient(app) yield TestClient(app)
app.dependency_overrides.clear() app.dependency_overrides.clear()

View file

@ -20,9 +20,7 @@ def _add_chunks(db_path: str, doc_id: str, chunks: list[dict]) -> None:
conn.close() conn.close()
def test_search_returns_results(client, test_db, monkeypatch): def test_search_returns_results(client, test_db):
import app.api.search as _search_mod
monkeypatch.setattr(_search_mod, "_get_db_path", lambda: test_db)
# BM25Okapi IDF is 0 when df == N/2 (e.g. 2 docs, 1 match → log(1.0) = 0). # BM25Okapi IDF is 0 when df == N/2 (e.g. 2 docs, 1 match → log(1.0) = 0).
# Add a 3rd unrelated chunk so relevant terms score above zero. # Add a 3rd unrelated chunk so relevant terms score above zero.
_add_chunks(test_db, "book-a", [ _add_chunks(test_db, "book-a", [
@ -46,9 +44,7 @@ def test_search_empty_index_returns_empty(client):
assert resp.json() == [] assert resp.json() == []
def test_search_filters_by_doc_ids(client, test_db, monkeypatch): def test_search_filters_by_doc_ids(client, test_db):
import app.api.search as _search_mod
monkeypatch.setattr(_search_mod, "_get_db_path", lambda: test_db)
# Three chunks so BM25Okapi IDF is non-zero for terms appearing in one doc. # Three chunks so BM25Okapi IDF is non-zero for terms appearing in one doc.
_add_chunks(test_db, "book-a", [ _add_chunks(test_db, "book-a", [
{"page_number": 1, "text": "Grapple rules for melee attacks."}, {"page_number": 1, "text": "Grapple rules for melee attacks."},

View file

@ -9,7 +9,8 @@ from unittest.mock import MagicMock, patch
import pytest import pytest
from app.main import _check_vec_schema, _reembed_docs from app.startup import check_and_rebuild_vec_schema as _check_vec_schema
from app.startup import reembed_docs as _reembed_docs
def _make_vec_db(path: str, dims: int) -> None: def _make_vec_db(path: str, dims: int) -> None: