diff --git a/app/api/chat.py b/app/api/chat.py index 8260801..1ec6547 100644 --- a/app/api/chat.py +++ b/app/api/chat.py @@ -1,5 +1,126 @@ # app/api/chat.py -"""RAG chat API — streaming LLM responses with page citations (Task 6).""" -from fastapi import APIRouter +""" +RAG chat endpoint — retrieves relevant page chunks and synthesizes an answer. +BSL 1.1 — BYOK gate: requires PAGEPIPER_OLLAMA_URL or a Paid tier license. +Returns 402 with clear upgrade message if neither is configured. +""" +from __future__ import annotations + +import logging +import os + +from fastapi import APIRouter, HTTPException +from pydantic import BaseModel + +from app.services.retriever import Retriever +from app.services.synthesizer import Synthesizer + +logger = logging.getLogger(__name__) router = APIRouter(prefix="/api/chat", tags=["chat"]) + + +class ChatTurn(BaseModel): + role: str # "user" | "assistant" + content: str + + +class ChatRequest(BaseModel): + message: str + history: list[ChatTurn] = [] + doc_ids: list[str] | None = None + top_k: int = 5 + + +class ChatResponse(BaseModel): + answer: str + citations: list[dict] + + +def _get_llm_router(): + """Return LLMRouter if Ollama configured, else None.""" + from app.config import get_llm_config + + cfg = get_llm_config() + if cfg is None: + return None + from circuitforge_core.llm import LLMRouter + + return LLMRouter(cfg) + + +def _get_db_path() -> str: + """Read lazily so test fixtures take effect.""" + import pathlib + + data_dir = pathlib.Path(os.environ.get("PAGEPIPER_DATA_DIR", "data")) + return str(data_dir / "pagepiper.db") + + +def _get_vec_db_path() -> str: + import pathlib + + data_dir = pathlib.Path(os.environ.get("PAGEPIPER_DATA_DIR", "data")) + return str(data_dir / "pagepiper_vecs.db") + + +def _require_llm(): + """Return LLMRouter or raise 402.""" + llm = _get_llm_router() + if llm is None: + raise HTTPException( + status_code=402, + detail={ + "error": "ollama_required", + "message": ( + "RAG chat requires Ollama. Set PAGEPIPER_OLLAMA_URL in your .env file, " + "then restart. Run: ollama pull nomic-embed-text && ollama pull mistral:7b" + ), + }, + ) + return llm + + +@router.post("") +def chat(req: ChatRequest) -> ChatResponse: + llm = _require_llm() + + from app.main import _bm25 + + retriever = Retriever(_bm25) + chunks = retriever.hybrid_search( + query=req.message, + top_k=req.top_k, + doc_ids=req.doc_ids, + db_path=_get_db_path(), + vec_db_path=_get_vec_db_path(), + llm=llm, + ) + + if not chunks: + return ChatResponse( + answer=( + "I couldn't find any relevant passages. " + "Try a different query or check which documents are indexed." + ), + citations=[], + ) + + synth = Synthesizer(llm) + result = synth.synthesize( + message=req.message, + history=[t.model_dump() for t in req.history], + chunks=chunks, + ) + + return ChatResponse( + answer=result.answer, + citations=[ + { + "doc_id": c.doc_id, + "page_number": c.page_number, + "snippet": c.snippet, + } + for c in result.citations + ], + ) diff --git a/app/services/retriever.py b/app/services/retriever.py new file mode 100644 index 0000000..b2d09a5 --- /dev/null +++ b/app/services/retriever.py @@ -0,0 +1,121 @@ +# app/services/retriever.py +""" +Hybrid BM25 + semantic retriever. + +BSL 1.1 — semantic path requires PAGEPIPER_OLLAMA_URL (BYOK gate). +BM25-only path is MIT and has no gate. +""" +from __future__ import annotations + +import logging +from dataclasses import dataclass + +from app.services.bm25_index import BM25Index + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class RetrievedChunk: + """A chunk returned by the retriever, with source scores.""" + + chunk_id: str + doc_id: str + page_number: int + text: str + bm25_score: float + vector_score: float | None + + +class Retriever: + def __init__(self, bm25: BM25Index) -> None: + self._bm25 = bm25 + + def hybrid_search( + self, + query: str, + top_k: int, + doc_ids: list[str] | None, + db_path: str, + vec_db_path: str, + llm, # LLMRouter | None — caller must pass + ) -> list[RetrievedChunk]: + """ + Merge BM25 and semantic results. + Falls back to BM25-only if llm is None. + """ + if llm is None: + return self._bm25_only(query, top_k, doc_ids, db_path) + + from circuitforge_core.vector.sqlite_vec import LocalSQLiteVecStore + + self._bm25.ensure_fresh(db_path) + bm25_hits = { + r.chunk_id: r + for r in self._bm25.query(query, top_k=top_k * 2, doc_ids=doc_ids) + } + + vec = llm.embed([query])[0] + store = LocalSQLiteVecStore(db_path=vec_db_path, table="page_vecs", dimensions=768) + filter_meta = {"doc_id": doc_ids[0]} if doc_ids and len(doc_ids) == 1 else None + vec_hits = store.query(vec, top_k=top_k * 2, filter_metadata=filter_meta) + + if doc_ids and len(doc_ids) > 1: + vec_hits = [h for h in vec_hits if h.metadata.get("doc_id") in doc_ids] + + # Merge: BM25 hits take priority; vector hits fill in additional results + merged: dict[str, RetrievedChunk] = {} + for cid, r in bm25_hits.items(): + merged[cid] = RetrievedChunk( + chunk_id=cid, + doc_id=r.doc_id, + page_number=r.page_number, + text=r.text, + bm25_score=r.score, + vector_score=None, + ) + for vh in vec_hits: + text = next((c["text"] for c in self._bm25._chunks if c["id"] == vh.id), "") + if vh.id in merged: + existing = merged[vh.id] + merged[vh.id] = RetrievedChunk( + chunk_id=existing.chunk_id, + doc_id=existing.doc_id, + page_number=existing.page_number, + text=existing.text, + bm25_score=existing.bm25_score, + vector_score=vh.score, + ) + else: + merged[vh.id] = RetrievedChunk( + chunk_id=vh.id, + doc_id=vh.metadata.get("doc_id", ""), + page_number=int(vh.metadata.get("page_number", 0)), + text=text, + bm25_score=0.0, + vector_score=vh.score, + ) + + def _combined(r: RetrievedChunk) -> float: + bm25 = r.bm25_score + vec = (1.0 / (1.0 + r.vector_score)) if r.vector_score is not None else 0.0 + return bm25 * 0.5 + vec * 0.5 + + ranked = sorted(merged.values(), key=_combined, reverse=True) + return ranked[:top_k] + + def _bm25_only( + self, query: str, top_k: int, doc_ids: list[str] | None, db_path: str + ) -> list[RetrievedChunk]: + self._bm25.ensure_fresh(db_path) + return [ + RetrievedChunk( + chunk_id=r.chunk_id, + doc_id=r.doc_id, + page_number=r.page_number, + text=r.text, + bm25_score=r.score, + vector_score=None, + ) + for r in self._bm25.query(query, top_k=top_k, doc_ids=doc_ids) + ] diff --git a/app/services/synthesizer.py b/app/services/synthesizer.py new file mode 100644 index 0000000..d11640b --- /dev/null +++ b/app/services/synthesizer.py @@ -0,0 +1,58 @@ +# app/services/synthesizer.py +""" +LLM answer synthesis over retrieved chunks. + +BSL 1.1 — requires LLMRouter (Ollama BYOK or cloud tier). +""" +from __future__ import annotations + +from dataclasses import dataclass + +from app.services.retriever import RetrievedChunk + +_SYSTEM_PROMPT = ( + "You are a helpful document assistant. " + "Answer the user's question using ONLY the provided document excerpts. " + "For each claim, cite the source page as [p.N]. " + "If the excerpts are insufficient, say so. Do not invent information." +) + + +@dataclass(frozen=True) +class Citation: + doc_id: str + page_number: int + snippet: str + + +@dataclass(frozen=True) +class SynthesisResult: + answer: str + citations: list[Citation] + + +class Synthesizer: + def __init__(self, llm) -> None: # LLMRouter + self._llm = llm + + def synthesize( + self, + message: str, + history: list[dict], + chunks: list[RetrievedChunk], + ) -> SynthesisResult: + context_parts = [f"[p.{c.page_number}]\n{c.text[:500]}" for c in chunks] + context = "\n\n---\n\n".join(context_parts) + prompt = f"Document excerpts:\n\n{context}\n\nQuestion: {message}" + + answer = self._llm.complete(prompt, system=_SYSTEM_PROMPT) + + citations = [ + Citation( + doc_id=c.doc_id, + page_number=c.page_number, + snippet=c.text[:200], + ) + for c in chunks + ] + return SynthesisResult(answer=answer, citations=citations) diff --git a/tests/test_chat_api.py b/tests/test_chat_api.py new file mode 100644 index 0000000..476d5b4 --- /dev/null +++ b/tests/test_chat_api.py @@ -0,0 +1,59 @@ +# tests/test_chat_api.py +"""Tests for POST /api/chat — RAG chat (BSL, BYOK gate).""" +from __future__ import annotations + +import sqlite3 +from unittest.mock import MagicMock, patch + +from app.services.retriever import RetrievedChunk + + +def test_chat_returns_402_without_ollama(client, monkeypatch): + monkeypatch.delenv("PAGEPIPER_OLLAMA_URL", raising=False) + resp = client.post("/api/chat", json={"message": "How does Fireball work?", "history": []}) + assert resp.status_code == 402 + body = resp.json() + assert "detail" in body + assert "Ollama" in body["detail"]["message"] + + +def test_chat_returns_answer_with_mocked_ollama(client, test_db, monkeypatch): + monkeypatch.setenv("PAGEPIPER_OLLAMA_URL", "http://localhost:11434") + + conn = sqlite3.connect(test_db) + conn.execute( + "INSERT OR IGNORE INTO documents(id, title, file_path, status) VALUES ('b1','PHB','phb.pdf','ready')" + ) + conn.execute( + "INSERT INTO page_chunks(doc_id, page_number, text, source, word_count) " + "VALUES ('b1',15,'Fireball deals 8d6 fire damage.','text_layer',6)" + ) + conn.commit() + conn.close() + + mock_llm = MagicMock() + mock_llm.complete.return_value = "Fireball deals 8d6 fire damage [p.15]." + + mock_chunks = [ + RetrievedChunk( + chunk_id="c1", + doc_id="b1", + page_number=15, + text="Fireball deals 8d6 fire damage.", + bm25_score=1.0, + vector_score=None, + ) + ] + + with patch("app.api.chat.Retriever.hybrid_search", return_value=mock_chunks): + with patch("app.api.chat._get_llm_router", return_value=mock_llm): + resp = client.post( + "/api/chat", + json={"message": "How does Fireball work?", "history": [], "doc_ids": ["b1"]}, + ) + + assert resp.status_code == 200 + body = resp.json() + assert "answer" in body + assert "citations" in body + assert "Fireball" in body["answer"] diff --git a/tests/test_synthesizer.py b/tests/test_synthesizer.py new file mode 100644 index 0000000..b173960 --- /dev/null +++ b/tests/test_synthesizer.py @@ -0,0 +1,54 @@ +# tests/test_synthesizer.py +"""Tests for Synthesizer — mocked LLM, citation assembly.""" +from __future__ import annotations + +from unittest.mock import MagicMock + +from app.services.retriever import RetrievedChunk +from app.services.synthesizer import Synthesizer, SynthesisResult + + +def _chunk(doc_id: str = "book-a", page: int = 5, text: str = "Fireball rules") -> RetrievedChunk: + return RetrievedChunk( + chunk_id="c1", doc_id=doc_id, page_number=page, text=text, + bm25_score=1.0, vector_score=None, + ) + + +def test_synthesizer_returns_answer_and_citations(): + mock_llm = MagicMock() + mock_llm.complete.return_value = "Fireball deals 8d6 damage [p.5]." + + synth = Synthesizer(mock_llm) + result = synth.synthesize( + message="How does Fireball work?", + history=[], + chunks=[_chunk()], + ) + + assert isinstance(result, SynthesisResult) + assert "Fireball" in result.answer + assert len(result.citations) == 1 + assert result.citations[0].page_number == 5 + assert result.citations[0].doc_id == "book-a" + + +def test_synthesizer_builds_context_from_chunks(): + mock_llm = MagicMock() + mock_llm.complete.return_value = "Answer." + + synth = Synthesizer(mock_llm) + synth.synthesize("Q?", [], [_chunk(text="Detailed rule text here.")]) + + call_args = mock_llm.complete.call_args + assert "Detailed rule text here." in call_args[0][0] or "Detailed rule text here." in str(call_args) + + +def test_synthesizer_uses_system_prompt(): + mock_llm = MagicMock() + mock_llm.complete.return_value = "Answer." + synth = Synthesizer(mock_llm) + synth.synthesize("Q?", [], [_chunk()]) + + call_kwargs = mock_llm.complete.call_args + assert call_kwargs.kwargs.get("system") or call_kwargs[1].get("system")