Skip to content

Commit

Permalink
Merge pull request #75 from mirumee/fix_custom_scalar_as_argument_type
Browse files Browse the repository at this point in the history
Fix custom scalar as argument type
  • Loading branch information
mat-sop committed Feb 13, 2023
2 parents 647c072 + 459a3a9 commit c53c734
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 56 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# CHANGELOG

## UNRELEASED

- Fixed incorrectly raised exception when using custom scalar as query argument type.


## 0.2.0 (2023-02-02)

- Added `remote_schema_url` and `remote_schema_headers` settings to support reading remote schemas.
Expand Down
74 changes: 48 additions & 26 deletions ariadne_codegen/generators/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
from typing import List, Tuple, Union

from graphql import (
GraphQLEnumType,
GraphQLInputObjectType,
GraphQLScalarType,
GraphQLSchema,
ListTypeNode,
NamedTypeNode,
NonNullTypeNode,
Expand All @@ -19,14 +23,46 @@
generate_list_annotation,
generate_name,
)
from .constants import SIMPLE_TYPE_MAP
from .constants import ANY, SIMPLE_TYPE_MAP
from .utils import str_to_snake_case


class ArgumentsGenerator:
def __init__(self, convert_to_snake_case: bool = True) -> None:
def __init__(
self, schema: GraphQLSchema, convert_to_snake_case: bool = True
) -> None:
self.schema = schema
self.convert_to_snake_case = convert_to_snake_case
self.used_types: List[str] = []
self._used_enums: List[str] = []
self._used_inputs: List[str] = []

def generate(
self, variable_definitions: Tuple[VariableDefinitionNode, ...]
) -> Tuple[ast.arguments, ast.Dict]:
"""Generate arguments from given variable definitions."""
arguments = generate_arguments([generate_arg("self")])
dict_ = generate_dict()
for variable_definition in variable_definitions:
org_name = variable_definition.variable.name.value
name = self._process_name(org_name)
annotation = self._parse_type_node(variable_definition.type)

arguments.args.append(generate_arg(name, annotation))
dict_.keys.append(generate_constant(org_name))
dict_.values.append(generate_name(name))
return arguments, dict_

def get_used_enums(self) -> List[str]:
return self._used_enums

def get_used_inputs(self) -> List[str]:
return self._used_inputs

def _process_name(self, name: str) -> str:
if self.convert_to_snake_case:
return str_to_snake_case(name)
return name

