From ae37ba7dce4c6de1937233c4320d0647fd7f3a70 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateusz=20Sopi=C5=84ski?= Date: Fri, 10 Feb 2023 16:02:12 +0100 Subject: [PATCH 1/4] Fix using custom scalars as arguments types --- ariadne_codegen/generators/arguments.py | 13 ++++-- ariadne_codegen/generators/package.py | 20 ++++++---- tests/generators/test_arguments_generator.py | 42 +++++++++++++++++--- tests/generators/test_package_generator.py | 22 ++++++++++ 4 files changed, 79 insertions(+), 18 deletions(-) diff --git a/ariadne_codegen/generators/arguments.py b/ariadne_codegen/generators/arguments.py index 05a4807b..4f2fbf1a 100644 --- a/ariadne_codegen/generators/arguments.py +++ b/ariadne_codegen/generators/arguments.py @@ -2,6 +2,8 @@ from typing import List, Tuple, Union from graphql import ( + GraphQLScalarType, + GraphQLSchema, ListTypeNode, NamedTypeNode, NonNullTypeNode, @@ -19,12 +21,15 @@ generate_list_annotation, generate_name, ) -from .constants import SIMPLE_TYPE_MAP +from .constants import ANY, SIMPLE_TYPE_MAP from .utils import str_to_snake_case class ArgumentsGenerator: - def __init__(self, convert_to_snake_case: bool = True) -> None: + def __init__( + self, schema: GraphQLSchema, convert_to_snake_case: bool = True + ) -> None: + self.schema = schema self.convert_to_snake_case = convert_to_snake_case self.used_types: List[str] = [] @@ -51,8 +56,8 @@ def _parse_named_type_node( ) -> Union[ast.Name, ast.Subscript]: name = node.name.value - if name in SIMPLE_TYPE_MAP: - name = SIMPLE_TYPE_MAP[name] + if isinstance(self.schema.type_map[name], GraphQLScalarType): + name = SIMPLE_TYPE_MAP.get(name, ANY) else: self.used_types.append(name) diff --git a/ariadne_codegen/generators/package.py b/ariadne_codegen/generators/package.py index 73732427..71eb09a3 100644 --- a/ariadne_codegen/generators/package.py +++ b/ariadne_codegen/generators/package.py @@ -7,6 +7,7 @@ FragmentDefinitionNode, GraphQLEnumType, GraphQLInputObjectType, + GraphQLScalarType, GraphQLSchema, OperationDefinitionNode, ) @@ -98,7 +99,9 @@ def __init__( self.arguments_generator = ( arguments_generator if arguments_generator - else ArgumentsGenerator(convert_to_snake_case=self.convert_to_snake_case) + else ArgumentsGenerator( + schema=self.schema, convert_to_snake_case=self.convert_to_snake_case + ) ) self.input_types_generator = ( input_types_generator @@ -210,13 +213,14 @@ def _generate_client(self): input_types = [] enums = [] - for type_ in self.arguments_generator.used_types: - if isinstance(self.schema.type_map[type_], GraphQLInputObjectType): - input_types.append(type_) - elif isinstance(self.schema.type_map[type_], GraphQLEnumType): - enums.append(type_) - else: - raise ParsingError(f"Argument type {type_} not found in schema.") + for type_name in self.arguments_generator.used_types: + type_ = self.schema.type_map.get(type_name) + if isinstance(type_, GraphQLInputObjectType): + input_types.append(type_name) + elif isinstance(type_, GraphQLEnumType): + enums.append(type_name) + elif not isinstance(type_, GraphQLScalarType): + raise ParsingError(f"Argument type {type_name} not found in schema.") self.client_generator.add_import( names=input_types, from_=self.input_types_module_name, level=1 diff --git a/tests/generators/test_arguments_generator.py b/tests/generators/test_arguments_generator.py index d8ca2276..7b78c376 100644 --- a/tests/generators/test_arguments_generator.py +++ b/tests/generators/test_arguments_generator.py @@ -1,6 +1,11 @@ import ast -from graphql import OperationDefinitionNode, parse +from graphql import ( + OperationDefinitionNode, + build_schema, + parse, + GraphQLSchema, +) from ariadne_codegen.generators.arguments import ArgumentsGenerator from ariadne_codegen.generators.constants import OPTIONAL @@ -16,7 +21,19 @@ def _get_variable_definitions_from_query_str(query: str): def test_generate_returns_arguments_with_correct_non_optional_names_and_annotations(): - generator = ArgumentsGenerator() + schema_str = """ + schema { query: Query } + type Query { _skip: ID! } + + input CustomInputType { + fieldA: Int! + fieldB: Float! + fieldC: String! + fieldD: Boolean! + } + """ + schema = build_schema(schema_str) + generator = ArgumentsGenerator(schema=schema) query = ( "query q($id: ID!, $name: String!, $amount: Int!, $val: Float!, " "$flag: Boolean!, $custom_input: CustomInputType!) {r}" @@ -43,7 +60,12 @@ def test_generate_returns_arguments_with_correct_non_optional_names_and_annotati def test_generate_returns_arguments_with_correct_optional_annotation(): - generator = ArgumentsGenerator() + schema_str = """ + schema { query: Query } + type Query { _skip: ID! } + """ + schema = build_schema(schema_str) + generator = ArgumentsGenerator(schema=schema) query = "query q($id: ID) {r}" variable_definitions = _get_variable_definitions_from_query_str(query) @@ -62,7 +84,7 @@ def test_generate_returns_arguments_with_correct_optional_annotation(): def test_generate_returns_arguments_with_only_self_argument_without_annotation(): - generator = ArgumentsGenerator() + generator = ArgumentsGenerator(schema=GraphQLSchema()) query = "query q {r}" variable_definitions = _get_variable_definitions_from_query_str(query) @@ -76,7 +98,15 @@ def test_generate_returns_arguments_with_only_self_argument_without_annotation() def test_generate_saves_used_non_scalar_types(): - generator = ArgumentsGenerator() + schema_str = """ + schema { query: Query } + type Query { _skip: String! } + + input Type1 { fieldA: Int! } + input Type2 { fieldB: Int! } + """ + schema = build_schema(schema_str) + generator = ArgumentsGenerator(schema=schema) query = "query q($a1: String!, $a2: String, $a3: Type1!, $a4: Type2) {r}" variable_definitions = _get_variable_definitions_from_query_str(query) @@ -87,7 +117,7 @@ def test_generate_saves_used_non_scalar_types(): def test_generate_returns_arguments_and_dictionary_with_snake_case_names(): - generator = ArgumentsGenerator(convert_to_snake_case=True) + generator = ArgumentsGenerator(schema=GraphQLSchema(), convert_to_snake_case=True) query = "query q($camelCase: String!, $snake_case: String!) {r}" variable_definitions = _get_variable_definitions_from_query_str(query) diff --git a/tests/generators/test_package_generator.py b/tests/generators/test_package_generator.py index 53d29974..56d15720 100644 --- a/tests/generators/test_package_generator.py +++ b/tests/generators/test_package_generator.py @@ -45,6 +45,8 @@ input CustomInput { value: Int! } + +scalar CustomScalar """ @@ -487,6 +489,26 @@ class CustomQueryQuery1Field2(BaseModel): assert dedent(expected_types) in result_types_content +def test_generate_doesnt_raise_exception_for_custom_scalar_as_argument_type(tmp_path): + package_name = "test_graphql_client" + query_str = "query TestQuery($argument: CustomScalar!) { query1 { id } }" + + query_def = parse(query_str).definitions[0] + generator = PackageGenerator( + package_name, + tmp_path.as_posix(), + build_ast_schema(parse(SCHEMA_STR)), + ) + + generator.add_operation(query_def) + generator.generate() + + client_file_path = tmp_path / package_name / f"{generator.client_file_name}.py" + with client_file_path.open() as client_file: + client_file_content = client_file.read() + assert "(self, argument: Any)" in client_file_content + + def test_generate_returns_list_of_generated_files(tmp_path): generator = PackageGenerator( "test_graphql_client", From 47578481a28ec447862efa0d13448632d86199dd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateusz=20Sopi=C5=84ski?= Date: Fri, 10 Feb 2023 16:04:08 +0100 Subject: [PATCH 2/4] Fix imports order --- tests/generators/test_arguments_generator.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/tests/generators/test_arguments_generator.py b/tests/generators/test_arguments_generator.py index 7b78c376..0adcc881 100644 --- a/tests/generators/test_arguments_generator.py +++ b/tests/generators/test_arguments_generator.py @@ -1,11 +1,6 @@ import ast -from graphql import ( - OperationDefinitionNode, - build_schema, - parse, - GraphQLSchema, -) +from graphql import GraphQLSchema, OperationDefinitionNode, build_schema, parse from ariadne_codegen.generators.arguments import ArgumentsGenerator from ariadne_codegen.generators.constants import OPTIONAL From e6969b88aa3ba1ad6ac1128bd5e5782d977bf299 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateusz=20Sopi=C5=84ski?= Date: Mon, 13 Feb 2023 11:10:10 +0100 Subject: [PATCH 3/4] Refactor arguments generator to validate used types --- ariadne_codegen/generators/arguments.py | 63 +++++++++++++------- ariadne_codegen/generators/package.py | 29 +++------ tests/generators/test_arguments_generator.py | 38 +++++++++++- tests/generators/test_package_generator.py | 22 ------- 4 files changed, 82 insertions(+), 70 deletions(-) diff --git a/ariadne_codegen/generators/arguments.py b/ariadne_codegen/generators/arguments.py index 4f2fbf1a..37204869 100644 --- a/ariadne_codegen/generators/arguments.py +++ b/ariadne_codegen/generators/arguments.py @@ -2,6 +2,8 @@ from typing import List, Tuple, Union from graphql import ( + GraphQLEnumType, + GraphQLInputObjectType, GraphQLScalarType, GraphQLSchema, ListTypeNode, @@ -32,6 +34,35 @@ def __init__( self.schema = schema self.convert_to_snake_case = convert_to_snake_case self.used_types: List[str] = [] + self._used_enums: List[str] = [] + self._used_inputs: List[str] = [] + + def generate( + self, variable_definitions: Tuple[VariableDefinitionNode, ...] + ) -> Tuple[ast.arguments, ast.Dict]: + """Generate arguments from given variable definitions.""" + arguments = generate_arguments([generate_arg("self")]) + dict_ = generate_dict() + for variable_definition in variable_definitions: + org_name = variable_definition.variable.name.value + name = self._process_name(org_name) + annotation = self._parse_type_node(variable_definition.type) + + arguments.args.append(generate_arg(name, annotation)) + dict_.keys.append(generate_constant(org_name)) + dict_.values.append(generate_name(name)) + return arguments, dict_ + + def get_used_enums(self) -> List[str]: + return self._used_enums + + def get_used_inputs(self) -> List[str]: + return self._used_inputs + + def _process_name(self, name: str) -> str: + if self.convert_to_snake_case: + return str_to_snake_case(name) + return name def _parse_type_node( self, @@ -55,31 +86,17 @@ def _parse_named_type_node( self, node: NamedTypeNode, nullable: bool = True ) -> Union[ast.Name, ast.Subscript]: name = node.name.value + type_ = self.schema.type_map.get(name) + if not type_: + raise ParsingError(f"Argument type {name} not found in schema.") - if isinstance(self.schema.type_map[name], GraphQLScalarType): + if isinstance(type_, GraphQLInputObjectType): + self._used_inputs.append(name) + elif isinstance(type_, GraphQLEnumType): + self._used_enums.append(name) + elif isinstance(type_, GraphQLScalarType): name = SIMPLE_TYPE_MAP.get(name, ANY) else: - self.used_types.append(name) + raise ParsingError(f"Incorrect argument type {name}") return generate_annotation_name(name, nullable) - - def _process_name(self, name: str) -> str: - if self.convert_to_snake_case: - return str_to_snake_case(name) - return name - - def generate( - self, variable_definitions: Tuple[VariableDefinitionNode, ...] - ) -> Tuple[ast.arguments, ast.Dict]: - """Generate arguments from given variable definitions.""" - arguments = generate_arguments([generate_arg("self")]) - dict_ = generate_dict() - for variable_definition in variable_definitions: - org_name = variable_definition.variable.name.value - name = self._process_name(org_name) - annotation = self._parse_type_node(variable_definition.type) - - arguments.args.append(generate_arg(name, annotation)) - dict_.keys.append(generate_constant(org_name)) - dict_.values.append(generate_name(name)) - return arguments, dict_ diff --git a/ariadne_codegen/generators/package.py b/ariadne_codegen/generators/package.py index 71eb09a3..b3fcde88 100644 --- a/ariadne_codegen/generators/package.py +++ b/ariadne_codegen/generators/package.py @@ -3,14 +3,7 @@ from pathlib import Path from typing import Dict, List, Optional -from graphql import ( - FragmentDefinitionNode, - GraphQLEnumType, - GraphQLInputObjectType, - GraphQLScalarType, - GraphQLSchema, - OperationDefinitionNode, -) +from graphql import FragmentDefinitionNode, GraphQLSchema, OperationDefinitionNode from ..exceptions import ParsingError from .arguments import ArgumentsGenerator @@ -211,24 +204,16 @@ def _validate_unique_file_names(self): def _generate_client(self): client_file_path = self.package_path / f"{self.client_file_name}.py" - input_types = [] - enums = [] - for type_name in self.arguments_generator.used_types: - type_ = self.schema.type_map.get(type_name) - if isinstance(type_, GraphQLInputObjectType): - input_types.append(type_name) - elif isinstance(type_, GraphQLEnumType): - enums.append(type_name) - elif not isinstance(type_, GraphQLScalarType): - raise ParsingError(f"Argument type {type_name} not found in schema.") - self.client_generator.add_import( - names=input_types, from_=self.input_types_module_name, level=1 + names=self.arguments_generator.get_used_inputs(), + from_=self.input_types_module_name, + level=1, ) self.client_generator.add_import( - names=enums, from_=self.enums_module_name, level=1 + names=self.arguments_generator.get_used_enums(), + from_=self.enums_module_name, + level=1, ) - self.client_generator.add_import( names=[self.base_client_name], from_=self.base_client_file_path.stem, diff --git a/tests/generators/test_arguments_generator.py b/tests/generators/test_arguments_generator.py index 0adcc881..c93d0ab1 100644 --- a/tests/generators/test_arguments_generator.py +++ b/tests/generators/test_arguments_generator.py @@ -3,7 +3,7 @@ from graphql import GraphQLSchema, OperationDefinitionNode, build_schema, parse from ariadne_codegen.generators.arguments import ArgumentsGenerator -from ariadne_codegen.generators.constants import OPTIONAL +from ariadne_codegen.generators.constants import ANY, OPTIONAL from ..utils import compare_ast @@ -107,8 +107,9 @@ def test_generate_saves_used_non_scalar_types(): generator.generate(variable_definitions) - assert len(generator.used_types) == 2 - assert generator.used_types == ["Type1", "Type2"] + used_inputs = generator.get_used_inputs() + assert len(used_inputs) == 2 + assert used_inputs == ["Type1", "Type2"] def test_generate_returns_arguments_and_dictionary_with_snake_case_names(): @@ -136,3 +137,34 @@ def test_generate_returns_arguments_and_dictionary_with_snake_case_names(): assert compare_ast(arguments, expected_arguments) assert compare_ast(arguments_dict, expected_arguments_dict) + + +def test_generate_returns_arguments_with_used_custom_scalar(): + schema_str = """ + schema { query: Query } + type Query { _skip: String! } + scalar CustomScalar + """ + generator = ArgumentsGenerator(schema=build_schema(schema_str)) + query_str = "query q($arg: CustomScalar!) {r}" + + expected_arguments = ast.arguments( + posonlyargs=[], + args=[ + ast.arg(arg="self"), + ast.arg(arg="arg", annotation=ast.Name(id=ANY)), + ], + kwonlyargs=[], + kw_defaults=[], + defaults=[], + ) + expected_arguments_dict = ast.Dict( + keys=[ast.Constant(value="arg")], values=[ast.Name(id="arg")] + ) + + arguments, arguments_dict = generator.generate( + _get_variable_definitions_from_query_str(query_str) + ) + + assert compare_ast(arguments, expected_arguments) + assert compare_ast(arguments_dict, expected_arguments_dict) diff --git a/tests/generators/test_package_generator.py b/tests/generators/test_package_generator.py index 56d15720..53d29974 100644 --- a/tests/generators/test_package_generator.py +++ b/tests/generators/test_package_generator.py @@ -45,8 +45,6 @@ input CustomInput { value: Int! } - -scalar CustomScalar """ @@ -489,26 +487,6 @@ class CustomQueryQuery1Field2(BaseModel): assert dedent(expected_types) in result_types_content -def test_generate_doesnt_raise_exception_for_custom_scalar_as_argument_type(tmp_path): - package_name = "test_graphql_client" - query_str = "query TestQuery($argument: CustomScalar!) { query1 { id } }" - - query_def = parse(query_str).definitions[0] - generator = PackageGenerator( - package_name, - tmp_path.as_posix(), - build_ast_schema(parse(SCHEMA_STR)), - ) - - generator.add_operation(query_def) - generator.generate() - - client_file_path = tmp_path / package_name / f"{generator.client_file_name}.py" - with client_file_path.open() as client_file: - client_file_content = client_file.read() - assert "(self, argument: Any)" in client_file_content - - def test_generate_returns_list_of_generated_files(tmp_path): generator = PackageGenerator( "test_graphql_client", From 459a3a9017ac0c53acff7349405fce69ccc7917c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateusz=20Sopi=C5=84ski?= Date: Mon, 13 Feb 2023 11:15:49 +0100 Subject: [PATCH 4/4] Update changelog --- CHANGELOG.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index fe28aef4..ed4fdf5d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,10 @@ # CHANGELOG +## UNRELEASED + +- Fixed incorrectly raised exception when using custom scalar as query argument type. + + ## 0.2.0 (2023-02-02) - Added `remote_schema_url` and `remote_schema_headers` settings to support reading remote schemas.