Skip to content

Commit 8724bca

Browse files
Merge pull request #864 from neo4j-contrib/feature/unique_variable_names
Added method to generate unique variable names for specific pathes.
2 parents 8d41f58 + 4c7be9b commit 8724bca

File tree

5 files changed

+132
-10
lines changed

5 files changed

+132
-10
lines changed

doc/source/traversal.rst

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,26 @@ With both `traverse_relations` and `fetch_relations`, you can force the use of a
7777

7878
Person.nodes.fetch_relations('city__country', Optional('country')).all()
7979

80+
Unique variables
81+
----------------
82+
83+
If you want to use the same variable name for traversed nodes when chaining traversals, you can use the `unique_variables` method::
84+
85+
# This does not guarantee that coffees__species will traverse the same nodes as coffees
86+
# So coffees__species can traverse the Coffee node "Gold 3000"
87+
nodeset = (
88+
Supplier.nodes.fetch_relations("coffees", "coffees__species")
89+
.filter(coffees__name="Nescafe")
90+
)
91+
92+
# This guarantees that coffees__species will traverse the same nodes as coffees
93+
# So when fetching species, it will only fetch those of the Coffee node "Nescafe"
94+
nodeset = (
95+
Supplier.nodes.fetch_relations("coffees", "coffees__species")
96+
.filter(coffees__name="Nescafe")
97+
.unique_variables("coffees")
98+
)
99+
80100
Resolve results
81101
---------------
82102

neomodel/async_/match.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -503,9 +503,11 @@ def create_relation_identifier(self) -> str:
503503
self._relation_identifier_count += 1
504504
return f"r{self._relation_identifier_count}"
505505

506-
def create_node_identifier(self, prefix: str) -> str:
507-
self._node_identifier_count += 1
508-
return f"{prefix}{self._node_identifier_count}"
506+
def create_node_identifier(self, prefix: str, path: str) -> str:
507+
if path not in self.node_set._unique_variables:
508+
self._node_identifier_count += 1
509+
return f"{prefix}{self._node_identifier_count}"
510+
return prefix
509511

