13
13
from dataclasses import dataclass
14
14
15
15
from beartype .door import is_bearable
16
- from collections .abc import Mapping , Sequence , Callable
16
+ from collections .abc import Mapping , Sequence , Callable , Hashable
17
17
from typing import (
18
18
overload ,
19
19
TypeVar ,
86
86
SETTINGS = {"debug" : False }
87
87
88
88
89
+ @dataclass (frozen = True )
90
+ class UnionCacheKey :
91
+ union : Hashable
92
+ tagging : Tagging
93
+
94
+
89
95
def init (debug : bool = False ) -> None :
90
96
SETTINGS ["debug" ] = debug
91
97
@@ -115,15 +121,14 @@ class Union_Foo_bar:
115
121
should be only once.
116
122
"""
117
123
118
- classes : dict [str , type [Any ]] = dataclasses .field (default_factory = dict )
124
+ classes : dict [Hashable , type [Any ]] = dataclasses .field (default_factory = dict )
119
125
120
126
def _get_class (self , cls : type [Any ]) -> type [Any ]:
121
127
"""
122
128
Get a wrapper class from the the cache. If not found, it will generate
123
129
the class and store it in the cache.
124
130
"""
125
- class_name = f"Wrapper{ typename (cls )} "
126
- wrapper = self .classes .get (class_name )
131
+ wrapper = self .classes .get (cls )
127
132
return wrapper or self ._generate_class (cls )
128
133
129
134
def _generate_class (self , cls : type [Any ]) -> type [Any ]:
@@ -139,7 +144,7 @@ def _generate_class(self, cls: type[Any]) -> type[Any]:
139
144
wrapper = dataclasses .make_dataclass (class_name , [("v" , cls )])
140
145
141
146
serde (wrapper )
142
- self .classes [class_name ] = wrapper
147
+ self .classes [cls ] = wrapper
143
148
144
149
logger .debug (f"(de)serializing code for { class_name } was generated" )
145
150
return wrapper
@@ -170,10 +175,8 @@ def _get_union_class(self, cls: type[Any]) -> Optional[type[Any]]:
170
175
the class and store it in the cache.
171
176
"""
172
177
union_cls , tagging = _extract_from_with_tagging (cls )
173
- class_name = union_func_name (
174
- f"{ tagging .produce_unique_class_name ()} Union" , list (type_args (union_cls ))
175
- )
176
- wrapper = self .classes .get (class_name )
178
+ cache_key = UnionCacheKey (union = union_cls , tagging = tagging )
179
+ wrapper = self .classes .get (cache_key )
177
180
return wrapper or self ._generate_union_class (cls )
178
181
179
182
def _generate_union_class (self , cls : type [Any ]) -> type [Any ]:
@@ -184,12 +187,13 @@ def _generate_union_class(self, cls: type[Any]) -> type[Any]:
184
187
import serde
185
188
186
189
union_cls , tagging = _extract_from_with_tagging (cls )
190
+ cache_key = UnionCacheKey (union = union_cls , tagging = tagging )
187
191
class_name = union_func_name (
188
192
f"{ tagging .produce_unique_class_name ()} Union" , list (type_args (union_cls ))
189
193
)
190
194
wrapper = dataclasses .make_dataclass (class_name , [("v" , union_cls )])
191
195
serde .serde (wrapper , tagging = tagging )
192
- self .classes [class_name ] = wrapper
196
+ self .classes [cache_key ] = wrapper
193
197
return wrapper
194
198
195
199
def serialize_union (self , cls : type [Any ], obj : Any ) -> Any :
0 commit comments