pagepiper/tests/test_startup.py
pyr0ball 8eef52a054 feat: per-user database isolation for cloud instances (closes #4)
Implements Option A from the issue design: each cloud user gets their own
data directory (DATA_DIR/users/{user_id}/) with separate pagepiper.db,
pagepiper_vecs.db, uploads/, and books/. Local mode is unchanged.

Key changes:
- app/startup.py: extract apply_migrations, reembed_docs,
  check_and_rebuild_vec_schema out of main.py (no circular imports)
- app/config.py: add LOCAL_USER_ID constant and user_data_dir() helper
- app/cloud_session.py: extract resolve_authenticated_user(); require_paid_tier
  now returns user_id (str) instead of None
- app/deps.py: add UserCtx dataclass (db_path, vec_db_path, data_dir,
  watch_dir, bm25) + get_user_ctx dependency; per-user startup guard runs
  migrations + vec schema check once per process per user
- app/main.py: _bm25 singleton -> _bm25_map dict keyed by user_id;
  add _get_bm25_for(); lifespan only runs startup checks in local mode
- app/api/library.py, search.py, chat.py: thread UserCtx through all
  endpoints; remove module-level _mark_bm25_dirty injection pattern
- tests/conftest.py: override get_user_ctx in addition to get_db so all
  endpoints get a consistent test UserCtx
2026-05-13 16:31:51 -07:00

171 lines
6.4 KiB
Python

# tests/test_startup.py
"""Tests for startup vec DB schema validation (_check_vec_schema)."""
from __future__ import annotations
import os
import sqlite3
import threading
from unittest.mock import MagicMock, patch
import pytest
from app.startup import check_and_rebuild_vec_schema as _check_vec_schema
from app.startup import reembed_docs as _reembed_docs
def _make_vec_db(path: str, dims: int) -> None:
"""Create a minimal sqlite-vec-style DB with the given dimension."""
conn = sqlite3.connect(path)
conn.execute("PRAGMA journal_mode=WAL")
# Replicate the virtual table name used by LocalSQLiteVecStore
conn.execute(f"CREATE TABLE page_vecs_vecs (embedding float[{dims}])")
conn.execute(
"INSERT INTO sqlite_master(type, name, tbl_name, sql) VALUES (?,?,?,?)"
if False else ""
)
# Write a real sqlite_master entry via a virtual table workaround:
# Easiest is to put the dimension marker directly in a metadata table.
# But _check_vec_schema reads sqlite_master, so we need the real DDL there.
conn.close()
# sqlite_master is read-only — recreate using the real CREATE VIRTUAL TABLE path
# by faking it via a regular table with the matching name pattern.
conn2 = sqlite3.connect(path)
conn2.execute("DROP TABLE IF EXISTS page_vecs_vecs")
# Write a row that _check_vec_schema will parse via its regex
conn2.execute(
"CREATE TABLE _schema_hint (sql TEXT)"
)
conn2.execute(
"INSERT INTO _schema_hint VALUES (?)",
[f"CREATE VIRTUAL TABLE page_vecs_vecs USING vec0(embedding float[{dims}])"],
)
conn2.commit()
conn2.close()
def _make_real_vec_db(path: str, dims: int) -> None:
"""Create a vec DB whose sqlite_master actually contains the dimension DDL."""
import sqlite3 as _sq
# We can't load sqlite-vec in tests, so simulate by writing sqlite_master directly
# via a shadow table that _check_vec_schema reads.
conn = _sq.connect(path)
conn.execute(
f"""CREATE TABLE page_vecs_vecs (
embedding float[{dims}]
)"""
)
conn.commit()
conn.close()
class TestCheckVecSchema:
def test_no_file_is_noop(self, tmp_path):
"""Missing vec DB should not raise."""
_check_vec_schema(str(tmp_path / "missing.db"), 1024, str(tmp_path / "main.db"))
def test_matching_dims_keeps_file(self, tmp_path):
"""Correct dimensions: vec DB must not be deleted."""
vec_path = str(tmp_path / "vecs.db")
conn = sqlite3.connect(vec_path)
conn.execute("CREATE TABLE page_vecs_vecs (embedding float[1024])")
conn.commit()
conn.close()
_check_vec_schema(vec_path, 1024, str(tmp_path / "main.db"))
assert os.path.exists(vec_path), "Vec DB should not be deleted when dims match"
def test_mismatched_dims_deletes_file(self, tmp_path):
"""Dimension mismatch: vec DB must be deleted."""
vec_path = str(tmp_path / "vecs.db")
conn = sqlite3.connect(vec_path)
conn.execute("CREATE TABLE page_vecs_vecs (embedding float[768])")
conn.commit()
conn.close()
db_path = str(tmp_path / "main.db")
_check_vec_schema(vec_path, 1024, db_path)
assert not os.path.exists(vec_path), "Vec DB should be deleted on dimension mismatch"
def test_mismatched_dims_queues_reembed(self, tmp_path):
"""Dimension mismatch: re-embed thread must be started for ready docs."""
vec_path = str(tmp_path / "vecs.db")
conn = sqlite3.connect(vec_path)
conn.execute("CREATE TABLE page_vecs_vecs (embedding float[768])")
conn.commit()
conn.close()
db_path = str(tmp_path / "main.db")
schema = (
"CREATE TABLE documents ("
"id TEXT PRIMARY KEY, title TEXT, file_path TEXT, "
"status TEXT, task_id TEXT, page_count INTEGER, "
"error_msg TEXT, created_at TEXT, updated_at TEXT)"
)
main_conn = sqlite3.connect(db_path)
main_conn.execute(schema)
main_conn.execute(
"INSERT INTO documents VALUES ('abc123', 'Book', '/tmp/book.pdf', 'ready', NULL, 10, NULL, '2026-01-01', '2026-01-01')"
)
main_conn.commit()
main_conn.close()
started = []
real_thread_start = threading.Thread.start
def _capture_start(self):
started.append(self)
# Don't actually run the re-embed to keep tests fast
self.run = lambda: None
real_thread_start(self)
with patch.object(threading.Thread, "start", _capture_start):
_check_vec_schema(vec_path, 1024, db_path)
assert len(started) == 1, "Exactly one re-embed thread should be started"
assert started[0].name == "pagepiper-reembed"
def test_no_ready_docs_skips_thread(self, tmp_path):
"""Mismatch with no ready docs: no thread should be started."""
vec_path = str(tmp_path / "vecs.db")
conn = sqlite3.connect(vec_path)
conn.execute("CREATE TABLE page_vecs_vecs (embedding float[768])")
conn.commit()
conn.close()
db_path = str(tmp_path / "main.db")
schema = (
"CREATE TABLE documents ("
"id TEXT PRIMARY KEY, title TEXT, file_path TEXT, "
"status TEXT, task_id TEXT, page_count INTEGER, "
"error_msg TEXT, created_at TEXT, updated_at TEXT)"
)
main_conn = sqlite3.connect(db_path)
main_conn.execute(schema)
main_conn.commit()
main_conn.close()
started = []
with patch.object(threading.Thread, "start", lambda self: started.append(self)):
_check_vec_schema(vec_path, 1024, db_path)
assert len(started) == 0
def test_empty_db_no_table_is_noop(self, tmp_path):
"""Vec DB exists but has no page_vecs_vecs table yet: no deletion."""
vec_path = str(tmp_path / "vecs.db")
sqlite3.connect(vec_path).close() # create empty file
_check_vec_schema(vec_path, 1024, str(tmp_path / "main.db"))
assert os.path.exists(vec_path)
def test_corrupt_db_does_not_raise(self, tmp_path):
"""Corrupt or unreadable vec DB must not propagate exceptions."""
vec_path = str(tmp_path / "vecs.db")
with open(vec_path, "w") as f:
f.write("not a sqlite database")
_check_vec_schema(vec_path, 1024, str(tmp_path / "main.db"))
# No assertion needed — just must not raise