|
1 | 1 | from __future__ import annotations
|
2 | 2 |
|
3 | 3 | import inspect
|
| 4 | +import re |
4 | 5 | from dataclasses import dataclass, field
|
5 | 6 | from types import ModuleType
|
6 | 7 | from typing import TYPE_CHECKING, Callable, Optional
|
7 | 8 | from unittest.mock import MagicMock
|
8 | 9 |
|
| 10 | +import msgspec |
9 | 11 | import pytest
|
10 | 12 | from msgspec import Meta, Struct, to_builtins
|
11 | 13 |
|
12 | 14 | from litestar import Litestar, Request, get, post
|
13 | 15 | from litestar._openapi.schema_generation import SchemaCreator
|
14 |
| -from litestar.dto import DataclassDTO, DTOConfig, DTOField |
| 16 | +from litestar.dto import DataclassDTO, DTOConfig, DTOField, MsgspecDTO |
15 | 17 | from litestar.dto._backend import DTOBackend, _create_struct_field_meta_for_field_definition
|
16 | 18 | from litestar.dto._codegen_backend import DTOCodegenBackend
|
17 | 19 | from litestar.dto._types import CollectionType, SimpleType, TransferDTOFieldDefinition
|
@@ -549,3 +551,96 @@ def test_create_struct_field_meta_for_field_definition(constraint_kwargs: Any) -
|
549 | 551 | title="test",
|
550 | 552 | **constraint_kwargs,
|
551 | 553 | )
|
| 554 | + |
| 555 | + |
| 556 | +@pytest.mark.parametrize( |
| 557 | + "simple_type, value", |
| 558 | + [ |
| 559 | + ("None", [{"value": "hello"}]), |
| 560 | + ("None", None), |
| 561 | + ("None,int", None), |
| 562 | + ("None,int", 1), |
| 563 | + ("None,int", [{"value": "hello"}]), |
| 564 | + ("int", [{"value": "hello"}]), |
| 565 | + ("int", 1), |
| 566 | + ("bool", [{"value": "hello"}]), |
| 567 | + ("bool", True), |
| 568 | + ("bool,str,int", True), |
| 569 | + ("bool,str,int", 1), |
| 570 | + ("bool,str,int", "hello"), |
| 571 | + ("bool,str,int", [{"value": "hello"}]), |
| 572 | + ("bool,Inner", {"value": "hello"}), |
| 573 | + ("bool,Inner", [{"value": "hello"}]), |
| 574 | + ("bool,Inner", True), |
| 575 | + ], |
| 576 | +) |
| 577 | +def test_transfer_nested_simple_type_union( |
| 578 | + asgi_connection: Request[Any, Any, Any], |
| 579 | + create_module: Callable[[str], ModuleType], |
| 580 | + simple_type: str, |
| 581 | + value: Any, |
| 582 | +) -> None: |
| 583 | + # https://github.com/litestar-org/litestar/issues/4273 |
| 584 | + |
| 585 | + module = create_module(f""" |
| 586 | +from typing import Union |
| 587 | +import msgspec |
| 588 | +
|
| 589 | +class Inner(msgspec.Struct): |
| 590 | + value: str |
| 591 | +
|
| 592 | +class Outer(msgspec.Struct): |
| 593 | + some_field: Union[{simple_type}, list[Inner]] |
| 594 | +""") |
| 595 | + |
| 596 | + backend = DTOCodegenBackend( |
| 597 | + handler_id="test", |
| 598 | + dto_factory=MsgspecDTO[module.Outer], # type: ignore[name-defined] |
| 599 | + field_definition=TransferDTOFieldDefinition.from_annotation(module.Outer), |
| 600 | + model_type=module.Outer, |
| 601 | + wrapper_attribute_name=None, |
| 602 | + is_data_field=True, |
| 603 | + ) |
| 604 | + |
| 605 | + data = backend.populate_data_from_builtins({"some_field": value}, asgi_connection) |
| 606 | + assert isinstance(data, module.Outer) |
| 607 | + if isinstance(value, list): |
| 608 | + assert data.some_field == msgspec.convert(value, type=list[module.Inner]) # type: ignore[name-defined] |
| 609 | + elif isinstance(value, dict): |
| 610 | + assert data.some_field == msgspec.convert(value, type=module.Inner) |
| 611 | + else: |
| 612 | + assert data.some_field == value |
| 613 | + |
| 614 | + |
| 615 | +def test_nested_union_with_multiple_composite_types_raises( |
| 616 | + asgi_connection: Request[Any, Any, Any], |
| 617 | + create_module: Callable[[str], ModuleType], |
| 618 | +) -> None: |
| 619 | + module = create_module(""" |
| 620 | +from typing import Union |
| 621 | +import dataclasses |
| 622 | +
|
| 623 | +@dataclasses.dataclass |
| 624 | +class Inner: |
| 625 | + value: str |
| 626 | +
|
| 627 | +
|
| 628 | +@dataclasses.dataclass |
| 629 | +class Outer: |
| 630 | + some_field: Union[list[str], dict[str, str], Inner] |
| 631 | +""") |
| 632 | + |
| 633 | + with pytest.raises( |
| 634 | + RuntimeError, |
| 635 | + match=re.escape( |
| 636 | + "Multiple composite types within unions are not supported. Received: list[str], dict[str, str]" |
| 637 | + ), |
| 638 | + ): |
| 639 | + DTOCodegenBackend( |
| 640 | + handler_id="test", |
| 641 | + dto_factory=DataclassDTO[module.Outer], # type: ignore[name-defined] |
| 642 | + field_definition=TransferDTOFieldDefinition.from_annotation(module.Outer), |
| 643 | + model_type=module.Outer, |
| 644 | + wrapper_attribute_name=None, |
| 645 | + is_data_field=True, |
| 646 | + ) |
0 commit comments