|
2 | 2 | from datetime import datetime
|
3 | 3 | from test._async_compat import mark_async_test
|
4 | 4 |
|
5 |
| -import numpy as np |
6 | 5 | from pytest import raises, skip, warns
|
7 | 6 |
|
8 | 7 | from neomodel import (
|
@@ -545,13 +544,14 @@ async def test_q_filters():
|
545 | 544 | assert c6 in combined_coffees
|
546 | 545 | assert c3 not in combined_coffees
|
547 | 546 |
|
548 |
| - with raises( |
549 |
| - ValueError, |
550 |
| - match=r"Cannot filter using OR operator on variables coming from both MATCH and OPTIONAL MATCH statements", |
551 |
| - ): |
552 |
| - await Coffee.nodes.fetch_relations(Optional("species")).filter( |
553 |
| - Q(name="Latte") | Q(species__name="Robusta") |
554 |
| - ).all() |
| 547 | + robusta = await Species(name="Robusta").save() |
| 548 | + await c4.species.connect(robusta) |
| 549 | + latte_or_robusta_coffee = ( |
| 550 | + await Coffee.nodes.fetch_relations(Optional("species")) |
| 551 | + .filter(Q(name="Latte") | Q(species__name="Robusta")) |
| 552 | + .all() |
| 553 | + ) |
| 554 | + assert len(latte_or_robusta_coffee) == 2 |
555 | 555 |
|
556 | 556 | class QQ:
|
557 | 557 | pass
|
@@ -632,6 +632,11 @@ async def test_relation_prop_filtering():
|
632 | 632 | await nescafe.suppliers.connect(supplier2, {"since": datetime(2010, 4, 1, 0, 0)})
|
633 | 633 | await nescafe.species.connect(arabica)
|
634 | 634 |
|
| 635 | + result = await Coffee.nodes.filter( |
| 636 | + **{"suppliers|since__gt": datetime(2010, 4, 1, 0, 0)} |
| 637 | + ).all() |
| 638 | + assert len(result) == 1 |
| 639 | + |
635 | 640 | results = await Supplier.nodes.filter(
|
636 | 641 | **{"coffees__name": "Nescafe", "coffees|since__gt": datetime(2018, 4, 1, 0, 0)}
|
637 | 642 | ).all()
|
@@ -1155,6 +1160,49 @@ async def test_in_filter_with_array_property():
|
1155 | 1160 | ), "Species found by tags with not match tags given"
|
1156 | 1161 |
|
1157 | 1162 |
|
| 1163 | +@mark_async_test |
| 1164 | +async def test_unique_variables(): |
| 1165 | + arabica = await Species(name="Arabica").save() |
| 1166 | + nescafe = await Coffee(name="Nescafe", price=99).save() |
| 1167 | + gold3000 = await Coffee(name="Gold 3000", price=11).save() |
| 1168 | + supplier1 = await Supplier(name="Supplier 1", delivery_cost=3).save() |
| 1169 | + supplier2 = await Supplier(name="Supplier 2", delivery_cost=20).save() |
| 1170 | + supplier3 = await Supplier(name="Supplier 3", delivery_cost=20).save() |
| 1171 | + |
| 1172 | + await nescafe.suppliers.connect(supplier1, {"since": datetime(2020, 4, 1, 0, 0)}) |
| 1173 | + await nescafe.suppliers.connect(supplier2, {"since": datetime(2010, 4, 1, 0, 0)}) |
| 1174 | + await nescafe.species.connect(arabica) |
| 1175 | + await gold3000.suppliers.connect(supplier1, {"since": datetime(2020, 4, 1, 0, 0)}) |
| 1176 | + await gold3000.species.connect(arabica) |
| 1177 | + |
| 1178 | + nodeset = Supplier.nodes.fetch_relations("coffees", "coffees__species").filter( |
| 1179 | + coffees__name="Nescafe" |
| 1180 | + ) |
| 1181 | + ast = await nodeset.query_cls(nodeset).build_ast() |
| 1182 | + query = ast.build_query() |
| 1183 | + assert "coffee_coffees1" in query |
| 1184 | + assert "coffee_coffees2" in query |
| 1185 | + results = await nodeset.all() |
| 1186 | + # This will be 3 because 2 suppliers for Nescafe and 1 for Gold 3000 |
| 1187 | + # Gold 3000 is traversed because coffees__species redefines the coffees traversal |
| 1188 | + assert len(results) == 3 |
| 1189 | + |
| 1190 | + nodeset = ( |
| 1191 | + Supplier.nodes.fetch_relations("coffees", "coffees__species") |
| 1192 | + .filter(coffees__name="Nescafe") |
| 1193 | + .unique_variables("coffees") |
| 1194 | + ) |
| 1195 | + ast = await nodeset.query_cls(nodeset).build_ast() |
| 1196 | + query = ast.build_query() |
| 1197 | + assert "coffee_coffees" in query |
| 1198 | + assert "coffee_coffees1" not in query |
| 1199 | + assert "coffee_coffees2" not in query |
| 1200 | + results = await nodeset.all() |
| 1201 | + # This will 2 because Gold 3000 is excluded this time |
| 1202 | + # As coffees will be reused in coffees__species |
| 1203 | + assert len(results) == 2 |
| 1204 | + |
| 1205 | + |
1158 | 1206 | @mark_async_test
|
1159 | 1207 | async def test_async_iterator():
|
1160 | 1208 | n = 10
|
|
0 commit comments