fix(vector): explicit rollback, table identifier guard, query scope fix
This commit is contained in:
parent
0489f1111c
commit
a6d906bcbb
2 changed files with 25 additions and 11 deletions
|
|
@ -10,6 +10,7 @@ from __future__ import annotations
|
|||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import sqlite3
|
||||
import struct
|
||||
from contextlib import contextmanager
|
||||
|
|
@ -22,6 +23,8 @@ from .base import VectorMatch, VectorStore
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_SAFE_IDENTIFIER = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]*$")
|
||||
|
||||
|
||||
def _serialize(vector: list[float]) -> bytes:
|
||||
return struct.pack(f"<{len(vector)}f", *vector)
|
||||
|
|
@ -47,6 +50,10 @@ class LocalSQLiteVecStore(VectorStore):
|
|||
table: str = "vecs",
|
||||
dimensions: int = 768,
|
||||
) -> None:
|
||||
if not _SAFE_IDENTIFIER.match(table):
|
||||
raise ValueError(
|
||||
f"table must be a valid SQL identifier (letters, digits, underscores): {table!r}"
|
||||
)
|
||||
self.db_path = str(db_path)
|
||||
self.table = table
|
||||
self.dimensions = dimensions
|
||||
|
|
@ -62,6 +69,9 @@ class LocalSQLiteVecStore(VectorStore):
|
|||
try:
|
||||
yield conn
|
||||
conn.commit()
|
||||
except Exception:
|
||||
conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
|
@ -125,7 +135,6 @@ class LocalSQLiteVecStore(VectorStore):
|
|||
""",
|
||||
[_serialize(vector), top_k],
|
||||
).fetchall()
|
||||
|
||||
results = [
|
||||
VectorMatch(
|
||||
entry_id=r["entry_id"],
|
||||
|
|
|
|||
|
|
@ -29,9 +29,14 @@ def test_upsert_and_query_returns_match(store):
|
|||
|
||||
def test_upsert_replaces_existing(store):
|
||||
store.upsert("chunk-1", _vec(0.1), {"page": 1})
|
||||
store.upsert("chunk-1", _vec(0.2), {"page": 99})
|
||||
results = store.query(_vec(0.2), top_k=5)
|
||||
store.upsert("chunk-1", _vec(0.9), {"page": 99})
|
||||
# Metadata check
|
||||
results = store.query(_vec(0.9), top_k=5)
|
||||
assert results[0].metadata["page"] == 99
|
||||
# Vector check: querying with new vector should score better than querying with old
|
||||
old_results = store.query(_vec(0.1), top_k=5)
|
||||
new_results = store.query(_vec(0.9), top_k=5)
|
||||
assert new_results[0].score < old_results[0].score
|
||||
|
||||
|
||||
def test_query_respects_top_k(store):
|
||||
|
|
|
|||
Loading…
Reference in a new issue