Skip to content

Commit 0bcd53c

Browse files
authored
feat(DTO): Support nested union types with single composites (#4293)
1 parent 490ebba commit 0bcd53c

File tree

2 files changed

+162
-16
lines changed

2 files changed

+162
-16
lines changed

litestar/dto/_codegen_backend.py

Lines changed: 66 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -569,28 +569,79 @@ def _create_transfer_nested_union_type_data(
569569
source_value_name: str,
570570
assignment_target: str,
571571
) -> None:
572-
for inner_type in transfer_type.inner_types:
573-
if isinstance(inner_type, CompositeType):
574-
continue
572+
def _handle_transfer_instance(simple_type_: SimpleType, conditional_: str) -> None:
573+
if simple_type_.field_definition.is_none_type:
574+
with self._start_block(f"{conditional_} {source_value_name} is None:"):
575+
self._add_stmt(f"{assignment_target} = {source_value_name}")
576+
return
575577

576-
if inner_type.nested_field_info:
577-
if self.is_data_field:
578-
constraint_type = inner_type.nested_field_info.model
579-
destination_type = inner_type.field_definition.annotation
580-
else:
581-
constraint_type = inner_type.field_definition.annotation
582-
destination_type = inner_type.nested_field_info.model
578+
field_definitions: tuple[TransferDTOFieldDefinition, ...] | None
579+
if simple_type_.nested_field_info and self.is_data_field:
580+
constraint_type = simple_type_.nested_field_info.model
581+
destination_type = simple_type_.field_definition.annotation
582+
field_definitions = simple_type_.nested_field_info.field_definitions
583+
else:
584+
constraint_type = simple_type_.field_definition.annotation
585+
destination_type = (
586+
simple_type_.nested_field_info.model
587+
if simple_type_.nested_field_info and not self.is_data_field
588+
else simple_type_.field_definition.annotation
589+
)
590+
field_definitions = (
591+
simple_type_.nested_field_info.field_definitions if simple_type_.nested_field_info else None
592+
)
583593

584-
constraint_type_name = self._add_to_fn_globals("constraint_type", constraint_type)
585-
destination_type_name = self._add_to_fn_globals("destination_type", destination_type)
594+
constraint_type_name = self._add_to_fn_globals("constraint_type", constraint_type)
595+
destination_type_name = self._add_to_fn_globals("destination_type", destination_type)
586596

587-
with self._start_block(f"if isinstance({source_value_name}, {constraint_type_name}):"):
597+
with self._start_block(f"{conditional_} isinstance({source_value_name}, {constraint_type_name}):"):
598+
if field_definitions:
588599
self._create_transfer_instance_data(
589600
destination_type_name=destination_type_name,
590601
destination_type_is_dict=destination_type is dict,
591-
field_definitions=inner_type.nested_field_info.field_definitions,
602+
field_definitions=field_definitions,
592603
source_instance_name=source_value_name,
593604
tmp_return_type_name=assignment_target,
594605
)
595-
return
606+
else:
607+
self._add_stmt(f"{assignment_target} = {source_value_name}")
608+
609+
simple_types: list[SimpleType] = []
610+
non_simple_types: list[CompositeType] = []
611+
for inner_type in transfer_type.inner_types:
612+
if isinstance(inner_type, SimpleType):
613+
simple_types.append(inner_type)
614+
else:
615+
non_simple_types.append(inner_type)
616+
617+
if len(non_simple_types) > 1:
618+
# we've got something like 'Union[list[str], dict[str, int]]. Since checking against these goes beyond the
619+
# scope of simple 'isinstance' or 'type' checks, we cannot generate code that handles these correctly.
620+
# so we give up with an exception
621+
raise RuntimeError(
622+
"Multiple composite types within unions are not supported. Received: "
623+
f"{', '.join(str(t.field_definition.raw) for t in non_simple_types)}"
624+
)
625+
626+
# special case: simple + one non-simple
627+
if len(non_simple_types) == 1 and simple_types:
628+
conditional = "if"
629+
for simple_type in simple_types:
630+
_handle_transfer_instance(simple_type, conditional_=conditional)
631+
conditional = "elif"
632+
633+
with self._start_block("else:"):
634+
self._create_transfer_type_data_body(
635+
transfer_type=non_simple_types[0],
636+
nested_as_dict=False,
637+
source_value_name=source_value_name,
638+
assignment_target=assignment_target,
639+
)
640+
return
641+
642+
for inner_type in simple_types:
643+
if inner_type.nested_field_info:
644+
_handle_transfer_instance(inner_type, conditional_="if")
645+
return
646+
596647
self._add_stmt(f"{assignment_target} = {source_value_name}")

tests/unit/test_dto/test_factory/test_backends/test_backends.py

Lines changed: 96 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,19 @@
11
from __future__ import annotations
22

33
import inspect
4+
import re
45
from dataclasses import dataclass, field
56
from types import ModuleType
67
from typing import TYPE_CHECKING, Callable, Optional
78
from unittest.mock import MagicMock
89

10+
import msgspec
911
import pytest
1012
from msgspec import Meta, Struct, to_builtins
1113

1214
from litestar import Litestar, Request, get, post
1315
from litestar._openapi.schema_generation import SchemaCreator
14-
from litestar.dto import DataclassDTO, DTOConfig, DTOField
16+
from litestar.dto import DataclassDTO, DTOConfig, DTOField, MsgspecDTO
1517
from litestar.dto._backend import DTOBackend, _create_struct_field_meta_for_field_definition
1618
from litestar.dto._codegen_backend import DTOCodegenBackend
1719
from litestar.dto._types import CollectionType, SimpleType, TransferDTOFieldDefinition
@@ -549,3 +551,96 @@ def test_create_struct_field_meta_for_field_definition(constraint_kwargs: Any) -
549551
title="test",
550552
**constraint_kwargs,
551553
)
554+
555+
556+
@pytest.mark.parametrize(
557+
"simple_type, value",
558+
[
559+
("None", [{"value": "hello"}]),
560+
("None", None),
561+
("None,int", None),
562+
("None,int", 1),
563+
("None,int", [{"value": "hello"}]),
564+
("int", [{"value": "hello"}]),
565+
("int", 1),
566+
("bool", [{"value": "hello"}]),
567+
("bool", True),
568+
("bool,str,int", True),
569+
("bool,str,int", 1),
570+
("bool,str,int", "hello"),
571+
("bool,str,int", [{"value": "hello"}]),
572+
("bool,Inner", {"value": "hello"}),
573+
("bool,Inner", [{"value": "hello"}]),
574+
("bool,Inner", True),
575+
],
576+
)
577+
def test_transfer_nested_simple_type_union(
578+
asgi_connection: Request[Any, Any, Any],
579+
create_module: Callable[[str], ModuleType],
580+
simple_type: str,
581+
value: Any,
582+
) -> None:
583+
# https://github.com/litestar-org/litestar/issues/4273
584+
585+
module = create_module(f"""
586+
from typing import Union
587+
import msgspec
588+
589+
class Inner(msgspec.Struct):
590+
value: str
591+
592+
class Outer(msgspec.Struct):
593+
some_field: Union[{simple_type}, list[Inner]]
594+
""")
595+
596+
backend = DTOCodegenBackend(
597+
handler_id="test",
598+
dto_factory=MsgspecDTO[module.Outer], # type: ignore[name-defined]
599+
field_definition=TransferDTOFieldDefinition.from_annotation(module.Outer),
600+
model_type=module.Outer,
601+
wrapper_attribute_name=None,
602+
is_data_field=True,
603+
)
604+
605+
data = backend.populate_data_from_builtins({"some_field": value}, asgi_connection)
606+
assert isinstance(data, module.Outer)
607+
if isinstance(value, list):
608+
assert data.some_field == msgspec.convert(value, type=list[module.Inner]) # type: ignore[name-defined]
609+
elif isinstance(value, dict):
610+
assert data.some_field == msgspec.convert(value, type=module.Inner)
611+
else:
612+
assert data.some_field == value
613+
614+
615+
def test_nested_union_with_multiple_composite_types_raises(
616+
asgi_connection: Request[Any, Any, Any],
617+
create_module: Callable[[str], ModuleType],
618+
) -> None:
619+
module = create_module("""
620+
from typing import Union
621+
import dataclasses
622+
623+
@dataclasses.dataclass
624+
class Inner:
625+
value: str
626+
627+
628+
@dataclasses.dataclass
629+
class Outer:
630+
some_field: Union[list[str], dict[str, str], Inner]
631+
""")
632+
633+
with pytest.raises(
634+
RuntimeError,
635+
match=re.escape(
636+
"Multiple composite types within unions are not supported. Received: list[str], dict[str, str]"
637+
),
638+
):
639+
DTOCodegenBackend(
640+
handler_id="test",
641+
dto_factory=DataclassDTO[module.Outer], # type: ignore[name-defined]
642+
field_definition=TransferDTOFieldDefinition.from_annotation(module.Outer),
643+
model_type=module.Outer,
644+
wrapper_attribute_name=None,
645+
is_data_field=True,
646+
)

0 commit comments

Comments
 (0)