Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 112 additions & 1 deletion src/neo4j_graphrag/experimental/components/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class PropertyType(BaseModel):
]
description: str = ""
required: bool = False

# unique: bool = False
model_config = ConfigDict(
frozen=True,
)
Expand Down Expand Up @@ -161,6 +161,22 @@ def property_type_from_name(self, name: str) -> Optional[PropertyType]:
return None


class ConstraintType(BaseModel):
"""
Represents a constraint on a node in the graph.
"""

type: Literal[
"UNIQUENESS"
] # TODO: add other constraint types ["propertyExistence", "propertyType", "key"]
node_type: str
property_name: str

model_config = ConfigDict(
frozen=True,
)


class GraphSchema(DataModel):
"""This model represents the expected
node and relationship types in the graph.
Expand All @@ -177,6 +193,7 @@ class GraphSchema(DataModel):
node_types: Tuple[NodeType, ...]
relationship_types: Tuple[RelationshipType, ...] = tuple()
patterns: Tuple[Tuple[str, str, str], ...] = tuple()
constraints: Tuple[ConstraintType, ...] = tuple()

additional_node_types: bool = Field(
default_factory=default_additional_item("node_types")
Expand Down Expand Up @@ -239,6 +256,31 @@ def validate_additional_parameters(self) -> Self:
)
return self

@model_validator(mode="after")
def validate_constraints_against_node_types(self) -> Self:
if not self.constraints:
return self
for constraint in self.constraints:
if not constraint.property_name:
raise SchemaValidationError(
f"Constraint has no property name: {constraint}. Property name is required."
)
if constraint.node_type not in self._node_type_index:
raise SchemaValidationError(
f"Constraint references undefined node type: {constraint.node_type}"
)
# Check if property_name exists on the node type (only if additional_properties is False)
node_type = self._node_type_index[constraint.node_type]
if not node_type.additional_properties:
valid_property_names = {p.name for p in node_type.properties}
if constraint.property_name not in valid_property_names:
raise SchemaValidationError(
f"Constraint references undefined property '{constraint.property_name}' "
f"on node type '{constraint.node_type}'. "
f"Valid properties: {valid_property_names}"
)
return self

def node_type_from_label(self, label: str) -> Optional[NodeType]:
return self._node_type_index.get(label)

Expand Down Expand Up @@ -382,6 +424,7 @@ def create_schema_model(
node_types: Sequence[NodeType],
relationship_types: Optional[Sequence[RelationshipType]] = None,
patterns: Optional[Sequence[Tuple[str, str, str]]] = None,
constraints: Optional[Sequence[ConstraintType]] = None,
**kwargs: Any,
) -> GraphSchema:
"""
Expand All @@ -403,6 +446,7 @@ def create_schema_model(
node_types=node_types,
relationship_types=relationship_types or (),
patterns=patterns or (),
constraints=constraints or (),
**kwargs,
)
)
Expand All @@ -415,6 +459,7 @@ async def run(
node_types: Sequence[NodeType],
relationship_types: Optional[Sequence[RelationshipType]] = None,
patterns: Optional[Sequence[Tuple[str, str, str]]] = None,
constraints: Optional[Sequence[ConstraintType]] = None,
**kwargs: Any,
) -> GraphSchema:
"""
Expand All @@ -432,6 +477,7 @@ async def run(
node_types,
relationship_types,
patterns,
constraints,
**kwargs,
)

Expand Down Expand Up @@ -555,6 +601,61 @@ def _filter_relationships_without_labels(
relationship_types, "relationship type"
)

def _filter_invalid_constraints(
self, constraints: List[Dict[str, Any]], node_types: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
"""Filter out constraints that reference undefined node types, have no property name,
or reference a property that doesn't exist on the node type."""
if not constraints:
return []

if not node_types:
logging.info(
"Filtering out all constraints because no node types are defined. "
"Constraints reference node types that must be defined."
)
return []

