Skip to content

Commit a4c3654

Browse files
authored
Merge pull request #38 from NavAbility/33/async_calls
Async calls
2 parents 3d51d0c + 49aeef6 commit a4c3654

File tree

12 files changed

+142
-88
lines changed

12 files changed

+142
-88
lines changed

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,5 +36,6 @@
3636
"black==21.9b0",
3737
"flake8==4.0.1",
3838
"pytest==6.2.5",
39+
"pytest-asyncio==0.18.1",
3940
],
4041
)

src/navability/entities/factor/__init__.py

Whitespace-only changes.

src/navability/entities/navabilityclient.py

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -30,24 +30,38 @@ def mutate(self, options: MutationOptions):
3030
class NavAbilityWebsocketClient(NavAbilityClient):
3131
def __init__(self, url: str = "wss://api.d1.navability.io/graphql") -> None:
3232
super().__init__()
33-
transport = WebsocketsTransport(url=url)
34-
self.client = GQLCLient(transport=transport, fetch_schema_from_transport=True)
33+
self.transport = WebsocketsTransport(url=url)
3534

36-
def query(self, options: QueryOptions):
37-
return self.client.execute(options.query, options.variables)
35+
async def query(self, options: QueryOptions):
36+
async with GQLCLient(
37+
transport=self.transport, fetch_schema_from_transport=False
38+
) as client:
39+
result = await client.execute(options.query, options.variables)
40+
return result
3841

39-
def mutate(self, options: MutationOptions):
40-
return self.client.execute(options.mutation, options.variables)
42+
async def mutate(self, options: MutationOptions):
43+
async with GQLCLient(
44+
transport=self.transport, fetch_schema_from_transport=False
45+
) as client:
46+
result = await client.execute(options.mutation, options.variables)
47+
return result
4148

4249

4350
class NavAbilityHttpsClient(NavAbilityClient):
4451
def __init__(self, url: str = "https://api.d1.navability.io") -> None:
4552
super().__init__()
46-
transport = AIOHTTPTransport(url=url)
47-
self.client = GQLCLient(transport=transport, fetch_schema_from_transport=True)
48-
49-
def query(self, options: QueryOptions):
50-
return self.client.execute(options.query, options.variables)
51-
52-
def mutate(self, options: MutationOptions):
53-
return self.client.execute(options.mutation, options.variables)
53+
self.transport = AIOHTTPTransport(url=url)
54+
55+
async def query(self, options: QueryOptions):
56+
async with GQLCLient(
57+
transport=self.transport, fetch_schema_from_transport=True
58+
) as client:
59+
result = await client.execute(options.query, options.variables)
60+
return result
61+
62+
async def mutate(self, options: MutationOptions):
63+
async with GQLCLient(
64+
transport=self.transport, fetch_schema_from_transport=True
65+
) as client:
66+
result = await client.execute(options.mutation, options.variables)
67+
return result

src/navability/services/factor.py

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -34,40 +34,39 @@
3434
logger = logging.getLogger(__name__)
3535

3636

37-
def addFactor(navAbilityClient: NavAbilityClient, client: Client, f: Factor):
38-
return navAbilityClient.mutate(
37+
async def addFactor(navAbilityClient: NavAbilityClient, client: Client, f: Factor):
38+
result = await navAbilityClient.mutate(
3939
MutationOptions(
4040
gql(GQL_ADDFACTOR),
4141
{"factor": {"client": client.dump(), "packedData": f.dumps()}},
4242
)
43-
)["addFactor"]
43+
)
44+
return result["addFactor"]
4445

4546

