diff --git a/graphiti_core/edges.py b/graphiti_core/edges.py index 82066e732..e24ee84bd 100644 --- a/graphiti_core/edges.py +++ b/graphiti_core/edges.py @@ -576,7 +576,35 @@ def get_entity_edge_from_record(record: Any, provider: GraphProvider) -> EntityE episodes = record['episodes'] if provider == GraphProvider.KUZU: attributes = json.loads(record['attributes']) if record['attributes'] else {} + elif provider == GraphProvider.NEO4J: + # Neo4j: Try new JSON format first, fall back to old spread format + raw_attrs = record.get('attributes', '') + if raw_attrs and isinstance(raw_attrs, str): + # New format: JSON string in e.attributes + attributes = json.loads(raw_attrs) + else: + # Old format: attributes spread as individual properties + all_props = record.get('all_properties', {}) + if all_props: + attributes = dict(all_props) + # Remove known system fields + attributes.pop('uuid', None) + attributes.pop('source_node_uuid', None) + attributes.pop('target_node_uuid', None) + attributes.pop('fact', None) + attributes.pop('fact_embedding', None) + attributes.pop('name', None) + attributes.pop('group_id', None) + attributes.pop('episodes', None) + attributes.pop('created_at', None) + attributes.pop('expired_at', None) + attributes.pop('valid_at', None) + attributes.pop('invalid_at', None) + attributes.pop('attributes', None) # Remove the empty attributes field + else: + attributes = {} else: + # FalkorDB, Neptune: Original behavior attributes = record['attributes'] attributes.pop('uuid', None) attributes.pop('source_node_uuid', None) diff --git a/graphiti_core/models/edges/edge_db_queries.py b/graphiti_core/models/edges/edge_db_queries.py index c57f36942..f3a52eab6 100644 --- a/graphiti_core/models/edges/edge_db_queries.py +++ b/graphiti_core/models/edges/edge_db_queries.py @@ -203,6 +203,23 @@ def get_entity_edge_return_query(provider: GraphProvider) -> str: properties(e) AS attributes """ + if provider == GraphProvider.NEO4J: + return """ + e.uuid AS uuid, + n.uuid AS source_node_uuid, + m.uuid AS target_node_uuid, + e.group_id AS group_id, + e.created_at AS created_at, + e.name AS name, + e.fact AS fact, + e.episodes AS episodes, + e.expired_at AS expired_at, + e.valid_at AS valid_at, + e.invalid_at AS invalid_at, + COALESCE(e.attributes, '') AS attributes, + properties(e) AS all_properties + """ + return """ e.uuid AS uuid, n.uuid AS source_node_uuid, diff --git a/graphiti_core/models/nodes/node_db_queries.py b/graphiti_core/models/nodes/node_db_queries.py index 34e3d8b8b..0791cedda 100644 --- a/graphiti_core/models/nodes/node_db_queries.py +++ b/graphiti_core/models/nodes/node_db_queries.py @@ -265,6 +265,18 @@ def get_entity_node_return_query(provider: GraphProvider) -> str: n.summary AS summary, n.attributes AS attributes """ + + if provider == GraphProvider.NEO4J: + return """ + n.uuid AS uuid, + n.name AS name, + n.group_id AS group_id, + n.created_at AS created_at, + n.summary AS summary, + labels(n) AS labels, + COALESCE(n.attributes, '') AS attributes, + properties(n) AS all_properties + """ return """ n.uuid AS uuid, diff --git a/graphiti_core/nodes.py b/graphiti_core/nodes.py index cd3d003d1..ace3eb672 100644 --- a/graphiti_core/nodes.py +++ b/graphiti_core/nodes.py @@ -754,7 +754,30 @@ def get_episodic_node_from_record(record: Any) -> EpisodicNode: def get_entity_node_from_record(record: Any, provider: GraphProvider) -> EntityNode: if provider == GraphProvider.KUZU: attributes = json.loads(record['attributes']) if record['attributes'] else {} + elif provider == GraphProvider.NEO4J: + # Neo4j: Try new JSON format first, fall back to old spread format + raw_attrs = record.get('attributes', '') + if raw_attrs and isinstance(raw_attrs, str): + # New format: JSON string in n.attributes + attributes = json.loads(raw_attrs) + else: + # Old format: attributes spread as individual properties + all_props = record.get('all_properties', {}) + if all_props: + attributes = dict(all_props) + # Remove known system fields + attributes.pop('uuid', None) + attributes.pop('name', None) + attributes.pop('group_id', None) + attributes.pop('name_embedding', None) + attributes.pop('summary', None) + attributes.pop('created_at', None) + attributes.pop('labels', None) + attributes.pop('attributes', None) # Remove the empty attributes field + else: + attributes = {} else: + # FalkorDB, Neptune: Original behavior attributes = record['attributes'] attributes.pop('uuid', None) attributes.pop('name', None) diff --git a/graphiti_core/utils/bulk_utils.py b/graphiti_core/utils/bulk_utils.py index eb3008e71..db3a9135f 100644 --- a/graphiti_core/utils/bulk_utils.py +++ b/graphiti_core/utils/bulk_utils.py @@ -181,7 +181,12 @@ async def add_nodes_and_edges_bulk_tx( if driver.provider == GraphProvider.KUZU: attributes = convert_datetimes_to_strings(node.attributes) if node.attributes else {} entity_data['attributes'] = json.dumps(attributes) + elif driver.provider == GraphProvider.NEO4J: + # Neo4j: Serialize attributes to JSON string to support nested structures + attributes = convert_datetimes_to_strings(node.attributes) if node.attributes else {} + entity_data['attributes'] = json.dumps(attributes) if attributes else '{}' else: + # FalkorDB, Neptune: Keep original behavior (spread attributes) entity_data.update(node.attributes or {}) nodes.append(entity_data) @@ -208,7 +213,12 @@ async def add_nodes_and_edges_bulk_tx( if driver.provider == GraphProvider.KUZU: attributes = convert_datetimes_to_strings(edge.attributes) if edge.attributes else {} edge_data['attributes'] = json.dumps(attributes) + elif driver.provider == GraphProvider.NEO4J: + # Neo4j: Serialize attributes to JSON string to support nested structures + attributes = convert_datetimes_to_strings(edge.attributes) if edge.attributes else {} + edge_data['attributes'] = json.dumps(attributes) if attributes else '{}' else: + # FalkorDB, Neptune: Keep original behavior (spread attributes) edge_data.update(edge.attributes or {}) edges.append(edge_data) diff --git a/tests/test_neo4j_nested_attributes_int.py b/tests/test_neo4j_nested_attributes_int.py new file mode 100644 index 000000000..7de4ae6cd --- /dev/null +++ b/tests/test_neo4j_nested_attributes_int.py @@ -0,0 +1,208 @@ +"""Integration test for Neo4j nested attributes serialization. + +Tests that entities and edges with complex nested attributes (Maps of Lists, Lists of Maps) +are properly serialized to JSON strings for Neo4j storage. + +This test addresses a bug where Neo4j would reject entity/edge attributes containing +nested structures with the error: +Neo.ClientError.Statement.TypeError - Property values can only be of primitive types +or arrays thereof. +""" + +import pytest +from datetime import datetime, UTC + +from graphiti_core.nodes import EntityNode +from graphiti_core.edges import EntityEdge +from graphiti_core.driver.driver import GraphProvider + + +@pytest.mark.integration +async def test_nested_entity_attributes(graph_driver, embedder): + """Test that entities with nested attributes are stored and retrieved correctly in Neo4j.""" + if graph_driver.provider != GraphProvider.NEO4J: + pytest.skip("This test is specific to Neo4j nested attributes serialization") + + # Create entity with nested attributes (Maps of Lists, Lists of Maps) + entity = EntityNode( + uuid="test-entity-nested-attrs-001", + name="Test Entity with Nested Attributes", + group_id="test-group-nested", + labels=["Entity", "TestType"], + created_at=datetime.now(UTC), + summary="Test entity for nested attributes", + attributes={ + # Simple array of primitives - should work + "discovered_resources": ["resource1", "resource2", "resource3"], + # Nested map with list values - the problematic case + "metadata": { + "analysis": ["analysis_item1", "analysis_item2"], + "nested_map": {"key1": "value1", "key2": "value2"} + }, + # Map with complex nested structure + "activity_log": { + "initiated_actions": ["action1", "action2"], + "completed_tasks": { + "task_list": ["task1", "task2"], + "priority": "high" + } + }, + # Simple primitive attributes + "count": 42, + "status": "active" + } + ) + + await entity.generate_name_embedding(embedder) + + # Save entity - this would previously crash Neo4j with nested structures + await entity.save(graph_driver) + + # Retrieve entity and verify attributes are preserved + retrieved = await EntityNode.get_by_uuid(graph_driver, entity.uuid) + + assert retrieved is not None, "Entity should be retrievable" + assert retrieved.uuid == entity.uuid + assert retrieved.name == entity.name + + # Verify nested attributes are correctly preserved + assert retrieved.attributes == entity.attributes, "Attributes should be preserved exactly" + assert retrieved.attributes["discovered_resources"] == ["resource1", "resource2", "resource3"] + assert retrieved.attributes["metadata"]["analysis"] == ["analysis_item1", "analysis_item2"] + assert retrieved.attributes["metadata"]["nested_map"]["key1"] == "value1" + assert retrieved.attributes["activity_log"]["completed_tasks"]["task_list"] == ["task1", "task2"] + assert retrieved.attributes["count"] == 42 + assert retrieved.attributes["status"] == "active" + + +@pytest.mark.integration +async def test_nested_edge_attributes(graph_driver, embedder): + """Test that edges with nested attributes are stored and retrieved correctly in Neo4j.""" + if graph_driver.provider != GraphProvider.NEO4J: + pytest.skip("This test is specific to Neo4j nested attributes serialization") + + # First create two entity nodes to connect + source_entity = EntityNode( + uuid="test-source-entity-001", + name="Source Entity", + group_id="test-group-nested", + labels=["Entity", "TestType"], + created_at=datetime.now(UTC), + summary="Source entity for edge test", + attributes={} + ) + + target_entity = EntityNode( + uuid="test-target-entity-001", + name="Target Entity", + group_id="test-group-nested", + labels=["Entity", "TestType"], + created_at=datetime.now(UTC), + summary="Target entity for edge test", + attributes={} + ) + + await source_entity.generate_name_embedding(embedder) + await target_entity.generate_name_embedding(embedder) + await source_entity.save(graph_driver) + await target_entity.save(graph_driver) + + # Create edge with nested attributes + edge = EntityEdge( + uuid="test-edge-nested-attrs-001", + source_node_uuid=source_entity.uuid, + target_node_uuid=target_entity.uuid, + name="RELATES_TO", + fact="Source entity relates to target entity with complex metadata", + group_id="test-group-nested", + episodes=["episode1", "episode2"], + created_at=datetime.now(UTC), + valid_at=datetime.now(UTC), + attributes={ + # Nested map with list values + "relationship_metadata": { + "interaction_types": ["collaboration", "communication"], + "details": { + "frequency": "daily", + "confidence": 0.95 + } + }, + # Map with complex structure + "historical_data": { + "events": ["event1", "event2", "event3"], + "analysis": { + "trends": ["increasing", "positive"], + "factors": {"external": True, "internal": False} + } + }, + # Simple attributes + "weight": 0.85, + "verified": True + } + ) + + await edge.generate_embedding(embedder) + + # Save edge - this would previously crash Neo4j with nested structures + await edge.save(graph_driver) + + # Retrieve edge and verify attributes are preserved + retrieved = await EntityEdge.get_by_uuid(graph_driver, edge.uuid) + + assert retrieved is not None, "Edge should be retrievable" + assert retrieved.uuid == edge.uuid + assert retrieved.fact == edge.fact + + # Verify nested attributes are correctly preserved + assert retrieved.attributes == edge.attributes, "Edge attributes should be preserved exactly" + assert retrieved.attributes["relationship_metadata"]["interaction_types"] == ["collaboration", "communication"] + assert retrieved.attributes["relationship_metadata"]["details"]["frequency"] == "daily" + assert retrieved.attributes["historical_data"]["events"] == ["event1", "event2", "event3"] + assert retrieved.attributes["historical_data"]["analysis"]["factors"]["external"] is True + assert retrieved.attributes["weight"] == 0.85 + assert retrieved.attributes["verified"] is True + + +@pytest.mark.integration +async def test_empty_and_none_attributes(graph_driver, embedder): + """Test that empty and None attributes are handled correctly.""" + if graph_driver.provider != GraphProvider.NEO4J: + pytest.skip("This test is specific to Neo4j nested attributes serialization") + + # Entity with empty attributes + entity_empty = EntityNode( + uuid="test-entity-empty-attrs-001", + name="Entity with Empty Attributes", + group_id="test-group-nested", + labels=["Entity", "TestType"], + created_at=datetime.now(UTC), + summary="Test entity with empty attributes", + attributes={} + ) + + await entity_empty.generate_name_embedding(embedder) + await entity_empty.save(graph_driver) + + retrieved_empty = await EntityNode.get_by_uuid(graph_driver, entity_empty.uuid) + assert retrieved_empty is not None + assert retrieved_empty.attributes == {} + + # Entity with None-valued attributes + entity_none = EntityNode( + uuid="test-entity-none-attrs-001", + name="Entity with None Attributes", + group_id="test-group-nested", + labels=["Entity", "TestType"], + created_at=datetime.now(UTC), + summary="Test entity with None attributes", + attributes={"key1": None, "key2": "value2"} + ) + + await entity_none.generate_name_embedding(embedder) + await entity_none.save(graph_driver) + + retrieved_none = await EntityNode.get_by_uuid(graph_driver, entity_none.uuid) + assert retrieved_none is not None + assert retrieved_none.attributes["key1"] is None + assert retrieved_none.attributes["key2"] == "value2" +