|
2 | 2 |
|
3 | 3 | import os |
4 | 4 | import inspect |
| 5 | +import weakref |
5 | 6 | from typing import TYPE_CHECKING, Any, Type, Union, Generic, TypeVar, Callable, Optional, cast |
6 | 7 | from datetime import date, datetime |
7 | 8 | from typing_extensions import ( |
@@ -573,6 +574,9 @@ class CachedDiscriminatorType(Protocol): |
573 | 574 | __discriminator__: DiscriminatorDetails |
574 | 575 |
|
575 | 576 |
|
| 577 | +DISCRIMINATOR_CACHE: weakref.WeakKeyDictionary[type, DiscriminatorDetails] = weakref.WeakKeyDictionary() |
| 578 | + |
| 579 | + |
576 | 580 | class DiscriminatorDetails: |
577 | 581 | field_name: str |
578 | 582 | """The name of the discriminator field in the variant class, e.g. |
@@ -615,8 +619,9 @@ def __init__( |
615 | 619 |
|
616 | 620 |
|
617 | 621 | def _build_discriminated_union_meta(*, union: type, meta_annotations: tuple[Any, ...]) -> DiscriminatorDetails | None: |
618 | | - if isinstance(union, CachedDiscriminatorType): |
619 | | - return union.__discriminator__ |
| 622 | + cached = DISCRIMINATOR_CACHE.get(union) |
| 623 | + if cached is not None: |
| 624 | + return cached |
620 | 625 |
|
621 | 626 | discriminator_field_name: str | None = None |
622 | 627 |
|
@@ -669,7 +674,7 @@ def _build_discriminated_union_meta(*, union: type, meta_annotations: tuple[Any, |
669 | 674 | discriminator_field=discriminator_field_name, |
670 | 675 | discriminator_alias=discriminator_alias, |
671 | 676 | ) |
672 | | - cast(CachedDiscriminatorType, union).__discriminator__ = details |
| 677 | + DISCRIMINATOR_CACHE.setdefault(union, details) |
673 | 678 | return details |
674 | 679 |
|
675 | 680 |
|
|
0 commit comments