def _parse_type_node(
self,
Expand All @@ -50,31 +86,17 @@ def _parse_named_type_node(
self, node: NamedTypeNode, nullable: bool = True
) -> Union[ast.Name, ast.Subscript]:
name = node.name.value
type_ = self.schema.type_map.get(name)
if not type_:
raise ParsingError(f"Argument type {name} not found in schema.")

if name in SIMPLE_TYPE_MAP:
name = SIMPLE_TYPE_MAP[name]
if isinstance(type_, GraphQLInputObjectType):
self._used_inputs.append(name)
elif isinstance(type_, GraphQLEnumType):
self._used_enums.append(name)
elif isinstance(type_, GraphQLScalarType):
name = SIMPLE_TYPE_MAP.get(name, ANY)
else:
self.used_types.append(name)
raise ParsingError(f"Incorrect argument type {name}")

return generate_annotation_name(name, nullable)

def _process_name(self, name: str) -> str:
if self.convert_to_snake_case:
return str_to_snake_case(name)
return name

def generate(
self, variable_definitions: Tuple[VariableDefinitionNode, ...]
) -> Tuple[ast.arguments, ast.Dict]:
"""Generate arguments from given variable definitions."""
arguments = generate_arguments([generate_arg("self")])
dict_ = generate_dict()
for variable_definition in variable_definitions:
org_name = variable_definition.variable.name.value
name = self._process_name(org_name)
annotation = self._parse_type_node(variable_definition.type)

arguments.args.append(generate_arg(name, annotation))
dict_.keys.append(generate_constant(org_name))
dict_.values.append(generate_name(name))
return arguments, dict_
31 changes: 10 additions & 21 deletions ariadne_codegen/generators/package.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,7 @@
from pathlib import Path
from typing import Dict, List, Optional

from graphql import (
FragmentDefinitionNode,
GraphQLEnumType,
GraphQLInputObjectType,
GraphQLSchema,
OperationDefinitionNode,
)
from graphql import FragmentDefinitionNode, GraphQLSchema, OperationDefinitionNode

from ..exceptions import ParsingError
from .arguments import ArgumentsGenerator
Expand Down Expand Up @@ -98,7 +92,9 @@ def __init__(
self.arguments_generator = (
arguments_generator
if arguments_generator
else ArgumentsGenerator(convert_to_snake_case=self.convert_to_snake_case)
else ArgumentsGenerator(
schema=self.schema, convert_to_snake_case=self.convert_to_snake_case
)
)
self.input_types_generator = (
input_types_generator
Expand Down Expand Up @@ -208,23 +204,16 @@ def _validate_unique_file_names(self):
def _generate_client(self):
client_file_path = self.package_path / f"{self.client_file_name}.py"

input_types = []
enums = []
for type_ in self.arguments_generator.used_types:
if isinstance(self.schema.type_map[type_], GraphQLInputObjectType):
input_types.append(type_)
elif isinstance(self.schema.type_map[type_], GraphQLEnumType):
enums.append(type_)
else:
raise ParsingError(f"Argument type {type_} not found in schema.")

self.client_generator.add_import(
names=input_types, from_=self.input_types_module_name, level=1
names=self.arguments_generator.get_used_inputs(),
from_=self.input_types_module_name,
level=1,
)
self.client_generator.add_import(
names=enums, from_=self.enums_module_name, level=1
names=self.arguments_generator.get_used_enums(),
from_=self.enums_module_name,
level=1,
)

self.client_generator.add_import(
names=[self.base_client_name],
from_=self.base_client_file_path.stem,
Expand Down
75 changes: 66 additions & 9 deletions tests/generators/test_arguments_generator.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import ast

from graphql import OperationDefinitionNode, parse
from graphql import GraphQLSchema, OperationDefinitionNode, build_schema, parse

from ariadne_codegen.generators.arguments import ArgumentsGenerator
from ariadne_codegen.generators.constants import OPTIONAL
from ariadne_codegen.generators.constants import ANY, OPTIONAL

from ..utils import compare_ast

Expand All @@ -16,7 +16,19 @@ def _get_variable_definitions_from_query_str(query: str):


def test_generate_returns_arguments_with_correct_non_optional_names_and_annotations():
generator = ArgumentsGenerator()
schema_str = """
schema { query: Query }
type Query { _skip: ID! }
input CustomInputType {
fieldA: Int!
fieldB: Float!
fieldC: String!
fieldD: Boolean!
}
"""
schema = build_schema(schema_str)
generator = ArgumentsGenerator(schema=schema)
query = (
"query q($id: ID!, $name: String!, $amount: Int!, $val: Float!, "
"$flag: Boolean!, $custom_input: CustomInputType!) {r}"
Expand All @@ -43,7 +55,12 @@ def test_generate_returns_arguments_with_correct_non_optional_names_and_annotati


def test_generate_returns_arguments_with_correct_optional_annotation():
generator = ArgumentsGenerator()
schema_str = """
schema { query: Query }
type Query { _skip: ID! }
"""
schema = build_schema(schema_str)
generator = ArgumentsGenerator(schema=schema)
query = "query q($id: ID) {r}"
variable_definitions = _get_variable_definitions_from_query_str(query)

Expand All @@ -62,7 +79,7 @@ def test_generate_returns_arguments_with_correct_optional_annotation():


def test_generate_returns_arguments_with_only_self_argument_without_annotation():
generator = ArgumentsGenerator()
generator = ArgumentsGenerator(schema=GraphQLSchema())
query = "query q {r}"
variable_definitions = _get_variable_definitions_from_query_str(query)

Expand All @@ -76,18 +93,27 @@ def test_generate_returns_arguments_with_only_self_argument_without_annotation()


def test_generate_saves_used_non_scalar_types():
generator = ArgumentsGenerator()
schema_str = """
schema { query: Query }
type Query { _skip: String! }
input Type1 { fieldA: Int! }
input Type2 { fieldB: Int! }
"""
schema = build_schema(schema_str)
generator = ArgumentsGenerator(schema=schema)
query = "query q($a1: String!, $a2: String, $a3: Type1!, $a4: Type2) {r}"
variable_definitions = _get_variable_definitions_from_query_str(query)

generator.generate(variable_definitions)

assert len(generator.used_types) == 2
assert generator.used_types == ["Type1", "Type2"]
used_inputs = generator.get_used_inputs()
assert len(used_inputs) == 2
assert used_inputs == ["Type1", "Type2"]


def test_generate_returns_arguments_and_dictionary_with_snake_case_names():
generator = ArgumentsGenerator(convert_to_snake_case=True)
generator = ArgumentsGenerator(schema=GraphQLSchema(), convert_to_snake_case=True)
query = "query q($camelCase: String!, $snake_case: String!) {r}"
variable_definitions = _get_variable_definitions_from_query_str(query)

Expand All @@ -111,3 +137,34 @@ def test_generate_returns_arguments_and_dictionary_with_snake_case_names():

assert compare_ast(arguments, expected_arguments)
assert compare_ast(arguments_dict, expected_arguments_dict)


def test_generate_returns_arguments_with_used_custom_scalar():
schema_str = """
schema { query: Query }
type Query { _skip: String! }
scalar CustomScalar
"""
generator = ArgumentsGenerator(schema=build_schema(schema_str))
query_str = "query q($arg: CustomScalar!) {r}"

expected_arguments = ast.arguments(
posonlyargs=[],
args=[
ast.arg(arg="self"),
ast.arg(arg="arg", annotation=ast.Name(id=ANY)),
],
kwonlyargs=[],
kw_defaults=[],
defaults=[],
)
expected_arguments_dict = ast.Dict(
keys=[ast.Constant(value="arg")], values=[ast.Name(id="arg")]
)

arguments, arguments_dict = generator.generate(
_get_variable_definitions_from_query_str(query_str)
)

assert compare_ast(arguments, expected_arguments)
assert compare_ast(arguments_dict, expected_arguments_dict)

0 comments on commit c53c734

Please sign in to comment.