feat(api): add retriever, synthesizer, and chat endpoint (BSL — BYOK gate)
- app/services/retriever.py: hybrid BM25 + semantic Retriever with BM25-only fallback when llm=None - app/services/synthesizer.py: LLM answer synthesis with citation assembly over retrieved chunks - app/api/chat.py: POST /api/chat endpoint with 402 gate when PAGEPIPER_OLLAMA_URL is unset - tests/test_synthesizer.py: 3 TDD unit tests (mocked LLM, context building, system prompt) - tests/test_chat_api.py: 2 integration tests (402 without Ollama, 200 with mocked retriever+LLM)
This commit is contained in:
parent
eb5c7383ed
commit
0e493ab560
5 changed files with 415 additions and 2 deletions
125
app/api/chat.py
125
app/api/chat.py
|
|
@ -1,5 +1,126 @@
|
||||||
# app/api/chat.py
|
# app/api/chat.py
|
||||||
"""RAG chat API — streaming LLM responses with page citations (Task 6)."""
|
"""
|
||||||
from fastapi import APIRouter
|
RAG chat endpoint — retrieves relevant page chunks and synthesizes an answer.
|
||||||
|
|
||||||
|
BSL 1.1 — BYOK gate: requires PAGEPIPER_OLLAMA_URL or a Paid tier license.
|
||||||
|
Returns 402 with clear upgrade message if neither is configured.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
|
from fastapi import APIRouter, HTTPException
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from app.services.retriever import Retriever
|
||||||
|
from app.services.synthesizer import Synthesizer
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
router = APIRouter(prefix="/api/chat", tags=["chat"])
|
router = APIRouter(prefix="/api/chat", tags=["chat"])
|
||||||
|
|
||||||
|
|
||||||
|
class ChatTurn(BaseModel):
|
||||||
|
role: str # "user" | "assistant"
|
||||||
|
content: str
|
||||||
|
|
||||||
|
|
||||||
|
class ChatRequest(BaseModel):
|
||||||
|
message: str
|
||||||
|
history: list[ChatTurn] = []
|
||||||
|
doc_ids: list[str] | None = None
|
||||||
|
top_k: int = 5
|
||||||
|
|
||||||
|
|
||||||
|
class ChatResponse(BaseModel):
|
||||||
|
answer: str
|
||||||
|
citations: list[dict]
|
||||||
|
|
||||||
|
|
||||||
|
def _get_llm_router():
|
||||||
|
"""Return LLMRouter if Ollama configured, else None."""
|
||||||
|
from app.config import get_llm_config
|
||||||
|
|
||||||
|
cfg = get_llm_config()
|
||||||
|
if cfg is None:
|
||||||
|
return None
|
||||||
|
from circuitforge_core.llm import LLMRouter
|
||||||
|
|
||||||
|
return LLMRouter(cfg)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_db_path() -> str:
|
||||||
|
"""Read lazily so test fixtures take effect."""
|
||||||
|
import pathlib
|
||||||
|
|
||||||
|
data_dir = pathlib.Path(os.environ.get("PAGEPIPER_DATA_DIR", "data"))
|
||||||
|
return str(data_dir / "pagepiper.db")
|
||||||
|
|
||||||
|
|
||||||
|
def _get_vec_db_path() -> str:
|
||||||
|
import pathlib
|
||||||
|
|
||||||
|
data_dir = pathlib.Path(os.environ.get("PAGEPIPER_DATA_DIR", "data"))
|
||||||
|
return str(data_dir / "pagepiper_vecs.db")
|
||||||
|
|
||||||
|
|
||||||
|
def _require_llm():
|
||||||
|
"""Return LLMRouter or raise 402."""
|
||||||
|
llm = _get_llm_router()
|
||||||
|
if llm is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=402,
|
||||||
|
detail={
|
||||||
|
"error": "ollama_required",
|
||||||
|
"message": (
|
||||||
|
"RAG chat requires Ollama. Set PAGEPIPER_OLLAMA_URL in your .env file, "
|
||||||
|
"then restart. Run: ollama pull nomic-embed-text && ollama pull mistral:7b"
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return llm
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("")
|
||||||
|
def chat(req: ChatRequest) -> ChatResponse:
|
||||||
|
llm = _require_llm()
|
||||||
|
|
||||||
|
from app.main import _bm25
|
||||||
|
|
||||||
|
retriever = Retriever(_bm25)
|
||||||
|
chunks = retriever.hybrid_search(
|
||||||
|
query=req.message,
|
||||||
|
top_k=req.top_k,
|
||||||
|
doc_ids=req.doc_ids,
|
||||||
|
db_path=_get_db_path(),
|
||||||
|
vec_db_path=_get_vec_db_path(),
|
||||||
|
llm=llm,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not chunks:
|
||||||
|
return ChatResponse(
|
||||||
|
answer=(
|
||||||
|
"I couldn't find any relevant passages. "
|
||||||
|
"Try a different query or check which documents are indexed."
|
||||||
|
),
|
||||||
|
citations=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
synth = Synthesizer(llm)
|
||||||
|
result = synth.synthesize(
|
||||||
|
message=req.message,
|
||||||
|
history=[t.model_dump() for t in req.history],
|
||||||
|
chunks=chunks,
|
||||||
|
)
|
||||||
|
|
||||||
|
return ChatResponse(
|
||||||
|
answer=result.answer,
|
||||||
|
citations=[
|
||||||
|
{
|
||||||
|
"doc_id": c.doc_id,
|
||||||
|
"page_number": c.page_number,
|
||||||
|
"snippet": c.snippet,
|
||||||
|
}
|
||||||
|
for c in result.citations
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
|
||||||
121
app/services/retriever.py
Normal file
121
app/services/retriever.py
Normal file
|
|
@ -0,0 +1,121 @@
|
||||||
|
# app/services/retriever.py
|
||||||
|
"""
|
||||||
|
Hybrid BM25 + semantic retriever.
|
||||||
|
|
||||||
|
BSL 1.1 — semantic path requires PAGEPIPER_OLLAMA_URL (BYOK gate).
|
||||||
|
BM25-only path is MIT and has no gate.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from app.services.bm25_index import BM25Index
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class RetrievedChunk:
|
||||||
|
"""A chunk returned by the retriever, with source scores."""
|
||||||
|
|
||||||
|
chunk_id: str
|
||||||
|
doc_id: str
|
||||||
|
page_number: int
|
||||||
|
text: str
|
||||||
|
bm25_score: float
|
||||||
|
vector_score: float | None
|
||||||
|
|
||||||
|
|
||||||
|
class Retriever:
|
||||||
|
def __init__(self, bm25: BM25Index) -> None:
|
||||||
|
self._bm25 = bm25
|
||||||
|
|
||||||
|
def hybrid_search(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
top_k: int,
|
||||||
|
doc_ids: list[str] | None,
|
||||||
|
db_path: str,
|
||||||
|
vec_db_path: str,
|
||||||
|
llm, # LLMRouter | None — caller must pass
|
||||||
|
) -> list[RetrievedChunk]:
|
||||||
|
"""
|
||||||
|
Merge BM25 and semantic results.
|
||||||
|
Falls back to BM25-only if llm is None.
|
||||||
|
"""
|
||||||
|
if llm is None:
|
||||||
|
return self._bm25_only(query, top_k, doc_ids, db_path)
|
||||||
|
|
||||||
|
from circuitforge_core.vector.sqlite_vec import LocalSQLiteVecStore
|
||||||
|
|
||||||
|
self._bm25.ensure_fresh(db_path)
|
||||||
|
bm25_hits = {
|
||||||
|
r.chunk_id: r
|
||||||
|
for r in self._bm25.query(query, top_k=top_k * 2, doc_ids=doc_ids)
|
||||||
|
}
|
||||||
|
|
||||||
|
vec = llm.embed([query])[0]
|
||||||
|
store = LocalSQLiteVecStore(db_path=vec_db_path, table="page_vecs", dimensions=768)
|
||||||
|
filter_meta = {"doc_id": doc_ids[0]} if doc_ids and len(doc_ids) == 1 else None
|
||||||
|
vec_hits = store.query(vec, top_k=top_k * 2, filter_metadata=filter_meta)
|
||||||
|
|
||||||
|
if doc_ids and len(doc_ids) > 1:
|
||||||
|
vec_hits = [h for h in vec_hits if h.metadata.get("doc_id") in doc_ids]
|
||||||
|
|
||||||
|
# Merge: BM25 hits take priority; vector hits fill in additional results
|
||||||
|
merged: dict[str, RetrievedChunk] = {}
|
||||||
|
for cid, r in bm25_hits.items():
|
||||||
|
merged[cid] = RetrievedChunk(
|
||||||
|
chunk_id=cid,
|
||||||
|
doc_id=r.doc_id,
|
||||||
|
page_number=r.page_number,
|
||||||
|
text=r.text,
|
||||||
|
bm25_score=r.score,
|
||||||
|
vector_score=None,
|
||||||
|
)
|
||||||
|
for vh in vec_hits:
|
||||||
|
text = next((c["text"] for c in self._bm25._chunks if c["id"] == vh.id), "")
|
||||||
|
if vh.id in merged:
|
||||||
|
existing = merged[vh.id]
|
||||||
|
merged[vh.id] = RetrievedChunk(
|
||||||
|
chunk_id=existing.chunk_id,
|
||||||
|
doc_id=existing.doc_id,
|
||||||
|
page_number=existing.page_number,
|
||||||
|
text=existing.text,
|
||||||
|
bm25_score=existing.bm25_score,
|
||||||
|
vector_score=vh.score,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
merged[vh.id] = RetrievedChunk(
|
||||||
|
chunk_id=vh.id,
|
||||||
|
doc_id=vh.metadata.get("doc_id", ""),
|
||||||
|
page_number=int(vh.metadata.get("page_number", 0)),
|
||||||
|
text=text,
|
||||||
|
bm25_score=0.0,
|
||||||
|
vector_score=vh.score,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _combined(r: RetrievedChunk) -> float:
|
||||||
|
bm25 = r.bm25_score
|
||||||
|
vec = (1.0 / (1.0 + r.vector_score)) if r.vector_score is not None else 0.0
|
||||||
|
return bm25 * 0.5 + vec * 0.5
|
||||||
|
|
||||||
|
ranked = sorted(merged.values(), key=_combined, reverse=True)
|
||||||
|
return ranked[:top_k]
|
||||||
|
|
||||||
|
def _bm25_only(
|
||||||
|
self, query: str, top_k: int, doc_ids: list[str] | None, db_path: str
|
||||||
|
) -> list[RetrievedChunk]:
|
||||||
|
self._bm25.ensure_fresh(db_path)
|
||||||
|
return [
|
||||||
|
RetrievedChunk(
|
||||||
|
chunk_id=r.chunk_id,
|
||||||
|
doc_id=r.doc_id,
|
||||||
|
page_number=r.page_number,
|
||||||
|
text=r.text,
|
||||||
|
bm25_score=r.score,
|
||||||
|
vector_score=None,
|
||||||
|
)
|
||||||
|
for r in self._bm25.query(query, top_k=top_k, doc_ids=doc_ids)
|
||||||
|
]
|
||||||
58
app/services/synthesizer.py
Normal file
58
app/services/synthesizer.py
Normal file
|
|
@ -0,0 +1,58 @@
|
||||||
|
# app/services/synthesizer.py
|
||||||
|
"""
|
||||||
|
LLM answer synthesis over retrieved chunks.
|
||||||
|
|
||||||
|
BSL 1.1 — requires LLMRouter (Ollama BYOK or cloud tier).
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from app.services.retriever import RetrievedChunk
|
||||||
|
|
||||||
|
_SYSTEM_PROMPT = (
|
||||||
|
"You are a helpful document assistant. "
|
||||||
|
"Answer the user's question using ONLY the provided document excerpts. "
|
||||||
|
"For each claim, cite the source page as [p.N]. "
|
||||||
|
"If the excerpts are insufficient, say so. Do not invent information."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class Citation:
|
||||||
|
doc_id: str
|
||||||
|
page_number: int
|
||||||
|
snippet: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class SynthesisResult:
|
||||||
|
answer: str
|
||||||
|
citations: list[Citation]
|
||||||
|
|
||||||
|
|
||||||
|
class Synthesizer:
|
||||||
|
def __init__(self, llm) -> None: # LLMRouter
|
||||||
|
self._llm = llm
|
||||||
|
|
||||||
|
def synthesize(
|
||||||
|
self,
|
||||||
|
message: str,
|
||||||
|
history: list[dict],
|
||||||
|
chunks: list[RetrievedChunk],
|
||||||
|
) -> SynthesisResult:
|
||||||
|
context_parts = [f"[p.{c.page_number}]\n{c.text[:500]}" for c in chunks]
|
||||||
|
context = "\n\n---\n\n".join(context_parts)
|
||||||
|
prompt = f"Document excerpts:\n\n{context}\n\nQuestion: {message}"
|
||||||
|
|
||||||
|
answer = self._llm.complete(prompt, system=_SYSTEM_PROMPT)
|
||||||
|
|
||||||
|
citations = [
|
||||||
|
Citation(
|
||||||
|
doc_id=c.doc_id,
|
||||||
|
page_number=c.page_number,
|
||||||
|
snippet=c.text[:200],
|
||||||
|
)
|
||||||
|
for c in chunks
|
||||||
|
]
|
||||||
|
return SynthesisResult(answer=answer, citations=citations)
|
||||||
59
tests/test_chat_api.py
Normal file
59
tests/test_chat_api.py
Normal file
|
|
@ -0,0 +1,59 @@
|
||||||
|
# tests/test_chat_api.py
|
||||||
|
"""Tests for POST /api/chat — RAG chat (BSL, BYOK gate)."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import sqlite3
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
from app.services.retriever import RetrievedChunk
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_returns_402_without_ollama(client, monkeypatch):
|
||||||
|
monkeypatch.delenv("PAGEPIPER_OLLAMA_URL", raising=False)
|
||||||
|
resp = client.post("/api/chat", json={"message": "How does Fireball work?", "history": []})
|
||||||
|
assert resp.status_code == 402
|
||||||
|
body = resp.json()
|
||||||
|
assert "detail" in body
|
||||||
|
assert "Ollama" in body["detail"]["message"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_returns_answer_with_mocked_ollama(client, test_db, monkeypatch):
|
||||||
|
monkeypatch.setenv("PAGEPIPER_OLLAMA_URL", "http://localhost:11434")
|
||||||
|
|
||||||
|
conn = sqlite3.connect(test_db)
|
||||||
|
conn.execute(
|
||||||
|
"INSERT OR IGNORE INTO documents(id, title, file_path, status) VALUES ('b1','PHB','phb.pdf','ready')"
|
||||||
|
)
|
||||||
|
conn.execute(
|
||||||
|
"INSERT INTO page_chunks(doc_id, page_number, text, source, word_count) "
|
||||||
|
"VALUES ('b1',15,'Fireball deals 8d6 fire damage.','text_layer',6)"
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
mock_llm = MagicMock()
|
||||||
|
mock_llm.complete.return_value = "Fireball deals 8d6 fire damage [p.15]."
|
||||||
|
|
||||||
|
mock_chunks = [
|
||||||
|
RetrievedChunk(
|
||||||
|
chunk_id="c1",
|
||||||
|
doc_id="b1",
|
||||||
|
page_number=15,
|
||||||
|
text="Fireball deals 8d6 fire damage.",
|
||||||
|
bm25_score=1.0,
|
||||||
|
vector_score=None,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
with patch("app.api.chat.Retriever.hybrid_search", return_value=mock_chunks):
|
||||||
|
with patch("app.api.chat._get_llm_router", return_value=mock_llm):
|
||||||
|
resp = client.post(
|
||||||
|
"/api/chat",
|
||||||
|
json={"message": "How does Fireball work?", "history": [], "doc_ids": ["b1"]},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
body = resp.json()
|
||||||
|
assert "answer" in body
|
||||||
|
assert "citations" in body
|
||||||
|
assert "Fireball" in body["answer"]
|
||||||
54
tests/test_synthesizer.py
Normal file
54
tests/test_synthesizer.py
Normal file
|
|
@ -0,0 +1,54 @@
|
||||||
|
# tests/test_synthesizer.py
|
||||||
|
"""Tests for Synthesizer — mocked LLM, citation assembly."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
from app.services.retriever import RetrievedChunk
|
||||||
|
from app.services.synthesizer import Synthesizer, SynthesisResult
|
||||||
|
|
||||||
|
|
||||||
|
def _chunk(doc_id: str = "book-a", page: int = 5, text: str = "Fireball rules") -> RetrievedChunk:
|
||||||
|
return RetrievedChunk(
|
||||||
|
chunk_id="c1", doc_id=doc_id, page_number=page, text=text,
|
||||||
|
bm25_score=1.0, vector_score=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_synthesizer_returns_answer_and_citations():
|
||||||
|
mock_llm = MagicMock()
|
||||||
|
mock_llm.complete.return_value = "Fireball deals 8d6 damage [p.5]."
|
||||||
|
|
||||||
|
synth = Synthesizer(mock_llm)
|
||||||
|
result = synth.synthesize(
|
||||||
|
message="How does Fireball work?",
|
||||||
|
history=[],
|
||||||
|
chunks=[_chunk()],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result, SynthesisResult)
|
||||||
|
assert "Fireball" in result.answer
|
||||||
|
assert len(result.citations) == 1
|
||||||
|
assert result.citations[0].page_number == 5
|
||||||
|
assert result.citations[0].doc_id == "book-a"
|
||||||
|
|
||||||
|
|
||||||
|
def test_synthesizer_builds_context_from_chunks():
|
||||||
|
mock_llm = MagicMock()
|
||||||
|
mock_llm.complete.return_value = "Answer."
|
||||||
|
|
||||||
|
synth = Synthesizer(mock_llm)
|
||||||
|
synth.synthesize("Q?", [], [_chunk(text="Detailed rule text here.")])
|
||||||
|
|
||||||
|
call_args = mock_llm.complete.call_args
|
||||||
|
assert "Detailed rule text here." in call_args[0][0] or "Detailed rule text here." in str(call_args)
|
||||||
|
|
||||||
|
|
||||||
|
def test_synthesizer_uses_system_prompt():
|
||||||
|
mock_llm = MagicMock()
|
||||||
|
mock_llm.complete.return_value = "Answer."
|
||||||
|
synth = Synthesizer(mock_llm)
|
||||||
|
synth.synthesize("Q?", [], [_chunk()])
|
||||||
|
|
||||||
|
call_kwargs = mock_llm.complete.call_args
|
||||||
|
assert call_kwargs.kwargs.get("system") or call_kwargs[1].get("system")
|
||||||
Loading…
Reference in a new issue