510512
def build_order_by(self, ident: str, source: "AsyncNodeSet") -> None:
511513
if "?" in source.order_by_elements:
@@ -620,8 +622,9 @@ def build_traversal_from_path(
620622
rhs_name = relation["alias"]
621623
else:
622624
rhs_name = f"{rhs_label.lower()}_{rel_iterator}"
623-
rhs_name = self.create_node_identifier(rhs_name)
625+
rhs_name = self.create_node_identifier(rhs_name, rel_iterator)
624626
rhs_ident = f"{rhs_name}:{rhs_label}"
627+
625628
if relation["include_in_return"] and not already_present:
626629
self._additional_return(rhs_name)
627630

@@ -1384,6 +1387,7 @@ def __init__(self, source: Any) -> None:
13841387
self._extra_results: list = []
13851388
self._subqueries: list[Subquery] = []
13861389
self._intermediate_transforms: list = []
1390+
self._unique_variables: list[str] = []
13871391

13881392
def __await__(self) -> Any:
13891393
return self.all().__await__() # type: ignore[attr-defined]
@@ -1555,6 +1559,11 @@ def _register_relation_to_fetch(
15551559
item["alias"] = alias
15561560
return item
15571561

1562+
def unique_variables(self, *pathes: tuple[str, ...]) -> "AsyncNodeSet":
1563+
"""Generate unique variable names for the given pathes."""
1564+
self._unique_variables = pathes
1565+
return self
1566+
15581567
def fetch_relations(self, *relation_names: tuple[str, ...]) -> "AsyncNodeSet":
15591568
"""Specify a set of relations to traverse and return."""
15601569
relations = []

neomodel/sync_/match.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -501,9 +501,11 @@ def create_relation_identifier(self) -> str:
501501
self._relation_identifier_count += 1
502502
return f"r{self._relation_identifier_count}"
503503

504-
def create_node_identifier(self, prefix: str) -> str:
505-
self._node_identifier_count += 1
506-
return f"{prefix}{self._node_identifier_count}"
504+
def create_node_identifier(self, prefix: str, path: str) -> str:
505+
if path not in self.node_set._unique_variables:
506+
self._node_identifier_count += 1
507+
return f"{prefix}{self._node_identifier_count}"
508+
return prefix
507509

508510
def build_order_by(self, ident: str, source: "NodeSet") -> None:
509511
if "?" in source.order_by_elements:
@@ -618,8 +620,9 @@ def build_traversal_from_path(
618620
rhs_name = relation["alias"]
619621
else:
620622
rhs_name = f"{rhs_label.lower()}_{rel_iterator}"
621-
rhs_name = self.create_node_identifier(rhs_name)
623+
rhs_name = self.create_node_identifier(rhs_name, rel_iterator)
622624
rhs_ident = f"{rhs_name}:{rhs_label}"
625+
623626
if relation["include_in_return"] and not already_present:
624627
self._additional_return(rhs_name)
625628

@@ -1380,6 +1383,7 @@ def __init__(self, source: Any) -> None:
13801383
self._extra_results: list = []
13811384
self._subqueries: list[Subquery] = []
13821385
self._intermediate_transforms: list = []
1386+
self._unique_variables: list[str] = []
13831387

13841388
def __await__(self) -> Any:
13851389
return self.all().__await__() # type: ignore[attr-defined]
@@ -1551,6 +1555,11 @@ def _register_relation_to_fetch(
15511555
item["alias"] = alias
15521556
return item
15531557

1558+
def unique_variables(self, *pathes: tuple[str, ...]) -> "NodeSet":
1559+
"""Generate unique variable names for the given pathes."""
1560+
self._unique_variables = pathes
1561+
return self
1562+
15541563
def fetch_relations(self, *relation_names: tuple[str, ...]) -> "NodeSet":
15551564
"""Specify a set of relations to traverse and return."""
15561565
relations = []

test/async_/test_match_api.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from datetime import datetime
33
from test._async_compat import mark_async_test
44

5-
import numpy as np
65
from pytest import raises, skip, warns
76

87
from neomodel import (
@@ -1161,6 +1160,49 @@ async def test_in_filter_with_array_property():
11611160
), "Species found by tags with not match tags given"
11621161

11631162

1163+
@mark_async_test
1164+
async def test_unique_variables():
1165+
arabica = await Species(name="Arabica").save()
1166+
nescafe = await Coffee(name="Nescafe", price=99).save()
1167+
gold3000 = await Coffee(name="Gold 3000", price=11).save()
1168+
supplier1 = await Supplier(name="Supplier 1", delivery_cost=3).save()
1169+
supplier2 = await Supplier(name="Supplier 2", delivery_cost=20).save()
1170+
supplier3 = await Supplier(name="Supplier 3", delivery_cost=20).save()
1171+
1172+
await nescafe.suppliers.connect(supplier1, {"since": datetime(2020, 4, 1, 0, 0)})
1173+
await nescafe.suppliers.connect(supplier2, {"since": datetime(2010, 4, 1, 0, 0)})
1174+
await nescafe.species.connect(arabica)
1175+
await gold3000.suppliers.connect(supplier1, {"since": datetime(2020, 4, 1, 0, 0)})
1176+
await gold3000.species.connect(arabica)
1177+
1178+
nodeset = Supplier.nodes.fetch_relations("coffees", "coffees__species").filter(
1179+
coffees__name="Nescafe"
1180+
)
1181+
ast = await nodeset.query_cls(nodeset).build_ast()
1182+
query = ast.build_query()
1183+
assert "coffee_coffees1" in query
1184+
assert "coffee_coffees2" in query
1185+
results = await nodeset.all()
1186+
# This will be 3 because 2 suppliers for Nescafe and 1 for Gold 3000
1187+
# Gold 3000 is traversed because coffees__species redefines the coffees traversal
1188+
assert len(results) == 3
1189+
1190+
nodeset = (
1191+
Supplier.nodes.fetch_relations("coffees", "coffees__species")
1192+
.filter(coffees__name="Nescafe")
1193+
.unique_variables("coffees")
1194+
)
1195+
ast = await nodeset.query_cls(nodeset).build_ast()
1196+
query = ast.build_query()
1197+
assert "coffee_coffees" in query
1198+
assert "coffee_coffees1" not in query
1199+
assert "coffee_coffees2" not in query
1200+
results = await nodeset.all()
1201+
# This will 2 because Gold 3000 is excluded this time
1202+
# As coffees will be reused in coffees__species
1203+
assert len(results) == 2
1204+
1205+
11641206
@mark_async_test
11651207
async def test_async_iterator():
11661208
n = 10

test/sync_/test_match_api.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from datetime import datetime
33
from test._async_compat import mark_sync_test
44

5-
import numpy as np
65
from pytest import raises, skip, warns
76

87
from neomodel import (
@@ -1145,6 +1144,49 @@ def test_in_filter_with_array_property():
11451144
), "Species found by tags with not match tags given"
11461145

11471146

1147+
@mark_sync_test
1148+
def test_unique_variables():
1149+
arabica = Species(name="Arabica").save()
1150+
nescafe = Coffee(name="Nescafe", price=99).save()
1151+
gold3000 = Coffee(name="Gold 3000", price=11).save()
1152+
supplier1 = Supplier(name="Supplier 1", delivery_cost=3).save()
1153+
supplier2 = Supplier(name="Supplier 2", delivery_cost=20).save()
1154+
supplier3 = Supplier(name="Supplier 3", delivery_cost=20).save()
1155+
1156+
nescafe.suppliers.connect(supplier1, {"since": datetime(2020, 4, 1, 0, 0)})
1157+
nescafe.suppliers.connect(supplier2, {"since": datetime(2010, 4, 1, 0, 0)})
1158+
nescafe.species.connect(arabica)
1159+
gold3000.suppliers.connect(supplier1, {"since": datetime(2020, 4, 1, 0, 0)})
1160+
gold3000.species.connect(arabica)
1161+
1162+
nodeset = Supplier.nodes.fetch_relations("coffees", "coffees__species").filter(
1163+
coffees__name="Nescafe"
1164+
)
1165+
ast = nodeset.query_cls(nodeset).build_ast()
1166+
query = ast.build_query()
1167+
assert "coffee_coffees1" in query
1168+
assert "coffee_coffees2" in query
1169+
results = nodeset.all()
1170+
# This will be 3 because 2 suppliers for Nescafe and 1 for Gold 3000
1171+
# Gold 3000 is traversed because coffees__species redefines the coffees traversal
1172+
assert len(results) == 3
1173+
1174+
nodeset = (
1175+
Supplier.nodes.fetch_relations("coffees", "coffees__species")
1176+
.filter(coffees__name="Nescafe")
1177+
.unique_variables("coffees")
1178+
)
1179+
ast = nodeset.query_cls(nodeset).build_ast()
1180+
query = ast.build_query()
1181+
assert "coffee_coffees" in query
1182+
assert "coffee_coffees1" not in query
1183+
assert "coffee_coffees2" not in query
1184+
results = nodeset.all()
1185+
# This will 2 because Gold 3000 is excluded this time
1186+
# As coffees will be reused in coffees__species
1187+
assert len(results) == 2
1188+
1189+
11481190
@mark_sync_test
11491191
def test_async_iterator():
11501192
n = 10

0 commit comments

Comments
 (0)