From 3b8da7508b8891d1a2ffce86f49d7b40b8f92a35 Mon Sep 17 00:00:00 2001 From: "Daniel L. Iser" Date: Sat, 13 Dec 2025 15:53:51 -0500 Subject: [PATCH 1/2] feat: add per-request isolation headers for multi-tenant support Implements X-Graph-Name and X-Collection-Name headers to enable per-request graph and collection routing without deploying separate AutoMem instances. Key changes: - Add _get_graph_name() and _get_collection_name() helper functions with validation (regex, length limits) and optional whitelist support - Update get_memory_graph() to support per-request isolation with automatic header detection and connection caching - Update core endpoints (store_memory, update_memory, delete_memory, recall_memories) to use per-request isolation - Add comprehensive test suite (22 tests) covering header extraction, validation, whitelists, and backwards compatibility - Maintain full backwards compatibility: no headers = environment defaults - Add connection caching to avoid redundant graph instance creation Security features: - Header validation: [a-zA-Z0-9_-]+ with max 64 chars - Optional whitelists via ALLOWED_GRAPHS and ALLOWED_COLLECTIONS - 400 for invalid format, 403 for whitelist rejection Implementation follows spec in: services/automem-federation/docs/automem-isolation-headers-spec.md Closes: Multi-tenant isolation requirement --- app.py | 112 +++++++++++-- tests/test_isolation_headers.py | 271 ++++++++++++++++++++++++++++++++ 2 files changed, 373 insertions(+), 10 deletions(-) create mode 100644 tests/test_isolation_headers.py diff --git a/app.py b/app.py index 1702eee..c478b6b 100644 --- a/app.py +++ b/app.py @@ -1067,6 +1067,52 @@ class ServiceState: state = ServiceState() +# Connection caches for per-request isolation +_graph_cache: Dict[str, Any] = {} +_collection_cache: Dict[str, str] = {} + + +def _get_graph_name() -> str: + """Extract and validate X-Graph-Name header for per-request graph isolation. + + Returns: + Graph name from header if valid, otherwise environment default. + + Raises: + 400: If header format is invalid (regex or length check fails) + 403: If graph name is not in ALLOWED_GRAPHS whitelist + """ + value = request.headers.get("X-Graph-Name", "").strip() + if value: + if not re.match(r'^[a-zA-Z0-9_-]+$', value) or len(value) > 64: + abort(400, description=f"Invalid X-Graph-Name: {value}") + allowed = os.getenv("ALLOWED_GRAPHS", "") + if allowed and value not in [g.strip() for g in allowed.split(",")]: + abort(403, description=f"Graph '{value}' not allowed") + return value + return GRAPH_NAME + + +def _get_collection_name() -> str: + """Extract and validate X-Collection-Name header for per-request collection isolation. + + Returns: + Collection name from header if valid, otherwise environment default. + + Raises: + 400: If header format is invalid (regex or length check fails) + 403: If collection name is not in ALLOWED_COLLECTIONS whitelist + """ + value = request.headers.get("X-Collection-Name", "").strip() + if value: + if not re.match(r'^[a-zA-Z0-9_-]+$', value) or len(value) > 64: + abort(400, description=f"Invalid X-Collection-Name: {value}") + allowed = os.getenv("ALLOWED_COLLECTIONS", "") + if allowed and value not in [c.strip() for c in allowed.split(",")]: + abort(403, description=f"Collection '{value}' not allowed") + return value + return COLLECTION_NAME + def _extract_api_token() -> Optional[str]: if not API_TOKEN: @@ -1380,9 +1426,39 @@ def _ensure_qdrant_collection() -> None: state.qdrant = None -def get_memory_graph() -> Any: +def get_memory_graph(graph_name: Optional[str] = None) -> Any: + """Get FalkorDB graph instance with optional per-request isolation. + + Args: + graph_name: Optional graph name override. If None, tries request headers, + then falls back to environment default. + + Returns: + Graph instance for the specified graph name. + """ init_falkordb() - return state.memory_graph + + # If no graph_name provided and no FalkorDB connection, return default graph for backward compatibility + if graph_name is None and state.falkordb is None: + return state.memory_graph + + if state.falkordb is None: + return None + + # Determine graph name: explicit parameter > request header > environment default + if graph_name is None: + # Only check headers if we're in a request context + try: + graph_name = _get_graph_name() + except RuntimeError: + # Not in request context, use default + graph_name = GRAPH_NAME + + # Cache graph instances per name + if graph_name not in _graph_cache: + _graph_cache[graph_name] = state.falkordb.select_graph(graph_name) + + return _graph_cache[graph_name] def get_qdrant_client() -> Optional[QdrantClient]: @@ -2395,7 +2471,11 @@ def store_memory() -> Any: except ValueError as exc: abort(400, description=str(exc)) - graph = get_memory_graph() + # Extract per-request isolation headers + graph_name = _get_graph_name() + collection_name = _get_collection_name() + + graph = get_memory_graph(graph_name) if graph is None: abort(503, description="FalkorDB is unavailable") @@ -2479,7 +2559,7 @@ def store_memory() -> Any: if qdrant_client is not None: try: qdrant_client.upsert( - collection_name=COLLECTION_NAME, + collection_name=collection_name, points=[ PointStruct( id=memory_id, @@ -2552,7 +2632,11 @@ def update_memory(memory_id: str) -> Any: if not isinstance(payload, dict): abort(400, description="JSON body is required") - graph = get_memory_graph() + # Extract per-request isolation headers + graph_name = _get_graph_name() + collection_name = _get_collection_name() + + graph = get_memory_graph(graph_name) if graph is None: abort(503, description="FalkorDB is unavailable") @@ -2641,7 +2725,7 @@ def update_memory(memory_id: str) -> Any: else: try: existing = qdrant_client.retrieve( - collection_name=COLLECTION_NAME, + collection_name=collection_name, ids=[memory_id], with_vectors=True, ) @@ -2665,7 +2749,7 @@ def update_memory(memory_id: str) -> Any: "metadata": metadata, } qdrant_client.upsert( - collection_name=COLLECTION_NAME, + collection_name=collection_name, points=[PointStruct(id=memory_id, vector=vector, payload=payload)], ) @@ -2674,7 +2758,11 @@ def update_memory(memory_id: str) -> Any: @memory_bp.route("/memory/", methods=["DELETE"]) def delete_memory(memory_id: str) -> Any: - graph = get_memory_graph() + # Extract per-request isolation headers + graph_name = _get_graph_name() + collection_name = _get_collection_name() + + graph = get_memory_graph(graph_name) if graph is None: abort(503, description="FalkorDB is unavailable") @@ -2691,7 +2779,7 @@ def delete_memory(memory_id: str) -> Any: selector = qdrant_models.PointIdsList(points=[memory_id]) else: selector = {"points": [memory_id]} - qdrant_client.delete(collection_name=COLLECTION_NAME, points_selector=selector) + qdrant_client.delete(collection_name=collection_name, points_selector=selector) except Exception: logger.exception("Failed to delete vector for memory %s", memory_id) @@ -2783,8 +2871,12 @@ def recall_memories() -> Any: tag_filters = _normalize_tag_list(tags_param) + # Extract per-request isolation headers + graph_name = _get_graph_name() + collection_name = _get_collection_name() + seen_ids: set[str] = set() - graph = get_memory_graph() + graph = get_memory_graph(graph_name) qdrant_client = get_qdrant_client() results: List[Dict[str, Any]] = [] diff --git a/tests/test_isolation_headers.py b/tests/test_isolation_headers.py new file mode 100644 index 0000000..4ae1b12 --- /dev/null +++ b/tests/test_isolation_headers.py @@ -0,0 +1,271 @@ +"""Tests for per-request isolation headers (X-Graph-Name, X-Collection-Name). + +This module tests the implementation of per-request isolation as specified in: +services/automem-federation/docs/automem-isolation-headers-spec.md +""" + +import json +import os +import pytest +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import app + + +# Test fixtures and setup +@pytest.fixture +def mock_graph(): + """Mock FalkorDB graph instance.""" + graph = MagicMock() + graph.query = MagicMock(return_value=SimpleNamespace(result_set=[])) + return graph + + +@pytest.fixture +def mock_qdrant(): + """Mock Qdrant client instance.""" + qdrant = MagicMock() + qdrant.upsert = MagicMock() + qdrant.retrieve = MagicMock(return_value=[]) + qdrant.delete = MagicMock() + return qdrant + + +class TestHeaderExtraction: + """Test header extraction and validation logic.""" + + def test_get_graph_name_with_valid_header(self): + """Test that valid X-Graph-Name header is accepted.""" + with app.app.test_request_context( + headers={"X-Graph-Name": "test_graph"} + ): + assert app._get_graph_name() == "test_graph" + + def test_get_graph_name_with_underscore(self): + """Test that graph names with underscores are valid.""" + with app.app.test_request_context( + headers={"X-Graph-Name": "test_graph_123"} + ): + assert app._get_graph_name() == "test_graph_123" + + def test_get_graph_name_with_hyphen(self): + """Test that graph names with hyphens are valid.""" + with app.app.test_request_context( + headers={"X-Graph-Name": "test-graph-123"} + ): + assert app._get_graph_name() == "test-graph-123" + + def test_get_graph_name_defaults_to_env(self): + """Test that missing header falls back to environment default.""" + with app.app.test_request_context(): + assert app._get_graph_name() == app.GRAPH_NAME + + def test_get_graph_name_empty_header_uses_default(self): + """Test that empty header value uses default.""" + with app.app.test_request_context( + headers={"X-Graph-Name": ""} + ): + assert app._get_graph_name() == app.GRAPH_NAME + + def test_get_graph_name_whitespace_header_uses_default(self): + """Test that whitespace-only header uses default.""" + with app.app.test_request_context( + headers={"X-Graph-Name": " "} + ): + assert app._get_graph_name() == app.GRAPH_NAME + + def test_get_collection_name_with_valid_header(self): + """Test that valid X-Collection-Name header is accepted.""" + with app.app.test_request_context( + headers={"X-Collection-Name": "test_collection"} + ): + assert app._get_collection_name() == "test_collection" + + def test_get_collection_name_defaults_to_env(self): + """Test that missing header falls back to environment default.""" + with app.app.test_request_context(): + assert app._get_collection_name() == app.COLLECTION_NAME + + +class TestHeaderValidation: + """Test header validation rules (regex, length, format).""" + + def test_invalid_graph_name_special_chars(self): + """Test that special characters in graph name are rejected (400).""" + with app.app.test_request_context( + headers={"X-Graph-Name": "invalid@name"} + ): + from werkzeug.exceptions import BadRequest + with pytest.raises(BadRequest) as exc_info: + app._get_graph_name() + assert exc_info.value.code == 400 + assert "Invalid X-Graph-Name" in str(exc_info.value.description) + + def test_invalid_graph_name_spaces(self): + """Test that spaces in graph name are rejected (400).""" + with app.app.test_request_context( + headers={"X-Graph-Name": "invalid name"} + ): + from werkzeug.exceptions import BadRequest + with pytest.raises(BadRequest) as exc_info: + app._get_graph_name() + assert exc_info.value.code == 400 + + def test_invalid_graph_name_too_long(self): + """Test that graph names over 64 chars are rejected (400).""" + with app.app.test_request_context( + headers={"X-Graph-Name": "a" * 65} + ): + from werkzeug.exceptions import BadRequest + with pytest.raises(BadRequest) as exc_info: + app._get_graph_name() + assert exc_info.value.code == 400 + + def test_valid_graph_name_max_length(self): + """Test that graph names at exactly 64 chars are accepted.""" + with app.app.test_request_context( + headers={"X-Graph-Name": "a" * 64} + ): + assert app._get_graph_name() == "a" * 64 + + def test_invalid_collection_name_special_chars(self): + """Test that special characters in collection name are rejected (400).""" + with app.app.test_request_context( + headers={"X-Collection-Name": "invalid$collection"} + ): + from werkzeug.exceptions import BadRequest + with pytest.raises(BadRequest) as exc_info: + app._get_collection_name() + assert exc_info.value.code == 400 + assert "Invalid X-Collection-Name" in str(exc_info.value.description) + + +class TestWhitelistEnforcement: + """Test whitelist enforcement via ALLOWED_GRAPHS and ALLOWED_COLLECTIONS.""" + + def test_graph_whitelist_allows_listed_graph(self, monkeypatch): + """Test that whitelisted graph names are accepted.""" + monkeypatch.setenv("ALLOWED_GRAPHS", "graph1,graph2,graph3") + with app.app.test_request_context( + headers={"X-Graph-Name": "graph2"} + ): + assert app._get_graph_name() == "graph2" + + def test_graph_whitelist_rejects_unlisted_graph(self, monkeypatch): + """Test that non-whitelisted graph names are rejected (403).""" + monkeypatch.setenv("ALLOWED_GRAPHS", "graph1,graph2") + with app.app.test_request_context( + headers={"X-Graph-Name": "graph3"} + ): + from werkzeug.exceptions import Forbidden + with pytest.raises(Forbidden) as exc_info: + app._get_graph_name() + assert exc_info.value.code == 403 + assert "not allowed" in str(exc_info.value.description) + + def test_graph_whitelist_with_spaces(self, monkeypatch): + """Test that whitelist parsing handles spaces correctly.""" + monkeypatch.setenv("ALLOWED_GRAPHS", "graph1, graph2, graph3") + with app.app.test_request_context( + headers={"X-Graph-Name": "graph2"} + ): + assert app._get_graph_name() == "graph2" + + def test_no_whitelist_allows_any_valid_name(self, monkeypatch): + """Test that without whitelist, any valid name is accepted.""" + monkeypatch.delenv("ALLOWED_GRAPHS", raising=False) + with app.app.test_request_context( + headers={"X-Graph-Name": "any_valid_graph"} + ): + assert app._get_graph_name() == "any_valid_graph" + + def test_collection_whitelist_allows_listed_collection(self, monkeypatch): + """Test that whitelisted collection names are accepted.""" + monkeypatch.setenv("ALLOWED_COLLECTIONS", "coll1,coll2,coll3") + with app.app.test_request_context( + headers={"X-Collection-Name": "coll2"} + ): + assert app._get_collection_name() == "coll2" + + def test_collection_whitelist_rejects_unlisted_collection(self, monkeypatch): + """Test that non-whitelisted collection names are rejected (403).""" + monkeypatch.setenv("ALLOWED_COLLECTIONS", "coll1,coll2") + with app.app.test_request_context( + headers={"X-Collection-Name": "coll3"} + ): + from werkzeug.exceptions import Forbidden + with pytest.raises(Forbidden) as exc_info: + app._get_collection_name() + assert exc_info.value.code == 403 + + +class TestBackwardsCompatibility: + """Test that existing behavior is preserved when headers are not provided.""" + + def test_get_memory_graph_without_request_context_uses_default(self): + """Test that get_memory_graph() works outside request context.""" + # This should not raise an error and should use the default + # Note: This will fail if FalkorDB is not available, but that's expected + # in a unit test environment. The key is that it doesn't raise a RuntimeError + # about missing request context. + try: + graph = app.get_memory_graph() + # If it returns None, that's fine - FalkorDB might not be running + # We're just testing it doesn't crash + except RuntimeError as e: + if "request context" in str(e).lower(): + pytest.fail("get_memory_graph() should not require request context") + # Other RuntimeErrors are acceptable (e.g., connection failures) + + +class TestConnectionCaching: + """Test that graph connections are cached per graph name.""" + + @patch('app.state') + def test_graph_caching_same_name(self, mock_state): + """Test that requesting same graph name returns cached instance.""" + mock_falkordb = MagicMock() + mock_state.falkordb = mock_falkordb + mock_state.memory_graph = MagicMock() + + # Create a mock graph instance + mock_graph_instance = MagicMock() + mock_falkordb.select_graph.return_value = mock_graph_instance + + # Clear cache for clean test + app._graph_cache.clear() + + # First call should create new instance + graph1 = app.get_memory_graph("test_graph") + assert mock_falkordb.select_graph.call_count == 1 + assert mock_falkordb.select_graph.call_args[0][0] == "test_graph" + + # Second call with same name should use cache + graph2 = app.get_memory_graph("test_graph") + assert mock_falkordb.select_graph.call_count == 1 # Still 1, not 2 + assert graph1 is graph2 # Same instance + + @patch('app.state') + def test_graph_caching_different_names(self, mock_state): + """Test that different graph names create separate instances.""" + mock_falkordb = MagicMock() + mock_state.falkordb = mock_falkordb + mock_state.memory_graph = MagicMock() + + # Return different instances for different graph names + mock_graph1 = MagicMock() + mock_graph2 = MagicMock() + mock_falkordb.select_graph.side_effect = [mock_graph1, mock_graph2] + + # Clear cache for clean test + app._graph_cache.clear() + + # Two different graph names should create two instances + graph1 = app.get_memory_graph("graph1") + graph2 = app.get_memory_graph("graph2") + + assert mock_falkordb.select_graph.call_count == 2 + assert graph1 is mock_graph1 + assert graph2 is mock_graph2 + assert graph1 is not graph2 From 93ed0785f6703d32f82ab417c7fac7deef7783f6 Mon Sep 17 00:00:00 2001 From: "Daniel L. Iser" Date: Sun, 14 Dec 2025 21:19:08 -0500 Subject: [PATCH 2/2] fix: add request context check for background task safety - Import has_request_context from Flask - Add context check to _get_graph_name() and _get_collection_name() - Background tasks (consolidation, enrichment) now safely use defaults --- app.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/app.py b/app.py index c478b6b..ee51c30 100644 --- a/app.py +++ b/app.py @@ -27,7 +27,7 @@ from typing import Any, Dict, Iterable, List, Optional, Set, Tuple from falkordb import FalkorDB -from flask import Blueprint, Flask, abort, jsonify, request +from flask import Blueprint, Flask, abort, has_request_context, jsonify, request from qdrant_client import QdrantClient from qdrant_client import models as qdrant_models @@ -1077,11 +1077,16 @@ def _get_graph_name() -> str: Returns: Graph name from header if valid, otherwise environment default. + Returns default if called outside of request context (e.g., background tasks). Raises: 400: If header format is invalid (regex or length check fails) 403: If graph name is not in ALLOWED_GRAPHS whitelist """ + # Return default for background tasks and non-HTTP contexts + if not has_request_context(): + return GRAPH_NAME + value = request.headers.get("X-Graph-Name", "").strip() if value: if not re.match(r'^[a-zA-Z0-9_-]+$', value) or len(value) > 64: @@ -1098,11 +1103,16 @@ def _get_collection_name() -> str: Returns: Collection name from header if valid, otherwise environment default. + Returns default if called outside of request context (e.g., background tasks). Raises: 400: If header format is invalid (regex or length check fails) 403: If collection name is not in ALLOWED_COLLECTIONS whitelist """ + # Return default for background tasks and non-HTTP contexts + if not has_request_context(): + return COLLECTION_NAME + value = request.headers.get("X-Collection-Name", "").strip() if value: if not re.match(r'^[a-zA-Z0-9_-]+$', value) or len(value) > 64: