Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 26 additions & 8 deletions src/powermem/storage/sqlite/sqlite_vector_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,24 @@

logger = logging.getLogger(__name__)

def _json_path_for_key(key: str) -> str:
"""
Build a SQLite JSON1 path for `json_extract(payload, ?)` from a dotted key.

We preserve the previous semantics where dots indicate nesting, but avoid
SQL injection by *parameterizing the path* instead of interpolating it into SQL.

Example:
key="user_id" -> $."user_id"
key="a.b" -> $."a"."b"
key="x') OR 1=1 -- " -> $."x') OR 1=1 -- "
"""
segments = key.split(".") if key is not None else [""]
path = "$"
for segment in segments:
path += f".{json.dumps(segment)}"
return path


class SQLiteVectorStore(VectorStoreBase):
"""Simple SQLite-based vector store implementation."""
Expand Down Expand Up @@ -123,8 +141,8 @@ def search(self, query: str, vectors: List[List[float]] = None, limit: int = 5,
conditions = []
for key, value in filters.items():
# Filter by JSON field in payload
conditions.append(f"json_extract(payload, '$.{key}') = ?")
query_params.append(value)
conditions.append("(json_extract(payload, ?) = ?)")
query_params.extend([_json_path_for_key(key), value])

if conditions:
query_sql += " WHERE " + " AND ".join(conditions)
Expand Down Expand Up @@ -250,8 +268,8 @@ def list(self, filters=None, limit=None, offset=None, order_by=None, order="desc
conditions = []
for key, value in filters.items():
# Filter by JSON field in payload
conditions.append(f"json_extract(payload, '$.{key}') = ?")
query_params.append(value)
conditions.append("(json_extract(payload, ?) = ?)")
query_params.extend([_json_path_for_key(key), value])

if conditions:
query += " WHERE " + " AND ".join(conditions)
Expand Down Expand Up @@ -312,8 +330,8 @@ def count(self, filters=None) -> int:
conditions = []
for key, value in filters.items():
# Filter by JSON field in payload
conditions.append(f"json_extract(payload, '$.{key}') = ?")
query_params.append(value)
conditions.append("(json_extract(payload, ?) = ?)")
query_params.extend([_json_path_for_key(key), value])

if conditions:
query += " WHERE " + " AND ".join(conditions)
Expand Down Expand Up @@ -343,8 +361,8 @@ def get_statistics(
if filters:
conditions = []
for key, value in filters.items():
conditions.append(f"json_extract(payload, '$.{key}') = ?")
query_params.append(value)
conditions.append("(json_extract(payload, ?) = ?)")
query_params.extend([_json_path_for_key(key), value])
if conditions:
query += " WHERE " + " AND ".join(conditions)

Expand Down
53 changes: 53 additions & 0 deletions tests/integration/test_storage_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,3 +194,56 @@ def test_storage_collection_isolation(self, sqlite_memory):
assert len(mem2_results) > 0
finally:
patcher2.stop()

def test_sqlite_filters_key_sql_injection_regression(self):
"""
Regression test for CWE-89: filters key must not be interpolated into SQL.

Before the fix, a malicious key could inject SQL and bypass the intended
`user_id` filter, returning other users' rows.
"""
store = SQLiteVectorStore(
database_path=":memory:",
collection_name=f"test_collection_sqli_{uuid.uuid4().hex[:8]}",
)

dangerous_key = "x') OR (?=?) -- "
vec = [0.1] * 10

store.insert(
[vec],
[
{
"user_id": "alice",
dangerous_key: "alice",
"content": "ALICE_SECRET",
}
],
)
store.insert(
[vec],
[
{
"user_id": "bob",
dangerous_key: "alice",
"content": "BOB_SECRET",
}
],
)

injected_filters = {dangerous_key: "alice", "user_id": "alice"}

results = store.search(query="", vectors=[vec], limit=10, filters=injected_filters)
assert len(results) == 1
assert results[0].payload is not None
assert results[0].payload.get("user_id") == "alice"

listed = store.list(filters=injected_filters)
assert len(listed) == 1
assert listed[0].payload is not None
assert listed[0].payload.get("user_id") == "alice"

assert store.count(filters=injected_filters) == 1

stats = store.get_statistics(filters=injected_filters)
assert stats["total_memories"] == 1
Loading