Skip to content

Commit

Permalink
Refactor common validation into separate function
Browse files Browse the repository at this point in the history
  • Loading branch information
kristjanvalur committed Jun 5, 2023
1 parent 1c9812e commit 78191f1
Showing 1 changed file with 118 additions and 110 deletions.
228 changes: 118 additions & 110 deletions strawberry/schema/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,100 @@ def validate_document(
)


async def _parse_and_validate_async(
execution_context: ExecutionContext,
process_errors: Callable[[List[GraphQLError], Optional[ExecutionContext]], None],
extensions_runner: SchemaExtensionsRunner,
allowed_operation_types: Optional[Iterable[OperationType]] = None,
):
assert execution_context.query
async with extensions_runner.parsing():
try:
if not execution_context.graphql_document:
execution_context.graphql_document = parse_document(
execution_context.query, **execution_context.parse_options
)

except GraphQLError as error:
execution_context.errors = [error]
process_errors([error], execution_context)
return ExecutionResult(
data=None,
errors=[error],
extensions=await extensions_runner.get_extensions_results(),
)

except Exception as error: # pragma: no cover
error = GraphQLError(str(error), original_error=error)

execution_context.errors = [error]
process_errors([error], execution_context)

return ExecutionResult(
data=None,
errors=[error],
extensions=await extensions_runner.get_extensions_results(),
)

if (
allowed_operation_types
and execution_context.operation_type not in allowed_operation_types
):
raise InvalidOperationTypeError(execution_context.operation_type)

async with extensions_runner.validation():
_run_validation(execution_context)
if execution_context.errors:
process_errors(execution_context.errors, execution_context)
return ExecutionResult(data=None, errors=execution_context.errors)


def _parse_and_validate_sync(
execution_context: ExecutionContext,
process_errors: Callable[[List[GraphQLError], Optional[ExecutionContext]], None],
extensions_runner: SchemaExtensionsRunner,
allowed_operation_types: Iterable[OperationType],
) -> Optional[ExecutionResult]:
assert execution_context.query
with extensions_runner.parsing():
try:
if not execution_context.graphql_document:
execution_context.graphql_document = parse_document(
execution_context.query, **execution_context.parse_options
)

except GraphQLError as error:
execution_context.errors = [error]
process_errors([error], execution_context)
return ExecutionResult(
data=None,
errors=[error],
extensions=extensions_runner.get_extensions_results_sync(),
)

except Exception as error: # pragma: no cover
error = GraphQLError(str(error), original_error=error)

execution_context.errors = [error]
process_errors([error], execution_context)

return ExecutionResult(
data=None,
errors=[error],
extensions=extensions_runner.get_extensions_results_sync(),
)

if execution_context.operation_type not in allowed_operation_types:
raise InvalidOperationTypeError(execution_context.operation_type)

with extensions_runner.validation():
_run_validation(execution_context)
if execution_context.errors:
process_errors(execution_context.errors, execution_context)
return ExecutionResult(data=None, errors=execution_context.errors)
return None


def _run_validation(execution_context: ExecutionContext) -> None:
# Check if there are any validation rules or if validation has
# already been run by an extension
Expand Down Expand Up @@ -93,45 +187,18 @@ async def execute(
if not execution_context.query:
raise MissingQueryError()

async with extensions_runner.parsing():
try:
if not execution_context.graphql_document:
execution_context.graphql_document = parse_document(
execution_context.query, **execution_context.parse_options
)

except GraphQLError as error:
execution_context.errors = [error]
process_errors([error], execution_context)
return ExecutionResult(
data=None,
errors=[error],
extensions=await extensions_runner.get_extensions_results(),
)

except Exception as error: # pragma: no cover
error = GraphQLError(str(error), original_error=error)

execution_context.errors = [error]
process_errors([error], execution_context)

return ExecutionResult(
data=None,
errors=[error],
extensions=await extensions_runner.get_extensions_results(),
)

if execution_context.operation_type not in allowed_operation_types:
raise InvalidOperationTypeError(execution_context.operation_type)

async with extensions_runner.validation():
_run_validation(execution_context)
if execution_context.errors:
process_errors(execution_context.errors, execution_context)
return ExecutionResult(data=None, errors=execution_context.errors)
error_result = await _parse_and_validate_async(
execution_context,
process_errors,
extensions_runner,
allowed_operation_types,
)
if error_result is not None:
return error_result

async with extensions_runner.executing():
if not execution_context.result:
assert execution_context.graphql_document
result = original_execute(
schema,
execution_context.graphql_document,
Expand Down Expand Up @@ -186,44 +253,18 @@ def execute_sync(
if not execution_context.query:
raise MissingQueryError()

with extensions_runner.parsing():
try:
if not execution_context.graphql_document:
execution_context.graphql_document = parse_document(
execution_context.query, **execution_context.parse_options
)

except GraphQLError as error:
execution_context.errors = [error]
process_errors([error], execution_context)
return ExecutionResult(
data=None,
errors=[error],
extensions=extensions_runner.get_extensions_results_sync(),
)

except Exception as error: # pragma: no cover
error = GraphQLError(str(error), original_error=error)

execution_context.errors = [error]
process_errors([error], execution_context)
return ExecutionResult(
data=None,
errors=[error],
extensions=extensions_runner.get_extensions_results_sync(),
)

if execution_context.operation_type not in allowed_operation_types:
raise InvalidOperationTypeError(execution_context.operation_type)

with extensions_runner.validation():
_run_validation(execution_context)
if execution_context.errors:
process_errors(execution_context.errors, execution_context)
return ExecutionResult(data=None, errors=execution_context.errors)
error_result = _parse_and_validate_sync(
execution_context,
process_errors,
extensions_runner,
allowed_operation_types,
)
if error_result is not None:
return error_result

with extensions_runner.executing():
if not execution_context.result:
assert execution_context.graphql_document
result = original_execute(
schema,
execution_context.graphql_document,
Expand Down Expand Up @@ -292,50 +333,17 @@ async def subscribe(
async with extensions_runner.operation():
# Note: In graphql-core the schema would be validated here but in
# Strawberry we are validating it at initialisation time instead
assert execution_context.query is not None

async with extensions_runner.parsing():
try:
if not execution_context.graphql_document:
execution_context.graphql_document = parse_document(
execution_context.query, **execution_context.parse_options
)

except GraphQLError as error:
execution_context.errors = [error]
process_errors([error], execution_context)
raise SubscribeSingleResult(
ExecutionResult(
data=None,
errors=[error],
extensions=await extensions_runner.get_extensions_results(),
)
)

except Exception as error: # pragma: no cover
error = GraphQLError(str(error), original_error=error)
execution_context.errors = [error]
process_errors([error], execution_context)

raise SubscribeSingleResult(
ExecutionResult(
data=None,
errors=[error],
extensions=await extensions_runner.get_extensions_results(),
)
)

async with extensions_runner.validation():
_run_validation(execution_context)
if execution_context.errors:
process_errors(execution_context.errors, execution_context)
raise SubscribeSingleResult(
ExecutionResult(data=None, errors=execution_context.errors)
)
error_result = await _parse_and_validate_async(
execution_context, process_errors, extensions_runner
)
if error_result is not None:
raise SubscribeSingleResult(error_result)

async with extensions_runner.executing():
# currently original_subscribe is an async function. A future release
# of graphql-core will make it optionally awaitable
assert execution_context.graphql_document
result: Union[AsyncIterable[GraphQLExecutionResult], GraphQLExecutionResult]
result_or_awaitable = original_subscribe(
schema,
Expand Down

0 comments on commit 78191f1

Please sign in to comment.