feat: per-user database isolation for cloud instances (closes #4)
Implements Option A from the issue design: each cloud user gets their own
data directory (DATA_DIR/users/{user_id}/) with separate pagepiper.db,
pagepiper_vecs.db, uploads/, and books/. Local mode is unchanged.
Key changes:
- app/startup.py: extract apply_migrations, reembed_docs,
check_and_rebuild_vec_schema out of main.py (no circular imports)
- app/config.py: add LOCAL_USER_ID constant and user_data_dir() helper
- app/cloud_session.py: extract resolve_authenticated_user(); require_paid_tier
now returns user_id (str) instead of None
- app/deps.py: add UserCtx dataclass (db_path, vec_db_path, data_dir,
watch_dir, bm25) + get_user_ctx dependency; per-user startup guard runs
migrations + vec schema check once per process per user
- app/main.py: _bm25 singleton -> _bm25_map dict keyed by user_id;
add _get_bm25_for(); lifespan only runs startup checks in local mode
- app/api/library.py, search.py, chat.py: thread UserCtx through all
endpoints; remove module-level _mark_bm25_dirty injection pattern
- tests/conftest.py: override get_user_ctx in addition to get_db so all
endpoints get a consistent test UserCtx
This commit is contained in:
parent
df9e91ad89
commit
8eef52a054
11 changed files with 432 additions and 199 deletions
|
|
@ -10,9 +10,11 @@ from __future__ import annotations
|
|||
import logging
|
||||
import os
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.cloud_session import require_paid_tier
|
||||
from app.deps import UserCtx, get_user_ctx
|
||||
from app.services.retriever import Retriever
|
||||
from app.services.synthesizer import Synthesizer
|
||||
|
||||
|
|
@ -56,21 +58,6 @@ def _get_llm_router():
|
|||
return LLMRouter(cfg)
|
||||
|
||||
|
||||
def _get_db_path() -> str:
|
||||
"""Read lazily so test fixtures take effect."""
|
||||
import pathlib
|
||||
|
||||
data_dir = pathlib.Path(os.environ.get("PAGEPIPER_DATA_DIR", "data"))
|
||||
return str(data_dir / "pagepiper.db")
|
||||
|
||||
|
||||
def _get_vec_db_path() -> str:
|
||||
import pathlib
|
||||
|
||||
data_dir = pathlib.Path(os.environ.get("PAGEPIPER_DATA_DIR", "data"))
|
||||
return str(data_dir / "pagepiper_vecs.db")
|
||||
|
||||
|
||||
def _require_llm():
|
||||
"""Return LLMRouter or raise 402."""
|
||||
llm = _get_llm_router()
|
||||
|
|
@ -89,18 +76,20 @@ def _require_llm():
|
|||
|
||||
|
||||
@router.post("")
|
||||
def chat(req: ChatRequest) -> ChatResponse:
|
||||
def chat(
|
||||
req: ChatRequest,
|
||||
ctx: UserCtx = Depends(get_user_ctx),
|
||||
_tier: str = Depends(require_paid_tier),
|
||||
) -> ChatResponse:
|
||||
llm = _require_llm()
|
||||
|
||||
from app.main import _bm25
|
||||
|
||||
retriever = Retriever(_bm25)
|
||||
retriever = Retriever(ctx.bm25)
|
||||
chunks = retriever.hybrid_search(
|
||||
query=req.message,
|
||||
top_k=req.top_k,
|
||||
doc_ids=req.doc_ids,
|
||||
db_path=_get_db_path(),
|
||||
vec_db_path=_get_vec_db_path(),
|
||||
db_path=ctx.db_path,
|
||||
vec_db_path=ctx.vec_db_path,
|
||||
llm=llm,
|
||||
)
|
||||
|
||||
|
|
@ -141,7 +130,10 @@ def chat_feedback_status() -> dict:
|
|||
|
||||
|
||||
@router.post("/feedback")
|
||||
def submit_chat_feedback(req: ChatFeedbackRequest) -> dict:
|
||||
def submit_chat_feedback(
|
||||
req: ChatFeedbackRequest,
|
||||
ctx: UserCtx = Depends(get_user_ctx),
|
||||
) -> dict:
|
||||
import json
|
||||
import sqlite3
|
||||
|
||||
|
|
@ -149,8 +141,7 @@ def submit_chat_feedback(req: ChatFeedbackRequest) -> dict:
|
|||
from fastapi import HTTPException
|
||||
raise HTTPException(status_code=422, detail="rating must be 1 or -1")
|
||||
|
||||
db_path = _get_db_path()
|
||||
con = sqlite3.connect(db_path)
|
||||
con = sqlite3.connect(ctx.db_path)
|
||||
try:
|
||||
con.execute(
|
||||
"INSERT INTO chat_feedback (rating, question, answer, doc_ids) VALUES (?, ?, ?, ?)",
|
||||
|
|
|
|||
|
|
@ -14,26 +14,25 @@ from typing import Callable
|
|||
|
||||
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, UploadFile
|
||||
|
||||
from app.config import WATCH_DIR, DB_PATH, VEC_DB_PATH, DATA_DIR
|
||||
from app.deps import get_db
|
||||
from app.config import VEC_DIMENSIONS
|
||||
from app.deps import UserCtx, get_db, get_user_ctx
|
||||
|
||||
_MAX_UPLOAD_BYTES = 200 * 1024 * 1024 # 200 MB
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api/library", tags=["library"])
|
||||
|
||||
# Injected by main.py after _bm25 is created
|
||||
_mark_bm25_dirty: Callable[[], None] | None = None
|
||||
|
||||
|
||||
_INGEST_TASKS = {
|
||||
".pdf": "pagepiper/ingest_pdf",
|
||||
".epub": "pagepiper/ingest_epub",
|
||||
".docx": "pagepiper/ingest_docx",
|
||||
}
|
||||
|
||||
_INGEST_RUNNERS = {
|
||||
".pdf": "scripts.ingest_pdf",
|
||||
".epub": "scripts.ingest_epub",
|
||||
".docx": "scripts.ingest_docx",
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -41,24 +40,22 @@ def _dispatch_ingest(
|
|||
doc_id: str,
|
||||
file_path: str,
|
||||
background_tasks: BackgroundTasks,
|
||||
data_dir: Path,
|
||||
mark_dirty_fn: Callable[[], None],
|
||||
) -> str:
|
||||
"""Dispatch an ingest task. Tries cf-orch; falls back to BackgroundTasks."""
|
||||
import importlib
|
||||
import os as _os
|
||||
from pathlib import Path as _Path
|
||||
|
||||
suffix = _Path(file_path).suffix.lower()
|
||||
suffix = Path(file_path).suffix.lower()
|
||||
task_name = _INGEST_TASKS.get(suffix, "pagepiper/ingest_pdf")
|
||||
runner_module = _INGEST_RUNNERS.get(suffix, "scripts.ingest_pdf")
|
||||
|
||||
# Read lazily so test fixtures (monkeypatch.setenv) take effect
|
||||
_data_dir = _Path(_os.environ.get("PAGEPIPER_DATA_DIR", "data"))
|
||||
task_id = str(uuid.uuid4())
|
||||
args = {
|
||||
"doc_id": doc_id,
|
||||
"file_path": file_path,
|
||||
"db_path": str(_data_dir / "pagepiper.db"),
|
||||
"vec_db_path": str(_data_dir / "pagepiper_vecs.db"),
|
||||
"db_path": str(data_dir / "pagepiper.db"),
|
||||
"vec_db_path": str(data_dir / "pagepiper_vecs.db"),
|
||||
}
|
||||
|
||||
try:
|
||||
|
|
@ -67,7 +64,7 @@ def _dispatch_ingest(
|
|||
logger.info("Dispatched cf-orch ingest task %s for doc %s", task_id, doc_id)
|
||||
except Exception:
|
||||
mod = importlib.import_module(runner_module)
|
||||
background_tasks.add_task(_run_ingest_background, mod.run, args, task_id)
|
||||
background_tasks.add_task(_run_ingest_background, mod.run, args, task_id, mark_dirty_fn)
|
||||
logger.info(
|
||||
"cf-orch unavailable — running ingest in background thread (task %s)", task_id
|
||||
)
|
||||
|
|
@ -75,14 +72,19 @@ def _dispatch_ingest(
|
|||
return task_id
|
||||
|
||||
|
||||
def _run_ingest_background(run_fn: Callable[..., None], args: dict, task_id: str) -> None:
|
||||
def _run_ingest_background(
|
||||
run_fn: Callable[..., None],
|
||||
args: dict,
|
||||
task_id: str,
|
||||
mark_dirty_fn: Callable[[], None] | None = None,
|
||||
) -> None:
|
||||
from app.api.ingest import _task_registry
|
||||
_task_registry[task_id] = {"status": "running", "progress": 0}
|
||||
try:
|
||||
run_fn(**args)
|
||||
_task_registry[task_id] = {"status": "complete", "progress": 100}
|
||||
if _mark_bm25_dirty:
|
||||
_mark_bm25_dirty()
|
||||
if mark_dirty_fn:
|
||||
mark_dirty_fn()
|
||||
except Exception as exc:
|
||||
logger.exception("Ingest task %s failed", task_id)
|
||||
_task_registry[task_id] = {"status": "error", "error": str(exc)}
|
||||
|
|
@ -101,13 +103,18 @@ def list_library(db: sqlite3.Connection = Depends(get_db)) -> list[dict]:
|
|||
def scan_library(
|
||||
background_tasks: BackgroundTasks,
|
||||
db: sqlite3.Connection = Depends(get_db),
|
||||
ctx: UserCtx = Depends(get_user_ctx),
|
||||
) -> dict:
|
||||
"""Scan the watched directory and queue ingest for any new PDFs."""
|
||||
watch = WATCH_DIR
|
||||
watch = ctx.watch_dir
|
||||
if not watch.exists():
|
||||
raise HTTPException(status_code=404, detail=f"Watch directory not found: {watch}")
|
||||
|
||||
pdfs = list(watch.glob("**/*.pdf")) + list(watch.glob("**/*.epub"))
|
||||
pdfs = (
|
||||
list(watch.glob("**/*.pdf"))
|
||||
+ list(watch.glob("**/*.epub"))
|
||||
+ list(watch.glob("**/*.docx"))
|
||||
)
|
||||
queued = []
|
||||
|
||||
for pdf_path in pdfs:
|
||||
|
|
@ -117,7 +124,7 @@ def scan_library(
|
|||
).fetchone()
|
||||
|
||||
if existing and existing["status"] == "ready":
|
||||
continue # already indexed
|
||||
continue
|
||||
|
||||
if existing:
|
||||
doc_id = existing["id"]
|
||||
|
|
@ -129,7 +136,9 @@ def scan_library(
|
|||
).fetchone()[0]
|
||||
db.commit()
|
||||
|
||||
task_id = _dispatch_ingest(doc_id, path_str, background_tasks)
|
||||
task_id = _dispatch_ingest(
|
||||
doc_id, path_str, background_tasks, ctx.data_dir, ctx.bm25.mark_dirty
|
||||
)
|
||||
db.execute(
|
||||
"UPDATE documents SET status='processing', task_id=? WHERE id=?",
|
||||
[task_id, doc_id],
|
||||
|
|
@ -145,12 +154,15 @@ def reingest_document(
|
|||
doc_id: str,
|
||||
background_tasks: BackgroundTasks,
|
||||
db: sqlite3.Connection = Depends(get_db),
|
||||
ctx: UserCtx = Depends(get_user_ctx),
|
||||
) -> dict:
|
||||
row = db.execute("SELECT file_path FROM documents WHERE id=?", [doc_id]).fetchone()
|
||||
if not row:
|
||||
raise HTTPException(status_code=404, detail="Document not found")
|
||||
|
||||
task_id = _dispatch_ingest(doc_id, row["file_path"], background_tasks)
|
||||
task_id = _dispatch_ingest(
|
||||
doc_id, row["file_path"], background_tasks, ctx.data_dir, ctx.bm25.mark_dirty
|
||||
)
|
||||
db.execute(
|
||||
"UPDATE documents SET status='processing', task_id=?, error_msg=NULL WHERE id=?",
|
||||
[task_id, doc_id],
|
||||
|
|
@ -163,6 +175,7 @@ def reingest_document(
|
|||
def delete_document(
|
||||
doc_id: str,
|
||||
db: sqlite3.Connection = Depends(get_db),
|
||||
ctx: UserCtx = Depends(get_user_ctx),
|
||||
) -> None:
|
||||
row = db.execute("SELECT id FROM documents WHERE id=?", [doc_id]).fetchone()
|
||||
if not row:
|
||||
|
|
@ -171,23 +184,21 @@ def delete_document(
|
|||
db.execute("DELETE FROM documents WHERE id=?", [doc_id])
|
||||
db.commit()
|
||||
|
||||
# Remove embeddings from vector store
|
||||
try:
|
||||
from circuitforge_core.vector.sqlite_vec import LocalSQLiteVecStore # type: ignore[import]
|
||||
from app.config import VEC_DIMENSIONS
|
||||
store = LocalSQLiteVecStore(db_path=VEC_DB_PATH, table="page_vecs", dimensions=VEC_DIMENSIONS)
|
||||
store = LocalSQLiteVecStore(
|
||||
db_path=ctx.vec_db_path, table="page_vecs", dimensions=VEC_DIMENSIONS
|
||||
)
|
||||
store.delete_where({"doc_id": doc_id})
|
||||
except Exception as exc:
|
||||
logger.warning("Could not remove vectors for doc %s: %s", doc_id, exc)
|
||||
|
||||
if _mark_bm25_dirty:
|
||||
_mark_bm25_dirty()
|
||||
ctx.bm25.mark_dirty()
|
||||
|
||||
|
||||
def _get_vec_count(doc_id: str) -> int:
|
||||
"""Return how many vectors have been stored for this doc. Returns 0 on any error."""
|
||||
def _get_vec_count(doc_id: str, vec_db_path: str) -> int:
|
||||
try:
|
||||
conn = sqlite3.connect(VEC_DB_PATH)
|
||||
conn = sqlite3.connect(vec_db_path)
|
||||
count = conn.execute(
|
||||
"SELECT COUNT(*) FROM page_vecs_meta WHERE json_extract(metadata, '$.doc_id') = ?",
|
||||
[doc_id],
|
||||
|
|
@ -202,6 +213,7 @@ def _get_vec_count(doc_id: str) -> int:
|
|||
def document_status(
|
||||
doc_id: str,
|
||||
db: sqlite3.Connection = Depends(get_db),
|
||||
ctx: UserCtx = Depends(get_user_ctx),
|
||||
) -> dict:
|
||||
row = db.execute(
|
||||
"SELECT id, status, task_id, page_count, error_msg FROM documents WHERE id=?",
|
||||
|
|
@ -210,7 +222,7 @@ def document_status(
|
|||
if not row:
|
||||
raise HTTPException(status_code=404, detail="Document not found")
|
||||
result = dict(row)
|
||||
result["vec_count"] = _get_vec_count(doc_id)
|
||||
result["vec_count"] = _get_vec_count(doc_id, ctx.vec_db_path)
|
||||
return result
|
||||
|
||||
|
||||
|
|
@ -219,18 +231,19 @@ def upload_document(
|
|||
file: UploadFile,
|
||||
background_tasks: BackgroundTasks,
|
||||
db: sqlite3.Connection = Depends(get_db),
|
||||
ctx: UserCtx = Depends(get_user_ctx),
|
||||
) -> dict:
|
||||
"""Accept a PDF/EPUB upload, save to data/uploads/, and queue for indexing."""
|
||||
name = Path(file.filename or "").name
|
||||
suffix = Path(name).suffix.lower()
|
||||
if suffix not in _INGEST_TASKS:
|
||||
raise HTTPException(status_code=400, detail="Supported formats: PDF, EPUB")
|
||||
raise HTTPException(status_code=400, detail="Supported formats: PDF, EPUB, DOCX")
|
||||
|
||||
content = file.file.read()
|
||||
if len(content) > _MAX_UPLOAD_BYTES:
|
||||
raise HTTPException(status_code=413, detail="File exceeds 200 MB limit")
|
||||
|
||||
upload_dir = DATA_DIR / "uploads"
|
||||
upload_dir = ctx.data_dir / "uploads"
|
||||
upload_dir.mkdir(parents=True, exist_ok=True)
|
||||
dest = upload_dir / name
|
||||
dest.write_bytes(content)
|
||||
|
|
@ -253,7 +266,9 @@ def upload_document(
|
|||
).fetchone()[0]
|
||||
db.commit()
|
||||
|
||||
task_id = _dispatch_ingest(doc_id, path_str, background_tasks)
|
||||
task_id = _dispatch_ingest(
|
||||
doc_id, path_str, background_tasks, ctx.data_dir, ctx.bm25.mark_dirty
|
||||
)
|
||||
db.execute(
|
||||
"UPDATE documents SET status='processing', task_id=? WHERE id=?",
|
||||
[task_id, doc_id],
|
||||
|
|
|
|||
|
|
@ -7,13 +7,11 @@ MIT — no tier gate. No Ollama required.
|
|||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.services.bm25_index import BM25Index
|
||||
from app.deps import UserCtx, get_user_ctx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api/search", tags=["search"])
|
||||
|
|
@ -29,32 +27,17 @@ class SearchResult(BaseModel):
|
|||
chunk_id: str
|
||||
doc_id: str
|
||||
page_number: int
|
||||
text_snippet: str # first 300 chars of the page text
|
||||
text_snippet: str
|
||||
bm25_score: float
|
||||
|
||||
|
||||
def _get_bm25() -> BM25Index:
|
||||
import app.main as _main
|
||||
bm25 = getattr(_main, "_bm25", None)
|
||||
if bm25 is None:
|
||||
raise RuntimeError("BM25 index not initialised — app.main not loaded")
|
||||
return bm25
|
||||
|
||||
|
||||
def _get_db_path() -> str:
|
||||
"""Read lazily so test fixtures (monkeypatch.setattr) take effect."""
|
||||
import pathlib
|
||||
data_dir = pathlib.Path(os.environ.get("PAGEPIPER_DATA_DIR", "data"))
|
||||
return str(data_dir / "pagepiper.db")
|
||||
|
||||
|
||||
@router.post("")
|
||||
def search(
|
||||
req: SearchRequest,
|
||||
bm25: Annotated[BM25Index, Depends(_get_bm25)],
|
||||
ctx: UserCtx = Depends(get_user_ctx),
|
||||
) -> list[SearchResult]:
|
||||
bm25.ensure_fresh(_get_db_path())
|
||||
hits = bm25.query(req.query, top_k=req.top_k, doc_ids=req.doc_ids)
|
||||
ctx.bm25.ensure_fresh(ctx.db_path)
|
||||
hits = ctx.bm25.query(req.query, top_k=req.top_k, doc_ids=req.doc_ids)
|
||||
return [
|
||||
SearchResult(
|
||||
chunk_id=h.chunk_id,
|
||||
|
|
|
|||
131
app/cloud_session.py
Normal file
131
app/cloud_session.py
Normal file
|
|
@ -0,0 +1,131 @@
|
|||
# app/cloud_session.py
|
||||
"""Cloud session auth for Pagepiper — validates cf_session cookie via Directus + Heimdall.
|
||||
|
||||
In local mode (CLOUD_MODE unset or false), require_paid_tier is a no-op.
|
||||
In cloud mode, the Caddy proxy forwards the browser's Cookie header as
|
||||
X-CF-Session. This module extracts cf_session, validates it against
|
||||
Directus /users/me, then checks the user's Pagepiper tier via Heimdall.
|
||||
Auto-provisions a free tier key for new users.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
|
||||
import httpx
|
||||
from fastapi import HTTPException, Request
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
CLOUD_MODE: bool = os.environ.get("CLOUD_MODE", "").lower() in ("1", "true", "yes")
|
||||
DIRECTUS_URL: str = os.environ.get("DIRECTUS_URL", "http://172.31.0.3:8055").rstrip("/")
|
||||
HEIMDALL_URL: str = os.environ.get("HEIMDALL_URL", "https://license.circuitforge.tech").rstrip("/")
|
||||
HEIMDALL_ADMIN_TOKEN: str = os.environ.get("HEIMDALL_ADMIN_TOKEN", "")
|
||||
|
||||
_TIER_ORDER = {"free": 0, "paid": 1, "premium": 2, "ultra": 3}
|
||||
|
||||
|
||||
def _extract_session_token(cookie_header: str) -> str:
|
||||
m = re.search(r"(?:^|;)\s*cf_session=([^;]+)", cookie_header)
|
||||
return m.group(1).strip() if m else ""
|
||||
|
||||
|
||||
def _get_user_id(jwt: str) -> str | None:
|
||||
try:
|
||||
resp = httpx.get(
|
||||
f"{DIRECTUS_URL}/users/me",
|
||||
headers={"Authorization": f"Bearer {jwt}"},
|
||||
timeout=5.0,
|
||||
)
|
||||
if resp.status_code == 200:
|
||||
return resp.json().get("data", {}).get("id")
|
||||
except Exception as exc:
|
||||
log.warning("Directus session check failed: %s", exc)
|
||||
return None
|
||||
|
||||
|
||||
def _ensure_provisioned(user_id: str) -> None:
|
||||
if not HEIMDALL_ADMIN_TOKEN:
|
||||
return
|
||||
try:
|
||||
httpx.post(
|
||||
f"{HEIMDALL_URL}/admin/provision",
|
||||
json={"directus_user_id": user_id, "product": "pagepiper", "tier": "free"},
|
||||
headers={"Authorization": f"Bearer {HEIMDALL_ADMIN_TOKEN}"},
|
||||
timeout=5.0,
|
||||
)
|
||||
except Exception as exc:
|
||||
log.warning("Heimdall provision failed for user %s: %s", user_id, exc)
|
||||
|
||||
|
||||
def _get_tier(user_id: str) -> str:
|
||||
if not HEIMDALL_ADMIN_TOKEN:
|
||||
return "free"
|
||||
try:
|
||||
resp = httpx.get(
|
||||
f"{HEIMDALL_URL}/admin/cloud/resolve",
|
||||
params={"directus_user_id": user_id, "product": "pagepiper"},
|
||||
headers={"Authorization": f"Bearer {HEIMDALL_ADMIN_TOKEN}"},
|
||||
timeout=5.0,
|
||||
)
|
||||
if resp.status_code == 200:
|
||||
return resp.json().get("tier", "free")
|
||||
except Exception as exc:
|
||||
log.warning("Heimdall tier check failed for user %s: %s", user_id, exc)
|
||||
return "free"
|
||||
|
||||
|
||||
def resolve_authenticated_user(request: Request) -> str:
|
||||
"""Validate the session cookie and return the Directus user_id. Raises 401 if invalid."""
|
||||
cookie_header = request.headers.get("x-cf-session", "")
|
||||
jwt = _extract_session_token(cookie_header)
|
||||
|
||||
if not jwt:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail={
|
||||
"error": "auth_required",
|
||||
"message": "Sign in at circuitforge.tech to use Pagepiper cloud.",
|
||||
},
|
||||
)
|
||||
|
||||
user_id = _get_user_id(jwt)
|
||||
if not user_id:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail={
|
||||
"error": "session_invalid",
|
||||
"message": "Your session has expired. Sign in again at circuitforge.tech.",
|
||||
},
|
||||
)
|
||||
|
||||
_ensure_provisioned(user_id)
|
||||
return user_id
|
||||
|
||||
|
||||
def require_paid_tier(request: Request) -> str:
|
||||
"""FastAPI dependency — 401 if no valid session, 402 if tier < paid. Returns user_id.
|
||||
|
||||
In local mode (CLOUD_MODE not set), returns LOCAL_USER_ID without any auth check.
|
||||
"""
|
||||
if not CLOUD_MODE:
|
||||
from app.config import LOCAL_USER_ID
|
||||
return LOCAL_USER_ID
|
||||
|
||||
user_id = resolve_authenticated_user(request)
|
||||
tier = _get_tier(user_id)
|
||||
|
||||
if _TIER_ORDER.get(tier, 0) < _TIER_ORDER["paid"]:
|
||||
raise HTTPException(
|
||||
status_code=402,
|
||||
detail={
|
||||
"error": "upgrade_required",
|
||||
"message": (
|
||||
"RAG chat requires a Paid tier Pagepiper license. "
|
||||
"Upgrade at circuitforge.tech/software/pagepiper."
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
return user_id
|
||||
|
|
@ -12,16 +12,36 @@ VEC_DB_PATH = str(DATA_DIR / "pagepiper_vecs.db")
|
|||
WATCH_DIR = Path(os.environ.get("PAGEPIPER_WATCH_DIR", "books"))
|
||||
VEC_DIMENSIONS = int(os.environ.get("PAGEPIPER_EMBED_DIMS", "1024"))
|
||||
|
||||
LOCAL_USER_ID = "__local__"
|
||||
|
||||
|
||||
def user_data_dir(user_id: str) -> Path:
|
||||
"""Return (and create) the per-user data directory under DATA_DIR/users/."""
|
||||
d = DATA_DIR / "users" / user_id
|
||||
d.mkdir(parents=True, exist_ok=True)
|
||||
return d
|
||||
|
||||
|
||||
def get_llm_config() -> dict | None:
|
||||
"""Build LLMRouter config from env vars. Returns None if PAGEPIPER_OLLAMA_URL is unset."""
|
||||
"""Build LLMRouter config from env vars.
|
||||
|
||||
Returns None only when neither PAGEPIPER_OLLAMA_URL nor CF_ORCH_URL is set.
|
||||
CF_ORCH_URL alone is sufficient — the coordinator resolves the service URL at
|
||||
allocation time so PAGEPIPER_OLLAMA_URL becomes optional.
|
||||
"""
|
||||
url = os.environ.get("PAGEPIPER_OLLAMA_URL", "").strip()
|
||||
if not url:
|
||||
orch_url = os.environ.get("CF_ORCH_URL", "").strip()
|
||||
|
||||
if not url and not orch_url:
|
||||
return None
|
||||
_clean = url.rstrip("/")
|
||||
_base_url = _clean if _clean.endswith("/v1") else _clean + "/v1"
|
||||
|
||||
chat_model = os.environ.get("PAGEPIPER_CHAT_MODEL", "mistral:7b")
|
||||
|
||||
_base_url = ""
|
||||
if url:
|
||||
_clean = url.rstrip("/")
|
||||
_base_url = _clean if _clean.endswith("/v1") else _clean + "/v1"
|
||||
|
||||
backend: dict = {
|
||||
"type": "openai_compat",
|
||||
"base_url": _base_url,
|
||||
|
|
@ -30,12 +50,9 @@ def get_llm_config() -> dict | None:
|
|||
"supports_images": False,
|
||||
}
|
||||
|
||||
# Wire cf-orch allocation when coordinator is configured so the model stays warm
|
||||
# and cold-start latency doesn't cause chat timeouts.
|
||||
orch_url = os.environ.get("CF_ORCH_URL", "").strip()
|
||||
if orch_url:
|
||||
backend["cf_orch"] = {
|
||||
"service": "ollama",
|
||||
"service": os.environ.get("PAGEPIPER_ORCH_SERVICE", "ollama"),
|
||||
"model_candidates": [chat_model],
|
||||
"ttl_s": 3600,
|
||||
}
|
||||
|
|
|
|||
68
app/deps.py
68
app/deps.py
|
|
@ -3,13 +3,75 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import sqlite3
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Generator
|
||||
|
||||
from app.config import DB_PATH
|
||||
from fastapi import Depends, Request
|
||||
|
||||
from app.config import DATA_DIR, LOCAL_USER_ID
|
||||
from app.services.bm25_index import BM25Index
|
||||
|
||||
|
||||
def get_db() -> Generator[sqlite3.Connection, None, None]:
|
||||
conn = sqlite3.connect(DB_PATH, check_same_thread=False)
|
||||
@dataclass
|
||||
class UserCtx:
|
||||
"""Per-request context routing DB paths and BM25 to the right user."""
|
||||
|
||||
user_id: str
|
||||
db_path: str
|
||||
vec_db_path: str
|
||||
data_dir: Path
|
||||
watch_dir: Path
|
||||
bm25: BM25Index
|
||||
|
||||
|
||||
_user_startup_done: set[str] = set()
|
||||
|
||||
|
||||
def _run_user_startup(user_id: str, user_dir: Path) -> None:
|
||||
"""Run migrations and vec schema check once per process lifetime per user."""
|
||||
if user_id in _user_startup_done:
|
||||
return
|
||||
_user_startup_done.add(user_id)
|
||||
from app.config import VEC_DIMENSIONS
|
||||
from app.startup import apply_migrations, check_and_rebuild_vec_schema
|
||||
apply_migrations(str(user_dir / "pagepiper.db"))
|
||||
check_and_rebuild_vec_schema(
|
||||
str(user_dir / "pagepiper_vecs.db"), VEC_DIMENSIONS, str(user_dir / "pagepiper.db")
|
||||
)
|
||||
|
||||
|
||||
def get_user_ctx(request: Request) -> UserCtx:
|
||||
"""Resolve the per-user data directory, DB paths, and BM25 instance for this request."""
|
||||
import app.main as _main
|
||||
from app.cloud_session import CLOUD_MODE
|
||||
|
||||
if CLOUD_MODE:
|
||||
from app.cloud_session import resolve_authenticated_user
|
||||
from app.config import user_data_dir
|
||||
user_id = resolve_authenticated_user(request)
|
||||
user_dir = user_data_dir(user_id)
|
||||
_run_user_startup(user_id, user_dir)
|
||||
watch_dir = user_dir / "books"
|
||||
watch_dir.mkdir(parents=True, exist_ok=True)
|
||||
else:
|
||||
from app.config import WATCH_DIR
|
||||
user_id = LOCAL_USER_ID
|
||||
user_dir = DATA_DIR
|
||||
watch_dir = WATCH_DIR
|
||||
|
||||
return UserCtx(
|
||||
user_id=user_id,
|
||||
db_path=str(user_dir / "pagepiper.db"),
|
||||
vec_db_path=str(user_dir / "pagepiper_vecs.db"),
|
||||
data_dir=user_dir,
|
||||
watch_dir=watch_dir,
|
||||
bm25=_main._get_bm25_for(user_id),
|
||||
)
|
||||
|
||||
|
||||
def get_db(ctx: UserCtx = Depends(get_user_ctx)) -> Generator[sqlite3.Connection, None, None]:
|
||||
conn = sqlite3.connect(ctx.db_path, check_same_thread=False)
|
||||
conn.execute("PRAGMA foreign_keys = ON")
|
||||
conn.execute("PRAGMA journal_mode = WAL")
|
||||
conn.row_factory = sqlite3.Row
|
||||
|
|
|
|||
111
app/main.py
111
app/main.py
|
|
@ -4,9 +4,6 @@ from __future__ import annotations
|
|||
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import sqlite3
|
||||
import threading
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI
|
||||
|
|
@ -16,110 +13,36 @@ from app.services.bm25_index import BM25Index
|
|||
|
||||
logger = logging.getLogger("pagepiper")
|
||||
|
||||
# Module-level BM25 singleton — shared across all requests
|
||||
_bm25 = BM25Index()
|
||||
# Per-user BM25 registry — keyed by user_id; "__local__" for single-user mode
|
||||
_bm25_map: dict[str, BM25Index] = {}
|
||||
|
||||
|
||||
def _apply_migrations() -> None:
|
||||
from scripts.db_migrate import migrate
|
||||
migrate(DB_PATH)
|
||||
|
||||
|
||||
def _reembed_docs(docs: list[tuple[str, str]], db_path: str, vec_db_path: str) -> None:
|
||||
"""Re-run full ingest for a list of (doc_id, file_path) sequentially."""
|
||||
for doc_id, file_path in docs:
|
||||
suffix = os.path.splitext(file_path)[1].lower()
|
||||
try:
|
||||
if suffix == ".epub":
|
||||
from scripts.ingest_epub import run
|
||||
else:
|
||||
from scripts.ingest_pdf import run
|
||||
logger.info("Auto re-embed: starting %s", os.path.basename(file_path))
|
||||
run(doc_id=doc_id, file_path=file_path, db_path=db_path, vec_db_path=vec_db_path)
|
||||
except Exception as exc:
|
||||
logger.error("Auto re-embed failed for doc %s: %s", doc_id[:8], exc)
|
||||
|
||||
|
||||
def _check_vec_schema(vec_db_path: str, expected_dims: int, db_path: str) -> None:
|
||||
"""Drop the vec DB if its stored dimension doesn't match config, then queue re-embed.
|
||||
|
||||
sqlite-vec bakes the embedding dimension into the virtual table DDL, so changing
|
||||
models requires dropping and recreating the whole file. Catches the mismatch at
|
||||
startup rather than surfacing it as an obscure OperationalError mid-request.
|
||||
"""
|
||||
if not os.path.exists(vec_db_path):
|
||||
return
|
||||
try:
|
||||
conn = sqlite3.connect(vec_db_path)
|
||||
row = conn.execute(
|
||||
"SELECT sql FROM sqlite_master WHERE name='page_vecs_vecs'"
|
||||
).fetchone()
|
||||
conn.close()
|
||||
except Exception as exc:
|
||||
logger.warning("Vec schema check could not read %s (non-fatal): %s", vec_db_path, exc)
|
||||
return
|
||||
|
||||
if not row:
|
||||
return # table not yet created — first embed will build it with the right dims
|
||||
|
||||
m = re.search(r'float\[(\d+)\]', row[0])
|
||||
if not m:
|
||||
return
|
||||
actual_dims = int(m.group(1))
|
||||
if actual_dims == expected_dims:
|
||||
return
|
||||
|
||||
logger.warning(
|
||||
"Vec DB dimension mismatch: stored=%d, configured=%d — dropping %s and queuing re-embed",
|
||||
actual_dims, expected_dims, vec_db_path,
|
||||
)
|
||||
try:
|
||||
os.remove(vec_db_path)
|
||||
except OSError as exc:
|
||||
logger.error(
|
||||
"Could not delete stale vec DB %s: %s — fix permissions and restart", vec_db_path, exc
|
||||
)
|
||||
return
|
||||
|
||||
# Collect all ready docs so we can rebuild their embeddings in the background.
|
||||
try:
|
||||
conn = sqlite3.connect(db_path)
|
||||
docs = conn.execute(
|
||||
"SELECT id, file_path FROM documents WHERE status='ready'"
|
||||
).fetchall()
|
||||
conn.close()
|
||||
except Exception as exc:
|
||||
logger.warning("Could not query documents for re-embed: %s", exc)
|
||||
return
|
||||
|
||||
if not docs:
|
||||
return
|
||||
|
||||
logger.info("Queuing re-embed for %d document(s) in background", len(docs))
|
||||
threading.Thread(
|
||||
target=_reembed_docs,
|
||||
args=(docs, db_path, vec_db_path),
|
||||
daemon=True,
|
||||
name="pagepiper-reembed",
|
||||
).start()
|
||||
def _get_bm25_for(user_id: str) -> BM25Index:
|
||||
if user_id not in _bm25_map:
|
||||
_bm25_map[user_id] = BM25Index()
|
||||
return _bm25_map[user_id]
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
_apply_migrations()
|
||||
from app.cloud_session import CLOUD_MODE
|
||||
from app.config import LOCAL_USER_ID
|
||||
from app.startup import apply_migrations, check_and_rebuild_vec_schema
|
||||
|
||||
embed_model = os.environ.get("PAGEPIPER_EMBED_MODEL", "nomic-embed-text")
|
||||
logger.info("Pagepiper starting — embed model: %s, dims: %d", embed_model, VEC_DIMENSIONS)
|
||||
_check_vec_schema(VEC_DB_PATH, VEC_DIMENSIONS, DB_PATH)
|
||||
_bm25.mark_dirty() # will rebuild on first search
|
||||
|
||||
if not CLOUD_MODE:
|
||||
# In cloud mode, per-user migration and vec schema check run on first request (deps.py).
|
||||
apply_migrations(DB_PATH)
|
||||
check_and_rebuild_vec_schema(VEC_DB_PATH, VEC_DIMENSIONS, DB_PATH)
|
||||
_get_bm25_for(LOCAL_USER_ID).mark_dirty()
|
||||
|
||||
yield
|
||||
|
||||
|
||||
app = FastAPI(title="Pagepiper", lifespan=lifespan)
|
||||
|
||||
# Wire BM25 dirty callback into library router
|
||||
from app.api import library as _lib_module # noqa: E402
|
||||
_lib_module._mark_bm25_dirty = _bm25.mark_dirty
|
||||
|
||||
# Register routers
|
||||
from app.api.library import router as library_router # noqa: E402
|
||||
from app.api.ingest import router as ingest_router # noqa: E402
|
||||
|
|
|
|||
95
app/startup.py
Normal file
95
app/startup.py
Normal file
|
|
@ -0,0 +1,95 @@
|
|||
# app/startup.py
|
||||
"""DB migration and vec schema check utilities — called at startup and on first user request."""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import sqlite3
|
||||
import threading
|
||||
|
||||
logger = logging.getLogger("pagepiper")
|
||||
|
||||
|
||||
def apply_migrations(db_path: str) -> None:
|
||||
from scripts.db_migrate import migrate
|
||||
migrate(db_path)
|
||||
|
||||
|
||||
def reembed_docs(docs: list[tuple[str, str]], db_path: str, vec_db_path: str) -> None:
|
||||
for doc_id, file_path in docs:
|
||||
suffix = os.path.splitext(file_path)[1].lower()
|
||||
try:
|
||||
if suffix == ".epub":
|
||||
from scripts.ingest_epub import run
|
||||
elif suffix == ".docx":
|
||||
from scripts.ingest_docx import run
|
||||
else:
|
||||
from scripts.ingest_pdf import run
|
||||
logger.info("Auto re-embed: starting %s", os.path.basename(file_path))
|
||||
run(doc_id=doc_id, file_path=file_path, db_path=db_path, vec_db_path=vec_db_path)
|
||||
except Exception as exc:
|
||||
logger.error("Auto re-embed failed for doc %s: %s", doc_id[:8], exc)
|
||||
|
||||
|
||||
def check_and_rebuild_vec_schema(vec_db_path: str, expected_dims: int, db_path: str) -> None:
|
||||
"""Drop the vec DB if its stored dimension doesn't match config, then queue re-embed.
|
||||
|
||||
sqlite-vec bakes the embedding dimension into the virtual table DDL, so changing
|
||||
models requires dropping and recreating the whole file. Catches the mismatch at
|
||||
startup rather than surfacing it as an obscure OperationalError mid-request.
|
||||
"""
|
||||
if not os.path.exists(vec_db_path):
|
||||
return
|
||||
try:
|
||||
conn = sqlite3.connect(vec_db_path)
|
||||
row = conn.execute(
|
||||
"SELECT sql FROM sqlite_master WHERE name='page_vecs_vecs'"
|
||||
).fetchone()
|
||||
conn.close()
|
||||
except Exception as exc:
|
||||
logger.warning("Vec schema check could not read %s (non-fatal): %s", vec_db_path, exc)
|
||||
return
|
||||
|
||||
if not row:
|
||||
return
|
||||
|
||||
m = re.search(r'float\[(\d+)\]', row[0])
|
||||
if not m:
|
||||
return
|
||||
actual_dims = int(m.group(1))
|
||||
if actual_dims == expected_dims:
|
||||
return
|
||||
|
||||
logger.warning(
|
||||
"Vec DB dimension mismatch: stored=%d, configured=%d — dropping %s and queuing re-embed",
|
||||
actual_dims, expected_dims, vec_db_path,
|
||||
)
|
||||
try:
|
||||
os.remove(vec_db_path)
|
||||
except OSError as exc:
|
||||
logger.error(
|
||||
"Could not delete stale vec DB %s: %s — fix permissions and restart", vec_db_path, exc
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
conn = sqlite3.connect(db_path)
|
||||
docs = conn.execute(
|
||||
"SELECT id, file_path FROM documents WHERE status='ready'"
|
||||
).fetchall()
|
||||
conn.close()
|
||||
except Exception as exc:
|
||||
logger.warning("Could not query documents for re-embed: %s", exc)
|
||||
return
|
||||
|
||||
if not docs:
|
||||
return
|
||||
|
||||
logger.info("Queuing re-embed for %d document(s) in background", len(docs))
|
||||
threading.Thread(
|
||||
target=reembed_docs,
|
||||
args=(docs, db_path, vec_db_path),
|
||||
daemon=True,
|
||||
name="pagepiper-reembed",
|
||||
).start()
|
||||
|
|
@ -27,13 +27,32 @@ def client(test_db, tmp_path, monkeypatch):
|
|||
(tmp_path / "books").mkdir(exist_ok=True)
|
||||
|
||||
import app.main as _main_module
|
||||
from app.main import app, _bm25
|
||||
from app.deps import get_db
|
||||
from app.config import LOCAL_USER_ID
|
||||
from app.deps import UserCtx, get_db, get_user_ctx
|
||||
from app.main import app
|
||||
from app.services.bm25_index import BM25Index
|
||||
from app.startup import apply_migrations, check_and_rebuild_vec_schema
|
||||
|
||||
# Suppress startup side effects — test_db fixture already applies the schema,
|
||||
# and vec schema validation is tested separately in test_startup.py
|
||||
monkeypatch.setattr(_main_module, "_apply_migrations", lambda: None)
|
||||
monkeypatch.setattr(_main_module, "_check_vec_schema", lambda *a, **kw: None)
|
||||
monkeypatch.setattr(_main_module, "_apply_migrations", lambda: None, raising=False)
|
||||
monkeypatch.setattr(
|
||||
"app.startup.apply_migrations", lambda *a, **kw: None
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"app.startup.check_and_rebuild_vec_schema", lambda *a, **kw: None
|
||||
)
|
||||
|
||||
test_bm25 = BM25Index()
|
||||
test_bm25.mark_dirty()
|
||||
|
||||
def override_user_ctx():
|
||||
return UserCtx(
|
||||
user_id=LOCAL_USER_ID,
|
||||
db_path=test_db,
|
||||
vec_db_path=str(tmp_path / "test_vecs.db"),
|
||||
data_dir=Path(tmp_path),
|
||||
watch_dir=Path(tmp_path) / "books",
|
||||
bm25=test_bm25,
|
||||
)
|
||||
|
||||
def override_db():
|
||||
conn = sqlite3.connect(test_db)
|
||||
|
|
@ -44,7 +63,7 @@ def client(test_db, tmp_path, monkeypatch):
|
|||
finally:
|
||||
conn.close()
|
||||
|
||||
app.dependency_overrides[get_user_ctx] = override_user_ctx
|
||||
app.dependency_overrides[get_db] = override_db
|
||||
_bm25.mark_dirty() # clear any state from previous tests
|
||||
yield TestClient(app)
|
||||
app.dependency_overrides.clear()
|
||||
|
|
|
|||
|
|
@ -20,9 +20,7 @@ def _add_chunks(db_path: str, doc_id: str, chunks: list[dict]) -> None:
|
|||
conn.close()
|
||||
|
||||
|
||||
def test_search_returns_results(client, test_db, monkeypatch):
|
||||
import app.api.search as _search_mod
|
||||
monkeypatch.setattr(_search_mod, "_get_db_path", lambda: test_db)
|
||||
def test_search_returns_results(client, test_db):
|
||||
# BM25Okapi IDF is 0 when df == N/2 (e.g. 2 docs, 1 match → log(1.0) = 0).
|
||||
# Add a 3rd unrelated chunk so relevant terms score above zero.
|
||||
_add_chunks(test_db, "book-a", [
|
||||
|
|
@ -46,9 +44,7 @@ def test_search_empty_index_returns_empty(client):
|
|||
assert resp.json() == []
|
||||
|
||||
|
||||
def test_search_filters_by_doc_ids(client, test_db, monkeypatch):
|
||||
import app.api.search as _search_mod
|
||||
monkeypatch.setattr(_search_mod, "_get_db_path", lambda: test_db)
|
||||
def test_search_filters_by_doc_ids(client, test_db):
|
||||
# Three chunks so BM25Okapi IDF is non-zero for terms appearing in one doc.
|
||||
_add_chunks(test_db, "book-a", [
|
||||
{"page_number": 1, "text": "Grapple rules for melee attacks."},
|
||||
|
|
|
|||
|
|
@ -9,7 +9,8 @@ from unittest.mock import MagicMock, patch
|
|||
|
||||
import pytest
|
||||
|
||||
from app.main import _check_vec_schema, _reembed_docs
|
||||
from app.startup import check_and_rebuild_vec_schema as _check_vec_schema
|
||||
from app.startup import reembed_docs as _reembed_docs
|
||||
|
||||
|
||||
def _make_vec_db(path: str, dims: int) -> None:
|
||||
|
|
|
|||
Loading…
Reference in a new issue