pagepiper/app/api/chat.py
pyr0ball 8eef52a054 feat: per-user database isolation for cloud instances (closes #4)
Implements Option A from the issue design: each cloud user gets their own
data directory (DATA_DIR/users/{user_id}/) with separate pagepiper.db,
pagepiper_vecs.db, uploads/, and books/. Local mode is unchanged.

Key changes:
- app/startup.py: extract apply_migrations, reembed_docs,
  check_and_rebuild_vec_schema out of main.py (no circular imports)
- app/config.py: add LOCAL_USER_ID constant and user_data_dir() helper
- app/cloud_session.py: extract resolve_authenticated_user(); require_paid_tier
  now returns user_id (str) instead of None
- app/deps.py: add UserCtx dataclass (db_path, vec_db_path, data_dir,
  watch_dir, bm25) + get_user_ctx dependency; per-user startup guard runs
  migrations + vec schema check once per process per user
- app/main.py: _bm25 singleton -> _bm25_map dict keyed by user_id;
  add _get_bm25_for(); lifespan only runs startup checks in local mode
- app/api/library.py, search.py, chat.py: thread UserCtx through all
  endpoints; remove module-level _mark_bm25_dirty injection pattern
- tests/conftest.py: override get_user_ctx in addition to get_db so all
  endpoints get a consistent test UserCtx
2026-05-13 16:31:51 -07:00

153 lines
3.9 KiB
Python

# app/api/chat.py
"""
RAG chat endpoint — retrieves relevant page chunks and synthesizes an answer.
BSL 1.1 — BYOK gate: requires PAGEPIPER_OLLAMA_URL or a Paid tier license.
Returns 402 with clear upgrade message if neither is configured.
"""
from __future__ import annotations
import logging
import os
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel
from app.cloud_session import require_paid_tier
from app.deps import UserCtx, get_user_ctx
from app.services.retriever import Retriever
from app.services.synthesizer import Synthesizer
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/chat", tags=["chat"])
class ChatTurn(BaseModel):
role: str # "user" | "assistant"
content: str
class ChatRequest(BaseModel):
message: str
history: list[ChatTurn] = []
doc_ids: list[str] | None = None
top_k: int = 10
class ChatResponse(BaseModel):
answer: str
citations: list[dict]
class ChatFeedbackRequest(BaseModel):
rating: int # 1 = thumbs up, -1 = thumbs down
question: str = ""
answer: str = ""
doc_ids: list[str] = []
def _get_llm_router():
"""Return LLMRouter if Ollama configured, else None."""
from app.config import get_llm_config
cfg = get_llm_config()
if cfg is None:
return None
from circuitforge_core.llm import LLMRouter
return LLMRouter(cfg)
def _require_llm():
"""Return LLMRouter or raise 402."""
llm = _get_llm_router()
if llm is None:
raise HTTPException(
status_code=402,
detail={
"error": "ollama_required",
"message": (
"RAG chat requires Ollama. Set PAGEPIPER_OLLAMA_URL in your .env file, "
"then restart. Run: ollama pull nomic-embed-text && ollama pull mistral:7b"
),
},
)
return llm
@router.post("")
def chat(
req: ChatRequest,
ctx: UserCtx = Depends(get_user_ctx),
_tier: str = Depends(require_paid_tier),
) -> ChatResponse:
llm = _require_llm()
retriever = Retriever(ctx.bm25)
chunks = retriever.hybrid_search(
query=req.message,
top_k=req.top_k,
doc_ids=req.doc_ids,
db_path=ctx.db_path,
vec_db_path=ctx.vec_db_path,
llm=llm,
)
if not chunks:
return ChatResponse(
answer=(
"I couldn't find any relevant passages. "
"Try a different query or check which documents are indexed."
),
citations=[],
)
synth = Synthesizer(llm)
result = synth.synthesize(
message=req.message,
history=[t.model_dump() for t in req.history],
chunks=chunks,
)
return ChatResponse(
answer=result.answer,
citations=[
{
"doc_id": c.doc_id,
"page_number": c.page_number,
"snippet": c.snippet,
"bm25_score": c.bm25_score,
}
for c in result.citations
],
)
@router.get("/feedback/status")
def chat_feedback_status() -> dict:
enabled = os.environ.get("PAGEPIPER_CHAT_FEEDBACK", "").lower() in ("1", "true", "yes")
return {"enabled": enabled}
@router.post("/feedback")
def submit_chat_feedback(
req: ChatFeedbackRequest,
ctx: UserCtx = Depends(get_user_ctx),
) -> dict:
import json
import sqlite3
if req.rating not in (1, -1):
from fastapi import HTTPException
raise HTTPException(status_code=422, detail="rating must be 1 or -1")
con = sqlite3.connect(ctx.db_path)
try:
con.execute(
"INSERT INTO chat_feedback (rating, question, answer, doc_ids) VALUES (?, ?, ?, ?)",
(req.rating, req.question[:2000], req.answer[:4000], json.dumps(req.doc_ids)),
)
con.commit()
finally:
con.close()
return {"ok": True}