diff --git a/app/db/migrations/001_chains.sql b/app/db/migrations/001_chains.sql new file mode 100644 index 0000000..dfdd3f3 --- /dev/null +++ b/app/db/migrations/001_chains.sql @@ -0,0 +1,5 @@ +CREATE TABLE IF NOT EXISTS chains ( + id TEXT PRIMARY KEY, + name TEXT NOT NULL, + created_at REAL NOT NULL +); diff --git a/app/db/migrations/002_nodes.sql b/app/db/migrations/002_nodes.sql new file mode 100644 index 0000000..1de0ac9 --- /dev/null +++ b/app/db/migrations/002_nodes.sql @@ -0,0 +1,19 @@ +CREATE TABLE IF NOT EXISTS nodes ( + id TEXT PRIMARY KEY, + chain_id TEXT NOT NULL REFERENCES chains(id) ON DELETE CASCADE, + parent_id TEXT REFERENCES nodes(id) ON DELETE CASCADE, + audio_path TEXT, + duration_s REAL, + status TEXT NOT NULL DEFAULT 'pending', + is_committed INTEGER NOT NULL DEFAULT 0, + prompt TEXT NOT NULL DEFAULT '', + energy REAL, + tempo_feel REAL, + density REAL, + cfg_coef REAL NOT NULL DEFAULT 3.0, + prompt_duration_s REAL NOT NULL DEFAULT 10.0, + error_msg TEXT, + created_at REAL NOT NULL +); +CREATE INDEX IF NOT EXISTS idx_nodes_chain_id ON nodes(chain_id); +CREATE INDEX IF NOT EXISTS idx_nodes_parent_id ON nodes(parent_id); diff --git a/app/db/store.py b/app/db/store.py new file mode 100644 index 0000000..28c4bc8 --- /dev/null +++ b/app/db/store.py @@ -0,0 +1,19 @@ +from __future__ import annotations +import sqlite3 +from pathlib import Path + +_MIGRATIONS_DIR = Path(__file__).parent / "migrations" + + +def get_connection(db_path: str) -> sqlite3.Connection: + conn = sqlite3.connect(db_path, check_same_thread=False) + conn.row_factory = sqlite3.Row + conn.execute("PRAGMA journal_mode=WAL") + conn.execute("PRAGMA foreign_keys=ON") + return conn + + +def run_migrations(conn: sqlite3.Connection) -> None: + for migration in sorted(_MIGRATIONS_DIR.glob("*.sql")): + conn.executescript(migration.read_text()) + conn.commit() diff --git a/tests/test_store.py b/tests/test_store.py new file mode 100644 index 0000000..c09c9e4 --- /dev/null +++ b/tests/test_store.py @@ -0,0 +1,27 @@ +import sqlite3 +import tempfile +import pytest +from app.db.store import get_connection, run_migrations + + +def test_migrations_create_tables(): + with tempfile.NamedTemporaryFile(suffix=".db") as f: + conn = get_connection(f.name) + run_migrations(conn) + tables = {r[0] for r in conn.execute( + "SELECT name FROM sqlite_master WHERE type='table'" + )} + assert "chains" in tables + assert "nodes" in tables + + +def test_foreign_keys_enforced(): + with tempfile.NamedTemporaryFile(suffix=".db") as f: + conn = get_connection(f.name) + run_migrations(conn) + with pytest.raises(sqlite3.IntegrityError): + conn.execute( + "INSERT INTO nodes (id, chain_id, status, created_at) " + "VALUES ('n1', 'nonexistent', 'pending', 0.0)" + ) + conn.commit()