From b4d9944a2f48e371d248820ebb29e2bbdb014909 Mon Sep 17 00:00:00 2001 From: pyr0ball Date: Fri, 15 May 2026 21:15:09 -0700 Subject: [PATCH] feat(blocklist): 6 REST endpoints + Pi-hole settings fields Add blocklist candidate listing, scan trigger, status update, push/unblock to Pi-hole, and connection test endpoints. Add pihole_url/version/api_key and router_source_ids/device_names fields to SettingsBody and prefs handling in patch_settings. Add PiholeClient.__post_init__ validation so 503 fires naturally when url/api_key are unconfigured (mock-safe: bypassed in tests). --- app/rest.py | 126 ++++++++++++++++++++++++ app/services/pihole.py | 2 + tests/test_blocklist_endpoints.py | 157 ++++++++++++++++++++++++++++++ 3 files changed, 285 insertions(+) create mode 100644 tests/test_blocklist_endpoints.py diff --git a/app/rest.py b/app/rest.py index c3fa7e8..94ee843 100644 --- a/app/rest.py +++ b/app/rest.py @@ -27,6 +27,16 @@ from pydantic import BaseModel from app.ingest.pipeline import ensure_schema from app.ingest.base import load_compiled_patterns from app.ingest.tautulli import parse_webhook as _parse_tautulli +from app.services.blocklist import ( + BlocklistCandidate, + list_candidates, + load_telemetry_rules, + mark_pushed, + mark_unblocked, + run_scan, + update_candidate_status, +) +from app.services.pihole import PiholeClient from app.services.incidents import ( build_bundle, create_incident, @@ -147,6 +157,11 @@ class SettingsBody(BaseModel): llm_api_key: str | None = None tautulli_token: str | None = None severity_overrides: list[SeverityOverride] | None = None + pihole_url: str | None = None + pihole_version: str | None = None + pihole_api_key: str | None = None + router_source_ids: str | None = None + device_names: str | None = None class IncidentCreate(BaseModel): @@ -347,6 +362,16 @@ def patch_settings(body: SettingsBody) -> dict: prefs["tautulli_token"] = body.tautulli_token if body.severity_overrides is not None: prefs["severity_overrides"] = [o.model_dump() for o in body.severity_overrides] + if body.pihole_url is not None: + prefs["pihole_url"] = body.pihole_url + if body.pihole_version is not None: + prefs["pihole_version"] = body.pihole_version + if body.pihole_api_key is not None: + prefs["pihole_api_key"] = body.pihole_api_key + if body.router_source_ids is not None: + prefs["router_source_ids"] = body.router_source_ids + if body.device_names is not None: + prefs["device_names"] = body.device_names _save_prefs(prefs) return prefs @@ -523,6 +548,107 @@ def ingest_tautulli( return {"stored": 1, "entry_id": entry.entry_id, "action": payload.get("action")} +class BlocklistStatusBody(BaseModel): + status: str + + +def _make_pihole_client() -> PiholeClient: + """Build PiholeClient from prefs. Raises 503 if not configured. + + The 503 is raised by catching ValueError from PiholeClient.__post_init__, + which validates that url and api_key are non-empty. When PiholeClient is + mocked in tests, __post_init__ is never called and no 503 is raised. + """ + prefs = _load_prefs() + url = prefs.get("pihole_url", "") + key = prefs.get("pihole_api_key", "") + version = prefs.get("pihole_version", "v6") + try: + return PiholeClient(url=url, api_key=key, version=version) + except ValueError as exc: + raise HTTPException( + status_code=503, + detail="Pi-hole not configured — set pihole_url and pihole_api_key in Settings", + ) from exc + + +@router.get("/api/blocklist/candidates") +def list_blocklist_candidates( + status: Annotated[str | None, Query()] = None, + device_ip: Annotated[str | None, Query()] = None, +) -> dict: + candidates = list_candidates(DB_PATH, status=status, device_ip=device_ip) + return {"candidates": [dataclasses.asdict(c) for c in candidates], "total": len(candidates)} + + +@router.post("/api/blocklist/scan") +def scan_blocklist(background_tasks: BackgroundTasks) -> dict: + prefs = _load_prefs() + source_ids = [s.strip() for s in prefs.get("router_source_ids", "").split(",") if s.strip()] + device_map: dict[str, str] = {} + raw_devices = prefs.get("device_names", "") + if raw_devices: + try: + device_map = json.loads(raw_devices) + except (ValueError, TypeError): + pass + telemetry_path = PATTERN_DIR / "telemetry.yaml" + telemetry_rules = load_telemetry_rules(telemetry_path) if telemetry_path.exists() else [] + background_tasks.add_task(run_scan, DB_PATH, source_ids, device_map, telemetry_rules) + return {"started": True} + + +@router.patch("/api/blocklist/candidates/{candidate_id}") +def update_blocklist_status(candidate_id: str, body: BlocklistStatusBody) -> dict: + try: + candidate = update_candidate_status(DB_PATH, candidate_id, body.status) + return dataclasses.asdict(candidate) + except ValueError as exc: + raise HTTPException(status_code=400, detail=str(exc)) + except KeyError: + raise HTTPException(status_code=404, detail="Candidate not found") + + +@router.post("/api/blocklist/push/{candidate_id}") +def push_to_pihole(candidate_id: str) -> dict: + candidates = list_candidates(DB_PATH) + candidate = next((c for c in candidates if c.id == candidate_id), None) + if candidate is None: + raise HTTPException(status_code=404, detail="Candidate not found") + if candidate.status != "approved": + raise HTTPException( + status_code=400, + detail=f"Candidate must be approved before pushing (current status: {candidate.status!r})", + ) + pihole = _make_pihole_client() + pihole.block(candidate.domain_or_ip) + mark_pushed(DB_PATH, candidate_id) + return {"pushed": True, "domain": candidate.domain_or_ip} + + +@router.delete("/api/blocklist/push/{candidate_id}") +def unblock_from_pihole(candidate_id: str) -> dict: + candidates = list_candidates(DB_PATH) + candidate = next((c for c in candidates if c.id == candidate_id), None) + if candidate is None: + raise HTTPException(status_code=404, detail="Candidate not found") + if candidate.status != "pushed": + raise HTTPException( + status_code=400, + detail=f"Candidate is not currently pushed (status: {candidate.status!r})", + ) + pihole = _make_pihole_client() + pihole.unblock(candidate.domain_or_ip) + mark_unblocked(DB_PATH, candidate_id) + return {"unblocked": True, "domain": candidate.domain_or_ip} + + +@router.post("/api/blocklist/test") +def test_pihole_connection() -> dict: + pihole = _make_pihole_client() + return pihole.test_connection() + + app.include_router(router) _ctx = APIRouter(prefix="/turnstone/api/context") diff --git a/app/services/pihole.py b/app/services/pihole.py index f952d46..1a77c5c 100644 --- a/app/services/pihole.py +++ b/app/services/pihole.py @@ -14,6 +14,8 @@ class PiholeClient: def __post_init__(self) -> None: self.url = self.url.rstrip("/") + if not self.url or not self.api_key: + raise ValueError("PiholeClient requires a non-empty url and api_key") # ── Public API ──────────────────────────────────────────────────────── diff --git a/tests/test_blocklist_endpoints.py b/tests/test_blocklist_endpoints.py new file mode 100644 index 0000000..1c4289a --- /dev/null +++ b/tests/test_blocklist_endpoints.py @@ -0,0 +1,157 @@ +"""Tests for the blocklist REST endpoints.""" +from __future__ import annotations + +import json +import pytest +from unittest.mock import MagicMock, patch + + +@pytest.fixture +def client(tmp_path): + from fastapi.testclient import TestClient + from app.ingest.pipeline import ensure_schema + import app.rest as rest_module + + db = tmp_path / "test.db" + ensure_schema(db) + + with patch.object(rest_module, "DB_PATH", db), \ + patch.object(rest_module, "PREFS_PATH", tmp_path / "prefs.json"), \ + patch.object(rest_module, "_compiled_patterns", []): + with TestClient(rest_module.app, raise_server_exceptions=True) as c: + yield c + + +@pytest.fixture +def client_with_candidate(tmp_path): + from fastapi.testclient import TestClient + from app.ingest.pipeline import ensure_schema + import app.rest as rest_module + import sqlite3, uuid + + db = tmp_path / "test.db" + ensure_schema(db) + cid = str(uuid.uuid4()) + conn = sqlite3.connect(str(db)) + conn.execute( + "INSERT INTO blocklist_candidates (id, domain_or_ip, first_seen, last_seen) VALUES (?, 'samsungads.com', '2026-05-14T00:00:00+00:00', '2026-05-14T00:00:00+00:00')", + (cid,), + ) + conn.commit() + conn.close() + + with patch.object(rest_module, "DB_PATH", db), \ + patch.object(rest_module, "PREFS_PATH", tmp_path / "prefs.json"), \ + patch.object(rest_module, "_compiled_patterns", []): + with TestClient(rest_module.app, raise_server_exceptions=True) as c: + yield c, cid + + +class TestListCandidates: + def test_empty_returns_empty_list(self, client): + resp = client.get("/turnstone/api/blocklist/candidates") + assert resp.status_code == 200 + assert resp.json()["candidates"] == [] + + def test_returns_candidate(self, client_with_candidate): + c, _ = client_with_candidate + resp = c.get("/turnstone/api/blocklist/candidates") + assert resp.status_code == 200 + data = resp.json() + assert data["total"] == 1 + assert data["candidates"][0]["domain_or_ip"] == "samsungads.com" + + def test_filter_by_status(self, client_with_candidate): + c, _ = client_with_candidate + assert c.get("/turnstone/api/blocklist/candidates?status=pending").json()["total"] == 1 + assert c.get("/turnstone/api/blocklist/candidates?status=pushed").json()["total"] == 0 + + +class TestScan: + def test_scan_returns_started(self, client): + with patch("app.rest.run_scan", return_value=0): + resp = client.post("/turnstone/api/blocklist/scan") + assert resp.status_code == 200 + assert resp.json()["started"] is True + + +class TestUpdateStatus: + def test_approve_candidate(self, client_with_candidate): + c, cid = client_with_candidate + resp = c.patch(f"/turnstone/api/blocklist/candidates/{cid}", json={"status": "approved"}) + assert resp.status_code == 200 + assert resp.json()["status"] == "approved" + + def test_reject_candidate(self, client_with_candidate): + c, cid = client_with_candidate + resp = c.patch(f"/turnstone/api/blocklist/candidates/{cid}", json={"status": "rejected"}) + assert resp.status_code == 200 + assert resp.json()["status"] == "rejected" + + def test_invalid_status_returns_400(self, client_with_candidate): + c, cid = client_with_candidate + resp = c.patch(f"/turnstone/api/blocklist/candidates/{cid}", json={"status": "hacked"}) + assert resp.status_code == 400 + + def test_unknown_id_returns_404(self, client): + resp = client.patch("/turnstone/api/blocklist/candidates/no-such-id", json={"status": "approved"}) + assert resp.status_code == 404 + + +class TestPush: + def test_push_approved_candidate(self, client_with_candidate): + c, cid = client_with_candidate + c.patch(f"/turnstone/api/blocklist/candidates/{cid}", json={"status": "approved"}) + mock_pihole = MagicMock() + with patch("app.rest.PiholeClient", return_value=mock_pihole): + resp = c.post(f"/turnstone/api/blocklist/push/{cid}") + assert resp.status_code == 200 + data = resp.json() + assert data["pushed"] is True + assert data["domain"] == "samsungads.com" + mock_pihole.block.assert_called_once_with("samsungads.com") + + def test_push_unapproved_returns_400(self, client_with_candidate): + c, cid = client_with_candidate + with patch("app.rest.PiholeClient"): + resp = c.post(f"/turnstone/api/blocklist/push/{cid}") + assert resp.status_code == 400 + + def test_push_no_pihole_config_returns_503(self, client_with_candidate): + c, cid = client_with_candidate + c.patch(f"/turnstone/api/blocklist/candidates/{cid}", json={"status": "approved"}) + resp = c.post(f"/turnstone/api/blocklist/push/{cid}") + assert resp.status_code == 503 + + +class TestUnblock: + def test_unblock_pushed_candidate(self, client_with_candidate): + c, cid = client_with_candidate + c.patch(f"/turnstone/api/blocklist/candidates/{cid}", json={"status": "approved"}) + mock_pihole = MagicMock() + with patch("app.rest.PiholeClient", return_value=mock_pihole): + c.post(f"/turnstone/api/blocklist/push/{cid}") + resp = c.delete(f"/turnstone/api/blocklist/push/{cid}") + assert resp.status_code == 200 + assert resp.json()["unblocked"] is True + mock_pihole.unblock.assert_called_once_with("samsungads.com") + + def test_unblock_not_pushed_returns_400(self, client_with_candidate): + c, cid = client_with_candidate + with patch("app.rest.PiholeClient"): + resp = c.delete(f"/turnstone/api/blocklist/push/{cid}") + assert resp.status_code == 400 + + +class TestPiholeTest: + def test_returns_connection_result(self, client): + mock_pihole = MagicMock() + mock_pihole.test_connection.return_value = {"ok": True, "version": "v6", "domain_count": 5, "error": None} + with patch("app.rest.PiholeClient", return_value=mock_pihole): + resp = client.post("/turnstone/api/blocklist/test") + assert resp.status_code == 200 + assert resp.json()["ok"] is True + + def test_no_pihole_config_returns_503(self, client): + resp = client.post("/turnstone/api/blocklist/test") + assert resp.status_code == 503