diff --git a/CHANGELOG.md b/CHANGELOG.md index fe28aef4..ed4fdf5d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,10 @@ # CHANGELOG +## UNRELEASED + +- Fixed incorrectly raised exception when using custom scalar as query argument type. + + ## 0.2.0 (2023-02-02) - Added `remote_schema_url` and `remote_schema_headers` settings to support reading remote schemas. diff --git a/ariadne_codegen/generators/arguments.py b/ariadne_codegen/generators/arguments.py index 05a4807b..37204869 100644 --- a/ariadne_codegen/generators/arguments.py +++ b/ariadne_codegen/generators/arguments.py @@ -2,6 +2,10 @@ from typing import List, Tuple, Union from graphql import ( + GraphQLEnumType, + GraphQLInputObjectType, + GraphQLScalarType, + GraphQLSchema, ListTypeNode, NamedTypeNode, NonNullTypeNode, @@ -19,14 +23,46 @@ generate_list_annotation, generate_name, ) -from .constants import SIMPLE_TYPE_MAP +from .constants import ANY, SIMPLE_TYPE_MAP from .utils import str_to_snake_case class ArgumentsGenerator: - def __init__(self, convert_to_snake_case: bool = True) -> None: + def __init__( + self, schema: GraphQLSchema, convert_to_snake_case: bool = True + ) -> None: + self.schema = schema self.convert_to_snake_case = convert_to_snake_case self.used_types: List[str] = [] + self._used_enums: List[str] = [] + self._used_inputs: List[str] = [] + + def generate( + self, variable_definitions: Tuple[VariableDefinitionNode, ...] + ) -> Tuple[ast.arguments, ast.Dict]: + """Generate arguments from given variable definitions.""" + arguments = generate_arguments([generate_arg("self")]) + dict_ = generate_dict() + for variable_definition in variable_definitions: + org_name = variable_definition.variable.name.value + name = self._process_name(org_name) + annotation = self._parse_type_node(variable_definition.type) + + arguments.args.append(generate_arg(name, annotation)) + dict_.keys.append(generate_constant(org_name)) + dict_.values.append(generate_name(name)) + return arguments, dict_ + + def get_used_enums(self) -> List[str]: + return self._used_enums + + def get_used_inputs(self) -> List[str]: + return self._used_inputs + + def _process_name(self, name: str) -> str: + if self.convert_to_snake_case: + return str_to_snake_case(name) + return name def _parse_type_node( self, @@ -50,31 +86,17 @@ def _parse_named_type_node( self, node: NamedTypeNode, nullable: bool = True ) -> Union[ast.Name, ast.Subscript]: name = node.name.value + type_ = self.schema.type_map.get(name) + if not type_: + raise ParsingError(f"Argument type {name} not found in schema.") - if name in SIMPLE_TYPE_MAP: - name = SIMPLE_TYPE_MAP[name] + if isinstance(type_, GraphQLInputObjectType): + self._used_inputs.append(name) + elif isinstance(type_, GraphQLEnumType): + self._used_enums.append(name) + elif isinstance(type_, GraphQLScalarType): + name = SIMPLE_TYPE_MAP.get(name, ANY) else: - self.used_types.append(name) + raise ParsingError(f"Incorrect argument type {name}") return generate_annotation_name(name, nullable) - - def _process_name(self, name: str) -> str: - if self.convert_to_snake_case: - return str_to_snake_case(name) - return name - - def generate( - self, variable_definitions: Tuple[VariableDefinitionNode, ...] - ) -> Tuple[ast.arguments, ast.Dict]: - """Generate arguments from given variable definitions.""" - arguments = generate_arguments([generate_arg("self")]) - dict_ = generate_dict() - for variable_definition in variable_definitions: - org_name = variable_definition.variable.name.value - name = self._process_name(org_name) - annotation = self._parse_type_node(variable_definition.type) - - arguments.args.append(generate_arg(name, annotation)) - dict_.keys.append(generate_constant(org_name)) - dict_.values.append(generate_name(name)) - return arguments, dict_ diff --git a/ariadne_codegen/generators/package.py b/ariadne_codegen/generators/package.py index 73732427..b3fcde88 100644 --- a/ariadne_codegen/generators/package.py +++ b/ariadne_codegen/generators/package.py @@ -3,13 +3,7 @@ from pathlib import Path from typing import Dict, List, Optional -from graphql import ( - FragmentDefinitionNode, - GraphQLEnumType, - GraphQLInputObjectType, - GraphQLSchema, - OperationDefinitionNode, -) +from graphql import FragmentDefinitionNode, GraphQLSchema, OperationDefinitionNode from ..exceptions import ParsingError from .arguments import ArgumentsGenerator @@ -98,7 +92,9 @@ def __init__( self.arguments_generator = ( arguments_generator if arguments_generator - else ArgumentsGenerator(convert_to_snake_case=self.convert_to_snake_case) + else ArgumentsGenerator( + schema=self.schema, convert_to_snake_case=self.convert_to_snake_case + ) ) self.input_types_generator = ( input_types_generator @@ -208,23 +204,16 @@ def _validate_unique_file_names(self): def _generate_client(self): client_file_path = self.package_path / f"{self.client_file_name}.py" - input_types = [] - enums = [] - for type_ in self.arguments_generator.used_types: - if isinstance(self.schema.type_map[type_], GraphQLInputObjectType): - input_types.append(type_) - elif isinstance(self.schema.type_map[type_], GraphQLEnumType): - enums.append(type_) - else: - raise ParsingError(f"Argument type {type_} not found in schema.") - self.client_generator.add_import( - names=input_types, from_=self.input_types_module_name, level=1 + names=self.arguments_generator.get_used_inputs(), + from_=self.input_types_module_name, + level=1, ) self.client_generator.add_import( - names=enums, from_=self.enums_module_name, level=1 + names=self.arguments_generator.get_used_enums(), + from_=self.enums_module_name, + level=1, ) - self.client_generator.add_import( names=[self.base_client_name], from_=self.base_client_file_path.stem, diff --git a/tests/generators/test_arguments_generator.py b/tests/generators/test_arguments_generator.py index d8ca2276..c93d0ab1 100644 --- a/tests/generators/test_arguments_generator.py +++ b/tests/generators/test_arguments_generator.py @@ -1,9 +1,9 @@ import ast -from graphql import OperationDefinitionNode, parse +from graphql import GraphQLSchema, OperationDefinitionNode, build_schema, parse from ariadne_codegen.generators.arguments import ArgumentsGenerator -from ariadne_codegen.generators.constants import OPTIONAL +from ariadne_codegen.generators.constants import ANY, OPTIONAL from ..utils import compare_ast @@ -16,7 +16,19 @@ def _get_variable_definitions_from_query_str(query: str): def test_generate_returns_arguments_with_correct_non_optional_names_and_annotations(): - generator = ArgumentsGenerator() + schema_str = """ + schema { query: Query } + type Query { _skip: ID! } + + input CustomInputType { + fieldA: Int! + fieldB: Float! + fieldC: String! + fieldD: Boolean! + } + """ + schema = build_schema(schema_str) + generator = ArgumentsGenerator(schema=schema) query = ( "query q($id: ID!, $name: String!, $amount: Int!, $val: Float!, " "$flag: Boolean!, $custom_input: CustomInputType!) {r}" @@ -43,7 +55,12 @@ def test_generate_returns_arguments_with_correct_non_optional_names_and_annotati def test_generate_returns_arguments_with_correct_optional_annotation(): - generator = ArgumentsGenerator() + schema_str = """ + schema { query: Query } + type Query { _skip: ID! } + """ + schema = build_schema(schema_str) + generator = ArgumentsGenerator(schema=schema) query = "query q($id: ID) {r}" variable_definitions = _get_variable_definitions_from_query_str(query) @@ -62,7 +79,7 @@ def test_generate_returns_arguments_with_correct_optional_annotation(): def test_generate_returns_arguments_with_only_self_argument_without_annotation(): - generator = ArgumentsGenerator() + generator = ArgumentsGenerator(schema=GraphQLSchema()) query = "query q {r}" variable_definitions = _get_variable_definitions_from_query_str(query) @@ -76,18 +93,27 @@ def test_generate_returns_arguments_with_only_self_argument_without_annotation() def test_generate_saves_used_non_scalar_types(): - generator = ArgumentsGenerator() + schema_str = """ + schema { query: Query } + type Query { _skip: String! } + + input Type1 { fieldA: Int! } + input Type2 { fieldB: Int! } + """ + schema = build_schema(schema_str) + generator = ArgumentsGenerator(schema=schema) query = "query q($a1: String!, $a2: String, $a3: Type1!, $a4: Type2) {r}" variable_definitions = _get_variable_definitions_from_query_str(query) generator.generate(variable_definitions) - assert len(generator.used_types) == 2 - assert generator.used_types == ["Type1", "Type2"] + used_inputs = generator.get_used_inputs() + assert len(used_inputs) == 2 + assert used_inputs == ["Type1", "Type2"] def test_generate_returns_arguments_and_dictionary_with_snake_case_names(): - generator = ArgumentsGenerator(convert_to_snake_case=True) + generator = ArgumentsGenerator(schema=GraphQLSchema(), convert_to_snake_case=True) query = "query q($camelCase: String!, $snake_case: String!) {r}" variable_definitions = _get_variable_definitions_from_query_str(query) @@ -111,3 +137,34 @@ def test_generate_returns_arguments_and_dictionary_with_snake_case_names(): assert compare_ast(arguments, expected_arguments) assert compare_ast(arguments_dict, expected_arguments_dict) + + +def test_generate_returns_arguments_with_used_custom_scalar(): + schema_str = """ + schema { query: Query } + type Query { _skip: String! } + scalar CustomScalar + """ + generator = ArgumentsGenerator(schema=build_schema(schema_str)) + query_str = "query q($arg: CustomScalar!) {r}" + + expected_arguments = ast.arguments( + posonlyargs=[], + args=[ + ast.arg(arg="self"), + ast.arg(arg="arg", annotation=ast.Name(id=ANY)), + ], + kwonlyargs=[], + kw_defaults=[], + defaults=[], + ) + expected_arguments_dict = ast.Dict( + keys=[ast.Constant(value="arg")], values=[ast.Name(id="arg")] + ) + + arguments, arguments_dict = generator.generate( + _get_variable_definitions_from_query_str(query_str) + ) + + assert compare_ast(arguments, expected_arguments) + assert compare_ast(arguments_dict, expected_arguments_dict)