Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add lazy execution prototype #155

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 50 additions & 1 deletion src/graphql/execution/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
Path,
Undefined,
)
from ..utilities.deferred_value import DeferredValue, deferred_dict, deferred_list
from ..type import (
GraphQLAbstractType,
GraphQLField,
Expand Down Expand Up @@ -222,6 +223,11 @@ def __init__(
self.is_awaitable = is_awaitable
self._subfields_cache: Dict[Tuple, Dict[str, List[FieldNode]]] = {}

self._deferred_values: List[Tuple[DeferredValue, Any]] = []

def is_lazy(self, value: Any) -> bool:
return False

@classmethod
def build(
cls,
Expand Down Expand Up @@ -350,12 +356,25 @@ def execute_operation(

path = None

return (
result = (
self.execute_fields_serially
if operation.operation == OperationType.MUTATION
else self.execute_fields
)(root_type, root_value, path, root_fields)

while len(self._deferred_values) > 0:
for d in list(self._deferred_values):
self._deferred_values.remove(d)
res = d[1].get()
d[0].resolve(res)

if isinstance(result, DeferredValue):
if result.is_rejected:
raise cast(Exception, result.reason)
return result.value

return result

def execute_fields_serially(
self,
parent_type: GraphQLObjectType,
Expand Down Expand Up @@ -432,6 +451,7 @@ def execute_fields(
is_awaitable = self.is_awaitable
awaitable_fields: List[str] = []
append_awaitable = awaitable_fields.append
contains_deferred = False
for response_name, field_nodes in fields.items():
field_path = Path(path, response_name, parent_type.name)
result = self.execute_field(
Expand All @@ -441,6 +461,11 @@ def execute_fields(
results[response_name] = result
if is_awaitable(result):
append_awaitable(response_name)
if isinstance(result, DeferredValue):
contains_deferred = True

if contains_deferred:
return deferred_dict(results)

# If there are no coroutines, we can just return the object
if not awaitable_fields:
Expand Down Expand Up @@ -634,6 +659,23 @@ def complete_value(
if result is None or result is Undefined:
return None

if self.is_lazy(result):
def handle_resolve(resolved: Any) -> Any:
return self.complete_value(
return_type, field_nodes, info, path, resolved
)

def handle_error(raw_error: Exception) -> None:
raise raw_error

deferred = DeferredValue()
self._deferred_values.append((
deferred, result
))

completed = deferred.then(handle_resolve, handle_error)
return completed

# If field type is List, complete each item in the list with inner type
if is_list_type(return_type):
return self.complete_list_value(
Expand Down Expand Up @@ -705,6 +747,7 @@ async def async_iterable_to_list(
append_awaitable = awaitable_indices.append
completed_results: List[Any] = []
append_result = completed_results.append
contains_deferred = False
for index, item in enumerate(result):
# No need to modify the info object containing the path, since from here on
# it is not ever accessed by resolver functions.
Expand Down Expand Up @@ -746,6 +789,9 @@ async def await_completed(item: Any, item_path: Path) -> Any:
return None

completed_item = await_completed(completed_item, item_path)
if isinstance(completed_item, DeferredValue):
contains_deferred = True

except Exception as raw_error:
error = located_error(raw_error, field_nodes, item_path.as_list())
self.handle_field_error(error, item_type)
Expand All @@ -755,6 +801,9 @@ async def await_completed(item: Any, item_path: Path) -> Any:
append_awaitable(index)
append_result(completed_item)

if contains_deferred is True:
return deferred_list(completed_results)

if not awaitable_indices:
return completed_results

Expand Down
210 changes: 210 additions & 0 deletions src/graphql/utilities/deferred_value.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
from typing import Any, Optional, List, Callable, cast, Dict


OnSuccessCallback = Callable[[Any], None]
OnErrorCallback = Callable[[Exception], None]


class DeferredValue:
PENDING = -1
REJECTED = 0
RESOLVED = 1

_value: Optional[Any]
_reason: Optional[Exception]
_callbacks: List[OnSuccessCallback]
_errbacks: List[OnErrorCallback]

def __init__(
self,
on_complete: Optional[OnSuccessCallback] = None,
on_error: Optional[OnErrorCallback] = None,
):
self._state = self.PENDING
self._value = None
self._reason = None
if on_complete:
self._callbacks = [on_complete]
else:
self._callbacks = []
if on_error:
self._errbacks = [on_error]
else:
self._errbacks = []

def resolve(self, value: Any) -> None:
if self._state != DeferredValue.PENDING:
return

if isinstance(value, DeferredValue):
value.add_callback(self.resolve)
value.add_errback(self.reject)
return

self._value = value
self._state = self.RESOLVED

callbacks = self._callbacks
self._callbacks = []
for callback in callbacks:
try:
callback(value)
except Exception:
# Ignore errors in callbacks
pass

def reject(self, reason: Exception) -> None:
if self._state != DeferredValue.PENDING:
return

self._reason = reason
self._state = self.REJECTED

errbacks = self._errbacks
self._errbacks = []
for errback in errbacks:
try:
errback(reason)
except Exception:
# Ignore errors in errback
pass

def then(
self,
on_complete: Optional[OnSuccessCallback] = None,
on_error: Optional[OnErrorCallback] = None,
) -> "DeferredValue":
ret = DeferredValue()

def call_and_resolve(v: Any) -> None:
try:
if on_complete:
ret.resolve(on_complete(v))
else:
ret.resolve(v)
except Exception as e:
ret.reject(e)

def call_and_reject(r: Exception) -> None:
try:
if on_error:
ret.resolve(on_error(r))
else:
ret.reject(r)
except Exception as e:
ret.reject(e)

self.add_callback(call_and_resolve)
self.add_errback(call_and_resolve)

return ret

def add_callback(self, callback: OnSuccessCallback) -> None:
if self._state == self.PENDING:
self._callbacks.append(callback)
return

if self._state == self.RESOLVED:
callback(self._value)

def add_errback(self, callback: OnErrorCallback) -> None:
if self._state == self.PENDING:
self._errbacks.append(callback)
return

if self._state == self.REJECTED:
callback(cast(Exception, self._reason))

@property
def is_resolved(self) -> bool:
return self._state == self.RESOLVED

@property
def is_rejected(self) -> bool:
return self._state == self.REJECTED

@property
def value(self) -> Any:
return self._value

@property
def reason(self) -> Optional[Exception]:
return self._reason


def deferred_dict(m: Dict[str, Any]) -> DeferredValue:
"""
A special function that takes a dictionary of deferred values
and turns them into a deferred value that will ultimately resolve
into a dictionary of values.
"""
if len(m) == 0:
raise TypeError("Empty dict")

ret = DeferredValue()

plain_values = {
key: value for key, value in m.items() if not isinstance(value, DeferredValue)
}
deferred_values = {
key: value for key, value in m.items() if isinstance(value, DeferredValue)
}

count = len(deferred_values)

def handle_success(_: Any) -> None:
nonlocal count
count -= 1
if count == 0:
value = plain_values

for k, p in deferred_values.items():
value[k] = p.value

ret.resolve(value)

for p in deferred_values.values():
p.add_callback(handle_success)
p.add_errback(ret.reject)

return ret


def deferred_list(l: List[Any]) -> DeferredValue:
"""
A special function that takes a list of deferred values
and turns them into a deferred value for a list of values.
"""
if len(l) == 0:
raise TypeError("Empty list")

ret = DeferredValue()

plain_values = {}
deferred_values = {}
for index, value in enumerate(l):
if isinstance(value, DeferredValue):
deferred_values[index] = value
else:
plain_values[index] = value

count = len(deferred_values)

def handle_success(_: Any) -> None:
nonlocal count
count -= 1
if count == 0:
values = []

for k in sorted(list(plain_values.keys()) + list(deferred_values.keys())):
value = plain_values.get(k, None)
if not value:
value = deferred_values[k].value
values.append(value)
ret.resolve(values)

for p in l:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for p in l:
for p in deferred_values.values():

Because the List l will have plain_values too ?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The work presented in this PR is really good 👍🏼
We were able to increase our GQL performance by at least 5X using dataloaders.

p.add_callback(handle_success)
p.add_errback(ret.reject)

return ret
Loading