From 73fd1f722cd5d40a2ddf15c66e2d15e69e4d3b41 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateusz=20Sopi=C5=84ski?= Date: Tue, 4 Apr 2023 14:56:31 +0200 Subject: [PATCH 1/8] Change base model to apply parse method on every list item --- .../dependencies/base_model.py | 20 +++- .../dependencies/test_base_model.py | 93 +++++++++++++++++++ 2 files changed, 110 insertions(+), 3 deletions(-) create mode 100644 tests/client_generators/dependencies/test_base_model.py diff --git a/ariadne_codegen/client_generators/dependencies/base_model.py b/ariadne_codegen/client_generators/dependencies/base_model.py index 2b6cc8d0..47e72abb 100644 --- a/ariadne_codegen/client_generators/dependencies/base_model.py +++ b/ariadne_codegen/client_generators/dependencies/base_model.py @@ -1,4 +1,4 @@ -from typing import Any, Dict +from typing import Any, Dict, Type, Union, get_args, get_origin from pydantic import BaseModel as PydanticBaseModel from pydantic.class_validators import validator @@ -15,10 +15,24 @@ class Config: # pylint: disable=no-self-argument @validator("*", pre=True) - def decode_custom_scalars(cls, value: Any, field: ModelField) -> Any: - decode = SCALARS_PARSE_FUNCTIONS.get(field.type_) + def parse_custom_scalars(cls, value: Any, field: ModelField) -> Any: + return cls._parse_custom_scalar_value(value, field.annotation) + + @classmethod + def _parse_custom_scalar_value(cls, value: Any, type_: Type[Any]) -> Any: + origin = get_origin(type_) + args = get_args(type_) + if origin is list and isinstance(value, list): + return [cls._parse_custom_scalar_value(item, args[0]) for item in value] + + if origin is Union and type(None) in args: + sub_type: Any = list(filter(None, args))[0] + return cls._parse_custom_scalar_value(value, sub_type) + + decode = SCALARS_PARSE_FUNCTIONS.get(type_) if decode and callable(decode): return decode(value) + return value def dict(self, **kwargs: Any) -> Dict[str, Any]: diff --git a/tests/client_generators/dependencies/test_base_model.py b/tests/client_generators/dependencies/test_base_model.py new file mode 100644 index 00000000..352717d7 --- /dev/null +++ b/tests/client_generators/dependencies/test_base_model.py @@ -0,0 +1,93 @@ +from typing import List, Optional + +import pytest + +from ariadne_codegen.client_generators.dependencies.base_model import BaseModel + + +@pytest.mark.parametrize( + "annotation, value, expected_args", + [ + (str, "a", ["a"]), + (Optional[str], "a", ["a"]), + (Optional[str], None, [None]), + (List[str], ["a", "b"], ["a", "b"]), + (List[Optional[str]], ["a", None], ["a", None]), + (Optional[List[str]], ["a", "b"], ["a", "b"]), + (Optional[List[str]], None, []), + (Optional[List[Optional[str]]], ["a", None], ["a", None]), + (List[List[str]], [["a", "b"], ["c", "d"]], ["a", "b", "c", "d"]), + (Optional[List[List[str]]], [["a", "b"], ["c", "d"]], ["a", "b", "c", "d"]), + (Optional[List[List[str]]], None, []), + ( + Optional[List[Optional[List[str]]]], + [["a", "b"], ["c", "d"]], + ["a", "b", "c", "d"], + ), + (Optional[List[Optional[List[str]]]], None, []), + (Optional[List[Optional[List[str]]]], [["a", "b"], None], ["a", "b"]), + ( + Optional[List[Optional[List[Optional[str]]]]], + [["a", "b"], ["c", "d"]], + ["a", "b", "c", "d"], + ), + (Optional[List[Optional[List[Optional[str]]]]], None, []), + (Optional[List[Optional[List[Optional[str]]]]], [["a", "b"], None], ["a", "b"]), + ( + Optional[List[Optional[List[Optional[str]]]]], + [["a", None], ["b", None]], + ["a", None, "b", None], + ), + ], +) +def test_base_model_applies_parse_on_every_element( + annotation, value, expected_args, mocker +): + mocked_parse = mocker.MagicMock(side_effect=lambda s: s) + mocker.patch( + "ariadne_codegen.client_generators.dependencies.base_model." + "SCALARS_PARSE_FUNCTIONS", + {str: mocked_parse}, + ) + + class TestModel(BaseModel): + field: annotation + + TestModel(field=value) + + assert mocked_parse.call_count == len(expected_args) + assert [c.args[0] for c in mocked_parse.call_args_list] == expected_args + + +def test_base_model_doesnt_apply_parse_on_not_matching_type(mocker): + mocked_parse = mocker.MagicMock(side_effect=lambda s: s) + mocker.patch( + "ariadne_codegen.client_generators.dependencies.base_model." + "SCALARS_PARSE_FUNCTIONS", + {str: mocked_parse}, + ) + + class TestModel(BaseModel): + field_a: int + field_b: Optional[int] + field_c: Optional[int] + field_d: List[int] + field_e: Optional[List[int]] + field_f: Optional[List[int]] + field_g: Optional[List[Optional[int]]] + field_h: Optional[List[Optional[int]]] + field_i: Optional[List[Optional[int]]] + + TestModel( + field_a=1, + field_b=2, + field_c=None, + field_d=[3, 4], + field_e=[5, 6], + field_f=None, + field_g=[7, 8], + field_h=[9, None], + field_i=None, + ) + + assert not mocked_parse.called From fd29badbff786dbbe2f1d605c9685343e946e1cd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateusz=20Sopi=C5=84ski?= Date: Wed, 5 Apr 2023 10:07:40 +0200 Subject: [PATCH 2/8] Change base model tests to use parse_obj method --- .../dependencies/test_base_model.py | 28 ++++++++++--------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/tests/client_generators/dependencies/test_base_model.py b/tests/client_generators/dependencies/test_base_model.py index 352717d7..52d1891f 100644 --- a/tests/client_generators/dependencies/test_base_model.py +++ b/tests/client_generators/dependencies/test_base_model.py @@ -40,7 +40,7 @@ ), ], ) -def test_base_model_applies_parse_on_every_element( +def test_parse_obj_applies_parse_on_every_element( annotation, value, expected_args, mocker ): mocked_parse = mocker.MagicMock(side_effect=lambda s: s) @@ -53,13 +53,13 @@ def test_base_model_applies_parse_on_every_element( class TestModel(BaseModel): field: annotation - TestModel(field=value) + TestModel.parse_obj({"field": value}) assert mocked_parse.call_count == len(expected_args) assert [c.args[0] for c in mocked_parse.call_args_list] == expected_args -def test_base_model_doesnt_apply_parse_on_not_matching_type(mocker): +def test_parse_obj_doesnt_apply_parse_on_not_matching_type(mocker): mocked_parse = mocker.MagicMock(side_effect=lambda s: s) mocker.patch( "ariadne_codegen.client_generators.dependencies.base_model." @@ -78,16 +78,18 @@ class TestModel(BaseModel): field_h: Optional[List[Optional[int]]] field_i: Optional[List[Optional[int]]] - TestModel( - field_a=1, - field_b=2, - field_c=None, - field_d=[3, 4], - field_e=[5, 6], - field_f=None, - field_g=[7, 8], - field_h=[9, None], - field_i=None, + TestModel.parse_obj( + { + "field_a": 1, + "field_b": 2, + "field_c": None, + "field_d": [3, 4], + "field_e": [5, 6], + "field_f": None, + "field_g": [7, 8], + "field_h": [9, None], + "field_i": None, + } ) assert not mocked_parse.called From 95e8b3f963bc8ed6cac3648f5440f11cd5fad373 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateusz=20Sopi=C5=84ski?= Date: Wed, 5 Apr 2023 10:56:40 +0200 Subject: [PATCH 3/8] Change base model to apply serialize on every list item --- .../dependencies/base_model.py | 16 ++- .../dependencies/test_base_model.py | 117 ++++++++++++++++++ 2 files changed, 128 insertions(+), 5 deletions(-) diff --git a/ariadne_codegen/client_generators/dependencies/base_model.py b/ariadne_codegen/client_generators/dependencies/base_model.py index 47e72abb..6e640d76 100644 --- a/ariadne_codegen/client_generators/dependencies/base_model.py +++ b/ariadne_codegen/client_generators/dependencies/base_model.py @@ -37,8 +37,14 @@ def _parse_custom_scalar_value(cls, value: Any, type_: Type[Any]) -> Any: def dict(self, **kwargs: Any) -> Dict[str, Any]: dict_ = super().dict(**kwargs) - for key, value in dict_.items(): - serialize = SCALARS_SERIALIZE_FUNCTIONS.get(type(value)) - if serialize and callable(serialize): - dict_[key] = serialize(value) - return dict_ + return {key: self._serialize_value(value) for key, value in dict_.items()} + + def _serialize_value(self, value: Any) -> Any: + serialize = SCALARS_SERIALIZE_FUNCTIONS.get(type(value)) + if serialize and callable(serialize): + return serialize(value) + + if isinstance(value, list): + return [self._serialize_value(item) for item in value] + + return value diff --git a/tests/client_generators/dependencies/test_base_model.py b/tests/client_generators/dependencies/test_base_model.py index 52d1891f..606b3f70 100644 --- a/tests/client_generators/dependencies/test_base_model.py +++ b/tests/client_generators/dependencies/test_base_model.py @@ -93,3 +93,120 @@ class TestModel(BaseModel): ) assert not mocked_parse.called + +@pytest.mark.parametrize( + "annotation, value, expected_args", + [ + (str, "a", {"a"}), + (Optional[str], "a", {"a"}), + (Optional[str], None, set()), + (List[str], ["a", "b"], {"a", "b"}), + (List[Optional[str]], ["a", None], {"a"}), + (Optional[List[str]], ["a", "b"], {"a", "b"}), + (Optional[List[str]], None, set()), + (Optional[List[Optional[str]]], ["a", None], {"a"}), + (Optional[List[Optional[str]]], None, set()), + (List[List[str]], [["a", "b"], ["c", "d"]], {"a", "b", "c", "d"}), + (Optional[List[List[str]]], [["a", "b"], ["c", "d"]], {"a", "b", "c", "d"}), + (Optional[List[List[str]]], None, set()), + ( + Optional[List[Optional[List[str]]]], + [["a", "b"], ["c", "d"]], + {"a", "b", "c", "d"}, + ), + (Optional[List[Optional[List[str]]]], None, set()), + (Optional[List[Optional[List[str]]]], [["a", "b"], None], {"a", "b"}), + ( + Optional[List[Optional[List[Optional[str]]]]], + [["a", "b"], ["c", "d"]], + {"a", "b", "c", "d"}, + ), + (Optional[List[Optional[List[Optional[str]]]]], None, set()), + (Optional[List[Optional[List[Optional[str]]]]], [["a", "b"], None], {"a", "b"}), + ( + Optional[List[Optional[List[Optional[str]]]]], + [["a", None], ["b", None]], + {"a", "b"}, + ), + ], +) +def test_dict_applies_serialize_on_every_list_element( + annotation, value, expected_args, mocker +): + mocked_serialize = mocker.MagicMock(side_effect=lambda s: s) + mocker.patch( + "ariadne_codegen.client_generators.dependencies.base_model." + "SCALARS_SERIALIZE_FUNCTIONS", + {str: mocked_serialize}, + ) + + class TestModel(BaseModel): + field: annotation + + TestModel.parse_obj({"field": value}).dict() + + assert mocked_serialize.call_count == len(expected_args) + assert {c.args[0] for c in mocked_serialize.call_args_list} == expected_args + + +def test_dict_doesnt_apply_serialize_on_not_matching_type(mocker): + mocked_serialize = mocker.MagicMock(side_effect=lambda s: s) + mocker.patch( + "ariadne_codegen.client_generators.dependencies.base_model." + "SCALARS_SERIALIZE_FUNCTIONS", + {str: mocked_serialize}, + ) + + class TestModel(BaseModel): + field_a: int + field_b: Optional[int] + field_c: Optional[int] + field_d: List[int] + field_e: Optional[List[int]] + field_f: Optional[List[int]] + field_g: Optional[List[Optional[int]]] + field_h: Optional[List[Optional[int]]] + field_i: Optional[List[Optional[int]]] + + TestModel.parse_obj( + { + "field_a": 1, + "field_b": 2, + "field_c": None, + "field_d": [3, 4], + "field_e": [5, 6], + "field_f": None, + "field_g": [7, 8], + "field_h": [9, None], + "field_i": None, + } + ).dict() + + assert not mocked_serialize.called + + +def test_dict_applies_serialize_only_once_for_every_element(mocker): + mocked_serialize = mocker.MagicMock(side_effect=lambda s: s) + mocker.patch( + "ariadne_codegen.client_generators.dependencies.base_model." + "SCALARS_SERIALIZE_FUNCTIONS", + {str: mocked_serialize}, + ) + + class TestModelC(BaseModel): + value: str + + class TestModelB(BaseModel): + value: str + field_c: TestModelC + + class TestModelA(BaseModel): + value: str + field_b: TestModelB + + TestModelA.parse_obj( + {"value": "a", "field_b": {"value": "b", "field_c": {"value": "c"}}} + ).dict() + + assert mocked_serialize.call_count == 3 + assert {c.args[0] for c in mocked_serialize.call_args_list} == {"a", "b", "c"} From a43792747bfac2e4a2939037727ae3c1302dd653 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateusz=20Sopi=C5=84ski?= Date: Wed, 5 Apr 2023 10:57:08 +0200 Subject: [PATCH 4/8] Change base model to not apply parse on None --- .../dependencies/base_model.py | 2 +- .../dependencies/test_base_model.py | 69 +++++++++++++------ 2 files changed, 50 insertions(+), 21 deletions(-) diff --git a/ariadne_codegen/client_generators/dependencies/base_model.py b/ariadne_codegen/client_generators/dependencies/base_model.py index 6e640d76..aee81e59 100644 --- a/ariadne_codegen/client_generators/dependencies/base_model.py +++ b/ariadne_codegen/client_generators/dependencies/base_model.py @@ -30,7 +30,7 @@ def _parse_custom_scalar_value(cls, value: Any, type_: Type[Any]) -> Any: return cls._parse_custom_scalar_value(value, sub_type) decode = SCALARS_PARSE_FUNCTIONS.get(type_) - if decode and callable(decode): + if value and decode and callable(decode): return decode(value) return value diff --git a/tests/client_generators/dependencies/test_base_model.py b/tests/client_generators/dependencies/test_base_model.py index 606b3f70..c9dbe81c 100644 --- a/tests/client_generators/dependencies/test_base_model.py +++ b/tests/client_generators/dependencies/test_base_model.py @@ -8,39 +8,40 @@ @pytest.mark.parametrize( "annotation, value, expected_args", [ - (str, "a", ["a"]), - (Optional[str], "a", ["a"]), - (Optional[str], None, [None]), - (List[str], ["a", "b"], ["a", "b"]), - (List[Optional[str]], ["a", None], ["a", None]), - (Optional[List[str]], ["a", "b"], ["a", "b"]), - (Optional[List[str]], None, []), - (Optional[List[Optional[str]]], ["a", None], ["a", None]), - (List[List[str]], [["a", "b"], ["c", "d"]], ["a", "b", "c", "d"]), - (Optional[List[List[str]]], [["a", "b"], ["c", "d"]], ["a", "b", "c", "d"]), - (Optional[List[List[str]]], None, []), + (str, "a", {"a"}), + (Optional[str], "a", {"a"}), + (Optional[str], None, set()), + (List[str], ["a", "b"], {"a", "b"}), + (List[Optional[str]], ["a", None], {"a"}), + (Optional[List[str]], ["a", "b"], {"a", "b"}), + (Optional[List[str]], None, set()), + (Optional[List[Optional[str]]], ["a", None], {"a"}), + (Optional[List[Optional[str]]], None, set()), + (List[List[str]], [["a", "b"], ["c", "d"]], {"a", "b", "c", "d"}), + (Optional[List[List[str]]], [["a", "b"], ["c", "d"]], {"a", "b", "c", "d"}), + (Optional[List[List[str]]], None, set()), ( Optional[List[Optional[List[str]]]], [["a", "b"], ["c", "d"]], - ["a", "b", "c", "d"], + {"a", "b", "c", "d"}, ), - (Optional[List[Optional[List[str]]]], None, []), - (Optional[List[Optional[List[str]]]], [["a", "b"], None], ["a", "b"]), + (Optional[List[Optional[List[str]]]], None, set()), + (Optional[List[Optional[List[str]]]], [["a", "b"], None], {"a", "b"}), ( Optional[List[Optional[List[Optional[str]]]]], [["a", "b"], ["c", "d"]], - ["a", "b", "c", "d"], + {"a", "b", "c", "d"}, ), - (Optional[List[Optional[List[Optional[str]]]]], None, []), - (Optional[List[Optional[List[Optional[str]]]]], [["a", "b"], None], ["a", "b"]), + (Optional[List[Optional[List[Optional[str]]]]], None, set()), + (Optional[List[Optional[List[Optional[str]]]]], [["a", "b"], None], {"a", "b"}), ( Optional[List[Optional[List[Optional[str]]]]], [["a", None], ["b", None]], - ["a", None, "b", None], + {"a", "b"}, ), ], ) -def test_parse_obj_applies_parse_on_every_element( +def test_parse_obj_applies_parse_on_every_list_element( annotation, value, expected_args, mocker ): mocked_parse = mocker.MagicMock(side_effect=lambda s: s) @@ -56,7 +57,7 @@ class TestModel(BaseModel): TestModel.parse_obj({"field": value}) assert mocked_parse.call_count == len(expected_args) - assert [c.args[0] for c in mocked_parse.call_args_list] == expected_args + assert {c.args[0] for c in mocked_parse.call_args_list} == expected_args def test_parse_obj_doesnt_apply_parse_on_not_matching_type(mocker): @@ -94,6 +95,34 @@ class TestModel(BaseModel): assert not mocked_parse.called + +def test_parse_obj_applies_parse_only_once_for_every_element(mocker): + mocked_parse = mocker.MagicMock(side_effect=lambda s: s) + mocker.patch( + "ariadne_codegen.client_generators.dependencies.base_model." + "SCALARS_PARSE_FUNCTIONS", + {str: mocked_parse}, + ) + + class TestModelC(BaseModel): + value: str + + class TestModelB(BaseModel): + value: str + field_c: TestModelC + + class TestModelA(BaseModel): + value: str + field_b: TestModelB + + TestModelA.parse_obj( + {"value": "a", "field_b": {"value": "b", "field_c": {"value": "c"}}} + ).dict() + + assert mocked_parse.call_count == 3 + assert {c.args[0] for c in mocked_parse.call_args_list} == {"a", "b", "c"} + + @pytest.mark.parametrize( "annotation, value, expected_args", [ From 516c6884b62e40cc98e21407e92eae9c480d57ed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateusz=20Sopi=C5=84ski?= Date: Wed, 5 Apr 2023 11:00:51 +0200 Subject: [PATCH 5/8] Change base model file in tests --- .../expected_client/base_model.py | 38 ++++++++++++++----- .../expected_client/base_model.py | 38 ++++++++++++++----- .../expected_client/base_model.py | 38 ++++++++++++++----- .../expected_client/base_model.py | 38 ++++++++++++++----- .../example/expected_client/base_model.py | 38 ++++++++++++++----- .../expected_client/base_model.py | 38 ++++++++++++++----- .../expected_client/base_model.py | 38 ++++++++++++++----- .../expected_client/base_model.py | 38 ++++++++++++++----- 8 files changed, 232 insertions(+), 72 deletions(-) diff --git a/tests/main/clients/custom_base_client/expected_client/base_model.py b/tests/main/clients/custom_base_client/expected_client/base_model.py index 2b6cc8d0..aee81e59 100644 --- a/tests/main/clients/custom_base_client/expected_client/base_model.py +++ b/tests/main/clients/custom_base_client/expected_client/base_model.py @@ -1,4 +1,4 @@ -from typing import Any, Dict +from typing import Any, Dict, Type, Union, get_args, get_origin from pydantic import BaseModel as PydanticBaseModel from pydantic.class_validators import validator @@ -15,16 +15,36 @@ class Config: # pylint: disable=no-self-argument @validator("*", pre=True) - def decode_custom_scalars(cls, value: Any, field: ModelField) -> Any: - decode = SCALARS_PARSE_FUNCTIONS.get(field.type_) - if decode and callable(decode): + def parse_custom_scalars(cls, value: Any, field: ModelField) -> Any: + return cls._parse_custom_scalar_value(value, field.annotation) + + @classmethod + def _parse_custom_scalar_value(cls, value: Any, type_: Type[Any]) -> Any: + origin = get_origin(type_) + args = get_args(type_) + if origin is list and isinstance(value, list): + return [cls._parse_custom_scalar_value(item, args[0]) for item in value] + + if origin is Union and type(None) in args: + sub_type: Any = list(filter(None, args))[0] + return cls._parse_custom_scalar_value(value, sub_type) + + decode = SCALARS_PARSE_FUNCTIONS.get(type_) + if value and decode and callable(decode): return decode(value) + return value def dict(self, **kwargs: Any) -> Dict[str, Any]: dict_ = super().dict(**kwargs) - for key, value in dict_.items(): - serialize = SCALARS_SERIALIZE_FUNCTIONS.get(type(value)) - if serialize and callable(serialize): - dict_[key] = serialize(value) - return dict_ + return {key: self._serialize_value(value) for key, value in dict_.items()} + + def _serialize_value(self, value: Any) -> Any: + serialize = SCALARS_SERIALIZE_FUNCTIONS.get(type(value)) + if serialize and callable(serialize): + return serialize(value) + + if isinstance(value, list): + return [self._serialize_value(item) for item in value] + + return value diff --git a/tests/main/clients/custom_config_file/expected_client/base_model.py b/tests/main/clients/custom_config_file/expected_client/base_model.py index 2b6cc8d0..aee81e59 100644 --- a/tests/main/clients/custom_config_file/expected_client/base_model.py +++ b/tests/main/clients/custom_config_file/expected_client/base_model.py @@ -1,4 +1,4 @@ -from typing import Any, Dict +from typing import Any, Dict, Type, Union, get_args, get_origin from pydantic import BaseModel as PydanticBaseModel from pydantic.class_validators import validator @@ -15,16 +15,36 @@ class Config: # pylint: disable=no-self-argument @validator("*", pre=True) - def decode_custom_scalars(cls, value: Any, field: ModelField) -> Any: - decode = SCALARS_PARSE_FUNCTIONS.get(field.type_) - if decode and callable(decode): + def parse_custom_scalars(cls, value: Any, field: ModelField) -> Any: + return cls._parse_custom_scalar_value(value, field.annotation) + + @classmethod + def _parse_custom_scalar_value(cls, value: Any, type_: Type[Any]) -> Any: + origin = get_origin(type_) + args = get_args(type_) + if origin is list and isinstance(value, list): + return [cls._parse_custom_scalar_value(item, args[0]) for item in value] + + if origin is Union and type(None) in args: + sub_type: Any = list(filter(None, args))[0] + return cls._parse_custom_scalar_value(value, sub_type) + + decode = SCALARS_PARSE_FUNCTIONS.get(type_) + if value and decode and callable(decode): return decode(value) + return value def dict(self, **kwargs: Any) -> Dict[str, Any]: dict_ = super().dict(**kwargs) - for key, value in dict_.items(): - serialize = SCALARS_SERIALIZE_FUNCTIONS.get(type(value)) - if serialize and callable(serialize): - dict_[key] = serialize(value) - return dict_ + return {key: self._serialize_value(value) for key, value in dict_.items()} + + def _serialize_value(self, value: Any) -> Any: + serialize = SCALARS_SERIALIZE_FUNCTIONS.get(type(value)) + if serialize and callable(serialize): + return serialize(value) + + if isinstance(value, list): + return [self._serialize_value(item) for item in value] + + return value diff --git a/tests/main/clients/custom_files_names/expected_client/base_model.py b/tests/main/clients/custom_files_names/expected_client/base_model.py index 2b6cc8d0..aee81e59 100644 --- a/tests/main/clients/custom_files_names/expected_client/base_model.py +++ b/tests/main/clients/custom_files_names/expected_client/base_model.py @@ -1,4 +1,4 @@ -from typing import Any, Dict +from typing import Any, Dict, Type, Union, get_args, get_origin from pydantic import BaseModel as PydanticBaseModel from pydantic.class_validators import validator @@ -15,16 +15,36 @@ class Config: # pylint: disable=no-self-argument @validator("*", pre=True) - def decode_custom_scalars(cls, value: Any, field: ModelField) -> Any: - decode = SCALARS_PARSE_FUNCTIONS.get(field.type_) - if decode and callable(decode): + def parse_custom_scalars(cls, value: Any, field: ModelField) -> Any: + return cls._parse_custom_scalar_value(value, field.annotation) + + @classmethod + def _parse_custom_scalar_value(cls, value: Any, type_: Type[Any]) -> Any: + origin = get_origin(type_) + args = get_args(type_) + if origin is list and isinstance(value, list): + return [cls._parse_custom_scalar_value(item, args[0]) for item in value] + + if origin is Union and type(None) in args: + sub_type: Any = list(filter(None, args))[0] + return cls._parse_custom_scalar_value(value, sub_type) + + decode = SCALARS_PARSE_FUNCTIONS.get(type_) + if value and decode and callable(decode): return decode(value) + return value def dict(self, **kwargs: Any) -> Dict[str, Any]: dict_ = super().dict(**kwargs) - for key, value in dict_.items(): - serialize = SCALARS_SERIALIZE_FUNCTIONS.get(type(value)) - if serialize and callable(serialize): - dict_[key] = serialize(value) - return dict_ + return {key: self._serialize_value(value) for key, value in dict_.items()} + + def _serialize_value(self, value: Any) -> Any: + serialize = SCALARS_SERIALIZE_FUNCTIONS.get(type(value)) + if serialize and callable(serialize): + return serialize(value) + + if isinstance(value, list): + return [self._serialize_value(item) for item in value] + + return value diff --git a/tests/main/clients/custom_scalars/expected_client/base_model.py b/tests/main/clients/custom_scalars/expected_client/base_model.py index 2b6cc8d0..aee81e59 100644 --- a/tests/main/clients/custom_scalars/expected_client/base_model.py +++ b/tests/main/clients/custom_scalars/expected_client/base_model.py @@ -1,4 +1,4 @@ -from typing import Any, Dict +from typing import Any, Dict, Type, Union, get_args, get_origin from pydantic import BaseModel as PydanticBaseModel from pydantic.class_validators import validator @@ -15,16 +15,36 @@ class Config: # pylint: disable=no-self-argument @validator("*", pre=True) - def decode_custom_scalars(cls, value: Any, field: ModelField) -> Any: - decode = SCALARS_PARSE_FUNCTIONS.get(field.type_) - if decode and callable(decode): + def parse_custom_scalars(cls, value: Any, field: ModelField) -> Any: + return cls._parse_custom_scalar_value(value, field.annotation) + + @classmethod + def _parse_custom_scalar_value(cls, value: Any, type_: Type[Any]) -> Any: + origin = get_origin(type_) + args = get_args(type_) + if origin is list and isinstance(value, list): + return [cls._parse_custom_scalar_value(item, args[0]) for item in value] + + if origin is Union and type(None) in args: + sub_type: Any = list(filter(None, args))[0] + return cls._parse_custom_scalar_value(value, sub_type) + + decode = SCALARS_PARSE_FUNCTIONS.get(type_) + if value and decode and callable(decode): return decode(value) + return value def dict(self, **kwargs: Any) -> Dict[str, Any]: dict_ = super().dict(**kwargs) - for key, value in dict_.items(): - serialize = SCALARS_SERIALIZE_FUNCTIONS.get(type(value)) - if serialize and callable(serialize): - dict_[key] = serialize(value) - return dict_ + return {key: self._serialize_value(value) for key, value in dict_.items()} + + def _serialize_value(self, value: Any) -> Any: + serialize = SCALARS_SERIALIZE_FUNCTIONS.get(type(value)) + if serialize and callable(serialize): + return serialize(value) + + if isinstance(value, list): + return [self._serialize_value(item) for item in value] + + return value diff --git a/tests/main/clients/example/expected_client/base_model.py b/tests/main/clients/example/expected_client/base_model.py index 2b6cc8d0..aee81e59 100644 --- a/tests/main/clients/example/expected_client/base_model.py +++ b/tests/main/clients/example/expected_client/base_model.py @@ -1,4 +1,4 @@ -from typing import Any, Dict +from typing import Any, Dict, Type, Union, get_args, get_origin from pydantic import BaseModel as PydanticBaseModel from pydantic.class_validators import validator @@ -15,16 +15,36 @@ class Config: # pylint: disable=no-self-argument @validator("*", pre=True) - def decode_custom_scalars(cls, value: Any, field: ModelField) -> Any: - decode = SCALARS_PARSE_FUNCTIONS.get(field.type_) - if decode and callable(decode): + def parse_custom_scalars(cls, value: Any, field: ModelField) -> Any: + return cls._parse_custom_scalar_value(value, field.annotation) + + @classmethod + def _parse_custom_scalar_value(cls, value: Any, type_: Type[Any]) -> Any: + origin = get_origin(type_) + args = get_args(type_) + if origin is list and isinstance(value, list): + return [cls._parse_custom_scalar_value(item, args[0]) for item in value] + + if origin is Union and type(None) in args: + sub_type: Any = list(filter(None, args))[0] + return cls._parse_custom_scalar_value(value, sub_type) + + decode = SCALARS_PARSE_FUNCTIONS.get(type_) + if value and decode and callable(decode): return decode(value) + return value def dict(self, **kwargs: Any) -> Dict[str, Any]: dict_ = super().dict(**kwargs) - for key, value in dict_.items(): - serialize = SCALARS_SERIALIZE_FUNCTIONS.get(type(value)) - if serialize and callable(serialize): - dict_[key] = serialize(value) - return dict_ + return {key: self._serialize_value(value) for key, value in dict_.items()} + + def _serialize_value(self, value: Any) -> Any: + serialize = SCALARS_SERIALIZE_FUNCTIONS.get(type(value)) + if serialize and callable(serialize): + return serialize(value) + + if isinstance(value, list): + return [self._serialize_value(item) for item in value] + + return value diff --git a/tests/main/clients/extended_models/expected_client/base_model.py b/tests/main/clients/extended_models/expected_client/base_model.py index 2b6cc8d0..aee81e59 100644 --- a/tests/main/clients/extended_models/expected_client/base_model.py +++ b/tests/main/clients/extended_models/expected_client/base_model.py @@ -1,4 +1,4 @@ -from typing import Any, Dict +from typing import Any, Dict, Type, Union, get_args, get_origin from pydantic import BaseModel as PydanticBaseModel from pydantic.class_validators import validator @@ -15,16 +15,36 @@ class Config: # pylint: disable=no-self-argument @validator("*", pre=True) - def decode_custom_scalars(cls, value: Any, field: ModelField) -> Any: - decode = SCALARS_PARSE_FUNCTIONS.get(field.type_) - if decode and callable(decode): + def parse_custom_scalars(cls, value: Any, field: ModelField) -> Any: + return cls._parse_custom_scalar_value(value, field.annotation) + + @classmethod + def _parse_custom_scalar_value(cls, value: Any, type_: Type[Any]) -> Any: + origin = get_origin(type_) + args = get_args(type_) + if origin is list and isinstance(value, list): + return [cls._parse_custom_scalar_value(item, args[0]) for item in value] + + if origin is Union and type(None) in args: + sub_type: Any = list(filter(None, args))[0] + return cls._parse_custom_scalar_value(value, sub_type) + + decode = SCALARS_PARSE_FUNCTIONS.get(type_) + if value and decode and callable(decode): return decode(value) + return value def dict(self, **kwargs: Any) -> Dict[str, Any]: dict_ = super().dict(**kwargs) - for key, value in dict_.items(): - serialize = SCALARS_SERIALIZE_FUNCTIONS.get(type(value)) - if serialize and callable(serialize): - dict_[key] = serialize(value) - return dict_ + return {key: self._serialize_value(value) for key, value in dict_.items()} + + def _serialize_value(self, value: Any) -> Any: + serialize = SCALARS_SERIALIZE_FUNCTIONS.get(type(value)) + if serialize and callable(serialize): + return serialize(value) + + if isinstance(value, list): + return [self._serialize_value(item) for item in value] + + return value diff --git a/tests/main/clients/inline_fragments/expected_client/base_model.py b/tests/main/clients/inline_fragments/expected_client/base_model.py index 2b6cc8d0..aee81e59 100644 --- a/tests/main/clients/inline_fragments/expected_client/base_model.py +++ b/tests/main/clients/inline_fragments/expected_client/base_model.py @@ -1,4 +1,4 @@ -from typing import Any, Dict +from typing import Any, Dict, Type, Union, get_args, get_origin from pydantic import BaseModel as PydanticBaseModel from pydantic.class_validators import validator @@ -15,16 +15,36 @@ class Config: # pylint: disable=no-self-argument @validator("*", pre=True) - def decode_custom_scalars(cls, value: Any, field: ModelField) -> Any: - decode = SCALARS_PARSE_FUNCTIONS.get(field.type_) - if decode and callable(decode): + def parse_custom_scalars(cls, value: Any, field: ModelField) -> Any: + return cls._parse_custom_scalar_value(value, field.annotation) + + @classmethod + def _parse_custom_scalar_value(cls, value: Any, type_: Type[Any]) -> Any: + origin = get_origin(type_) + args = get_args(type_) + if origin is list and isinstance(value, list): + return [cls._parse_custom_scalar_value(item, args[0]) for item in value] + + if origin is Union and type(None) in args: + sub_type: Any = list(filter(None, args))[0] + return cls._parse_custom_scalar_value(value, sub_type) + + decode = SCALARS_PARSE_FUNCTIONS.get(type_) + if value and decode and callable(decode): return decode(value) + return value def dict(self, **kwargs: Any) -> Dict[str, Any]: dict_ = super().dict(**kwargs) - for key, value in dict_.items(): - serialize = SCALARS_SERIALIZE_FUNCTIONS.get(type(value)) - if serialize and callable(serialize): - dict_[key] = serialize(value) - return dict_ + return {key: self._serialize_value(value) for key, value in dict_.items()} + + def _serialize_value(self, value: Any) -> Any: + serialize = SCALARS_SERIALIZE_FUNCTIONS.get(type(value)) + if serialize and callable(serialize): + return serialize(value) + + if isinstance(value, list): + return [self._serialize_value(item) for item in value] + + return value diff --git a/tests/main/clients/remote_schema/expected_client/base_model.py b/tests/main/clients/remote_schema/expected_client/base_model.py index 2b6cc8d0..aee81e59 100644 --- a/tests/main/clients/remote_schema/expected_client/base_model.py +++ b/tests/main/clients/remote_schema/expected_client/base_model.py @@ -1,4 +1,4 @@ -from typing import Any, Dict +from typing import Any, Dict, Type, Union, get_args, get_origin from pydantic import BaseModel as PydanticBaseModel from pydantic.class_validators import validator @@ -15,16 +15,36 @@ class Config: # pylint: disable=no-self-argument @validator("*", pre=True) - def decode_custom_scalars(cls, value: Any, field: ModelField) -> Any: - decode = SCALARS_PARSE_FUNCTIONS.get(field.type_) - if decode and callable(decode): + def parse_custom_scalars(cls, value: Any, field: ModelField) -> Any: + return cls._parse_custom_scalar_value(value, field.annotation) + + @classmethod + def _parse_custom_scalar_value(cls, value: Any, type_: Type[Any]) -> Any: + origin = get_origin(type_) + args = get_args(type_) + if origin is list and isinstance(value, list): + return [cls._parse_custom_scalar_value(item, args[0]) for item in value] + + if origin is Union and type(None) in args: + sub_type: Any = list(filter(None, args))[0] + return cls._parse_custom_scalar_value(value, sub_type) + + decode = SCALARS_PARSE_FUNCTIONS.get(type_) + if value and decode and callable(decode): return decode(value) + return value def dict(self, **kwargs: Any) -> Dict[str, Any]: dict_ = super().dict(**kwargs) - for key, value in dict_.items(): - serialize = SCALARS_SERIALIZE_FUNCTIONS.get(type(value)) - if serialize and callable(serialize): - dict_[key] = serialize(value) - return dict_ + return {key: self._serialize_value(value) for key, value in dict_.items()} + + def _serialize_value(self, value: Any) -> Any: + serialize = SCALARS_SERIALIZE_FUNCTIONS.get(type(value)) + if serialize and callable(serialize): + return serialize(value) + + if isinstance(value, list): + return [self._serialize_value(item) for item in value] + + return value From 2262350cfd81203c550ed0db218ba6e06fc544e2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateusz=20Sopi=C5=84ski?= Date: Wed, 5 Apr 2023 11:01:49 +0200 Subject: [PATCH 6/8] Update changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3994533e..d5c9b624 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ - Unlocked versions of black, isort, autoflake and dev dependencies - Added `remote_schema_verify_ssl` option. - Changed how default values for inputs are generated to handle potential cycles. +- Changed `BaseModel` to apply `parse` and `serialize` methods on every list element. ## 0.4.0 (2023-03-20) From c6665dba6ef63362b57e3f0c6d74a4158fd4fb69 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateusz=20Sopi=C5=84ski?= Date: Wed, 5 Apr 2023 11:07:43 +0200 Subject: [PATCH 7/8] Fix test --- tests/client_generators/dependencies/test_base_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/client_generators/dependencies/test_base_model.py b/tests/client_generators/dependencies/test_base_model.py index c9dbe81c..39fc5ecc 100644 --- a/tests/client_generators/dependencies/test_base_model.py +++ b/tests/client_generators/dependencies/test_base_model.py @@ -117,7 +117,7 @@ class TestModelA(BaseModel): TestModelA.parse_obj( {"value": "a", "field_b": {"value": "b", "field_c": {"value": "c"}}} - ).dict() + ) assert mocked_parse.call_count == 3 assert {c.args[0] for c in mocked_parse.call_args_list} == {"a", "b", "c"} From 292a2d43158024c34c1255de56875ea301ca2b8a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateusz=20Sopi=C5=84ski?= Date: Wed, 5 Apr 2023 11:48:30 +0200 Subject: [PATCH 8/8] Update CHANGELOG.md MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Rafał Pitoń --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d5c9b624..b679ba18 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,7 +8,7 @@ - Unlocked versions of black, isort, autoflake and dev dependencies - Added `remote_schema_verify_ssl` option. - Changed how default values for inputs are generated to handle potential cycles. -- Changed `BaseModel` to apply `parse` and `serialize` methods on every list element. +- Fixed `BaseModel` incorrectly calling `parse` and `serialize` methods on entire list instead of its items for `List[Scalar]`. ## 0.4.0 (2023-03-20)