feat: per-user database isolation for cloud instances (closes #4)
Implements Option A from the issue design: each cloud user gets their own
data directory (DATA_DIR/users/{user_id}/) with separate pagepiper.db,
pagepiper_vecs.db, uploads/, and books/. Local mode is unchanged.
Key changes:
- app/startup.py: extract apply_migrations, reembed_docs,
check_and_rebuild_vec_schema out of main.py (no circular imports)
- app/config.py: add LOCAL_USER_ID constant and user_data_dir() helper
- app/cloud_session.py: extract resolve_authenticated_user(); require_paid_tier
now returns user_id (str) instead of None
- app/deps.py: add UserCtx dataclass (db_path, vec_db_path, data_dir,
watch_dir, bm25) + get_user_ctx dependency; per-user startup guard runs
migrations + vec schema check once per process per user
- app/main.py: _bm25 singleton -> _bm25_map dict keyed by user_id;
add _get_bm25_for(); lifespan only runs startup checks in local mode
- app/api/library.py, search.py, chat.py: thread UserCtx through all
endpoints; remove module-level _mark_bm25_dirty injection pattern
- tests/conftest.py: override get_user_ctx in addition to get_db so all
endpoints get a consistent test UserCtx
This commit is contained in:
parent
df9e91ad89
commit
8eef52a054
11 changed files with 432 additions and 199 deletions
|
|
@ -10,9 +10,11 @@ from __future__ import annotations
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from fastapi import APIRouter, HTTPException
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from app.cloud_session import require_paid_tier
|
||||||
|
from app.deps import UserCtx, get_user_ctx
|
||||||
from app.services.retriever import Retriever
|
from app.services.retriever import Retriever
|
||||||
from app.services.synthesizer import Synthesizer
|
from app.services.synthesizer import Synthesizer
|
||||||
|
|
||||||
|
|
@ -56,21 +58,6 @@ def _get_llm_router():
|
||||||
return LLMRouter(cfg)
|
return LLMRouter(cfg)
|
||||||
|
|
||||||
|
|
||||||
def _get_db_path() -> str:
|
|
||||||
"""Read lazily so test fixtures take effect."""
|
|
||||||
import pathlib
|
|
||||||
|
|
||||||
data_dir = pathlib.Path(os.environ.get("PAGEPIPER_DATA_DIR", "data"))
|
|
||||||
return str(data_dir / "pagepiper.db")
|
|
||||||
|
|
||||||
|
|
||||||
def _get_vec_db_path() -> str:
|
|
||||||
import pathlib
|
|
||||||
|
|
||||||
data_dir = pathlib.Path(os.environ.get("PAGEPIPER_DATA_DIR", "data"))
|
|
||||||
return str(data_dir / "pagepiper_vecs.db")
|
|
||||||
|
|
||||||
|
|
||||||
def _require_llm():
|
def _require_llm():
|
||||||
"""Return LLMRouter or raise 402."""
|
"""Return LLMRouter or raise 402."""
|
||||||
llm = _get_llm_router()
|
llm = _get_llm_router()
|
||||||
|
|
@ -89,18 +76,20 @@ def _require_llm():
|
||||||
|
|
||||||
|
|
||||||
@router.post("")
|
@router.post("")
|
||||||
def chat(req: ChatRequest) -> ChatResponse:
|
def chat(
|
||||||
|
req: ChatRequest,
|
||||||
|
ctx: UserCtx = Depends(get_user_ctx),
|
||||||
|
_tier: str = Depends(require_paid_tier),
|
||||||
|
) -> ChatResponse:
|
||||||
llm = _require_llm()
|
llm = _require_llm()
|
||||||
|
|
||||||
from app.main import _bm25
|
retriever = Retriever(ctx.bm25)
|
||||||
|
|
||||||
retriever = Retriever(_bm25)
|
|
||||||
chunks = retriever.hybrid_search(
|
chunks = retriever.hybrid_search(
|
||||||
query=req.message,
|
query=req.message,
|
||||||
top_k=req.top_k,
|
top_k=req.top_k,
|
||||||
doc_ids=req.doc_ids,
|
doc_ids=req.doc_ids,
|
||||||
db_path=_get_db_path(),
|
db_path=ctx.db_path,
|
||||||
vec_db_path=_get_vec_db_path(),
|
vec_db_path=ctx.vec_db_path,
|
||||||
llm=llm,
|
llm=llm,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -141,7 +130,10 @@ def chat_feedback_status() -> dict:
|
||||||
|
|
||||||
|
|
||||||
@router.post("/feedback")
|
@router.post("/feedback")
|
||||||
def submit_chat_feedback(req: ChatFeedbackRequest) -> dict:
|
def submit_chat_feedback(
|
||||||
|
req: ChatFeedbackRequest,
|
||||||
|
ctx: UserCtx = Depends(get_user_ctx),
|
||||||
|
) -> dict:
|
||||||
import json
|
import json
|
||||||
import sqlite3
|
import sqlite3
|
||||||
|
|
||||||
|
|
@ -149,8 +141,7 @@ def submit_chat_feedback(req: ChatFeedbackRequest) -> dict:
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
raise HTTPException(status_code=422, detail="rating must be 1 or -1")
|
raise HTTPException(status_code=422, detail="rating must be 1 or -1")
|
||||||
|
|
||||||
db_path = _get_db_path()
|
con = sqlite3.connect(ctx.db_path)
|
||||||
con = sqlite3.connect(db_path)
|
|
||||||
try:
|
try:
|
||||||
con.execute(
|
con.execute(
|
||||||
"INSERT INTO chat_feedback (rating, question, answer, doc_ids) VALUES (?, ?, ?, ?)",
|
"INSERT INTO chat_feedback (rating, question, answer, doc_ids) VALUES (?, ?, ?, ?)",
|
||||||
|
|
|
||||||
|
|
@ -14,26 +14,25 @@ from typing import Callable
|
||||||
|
|
||||||
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, UploadFile
|
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, UploadFile
|
||||||
|
|
||||||
from app.config import WATCH_DIR, DB_PATH, VEC_DB_PATH, DATA_DIR
|
from app.config import VEC_DIMENSIONS
|
||||||
from app.deps import get_db
|
from app.deps import UserCtx, get_db, get_user_ctx
|
||||||
|
|
||||||
_MAX_UPLOAD_BYTES = 200 * 1024 * 1024 # 200 MB
|
_MAX_UPLOAD_BYTES = 200 * 1024 * 1024 # 200 MB
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
router = APIRouter(prefix="/api/library", tags=["library"])
|
router = APIRouter(prefix="/api/library", tags=["library"])
|
||||||
|
|
||||||
# Injected by main.py after _bm25 is created
|
|
||||||
_mark_bm25_dirty: Callable[[], None] | None = None
|
|
||||||
|
|
||||||
|
|
||||||
_INGEST_TASKS = {
|
_INGEST_TASKS = {
|
||||||
".pdf": "pagepiper/ingest_pdf",
|
".pdf": "pagepiper/ingest_pdf",
|
||||||
".epub": "pagepiper/ingest_epub",
|
".epub": "pagepiper/ingest_epub",
|
||||||
|
".docx": "pagepiper/ingest_docx",
|
||||||
}
|
}
|
||||||
|
|
||||||
_INGEST_RUNNERS = {
|
_INGEST_RUNNERS = {
|
||||||
".pdf": "scripts.ingest_pdf",
|
".pdf": "scripts.ingest_pdf",
|
||||||
".epub": "scripts.ingest_epub",
|
".epub": "scripts.ingest_epub",
|
||||||
|
".docx": "scripts.ingest_docx",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -41,24 +40,22 @@ def _dispatch_ingest(
|
||||||
doc_id: str,
|
doc_id: str,
|
||||||
file_path: str,
|
file_path: str,
|
||||||
background_tasks: BackgroundTasks,
|
background_tasks: BackgroundTasks,
|
||||||
|
data_dir: Path,
|
||||||
|
mark_dirty_fn: Callable[[], None],
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Dispatch an ingest task. Tries cf-orch; falls back to BackgroundTasks."""
|
"""Dispatch an ingest task. Tries cf-orch; falls back to BackgroundTasks."""
|
||||||
import importlib
|
import importlib
|
||||||
import os as _os
|
|
||||||
from pathlib import Path as _Path
|
|
||||||
|
|
||||||
suffix = _Path(file_path).suffix.lower()
|
suffix = Path(file_path).suffix.lower()
|
||||||
task_name = _INGEST_TASKS.get(suffix, "pagepiper/ingest_pdf")
|
task_name = _INGEST_TASKS.get(suffix, "pagepiper/ingest_pdf")
|
||||||
runner_module = _INGEST_RUNNERS.get(suffix, "scripts.ingest_pdf")
|
runner_module = _INGEST_RUNNERS.get(suffix, "scripts.ingest_pdf")
|
||||||
|
|
||||||
# Read lazily so test fixtures (monkeypatch.setenv) take effect
|
|
||||||
_data_dir = _Path(_os.environ.get("PAGEPIPER_DATA_DIR", "data"))
|
|
||||||
task_id = str(uuid.uuid4())
|
task_id = str(uuid.uuid4())
|
||||||
args = {
|
args = {
|
||||||
"doc_id": doc_id,
|
"doc_id": doc_id,
|
||||||
"file_path": file_path,
|
"file_path": file_path,
|
||||||
"db_path": str(_data_dir / "pagepiper.db"),
|
"db_path": str(data_dir / "pagepiper.db"),
|
||||||
"vec_db_path": str(_data_dir / "pagepiper_vecs.db"),
|
"vec_db_path": str(data_dir / "pagepiper_vecs.db"),
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -67,7 +64,7 @@ def _dispatch_ingest(
|
||||||
logger.info("Dispatched cf-orch ingest task %s for doc %s", task_id, doc_id)
|
logger.info("Dispatched cf-orch ingest task %s for doc %s", task_id, doc_id)
|
||||||
except Exception:
|
except Exception:
|
||||||
mod = importlib.import_module(runner_module)
|
mod = importlib.import_module(runner_module)
|
||||||
background_tasks.add_task(_run_ingest_background, mod.run, args, task_id)
|
background_tasks.add_task(_run_ingest_background, mod.run, args, task_id, mark_dirty_fn)
|
||||||
logger.info(
|
logger.info(
|
||||||
"cf-orch unavailable — running ingest in background thread (task %s)", task_id
|
"cf-orch unavailable — running ingest in background thread (task %s)", task_id
|
||||||
)
|
)
|
||||||
|
|
@ -75,14 +72,19 @@ def _dispatch_ingest(
|
||||||
return task_id
|
return task_id
|
||||||
|
|
||||||
|
|
||||||
def _run_ingest_background(run_fn: Callable[..., None], args: dict, task_id: str) -> None:
|
def _run_ingest_background(
|
||||||
|
run_fn: Callable[..., None],
|
||||||
|
args: dict,
|
||||||
|
task_id: str,
|
||||||
|
mark_dirty_fn: Callable[[], None] | None = None,
|
||||||
|
) -> None:
|
||||||
from app.api.ingest import _task_registry
|
from app.api.ingest import _task_registry
|
||||||
_task_registry[task_id] = {"status": "running", "progress": 0}
|
_task_registry[task_id] = {"status": "running", "progress": 0}
|
||||||
try:
|
try:
|
||||||
run_fn(**args)
|
run_fn(**args)
|
||||||
_task_registry[task_id] = {"status": "complete", "progress": 100}
|
_task_registry[task_id] = {"status": "complete", "progress": 100}
|
||||||
if _mark_bm25_dirty:
|
if mark_dirty_fn:
|
||||||
_mark_bm25_dirty()
|
mark_dirty_fn()
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.exception("Ingest task %s failed", task_id)
|
logger.exception("Ingest task %s failed", task_id)
|
||||||
_task_registry[task_id] = {"status": "error", "error": str(exc)}
|
_task_registry[task_id] = {"status": "error", "error": str(exc)}
|
||||||
|
|
@ -101,13 +103,18 @@ def list_library(db: sqlite3.Connection = Depends(get_db)) -> list[dict]:
|
||||||
def scan_library(
|
def scan_library(
|
||||||
background_tasks: BackgroundTasks,
|
background_tasks: BackgroundTasks,
|
||||||
db: sqlite3.Connection = Depends(get_db),
|
db: sqlite3.Connection = Depends(get_db),
|
||||||
|
ctx: UserCtx = Depends(get_user_ctx),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""Scan the watched directory and queue ingest for any new PDFs."""
|
"""Scan the watched directory and queue ingest for any new PDFs."""
|
||||||
watch = WATCH_DIR
|
watch = ctx.watch_dir
|
||||||
if not watch.exists():
|
if not watch.exists():
|
||||||
raise HTTPException(status_code=404, detail=f"Watch directory not found: {watch}")
|
raise HTTPException(status_code=404, detail=f"Watch directory not found: {watch}")
|
||||||
|
|
||||||
pdfs = list(watch.glob("**/*.pdf")) + list(watch.glob("**/*.epub"))
|
pdfs = (
|
||||||
|
list(watch.glob("**/*.pdf"))
|
||||||
|
+ list(watch.glob("**/*.epub"))
|
||||||
|
+ list(watch.glob("**/*.docx"))
|
||||||
|
)
|
||||||
queued = []
|
queued = []
|
||||||
|
|
||||||
for pdf_path in pdfs:
|
for pdf_path in pdfs:
|
||||||
|
|
@ -117,7 +124,7 @@ def scan_library(
|
||||||
).fetchone()
|
).fetchone()
|
||||||
|
|
||||||
if existing and existing["status"] == "ready":
|
if existing and existing["status"] == "ready":
|
||||||
continue # already indexed
|
continue
|
||||||
|
|
||||||
if existing:
|
if existing:
|
||||||
doc_id = existing["id"]
|
doc_id = existing["id"]
|
||||||
|
|
@ -129,7 +136,9 @@ def scan_library(
|
||||||
).fetchone()[0]
|
).fetchone()[0]
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
task_id = _dispatch_ingest(doc_id, path_str, background_tasks)
|
task_id = _dispatch_ingest(
|
||||||
|
doc_id, path_str, background_tasks, ctx.data_dir, ctx.bm25.mark_dirty
|
||||||
|
)
|
||||||
db.execute(
|
db.execute(
|
||||||
"UPDATE documents SET status='processing', task_id=? WHERE id=?",
|
"UPDATE documents SET status='processing', task_id=? WHERE id=?",
|
||||||
[task_id, doc_id],
|
[task_id, doc_id],
|
||||||
|
|
@ -145,12 +154,15 @@ def reingest_document(
|
||||||
doc_id: str,
|
doc_id: str,
|
||||||
background_tasks: BackgroundTasks,
|
background_tasks: BackgroundTasks,
|
||||||
db: sqlite3.Connection = Depends(get_db),
|
db: sqlite3.Connection = Depends(get_db),
|
||||||
|
ctx: UserCtx = Depends(get_user_ctx),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
row = db.execute("SELECT file_path FROM documents WHERE id=?", [doc_id]).fetchone()
|
row = db.execute("SELECT file_path FROM documents WHERE id=?", [doc_id]).fetchone()
|
||||||
if not row:
|
if not row:
|
||||||
raise HTTPException(status_code=404, detail="Document not found")
|
raise HTTPException(status_code=404, detail="Document not found")
|
||||||
|
|
||||||
task_id = _dispatch_ingest(doc_id, row["file_path"], background_tasks)
|
task_id = _dispatch_ingest(
|
||||||
|
doc_id, row["file_path"], background_tasks, ctx.data_dir, ctx.bm25.mark_dirty
|
||||||
|
)
|
||||||
db.execute(
|
db.execute(
|
||||||
"UPDATE documents SET status='processing', task_id=?, error_msg=NULL WHERE id=?",
|
"UPDATE documents SET status='processing', task_id=?, error_msg=NULL WHERE id=?",
|
||||||
[task_id, doc_id],
|
[task_id, doc_id],
|
||||||
|
|
@ -163,6 +175,7 @@ def reingest_document(
|
||||||
def delete_document(
|
def delete_document(
|
||||||
doc_id: str,
|
doc_id: str,
|
||||||
db: sqlite3.Connection = Depends(get_db),
|
db: sqlite3.Connection = Depends(get_db),
|
||||||
|
ctx: UserCtx = Depends(get_user_ctx),
|
||||||
) -> None:
|
) -> None:
|
||||||
row = db.execute("SELECT id FROM documents WHERE id=?", [doc_id]).fetchone()
|
row = db.execute("SELECT id FROM documents WHERE id=?", [doc_id]).fetchone()
|
||||||
if not row:
|
if not row:
|
||||||
|
|
@ -171,23 +184,21 @@ def delete_document(
|
||||||
db.execute("DELETE FROM documents WHERE id=?", [doc_id])
|
db.execute("DELETE FROM documents WHERE id=?", [doc_id])
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
# Remove embeddings from vector store
|
|
||||||
try:
|
try:
|
||||||
from circuitforge_core.vector.sqlite_vec import LocalSQLiteVecStore # type: ignore[import]
|
from circuitforge_core.vector.sqlite_vec import LocalSQLiteVecStore # type: ignore[import]
|
||||||
from app.config import VEC_DIMENSIONS
|
store = LocalSQLiteVecStore(
|
||||||
store = LocalSQLiteVecStore(db_path=VEC_DB_PATH, table="page_vecs", dimensions=VEC_DIMENSIONS)
|
db_path=ctx.vec_db_path, table="page_vecs", dimensions=VEC_DIMENSIONS
|
||||||
|
)
|
||||||
store.delete_where({"doc_id": doc_id})
|
store.delete_where({"doc_id": doc_id})
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.warning("Could not remove vectors for doc %s: %s", doc_id, exc)
|
logger.warning("Could not remove vectors for doc %s: %s", doc_id, exc)
|
||||||
|
|
||||||
if _mark_bm25_dirty:
|
ctx.bm25.mark_dirty()
|
||||||
_mark_bm25_dirty()
|
|
||||||
|
|
||||||
|
|
||||||
def _get_vec_count(doc_id: str) -> int:
|
def _get_vec_count(doc_id: str, vec_db_path: str) -> int:
|
||||||
"""Return how many vectors have been stored for this doc. Returns 0 on any error."""
|
|
||||||
try:
|
try:
|
||||||
conn = sqlite3.connect(VEC_DB_PATH)
|
conn = sqlite3.connect(vec_db_path)
|
||||||
count = conn.execute(
|
count = conn.execute(
|
||||||
"SELECT COUNT(*) FROM page_vecs_meta WHERE json_extract(metadata, '$.doc_id') = ?",
|
"SELECT COUNT(*) FROM page_vecs_meta WHERE json_extract(metadata, '$.doc_id') = ?",
|
||||||
[doc_id],
|
[doc_id],
|
||||||
|
|
@ -202,6 +213,7 @@ def _get_vec_count(doc_id: str) -> int:
|
||||||
def document_status(
|
def document_status(
|
||||||
doc_id: str,
|
doc_id: str,
|
||||||
db: sqlite3.Connection = Depends(get_db),
|
db: sqlite3.Connection = Depends(get_db),
|
||||||
|
ctx: UserCtx = Depends(get_user_ctx),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
row = db.execute(
|
row = db.execute(
|
||||||
"SELECT id, status, task_id, page_count, error_msg FROM documents WHERE id=?",
|
"SELECT id, status, task_id, page_count, error_msg FROM documents WHERE id=?",
|
||||||
|
|
@ -210,7 +222,7 @@ def document_status(
|
||||||
if not row:
|
if not row:
|
||||||
raise HTTPException(status_code=404, detail="Document not found")
|
raise HTTPException(status_code=404, detail="Document not found")
|
||||||
result = dict(row)
|
result = dict(row)
|
||||||
result["vec_count"] = _get_vec_count(doc_id)
|
result["vec_count"] = _get_vec_count(doc_id, ctx.vec_db_path)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -219,18 +231,19 @@ def upload_document(
|
||||||
file: UploadFile,
|
file: UploadFile,
|
||||||
background_tasks: BackgroundTasks,
|
background_tasks: BackgroundTasks,
|
||||||
db: sqlite3.Connection = Depends(get_db),
|
db: sqlite3.Connection = Depends(get_db),
|
||||||
|
ctx: UserCtx = Depends(get_user_ctx),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""Accept a PDF/EPUB upload, save to data/uploads/, and queue for indexing."""
|
"""Accept a PDF/EPUB upload, save to data/uploads/, and queue for indexing."""
|
||||||
name = Path(file.filename or "").name
|
name = Path(file.filename or "").name
|
||||||
suffix = Path(name).suffix.lower()
|
suffix = Path(name).suffix.lower()
|
||||||
if suffix not in _INGEST_TASKS:
|
if suffix not in _INGEST_TASKS:
|
||||||
raise HTTPException(status_code=400, detail="Supported formats: PDF, EPUB")
|
raise HTTPException(status_code=400, detail="Supported formats: PDF, EPUB, DOCX")
|
||||||
|
|
||||||
content = file.file.read()
|
content = file.file.read()
|
||||||
if len(content) > _MAX_UPLOAD_BYTES:
|
if len(content) > _MAX_UPLOAD_BYTES:
|
||||||
raise HTTPException(status_code=413, detail="File exceeds 200 MB limit")
|
raise HTTPException(status_code=413, detail="File exceeds 200 MB limit")
|
||||||
|
|
||||||
upload_dir = DATA_DIR / "uploads"
|
upload_dir = ctx.data_dir / "uploads"
|
||||||
upload_dir.mkdir(parents=True, exist_ok=True)
|
upload_dir.mkdir(parents=True, exist_ok=True)
|
||||||
dest = upload_dir / name
|
dest = upload_dir / name
|
||||||
dest.write_bytes(content)
|
dest.write_bytes(content)
|
||||||
|
|
@ -253,7 +266,9 @@ def upload_document(
|
||||||
).fetchone()[0]
|
).fetchone()[0]
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
task_id = _dispatch_ingest(doc_id, path_str, background_tasks)
|
task_id = _dispatch_ingest(
|
||||||
|
doc_id, path_str, background_tasks, ctx.data_dir, ctx.bm25.mark_dirty
|
||||||
|
)
|
||||||
db.execute(
|
db.execute(
|
||||||
"UPDATE documents SET status='processing', task_id=? WHERE id=?",
|
"UPDATE documents SET status='processing', task_id=? WHERE id=?",
|
||||||
[task_id, doc_id],
|
[task_id, doc_id],
|
||||||
|
|
|
||||||
|
|
@ -7,13 +7,11 @@ MIT — no tier gate. No Ollama required.
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
from typing import Annotated
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends
|
from fastapi import APIRouter, Depends
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from app.services.bm25_index import BM25Index
|
from app.deps import UserCtx, get_user_ctx
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
router = APIRouter(prefix="/api/search", tags=["search"])
|
router = APIRouter(prefix="/api/search", tags=["search"])
|
||||||
|
|
@ -29,32 +27,17 @@ class SearchResult(BaseModel):
|
||||||
chunk_id: str
|
chunk_id: str
|
||||||
doc_id: str
|
doc_id: str
|
||||||
page_number: int
|
page_number: int
|
||||||
text_snippet: str # first 300 chars of the page text
|
text_snippet: str
|
||||||
bm25_score: float
|
bm25_score: float
|
||||||
|
|
||||||
|
|
||||||
def _get_bm25() -> BM25Index:
|
|
||||||
import app.main as _main
|
|
||||||
bm25 = getattr(_main, "_bm25", None)
|
|
||||||
if bm25 is None:
|
|
||||||
raise RuntimeError("BM25 index not initialised — app.main not loaded")
|
|
||||||
return bm25
|
|
||||||
|
|
||||||
|
|
||||||
def _get_db_path() -> str:
|
|
||||||
"""Read lazily so test fixtures (monkeypatch.setattr) take effect."""
|
|
||||||
import pathlib
|
|
||||||
data_dir = pathlib.Path(os.environ.get("PAGEPIPER_DATA_DIR", "data"))
|
|
||||||
return str(data_dir / "pagepiper.db")
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("")
|
@router.post("")
|
||||||
def search(
|
def search(
|
||||||
req: SearchRequest,
|
req: SearchRequest,
|
||||||
bm25: Annotated[BM25Index, Depends(_get_bm25)],
|
ctx: UserCtx = Depends(get_user_ctx),
|
||||||
) -> list[SearchResult]:
|
) -> list[SearchResult]:
|
||||||
bm25.ensure_fresh(_get_db_path())
|
ctx.bm25.ensure_fresh(ctx.db_path)
|
||||||
hits = bm25.query(req.query, top_k=req.top_k, doc_ids=req.doc_ids)
|
hits = ctx.bm25.query(req.query, top_k=req.top_k, doc_ids=req.doc_ids)
|
||||||
return [
|
return [
|
||||||
SearchResult(
|
SearchResult(
|
||||||
chunk_id=h.chunk_id,
|
chunk_id=h.chunk_id,
|
||||||
|
|
|
||||||
131
app/cloud_session.py
Normal file
131
app/cloud_session.py
Normal file
|
|
@ -0,0 +1,131 @@
|
||||||
|
# app/cloud_session.py
|
||||||
|
"""Cloud session auth for Pagepiper — validates cf_session cookie via Directus + Heimdall.
|
||||||
|
|
||||||
|
In local mode (CLOUD_MODE unset or false), require_paid_tier is a no-op.
|
||||||
|
In cloud mode, the Caddy proxy forwards the browser's Cookie header as
|
||||||
|
X-CF-Session. This module extracts cf_session, validates it against
|
||||||
|
Directus /users/me, then checks the user's Pagepiper tier via Heimdall.
|
||||||
|
Auto-provisions a free tier key for new users.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from fastapi import HTTPException, Request
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
CLOUD_MODE: bool = os.environ.get("CLOUD_MODE", "").lower() in ("1", "true", "yes")
|
||||||
|
DIRECTUS_URL: str = os.environ.get("DIRECTUS_URL", "http://172.31.0.3:8055").rstrip("/")
|
||||||
|
HEIMDALL_URL: str = os.environ.get("HEIMDALL_URL", "https://license.circuitforge.tech").rstrip("/")
|
||||||
|
HEIMDALL_ADMIN_TOKEN: str = os.environ.get("HEIMDALL_ADMIN_TOKEN", "")
|
||||||
|
|
||||||
|
_TIER_ORDER = {"free": 0, "paid": 1, "premium": 2, "ultra": 3}
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_session_token(cookie_header: str) -> str:
|
||||||
|
m = re.search(r"(?:^|;)\s*cf_session=([^;]+)", cookie_header)
|
||||||
|
return m.group(1).strip() if m else ""
|
||||||
|
|
||||||
|
|
||||||
|
def _get_user_id(jwt: str) -> str | None:
|
||||||
|
try:
|
||||||
|
resp = httpx.get(
|
||||||
|
f"{DIRECTUS_URL}/users/me",
|
||||||
|
headers={"Authorization": f"Bearer {jwt}"},
|
||||||
|
timeout=5.0,
|
||||||
|
)
|
||||||
|
if resp.status_code == 200:
|
||||||
|
return resp.json().get("data", {}).get("id")
|
||||||
|
except Exception as exc:
|
||||||
|
log.warning("Directus session check failed: %s", exc)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_provisioned(user_id: str) -> None:
|
||||||
|
if not HEIMDALL_ADMIN_TOKEN:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
httpx.post(
|
||||||
|
f"{HEIMDALL_URL}/admin/provision",
|
||||||
|
json={"directus_user_id": user_id, "product": "pagepiper", "tier": "free"},
|
||||||
|
headers={"Authorization": f"Bearer {HEIMDALL_ADMIN_TOKEN}"},
|
||||||
|
timeout=5.0,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
log.warning("Heimdall provision failed for user %s: %s", user_id, exc)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_tier(user_id: str) -> str:
|
||||||
|
if not HEIMDALL_ADMIN_TOKEN:
|
||||||
|
return "free"
|
||||||
|
try:
|
||||||
|
resp = httpx.get(
|
||||||
|
f"{HEIMDALL_URL}/admin/cloud/resolve",
|
||||||
|
params={"directus_user_id": user_id, "product": "pagepiper"},
|
||||||
|
headers={"Authorization": f"Bearer {HEIMDALL_ADMIN_TOKEN}"},
|
||||||
|
timeout=5.0,
|
||||||
|
)
|
||||||
|
if resp.status_code == 200:
|
||||||
|
return resp.json().get("tier", "free")
|
||||||
|
except Exception as exc:
|
||||||
|
log.warning("Heimdall tier check failed for user %s: %s", user_id, exc)
|
||||||
|
return "free"
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_authenticated_user(request: Request) -> str:
|
||||||
|
"""Validate the session cookie and return the Directus user_id. Raises 401 if invalid."""
|
||||||
|
cookie_header = request.headers.get("x-cf-session", "")
|
||||||
|
jwt = _extract_session_token(cookie_header)
|
||||||
|
|
||||||
|
if not jwt:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=401,
|
||||||
|
detail={
|
||||||
|
"error": "auth_required",
|
||||||
|
"message": "Sign in at circuitforge.tech to use Pagepiper cloud.",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
user_id = _get_user_id(jwt)
|
||||||
|
if not user_id:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=401,
|
||||||
|
detail={
|
||||||
|
"error": "session_invalid",
|
||||||
|
"message": "Your session has expired. Sign in again at circuitforge.tech.",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
_ensure_provisioned(user_id)
|
||||||
|
return user_id
|
||||||
|
|
||||||
|
|
||||||
|
def require_paid_tier(request: Request) -> str:
|
||||||
|
"""FastAPI dependency — 401 if no valid session, 402 if tier < paid. Returns user_id.
|
||||||
|
|
||||||
|
In local mode (CLOUD_MODE not set), returns LOCAL_USER_ID without any auth check.
|
||||||
|
"""
|
||||||
|
if not CLOUD_MODE:
|
||||||
|
from app.config import LOCAL_USER_ID
|
||||||
|
return LOCAL_USER_ID
|
||||||
|
|
||||||
|
user_id = resolve_authenticated_user(request)
|
||||||
|
tier = _get_tier(user_id)
|
||||||
|
|
||||||
|
if _TIER_ORDER.get(tier, 0) < _TIER_ORDER["paid"]:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=402,
|
||||||
|
detail={
|
||||||
|
"error": "upgrade_required",
|
||||||
|
"message": (
|
||||||
|
"RAG chat requires a Paid tier Pagepiper license. "
|
||||||
|
"Upgrade at circuitforge.tech/software/pagepiper."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
return user_id
|
||||||
|
|
@ -12,16 +12,36 @@ VEC_DB_PATH = str(DATA_DIR / "pagepiper_vecs.db")
|
||||||
WATCH_DIR = Path(os.environ.get("PAGEPIPER_WATCH_DIR", "books"))
|
WATCH_DIR = Path(os.environ.get("PAGEPIPER_WATCH_DIR", "books"))
|
||||||
VEC_DIMENSIONS = int(os.environ.get("PAGEPIPER_EMBED_DIMS", "1024"))
|
VEC_DIMENSIONS = int(os.environ.get("PAGEPIPER_EMBED_DIMS", "1024"))
|
||||||
|
|
||||||
|
LOCAL_USER_ID = "__local__"
|
||||||
|
|
||||||
|
|
||||||
|
def user_data_dir(user_id: str) -> Path:
|
||||||
|
"""Return (and create) the per-user data directory under DATA_DIR/users/."""
|
||||||
|
d = DATA_DIR / "users" / user_id
|
||||||
|
d.mkdir(parents=True, exist_ok=True)
|
||||||
|
return d
|
||||||
|
|
||||||
|
|
||||||
def get_llm_config() -> dict | None:
|
def get_llm_config() -> dict | None:
|
||||||
"""Build LLMRouter config from env vars. Returns None if PAGEPIPER_OLLAMA_URL is unset."""
|
"""Build LLMRouter config from env vars.
|
||||||
|
|
||||||
|
Returns None only when neither PAGEPIPER_OLLAMA_URL nor CF_ORCH_URL is set.
|
||||||
|
CF_ORCH_URL alone is sufficient — the coordinator resolves the service URL at
|
||||||
|
allocation time so PAGEPIPER_OLLAMA_URL becomes optional.
|
||||||
|
"""
|
||||||
url = os.environ.get("PAGEPIPER_OLLAMA_URL", "").strip()
|
url = os.environ.get("PAGEPIPER_OLLAMA_URL", "").strip()
|
||||||
if not url:
|
orch_url = os.environ.get("CF_ORCH_URL", "").strip()
|
||||||
|
|
||||||
|
if not url and not orch_url:
|
||||||
return None
|
return None
|
||||||
_clean = url.rstrip("/")
|
|
||||||
_base_url = _clean if _clean.endswith("/v1") else _clean + "/v1"
|
|
||||||
chat_model = os.environ.get("PAGEPIPER_CHAT_MODEL", "mistral:7b")
|
chat_model = os.environ.get("PAGEPIPER_CHAT_MODEL", "mistral:7b")
|
||||||
|
|
||||||
|
_base_url = ""
|
||||||
|
if url:
|
||||||
|
_clean = url.rstrip("/")
|
||||||
|
_base_url = _clean if _clean.endswith("/v1") else _clean + "/v1"
|
||||||
|
|
||||||
backend: dict = {
|
backend: dict = {
|
||||||
"type": "openai_compat",
|
"type": "openai_compat",
|
||||||
"base_url": _base_url,
|
"base_url": _base_url,
|
||||||
|
|
@ -30,12 +50,9 @@ def get_llm_config() -> dict | None:
|
||||||
"supports_images": False,
|
"supports_images": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Wire cf-orch allocation when coordinator is configured so the model stays warm
|
|
||||||
# and cold-start latency doesn't cause chat timeouts.
|
|
||||||
orch_url = os.environ.get("CF_ORCH_URL", "").strip()
|
|
||||||
if orch_url:
|
if orch_url:
|
||||||
backend["cf_orch"] = {
|
backend["cf_orch"] = {
|
||||||
"service": "ollama",
|
"service": os.environ.get("PAGEPIPER_ORCH_SERVICE", "ollama"),
|
||||||
"model_candidates": [chat_model],
|
"model_candidates": [chat_model],
|
||||||
"ttl_s": 3600,
|
"ttl_s": 3600,
|
||||||
}
|
}
|
||||||
|
|
|
||||||
68
app/deps.py
68
app/deps.py
|
|
@ -3,13 +3,75 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import sqlite3
|
import sqlite3
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
from typing import Generator
|
from typing import Generator
|
||||||
|
|
||||||
from app.config import DB_PATH
|
from fastapi import Depends, Request
|
||||||
|
|
||||||
|
from app.config import DATA_DIR, LOCAL_USER_ID
|
||||||
|
from app.services.bm25_index import BM25Index
|
||||||
|
|
||||||
|
|
||||||
def get_db() -> Generator[sqlite3.Connection, None, None]:
|
@dataclass
|
||||||
conn = sqlite3.connect(DB_PATH, check_same_thread=False)
|
class UserCtx:
|
||||||
|
"""Per-request context routing DB paths and BM25 to the right user."""
|
||||||
|
|
||||||
|
user_id: str
|
||||||
|
db_path: str
|
||||||
|
vec_db_path: str
|
||||||
|
data_dir: Path
|
||||||
|
watch_dir: Path
|
||||||
|
bm25: BM25Index
|
||||||
|
|
||||||
|
|
||||||
|
_user_startup_done: set[str] = set()
|
||||||
|
|
||||||
|
|
||||||
|
def _run_user_startup(user_id: str, user_dir: Path) -> None:
|
||||||
|
"""Run migrations and vec schema check once per process lifetime per user."""
|
||||||
|
if user_id in _user_startup_done:
|
||||||
|
return
|
||||||
|
_user_startup_done.add(user_id)
|
||||||
|
from app.config import VEC_DIMENSIONS
|
||||||
|
from app.startup import apply_migrations, check_and_rebuild_vec_schema
|
||||||
|
apply_migrations(str(user_dir / "pagepiper.db"))
|
||||||
|
check_and_rebuild_vec_schema(
|
||||||
|
str(user_dir / "pagepiper_vecs.db"), VEC_DIMENSIONS, str(user_dir / "pagepiper.db")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_user_ctx(request: Request) -> UserCtx:
|
||||||
|
"""Resolve the per-user data directory, DB paths, and BM25 instance for this request."""
|
||||||
|
import app.main as _main
|
||||||
|
from app.cloud_session import CLOUD_MODE
|
||||||
|
|
||||||
|
if CLOUD_MODE:
|
||||||
|
from app.cloud_session import resolve_authenticated_user
|
||||||
|
from app.config import user_data_dir
|
||||||
|
user_id = resolve_authenticated_user(request)
|
||||||
|
user_dir = user_data_dir(user_id)
|
||||||
|
_run_user_startup(user_id, user_dir)
|
||||||
|
watch_dir = user_dir / "books"
|
||||||
|
watch_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
else:
|
||||||
|
from app.config import WATCH_DIR
|
||||||
|
user_id = LOCAL_USER_ID
|
||||||
|
user_dir = DATA_DIR
|
||||||
|
watch_dir = WATCH_DIR
|
||||||
|
|
||||||
|
return UserCtx(
|
||||||
|
user_id=user_id,
|
||||||
|
db_path=str(user_dir / "pagepiper.db"),
|
||||||
|
vec_db_path=str(user_dir / "pagepiper_vecs.db"),
|
||||||
|
data_dir=user_dir,
|
||||||
|
watch_dir=watch_dir,
|
||||||
|
bm25=_main._get_bm25_for(user_id),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_db(ctx: UserCtx = Depends(get_user_ctx)) -> Generator[sqlite3.Connection, None, None]:
|
||||||
|
conn = sqlite3.connect(ctx.db_path, check_same_thread=False)
|
||||||
conn.execute("PRAGMA foreign_keys = ON")
|
conn.execute("PRAGMA foreign_keys = ON")
|
||||||
conn.execute("PRAGMA journal_mode = WAL")
|
conn.execute("PRAGMA journal_mode = WAL")
|
||||||
conn.row_factory = sqlite3.Row
|
conn.row_factory = sqlite3.Row
|
||||||
|
|
|
||||||
111
app/main.py
111
app/main.py
|
|
@ -4,9 +4,6 @@ from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import re
|
|
||||||
import sqlite3
|
|
||||||
import threading
|
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
|
|
@ -16,110 +13,36 @@ from app.services.bm25_index import BM25Index
|
||||||
|
|
||||||
logger = logging.getLogger("pagepiper")
|
logger = logging.getLogger("pagepiper")
|
||||||
|
|
||||||
# Module-level BM25 singleton — shared across all requests
|
# Per-user BM25 registry — keyed by user_id; "__local__" for single-user mode
|
||||||
_bm25 = BM25Index()
|
_bm25_map: dict[str, BM25Index] = {}
|
||||||
|
|
||||||
|
|
||||||
def _apply_migrations() -> None:
|
def _get_bm25_for(user_id: str) -> BM25Index:
|
||||||
from scripts.db_migrate import migrate
|
if user_id not in _bm25_map:
|
||||||
migrate(DB_PATH)
|
_bm25_map[user_id] = BM25Index()
|
||||||
|
return _bm25_map[user_id]
|
||||||
|
|
||||||
def _reembed_docs(docs: list[tuple[str, str]], db_path: str, vec_db_path: str) -> None:
|
|
||||||
"""Re-run full ingest for a list of (doc_id, file_path) sequentially."""
|
|
||||||
for doc_id, file_path in docs:
|
|
||||||
suffix = os.path.splitext(file_path)[1].lower()
|
|
||||||
try:
|
|
||||||
if suffix == ".epub":
|
|
||||||
from scripts.ingest_epub import run
|
|
||||||
else:
|
|
||||||
from scripts.ingest_pdf import run
|
|
||||||
logger.info("Auto re-embed: starting %s", os.path.basename(file_path))
|
|
||||||
run(doc_id=doc_id, file_path=file_path, db_path=db_path, vec_db_path=vec_db_path)
|
|
||||||
except Exception as exc:
|
|
||||||
logger.error("Auto re-embed failed for doc %s: %s", doc_id[:8], exc)
|
|
||||||
|
|
||||||
|
|
||||||
def _check_vec_schema(vec_db_path: str, expected_dims: int, db_path: str) -> None:
|
|
||||||
"""Drop the vec DB if its stored dimension doesn't match config, then queue re-embed.
|
|
||||||
|
|
||||||
sqlite-vec bakes the embedding dimension into the virtual table DDL, so changing
|
|
||||||
models requires dropping and recreating the whole file. Catches the mismatch at
|
|
||||||
startup rather than surfacing it as an obscure OperationalError mid-request.
|
|
||||||
"""
|
|
||||||
if not os.path.exists(vec_db_path):
|
|
||||||
return
|
|
||||||
try:
|
|
||||||
conn = sqlite3.connect(vec_db_path)
|
|
||||||
row = conn.execute(
|
|
||||||
"SELECT sql FROM sqlite_master WHERE name='page_vecs_vecs'"
|
|
||||||
).fetchone()
|
|
||||||
conn.close()
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning("Vec schema check could not read %s (non-fatal): %s", vec_db_path, exc)
|
|
||||||
return
|
|
||||||
|
|
||||||
if not row:
|
|
||||||
return # table not yet created — first embed will build it with the right dims
|
|
||||||
|
|
||||||
m = re.search(r'float\[(\d+)\]', row[0])
|
|
||||||
if not m:
|
|
||||||
return
|
|
||||||
actual_dims = int(m.group(1))
|
|
||||||
if actual_dims == expected_dims:
|
|
||||||
return
|
|
||||||
|
|
||||||
logger.warning(
|
|
||||||
"Vec DB dimension mismatch: stored=%d, configured=%d — dropping %s and queuing re-embed",
|
|
||||||
actual_dims, expected_dims, vec_db_path,
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
os.remove(vec_db_path)
|
|
||||||
except OSError as exc:
|
|
||||||
logger.error(
|
|
||||||
"Could not delete stale vec DB %s: %s — fix permissions and restart", vec_db_path, exc
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
# Collect all ready docs so we can rebuild their embeddings in the background.
|
|
||||||
try:
|
|
||||||
conn = sqlite3.connect(db_path)
|
|
||||||
docs = conn.execute(
|
|
||||||
"SELECT id, file_path FROM documents WHERE status='ready'"
|
|
||||||
).fetchall()
|
|
||||||
conn.close()
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning("Could not query documents for re-embed: %s", exc)
|
|
||||||
return
|
|
||||||
|
|
||||||
if not docs:
|
|
||||||
return
|
|
||||||
|
|
||||||
logger.info("Queuing re-embed for %d document(s) in background", len(docs))
|
|
||||||
threading.Thread(
|
|
||||||
target=_reembed_docs,
|
|
||||||
args=(docs, db_path, vec_db_path),
|
|
||||||
daemon=True,
|
|
||||||
name="pagepiper-reembed",
|
|
||||||
).start()
|
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
_apply_migrations()
|
from app.cloud_session import CLOUD_MODE
|
||||||
|
from app.config import LOCAL_USER_ID
|
||||||
|
from app.startup import apply_migrations, check_and_rebuild_vec_schema
|
||||||
|
|
||||||
embed_model = os.environ.get("PAGEPIPER_EMBED_MODEL", "nomic-embed-text")
|
embed_model = os.environ.get("PAGEPIPER_EMBED_MODEL", "nomic-embed-text")
|
||||||
logger.info("Pagepiper starting — embed model: %s, dims: %d", embed_model, VEC_DIMENSIONS)
|
logger.info("Pagepiper starting — embed model: %s, dims: %d", embed_model, VEC_DIMENSIONS)
|
||||||
_check_vec_schema(VEC_DB_PATH, VEC_DIMENSIONS, DB_PATH)
|
|
||||||
_bm25.mark_dirty() # will rebuild on first search
|
if not CLOUD_MODE:
|
||||||
|
# In cloud mode, per-user migration and vec schema check run on first request (deps.py).
|
||||||
|
apply_migrations(DB_PATH)
|
||||||
|
check_and_rebuild_vec_schema(VEC_DB_PATH, VEC_DIMENSIONS, DB_PATH)
|
||||||
|
_get_bm25_for(LOCAL_USER_ID).mark_dirty()
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI(title="Pagepiper", lifespan=lifespan)
|
app = FastAPI(title="Pagepiper", lifespan=lifespan)
|
||||||
|
|
||||||
# Wire BM25 dirty callback into library router
|
|
||||||
from app.api import library as _lib_module # noqa: E402
|
|
||||||
_lib_module._mark_bm25_dirty = _bm25.mark_dirty
|
|
||||||
|
|
||||||
# Register routers
|
# Register routers
|
||||||
from app.api.library import router as library_router # noqa: E402
|
from app.api.library import router as library_router # noqa: E402
|
||||||
from app.api.ingest import router as ingest_router # noqa: E402
|
from app.api.ingest import router as ingest_router # noqa: E402
|
||||||
|
|
|
||||||
95
app/startup.py
Normal file
95
app/startup.py
Normal file
|
|
@ -0,0 +1,95 @@
|
||||||
|
# app/startup.py
|
||||||
|
"""DB migration and vec schema check utilities — called at startup and on first user request."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import sqlite3
|
||||||
|
import threading
|
||||||
|
|
||||||
|
logger = logging.getLogger("pagepiper")
|
||||||
|
|
||||||
|
|
||||||
|
def apply_migrations(db_path: str) -> None:
|
||||||
|
from scripts.db_migrate import migrate
|
||||||
|
migrate(db_path)
|
||||||
|
|
||||||
|
|
||||||
|
def reembed_docs(docs: list[tuple[str, str]], db_path: str, vec_db_path: str) -> None:
|
||||||
|
for doc_id, file_path in docs:
|
||||||
|
suffix = os.path.splitext(file_path)[1].lower()
|
||||||
|
try:
|
||||||
|
if suffix == ".epub":
|
||||||
|
from scripts.ingest_epub import run
|
||||||
|
elif suffix == ".docx":
|
||||||
|
from scripts.ingest_docx import run
|
||||||
|
else:
|
||||||
|
from scripts.ingest_pdf import run
|
||||||
|
logger.info("Auto re-embed: starting %s", os.path.basename(file_path))
|
||||||
|
run(doc_id=doc_id, file_path=file_path, db_path=db_path, vec_db_path=vec_db_path)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("Auto re-embed failed for doc %s: %s", doc_id[:8], exc)
|
||||||
|
|
||||||
|
|
||||||
|
def check_and_rebuild_vec_schema(vec_db_path: str, expected_dims: int, db_path: str) -> None:
|
||||||
|
"""Drop the vec DB if its stored dimension doesn't match config, then queue re-embed.
|
||||||
|
|
||||||
|
sqlite-vec bakes the embedding dimension into the virtual table DDL, so changing
|
||||||
|
models requires dropping and recreating the whole file. Catches the mismatch at
|
||||||
|
startup rather than surfacing it as an obscure OperationalError mid-request.
|
||||||
|
"""
|
||||||
|
if not os.path.exists(vec_db_path):
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
conn = sqlite3.connect(vec_db_path)
|
||||||
|
row = conn.execute(
|
||||||
|
"SELECT sql FROM sqlite_master WHERE name='page_vecs_vecs'"
|
||||||
|
).fetchone()
|
||||||
|
conn.close()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("Vec schema check could not read %s (non-fatal): %s", vec_db_path, exc)
|
||||||
|
return
|
||||||
|
|
||||||
|
if not row:
|
||||||
|
return
|
||||||
|
|
||||||
|
m = re.search(r'float\[(\d+)\]', row[0])
|
||||||
|
if not m:
|
||||||
|
return
|
||||||
|
actual_dims = int(m.group(1))
|
||||||
|
if actual_dims == expected_dims:
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.warning(
|
||||||
|
"Vec DB dimension mismatch: stored=%d, configured=%d — dropping %s and queuing re-embed",
|
||||||
|
actual_dims, expected_dims, vec_db_path,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
os.remove(vec_db_path)
|
||||||
|
except OSError as exc:
|
||||||
|
logger.error(
|
||||||
|
"Could not delete stale vec DB %s: %s — fix permissions and restart", vec_db_path, exc
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
conn = sqlite3.connect(db_path)
|
||||||
|
docs = conn.execute(
|
||||||
|
"SELECT id, file_path FROM documents WHERE status='ready'"
|
||||||
|
).fetchall()
|
||||||
|
conn.close()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("Could not query documents for re-embed: %s", exc)
|
||||||
|
return
|
||||||
|
|
||||||
|
if not docs:
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info("Queuing re-embed for %d document(s) in background", len(docs))
|
||||||
|
threading.Thread(
|
||||||
|
target=reembed_docs,
|
||||||
|
args=(docs, db_path, vec_db_path),
|
||||||
|
daemon=True,
|
||||||
|
name="pagepiper-reembed",
|
||||||
|
).start()
|
||||||
|
|
@ -27,13 +27,32 @@ def client(test_db, tmp_path, monkeypatch):
|
||||||
(tmp_path / "books").mkdir(exist_ok=True)
|
(tmp_path / "books").mkdir(exist_ok=True)
|
||||||
|
|
||||||
import app.main as _main_module
|
import app.main as _main_module
|
||||||
from app.main import app, _bm25
|
from app.config import LOCAL_USER_ID
|
||||||
from app.deps import get_db
|
from app.deps import UserCtx, get_db, get_user_ctx
|
||||||
|
from app.main import app
|
||||||
|
from app.services.bm25_index import BM25Index
|
||||||
|
from app.startup import apply_migrations, check_and_rebuild_vec_schema
|
||||||
|
|
||||||
# Suppress startup side effects — test_db fixture already applies the schema,
|
monkeypatch.setattr(_main_module, "_apply_migrations", lambda: None, raising=False)
|
||||||
# and vec schema validation is tested separately in test_startup.py
|
monkeypatch.setattr(
|
||||||
monkeypatch.setattr(_main_module, "_apply_migrations", lambda: None)
|
"app.startup.apply_migrations", lambda *a, **kw: None
|
||||||
monkeypatch.setattr(_main_module, "_check_vec_schema", lambda *a, **kw: None)
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"app.startup.check_and_rebuild_vec_schema", lambda *a, **kw: None
|
||||||
|
)
|
||||||
|
|
||||||
|
test_bm25 = BM25Index()
|
||||||
|
test_bm25.mark_dirty()
|
||||||
|
|
||||||
|
def override_user_ctx():
|
||||||
|
return UserCtx(
|
||||||
|
user_id=LOCAL_USER_ID,
|
||||||
|
db_path=test_db,
|
||||||
|
vec_db_path=str(tmp_path / "test_vecs.db"),
|
||||||
|
data_dir=Path(tmp_path),
|
||||||
|
watch_dir=Path(tmp_path) / "books",
|
||||||
|
bm25=test_bm25,
|
||||||
|
)
|
||||||
|
|
||||||
def override_db():
|
def override_db():
|
||||||
conn = sqlite3.connect(test_db)
|
conn = sqlite3.connect(test_db)
|
||||||
|
|
@ -44,7 +63,7 @@ def client(test_db, tmp_path, monkeypatch):
|
||||||
finally:
|
finally:
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
|
app.dependency_overrides[get_user_ctx] = override_user_ctx
|
||||||
app.dependency_overrides[get_db] = override_db
|
app.dependency_overrides[get_db] = override_db
|
||||||
_bm25.mark_dirty() # clear any state from previous tests
|
|
||||||
yield TestClient(app)
|
yield TestClient(app)
|
||||||
app.dependency_overrides.clear()
|
app.dependency_overrides.clear()
|
||||||
|
|
|
||||||
|
|
@ -20,9 +20,7 @@ def _add_chunks(db_path: str, doc_id: str, chunks: list[dict]) -> None:
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
|
|
||||||
def test_search_returns_results(client, test_db, monkeypatch):
|
def test_search_returns_results(client, test_db):
|
||||||
import app.api.search as _search_mod
|
|
||||||
monkeypatch.setattr(_search_mod, "_get_db_path", lambda: test_db)
|
|
||||||
# BM25Okapi IDF is 0 when df == N/2 (e.g. 2 docs, 1 match → log(1.0) = 0).
|
# BM25Okapi IDF is 0 when df == N/2 (e.g. 2 docs, 1 match → log(1.0) = 0).
|
||||||
# Add a 3rd unrelated chunk so relevant terms score above zero.
|
# Add a 3rd unrelated chunk so relevant terms score above zero.
|
||||||
_add_chunks(test_db, "book-a", [
|
_add_chunks(test_db, "book-a", [
|
||||||
|
|
@ -46,9 +44,7 @@ def test_search_empty_index_returns_empty(client):
|
||||||
assert resp.json() == []
|
assert resp.json() == []
|
||||||
|
|
||||||
|
|
||||||
def test_search_filters_by_doc_ids(client, test_db, monkeypatch):
|
def test_search_filters_by_doc_ids(client, test_db):
|
||||||
import app.api.search as _search_mod
|
|
||||||
monkeypatch.setattr(_search_mod, "_get_db_path", lambda: test_db)
|
|
||||||
# Three chunks so BM25Okapi IDF is non-zero for terms appearing in one doc.
|
# Three chunks so BM25Okapi IDF is non-zero for terms appearing in one doc.
|
||||||
_add_chunks(test_db, "book-a", [
|
_add_chunks(test_db, "book-a", [
|
||||||
{"page_number": 1, "text": "Grapple rules for melee attacks."},
|
{"page_number": 1, "text": "Grapple rules for melee attacks."},
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,8 @@ from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from app.main import _check_vec_schema, _reembed_docs
|
from app.startup import check_and_rebuild_vec_schema as _check_vec_schema
|
||||||
|
from app.startup import reembed_docs as _reembed_docs
|
||||||
|
|
||||||
|
|
||||||
def _make_vec_db(path: str, dims: int) -> None:
|
def _make_vec_db(path: str, dims: int) -> None:
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue