From ea961d6da98242365c66e19d23b630bdce0b336a Mon Sep 17 00:00:00 2001 From: pyr0ball Date: Mon, 20 Apr 2026 11:55:43 -0700 Subject: [PATCH] feat: messaging DB helpers + unit tests (#74) --- scripts/messaging.py | 269 +++++++++++++++++++++++++++ tests/test_messaging.py | 393 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 662 insertions(+) create mode 100644 scripts/messaging.py create mode 100644 tests/test_messaging.py diff --git a/scripts/messaging.py b/scripts/messaging.py new file mode 100644 index 0000000..2d5fb3f --- /dev/null +++ b/scripts/messaging.py @@ -0,0 +1,269 @@ +""" +DB helpers for the messaging feature. + +Messages table: manual log entries and LLM drafts (one row per message). +Message templates table: built-in seeds and user-created templates. + +Conventions (match scripts/db.py): +- All functions take db_path: Path as first argument. +- sqlite3.connect(db_path), row_factory = sqlite3.Row +- Return plain dicts (dict(row)) +- Always close connection in finally +""" +import sqlite3 +from datetime import datetime, timezone +from pathlib import Path +from typing import Optional + + +# --------------------------------------------------------------------------- +# Internal helpers +# --------------------------------------------------------------------------- + +def _connect(db_path: Path) -> sqlite3.Connection: + con = sqlite3.connect(db_path) + con.row_factory = sqlite3.Row + return con + + +def _now_utc() -> str: + """Return current UTC time as ISO 8601 string.""" + return datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S") + + +# --------------------------------------------------------------------------- +# Messages +# --------------------------------------------------------------------------- + +def create_message( + db_path: Path, + *, + job_id: Optional[int], + job_contact_id: Optional[int], + type: str, + direction: str, + subject: Optional[str], + body: Optional[str], + from_addr: Optional[str], + to_addr: Optional[str], + template_id: Optional[int], +) -> dict: + """Insert a new message row and return it as a dict.""" + con = _connect(db_path) + try: + cur = con.execute( + """ + INSERT INTO messages + (job_id, job_contact_id, type, direction, subject, body, + from_addr, to_addr, template_id) + VALUES + (?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + (job_id, job_contact_id, type, direction, subject, body, + from_addr, to_addr, template_id), + ) + con.commit() + row = con.execute( + "SELECT * FROM messages WHERE id = ?", (cur.lastrowid,) + ).fetchone() + return dict(row) + finally: + con.close() + + +def list_messages( + db_path: Path, + *, + job_id: Optional[int] = None, + type: Optional[str] = None, + direction: Optional[str] = None, + limit: int = 100, +) -> list[dict]: + """Return messages, optionally filtered. Ordered by logged_at DESC.""" + conditions: list[str] = [] + params: list = [] + + if job_id is not None: + conditions.append("job_id = ?") + params.append(job_id) + if type is not None: + conditions.append("type = ?") + params.append(type) + if direction is not None: + conditions.append("direction = ?") + params.append(direction) + + where = ("WHERE " + " AND ".join(conditions)) if conditions else "" + params.append(limit) + + con = _connect(db_path) + try: + rows = con.execute( + f"SELECT * FROM messages {where} ORDER BY logged_at DESC LIMIT ?", + params, + ).fetchall() + return [dict(r) for r in rows] + finally: + con.close() + + +def delete_message(db_path: Path, message_id: int) -> None: + """Delete a message by id. Raises KeyError if not found.""" + con = _connect(db_path) + try: + row = con.execute( + "SELECT id FROM messages WHERE id = ?", (message_id,) + ).fetchone() + if row is None: + raise KeyError(f"Message {message_id} not found") + con.execute("DELETE FROM messages WHERE id = ?", (message_id,)) + con.commit() + finally: + con.close() + + +def approve_message(db_path: Path, message_id: int) -> dict: + """Set approved_at to now for the given message. Raises KeyError if not found.""" + con = _connect(db_path) + try: + row = con.execute( + "SELECT id FROM messages WHERE id = ?", (message_id,) + ).fetchone() + if row is None: + raise KeyError(f"Message {message_id} not found") + con.execute( + "UPDATE messages SET approved_at = ? WHERE id = ?", + (_now_utc(), message_id), + ) + con.commit() + updated = con.execute( + "SELECT * FROM messages WHERE id = ?", (message_id,) + ).fetchone() + return dict(updated) + finally: + con.close() + + +# --------------------------------------------------------------------------- +# Templates +# --------------------------------------------------------------------------- + +def list_templates(db_path: Path) -> list[dict]: + """Return all templates ordered by is_builtin DESC, then title ASC.""" + con = _connect(db_path) + try: + rows = con.execute( + "SELECT * FROM message_templates ORDER BY is_builtin DESC, title ASC" + ).fetchall() + return [dict(r) for r in rows] + finally: + con.close() + + +def create_template( + db_path: Path, + *, + title: str, + category: str = "custom", + subject_template: Optional[str] = None, + body_template: str, +) -> dict: + """Insert a new user-defined template and return it as a dict.""" + con = _connect(db_path) + try: + cur = con.execute( + """ + INSERT INTO message_templates + (title, category, subject_template, body_template, is_builtin) + VALUES + (?, ?, ?, ?, 0) + """, + (title, category, subject_template, body_template), + ) + con.commit() + row = con.execute( + "SELECT * FROM message_templates WHERE id = ?", (cur.lastrowid,) + ).fetchone() + return dict(row) + finally: + con.close() + + +def update_template(db_path: Path, template_id: int, **fields) -> dict: + """ + Update allowed fields on a user-defined template. + + Raises PermissionError if the template is a built-in (is_builtin=1). + Raises KeyError if the template is not found. + """ + if not fields: + # Nothing to update — just return current state + con = _connect(db_path) + try: + row = con.execute( + "SELECT * FROM message_templates WHERE id = ?", (template_id,) + ).fetchone() + if row is None: + raise KeyError(f"Template {template_id} not found") + return dict(row) + finally: + con.close() + + _ALLOWED_FIELDS = { + "title", "category", "subject_template", "body_template", + } + invalid = set(fields) - _ALLOWED_FIELDS + if invalid: + raise ValueError(f"Cannot update field(s): {invalid}") + + con = _connect(db_path) + try: + row = con.execute( + "SELECT id, is_builtin FROM message_templates WHERE id = ?", + (template_id,), + ).fetchone() + if row is None: + raise KeyError(f"Template {template_id} not found") + if row["is_builtin"]: + raise PermissionError( + f"Template {template_id} is a built-in and cannot be modified" + ) + + set_clause = ", ".join(f"{col} = ?" for col in fields) + values = list(fields.values()) + [_now_utc(), template_id] + con.execute( + f"UPDATE message_templates SET {set_clause}, updated_at = ? WHERE id = ?", + values, + ) + con.commit() + updated = con.execute( + "SELECT * FROM message_templates WHERE id = ?", (template_id,) + ).fetchone() + return dict(updated) + finally: + con.close() + + +def delete_template(db_path: Path, template_id: int) -> None: + """ + Delete a user-defined template. + + Raises PermissionError if the template is a built-in (is_builtin=1). + Raises KeyError if the template is not found. + """ + con = _connect(db_path) + try: + row = con.execute( + "SELECT id, is_builtin FROM message_templates WHERE id = ?", + (template_id,), + ).fetchone() + if row is None: + raise KeyError(f"Template {template_id} not found") + if row["is_builtin"]: + raise PermissionError( + f"Template {template_id} is a built-in and cannot be deleted" + ) + con.execute("DELETE FROM message_templates WHERE id = ?", (template_id,)) + con.commit() + finally: + con.close() diff --git a/tests/test_messaging.py b/tests/test_messaging.py new file mode 100644 index 0000000..4e69648 --- /dev/null +++ b/tests/test_messaging.py @@ -0,0 +1,393 @@ +""" +Unit tests for scripts/messaging.py — DB helpers for messages and message_templates. + +TDD approach: tests written before implementation. +""" +import sqlite3 +from pathlib import Path + +import pytest + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + +def _apply_migration_008(db_path: Path) -> None: + """Apply migration 008 directly so tests run without the full migrate_db stack.""" + migration = ( + Path(__file__).parent.parent / "migrations" / "008_messaging.sql" + ) + sql = migration.read_text(encoding="utf-8") + con = sqlite3.connect(db_path) + try: + # Create jobs table stub so FK references don't break + con.execute(""" + CREATE TABLE IF NOT EXISTS jobs ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + title TEXT + ) + """) + con.execute(""" + CREATE TABLE IF NOT EXISTS job_contacts ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + job_id INTEGER + ) + """) + # Execute migration statements + statements = [s.strip() for s in sql.split(";") if s.strip()] + for stmt in statements: + stripped = "\n".join( + ln for ln in stmt.splitlines() if not ln.strip().startswith("--") + ).strip() + if stripped: + con.execute(stripped) + con.commit() + finally: + con.close() + + +@pytest.fixture() +def db_path(tmp_path: Path) -> Path: + """Temporary SQLite DB with migration 008 applied.""" + path = tmp_path / "test.db" + _apply_migration_008(path) + return path + + +@pytest.fixture() +def job_id(db_path: Path) -> int: + """Insert a dummy job and return its id.""" + con = sqlite3.connect(db_path) + try: + cur = con.execute("INSERT INTO jobs (title) VALUES ('Test Job')") + con.commit() + return cur.lastrowid + finally: + con.close() + + +# --------------------------------------------------------------------------- +# Message tests +# --------------------------------------------------------------------------- + +class TestCreateMessage: + def test_create_returns_dict(self, db_path: Path, job_id: int) -> None: + from scripts.messaging import create_message + + msg = create_message( + db_path, + job_id=job_id, + job_contact_id=None, + type="email", + direction="outbound", + subject="Hello", + body="Body text", + from_addr="me@example.com", + to_addr="them@example.com", + template_id=None, + ) + + assert isinstance(msg, dict) + assert msg["subject"] == "Hello" + assert msg["body"] == "Body text" + assert msg["direction"] == "outbound" + assert msg["type"] == "email" + assert "id" in msg + assert msg["id"] > 0 + + def test_create_persists_to_db(self, db_path: Path, job_id: int) -> None: + from scripts.messaging import create_message + + create_message( + db_path, + job_id=job_id, + job_contact_id=None, + type="email", + direction="outbound", + subject="Persisted", + body="Stored body", + from_addr="a@b.com", + to_addr="c@d.com", + template_id=None, + ) + + con = sqlite3.connect(db_path) + try: + row = con.execute( + "SELECT subject FROM messages WHERE subject='Persisted'" + ).fetchone() + assert row is not None + finally: + con.close() + + +class TestListMessages: + def _make_message( + self, + db_path: Path, + job_id: int, + *, + type: str = "email", + direction: str = "outbound", + subject: str = "Subject", + ) -> dict: + from scripts.messaging import create_message + return create_message( + db_path, + job_id=job_id, + job_contact_id=None, + type=type, + direction=direction, + subject=subject, + body="body", + from_addr="a@b.com", + to_addr="c@d.com", + template_id=None, + ) + + def test_list_returns_all_messages(self, db_path: Path, job_id: int) -> None: + from scripts.messaging import list_messages + + self._make_message(db_path, job_id, subject="First") + self._make_message(db_path, job_id, subject="Second") + + result = list_messages(db_path) + assert len(result) == 2 + + def test_list_filtered_by_job_id(self, db_path: Path, job_id: int) -> None: + from scripts.messaging import list_messages + + # Create a second job + con = sqlite3.connect(db_path) + try: + cur = con.execute("INSERT INTO jobs (title) VALUES ('Other Job')") + con.commit() + other_job_id = cur.lastrowid + finally: + con.close() + + self._make_message(db_path, job_id, subject="For job 1") + self._make_message(db_path, other_job_id, subject="For job 2") + + result = list_messages(db_path, job_id=job_id) + assert len(result) == 1 + assert result[0]["subject"] == "For job 1" + + def test_list_filtered_by_type(self, db_path: Path, job_id: int) -> None: + from scripts.messaging import list_messages + + self._make_message(db_path, job_id, type="email", subject="Email msg") + self._make_message(db_path, job_id, type="sms", subject="SMS msg") + + emails = list_messages(db_path, type="email") + assert len(emails) == 1 + assert emails[0]["type"] == "email" + + def test_list_filtered_by_direction(self, db_path: Path, job_id: int) -> None: + from scripts.messaging import list_messages + + self._make_message(db_path, job_id, direction="outbound") + self._make_message(db_path, job_id, direction="inbound") + + outbound = list_messages(db_path, direction="outbound") + assert len(outbound) == 1 + assert outbound[0]["direction"] == "outbound" + + def test_list_respects_limit(self, db_path: Path, job_id: int) -> None: + from scripts.messaging import list_messages + + for i in range(5): + self._make_message(db_path, job_id, subject=f"Msg {i}") + + result = list_messages(db_path, limit=3) + assert len(result) == 3 + + +class TestDeleteMessage: + def test_delete_removes_message(self, db_path: Path, job_id: int) -> None: + from scripts.messaging import create_message, delete_message, list_messages + + msg = create_message( + db_path, + job_id=job_id, + job_contact_id=None, + type="email", + direction="outbound", + subject="To delete", + body="bye", + from_addr="a@b.com", + to_addr="c@d.com", + template_id=None, + ) + + delete_message(db_path, msg["id"]) + assert list_messages(db_path) == [] + + def test_delete_raises_key_error_when_not_found(self, db_path: Path) -> None: + from scripts.messaging import delete_message + + with pytest.raises(KeyError): + delete_message(db_path, 99999) + + +class TestApproveMessage: + def test_approve_sets_approved_at(self, db_path: Path, job_id: int) -> None: + from scripts.messaging import approve_message, create_message + + msg = create_message( + db_path, + job_id=job_id, + job_contact_id=None, + type="email", + direction="outbound", + subject="Draft", + body="Draft body", + from_addr="a@b.com", + to_addr="c@d.com", + template_id=None, + ) + assert msg.get("approved_at") is None + + updated = approve_message(db_path, msg["id"]) + assert updated["approved_at"] is not None + assert updated["id"] == msg["id"] + + def test_approve_returns_full_dict(self, db_path: Path, job_id: int) -> None: + from scripts.messaging import approve_message, create_message + + msg = create_message( + db_path, + job_id=job_id, + job_contact_id=None, + type="email", + direction="outbound", + subject="Draft", + body="Body here", + from_addr="a@b.com", + to_addr="c@d.com", + template_id=None, + ) + + updated = approve_message(db_path, msg["id"]) + assert updated["body"] == "Body here" + assert updated["subject"] == "Draft" + + def test_approve_raises_key_error_when_not_found(self, db_path: Path) -> None: + from scripts.messaging import approve_message + + with pytest.raises(KeyError): + approve_message(db_path, 99999) + + +# --------------------------------------------------------------------------- +# Template tests +# --------------------------------------------------------------------------- + +class TestListTemplates: + def test_includes_four_builtins(self, db_path: Path) -> None: + from scripts.messaging import list_templates + + templates = list_templates(db_path) + builtin_keys = {t["key"] for t in templates if t["is_builtin"]} + assert builtin_keys == { + "follow_up", + "thank_you", + "accommodation_request", + "withdrawal", + } + + def test_returns_list_of_dicts(self, db_path: Path) -> None: + from scripts.messaging import list_templates + + templates = list_templates(db_path) + assert isinstance(templates, list) + assert all(isinstance(t, dict) for t in templates) + + +class TestCreateTemplate: + def test_create_returns_dict(self, db_path: Path) -> None: + from scripts.messaging import create_template + + tmpl = create_template( + db_path, + title="My Template", + category="custom", + subject_template="Hello {{name}}", + body_template="Dear {{name}}, ...", + ) + + assert isinstance(tmpl, dict) + assert tmpl["title"] == "My Template" + assert tmpl["category"] == "custom" + assert tmpl["is_builtin"] == 0 + assert "id" in tmpl + + def test_create_default_category(self, db_path: Path) -> None: + from scripts.messaging import create_template + + tmpl = create_template( + db_path, + title="No Category", + body_template="Body", + ) + assert tmpl["category"] == "custom" + + def test_create_appears_in_list(self, db_path: Path) -> None: + from scripts.messaging import create_template, list_templates + + create_template(db_path, title="Listed", body_template="Body") + titles = [t["title"] for t in list_templates(db_path)] + assert "Listed" in titles + + +class TestUpdateTemplate: + def test_update_user_template(self, db_path: Path) -> None: + from scripts.messaging import create_template, update_template + + tmpl = create_template(db_path, title="Original", body_template="Old body") + updated = update_template(db_path, tmpl["id"], title="Updated", body_template="New body") + + assert updated["title"] == "Updated" + assert updated["body_template"] == "New body" + + def test_update_returns_persisted_values(self, db_path: Path) -> None: + from scripts.messaging import create_template, list_templates, update_template + + tmpl = create_template(db_path, title="Before", body_template="x") + update_template(db_path, tmpl["id"], title="After") + + templates = list_templates(db_path) + titles = [t["title"] for t in templates] + assert "After" in titles + assert "Before" not in titles + + def test_update_builtin_raises_permission_error(self, db_path: Path) -> None: + from scripts.messaging import list_templates, update_template + + builtin = next(t for t in list_templates(db_path) if t["is_builtin"]) + with pytest.raises(PermissionError): + update_template(db_path, builtin["id"], title="Hacked") + + +class TestDeleteTemplate: + def test_delete_user_template(self, db_path: Path) -> None: + from scripts.messaging import create_template, delete_template, list_templates + + tmpl = create_template(db_path, title="To Delete", body_template="bye") + initial_count = len(list_templates(db_path)) + delete_template(db_path, tmpl["id"]) + assert len(list_templates(db_path)) == initial_count - 1 + + def test_delete_builtin_raises_permission_error(self, db_path: Path) -> None: + from scripts.messaging import delete_template, list_templates + + builtin = next(t for t in list_templates(db_path) if t["is_builtin"]) + with pytest.raises(PermissionError): + delete_template(db_path, builtin["id"]) + + def test_delete_missing_raises_key_error(self, db_path: Path) -> None: + from scripts.messaging import delete_template + + with pytest.raises(KeyError): + delete_template(db_path, 99999)