feat(community): add CommunityDB connection pool and migration runner

This commit is contained in:
pyr0ball 2026-04-12 21:38:14 -07:00
parent d78310d4fd
commit f74457d11f
2 changed files with 180 additions and 0 deletions

View file

@ -0,0 +1,117 @@
# circuitforge_core/community/db.py
# MIT License
from __future__ import annotations
import importlib.resources
import logging
from pathlib import Path
import psycopg2
from psycopg2.pool import ThreadedConnectionPool
logger = logging.getLogger(__name__)
_MIN_CONN = 1
_MAX_CONN = 10
class CommunityDB:
"""Shared PostgreSQL connection pool + migration runner for the community module.
Products instantiate one CommunityDB at startup and pass it to SharedStore
subclasses. The pool is thread-safe (ThreadedConnectionPool).
Usage:
db = CommunityDB.from_env() # reads COMMUNITY_DB_URL
db.run_migrations()
store = MyProductStore(db)
db.close() # at shutdown
"""
def __init__(self, dsn: str | None) -> None:
if not dsn:
raise ValueError(
"CommunityDB requires a DSN. "
"Set COMMUNITY_DB_URL or pass dsn= explicitly."
)
self._pool = ThreadedConnectionPool(_MIN_CONN, _MAX_CONN, dsn=dsn)
logger.debug("CommunityDB pool created (min=%d, max=%d)", _MIN_CONN, _MAX_CONN)
@classmethod
def from_env(cls) -> "CommunityDB":
"""Construct from the COMMUNITY_DB_URL environment variable."""
import os
dsn = os.environ.get("COMMUNITY_DB_URL")
return cls(dsn=dsn)
# ------------------------------------------------------------------
# Connection management
# ------------------------------------------------------------------
def getconn(self):
"""Borrow a connection from the pool. Must be returned via putconn()."""
return self._pool.getconn()
def putconn(self, conn) -> None:
"""Return a borrowed connection to the pool."""
self._pool.putconn(conn)
def close(self) -> None:
"""Close all pool connections. Call at application shutdown."""
self._pool.closeall()
logger.debug("CommunityDB pool closed")
# ------------------------------------------------------------------
# Migration runner
# ------------------------------------------------------------------
def _discover_migrations(self) -> list[Path]:
"""Return sorted list of .sql migration files from the community migrations dir."""
pkg = importlib.resources.files("circuitforge_core.community.migrations")
files = sorted(
[Path(str(p)) for p in pkg.iterdir() if str(p).endswith(".sql")],
key=lambda p: p.name,
)
return files
def run_migrations(self) -> None:
"""Apply all community migration SQL files in numeric order.
Uses a simple applied-migrations table to avoid re-running already
applied migrations. Idempotent.
"""
conn = self.getconn()
try:
with conn.cursor() as cur:
cur.execute("""
CREATE TABLE IF NOT EXISTS _community_migrations (
filename TEXT PRIMARY KEY,
applied_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
)
""")
conn.commit()
for migration_file in self._discover_migrations():
name = migration_file.name
cur.execute(
"SELECT 1 FROM _community_migrations WHERE filename = %s",
(name,),
)
if cur.fetchone():
logger.debug("Migration %s already applied, skipping", name)
continue
sql = migration_file.read_text()
logger.info("Applying community migration: %s", name)
cur.execute(sql)
cur.execute(
"INSERT INTO _community_migrations (filename) VALUES (%s)",
(name,),
)
conn.commit()
except Exception:
conn.rollback()
raise
finally:
self.putconn(conn)

View file

@ -0,0 +1,63 @@
# tests/community/test_db.py
import os
import pytest
from unittest.mock import MagicMock, patch, call
from circuitforge_core.community.db import CommunityDB
@pytest.fixture
def mock_pool():
"""Patch psycopg2.pool.ThreadedConnectionPool to avoid needing a real PG instance."""
with patch("circuitforge_core.community.db.ThreadedConnectionPool") as mock_cls:
mock_instance = MagicMock()
mock_cls.return_value = mock_instance
yield mock_cls, mock_instance
def test_community_db_requires_url():
with pytest.raises(ValueError, match="COMMUNITY_DB_URL"):
CommunityDB(dsn=None)
def test_community_db_init_creates_pool(mock_pool):
mock_cls, _ = mock_pool
CommunityDB(dsn="postgresql://user:pass@localhost/cf_community")
mock_cls.assert_called_once()
def test_community_db_close_puts_pool(mock_pool):
_, mock_instance = mock_pool
db = CommunityDB(dsn="postgresql://user:pass@localhost/cf_community")
db.close()
mock_instance.closeall.assert_called_once()
def test_community_db_migration_files_discovered():
"""Migration runner must find at least 001 and 002 SQL files."""
db = CommunityDB.__new__(CommunityDB)
files = db._discover_migrations()
names = [f.name for f in files]
assert any("001" in n for n in names)
assert any("002" in n for n in names)
# Must be sorted numerically
assert files == sorted(files, key=lambda p: p.name)
def test_community_db_run_migrations_executes_sql(mock_pool):
_, mock_instance = mock_pool
mock_conn = MagicMock()
mock_cur = MagicMock()
mock_instance.getconn.return_value = mock_conn
mock_conn.cursor.return_value.__enter__.return_value = mock_cur
db = CommunityDB(dsn="postgresql://user:pass@localhost/cf_community")
db.run_migrations()
# At least one execute call must have happened
assert mock_cur.execute.called
def test_community_db_from_env(monkeypatch, mock_pool):
monkeypatch.setenv("COMMUNITY_DB_URL", "postgresql://u:p@host/db")
db = CommunityDB.from_env()
assert db is not None