Skip to content

Commit

Permalink
migrate custom schema extensions to the new design
Browse files Browse the repository at this point in the history
  • Loading branch information
nrbnlulu committed Oct 1, 2024
1 parent e60f1ef commit 069bedb
Show file tree
Hide file tree
Showing 16 changed files with 135 additions and 83 deletions.
8 changes: 5 additions & 3 deletions strawberry/extensions/add_validation_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
if TYPE_CHECKING:
from graphql import ASTValidationRule

from strawberry.types.execution import ExecutionContext


class AddValidationRules(SchemaExtension):
"""Add graphql-core validation rules.
Expand Down Expand Up @@ -42,9 +44,9 @@ def enter_field(self, node, *args) -> None:
def __init__(self, validation_rules: List[Type[ASTValidationRule]]) -> None:
self.validation_rules = validation_rules

def on_operation(self) -> Iterator[None]:
self.execution_context.validation_rules = (
self.execution_context.validation_rules + tuple(self.validation_rules)
def on_operation(self, execution_context: ExecutionContext) -> Iterator[None]:
execution_context.validation_rules = execution_context.validation_rules + tuple(

Check warning on line 48 in strawberry/extensions/add_validation_rules.py

View check run for this annotation

Codecov / codecov/patch

strawberry/extensions/add_validation_rules.py#L48

Added line #L48 was not covered by tests
self.validation_rules
)
yield

Expand Down
11 changes: 6 additions & 5 deletions strawberry/extensions/base_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@ class LifecycleStep(Enum):


class SchemaExtension:
# to support extensions that still use the old signature
# we have an optional argument here for ease of initialization.
def __init__(
self, *, execution_context: ExecutionContext | None = None
) -> None: ...
if not TYPE_CHECKING:
# to support extensions that still use the old signature
# we have an optional argument here for ease of initialization.
def __init__(
self, *, execution_context: ExecutionContext | None = None
) -> None: ...

def on_operation( # type: ignore
self, execution_context: ExecutionContext
Expand Down
5 changes: 3 additions & 2 deletions strawberry/extensions/disable_validation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Iterator

from strawberry.extensions.base_extension import SchemaExtension
from strawberry.types.execution import ExecutionContext


class DisableValidation(SchemaExtension):
Expand All @@ -26,8 +27,8 @@ def __init__(self) -> None:
# some in the future
pass

def on_operation(self) -> Iterator[None]:
self.execution_context.validation_rules = () # remove all validation_rules
def on_operation(self, execution_context: ExecutionContext) -> Iterator[None]:
execution_context.validation_rules = () # remove all validation_rules

Check warning on line 31 in strawberry/extensions/disable_validation.py

View check run for this annotation

Codecov / codecov/patch

strawberry/extensions/disable_validation.py#L31

Added line #L31 was not covered by tests
yield


Expand Down
5 changes: 3 additions & 2 deletions strawberry/extensions/mask_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from graphql.error import GraphQLError

from strawberry.extensions.base_extension import SchemaExtension
from strawberry.types.execution import ExecutionContext


def default_should_mask_error(_: GraphQLError) -> bool:
Expand Down Expand Up @@ -32,9 +33,9 @@ def anonymise_error(self, error: GraphQLError) -> GraphQLError:
original_error=None,
)

def on_operation(self) -> Iterator[None]:
def on_operation(self, execution_context: ExecutionContext) -> Iterator[None]:
yield
result = self.execution_context.result
result = execution_context.result
if result and result.errors:

Check warning on line 39 in strawberry/extensions/mask_errors.py

View check run for this annotation

Codecov / codecov/patch

strawberry/extensions/mask_errors.py#L38-L39

Added lines #L38 - L39 were not covered by tests
processed_errors: List[GraphQLError] = []
for error in result.errors:
Expand Down
5 changes: 3 additions & 2 deletions strawberry/extensions/max_tokens.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Iterator

from strawberry.extensions.base_extension import SchemaExtension
from strawberry.types.execution import ExecutionContext


class MaxTokensLimiter(SchemaExtension):
Expand Down Expand Up @@ -34,8 +35,8 @@ def __init__(
"""
self.max_token_count = max_token_count

def on_operation(self) -> Iterator[None]:
self.execution_context.parse_options["max_tokens"] = self.max_token_count
def on_operation(self, execution_context: ExecutionContext) -> Iterator[None]:
execution_context.parse_options["max_tokens"] = self.max_token_count
yield