46-
def listFactors(
47+
async def listFactors(
4748
navAbilityClient: NavAbilityClient,
4849
client: Client,
4950
regexFilter: str = ".*",
5051
tags: List[str] = None,
5152
solvable: int = 0,
5253
) -> List[str]:
53-
return [
54-
v.label
55-
for v in getFactors(
56-
navAbilityClient,
57-
client,
58-
detail=QueryDetail.SKELETON,
59-
regexFilter=regexFilter,
60-
tags=tags,
61-
solvable=solvable,
62-
)
63-
]
54+
factors = await getFactors(
55+
navAbilityClient,
56+
client,
57+
detail=QueryDetail.SKELETON,
58+
regexFilter=regexFilter,
59+
tags=tags,
60+
solvable=solvable,
61+
)
62+
return [f.label for f in factors]
6463

6564

6665
# Alias
6766
lsf = listFactors
6867

6968

70-
def getFactors(
69+
async def getFactors(
7170
navAbilityClient: NavAbilityClient,
7271
client: Client,
7372
detail: QueryDetail = QueryDetail.SKELETON,
@@ -86,7 +85,7 @@ def getFactors(
8685
"fields_full": detail == QueryDetail.FULL,
8786
}
8887
logger.debug(f"Query params: {params}")
89-
res = navAbilityClient.query(
88+
res = await navAbilityClient.query(
9089
QueryOptions(gql(GQL_FRAGMENT_FACTORS + GQL_GETFACTORS), params)
9190
)
9291
logger.debug(f"Query result: {res}")
@@ -116,11 +115,11 @@ def getFactors(
116115
]
117116

118117

119-
def getFactor(navAbilityClient: NavAbilityClient, client: Client, label: str):
118+
async def getFactor(navAbilityClient: NavAbilityClient, client: Client, label: str):
120119
params = client.dump()
121120
params["label"] = label
122121
logger.debug(f"Query params: {params}")
123-
res = navAbilityClient.query(
122+
res = await navAbilityClient.query(
124123
QueryOptions(gql(GQL_FRAGMENT_FACTORS + GQL_GETFACTOR), params)
125124
)
126125
logger.debug(f"Query result: {res}")

src/navability/services/solve.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
from navability.entities.navabilityclient import MutationOptions, NavAbilityClient
66

77

8-
def solveSession(navAbilityClient: NavAbilityClient, client: Client):
9-
return navAbilityClient.mutate(
8+
async def solveSession(navAbilityClient: NavAbilityClient, client: Client):
9+
result = await navAbilityClient.mutate(
1010
MutationOptions(gql(GQL_SOLVESESSION), {"client": client.dump()})
11-
)["solveSession"]
11+
)
12+
return result["solveSession"]

src/navability/services/status.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,16 @@
55
from navability.entities.statusmessage import StatusMessageSchema
66

77

8-
def getStatusMessages(navAbilityClient: NavAbilityClient, id: str):
9-
statusMessages = navAbilityClient.query(
8+
async def getStatusMessages(navAbilityClient: NavAbilityClient, id: str):
9+
statusMessages = await navAbilityClient.query(
1010
QueryOptions(gql(GQL_GETSTATUSMESSAGES), {"id": id})
1111
)
1212
schema = StatusMessageSchema(many=True)
1313
return schema.load(statusMessages["statusMessages"])
1414

1515

16-
def getStatusLatest(navAbilityClient: NavAbilityClient, id: str):
17-
statusMessages = navAbilityClient.query(
16+
async def getStatusLatest(navAbilityClient: NavAbilityClient, id: str):
17+
statusMessages = await navAbilityClient.query(
1818
QueryOptions(gql(GQL_GETSTATUSLATEST), {"id": id})
1919
)
2020
schema = StatusMessageSchema()

src/navability/services/utils.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from navability.services import getStatusLatest
66

77

8-
def waitForCompletion(
8+
async def waitForCompletion(
99
navAbilityClient: NavAbilityClient,
1010
requestIds: List[str],
1111
maxSeconds: int = 60,
@@ -21,13 +21,13 @@ def waitForCompletion(
2121
Defaults to "Complete".
2222
"""
2323
wait_time = maxSeconds
24-
while any(
25-
[
26-
getStatusLatest(navAbilityClient, res).state != expectedStatus
27-
for res in requestIds
28-
]
29-
):
24+
tasksInProgress = True
25+
while tasksInProgress:
3026
time.sleep(2)
3127
wait_time -= 2
3228
if wait_time <= 0:
3329
raise Exception(exceptionMessage)
30+
tasksInProgress = False
31+
for requestId in requestIds:
32+
result = await getStatusLatest(navAbilityClient, requestId)
33+
tasksInProgress |= result.state != expectedStatus

src/navability/services/variable.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -34,40 +34,40 @@
3434
logger = logging.getLogger(__name__)
3535

3636

37-
def addVariable(navAbilityClient: NavAbilityClient, client: Client, v: Variable):
38-
return navAbilityClient.mutate(
37+
async def addVariable(navAbilityClient: NavAbilityClient, client: Client, v: Variable):
38+
result = await navAbilityClient.mutate(
3939
MutationOptions(
4040
gql(GQL_ADDVARIABLE),
4141
{"variable": {"client": client.dump(), "packedData": v.dumpsPacked()}},
4242
)
43-
)["addVariable"]
43+
)
44+
return result["addVariable"]
4445

4546

46-
def listVariables(
47+
async def listVariables(
4748
navAbilityClient: NavAbilityClient,
4849
client: Client,
4950
regexFilter: str = ".*",
5051
tags: List[str] = None,
5152
solvable: int = 0,
5253
) -> List[str]:
53-
return [
54-
v.label
55-
for v in getVariables(
56-
navAbilityClient,
57-
client,
58-
detail=QueryDetail.SKELETON,
59-
regexFilter=regexFilter,
60-
tags=tags,
61-
solvable=solvable,
62-
)
63-
]
54+
variables = await getVariables(
55+
navAbilityClient,
56+
client,
57+
detail=QueryDetail.SKELETON,
58+
regexFilter=regexFilter,
59+
tags=tags,
60+
solvable=solvable,
61+
)
62+
result = [v.label for v in variables]
63+
return result
6464

6565

6666
# Alias
6767
ls = listVariables
6868

6969

70-
def getVariables(
70+
async def getVariables(
7171
navAbilityClient: NavAbilityClient,
7272
client: Client,
7373
detail: QueryDetail = QueryDetail.SKELETON,
@@ -86,7 +86,7 @@ def getVariables(
8686
"fields_full": detail == QueryDetail.FULL,
8787
}
8888
logger.debug(f"Query params: {params}")
89-
res = navAbilityClient.query(
89+
res = await navAbilityClient.query(
9090
QueryOptions(gql(GQL_FRAGMENT_VARIABLES + GQL_GETVARIABLES), params)
9191
)
9292
logger.debug(f"Query result: {res}")
@@ -116,11 +116,11 @@ def getVariables(
116116
]
117117

118118

119-
def getVariable(navAbilityClient: NavAbilityClient, client: Client, label: str):
119+
async def getVariable(navAbilityClient: NavAbilityClient, client: Client, label: str):
120120
params = client.dump()
121121
params["label"] = label
122122
logger.debug(f"Query params: {params}")
123-
res = navAbilityClient.query(
123+
res = await navAbilityClient.query(
124124
QueryOptions(gql(GQL_FRAGMENT_VARIABLES + GQL_GETVARIABLE), params)
125125
)
126126
logger.debug(f"Query result: {res}")

tests/conftest.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def client(env_config) -> Client:
6565

6666

6767
@pytest.fixture(scope="module")
68-
def example_graph(navability_https_client: NavAbilityClient, client: Client):
68+
async def example_graph(navability_https_client: NavAbilityClient, client: Client):
6969
variables = [
7070
Variable("x0", VariableType.Pose2.value),
7171
Variable("x1", VariableType.Pose2.value),
@@ -105,23 +105,23 @@ def example_graph(navability_https_client: NavAbilityClient, client: Client):
105105
]
106106
# Variables
107107
result_ids = [
108-
addVariable(navability_https_client, client, v) for v in variables
109-
] + [addFactor(navability_https_client, client, f) for f in factors]
108+
await addVariable(navability_https_client, client, v) for v in variables
109+
] + [await addFactor(navability_https_client, client, f) for f in factors]
110110

111111
logging.info(f"[Fixture] Adding variables and factors, waiting for completion")
112112

113-
waitForCompletion(navability_https_client, result_ids, maxSeconds=120)
113+
await waitForCompletion(navability_https_client, result_ids, maxSeconds=120)
114114

115115
return (navability_https_client, client, variables, factors)
116116

117117

118118
@pytest.fixture(scope="module")
119-
def example_graph_solved(example_graph):
119+
async def example_graph_solved(example_graph):
120120
"""Get the graph after it has been solved.
121121
NOTE this changes the graph, so tests need to be defensive.
122122
"""
123123
navability_https_client, client, variables, factors = example_graph
124124
logging.info(f"[Fixture] Solving graph, client = {client.dumps()}")
125-
requestId = solveSession(navability_https_client, client)
126-
waitForCompletion(navability_https_client, [requestId], maxSeconds=180)
125+
requestId = await solveSession(navability_https_client, client)
126+
await waitForCompletion(navability_https_client, [requestId], maxSeconds=180)
127127
return (navability_https_client, client, variables, factors)

tests/test_factor.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,22 @@
1+
import asyncio
2+
3+
import pytest
4+
15
from navability.entities import Client
26
from navability.services import lsf # getFactor
37

48

5-
def test_lsf(example_graph):
9+
@pytest.mark.asyncio
10+
async def test_lsf(example_graph):
611
navability_client, client, variables, factors = example_graph
7-
assert set(lsf(navability_client, client)) == set([f.label for f in factors])
12+
assert set(await lsf(navability_client, client)) == set([f.label for f in factors])
813

914

10-
def test_lsf_no_session(example_graph):
15+
@pytest.mark.asyncio
16+
async def test_lsf_no_session(example_graph):
1117
navability_client, client, variables, factors = example_graph
1218
noSessionClient = Client(client.userId, client.robotId, "DoesntExist")
13-
assert lsf(navability_client, noSessionClient) == []
19+
assert await lsf(navability_client, noSessionClient) == []
1420

1521

1622
# def test_getFactor(example_graph):
@@ -19,3 +25,10 @@ def test_lsf_no_session(example_graph):
1925
# getFactor(navability_client, client, factors[0].label).label
2026
# == variables[0].label
2127
# )
28+
29+
# Redefining the event loop so we can we can use module-level fixtures.
30+
@pytest.fixture(scope="module")
31+
def event_loop():
32+
loop = asyncio.get_event_loop()
33+
yield loop
34+
loop.close()

0 commit comments

Comments
 (0)