diff --git a/circuitforge_core/db/__init__.py b/circuitforge_core/db/__init__.py new file mode 100644 index 0000000..eafc00b --- /dev/null +++ b/circuitforge_core/db/__init__.py @@ -0,0 +1,4 @@ +from .base import get_connection +from .migrations import run_migrations + +__all__ = ["get_connection", "run_migrations"] diff --git a/circuitforge_core/db/base.py b/circuitforge_core/db/base.py new file mode 100644 index 0000000..baed4d0 --- /dev/null +++ b/circuitforge_core/db/base.py @@ -0,0 +1,28 @@ +""" +SQLite connection factory for CircuitForge products. +Supports plain SQLite and SQLCipher (AES-256) when CLOUD_MODE is active. +""" +from __future__ import annotations +import os +import sqlite3 +from pathlib import Path + + +def get_connection(db_path: Path, key: str = "") -> sqlite3.Connection: + """ + Open a SQLite database connection. + + In cloud mode with a key: uses SQLCipher (API-identical to sqlite3). + Otherwise: plain sqlite3. + + Args: + db_path: Path to the database file. Created if absent. + key: SQLCipher encryption key. Empty = unencrypted. + """ + cloud_mode = os.environ.get("CLOUD_MODE", "").lower() in ("1", "true", "yes") + if cloud_mode and key: + from pysqlcipher3 import dbapi2 as _sqlcipher # type: ignore + conn = _sqlcipher.connect(str(db_path)) + conn.execute(f"PRAGMA key='{key}'") + return conn + return sqlite3.connect(str(db_path)) diff --git a/circuitforge_core/db/migrations.py b/circuitforge_core/db/migrations.py new file mode 100644 index 0000000..ddcf331 --- /dev/null +++ b/circuitforge_core/db/migrations.py @@ -0,0 +1,27 @@ +""" +Sequential SQL migration runner. +Applies *.sql files from migrations_dir in filename order. +Tracks applied migrations in a _migrations table — safe to call multiple times. +""" +from __future__ import annotations +import sqlite3 +from pathlib import Path + + +def run_migrations(conn: sqlite3.Connection, migrations_dir: Path) -> None: + """Apply any unapplied *.sql migrations from migrations_dir.""" + conn.execute( + "CREATE TABLE IF NOT EXISTS _migrations " + "(name TEXT PRIMARY KEY, applied_at TEXT DEFAULT CURRENT_TIMESTAMP)" + ) + conn.commit() + + applied = {row[0] for row in conn.execute("SELECT name FROM _migrations")} + sql_files = sorted(migrations_dir.glob("*.sql")) + + for sql_file in sql_files: + if sql_file.name in applied: + continue + conn.executescript(sql_file.read_text()) + conn.execute("INSERT INTO _migrations (name) VALUES (?)", (sql_file.name,)) + conn.commit() diff --git a/tests/test_db.py b/tests/test_db.py new file mode 100644 index 0000000..1b2a018 --- /dev/null +++ b/tests/test_db.py @@ -0,0 +1,63 @@ +import sqlite3 +import tempfile +from pathlib import Path +import pytest +from circuitforge_core.db import get_connection, run_migrations + + +def test_get_connection_returns_sqlite_connection(tmp_path): + db = tmp_path / "test.db" + conn = get_connection(db) + assert isinstance(conn, sqlite3.Connection) + conn.close() + + +def test_get_connection_creates_file(tmp_path): + db = tmp_path / "test.db" + assert not db.exists() + conn = get_connection(db) + conn.close() + assert db.exists() + + +def test_run_migrations_applies_sql_files(tmp_path): + db = tmp_path / "test.db" + migrations_dir = tmp_path / "migrations" + migrations_dir.mkdir() + (migrations_dir / "001_create_foo.sql").write_text( + "CREATE TABLE foo (id INTEGER PRIMARY KEY, name TEXT);" + ) + conn = get_connection(db) + run_migrations(conn, migrations_dir) + cursor = conn.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='foo'") + assert cursor.fetchone() is not None + conn.close() + + +def test_run_migrations_is_idempotent(tmp_path): + db = tmp_path / "test.db" + migrations_dir = tmp_path / "migrations" + migrations_dir.mkdir() + (migrations_dir / "001_create_foo.sql").write_text( + "CREATE TABLE foo (id INTEGER PRIMARY KEY, name TEXT);" + ) + conn = get_connection(db) + run_migrations(conn, migrations_dir) + run_migrations(conn, migrations_dir) # second run must not raise + conn.close() + + +def test_run_migrations_applies_in_order(tmp_path): + db = tmp_path / "test.db" + migrations_dir = tmp_path / "migrations" + migrations_dir.mkdir() + (migrations_dir / "001_create_foo.sql").write_text( + "CREATE TABLE foo (id INTEGER PRIMARY KEY);" + ) + (migrations_dir / "002_add_name.sql").write_text( + "ALTER TABLE foo ADD COLUMN name TEXT;" + ) + conn = get_connection(db) + run_migrations(conn, migrations_dir) + conn.execute("INSERT INTO foo (name) VALUES ('bar')") + conn.close()