Skip to content

Commit 48cbf9d

Browse files
authored
Fix broken circular imports (#50)
Fix broken circular imports
2 parents b4c342a + 9427dff commit 48cbf9d

File tree

8 files changed

+161
-19
lines changed

8 files changed

+161
-19
lines changed

.isort.cfg

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,5 @@ known_pytest=pytest
1111
known_future_library=future
1212
known_standard_library=types,requests
1313
default_section=THIRDPARTY
14-
known_third_party = click,graphql,hupper,pygments,starlette,uvicorn
14+
known_third_party = base,click,github_release,graphql,hupper,pygments,starlette,uvicorn
1515
sections=FUTURE,STDLIB,PYTEST,DJANGO,THIRDPARTY,FIRSTPARTY,LOCALFOLDER

RELEASE.md

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
Release type: minor
2+
3+
This changes field to be lazy by default, allowing to use circular dependencies
4+
when declaring types.

strawberry/field.py

+66-18
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
1-
from typing import get_type_hints
1+
import typing
22

33
import dataclasses
44
from graphql import GraphQLField
55

66
from .constants import IS_STRAWBERRY_FIELD, IS_STRAWBERRY_INPUT
77
from .exceptions import MissingArgumentsAnnotationsError, MissingReturnAnnotationError
8-
from .type_converter import get_graphql_type_for_annotation
8+
from .type_converter import REGISTRY, get_graphql_type_for_annotation
99
from .utils.dict_to_type import dict_to_type
1010
from .utils.inspect import get_func_args
11+
from .utils.lazy_property import lazy_property
1112
from .utils.str_converters import to_camel_case, to_snake_case
1213
from .utils.typing import (
1314
get_list_annotation,
@@ -17,6 +18,67 @@
1718
)
1819

1920

21+
class LazyFieldWrapper:
22+
"""A lazy wrapper for a strawberry field.
23+
This allows to use cyclic dependencies in a strawberry fields:
24+
25+
>>> @strawberry.type
26+
>>> class TypeA:
27+
>>> @strawberry.field
28+
>>> def type_b(self, info) -> "TypeB":
29+
>>> from .type_b import TypeB
30+
>>> return TypeB()
31+
"""
32+
33+
def __init__(self, obj, is_subscription, **kwargs):
34+
self._wrapped_obj = obj
35+
self.is_subscription = is_subscription
36+
self.kwargs = kwargs
37+
38+
if callable(self._wrapped_obj):
39+
self._check_has_annotations()
40+
41+
def _check_has_annotations(self):
42+
# using annotations without passing from typing.get_type_hints
43+
# as we don't the actually types for the annotations
44+
annotations = self._wrapped_obj.__annotations__
45+
name = self._wrapped_obj.__name__
46+
47+
if "return" not in annotations:
48+
raise MissingReturnAnnotationError(name)
49+
50+
function_arguments = set(get_func_args(self._wrapped_obj)) - {"self", "info"}
51+
52+
arguments_annotations = {
53+
key: value
54+
for key, value in annotations.items()
55+
if key not in ["info", "return"]
56+
}
57+
58+
annotated_function_arguments = set(arguments_annotations.keys())
59+
arguments_missing_annotations = (
60+
function_arguments - annotated_function_arguments
61+
)
62+
63+
if len(arguments_missing_annotations) > 0:
64+
raise MissingArgumentsAnnotationsError(name, arguments_missing_annotations)
65+
66+
def __getattr__(self, attr):
67+
if attr in self.__dict__:
68+
return getattr(self, attr)
69+
70+
return getattr(self._wrapped_obj, attr)
71+
72+
def __call__(self, *args, **kwargs):
73+
return self._wrapped_obj(self, *args, **kwargs)
74+
75+
@lazy_property
76+
def field(self):
77+
return _get_field(
78+
self._wrapped_obj, is_subscription=self.is_subscription, **self.kwargs
79+
)
80+
81+
2082
class strawberry_field:
2183
"""A small wrapper for a field in strawberry.
2284
@@ -51,10 +113,7 @@ def __call__(self, wrap):
51113

52114
self.kwargs["description"] = self.description or wrap.__doc__
53115

54-
wrap.field = _get_field(
55-
wrap, is_subscription=self.is_subscription, **self.kwargs
56-
)
57-
return wrap
116+
return LazyFieldWrapper(wrap, self.is_subscription, **self.kwargs)
58117

59118

60119
def convert_args(args, annotations):
@@ -91,29 +150,18 @@ def convert_args(args, annotations):
91150

92151

93152
def _get_field(wrap, *, is_subscription=False, **kwargs):
94-
annotations = get_type_hints(wrap)
153+
annotations = typing.get_type_hints(wrap, None, REGISTRY)
95154

96155
name = wrap.__name__
97156

98-
if "return" not in annotations:
99-
raise MissingReturnAnnotationError(name)
100-
101157
field_type = get_graphql_type_for_annotation(annotations["return"], name)
102158

103-
function_arguments = set(get_func_args(wrap)) - {"self", "info"}
104-
105159
arguments_annotations = {
106160
key: value
107161
for key, value in annotations.items()
108162
if key not in ["info", "return"]
109163
}
110164

111-
annotated_function_arguments = set(arguments_annotations.keys())
112-
arguments_missing_annotations = function_arguments - annotated_function_arguments
113-
114-
if len(arguments_missing_annotations) > 0:
115-
raise MissingArgumentsAnnotationsError(name, arguments_missing_annotations)
116-
117165
arguments = {
118166
to_camel_case(name): get_graphql_type_for_annotation(annotation, name)
119167
for name, annotation in arguments_annotations.items()

strawberry/utils/lazy_property.py

+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
def lazy_property(fn):
2+
"""Decorator that makes a property lazy-evaluated."""
3+
attr_name = "_lazy_" + fn.__name__
4+
5+
@property
6+
def _lazy_property(self):
7+
if not hasattr(self, attr_name):
8+
setattr(self, attr_name, fn(self))
9+
return getattr(self, attr_name)
10+
11+
return _lazy_property

tests/test_cyclic/__init__.py

Whitespace-only changes.

tests/test_cyclic/test_cyclic.py

+47
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import textwrap
2+
3+
import strawberry
4+
5+
6+
def test_cyclic_import():
7+
from .type_a import TypeA
8+
from .type_b import TypeB
9+
10+
@strawberry.type
11+
class Query:
12+
a: TypeA
13+
b: TypeB
14+
15+
assert (
16+
repr(Query(None, None))
17+
== textwrap.dedent(
18+
"""
19+
type Query {
20+
a: TypeA!
21+
b: TypeB!
22+
}
23+
"""
24+
).strip()
25+
)
26+
27+
assert (
28+
repr(TypeA())
29+
== textwrap.dedent(
30+
"""
31+
type TypeA {
32+
typeB: TypeB!
33+
}
34+
"""
35+
).strip()
36+
)
37+
38+
assert (
39+
repr(TypeB())
40+
== textwrap.dedent(
41+
"""
42+
type TypeB {
43+
typeA: TypeA!
44+
}
45+
"""
46+
).strip()
47+
)

tests/test_cyclic/type_a.py

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import typing
2+
3+
import strawberry
4+
5+
6+
if typing.TYPE_CHECKING:
7+
from .type_b import TypeB
8+
9+
10+
@strawberry.type
11+
class TypeA:
12+
@strawberry.field
13+
def type_b(self, info) -> "TypeB":
14+
from .type_b import TypeB
15+
16+
return TypeB()

tests/test_cyclic/type_b.py

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import typing
2+
3+
import strawberry
4+
5+
6+
if typing.TYPE_CHECKING:
7+
from .type_a import TypeA
8+
9+
10+
@strawberry.type
11+
class TypeB:
12+
@strawberry.field
13+
def type_a(self, info) -> "TypeA":
14+
from .type_a import TypeA
15+
16+
return TypeA()

0 commit comments

Comments
 (0)