Skip to content

Commit 7fccd75

Browse files
committed
Merge branch 'snuderl-execute-optimization'
2 parents a5dc136 + f21015c commit 7fccd75

File tree

4 files changed

+378
-339
lines changed

4 files changed

+378
-339
lines changed

graphql_core_promise/execute/promise.py

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
class PromiseExecutionContext(ExecutionContext):
4949
"""
5050
Translate methods on the original graphql.execution.execute.ExecutionContext
51-
to be promise-aware and promise-based so that promise-based dataloaders and
51+
to be promise-aware and promise-based so that promise-based dataloaders and
5252
resolvers can continue to function
5353
"""
5454

@@ -153,6 +153,7 @@ def execute_field(
153153

154154
if self.is_promise(result):
155155
result: Promise = result
156+
156157
# noinspection PyShadowingNames
157158
def await_result() -> Any:
158159
def handle_error(raw_error):
@@ -214,16 +215,24 @@ def execute_fields(
214215
if not awaitable_fields:
215216
return results
216217

217-
def get_results() -> Dict[str, Any]:
218-
def on_all_resolve(resolved_results: List[Any]):
219-
for field, result in zip(awaitable_fields, resolved_results):
220-
results[field] = result
221-
return results
218+
def get_results() -> dict[str, Any]:
219+
r = [results[field] for field in awaitable_fields]
220+
if len(r) > 1:
222221

223-
p = Promise.all([results[field] for field in awaitable_fields]).then(
224-
on_all_resolve
225-
)
226-
return p
222+
def on_all_resolve(resolved_results: list[Any]):
223+
for field, result in zip(awaitable_fields, resolved_results):
224+
results[field] = result
225+
return results
226+
227+
p = Promise.all(r).then(on_all_resolve)
228+
else:
229+
230+
def on_single_resolve(resolved):
231+
results[awaitable_fields[0]] = resolved
232+
return results
233+
234+
return r[0].then(on_single_resolve)
235+
return p # type: ignore
227236

228237
return get_results()
229238

@@ -468,7 +477,17 @@ def on_error(raw_error):
468477
return completed_results
469478

470479
# noinspection PyShadowingNames
471-
def get_completed_results() -> List[Any]:
480+
def get_completed_results() -> list[Any]:
481+
if len(awaitable_indices) == 1:
482+
483+
def on_one_resolved(result):
484+
completed_results[index] = result
485+
return completed_results
486+
487+
# If there is only one index, avoid the overhead of parallelization.
488+
index = awaitable_indices[0]
489+
return completed_results[0].then(on_one_resolved)
490+
472491
def on_all_resolve(results):
473492
for index, result in zip(awaitable_indices, results):
474493
completed_results[index] = result
@@ -479,4 +498,4 @@ def on_all_resolve(results):
479498
).then(on_all_resolve)
480499

481500
res = get_completed_results()
482-
return res
501+
return res

0 commit comments

Comments
 (0)