turnstone/app/services/blocklist.py

288 lines
8.8 KiB
Python

"""Blocklist candidate extraction, management, and telemetry matching."""
from __future__ import annotations
import dataclasses
import json
import re
import sqlite3
import uuid
from datetime import datetime, timezone
from pathlib import Path
import yaml
# ---------------------------------------------------------------------------
# Data models
# ---------------------------------------------------------------------------
@dataclasses.dataclass(frozen=True)
class TelemetryRule:
name: str
domains: tuple[str, ...]
category: str
description: str
@dataclasses.dataclass
class BlocklistCandidate:
id: str
domain_or_ip: str
source_device_ip: str | None
source_device_name: str | None
first_seen: str
last_seen: str
hit_count: int
status: str
pushed_at: str | None
log_evidence: list[str]
matched_rule: str | None
llm_score: float | None
llm_reason: str | None
# ---------------------------------------------------------------------------
# Telemetry list
# ---------------------------------------------------------------------------
def load_telemetry_rules(path: Path) -> list[TelemetryRule]:
"""Load telemetry rules from a YAML file."""
data = yaml.safe_load(path.read_text())
return [
TelemetryRule(
name=r["name"],
domains=tuple(d.lower().strip(".") for d in r["domains"]),
category=r["category"],
description=r.get("description", ""),
)
for r in data.get("rules", [])
]
def matches_telemetry(domain: str, rules: list[TelemetryRule]) -> TelemetryRule | None:
"""Return the first rule whose domains include domain or a parent domain, else None."""
d = domain.lower().strip(".")
for rule in rules:
for rd in rule.domains:
if d == rd or d.endswith("." + rd):
return rule
return None
# ---------------------------------------------------------------------------
# Regex extractors for router log entries
# ---------------------------------------------------------------------------
_DNSMASQ_RE = re.compile(
r"query\[A{1,4}\]\s+(?P<domain>\S+)\s+from\s+(?P<src>[\d.]+)"
)
_IPTABLES_RE = re.compile(
r"SRC=(?P<src>[\d.]+).*?DST=(?P<dst>[\d.a-zA-Z.-]+)"
)
_VALID_STATUSES = {"pending", "approved", "rejected", "pushed", "unblocked"}
# ---------------------------------------------------------------------------
# DB helpers
# ---------------------------------------------------------------------------
def _now_iso() -> str:
return datetime.now(timezone.utc).isoformat()
def _row_to_candidate(row: tuple) -> BlocklistCandidate:
return BlocklistCandidate(
id=row[0],
domain_or_ip=row[1],
source_device_ip=row[2],
source_device_name=row[3],
first_seen=row[4],
last_seen=row[5],
hit_count=row[6],
status=row[7],
pushed_at=row[8],
log_evidence=json.loads(row[9] or "[]"),
matched_rule=row[10],
llm_score=row[11],
llm_reason=row[12],
)
def _upsert_candidate(
conn: sqlite3.Connection,
domain_or_ip: str,
source_device_ip: str | None,
source_device_name: str | None,
matched_rule: str | None,
entry_id: str,
now: str,
) -> bool:
"""Insert or update a candidate. Returns True if a new row was created."""
row = conn.execute(
"SELECT id, hit_count, log_evidence FROM blocklist_candidates "
"WHERE domain_or_ip = ? AND source_device_ip IS ?",
(domain_or_ip, source_device_ip),
).fetchone()
if row is None:
conn.execute(
"""INSERT INTO blocklist_candidates
(id, domain_or_ip, source_device_ip, source_device_name,
first_seen, last_seen, hit_count, status, pushed_at, log_evidence, matched_rule)
VALUES (?, ?, ?, ?, ?, ?, 1, 'pending', NULL, ?, ?)""",
(
str(uuid.uuid4()), domain_or_ip, source_device_ip, source_device_name,
now, now, json.dumps([entry_id]), matched_rule,
),
)
return True
existing_id, hit_count, existing_evidence = row
evidence = json.loads(existing_evidence or "[]")
if entry_id not in evidence:
evidence.append(entry_id)
evidence = evidence[-10:] # cap at 10
conn.execute(
"UPDATE blocklist_candidates SET last_seen=?, hit_count=?, log_evidence=? WHERE id=?",
(now, hit_count + 1, json.dumps(evidence), existing_id),
)
return False
# ---------------------------------------------------------------------------
# Extraction scan
# ---------------------------------------------------------------------------
def run_scan(
db_path: Path,
router_source_ids: list[str],
device_map: dict[str, str],
telemetry_rules: list[TelemetryRule],
) -> int:
"""Scan log_entries from router sources, upsert blocklist candidates.
Only entries whose source IP is in device_map are recorded.
Returns the total number of rows created or updated.
"""
if not router_source_ids or not device_map:
return 0
placeholders = ",".join("?" for _ in router_source_ids)
now = _now_iso()
count = 0
conn = sqlite3.connect(str(db_path))
try:
rows = conn.execute(
f"SELECT id, text FROM log_entries WHERE source_id IN ({placeholders})",
router_source_ids,
).fetchall()
for entry_id, text in rows:
src_ip: str | None = None
dst: str | None = None
m = _DNSMASQ_RE.search(text)
if m:
src_ip = m.group("src")
dst = m.group("domain")
else:
m = _IPTABLES_RE.search(text)
if m:
src_ip = m.group("src")
dst = m.group("dst")
if src_ip is None or src_ip not in device_map:
continue
device_name = device_map[src_ip]
rule = matches_telemetry(dst, telemetry_rules) if dst else None
matched_rule_name = rule.name if rule else None
_upsert_candidate(conn, dst or "unknown", src_ip, device_name, matched_rule_name, entry_id, now)
count += 1
conn.commit()
finally:
conn.close()
return count
# ---------------------------------------------------------------------------
# Candidate CRUD
# ---------------------------------------------------------------------------
_CANDIDATE_SELECT = (
"SELECT id,domain_or_ip,source_device_ip,source_device_name,"
"first_seen,last_seen,hit_count,status,pushed_at,log_evidence,"
"matched_rule,llm_score,llm_reason FROM blocklist_candidates"
)
def list_candidates(
db_path: Path,
status: str | None = None,
device_ip: str | None = None,
) -> list[BlocklistCandidate]:
conn = sqlite3.connect(str(db_path))
try:
query = f"{_CANDIDATE_SELECT} WHERE 1=1"
params: list = []
if status and status != "all":
query += " AND status = ?"
params.append(status)
if device_ip:
query += " AND source_device_ip = ?"
params.append(device_ip)
query += " ORDER BY last_seen DESC"
rows = conn.execute(query, params).fetchall()
finally:
conn.close()
return [_row_to_candidate(r) for r in rows]
def _get_candidate(conn: sqlite3.Connection, candidate_id: str) -> BlocklistCandidate:
row = conn.execute(
f"{_CANDIDATE_SELECT} WHERE id=?",
(candidate_id,),
).fetchone()
if row is None:
raise KeyError(f"Candidate {candidate_id!r} not found")
return _row_to_candidate(row)
def update_candidate_status(db_path: Path, candidate_id: str, new_status: str) -> BlocklistCandidate:
if new_status not in _VALID_STATUSES:
raise ValueError(f"Invalid status {new_status!r}. Must be one of {_VALID_STATUSES}")
conn = sqlite3.connect(str(db_path))
try:
conn.execute("UPDATE blocklist_candidates SET status=? WHERE id=?", (new_status, candidate_id))
conn.commit()
return _get_candidate(conn, candidate_id)
finally:
conn.close()
def mark_pushed(db_path: Path, candidate_id: str) -> BlocklistCandidate:
conn = sqlite3.connect(str(db_path))
try:
conn.execute(
"UPDATE blocklist_candidates SET status='pushed', pushed_at=? WHERE id=?",
(_now_iso(), candidate_id),
)
conn.commit()
return _get_candidate(conn, candidate_id)
finally:
conn.close()
def mark_unblocked(db_path: Path, candidate_id: str) -> BlocklistCandidate:
conn = sqlite3.connect(str(db_path))
try:
conn.execute("UPDATE blocklist_candidates SET status='unblocked' WHERE id=?", (candidate_id,))
conn.commit()
return _get_candidate(conn, candidate_id)
finally:
conn.close()