Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 113 additions & 11 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple

from falkordb import FalkorDB
from flask import Blueprint, Flask, abort, jsonify, request
from flask import Blueprint, Flask, abort, has_request_context, jsonify, request
from qdrant_client import QdrantClient
from qdrant_client import models as qdrant_models

Expand Down Expand Up @@ -1067,6 +1067,62 @@ class ServiceState:

state = ServiceState()

# Connection caches for per-request isolation
_graph_cache: Dict[str, Any] = {}
_collection_cache: Dict[str, str] = {}


def _get_graph_name() -> str:
"""Extract and validate X-Graph-Name header for per-request graph isolation.

Returns:
Graph name from header if valid, otherwise environment default.
Returns default if called outside of request context (e.g., background tasks).

Raises:
400: If header format is invalid (regex or length check fails)
403: If graph name is not in ALLOWED_GRAPHS whitelist
"""
# Return default for background tasks and non-HTTP contexts
if not has_request_context():
return GRAPH_NAME

value = request.headers.get("X-Graph-Name", "").strip()
if value:
if not re.match(r'^[a-zA-Z0-9_-]+$', value) or len(value) > 64:
abort(400, description=f"Invalid X-Graph-Name: {value}")
allowed = os.getenv("ALLOWED_GRAPHS", "")
if allowed and value not in [g.strip() for g in allowed.split(",")]:
abort(403, description=f"Graph '{value}' not allowed")
return value
return GRAPH_NAME


def _get_collection_name() -> str:
"""Extract and validate X-Collection-Name header for per-request collection isolation.

Returns:
Collection name from header if valid, otherwise environment default.
Returns default if called outside of request context (e.g., background tasks).

Raises:
400: If header format is invalid (regex or length check fails)
403: If collection name is not in ALLOWED_COLLECTIONS whitelist
"""
# Return default for background tasks and non-HTTP contexts
if not has_request_context():
return COLLECTION_NAME

value = request.headers.get("X-Collection-Name", "").strip()
if value:
if not re.match(r'^[a-zA-Z0-9_-]+$', value) or len(value) > 64:
abort(400, description=f"Invalid X-Collection-Name: {value}")
allowed = os.getenv("ALLOWED_COLLECTIONS", "")
if allowed and value not in [c.strip() for c in allowed.split(",")]:
abort(403, description=f"Collection '{value}' not allowed")
return value
return COLLECTION_NAME


def _extract_api_token() -> Optional[str]:
if not API_TOKEN:
Expand Down Expand Up @@ -1380,9 +1436,39 @@ def _ensure_qdrant_collection() -> None:
state.qdrant = None


def get_memory_graph() -> Any:
def get_memory_graph(graph_name: Optional[str] = None) -> Any:
"""Get FalkorDB graph instance with optional per-request isolation.

Args:
graph_name: Optional graph name override. If None, tries request headers,
then falls back to environment default.

Returns:
Graph instance for the specified graph name.
"""
init_falkordb()
return state.memory_graph

# If no graph_name provided and no FalkorDB connection, return default graph for backward compatibility
if graph_name is None and state.falkordb is None:
return state.memory_graph

if state.falkordb is None:
return None

# Determine graph name: explicit parameter > request header > environment default
if graph_name is None:
# Only check headers if we're in a request context
try:
graph_name = _get_graph_name()
except RuntimeError:
# Not in request context, use default
graph_name = GRAPH_NAME

# Cache graph instances per name
if graph_name not in _graph_cache:
_graph_cache[graph_name] = state.falkordb.select_graph(graph_name)

return _graph_cache[graph_name]


def get_qdrant_client() -> Optional[QdrantClient]:
Expand Down Expand Up @@ -2395,7 +2481,11 @@ def store_memory() -> Any:
except ValueError as exc:
abort(400, description=str(exc))

graph = get_memory_graph()
# Extract per-request isolation headers
graph_name = _get_graph_name()
collection_name = _get_collection_name()

graph = get_memory_graph(graph_name)
if graph is None:
abort(503, description="FalkorDB is unavailable")

Expand Down Expand Up @@ -2479,7 +2569,7 @@ def store_memory() -> Any:
if qdrant_client is not None:
try:
qdrant_client.upsert(
collection_name=COLLECTION_NAME,
collection_name=collection_name,
points=[
PointStruct(
id=memory_id,
Expand Down Expand Up @@ -2552,7 +2642,11 @@ def update_memory(memory_id: str) -> Any:
if not isinstance(payload, dict):
abort(400, description="JSON body is required")

graph = get_memory_graph()
# Extract per-request isolation headers
graph_name = _get_graph_name()
collection_name = _get_collection_name()

graph = get_memory_graph(graph_name)
if graph is None:
abort(503, description="FalkorDB is unavailable")

Expand Down Expand Up @@ -2641,7 +2735,7 @@ def update_memory(memory_id: str) -> Any:
else:
try:
existing = qdrant_client.retrieve(
collection_name=COLLECTION_NAME,
collection_name=collection_name,
ids=[memory_id],
with_vectors=True,
)
Expand All @@ -2665,7 +2759,7 @@ def update_memory(memory_id: str) -> Any:
"metadata": metadata,
}
qdrant_client.upsert(
collection_name=COLLECTION_NAME,
collection_name=collection_name,
points=[PointStruct(id=memory_id, vector=vector, payload=payload)],
)

Expand All @@ -2674,7 +2768,11 @@ def update_memory(memory_id: str) -> Any:

@memory_bp.route("/memory/<memory_id>", methods=["DELETE"])
def delete_memory(memory_id: str) -> Any:
graph = get_memory_graph()
# Extract per-request isolation headers
graph_name = _get_graph_name()
collection_name = _get_collection_name()

graph = get_memory_graph(graph_name)
if graph is None:
abort(503, description="FalkorDB is unavailable")

Expand All @@ -2691,7 +2789,7 @@ def delete_memory(memory_id: str) -> Any:
selector = qdrant_models.PointIdsList(points=[memory_id])
else:
selector = {"points": [memory_id]}
qdrant_client.delete(collection_name=COLLECTION_NAME, points_selector=selector)
qdrant_client.delete(collection_name=collection_name, points_selector=selector)
except Exception:
logger.exception("Failed to delete vector for memory %s", memory_id)

Expand Down Expand Up @@ -2783,8 +2881,12 @@ def recall_memories() -> Any:

tag_filters = _normalize_tag_list(tags_param)

# Extract per-request isolation headers
graph_name = _get_graph_name()
collection_name = _get_collection_name()

seen_ids: set[str] = set()
graph = get_memory_graph()
graph = get_memory_graph(graph_name)
qdrant_client = get_qdrant_client()

results: List[Dict[str, Any]] = []
Expand Down
Loading
Loading