Skip to content

Commit 33f61bf

Browse files
Merge pull request #866 from neo4j-contrib/rc/5.4.5
Rc/5.4.5
2 parents 0d38b1a + 8724bca commit 33f61bf

File tree

7 files changed

+174
-34
lines changed

7 files changed

+174
-34
lines changed

doc/source/configuration.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ Adjust driver configuration - these options are only available for this connecti
3232
config.MAX_TRANSACTION_RETRY_TIME = 30.0 # default
3333
config.RESOLVER = None # default
3434
config.TRUST = neo4j.TRUST_SYSTEM_CA_SIGNED_CERTIFICATES # default
35-
config.USER_AGENT = neomodel/v5.4.4 # default
35+
config.USER_AGENT = neomodel/v5.4.5 # default
3636

3737
Setting the database name, if different from the default one::
3838

doc/source/traversal.rst

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,26 @@ With both `traverse_relations` and `fetch_relations`, you can force the use of a
7777

7878
Person.nodes.fetch_relations('city__country', Optional('country')).all()
7979

80+
Unique variables
81+
----------------
82+
83+
If you want to use the same variable name for traversed nodes when chaining traversals, you can use the `unique_variables` method::
84+
85+
# This does not guarantee that coffees__species will traverse the same nodes as coffees
86+
# So coffees__species can traverse the Coffee node "Gold 3000"
87+
nodeset = (
88+
Supplier.nodes.fetch_relations("coffees", "coffees__species")
89+
.filter(coffees__name="Nescafe")
90+
)
91+
92+
# This guarantees that coffees__species will traverse the same nodes as coffees
93+
# So when fetching species, it will only fetch those of the Coffee node "Nescafe"
94+
nodeset = (
95+
Supplier.nodes.fetch_relations("coffees", "coffees__species")
96+
.filter(coffees__name="Nescafe")
97+
.unique_variables("coffees")
98+
)
99+
80100
Resolve results
81101
---------------
82102

neomodel/_version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "5.4.4"
1+
__version__ = "5.4.5"

neomodel/async_/match.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -503,9 +503,11 @@ def create_relation_identifier(self) -> str:
503503
self._relation_identifier_count += 1
504504
return f"r{self._relation_identifier_count}"
505505

506-
def create_node_identifier(self, prefix: str) -> str:
507-
self._node_identifier_count += 1
508-
return f"{prefix}{self._node_identifier_count}"
506+
def create_node_identifier(self, prefix: str, path: str) -> str:
507+
if path not in self.node_set._unique_variables:
508+
self._node_identifier_count += 1
509+
return f"{prefix}{self._node_identifier_count}"
510+
return prefix
509511

