From dceb2d30ca54e19e11862359d46a9be8b585a261 Mon Sep 17 00:00:00 2001 From: pyr0ball Date: Fri, 15 May 2026 21:00:01 -0700 Subject: [PATCH] feat(blocklist): Pi-hole v5/v6 API client + tests PiholeClient dataclass supporting both Pi-hole v5 (PHP /admin/api.php) and v6 (REST /api/) with public block/unblock/test_connection methods. 9 tests covering both API versions, auth flow, and error handling. --- app/services/pihole.py | 89 ++++++++++++++++++++++++++++++++++++ tests/test_service_pihole.py | 80 ++++++++++++++++++++++++++++++++ 2 files changed, 169 insertions(+) create mode 100644 app/services/pihole.py create mode 100644 tests/test_service_pihole.py diff --git a/app/services/pihole.py b/app/services/pihole.py new file mode 100644 index 0000000..0648072 --- /dev/null +++ b/app/services/pihole.py @@ -0,0 +1,89 @@ +"""Pi-hole API client supporting v5 (PHP) and v6 (REST) APIs.""" +from __future__ import annotations + +import dataclasses + +import httpx + + +@dataclasses.dataclass +class PiholeClient: + url: str + api_key: str + version: str = "v6" # "v5" | "v6" + + def __post_init__(self) -> None: + self.url = self.url.rstrip("/") + + # ── Public API ──────────────────────────────────────────────────────── + + def block(self, domain: str, comment: str = "Turnstone block") -> None: + if self.version == "v5": + self._v5_get("black", "add", domain) + else: + sid = self._v6_auth() + self._v6_post_domain(sid, domain, comment) + + def unblock(self, domain: str) -> None: + if self.version == "v5": + self._v5_get("black", "sub", domain) + else: + sid = self._v6_auth() + self._v6_delete_domain(sid, domain) + + def test_connection(self) -> dict: + try: + if self.version == "v5": + return self._v5_test() + return self._v6_test() + except Exception as exc: + return {"ok": False, "version": self.version, "domain_count": 0, "error": str(exc)} + + # ── v5 helpers ──────────────────────────────────────────────────────── + + def _v5_get(self, list_type: str, action: str, domain: str) -> None: + params = {"list": list_type, action: domain, "auth": self.api_key} + with httpx.Client(timeout=10) as c: + c.get(f"{self.url}/admin/api.php", params=params).raise_for_status() + + def _v5_test(self) -> dict: + with httpx.Client(timeout=10) as c: + r = c.get(f"{self.url}/admin/api.php", params={"summaryRaw": "", "auth": self.api_key}) + r.raise_for_status() + data = r.json() + return { + "ok": True, + "version": "v5", + "domain_count": int(data.get("domains_being_blocked", 0)), + "error": None, + } + + # ── v6 helpers ──────────────────────────────────────────────────────── + + def _v6_auth(self) -> str: + with httpx.Client(timeout=10) as c: + r = c.post(f"{self.url}/api/auth", json={"password": self.api_key}) + r.raise_for_status() + return r.json()["session"]["sid"] + + def _v6_post_domain(self, sid: str, domain: str, comment: str) -> None: + body = [{"domain": domain, "comment": comment, "enabled": True}] + with httpx.Client(timeout=10, cookies={"sid": sid}) as c: + c.post(f"{self.url}/api/domains/deny", json=body).raise_for_status() + + def _v6_delete_domain(self, sid: str, domain: str) -> None: + with httpx.Client(timeout=10, cookies={"sid": sid}) as c: + c.delete(f"{self.url}/api/domains/deny/{domain}").raise_for_status() + + def _v6_test(self) -> dict: + sid = self._v6_auth() + with httpx.Client(timeout=10, cookies={"sid": sid}) as c: + r = c.get(f"{self.url}/api/domains/deny") + r.raise_for_status() + data = r.json() + return { + "ok": True, + "version": "v6", + "domain_count": len(data.get("data", [])), + "error": None, + } diff --git a/tests/test_service_pihole.py b/tests/test_service_pihole.py new file mode 100644 index 0000000..ba4dd47 --- /dev/null +++ b/tests/test_service_pihole.py @@ -0,0 +1,80 @@ +"""Tests for the Pi-hole API client.""" +from __future__ import annotations + +import pytest +from unittest.mock import MagicMock, patch + + +class TestV5Client: + def _client(self): + from app.services.pihole import PiholeClient + return PiholeClient(url="http://pi.hole", api_key="testkey", version="v5") + + def test_block_calls_v5_get(self): + client = self._client() + with patch.object(client, "_v5_get") as mock_get: + client.block("samsungads.com") + mock_get.assert_called_once_with("black", "add", "samsungads.com") + + def test_unblock_calls_v5_get_sub(self): + client = self._client() + with patch.object(client, "_v5_get") as mock_get: + client.unblock("samsungads.com") + mock_get.assert_called_once_with("black", "sub", "samsungads.com") + + def test_test_connection_returns_ok(self): + client = self._client() + with patch.object(client, "_v5_test", return_value={"ok": True, "version": "v5", "domain_count": 42, "error": None}): + result = client.test_connection() + assert result["ok"] is True + assert result["domain_count"] == 42 + + def test_test_connection_catches_error(self): + client = self._client() + with patch.object(client, "_v5_test", side_effect=Exception("connection refused")): + result = client.test_connection() + assert result["ok"] is False + assert "connection refused" in result["error"] + + +class TestV6Client: + def _client(self): + from app.services.pihole import PiholeClient + return PiholeClient(url="http://pi.hole", api_key="apppassword", version="v6") + + def test_block_auths_then_posts(self): + client = self._client() + with patch.object(client, "_v6_auth", return_value="test-sid") as mock_auth, \ + patch.object(client, "_v6_post_domain") as mock_post: + client.block("samsungads.com", "test comment") + mock_auth.assert_called_once() + mock_post.assert_called_once_with("test-sid", "samsungads.com", "test comment") + + def test_unblock_auths_then_deletes(self): + client = self._client() + with patch.object(client, "_v6_auth", return_value="sid123") as mock_auth, \ + patch.object(client, "_v6_delete_domain") as mock_del: + client.unblock("samsungads.com") + mock_auth.assert_called_once() + mock_del.assert_called_once_with("sid123", "samsungads.com") + + def test_test_connection_returns_ok(self): + client = self._client() + with patch.object(client, "_v6_test", return_value={"ok": True, "version": "v6", "domain_count": 5, "error": None}): + result = client.test_connection() + assert result["ok"] is True + + def test_test_connection_catches_error(self): + client = self._client() + with patch.object(client, "_v6_test", side_effect=Exception("timeout")): + result = client.test_connection() + assert result["ok"] is False + assert "timeout" in result["error"] + + def test_block_default_comment(self): + client = self._client() + with patch.object(client, "_v6_auth", return_value="sid"), \ + patch.object(client, "_v6_post_domain") as mock_post: + client.block("samsungads.com") + _, _, comment = mock_post.call_args.args + assert comment == "Turnstone block"