Skip to content

Commit

Permalink
Merge pull request #119 from mirumee/default_values_cycles
Browse files Browse the repository at this point in the history
Default values cycles
  • Loading branch information
mat-sop committed Apr 4, 2023
2 parents 24b9968 + 08cba1e commit 65b3410
Show file tree
Hide file tree
Showing 9 changed files with 224 additions and 119 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
- Fixed support of custom operation types names.
- Unlocked versions of black, isort, autoflake and dev dependencies
- Added `remote_schema_verify_ssl` option.
- Changed how default values for inputs are generated to handle potential cycles.


## 0.4.0 (2023-03-20)
Expand Down
26 changes: 13 additions & 13 deletions EXAMPLE.md
Original file line number Diff line number Diff line change
Expand Up @@ -259,11 +259,6 @@ from .base_model import BaseModel
from .enums import Color


class LocationInput(BaseModel):
city: Optional[str]
country: Optional[str]


class UserCreateInput(BaseModel):
first_name: Optional[str] = Field(alias="firstName")
last_name: Optional[str] = Field(alias="lastName")
Expand All @@ -272,11 +267,9 @@ class UserCreateInput(BaseModel):
location: Optional["LocationInput"]


class NotificationsPreferencesInput(BaseModel):
receive_mails: bool = Field(alias="receiveMails")
receive_push_notifications: bool = Field(alias="receivePushNotifications")
receive_sms: bool = Field(alias="receiveSms")
title: str
class LocationInput(BaseModel):
city: Optional[str]
country: Optional[str]


class UserPreferencesInput(BaseModel):
Expand All @@ -288,7 +281,7 @@ class UserPreferencesInput(BaseModel):
)
notifications_preferences: "NotificationsPreferencesInput" = Field(
alias="notificationsPreferences",
default=NotificationsPreferencesInput.parse_obj(
default_factory=lambda: globals()["NotificationsPreferencesInput"].parse_obj(
{
"receiveMails": True,
"receivePushNotifications": True,
Expand All @@ -299,10 +292,17 @@ class UserPreferencesInput(BaseModel):
)


LocationInput.update_forward_refs()
class NotificationsPreferencesInput(BaseModel):
receive_mails: bool = Field(alias="receiveMails")
receive_push_notifications: bool = Field(alias="receivePushNotifications")
receive_sms: bool = Field(alias="receiveSms")
title: str


UserCreateInput.update_forward_refs()
NotificationsPreferencesInput.update_forward_refs()
LocationInput.update_forward_refs()
UserPreferencesInput.update_forward_refs()
NotificationsPreferencesInput.update_forward_refs()
```

### Enums
Expand Down
26 changes: 24 additions & 2 deletions ariadne_codegen/client_generators/input_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,16 @@

from ..codegen import (
generate_annotation_name,
generate_attribute,
generate_call,
generate_constant,
generate_dict,
generate_keyword,
generate_lambda,
generate_list,
generate_list_annotation,
generate_method_call,
generate_name,
generate_subscript,
)
from ..exceptions import ParsingError
from .constants import ANY, FIELD_CLASS, SIMPLE_TYPE_MAP
Expand Down Expand Up @@ -161,7 +162,28 @@ def parse_input_const_value_node(
],
)
if not nested_object:
return generate_method_call(field_type, "parse_obj", [dict_])
return generate_call(
func=generate_name(FIELD_CLASS),
keywords=[
generate_keyword(
arg="default_factory",
value=generate_lambda(
body=generate_call(
func=generate_attribute(
value=generate_subscript(
value=generate_call(
func=generate_name("globals")
),
slice_=generate_constant(field_type),
),
attr="parse_obj",
),
args=[dict_],
)
),
)
],
)
return dict_

return None
43 changes: 5 additions & 38 deletions ariadne_codegen/client_generators/input_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,14 +80,13 @@ def generate(self) -> ast.Module:
names=scalar_data.names_to_import, from_=scalar_data.import_
)
)
sorted_class_defs = self._get_sorted_class_defs()
update_forward_refs_calls = [
generate_expr(generate_method_call(c.name, UPDATE_FORWARD_REFS_METHOD))
for c in sorted_class_defs
for c in self._class_defs
]
module_body = (
cast(List[ast.stmt], self._imports)
+ cast(List[ast.stmt], sorted_class_defs)
+ cast(List[ast.stmt], self._class_defs)
+ cast(List[ast.stmt], update_forward_refs_calls)
)
module = generate_module(body=module_body)
Expand Down Expand Up @@ -136,9 +135,7 @@ def _parse_input_definition(
field_implementation, input_field=field, field_name=org_name
)
class_def.body.append(field_implementation)
self._save_used_enums_scalars_and_dependencies(
class_name=class_def.name, field_type=field_type
)
self._save_used_enums_and_scalars(field_type=field_type)

if self.plugin_manager:
class_def = self.plugin_manager.generate_input_class(
Expand Down Expand Up @@ -174,40 +171,10 @@ def _process_field_value(
)
return field_with_alias

def _save_used_enums_scalars_and_dependencies(
self, class_name: str, field_type: str = ""
) -> None:
def _save_used_enums_and_scalars(self, field_type: str = "") -> None:
if not field_type:
return
if isinstance(self.schema.type_map[field_type], GraphQLInputObjectType):
self._dependencies[class_name].append(field_type)
elif isinstance(self.schema.type_map[field_type], GraphQLEnumType):
if isinstance(self.schema.type_map[field_type], GraphQLEnumType):
self._used_enums.append(field_type)
elif isinstance(self.schema.type_map[field_type], GraphQLScalarType):
self._used_scalars.append(field_type)

def _get_sorted_class_defs(self) -> List[ast.ClassDef]:
input_class_defs_dict_ = {c.name: c for c in self._class_defs}

processed_names = []
for class_ in self._class_defs:
if class_.name not in processed_names:
processed_names.extend(self._get_dependant_names(class_.name))
processed_names.append(class_.name)

names_without_duplicates = self._get_list_without_duplicates(processed_names)
return [input_class_defs_dict_[n] for n in names_without_duplicates]

def _get_dependant_names(self, name: str) -> List[str]:
result = []
for dependency_name in self._dependencies[name]:
result.extend(self._get_dependant_names(dependency_name))
result.append(dependency_name)
return result

def _get_list_without_duplicates(self, list_: list) -> list:
result = []
for element in list_:
if not element in result:
result.append(element)
return result
Original file line number Diff line number Diff line change
Expand Up @@ -148,19 +148,49 @@ def test_generate_returns_module_with_parsed_inputs_object_field_with_default_va
}
"""
expected_field_value = ast.Call(
func=ast.Attribute(value=ast.Name(id="SecondInput"), attr="parse_obj"),
args=[
ast.Dict(keys=[ast.Constant(value="val")], values=[ast.Constant(value=5)])
func=ast.Name(id="Field"),
args=[],
keywords=[
ast.keyword(
arg="default_factory",
value=ast.Lambda(
args=ast.arguments(
posonlyargs=[],
args=[],
kwonlyargs=[],
kw_defaults=[],
defaults=[],
),
body=ast.Call(
func=ast.Attribute(
value=ast.Subscript(
value=ast.Call(
func=ast.Name(id="globals"), args=[], keywords=[]
),
slice=ast.Constant(value="SecondInput"),
),
attr="parse_obj",
),
args=[
ast.Dict(
keys=[ast.Constant(value="val")],
values=[ast.Constant(value=5)],
)
],
keywords=[],
),
),
)
],
keywords=[],
)

generator = InputTypesGenerator(
schema=build_ast_schema(parse(schema_str)), enums_module="enums"
)

module = generator.generate()

class_def = get_class_def(module, 1)
class_def = get_class_def(module, 0)
assert isinstance(class_def, ast.ClassDef)
assert class_def.name == "TestInput"
assert len(class_def.body) == 1
Expand All @@ -185,27 +215,53 @@ def test_generate_returns_module_with_parsed_nested_object_as_default_value():
}
"""
expected_field_value = ast.Call(
func=ast.Attribute(value=ast.Name(id="SecondInput"), attr="parse_obj"),
args=[
ast.Dict(
keys=[ast.Constant(value="nested")],
values=[
ast.Dict(
keys=[ast.Constant(value="val")],
values=[ast.Constant(value=1.5)],
)
],
func=ast.Name(id="Field"),
args=[],
keywords=[
ast.keyword(
arg="default_factory",
value=ast.Lambda(
args=ast.arguments(
posonlyargs=[],
args=[],
kwonlyargs=[],
kw_defaults=[],
defaults=[],
),
body=ast.Call(
func=ast.Attribute(
value=ast.Subscript(
value=ast.Call(
func=ast.Name(id="globals"), args=[], keywords=[]
),
slice=ast.Constant(value="SecondInput"),
),
attr="parse_obj",
),
args=[
ast.Dict(
keys=[ast.Constant(value="nested")],
values=[
ast.Dict(
keys=[ast.Constant(value="val")],
values=[ast.Constant(value=1.5)],
)
],
)
],
keywords=[],
),
),
)
],
keywords=[],
)
generator = InputTypesGenerator(
schema=build_ast_schema(parse(schema_str)), enums_module="enums"
)

module = generator.generate()

class_def = get_class_def(module, 2)
class_def = get_class_def(module, 0)
assert isinstance(class_def, ast.ClassDef)
assert class_def.name == "TestInput"
assert len(class_def.body) == 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def test_generate_returns_modules_with_update_forward_refs_calls():
ast.Expr(
value=ast.Call(
func=ast.Attribute(
value=ast.Name(id="NestedInput"), attr=UPDATE_FORWARD_REFS_METHOD
value=ast.Name(id="TestInput"), attr=UPDATE_FORWARD_REFS_METHOD
),
args=[],
keywords=[],
Expand All @@ -44,7 +44,7 @@ def test_generate_returns_modules_with_update_forward_refs_calls():
ast.Expr(
value=ast.Call(
func=ast.Attribute(
value=ast.Name(id="TestInput"), attr=UPDATE_FORWARD_REFS_METHOD
value=ast.Name(id="NestedInput"), attr=UPDATE_FORWARD_REFS_METHOD
),
args=[],
keywords=[],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,34 +25,34 @@
""",
[
ast.ClassDef(
name="CustomInput2",
name="CustomInput",
bases=[ast.Name(id=BASE_MODEL_CLASS_NAME)],
keywords=[],
decorator_list=[],
body=[
ast.AnnAssign(
target=ast.Name(id="field"),
target=ast.Name(id="field1"),
annotation=ast.Name(id='"CustomInput2"'),
simple=1,
),
ast.AnnAssign(
target=ast.Name(id="field2"),
annotation=ast.Name(id="int"),
simple=1,
)
),
],
),
ast.ClassDef(
name="CustomInput",
name="CustomInput2",
bases=[ast.Name(id=BASE_MODEL_CLASS_NAME)],
keywords=[],
decorator_list=[],
body=[
ast.AnnAssign(
target=ast.Name(id="field1"),
annotation=ast.Name(id='"CustomInput2"'),
simple=1,
),
ast.AnnAssign(
target=ast.Name(id="field2"),
target=ast.Name(id="field"),
annotation=ast.Name(id="int"),
simple=1,
),
)
],
),
],
Expand All @@ -72,7 +72,7 @@ def test_generate_returns_module_with_parsed_input_types(
assert compare_ast(class_defs, expected_class_defs)


def test_generate_returns_module_with_correct_order_of_classes():
def test_generate_returns_module_with_classes_in_the_same_order_as_declared():
schema_str = """
input BeforeInput {
field: Boolean!
Expand All @@ -96,9 +96,9 @@ def test_generate_returns_module_with_correct_order_of_classes():
"""
expected_order = [
"BeforeInput",
"NestedInput",
"SecondInput",
"TestInput",
"SecondInput",
"NestedInput",
"AfterInput",
]
generator = InputTypesGenerator(
Expand Down
Loading

0 comments on commit 65b3410

Please sign in to comment.