diff --git a/CHANGELOG.md b/CHANGELOG.md index 3994533e..b679ba18 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ - Unlocked versions of black, isort, autoflake and dev dependencies - Added `remote_schema_verify_ssl` option. - Changed how default values for inputs are generated to handle potential cycles. +- Fixed `BaseModel` incorrectly calling `parse` and `serialize` methods on entire list instead of its items for `List[Scalar]`. ## 0.4.0 (2023-03-20) diff --git a/ariadne_codegen/client_generators/dependencies/base_model.py b/ariadne_codegen/client_generators/dependencies/base_model.py index 2b6cc8d0..aee81e59 100644 --- a/ariadne_codegen/client_generators/dependencies/base_model.py +++ b/ariadne_codegen/client_generators/dependencies/base_model.py @@ -1,4 +1,4 @@ -from typing import Any, Dict +from typing import Any, Dict, Type, Union, get_args, get_origin from pydantic import BaseModel as PydanticBaseModel from pydantic.class_validators import validator @@ -15,16 +15,36 @@ class Config: # pylint: disable=no-self-argument @validator("*", pre=True) - def decode_custom_scalars(cls, value: Any, field: ModelField) -> Any: - decode = SCALARS_PARSE_FUNCTIONS.get(field.type_) - if decode and callable(decode): + def parse_custom_scalars(cls, value: Any, field: ModelField) -> Any: + return cls._parse_custom_scalar_value(value, field.annotation) + + @classmethod + def _parse_custom_scalar_value(cls, value: Any, type_: Type[Any]) -> Any: + origin = get_origin(type_) + args = get_args(type_) + if origin is list and isinstance(value, list): + return [cls._parse_custom_scalar_value(item, args[0]) for item in value] + + if origin is Union and type(None) in args: + sub_type: Any = list(filter(None, args))[0] + return cls._parse_custom_scalar_value(value, sub_type) + + decode = SCALARS_PARSE_FUNCTIONS.get(type_) + if value and decode and callable(decode): return decode(value) + return value def dict(self, **kwargs: Any) -> Dict[str, Any]: dict_ = super().dict(**kwargs) - for key, value in dict_.items(): - serialize = SCALARS_SERIALIZE_FUNCTIONS.get(type(value)) - if serialize and callable(serialize): - dict_[key] = serialize(value) - return dict_ + return {key: self._serialize_value(value) for key, value in dict_.items()} + + def _serialize_value(self, value: Any) -> Any: + serialize = SCALARS_SERIALIZE_FUNCTIONS.get(type(value)) + if serialize and callable(serialize): + return serialize(value) + + if isinstance(value, list): + return [self._serialize_value(item) for item in value] + + return value diff --git a/tests/client_generators/dependencies/test_base_model.py b/tests/client_generators/dependencies/test_base_model.py new file mode 100644 index 00000000..39fc5ecc --- /dev/null +++ b/tests/client_generators/dependencies/test_base_model.py @@ -0,0 +1,241 @@ +from typing import List, Optional + +import pytest + +from ariadne_codegen.client_generators.dependencies.base_model import BaseModel + + +@pytest.mark.parametrize( + "annotation, value, expected_args", + [ + (str, "a", {"a"}), + (Optional[str], "a", {"a"}), + (Optional[str], None, set()), + (List[str], ["a", "b"], {"a", "b"}), + (List[Optional[str]], ["a", None], {"a"}), + (Optional[List[str]], ["a", "b"], {"a", "b"}), + (Optional[List[str]], None, set()), + (Optional[List[Optional[str]]], ["a", None], {"a"}), + (Optional[List[Optional[str]]], None, set()), + (List[List[str]], [["a", "b"], ["c", "d"]], {"a", "b", "c", "d"}), + (Optional[List[List[str]]], [["a", "b"], ["c", "d"]], {"a", "b", "c", "d"}), + (Optional[List[List[str]]], None, set()), + ( + Optional[List[Optional[List[str]]]], + [["a", "b"], ["c", "d"]], + {"a", "b", "c", "d"}, + ), + (Optional[List[Optional[List[str]]]], None, set()), + (Optional[List[Optional[List[str]]]], [["a", "b"], None], {"a", "b"}), + ( + Optional[List[Optional[List[Optional[str]]]]], + [["a", "b"], ["c", "d"]], + {"a", "b", "c", "d"}, + ), + (Optional[List[Optional[List[Optional[str]]]]], None, set()), + (Optional[List[Optional[List[Optional[str]]]]], [["a", "b"], None], {"a", "b"}), + ( + Optional[List[Optional[List[Optional[str]]]]], + [["a", None], ["b", None]], + {"a", "b"}, + ), + ], +) +def test_parse_obj_applies_parse_on_every_list_element( + annotation, value, expected_args, mocker +): + mocked_parse = mocker.MagicMock(side_effect=lambda s: s) + mocker.patch( + "ariadne_codegen.client_generators.dependencies.base_model." + "SCALARS_PARSE_FUNCTIONS", + {str: mocked_parse}, + ) + + class TestModel(BaseModel): + field: annotation + + TestModel.parse_obj({"field": value}) + + assert mocked_parse.call_count == len(expected_args) + assert {c.args[0] for c in mocked_parse.call_args_list} == expected_args + + +def test_parse_obj_doesnt_apply_parse_on_not_matching_type(mocker): + mocked_parse = mocker.MagicMock(side_effect=lambda s: s) + mocker.patch( + "ariadne_codegen.client_generators.dependencies.base_model." + "SCALARS_PARSE_FUNCTIONS", + {str: mocked_parse}, + ) + + class TestModel(BaseModel): + field_a: int + field_b: Optional[int] + field_c: Optional[int] + field_d: List[int] + field_e: Optional[List[int]] + field_f: Optional[List[int]] + field_g: Optional[List[Optional[int]]] + field_h: Optional[List[Optional[int]]] + field_i: Optional[List[Optional[int]]] + + TestModel.parse_obj( + { + "field_a": 1, + "field_b": 2, + "field_c": None, + "field_d": [3, 4], + "field_e": [5, 6], + "field_f": None, + "field_g": [7, 8], + "field_h": [9, None], + "field_i": None, + } + ) + + assert not mocked_parse.called + + +def test_parse_obj_applies_parse_only_once_for_every_element(mocker): + mocked_parse = mocker.MagicMock(side_effect=lambda s: s) + mocker.patch( + "ariadne_codegen.client_generators.dependencies.base_model." + "SCALARS_PARSE_FUNCTIONS", + {str: mocked_parse}, + ) + + class TestModelC(BaseModel): + value: str + + class TestModelB(BaseModel): + value: str + field_c: TestModelC + + class TestModelA(BaseModel): + value: str + field_b: TestModelB + + TestModelA.parse_obj( + {"value": "a", "field_b": {"value": "b", "field_c": {"value": "c"}}} + ) + + assert mocked_parse.call_count == 3 + assert {c.args[0] for c in mocked_parse.call_args_list} == {"a", "b", "c"} + + +@pytest.mark.parametrize( + "annotation, value, expected_args", + [ + (str, "a", {"a"}), + (Optional[str], "a", {"a"}), + (Optional[str], None, set()), + (List[str], ["a", "b"], {"a", "b"}), + (List[Optional[str]], ["a", None], {"a"}), + (Optional[List[str]], ["a", "b"], {"a", "b"}), + (Optional[List[str]], None, set()), + (Optional[List[Optional[str]]], ["a", None], {"a"}), + (Optional[List[Optional[str]]], None, set()), + (List[List[str]], [["a", "b"], ["c", "d"]], {"a", "b", "c", "d"}), + (Optional[List[List[str]]], [["a", "b"], ["c", "d"]], {"a", "b", "c", "d"}), + (Optional[List[List[str]]], None, set()), + ( + Optional[List[Optional[List[str]]]], + [["a", "b"], ["c", "d"]], + {"a", "b", "c", "d"}, + ), + (Optional[List[Optional[List[str]]]], None, set()), + (Optional[List[Optional[List[str]]]], [["a", "b"], None], {"a", "b"}), + ( + Optional[List[Optional[List[Optional[str]]]]], + [["a", "b"], ["c", "d"]], + {"a", "b", "c", "d"}, + ), + (Optional[List[Optional[List[Optional[str]]]]], None, set()), + (Optional[List[Optional[List[Optional[str]]]]], [["a", "b"], None], {"a", "b"}), + ( + Optional[List[Optional[List[Optional[str]]]]], + [["a", None], ["b", None]], + {"a", "b"}, + ), + ], +) +def test_dict_applies_serialize_on_every_list_element( + annotation, value, expected_args, mocker +): + mocked_serialize = mocker.MagicMock(side_effect=lambda s: s) + mocker.patch( + "ariadne_codegen.client_generators.dependencies.base_model." + "SCALARS_SERIALIZE_FUNCTIONS", + {str: mocked_serialize}, + ) + + class TestModel(BaseModel): + field: annotation + + TestModel.parse_obj({"field": value}).dict() + + assert mocked_serialize.call_count == len(expected_args) + assert {c.args[0] for c in mocked_serialize.call_args_list} == expected_args + + +def test_dict_doesnt_apply_serialize_on_not_matching_type(mocker): + mocked_serialize = mocker.MagicMock(side_effect=lambda s: s) + mocker.patch( + "ariadne_codegen.client_generators.dependencies.base_model." + "SCALARS_SERIALIZE_FUNCTIONS", + {str: mocked_serialize}, + ) + + class TestModel(BaseModel): + field_a: int + field_b: Optional[int] + field_c: Optional[int] + field_d: List[int] + field_e: Optional[List[int]] + field_f: Optional[List[int]] + field_g: Optional[List[Optional[int]]] + field_h: Optional[List[Optional[int]]] + field_i: Optional[List[Optional[int]]] + + TestModel.parse_obj( + { + "field_a": 1, + "field_b": 2, + "field_c": None, + "field_d": [3, 4], + "field_e": [5, 6], + "field_f": None, + "field_g": [7, 8], + "field_h": [9, None], + "field_i": None, + } + ).dict() + + assert not mocked_serialize.called + + +def test_dict_applies_serialize_only_once_for_every_element(mocker): + mocked_serialize = mocker.MagicMock(side_effect=lambda s: s) + mocker.patch( + "ariadne_codegen.client_generators.dependencies.base_model." + "SCALARS_SERIALIZE_FUNCTIONS", + {str: mocked_serialize}, + ) + + class TestModelC(BaseModel): + value: str + + class TestModelB(BaseModel): + value: str + field_c: TestModelC + + class TestModelA(BaseModel): + value: str + field_b: TestModelB + + TestModelA.parse_obj( + {"value": "a", "field_b": {"value": "b", "field_c": {"value": "c"}}} + ).dict() + + assert mocked_serialize.call_count == 3 + assert {c.args[0] for c in mocked_serialize.call_args_list} == {"a", "b", "c"} diff --git a/tests/main/clients/custom_base_client/expected_client/base_model.py b/tests/main/clients/custom_base_client/expected_client/base_model.py index 2b6cc8d0..aee81e59 100644 --- a/tests/main/clients/custom_base_client/expected_client/base_model.py +++ b/tests/main/clients/custom_base_client/expected_client/base_model.py @@ -1,4 +1,4 @@ -from typing import Any, Dict +from typing import Any, Dict, Type, Union, get_args, get_origin from pydantic import BaseModel as PydanticBaseModel from pydantic.class_validators import validator @@ -15,16 +15,36 @@ class Config: # pylint: disable=no-self-argument @validator("*", pre=True) - def decode_custom_scalars(cls, value: Any, field: ModelField) -> Any: - decode = SCALARS_PARSE_FUNCTIONS.get(field.type_) - if decode and callable(decode): + def parse_custom_scalars(cls, value: Any, field: ModelField) -> Any: + return cls._parse_custom_scalar_value(value, field.annotation) + + @classmethod + def _parse_custom_scalar_value(cls, value: Any, type_: Type[Any]) -> Any: + origin = get_origin(type_) + args = get_args(type_) + if origin is list and isinstance(value, list): + return [cls._parse_custom_scalar_value(item, args[0]) for item in value] + + if origin is Union and type(None) in args: + sub_type: Any = list(filter(None, args))[0] + return cls._parse_custom_scalar_value(value, sub_type) + + decode = SCALARS_PARSE_FUNCTIONS.get(type_) + if value and decode and callable(decode): return decode(value) + return value def dict(self, **kwargs: Any) -> Dict[str, Any]: dict_ = super().dict(**kwargs) - for key, value in dict_.items(): - serialize = SCALARS_SERIALIZE_FUNCTIONS.get(type(value)) - if serialize and callable(serialize): - dict_[key] = serialize(value) - return dict_ + return {key: self._serialize_value(value) for key, value in dict_.items()} + + def _serialize_value(self, value: Any) -> Any: + serialize = SCALARS_SERIALIZE_FUNCTIONS.get(type(value)) + if serialize and callable(serialize): + return serialize(value) + + if isinstance(value, list): + return [self._serialize_value(item) for item in value] + + return value diff --git a/tests/main/clients/custom_config_file/expected_client/base_model.py b/tests/main/clients/custom_config_file/expected_client/base_model.py index 2b6cc8d0..aee81e59 100644 --- a/tests/main/clients/custom_config_file/expected_client/base_model.py +++ b/tests/main/clients/custom_config_file/expected_client/base_model.py @@ -1,4 +1,4 @@ -from typing import Any, Dict +from typing import Any, Dict, Type, Union, get_args, get_origin from pydantic import BaseModel as PydanticBaseModel from pydantic.class_validators import validator @@ -15,16 +15,36 @@ class Config: # pylint: disable=no-self-argument @validator("*", pre=True) - def decode_custom_scalars(cls, value: Any, field: ModelField) -> Any: - decode = SCALARS_PARSE_FUNCTIONS.get(field.type_) - if decode and callable(decode): + def parse_custom_scalars(cls, value: Any, field: ModelField) -> Any: + return cls._parse_custom_scalar_value(value, field.annotation) + + @classmethod + def _parse_custom_scalar_value(cls, value: Any, type_: Type[Any]) -> Any: + origin = get_origin(type_) + args = get_args(type_) + if origin is list and isinstance(value, list): + return [cls._parse_custom_scalar_value(item, args[0]) for item in value] + + if origin is Union and type(None) in args: + sub_type: Any = list(filter(None, args))[0] + return cls._parse_custom_scalar_value(value, sub_type) + + decode = SCALARS_PARSE_FUNCTIONS.get(type_) + if value and decode and callable(decode): return decode(value) + return value def dict(self, **kwargs: Any) -> Dict[str, Any]: dict_ = super().dict(**kwargs) - for key, value in dict_.items(): - serialize = SCALARS_SERIALIZE_FUNCTIONS.get(type(value)) - if serialize and callable(serialize): - dict_[key] = serialize(value) - return dict_ + return {key: self._serialize_value(value) for key, value in dict_.items()} + + def _serialize_value(self, value: Any) -> Any: + serialize = SCALARS_SERIALIZE_FUNCTIONS.get(type(value)) + if serialize and callable(serialize): + return serialize(value) + + if isinstance(value, list): + return [self._serialize_value(item) for item in value] + + return value diff --git a/tests/main/clients/custom_files_names/expected_client/base_model.py b/tests/main/clients/custom_files_names/expected_client/base_model.py index 2b6cc8d0..aee81e59 100644 --- a/tests/main/clients/custom_files_names/expected_client/base_model.py +++ b/tests/main/clients/custom_files_names/expected_client/base_model.py @@ -1,4 +1,4 @@ -from typing import Any, Dict +from typing import Any, Dict, Type, Union, get_args, get_origin from pydantic import BaseModel as PydanticBaseModel from pydantic.class_validators import validator @@ -15,16 +15,36 @@ class Config: # pylint: disable=no-self-argument @validator("*", pre=True) - def decode_custom_scalars(cls, value: Any, field: ModelField) -> Any: - decode = SCALARS_PARSE_FUNCTIONS.get(field.type_) - if decode and callable(decode): + def parse_custom_scalars(cls, value: Any, field: ModelField) -> Any: + return cls._parse_custom_scalar_value(value, field.annotation) + + @classmethod + def _parse_custom_scalar_value(cls, value: Any, type_: Type[Any]) -> Any: + origin = get_origin(type_) + args = get_args(type_) + if origin is list and isinstance(value, list): + return [cls._parse_custom_scalar_value(item, args[0]) for item in value] + + if origin is Union and type(None) in args: + sub_type: Any = list(filter(None, args))[0] + return cls._parse_custom_scalar_value(value, sub_type) + + decode = SCALARS_PARSE_FUNCTIONS.get(type_) + if value and decode and callable(decode): return decode(value) + return value def dict(self, **kwargs: Any) -> Dict[str, Any]: dict_ = super().dict(**kwargs) - for key, value in dict_.items(): - serialize = SCALARS_SERIALIZE_FUNCTIONS.get(type(value)) - if serialize and callable(serialize): - dict_[key] = serialize(value) - return dict_ + return {key: self._serialize_value(value) for key, value in dict_.items()} + + def _serialize_value(self, value: Any) -> Any: + serialize = SCALARS_SERIALIZE_FUNCTIONS.get(type(value)) + if serialize and callable(serialize): + return serialize(value) + + if isinstance(value, list): + return [self._serialize_value(item) for item in value] + + return value diff --git a/tests/main/clients/custom_scalars/expected_client/base_model.py b/tests/main/clients/custom_scalars/expected_client/base_model.py index 2b6cc8d0..aee81e59 100644 --- a/tests/main/clients/custom_scalars/expected_client/base_model.py +++ b/tests/main/clients/custom_scalars/expected_client/base_model.py @@ -1,4 +1,4 @@ -from typing import Any, Dict +from typing import Any, Dict, Type, Union, get_args, get_origin from pydantic import BaseModel as PydanticBaseModel from pydantic.class_validators import validator @@ -15,16 +15,36 @@ class Config: # pylint: disable=no-self-argument @validator("*", pre=True) - def decode_custom_scalars(cls, value: Any, field: ModelField) -> Any: - decode = SCALARS_PARSE_FUNCTIONS.get(field.type_) - if decode and callable(decode): + def parse_custom_scalars(cls, value: Any, field: ModelField) -> Any: + return cls._parse_custom_scalar_value(value, field.annotation) + + @classmethod + def _parse_custom_scalar_value(cls, value: Any, type_: Type[Any]) -> Any: + origin = get_origin(type_) + args = get_args(type_) + if origin is list and isinstance(value, list): + return [cls._parse_custom_scalar_value(item, args[0]) for item in value] + + if origin is Union and type(None) in args: + sub_type: Any = list(filter(None, args))[0] + return cls._parse_custom_scalar_value(value, sub_type) + + decode = SCALARS_PARSE_FUNCTIONS.get(type_) + if value and decode and callable(decode): return decode(value) + return value def dict(self, **kwargs: Any) -> Dict[str, Any]: dict_ = super().dict(**kwargs) - for key, value in dict_.items(): - serialize = SCALARS_SERIALIZE_FUNCTIONS.get(type(value)) - if serialize and callable(serialize): - dict_[key] = serialize(value) - return dict_ + return {key: self._serialize_value(value) for key, value in dict_.items()} + + def _serialize_value(self, value: Any) -> Any: + serialize = SCALARS_SERIALIZE_FUNCTIONS.get(type(value)) + if serialize and callable(serialize): + return serialize(value) + + if isinstance(value, list): + return [self._serialize_value(item) for item in value] + + return value diff --git a/tests/main/clients/example/expected_client/base_model.py b/tests/main/clients/example/expected_client/base_model.py index 2b6cc8d0..aee81e59 100644 --- a/tests/main/clients/example/expected_client/base_model.py +++ b/tests/main/clients/example/expected_client/base_model.py @@ -1,4 +1,4 @@ -from typing import Any, Dict +from typing import Any, Dict, Type, Union, get_args, get_origin from pydantic import BaseModel as PydanticBaseModel from pydantic.class_validators import validator @@ -15,16 +15,36 @@ class Config: # pylint: disable=no-self-argument @validator("*", pre=True) - def decode_custom_scalars(cls, value: Any, field: ModelField) -> Any: - decode = SCALARS_PARSE_FUNCTIONS.get(field.type_) - if decode and callable(decode): + def parse_custom_scalars(cls, value: Any, field: ModelField) -> Any: + return cls._parse_custom_scalar_value(value, field.annotation) + + @classmethod + def _parse_custom_scalar_value(cls, value: Any, type_: Type[Any]) -> Any: + origin = get_origin(type_) + args = get_args(type_) + if origin is list and isinstance(value, list): + return [cls._parse_custom_scalar_value(item, args[0]) for item in value] + + if origin is Union and type(None) in args: + sub_type: Any = list(filter(None, args))[0] + return cls._parse_custom_scalar_value(value, sub_type) + + decode = SCALARS_PARSE_FUNCTIONS.get(type_) + if value and decode and callable(decode): return decode(value) + return value def dict(self, **kwargs: Any) -> Dict[str, Any]: dict_ = super().dict(**kwargs) - for key, value in dict_.items(): - serialize = SCALARS_SERIALIZE_FUNCTIONS.get(type(value)) - if serialize and callable(serialize): - dict_[key] = serialize(value) - return dict_ + return {key: self._serialize_value(value) for key, value in dict_.items()} + + def _serialize_value(self, value: Any) -> Any: + serialize = SCALARS_SERIALIZE_FUNCTIONS.get(type(value)) + if serialize and callable(serialize): + return serialize(value) + + if isinstance(value, list): + return [self._serialize_value(item) for item in value] + + return value diff --git a/tests/main/clients/extended_models/expected_client/base_model.py b/tests/main/clients/extended_models/expected_client/base_model.py index 2b6cc8d0..aee81e59 100644 --- a/tests/main/clients/extended_models/expected_client/base_model.py +++ b/tests/main/clients/extended_models/expected_client/base_model.py @@ -1,4 +1,4 @@ -from typing import Any, Dict +from typing import Any, Dict, Type, Union, get_args, get_origin from pydantic import BaseModel as PydanticBaseModel from pydantic.class_validators import validator @@ -15,16 +15,36 @@ class Config: # pylint: disable=no-self-argument @validator("*", pre=True) - def decode_custom_scalars(cls, value: Any, field: ModelField) -> Any: - decode = SCALARS_PARSE_FUNCTIONS.get(field.type_) - if decode and callable(decode): + def parse_custom_scalars(cls, value: Any, field: ModelField) -> Any: + return cls._parse_custom_scalar_value(value, field.annotation) + + @classmethod + def _parse_custom_scalar_value(cls, value: Any, type_: Type[Any]) -> Any: + origin = get_origin(type_) + args = get_args(type_) + if origin is list and isinstance(value, list): + return [cls._parse_custom_scalar_value(item, args[0]) for item in value] + + if origin is Union and type(None) in args: + sub_type: Any = list(filter(None, args))[0] + return cls._parse_custom_scalar_value(value, sub_type) + + decode = SCALARS_PARSE_FUNCTIONS.get(type_) + if value and decode and callable(decode): return decode(value) + return value def dict(self, **kwargs: Any) -> Dict[str, Any]: dict_ = super().dict(**kwargs) - for key, value in dict_.items(): - serialize = SCALARS_SERIALIZE_FUNCTIONS.get(type(value)) - if serialize and callable(serialize): - dict_[key] = serialize(value) - return dict_ + return {key: self._serialize_value(value) for key, value in dict_.items()} + + def _serialize_value(self, value: Any) -> Any: + serialize = SCALARS_SERIALIZE_FUNCTIONS.get(type(value)) + if serialize and callable(serialize): + return serialize(value) + + if isinstance(value, list): + return [self._serialize_value(item) for item in value] + + return value diff --git a/tests/main/clients/inline_fragments/expected_client/base_model.py b/tests/main/clients/inline_fragments/expected_client/base_model.py index 2b6cc8d0..aee81e59 100644 --- a/tests/main/clients/inline_fragments/expected_client/base_model.py +++ b/tests/main/clients/inline_fragments/expected_client/base_model.py @@ -1,4 +1,4 @@ -from typing import Any, Dict +from typing import Any, Dict, Type, Union, get_args, get_origin from pydantic import BaseModel as PydanticBaseModel from pydantic.class_validators import validator @@ -15,16 +15,36 @@ class Config: # pylint: disable=no-self-argument @validator("*", pre=True) - def decode_custom_scalars(cls, value: Any, field: ModelField) -> Any: - decode = SCALARS_PARSE_FUNCTIONS.get(field.type_) - if decode and callable(decode): + def parse_custom_scalars(cls, value: Any, field: ModelField) -> Any: + return cls._parse_custom_scalar_value(value, field.annotation) + + @classmethod + def _parse_custom_scalar_value(cls, value: Any, type_: Type[Any]) -> Any: + origin = get_origin(type_) + args = get_args(type_) + if origin is list and isinstance(value, list): + return [cls._parse_custom_scalar_value(item, args[0]) for item in value] + + if origin is Union and type(None) in args: + sub_type: Any = list(filter(None, args))[0] + return cls._parse_custom_scalar_value(value, sub_type) + + decode = SCALARS_PARSE_FUNCTIONS.get(type_) + if value and decode and callable(decode): return decode(value) + return value def dict(self, **kwargs: Any) -> Dict[str, Any]: dict_ = super().dict(**kwargs) - for key, value in dict_.items(): - serialize = SCALARS_SERIALIZE_FUNCTIONS.get(type(value)) - if serialize and callable(serialize): - dict_[key] = serialize(value) - return dict_ + return {key: self._serialize_value(value) for key, value in dict_.items()} + + def _serialize_value(self, value: Any) -> Any: + serialize = SCALARS_SERIALIZE_FUNCTIONS.get(type(value)) + if serialize and callable(serialize): + return serialize(value) + + if isinstance(value, list): + return [self._serialize_value(item) for item in value] + + return value diff --git a/tests/main/clients/remote_schema/expected_client/base_model.py b/tests/main/clients/remote_schema/expected_client/base_model.py index 2b6cc8d0..aee81e59 100644 --- a/tests/main/clients/remote_schema/expected_client/base_model.py +++ b/tests/main/clients/remote_schema/expected_client/base_model.py @@ -1,4 +1,4 @@ -from typing import Any, Dict +from typing import Any, Dict, Type, Union, get_args, get_origin from pydantic import BaseModel as PydanticBaseModel from pydantic.class_validators import validator @@ -15,16 +15,36 @@ class Config: # pylint: disable=no-self-argument @validator("*", pre=True) - def decode_custom_scalars(cls, value: Any, field: ModelField) -> Any: - decode = SCALARS_PARSE_FUNCTIONS.get(field.type_) - if decode and callable(decode): + def parse_custom_scalars(cls, value: Any, field: ModelField) -> Any: + return cls._parse_custom_scalar_value(value, field.annotation) + + @classmethod + def _parse_custom_scalar_value(cls, value: Any, type_: Type[Any]) -> Any: + origin = get_origin(type_) + args = get_args(type_) + if origin is list and isinstance(value, list): + return [cls._parse_custom_scalar_value(item, args[0]) for item in value] + + if origin is Union and type(None) in args: + sub_type: Any = list(filter(None, args))[0] + return cls._parse_custom_scalar_value(value, sub_type) + + decode = SCALARS_PARSE_FUNCTIONS.get(type_) + if value and decode and callable(decode): return decode(value) + return value def dict(self, **kwargs: Any) -> Dict[str, Any]: dict_ = super().dict(**kwargs) - for key, value in dict_.items(): - serialize = SCALARS_SERIALIZE_FUNCTIONS.get(type(value)) - if serialize and callable(serialize): - dict_[key] = serialize(value) - return dict_ + return {key: self._serialize_value(value) for key, value in dict_.items()} + + def _serialize_value(self, value: Any) -> Any: + serialize = SCALARS_SERIALIZE_FUNCTIONS.get(type(value)) + if serialize and callable(serialize): + return serialize(value) + + if isinstance(value, list): + return [self._serialize_value(item) for item in value] + + return value