# Build a mapping of node_type label -> set of property names
node_type_properties: Dict[str, set[str]] = {}
for node_type_dict in node_types:
label = node_type_dict.get("label")
if label:
properties = node_type_dict.get("properties", [])
property_names = {p.get("name") for p in properties if p.get("name")}
node_type_properties[label] = property_names

valid_node_labels = set(node_type_properties.keys())

filtered_constraints = []
for constraint in constraints:
# check if the property_name is provided
if not constraint.get("property_name"):
logging.info(
f"Filtering out constraint: {constraint}. "
f"Property name is not provided."
)
continue
# check if the node_type is valid
node_type = constraint.get("node_type")
if node_type not in valid_node_labels:
logging.info(
f"Filtering out constraint: {constraint}. "
f"Node type '{node_type}' is not valid. Valid node types: {valid_node_labels}"
)
continue
# check if the property_name exists on the node type
property_name = constraint.get("property_name")
if property_name not in node_type_properties.get(node_type, set()):
logging.info(
f"Filtering out constraint: {constraint}. "
f"Property '{property_name}' does not exist on node type '{node_type}'. "
f"Valid properties: {node_type_properties.get(node_type, set())}"
)
continue
filtered_constraints.append(constraint)
return filtered_constraints

def _clean_json_content(self, content: str) -> str:
content = content.strip()

Expand Down Expand Up @@ -624,6 +725,9 @@ async def run(self, text: str, examples: str = "", **kwargs: Any) -> GraphSchema
extracted_patterns: Optional[List[Tuple[str, str, str]]] = extracted_schema.get(
"patterns"
)
extracted_constraints: Optional[List[Dict[str, Any]]] = extracted_schema.get(
"constraints"
)

# Filter out nodes and relationships without labels
extracted_node_types = self._filter_nodes_without_labels(extracted_node_types)
Expand All @@ -638,11 +742,18 @@ async def run(self, text: str, examples: str = "", **kwargs: Any) -> GraphSchema
extracted_patterns, extracted_node_types, extracted_relationship_types
)

# Filter out invalid constraints
if extracted_constraints:
extracted_constraints = self._filter_invalid_constraints(
extracted_constraints, extracted_node_types
)

return GraphSchema.model_validate(
{
"node_types": extracted_node_types,
"relationship_types": extracted_relationship_types,
"patterns": extracted_patterns,
"constraints": extracted_constraints or [],
}
)

Expand Down
18 changes: 16 additions & 2 deletions src/neo4j_graphrag/generation/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,12 @@ class SchemaExtractionTemplate(PromptTemplate):
5. When defining patterns, ensure that every node label and relationship label mentioned exists in your lists of node types and relationship types.
6. Do not create node types that aren't clearly mentioned in the text.
7. Keep your schema minimal and focused on clearly identifiable patterns in the text.
8. UNIQUENESS CONSTRAINTS:
8.1 UNIQUENESS is optional; each node_type may or may not have exactly one uniqueness constraint.
8.2 Only use properties that seem to not have too many missing values in the sample.
8.3 Constraints reference node_types by label and specify which property is unique.
8.4 If a property appears in a uniqueness constraint it MUST also appear in the corresponding node_type as a property.


Accepted property types are: BOOLEAN, DATE, DURATION, FLOAT, INTEGER, LIST,
LOCAL_DATETIME, LOCAL_TIME, POINT, STRING, ZONED_DATETIME, ZONED_TIME.
Expand All @@ -233,18 +239,26 @@ class SchemaExtractionTemplate(PromptTemplate):
"type": "STRING"
}}
]
}},
}}
...
],
"relationship_types": [
{{
"label": "WORKS_FOR"
}},
}}
...
],
"patterns": [
["Person", "WORKS_FOR", "Company"],
...
],
"constraints": [
{{
"type": "UNIQUENESS",
"node_type": "Person",
"property_name": "name"
}}
...
]
}}

Expand Down
Loading