Expand Down
5 changes: 2 additions & 3 deletions strawberry/extensions/parser_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from strawberry.extensions.base_extension import SchemaExtension
from strawberry.schema.execute import parse_document
from strawberry.types.execution import ExecutionContext


class ParserCache(SchemaExtension):
Expand Down Expand Up @@ -33,9 +34,7 @@ def __init__(self, maxsize: Optional[int] = None) -> None:
"""
self.cached_parse_document = lru_cache(maxsize=maxsize)(parse_document)

def on_parse(self) -> Iterator[None]:
execution_context = self.execution_context

def on_parse(self, execution_context: ExecutionContext) -> Iterator[None]:
execution_context.graphql_document = self.cached_parse_document(
execution_context.query, **execution_context.parse_options
)
Expand Down
7 changes: 5 additions & 2 deletions strawberry/extensions/pyinstrument.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from __future__ import annotations

from pathlib import Path
from typing import Iterator
from typing import TYPE_CHECKING, Iterator

from pyinstrument import Profiler

from strawberry.extensions.base_extension import SchemaExtension

if TYPE_CHECKING:
from strawberry.types.execution import ExecutionContext


class PyInstrument(SchemaExtension):
"""Extension to profile the execution time of resolvers using PyInstrument."""
Expand All @@ -17,7 +20,7 @@ def __init__(
) -> None:
self._report_path = report_path

def on_operation(self) -> Iterator[None]:
def on_operation(self, execution_context: ExecutionContext) -> Iterator[None]:
profiler = Profiler()
profiler.start()

Expand Down
27 changes: 19 additions & 8 deletions strawberry/extensions/tracing/apollo.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
if TYPE_CHECKING:
from graphql import GraphQLResolveInfo

from strawberry.types.execution import ExecutionContext

DATETIME_FORMAT = "%Y-%m-%dT%H:%M:%S.%fZ"

if TYPE_CHECKING:
Expand Down Expand Up @@ -80,23 +82,30 @@ def to_json(self) -> Dict[str, Any]:


class ApolloTracingExtension(SchemaExtension):
def __init__(self, execution_context: ExecutionContext) -> None:
self._resolver_stats: List[ApolloResolverStats] = []
self.execution_context = execution_context

def on_operation(self) -> Generator[None, None, None]:
def __init__(self) -> None:
self._resolver_stats: List[
ApolloResolverStats
] = [] # TODO: this is probably a bug just as using self.execution_context was a bug.

def on_operation(
self, execution_context: ExecutionContext
) -> Generator[None, None, None]:
self.start_timestamp = self.now()
self.start_time = datetime.now(timezone.utc)
yield
self.end_timestamp = self.now()
self.end_time = datetime.now(timezone.utc)

def on_parse(self) -> Generator[None, None, None]:
def on_parse(
self, execution_context: ExecutionContext
) -> Generator[None, None, None]:
self._start_parsing = self.now()
yield
self._end_parsing = self.now()

def on_validate(self) -> Generator[None, None, None]:
def on_validate(
self, execution_context: ExecutionContext
) -> Generator[None, None, None]:
self._start_validation = self.now()
yield
self._end_validation = self.now()
Expand All @@ -121,7 +130,9 @@ def stats(self) -> ApolloTracingStats:
),
)

def get_results(self) -> Dict[str, Dict[str, Any]]:
def get_results(
self, execution_context: ExecutionContext
) -> Dict[str, Dict[str, Any]]:
return {"tracing": self.stats.to_json()}

async def resolve(
Expand Down
35 changes: 21 additions & 14 deletions strawberry/extensions/tracing/datadog.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import hashlib
from functools import cached_property
from functools import lru_cache
from inspect import isawaitable
from typing import TYPE_CHECKING, Any, Callable, Generator, Iterator, Optional

Expand All @@ -25,15 +25,15 @@ def __init__(
if execution_context:
self.execution_context = execution_context

@cached_property
def _resource_name(self) -> str:
if self.execution_context.query is None:
@lru_cache
def _resource_name(self, execution_context: ExecutionContext) -> str:
if execution_context.query is None:
return "query_missing"

query_hash = self.hash_query(self.execution_context.query)
query_hash = self.hash_query(execution_context.query)

if self.execution_context.operation_name:
return f"{self.execution_context.operation_name}:{query_hash}"
if execution_context.operation_name:
return f"{execution_context.operation_name}:{query_hash}"

return query_hash

Expand All @@ -54,7 +54,7 @@ class CustomExtension(DatadogTracingExtension):
def create_span(self, lifecycle_step, name, **kwargs):
span = super().create_span(lifecycle_step, name, **kwargs)
if lifecycle_step == LifeCycleStep.OPERATION:
span.set_tag("graphql.query", self.execution_context.query)
span.set_tag("graphql.query", execution_context.query)
return span
```
"""
Expand All @@ -67,21 +67,21 @@ def create_span(self, lifecycle_step, name, **kwargs):
def hash_query(self, query: str) -> str:
return hashlib.md5(query.encode("utf-8")).hexdigest()

