|
1 |
| -from typing import get_type_hints |
| 1 | +import typing |
2 | 2 |
|
3 | 3 | import dataclasses
|
4 | 4 | from graphql import GraphQLField
|
5 | 5 |
|
6 | 6 | from .constants import IS_STRAWBERRY_FIELD, IS_STRAWBERRY_INPUT
|
7 | 7 | from .exceptions import MissingArgumentsAnnotationsError, MissingReturnAnnotationError
|
8 |
| -from .type_converter import get_graphql_type_for_annotation |
| 8 | +from .type_converter import REGISTRY, get_graphql_type_for_annotation |
9 | 9 | from .utils.dict_to_type import dict_to_type
|
10 | 10 | from .utils.inspect import get_func_args
|
| 11 | +from .utils.lazy_property import lazy_property |
11 | 12 | from .utils.str_converters import to_camel_case, to_snake_case
|
12 | 13 | from .utils.typing import (
|
13 | 14 | get_list_annotation,
|
|
17 | 18 | )
|
18 | 19 |
|
19 | 20 |
|
| 21 | +class LazyFieldWrapper: |
| 22 | + """A lazy wrapper for a strawberry field. |
| 23 | + This allows to use cyclic dependencies in a strawberry fields: |
| 24 | +
|
| 25 | + >>> @strawberry.type |
| 26 | + >>> class TypeA: |
| 27 | + >>> @strawberry.field |
| 28 | + >>> def type_b(self, info) -> "TypeB": |
| 29 | + >>> from .type_b import TypeB |
| 30 | + >>> return TypeB() |
| 31 | + """ |
| 32 | + |
| 33 | + def __init__(self, obj, is_subscription, **kwargs): |
| 34 | + self._wrapped_obj = obj |
| 35 | + self.is_subscription = is_subscription |
| 36 | + self.kwargs = kwargs |
| 37 | + |
| 38 | + if callable(self._wrapped_obj): |
| 39 | + self._check_has_annotations() |
| 40 | + |
| 41 | + def _check_has_annotations(self): |
| 42 | + # using annotations without passing from typing.get_type_hints |
| 43 | + # as we don't the actually types for the annotations |
| 44 | + annotations = self._wrapped_obj.__annotations__ |
| 45 | + name = self._wrapped_obj.__name__ |
| 46 | + |
| 47 | + if "return" not in annotations: |
| 48 | + raise MissingReturnAnnotationError(name) |
| 49 | + |
| 50 | + function_arguments = set(get_func_args(self._wrapped_obj)) - {"self", "info"} |
| 51 | + |
| 52 | + arguments_annotations = { |
| 53 | + key: value |
| 54 | + for key, value in annotations.items() |
| 55 | + if key not in ["info", "return"] |
| 56 | + } |
| 57 | + |
| 58 | + annotated_function_arguments = set(arguments_annotations.keys()) |
| 59 | + arguments_missing_annotations = ( |
| 60 | + function_arguments - annotated_function_arguments |
| 61 | + ) |
| 62 | + |
| 63 | + if len(arguments_missing_annotations) > 0: |
| 64 | + raise MissingArgumentsAnnotationsError(name, arguments_missing_annotations) |
| 65 | + |
| 66 | + def __getattr__(self, attr): |
| 67 | + if attr in self.__dict__: |
| 68 | + return getattr(self, attr) |
| 69 | + |
| 70 | + return getattr(self._wrapped_obj, attr) |
| 71 | + |
| 72 | + def __call__(self, *args, **kwargs): |
| 73 | + return self._wrapped_obj(self, *args, **kwargs) |
| 74 | + |
| 75 | + @lazy_property |
| 76 | + def field(self): |
| 77 | + return _get_field( |
| 78 | + self._wrapped_obj, is_subscription=self.is_subscription, **self.kwargs |
| 79 | + ) |
| 80 | + |
| 81 | + |
20 | 82 | class strawberry_field:
|
21 | 83 | """A small wrapper for a field in strawberry.
|
22 | 84 |
|
@@ -51,10 +113,7 @@ def __call__(self, wrap):
|
51 | 113 |
|
52 | 114 | self.kwargs["description"] = self.description or wrap.__doc__
|
53 | 115 |
|
54 |
| - wrap.field = _get_field( |
55 |
| - wrap, is_subscription=self.is_subscription, **self.kwargs |
56 |
| - ) |
57 |
| - return wrap |
| 116 | + return LazyFieldWrapper(wrap, self.is_subscription, **self.kwargs) |
58 | 117 |
|
59 | 118 |
|
60 | 119 | def convert_args(args, annotations):
|
@@ -91,29 +150,18 @@ def convert_args(args, annotations):
|
91 | 150 |
|
92 | 151 |
|
93 | 152 | def _get_field(wrap, *, is_subscription=False, **kwargs):
|
94 |
| - annotations = get_type_hints(wrap) |
| 153 | + annotations = typing.get_type_hints(wrap, None, REGISTRY) |
95 | 154 |
|
96 | 155 | name = wrap.__name__
|
97 | 156 |
|
98 |
| - if "return" not in annotations: |
99 |
| - raise MissingReturnAnnotationError(name) |
100 |
| - |
101 | 157 | field_type = get_graphql_type_for_annotation(annotations["return"], name)
|
102 | 158 |
|
103 |
| - function_arguments = set(get_func_args(wrap)) - {"self", "info"} |
104 |
| - |
105 | 159 | arguments_annotations = {
|
106 | 160 | key: value
|
107 | 161 | for key, value in annotations.items()
|
108 | 162 | if key not in ["info", "return"]
|
109 | 163 | }
|
110 | 164 |
|
111 |
| - annotated_function_arguments = set(arguments_annotations.keys()) |
112 |
| - arguments_missing_annotations = function_arguments - annotated_function_arguments |
113 |
| - |
114 |
| - if len(arguments_missing_annotations) > 0: |
115 |
| - raise MissingArgumentsAnnotationsError(name, arguments_missing_annotations) |
116 |
| - |
117 | 165 | arguments = {
|
118 | 166 | to_camel_case(name): get_graphql_type_for_annotation(annotation, name)
|
119 | 167 | for name, annotation in arguments_annotations.items()
|
|
0 commit comments