Skip to content

Commit f125b4e

Browse files
committed
refactor: convert IncrementalPublisher class members to method args
Replicates graphql/graphql-js@6fa7d55
1 parent 57cab52 commit f125b4e

File tree

1 file changed

+55
-38
lines changed

1 file changed

+55
-38
lines changed

src/graphql/execution/incremental_publisher.py

Lines changed: 55 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
TYPE_CHECKING,
99
Any,
1010
AsyncGenerator,
11+
NamedTuple,
1112
Sequence,
1213
cast,
1314
)
@@ -37,6 +38,7 @@
3738
DeferredFragmentRecord,
3839
DeferredGroupedFieldSetResult,
3940
IncrementalDataRecord,
41+
IncrementalDataRecordResult,
4042
IncrementalResult,
4143
ReconcilableDeferredGroupedFieldSetResult,
4244
StreamItemsResult,
@@ -60,6 +62,14 @@ class IncrementalPublisherContext(Protocol):
6062
cancellable_streams: set[CancellableStreamRecord] | None
6163

6264

65+
class SubsequentIncrementalExecutionResultContext(NamedTuple):
66+
"""The context for subsequent incremental execution results."""
67+
68+
pending: list[PendingResult]
69+
incremental: list[IncrementalResult]
70+
completed: list[CompletedResult]
71+
72+
6373
class IncrementalPublisher:
6474
"""Publish incremental results.
6575
@@ -72,15 +82,11 @@ class IncrementalPublisher:
7282
_context: IncrementalPublisherContext
7383
_next_id: int
7484
_incremental_graph: IncrementalGraph
75-
_incremental: list[IncrementalResult]
76-
_completed: list[CompletedResult]
7785

7886
def __init__(self, context: IncrementalPublisherContext) -> None:
7987
self._context = context
8088
self._next_id = 0
8189
self._incremental_graph = IncrementalGraph()
82-
self._incremental = []
83-
self._completed = []
8490

8591
def build_response(
8692
self,
@@ -131,36 +137,26 @@ async def _subscribe(
131137
self,
132138
) -> AsyncGenerator[SubsequentIncrementalExecutionResult, None]:
133139
"""Subscribe to the incremental results."""
140+
incremental_graph = self._incremental_graph
141+
check_has_next = incremental_graph.has_next
142+
handle_completed_incremental_data = self._handle_completed_incremental_data
143+
completed_incremental_data = incremental_graph.completed_incremental_data()
144+
# use the raw iterator rather than 'async for' so as not to end the iterator
145+
# when exiting the loop with the next value
146+
get_next_results = completed_incremental_data.__aiter__().__anext__
147+
is_done = False
134148
try:
135-
incremental_graph = self._incremental_graph
136-
get_new_pending = incremental_graph.get_new_pending
137-
check_has_next = incremental_graph.has_next
138-
pending_sources_to_results = self._pending_sources_to_results
139-
completed_incremental_data = incremental_graph.completed_incremental_data()
140-
# use the raw iterator rather than 'async for' so as not to end the iterator
141-
# when exiting the loop with the next value
142-
get_next_results = completed_incremental_data.__aiter__().__anext__
143-
is_done = False
144149
while not is_done:
145150
try:
146151
completed_results = await get_next_results()
147152
except StopAsyncIteration: # pragma: no cover
148153
break
149-
pending: list[PendingResult] = []
150154

155+
context = SubsequentIncrementalExecutionResultContext([], [], [])
151156
for completed_result in completed_results:
152-
if is_deferred_grouped_field_set_result(completed_result):
153-
self._handle_completed_deferred_grouped_field_set(
154-
completed_result
155-
)
156-
else:
157-
completed_result = cast("StreamItemsResult", completed_result)
158-
await self._handle_completed_stream_items(completed_result)
157+
await handle_completed_incremental_data(completed_result, context)
159158

160-
new_pending = get_new_pending()
161-
pending.extend(pending_sources_to_results(new_pending))
162-
163-
if self._incremental or self._completed:
159+
if context.incremental or context.completed:
164160
has_next = check_has_next()
165161

166162
if not has_next:
@@ -169,15 +165,12 @@ async def _subscribe(
169165
subsequent_incremental_execution_result = (
170166
SubsequentIncrementalExecutionResult(
171167
has_next=has_next,
172-
pending=pending or None,
173-
incremental=self._incremental or None,
174-
completed=self._completed or None,
168+
pending=context.pending or None,
169+
incremental=context.incremental or None,
170+
completed=context.completed or None,
175171
)
176172
)
177173

178-
self._incremental = []
179-
self._completed = []
180-
181174
yield subsequent_incremental_execution_result
182175
finally:
183176
await self._stop_async_iterators()
@@ -194,12 +187,34 @@ async def _stop_async_iterators(self) -> None:
194187
if early_returns:
195188
await gather(*early_returns, return_exceptions=True)
196189

190+
async def _handle_completed_incremental_data(
191+
self,
192+
completed_incremental_data: IncrementalDataRecordResult,
193+
context: SubsequentIncrementalExecutionResultContext,
194+
) -> None:
195+
if is_deferred_grouped_field_set_result(completed_incremental_data):
196+
self._handle_completed_deferred_grouped_field_set(
197+
completed_incremental_data, context
198+
)
199+
else:
200+
completed_incremental_data = cast(
201+
"StreamItemsResult", completed_incremental_data
202+
)
203+
await self._handle_completed_stream_items(
204+
completed_incremental_data, context
205+
)
206+
207+
new_pending = self._incremental_graph.get_new_pending()
208+
context.pending.extend(self._pending_sources_to_results(new_pending))
209+
197210
def _handle_completed_deferred_grouped_field_set(
198-
self, deferred_grouped_field_set_result: DeferredGroupedFieldSetResult
211+
self,
212+
deferred_grouped_field_set_result: DeferredGroupedFieldSetResult,
213+
context: SubsequentIncrementalExecutionResultContext,
199214
) -> None:
200215
"""Handle completed deferred grouped field set result."""
201-
append_completed = self._completed.append
202-
append_incremental = self._incremental.append
216+
append_completed = context.completed.append
217+
append_incremental = context.incremental.append
203218
if is_non_reconcilable_deferred_grouped_field_set_result(
204219
deferred_grouped_field_set_result
205220
):
@@ -260,7 +275,9 @@ def _handle_completed_deferred_grouped_field_set(
260275
append_completed(CompletedResult(id_))
261276

262277
async def _handle_completed_stream_items(
263-
self, stream_items_result: StreamItemsResult
278+
self,
279+
stream_items_result: StreamItemsResult,
280+
context: SubsequentIncrementalExecutionResultContext,
264281
) -> None:
265282
"""Handle completed stream."""
266283
stream_record = stream_items_result.stream_record
@@ -269,7 +286,7 @@ async def _handle_completed_stream_items(
269286
return # pragma: no cover
270287
incremental_graph = self._incremental_graph
271288
if stream_items_result.errors is not None:
272-
self._completed.append(CompletedResult(id_, stream_items_result.errors))
289+
context.completed.append(CompletedResult(id_, stream_items_result.errors))
273290
incremental_graph.remove_subsequent_result_record(stream_record)
274291
if is_cancellable_stream_record(stream_record):
275292
cancellable_streams = self._context.cancellable_streams
@@ -278,7 +295,7 @@ async def _handle_completed_stream_items(
278295
with suppress(Exception):
279296
await stream_record.early_return
280297
elif stream_items_result.result is None:
281-
self._completed.append(CompletedResult(id_))
298+
context.completed.append(CompletedResult(id_))
282299
incremental_graph.remove_subsequent_result_record(stream_record)
283300
if is_cancellable_stream_record(stream_record):
284301
cancellable_streams = self._context.cancellable_streams
@@ -289,7 +306,7 @@ async def _handle_completed_stream_items(
289306
incremental_entry = IncrementalStreamResult(
290307
items=result.items, id=id_, errors=result.errors
291308
)
292-
self._incremental.append(incremental_entry)
309+
context.incremental.append(incremental_entry)
293310
if stream_items_result.incremental_data_records: # pragma: no branch
294311
incremental_graph.add_incremental_data_records(
295312
stream_items_result.incremental_data_records

0 commit comments

Comments
 (0)