"""Tests for the blocklist REST endpoints.""" from __future__ import annotations import json import pytest from unittest.mock import MagicMock, patch @pytest.fixture def client(tmp_path): from fastapi.testclient import TestClient from app.glean.pipeline import ensure_schema import app.rest as rest_module db = tmp_path / "test.db" ensure_schema(db) with patch.object(rest_module, "DB_PATH", db), \ patch.object(rest_module, "CONTEXT_DB_PATH", tmp_path / "context.db"), \ patch.object(rest_module, "INCIDENTS_DB_PATH", tmp_path / "incidents.db"), \ patch.object(rest_module, "PREFS_PATH", tmp_path / "prefs.json"), \ patch.object(rest_module, "_compiled_patterns", []): with TestClient(rest_module.app, raise_server_exceptions=True) as c: yield c @pytest.fixture def client_with_candidate(tmp_path): from fastapi.testclient import TestClient from app.glean.pipeline import ensure_schema import app.rest as rest_module import sqlite3, uuid db = tmp_path / "test.db" ensure_schema(db) cid = str(uuid.uuid4()) conn = sqlite3.connect(str(db)) conn.execute( "INSERT INTO blocklist_candidates (id, domain_or_ip, first_seen, last_seen) VALUES (?, 'samsungads.com', '2026-05-14T00:00:00+00:00', '2026-05-14T00:00:00+00:00')", (cid,), ) conn.commit() conn.close() with patch.object(rest_module, "DB_PATH", db), \ patch.object(rest_module, "CONTEXT_DB_PATH", tmp_path / "context.db"), \ patch.object(rest_module, "INCIDENTS_DB_PATH", tmp_path / "incidents.db"), \ patch.object(rest_module, "PREFS_PATH", tmp_path / "prefs.json"), \ patch.object(rest_module, "_compiled_patterns", []): with TestClient(rest_module.app, raise_server_exceptions=True) as c: yield c, cid class TestListCandidates: def test_empty_returns_empty_list(self, client): resp = client.get("/turnstone/api/blocklist/candidates") assert resp.status_code == 200 assert resp.json()["candidates"] == [] def test_returns_candidate(self, client_with_candidate): c, _ = client_with_candidate resp = c.get("/turnstone/api/blocklist/candidates") assert resp.status_code == 200 data = resp.json() assert data["total"] == 1 assert data["candidates"][0]["domain_or_ip"] == "samsungads.com" def test_filter_by_status(self, client_with_candidate): c, _ = client_with_candidate assert c.get("/turnstone/api/blocklist/candidates?status=pending").json()["total"] == 1 assert c.get("/turnstone/api/blocklist/candidates?status=pushed").json()["total"] == 0 class TestScan: def test_scan_returns_started(self, client): with patch("app.rest.run_scan", return_value=0): resp = client.post("/turnstone/api/blocklist/scan") assert resp.status_code == 200 assert resp.json()["started"] is True class TestUpdateStatus: def test_approve_candidate(self, client_with_candidate): c, cid = client_with_candidate resp = c.patch(f"/turnstone/api/blocklist/candidates/{cid}", json={"status": "approved"}) assert resp.status_code == 200 assert resp.json()["status"] == "approved" def test_reject_candidate(self, client_with_candidate): c, cid = client_with_candidate resp = c.patch(f"/turnstone/api/blocklist/candidates/{cid}", json={"status": "rejected"}) assert resp.status_code == 200 assert resp.json()["status"] == "rejected" def test_invalid_status_returns_400(self, client_with_candidate): c, cid = client_with_candidate resp = c.patch(f"/turnstone/api/blocklist/candidates/{cid}", json={"status": "hacked"}) assert resp.status_code == 400 def test_unknown_id_returns_404(self, client): resp = client.patch("/turnstone/api/blocklist/candidates/no-such-id", json={"status": "approved"}) assert resp.status_code == 404 class TestPush: def test_push_approved_candidate(self, client_with_candidate): c, cid = client_with_candidate c.patch(f"/turnstone/api/blocklist/candidates/{cid}", json={"status": "approved"}) mock_pihole = MagicMock() with patch("app.rest.PiholeClient", return_value=mock_pihole): resp = c.post(f"/turnstone/api/blocklist/push/{cid}") assert resp.status_code == 200 data = resp.json() assert data["pushed"] is True assert data["domain"] == "samsungads.com" mock_pihole.block.assert_called_once_with("samsungads.com") def test_push_unapproved_returns_400(self, client_with_candidate): c, cid = client_with_candidate with patch("app.rest.PiholeClient"): resp = c.post(f"/turnstone/api/blocklist/push/{cid}") assert resp.status_code == 400 def test_push_no_pihole_config_returns_503(self, client_with_candidate): c, cid = client_with_candidate c.patch(f"/turnstone/api/blocklist/candidates/{cid}", json={"status": "approved"}) resp = c.post(f"/turnstone/api/blocklist/push/{cid}") assert resp.status_code == 503 class TestUnblock: def test_unblock_pushed_candidate(self, client_with_candidate): c, cid = client_with_candidate c.patch(f"/turnstone/api/blocklist/candidates/{cid}", json={"status": "approved"}) mock_pihole = MagicMock() with patch("app.rest.PiholeClient", return_value=mock_pihole): c.post(f"/turnstone/api/blocklist/push/{cid}") resp = c.delete(f"/turnstone/api/blocklist/push/{cid}") assert resp.status_code == 200 assert resp.json()["unblocked"] is True mock_pihole.unblock.assert_called_once_with("samsungads.com") def test_unblock_not_pushed_returns_400(self, client_with_candidate): c, cid = client_with_candidate with patch("app.rest.PiholeClient"): resp = c.delete(f"/turnstone/api/blocklist/push/{cid}") assert resp.status_code == 400 class TestPiholeTest: def test_returns_connection_result(self, client): mock_pihole = MagicMock() mock_pihole.test_connection.return_value = {"ok": True, "version": "v6", "domain_count": 5, "error": None} with patch("app.rest.PiholeClient", return_value=mock_pihole): resp = client.post("/turnstone/api/blocklist/test") assert resp.status_code == 200 assert resp.json()["ok"] is True def test_no_pihole_config_returns_503(self, client): resp = client.post("/turnstone/api/blocklist/test") assert resp.status_code == 503