Skip to content

Commit 14ce0fd

Browse files
Cache through type rather than string
1 parent 0b80ff0 commit 14ce0fd

File tree

1 file changed

+14
-10
lines changed

1 file changed

+14
-10
lines changed

serde/core.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from dataclasses import dataclass
1414

1515
from beartype.door import is_bearable
16-
from collections.abc import Mapping, Sequence, Callable
16+
from collections.abc import Mapping, Sequence, Callable, Hashable
1717
from typing import (
1818
overload,
1919
TypeVar,
@@ -86,6 +86,12 @@
8686
SETTINGS = {"debug": False}
8787

8888

89+
@dataclass(frozen=True)
90+
class UnionCacheKey:
91+
union: Hashable
92+
tagging: Tagging
93+
94+
8995
def init(debug: bool = False) -> None:
9096
SETTINGS["debug"] = debug
9197

@@ -115,15 +121,14 @@ class Union_Foo_bar:
115121
should be only once.
116122
"""
117123

118-
classes: dict[str, type[Any]] = dataclasses.field(default_factory=dict)
124+
classes: dict[Hashable, type[Any]] = dataclasses.field(default_factory=dict)
119125

120126
def _get_class(self, cls: type[Any]) -> type[Any]:
121127
"""
122128
Get a wrapper class from the the cache. If not found, it will generate
123129
the class and store it in the cache.
124130
"""
125-
class_name = f"Wrapper{typename(cls)}"
126-
wrapper = self.classes.get(class_name)
131+
wrapper = self.classes.get(cls)
127132
return wrapper or self._generate_class(cls)
128133

129134
def _generate_class(self, cls: type[Any]) -> type[Any]:
@@ -139,7 +144,7 @@ def _generate_class(self, cls: type[Any]) -> type[Any]:
139144
wrapper = dataclasses.make_dataclass(class_name, [("v", cls)])
140145

141146
serde(wrapper)
142-
self.classes[class_name] = wrapper
147+
self.classes[cls] = wrapper
143148

144149
logger.debug(f"(de)serializing code for {class_name} was generated")
145150
return wrapper
@@ -170,10 +175,8 @@ def _get_union_class(self, cls: type[Any]) -> Optional[type[Any]]:
170175
the class and store it in the cache.
171176
"""
172177
union_cls, tagging = _extract_from_with_tagging(cls)
173-
class_name = union_func_name(
174-
f"{tagging.produce_unique_class_name()}Union", list(type_args(union_cls))
175-
)
176-
wrapper = self.classes.get(class_name)
178+
cache_key = UnionCacheKey(union=union_cls, tagging=tagging)
179+
wrapper = self.classes.get(cache_key)
177180
return wrapper or self._generate_union_class(cls)
178181

179182
def _generate_union_class(self, cls: type[Any]) -> type[Any]:
@@ -184,12 +187,13 @@ def _generate_union_class(self, cls: type[Any]) -> type[Any]:
184187
import serde
185188

186189
union_cls, tagging = _extract_from_with_tagging(cls)
190+
cache_key = UnionCacheKey(union=union_cls, tagging=tagging)
187191
class_name = union_func_name(
188192
f"{tagging.produce_unique_class_name()}Union", list(type_args(union_cls))
189193
)
190194
wrapper = dataclasses.make_dataclass(class_name, [("v", union_cls)])
191195
serde.serde(wrapper, tagging=tagging)
192-
self.classes[class_name] = wrapper
196+
self.classes[cache_key] = wrapper
193197
return wrapper
194198

195199
def serialize_union(self, cls: type[Any], obj: Any) -> Any:

0 commit comments

Comments
 (0)