diff --git a/poetry.lock b/poetry.lock index 33a69bb97f..45bd964152 100644 --- a/poetry.lock +++ b/poetry.lock @@ -3611,6 +3611,24 @@ typing-extensions = "*" [package.extras] dev = ["flake8", "flake8-docstrings", "mypy", "packaging", "pre-commit", "pytest", "pytest-cov", "types-setuptools"] +[[package]] +name = "rustberry" +version = "0.0.14" +description = "" +optional = false +python-versions = ">=3.11" +files = [ + {file = "rustberry-0.0.14-cp311-cp311-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:afcde10419392e4f1c9be5114afdec821917add92d12751765e1b1932a6f57ed"}, + {file = "rustberry-0.0.14-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:54545317941ca0c8c3566fbcc4ea7c3968459f7d259c27dce08e8313761d627d"}, + {file = "rustberry-0.0.14-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a6c55b7e966698f4eed211bdbd6e978a8a0350886aecd48b1fe4adf49a9dd82d"}, + {file = "rustberry-0.0.14-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ff9672deac947fcb47ba4f001329c75f93ec483b993c2a0737c8dd0a956ef12c"}, + {file = "rustberry-0.0.14-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:1cda938e46cd429d6a2878aff06001610731dc86f1c8b64a05763b4a38fd4f40"}, + {file = "rustberry-0.0.14-cp311-none-win32.whl", hash = "sha256:e7b0b0c66373f032c7d53f1f9aa0f0451b5481e6fc691480bc4eefc06bc661df"}, + {file = "rustberry-0.0.14-cp311-none-win_amd64.whl", hash = "sha256:229782618feefedd98beaab12db45e5ef0ea7c5c6546d75bbff0c54fb45e33f7"}, + {file = "rustberry-0.0.14-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9272de768502c889eff7643f9468e97ace706d563e809d126822a546db10e986"}, + {file = "rustberry-0.0.14-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:9a4b6ef204a73459be06a5a2775948d0ae6a75139661cc3b6e9348c1878ec8b6"}, +] + [[package]] name = "sanic" version = "23.12.1" @@ -4840,4 +4858,4 @@ starlite = ["starlite"] [metadata] lock-version = "2.0" python-versions = "^3.8" -content-hash = "3adcc49a22af2f096570400d4a9e8687423f8fad1014a234e8e5668f6059ac31" +content-hash = "c78cbc096d5ae8d64936328fea42e3be94e724ca8a79e1399be98a2be9ee4ad2" diff --git a/pyproject.toml b/pyproject.toml index 01d01c5551..bb34537ca5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,7 +62,7 @@ libcst = {version = ">=0.4.7", optional = true} rich = {version = ">=12.0.0", optional = true} pyinstrument = {version = ">=4.0.0", optional = true} graphlib_backport = {version = "*", python = "<3.9", optional = true} - +rustberry = {version = ">=0.0.14", python=">=3.11"} [tool.poetry.group.dev.dependencies] asgiref = "^3.2" ddtrace = ">=1.6.4" diff --git a/strawberry/schema/execute.py b/strawberry/schema/execute.py index af0bd07a7f..acafec39b5 100644 --- a/strawberry/schema/execute.py +++ b/strawberry/schema/execute.py @@ -17,9 +17,8 @@ cast, ) -from graphql import GraphQLError, parse +from graphql import GraphQLError, parse, validate from graphql import execute as original_execute -from graphql.validation import validate from strawberry.exceptions import MissingQueryError from strawberry.extensions.runner import SchemaExtensionsRunner @@ -38,6 +37,7 @@ from strawberry.extensions import SchemaExtension from strawberry.types import ExecutionContext + from strawberry.types.execution import Executor, ParseOptions from strawberry.types.graphql import OperationType @@ -82,6 +82,7 @@ async def execute( execution_context: ExecutionContext, execution_context_class: Optional[Type[GraphQLExecutionContext]] = None, process_errors: Callable[[List[GraphQLError], Optional[ExecutionContext]], None], + executor: Executor, ) -> ExecutionResult: extensions_runner = SchemaExtensionsRunner( execution_context=execution_context, @@ -95,30 +96,28 @@ async def execute( if not execution_context.query: raise MissingQueryError() - async with extensions_runner.parsing(): - try: - if not execution_context.graphql_document: - execution_context.graphql_document = parse_document( - execution_context.query, **execution_context.parse_options - ) + async with extensions_runner.parsing(): + try: + if not execution_context.graphql_document: + executor.parse(execution_context) - except GraphQLError as exc: - execution_context.errors = [exc] - process_errors([exc], execution_context) - return ExecutionResult( - data=None, - errors=[exc], - extensions=await extensions_runner.get_extensions_results(), - ) + except GraphQLError as exc: + execution_context.errors = [exc] + process_errors([exc], execution_context) + return ExecutionResult( + data=None, + errors=[exc], + extensions=await extensions_runner.get_extensions_results(), + ) if execution_context.operation_type not in allowed_operation_types: raise InvalidOperationTypeError(execution_context.operation_type) - async with extensions_runner.validation(): - _run_validation(execution_context) - if execution_context.errors: - process_errors(execution_context.errors, execution_context) - return ExecutionResult(data=None, errors=execution_context.errors) + async with extensions_runner.validation(): + executor.validate(execution_context) + if execution_context.errors: + process_errors(execution_context.errors, execution_context) + return ExecutionResult(data=None, errors=execution_context.errors) async with extensions_runner.executing(): if not execution_context.result: @@ -180,6 +179,7 @@ def execute_sync( execution_context: ExecutionContext, execution_context_class: Optional[Type[GraphQLExecutionContext]] = None, process_errors: Callable[[List[GraphQLError], Optional[ExecutionContext]], None], + executor: Executor, ) -> ExecutionResult: extensions_runner = SchemaExtensionsRunner( execution_context=execution_context, @@ -193,30 +193,28 @@ def execute_sync( if not execution_context.query: raise MissingQueryError() - with extensions_runner.parsing(): - try: - if not execution_context.graphql_document: - execution_context.graphql_document = parse_document( - execution_context.query, **execution_context.parse_options - ) + with extensions_runner.parsing(): + try: + if not execution_context.graphql_document: + executor.parse(execution_context) - except GraphQLError as exc: - execution_context.errors = [exc] - process_errors([exc], execution_context) - return ExecutionResult( - data=None, - errors=[exc], - extensions=extensions_runner.get_extensions_results_sync(), - ) + except GraphQLError as exc: + execution_context.errors = [exc] + process_errors([exc], execution_context) + return ExecutionResult( + data=None, + errors=[exc], + extensions=extensions_runner.get_extensions_results_sync(), + ) if execution_context.operation_type not in allowed_operation_types: raise InvalidOperationTypeError(execution_context.operation_type) - with extensions_runner.validation(): - _run_validation(execution_context) - if execution_context.errors: - process_errors(execution_context.errors, execution_context) - return ExecutionResult(data=None, errors=execution_context.errors) + with extensions_runner.validation(): + executor.validate(execution_context) + if execution_context.errors: + process_errors(execution_context.errors, execution_context) + return ExecutionResult(data=None, errors=execution_context.errors) with extensions_runner.executing(): if not execution_context.result: diff --git a/strawberry/schema/executors.py b/strawberry/schema/executors.py new file mode 100644 index 0000000000..b9ac0aff0e --- /dev/null +++ b/strawberry/schema/executors.py @@ -0,0 +1,58 @@ +from typing import TYPE_CHECKING + +from graphql import GraphQLError, parse +from rustberry import QueryCompiler + +from strawberry import Schema +from strawberry.types.execution import ExecutionContext, Executor + +if TYPE_CHECKING: + from rustberry._rustberry import Document + +RUSTBERRY_DOCUMENT_FIELD = "__rustberry_document" + + +class RustberryExecutor(Executor): + def __init__(self, schema: Schema) -> None: + super().__init__(schema) + self.compiler = QueryCompiler(schema.as_str()) + + def parse(self, execution_context: ExecutionContext) -> None: + document = self.compiler.parse(execution_context.query) + setattr(execution_context, RUSTBERRY_DOCUMENT_FIELD, document) + execution_context.graphql_document = self.compiler.gql_core_ast_mirror(document) + + def validate( + self, + execution_context: ExecutionContext, + ) -> None: + assert execution_context.graphql_document + document: Document = getattr(execution_context, RUSTBERRY_DOCUMENT_FIELD, None) + assert document, "Document not set - Required for Rustberry use" + validation_successful = self.compiler.validate(document) + if not validation_successful: + execution_context.errors = execution_context.errors or [] + execution_context.errors.append(GraphQLError("Validation failed")) + + +class RustberryExecutorV2(Executor): + def __init__(self, schema: Schema) -> None: + super().__init__(schema) + self.compiler = QueryCompiler(schema.as_str()) + + def parse(self, execution_context: ExecutionContext) -> None: + document = self.compiler.parse(execution_context.query) + setattr(execution_context, RUSTBERRY_DOCUMENT_FIELD, document) + execution_context.graphql_document = parse(execution_context.query) + + def validate( + self, + execution_context: ExecutionContext, + ) -> None: + assert execution_context.graphql_document + document: Document = getattr(execution_context, RUSTBERRY_DOCUMENT_FIELD, None) + assert document, "Document not set - Required for Rustberry use" + validation_successful = self.compiler.validate(document) + if not validation_successful: + execution_context.errors = execution_context.errors or [] + execution_context.errors.append(GraphQLError("Validation failed")) diff --git a/strawberry/schema/schema.py b/strawberry/schema/schema.py index b43963d9b5..f1f2ea8600 100644 --- a/strawberry/schema/schema.py +++ b/strawberry/schema/schema.py @@ -40,6 +40,7 @@ from strawberry.types.types import StrawberryObjectDefinition from ..printer import print_schema +from ..types.execution import Executor, GraphQlCoreExecutor from . import compat from .base import BaseSchema from .config import StrawberryConfig @@ -82,6 +83,7 @@ def __init__( Dict[object, Union[Type, ScalarWrapper, ScalarDefinition]] ] = None, schema_directives: Iterable[object] = (), + executor_class: Optional[Type[Executor]] = None, ) -> None: self.query = query self.mutation = mutation @@ -176,6 +178,11 @@ def __init__( formatted_errors = "\n\n".join(f"❌ {error.message}" for error in errors) raise ValueError(f"Invalid Schema. Errors:\n\n{formatted_errors}") + if executor_class: + self.executor = executor_class(self) + else: + self.executor = GraphQlCoreExecutor(self) + def get_extensions( self, sync: bool = False ) -> List[Union[Type[SchemaExtension], SchemaExtension]]: @@ -267,6 +274,7 @@ async def execute( execution_context=execution_context, allowed_operation_types=allowed_operation_types, process_errors=self.process_errors, + executor=self.executor, ) return result @@ -299,6 +307,7 @@ def execute_sync( execution_context=execution_context, allowed_operation_types=allowed_operation_types, process_errors=self.process_errors, + executor=self.executor, ) return result diff --git a/strawberry/types/execution.py b/strawberry/types/execution.py index 9dc7ff7ef3..adb77c612c 100644 --- a/strawberry/types/execution.py +++ b/strawberry/types/execution.py @@ -1,5 +1,6 @@ from __future__ import annotations +import abc import dataclasses from typing import ( TYPE_CHECKING, @@ -12,7 +13,7 @@ ) from typing_extensions import TypedDict -from graphql import specified_rules +from graphql import parse, specified_rules, validate from strawberry.utils.operation import get_first_operation, get_operation_type @@ -29,6 +30,45 @@ from .graphql import OperationType +class Executor(abc.ABC): + def __init__(self, schema: Schema) -> None: + self.schema = schema + + @abc.abstractmethod + def parse(self, execution_context: ExecutionContext) -> None: ... + + @abc.abstractmethod + def validate( + self, + execution_context: ExecutionContext, + ) -> None: ... + + +class GraphQlCoreExecutor(Executor): + def __init__(self, schema: Schema) -> None: + super().__init__(schema) + + def parse(self, execution_context: ExecutionContext) -> None: + execution_context.graphql_document = parse( + execution_context.query, **execution_context.parse_options + ) + + def validate( + self, + execution_context: ExecutionContext, + ) -> None: + if ( + len(execution_context.validation_rules) > 0 + and execution_context.errors is None + ): + assert execution_context.graphql_document + execution_context.errors = validate( + execution_context.schema._schema, + execution_context.graphql_document, + execution_context.validation_rules, + ) + + @dataclasses.dataclass class ExecutionContext: query: Optional[str] diff --git a/tests/benchmarks/api.py b/tests/benchmarks/api.py index 157b18b556..ac368cd5f6 100644 --- a/tests/benchmarks/api.py +++ b/tests/benchmarks/api.py @@ -2,6 +2,7 @@ import strawberry from strawberry.directive import DirectiveLocation +from strawberry.schema.executors import RustberryExecutor, RustberryExecutorV2 @strawberry.type @@ -79,7 +80,12 @@ def uppercase(value: str) -> str: return value.upper() -schema = strawberry.Schema(query=Query, subscription=Subscription) +schema = strawberry.Schema( + query=Query, subscription=Subscription, executor_class=RustberryExecutor +) schema_with_directives = strawberry.Schema( - query=Query, directives=[uppercase], subscription=Subscription + query=Query, + directives=[uppercase], + subscription=Subscription, + executor_class=RustberryExecutorV2, ) diff --git a/tests/benchmarks/schema.py b/tests/benchmarks/schema.py index fd7b306895..5624d27311 100644 --- a/tests/benchmarks/schema.py +++ b/tests/benchmarks/schema.py @@ -5,6 +5,7 @@ from typing_extensions import Annotated import strawberry +from strawberry.schema.executors import RustberryExecutor @strawberry.enum @@ -178,4 +179,4 @@ async def search( ] -schema = strawberry.Schema(query=Query) +schema = strawberry.Schema(query=Query, executor_class=RustberryExecutor) diff --git a/tests/benchmarks/test_execute.py b/tests/benchmarks/test_execute.py index 865d49c700..db63a5ae7c 100644 --- a/tests/benchmarks/test_execute.py +++ b/tests/benchmarks/test_execute.py @@ -9,6 +9,7 @@ import strawberry from strawberry.scalars import ID +from strawberry.schema.executors import RustberryExecutorV2 @pytest.mark.benchmark @@ -54,7 +55,7 @@ def patrons(self) -> List[Patron]: for i in range(1000) ] - schema = strawberry.Schema(query=Query) + schema = strawberry.Schema(query=Query, executor_class=RustberryExecutorV2) query = """ query something{ @@ -92,7 +93,9 @@ class Item: class Query: items: List[Item] - schema = strawberry.Schema(query=Query, types=CONCRETE_TYPES) + schema = strawberry.Schema( + query=Query, types=CONCRETE_TYPES, executor_class=RustberryExecutorV2 + ) query = "query { items { id } }" def run():