From 069bedb814e2ea8224717da20a9688c7684b5d47 Mon Sep 17 00:00:00 2001 From: nir Date: Tue, 1 Oct 2024 09:52:14 +0300 Subject: [PATCH] migrate custom schema extensions to the new design --- strawberry/extensions/add_validation_rules.py | 8 ++-- strawberry/extensions/base_extension.py | 11 ++--- strawberry/extensions/disable_validation.py | 5 ++- strawberry/extensions/mask_errors.py | 5 ++- strawberry/extensions/max_tokens.py | 5 ++- strawberry/extensions/parser_cache.py | 5 +-- strawberry/extensions/pyinstrument.py | 7 ++- strawberry/extensions/tracing/apollo.py | 27 ++++++++---- strawberry/extensions/tracing/datadog.py | 35 +++++++++------ .../extensions/tracing/opentelemetry.py | 22 ++++++---- strawberry/extensions/tracing/sentry.py | 43 ++++++++++++------- strawberry/extensions/validation_cache.py | 5 +-- strawberry/schema/schema.py | 31 +++++++------ strawberry/types/execution.py | 3 ++ tests/schema/extensions/test_datadog.py | 4 +- tests/views/schema.py | 2 +- 16 files changed, 135 insertions(+), 83 deletions(-) diff --git a/strawberry/extensions/add_validation_rules.py b/strawberry/extensions/add_validation_rules.py index 763ef70b05..b0869f3885 100644 --- a/strawberry/extensions/add_validation_rules.py +++ b/strawberry/extensions/add_validation_rules.py @@ -7,6 +7,8 @@ if TYPE_CHECKING: from graphql import ASTValidationRule + from strawberry.types.execution import ExecutionContext + class AddValidationRules(SchemaExtension): """Add graphql-core validation rules. @@ -42,9 +44,9 @@ def enter_field(self, node, *args) -> None: def __init__(self, validation_rules: List[Type[ASTValidationRule]]) -> None: self.validation_rules = validation_rules - def on_operation(self) -> Iterator[None]: - self.execution_context.validation_rules = ( - self.execution_context.validation_rules + tuple(self.validation_rules) + def on_operation(self, execution_context: ExecutionContext) -> Iterator[None]: + execution_context.validation_rules = execution_context.validation_rules + tuple( + self.validation_rules ) yield diff --git a/strawberry/extensions/base_extension.py b/strawberry/extensions/base_extension.py index 3ba2fb6f0c..03a59cd85c 100644 --- a/strawberry/extensions/base_extension.py +++ b/strawberry/extensions/base_extension.py @@ -19,11 +19,12 @@ class LifecycleStep(Enum): class SchemaExtension: - # to support extensions that still use the old signature - # we have an optional argument here for ease of initialization. - def __init__( - self, *, execution_context: ExecutionContext | None = None - ) -> None: ... + if not TYPE_CHECKING: + # to support extensions that still use the old signature + # we have an optional argument here for ease of initialization. + def __init__( + self, *, execution_context: ExecutionContext | None = None + ) -> None: ... def on_operation( # type: ignore self, execution_context: ExecutionContext diff --git a/strawberry/extensions/disable_validation.py b/strawberry/extensions/disable_validation.py index cd9aeafaed..ee1fae4050 100644 --- a/strawberry/extensions/disable_validation.py +++ b/strawberry/extensions/disable_validation.py @@ -1,6 +1,7 @@ from typing import Iterator from strawberry.extensions.base_extension import SchemaExtension +from strawberry.types.execution import ExecutionContext class DisableValidation(SchemaExtension): @@ -26,8 +27,8 @@ def __init__(self) -> None: # some in the future pass - def on_operation(self) -> Iterator[None]: - self.execution_context.validation_rules = () # remove all validation_rules + def on_operation(self, execution_context: ExecutionContext) -> Iterator[None]: + execution_context.validation_rules = () # remove all validation_rules yield diff --git a/strawberry/extensions/mask_errors.py b/strawberry/extensions/mask_errors.py index 5cb0ffbaa7..e7814c5581 100644 --- a/strawberry/extensions/mask_errors.py +++ b/strawberry/extensions/mask_errors.py @@ -3,6 +3,7 @@ from graphql.error import GraphQLError from strawberry.extensions.base_extension import SchemaExtension +from strawberry.types.execution import ExecutionContext def default_should_mask_error(_: GraphQLError) -> bool: @@ -32,9 +33,9 @@ def anonymise_error(self, error: GraphQLError) -> GraphQLError: original_error=None, ) - def on_operation(self) -> Iterator[None]: + def on_operation(self, execution_context: ExecutionContext) -> Iterator[None]: yield - result = self.execution_context.result + result = execution_context.result if result and result.errors: processed_errors: List[GraphQLError] = [] for error in result.errors: diff --git a/strawberry/extensions/max_tokens.py b/strawberry/extensions/max_tokens.py index 60accd8a8a..3c0d0d243f 100644 --- a/strawberry/extensions/max_tokens.py +++ b/strawberry/extensions/max_tokens.py @@ -1,6 +1,7 @@ from typing import Iterator from strawberry.extensions.base_extension import SchemaExtension +from strawberry.types.execution import ExecutionContext class MaxTokensLimiter(SchemaExtension): @@ -34,8 +35,8 @@ def __init__( """ self.max_token_count = max_token_count - def on_operation(self) -> Iterator[None]: - self.execution_context.parse_options["max_tokens"] = self.max_token_count + def on_operation(self, execution_context: ExecutionContext) -> Iterator[None]: + execution_context.parse_options["max_tokens"] = self.max_token_count yield diff --git a/strawberry/extensions/parser_cache.py b/strawberry/extensions/parser_cache.py index 39b28c039b..7c4b0fe8f4 100644 --- a/strawberry/extensions/parser_cache.py +++ b/strawberry/extensions/parser_cache.py @@ -3,6 +3,7 @@ from strawberry.extensions.base_extension import SchemaExtension from strawberry.schema.execute import parse_document +from strawberry.types.execution import ExecutionContext class ParserCache(SchemaExtension): @@ -33,9 +34,7 @@ def __init__(self, maxsize: Optional[int] = None) -> None: """ self.cached_parse_document = lru_cache(maxsize=maxsize)(parse_document) - def on_parse(self) -> Iterator[None]: - execution_context = self.execution_context - + def on_parse(self, execution_context: ExecutionContext) -> Iterator[None]: execution_context.graphql_document = self.cached_parse_document( execution_context.query, **execution_context.parse_options ) diff --git a/strawberry/extensions/pyinstrument.py b/strawberry/extensions/pyinstrument.py index 53dd9fe66a..b34fd4b7bf 100644 --- a/strawberry/extensions/pyinstrument.py +++ b/strawberry/extensions/pyinstrument.py @@ -1,12 +1,15 @@ from __future__ import annotations from pathlib import Path -from typing import Iterator +from typing import TYPE_CHECKING, Iterator from pyinstrument import Profiler from strawberry.extensions.base_extension import SchemaExtension +if TYPE_CHECKING: + from strawberry.types.execution import ExecutionContext + class PyInstrument(SchemaExtension): """Extension to profile the execution time of resolvers using PyInstrument.""" @@ -17,7 +20,7 @@ def __init__( ) -> None: self._report_path = report_path - def on_operation(self) -> Iterator[None]: + def on_operation(self, execution_context: ExecutionContext) -> Iterator[None]: profiler = Profiler() profiler.start() diff --git a/strawberry/extensions/tracing/apollo.py b/strawberry/extensions/tracing/apollo.py index 2245f54643..5409d98884 100644 --- a/strawberry/extensions/tracing/apollo.py +++ b/strawberry/extensions/tracing/apollo.py @@ -14,6 +14,8 @@ if TYPE_CHECKING: from graphql import GraphQLResolveInfo + from strawberry.types.execution import ExecutionContext + DATETIME_FORMAT = "%Y-%m-%dT%H:%M:%S.%fZ" if TYPE_CHECKING: @@ -80,23 +82,30 @@ def to_json(self) -> Dict[str, Any]: class ApolloTracingExtension(SchemaExtension): - def __init__(self, execution_context: ExecutionContext) -> None: - self._resolver_stats: List[ApolloResolverStats] = [] - self.execution_context = execution_context - - def on_operation(self) -> Generator[None, None, None]: + def __init__(self) -> None: + self._resolver_stats: List[ + ApolloResolverStats + ] = [] # TODO: this is probably a bug just as using self.execution_context was a bug. + + def on_operation( + self, execution_context: ExecutionContext + ) -> Generator[None, None, None]: self.start_timestamp = self.now() self.start_time = datetime.now(timezone.utc) yield self.end_timestamp = self.now() self.end_time = datetime.now(timezone.utc) - def on_parse(self) -> Generator[None, None, None]: + def on_parse( + self, execution_context: ExecutionContext + ) -> Generator[None, None, None]: self._start_parsing = self.now() yield self._end_parsing = self.now() - def on_validate(self) -> Generator[None, None, None]: + def on_validate( + self, execution_context: ExecutionContext + ) -> Generator[None, None, None]: self._start_validation = self.now() yield self._end_validation = self.now() @@ -121,7 +130,9 @@ def stats(self) -> ApolloTracingStats: ), ) - def get_results(self) -> Dict[str, Dict[str, Any]]: + def get_results( + self, execution_context: ExecutionContext + ) -> Dict[str, Dict[str, Any]]: return {"tracing": self.stats.to_json()} async def resolve( diff --git a/strawberry/extensions/tracing/datadog.py b/strawberry/extensions/tracing/datadog.py index 2b8c676ca2..229823fa87 100644 --- a/strawberry/extensions/tracing/datadog.py +++ b/strawberry/extensions/tracing/datadog.py @@ -1,7 +1,7 @@ from __future__ import annotations import hashlib -from functools import cached_property +from functools import lru_cache from inspect import isawaitable from typing import TYPE_CHECKING, Any, Callable, Generator, Iterator, Optional @@ -25,15 +25,15 @@ def __init__( if execution_context: self.execution_context = execution_context - @cached_property - def _resource_name(self) -> str: - if self.execution_context.query is None: + @lru_cache + def _resource_name(self, execution_context: ExecutionContext) -> str: + if execution_context.query is None: return "query_missing" - query_hash = self.hash_query(self.execution_context.query) + query_hash = self.hash_query(execution_context.query) - if self.execution_context.operation_name: - return f"{self.execution_context.operation_name}:{query_hash}" + if execution_context.operation_name: + return f"{execution_context.operation_name}:{query_hash}" return query_hash @@ -54,7 +54,7 @@ class CustomExtension(DatadogTracingExtension): def create_span(self, lifecycle_step, name, **kwargs): span = super().create_span(lifecycle_step, name, **kwargs) if lifecycle_step == LifeCycleStep.OPERATION: - span.set_tag("graphql.query", self.execution_context.query) + span.set_tag("graphql.query", execution_context.query) return span ``` """ @@ -67,8 +67,8 @@ def create_span(self, lifecycle_step, name, **kwargs): def hash_query(self, query: str) -> str: return hashlib.md5(query.encode("utf-8")).hexdigest() - def on_operation(self) -> Iterator[None]: - self._operation_name = self.execution_context.operation_name + def on_operation(self, execution_context: ExecutionContext) -> Iterator[None]: + self._operation_name = execution_context.operation_name span_name = ( f"{self._operation_name}" if self._operation_name else "Anonymous Query" ) @@ -76,12 +76,12 @@ def on_operation(self) -> Iterator[None]: self.request_span = self.create_span( LifecycleStep.OPERATION, span_name, - resource=self._resource_name, + resource=self._resource_name(execution_context), service="strawberry", ) self.request_span.set_tag("graphql.operation_name", self._operation_name) - query = self.execution_context.query + query = execution_context.query if query is not None: query = query.strip() @@ -100,7 +100,9 @@ def on_operation(self) -> Iterator[None]: self.request_span.finish() - def on_validate(self) -> Generator[None, None, None]: + def on_validate( + self, execution_context: ExecutionContext + ) -> Generator[None, None, None]: self.validation_span = self.create_span( lifecycle_step=LifecycleStep.VALIDATION, name="Validation", @@ -108,7 +110,9 @@ def on_validate(self) -> Generator[None, None, None]: yield self.validation_span.finish() - def on_parse(self) -> Generator[None, None, None]: + def on_parse( + self, execution_context: ExecutionContext + ) -> Generator[None, None, None]: self.parsing_span = self.create_span( lifecycle_step=LifecycleStep.PARSE, name="Parsing", @@ -150,6 +154,9 @@ async def resolve( return result + def __hash__(self) -> int: + return id(self) + class DatadogTracingExtensionSync(DatadogTracingExtension): def resolve( diff --git a/strawberry/extensions/tracing/opentelemetry.py b/strawberry/extensions/tracing/opentelemetry.py index 285c2e5f92..c8eb2fb4e0 100644 --- a/strawberry/extensions/tracing/opentelemetry.py +++ b/strawberry/extensions/tracing/opentelemetry.py @@ -51,8 +51,10 @@ def __init__( if execution_context: self.execution_context = execution_context - def on_operation(self) -> Generator[None, None, None]: - self._operation_name = self.execution_context.operation_name + def on_operation( + self, execution_context: ExecutionContext + ) -> Generator[None, None, None]: + self._operation_name = execution_context.operation_name span_name = ( f"GraphQL Query: {self._operation_name}" if self._operation_name @@ -64,9 +66,9 @@ def on_operation(self) -> Generator[None, None, None]: ) self._span_holder[LifecycleStep.OPERATION].set_attribute("component", "graphql") - if self.execution_context.query: + if execution_context.query: self._span_holder[LifecycleStep.OPERATION].set_attribute( - "query", self.execution_context.query + "query", execution_context.query ) yield @@ -75,12 +77,14 @@ def on_operation(self) -> Generator[None, None, None]: # operation but we don't know until the parsing stage has finished. If # that's the case we want to update the span name so that we have a more # useful name in our trace. - if not self._operation_name and self.execution_context.operation_name: - span_name = f"GraphQL Query: {self.execution_context.operation_name}" + if not self._operation_name and execution_context.operation_name: + span_name = f"GraphQL Query: {execution_context.operation_name}" self._span_holder[LifecycleStep.OPERATION].update_name(span_name) self._span_holder[LifecycleStep.OPERATION].end() - def on_validate(self) -> Generator[None, None, None]: + def on_validate( + self, execution_context: ExecutionContext + ) -> Generator[None, None, None]: ctx = trace.set_span_in_context(self._span_holder[LifecycleStep.OPERATION]) self._span_holder[LifecycleStep.VALIDATION] = self._tracer.start_span( "GraphQL Validation", @@ -89,7 +93,9 @@ def on_validate(self) -> Generator[None, None, None]: yield self._span_holder[LifecycleStep.VALIDATION].end() - def on_parse(self) -> Generator[None, None, None]: + def on_parse( + self, execution_context: ExecutionContext + ) -> Generator[None, None, None]: ctx = trace.set_span_in_context(self._span_holder[LifecycleStep.OPERATION]) self._span_holder[LifecycleStep.PARSE] = self._tracer.start_span( "GraphQL Parsing", context=ctx diff --git a/strawberry/extensions/tracing/sentry.py b/strawberry/extensions/tracing/sentry.py index 7a0c6188b4..10dc19d38c 100644 --- a/strawberry/extensions/tracing/sentry.py +++ b/strawberry/extensions/tracing/sentry.py @@ -2,7 +2,7 @@ import hashlib import warnings -from functools import cached_property +from functools import lru_cache from inspect import isawaitable from typing import TYPE_CHECKING, Any, Callable, Generator, Optional @@ -32,22 +32,24 @@ def __init__( if execution_context: self.execution_context = execution_context - @cached_property - def _resource_name(self) -> str: - assert self.execution_context.query + @lru_cache + def _resource_name(self, execution_context: ExecutionContext) -> str: + assert execution_context.query - query_hash = self.hash_query(self.execution_context.query) + query_hash = self.hash_query(execution_context.query) - if self.execution_context.operation_name: - return f"{self.execution_context.operation_name}:{query_hash}" + if execution_context.operation_name: + return f"{execution_context.operation_name}:{query_hash}" return query_hash def hash_query(self, query: str) -> str: return hashlib.md5(query.encode("utf-8")).hexdigest() - def on_operation(self) -> Generator[None, None, None]: - self._operation_name = self.execution_context.operation_name + def on_operation( + self, execution_context: ExecutionContext + ) -> Generator[None, None, None]: + self._operation_name = execution_context.operation_name name = f"{self._operation_name}" if self._operation_name else "Anonymous Query" with configure_scope() as scope: @@ -63,22 +65,26 @@ def on_operation(self) -> Generator[None, None, None]: operation_type = "query" - assert self.execution_context.query + assert execution_context.query - if self.execution_context.query.strip().startswith("mutation"): + if execution_context.query.strip().startswith("mutation"): operation_type = "mutation" - if self.execution_context.query.strip().startswith("subscription"): + if execution_context.query.strip().startswith("subscription"): operation_type = "subscription" self.gql_span.set_tag("graphql.operation_type", operation_type) - self.gql_span.set_tag("graphql.resource_name", self._resource_name) - self.gql_span.set_data("graphql.query", self.execution_context.query) + self.gql_span.set_tag( + "graphql.resource_name", self._resource_name(execution_context) + ) + self.gql_span.set_data("graphql.query", execution_context.query) yield self.gql_span.finish() - def on_validate(self) -> Generator[None, None, None]: + def on_validate( + self, execution_context: ExecutionContext + ) -> Generator[None, None, None]: self.validation_span = self.gql_span.start_child( op="validation", description="Validation" ) @@ -87,7 +93,9 @@ def on_validate(self) -> Generator[None, None, None]: self.validation_span.finish() - def on_parse(self) -> Generator[None, None, None]: + def on_parse( + self, execution_context: ExecutionContext + ) -> Generator[None, None, None]: self.parsing_span = self.gql_span.start_child( op="parsing", description="Parsing" ) @@ -132,6 +140,9 @@ async def resolve( return result + def __hash__(self) -> int: + return id(self) + class SentryTracingExtensionSync(SentryTracingExtension): def resolve( diff --git a/strawberry/extensions/validation_cache.py b/strawberry/extensions/validation_cache.py index 6c9cc153c4..af3fb28dbc 100644 --- a/strawberry/extensions/validation_cache.py +++ b/strawberry/extensions/validation_cache.py @@ -3,6 +3,7 @@ from strawberry.extensions.base_extension import SchemaExtension from strawberry.schema.execute import validate_document +from strawberry.types.execution import ExecutionContext class ValidationCache(SchemaExtension): @@ -33,9 +34,7 @@ def __init__(self, maxsize: Optional[int] = None) -> None: """ self.cached_validate_document = lru_cache(maxsize=maxsize)(validate_document) - def on_validate(self) -> Iterator[None]: - execution_context = self.execution_context - + def on_validate(self, execution_context: ExecutionContext) -> Iterator[None]: errors = self.cached_validate_document( execution_context.schema._schema, execution_context.graphql_document, diff --git a/strawberry/schema/schema.py b/strawberry/schema/schema.py index 27ae063201..a50c122e27 100644 --- a/strawberry/schema/schema.py +++ b/strawberry/schema/schema.py @@ -1,5 +1,6 @@ from __future__ import annotations +import inspect import warnings from functools import cached_property, lru_cache from typing import ( @@ -214,17 +215,21 @@ class Query: raise ValueError(f"Invalid Schema. Errors:\n\n{formatted_errors}") def get_extensions(self, sync: bool = False) -> List[SchemaExtension]: - extensions = [] + ret: List[SchemaExtension] = [] if self.directives: - extensions = [ - *self.extensions, - DirectivesExtensionSync if sync else DirectivesExtension, - ] - extensions.extend(self.extensions) - return [ - ext if isinstance(ext, SchemaExtension) else ext(execution_context=None) - for ext in extensions - ] + ret.append( + DirectivesExtensionSync() if sync else DirectivesExtension(), + ) + + for extension in self.extensions: + if isinstance(extension, SchemaExtension): + ret.append(extension) + + elif inspect.signature(extension).parameters.get("execution_context"): + ret.append(extension(execution_context=None)) # type: ignore + ret.append(extension()) + + return ret @cached_property def _sync_extensions(self) -> List[SchemaExtension]: @@ -237,13 +242,13 @@ def _async_extensions(self) -> List[SchemaExtension]: @cached_property def sync_extension_runner(self) -> SchemaExtensionsRunner: return SchemaExtensionsRunner( - extensions=self.get_extensions(sync=True), + extensions=self._sync_extensions, ) @cached_property def async_extension_runner(self) -> SchemaExtensionsRunner: return SchemaExtensionsRunner( - extensions=self.get_extensions(sync=False), + extensions=self._async_extensions, ) def _get_middleware_manager( @@ -347,7 +352,7 @@ async def execute( root_value=root_value, operation_name=operation_name, ) - extensions = self.get_extensions() + extensions = self._async_extensions return await execute( self._schema, execution_context=execution_context, diff --git a/strawberry/types/execution.py b/strawberry/types/execution.py index 94f3982a6f..5007abb60c 100644 --- a/strawberry/types/execution.py +++ b/strawberry/types/execution.py @@ -88,6 +88,9 @@ def _get_first_operation(self) -> Optional[OperationDefinitionNode]: return get_first_operation(graphql_document) + def __hash__(self) -> int: + return id(self) + @dataclasses.dataclass class ExecutionResult: diff --git a/tests/schema/extensions/test_datadog.py b/tests/schema/extensions/test_datadog.py index 0af13877a7..cb42c59f1e 100644 --- a/tests/schema/extensions/test_datadog.py +++ b/tests/schema/extensions/test_datadog.py @@ -4,6 +4,7 @@ import pytest import strawberry +from strawberry.types.execution import ExecutionContext if typing.TYPE_CHECKING: from strawberry.extensions.tracing.datadog import DatadogTracingExtension @@ -263,13 +264,14 @@ async def test_create_span_override(datadog_extension): class CustomExtension(extension): def create_span( self, + execution_context: ExecutionContext, lifecycle_step: LifecycleStep, name: str, **kwargs, # noqa: ANN003 ): span = super().create_span(lifecycle_step, name, **kwargs) if lifecycle_step == LifecycleStep.OPERATION: - span.set_tag("graphql.query", self.execution_context.query) + span.set_tag("graphql.query", execution_context.query) return span schema = strawberry.Schema( diff --git a/tests/views/schema.py b/tests/views/schema.py index b0c14bfd76..56af800f00 100644 --- a/tests/views/schema.py +++ b/tests/views/schema.py @@ -21,7 +21,7 @@ def has_permission(self, source: Any, info: strawberry.Info, **kwargs: Any) -> b class MyExtension(SchemaExtension): - def get_results(self) -> Dict[str, str]: + def get_results(self, execution_context: ExecutionContext) -> Dict[str, str]: return {"example": "example"}