def on_operation(self) -> Iterator[None]:
self._operation_name = self.execution_context.operation_name
def on_operation(self, execution_context: ExecutionContext) -> Iterator[None]:
self._operation_name = execution_context.operation_name
span_name = (
f"{self._operation_name}" if self._operation_name else "Anonymous Query"
)

self.request_span = self.create_span(
LifecycleStep.OPERATION,
span_name,
resource=self._resource_name,
resource=self._resource_name(execution_context),
service="strawberry",
)
self.request_span.set_tag("graphql.operation_name", self._operation_name)

query = self.execution_context.query
query = execution_context.query

if query is not None:
query = query.strip()
Expand All @@ -100,15 +100,19 @@ def on_operation(self) -> Iterator[None]:

self.request_span.finish()

def on_validate(self) -> Generator[None, None, None]:
def on_validate(
self, execution_context: ExecutionContext
) -> Generator[None, None, None]:
self.validation_span = self.create_span(
lifecycle_step=LifecycleStep.VALIDATION,
name="Validation",
)
yield
self.validation_span.finish()

def on_parse(self) -> Generator[None, None, None]:
def on_parse(
self, execution_context: ExecutionContext
) -> Generator[None, None, None]:
self.parsing_span = self.create_span(
lifecycle_step=LifecycleStep.PARSE,
name="Parsing",
Expand Down Expand Up @@ -150,6 +154,9 @@ async def resolve(

return result

def __hash__(self) -> int:
return id(self)


class DatadogTracingExtensionSync(DatadogTracingExtension):
def resolve(
Expand Down
22 changes: 14 additions & 8 deletions strawberry/extensions/tracing/opentelemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,10 @@ def __init__(
if execution_context:
self.execution_context = execution_context

def on_operation(self) -> Generator[None, None, None]:
self._operation_name = self.execution_context.operation_name
def on_operation(
self, execution_context: ExecutionContext
) -> Generator[None, None, None]:
self._operation_name = execution_context.operation_name
span_name = (
f"GraphQL Query: {self._operation_name}"
if self._operation_name
Expand All @@ -64,9 +66,9 @@ def on_operation(self) -> Generator[None, None, None]:
)
self._span_holder[LifecycleStep.OPERATION].set_attribute("component", "graphql")

if self.execution_context.query:
if execution_context.query:
self._span_holder[LifecycleStep.OPERATION].set_attribute(
"query", self.execution_context.query
"query", execution_context.query
)

yield
Expand All @@ -75,12 +77,14 @@ def on_operation(self) -> Generator[None, None, None]:
# operation but we don't know until the parsing stage has finished. If
# that's the case we want to update the span name so that we have a more
# useful name in our trace.
if not self._operation_name and self.execution_context.operation_name:
span_name = f"GraphQL Query: {self.execution_context.operation_name}"
if not self._operation_name and execution_context.operation_name:
span_name = f"GraphQL Query: {execution_context.operation_name}"
self._span_holder[LifecycleStep.OPERATION].update_name(span_name)
self._span_holder[LifecycleStep.OPERATION].end()

def on_validate(self) -> Generator[None, None, None]:
def on_validate(
self, execution_context: ExecutionContext
) -> Generator[None, None, None]:
ctx = trace.set_span_in_context(self._span_holder[LifecycleStep.OPERATION])
self._span_holder[LifecycleStep.VALIDATION] = self._tracer.start_span(
"GraphQL Validation",
Expand All @@ -89,7 +93,9 @@ def on_validate(self) -> Generator[None, None, None]:
yield
self._span_holder[LifecycleStep.VALIDATION].end()

def on_parse(self) -> Generator[None, None, None]:
def on_parse(
self, execution_context: ExecutionContext
) -> Generator[None, None, None]:
ctx = trace.set_span_in_context(self._span_holder[LifecycleStep.OPERATION])
self._span_holder[LifecycleStep.PARSE] = self._tracer.start_span(
"GraphQL Parsing", context=ctx
Expand Down
Loading

0 comments on commit 069bedb

Please sign in to comment.