Skip to content

Commit 1956f68

Browse files
committed
betterbutstilltoomuchatonce
1 parent a0b1bb7 commit 1956f68

File tree

14 files changed

+655
-64
lines changed

14 files changed

+655
-64
lines changed

end_to_end_tests/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -1 +1,4 @@
11
""" Generate a complete client and verify that it is correct """
2+
import pytest
3+
4+
pytest.register_assert_rewrite('end_to_end_tests.end_to_end_live_tests')

end_to_end_tests/baseline_openapi_3.0.json

+4-2
Original file line numberDiff line numberDiff line change
@@ -2841,15 +2841,17 @@
28412841
"modelType": {
28422842
"type": "string"
28432843
}
2844-
}
2844+
},
2845+
"required": ["modelType"]
28452846
},
28462847
"ADiscriminatedUnionType2": {
28472848
"type": "object",
28482849
"properties": {
28492850
"modelType": {
28502851
"type": "string"
28512852
}
2852-
}
2853+
},
2854+
"required": ["modelType"]
28532855
}
28542856
},
28552857
"parameters": {

end_to_end_tests/baseline_openapi_3.1.yaml

+4-2
Original file line numberDiff line numberDiff line change
@@ -2835,15 +2835,17 @@ info:
28352835
"modelType": {
28362836
"type": "string"
28372837
}
2838-
}
2838+
},
2839+
"required": ["modelType"]
28392840
},
28402841
"ADiscriminatedUnionType2": {
28412842
"type": "object",
28422843
"properties": {
28432844
"modelType": {
28442845
"type": "string"
28452846
}
2846-
}
2847+
},
2848+
"required": ["modelType"]
28472849
}
28482850
}
28492851
"parameters": {
+33
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import importlib
2+
from typing import Any
3+
4+
import pytest
5+
6+
7+
def live_tests_3_x():
8+
_test_model_with_discriminated_union()
9+
10+
11+
def _import_model(module_name, class_name: str) -> Any:
12+
module = importlib.import_module(f"my_test_api_client.models.{module_name}")
13+
module = importlib.reload(module) # avoid test contamination from previous import
14+
return getattr(module, class_name)
15+
16+
17+
def _test_model_with_discriminated_union():
18+
ModelType1Class = _import_model("a_discriminated_union_type_1", "ADiscriminatedUnionType1")
19+
ModelType2Class = _import_model("a_discriminated_union_type_2", "ADiscriminatedUnionType2")
20+
ModelClass = _import_model("model_with_discriminated_union", "ModelWithDiscriminatedUnion")
21+
22+
assert (
23+
ModelClass.from_dict({"discriminated_union": {"modelType": "type1"}}) ==
24+
ModelClass(discriminated_union=ModelType1Class.from_dict({"modelType": "type1"}))
25+
)
26+
assert (
27+
ModelClass.from_dict({"discriminated_union": {"modelType": "type2"}}) ==
28+
ModelClass(discriminated_union=ModelType2Class.from_dict({"modelType": "type2"}))
29+
)
30+
with pytest.raises(TypeError):
31+
ModelClass.from_dict({"discriminated_union": {"modelType": "type3"}})
32+
with pytest.raises(TypeError):
33+
ModelClass.from_dict({"discriminated_union": {}})

end_to_end_tests/golden-record/my_test_api_client/models/a_discriminated_union_type_1.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,38 @@
1-
from typing import Any, Dict, List, Type, TypeVar, Union
1+
from typing import Any, Dict, List, Type, TypeVar
22

33
from attrs import define as _attrs_define
44
from attrs import field as _attrs_field
55

6-
from ..types import UNSET, Unset
7-
86
T = TypeVar("T", bound="ADiscriminatedUnionType1")
97

108

119
@_attrs_define
1210
class ADiscriminatedUnionType1:
1311
"""
1412
Attributes:
15-
model_type (Union[Unset, str]):
13+
model_type (str):
1614
"""
1715

18-
model_type: Union[Unset, str] = UNSET
16+
model_type: str
1917
additional_properties: Dict[str, Any] = _attrs_field(init=False, factory=dict)
2018

2119
def to_dict(self) -> Dict[str, Any]:
2220
model_type = self.model_type
2321

2422
field_dict: Dict[str, Any] = {}
2523
field_dict.update(self.additional_properties)
26-
field_dict.update({})
27-
if model_type is not UNSET:
28-
field_dict["modelType"] = model_type
24+
field_dict.update(
25+
{
26+
"modelType": model_type,
27+
}
28+
)
2929

3030
return field_dict
3131

3232
@classmethod
3333
def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T:
3434
d = src_dict.copy()
35-
model_type = d.pop("modelType", UNSET)
35+
model_type = d.pop("modelType")
3636

3737
a_discriminated_union_type_1 = cls(
3838
model_type=model_type,

end_to_end_tests/golden-record/my_test_api_client/models/a_discriminated_union_type_2.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,38 @@
1-
from typing import Any, Dict, List, Type, TypeVar, Union
1+
from typing import Any, Dict, List, Type, TypeVar
22

33
from attrs import define as _attrs_define
44
from attrs import field as _attrs_field
55

6-
from ..types import UNSET, Unset
7-
86
T = TypeVar("T", bound="ADiscriminatedUnionType2")
97

108

119
@_attrs_define
1210
class ADiscriminatedUnionType2:
1311
"""
1412
Attributes:
15-
model_type (Union[Unset, str]):
13+
model_type (str):
1614
"""
1715

18-
model_type: Union[Unset, str] = UNSET
16+
model_type: str
1917
additional_properties: Dict[str, Any] = _attrs_field(init=False, factory=dict)
2018

2119
def to_dict(self) -> Dict[str, Any]:
2220
model_type = self.model_type
2321

2422
field_dict: Dict[str, Any] = {}
2523
field_dict.update(self.additional_properties)
26-
field_dict.update({})
27-
if model_type is not UNSET:
28-
field_dict["modelType"] = model_type
24+
field_dict.update(
25+
{
26+
"modelType": model_type,
27+
}
28+
)
2929

3030
return field_dict
3131

3232
@classmethod
3333
def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T:
3434
d = src_dict.copy()
35-
model_type = d.pop("modelType", UNSET)
35+
model_type = d.pop("modelType")
3636

3737
a_discriminated_union_type_2 = cls(
3838
model_type=model_type,

end_to_end_tests/golden-record/my_test_api_client/models/model_with_discriminated_union.py

+28-17
Original file line numberDiff line numberDiff line change
@@ -59,23 +59,34 @@ def _parse_discriminated_union(
5959
return data
6060
if isinstance(data, Unset):
6161
return data
62-
try:
63-
if not isinstance(data, dict):
64-
raise TypeError()
65-
componentsschemas_a_discriminated_union_type_0 = ADiscriminatedUnionType1.from_dict(data)
66-
67-
return componentsschemas_a_discriminated_union_type_0
68-
except: # noqa: E722
69-
pass
70-
try:
71-
if not isinstance(data, dict):
72-
raise TypeError()
73-
componentsschemas_a_discriminated_union_type_1 = ADiscriminatedUnionType2.from_dict(data)
74-
75-
return componentsschemas_a_discriminated_union_type_1
76-
except: # noqa: E722
77-
pass
78-
return cast(Union["ADiscriminatedUnionType1", "ADiscriminatedUnionType2", None, Unset], data)
62+
if not isinstance(data, dict):
63+
raise TypeError()
64+
if "modelType" in data:
65+
_discriminator_value = data["modelType"]
66+
67+
def _parse_componentsschemas_a_discriminated_union_type_1(data: object) -> ADiscriminatedUnionType1:
68+
if not isinstance(data, dict):
69+
raise TypeError()
70+
componentsschemas_a_discriminated_union_type_1 = ADiscriminatedUnionType1.from_dict(data)
71+
72+
return componentsschemas_a_discriminated_union_type_1
73+
74+
def _parse_componentsschemas_a_discriminated_union_type_2(data: object) -> ADiscriminatedUnionType2:
75+
if not isinstance(data, dict):
76+
raise TypeError()
77+
componentsschemas_a_discriminated_union_type_2 = ADiscriminatedUnionType2.from_dict(data)
78+
79+
return componentsschemas_a_discriminated_union_type_2
80+
81+
_discriminator_mapping = {
82+
"type1": _parse_componentsschemas_a_discriminated_union_type_1,
83+
"type2": _parse_componentsschemas_a_discriminated_union_type_2,
84+
}
85+
if _parse_fn := _discriminator_mapping.get(_discriminator_value):
86+
return cast(
87+
Union["ADiscriminatedUnionType1", "ADiscriminatedUnionType2", None, Unset], _parse_fn(data)
88+
)
89+
raise TypeError("unrecognized value for property modelType")
7990

8091
discriminated_union = _parse_discriminated_union(d.pop("discriminated_union", UNSET))
8192

end_to_end_tests/test_end_to_end.py

+15-3
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
1+
import os
12
import shutil
23
from filecmp import cmpfiles, dircmp
34
from pathlib import Path
4-
from typing import Dict, List, Optional, Set
5+
import sys
6+
from typing import Callable, Dict, List, Optional, Set
57

68
import pytest
79
from click.testing import Result
810
from typer.testing import CliRunner
911

1012
from openapi_python_client.cli import app
13+
from .end_to_end_live_tests import live_tests_3_x
14+
1115

1216

1317
def _compare_directories(
@@ -83,6 +87,7 @@ def run_e2e_test(
8387
golden_record_path: str = "golden-record",
8488
output_path: str = "my-test-api-client",
8589
expected_missing: Optional[Set[str]] = None,
90+
live_tests: Optional[Callable[[str], None]] = None,
8691
) -> Result:
8792
output_path = Path.cwd() / output_path
8893
shutil.rmtree(output_path, ignore_errors=True)
@@ -97,6 +102,13 @@ def run_e2e_test(
97102
_compare_directories(
98103
gr_path, output_path, expected_differences=expected_differences, expected_missing=expected_missing
99104
)
105+
if live_tests:
106+
old_path = sys.path.copy()
107+
sys.path.insert(0, str(output_path))
108+
try:
109+
live_tests()
110+
finally:
111+
sys.path = old_path
100112

101113
import mypy.api
102114

@@ -131,11 +143,11 @@ def _run_command(command: str, extra_args: Optional[List[str]] = None, openapi_d
131143

132144

133145
def test_baseline_end_to_end_3_0():
134-
run_e2e_test("baseline_openapi_3.0.json", [], {})
146+
run_e2e_test("baseline_openapi_3.0.json", [], {}, live_tests=live_tests_3_x)
135147

136148

137149
def test_baseline_end_to_end_3_1():
138-
run_e2e_test("baseline_openapi_3.1.yaml", [], {})
150+
run_e2e_test("baseline_openapi_3.1.yaml", [], {}, live_tests=live_tests_3_x)
139151

140152

141153
def test_3_1_specific_features():

openapi_python_client/parser/properties/schemas.py

+7
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,13 @@ def parse_reference_path(ref_path_raw: str) -> Union[ReferencePath, ParseError]:
4646
return cast(ReferencePath, parsed.fragment)
4747

4848

49+
def get_reference_simple_name(ref_path: str) -> str:
50+
"""
51+
Takes a path like `/components/schemas/NameOfThing` and returns a string like `NameOfThing`.
52+
"""
53+
return ref_path.split("/", 3)[-1]
54+
55+
4956
@define
5057
class Class:
5158
"""Represents Python class which will be generated from an OpenAPI schema"""

0 commit comments

Comments
 (0)