fix(vector): explicit rollback, table identifier guard, query scope fix
This commit is contained in:
parent
0489f1111c
commit
a6d906bcbb
2 changed files with 25 additions and 11 deletions
|
|
@ -10,6 +10,7 @@ from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
import sqlite3
|
import sqlite3
|
||||||
import struct
|
import struct
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
|
@ -22,6 +23,8 @@ from .base import VectorMatch, VectorStore
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_SAFE_IDENTIFIER = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]*$")
|
||||||
|
|
||||||
|
|
||||||
def _serialize(vector: list[float]) -> bytes:
|
def _serialize(vector: list[float]) -> bytes:
|
||||||
return struct.pack(f"<{len(vector)}f", *vector)
|
return struct.pack(f"<{len(vector)}f", *vector)
|
||||||
|
|
@ -47,6 +50,10 @@ class LocalSQLiteVecStore(VectorStore):
|
||||||
table: str = "vecs",
|
table: str = "vecs",
|
||||||
dimensions: int = 768,
|
dimensions: int = 768,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
if not _SAFE_IDENTIFIER.match(table):
|
||||||
|
raise ValueError(
|
||||||
|
f"table must be a valid SQL identifier (letters, digits, underscores): {table!r}"
|
||||||
|
)
|
||||||
self.db_path = str(db_path)
|
self.db_path = str(db_path)
|
||||||
self.table = table
|
self.table = table
|
||||||
self.dimensions = dimensions
|
self.dimensions = dimensions
|
||||||
|
|
@ -62,6 +69,9 @@ class LocalSQLiteVecStore(VectorStore):
|
||||||
try:
|
try:
|
||||||
yield conn
|
yield conn
|
||||||
conn.commit()
|
conn.commit()
|
||||||
|
except Exception:
|
||||||
|
conn.rollback()
|
||||||
|
raise
|
||||||
finally:
|
finally:
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
|
|
@ -125,15 +135,14 @@ class LocalSQLiteVecStore(VectorStore):
|
||||||
""",
|
""",
|
||||||
[_serialize(vector), top_k],
|
[_serialize(vector), top_k],
|
||||||
).fetchall()
|
).fetchall()
|
||||||
|
results = [
|
||||||
results = [
|
VectorMatch(
|
||||||
VectorMatch(
|
entry_id=r["entry_id"],
|
||||||
entry_id=r["entry_id"],
|
score=r["distance"],
|
||||||
score=r["distance"],
|
metadata=json.loads(r["metadata"]),
|
||||||
metadata=json.loads(r["metadata"]),
|
)
|
||||||
)
|
for r in rows
|
||||||
for r in rows
|
]
|
||||||
]
|
|
||||||
|
|
||||||
if filter_metadata:
|
if filter_metadata:
|
||||||
results = [
|
results = [
|
||||||
|
|
|
||||||
|
|
@ -29,9 +29,14 @@ def test_upsert_and_query_returns_match(store):
|
||||||
|
|
||||||
def test_upsert_replaces_existing(store):
|
def test_upsert_replaces_existing(store):
|
||||||
store.upsert("chunk-1", _vec(0.1), {"page": 1})
|
store.upsert("chunk-1", _vec(0.1), {"page": 1})
|
||||||
store.upsert("chunk-1", _vec(0.2), {"page": 99})
|
store.upsert("chunk-1", _vec(0.9), {"page": 99})
|
||||||
results = store.query(_vec(0.2), top_k=5)
|
# Metadata check
|
||||||
|
results = store.query(_vec(0.9), top_k=5)
|
||||||
assert results[0].metadata["page"] == 99
|
assert results[0].metadata["page"] == 99
|
||||||
|
# Vector check: querying with new vector should score better than querying with old
|
||||||
|
old_results = store.query(_vec(0.1), top_k=5)
|
||||||
|
new_results = store.query(_vec(0.9), top_k=5)
|
||||||
|
assert new_results[0].score < old_results[0].score
|
||||||
|
|
||||||
|
|
||||||
def test_query_respects_top_k(store):
|
def test_query_respects_top_k(store):
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue