diff --git a/CHANGELOG.md b/CHANGELOG.md index 2d1d7983..3994533e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ - Fixed support of custom operation types names. - Unlocked versions of black, isort, autoflake and dev dependencies - Added `remote_schema_verify_ssl` option. +- Changed how default values for inputs are generated to handle potential cycles. ## 0.4.0 (2023-03-20) diff --git a/EXAMPLE.md b/EXAMPLE.md index 59767849..40257d0e 100644 --- a/EXAMPLE.md +++ b/EXAMPLE.md @@ -259,11 +259,6 @@ from .base_model import BaseModel from .enums import Color -class LocationInput(BaseModel): - city: Optional[str] - country: Optional[str] - - class UserCreateInput(BaseModel): first_name: Optional[str] = Field(alias="firstName") last_name: Optional[str] = Field(alias="lastName") @@ -272,11 +267,9 @@ class UserCreateInput(BaseModel): location: Optional["LocationInput"] -class NotificationsPreferencesInput(BaseModel): - receive_mails: bool = Field(alias="receiveMails") - receive_push_notifications: bool = Field(alias="receivePushNotifications") - receive_sms: bool = Field(alias="receiveSms") - title: str +class LocationInput(BaseModel): + city: Optional[str] + country: Optional[str] class UserPreferencesInput(BaseModel): @@ -288,7 +281,7 @@ class UserPreferencesInput(BaseModel): ) notifications_preferences: "NotificationsPreferencesInput" = Field( alias="notificationsPreferences", - default=NotificationsPreferencesInput.parse_obj( + default_factory=lambda: globals()["NotificationsPreferencesInput"].parse_obj( { "receiveMails": True, "receivePushNotifications": True, @@ -299,10 +292,17 @@ class UserPreferencesInput(BaseModel): ) -LocationInput.update_forward_refs() +class NotificationsPreferencesInput(BaseModel): + receive_mails: bool = Field(alias="receiveMails") + receive_push_notifications: bool = Field(alias="receivePushNotifications") + receive_sms: bool = Field(alias="receiveSms") + title: str + + UserCreateInput.update_forward_refs() -NotificationsPreferencesInput.update_forward_refs() +LocationInput.update_forward_refs() UserPreferencesInput.update_forward_refs() +NotificationsPreferencesInput.update_forward_refs() ``` ### Enums diff --git a/ariadne_codegen/client_generators/input_fields.py b/ariadne_codegen/client_generators/input_fields.py index 0483495b..dd2b9d46 100644 --- a/ariadne_codegen/client_generators/input_fields.py +++ b/ariadne_codegen/client_generators/input_fields.py @@ -21,6 +21,7 @@ from ..codegen import ( generate_annotation_name, + generate_attribute, generate_call, generate_constant, generate_dict, @@ -28,8 +29,8 @@ generate_lambda, generate_list, generate_list_annotation, - generate_method_call, generate_name, + generate_subscript, ) from ..exceptions import ParsingError from .constants import ANY, FIELD_CLASS, SIMPLE_TYPE_MAP @@ -161,7 +162,28 @@ def parse_input_const_value_node( ], ) if not nested_object: - return generate_method_call(field_type, "parse_obj", [dict_]) + return generate_call( + func=generate_name(FIELD_CLASS), + keywords=[ + generate_keyword( + arg="default_factory", + value=generate_lambda( + body=generate_call( + func=generate_attribute( + value=generate_subscript( + value=generate_call( + func=generate_name("globals") + ), + slice_=generate_constant(field_type), + ), + attr="parse_obj", + ), + args=[dict_], + ) + ), + ) + ], + ) return dict_ return None diff --git a/ariadne_codegen/client_generators/input_types.py b/ariadne_codegen/client_generators/input_types.py index 9416aca6..e89fd179 100644 --- a/ariadne_codegen/client_generators/input_types.py +++ b/ariadne_codegen/client_generators/input_types.py @@ -80,14 +80,13 @@ def generate(self) -> ast.Module: names=scalar_data.names_to_import, from_=scalar_data.import_ ) ) - sorted_class_defs = self._get_sorted_class_defs() update_forward_refs_calls = [ generate_expr(generate_method_call(c.name, UPDATE_FORWARD_REFS_METHOD)) - for c in sorted_class_defs + for c in self._class_defs ] module_body = ( cast(List[ast.stmt], self._imports) - + cast(List[ast.stmt], sorted_class_defs) + + cast(List[ast.stmt], self._class_defs) + cast(List[ast.stmt], update_forward_refs_calls) ) module = generate_module(body=module_body) @@ -136,9 +135,7 @@ def _parse_input_definition( field_implementation, input_field=field, field_name=org_name ) class_def.body.append(field_implementation) - self._save_used_enums_scalars_and_dependencies( - class_name=class_def.name, field_type=field_type - ) + self._save_used_enums_and_scalars(field_type=field_type) if self.plugin_manager: class_def = self.plugin_manager.generate_input_class( @@ -174,40 +171,10 @@ def _process_field_value( ) return field_with_alias - def _save_used_enums_scalars_and_dependencies( - self, class_name: str, field_type: str = "" - ) -> None: + def _save_used_enums_and_scalars(self, field_type: str = "") -> None: if not field_type: return - if isinstance(self.schema.type_map[field_type], GraphQLInputObjectType): - self._dependencies[class_name].append(field_type) - elif isinstance(self.schema.type_map[field_type], GraphQLEnumType): + if isinstance(self.schema.type_map[field_type], GraphQLEnumType): self._used_enums.append(field_type) elif isinstance(self.schema.type_map[field_type], GraphQLScalarType): self._used_scalars.append(field_type) - - def _get_sorted_class_defs(self) -> List[ast.ClassDef]: - input_class_defs_dict_ = {c.name: c for c in self._class_defs} - - processed_names = [] - for class_ in self._class_defs: - if class_.name not in processed_names: - processed_names.extend(self._get_dependant_names(class_.name)) - processed_names.append(class_.name) - - names_without_duplicates = self._get_list_without_duplicates(processed_names) - return [input_class_defs_dict_[n] for n in names_without_duplicates] - - def _get_dependant_names(self, name: str) -> List[str]: - result = [] - for dependency_name in self._dependencies[name]: - result.extend(self._get_dependant_names(dependency_name)) - result.append(dependency_name) - return result - - def _get_list_without_duplicates(self, list_: list) -> list: - result = [] - for element in list_: - if not element in result: - result.append(element) - return result diff --git a/tests/client_generators/input_types_generator/test_default_values.py b/tests/client_generators/input_types_generator/test_default_values.py index 471a814c..ba6fe64d 100644 --- a/tests/client_generators/input_types_generator/test_default_values.py +++ b/tests/client_generators/input_types_generator/test_default_values.py @@ -148,19 +148,49 @@ def test_generate_returns_module_with_parsed_inputs_object_field_with_default_va } """ expected_field_value = ast.Call( - func=ast.Attribute(value=ast.Name(id="SecondInput"), attr="parse_obj"), - args=[ - ast.Dict(keys=[ast.Constant(value="val")], values=[ast.Constant(value=5)]) + func=ast.Name(id="Field"), + args=[], + keywords=[ + ast.keyword( + arg="default_factory", + value=ast.Lambda( + args=ast.arguments( + posonlyargs=[], + args=[], + kwonlyargs=[], + kw_defaults=[], + defaults=[], + ), + body=ast.Call( + func=ast.Attribute( + value=ast.Subscript( + value=ast.Call( + func=ast.Name(id="globals"), args=[], keywords=[] + ), + slice=ast.Constant(value="SecondInput"), + ), + attr="parse_obj", + ), + args=[ + ast.Dict( + keys=[ast.Constant(value="val")], + values=[ast.Constant(value=5)], + ) + ], + keywords=[], + ), + ), + ) ], - keywords=[], ) + generator = InputTypesGenerator( schema=build_ast_schema(parse(schema_str)), enums_module="enums" ) module = generator.generate() - class_def = get_class_def(module, 1) + class_def = get_class_def(module, 0) assert isinstance(class_def, ast.ClassDef) assert class_def.name == "TestInput" assert len(class_def.body) == 1 @@ -185,19 +215,45 @@ def test_generate_returns_module_with_parsed_nested_object_as_default_value(): } """ expected_field_value = ast.Call( - func=ast.Attribute(value=ast.Name(id="SecondInput"), attr="parse_obj"), - args=[ - ast.Dict( - keys=[ast.Constant(value="nested")], - values=[ - ast.Dict( - keys=[ast.Constant(value="val")], - values=[ast.Constant(value=1.5)], - ) - ], + func=ast.Name(id="Field"), + args=[], + keywords=[ + ast.keyword( + arg="default_factory", + value=ast.Lambda( + args=ast.arguments( + posonlyargs=[], + args=[], + kwonlyargs=[], + kw_defaults=[], + defaults=[], + ), + body=ast.Call( + func=ast.Attribute( + value=ast.Subscript( + value=ast.Call( + func=ast.Name(id="globals"), args=[], keywords=[] + ), + slice=ast.Constant(value="SecondInput"), + ), + attr="parse_obj", + ), + args=[ + ast.Dict( + keys=[ast.Constant(value="nested")], + values=[ + ast.Dict( + keys=[ast.Constant(value="val")], + values=[ast.Constant(value=1.5)], + ) + ], + ) + ], + keywords=[], + ), + ), ) ], - keywords=[], ) generator = InputTypesGenerator( schema=build_ast_schema(parse(schema_str)), enums_module="enums" @@ -205,7 +261,7 @@ def test_generate_returns_module_with_parsed_nested_object_as_default_value(): module = generator.generate() - class_def = get_class_def(module, 2) + class_def = get_class_def(module, 0) assert isinstance(class_def, ast.ClassDef) assert class_def.name == "TestInput" assert len(class_def.body) == 1 diff --git a/tests/client_generators/input_types_generator/test_method_calls.py b/tests/client_generators/input_types_generator/test_method_calls.py index 88c8e7ae..a6d846b9 100644 --- a/tests/client_generators/input_types_generator/test_method_calls.py +++ b/tests/client_generators/input_types_generator/test_method_calls.py @@ -35,7 +35,7 @@ def test_generate_returns_modules_with_update_forward_refs_calls(): ast.Expr( value=ast.Call( func=ast.Attribute( - value=ast.Name(id="NestedInput"), attr=UPDATE_FORWARD_REFS_METHOD + value=ast.Name(id="TestInput"), attr=UPDATE_FORWARD_REFS_METHOD ), args=[], keywords=[], @@ -44,7 +44,7 @@ def test_generate_returns_modules_with_update_forward_refs_calls(): ast.Expr( value=ast.Call( func=ast.Attribute( - value=ast.Name(id="TestInput"), attr=UPDATE_FORWARD_REFS_METHOD + value=ast.Name(id="NestedInput"), attr=UPDATE_FORWARD_REFS_METHOD ), args=[], keywords=[], diff --git a/tests/client_generators/input_types_generator/test_parsing_inputs.py b/tests/client_generators/input_types_generator/test_parsing_inputs.py index 78ecf367..a9b33efc 100644 --- a/tests/client_generators/input_types_generator/test_parsing_inputs.py +++ b/tests/client_generators/input_types_generator/test_parsing_inputs.py @@ -25,34 +25,34 @@ """, [ ast.ClassDef( - name="CustomInput2", + name="CustomInput", bases=[ast.Name(id=BASE_MODEL_CLASS_NAME)], keywords=[], decorator_list=[], body=[ ast.AnnAssign( - target=ast.Name(id="field"), + target=ast.Name(id="field1"), + annotation=ast.Name(id='"CustomInput2"'), + simple=1, + ), + ast.AnnAssign( + target=ast.Name(id="field2"), annotation=ast.Name(id="int"), simple=1, - ) + ), ], ), ast.ClassDef( - name="CustomInput", + name="CustomInput2", bases=[ast.Name(id=BASE_MODEL_CLASS_NAME)], keywords=[], decorator_list=[], body=[ ast.AnnAssign( - target=ast.Name(id="field1"), - annotation=ast.Name(id='"CustomInput2"'), - simple=1, - ), - ast.AnnAssign( - target=ast.Name(id="field2"), + target=ast.Name(id="field"), annotation=ast.Name(id="int"), simple=1, - ), + ) ], ), ], @@ -72,7 +72,7 @@ def test_generate_returns_module_with_parsed_input_types( assert compare_ast(class_defs, expected_class_defs) -def test_generate_returns_module_with_correct_order_of_classes(): +def test_generate_returns_module_with_classes_in_the_same_order_as_declared(): schema_str = """ input BeforeInput { field: Boolean! @@ -96,9 +96,9 @@ def test_generate_returns_module_with_correct_order_of_classes(): """ expected_order = [ "BeforeInput", - "NestedInput", - "SecondInput", "TestInput", + "SecondInput", + "NestedInput", "AfterInput", ] generator = InputTypesGenerator( diff --git a/tests/client_generators/test_input_fields.py b/tests/client_generators/test_input_fields.py index e7beb9f6..924a814f 100644 --- a/tests/client_generators/test_input_fields.py +++ b/tests/client_generators/test_input_fields.py @@ -282,17 +282,48 @@ def test_parse_input_const_value_node_given_list_returns_correct_method_call( ), "TestInput", ast.Call( - func=ast.Attribute(value=ast.Name(id="TestInput"), attr="parse_obj"), - args=[ - ast.Dict( - keys=[ - ast.Constant(value="fieldA"), - ast.Constant(value="fieldB"), - ], - values=[ast.Constant(value="a"), ast.Constant(value="B")], + func=ast.Name(id="Field"), + args=[], + keywords=[ + ast.keyword( + arg="default_factory", + value=ast.Lambda( + args=ast.arguments( + posonlyargs=[], + args=[], + kwonlyargs=[], + kw_defaults=[], + defaults=[], + ), + body=ast.Call( + func=ast.Attribute( + value=ast.Subscript( + value=ast.Call( + func=ast.Name(id="globals"), + args=[], + keywords=[], + ), + slice=ast.Constant(value="TestInput"), + ), + attr="parse_obj", + ), + args=[ + ast.Dict( + keys=[ + ast.Constant(value="fieldA"), + ast.Constant(value="fieldB"), + ], + values=[ + ast.Constant(value="a"), + ast.Constant(value="B"), + ], + ) + ], + keywords=[], + ), + ), ) ], - keywords=[], ), ), ( @@ -313,19 +344,47 @@ def test_parse_input_const_value_node_given_list_returns_correct_method_call( ), "TestInput", ast.Call( - func=ast.Attribute(value=ast.Name(id="TestInput"), attr="parse_obj"), - args=[ - ast.Dict( - keys=[ast.Constant(value="nestedField")], - values=[ - ast.Dict( - keys=[ast.Constant(value="fieldA")], - values=[ast.Constant(value="a")], - ) - ], + func=ast.Name(id="Field"), + args=[], + keywords=[ + ast.keyword( + arg="default_factory", + value=ast.Lambda( + args=ast.arguments( + posonlyargs=[], + args=[], + kwonlyargs=[], + kw_defaults=[], + defaults=[], + ), + body=ast.Call( + func=ast.Attribute( + value=ast.Subscript( + value=ast.Call( + func=ast.Name(id="globals"), + args=[], + keywords=[], + ), + slice=ast.Constant(value="TestInput"), + ), + attr="parse_obj", + ), + args=[ + ast.Dict( + keys=[ast.Constant(value="nestedField")], + values=[ + ast.Dict( + keys=[ast.Constant(value="fieldA")], + values=[ast.Constant(value="a")], + ) + ], + ) + ], + keywords=[], + ), + ), ) ], - keywords=[], ), ), ], diff --git a/tests/main/clients/example/expected_client/input_types.py b/tests/main/clients/example/expected_client/input_types.py index db4e18d2..1937c7d8 100644 --- a/tests/main/clients/example/expected_client/input_types.py +++ b/tests/main/clients/example/expected_client/input_types.py @@ -6,11 +6,6 @@ from .enums import Color -class LocationInput(BaseModel): - city: Optional[str] - country: Optional[str] - - class UserCreateInput(BaseModel): first_name: Optional[str] = Field(alias="firstName") last_name: Optional[str] = Field(alias="lastName") @@ -19,11 +14,9 @@ class UserCreateInput(BaseModel): location: Optional["LocationInput"] -class NotificationsPreferencesInput(BaseModel): - receive_mails: bool = Field(alias="receiveMails") - receive_push_notifications: bool = Field(alias="receivePushNotifications") - receive_sms: bool = Field(alias="receiveSms") - title: str +class LocationInput(BaseModel): + city: Optional[str] + country: Optional[str] class UserPreferencesInput(BaseModel): @@ -35,7 +28,7 @@ class UserPreferencesInput(BaseModel): ) notifications_preferences: "NotificationsPreferencesInput" = Field( alias="notificationsPreferences", - default=NotificationsPreferencesInput.parse_obj( + default_factory=lambda: globals()["NotificationsPreferencesInput"].parse_obj( { "receiveMails": True, "receivePushNotifications": True, @@ -46,7 +39,14 @@ class UserPreferencesInput(BaseModel): ) -LocationInput.update_forward_refs() +class NotificationsPreferencesInput(BaseModel): + receive_mails: bool = Field(alias="receiveMails") + receive_push_notifications: bool = Field(alias="receivePushNotifications") + receive_sms: bool = Field(alias="receiveSms") + title: str + + UserCreateInput.update_forward_refs() -NotificationsPreferencesInput.update_forward_refs() +LocationInput.update_forward_refs() UserPreferencesInput.update_forward_refs() +NotificationsPreferencesInput.update_forward_refs()