From 4ac99403bd7d0739afe48d5e69bbd688adb49c8b Mon Sep 17 00:00:00 2001 From: pyr0ball Date: Sun, 12 Apr 2026 17:21:39 -0700 Subject: [PATCH] feat(community): add CommunityDB connection pool and migration runner --- circuitforge_core/community/db.py | 114 +++++++++++++++++++++++++++++- tests/community/test_db.py | 64 +++++++++++++++++ 2 files changed, 176 insertions(+), 2 deletions(-) create mode 100644 tests/community/test_db.py diff --git a/circuitforge_core/community/db.py b/circuitforge_core/community/db.py index 0fc109f..a88c7e9 100644 --- a/circuitforge_core/community/db.py +++ b/circuitforge_core/community/db.py @@ -1,7 +1,117 @@ # circuitforge_core/community/db.py # MIT License -# Stub — implemented in full by Task 4 + +from __future__ import annotations + +import importlib.resources +import logging +from pathlib import Path + +import psycopg2 +from psycopg2.pool import ThreadedConnectionPool + +logger = logging.getLogger(__name__) + +_MIN_CONN = 1 +_MAX_CONN = 10 class CommunityDB: - pass + """Shared PostgreSQL connection pool + migration runner for the community module. + + Products instantiate one CommunityDB at startup and pass it to SharedStore + subclasses. The pool is thread-safe (ThreadedConnectionPool). + + Usage: + db = CommunityDB.from_env() # reads COMMUNITY_DB_URL + db.run_migrations() + store = MyProductStore(db) + db.close() # at shutdown + """ + + def __init__(self, dsn: str | None) -> None: + if not dsn: + raise ValueError( + "CommunityDB requires a DSN. " + "Set COMMUNITY_DB_URL or pass dsn= explicitly." + ) + self._pool = ThreadedConnectionPool(_MIN_CONN, _MAX_CONN, dsn=dsn) + logger.debug("CommunityDB pool created (min=%d, max=%d)", _MIN_CONN, _MAX_CONN) + + @classmethod + def from_env(cls) -> "CommunityDB": + """Construct from the COMMUNITY_DB_URL environment variable.""" + import os + dsn = os.environ.get("COMMUNITY_DB_URL") + return cls(dsn=dsn) + + # ------------------------------------------------------------------ + # Connection management + # ------------------------------------------------------------------ + + def getconn(self): + """Borrow a connection from the pool. Must be returned via putconn().""" + return self._pool.getconn() + + def putconn(self, conn) -> None: + """Return a borrowed connection to the pool.""" + self._pool.putconn(conn) + + def close(self) -> None: + """Close all pool connections. Call at application shutdown.""" + self._pool.closeall() + logger.debug("CommunityDB pool closed") + + # ------------------------------------------------------------------ + # Migration runner + # ------------------------------------------------------------------ + + def _discover_migrations(self) -> list[Path]: + """Return sorted list of .sql migration files from the community migrations dir.""" + pkg = importlib.resources.files("circuitforge_core.community.migrations") + files = sorted( + [Path(str(p)) for p in pkg.iterdir() if str(p).endswith(".sql")], + key=lambda p: p.name, + ) + return files + + def run_migrations(self) -> None: + """Apply all community migration SQL files in numeric order. + + Uses a simple applied-migrations table to avoid re-running already + applied migrations. Idempotent. + """ + conn = self.getconn() + try: + with conn.cursor() as cur: + cur.execute(""" + CREATE TABLE IF NOT EXISTS _community_migrations ( + filename TEXT PRIMARY KEY, + applied_at TIMESTAMPTZ NOT NULL DEFAULT NOW() + ) + """) + conn.commit() + + for migration_file in self._discover_migrations(): + name = migration_file.name + cur.execute( + "SELECT 1 FROM _community_migrations WHERE filename = %s", + (name,), + ) + if cur.fetchone(): + logger.debug("Migration %s already applied, skipping", name) + continue + + sql = migration_file.read_text() + logger.info("Applying community migration: %s", name) + cur.execute(sql) + cur.execute( + "INSERT INTO _community_migrations (filename) VALUES (%s)", + (name,), + ) + conn.commit() + except Exception: + conn.rollback() + raise + finally: + self.putconn(conn) diff --git a/tests/community/test_db.py b/tests/community/test_db.py new file mode 100644 index 0000000..62c19b6 --- /dev/null +++ b/tests/community/test_db.py @@ -0,0 +1,64 @@ +# tests/community/test_db.py +import os +import pytest +from unittest.mock import MagicMock, patch +from circuitforge_core.community.db import CommunityDB + + +@pytest.fixture +def mock_pool(): + """Patch psycopg2.pool.ThreadedConnectionPool to avoid needing a real PG instance.""" + with patch("circuitforge_core.community.db.ThreadedConnectionPool") as mock_cls: + mock_instance = MagicMock() + mock_cls.return_value = mock_instance + yield mock_cls, mock_instance + + +def test_community_db_requires_url(): + with pytest.raises(ValueError, match="COMMUNITY_DB_URL"): + CommunityDB(dsn=None) + + +def test_community_db_init_creates_pool(mock_pool): + mock_cls, _ = mock_pool + CommunityDB(dsn="postgresql://user:pass@localhost/cf_community") + mock_cls.assert_called_once() + + +def test_community_db_close_puts_pool(mock_pool): + _, mock_instance = mock_pool + db = CommunityDB(dsn="postgresql://user:pass@localhost/cf_community") + db.close() + mock_instance.closeall.assert_called_once() + + +def test_community_db_migration_files_discovered(): + """Migration runner must find at least 001 and 002 SQL files.""" + db = CommunityDB.__new__(CommunityDB) + files = db._discover_migrations() + names = [f.name for f in files] + assert any("001" in n for n in names) + assert any("002" in n for n in names) + # Must be sorted numerically + assert files == sorted(files, key=lambda p: p.name) + + +def test_community_db_run_migrations_executes_sql(mock_pool): + _, mock_instance = mock_pool + mock_conn = MagicMock() + mock_cur = MagicMock() + mock_instance.getconn.return_value = mock_conn + mock_conn.cursor.return_value.__enter__.return_value = mock_cur + mock_cur.fetchone.return_value = None # no migrations applied yet + + db = CommunityDB(dsn="postgresql://user:pass@localhost/cf_community") + db.run_migrations() + + # At least one execute call must have happened + assert mock_cur.execute.called + + +def test_community_db_from_env(monkeypatch, mock_pool): + monkeypatch.setenv("COMMUNITY_DB_URL", "postgresql://u:p@host/db") + db = CommunityDB.from_env() + assert db is not None