510512
def build_order_by(self, ident: str, source: "AsyncNodeSet") -> None:
511513
if "?" in source.order_by_elements:
@@ -613,14 +615,16 @@ def build_traversal_from_path(
613615
rhs_label = relationship.definition["node_class"].__label__
614616
if relation.get("relation_filtering"):
615617
rhs_name = rel_ident
618+
rhs_ident = f":{rhs_label}"
616619
else:
617620
if index + 1 == len(parts) and "alias" in relation:
618621
# If an alias is defined, use it to store the last hop in the path
619622
rhs_name = relation["alias"]
620623
else:
621624
rhs_name = f"{rhs_label.lower()}_{rel_iterator}"
622-
rhs_name = self.create_node_identifier(rhs_name)
623-
rhs_ident = f"{rhs_name}:{rhs_label}"
625+
rhs_name = self.create_node_identifier(rhs_name, rel_iterator)
626+
rhs_ident = f"{rhs_name}:{rhs_label}"
627+
624628
if relation["include_in_return"] and not already_present:
625629
self._additional_return(rhs_name)
626630

@@ -825,9 +829,11 @@ def add_to_target(statement: str, connector: Q, optional: bool) -> None:
825829
match_filters = [filter[0] for filter in target if not filter[1]]
826830
opt_match_filters = [filter[0] for filter in target if filter[1]]
827831
if q.connector == Q.OR and match_filters and opt_match_filters:
828-
raise ValueError(
829-
"Cannot filter using OR operator on variables coming from both MATCH and OPTIONAL MATCH statements"
830-
)
832+
# In this case, we can't split filters in two WHERE statements so we move
833+
# everything into the one applied after OPTIONAL MATCH statements...
834+
opt_match_filters += match_filters
835+
match_filters = []
836+
831837
ret = f" {q.connector} ".join(match_filters)
832838
if ret and q.negated:
833839
ret = f"NOT ({ret})"
@@ -1381,6 +1387,7 @@ def __init__(self, source: Any) -> None:
13811387
self._extra_results: list = []
13821388
self._subqueries: list[Subquery] = []
13831389
self._intermediate_transforms: list = []
1390+
self._unique_variables: list[str] = []
13841391

13851392
def __await__(self) -> Any:
13861393
return self.all().__await__() # type: ignore[attr-defined]
@@ -1552,6 +1559,11 @@ def _register_relation_to_fetch(
15521559
item["alias"] = alias
15531560
return item
15541561

1562+
def unique_variables(self, *pathes: tuple[str, ...]) -> "AsyncNodeSet":
1563+
"""Generate unique variable names for the given pathes."""
1564+
self._unique_variables = pathes
1565+
return self
1566+
15551567
def fetch_relations(self, *relation_names: tuple[str, ...]) -> "AsyncNodeSet":
15561568
"""Specify a set of relations to traverse and return."""
15571569
relations = []

neomodel/sync_/match.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -501,9 +501,11 @@ def create_relation_identifier(self) -> str:
501501
self._relation_identifier_count += 1
502502
return f"r{self._relation_identifier_count}"
503503

504-
def create_node_identifier(self, prefix: str) -> str:
505-
self._node_identifier_count += 1
506-
return f"{prefix}{self._node_identifier_count}"
504+
def create_node_identifier(self, prefix: str, path: str) -> str:
505+
if path not in self.node_set._unique_variables:
506+
self._node_identifier_count += 1
507+
return f"{prefix}{self._node_identifier_count}"
508+
return prefix
507509

508510
def build_order_by(self, ident: str, source: "NodeSet") -> None:
509511
if "?" in source.order_by_elements:
@@ -611,14 +613,16 @@ def build_traversal_from_path(
611613
rhs_label = relationship.definition["node_class"].__label__
612614
if relation.get("relation_filtering"):
613615
rhs_name = rel_ident
616+
rhs_ident = f":{rhs_label}"
614617
else:
615618
if index + 1 == len(parts) and "alias" in relation:
616619
# If an alias is defined, use it to store the last hop in the path
617620
rhs_name = relation["alias"]
618621
else:
619622
rhs_name = f"{rhs_label.lower()}_{rel_iterator}"
620-
rhs_name = self.create_node_identifier(rhs_name)
621-
rhs_ident = f"{rhs_name}:{rhs_label}"
623+
rhs_name = self.create_node_identifier(rhs_name, rel_iterator)
624+
rhs_ident = f"{rhs_name}:{rhs_label}"
625+
622626
if relation["include_in_return"] and not already_present:
623627
self._additional_return(rhs_name)
624628

@@ -823,9 +827,11 @@ def add_to_target(statement: str, connector: Q, optional: bool) -> None:
823827
match_filters = [filter[0] for filter in target if not filter[1]]
824828
opt_match_filters = [filter[0] for filter in target if filter[1]]
825829
if q.connector == Q.OR and match_filters and opt_match_filters:
826-
raise ValueError(
827-
"Cannot filter using OR operator on variables coming from both MATCH and OPTIONAL MATCH statements"
828-
)
830+
# In this case, we can't split filters in two WHERE statements so we move
831+
# everything into the one applied after OPTIONAL MATCH statements...
832+
opt_match_filters += match_filters
833+
match_filters = []
834+
829835
ret = f" {q.connector} ".join(match_filters)
830836
if ret and q.negated:
831837
ret = f"NOT ({ret})"
@@ -1377,6 +1383,7 @@ def __init__(self, source: Any) -> None:
13771383
self._extra_results: list = []
13781384
self._subqueries: list[Subquery] = []
13791385
self._intermediate_transforms: list = []
1386+
self._unique_variables: list[str] = []
13801387

13811388
def __await__(self) -> Any:
13821389
return self.all().__await__() # type: ignore[attr-defined]
@@ -1548,6 +1555,11 @@ def _register_relation_to_fetch(
15481555
item["alias"] = alias
15491556
return item
15501557

1558+
def unique_variables(self, *pathes: tuple[str, ...]) -> "NodeSet":
1559+
"""Generate unique variable names for the given pathes."""
1560+
self._unique_variables = pathes
1561+
return self
1562+
15511563
def fetch_relations(self, *relation_names: tuple[str, ...]) -> "NodeSet":
15521564
"""Specify a set of relations to traverse and return."""
15531565
relations = []

test/async_/test_match_api.py

Lines changed: 56 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from datetime import datetime
33
from test._async_compat import mark_async_test
44

5-
import numpy as np
65
from pytest import raises, skip, warns
76

87
from neomodel import (
@@ -545,13 +544,14 @@ async def test_q_filters():
545544
assert c6 in combined_coffees
546545
assert c3 not in combined_coffees
547546

548-
with raises(
549-
ValueError,
550-
match=r"Cannot filter using OR operator on variables coming from both MATCH and OPTIONAL MATCH statements",
551-
):
552-
await Coffee.nodes.fetch_relations(Optional("species")).filter(
553-
Q(name="Latte") | Q(species__name="Robusta")
554-
).all()
547+
robusta = await Species(name="Robusta").save()
548+
await c4.species.connect(robusta)
549+
latte_or_robusta_coffee = (
550+
await Coffee.nodes.fetch_relations(Optional("species"))
551+
.filter(Q(name="Latte") | Q(species__name="Robusta"))
552+
.all()
553+
)
554+
assert len(latte_or_robusta_coffee) == 2
555555

556556
class QQ:
557557
pass
@@ -632,6 +632,11 @@ async def test_relation_prop_filtering():
632632
await nescafe.suppliers.connect(supplier2, {"since": datetime(2010, 4, 1, 0, 0)})
633633
await nescafe.species.connect(arabica)
634634

635+
result = await Coffee.nodes.filter(
636+
**{"suppliers|since__gt": datetime(2010, 4, 1, 0, 0)}
637+
).all()
638+
assert len(result) == 1
639+
635640
results = await Supplier.nodes.filter(
636641
**{"coffees__name": "Nescafe", "coffees|since__gt": datetime(2018, 4, 1, 0, 0)}
637642
).all()
@@ -1155,6 +1160,49 @@ async def test_in_filter_with_array_property():
11551160
), "Species found by tags with not match tags given"
11561161

11571162

1163+
@mark_async_test
1164+
async def test_unique_variables():
1165+
arabica = await Species(name="Arabica").save()
1166+
nescafe = await Coffee(name="Nescafe", price=99).save()
1167+
gold3000 = await Coffee(name="Gold 3000", price=11).save()
1168+
supplier1 = await Supplier(name="Supplier 1", delivery_cost=3).save()
1169+
supplier2 = await Supplier(name="Supplier 2", delivery_cost=20).save()
1170+
supplier3 = await Supplier(name="Supplier 3", delivery_cost=20).save()
1171+
1172+
await nescafe.suppliers.connect(supplier1, {"since": datetime(2020, 4, 1, 0, 0)})
1173+
await nescafe.suppliers.connect(supplier2, {"since": datetime(2010, 4, 1, 0, 0)})
1174+
await nescafe.species.connect(arabica)
1175+
await gold3000.suppliers.connect(supplier1, {"since": datetime(2020, 4, 1, 0, 0)})
1176+
await gold3000.species.connect(arabica)
1177+
1178+
nodeset = Supplier.nodes.fetch_relations("coffees", "coffees__species").filter(
1179+
coffees__name="Nescafe"
1180+
)
1181+
ast = await nodeset.query_cls(nodeset).build_ast()
1182+
query = ast.build_query()
1183+
assert "coffee_coffees1" in query
1184+
assert "coffee_coffees2" in query
1185+
results = await nodeset.all()
1186+
# This will be 3 because 2 suppliers for Nescafe and 1 for Gold 3000
1187+
# Gold 3000 is traversed because coffees__species redefines the coffees traversal
1188+
assert len(results) == 3
1189+
1190+
nodeset = (
1191+
Supplier.nodes.fetch_relations("coffees", "coffees__species")
1192+
.filter(coffees__name="Nescafe")
1193+
.unique_variables("coffees")
1194+
)
1195+
ast = await nodeset.query_cls(nodeset).build_ast()
1196+
query = ast.build_query()
1197+
assert "coffee_coffees" in query
1198+
assert "coffee_coffees1" not in query
1199+
assert "coffee_coffees2" not in query
1200+
results = await nodeset.all()
1201+
# This will 2 because Gold 3000 is excluded this time
1202+
# As coffees will be reused in coffees__species
1203+
assert len(results) == 2
1204+
1205+
11581206
@mark_async_test
11591207
async def test_async_iterator():
11601208
n = 10

test/sync_/test_match_api.py

Lines changed: 56 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from datetime import datetime
33
from test._async_compat import mark_sync_test
44

5-
import numpy as np
65
from pytest import raises, skip, warns
76

87
from neomodel import (
@@ -541,13 +540,14 @@ def test_q_filters():
541540
assert c6 in combined_coffees
542541
assert c3 not in combined_coffees
543542

544-
with raises(
545-
ValueError,
546-
match=r"Cannot filter using OR operator on variables coming from both MATCH and OPTIONAL MATCH statements",
547-
):
548-
Coffee.nodes.fetch_relations(Optional("species")).filter(
549-
Q(name="Latte") | Q(species__name="Robusta")
550-
).all()
543+
robusta = Species(name="Robusta").save()
544+
c4.species.connect(robusta)
545+
latte_or_robusta_coffee = (
546+
Coffee.nodes.fetch_relations(Optional("species"))
547+
.filter(Q(name="Latte") | Q(species__name="Robusta"))
548+
.all()
549+
)
550+
assert len(latte_or_robusta_coffee) == 2
551551

552552
class QQ:
553553
pass
@@ -624,6 +624,11 @@ def test_relation_prop_filtering():
624624
nescafe.suppliers.connect(supplier2, {"since": datetime(2010, 4, 1, 0, 0)})
625625
nescafe.species.connect(arabica)
626626

627+
result = Coffee.nodes.filter(
628+
**{"suppliers|since__gt": datetime(2010, 4, 1, 0, 0)}
629+
).all()
630+
assert len(result) == 1
631+
627632
results = Supplier.nodes.filter(
628633
**{"coffees__name": "Nescafe", "coffees|since__gt": datetime(2018, 4, 1, 0, 0)}
629634
).all()
@@ -1139,6 +1144,49 @@ def test_in_filter_with_array_property():
11391144
), "Species found by tags with not match tags given"
11401145

11411146

1147+
@mark_sync_test
1148+
def test_unique_variables():
1149+
arabica = Species(name="Arabica").save()
1150+
nescafe = Coffee(name="Nescafe", price=99).save()
1151+
gold3000 = Coffee(name="Gold 3000", price=11).save()
1152+
supplier1 = Supplier(name="Supplier 1", delivery_cost=3).save()
1153+
supplier2 = Supplier(name="Supplier 2", delivery_cost=20).save()
1154+
supplier3 = Supplier(name="Supplier 3", delivery_cost=20).save()
1155+
1156+
nescafe.suppliers.connect(supplier1, {"since": datetime(2020, 4, 1, 0, 0)})
1157+
nescafe.suppliers.connect(supplier2, {"since": datetime(2010, 4, 1, 0, 0)})
1158+
nescafe.species.connect(arabica)
1159+
gold3000.suppliers.connect(supplier1, {"since": datetime(2020, 4, 1, 0, 0)})
1160+
gold3000.species.connect(arabica)
1161+
1162+
nodeset = Supplier.nodes.fetch_relations("coffees", "coffees__species").filter(
1163+
coffees__name="Nescafe"
1164+
)
1165+
ast = nodeset.query_cls(nodeset).build_ast()
1166+
query = ast.build_query()
1167+
assert "coffee_coffees1" in query
1168+
assert "coffee_coffees2" in query
1169+
results = nodeset.all()
1170+
# This will be 3 because 2 suppliers for Nescafe and 1 for Gold 3000
1171+
# Gold 3000 is traversed because coffees__species redefines the coffees traversal
1172+
assert len(results) == 3
1173+
1174+
nodeset = (
1175+
Supplier.nodes.fetch_relations("coffees", "coffees__species")
1176+
.filter(coffees__name="Nescafe")
1177+
.unique_variables("coffees")
1178+
)
1179+
ast = nodeset.query_cls(nodeset).build_ast()
1180+
query = ast.build_query()
1181+
assert "coffee_coffees" in query
1182+
assert "coffee_coffees1" not in query
1183+
assert "coffee_coffees2" not in query
1184+
results = nodeset.all()
1185+
# This will 2 because Gold 3000 is excluded this time
1186+
# As coffees will be reused in coffees__species
1187+
assert len(results) == 2
1188+
1189+
11421190
@mark_sync_test
11431191
def test_async_iterator():
11441192
n = 10

0 commit comments

Comments
 (0)