fix(vector): explicit rollback, table identifier guard, query scope fix

This commit is contained in:
pyr0ball 2026-05-04 15:55:05 -07:00
parent 0489f1111c
commit a6d906bcbb
2 changed files with 25 additions and 11 deletions

View file

@ -10,6 +10,7 @@ from __future__ import annotations
import json
import logging
import re
import sqlite3
import struct
from contextlib import contextmanager
@ -22,6 +23,8 @@ from .base import VectorMatch, VectorStore
logger = logging.getLogger(__name__)
_SAFE_IDENTIFIER = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]*$")
def _serialize(vector: list[float]) -> bytes:
return struct.pack(f"<{len(vector)}f", *vector)
@ -47,6 +50,10 @@ class LocalSQLiteVecStore(VectorStore):
table: str = "vecs",
dimensions: int = 768,
) -> None:
if not _SAFE_IDENTIFIER.match(table):
raise ValueError(
f"table must be a valid SQL identifier (letters, digits, underscores): {table!r}"
)
self.db_path = str(db_path)
self.table = table
self.dimensions = dimensions
@ -62,6 +69,9 @@ class LocalSQLiteVecStore(VectorStore):
try:
yield conn
conn.commit()
except Exception:
conn.rollback()
raise
finally:
conn.close()
@ -125,15 +135,14 @@ class LocalSQLiteVecStore(VectorStore):
""",
[_serialize(vector), top_k],
).fetchall()
results = [
VectorMatch(
entry_id=r["entry_id"],
score=r["distance"],
metadata=json.loads(r["metadata"]),
)
for r in rows
]
results = [
VectorMatch(
entry_id=r["entry_id"],
score=r["distance"],
metadata=json.loads(r["metadata"]),
)
for r in rows
]
if filter_metadata:
results = [

View file

@ -29,9 +29,14 @@ def test_upsert_and_query_returns_match(store):
def test_upsert_replaces_existing(store):
store.upsert("chunk-1", _vec(0.1), {"page": 1})
store.upsert("chunk-1", _vec(0.2), {"page": 99})
results = store.query(_vec(0.2), top_k=5)
store.upsert("chunk-1", _vec(0.9), {"page": 99})
# Metadata check
results = store.query(_vec(0.9), top_k=5)
assert results[0].metadata["page"] == 99
# Vector check: querying with new vector should score better than querying with old
old_results = store.query(_vec(0.1), top_k=5)
new_results = store.query(_vec(0.9), top_k=5)
assert new_results[0].score < old_results[0].score
def test_query_respects_top_k(store):